S5: WebSocket real-time — per-channel broadcast, auto-reconnect
Backend: - AppState with per-channel broadcast::Sender map - WS handler: auth via first message, keepalive pings, broadcast forwarding - post_message broadcasts WsEvent::Message to all subscribers Frontend: - useChannelSocket hook: connects, auths, appends messages, auto-reconnects - Removed 3s polling — WebSocket is primary, initial load via REST - Deduplication on WS messages (sender also fetches after post) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
13
Cargo.lock
generated
13
Cargo.lock
generated
@@ -184,6 +184,7 @@ dependencies = [
|
|||||||
"axum",
|
"axum",
|
||||||
"chrono",
|
"chrono",
|
||||||
"colony-types",
|
"colony-types",
|
||||||
|
"futures-util",
|
||||||
"serde",
|
"serde",
|
||||||
"serde_json",
|
"serde_json",
|
||||||
"sqlx",
|
"sqlx",
|
||||||
@@ -443,6 +444,17 @@ version = "0.3.32"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "cecba35d7ad927e23624b22ad55235f2239cfa44fd10428eecbeba6d6a717718"
|
checksum = "cecba35d7ad927e23624b22ad55235f2239cfa44fd10428eecbeba6d6a717718"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "futures-macro"
|
||||||
|
version = "0.3.32"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "e835b70203e41293343137df5c0664546da5745f82ec9b84d40be8336958447b"
|
||||||
|
dependencies = [
|
||||||
|
"proc-macro2",
|
||||||
|
"quote",
|
||||||
|
"syn",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "futures-sink"
|
name = "futures-sink"
|
||||||
version = "0.3.32"
|
version = "0.3.32"
|
||||||
@@ -463,6 +475,7 @@ checksum = "389ca41296e6190b48053de0321d02a77f32f8a5d2461dd38762c0593805c6d6"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"futures-core",
|
"futures-core",
|
||||||
"futures-io",
|
"futures-io",
|
||||||
|
"futures-macro",
|
||||||
"futures-sink",
|
"futures-sink",
|
||||||
"futures-task",
|
"futures-task",
|
||||||
"memchr",
|
"memchr",
|
||||||
|
|||||||
@@ -12,4 +12,5 @@ sqlx = { version = "0.8", features = ["runtime-tokio", "sqlite", "chrono", "uuid
|
|||||||
axum = { version = "0.8", features = ["ws"] }
|
axum = { version = "0.8", features = ["ws"] }
|
||||||
tower-http = { version = "0.6", features = ["cors", "fs"] }
|
tower-http = { version = "0.6", features = ["cors", "fs"] }
|
||||||
ts-rs = { version = "10", features = ["serde-json-impl", "uuid-impl", "chrono-impl"] }
|
ts-rs = { version = "10", features = ["serde-json-impl", "uuid-impl", "chrono-impl"] }
|
||||||
|
futures-util = "0.3"
|
||||||
thiserror = "2"
|
thiserror = "2"
|
||||||
|
|||||||
@@ -13,4 +13,5 @@ tokio = { workspace = true }
|
|||||||
sqlx = { workspace = true }
|
sqlx = { workspace = true }
|
||||||
axum = { workspace = true }
|
axum = { workspace = true }
|
||||||
tower-http = { workspace = true }
|
tower-http = { workspace = true }
|
||||||
|
futures-util = { workspace = true }
|
||||||
thiserror = { workspace = true }
|
thiserror = { workspace = true }
|
||||||
|
|||||||
@@ -1,8 +1,11 @@
|
|||||||
mod db;
|
mod db;
|
||||||
mod routes;
|
mod routes;
|
||||||
|
mod state;
|
||||||
|
mod ws;
|
||||||
|
|
||||||
use axum::{routing::get, Router};
|
use axum::{routing::get, Router};
|
||||||
use sqlx::sqlite::SqlitePoolOptions;
|
use sqlx::sqlite::SqlitePoolOptions;
|
||||||
|
use state::AppState;
|
||||||
use std::env;
|
use std::env;
|
||||||
use tower_http::services::{ServeDir, ServeFile};
|
use tower_http::services::{ServeDir, ServeFile};
|
||||||
|
|
||||||
@@ -19,18 +22,18 @@ async fn main() {
|
|||||||
|
|
||||||
eprintln!("colony: connected to {}", db_url);
|
eprintln!("colony: connected to {}", db_url);
|
||||||
|
|
||||||
// Enable WAL mode
|
|
||||||
sqlx::query("PRAGMA journal_mode=WAL")
|
sqlx::query("PRAGMA journal_mode=WAL")
|
||||||
.execute(&pool)
|
.execute(&pool)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
// Run embedded migrations
|
|
||||||
sqlx::migrate!("./migrations")
|
sqlx::migrate!("./migrations")
|
||||||
.run(&pool)
|
.run(&pool)
|
||||||
.await
|
.await
|
||||||
.expect("Failed to run migrations");
|
.expect("Failed to run migrations");
|
||||||
|
|
||||||
|
let state = AppState::new(pool);
|
||||||
|
|
||||||
eprintln!("colony: migrations done, starting on port {}", port);
|
eprintln!("colony: migrations done, starting on port {}", port);
|
||||||
|
|
||||||
let app = Router::new()
|
let app = Router::new()
|
||||||
@@ -46,11 +49,11 @@ async fn main() {
|
|||||||
"/api/channels/{channel_id}/messages",
|
"/api/channels/{channel_id}/messages",
|
||||||
get(routes::list_messages).post(routes::post_message),
|
get(routes::list_messages).post(routes::post_message),
|
||||||
)
|
)
|
||||||
// Serve frontend static files, fallback to index.html for SPA routing
|
.route("/ws/{channel_id}", get(ws::ws_handler))
|
||||||
.fallback_service(
|
.fallback_service(
|
||||||
ServeDir::new("static").fallback(ServeFile::new("static/index.html")),
|
ServeDir::new("static").fallback(ServeFile::new("static/index.html")),
|
||||||
)
|
)
|
||||||
.with_state(pool);
|
.with_state(state);
|
||||||
|
|
||||||
let listener = tokio::net::TcpListener::bind(format!("0.0.0.0:{}", port))
|
let listener = tokio::net::TcpListener::bind(format!("0.0.0.0:{}", port))
|
||||||
.await
|
.await
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ use sqlx::SqlitePool;
|
|||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
|
|
||||||
use crate::db::*;
|
use crate::db::*;
|
||||||
|
use crate::state::AppState;
|
||||||
|
|
||||||
// ── Error handling ──
|
// ── Error handling ──
|
||||||
|
|
||||||
@@ -75,30 +76,30 @@ pub async fn health() -> &'static str {
|
|||||||
|
|
||||||
// ── Channels ──
|
// ── Channels ──
|
||||||
|
|
||||||
pub async fn list_channels(State(db): State<SqlitePool>) -> Result<Json<Vec<Channel>>> {
|
pub async fn list_channels(State(state): State<AppState>) -> Result<Json<Vec<Channel>>> {
|
||||||
let rows = sqlx::query_as::<_, ChannelRow>("SELECT * FROM channels ORDER BY created_at")
|
let rows = sqlx::query_as::<_, ChannelRow>("SELECT * FROM channels ORDER BY created_at")
|
||||||
.fetch_all(&db)
|
.fetch_all(&state.db)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
let channels: Vec<Channel> = rows.iter().map(|r| r.to_api()).collect();
|
let channels: Vec<Channel> = rows.iter().map(|r| r.to_api()).collect();
|
||||||
Ok(Json(channels))
|
Ok(Json(channels))
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn list_users(State(db): State<SqlitePool>) -> Result<Json<Vec<User>>> {
|
pub async fn list_users(State(state): State<AppState>) -> Result<Json<Vec<User>>> {
|
||||||
let rows = sqlx::query_as::<_, UserRow>("SELECT * FROM users ORDER BY created_at")
|
let rows = sqlx::query_as::<_, UserRow>("SELECT * FROM users ORDER BY created_at")
|
||||||
.fetch_all(&db)
|
.fetch_all(&state.db)
|
||||||
.await?;
|
.await?;
|
||||||
Ok(Json(rows.iter().map(|r| r.to_api()).collect()))
|
Ok(Json(rows.iter().map(|r| r.to_api()).collect()))
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn get_me(
|
pub async fn get_me(
|
||||||
State(db): State<SqlitePool>,
|
State(state): State<AppState>,
|
||||||
Query(user_param): Query<UserParam>,
|
Query(user_param): Query<UserParam>,
|
||||||
) -> Result<Json<User>> {
|
) -> Result<Json<User>> {
|
||||||
let username = user_param.user.as_deref().unwrap_or("benji");
|
let username = user_param.user.as_deref().unwrap_or("benji");
|
||||||
let row = sqlx::query_as::<_, UserRow>("SELECT * FROM users WHERE username = ?")
|
let row = sqlx::query_as::<_, UserRow>("SELECT * FROM users WHERE username = ?")
|
||||||
.bind(username)
|
.bind(username)
|
||||||
.fetch_optional(&db)
|
.fetch_optional(&state.db)
|
||||||
.await?;
|
.await?;
|
||||||
match row {
|
match row {
|
||||||
Some(r) => Ok(Json(r.to_api())),
|
Some(r) => Ok(Json(r.to_api())),
|
||||||
@@ -107,36 +108,36 @@ pub async fn get_me(
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub async fn create_channel(
|
pub async fn create_channel(
|
||||||
State(db): State<SqlitePool>,
|
State(state): State<AppState>,
|
||||||
Query(user_param): Query<UserParam>,
|
Query(user_param): Query<UserParam>,
|
||||||
Json(body): Json<CreateChannel>,
|
Json(body): Json<CreateChannel>,
|
||||||
) -> Result<impl IntoResponse> {
|
) -> Result<impl IntoResponse> {
|
||||||
let id = Uuid::new_v4().to_string();
|
let id = Uuid::new_v4().to_string();
|
||||||
let created_by = resolve_user(&db, &user_param).await?;
|
let created_by = resolve_user(&state.db, &user_param).await?;
|
||||||
|
|
||||||
sqlx::query("INSERT INTO channels (id, name, description, created_by) VALUES (?, ?, ?, ?)")
|
sqlx::query("INSERT INTO channels (id, name, description, created_by) VALUES (?, ?, ?, ?)")
|
||||||
.bind(&id)
|
.bind(&id)
|
||||||
.bind(&body.name)
|
.bind(&body.name)
|
||||||
.bind(&body.description)
|
.bind(&body.description)
|
||||||
.bind(created_by)
|
.bind(created_by)
|
||||||
.execute(&db)
|
.execute(&state.db)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
let row = sqlx::query_as::<_, ChannelRow>("SELECT * FROM channels WHERE id = ?")
|
let row = sqlx::query_as::<_, ChannelRow>("SELECT * FROM channels WHERE id = ?")
|
||||||
.bind(&id)
|
.bind(&id)
|
||||||
.fetch_one(&db)
|
.fetch_one(&state.db)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
Ok((StatusCode::CREATED, Json(row.to_api())))
|
Ok((StatusCode::CREATED, Json(row.to_api())))
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn get_channel(
|
pub async fn get_channel(
|
||||||
State(db): State<SqlitePool>,
|
State(state): State<AppState>,
|
||||||
Path(id): Path<String>,
|
Path(id): Path<String>,
|
||||||
) -> Result<Json<Channel>> {
|
) -> Result<Json<Channel>> {
|
||||||
let row = sqlx::query_as::<_, ChannelRow>("SELECT * FROM channels WHERE id = ?")
|
let row = sqlx::query_as::<_, ChannelRow>("SELECT * FROM channels WHERE id = ?")
|
||||||
.bind(&id)
|
.bind(&id)
|
||||||
.fetch_optional(&db)
|
.fetch_optional(&state.db)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
match row {
|
match row {
|
||||||
@@ -148,7 +149,7 @@ pub async fn get_channel(
|
|||||||
// ── Messages ──
|
// ── Messages ──
|
||||||
|
|
||||||
pub async fn list_messages(
|
pub async fn list_messages(
|
||||||
State(db): State<SqlitePool>,
|
State(state): State<AppState>,
|
||||||
Path(channel_id): Path<String>,
|
Path(channel_id): Path<String>,
|
||||||
Query(query): Query<MessageQuery>,
|
Query(query): Query<MessageQuery>,
|
||||||
) -> Result<Json<Vec<Message>>> {
|
) -> Result<Json<Vec<Message>>> {
|
||||||
@@ -186,13 +187,13 @@ pub async fn list_messages(
|
|||||||
q = q.bind(b);
|
q = q.bind(b);
|
||||||
}
|
}
|
||||||
|
|
||||||
let rows = q.fetch_all(&db).await?;
|
let rows = q.fetch_all(&state.db).await?;
|
||||||
let messages: Vec<Message> = rows.iter().map(|r| r.to_api_message()).collect();
|
let messages: Vec<Message> = rows.iter().map(|r| r.to_api_message()).collect();
|
||||||
Ok(Json(messages))
|
Ok(Json(messages))
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn post_message(
|
pub async fn post_message(
|
||||||
State(db): State<SqlitePool>,
|
State(state): State<AppState>,
|
||||||
Path(channel_id): Path<String>,
|
Path(channel_id): Path<String>,
|
||||||
Query(user_param): Query<UserParam>,
|
Query(user_param): Query<UserParam>,
|
||||||
Json(body): Json<PostMessage>,
|
Json(body): Json<PostMessage>,
|
||||||
@@ -202,7 +203,7 @@ pub async fn post_message(
|
|||||||
"SELECT COUNT(*) FROM channels WHERE id = ?",
|
"SELECT COUNT(*) FROM channels WHERE id = ?",
|
||||||
)
|
)
|
||||||
.bind(&channel_id)
|
.bind(&channel_id)
|
||||||
.fetch_one(&db)
|
.fetch_one(&state.db)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
if channel_exists == 0 {
|
if channel_exists == 0 {
|
||||||
@@ -215,7 +216,7 @@ pub async fn post_message(
|
|||||||
"SELECT channel_id FROM messages WHERE id = ?",
|
"SELECT channel_id FROM messages WHERE id = ?",
|
||||||
)
|
)
|
||||||
.bind(reply_id.to_string())
|
.bind(reply_id.to_string())
|
||||||
.fetch_optional(&db)
|
.fetch_optional(&state.db)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
match reply_channel {
|
match reply_channel {
|
||||||
@@ -228,7 +229,7 @@ pub async fn post_message(
|
|||||||
}
|
}
|
||||||
|
|
||||||
let id = Uuid::new_v4().to_string();
|
let id = Uuid::new_v4().to_string();
|
||||||
let user_id = resolve_user(&db, &user_param).await?;
|
let user_id = resolve_user(&state.db, &user_param).await?;
|
||||||
|
|
||||||
let msg_type = match body.r#type {
|
let msg_type = match body.r#type {
|
||||||
MessageType::Text => "text",
|
MessageType::Text => "text",
|
||||||
@@ -256,7 +257,7 @@ pub async fn post_message(
|
|||||||
.bind(&body.content)
|
.bind(&body.content)
|
||||||
.bind(&metadata_json)
|
.bind(&metadata_json)
|
||||||
.bind(&reply_to)
|
.bind(&reply_to)
|
||||||
.execute(&db)
|
.execute(&state.db)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
// Fetch the full message with user
|
// Fetch the full message with user
|
||||||
@@ -265,10 +266,16 @@ pub async fn post_message(
|
|||||||
FROM messages m JOIN users u ON m.user_id = u.id WHERE m.id = ?",
|
FROM messages m JOIN users u ON m.user_id = u.id WHERE m.id = ?",
|
||||||
)
|
)
|
||||||
.bind(&id)
|
.bind(&id)
|
||||||
.fetch_one(&db)
|
.fetch_one(&state.db)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
Ok((StatusCode::CREATED, Json(row.to_api_message())))
|
let message = row.to_api_message();
|
||||||
|
|
||||||
|
// Broadcast to WebSocket subscribers
|
||||||
|
let tx = state.get_sender(&channel_id).await;
|
||||||
|
let _ = tx.send(WsEvent::Message(message.clone()));
|
||||||
|
|
||||||
|
Ok((StatusCode::CREATED, Json(message)))
|
||||||
}
|
}
|
||||||
|
|
||||||
// ── Joined row type for message + user ──
|
// ── Joined row type for message + user ──
|
||||||
|
|||||||
43
crates/colony/src/state.rs
Normal file
43
crates/colony/src/state.rs
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
use colony_types::WsEvent;
|
||||||
|
use sqlx::SqlitePool;
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use tokio::sync::{broadcast, RwLock};
|
||||||
|
|
||||||
|
const BROADCAST_CAPACITY: usize = 256;
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct AppState {
|
||||||
|
pub db: SqlitePool,
|
||||||
|
channels: Arc<RwLock<HashMap<String, broadcast::Sender<WsEvent>>>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AppState {
|
||||||
|
pub fn new(db: SqlitePool) -> Self {
|
||||||
|
Self {
|
||||||
|
db,
|
||||||
|
channels: Arc::new(RwLock::new(HashMap::new())),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn get_sender(&self, channel_id: &str) -> broadcast::Sender<WsEvent> {
|
||||||
|
let read = self.channels.read().await;
|
||||||
|
if let Some(tx) = read.get(channel_id) {
|
||||||
|
return tx.clone();
|
||||||
|
}
|
||||||
|
drop(read);
|
||||||
|
|
||||||
|
let mut write = self.channels.write().await;
|
||||||
|
// Double-check after acquiring write lock
|
||||||
|
if let Some(tx) = write.get(channel_id) {
|
||||||
|
return tx.clone();
|
||||||
|
}
|
||||||
|
let (tx, _) = broadcast::channel(BROADCAST_CAPACITY);
|
||||||
|
write.insert(channel_id.to_string(), tx.clone());
|
||||||
|
tx
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn subscribe(&self, channel_id: &str) -> broadcast::Receiver<WsEvent> {
|
||||||
|
self.get_sender(channel_id).await.subscribe()
|
||||||
|
}
|
||||||
|
}
|
||||||
102
crates/colony/src/ws.rs
Normal file
102
crates/colony/src/ws.rs
Normal file
@@ -0,0 +1,102 @@
|
|||||||
|
use axum::{
|
||||||
|
extract::{ws::WebSocket, Path, State, WebSocketUpgrade},
|
||||||
|
response::IntoResponse,
|
||||||
|
};
|
||||||
|
use futures_util::{SinkExt, StreamExt};
|
||||||
|
use serde::Deserialize;
|
||||||
|
use std::time::Duration;
|
||||||
|
use tokio::time::interval;
|
||||||
|
|
||||||
|
use crate::state::AppState;
|
||||||
|
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
struct AuthMessage {
|
||||||
|
#[serde(rename = "type")]
|
||||||
|
msg_type: String,
|
||||||
|
#[allow(dead_code)]
|
||||||
|
token: Option<String>,
|
||||||
|
user: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn ws_handler(
|
||||||
|
ws: WebSocketUpgrade,
|
||||||
|
Path(channel_id): Path<String>,
|
||||||
|
State(state): State<AppState>,
|
||||||
|
) -> impl IntoResponse {
|
||||||
|
ws.on_upgrade(move |socket| handle_socket(socket, channel_id, state))
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn handle_socket(socket: WebSocket, channel_id: String, state: AppState) {
|
||||||
|
let (mut sender, mut receiver) = socket.split();
|
||||||
|
|
||||||
|
// Wait for auth message (first message must be {"type":"auth", "user":"..."})
|
||||||
|
let _user = match tokio::time::timeout(Duration::from_secs(10), receiver.next()).await {
|
||||||
|
Ok(Some(Ok(msg))) => {
|
||||||
|
if let axum::extract::ws::Message::Text(text) = msg {
|
||||||
|
match serde_json::from_str::<AuthMessage>(&text) {
|
||||||
|
Ok(auth) if auth.msg_type == "auth" => {
|
||||||
|
auth.user.unwrap_or_else(|| "anonymous".to_string())
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
let _ = sender
|
||||||
|
.send(axum::extract::ws::Message::Text(
|
||||||
|
r#"{"error":"first message must be {\"type\":\"auth\",\"user\":\"...\"}}"#.into(),
|
||||||
|
))
|
||||||
|
.await;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => return, // Timeout or disconnect
|
||||||
|
};
|
||||||
|
|
||||||
|
// Subscribe to channel broadcast
|
||||||
|
let mut rx = state.subscribe(&channel_id).await;
|
||||||
|
|
||||||
|
// Send confirmation
|
||||||
|
let _ = sender
|
||||||
|
.send(axum::extract::ws::Message::Text(
|
||||||
|
r#"{"event":"connected"}"#.into(),
|
||||||
|
))
|
||||||
|
.await;
|
||||||
|
|
||||||
|
// Ping interval for keepalive
|
||||||
|
let mut ping_interval = interval(Duration::from_secs(30));
|
||||||
|
|
||||||
|
loop {
|
||||||
|
tokio::select! {
|
||||||
|
// Broadcast message received → forward to client
|
||||||
|
msg = rx.recv() => {
|
||||||
|
match msg {
|
||||||
|
Ok(event) => {
|
||||||
|
let json = serde_json::to_string(&event).unwrap();
|
||||||
|
if sender.send(axum::extract::ws::Message::Text(json.into())).await.is_err() {
|
||||||
|
break; // Client disconnected
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(tokio::sync::broadcast::error::RecvError::Lagged(n)) => {
|
||||||
|
eprintln!("colony: ws client lagged by {} messages", n);
|
||||||
|
}
|
||||||
|
Err(_) => break, // Channel closed
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Client message (we don't expect any after auth, but drain to detect disconnect)
|
||||||
|
msg = receiver.next() => {
|
||||||
|
match msg {
|
||||||
|
Some(Ok(axum::extract::ws::Message::Close(_))) | None => break,
|
||||||
|
Some(Err(_)) => break,
|
||||||
|
_ => {} // Ignore other messages
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Keepalive ping
|
||||||
|
_ = ping_interval.tick() => {
|
||||||
|
if sender.send(axum::extract::ws::Message::Ping(vec![].into())).await.is_err() {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -6,6 +6,7 @@ import { ChannelSidebar } from "@/components/ChannelSidebar";
|
|||||||
import { MessageItem } from "@/components/MessageItem";
|
import { MessageItem } from "@/components/MessageItem";
|
||||||
import { ComposeBox } from "@/components/ComposeBox";
|
import { ComposeBox } from "@/components/ComposeBox";
|
||||||
import { Sheet, SheetContent, SheetTrigger } from "@/components/ui/sheet";
|
import { Sheet, SheetContent, SheetTrigger } from "@/components/ui/sheet";
|
||||||
|
import { useChannelSocket } from "@/hooks/useChannelSocket";
|
||||||
|
|
||||||
export default function App() {
|
export default function App() {
|
||||||
const [channels, setChannels] = useState<Channel[]>([]);
|
const [channels, setChannels] = useState<Channel[]>([]);
|
||||||
@@ -34,10 +35,21 @@ export default function App() {
|
|||||||
setMessages(msgs);
|
setMessages(msgs);
|
||||||
}
|
}
|
||||||
} catch {
|
} catch {
|
||||||
// Silently ignore fetch errors during polling
|
// Silently ignore fetch errors
|
||||||
}
|
}
|
||||||
}, []);
|
}, []);
|
||||||
|
|
||||||
|
// WebSocket: append new messages in real-time
|
||||||
|
const handleWsMessage = useCallback((msg: Message) => {
|
||||||
|
setMessages((prev) => {
|
||||||
|
// Deduplicate — the sender also fetches after posting
|
||||||
|
if (prev.some((m) => m.id === msg.id)) return prev;
|
||||||
|
return [...prev, msg];
|
||||||
|
});
|
||||||
|
}, []);
|
||||||
|
|
||||||
|
useChannelSocket(activeChannelId, handleWsMessage);
|
||||||
|
|
||||||
useEffect(() => { loadChannels(); }, [loadChannels]);
|
useEffect(() => { loadChannels(); }, [loadChannels]);
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
@@ -47,6 +59,7 @@ export default function App() {
|
|||||||
loadMessages();
|
loadMessages();
|
||||||
}, [activeChannelId, loadMessages]);
|
}, [activeChannelId, loadMessages]);
|
||||||
|
|
||||||
|
// Auto-scroll only on new messages
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (messages.length > prevMsgCountRef.current && scrollRef.current) {
|
if (messages.length > prevMsgCountRef.current && scrollRef.current) {
|
||||||
scrollRef.current.scrollTop = scrollRef.current.scrollHeight;
|
scrollRef.current.scrollTop = scrollRef.current.scrollHeight;
|
||||||
@@ -54,11 +67,6 @@ export default function App() {
|
|||||||
prevMsgCountRef.current = messages.length;
|
prevMsgCountRef.current = messages.length;
|
||||||
}, [messages]);
|
}, [messages]);
|
||||||
|
|
||||||
useEffect(() => {
|
|
||||||
const interval = setInterval(loadMessages, 3000);
|
|
||||||
return () => clearInterval(interval);
|
|
||||||
}, [loadMessages]);
|
|
||||||
|
|
||||||
const messagesById = new Map(messages.map((m) => [m.id, m]));
|
const messagesById = new Map(messages.map((m) => [m.id, m]));
|
||||||
const activeChannel = channels.find((c) => c.id === activeChannelId);
|
const activeChannel = channels.find((c) => c.id === activeChannelId);
|
||||||
|
|
||||||
@@ -76,15 +84,12 @@ export default function App() {
|
|||||||
|
|
||||||
return (
|
return (
|
||||||
<div className="flex h-full">
|
<div className="flex h-full">
|
||||||
{/* Desktop sidebar */}
|
|
||||||
<div className="hidden md:block">
|
<div className="hidden md:block">
|
||||||
{sidebar}
|
{sidebar}
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<div className="flex-1 flex flex-col min-w-0">
|
<div className="flex-1 flex flex-col min-w-0">
|
||||||
{/* Channel header */}
|
|
||||||
<div className="px-3 py-2 md:px-4 border-b border-border flex items-center gap-2">
|
<div className="px-3 py-2 md:px-4 border-b border-border flex items-center gap-2">
|
||||||
{/* Mobile: Sheet trigger */}
|
|
||||||
<Sheet open={sheetOpen} onOpenChange={setSheetOpen}>
|
<Sheet open={sheetOpen} onOpenChange={setSheetOpen}>
|
||||||
<SheetTrigger className="md:hidden p-1 h-8 w-8 text-muted-foreground hover:text-foreground rounded-sm">
|
<SheetTrigger className="md:hidden p-1 h-8 w-8 text-muted-foreground hover:text-foreground rounded-sm">
|
||||||
=
|
=
|
||||||
@@ -112,7 +117,6 @@ export default function App() {
|
|||||||
)}
|
)}
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
{/* Messages */}
|
|
||||||
<div ref={scrollRef} className="flex-1 overflow-y-auto">
|
<div ref={scrollRef} className="flex-1 overflow-y-auto">
|
||||||
{messages.length === 0 && activeChannelId && (
|
{messages.length === 0 && activeChannelId && (
|
||||||
<div className="flex items-center justify-center h-full text-muted-foreground text-xs">
|
<div className="flex items-center justify-center h-full text-muted-foreground text-xs">
|
||||||
|
|||||||
61
ui/colony/src/hooks/useChannelSocket.ts
Normal file
61
ui/colony/src/hooks/useChannelSocket.ts
Normal file
@@ -0,0 +1,61 @@
|
|||||||
|
import { useEffect, useRef, useCallback } from "react";
|
||||||
|
import type { Message } from "@/types/Message";
|
||||||
|
import { getCurrentUsername } from "@/api";
|
||||||
|
|
||||||
|
interface WsMessageEvent {
|
||||||
|
event: "message";
|
||||||
|
data: Message;
|
||||||
|
}
|
||||||
|
|
||||||
|
interface WsConnectedEvent {
|
||||||
|
event: "connected";
|
||||||
|
}
|
||||||
|
|
||||||
|
type WsEvent = WsMessageEvent | WsConnectedEvent;
|
||||||
|
|
||||||
|
export function useChannelSocket(
|
||||||
|
channelId: string | null,
|
||||||
|
onMessage: (msg: Message) => void,
|
||||||
|
) {
|
||||||
|
const wsRef = useRef<WebSocket | null>(null);
|
||||||
|
const reconnectTimer = useRef<ReturnType<typeof setTimeout>>();
|
||||||
|
|
||||||
|
const connect = useCallback(() => {
|
||||||
|
if (!channelId) return;
|
||||||
|
|
||||||
|
const protocol = window.location.protocol === "https:" ? "wss:" : "ws:";
|
||||||
|
const host = window.location.host;
|
||||||
|
const ws = new WebSocket(`${protocol}//${host}/ws/${channelId}`);
|
||||||
|
|
||||||
|
ws.onopen = () => {
|
||||||
|
// Send auth message
|
||||||
|
ws.send(JSON.stringify({
|
||||||
|
type: "auth",
|
||||||
|
user: getCurrentUsername(),
|
||||||
|
}));
|
||||||
|
};
|
||||||
|
|
||||||
|
ws.onmessage = (e) => {
|
||||||
|
const event: WsEvent = JSON.parse(e.data);
|
||||||
|
if (event.event === "message") {
|
||||||
|
onMessage(event.data);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
ws.onclose = () => {
|
||||||
|
// Reconnect after 3s
|
||||||
|
reconnectTimer.current = setTimeout(connect, 3000);
|
||||||
|
};
|
||||||
|
|
||||||
|
wsRef.current = ws;
|
||||||
|
}, [channelId, onMessage]);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
connect();
|
||||||
|
return () => {
|
||||||
|
clearTimeout(reconnectTimer.current);
|
||||||
|
wsRef.current?.close();
|
||||||
|
wsRef.current = null;
|
||||||
|
};
|
||||||
|
}, [connect]);
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user