use std::sync::Arc; use axum::Router; use clap::{Parser, Subcommand}; use sqlx::sqlite::SqlitePool; use tower_http::cors::CorsLayer; use tower_http::services::{ServeDir, ServeFile}; use tori::{agent, api, db, kb, template, timer, worker, worker_runner, ws, ws_worker}; use tori::{AppState, Config}; #[derive(Parser)] #[command(name = "tori", about = "Tori AI agent orchestration")] struct Cli { #[command(subcommand)] command: Command, } #[derive(Subcommand)] enum Command { /// Start the API server Server, /// Start a worker that connects to the server Worker { /// Server WebSocket URL #[arg(long, env = "TORI_SERVER", default_value = "ws://127.0.0.1:3000/ws/tori/workers")] server: String, /// Worker name #[arg(long, env = "TORI_WORKER_NAME")] name: Option, /// Config file path (for LLM settings) #[arg(long, default_value = "config.yaml")] config: String, }, } #[tokio::main] async fn main() -> anyhow::Result<()> { tracing_subscriber::fmt() .with_env_filter("tori=debug,tower_http=debug") .init(); let cli = Cli::parse(); match cli.command { Command::Server => run_server().await, Command::Worker { server, name, config } => { let name = name.unwrap_or_else(|| { hostname::get() .map(|h| h.to_string_lossy().to_string()) .unwrap_or_else(|_| "worker-1".to_string()) }); let config_str = std::fs::read_to_string(&config) .unwrap_or_else(|e| panic!("Failed to read {}: {}", config, e)); let cfg: Config = serde_yaml::from_str(&config_str) .unwrap_or_else(|e| panic!("Failed to parse {}: {}", config, e)); let llm_config = cfg.llm .unwrap_or_else(|| panic!("LLM config required in {} for worker mode", config)); worker_runner::run(&server, &name, &llm_config).await } } } async fn run_server() -> anyhow::Result<()> { let config_str = std::fs::read_to_string("config.yaml") .expect("Failed to read config.yaml"); let config: Config = serde_yaml::from_str(&config_str) .expect("Failed to parse config.yaml"); let database = db::Database::new(&config.database.path).await?; database.migrate().await?; let kb_arc = match kb::KbManager::new(database.pool.clone()) { Ok(kb) => { tracing::info!("KB manager initialized"); Some(Arc::new(kb)) } Err(e) => { tracing::warn!("KB manager init failed (will retry on use): {}", e); None } }; if let Some(ref repo_cfg) = config.template_repo { template::ensure_repo_ready(repo_cfg).await; } let worker_mgr = worker::WorkerManager::new(); let agent_mgr = agent::AgentManager::new( database.pool.clone(), worker_mgr.clone(), ); timer::start_timer_runner(database.pool.clone(), agent_mgr.clone()); resume_workflows(database.pool.clone(), agent_mgr.clone()).await; let obj_root = std::env::var("OBJ_ROOT").unwrap_or_else(|_| "/data/obj".to_string()); let auth_config = { let jwt_secret = std::env::var("JWT_SECRET") .unwrap_or_else(|_| uuid::Uuid::new_v4().to_string()); let public_url = std::env::var("PUBLIC_URL") .unwrap_or_else(|_| "https://tori.euphon.cloud".to_string()); if let (Ok(id), Ok(secret)) = ( std::env::var("SSO_CLIENT_ID"), std::env::var("SSO_CLIENT_SECRET"), ) { tracing::info!("TikTok SSO enabled (public_url={})", public_url); Some(api::auth::AuthConfig { provider: api::auth::OAuthProvider::TikTokSso { client_id: id, client_secret: secret, }, jwt_secret, public_url, }) } else if let (Ok(id), Ok(secret)) = ( std::env::var("GOOGLE_CLIENT_ID"), std::env::var("GOOGLE_CLIENT_SECRET"), ) { tracing::info!("Google OAuth enabled (public_url={})", public_url); Some(api::auth::AuthConfig { provider: api::auth::OAuthProvider::Google { client_id: id, client_secret: secret, }, jwt_secret, public_url, }) } else { tracing::warn!("No OAuth configured"); None } }; let ws_pool = database.pool.clone(); let state = Arc::new(AppState { db: database, config: config.clone(), agent_mgr: agent_mgr.clone(), kb: kb_arc, obj_root: obj_root.clone(), auth: auth_config, }); let app = Router::new() .route("/tori/api/health", axum::routing::get(|| async { axum::Json(serde_json::json!({"status": "ok"})) })) .nest("/tori/api/auth", api::auth::router(state.clone())) .nest("/tori/api", api::router(state.clone()) .layer(axum::middleware::from_fn_with_state(state.clone(), api::auth::require_auth)) ) .nest("/api/obj", api::obj::router(obj_root.clone())) .route("/api/obj/", axum::routing::get({ let r = obj_root; move || api::obj::root_listing(r) })) .nest("/ws/tori/workers", { let pool = ws_pool; let agent_mgr_for_ws = agent_mgr.clone(); let broadcast_fn: Arc tokio::sync::broadcast::Sender + Send + Sync> = Arc::new(move |pid: &str| agent_mgr_for_ws.get_broadcast_sender(pid)); ws_worker::router(worker_mgr, pool, broadcast_fn) }) .nest("/ws/tori", ws::router(agent_mgr)) .nest_service("/tori", ServeDir::new("web/dist").fallback(ServeFile::new("web/dist/index.html"))) .route("/", axum::routing::get(|| async { axum::response::Redirect::permanent("/tori/") })) .layer(CorsLayer::permissive()); let addr = format!("{}:{}", &config.server.host, config.server.port); tracing::info!("Tori server listening on {}", addr); let listener = tokio::net::TcpListener::bind(&addr).await?; axum::serve(listener, app).await?; Ok(()) } async fn resume_workflows(pool: SqlitePool, agent_mgr: Arc) { let rows: Vec<(String, String, String)> = match sqlx::query_as( "SELECT w.id, w.project_id, w.requirement FROM workflows w \ JOIN projects p ON w.project_id = p.id \ WHERE w.status IN ('pending', 'planning', 'executing') \ AND p.deleted = 0 \ ORDER BY w.created_at ASC" ) .fetch_all(&pool) .await { Ok(r) => r, Err(e) => { tracing::error!("Failed to query incomplete workflows: {}", e); return; } }; if rows.is_empty() { tracing::info!("No incomplete workflows to resume"); return; } tracing::info!("Resuming {} incomplete workflow(s)", rows.len()); for (workflow_id, project_id, requirement) in rows { tracing::info!("Resuming workflow {} (project {})", workflow_id, project_id); agent_mgr.send_event(&project_id, agent::AgentEvent::NewRequirement { workflow_id, requirement, template_id: None, }).await; } }