Switch from fastembed to Python sentence-transformers for embedding
ort (ONNX Runtime) has no prebuilt binaries for aarch64-musl. Use a Python subprocess with sentence-transformers instead: - scripts/embed.py: reads JSON stdin, outputs embeddings - kb.rs: calls Python script via tokio subprocess - Dockerfile: install python3 + sentence-transformers Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
60
src/kb.rs
60
src/kb.rs
@@ -1,11 +1,10 @@
|
||||
use anyhow::Result;
|
||||
use sqlx::sqlite::SqlitePool;
|
||||
use std::sync::Mutex;
|
||||
use std::process::Stdio;
|
||||
|
||||
const TOP_K: usize = 5;
|
||||
|
||||
pub struct KbManager {
|
||||
embedder: Mutex<fastembed::TextEmbedding>,
|
||||
pool: SqlitePool,
|
||||
}
|
||||
|
||||
@@ -25,14 +24,10 @@ struct Chunk {
|
||||
|
||||
impl KbManager {
|
||||
pub fn new(pool: SqlitePool) -> Result<Self> {
|
||||
let embedder = fastembed::TextEmbedding::try_new(
|
||||
fastembed::InitOptions::new(fastembed::EmbeddingModel::AllMiniLML6V2)
|
||||
.with_show_download_progress(true),
|
||||
)?;
|
||||
Ok(Self { embedder: Mutex::new(embedder), pool })
|
||||
Ok(Self { pool })
|
||||
}
|
||||
|
||||
/// Re-index: chunk the content, embed, store in SQLite
|
||||
/// Re-index: chunk the content, embed via Python, store in SQLite
|
||||
pub async fn index(&self, content: &str) -> Result<()> {
|
||||
// Clear old chunks
|
||||
sqlx::query("DELETE FROM kb_chunks")
|
||||
@@ -45,7 +40,7 @@ impl KbManager {
|
||||
}
|
||||
|
||||
let texts: Vec<String> = chunks.iter().map(|c| c.content.clone()).collect();
|
||||
let embeddings = self.embedder.lock().unwrap().embed(texts, None)?;
|
||||
let embeddings = compute_embeddings(&texts).await?;
|
||||
|
||||
for (chunk, embedding) in chunks.iter().zip(embeddings.into_iter()) {
|
||||
let vec_bytes = embedding_to_bytes(&embedding);
|
||||
@@ -66,7 +61,7 @@ impl KbManager {
|
||||
|
||||
/// Search KB by query, returns top-k results
|
||||
pub async fn search(&self, query: &str) -> Result<Vec<SearchResult>> {
|
||||
let query_embeddings = self.embedder.lock().unwrap().embed(vec![query.to_string()], None)?;
|
||||
let query_embeddings = compute_embeddings(&[query.to_string()]).await?;
|
||||
let query_vec = query_embeddings
|
||||
.into_iter()
|
||||
.next()
|
||||
@@ -102,7 +97,50 @@ impl KbManager {
|
||||
}
|
||||
}
|
||||
|
||||
/// Call Python script to compute embeddings
|
||||
async fn compute_embeddings(texts: &[String]) -> Result<Vec<Vec<f32>>> {
|
||||
let input = serde_json::json!({ "texts": texts });
|
||||
|
||||
let mut child = tokio::process::Command::new("python3")
|
||||
.arg("/app/scripts/embed.py")
|
||||
.stdin(Stdio::piped())
|
||||
.stdout(Stdio::piped())
|
||||
.stderr(Stdio::piped())
|
||||
.spawn()?;
|
||||
|
||||
if let Some(mut stdin) = child.stdin.take() {
|
||||
use tokio::io::AsyncWriteExt;
|
||||
stdin.write_all(input.to_string().as_bytes()).await?;
|
||||
}
|
||||
|
||||
let output = child.wait_with_output().await?;
|
||||
|
||||
if !output.status.success() {
|
||||
let stderr = String::from_utf8_lossy(&output.stderr);
|
||||
anyhow::bail!("Embedding script failed: {}", stderr);
|
||||
}
|
||||
|
||||
let result: serde_json::Value = serde_json::from_slice(&output.stdout)?;
|
||||
let embeddings: Vec<Vec<f32>> = result["embeddings"]
|
||||
.as_array()
|
||||
.ok_or_else(|| anyhow::anyhow!("Invalid embedding output"))?
|
||||
.iter()
|
||||
.map(|arr| {
|
||||
arr.as_array()
|
||||
.unwrap_or(&vec![])
|
||||
.iter()
|
||||
.map(|v| v.as_f64().unwrap_or(0.0) as f32)
|
||||
.collect()
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(embeddings)
|
||||
}
|
||||
|
||||
fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
|
||||
if a.len() != b.len() || a.is_empty() {
|
||||
return 0.0;
|
||||
}
|
||||
let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
|
||||
let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
@@ -131,7 +169,6 @@ fn split_chunks(content: &str) -> Vec<Chunk> {
|
||||
|
||||
for line in content.lines() {
|
||||
if line.starts_with("## ") {
|
||||
// Save previous chunk
|
||||
let text = current_lines.join("\n").trim().to_string();
|
||||
if !text.is_empty() {
|
||||
chunks.push(Chunk {
|
||||
@@ -150,7 +187,6 @@ fn split_chunks(content: &str) -> Vec<Chunk> {
|
||||
}
|
||||
}
|
||||
|
||||
// Last chunk
|
||||
let text = current_lines.join("\n").trim().to_string();
|
||||
if !text.is_empty() {
|
||||
chunks.push(Chunk {
|
||||
|
||||
Reference in New Issue
Block a user