"""Experiment 2e: Noise-tolerant retrieval. Problem: WTA pattern separation is brittle to noise in cue embeddings. Real use case requires retrieving from semantically similar (not identical) cues. Approaches to test: 1. Soft-WTA: Use softmax temperature instead of hard top-k 2. Multi-probe: Multiple noisy retrievals + voting 3. Coarse-to-fine: Nearest-neighbor in embedding space → exact Hebbian recall 4. Learned similarity-preserving hash: train the separator to be noise-robust 5. Wider k: trade capacity for noise robustness """ import sys import time import json from pathlib import Path import torch import torch.nn as nn import numpy as np DEVICE = "cuda" RESULTS_DIR = Path(__file__).parent.parent / "doc" def cosine(a, b): if a.norm() == 0 or b.norm() == 0: return 0.0 return nn.functional.cosine_similarity(a.unsqueeze(0), b.unsqueeze(0)).item() def winner_take_all(x, k): _, topk_idx = x.topk(k, dim=-1) out = torch.zeros_like(x) out.scatter_(-1, topk_idx, 1.0) return out class SoftWTASeparator(nn.Module): """Soft winner-take-all using temperature-scaled softmax. Instead of hard binary codes, produces soft sparse codes. More robust to noise but reduces discrimination. """ def __init__(self, input_dim, code_dim, temperature=0.1): super().__init__() self.temperature = temperature proj = torch.randn(input_dim, code_dim) * (1.0 / input_dim**0.5) self.register_buffer('proj', proj) def forward(self, x): h = x @ self.proj # Soft WTA: high temp → more spread, low temp → more sparse return torch.softmax(h / self.temperature, dim=-1) class MultiProbeSeparator(nn.Module): """Multiple random projections, retrieve from all, majority vote.""" def __init__(self, input_dim, code_dim, k_active, num_probes=8): super().__init__() self.k_active = k_active self.num_probes = num_probes # Multiple random projections projs = torch.randn(num_probes, input_dim, code_dim) * (1.0 / input_dim**0.5) self.register_buffer('projs', projs) def forward(self, x): """Returns averaged code across all probes.""" votes = torch.zeros(self.projs.shape[2], device=x.device) for i in range(self.num_probes): h = x @ self.projs[i] code = winner_take_all(h, self.k_active) votes += code # Threshold: active if majority of probes agree threshold = self.num_probes / 2 return (votes > threshold).float() class CoarseToFineMemory(nn.Module): """Coarse: nearest-neighbor in embedding space. Fine: exact Hebbian recall from nearest stored cue. This is the most practical approach: SNN/Hebbian provides the association storage, but retrieval is bootstrapped by embedding similarity. """ def __init__(self, input_dim, code_dim=16384, k_active=20): super().__init__() self.code_dim = code_dim self.k_active = k_active proj = torch.randn(input_dim, code_dim, device=DEVICE) * (1.0 / input_dim**0.5) self.register_buffer('proj', proj) target_proj = torch.randn(input_dim, code_dim, device=DEVICE) * (1.0 / input_dim**0.5) self.register_buffer('target_proj', target_proj) self.W = nn.Parameter(torch.zeros(code_dim, code_dim, device=DEVICE), requires_grad=False) # Store raw cue embeddings for nearest-neighbor lookup self.cue_store = [] def separate(self, x, proj): h = x @ proj return winner_take_all(h, self.k_active) def learn(self, cue, target): self.cue_store.append(cue.detach().clone()) cue_code = self.separate(cue, self.proj) target_code = self.separate(target, self.target_proj) self.W.data += torch.outer(target_code, cue_code) def recall(self, query): """Coarse: find nearest stored cue. Fine: Hebbian recall.""" if not self.cue_store: return torch.zeros(self.code_dim, device=DEVICE) # Nearest neighbor cue_matrix = torch.stack(self.cue_store) # [N, dim] sims = nn.functional.cosine_similarity( query.unsqueeze(0), cue_matrix, dim=-1) # [N] best_idx = sims.argmax() best_cue = self.cue_store[best_idx] # Exact Hebbian recall with nearest cue cue_code = self.separate(best_cue, self.proj) raw = self.W @ cue_code return winner_take_all(raw, self.k_active) def test_approach(name, mem_class, num_pairs=100, noise_levels=None, **kwargs): """Generic test harness.""" if noise_levels is None: noise_levels = [0.0, 0.1, 0.2, 0.5, 1.0, 2.0] input_dim = 768 cues = [nn.functional.normalize(torch.randn(input_dim, device=DEVICE), dim=0) for _ in range(num_pairs)] targets = [nn.functional.normalize(torch.randn(input_dim, device=DEVICE), dim=0) for _ in range(num_pairs)] mem = mem_class(**kwargs).to(DEVICE) if not isinstance(mem_class, nn.Module) else mem_class # Learn for i in range(num_pairs): mem.learn(cues[i], targets[i]) results = {} for noise_std in noise_levels: correct_sims = [] for i in range(num_pairs): noisy_cue = cues[i] + torch.randn_like(cues[i]) * noise_std noisy_cue = nn.functional.normalize(noisy_cue, dim=0) recalled = mem.recall(noisy_cue) # Compare to target code if hasattr(mem, 'target_separator'): target_code = mem.target_separator(targets[i]) elif hasattr(mem, 'target_proj'): target_code = winner_take_all(targets[i] @ mem.target_proj, mem.k_active) else: target_code = targets[i] cs = cosine(recalled, target_code) correct_sims.append(cs) mc = np.mean(correct_sims) exact = np.mean([s > 0.99 for s in correct_sims]) results[noise_std] = {"mean_cos": mc, "exact_rate": exact} print(f" {name}: noise={noise_std:.2f} → CosSim={mc:.4f}, Exact={exact:.2%}") return results # --- Approach-specific memory classes --- class SoftWTAMemory(nn.Module): def __init__(self, input_dim=768, code_dim=16384, temperature=0.1): super().__init__() self.separator = SoftWTASeparator(input_dim, code_dim, temperature) self.target_separator = SoftWTASeparator(input_dim, code_dim, temperature) self.W = nn.Parameter(torch.zeros(code_dim, code_dim), requires_grad=False) def learn(self, cue, target): cc = self.separator(cue) tc = self.target_separator(target) self.W.data += torch.outer(tc, cc) def recall(self, cue): cc = self.separator(cue) return self.W @ cc class MultiProbeMemory(nn.Module): def __init__(self, input_dim=768, code_dim=8192, k_active=20, num_probes=16): super().__init__() self.separator = MultiProbeSeparator(input_dim, code_dim, k_active, num_probes) self.target_separator = MultiProbeSeparator(input_dim, code_dim, k_active, num_probes) self.k_active = k_active self.W = nn.Parameter(torch.zeros(code_dim, code_dim), requires_grad=False) def learn(self, cue, target): cc = self.separator(cue) tc = self.target_separator(target) self.W.data += torch.outer(tc, cc) def recall(self, cue): cc = self.separator(cue) raw = self.W @ cc return winner_take_all(raw, self.k_active) class WiderKMemory(nn.Module): """Just use wider k — simple and might work.""" def __init__(self, input_dim=768, code_dim=16384, k_active=200): super().__init__() self.k_active = k_active proj = torch.randn(input_dim, code_dim) * (1.0 / input_dim**0.5) self.register_buffer('proj', proj) target_proj = torch.randn(input_dim, code_dim) * (1.0 / input_dim**0.5) self.register_buffer('target_proj', target_proj) self.W = nn.Parameter(torch.zeros(code_dim, code_dim), requires_grad=False) def learn(self, cue, target): cc = winner_take_all(cue @ self.proj, self.k_active) tc = winner_take_all(target @ self.target_proj, self.k_active) self.W.data += torch.outer(tc, cc) def recall(self, cue): cc = winner_take_all(cue @ self.proj, self.k_active) raw = self.W @ cc return winner_take_all(raw, self.k_active) @property def target_separator(self): return None # handled differently def main(): print("=" * 60) print("Experiment 2e: Noise-Tolerant Retrieval") print("=" * 60) noise_levels = [0.0, 0.05, 0.1, 0.2, 0.5, 1.0] num_pairs = 100 all_results = {} # 1. Soft WTA print("\n=== 1. Soft WTA ===") for temp in [0.01, 0.05, 0.1, 0.5]: name = f"soft_wta_t{temp}" print(f"\n-- temperature={temp} --") mem = SoftWTAMemory(temperature=temp).to(DEVICE) cues = [nn.functional.normalize(torch.randn(768, device=DEVICE), dim=0) for _ in range(num_pairs)] targets = [nn.functional.normalize(torch.randn(768, device=DEVICE), dim=0) for _ in range(num_pairs)] for i in range(num_pairs): mem.learn(cues[i], targets[i]) results = {} for ns in noise_levels: sims = [] for i in range(num_pairs): noisy = nn.functional.normalize(cues[i] + torch.randn_like(cues[i]) * ns, dim=0) recalled = mem.recall(noisy) tc = mem.target_separator(targets[i]) sims.append(cosine(recalled, tc)) mc = np.mean(sims) print(f" noise={ns:.2f}: CosSim={mc:.4f}") results[ns] = mc all_results[name] = results # 2. Multi-probe print("\n=== 2. Multi-Probe ===") for n_probes in [4, 8, 16, 32]: name = f"multiprobe_{n_probes}" print(f"\n-- probes={n_probes} --") mem = MultiProbeMemory(num_probes=n_probes).to(DEVICE) cues = [nn.functional.normalize(torch.randn(768, device=DEVICE), dim=0) for _ in range(num_pairs)] targets = [nn.functional.normalize(torch.randn(768, device=DEVICE), dim=0) for _ in range(num_pairs)] for i in range(num_pairs): mem.learn(cues[i], targets[i]) results = {} for ns in noise_levels: sims = [] for i in range(num_pairs): noisy = nn.functional.normalize(cues[i] + torch.randn_like(cues[i]) * ns, dim=0) recalled = mem.recall(noisy) tc = mem.target_separator(targets[i]) sims.append(cosine(recalled, tc)) mc = np.mean(sims) print(f" noise={ns:.2f}: CosSim={mc:.4f}") results[ns] = mc all_results[name] = results # 3. Coarse-to-fine print("\n=== 3. Coarse-to-Fine (NN + Hebbian) ===") mem = CoarseToFineMemory(768).to(DEVICE) cues = [nn.functional.normalize(torch.randn(768, device=DEVICE), dim=0) for _ in range(num_pairs)] targets = [nn.functional.normalize(torch.randn(768, device=DEVICE), dim=0) for _ in range(num_pairs)] for i in range(num_pairs): mem.learn(cues[i], targets[i]) results = {} for ns in noise_levels: sims = [] for i in range(num_pairs): noisy = nn.functional.normalize(cues[i] + torch.randn_like(cues[i]) * ns, dim=0) recalled = mem.recall(noisy) tc = winner_take_all(targets[i] @ mem.target_proj, mem.k_active) sims.append(cosine(recalled, tc)) mc = np.mean(sims) print(f" noise={ns:.2f}: CosSim={mc:.4f}") results[ns] = mc all_results["coarse_to_fine"] = results # 4. Wider k print("\n=== 4. Wider K ===") for k in [50, 100, 200, 500, 1000]: name = f"wider_k_{k}" print(f"\n-- k={k} --") mem = WiderKMemory(k_active=k).to(DEVICE) cues = [nn.functional.normalize(torch.randn(768, device=DEVICE), dim=0) for _ in range(num_pairs)] targets = [nn.functional.normalize(torch.randn(768, device=DEVICE), dim=0) for _ in range(num_pairs)] for i in range(num_pairs): mem.learn(cues[i], targets[i]) results = {} for ns in noise_levels: sims = [] for i in range(num_pairs): noisy = nn.functional.normalize(cues[i] + torch.randn_like(cues[i]) * ns, dim=0) recalled = mem.recall(noisy) tc = winner_take_all(targets[i] @ mem.target_proj, k) sims.append(cosine(recalled, tc)) mc = np.mean(sims) print(f" noise={ns:.2f}: CosSim={mc:.4f}") results[ns] = mc all_results[name] = results # Save serializable = {} for k, v in all_results.items(): serializable[k] = {str(kk): float(vv) for kk, vv in v.items()} with open(RESULTS_DIR / "exp02e_results.json", "w") as f: json.dump(serializable, f, indent=2) # Summary table print("\n" + "=" * 80) print("SUMMARY: CosSim at each noise level") print(f"{'Method':<25}", end="") for ns in noise_levels: print(f" σ={ns:.2f}", end="") print() print("-" * 80) for method, res in all_results.items(): print(f"{method:<25}", end="") for ns in noise_levels: v = res.get(ns, 0) print(f" {v:>5.3f}", end="") print() if __name__ == "__main__": main()