- Switch extraction prompt to qa-style (80% recall vs 60% baseline) - Semicolon-separated cues in extraction become paraphrase variants - Add import_claude.py to bulk-import Claude Code conversation history - Fix LLM model name in systemd service, add logging basicConfig
401 lines
15 KiB
Python
401 lines
15 KiB
Python
"""nocmem — Memory service for NOC.
|
||
|
||
Wraps NuoNuo's HippocampalMemory as an HTTP API.
|
||
Auto-recall on every user message, async ingest after LLM response.
|
||
"""
|
||
|
||
import asyncio
|
||
import os
|
||
import re
|
||
import time
|
||
import logging
|
||
from contextlib import asynccontextmanager
|
||
from dataclasses import dataclass
|
||
from pathlib import Path
|
||
|
||
import torch
|
||
from fastapi import FastAPI
|
||
from pydantic import BaseModel, Field
|
||
from sentence_transformers import SentenceTransformer
|
||
from openai import OpenAI
|
||
|
||
from nuonuo.hippocampus import HippocampalMemory
|
||
|
||
logging.basicConfig(level=logging.INFO)
|
||
logger = logging.getLogger("nocmem")
|
||
|
||
# ── config ──────────────────────────────────────────────────────────
|
||
|
||
EMBED_MODEL = os.environ.get("NOCMEM_EMBED_MODEL", "all-MiniLM-L6-v2")
|
||
EMBED_DIM = int(os.environ.get("NOCMEM_EMBED_DIM", "384"))
|
||
DEVICE = os.environ.get("NOCMEM_DEVICE", "cuda" if torch.cuda.is_available() else "cpu")
|
||
DATA_DIR = Path(os.environ.get("NOCMEM_DATA_DIR", "./data"))
|
||
CHECKPOINT = DATA_DIR / "hippocampus.pt"
|
||
SAVE_INTERVAL = int(os.environ.get("NOCMEM_SAVE_INTERVAL", "10")) # save every N stores
|
||
HOPFIELD_BETA = float(os.environ.get("NOCMEM_HOPFIELD_BETA", "32.0"))
|
||
HOPFIELD_TOP_K = int(os.environ.get("NOCMEM_HOPFIELD_TOP_K", "10"))
|
||
COS_SIM_THRESHOLD = float(os.environ.get("NOCMEM_COS_SIM_THRESHOLD", "0.35"))
|
||
|
||
# LLM for memory extraction (optional)
|
||
LLM_ENDPOINT = os.environ.get("NOCMEM_LLM_ENDPOINT", "")
|
||
LLM_MODEL = os.environ.get("NOCMEM_LLM_MODEL", "gemma4:12b")
|
||
LLM_API_KEY = os.environ.get("NOCMEM_LLM_API_KEY", "unused")
|
||
|
||
# ── globals ─────────────────────────────────────────────────────────
|
||
|
||
encoder: SentenceTransformer = None # type: ignore
|
||
hippocampus: HippocampalMemory = None # type: ignore
|
||
llm_client = None # optional
|
||
_stores_since_save = 0
|
||
|
||
|
||
def embed(text: str) -> torch.Tensor:
|
||
return encoder.encode(
|
||
[text], convert_to_tensor=True, normalize_embeddings=True, device=DEVICE
|
||
)[0]
|
||
|
||
|
||
def embed_batch(texts: list[str]) -> list[torch.Tensor]:
|
||
if not texts:
|
||
return []
|
||
t = encoder.encode(
|
||
texts, convert_to_tensor=True, normalize_embeddings=True, device=DEVICE
|
||
)
|
||
return [t[i] for i in range(t.shape[0])]
|
||
|
||
|
||
def maybe_save():
|
||
global _stores_since_save
|
||
_stores_since_save += 1
|
||
if _stores_since_save >= SAVE_INTERVAL:
|
||
_stores_since_save = 0
|
||
DATA_DIR.mkdir(parents=True, exist_ok=True)
|
||
hippocampus.save(str(CHECKPOINT))
|
||
logger.info("checkpoint saved: %s", CHECKPOINT)
|
||
|
||
|
||
# ── lifespan ────────────────────────────────────────────────────────
|
||
|
||
@asynccontextmanager
|
||
async def lifespan(app: FastAPI):
|
||
global encoder, hippocampus, llm_client
|
||
|
||
logger.info("loading embedding model: %s (device=%s)", EMBED_MODEL, DEVICE)
|
||
encoder = SentenceTransformer(EMBED_MODEL, device=DEVICE)
|
||
|
||
if CHECKPOINT.exists():
|
||
logger.info("loading checkpoint: %s", CHECKPOINT)
|
||
hippocampus = HippocampalMemory.load(str(CHECKPOINT), device=DEVICE)
|
||
logger.info("loaded %d memories", len(hippocampus.memories))
|
||
else:
|
||
logger.info("no checkpoint found, starting fresh")
|
||
hippocampus = HippocampalMemory(
|
||
embed_dim=EMBED_DIM, beta=HOPFIELD_BETA,
|
||
hopfield_top_k=HOPFIELD_TOP_K, device=DEVICE,
|
||
)
|
||
|
||
if LLM_ENDPOINT:
|
||
try:
|
||
client = OpenAI(base_url=LLM_ENDPOINT, api_key=LLM_API_KEY, timeout=5.0)
|
||
client.models.list()
|
||
llm_client = client
|
||
logger.info("LLM client connected: %s", LLM_ENDPOINT)
|
||
except Exception as e:
|
||
logger.warning("LLM client unavailable: %s", e)
|
||
|
||
yield
|
||
|
||
# save on shutdown
|
||
DATA_DIR.mkdir(parents=True, exist_ok=True)
|
||
hippocampus.save(str(CHECKPOINT))
|
||
logger.info("shutdown: checkpoint saved")
|
||
|
||
|
||
app = FastAPI(title="nocmem", lifespan=lifespan)
|
||
|
||
|
||
# ── models ──────────────────────────────────────────────────────────
|
||
|
||
class RecallRequest(BaseModel):
|
||
text: str
|
||
top_k: int = Field(default=5, ge=1, le=20)
|
||
hops: int = Field(default=2, ge=1, le=5)
|
||
min_similarity: float = Field(default=0.0, ge=0.0, le=1.0)
|
||
|
||
class RecallResponse(BaseModel):
|
||
memories: str
|
||
count: int
|
||
latency_ms: float
|
||
|
||
class IngestRequest(BaseModel):
|
||
user_msg: str
|
||
assistant_msg: str
|
||
|
||
class IngestResponse(BaseModel):
|
||
stored: int
|
||
|
||
class StoreRequest(BaseModel):
|
||
cue: str
|
||
target: str
|
||
importance: float = Field(default=0.5, ge=0.0, le=1.0)
|
||
|
||
class StoreResponse(BaseModel):
|
||
memory_id: int
|
||
|
||
|
||
# ── endpoints ───────────────────────────────────────────────────────
|
||
|
||
@app.post("/recall", response_model=RecallResponse)
|
||
async def recall(req: RecallRequest):
|
||
t0 = time.monotonic()
|
||
|
||
query_emb = embed(req.text)
|
||
|
||
# pre-filter: check if anything in memory is actually similar enough
|
||
cue_mat = hippocampus._get_cue_matrix()
|
||
if cue_mat is not None and COS_SIM_THRESHOLD > 0:
|
||
cos_sims = query_emb @ cue_mat.T
|
||
max_cos_sim = cos_sims.max().item()
|
||
if max_cos_sim < COS_SIM_THRESHOLD:
|
||
# nothing in memory is similar enough — don't hallucinate
|
||
return RecallResponse(memories="", count=0, latency_ms=(time.monotonic() - t0) * 1000)
|
||
|
||
# single-hop
|
||
results = hippocampus.recall(query_emb, top_k=req.top_k)
|
||
|
||
# multi-hop chain from top result
|
||
chain_results = []
|
||
if results and req.hops > 1:
|
||
chain = hippocampus.recall_chain(query_emb, hops=req.hops)
|
||
# add chain results not already in single-hop
|
||
seen_ids = {r.memory_id for r in results}
|
||
for cr in chain:
|
||
if cr.memory_id not in seen_ids:
|
||
chain_results.append(cr)
|
||
seen_ids.add(cr.memory_id)
|
||
|
||
all_results = results + chain_results
|
||
elapsed = (time.monotonic() - t0) * 1000
|
||
|
||
if not all_results:
|
||
return RecallResponse(memories="", count=0, latency_ms=elapsed)
|
||
|
||
lines = []
|
||
for r in all_results:
|
||
if r.similarity < req.min_similarity:
|
||
continue
|
||
meta = r.metadata
|
||
text = meta.get("target", meta.get("text", ""))
|
||
if not text:
|
||
continue
|
||
hop_tag = f" (联想 hop={r.hop_distance})" if r.hop_distance > 1 else ""
|
||
lines.append(f"- {text}{hop_tag}")
|
||
|
||
if not lines:
|
||
return RecallResponse(memories="", count=0, latency_ms=elapsed)
|
||
|
||
formatted = "[以下是可能相关的历史记忆,仅供参考。请优先关注用户当前的消息。]\n" + "\n".join(lines)
|
||
return RecallResponse(memories=formatted, count=len(lines), latency_ms=elapsed)
|
||
|
||
|
||
@app.post("/ingest", response_model=IngestResponse)
|
||
async def ingest(req: IngestRequest):
|
||
extracted = await asyncio.to_thread(_extract_and_store, req.user_msg, req.assistant_msg)
|
||
return IngestResponse(stored=extracted)
|
||
|
||
|
||
@dataclass
|
||
class ExtractedMemory:
|
||
cue: str
|
||
target: str
|
||
importance: float = 0.5
|
||
|
||
|
||
def _extract_memories_llm(user_msg: str, assistant_msg: str) -> list[ExtractedMemory]:
|
||
prompt = (
|
||
'你是一个记忆提取器。把这段对话变成若干个"问答对"——未来有人问这个问题时,能直接给出答案。\n\n'
|
||
"要求:\n"
|
||
"- 问题要自然,像人真的会这么问\n"
|
||
"- 答案要具体完整,包含关键细节(名称、数字、地址等)\n"
|
||
"- 同一个事实可以从不同角度提问\n"
|
||
"- 每条 CUE 提供 2-3 个不同的触发短语,用分号分隔\n\n"
|
||
"格式(每行一条):\n"
|
||
"CUE: <提问方式1>; <提问方式2>; <提问方式3> | TARGET: <完整的回答> | IMPORTANCE: <0-1>\n\n"
|
||
f"User: {user_msg}\nAssistant: {assistant_msg}\n\n"
|
||
"没有值得记住的则输出 NONE。"
|
||
)
|
||
try:
|
||
resp = llm_client.chat.completions.create(
|
||
model=LLM_MODEL, messages=[{"role": "user", "content": prompt}],
|
||
temperature=0.3, max_tokens=512,
|
||
)
|
||
result = resp.choices[0].message.content
|
||
except Exception:
|
||
return _extract_memories_heuristic(user_msg, assistant_msg)
|
||
|
||
memories = []
|
||
for line in result.strip().split("\n"):
|
||
if line.strip() == "NONE":
|
||
break
|
||
m = re.match(r"CUE:\s*(.+?)\s*\|\s*TARGET:\s*(.+?)\s*\|\s*IMPORTANCE:\s*([\d.]+)", line)
|
||
if m:
|
||
memories.append(ExtractedMemory(m.group(1).strip(), m.group(2).strip(), float(m.group(3))))
|
||
return memories
|
||
|
||
|
||
def _extract_memories_heuristic(user_msg: str, assistant_msg: str) -> list[ExtractedMemory]:
|
||
memories = []
|
||
# detect questions — English and Chinese
|
||
has_question = "?" in user_msg or "?" in user_msg or any(
|
||
user_msg.strip().startswith(q) for q in ["怎么", "什么", "哪", "为什么", "如何", "多少", "几"]
|
||
)
|
||
# count meaningful length: for Chinese, use character count
|
||
assistant_long_enough = len(assistant_msg) > 20
|
||
if has_question and assistant_long_enough:
|
||
cue = user_msg.rstrip("??").strip()
|
||
memories.append(ExtractedMemory(
|
||
cue=cue, target=assistant_msg[:300], importance=0.6,
|
||
))
|
||
# tech keywords — English and Chinese
|
||
tech_keywords = [
|
||
"deploy", "config", "bug", "fix", "error", "database", "server",
|
||
"api", "port", "token", "password", "version", "install", "upgrade",
|
||
"部署", "配置", "错误", "数据库", "服务器", "端口", "密码", "版本",
|
||
"安装", "升级", "模型", "工具", "代码", "项目", "优化", "性能",
|
||
"内存", "GPU", "vllm", "docker", "k8s", "git", "编译", "测试",
|
||
]
|
||
combined = (user_msg + " " + assistant_msg).lower()
|
||
user_meaningful = len(user_msg) >= 8 # characters, not words
|
||
if any(kw in combined for kw in tech_keywords) and user_meaningful:
|
||
if not memories: # avoid duplicate with Q&A extraction
|
||
memories.append(ExtractedMemory(
|
||
cue=user_msg[:150], target=assistant_msg[:300], importance=0.5,
|
||
))
|
||
return memories
|
||
|
||
|
||
def _generate_paraphrases_heuristic(text: str, n: int = 3) -> list[str]:
|
||
variants = []
|
||
text_lower = text.lower().strip()
|
||
# English prefixes
|
||
for pfx in ["can you ", "please ", "i need to ", "how do i ", "how to ", "what is ", "what's "]:
|
||
if text_lower.startswith(pfx):
|
||
stripped = text[len(pfx):].strip()
|
||
if stripped:
|
||
variants.append(stripped)
|
||
# Chinese prefixes
|
||
for pfx in ["帮我看看", "帮我", "请问", "我想知道", "能不能", "怎么样", "看下", "看看"]:
|
||
if text.startswith(pfx):
|
||
stripped = text[len(pfx):].strip()
|
||
if stripped:
|
||
variants.append(stripped)
|
||
# synonym swaps — English
|
||
en_swaps = {"slow": "performance issues", "fix": "resolve", "deploy": "release",
|
||
"error": "issue", "bug": "problem", "database": "DB", "server": "machine"}
|
||
for old, new in en_swaps.items():
|
||
if old in text_lower:
|
||
variant = text.replace(old, new).replace(old.capitalize(), new.capitalize())
|
||
if variant != text and variant not in variants:
|
||
variants.append(variant)
|
||
# synonym swaps — Chinese
|
||
cn_swaps = {"数据库": "DB", "服务器": "机器", "部署": "上线", "配置": "设置",
|
||
"性能": "速度", "优化": "改进", "工具": "tool", "项目": "project"}
|
||
for old, new in cn_swaps.items():
|
||
if old in text:
|
||
variant = text.replace(old, new)
|
||
if variant != text and variant not in variants:
|
||
variants.append(variant)
|
||
return variants[:n]
|
||
|
||
|
||
def _generate_paraphrases_llm(text: str, n: int = 3) -> list[str]:
|
||
prompt = f"Generate {n} different paraphrases of this text. Each should convey the same meaning but use different words. One per line, no numbering.\n\nText: {text}"
|
||
try:
|
||
resp = llm_client.chat.completions.create(
|
||
model=LLM_MODEL, messages=[{"role": "user", "content": prompt}],
|
||
temperature=0.8, max_tokens=256,
|
||
)
|
||
result = resp.choices[0].message.content
|
||
return [l.strip() for l in result.strip().split("\n") if l.strip() and len(l.strip()) > 3][:n]
|
||
except Exception:
|
||
return _generate_paraphrases_heuristic(text, n)
|
||
|
||
|
||
def _extract_and_store(user_msg: str, assistant_msg: str) -> int:
|
||
if llm_client:
|
||
memories = _extract_memories_llm(user_msg, assistant_msg)
|
||
else:
|
||
memories = _extract_memories_heuristic(user_msg, assistant_msg)
|
||
|
||
if not memories:
|
||
return 0
|
||
|
||
stored = 0
|
||
for mem in memories:
|
||
if mem.importance < 0.3:
|
||
continue
|
||
|
||
# split semicolon-separated cues into primary + variants
|
||
cue_parts = [p.strip() for p in mem.cue.split(";") if p.strip()]
|
||
primary_cue = cue_parts[0] if cue_parts else mem.cue
|
||
inline_variants = cue_parts[1:] if len(cue_parts) > 1 else []
|
||
|
||
cue_emb = embed(primary_cue)
|
||
target_emb = embed(mem.target)
|
||
|
||
# inline variants from semicolon cues (already in the extraction)
|
||
variant_embs = embed_batch(inline_variants) if inline_variants else []
|
||
|
||
# additionally generate paraphrases if no inline variants
|
||
if not inline_variants:
|
||
if llm_client:
|
||
paraphrases = _generate_paraphrases_llm(primary_cue, n=3)
|
||
else:
|
||
paraphrases = _generate_paraphrases_heuristic(primary_cue, n=3)
|
||
variant_embs = embed_batch(paraphrases) if paraphrases else []
|
||
|
||
hippocampus.store(
|
||
cue_emb, target_emb,
|
||
cue_variants=variant_embs if variant_embs else None,
|
||
metadata={"cue": mem.cue, "target": mem.target, "importance": mem.importance},
|
||
timestamp=time.time(),
|
||
)
|
||
stored += 1
|
||
|
||
if stored > 0:
|
||
maybe_save()
|
||
logger.info("ingested %d memories from conversation turn", stored)
|
||
|
||
return stored
|
||
|
||
|
||
@app.post("/store", response_model=StoreResponse)
|
||
async def store_direct(req: StoreRequest):
|
||
"""Direct store — bypass LLM extraction, for manual/testing use."""
|
||
cue_emb = embed(req.cue)
|
||
target_emb = embed(req.target)
|
||
mid = hippocampus.store(
|
||
cue_emb, target_emb,
|
||
metadata={"cue": req.cue, "target": req.target, "importance": req.importance},
|
||
timestamp=time.time(),
|
||
)
|
||
maybe_save()
|
||
return StoreResponse(memory_id=mid)
|
||
|
||
|
||
@app.get("/stats")
|
||
async def stats():
|
||
s = hippocampus.stats()
|
||
s["device"] = DEVICE
|
||
s["embedding_model"] = EMBED_MODEL
|
||
s["checkpoint"] = str(CHECKPOINT)
|
||
s["checkpoint_exists"] = CHECKPOINT.exists()
|
||
return s
|
||
|
||
|
||
@app.delete("/memory/{memory_id}")
|
||
async def forget(memory_id: int):
|
||
hippocampus.forget(memory_id)
|
||
maybe_save()
|
||
return {"deleted": memory_id}
|