Hopfield + Hebbian hybrid memory system for LLMs. Two nights of experiments (16 iterations), validated on LongMemEval (ICLR 2025). Architecture: - Single-hop: Two-Stage Hopfield (NN top-20 → softmax settle) - Multi-hop: Hebbian W matrix with WTA pattern separation - 64% on LongMemEval (500 questions), retrieval-only, no LLM dependency - 4ms latency @ 20K memories, ~1GB VRAM Key findings: - Hopfield attention solved noise tolerance (20% → 100% vs flat Hebbian) - WTA pattern separation enables 20K+ capacity - Multi-hop associative chains (6 hops, CosSim=1.0) — RAG can't do this - MiniLM-L6 is optimal (discrimination gap > absolute similarity) - Paraphrase cue augmentation: 55% → 100% on synthetic, 36% → 64% on benchmark - SNN encoder viable (CosSim 0.99) but not needed for current architecture
188 lines
6.3 KiB
Python
188 lines
6.3 KiB
Python
"""Experiment 3b: Consolidation near capacity limits.
|
|
|
|
With code_dim=16384 and k=20, capacity is so high that consolidation seems
|
|
unnecessary. Test with smaller code_dim (2048) where capacity limits are lower
|
|
and consolidation effects should be visible.
|
|
|
|
Also test: stronger homeostasis to control W_norm growth.
|
|
"""
|
|
|
|
import sys
|
|
import time
|
|
import json
|
|
from pathlib import Path
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import numpy as np
|
|
|
|
sys.path.insert(0, str(Path(__file__).parent.parent / "src"))
|
|
from nuonuo.consolidation import MemoryConsolidator, winner_take_all
|
|
|
|
DEVICE = "cuda"
|
|
RESULTS_DIR = Path(__file__).parent.parent / "doc"
|
|
|
|
|
|
def cosine(a, b):
|
|
if a.norm() == 0 or b.norm() == 0:
|
|
return 0.0
|
|
return nn.functional.cosine_similarity(a.unsqueeze(0), b.unsqueeze(0)).item()
|
|
|
|
|
|
class SmallMemory:
|
|
"""Smaller memory for capacity-limited tests."""
|
|
def __init__(self, input_dim=768, code_dim=2048, k=50):
|
|
self.k = k
|
|
self.code_dim = code_dim
|
|
self.proj = (torch.randn(input_dim, code_dim, device=DEVICE)
|
|
* (1.0 / input_dim**0.5))
|
|
self.target_proj = (torch.randn(input_dim, code_dim, device=DEVICE)
|
|
* (1.0 / input_dim**0.5))
|
|
self.W = nn.Parameter(torch.zeros(code_dim, code_dim, device=DEVICE),
|
|
requires_grad=False)
|
|
self.consolidator = MemoryConsolidator(code_dim, k)
|
|
|
|
def sep(self, x):
|
|
return winner_take_all(x @ self.proj, self.k)
|
|
|
|
def sep_target(self, x):
|
|
return winner_take_all(x @ self.target_proj, self.k)
|
|
|
|
def learn(self, cue, target, record=True):
|
|
cc = self.sep(cue)
|
|
tc = self.sep_target(target)
|
|
self.W.data += torch.outer(tc, cc)
|
|
if record:
|
|
self.consolidator.record(cc, tc)
|
|
|
|
def recall(self, cue):
|
|
cc = self.sep(cue)
|
|
raw = self.W @ cc
|
|
return winner_take_all(raw, self.k)
|
|
|
|
def test_recall(self, cues, targets):
|
|
correct = []
|
|
for i in range(len(cues)):
|
|
recalled = self.recall(cues[i])
|
|
tc = self.sep_target(targets[i])
|
|
correct.append(cosine(recalled, tc))
|
|
return np.mean(correct), np.mean([s > 0.5 for s in correct])
|
|
|
|
def consolidate(self, **kwargs):
|
|
return self.consolidator.consolidate(
|
|
self.W, self.proj, self.target_proj, **kwargs)
|
|
|
|
|
|
def gen_memories(n, dim=768):
|
|
cues = [nn.functional.normalize(torch.randn(dim, device=DEVICE), dim=0)
|
|
for _ in range(n)]
|
|
targets = [nn.functional.normalize(torch.randn(dim, device=DEVICE), dim=0)
|
|
for _ in range(n)]
|
|
return cues, targets
|
|
|
|
|
|
def test_capacity_with_consolidation():
|
|
"""Find where small memory breaks and see if consolidation helps."""
|
|
print("=== Capacity with code_dim=2048, k=50 ===")
|
|
|
|
for n_pairs in [50, 100, 200, 500, 1000, 2000]:
|
|
mem_no_consol = SmallMemory()
|
|
mem_with_consol = SmallMemory()
|
|
mem_with_consol.proj = mem_no_consol.proj
|
|
mem_with_consol.target_proj = mem_no_consol.target_proj
|
|
|
|
cues, targets = gen_memories(n_pairs)
|
|
|
|
# Learn in both
|
|
for i in range(n_pairs):
|
|
mem_no_consol.learn(cues[i], targets[i], record=False)
|
|
mem_with_consol.learn(cues[i], targets[i], record=True)
|
|
|
|
cos_no, rate_no = mem_no_consol.test_recall(cues, targets)
|
|
|
|
# Consolidate with strong homeostasis
|
|
mem_with_consol.consolidate(num_epochs=3, homeostasis_factor=0.80,
|
|
prune_threshold=0.01)
|
|
cos_yes, rate_yes = mem_with_consol.test_recall(cues, targets)
|
|
|
|
w_no = mem_no_consol.W.data.norm().item()
|
|
w_yes = mem_with_consol.W.data.norm().item()
|
|
|
|
print(f" N={n_pairs:>5}: "
|
|
f"No_consol: CosSim={cos_no:.4f} Rate={rate_no:.0%} W={w_no:.0f} | "
|
|
f"With_consol: CosSim={cos_yes:.4f} Rate={rate_yes:.0%} W={w_yes:.0f}")
|
|
|
|
|
|
def test_multi_night_at_limit():
|
|
"""7-day scenario near capacity limits."""
|
|
print("\n=== 7-Day Scenario (code_dim=2048, k=50, 200/day) ===")
|
|
|
|
mem = SmallMemory()
|
|
all_cues = []
|
|
all_targets = []
|
|
|
|
for day in range(1, 8):
|
|
cues_today, targets_today = gen_memories(200)
|
|
all_cues.extend(cues_today)
|
|
all_targets.extend(targets_today)
|
|
|
|
for i in range(200):
|
|
mem.learn(cues_today[i], targets_today[i])
|
|
|
|
# Test on all memories so far
|
|
cos_all, rate_all = mem.test_recall(all_cues, all_targets)
|
|
cos_today, rate_today = mem.test_recall(cues_today, targets_today)
|
|
cos_day1, _ = mem.test_recall(all_cues[:200], all_targets[:200])
|
|
|
|
w_norm = mem.W.data.norm().item()
|
|
print(f" Day {day} (total={len(all_cues)}): "
|
|
f"All={cos_all:.4f}({rate_all:.0%}), "
|
|
f"Today={cos_today:.4f}, Day1={cos_day1:.4f}, "
|
|
f"W={w_norm:.0f}")
|
|
|
|
# Night: consolidate
|
|
mem.consolidate(num_epochs=3, homeostasis_factor=0.85,
|
|
prune_threshold=0.01)
|
|
mem.consolidator.selective_clear(keep_fraction=0.3)
|
|
|
|
cos_after, rate_after = mem.test_recall(all_cues, all_targets)
|
|
cos_day1_after, _ = mem.test_recall(all_cues[:200], all_targets[:200])
|
|
w_after = mem.W.data.norm().item()
|
|
print(f" → Night {day}: "
|
|
f"All={cos_after:.4f}({rate_after:.0%}), Day1={cos_day1_after:.4f}, "
|
|
f"W={w_after:.0f}")
|
|
|
|
|
|
def test_homeostasis_sweep():
|
|
"""Find the right homeostasis factor."""
|
|
print("\n=== Homeostasis Factor Sweep (500 pairs, 10 nights) ===")
|
|
|
|
for hf in [1.0, 0.99, 0.95, 0.90, 0.85, 0.80, 0.70]:
|
|
mem = SmallMemory()
|
|
cues, targets = gen_memories(500)
|
|
for i in range(500):
|
|
mem.learn(cues[i], targets[i])
|
|
|
|
for night in range(10):
|
|
mem.consolidate(num_epochs=1, homeostasis_factor=hf)
|
|
|
|
cos, rate = mem.test_recall(cues, targets)
|
|
w = mem.W.data.norm().item()
|
|
sp = (mem.W.data.abs() < 0.01).float().mean().item()
|
|
print(f" hf={hf:.2f}: CosSim={cos:.4f}, Rate={rate:.0%}, "
|
|
f"W_norm={w:.1f}, Sparsity={sp:.2%}")
|
|
|
|
|
|
def main():
|
|
print("=" * 60)
|
|
print("Experiment 3b: Consolidation Under Stress")
|
|
print("=" * 60)
|
|
|
|
test_capacity_with_consolidation()
|
|
test_multi_night_at_limit()
|
|
test_homeostasis_sweep()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|