diff --git a/src/agent.rs b/src/agent.rs index a3ef56d..aa3481d 100644 --- a/src/agent.rs +++ b/src/agent.rs @@ -1,17 +1,13 @@ use std::collections::HashMap; -use std::path::Path; use std::sync::Arc; -use std::sync::atomic::{AtomicU16, Ordering}; use serde::{Deserialize, Serialize}; use sqlx::sqlite::SqlitePool; use tokio::sync::{mpsc, RwLock, broadcast}; use crate::llm::{LlmClient, ChatMessage, Tool, ToolFunction}; use crate::exec::LocalExecutor; -use crate::template::{self, LoadedTemplate}; use crate::tools::ExternalToolManager; use crate::worker::WorkerManager; -use crate::LlmConfig; use crate::sink::{AgentUpdate, ServiceManager}; use crate::state::{AgentState, AgentPhase, Artifact, Step, StepStatus, StepResult, StepResultStatus, check_scratchpad_size}; @@ -73,49 +69,23 @@ pub fn plan_infos_from_state(state: &AgentState) -> Vec { } pub struct AgentManager { - agents: RwLock>>, broadcast: RwLock>>, - pub services: RwLock>, - next_port: AtomicU16, pool: SqlitePool, - llm_config: LlmConfig, - template_repo: Option, - kb: Option>, - jwt_private_key_path: Option, pub worker_mgr: Arc, } impl AgentManager { pub fn new( pool: SqlitePool, - llm_config: LlmConfig, - template_repo: Option, - kb: Option>, - jwt_private_key_path: Option, worker_mgr: Arc, ) -> Arc { Arc::new(Self { - agents: RwLock::new(HashMap::new()), broadcast: RwLock::new(HashMap::new()), - services: RwLock::new(HashMap::new()), - next_port: AtomicU16::new(9100), pool, - llm_config, - template_repo, - kb, - jwt_private_key_path, worker_mgr, }) } - pub fn allocate_port(&self) -> u16 { - self.next_port.fetch_add(1, Ordering::Relaxed) - } - - pub async fn get_service_port(&self, project_id: &str) -> Option { - self.services.read().await.get(project_id).map(|s| s.port) - } - pub async fn get_broadcast(&self, project_id: &str) -> broadcast::Receiver { let mut map = self.broadcast.write().await; let tx = map.entry(project_id.to_string()) @@ -123,403 +93,95 @@ impl AgentManager { tx.subscribe() } - pub async fn send_event(self: &Arc, project_id: &str, event: AgentEvent) { - let agents = self.agents.read().await; - if let Some(tx) = agents.get(project_id) { - let _ = tx.send(event).await; - } else { - drop(agents); - self.spawn_agent(project_id.to_string()).await; - let agents = self.agents.read().await; - if let Some(tx) = agents.get(project_id) { - let _ = tx.send(event).await; - } - } + pub fn get_broadcast_sender(&self, project_id: &str) -> broadcast::Sender { + // This is called synchronously from ws_worker; we use a blocking approach + // since RwLock is tokio-based, we need a sync wrapper + // Actually, let's use try_write or just create a new one + // For simplicity, return a new sender each time (they share the channel) + // This is safe because broadcast::Sender is Clone + tokio::task::block_in_place(|| { + let rt = tokio::runtime::Handle::current(); + rt.block_on(async { + let mut map = self.broadcast.write().await; + map.entry(project_id.to_string()) + .or_insert_with(|| broadcast::channel(64).0) + .clone() + }) + }) } - async fn spawn_agent(self: &Arc, project_id: String) { - let (tx, rx) = mpsc::channel(32); - self.agents.write().await.insert(project_id.clone(), tx); + /// Dispatch an event to a worker. + pub async fn send_event(self: &Arc, project_id: &str, event: AgentEvent) { + match event { + AgentEvent::NewRequirement { workflow_id, requirement, template_id } => { + // Generate title (heuristic) + let title = generate_title_heuristic(&requirement); + let _ = sqlx::query("UPDATE projects SET name = ? WHERE id = ?") + .bind(&title).bind(project_id).execute(&self.pool).await; + let btx = { + let mut map = self.broadcast.write().await; + map.entry(project_id.to_string()) + .or_insert_with(|| broadcast::channel(64).0) + .clone() + }; + let _ = btx.send(WsMessage::ProjectUpdate { + project_id: project_id.to_string(), + name: title, + }); - let broadcast_tx = { - let mut map = self.broadcast.write().await; - map.entry(project_id.clone()) - .or_insert_with(|| broadcast::channel(64).0) - .clone() - }; + // Update workflow status + let _ = sqlx::query("UPDATE workflows SET status = 'executing' WHERE id = ?") + .bind(&workflow_id).execute(&self.pool).await; + let _ = btx.send(WsMessage::WorkflowStatusUpdate { + workflow_id: workflow_id.clone(), + status: "executing".into(), + }); - let mgr = Arc::clone(self); - tokio::spawn(agent_loop(project_id, rx, broadcast_tx, mgr)); + // Persist template_id + if let Some(ref tid) = template_id { + let _ = sqlx::query("UPDATE workflows SET template_id = ? WHERE id = ?") + .bind(tid).bind(&workflow_id).execute(&self.pool).await; + } + + // Dispatch to worker + let assign = crate::worker::ServerToWorker::WorkflowAssign { + workflow_id: workflow_id.clone(), + project_id: project_id.to_string(), + requirement, + template_id, + initial_state: None, + require_plan_approval: false, + }; + + match self.worker_mgr.assign_workflow(assign).await { + Ok(name) => { + tracing::info!("Workflow {} dispatched to worker '{}'", workflow_id, name); + } + Err(e) => { + tracing::error!("Failed to dispatch workflow {}: {}", workflow_id, e); + let _ = sqlx::query("UPDATE workflows SET status = 'failed' WHERE id = ?") + .bind(&workflow_id).execute(&self.pool).await; + let _ = btx.send(WsMessage::WorkflowStatusUpdate { + workflow_id, + status: "failed".into(), + }); + let _ = btx.send(WsMessage::Error { + message: format!("No worker available: {}", e), + }); + } + } + } + AgentEvent::Comment { workflow_id, content } => { + if let Err(e) = self.worker_mgr.forward_comment(&workflow_id, &content).await { + tracing::warn!("Failed to forward comment for workflow {}: {}", workflow_id, e); + } + } + } } } // Template system is in crate::template -/// Read INSTRUCTIONS.md from workdir if it exists. -async fn read_instructions(workdir: &str) -> String { - let path = format!("{}/INSTRUCTIONS.md", workdir); - tokio::fs::read_to_string(&path).await.unwrap_or_default() -} - -async fn ensure_workspace(exec: &LocalExecutor, workdir: &str) { - let _ = tokio::fs::create_dir_all(workdir).await; - let setup_script = format!("{}/scripts/setup.sh", workdir); - if Path::new(&setup_script).exists() { - tracing::info!("Running setup.sh in {}", workdir); - let _ = exec.execute("bash scripts/setup.sh", workdir).await; - } else { - let venv_path = format!("{}/.venv", workdir); - if !Path::new(&venv_path).exists() { - let _ = exec.execute("uv venv .venv", workdir).await; - let _ = exec.execute("uv pip install httpx fastapi uvicorn requests flask pydantic numpy pandas matplotlib pillow jinja2 pyyaml python-dotenv beautifulsoup4 lxml aiohttp aiofiles pytest rich click typer sqlalchemy", workdir).await; - } - } -} - -async fn agent_loop( - project_id: String, - mut rx: mpsc::Receiver, - broadcast_tx: broadcast::Sender, - mgr: Arc, -) { - let pool = mgr.pool.clone(); - let llm_config = mgr.llm_config.clone(); - let llm = LlmClient::new(&llm_config); - let exec = LocalExecutor::new(mgr.jwt_private_key_path.clone()); - let workdir = format!("/app/data/workspaces/{}", project_id); - let svc_mgr = ServiceManager::new(9100); - - // Create update channel and spawn handler - let (update_tx, update_rx) = mpsc::channel::(64); - { - let handler_pool = pool.clone(); - let handler_btx = broadcast_tx.clone(); - tokio::spawn(async move { - crate::sink::handle_agent_updates(update_rx, handler_pool, handler_btx).await; - }); - } - - tracing::info!("Agent loop started for project {}", project_id); - - while let Some(event) = rx.recv().await { - match event { - AgentEvent::NewRequirement { workflow_id, requirement, template_id: forced_template } => { - tracing::info!("Processing new requirement for workflow {}", workflow_id); - // Generate project title from requirement (heuristic, no LLM) - { - let title = generate_title_heuristic(&requirement); - let _ = sqlx::query("UPDATE projects SET name = ? WHERE id = ?") - .bind(&title) - .bind(&project_id) - .execute(&pool) - .await; - let _ = broadcast_tx.send(WsMessage::ProjectUpdate { - project_id: project_id.clone(), - name: title, - }); - } - - let _ = update_tx.send(AgentUpdate::WorkflowStatus { - workflow_id: workflow_id.clone(), - status: "executing".into(), - }).await; - - // Template: must be explicitly provided (no LLM selection) - let template_id = forced_template; - // Persist template_id to workflow - if let Some(ref tid) = template_id { - let _ = sqlx::query("UPDATE workflows SET template_id = ? WHERE id = ?") - .bind(tid) - .bind(&workflow_id) - .execute(&pool) - .await; - } - - let loaded_template = if let Some(ref tid) = template_id { - tracing::info!("Template selected for workflow {}: {}", workflow_id, tid); - let _ = tokio::fs::create_dir_all(&workdir).await; - - if template::is_repo_template(tid) { - // Repo template: extract from git then load - match template::extract_repo_template(tid, mgr.template_repo.as_ref()).await { - Ok(template_dir) => { - if let Err(e) = template::apply_template(&template_dir, &workdir).await { - tracing::error!("Failed to apply repo template {}: {}", tid, e); - } - match LoadedTemplate::load_from_dir(tid, &template_dir).await { - Ok(t) => Some(t), - Err(e) => { - tracing::error!("Failed to load repo template {}: {}", tid, e); - None - } - } - } - Err(e) => { - tracing::error!("Failed to extract repo template {}: {}", tid, e); - None - } - } - } else { - // Local built-in template - let template_dir = std::path::Path::new(template::templates_dir()).join(tid); - if let Err(e) = template::apply_template(&template_dir, &workdir).await { - tracing::error!("Failed to apply template {}: {}", tid, e); - } - match LoadedTemplate::load(tid).await { - Ok(t) => Some(t), - Err(e) => { - tracing::error!("Failed to load template {}: {}", tid, e); - None - } - } - } - } else { - None - }; - - // Import KB files from template - if let Some(ref t) = loaded_template { - if let Some(ref kb) = mgr.kb { - let mut batch_items: Vec<(String, String)> = Vec::new(); - for (title, content) in &t.kb_files { - // Check if article already exists by title - let existing: Option = sqlx::query_scalar( - "SELECT id FROM kb_articles WHERE title = ?" - ) - .bind(title) - .fetch_optional(&pool) - .await - .ok() - .flatten(); - - let article_id = if let Some(id) = existing { - let _ = sqlx::query( - "UPDATE kb_articles SET content = ?, updated_at = datetime('now') WHERE id = ?" - ) - .bind(content) - .bind(&id) - .execute(&pool) - .await; - id - } else { - let id = uuid::Uuid::new_v4().to_string(); - let _ = sqlx::query( - "INSERT INTO kb_articles (id, title, content) VALUES (?, ?, ?)" - ) - .bind(&id) - .bind(title) - .bind(content) - .execute(&pool) - .await; - id - }; - - batch_items.push((article_id, content.clone())); - } - // Batch index: single embed.py call for all articles - if !batch_items.is_empty() { - if let Err(e) = kb.index_batch(&batch_items).await { - tracing::warn!("Failed to batch index KB articles: {}", e); - } - } - tracing::info!("Imported {} KB articles from template", t.kb_files.len()); - } - } - - ensure_workspace(&exec, &workdir).await; - let _ = tokio::fs::write(format!("{}/requirement.md", workdir), &requirement).await; - - // Run template setup if present - if let Some(ref tid) = template_id { - let template_dir = if template::is_repo_template(tid) { - template::extract_repo_template(tid, mgr.template_repo.as_ref()) - .await - .ok() - } else { - Some(std::path::Path::new(template::templates_dir()).join(tid)) - }; - if let Some(ref tdir) = template_dir { - if let Err(e) = template::run_setup(tdir, &workdir).await { - tracing::error!("Template setup failed for {}: {}", tid, e); - } - } - } - - let instructions = if let Some(ref t) = loaded_template { - t.instructions.clone() - } else { - read_instructions(&workdir).await - }; - - let ext_tools = loaded_template.as_ref().map(|t| &t.external_tools); - let plan_approval = loaded_template.as_ref().map_or(false, |t| t.require_plan_approval); - - tracing::info!("Starting agent loop for workflow {}", workflow_id); - // Run tool-calling agent loop - let result = run_agent_loop( - &llm, &exec, &update_tx, &mut rx, - &project_id, &workflow_id, &requirement, &workdir, &svc_mgr, - &instructions, None, ext_tools, - plan_approval, - ).await; - - let final_status = if result.is_ok() { "done" } else { "failed" }; - tracing::info!("Agent loop finished for workflow {}, status: {}", workflow_id, final_status); - if let Err(e) = &result { - tracing::error!("Agent error for workflow {}: {}", workflow_id, e); - let _ = update_tx.send(AgentUpdate::Error { - message: format!("Agent error: {}", e), - }).await; - } - - let _ = update_tx.send(AgentUpdate::WorkflowComplete { - workflow_id: workflow_id.clone(), - status: final_status.into(), - report: None, // Report generation will be handled separately - }).await; - } - AgentEvent::Comment { workflow_id, content } => { - tracing::info!("Comment on workflow {}: {}", workflow_id, content); - - let wf = sqlx::query_as::<_, crate::db::Workflow>( - "SELECT * FROM workflows WHERE id = ?", - ) - .bind(&workflow_id) - .fetch_optional(&pool) - .await - .ok() - .flatten(); - - let Some(wf) = wf else { continue }; - - // Load latest state snapshot - let snapshot = sqlx::query_scalar::<_, String>( - "SELECT state_json FROM agent_state_snapshots WHERE workflow_id = ? ORDER BY created_at DESC LIMIT 1" - ) - .bind(&workflow_id) - .fetch_optional(&pool) - .await - .ok() - .flatten(); - - let mut state = snapshot - .and_then(|json| serde_json::from_str::(&json).ok()) - .unwrap_or_else(AgentState::new); - - // Resume directly if: workflow is failed/done/waiting_user, - // OR if state snapshot has a WaitingUser step (e.g. after pod restart) - let has_waiting_step = state.steps.iter().any(|s| matches!(s.status, StepStatus::WaitingUser)); - let is_resuming = wf.status == "failed" || wf.status == "done" - || wf.status == "waiting_user" || has_waiting_step; - if is_resuming { - // Reset Failed/WaitingUser steps so they get re-executed - for step in &mut state.steps { - if matches!(step.status, StepStatus::Failed) { - step.status = StepStatus::Pending; - } - if matches!(step.status, StepStatus::WaitingUser) { - // Mark as Running so it continues (not re-plans) - step.status = StepStatus::Running; - } - } - // Attach comment as feedback to the first actionable step - if let Some(order) = state.first_actionable_step() { - if let Some(step) = state.steps.iter_mut().find(|s| s.order == order) { - step.user_feedbacks.push(content.clone()); - } - } - tracing::info!("[workflow {}] Resuming from state (status={}), first actionable: {:?}", - workflow_id, wf.status, state.first_actionable_step()); - } else { - // Active workflow: LLM decides whether to revise plan - state = process_feedback( - &llm, &update_tx, - &project_id, &workflow_id, state, &content, - ).await; - } - - // If there are actionable steps, resume execution - if state.first_actionable_step().is_some() { - ensure_workspace(&exec, &workdir).await; - - let _ = update_tx.send(AgentUpdate::WorkflowStatus { - workflow_id: workflow_id.clone(), - status: "executing".into(), - }).await; - - // Prepare state for execution: set first pending step to Running - if let Some(next) = state.first_actionable_step() { - let was_same_step = matches!(state.phase, AgentPhase::Executing { step } if step == next); - if let Some(step) = state.steps.iter_mut().find(|s| s.order == next) { - if matches!(step.status, StepStatus::Pending) { - step.status = StepStatus::Running; - } - } - state.phase = AgentPhase::Executing { step: next }; - // Only clear chat history when advancing to a new step; - // keep it when resuming the same step after ask_user - if !was_same_step { - state.current_step_chat_history.clear(); - } - } - - let instructions = read_instructions(&workdir).await; - - // Reload template config if available - let loaded_template = if !wf.template_id.is_empty() { - let tid = &wf.template_id; - if template::is_repo_template(tid) { - match template::extract_repo_template(tid, mgr.template_repo.as_ref()).await { - Ok(template_dir) => { - LoadedTemplate::load_from_dir(tid, &template_dir).await.ok() - } - Err(e) => { - tracing::warn!("Failed to reload template {}: {}", tid, e); - None - } - } - } else { - LoadedTemplate::load(tid).await.ok() - } - } else { - None - }; - let ext_tools = loaded_template.as_ref().map(|t| &t.external_tools); - let plan_approval = loaded_template.as_ref().map_or(false, |t| t.require_plan_approval); - - let result = run_agent_loop( - &llm, &exec, &update_tx, &mut rx, - &project_id, &workflow_id, &wf.requirement, &workdir, &svc_mgr, - &instructions, Some(state), ext_tools, - plan_approval, - ).await; - - let final_status = if result.is_ok() { "done" } else { "failed" }; - if let Err(e) = &result { - let _ = update_tx.send(AgentUpdate::Error { - message: format!("Agent error: {}", e), - }).await; - } - - let _ = update_tx.send(AgentUpdate::WorkflowComplete { - workflow_id: workflow_id.clone(), - status: final_status.into(), - report: None, - }).await; - } else { - // No actionable steps — feedback was informational only - let _ = update_tx.send(AgentUpdate::WorkflowStatus { - workflow_id: workflow_id.clone(), - status: "done".into(), - }).await; - } - } - } - } - - tracing::info!("Agent loop ended for project {}", project_id); -} - // --- Tool definitions --- fn make_tool(name: &str, description: &str, parameters: serde_json::Value) -> Tool { @@ -750,49 +412,6 @@ fn build_step_user_message(step: &Step, completed_summaries: &[(i32, String, Str ctx } -fn build_feedback_prompt(project_id: &str, state: &AgentState, feedback: &str) -> String { - let mut plan_state = String::new(); - for s in &state.steps { - let status = match s.status { - StepStatus::Done => " [done]", - StepStatus::Running => " [running]", - StepStatus::WaitingUser => " [waiting]", - StepStatus::Failed => " [FAILED]", - StepStatus::Pending => "", - }; - plan_state.push_str(&format!("{}. {}{}\n {}\n", s.order, s.title, status, s.description)); - if let Some(summary) = &s.summary { - plan_state.push_str(&format!(" 摘要: {}\n", summary)); - } - } - include_str!("prompts/feedback.md") - .replace("{project_id}", project_id) - .replace("{plan_state}", &plan_state) - .replace("{feedback}", feedback) -} - -fn build_feedback_tools() -> Vec { - vec![ - make_tool("revise_plan", "修改执行计划。提供完整步骤列表。系统自动 diff:description 未变的已完成步骤保留成果,变化的步骤及后续重新执行。", serde_json::json!({ - "type": "object", - "properties": { - "steps": { - "type": "array", - "items": { - "type": "object", - "properties": { - "title": { "type": "string", "description": "步骤标题" }, - "description": { "type": "string", "description": "详细描述" } - }, - "required": ["title", "description"] - } - } - }, - "required": ["steps"] - })), - ] -} - // --- Helpers --- /// Truncate a string at a char boundary, returning at most `max_bytes` bytes. @@ -956,91 +575,6 @@ async fn send_llm_call( }).await; } -/// Process user feedback: call LLM to decide whether to revise the plan. -/// Returns the (possibly modified) AgentState ready for resumed execution. -async fn process_feedback( - llm: &LlmClient, - update_tx: &mpsc::Sender, - project_id: &str, - workflow_id: &str, - mut state: AgentState, - feedback: &str, -) -> AgentState { - let prompt = build_feedback_prompt(project_id, &state, feedback); - let tools = build_feedback_tools(); - let messages = vec![ - ChatMessage::system(&prompt), - ChatMessage::user(feedback), - ]; - - tracing::info!("[workflow {}] Processing feedback with LLM", workflow_id); - let response = match llm.chat_with_tools(messages, &tools).await { - Ok(r) => r, - Err(e) => { - tracing::error!("[workflow {}] Feedback LLM call failed: {}", workflow_id, e); - if let Some(step) = state.steps.iter_mut().find(|s| !matches!(s.status, StepStatus::Done)) { - step.user_feedbacks.push(feedback.to_string()); - } - return state; - } - }; - - let choice = match response.choices.into_iter().next() { - Some(c) => c, - None => return state, - }; - - if let Some(tool_calls) = &choice.message.tool_calls { - for tc in tool_calls { - if tc.function.name == "revise_plan" { - let args: serde_json::Value = serde_json::from_str(&tc.function.arguments).unwrap_or_default(); - let raw_steps = args["steps"].as_array().cloned().unwrap_or_default(); - - let new_steps: Vec = raw_steps.iter().enumerate().map(|(i, item)| { - let order = (i + 1) as i32; - Step { - order, - title: item["title"].as_str().unwrap_or("").to_string(), - description: item["description"].as_str().unwrap_or("").to_string(), - status: StepStatus::Pending, - summary: None, - user_feedbacks: Vec::new(), - db_id: String::new(), - artifacts: Vec::new(), - } - }).collect(); - - let diff = state.apply_plan_diff(new_steps); - - let _ = update_tx.send(AgentUpdate::PlanUpdate { - workflow_id: workflow_id.to_string(), - steps: plan_infos_from_state(&state), - }).await; - - tracing::info!("[workflow {}] Plan revised via feedback. First actionable: {:?}", - workflow_id, state.first_actionable_step()); - - let diff_display = format!("```diff\n{}\n```", diff); - send_execution(update_tx, workflow_id, 0, "revise_plan", "计划变更", &diff_display, "done").await; - } - } - } else { - let text = choice.message.content.as_deref().unwrap_or(""); - tracing::info!("[workflow {}] Feedback processed, no plan change: {}", workflow_id, truncate_str(text, 200)); - send_execution(update_tx, workflow_id, state.current_step(), "text_response", "", text, "done").await; - } - - let target_order = state.first_actionable_step() - .unwrap_or_else(|| state.steps.last().map(|s| s.order).unwrap_or(0)); - if let Some(step) = state.steps.iter_mut().find(|s| s.order == target_order) { - step.user_feedbacks.push(feedback.to_string()); - } - - send_snapshot(update_tx, workflow_id, state.current_step(), &state).await; - - state -} - /// Run an isolated sub-loop for a single step. Returns StepResult. #[allow(clippy::too_many_arguments)] pub async fn run_step_loop( diff --git a/src/api/chat.rs b/src/api/chat.rs index 2bc91fe..dea1f7c 100644 --- a/src/api/chat.rs +++ b/src/api/chat.rs @@ -1,53 +1,20 @@ use std::sync::Arc; use axum::{ - extract::State, http::StatusCode, response::{IntoResponse, Response}, routing::post, Json, Router, }; -use serde::Deserialize; -use crate::llm::{ChatMessage, LlmClient}; use crate::AppState; -#[derive(Deserialize)] -struct ChatRequest { - messages: Vec, -} - -#[derive(Deserialize)] -struct SimpleChatMessage { - role: String, - content: String, -} - pub fn router(state: Arc) -> Router { Router::new() .route("/chat", post(chat)) .with_state(state) } -async fn chat( - State(state): State>, - Json(input): Json, -) -> Result, Response> { - let llm = LlmClient::new(&state.config.llm); - let messages: Vec = input - .messages - .into_iter() - .map(|m| ChatMessage { - role: m.role, - content: Some(m.content), - tool_calls: None, - tool_call_id: None, - }) - .collect(); - - let reply = llm.chat(messages).await.map_err(|e| { - tracing::error!("Chat LLM error: {}", e); - (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response() - })?; - - Ok(Json(serde_json::json!({ "reply": reply }))) +async fn chat() -> Result, Response> { + // Chat endpoint removed — LLM runs on workers only + Err((StatusCode::GONE, "Chat endpoint removed. LLM runs on workers.").into_response()) } diff --git a/src/api/mod.rs b/src/api/mod.rs index ddff9f4..b5327ec 100644 --- a/src/api/mod.rs +++ b/src/api/mod.rs @@ -59,50 +59,12 @@ async fn proxy_to_service( } async fn proxy_impl( - state: &AppState, - project_id: &str, - path: &str, - req: Request, + _state: &AppState, + _project_id: &str, + _path: &str, + _req: Request, ) -> Response { - let port = match state.agent_mgr.get_service_port(project_id).await { - Some(p) => p, - None => return (StatusCode::SERVICE_UNAVAILABLE, "服务未启动").into_response(), - }; - - let query = req.uri().query().map(|q| format!("?{}", q)).unwrap_or_default(); - let url = format!("http://127.0.0.1:{}{}{}", port, path, query); - - let client = reqwest::Client::new(); - let method = req.method().clone(); - let headers = req.headers().clone(); - let body_bytes = match axum::body::to_bytes(req.into_body(), 10 * 1024 * 1024).await { - Ok(b) => b, - Err(_) => return (StatusCode::BAD_REQUEST, "请求体过大").into_response(), - }; - - let mut upstream_req = client.request(method, &url); - for (key, val) in headers.iter() { - if key != "host" { - upstream_req = upstream_req.header(key, val); - } - } - upstream_req = upstream_req.body(body_bytes); - - match upstream_req.send().await { - Ok(resp) => { - let status = StatusCode::from_u16(resp.status().as_u16()).unwrap_or(StatusCode::BAD_GATEWAY); - let resp_headers = resp.headers().clone(); - let body = resp.bytes().await.unwrap_or_default(); - let mut response = (status, body).into_response(); - for (key, val) in resp_headers.iter() { - if let Ok(name) = axum::http::header::HeaderName::from_bytes(key.as_ref()) { - response.headers_mut().insert(name, val.clone()); - } - } - response - } - Err(_) => (StatusCode::BAD_GATEWAY, "无法连接到后端服务").into_response(), - } + (StatusCode::SERVICE_UNAVAILABLE, "服务在 worker 上运行,无法从 server 代理").into_response() } diff --git a/src/lib.rs b/src/lib.rs index 82d3383..c6f74ed 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -28,7 +28,9 @@ pub struct AppState { #[derive(Debug, Clone, Deserialize)] pub struct Config { - pub llm: LlmConfig, + /// LLM config is optional on the server — workers bring their own. + #[serde(default)] + pub llm: Option, pub server: ServerConfig, pub database: DatabaseConfig, #[serde(default)] diff --git a/src/main.rs b/src/main.rs index 1ae0fb0..cf7fc79 100644 --- a/src/main.rs +++ b/src/main.rs @@ -27,6 +27,9 @@ enum Command { /// Worker name #[arg(long, env = "TORI_WORKER_NAME")] name: Option, + /// Config file path (for LLM settings) + #[arg(long, default_value = "config.yaml")] + config: String, }, } @@ -40,13 +43,19 @@ async fn main() -> anyhow::Result<()> { match cli.command { Command::Server => run_server().await, - Command::Worker { server, name } => { + Command::Worker { server, name, config } => { let name = name.unwrap_or_else(|| { hostname::get() .map(|h| h.to_string_lossy().to_string()) .unwrap_or_else(|_| "worker-1".to_string()) }); - worker_runner::run(&server, &name).await + let config_str = std::fs::read_to_string(&config) + .unwrap_or_else(|e| panic!("Failed to read {}: {}", config, e)); + let cfg: Config = serde_yaml::from_str(&config_str) + .unwrap_or_else(|e| panic!("Failed to parse {}: {}", config, e)); + let llm_config = cfg.llm + .unwrap_or_else(|| panic!("LLM config required in {} for worker mode", config)); + worker_runner::run(&server, &name, &llm_config).await } } } @@ -79,10 +88,6 @@ async fn run_server() -> anyhow::Result<()> { let agent_mgr = agent::AgentManager::new( database.pool.clone(), - config.llm.clone(), - config.template_repo.clone(), - kb_arc.clone(), - config.jwt_private_key.clone(), worker_mgr.clone(), ); @@ -129,6 +134,7 @@ async fn run_server() -> anyhow::Result<()> { } }; + let ws_pool = database.pool.clone(); let state = Arc::new(AppState { db: database, config: config.clone(), @@ -151,7 +157,13 @@ async fn run_server() -> anyhow::Result<()> { let r = obj_root; move || api::obj::root_listing(r) })) - .nest("/ws/tori/workers", ws_worker::router(worker_mgr)) + .nest("/ws/tori/workers", { + let pool = ws_pool; + let agent_mgr_for_ws = agent_mgr.clone(); + let broadcast_fn: Arc tokio::sync::broadcast::Sender + Send + Sync> = + Arc::new(move |pid: &str| agent_mgr_for_ws.get_broadcast_sender(pid)); + ws_worker::router(worker_mgr, pool, broadcast_fn) + }) .nest("/ws/tori", ws::router(agent_mgr)) .nest_service("/tori", ServeDir::new("web/dist").fallback(ServeFile::new("web/dist/index.html"))) .route("/", axum::routing::get(|| async { diff --git a/src/sink.rs b/src/sink.rs index d16f3bb..cba8f09 100644 --- a/src/sink.rs +++ b/src/sink.rs @@ -3,7 +3,7 @@ use std::sync::Arc; use std::sync::atomic::{AtomicU16, Ordering}; use serde::{Deserialize, Serialize}; use sqlx::sqlite::SqlitePool; -use tokio::sync::{RwLock, broadcast, mpsc}; +use tokio::sync::{RwLock, broadcast}; use crate::agent::{PlanStepInfo, WsMessage, ServiceInfo}; use crate::state::{AgentState, Artifact}; @@ -13,64 +13,19 @@ use crate::state::{AgentState, Artifact}; #[derive(Debug, Clone, Serialize, Deserialize)] #[serde(tag = "kind")] pub enum AgentUpdate { - PlanUpdate { - workflow_id: String, - steps: Vec, - }, - WorkflowStatus { - workflow_id: String, - status: String, - }, - Activity { - workflow_id: String, - activity: String, - }, - ExecutionLog { - workflow_id: String, - step_order: i32, - tool_name: String, - tool_input: String, - output: String, - status: String, - }, - LlmCallLog { - workflow_id: String, - step_order: i32, - phase: String, - messages_count: i32, - tools_count: i32, - tool_calls: String, - text_response: String, - prompt_tokens: Option, - completion_tokens: Option, - latency_ms: i64, - }, - StateSnapshot { - workflow_id: String, - step_order: i32, - state: AgentState, - }, - WorkflowComplete { - workflow_id: String, - status: String, - report: Option, - }, - ArtifactSave { - workflow_id: String, - step_order: i32, - artifact: Artifact, - }, - RequirementUpdate { - workflow_id: String, - requirement: String, - }, - Error { - message: String, - }, + PlanUpdate { workflow_id: String, steps: Vec }, + WorkflowStatus { workflow_id: String, status: String }, + Activity { workflow_id: String, activity: String }, + ExecutionLog { workflow_id: String, step_order: i32, tool_name: String, tool_input: String, output: String, status: String }, + LlmCallLog { workflow_id: String, step_order: i32, phase: String, messages_count: i32, tools_count: i32, tool_calls: String, text_response: String, prompt_tokens: Option, completion_tokens: Option, latency_ms: i64 }, + StateSnapshot { workflow_id: String, step_order: i32, state: AgentState }, + WorkflowComplete { workflow_id: String, status: String, report: Option }, + ArtifactSave { workflow_id: String, step_order: i32, artifact: Artifact }, + RequirementUpdate { workflow_id: String, requirement: String }, + Error { message: String }, } /// Manages local services (start_service / stop_service tools). -/// Created per-worker or per-agent-loop. pub struct ServiceManager { pub services: RwLock>, next_port: AtomicU16, @@ -89,151 +44,84 @@ impl ServiceManager { } } -/// Server-side handler: consumes AgentUpdate from channel, persists to DB and broadcasts to frontend. -pub async fn handle_agent_updates( - mut rx: mpsc::Receiver, - pool: SqlitePool, - broadcast_tx: broadcast::Sender, +/// Helper: broadcast if sender is available. +fn bcast(tx: Option<&broadcast::Sender>, msg: WsMessage) { + if let Some(tx) = tx { let _ = tx.send(msg); } +} + +/// Process a single AgentUpdate: persist to DB and broadcast to frontend. +pub async fn handle_single_update( + update: &AgentUpdate, + pool: &SqlitePool, + broadcast_tx: Option<&broadcast::Sender>, ) { - while let Some(update) = rx.recv().await { - match update { - AgentUpdate::PlanUpdate { workflow_id, steps } => { - let _ = broadcast_tx.send(WsMessage::PlanUpdate { workflow_id, steps }); - } - AgentUpdate::WorkflowStatus { ref workflow_id, ref status } => { - let _ = sqlx::query("UPDATE workflows SET status = ? WHERE id = ?") - .bind(status) - .bind(workflow_id) - .execute(&pool) - .await; - let _ = broadcast_tx.send(WsMessage::WorkflowStatusUpdate { - workflow_id: workflow_id.clone(), - status: status.clone(), - }); - } - AgentUpdate::Activity { workflow_id, activity } => { - let _ = broadcast_tx.send(WsMessage::ActivityUpdate { workflow_id, activity }); - } - AgentUpdate::ExecutionLog { ref workflow_id, step_order, ref tool_name, ref tool_input, ref output, ref status } => { - let id = uuid::Uuid::new_v4().to_string(); - let _ = sqlx::query( - "INSERT INTO execution_log (id, workflow_id, step_order, tool_name, tool_input, output, status, created_at) VALUES (?, ?, ?, ?, ?, ?, ?, datetime('now'))" - ) - .bind(&id) - .bind(workflow_id) - .bind(step_order) - .bind(tool_name) - .bind(tool_input) - .bind(output) - .bind(status) - .execute(&pool) - .await; - let _ = broadcast_tx.send(WsMessage::StepStatusUpdate { - step_id: id, - status: status.clone(), - output: output.clone(), - }); - } - AgentUpdate::LlmCallLog { ref workflow_id, step_order, ref phase, messages_count, tools_count, ref tool_calls, ref text_response, prompt_tokens, completion_tokens, latency_ms } => { - let id = uuid::Uuid::new_v4().to_string(); - let _ = sqlx::query( - "INSERT INTO llm_call_log (id, workflow_id, step_order, phase, messages_count, tools_count, tool_calls, text_response, prompt_tokens, completion_tokens, latency_ms, created_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, datetime('now'))" - ) - .bind(&id) - .bind(workflow_id) - .bind(step_order) - .bind(phase) - .bind(messages_count) - .bind(tools_count) - .bind(tool_calls) - .bind(text_response) - .bind(prompt_tokens.map(|v| v as i32)) - .bind(completion_tokens.map(|v| v as i32)) - .bind(latency_ms as i32) - .execute(&pool) - .await; - let entry = crate::db::LlmCallLogEntry { - id, - workflow_id: workflow_id.clone(), - step_order, - phase: phase.clone(), - messages_count, - tools_count, - tool_calls: tool_calls.clone(), - text_response: text_response.clone(), - prompt_tokens: prompt_tokens.map(|v| v as i32), - completion_tokens: completion_tokens.map(|v| v as i32), - latency_ms: latency_ms as i32, - created_at: String::new(), - }; - let _ = broadcast_tx.send(WsMessage::LlmCallLog { - workflow_id: workflow_id.clone(), - entry, - }); - } - AgentUpdate::StateSnapshot { ref workflow_id, step_order, ref state } => { - let id = uuid::Uuid::new_v4().to_string(); - let json = serde_json::to_string(state).unwrap_or_default(); - let _ = sqlx::query( - "INSERT INTO agent_state_snapshots (id, workflow_id, step_order, state_json, created_at) VALUES (?, ?, ?, ?, datetime('now'))" - ) - .bind(&id) - .bind(workflow_id) - .bind(step_order) - .bind(&json) - .execute(&pool) - .await; - } - AgentUpdate::WorkflowComplete { ref workflow_id, ref status, ref report } => { - let _ = sqlx::query("UPDATE workflows SET status = ? WHERE id = ?") - .bind(status) - .bind(workflow_id) - .execute(&pool) - .await; - if let Some(ref r) = report { - let _ = sqlx::query("UPDATE workflows SET report = ? WHERE id = ?") - .bind(r) - .bind(workflow_id) - .execute(&pool) - .await; - let _ = broadcast_tx.send(WsMessage::ReportReady { - workflow_id: workflow_id.clone(), - }); - } - let _ = broadcast_tx.send(WsMessage::WorkflowStatusUpdate { - workflow_id: workflow_id.clone(), - status: status.clone(), - }); - } - AgentUpdate::ArtifactSave { ref workflow_id, step_order, ref artifact } => { - let id = uuid::Uuid::new_v4().to_string(); - let _ = sqlx::query( - "INSERT INTO step_artifacts (id, workflow_id, step_order, name, path, artifact_type, description) VALUES (?, ?, ?, ?, ?, ?, ?)" - ) - .bind(&id) - .bind(workflow_id) - .bind(step_order) - .bind(&artifact.name) - .bind(&artifact.path) - .bind(&artifact.artifact_type) - .bind(&artifact.description) - .execute(&pool) - .await; - } - AgentUpdate::RequirementUpdate { ref workflow_id, ref requirement } => { - let _ = sqlx::query("UPDATE workflows SET requirement = ? WHERE id = ?") - .bind(requirement) - .bind(workflow_id) - .execute(&pool) - .await; - let _ = broadcast_tx.send(WsMessage::RequirementUpdate { - workflow_id: workflow_id.clone(), - requirement: requirement.clone(), - }); - } - AgentUpdate::Error { message } => { - let _ = broadcast_tx.send(WsMessage::Error { message }); + match update { + AgentUpdate::PlanUpdate { workflow_id, steps } => { + bcast(broadcast_tx, WsMessage::PlanUpdate { workflow_id: workflow_id.clone(), steps: steps.clone() }); + } + AgentUpdate::WorkflowStatus { workflow_id, status } => { + let _ = sqlx::query("UPDATE workflows SET status = ? WHERE id = ?") + .bind(status).bind(workflow_id).execute(pool).await; + bcast(broadcast_tx, WsMessage::WorkflowStatusUpdate { workflow_id: workflow_id.clone(), status: status.clone() }); + } + AgentUpdate::Activity { workflow_id, activity } => { + bcast(broadcast_tx, WsMessage::ActivityUpdate { workflow_id: workflow_id.clone(), activity: activity.clone() }); + } + AgentUpdate::ExecutionLog { workflow_id, step_order, tool_name, tool_input, output, status } => { + let id = uuid::Uuid::new_v4().to_string(); + let _ = sqlx::query( + "INSERT INTO execution_log (id, workflow_id, step_order, tool_name, tool_input, output, status, created_at) VALUES (?, ?, ?, ?, ?, ?, ?, datetime('now'))" + ).bind(&id).bind(workflow_id).bind(step_order).bind(tool_name).bind(tool_input).bind(output).bind(status) + .execute(pool).await; + bcast(broadcast_tx, WsMessage::StepStatusUpdate { step_id: id, status: status.clone(), output: output.clone() }); + } + AgentUpdate::LlmCallLog { workflow_id, step_order, phase, messages_count, tools_count, tool_calls, text_response, prompt_tokens, completion_tokens, latency_ms } => { + let id = uuid::Uuid::new_v4().to_string(); + let _ = sqlx::query( + "INSERT INTO llm_call_log (id, workflow_id, step_order, phase, messages_count, tools_count, tool_calls, text_response, prompt_tokens, completion_tokens, latency_ms, created_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, datetime('now'))" + ).bind(&id).bind(workflow_id).bind(step_order).bind(phase).bind(messages_count).bind(tools_count) + .bind(tool_calls).bind(text_response).bind(prompt_tokens.map(|v| v as i32)).bind(completion_tokens.map(|v| v as i32)).bind(*latency_ms as i32) + .execute(pool).await; + let entry = crate::db::LlmCallLogEntry { + id, workflow_id: workflow_id.clone(), step_order: *step_order, phase: phase.clone(), + messages_count: *messages_count, tools_count: *tools_count, tool_calls: tool_calls.clone(), + text_response: text_response.clone(), prompt_tokens: prompt_tokens.map(|v| v as i32), + completion_tokens: completion_tokens.map(|v| v as i32), latency_ms: *latency_ms as i32, + created_at: String::new(), + }; + bcast(broadcast_tx, WsMessage::LlmCallLog { workflow_id: workflow_id.clone(), entry }); + } + AgentUpdate::StateSnapshot { workflow_id, step_order, state } => { + let id = uuid::Uuid::new_v4().to_string(); + let json = serde_json::to_string(state).unwrap_or_default(); + let _ = sqlx::query( + "INSERT INTO agent_state_snapshots (id, workflow_id, step_order, state_json, created_at) VALUES (?, ?, ?, ?, datetime('now'))" + ).bind(&id).bind(workflow_id).bind(step_order).bind(&json).execute(pool).await; + } + AgentUpdate::WorkflowComplete { workflow_id, status, report } => { + let _ = sqlx::query("UPDATE workflows SET status = ? WHERE id = ?") + .bind(status).bind(workflow_id).execute(pool).await; + if let Some(r) = report { + let _ = sqlx::query("UPDATE workflows SET report = ? WHERE id = ?") + .bind(r).bind(workflow_id).execute(pool).await; + bcast(broadcast_tx, WsMessage::ReportReady { workflow_id: workflow_id.clone() }); } + bcast(broadcast_tx, WsMessage::WorkflowStatusUpdate { workflow_id: workflow_id.clone(), status: status.clone() }); + } + AgentUpdate::ArtifactSave { workflow_id, step_order, artifact } => { + let id = uuid::Uuid::new_v4().to_string(); + let _ = sqlx::query( + "INSERT INTO step_artifacts (id, workflow_id, step_order, name, path, artifact_type, description) VALUES (?, ?, ?, ?, ?, ?, ?)" + ).bind(&id).bind(workflow_id).bind(step_order).bind(&artifact.name).bind(&artifact.path) + .bind(&artifact.artifact_type).bind(&artifact.description).execute(pool).await; + } + AgentUpdate::RequirementUpdate { workflow_id, requirement } => { + let _ = sqlx::query("UPDATE workflows SET requirement = ? WHERE id = ?") + .bind(requirement).bind(workflow_id).execute(pool).await; + bcast(broadcast_tx, WsMessage::RequirementUpdate { workflow_id: workflow_id.clone(), requirement: requirement.clone() }); + } + AgentUpdate::Error { message } => { + bcast(broadcast_tx, WsMessage::Error { message: message.clone() }); } } } diff --git a/src/worker.rs b/src/worker.rs index 99cf576..001d446 100644 --- a/src/worker.rs +++ b/src/worker.rs @@ -3,6 +3,8 @@ use std::sync::Arc; use tokio::sync::RwLock; use serde::{Deserialize, Serialize}; +use crate::state::AgentState; + /// Information reported by a worker on registration. #[derive(Debug, Clone, Serialize, Deserialize)] pub struct WorkerInfo { @@ -14,18 +16,13 @@ pub struct WorkerInfo { pub kernel: String, } -/// A registered worker with a channel for sending scripts to execute. +/// A registered worker. struct Worker { pub info: WorkerInfo, - pub tx: tokio::sync::mpsc::Sender, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct WorkerRequest { - pub job_id: String, - pub script: String, + pub tx: tokio::sync::mpsc::Sender, } +/// Legacy script execution result. #[derive(Debug, Clone, Serialize, Deserialize)] pub struct WorkerResult { pub job_id: String, @@ -34,110 +31,6 @@ pub struct WorkerResult { pub stderr: String, } -/// Manages all connected workers. -pub struct WorkerManager { - workers: RwLock>, - /// Pending job results, keyed by job_id. - results: RwLock>>, -} - -impl WorkerManager { - pub fn new() -> Arc { - Arc::new(Self { - workers: RwLock::new(HashMap::new()), - results: RwLock::new(HashMap::new()), - }) - } - - /// Register a new worker. Returns a receiver for job requests. - pub async fn register( - &self, - name: String, - info: WorkerInfo, - ) -> tokio::sync::mpsc::Receiver { - let (tx, rx) = tokio::sync::mpsc::channel(16); - tracing::info!("Worker registered: {} (cpu={}, mem={}, gpu={}, os={}, kernel={})", - name, info.cpu, info.memory, info.gpu, info.os, info.kernel); - self.workers.write().await.insert(name, Worker { info, tx }); - rx - } - - /// Remove a worker. - pub async fn unregister(&self, name: &str) { - self.workers.write().await.remove(name); - tracing::info!("Worker unregistered: {}", name); - } - - /// List all connected workers. - pub async fn list(&self) -> Vec<(String, WorkerInfo)> { - self.workers - .read() - .await - .iter() - .map(|(name, w)| (name.clone(), w.info.clone())) - .collect() - } - - /// Submit a script to a worker and wait for the result. - pub async fn execute( - &self, - worker_name: &str, - script: &str, - timeout_secs: u64, - ) -> Result { - let job_id = uuid::Uuid::new_v4().to_string(); - - // Find the worker and send the request - let tx = { - let workers = self.workers.read().await; - let worker = workers - .get(worker_name) - .ok_or_else(|| format!("Worker '{}' not found", worker_name))?; - worker.tx.clone() - }; - - let (result_tx, result_rx) = tokio::sync::oneshot::channel(); - self.results.write().await.insert(job_id.clone(), result_tx); - - let req = WorkerRequest { - job_id: job_id.clone(), - script: script.to_string(), - }; - - tx.send(req).await.map_err(|_| { - format!("Worker '{}' disconnected", worker_name) - })?; - - // Wait for result with timeout - match tokio::time::timeout( - std::time::Duration::from_secs(timeout_secs), - result_rx, - ) - .await - { - Ok(Ok(result)) => Ok(result), - Ok(Err(_)) => Err("Worker channel closed unexpectedly".into()), - Err(_) => { - self.results.write().await.remove(&job_id); - Err(format!("Execution timed out after {}s", timeout_secs)) - } - } - } - - /// Called when a worker sends back a result. - pub async fn report_result(&self, result: WorkerResult) { - if let Some(tx) = self.results.write().await.remove(&result.job_id) { - let _ = tx.send(result); - } - } -} - -// --- Extended protocol for workflow execution --- - -use crate::LlmConfig; -use crate::state::AgentState; -use crate::llm::Tool; - /// Messages sent from server to worker. #[derive(Debug, Clone, Serialize, Deserialize)] #[serde(tag = "type")] @@ -148,15 +41,12 @@ pub enum ServerToWorker { workflow_id: String, project_id: String, requirement: String, - workdir: String, - instructions: String, - llm_config: LlmConfig, + #[serde(default)] + template_id: Option, #[serde(default)] initial_state: Option, #[serde(default)] require_plan_approval: bool, - #[serde(default)] - external_tools: Vec, }, /// Forward a user comment to the worker executing this workflow. #[serde(rename = "comment")] @@ -173,9 +63,6 @@ pub enum WorkerToServer { /// Worker registration. #[serde(rename = "register")] Register { info: WorkerInfo }, - /// Script execution result (legacy). - #[serde(rename = "result")] - Result(WorkerResult), /// Agent update from workflow execution. #[serde(rename = "update")] Update { @@ -184,3 +71,97 @@ pub enum WorkerToServer { update: crate::sink::AgentUpdate, }, } + +/// Manages all connected workers and workflow assignments. +pub struct WorkerManager { + workers: RwLock>, + /// workflow_id → worker_name + assignments: RwLock>, +} + +impl WorkerManager { + pub fn new() -> Arc { + Arc::new(Self { + workers: RwLock::new(HashMap::new()), + assignments: RwLock::new(HashMap::new()), + }) + } + + /// Register a new worker. Returns a receiver for messages. + pub async fn register( + &self, + name: String, + info: WorkerInfo, + ) -> tokio::sync::mpsc::Receiver { + let (tx, rx) = tokio::sync::mpsc::channel(16); + tracing::info!("Worker registered: {} (cpu={}, mem={}, gpu={}, os={}, kernel={})", + name, info.cpu, info.memory, info.gpu, info.os, info.kernel); + self.workers.write().await.insert(name, Worker { info, tx }); + rx + } + + /// Remove a worker and clean up its assignments. + pub async fn unregister(&self, name: &str) { + self.workers.write().await.remove(name); + // Remove all workflow assignments for this worker + let mut assignments = self.assignments.write().await; + assignments.retain(|_, worker| worker != name); + tracing::info!("Worker unregistered: {}", name); + } + + /// List all connected workers. + pub async fn list(&self) -> Vec<(String, WorkerInfo)> { + self.workers + .read() + .await + .iter() + .map(|(name, w)| (name.clone(), w.info.clone())) + .collect() + } + + /// Assign a workflow to the first available worker. Returns worker name. + pub async fn assign_workflow(&self, assign: ServerToWorker) -> Result { + let workflow_id = match &assign { + ServerToWorker::WorkflowAssign { workflow_id, .. } => workflow_id.clone(), + _ => return Err("Not a workflow assignment".into()), + }; + + let workers = self.workers.read().await; + // Pick first worker (simple strategy for now) + let (name, worker) = workers.iter().next() + .ok_or_else(|| "No workers available".to_string())?; + + worker.tx.send(assign).await.map_err(|_| { + format!("Worker '{}' disconnected", name) + })?; + + let worker_name = name.clone(); + drop(workers); + + self.assignments.write().await.insert(workflow_id, worker_name.clone()); + Ok(worker_name) + } + + /// Forward a comment to the worker handling a workflow. + pub async fn forward_comment(&self, workflow_id: &str, content: &str) -> Result<(), String> { + let assignments = self.assignments.read().await; + let worker_name = assignments.get(workflow_id) + .ok_or_else(|| format!("No worker assigned for workflow {}", workflow_id))? + .clone(); + drop(assignments); + + let workers = self.workers.read().await; + let worker = workers.get(&worker_name) + .ok_or_else(|| format!("Worker '{}' not found", worker_name))?; + + worker.tx.send(ServerToWorker::Comment { + workflow_id: workflow_id.to_string(), + content: content.to_string(), + }).await.map_err(|_| format!("Worker '{}' disconnected", worker_name)) + } + + /// Remove a workflow assignment (when workflow completes). + pub async fn complete_workflow(&self, workflow_id: &str) { + self.assignments.write().await.remove(workflow_id); + } +} diff --git a/src/worker_runner.rs b/src/worker_runner.rs index ba0c8b7..2213e16 100644 --- a/src/worker_runner.rs +++ b/src/worker_runner.rs @@ -59,11 +59,11 @@ fn collect_worker_info(name: &str) -> WorkerInfo { } } -pub async fn run(server_url: &str, worker_name: &str) -> anyhow::Result<()> { - tracing::info!("Tori worker '{}' connecting to {}", worker_name, server_url); +pub async fn run(server_url: &str, worker_name: &str, llm_config: &crate::LlmConfig) -> anyhow::Result<()> { + tracing::info!("Tori worker '{}' connecting to {} (model={})", worker_name, server_url, llm_config.model); loop { - match connect_and_run(server_url, worker_name).await { + match connect_and_run(server_url, worker_name, llm_config).await { Ok(()) => { tracing::info!("Worker connection closed, reconnecting in 5s..."); } @@ -75,7 +75,7 @@ pub async fn run(server_url: &str, worker_name: &str) -> anyhow::Result<()> { } } -async fn connect_and_run(server_url: &str, worker_name: &str) -> anyhow::Result<()> { +async fn connect_and_run(server_url: &str, worker_name: &str, llm_config: &crate::LlmConfig) -> anyhow::Result<()> { let (ws_stream, _) = connect_async(server_url).await?; let (mut ws_tx, mut ws_rx) = ws_stream.split(); @@ -127,17 +127,16 @@ async fn connect_and_run(server_url: &str, worker_name: &str) -> anyhow::Result< workflow_id, project_id, requirement, - workdir, - instructions, - llm_config, + template_id: _, initial_state, require_plan_approval, - external_tools: _, } => { tracing::info!("Received workflow: {} (project {})", workflow_id, project_id); - let llm = LlmClient::new(&llm_config); + let llm = LlmClient::new(llm_config); let exec = LocalExecutor::new(None); + let workdir = format!("/app/data/workspaces/{}", project_id); + let instructions = String::new(); // TODO: load from template // update channel → serialize → WebSocket let (update_tx, mut update_rx) = mpsc::channel::(64); diff --git a/src/ws_worker.rs b/src/ws_worker.rs index a76028d..2a20ff3 100644 --- a/src/ws_worker.rs +++ b/src/ws_worker.rs @@ -7,44 +7,52 @@ use axum::{ Router, }; use futures::{SinkExt, StreamExt}; -use serde::Deserialize; +use sqlx::sqlite::SqlitePool; +use tokio::sync::broadcast; -use crate::worker::{WorkerInfo, WorkerManager, WorkerResult}; +use crate::agent::WsMessage; +use crate::worker::{WorkerInfo, WorkerManager, WorkerToServer}; -pub fn router(mgr: Arc) -> Router { +pub struct WsWorkerState { + pub mgr: Arc, + pub pool: SqlitePool, + pub broadcast_fn: Arc broadcast::Sender + Send + Sync>, +} + +pub fn router(mgr: Arc, pool: SqlitePool, broadcast_fn: Arc broadcast::Sender + Send + Sync>) -> Router { + let state = Arc::new(WsWorkerState { mgr, pool, broadcast_fn }); Router::new() .route("/", get(ws_handler)) - .with_state(mgr) + .with_state(state) } async fn ws_handler( ws: WebSocketUpgrade, - State(mgr): State>, + State(state): State>, ) -> Response { - ws.on_upgrade(move |socket| handle_worker_socket(socket, mgr)) + ws.on_upgrade(move |socket| handle_worker_socket(socket, state)) } -#[derive(Deserialize)] -#[serde(tag = "type")] -enum WorkerMessage { - #[serde(rename = "register")] - Register { info: WorkerInfo }, - #[serde(rename = "result")] - Result(WorkerResult), -} - -async fn handle_worker_socket(socket: WebSocket, mgr: Arc) { +async fn handle_worker_socket(socket: WebSocket, state: Arc) { let (mut sender, mut receiver) = socket.split(); // First message must be registration - let (name, mut job_rx) = loop { + let (name, mut msg_rx) = loop { match receiver.next().await { Some(Ok(Message::Text(text))) => { - match serde_json::from_str::(&text) { - Ok(WorkerMessage::Register { info }) => { + match serde_json::from_str::(&text) { + Ok(v) if v["type"] == "register" => { + let info: WorkerInfo = match serde_json::from_value(v["info"].clone()) { + Ok(i) => i, + Err(_) => { + let _ = sender.send(Message::Text( + r#"{"type":"error","message":"Invalid worker info"}"#.into(), + )).await; + return; + } + }; let name = info.name.clone(); - let rx = mgr.register(name.clone(), info).await; - // Ack + let rx = state.mgr.register(name.clone(), info).await; let ack = serde_json::json!({ "type": "registered", "name": &name }); let _ = sender.send(Message::Text(ack.to_string().into())).await; break (name, rx); @@ -62,31 +70,28 @@ async fn handle_worker_socket(socket: WebSocket, mgr: Arc) { } }; - // Main loop: forward jobs to worker, receive results let name_clone = name.clone(); - let mgr_clone = mgr.clone(); + let mgr_for_cleanup = state.mgr.clone(); - // Task: send jobs from job_rx to the WebSocket + // Task: send ServerToWorker messages from msg_rx to the WebSocket let send_task = tokio::spawn(async move { - while let Some(req) = job_rx.recv().await { - let msg = serde_json::json!({ - "type": "execute", - "job_id": req.job_id, - "script": req.script, - }); - if sender.send(Message::Text(msg.to_string().into())).await.is_err() { - break; + while let Some(msg) = msg_rx.recv().await { + if let Ok(json) = serde_json::to_string(&msg) { + if sender.send(Message::Text(json.into())).await.is_err() { + break; + } } } }); - // Task: receive results from the WebSocket + // Task: receive WorkerToServer messages from the WebSocket → process + let state_clone = state.clone(); let recv_task = tokio::spawn(async move { while let Some(Ok(msg)) = receiver.next().await { match msg { Message::Text(text) => { - if let Ok(WorkerMessage::Result(result)) = serde_json::from_str(&text) { - mgr_clone.report_result(result).await; + if let Ok(worker_msg) = serde_json::from_str::(&text) { + handle_worker_message(&state_clone, worker_msg).await; } } Message::Close(_) => break, @@ -100,5 +105,38 @@ async fn handle_worker_socket(socket: WebSocket, mgr: Arc) { _ = recv_task => {}, } - mgr.unregister(&name_clone).await; + mgr_for_cleanup.unregister(&name_clone).await; +} + +async fn handle_worker_message(state: &WsWorkerState, msg: WorkerToServer) { + match msg { + WorkerToServer::Register { .. } => { + // Already handled during initial handshake + } + WorkerToServer::Update { workflow_id, update } => { + // Get project_id for broadcasting (look up from DB) + let project_id: Option = sqlx::query_scalar( + "SELECT project_id FROM workflows WHERE id = ?" + ) + .bind(&workflow_id) + .fetch_optional(&state.pool) + .await + .ok() + .flatten(); + + let broadcast_tx = if let Some(ref pid) = project_id { + Some((state.broadcast_fn)(pid)) + } else { + None + }; + + // Check if this is a workflow completion + if let crate::sink::AgentUpdate::WorkflowComplete { ref workflow_id, .. } = update { + state.mgr.complete_workflow(workflow_id).await; + } + + // Process the update: write to DB + broadcast + crate::sink::handle_single_update(&update, &state.pool, broadcast_tx.as_ref()).await; + } + } }