Files
nuonuo/experiments/exp06_biohash.py
Fam Zheng d923aa1e31 NuoNuo: Hippocampal memory module prototype
Hopfield + Hebbian hybrid memory system for LLMs.
Two nights of experiments (16 iterations), validated on LongMemEval (ICLR 2025).

Architecture:
- Single-hop: Two-Stage Hopfield (NN top-20 → softmax settle)
- Multi-hop: Hebbian W matrix with WTA pattern separation
- 64% on LongMemEval (500 questions), retrieval-only, no LLM dependency
- 4ms latency @ 20K memories, ~1GB VRAM

Key findings:
- Hopfield attention solved noise tolerance (20% → 100% vs flat Hebbian)
- WTA pattern separation enables 20K+ capacity
- Multi-hop associative chains (6 hops, CosSim=1.0) — RAG can't do this
- MiniLM-L6 is optimal (discrimination gap > absolute similarity)
- Paraphrase cue augmentation: 55% → 100% on synthetic, 36% → 64% on benchmark
- SNN encoder viable (CosSim 0.99) but not needed for current architecture
2026-04-07 10:37:24 +01:00

382 lines
14 KiB
Python

"""Experiment 6: BioHash — Learnable Fly Algorithm.
Replace random projection with learned projection trained via contrastive loss
on real sentence embeddings. The key insight from Dasgupta 2017 (Science):
random projection + WTA already preserves neighborhoods. Learning the projection
should make it even better.
Training objective:
- Positive pairs (similar sentences): maximize Jaccard overlap of sparse codes
- Negative pairs (different sentences): minimize overlap
Since WTA is not differentiable, we use a soft relaxation during training
(Gumbel-softmax or straight-through estimator) and hard WTA at test time.
"""
import sys
import time
import json
from pathlib import Path
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
DEVICE = "cuda"
RESULTS_DIR = Path(__file__).parent.parent / "doc"
def winner_take_all(x, k):
_, idx = x.topk(k, dim=-1)
out = torch.zeros_like(x)
out.scatter_(-1, idx, 1.0)
return out
def jaccard(a, b):
"""Jaccard similarity of two binary vectors."""
intersection = (a * b).sum(dim=-1)
union = ((a + b) > 0).float().sum(dim=-1)
return (intersection / union.clamp(min=1)).mean().item()
def soft_topk(x, k, temperature=1.0):
"""Differentiable approximation of WTA using softmax."""
# Straight-through estimator: hard WTA forward, soft backward
hard = winner_take_all(x, k)
soft = torch.softmax(x / temperature, dim=-1) * k # scaled softmax
return hard + (soft - soft.detach()) # STE trick
class BioHash(nn.Module):
"""Learnable Fly Hash with WTA sparsification.
Architecture mirrors fruit fly olfactory circuit:
- Projection neurons (PN): input → high-dim (learned, replaces random)
- Kenyon cells (KC): WTA top-k → sparse binary code
"""
def __init__(self, input_dim=384, code_dim=16384, k=50):
super().__init__()
self.k = k
self.code_dim = code_dim
# Learnable projection (replaces random matrix)
self.proj = nn.Linear(input_dim, code_dim, bias=False)
# Initialize like random fly projection
nn.init.normal_(self.proj.weight, std=1.0 / input_dim**0.5)
def forward(self, x, soft=False, temperature=1.0):
"""
x: [batch, input_dim] normalized embeddings
Returns: [batch, code_dim] sparse binary codes
"""
h = self.proj(x) # [batch, code_dim]
if soft:
return soft_topk(h, self.k, temperature)
return winner_take_all(h, self.k)
def encode_hard(self, x):
"""Hard WTA encoding (for inference)."""
with torch.no_grad():
return winner_take_all(self.proj(x), self.k)
class RandomFlyHash(nn.Module):
"""Baseline: original random Fly algorithm (not learned)."""
def __init__(self, input_dim=384, code_dim=16384, k=50):
super().__init__()
self.k = k
proj = torch.randn(input_dim, code_dim) * (1.0 / input_dim**0.5)
self.register_buffer('proj', proj)
def encode_hard(self, x):
with torch.no_grad():
return winner_take_all(x @ self.proj, self.k)
def generate_training_data(model, n_pairs=5000, noise_std=0.3):
"""Generate contrastive pairs from sentence embeddings.
Positive pairs: same sentence with noise (simulating paraphrase)
Negative pairs: different sentences
"""
# Diverse training sentences
templates = [
"The {} is having {} issues",
"We need to {} the {} system",
"The {} team is working on {}",
"There's a bug in the {} {}",
"Let's deploy {} to {}",
"The {} performance is {}",
"How do I configure {}?",
"The {} logs show {}",
"We should monitor the {} {}",
"The {} needs {} upgrade",
]
subjects = ["database", "API", "server", "frontend", "backend",
"auth", "cache", "queue", "storage", "network",
"deployment", "monitoring", "logging", "testing", "CI/CD"]
modifiers = ["critical", "minor", "performance", "security", "timeout",
"memory", "disk", "CPU", "latency", "throughput"]
sentences = []
for t in templates:
for s in subjects:
for m in modifiers:
sentences.append(t.format(s, m))
np.random.shuffle(sentences)
sentences = sentences[:n_pairs * 2] # enough for pairs
# Encode
embs = model.encode(sentences, convert_to_tensor=True,
normalize_embeddings=True, device=DEVICE,
batch_size=256)
return embs
def train_biohash(model, code_dim=16384, k=50, epochs=100, batch_size=256,
lr=1e-3, noise_std=0.3, margin=0.2):
"""Train BioHash with contrastive loss on sentence embeddings."""
embed_dim = model.get_sentence_embedding_dimension()
hasher = BioHash(embed_dim, code_dim, k).to(DEVICE)
optimizer = optim.Adam(hasher.parameters(), lr=lr)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs)
print(f"Training BioHash: code={code_dim}, k={k}, noise={noise_std}")
# Generate training embeddings
embs = generate_training_data(model, n_pairs=5000)
for epoch in range(epochs):
# Sample batch
idx = torch.randperm(embs.shape[0])[:batch_size]
anchor = embs[idx]
# Positive: add noise (simulate paraphrase)
pos = nn.functional.normalize(
anchor + torch.randn_like(anchor) * noise_std, dim=-1)
# Negative: random different embeddings
neg_idx = torch.randperm(embs.shape[0])[:batch_size]
neg = embs[neg_idx]
# Forward with STE
code_anchor = hasher(anchor, soft=True, temperature=0.5)
code_pos = hasher(pos, soft=True, temperature=0.5)
code_neg = hasher(neg, soft=True, temperature=0.5)
# Jaccard-like loss (differentiable via STE)
# Positive overlap: maximize
pos_overlap = (code_anchor * code_pos).sum(dim=-1) / k
# Negative overlap: minimize (with margin)
neg_overlap = (code_anchor * code_neg).sum(dim=-1) / k
loss = -pos_overlap.mean() + torch.relu(neg_overlap - margin).mean()
optimizer.zero_grad()
loss.backward()
nn.utils.clip_grad_norm_(hasher.parameters(), 1.0)
optimizer.step()
scheduler.step()
if (epoch + 1) % 20 == 0:
# Eval with hard WTA
with torch.no_grad():
h_anchor = hasher.encode_hard(anchor)
h_pos = hasher.encode_hard(pos)
h_neg = hasher.encode_hard(neg)
j_pos = jaccard(h_anchor, h_pos)
j_neg = jaccard(h_anchor, h_neg)
print(f" Epoch {epoch+1}: loss={loss.item():.4f}, "
f"Jaccard_pos={j_pos:.4f}, Jaccard_neg={j_neg:.4f}, "
f"gap={j_pos-j_neg:.4f}")
return hasher
def evaluate_recall(hasher, model, label=""):
"""Test associative recall with this hasher."""
# Memory pairs
pairs = [
("What's the weather like today?", "User prefers to check weather every morning"),
("Let's deploy the new version", "The deployment pipeline uses GitHub Actions with k3s"),
("The database is slow again", "Missing index on users table caused slowdown"),
("I need to fix the auth bug", "Auth uses JWT tokens with 24h expiry in Redis"),
("The API returns 500 errors", "Last 500 was OOM in the Python worker"),
("Let's set up monitoring", "Prometheus + Grafana on OCI cluster"),
("The tests are failing", "CI needs postgres service container"),
("Memory usage is too high", "Known leak in websocket handler"),
("Help with Docker setup", "docker-compose for dev, k3s for prod"),
("Log files are too large", "Logs rotate daily, 30 days retention, shipped to Loki"),
]
paraphrases = [
"How's the weather outside?",
"We should push the new release",
"DB performance is terrible",
"There's a login bug to fix",
"Getting internal server errors",
"We need better observability",
"CI tests keep breaking",
"The service is using too much RAM",
"Help me with Docker configuration",
"Logs are eating up disk space",
]
cue_embs = model.encode([p[0] for p in pairs], convert_to_tensor=True,
normalize_embeddings=True, device=DEVICE)
target_embs = model.encode([p[1] for p in pairs], convert_to_tensor=True,
normalize_embeddings=True, device=DEVICE)
para_embs = model.encode(paraphrases, convert_to_tensor=True,
normalize_embeddings=True, device=DEVICE)
# Build Hebbian memory
code_dim = hasher.encode_hard(cue_embs[:1]).shape[-1]
k = int(hasher.encode_hard(cue_embs[:1]).sum().item())
W = torch.zeros(code_dim, code_dim, device=DEVICE)
cue_codes = hasher.encode_hard(cue_embs)
target_codes = hasher.encode_hard(target_embs)
for i in range(len(pairs)):
W += torch.outer(target_codes[i], cue_codes[i])
# Test exact recall
exact_correct = 0
for i in range(len(pairs)):
recalled = winner_take_all(W @ cue_codes[i], k)
sims = nn.functional.cosine_similarity(
recalled.unsqueeze(0), target_codes, dim=-1)
if sims.argmax().item() == i:
exact_correct += 1
# Test paraphrase recall
para_correct = 0
para_codes = hasher.encode_hard(para_embs)
for i in range(len(paraphrases)):
recalled = winner_take_all(W @ para_codes[i], k)
sims = nn.functional.cosine_similarity(
recalled.unsqueeze(0), target_codes, dim=-1)
if sims.argmax().item() == i:
para_correct += 1
# Code overlap analysis
pos_overlaps = []
neg_overlaps = []
for i in range(len(pairs)):
# Positive: cue vs paraphrase
overlap = (cue_codes[i] * para_codes[i]).sum().item() / k
pos_overlaps.append(overlap)
# Negative: cue vs random other paraphrase
j = (i + 1) % len(pairs)
overlap_neg = (cue_codes[i] * para_codes[j]).sum().item() / k
neg_overlaps.append(overlap_neg)
n = len(pairs)
print(f" {label}: Exact={exact_correct}/{n}, Para={para_correct}/{n}, "
f"CodeOverlap: pos={np.mean(pos_overlaps):.3f}, "
f"neg={np.mean(neg_overlaps):.3f}, "
f"gap={np.mean(pos_overlaps)-np.mean(neg_overlaps):.3f}")
return exact_correct / n, para_correct / n, np.mean(pos_overlaps)
def evaluate_at_scale(hasher, model, n_background, label=""):
"""Test with background memories (the real challenge)."""
pairs = [
("The database is slow", "Check missing indexes on users table"),
("Deploy to production", "Use blue-green via GitHub Actions"),
("Server crashed", "Check logs, likely OOM in Python worker"),
("Fix the auth bug", "JWT tokens with 24h expiry in Redis"),
("API returns 500", "OOM in Python worker process"),
]
paraphrases = [
"DB performance terrible",
"Push the new release",
"Server is down",
"Login bug needs fixing",
"Getting 500 errors from API",
]
# Background noise
bg_sentences = [f"Background task {i} about topic {i%20}" for i in range(n_background)]
bg_targets = [f"Background detail {i} with info {i%10}" for i in range(n_background)]
all_cues = [p[0] for p in pairs] + bg_sentences
all_targets = [p[1] for p in pairs] + bg_targets
cue_embs = model.encode(all_cues, convert_to_tensor=True,
normalize_embeddings=True, device=DEVICE, batch_size=256)
target_embs = model.encode(all_targets, convert_to_tensor=True,
normalize_embeddings=True, device=DEVICE, batch_size=256)
para_embs = model.encode(paraphrases, convert_to_tensor=True,
normalize_embeddings=True, device=DEVICE)
# Build memory
cue_codes = hasher.encode_hard(cue_embs)
target_codes = hasher.encode_hard(target_embs)
code_dim = cue_codes.shape[-1]
k = int(cue_codes[0].sum().item())
W = torch.zeros(code_dim, code_dim, device=DEVICE)
for i in range(len(all_cues)):
W += torch.outer(target_codes[i], cue_codes[i])
# Test paraphrase recall
para_codes = hasher.encode_hard(para_embs)
correct = 0
for i in range(len(paraphrases)):
recalled = winner_take_all(W @ para_codes[i], k)
sims = nn.functional.cosine_similarity(
recalled.unsqueeze(0), target_codes[:len(pairs)], dim=-1)
if sims.argmax().item() == i:
correct += 1
n = len(paraphrases)
print(f" {label} (bg={n_background}): Para={correct}/{n} ({correct/n:.0%})")
return correct / n
def main():
print("=" * 60)
print("Experiment 6: BioHash — Learnable Fly Algorithm")
print("=" * 60)
from sentence_transformers import SentenceTransformer
model = SentenceTransformer("all-MiniLM-L6-v2", device=DEVICE)
# Baseline: random projection (current approach)
print("\n=== Baseline: Random Fly Hash ===")
random_hasher = RandomFlyHash(384, 16384, 50).to(DEVICE)
evaluate_recall(random_hasher, model, "Random")
for n_bg in [0, 100, 500]:
evaluate_at_scale(random_hasher, model, n_bg, "Random")
# Train BioHash with different configs
print("\n=== Training BioHash ===")
for noise_std in [0.2, 0.5]:
print(f"\n--- noise_std={noise_std} ---")
hasher = train_biohash(model, code_dim=16384, k=50,
epochs=200, noise_std=noise_std, lr=1e-3)
evaluate_recall(hasher, model, f"BioHash(noise={noise_std})")
for n_bg in [0, 100, 500]:
evaluate_at_scale(hasher, model, n_bg, f"BioHash(noise={noise_std})")
# Try different k values with BioHash
print("\n=== BioHash: k sweep ===")
for k in [20, 50, 100, 200]:
hasher = train_biohash(model, code_dim=16384, k=k,
epochs=200, noise_std=0.3, lr=1e-3)
evaluate_recall(hasher, model, f"BioHash(k={k})")
evaluate_at_scale(hasher, model, 500, f"BioHash(k={k})")
if __name__ == "__main__":
main()