refactor: server no longer runs agent loop or LLM
- Remove agent_loop from server (was ~400 lines) — server dispatches to workers - AgentManager simplified to pure dispatcher (send_event → worker) - Remove LLM config requirement from server (workers bring their own via config.yaml) - Remove process_feedback, build_feedback_tools from server - Remove chat API endpoint (LLM on workers only) - Remove service proxy (services run on workers) - Worker reads LLM config from its own config.yaml - ws_worker.rs handles WorkerToServer::Update messages (DB + broadcast) - Verified locally: tori server + tori worker connect and register
This commit is contained in:
626
src/agent.rs
626
src/agent.rs
@@ -1,17 +1,13 @@
|
||||
use std::collections::HashMap;
|
||||
use std::path::Path;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::{AtomicU16, Ordering};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use sqlx::sqlite::SqlitePool;
|
||||
use tokio::sync::{mpsc, RwLock, broadcast};
|
||||
|
||||
use crate::llm::{LlmClient, ChatMessage, Tool, ToolFunction};
|
||||
use crate::exec::LocalExecutor;
|
||||
use crate::template::{self, LoadedTemplate};
|
||||
use crate::tools::ExternalToolManager;
|
||||
use crate::worker::WorkerManager;
|
||||
use crate::LlmConfig;
|
||||
use crate::sink::{AgentUpdate, ServiceManager};
|
||||
|
||||
use crate::state::{AgentState, AgentPhase, Artifact, Step, StepStatus, StepResult, StepResultStatus, check_scratchpad_size};
|
||||
@@ -73,49 +69,23 @@ pub fn plan_infos_from_state(state: &AgentState) -> Vec<PlanStepInfo> {
|
||||
}
|
||||
|
||||
pub struct AgentManager {
|
||||
agents: RwLock<HashMap<String, mpsc::Sender<AgentEvent>>>,
|
||||
broadcast: RwLock<HashMap<String, broadcast::Sender<WsMessage>>>,
|
||||
pub services: RwLock<HashMap<String, ServiceInfo>>,
|
||||
next_port: AtomicU16,
|
||||
pool: SqlitePool,
|
||||
llm_config: LlmConfig,
|
||||
template_repo: Option<crate::TemplateRepoConfig>,
|
||||
kb: Option<Arc<crate::kb::KbManager>>,
|
||||
jwt_private_key_path: Option<String>,
|
||||
pub worker_mgr: Arc<WorkerManager>,
|
||||
}
|
||||
|
||||
impl AgentManager {
|
||||
pub fn new(
|
||||
pool: SqlitePool,
|
||||
llm_config: LlmConfig,
|
||||
template_repo: Option<crate::TemplateRepoConfig>,
|
||||
kb: Option<Arc<crate::kb::KbManager>>,
|
||||
jwt_private_key_path: Option<String>,
|
||||
worker_mgr: Arc<WorkerManager>,
|
||||
) -> Arc<Self> {
|
||||
Arc::new(Self {
|
||||
agents: RwLock::new(HashMap::new()),
|
||||
broadcast: RwLock::new(HashMap::new()),
|
||||
services: RwLock::new(HashMap::new()),
|
||||
next_port: AtomicU16::new(9100),
|
||||
pool,
|
||||
llm_config,
|
||||
template_repo,
|
||||
kb,
|
||||
jwt_private_key_path,
|
||||
worker_mgr,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn allocate_port(&self) -> u16 {
|
||||
self.next_port.fetch_add(1, Ordering::Relaxed)
|
||||
}
|
||||
|
||||
pub async fn get_service_port(&self, project_id: &str) -> Option<u16> {
|
||||
self.services.read().await.get(project_id).map(|s| s.port)
|
||||
}
|
||||
|
||||
pub async fn get_broadcast(&self, project_id: &str) -> broadcast::Receiver<WsMessage> {
|
||||
let mut map = self.broadcast.write().await;
|
||||
let tx = map.entry(project_id.to_string())
|
||||
@@ -123,403 +93,95 @@ impl AgentManager {
|
||||
tx.subscribe()
|
||||
}
|
||||
|
||||
pub async fn send_event(self: &Arc<Self>, project_id: &str, event: AgentEvent) {
|
||||
let agents = self.agents.read().await;
|
||||
if let Some(tx) = agents.get(project_id) {
|
||||
let _ = tx.send(event).await;
|
||||
} else {
|
||||
drop(agents);
|
||||
self.spawn_agent(project_id.to_string()).await;
|
||||
let agents = self.agents.read().await;
|
||||
if let Some(tx) = agents.get(project_id) {
|
||||
let _ = tx.send(event).await;
|
||||
}
|
||||
}
|
||||
pub fn get_broadcast_sender(&self, project_id: &str) -> broadcast::Sender<WsMessage> {
|
||||
// This is called synchronously from ws_worker; we use a blocking approach
|
||||
// since RwLock is tokio-based, we need a sync wrapper
|
||||
// Actually, let's use try_write or just create a new one
|
||||
// For simplicity, return a new sender each time (they share the channel)
|
||||
// This is safe because broadcast::Sender is Clone
|
||||
tokio::task::block_in_place(|| {
|
||||
let rt = tokio::runtime::Handle::current();
|
||||
rt.block_on(async {
|
||||
let mut map = self.broadcast.write().await;
|
||||
map.entry(project_id.to_string())
|
||||
.or_insert_with(|| broadcast::channel(64).0)
|
||||
.clone()
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
async fn spawn_agent(self: &Arc<Self>, project_id: String) {
|
||||
let (tx, rx) = mpsc::channel(32);
|
||||
self.agents.write().await.insert(project_id.clone(), tx);
|
||||
/// Dispatch an event to a worker.
|
||||
pub async fn send_event(self: &Arc<Self>, project_id: &str, event: AgentEvent) {
|
||||
match event {
|
||||
AgentEvent::NewRequirement { workflow_id, requirement, template_id } => {
|
||||
// Generate title (heuristic)
|
||||
let title = generate_title_heuristic(&requirement);
|
||||
let _ = sqlx::query("UPDATE projects SET name = ? WHERE id = ?")
|
||||
.bind(&title).bind(project_id).execute(&self.pool).await;
|
||||
let btx = {
|
||||
let mut map = self.broadcast.write().await;
|
||||
map.entry(project_id.to_string())
|
||||
.or_insert_with(|| broadcast::channel(64).0)
|
||||
.clone()
|
||||
};
|
||||
let _ = btx.send(WsMessage::ProjectUpdate {
|
||||
project_id: project_id.to_string(),
|
||||
name: title,
|
||||
});
|
||||
|
||||
let broadcast_tx = {
|
||||
let mut map = self.broadcast.write().await;
|
||||
map.entry(project_id.clone())
|
||||
.or_insert_with(|| broadcast::channel(64).0)
|
||||
.clone()
|
||||
};
|
||||
// Update workflow status
|
||||
let _ = sqlx::query("UPDATE workflows SET status = 'executing' WHERE id = ?")
|
||||
.bind(&workflow_id).execute(&self.pool).await;
|
||||
let _ = btx.send(WsMessage::WorkflowStatusUpdate {
|
||||
workflow_id: workflow_id.clone(),
|
||||
status: "executing".into(),
|
||||
});
|
||||
|
||||
let mgr = Arc::clone(self);
|
||||
tokio::spawn(agent_loop(project_id, rx, broadcast_tx, mgr));
|
||||
// Persist template_id
|
||||
if let Some(ref tid) = template_id {
|
||||
let _ = sqlx::query("UPDATE workflows SET template_id = ? WHERE id = ?")
|
||||
.bind(tid).bind(&workflow_id).execute(&self.pool).await;
|
||||
}
|
||||
|
||||
// Dispatch to worker
|
||||
let assign = crate::worker::ServerToWorker::WorkflowAssign {
|
||||
workflow_id: workflow_id.clone(),
|
||||
project_id: project_id.to_string(),
|
||||
requirement,
|
||||
template_id,
|
||||
initial_state: None,
|
||||
require_plan_approval: false,
|
||||
};
|
||||
|
||||
match self.worker_mgr.assign_workflow(assign).await {
|
||||
Ok(name) => {
|
||||
tracing::info!("Workflow {} dispatched to worker '{}'", workflow_id, name);
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to dispatch workflow {}: {}", workflow_id, e);
|
||||
let _ = sqlx::query("UPDATE workflows SET status = 'failed' WHERE id = ?")
|
||||
.bind(&workflow_id).execute(&self.pool).await;
|
||||
let _ = btx.send(WsMessage::WorkflowStatusUpdate {
|
||||
workflow_id,
|
||||
status: "failed".into(),
|
||||
});
|
||||
let _ = btx.send(WsMessage::Error {
|
||||
message: format!("No worker available: {}", e),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
AgentEvent::Comment { workflow_id, content } => {
|
||||
if let Err(e) = self.worker_mgr.forward_comment(&workflow_id, &content).await {
|
||||
tracing::warn!("Failed to forward comment for workflow {}: {}", workflow_id, e);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Template system is in crate::template
|
||||
|
||||
/// Read INSTRUCTIONS.md from workdir if it exists.
|
||||
async fn read_instructions(workdir: &str) -> String {
|
||||
let path = format!("{}/INSTRUCTIONS.md", workdir);
|
||||
tokio::fs::read_to_string(&path).await.unwrap_or_default()
|
||||
}
|
||||
|
||||
async fn ensure_workspace(exec: &LocalExecutor, workdir: &str) {
|
||||
let _ = tokio::fs::create_dir_all(workdir).await;
|
||||
let setup_script = format!("{}/scripts/setup.sh", workdir);
|
||||
if Path::new(&setup_script).exists() {
|
||||
tracing::info!("Running setup.sh in {}", workdir);
|
||||
let _ = exec.execute("bash scripts/setup.sh", workdir).await;
|
||||
} else {
|
||||
let venv_path = format!("{}/.venv", workdir);
|
||||
if !Path::new(&venv_path).exists() {
|
||||
let _ = exec.execute("uv venv .venv", workdir).await;
|
||||
let _ = exec.execute("uv pip install httpx fastapi uvicorn requests flask pydantic numpy pandas matplotlib pillow jinja2 pyyaml python-dotenv beautifulsoup4 lxml aiohttp aiofiles pytest rich click typer sqlalchemy", workdir).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn agent_loop(
|
||||
project_id: String,
|
||||
mut rx: mpsc::Receiver<AgentEvent>,
|
||||
broadcast_tx: broadcast::Sender<WsMessage>,
|
||||
mgr: Arc<AgentManager>,
|
||||
) {
|
||||
let pool = mgr.pool.clone();
|
||||
let llm_config = mgr.llm_config.clone();
|
||||
let llm = LlmClient::new(&llm_config);
|
||||
let exec = LocalExecutor::new(mgr.jwt_private_key_path.clone());
|
||||
let workdir = format!("/app/data/workspaces/{}", project_id);
|
||||
let svc_mgr = ServiceManager::new(9100);
|
||||
|
||||
// Create update channel and spawn handler
|
||||
let (update_tx, update_rx) = mpsc::channel::<AgentUpdate>(64);
|
||||
{
|
||||
let handler_pool = pool.clone();
|
||||
let handler_btx = broadcast_tx.clone();
|
||||
tokio::spawn(async move {
|
||||
crate::sink::handle_agent_updates(update_rx, handler_pool, handler_btx).await;
|
||||
});
|
||||
}
|
||||
|
||||
tracing::info!("Agent loop started for project {}", project_id);
|
||||
|
||||
while let Some(event) = rx.recv().await {
|
||||
match event {
|
||||
AgentEvent::NewRequirement { workflow_id, requirement, template_id: forced_template } => {
|
||||
tracing::info!("Processing new requirement for workflow {}", workflow_id);
|
||||
// Generate project title from requirement (heuristic, no LLM)
|
||||
{
|
||||
let title = generate_title_heuristic(&requirement);
|
||||
let _ = sqlx::query("UPDATE projects SET name = ? WHERE id = ?")
|
||||
.bind(&title)
|
||||
.bind(&project_id)
|
||||
.execute(&pool)
|
||||
.await;
|
||||
let _ = broadcast_tx.send(WsMessage::ProjectUpdate {
|
||||
project_id: project_id.clone(),
|
||||
name: title,
|
||||
});
|
||||
}
|
||||
|
||||
let _ = update_tx.send(AgentUpdate::WorkflowStatus {
|
||||
workflow_id: workflow_id.clone(),
|
||||
status: "executing".into(),
|
||||
}).await;
|
||||
|
||||
// Template: must be explicitly provided (no LLM selection)
|
||||
let template_id = forced_template;
|
||||
// Persist template_id to workflow
|
||||
if let Some(ref tid) = template_id {
|
||||
let _ = sqlx::query("UPDATE workflows SET template_id = ? WHERE id = ?")
|
||||
.bind(tid)
|
||||
.bind(&workflow_id)
|
||||
.execute(&pool)
|
||||
.await;
|
||||
}
|
||||
|
||||
let loaded_template = if let Some(ref tid) = template_id {
|
||||
tracing::info!("Template selected for workflow {}: {}", workflow_id, tid);
|
||||
let _ = tokio::fs::create_dir_all(&workdir).await;
|
||||
|
||||
if template::is_repo_template(tid) {
|
||||
// Repo template: extract from git then load
|
||||
match template::extract_repo_template(tid, mgr.template_repo.as_ref()).await {
|
||||
Ok(template_dir) => {
|
||||
if let Err(e) = template::apply_template(&template_dir, &workdir).await {
|
||||
tracing::error!("Failed to apply repo template {}: {}", tid, e);
|
||||
}
|
||||
match LoadedTemplate::load_from_dir(tid, &template_dir).await {
|
||||
Ok(t) => Some(t),
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to load repo template {}: {}", tid, e);
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to extract repo template {}: {}", tid, e);
|
||||
None
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Local built-in template
|
||||
let template_dir = std::path::Path::new(template::templates_dir()).join(tid);
|
||||
if let Err(e) = template::apply_template(&template_dir, &workdir).await {
|
||||
tracing::error!("Failed to apply template {}: {}", tid, e);
|
||||
}
|
||||
match LoadedTemplate::load(tid).await {
|
||||
Ok(t) => Some(t),
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to load template {}: {}", tid, e);
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
// Import KB files from template
|
||||
if let Some(ref t) = loaded_template {
|
||||
if let Some(ref kb) = mgr.kb {
|
||||
let mut batch_items: Vec<(String, String)> = Vec::new();
|
||||
for (title, content) in &t.kb_files {
|
||||
// Check if article already exists by title
|
||||
let existing: Option<String> = sqlx::query_scalar(
|
||||
"SELECT id FROM kb_articles WHERE title = ?"
|
||||
)
|
||||
.bind(title)
|
||||
.fetch_optional(&pool)
|
||||
.await
|
||||
.ok()
|
||||
.flatten();
|
||||
|
||||
let article_id = if let Some(id) = existing {
|
||||
let _ = sqlx::query(
|
||||
"UPDATE kb_articles SET content = ?, updated_at = datetime('now') WHERE id = ?"
|
||||
)
|
||||
.bind(content)
|
||||
.bind(&id)
|
||||
.execute(&pool)
|
||||
.await;
|
||||
id
|
||||
} else {
|
||||
let id = uuid::Uuid::new_v4().to_string();
|
||||
let _ = sqlx::query(
|
||||
"INSERT INTO kb_articles (id, title, content) VALUES (?, ?, ?)"
|
||||
)
|
||||
.bind(&id)
|
||||
.bind(title)
|
||||
.bind(content)
|
||||
.execute(&pool)
|
||||
.await;
|
||||
id
|
||||
};
|
||||
|
||||
batch_items.push((article_id, content.clone()));
|
||||
}
|
||||
// Batch index: single embed.py call for all articles
|
||||
if !batch_items.is_empty() {
|
||||
if let Err(e) = kb.index_batch(&batch_items).await {
|
||||
tracing::warn!("Failed to batch index KB articles: {}", e);
|
||||
}
|
||||
}
|
||||
tracing::info!("Imported {} KB articles from template", t.kb_files.len());
|
||||
}
|
||||
}
|
||||
|
||||
ensure_workspace(&exec, &workdir).await;
|
||||
let _ = tokio::fs::write(format!("{}/requirement.md", workdir), &requirement).await;
|
||||
|
||||
// Run template setup if present
|
||||
if let Some(ref tid) = template_id {
|
||||
let template_dir = if template::is_repo_template(tid) {
|
||||
template::extract_repo_template(tid, mgr.template_repo.as_ref())
|
||||
.await
|
||||
.ok()
|
||||
} else {
|
||||
Some(std::path::Path::new(template::templates_dir()).join(tid))
|
||||
};
|
||||
if let Some(ref tdir) = template_dir {
|
||||
if let Err(e) = template::run_setup(tdir, &workdir).await {
|
||||
tracing::error!("Template setup failed for {}: {}", tid, e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let instructions = if let Some(ref t) = loaded_template {
|
||||
t.instructions.clone()
|
||||
} else {
|
||||
read_instructions(&workdir).await
|
||||
};
|
||||
|
||||
let ext_tools = loaded_template.as_ref().map(|t| &t.external_tools);
|
||||
let plan_approval = loaded_template.as_ref().map_or(false, |t| t.require_plan_approval);
|
||||
|
||||
tracing::info!("Starting agent loop for workflow {}", workflow_id);
|
||||
// Run tool-calling agent loop
|
||||
let result = run_agent_loop(
|
||||
&llm, &exec, &update_tx, &mut rx,
|
||||
&project_id, &workflow_id, &requirement, &workdir, &svc_mgr,
|
||||
&instructions, None, ext_tools,
|
||||
plan_approval,
|
||||
).await;
|
||||
|
||||
let final_status = if result.is_ok() { "done" } else { "failed" };
|
||||
tracing::info!("Agent loop finished for workflow {}, status: {}", workflow_id, final_status);
|
||||
if let Err(e) = &result {
|
||||
tracing::error!("Agent error for workflow {}: {}", workflow_id, e);
|
||||
let _ = update_tx.send(AgentUpdate::Error {
|
||||
message: format!("Agent error: {}", e),
|
||||
}).await;
|
||||
}
|
||||
|
||||
let _ = update_tx.send(AgentUpdate::WorkflowComplete {
|
||||
workflow_id: workflow_id.clone(),
|
||||
status: final_status.into(),
|
||||
report: None, // Report generation will be handled separately
|
||||
}).await;
|
||||
}
|
||||
AgentEvent::Comment { workflow_id, content } => {
|
||||
tracing::info!("Comment on workflow {}: {}", workflow_id, content);
|
||||
|
||||
let wf = sqlx::query_as::<_, crate::db::Workflow>(
|
||||
"SELECT * FROM workflows WHERE id = ?",
|
||||
)
|
||||
.bind(&workflow_id)
|
||||
.fetch_optional(&pool)
|
||||
.await
|
||||
.ok()
|
||||
.flatten();
|
||||
|
||||
let Some(wf) = wf else { continue };
|
||||
|
||||
// Load latest state snapshot
|
||||
let snapshot = sqlx::query_scalar::<_, String>(
|
||||
"SELECT state_json FROM agent_state_snapshots WHERE workflow_id = ? ORDER BY created_at DESC LIMIT 1"
|
||||
)
|
||||
.bind(&workflow_id)
|
||||
.fetch_optional(&pool)
|
||||
.await
|
||||
.ok()
|
||||
.flatten();
|
||||
|
||||
let mut state = snapshot
|
||||
.and_then(|json| serde_json::from_str::<AgentState>(&json).ok())
|
||||
.unwrap_or_else(AgentState::new);
|
||||
|
||||
// Resume directly if: workflow is failed/done/waiting_user,
|
||||
// OR if state snapshot has a WaitingUser step (e.g. after pod restart)
|
||||
let has_waiting_step = state.steps.iter().any(|s| matches!(s.status, StepStatus::WaitingUser));
|
||||
let is_resuming = wf.status == "failed" || wf.status == "done"
|
||||
|| wf.status == "waiting_user" || has_waiting_step;
|
||||
if is_resuming {
|
||||
// Reset Failed/WaitingUser steps so they get re-executed
|
||||
for step in &mut state.steps {
|
||||
if matches!(step.status, StepStatus::Failed) {
|
||||
step.status = StepStatus::Pending;
|
||||
}
|
||||
if matches!(step.status, StepStatus::WaitingUser) {
|
||||
// Mark as Running so it continues (not re-plans)
|
||||
step.status = StepStatus::Running;
|
||||
}
|
||||
}
|
||||
// Attach comment as feedback to the first actionable step
|
||||
if let Some(order) = state.first_actionable_step() {
|
||||
if let Some(step) = state.steps.iter_mut().find(|s| s.order == order) {
|
||||
step.user_feedbacks.push(content.clone());
|
||||
}
|
||||
}
|
||||
tracing::info!("[workflow {}] Resuming from state (status={}), first actionable: {:?}",
|
||||
workflow_id, wf.status, state.first_actionable_step());
|
||||
} else {
|
||||
// Active workflow: LLM decides whether to revise plan
|
||||
state = process_feedback(
|
||||
&llm, &update_tx,
|
||||
&project_id, &workflow_id, state, &content,
|
||||
).await;
|
||||
}
|
||||
|
||||
// If there are actionable steps, resume execution
|
||||
if state.first_actionable_step().is_some() {
|
||||
ensure_workspace(&exec, &workdir).await;
|
||||
|
||||
let _ = update_tx.send(AgentUpdate::WorkflowStatus {
|
||||
workflow_id: workflow_id.clone(),
|
||||
status: "executing".into(),
|
||||
}).await;
|
||||
|
||||
// Prepare state for execution: set first pending step to Running
|
||||
if let Some(next) = state.first_actionable_step() {
|
||||
let was_same_step = matches!(state.phase, AgentPhase::Executing { step } if step == next);
|
||||
if let Some(step) = state.steps.iter_mut().find(|s| s.order == next) {
|
||||
if matches!(step.status, StepStatus::Pending) {
|
||||
step.status = StepStatus::Running;
|
||||
}
|
||||
}
|
||||
state.phase = AgentPhase::Executing { step: next };
|
||||
// Only clear chat history when advancing to a new step;
|
||||
// keep it when resuming the same step after ask_user
|
||||
if !was_same_step {
|
||||
state.current_step_chat_history.clear();
|
||||
}
|
||||
}
|
||||
|
||||
let instructions = read_instructions(&workdir).await;
|
||||
|
||||
// Reload template config if available
|
||||
let loaded_template = if !wf.template_id.is_empty() {
|
||||
let tid = &wf.template_id;
|
||||
if template::is_repo_template(tid) {
|
||||
match template::extract_repo_template(tid, mgr.template_repo.as_ref()).await {
|
||||
Ok(template_dir) => {
|
||||
LoadedTemplate::load_from_dir(tid, &template_dir).await.ok()
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!("Failed to reload template {}: {}", tid, e);
|
||||
None
|
||||
}
|
||||
}
|
||||
} else {
|
||||
LoadedTemplate::load(tid).await.ok()
|
||||
}
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let ext_tools = loaded_template.as_ref().map(|t| &t.external_tools);
|
||||
let plan_approval = loaded_template.as_ref().map_or(false, |t| t.require_plan_approval);
|
||||
|
||||
let result = run_agent_loop(
|
||||
&llm, &exec, &update_tx, &mut rx,
|
||||
&project_id, &workflow_id, &wf.requirement, &workdir, &svc_mgr,
|
||||
&instructions, Some(state), ext_tools,
|
||||
plan_approval,
|
||||
).await;
|
||||
|
||||
let final_status = if result.is_ok() { "done" } else { "failed" };
|
||||
if let Err(e) = &result {
|
||||
let _ = update_tx.send(AgentUpdate::Error {
|
||||
message: format!("Agent error: {}", e),
|
||||
}).await;
|
||||
}
|
||||
|
||||
let _ = update_tx.send(AgentUpdate::WorkflowComplete {
|
||||
workflow_id: workflow_id.clone(),
|
||||
status: final_status.into(),
|
||||
report: None,
|
||||
}).await;
|
||||
} else {
|
||||
// No actionable steps — feedback was informational only
|
||||
let _ = update_tx.send(AgentUpdate::WorkflowStatus {
|
||||
workflow_id: workflow_id.clone(),
|
||||
status: "done".into(),
|
||||
}).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
tracing::info!("Agent loop ended for project {}", project_id);
|
||||
}
|
||||
|
||||
// --- Tool definitions ---
|
||||
|
||||
fn make_tool(name: &str, description: &str, parameters: serde_json::Value) -> Tool {
|
||||
@@ -750,49 +412,6 @@ fn build_step_user_message(step: &Step, completed_summaries: &[(i32, String, Str
|
||||
ctx
|
||||
}
|
||||
|
||||
fn build_feedback_prompt(project_id: &str, state: &AgentState, feedback: &str) -> String {
|
||||
let mut plan_state = String::new();
|
||||
for s in &state.steps {
|
||||
let status = match s.status {
|
||||
StepStatus::Done => " [done]",
|
||||
StepStatus::Running => " [running]",
|
||||
StepStatus::WaitingUser => " [waiting]",
|
||||
StepStatus::Failed => " [FAILED]",
|
||||
StepStatus::Pending => "",
|
||||
};
|
||||
plan_state.push_str(&format!("{}. {}{}\n {}\n", s.order, s.title, status, s.description));
|
||||
if let Some(summary) = &s.summary {
|
||||
plan_state.push_str(&format!(" 摘要: {}\n", summary));
|
||||
}
|
||||
}
|
||||
include_str!("prompts/feedback.md")
|
||||
.replace("{project_id}", project_id)
|
||||
.replace("{plan_state}", &plan_state)
|
||||
.replace("{feedback}", feedback)
|
||||
}
|
||||
|
||||
fn build_feedback_tools() -> Vec<Tool> {
|
||||
vec![
|
||||
make_tool("revise_plan", "修改执行计划。提供完整步骤列表。系统自动 diff:description 未变的已完成步骤保留成果,变化的步骤及后续重新执行。", serde_json::json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"steps": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"title": { "type": "string", "description": "步骤标题" },
|
||||
"description": { "type": "string", "description": "详细描述" }
|
||||
},
|
||||
"required": ["title", "description"]
|
||||
}
|
||||
}
|
||||
},
|
||||
"required": ["steps"]
|
||||
})),
|
||||
]
|
||||
}
|
||||
|
||||
// --- Helpers ---
|
||||
|
||||
/// Truncate a string at a char boundary, returning at most `max_bytes` bytes.
|
||||
@@ -956,91 +575,6 @@ async fn send_llm_call(
|
||||
}).await;
|
||||
}
|
||||
|
||||
/// Process user feedback: call LLM to decide whether to revise the plan.
|
||||
/// Returns the (possibly modified) AgentState ready for resumed execution.
|
||||
async fn process_feedback(
|
||||
llm: &LlmClient,
|
||||
update_tx: &mpsc::Sender<AgentUpdate>,
|
||||
project_id: &str,
|
||||
workflow_id: &str,
|
||||
mut state: AgentState,
|
||||
feedback: &str,
|
||||
) -> AgentState {
|
||||
let prompt = build_feedback_prompt(project_id, &state, feedback);
|
||||
let tools = build_feedback_tools();
|
||||
let messages = vec![
|
||||
ChatMessage::system(&prompt),
|
||||
ChatMessage::user(feedback),
|
||||
];
|
||||
|
||||
tracing::info!("[workflow {}] Processing feedback with LLM", workflow_id);
|
||||
let response = match llm.chat_with_tools(messages, &tools).await {
|
||||
Ok(r) => r,
|
||||
Err(e) => {
|
||||
tracing::error!("[workflow {}] Feedback LLM call failed: {}", workflow_id, e);
|
||||
if let Some(step) = state.steps.iter_mut().find(|s| !matches!(s.status, StepStatus::Done)) {
|
||||
step.user_feedbacks.push(feedback.to_string());
|
||||
}
|
||||
return state;
|
||||
}
|
||||
};
|
||||
|
||||
let choice = match response.choices.into_iter().next() {
|
||||
Some(c) => c,
|
||||
None => return state,
|
||||
};
|
||||
|
||||
if let Some(tool_calls) = &choice.message.tool_calls {
|
||||
for tc in tool_calls {
|
||||
if tc.function.name == "revise_plan" {
|
||||
let args: serde_json::Value = serde_json::from_str(&tc.function.arguments).unwrap_or_default();
|
||||
let raw_steps = args["steps"].as_array().cloned().unwrap_or_default();
|
||||
|
||||
let new_steps: Vec<Step> = raw_steps.iter().enumerate().map(|(i, item)| {
|
||||
let order = (i + 1) as i32;
|
||||
Step {
|
||||
order,
|
||||
title: item["title"].as_str().unwrap_or("").to_string(),
|
||||
description: item["description"].as_str().unwrap_or("").to_string(),
|
||||
status: StepStatus::Pending,
|
||||
summary: None,
|
||||
user_feedbacks: Vec::new(),
|
||||
db_id: String::new(),
|
||||
artifacts: Vec::new(),
|
||||
}
|
||||
}).collect();
|
||||
|
||||
let diff = state.apply_plan_diff(new_steps);
|
||||
|
||||
let _ = update_tx.send(AgentUpdate::PlanUpdate {
|
||||
workflow_id: workflow_id.to_string(),
|
||||
steps: plan_infos_from_state(&state),
|
||||
}).await;
|
||||
|
||||
tracing::info!("[workflow {}] Plan revised via feedback. First actionable: {:?}",
|
||||
workflow_id, state.first_actionable_step());
|
||||
|
||||
let diff_display = format!("```diff\n{}\n```", diff);
|
||||
send_execution(update_tx, workflow_id, 0, "revise_plan", "计划变更", &diff_display, "done").await;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
let text = choice.message.content.as_deref().unwrap_or("");
|
||||
tracing::info!("[workflow {}] Feedback processed, no plan change: {}", workflow_id, truncate_str(text, 200));
|
||||
send_execution(update_tx, workflow_id, state.current_step(), "text_response", "", text, "done").await;
|
||||
}
|
||||
|
||||
let target_order = state.first_actionable_step()
|
||||
.unwrap_or_else(|| state.steps.last().map(|s| s.order).unwrap_or(0));
|
||||
if let Some(step) = state.steps.iter_mut().find(|s| s.order == target_order) {
|
||||
step.user_feedbacks.push(feedback.to_string());
|
||||
}
|
||||
|
||||
send_snapshot(update_tx, workflow_id, state.current_step(), &state).await;
|
||||
|
||||
state
|
||||
}
|
||||
|
||||
/// Run an isolated sub-loop for a single step. Returns StepResult.
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub async fn run_step_loop(
|
||||
|
||||
@@ -1,53 +1,20 @@
|
||||
use std::sync::Arc;
|
||||
use axum::{
|
||||
extract::State,
|
||||
http::StatusCode,
|
||||
response::{IntoResponse, Response},
|
||||
routing::post,
|
||||
Json, Router,
|
||||
};
|
||||
use serde::Deserialize;
|
||||
|
||||
use crate::llm::{ChatMessage, LlmClient};
|
||||
use crate::AppState;
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct ChatRequest {
|
||||
messages: Vec<SimpleChatMessage>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct SimpleChatMessage {
|
||||
role: String,
|
||||
content: String,
|
||||
}
|
||||
|
||||
pub fn router(state: Arc<AppState>) -> Router {
|
||||
Router::new()
|
||||
.route("/chat", post(chat))
|
||||
.with_state(state)
|
||||
}
|
||||
|
||||
async fn chat(
|
||||
State(state): State<Arc<AppState>>,
|
||||
Json(input): Json<ChatRequest>,
|
||||
) -> Result<Json<serde_json::Value>, Response> {
|
||||
let llm = LlmClient::new(&state.config.llm);
|
||||
let messages: Vec<ChatMessage> = input
|
||||
.messages
|
||||
.into_iter()
|
||||
.map(|m| ChatMessage {
|
||||
role: m.role,
|
||||
content: Some(m.content),
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
})
|
||||
.collect();
|
||||
|
||||
let reply = llm.chat(messages).await.map_err(|e| {
|
||||
tracing::error!("Chat LLM error: {}", e);
|
||||
(StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response()
|
||||
})?;
|
||||
|
||||
Ok(Json(serde_json::json!({ "reply": reply })))
|
||||
async fn chat() -> Result<Json<serde_json::Value>, Response> {
|
||||
// Chat endpoint removed — LLM runs on workers only
|
||||
Err((StatusCode::GONE, "Chat endpoint removed. LLM runs on workers.").into_response())
|
||||
}
|
||||
|
||||
@@ -59,50 +59,12 @@ async fn proxy_to_service(
|
||||
}
|
||||
|
||||
async fn proxy_impl(
|
||||
state: &AppState,
|
||||
project_id: &str,
|
||||
path: &str,
|
||||
req: Request<Body>,
|
||||
_state: &AppState,
|
||||
_project_id: &str,
|
||||
_path: &str,
|
||||
_req: Request<Body>,
|
||||
) -> Response {
|
||||
let port = match state.agent_mgr.get_service_port(project_id).await {
|
||||
Some(p) => p,
|
||||
None => return (StatusCode::SERVICE_UNAVAILABLE, "服务未启动").into_response(),
|
||||
};
|
||||
|
||||
let query = req.uri().query().map(|q| format!("?{}", q)).unwrap_or_default();
|
||||
let url = format!("http://127.0.0.1:{}{}{}", port, path, query);
|
||||
|
||||
let client = reqwest::Client::new();
|
||||
let method = req.method().clone();
|
||||
let headers = req.headers().clone();
|
||||
let body_bytes = match axum::body::to_bytes(req.into_body(), 10 * 1024 * 1024).await {
|
||||
Ok(b) => b,
|
||||
Err(_) => return (StatusCode::BAD_REQUEST, "请求体过大").into_response(),
|
||||
};
|
||||
|
||||
let mut upstream_req = client.request(method, &url);
|
||||
for (key, val) in headers.iter() {
|
||||
if key != "host" {
|
||||
upstream_req = upstream_req.header(key, val);
|
||||
}
|
||||
}
|
||||
upstream_req = upstream_req.body(body_bytes);
|
||||
|
||||
match upstream_req.send().await {
|
||||
Ok(resp) => {
|
||||
let status = StatusCode::from_u16(resp.status().as_u16()).unwrap_or(StatusCode::BAD_GATEWAY);
|
||||
let resp_headers = resp.headers().clone();
|
||||
let body = resp.bytes().await.unwrap_or_default();
|
||||
let mut response = (status, body).into_response();
|
||||
for (key, val) in resp_headers.iter() {
|
||||
if let Ok(name) = axum::http::header::HeaderName::from_bytes(key.as_ref()) {
|
||||
response.headers_mut().insert(name, val.clone());
|
||||
}
|
||||
}
|
||||
response
|
||||
}
|
||||
Err(_) => (StatusCode::BAD_GATEWAY, "无法连接到后端服务").into_response(),
|
||||
}
|
||||
(StatusCode::SERVICE_UNAVAILABLE, "服务在 worker 上运行,无法从 server 代理").into_response()
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -28,7 +28,9 @@ pub struct AppState {
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct Config {
|
||||
pub llm: LlmConfig,
|
||||
/// LLM config is optional on the server — workers bring their own.
|
||||
#[serde(default)]
|
||||
pub llm: Option<LlmConfig>,
|
||||
pub server: ServerConfig,
|
||||
pub database: DatabaseConfig,
|
||||
#[serde(default)]
|
||||
|
||||
26
src/main.rs
26
src/main.rs
@@ -27,6 +27,9 @@ enum Command {
|
||||
/// Worker name
|
||||
#[arg(long, env = "TORI_WORKER_NAME")]
|
||||
name: Option<String>,
|
||||
/// Config file path (for LLM settings)
|
||||
#[arg(long, default_value = "config.yaml")]
|
||||
config: String,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -40,13 +43,19 @@ async fn main() -> anyhow::Result<()> {
|
||||
|
||||
match cli.command {
|
||||
Command::Server => run_server().await,
|
||||
Command::Worker { server, name } => {
|
||||
Command::Worker { server, name, config } => {
|
||||
let name = name.unwrap_or_else(|| {
|
||||
hostname::get()
|
||||
.map(|h| h.to_string_lossy().to_string())
|
||||
.unwrap_or_else(|_| "worker-1".to_string())
|
||||
});
|
||||
worker_runner::run(&server, &name).await
|
||||
let config_str = std::fs::read_to_string(&config)
|
||||
.unwrap_or_else(|e| panic!("Failed to read {}: {}", config, e));
|
||||
let cfg: Config = serde_yaml::from_str(&config_str)
|
||||
.unwrap_or_else(|e| panic!("Failed to parse {}: {}", config, e));
|
||||
let llm_config = cfg.llm
|
||||
.unwrap_or_else(|| panic!("LLM config required in {} for worker mode", config));
|
||||
worker_runner::run(&server, &name, &llm_config).await
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -79,10 +88,6 @@ async fn run_server() -> anyhow::Result<()> {
|
||||
|
||||
let agent_mgr = agent::AgentManager::new(
|
||||
database.pool.clone(),
|
||||
config.llm.clone(),
|
||||
config.template_repo.clone(),
|
||||
kb_arc.clone(),
|
||||
config.jwt_private_key.clone(),
|
||||
worker_mgr.clone(),
|
||||
);
|
||||
|
||||
@@ -129,6 +134,7 @@ async fn run_server() -> anyhow::Result<()> {
|
||||
}
|
||||
};
|
||||
|
||||
let ws_pool = database.pool.clone();
|
||||
let state = Arc::new(AppState {
|
||||
db: database,
|
||||
config: config.clone(),
|
||||
@@ -151,7 +157,13 @@ async fn run_server() -> anyhow::Result<()> {
|
||||
let r = obj_root;
|
||||
move || api::obj::root_listing(r)
|
||||
}))
|
||||
.nest("/ws/tori/workers", ws_worker::router(worker_mgr))
|
||||
.nest("/ws/tori/workers", {
|
||||
let pool = ws_pool;
|
||||
let agent_mgr_for_ws = agent_mgr.clone();
|
||||
let broadcast_fn: Arc<dyn Fn(&str) -> tokio::sync::broadcast::Sender<agent::WsMessage> + Send + Sync> =
|
||||
Arc::new(move |pid: &str| agent_mgr_for_ws.get_broadcast_sender(pid));
|
||||
ws_worker::router(worker_mgr, pool, broadcast_fn)
|
||||
})
|
||||
.nest("/ws/tori", ws::router(agent_mgr))
|
||||
.nest_service("/tori", ServeDir::new("web/dist").fallback(ServeFile::new("web/dist/index.html")))
|
||||
.route("/", axum::routing::get(|| async {
|
||||
|
||||
286
src/sink.rs
286
src/sink.rs
@@ -3,7 +3,7 @@ use std::sync::Arc;
|
||||
use std::sync::atomic::{AtomicU16, Ordering};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use sqlx::sqlite::SqlitePool;
|
||||
use tokio::sync::{RwLock, broadcast, mpsc};
|
||||
use tokio::sync::{RwLock, broadcast};
|
||||
|
||||
use crate::agent::{PlanStepInfo, WsMessage, ServiceInfo};
|
||||
use crate::state::{AgentState, Artifact};
|
||||
@@ -13,64 +13,19 @@ use crate::state::{AgentState, Artifact};
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(tag = "kind")]
|
||||
pub enum AgentUpdate {
|
||||
PlanUpdate {
|
||||
workflow_id: String,
|
||||
steps: Vec<PlanStepInfo>,
|
||||
},
|
||||
WorkflowStatus {
|
||||
workflow_id: String,
|
||||
status: String,
|
||||
},
|
||||
Activity {
|
||||
workflow_id: String,
|
||||
activity: String,
|
||||
},
|
||||
ExecutionLog {
|
||||
workflow_id: String,
|
||||
step_order: i32,
|
||||
tool_name: String,
|
||||
tool_input: String,
|
||||
output: String,
|
||||
status: String,
|
||||
},
|
||||
LlmCallLog {
|
||||
workflow_id: String,
|
||||
step_order: i32,
|
||||
phase: String,
|
||||
messages_count: i32,
|
||||
tools_count: i32,
|
||||
tool_calls: String,
|
||||
text_response: String,
|
||||
prompt_tokens: Option<u32>,
|
||||
completion_tokens: Option<u32>,
|
||||
latency_ms: i64,
|
||||
},
|
||||
StateSnapshot {
|
||||
workflow_id: String,
|
||||
step_order: i32,
|
||||
state: AgentState,
|
||||
},
|
||||
WorkflowComplete {
|
||||
workflow_id: String,
|
||||
status: String,
|
||||
report: Option<String>,
|
||||
},
|
||||
ArtifactSave {
|
||||
workflow_id: String,
|
||||
step_order: i32,
|
||||
artifact: Artifact,
|
||||
},
|
||||
RequirementUpdate {
|
||||
workflow_id: String,
|
||||
requirement: String,
|
||||
},
|
||||
Error {
|
||||
message: String,
|
||||
},
|
||||
PlanUpdate { workflow_id: String, steps: Vec<PlanStepInfo> },
|
||||
WorkflowStatus { workflow_id: String, status: String },
|
||||
Activity { workflow_id: String, activity: String },
|
||||
ExecutionLog { workflow_id: String, step_order: i32, tool_name: String, tool_input: String, output: String, status: String },
|
||||
LlmCallLog { workflow_id: String, step_order: i32, phase: String, messages_count: i32, tools_count: i32, tool_calls: String, text_response: String, prompt_tokens: Option<u32>, completion_tokens: Option<u32>, latency_ms: i64 },
|
||||
StateSnapshot { workflow_id: String, step_order: i32, state: AgentState },
|
||||
WorkflowComplete { workflow_id: String, status: String, report: Option<String> },
|
||||
ArtifactSave { workflow_id: String, step_order: i32, artifact: Artifact },
|
||||
RequirementUpdate { workflow_id: String, requirement: String },
|
||||
Error { message: String },
|
||||
}
|
||||
|
||||
/// Manages local services (start_service / stop_service tools).
|
||||
/// Created per-worker or per-agent-loop.
|
||||
pub struct ServiceManager {
|
||||
pub services: RwLock<HashMap<String, ServiceInfo>>,
|
||||
next_port: AtomicU16,
|
||||
@@ -89,151 +44,84 @@ impl ServiceManager {
|
||||
}
|
||||
}
|
||||
|
||||
/// Server-side handler: consumes AgentUpdate from channel, persists to DB and broadcasts to frontend.
|
||||
pub async fn handle_agent_updates(
|
||||
mut rx: mpsc::Receiver<AgentUpdate>,
|
||||
pool: SqlitePool,
|
||||
broadcast_tx: broadcast::Sender<WsMessage>,
|
||||
/// Helper: broadcast if sender is available.
|
||||
fn bcast(tx: Option<&broadcast::Sender<WsMessage>>, msg: WsMessage) {
|
||||
if let Some(tx) = tx { let _ = tx.send(msg); }
|
||||
}
|
||||
|
||||
/// Process a single AgentUpdate: persist to DB and broadcast to frontend.
|
||||
pub async fn handle_single_update(
|
||||
update: &AgentUpdate,
|
||||
pool: &SqlitePool,
|
||||
broadcast_tx: Option<&broadcast::Sender<WsMessage>>,
|
||||
) {
|
||||
while let Some(update) = rx.recv().await {
|
||||
match update {
|
||||
AgentUpdate::PlanUpdate { workflow_id, steps } => {
|
||||
let _ = broadcast_tx.send(WsMessage::PlanUpdate { workflow_id, steps });
|
||||
}
|
||||
AgentUpdate::WorkflowStatus { ref workflow_id, ref status } => {
|
||||
let _ = sqlx::query("UPDATE workflows SET status = ? WHERE id = ?")
|
||||
.bind(status)
|
||||
.bind(workflow_id)
|
||||
.execute(&pool)
|
||||
.await;
|
||||
let _ = broadcast_tx.send(WsMessage::WorkflowStatusUpdate {
|
||||
workflow_id: workflow_id.clone(),
|
||||
status: status.clone(),
|
||||
});
|
||||
}
|
||||
AgentUpdate::Activity { workflow_id, activity } => {
|
||||
let _ = broadcast_tx.send(WsMessage::ActivityUpdate { workflow_id, activity });
|
||||
}
|
||||
AgentUpdate::ExecutionLog { ref workflow_id, step_order, ref tool_name, ref tool_input, ref output, ref status } => {
|
||||
let id = uuid::Uuid::new_v4().to_string();
|
||||
let _ = sqlx::query(
|
||||
"INSERT INTO execution_log (id, workflow_id, step_order, tool_name, tool_input, output, status, created_at) VALUES (?, ?, ?, ?, ?, ?, ?, datetime('now'))"
|
||||
)
|
||||
.bind(&id)
|
||||
.bind(workflow_id)
|
||||
.bind(step_order)
|
||||
.bind(tool_name)
|
||||
.bind(tool_input)
|
||||
.bind(output)
|
||||
.bind(status)
|
||||
.execute(&pool)
|
||||
.await;
|
||||
let _ = broadcast_tx.send(WsMessage::StepStatusUpdate {
|
||||
step_id: id,
|
||||
status: status.clone(),
|
||||
output: output.clone(),
|
||||
});
|
||||
}
|
||||
AgentUpdate::LlmCallLog { ref workflow_id, step_order, ref phase, messages_count, tools_count, ref tool_calls, ref text_response, prompt_tokens, completion_tokens, latency_ms } => {
|
||||
let id = uuid::Uuid::new_v4().to_string();
|
||||
let _ = sqlx::query(
|
||||
"INSERT INTO llm_call_log (id, workflow_id, step_order, phase, messages_count, tools_count, tool_calls, text_response, prompt_tokens, completion_tokens, latency_ms, created_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, datetime('now'))"
|
||||
)
|
||||
.bind(&id)
|
||||
.bind(workflow_id)
|
||||
.bind(step_order)
|
||||
.bind(phase)
|
||||
.bind(messages_count)
|
||||
.bind(tools_count)
|
||||
.bind(tool_calls)
|
||||
.bind(text_response)
|
||||
.bind(prompt_tokens.map(|v| v as i32))
|
||||
.bind(completion_tokens.map(|v| v as i32))
|
||||
.bind(latency_ms as i32)
|
||||
.execute(&pool)
|
||||
.await;
|
||||
let entry = crate::db::LlmCallLogEntry {
|
||||
id,
|
||||
workflow_id: workflow_id.clone(),
|
||||
step_order,
|
||||
phase: phase.clone(),
|
||||
messages_count,
|
||||
tools_count,
|
||||
tool_calls: tool_calls.clone(),
|
||||
text_response: text_response.clone(),
|
||||
prompt_tokens: prompt_tokens.map(|v| v as i32),
|
||||
completion_tokens: completion_tokens.map(|v| v as i32),
|
||||
latency_ms: latency_ms as i32,
|
||||
created_at: String::new(),
|
||||
};
|
||||
let _ = broadcast_tx.send(WsMessage::LlmCallLog {
|
||||
workflow_id: workflow_id.clone(),
|
||||
entry,
|
||||
});
|
||||
}
|
||||
AgentUpdate::StateSnapshot { ref workflow_id, step_order, ref state } => {
|
||||
let id = uuid::Uuid::new_v4().to_string();
|
||||
let json = serde_json::to_string(state).unwrap_or_default();
|
||||
let _ = sqlx::query(
|
||||
"INSERT INTO agent_state_snapshots (id, workflow_id, step_order, state_json, created_at) VALUES (?, ?, ?, ?, datetime('now'))"
|
||||
)
|
||||
.bind(&id)
|
||||
.bind(workflow_id)
|
||||
.bind(step_order)
|
||||
.bind(&json)
|
||||
.execute(&pool)
|
||||
.await;
|
||||
}
|
||||
AgentUpdate::WorkflowComplete { ref workflow_id, ref status, ref report } => {
|
||||
let _ = sqlx::query("UPDATE workflows SET status = ? WHERE id = ?")
|
||||
.bind(status)
|
||||
.bind(workflow_id)
|
||||
.execute(&pool)
|
||||
.await;
|
||||
if let Some(ref r) = report {
|
||||
let _ = sqlx::query("UPDATE workflows SET report = ? WHERE id = ?")
|
||||
.bind(r)
|
||||
.bind(workflow_id)
|
||||
.execute(&pool)
|
||||
.await;
|
||||
let _ = broadcast_tx.send(WsMessage::ReportReady {
|
||||
workflow_id: workflow_id.clone(),
|
||||
});
|
||||
}
|
||||
let _ = broadcast_tx.send(WsMessage::WorkflowStatusUpdate {
|
||||
workflow_id: workflow_id.clone(),
|
||||
status: status.clone(),
|
||||
});
|
||||
}
|
||||
AgentUpdate::ArtifactSave { ref workflow_id, step_order, ref artifact } => {
|
||||
let id = uuid::Uuid::new_v4().to_string();
|
||||
let _ = sqlx::query(
|
||||
"INSERT INTO step_artifacts (id, workflow_id, step_order, name, path, artifact_type, description) VALUES (?, ?, ?, ?, ?, ?, ?)"
|
||||
)
|
||||
.bind(&id)
|
||||
.bind(workflow_id)
|
||||
.bind(step_order)
|
||||
.bind(&artifact.name)
|
||||
.bind(&artifact.path)
|
||||
.bind(&artifact.artifact_type)
|
||||
.bind(&artifact.description)
|
||||
.execute(&pool)
|
||||
.await;
|
||||
}
|
||||
AgentUpdate::RequirementUpdate { ref workflow_id, ref requirement } => {
|
||||
let _ = sqlx::query("UPDATE workflows SET requirement = ? WHERE id = ?")
|
||||
.bind(requirement)
|
||||
.bind(workflow_id)
|
||||
.execute(&pool)
|
||||
.await;
|
||||
let _ = broadcast_tx.send(WsMessage::RequirementUpdate {
|
||||
workflow_id: workflow_id.clone(),
|
||||
requirement: requirement.clone(),
|
||||
});
|
||||
}
|
||||
AgentUpdate::Error { message } => {
|
||||
let _ = broadcast_tx.send(WsMessage::Error { message });
|
||||
match update {
|
||||
AgentUpdate::PlanUpdate { workflow_id, steps } => {
|
||||
bcast(broadcast_tx, WsMessage::PlanUpdate { workflow_id: workflow_id.clone(), steps: steps.clone() });
|
||||
}
|
||||
AgentUpdate::WorkflowStatus { workflow_id, status } => {
|
||||
let _ = sqlx::query("UPDATE workflows SET status = ? WHERE id = ?")
|
||||
.bind(status).bind(workflow_id).execute(pool).await;
|
||||
bcast(broadcast_tx, WsMessage::WorkflowStatusUpdate { workflow_id: workflow_id.clone(), status: status.clone() });
|
||||
}
|
||||
AgentUpdate::Activity { workflow_id, activity } => {
|
||||
bcast(broadcast_tx, WsMessage::ActivityUpdate { workflow_id: workflow_id.clone(), activity: activity.clone() });
|
||||
}
|
||||
AgentUpdate::ExecutionLog { workflow_id, step_order, tool_name, tool_input, output, status } => {
|
||||
let id = uuid::Uuid::new_v4().to_string();
|
||||
let _ = sqlx::query(
|
||||
"INSERT INTO execution_log (id, workflow_id, step_order, tool_name, tool_input, output, status, created_at) VALUES (?, ?, ?, ?, ?, ?, ?, datetime('now'))"
|
||||
).bind(&id).bind(workflow_id).bind(step_order).bind(tool_name).bind(tool_input).bind(output).bind(status)
|
||||
.execute(pool).await;
|
||||
bcast(broadcast_tx, WsMessage::StepStatusUpdate { step_id: id, status: status.clone(), output: output.clone() });
|
||||
}
|
||||
AgentUpdate::LlmCallLog { workflow_id, step_order, phase, messages_count, tools_count, tool_calls, text_response, prompt_tokens, completion_tokens, latency_ms } => {
|
||||
let id = uuid::Uuid::new_v4().to_string();
|
||||
let _ = sqlx::query(
|
||||
"INSERT INTO llm_call_log (id, workflow_id, step_order, phase, messages_count, tools_count, tool_calls, text_response, prompt_tokens, completion_tokens, latency_ms, created_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, datetime('now'))"
|
||||
).bind(&id).bind(workflow_id).bind(step_order).bind(phase).bind(messages_count).bind(tools_count)
|
||||
.bind(tool_calls).bind(text_response).bind(prompt_tokens.map(|v| v as i32)).bind(completion_tokens.map(|v| v as i32)).bind(*latency_ms as i32)
|
||||
.execute(pool).await;
|
||||
let entry = crate::db::LlmCallLogEntry {
|
||||
id, workflow_id: workflow_id.clone(), step_order: *step_order, phase: phase.clone(),
|
||||
messages_count: *messages_count, tools_count: *tools_count, tool_calls: tool_calls.clone(),
|
||||
text_response: text_response.clone(), prompt_tokens: prompt_tokens.map(|v| v as i32),
|
||||
completion_tokens: completion_tokens.map(|v| v as i32), latency_ms: *latency_ms as i32,
|
||||
created_at: String::new(),
|
||||
};
|
||||
bcast(broadcast_tx, WsMessage::LlmCallLog { workflow_id: workflow_id.clone(), entry });
|
||||
}
|
||||
AgentUpdate::StateSnapshot { workflow_id, step_order, state } => {
|
||||
let id = uuid::Uuid::new_v4().to_string();
|
||||
let json = serde_json::to_string(state).unwrap_or_default();
|
||||
let _ = sqlx::query(
|
||||
"INSERT INTO agent_state_snapshots (id, workflow_id, step_order, state_json, created_at) VALUES (?, ?, ?, ?, datetime('now'))"
|
||||
).bind(&id).bind(workflow_id).bind(step_order).bind(&json).execute(pool).await;
|
||||
}
|
||||
AgentUpdate::WorkflowComplete { workflow_id, status, report } => {
|
||||
let _ = sqlx::query("UPDATE workflows SET status = ? WHERE id = ?")
|
||||
.bind(status).bind(workflow_id).execute(pool).await;
|
||||
if let Some(r) = report {
|
||||
let _ = sqlx::query("UPDATE workflows SET report = ? WHERE id = ?")
|
||||
.bind(r).bind(workflow_id).execute(pool).await;
|
||||
bcast(broadcast_tx, WsMessage::ReportReady { workflow_id: workflow_id.clone() });
|
||||
}
|
||||
bcast(broadcast_tx, WsMessage::WorkflowStatusUpdate { workflow_id: workflow_id.clone(), status: status.clone() });
|
||||
}
|
||||
AgentUpdate::ArtifactSave { workflow_id, step_order, artifact } => {
|
||||
let id = uuid::Uuid::new_v4().to_string();
|
||||
let _ = sqlx::query(
|
||||
"INSERT INTO step_artifacts (id, workflow_id, step_order, name, path, artifact_type, description) VALUES (?, ?, ?, ?, ?, ?, ?)"
|
||||
).bind(&id).bind(workflow_id).bind(step_order).bind(&artifact.name).bind(&artifact.path)
|
||||
.bind(&artifact.artifact_type).bind(&artifact.description).execute(pool).await;
|
||||
}
|
||||
AgentUpdate::RequirementUpdate { workflow_id, requirement } => {
|
||||
let _ = sqlx::query("UPDATE workflows SET requirement = ? WHERE id = ?")
|
||||
.bind(requirement).bind(workflow_id).execute(pool).await;
|
||||
bcast(broadcast_tx, WsMessage::RequirementUpdate { workflow_id: workflow_id.clone(), requirement: requirement.clone() });
|
||||
}
|
||||
AgentUpdate::Error { message } => {
|
||||
bcast(broadcast_tx, WsMessage::Error { message: message.clone() });
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
221
src/worker.rs
221
src/worker.rs
@@ -3,6 +3,8 @@ use std::sync::Arc;
|
||||
use tokio::sync::RwLock;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::state::AgentState;
|
||||
|
||||
/// Information reported by a worker on registration.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct WorkerInfo {
|
||||
@@ -14,18 +16,13 @@ pub struct WorkerInfo {
|
||||
pub kernel: String,
|
||||
}
|
||||
|
||||
/// A registered worker with a channel for sending scripts to execute.
|
||||
/// A registered worker.
|
||||
struct Worker {
|
||||
pub info: WorkerInfo,
|
||||
pub tx: tokio::sync::mpsc::Sender<WorkerRequest>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct WorkerRequest {
|
||||
pub job_id: String,
|
||||
pub script: String,
|
||||
pub tx: tokio::sync::mpsc::Sender<ServerToWorker>,
|
||||
}
|
||||
|
||||
/// Legacy script execution result.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct WorkerResult {
|
||||
pub job_id: String,
|
||||
@@ -34,110 +31,6 @@ pub struct WorkerResult {
|
||||
pub stderr: String,
|
||||
}
|
||||
|
||||
/// Manages all connected workers.
|
||||
pub struct WorkerManager {
|
||||
workers: RwLock<HashMap<String, Worker>>,
|
||||
/// Pending job results, keyed by job_id.
|
||||
results: RwLock<HashMap<String, tokio::sync::oneshot::Sender<WorkerResult>>>,
|
||||
}
|
||||
|
||||
impl WorkerManager {
|
||||
pub fn new() -> Arc<Self> {
|
||||
Arc::new(Self {
|
||||
workers: RwLock::new(HashMap::new()),
|
||||
results: RwLock::new(HashMap::new()),
|
||||
})
|
||||
}
|
||||
|
||||
/// Register a new worker. Returns a receiver for job requests.
|
||||
pub async fn register(
|
||||
&self,
|
||||
name: String,
|
||||
info: WorkerInfo,
|
||||
) -> tokio::sync::mpsc::Receiver<WorkerRequest> {
|
||||
let (tx, rx) = tokio::sync::mpsc::channel(16);
|
||||
tracing::info!("Worker registered: {} (cpu={}, mem={}, gpu={}, os={}, kernel={})",
|
||||
name, info.cpu, info.memory, info.gpu, info.os, info.kernel);
|
||||
self.workers.write().await.insert(name, Worker { info, tx });
|
||||
rx
|
||||
}
|
||||
|
||||
/// Remove a worker.
|
||||
pub async fn unregister(&self, name: &str) {
|
||||
self.workers.write().await.remove(name);
|
||||
tracing::info!("Worker unregistered: {}", name);
|
||||
}
|
||||
|
||||
/// List all connected workers.
|
||||
pub async fn list(&self) -> Vec<(String, WorkerInfo)> {
|
||||
self.workers
|
||||
.read()
|
||||
.await
|
||||
.iter()
|
||||
.map(|(name, w)| (name.clone(), w.info.clone()))
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Submit a script to a worker and wait for the result.
|
||||
pub async fn execute(
|
||||
&self,
|
||||
worker_name: &str,
|
||||
script: &str,
|
||||
timeout_secs: u64,
|
||||
) -> Result<WorkerResult, String> {
|
||||
let job_id = uuid::Uuid::new_v4().to_string();
|
||||
|
||||
// Find the worker and send the request
|
||||
let tx = {
|
||||
let workers = self.workers.read().await;
|
||||
let worker = workers
|
||||
.get(worker_name)
|
||||
.ok_or_else(|| format!("Worker '{}' not found", worker_name))?;
|
||||
worker.tx.clone()
|
||||
};
|
||||
|
||||
let (result_tx, result_rx) = tokio::sync::oneshot::channel();
|
||||
self.results.write().await.insert(job_id.clone(), result_tx);
|
||||
|
||||
let req = WorkerRequest {
|
||||
job_id: job_id.clone(),
|
||||
script: script.to_string(),
|
||||
};
|
||||
|
||||
tx.send(req).await.map_err(|_| {
|
||||
format!("Worker '{}' disconnected", worker_name)
|
||||
})?;
|
||||
|
||||
// Wait for result with timeout
|
||||
match tokio::time::timeout(
|
||||
std::time::Duration::from_secs(timeout_secs),
|
||||
result_rx,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(Ok(result)) => Ok(result),
|
||||
Ok(Err(_)) => Err("Worker channel closed unexpectedly".into()),
|
||||
Err(_) => {
|
||||
self.results.write().await.remove(&job_id);
|
||||
Err(format!("Execution timed out after {}s", timeout_secs))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Called when a worker sends back a result.
|
||||
pub async fn report_result(&self, result: WorkerResult) {
|
||||
if let Some(tx) = self.results.write().await.remove(&result.job_id) {
|
||||
let _ = tx.send(result);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// --- Extended protocol for workflow execution ---
|
||||
|
||||
use crate::LlmConfig;
|
||||
use crate::state::AgentState;
|
||||
use crate::llm::Tool;
|
||||
|
||||
/// Messages sent from server to worker.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(tag = "type")]
|
||||
@@ -148,15 +41,12 @@ pub enum ServerToWorker {
|
||||
workflow_id: String,
|
||||
project_id: String,
|
||||
requirement: String,
|
||||
workdir: String,
|
||||
instructions: String,
|
||||
llm_config: LlmConfig,
|
||||
#[serde(default)]
|
||||
template_id: Option<String>,
|
||||
#[serde(default)]
|
||||
initial_state: Option<AgentState>,
|
||||
#[serde(default)]
|
||||
require_plan_approval: bool,
|
||||
#[serde(default)]
|
||||
external_tools: Vec<Tool>,
|
||||
},
|
||||
/// Forward a user comment to the worker executing this workflow.
|
||||
#[serde(rename = "comment")]
|
||||
@@ -173,9 +63,6 @@ pub enum WorkerToServer {
|
||||
/// Worker registration.
|
||||
#[serde(rename = "register")]
|
||||
Register { info: WorkerInfo },
|
||||
/// Script execution result (legacy).
|
||||
#[serde(rename = "result")]
|
||||
Result(WorkerResult),
|
||||
/// Agent update from workflow execution.
|
||||
#[serde(rename = "update")]
|
||||
Update {
|
||||
@@ -184,3 +71,97 @@ pub enum WorkerToServer {
|
||||
update: crate::sink::AgentUpdate,
|
||||
},
|
||||
}
|
||||
|
||||
/// Manages all connected workers and workflow assignments.
|
||||
pub struct WorkerManager {
|
||||
workers: RwLock<HashMap<String, Worker>>,
|
||||
/// workflow_id → worker_name
|
||||
assignments: RwLock<HashMap<String, String>>,
|
||||
}
|
||||
|
||||
impl WorkerManager {
|
||||
pub fn new() -> Arc<Self> {
|
||||
Arc::new(Self {
|
||||
workers: RwLock::new(HashMap::new()),
|
||||
assignments: RwLock::new(HashMap::new()),
|
||||
})
|
||||
}
|
||||
|
||||
/// Register a new worker. Returns a receiver for messages.
|
||||
pub async fn register(
|
||||
&self,
|
||||
name: String,
|
||||
info: WorkerInfo,
|
||||
) -> tokio::sync::mpsc::Receiver<ServerToWorker> {
|
||||
let (tx, rx) = tokio::sync::mpsc::channel(16);
|
||||
tracing::info!("Worker registered: {} (cpu={}, mem={}, gpu={}, os={}, kernel={})",
|
||||
name, info.cpu, info.memory, info.gpu, info.os, info.kernel);
|
||||
self.workers.write().await.insert(name, Worker { info, tx });
|
||||
rx
|
||||
}
|
||||
|
||||
/// Remove a worker and clean up its assignments.
|
||||
pub async fn unregister(&self, name: &str) {
|
||||
self.workers.write().await.remove(name);
|
||||
// Remove all workflow assignments for this worker
|
||||
let mut assignments = self.assignments.write().await;
|
||||
assignments.retain(|_, worker| worker != name);
|
||||
tracing::info!("Worker unregistered: {}", name);
|
||||
}
|
||||
|
||||
/// List all connected workers.
|
||||
pub async fn list(&self) -> Vec<(String, WorkerInfo)> {
|
||||
self.workers
|
||||
.read()
|
||||
.await
|
||||
.iter()
|
||||
.map(|(name, w)| (name.clone(), w.info.clone()))
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Assign a workflow to the first available worker. Returns worker name.
|
||||
pub async fn assign_workflow(&self, assign: ServerToWorker) -> Result<String, String> {
|
||||
let workflow_id = match &assign {
|
||||
ServerToWorker::WorkflowAssign { workflow_id, .. } => workflow_id.clone(),
|
||||
_ => return Err("Not a workflow assignment".into()),
|
||||
};
|
||||
|
||||
let workers = self.workers.read().await;
|
||||
// Pick first worker (simple strategy for now)
|
||||
let (name, worker) = workers.iter().next()
|
||||
.ok_or_else(|| "No workers available".to_string())?;
|
||||
|
||||
worker.tx.send(assign).await.map_err(|_| {
|
||||
format!("Worker '{}' disconnected", name)
|
||||
})?;
|
||||
|
||||
let worker_name = name.clone();
|
||||
drop(workers);
|
||||
|
||||
self.assignments.write().await.insert(workflow_id, worker_name.clone());
|
||||
Ok(worker_name)
|
||||
}
|
||||
|
||||
/// Forward a comment to the worker handling a workflow.
|
||||
pub async fn forward_comment(&self, workflow_id: &str, content: &str) -> Result<(), String> {
|
||||
let assignments = self.assignments.read().await;
|
||||
let worker_name = assignments.get(workflow_id)
|
||||
.ok_or_else(|| format!("No worker assigned for workflow {}", workflow_id))?
|
||||
.clone();
|
||||
drop(assignments);
|
||||
|
||||
let workers = self.workers.read().await;
|
||||
let worker = workers.get(&worker_name)
|
||||
.ok_or_else(|| format!("Worker '{}' not found", worker_name))?;
|
||||
|
||||
worker.tx.send(ServerToWorker::Comment {
|
||||
workflow_id: workflow_id.to_string(),
|
||||
content: content.to_string(),
|
||||
}).await.map_err(|_| format!("Worker '{}' disconnected", worker_name))
|
||||
}
|
||||
|
||||
/// Remove a workflow assignment (when workflow completes).
|
||||
pub async fn complete_workflow(&self, workflow_id: &str) {
|
||||
self.assignments.write().await.remove(workflow_id);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -59,11 +59,11 @@ fn collect_worker_info(name: &str) -> WorkerInfo {
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn run(server_url: &str, worker_name: &str) -> anyhow::Result<()> {
|
||||
tracing::info!("Tori worker '{}' connecting to {}", worker_name, server_url);
|
||||
pub async fn run(server_url: &str, worker_name: &str, llm_config: &crate::LlmConfig) -> anyhow::Result<()> {
|
||||
tracing::info!("Tori worker '{}' connecting to {} (model={})", worker_name, server_url, llm_config.model);
|
||||
|
||||
loop {
|
||||
match connect_and_run(server_url, worker_name).await {
|
||||
match connect_and_run(server_url, worker_name, llm_config).await {
|
||||
Ok(()) => {
|
||||
tracing::info!("Worker connection closed, reconnecting in 5s...");
|
||||
}
|
||||
@@ -75,7 +75,7 @@ pub async fn run(server_url: &str, worker_name: &str) -> anyhow::Result<()> {
|
||||
}
|
||||
}
|
||||
|
||||
async fn connect_and_run(server_url: &str, worker_name: &str) -> anyhow::Result<()> {
|
||||
async fn connect_and_run(server_url: &str, worker_name: &str, llm_config: &crate::LlmConfig) -> anyhow::Result<()> {
|
||||
let (ws_stream, _) = connect_async(server_url).await?;
|
||||
let (mut ws_tx, mut ws_rx) = ws_stream.split();
|
||||
|
||||
@@ -127,17 +127,16 @@ async fn connect_and_run(server_url: &str, worker_name: &str) -> anyhow::Result<
|
||||
workflow_id,
|
||||
project_id,
|
||||
requirement,
|
||||
workdir,
|
||||
instructions,
|
||||
llm_config,
|
||||
template_id: _,
|
||||
initial_state,
|
||||
require_plan_approval,
|
||||
external_tools: _,
|
||||
} => {
|
||||
tracing::info!("Received workflow: {} (project {})", workflow_id, project_id);
|
||||
|
||||
let llm = LlmClient::new(&llm_config);
|
||||
let llm = LlmClient::new(llm_config);
|
||||
let exec = LocalExecutor::new(None);
|
||||
let workdir = format!("/app/data/workspaces/{}", project_id);
|
||||
let instructions = String::new(); // TODO: load from template
|
||||
|
||||
// update channel → serialize → WebSocket
|
||||
let (update_tx, mut update_rx) = mpsc::channel::<AgentUpdate>(64);
|
||||
|
||||
110
src/ws_worker.rs
110
src/ws_worker.rs
@@ -7,44 +7,52 @@ use axum::{
|
||||
Router,
|
||||
};
|
||||
use futures::{SinkExt, StreamExt};
|
||||
use serde::Deserialize;
|
||||
use sqlx::sqlite::SqlitePool;
|
||||
use tokio::sync::broadcast;
|
||||
|
||||
use crate::worker::{WorkerInfo, WorkerManager, WorkerResult};
|
||||
use crate::agent::WsMessage;
|
||||
use crate::worker::{WorkerInfo, WorkerManager, WorkerToServer};
|
||||
|
||||
pub fn router(mgr: Arc<WorkerManager>) -> Router {
|
||||
pub struct WsWorkerState {
|
||||
pub mgr: Arc<WorkerManager>,
|
||||
pub pool: SqlitePool,
|
||||
pub broadcast_fn: Arc<dyn Fn(&str) -> broadcast::Sender<WsMessage> + Send + Sync>,
|
||||
}
|
||||
|
||||
pub fn router(mgr: Arc<WorkerManager>, pool: SqlitePool, broadcast_fn: Arc<dyn Fn(&str) -> broadcast::Sender<WsMessage> + Send + Sync>) -> Router {
|
||||
let state = Arc::new(WsWorkerState { mgr, pool, broadcast_fn });
|
||||
Router::new()
|
||||
.route("/", get(ws_handler))
|
||||
.with_state(mgr)
|
||||
.with_state(state)
|
||||
}
|
||||
|
||||
async fn ws_handler(
|
||||
ws: WebSocketUpgrade,
|
||||
State(mgr): State<Arc<WorkerManager>>,
|
||||
State(state): State<Arc<WsWorkerState>>,
|
||||
) -> Response {
|
||||
ws.on_upgrade(move |socket| handle_worker_socket(socket, mgr))
|
||||
ws.on_upgrade(move |socket| handle_worker_socket(socket, state))
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
#[serde(tag = "type")]
|
||||
enum WorkerMessage {
|
||||
#[serde(rename = "register")]
|
||||
Register { info: WorkerInfo },
|
||||
#[serde(rename = "result")]
|
||||
Result(WorkerResult),
|
||||
}
|
||||
|
||||
async fn handle_worker_socket(socket: WebSocket, mgr: Arc<WorkerManager>) {
|
||||
async fn handle_worker_socket(socket: WebSocket, state: Arc<WsWorkerState>) {
|
||||
let (mut sender, mut receiver) = socket.split();
|
||||
|
||||
// First message must be registration
|
||||
let (name, mut job_rx) = loop {
|
||||
let (name, mut msg_rx) = loop {
|
||||
match receiver.next().await {
|
||||
Some(Ok(Message::Text(text))) => {
|
||||
match serde_json::from_str::<WorkerMessage>(&text) {
|
||||
Ok(WorkerMessage::Register { info }) => {
|
||||
match serde_json::from_str::<serde_json::Value>(&text) {
|
||||
Ok(v) if v["type"] == "register" => {
|
||||
let info: WorkerInfo = match serde_json::from_value(v["info"].clone()) {
|
||||
Ok(i) => i,
|
||||
Err(_) => {
|
||||
let _ = sender.send(Message::Text(
|
||||
r#"{"type":"error","message":"Invalid worker info"}"#.into(),
|
||||
)).await;
|
||||
return;
|
||||
}
|
||||
};
|
||||
let name = info.name.clone();
|
||||
let rx = mgr.register(name.clone(), info).await;
|
||||
// Ack
|
||||
let rx = state.mgr.register(name.clone(), info).await;
|
||||
let ack = serde_json::json!({ "type": "registered", "name": &name });
|
||||
let _ = sender.send(Message::Text(ack.to_string().into())).await;
|
||||
break (name, rx);
|
||||
@@ -62,31 +70,28 @@ async fn handle_worker_socket(socket: WebSocket, mgr: Arc<WorkerManager>) {
|
||||
}
|
||||
};
|
||||
|
||||
// Main loop: forward jobs to worker, receive results
|
||||
let name_clone = name.clone();
|
||||
let mgr_clone = mgr.clone();
|
||||
let mgr_for_cleanup = state.mgr.clone();
|
||||
|
||||
// Task: send jobs from job_rx to the WebSocket
|
||||
// Task: send ServerToWorker messages from msg_rx to the WebSocket
|
||||
let send_task = tokio::spawn(async move {
|
||||
while let Some(req) = job_rx.recv().await {
|
||||
let msg = serde_json::json!({
|
||||
"type": "execute",
|
||||
"job_id": req.job_id,
|
||||
"script": req.script,
|
||||
});
|
||||
if sender.send(Message::Text(msg.to_string().into())).await.is_err() {
|
||||
break;
|
||||
while let Some(msg) = msg_rx.recv().await {
|
||||
if let Ok(json) = serde_json::to_string(&msg) {
|
||||
if sender.send(Message::Text(json.into())).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// Task: receive results from the WebSocket
|
||||
// Task: receive WorkerToServer messages from the WebSocket → process
|
||||
let state_clone = state.clone();
|
||||
let recv_task = tokio::spawn(async move {
|
||||
while let Some(Ok(msg)) = receiver.next().await {
|
||||
match msg {
|
||||
Message::Text(text) => {
|
||||
if let Ok(WorkerMessage::Result(result)) = serde_json::from_str(&text) {
|
||||
mgr_clone.report_result(result).await;
|
||||
if let Ok(worker_msg) = serde_json::from_str::<WorkerToServer>(&text) {
|
||||
handle_worker_message(&state_clone, worker_msg).await;
|
||||
}
|
||||
}
|
||||
Message::Close(_) => break,
|
||||
@@ -100,5 +105,38 @@ async fn handle_worker_socket(socket: WebSocket, mgr: Arc<WorkerManager>) {
|
||||
_ = recv_task => {},
|
||||
}
|
||||
|
||||
mgr.unregister(&name_clone).await;
|
||||
mgr_for_cleanup.unregister(&name_clone).await;
|
||||
}
|
||||
|
||||
async fn handle_worker_message(state: &WsWorkerState, msg: WorkerToServer) {
|
||||
match msg {
|
||||
WorkerToServer::Register { .. } => {
|
||||
// Already handled during initial handshake
|
||||
}
|
||||
WorkerToServer::Update { workflow_id, update } => {
|
||||
// Get project_id for broadcasting (look up from DB)
|
||||
let project_id: Option<String> = sqlx::query_scalar(
|
||||
"SELECT project_id FROM workflows WHERE id = ?"
|
||||
)
|
||||
.bind(&workflow_id)
|
||||
.fetch_optional(&state.pool)
|
||||
.await
|
||||
.ok()
|
||||
.flatten();
|
||||
|
||||
let broadcast_tx = if let Some(ref pid) = project_id {
|
||||
Some((state.broadcast_fn)(pid))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
// Check if this is a workflow completion
|
||||
if let crate::sink::AgentUpdate::WorkflowComplete { ref workflow_id, .. } = update {
|
||||
state.mgr.complete_workflow(workflow_id).await;
|
||||
}
|
||||
|
||||
// Process the update: write to DB + broadcast
|
||||
crate::sink::handle_single_update(&update, &state.pool, broadcast_tx.as_ref()).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user