NuoNuo: Hippocampal memory module prototype

Hopfield + Hebbian hybrid memory system for LLMs.
Two nights of experiments (16 iterations), validated on LongMemEval (ICLR 2025).

Architecture:
- Single-hop: Two-Stage Hopfield (NN top-20 → softmax settle)
- Multi-hop: Hebbian W matrix with WTA pattern separation
- 64% on LongMemEval (500 questions), retrieval-only, no LLM dependency
- 4ms latency @ 20K memories, ~1GB VRAM

Key findings:
- Hopfield attention solved noise tolerance (20% → 100% vs flat Hebbian)
- WTA pattern separation enables 20K+ capacity
- Multi-hop associative chains (6 hops, CosSim=1.0) — RAG can't do this
- MiniLM-L6 is optimal (discrimination gap > absolute similarity)
- Paraphrase cue augmentation: 55% → 100% on synthetic, 36% → 64% on benchmark
- SNN encoder viable (CosSim 0.99) but not needed for current architecture
This commit is contained in:
2026-04-07 10:37:24 +01:00
commit d923aa1e31
65 changed files with 13148 additions and 0 deletions

View File

@@ -0,0 +1,368 @@
"""Experiment 2e: Noise-tolerant retrieval.
Problem: WTA pattern separation is brittle to noise in cue embeddings.
Real use case requires retrieving from semantically similar (not identical) cues.
Approaches to test:
1. Soft-WTA: Use softmax temperature instead of hard top-k
2. Multi-probe: Multiple noisy retrievals + voting
3. Coarse-to-fine: Nearest-neighbor in embedding space → exact Hebbian recall
4. Learned similarity-preserving hash: train the separator to be noise-robust
5. Wider k: trade capacity for noise robustness
"""
import sys
import time
import json
from pathlib import Path
import torch
import torch.nn as nn
import numpy as np
DEVICE = "cuda"
RESULTS_DIR = Path(__file__).parent.parent / "doc"
def cosine(a, b):
if a.norm() == 0 or b.norm() == 0:
return 0.0
return nn.functional.cosine_similarity(a.unsqueeze(0), b.unsqueeze(0)).item()
def winner_take_all(x, k):
_, topk_idx = x.topk(k, dim=-1)
out = torch.zeros_like(x)
out.scatter_(-1, topk_idx, 1.0)
return out
class SoftWTASeparator(nn.Module):
"""Soft winner-take-all using temperature-scaled softmax.
Instead of hard binary codes, produces soft sparse codes.
More robust to noise but reduces discrimination.
"""
def __init__(self, input_dim, code_dim, temperature=0.1):
super().__init__()
self.temperature = temperature
proj = torch.randn(input_dim, code_dim) * (1.0 / input_dim**0.5)
self.register_buffer('proj', proj)
def forward(self, x):
h = x @ self.proj
# Soft WTA: high temp → more spread, low temp → more sparse
return torch.softmax(h / self.temperature, dim=-1)
class MultiProbeSeparator(nn.Module):
"""Multiple random projections, retrieve from all, majority vote."""
def __init__(self, input_dim, code_dim, k_active, num_probes=8):
super().__init__()
self.k_active = k_active
self.num_probes = num_probes
# Multiple random projections
projs = torch.randn(num_probes, input_dim, code_dim) * (1.0 / input_dim**0.5)
self.register_buffer('projs', projs)
def forward(self, x):
"""Returns averaged code across all probes."""
votes = torch.zeros(self.projs.shape[2], device=x.device)
for i in range(self.num_probes):
h = x @ self.projs[i]
code = winner_take_all(h, self.k_active)
votes += code
# Threshold: active if majority of probes agree
threshold = self.num_probes / 2
return (votes > threshold).float()
class CoarseToFineMemory(nn.Module):
"""Coarse: nearest-neighbor in embedding space.
Fine: exact Hebbian recall from nearest stored cue.
This is the most practical approach: SNN/Hebbian provides the
association storage, but retrieval is bootstrapped by embedding similarity.
"""
def __init__(self, input_dim, code_dim=16384, k_active=20):
super().__init__()
self.code_dim = code_dim
self.k_active = k_active
proj = torch.randn(input_dim, code_dim, device=DEVICE) * (1.0 / input_dim**0.5)
self.register_buffer('proj', proj)
target_proj = torch.randn(input_dim, code_dim, device=DEVICE) * (1.0 / input_dim**0.5)
self.register_buffer('target_proj', target_proj)
self.W = nn.Parameter(torch.zeros(code_dim, code_dim, device=DEVICE),
requires_grad=False)
# Store raw cue embeddings for nearest-neighbor lookup
self.cue_store = []
def separate(self, x, proj):
h = x @ proj
return winner_take_all(h, self.k_active)
def learn(self, cue, target):
self.cue_store.append(cue.detach().clone())
cue_code = self.separate(cue, self.proj)
target_code = self.separate(target, self.target_proj)
self.W.data += torch.outer(target_code, cue_code)
def recall(self, query):
"""Coarse: find nearest stored cue. Fine: Hebbian recall."""
if not self.cue_store:
return torch.zeros(self.code_dim, device=DEVICE)
# Nearest neighbor
cue_matrix = torch.stack(self.cue_store) # [N, dim]
sims = nn.functional.cosine_similarity(
query.unsqueeze(0), cue_matrix, dim=-1) # [N]
best_idx = sims.argmax()
best_cue = self.cue_store[best_idx]
# Exact Hebbian recall with nearest cue
cue_code = self.separate(best_cue, self.proj)
raw = self.W @ cue_code
return winner_take_all(raw, self.k_active)
def test_approach(name, mem_class, num_pairs=100, noise_levels=None, **kwargs):
"""Generic test harness."""
if noise_levels is None:
noise_levels = [0.0, 0.1, 0.2, 0.5, 1.0, 2.0]
input_dim = 768
cues = [nn.functional.normalize(torch.randn(input_dim, device=DEVICE), dim=0)
for _ in range(num_pairs)]
targets = [nn.functional.normalize(torch.randn(input_dim, device=DEVICE), dim=0)
for _ in range(num_pairs)]
mem = mem_class(**kwargs).to(DEVICE) if not isinstance(mem_class, nn.Module) else mem_class
# Learn
for i in range(num_pairs):
mem.learn(cues[i], targets[i])
results = {}
for noise_std in noise_levels:
correct_sims = []
for i in range(num_pairs):
noisy_cue = cues[i] + torch.randn_like(cues[i]) * noise_std
noisy_cue = nn.functional.normalize(noisy_cue, dim=0)
recalled = mem.recall(noisy_cue)
# Compare to target code
if hasattr(mem, 'target_separator'):
target_code = mem.target_separator(targets[i])
elif hasattr(mem, 'target_proj'):
target_code = winner_take_all(targets[i] @ mem.target_proj, mem.k_active)
else:
target_code = targets[i]
cs = cosine(recalled, target_code)
correct_sims.append(cs)
mc = np.mean(correct_sims)
exact = np.mean([s > 0.99 for s in correct_sims])
results[noise_std] = {"mean_cos": mc, "exact_rate": exact}
print(f" {name}: noise={noise_std:.2f} → CosSim={mc:.4f}, Exact={exact:.2%}")
return results
# --- Approach-specific memory classes ---
class SoftWTAMemory(nn.Module):
def __init__(self, input_dim=768, code_dim=16384, temperature=0.1):
super().__init__()
self.separator = SoftWTASeparator(input_dim, code_dim, temperature)
self.target_separator = SoftWTASeparator(input_dim, code_dim, temperature)
self.W = nn.Parameter(torch.zeros(code_dim, code_dim), requires_grad=False)
def learn(self, cue, target):
cc = self.separator(cue)
tc = self.target_separator(target)
self.W.data += torch.outer(tc, cc)
def recall(self, cue):
cc = self.separator(cue)
return self.W @ cc
class MultiProbeMemory(nn.Module):
def __init__(self, input_dim=768, code_dim=8192, k_active=20, num_probes=16):
super().__init__()
self.separator = MultiProbeSeparator(input_dim, code_dim, k_active, num_probes)
self.target_separator = MultiProbeSeparator(input_dim, code_dim, k_active, num_probes)
self.k_active = k_active
self.W = nn.Parameter(torch.zeros(code_dim, code_dim), requires_grad=False)
def learn(self, cue, target):
cc = self.separator(cue)
tc = self.target_separator(target)
self.W.data += torch.outer(tc, cc)
def recall(self, cue):
cc = self.separator(cue)
raw = self.W @ cc
return winner_take_all(raw, self.k_active)
class WiderKMemory(nn.Module):
"""Just use wider k — simple and might work."""
def __init__(self, input_dim=768, code_dim=16384, k_active=200):
super().__init__()
self.k_active = k_active
proj = torch.randn(input_dim, code_dim) * (1.0 / input_dim**0.5)
self.register_buffer('proj', proj)
target_proj = torch.randn(input_dim, code_dim) * (1.0 / input_dim**0.5)
self.register_buffer('target_proj', target_proj)
self.W = nn.Parameter(torch.zeros(code_dim, code_dim), requires_grad=False)
def learn(self, cue, target):
cc = winner_take_all(cue @ self.proj, self.k_active)
tc = winner_take_all(target @ self.target_proj, self.k_active)
self.W.data += torch.outer(tc, cc)
def recall(self, cue):
cc = winner_take_all(cue @ self.proj, self.k_active)
raw = self.W @ cc
return winner_take_all(raw, self.k_active)
@property
def target_separator(self):
return None # handled differently
def main():
print("=" * 60)
print("Experiment 2e: Noise-Tolerant Retrieval")
print("=" * 60)
noise_levels = [0.0, 0.05, 0.1, 0.2, 0.5, 1.0]
num_pairs = 100
all_results = {}
# 1. Soft WTA
print("\n=== 1. Soft WTA ===")
for temp in [0.01, 0.05, 0.1, 0.5]:
name = f"soft_wta_t{temp}"
print(f"\n-- temperature={temp} --")
mem = SoftWTAMemory(temperature=temp).to(DEVICE)
cues = [nn.functional.normalize(torch.randn(768, device=DEVICE), dim=0) for _ in range(num_pairs)]
targets = [nn.functional.normalize(torch.randn(768, device=DEVICE), dim=0) for _ in range(num_pairs)]
for i in range(num_pairs):
mem.learn(cues[i], targets[i])
results = {}
for ns in noise_levels:
sims = []
for i in range(num_pairs):
noisy = nn.functional.normalize(cues[i] + torch.randn_like(cues[i]) * ns, dim=0)
recalled = mem.recall(noisy)
tc = mem.target_separator(targets[i])
sims.append(cosine(recalled, tc))
mc = np.mean(sims)
print(f" noise={ns:.2f}: CosSim={mc:.4f}")
results[ns] = mc
all_results[name] = results
# 2. Multi-probe
print("\n=== 2. Multi-Probe ===")
for n_probes in [4, 8, 16, 32]:
name = f"multiprobe_{n_probes}"
print(f"\n-- probes={n_probes} --")
mem = MultiProbeMemory(num_probes=n_probes).to(DEVICE)
cues = [nn.functional.normalize(torch.randn(768, device=DEVICE), dim=0) for _ in range(num_pairs)]
targets = [nn.functional.normalize(torch.randn(768, device=DEVICE), dim=0) for _ in range(num_pairs)]
for i in range(num_pairs):
mem.learn(cues[i], targets[i])
results = {}
for ns in noise_levels:
sims = []
for i in range(num_pairs):
noisy = nn.functional.normalize(cues[i] + torch.randn_like(cues[i]) * ns, dim=0)
recalled = mem.recall(noisy)
tc = mem.target_separator(targets[i])
sims.append(cosine(recalled, tc))
mc = np.mean(sims)
print(f" noise={ns:.2f}: CosSim={mc:.4f}")
results[ns] = mc
all_results[name] = results
# 3. Coarse-to-fine
print("\n=== 3. Coarse-to-Fine (NN + Hebbian) ===")
mem = CoarseToFineMemory(768).to(DEVICE)
cues = [nn.functional.normalize(torch.randn(768, device=DEVICE), dim=0) for _ in range(num_pairs)]
targets = [nn.functional.normalize(torch.randn(768, device=DEVICE), dim=0) for _ in range(num_pairs)]
for i in range(num_pairs):
mem.learn(cues[i], targets[i])
results = {}
for ns in noise_levels:
sims = []
for i in range(num_pairs):
noisy = nn.functional.normalize(cues[i] + torch.randn_like(cues[i]) * ns, dim=0)
recalled = mem.recall(noisy)
tc = winner_take_all(targets[i] @ mem.target_proj, mem.k_active)
sims.append(cosine(recalled, tc))
mc = np.mean(sims)
print(f" noise={ns:.2f}: CosSim={mc:.4f}")
results[ns] = mc
all_results["coarse_to_fine"] = results
# 4. Wider k
print("\n=== 4. Wider K ===")
for k in [50, 100, 200, 500, 1000]:
name = f"wider_k_{k}"
print(f"\n-- k={k} --")
mem = WiderKMemory(k_active=k).to(DEVICE)
cues = [nn.functional.normalize(torch.randn(768, device=DEVICE), dim=0) for _ in range(num_pairs)]
targets = [nn.functional.normalize(torch.randn(768, device=DEVICE), dim=0) for _ in range(num_pairs)]
for i in range(num_pairs):
mem.learn(cues[i], targets[i])
results = {}
for ns in noise_levels:
sims = []
for i in range(num_pairs):
noisy = nn.functional.normalize(cues[i] + torch.randn_like(cues[i]) * ns, dim=0)
recalled = mem.recall(noisy)
tc = winner_take_all(targets[i] @ mem.target_proj, k)
sims.append(cosine(recalled, tc))
mc = np.mean(sims)
print(f" noise={ns:.2f}: CosSim={mc:.4f}")
results[ns] = mc
all_results[name] = results
# Save
serializable = {}
for k, v in all_results.items():
serializable[k] = {str(kk): float(vv) for kk, vv in v.items()}
with open(RESULTS_DIR / "exp02e_results.json", "w") as f:
json.dump(serializable, f, indent=2)
# Summary table
print("\n" + "=" * 80)
print("SUMMARY: CosSim at each noise level")
print(f"{'Method':<25}", end="")
for ns in noise_levels:
print(f" σ={ns:.2f}", end="")
print()
print("-" * 80)
for method, res in all_results.items():
print(f"{method:<25}", end="")
for ns in noise_levels:
v = res.get(ns, 0)
print(f" {v:>5.3f}", end="")
print()
if __name__ == "__main__":
main()