diff --git a/Cargo.lock b/Cargo.lock index a6147fa..4826769 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -26,6 +26,56 @@ dependencies = [ "libc", ] +[[package]] +name = "anstream" +version = "0.6.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "43d5b281e737544384e969a5ccad3f1cdd24b48086a0fc1b2a5262a26b8f4f4a" +dependencies = [ + "anstyle", + "anstyle-parse", + "anstyle-query", + "anstyle-wincon", + "colorchoice", + "is_terminal_polyfill", + "utf8parse", +] + +[[package]] +name = "anstyle" +version = "1.0.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "940b3a0ca603d1eade50a4846a2afffd5ef57a9feac2c0e2ec2e14f9ead76000" + +[[package]] +name = "anstyle-parse" +version = "0.2.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4e7644824f0aa2c7b9384579234ef10eb7efb6a0deb83f9630a49594dd9c15c2" +dependencies = [ + "utf8parse", +] + +[[package]] +name = "anstyle-query" +version = "1.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "40c48f72fd53cd289104fc64099abca73db4166ad86ea0b4341abe65af83dadc" +dependencies = [ + "windows-sys 0.61.2", +] + +[[package]] +name = "anstyle-wincon" +version = "3.0.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "291e6a250ff86cd4a820112fb8898808a366d8f9f58ce16d1f538353ad55747d" +dependencies = [ + "anstyle", + "once_cell_polyfill", + "windows-sys 0.61.2", +] + [[package]] name = "anyhow" version = "1.0.102" @@ -83,7 +133,7 @@ dependencies = [ "sha1", "sync_wrapper", "tokio", - "tokio-tungstenite", + "tokio-tungstenite 0.28.0", "tower", "tower-layer", "tower-service", @@ -216,6 +266,52 @@ dependencies = [ "windows-link", ] +[[package]] +name = "clap" +version = "4.5.60" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2797f34da339ce31042b27d23607e051786132987f595b02ba4f6a6dffb7030a" +dependencies = [ + "clap_builder", + "clap_derive", +] + +[[package]] +name = "clap_builder" +version = "4.5.60" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "24a241312cea5059b13574bb9b3861cabf758b879c15190b37b6d6fd63ab6876" +dependencies = [ + "anstream", + "anstyle", + "clap_lex", + "strsim", +] + +[[package]] +name = "clap_derive" +version = "4.5.55" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a92793da1a46a5f2a02a6f4c46c6496b28c43638adea8306fcb0caa1634f24e5" +dependencies = [ + "heck", + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "clap_lex" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c8d4a3bb8b1e0c1050499d1815f5ab16d04f0959b233085fb31653fbfc9d98f9" + +[[package]] +name = "colorchoice" +version = "1.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d07550c9036bf2ae0c684c4297d503f838287c83c53686d05370d0e139ae570" + [[package]] name = "concurrent-queue" version = "2.5.0" @@ -663,6 +759,17 @@ dependencies = [ "windows-sys 0.61.2", ] +[[package]] +name = "hostname" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "617aaa3557aef3810a6369d0a99fac8a080891b68bd9f9812a1eeda0c0730cbd" +dependencies = [ + "cfg-if", + "libc", + "windows-link", +] + [[package]] name = "http" version = "1.4.0" @@ -750,7 +857,7 @@ dependencies = [ "tokio", "tokio-rustls", "tower-service", - "webpki-roots", + "webpki-roots 1.0.4", ] [[package]] @@ -936,6 +1043,12 @@ dependencies = [ "serde", ] +[[package]] +name = "is_terminal_polyfill" +version = "1.70.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a6cb138bb79a146c1bd460005623e142ef0181e3d0219cb493e02f7d08a35695" + [[package]] name = "itoa" version = "1.0.17" @@ -1207,6 +1320,12 @@ version = "1.21.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d" +[[package]] +name = "once_cell_polyfill" +version = "1.70.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "384b8ab6d37215f3c5301a95a4accb5d64aa607f1fcb26a11b5303878451b4fe" + [[package]] name = "parking" version = "2.2.1" @@ -1564,7 +1683,7 @@ dependencies = [ "wasm-bindgen-futures", "wasm-streams", "web-sys", - "webpki-roots", + "webpki-roots 1.0.4", ] [[package]] @@ -2063,6 +2182,12 @@ dependencies = [ "unicode-properties", ] +[[package]] +name = "strsim" +version = "0.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f" + [[package]] name = "subtle" version = "2.6.1" @@ -2234,6 +2359,22 @@ dependencies = [ "tokio", ] +[[package]] +name = "tokio-tungstenite" +version = "0.26.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7a9daff607c6d2bf6c16fd681ccb7eecc83e4e2cdc1ca067ffaadfca5de7f084" +dependencies = [ + "futures-util", + "log", + "rustls", + "rustls-pki-types", + "tokio", + "tokio-rustls", + "tungstenite 0.26.2", + "webpki-roots 0.26.11", +] + [[package]] name = "tokio-tungstenite" version = "0.28.0" @@ -2243,7 +2384,7 @@ dependencies = [ "futures-util", "log", "tokio", - "tungstenite", + "tungstenite 0.28.0", ] [[package]] @@ -2268,7 +2409,9 @@ dependencies = [ "axum-extra", "base64", "chrono", + "clap", "futures", + "hostname", "jsonwebtoken", "mime_guess", "nix", @@ -2280,6 +2423,7 @@ dependencies = [ "sqlx", "time", "tokio", + "tokio-tungstenite 0.26.2", "tokio-util", "tower-http", "tracing", @@ -2411,6 +2555,25 @@ version = "0.2.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b" +[[package]] +name = "tungstenite" +version = "0.26.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4793cb5e56680ecbb1d843515b23b6de9a75eb04b66643e256a396d43be33c13" +dependencies = [ + "bytes", + "data-encoding", + "http", + "httparse", + "log", + "rand 0.9.2", + "rustls", + "rustls-pki-types", + "sha1", + "thiserror", + "utf-8", +] + [[package]] name = "tungstenite" version = "0.28.0" @@ -2515,6 +2678,12 @@ version = "1.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b6c140620e7ffbb22c2dee59cafe6084a59b5ffc27a8859a5f0d494b5d52b6be" +[[package]] +name = "utf8parse" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" + [[package]] name = "uuid" version = "1.21.0" @@ -2709,6 +2878,15 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "webpki-roots" +version = "0.26.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "521bc38abb08001b01866da9f51eb7c5d647a19260e00054a8c7fd5f9e57f7a9" +dependencies = [ + "webpki-roots 1.0.4", +] + [[package]] name = "webpki-roots" version = "1.0.4" diff --git a/Cargo.toml b/Cargo.toml index bab0528..47d0b5d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,11 +19,14 @@ sqlx = { version = "0.8", features = ["runtime-tokio", "sqlite"] } tower-http = { version = "0.6", features = ["cors", "fs"] } reqwest = { version = "0.12", default-features = false, features = ["json", "stream", "rustls-tls"] } futures = "0.3" +tokio-tungstenite = { version = "0.26", features = ["rustls-tls-webpki-roots"] } tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } chrono = { version = "0.4", features = ["serde"] } uuid = { version = "1", features = ["v4"] } anyhow = "1" +clap = { version = "4", features = ["derive", "env"] } +hostname = "0.4" mime_guess = "2" tokio-util = { version = "0.7", features = ["io"] } nix = { version = "0.29", features = ["signal"] } diff --git a/src/agent.rs b/src/agent.rs index 3e55c95..a3ef56d 100644 --- a/src/agent.rs +++ b/src/agent.rs @@ -12,6 +12,7 @@ use crate::template::{self, LoadedTemplate}; use crate::tools::ExternalToolManager; use crate::worker::WorkerManager; use crate::LlmConfig; +use crate::sink::{AgentUpdate, ServiceManager}; use crate::state::{AgentState, AgentPhase, Artifact, Step, StepStatus, StepResult, StepResultStatus, check_scratchpad_size}; @@ -186,6 +187,17 @@ async fn agent_loop( let llm = LlmClient::new(&llm_config); let exec = LocalExecutor::new(mgr.jwt_private_key_path.clone()); let workdir = format!("/app/data/workspaces/{}", project_id); + let svc_mgr = ServiceManager::new(9100); + + // Create update channel and spawn handler + let (update_tx, update_rx) = mpsc::channel::(64); + { + let handler_pool = pool.clone(); + let handler_btx = broadcast_tx.clone(); + tokio::spawn(async move { + crate::sink::handle_agent_updates(update_rx, handler_pool, handler_btx).await; + }); + } tracing::info!("Agent loop started for project {}", project_id); @@ -193,44 +205,27 @@ async fn agent_loop( match event { AgentEvent::NewRequirement { workflow_id, requirement, template_id: forced_template } => { tracing::info!("Processing new requirement for workflow {}", workflow_id); - // Generate project title in background (don't block the agent loop) + // Generate project title from requirement (heuristic, no LLM) { - let title_llm = LlmClient::new(&llm_config); - let title_pool = pool.clone(); - let title_btx = broadcast_tx.clone(); - let title_pid = project_id.clone(); - let title_req = requirement.clone(); - tokio::spawn(async move { - if let Ok(title) = generate_title(&title_llm, &title_req).await { - let _ = sqlx::query("UPDATE projects SET name = ? WHERE id = ?") - .bind(&title) - .bind(&title_pid) - .execute(&title_pool) - .await; - let _ = title_btx.send(WsMessage::ProjectUpdate { - project_id: title_pid, - name: title, - }); - } + let title = generate_title_heuristic(&requirement); + let _ = sqlx::query("UPDATE projects SET name = ? WHERE id = ?") + .bind(&title) + .bind(&project_id) + .execute(&pool) + .await; + let _ = broadcast_tx.send(WsMessage::ProjectUpdate { + project_id: project_id.clone(), + name: title, }); } - let _ = broadcast_tx.send(WsMessage::WorkflowStatusUpdate { + let _ = update_tx.send(AgentUpdate::WorkflowStatus { workflow_id: workflow_id.clone(), status: "executing".into(), - }); - let _ = sqlx::query("UPDATE workflows SET status = 'executing' WHERE id = ?") - .bind(&workflow_id) - .execute(&pool) - .await; + }).await; - // Template selection + workspace setup - let template_id = if forced_template.is_some() { - tracing::info!("Using forced template: {:?}", forced_template); - forced_template - } else { - template::select_template(&llm, &requirement, mgr.template_repo.as_ref()).await - }; + // Template: must be explicitly provided (no LLM selection) + let template_id = forced_template; // Persist template_id to workflow if let Some(ref tid) = template_id { let _ = sqlx::query("UPDATE workflows SET template_id = ? WHERE id = ?") @@ -362,9 +357,9 @@ async fn agent_loop( tracing::info!("Starting agent loop for workflow {}", workflow_id); // Run tool-calling agent loop let result = run_agent_loop( - &llm, &exec, &pool, &broadcast_tx, - &project_id, &workflow_id, &requirement, &workdir, &mgr, - &instructions, None, ext_tools, &mut rx, + &llm, &exec, &update_tx, &mut rx, + &project_id, &workflow_id, &requirement, &workdir, &svc_mgr, + &instructions, None, ext_tools, plan_approval, ).await; @@ -372,40 +367,16 @@ async fn agent_loop( tracing::info!("Agent loop finished for workflow {}, status: {}", workflow_id, final_status); if let Err(e) = &result { tracing::error!("Agent error for workflow {}: {}", workflow_id, e); - let _ = broadcast_tx.send(WsMessage::Error { + let _ = update_tx.send(AgentUpdate::Error { message: format!("Agent error: {}", e), - }); + }).await; } - let _ = sqlx::query("UPDATE workflows SET status = ? WHERE id = ?") - .bind(final_status) - .bind(&workflow_id) - .execute(&pool) - .await; - let _ = broadcast_tx.send(WsMessage::WorkflowStatusUpdate { + let _ = update_tx.send(AgentUpdate::WorkflowComplete { workflow_id: workflow_id.clone(), status: final_status.into(), - }); - - // Generate report from execution log - let log_entries = sqlx::query_as::<_, crate::db::ExecutionLogEntry>( - "SELECT * FROM execution_log WHERE workflow_id = ? ORDER BY created_at" - ) - .bind(&workflow_id) - .fetch_all(&pool) - .await - .unwrap_or_default(); - - if let Ok(report) = generate_report(&llm, &requirement, &log_entries, &project_id).await { - let _ = sqlx::query("UPDATE workflows SET report = ? WHERE id = ?") - .bind(&report) - .bind(&workflow_id) - .execute(&pool) - .await; - let _ = broadcast_tx.send(WsMessage::ReportReady { - workflow_id: workflow_id.clone(), - }); - } + report: None, // Report generation will be handled separately + }).await; } AgentEvent::Comment { workflow_id, content } => { tracing::info!("Comment on workflow {}: {}", workflow_id, content); @@ -462,7 +433,7 @@ async fn agent_loop( } else { // Active workflow: LLM decides whether to revise plan state = process_feedback( - &llm, &pool, &broadcast_tx, + &llm, &update_tx, &project_id, &workflow_id, state, &content, ).await; } @@ -471,14 +442,10 @@ async fn agent_loop( if state.first_actionable_step().is_some() { ensure_workspace(&exec, &workdir).await; - let _ = broadcast_tx.send(WsMessage::WorkflowStatusUpdate { + let _ = update_tx.send(AgentUpdate::WorkflowStatus { workflow_id: workflow_id.clone(), status: "executing".into(), - }); - let _ = sqlx::query("UPDATE workflows SET status = 'executing' WHERE id = ?") - .bind(&workflow_id) - .execute(&pool) - .await; + }).await; // Prepare state for execution: set first pending step to Running if let Some(next) = state.first_actionable_step() { @@ -521,55 +488,30 @@ async fn agent_loop( let plan_approval = loaded_template.as_ref().map_or(false, |t| t.require_plan_approval); let result = run_agent_loop( - &llm, &exec, &pool, &broadcast_tx, - &project_id, &workflow_id, &wf.requirement, &workdir, &mgr, - &instructions, Some(state), ext_tools, &mut rx, + &llm, &exec, &update_tx, &mut rx, + &project_id, &workflow_id, &wf.requirement, &workdir, &svc_mgr, + &instructions, Some(state), ext_tools, plan_approval, ).await; let final_status = if result.is_ok() { "done" } else { "failed" }; if let Err(e) = &result { - let _ = broadcast_tx.send(WsMessage::Error { + let _ = update_tx.send(AgentUpdate::Error { message: format!("Agent error: {}", e), - }); + }).await; } - let _ = sqlx::query("UPDATE workflows SET status = ? WHERE id = ?") - .bind(final_status) - .bind(&workflow_id) - .execute(&pool) - .await; - let _ = broadcast_tx.send(WsMessage::WorkflowStatusUpdate { + let _ = update_tx.send(AgentUpdate::WorkflowComplete { workflow_id: workflow_id.clone(), status: final_status.into(), - }); - - // Regenerate report - let log_entries = sqlx::query_as::<_, crate::db::ExecutionLogEntry>( - "SELECT * FROM execution_log WHERE workflow_id = ? ORDER BY created_at" - ) - .bind(&workflow_id) - .fetch_all(&pool) - .await - .unwrap_or_default(); - - if let Ok(report) = generate_report(&llm, &wf.requirement, &log_entries, &project_id).await { - let _ = sqlx::query("UPDATE workflows SET report = ? WHERE id = ?") - .bind(&report) - .bind(&workflow_id) - .execute(&pool) - .await; - let _ = broadcast_tx.send(WsMessage::ReportReady { - workflow_id: workflow_id.clone(), - }); - } + report: None, + }).await; } else { // No actionable steps — feedback was informational only - // Mark workflow back to done - let _ = broadcast_tx.send(WsMessage::WorkflowStatusUpdate { + let _ = update_tx.send(AgentUpdate::WorkflowStatus { workflow_id: workflow_id.clone(), status: "done".into(), - }); + }).await; } } } @@ -610,22 +552,6 @@ fn tool_list_files() -> Tool { })) } -fn tool_kb_search() -> Tool { - make_tool("kb_search", "搜索知识库中与查询相关的内容片段。返回最相关的 top-5 片段。", serde_json::json!({ - "type": "object", - "properties": { - "query": { "type": "string", "description": "搜索查询" } - }, - "required": ["query"] - })) -} - -fn tool_kb_read() -> Tool { - make_tool("kb_read", "读取知识库全文内容。", serde_json::json!({ - "type": "object", - "properties": {} - })) -} fn build_planning_tools() -> Vec { vec![ @@ -649,8 +575,6 @@ fn build_planning_tools() -> Vec { })), tool_list_files(), tool_read_file(), - tool_kb_search(), - tool_kb_read(), ] } @@ -757,21 +681,6 @@ fn build_step_tools() -> Vec { }, "required": ["summary", "artifacts"] })), - tool_kb_search(), - tool_kb_read(), - make_tool("list_workers", "列出所有已注册的远程 worker 节点及其硬件/软件信息(CPU、内存、GPU、OS、内核)。", serde_json::json!({ - "type": "object", - "properties": {} - })), - make_tool("execute_on_worker", "在指定的远程 worker 上执行脚本。脚本以 bash 执行。可以通过 HTTP 访问项目文件:GET/POST /api/obj/{project_id}/files/{path}", serde_json::json!({ - "type": "object", - "properties": { - "worker": { "type": "string", "description": "Worker 名称(从 list_workers 获取)" }, - "script": { "type": "string", "description": "要执行的 bash 脚本内容" }, - "timeout": { "type": "integer", "description": "超时秒数(默认 300)", "default": 300 } - }, - "required": ["worker", "script"] - })), ] } @@ -989,25 +898,18 @@ async fn execute_tool( // --- Tool-calling agent loop (state machine) --- -/// Save an AgentState snapshot to DB. -async fn save_state_snapshot(pool: &SqlitePool, workflow_id: &str, step_order: i32, state: &AgentState) { - let id = uuid::Uuid::new_v4().to_string(); - let json = serde_json::to_string(state).unwrap_or_default(); - let _ = sqlx::query( - "INSERT INTO agent_state_snapshots (id, workflow_id, step_order, state_json, created_at) VALUES (?, ?, ?, ?, datetime('now'))" - ) - .bind(&id) - .bind(workflow_id) - .bind(step_order) - .bind(&json) - .execute(pool) - .await; +/// Send a state snapshot via update channel. +async fn send_snapshot(tx: &mpsc::Sender, workflow_id: &str, step_order: i32, state: &AgentState) { + let _ = tx.send(AgentUpdate::StateSnapshot { + workflow_id: workflow_id.to_string(), + step_order, + state: state.clone(), + }).await; } -/// Log a tool call to execution_log and broadcast to frontend. -async fn log_execution( - pool: &SqlitePool, - broadcast_tx: &broadcast::Sender, +/// Send an execution log entry via update channel. +async fn send_execution( + tx: &mpsc::Sender, workflow_id: &str, step_order: i32, tool_name: &str, @@ -1015,32 +917,20 @@ async fn log_execution( output: &str, status: &str, ) { - let id = uuid::Uuid::new_v4().to_string(); - let _ = sqlx::query( - "INSERT INTO execution_log (id, workflow_id, step_order, tool_name, tool_input, output, status, created_at) VALUES (?, ?, ?, ?, ?, ?, ?, datetime('now'))" - ) - .bind(&id) - .bind(workflow_id) - .bind(step_order) - .bind(tool_name) - .bind(tool_input) - .bind(output) - .bind(status) - .execute(pool) - .await; - - let _ = broadcast_tx.send(WsMessage::StepStatusUpdate { - step_id: id, - status: status.into(), + let _ = tx.send(AgentUpdate::ExecutionLog { + workflow_id: workflow_id.to_string(), + step_order, + tool_name: tool_name.to_string(), + tool_input: tool_input.to_string(), output: output.to_string(), - }); + status: status.to_string(), + }).await; } -/// Log an LLM call to llm_call_log and broadcast to frontend. +/// Send an LLM call log entry via update channel. #[allow(clippy::too_many_arguments)] -async fn log_llm_call( - pool: &SqlitePool, - broadcast_tx: &broadcast::Sender, +async fn send_llm_call( + tx: &mpsc::Sender, workflow_id: &str, step_order: i32, phase: &str, @@ -1052,26 +942,7 @@ async fn log_llm_call( completion_tokens: Option, latency_ms: i64, ) { - let id = uuid::Uuid::new_v4().to_string(); - let _ = sqlx::query( - "INSERT INTO llm_call_log (id, workflow_id, step_order, phase, messages_count, tools_count, tool_calls, text_response, prompt_tokens, completion_tokens, latency_ms, created_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, datetime('now'))" - ) - .bind(&id) - .bind(workflow_id) - .bind(step_order) - .bind(phase) - .bind(messages_count) - .bind(tools_count) - .bind(tool_calls_json) - .bind(text_response) - .bind(prompt_tokens.map(|v| v as i32)) - .bind(completion_tokens.map(|v| v as i32)) - .bind(latency_ms as i32) - .execute(pool) - .await; - - let entry = crate::db::LlmCallLogEntry { - id: id.clone(), + let _ = tx.send(AgentUpdate::LlmCallLog { workflow_id: workflow_id.to_string(), step_order, phase: phase.to_string(), @@ -1079,24 +950,17 @@ async fn log_llm_call( tools_count, tool_calls: tool_calls_json.to_string(), text_response: text_response.to_string(), - prompt_tokens: prompt_tokens.map(|v| v as i32), - completion_tokens: completion_tokens.map(|v| v as i32), - latency_ms: latency_ms as i32, - created_at: String::new(), - }; - - let _ = broadcast_tx.send(WsMessage::LlmCallLog { - workflow_id: workflow_id.to_string(), - entry, - }); + prompt_tokens, + completion_tokens, + latency_ms, + }).await; } /// Process user feedback: call LLM to decide whether to revise the plan. /// Returns the (possibly modified) AgentState ready for resumed execution. async fn process_feedback( llm: &LlmClient, - pool: &SqlitePool, - broadcast_tx: &broadcast::Sender, + update_tx: &mpsc::Sender, project_id: &str, workflow_id: &str, mut state: AgentState, @@ -1114,7 +978,6 @@ async fn process_feedback( Ok(r) => r, Err(e) => { tracing::error!("[workflow {}] Feedback LLM call failed: {}", workflow_id, e); - // On failure, attach feedback to first non-done step and return unchanged if let Some(step) = state.steps.iter_mut().find(|s| !matches!(s.status, StepStatus::Done)) { step.user_feedbacks.push(feedback.to_string()); } @@ -1147,60 +1010,54 @@ async fn process_feedback( } }).collect(); - // Apply docker-cache diff let diff = state.apply_plan_diff(new_steps); - // Broadcast updated plan - let _ = broadcast_tx.send(WsMessage::PlanUpdate { + let _ = update_tx.send(AgentUpdate::PlanUpdate { workflow_id: workflow_id.to_string(), steps: plan_infos_from_state(&state), - }); + }).await; tracing::info!("[workflow {}] Plan revised via feedback. First actionable: {:?}", workflow_id, state.first_actionable_step()); - // Log the diff so frontend can show what changed let diff_display = format!("```diff\n{}\n```", diff); - log_execution(pool, broadcast_tx, workflow_id, 0, "revise_plan", "计划变更", &diff_display, "done").await; + send_execution(update_tx, workflow_id, 0, "revise_plan", "计划变更", &diff_display, "done").await; } } } else { - // Text response only — feedback is informational, no plan change let text = choice.message.content.as_deref().unwrap_or(""); tracing::info!("[workflow {}] Feedback processed, no plan change: {}", workflow_id, truncate_str(text, 200)); - log_execution(pool, broadcast_tx, workflow_id, state.current_step(), "text_response", "", text, "done").await; + send_execution(update_tx, workflow_id, state.current_step(), "text_response", "", text, "done").await; } - // Attach feedback to the first actionable step (or last step) let target_order = state.first_actionable_step() .unwrap_or_else(|| state.steps.last().map(|s| s.order).unwrap_or(0)); if let Some(step) = state.steps.iter_mut().find(|s| s.order == target_order) { step.user_feedbacks.push(feedback.to_string()); } - // Snapshot after feedback processing - save_state_snapshot(pool, workflow_id, state.current_step(), &state).await; + send_snapshot(update_tx, workflow_id, state.current_step(), &state).await; state } /// Run an isolated sub-loop for a single step. Returns StepResult. #[allow(clippy::too_many_arguments)] -async fn run_step_loop( +pub async fn run_step_loop( llm: &LlmClient, exec: &LocalExecutor, - pool: &SqlitePool, - broadcast_tx: &broadcast::Sender, + update_tx: &mpsc::Sender, + event_rx: &mut mpsc::Receiver, project_id: &str, workflow_id: &str, workdir: &str, - mgr: &Arc, + svc_mgr: &ServiceManager, instructions: &str, step: &Step, completed_summaries: &[(i32, String, String, Vec)], parent_scratchpad: &str, external_tools: Option<&ExternalToolManager>, - event_rx: &mut mpsc::Receiver, + all_steps: &[Step], ) -> StepResult { let system_prompt = build_step_execution_prompt(project_id, instructions); let user_message = build_step_user_message(step, completed_summaries, parent_scratchpad); @@ -1236,10 +1093,10 @@ async fn run_step_loop( let phase_label = format!("step({})", step_order); tracing::info!("[workflow {}] Step {} LLM call #{} msgs={}", workflow_id, step_order, iteration + 1, messages.len()); - let _ = broadcast_tx.send(WsMessage::ActivityUpdate { + let _ = update_tx.send(AgentUpdate::Activity { workflow_id: workflow_id.to_string(), activity: format!("步骤 {} — 等待 LLM 响应...", step_order), - }); + }).await; let call_start = std::time::Instant::now(); let response = match llm.chat_with_tools(messages, &step_tools).await { Ok(r) => r, @@ -1302,24 +1159,16 @@ async fn run_step_loop( }).collect()) .unwrap_or_default(); - // Save artifacts to DB + // Save artifacts for art in &artifacts { - let art_id = uuid::Uuid::new_v4().to_string(); - let _ = sqlx::query( - "INSERT INTO step_artifacts (id, workflow_id, step_order, name, path, artifact_type, description) VALUES (?, ?, ?, ?, ?, ?, ?)" - ) - .bind(&art_id) - .bind(workflow_id) - .bind(step_order) - .bind(&art.name) - .bind(&art.path) - .bind(&art.artifact_type) - .bind(&art.description) - .execute(pool) - .await; + let _ = update_tx.send(AgentUpdate::ArtifactSave { + workflow_id: workflow_id.to_string(), + step_order, + artifact: art.clone(), + }).await; } - log_execution(pool, broadcast_tx, workflow_id, step_order, "step_done", &summary, "步骤完成", "done").await; + send_execution(update_tx, workflow_id, step_order, "step_done", &summary, "步骤完成", "done").await; step_chat_history.push(ChatMessage::tool_result(&tc.id, "步骤已完成。")); step_done_result = Some(StepResult { status: StepResultStatus::Done, @@ -1348,26 +1197,22 @@ async fn run_step_loop( "ask_user" => { let reason = args["question"].as_str().unwrap_or("等待确认"); - let _ = broadcast_tx.send(WsMessage::ActivityUpdate { + let _ = update_tx.send(AgentUpdate::Activity { workflow_id: workflow_id.to_string(), activity: format!("步骤 {} — 等待用户确认: {}", step_order, reason), - }); + }).await; - // Broadcast waiting status - let _ = broadcast_tx.send(WsMessage::PlanUpdate { + // Broadcast waiting status using local steps data + let waiting_steps = plan_infos_with_override(all_steps, step_order, "waiting_user"); + let _ = update_tx.send(AgentUpdate::PlanUpdate { workflow_id: workflow_id.to_string(), - steps: plan_infos_from_state_with_override(step_order, "waiting_user", - pool, workflow_id).await, - }); - let _ = broadcast_tx.send(WsMessage::WorkflowStatusUpdate { + steps: waiting_steps, + }).await; + let _ = update_tx.send(AgentUpdate::WorkflowStatus { workflow_id: workflow_id.to_string(), status: "waiting_user".into(), - }); - let _ = sqlx::query("UPDATE workflows SET status = 'waiting_user' WHERE id = ?") - .bind(workflow_id) - .execute(pool) - .await; - log_execution(pool, broadcast_tx, workflow_id, step_order, "ask_user", reason, reason, "waiting").await; + }).await; + send_execution(update_tx, workflow_id, step_order, "ask_user", reason, reason, "waiting").await; tracing::info!("[workflow {}] Step {} waiting for approval: {}", workflow_id, step_order, reason); @@ -1390,7 +1235,7 @@ async fn run_step_loop( if approval_content.starts_with("rejected:") { let reason = approval_content.strip_prefix("rejected:").unwrap_or("").trim(); - log_execution(pool, broadcast_tx, workflow_id, step_order, "ask_user", "rejected", reason, "failed").await; + send_execution(update_tx, workflow_id, step_order, "ask_user", "rejected", reason, "failed").await; step_chat_history.push(ChatMessage::tool_result(&tc.id, &format!("用户拒绝: {}", reason))); step_done_result = Some(StepResult { status: StepResultStatus::Failed { error: format!("用户终止: {}", reason) }, @@ -1407,14 +1252,10 @@ async fn run_step_loop( approval_content.clone() }; - let _ = sqlx::query("UPDATE workflows SET status = 'executing' WHERE id = ?") - .bind(workflow_id) - .execute(pool) - .await; - let _ = broadcast_tx.send(WsMessage::WorkflowStatusUpdate { + let _ = update_tx.send(AgentUpdate::WorkflowStatus { workflow_id: workflow_id.to_string(), status: "executing".into(), - }); + }).await; let tool_msg = if feedback.is_empty() { "用户已确认,继续执行。".to_string() @@ -1427,7 +1268,7 @@ async fn run_step_loop( "start_service" => { let cmd = args["command"].as_str().unwrap_or(""); { - let mut services = mgr.services.write().await; + let mut services = svc_mgr.services.write().await; if let Some(old) = services.remove(project_id) { let _ = nix::sys::signal::kill( nix::unistd::Pid::from_raw(old.pid as i32), @@ -1435,7 +1276,7 @@ async fn run_step_loop( ); } } - let port = mgr.allocate_port(); + let port = svc_mgr.allocate_port(); let cmd_with_port = cmd.replace("$PORT", &port.to_string()); let venv_bin = format!("{}/.venv/bin", workdir); let path_env = match std::env::var("PATH") { @@ -1455,19 +1296,19 @@ async fn run_step_loop( { Ok(child) => { let pid = child.id().unwrap_or(0); - mgr.services.write().await.insert(project_id.to_string(), ServiceInfo { port, pid }); + svc_mgr.services.write().await.insert(project_id.to_string(), ServiceInfo { port, pid }); tokio::time::sleep(std::time::Duration::from_secs(2)).await; format!("服务已启动,端口 {},访问地址:/api/projects/{}/app/", port, project_id) } Err(e) => format!("Error: 启动失败:{}", e), }; let status = if result.starts_with("Error:") { "failed" } else { "done" }; - log_execution(pool, broadcast_tx, workflow_id, step_order, "start_service", cmd, &result, status).await; + send_execution(update_tx, workflow_id, step_order, "start_service", cmd, &result, status).await; step_chat_history.push(ChatMessage::tool_result(&tc.id, &result)); } "stop_service" => { - let mut services = mgr.services.write().await; + let mut services = svc_mgr.services.write().await; let result = if let Some(svc) = services.remove(project_id) { let _ = nix::sys::signal::kill( nix::unistd::Pid::from_raw(svc.pid as i32), @@ -1477,98 +1318,16 @@ async fn run_step_loop( } else { "当前没有运行中的服务。".to_string() }; - log_execution(pool, broadcast_tx, workflow_id, step_order, "stop_service", "", &result, "done").await; - step_chat_history.push(ChatMessage::tool_result(&tc.id, &result)); - } - - "kb_search" => { - let query = args["query"].as_str().unwrap_or(""); - let result = if let Some(kb) = &mgr.kb { - match kb.search(query).await { - Ok(results) if results.is_empty() => "知识库为空或没有匹配结果。".to_string(), - Ok(results) => { - results.iter().enumerate().map(|(i, r)| { - let article_label = if r.article_title.is_empty() { - String::new() - } else { - format!(" [文章: {}]", r.article_title) - }; - format!("--- 片段 {} (相似度: {:.2}){} ---\n{}", i + 1, r.score, article_label, r.content) - }).collect::>().join("\n\n") - } - Err(e) => format!("Error: {}", e), - } - } else { - "知识库未初始化。".to_string() - }; - step_chat_history.push(ChatMessage::tool_result(&tc.id, &result)); - } - - "kb_read" => { - let result = if let Some(kb) = &mgr.kb { - match kb.read_all().await { - Ok(content) if content.is_empty() => "知识库为空。".to_string(), - Ok(content) => content, - Err(e) => format!("Error: {}", e), - } - } else { - "知识库未初始化。".to_string() - }; - step_chat_history.push(ChatMessage::tool_result(&tc.id, &result)); - } - - "list_workers" => { - let _ = broadcast_tx.send(WsMessage::ActivityUpdate { - workflow_id: workflow_id.to_string(), - activity: format!("步骤 {} — 列出 Workers", step_order), - }); - let workers = mgr.worker_mgr.list().await; - let result = if workers.is_empty() { - "没有已注册的 worker。".to_string() - } else { - let items: Vec = workers.iter().map(|(name, info)| { - format!("- {} (cpu={}, mem={}, gpu={}, os={}, kernel={})", - name, info.cpu, info.memory, info.gpu, info.os, info.kernel) - }).collect(); - format!("已注册的 workers:\n{}", items.join("\n")) - }; - log_execution(pool, broadcast_tx, workflow_id, step_order, "list_workers", "", &result, "done").await; - step_chat_history.push(ChatMessage::tool_result(&tc.id, &result)); - } - - "execute_on_worker" => { - let worker_name = args.get("worker").and_then(|v| v.as_str()).unwrap_or(""); - let script = args.get("script").and_then(|v| v.as_str()).unwrap_or(""); - let timeout = args.get("timeout").and_then(|v| v.as_u64()).unwrap_or(300); - let _ = broadcast_tx.send(WsMessage::ActivityUpdate { - workflow_id: workflow_id.to_string(), - activity: format!("步骤 {} — 在 {} 上执行脚本", step_order, worker_name), - }); - let result = match mgr.worker_mgr.execute(worker_name, script, timeout).await { - Ok(wr) => { - let mut out = String::new(); - out.push_str(&format!("exit_code: {}\n", wr.exit_code)); - if !wr.stdout.is_empty() { - out.push_str(&format!("stdout:\n{}\n", truncate_str(&wr.stdout, 8192))); - } - if !wr.stderr.is_empty() { - out.push_str(&format!("stderr:\n{}\n", truncate_str(&wr.stderr, 4096))); - } - out - } - Err(e) => format!("Error: {}", e), - }; - let status = if result.starts_with("Error:") { "failed" } else { "done" }; - log_execution(pool, broadcast_tx, workflow_id, step_order, "execute_on_worker", &tc.function.arguments, &result, status).await; + send_execution(update_tx, workflow_id, step_order, "stop_service", "", &result, "done").await; step_chat_history.push(ChatMessage::tool_result(&tc.id, &result)); } // External tools name if external_tools.as_ref().is_some_and(|e| e.has_tool(name)) => { - let _ = broadcast_tx.send(WsMessage::ActivityUpdate { + let _ = update_tx.send(AgentUpdate::Activity { workflow_id: workflow_id.to_string(), activity: format!("步骤 {} — 工具: {}", step_order, name), - }); + }).await; let result = match external_tools.unwrap().invoke(name, &tc.function.arguments, workdir).await { Ok(output) => { let truncated = truncate_str(&output, 8192); @@ -1577,7 +1336,7 @@ async fn run_step_loop( Err(e) => format!("Tool error: {}", e), }; let status = if result.starts_with("Tool error:") { "failed" } else { "done" }; - log_execution(pool, broadcast_tx, workflow_id, step_order, &tc.function.name, &tc.function.arguments, &result, status).await; + send_execution(update_tx, workflow_id, step_order, &tc.function.name, &tc.function.arguments, &result, status).await; step_chat_history.push(ChatMessage::tool_result(&tc.id, &result)); } @@ -1593,13 +1352,13 @@ async fn run_step_loop( "list_files" => "列出文件".to_string(), other => format!("工具: {}", other), }; - let _ = broadcast_tx.send(WsMessage::ActivityUpdate { + let _ = update_tx.send(AgentUpdate::Activity { workflow_id: workflow_id.to_string(), activity: format!("步骤 {} — {}", step_order, tool_desc), - }); + }).await; let result = execute_tool(&tc.function.name, &tc.function.arguments, workdir, exec).await; let status = if result.starts_with("Error:") { "failed" } else { "done" }; - log_execution(pool, broadcast_tx, workflow_id, step_order, &tc.function.name, &tc.function.arguments, &result, status).await; + send_execution(update_tx, workflow_id, step_order, &tc.function.name, &tc.function.arguments, &result, status).await; step_chat_history.push(ChatMessage::tool_result(&tc.id, &result)); } } @@ -1613,8 +1372,8 @@ async fn run_step_loop( }) }).collect(); let tc_json_str = serde_json::to_string(&tc_json).unwrap_or_else(|_| "[]".to_string()); - log_llm_call( - pool, broadcast_tx, workflow_id, step_order, + send_llm_call( + update_tx, workflow_id, step_order, &phase_label, msg_count, tool_count, &tc_json_str, &llm_text_response, prompt_tokens, completion_tokens, latency_ms, @@ -1627,9 +1386,9 @@ async fn run_step_loop( // Text response without tool calls let content = choice.message.content.as_deref().unwrap_or("(no content)"); tracing::info!("[workflow {}] Step {} text response: {}", workflow_id, step_order, truncate_str(content, 200)); - log_execution(pool, broadcast_tx, workflow_id, step_order, "text_response", "", content, "done").await; - log_llm_call( - pool, broadcast_tx, workflow_id, step_order, + send_execution(update_tx, workflow_id, step_order, "text_response", "", content, "done").await; + send_llm_call( + update_tx, workflow_id, step_order, &phase_label, msg_count, tool_count, "[]", content, prompt_tokens, completion_tokens, latency_ms, @@ -1647,67 +1406,44 @@ async fn run_step_loop( } } -/// Helper to get plan step infos with a status override for a specific step. -/// Used during ask_user in the step sub-loop where we don't have -/// mutable access to the AgentState. -async fn plan_infos_from_state_with_override( - step_order: i32, - override_status: &str, - pool: &SqlitePool, - workflow_id: &str, -) -> Vec { - // Read the latest state snapshot to get step info - let snapshot = sqlx::query_scalar::<_, String>( - "SELECT state_json FROM agent_state_snapshots WHERE workflow_id = ? ORDER BY created_at DESC LIMIT 1" - ) - .bind(workflow_id) - .fetch_optional(pool) - .await - .ok() - .flatten(); - - if let Some(json) = snapshot { - if let Ok(state) = serde_json::from_str::(&json) { - return state.steps.iter().map(|s| { - let status = if s.order == step_order { - override_status.to_string() - } else { - match s.status { - StepStatus::Pending => "pending", - StepStatus::Running => "running", - StepStatus::WaitingUser => "waiting_user", - StepStatus::Done => "done", - StepStatus::Failed => "failed", - }.to_string() - }; - PlanStepInfo { - order: s.order, - description: s.title.clone(), - command: s.description.clone(), - status: Some(status), - artifacts: s.artifacts.clone(), - } - }).collect(); +/// Compute plan step infos with a status override for a specific step. +fn plan_infos_with_override(steps: &[Step], override_order: i32, override_status: &str) -> Vec { + steps.iter().map(|s| { + let status = if s.order == override_order { + override_status.to_string() + } else { + match s.status { + StepStatus::Pending => "pending", + StepStatus::Running => "running", + StepStatus::WaitingUser => "waiting_user", + StepStatus::Done => "done", + StepStatus::Failed => "failed", + }.to_string() + }; + PlanStepInfo { + order: s.order, + description: s.title.clone(), + command: s.description.clone(), + status: Some(status), + artifacts: s.artifacts.clone(), } - } - Vec::new() + }).collect() } #[allow(clippy::too_many_arguments)] -async fn run_agent_loop( +pub async fn run_agent_loop( llm: &LlmClient, exec: &LocalExecutor, - pool: &SqlitePool, - broadcast_tx: &broadcast::Sender, + update_tx: &mpsc::Sender, + event_rx: &mut mpsc::Receiver, project_id: &str, workflow_id: &str, requirement: &str, workdir: &str, - mgr: &Arc, + svc_mgr: &ServiceManager, instructions: &str, initial_state: Option, external_tools: Option<&ExternalToolManager>, - event_rx: &mut mpsc::Receiver, require_plan_approval: bool, ) -> anyhow::Result<()> { let planning_tools = build_planning_tools(); @@ -1728,10 +1464,10 @@ async fn run_agent_loop( let tool_count = planning_tools.len() as i32; tracing::info!("[workflow {}] Planning LLM call #{} msgs={}", workflow_id, iteration + 1, messages.len()); - let _ = broadcast_tx.send(WsMessage::ActivityUpdate { + let _ = update_tx.send(AgentUpdate::Activity { workflow_id: workflow_id.to_string(), activity: "规划中 — 等待 LLM 响应...".to_string(), - }); + }).await; let call_start = std::time::Instant::now(); let response = match llm.chat_with_tools(messages, &planning_tools).await { Ok(r) => r, @@ -1780,30 +1516,26 @@ async fn run_agent_loop( }); } - let _ = broadcast_tx.send(WsMessage::PlanUpdate { + let _ = update_tx.send(AgentUpdate::PlanUpdate { workflow_id: workflow_id.to_string(), steps: plan_infos_from_state(&state), - }); + }).await; - save_state_snapshot(pool, workflow_id, 0, &state).await; + send_snapshot(update_tx, workflow_id, 0, &state).await; tracing::info!("[workflow {}] Plan set ({} steps)", workflow_id, state.steps.len()); // If require_plan_approval, wait for user to confirm the plan if require_plan_approval { tracing::info!("[workflow {}] Waiting for plan approval", workflow_id); - let _ = broadcast_tx.send(WsMessage::ActivityUpdate { + let _ = update_tx.send(AgentUpdate::Activity { workflow_id: workflow_id.to_string(), activity: "计划已生成 — 等待用户确认...".to_string(), - }); - let _ = broadcast_tx.send(WsMessage::WorkflowStatusUpdate { + }).await; + let _ = update_tx.send(AgentUpdate::WorkflowStatus { workflow_id: workflow_id.to_string(), status: "waiting_user".into(), - }); - let _ = sqlx::query("UPDATE workflows SET status = 'waiting_user' WHERE id = ?") - .bind(workflow_id) - .execute(pool) - .await; - log_execution(pool, broadcast_tx, workflow_id, 0, "plan_approval", "等待确认计划", "等待用户确认执行计划", "waiting").await; + }).await; + send_execution(update_tx, workflow_id, 0, "plan_approval", "等待确认计划", "等待用户确认执行计划", "waiting").await; // Block until Comment event let approval_content = loop { @@ -1821,23 +1553,18 @@ async fn run_agent_loop( if approval_content.starts_with("rejected:") { let reason = approval_content.strip_prefix("rejected:").unwrap_or("").trim(); tracing::info!("[workflow {}] Plan rejected: {}", workflow_id, reason); - log_execution(pool, broadcast_tx, workflow_id, 0, "plan_approval", "rejected", reason, "failed").await; + send_execution(update_tx, workflow_id, 0, "plan_approval", "rejected", reason, "failed").await; - // Feed rejection back into planning conversation so LLM can re-plan state.current_step_chat_history.push(ChatMessage::tool_result( &tc.id, &format!("用户拒绝了此计划: {}。请根据反馈修改计划后重新调用 update_plan。", reason), )); state.steps.clear(); - let _ = broadcast_tx.send(WsMessage::WorkflowStatusUpdate { + let _ = update_tx.send(AgentUpdate::WorkflowStatus { workflow_id: workflow_id.to_string(), status: "executing".into(), - }); - let _ = sqlx::query("UPDATE workflows SET status = 'executing' WHERE id = ?") - .bind(workflow_id) - .execute(pool) - .await; + }).await; // Stay in Planning phase, continue the loop continue; } @@ -1849,30 +1576,26 @@ async fn run_agent_loop( String::new() }; - log_execution(pool, broadcast_tx, workflow_id, 0, "plan_approval", "approved", &feedback, "done").await; - let _ = broadcast_tx.send(WsMessage::WorkflowStatusUpdate { + send_execution(update_tx, workflow_id, 0, "plan_approval", "approved", &feedback, "done").await; + let _ = update_tx.send(AgentUpdate::WorkflowStatus { workflow_id: workflow_id.to_string(), status: "executing".into(), - }); - let _ = sqlx::query("UPDATE workflows SET status = 'executing' WHERE id = ?") - .bind(workflow_id) - .execute(pool) - .await; + }).await; } // Enter execution phase if let Some(first) = state.steps.first_mut() { first.status = StepStatus::Running; } - let _ = broadcast_tx.send(WsMessage::PlanUpdate { + let _ = update_tx.send(AgentUpdate::PlanUpdate { workflow_id: workflow_id.to_string(), steps: plan_infos_from_state(&state), - }); + }).await; state.current_step_chat_history.clear(); state.phase = AgentPhase::Executing { step: 1 }; phase_transition = true; - save_state_snapshot(pool, workflow_id, 0, &state).await; + send_snapshot(update_tx, workflow_id, 0, &state).await; tracing::info!("[workflow {}] Entering Executing", workflow_id); } // Planning phase IO tools @@ -1887,13 +1610,13 @@ async fn run_agent_loop( serde_json::json!({ "name": tc.function.name, "arguments_preview": truncate_str(&tc.function.arguments, 200) }) }).collect(); let tc_json_str = serde_json::to_string(&tc_json).unwrap_or_else(|_| "[]".to_string()); - log_llm_call(pool, broadcast_tx, workflow_id, 0, "planning", msg_count, tool_count, + send_llm_call(update_tx, workflow_id, 0, "planning", msg_count, tool_count, &tc_json_str, &llm_text_response, prompt_tokens, completion_tokens, latency_ms).await; } else { let content = choice.message.content.as_deref().unwrap_or("(no content)"); tracing::info!("[workflow {}] Planning text response: {}", workflow_id, truncate_str(content, 200)); - log_execution(pool, broadcast_tx, workflow_id, 0, "text_response", "", content, "done").await; - log_llm_call(pool, broadcast_tx, workflow_id, 0, "planning", msg_count, tool_count, + send_execution(update_tx, workflow_id, 0, "text_response", "", content, "done").await; + send_llm_call(update_tx, workflow_id, 0, "planning", msg_count, tool_count, "[]", content, prompt_tokens, completion_tokens, latency_ms).await; } } @@ -1915,11 +1638,11 @@ async fn run_agent_loop( state.phase = AgentPhase::Executing { step: step_order }; state.current_step_chat_history.clear(); - let _ = broadcast_tx.send(WsMessage::PlanUpdate { + let _ = update_tx.send(AgentUpdate::PlanUpdate { workflow_id: workflow_id.to_string(), steps: plan_infos_from_state(&state), - }); - save_state_snapshot(pool, workflow_id, step_order, &state).await; + }).await; + send_snapshot(update_tx, workflow_id, step_order, &state).await; // Build completed summaries for context let completed_summaries: Vec<(i32, String, String, Vec)> = state.steps.iter() @@ -1933,10 +1656,10 @@ async fn run_agent_loop( // Run the isolated step sub-loop let step_result = run_step_loop( - llm, exec, pool, broadcast_tx, - project_id, workflow_id, workdir, mgr, + llm, exec, update_tx, event_rx, + project_id, workflow_id, workdir, svc_mgr, instructions, &step, &completed_summaries, &state.scratchpad, - external_tools, event_rx, + external_tools, &state.steps, ).await; tracing::info!("[workflow {}] Step {} completed: {:?}", workflow_id, step_order, step_result.status); @@ -1955,29 +1678,27 @@ async fn run_agent_loop( s.status = StepStatus::Failed; s.summary = Some(step_result.summary.clone()); } - let _ = broadcast_tx.send(WsMessage::PlanUpdate { + let _ = update_tx.send(AgentUpdate::PlanUpdate { workflow_id: workflow_id.to_string(), steps: plan_infos_from_state(&state), - }); - save_state_snapshot(pool, workflow_id, step_order, &state).await; + }).await; + send_snapshot(update_tx, workflow_id, step_order, &state).await; return Err(anyhow::anyhow!("Step {} failed: {}", step_order, error)); } StepResultStatus::NeedsInput { message: _ } => { - // This shouldn't normally happen since ask_user is handled inside - // run_step_loop, but handle gracefully if let Some(s) = state.steps.iter_mut().find(|s| s.order == step_order) { s.status = StepStatus::WaitingUser; } - save_state_snapshot(pool, workflow_id, step_order, &state).await; + send_snapshot(update_tx, workflow_id, step_order, &state).await; continue; } } - let _ = broadcast_tx.send(WsMessage::PlanUpdate { + let _ = update_tx.send(AgentUpdate::PlanUpdate { workflow_id: workflow_id.to_string(), steps: plan_infos_from_state(&state), - }); - save_state_snapshot(pool, workflow_id, step_order, &state).await; + }).await; + send_snapshot(update_tx, workflow_id, step_order, &state).await; // --- Coordinator review --- // Check if there are more steps; if not, we're done @@ -2023,10 +1744,10 @@ async fn run_agent_loop( state.current_step_chat_history.clear(); tracing::info!("[workflow {}] Coordinator review for step {}", workflow_id, step_order); - let _ = broadcast_tx.send(WsMessage::ActivityUpdate { + let _ = update_tx.send(AgentUpdate::Activity { workflow_id: workflow_id.to_string(), activity: format!("步骤 {} 完成 — 协调器审核中...", step_order), - }); + }).await; let call_start = std::time::Instant::now(); let coord_response = match llm.chat_with_tools(coord_messages.clone(), &coordinator_tools).await { Ok(r) => r, @@ -2041,7 +1762,7 @@ async fn run_agent_loop( .map(|u| (Some(u.prompt_tokens), Some(u.completion_tokens))) .unwrap_or((None, None)); - log_llm_call(pool, broadcast_tx, workflow_id, step_order, "coordinator", + send_llm_call(update_tx, workflow_id, step_order, "coordinator", coord_messages.len() as i32, coordinator_tools.len() as i32, "[]", "", prompt_tokens, completion_tokens, latency_ms).await; @@ -2064,12 +1785,12 @@ async fn run_agent_loop( }).collect(); state.apply_plan_diff(new_steps); - let _ = broadcast_tx.send(WsMessage::PlanUpdate { + let _ = update_tx.send(AgentUpdate::PlanUpdate { workflow_id: workflow_id.to_string(), steps: plan_infos_from_state(&state), - }); + }).await; tracing::info!("[workflow {}] Coordinator revised plan", workflow_id); - save_state_snapshot(pool, workflow_id, step_order, &state).await; + send_snapshot(update_tx, workflow_id, step_order, &state).await; } "update_scratchpad" => { let content = args["content"].as_str().unwrap_or(""); @@ -2082,13 +1803,11 @@ async fn run_agent_loop( } "update_requirement" => { let new_req = args["requirement"].as_str().unwrap_or(""); - let _ = sqlx::query("UPDATE workflows SET requirement = ? WHERE id = ?") - .bind(new_req).bind(workflow_id).execute(pool).await; let _ = tokio::fs::write(format!("{}/requirement.md", workdir), new_req).await; - let _ = broadcast_tx.send(WsMessage::RequirementUpdate { + let _ = update_tx.send(AgentUpdate::RequirementUpdate { workflow_id: workflow_id.to_string(), requirement: new_req.to_string(), - }); + }).await; } _ => {} } @@ -2102,15 +1821,27 @@ async fn run_agent_loop( } // Final snapshot - save_state_snapshot(pool, workflow_id, state.current_step(), &state).await; + send_snapshot(update_tx, workflow_id, state.current_step(), &state).await; Ok(()) } -async fn generate_report( +/// Simple log entry for report generation (no DB dependency). +/// Used by the worker binary to collect execution logs during agent loop. +#[allow(dead_code)] +pub struct SimpleLogEntry { + pub step_order: i32, + pub tool_name: String, + pub tool_input: String, + pub output: String, + pub status: String, +} + +#[allow(dead_code)] +pub async fn generate_report( llm: &LlmClient, requirement: &str, - entries: &[crate::db::ExecutionLogEntry], + entries: &[SimpleLogEntry], project_id: &str, ) -> anyhow::Result { let steps_detail: String = entries @@ -2147,23 +1878,10 @@ async fn generate_report( Ok(report) } -async fn generate_title(llm: &LlmClient, requirement: &str) -> anyhow::Result { - let response = llm - .chat(vec![ - ChatMessage::system("为给定的需求生成一个简短的项目标题(最多15个汉字)。只回复标题本身,不要加任何其他内容。使用中文。"), - ChatMessage::user(requirement), - ]) - .await?; - - let mut title = response.trim().trim_matches('"').to_string(); - // Hard limit: if LLM returns garbage, take only the first line, max 80 chars - if let Some(first_line) = title.lines().next() { - title = first_line.to_string(); - } - if title.len() > 80 { - title = truncate_str(&title, 80).to_string(); - } - Ok(title) +fn generate_title_heuristic(requirement: &str) -> String { + let first_line = requirement.lines().next().unwrap_or(requirement); + let trimmed = first_line.trim().trim_start_matches('#').trim(); + truncate_str(trimmed, 50).to_string() } #[cfg(test)] @@ -2304,7 +2022,7 @@ mod tests { let names: Vec<&str> = tools.iter().map(|t| t.function.name.as_str()).collect(); for expected in &["execute", "read_file", "write_file", "list_files", "start_service", "stop_service", "update_scratchpad", - "ask_user", "kb_search", "kb_read"] { + "ask_user"] { assert!(names.contains(expected), "{} must be in step tools", expected); } } @@ -2329,8 +2047,6 @@ mod tests { assert!(names.contains(&"update_plan")); assert!(names.contains(&"list_files")); assert!(names.contains(&"read_file")); - assert!(names.contains(&"kb_search")); - assert!(names.contains(&"kb_read")); assert!(!names.contains(&"execute")); assert!(!names.contains(&"step_done")); } diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..82d3383 --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,74 @@ +pub mod api; +pub mod agent; +pub mod db; +pub mod kb; +pub mod llm; +pub mod exec; +pub mod state; +pub mod template; +pub mod timer; +pub mod tools; +pub mod worker; +pub mod sink; +pub mod worker_runner; +pub mod ws; +pub mod ws_worker; + +use std::sync::Arc; +use serde::Deserialize; + +pub struct AppState { + pub db: db::Database, + pub config: Config, + pub agent_mgr: Arc, + pub kb: Option>, + pub obj_root: String, + pub auth: Option, +} + +#[derive(Debug, Clone, Deserialize)] +pub struct Config { + pub llm: LlmConfig, + pub server: ServerConfig, + pub database: DatabaseConfig, + #[serde(default)] + pub template_repo: Option, + /// Path to EC private key PEM file for JWT signing + #[serde(default)] + pub jwt_private_key: Option, +} + +#[derive(Debug, Clone, Deserialize)] +pub struct TemplateRepoConfig { + pub gitea_url: String, + pub owner: String, + pub repo: String, + #[serde(default = "default_repo_path")] + pub local_path: String, +} + +fn default_repo_path() -> String { + if std::path::Path::new("/app/oseng-templates").is_dir() { + "/app/oseng-templates".to_string() + } else { + "oseng-templates".to_string() + } +} + +#[derive(Debug, Clone, serde::Serialize, Deserialize)] +pub struct LlmConfig { + pub base_url: String, + pub api_key: String, + pub model: String, +} + +#[derive(Debug, Clone, Deserialize)] +pub struct ServerConfig { + pub host: String, + pub port: u16, +} + +#[derive(Debug, Clone, Deserialize)] +pub struct DatabaseConfig { + pub path: String, +} diff --git a/src/main.rs b/src/main.rs index d70f685..1ae0fb0 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,77 +1,33 @@ -mod api; -mod agent; -mod db; -mod kb; -mod llm; -mod exec; -pub mod state; -mod template; -mod timer; -mod tools; -mod worker; -mod ws; -mod ws_worker; - use std::sync::Arc; use axum::Router; +use clap::{Parser, Subcommand}; use sqlx::sqlite::SqlitePool; use tower_http::cors::CorsLayer; use tower_http::services::{ServeDir, ServeFile}; -pub struct AppState { - pub db: db::Database, - pub config: Config, - pub agent_mgr: Arc, - pub kb: Option>, - pub obj_root: String, - pub auth: Option, +use tori::{agent, api, db, kb, template, timer, worker, worker_runner, ws, ws_worker}; +use tori::{AppState, Config}; + +#[derive(Parser)] +#[command(name = "tori", about = "Tori AI agent orchestration")] +struct Cli { + #[command(subcommand)] + command: Command, } -#[derive(Debug, Clone, serde::Deserialize)] -pub struct Config { - pub llm: LlmConfig, - pub server: ServerConfig, - pub database: DatabaseConfig, - #[serde(default)] - pub template_repo: Option, - /// Path to EC private key PEM file for JWT signing - #[serde(default)] - pub jwt_private_key: Option, -} - -#[derive(Debug, Clone, serde::Deserialize)] -pub struct TemplateRepoConfig { - pub gitea_url: String, - pub owner: String, - pub repo: String, - #[serde(default = "default_repo_path")] - pub local_path: String, -} - -fn default_repo_path() -> String { - if std::path::Path::new("/app/oseng-templates").is_dir() { - "/app/oseng-templates".to_string() - } else { - "oseng-templates".to_string() - } -} - -#[derive(Debug, Clone, serde::Deserialize)] -pub struct LlmConfig { - pub base_url: String, - pub api_key: String, - pub model: String, -} - -#[derive(Debug, Clone, serde::Deserialize)] -pub struct ServerConfig { - pub host: String, - pub port: u16, -} - -#[derive(Debug, Clone, serde::Deserialize)] -pub struct DatabaseConfig { - pub path: String, +#[derive(Subcommand)] +enum Command { + /// Start the API server + Server, + /// Start a worker that connects to the server + Worker { + /// Server WebSocket URL + #[arg(long, env = "TORI_SERVER", default_value = "ws://127.0.0.1:3000/ws/tori/workers")] + server: String, + /// Worker name + #[arg(long, env = "TORI_WORKER_NAME")] + name: Option, + }, } #[tokio::main] @@ -80,6 +36,22 @@ async fn main() -> anyhow::Result<()> { .with_env_filter("tori=debug,tower_http=debug") .init(); + let cli = Cli::parse(); + + match cli.command { + Command::Server => run_server().await, + Command::Worker { server, name } => { + let name = name.unwrap_or_else(|| { + hostname::get() + .map(|h| h.to_string_lossy().to_string()) + .unwrap_or_else(|_| "worker-1".to_string()) + }); + worker_runner::run(&server, &name).await + } + } +} + +async fn run_server() -> anyhow::Result<()> { let config_str = std::fs::read_to_string("config.yaml") .expect("Failed to read config.yaml"); let config: Config = serde_yaml::from_str(&config_str) @@ -88,7 +60,6 @@ async fn main() -> anyhow::Result<()> { let database = db::Database::new(&config.database.path).await?; database.migrate().await?; - // Initialize KB manager let kb_arc = match kb::KbManager::new(database.pool.clone()) { Ok(kb) => { tracing::info!("KB manager initialized"); @@ -100,7 +71,6 @@ async fn main() -> anyhow::Result<()> { } }; - // Ensure template repo is cloned before serving if let Some(ref repo_cfg) = config.template_repo { template::ensure_repo_ready(repo_cfg).await; } @@ -117,8 +87,6 @@ async fn main() -> anyhow::Result<()> { ); timer::start_timer_runner(database.pool.clone(), agent_mgr.clone()); - - // Resume incomplete workflows after restart resume_workflows(database.pool.clone(), agent_mgr.clone()).await; let obj_root = std::env::var("OBJ_ROOT").unwrap_or_else(|_| "/data/obj".to_string()); @@ -129,7 +97,6 @@ async fn main() -> anyhow::Result<()> { let public_url = std::env::var("PUBLIC_URL") .unwrap_or_else(|_| "https://tori.euphon.cloud".to_string()); - // Try TikTok SSO first, then Google OAuth if let (Ok(id), Ok(secret)) = ( std::env::var("SSO_CLIENT_ID"), std::env::var("SSO_CLIENT_SECRET"), @@ -157,7 +124,7 @@ async fn main() -> anyhow::Result<()> { public_url, }) } else { - tracing::warn!("No OAuth configured (set SSO_CLIENT_ID/SSO_CLIENT_SECRET or GOOGLE_CLIENT_ID/GOOGLE_CLIENT_SECRET)"); + tracing::warn!("No OAuth configured"); None } }; @@ -172,13 +139,10 @@ async fn main() -> anyhow::Result<()> { }); let app = Router::new() - // Health check (public, for k8s probes) .route("/tori/api/health", axum::routing::get(|| async { axum::Json(serde_json::json!({"status": "ok"})) })) - // Auth routes are public .nest("/tori/api/auth", api::auth::router(state.clone())) - // Protected API routes .nest("/tori/api", api::router(state.clone()) .layer(axum::middleware::from_fn_with_state(state.clone(), api::auth::require_auth)) ) diff --git a/src/sink.rs b/src/sink.rs new file mode 100644 index 0000000..d16f3bb --- /dev/null +++ b/src/sink.rs @@ -0,0 +1,239 @@ +use std::collections::HashMap; +use std::sync::Arc; +use std::sync::atomic::{AtomicU16, Ordering}; +use serde::{Deserialize, Serialize}; +use sqlx::sqlite::SqlitePool; +use tokio::sync::{RwLock, broadcast, mpsc}; + +use crate::agent::{PlanStepInfo, WsMessage, ServiceInfo}; +use crate::state::{AgentState, Artifact}; + +/// All updates produced by the agent loop. This is the single output interface +/// that decouples the agent logic from DB persistence and WebSocket broadcasting. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "kind")] +pub enum AgentUpdate { + PlanUpdate { + workflow_id: String, + steps: Vec, + }, + WorkflowStatus { + workflow_id: String, + status: String, + }, + Activity { + workflow_id: String, + activity: String, + }, + ExecutionLog { + workflow_id: String, + step_order: i32, + tool_name: String, + tool_input: String, + output: String, + status: String, + }, + LlmCallLog { + workflow_id: String, + step_order: i32, + phase: String, + messages_count: i32, + tools_count: i32, + tool_calls: String, + text_response: String, + prompt_tokens: Option, + completion_tokens: Option, + latency_ms: i64, + }, + StateSnapshot { + workflow_id: String, + step_order: i32, + state: AgentState, + }, + WorkflowComplete { + workflow_id: String, + status: String, + report: Option, + }, + ArtifactSave { + workflow_id: String, + step_order: i32, + artifact: Artifact, + }, + RequirementUpdate { + workflow_id: String, + requirement: String, + }, + Error { + message: String, + }, +} + +/// Manages local services (start_service / stop_service tools). +/// Created per-worker or per-agent-loop. +pub struct ServiceManager { + pub services: RwLock>, + next_port: AtomicU16, +} + +impl ServiceManager { + pub fn new(start_port: u16) -> Arc { + Arc::new(Self { + services: RwLock::new(HashMap::new()), + next_port: AtomicU16::new(start_port), + }) + } + + pub fn allocate_port(&self) -> u16 { + self.next_port.fetch_add(1, Ordering::Relaxed) + } +} + +/// Server-side handler: consumes AgentUpdate from channel, persists to DB and broadcasts to frontend. +pub async fn handle_agent_updates( + mut rx: mpsc::Receiver, + pool: SqlitePool, + broadcast_tx: broadcast::Sender, +) { + while let Some(update) = rx.recv().await { + match update { + AgentUpdate::PlanUpdate { workflow_id, steps } => { + let _ = broadcast_tx.send(WsMessage::PlanUpdate { workflow_id, steps }); + } + AgentUpdate::WorkflowStatus { ref workflow_id, ref status } => { + let _ = sqlx::query("UPDATE workflows SET status = ? WHERE id = ?") + .bind(status) + .bind(workflow_id) + .execute(&pool) + .await; + let _ = broadcast_tx.send(WsMessage::WorkflowStatusUpdate { + workflow_id: workflow_id.clone(), + status: status.clone(), + }); + } + AgentUpdate::Activity { workflow_id, activity } => { + let _ = broadcast_tx.send(WsMessage::ActivityUpdate { workflow_id, activity }); + } + AgentUpdate::ExecutionLog { ref workflow_id, step_order, ref tool_name, ref tool_input, ref output, ref status } => { + let id = uuid::Uuid::new_v4().to_string(); + let _ = sqlx::query( + "INSERT INTO execution_log (id, workflow_id, step_order, tool_name, tool_input, output, status, created_at) VALUES (?, ?, ?, ?, ?, ?, ?, datetime('now'))" + ) + .bind(&id) + .bind(workflow_id) + .bind(step_order) + .bind(tool_name) + .bind(tool_input) + .bind(output) + .bind(status) + .execute(&pool) + .await; + let _ = broadcast_tx.send(WsMessage::StepStatusUpdate { + step_id: id, + status: status.clone(), + output: output.clone(), + }); + } + AgentUpdate::LlmCallLog { ref workflow_id, step_order, ref phase, messages_count, tools_count, ref tool_calls, ref text_response, prompt_tokens, completion_tokens, latency_ms } => { + let id = uuid::Uuid::new_v4().to_string(); + let _ = sqlx::query( + "INSERT INTO llm_call_log (id, workflow_id, step_order, phase, messages_count, tools_count, tool_calls, text_response, prompt_tokens, completion_tokens, latency_ms, created_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, datetime('now'))" + ) + .bind(&id) + .bind(workflow_id) + .bind(step_order) + .bind(phase) + .bind(messages_count) + .bind(tools_count) + .bind(tool_calls) + .bind(text_response) + .bind(prompt_tokens.map(|v| v as i32)) + .bind(completion_tokens.map(|v| v as i32)) + .bind(latency_ms as i32) + .execute(&pool) + .await; + let entry = crate::db::LlmCallLogEntry { + id, + workflow_id: workflow_id.clone(), + step_order, + phase: phase.clone(), + messages_count, + tools_count, + tool_calls: tool_calls.clone(), + text_response: text_response.clone(), + prompt_tokens: prompt_tokens.map(|v| v as i32), + completion_tokens: completion_tokens.map(|v| v as i32), + latency_ms: latency_ms as i32, + created_at: String::new(), + }; + let _ = broadcast_tx.send(WsMessage::LlmCallLog { + workflow_id: workflow_id.clone(), + entry, + }); + } + AgentUpdate::StateSnapshot { ref workflow_id, step_order, ref state } => { + let id = uuid::Uuid::new_v4().to_string(); + let json = serde_json::to_string(state).unwrap_or_default(); + let _ = sqlx::query( + "INSERT INTO agent_state_snapshots (id, workflow_id, step_order, state_json, created_at) VALUES (?, ?, ?, ?, datetime('now'))" + ) + .bind(&id) + .bind(workflow_id) + .bind(step_order) + .bind(&json) + .execute(&pool) + .await; + } + AgentUpdate::WorkflowComplete { ref workflow_id, ref status, ref report } => { + let _ = sqlx::query("UPDATE workflows SET status = ? WHERE id = ?") + .bind(status) + .bind(workflow_id) + .execute(&pool) + .await; + if let Some(ref r) = report { + let _ = sqlx::query("UPDATE workflows SET report = ? WHERE id = ?") + .bind(r) + .bind(workflow_id) + .execute(&pool) + .await; + let _ = broadcast_tx.send(WsMessage::ReportReady { + workflow_id: workflow_id.clone(), + }); + } + let _ = broadcast_tx.send(WsMessage::WorkflowStatusUpdate { + workflow_id: workflow_id.clone(), + status: status.clone(), + }); + } + AgentUpdate::ArtifactSave { ref workflow_id, step_order, ref artifact } => { + let id = uuid::Uuid::new_v4().to_string(); + let _ = sqlx::query( + "INSERT INTO step_artifacts (id, workflow_id, step_order, name, path, artifact_type, description) VALUES (?, ?, ?, ?, ?, ?, ?)" + ) + .bind(&id) + .bind(workflow_id) + .bind(step_order) + .bind(&artifact.name) + .bind(&artifact.path) + .bind(&artifact.artifact_type) + .bind(&artifact.description) + .execute(&pool) + .await; + } + AgentUpdate::RequirementUpdate { ref workflow_id, ref requirement } => { + let _ = sqlx::query("UPDATE workflows SET requirement = ? WHERE id = ?") + .bind(requirement) + .bind(workflow_id) + .execute(&pool) + .await; + let _ = broadcast_tx.send(WsMessage::RequirementUpdate { + workflow_id: workflow_id.clone(), + requirement: requirement.clone(), + }); + } + AgentUpdate::Error { message } => { + let _ = broadcast_tx.send(WsMessage::Error { message }); + } + } + } +} diff --git a/src/template.rs b/src/template.rs index 091fdcf..c399612 100644 --- a/src/template.rs +++ b/src/template.rs @@ -3,7 +3,6 @@ use std::path::{Path, PathBuf}; use serde::Deserialize; use crate::TemplateRepoConfig; -use crate::llm::{ChatMessage, LlmClient}; use crate::tools::ExternalToolManager; #[derive(Debug, Deserialize)] @@ -463,42 +462,6 @@ pub fn is_repo_template(template_id: &str) -> bool { template_id.contains('/') } -// --- LLM template selection --- - -pub async fn select_template(llm: &LlmClient, requirement: &str, repo_cfg: Option<&TemplateRepoConfig>) -> Option { - let all = list_all_templates(repo_cfg).await; - if all.is_empty() { - return None; - } - - let listing: String = all - .iter() - .map(|t| format!("- id: {}\n 名称: {}\n 描述: {}", t.id, t.name, t.description)) - .collect::>() - .join("\n"); - - let prompt = format!( - "以下是可用的项目模板:\n{}\n\n用户需求:{}\n\n选择最匹配的模板 ID,如果都不合适则回复 none。只回复模板 ID 或 none,不要其他内容。", - listing, requirement - ); - - let response = llm - .chat(vec![ - ChatMessage::system("你是一个模板选择助手。根据用户需求选择最合适的项目模板。只回复模板 ID 或 none。"), - ChatMessage::user(&prompt), - ]) - .await - .ok()?; - - let answer = response.trim().to_lowercase(); - tracing::info!("Template selection LLM response: '{}' (available: {:?})", - answer, all.iter().map(|t| t.id.as_str()).collect::>()); - if answer == "none" { - return None; - } - - all.iter().find(|t| t.id == answer).map(|t| t.id.clone()) -} // --- Template loading --- diff --git a/src/worker.rs b/src/worker.rs index b998c7b..99cf576 100644 --- a/src/worker.rs +++ b/src/worker.rs @@ -131,3 +131,56 @@ impl WorkerManager { } } } + +// --- Extended protocol for workflow execution --- + +use crate::LlmConfig; +use crate::state::AgentState; +use crate::llm::Tool; + +/// Messages sent from server to worker. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type")] +pub enum ServerToWorker { + /// Assign a full workflow for execution. + #[serde(rename = "workflow_assign")] + WorkflowAssign { + workflow_id: String, + project_id: String, + requirement: String, + workdir: String, + instructions: String, + llm_config: LlmConfig, + #[serde(default)] + initial_state: Option, + #[serde(default)] + require_plan_approval: bool, + #[serde(default)] + external_tools: Vec, + }, + /// Forward a user comment to the worker executing this workflow. + #[serde(rename = "comment")] + Comment { + workflow_id: String, + content: String, + }, +} + +/// Messages sent from worker to server. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type")] +pub enum WorkerToServer { + /// Worker registration. + #[serde(rename = "register")] + Register { info: WorkerInfo }, + /// Script execution result (legacy). + #[serde(rename = "result")] + Result(WorkerResult), + /// Agent update from workflow execution. + #[serde(rename = "update")] + Update { + workflow_id: String, + #[serde(flatten)] + update: crate::sink::AgentUpdate, + }, +} diff --git a/src/worker_runner.rs b/src/worker_runner.rs new file mode 100644 index 0000000..ba0c8b7 --- /dev/null +++ b/src/worker_runner.rs @@ -0,0 +1,203 @@ +use std::sync::Arc; +use futures::{SinkExt, StreamExt}; +use tokio::sync::mpsc; +use tokio_tungstenite::{connect_async, tungstenite::Message}; + +use crate::agent::{self, AgentEvent}; +use crate::exec::LocalExecutor; +use crate::llm::LlmClient; +use crate::sink::{AgentUpdate, ServiceManager}; +use crate::worker::{ServerToWorker, WorkerInfo, WorkerToServer}; + +fn collect_worker_info(name: &str) -> WorkerInfo { + let cpu = std::fs::read_to_string("/proc/cpuinfo") + .ok() + .and_then(|s| { + s.lines() + .find(|l| l.starts_with("model name")) + .map(|l| l.split(':').nth(1).unwrap_or("").trim().to_string()) + }) + .unwrap_or_else(|| "unknown".into()); + + let memory = std::fs::read_to_string("/proc/meminfo") + .ok() + .and_then(|s| { + s.lines() + .find(|l| l.starts_with("MemTotal")) + .and_then(|l| l.split_whitespace().nth(1)) + .and_then(|kb| kb.parse::().ok()) + .map(|kb| format!("{:.1} GB", kb as f64 / 1_048_576.0)) + }) + .unwrap_or_else(|| "unknown".into()); + + let gpu = std::process::Command::new("nvidia-smi") + .arg("--query-gpu=name") + .arg("--format=csv,noheader") + .output() + .ok() + .and_then(|o| { + if o.status.success() { + Some(String::from_utf8_lossy(&o.stdout).trim().to_string()) + } else { + None + } + }) + .unwrap_or_else(|| "none".into()); + + WorkerInfo { + name: name.to_string(), + cpu, + memory, + gpu, + os: std::env::consts::OS.to_string(), + kernel: std::process::Command::new("uname") + .arg("-r") + .output() + .ok() + .map(|o| String::from_utf8_lossy(&o.stdout).trim().to_string()) + .unwrap_or_else(|| "unknown".into()), + } +} + +pub async fn run(server_url: &str, worker_name: &str) -> anyhow::Result<()> { + tracing::info!("Tori worker '{}' connecting to {}", worker_name, server_url); + + loop { + match connect_and_run(server_url, worker_name).await { + Ok(()) => { + tracing::info!("Worker connection closed, reconnecting in 5s..."); + } + Err(e) => { + tracing::error!("Worker error: {}, reconnecting in 5s...", e); + } + } + tokio::time::sleep(std::time::Duration::from_secs(5)).await; + } +} + +async fn connect_and_run(server_url: &str, worker_name: &str) -> anyhow::Result<()> { + let (ws_stream, _) = connect_async(server_url).await?; + let (mut ws_tx, mut ws_rx) = ws_stream.split(); + + // Register + let info = collect_worker_info(worker_name); + let register_msg = serde_json::to_string(&WorkerToServer::Register { info })?; + ws_tx.send(Message::Text(register_msg.into())).await?; + + // Wait for registration ack + while let Some(msg) = ws_rx.next().await { + match msg? { + Message::Text(text) => { + let v: serde_json::Value = serde_json::from_str(&text)?; + if v["type"] == "registered" { + tracing::info!("Registered as '{}'", v["name"]); + break; + } + } + Message::Close(_) => anyhow::bail!("Connection closed during registration"), + _ => {} + } + } + + let svc_mgr = ServiceManager::new(9100); + let ws_tx = Arc::new(tokio::sync::Mutex::new(ws_tx)); + + // Channel for forwarding comments to the running workflow + let comment_tx: Arc>>> = + Arc::new(tokio::sync::Mutex::new(None)); + + // Main message loop + while let Some(msg) = ws_rx.next().await { + let text = match msg? { + Message::Text(t) => t, + Message::Close(_) => break, + _ => continue, + }; + + let server_msg: ServerToWorker = match serde_json::from_str(&text) { + Ok(m) => m, + Err(e) => { + tracing::warn!("Failed to parse server message: {}", e); + continue; + } + }; + + match server_msg { + ServerToWorker::WorkflowAssign { + workflow_id, + project_id, + requirement, + workdir, + instructions, + llm_config, + initial_state, + require_plan_approval, + external_tools: _, + } => { + tracing::info!("Received workflow: {} (project {})", workflow_id, project_id); + + let llm = LlmClient::new(&llm_config); + let exec = LocalExecutor::new(None); + + // update channel → serialize → WebSocket + let (update_tx, mut update_rx) = mpsc::channel::(64); + let ws_tx_clone = ws_tx.clone(); + let wf_id_clone = workflow_id.clone(); + tokio::spawn(async move { + while let Some(update) = update_rx.recv().await { + let msg = WorkerToServer::Update { + workflow_id: wf_id_clone.clone(), + update, + }; + if let Ok(json) = serde_json::to_string(&msg) { + let mut tx = ws_tx_clone.lock().await; + if tx.send(Message::Text(json.into())).await.is_err() { + break; + } + } + } + }); + + // event channel for comments + let (evt_tx, mut evt_rx) = mpsc::channel::(32); + *comment_tx.lock().await = Some(evt_tx); + + let _ = tokio::fs::create_dir_all(&workdir).await; + + let result = agent::run_agent_loop( + &llm, &exec, &update_tx, &mut evt_rx, + &project_id, &workflow_id, &requirement, &workdir, &svc_mgr, + &instructions, initial_state, None, require_plan_approval, + ).await; + + let final_status = if result.is_ok() { "done" } else { "failed" }; + if let Err(e) = &result { + tracing::error!("Workflow {} failed: {}", workflow_id, e); + let _ = update_tx.send(AgentUpdate::Error { + message: format!("Agent error: {}", e), + }).await; + } + + let _ = update_tx.send(AgentUpdate::WorkflowComplete { + workflow_id: workflow_id.clone(), + status: final_status.into(), + report: None, + }).await; + + *comment_tx.lock().await = None; + tracing::info!("Workflow {} completed: {}", workflow_id, final_status); + } + + ServerToWorker::Comment { workflow_id, content } => { + if let Some(ref tx) = *comment_tx.lock().await { + let _ = tx.send(AgentEvent::Comment { + workflow_id, + content, + }).await; + } + } + } + } + + Ok(()) +}