use std::collections::HashMap; use std::sync::Arc; use tokio::sync::RwLock; use serde::{Deserialize, Serialize}; /// Information reported by a worker on registration. #[derive(Debug, Clone, Serialize, Deserialize)] pub struct WorkerInfo { pub name: String, pub cpu: String, pub memory: String, pub gpu: String, pub os: String, pub kernel: String, } /// A registered worker with a channel for sending scripts to execute. struct Worker { pub info: WorkerInfo, pub tx: tokio::sync::mpsc::Sender, } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct WorkerRequest { pub job_id: String, pub script: String, } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct WorkerResult { pub job_id: String, pub exit_code: i32, pub stdout: String, pub stderr: String, } /// Manages all connected workers. pub struct WorkerManager { workers: RwLock>, /// Pending job results, keyed by job_id. results: RwLock>>, } impl WorkerManager { pub fn new() -> Arc { Arc::new(Self { workers: RwLock::new(HashMap::new()), results: RwLock::new(HashMap::new()), }) } /// Register a new worker. Returns a receiver for job requests. pub async fn register( &self, name: String, info: WorkerInfo, ) -> tokio::sync::mpsc::Receiver { let (tx, rx) = tokio::sync::mpsc::channel(16); tracing::info!("Worker registered: {} (cpu={}, mem={}, gpu={}, os={}, kernel={})", name, info.cpu, info.memory, info.gpu, info.os, info.kernel); self.workers.write().await.insert(name, Worker { info, tx }); rx } /// Remove a worker. pub async fn unregister(&self, name: &str) { self.workers.write().await.remove(name); tracing::info!("Worker unregistered: {}", name); } /// List all connected workers. pub async fn list(&self) -> Vec<(String, WorkerInfo)> { self.workers .read() .await .iter() .map(|(name, w)| (name.clone(), w.info.clone())) .collect() } /// Submit a script to a worker and wait for the result. pub async fn execute( &self, worker_name: &str, script: &str, timeout_secs: u64, ) -> Result { let job_id = uuid::Uuid::new_v4().to_string(); // Find the worker and send the request let tx = { let workers = self.workers.read().await; let worker = workers .get(worker_name) .ok_or_else(|| format!("Worker '{}' not found", worker_name))?; worker.tx.clone() }; let (result_tx, result_rx) = tokio::sync::oneshot::channel(); self.results.write().await.insert(job_id.clone(), result_tx); let req = WorkerRequest { job_id: job_id.clone(), script: script.to_string(), }; tx.send(req).await.map_err(|_| { format!("Worker '{}' disconnected", worker_name) })?; // Wait for result with timeout match tokio::time::timeout( std::time::Duration::from_secs(timeout_secs), result_rx, ) .await { Ok(Ok(result)) => Ok(result), Ok(Err(_)) => Err("Worker channel closed unexpectedly".into()), Err(_) => { self.results.write().await.remove(&job_id); Err(format!("Execution timed out after {}s", timeout_secs)) } } } /// Called when a worker sends back a result. pub async fn report_result(&self, result: WorkerResult) { if let Some(tx) = self.results.write().await.remove(&result.job_id) { let _ = tx.send(result); } } }