refactor: worker mode — server offloads all LLM/exec to worker
- Split into `tori server` / `tori worker` subcommands (clap derive) - Extract lib.rs for shared crate (agent, llm, exec, state, etc.) - Introduce AgentUpdate channel to decouple agent loop from DB/broadcast - New sink.rs: AgentUpdate enum + ServiceManager + handle_agent_updates - New worker_runner.rs: connects to server WS, runs full agent loop - Expand worker protocol: ServerToWorker (workflow_assign, comment) and WorkerToServer (register, result, update) - Remove LLM from title generation (heuristic) and template selection (must be explicit) - Remove KB tools (kb_search, kb_read) and remote worker tools (list_workers, execute_on_worker) from agent loop - run_agent_loop/run_step_loop now take mpsc::Sender<AgentUpdate> instead of direct DB pool + broadcast sender
This commit is contained in:
186
Cargo.lock
generated
186
Cargo.lock
generated
@@ -26,6 +26,56 @@ dependencies = [
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "anstream"
|
||||
version = "0.6.21"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "43d5b281e737544384e969a5ccad3f1cdd24b48086a0fc1b2a5262a26b8f4f4a"
|
||||
dependencies = [
|
||||
"anstyle",
|
||||
"anstyle-parse",
|
||||
"anstyle-query",
|
||||
"anstyle-wincon",
|
||||
"colorchoice",
|
||||
"is_terminal_polyfill",
|
||||
"utf8parse",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "anstyle"
|
||||
version = "1.0.14"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "940b3a0ca603d1eade50a4846a2afffd5ef57a9feac2c0e2ec2e14f9ead76000"
|
||||
|
||||
[[package]]
|
||||
name = "anstyle-parse"
|
||||
version = "0.2.7"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4e7644824f0aa2c7b9384579234ef10eb7efb6a0deb83f9630a49594dd9c15c2"
|
||||
dependencies = [
|
||||
"utf8parse",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "anstyle-query"
|
||||
version = "1.1.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "40c48f72fd53cd289104fc64099abca73db4166ad86ea0b4341abe65af83dadc"
|
||||
dependencies = [
|
||||
"windows-sys 0.61.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "anstyle-wincon"
|
||||
version = "3.0.11"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "291e6a250ff86cd4a820112fb8898808a366d8f9f58ce16d1f538353ad55747d"
|
||||
dependencies = [
|
||||
"anstyle",
|
||||
"once_cell_polyfill",
|
||||
"windows-sys 0.61.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "anyhow"
|
||||
version = "1.0.102"
|
||||
@@ -83,7 +133,7 @@ dependencies = [
|
||||
"sha1",
|
||||
"sync_wrapper",
|
||||
"tokio",
|
||||
"tokio-tungstenite",
|
||||
"tokio-tungstenite 0.28.0",
|
||||
"tower",
|
||||
"tower-layer",
|
||||
"tower-service",
|
||||
@@ -216,6 +266,52 @@ dependencies = [
|
||||
"windows-link",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "clap"
|
||||
version = "4.5.60"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2797f34da339ce31042b27d23607e051786132987f595b02ba4f6a6dffb7030a"
|
||||
dependencies = [
|
||||
"clap_builder",
|
||||
"clap_derive",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "clap_builder"
|
||||
version = "4.5.60"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "24a241312cea5059b13574bb9b3861cabf758b879c15190b37b6d6fd63ab6876"
|
||||
dependencies = [
|
||||
"anstream",
|
||||
"anstyle",
|
||||
"clap_lex",
|
||||
"strsim",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "clap_derive"
|
||||
version = "4.5.55"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a92793da1a46a5f2a02a6f4c46c6496b28c43638adea8306fcb0caa1634f24e5"
|
||||
dependencies = [
|
||||
"heck",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "clap_lex"
|
||||
version = "1.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c8d4a3bb8b1e0c1050499d1815f5ab16d04f0959b233085fb31653fbfc9d98f9"
|
||||
|
||||
[[package]]
|
||||
name = "colorchoice"
|
||||
version = "1.0.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1d07550c9036bf2ae0c684c4297d503f838287c83c53686d05370d0e139ae570"
|
||||
|
||||
[[package]]
|
||||
name = "concurrent-queue"
|
||||
version = "2.5.0"
|
||||
@@ -663,6 +759,17 @@ dependencies = [
|
||||
"windows-sys 0.61.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "hostname"
|
||||
version = "0.4.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "617aaa3557aef3810a6369d0a99fac8a080891b68bd9f9812a1eeda0c0730cbd"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"libc",
|
||||
"windows-link",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "http"
|
||||
version = "1.4.0"
|
||||
@@ -750,7 +857,7 @@ dependencies = [
|
||||
"tokio",
|
||||
"tokio-rustls",
|
||||
"tower-service",
|
||||
"webpki-roots",
|
||||
"webpki-roots 1.0.4",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -936,6 +1043,12 @@ dependencies = [
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "is_terminal_polyfill"
|
||||
version = "1.70.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a6cb138bb79a146c1bd460005623e142ef0181e3d0219cb493e02f7d08a35695"
|
||||
|
||||
[[package]]
|
||||
name = "itoa"
|
||||
version = "1.0.17"
|
||||
@@ -1207,6 +1320,12 @@ version = "1.21.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d"
|
||||
|
||||
[[package]]
|
||||
name = "once_cell_polyfill"
|
||||
version = "1.70.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "384b8ab6d37215f3c5301a95a4accb5d64aa607f1fcb26a11b5303878451b4fe"
|
||||
|
||||
[[package]]
|
||||
name = "parking"
|
||||
version = "2.2.1"
|
||||
@@ -1564,7 +1683,7 @@ dependencies = [
|
||||
"wasm-bindgen-futures",
|
||||
"wasm-streams",
|
||||
"web-sys",
|
||||
"webpki-roots",
|
||||
"webpki-roots 1.0.4",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -2063,6 +2182,12 @@ dependencies = [
|
||||
"unicode-properties",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "strsim"
|
||||
version = "0.11.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f"
|
||||
|
||||
[[package]]
|
||||
name = "subtle"
|
||||
version = "2.6.1"
|
||||
@@ -2234,6 +2359,22 @@ dependencies = [
|
||||
"tokio",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tokio-tungstenite"
|
||||
version = "0.26.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7a9daff607c6d2bf6c16fd681ccb7eecc83e4e2cdc1ca067ffaadfca5de7f084"
|
||||
dependencies = [
|
||||
"futures-util",
|
||||
"log",
|
||||
"rustls",
|
||||
"rustls-pki-types",
|
||||
"tokio",
|
||||
"tokio-rustls",
|
||||
"tungstenite 0.26.2",
|
||||
"webpki-roots 0.26.11",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tokio-tungstenite"
|
||||
version = "0.28.0"
|
||||
@@ -2243,7 +2384,7 @@ dependencies = [
|
||||
"futures-util",
|
||||
"log",
|
||||
"tokio",
|
||||
"tungstenite",
|
||||
"tungstenite 0.28.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -2268,7 +2409,9 @@ dependencies = [
|
||||
"axum-extra",
|
||||
"base64",
|
||||
"chrono",
|
||||
"clap",
|
||||
"futures",
|
||||
"hostname",
|
||||
"jsonwebtoken",
|
||||
"mime_guess",
|
||||
"nix",
|
||||
@@ -2280,6 +2423,7 @@ dependencies = [
|
||||
"sqlx",
|
||||
"time",
|
||||
"tokio",
|
||||
"tokio-tungstenite 0.26.2",
|
||||
"tokio-util",
|
||||
"tower-http",
|
||||
"tracing",
|
||||
@@ -2411,6 +2555,25 @@ version = "0.2.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b"
|
||||
|
||||
[[package]]
|
||||
name = "tungstenite"
|
||||
version = "0.26.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4793cb5e56680ecbb1d843515b23b6de9a75eb04b66643e256a396d43be33c13"
|
||||
dependencies = [
|
||||
"bytes",
|
||||
"data-encoding",
|
||||
"http",
|
||||
"httparse",
|
||||
"log",
|
||||
"rand 0.9.2",
|
||||
"rustls",
|
||||
"rustls-pki-types",
|
||||
"sha1",
|
||||
"thiserror",
|
||||
"utf-8",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tungstenite"
|
||||
version = "0.28.0"
|
||||
@@ -2515,6 +2678,12 @@ version = "1.0.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b6c140620e7ffbb22c2dee59cafe6084a59b5ffc27a8859a5f0d494b5d52b6be"
|
||||
|
||||
[[package]]
|
||||
name = "utf8parse"
|
||||
version = "0.2.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821"
|
||||
|
||||
[[package]]
|
||||
name = "uuid"
|
||||
version = "1.21.0"
|
||||
@@ -2709,6 +2878,15 @@ dependencies = [
|
||||
"wasm-bindgen",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "webpki-roots"
|
||||
version = "0.26.11"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "521bc38abb08001b01866da9f51eb7c5d647a19260e00054a8c7fd5f9e57f7a9"
|
||||
dependencies = [
|
||||
"webpki-roots 1.0.4",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "webpki-roots"
|
||||
version = "1.0.4"
|
||||
|
||||
@@ -19,11 +19,14 @@ sqlx = { version = "0.8", features = ["runtime-tokio", "sqlite"] }
|
||||
tower-http = { version = "0.6", features = ["cors", "fs"] }
|
||||
reqwest = { version = "0.12", default-features = false, features = ["json", "stream", "rustls-tls"] }
|
||||
futures = "0.3"
|
||||
tokio-tungstenite = { version = "0.26", features = ["rustls-tls-webpki-roots"] }
|
||||
tracing = "0.1"
|
||||
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
|
||||
chrono = { version = "0.4", features = ["serde"] }
|
||||
uuid = { version = "1", features = ["v4"] }
|
||||
anyhow = "1"
|
||||
clap = { version = "4", features = ["derive", "env"] }
|
||||
hostname = "0.4"
|
||||
mime_guess = "2"
|
||||
tokio-util = { version = "0.7", features = ["io"] }
|
||||
nix = { version = "0.29", features = ["signal"] }
|
||||
|
||||
664
src/agent.rs
664
src/agent.rs
File diff suppressed because it is too large
Load Diff
74
src/lib.rs
Normal file
74
src/lib.rs
Normal file
@@ -0,0 +1,74 @@
|
||||
pub mod api;
|
||||
pub mod agent;
|
||||
pub mod db;
|
||||
pub mod kb;
|
||||
pub mod llm;
|
||||
pub mod exec;
|
||||
pub mod state;
|
||||
pub mod template;
|
||||
pub mod timer;
|
||||
pub mod tools;
|
||||
pub mod worker;
|
||||
pub mod sink;
|
||||
pub mod worker_runner;
|
||||
pub mod ws;
|
||||
pub mod ws_worker;
|
||||
|
||||
use std::sync::Arc;
|
||||
use serde::Deserialize;
|
||||
|
||||
pub struct AppState {
|
||||
pub db: db::Database,
|
||||
pub config: Config,
|
||||
pub agent_mgr: Arc<agent::AgentManager>,
|
||||
pub kb: Option<Arc<kb::KbManager>>,
|
||||
pub obj_root: String,
|
||||
pub auth: Option<api::auth::AuthConfig>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct Config {
|
||||
pub llm: LlmConfig,
|
||||
pub server: ServerConfig,
|
||||
pub database: DatabaseConfig,
|
||||
#[serde(default)]
|
||||
pub template_repo: Option<TemplateRepoConfig>,
|
||||
/// Path to EC private key PEM file for JWT signing
|
||||
#[serde(default)]
|
||||
pub jwt_private_key: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct TemplateRepoConfig {
|
||||
pub gitea_url: String,
|
||||
pub owner: String,
|
||||
pub repo: String,
|
||||
#[serde(default = "default_repo_path")]
|
||||
pub local_path: String,
|
||||
}
|
||||
|
||||
fn default_repo_path() -> String {
|
||||
if std::path::Path::new("/app/oseng-templates").is_dir() {
|
||||
"/app/oseng-templates".to_string()
|
||||
} else {
|
||||
"oseng-templates".to_string()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, serde::Serialize, Deserialize)]
|
||||
pub struct LlmConfig {
|
||||
pub base_url: String,
|
||||
pub api_key: String,
|
||||
pub model: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct ServerConfig {
|
||||
pub host: String,
|
||||
pub port: u16,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct DatabaseConfig {
|
||||
pub path: String,
|
||||
}
|
||||
114
src/main.rs
114
src/main.rs
@@ -1,77 +1,33 @@
|
||||
mod api;
|
||||
mod agent;
|
||||
mod db;
|
||||
mod kb;
|
||||
mod llm;
|
||||
mod exec;
|
||||
pub mod state;
|
||||
mod template;
|
||||
mod timer;
|
||||
mod tools;
|
||||
mod worker;
|
||||
mod ws;
|
||||
mod ws_worker;
|
||||
|
||||
use std::sync::Arc;
|
||||
use axum::Router;
|
||||
use clap::{Parser, Subcommand};
|
||||
use sqlx::sqlite::SqlitePool;
|
||||
use tower_http::cors::CorsLayer;
|
||||
use tower_http::services::{ServeDir, ServeFile};
|
||||
|
||||
pub struct AppState {
|
||||
pub db: db::Database,
|
||||
pub config: Config,
|
||||
pub agent_mgr: Arc<agent::AgentManager>,
|
||||
pub kb: Option<Arc<kb::KbManager>>,
|
||||
pub obj_root: String,
|
||||
pub auth: Option<api::auth::AuthConfig>,
|
||||
use tori::{agent, api, db, kb, template, timer, worker, worker_runner, ws, ws_worker};
|
||||
use tori::{AppState, Config};
|
||||
|
||||
#[derive(Parser)]
|
||||
#[command(name = "tori", about = "Tori AI agent orchestration")]
|
||||
struct Cli {
|
||||
#[command(subcommand)]
|
||||
command: Command,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, serde::Deserialize)]
|
||||
pub struct Config {
|
||||
pub llm: LlmConfig,
|
||||
pub server: ServerConfig,
|
||||
pub database: DatabaseConfig,
|
||||
#[serde(default)]
|
||||
pub template_repo: Option<TemplateRepoConfig>,
|
||||
/// Path to EC private key PEM file for JWT signing
|
||||
#[serde(default)]
|
||||
pub jwt_private_key: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, serde::Deserialize)]
|
||||
pub struct TemplateRepoConfig {
|
||||
pub gitea_url: String,
|
||||
pub owner: String,
|
||||
pub repo: String,
|
||||
#[serde(default = "default_repo_path")]
|
||||
pub local_path: String,
|
||||
}
|
||||
|
||||
fn default_repo_path() -> String {
|
||||
if std::path::Path::new("/app/oseng-templates").is_dir() {
|
||||
"/app/oseng-templates".to_string()
|
||||
} else {
|
||||
"oseng-templates".to_string()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, serde::Deserialize)]
|
||||
pub struct LlmConfig {
|
||||
pub base_url: String,
|
||||
pub api_key: String,
|
||||
pub model: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, serde::Deserialize)]
|
||||
pub struct ServerConfig {
|
||||
pub host: String,
|
||||
pub port: u16,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, serde::Deserialize)]
|
||||
pub struct DatabaseConfig {
|
||||
pub path: String,
|
||||
#[derive(Subcommand)]
|
||||
enum Command {
|
||||
/// Start the API server
|
||||
Server,
|
||||
/// Start a worker that connects to the server
|
||||
Worker {
|
||||
/// Server WebSocket URL
|
||||
#[arg(long, env = "TORI_SERVER", default_value = "ws://127.0.0.1:3000/ws/tori/workers")]
|
||||
server: String,
|
||||
/// Worker name
|
||||
#[arg(long, env = "TORI_WORKER_NAME")]
|
||||
name: Option<String>,
|
||||
},
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
@@ -80,6 +36,22 @@ async fn main() -> anyhow::Result<()> {
|
||||
.with_env_filter("tori=debug,tower_http=debug")
|
||||
.init();
|
||||
|
||||
let cli = Cli::parse();
|
||||
|
||||
match cli.command {
|
||||
Command::Server => run_server().await,
|
||||
Command::Worker { server, name } => {
|
||||
let name = name.unwrap_or_else(|| {
|
||||
hostname::get()
|
||||
.map(|h| h.to_string_lossy().to_string())
|
||||
.unwrap_or_else(|_| "worker-1".to_string())
|
||||
});
|
||||
worker_runner::run(&server, &name).await
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn run_server() -> anyhow::Result<()> {
|
||||
let config_str = std::fs::read_to_string("config.yaml")
|
||||
.expect("Failed to read config.yaml");
|
||||
let config: Config = serde_yaml::from_str(&config_str)
|
||||
@@ -88,7 +60,6 @@ async fn main() -> anyhow::Result<()> {
|
||||
let database = db::Database::new(&config.database.path).await?;
|
||||
database.migrate().await?;
|
||||
|
||||
// Initialize KB manager
|
||||
let kb_arc = match kb::KbManager::new(database.pool.clone()) {
|
||||
Ok(kb) => {
|
||||
tracing::info!("KB manager initialized");
|
||||
@@ -100,7 +71,6 @@ async fn main() -> anyhow::Result<()> {
|
||||
}
|
||||
};
|
||||
|
||||
// Ensure template repo is cloned before serving
|
||||
if let Some(ref repo_cfg) = config.template_repo {
|
||||
template::ensure_repo_ready(repo_cfg).await;
|
||||
}
|
||||
@@ -117,8 +87,6 @@ async fn main() -> anyhow::Result<()> {
|
||||
);
|
||||
|
||||
timer::start_timer_runner(database.pool.clone(), agent_mgr.clone());
|
||||
|
||||
// Resume incomplete workflows after restart
|
||||
resume_workflows(database.pool.clone(), agent_mgr.clone()).await;
|
||||
|
||||
let obj_root = std::env::var("OBJ_ROOT").unwrap_or_else(|_| "/data/obj".to_string());
|
||||
@@ -129,7 +97,6 @@ async fn main() -> anyhow::Result<()> {
|
||||
let public_url = std::env::var("PUBLIC_URL")
|
||||
.unwrap_or_else(|_| "https://tori.euphon.cloud".to_string());
|
||||
|
||||
// Try TikTok SSO first, then Google OAuth
|
||||
if let (Ok(id), Ok(secret)) = (
|
||||
std::env::var("SSO_CLIENT_ID"),
|
||||
std::env::var("SSO_CLIENT_SECRET"),
|
||||
@@ -157,7 +124,7 @@ async fn main() -> anyhow::Result<()> {
|
||||
public_url,
|
||||
})
|
||||
} else {
|
||||
tracing::warn!("No OAuth configured (set SSO_CLIENT_ID/SSO_CLIENT_SECRET or GOOGLE_CLIENT_ID/GOOGLE_CLIENT_SECRET)");
|
||||
tracing::warn!("No OAuth configured");
|
||||
None
|
||||
}
|
||||
};
|
||||
@@ -172,13 +139,10 @@ async fn main() -> anyhow::Result<()> {
|
||||
});
|
||||
|
||||
let app = Router::new()
|
||||
// Health check (public, for k8s probes)
|
||||
.route("/tori/api/health", axum::routing::get(|| async {
|
||||
axum::Json(serde_json::json!({"status": "ok"}))
|
||||
}))
|
||||
// Auth routes are public
|
||||
.nest("/tori/api/auth", api::auth::router(state.clone()))
|
||||
// Protected API routes
|
||||
.nest("/tori/api", api::router(state.clone())
|
||||
.layer(axum::middleware::from_fn_with_state(state.clone(), api::auth::require_auth))
|
||||
)
|
||||
|
||||
239
src/sink.rs
Normal file
239
src/sink.rs
Normal file
@@ -0,0 +1,239 @@
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::{AtomicU16, Ordering};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use sqlx::sqlite::SqlitePool;
|
||||
use tokio::sync::{RwLock, broadcast, mpsc};
|
||||
|
||||
use crate::agent::{PlanStepInfo, WsMessage, ServiceInfo};
|
||||
use crate::state::{AgentState, Artifact};
|
||||
|
||||
/// All updates produced by the agent loop. This is the single output interface
|
||||
/// that decouples the agent logic from DB persistence and WebSocket broadcasting.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(tag = "kind")]
|
||||
pub enum AgentUpdate {
|
||||
PlanUpdate {
|
||||
workflow_id: String,
|
||||
steps: Vec<PlanStepInfo>,
|
||||
},
|
||||
WorkflowStatus {
|
||||
workflow_id: String,
|
||||
status: String,
|
||||
},
|
||||
Activity {
|
||||
workflow_id: String,
|
||||
activity: String,
|
||||
},
|
||||
ExecutionLog {
|
||||
workflow_id: String,
|
||||
step_order: i32,
|
||||
tool_name: String,
|
||||
tool_input: String,
|
||||
output: String,
|
||||
status: String,
|
||||
},
|
||||
LlmCallLog {
|
||||
workflow_id: String,
|
||||
step_order: i32,
|
||||
phase: String,
|
||||
messages_count: i32,
|
||||
tools_count: i32,
|
||||
tool_calls: String,
|
||||
text_response: String,
|
||||
prompt_tokens: Option<u32>,
|
||||
completion_tokens: Option<u32>,
|
||||
latency_ms: i64,
|
||||
},
|
||||
StateSnapshot {
|
||||
workflow_id: String,
|
||||
step_order: i32,
|
||||
state: AgentState,
|
||||
},
|
||||
WorkflowComplete {
|
||||
workflow_id: String,
|
||||
status: String,
|
||||
report: Option<String>,
|
||||
},
|
||||
ArtifactSave {
|
||||
workflow_id: String,
|
||||
step_order: i32,
|
||||
artifact: Artifact,
|
||||
},
|
||||
RequirementUpdate {
|
||||
workflow_id: String,
|
||||
requirement: String,
|
||||
},
|
||||
Error {
|
||||
message: String,
|
||||
},
|
||||
}
|
||||
|
||||
/// Manages local services (start_service / stop_service tools).
|
||||
/// Created per-worker or per-agent-loop.
|
||||
pub struct ServiceManager {
|
||||
pub services: RwLock<HashMap<String, ServiceInfo>>,
|
||||
next_port: AtomicU16,
|
||||
}
|
||||
|
||||
impl ServiceManager {
|
||||
pub fn new(start_port: u16) -> Arc<Self> {
|
||||
Arc::new(Self {
|
||||
services: RwLock::new(HashMap::new()),
|
||||
next_port: AtomicU16::new(start_port),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn allocate_port(&self) -> u16 {
|
||||
self.next_port.fetch_add(1, Ordering::Relaxed)
|
||||
}
|
||||
}
|
||||
|
||||
/// Server-side handler: consumes AgentUpdate from channel, persists to DB and broadcasts to frontend.
|
||||
pub async fn handle_agent_updates(
|
||||
mut rx: mpsc::Receiver<AgentUpdate>,
|
||||
pool: SqlitePool,
|
||||
broadcast_tx: broadcast::Sender<WsMessage>,
|
||||
) {
|
||||
while let Some(update) = rx.recv().await {
|
||||
match update {
|
||||
AgentUpdate::PlanUpdate { workflow_id, steps } => {
|
||||
let _ = broadcast_tx.send(WsMessage::PlanUpdate { workflow_id, steps });
|
||||
}
|
||||
AgentUpdate::WorkflowStatus { ref workflow_id, ref status } => {
|
||||
let _ = sqlx::query("UPDATE workflows SET status = ? WHERE id = ?")
|
||||
.bind(status)
|
||||
.bind(workflow_id)
|
||||
.execute(&pool)
|
||||
.await;
|
||||
let _ = broadcast_tx.send(WsMessage::WorkflowStatusUpdate {
|
||||
workflow_id: workflow_id.clone(),
|
||||
status: status.clone(),
|
||||
});
|
||||
}
|
||||
AgentUpdate::Activity { workflow_id, activity } => {
|
||||
let _ = broadcast_tx.send(WsMessage::ActivityUpdate { workflow_id, activity });
|
||||
}
|
||||
AgentUpdate::ExecutionLog { ref workflow_id, step_order, ref tool_name, ref tool_input, ref output, ref status } => {
|
||||
let id = uuid::Uuid::new_v4().to_string();
|
||||
let _ = sqlx::query(
|
||||
"INSERT INTO execution_log (id, workflow_id, step_order, tool_name, tool_input, output, status, created_at) VALUES (?, ?, ?, ?, ?, ?, ?, datetime('now'))"
|
||||
)
|
||||
.bind(&id)
|
||||
.bind(workflow_id)
|
||||
.bind(step_order)
|
||||
.bind(tool_name)
|
||||
.bind(tool_input)
|
||||
.bind(output)
|
||||
.bind(status)
|
||||
.execute(&pool)
|
||||
.await;
|
||||
let _ = broadcast_tx.send(WsMessage::StepStatusUpdate {
|
||||
step_id: id,
|
||||
status: status.clone(),
|
||||
output: output.clone(),
|
||||
});
|
||||
}
|
||||
AgentUpdate::LlmCallLog { ref workflow_id, step_order, ref phase, messages_count, tools_count, ref tool_calls, ref text_response, prompt_tokens, completion_tokens, latency_ms } => {
|
||||
let id = uuid::Uuid::new_v4().to_string();
|
||||
let _ = sqlx::query(
|
||||
"INSERT INTO llm_call_log (id, workflow_id, step_order, phase, messages_count, tools_count, tool_calls, text_response, prompt_tokens, completion_tokens, latency_ms, created_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, datetime('now'))"
|
||||
)
|
||||
.bind(&id)
|
||||
.bind(workflow_id)
|
||||
.bind(step_order)
|
||||
.bind(phase)
|
||||
.bind(messages_count)
|
||||
.bind(tools_count)
|
||||
.bind(tool_calls)
|
||||
.bind(text_response)
|
||||
.bind(prompt_tokens.map(|v| v as i32))
|
||||
.bind(completion_tokens.map(|v| v as i32))
|
||||
.bind(latency_ms as i32)
|
||||
.execute(&pool)
|
||||
.await;
|
||||
let entry = crate::db::LlmCallLogEntry {
|
||||
id,
|
||||
workflow_id: workflow_id.clone(),
|
||||
step_order,
|
||||
phase: phase.clone(),
|
||||
messages_count,
|
||||
tools_count,
|
||||
tool_calls: tool_calls.clone(),
|
||||
text_response: text_response.clone(),
|
||||
prompt_tokens: prompt_tokens.map(|v| v as i32),
|
||||
completion_tokens: completion_tokens.map(|v| v as i32),
|
||||
latency_ms: latency_ms as i32,
|
||||
created_at: String::new(),
|
||||
};
|
||||
let _ = broadcast_tx.send(WsMessage::LlmCallLog {
|
||||
workflow_id: workflow_id.clone(),
|
||||
entry,
|
||||
});
|
||||
}
|
||||
AgentUpdate::StateSnapshot { ref workflow_id, step_order, ref state } => {
|
||||
let id = uuid::Uuid::new_v4().to_string();
|
||||
let json = serde_json::to_string(state).unwrap_or_default();
|
||||
let _ = sqlx::query(
|
||||
"INSERT INTO agent_state_snapshots (id, workflow_id, step_order, state_json, created_at) VALUES (?, ?, ?, ?, datetime('now'))"
|
||||
)
|
||||
.bind(&id)
|
||||
.bind(workflow_id)
|
||||
.bind(step_order)
|
||||
.bind(&json)
|
||||
.execute(&pool)
|
||||
.await;
|
||||
}
|
||||
AgentUpdate::WorkflowComplete { ref workflow_id, ref status, ref report } => {
|
||||
let _ = sqlx::query("UPDATE workflows SET status = ? WHERE id = ?")
|
||||
.bind(status)
|
||||
.bind(workflow_id)
|
||||
.execute(&pool)
|
||||
.await;
|
||||
if let Some(ref r) = report {
|
||||
let _ = sqlx::query("UPDATE workflows SET report = ? WHERE id = ?")
|
||||
.bind(r)
|
||||
.bind(workflow_id)
|
||||
.execute(&pool)
|
||||
.await;
|
||||
let _ = broadcast_tx.send(WsMessage::ReportReady {
|
||||
workflow_id: workflow_id.clone(),
|
||||
});
|
||||
}
|
||||
let _ = broadcast_tx.send(WsMessage::WorkflowStatusUpdate {
|
||||
workflow_id: workflow_id.clone(),
|
||||
status: status.clone(),
|
||||
});
|
||||
}
|
||||
AgentUpdate::ArtifactSave { ref workflow_id, step_order, ref artifact } => {
|
||||
let id = uuid::Uuid::new_v4().to_string();
|
||||
let _ = sqlx::query(
|
||||
"INSERT INTO step_artifacts (id, workflow_id, step_order, name, path, artifact_type, description) VALUES (?, ?, ?, ?, ?, ?, ?)"
|
||||
)
|
||||
.bind(&id)
|
||||
.bind(workflow_id)
|
||||
.bind(step_order)
|
||||
.bind(&artifact.name)
|
||||
.bind(&artifact.path)
|
||||
.bind(&artifact.artifact_type)
|
||||
.bind(&artifact.description)
|
||||
.execute(&pool)
|
||||
.await;
|
||||
}
|
||||
AgentUpdate::RequirementUpdate { ref workflow_id, ref requirement } => {
|
||||
let _ = sqlx::query("UPDATE workflows SET requirement = ? WHERE id = ?")
|
||||
.bind(requirement)
|
||||
.bind(workflow_id)
|
||||
.execute(&pool)
|
||||
.await;
|
||||
let _ = broadcast_tx.send(WsMessage::RequirementUpdate {
|
||||
workflow_id: workflow_id.clone(),
|
||||
requirement: requirement.clone(),
|
||||
});
|
||||
}
|
||||
AgentUpdate::Error { message } => {
|
||||
let _ = broadcast_tx.send(WsMessage::Error { message });
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -3,7 +3,6 @@ use std::path::{Path, PathBuf};
|
||||
use serde::Deserialize;
|
||||
|
||||
use crate::TemplateRepoConfig;
|
||||
use crate::llm::{ChatMessage, LlmClient};
|
||||
use crate::tools::ExternalToolManager;
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
@@ -463,42 +462,6 @@ pub fn is_repo_template(template_id: &str) -> bool {
|
||||
template_id.contains('/')
|
||||
}
|
||||
|
||||
// --- LLM template selection ---
|
||||
|
||||
pub async fn select_template(llm: &LlmClient, requirement: &str, repo_cfg: Option<&TemplateRepoConfig>) -> Option<String> {
|
||||
let all = list_all_templates(repo_cfg).await;
|
||||
if all.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let listing: String = all
|
||||
.iter()
|
||||
.map(|t| format!("- id: {}\n 名称: {}\n 描述: {}", t.id, t.name, t.description))
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n");
|
||||
|
||||
let prompt = format!(
|
||||
"以下是可用的项目模板:\n{}\n\n用户需求:{}\n\n选择最匹配的模板 ID,如果都不合适则回复 none。只回复模板 ID 或 none,不要其他内容。",
|
||||
listing, requirement
|
||||
);
|
||||
|
||||
let response = llm
|
||||
.chat(vec![
|
||||
ChatMessage::system("你是一个模板选择助手。根据用户需求选择最合适的项目模板。只回复模板 ID 或 none。"),
|
||||
ChatMessage::user(&prompt),
|
||||
])
|
||||
.await
|
||||
.ok()?;
|
||||
|
||||
let answer = response.trim().to_lowercase();
|
||||
tracing::info!("Template selection LLM response: '{}' (available: {:?})",
|
||||
answer, all.iter().map(|t| t.id.as_str()).collect::<Vec<_>>());
|
||||
if answer == "none" {
|
||||
return None;
|
||||
}
|
||||
|
||||
all.iter().find(|t| t.id == answer).map(|t| t.id.clone())
|
||||
}
|
||||
|
||||
// --- Template loading ---
|
||||
|
||||
|
||||
@@ -131,3 +131,56 @@ impl WorkerManager {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// --- Extended protocol for workflow execution ---
|
||||
|
||||
use crate::LlmConfig;
|
||||
use crate::state::AgentState;
|
||||
use crate::llm::Tool;
|
||||
|
||||
/// Messages sent from server to worker.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(tag = "type")]
|
||||
pub enum ServerToWorker {
|
||||
/// Assign a full workflow for execution.
|
||||
#[serde(rename = "workflow_assign")]
|
||||
WorkflowAssign {
|
||||
workflow_id: String,
|
||||
project_id: String,
|
||||
requirement: String,
|
||||
workdir: String,
|
||||
instructions: String,
|
||||
llm_config: LlmConfig,
|
||||
#[serde(default)]
|
||||
initial_state: Option<AgentState>,
|
||||
#[serde(default)]
|
||||
require_plan_approval: bool,
|
||||
#[serde(default)]
|
||||
external_tools: Vec<Tool>,
|
||||
},
|
||||
/// Forward a user comment to the worker executing this workflow.
|
||||
#[serde(rename = "comment")]
|
||||
Comment {
|
||||
workflow_id: String,
|
||||
content: String,
|
||||
},
|
||||
}
|
||||
|
||||
/// Messages sent from worker to server.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(tag = "type")]
|
||||
pub enum WorkerToServer {
|
||||
/// Worker registration.
|
||||
#[serde(rename = "register")]
|
||||
Register { info: WorkerInfo },
|
||||
/// Script execution result (legacy).
|
||||
#[serde(rename = "result")]
|
||||
Result(WorkerResult),
|
||||
/// Agent update from workflow execution.
|
||||
#[serde(rename = "update")]
|
||||
Update {
|
||||
workflow_id: String,
|
||||
#[serde(flatten)]
|
||||
update: crate::sink::AgentUpdate,
|
||||
},
|
||||
}
|
||||
|
||||
203
src/worker_runner.rs
Normal file
203
src/worker_runner.rs
Normal file
@@ -0,0 +1,203 @@
|
||||
use std::sync::Arc;
|
||||
use futures::{SinkExt, StreamExt};
|
||||
use tokio::sync::mpsc;
|
||||
use tokio_tungstenite::{connect_async, tungstenite::Message};
|
||||
|
||||
use crate::agent::{self, AgentEvent};
|
||||
use crate::exec::LocalExecutor;
|
||||
use crate::llm::LlmClient;
|
||||
use crate::sink::{AgentUpdate, ServiceManager};
|
||||
use crate::worker::{ServerToWorker, WorkerInfo, WorkerToServer};
|
||||
|
||||
fn collect_worker_info(name: &str) -> WorkerInfo {
|
||||
let cpu = std::fs::read_to_string("/proc/cpuinfo")
|
||||
.ok()
|
||||
.and_then(|s| {
|
||||
s.lines()
|
||||
.find(|l| l.starts_with("model name"))
|
||||
.map(|l| l.split(':').nth(1).unwrap_or("").trim().to_string())
|
||||
})
|
||||
.unwrap_or_else(|| "unknown".into());
|
||||
|
||||
let memory = std::fs::read_to_string("/proc/meminfo")
|
||||
.ok()
|
||||
.and_then(|s| {
|
||||
s.lines()
|
||||
.find(|l| l.starts_with("MemTotal"))
|
||||
.and_then(|l| l.split_whitespace().nth(1))
|
||||
.and_then(|kb| kb.parse::<u64>().ok())
|
||||
.map(|kb| format!("{:.1} GB", kb as f64 / 1_048_576.0))
|
||||
})
|
||||
.unwrap_or_else(|| "unknown".into());
|
||||
|
||||
let gpu = std::process::Command::new("nvidia-smi")
|
||||
.arg("--query-gpu=name")
|
||||
.arg("--format=csv,noheader")
|
||||
.output()
|
||||
.ok()
|
||||
.and_then(|o| {
|
||||
if o.status.success() {
|
||||
Some(String::from_utf8_lossy(&o.stdout).trim().to_string())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.unwrap_or_else(|| "none".into());
|
||||
|
||||
WorkerInfo {
|
||||
name: name.to_string(),
|
||||
cpu,
|
||||
memory,
|
||||
gpu,
|
||||
os: std::env::consts::OS.to_string(),
|
||||
kernel: std::process::Command::new("uname")
|
||||
.arg("-r")
|
||||
.output()
|
||||
.ok()
|
||||
.map(|o| String::from_utf8_lossy(&o.stdout).trim().to_string())
|
||||
.unwrap_or_else(|| "unknown".into()),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn run(server_url: &str, worker_name: &str) -> anyhow::Result<()> {
|
||||
tracing::info!("Tori worker '{}' connecting to {}", worker_name, server_url);
|
||||
|
||||
loop {
|
||||
match connect_and_run(server_url, worker_name).await {
|
||||
Ok(()) => {
|
||||
tracing::info!("Worker connection closed, reconnecting in 5s...");
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!("Worker error: {}, reconnecting in 5s...", e);
|
||||
}
|
||||
}
|
||||
tokio::time::sleep(std::time::Duration::from_secs(5)).await;
|
||||
}
|
||||
}
|
||||
|
||||
async fn connect_and_run(server_url: &str, worker_name: &str) -> anyhow::Result<()> {
|
||||
let (ws_stream, _) = connect_async(server_url).await?;
|
||||
let (mut ws_tx, mut ws_rx) = ws_stream.split();
|
||||
|
||||
// Register
|
||||
let info = collect_worker_info(worker_name);
|
||||
let register_msg = serde_json::to_string(&WorkerToServer::Register { info })?;
|
||||
ws_tx.send(Message::Text(register_msg.into())).await?;
|
||||
|
||||
// Wait for registration ack
|
||||
while let Some(msg) = ws_rx.next().await {
|
||||
match msg? {
|
||||
Message::Text(text) => {
|
||||
let v: serde_json::Value = serde_json::from_str(&text)?;
|
||||
if v["type"] == "registered" {
|
||||
tracing::info!("Registered as '{}'", v["name"]);
|
||||
break;
|
||||
}
|
||||
}
|
||||
Message::Close(_) => anyhow::bail!("Connection closed during registration"),
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
let svc_mgr = ServiceManager::new(9100);
|
||||
let ws_tx = Arc::new(tokio::sync::Mutex::new(ws_tx));
|
||||
|
||||
// Channel for forwarding comments to the running workflow
|
||||
let comment_tx: Arc<tokio::sync::Mutex<Option<mpsc::Sender<AgentEvent>>>> =
|
||||
Arc::new(tokio::sync::Mutex::new(None));
|
||||
|
||||
// Main message loop
|
||||
while let Some(msg) = ws_rx.next().await {
|
||||
let text = match msg? {
|
||||
Message::Text(t) => t,
|
||||
Message::Close(_) => break,
|
||||
_ => continue,
|
||||
};
|
||||
|
||||
let server_msg: ServerToWorker = match serde_json::from_str(&text) {
|
||||
Ok(m) => m,
|
||||
Err(e) => {
|
||||
tracing::warn!("Failed to parse server message: {}", e);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
match server_msg {
|
||||
ServerToWorker::WorkflowAssign {
|
||||
workflow_id,
|
||||
project_id,
|
||||
requirement,
|
||||
workdir,
|
||||
instructions,
|
||||
llm_config,
|
||||
initial_state,
|
||||
require_plan_approval,
|
||||
external_tools: _,
|
||||
} => {
|
||||
tracing::info!("Received workflow: {} (project {})", workflow_id, project_id);
|
||||
|
||||
let llm = LlmClient::new(&llm_config);
|
||||
let exec = LocalExecutor::new(None);
|
||||
|
||||
// update channel → serialize → WebSocket
|
||||
let (update_tx, mut update_rx) = mpsc::channel::<AgentUpdate>(64);
|
||||
let ws_tx_clone = ws_tx.clone();
|
||||
let wf_id_clone = workflow_id.clone();
|
||||
tokio::spawn(async move {
|
||||
while let Some(update) = update_rx.recv().await {
|
||||
let msg = WorkerToServer::Update {
|
||||
workflow_id: wf_id_clone.clone(),
|
||||
update,
|
||||
};
|
||||
if let Ok(json) = serde_json::to_string(&msg) {
|
||||
let mut tx = ws_tx_clone.lock().await;
|
||||
if tx.send(Message::Text(json.into())).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// event channel for comments
|
||||
let (evt_tx, mut evt_rx) = mpsc::channel::<AgentEvent>(32);
|
||||
*comment_tx.lock().await = Some(evt_tx);
|
||||
|
||||
let _ = tokio::fs::create_dir_all(&workdir).await;
|
||||
|
||||
let result = agent::run_agent_loop(
|
||||
&llm, &exec, &update_tx, &mut evt_rx,
|
||||
&project_id, &workflow_id, &requirement, &workdir, &svc_mgr,
|
||||
&instructions, initial_state, None, require_plan_approval,
|
||||
).await;
|
||||
|
||||
let final_status = if result.is_ok() { "done" } else { "failed" };
|
||||
if let Err(e) = &result {
|
||||
tracing::error!("Workflow {} failed: {}", workflow_id, e);
|
||||
let _ = update_tx.send(AgentUpdate::Error {
|
||||
message: format!("Agent error: {}", e),
|
||||
}).await;
|
||||
}
|
||||
|
||||
let _ = update_tx.send(AgentUpdate::WorkflowComplete {
|
||||
workflow_id: workflow_id.clone(),
|
||||
status: final_status.into(),
|
||||
report: None,
|
||||
}).await;
|
||||
|
||||
*comment_tx.lock().await = None;
|
||||
tracing::info!("Workflow {} completed: {}", workflow_id, final_status);
|
||||
}
|
||||
|
||||
ServerToWorker::Comment { workflow_id, content } => {
|
||||
if let Some(ref tx) = *comment_tx.lock().await {
|
||||
let _ = tx.send(AgentEvent::Comment {
|
||||
workflow_id,
|
||||
content,
|
||||
}).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
Reference in New Issue
Block a user