Files
nuonuo/experiments/exp07d_twostage.py
Fam Zheng d923aa1e31 NuoNuo: Hippocampal memory module prototype
Hopfield + Hebbian hybrid memory system for LLMs.
Two nights of experiments (16 iterations), validated on LongMemEval (ICLR 2025).

Architecture:
- Single-hop: Two-Stage Hopfield (NN top-20 → softmax settle)
- Multi-hop: Hebbian W matrix with WTA pattern separation
- 64% on LongMemEval (500 questions), retrieval-only, no LLM dependency
- 4ms latency @ 20K memories, ~1GB VRAM

Key findings:
- Hopfield attention solved noise tolerance (20% → 100% vs flat Hebbian)
- WTA pattern separation enables 20K+ capacity
- Multi-hop associative chains (6 hops, CosSim=1.0) — RAG can't do this
- MiniLM-L6 is optimal (discrimination gap > absolute similarity)
- Paraphrase cue augmentation: 55% → 100% on synthetic, 36% → 64% on benchmark
- SNN encoder viable (CosSim 0.99) but not needed for current architecture
2026-04-07 10:37:24 +01:00

340 lines
13 KiB
Python

"""Experiment 7d: Two-stage retrieval for scale.
Problem: Embedding Hopfield degrades at 10K+ (80%).
Fix: Pre-filter with approximate NN (top-K), then Hopfield settle on candidates.
This is O(N) for pre-filter (can be O(log N) with FAISS) + O(K) for Hopfield.
Also: test adaptive β based on attention entropy (low entropy = confident).
"""
import sys
import time
from pathlib import Path
import torch
import torch.nn as nn
import numpy as np
DEVICE = "cuda"
def cosine(a, b):
return nn.functional.cosine_similarity(a.unsqueeze(0), b.unsqueeze(0)).item()
class TwoStageHopfield:
"""Pre-filter + Hopfield settle.
Stage 1: cosine NN → top-K candidates (fast, O(N) or O(log N) with index)
Stage 2: Hopfield attention over K candidates (precise, O(K))
"""
def __init__(self, beta=16.0, top_k=50):
self.beta = beta
self.top_k = top_k
self.cue_embs = []
self.target_embs = []
self._cue_matrix = None # Cached for batch NN
def learn(self, cue_emb, target_emb):
self.cue_embs.append(cue_emb.detach())
self.target_embs.append(target_emb.detach())
self._cue_matrix = None # Invalidate cache
def _get_cue_matrix(self):
if self._cue_matrix is None:
self._cue_matrix = torch.stack(self.cue_embs)
return self._cue_matrix
def recall(self, query_emb, steps=3):
cue_mat = self._get_cue_matrix()
target_mat = torch.stack(self.target_embs)
N = cue_mat.shape[0]
# Stage 1: Fast NN pre-filter
k = min(self.top_k, N)
sims = query_emb @ cue_mat.T # [N]
top_sims, top_indices = sims.topk(k)
# Stage 2: Hopfield settle on candidates only
cand_cues = cue_mat[top_indices] # [K, dim]
cand_targets = target_mat[top_indices] # [K, dim]
xi = query_emb
for _ in range(steps):
scores = self.beta * (xi @ cand_cues.T)
attn = torch.softmax(scores, dim=0)
xi = attn @ cand_cues
xi = nn.functional.normalize(xi, dim=0)
# Final association
scores = self.beta * (xi @ cand_cues.T)
attn = torch.softmax(scores, dim=0)
target = attn @ cand_targets
# Map back to global index
best_local = attn.argmax().item()
best_global = top_indices[best_local].item()
return nn.functional.normalize(target, dim=0), best_global, attn
def recall_multihop(self, query_emb, hops=2, steps=3):
"""Multi-hop: each hop does two-stage retrieval."""
xi = query_emb
results = []
for _ in range(hops):
target, idx, attn = self.recall(xi, steps=steps)
results.append((target, idx))
xi = target # Use target as next query
return results
def load_model():
from sentence_transformers import SentenceTransformer
return SentenceTransformer("all-MiniLM-L6-v2", device=DEVICE)
def test_scale(model):
"""Scale test comparing pure Hopfield vs two-stage."""
print("\n=== Scale Comparison ===")
pairs = [
("What's the weather like today?", "User checks weather every morning"),
("Let's deploy the new version", "Deployment uses GitHub Actions with k3s"),
("The database is slow again", "Missing index on users table"),
("I need to fix the auth bug", "JWT tokens with 24h expiry in Redis"),
("The API returns 500 errors", "OOM in the Python worker"),
("Let's set up monitoring", "Prometheus + Grafana on OCI"),
("Tests failing in CI", "CI needs postgres service container"),
("Memory usage too high", "Leak in websocket handler"),
("Help with Docker setup", "docker-compose for dev, k3s for prod"),
("Log files too large", "Logs rotate daily, shipped to Loki"),
]
paraphrases = [
"How's the weather outside?",
"Push the new release",
"DB performance terrible",
"Login bug needs fixing",
"Getting 500 errors",
"Need better observability",
"CI tests breaking",
"Service using too much RAM",
"Docker config help",
"Logs eating disk space",
]
cue_embs = model.encode([p[0] for p in pairs], convert_to_tensor=True,
normalize_embeddings=True, device=DEVICE)
target_embs = model.encode([p[1] for p in pairs], convert_to_tensor=True,
normalize_embeddings=True, device=DEVICE)
para_embs = model.encode(paraphrases, convert_to_tensor=True,
normalize_embeddings=True, device=DEVICE)
for n_bg in [0, 100, 500, 1000, 5000, 10000, 20000]:
# Two-stage with different K
for top_k in [20, 50, 100]:
if n_bg < top_k and n_bg > 0:
continue
mem = TwoStageHopfield(beta=16.0, top_k=top_k)
for i in range(len(pairs)):
mem.learn(cue_embs[i], target_embs[i])
if n_bg > 0:
topics = ["server", "database", "API", "frontend", "backend",
"cache", "queue", "network", "storage", "auth",
"docker", "kubernetes", "redis", "nginx", "postgres"]
bg_cues = [f"The {topics[i%len(topics)]} system has issue {i}"
for i in range(n_bg)]
bg_targets = [f"Fix {topics[i%len(topics)]} issue {i} urgently"
for i in range(n_bg)]
for start in range(0, n_bg, 256):
end = min(start + 256, n_bg)
bc = model.encode(bg_cues[start:end], convert_to_tensor=True,
normalize_embeddings=True, device=DEVICE)
bt = model.encode(bg_targets[start:end], convert_to_tensor=True,
normalize_embeddings=True, device=DEVICE)
for j in range(bc.shape[0]):
mem.learn(bc[j], bt[j])
# Test
t0 = time.time()
correct = 0
for i in range(len(paraphrases)):
with torch.no_grad():
recalled, idx, attn = mem.recall(para_embs[i])
all_sims = [cosine(recalled, target_embs[j]) for j in range(len(pairs))]
if np.argmax(all_sims) == i:
correct += 1
dt = (time.time() - t0) / len(paraphrases) * 1000
n = len(paraphrases)
total = len(mem.cue_embs)
print(f" N={total:>6}, K={top_k:>3}: "
f"Para={correct}/{n} ({correct/n:>3.0%}), "
f"time={dt:.1f}ms")
del mem
torch.cuda.empty_cache()
if n_bg > 0:
print()
def test_multihop_at_scale(model):
"""Multi-hop with two-stage at scale."""
print("\n=== Multi-hop Two-Stage (500 bg) ===")
chains = [
["What's the weather?", "Check weather before going out",
"My coffee shop nearby", "Great latte art"],
["Review the code", "Found memory leak", "Leaks cause OOM", "Add k8s limits"],
["Deploy to prod", "Blue-green deployment", "Blue is active", "Switch to green"],
]
mem = TwoStageHopfield(beta=16.0, top_k=50)
all_embs = []
for chain in chains:
embs = [model.encode([t], convert_to_tensor=True,
normalize_embeddings=True, device=DEVICE)[0]
for t in chain]
all_embs.append(embs)
for i in range(len(chain) - 1):
mem.learn(embs[i], embs[i+1])
# Background
bg = [f"Background about {['code','ops','ml','data','infra'][i%5]} number {i}"
for i in range(500)]
bg_embs = model.encode(bg, convert_to_tensor=True,
normalize_embeddings=True, device=DEVICE, batch_size=256)
for i in range(499):
mem.learn(bg_embs[i], bg_embs[i+1])
for ci, chain in enumerate(chains):
results = mem.recall_multihop(all_embs[ci][0], hops=len(chain)-1)
for hop_idx, (recalled, idx) in enumerate(results):
target = all_embs[ci][hop_idx + 1]
sim = cosine(recalled, target)
status = "" if sim > 0.7 else ""
print(f" {status} Chain{ci+1} hop{hop_idx+1}: sim={sim:.3f}")
def test_diverse_queries(model):
"""Larger test set with more diverse queries."""
print("\n=== Diverse Query Test (20 pairs, 2000 bg) ===")
pairs = [
("What's the weather like today?", "User checks weather every morning"),
("Let's deploy the new version", "Deployment uses GitHub Actions with k3s"),
("The database is slow again", "Missing index on users table"),
("I need to fix the auth bug", "JWT tokens with 24h expiry in Redis"),
("The API returns 500 errors", "OOM in the Python worker"),
("Let's set up monitoring", "Prometheus + Grafana on OCI"),
("Tests failing in CI", "CI needs postgres service container"),
("Memory usage too high", "Leak in websocket handler"),
("Help with Docker setup", "docker-compose for dev, k3s for prod"),
("Log files too large", "Logs rotate daily, shipped to Loki"),
("How to add caching?", "Redis available at redis.internal:6379"),
("Frontend loads slowly", "CDN CloudFlare, 1h TTL for assets"),
("Refactor payment module", "Stripe API, webhook in payments/webhook.py"),
("Set up new server", "Ubuntu 22.04, Docker, Tailscale, monitoring"),
("Optimize search", "Elasticsearch v8, recently upgraded"),
("Backup the database", "Daily 3am UTC cron to S3"),
("Configure reverse proxy", "Traefik, not nginx"),
("Team meeting schedule", "Standup 10am London, Mon-Fri"),
("Learn a new language", "User has Python+Go, new to systems programming"),
("Review my PR", "User prefers small PRs with clear commits"),
]
paraphrases = [
"How's the weather?",
"Ship the release",
"DB is crawling",
"Fix the login issue",
"Server errors everywhere",
"Need observability",
"CI is broken",
"Too much RAM usage",
"Docker help please",
"Disk full from logs",
"Want to add a cache layer",
"Website too slow",
"Payment code needs rework",
"Provision a new machine",
"Search is slow",
"Need a DB backup",
"Proxy configuration",
"When's the standup?",
"Want to learn Rust",
"Check my pull request",
]
cue_embs = model.encode([p[0] for p in pairs], convert_to_tensor=True,
normalize_embeddings=True, device=DEVICE)
target_embs = model.encode([p[1] for p in pairs], convert_to_tensor=True,
normalize_embeddings=True, device=DEVICE)
para_embs = model.encode(paraphrases, convert_to_tensor=True,
normalize_embeddings=True, device=DEVICE)
mem = TwoStageHopfield(beta=16.0, top_k=50)
for i in range(len(pairs)):
mem.learn(cue_embs[i], target_embs[i])
# 2000 diverse background
topics = ["server", "database", "API", "frontend", "backend", "cache",
"queue", "network", "storage", "auth", "docker", "kubernetes",
"redis", "nginx", "postgres", "python", "golang", "react",
"terraform", "ansible"]
actions = ["crashed", "is slow", "needs update", "has bug", "timed out",
"needs migration", "needs backup", "has leak", "is down", "needs config"]
bg_cues = [f"The {topics[i%len(topics)]} {actions[i%len(actions)]} (ticket {i})"
for i in range(2000)]
bg_targets = [f"Fix {topics[i%len(topics)]} {actions[i%len(actions)]}: see wiki page {i}"
for i in range(2000)]
for start in range(0, 2000, 256):
end = min(start + 256, 2000)
bc = model.encode(bg_cues[start:end], convert_to_tensor=True,
normalize_embeddings=True, device=DEVICE)
bt = model.encode(bg_targets[start:end], convert_to_tensor=True,
normalize_embeddings=True, device=DEVICE)
for j in range(bc.shape[0]):
mem.learn(bc[j], bt[j])
# Test
correct = 0
failures = []
for i in range(len(paraphrases)):
with torch.no_grad():
recalled, idx, attn = mem.recall(para_embs[i])
all_sims = [cosine(recalled, target_embs[j]) for j in range(len(pairs))]
best = np.argmax(all_sims)
if best == i:
correct += 1
else:
failures.append((i, best, all_sims[i], all_sims[best]))
n = len(paraphrases)
print(f" Result: {correct}/{n} ({correct/n:.0%})")
if failures:
print(f" Failures:")
for qi, gi, sim_correct, sim_got in failures:
print(f" Q: '{paraphrases[qi][:30]}...' → got [{gi}] "
f"(sim_correct={sim_correct:.3f}, sim_got={sim_got:.3f})")
def main():
print("=" * 60)
print("Experiment 7d: Two-Stage Hopfield")
print("=" * 60)
model = load_model()
test_scale(model)
test_multihop_at_scale(model)
test_diverse_queries(model)
if __name__ == "__main__":
main()