use std::sync::Arc; use axum::{ extract::{State, WebSocketUpgrade, ws::{Message, WebSocket}}, response::Response, routing::get, Router, }; use futures::{SinkExt, StreamExt}; use serde::Deserialize; use crate::worker::{WorkerInfo, WorkerManager, WorkerResult}; pub fn router(mgr: Arc) -> Router { Router::new() .route("/", get(ws_handler)) .with_state(mgr) } async fn ws_handler( ws: WebSocketUpgrade, State(mgr): State>, ) -> Response { ws.on_upgrade(move |socket| handle_worker_socket(socket, mgr)) } #[derive(Deserialize)] #[serde(tag = "type")] enum WorkerMessage { #[serde(rename = "register")] Register { info: WorkerInfo }, #[serde(rename = "result")] Result(WorkerResult), } async fn handle_worker_socket(socket: WebSocket, mgr: Arc) { let (mut sender, mut receiver) = socket.split(); // First message must be registration let (name, mut job_rx) = loop { match receiver.next().await { Some(Ok(Message::Text(text))) => { match serde_json::from_str::(&text) { Ok(WorkerMessage::Register { info }) => { let name = info.name.clone(); let rx = mgr.register(name.clone(), info).await; // Ack let ack = serde_json::json!({ "type": "registered", "name": &name }); let _ = sender.send(Message::Text(ack.to_string().into())).await; break (name, rx); } _ => { let _ = sender.send(Message::Text( r#"{"type":"error","message":"First message must be register"}"#.into(), )).await; return; } } } Some(Ok(Message::Close(_))) | None => return, _ => continue, } }; // Main loop: forward jobs to worker, receive results let name_clone = name.clone(); let mgr_clone = mgr.clone(); // Task: send jobs from job_rx to the WebSocket let send_task = tokio::spawn(async move { while let Some(req) = job_rx.recv().await { let msg = serde_json::json!({ "type": "execute", "job_id": req.job_id, "script": req.script, }); if sender.send(Message::Text(msg.to_string().into())).await.is_err() { break; } } }); // Task: receive results from the WebSocket let recv_task = tokio::spawn(async move { while let Some(Ok(msg)) = receiver.next().await { match msg { Message::Text(text) => { if let Ok(WorkerMessage::Result(result)) = serde_json::from_str(&text) { mgr_clone.report_result(result).await; } } Message::Close(_) => break, _ => {} } } }); tokio::select! { _ = send_task => {}, _ = recv_task => {}, } mgr.unregister(&name_clone).await; }