Hopfield + Hebbian hybrid memory system for LLMs. Two nights of experiments (16 iterations), validated on LongMemEval (ICLR 2025). Architecture: - Single-hop: Two-Stage Hopfield (NN top-20 → softmax settle) - Multi-hop: Hebbian W matrix with WTA pattern separation - 64% on LongMemEval (500 questions), retrieval-only, no LLM dependency - 4ms latency @ 20K memories, ~1GB VRAM Key findings: - Hopfield attention solved noise tolerance (20% → 100% vs flat Hebbian) - WTA pattern separation enables 20K+ capacity - Multi-hop associative chains (6 hops, CosSim=1.0) — RAG can't do this - MiniLM-L6 is optimal (discrimination gap > absolute similarity) - Paraphrase cue augmentation: 55% → 100% on synthetic, 36% → 64% on benchmark - SNN encoder viable (CosSim 0.99) but not needed for current architecture
382 lines
14 KiB
Python
382 lines
14 KiB
Python
"""Experiment 6: BioHash — Learnable Fly Algorithm.
|
|
|
|
Replace random projection with learned projection trained via contrastive loss
|
|
on real sentence embeddings. The key insight from Dasgupta 2017 (Science):
|
|
random projection + WTA already preserves neighborhoods. Learning the projection
|
|
should make it even better.
|
|
|
|
Training objective:
|
|
- Positive pairs (similar sentences): maximize Jaccard overlap of sparse codes
|
|
- Negative pairs (different sentences): minimize overlap
|
|
|
|
Since WTA is not differentiable, we use a soft relaxation during training
|
|
(Gumbel-softmax or straight-through estimator) and hard WTA at test time.
|
|
"""
|
|
|
|
import sys
|
|
import time
|
|
import json
|
|
from pathlib import Path
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.optim as optim
|
|
import numpy as np
|
|
|
|
DEVICE = "cuda"
|
|
RESULTS_DIR = Path(__file__).parent.parent / "doc"
|
|
|
|
|
|
def winner_take_all(x, k):
|
|
_, idx = x.topk(k, dim=-1)
|
|
out = torch.zeros_like(x)
|
|
out.scatter_(-1, idx, 1.0)
|
|
return out
|
|
|
|
|
|
def jaccard(a, b):
|
|
"""Jaccard similarity of two binary vectors."""
|
|
intersection = (a * b).sum(dim=-1)
|
|
union = ((a + b) > 0).float().sum(dim=-1)
|
|
return (intersection / union.clamp(min=1)).mean().item()
|
|
|
|
|
|
def soft_topk(x, k, temperature=1.0):
|
|
"""Differentiable approximation of WTA using softmax."""
|
|
# Straight-through estimator: hard WTA forward, soft backward
|
|
hard = winner_take_all(x, k)
|
|
soft = torch.softmax(x / temperature, dim=-1) * k # scaled softmax
|
|
return hard + (soft - soft.detach()) # STE trick
|
|
|
|
|
|
class BioHash(nn.Module):
|
|
"""Learnable Fly Hash with WTA sparsification.
|
|
|
|
Architecture mirrors fruit fly olfactory circuit:
|
|
- Projection neurons (PN): input → high-dim (learned, replaces random)
|
|
- Kenyon cells (KC): WTA top-k → sparse binary code
|
|
"""
|
|
|
|
def __init__(self, input_dim=384, code_dim=16384, k=50):
|
|
super().__init__()
|
|
self.k = k
|
|
self.code_dim = code_dim
|
|
|
|
# Learnable projection (replaces random matrix)
|
|
self.proj = nn.Linear(input_dim, code_dim, bias=False)
|
|
# Initialize like random fly projection
|
|
nn.init.normal_(self.proj.weight, std=1.0 / input_dim**0.5)
|
|
|
|
def forward(self, x, soft=False, temperature=1.0):
|
|
"""
|
|
x: [batch, input_dim] normalized embeddings
|
|
Returns: [batch, code_dim] sparse binary codes
|
|
"""
|
|
h = self.proj(x) # [batch, code_dim]
|
|
if soft:
|
|
return soft_topk(h, self.k, temperature)
|
|
return winner_take_all(h, self.k)
|
|
|
|
def encode_hard(self, x):
|
|
"""Hard WTA encoding (for inference)."""
|
|
with torch.no_grad():
|
|
return winner_take_all(self.proj(x), self.k)
|
|
|
|
|
|
class RandomFlyHash(nn.Module):
|
|
"""Baseline: original random Fly algorithm (not learned)."""
|
|
|
|
def __init__(self, input_dim=384, code_dim=16384, k=50):
|
|
super().__init__()
|
|
self.k = k
|
|
proj = torch.randn(input_dim, code_dim) * (1.0 / input_dim**0.5)
|
|
self.register_buffer('proj', proj)
|
|
|
|
def encode_hard(self, x):
|
|
with torch.no_grad():
|
|
return winner_take_all(x @ self.proj, self.k)
|
|
|
|
|
|
def generate_training_data(model, n_pairs=5000, noise_std=0.3):
|
|
"""Generate contrastive pairs from sentence embeddings.
|
|
|
|
Positive pairs: same sentence with noise (simulating paraphrase)
|
|
Negative pairs: different sentences
|
|
"""
|
|
# Diverse training sentences
|
|
templates = [
|
|
"The {} is having {} issues",
|
|
"We need to {} the {} system",
|
|
"The {} team is working on {}",
|
|
"There's a bug in the {} {}",
|
|
"Let's deploy {} to {}",
|
|
"The {} performance is {}",
|
|
"How do I configure {}?",
|
|
"The {} logs show {}",
|
|
"We should monitor the {} {}",
|
|
"The {} needs {} upgrade",
|
|
]
|
|
subjects = ["database", "API", "server", "frontend", "backend",
|
|
"auth", "cache", "queue", "storage", "network",
|
|
"deployment", "monitoring", "logging", "testing", "CI/CD"]
|
|
modifiers = ["critical", "minor", "performance", "security", "timeout",
|
|
"memory", "disk", "CPU", "latency", "throughput"]
|
|
|
|
sentences = []
|
|
for t in templates:
|
|
for s in subjects:
|
|
for m in modifiers:
|
|
sentences.append(t.format(s, m))
|
|
|
|
np.random.shuffle(sentences)
|
|
sentences = sentences[:n_pairs * 2] # enough for pairs
|
|
|
|
# Encode
|
|
embs = model.encode(sentences, convert_to_tensor=True,
|
|
normalize_embeddings=True, device=DEVICE,
|
|
batch_size=256)
|
|
return embs
|
|
|
|
|
|
def train_biohash(model, code_dim=16384, k=50, epochs=100, batch_size=256,
|
|
lr=1e-3, noise_std=0.3, margin=0.2):
|
|
"""Train BioHash with contrastive loss on sentence embeddings."""
|
|
embed_dim = model.get_sentence_embedding_dimension()
|
|
hasher = BioHash(embed_dim, code_dim, k).to(DEVICE)
|
|
optimizer = optim.Adam(hasher.parameters(), lr=lr)
|
|
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs)
|
|
|
|
print(f"Training BioHash: code={code_dim}, k={k}, noise={noise_std}")
|
|
|
|
# Generate training embeddings
|
|
embs = generate_training_data(model, n_pairs=5000)
|
|
|
|
for epoch in range(epochs):
|
|
# Sample batch
|
|
idx = torch.randperm(embs.shape[0])[:batch_size]
|
|
anchor = embs[idx]
|
|
|
|
# Positive: add noise (simulate paraphrase)
|
|
pos = nn.functional.normalize(
|
|
anchor + torch.randn_like(anchor) * noise_std, dim=-1)
|
|
|
|
# Negative: random different embeddings
|
|
neg_idx = torch.randperm(embs.shape[0])[:batch_size]
|
|
neg = embs[neg_idx]
|
|
|
|
# Forward with STE
|
|
code_anchor = hasher(anchor, soft=True, temperature=0.5)
|
|
code_pos = hasher(pos, soft=True, temperature=0.5)
|
|
code_neg = hasher(neg, soft=True, temperature=0.5)
|
|
|
|
# Jaccard-like loss (differentiable via STE)
|
|
# Positive overlap: maximize
|
|
pos_overlap = (code_anchor * code_pos).sum(dim=-1) / k
|
|
# Negative overlap: minimize (with margin)
|
|
neg_overlap = (code_anchor * code_neg).sum(dim=-1) / k
|
|
|
|
loss = -pos_overlap.mean() + torch.relu(neg_overlap - margin).mean()
|
|
|
|
optimizer.zero_grad()
|
|
loss.backward()
|
|
nn.utils.clip_grad_norm_(hasher.parameters(), 1.0)
|
|
optimizer.step()
|
|
scheduler.step()
|
|
|
|
if (epoch + 1) % 20 == 0:
|
|
# Eval with hard WTA
|
|
with torch.no_grad():
|
|
h_anchor = hasher.encode_hard(anchor)
|
|
h_pos = hasher.encode_hard(pos)
|
|
h_neg = hasher.encode_hard(neg)
|
|
j_pos = jaccard(h_anchor, h_pos)
|
|
j_neg = jaccard(h_anchor, h_neg)
|
|
print(f" Epoch {epoch+1}: loss={loss.item():.4f}, "
|
|
f"Jaccard_pos={j_pos:.4f}, Jaccard_neg={j_neg:.4f}, "
|
|
f"gap={j_pos-j_neg:.4f}")
|
|
|
|
return hasher
|
|
|
|
|
|
def evaluate_recall(hasher, model, label=""):
|
|
"""Test associative recall with this hasher."""
|
|
# Memory pairs
|
|
pairs = [
|
|
("What's the weather like today?", "User prefers to check weather every morning"),
|
|
("Let's deploy the new version", "The deployment pipeline uses GitHub Actions with k3s"),
|
|
("The database is slow again", "Missing index on users table caused slowdown"),
|
|
("I need to fix the auth bug", "Auth uses JWT tokens with 24h expiry in Redis"),
|
|
("The API returns 500 errors", "Last 500 was OOM in the Python worker"),
|
|
("Let's set up monitoring", "Prometheus + Grafana on OCI cluster"),
|
|
("The tests are failing", "CI needs postgres service container"),
|
|
("Memory usage is too high", "Known leak in websocket handler"),
|
|
("Help with Docker setup", "docker-compose for dev, k3s for prod"),
|
|
("Log files are too large", "Logs rotate daily, 30 days retention, shipped to Loki"),
|
|
]
|
|
paraphrases = [
|
|
"How's the weather outside?",
|
|
"We should push the new release",
|
|
"DB performance is terrible",
|
|
"There's a login bug to fix",
|
|
"Getting internal server errors",
|
|
"We need better observability",
|
|
"CI tests keep breaking",
|
|
"The service is using too much RAM",
|
|
"Help me with Docker configuration",
|
|
"Logs are eating up disk space",
|
|
]
|
|
|
|
cue_embs = model.encode([p[0] for p in pairs], convert_to_tensor=True,
|
|
normalize_embeddings=True, device=DEVICE)
|
|
target_embs = model.encode([p[1] for p in pairs], convert_to_tensor=True,
|
|
normalize_embeddings=True, device=DEVICE)
|
|
para_embs = model.encode(paraphrases, convert_to_tensor=True,
|
|
normalize_embeddings=True, device=DEVICE)
|
|
|
|
# Build Hebbian memory
|
|
code_dim = hasher.encode_hard(cue_embs[:1]).shape[-1]
|
|
k = int(hasher.encode_hard(cue_embs[:1]).sum().item())
|
|
W = torch.zeros(code_dim, code_dim, device=DEVICE)
|
|
|
|
cue_codes = hasher.encode_hard(cue_embs)
|
|
target_codes = hasher.encode_hard(target_embs)
|
|
|
|
for i in range(len(pairs)):
|
|
W += torch.outer(target_codes[i], cue_codes[i])
|
|
|
|
# Test exact recall
|
|
exact_correct = 0
|
|
for i in range(len(pairs)):
|
|
recalled = winner_take_all(W @ cue_codes[i], k)
|
|
sims = nn.functional.cosine_similarity(
|
|
recalled.unsqueeze(0), target_codes, dim=-1)
|
|
if sims.argmax().item() == i:
|
|
exact_correct += 1
|
|
|
|
# Test paraphrase recall
|
|
para_correct = 0
|
|
para_codes = hasher.encode_hard(para_embs)
|
|
for i in range(len(paraphrases)):
|
|
recalled = winner_take_all(W @ para_codes[i], k)
|
|
sims = nn.functional.cosine_similarity(
|
|
recalled.unsqueeze(0), target_codes, dim=-1)
|
|
if sims.argmax().item() == i:
|
|
para_correct += 1
|
|
|
|
# Code overlap analysis
|
|
pos_overlaps = []
|
|
neg_overlaps = []
|
|
for i in range(len(pairs)):
|
|
# Positive: cue vs paraphrase
|
|
overlap = (cue_codes[i] * para_codes[i]).sum().item() / k
|
|
pos_overlaps.append(overlap)
|
|
# Negative: cue vs random other paraphrase
|
|
j = (i + 1) % len(pairs)
|
|
overlap_neg = (cue_codes[i] * para_codes[j]).sum().item() / k
|
|
neg_overlaps.append(overlap_neg)
|
|
|
|
n = len(pairs)
|
|
print(f" {label}: Exact={exact_correct}/{n}, Para={para_correct}/{n}, "
|
|
f"CodeOverlap: pos={np.mean(pos_overlaps):.3f}, "
|
|
f"neg={np.mean(neg_overlaps):.3f}, "
|
|
f"gap={np.mean(pos_overlaps)-np.mean(neg_overlaps):.3f}")
|
|
|
|
return exact_correct / n, para_correct / n, np.mean(pos_overlaps)
|
|
|
|
|
|
def evaluate_at_scale(hasher, model, n_background, label=""):
|
|
"""Test with background memories (the real challenge)."""
|
|
pairs = [
|
|
("The database is slow", "Check missing indexes on users table"),
|
|
("Deploy to production", "Use blue-green via GitHub Actions"),
|
|
("Server crashed", "Check logs, likely OOM in Python worker"),
|
|
("Fix the auth bug", "JWT tokens with 24h expiry in Redis"),
|
|
("API returns 500", "OOM in Python worker process"),
|
|
]
|
|
paraphrases = [
|
|
"DB performance terrible",
|
|
"Push the new release",
|
|
"Server is down",
|
|
"Login bug needs fixing",
|
|
"Getting 500 errors from API",
|
|
]
|
|
|
|
# Background noise
|
|
bg_sentences = [f"Background task {i} about topic {i%20}" for i in range(n_background)]
|
|
bg_targets = [f"Background detail {i} with info {i%10}" for i in range(n_background)]
|
|
|
|
all_cues = [p[0] for p in pairs] + bg_sentences
|
|
all_targets = [p[1] for p in pairs] + bg_targets
|
|
|
|
cue_embs = model.encode(all_cues, convert_to_tensor=True,
|
|
normalize_embeddings=True, device=DEVICE, batch_size=256)
|
|
target_embs = model.encode(all_targets, convert_to_tensor=True,
|
|
normalize_embeddings=True, device=DEVICE, batch_size=256)
|
|
para_embs = model.encode(paraphrases, convert_to_tensor=True,
|
|
normalize_embeddings=True, device=DEVICE)
|
|
|
|
# Build memory
|
|
cue_codes = hasher.encode_hard(cue_embs)
|
|
target_codes = hasher.encode_hard(target_embs)
|
|
|
|
code_dim = cue_codes.shape[-1]
|
|
k = int(cue_codes[0].sum().item())
|
|
W = torch.zeros(code_dim, code_dim, device=DEVICE)
|
|
for i in range(len(all_cues)):
|
|
W += torch.outer(target_codes[i], cue_codes[i])
|
|
|
|
# Test paraphrase recall
|
|
para_codes = hasher.encode_hard(para_embs)
|
|
correct = 0
|
|
for i in range(len(paraphrases)):
|
|
recalled = winner_take_all(W @ para_codes[i], k)
|
|
sims = nn.functional.cosine_similarity(
|
|
recalled.unsqueeze(0), target_codes[:len(pairs)], dim=-1)
|
|
if sims.argmax().item() == i:
|
|
correct += 1
|
|
|
|
n = len(paraphrases)
|
|
print(f" {label} (bg={n_background}): Para={correct}/{n} ({correct/n:.0%})")
|
|
return correct / n
|
|
|
|
|
|
def main():
|
|
print("=" * 60)
|
|
print("Experiment 6: BioHash — Learnable Fly Algorithm")
|
|
print("=" * 60)
|
|
|
|
from sentence_transformers import SentenceTransformer
|
|
model = SentenceTransformer("all-MiniLM-L6-v2", device=DEVICE)
|
|
|
|
# Baseline: random projection (current approach)
|
|
print("\n=== Baseline: Random Fly Hash ===")
|
|
random_hasher = RandomFlyHash(384, 16384, 50).to(DEVICE)
|
|
evaluate_recall(random_hasher, model, "Random")
|
|
|
|
for n_bg in [0, 100, 500]:
|
|
evaluate_at_scale(random_hasher, model, n_bg, "Random")
|
|
|
|
# Train BioHash with different configs
|
|
print("\n=== Training BioHash ===")
|
|
|
|
for noise_std in [0.2, 0.5]:
|
|
print(f"\n--- noise_std={noise_std} ---")
|
|
hasher = train_biohash(model, code_dim=16384, k=50,
|
|
epochs=200, noise_std=noise_std, lr=1e-3)
|
|
|
|
evaluate_recall(hasher, model, f"BioHash(noise={noise_std})")
|
|
for n_bg in [0, 100, 500]:
|
|
evaluate_at_scale(hasher, model, n_bg, f"BioHash(noise={noise_std})")
|
|
|
|
# Try different k values with BioHash
|
|
print("\n=== BioHash: k sweep ===")
|
|
for k in [20, 50, 100, 200]:
|
|
hasher = train_biohash(model, code_dim=16384, k=k,
|
|
epochs=200, noise_std=0.3, lr=1e-3)
|
|
evaluate_recall(hasher, model, f"BioHash(k={k})")
|
|
evaluate_at_scale(hasher, model, 500, f"BioHash(k={k})")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|