Files
tori/src/llm.rs
Fam Zheng fa800b1601 feat: step artifacts framework
- Add Artifact type to Step (name, path, artifact_type, description)
- step_done tool accepts optional artifacts parameter
- Save artifacts to step_artifacts DB table
- Display artifacts in frontend PlanSection (tag style)
- Show artifacts in step context for sub-agents and coordinator
- Add LLM client retry with exponential backoff
2026-03-09 12:01:29 +00:00

175 lines
5.8 KiB
Rust

use serde::{Deserialize, Serialize};
use crate::LlmConfig;
pub struct LlmClient {
client: reqwest::Client,
config: LlmConfig,
}
#[derive(Debug, Serialize)]
struct ChatRequest {
model: String,
messages: Vec<ChatMessage>,
#[serde(skip_serializing_if = "Vec::is_empty")]
tools: Vec<Tool>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatMessage {
pub role: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub content: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub tool_calls: Option<Vec<ToolCall>>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub tool_call_id: Option<String>,
}
impl ChatMessage {
pub fn system(content: &str) -> Self {
Self { role: "system".into(), content: Some(content.into()), tool_calls: None, tool_call_id: None }
}
pub fn user(content: &str) -> Self {
Self { role: "user".into(), content: Some(content.into()), tool_calls: None, tool_call_id: None }
}
pub fn tool_result(tool_call_id: &str, content: &str) -> Self {
Self { role: "tool".into(), content: Some(content.into()), tool_calls: None, tool_call_id: Some(tool_call_id.into()) }
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Tool {
#[serde(rename = "type")]
pub tool_type: String,
pub function: ToolFunction,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolFunction {
pub name: String,
pub description: String,
pub parameters: serde_json::Value,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolCall {
pub id: String,
#[serde(rename = "type")]
pub call_type: String,
pub function: ToolCallFunction,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolCallFunction {
pub name: String,
pub arguments: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Usage {
#[serde(default)]
pub prompt_tokens: u32,
#[serde(default)]
pub completion_tokens: u32,
#[serde(default)]
pub total_tokens: u32,
}
#[derive(Debug, Deserialize)]
pub struct ChatResponse {
pub choices: Vec<ChatChoice>,
pub usage: Option<Usage>,
}
#[derive(Debug, Deserialize)]
pub struct ChatChoice {
pub message: ChatMessage,
#[allow(dead_code)]
pub finish_reason: Option<String>,
}
impl LlmClient {
pub fn new(config: &LlmConfig) -> Self {
Self {
client: reqwest::Client::builder()
.timeout(std::time::Duration::from_secs(120))
.connect_timeout(std::time::Duration::from_secs(10))
.build()
.expect("Failed to build HTTP client"),
config: config.clone(),
}
}
/// Simple chat without tools — returns content string
pub async fn chat(&self, messages: Vec<ChatMessage>) -> anyhow::Result<String> {
let resp = self.chat_with_tools(messages, &[]).await?;
Ok(resp.choices.into_iter().next()
.and_then(|c| c.message.content)
.unwrap_or_default())
}
/// Chat with tool definitions — returns full response for tool-calling loop.
/// Retries up to 3 times with exponential backoff on transient errors.
pub async fn chat_with_tools(&self, messages: Vec<ChatMessage>, tools: &[Tool]) -> anyhow::Result<ChatResponse> {
let url = format!("{}/chat/completions", self.config.base_url);
let max_retries = 3u32;
let mut last_err = None;
let tools_vec = tools.to_vec();
for attempt in 0..max_retries {
if attempt > 0 {
let delay = std::time::Duration::from_secs(2u64.pow(attempt));
tracing::warn!("LLM retry #{} after {}s", attempt, delay.as_secs());
tokio::time::sleep(delay).await;
}
tracing::debug!("LLM request to {} model={} messages={} tools={} attempt={}", url, self.config.model, messages.len(), tools_vec.len(), attempt + 1);
let result = self.client
.post(&url)
.header("Authorization", format!("Bearer {}", self.config.api_key))
.json(&ChatRequest {
model: self.config.model.clone(),
messages: messages.clone(),
tools: tools_vec.clone(),
})
.send()
.await;
let http_resp = match result {
Ok(r) => r,
Err(e) => {
tracing::warn!("LLM request error (attempt {}): {}", attempt + 1, e);
last_err = Some(anyhow::anyhow!("{}", e));
continue;
}
};
let status = http_resp.status();
if status.is_server_error() || status.as_u16() == 429 {
let body = http_resp.text().await.unwrap_or_default();
tracing::warn!("LLM API error {} (attempt {}): {}", status, attempt + 1, &body[..body.len().min(200)]);
last_err = Some(anyhow::anyhow!("LLM API error {}: {}", status, body));
continue;
}
if !status.is_success() {
let body = http_resp.text().await.unwrap_or_default();
tracing::error!("LLM API error {}: {}", status, &body[..body.len().min(500)]);
anyhow::bail!("LLM API error {}: {}", status, body);
}
let body = http_resp.text().await?;
let resp: ChatResponse = serde_json::from_str(&body).map_err(|e| {
tracing::error!("LLM response parse error: {}. Body: {}", e, &body[..body.len().min(500)]);
anyhow::anyhow!("Failed to parse LLM response: {}", e)
})?;
return Ok(resp);
}
Err(last_err.unwrap_or_else(|| anyhow::anyhow!("LLM call failed after {} retries", max_retries)))
}
}