use serde::{Deserialize, Serialize}; use crate::LlmConfig; pub struct LlmClient { client: reqwest::Client, config: LlmConfig, } #[derive(Debug, Serialize)] struct ChatRequest { model: String, messages: Vec, #[serde(skip_serializing_if = "Vec::is_empty")] tools: Vec, } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ChatMessage { pub role: String, #[serde(default, skip_serializing_if = "Option::is_none")] pub content: Option, #[serde(default, skip_serializing_if = "Option::is_none")] pub tool_calls: Option>, #[serde(default, skip_serializing_if = "Option::is_none")] pub tool_call_id: Option, } impl ChatMessage { pub fn system(content: &str) -> Self { Self { role: "system".into(), content: Some(content.into()), tool_calls: None, tool_call_id: None } } pub fn user(content: &str) -> Self { Self { role: "user".into(), content: Some(content.into()), tool_calls: None, tool_call_id: None } } pub fn tool_result(tool_call_id: &str, content: &str) -> Self { Self { role: "tool".into(), content: Some(content.into()), tool_calls: None, tool_call_id: Some(tool_call_id.into()) } } } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct Tool { #[serde(rename = "type")] pub tool_type: String, pub function: ToolFunction, } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ToolFunction { pub name: String, pub description: String, pub parameters: serde_json::Value, } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ToolCall { pub id: String, #[serde(rename = "type")] pub call_type: String, pub function: ToolCallFunction, } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ToolCallFunction { pub name: String, pub arguments: String, } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct Usage { #[serde(default)] pub prompt_tokens: u32, #[serde(default)] pub completion_tokens: u32, #[serde(default)] pub total_tokens: u32, } #[derive(Debug, Deserialize)] pub struct ChatResponse { pub choices: Vec, pub usage: Option, } #[derive(Debug, Deserialize)] pub struct ChatChoice { pub message: ChatMessage, #[allow(dead_code)] pub finish_reason: Option, } impl LlmClient { pub fn new(config: &LlmConfig) -> Self { Self { client: reqwest::Client::builder() .timeout(std::time::Duration::from_secs(120)) .connect_timeout(std::time::Duration::from_secs(10)) .build() .expect("Failed to build HTTP client"), config: config.clone(), } } /// Simple chat without tools — returns content string pub async fn chat(&self, messages: Vec) -> anyhow::Result { let resp = self.chat_with_tools(messages, &[]).await?; Ok(resp.choices.into_iter().next() .and_then(|c| c.message.content) .unwrap_or_default()) } /// Chat with tool definitions — returns full response for tool-calling loop. /// Retries up to 3 times with exponential backoff on transient errors. pub async fn chat_with_tools(&self, messages: Vec, tools: &[Tool]) -> anyhow::Result { let url = format!("{}/chat/completions", self.config.base_url); let max_retries = 3u32; let mut last_err = None; let tools_vec = tools.to_vec(); for attempt in 0..max_retries { if attempt > 0 { let delay = std::time::Duration::from_secs(2u64.pow(attempt)); tracing::warn!("LLM retry #{} after {}s", attempt, delay.as_secs()); tokio::time::sleep(delay).await; } tracing::debug!("LLM request to {} model={} messages={} tools={} attempt={}", url, self.config.model, messages.len(), tools_vec.len(), attempt + 1); let result = self.client .post(&url) .header("Authorization", format!("Bearer {}", self.config.api_key)) .json(&ChatRequest { model: self.config.model.clone(), messages: messages.clone(), tools: tools_vec.clone(), }) .send() .await; let http_resp = match result { Ok(r) => r, Err(e) => { tracing::warn!("LLM request error (attempt {}): {}", attempt + 1, e); last_err = Some(anyhow::anyhow!("{}", e)); continue; } }; let status = http_resp.status(); if status.is_server_error() || status.as_u16() == 429 { let body = http_resp.text().await.unwrap_or_default(); tracing::warn!("LLM API error {} (attempt {}): {}", status, attempt + 1, &body[..body.len().min(200)]); last_err = Some(anyhow::anyhow!("LLM API error {}: {}", status, body)); continue; } if !status.is_success() { let body = http_resp.text().await.unwrap_or_default(); tracing::error!("LLM API error {}: {}", status, &body[..body.len().min(500)]); anyhow::bail!("LLM API error {}: {}", status, body); } let body = http_resp.text().await?; let resp: ChatResponse = serde_json::from_str(&body).map_err(|e| { tracing::error!("LLM response parse error: {}. Body: {}", e, &body[..body.len().min(500)]); anyhow::anyhow!("Failed to parse LLM response: {}", e) })?; return Ok(resp); } Err(last_err.unwrap_or_else(|| anyhow::anyhow!("LLM call failed after {} retries", max_retries))) } }