NuoNuo: Hippocampal memory module prototype
Hopfield + Hebbian hybrid memory system for LLMs. Two nights of experiments (16 iterations), validated on LongMemEval (ICLR 2025). Architecture: - Single-hop: Two-Stage Hopfield (NN top-20 → softmax settle) - Multi-hop: Hebbian W matrix with WTA pattern separation - 64% on LongMemEval (500 questions), retrieval-only, no LLM dependency - 4ms latency @ 20K memories, ~1GB VRAM Key findings: - Hopfield attention solved noise tolerance (20% → 100% vs flat Hebbian) - WTA pattern separation enables 20K+ capacity - Multi-hop associative chains (6 hops, CosSim=1.0) — RAG can't do this - MiniLM-L6 is optimal (discrimination gap > absolute similarity) - Paraphrase cue augmentation: 55% → 100% on synthetic, 36% → 64% on benchmark - SNN encoder viable (CosSim 0.99) but not needed for current architecture
This commit is contained in:
192
experiments/exp02b_stdp_v2.py
Normal file
192
experiments/exp02b_stdp_v2.py
Normal file
@@ -0,0 +1,192 @@
|
||||
"""Experiment 2b: STDP Associative Recall (v2 - fixed learning).
|
||||
|
||||
v1 failed completely because W=0 → no spikes → no STDP updates (chicken-egg).
|
||||
v2 fixes this with teacher-forced STDP: directly use (cue, target) as (pre, post).
|
||||
|
||||
Also tests DirectAssociativeMemory (simple outer-product Hebbian) as baseline.
|
||||
"""
|
||||
|
||||
import sys
|
||||
import time
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent / "src"))
|
||||
from nuonuo.memory import STDPMemoryNetwork, DirectAssociativeMemory
|
||||
|
||||
DEVICE = "cuda"
|
||||
RESULTS_DIR = Path(__file__).parent.parent / "doc"
|
||||
|
||||
|
||||
def spike_cosine(a, b):
|
||||
"""Cosine similarity on firing rate vectors."""
|
||||
if a.dim() == 2:
|
||||
a = a.mean(dim=0)
|
||||
if b.dim() == 2:
|
||||
b = b.mean(dim=0)
|
||||
if a.norm() == 0 or b.norm() == 0:
|
||||
return 0.0
|
||||
return nn.functional.cosine_similarity(a.unsqueeze(0), b.unsqueeze(0)).item()
|
||||
|
||||
|
||||
def vec_cosine(a, b):
|
||||
"""Cosine similarity of two 1D vectors."""
|
||||
if a.norm() == 0 or b.norm() == 0:
|
||||
return 0.0
|
||||
return nn.functional.cosine_similarity(a.unsqueeze(0), b.unsqueeze(0)).item()
|
||||
|
||||
|
||||
def gen_spikes(num_steps, num_neurons, fr=0.05, device="cuda"):
|
||||
return (torch.rand(num_steps, num_neurons, device=device) < fr).float()
|
||||
|
||||
|
||||
def test_stdp_v2(num_neurons, num_steps, num_pairs, fr, num_pres, a_plus):
|
||||
"""Test the v2 STDP network."""
|
||||
net = STDPMemoryNetwork(
|
||||
num_neurons=num_neurons, a_plus=a_plus, a_minus=a_plus*1.2,
|
||||
w_init_std=0.01
|
||||
).to(DEVICE)
|
||||
|
||||
cues = [gen_spikes(num_steps, num_neurons, fr) for _ in range(num_pairs)]
|
||||
targets = [gen_spikes(num_steps, num_neurons, fr) for _ in range(num_pairs)]
|
||||
|
||||
# Learn
|
||||
t0 = time.time()
|
||||
for i in range(num_pairs):
|
||||
net.learn_association(cues[i], targets[i], num_presentations=num_pres)
|
||||
learn_t = time.time() - t0
|
||||
|
||||
# Recall
|
||||
correct_sims = []
|
||||
wrong_sims = []
|
||||
for i in range(num_pairs):
|
||||
recalled = net.recall(cues[i])
|
||||
cs = spike_cosine(recalled, targets[i])
|
||||
correct_sims.append(cs)
|
||||
for j in range(num_pairs):
|
||||
if j != i:
|
||||
wrong_sims.append(spike_cosine(recalled, targets[j]))
|
||||
|
||||
mc = np.mean(correct_sims)
|
||||
mw = np.mean(wrong_sims) if wrong_sims else 0
|
||||
ws = net.get_weight_stats()
|
||||
|
||||
print(f" STDP: pairs={num_pairs}, pres={num_pres}, A+={a_plus:.3f} | "
|
||||
f"Correct={mc:.4f}, Wrong={mw:.4f}, Disc={mc-mw:.4f}, "
|
||||
f"W_abs={ws['abs_mean']:.4f}, sparsity={ws['sparsity']:.2f}, "
|
||||
f"time={learn_t:.1f}s")
|
||||
|
||||
return {"method": "stdp_v2", "correct": mc, "wrong": mw,
|
||||
"disc": mc-mw, "w_stats": ws, "time": learn_t,
|
||||
"num_pairs": num_pairs, "a_plus": a_plus, "num_pres": num_pres}
|
||||
|
||||
|
||||
def test_direct_hebbian(num_neurons, num_steps, num_pairs, fr, lr):
|
||||
"""Test the direct outer-product Hebbian memory."""
|
||||
net = DirectAssociativeMemory(num_neurons=num_neurons, lr=lr).to(DEVICE)
|
||||
|
||||
cues = [gen_spikes(num_steps, num_neurons, fr) for _ in range(num_pairs)]
|
||||
targets = [gen_spikes(num_steps, num_neurons, fr) for _ in range(num_pairs)]
|
||||
|
||||
# Learn
|
||||
t0 = time.time()
|
||||
for i in range(num_pairs):
|
||||
net.learn(cues[i], targets[i])
|
||||
learn_t = time.time() - t0
|
||||
|
||||
# Recall
|
||||
correct_sims = []
|
||||
wrong_sims = []
|
||||
for i in range(num_pairs):
|
||||
recalled = net.recall(cues[i]) # continuous vector
|
||||
target_rate = targets[i].mean(dim=0)
|
||||
cs = vec_cosine(recalled, target_rate)
|
||||
correct_sims.append(cs)
|
||||
for j in range(num_pairs):
|
||||
if j != i:
|
||||
wrong_sims.append(vec_cosine(recalled, targets[j].mean(dim=0)))
|
||||
|
||||
mc = np.mean(correct_sims)
|
||||
mw = np.mean(wrong_sims) if wrong_sims else 0
|
||||
ws = net.get_weight_stats()
|
||||
|
||||
print(f" Hebbian: pairs={num_pairs}, lr={lr:.3f} | "
|
||||
f"Correct={mc:.4f}, Wrong={mw:.4f}, Disc={mc-mw:.4f}, "
|
||||
f"W_abs={ws['abs_mean']:.6f}, sparsity={ws['sparsity']:.2f}, "
|
||||
f"time={learn_t:.3f}s")
|
||||
|
||||
return {"method": "direct_hebbian", "correct": mc, "wrong": mw,
|
||||
"disc": mc-mw, "w_stats": ws, "time": learn_t,
|
||||
"num_pairs": num_pairs, "lr": lr}
|
||||
|
||||
|
||||
def main():
|
||||
print("=" * 60)
|
||||
print("Experiment 2b: STDP v2 + Direct Hebbian")
|
||||
print("=" * 60)
|
||||
|
||||
results = []
|
||||
N = 2048
|
||||
S = 64
|
||||
FR = 0.05
|
||||
|
||||
# --- Part A: Direct Hebbian (baseline) ---
|
||||
print("\n=== Part A: Direct Hebbian Memory ===")
|
||||
|
||||
print("\nA1: Scaling pairs (lr=0.5)")
|
||||
for n in [1, 5, 10, 20, 50, 100]:
|
||||
r = test_direct_hebbian(N, S, n, FR, lr=0.5)
|
||||
results.append({**r, "test": f"hebb_pairs_{n}"})
|
||||
|
||||
print("\nA2: Learning rate sweep (10 pairs)")
|
||||
for lr in [0.01, 0.1, 0.5, 1.0, 5.0]:
|
||||
r = test_direct_hebbian(N, S, 10, FR, lr=lr)
|
||||
results.append({**r, "test": f"hebb_lr_{lr}"})
|
||||
|
||||
# --- Part B: STDP v2 ---
|
||||
print("\n=== Part B: STDP v2 (teacher-forced) ===")
|
||||
|
||||
print("\nB1: Sanity check - single pair")
|
||||
r = test_stdp_v2(N, S, 1, FR, num_pres=5, a_plus=0.01)
|
||||
results.append({**r, "test": "stdp_single"})
|
||||
|
||||
print("\nB2: A+ sweep (10 pairs, 5 presentations)")
|
||||
for ap in [0.001, 0.005, 0.01, 0.05, 0.1]:
|
||||
r = test_stdp_v2(N, S, 10, FR, num_pres=5, a_plus=ap)
|
||||
results.append({**r, "test": f"stdp_ap_{ap}"})
|
||||
|
||||
print("\nB3: Presentation count (10 pairs, A+=0.01)")
|
||||
for pres in [1, 3, 5, 10, 20]:
|
||||
r = test_stdp_v2(N, S, 10, FR, num_pres=pres, a_plus=0.01)
|
||||
results.append({**r, "test": f"stdp_pres_{pres}"})
|
||||
|
||||
print("\nB4: Scaling pairs (A+=0.01, 5 presentations)")
|
||||
for n in [1, 5, 10, 20, 50]:
|
||||
r = test_stdp_v2(N, S, n, FR, num_pres=5, a_plus=0.01)
|
||||
results.append({**r, "test": f"stdp_pairs_{n}"})
|
||||
|
||||
# Save
|
||||
with open(RESULTS_DIR / "exp02b_results.json", "w") as f:
|
||||
json.dump(results, f, indent=2, default=float)
|
||||
|
||||
# Best from each method
|
||||
print("\n" + "=" * 60)
|
||||
hebb_best = max([r for r in results if r["method"] == "direct_hebbian"],
|
||||
key=lambda x: x["disc"], default=None)
|
||||
stdp_best = max([r for r in results if r["method"] == "stdp_v2"],
|
||||
key=lambda x: x["disc"], default=None)
|
||||
|
||||
if hebb_best:
|
||||
print(f"Best Hebbian: {hebb_best['test']} — "
|
||||
f"Correct={hebb_best['correct']:.4f}, Disc={hebb_best['disc']:.4f}")
|
||||
if stdp_best:
|
||||
print(f"Best STDP: {stdp_best['test']} — "
|
||||
f"Correct={stdp_best['correct']:.4f}, Disc={stdp_best['disc']:.4f}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user