"""nocmem — Memory service for NOC. Wraps NuoNuo's HippocampalMemory as an HTTP API. Auto-recall on every user message, async ingest after LLM response. """ import asyncio import os import re import time import logging from contextlib import asynccontextmanager from dataclasses import dataclass from pathlib import Path import torch from fastapi import FastAPI from pydantic import BaseModel, Field from sentence_transformers import SentenceTransformer from openai import OpenAI from nuonuo.hippocampus import HippocampalMemory logger = logging.getLogger("nocmem") # ── config ────────────────────────────────────────────────────────── EMBED_MODEL = os.environ.get("NOCMEM_EMBED_MODEL", "all-MiniLM-L6-v2") EMBED_DIM = int(os.environ.get("NOCMEM_EMBED_DIM", "384")) DEVICE = os.environ.get("NOCMEM_DEVICE", "cuda" if torch.cuda.is_available() else "cpu") DATA_DIR = Path(os.environ.get("NOCMEM_DATA_DIR", "./data")) CHECKPOINT = DATA_DIR / "hippocampus.pt" SAVE_INTERVAL = int(os.environ.get("NOCMEM_SAVE_INTERVAL", "10")) # save every N stores HOPFIELD_BETA = float(os.environ.get("NOCMEM_HOPFIELD_BETA", "32.0")) HOPFIELD_TOP_K = int(os.environ.get("NOCMEM_HOPFIELD_TOP_K", "10")) COS_SIM_THRESHOLD = float(os.environ.get("NOCMEM_COS_SIM_THRESHOLD", "0.35")) # LLM for memory extraction (optional) LLM_ENDPOINT = os.environ.get("NOCMEM_LLM_ENDPOINT", "") LLM_MODEL = os.environ.get("NOCMEM_LLM_MODEL", "gemma4:12b") LLM_API_KEY = os.environ.get("NOCMEM_LLM_API_KEY", "unused") # ── globals ───────────────────────────────────────────────────────── encoder: SentenceTransformer = None # type: ignore hippocampus: HippocampalMemory = None # type: ignore llm_client = None # optional _stores_since_save = 0 def embed(text: str) -> torch.Tensor: return encoder.encode( [text], convert_to_tensor=True, normalize_embeddings=True, device=DEVICE )[0] def embed_batch(texts: list[str]) -> list[torch.Tensor]: if not texts: return [] t = encoder.encode( texts, convert_to_tensor=True, normalize_embeddings=True, device=DEVICE ) return [t[i] for i in range(t.shape[0])] def maybe_save(): global _stores_since_save _stores_since_save += 1 if _stores_since_save >= SAVE_INTERVAL: _stores_since_save = 0 DATA_DIR.mkdir(parents=True, exist_ok=True) hippocampus.save(str(CHECKPOINT)) logger.info("checkpoint saved: %s", CHECKPOINT) # ── lifespan ──────────────────────────────────────────────────────── @asynccontextmanager async def lifespan(app: FastAPI): global encoder, hippocampus, llm_client logger.info("loading embedding model: %s (device=%s)", EMBED_MODEL, DEVICE) encoder = SentenceTransformer(EMBED_MODEL, device=DEVICE) if CHECKPOINT.exists(): logger.info("loading checkpoint: %s", CHECKPOINT) hippocampus = HippocampalMemory.load(str(CHECKPOINT), device=DEVICE) logger.info("loaded %d memories", len(hippocampus.memories)) else: logger.info("no checkpoint found, starting fresh") hippocampus = HippocampalMemory( embed_dim=EMBED_DIM, beta=HOPFIELD_BETA, hopfield_top_k=HOPFIELD_TOP_K, device=DEVICE, ) if LLM_ENDPOINT: try: client = OpenAI(base_url=LLM_ENDPOINT, api_key=LLM_API_KEY, timeout=5.0) client.models.list() llm_client = client logger.info("LLM client connected: %s", LLM_ENDPOINT) except Exception as e: logger.warning("LLM client unavailable: %s", e) yield # save on shutdown DATA_DIR.mkdir(parents=True, exist_ok=True) hippocampus.save(str(CHECKPOINT)) logger.info("shutdown: checkpoint saved") app = FastAPI(title="nocmem", lifespan=lifespan) # ── models ────────────────────────────────────────────────────────── class RecallRequest(BaseModel): text: str top_k: int = Field(default=5, ge=1, le=20) hops: int = Field(default=2, ge=1, le=5) min_similarity: float = Field(default=0.0, ge=0.0, le=1.0) class RecallResponse(BaseModel): memories: str count: int latency_ms: float class IngestRequest(BaseModel): user_msg: str assistant_msg: str class IngestResponse(BaseModel): stored: int class StoreRequest(BaseModel): cue: str target: str importance: float = Field(default=0.5, ge=0.0, le=1.0) class StoreResponse(BaseModel): memory_id: int # ── endpoints ─────────────────────────────────────────────────────── @app.post("/recall", response_model=RecallResponse) async def recall(req: RecallRequest): t0 = time.monotonic() query_emb = embed(req.text) # pre-filter: check if anything in memory is actually similar enough cue_mat = hippocampus._get_cue_matrix() if cue_mat is not None and COS_SIM_THRESHOLD > 0: cos_sims = query_emb @ cue_mat.T max_cos_sim = cos_sims.max().item() if max_cos_sim < COS_SIM_THRESHOLD: # nothing in memory is similar enough — don't hallucinate return RecallResponse(memories="", count=0, latency_ms=(time.monotonic() - t0) * 1000) # single-hop results = hippocampus.recall(query_emb, top_k=req.top_k) # multi-hop chain from top result chain_results = [] if results and req.hops > 1: chain = hippocampus.recall_chain(query_emb, hops=req.hops) # add chain results not already in single-hop seen_ids = {r.memory_id for r in results} for cr in chain: if cr.memory_id not in seen_ids: chain_results.append(cr) seen_ids.add(cr.memory_id) all_results = results + chain_results elapsed = (time.monotonic() - t0) * 1000 if not all_results: return RecallResponse(memories="", count=0, latency_ms=elapsed) lines = [] for r in all_results: if r.similarity < req.min_similarity: continue meta = r.metadata text = meta.get("target", meta.get("text", "")) if not text: continue hop_tag = f" (联想 hop={r.hop_distance})" if r.hop_distance > 1 else "" lines.append(f"- {text}{hop_tag}") if not lines: return RecallResponse(memories="", count=0, latency_ms=elapsed) formatted = "[以下是可能相关的历史记忆,仅供参考。请优先关注用户当前的消息。]\n" + "\n".join(lines) return RecallResponse(memories=formatted, count=len(lines), latency_ms=elapsed) @app.post("/ingest", response_model=IngestResponse) async def ingest(req: IngestRequest): extracted = await asyncio.to_thread(_extract_and_store, req.user_msg, req.assistant_msg) return IngestResponse(stored=extracted) @dataclass class ExtractedMemory: cue: str target: str importance: float = 0.5 def _extract_memories_llm(user_msg: str, assistant_msg: str) -> list[ExtractedMemory]: prompt = ( "From this conversation turn, extract key facts worth remembering for future conversations.\n" "For each fact, provide a \"cue\" (what would trigger recalling this) and a \"target\" (the fact itself).\n" "Rate importance 0-1 (1 = critical fact, 0 = trivial).\n\n" f"User: {user_msg}\nAssistant: {assistant_msg}\n\n" "Output format (one per line):\nCUE: | TARGET: | IMPORTANCE: <0-1>\n\n" "Only extract genuinely useful facts. If nothing worth remembering, output NONE." ) try: resp = llm_client.chat.completions.create( model=LLM_MODEL, messages=[{"role": "user", "content": prompt}], temperature=0.3, max_tokens=512, ) result = resp.choices[0].message.content except Exception: return _extract_memories_heuristic(user_msg, assistant_msg) memories = [] for line in result.strip().split("\n"): if line.strip() == "NONE": break m = re.match(r"CUE:\s*(.+?)\s*\|\s*TARGET:\s*(.+?)\s*\|\s*IMPORTANCE:\s*([\d.]+)", line) if m: memories.append(ExtractedMemory(m.group(1).strip(), m.group(2).strip(), float(m.group(3)))) return memories def _extract_memories_heuristic(user_msg: str, assistant_msg: str) -> list[ExtractedMemory]: memories = [] # detect questions — English and Chinese has_question = "?" in user_msg or "?" in user_msg or any( user_msg.strip().startswith(q) for q in ["怎么", "什么", "哪", "为什么", "如何", "多少", "几"] ) # count meaningful length: for Chinese, use character count assistant_long_enough = len(assistant_msg) > 20 if has_question and assistant_long_enough: cue = user_msg.rstrip("??").strip() memories.append(ExtractedMemory( cue=cue, target=assistant_msg[:300], importance=0.6, )) # tech keywords — English and Chinese tech_keywords = [ "deploy", "config", "bug", "fix", "error", "database", "server", "api", "port", "token", "password", "version", "install", "upgrade", "部署", "配置", "错误", "数据库", "服务器", "端口", "密码", "版本", "安装", "升级", "模型", "工具", "代码", "项目", "优化", "性能", "内存", "GPU", "vllm", "docker", "k8s", "git", "编译", "测试", ] combined = (user_msg + " " + assistant_msg).lower() user_meaningful = len(user_msg) >= 8 # characters, not words if any(kw in combined for kw in tech_keywords) and user_meaningful: if not memories: # avoid duplicate with Q&A extraction memories.append(ExtractedMemory( cue=user_msg[:150], target=assistant_msg[:300], importance=0.5, )) return memories def _generate_paraphrases_heuristic(text: str, n: int = 3) -> list[str]: variants = [] text_lower = text.lower().strip() # English prefixes for pfx in ["can you ", "please ", "i need to ", "how do i ", "how to ", "what is ", "what's "]: if text_lower.startswith(pfx): stripped = text[len(pfx):].strip() if stripped: variants.append(stripped) # Chinese prefixes for pfx in ["帮我看看", "帮我", "请问", "我想知道", "能不能", "怎么样", "看下", "看看"]: if text.startswith(pfx): stripped = text[len(pfx):].strip() if stripped: variants.append(stripped) # synonym swaps — English en_swaps = {"slow": "performance issues", "fix": "resolve", "deploy": "release", "error": "issue", "bug": "problem", "database": "DB", "server": "machine"} for old, new in en_swaps.items(): if old in text_lower: variant = text.replace(old, new).replace(old.capitalize(), new.capitalize()) if variant != text and variant not in variants: variants.append(variant) # synonym swaps — Chinese cn_swaps = {"数据库": "DB", "服务器": "机器", "部署": "上线", "配置": "设置", "性能": "速度", "优化": "改进", "工具": "tool", "项目": "project"} for old, new in cn_swaps.items(): if old in text: variant = text.replace(old, new) if variant != text and variant not in variants: variants.append(variant) return variants[:n] def _generate_paraphrases_llm(text: str, n: int = 3) -> list[str]: prompt = f"Generate {n} different paraphrases of this text. Each should convey the same meaning but use different words. One per line, no numbering.\n\nText: {text}" try: resp = llm_client.chat.completions.create( model=LLM_MODEL, messages=[{"role": "user", "content": prompt}], temperature=0.8, max_tokens=256, ) result = resp.choices[0].message.content return [l.strip() for l in result.strip().split("\n") if l.strip() and len(l.strip()) > 3][:n] except Exception: return _generate_paraphrases_heuristic(text, n) def _extract_and_store(user_msg: str, assistant_msg: str) -> int: if llm_client: memories = _extract_memories_llm(user_msg, assistant_msg) else: memories = _extract_memories_heuristic(user_msg, assistant_msg) if not memories: return 0 stored = 0 for mem in memories: if mem.importance < 0.3: continue cue_emb = embed(mem.cue) target_emb = embed(mem.target) if llm_client: paraphrases = _generate_paraphrases_llm(mem.cue, n=3) else: paraphrases = _generate_paraphrases_heuristic(mem.cue, n=3) variant_embs = embed_batch(paraphrases) if paraphrases else [] hippocampus.store( cue_emb, target_emb, cue_variants=variant_embs, metadata={"cue": mem.cue, "target": mem.target, "importance": mem.importance}, timestamp=time.time(), ) stored += 1 if stored > 0: maybe_save() logger.info("ingested %d memories from conversation turn", stored) return stored @app.post("/store", response_model=StoreResponse) async def store_direct(req: StoreRequest): """Direct store — bypass LLM extraction, for manual/testing use.""" cue_emb = embed(req.cue) target_emb = embed(req.target) mid = hippocampus.store( cue_emb, target_emb, metadata={"cue": req.cue, "target": req.target, "importance": req.importance}, timestamp=time.time(), ) maybe_save() return StoreResponse(memory_id=mid) @app.get("/stats") async def stats(): s = hippocampus.stats() s["device"] = DEVICE s["embedding_model"] = EMBED_MODEL s["checkpoint"] = str(CHECKPOINT) s["checkpoint_exists"] = CHECKPOINT.exists() return s @app.delete("/memory/{memory_id}") async def forget(memory_id: int): hippocampus.forget(memory_id) maybe_save() return {"deleted": memory_id}