diff --git a/src/config.rs b/src/config.rs index c561036..ed53af4 100644 --- a/src/config.rs +++ b/src/config.rs @@ -10,8 +10,6 @@ pub struct Config { #[serde(default)] pub backend: BackendConfig, #[serde(default)] - pub whisper_url: Option, - #[serde(default)] pub gitea: Option, #[serde(default)] pub nocmem: Option, diff --git a/src/display.rs b/src/display.rs index fa55679..6cf617f 100644 --- a/src/display.rs +++ b/src/display.rs @@ -192,37 +192,49 @@ pub fn build_user_content( format!("{text}\n\n[scratch]\n{scratch}") }; - // collect media data (images + videos) + // collect media data (images + videos + audio) + #[derive(PartialEq)] + enum MediaKind { Image, Video, Audio } let mut media_parts: Vec = Vec::new(); + tracing::info!("build_user_content: {} media files", media.len()); for path in media { - let (mime, is_video) = match path + tracing::info!(" media file: {:?}, ext={:?}, exists={}", path, path.extension(), path.exists()); + let (mime, kind) = match path .extension() .and_then(|e| e.to_str()) .map(|e| e.to_lowercase()) .as_deref() { - Some("jpg" | "jpeg") => ("image/jpeg", false), - Some("png") => ("image/png", false), - Some("gif") => ("image/gif", false), - Some("webp") => ("image/webp", false), - Some("mp4") => ("video/mp4", true), - Some("webm") => ("video/webm", true), - Some("mov") => ("video/quicktime", true), + Some("jpg" | "jpeg") => ("image/jpeg", MediaKind::Image), + Some("png") => ("image/png", MediaKind::Image), + Some("gif") => ("image/gif", MediaKind::Image), + Some("webp") => ("image/webp", MediaKind::Image), + Some("mp4") => ("video/mp4", MediaKind::Video), + Some("webm") => ("video/webm", MediaKind::Video), + Some("mov") => ("video/quicktime", MediaKind::Video), + Some("ogg" | "oga" | "opus") => ("audio/ogg", MediaKind::Audio), + Some("wav") => ("audio/wav", MediaKind::Audio), + Some("mp3") => ("audio/mpeg", MediaKind::Audio), + Some("flac") => ("audio/flac", MediaKind::Audio), + Some("m4a") => ("audio/mp4", MediaKind::Audio), _ => continue, }; if let Ok(data) = std::fs::read(path) { let b64 = base64::engine::general_purpose::STANDARD.encode(&data); let data_url = format!("data:{mime};base64,{b64}"); - if is_video { - media_parts.push(serde_json::json!({ + match kind { + MediaKind::Video => media_parts.push(serde_json::json!({ "type": "video_url", "video_url": {"url": data_url} - })); - } else { - media_parts.push(serde_json::json!({ + })), + MediaKind::Audio => media_parts.push(serde_json::json!({ + "type": "audio_url", + "audio_url": {"url": data_url} + })), + MediaKind::Image => media_parts.push(serde_json::json!({ "type": "image_url", "image_url": {"url": data_url} - })); + })), } } } diff --git a/src/main.rs b/src/main.rs index 3e004ef..79946c8 100644 --- a/src/main.rs +++ b/src/main.rs @@ -204,7 +204,7 @@ async fn handle_inner( ) -> Result<()> { let mut uploaded: Vec = Vec::new(); let mut download_errors: Vec = Vec::new(); - let mut transcriptions: Vec = Vec::new(); + let transcriptions: Vec = Vec::new(); if let Some(doc) = msg.document() { let name = doc.file_name.as_deref().unwrap_or("file"); @@ -228,20 +228,7 @@ async fn handle_inner( let fallback = format!("audio_{}.ogg", Local::now().format("%H%M%S")); let name = audio.file_name.as_deref().unwrap_or(&fallback); match download_tg_file(bot, &audio.file.id, name).await { - Ok(p) => { - if let Some(url) = &config.whisper_url { - match transcribe_audio(url, &p).await { - Ok(t) if !t.is_empty() => transcriptions.push(t), - Ok(_) => uploaded.push(p), - Err(e) => { - warn!("transcribe failed: {e:#}"); - uploaded.push(p); - } - } - } else { - uploaded.push(p); - } - } + Ok(p) => uploaded.push(p), Err(e) => download_errors.push(format!("audio: {e:#}")), } } @@ -249,20 +236,7 @@ async fn handle_inner( if let Some(voice) = msg.voice() { let name = format!("voice_{}.ogg", Local::now().format("%H%M%S")); match download_tg_file(bot, &voice.file.id, &name).await { - Ok(p) => { - if let Some(url) = &config.whisper_url { - match transcribe_audio(url, &p).await { - Ok(t) if !t.is_empty() => transcriptions.push(t), - Ok(_) => uploaded.push(p), - Err(e) => { - warn!("transcribe failed: {e:#}"); - uploaded.push(p); - } - } - } else { - uploaded.push(p); - } - } + Ok(p) => uploaded.push(p), Err(e) => download_errors.push(format!("voice: {e:#}")), } } @@ -502,8 +476,17 @@ fn build_prompt( parts.push(format!("[语音消息] {t}")); } + // only mention files that won't be sent as multimodal content + let multimodal_exts = ["jpg", "jpeg", "png", "gif", "webp", "mp4", "webm", "mov", + "ogg", "oga", "opus", "wav", "mp3", "flac", "m4a"]; for f in uploaded { - parts.push(format!("[用户上传了文件: {}]", f.display())); + let is_media = f.extension() + .and_then(|e| e.to_str()) + .map(|e| multimodal_exts.contains(&e.to_lowercase().as_str())) + .unwrap_or(false); + if !is_media { + parts.push(format!("[用户上传了文件: {}]", f.display())); + } } for e in errors { @@ -517,24 +500,3 @@ fn build_prompt( parts.join("\n") } -async fn transcribe_audio(whisper_url: &str, file_path: &Path) -> Result { - let client = reqwest::Client::builder() - .timeout(std::time::Duration::from_secs(60)) - .build()?; - let url = format!("{}/v1/audio/transcriptions", whisper_url.trim_end_matches('/')); - let file_bytes = tokio::fs::read(file_path).await?; - let file_name = file_path - .file_name() - .and_then(|n| n.to_str()) - .unwrap_or("audio.ogg") - .to_string(); - let part = reqwest::multipart::Part::bytes(file_bytes) - .file_name(file_name) - .mime_str("audio/ogg")?; - let form = reqwest::multipart::Form::new() - .part("file", part) - .text("model", "base"); - let resp = client.post(&url).multipart(form).send().await?.error_for_status()?; - let json: serde_json::Value = resp.json().await?; - Ok(json["text"].as_str().unwrap_or("").to_string()) -} diff --git a/src/stream.rs b/src/stream.rs index fd431d9..86af17f 100644 --- a/src/stream.rs +++ b/src/stream.rs @@ -71,6 +71,20 @@ pub async fn run_openai_with_tools( messages.len(), tools.as_array().map(|a| a.len()).unwrap_or(0)); + // log last user message structure for debugging + if let Some(last) = messages.last() { + let content = &last["content"]; + if content.is_array() { + let types: Vec<&str> = content.as_array().unwrap() + .iter() + .filter_map(|v| v["type"].as_str()) + .collect(); + info!("last user content: multimodal {:?}", types); + } else if let Some(s) = content.as_str() { + info!("last user content: text ({} chars)", s.len()); + } + } + let resp_raw = client .post(&url) .header("Authorization", format!("Bearer {api_key}")) @@ -230,6 +244,10 @@ pub async fn run_openai_with_tools( let _ = output.finalize(&cleaned).await; } + // log successful API call + let req_json = serde_json::to_string(&body).unwrap_or_default(); + state.log_api(sid, &req_json, &cleaned, 200).await; + return Ok(cleaned); } }