Files
nuonuo/experiments/exp02e_noise_tolerance.py
Fam Zheng d923aa1e31 NuoNuo: Hippocampal memory module prototype
Hopfield + Hebbian hybrid memory system for LLMs.
Two nights of experiments (16 iterations), validated on LongMemEval (ICLR 2025).

Architecture:
- Single-hop: Two-Stage Hopfield (NN top-20 → softmax settle)
- Multi-hop: Hebbian W matrix with WTA pattern separation
- 64% on LongMemEval (500 questions), retrieval-only, no LLM dependency
- 4ms latency @ 20K memories, ~1GB VRAM

Key findings:
- Hopfield attention solved noise tolerance (20% → 100% vs flat Hebbian)
- WTA pattern separation enables 20K+ capacity
- Multi-hop associative chains (6 hops, CosSim=1.0) — RAG can't do this
- MiniLM-L6 is optimal (discrimination gap > absolute similarity)
- Paraphrase cue augmentation: 55% → 100% on synthetic, 36% → 64% on benchmark
- SNN encoder viable (CosSim 0.99) but not needed for current architecture
2026-04-07 10:37:24 +01:00

369 lines
13 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""Experiment 2e: Noise-tolerant retrieval.
Problem: WTA pattern separation is brittle to noise in cue embeddings.
Real use case requires retrieving from semantically similar (not identical) cues.
Approaches to test:
1. Soft-WTA: Use softmax temperature instead of hard top-k
2. Multi-probe: Multiple noisy retrievals + voting
3. Coarse-to-fine: Nearest-neighbor in embedding space → exact Hebbian recall
4. Learned similarity-preserving hash: train the separator to be noise-robust
5. Wider k: trade capacity for noise robustness
"""
import sys
import time
import json
from pathlib import Path
import torch
import torch.nn as nn
import numpy as np
DEVICE = "cuda"
RESULTS_DIR = Path(__file__).parent.parent / "doc"
def cosine(a, b):
if a.norm() == 0 or b.norm() == 0:
return 0.0
return nn.functional.cosine_similarity(a.unsqueeze(0), b.unsqueeze(0)).item()
def winner_take_all(x, k):
_, topk_idx = x.topk(k, dim=-1)
out = torch.zeros_like(x)
out.scatter_(-1, topk_idx, 1.0)
return out
class SoftWTASeparator(nn.Module):
"""Soft winner-take-all using temperature-scaled softmax.
Instead of hard binary codes, produces soft sparse codes.
More robust to noise but reduces discrimination.
"""
def __init__(self, input_dim, code_dim, temperature=0.1):
super().__init__()
self.temperature = temperature
proj = torch.randn(input_dim, code_dim) * (1.0 / input_dim**0.5)
self.register_buffer('proj', proj)
def forward(self, x):
h = x @ self.proj
# Soft WTA: high temp → more spread, low temp → more sparse
return torch.softmax(h / self.temperature, dim=-1)
class MultiProbeSeparator(nn.Module):
"""Multiple random projections, retrieve from all, majority vote."""
def __init__(self, input_dim, code_dim, k_active, num_probes=8):
super().__init__()
self.k_active = k_active
self.num_probes = num_probes
# Multiple random projections
projs = torch.randn(num_probes, input_dim, code_dim) * (1.0 / input_dim**0.5)
self.register_buffer('projs', projs)
def forward(self, x):
"""Returns averaged code across all probes."""
votes = torch.zeros(self.projs.shape[2], device=x.device)
for i in range(self.num_probes):
h = x @ self.projs[i]
code = winner_take_all(h, self.k_active)
votes += code
# Threshold: active if majority of probes agree
threshold = self.num_probes / 2
return (votes > threshold).float()
class CoarseToFineMemory(nn.Module):
"""Coarse: nearest-neighbor in embedding space.
Fine: exact Hebbian recall from nearest stored cue.
This is the most practical approach: SNN/Hebbian provides the
association storage, but retrieval is bootstrapped by embedding similarity.
"""
def __init__(self, input_dim, code_dim=16384, k_active=20):
super().__init__()
self.code_dim = code_dim
self.k_active = k_active
proj = torch.randn(input_dim, code_dim, device=DEVICE) * (1.0 / input_dim**0.5)
self.register_buffer('proj', proj)
target_proj = torch.randn(input_dim, code_dim, device=DEVICE) * (1.0 / input_dim**0.5)
self.register_buffer('target_proj', target_proj)
self.W = nn.Parameter(torch.zeros(code_dim, code_dim, device=DEVICE),
requires_grad=False)
# Store raw cue embeddings for nearest-neighbor lookup
self.cue_store = []
def separate(self, x, proj):
h = x @ proj
return winner_take_all(h, self.k_active)
def learn(self, cue, target):
self.cue_store.append(cue.detach().clone())
cue_code = self.separate(cue, self.proj)
target_code = self.separate(target, self.target_proj)
self.W.data += torch.outer(target_code, cue_code)
def recall(self, query):
"""Coarse: find nearest stored cue. Fine: Hebbian recall."""
if not self.cue_store:
return torch.zeros(self.code_dim, device=DEVICE)
# Nearest neighbor
cue_matrix = torch.stack(self.cue_store) # [N, dim]
sims = nn.functional.cosine_similarity(
query.unsqueeze(0), cue_matrix, dim=-1) # [N]
best_idx = sims.argmax()
best_cue = self.cue_store[best_idx]
# Exact Hebbian recall with nearest cue
cue_code = self.separate(best_cue, self.proj)
raw = self.W @ cue_code
return winner_take_all(raw, self.k_active)
def test_approach(name, mem_class, num_pairs=100, noise_levels=None, **kwargs):
"""Generic test harness."""
if noise_levels is None:
noise_levels = [0.0, 0.1, 0.2, 0.5, 1.0, 2.0]
input_dim = 768
cues = [nn.functional.normalize(torch.randn(input_dim, device=DEVICE), dim=0)
for _ in range(num_pairs)]
targets = [nn.functional.normalize(torch.randn(input_dim, device=DEVICE), dim=0)
for _ in range(num_pairs)]
mem = mem_class(**kwargs).to(DEVICE) if not isinstance(mem_class, nn.Module) else mem_class
# Learn
for i in range(num_pairs):
mem.learn(cues[i], targets[i])
results = {}
for noise_std in noise_levels:
correct_sims = []
for i in range(num_pairs):
noisy_cue = cues[i] + torch.randn_like(cues[i]) * noise_std
noisy_cue = nn.functional.normalize(noisy_cue, dim=0)
recalled = mem.recall(noisy_cue)
# Compare to target code
if hasattr(mem, 'target_separator'):
target_code = mem.target_separator(targets[i])
elif hasattr(mem, 'target_proj'):
target_code = winner_take_all(targets[i] @ mem.target_proj, mem.k_active)
else:
target_code = targets[i]
cs = cosine(recalled, target_code)
correct_sims.append(cs)
mc = np.mean(correct_sims)
exact = np.mean([s > 0.99 for s in correct_sims])
results[noise_std] = {"mean_cos": mc, "exact_rate": exact}
print(f" {name}: noise={noise_std:.2f} → CosSim={mc:.4f}, Exact={exact:.2%}")
return results
# --- Approach-specific memory classes ---
class SoftWTAMemory(nn.Module):
def __init__(self, input_dim=768, code_dim=16384, temperature=0.1):
super().__init__()
self.separator = SoftWTASeparator(input_dim, code_dim, temperature)
self.target_separator = SoftWTASeparator(input_dim, code_dim, temperature)
self.W = nn.Parameter(torch.zeros(code_dim, code_dim), requires_grad=False)
def learn(self, cue, target):
cc = self.separator(cue)
tc = self.target_separator(target)
self.W.data += torch.outer(tc, cc)
def recall(self, cue):
cc = self.separator(cue)
return self.W @ cc
class MultiProbeMemory(nn.Module):
def __init__(self, input_dim=768, code_dim=8192, k_active=20, num_probes=16):
super().__init__()
self.separator = MultiProbeSeparator(input_dim, code_dim, k_active, num_probes)
self.target_separator = MultiProbeSeparator(input_dim, code_dim, k_active, num_probes)
self.k_active = k_active
self.W = nn.Parameter(torch.zeros(code_dim, code_dim), requires_grad=False)
def learn(self, cue, target):
cc = self.separator(cue)
tc = self.target_separator(target)
self.W.data += torch.outer(tc, cc)
def recall(self, cue):
cc = self.separator(cue)
raw = self.W @ cc
return winner_take_all(raw, self.k_active)
class WiderKMemory(nn.Module):
"""Just use wider k — simple and might work."""
def __init__(self, input_dim=768, code_dim=16384, k_active=200):
super().__init__()
self.k_active = k_active
proj = torch.randn(input_dim, code_dim) * (1.0 / input_dim**0.5)
self.register_buffer('proj', proj)
target_proj = torch.randn(input_dim, code_dim) * (1.0 / input_dim**0.5)
self.register_buffer('target_proj', target_proj)
self.W = nn.Parameter(torch.zeros(code_dim, code_dim), requires_grad=False)
def learn(self, cue, target):
cc = winner_take_all(cue @ self.proj, self.k_active)
tc = winner_take_all(target @ self.target_proj, self.k_active)
self.W.data += torch.outer(tc, cc)
def recall(self, cue):
cc = winner_take_all(cue @ self.proj, self.k_active)
raw = self.W @ cc
return winner_take_all(raw, self.k_active)
@property
def target_separator(self):
return None # handled differently
def main():
print("=" * 60)
print("Experiment 2e: Noise-Tolerant Retrieval")
print("=" * 60)
noise_levels = [0.0, 0.05, 0.1, 0.2, 0.5, 1.0]
num_pairs = 100
all_results = {}
# 1. Soft WTA
print("\n=== 1. Soft WTA ===")
for temp in [0.01, 0.05, 0.1, 0.5]:
name = f"soft_wta_t{temp}"
print(f"\n-- temperature={temp} --")
mem = SoftWTAMemory(temperature=temp).to(DEVICE)
cues = [nn.functional.normalize(torch.randn(768, device=DEVICE), dim=0) for _ in range(num_pairs)]
targets = [nn.functional.normalize(torch.randn(768, device=DEVICE), dim=0) for _ in range(num_pairs)]
for i in range(num_pairs):
mem.learn(cues[i], targets[i])
results = {}
for ns in noise_levels:
sims = []
for i in range(num_pairs):
noisy = nn.functional.normalize(cues[i] + torch.randn_like(cues[i]) * ns, dim=0)
recalled = mem.recall(noisy)
tc = mem.target_separator(targets[i])
sims.append(cosine(recalled, tc))
mc = np.mean(sims)
print(f" noise={ns:.2f}: CosSim={mc:.4f}")
results[ns] = mc
all_results[name] = results
# 2. Multi-probe
print("\n=== 2. Multi-Probe ===")
for n_probes in [4, 8, 16, 32]:
name = f"multiprobe_{n_probes}"
print(f"\n-- probes={n_probes} --")
mem = MultiProbeMemory(num_probes=n_probes).to(DEVICE)
cues = [nn.functional.normalize(torch.randn(768, device=DEVICE), dim=0) for _ in range(num_pairs)]
targets = [nn.functional.normalize(torch.randn(768, device=DEVICE), dim=0) for _ in range(num_pairs)]
for i in range(num_pairs):
mem.learn(cues[i], targets[i])
results = {}
for ns in noise_levels:
sims = []
for i in range(num_pairs):
noisy = nn.functional.normalize(cues[i] + torch.randn_like(cues[i]) * ns, dim=0)
recalled = mem.recall(noisy)
tc = mem.target_separator(targets[i])
sims.append(cosine(recalled, tc))
mc = np.mean(sims)
print(f" noise={ns:.2f}: CosSim={mc:.4f}")
results[ns] = mc
all_results[name] = results
# 3. Coarse-to-fine
print("\n=== 3. Coarse-to-Fine (NN + Hebbian) ===")
mem = CoarseToFineMemory(768).to(DEVICE)
cues = [nn.functional.normalize(torch.randn(768, device=DEVICE), dim=0) for _ in range(num_pairs)]
targets = [nn.functional.normalize(torch.randn(768, device=DEVICE), dim=0) for _ in range(num_pairs)]
for i in range(num_pairs):
mem.learn(cues[i], targets[i])
results = {}
for ns in noise_levels:
sims = []
for i in range(num_pairs):
noisy = nn.functional.normalize(cues[i] + torch.randn_like(cues[i]) * ns, dim=0)
recalled = mem.recall(noisy)
tc = winner_take_all(targets[i] @ mem.target_proj, mem.k_active)
sims.append(cosine(recalled, tc))
mc = np.mean(sims)
print(f" noise={ns:.2f}: CosSim={mc:.4f}")
results[ns] = mc
all_results["coarse_to_fine"] = results
# 4. Wider k
print("\n=== 4. Wider K ===")
for k in [50, 100, 200, 500, 1000]:
name = f"wider_k_{k}"
print(f"\n-- k={k} --")
mem = WiderKMemory(k_active=k).to(DEVICE)
cues = [nn.functional.normalize(torch.randn(768, device=DEVICE), dim=0) for _ in range(num_pairs)]
targets = [nn.functional.normalize(torch.randn(768, device=DEVICE), dim=0) for _ in range(num_pairs)]
for i in range(num_pairs):
mem.learn(cues[i], targets[i])
results = {}
for ns in noise_levels:
sims = []
for i in range(num_pairs):
noisy = nn.functional.normalize(cues[i] + torch.randn_like(cues[i]) * ns, dim=0)
recalled = mem.recall(noisy)
tc = winner_take_all(targets[i] @ mem.target_proj, k)
sims.append(cosine(recalled, tc))
mc = np.mean(sims)
print(f" noise={ns:.2f}: CosSim={mc:.4f}")
results[ns] = mc
all_results[name] = results
# Save
serializable = {}
for k, v in all_results.items():
serializable[k] = {str(kk): float(vv) for kk, vv in v.items()}
with open(RESULTS_DIR / "exp02e_results.json", "w") as f:
json.dump(serializable, f, indent=2)
# Summary table
print("\n" + "=" * 80)
print("SUMMARY: CosSim at each noise level")
print(f"{'Method':<25}", end="")
for ns in noise_levels:
print(f" σ={ns:.2f}", end="")
print()
print("-" * 80)
for method, res in all_results.items():
print(f"{method:<25}", end="")
for ns in noise_levels:
v = res.get(ns, 0)
print(f" {v:>5.3f}", end="")
print()
if __name__ == "__main__":
main()