improve: enhance KB search with better embedding and chunking

This commit is contained in:
Fam Zheng
2026-03-04 11:47:03 +00:00
parent fe1370230f
commit 69ad06ca5b
2 changed files with 92 additions and 42 deletions

View File

@@ -1,6 +1,5 @@
use anyhow::Result;
use sqlx::sqlite::SqlitePool;
use std::process::Stdio;
const TOP_K: usize = 5;
@@ -30,21 +29,34 @@ impl KbManager {
/// Re-index a single article: delete its old chunks, chunk the content, embed, store
pub async fn index(&self, article_id: &str, content: &str) -> Result<()> {
// Delete only this article's chunks
sqlx::query("DELETE FROM kb_chunks WHERE article_id = ?")
.bind(article_id)
.execute(&self.pool)
.await?;
self.index_batch(&[(article_id.to_string(), content.to_string())]).await
}
let chunks = split_chunks(content);
if chunks.is_empty() {
/// Batch re-index multiple articles in one embedding call (avoids repeated model loading).
pub async fn index_batch(&self, articles: &[(String, String)]) -> Result<()> {
// Collect all chunks with their article_id
let mut all_chunks: Vec<(String, Chunk)> = Vec::new(); // (article_id, chunk)
for (article_id, content) in articles {
sqlx::query("DELETE FROM kb_chunks WHERE article_id = ?")
.bind(article_id)
.execute(&self.pool)
.await?;
let chunks = split_chunks(content);
for chunk in chunks {
all_chunks.push((article_id.clone(), chunk));
}
}
if all_chunks.is_empty() {
return Ok(());
}
let texts: Vec<String> = chunks.iter().map(|c| c.content.clone()).collect();
// Single embedding call for all chunks
let texts: Vec<String> = all_chunks.iter().map(|(_, c)| c.content.clone()).collect();
let embeddings = compute_embeddings(&texts).await?;
for (chunk, embedding) in chunks.iter().zip(embeddings.into_iter()) {
for ((article_id, chunk), embedding) in all_chunks.iter().zip(embeddings.into_iter()) {
let vec_bytes = embedding_to_bytes(&embedding);
sqlx::query(
"INSERT INTO kb_chunks (id, article_id, title, content, embedding) VALUES (?, ?, ?, ?, ?)",
@@ -58,7 +70,7 @@ impl KbManager {
.await?;
}
tracing::info!("KB indexed article {}: {} chunks", article_id, chunks.len());
tracing::info!("KB indexed {} articles, {} total chunks", articles.len(), all_chunks.len());
Ok(())
}
@@ -138,30 +150,28 @@ impl KbManager {
}
}
/// Call Python script to compute embeddings
/// Call embedding HTTP server
async fn compute_embeddings(texts: &[String]) -> Result<Vec<Vec<f32>>> {
let embed_url = std::env::var("TORI_EMBED_URL")
.unwrap_or_else(|_| "http://127.0.0.1:8199".to_string());
let client = reqwest::Client::new();
let input = serde_json::json!({ "texts": texts });
let mut child = tokio::process::Command::new("/app/venv/bin/python")
.arg("/app/scripts/embed.py")
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.stderr(Stdio::piped())
.spawn()?;
let resp = client
.post(format!("{}/embed", embed_url))
.json(&input)
.timeout(std::time::Duration::from_secs(300))
.send()
.await
.map_err(|e| anyhow::anyhow!("Embedding server request failed (is embed.py running?): {}", e))?;
if let Some(mut stdin) = child.stdin.take() {
use tokio::io::AsyncWriteExt;
stdin.write_all(input.to_string().as_bytes()).await?;
if !resp.status().is_success() {
let status = resp.status();
let body = resp.text().await.unwrap_or_default();
anyhow::bail!("Embedding server error {}: {}", status, body);
}
let output = child.wait_with_output().await?;
if !output.status.success() {
let stderr = String::from_utf8_lossy(&output.stderr);
anyhow::bail!("Embedding script failed: {}", stderr);
}
let result: serde_json::Value = serde_json::from_slice(&output.stdout)?;
let result: serde_json::Value = resp.json().await?;
let embeddings: Vec<Vec<f32>> = result["embeddings"]
.as_array()
.ok_or_else(|| anyhow::anyhow!("Invalid embedding output"))?