NuoNuo: Hippocampal memory module prototype
Hopfield + Hebbian hybrid memory system for LLMs. Two nights of experiments (16 iterations), validated on LongMemEval (ICLR 2025). Architecture: - Single-hop: Two-Stage Hopfield (NN top-20 → softmax settle) - Multi-hop: Hebbian W matrix with WTA pattern separation - 64% on LongMemEval (500 questions), retrieval-only, no LLM dependency - 4ms latency @ 20K memories, ~1GB VRAM Key findings: - Hopfield attention solved noise tolerance (20% → 100% vs flat Hebbian) - WTA pattern separation enables 20K+ capacity - Multi-hop associative chains (6 hops, CosSim=1.0) — RAG can't do this - MiniLM-L6 is optimal (discrimination gap > absolute similarity) - Paraphrase cue augmentation: 55% → 100% on synthetic, 36% → 64% on benchmark - SNN encoder viable (CosSim 0.99) but not needed for current architecture
This commit is contained in:
368
experiments/exp02e_noise_tolerance.py
Normal file
368
experiments/exp02e_noise_tolerance.py
Normal file
@@ -0,0 +1,368 @@
|
||||
"""Experiment 2e: Noise-tolerant retrieval.
|
||||
|
||||
Problem: WTA pattern separation is brittle to noise in cue embeddings.
|
||||
Real use case requires retrieving from semantically similar (not identical) cues.
|
||||
|
||||
Approaches to test:
|
||||
1. Soft-WTA: Use softmax temperature instead of hard top-k
|
||||
2. Multi-probe: Multiple noisy retrievals + voting
|
||||
3. Coarse-to-fine: Nearest-neighbor in embedding space → exact Hebbian recall
|
||||
4. Learned similarity-preserving hash: train the separator to be noise-robust
|
||||
5. Wider k: trade capacity for noise robustness
|
||||
"""
|
||||
|
||||
import sys
|
||||
import time
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
|
||||
DEVICE = "cuda"
|
||||
RESULTS_DIR = Path(__file__).parent.parent / "doc"
|
||||
|
||||
|
||||
def cosine(a, b):
|
||||
if a.norm() == 0 or b.norm() == 0:
|
||||
return 0.0
|
||||
return nn.functional.cosine_similarity(a.unsqueeze(0), b.unsqueeze(0)).item()
|
||||
|
||||
|
||||
def winner_take_all(x, k):
|
||||
_, topk_idx = x.topk(k, dim=-1)
|
||||
out = torch.zeros_like(x)
|
||||
out.scatter_(-1, topk_idx, 1.0)
|
||||
return out
|
||||
|
||||
|
||||
class SoftWTASeparator(nn.Module):
|
||||
"""Soft winner-take-all using temperature-scaled softmax.
|
||||
Instead of hard binary codes, produces soft sparse codes.
|
||||
More robust to noise but reduces discrimination.
|
||||
"""
|
||||
def __init__(self, input_dim, code_dim, temperature=0.1):
|
||||
super().__init__()
|
||||
self.temperature = temperature
|
||||
proj = torch.randn(input_dim, code_dim) * (1.0 / input_dim**0.5)
|
||||
self.register_buffer('proj', proj)
|
||||
|
||||
def forward(self, x):
|
||||
h = x @ self.proj
|
||||
# Soft WTA: high temp → more spread, low temp → more sparse
|
||||
return torch.softmax(h / self.temperature, dim=-1)
|
||||
|
||||
|
||||
class MultiProbeSeparator(nn.Module):
|
||||
"""Multiple random projections, retrieve from all, majority vote."""
|
||||
def __init__(self, input_dim, code_dim, k_active, num_probes=8):
|
||||
super().__init__()
|
||||
self.k_active = k_active
|
||||
self.num_probes = num_probes
|
||||
# Multiple random projections
|
||||
projs = torch.randn(num_probes, input_dim, code_dim) * (1.0 / input_dim**0.5)
|
||||
self.register_buffer('projs', projs)
|
||||
|
||||
def forward(self, x):
|
||||
"""Returns averaged code across all probes."""
|
||||
votes = torch.zeros(self.projs.shape[2], device=x.device)
|
||||
for i in range(self.num_probes):
|
||||
h = x @ self.projs[i]
|
||||
code = winner_take_all(h, self.k_active)
|
||||
votes += code
|
||||
# Threshold: active if majority of probes agree
|
||||
threshold = self.num_probes / 2
|
||||
return (votes > threshold).float()
|
||||
|
||||
|
||||
class CoarseToFineMemory(nn.Module):
|
||||
"""Coarse: nearest-neighbor in embedding space.
|
||||
Fine: exact Hebbian recall from nearest stored cue.
|
||||
|
||||
This is the most practical approach: SNN/Hebbian provides the
|
||||
association storage, but retrieval is bootstrapped by embedding similarity.
|
||||
"""
|
||||
def __init__(self, input_dim, code_dim=16384, k_active=20):
|
||||
super().__init__()
|
||||
self.code_dim = code_dim
|
||||
self.k_active = k_active
|
||||
|
||||
proj = torch.randn(input_dim, code_dim, device=DEVICE) * (1.0 / input_dim**0.5)
|
||||
self.register_buffer('proj', proj)
|
||||
target_proj = torch.randn(input_dim, code_dim, device=DEVICE) * (1.0 / input_dim**0.5)
|
||||
self.register_buffer('target_proj', target_proj)
|
||||
|
||||
self.W = nn.Parameter(torch.zeros(code_dim, code_dim, device=DEVICE),
|
||||
requires_grad=False)
|
||||
|
||||
# Store raw cue embeddings for nearest-neighbor lookup
|
||||
self.cue_store = []
|
||||
|
||||
def separate(self, x, proj):
|
||||
h = x @ proj
|
||||
return winner_take_all(h, self.k_active)
|
||||
|
||||
def learn(self, cue, target):
|
||||
self.cue_store.append(cue.detach().clone())
|
||||
cue_code = self.separate(cue, self.proj)
|
||||
target_code = self.separate(target, self.target_proj)
|
||||
self.W.data += torch.outer(target_code, cue_code)
|
||||
|
||||
def recall(self, query):
|
||||
"""Coarse: find nearest stored cue. Fine: Hebbian recall."""
|
||||
if not self.cue_store:
|
||||
return torch.zeros(self.code_dim, device=DEVICE)
|
||||
|
||||
# Nearest neighbor
|
||||
cue_matrix = torch.stack(self.cue_store) # [N, dim]
|
||||
sims = nn.functional.cosine_similarity(
|
||||
query.unsqueeze(0), cue_matrix, dim=-1) # [N]
|
||||
best_idx = sims.argmax()
|
||||
best_cue = self.cue_store[best_idx]
|
||||
|
||||
# Exact Hebbian recall with nearest cue
|
||||
cue_code = self.separate(best_cue, self.proj)
|
||||
raw = self.W @ cue_code
|
||||
return winner_take_all(raw, self.k_active)
|
||||
|
||||
|
||||
def test_approach(name, mem_class, num_pairs=100, noise_levels=None, **kwargs):
|
||||
"""Generic test harness."""
|
||||
if noise_levels is None:
|
||||
noise_levels = [0.0, 0.1, 0.2, 0.5, 1.0, 2.0]
|
||||
|
||||
input_dim = 768
|
||||
cues = [nn.functional.normalize(torch.randn(input_dim, device=DEVICE), dim=0)
|
||||
for _ in range(num_pairs)]
|
||||
targets = [nn.functional.normalize(torch.randn(input_dim, device=DEVICE), dim=0)
|
||||
for _ in range(num_pairs)]
|
||||
|
||||
mem = mem_class(**kwargs).to(DEVICE) if not isinstance(mem_class, nn.Module) else mem_class
|
||||
|
||||
# Learn
|
||||
for i in range(num_pairs):
|
||||
mem.learn(cues[i], targets[i])
|
||||
|
||||
results = {}
|
||||
for noise_std in noise_levels:
|
||||
correct_sims = []
|
||||
for i in range(num_pairs):
|
||||
noisy_cue = cues[i] + torch.randn_like(cues[i]) * noise_std
|
||||
noisy_cue = nn.functional.normalize(noisy_cue, dim=0)
|
||||
|
||||
recalled = mem.recall(noisy_cue)
|
||||
|
||||
# Compare to target code
|
||||
if hasattr(mem, 'target_separator'):
|
||||
target_code = mem.target_separator(targets[i])
|
||||
elif hasattr(mem, 'target_proj'):
|
||||
target_code = winner_take_all(targets[i] @ mem.target_proj, mem.k_active)
|
||||
else:
|
||||
target_code = targets[i]
|
||||
|
||||
cs = cosine(recalled, target_code)
|
||||
correct_sims.append(cs)
|
||||
|
||||
mc = np.mean(correct_sims)
|
||||
exact = np.mean([s > 0.99 for s in correct_sims])
|
||||
results[noise_std] = {"mean_cos": mc, "exact_rate": exact}
|
||||
print(f" {name}: noise={noise_std:.2f} → CosSim={mc:.4f}, Exact={exact:.2%}")
|
||||
|
||||
return results
|
||||
|
||||
|
||||
# --- Approach-specific memory classes ---
|
||||
|
||||
class SoftWTAMemory(nn.Module):
|
||||
def __init__(self, input_dim=768, code_dim=16384, temperature=0.1):
|
||||
super().__init__()
|
||||
self.separator = SoftWTASeparator(input_dim, code_dim, temperature)
|
||||
self.target_separator = SoftWTASeparator(input_dim, code_dim, temperature)
|
||||
self.W = nn.Parameter(torch.zeros(code_dim, code_dim), requires_grad=False)
|
||||
|
||||
def learn(self, cue, target):
|
||||
cc = self.separator(cue)
|
||||
tc = self.target_separator(target)
|
||||
self.W.data += torch.outer(tc, cc)
|
||||
|
||||
def recall(self, cue):
|
||||
cc = self.separator(cue)
|
||||
return self.W @ cc
|
||||
|
||||
|
||||
class MultiProbeMemory(nn.Module):
|
||||
def __init__(self, input_dim=768, code_dim=8192, k_active=20, num_probes=16):
|
||||
super().__init__()
|
||||
self.separator = MultiProbeSeparator(input_dim, code_dim, k_active, num_probes)
|
||||
self.target_separator = MultiProbeSeparator(input_dim, code_dim, k_active, num_probes)
|
||||
self.k_active = k_active
|
||||
self.W = nn.Parameter(torch.zeros(code_dim, code_dim), requires_grad=False)
|
||||
|
||||
def learn(self, cue, target):
|
||||
cc = self.separator(cue)
|
||||
tc = self.target_separator(target)
|
||||
self.W.data += torch.outer(tc, cc)
|
||||
|
||||
def recall(self, cue):
|
||||
cc = self.separator(cue)
|
||||
raw = self.W @ cc
|
||||
return winner_take_all(raw, self.k_active)
|
||||
|
||||
|
||||
class WiderKMemory(nn.Module):
|
||||
"""Just use wider k — simple and might work."""
|
||||
def __init__(self, input_dim=768, code_dim=16384, k_active=200):
|
||||
super().__init__()
|
||||
self.k_active = k_active
|
||||
proj = torch.randn(input_dim, code_dim) * (1.0 / input_dim**0.5)
|
||||
self.register_buffer('proj', proj)
|
||||
target_proj = torch.randn(input_dim, code_dim) * (1.0 / input_dim**0.5)
|
||||
self.register_buffer('target_proj', target_proj)
|
||||
self.W = nn.Parameter(torch.zeros(code_dim, code_dim), requires_grad=False)
|
||||
|
||||
def learn(self, cue, target):
|
||||
cc = winner_take_all(cue @ self.proj, self.k_active)
|
||||
tc = winner_take_all(target @ self.target_proj, self.k_active)
|
||||
self.W.data += torch.outer(tc, cc)
|
||||
|
||||
def recall(self, cue):
|
||||
cc = winner_take_all(cue @ self.proj, self.k_active)
|
||||
raw = self.W @ cc
|
||||
return winner_take_all(raw, self.k_active)
|
||||
|
||||
@property
|
||||
def target_separator(self):
|
||||
return None # handled differently
|
||||
|
||||
|
||||
def main():
|
||||
print("=" * 60)
|
||||
print("Experiment 2e: Noise-Tolerant Retrieval")
|
||||
print("=" * 60)
|
||||
|
||||
noise_levels = [0.0, 0.05, 0.1, 0.2, 0.5, 1.0]
|
||||
num_pairs = 100
|
||||
all_results = {}
|
||||
|
||||
# 1. Soft WTA
|
||||
print("\n=== 1. Soft WTA ===")
|
||||
for temp in [0.01, 0.05, 0.1, 0.5]:
|
||||
name = f"soft_wta_t{temp}"
|
||||
print(f"\n-- temperature={temp} --")
|
||||
mem = SoftWTAMemory(temperature=temp).to(DEVICE)
|
||||
|
||||
cues = [nn.functional.normalize(torch.randn(768, device=DEVICE), dim=0) for _ in range(num_pairs)]
|
||||
targets = [nn.functional.normalize(torch.randn(768, device=DEVICE), dim=0) for _ in range(num_pairs)]
|
||||
for i in range(num_pairs):
|
||||
mem.learn(cues[i], targets[i])
|
||||
|
||||
results = {}
|
||||
for ns in noise_levels:
|
||||
sims = []
|
||||
for i in range(num_pairs):
|
||||
noisy = nn.functional.normalize(cues[i] + torch.randn_like(cues[i]) * ns, dim=0)
|
||||
recalled = mem.recall(noisy)
|
||||
tc = mem.target_separator(targets[i])
|
||||
sims.append(cosine(recalled, tc))
|
||||
mc = np.mean(sims)
|
||||
print(f" noise={ns:.2f}: CosSim={mc:.4f}")
|
||||
results[ns] = mc
|
||||
all_results[name] = results
|
||||
|
||||
# 2. Multi-probe
|
||||
print("\n=== 2. Multi-Probe ===")
|
||||
for n_probes in [4, 8, 16, 32]:
|
||||
name = f"multiprobe_{n_probes}"
|
||||
print(f"\n-- probes={n_probes} --")
|
||||
mem = MultiProbeMemory(num_probes=n_probes).to(DEVICE)
|
||||
|
||||
cues = [nn.functional.normalize(torch.randn(768, device=DEVICE), dim=0) for _ in range(num_pairs)]
|
||||
targets = [nn.functional.normalize(torch.randn(768, device=DEVICE), dim=0) for _ in range(num_pairs)]
|
||||
for i in range(num_pairs):
|
||||
mem.learn(cues[i], targets[i])
|
||||
|
||||
results = {}
|
||||
for ns in noise_levels:
|
||||
sims = []
|
||||
for i in range(num_pairs):
|
||||
noisy = nn.functional.normalize(cues[i] + torch.randn_like(cues[i]) * ns, dim=0)
|
||||
recalled = mem.recall(noisy)
|
||||
tc = mem.target_separator(targets[i])
|
||||
sims.append(cosine(recalled, tc))
|
||||
mc = np.mean(sims)
|
||||
print(f" noise={ns:.2f}: CosSim={mc:.4f}")
|
||||
results[ns] = mc
|
||||
all_results[name] = results
|
||||
|
||||
# 3. Coarse-to-fine
|
||||
print("\n=== 3. Coarse-to-Fine (NN + Hebbian) ===")
|
||||
mem = CoarseToFineMemory(768).to(DEVICE)
|
||||
cues = [nn.functional.normalize(torch.randn(768, device=DEVICE), dim=0) for _ in range(num_pairs)]
|
||||
targets = [nn.functional.normalize(torch.randn(768, device=DEVICE), dim=0) for _ in range(num_pairs)]
|
||||
for i in range(num_pairs):
|
||||
mem.learn(cues[i], targets[i])
|
||||
|
||||
results = {}
|
||||
for ns in noise_levels:
|
||||
sims = []
|
||||
for i in range(num_pairs):
|
||||
noisy = nn.functional.normalize(cues[i] + torch.randn_like(cues[i]) * ns, dim=0)
|
||||
recalled = mem.recall(noisy)
|
||||
tc = winner_take_all(targets[i] @ mem.target_proj, mem.k_active)
|
||||
sims.append(cosine(recalled, tc))
|
||||
mc = np.mean(sims)
|
||||
print(f" noise={ns:.2f}: CosSim={mc:.4f}")
|
||||
results[ns] = mc
|
||||
all_results["coarse_to_fine"] = results
|
||||
|
||||
# 4. Wider k
|
||||
print("\n=== 4. Wider K ===")
|
||||
for k in [50, 100, 200, 500, 1000]:
|
||||
name = f"wider_k_{k}"
|
||||
print(f"\n-- k={k} --")
|
||||
mem = WiderKMemory(k_active=k).to(DEVICE)
|
||||
|
||||
cues = [nn.functional.normalize(torch.randn(768, device=DEVICE), dim=0) for _ in range(num_pairs)]
|
||||
targets = [nn.functional.normalize(torch.randn(768, device=DEVICE), dim=0) for _ in range(num_pairs)]
|
||||
for i in range(num_pairs):
|
||||
mem.learn(cues[i], targets[i])
|
||||
|
||||
results = {}
|
||||
for ns in noise_levels:
|
||||
sims = []
|
||||
for i in range(num_pairs):
|
||||
noisy = nn.functional.normalize(cues[i] + torch.randn_like(cues[i]) * ns, dim=0)
|
||||
recalled = mem.recall(noisy)
|
||||
tc = winner_take_all(targets[i] @ mem.target_proj, k)
|
||||
sims.append(cosine(recalled, tc))
|
||||
mc = np.mean(sims)
|
||||
print(f" noise={ns:.2f}: CosSim={mc:.4f}")
|
||||
results[ns] = mc
|
||||
all_results[name] = results
|
||||
|
||||
# Save
|
||||
serializable = {}
|
||||
for k, v in all_results.items():
|
||||
serializable[k] = {str(kk): float(vv) for kk, vv in v.items()}
|
||||
with open(RESULTS_DIR / "exp02e_results.json", "w") as f:
|
||||
json.dump(serializable, f, indent=2)
|
||||
|
||||
# Summary table
|
||||
print("\n" + "=" * 80)
|
||||
print("SUMMARY: CosSim at each noise level")
|
||||
print(f"{'Method':<25}", end="")
|
||||
for ns in noise_levels:
|
||||
print(f" σ={ns:.2f}", end="")
|
||||
print()
|
||||
print("-" * 80)
|
||||
for method, res in all_results.items():
|
||||
print(f"{method:<25}", end="")
|
||||
for ns in noise_levels:
|
||||
v = res.get(ns, 0)
|
||||
print(f" {v:>5.3f}", end="")
|
||||
print()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user