- Remove agent_loop from server (was ~400 lines) — server dispatches to workers - AgentManager simplified to pure dispatcher (send_event → worker) - Remove LLM config requirement from server (workers bring their own via config.yaml) - Remove process_feedback, build_feedback_tools from server - Remove chat API endpoint (LLM on workers only) - Remove service proxy (services run on workers) - Worker reads LLM config from its own config.yaml - ws_worker.rs handles WorkerToServer::Update messages (DB + broadcast) - Verified locally: tori server + tori worker connect and register
215 lines
7.3 KiB
Rust
215 lines
7.3 KiB
Rust
use std::sync::Arc;
|
|
use axum::Router;
|
|
use clap::{Parser, Subcommand};
|
|
use sqlx::sqlite::SqlitePool;
|
|
use tower_http::cors::CorsLayer;
|
|
use tower_http::services::{ServeDir, ServeFile};
|
|
|
|
use tori::{agent, api, db, kb, template, timer, worker, worker_runner, ws, ws_worker};
|
|
use tori::{AppState, Config};
|
|
|
|
#[derive(Parser)]
|
|
#[command(name = "tori", about = "Tori AI agent orchestration")]
|
|
struct Cli {
|
|
#[command(subcommand)]
|
|
command: Command,
|
|
}
|
|
|
|
#[derive(Subcommand)]
|
|
enum Command {
|
|
/// Start the API server
|
|
Server,
|
|
/// Start a worker that connects to the server
|
|
Worker {
|
|
/// Server WebSocket URL
|
|
#[arg(long, env = "TORI_SERVER", default_value = "ws://127.0.0.1:3000/ws/tori/workers")]
|
|
server: String,
|
|
/// Worker name
|
|
#[arg(long, env = "TORI_WORKER_NAME")]
|
|
name: Option<String>,
|
|
/// Config file path (for LLM settings)
|
|
#[arg(long, default_value = "config.yaml")]
|
|
config: String,
|
|
},
|
|
}
|
|
|
|
#[tokio::main]
|
|
async fn main() -> anyhow::Result<()> {
|
|
tracing_subscriber::fmt()
|
|
.with_env_filter("tori=debug,tower_http=debug")
|
|
.init();
|
|
|
|
let cli = Cli::parse();
|
|
|
|
match cli.command {
|
|
Command::Server => run_server().await,
|
|
Command::Worker { server, name, config } => {
|
|
let name = name.unwrap_or_else(|| {
|
|
hostname::get()
|
|
.map(|h| h.to_string_lossy().to_string())
|
|
.unwrap_or_else(|_| "worker-1".to_string())
|
|
});
|
|
let config_str = std::fs::read_to_string(&config)
|
|
.unwrap_or_else(|e| panic!("Failed to read {}: {}", config, e));
|
|
let cfg: Config = serde_yaml::from_str(&config_str)
|
|
.unwrap_or_else(|e| panic!("Failed to parse {}: {}", config, e));
|
|
let llm_config = cfg.llm
|
|
.unwrap_or_else(|| panic!("LLM config required in {} for worker mode", config));
|
|
worker_runner::run(&server, &name, &llm_config).await
|
|
}
|
|
}
|
|
}
|
|
|
|
async fn run_server() -> anyhow::Result<()> {
|
|
let config_str = std::fs::read_to_string("config.yaml")
|
|
.expect("Failed to read config.yaml");
|
|
let config: Config = serde_yaml::from_str(&config_str)
|
|
.expect("Failed to parse config.yaml");
|
|
|
|
let database = db::Database::new(&config.database.path).await?;
|
|
database.migrate().await?;
|
|
|
|
let kb_arc = match kb::KbManager::new(database.pool.clone()) {
|
|
Ok(kb) => {
|
|
tracing::info!("KB manager initialized");
|
|
Some(Arc::new(kb))
|
|
}
|
|
Err(e) => {
|
|
tracing::warn!("KB manager init failed (will retry on use): {}", e);
|
|
None
|
|
}
|
|
};
|
|
|
|
if let Some(ref repo_cfg) = config.template_repo {
|
|
template::ensure_repo_ready(repo_cfg).await;
|
|
}
|
|
|
|
let worker_mgr = worker::WorkerManager::new();
|
|
|
|
let agent_mgr = agent::AgentManager::new(
|
|
database.pool.clone(),
|
|
worker_mgr.clone(),
|
|
);
|
|
|
|
timer::start_timer_runner(database.pool.clone(), agent_mgr.clone());
|
|
resume_workflows(database.pool.clone(), agent_mgr.clone()).await;
|
|
|
|
let obj_root = std::env::var("OBJ_ROOT").unwrap_or_else(|_| "/data/obj".to_string());
|
|
|
|
let auth_config = {
|
|
let jwt_secret = std::env::var("JWT_SECRET")
|
|
.unwrap_or_else(|_| uuid::Uuid::new_v4().to_string());
|
|
let public_url = std::env::var("PUBLIC_URL")
|
|
.unwrap_or_else(|_| "https://tori.euphon.cloud".to_string());
|
|
|
|
if let (Ok(id), Ok(secret)) = (
|
|
std::env::var("SSO_CLIENT_ID"),
|
|
std::env::var("SSO_CLIENT_SECRET"),
|
|
) {
|
|
tracing::info!("TikTok SSO enabled (public_url={})", public_url);
|
|
Some(api::auth::AuthConfig {
|
|
provider: api::auth::OAuthProvider::TikTokSso {
|
|
client_id: id,
|
|
client_secret: secret,
|
|
},
|
|
jwt_secret,
|
|
public_url,
|
|
})
|
|
} else if let (Ok(id), Ok(secret)) = (
|
|
std::env::var("GOOGLE_CLIENT_ID"),
|
|
std::env::var("GOOGLE_CLIENT_SECRET"),
|
|
) {
|
|
tracing::info!("Google OAuth enabled (public_url={})", public_url);
|
|
Some(api::auth::AuthConfig {
|
|
provider: api::auth::OAuthProvider::Google {
|
|
client_id: id,
|
|
client_secret: secret,
|
|
},
|
|
jwt_secret,
|
|
public_url,
|
|
})
|
|
} else {
|
|
tracing::warn!("No OAuth configured");
|
|
None
|
|
}
|
|
};
|
|
|
|
let ws_pool = database.pool.clone();
|
|
let state = Arc::new(AppState {
|
|
db: database,
|
|
config: config.clone(),
|
|
agent_mgr: agent_mgr.clone(),
|
|
kb: kb_arc,
|
|
obj_root: obj_root.clone(),
|
|
auth: auth_config,
|
|
});
|
|
|
|
let app = Router::new()
|
|
.route("/tori/api/health", axum::routing::get(|| async {
|
|
axum::Json(serde_json::json!({"status": "ok"}))
|
|
}))
|
|
.nest("/tori/api/auth", api::auth::router(state.clone()))
|
|
.nest("/tori/api", api::router(state.clone())
|
|
.layer(axum::middleware::from_fn_with_state(state.clone(), api::auth::require_auth))
|
|
)
|
|
.nest("/api/obj", api::obj::router(obj_root.clone()))
|
|
.route("/api/obj/", axum::routing::get({
|
|
let r = obj_root;
|
|
move || api::obj::root_listing(r)
|
|
}))
|
|
.nest("/ws/tori/workers", {
|
|
let pool = ws_pool;
|
|
let agent_mgr_for_ws = agent_mgr.clone();
|
|
let broadcast_fn: Arc<dyn Fn(&str) -> tokio::sync::broadcast::Sender<agent::WsMessage> + Send + Sync> =
|
|
Arc::new(move |pid: &str| agent_mgr_for_ws.get_broadcast_sender(pid));
|
|
ws_worker::router(worker_mgr, pool, broadcast_fn)
|
|
})
|
|
.nest("/ws/tori", ws::router(agent_mgr))
|
|
.nest_service("/tori", ServeDir::new("web/dist").fallback(ServeFile::new("web/dist/index.html")))
|
|
.route("/", axum::routing::get(|| async {
|
|
axum::response::Redirect::permanent("/tori/")
|
|
}))
|
|
.layer(CorsLayer::permissive());
|
|
|
|
let addr = format!("{}:{}", &config.server.host, config.server.port);
|
|
tracing::info!("Tori server listening on {}", addr);
|
|
let listener = tokio::net::TcpListener::bind(&addr).await?;
|
|
axum::serve(listener, app).await?;
|
|
|
|
Ok(())
|
|
}
|
|
|
|
async fn resume_workflows(pool: SqlitePool, agent_mgr: Arc<agent::AgentManager>) {
|
|
let rows: Vec<(String, String, String)> = match sqlx::query_as(
|
|
"SELECT w.id, w.project_id, w.requirement FROM workflows w \
|
|
JOIN projects p ON w.project_id = p.id \
|
|
WHERE w.status IN ('pending', 'planning', 'executing') \
|
|
AND p.deleted = 0 \
|
|
ORDER BY w.created_at ASC"
|
|
)
|
|
.fetch_all(&pool)
|
|
.await
|
|
{
|
|
Ok(r) => r,
|
|
Err(e) => {
|
|
tracing::error!("Failed to query incomplete workflows: {}", e);
|
|
return;
|
|
}
|
|
};
|
|
|
|
if rows.is_empty() {
|
|
tracing::info!("No incomplete workflows to resume");
|
|
return;
|
|
}
|
|
|
|
tracing::info!("Resuming {} incomplete workflow(s)", rows.len());
|
|
for (workflow_id, project_id, requirement) in rows {
|
|
tracing::info!("Resuming workflow {} (project {})", workflow_id, project_id);
|
|
agent_mgr.send_event(&project_id, agent::AgentEvent::NewRequirement {
|
|
workflow_id,
|
|
requirement,
|
|
template_id: None,
|
|
}).await;
|
|
}
|
|
}
|