Switch from fastembed to Python sentence-transformers for embedding
ort (ONNX Runtime) has no prebuilt binaries for aarch64-musl. Use a Python subprocess with sentence-transformers instead: - scripts/embed.py: reads JSON stdin, outputs embeddings - kb.rs: calls Python script via tokio subprocess - Dockerfile: install python3 + sentence-transformers Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
761
Cargo.lock
generated
761
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -26,4 +26,3 @@ uuid = { version = "1", features = ["v4"] }
|
||||
anyhow = "1"
|
||||
mime_guess = "2"
|
||||
nix = { version = "0.29", features = ["signal"] }
|
||||
fastembed = { version = "5", default-features = false, features = ["hf-hub", "hf-hub-rustls-tls", "ort-download-binaries-rustls-tls"] }
|
||||
|
||||
@@ -8,13 +8,15 @@ RUN npm run build
|
||||
|
||||
# Stage 2: Runtime
|
||||
FROM alpine:3.21
|
||||
RUN apk add --no-cache ca-certificates curl bash
|
||||
RUN apk add --no-cache ca-certificates curl bash python3 py3-pip
|
||||
RUN curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
ENV PATH="/root/.local/bin:$PATH"
|
||||
RUN pip3 install --break-system-packages sentence-transformers
|
||||
RUN mkdir -p /app/data/workspaces
|
||||
WORKDIR /app
|
||||
COPY target/aarch64-unknown-linux-musl/release/tori .
|
||||
COPY --from=frontend /app/web/dist ./web/dist/
|
||||
COPY scripts/embed.py ./scripts/
|
||||
COPY config.yaml .
|
||||
|
||||
EXPOSE 3000
|
||||
|
||||
26
scripts/embed.py
Normal file
26
scripts/embed.py
Normal file
@@ -0,0 +1,26 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Generate embeddings for text chunks. Reads JSON from stdin, writes JSON to stdout.
|
||||
|
||||
Input: {"texts": ["text1", "text2", ...]}
|
||||
Output: {"embeddings": [[0.1, 0.2, ...], [0.3, 0.4, ...], ...]}
|
||||
"""
|
||||
import json
|
||||
import sys
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
MODEL_NAME = "all-MiniLM-L6-v2"
|
||||
|
||||
def main():
|
||||
data = json.loads(sys.stdin.read())
|
||||
texts = data["texts"]
|
||||
|
||||
if not texts:
|
||||
print(json.dumps({"embeddings": []}))
|
||||
return
|
||||
|
||||
model = SentenceTransformer(MODEL_NAME)
|
||||
embeddings = model.encode(texts, normalize_embeddings=True)
|
||||
print(json.dumps({"embeddings": embeddings.tolist()}))
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
60
src/kb.rs
60
src/kb.rs
@@ -1,11 +1,10 @@
|
||||
use anyhow::Result;
|
||||
use sqlx::sqlite::SqlitePool;
|
||||
use std::sync::Mutex;
|
||||
use std::process::Stdio;
|
||||
|
||||
const TOP_K: usize = 5;
|
||||
|
||||
pub struct KbManager {
|
||||
embedder: Mutex<fastembed::TextEmbedding>,
|
||||
pool: SqlitePool,
|
||||
}
|
||||
|
||||
@@ -25,14 +24,10 @@ struct Chunk {
|
||||
|
||||
impl KbManager {
|
||||
pub fn new(pool: SqlitePool) -> Result<Self> {
|
||||
let embedder = fastembed::TextEmbedding::try_new(
|
||||
fastembed::InitOptions::new(fastembed::EmbeddingModel::AllMiniLML6V2)
|
||||
.with_show_download_progress(true),
|
||||
)?;
|
||||
Ok(Self { embedder: Mutex::new(embedder), pool })
|
||||
Ok(Self { pool })
|
||||
}
|
||||
|
||||
/// Re-index: chunk the content, embed, store in SQLite
|
||||
/// Re-index: chunk the content, embed via Python, store in SQLite
|
||||
pub async fn index(&self, content: &str) -> Result<()> {
|
||||
// Clear old chunks
|
||||
sqlx::query("DELETE FROM kb_chunks")
|
||||
@@ -45,7 +40,7 @@ impl KbManager {
|
||||
}
|
||||
|
||||
let texts: Vec<String> = chunks.iter().map(|c| c.content.clone()).collect();
|
||||
let embeddings = self.embedder.lock().unwrap().embed(texts, None)?;
|
||||
let embeddings = compute_embeddings(&texts).await?;
|
||||
|
||||
for (chunk, embedding) in chunks.iter().zip(embeddings.into_iter()) {
|
||||
let vec_bytes = embedding_to_bytes(&embedding);
|
||||
@@ -66,7 +61,7 @@ impl KbManager {
|
||||
|
||||
/// Search KB by query, returns top-k results
|
||||
pub async fn search(&self, query: &str) -> Result<Vec<SearchResult>> {
|
||||
let query_embeddings = self.embedder.lock().unwrap().embed(vec![query.to_string()], None)?;
|
||||
let query_embeddings = compute_embeddings(&[query.to_string()]).await?;
|
||||
let query_vec = query_embeddings
|
||||
.into_iter()
|
||||
.next()
|
||||
@@ -102,7 +97,50 @@ impl KbManager {
|
||||
}
|
||||
}
|
||||
|
||||
/// Call Python script to compute embeddings
|
||||
async fn compute_embeddings(texts: &[String]) -> Result<Vec<Vec<f32>>> {
|
||||
let input = serde_json::json!({ "texts": texts });
|
||||
|
||||
let mut child = tokio::process::Command::new("python3")
|
||||
.arg("/app/scripts/embed.py")
|
||||
.stdin(Stdio::piped())
|
||||
.stdout(Stdio::piped())
|
||||
.stderr(Stdio::piped())
|
||||
.spawn()?;
|
||||
|
||||
if let Some(mut stdin) = child.stdin.take() {
|
||||
use tokio::io::AsyncWriteExt;
|
||||
stdin.write_all(input.to_string().as_bytes()).await?;
|
||||
}
|
||||
|
||||
let output = child.wait_with_output().await?;
|
||||
|
||||
if !output.status.success() {
|
||||
let stderr = String::from_utf8_lossy(&output.stderr);
|
||||
anyhow::bail!("Embedding script failed: {}", stderr);
|
||||
}
|
||||
|
||||
let result: serde_json::Value = serde_json::from_slice(&output.stdout)?;
|
||||
let embeddings: Vec<Vec<f32>> = result["embeddings"]
|
||||
.as_array()
|
||||
.ok_or_else(|| anyhow::anyhow!("Invalid embedding output"))?
|
||||
.iter()
|
||||
.map(|arr| {
|
||||
arr.as_array()
|
||||
.unwrap_or(&vec![])
|
||||
.iter()
|
||||
.map(|v| v.as_f64().unwrap_or(0.0) as f32)
|
||||
.collect()
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(embeddings)
|
||||
}
|
||||
|
||||
fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
|
||||
if a.len() != b.len() || a.is_empty() {
|
||||
return 0.0;
|
||||
}
|
||||
let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
|
||||
let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
@@ -131,7 +169,6 @@ fn split_chunks(content: &str) -> Vec<Chunk> {
|
||||
|
||||
for line in content.lines() {
|
||||
if line.starts_with("## ") {
|
||||
// Save previous chunk
|
||||
let text = current_lines.join("\n").trim().to_string();
|
||||
if !text.is_empty() {
|
||||
chunks.push(Chunk {
|
||||
@@ -150,7 +187,6 @@ fn split_chunks(content: &str) -> Vec<Chunk> {
|
||||
}
|
||||
}
|
||||
|
||||
// Last chunk
|
||||
let text = current_lines.join("\n").trim().to_string();
|
||||
if !text.is_empty() {
|
||||
chunks.push(Chunk {
|
||||
|
||||
Reference in New Issue
Block a user