Hopfield + Hebbian hybrid memory system for LLMs. Two nights of experiments (16 iterations), validated on LongMemEval (ICLR 2025). Architecture: - Single-hop: Two-Stage Hopfield (NN top-20 → softmax settle) - Multi-hop: Hebbian W matrix with WTA pattern separation - 64% on LongMemEval (500 questions), retrieval-only, no LLM dependency - 4ms latency @ 20K memories, ~1GB VRAM Key findings: - Hopfield attention solved noise tolerance (20% → 100% vs flat Hebbian) - WTA pattern separation enables 20K+ capacity - Multi-hop associative chains (6 hops, CosSim=1.0) — RAG can't do this - MiniLM-L6 is optimal (discrimination gap > absolute similarity) - Paraphrase cue augmentation: 55% → 100% on synthetic, 36% → 64% on benchmark - SNN encoder viable (CosSim 0.99) but not needed for current architecture
193 lines
6.4 KiB
Python
193 lines
6.4 KiB
Python
"""Experiment 2b: STDP Associative Recall (v2 - fixed learning).
|
|
|
|
v1 failed completely because W=0 → no spikes → no STDP updates (chicken-egg).
|
|
v2 fixes this with teacher-forced STDP: directly use (cue, target) as (pre, post).
|
|
|
|
Also tests DirectAssociativeMemory (simple outer-product Hebbian) as baseline.
|
|
"""
|
|
|
|
import sys
|
|
import time
|
|
import json
|
|
from pathlib import Path
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import numpy as np
|
|
|
|
sys.path.insert(0, str(Path(__file__).parent.parent / "src"))
|
|
from nuonuo.memory import STDPMemoryNetwork, DirectAssociativeMemory
|
|
|
|
DEVICE = "cuda"
|
|
RESULTS_DIR = Path(__file__).parent.parent / "doc"
|
|
|
|
|
|
def spike_cosine(a, b):
|
|
"""Cosine similarity on firing rate vectors."""
|
|
if a.dim() == 2:
|
|
a = a.mean(dim=0)
|
|
if b.dim() == 2:
|
|
b = b.mean(dim=0)
|
|
if a.norm() == 0 or b.norm() == 0:
|
|
return 0.0
|
|
return nn.functional.cosine_similarity(a.unsqueeze(0), b.unsqueeze(0)).item()
|
|
|
|
|
|
def vec_cosine(a, b):
|
|
"""Cosine similarity of two 1D vectors."""
|
|
if a.norm() == 0 or b.norm() == 0:
|
|
return 0.0
|
|
return nn.functional.cosine_similarity(a.unsqueeze(0), b.unsqueeze(0)).item()
|
|
|
|
|
|
def gen_spikes(num_steps, num_neurons, fr=0.05, device="cuda"):
|
|
return (torch.rand(num_steps, num_neurons, device=device) < fr).float()
|
|
|
|
|
|
def test_stdp_v2(num_neurons, num_steps, num_pairs, fr, num_pres, a_plus):
|
|
"""Test the v2 STDP network."""
|
|
net = STDPMemoryNetwork(
|
|
num_neurons=num_neurons, a_plus=a_plus, a_minus=a_plus*1.2,
|
|
w_init_std=0.01
|
|
).to(DEVICE)
|
|
|
|
cues = [gen_spikes(num_steps, num_neurons, fr) for _ in range(num_pairs)]
|
|
targets = [gen_spikes(num_steps, num_neurons, fr) for _ in range(num_pairs)]
|
|
|
|
# Learn
|
|
t0 = time.time()
|
|
for i in range(num_pairs):
|
|
net.learn_association(cues[i], targets[i], num_presentations=num_pres)
|
|
learn_t = time.time() - t0
|
|
|
|
# Recall
|
|
correct_sims = []
|
|
wrong_sims = []
|
|
for i in range(num_pairs):
|
|
recalled = net.recall(cues[i])
|
|
cs = spike_cosine(recalled, targets[i])
|
|
correct_sims.append(cs)
|
|
for j in range(num_pairs):
|
|
if j != i:
|
|
wrong_sims.append(spike_cosine(recalled, targets[j]))
|
|
|
|
mc = np.mean(correct_sims)
|
|
mw = np.mean(wrong_sims) if wrong_sims else 0
|
|
ws = net.get_weight_stats()
|
|
|
|
print(f" STDP: pairs={num_pairs}, pres={num_pres}, A+={a_plus:.3f} | "
|
|
f"Correct={mc:.4f}, Wrong={mw:.4f}, Disc={mc-mw:.4f}, "
|
|
f"W_abs={ws['abs_mean']:.4f}, sparsity={ws['sparsity']:.2f}, "
|
|
f"time={learn_t:.1f}s")
|
|
|
|
return {"method": "stdp_v2", "correct": mc, "wrong": mw,
|
|
"disc": mc-mw, "w_stats": ws, "time": learn_t,
|
|
"num_pairs": num_pairs, "a_plus": a_plus, "num_pres": num_pres}
|
|
|
|
|
|
def test_direct_hebbian(num_neurons, num_steps, num_pairs, fr, lr):
|
|
"""Test the direct outer-product Hebbian memory."""
|
|
net = DirectAssociativeMemory(num_neurons=num_neurons, lr=lr).to(DEVICE)
|
|
|
|
cues = [gen_spikes(num_steps, num_neurons, fr) for _ in range(num_pairs)]
|
|
targets = [gen_spikes(num_steps, num_neurons, fr) for _ in range(num_pairs)]
|
|
|
|
# Learn
|
|
t0 = time.time()
|
|
for i in range(num_pairs):
|
|
net.learn(cues[i], targets[i])
|
|
learn_t = time.time() - t0
|
|
|
|
# Recall
|
|
correct_sims = []
|
|
wrong_sims = []
|
|
for i in range(num_pairs):
|
|
recalled = net.recall(cues[i]) # continuous vector
|
|
target_rate = targets[i].mean(dim=0)
|
|
cs = vec_cosine(recalled, target_rate)
|
|
correct_sims.append(cs)
|
|
for j in range(num_pairs):
|
|
if j != i:
|
|
wrong_sims.append(vec_cosine(recalled, targets[j].mean(dim=0)))
|
|
|
|
mc = np.mean(correct_sims)
|
|
mw = np.mean(wrong_sims) if wrong_sims else 0
|
|
ws = net.get_weight_stats()
|
|
|
|
print(f" Hebbian: pairs={num_pairs}, lr={lr:.3f} | "
|
|
f"Correct={mc:.4f}, Wrong={mw:.4f}, Disc={mc-mw:.4f}, "
|
|
f"W_abs={ws['abs_mean']:.6f}, sparsity={ws['sparsity']:.2f}, "
|
|
f"time={learn_t:.3f}s")
|
|
|
|
return {"method": "direct_hebbian", "correct": mc, "wrong": mw,
|
|
"disc": mc-mw, "w_stats": ws, "time": learn_t,
|
|
"num_pairs": num_pairs, "lr": lr}
|
|
|
|
|
|
def main():
|
|
print("=" * 60)
|
|
print("Experiment 2b: STDP v2 + Direct Hebbian")
|
|
print("=" * 60)
|
|
|
|
results = []
|
|
N = 2048
|
|
S = 64
|
|
FR = 0.05
|
|
|
|
# --- Part A: Direct Hebbian (baseline) ---
|
|
print("\n=== Part A: Direct Hebbian Memory ===")
|
|
|
|
print("\nA1: Scaling pairs (lr=0.5)")
|
|
for n in [1, 5, 10, 20, 50, 100]:
|
|
r = test_direct_hebbian(N, S, n, FR, lr=0.5)
|
|
results.append({**r, "test": f"hebb_pairs_{n}"})
|
|
|
|
print("\nA2: Learning rate sweep (10 pairs)")
|
|
for lr in [0.01, 0.1, 0.5, 1.0, 5.0]:
|
|
r = test_direct_hebbian(N, S, 10, FR, lr=lr)
|
|
results.append({**r, "test": f"hebb_lr_{lr}"})
|
|
|
|
# --- Part B: STDP v2 ---
|
|
print("\n=== Part B: STDP v2 (teacher-forced) ===")
|
|
|
|
print("\nB1: Sanity check - single pair")
|
|
r = test_stdp_v2(N, S, 1, FR, num_pres=5, a_plus=0.01)
|
|
results.append({**r, "test": "stdp_single"})
|
|
|
|
print("\nB2: A+ sweep (10 pairs, 5 presentations)")
|
|
for ap in [0.001, 0.005, 0.01, 0.05, 0.1]:
|
|
r = test_stdp_v2(N, S, 10, FR, num_pres=5, a_plus=ap)
|
|
results.append({**r, "test": f"stdp_ap_{ap}"})
|
|
|
|
print("\nB3: Presentation count (10 pairs, A+=0.01)")
|
|
for pres in [1, 3, 5, 10, 20]:
|
|
r = test_stdp_v2(N, S, 10, FR, num_pres=pres, a_plus=0.01)
|
|
results.append({**r, "test": f"stdp_pres_{pres}"})
|
|
|
|
print("\nB4: Scaling pairs (A+=0.01, 5 presentations)")
|
|
for n in [1, 5, 10, 20, 50]:
|
|
r = test_stdp_v2(N, S, n, FR, num_pres=5, a_plus=0.01)
|
|
results.append({**r, "test": f"stdp_pairs_{n}"})
|
|
|
|
# Save
|
|
with open(RESULTS_DIR / "exp02b_results.json", "w") as f:
|
|
json.dump(results, f, indent=2, default=float)
|
|
|
|
# Best from each method
|
|
print("\n" + "=" * 60)
|
|
hebb_best = max([r for r in results if r["method"] == "direct_hebbian"],
|
|
key=lambda x: x["disc"], default=None)
|
|
stdp_best = max([r for r in results if r["method"] == "stdp_v2"],
|
|
key=lambda x: x["disc"], default=None)
|
|
|
|
if hebb_best:
|
|
print(f"Best Hebbian: {hebb_best['test']} — "
|
|
f"Correct={hebb_best['correct']:.4f}, Disc={hebb_best['disc']:.4f}")
|
|
if stdp_best:
|
|
print(f"Best STDP: {stdp_best['test']} — "
|
|
f"Correct={stdp_best['correct']:.4f}, Disc={stdp_best['disc']:.4f}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|