S5: WebSocket real-time — per-channel broadcast, auto-reconnect

Backend:
- AppState with per-channel broadcast::Sender map
- WS handler: auth via first message, keepalive pings, broadcast forwarding
- post_message broadcasts WsEvent::Message to all subscribers

Frontend:
- useChannelSocket hook: connects, auths, appends messages, auto-reconnects
- Removed 3s polling — WebSocket is primary, initial load via REST
- Deduplication on WS messages (sender also fetches after post)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-03-29 20:29:07 +02:00
parent 9303641daf
commit 17cca7b077
9 changed files with 270 additions and 35 deletions

View File

@@ -1,8 +1,11 @@
mod db;
mod routes;
mod state;
mod ws;
use axum::{routing::get, Router};
use sqlx::sqlite::SqlitePoolOptions;
use state::AppState;
use std::env;
use tower_http::services::{ServeDir, ServeFile};
@@ -19,18 +22,18 @@ async fn main() {
eprintln!("colony: connected to {}", db_url);
// Enable WAL mode
sqlx::query("PRAGMA journal_mode=WAL")
.execute(&pool)
.await
.unwrap();
// Run embedded migrations
sqlx::migrate!("./migrations")
.run(&pool)
.await
.expect("Failed to run migrations");
let state = AppState::new(pool);
eprintln!("colony: migrations done, starting on port {}", port);
let app = Router::new()
@@ -46,11 +49,11 @@ async fn main() {
"/api/channels/{channel_id}/messages",
get(routes::list_messages).post(routes::post_message),
)
// Serve frontend static files, fallback to index.html for SPA routing
.route("/ws/{channel_id}", get(ws::ws_handler))
.fallback_service(
ServeDir::new("static").fallback(ServeFile::new("static/index.html")),
)
.with_state(pool);
.with_state(state);
let listener = tokio::net::TcpListener::bind(format!("0.0.0.0:{}", port))
.await

View File

@@ -9,6 +9,7 @@ use sqlx::SqlitePool;
use uuid::Uuid;
use crate::db::*;
use crate::state::AppState;
// ── Error handling ──
@@ -75,30 +76,30 @@ pub async fn health() -> &'static str {
// ── Channels ──
pub async fn list_channels(State(db): State<SqlitePool>) -> Result<Json<Vec<Channel>>> {
pub async fn list_channels(State(state): State<AppState>) -> Result<Json<Vec<Channel>>> {
let rows = sqlx::query_as::<_, ChannelRow>("SELECT * FROM channels ORDER BY created_at")
.fetch_all(&db)
.fetch_all(&state.db)
.await?;
let channels: Vec<Channel> = rows.iter().map(|r| r.to_api()).collect();
Ok(Json(channels))
}
pub async fn list_users(State(db): State<SqlitePool>) -> Result<Json<Vec<User>>> {
pub async fn list_users(State(state): State<AppState>) -> Result<Json<Vec<User>>> {
let rows = sqlx::query_as::<_, UserRow>("SELECT * FROM users ORDER BY created_at")
.fetch_all(&db)
.fetch_all(&state.db)
.await?;
Ok(Json(rows.iter().map(|r| r.to_api()).collect()))
}
pub async fn get_me(
State(db): State<SqlitePool>,
State(state): State<AppState>,
Query(user_param): Query<UserParam>,
) -> Result<Json<User>> {
let username = user_param.user.as_deref().unwrap_or("benji");
let row = sqlx::query_as::<_, UserRow>("SELECT * FROM users WHERE username = ?")
.bind(username)
.fetch_optional(&db)
.fetch_optional(&state.db)
.await?;
match row {
Some(r) => Ok(Json(r.to_api())),
@@ -107,36 +108,36 @@ pub async fn get_me(
}
pub async fn create_channel(
State(db): State<SqlitePool>,
State(state): State<AppState>,
Query(user_param): Query<UserParam>,
Json(body): Json<CreateChannel>,
) -> Result<impl IntoResponse> {
let id = Uuid::new_v4().to_string();
let created_by = resolve_user(&db, &user_param).await?;
let created_by = resolve_user(&state.db, &user_param).await?;
sqlx::query("INSERT INTO channels (id, name, description, created_by) VALUES (?, ?, ?, ?)")
.bind(&id)
.bind(&body.name)
.bind(&body.description)
.bind(created_by)
.execute(&db)
.execute(&state.db)
.await?;
let row = sqlx::query_as::<_, ChannelRow>("SELECT * FROM channels WHERE id = ?")
.bind(&id)
.fetch_one(&db)
.fetch_one(&state.db)
.await?;
Ok((StatusCode::CREATED, Json(row.to_api())))
}
pub async fn get_channel(
State(db): State<SqlitePool>,
State(state): State<AppState>,
Path(id): Path<String>,
) -> Result<Json<Channel>> {
let row = sqlx::query_as::<_, ChannelRow>("SELECT * FROM channels WHERE id = ?")
.bind(&id)
.fetch_optional(&db)
.fetch_optional(&state.db)
.await?;
match row {
@@ -148,7 +149,7 @@ pub async fn get_channel(
// ── Messages ──
pub async fn list_messages(
State(db): State<SqlitePool>,
State(state): State<AppState>,
Path(channel_id): Path<String>,
Query(query): Query<MessageQuery>,
) -> Result<Json<Vec<Message>>> {
@@ -186,13 +187,13 @@ pub async fn list_messages(
q = q.bind(b);
}
let rows = q.fetch_all(&db).await?;
let rows = q.fetch_all(&state.db).await?;
let messages: Vec<Message> = rows.iter().map(|r| r.to_api_message()).collect();
Ok(Json(messages))
}
pub async fn post_message(
State(db): State<SqlitePool>,
State(state): State<AppState>,
Path(channel_id): Path<String>,
Query(user_param): Query<UserParam>,
Json(body): Json<PostMessage>,
@@ -202,7 +203,7 @@ pub async fn post_message(
"SELECT COUNT(*) FROM channels WHERE id = ?",
)
.bind(&channel_id)
.fetch_one(&db)
.fetch_one(&state.db)
.await?;
if channel_exists == 0 {
@@ -215,7 +216,7 @@ pub async fn post_message(
"SELECT channel_id FROM messages WHERE id = ?",
)
.bind(reply_id.to_string())
.fetch_optional(&db)
.fetch_optional(&state.db)
.await?;
match reply_channel {
@@ -228,7 +229,7 @@ pub async fn post_message(
}
let id = Uuid::new_v4().to_string();
let user_id = resolve_user(&db, &user_param).await?;
let user_id = resolve_user(&state.db, &user_param).await?;
let msg_type = match body.r#type {
MessageType::Text => "text",
@@ -256,7 +257,7 @@ pub async fn post_message(
.bind(&body.content)
.bind(&metadata_json)
.bind(&reply_to)
.execute(&db)
.execute(&state.db)
.await?;
// Fetch the full message with user
@@ -265,10 +266,16 @@ pub async fn post_message(
FROM messages m JOIN users u ON m.user_id = u.id WHERE m.id = ?",
)
.bind(&id)
.fetch_one(&db)
.fetch_one(&state.db)
.await?;
Ok((StatusCode::CREATED, Json(row.to_api_message())))
let message = row.to_api_message();
// Broadcast to WebSocket subscribers
let tx = state.get_sender(&channel_id).await;
let _ = tx.send(WsEvent::Message(message.clone()));
Ok((StatusCode::CREATED, Json(message)))
}
// ── Joined row type for message + user ──

View File

@@ -0,0 +1,43 @@
use colony_types::WsEvent;
use sqlx::SqlitePool;
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::{broadcast, RwLock};
const BROADCAST_CAPACITY: usize = 256;
#[derive(Clone)]
pub struct AppState {
pub db: SqlitePool,
channels: Arc<RwLock<HashMap<String, broadcast::Sender<WsEvent>>>>,
}
impl AppState {
pub fn new(db: SqlitePool) -> Self {
Self {
db,
channels: Arc::new(RwLock::new(HashMap::new())),
}
}
pub async fn get_sender(&self, channel_id: &str) -> broadcast::Sender<WsEvent> {
let read = self.channels.read().await;
if let Some(tx) = read.get(channel_id) {
return tx.clone();
}
drop(read);
let mut write = self.channels.write().await;
// Double-check after acquiring write lock
if let Some(tx) = write.get(channel_id) {
return tx.clone();
}
let (tx, _) = broadcast::channel(BROADCAST_CAPACITY);
write.insert(channel_id.to_string(), tx.clone());
tx
}
pub async fn subscribe(&self, channel_id: &str) -> broadcast::Receiver<WsEvent> {
self.get_sender(channel_id).await.subscribe()
}
}

102
crates/colony/src/ws.rs Normal file
View File

@@ -0,0 +1,102 @@
use axum::{
extract::{ws::WebSocket, Path, State, WebSocketUpgrade},
response::IntoResponse,
};
use futures_util::{SinkExt, StreamExt};
use serde::Deserialize;
use std::time::Duration;
use tokio::time::interval;
use crate::state::AppState;
#[derive(Deserialize)]
struct AuthMessage {
#[serde(rename = "type")]
msg_type: String,
#[allow(dead_code)]
token: Option<String>,
user: Option<String>,
}
pub async fn ws_handler(
ws: WebSocketUpgrade,
Path(channel_id): Path<String>,
State(state): State<AppState>,
) -> impl IntoResponse {
ws.on_upgrade(move |socket| handle_socket(socket, channel_id, state))
}
async fn handle_socket(socket: WebSocket, channel_id: String, state: AppState) {
let (mut sender, mut receiver) = socket.split();
// Wait for auth message (first message must be {"type":"auth", "user":"..."})
let _user = match tokio::time::timeout(Duration::from_secs(10), receiver.next()).await {
Ok(Some(Ok(msg))) => {
if let axum::extract::ws::Message::Text(text) = msg {
match serde_json::from_str::<AuthMessage>(&text) {
Ok(auth) if auth.msg_type == "auth" => {
auth.user.unwrap_or_else(|| "anonymous".to_string())
}
_ => {
let _ = sender
.send(axum::extract::ws::Message::Text(
r#"{"error":"first message must be {\"type\":\"auth\",\"user\":\"...\"}}"#.into(),
))
.await;
return;
}
}
} else {
return;
}
}
_ => return, // Timeout or disconnect
};
// Subscribe to channel broadcast
let mut rx = state.subscribe(&channel_id).await;
// Send confirmation
let _ = sender
.send(axum::extract::ws::Message::Text(
r#"{"event":"connected"}"#.into(),
))
.await;
// Ping interval for keepalive
let mut ping_interval = interval(Duration::from_secs(30));
loop {
tokio::select! {
// Broadcast message received → forward to client
msg = rx.recv() => {
match msg {
Ok(event) => {
let json = serde_json::to_string(&event).unwrap();
if sender.send(axum::extract::ws::Message::Text(json.into())).await.is_err() {
break; // Client disconnected
}
}
Err(tokio::sync::broadcast::error::RecvError::Lagged(n)) => {
eprintln!("colony: ws client lagged by {} messages", n);
}
Err(_) => break, // Channel closed
}
}
// Client message (we don't expect any after auth, but drain to detect disconnect)
msg = receiver.next() => {
match msg {
Some(Ok(axum::extract::ws::Message::Close(_))) | None => break,
Some(Err(_)) => break,
_ => {} // Ignore other messages
}
}
// Keepalive ping
_ = ping_interval.tick() => {
if sender.send(axum::extract::ws::Message::Ping(vec![].into())).await.is_err() {
break;
}
}
}
}
}