use std::sync::Arc; use axum::{ extract::{State, WebSocketUpgrade, ws::{Message, WebSocket}}, response::Response, routing::get, Router, }; use futures::{SinkExt, StreamExt}; use sqlx::sqlite::SqlitePool; use tokio::sync::broadcast; use crate::agent::WsMessage; use crate::worker::{WorkerInfo, WorkerManager, WorkerToServer}; pub struct WsWorkerState { pub mgr: Arc, pub pool: SqlitePool, pub broadcast_fn: Arc broadcast::Sender + Send + Sync>, } pub fn router(mgr: Arc, pool: SqlitePool, broadcast_fn: Arc broadcast::Sender + Send + Sync>) -> Router { let state = Arc::new(WsWorkerState { mgr, pool, broadcast_fn }); Router::new() .route("/", get(ws_handler)) .with_state(state) } async fn ws_handler( ws: WebSocketUpgrade, State(state): State>, ) -> Response { ws.on_upgrade(move |socket| handle_worker_socket(socket, state)) } async fn handle_worker_socket(socket: WebSocket, state: Arc) { let (mut sender, mut receiver) = socket.split(); // First message must be registration let (name, mut msg_rx) = loop { match receiver.next().await { Some(Ok(Message::Text(text))) => { match serde_json::from_str::(&text) { Ok(v) if v["type"] == "register" => { let info: WorkerInfo = match serde_json::from_value(v["info"].clone()) { Ok(i) => i, Err(_) => { let _ = sender.send(Message::Text( r#"{"type":"error","message":"Invalid worker info"}"#.into(), )).await; return; } }; let name = info.name.clone(); let rx = state.mgr.register(name.clone(), info).await; let ack = serde_json::json!({ "type": "registered", "name": &name }); let _ = sender.send(Message::Text(ack.to_string().into())).await; break (name, rx); } _ => { let _ = sender.send(Message::Text( r#"{"type":"error","message":"First message must be register"}"#.into(), )).await; return; } } } Some(Ok(Message::Close(_))) | None => return, _ => continue, } }; let name_clone = name.clone(); let mgr_for_cleanup = state.mgr.clone(); // Task: send ServerToWorker messages from msg_rx to the WebSocket let send_task = tokio::spawn(async move { while let Some(msg) = msg_rx.recv().await { if let Ok(json) = serde_json::to_string(&msg) { if sender.send(Message::Text(json.into())).await.is_err() { break; } } } }); // Task: receive WorkerToServer messages from the WebSocket → process let state_clone = state.clone(); let recv_task = tokio::spawn(async move { while let Some(Ok(msg)) = receiver.next().await { match msg { Message::Text(text) => { if let Ok(worker_msg) = serde_json::from_str::(&text) { handle_worker_message(&state_clone, worker_msg).await; } } Message::Close(_) => break, _ => {} } } }); tokio::select! { _ = send_task => {}, _ = recv_task => {}, } mgr_for_cleanup.unregister(&name_clone).await; } async fn handle_worker_message(state: &WsWorkerState, msg: WorkerToServer) { match msg { WorkerToServer::Register { .. } => { // Already handled during initial handshake } WorkerToServer::Update { workflow_id, update } => { // Get project_id for broadcasting (look up from DB) let project_id: Option = sqlx::query_scalar( "SELECT project_id FROM workflows WHERE id = ?" ) .bind(&workflow_id) .fetch_optional(&state.pool) .await .ok() .flatten(); let broadcast_tx = if let Some(ref pid) = project_id { Some((state.broadcast_fn)(pid)) } else { None }; // Check if this is a workflow completion if let crate::sink::AgentUpdate::WorkflowComplete { ref workflow_id, .. } = update { state.mgr.complete_workflow(workflow_id).await; } // Process the update: write to DB + broadcast crate::sink::handle_single_update(&update, &state.pool, broadcast_tx.as_ref()).await; } } }