NuoNuo: Hippocampal memory module prototype
Hopfield + Hebbian hybrid memory system for LLMs. Two nights of experiments (16 iterations), validated on LongMemEval (ICLR 2025). Architecture: - Single-hop: Two-Stage Hopfield (NN top-20 → softmax settle) - Multi-hop: Hebbian W matrix with WTA pattern separation - 64% on LongMemEval (500 questions), retrieval-only, no LLM dependency - 4ms latency @ 20K memories, ~1GB VRAM Key findings: - Hopfield attention solved noise tolerance (20% → 100% vs flat Hebbian) - WTA pattern separation enables 20K+ capacity - Multi-hop associative chains (6 hops, CosSim=1.0) — RAG can't do this - MiniLM-L6 is optimal (discrimination gap > absolute similarity) - Paraphrase cue augmentation: 55% → 100% on synthetic, 36% → 64% on benchmark - SNN encoder viable (CosSim 0.99) but not needed for current architecture
This commit is contained in:
306
experiments/exp13_snn_hopfield.py
Normal file
306
experiments/exp13_snn_hopfield.py
Normal file
@@ -0,0 +1,306 @@
|
||||
"""Experiment P5: SNN-native Hopfield (spike-based attention).
|
||||
|
||||
Goal: Implement Hopfield-like attractor dynamics using LIF neurons.
|
||||
|
||||
The connection: Hopfield softmax attention with inverse temperature β
|
||||
is equivalent to a Boltzmann distribution at temperature 1/β.
|
||||
In SNN terms: β maps to membrane time constant / threshold ratio.
|
||||
|
||||
Approach: Replace softmax(β * q @ K^T) @ V with:
|
||||
1. Encode query as spike train
|
||||
2. Feed through recurrent LIF network with stored patterns as synaptic weights
|
||||
3. Network settles to attractor (nearest stored pattern)
|
||||
4. Read out associated target
|
||||
|
||||
This is closer to biological CA3 recurrent dynamics.
|
||||
"""
|
||||
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import snntorch as snn
|
||||
import numpy as np
|
||||
|
||||
DEVICE = "cuda"
|
||||
|
||||
|
||||
def cosine(a, b):
|
||||
return nn.functional.cosine_similarity(a.unsqueeze(0), b.unsqueeze(0)).item()
|
||||
|
||||
|
||||
class SNNHopfield(nn.Module):
|
||||
"""Spike-based Hopfield network.
|
||||
|
||||
Architecture:
|
||||
- Input layer: converts query embedding to current injection
|
||||
- Recurrent layer: LIF neurons with Hopfield-like connection weights
|
||||
- Readout: spike rates → attention weights → target embedding
|
||||
|
||||
The recurrent weights are set (not trained) based on stored patterns,
|
||||
making this a "configured" SNN, not a "trained" one.
|
||||
"""
|
||||
|
||||
def __init__(self, dim, beta=0.9, threshold=1.0, num_steps=50):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.num_steps = num_steps
|
||||
self.beta_lif = beta # LIF membrane decay
|
||||
self.threshold = threshold
|
||||
|
||||
self.lif = snn.Leaky(beta=beta, threshold=threshold)
|
||||
|
||||
# Stored patterns
|
||||
self.cue_patterns = []
|
||||
self.target_patterns = []
|
||||
|
||||
def store(self, cue_emb, target_emb):
|
||||
self.cue_patterns.append(cue_emb.detach())
|
||||
self.target_patterns.append(target_emb.detach())
|
||||
|
||||
def _build_weights(self):
|
||||
"""Build Hopfield-like recurrent weights from stored patterns.
|
||||
|
||||
W_ij = Σ_μ (pattern_μ_i * pattern_μ_j) / N
|
||||
This creates attractor states at each stored pattern.
|
||||
"""
|
||||
if not self.cue_patterns:
|
||||
return torch.zeros(self.dim, self.dim, device=DEVICE)
|
||||
|
||||
patterns = torch.stack(self.cue_patterns) # [N_patterns, dim]
|
||||
W = patterns.T @ patterns / len(self.cue_patterns) # [dim, dim]
|
||||
# Remove diagonal (no self-connections, like biological networks)
|
||||
W.fill_diagonal_(0)
|
||||
return W
|
||||
|
||||
def recall(self, query_emb):
|
||||
"""Spike-based attractor dynamics.
|
||||
|
||||
1. Inject query as constant current
|
||||
2. Let network settle via recurrent dynamics
|
||||
3. Read spike rates → find nearest stored pattern → get target
|
||||
"""
|
||||
W = self._build_weights()
|
||||
|
||||
# LIF dynamics
|
||||
mem = torch.zeros(self.dim, device=DEVICE)
|
||||
spike_counts = torch.zeros(self.dim, device=DEVICE)
|
||||
|
||||
# Constant input current from query (scaled)
|
||||
input_current = query_emb * 2.0 # Scale to help reach threshold
|
||||
|
||||
for step in range(self.num_steps):
|
||||
# Total current: external input + recurrent
|
||||
if step < self.num_steps // 2:
|
||||
# First half: external input drives the network
|
||||
total_current = input_current + W @ (mem / self.threshold)
|
||||
else:
|
||||
# Second half: only recurrent (free running, settle to attractor)
|
||||
total_current = W @ (mem / self.threshold)
|
||||
|
||||
spk, mem = self.lif(total_current, mem)
|
||||
spike_counts += spk
|
||||
|
||||
# Spike rates as representation
|
||||
spike_rates = spike_counts / self.num_steps # [dim]
|
||||
|
||||
# Find nearest stored pattern by spike rate similarity
|
||||
if not self.cue_patterns:
|
||||
return None, None
|
||||
|
||||
cue_mat = torch.stack(self.cue_patterns)
|
||||
sims = nn.functional.cosine_similarity(
|
||||
spike_rates.unsqueeze(0), cue_mat, dim=-1)
|
||||
|
||||
# Softmax attention based on similarity (hybrid: spike settle + soft readout)
|
||||
attn = torch.softmax(sims * 16.0, dim=0)
|
||||
target_mat = torch.stack(self.target_patterns)
|
||||
recalled = attn @ target_mat
|
||||
recalled = nn.functional.normalize(recalled, dim=0)
|
||||
|
||||
best_idx = sims.argmax().item()
|
||||
return recalled, best_idx
|
||||
|
||||
def recall_pure_spike(self, query_emb):
|
||||
"""Fully spike-based recall (no softmax at readout)."""
|
||||
W = self._build_weights()
|
||||
|
||||
mem = torch.zeros(self.dim, device=DEVICE)
|
||||
spike_counts = torch.zeros(self.dim, device=DEVICE)
|
||||
input_current = query_emb * 2.0
|
||||
|
||||
for step in range(self.num_steps):
|
||||
if step < self.num_steps // 2:
|
||||
total_current = input_current + W @ (mem / self.threshold)
|
||||
else:
|
||||
total_current = W @ (mem / self.threshold)
|
||||
spk, mem = self.lif(total_current, mem)
|
||||
spike_counts += spk
|
||||
|
||||
spike_rates = spike_counts / self.num_steps
|
||||
|
||||
# Pure spike readout: direct cosine similarity (no softmax)
|
||||
cue_mat = torch.stack(self.cue_patterns)
|
||||
sims = nn.functional.cosine_similarity(
|
||||
spike_rates.unsqueeze(0), cue_mat, dim=-1)
|
||||
best_idx = sims.argmax().item()
|
||||
return self.target_patterns[best_idx], best_idx
|
||||
|
||||
|
||||
def load_model():
|
||||
from sentence_transformers import SentenceTransformer
|
||||
return SentenceTransformer("all-MiniLM-L6-v2", device=DEVICE)
|
||||
|
||||
|
||||
def emb(model, text):
|
||||
return model.encode([text], convert_to_tensor=True,
|
||||
normalize_embeddings=True, device=DEVICE)[0]
|
||||
|
||||
|
||||
def test_basic(model):
|
||||
"""Basic SNN Hopfield recall."""
|
||||
print("=== Test 1: Basic SNN Hopfield ===\n")
|
||||
|
||||
pairs = [
|
||||
("The database is slow", "Check missing indexes"),
|
||||
("Deploy to production", "Use blue-green deployment"),
|
||||
("The API returns 500", "Check for OOM in worker"),
|
||||
("Set up monitoring", "Prometheus and Grafana"),
|
||||
("Tests failing in CI", "Need postgres container"),
|
||||
]
|
||||
|
||||
for num_steps in [20, 50, 100, 200]:
|
||||
for beta in [0.8, 0.9, 0.95]:
|
||||
net = SNNHopfield(384, beta=beta, num_steps=num_steps).to(DEVICE)
|
||||
|
||||
for cue, target in pairs:
|
||||
net.store(emb(model, cue), emb(model, target))
|
||||
|
||||
# Test exact recall
|
||||
correct = 0
|
||||
for i, (cue, target) in enumerate(pairs):
|
||||
recalled, idx = net.recall(emb(model, cue))
|
||||
if idx == i:
|
||||
correct += 1
|
||||
|
||||
# Test paraphrase
|
||||
paraphrases = ["DB is crawling", "Ship the release",
|
||||
"Getting 500 errors", "Need observability", "CI broken"]
|
||||
para_correct = 0
|
||||
for i, para in enumerate(paraphrases):
|
||||
recalled, idx = net.recall(emb(model, para))
|
||||
if idx == i:
|
||||
para_correct += 1
|
||||
|
||||
n = len(pairs)
|
||||
print(f" steps={num_steps:>3}, β={beta}: "
|
||||
f"Exact={correct}/{n}, Para={para_correct}/{n}")
|
||||
|
||||
|
||||
def test_comparison(model):
|
||||
"""Compare SNN Hopfield vs standard Hopfield."""
|
||||
print("\n=== Test 2: SNN vs Standard Hopfield ===\n")
|
||||
|
||||
pairs = [
|
||||
("The database is slow", "Check missing indexes"),
|
||||
("Deploy to production", "Use blue-green deployment"),
|
||||
("The API returns 500", "Check for OOM in worker"),
|
||||
("Set up monitoring", "Prometheus and Grafana"),
|
||||
("Tests failing in CI", "Need postgres container"),
|
||||
]
|
||||
paraphrases = ["DB is crawling", "Ship the release",
|
||||
"Getting 500 errors", "Need observability", "CI broken"]
|
||||
|
||||
# SNN Hopfield
|
||||
snn_net = SNNHopfield(384, beta=0.9, num_steps=100).to(DEVICE)
|
||||
for cue, target in pairs:
|
||||
snn_net.store(emb(model, cue), emb(model, target))
|
||||
|
||||
snn_correct = 0
|
||||
t0 = time.time()
|
||||
for i, para in enumerate(paraphrases):
|
||||
_, idx = snn_net.recall(emb(model, para))
|
||||
if idx == i:
|
||||
snn_correct += 1
|
||||
snn_time = (time.time() - t0) / len(paraphrases) * 1000
|
||||
|
||||
# Standard Hopfield (softmax attention)
|
||||
cue_embs = [emb(model, p[0]) for p in pairs]
|
||||
target_embs = [emb(model, p[1]) for p in pairs]
|
||||
cue_mat = torch.stack(cue_embs)
|
||||
target_mat = torch.stack(target_embs)
|
||||
|
||||
std_correct = 0
|
||||
t0 = time.time()
|
||||
for i, para in enumerate(paraphrases):
|
||||
q = emb(model, para)
|
||||
xi = q
|
||||
for _ in range(3):
|
||||
scores = 16.0 * (xi @ cue_mat.T)
|
||||
attn = torch.softmax(scores, dim=0)
|
||||
xi = attn @ cue_mat
|
||||
xi = nn.functional.normalize(xi, dim=0)
|
||||
scores = 16.0 * (xi @ cue_mat.T)
|
||||
attn = torch.softmax(scores, dim=0)
|
||||
best = attn.argmax().item()
|
||||
if best == i:
|
||||
std_correct += 1
|
||||
std_time = (time.time() - t0) / len(paraphrases) * 1000
|
||||
|
||||
n = len(paraphrases)
|
||||
print(f" SNN Hopfield: {snn_correct}/{n} ({snn_correct/n:.0%}), {snn_time:.1f}ms/query")
|
||||
print(f" Standard Hopfield: {std_correct}/{n} ({std_correct/n:.0%}), {std_time:.1f}ms/query")
|
||||
|
||||
|
||||
def test_with_background(model):
|
||||
"""SNN Hopfield with background noise."""
|
||||
print("\n=== Test 3: SNN Hopfield with Background ===\n")
|
||||
|
||||
pairs = [
|
||||
("The database is slow", "Check missing indexes"),
|
||||
("Deploy to production", "Use blue-green deployment"),
|
||||
("The API returns 500", "Check for OOM in worker"),
|
||||
]
|
||||
paraphrases = ["DB is crawling", "Ship the release", "Getting 500 errors"]
|
||||
|
||||
for n_bg in [0, 10, 50]:
|
||||
net = SNNHopfield(384, beta=0.9, num_steps=100).to(DEVICE)
|
||||
for cue, target in pairs:
|
||||
net.store(emb(model, cue), emb(model, target))
|
||||
|
||||
for i in range(n_bg):
|
||||
net.store(
|
||||
emb(model, f"Background task {i} about topic {i%5}"),
|
||||
emb(model, f"Background detail {i}"),
|
||||
)
|
||||
|
||||
correct = 0
|
||||
for i, para in enumerate(paraphrases):
|
||||
_, idx = net.recall(emb(model, para))
|
||||
if idx == i:
|
||||
correct += 1
|
||||
|
||||
n = len(paraphrases)
|
||||
t0 = time.time()
|
||||
net.recall(emb(model, paraphrases[0]))
|
||||
dt = (time.time() - t0) * 1000
|
||||
print(f" bg={n_bg:>3}: Para={correct}/{n} ({correct/n:.0%}), "
|
||||
f"latency={dt:.1f}ms, "
|
||||
f"W_size={net.dim**2*4/1024/1024:.0f}MB")
|
||||
|
||||
|
||||
def main():
|
||||
print("=" * 60)
|
||||
print("Experiment P5: SNN-native Hopfield")
|
||||
print("=" * 60)
|
||||
|
||||
model = load_model()
|
||||
test_basic(model)
|
||||
test_comparison(model)
|
||||
test_with_background(model)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user