Files
noc/mem/server.py
Fam Zheng 7000ccda0f add nocmem: auto memory recall + ingest via NuoNuo hippocampal network
- nocmem Python service (mem/): FastAPI wrapper around NuoNuo's
  Hopfield-Hebbian memory, with /recall, /ingest, /store, /stats endpoints
- NOC integration: auto recall after user message (injected as system msg),
  async ingest after LLM response (fire-and-forget)
- Recall: cosine pre-filter (threshold 0.35) + Hopfield attention (β=32),
  top_k=3, KV-cache friendly (appended after user msg, not in system prompt)
- Ingest: LLM extraction + paraphrase augmentation, heuristic fallback
- Wired into main.rs, life.rs (agent done), http.rs (api chat)
- Config: optional `nocmem.endpoint` in config.yaml
- Includes benchmarks: LongMemEval (R@5=94.0%), efficiency, noise vs scale
- Design doc: doc/nocmem.md
2026-04-11 12:24:48 +01:00

387 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""nocmem — Memory service for NOC.
Wraps NuoNuo's HippocampalMemory as an HTTP API.
Auto-recall on every user message, async ingest after LLM response.
"""
import asyncio
import os
import re
import time
import logging
from contextlib import asynccontextmanager
from dataclasses import dataclass
from pathlib import Path
import torch
from fastapi import FastAPI
from pydantic import BaseModel, Field
from sentence_transformers import SentenceTransformer
from openai import OpenAI
from nuonuo.hippocampus import HippocampalMemory
logger = logging.getLogger("nocmem")
# ── config ──────────────────────────────────────────────────────────
EMBED_MODEL = os.environ.get("NOCMEM_EMBED_MODEL", "all-MiniLM-L6-v2")
EMBED_DIM = int(os.environ.get("NOCMEM_EMBED_DIM", "384"))
DEVICE = os.environ.get("NOCMEM_DEVICE", "cuda" if torch.cuda.is_available() else "cpu")
DATA_DIR = Path(os.environ.get("NOCMEM_DATA_DIR", "./data"))
CHECKPOINT = DATA_DIR / "hippocampus.pt"
SAVE_INTERVAL = int(os.environ.get("NOCMEM_SAVE_INTERVAL", "10")) # save every N stores
HOPFIELD_BETA = float(os.environ.get("NOCMEM_HOPFIELD_BETA", "32.0"))
HOPFIELD_TOP_K = int(os.environ.get("NOCMEM_HOPFIELD_TOP_K", "10"))
COS_SIM_THRESHOLD = float(os.environ.get("NOCMEM_COS_SIM_THRESHOLD", "0.35"))
# LLM for memory extraction (optional)
LLM_ENDPOINT = os.environ.get("NOCMEM_LLM_ENDPOINT", "")
LLM_MODEL = os.environ.get("NOCMEM_LLM_MODEL", "gemma4:12b")
LLM_API_KEY = os.environ.get("NOCMEM_LLM_API_KEY", "unused")
# ── globals ─────────────────────────────────────────────────────────
encoder: SentenceTransformer = None # type: ignore
hippocampus: HippocampalMemory = None # type: ignore
llm_client = None # optional
_stores_since_save = 0
def embed(text: str) -> torch.Tensor:
return encoder.encode(
[text], convert_to_tensor=True, normalize_embeddings=True, device=DEVICE
)[0]
def embed_batch(texts: list[str]) -> list[torch.Tensor]:
if not texts:
return []
t = encoder.encode(
texts, convert_to_tensor=True, normalize_embeddings=True, device=DEVICE
)
return [t[i] for i in range(t.shape[0])]
def maybe_save():
global _stores_since_save
_stores_since_save += 1
if _stores_since_save >= SAVE_INTERVAL:
_stores_since_save = 0
DATA_DIR.mkdir(parents=True, exist_ok=True)
hippocampus.save(str(CHECKPOINT))
logger.info("checkpoint saved: %s", CHECKPOINT)
# ── lifespan ────────────────────────────────────────────────────────
@asynccontextmanager
async def lifespan(app: FastAPI):
global encoder, hippocampus, llm_client
logger.info("loading embedding model: %s (device=%s)", EMBED_MODEL, DEVICE)
encoder = SentenceTransformer(EMBED_MODEL, device=DEVICE)
if CHECKPOINT.exists():
logger.info("loading checkpoint: %s", CHECKPOINT)
hippocampus = HippocampalMemory.load(str(CHECKPOINT), device=DEVICE)
logger.info("loaded %d memories", len(hippocampus.memories))
else:
logger.info("no checkpoint found, starting fresh")
hippocampus = HippocampalMemory(
embed_dim=EMBED_DIM, beta=HOPFIELD_BETA,
hopfield_top_k=HOPFIELD_TOP_K, device=DEVICE,
)
if LLM_ENDPOINT:
try:
client = OpenAI(base_url=LLM_ENDPOINT, api_key=LLM_API_KEY, timeout=5.0)
client.models.list()
llm_client = client
logger.info("LLM client connected: %s", LLM_ENDPOINT)
except Exception as e:
logger.warning("LLM client unavailable: %s", e)
yield
# save on shutdown
DATA_DIR.mkdir(parents=True, exist_ok=True)
hippocampus.save(str(CHECKPOINT))
logger.info("shutdown: checkpoint saved")
app = FastAPI(title="nocmem", lifespan=lifespan)
# ── models ──────────────────────────────────────────────────────────
class RecallRequest(BaseModel):
text: str
top_k: int = Field(default=5, ge=1, le=20)
hops: int = Field(default=2, ge=1, le=5)
min_similarity: float = Field(default=0.0, ge=0.0, le=1.0)
class RecallResponse(BaseModel):
memories: str
count: int
latency_ms: float
class IngestRequest(BaseModel):
user_msg: str
assistant_msg: str
class IngestResponse(BaseModel):
stored: int
class StoreRequest(BaseModel):
cue: str
target: str
importance: float = Field(default=0.5, ge=0.0, le=1.0)
class StoreResponse(BaseModel):
memory_id: int
# ── endpoints ───────────────────────────────────────────────────────
@app.post("/recall", response_model=RecallResponse)
async def recall(req: RecallRequest):
t0 = time.monotonic()
query_emb = embed(req.text)
# pre-filter: check if anything in memory is actually similar enough
cue_mat = hippocampus._get_cue_matrix()
if cue_mat is not None and COS_SIM_THRESHOLD > 0:
cos_sims = query_emb @ cue_mat.T
max_cos_sim = cos_sims.max().item()
if max_cos_sim < COS_SIM_THRESHOLD:
# nothing in memory is similar enough — don't hallucinate
return RecallResponse(memories="", count=0, latency_ms=(time.monotonic() - t0) * 1000)
# single-hop
results = hippocampus.recall(query_emb, top_k=req.top_k)
# multi-hop chain from top result
chain_results = []
if results and req.hops > 1:
chain = hippocampus.recall_chain(query_emb, hops=req.hops)
# add chain results not already in single-hop
seen_ids = {r.memory_id for r in results}
for cr in chain:
if cr.memory_id not in seen_ids:
chain_results.append(cr)
seen_ids.add(cr.memory_id)
all_results = results + chain_results
elapsed = (time.monotonic() - t0) * 1000
if not all_results:
return RecallResponse(memories="", count=0, latency_ms=elapsed)
lines = []
for r in all_results:
if r.similarity < req.min_similarity:
continue
meta = r.metadata
text = meta.get("target", meta.get("text", ""))
if not text:
continue
hop_tag = f" (联想 hop={r.hop_distance})" if r.hop_distance > 1 else ""
lines.append(f"- {text}{hop_tag}")
if not lines:
return RecallResponse(memories="", count=0, latency_ms=elapsed)
formatted = "[以下是可能相关的历史记忆,仅供参考。请优先关注用户当前的消息。]\n" + "\n".join(lines)
return RecallResponse(memories=formatted, count=len(lines), latency_ms=elapsed)
@app.post("/ingest", response_model=IngestResponse)
async def ingest(req: IngestRequest):
extracted = await asyncio.to_thread(_extract_and_store, req.user_msg, req.assistant_msg)
return IngestResponse(stored=extracted)
@dataclass
class ExtractedMemory:
cue: str
target: str
importance: float = 0.5
def _extract_memories_llm(user_msg: str, assistant_msg: str) -> list[ExtractedMemory]:
prompt = (
"From this conversation turn, extract key facts worth remembering for future conversations.\n"
"For each fact, provide a \"cue\" (what would trigger recalling this) and a \"target\" (the fact itself).\n"
"Rate importance 0-1 (1 = critical fact, 0 = trivial).\n\n"
f"User: {user_msg}\nAssistant: {assistant_msg}\n\n"
"Output format (one per line):\nCUE: <trigger phrase> | TARGET: <fact> | IMPORTANCE: <0-1>\n\n"
"Only extract genuinely useful facts. If nothing worth remembering, output NONE."
)
try:
resp = llm_client.chat.completions.create(
model=LLM_MODEL, messages=[{"role": "user", "content": prompt}],
temperature=0.3, max_tokens=512,
)
result = resp.choices[0].message.content
except Exception:
return _extract_memories_heuristic(user_msg, assistant_msg)
memories = []
for line in result.strip().split("\n"):
if line.strip() == "NONE":
break
m = re.match(r"CUE:\s*(.+?)\s*\|\s*TARGET:\s*(.+?)\s*\|\s*IMPORTANCE:\s*([\d.]+)", line)
if m:
memories.append(ExtractedMemory(m.group(1).strip(), m.group(2).strip(), float(m.group(3))))
return memories
def _extract_memories_heuristic(user_msg: str, assistant_msg: str) -> list[ExtractedMemory]:
memories = []
# detect questions — English and Chinese
has_question = "?" in user_msg or "" in user_msg or any(
user_msg.strip().startswith(q) for q in ["怎么", "什么", "", "为什么", "如何", "多少", ""]
)
# count meaningful length: for Chinese, use character count
assistant_long_enough = len(assistant_msg) > 20
if has_question and assistant_long_enough:
cue = user_msg.rstrip("?").strip()
memories.append(ExtractedMemory(
cue=cue, target=assistant_msg[:300], importance=0.6,
))
# tech keywords — English and Chinese
tech_keywords = [
"deploy", "config", "bug", "fix", "error", "database", "server",
"api", "port", "token", "password", "version", "install", "upgrade",
"部署", "配置", "错误", "数据库", "服务器", "端口", "密码", "版本",
"安装", "升级", "模型", "工具", "代码", "项目", "优化", "性能",
"内存", "GPU", "vllm", "docker", "k8s", "git", "编译", "测试",
]
combined = (user_msg + " " + assistant_msg).lower()
user_meaningful = len(user_msg) >= 8 # characters, not words
if any(kw in combined for kw in tech_keywords) and user_meaningful:
if not memories: # avoid duplicate with Q&A extraction
memories.append(ExtractedMemory(
cue=user_msg[:150], target=assistant_msg[:300], importance=0.5,
))
return memories
def _generate_paraphrases_heuristic(text: str, n: int = 3) -> list[str]:
variants = []
text_lower = text.lower().strip()
# English prefixes
for pfx in ["can you ", "please ", "i need to ", "how do i ", "how to ", "what is ", "what's "]:
if text_lower.startswith(pfx):
stripped = text[len(pfx):].strip()
if stripped:
variants.append(stripped)
# Chinese prefixes
for pfx in ["帮我看看", "帮我", "请问", "我想知道", "能不能", "怎么样", "看下", "看看"]:
if text.startswith(pfx):
stripped = text[len(pfx):].strip()
if stripped:
variants.append(stripped)
# synonym swaps — English
en_swaps = {"slow": "performance issues", "fix": "resolve", "deploy": "release",
"error": "issue", "bug": "problem", "database": "DB", "server": "machine"}
for old, new in en_swaps.items():
if old in text_lower:
variant = text.replace(old, new).replace(old.capitalize(), new.capitalize())
if variant != text and variant not in variants:
variants.append(variant)
# synonym swaps — Chinese
cn_swaps = {"数据库": "DB", "服务器": "机器", "部署": "上线", "配置": "设置",
"性能": "速度", "优化": "改进", "工具": "tool", "项目": "project"}
for old, new in cn_swaps.items():
if old in text:
variant = text.replace(old, new)
if variant != text and variant not in variants:
variants.append(variant)
return variants[:n]
def _generate_paraphrases_llm(text: str, n: int = 3) -> list[str]:
prompt = f"Generate {n} different paraphrases of this text. Each should convey the same meaning but use different words. One per line, no numbering.\n\nText: {text}"
try:
resp = llm_client.chat.completions.create(
model=LLM_MODEL, messages=[{"role": "user", "content": prompt}],
temperature=0.8, max_tokens=256,
)
result = resp.choices[0].message.content
return [l.strip() for l in result.strip().split("\n") if l.strip() and len(l.strip()) > 3][:n]
except Exception:
return _generate_paraphrases_heuristic(text, n)
def _extract_and_store(user_msg: str, assistant_msg: str) -> int:
if llm_client:
memories = _extract_memories_llm(user_msg, assistant_msg)
else:
memories = _extract_memories_heuristic(user_msg, assistant_msg)
if not memories:
return 0
stored = 0
for mem in memories:
if mem.importance < 0.3:
continue
cue_emb = embed(mem.cue)
target_emb = embed(mem.target)
if llm_client:
paraphrases = _generate_paraphrases_llm(mem.cue, n=3)
else:
paraphrases = _generate_paraphrases_heuristic(mem.cue, n=3)
variant_embs = embed_batch(paraphrases) if paraphrases else []
hippocampus.store(
cue_emb, target_emb,
cue_variants=variant_embs,
metadata={"cue": mem.cue, "target": mem.target, "importance": mem.importance},
timestamp=time.time(),
)
stored += 1
if stored > 0:
maybe_save()
logger.info("ingested %d memories from conversation turn", stored)
return stored
@app.post("/store", response_model=StoreResponse)
async def store_direct(req: StoreRequest):
"""Direct store — bypass LLM extraction, for manual/testing use."""
cue_emb = embed(req.cue)
target_emb = embed(req.target)
mid = hippocampus.store(
cue_emb, target_emb,
metadata={"cue": req.cue, "target": req.target, "importance": req.importance},
timestamp=time.time(),
)
maybe_save()
return StoreResponse(memory_id=mid)
@app.get("/stats")
async def stats():
s = hippocampus.stats()
s["device"] = DEVICE
s["embedding_model"] = EMBED_MODEL
s["checkpoint"] = str(CHECKPOINT)
s["checkpoint_exists"] = CHECKPOINT.exists()
return s
@app.delete("/memory/{memory_id}")
async def forget(memory_id: int):
hippocampus.forget(memory_id)
maybe_save()
return {"deleted": memory_id}