use axum::{ extract::{ws::WebSocket, Path, State, WebSocketUpgrade}, response::IntoResponse, }; use futures_util::{SinkExt, StreamExt}; use serde::Deserialize; use std::time::Duration; use tokio::time::interval; use crate::state::AppState; #[derive(Deserialize)] struct AuthMessage { #[serde(rename = "type")] msg_type: String, #[allow(dead_code)] token: Option, user: Option, } pub async fn ws_handler( ws: WebSocketUpgrade, Path(channel_id): Path, State(state): State, ) -> impl IntoResponse { ws.on_upgrade(move |socket| handle_socket(socket, channel_id, state)) } async fn handle_socket(socket: WebSocket, channel_id: String, state: AppState) { let (mut sender, mut receiver) = socket.split(); // Wait for auth message (first message must be {"type":"auth", "user":"..."}) let _user = match tokio::time::timeout(Duration::from_secs(10), receiver.next()).await { Ok(Some(Ok(msg))) => { if let axum::extract::ws::Message::Text(text) = msg { match serde_json::from_str::(&text) { Ok(auth) if auth.msg_type == "auth" => { auth.user.unwrap_or_else(|| "anonymous".to_string()) } _ => { let _ = sender .send(axum::extract::ws::Message::Text( r#"{"error":"first message must be {\"type\":\"auth\",\"user\":\"...\"}}"#.into(), )) .await; return; } } } else { return; } } _ => return, // Timeout or disconnect }; // Subscribe to channel broadcast let mut rx = state.subscribe(&channel_id).await; // Send confirmation let _ = sender .send(axum::extract::ws::Message::Text( r#"{"event":"connected"}"#.into(), )) .await; // Ping interval for keepalive let mut ping_interval = interval(Duration::from_secs(30)); loop { tokio::select! { // Broadcast message received → forward to client msg = rx.recv() => { match msg { Ok(event) => { let json = serde_json::to_string(&event).unwrap(); if sender.send(axum::extract::ws::Message::Text(json.into())).await.is_err() { break; // Client disconnected } } Err(tokio::sync::broadcast::error::RecvError::Lagged(n)) => { eprintln!("colony: ws client lagged by {} messages", n); } Err(_) => break, // Channel closed } } // Client message (we don't expect any after auth, but drain to detect disconnect) msg = receiver.next() => { match msg { Some(Ok(axum::extract::ws::Message::Close(_))) | None => break, Some(Err(_)) => break, _ => {} // Ignore other messages } } // Keepalive ping _ = ping_interval.tick() => { if sender.send(axum::extract::ws::Message::Ping(vec![].into())).await.is_err() { break; } } } } }