Files
tori/src/main.rs
Fam Zheng decabc0e8a refactor: server no longer runs agent loop or LLM
- Remove agent_loop from server (was ~400 lines) — server dispatches to workers
- AgentManager simplified to pure dispatcher (send_event → worker)
- Remove LLM config requirement from server (workers bring their own via config.yaml)
- Remove process_feedback, build_feedback_tools from server
- Remove chat API endpoint (LLM on workers only)
- Remove service proxy (services run on workers)
- Worker reads LLM config from its own config.yaml
- ws_worker.rs handles WorkerToServer::Update messages (DB + broadcast)
- Verified locally: tori server + tori worker connect and register
2026-04-06 13:18:21 +01:00

215 lines
7.3 KiB
Rust

use std::sync::Arc;
use axum::Router;
use clap::{Parser, Subcommand};
use sqlx::sqlite::SqlitePool;
use tower_http::cors::CorsLayer;
use tower_http::services::{ServeDir, ServeFile};
use tori::{agent, api, db, kb, template, timer, worker, worker_runner, ws, ws_worker};
use tori::{AppState, Config};
#[derive(Parser)]
#[command(name = "tori", about = "Tori AI agent orchestration")]
struct Cli {
#[command(subcommand)]
command: Command,
}
#[derive(Subcommand)]
enum Command {
/// Start the API server
Server,
/// Start a worker that connects to the server
Worker {
/// Server WebSocket URL
#[arg(long, env = "TORI_SERVER", default_value = "ws://127.0.0.1:3000/ws/tori/workers")]
server: String,
/// Worker name
#[arg(long, env = "TORI_WORKER_NAME")]
name: Option<String>,
/// Config file path (for LLM settings)
#[arg(long, default_value = "config.yaml")]
config: String,
},
}
#[tokio::main]
async fn main() -> anyhow::Result<()> {
tracing_subscriber::fmt()
.with_env_filter("tori=debug,tower_http=debug")
.init();
let cli = Cli::parse();
match cli.command {
Command::Server => run_server().await,
Command::Worker { server, name, config } => {
let name = name.unwrap_or_else(|| {
hostname::get()
.map(|h| h.to_string_lossy().to_string())
.unwrap_or_else(|_| "worker-1".to_string())
});
let config_str = std::fs::read_to_string(&config)
.unwrap_or_else(|e| panic!("Failed to read {}: {}", config, e));
let cfg: Config = serde_yaml::from_str(&config_str)
.unwrap_or_else(|e| panic!("Failed to parse {}: {}", config, e));
let llm_config = cfg.llm
.unwrap_or_else(|| panic!("LLM config required in {} for worker mode", config));
worker_runner::run(&server, &name, &llm_config).await
}
}
}
async fn run_server() -> anyhow::Result<()> {
let config_str = std::fs::read_to_string("config.yaml")
.expect("Failed to read config.yaml");
let config: Config = serde_yaml::from_str(&config_str)
.expect("Failed to parse config.yaml");
let database = db::Database::new(&config.database.path).await?;
database.migrate().await?;
let kb_arc = match kb::KbManager::new(database.pool.clone()) {
Ok(kb) => {
tracing::info!("KB manager initialized");
Some(Arc::new(kb))
}
Err(e) => {
tracing::warn!("KB manager init failed (will retry on use): {}", e);
None
}
};
if let Some(ref repo_cfg) = config.template_repo {
template::ensure_repo_ready(repo_cfg).await;
}
let worker_mgr = worker::WorkerManager::new();
let agent_mgr = agent::AgentManager::new(
database.pool.clone(),
worker_mgr.clone(),
);
timer::start_timer_runner(database.pool.clone(), agent_mgr.clone());
resume_workflows(database.pool.clone(), agent_mgr.clone()).await;
let obj_root = std::env::var("OBJ_ROOT").unwrap_or_else(|_| "/data/obj".to_string());
let auth_config = {
let jwt_secret = std::env::var("JWT_SECRET")
.unwrap_or_else(|_| uuid::Uuid::new_v4().to_string());
let public_url = std::env::var("PUBLIC_URL")
.unwrap_or_else(|_| "https://tori.euphon.cloud".to_string());
if let (Ok(id), Ok(secret)) = (
std::env::var("SSO_CLIENT_ID"),
std::env::var("SSO_CLIENT_SECRET"),
) {
tracing::info!("TikTok SSO enabled (public_url={})", public_url);
Some(api::auth::AuthConfig {
provider: api::auth::OAuthProvider::TikTokSso {
client_id: id,
client_secret: secret,
},
jwt_secret,
public_url,
})
} else if let (Ok(id), Ok(secret)) = (
std::env::var("GOOGLE_CLIENT_ID"),
std::env::var("GOOGLE_CLIENT_SECRET"),
) {
tracing::info!("Google OAuth enabled (public_url={})", public_url);
Some(api::auth::AuthConfig {
provider: api::auth::OAuthProvider::Google {
client_id: id,
client_secret: secret,
},
jwt_secret,
public_url,
})
} else {
tracing::warn!("No OAuth configured");
None
}
};
let ws_pool = database.pool.clone();
let state = Arc::new(AppState {
db: database,
config: config.clone(),
agent_mgr: agent_mgr.clone(),
kb: kb_arc,
obj_root: obj_root.clone(),
auth: auth_config,
});
let app = Router::new()
.route("/tori/api/health", axum::routing::get(|| async {
axum::Json(serde_json::json!({"status": "ok"}))
}))
.nest("/tori/api/auth", api::auth::router(state.clone()))
.nest("/tori/api", api::router(state.clone())
.layer(axum::middleware::from_fn_with_state(state.clone(), api::auth::require_auth))
)
.nest("/api/obj", api::obj::router(obj_root.clone()))
.route("/api/obj/", axum::routing::get({
let r = obj_root;
move || api::obj::root_listing(r)
}))
.nest("/ws/tori/workers", {
let pool = ws_pool;
let agent_mgr_for_ws = agent_mgr.clone();
let broadcast_fn: Arc<dyn Fn(&str) -> tokio::sync::broadcast::Sender<agent::WsMessage> + Send + Sync> =
Arc::new(move |pid: &str| agent_mgr_for_ws.get_broadcast_sender(pid));
ws_worker::router(worker_mgr, pool, broadcast_fn)
})
.nest("/ws/tori", ws::router(agent_mgr))
.nest_service("/tori", ServeDir::new("web/dist").fallback(ServeFile::new("web/dist/index.html")))
.route("/", axum::routing::get(|| async {
axum::response::Redirect::permanent("/tori/")
}))
.layer(CorsLayer::permissive());
let addr = format!("{}:{}", &config.server.host, config.server.port);
tracing::info!("Tori server listening on {}", addr);
let listener = tokio::net::TcpListener::bind(&addr).await?;
axum::serve(listener, app).await?;
Ok(())
}
async fn resume_workflows(pool: SqlitePool, agent_mgr: Arc<agent::AgentManager>) {
let rows: Vec<(String, String, String)> = match sqlx::query_as(
"SELECT w.id, w.project_id, w.requirement FROM workflows w \
JOIN projects p ON w.project_id = p.id \
WHERE w.status IN ('pending', 'planning', 'executing') \
AND p.deleted = 0 \
ORDER BY w.created_at ASC"
)
.fetch_all(&pool)
.await
{
Ok(r) => r,
Err(e) => {
tracing::error!("Failed to query incomplete workflows: {}", e);
return;
}
};
if rows.is_empty() {
tracing::info!("No incomplete workflows to resume");
return;
}
tracing::info!("Resuming {} incomplete workflow(s)", rows.len());
for (workflow_id, project_id, requirement) in rows {
tracing::info!("Resuming workflow {} (project {})", workflow_id, project_id);
agent_mgr.send_event(&project_id, agent::AgentEvent::NewRequirement {
workflow_id,
requirement,
template_id: None,
}).await;
}
}