ExternalToolManager.discover() now accepts template root dir, detects pyproject.toml and runs `uv sync` to create a venv. Tool invocation and schema discovery inject the venv PATH/VIRTUAL_ENV so template tools can import declared dependencies without manual installation.
175 lines
5.8 KiB
Rust
175 lines
5.8 KiB
Rust
use serde::{Deserialize, Serialize};
|
|
use crate::LlmConfig;
|
|
|
|
pub struct LlmClient {
|
|
client: reqwest::Client,
|
|
config: LlmConfig,
|
|
}
|
|
|
|
#[derive(Debug, Serialize)]
|
|
struct ChatRequest {
|
|
model: String,
|
|
messages: Vec<ChatMessage>,
|
|
#[serde(skip_serializing_if = "Vec::is_empty")]
|
|
tools: Vec<Tool>,
|
|
}
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
pub struct ChatMessage {
|
|
pub role: String,
|
|
#[serde(default, skip_serializing_if = "Option::is_none")]
|
|
pub content: Option<String>,
|
|
#[serde(default, skip_serializing_if = "Option::is_none")]
|
|
pub tool_calls: Option<Vec<ToolCall>>,
|
|
#[serde(default, skip_serializing_if = "Option::is_none")]
|
|
pub tool_call_id: Option<String>,
|
|
}
|
|
|
|
impl ChatMessage {
|
|
pub fn system(content: &str) -> Self {
|
|
Self { role: "system".into(), content: Some(content.into()), tool_calls: None, tool_call_id: None }
|
|
}
|
|
|
|
pub fn user(content: &str) -> Self {
|
|
Self { role: "user".into(), content: Some(content.into()), tool_calls: None, tool_call_id: None }
|
|
}
|
|
|
|
pub fn tool_result(tool_call_id: &str, content: &str) -> Self {
|
|
Self { role: "tool".into(), content: Some(content.into()), tool_calls: None, tool_call_id: Some(tool_call_id.into()) }
|
|
}
|
|
}
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
pub struct Tool {
|
|
#[serde(rename = "type")]
|
|
pub tool_type: String,
|
|
pub function: ToolFunction,
|
|
}
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
pub struct ToolFunction {
|
|
pub name: String,
|
|
pub description: String,
|
|
pub parameters: serde_json::Value,
|
|
}
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
pub struct ToolCall {
|
|
pub id: String,
|
|
#[serde(rename = "type")]
|
|
pub call_type: String,
|
|
pub function: ToolCallFunction,
|
|
}
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
pub struct ToolCallFunction {
|
|
pub name: String,
|
|
pub arguments: String,
|
|
}
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
pub struct Usage {
|
|
#[serde(default)]
|
|
pub prompt_tokens: u32,
|
|
#[serde(default)]
|
|
pub completion_tokens: u32,
|
|
#[serde(default)]
|
|
pub total_tokens: u32,
|
|
}
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
pub struct ChatResponse {
|
|
pub choices: Vec<ChatChoice>,
|
|
pub usage: Option<Usage>,
|
|
}
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
pub struct ChatChoice {
|
|
pub message: ChatMessage,
|
|
#[allow(dead_code)]
|
|
pub finish_reason: Option<String>,
|
|
}
|
|
|
|
impl LlmClient {
|
|
pub fn new(config: &LlmConfig) -> Self {
|
|
Self {
|
|
client: reqwest::Client::builder()
|
|
.timeout(std::time::Duration::from_secs(300))
|
|
.connect_timeout(std::time::Duration::from_secs(10))
|
|
.build()
|
|
.expect("Failed to build HTTP client"),
|
|
config: config.clone(),
|
|
}
|
|
}
|
|
|
|
/// Simple chat without tools — returns content string
|
|
pub async fn chat(&self, messages: Vec<ChatMessage>) -> anyhow::Result<String> {
|
|
let resp = self.chat_with_tools(messages, &[]).await?;
|
|
Ok(resp.choices.into_iter().next()
|
|
.and_then(|c| c.message.content)
|
|
.unwrap_or_default())
|
|
}
|
|
|
|
/// Chat with tool definitions — returns full response for tool-calling loop.
|
|
/// Retries up to 3 times with exponential backoff on transient errors.
|
|
pub async fn chat_with_tools(&self, messages: Vec<ChatMessage>, tools: &[Tool]) -> anyhow::Result<ChatResponse> {
|
|
let url = format!("{}/chat/completions", self.config.base_url);
|
|
let max_retries = 3u32;
|
|
let mut last_err = None;
|
|
let tools_vec = tools.to_vec();
|
|
|
|
for attempt in 0..max_retries {
|
|
if attempt > 0 {
|
|
let delay = std::time::Duration::from_secs(2u64.pow(attempt));
|
|
tracing::warn!("LLM retry #{} after {}s", attempt, delay.as_secs());
|
|
tokio::time::sleep(delay).await;
|
|
}
|
|
|
|
tracing::debug!("LLM request to {} model={} messages={} tools={} attempt={}", url, self.config.model, messages.len(), tools_vec.len(), attempt + 1);
|
|
let result = self.client
|
|
.post(&url)
|
|
.header("Authorization", format!("Bearer {}", self.config.api_key))
|
|
.json(&ChatRequest {
|
|
model: self.config.model.clone(),
|
|
messages: messages.clone(),
|
|
tools: tools_vec.clone(),
|
|
})
|
|
.send()
|
|
.await;
|
|
|
|
let http_resp = match result {
|
|
Ok(r) => r,
|
|
Err(e) => {
|
|
tracing::warn!("LLM request error (attempt {}): {}", attempt + 1, e);
|
|
last_err = Some(anyhow::anyhow!("{}", e));
|
|
continue;
|
|
}
|
|
};
|
|
|
|
let status = http_resp.status();
|
|
if status.is_server_error() || status.as_u16() == 429 {
|
|
let body = http_resp.text().await.unwrap_or_default();
|
|
tracing::warn!("LLM API error {} (attempt {}): {}", status, attempt + 1, &body[..body.len().min(200)]);
|
|
last_err = Some(anyhow::anyhow!("LLM API error {}: {}", status, body));
|
|
continue;
|
|
}
|
|
|
|
if !status.is_success() {
|
|
let body = http_resp.text().await.unwrap_or_default();
|
|
tracing::error!("LLM API error {}: {}", status, &body[..body.len().min(500)]);
|
|
anyhow::bail!("LLM API error {}: {}", status, body);
|
|
}
|
|
|
|
let body = http_resp.text().await?;
|
|
let resp: ChatResponse = serde_json::from_str(&body).map_err(|e| {
|
|
tracing::error!("LLM response parse error: {}. Body: {}", e, &body[..body.len().min(500)]);
|
|
anyhow::anyhow!("Failed to parse LLM response: {}", e)
|
|
})?;
|
|
|
|
return Ok(resp);
|
|
}
|
|
|
|
Err(last_err.unwrap_or_else(|| anyhow::anyhow!("LLM call failed after {} retries", max_retries)))
|
|
}
|
|
}
|