Hopfield + Hebbian hybrid memory system for LLMs. Two nights of experiments (16 iterations), validated on LongMemEval (ICLR 2025). Architecture: - Single-hop: Two-Stage Hopfield (NN top-20 → softmax settle) - Multi-hop: Hebbian W matrix with WTA pattern separation - 64% on LongMemEval (500 questions), retrieval-only, no LLM dependency - 4ms latency @ 20K memories, ~1GB VRAM Key findings: - Hopfield attention solved noise tolerance (20% → 100% vs flat Hebbian) - WTA pattern separation enables 20K+ capacity - Multi-hop associative chains (6 hops, CosSim=1.0) — RAG can't do this - MiniLM-L6 is optimal (discrimination gap > absolute similarity) - Paraphrase cue augmentation: 55% → 100% on synthetic, 36% → 64% on benchmark - SNN encoder viable (CosSim 0.99) but not needed for current architecture
222 lines
7.2 KiB
Python
222 lines
7.2 KiB
Python
"""Experiment 2: STDP Associative Recall.
|
|
|
|
Core question: Can STDP learn associations between spike patterns,
|
|
such that presenting a cue recalls the correct target?
|
|
|
|
Test protocol:
|
|
1. Generate N pairs of (cue, target) spike patterns
|
|
2. Train STDP network on all pairs
|
|
3. Present each cue and measure similarity between recall and correct target
|
|
4. Measure interference: does recall of pair K degrade after learning pair K+1?
|
|
|
|
This is the make-or-break experiment for the whole approach.
|
|
"""
|
|
|
|
import sys
|
|
import time
|
|
import json
|
|
from pathlib import Path
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import numpy as np
|
|
|
|
sys.path.insert(0, str(Path(__file__).parent.parent / "src"))
|
|
from nuonuo.memory import STDPMemoryNetwork
|
|
|
|
DEVICE = "cuda"
|
|
RESULTS_DIR = Path(__file__).parent.parent / "doc"
|
|
|
|
|
|
def spike_similarity(a, b):
|
|
"""Cosine similarity between two spike trains (flattened)."""
|
|
a_flat = a.flatten().float()
|
|
b_flat = b.flatten().float()
|
|
if a_flat.norm() == 0 or b_flat.norm() == 0:
|
|
return 0.0
|
|
return nn.functional.cosine_similarity(
|
|
a_flat.unsqueeze(0), b_flat.unsqueeze(0)
|
|
).item()
|
|
|
|
|
|
def firing_rate_similarity(a, b):
|
|
"""Similarity based on per-neuron firing rates."""
|
|
fr_a = a.float().mean(dim=0)
|
|
fr_b = b.float().mean(dim=0)
|
|
if fr_a.norm() == 0 or fr_b.norm() == 0:
|
|
return 0.0
|
|
return nn.functional.cosine_similarity(
|
|
fr_a.unsqueeze(0), fr_b.unsqueeze(0)
|
|
).item()
|
|
|
|
|
|
def generate_spike_pattern(num_steps, num_neurons, firing_rate=0.05, device="cuda"):
|
|
"""Generate a random sparse spike pattern."""
|
|
return (torch.rand(num_steps, num_neurons, device=device) < firing_rate).float()
|
|
|
|
|
|
def run_recall_test(num_neurons, num_steps, num_pairs, firing_rate,
|
|
num_presentations, a_plus, a_minus):
|
|
"""Test associative recall with given parameters."""
|
|
print(f" neurons={num_neurons}, steps={num_steps}, pairs={num_pairs}, "
|
|
f"FR={firing_rate}, pres={num_presentations}, "
|
|
f"A+={a_plus}, A-={a_minus}")
|
|
|
|
net = STDPMemoryNetwork(
|
|
num_neurons=num_neurons,
|
|
a_plus=a_plus,
|
|
a_minus=a_minus,
|
|
).to(DEVICE)
|
|
|
|
# Generate pattern pairs
|
|
cues = []
|
|
targets = []
|
|
for _ in range(num_pairs):
|
|
cue = generate_spike_pattern(num_steps, num_neurons, firing_rate, DEVICE)
|
|
target = generate_spike_pattern(num_steps, num_neurons, firing_rate, DEVICE)
|
|
cues.append(cue)
|
|
targets.append(target)
|
|
|
|
# Learn all pairs
|
|
t0 = time.time()
|
|
for i in range(num_pairs):
|
|
net.learn_association(cues[i], targets[i], num_presentations=num_presentations)
|
|
learn_time = time.time() - t0
|
|
|
|
# Test recall
|
|
correct_sims = []
|
|
wrong_sims = []
|
|
|
|
for i in range(num_pairs):
|
|
recalled = net.recall(cues[i], num_recall_steps=num_steps)
|
|
|
|
# Similarity to correct target
|
|
correct_sim = firing_rate_similarity(recalled, targets[i])
|
|
correct_sims.append(correct_sim)
|
|
|
|
# Similarity to wrong targets (average)
|
|
wrong_sim_list = []
|
|
for j in range(num_pairs):
|
|
if j != i:
|
|
wrong_sim_list.append(firing_rate_similarity(recalled, targets[j]))
|
|
if wrong_sim_list:
|
|
wrong_sims.append(np.mean(wrong_sim_list))
|
|
|
|
mean_correct = np.mean(correct_sims)
|
|
mean_wrong = np.mean(wrong_sims) if wrong_sims else 0
|
|
discrimination = mean_correct - mean_wrong
|
|
|
|
w_stats = net.get_weight_stats()
|
|
recall_fr = recalled.mean().item() if len(correct_sims) > 0 else 0
|
|
|
|
print(f" Correct sim: {mean_correct:.4f}, Wrong sim: {mean_wrong:.4f}, "
|
|
f"Discrimination: {discrimination:.4f}")
|
|
print(f" Recall FR: {recall_fr:.4f}, W stats: mean={w_stats['abs_mean']:.4f}, "
|
|
f"sparsity={w_stats['sparsity']:.2f}")
|
|
print(f" Learn time: {learn_time:.1f}s")
|
|
|
|
return {
|
|
"num_neurons": num_neurons,
|
|
"num_steps": num_steps,
|
|
"num_pairs": num_pairs,
|
|
"firing_rate": firing_rate,
|
|
"num_presentations": num_presentations,
|
|
"a_plus": a_plus,
|
|
"a_minus": a_minus,
|
|
"mean_correct_sim": mean_correct,
|
|
"mean_wrong_sim": mean_wrong,
|
|
"discrimination": discrimination,
|
|
"correct_sims": correct_sims,
|
|
"recall_firing_rate": recall_fr,
|
|
"weight_stats": w_stats,
|
|
"learn_time": learn_time,
|
|
}
|
|
|
|
|
|
def main():
|
|
print("=" * 60)
|
|
print("Experiment 2: STDP Associative Recall")
|
|
print("=" * 60)
|
|
|
|
results = []
|
|
|
|
# Test 1: Baseline — can it learn even 1 pair?
|
|
print("\n--- Test 1: Single pair (sanity check) ---")
|
|
r = run_recall_test(
|
|
num_neurons=2048, num_steps=64, num_pairs=1,
|
|
firing_rate=0.05, num_presentations=5,
|
|
a_plus=0.005, a_minus=0.006,
|
|
)
|
|
results.append({**r, "test": "single_pair"})
|
|
|
|
# Test 2: Vary number of pairs
|
|
print("\n--- Test 2: Scaling pairs ---")
|
|
for n_pairs in [5, 10, 20, 50]:
|
|
r = run_recall_test(
|
|
num_neurons=2048, num_steps=64, num_pairs=n_pairs,
|
|
firing_rate=0.05, num_presentations=5,
|
|
a_plus=0.005, a_minus=0.006,
|
|
)
|
|
results.append({**r, "test": f"pairs_{n_pairs}"})
|
|
|
|
# Test 3: Vary STDP learning rates
|
|
print("\n--- Test 3: STDP learning rate sweep ---")
|
|
for a_plus in [0.001, 0.005, 0.01, 0.05]:
|
|
r = run_recall_test(
|
|
num_neurons=2048, num_steps=64, num_pairs=10,
|
|
firing_rate=0.05, num_presentations=5,
|
|
a_plus=a_plus, a_minus=a_plus * 1.2,
|
|
)
|
|
results.append({**r, "test": f"lr_{a_plus}"})
|
|
|
|
# Test 4: Vary firing rate
|
|
print("\n--- Test 4: Firing rate sweep ---")
|
|
for fr in [0.02, 0.05, 0.10, 0.20]:
|
|
r = run_recall_test(
|
|
num_neurons=2048, num_steps=64, num_pairs=10,
|
|
firing_rate=fr, num_presentations=5,
|
|
a_plus=0.005, a_minus=0.006,
|
|
)
|
|
results.append({**r, "test": f"fr_{fr}"})
|
|
|
|
# Test 5: More presentations
|
|
print("\n--- Test 5: Presentation count ---")
|
|
for n_pres in [1, 3, 5, 10, 20]:
|
|
r = run_recall_test(
|
|
num_neurons=2048, num_steps=64, num_pairs=10,
|
|
firing_rate=0.05, num_presentations=n_pres,
|
|
a_plus=0.005, a_minus=0.006,
|
|
)
|
|
results.append({**r, "test": f"pres_{n_pres}"})
|
|
|
|
# Test 6: Wider network
|
|
print("\n--- Test 6: Network width ---")
|
|
for neurons in [1024, 2048, 4096, 8192]:
|
|
r = run_recall_test(
|
|
num_neurons=neurons, num_steps=64, num_pairs=10,
|
|
firing_rate=0.05, num_presentations=5,
|
|
a_plus=0.005, a_minus=0.006,
|
|
)
|
|
results.append({**r, "test": f"width_{neurons}"})
|
|
|
|
# Save results
|
|
for r in results:
|
|
r["correct_sims"] = [float(x) for x in r["correct_sims"]]
|
|
with open(RESULTS_DIR / "exp02_results.json", "w") as f:
|
|
json.dump(results, f, indent=2, default=float)
|
|
|
|
# Summary
|
|
print("\n" + "=" * 60)
|
|
print("SUMMARY")
|
|
print("=" * 60)
|
|
print(f"{'Test':<15} {'Correct':>8} {'Wrong':>8} {'Discrim':>8} {'RecallFR':>8}")
|
|
print("-" * 50)
|
|
for r in results:
|
|
print(f"{r['test']:<15} {r['mean_correct_sim']:>8.4f} "
|
|
f"{r['mean_wrong_sim']:>8.4f} {r['discrimination']:>8.4f} "
|
|
f"{r['recall_firing_rate']:>8.4f}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|