- nocmem Python service (mem/): FastAPI wrapper around NuoNuo's Hopfield-Hebbian memory, with /recall, /ingest, /store, /stats endpoints - NOC integration: auto recall after user message (injected as system msg), async ingest after LLM response (fire-and-forget) - Recall: cosine pre-filter (threshold 0.35) + Hopfield attention (β=32), top_k=3, KV-cache friendly (appended after user msg, not in system prompt) - Ingest: LLM extraction + paraphrase augmentation, heuristic fallback - Wired into main.rs, life.rs (agent done), http.rs (api chat) - Config: optional `nocmem.endpoint` in config.yaml - Includes benchmarks: LongMemEval (R@5=94.0%), efficiency, noise vs scale - Design doc: doc/nocmem.md
179 lines
6.1 KiB
Python
179 lines
6.1 KiB
Python
"""Does recall noise decrease as memory count grows?
|
|
|
|
At various scales, measure:
|
|
1. Recall accuracy (R@3) for relevant queries
|
|
2. Max cosine similarity for irrelevant queries
|
|
3. Separation gap between relevant and irrelevant
|
|
|
|
If nocmem works well at scale, the gap should widen — relevant queries
|
|
should score much higher than irrelevant ones as the memory pool grows.
|
|
"""
|
|
|
|
import json
|
|
import time
|
|
import torch
|
|
import numpy as np
|
|
from sentence_transformers import SentenceTransformer
|
|
from nuonuo.hippocampus import HippocampalMemory
|
|
|
|
DEVICE = "cuda"
|
|
EMBED_DIM = 384
|
|
DATA_FILE = "benchmarks/longmemeval.json"
|
|
|
|
IRRELEVANT_QUERIES = [
|
|
"今天天气怎么样",
|
|
"你喜欢吃什么",
|
|
"嗨",
|
|
"讲个笑话",
|
|
"明天会下雨吗",
|
|
"你觉得猫可爱还是狗可爱",
|
|
"人生的意义是什么",
|
|
"帮我写一首诗",
|
|
"地球到月球有多远",
|
|
"如何学会游泳",
|
|
]
|
|
|
|
BETA_CONFIGS = [16.0, 32.0, 64.0]
|
|
SCALES = [50, 200, 500, 1000, 3000]
|
|
|
|
|
|
def main():
|
|
print("noise vs scale benchmark\n")
|
|
print("loading encoder...")
|
|
encoder = SentenceTransformer("all-MiniLM-L6-v2", device=DEVICE)
|
|
|
|
def emb(text):
|
|
return encoder.encode([text], convert_to_tensor=True,
|
|
normalize_embeddings=True, device=DEVICE)[0]
|
|
|
|
def emb_batch(texts):
|
|
return encoder.encode(texts, convert_to_tensor=True,
|
|
normalize_embeddings=True, device=DEVICE,
|
|
batch_size=256, show_progress_bar=False)
|
|
|
|
# load data
|
|
print("loading data...")
|
|
with open(DATA_FILE) as f:
|
|
data = json.load(f)
|
|
|
|
# collect unique chunks with their source question index
|
|
all_chunks = [] # (text, question_idx, session_id)
|
|
seen = set()
|
|
for qi, item in enumerate(data):
|
|
for sid, sess in zip(item["haystack_session_ids"], item["haystack_sessions"]):
|
|
for i in range(0, len(sess) - 1, 2):
|
|
key = (sid, i)
|
|
if key in seen:
|
|
continue
|
|
seen.add(key)
|
|
user = sess[i]["content"]
|
|
asst = sess[i + 1]["content"] if i + 1 < len(sess) else ""
|
|
text = f"{user}\n{asst}"[:1000]
|
|
all_chunks.append((text, qi, sid))
|
|
print(f" {len(all_chunks)} unique chunks")
|
|
|
|
# pre-embed irrelevant queries
|
|
irrel_embs = [emb(q) for q in IRRELEVANT_QUERIES]
|
|
|
|
# collect relevant queries: for each question, we know the answer session
|
|
# pick first 50 questions that have at least one answer session
|
|
relevant_queries = []
|
|
for item in data[:100]:
|
|
answer_sids = set(item["answer_session_ids"])
|
|
relevant_queries.append((item["question"], answer_sids))
|
|
rel_query_embs = emb_batch([q for q, _ in relevant_queries])
|
|
|
|
print(f" {len(relevant_queries)} relevant queries")
|
|
print(f" {len(IRRELEVANT_QUERIES)} irrelevant queries")
|
|
|
|
# filter scales to what we have
|
|
scales = [s for s in SCALES if s <= len(all_chunks)]
|
|
|
|
for beta in BETA_CONFIGS:
|
|
print(f"\n{'='*70}")
|
|
print(f" β = {beta}")
|
|
print(f"{'='*70}")
|
|
print(f"{'Scale':>7} | {'R@3':>6} | {'Rel maxcos':>10} {'Irrel maxcos':>12} {'Gap':>8} | {'Rel attn':>9} {'Irrel attn':>11}")
|
|
print("-" * 80)
|
|
|
|
for n in scales:
|
|
subset = all_chunks[:n]
|
|
texts = [c[0] for c in subset]
|
|
sids = [c[2] for c in subset]
|
|
|
|
# embed and build memory
|
|
embeddings = emb_batch(texts)
|
|
hip = HippocampalMemory(
|
|
embed_dim=EMBED_DIM, beta=beta, hopfield_top_k=10, device=DEVICE,
|
|
)
|
|
for i in range(n):
|
|
hip.store(embeddings[i], embeddings[i],
|
|
metadata={"session_id": sids[i]})
|
|
|
|
cue_mat = hip._get_cue_matrix()
|
|
|
|
# --- relevant queries ---
|
|
rel_max_cos = []
|
|
rel_top_attn = []
|
|
hits = 0
|
|
tested = 0
|
|
|
|
for qi in range(len(relevant_queries)):
|
|
question, answer_sids = relevant_queries[qi]
|
|
qe = rel_query_embs[qi]
|
|
|
|
# check if any answer session is in this subset
|
|
subset_sids = set(sids)
|
|
if not (answer_sids & subset_sids):
|
|
continue
|
|
tested += 1
|
|
|
|
# cosine sim
|
|
cos_sims = qe @ cue_mat.T
|
|
rel_max_cos.append(cos_sims.max().item())
|
|
|
|
# recall
|
|
results = hip.recall(qe, top_k=3)
|
|
top_attn = results[0].similarity if results else 0
|
|
rel_top_attn.append(top_attn)
|
|
|
|
recalled_sids = {r.metadata["session_id"] for r in results}
|
|
if answer_sids & recalled_sids:
|
|
hits += 1
|
|
|
|
r3 = hits / tested * 100 if tested > 0 else 0
|
|
avg_rel_cos = np.mean(rel_max_cos) if rel_max_cos else 0
|
|
avg_rel_attn = np.mean(rel_top_attn) if rel_top_attn else 0
|
|
|
|
# --- irrelevant queries ---
|
|
irrel_max_cos = []
|
|
irrel_top_attn = []
|
|
for qe in irrel_embs:
|
|
cos_sims = qe @ cue_mat.T
|
|
irrel_max_cos.append(cos_sims.max().item())
|
|
|
|
results = hip.recall(qe, top_k=3)
|
|
top_attn = results[0].similarity if results else 0
|
|
irrel_top_attn.append(top_attn)
|
|
|
|
avg_irrel_cos = np.mean(irrel_max_cos)
|
|
avg_irrel_attn = np.mean(irrel_top_attn)
|
|
|
|
gap = avg_rel_cos - avg_irrel_cos
|
|
|
|
print(f"{n:>7,} | {r3:>5.1f}% | {avg_rel_cos:>10.3f} {avg_irrel_cos:>12.3f} {gap:>8.3f} | {avg_rel_attn:>8.0%} {avg_irrel_attn:>10.0%}")
|
|
|
|
del hip
|
|
torch.cuda.empty_cache()
|
|
|
|
print(f"\n── 解读 ──")
|
|
print(f"Rel maxcos: 相关查询的最大余弦相似度(越高越好)")
|
|
print(f"Irrel maxcos: 无关查询的最大余弦相似度(越低越好)")
|
|
print(f"Gap: 两者之差(越大越好 = 越容易区分)")
|
|
print(f"Rel attn: 相关查询 top1 的 Hopfield attention 权重")
|
|
print(f"Irrel attn: 无关查询 top1 的 Hopfield attention 权重(越低 = 越少噪音)")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|