Hopfield + Hebbian hybrid memory system for LLMs. Two nights of experiments (16 iterations), validated on LongMemEval (ICLR 2025). Architecture: - Single-hop: Two-Stage Hopfield (NN top-20 → softmax settle) - Multi-hop: Hebbian W matrix with WTA pattern separation - 64% on LongMemEval (500 questions), retrieval-only, no LLM dependency - 4ms latency @ 20K memories, ~1GB VRAM Key findings: - Hopfield attention solved noise tolerance (20% → 100% vs flat Hebbian) - WTA pattern separation enables 20K+ capacity - Multi-hop associative chains (6 hops, CosSim=1.0) — RAG can't do this - MiniLM-L6 is optimal (discrimination gap > absolute similarity) - Paraphrase cue augmentation: 55% → 100% on synthetic, 36% → 64% on benchmark - SNN encoder viable (CosSim 0.99) but not needed for current architecture
340 lines
13 KiB
Python
340 lines
13 KiB
Python
"""Experiment 7d: Two-stage retrieval for scale.
|
|
|
|
Problem: Embedding Hopfield degrades at 10K+ (80%).
|
|
Fix: Pre-filter with approximate NN (top-K), then Hopfield settle on candidates.
|
|
|
|
This is O(N) for pre-filter (can be O(log N) with FAISS) + O(K) for Hopfield.
|
|
Also: test adaptive β based on attention entropy (low entropy = confident).
|
|
"""
|
|
|
|
import sys
|
|
import time
|
|
from pathlib import Path
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import numpy as np
|
|
|
|
DEVICE = "cuda"
|
|
|
|
|
|
def cosine(a, b):
|
|
return nn.functional.cosine_similarity(a.unsqueeze(0), b.unsqueeze(0)).item()
|
|
|
|
|
|
class TwoStageHopfield:
|
|
"""Pre-filter + Hopfield settle.
|
|
|
|
Stage 1: cosine NN → top-K candidates (fast, O(N) or O(log N) with index)
|
|
Stage 2: Hopfield attention over K candidates (precise, O(K))
|
|
"""
|
|
def __init__(self, beta=16.0, top_k=50):
|
|
self.beta = beta
|
|
self.top_k = top_k
|
|
self.cue_embs = []
|
|
self.target_embs = []
|
|
self._cue_matrix = None # Cached for batch NN
|
|
|
|
def learn(self, cue_emb, target_emb):
|
|
self.cue_embs.append(cue_emb.detach())
|
|
self.target_embs.append(target_emb.detach())
|
|
self._cue_matrix = None # Invalidate cache
|
|
|
|
def _get_cue_matrix(self):
|
|
if self._cue_matrix is None:
|
|
self._cue_matrix = torch.stack(self.cue_embs)
|
|
return self._cue_matrix
|
|
|
|
def recall(self, query_emb, steps=3):
|
|
cue_mat = self._get_cue_matrix()
|
|
target_mat = torch.stack(self.target_embs)
|
|
N = cue_mat.shape[0]
|
|
|
|
# Stage 1: Fast NN pre-filter
|
|
k = min(self.top_k, N)
|
|
sims = query_emb @ cue_mat.T # [N]
|
|
top_sims, top_indices = sims.topk(k)
|
|
|
|
# Stage 2: Hopfield settle on candidates only
|
|
cand_cues = cue_mat[top_indices] # [K, dim]
|
|
cand_targets = target_mat[top_indices] # [K, dim]
|
|
|
|
xi = query_emb
|
|
for _ in range(steps):
|
|
scores = self.beta * (xi @ cand_cues.T)
|
|
attn = torch.softmax(scores, dim=0)
|
|
xi = attn @ cand_cues
|
|
xi = nn.functional.normalize(xi, dim=0)
|
|
|
|
# Final association
|
|
scores = self.beta * (xi @ cand_cues.T)
|
|
attn = torch.softmax(scores, dim=0)
|
|
target = attn @ cand_targets
|
|
|
|
# Map back to global index
|
|
best_local = attn.argmax().item()
|
|
best_global = top_indices[best_local].item()
|
|
|
|
return nn.functional.normalize(target, dim=0), best_global, attn
|
|
|
|
def recall_multihop(self, query_emb, hops=2, steps=3):
|
|
"""Multi-hop: each hop does two-stage retrieval."""
|
|
xi = query_emb
|
|
results = []
|
|
for _ in range(hops):
|
|
target, idx, attn = self.recall(xi, steps=steps)
|
|
results.append((target, idx))
|
|
xi = target # Use target as next query
|
|
return results
|
|
|
|
|
|
def load_model():
|
|
from sentence_transformers import SentenceTransformer
|
|
return SentenceTransformer("all-MiniLM-L6-v2", device=DEVICE)
|
|
|
|
|
|
def test_scale(model):
|
|
"""Scale test comparing pure Hopfield vs two-stage."""
|
|
print("\n=== Scale Comparison ===")
|
|
|
|
pairs = [
|
|
("What's the weather like today?", "User checks weather every morning"),
|
|
("Let's deploy the new version", "Deployment uses GitHub Actions with k3s"),
|
|
("The database is slow again", "Missing index on users table"),
|
|
("I need to fix the auth bug", "JWT tokens with 24h expiry in Redis"),
|
|
("The API returns 500 errors", "OOM in the Python worker"),
|
|
("Let's set up monitoring", "Prometheus + Grafana on OCI"),
|
|
("Tests failing in CI", "CI needs postgres service container"),
|
|
("Memory usage too high", "Leak in websocket handler"),
|
|
("Help with Docker setup", "docker-compose for dev, k3s for prod"),
|
|
("Log files too large", "Logs rotate daily, shipped to Loki"),
|
|
]
|
|
paraphrases = [
|
|
"How's the weather outside?",
|
|
"Push the new release",
|
|
"DB performance terrible",
|
|
"Login bug needs fixing",
|
|
"Getting 500 errors",
|
|
"Need better observability",
|
|
"CI tests breaking",
|
|
"Service using too much RAM",
|
|
"Docker config help",
|
|
"Logs eating disk space",
|
|
]
|
|
|
|
cue_embs = model.encode([p[0] for p in pairs], convert_to_tensor=True,
|
|
normalize_embeddings=True, device=DEVICE)
|
|
target_embs = model.encode([p[1] for p in pairs], convert_to_tensor=True,
|
|
normalize_embeddings=True, device=DEVICE)
|
|
para_embs = model.encode(paraphrases, convert_to_tensor=True,
|
|
normalize_embeddings=True, device=DEVICE)
|
|
|
|
for n_bg in [0, 100, 500, 1000, 5000, 10000, 20000]:
|
|
# Two-stage with different K
|
|
for top_k in [20, 50, 100]:
|
|
if n_bg < top_k and n_bg > 0:
|
|
continue
|
|
|
|
mem = TwoStageHopfield(beta=16.0, top_k=top_k)
|
|
|
|
for i in range(len(pairs)):
|
|
mem.learn(cue_embs[i], target_embs[i])
|
|
|
|
if n_bg > 0:
|
|
topics = ["server", "database", "API", "frontend", "backend",
|
|
"cache", "queue", "network", "storage", "auth",
|
|
"docker", "kubernetes", "redis", "nginx", "postgres"]
|
|
bg_cues = [f"The {topics[i%len(topics)]} system has issue {i}"
|
|
for i in range(n_bg)]
|
|
bg_targets = [f"Fix {topics[i%len(topics)]} issue {i} urgently"
|
|
for i in range(n_bg)]
|
|
|
|
for start in range(0, n_bg, 256):
|
|
end = min(start + 256, n_bg)
|
|
bc = model.encode(bg_cues[start:end], convert_to_tensor=True,
|
|
normalize_embeddings=True, device=DEVICE)
|
|
bt = model.encode(bg_targets[start:end], convert_to_tensor=True,
|
|
normalize_embeddings=True, device=DEVICE)
|
|
for j in range(bc.shape[0]):
|
|
mem.learn(bc[j], bt[j])
|
|
|
|
# Test
|
|
t0 = time.time()
|
|
correct = 0
|
|
for i in range(len(paraphrases)):
|
|
with torch.no_grad():
|
|
recalled, idx, attn = mem.recall(para_embs[i])
|
|
all_sims = [cosine(recalled, target_embs[j]) for j in range(len(pairs))]
|
|
if np.argmax(all_sims) == i:
|
|
correct += 1
|
|
dt = (time.time() - t0) / len(paraphrases) * 1000
|
|
|
|
n = len(paraphrases)
|
|
total = len(mem.cue_embs)
|
|
print(f" N={total:>6}, K={top_k:>3}: "
|
|
f"Para={correct}/{n} ({correct/n:>3.0%}), "
|
|
f"time={dt:.1f}ms")
|
|
|
|
del mem
|
|
torch.cuda.empty_cache()
|
|
|
|
if n_bg > 0:
|
|
print()
|
|
|
|
|
|
def test_multihop_at_scale(model):
|
|
"""Multi-hop with two-stage at scale."""
|
|
print("\n=== Multi-hop Two-Stage (500 bg) ===")
|
|
|
|
chains = [
|
|
["What's the weather?", "Check weather before going out",
|
|
"My coffee shop nearby", "Great latte art"],
|
|
["Review the code", "Found memory leak", "Leaks cause OOM", "Add k8s limits"],
|
|
["Deploy to prod", "Blue-green deployment", "Blue is active", "Switch to green"],
|
|
]
|
|
|
|
mem = TwoStageHopfield(beta=16.0, top_k=50)
|
|
|
|
all_embs = []
|
|
for chain in chains:
|
|
embs = [model.encode([t], convert_to_tensor=True,
|
|
normalize_embeddings=True, device=DEVICE)[0]
|
|
for t in chain]
|
|
all_embs.append(embs)
|
|
for i in range(len(chain) - 1):
|
|
mem.learn(embs[i], embs[i+1])
|
|
|
|
# Background
|
|
bg = [f"Background about {['code','ops','ml','data','infra'][i%5]} number {i}"
|
|
for i in range(500)]
|
|
bg_embs = model.encode(bg, convert_to_tensor=True,
|
|
normalize_embeddings=True, device=DEVICE, batch_size=256)
|
|
for i in range(499):
|
|
mem.learn(bg_embs[i], bg_embs[i+1])
|
|
|
|
for ci, chain in enumerate(chains):
|
|
results = mem.recall_multihop(all_embs[ci][0], hops=len(chain)-1)
|
|
for hop_idx, (recalled, idx) in enumerate(results):
|
|
target = all_embs[ci][hop_idx + 1]
|
|
sim = cosine(recalled, target)
|
|
status = "✓" if sim > 0.7 else "✗"
|
|
print(f" {status} Chain{ci+1} hop{hop_idx+1}: sim={sim:.3f}")
|
|
|
|
|
|
def test_diverse_queries(model):
|
|
"""Larger test set with more diverse queries."""
|
|
print("\n=== Diverse Query Test (20 pairs, 2000 bg) ===")
|
|
|
|
pairs = [
|
|
("What's the weather like today?", "User checks weather every morning"),
|
|
("Let's deploy the new version", "Deployment uses GitHub Actions with k3s"),
|
|
("The database is slow again", "Missing index on users table"),
|
|
("I need to fix the auth bug", "JWT tokens with 24h expiry in Redis"),
|
|
("The API returns 500 errors", "OOM in the Python worker"),
|
|
("Let's set up monitoring", "Prometheus + Grafana on OCI"),
|
|
("Tests failing in CI", "CI needs postgres service container"),
|
|
("Memory usage too high", "Leak in websocket handler"),
|
|
("Help with Docker setup", "docker-compose for dev, k3s for prod"),
|
|
("Log files too large", "Logs rotate daily, shipped to Loki"),
|
|
("How to add caching?", "Redis available at redis.internal:6379"),
|
|
("Frontend loads slowly", "CDN CloudFlare, 1h TTL for assets"),
|
|
("Refactor payment module", "Stripe API, webhook in payments/webhook.py"),
|
|
("Set up new server", "Ubuntu 22.04, Docker, Tailscale, monitoring"),
|
|
("Optimize search", "Elasticsearch v8, recently upgraded"),
|
|
("Backup the database", "Daily 3am UTC cron to S3"),
|
|
("Configure reverse proxy", "Traefik, not nginx"),
|
|
("Team meeting schedule", "Standup 10am London, Mon-Fri"),
|
|
("Learn a new language", "User has Python+Go, new to systems programming"),
|
|
("Review my PR", "User prefers small PRs with clear commits"),
|
|
]
|
|
paraphrases = [
|
|
"How's the weather?",
|
|
"Ship the release",
|
|
"DB is crawling",
|
|
"Fix the login issue",
|
|
"Server errors everywhere",
|
|
"Need observability",
|
|
"CI is broken",
|
|
"Too much RAM usage",
|
|
"Docker help please",
|
|
"Disk full from logs",
|
|
"Want to add a cache layer",
|
|
"Website too slow",
|
|
"Payment code needs rework",
|
|
"Provision a new machine",
|
|
"Search is slow",
|
|
"Need a DB backup",
|
|
"Proxy configuration",
|
|
"When's the standup?",
|
|
"Want to learn Rust",
|
|
"Check my pull request",
|
|
]
|
|
|
|
cue_embs = model.encode([p[0] for p in pairs], convert_to_tensor=True,
|
|
normalize_embeddings=True, device=DEVICE)
|
|
target_embs = model.encode([p[1] for p in pairs], convert_to_tensor=True,
|
|
normalize_embeddings=True, device=DEVICE)
|
|
para_embs = model.encode(paraphrases, convert_to_tensor=True,
|
|
normalize_embeddings=True, device=DEVICE)
|
|
|
|
mem = TwoStageHopfield(beta=16.0, top_k=50)
|
|
for i in range(len(pairs)):
|
|
mem.learn(cue_embs[i], target_embs[i])
|
|
|
|
# 2000 diverse background
|
|
topics = ["server", "database", "API", "frontend", "backend", "cache",
|
|
"queue", "network", "storage", "auth", "docker", "kubernetes",
|
|
"redis", "nginx", "postgres", "python", "golang", "react",
|
|
"terraform", "ansible"]
|
|
actions = ["crashed", "is slow", "needs update", "has bug", "timed out",
|
|
"needs migration", "needs backup", "has leak", "is down", "needs config"]
|
|
bg_cues = [f"The {topics[i%len(topics)]} {actions[i%len(actions)]} (ticket {i})"
|
|
for i in range(2000)]
|
|
bg_targets = [f"Fix {topics[i%len(topics)]} {actions[i%len(actions)]}: see wiki page {i}"
|
|
for i in range(2000)]
|
|
|
|
for start in range(0, 2000, 256):
|
|
end = min(start + 256, 2000)
|
|
bc = model.encode(bg_cues[start:end], convert_to_tensor=True,
|
|
normalize_embeddings=True, device=DEVICE)
|
|
bt = model.encode(bg_targets[start:end], convert_to_tensor=True,
|
|
normalize_embeddings=True, device=DEVICE)
|
|
for j in range(bc.shape[0]):
|
|
mem.learn(bc[j], bt[j])
|
|
|
|
# Test
|
|
correct = 0
|
|
failures = []
|
|
for i in range(len(paraphrases)):
|
|
with torch.no_grad():
|
|
recalled, idx, attn = mem.recall(para_embs[i])
|
|
all_sims = [cosine(recalled, target_embs[j]) for j in range(len(pairs))]
|
|
best = np.argmax(all_sims)
|
|
if best == i:
|
|
correct += 1
|
|
else:
|
|
failures.append((i, best, all_sims[i], all_sims[best]))
|
|
|
|
n = len(paraphrases)
|
|
print(f" Result: {correct}/{n} ({correct/n:.0%})")
|
|
if failures:
|
|
print(f" Failures:")
|
|
for qi, gi, sim_correct, sim_got in failures:
|
|
print(f" Q: '{paraphrases[qi][:30]}...' → got [{gi}] "
|
|
f"(sim_correct={sim_correct:.3f}, sim_got={sim_got:.3f})")
|
|
|
|
|
|
def main():
|
|
print("=" * 60)
|
|
print("Experiment 7d: Two-Stage Hopfield")
|
|
print("=" * 60)
|
|
|
|
model = load_model()
|
|
test_scale(model)
|
|
test_multihop_at_scale(model)
|
|
test_diverse_queries(model)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|