Files
nuonuo/experiments/exp07c_hopfield_embedding.py
Fam Zheng d923aa1e31 NuoNuo: Hippocampal memory module prototype
Hopfield + Hebbian hybrid memory system for LLMs.
Two nights of experiments (16 iterations), validated on LongMemEval (ICLR 2025).

Architecture:
- Single-hop: Two-Stage Hopfield (NN top-20 → softmax settle)
- Multi-hop: Hebbian W matrix with WTA pattern separation
- 64% on LongMemEval (500 questions), retrieval-only, no LLM dependency
- 4ms latency @ 20K memories, ~1GB VRAM

Key findings:
- Hopfield attention solved noise tolerance (20% → 100% vs flat Hebbian)
- WTA pattern separation enables 20K+ capacity
- Multi-hop associative chains (6 hops, CosSim=1.0) — RAG can't do this
- MiniLM-L6 is optimal (discrimination gap > absolute similarity)
- Paraphrase cue augmentation: 55% → 100% on synthetic, 36% → 64% on benchmark
- SNN encoder viable (CosSim 0.99) but not needed for current architecture
2026-04-07 10:37:24 +01:00

318 lines
12 KiB
Python

"""Experiment 7c: Hopfield in embedding space (no WTA codes for retrieval).
Key insight: WTA codes distort semantic distance. Hopfield attention works
better directly on continuous embeddings where cosine similarity is meaningful.
WTA codes are only needed for Hebbian multi-hop (W matrix).
For single-hop retrieval, embedding-space Hopfield is strictly better.
Test:
1. Embedding-space Hopfield at scale (1K-10K)
2. Hard semantic distractors
3. Embedding-space multi-hop (no WTA needed?)
4. Compare code-space vs embedding-space
"""
import sys
import time
from pathlib import Path
import torch
import torch.nn as nn
import numpy as np
DEVICE = "cuda"
class EmbeddingHopfield:
"""Modern Hopfield network operating directly on embeddings.
No WTA codes, no pattern separation — pure softmax attention
over stored embedding patterns. This is essentially transformer
cross-attention with stored memories as K/V.
"""
def __init__(self, beta=16.0):
self.beta = beta
self.cue_embs = [] # Keys
self.target_embs = [] # Values
self.metadata = []
def learn(self, cue_emb, target_emb, meta=None):
self.cue_embs.append(cue_emb.detach())
self.target_embs.append(target_emb.detach())
self.metadata.append(meta or {})
def recall(self, query_emb, steps=3):
"""Iterative Hopfield retrieval in embedding space.
Step 1: query settles to nearest cue attractor via softmax attention
Step 2: settled query → associated target via softmax attention
"""
cue_mat = torch.stack(self.cue_embs) # [N, dim]
target_mat = torch.stack(self.target_embs) # [N, dim]
xi = query_emb # [dim]
# Settle to nearest cue (iterative attention)
for _ in range(steps):
scores = self.beta * (xi @ cue_mat.T) # [N]
attn = torch.softmax(scores, dim=0)
xi = attn @ cue_mat # [dim] — weighted average of cues
xi = nn.functional.normalize(xi, dim=0)
# Associate: settled cue → target
scores = self.beta * (xi @ cue_mat.T)
attn = torch.softmax(scores, dim=0)
target = attn @ target_mat
return nn.functional.normalize(target, dim=0), attn
def recall_multihop(self, query_emb, hops=2, steps_per_hop=3):
"""Multi-hop in embedding space.
Settle to cue → get target → use target as next query.
"""
cue_mat = torch.stack(self.cue_embs)
target_mat = torch.stack(self.target_embs)
xi = query_emb
results = []
for hop in range(hops):
# Settle
for _ in range(steps_per_hop):
scores = self.beta * (xi @ cue_mat.T)
attn = torch.softmax(scores, dim=0)
xi = attn @ cue_mat
xi = nn.functional.normalize(xi, dim=0)
# Associate
scores = self.beta * (xi @ cue_mat.T)
attn = torch.softmax(scores, dim=0)
target = attn @ target_mat
target = nn.functional.normalize(target, dim=0)
results.append((target, attn))
# Next hop
xi = target
return results
def load_model():
from sentence_transformers import SentenceTransformer
return SentenceTransformer("all-MiniLM-L6-v2", device=DEVICE)
def cosine(a, b):
return nn.functional.cosine_similarity(a.unsqueeze(0), b.unsqueeze(0)).item()
def test_scale(model):
"""Scale test with embedding-space Hopfield."""
print("\n=== Scale Test: Embedding-Space Hopfield ===")
pairs = [
("What's the weather like today?", "User checks weather every morning"),
("Let's deploy the new version", "Deployment uses GitHub Actions with k3s"),
("The database is slow again", "Missing index on users table"),
("I need to fix the auth bug", "JWT tokens with 24h expiry in Redis"),
("The API returns 500 errors", "OOM in the Python worker"),
("Let's set up monitoring", "Prometheus + Grafana on OCI"),
("Tests failing in CI", "CI needs postgres service container"),
("Memory usage too high", "Leak in websocket handler"),
("Help with Docker setup", "docker-compose for dev, k3s for prod"),
("Log files too large", "Logs rotate daily, shipped to Loki"),
]
paraphrases = [
"How's the weather outside?",
"Push the new release",
"DB performance terrible",
"Login bug needs fixing",
"Getting 500 errors",
"Need better observability",
"CI tests breaking",
"Service using too much RAM",
"Docker config help",
"Logs eating disk space",
]
cue_embs = model.encode([p[0] for p in pairs], convert_to_tensor=True,
normalize_embeddings=True, device=DEVICE)
target_embs = model.encode([p[1] for p in pairs], convert_to_tensor=True,
normalize_embeddings=True, device=DEVICE)
para_embs = model.encode(paraphrases, convert_to_tensor=True,
normalize_embeddings=True, device=DEVICE)
for n_bg in [0, 100, 500, 1000, 2000, 5000, 10000]:
for beta in [16, 32, 64]:
mem = EmbeddingHopfield(beta=beta)
for i in range(len(pairs)):
mem.learn(cue_embs[i], target_embs[i])
if n_bg > 0:
topics = ["server", "database", "API", "frontend", "backend",
"cache", "queue", "network", "storage", "auth",
"docker", "kubernetes", "redis", "nginx", "postgres"]
bg_cues = [f"The {topics[i%len(topics)]} system has issue {i}" for i in range(n_bg)]
bg_targets = [f"Fix {topics[i%len(topics)]} issue {i} urgently" for i in range(n_bg)]
for start in range(0, n_bg, 256):
end = min(start + 256, n_bg)
bc = model.encode(bg_cues[start:end], convert_to_tensor=True,
normalize_embeddings=True, device=DEVICE)
bt = model.encode(bg_targets[start:end], convert_to_tensor=True,
normalize_embeddings=True, device=DEVICE)
for j in range(bc.shape[0]):
mem.learn(bc[j], bt[j])
# Test paraphrase recall
t0 = time.time()
correct = 0
for i in range(len(paraphrases)):
with torch.no_grad():
recalled, attn = mem.recall(para_embs[i])
sim = cosine(recalled, target_embs[i])
# Check if recalled is closest to correct target
all_sims = [cosine(recalled, target_embs[j]) for j in range(len(pairs))]
if np.argmax(all_sims) == i:
correct += 1
dt = (time.time() - t0) / len(paraphrases) * 1000
n = len(paraphrases)
if beta == 32 or n_bg == 0: # Only print all β for bg=0
print(f" N={n_bg+len(pairs):>6}, β={beta:>2}: "
f"Para={correct}/{n} ({correct/n:.0%}), "
f"time={dt:.1f}ms")
del mem
if n_bg == 0:
print() # separator after β sweep
def test_hard_distractors(model):
"""Semantic distractors in embedding space."""
print("\n=== Hard Semantic Distractors (Embedding Hopfield) ===")
target_pair = ("The database is slow", "Missing index on users table")
distractors = [
("The database crashed completely", "Run database recovery procedure"),
("Database needs backup now", "Use pg_dump for PostgreSQL backup"),
("The datastore is slow", "Check Redis connection pool settings"),
("DB latency is high", "Review query execution plans"),
("Database performance degraded", "Check for lock contention"),
("SQL queries are slow", "Add composite index on frequently joined columns"),
("The cache is slow", "Increase Redis maxmemory setting"),
("MongoDB is slow", "Check for collection scans without index"),
("The search index is slow", "Rebuild Elasticsearch index"),
("Database connection timeout", "Increase pool size in connection config"),
]
query = "DB performance is terrible"
cue_emb = model.encode([target_pair[0]], convert_to_tensor=True,
normalize_embeddings=True, device=DEVICE)[0]
target_emb = model.encode([target_pair[1]], convert_to_tensor=True,
normalize_embeddings=True, device=DEVICE)[0]
q_emb = model.encode([query], convert_to_tensor=True,
normalize_embeddings=True, device=DEVICE)[0]
# Show embedding distances
print(f"\n Query: '{query}'")
print(f" Target cue: '{target_pair[0]}' (cos={cosine(q_emb, cue_emb):.3f})")
for dc, dt in distractors[:5]:
dc_emb = model.encode([dc], convert_to_tensor=True,
normalize_embeddings=True, device=DEVICE)[0]
print(f" Distractor: '{dc[:40]}...' (cos={cosine(q_emb, dc_emb):.3f})")
for beta in [8, 16, 32, 64, 128]:
mem = EmbeddingHopfield(beta=beta)
mem.learn(cue_emb, target_emb, {"text": target_pair[1]})
for dc, dt in distractors:
dc_emb = model.encode([dc], convert_to_tensor=True,
normalize_embeddings=True, device=DEVICE)[0]
dt_emb = model.encode([dt], convert_to_tensor=True,
normalize_embeddings=True, device=DEVICE)[0]
mem.learn(dc_emb, dt_emb, {"text": dt})
recalled, attn = mem.recall(q_emb)
sim_to_target = cosine(recalled, target_emb)
top_idx = attn.argmax().item()
top_attn = attn[top_idx].item()
all_texts = [target_pair[1]] + [d[1] for d in distractors]
print(f" β={beta:>3}: sim={sim_to_target:.3f}, "
f"top_attn={top_attn:.3f}'{all_texts[top_idx][:40]}...'")
def test_multihop_embedding(model):
"""Multi-hop in pure embedding space."""
print("\n=== Multi-hop (Embedding Space) ===")
chains = [
["What's the weather?", "Check weather before going out",
"My coffee shop is around the corner", "Great latte art there"],
["Review the code", "Found a memory leak in review",
"Memory leaks cause OOM", "Add memory limits to k8s pods"],
]
for chain in chains:
mem = EmbeddingHopfield(beta=32)
chain_embs = [model.encode([t], convert_to_tensor=True,
normalize_embeddings=True, device=DEVICE)[0]
for t in chain]
for i in range(len(chain) - 1):
mem.learn(chain_embs[i], chain_embs[i+1])
results = mem.recall_multihop(chain_embs[0], hops=len(chain)-1)
print(f"\n Chain: {''.join([c[:20]+'...' for c in chain])}")
for hop_idx, (recalled, attn) in enumerate(results):
target = chain_embs[hop_idx + 1]
sim = cosine(recalled, target)
status = "" if sim > 0.7 else ""
print(f" {status} hop {hop_idx+1}: sim={sim:.3f}")
# With background
print("\n --- With 500 background ---")
mem = EmbeddingHopfield(beta=32)
all_embs = []
for chain in chains:
embs = [model.encode([t], convert_to_tensor=True,
normalize_embeddings=True, device=DEVICE)[0]
for t in chain]
all_embs.append(embs)
for i in range(len(chain) - 1):
mem.learn(embs[i], embs[i+1])
bg = [f"Background about {['coding','devops','ml','infra','data'][i%5]} topic {i}" for i in range(500)]
bg_embs = model.encode(bg, convert_to_tensor=True,
normalize_embeddings=True, device=DEVICE, batch_size=256)
for i in range(499):
mem.learn(bg_embs[i], bg_embs[i+1])
for ci, chain in enumerate(chains):
results = mem.recall_multihop(all_embs[ci][0], hops=len(chain)-1)
for hop_idx, (recalled, _) in enumerate(results):
target = all_embs[ci][hop_idx + 1]
sim = cosine(recalled, target)
status = "" if sim > 0.7 else ""
print(f" {status} Chain{ci+1} hop{hop_idx+1}: sim={sim:.3f}")
def main():
print("=" * 60)
print("Experiment 7c: Embedding-Space Hopfield")
print("=" * 60)
model = load_model()
test_scale(model)
test_hard_distractors(model)
test_multihop_embedding(model)
if __name__ == "__main__":
main()