improve: enhance KB search with better embedding and chunking
This commit is contained in:
@@ -1,26 +1,66 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Generate embeddings for text chunks. Reads JSON from stdin, writes JSON to stdout.
|
||||
"""Embedding HTTP server. Loads model once at startup, serves requests on port 8199.
|
||||
|
||||
Input: {"texts": ["text1", "text2", ...]}
|
||||
Output: {"embeddings": [[0.1, 0.2, ...], [0.3, 0.4, ...], ...]}
|
||||
POST /embed {"texts": ["text1", "text2", ...]}
|
||||
Response: {"embeddings": [[0.1, 0.2, ...], ...]}
|
||||
|
||||
GET /health -> 200 OK
|
||||
"""
|
||||
import json
|
||||
import sys
|
||||
from http.server import HTTPServer, BaseHTTPRequestHandler
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
MODEL_NAME = "all-MiniLM-L6-v2"
|
||||
PORT = 8199
|
||||
|
||||
def main():
|
||||
data = json.loads(sys.stdin.read())
|
||||
texts = data["texts"]
|
||||
# Load model once at startup
|
||||
print(f"Loading model {MODEL_NAME}...", flush=True)
|
||||
model = SentenceTransformer(MODEL_NAME)
|
||||
print(f"Model loaded, serving on port {PORT}", flush=True)
|
||||
|
||||
if not texts:
|
||||
print(json.dumps({"embeddings": []}))
|
||||
return
|
||||
|
||||
model = SentenceTransformer(MODEL_NAME)
|
||||
embeddings = model.encode(texts, normalize_embeddings=True)
|
||||
print(json.dumps({"embeddings": embeddings.tolist()}))
|
||||
class EmbedHandler(BaseHTTPRequestHandler):
|
||||
def do_POST(self):
|
||||
length = int(self.headers.get("Content-Length", 0))
|
||||
body = self.rfile.read(length)
|
||||
data = json.loads(body)
|
||||
texts = data.get("texts", [])
|
||||
|
||||
if not texts:
|
||||
result = {"embeddings": []}
|
||||
else:
|
||||
embeddings = model.encode(texts, normalize_embeddings=True)
|
||||
result = {"embeddings": embeddings.tolist()}
|
||||
|
||||
resp = json.dumps(result).encode()
|
||||
self.send_response(200)
|
||||
self.send_header("Content-Type", "application/json")
|
||||
self.send_header("Content-Length", str(len(resp)))
|
||||
self.end_headers()
|
||||
self.wfile.write(resp)
|
||||
|
||||
def do_GET(self):
|
||||
self.send_response(200)
|
||||
self.send_header("Content-Type", "text/plain")
|
||||
self.end_headers()
|
||||
self.wfile.write(b"ok")
|
||||
|
||||
def log_message(self, format, *args):
|
||||
# Suppress per-request logs
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
import socket
|
||||
# Dual-stack: listen on both IPv4 and IPv6
|
||||
class DualStackHTTPServer(HTTPServer):
|
||||
address_family = socket.AF_INET6
|
||||
def server_bind(self):
|
||||
self.socket.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, 0)
|
||||
super().server_bind()
|
||||
server = DualStackHTTPServer(("::", PORT), EmbedHandler)
|
||||
try:
|
||||
server.serve_forever()
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
|
||||
Reference in New Issue
Block a user