Hopfield + Hebbian hybrid memory system for LLMs. Two nights of experiments (16 iterations), validated on LongMemEval (ICLR 2025). Architecture: - Single-hop: Two-Stage Hopfield (NN top-20 → softmax settle) - Multi-hop: Hebbian W matrix with WTA pattern separation - 64% on LongMemEval (500 questions), retrieval-only, no LLM dependency - 4ms latency @ 20K memories, ~1GB VRAM Key findings: - Hopfield attention solved noise tolerance (20% → 100% vs flat Hebbian) - WTA pattern separation enables 20K+ capacity - Multi-hop associative chains (6 hops, CosSim=1.0) — RAG can't do this - MiniLM-L6 is optimal (discrimination gap > absolute similarity) - Paraphrase cue augmentation: 55% → 100% on synthetic, 36% → 64% on benchmark - SNN encoder viable (CosSim 0.99) but not needed for current architecture
307 lines
10 KiB
Python
307 lines
10 KiB
Python
"""Experiment P5: SNN-native Hopfield (spike-based attention).
|
|
|
|
Goal: Implement Hopfield-like attractor dynamics using LIF neurons.
|
|
|
|
The connection: Hopfield softmax attention with inverse temperature β
|
|
is equivalent to a Boltzmann distribution at temperature 1/β.
|
|
In SNN terms: β maps to membrane time constant / threshold ratio.
|
|
|
|
Approach: Replace softmax(β * q @ K^T) @ V with:
|
|
1. Encode query as spike train
|
|
2. Feed through recurrent LIF network with stored patterns as synaptic weights
|
|
3. Network settles to attractor (nearest stored pattern)
|
|
4. Read out associated target
|
|
|
|
This is closer to biological CA3 recurrent dynamics.
|
|
"""
|
|
|
|
import sys
|
|
import time
|
|
from pathlib import Path
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import snntorch as snn
|
|
import numpy as np
|
|
|
|
DEVICE = "cuda"
|
|
|
|
|
|
def cosine(a, b):
|
|
return nn.functional.cosine_similarity(a.unsqueeze(0), b.unsqueeze(0)).item()
|
|
|
|
|
|
class SNNHopfield(nn.Module):
|
|
"""Spike-based Hopfield network.
|
|
|
|
Architecture:
|
|
- Input layer: converts query embedding to current injection
|
|
- Recurrent layer: LIF neurons with Hopfield-like connection weights
|
|
- Readout: spike rates → attention weights → target embedding
|
|
|
|
The recurrent weights are set (not trained) based on stored patterns,
|
|
making this a "configured" SNN, not a "trained" one.
|
|
"""
|
|
|
|
def __init__(self, dim, beta=0.9, threshold=1.0, num_steps=50):
|
|
super().__init__()
|
|
self.dim = dim
|
|
self.num_steps = num_steps
|
|
self.beta_lif = beta # LIF membrane decay
|
|
self.threshold = threshold
|
|
|
|
self.lif = snn.Leaky(beta=beta, threshold=threshold)
|
|
|
|
# Stored patterns
|
|
self.cue_patterns = []
|
|
self.target_patterns = []
|
|
|
|
def store(self, cue_emb, target_emb):
|
|
self.cue_patterns.append(cue_emb.detach())
|
|
self.target_patterns.append(target_emb.detach())
|
|
|
|
def _build_weights(self):
|
|
"""Build Hopfield-like recurrent weights from stored patterns.
|
|
|
|
W_ij = Σ_μ (pattern_μ_i * pattern_μ_j) / N
|
|
This creates attractor states at each stored pattern.
|
|
"""
|
|
if not self.cue_patterns:
|
|
return torch.zeros(self.dim, self.dim, device=DEVICE)
|
|
|
|
patterns = torch.stack(self.cue_patterns) # [N_patterns, dim]
|
|
W = patterns.T @ patterns / len(self.cue_patterns) # [dim, dim]
|
|
# Remove diagonal (no self-connections, like biological networks)
|
|
W.fill_diagonal_(0)
|
|
return W
|
|
|
|
def recall(self, query_emb):
|
|
"""Spike-based attractor dynamics.
|
|
|
|
1. Inject query as constant current
|
|
2. Let network settle via recurrent dynamics
|
|
3. Read spike rates → find nearest stored pattern → get target
|
|
"""
|
|
W = self._build_weights()
|
|
|
|
# LIF dynamics
|
|
mem = torch.zeros(self.dim, device=DEVICE)
|
|
spike_counts = torch.zeros(self.dim, device=DEVICE)
|
|
|
|
# Constant input current from query (scaled)
|
|
input_current = query_emb * 2.0 # Scale to help reach threshold
|
|
|
|
for step in range(self.num_steps):
|
|
# Total current: external input + recurrent
|
|
if step < self.num_steps // 2:
|
|
# First half: external input drives the network
|
|
total_current = input_current + W @ (mem / self.threshold)
|
|
else:
|
|
# Second half: only recurrent (free running, settle to attractor)
|
|
total_current = W @ (mem / self.threshold)
|
|
|
|
spk, mem = self.lif(total_current, mem)
|
|
spike_counts += spk
|
|
|
|
# Spike rates as representation
|
|
spike_rates = spike_counts / self.num_steps # [dim]
|
|
|
|
# Find nearest stored pattern by spike rate similarity
|
|
if not self.cue_patterns:
|
|
return None, None
|
|
|
|
cue_mat = torch.stack(self.cue_patterns)
|
|
sims = nn.functional.cosine_similarity(
|
|
spike_rates.unsqueeze(0), cue_mat, dim=-1)
|
|
|
|
# Softmax attention based on similarity (hybrid: spike settle + soft readout)
|
|
attn = torch.softmax(sims * 16.0, dim=0)
|
|
target_mat = torch.stack(self.target_patterns)
|
|
recalled = attn @ target_mat
|
|
recalled = nn.functional.normalize(recalled, dim=0)
|
|
|
|
best_idx = sims.argmax().item()
|
|
return recalled, best_idx
|
|
|
|
def recall_pure_spike(self, query_emb):
|
|
"""Fully spike-based recall (no softmax at readout)."""
|
|
W = self._build_weights()
|
|
|
|
mem = torch.zeros(self.dim, device=DEVICE)
|
|
spike_counts = torch.zeros(self.dim, device=DEVICE)
|
|
input_current = query_emb * 2.0
|
|
|
|
for step in range(self.num_steps):
|
|
if step < self.num_steps // 2:
|
|
total_current = input_current + W @ (mem / self.threshold)
|
|
else:
|
|
total_current = W @ (mem / self.threshold)
|
|
spk, mem = self.lif(total_current, mem)
|
|
spike_counts += spk
|
|
|
|
spike_rates = spike_counts / self.num_steps
|
|
|
|
# Pure spike readout: direct cosine similarity (no softmax)
|
|
cue_mat = torch.stack(self.cue_patterns)
|
|
sims = nn.functional.cosine_similarity(
|
|
spike_rates.unsqueeze(0), cue_mat, dim=-1)
|
|
best_idx = sims.argmax().item()
|
|
return self.target_patterns[best_idx], best_idx
|
|
|
|
|
|
def load_model():
|
|
from sentence_transformers import SentenceTransformer
|
|
return SentenceTransformer("all-MiniLM-L6-v2", device=DEVICE)
|
|
|
|
|
|
def emb(model, text):
|
|
return model.encode([text], convert_to_tensor=True,
|
|
normalize_embeddings=True, device=DEVICE)[0]
|
|
|
|
|
|
def test_basic(model):
|
|
"""Basic SNN Hopfield recall."""
|
|
print("=== Test 1: Basic SNN Hopfield ===\n")
|
|
|
|
pairs = [
|
|
("The database is slow", "Check missing indexes"),
|
|
("Deploy to production", "Use blue-green deployment"),
|
|
("The API returns 500", "Check for OOM in worker"),
|
|
("Set up monitoring", "Prometheus and Grafana"),
|
|
("Tests failing in CI", "Need postgres container"),
|
|
]
|
|
|
|
for num_steps in [20, 50, 100, 200]:
|
|
for beta in [0.8, 0.9, 0.95]:
|
|
net = SNNHopfield(384, beta=beta, num_steps=num_steps).to(DEVICE)
|
|
|
|
for cue, target in pairs:
|
|
net.store(emb(model, cue), emb(model, target))
|
|
|
|
# Test exact recall
|
|
correct = 0
|
|
for i, (cue, target) in enumerate(pairs):
|
|
recalled, idx = net.recall(emb(model, cue))
|
|
if idx == i:
|
|
correct += 1
|
|
|
|
# Test paraphrase
|
|
paraphrases = ["DB is crawling", "Ship the release",
|
|
"Getting 500 errors", "Need observability", "CI broken"]
|
|
para_correct = 0
|
|
for i, para in enumerate(paraphrases):
|
|
recalled, idx = net.recall(emb(model, para))
|
|
if idx == i:
|
|
para_correct += 1
|
|
|
|
n = len(pairs)
|
|
print(f" steps={num_steps:>3}, β={beta}: "
|
|
f"Exact={correct}/{n}, Para={para_correct}/{n}")
|
|
|
|
|
|
def test_comparison(model):
|
|
"""Compare SNN Hopfield vs standard Hopfield."""
|
|
print("\n=== Test 2: SNN vs Standard Hopfield ===\n")
|
|
|
|
pairs = [
|
|
("The database is slow", "Check missing indexes"),
|
|
("Deploy to production", "Use blue-green deployment"),
|
|
("The API returns 500", "Check for OOM in worker"),
|
|
("Set up monitoring", "Prometheus and Grafana"),
|
|
("Tests failing in CI", "Need postgres container"),
|
|
]
|
|
paraphrases = ["DB is crawling", "Ship the release",
|
|
"Getting 500 errors", "Need observability", "CI broken"]
|
|
|
|
# SNN Hopfield
|
|
snn_net = SNNHopfield(384, beta=0.9, num_steps=100).to(DEVICE)
|
|
for cue, target in pairs:
|
|
snn_net.store(emb(model, cue), emb(model, target))
|
|
|
|
snn_correct = 0
|
|
t0 = time.time()
|
|
for i, para in enumerate(paraphrases):
|
|
_, idx = snn_net.recall(emb(model, para))
|
|
if idx == i:
|
|
snn_correct += 1
|
|
snn_time = (time.time() - t0) / len(paraphrases) * 1000
|
|
|
|
# Standard Hopfield (softmax attention)
|
|
cue_embs = [emb(model, p[0]) for p in pairs]
|
|
target_embs = [emb(model, p[1]) for p in pairs]
|
|
cue_mat = torch.stack(cue_embs)
|
|
target_mat = torch.stack(target_embs)
|
|
|
|
std_correct = 0
|
|
t0 = time.time()
|
|
for i, para in enumerate(paraphrases):
|
|
q = emb(model, para)
|
|
xi = q
|
|
for _ in range(3):
|
|
scores = 16.0 * (xi @ cue_mat.T)
|
|
attn = torch.softmax(scores, dim=0)
|
|
xi = attn @ cue_mat
|
|
xi = nn.functional.normalize(xi, dim=0)
|
|
scores = 16.0 * (xi @ cue_mat.T)
|
|
attn = torch.softmax(scores, dim=0)
|
|
best = attn.argmax().item()
|
|
if best == i:
|
|
std_correct += 1
|
|
std_time = (time.time() - t0) / len(paraphrases) * 1000
|
|
|
|
n = len(paraphrases)
|
|
print(f" SNN Hopfield: {snn_correct}/{n} ({snn_correct/n:.0%}), {snn_time:.1f}ms/query")
|
|
print(f" Standard Hopfield: {std_correct}/{n} ({std_correct/n:.0%}), {std_time:.1f}ms/query")
|
|
|
|
|
|
def test_with_background(model):
|
|
"""SNN Hopfield with background noise."""
|
|
print("\n=== Test 3: SNN Hopfield with Background ===\n")
|
|
|
|
pairs = [
|
|
("The database is slow", "Check missing indexes"),
|
|
("Deploy to production", "Use blue-green deployment"),
|
|
("The API returns 500", "Check for OOM in worker"),
|
|
]
|
|
paraphrases = ["DB is crawling", "Ship the release", "Getting 500 errors"]
|
|
|
|
for n_bg in [0, 10, 50]:
|
|
net = SNNHopfield(384, beta=0.9, num_steps=100).to(DEVICE)
|
|
for cue, target in pairs:
|
|
net.store(emb(model, cue), emb(model, target))
|
|
|
|
for i in range(n_bg):
|
|
net.store(
|
|
emb(model, f"Background task {i} about topic {i%5}"),
|
|
emb(model, f"Background detail {i}"),
|
|
)
|
|
|
|
correct = 0
|
|
for i, para in enumerate(paraphrases):
|
|
_, idx = net.recall(emb(model, para))
|
|
if idx == i:
|
|
correct += 1
|
|
|
|
n = len(paraphrases)
|
|
t0 = time.time()
|
|
net.recall(emb(model, paraphrases[0]))
|
|
dt = (time.time() - t0) * 1000
|
|
print(f" bg={n_bg:>3}: Para={correct}/{n} ({correct/n:.0%}), "
|
|
f"latency={dt:.1f}ms, "
|
|
f"W_size={net.dim**2*4/1024/1024:.0f}MB")
|
|
|
|
|
|
def main():
|
|
print("=" * 60)
|
|
print("Experiment P5: SNN-native Hopfield")
|
|
print("=" * 60)
|
|
|
|
model = load_model()
|
|
test_basic(model)
|
|
test_comparison(model)
|
|
test_with_background(model)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|