Hopfield + Hebbian hybrid memory system for LLMs. Two nights of experiments (16 iterations), validated on LongMemEval (ICLR 2025). Architecture: - Single-hop: Two-Stage Hopfield (NN top-20 → softmax settle) - Multi-hop: Hebbian W matrix with WTA pattern separation - 64% on LongMemEval (500 questions), retrieval-only, no LLM dependency - 4ms latency @ 20K memories, ~1GB VRAM Key findings: - Hopfield attention solved noise tolerance (20% → 100% vs flat Hebbian) - WTA pattern separation enables 20K+ capacity - Multi-hop associative chains (6 hops, CosSim=1.0) — RAG can't do this - MiniLM-L6 is optimal (discrimination gap > absolute similarity) - Paraphrase cue augmentation: 55% → 100% on synthetic, 36% → 64% on benchmark - SNN encoder viable (CosSim 0.99) but not needed for current architecture
255 lines
8.5 KiB
Python
255 lines
8.5 KiB
Python
"""Experiment 2f: Check discrimination for soft WTA + test learned separator.
|
|
|
|
Soft WTA temp=0.5 showed perfect noise tolerance but might have zero discrimination.
|
|
Need to check: can it tell correct target from wrong targets?
|
|
|
|
Then test: learned pattern separator (trained to be noise-robust via contrastive loss).
|
|
"""
|
|
|
|
import sys
|
|
import time
|
|
import json
|
|
from pathlib import Path
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.optim as optim
|
|
import numpy as np
|
|
|
|
DEVICE = "cuda"
|
|
RESULTS_DIR = Path(__file__).parent.parent / "doc"
|
|
|
|
|
|
def cosine(a, b):
|
|
if a.norm() == 0 or b.norm() == 0:
|
|
return 0.0
|
|
return nn.functional.cosine_similarity(a.unsqueeze(0), b.unsqueeze(0)).item()
|
|
|
|
|
|
def winner_take_all(x, k):
|
|
_, idx = x.topk(k, dim=-1)
|
|
out = torch.zeros_like(x)
|
|
out.scatter_(-1, idx, 1.0)
|
|
return out
|
|
|
|
|
|
class SoftWTAMemory(nn.Module):
|
|
def __init__(self, input_dim=768, code_dim=16384, temperature=0.5):
|
|
super().__init__()
|
|
self.temperature = temperature
|
|
proj = torch.randn(input_dim, code_dim) * (1.0 / input_dim**0.5)
|
|
self.register_buffer('proj', proj)
|
|
target_proj = torch.randn(input_dim, code_dim) * (1.0 / input_dim**0.5)
|
|
self.register_buffer('target_proj', target_proj)
|
|
self.W = nn.Parameter(torch.zeros(code_dim, code_dim), requires_grad=False)
|
|
|
|
def encode(self, x, proj):
|
|
return torch.softmax((x @ proj) / self.temperature, dim=-1)
|
|
|
|
def learn(self, cue, target):
|
|
cc = self.encode(cue, self.proj)
|
|
tc = self.encode(target, self.target_proj)
|
|
self.W.data += torch.outer(tc, cc)
|
|
|
|
def recall(self, cue):
|
|
cc = self.encode(cue, self.proj)
|
|
return self.W @ cc
|
|
|
|
|
|
def check_discrimination(temperature, num_pairs=100):
|
|
"""Check correct vs wrong similarity for soft WTA."""
|
|
mem = SoftWTAMemory(temperature=temperature).to(DEVICE)
|
|
|
|
cues = [nn.functional.normalize(torch.randn(768, device=DEVICE), dim=0)
|
|
for _ in range(num_pairs)]
|
|
targets = [nn.functional.normalize(torch.randn(768, device=DEVICE), dim=0)
|
|
for _ in range(num_pairs)]
|
|
|
|
for i in range(num_pairs):
|
|
mem.learn(cues[i], targets[i])
|
|
|
|
# Test with noise=0.1
|
|
for noise_std in [0.0, 0.1, 0.5]:
|
|
correct_sims = []
|
|
wrong_sims = []
|
|
for i in range(num_pairs):
|
|
noisy = nn.functional.normalize(
|
|
cues[i] + torch.randn_like(cues[i]) * noise_std, dim=0)
|
|
recalled = mem.recall(noisy)
|
|
|
|
tc = mem.encode(targets[i], mem.target_proj)
|
|
correct_sims.append(cosine(recalled, tc))
|
|
|
|
# Compare to random wrong targets
|
|
for j in range(min(20, num_pairs)):
|
|
if j != i:
|
|
wc = mem.encode(targets[j], mem.target_proj)
|
|
wrong_sims.append(cosine(recalled, wc))
|
|
|
|
mc = np.mean(correct_sims)
|
|
mw = np.mean(wrong_sims)
|
|
print(f" temp={temperature}, noise={noise_std:.1f}: "
|
|
f"Correct={mc:.4f}, Wrong={mw:.4f}, Disc={mc-mw:.4f}")
|
|
|
|
|
|
class LearnedSeparator(nn.Module):
|
|
"""Trained pattern separator: maps similar inputs to same code.
|
|
|
|
Architecture: MLP → sparse output (WTA)
|
|
Training: contrastive loss on (original, noisy) pairs
|
|
"""
|
|
def __init__(self, input_dim=768, code_dim=4096, k_active=50):
|
|
super().__init__()
|
|
self.k_active = k_active
|
|
self.code_dim = code_dim
|
|
self.net = nn.Sequential(
|
|
nn.Linear(input_dim, code_dim),
|
|
nn.ReLU(),
|
|
nn.Linear(code_dim, code_dim),
|
|
)
|
|
|
|
def forward(self, x):
|
|
h = self.net(x)
|
|
return winner_take_all(h, self.k_active)
|
|
|
|
def forward_soft(self, x, temperature=0.1):
|
|
"""Soft version for training (differentiable)."""
|
|
h = self.net(x)
|
|
return torch.softmax(h / temperature, dim=-1)
|
|
|
|
|
|
def train_learned_separator(input_dim=768, code_dim=4096, k_active=50,
|
|
epochs=100, batch_size=128, noise_std=0.3):
|
|
"""Train separator to produce same codes for original and noisy versions."""
|
|
sep = LearnedSeparator(input_dim, code_dim, k_active).to(DEVICE)
|
|
optimizer = optim.Adam(sep.parameters(), lr=1e-3)
|
|
|
|
print(f"\nTraining learned separator (code_dim={code_dim}, k={k_active}, "
|
|
f"noise={noise_std})")
|
|
|
|
for epoch in range(epochs):
|
|
# Generate batch of normalized vectors
|
|
x = nn.functional.normalize(torch.randn(batch_size, input_dim, device=DEVICE), dim=1)
|
|
# Noisy version
|
|
x_noisy = nn.functional.normalize(x + torch.randn_like(x) * noise_std, dim=1)
|
|
# Different vector (negative)
|
|
x_neg = nn.functional.normalize(torch.randn(batch_size, input_dim, device=DEVICE), dim=1)
|
|
|
|
# Soft codes
|
|
code = sep.forward_soft(x)
|
|
code_noisy = sep.forward_soft(x_noisy)
|
|
code_neg = sep.forward_soft(x_neg)
|
|
|
|
# Contrastive loss: same input → same code, diff input → diff code
|
|
pos_sim = nn.functional.cosine_similarity(code, code_noisy, dim=1).mean()
|
|
neg_sim = nn.functional.cosine_similarity(code, code_neg, dim=1).mean()
|
|
|
|
loss = -pos_sim + 0.5 * torch.relu(neg_sim - 0.1)
|
|
|
|
# Sparsity regularization
|
|
entropy = -(code * (code + 1e-10).log()).sum(dim=1).mean()
|
|
loss += 0.01 * entropy
|
|
|
|
optimizer.zero_grad()
|
|
loss.backward()
|
|
optimizer.step()
|
|
|
|
if (epoch + 1) % 20 == 0:
|
|
with torch.no_grad():
|
|
hard_code = sep(x)
|
|
hard_noisy = sep(x_noisy)
|
|
hard_neg = sep(x_neg)
|
|
# Exact match rate (same WTA pattern)
|
|
match_rate = (hard_code * hard_noisy).sum(dim=1).mean() / k_active
|
|
neg_match = (hard_code * hard_neg).sum(dim=1).mean() / k_active
|
|
print(f" Epoch {epoch+1}: loss={loss.item():.4f}, "
|
|
f"pos_match={match_rate:.4f}, neg_match={neg_match:.4f}")
|
|
|
|
return sep
|
|
|
|
|
|
def test_learned_memory(sep, num_pairs=100, noise_levels=None):
|
|
"""Test Hebbian memory using learned separator."""
|
|
if noise_levels is None:
|
|
noise_levels = [0.0, 0.1, 0.2, 0.5, 1.0]
|
|
|
|
code_dim = sep.code_dim
|
|
k = sep.k_active
|
|
|
|
W = torch.zeros(code_dim, code_dim, device=DEVICE)
|
|
|
|
cues = [nn.functional.normalize(torch.randn(768, device=DEVICE), dim=0)
|
|
for _ in range(num_pairs)]
|
|
targets = [nn.functional.normalize(torch.randn(768, device=DEVICE), dim=0)
|
|
for _ in range(num_pairs)]
|
|
|
|
# Learn
|
|
with torch.no_grad():
|
|
cue_codes = [sep(c.unsqueeze(0)).squeeze() for c in cues]
|
|
target_codes = [sep(t.unsqueeze(0)).squeeze() for t in targets]
|
|
|
|
for i in range(num_pairs):
|
|
W += torch.outer(target_codes[i], cue_codes[i])
|
|
|
|
# Test
|
|
for ns in noise_levels:
|
|
correct_sims = []
|
|
wrong_sims = []
|
|
for i in range(num_pairs):
|
|
noisy = nn.functional.normalize(
|
|
cues[i] + torch.randn_like(cues[i]) * ns, dim=0)
|
|
with torch.no_grad():
|
|
nc = sep(noisy.unsqueeze(0)).squeeze()
|
|
recalled_raw = W @ nc
|
|
recalled = winner_take_all(recalled_raw, k)
|
|
|
|
cs = cosine(recalled, target_codes[i])
|
|
correct_sims.append(cs)
|
|
|
|
for j in range(min(20, num_pairs)):
|
|
if j != i:
|
|
wrong_sims.append(cosine(recalled, target_codes[j]))
|
|
|
|
mc = np.mean(correct_sims)
|
|
mw = np.mean(wrong_sims)
|
|
exact = np.mean([s > 0.99 for s in correct_sims])
|
|
print(f" noise={ns:.2f}: Correct={mc:.4f}, Wrong={mw:.4f}, "
|
|
f"Disc={mc-mw:.4f}, Exact={exact:.2%}")
|
|
|
|
|
|
def main():
|
|
print("=" * 60)
|
|
print("Experiment 2f: Discrimination Check + Learned Separator")
|
|
print("=" * 60)
|
|
|
|
# Part 1: Check discrimination for soft WTA
|
|
print("\n=== Part 1: Soft WTA Discrimination ===")
|
|
for temp in [0.01, 0.05, 0.1, 0.5, 1.0]:
|
|
check_discrimination(temp)
|
|
print()
|
|
|
|
# Part 2: Learned separator
|
|
print("\n=== Part 2: Learned Separator ===")
|
|
|
|
# Train with different noise levels
|
|
for train_noise in [0.1, 0.3, 0.5]:
|
|
sep = train_learned_separator(
|
|
code_dim=4096, k_active=50,
|
|
epochs=200, noise_std=train_noise)
|
|
|
|
print(f"\n Testing (trained with noise={train_noise}):")
|
|
test_learned_memory(sep, num_pairs=100)
|
|
print()
|
|
|
|
# Part 3: Larger learned separator
|
|
print("\n=== Part 3: Larger Learned Separator (code=8192, k=20) ===")
|
|
sep = train_learned_separator(
|
|
code_dim=8192, k_active=20,
|
|
epochs=300, noise_std=0.3)
|
|
print("\n Testing:")
|
|
test_learned_memory(sep, num_pairs=200)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|