Files
nuonuo/experiments/exp02b_stdp_v2.py
Fam Zheng d923aa1e31 NuoNuo: Hippocampal memory module prototype
Hopfield + Hebbian hybrid memory system for LLMs.
Two nights of experiments (16 iterations), validated on LongMemEval (ICLR 2025).

Architecture:
- Single-hop: Two-Stage Hopfield (NN top-20 → softmax settle)
- Multi-hop: Hebbian W matrix with WTA pattern separation
- 64% on LongMemEval (500 questions), retrieval-only, no LLM dependency
- 4ms latency @ 20K memories, ~1GB VRAM

Key findings:
- Hopfield attention solved noise tolerance (20% → 100% vs flat Hebbian)
- WTA pattern separation enables 20K+ capacity
- Multi-hop associative chains (6 hops, CosSim=1.0) — RAG can't do this
- MiniLM-L6 is optimal (discrimination gap > absolute similarity)
- Paraphrase cue augmentation: 55% → 100% on synthetic, 36% → 64% on benchmark
- SNN encoder viable (CosSim 0.99) but not needed for current architecture
2026-04-07 10:37:24 +01:00

193 lines
6.4 KiB
Python

"""Experiment 2b: STDP Associative Recall (v2 - fixed learning).
v1 failed completely because W=0 → no spikes → no STDP updates (chicken-egg).
v2 fixes this with teacher-forced STDP: directly use (cue, target) as (pre, post).
Also tests DirectAssociativeMemory (simple outer-product Hebbian) as baseline.
"""
import sys
import time
import json
from pathlib import Path
import torch
import torch.nn as nn
import numpy as np
sys.path.insert(0, str(Path(__file__).parent.parent / "src"))
from nuonuo.memory import STDPMemoryNetwork, DirectAssociativeMemory
DEVICE = "cuda"
RESULTS_DIR = Path(__file__).parent.parent / "doc"
def spike_cosine(a, b):
"""Cosine similarity on firing rate vectors."""
if a.dim() == 2:
a = a.mean(dim=0)
if b.dim() == 2:
b = b.mean(dim=0)
if a.norm() == 0 or b.norm() == 0:
return 0.0
return nn.functional.cosine_similarity(a.unsqueeze(0), b.unsqueeze(0)).item()
def vec_cosine(a, b):
"""Cosine similarity of two 1D vectors."""
if a.norm() == 0 or b.norm() == 0:
return 0.0
return nn.functional.cosine_similarity(a.unsqueeze(0), b.unsqueeze(0)).item()
def gen_spikes(num_steps, num_neurons, fr=0.05, device="cuda"):
return (torch.rand(num_steps, num_neurons, device=device) < fr).float()
def test_stdp_v2(num_neurons, num_steps, num_pairs, fr, num_pres, a_plus):
"""Test the v2 STDP network."""
net = STDPMemoryNetwork(
num_neurons=num_neurons, a_plus=a_plus, a_minus=a_plus*1.2,
w_init_std=0.01
).to(DEVICE)
cues = [gen_spikes(num_steps, num_neurons, fr) for _ in range(num_pairs)]
targets = [gen_spikes(num_steps, num_neurons, fr) for _ in range(num_pairs)]
# Learn
t0 = time.time()
for i in range(num_pairs):
net.learn_association(cues[i], targets[i], num_presentations=num_pres)
learn_t = time.time() - t0
# Recall
correct_sims = []
wrong_sims = []
for i in range(num_pairs):
recalled = net.recall(cues[i])
cs = spike_cosine(recalled, targets[i])
correct_sims.append(cs)
for j in range(num_pairs):
if j != i:
wrong_sims.append(spike_cosine(recalled, targets[j]))
mc = np.mean(correct_sims)
mw = np.mean(wrong_sims) if wrong_sims else 0
ws = net.get_weight_stats()
print(f" STDP: pairs={num_pairs}, pres={num_pres}, A+={a_plus:.3f} | "
f"Correct={mc:.4f}, Wrong={mw:.4f}, Disc={mc-mw:.4f}, "
f"W_abs={ws['abs_mean']:.4f}, sparsity={ws['sparsity']:.2f}, "
f"time={learn_t:.1f}s")
return {"method": "stdp_v2", "correct": mc, "wrong": mw,
"disc": mc-mw, "w_stats": ws, "time": learn_t,
"num_pairs": num_pairs, "a_plus": a_plus, "num_pres": num_pres}
def test_direct_hebbian(num_neurons, num_steps, num_pairs, fr, lr):
"""Test the direct outer-product Hebbian memory."""
net = DirectAssociativeMemory(num_neurons=num_neurons, lr=lr).to(DEVICE)
cues = [gen_spikes(num_steps, num_neurons, fr) for _ in range(num_pairs)]
targets = [gen_spikes(num_steps, num_neurons, fr) for _ in range(num_pairs)]
# Learn
t0 = time.time()
for i in range(num_pairs):
net.learn(cues[i], targets[i])
learn_t = time.time() - t0
# Recall
correct_sims = []
wrong_sims = []
for i in range(num_pairs):
recalled = net.recall(cues[i]) # continuous vector
target_rate = targets[i].mean(dim=0)
cs = vec_cosine(recalled, target_rate)
correct_sims.append(cs)
for j in range(num_pairs):
if j != i:
wrong_sims.append(vec_cosine(recalled, targets[j].mean(dim=0)))
mc = np.mean(correct_sims)
mw = np.mean(wrong_sims) if wrong_sims else 0
ws = net.get_weight_stats()
print(f" Hebbian: pairs={num_pairs}, lr={lr:.3f} | "
f"Correct={mc:.4f}, Wrong={mw:.4f}, Disc={mc-mw:.4f}, "
f"W_abs={ws['abs_mean']:.6f}, sparsity={ws['sparsity']:.2f}, "
f"time={learn_t:.3f}s")
return {"method": "direct_hebbian", "correct": mc, "wrong": mw,
"disc": mc-mw, "w_stats": ws, "time": learn_t,
"num_pairs": num_pairs, "lr": lr}
def main():
print("=" * 60)
print("Experiment 2b: STDP v2 + Direct Hebbian")
print("=" * 60)
results = []
N = 2048
S = 64
FR = 0.05
# --- Part A: Direct Hebbian (baseline) ---
print("\n=== Part A: Direct Hebbian Memory ===")
print("\nA1: Scaling pairs (lr=0.5)")
for n in [1, 5, 10, 20, 50, 100]:
r = test_direct_hebbian(N, S, n, FR, lr=0.5)
results.append({**r, "test": f"hebb_pairs_{n}"})
print("\nA2: Learning rate sweep (10 pairs)")
for lr in [0.01, 0.1, 0.5, 1.0, 5.0]:
r = test_direct_hebbian(N, S, 10, FR, lr=lr)
results.append({**r, "test": f"hebb_lr_{lr}"})
# --- Part B: STDP v2 ---
print("\n=== Part B: STDP v2 (teacher-forced) ===")
print("\nB1: Sanity check - single pair")
r = test_stdp_v2(N, S, 1, FR, num_pres=5, a_plus=0.01)
results.append({**r, "test": "stdp_single"})
print("\nB2: A+ sweep (10 pairs, 5 presentations)")
for ap in [0.001, 0.005, 0.01, 0.05, 0.1]:
r = test_stdp_v2(N, S, 10, FR, num_pres=5, a_plus=ap)
results.append({**r, "test": f"stdp_ap_{ap}"})
print("\nB3: Presentation count (10 pairs, A+=0.01)")
for pres in [1, 3, 5, 10, 20]:
r = test_stdp_v2(N, S, 10, FR, num_pres=pres, a_plus=0.01)
results.append({**r, "test": f"stdp_pres_{pres}"})
print("\nB4: Scaling pairs (A+=0.01, 5 presentations)")
for n in [1, 5, 10, 20, 50]:
r = test_stdp_v2(N, S, n, FR, num_pres=5, a_plus=0.01)
results.append({**r, "test": f"stdp_pairs_{n}"})
# Save
with open(RESULTS_DIR / "exp02b_results.json", "w") as f:
json.dump(results, f, indent=2, default=float)
# Best from each method
print("\n" + "=" * 60)
hebb_best = max([r for r in results if r["method"] == "direct_hebbian"],
key=lambda x: x["disc"], default=None)
stdp_best = max([r for r in results if r["method"] == "stdp_v2"],
key=lambda x: x["disc"], default=None)
if hebb_best:
print(f"Best Hebbian: {hebb_best['test']}"
f"Correct={hebb_best['correct']:.4f}, Disc={hebb_best['disc']:.4f}")
if stdp_best:
print(f"Best STDP: {stdp_best['test']}"
f"Correct={stdp_best['correct']:.4f}, Disc={stdp_best['disc']:.4f}")
if __name__ == "__main__":
main()