From 2066a727b0a641971484434de125179ff197b0ce Mon Sep 17 00:00:00 2001 From: Fam Zheng Date: Mon, 6 Apr 2026 21:08:53 +0100 Subject: [PATCH] fix: shared ws_tx across reconnects, workflow runs in spawned task --- src/worker_runner.rs | 225 +++++++++++++++++++------------------------ 1 file changed, 99 insertions(+), 126 deletions(-) diff --git a/src/worker_runner.rs b/src/worker_runner.rs index 5290561..7516ae6 100644 --- a/src/worker_runner.rs +++ b/src/worker_runner.rs @@ -10,72 +10,62 @@ use crate::sink::{AgentUpdate, ServiceManager}; use crate::worker::{ServerToWorker, WorkerInfo, WorkerToServer}; fn collect_worker_info(name: &str) -> WorkerInfo { - let cpu = std::fs::read_to_string("/proc/cpuinfo") - .ok() - .and_then(|s| { - s.lines() - .find(|l| l.starts_with("model name")) - .map(|l| l.split(':').nth(1).unwrap_or("").trim().to_string()) - }) + let cpu = std::fs::read_to_string("/proc/cpuinfo").ok() + .and_then(|s| s.lines().find(|l| l.starts_with("model name")) + .map(|l| l.split(':').nth(1).unwrap_or("").trim().to_string())) .unwrap_or_else(|| "unknown".into()); - - let memory = std::fs::read_to_string("/proc/meminfo") - .ok() - .and_then(|s| { - s.lines() - .find(|l| l.starts_with("MemTotal")) - .and_then(|l| l.split_whitespace().nth(1)) - .and_then(|kb| kb.parse::().ok()) - .map(|kb| format!("{:.1} GB", kb as f64 / 1_048_576.0)) - }) + let memory = std::fs::read_to_string("/proc/meminfo").ok() + .and_then(|s| s.lines().find(|l| l.starts_with("MemTotal")) + .and_then(|l| l.split_whitespace().nth(1)) + .and_then(|kb| kb.parse::().ok()) + .map(|kb| format!("{:.1} GB", kb as f64 / 1_048_576.0))) .unwrap_or_else(|| "unknown".into()); - let gpu = std::process::Command::new("nvidia-smi") - .arg("--query-gpu=name") - .arg("--format=csv,noheader") - .output() - .ok() - .and_then(|o| { - if o.status.success() { - Some(String::from_utf8_lossy(&o.stdout).trim().to_string()) - } else { - None - } - }) + .arg("--query-gpu=name").arg("--format=csv,noheader").output().ok() + .and_then(|o| if o.status.success() { Some(String::from_utf8_lossy(&o.stdout).trim().to_string()) } else { None }) .unwrap_or_else(|| "none".into()); - WorkerInfo { - name: name.to_string(), - cpu, - memory, - gpu, + name: name.to_string(), cpu, memory, gpu, os: std::env::consts::OS.to_string(), - kernel: std::process::Command::new("uname") - .arg("-r") - .output() - .ok() + kernel: std::process::Command::new("uname").arg("-r").output().ok() .map(|o| String::from_utf8_lossy(&o.stdout).trim().to_string()) .unwrap_or_else(|| "unknown".into()), } } +/// Shared WebSocket sender that can be swapped on reconnect. +type SharedWsTx = Arc>, + Message +>>>>; + pub async fn run(server_url: &str, worker_name: &str, llm_config: &crate::LlmConfig) -> anyhow::Result<()> { tracing::info!("Tori worker '{}' connecting to {} (model={})", worker_name, server_url, llm_config.model); + let svc_mgr = ServiceManager::new(9100); + let ws_tx: SharedWsTx = Arc::new(tokio::sync::Mutex::new(None)); + let comment_tx: Arc>>> = + Arc::new(tokio::sync::Mutex::new(None)); + loop { - match connect_and_run(server_url, worker_name, llm_config).await { - Ok(()) => { - tracing::info!("Worker connection closed, reconnecting in 5s..."); - } - Err(e) => { - tracing::error!("Worker error: {}, reconnecting in 5s...", e); - } + match connect_and_run(server_url, worker_name, llm_config, &svc_mgr, &ws_tx, &comment_tx).await { + Ok(()) => tracing::info!("Connection closed, reconnecting in 5s..."), + Err(e) => tracing::error!("Worker error: {}, reconnecting in 5s...", e), } + // Clear ws_tx so relay tasks know the connection is gone + *ws_tx.lock().await = None; tokio::time::sleep(std::time::Duration::from_secs(5)).await; } } -async fn connect_and_run(server_url: &str, worker_name: &str, llm_config: &crate::LlmConfig) -> anyhow::Result<()> { +async fn connect_and_run( + server_url: &str, + worker_name: &str, + llm_config: &crate::LlmConfig, + svc_mgr: &ServiceManager, + shared_ws_tx: &SharedWsTx, + comment_tx: &Arc>>>, +) -> anyhow::Result<()> { let (ws_stream, _) = connect_async(server_url).await?; let (mut ws_tx, mut ws_rx) = ws_stream.split(); @@ -99,26 +89,31 @@ async fn connect_and_run(server_url: &str, worker_name: &str, llm_config: &crate } } - let svc_mgr = ServiceManager::new(9100); - let ws_tx = Arc::new(tokio::sync::Mutex::new(ws_tx)); + // Store the new ws_tx so relay tasks can use it + *shared_ws_tx.lock().await = Some(ws_tx); - // Ping task to keep connection alive - let ping_tx = ws_tx.clone(); - tokio::spawn(async move { + // Ping keepalive + let ping_tx = shared_ws_tx.clone(); + let ping_task = tokio::spawn(async move { let mut interval = tokio::time::interval(std::time::Duration::from_secs(30)); loop { interval.tick().await; - let mut tx = ping_tx.lock().await; - if tx.send(Message::Ping(vec![].into())).await.is_err() { + let guard = ping_tx.lock().await; + if guard.is_none() { break; } + // Can't send while holding mutex with Option, drop and re-acquire + drop(guard); + let mut guard = ping_tx.lock().await; + if let Some(ref mut tx) = *guard { + if tx.send(Message::Ping(vec![].into())).await.is_err() { + *guard = None; + break; + } + } else { break; } } }); - // Channel for forwarding comments to the running workflow - let comment_tx: Arc>>> = - Arc::new(tokio::sync::Mutex::new(None)); - // Main message loop while let Some(msg) = ws_rx.next().await { let text = match msg? { @@ -130,20 +125,13 @@ async fn connect_and_run(server_url: &str, worker_name: &str, llm_config: &crate let server_msg: ServerToWorker = match serde_json::from_str(&text) { Ok(m) => m, - Err(e) => { - tracing::warn!("Failed to parse server message: {}", e); - continue; - } + Err(e) => { tracing::warn!("Bad server message: {}", e); continue; } }; match server_msg { ServerToWorker::WorkflowAssign { - workflow_id, - project_id, - requirement, - template_id: _, - initial_state, - require_plan_approval, + workflow_id, project_id, requirement, + template_id: _, initial_state, require_plan_approval, } => { tracing::info!("Received workflow: {} (project {})", workflow_id, project_id); @@ -160,9 +148,9 @@ async fn connect_and_run(server_url: &str, worker_name: &str, llm_config: &crate let _ = exec.execute("uv venv .venv", &workdir).await; } - // update channel → serialize → WebSocket + // update channel → relay to shared ws_tx let (update_tx, mut update_rx) = mpsc::channel::(64); - let ws_tx_clone = ws_tx.clone(); + let relay_ws_tx = shared_ws_tx.clone(); let wf_id_clone = workflow_id.clone(); tokio::spawn(async move { while let Some(update) = update_rx.recv().await { @@ -170,12 +158,19 @@ async fn connect_and_run(server_url: &str, worker_name: &str, llm_config: &crate workflow_id: wf_id_clone.clone(), update, }; - if let Ok(json) = serde_json::to_string(&msg) { - let mut tx = ws_tx_clone.lock().await; + let json = match serde_json::to_string(&msg) { + Ok(j) => j, + Err(_) => continue, + }; + let mut guard = relay_ws_tx.lock().await; + if let Some(ref mut tx) = *guard { if tx.send(Message::Text(json.into())).await.is_err() { - break; + tracing::warn!("WebSocket send failed, buffering..."); + *guard = None; + // Don't break — keep draining update_rx so agent doesn't block } } + // If ws_tx is None, updates are lost (reconnect will happen) } }); @@ -183,83 +178,61 @@ async fn connect_and_run(server_url: &str, worker_name: &str, llm_config: &crate let (evt_tx, mut evt_rx) = mpsc::channel::(32); *comment_tx.lock().await = Some(evt_tx); - let _ = tokio::fs::create_dir_all(&workdir).await; + let svc = svc_mgr.clone(); + let wf_id = workflow_id.clone(); + let pid = project_id.clone(); + tokio::spawn(async move { + let result = agent::run_agent_loop( + &llm, &exec, &update_tx, &mut evt_rx, + &pid, &wf_id, &requirement, &workdir, &svc, + &instructions, initial_state, None, require_plan_approval, + ).await; - let result = agent::run_agent_loop( - &llm, &exec, &update_tx, &mut evt_rx, - &project_id, &workflow_id, &requirement, &workdir, &svc_mgr, - &instructions, initial_state, None, require_plan_approval, - ).await; + let final_status = if result.is_ok() { "done" } else { "failed" }; + let reason = if let Err(ref e) = result { format!("{}", e) } else { String::new() }; + if let Err(ref e) = result { + tracing::error!("Workflow {} failed: {}", wf_id, e); + let _ = update_tx.send(AgentUpdate::Error { message: format!("{}", e) }).await; + } - let final_status = if result.is_ok() { "done" } else { "failed" }; - if let Err(e) = &result { - tracing::error!("Workflow {} failed: {}", workflow_id, e); - let _ = update_tx.send(AgentUpdate::Error { - message: format!("Agent error: {}", e), + // Sync workspace files to server + sync_workspace(&update_tx, &pid, &workdir).await; + + let _ = update_tx.send(AgentUpdate::WorkflowComplete { + workflow_id: wf_id.clone(), status: final_status.into(), reason, }).await; - } - - // Sync all workspace files to server - sync_workspace(&update_tx, &project_id, &workdir).await; - - let reason = if let Err(ref e) = result { format!("{}", e) } else { String::new() }; - let _ = update_tx.send(AgentUpdate::WorkflowComplete { - workflow_id: workflow_id.clone(), - status: final_status.into(), - reason, - }).await; - - *comment_tx.lock().await = None; - tracing::info!("Workflow {} completed: {}", workflow_id, final_status); + tracing::info!("Workflow {} completed: {}", wf_id, final_status); + }); } ServerToWorker::Comment { workflow_id, content } => { if let Some(ref tx) = *comment_tx.lock().await { - let _ = tx.send(AgentEvent::Comment { - workflow_id, - content, - }).await; + let _ = tx.send(AgentEvent::Comment { workflow_id, content }).await; } } } } + ping_task.abort(); Ok(()) } /// Sync all workspace files to server via FileSync updates. -/// Skips .venv/, __pycache__/, .git/ and files > 1MB. -async fn sync_workspace( - update_tx: &mpsc::Sender, - project_id: &str, - workdir: &str, -) { +async fn sync_workspace(update_tx: &mpsc::Sender, project_id: &str, workdir: &str) { use base64::Engine; let base = std::path::Path::new(workdir); - if !base.exists() { - return; - } + if !base.exists() { return; } let mut stack = vec![base.to_path_buf()]; let mut count = 0u32; while let Some(dir) = stack.pop() { - let mut entries = match tokio::fs::read_dir(&dir).await { - Ok(e) => e, - Err(_) => continue, - }; + let mut entries = match tokio::fs::read_dir(&dir).await { Ok(e) => e, Err(_) => continue }; while let Ok(Some(entry)) = entries.next_entry().await { let name = entry.file_name().to_string_lossy().to_string(); - // Skip dirs we don't want to sync - if matches!(name.as_str(), ".venv" | "__pycache__" | ".git" | "node_modules" | ".mypy_cache") { - continue; - } + if matches!(name.as_str(), ".venv" | "__pycache__" | ".git" | "node_modules" | ".mypy_cache") { continue; } let path = entry.path(); - if path.is_dir() { - stack.push(path); - } else if let Ok(meta) = entry.metadata().await { - // Skip files > 1MB - if meta.len() > 1_048_576 { - continue; - } + if path.is_dir() { stack.push(path); } + else if let Ok(meta) = entry.metadata().await { + if meta.len() > 1_048_576 { continue; } if let Ok(bytes) = tokio::fs::read(&path).await { let rel = path.strip_prefix(base).unwrap_or(&path); let _ = update_tx.send(AgentUpdate::FileSync { @@ -272,5 +245,5 @@ async fn sync_workspace( } } } - tracing::info!("Synced {} files from workspace {}", count, workdir); + tracing::info!("Synced {} files from {}", count, workdir); }