Files
apes/crates/colony/src/routes.rs
limiteinductive 17cca7b077 S5: WebSocket real-time — per-channel broadcast, auto-reconnect
Backend:
- AppState with per-channel broadcast::Sender map
- WS handler: auth via first message, keepalive pings, broadcast forwarding
- post_message broadcasts WsEvent::Message to all subscribers

Frontend:
- useChannelSocket hook: connects, auths, appends messages, auto-reconnects
- Removed 3s polling — WebSocket is primary, initial load via REST
- Deduplication on WS messages (sender also fetches after post)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-29 20:29:07 +02:00

329 lines
9.7 KiB
Rust

use axum::{
extract::{Path, Query, State},
http::StatusCode,
response::{IntoResponse, Response},
Json,
};
use colony_types::*;
use sqlx::SqlitePool;
use uuid::Uuid;
use crate::db::*;
use crate::state::AppState;
// ── Error handling ──
pub enum AppError {
NotFound(String),
Conflict(String),
BadRequest(String),
Internal(String),
}
impl IntoResponse for AppError {
fn into_response(self) -> Response {
let (status, message) = match self {
AppError::NotFound(m) => (StatusCode::NOT_FOUND, m),
AppError::Conflict(m) => (StatusCode::CONFLICT, m),
AppError::BadRequest(m) => (StatusCode::BAD_REQUEST, m),
AppError::Internal(m) => (StatusCode::INTERNAL_SERVER_ERROR, m),
};
(status, Json(serde_json::json!({"error": message}))).into_response()
}
}
impl From<sqlx::Error> for AppError {
fn from(e: sqlx::Error) -> Self {
match &e {
sqlx::Error::Database(db_err) if db_err.message().contains("UNIQUE") => {
AppError::Conflict(format!("Already exists: {}", db_err.message()))
}
sqlx::Error::RowNotFound => AppError::NotFound("Not found".into()),
_ => AppError::Internal(format!("Database error: {e}")),
}
}
}
type Result<T> = std::result::Result<T, AppError>;
// ── User identity from ?user= param ──
#[derive(Debug, serde::Deserialize)]
pub struct UserParam {
pub user: Option<String>,
}
async fn resolve_user(db: &SqlitePool, param: &UserParam) -> Result<String> {
let username = param.user.as_deref().unwrap_or("benji");
let row = sqlx::query_scalar::<_, String>(
"SELECT id FROM users WHERE username = ?",
)
.bind(username)
.fetch_optional(db)
.await?;
match row {
Some(id) => Ok(id),
None => Err(AppError::BadRequest(format!("Unknown user: {username}"))),
}
}
// ── Health ──
pub async fn health() -> &'static str {
"ok"
}
// ── Channels ──
pub async fn list_channels(State(state): State<AppState>) -> Result<Json<Vec<Channel>>> {
let rows = sqlx::query_as::<_, ChannelRow>("SELECT * FROM channels ORDER BY created_at")
.fetch_all(&state.db)
.await?;
let channels: Vec<Channel> = rows.iter().map(|r| r.to_api()).collect();
Ok(Json(channels))
}
pub async fn list_users(State(state): State<AppState>) -> Result<Json<Vec<User>>> {
let rows = sqlx::query_as::<_, UserRow>("SELECT * FROM users ORDER BY created_at")
.fetch_all(&state.db)
.await?;
Ok(Json(rows.iter().map(|r| r.to_api()).collect()))
}
pub async fn get_me(
State(state): State<AppState>,
Query(user_param): Query<UserParam>,
) -> Result<Json<User>> {
let username = user_param.user.as_deref().unwrap_or("benji");
let row = sqlx::query_as::<_, UserRow>("SELECT * FROM users WHERE username = ?")
.bind(username)
.fetch_optional(&state.db)
.await?;
match row {
Some(r) => Ok(Json(r.to_api())),
None => Err(AppError::NotFound(format!("User {username} not found"))),
}
}
pub async fn create_channel(
State(state): State<AppState>,
Query(user_param): Query<UserParam>,
Json(body): Json<CreateChannel>,
) -> Result<impl IntoResponse> {
let id = Uuid::new_v4().to_string();
let created_by = resolve_user(&state.db, &user_param).await?;
sqlx::query("INSERT INTO channels (id, name, description, created_by) VALUES (?, ?, ?, ?)")
.bind(&id)
.bind(&body.name)
.bind(&body.description)
.bind(created_by)
.execute(&state.db)
.await?;
let row = sqlx::query_as::<_, ChannelRow>("SELECT * FROM channels WHERE id = ?")
.bind(&id)
.fetch_one(&state.db)
.await?;
Ok((StatusCode::CREATED, Json(row.to_api())))
}
pub async fn get_channel(
State(state): State<AppState>,
Path(id): Path<String>,
) -> Result<Json<Channel>> {
let row = sqlx::query_as::<_, ChannelRow>("SELECT * FROM channels WHERE id = ?")
.bind(&id)
.fetch_optional(&state.db)
.await?;
match row {
Some(r) => Ok(Json(r.to_api())),
None => Err(AppError::NotFound(format!("Channel {id} not found"))),
}
}
// ── Messages ──
pub async fn list_messages(
State(state): State<AppState>,
Path(channel_id): Path<String>,
Query(query): Query<MessageQuery>,
) -> Result<Json<Vec<Message>>> {
let mut sql = String::from(
"SELECT m.*, u.id as u_id, u.username, u.display_name, u.role, u.created_at as u_created_at \
FROM messages m JOIN users u ON m.user_id = u.id \
WHERE m.channel_id = ?",
);
let mut binds: Vec<String> = vec![channel_id.clone()];
if let Some(after_seq) = &query.after_seq {
sql.push_str(" AND m.seq > ?");
binds.push(after_seq.to_string());
}
if let Some(msg_type) = &query.r#type {
sql.push_str(" AND m.type = ?");
binds.push(match msg_type {
MessageType::Text => "text",
MessageType::Code => "code",
MessageType::Result => "result",
MessageType::Error => "error",
MessageType::Plan => "plan",
}
.to_string());
}
if let Some(user_id) = &query.user_id {
sql.push_str(" AND m.user_id = ?");
binds.push(user_id.to_string());
}
sql.push_str(" ORDER BY m.seq ASC");
let mut q = sqlx::query_as::<_, MessageWithUserRow>(&sql);
for b in &binds {
q = q.bind(b);
}
let rows = q.fetch_all(&state.db).await?;
let messages: Vec<Message> = rows.iter().map(|r| r.to_api_message()).collect();
Ok(Json(messages))
}
pub async fn post_message(
State(state): State<AppState>,
Path(channel_id): Path<String>,
Query(user_param): Query<UserParam>,
Json(body): Json<PostMessage>,
) -> Result<impl IntoResponse> {
// Verify channel exists
let channel_exists = sqlx::query_scalar::<_, i64>(
"SELECT COUNT(*) FROM channels WHERE id = ?",
)
.bind(&channel_id)
.fetch_one(&state.db)
.await?;
if channel_exists == 0 {
return Err(AppError::NotFound(format!("Channel {channel_id} not found")));
}
// Verify reply_to is in same channel (if provided)
if let Some(ref reply_id) = body.reply_to {
let reply_channel = sqlx::query_scalar::<_, String>(
"SELECT channel_id FROM messages WHERE id = ?",
)
.bind(reply_id.to_string())
.fetch_optional(&state.db)
.await?;
match reply_channel {
None => return Err(AppError::BadRequest(format!("reply_to message {reply_id} not found"))),
Some(ch) if ch != channel_id => {
return Err(AppError::BadRequest("reply_to must reference a message in the same channel".into()));
}
_ => {}
}
}
let id = Uuid::new_v4().to_string();
let user_id = resolve_user(&state.db, &user_param).await?;
let msg_type = match body.r#type {
MessageType::Text => "text",
MessageType::Code => "code",
MessageType::Result => "result",
MessageType::Error => "error",
MessageType::Plan => "plan",
};
let metadata_json = body
.metadata
.as_ref()
.map(|m| serde_json::to_string(m).unwrap());
let reply_to = body.reply_to.map(|r| r.to_string());
// seq is AUTOINCREMENT — no race conditions, no manual tracking
sqlx::query(
"INSERT INTO messages (id, channel_id, user_id, type, content, metadata, reply_to) \
VALUES (?, ?, ?, ?, ?, ?, ?)",
)
.bind(&id)
.bind(&channel_id)
.bind(user_id)
.bind(msg_type)
.bind(&body.content)
.bind(&metadata_json)
.bind(&reply_to)
.execute(&state.db)
.await?;
// Fetch the full message with user
let row = sqlx::query_as::<_, MessageWithUserRow>(
"SELECT m.*, u.id as u_id, u.username, u.display_name, u.role, u.created_at as u_created_at \
FROM messages m JOIN users u ON m.user_id = u.id WHERE m.id = ?",
)
.bind(&id)
.fetch_one(&state.db)
.await?;
let message = row.to_api_message();
// Broadcast to WebSocket subscribers
let tx = state.get_sender(&channel_id).await;
let _ = tx.send(WsEvent::Message(message.clone()));
Ok((StatusCode::CREATED, Json(message)))
}
// ── Joined row type for message + user ──
#[derive(Debug, sqlx::FromRow)]
pub struct MessageWithUserRow {
pub id: String,
pub seq: i64,
pub channel_id: String,
pub user_id: String,
pub r#type: String,
pub content: String,
pub metadata: Option<String>,
pub reply_to: Option<String>,
pub created_at: String,
pub updated_at: Option<String>,
pub deleted_at: Option<String>,
pub u_id: String,
pub username: String,
pub display_name: String,
pub role: String,
pub u_created_at: String,
}
impl MessageWithUserRow {
pub fn to_api_message(&self) -> Message {
let user_row = UserRow {
id: self.u_id.clone(),
username: self.username.clone(),
display_name: self.display_name.clone(),
role: self.role.clone(),
password_hash: None,
created_at: self.u_created_at.clone(),
};
let msg_row = MessageRow {
id: self.id.clone(),
seq: self.seq,
channel_id: self.channel_id.clone(),
user_id: self.user_id.clone(),
r#type: self.r#type.clone(),
content: self.content.clone(),
metadata: self.metadata.clone(),
reply_to: self.reply_to.clone(),
created_at: self.created_at.clone(),
updated_at: self.updated_at.clone(),
deleted_at: self.deleted_at.clone(),
};
msg_row.to_api(&user_row)
}
}