Hopfield + Hebbian hybrid memory system for LLMs. Two nights of experiments (16 iterations), validated on LongMemEval (ICLR 2025). Architecture: - Single-hop: Two-Stage Hopfield (NN top-20 → softmax settle) - Multi-hop: Hebbian W matrix with WTA pattern separation - 64% on LongMemEval (500 questions), retrieval-only, no LLM dependency - 4ms latency @ 20K memories, ~1GB VRAM Key findings: - Hopfield attention solved noise tolerance (20% → 100% vs flat Hebbian) - WTA pattern separation enables 20K+ capacity - Multi-hop associative chains (6 hops, CosSim=1.0) — RAG can't do this - MiniLM-L6 is optimal (discrimination gap > absolute similarity) - Paraphrase cue augmentation: 55% → 100% on synthetic, 36% → 64% on benchmark - SNN encoder viable (CosSim 0.99) but not needed for current architecture
322 lines
11 KiB
Python
322 lines
11 KiB
Python
"""LongMemEval benchmark with Gemma 4 post-retrieval reasoning.
|
|
|
|
Previous result: 36% with retrieval-only.
|
|
This version: retrieve top-K memories → Gemma 4 reads them and answers the question.
|
|
|
|
Improvements:
|
|
1. No truncation of assistant responses (store full text, chunked)
|
|
2. Store user statements as memories (preference extraction)
|
|
3. Include timestamps in stored memories
|
|
4. Post-retrieval: Gemma 4 synthesizes answer from recalled memories
|
|
"""
|
|
|
|
import sys
|
|
import json
|
|
import time
|
|
import re
|
|
import requests
|
|
from pathlib import Path
|
|
from collections import Counter
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import numpy as np
|
|
|
|
sys.path.insert(0, str(Path(__file__).parent.parent / "src"))
|
|
from nuonuo.hippocampus import HippocampalMemory
|
|
|
|
DEVICE = "cuda"
|
|
OLLAMA_URL = "http://localhost:11434/api/chat"
|
|
|
|
|
|
def gemma_chat(messages, max_tokens=256, temperature=0):
|
|
"""Call Gemma 4 via Ollama."""
|
|
try:
|
|
resp = requests.post(OLLAMA_URL, json={
|
|
"model": "gemma4:31b",
|
|
"messages": messages,
|
|
"stream": False,
|
|
"think": False,
|
|
"options": {"num_predict": max_tokens, "temperature": temperature},
|
|
}, timeout=60)
|
|
return resp.json()["message"]["content"]
|
|
except Exception as e:
|
|
return None
|
|
|
|
|
|
def load_model():
|
|
from sentence_transformers import SentenceTransformer
|
|
return SentenceTransformer("all-MiniLM-L6-v2", device=DEVICE)
|
|
|
|
|
|
def emb_batch(model, texts):
|
|
if not texts:
|
|
return []
|
|
embs = model.encode(texts, convert_to_tensor=True,
|
|
normalize_embeddings=True, device=DEVICE, batch_size=64)
|
|
return [embs[i] for i in range(embs.shape[0])]
|
|
|
|
|
|
def extract_memories_v2(session, session_date=""):
|
|
"""Improved memory extraction.
|
|
|
|
Changes from v1:
|
|
- Don't truncate assistant responses; chunk them instead
|
|
- Extract user preferences ("I like/prefer/use/enjoy X")
|
|
- Store timestamps
|
|
"""
|
|
memories = [] # list of (cue_text, target_text, extra_metadata)
|
|
|
|
for i, turn in enumerate(session):
|
|
content = turn["content"].strip()
|
|
role = turn["role"]
|
|
|
|
if role == "user":
|
|
# Find next assistant response
|
|
assistant_text = ""
|
|
for j in range(i + 1, len(session)):
|
|
if session[j]["role"] == "assistant":
|
|
assistant_text = session[j]["content"].strip()
|
|
break
|
|
|
|
if len(content) > 10 and len(assistant_text) > 10:
|
|
# Chunk long assistant responses (500 char chunks)
|
|
if len(assistant_text) > 500:
|
|
chunks = []
|
|
for start in range(0, len(assistant_text), 400):
|
|
chunk = assistant_text[start:start + 500]
|
|
if len(chunk) > 50:
|
|
chunks.append(chunk)
|
|
for ci, chunk in enumerate(chunks[:5]): # max 5 chunks
|
|
memories.append((
|
|
content,
|
|
chunk,
|
|
{"date": session_date, "chunk": ci}
|
|
))
|
|
else:
|
|
memories.append((content, assistant_text, {"date": session_date}))
|
|
|
|
# User's own statements as memories
|
|
if len(content) > 20:
|
|
memories.append((content, content, {"date": session_date, "type": "user_statement"}))
|
|
|
|
# Preference extraction
|
|
pref_patterns = [
|
|
r"I (?:like|love|prefer|enjoy|use|usually|always|often)\s+(.{5,80})",
|
|
r"my (?:favorite|preferred)\s+(.{5,80})",
|
|
r"I'm (?:a fan of|into|interested in)\s+(.{5,80})",
|
|
]
|
|
for pat in pref_patterns:
|
|
match = re.search(pat, content, re.IGNORECASE)
|
|
if match:
|
|
pref_text = match.group(0)
|
|
memories.append((
|
|
f"user preference: {pref_text}",
|
|
content,
|
|
{"date": session_date, "type": "preference"}
|
|
))
|
|
|
|
return memories
|
|
|
|
|
|
def check_answer(answer_text, expected_answer):
|
|
"""Check if the generated answer contains the expected answer."""
|
|
if not answer_text:
|
|
return False
|
|
|
|
expected = str(expected_answer).lower().strip()
|
|
generated = answer_text.lower().strip()
|
|
|
|
# Handle unanswerable
|
|
if "did not mention" in expected or "not mention" in expected:
|
|
neg_phrases = ["don't have", "no information", "not mentioned",
|
|
"cannot find", "don't recall", "no record", "not available"]
|
|
return any(p in generated for p in neg_phrases)
|
|
|
|
# Direct substring match
|
|
if expected in generated:
|
|
return True
|
|
|
|
# Key words match (60% of answer words present)
|
|
answer_words = [w for w in expected.split() if len(w) > 3]
|
|
if answer_words:
|
|
matches = sum(1 for w in answer_words if w in generated)
|
|
if matches >= max(1, len(answer_words) * 0.5):
|
|
return True
|
|
|
|
# Number matching (for temporal questions)
|
|
expected_nums = re.findall(r'\d+', expected)
|
|
generated_nums = re.findall(r'\d+', generated)
|
|
if expected_nums and any(n in generated_nums for n in expected_nums):
|
|
return True
|
|
|
|
return False
|
|
|
|
|
|
def run_benchmark(model, oracle, max_questions=None):
|
|
"""Full benchmark with Gemma 4 reasoning."""
|
|
if max_questions:
|
|
oracle = oracle[:max_questions]
|
|
|
|
results_by_type = Counter()
|
|
total_by_type = Counter()
|
|
retrieval_hits = Counter()
|
|
gemma_calls = 0
|
|
gemma_errors = 0
|
|
total_time = 0
|
|
|
|
for qi, entry in enumerate(oracle):
|
|
qtype = entry["question_type"]
|
|
question = entry["question"]
|
|
answer = entry["answer"]
|
|
sessions = entry["haystack_sessions"]
|
|
dates = entry.get("haystack_dates", [""] * len(sessions))
|
|
|
|
total_by_type[qtype] += 1
|
|
|
|
# Build memory
|
|
mem = HippocampalMemory(embed_dim=384)
|
|
all_texts = [] # (cue, target, meta)
|
|
|
|
for si, session in enumerate(sessions):
|
|
date = dates[si] if si < len(dates) else ""
|
|
pairs = extract_memories_v2(session, session_date=date)
|
|
all_texts.extend(pairs)
|
|
|
|
if not all_texts:
|
|
continue
|
|
|
|
cue_texts = [t[0] for t in all_texts]
|
|
target_texts = [t[1] for t in all_texts]
|
|
metas = [t[2] for t in all_texts]
|
|
|
|
cue_embs = emb_batch(model, cue_texts)
|
|
target_embs = emb_batch(model, target_texts)
|
|
|
|
for i in range(len(cue_texts)):
|
|
mem.store(cue_embs[i], target_embs[i],
|
|
metadata={"cue": cue_texts[i][:200],
|
|
"target": target_texts[i][:500],
|
|
**metas[i]})
|
|
|
|
# Recall
|
|
t0 = time.time()
|
|
q_emb = model.encode([question], convert_to_tensor=True,
|
|
normalize_embeddings=True, device=DEVICE)[0]
|
|
results = mem.recall(q_emb, top_k=10)
|
|
|
|
# Check if retrieval alone has the answer
|
|
retrieved_texts = []
|
|
for r in results:
|
|
retrieved_texts.append(r.metadata.get("target", ""))
|
|
retrieved_texts.append(r.metadata.get("cue", ""))
|
|
|
|
retrieval_hit = check_answer("\n".join(retrieved_texts), answer)
|
|
if retrieval_hit:
|
|
retrieval_hits[qtype] += 1
|
|
|
|
# Build context for Gemma
|
|
context_parts = []
|
|
for r in results[:7]: # top 7 memories
|
|
date = r.metadata.get("date", "")
|
|
target = r.metadata.get("target", "")
|
|
if date:
|
|
context_parts.append(f"[{date}] {target}")
|
|
else:
|
|
context_parts.append(target)
|
|
|
|
context = "\n".join(context_parts)
|
|
|
|
# Ask Gemma to answer based on recalled memories
|
|
prompt = f"""You are answering a question based on recalled memories from past conversations with the user. The memories may contain dates and timestamps — use them for temporal reasoning.
|
|
|
|
Memories:
|
|
{context}
|
|
|
|
Question: {question}
|
|
|
|
Instructions:
|
|
- Answer based ONLY on the memories above
|
|
- For "which came first" questions, compare the dates in the memories
|
|
- For "how many days" questions, calculate from the dates mentioned
|
|
- For counting questions, count all relevant items across memories
|
|
- Extract specific facts (names, numbers, places) directly from the memories
|
|
- Be concise (1-2 sentences)
|
|
- If genuinely no relevant information exists, say "Not mentioned in our conversations"
|
|
|
|
Answer:"""
|
|
|
|
gemma_answer = gemma_chat([{"role": "user", "content": prompt}],
|
|
max_tokens=100)
|
|
gemma_calls += 1
|
|
if gemma_answer is None:
|
|
gemma_errors += 1
|
|
gemma_answer = "\n".join(retrieved_texts[:3]) # fallback
|
|
|
|
total_time += time.time() - t0
|
|
|
|
hit = check_answer(gemma_answer, answer)
|
|
if hit:
|
|
results_by_type[qtype] += 1
|
|
|
|
if qi < 3 or (not hit and qi < 20):
|
|
status = "✓" if hit else "✗"
|
|
ret_status = "R✓" if retrieval_hit else "R✗"
|
|
print(f" {status} {ret_status} [{qtype[:12]:>12}] Q: {question[:55]}...")
|
|
print(f" Expected: {str(answer)[:60]}...")
|
|
if gemma_answer:
|
|
print(f" Gemma: {gemma_answer[:80]}...")
|
|
|
|
del mem
|
|
|
|
if (qi + 1) % 25 == 0:
|
|
print(f" ... {qi+1}/{len(oracle)} ({total_time:.0f}s, "
|
|
f"{gemma_errors} Gemma errors)")
|
|
|
|
return results_by_type, total_by_type, retrieval_hits, total_time
|
|
|
|
|
|
def main():
|
|
print("=" * 60)
|
|
print("LongMemEval + Gemma 4 Post-Retrieval Reasoning")
|
|
print("=" * 60)
|
|
|
|
# Verify Gemma
|
|
test = gemma_chat([{"role": "user", "content": "Say OK"}], max_tokens=10)
|
|
if test:
|
|
print(f"Gemma 4 connected: '{test.strip()}'")
|
|
else:
|
|
print("ERROR: Gemma 4 not available!")
|
|
return
|
|
|
|
model = load_model()
|
|
|
|
with open("data/longmemeval_oracle.json") as f:
|
|
oracle = json.load(f)
|
|
|
|
# Full benchmark directly
|
|
if True:
|
|
print(f"\n=== Full Benchmark (500 questions) ===\n")
|
|
results, totals, ret_hits, dt = run_benchmark(model, oracle)
|
|
|
|
print(f"\n{'='*60}")
|
|
print("FINAL RESULTS (Retrieval + Gemma 4 Reasoning)")
|
|
print(f"{'='*60}")
|
|
overall = sum(results.values())
|
|
total = sum(totals.values())
|
|
ret_overall = sum(ret_hits.values())
|
|
print(f"Overall: {overall}/{total} ({overall/total:.0%}) "
|
|
f"[retrieval-only baseline: {ret_overall}/{total} ({ret_overall/total:.0%})]")
|
|
print()
|
|
for qtype in sorted(totals.keys()):
|
|
c = results.get(qtype, 0)
|
|
rc = ret_hits.get(qtype, 0)
|
|
t = totals[qtype]
|
|
bar = "█" * int(c/t * 20) + "░" * (20 - int(c/t * 20))
|
|
print(f" {qtype:<25}: {c:>3}/{t:<3} ({c/t:>5.1%}) {bar} [ret: {rc/t:.0%}]")
|
|
print(f"\nTime: {dt:.0f}s ({dt/total:.1f}s/question)")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|