Hopfield + Hebbian hybrid memory system for LLMs. Two nights of experiments (16 iterations), validated on LongMemEval (ICLR 2025). Architecture: - Single-hop: Two-Stage Hopfield (NN top-20 → softmax settle) - Multi-hop: Hebbian W matrix with WTA pattern separation - 64% on LongMemEval (500 questions), retrieval-only, no LLM dependency - 4ms latency @ 20K memories, ~1GB VRAM Key findings: - Hopfield attention solved noise tolerance (20% → 100% vs flat Hebbian) - WTA pattern separation enables 20K+ capacity - Multi-hop associative chains (6 hops, CosSim=1.0) — RAG can't do this - MiniLM-L6 is optimal (discrimination gap > absolute similarity) - Paraphrase cue augmentation: 55% → 100% on synthetic, 36% → 64% on benchmark - SNN encoder viable (CosSim 0.99) but not needed for current architecture
210 lines
7.4 KiB
Python
210 lines
7.4 KiB
Python
"""Experiment 2c: Pattern separation + improved associative recall.
|
|
|
|
Key insight from 2b: random spike patterns have too much overlap,
|
|
causing catastrophic interference in associative memory.
|
|
|
|
Fix: Implement pattern separation (like dentate gyrus in hippocampus):
|
|
1. Winner-take-all: only top-k neurons fire → guaranteed sparse, minimal overlap
|
|
2. Random sparse projection: patterns projected through sparse random matrix
|
|
3. Scale up neurons to improve signal-to-noise ratio (capacity ∝ N/P)
|
|
|
|
Also test: direct Hebbian in rate-space (skip spike conversion entirely)
|
|
"""
|
|
|
|
import sys
|
|
import time
|
|
import json
|
|
from pathlib import Path
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import numpy as np
|
|
|
|
sys.path.insert(0, str(Path(__file__).parent.parent / "src"))
|
|
|
|
DEVICE = "cuda"
|
|
RESULTS_DIR = Path(__file__).parent.parent / "doc"
|
|
|
|
|
|
def cosine(a, b):
|
|
if a.norm() == 0 or b.norm() == 0:
|
|
return 0.0
|
|
return nn.functional.cosine_similarity(a.unsqueeze(0), b.unsqueeze(0)).item()
|
|
|
|
|
|
def winner_take_all(x, k):
|
|
"""Keep only top-k values, zero out the rest. Differentiable-ish."""
|
|
topk_vals, topk_idx = x.topk(k, dim=-1)
|
|
out = torch.zeros_like(x)
|
|
out.scatter_(-1, topk_idx, 1.0) # Binary: active or not
|
|
return out
|
|
|
|
|
|
class PatternSeparator(nn.Module):
|
|
"""Dentate gyrus analog: transforms input patterns into sparse, orthogonal codes."""
|
|
|
|
def __init__(self, input_dim, code_dim, k_active):
|
|
super().__init__()
|
|
self.k_active = k_active
|
|
# Sparse random projection (fixed, not learned)
|
|
proj = torch.randn(input_dim, code_dim) * (1.0 / input_dim**0.5)
|
|
self.register_buffer('proj', proj)
|
|
|
|
def forward(self, x):
|
|
"""x: [input_dim] → [code_dim] sparse binary"""
|
|
h = x @ self.proj
|
|
return winner_take_all(h, self.k_active)
|
|
|
|
|
|
class HebbianMemory(nn.Module):
|
|
"""Heteroassociative memory with pattern separation."""
|
|
|
|
def __init__(self, input_dim, code_dim=8192, k_active=50, lr=1.0):
|
|
super().__init__()
|
|
self.separator = PatternSeparator(input_dim, code_dim, k_active)
|
|
self.code_dim = code_dim
|
|
self.lr = lr
|
|
|
|
# Separate separator for targets (different random projection)
|
|
self.target_separator = PatternSeparator(input_dim, code_dim, k_active)
|
|
|
|
# Association matrix: separated_cue → separated_target
|
|
self.W = nn.Parameter(torch.zeros(code_dim, code_dim), requires_grad=False)
|
|
|
|
def learn(self, cue, target):
|
|
"""cue, target: [dim] continuous vectors"""
|
|
cue_code = self.separator(cue)
|
|
target_code = self.target_separator(target)
|
|
# Outer product Hebbian update
|
|
self.W.data += self.lr * torch.outer(target_code, cue_code)
|
|
|
|
def recall(self, cue, k_recall=50):
|
|
"""Returns separated target code."""
|
|
cue_code = self.separator(cue)
|
|
raw = self.W @ cue_code
|
|
# WTA on output to clean up
|
|
return winner_take_all(raw, k_recall)
|
|
|
|
def recall_continuous(self, cue):
|
|
"""Returns continuous activation (for cosine sim)."""
|
|
cue_code = self.separator(cue)
|
|
return self.W @ cue_code
|
|
|
|
|
|
def test_hebbian_with_separation(input_dim, code_dim, k_active, num_pairs, lr):
|
|
"""Test associative recall with pattern separation."""
|
|
mem = HebbianMemory(input_dim, code_dim, k_active, lr).to(DEVICE)
|
|
|
|
# Generate random normalized vectors as memories
|
|
cues = [nn.functional.normalize(torch.randn(input_dim, device=DEVICE), dim=0)
|
|
for _ in range(num_pairs)]
|
|
targets = [nn.functional.normalize(torch.randn(input_dim, device=DEVICE), dim=0)
|
|
for _ in range(num_pairs)]
|
|
|
|
# Learn
|
|
for i in range(num_pairs):
|
|
mem.learn(cues[i], targets[i])
|
|
|
|
# Test recall in code space (after separation)
|
|
correct_sims = []
|
|
wrong_sims = []
|
|
|
|
for i in range(num_pairs):
|
|
recalled = mem.recall(cues[i], k_recall=k_active)
|
|
target_code = mem.target_separator(targets[i])
|
|
|
|
cs = cosine(recalled, target_code)
|
|
correct_sims.append(cs)
|
|
|
|
for j in range(min(num_pairs, 20)): # limit comparisons for speed
|
|
if j != i:
|
|
wrong_code = mem.target_separator(targets[j])
|
|
wrong_sims.append(cosine(recalled, wrong_code))
|
|
|
|
mc = np.mean(correct_sims)
|
|
mw = np.mean(wrong_sims) if wrong_sims else 0
|
|
|
|
print(f" code={code_dim}, k={k_active}, pairs={num_pairs}, lr={lr:.2f} | "
|
|
f"Correct={mc:.4f}, Wrong={mw:.4f}, Disc={mc-mw:.4f}")
|
|
|
|
return {"correct": mc, "wrong": mw, "disc": mc - mw,
|
|
"code_dim": code_dim, "k_active": k_active,
|
|
"num_pairs": num_pairs, "lr": lr}
|
|
|
|
|
|
def test_overlap_analysis(code_dim, k_active, num_patterns):
|
|
"""Measure how orthogonal the separated patterns actually are."""
|
|
sep = PatternSeparator(768, code_dim, k_active).to(DEVICE)
|
|
|
|
patterns = []
|
|
for _ in range(num_patterns):
|
|
x = nn.functional.normalize(torch.randn(768, device=DEVICE), dim=0)
|
|
code = sep(x)
|
|
patterns.append(code)
|
|
|
|
# Pairwise cosine similarity
|
|
sims = []
|
|
for i in range(num_patterns):
|
|
for j in range(i+1, num_patterns):
|
|
s = cosine(patterns[i], patterns[j])
|
|
sims.append(s)
|
|
|
|
mean_sim = np.mean(sims)
|
|
max_sim = np.max(sims)
|
|
print(f" code={code_dim}, k={k_active}: mean_overlap={mean_sim:.4f}, max_overlap={max_sim:.4f}")
|
|
return {"mean_overlap": mean_sim, "max_overlap": max_sim}
|
|
|
|
|
|
def main():
|
|
print("=" * 60)
|
|
print("Experiment 2c: Pattern Separation + Hebbian Memory")
|
|
print("=" * 60)
|
|
|
|
results = []
|
|
|
|
# Part 1: Overlap analysis — how orthogonal are separated patterns?
|
|
print("\n=== Part 1: Pattern overlap after separation ===")
|
|
for code_dim in [2048, 4096, 8192, 16384]:
|
|
for k in [20, 50, 100]:
|
|
ov = test_overlap_analysis(code_dim, k, 100)
|
|
results.append({"test": "overlap", "code_dim": code_dim, "k": k, **ov})
|
|
|
|
# Part 2: Associative recall with separation
|
|
print("\n=== Part 2: Recall with pattern separation ===")
|
|
|
|
print("\n-- Scaling pairs --")
|
|
for n in [1, 5, 10, 20, 50, 100, 200, 500]:
|
|
r = test_hebbian_with_separation(768, 8192, 50, n, lr=1.0)
|
|
results.append({"test": f"sep_pairs_{n}", **r})
|
|
|
|
print("\n-- Code dimension sweep (100 pairs) --")
|
|
for cd in [2048, 4096, 8192, 16384]:
|
|
r = test_hebbian_with_separation(768, cd, 50, 100, lr=1.0)
|
|
results.append({"test": f"sep_codedim_{cd}", **r})
|
|
|
|
print("\n-- Sparsity sweep (100 pairs, code=8192) --")
|
|
for k in [10, 20, 50, 100, 200]:
|
|
r = test_hebbian_with_separation(768, 8192, k, 100, lr=1.0)
|
|
results.append({"test": f"sep_k_{k}", **r})
|
|
|
|
print("\n-- Capacity test: find the breaking point (code=16384, k=20) --")
|
|
for n in [10, 50, 100, 200, 500, 1000, 2000]:
|
|
r = test_hebbian_with_separation(768, 16384, 20, n, lr=1.0)
|
|
results.append({"test": f"cap_{n}", **r})
|
|
|
|
# Save
|
|
with open(RESULTS_DIR / "exp02c_results.json", "w") as f:
|
|
json.dump(results, f, indent=2, default=float)
|
|
|
|
# Find best config
|
|
recall_results = [r for r in results if r.get("disc") is not None and "cap_" in r.get("test", "")]
|
|
if recall_results:
|
|
print("\n=== Capacity curve (code=16384, k=20) ===")
|
|
print(f"{'Pairs':>6} {'Correct':>8} {'Wrong':>8} {'Disc':>8}")
|
|
for r in recall_results:
|
|
print(f"{r['num_pairs']:>6} {r['correct']:>8.4f} {r['wrong']:>8.4f} {r['disc']:>8.4f}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|