diff --git a/src/config.rs b/src/config.rs index ed53af4..c561036 100644 --- a/src/config.rs +++ b/src/config.rs @@ -10,6 +10,8 @@ pub struct Config { #[serde(default)] pub backend: BackendConfig, #[serde(default)] + pub whisper_url: Option, + #[serde(default)] pub gitea: Option, #[serde(default)] pub nocmem: Option, diff --git a/src/display.rs b/src/display.rs index 6cf617f..fa55679 100644 --- a/src/display.rs +++ b/src/display.rs @@ -192,49 +192,37 @@ pub fn build_user_content( format!("{text}\n\n[scratch]\n{scratch}") }; - // collect media data (images + videos + audio) - #[derive(PartialEq)] - enum MediaKind { Image, Video, Audio } + // collect media data (images + videos) let mut media_parts: Vec = Vec::new(); - tracing::info!("build_user_content: {} media files", media.len()); for path in media { - tracing::info!(" media file: {:?}, ext={:?}, exists={}", path, path.extension(), path.exists()); - let (mime, kind) = match path + let (mime, is_video) = match path .extension() .and_then(|e| e.to_str()) .map(|e| e.to_lowercase()) .as_deref() { - Some("jpg" | "jpeg") => ("image/jpeg", MediaKind::Image), - Some("png") => ("image/png", MediaKind::Image), - Some("gif") => ("image/gif", MediaKind::Image), - Some("webp") => ("image/webp", MediaKind::Image), - Some("mp4") => ("video/mp4", MediaKind::Video), - Some("webm") => ("video/webm", MediaKind::Video), - Some("mov") => ("video/quicktime", MediaKind::Video), - Some("ogg" | "oga" | "opus") => ("audio/ogg", MediaKind::Audio), - Some("wav") => ("audio/wav", MediaKind::Audio), - Some("mp3") => ("audio/mpeg", MediaKind::Audio), - Some("flac") => ("audio/flac", MediaKind::Audio), - Some("m4a") => ("audio/mp4", MediaKind::Audio), + Some("jpg" | "jpeg") => ("image/jpeg", false), + Some("png") => ("image/png", false), + Some("gif") => ("image/gif", false), + Some("webp") => ("image/webp", false), + Some("mp4") => ("video/mp4", true), + Some("webm") => ("video/webm", true), + Some("mov") => ("video/quicktime", true), _ => continue, }; if let Ok(data) = std::fs::read(path) { let b64 = base64::engine::general_purpose::STANDARD.encode(&data); let data_url = format!("data:{mime};base64,{b64}"); - match kind { - MediaKind::Video => media_parts.push(serde_json::json!({ + if is_video { + media_parts.push(serde_json::json!({ "type": "video_url", "video_url": {"url": data_url} - })), - MediaKind::Audio => media_parts.push(serde_json::json!({ - "type": "audio_url", - "audio_url": {"url": data_url} - })), - MediaKind::Image => media_parts.push(serde_json::json!({ + })); + } else { + media_parts.push(serde_json::json!({ "type": "image_url", "image_url": {"url": data_url} - })), + })); } } } diff --git a/src/main.rs b/src/main.rs index 79946c8..3e004ef 100644 --- a/src/main.rs +++ b/src/main.rs @@ -204,7 +204,7 @@ async fn handle_inner( ) -> Result<()> { let mut uploaded: Vec = Vec::new(); let mut download_errors: Vec = Vec::new(); - let transcriptions: Vec = Vec::new(); + let mut transcriptions: Vec = Vec::new(); if let Some(doc) = msg.document() { let name = doc.file_name.as_deref().unwrap_or("file"); @@ -228,7 +228,20 @@ async fn handle_inner( let fallback = format!("audio_{}.ogg", Local::now().format("%H%M%S")); let name = audio.file_name.as_deref().unwrap_or(&fallback); match download_tg_file(bot, &audio.file.id, name).await { - Ok(p) => uploaded.push(p), + Ok(p) => { + if let Some(url) = &config.whisper_url { + match transcribe_audio(url, &p).await { + Ok(t) if !t.is_empty() => transcriptions.push(t), + Ok(_) => uploaded.push(p), + Err(e) => { + warn!("transcribe failed: {e:#}"); + uploaded.push(p); + } + } + } else { + uploaded.push(p); + } + } Err(e) => download_errors.push(format!("audio: {e:#}")), } } @@ -236,7 +249,20 @@ async fn handle_inner( if let Some(voice) = msg.voice() { let name = format!("voice_{}.ogg", Local::now().format("%H%M%S")); match download_tg_file(bot, &voice.file.id, &name).await { - Ok(p) => uploaded.push(p), + Ok(p) => { + if let Some(url) = &config.whisper_url { + match transcribe_audio(url, &p).await { + Ok(t) if !t.is_empty() => transcriptions.push(t), + Ok(_) => uploaded.push(p), + Err(e) => { + warn!("transcribe failed: {e:#}"); + uploaded.push(p); + } + } + } else { + uploaded.push(p); + } + } Err(e) => download_errors.push(format!("voice: {e:#}")), } } @@ -476,17 +502,8 @@ fn build_prompt( parts.push(format!("[语音消息] {t}")); } - // only mention files that won't be sent as multimodal content - let multimodal_exts = ["jpg", "jpeg", "png", "gif", "webp", "mp4", "webm", "mov", - "ogg", "oga", "opus", "wav", "mp3", "flac", "m4a"]; for f in uploaded { - let is_media = f.extension() - .and_then(|e| e.to_str()) - .map(|e| multimodal_exts.contains(&e.to_lowercase().as_str())) - .unwrap_or(false); - if !is_media { - parts.push(format!("[用户上传了文件: {}]", f.display())); - } + parts.push(format!("[用户上传了文件: {}]", f.display())); } for e in errors { @@ -500,3 +517,24 @@ fn build_prompt( parts.join("\n") } +async fn transcribe_audio(whisper_url: &str, file_path: &Path) -> Result { + let client = reqwest::Client::builder() + .timeout(std::time::Duration::from_secs(60)) + .build()?; + let url = format!("{}/v1/audio/transcriptions", whisper_url.trim_end_matches('/')); + let file_bytes = tokio::fs::read(file_path).await?; + let file_name = file_path + .file_name() + .and_then(|n| n.to_str()) + .unwrap_or("audio.ogg") + .to_string(); + let part = reqwest::multipart::Part::bytes(file_bytes) + .file_name(file_name) + .mime_str("audio/ogg")?; + let form = reqwest::multipart::Form::new() + .part("file", part) + .text("model", "base"); + let resp = client.post(&url).multipart(form).send().await?.error_for_status()?; + let json: serde_json::Value = resp.json().await?; + Ok(json["text"].as_str().unwrap_or("").to_string()) +} diff --git a/src/stream.rs b/src/stream.rs index 86af17f..fd431d9 100644 --- a/src/stream.rs +++ b/src/stream.rs @@ -71,20 +71,6 @@ pub async fn run_openai_with_tools( messages.len(), tools.as_array().map(|a| a.len()).unwrap_or(0)); - // log last user message structure for debugging - if let Some(last) = messages.last() { - let content = &last["content"]; - if content.is_array() { - let types: Vec<&str> = content.as_array().unwrap() - .iter() - .filter_map(|v| v["type"].as_str()) - .collect(); - info!("last user content: multimodal {:?}", types); - } else if let Some(s) = content.as_str() { - info!("last user content: text ({} chars)", s.len()); - } - } - let resp_raw = client .post(&url) .header("Authorization", format!("Bearer {api_key}")) @@ -244,10 +230,6 @@ pub async fn run_openai_with_tools( let _ = output.finalize(&cleaned).await; } - // log successful API call - let req_json = serde_json::to_string(&body).unwrap_or_default(); - state.log_api(sid, &req_json, &cleaned, 200).await; - return Ok(cleaned); } }