Files
nuonuo/experiments/exp02_stdp_recall.py
Fam Zheng d923aa1e31 NuoNuo: Hippocampal memory module prototype
Hopfield + Hebbian hybrid memory system for LLMs.
Two nights of experiments (16 iterations), validated on LongMemEval (ICLR 2025).

Architecture:
- Single-hop: Two-Stage Hopfield (NN top-20 → softmax settle)
- Multi-hop: Hebbian W matrix with WTA pattern separation
- 64% on LongMemEval (500 questions), retrieval-only, no LLM dependency
- 4ms latency @ 20K memories, ~1GB VRAM

Key findings:
- Hopfield attention solved noise tolerance (20% → 100% vs flat Hebbian)
- WTA pattern separation enables 20K+ capacity
- Multi-hop associative chains (6 hops, CosSim=1.0) — RAG can't do this
- MiniLM-L6 is optimal (discrimination gap > absolute similarity)
- Paraphrase cue augmentation: 55% → 100% on synthetic, 36% → 64% on benchmark
- SNN encoder viable (CosSim 0.99) but not needed for current architecture
2026-04-07 10:37:24 +01:00

222 lines
7.2 KiB
Python

"""Experiment 2: STDP Associative Recall.
Core question: Can STDP learn associations between spike patterns,
such that presenting a cue recalls the correct target?
Test protocol:
1. Generate N pairs of (cue, target) spike patterns
2. Train STDP network on all pairs
3. Present each cue and measure similarity between recall and correct target
4. Measure interference: does recall of pair K degrade after learning pair K+1?
This is the make-or-break experiment for the whole approach.
"""
import sys
import time
import json
from pathlib import Path
import torch
import torch.nn as nn
import numpy as np
sys.path.insert(0, str(Path(__file__).parent.parent / "src"))
from nuonuo.memory import STDPMemoryNetwork
DEVICE = "cuda"
RESULTS_DIR = Path(__file__).parent.parent / "doc"
def spike_similarity(a, b):
"""Cosine similarity between two spike trains (flattened)."""
a_flat = a.flatten().float()
b_flat = b.flatten().float()
if a_flat.norm() == 0 or b_flat.norm() == 0:
return 0.0
return nn.functional.cosine_similarity(
a_flat.unsqueeze(0), b_flat.unsqueeze(0)
).item()
def firing_rate_similarity(a, b):
"""Similarity based on per-neuron firing rates."""
fr_a = a.float().mean(dim=0)
fr_b = b.float().mean(dim=0)
if fr_a.norm() == 0 or fr_b.norm() == 0:
return 0.0
return nn.functional.cosine_similarity(
fr_a.unsqueeze(0), fr_b.unsqueeze(0)
).item()
def generate_spike_pattern(num_steps, num_neurons, firing_rate=0.05, device="cuda"):
"""Generate a random sparse spike pattern."""
return (torch.rand(num_steps, num_neurons, device=device) < firing_rate).float()
def run_recall_test(num_neurons, num_steps, num_pairs, firing_rate,
num_presentations, a_plus, a_minus):
"""Test associative recall with given parameters."""
print(f" neurons={num_neurons}, steps={num_steps}, pairs={num_pairs}, "
f"FR={firing_rate}, pres={num_presentations}, "
f"A+={a_plus}, A-={a_minus}")
net = STDPMemoryNetwork(
num_neurons=num_neurons,
a_plus=a_plus,
a_minus=a_minus,
).to(DEVICE)
# Generate pattern pairs
cues = []
targets = []
for _ in range(num_pairs):
cue = generate_spike_pattern(num_steps, num_neurons, firing_rate, DEVICE)
target = generate_spike_pattern(num_steps, num_neurons, firing_rate, DEVICE)
cues.append(cue)
targets.append(target)
# Learn all pairs
t0 = time.time()
for i in range(num_pairs):
net.learn_association(cues[i], targets[i], num_presentations=num_presentations)
learn_time = time.time() - t0
# Test recall
correct_sims = []
wrong_sims = []
for i in range(num_pairs):
recalled = net.recall(cues[i], num_recall_steps=num_steps)
# Similarity to correct target
correct_sim = firing_rate_similarity(recalled, targets[i])
correct_sims.append(correct_sim)
# Similarity to wrong targets (average)
wrong_sim_list = []
for j in range(num_pairs):
if j != i:
wrong_sim_list.append(firing_rate_similarity(recalled, targets[j]))
if wrong_sim_list:
wrong_sims.append(np.mean(wrong_sim_list))
mean_correct = np.mean(correct_sims)
mean_wrong = np.mean(wrong_sims) if wrong_sims else 0
discrimination = mean_correct - mean_wrong
w_stats = net.get_weight_stats()
recall_fr = recalled.mean().item() if len(correct_sims) > 0 else 0
print(f" Correct sim: {mean_correct:.4f}, Wrong sim: {mean_wrong:.4f}, "
f"Discrimination: {discrimination:.4f}")
print(f" Recall FR: {recall_fr:.4f}, W stats: mean={w_stats['abs_mean']:.4f}, "
f"sparsity={w_stats['sparsity']:.2f}")
print(f" Learn time: {learn_time:.1f}s")
return {
"num_neurons": num_neurons,
"num_steps": num_steps,
"num_pairs": num_pairs,
"firing_rate": firing_rate,
"num_presentations": num_presentations,
"a_plus": a_plus,
"a_minus": a_minus,
"mean_correct_sim": mean_correct,
"mean_wrong_sim": mean_wrong,
"discrimination": discrimination,
"correct_sims": correct_sims,
"recall_firing_rate": recall_fr,
"weight_stats": w_stats,
"learn_time": learn_time,
}
def main():
print("=" * 60)
print("Experiment 2: STDP Associative Recall")
print("=" * 60)
results = []
# Test 1: Baseline — can it learn even 1 pair?
print("\n--- Test 1: Single pair (sanity check) ---")
r = run_recall_test(
num_neurons=2048, num_steps=64, num_pairs=1,
firing_rate=0.05, num_presentations=5,
a_plus=0.005, a_minus=0.006,
)
results.append({**r, "test": "single_pair"})
# Test 2: Vary number of pairs
print("\n--- Test 2: Scaling pairs ---")
for n_pairs in [5, 10, 20, 50]:
r = run_recall_test(
num_neurons=2048, num_steps=64, num_pairs=n_pairs,
firing_rate=0.05, num_presentations=5,
a_plus=0.005, a_minus=0.006,
)
results.append({**r, "test": f"pairs_{n_pairs}"})
# Test 3: Vary STDP learning rates
print("\n--- Test 3: STDP learning rate sweep ---")
for a_plus in [0.001, 0.005, 0.01, 0.05]:
r = run_recall_test(
num_neurons=2048, num_steps=64, num_pairs=10,
firing_rate=0.05, num_presentations=5,
a_plus=a_plus, a_minus=a_plus * 1.2,
)
results.append({**r, "test": f"lr_{a_plus}"})
# Test 4: Vary firing rate
print("\n--- Test 4: Firing rate sweep ---")
for fr in [0.02, 0.05, 0.10, 0.20]:
r = run_recall_test(
num_neurons=2048, num_steps=64, num_pairs=10,
firing_rate=fr, num_presentations=5,
a_plus=0.005, a_minus=0.006,
)
results.append({**r, "test": f"fr_{fr}"})
# Test 5: More presentations
print("\n--- Test 5: Presentation count ---")
for n_pres in [1, 3, 5, 10, 20]:
r = run_recall_test(
num_neurons=2048, num_steps=64, num_pairs=10,
firing_rate=0.05, num_presentations=n_pres,
a_plus=0.005, a_minus=0.006,
)
results.append({**r, "test": f"pres_{n_pres}"})
# Test 6: Wider network
print("\n--- Test 6: Network width ---")
for neurons in [1024, 2048, 4096, 8192]:
r = run_recall_test(
num_neurons=neurons, num_steps=64, num_pairs=10,
firing_rate=0.05, num_presentations=5,
a_plus=0.005, a_minus=0.006,
)
results.append({**r, "test": f"width_{neurons}"})
# Save results
for r in results:
r["correct_sims"] = [float(x) for x in r["correct_sims"]]
with open(RESULTS_DIR / "exp02_results.json", "w") as f:
json.dump(results, f, indent=2, default=float)
# Summary
print("\n" + "=" * 60)
print("SUMMARY")
print("=" * 60)
print(f"{'Test':<15} {'Correct':>8} {'Wrong':>8} {'Discrim':>8} {'RecallFR':>8}")
print("-" * 50)
for r in results:
print(f"{r['test']:<15} {r['mean_correct_sim']:>8.4f} "
f"{r['mean_wrong_sim']:>8.4f} {r['discrimination']:>8.4f} "
f"{r['recall_firing_rate']:>8.4f}")
if __name__ == "__main__":
main()