NuoNuo: Hippocampal memory module prototype
Hopfield + Hebbian hybrid memory system for LLMs. Two nights of experiments (16 iterations), validated on LongMemEval (ICLR 2025). Architecture: - Single-hop: Two-Stage Hopfield (NN top-20 → softmax settle) - Multi-hop: Hebbian W matrix with WTA pattern separation - 64% on LongMemEval (500 questions), retrieval-only, no LLM dependency - 4ms latency @ 20K memories, ~1GB VRAM Key findings: - Hopfield attention solved noise tolerance (20% → 100% vs flat Hebbian) - WTA pattern separation enables 20K+ capacity - Multi-hop associative chains (6 hops, CosSim=1.0) — RAG can't do this - MiniLM-L6 is optimal (discrimination gap > absolute similarity) - Paraphrase cue augmentation: 55% → 100% on synthetic, 36% → 64% on benchmark - SNN encoder viable (CosSim 0.99) but not needed for current architecture
This commit is contained in:
381
experiments/exp06_biohash.py
Normal file
381
experiments/exp06_biohash.py
Normal file
@@ -0,0 +1,381 @@
|
||||
"""Experiment 6: BioHash — Learnable Fly Algorithm.
|
||||
|
||||
Replace random projection with learned projection trained via contrastive loss
|
||||
on real sentence embeddings. The key insight from Dasgupta 2017 (Science):
|
||||
random projection + WTA already preserves neighborhoods. Learning the projection
|
||||
should make it even better.
|
||||
|
||||
Training objective:
|
||||
- Positive pairs (similar sentences): maximize Jaccard overlap of sparse codes
|
||||
- Negative pairs (different sentences): minimize overlap
|
||||
|
||||
Since WTA is not differentiable, we use a soft relaxation during training
|
||||
(Gumbel-softmax or straight-through estimator) and hard WTA at test time.
|
||||
"""
|
||||
|
||||
import sys
|
||||
import time
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
import numpy as np
|
||||
|
||||
DEVICE = "cuda"
|
||||
RESULTS_DIR = Path(__file__).parent.parent / "doc"
|
||||
|
||||
|
||||
def winner_take_all(x, k):
|
||||
_, idx = x.topk(k, dim=-1)
|
||||
out = torch.zeros_like(x)
|
||||
out.scatter_(-1, idx, 1.0)
|
||||
return out
|
||||
|
||||
|
||||
def jaccard(a, b):
|
||||
"""Jaccard similarity of two binary vectors."""
|
||||
intersection = (a * b).sum(dim=-1)
|
||||
union = ((a + b) > 0).float().sum(dim=-1)
|
||||
return (intersection / union.clamp(min=1)).mean().item()
|
||||
|
||||
|
||||
def soft_topk(x, k, temperature=1.0):
|
||||
"""Differentiable approximation of WTA using softmax."""
|
||||
# Straight-through estimator: hard WTA forward, soft backward
|
||||
hard = winner_take_all(x, k)
|
||||
soft = torch.softmax(x / temperature, dim=-1) * k # scaled softmax
|
||||
return hard + (soft - soft.detach()) # STE trick
|
||||
|
||||
|
||||
class BioHash(nn.Module):
|
||||
"""Learnable Fly Hash with WTA sparsification.
|
||||
|
||||
Architecture mirrors fruit fly olfactory circuit:
|
||||
- Projection neurons (PN): input → high-dim (learned, replaces random)
|
||||
- Kenyon cells (KC): WTA top-k → sparse binary code
|
||||
"""
|
||||
|
||||
def __init__(self, input_dim=384, code_dim=16384, k=50):
|
||||
super().__init__()
|
||||
self.k = k
|
||||
self.code_dim = code_dim
|
||||
|
||||
# Learnable projection (replaces random matrix)
|
||||
self.proj = nn.Linear(input_dim, code_dim, bias=False)
|
||||
# Initialize like random fly projection
|
||||
nn.init.normal_(self.proj.weight, std=1.0 / input_dim**0.5)
|
||||
|
||||
def forward(self, x, soft=False, temperature=1.0):
|
||||
"""
|
||||
x: [batch, input_dim] normalized embeddings
|
||||
Returns: [batch, code_dim] sparse binary codes
|
||||
"""
|
||||
h = self.proj(x) # [batch, code_dim]
|
||||
if soft:
|
||||
return soft_topk(h, self.k, temperature)
|
||||
return winner_take_all(h, self.k)
|
||||
|
||||
def encode_hard(self, x):
|
||||
"""Hard WTA encoding (for inference)."""
|
||||
with torch.no_grad():
|
||||
return winner_take_all(self.proj(x), self.k)
|
||||
|
||||
|
||||
class RandomFlyHash(nn.Module):
|
||||
"""Baseline: original random Fly algorithm (not learned)."""
|
||||
|
||||
def __init__(self, input_dim=384, code_dim=16384, k=50):
|
||||
super().__init__()
|
||||
self.k = k
|
||||
proj = torch.randn(input_dim, code_dim) * (1.0 / input_dim**0.5)
|
||||
self.register_buffer('proj', proj)
|
||||
|
||||
def encode_hard(self, x):
|
||||
with torch.no_grad():
|
||||
return winner_take_all(x @ self.proj, self.k)
|
||||
|
||||
|
||||
def generate_training_data(model, n_pairs=5000, noise_std=0.3):
|
||||
"""Generate contrastive pairs from sentence embeddings.
|
||||
|
||||
Positive pairs: same sentence with noise (simulating paraphrase)
|
||||
Negative pairs: different sentences
|
||||
"""
|
||||
# Diverse training sentences
|
||||
templates = [
|
||||
"The {} is having {} issues",
|
||||
"We need to {} the {} system",
|
||||
"The {} team is working on {}",
|
||||
"There's a bug in the {} {}",
|
||||
"Let's deploy {} to {}",
|
||||
"The {} performance is {}",
|
||||
"How do I configure {}?",
|
||||
"The {} logs show {}",
|
||||
"We should monitor the {} {}",
|
||||
"The {} needs {} upgrade",
|
||||
]
|
||||
subjects = ["database", "API", "server", "frontend", "backend",
|
||||
"auth", "cache", "queue", "storage", "network",
|
||||
"deployment", "monitoring", "logging", "testing", "CI/CD"]
|
||||
modifiers = ["critical", "minor", "performance", "security", "timeout",
|
||||
"memory", "disk", "CPU", "latency", "throughput"]
|
||||
|
||||
sentences = []
|
||||
for t in templates:
|
||||
for s in subjects:
|
||||
for m in modifiers:
|
||||
sentences.append(t.format(s, m))
|
||||
|
||||
np.random.shuffle(sentences)
|
||||
sentences = sentences[:n_pairs * 2] # enough for pairs
|
||||
|
||||
# Encode
|
||||
embs = model.encode(sentences, convert_to_tensor=True,
|
||||
normalize_embeddings=True, device=DEVICE,
|
||||
batch_size=256)
|
||||
return embs
|
||||
|
||||
|
||||
def train_biohash(model, code_dim=16384, k=50, epochs=100, batch_size=256,
|
||||
lr=1e-3, noise_std=0.3, margin=0.2):
|
||||
"""Train BioHash with contrastive loss on sentence embeddings."""
|
||||
embed_dim = model.get_sentence_embedding_dimension()
|
||||
hasher = BioHash(embed_dim, code_dim, k).to(DEVICE)
|
||||
optimizer = optim.Adam(hasher.parameters(), lr=lr)
|
||||
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs)
|
||||
|
||||
print(f"Training BioHash: code={code_dim}, k={k}, noise={noise_std}")
|
||||
|
||||
# Generate training embeddings
|
||||
embs = generate_training_data(model, n_pairs=5000)
|
||||
|
||||
for epoch in range(epochs):
|
||||
# Sample batch
|
||||
idx = torch.randperm(embs.shape[0])[:batch_size]
|
||||
anchor = embs[idx]
|
||||
|
||||
# Positive: add noise (simulate paraphrase)
|
||||
pos = nn.functional.normalize(
|
||||
anchor + torch.randn_like(anchor) * noise_std, dim=-1)
|
||||
|
||||
# Negative: random different embeddings
|
||||
neg_idx = torch.randperm(embs.shape[0])[:batch_size]
|
||||
neg = embs[neg_idx]
|
||||
|
||||
# Forward with STE
|
||||
code_anchor = hasher(anchor, soft=True, temperature=0.5)
|
||||
code_pos = hasher(pos, soft=True, temperature=0.5)
|
||||
code_neg = hasher(neg, soft=True, temperature=0.5)
|
||||
|
||||
# Jaccard-like loss (differentiable via STE)
|
||||
# Positive overlap: maximize
|
||||
pos_overlap = (code_anchor * code_pos).sum(dim=-1) / k
|
||||
# Negative overlap: minimize (with margin)
|
||||
neg_overlap = (code_anchor * code_neg).sum(dim=-1) / k
|
||||
|
||||
loss = -pos_overlap.mean() + torch.relu(neg_overlap - margin).mean()
|
||||
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
nn.utils.clip_grad_norm_(hasher.parameters(), 1.0)
|
||||
optimizer.step()
|
||||
scheduler.step()
|
||||
|
||||
if (epoch + 1) % 20 == 0:
|
||||
# Eval with hard WTA
|
||||
with torch.no_grad():
|
||||
h_anchor = hasher.encode_hard(anchor)
|
||||
h_pos = hasher.encode_hard(pos)
|
||||
h_neg = hasher.encode_hard(neg)
|
||||
j_pos = jaccard(h_anchor, h_pos)
|
||||
j_neg = jaccard(h_anchor, h_neg)
|
||||
print(f" Epoch {epoch+1}: loss={loss.item():.4f}, "
|
||||
f"Jaccard_pos={j_pos:.4f}, Jaccard_neg={j_neg:.4f}, "
|
||||
f"gap={j_pos-j_neg:.4f}")
|
||||
|
||||
return hasher
|
||||
|
||||
|
||||
def evaluate_recall(hasher, model, label=""):
|
||||
"""Test associative recall with this hasher."""
|
||||
# Memory pairs
|
||||
pairs = [
|
||||
("What's the weather like today?", "User prefers to check weather every morning"),
|
||||
("Let's deploy the new version", "The deployment pipeline uses GitHub Actions with k3s"),
|
||||
("The database is slow again", "Missing index on users table caused slowdown"),
|
||||
("I need to fix the auth bug", "Auth uses JWT tokens with 24h expiry in Redis"),
|
||||
("The API returns 500 errors", "Last 500 was OOM in the Python worker"),
|
||||
("Let's set up monitoring", "Prometheus + Grafana on OCI cluster"),
|
||||
("The tests are failing", "CI needs postgres service container"),
|
||||
("Memory usage is too high", "Known leak in websocket handler"),
|
||||
("Help with Docker setup", "docker-compose for dev, k3s for prod"),
|
||||
("Log files are too large", "Logs rotate daily, 30 days retention, shipped to Loki"),
|
||||
]
|
||||
paraphrases = [
|
||||
"How's the weather outside?",
|
||||
"We should push the new release",
|
||||
"DB performance is terrible",
|
||||
"There's a login bug to fix",
|
||||
"Getting internal server errors",
|
||||
"We need better observability",
|
||||
"CI tests keep breaking",
|
||||
"The service is using too much RAM",
|
||||
"Help me with Docker configuration",
|
||||
"Logs are eating up disk space",
|
||||
]
|
||||
|
||||
cue_embs = model.encode([p[0] for p in pairs], convert_to_tensor=True,
|
||||
normalize_embeddings=True, device=DEVICE)
|
||||
target_embs = model.encode([p[1] for p in pairs], convert_to_tensor=True,
|
||||
normalize_embeddings=True, device=DEVICE)
|
||||
para_embs = model.encode(paraphrases, convert_to_tensor=True,
|
||||
normalize_embeddings=True, device=DEVICE)
|
||||
|
||||
# Build Hebbian memory
|
||||
code_dim = hasher.encode_hard(cue_embs[:1]).shape[-1]
|
||||
k = int(hasher.encode_hard(cue_embs[:1]).sum().item())
|
||||
W = torch.zeros(code_dim, code_dim, device=DEVICE)
|
||||
|
||||
cue_codes = hasher.encode_hard(cue_embs)
|
||||
target_codes = hasher.encode_hard(target_embs)
|
||||
|
||||
for i in range(len(pairs)):
|
||||
W += torch.outer(target_codes[i], cue_codes[i])
|
||||
|
||||
# Test exact recall
|
||||
exact_correct = 0
|
||||
for i in range(len(pairs)):
|
||||
recalled = winner_take_all(W @ cue_codes[i], k)
|
||||
sims = nn.functional.cosine_similarity(
|
||||
recalled.unsqueeze(0), target_codes, dim=-1)
|
||||
if sims.argmax().item() == i:
|
||||
exact_correct += 1
|
||||
|
||||
# Test paraphrase recall
|
||||
para_correct = 0
|
||||
para_codes = hasher.encode_hard(para_embs)
|
||||
for i in range(len(paraphrases)):
|
||||
recalled = winner_take_all(W @ para_codes[i], k)
|
||||
sims = nn.functional.cosine_similarity(
|
||||
recalled.unsqueeze(0), target_codes, dim=-1)
|
||||
if sims.argmax().item() == i:
|
||||
para_correct += 1
|
||||
|
||||
# Code overlap analysis
|
||||
pos_overlaps = []
|
||||
neg_overlaps = []
|
||||
for i in range(len(pairs)):
|
||||
# Positive: cue vs paraphrase
|
||||
overlap = (cue_codes[i] * para_codes[i]).sum().item() / k
|
||||
pos_overlaps.append(overlap)
|
||||
# Negative: cue vs random other paraphrase
|
||||
j = (i + 1) % len(pairs)
|
||||
overlap_neg = (cue_codes[i] * para_codes[j]).sum().item() / k
|
||||
neg_overlaps.append(overlap_neg)
|
||||
|
||||
n = len(pairs)
|
||||
print(f" {label}: Exact={exact_correct}/{n}, Para={para_correct}/{n}, "
|
||||
f"CodeOverlap: pos={np.mean(pos_overlaps):.3f}, "
|
||||
f"neg={np.mean(neg_overlaps):.3f}, "
|
||||
f"gap={np.mean(pos_overlaps)-np.mean(neg_overlaps):.3f}")
|
||||
|
||||
return exact_correct / n, para_correct / n, np.mean(pos_overlaps)
|
||||
|
||||
|
||||
def evaluate_at_scale(hasher, model, n_background, label=""):
|
||||
"""Test with background memories (the real challenge)."""
|
||||
pairs = [
|
||||
("The database is slow", "Check missing indexes on users table"),
|
||||
("Deploy to production", "Use blue-green via GitHub Actions"),
|
||||
("Server crashed", "Check logs, likely OOM in Python worker"),
|
||||
("Fix the auth bug", "JWT tokens with 24h expiry in Redis"),
|
||||
("API returns 500", "OOM in Python worker process"),
|
||||
]
|
||||
paraphrases = [
|
||||
"DB performance terrible",
|
||||
"Push the new release",
|
||||
"Server is down",
|
||||
"Login bug needs fixing",
|
||||
"Getting 500 errors from API",
|
||||
]
|
||||
|
||||
# Background noise
|
||||
bg_sentences = [f"Background task {i} about topic {i%20}" for i in range(n_background)]
|
||||
bg_targets = [f"Background detail {i} with info {i%10}" for i in range(n_background)]
|
||||
|
||||
all_cues = [p[0] for p in pairs] + bg_sentences
|
||||
all_targets = [p[1] for p in pairs] + bg_targets
|
||||
|
||||
cue_embs = model.encode(all_cues, convert_to_tensor=True,
|
||||
normalize_embeddings=True, device=DEVICE, batch_size=256)
|
||||
target_embs = model.encode(all_targets, convert_to_tensor=True,
|
||||
normalize_embeddings=True, device=DEVICE, batch_size=256)
|
||||
para_embs = model.encode(paraphrases, convert_to_tensor=True,
|
||||
normalize_embeddings=True, device=DEVICE)
|
||||
|
||||
# Build memory
|
||||
cue_codes = hasher.encode_hard(cue_embs)
|
||||
target_codes = hasher.encode_hard(target_embs)
|
||||
|
||||
code_dim = cue_codes.shape[-1]
|
||||
k = int(cue_codes[0].sum().item())
|
||||
W = torch.zeros(code_dim, code_dim, device=DEVICE)
|
||||
for i in range(len(all_cues)):
|
||||
W += torch.outer(target_codes[i], cue_codes[i])
|
||||
|
||||
# Test paraphrase recall
|
||||
para_codes = hasher.encode_hard(para_embs)
|
||||
correct = 0
|
||||
for i in range(len(paraphrases)):
|
||||
recalled = winner_take_all(W @ para_codes[i], k)
|
||||
sims = nn.functional.cosine_similarity(
|
||||
recalled.unsqueeze(0), target_codes[:len(pairs)], dim=-1)
|
||||
if sims.argmax().item() == i:
|
||||
correct += 1
|
||||
|
||||
n = len(paraphrases)
|
||||
print(f" {label} (bg={n_background}): Para={correct}/{n} ({correct/n:.0%})")
|
||||
return correct / n
|
||||
|
||||
|
||||
def main():
|
||||
print("=" * 60)
|
||||
print("Experiment 6: BioHash — Learnable Fly Algorithm")
|
||||
print("=" * 60)
|
||||
|
||||
from sentence_transformers import SentenceTransformer
|
||||
model = SentenceTransformer("all-MiniLM-L6-v2", device=DEVICE)
|
||||
|
||||
# Baseline: random projection (current approach)
|
||||
print("\n=== Baseline: Random Fly Hash ===")
|
||||
random_hasher = RandomFlyHash(384, 16384, 50).to(DEVICE)
|
||||
evaluate_recall(random_hasher, model, "Random")
|
||||
|
||||
for n_bg in [0, 100, 500]:
|
||||
evaluate_at_scale(random_hasher, model, n_bg, "Random")
|
||||
|
||||
# Train BioHash with different configs
|
||||
print("\n=== Training BioHash ===")
|
||||
|
||||
for noise_std in [0.2, 0.5]:
|
||||
print(f"\n--- noise_std={noise_std} ---")
|
||||
hasher = train_biohash(model, code_dim=16384, k=50,
|
||||
epochs=200, noise_std=noise_std, lr=1e-3)
|
||||
|
||||
evaluate_recall(hasher, model, f"BioHash(noise={noise_std})")
|
||||
for n_bg in [0, 100, 500]:
|
||||
evaluate_at_scale(hasher, model, n_bg, f"BioHash(noise={noise_std})")
|
||||
|
||||
# Try different k values with BioHash
|
||||
print("\n=== BioHash: k sweep ===")
|
||||
for k in [20, 50, 100, 200]:
|
||||
hasher = train_biohash(model, code_dim=16384, k=k,
|
||||
epochs=200, noise_std=0.3, lr=1e-3)
|
||||
evaluate_recall(hasher, model, f"BioHash(k={k})")
|
||||
evaluate_at_scale(hasher, model, 500, f"BioHash(k={k})")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user