Files
nuonuo/experiments/exp16_longmemeval_gemma.py
Fam Zheng d923aa1e31 NuoNuo: Hippocampal memory module prototype
Hopfield + Hebbian hybrid memory system for LLMs.
Two nights of experiments (16 iterations), validated on LongMemEval (ICLR 2025).

Architecture:
- Single-hop: Two-Stage Hopfield (NN top-20 → softmax settle)
- Multi-hop: Hebbian W matrix with WTA pattern separation
- 64% on LongMemEval (500 questions), retrieval-only, no LLM dependency
- 4ms latency @ 20K memories, ~1GB VRAM

Key findings:
- Hopfield attention solved noise tolerance (20% → 100% vs flat Hebbian)
- WTA pattern separation enables 20K+ capacity
- Multi-hop associative chains (6 hops, CosSim=1.0) — RAG can't do this
- MiniLM-L6 is optimal (discrimination gap > absolute similarity)
- Paraphrase cue augmentation: 55% → 100% on synthetic, 36% → 64% on benchmark
- SNN encoder viable (CosSim 0.99) but not needed for current architecture
2026-04-07 10:37:24 +01:00

322 lines
11 KiB
Python

"""LongMemEval benchmark with Gemma 4 post-retrieval reasoning.
Previous result: 36% with retrieval-only.
This version: retrieve top-K memories → Gemma 4 reads them and answers the question.
Improvements:
1. No truncation of assistant responses (store full text, chunked)
2. Store user statements as memories (preference extraction)
3. Include timestamps in stored memories
4. Post-retrieval: Gemma 4 synthesizes answer from recalled memories
"""
import sys
import json
import time
import re
import requests
from pathlib import Path
from collections import Counter
import torch
import torch.nn as nn
import numpy as np
sys.path.insert(0, str(Path(__file__).parent.parent / "src"))
from nuonuo.hippocampus import HippocampalMemory
DEVICE = "cuda"
OLLAMA_URL = "http://localhost:11434/api/chat"
def gemma_chat(messages, max_tokens=256, temperature=0):
"""Call Gemma 4 via Ollama."""
try:
resp = requests.post(OLLAMA_URL, json={
"model": "gemma4:31b",
"messages": messages,
"stream": False,
"think": False,
"options": {"num_predict": max_tokens, "temperature": temperature},
}, timeout=60)
return resp.json()["message"]["content"]
except Exception as e:
return None
def load_model():
from sentence_transformers import SentenceTransformer
return SentenceTransformer("all-MiniLM-L6-v2", device=DEVICE)
def emb_batch(model, texts):
if not texts:
return []
embs = model.encode(texts, convert_to_tensor=True,
normalize_embeddings=True, device=DEVICE, batch_size=64)
return [embs[i] for i in range(embs.shape[0])]
def extract_memories_v2(session, session_date=""):
"""Improved memory extraction.
Changes from v1:
- Don't truncate assistant responses; chunk them instead
- Extract user preferences ("I like/prefer/use/enjoy X")
- Store timestamps
"""
memories = [] # list of (cue_text, target_text, extra_metadata)
for i, turn in enumerate(session):
content = turn["content"].strip()
role = turn["role"]
if role == "user":
# Find next assistant response
assistant_text = ""
for j in range(i + 1, len(session)):
if session[j]["role"] == "assistant":
assistant_text = session[j]["content"].strip()
break
if len(content) > 10 and len(assistant_text) > 10:
# Chunk long assistant responses (500 char chunks)
if len(assistant_text) > 500:
chunks = []
for start in range(0, len(assistant_text), 400):
chunk = assistant_text[start:start + 500]
if len(chunk) > 50:
chunks.append(chunk)
for ci, chunk in enumerate(chunks[:5]): # max 5 chunks
memories.append((
content,
chunk,
{"date": session_date, "chunk": ci}
))
else:
memories.append((content, assistant_text, {"date": session_date}))
# User's own statements as memories
if len(content) > 20:
memories.append((content, content, {"date": session_date, "type": "user_statement"}))
# Preference extraction
pref_patterns = [
r"I (?:like|love|prefer|enjoy|use|usually|always|often)\s+(.{5,80})",
r"my (?:favorite|preferred)\s+(.{5,80})",
r"I'm (?:a fan of|into|interested in)\s+(.{5,80})",
]
for pat in pref_patterns:
match = re.search(pat, content, re.IGNORECASE)
if match:
pref_text = match.group(0)
memories.append((
f"user preference: {pref_text}",
content,
{"date": session_date, "type": "preference"}
))
return memories
def check_answer(answer_text, expected_answer):
"""Check if the generated answer contains the expected answer."""
if not answer_text:
return False
expected = str(expected_answer).lower().strip()
generated = answer_text.lower().strip()
# Handle unanswerable
if "did not mention" in expected or "not mention" in expected:
neg_phrases = ["don't have", "no information", "not mentioned",
"cannot find", "don't recall", "no record", "not available"]
return any(p in generated for p in neg_phrases)
# Direct substring match
if expected in generated:
return True
# Key words match (60% of answer words present)
answer_words = [w for w in expected.split() if len(w) > 3]
if answer_words:
matches = sum(1 for w in answer_words if w in generated)
if matches >= max(1, len(answer_words) * 0.5):
return True
# Number matching (for temporal questions)
expected_nums = re.findall(r'\d+', expected)
generated_nums = re.findall(r'\d+', generated)
if expected_nums and any(n in generated_nums for n in expected_nums):
return True
return False
def run_benchmark(model, oracle, max_questions=None):
"""Full benchmark with Gemma 4 reasoning."""
if max_questions:
oracle = oracle[:max_questions]
results_by_type = Counter()
total_by_type = Counter()
retrieval_hits = Counter()
gemma_calls = 0
gemma_errors = 0
total_time = 0
for qi, entry in enumerate(oracle):
qtype = entry["question_type"]
question = entry["question"]
answer = entry["answer"]
sessions = entry["haystack_sessions"]
dates = entry.get("haystack_dates", [""] * len(sessions))
total_by_type[qtype] += 1
# Build memory
mem = HippocampalMemory(embed_dim=384)
all_texts = [] # (cue, target, meta)
for si, session in enumerate(sessions):
date = dates[si] if si < len(dates) else ""
pairs = extract_memories_v2(session, session_date=date)
all_texts.extend(pairs)
if not all_texts:
continue
cue_texts = [t[0] for t in all_texts]
target_texts = [t[1] for t in all_texts]
metas = [t[2] for t in all_texts]
cue_embs = emb_batch(model, cue_texts)
target_embs = emb_batch(model, target_texts)
for i in range(len(cue_texts)):
mem.store(cue_embs[i], target_embs[i],
metadata={"cue": cue_texts[i][:200],
"target": target_texts[i][:500],
**metas[i]})
# Recall
t0 = time.time()
q_emb = model.encode([question], convert_to_tensor=True,
normalize_embeddings=True, device=DEVICE)[0]
results = mem.recall(q_emb, top_k=10)
# Check if retrieval alone has the answer
retrieved_texts = []
for r in results:
retrieved_texts.append(r.metadata.get("target", ""))
retrieved_texts.append(r.metadata.get("cue", ""))
retrieval_hit = check_answer("\n".join(retrieved_texts), answer)
if retrieval_hit:
retrieval_hits[qtype] += 1
# Build context for Gemma
context_parts = []
for r in results[:7]: # top 7 memories
date = r.metadata.get("date", "")
target = r.metadata.get("target", "")
if date:
context_parts.append(f"[{date}] {target}")
else:
context_parts.append(target)
context = "\n".join(context_parts)
# Ask Gemma to answer based on recalled memories
prompt = f"""You are answering a question based on recalled memories from past conversations with the user. The memories may contain dates and timestamps — use them for temporal reasoning.
Memories:
{context}
Question: {question}
Instructions:
- Answer based ONLY on the memories above
- For "which came first" questions, compare the dates in the memories
- For "how many days" questions, calculate from the dates mentioned
- For counting questions, count all relevant items across memories
- Extract specific facts (names, numbers, places) directly from the memories
- Be concise (1-2 sentences)
- If genuinely no relevant information exists, say "Not mentioned in our conversations"
Answer:"""
gemma_answer = gemma_chat([{"role": "user", "content": prompt}],
max_tokens=100)
gemma_calls += 1
if gemma_answer is None:
gemma_errors += 1
gemma_answer = "\n".join(retrieved_texts[:3]) # fallback
total_time += time.time() - t0
hit = check_answer(gemma_answer, answer)
if hit:
results_by_type[qtype] += 1
if qi < 3 or (not hit and qi < 20):
status = "" if hit else ""
ret_status = "R✓" if retrieval_hit else "R✗"
print(f" {status} {ret_status} [{qtype[:12]:>12}] Q: {question[:55]}...")
print(f" Expected: {str(answer)[:60]}...")
if gemma_answer:
print(f" Gemma: {gemma_answer[:80]}...")
del mem
if (qi + 1) % 25 == 0:
print(f" ... {qi+1}/{len(oracle)} ({total_time:.0f}s, "
f"{gemma_errors} Gemma errors)")
return results_by_type, total_by_type, retrieval_hits, total_time
def main():
print("=" * 60)
print("LongMemEval + Gemma 4 Post-Retrieval Reasoning")
print("=" * 60)
# Verify Gemma
test = gemma_chat([{"role": "user", "content": "Say OK"}], max_tokens=10)
if test:
print(f"Gemma 4 connected: '{test.strip()}'")
else:
print("ERROR: Gemma 4 not available!")
return
model = load_model()
with open("data/longmemeval_oracle.json") as f:
oracle = json.load(f)
# Full benchmark directly
if True:
print(f"\n=== Full Benchmark (500 questions) ===\n")
results, totals, ret_hits, dt = run_benchmark(model, oracle)
print(f"\n{'='*60}")
print("FINAL RESULTS (Retrieval + Gemma 4 Reasoning)")
print(f"{'='*60}")
overall = sum(results.values())
total = sum(totals.values())
ret_overall = sum(ret_hits.values())
print(f"Overall: {overall}/{total} ({overall/total:.0%}) "
f"[retrieval-only baseline: {ret_overall}/{total} ({ret_overall/total:.0%})]")
print()
for qtype in sorted(totals.keys()):
c = results.get(qtype, 0)
rc = ret_hits.get(qtype, 0)
t = totals[qtype]
bar = "" * int(c/t * 20) + "" * (20 - int(c/t * 20))
print(f" {qtype:<25}: {c:>3}/{t:<3} ({c/t:>5.1%}) {bar} [ret: {rc/t:.0%}]")
print(f"\nTime: {dt:.0f}s ({dt/total:.1f}s/question)")
if __name__ == "__main__":
main()