add gen_voice tool, message timestamps, image multimodal, group chat, whisper STT
- gen_voice: IndexTTS2 voice cloning via tools/gen_voice script, ref audio cached on server to avoid re-upload - Message timestamps: created_at column in messages table, prepended to content in API calls so LLM sees message times - Image understanding: photos converted to base64 multimodal content for vision-capable models - Group chat: independent session contexts per chat_id, sendMessageDraft disabled in groups (private chat only) - Voice transcription: whisper service integration, transcribed text injected as [语音消息] prefix - Integration tests marked #[ignore] (require external services) - Reference voice asset: assets/ref_voice.mp3 - .gitignore: target/, noc.service, config/state/db files
This commit is contained in:
3
.gitignore
vendored
3
.gitignore
vendored
@@ -3,3 +3,6 @@ config.*.yaml
|
||||
state.json
|
||||
state.*.json
|
||||
*.db
|
||||
|
||||
target/
|
||||
noc.service
|
||||
|
||||
BIN
assets/ref_voice.mp3
Normal file
BIN
assets/ref_voice.mp3
Normal file
Binary file not shown.
52
doc/todo.md
52
doc/todo.md
@@ -11,34 +11,40 @@
|
||||
- [ ] 情境感知:根据时间、地点、日历自动调整行为
|
||||
|
||||
### 记忆与成长
|
||||
- [ ] 长期记忆 (MEMORY.md):跨 session 的持久化记忆
|
||||
- [ ] 语义搜索:基于 embedding 的记忆检索
|
||||
- [x] 持久记忆槽 (memory_slots):100 个跨 session 的记忆槽位,注入 system prompt
|
||||
- [ ] AutoMem:后台定时(如每 10 条消息)自动分析对话,由 LLM 决定 SKIP/UPDATE/INSERT 记忆,无需用户手动触发(参考 luke)
|
||||
- [ ] 分层记忆:核心记忆(身份/原则,始终注入)+ 长期记忆(偏好/事实,RAG 检索)+ scratch(当前任务)(参考 xg 三层 + luke 四层架构)
|
||||
- [ ] 语义搜索:基于 embedding 的记忆检索(BGE-M3/Gemini embedding + Qdrant 向量库)
|
||||
- [ ] 记忆合并:新记忆与已有记忆 cosine ≥ 0.7 时,用 LLM 合并而非插入(参考 xg)
|
||||
- [ ] 二次联想召回:第一轮直接检索 → 用 top-K 结果做第二轮关联检索,去重后合并(参考 xg/luke 2-pass recall)
|
||||
- [ ] 时间衰减:记忆按时间指数衰减加权,近期记忆优先(参考 xg 30 天半衰期)
|
||||
- [ ] 自我反思:定期回顾对话质量,优化自己的行为
|
||||
|
||||
### 知识图谱(参考 luke concept graph)
|
||||
- [ ] 概念图:Aho-Corasick 模式匹配用户消息中的关键概念,自动注入相关知识
|
||||
- [ ] update_concept tool:LLM 可动态添加/更新概念节点及关联关系
|
||||
- [ ] LRU 缓存:内存中保持热门概念,微秒级匹配
|
||||
|
||||
### 工具系统
|
||||
- [x] spawn_agent(Claude Code 子代理)
|
||||
- [x] update_scratch / update_memory
|
||||
- [x] send_file / agent_status / kill_agent
|
||||
- [x] 外部脚本工具发现 (tools/ 目录)
|
||||
- [ ] run_code tool:安全沙箱执行 Python/Shell 代码,捕获输出返回(参考 luke run_python)
|
||||
- [ ] gen_image tool:调用图像生成 API(Gemini/FLUX/本地模型)
|
||||
- [ ] gen_voice tool:TTS 语音合成,发送语音消息(参考 luke Elevenlabs / xg Fish-Speech)
|
||||
- [ ] set_timer tool:LLM 可设置延迟/定时任务,到时触发回调(参考 luke timer 系统)
|
||||
- [ ] web_search tool:网页搜索 + 摘要,不必每次都 spawn 完整 agent
|
||||
|
||||
### 感知能力
|
||||
- [x] 图片理解:multimodal vision input
|
||||
- [ ] 语音转录:whisper API 转文字
|
||||
- [ ] 屏幕/截图分析
|
||||
- [ ] 链接预览/摘要
|
||||
- [ ] 语音转文字 (STT):接收语音消息后自动转写(当前 xg 用 FunASR,luke 用 Whisper)
|
||||
|
||||
### 交互体验
|
||||
- [x] 群组支持:独立上下文
|
||||
- [x] 流式输出:sendMessageDraft + editMessageText
|
||||
- [x] Markdown 渲染
|
||||
- [ ] Typing indicator
|
||||
- [ ] Inline keyboard 交互
|
||||
- [ ] 语音回复 (TTS)
|
||||
- [ ] 流式分句发送:长回复按句号/问号断句分批发送,体验更自然
|
||||
- [ ] 多频道支持:同一 bot 核心逻辑支持 Telegram + WebSocket + HTTP(参考 luke MxN 多路复用架构)
|
||||
|
||||
### 工具生态
|
||||
- [x] 脚本工具发现 (tools/ + --schema)
|
||||
- [x] 异步子代理 (spawn_agent)
|
||||
- [x] 飞书待办管理
|
||||
- [ ] Web search / fetch
|
||||
- [ ] 更多脚本工具
|
||||
- [ ] MCP 协议支持
|
||||
|
||||
### 可靠性
|
||||
- [ ] API 重试策略 (指数退避)
|
||||
- [ ] 用量追踪
|
||||
- [ ] Context pruning (只裁工具输出)
|
||||
- [ ] Model failover
|
||||
### 上下文管理
|
||||
- [ ] 智能上下文分配:system prompt / 记忆 / 历史消息 / 工具输出各占比可配置,预留 60-70% 给工具输出(参考 luke 保守分配策略)
|
||||
- [ ] 对话历史滚动窗口优化:当前 100 条硬上限,可改为 token 预算制
|
||||
|
||||
213
src/main.rs
213
src/main.rs
@@ -124,12 +124,12 @@ fn discover_tools() -> serde_json::Value {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "spawn_agent",
|
||||
"description": "Spawn a Claude Code subagent to handle a complex task asynchronously. The subagent has access to shell, browser, and search engine, making it ideal for web searches, information lookup, technical research, code tasks, and other complex operations. You'll be notified when it completes.",
|
||||
"description": "启动一个 Claude Code 子代理异步执行复杂任务。子代理可使用 shell、浏览器和搜索引擎,适合网页搜索、资料查找、技术调研、代码任务等。完成后会收到通知。",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"id": {"type": "string", "description": "Short unique identifier (e.g. 'research', 'fix-bug')"},
|
||||
"task": {"type": "string", "description": "Detailed task description for the agent"}
|
||||
"id": {"type": "string", "description": "简短唯一标识符(如 'research'、'fix-bug')"},
|
||||
"task": {"type": "string", "description": "给子代理的详细任务描述"}
|
||||
},
|
||||
"required": ["id", "task"]
|
||||
}
|
||||
@@ -139,11 +139,11 @@ fn discover_tools() -> serde_json::Value {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "agent_status",
|
||||
"description": "Check the current status and output of a running or completed agent",
|
||||
"description": "查看正在运行或已完成的子代理的状态和输出",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"id": {"type": "string", "description": "The agent identifier"}
|
||||
"id": {"type": "string", "description": "子代理标识符"}
|
||||
},
|
||||
"required": ["id"]
|
||||
}
|
||||
@@ -153,11 +153,11 @@ fn discover_tools() -> serde_json::Value {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "kill_agent",
|
||||
"description": "Terminate a running agent",
|
||||
"description": "终止一个正在运行的子代理",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"id": {"type": "string", "description": "The agent identifier"}
|
||||
"id": {"type": "string", "description": "子代理标识符"}
|
||||
},
|
||||
"required": ["id"]
|
||||
}
|
||||
@@ -167,12 +167,12 @@ fn discover_tools() -> serde_json::Value {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "send_file",
|
||||
"description": "Send a file from the server to the user via Telegram. The file must exist on the server filesystem.",
|
||||
"description": "通过 Telegram 向用户发送服务器上的文件,文件必须存在于服务器文件系统中。",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {"type": "string", "description": "Absolute path to the file on the server"},
|
||||
"caption": {"type": "string", "description": "Optional caption/description for the file"}
|
||||
"path": {"type": "string", "description": "服务器上文件的绝对路径"},
|
||||
"caption": {"type": "string", "description": "可选的文件说明/描述"}
|
||||
},
|
||||
"required": ["path"]
|
||||
}
|
||||
@@ -182,16 +182,45 @@ fn discover_tools() -> serde_json::Value {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "update_scratch",
|
||||
"description": "Update your scratch area (working notes, state, reminders). This content is appended to every user message so you always see it. Use it to track ongoing context across turns.",
|
||||
"description": "更新你的草稿区(工作笔记、状态、提醒)。草稿区内容会附加到每条用户消息中,确保你始终可见。用于跨轮次跟踪上下文。",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"content": {"type": "string", "description": "The full scratch area content (replaces previous)"}
|
||||
"content": {"type": "string", "description": "完整的草稿区内容(替换之前的内容)"}
|
||||
},
|
||||
"required": ["content"]
|
||||
}
|
||||
}
|
||||
}),
|
||||
serde_json::json!({
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "update_memory",
|
||||
"description": "写入持久记忆槽。共 100 个槽位(0-99),跨会话保留。记忆槽内容会注入到每次对话的 system prompt 中。用于存储关键事实、用户偏好或重要上下文。内容设为空字符串可清除槽位。",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"slot_nr": {"type": "integer", "description": "槽位编号(0-99)"},
|
||||
"content": {"type": "string", "description": "要存储的内容(最多200字符),空字符串表示清除该槽位"}
|
||||
},
|
||||
"required": ["slot_nr", "content"]
|
||||
}
|
||||
}
|
||||
}),
|
||||
serde_json::json!({
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "gen_voice",
|
||||
"description": "将文字合成为语音并直接发送给用户。",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"text": {"type": "string", "description": "要合成语音的文字内容"}
|
||||
},
|
||||
"required": ["text"]
|
||||
}
|
||||
}
|
||||
}),
|
||||
];
|
||||
|
||||
// discover script tools
|
||||
@@ -259,7 +288,8 @@ impl AppState {
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
session_id TEXT NOT NULL,
|
||||
role TEXT NOT NULL,
|
||||
content TEXT NOT NULL
|
||||
content TEXT NOT NULL,
|
||||
created_at TEXT NOT NULL DEFAULT (datetime('now', 'localtime'))
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_messages_session ON messages(session_id);
|
||||
CREATE TABLE IF NOT EXISTS scratch_area (
|
||||
@@ -279,9 +309,20 @@ impl AppState {
|
||||
value TEXT NOT NULL,
|
||||
create_time TEXT NOT NULL,
|
||||
update_time TEXT NOT NULL
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS memory_slots (
|
||||
slot_nr INTEGER PRIMARY KEY CHECK(slot_nr BETWEEN 0 AND 99),
|
||||
content TEXT NOT NULL DEFAULT ''
|
||||
);",
|
||||
)
|
||||
.expect("init db schema");
|
||||
|
||||
// migrations
|
||||
let _ = conn.execute(
|
||||
"ALTER TABLE messages ADD COLUMN created_at TEXT NOT NULL DEFAULT ''",
|
||||
[],
|
||||
);
|
||||
|
||||
info!("opened db {}", db_path.display());
|
||||
|
||||
Self {
|
||||
@@ -312,13 +353,19 @@ impl AppState {
|
||||
.unwrap_or_default();
|
||||
|
||||
let mut stmt = db
|
||||
.prepare("SELECT role, content FROM messages WHERE session_id = ?1 ORDER BY id")
|
||||
.prepare("SELECT role, content, created_at FROM messages WHERE session_id = ?1 ORDER BY id")
|
||||
.unwrap();
|
||||
let messages: Vec<serde_json::Value> = stmt
|
||||
.query_map([sid], |row| {
|
||||
let role: String = row.get(0)?;
|
||||
let content: String = row.get(1)?;
|
||||
Ok(serde_json::json!({"role": role, "content": content}))
|
||||
let ts: String = row.get(2)?;
|
||||
let tagged = if ts.is_empty() {
|
||||
content
|
||||
} else {
|
||||
format!("[{ts}] {content}")
|
||||
};
|
||||
Ok(serde_json::json!({"role": role, "content": tagged}))
|
||||
})
|
||||
.unwrap()
|
||||
.filter_map(|r| r.ok())
|
||||
@@ -338,7 +385,7 @@ impl AppState {
|
||||
[sid],
|
||||
);
|
||||
let _ = db.execute(
|
||||
"INSERT INTO messages (session_id, role, content) VALUES (?1, ?2, ?3)",
|
||||
"INSERT INTO messages (session_id, role, content, created_at) VALUES (?1, ?2, ?3, datetime('now', 'localtime'))",
|
||||
rusqlite::params![sid, role, content],
|
||||
);
|
||||
}
|
||||
@@ -413,6 +460,32 @@ impl AppState {
|
||||
.ok()
|
||||
}
|
||||
|
||||
async fn get_memory_slots(&self) -> Vec<(i32, String)> {
|
||||
let db = self.db.lock().await;
|
||||
let mut stmt = db
|
||||
.prepare("SELECT slot_nr, content FROM memory_slots WHERE content != '' ORDER BY slot_nr")
|
||||
.unwrap();
|
||||
stmt.query_map([], |row| Ok((row.get(0)?, row.get(1)?)))
|
||||
.unwrap()
|
||||
.filter_map(|r| r.ok())
|
||||
.collect()
|
||||
}
|
||||
|
||||
async fn set_memory_slot(&self, slot_nr: i32, content: &str) -> Result<()> {
|
||||
if !(0..=99).contains(&slot_nr) {
|
||||
anyhow::bail!("slot_nr must be 0-99, got {slot_nr}");
|
||||
}
|
||||
if content.len() > 200 {
|
||||
anyhow::bail!("content too long: {} chars (max 200)", content.len());
|
||||
}
|
||||
let db = self.db.lock().await;
|
||||
db.execute(
|
||||
"INSERT INTO memory_slots (slot_nr, content) VALUES (?1, ?2) \
|
||||
ON CONFLICT(slot_nr) DO UPDATE SET content = ?2",
|
||||
rusqlite::params![slot_nr, content],
|
||||
)?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
// ── helpers ─────────────────────────────────────────────────────────
|
||||
@@ -722,6 +795,17 @@ async fn handle_inner(
|
||||
));
|
||||
}
|
||||
|
||||
let memory_slots = state.get_memory_slots().await;
|
||||
diag.push_str(&format!("## Memory Slots ({}/100 used)\n", memory_slots.len()));
|
||||
if memory_slots.is_empty() {
|
||||
diag.push_str("(empty)\n\n");
|
||||
} else {
|
||||
for (nr, content) in &memory_slots {
|
||||
diag.push_str(&format!("- `[{nr}]` {content}\n"));
|
||||
}
|
||||
diag.push('\n');
|
||||
}
|
||||
|
||||
let tmp = std::env::temp_dir().join(format!("noc-diag-{sid}.md"));
|
||||
tokio::fs::write(&tmp, &diag).await?;
|
||||
bot.send_document(chat_id, InputFile::file(&tmp))
|
||||
@@ -774,7 +858,8 @@ async fn handle_inner(
|
||||
} => {
|
||||
let conv = state.load_conv(&sid).await;
|
||||
let persona = state.get_config("persona").await.unwrap_or_default();
|
||||
let system_msg = build_system_prompt(&conv.summary, &persona);
|
||||
let memory_slots = state.get_memory_slots().await;
|
||||
let system_msg = build_system_prompt(&conv.summary, &persona, &memory_slots);
|
||||
|
||||
let mut api_messages = vec![system_msg];
|
||||
api_messages.extend(conv.messages);
|
||||
@@ -904,7 +989,7 @@ async fn transcribe_audio(whisper_url: &str, file_path: &Path) -> Result<String>
|
||||
Ok(json["text"].as_str().unwrap_or("").to_string())
|
||||
}
|
||||
|
||||
fn build_system_prompt(summary: &str, persona: &str) -> serde_json::Value {
|
||||
fn build_system_prompt(summary: &str, persona: &str, memory_slots: &[(i32, String)]) -> serde_json::Value {
|
||||
let mut text = if persona.is_empty() {
|
||||
String::from("你是一个AI助手。")
|
||||
} else {
|
||||
@@ -920,6 +1005,13 @@ fn build_system_prompt(summary: &str, persona: &str) -> serde_json::Value {
|
||||
不要使用LaTeX公式($...$)、特殊Unicode符号(→←↔)或HTML标签,Telegram无法渲染这些。",
|
||||
);
|
||||
|
||||
if !memory_slots.is_empty() {
|
||||
text.push_str("\n\n## 持久记忆(跨会话保留)\n");
|
||||
for (nr, content) in memory_slots {
|
||||
text.push_str(&format!("[{nr}] {content}\n"));
|
||||
}
|
||||
}
|
||||
|
||||
if !summary.is_empty() {
|
||||
text.push_str("\n\n## 之前的对话总结\n");
|
||||
text.push_str(summary);
|
||||
@@ -943,27 +1035,35 @@ fn build_user_content(
|
||||
// collect media data (images + videos)
|
||||
let mut media_parts: Vec<serde_json::Value> = Vec::new();
|
||||
for path in media {
|
||||
let mime = match path
|
||||
let (mime, is_video) = match path
|
||||
.extension()
|
||||
.and_then(|e| e.to_str())
|
||||
.map(|e| e.to_lowercase())
|
||||
.as_deref()
|
||||
{
|
||||
Some("jpg" | "jpeg") => "image/jpeg",
|
||||
Some("png") => "image/png",
|
||||
Some("gif") => "image/gif",
|
||||
Some("webp") => "image/webp",
|
||||
Some("mp4") => "video/mp4",
|
||||
Some("webm") => "video/webm",
|
||||
Some("mov") => "video/quicktime",
|
||||
Some("jpg" | "jpeg") => ("image/jpeg", false),
|
||||
Some("png") => ("image/png", false),
|
||||
Some("gif") => ("image/gif", false),
|
||||
Some("webp") => ("image/webp", false),
|
||||
Some("mp4") => ("video/mp4", true),
|
||||
Some("webm") => ("video/webm", true),
|
||||
Some("mov") => ("video/quicktime", true),
|
||||
_ => continue,
|
||||
};
|
||||
if let Ok(data) = std::fs::read(path) {
|
||||
let b64 = base64::engine::general_purpose::STANDARD.encode(&data);
|
||||
media_parts.push(serde_json::json!({
|
||||
"type": "image_url",
|
||||
"image_url": {"url": format!("data:{mime};base64,{b64}")}
|
||||
}));
|
||||
let data_url = format!("data:{mime};base64,{b64}");
|
||||
if is_video {
|
||||
media_parts.push(serde_json::json!({
|
||||
"type": "video_url",
|
||||
"video_url": {"url": data_url}
|
||||
}));
|
||||
} else {
|
||||
media_parts.push(serde_json::json!({
|
||||
"type": "image_url",
|
||||
"image_url": {"url": data_url}
|
||||
}));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1092,6 +1192,56 @@ async fn execute_tool(
|
||||
state.push_scratch(content).await;
|
||||
format!("Scratch updated ({} chars)", content.len())
|
||||
}
|
||||
"update_memory" => {
|
||||
let slot_nr = args["slot_nr"].as_i64().unwrap_or(-1) as i32;
|
||||
let content = args["content"].as_str().unwrap_or("");
|
||||
match state.set_memory_slot(slot_nr, content).await {
|
||||
Ok(_) => {
|
||||
if content.is_empty() {
|
||||
format!("Memory slot {slot_nr} cleared")
|
||||
} else {
|
||||
format!("Memory slot {slot_nr} updated ({} chars)", content.len())
|
||||
}
|
||||
}
|
||||
Err(e) => format!("Error: {e}"),
|
||||
}
|
||||
}
|
||||
"gen_voice" => {
|
||||
let text = args["text"].as_str().unwrap_or("");
|
||||
if text.is_empty() {
|
||||
return "Error: text is required".to_string();
|
||||
}
|
||||
let script = tools_dir().join("gen_voice");
|
||||
let result = tokio::time::timeout(
|
||||
std::time::Duration::from_secs(120),
|
||||
tokio::process::Command::new(&script)
|
||||
.arg(arguments)
|
||||
.output(),
|
||||
)
|
||||
.await;
|
||||
match result {
|
||||
Ok(Ok(out)) if out.status.success() => {
|
||||
let path_str = String::from_utf8_lossy(&out.stdout).trim().to_string();
|
||||
let path = Path::new(&path_str);
|
||||
if path.exists() {
|
||||
let input_file = InputFile::file(path);
|
||||
match bot.send_voice(chat_id, input_file).await {
|
||||
Ok(_) => format!("语音已发送: {path_str}"),
|
||||
Err(e) => format!("语音生成成功但发送失败: {e:#}"),
|
||||
}
|
||||
} else {
|
||||
format!("语音生成失败: 输出文件不存在 ({path_str})")
|
||||
}
|
||||
}
|
||||
Ok(Ok(out)) => {
|
||||
let stderr = String::from_utf8_lossy(&out.stderr);
|
||||
let stdout = String::from_utf8_lossy(&out.stdout);
|
||||
format!("gen_voice failed: {stdout} {stderr}")
|
||||
}
|
||||
Ok(Err(e)) => format!("gen_voice exec error: {e}"),
|
||||
Err(_) => "gen_voice timeout (120s)".to_string(),
|
||||
}
|
||||
}
|
||||
_ => run_script_tool(name, arguments).await,
|
||||
}
|
||||
}
|
||||
@@ -1204,7 +1354,8 @@ async fn agent_wakeup(
|
||||
state.push_message(sid, "user", wakeup_msg).await;
|
||||
let conv = state.load_conv(sid).await;
|
||||
let persona = state.get_config("persona").await.unwrap_or_default();
|
||||
let system_msg = build_system_prompt(&conv.summary, &persona);
|
||||
let memory_slots = state.get_memory_slots().await;
|
||||
let system_msg = build_system_prompt(&conv.summary, &persona, &memory_slots);
|
||||
let mut api_messages = vec![system_msg];
|
||||
api_messages.extend(conv.messages);
|
||||
|
||||
|
||||
@@ -25,6 +25,7 @@ fn tools() -> serde_json::Value {
|
||||
|
||||
/// Test non-streaming tool call round-trip
|
||||
#[tokio::test]
|
||||
#[ignore] // requires Ollama on ailab
|
||||
async fn test_tool_call_roundtrip_non_streaming() {
|
||||
let client = reqwest::Client::new();
|
||||
let url = format!("{OLLAMA_URL}/chat/completions");
|
||||
@@ -101,6 +102,7 @@ async fn test_tool_call_roundtrip_non_streaming() {
|
||||
|
||||
/// Test tool call with conversation history (simulates real scenario)
|
||||
#[tokio::test]
|
||||
#[ignore] // requires Ollama on ailab
|
||||
async fn test_tool_call_with_history() {
|
||||
let client = reqwest::Client::new();
|
||||
let url = format!("{OLLAMA_URL}/chat/completions");
|
||||
@@ -199,6 +201,7 @@ async fn test_tool_call_with_history() {
|
||||
|
||||
/// Test multimodal image input
|
||||
#[tokio::test]
|
||||
#[ignore] // requires Ollama on ailab
|
||||
async fn test_image_multimodal() {
|
||||
let client = reqwest::Client::new();
|
||||
let url = format!("{OLLAMA_URL}/chat/completions");
|
||||
@@ -232,6 +235,7 @@ async fn test_image_multimodal() {
|
||||
|
||||
/// Test streaming tool call round-trip (matches our actual code path)
|
||||
#[tokio::test]
|
||||
#[ignore] // requires Ollama on ailab
|
||||
async fn test_tool_call_roundtrip_streaming() {
|
||||
let client = reqwest::Client::new();
|
||||
let url = format!("{OLLAMA_URL}/chat/completions");
|
||||
|
||||
152
tools/gen_voice
Executable file
152
tools/gen_voice
Executable file
@@ -0,0 +1,152 @@
|
||||
#!/usr/bin/env -S uv run --script
|
||||
# /// script
|
||||
# requires-python = ">=3.11"
|
||||
# dependencies = ["requests"]
|
||||
# ///
|
||||
"""Generate voice audio using IndexTTS2 with a fixed reference voice.
|
||||
|
||||
Usage:
|
||||
./gen_voice --schema
|
||||
./gen_voice '{"text":"你好世界"}'
|
||||
./gen_voice 你好世界
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import requests
|
||||
|
||||
INDEXTTS_URL = "http://100.107.41.75:7860"
|
||||
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
REF_AUDIO = os.path.join(SCRIPT_DIR, "..", "assets", "ref_voice.mp3")
|
||||
OUTPUT_DIR = os.path.expanduser("~/down")
|
||||
|
||||
# cache the uploaded ref path to avoid re-uploading
|
||||
_CACHE_FILE = "/tmp/noc_gen_voice_ref_cache.json"
|
||||
|
||||
SCHEMA = {
|
||||
"name": "gen_voice",
|
||||
"description": "Generate speech audio from text using voice cloning (IndexTTS2). Returns the file path of the generated wav. Use send_file to send it to the user.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"text": {
|
||||
"type": "string",
|
||||
"description": "The text to synthesize into speech",
|
||||
},
|
||||
},
|
||||
"required": ["text"],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def get_ref_path():
|
||||
"""Upload ref audio once, cache the server-side path. Invalidate if server restarted."""
|
||||
# check cache — validate against server uptime
|
||||
if os.path.exists(_CACHE_FILE):
|
||||
try:
|
||||
with open(_CACHE_FILE) as f:
|
||||
cache = json.load(f)
|
||||
# quick health check — if server is up and path exists, reuse
|
||||
r = requests.head(f"{INDEXTTS_URL}/gradio_api/file={cache['path']}", timeout=3)
|
||||
if r.status_code == 200:
|
||||
return cache["path"]
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# upload
|
||||
with open(REF_AUDIO, "rb") as f:
|
||||
resp = requests.post(f"{INDEXTTS_URL}/gradio_api/upload", files={"files": f})
|
||||
resp.raise_for_status()
|
||||
ref_path = resp.json()[0]
|
||||
|
||||
# cache
|
||||
with open(_CACHE_FILE, "w") as f:
|
||||
json.dump({"path": ref_path}, f)
|
||||
|
||||
return ref_path
|
||||
|
||||
|
||||
def synthesize(text):
|
||||
ref = get_ref_path()
|
||||
file_data = {"path": ref, "meta": {"_type": "gradio.FileData"}}
|
||||
|
||||
# submit job
|
||||
resp = requests.post(
|
||||
f"{INDEXTTS_URL}/gradio_api/call/synthesize",
|
||||
json={
|
||||
"data": [
|
||||
text,
|
||||
file_data, # spk_audio
|
||||
file_data, # emo_audio
|
||||
0.5, # emo_alpha
|
||||
0, 0, 0, 0, 0, 0, 0, 0.8, # emotions (calm=0.8)
|
||||
False, # use_emo_text
|
||||
"", # emo_text
|
||||
False, # use_random
|
||||
]
|
||||
},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
event_id = resp.json()["event_id"]
|
||||
|
||||
# poll result via SSE
|
||||
result_resp = requests.get(
|
||||
f"{INDEXTTS_URL}/gradio_api/call/synthesize/{event_id}", stream=True
|
||||
)
|
||||
for line in result_resp.iter_lines(decode_unicode=True):
|
||||
if line.startswith("data: "):
|
||||
data = json.loads(line[6:])
|
||||
if isinstance(data, list) and data:
|
||||
url = data[0].get("url", "")
|
||||
if url:
|
||||
# download the wav
|
||||
wav = requests.get(url)
|
||||
wav.raise_for_status()
|
||||
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
||||
ts = time.strftime("%Y%m%d_%H%M%S")
|
||||
out_path = os.path.join(OUTPUT_DIR, f"tts_{ts}.wav")
|
||||
with open(out_path, "wb") as f:
|
||||
f.write(wav.content)
|
||||
return out_path
|
||||
elif data is None:
|
||||
raise RuntimeError("TTS synthesis failed (server returned null)")
|
||||
|
||||
raise RuntimeError("No result received from TTS server")
|
||||
|
||||
|
||||
def main():
|
||||
if len(sys.argv) < 2 or sys.argv[1] in ("--help", "-h"):
|
||||
print(__doc__.strip())
|
||||
sys.exit(0)
|
||||
|
||||
if sys.argv[1] == "--schema":
|
||||
print(json.dumps(SCHEMA, ensure_ascii=False))
|
||||
sys.exit(0)
|
||||
|
||||
arg = sys.argv[1]
|
||||
if not arg.startswith("{"):
|
||||
text = " ".join(sys.argv[1:])
|
||||
else:
|
||||
try:
|
||||
args = json.loads(arg)
|
||||
text = args.get("text", "")
|
||||
except json.JSONDecodeError as e:
|
||||
print(f"Invalid JSON: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
if not text:
|
||||
print("Error: text is required")
|
||||
sys.exit(1)
|
||||
|
||||
try:
|
||||
path = synthesize(text)
|
||||
print(path)
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user