diff --git a/Cargo.lock b/Cargo.lock index c4faaff..2a55a26 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -184,6 +184,7 @@ dependencies = [ "axum", "chrono", "colony-types", + "futures-util", "serde", "serde_json", "sqlx", @@ -443,6 +444,17 @@ version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cecba35d7ad927e23624b22ad55235f2239cfa44fd10428eecbeba6d6a717718" +[[package]] +name = "futures-macro" +version = "0.3.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e835b70203e41293343137df5c0664546da5745f82ec9b84d40be8336958447b" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "futures-sink" version = "0.3.32" @@ -463,6 +475,7 @@ checksum = "389ca41296e6190b48053de0321d02a77f32f8a5d2461dd38762c0593805c6d6" dependencies = [ "futures-core", "futures-io", + "futures-macro", "futures-sink", "futures-task", "memchr", diff --git a/Cargo.toml b/Cargo.toml index 5e295e5..ef4e6cc 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,4 +12,5 @@ sqlx = { version = "0.8", features = ["runtime-tokio", "sqlite", "chrono", "uuid axum = { version = "0.8", features = ["ws"] } tower-http = { version = "0.6", features = ["cors", "fs"] } ts-rs = { version = "10", features = ["serde-json-impl", "uuid-impl", "chrono-impl"] } +futures-util = "0.3" thiserror = "2" diff --git a/crates/colony/Cargo.toml b/crates/colony/Cargo.toml index ea67d82..ade5973 100644 --- a/crates/colony/Cargo.toml +++ b/crates/colony/Cargo.toml @@ -13,4 +13,5 @@ tokio = { workspace = true } sqlx = { workspace = true } axum = { workspace = true } tower-http = { workspace = true } +futures-util = { workspace = true } thiserror = { workspace = true } diff --git a/crates/colony/src/main.rs b/crates/colony/src/main.rs index ad34f84..b9c692d 100644 --- a/crates/colony/src/main.rs +++ b/crates/colony/src/main.rs @@ -1,8 +1,11 @@ mod db; mod routes; +mod state; +mod ws; use axum::{routing::get, Router}; use sqlx::sqlite::SqlitePoolOptions; +use state::AppState; use std::env; use tower_http::services::{ServeDir, ServeFile}; @@ -19,18 +22,18 @@ async fn main() { eprintln!("colony: connected to {}", db_url); - // Enable WAL mode sqlx::query("PRAGMA journal_mode=WAL") .execute(&pool) .await .unwrap(); - // Run embedded migrations sqlx::migrate!("./migrations") .run(&pool) .await .expect("Failed to run migrations"); + let state = AppState::new(pool); + eprintln!("colony: migrations done, starting on port {}", port); let app = Router::new() @@ -46,11 +49,11 @@ async fn main() { "/api/channels/{channel_id}/messages", get(routes::list_messages).post(routes::post_message), ) - // Serve frontend static files, fallback to index.html for SPA routing + .route("/ws/{channel_id}", get(ws::ws_handler)) .fallback_service( ServeDir::new("static").fallback(ServeFile::new("static/index.html")), ) - .with_state(pool); + .with_state(state); let listener = tokio::net::TcpListener::bind(format!("0.0.0.0:{}", port)) .await diff --git a/crates/colony/src/routes.rs b/crates/colony/src/routes.rs index a0c6ffb..b034d85 100644 --- a/crates/colony/src/routes.rs +++ b/crates/colony/src/routes.rs @@ -9,6 +9,7 @@ use sqlx::SqlitePool; use uuid::Uuid; use crate::db::*; +use crate::state::AppState; // ── Error handling ── @@ -75,30 +76,30 @@ pub async fn health() -> &'static str { // ── Channels ── -pub async fn list_channels(State(db): State) -> Result>> { +pub async fn list_channels(State(state): State) -> Result>> { let rows = sqlx::query_as::<_, ChannelRow>("SELECT * FROM channels ORDER BY created_at") - .fetch_all(&db) + .fetch_all(&state.db) .await?; let channels: Vec = rows.iter().map(|r| r.to_api()).collect(); Ok(Json(channels)) } -pub async fn list_users(State(db): State) -> Result>> { +pub async fn list_users(State(state): State) -> Result>> { let rows = sqlx::query_as::<_, UserRow>("SELECT * FROM users ORDER BY created_at") - .fetch_all(&db) + .fetch_all(&state.db) .await?; Ok(Json(rows.iter().map(|r| r.to_api()).collect())) } pub async fn get_me( - State(db): State, + State(state): State, Query(user_param): Query, ) -> Result> { let username = user_param.user.as_deref().unwrap_or("benji"); let row = sqlx::query_as::<_, UserRow>("SELECT * FROM users WHERE username = ?") .bind(username) - .fetch_optional(&db) + .fetch_optional(&state.db) .await?; match row { Some(r) => Ok(Json(r.to_api())), @@ -107,36 +108,36 @@ pub async fn get_me( } pub async fn create_channel( - State(db): State, + State(state): State, Query(user_param): Query, Json(body): Json, ) -> Result { let id = Uuid::new_v4().to_string(); - let created_by = resolve_user(&db, &user_param).await?; + let created_by = resolve_user(&state.db, &user_param).await?; sqlx::query("INSERT INTO channels (id, name, description, created_by) VALUES (?, ?, ?, ?)") .bind(&id) .bind(&body.name) .bind(&body.description) .bind(created_by) - .execute(&db) + .execute(&state.db) .await?; let row = sqlx::query_as::<_, ChannelRow>("SELECT * FROM channels WHERE id = ?") .bind(&id) - .fetch_one(&db) + .fetch_one(&state.db) .await?; Ok((StatusCode::CREATED, Json(row.to_api()))) } pub async fn get_channel( - State(db): State, + State(state): State, Path(id): Path, ) -> Result> { let row = sqlx::query_as::<_, ChannelRow>("SELECT * FROM channels WHERE id = ?") .bind(&id) - .fetch_optional(&db) + .fetch_optional(&state.db) .await?; match row { @@ -148,7 +149,7 @@ pub async fn get_channel( // ── Messages ── pub async fn list_messages( - State(db): State, + State(state): State, Path(channel_id): Path, Query(query): Query, ) -> Result>> { @@ -186,13 +187,13 @@ pub async fn list_messages( q = q.bind(b); } - let rows = q.fetch_all(&db).await?; + let rows = q.fetch_all(&state.db).await?; let messages: Vec = rows.iter().map(|r| r.to_api_message()).collect(); Ok(Json(messages)) } pub async fn post_message( - State(db): State, + State(state): State, Path(channel_id): Path, Query(user_param): Query, Json(body): Json, @@ -202,7 +203,7 @@ pub async fn post_message( "SELECT COUNT(*) FROM channels WHERE id = ?", ) .bind(&channel_id) - .fetch_one(&db) + .fetch_one(&state.db) .await?; if channel_exists == 0 { @@ -215,7 +216,7 @@ pub async fn post_message( "SELECT channel_id FROM messages WHERE id = ?", ) .bind(reply_id.to_string()) - .fetch_optional(&db) + .fetch_optional(&state.db) .await?; match reply_channel { @@ -228,7 +229,7 @@ pub async fn post_message( } let id = Uuid::new_v4().to_string(); - let user_id = resolve_user(&db, &user_param).await?; + let user_id = resolve_user(&state.db, &user_param).await?; let msg_type = match body.r#type { MessageType::Text => "text", @@ -256,7 +257,7 @@ pub async fn post_message( .bind(&body.content) .bind(&metadata_json) .bind(&reply_to) - .execute(&db) + .execute(&state.db) .await?; // Fetch the full message with user @@ -265,10 +266,16 @@ pub async fn post_message( FROM messages m JOIN users u ON m.user_id = u.id WHERE m.id = ?", ) .bind(&id) - .fetch_one(&db) + .fetch_one(&state.db) .await?; - Ok((StatusCode::CREATED, Json(row.to_api_message()))) + let message = row.to_api_message(); + + // Broadcast to WebSocket subscribers + let tx = state.get_sender(&channel_id).await; + let _ = tx.send(WsEvent::Message(message.clone())); + + Ok((StatusCode::CREATED, Json(message))) } // ── Joined row type for message + user ── diff --git a/crates/colony/src/state.rs b/crates/colony/src/state.rs new file mode 100644 index 0000000..0dab77d --- /dev/null +++ b/crates/colony/src/state.rs @@ -0,0 +1,43 @@ +use colony_types::WsEvent; +use sqlx::SqlitePool; +use std::collections::HashMap; +use std::sync::Arc; +use tokio::sync::{broadcast, RwLock}; + +const BROADCAST_CAPACITY: usize = 256; + +#[derive(Clone)] +pub struct AppState { + pub db: SqlitePool, + channels: Arc>>>, +} + +impl AppState { + pub fn new(db: SqlitePool) -> Self { + Self { + db, + channels: Arc::new(RwLock::new(HashMap::new())), + } + } + + pub async fn get_sender(&self, channel_id: &str) -> broadcast::Sender { + let read = self.channels.read().await; + if let Some(tx) = read.get(channel_id) { + return tx.clone(); + } + drop(read); + + let mut write = self.channels.write().await; + // Double-check after acquiring write lock + if let Some(tx) = write.get(channel_id) { + return tx.clone(); + } + let (tx, _) = broadcast::channel(BROADCAST_CAPACITY); + write.insert(channel_id.to_string(), tx.clone()); + tx + } + + pub async fn subscribe(&self, channel_id: &str) -> broadcast::Receiver { + self.get_sender(channel_id).await.subscribe() + } +} diff --git a/crates/colony/src/ws.rs b/crates/colony/src/ws.rs new file mode 100644 index 0000000..8ccac3e --- /dev/null +++ b/crates/colony/src/ws.rs @@ -0,0 +1,102 @@ +use axum::{ + extract::{ws::WebSocket, Path, State, WebSocketUpgrade}, + response::IntoResponse, +}; +use futures_util::{SinkExt, StreamExt}; +use serde::Deserialize; +use std::time::Duration; +use tokio::time::interval; + +use crate::state::AppState; + +#[derive(Deserialize)] +struct AuthMessage { + #[serde(rename = "type")] + msg_type: String, + #[allow(dead_code)] + token: Option, + user: Option, +} + +pub async fn ws_handler( + ws: WebSocketUpgrade, + Path(channel_id): Path, + State(state): State, +) -> impl IntoResponse { + ws.on_upgrade(move |socket| handle_socket(socket, channel_id, state)) +} + +async fn handle_socket(socket: WebSocket, channel_id: String, state: AppState) { + let (mut sender, mut receiver) = socket.split(); + + // Wait for auth message (first message must be {"type":"auth", "user":"..."}) + let _user = match tokio::time::timeout(Duration::from_secs(10), receiver.next()).await { + Ok(Some(Ok(msg))) => { + if let axum::extract::ws::Message::Text(text) = msg { + match serde_json::from_str::(&text) { + Ok(auth) if auth.msg_type == "auth" => { + auth.user.unwrap_or_else(|| "anonymous".to_string()) + } + _ => { + let _ = sender + .send(axum::extract::ws::Message::Text( + r#"{"error":"first message must be {\"type\":\"auth\",\"user\":\"...\"}}"#.into(), + )) + .await; + return; + } + } + } else { + return; + } + } + _ => return, // Timeout or disconnect + }; + + // Subscribe to channel broadcast + let mut rx = state.subscribe(&channel_id).await; + + // Send confirmation + let _ = sender + .send(axum::extract::ws::Message::Text( + r#"{"event":"connected"}"#.into(), + )) + .await; + + // Ping interval for keepalive + let mut ping_interval = interval(Duration::from_secs(30)); + + loop { + tokio::select! { + // Broadcast message received → forward to client + msg = rx.recv() => { + match msg { + Ok(event) => { + let json = serde_json::to_string(&event).unwrap(); + if sender.send(axum::extract::ws::Message::Text(json.into())).await.is_err() { + break; // Client disconnected + } + } + Err(tokio::sync::broadcast::error::RecvError::Lagged(n)) => { + eprintln!("colony: ws client lagged by {} messages", n); + } + Err(_) => break, // Channel closed + } + } + // Client message (we don't expect any after auth, but drain to detect disconnect) + msg = receiver.next() => { + match msg { + Some(Ok(axum::extract::ws::Message::Close(_))) | None => break, + Some(Err(_)) => break, + _ => {} // Ignore other messages + } + } + // Keepalive ping + _ = ping_interval.tick() => { + if sender.send(axum::extract::ws::Message::Ping(vec![].into())).await.is_err() { + break; + } + } + } + } +} diff --git a/ui/colony/src/App.tsx b/ui/colony/src/App.tsx index 0b79506..79a3302 100644 --- a/ui/colony/src/App.tsx +++ b/ui/colony/src/App.tsx @@ -6,6 +6,7 @@ import { ChannelSidebar } from "@/components/ChannelSidebar"; import { MessageItem } from "@/components/MessageItem"; import { ComposeBox } from "@/components/ComposeBox"; import { Sheet, SheetContent, SheetTrigger } from "@/components/ui/sheet"; +import { useChannelSocket } from "@/hooks/useChannelSocket"; export default function App() { const [channels, setChannels] = useState([]); @@ -34,10 +35,21 @@ export default function App() { setMessages(msgs); } } catch { - // Silently ignore fetch errors during polling + // Silently ignore fetch errors } }, []); + // WebSocket: append new messages in real-time + const handleWsMessage = useCallback((msg: Message) => { + setMessages((prev) => { + // Deduplicate — the sender also fetches after posting + if (prev.some((m) => m.id === msg.id)) return prev; + return [...prev, msg]; + }); + }, []); + + useChannelSocket(activeChannelId, handleWsMessage); + useEffect(() => { loadChannels(); }, [loadChannels]); useEffect(() => { @@ -47,6 +59,7 @@ export default function App() { loadMessages(); }, [activeChannelId, loadMessages]); + // Auto-scroll only on new messages useEffect(() => { if (messages.length > prevMsgCountRef.current && scrollRef.current) { scrollRef.current.scrollTop = scrollRef.current.scrollHeight; @@ -54,11 +67,6 @@ export default function App() { prevMsgCountRef.current = messages.length; }, [messages]); - useEffect(() => { - const interval = setInterval(loadMessages, 3000); - return () => clearInterval(interval); - }, [loadMessages]); - const messagesById = new Map(messages.map((m) => [m.id, m])); const activeChannel = channels.find((c) => c.id === activeChannelId); @@ -76,15 +84,12 @@ export default function App() { return (
- {/* Desktop sidebar */}
{sidebar}
- {/* Channel header */}
- {/* Mobile: Sheet trigger */} = @@ -112,7 +117,6 @@ export default function App() { )}
- {/* Messages */}
{messages.length === 0 && activeChannelId && (
diff --git a/ui/colony/src/hooks/useChannelSocket.ts b/ui/colony/src/hooks/useChannelSocket.ts new file mode 100644 index 0000000..6846f02 --- /dev/null +++ b/ui/colony/src/hooks/useChannelSocket.ts @@ -0,0 +1,61 @@ +import { useEffect, useRef, useCallback } from "react"; +import type { Message } from "@/types/Message"; +import { getCurrentUsername } from "@/api"; + +interface WsMessageEvent { + event: "message"; + data: Message; +} + +interface WsConnectedEvent { + event: "connected"; +} + +type WsEvent = WsMessageEvent | WsConnectedEvent; + +export function useChannelSocket( + channelId: string | null, + onMessage: (msg: Message) => void, +) { + const wsRef = useRef(null); + const reconnectTimer = useRef>(); + + const connect = useCallback(() => { + if (!channelId) return; + + const protocol = window.location.protocol === "https:" ? "wss:" : "ws:"; + const host = window.location.host; + const ws = new WebSocket(`${protocol}//${host}/ws/${channelId}`); + + ws.onopen = () => { + // Send auth message + ws.send(JSON.stringify({ + type: "auth", + user: getCurrentUsername(), + })); + }; + + ws.onmessage = (e) => { + const event: WsEvent = JSON.parse(e.data); + if (event.event === "message") { + onMessage(event.data); + } + }; + + ws.onclose = () => { + // Reconnect after 3s + reconnectTimer.current = setTimeout(connect, 3000); + }; + + wsRef.current = ws; + }, [channelId, onMessage]); + + useEffect(() => { + connect(); + return () => { + clearTimeout(reconnectTimer.current); + wsRef.current?.close(); + wsRef.current = null; + }; + }, [connect]); +}