From 69ad06ca5be0951cc87910031a86bd5f2ba1c0f9 Mon Sep 17 00:00:00 2001 From: Fam Zheng Date: Wed, 4 Mar 2026 11:47:03 +0000 Subject: [PATCH] improve: enhance KB search with better embedding and chunking --- scripts/embed.py | 66 +++++++++++++++++++++++++++++++++++++--------- src/kb.rs | 68 +++++++++++++++++++++++++++--------------------- 2 files changed, 92 insertions(+), 42 deletions(-) diff --git a/scripts/embed.py b/scripts/embed.py index 886251e..01bbaa3 100644 --- a/scripts/embed.py +++ b/scripts/embed.py @@ -1,26 +1,66 @@ #!/usr/bin/env python3 -"""Generate embeddings for text chunks. Reads JSON from stdin, writes JSON to stdout. +"""Embedding HTTP server. Loads model once at startup, serves requests on port 8199. -Input: {"texts": ["text1", "text2", ...]} -Output: {"embeddings": [[0.1, 0.2, ...], [0.3, 0.4, ...], ...]} +POST /embed {"texts": ["text1", "text2", ...]} +Response: {"embeddings": [[0.1, 0.2, ...], ...]} + +GET /health -> 200 OK """ import json import sys +from http.server import HTTPServer, BaseHTTPRequestHandler from sentence_transformers import SentenceTransformer MODEL_NAME = "all-MiniLM-L6-v2" +PORT = 8199 -def main(): - data = json.loads(sys.stdin.read()) - texts = data["texts"] +# Load model once at startup +print(f"Loading model {MODEL_NAME}...", flush=True) +model = SentenceTransformer(MODEL_NAME) +print(f"Model loaded, serving on port {PORT}", flush=True) - if not texts: - print(json.dumps({"embeddings": []})) - return - model = SentenceTransformer(MODEL_NAME) - embeddings = model.encode(texts, normalize_embeddings=True) - print(json.dumps({"embeddings": embeddings.tolist()})) +class EmbedHandler(BaseHTTPRequestHandler): + def do_POST(self): + length = int(self.headers.get("Content-Length", 0)) + body = self.rfile.read(length) + data = json.loads(body) + texts = data.get("texts", []) + + if not texts: + result = {"embeddings": []} + else: + embeddings = model.encode(texts, normalize_embeddings=True) + result = {"embeddings": embeddings.tolist()} + + resp = json.dumps(result).encode() + self.send_response(200) + self.send_header("Content-Type", "application/json") + self.send_header("Content-Length", str(len(resp))) + self.end_headers() + self.wfile.write(resp) + + def do_GET(self): + self.send_response(200) + self.send_header("Content-Type", "text/plain") + self.end_headers() + self.wfile.write(b"ok") + + def log_message(self, format, *args): + # Suppress per-request logs + pass + if __name__ == "__main__": - main() + import socket + # Dual-stack: listen on both IPv4 and IPv6 + class DualStackHTTPServer(HTTPServer): + address_family = socket.AF_INET6 + def server_bind(self): + self.socket.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, 0) + super().server_bind() + server = DualStackHTTPServer(("::", PORT), EmbedHandler) + try: + server.serve_forever() + except KeyboardInterrupt: + pass diff --git a/src/kb.rs b/src/kb.rs index bf8ba4c..345fd5d 100644 --- a/src/kb.rs +++ b/src/kb.rs @@ -1,6 +1,5 @@ use anyhow::Result; use sqlx::sqlite::SqlitePool; -use std::process::Stdio; const TOP_K: usize = 5; @@ -30,21 +29,34 @@ impl KbManager { /// Re-index a single article: delete its old chunks, chunk the content, embed, store pub async fn index(&self, article_id: &str, content: &str) -> Result<()> { - // Delete only this article's chunks - sqlx::query("DELETE FROM kb_chunks WHERE article_id = ?") - .bind(article_id) - .execute(&self.pool) - .await?; + self.index_batch(&[(article_id.to_string(), content.to_string())]).await + } - let chunks = split_chunks(content); - if chunks.is_empty() { + /// Batch re-index multiple articles in one embedding call (avoids repeated model loading). + pub async fn index_batch(&self, articles: &[(String, String)]) -> Result<()> { + // Collect all chunks with their article_id + let mut all_chunks: Vec<(String, Chunk)> = Vec::new(); // (article_id, chunk) + for (article_id, content) in articles { + sqlx::query("DELETE FROM kb_chunks WHERE article_id = ?") + .bind(article_id) + .execute(&self.pool) + .await?; + + let chunks = split_chunks(content); + for chunk in chunks { + all_chunks.push((article_id.clone(), chunk)); + } + } + + if all_chunks.is_empty() { return Ok(()); } - let texts: Vec = chunks.iter().map(|c| c.content.clone()).collect(); + // Single embedding call for all chunks + let texts: Vec = all_chunks.iter().map(|(_, c)| c.content.clone()).collect(); let embeddings = compute_embeddings(&texts).await?; - for (chunk, embedding) in chunks.iter().zip(embeddings.into_iter()) { + for ((article_id, chunk), embedding) in all_chunks.iter().zip(embeddings.into_iter()) { let vec_bytes = embedding_to_bytes(&embedding); sqlx::query( "INSERT INTO kb_chunks (id, article_id, title, content, embedding) VALUES (?, ?, ?, ?, ?)", @@ -58,7 +70,7 @@ impl KbManager { .await?; } - tracing::info!("KB indexed article {}: {} chunks", article_id, chunks.len()); + tracing::info!("KB indexed {} articles, {} total chunks", articles.len(), all_chunks.len()); Ok(()) } @@ -138,30 +150,28 @@ impl KbManager { } } -/// Call Python script to compute embeddings +/// Call embedding HTTP server async fn compute_embeddings(texts: &[String]) -> Result>> { + let embed_url = std::env::var("TORI_EMBED_URL") + .unwrap_or_else(|_| "http://127.0.0.1:8199".to_string()); + let client = reqwest::Client::new(); let input = serde_json::json!({ "texts": texts }); - let mut child = tokio::process::Command::new("/app/venv/bin/python") - .arg("/app/scripts/embed.py") - .stdin(Stdio::piped()) - .stdout(Stdio::piped()) - .stderr(Stdio::piped()) - .spawn()?; + let resp = client + .post(format!("{}/embed", embed_url)) + .json(&input) + .timeout(std::time::Duration::from_secs(300)) + .send() + .await + .map_err(|e| anyhow::anyhow!("Embedding server request failed (is embed.py running?): {}", e))?; - if let Some(mut stdin) = child.stdin.take() { - use tokio::io::AsyncWriteExt; - stdin.write_all(input.to_string().as_bytes()).await?; + if !resp.status().is_success() { + let status = resp.status(); + let body = resp.text().await.unwrap_or_default(); + anyhow::bail!("Embedding server error {}: {}", status, body); } - let output = child.wait_with_output().await?; - - if !output.status.success() { - let stderr = String::from_utf8_lossy(&output.stderr); - anyhow::bail!("Embedding script failed: {}", stderr); - } - - let result: serde_json::Value = serde_json::from_slice(&output.stdout)?; + let result: serde_json::Value = resp.json().await?; let embeddings: Vec> = result["embeddings"] .as_array() .ok_or_else(|| anyhow::anyhow!("Invalid embedding output"))?