Files
apes/crates/colony/src/ws.rs
limiteinductive 17cca7b077 S5: WebSocket real-time — per-channel broadcast, auto-reconnect
Backend:
- AppState with per-channel broadcast::Sender map
- WS handler: auth via first message, keepalive pings, broadcast forwarding
- post_message broadcasts WsEvent::Message to all subscribers

Frontend:
- useChannelSocket hook: connects, auths, appends messages, auto-reconnects
- Removed 3s polling — WebSocket is primary, initial load via REST
- Deduplication on WS messages (sender also fetches after post)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-29 20:29:07 +02:00

103 lines
3.4 KiB
Rust

use axum::{
extract::{ws::WebSocket, Path, State, WebSocketUpgrade},
response::IntoResponse,
};
use futures_util::{SinkExt, StreamExt};
use serde::Deserialize;
use std::time::Duration;
use tokio::time::interval;
use crate::state::AppState;
#[derive(Deserialize)]
struct AuthMessage {
#[serde(rename = "type")]
msg_type: String,
#[allow(dead_code)]
token: Option<String>,
user: Option<String>,
}
pub async fn ws_handler(
ws: WebSocketUpgrade,
Path(channel_id): Path<String>,
State(state): State<AppState>,
) -> impl IntoResponse {
ws.on_upgrade(move |socket| handle_socket(socket, channel_id, state))
}
async fn handle_socket(socket: WebSocket, channel_id: String, state: AppState) {
let (mut sender, mut receiver) = socket.split();
// Wait for auth message (first message must be {"type":"auth", "user":"..."})
let _user = match tokio::time::timeout(Duration::from_secs(10), receiver.next()).await {
Ok(Some(Ok(msg))) => {
if let axum::extract::ws::Message::Text(text) = msg {
match serde_json::from_str::<AuthMessage>(&text) {
Ok(auth) if auth.msg_type == "auth" => {
auth.user.unwrap_or_else(|| "anonymous".to_string())
}
_ => {
let _ = sender
.send(axum::extract::ws::Message::Text(
r#"{"error":"first message must be {\"type\":\"auth\",\"user\":\"...\"}}"#.into(),
))
.await;
return;
}
}
} else {
return;
}
}
_ => return, // Timeout or disconnect
};
// Subscribe to channel broadcast
let mut rx = state.subscribe(&channel_id).await;
// Send confirmation
let _ = sender
.send(axum::extract::ws::Message::Text(
r#"{"event":"connected"}"#.into(),
))
.await;
// Ping interval for keepalive
let mut ping_interval = interval(Duration::from_secs(30));
loop {
tokio::select! {
// Broadcast message received → forward to client
msg = rx.recv() => {
match msg {
Ok(event) => {
let json = serde_json::to_string(&event).unwrap();
if sender.send(axum::extract::ws::Message::Text(json.into())).await.is_err() {
break; // Client disconnected
}
}
Err(tokio::sync::broadcast::error::RecvError::Lagged(n)) => {
eprintln!("colony: ws client lagged by {} messages", n);
}
Err(_) => break, // Channel closed
}
}
// Client message (we don't expect any after auth, but drain to detect disconnect)
msg = receiver.next() => {
match msg {
Some(Ok(axum::extract::ws::Message::Close(_))) | None => break,
Some(Err(_)) => break,
_ => {} // Ignore other messages
}
}
// Keepalive ping
_ = ping_interval.tick() => {
if sender.send(axum::extract::ws::Message::Ping(vec![].into())).await.is_err() {
break;
}
}
}
}
}