feat: add Google OAuth, remote worker system, and file browser
- Google OAuth login with JWT session cookies, per-user project isolation - Remote worker registration via WebSocket, execute_on_worker/list_workers agent tools - File browser UI in workflow view, file upload/download API - Deploy script switched to local build, added tori.euphon.cloud ingress
This commit is contained in:
63
src/agent.rs
63
src/agent.rs
@@ -10,6 +10,7 @@ use crate::llm::{LlmClient, ChatMessage, Tool, ToolFunction};
|
||||
use crate::exec::LocalExecutor;
|
||||
use crate::template::{self, LoadedTemplate};
|
||||
use crate::tools::ExternalToolManager;
|
||||
use crate::worker::WorkerManager;
|
||||
use crate::LlmConfig;
|
||||
|
||||
use crate::state::{AgentState, AgentPhase, Artifact, Step, StepStatus, StepResult, StepResultStatus, check_scratchpad_size};
|
||||
@@ -80,6 +81,7 @@ pub struct AgentManager {
|
||||
template_repo: Option<crate::TemplateRepoConfig>,
|
||||
kb: Option<Arc<crate::kb::KbManager>>,
|
||||
jwt_private_key_path: Option<String>,
|
||||
pub worker_mgr: Arc<WorkerManager>,
|
||||
}
|
||||
|
||||
impl AgentManager {
|
||||
@@ -89,6 +91,7 @@ impl AgentManager {
|
||||
template_repo: Option<crate::TemplateRepoConfig>,
|
||||
kb: Option<Arc<crate::kb::KbManager>>,
|
||||
jwt_private_key_path: Option<String>,
|
||||
worker_mgr: Arc<WorkerManager>,
|
||||
) -> Arc<Self> {
|
||||
Arc::new(Self {
|
||||
agents: RwLock::new(HashMap::new()),
|
||||
@@ -100,6 +103,7 @@ impl AgentManager {
|
||||
template_repo,
|
||||
kb,
|
||||
jwt_private_key_path,
|
||||
worker_mgr,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -755,6 +759,19 @@ fn build_step_tools() -> Vec<Tool> {
|
||||
})),
|
||||
tool_kb_search(),
|
||||
tool_kb_read(),
|
||||
make_tool("list_workers", "列出所有已注册的远程 worker 节点及其硬件/软件信息(CPU、内存、GPU、OS、内核)。", serde_json::json!({
|
||||
"type": "object",
|
||||
"properties": {}
|
||||
})),
|
||||
make_tool("execute_on_worker", "在指定的远程 worker 上执行脚本。脚本以 bash 执行。可以通过 HTTP 访问项目文件:GET/POST /api/obj/{project_id}/files/{path}", serde_json::json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"worker": { "type": "string", "description": "Worker 名称(从 list_workers 获取)" },
|
||||
"script": { "type": "string", "description": "要执行的 bash 脚本内容" },
|
||||
"timeout": { "type": "integer", "description": "超时秒数(默认 300)", "default": 300 }
|
||||
},
|
||||
"required": ["worker", "script"]
|
||||
})),
|
||||
]
|
||||
}
|
||||
|
||||
@@ -1500,6 +1517,52 @@ async fn run_step_loop(
|
||||
step_chat_history.push(ChatMessage::tool_result(&tc.id, &result));
|
||||
}
|
||||
|
||||
"list_workers" => {
|
||||
let _ = broadcast_tx.send(WsMessage::ActivityUpdate {
|
||||
workflow_id: workflow_id.to_string(),
|
||||
activity: format!("步骤 {} — 列出 Workers", step_order),
|
||||
});
|
||||
let workers = mgr.worker_mgr.list().await;
|
||||
let result = if workers.is_empty() {
|
||||
"没有已注册的 worker。".to_string()
|
||||
} else {
|
||||
let items: Vec<String> = workers.iter().map(|(name, info)| {
|
||||
format!("- {} (cpu={}, mem={}, gpu={}, os={}, kernel={})",
|
||||
name, info.cpu, info.memory, info.gpu, info.os, info.kernel)
|
||||
}).collect();
|
||||
format!("已注册的 workers:\n{}", items.join("\n"))
|
||||
};
|
||||
log_execution(pool, broadcast_tx, workflow_id, step_order, "list_workers", "", &result, "done").await;
|
||||
step_chat_history.push(ChatMessage::tool_result(&tc.id, &result));
|
||||
}
|
||||
|
||||
"execute_on_worker" => {
|
||||
let worker_name = args.get("worker").and_then(|v| v.as_str()).unwrap_or("");
|
||||
let script = args.get("script").and_then(|v| v.as_str()).unwrap_or("");
|
||||
let timeout = args.get("timeout").and_then(|v| v.as_u64()).unwrap_or(300);
|
||||
let _ = broadcast_tx.send(WsMessage::ActivityUpdate {
|
||||
workflow_id: workflow_id.to_string(),
|
||||
activity: format!("步骤 {} — 在 {} 上执行脚本", step_order, worker_name),
|
||||
});
|
||||
let result = match mgr.worker_mgr.execute(worker_name, script, timeout).await {
|
||||
Ok(wr) => {
|
||||
let mut out = String::new();
|
||||
out.push_str(&format!("exit_code: {}\n", wr.exit_code));
|
||||
if !wr.stdout.is_empty() {
|
||||
out.push_str(&format!("stdout:\n{}\n", truncate_str(&wr.stdout, 8192)));
|
||||
}
|
||||
if !wr.stderr.is_empty() {
|
||||
out.push_str(&format!("stderr:\n{}\n", truncate_str(&wr.stderr, 4096)));
|
||||
}
|
||||
out
|
||||
}
|
||||
Err(e) => format!("Error: {}", e),
|
||||
};
|
||||
let status = if result.starts_with("Error:") { "failed" } else { "done" };
|
||||
log_execution(pool, broadcast_tx, workflow_id, step_order, "execute_on_worker", &tc.function.arguments, &result, status).await;
|
||||
step_chat_history.push(ChatMessage::tool_result(&tc.id, &result));
|
||||
}
|
||||
|
||||
// External tools
|
||||
name if external_tools.as_ref().is_some_and(|e| e.has_tool(name)) => {
|
||||
let _ = broadcast_tx.send(WsMessage::ActivityUpdate {
|
||||
|
||||
314
src/api/auth.rs
314
src/api/auth.rs
@@ -1,29 +1,51 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use axum::{
|
||||
extract::State,
|
||||
extract::{Query, Request, State},
|
||||
http::StatusCode,
|
||||
response::IntoResponse,
|
||||
routing::post,
|
||||
middleware::Next,
|
||||
response::{IntoResponse, Redirect, Response},
|
||||
routing::{get, post},
|
||||
Json, Router,
|
||||
};
|
||||
use jsonwebtoken::{encode, EncodingKey, Header, Algorithm};
|
||||
use axum_extra::extract::cookie::{Cookie, CookieJar, SameSite};
|
||||
use jsonwebtoken::{decode, encode, DecodingKey, EncodingKey, Header, Algorithm, Validation};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::AppState;
|
||||
|
||||
const COOKIE_NAME: &str = "tori_session";
|
||||
const CSRF_COOKIE: &str = "tori_session_csrf";
|
||||
const COOKIE_PATH: &str = "/";
|
||||
const SESSION_SECS: i64 = 7 * 86400;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct AuthConfig {
|
||||
pub google_client_id: String,
|
||||
pub google_client_secret: String,
|
||||
pub jwt_secret: String,
|
||||
pub public_url: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
pub struct Claims {
|
||||
pub sub: String,
|
||||
pub email: String,
|
||||
pub exp: i64,
|
||||
}
|
||||
|
||||
// --- EC key token generation (for agent/API use) ---
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct TokenResponse {
|
||||
struct EcTokenResponse {
|
||||
token: String,
|
||||
expires_in: u64,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct TokenRequest {
|
||||
/// Subject claim (e.g. "oseng", "tori-agent")
|
||||
struct EcTokenRequest {
|
||||
#[serde(default = "default_sub")]
|
||||
sub: String,
|
||||
/// Token validity in seconds (default: 300)
|
||||
#[serde(default = "default_ttl")]
|
||||
ttl_secs: u64,
|
||||
}
|
||||
@@ -37,7 +59,7 @@ fn default_ttl() -> u64 {
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct Claims {
|
||||
struct EcClaims {
|
||||
sub: String,
|
||||
iat: usize,
|
||||
exp: usize,
|
||||
@@ -45,7 +67,7 @@ struct Claims {
|
||||
|
||||
async fn generate_token(
|
||||
State(state): State<Arc<AppState>>,
|
||||
Json(body): Json<TokenRequest>,
|
||||
Json(body): Json<EcTokenRequest>,
|
||||
) -> impl IntoResponse {
|
||||
let privkey_pem = match &state.config.jwt_private_key {
|
||||
Some(path) => match std::fs::read_to_string(path) {
|
||||
@@ -69,7 +91,7 @@ async fn generate_token(
|
||||
};
|
||||
|
||||
let now = chrono::Utc::now().timestamp() as usize;
|
||||
let claims = Claims {
|
||||
let claims = EcClaims {
|
||||
sub: body.sub,
|
||||
iat: now,
|
||||
exp: now + body.ttl_secs as usize,
|
||||
@@ -77,7 +99,7 @@ async fn generate_token(
|
||||
|
||||
let header = Header::new(Algorithm::ES256);
|
||||
match encode(&header, &claims, &key) {
|
||||
Ok(token) => Json(TokenResponse {
|
||||
Ok(token) => Json(EcTokenResponse {
|
||||
token,
|
||||
expires_in: body.ttl_secs,
|
||||
}).into_response(),
|
||||
@@ -88,8 +110,276 @@ async fn generate_token(
|
||||
}
|
||||
}
|
||||
|
||||
// --- Google OAuth ---
|
||||
|
||||
pub fn router(state: Arc<AppState>) -> Router {
|
||||
Router::new()
|
||||
.route("/login", get(login))
|
||||
.route("/callback", get(callback))
|
||||
.route("/me", get(me))
|
||||
.route("/logout", post(logout))
|
||||
.route("/token", post(generate_token))
|
||||
.with_state(state)
|
||||
}
|
||||
|
||||
fn build_cookie(name: &str, value: String, max_age_secs: i64) -> Cookie<'static> {
|
||||
let mut c = Cookie::new(name.to_owned(), value);
|
||||
c.set_path(COOKIE_PATH);
|
||||
c.set_http_only(true);
|
||||
c.set_secure(true);
|
||||
c.set_same_site(SameSite::Lax);
|
||||
c.set_max_age(Some(time::Duration::seconds(max_age_secs)));
|
||||
c
|
||||
}
|
||||
|
||||
fn clear_cookie(name: &str) -> Cookie<'static> {
|
||||
build_cookie(name, String::new(), 0)
|
||||
}
|
||||
|
||||
async fn login(State(state): State<Arc<AppState>>) -> Response {
|
||||
let auth = match &state.auth {
|
||||
Some(a) => a,
|
||||
None => return (StatusCode::SERVICE_UNAVAILABLE, "Auth not configured").into_response(),
|
||||
};
|
||||
|
||||
let csrf = uuid::Uuid::new_v4().to_string();
|
||||
let redirect_uri = format!("{}/tori/api/auth/callback", auth.public_url);
|
||||
let url = format!(
|
||||
"https://accounts.google.com/o/oauth2/v2/auth?\
|
||||
client_id={}&redirect_uri={}&response_type=code&\
|
||||
scope=openid%20email%20profile&access_type=online&state={}",
|
||||
pct_encode(&auth.google_client_id),
|
||||
pct_encode(&redirect_uri),
|
||||
pct_encode(&csrf),
|
||||
);
|
||||
|
||||
let jar = CookieJar::new().add(build_cookie(CSRF_COOKIE, csrf, 300));
|
||||
(jar, Redirect::temporary(&url)).into_response()
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct CallbackParams {
|
||||
code: String,
|
||||
state: String,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct TokenResponse {
|
||||
id_token: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct GoogleUserInfo {
|
||||
sub: String,
|
||||
email: String,
|
||||
#[serde(default)]
|
||||
name: String,
|
||||
#[serde(default)]
|
||||
picture: String,
|
||||
}
|
||||
|
||||
async fn callback(
|
||||
State(state): State<Arc<AppState>>,
|
||||
jar: CookieJar,
|
||||
Query(params): Query<CallbackParams>,
|
||||
) -> Response {
|
||||
let auth = match &state.auth {
|
||||
Some(a) => a,
|
||||
None => return (StatusCode::SERVICE_UNAVAILABLE, "Auth not configured").into_response(),
|
||||
};
|
||||
|
||||
// CSRF check
|
||||
match jar.get(CSRF_COOKIE) {
|
||||
Some(c) if c.value() == params.state => {}
|
||||
_ => return (StatusCode::BAD_REQUEST, "Invalid state parameter").into_response(),
|
||||
}
|
||||
|
||||
let redirect_uri = format!("{}/tori/api/auth/callback", auth.public_url);
|
||||
|
||||
// Exchange code for token
|
||||
let client = reqwest::Client::new();
|
||||
let token_res = client
|
||||
.post("https://oauth2.googleapis.com/token")
|
||||
.form(&[
|
||||
("code", params.code.as_str()),
|
||||
("client_id", &auth.google_client_id),
|
||||
("client_secret", &auth.google_client_secret),
|
||||
("redirect_uri", &redirect_uri),
|
||||
("grant_type", "authorization_code"),
|
||||
])
|
||||
.send()
|
||||
.await;
|
||||
|
||||
let token_body: TokenResponse = match token_res {
|
||||
Ok(r) if r.status().is_success() => match r.json().await {
|
||||
Ok(t) => t,
|
||||
Err(e) => return (StatusCode::BAD_GATEWAY, format!("Token parse error: {}", e)).into_response(),
|
||||
},
|
||||
Ok(r) => {
|
||||
let body = r.text().await.unwrap_or_default();
|
||||
tracing::error!("Google token exchange failed: {}", body);
|
||||
return (StatusCode::BAD_GATEWAY, "Google token exchange failed").into_response();
|
||||
}
|
||||
Err(e) => return (StatusCode::BAD_GATEWAY, format!("Token request failed: {}", e)).into_response(),
|
||||
};
|
||||
|
||||
let id_token = match token_body.id_token {
|
||||
Some(t) => t,
|
||||
None => return (StatusCode::BAD_GATEWAY, "No id_token in response").into_response(),
|
||||
};
|
||||
|
||||
// Decode id_token payload (no verification needed - just received from Google over HTTPS)
|
||||
let user_info = match decode_google_id_token(&id_token) {
|
||||
Some(u) => u,
|
||||
None => return (StatusCode::BAD_GATEWAY, "Failed to decode id_token").into_response(),
|
||||
};
|
||||
|
||||
// Upsert user
|
||||
let user_id = format!("google:{}", user_info.sub);
|
||||
let _ = sqlx::query(
|
||||
"INSERT INTO users (id, email, name, picture)
|
||||
VALUES (?, ?, ?, ?)
|
||||
ON CONFLICT(id) DO UPDATE SET
|
||||
email = excluded.email,
|
||||
name = excluded.name,
|
||||
picture = excluded.picture,
|
||||
last_login_at = datetime('now')"
|
||||
)
|
||||
.bind(&user_id)
|
||||
.bind(&user_info.email)
|
||||
.bind(&user_info.name)
|
||||
.bind(&user_info.picture)
|
||||
.execute(&state.db.pool)
|
||||
.await;
|
||||
|
||||
tracing::info!("User logged in: {} ({})", user_info.email, user_id);
|
||||
|
||||
// Sign JWT
|
||||
let exp = chrono::Utc::now().timestamp() + SESSION_SECS;
|
||||
let claims = Claims {
|
||||
sub: user_id,
|
||||
email: user_info.email,
|
||||
exp,
|
||||
};
|
||||
let token = match encode(
|
||||
&Header::default(),
|
||||
&claims,
|
||||
&EncodingKey::from_secret(auth.jwt_secret.as_bytes()),
|
||||
) {
|
||||
Ok(t) => t,
|
||||
Err(e) => return (StatusCode::INTERNAL_SERVER_ERROR, format!("JWT error: {}", e)).into_response(),
|
||||
};
|
||||
|
||||
let jar = CookieJar::new()
|
||||
.add(build_cookie(COOKIE_NAME, token, SESSION_SECS))
|
||||
.add(clear_cookie(CSRF_COOKIE));
|
||||
(jar, Redirect::temporary("/tori/")).into_response()
|
||||
}
|
||||
|
||||
async fn me(State(state): State<Arc<AppState>>, jar: CookieJar) -> Response {
|
||||
let auth = match &state.auth {
|
||||
Some(a) => a,
|
||||
None => return (StatusCode::SERVICE_UNAVAILABLE, "Auth not configured").into_response(),
|
||||
};
|
||||
|
||||
let claims = match extract_claims(&jar, &auth.jwt_secret) {
|
||||
Some(c) => c,
|
||||
None => return StatusCode::UNAUTHORIZED.into_response(),
|
||||
};
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct UserInfo {
|
||||
id: String,
|
||||
email: String,
|
||||
name: String,
|
||||
picture: String,
|
||||
}
|
||||
|
||||
let user: Option<UserInfo> = sqlx::query_as::<_, (String, String, String, String)>(
|
||||
"SELECT id, email, name, picture FROM users WHERE id = ?"
|
||||
)
|
||||
.bind(&claims.sub)
|
||||
.fetch_optional(&state.db.pool)
|
||||
.await
|
||||
.ok()
|
||||
.flatten()
|
||||
.map(|(id, email, name, picture)| UserInfo { id, email, name, picture });
|
||||
|
||||
match user {
|
||||
Some(u) => Json(u).into_response(),
|
||||
None => StatusCode::UNAUTHORIZED.into_response(),
|
||||
}
|
||||
}
|
||||
|
||||
async fn logout(jar: CookieJar) -> impl IntoResponse {
|
||||
(jar.add(clear_cookie(COOKIE_NAME)), StatusCode::OK)
|
||||
}
|
||||
|
||||
// --- Middleware ---
|
||||
|
||||
pub async fn require_auth(
|
||||
State(state): State<Arc<AppState>>,
|
||||
jar: CookieJar,
|
||||
mut req: Request,
|
||||
next: Next,
|
||||
) -> Response {
|
||||
let auth = match &state.auth {
|
||||
Some(a) => a,
|
||||
None => return next.run(req).await, // auth not configured, pass through
|
||||
};
|
||||
|
||||
match extract_claims(&jar, &auth.jwt_secret) {
|
||||
Some(claims) => {
|
||||
req.extensions_mut().insert(claims);
|
||||
next.run(req).await
|
||||
}
|
||||
None => StatusCode::UNAUTHORIZED.into_response(),
|
||||
}
|
||||
}
|
||||
|
||||
// --- Helpers ---
|
||||
|
||||
fn extract_claims(jar: &CookieJar, jwt_secret: &str) -> Option<Claims> {
|
||||
let token = jar.get(COOKIE_NAME)?.value().to_string();
|
||||
let key = DecodingKey::from_secret(jwt_secret.as_bytes());
|
||||
let mut validation = Validation::default();
|
||||
validation.validate_exp = true;
|
||||
decode::<Claims>(&token, &key, &validation)
|
||||
.ok()
|
||||
.map(|d| d.claims)
|
||||
}
|
||||
|
||||
fn decode_google_id_token(id_token: &str) -> Option<GoogleUserInfo> {
|
||||
let parts: Vec<&str> = id_token.split('.').collect();
|
||||
if parts.len() != 3 {
|
||||
return None;
|
||||
}
|
||||
let padded = match parts[1].len() % 4 {
|
||||
2 => format!("{}==", parts[1]),
|
||||
3 => format!("{}=", parts[1]),
|
||||
_ => parts[1].to_string(),
|
||||
};
|
||||
let payload = base64_decode_url_safe(&padded)?;
|
||||
serde_json::from_slice(&payload).ok()
|
||||
}
|
||||
|
||||
fn base64_decode_url_safe(input: &str) -> Option<Vec<u8>> {
|
||||
let standard = input.replace('-', "+").replace('_', "/");
|
||||
use base64::Engine;
|
||||
base64::engine::general_purpose::STANDARD.decode(&standard).ok()
|
||||
}
|
||||
|
||||
fn pct_encode(s: &str) -> String {
|
||||
let mut out = String::with_capacity(s.len());
|
||||
for b in s.bytes() {
|
||||
match b {
|
||||
b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'_' | b'.' | b'~' => {
|
||||
out.push(b as char);
|
||||
}
|
||||
_ => {
|
||||
out.push_str(&format!("%{:02X}", b));
|
||||
}
|
||||
}
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
258
src/api/files.rs
Normal file
258
src/api/files.rs
Normal file
@@ -0,0 +1,258 @@
|
||||
use std::path::PathBuf;
|
||||
|
||||
use axum::{
|
||||
extract::{Multipart, Path},
|
||||
http::StatusCode,
|
||||
response::{IntoResponse, Response},
|
||||
routing::get,
|
||||
Json, Router,
|
||||
};
|
||||
use serde::Serialize;
|
||||
|
||||
fn workspace_root() -> &'static str {
|
||||
if std::path::Path::new("/app/data/workspaces").is_dir() {
|
||||
"/app/data/workspaces"
|
||||
} else {
|
||||
"data/workspaces"
|
||||
}
|
||||
}
|
||||
|
||||
fn resolve_path(project_id: &str, rel: &str) -> Result<PathBuf, Response> {
|
||||
let base = PathBuf::from(workspace_root()).join(project_id);
|
||||
let full = base.join(rel);
|
||||
// Prevent path traversal
|
||||
if rel.contains("..") {
|
||||
return Err((StatusCode::BAD_REQUEST, "Invalid path").into_response());
|
||||
}
|
||||
Ok(full)
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct FileEntry {
|
||||
name: String,
|
||||
is_dir: bool,
|
||||
size: u64,
|
||||
}
|
||||
|
||||
pub fn router() -> Router {
|
||||
Router::new()
|
||||
.route(
|
||||
"/projects/{id}/files",
|
||||
get(list_root).post(upload_root).patch(mkdir_root),
|
||||
)
|
||||
.route(
|
||||
"/projects/{id}/files/{*path}",
|
||||
get(get_file)
|
||||
.post(upload_file)
|
||||
.put(rename_file)
|
||||
.delete(delete_file)
|
||||
.patch(mkdir),
|
||||
)
|
||||
}
|
||||
|
||||
async fn list_dir(dir: PathBuf) -> Result<Json<Vec<FileEntry>>, Response> {
|
||||
let mut entries = Vec::new();
|
||||
// Return empty list if directory doesn't exist yet
|
||||
let mut rd = match tokio::fs::read_dir(&dir).await {
|
||||
Ok(rd) => rd,
|
||||
Err(_) => return Ok(Json(entries)),
|
||||
};
|
||||
while let Ok(Some(e)) = rd.next_entry().await {
|
||||
let meta = match e.metadata().await {
|
||||
Ok(m) => m,
|
||||
Err(_) => continue,
|
||||
};
|
||||
entries.push(FileEntry {
|
||||
name: e.file_name().to_string_lossy().to_string(),
|
||||
is_dir: meta.is_dir(),
|
||||
size: meta.len(),
|
||||
});
|
||||
}
|
||||
entries.sort_by(|a, b| {
|
||||
b.is_dir.cmp(&a.is_dir).then(a.name.cmp(&b.name))
|
||||
});
|
||||
Ok(Json(entries))
|
||||
}
|
||||
|
||||
async fn list_root(Path(project_id): Path<String>) -> Result<Json<Vec<FileEntry>>, Response> {
|
||||
let dir = resolve_path(&project_id, "")?;
|
||||
list_dir(dir).await
|
||||
}
|
||||
|
||||
async fn get_file(
|
||||
Path((project_id, file_path)): Path<(String, String)>,
|
||||
) -> Response {
|
||||
let full = match resolve_path(&project_id, &file_path) {
|
||||
Ok(p) => p,
|
||||
Err(e) => return e,
|
||||
};
|
||||
|
||||
// If it's a directory, list contents
|
||||
if full.is_dir() {
|
||||
return match list_dir(full).await {
|
||||
Ok(j) => j.into_response(),
|
||||
Err(e) => e,
|
||||
};
|
||||
}
|
||||
|
||||
// Otherwise serve the file
|
||||
match tokio::fs::read(&full).await {
|
||||
Ok(bytes) => {
|
||||
let mime = mime_guess::from_path(&full)
|
||||
.first_or_octet_stream()
|
||||
.to_string();
|
||||
let filename = full
|
||||
.file_name()
|
||||
.and_then(|n| n.to_str())
|
||||
.unwrap_or("file");
|
||||
(
|
||||
[
|
||||
(axum::http::header::CONTENT_TYPE, mime),
|
||||
(
|
||||
axum::http::header::CONTENT_DISPOSITION,
|
||||
format!("attachment; filename=\"{}\"", filename),
|
||||
),
|
||||
],
|
||||
bytes,
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
Err(_) => (StatusCode::NOT_FOUND, "File not found").into_response(),
|
||||
}
|
||||
}
|
||||
|
||||
async fn do_upload(project_id: &str, rel_dir: &str, mut multipart: Multipart) -> Response {
|
||||
let dir = match resolve_path(project_id, rel_dir) {
|
||||
Ok(p) => p,
|
||||
Err(e) => return e,
|
||||
};
|
||||
if let Err(e) = tokio::fs::create_dir_all(&dir).await {
|
||||
return (StatusCode::INTERNAL_SERVER_ERROR, format!("mkdir failed: {}", e)).into_response();
|
||||
}
|
||||
|
||||
let mut count = 0u32;
|
||||
while let Ok(Some(field)) = multipart.next_field().await {
|
||||
let filename: String = match field.file_name() {
|
||||
Some(f) => f.to_string(),
|
||||
None => continue,
|
||||
};
|
||||
if filename.contains("..") || filename.contains('/') {
|
||||
return (StatusCode::BAD_REQUEST, "Invalid filename").into_response();
|
||||
}
|
||||
let data = match field.bytes().await {
|
||||
Ok(d) => d,
|
||||
Err(e) => return (StatusCode::BAD_REQUEST, format!("Read error: {}", e)).into_response(),
|
||||
};
|
||||
let dest = dir.join(&filename);
|
||||
if let Err(e) = tokio::fs::write(&dest, &data).await {
|
||||
return (StatusCode::INTERNAL_SERVER_ERROR, format!("Write error: {}", e)).into_response();
|
||||
}
|
||||
count += 1;
|
||||
}
|
||||
|
||||
Json(serde_json::json!({ "uploaded": count })).into_response()
|
||||
}
|
||||
|
||||
async fn upload_root(
|
||||
Path(project_id): Path<String>,
|
||||
multipart: Multipart,
|
||||
) -> Response {
|
||||
do_upload(&project_id, "", multipart).await
|
||||
}
|
||||
|
||||
async fn upload_file(
|
||||
Path((project_id, file_path)): Path<(String, String)>,
|
||||
multipart: Multipart,
|
||||
) -> Response {
|
||||
do_upload(&project_id, &file_path, multipart).await
|
||||
}
|
||||
|
||||
#[derive(serde::Deserialize)]
|
||||
struct RenameInput {
|
||||
new_name: String,
|
||||
}
|
||||
|
||||
async fn rename_file(
|
||||
Path((project_id, file_path)): Path<(String, String)>,
|
||||
Json(input): Json<RenameInput>,
|
||||
) -> Response {
|
||||
if input.new_name.contains("..") || input.new_name.contains('/') {
|
||||
return (StatusCode::BAD_REQUEST, "Invalid new name").into_response();
|
||||
}
|
||||
|
||||
let src = match resolve_path(&project_id, &file_path) {
|
||||
Ok(p) => p,
|
||||
Err(e) => return e,
|
||||
};
|
||||
|
||||
let dst = src
|
||||
.parent()
|
||||
.unwrap_or(&src)
|
||||
.join(&input.new_name);
|
||||
|
||||
match tokio::fs::rename(&src, &dst).await {
|
||||
Ok(()) => Json(serde_json::json!({ "ok": true })).into_response(),
|
||||
Err(e) => (StatusCode::INTERNAL_SERVER_ERROR, format!("Rename failed: {}", e)).into_response(),
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(serde::Deserialize)]
|
||||
struct MkdirInput {
|
||||
name: String,
|
||||
}
|
||||
|
||||
async fn mkdir(
|
||||
Path((project_id, file_path)): Path<(String, String)>,
|
||||
Json(input): Json<MkdirInput>,
|
||||
) -> Response {
|
||||
if input.name.contains("..") || input.name.contains('/') {
|
||||
return (StatusCode::BAD_REQUEST, "Invalid directory name").into_response();
|
||||
}
|
||||
let parent = match resolve_path(&project_id, &file_path) {
|
||||
Ok(p) => p,
|
||||
Err(e) => return e,
|
||||
};
|
||||
let dir = parent.join(&input.name);
|
||||
match tokio::fs::create_dir_all(&dir).await {
|
||||
Ok(()) => Json(serde_json::json!({ "ok": true })).into_response(),
|
||||
Err(e) => (StatusCode::INTERNAL_SERVER_ERROR, format!("mkdir failed: {}", e)).into_response(),
|
||||
}
|
||||
}
|
||||
|
||||
async fn mkdir_root(
|
||||
Path(project_id): Path<String>,
|
||||
Json(input): Json<MkdirInput>,
|
||||
) -> Response {
|
||||
if input.name.contains("..") || input.name.contains('/') {
|
||||
return (StatusCode::BAD_REQUEST, "Invalid directory name").into_response();
|
||||
}
|
||||
let parent = match resolve_path(&project_id, "") {
|
||||
Ok(p) => p,
|
||||
Err(e) => return e,
|
||||
};
|
||||
let dir = parent.join(&input.name);
|
||||
match tokio::fs::create_dir_all(&dir).await {
|
||||
Ok(()) => Json(serde_json::json!({ "ok": true })).into_response(),
|
||||
Err(e) => (StatusCode::INTERNAL_SERVER_ERROR, format!("mkdir failed: {}", e)).into_response(),
|
||||
}
|
||||
}
|
||||
|
||||
async fn delete_file(
|
||||
Path((project_id, file_path)): Path<(String, String)>,
|
||||
) -> Response {
|
||||
let full = match resolve_path(&project_id, &file_path) {
|
||||
Ok(p) => p,
|
||||
Err(e) => return e,
|
||||
};
|
||||
|
||||
let result = if full.is_dir() {
|
||||
tokio::fs::remove_dir_all(&full).await
|
||||
} else {
|
||||
tokio::fs::remove_file(&full).await
|
||||
};
|
||||
|
||||
match result {
|
||||
Ok(()) => Json(serde_json::json!({ "ok": true })).into_response(),
|
||||
Err(e) => (StatusCode::INTERNAL_SERVER_ERROR, format!("Delete failed: {}", e)).into_response(),
|
||||
}
|
||||
}
|
||||
@@ -1,10 +1,12 @@
|
||||
mod auth;
|
||||
pub mod auth;
|
||||
mod chat;
|
||||
mod files;
|
||||
mod kb;
|
||||
pub mod obj;
|
||||
mod projects;
|
||||
mod settings;
|
||||
mod timers;
|
||||
mod workers;
|
||||
mod workflows;
|
||||
|
||||
use std::sync::Arc;
|
||||
@@ -13,7 +15,7 @@ use axum::{
|
||||
extract::{Path, State, Request},
|
||||
http::StatusCode,
|
||||
response::{IntoResponse, Response},
|
||||
routing::{get, any},
|
||||
routing::any,
|
||||
Json, Router,
|
||||
};
|
||||
|
||||
@@ -34,8 +36,8 @@ pub fn router(state: Arc<AppState>) -> Router {
|
||||
.merge(kb::router(state.clone()))
|
||||
.merge(settings::router(state.clone()))
|
||||
.merge(chat::router(state.clone()))
|
||||
.merge(auth::router(state.clone()))
|
||||
.route("/projects/{id}/files/{*path}", get(serve_project_file))
|
||||
.merge(workers::router(state.clone()))
|
||||
.merge(files::router())
|
||||
.route("/projects/{id}/app/{*path}", any(proxy_to_service).with_state(state.clone()))
|
||||
.route("/projects/{id}/app/", any(proxy_to_service_root).with_state(state))
|
||||
}
|
||||
@@ -103,40 +105,6 @@ async fn proxy_impl(
|
||||
}
|
||||
}
|
||||
|
||||
async fn serve_project_file(
|
||||
Path((project_id, file_path)): Path<(String, String)>,
|
||||
) -> Response {
|
||||
let full_path = std::path::PathBuf::from("/app/data/workspaces")
|
||||
.join(&project_id)
|
||||
.join(&file_path);
|
||||
|
||||
// Prevent path traversal
|
||||
if file_path.contains("..") {
|
||||
return (StatusCode::BAD_REQUEST, "Invalid path").into_response();
|
||||
}
|
||||
|
||||
match tokio::fs::read(&full_path).await {
|
||||
Ok(bytes) => {
|
||||
// Render markdown files as HTML
|
||||
if full_path.extension().is_some_and(|e| e == "md") {
|
||||
let md = String::from_utf8_lossy(&bytes);
|
||||
let html = render_markdown_page(&md, &file_path);
|
||||
return (
|
||||
[(axum::http::header::CONTENT_TYPE, "text/html; charset=utf-8".to_string())],
|
||||
html,
|
||||
).into_response();
|
||||
}
|
||||
let mime = mime_guess::from_path(&full_path)
|
||||
.first_or_octet_stream()
|
||||
.to_string();
|
||||
(
|
||||
[(axum::http::header::CONTENT_TYPE, mime)],
|
||||
bytes,
|
||||
).into_response()
|
||||
}
|
||||
Err(_) => (StatusCode::NOT_FOUND, "File not found").into_response(),
|
||||
}
|
||||
}
|
||||
|
||||
fn render_markdown_page(markdown: &str, title: &str) -> String {
|
||||
use pulldown_cmark::{Parser, Options, html};
|
||||
|
||||
@@ -4,10 +4,16 @@ use axum::{
|
||||
routing::get,
|
||||
Json, Router,
|
||||
};
|
||||
use axum::http::Extensions;
|
||||
use serde::Deserialize;
|
||||
use crate::AppState;
|
||||
use crate::db::Project;
|
||||
use super::{ApiResult, db_err};
|
||||
use super::auth::Claims;
|
||||
|
||||
fn owner_id(ext: &Extensions) -> &str {
|
||||
ext.get::<Claims>().map(|c| c.sub.as_str()).unwrap_or("")
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct CreateProject {
|
||||
@@ -31,25 +37,33 @@ pub fn router(state: Arc<AppState>) -> Router {
|
||||
|
||||
async fn list_projects(
|
||||
State(state): State<Arc<AppState>>,
|
||||
ext: Extensions,
|
||||
) -> ApiResult<Vec<Project>> {
|
||||
sqlx::query_as::<_, Project>("SELECT * FROM projects WHERE deleted = 0 ORDER BY updated_at DESC")
|
||||
.fetch_all(&state.db.pool)
|
||||
.await
|
||||
.map(Json)
|
||||
.map_err(db_err)
|
||||
let uid = owner_id(&ext);
|
||||
sqlx::query_as::<_, Project>(
|
||||
"SELECT * FROM projects WHERE deleted = 0 AND (owner_id = ? OR owner_id = '') ORDER BY updated_at DESC"
|
||||
)
|
||||
.bind(uid)
|
||||
.fetch_all(&state.db.pool)
|
||||
.await
|
||||
.map(Json)
|
||||
.map_err(db_err)
|
||||
}
|
||||
|
||||
async fn create_project(
|
||||
State(state): State<Arc<AppState>>,
|
||||
ext: Extensions,
|
||||
Json(input): Json<CreateProject>,
|
||||
) -> ApiResult<Project> {
|
||||
let id = uuid::Uuid::new_v4().to_string();
|
||||
let uid = owner_id(&ext);
|
||||
sqlx::query_as::<_, Project>(
|
||||
"INSERT INTO projects (id, name, description) VALUES (?, ?, ?) RETURNING *"
|
||||
"INSERT INTO projects (id, name, description, owner_id) VALUES (?, ?, ?, ?) RETURNING *"
|
||||
)
|
||||
.bind(&id)
|
||||
.bind(&input.name)
|
||||
.bind(&input.description)
|
||||
.bind(uid)
|
||||
.fetch_one(&state.db.pool)
|
||||
.await
|
||||
.map(Json)
|
||||
@@ -58,55 +72,71 @@ async fn create_project(
|
||||
|
||||
async fn get_project(
|
||||
State(state): State<Arc<AppState>>,
|
||||
ext: Extensions,
|
||||
Path(id): Path<String>,
|
||||
) -> ApiResult<Option<Project>> {
|
||||
sqlx::query_as::<_, Project>("SELECT * FROM projects WHERE id = ?")
|
||||
.bind(&id)
|
||||
.fetch_optional(&state.db.pool)
|
||||
.await
|
||||
.map(Json)
|
||||
.map_err(db_err)
|
||||
let uid = owner_id(&ext);
|
||||
sqlx::query_as::<_, Project>(
|
||||
"SELECT * FROM projects WHERE id = ? AND (owner_id = ? OR owner_id = '')"
|
||||
)
|
||||
.bind(&id)
|
||||
.bind(uid)
|
||||
.fetch_optional(&state.db.pool)
|
||||
.await
|
||||
.map(Json)
|
||||
.map_err(db_err)
|
||||
}
|
||||
|
||||
async fn update_project(
|
||||
State(state): State<Arc<AppState>>,
|
||||
ext: Extensions,
|
||||
Path(id): Path<String>,
|
||||
Json(input): Json<UpdateProject>,
|
||||
) -> ApiResult<Option<Project>> {
|
||||
let uid = owner_id(&ext);
|
||||
if let Some(name) = &input.name {
|
||||
sqlx::query("UPDATE projects SET name = ?, updated_at = datetime('now') WHERE id = ?")
|
||||
sqlx::query("UPDATE projects SET name = ?, updated_at = datetime('now') WHERE id = ? AND (owner_id = ? OR owner_id = '')")
|
||||
.bind(name)
|
||||
.bind(&id)
|
||||
.bind(uid)
|
||||
.execute(&state.db.pool)
|
||||
.await
|
||||
.map_err(db_err)?;
|
||||
}
|
||||
if let Some(desc) = &input.description {
|
||||
sqlx::query("UPDATE projects SET description = ?, updated_at = datetime('now') WHERE id = ?")
|
||||
sqlx::query("UPDATE projects SET description = ?, updated_at = datetime('now') WHERE id = ? AND (owner_id = ? OR owner_id = '')")
|
||||
.bind(desc)
|
||||
.bind(&id)
|
||||
.bind(uid)
|
||||
.execute(&state.db.pool)
|
||||
.await
|
||||
.map_err(db_err)?;
|
||||
}
|
||||
sqlx::query_as::<_, Project>("SELECT * FROM projects WHERE id = ?")
|
||||
.bind(&id)
|
||||
.fetch_optional(&state.db.pool)
|
||||
.await
|
||||
.map(Json)
|
||||
.map_err(db_err)
|
||||
sqlx::query_as::<_, Project>(
|
||||
"SELECT * FROM projects WHERE id = ? AND (owner_id = ? OR owner_id = '')"
|
||||
)
|
||||
.bind(&id)
|
||||
.bind(uid)
|
||||
.fetch_optional(&state.db.pool)
|
||||
.await
|
||||
.map(Json)
|
||||
.map_err(db_err)
|
||||
}
|
||||
|
||||
async fn delete_project(
|
||||
State(state): State<Arc<AppState>>,
|
||||
ext: Extensions,
|
||||
Path(id): Path<String>,
|
||||
) -> ApiResult<bool> {
|
||||
// Soft delete: mark as deleted in DB
|
||||
let result = sqlx::query("UPDATE projects SET deleted = 1, updated_at = datetime('now') WHERE id = ? AND deleted = 0")
|
||||
.bind(&id)
|
||||
.execute(&state.db.pool)
|
||||
.await
|
||||
.map_err(db_err)?;
|
||||
let uid = owner_id(&ext);
|
||||
let result = sqlx::query(
|
||||
"UPDATE projects SET deleted = 1, updated_at = datetime('now') WHERE id = ? AND deleted = 0 AND (owner_id = ? OR owner_id = '')"
|
||||
)
|
||||
.bind(&id)
|
||||
.bind(uid)
|
||||
.execute(&state.db.pool)
|
||||
.await
|
||||
.map_err(db_err)?;
|
||||
|
||||
if result.rows_affected() == 0 {
|
||||
return Ok(Json(false));
|
||||
|
||||
17
src/api/workers.rs
Normal file
17
src/api/workers.rs
Normal file
@@ -0,0 +1,17 @@
|
||||
use std::sync::Arc;
|
||||
use axum::{extract::State, routing::get, Json, Router};
|
||||
|
||||
use crate::AppState;
|
||||
use crate::worker::WorkerInfo;
|
||||
|
||||
async fn list_workers(State(state): State<Arc<AppState>>) -> Json<Vec<WorkerInfo>> {
|
||||
let workers = state.agent_mgr.worker_mgr.list().await;
|
||||
let entries: Vec<WorkerInfo> = workers.into_iter().map(|(_, info)| info).collect();
|
||||
Json(entries)
|
||||
}
|
||||
|
||||
pub fn router(state: Arc<AppState>) -> Router {
|
||||
Router::new()
|
||||
.route("/workers", get(list_workers))
|
||||
.with_state(state)
|
||||
}
|
||||
22
src/db.rs
22
src/db.rs
@@ -73,6 +73,13 @@ impl Database {
|
||||
.execute(&self.pool)
|
||||
.await;
|
||||
|
||||
// Migration: add owner_id column to projects
|
||||
let _ = sqlx::query(
|
||||
"ALTER TABLE projects ADD COLUMN owner_id TEXT NOT NULL DEFAULT ''"
|
||||
)
|
||||
.execute(&self.pool)
|
||||
.await;
|
||||
|
||||
// KB tables
|
||||
sqlx::query(
|
||||
"CREATE TABLE IF NOT EXISTS kb_articles (
|
||||
@@ -215,6 +222,19 @@ impl Database {
|
||||
.execute(&self.pool)
|
||||
.await?;
|
||||
|
||||
sqlx::query(
|
||||
"CREATE TABLE IF NOT EXISTS users (
|
||||
id TEXT PRIMARY KEY,
|
||||
email TEXT NOT NULL UNIQUE,
|
||||
name TEXT NOT NULL DEFAULT '',
|
||||
picture TEXT NOT NULL DEFAULT '',
|
||||
created_at TEXT NOT NULL DEFAULT (datetime('now')),
|
||||
last_login_at TEXT NOT NULL DEFAULT (datetime('now'))
|
||||
)"
|
||||
)
|
||||
.execute(&self.pool)
|
||||
.await?;
|
||||
|
||||
sqlx::query(
|
||||
"CREATE TABLE IF NOT EXISTS step_artifacts (
|
||||
id TEXT PRIMARY KEY,
|
||||
@@ -242,6 +262,8 @@ pub struct Project {
|
||||
pub updated_at: String,
|
||||
#[serde(default)]
|
||||
pub deleted: bool,
|
||||
#[serde(default)]
|
||||
pub owner_id: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, sqlx::FromRow)]
|
||||
|
||||
38
src/main.rs
38
src/main.rs
@@ -8,7 +8,9 @@ pub mod state;
|
||||
mod template;
|
||||
mod timer;
|
||||
mod tools;
|
||||
mod worker;
|
||||
mod ws;
|
||||
mod ws_worker;
|
||||
|
||||
use std::sync::Arc;
|
||||
use axum::Router;
|
||||
@@ -22,6 +24,7 @@ pub struct AppState {
|
||||
pub agent_mgr: Arc<agent::AgentManager>,
|
||||
pub kb: Option<Arc<kb::KbManager>>,
|
||||
pub obj_root: String,
|
||||
pub auth: Option<api::auth::AuthConfig>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, serde::Deserialize)]
|
||||
@@ -102,12 +105,15 @@ async fn main() -> anyhow::Result<()> {
|
||||
template::ensure_repo_ready(repo_cfg).await;
|
||||
}
|
||||
|
||||
let worker_mgr = worker::WorkerManager::new();
|
||||
|
||||
let agent_mgr = agent::AgentManager::new(
|
||||
database.pool.clone(),
|
||||
config.llm.clone(),
|
||||
config.template_repo.clone(),
|
||||
kb_arc.clone(),
|
||||
config.jwt_private_key.clone(),
|
||||
worker_mgr.clone(),
|
||||
);
|
||||
|
||||
timer::start_timer_runner(database.pool.clone(), agent_mgr.clone());
|
||||
@@ -117,21 +123,51 @@ async fn main() -> anyhow::Result<()> {
|
||||
|
||||
let obj_root = std::env::var("OBJ_ROOT").unwrap_or_else(|_| "/data/obj".to_string());
|
||||
|
||||
let auth_config = match (
|
||||
std::env::var("GOOGLE_CLIENT_ID"),
|
||||
std::env::var("GOOGLE_CLIENT_SECRET"),
|
||||
) {
|
||||
(Ok(client_id), Ok(client_secret)) => {
|
||||
let jwt_secret = std::env::var("JWT_SECRET")
|
||||
.unwrap_or_else(|_| uuid::Uuid::new_v4().to_string());
|
||||
let public_url = std::env::var("PUBLIC_URL")
|
||||
.unwrap_or_else(|_| "https://tori.euphon.cloud".to_string());
|
||||
tracing::info!("Google OAuth enabled (public_url={})", public_url);
|
||||
Some(api::auth::AuthConfig {
|
||||
google_client_id: client_id,
|
||||
google_client_secret: client_secret,
|
||||
jwt_secret,
|
||||
public_url,
|
||||
})
|
||||
}
|
||||
_ => {
|
||||
tracing::warn!("GOOGLE_CLIENT_ID / GOOGLE_CLIENT_SECRET not set, auth disabled");
|
||||
None
|
||||
}
|
||||
};
|
||||
|
||||
let state = Arc::new(AppState {
|
||||
db: database,
|
||||
config: config.clone(),
|
||||
agent_mgr: agent_mgr.clone(),
|
||||
kb: kb_arc,
|
||||
obj_root: obj_root.clone(),
|
||||
auth: auth_config,
|
||||
});
|
||||
|
||||
let app = Router::new()
|
||||
.nest("/tori/api", api::router(state))
|
||||
// Auth routes are public
|
||||
.nest("/tori/api/auth", api::auth::router(state.clone()))
|
||||
// Protected API routes
|
||||
.nest("/tori/api", api::router(state.clone())
|
||||
.layer(axum::middleware::from_fn_with_state(state.clone(), api::auth::require_auth))
|
||||
)
|
||||
.nest("/api/obj", api::obj::router(obj_root.clone()))
|
||||
.route("/api/obj/", axum::routing::get({
|
||||
let r = obj_root;
|
||||
move || api::obj::root_listing(r)
|
||||
}))
|
||||
.nest("/ws/tori/workers", ws_worker::router(worker_mgr))
|
||||
.nest("/ws/tori", ws::router(agent_mgr))
|
||||
.nest_service("/tori", ServeDir::new("web/dist").fallback(ServeFile::new("web/dist/index.html")))
|
||||
.route("/", axum::routing::get(|| async {
|
||||
|
||||
@@ -6,6 +6,8 @@
|
||||
- read_file / write_file / list_files:文件操作
|
||||
- start_service / stop_service:管理后台服务
|
||||
- kb_search / kb_read:搜索和读取知识库
|
||||
- list_workers:列出已注册的远程 worker 节点及其硬件/软件信息
|
||||
- execute_on_worker(worker, script, timeout):在远程 worker 上执行脚本
|
||||
- update_scratchpad:记录本步骤内的中间状态(步骤结束后丢弃,精华写进 summary)
|
||||
- ask_user:向用户提问,暂停执行等待用户回复
|
||||
- step_done:**完成当前步骤时必须调用**,提供本步骤的工作摘要
|
||||
@@ -32,4 +34,22 @@
|
||||
- 后台服务访问:/api/projects/{project_id}/app/(启动命令需监听 0.0.0.0:$PORT)
|
||||
- 【重要】应用通过反向代理访问,前端 HTML/JS 中的 fetch/XHR 请求必须使用相对路径(如 fetch('todos')),绝对不能用 / 开头的路径(如 fetch('/todos')),否则会 404
|
||||
|
||||
## 远程 Worker
|
||||
|
||||
可以通过 `list_workers` 查看所有已注册的远程 worker,然后用 `execute_on_worker` 在指定 worker 上执行脚本。适用于需要特定硬件(如 GPU)或在远程环境执行任务的场景。
|
||||
|
||||
**重要**:
|
||||
- 在 worker 上执行脚本时,可以通过 obj API 访问项目文件:
|
||||
- 下载文件:`curl https://tori.euphon.cloud/api/obj/{project_id}/files/{path}`
|
||||
- 上传文件:`curl -X POST -F 'files=@output.txt' https://tori.euphon.cloud/api/obj/{project_id}/files/`
|
||||
- Python 脚本会自动通过 `uv run --script` 执行,支持 PEP 723 内联依赖声明:
|
||||
```python
|
||||
# /// script
|
||||
# requires-python = ">=3.10"
|
||||
# dependencies = ["requests", "pandas"]
|
||||
# ///
|
||||
import requests, pandas as pd
|
||||
...
|
||||
```
|
||||
|
||||
请使用中文回复。
|
||||
|
||||
133
src/worker.rs
Normal file
133
src/worker.rs
Normal file
@@ -0,0 +1,133 @@
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::RwLock;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Information reported by a worker on registration.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct WorkerInfo {
|
||||
pub name: String,
|
||||
pub cpu: String,
|
||||
pub memory: String,
|
||||
pub gpu: String,
|
||||
pub os: String,
|
||||
pub kernel: String,
|
||||
}
|
||||
|
||||
/// A registered worker with a channel for sending scripts to execute.
|
||||
struct Worker {
|
||||
pub info: WorkerInfo,
|
||||
pub tx: tokio::sync::mpsc::Sender<WorkerRequest>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct WorkerRequest {
|
||||
pub job_id: String,
|
||||
pub script: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct WorkerResult {
|
||||
pub job_id: String,
|
||||
pub exit_code: i32,
|
||||
pub stdout: String,
|
||||
pub stderr: String,
|
||||
}
|
||||
|
||||
/// Manages all connected workers.
|
||||
pub struct WorkerManager {
|
||||
workers: RwLock<HashMap<String, Worker>>,
|
||||
/// Pending job results, keyed by job_id.
|
||||
results: RwLock<HashMap<String, tokio::sync::oneshot::Sender<WorkerResult>>>,
|
||||
}
|
||||
|
||||
impl WorkerManager {
|
||||
pub fn new() -> Arc<Self> {
|
||||
Arc::new(Self {
|
||||
workers: RwLock::new(HashMap::new()),
|
||||
results: RwLock::new(HashMap::new()),
|
||||
})
|
||||
}
|
||||
|
||||
/// Register a new worker. Returns a receiver for job requests.
|
||||
pub async fn register(
|
||||
&self,
|
||||
name: String,
|
||||
info: WorkerInfo,
|
||||
) -> tokio::sync::mpsc::Receiver<WorkerRequest> {
|
||||
let (tx, rx) = tokio::sync::mpsc::channel(16);
|
||||
tracing::info!("Worker registered: {} (cpu={}, mem={}, gpu={}, os={}, kernel={})",
|
||||
name, info.cpu, info.memory, info.gpu, info.os, info.kernel);
|
||||
self.workers.write().await.insert(name, Worker { info, tx });
|
||||
rx
|
||||
}
|
||||
|
||||
/// Remove a worker.
|
||||
pub async fn unregister(&self, name: &str) {
|
||||
self.workers.write().await.remove(name);
|
||||
tracing::info!("Worker unregistered: {}", name);
|
||||
}
|
||||
|
||||
/// List all connected workers.
|
||||
pub async fn list(&self) -> Vec<(String, WorkerInfo)> {
|
||||
self.workers
|
||||
.read()
|
||||
.await
|
||||
.iter()
|
||||
.map(|(name, w)| (name.clone(), w.info.clone()))
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Submit a script to a worker and wait for the result.
|
||||
pub async fn execute(
|
||||
&self,
|
||||
worker_name: &str,
|
||||
script: &str,
|
||||
timeout_secs: u64,
|
||||
) -> Result<WorkerResult, String> {
|
||||
let job_id = uuid::Uuid::new_v4().to_string();
|
||||
|
||||
// Find the worker and send the request
|
||||
let tx = {
|
||||
let workers = self.workers.read().await;
|
||||
let worker = workers
|
||||
.get(worker_name)
|
||||
.ok_or_else(|| format!("Worker '{}' not found", worker_name))?;
|
||||
worker.tx.clone()
|
||||
};
|
||||
|
||||
let (result_tx, result_rx) = tokio::sync::oneshot::channel();
|
||||
self.results.write().await.insert(job_id.clone(), result_tx);
|
||||
|
||||
let req = WorkerRequest {
|
||||
job_id: job_id.clone(),
|
||||
script: script.to_string(),
|
||||
};
|
||||
|
||||
tx.send(req).await.map_err(|_| {
|
||||
format!("Worker '{}' disconnected", worker_name)
|
||||
})?;
|
||||
|
||||
// Wait for result with timeout
|
||||
match tokio::time::timeout(
|
||||
std::time::Duration::from_secs(timeout_secs),
|
||||
result_rx,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(Ok(result)) => Ok(result),
|
||||
Ok(Err(_)) => Err("Worker channel closed unexpectedly".into()),
|
||||
Err(_) => {
|
||||
self.results.write().await.remove(&job_id);
|
||||
Err(format!("Execution timed out after {}s", timeout_secs))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Called when a worker sends back a result.
|
||||
pub async fn report_result(&self, result: WorkerResult) {
|
||||
if let Some(tx) = self.results.write().await.remove(&result.job_id) {
|
||||
let _ = tx.send(result);
|
||||
}
|
||||
}
|
||||
}
|
||||
104
src/ws_worker.rs
Normal file
104
src/ws_worker.rs
Normal file
@@ -0,0 +1,104 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use axum::{
|
||||
extract::{State, WebSocketUpgrade, ws::{Message, WebSocket}},
|
||||
response::Response,
|
||||
routing::get,
|
||||
Router,
|
||||
};
|
||||
use futures::{SinkExt, StreamExt};
|
||||
use serde::Deserialize;
|
||||
|
||||
use crate::worker::{WorkerInfo, WorkerManager, WorkerResult};
|
||||
|
||||
pub fn router(mgr: Arc<WorkerManager>) -> Router {
|
||||
Router::new()
|
||||
.route("/", get(ws_handler))
|
||||
.with_state(mgr)
|
||||
}
|
||||
|
||||
async fn ws_handler(
|
||||
ws: WebSocketUpgrade,
|
||||
State(mgr): State<Arc<WorkerManager>>,
|
||||
) -> Response {
|
||||
ws.on_upgrade(move |socket| handle_worker_socket(socket, mgr))
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
#[serde(tag = "type")]
|
||||
enum WorkerMessage {
|
||||
#[serde(rename = "register")]
|
||||
Register { info: WorkerInfo },
|
||||
#[serde(rename = "result")]
|
||||
Result(WorkerResult),
|
||||
}
|
||||
|
||||
async fn handle_worker_socket(socket: WebSocket, mgr: Arc<WorkerManager>) {
|
||||
let (mut sender, mut receiver) = socket.split();
|
||||
|
||||
// First message must be registration
|
||||
let (name, mut job_rx) = loop {
|
||||
match receiver.next().await {
|
||||
Some(Ok(Message::Text(text))) => {
|
||||
match serde_json::from_str::<WorkerMessage>(&text) {
|
||||
Ok(WorkerMessage::Register { info }) => {
|
||||
let name = info.name.clone();
|
||||
let rx = mgr.register(name.clone(), info).await;
|
||||
// Ack
|
||||
let ack = serde_json::json!({ "type": "registered", "name": &name });
|
||||
let _ = sender.send(Message::Text(ack.to_string().into())).await;
|
||||
break (name, rx);
|
||||
}
|
||||
_ => {
|
||||
let _ = sender.send(Message::Text(
|
||||
r#"{"type":"error","message":"First message must be register"}"#.into(),
|
||||
)).await;
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
Some(Ok(Message::Close(_))) | None => return,
|
||||
_ => continue,
|
||||
}
|
||||
};
|
||||
|
||||
// Main loop: forward jobs to worker, receive results
|
||||
let name_clone = name.clone();
|
||||
let mgr_clone = mgr.clone();
|
||||
|
||||
// Task: send jobs from job_rx to the WebSocket
|
||||
let send_task = tokio::spawn(async move {
|
||||
while let Some(req) = job_rx.recv().await {
|
||||
let msg = serde_json::json!({
|
||||
"type": "execute",
|
||||
"job_id": req.job_id,
|
||||
"script": req.script,
|
||||
});
|
||||
if sender.send(Message::Text(msg.to_string().into())).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// Task: receive results from the WebSocket
|
||||
let recv_task = tokio::spawn(async move {
|
||||
while let Some(Ok(msg)) = receiver.next().await {
|
||||
match msg {
|
||||
Message::Text(text) => {
|
||||
if let Ok(WorkerMessage::Result(result)) = serde_json::from_str(&text) {
|
||||
mgr_clone.report_result(result).await;
|
||||
}
|
||||
}
|
||||
Message::Close(_) => break,
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
tokio::select! {
|
||||
_ = send_task => {},
|
||||
_ = recv_task => {},
|
||||
}
|
||||
|
||||
mgr.unregister(&name_clone).await;
|
||||
}
|
||||
Reference in New Issue
Block a user