Hopfield + Hebbian hybrid memory system for LLMs. Two nights of experiments (16 iterations), validated on LongMemEval (ICLR 2025). Architecture: - Single-hop: Two-Stage Hopfield (NN top-20 → softmax settle) - Multi-hop: Hebbian W matrix with WTA pattern separation - 64% on LongMemEval (500 questions), retrieval-only, no LLM dependency - 4ms latency @ 20K memories, ~1GB VRAM Key findings: - Hopfield attention solved noise tolerance (20% → 100% vs flat Hebbian) - WTA pattern separation enables 20K+ capacity - Multi-hop associative chains (6 hops, CosSim=1.0) — RAG can't do this - MiniLM-L6 is optimal (discrimination gap > absolute similarity) - Paraphrase cue augmentation: 55% → 100% on synthetic, 36% → 64% on benchmark - SNN encoder viable (CosSim 0.99) but not needed for current architecture
376 lines
13 KiB
Python
376 lines
13 KiB
Python
"""Experiment 7b: Deep dive into Hopfield memory.
|
|
|
|
Hopfield crushed it at 1000 bg (100% para recall). Now stress test:
|
|
1. Scale to 5K, 10K, 20K memories — does softmax attention hold up?
|
|
2. Multi-hop: can we chain through Hopfield? (A→B→C)
|
|
3. Latency: O(N) attention — how slow at 20K?
|
|
4. β optimization: find sweet spot
|
|
5. Memory: storing all patterns explicitly — how much VRAM?
|
|
6. Mixed difficulty: semantically similar distractors (not just random bg)
|
|
"""
|
|
|
|
import sys
|
|
import time
|
|
from pathlib import Path
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import numpy as np
|
|
|
|
DEVICE = "cuda"
|
|
|
|
|
|
def cosine(a, b):
|
|
if a.norm() == 0 or b.norm() == 0:
|
|
return 0.0
|
|
return nn.functional.cosine_similarity(a.unsqueeze(0), b.unsqueeze(0)).item()
|
|
|
|
|
|
def winner_take_all(x, k):
|
|
_, idx = x.topk(k, dim=-1)
|
|
out = torch.zeros_like(x)
|
|
out.scatter_(-1, idx, 1.0)
|
|
return out
|
|
|
|
|
|
class HopfieldMemory:
|
|
def __init__(self, input_dim, code_dim=16384, k=50, beta=16.0):
|
|
self.k = k
|
|
self.code_dim = code_dim
|
|
self.beta = beta
|
|
self.proj = (torch.randn(input_dim, code_dim, device=DEVICE)
|
|
* (1.0 / input_dim**0.5))
|
|
self.cue_codes = []
|
|
self.target_codes = []
|
|
self.cue_embs = []
|
|
self.target_embs = []
|
|
|
|
def sep(self, x):
|
|
return winner_take_all(x @ self.proj, self.k)
|
|
|
|
def learn(self, cue_emb, target_emb):
|
|
self.cue_codes.append(self.sep(cue_emb))
|
|
self.target_codes.append(self.sep(target_emb))
|
|
self.cue_embs.append(cue_emb.detach())
|
|
self.target_embs.append(target_emb.detach())
|
|
|
|
def _get_matrices(self):
|
|
return torch.stack(self.cue_codes), torch.stack(self.target_codes)
|
|
|
|
def recall(self, query_emb, steps=3):
|
|
cue_mat, target_mat = self._get_matrices()
|
|
xi = self.sep(query_emb)
|
|
for _ in range(steps):
|
|
scores = self.beta * (xi @ cue_mat.T)
|
|
attn = torch.softmax(scores, dim=0)
|
|
xi = attn @ cue_mat
|
|
xi = winner_take_all(xi, self.k)
|
|
# Final association
|
|
scores = self.beta * (xi @ cue_mat.T)
|
|
attn = torch.softmax(scores, dim=0)
|
|
recalled = attn @ target_mat
|
|
return winner_take_all(recalled, self.k)
|
|
|
|
def recall_multihop(self, query_emb, hops=2, steps_per_hop=3):
|
|
"""Multi-hop: settle to cue → get target → use target as next cue."""
|
|
cue_mat, target_mat = self._get_matrices()
|
|
|
|
xi = self.sep(query_emb)
|
|
results = []
|
|
|
|
for hop in range(hops):
|
|
# Settle to nearest cue attractor
|
|
for _ in range(steps_per_hop):
|
|
scores = self.beta * (xi @ cue_mat.T)
|
|
attn = torch.softmax(scores, dim=0)
|
|
xi = attn @ cue_mat
|
|
xi = winner_take_all(xi, self.k)
|
|
|
|
# Associate: cue → target
|
|
scores = self.beta * (xi @ cue_mat.T)
|
|
attn = torch.softmax(scores, dim=0)
|
|
target = attn @ target_mat
|
|
target = winner_take_all(target, self.k)
|
|
results.append(target)
|
|
|
|
# Next hop: use target as new query
|
|
xi = target
|
|
|
|
return results
|
|
|
|
def recall_embedding_space(self, query_emb, steps=3):
|
|
"""Hopfield attention in raw embedding space (no WTA codes).
|
|
Might be better for noise tolerance since embeddings are continuous.
|
|
"""
|
|
if not self.cue_embs:
|
|
return None
|
|
|
|
cue_mat = torch.stack(self.cue_embs)
|
|
target_mat = torch.stack(self.target_embs)
|
|
|
|
xi = query_emb
|
|
for _ in range(steps):
|
|
scores = self.beta * (xi @ cue_mat.T)
|
|
attn = torch.softmax(scores, dim=0)
|
|
xi = attn @ cue_mat
|
|
|
|
# Final: get target
|
|
scores = self.beta * (xi @ cue_mat.T)
|
|
attn = torch.softmax(scores, dim=0)
|
|
return attn @ target_mat
|
|
|
|
|
|
def load_model():
|
|
from sentence_transformers import SentenceTransformer
|
|
return SentenceTransformer("all-MiniLM-L6-v2", device=DEVICE)
|
|
|
|
|
|
def test_scale(model, n_background_list, beta=16.0):
|
|
"""Test Hopfield at different scales."""
|
|
print(f"\n=== Scale Test (β={beta}) ===")
|
|
|
|
pairs = [
|
|
("What's the weather like today?", "User checks weather every morning"),
|
|
("Let's deploy the new version", "Deployment uses GitHub Actions with k3s"),
|
|
("The database is slow again", "Missing index on users table"),
|
|
("I need to fix the auth bug", "JWT tokens with 24h expiry in Redis"),
|
|
("The API returns 500 errors", "OOM in the Python worker"),
|
|
]
|
|
paraphrases = [
|
|
"How's the weather outside?",
|
|
"We should push the new release",
|
|
"DB performance is terrible",
|
|
"There's a login bug to fix",
|
|
"Getting internal server errors",
|
|
]
|
|
|
|
embed_dim = model.get_sentence_embedding_dimension()
|
|
cue_embs = model.encode([p[0] for p in pairs], convert_to_tensor=True,
|
|
normalize_embeddings=True, device=DEVICE)
|
|
target_embs = model.encode([p[1] for p in pairs], convert_to_tensor=True,
|
|
normalize_embeddings=True, device=DEVICE)
|
|
para_embs = model.encode(paraphrases, convert_to_tensor=True,
|
|
normalize_embeddings=True, device=DEVICE)
|
|
|
|
for n_bg in n_background_list:
|
|
mem = HopfieldMemory(embed_dim, code_dim=8192, k=50, beta=beta)
|
|
|
|
# Store test pairs
|
|
for i in range(len(pairs)):
|
|
mem.learn(cue_embs[i], target_embs[i])
|
|
|
|
# Store background
|
|
if n_bg > 0:
|
|
# More diverse background sentences
|
|
bg_cues = []
|
|
bg_targets = []
|
|
topics = ["server", "database", "API", "frontend", "backend",
|
|
"cache", "queue", "network", "storage", "auth"]
|
|
for i in range(n_bg):
|
|
t = topics[i % len(topics)]
|
|
bg_cues.append(f"The {t} system has issue number {i}")
|
|
bg_targets.append(f"Issue {i} for {t} requires attention from team {i%5}")
|
|
|
|
for start in range(0, n_bg, 256):
|
|
end = min(start + 256, n_bg)
|
|
bc = model.encode(bg_cues[start:end], convert_to_tensor=True,
|
|
normalize_embeddings=True, device=DEVICE)
|
|
bt = model.encode(bg_targets[start:end], convert_to_tensor=True,
|
|
normalize_embeddings=True, device=DEVICE)
|
|
for j in range(bc.shape[0]):
|
|
mem.learn(bc[j], bt[j])
|
|
|
|
# Test
|
|
target_codes = torch.stack([mem.sep(t) for t in target_embs])
|
|
|
|
# Paraphrase recall
|
|
t0 = time.time()
|
|
para_correct = 0
|
|
for i in range(len(paraphrases)):
|
|
recalled = mem.recall(para_embs[i])
|
|
sims = nn.functional.cosine_similarity(recalled.unsqueeze(0), target_codes, dim=-1)
|
|
if sims.argmax().item() == i:
|
|
para_correct += 1
|
|
recall_time = (time.time() - t0) / len(paraphrases) * 1000
|
|
|
|
# Also test in embedding space
|
|
para_correct_emb = 0
|
|
for i in range(len(paraphrases)):
|
|
recalled_emb = mem.recall_embedding_space(para_embs[i])
|
|
sims = nn.functional.cosine_similarity(recalled_emb.unsqueeze(0), target_embs, dim=-1)
|
|
if sims.argmax().item() == i:
|
|
para_correct_emb += 1
|
|
|
|
n = len(paraphrases)
|
|
total_mem = len(mem.cue_codes)
|
|
vram = total_mem * 8192 * 4 * 2 / 1024**2 # codes + embs approx
|
|
print(f" N={total_mem:>6}: Code={para_correct}/{n} ({para_correct/n:.0%}), "
|
|
f"Emb={para_correct_emb}/{n} ({para_correct_emb/n:.0%}), "
|
|
f"time={recall_time:.1f}ms, ~VRAM={vram:.0f}MB")
|
|
|
|
del mem
|
|
torch.cuda.empty_cache()
|
|
|
|
|
|
def test_multihop(model):
|
|
"""Multi-hop through Hopfield memory."""
|
|
print("\n=== Multi-hop Test ===")
|
|
|
|
chains = [
|
|
["What's the weather?", "I check weather before going out",
|
|
"My coffee shop is around the corner", "They have great latte art"],
|
|
["Let's review the code", "Code review found a memory leak",
|
|
"Memory leaks cause OOM kills", "Need memory limits in k8s"],
|
|
["Deploy to production", "Production uses blue-green deploy",
|
|
"Blue environment is active", "Switch DNS to green when ready"],
|
|
]
|
|
|
|
embed_dim = model.get_sentence_embedding_dimension()
|
|
|
|
for chain in chains:
|
|
mem = HopfieldMemory(embed_dim, code_dim=8192, k=50, beta=16.0)
|
|
|
|
chain_embs = [model.encode([t], convert_to_tensor=True,
|
|
normalize_embeddings=True, device=DEVICE)[0]
|
|
for t in chain]
|
|
|
|
# Learn consecutive pairs
|
|
for i in range(len(chain) - 1):
|
|
mem.learn(chain_embs[i], chain_embs[i+1])
|
|
|
|
# Multi-hop recall
|
|
target_codes = [mem.sep(e) for e in chain_embs]
|
|
|
|
results = mem.recall_multihop(chain_embs[0], hops=len(chain)-1)
|
|
|
|
print(f"\n Chain: {' → '.join([c[:20]+'...' for c in chain])}")
|
|
for hop_idx, recalled in enumerate(results):
|
|
target = target_codes[hop_idx + 1]
|
|
sim = cosine(recalled, target)
|
|
status = "✓" if sim > 0.5 else "✗"
|
|
print(f" {status} hop {hop_idx+1}: → '{chain[hop_idx+1][:30]}...' sim={sim:.3f}")
|
|
|
|
# Multi-hop with background noise
|
|
print("\n --- Multi-hop with 200 background memories ---")
|
|
mem = HopfieldMemory(embed_dim, code_dim=8192, k=50, beta=16.0)
|
|
|
|
# Store all chains
|
|
all_chain_embs = []
|
|
for chain in chains:
|
|
embs = [model.encode([t], convert_to_tensor=True,
|
|
normalize_embeddings=True, device=DEVICE)[0]
|
|
for t in chain]
|
|
all_chain_embs.append(embs)
|
|
for i in range(len(chain) - 1):
|
|
mem.learn(embs[i], embs[i+1])
|
|
|
|
# Add background
|
|
bg = [f"Background sentence number {i}" for i in range(200)]
|
|
bg_embs = model.encode(bg, convert_to_tensor=True,
|
|
normalize_embeddings=True, device=DEVICE)
|
|
for i in range(199):
|
|
mem.learn(bg_embs[i], bg_embs[i+1])
|
|
|
|
for ci, chain in enumerate(chains):
|
|
target_codes = [mem.sep(e) for e in all_chain_embs[ci]]
|
|
results = mem.recall_multihop(all_chain_embs[ci][0], hops=len(chain)-1)
|
|
|
|
for hop_idx, recalled in enumerate(results):
|
|
target = target_codes[hop_idx + 1]
|
|
sim = cosine(recalled, target)
|
|
status = "✓" if sim > 0.5 else "✗"
|
|
print(f" {status} Chain{ci+1} hop{hop_idx+1}: sim={sim:.3f}")
|
|
|
|
|
|
def test_hard_distractors(model):
|
|
"""Test with semantically similar distractors (harder than random bg)."""
|
|
print("\n=== Hard Distractors (semantically similar) ===")
|
|
|
|
# Target pair
|
|
pairs = [
|
|
("The database is slow", "Missing index on users table"),
|
|
]
|
|
# Distractors: similar to cue but different meaning
|
|
distractors_cue = [
|
|
"The database is fast",
|
|
"The database crashed",
|
|
"The database needs backup",
|
|
"The datastore is slow",
|
|
"The DB latency is high",
|
|
"Database performance degraded",
|
|
"SQL queries are slow",
|
|
"The cache is slow",
|
|
"The search index is slow",
|
|
"MongoDB is slow",
|
|
]
|
|
distractors_target = [
|
|
f"Distractor target {i}" for i in range(len(distractors_cue))
|
|
]
|
|
|
|
query = "DB performance is terrible"
|
|
|
|
embed_dim = model.get_sentence_embedding_dimension()
|
|
|
|
for beta in [8.0, 16.0, 32.0, 64.0]:
|
|
mem = HopfieldMemory(embed_dim, code_dim=8192, k=50, beta=beta)
|
|
|
|
# Store target
|
|
cue_emb = model.encode([pairs[0][0]], convert_to_tensor=True,
|
|
normalize_embeddings=True, device=DEVICE)[0]
|
|
target_emb = model.encode([pairs[0][1]], convert_to_tensor=True,
|
|
normalize_embeddings=True, device=DEVICE)[0]
|
|
mem.learn(cue_emb, target_emb)
|
|
|
|
# Store distractors
|
|
dist_cue_embs = model.encode(distractors_cue, convert_to_tensor=True,
|
|
normalize_embeddings=True, device=DEVICE)
|
|
dist_target_embs = model.encode(distractors_target, convert_to_tensor=True,
|
|
normalize_embeddings=True, device=DEVICE)
|
|
for i in range(len(distractors_cue)):
|
|
mem.learn(dist_cue_embs[i], dist_target_embs[i])
|
|
|
|
# Query
|
|
q_emb = model.encode([query], convert_to_tensor=True,
|
|
normalize_embeddings=True, device=DEVICE)[0]
|
|
recalled = mem.recall(q_emb)
|
|
target_code = mem.sep(target_emb)
|
|
sim = cosine(recalled, target_code)
|
|
|
|
# Also check which cue got highest attention
|
|
cue_mat = torch.stack(mem.cue_codes)
|
|
q_code = mem.sep(q_emb)
|
|
scores = beta * (q_code @ cue_mat.T)
|
|
attn = torch.softmax(scores, dim=0)
|
|
top_idx = attn.argmax().item()
|
|
top_attn = attn[top_idx].item()
|
|
|
|
all_cues = [pairs[0][0]] + distractors_cue
|
|
print(f" β={beta:>4}: sim_to_target={sim:.3f}, "
|
|
f"top_attn={top_attn:.3f} → '{all_cues[top_idx][:30]}...'")
|
|
|
|
|
|
def main():
|
|
print("=" * 60)
|
|
print("Experiment 7b: Hopfield Deep Dive")
|
|
print("=" * 60)
|
|
|
|
model = load_model()
|
|
|
|
# Scale test
|
|
test_scale(model, [0, 100, 500, 1000, 2000, 5000, 10000], beta=16.0)
|
|
|
|
# β sweep at large scale
|
|
print("\n=== β Sweep at N=5000 ===")
|
|
for beta in [4, 8, 16, 32, 64]:
|
|
test_scale(model, [5000], beta=beta)
|
|
|
|
# Multi-hop
|
|
test_multihop(model)
|
|
|
|
# Hard distractors
|
|
test_hard_distractors(model)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|