Hopfield + Hebbian hybrid memory system for LLMs. Two nights of experiments (16 iterations), validated on LongMemEval (ICLR 2025). Architecture: - Single-hop: Two-Stage Hopfield (NN top-20 → softmax settle) - Multi-hop: Hebbian W matrix with WTA pattern separation - 64% on LongMemEval (500 questions), retrieval-only, no LLM dependency - 4ms latency @ 20K memories, ~1GB VRAM Key findings: - Hopfield attention solved noise tolerance (20% → 100% vs flat Hebbian) - WTA pattern separation enables 20K+ capacity - Multi-hop associative chains (6 hops, CosSim=1.0) — RAG can't do this - MiniLM-L6 is optimal (discrimination gap > absolute similarity) - Paraphrase cue augmentation: 55% → 100% on synthetic, 36% → 64% on benchmark - SNN encoder viable (CosSim 0.99) but not needed for current architecture
318 lines
12 KiB
Python
318 lines
12 KiB
Python
"""Experiment 7c: Hopfield in embedding space (no WTA codes for retrieval).
|
|
|
|
Key insight: WTA codes distort semantic distance. Hopfield attention works
|
|
better directly on continuous embeddings where cosine similarity is meaningful.
|
|
|
|
WTA codes are only needed for Hebbian multi-hop (W matrix).
|
|
For single-hop retrieval, embedding-space Hopfield is strictly better.
|
|
|
|
Test:
|
|
1. Embedding-space Hopfield at scale (1K-10K)
|
|
2. Hard semantic distractors
|
|
3. Embedding-space multi-hop (no WTA needed?)
|
|
4. Compare code-space vs embedding-space
|
|
"""
|
|
|
|
import sys
|
|
import time
|
|
from pathlib import Path
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import numpy as np
|
|
|
|
DEVICE = "cuda"
|
|
|
|
|
|
class EmbeddingHopfield:
|
|
"""Modern Hopfield network operating directly on embeddings.
|
|
|
|
No WTA codes, no pattern separation — pure softmax attention
|
|
over stored embedding patterns. This is essentially transformer
|
|
cross-attention with stored memories as K/V.
|
|
"""
|
|
def __init__(self, beta=16.0):
|
|
self.beta = beta
|
|
self.cue_embs = [] # Keys
|
|
self.target_embs = [] # Values
|
|
self.metadata = []
|
|
|
|
def learn(self, cue_emb, target_emb, meta=None):
|
|
self.cue_embs.append(cue_emb.detach())
|
|
self.target_embs.append(target_emb.detach())
|
|
self.metadata.append(meta or {})
|
|
|
|
def recall(self, query_emb, steps=3):
|
|
"""Iterative Hopfield retrieval in embedding space.
|
|
|
|
Step 1: query settles to nearest cue attractor via softmax attention
|
|
Step 2: settled query → associated target via softmax attention
|
|
"""
|
|
cue_mat = torch.stack(self.cue_embs) # [N, dim]
|
|
target_mat = torch.stack(self.target_embs) # [N, dim]
|
|
|
|
xi = query_emb # [dim]
|
|
|
|
# Settle to nearest cue (iterative attention)
|
|
for _ in range(steps):
|
|
scores = self.beta * (xi @ cue_mat.T) # [N]
|
|
attn = torch.softmax(scores, dim=0)
|
|
xi = attn @ cue_mat # [dim] — weighted average of cues
|
|
xi = nn.functional.normalize(xi, dim=0)
|
|
|
|
# Associate: settled cue → target
|
|
scores = self.beta * (xi @ cue_mat.T)
|
|
attn = torch.softmax(scores, dim=0)
|
|
target = attn @ target_mat
|
|
return nn.functional.normalize(target, dim=0), attn
|
|
|
|
def recall_multihop(self, query_emb, hops=2, steps_per_hop=3):
|
|
"""Multi-hop in embedding space.
|
|
Settle to cue → get target → use target as next query.
|
|
"""
|
|
cue_mat = torch.stack(self.cue_embs)
|
|
target_mat = torch.stack(self.target_embs)
|
|
|
|
xi = query_emb
|
|
results = []
|
|
|
|
for hop in range(hops):
|
|
# Settle
|
|
for _ in range(steps_per_hop):
|
|
scores = self.beta * (xi @ cue_mat.T)
|
|
attn = torch.softmax(scores, dim=0)
|
|
xi = attn @ cue_mat
|
|
xi = nn.functional.normalize(xi, dim=0)
|
|
|
|
# Associate
|
|
scores = self.beta * (xi @ cue_mat.T)
|
|
attn = torch.softmax(scores, dim=0)
|
|
target = attn @ target_mat
|
|
target = nn.functional.normalize(target, dim=0)
|
|
results.append((target, attn))
|
|
|
|
# Next hop
|
|
xi = target
|
|
|
|
return results
|
|
|
|
|
|
def load_model():
|
|
from sentence_transformers import SentenceTransformer
|
|
return SentenceTransformer("all-MiniLM-L6-v2", device=DEVICE)
|
|
|
|
|
|
def cosine(a, b):
|
|
return nn.functional.cosine_similarity(a.unsqueeze(0), b.unsqueeze(0)).item()
|
|
|
|
|
|
def test_scale(model):
|
|
"""Scale test with embedding-space Hopfield."""
|
|
print("\n=== Scale Test: Embedding-Space Hopfield ===")
|
|
|
|
pairs = [
|
|
("What's the weather like today?", "User checks weather every morning"),
|
|
("Let's deploy the new version", "Deployment uses GitHub Actions with k3s"),
|
|
("The database is slow again", "Missing index on users table"),
|
|
("I need to fix the auth bug", "JWT tokens with 24h expiry in Redis"),
|
|
("The API returns 500 errors", "OOM in the Python worker"),
|
|
("Let's set up monitoring", "Prometheus + Grafana on OCI"),
|
|
("Tests failing in CI", "CI needs postgres service container"),
|
|
("Memory usage too high", "Leak in websocket handler"),
|
|
("Help with Docker setup", "docker-compose for dev, k3s for prod"),
|
|
("Log files too large", "Logs rotate daily, shipped to Loki"),
|
|
]
|
|
paraphrases = [
|
|
"How's the weather outside?",
|
|
"Push the new release",
|
|
"DB performance terrible",
|
|
"Login bug needs fixing",
|
|
"Getting 500 errors",
|
|
"Need better observability",
|
|
"CI tests breaking",
|
|
"Service using too much RAM",
|
|
"Docker config help",
|
|
"Logs eating disk space",
|
|
]
|
|
|
|
cue_embs = model.encode([p[0] for p in pairs], convert_to_tensor=True,
|
|
normalize_embeddings=True, device=DEVICE)
|
|
target_embs = model.encode([p[1] for p in pairs], convert_to_tensor=True,
|
|
normalize_embeddings=True, device=DEVICE)
|
|
para_embs = model.encode(paraphrases, convert_to_tensor=True,
|
|
normalize_embeddings=True, device=DEVICE)
|
|
|
|
for n_bg in [0, 100, 500, 1000, 2000, 5000, 10000]:
|
|
for beta in [16, 32, 64]:
|
|
mem = EmbeddingHopfield(beta=beta)
|
|
|
|
for i in range(len(pairs)):
|
|
mem.learn(cue_embs[i], target_embs[i])
|
|
|
|
if n_bg > 0:
|
|
topics = ["server", "database", "API", "frontend", "backend",
|
|
"cache", "queue", "network", "storage", "auth",
|
|
"docker", "kubernetes", "redis", "nginx", "postgres"]
|
|
bg_cues = [f"The {topics[i%len(topics)]} system has issue {i}" for i in range(n_bg)]
|
|
bg_targets = [f"Fix {topics[i%len(topics)]} issue {i} urgently" for i in range(n_bg)]
|
|
|
|
for start in range(0, n_bg, 256):
|
|
end = min(start + 256, n_bg)
|
|
bc = model.encode(bg_cues[start:end], convert_to_tensor=True,
|
|
normalize_embeddings=True, device=DEVICE)
|
|
bt = model.encode(bg_targets[start:end], convert_to_tensor=True,
|
|
normalize_embeddings=True, device=DEVICE)
|
|
for j in range(bc.shape[0]):
|
|
mem.learn(bc[j], bt[j])
|
|
|
|
# Test paraphrase recall
|
|
t0 = time.time()
|
|
correct = 0
|
|
for i in range(len(paraphrases)):
|
|
with torch.no_grad():
|
|
recalled, attn = mem.recall(para_embs[i])
|
|
sim = cosine(recalled, target_embs[i])
|
|
# Check if recalled is closest to correct target
|
|
all_sims = [cosine(recalled, target_embs[j]) for j in range(len(pairs))]
|
|
if np.argmax(all_sims) == i:
|
|
correct += 1
|
|
dt = (time.time() - t0) / len(paraphrases) * 1000
|
|
|
|
n = len(paraphrases)
|
|
if beta == 32 or n_bg == 0: # Only print all β for bg=0
|
|
print(f" N={n_bg+len(pairs):>6}, β={beta:>2}: "
|
|
f"Para={correct}/{n} ({correct/n:.0%}), "
|
|
f"time={dt:.1f}ms")
|
|
|
|
del mem
|
|
|
|
if n_bg == 0:
|
|
print() # separator after β sweep
|
|
|
|
|
|
def test_hard_distractors(model):
|
|
"""Semantic distractors in embedding space."""
|
|
print("\n=== Hard Semantic Distractors (Embedding Hopfield) ===")
|
|
|
|
target_pair = ("The database is slow", "Missing index on users table")
|
|
distractors = [
|
|
("The database crashed completely", "Run database recovery procedure"),
|
|
("Database needs backup now", "Use pg_dump for PostgreSQL backup"),
|
|
("The datastore is slow", "Check Redis connection pool settings"),
|
|
("DB latency is high", "Review query execution plans"),
|
|
("Database performance degraded", "Check for lock contention"),
|
|
("SQL queries are slow", "Add composite index on frequently joined columns"),
|
|
("The cache is slow", "Increase Redis maxmemory setting"),
|
|
("MongoDB is slow", "Check for collection scans without index"),
|
|
("The search index is slow", "Rebuild Elasticsearch index"),
|
|
("Database connection timeout", "Increase pool size in connection config"),
|
|
]
|
|
|
|
query = "DB performance is terrible"
|
|
|
|
cue_emb = model.encode([target_pair[0]], convert_to_tensor=True,
|
|
normalize_embeddings=True, device=DEVICE)[0]
|
|
target_emb = model.encode([target_pair[1]], convert_to_tensor=True,
|
|
normalize_embeddings=True, device=DEVICE)[0]
|
|
q_emb = model.encode([query], convert_to_tensor=True,
|
|
normalize_embeddings=True, device=DEVICE)[0]
|
|
|
|
# Show embedding distances
|
|
print(f"\n Query: '{query}'")
|
|
print(f" Target cue: '{target_pair[0]}' (cos={cosine(q_emb, cue_emb):.3f})")
|
|
for dc, dt in distractors[:5]:
|
|
dc_emb = model.encode([dc], convert_to_tensor=True,
|
|
normalize_embeddings=True, device=DEVICE)[0]
|
|
print(f" Distractor: '{dc[:40]}...' (cos={cosine(q_emb, dc_emb):.3f})")
|
|
|
|
for beta in [8, 16, 32, 64, 128]:
|
|
mem = EmbeddingHopfield(beta=beta)
|
|
mem.learn(cue_emb, target_emb, {"text": target_pair[1]})
|
|
|
|
for dc, dt in distractors:
|
|
dc_emb = model.encode([dc], convert_to_tensor=True,
|
|
normalize_embeddings=True, device=DEVICE)[0]
|
|
dt_emb = model.encode([dt], convert_to_tensor=True,
|
|
normalize_embeddings=True, device=DEVICE)[0]
|
|
mem.learn(dc_emb, dt_emb, {"text": dt})
|
|
|
|
recalled, attn = mem.recall(q_emb)
|
|
sim_to_target = cosine(recalled, target_emb)
|
|
top_idx = attn.argmax().item()
|
|
top_attn = attn[top_idx].item()
|
|
all_texts = [target_pair[1]] + [d[1] for d in distractors]
|
|
|
|
print(f" β={beta:>3}: sim={sim_to_target:.3f}, "
|
|
f"top_attn={top_attn:.3f} → '{all_texts[top_idx][:40]}...'")
|
|
|
|
|
|
def test_multihop_embedding(model):
|
|
"""Multi-hop in pure embedding space."""
|
|
print("\n=== Multi-hop (Embedding Space) ===")
|
|
|
|
chains = [
|
|
["What's the weather?", "Check weather before going out",
|
|
"My coffee shop is around the corner", "Great latte art there"],
|
|
["Review the code", "Found a memory leak in review",
|
|
"Memory leaks cause OOM", "Add memory limits to k8s pods"],
|
|
]
|
|
|
|
for chain in chains:
|
|
mem = EmbeddingHopfield(beta=32)
|
|
chain_embs = [model.encode([t], convert_to_tensor=True,
|
|
normalize_embeddings=True, device=DEVICE)[0]
|
|
for t in chain]
|
|
|
|
for i in range(len(chain) - 1):
|
|
mem.learn(chain_embs[i], chain_embs[i+1])
|
|
|
|
results = mem.recall_multihop(chain_embs[0], hops=len(chain)-1)
|
|
|
|
print(f"\n Chain: {' → '.join([c[:20]+'...' for c in chain])}")
|
|
for hop_idx, (recalled, attn) in enumerate(results):
|
|
target = chain_embs[hop_idx + 1]
|
|
sim = cosine(recalled, target)
|
|
status = "✓" if sim > 0.7 else "✗"
|
|
print(f" {status} hop {hop_idx+1}: sim={sim:.3f}")
|
|
|
|
# With background
|
|
print("\n --- With 500 background ---")
|
|
mem = EmbeddingHopfield(beta=32)
|
|
all_embs = []
|
|
for chain in chains:
|
|
embs = [model.encode([t], convert_to_tensor=True,
|
|
normalize_embeddings=True, device=DEVICE)[0]
|
|
for t in chain]
|
|
all_embs.append(embs)
|
|
for i in range(len(chain) - 1):
|
|
mem.learn(embs[i], embs[i+1])
|
|
|
|
bg = [f"Background about {['coding','devops','ml','infra','data'][i%5]} topic {i}" for i in range(500)]
|
|
bg_embs = model.encode(bg, convert_to_tensor=True,
|
|
normalize_embeddings=True, device=DEVICE, batch_size=256)
|
|
for i in range(499):
|
|
mem.learn(bg_embs[i], bg_embs[i+1])
|
|
|
|
for ci, chain in enumerate(chains):
|
|
results = mem.recall_multihop(all_embs[ci][0], hops=len(chain)-1)
|
|
for hop_idx, (recalled, _) in enumerate(results):
|
|
target = all_embs[ci][hop_idx + 1]
|
|
sim = cosine(recalled, target)
|
|
status = "✓" if sim > 0.7 else "✗"
|
|
print(f" {status} Chain{ci+1} hop{hop_idx+1}: sim={sim:.3f}")
|
|
|
|
|
|
def main():
|
|
print("=" * 60)
|
|
print("Experiment 7c: Embedding-Space Hopfield")
|
|
print("=" * 60)
|
|
|
|
model = load_model()
|
|
test_scale(model)
|
|
test_hard_distractors(model)
|
|
test_multihop_embedding(model)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|