Hopfield + Hebbian hybrid memory system for LLMs. Two nights of experiments (16 iterations), validated on LongMemEval (ICLR 2025). Architecture: - Single-hop: Two-Stage Hopfield (NN top-20 → softmax settle) - Multi-hop: Hebbian W matrix with WTA pattern separation - 64% on LongMemEval (500 questions), retrieval-only, no LLM dependency - 4ms latency @ 20K memories, ~1GB VRAM Key findings: - Hopfield attention solved noise tolerance (20% → 100% vs flat Hebbian) - WTA pattern separation enables 20K+ capacity - Multi-hop associative chains (6 hops, CosSim=1.0) — RAG can't do this - MiniLM-L6 is optimal (discrimination gap > absolute similarity) - Paraphrase cue augmentation: 55% → 100% on synthetic, 36% → 64% on benchmark - SNN encoder viable (CosSim 0.99) but not needed for current architecture
223 lines
8.0 KiB
Python
223 lines
8.0 KiB
Python
"""Experiment P1: Better embedding models.
|
|
|
|
MiniLM (22M) has weak paraphrase similarity for many pairs.
|
|
Test: BGE-small (33M), BGE-base (109M), and E5-small (33M).
|
|
Skip large models (330M+) due to VRAM budget with Hebbian W.
|
|
|
|
Measure:
|
|
1. Paraphrase pair cosine similarity (gap between same/diff pairs)
|
|
2. Recall accuracy with Hopfield at 2K background
|
|
3. Encoding speed
|
|
"""
|
|
|
|
import sys
|
|
import time
|
|
from pathlib import Path
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import numpy as np
|
|
|
|
DEVICE = "cuda"
|
|
|
|
# Test pairs (same as exp07e)
|
|
PAIRS = [
|
|
("What's the weather like today?", "User checks weather every morning"),
|
|
("Let's deploy the new version", "Deployment uses GitHub Actions with k3s"),
|
|
("The database is slow again", "Missing index on users table"),
|
|
("I need to fix the authentication bug", "JWT tokens with 24h expiry in Redis"),
|
|
("The API returns 500 errors", "OOM in the Python worker"),
|
|
("Let's set up monitoring", "Prometheus + Grafana on OCI"),
|
|
("Tests failing in CI", "CI needs postgres service container"),
|
|
("Memory usage too high", "Leak in websocket handler"),
|
|
("Help with Docker setup", "docker-compose for dev, k3s for prod"),
|
|
("Log files too large", "Logs rotate daily, shipped to Loki"),
|
|
("How to add caching?", "Redis available at redis.internal:6379"),
|
|
("Frontend loads slowly", "CDN CloudFlare, 1h TTL for assets"),
|
|
("Refactor payment module", "Stripe API, webhook in payments/webhook.py"),
|
|
("Set up new server", "Ubuntu 22.04, Docker, Tailscale, monitoring"),
|
|
("Optimize search", "Elasticsearch v8, recently upgraded"),
|
|
("Backup the database", "Daily 3am UTC cron to S3"),
|
|
("Configure reverse proxy", "Traefik, not nginx"),
|
|
("Team meeting schedule", "Standup 10am London, Mon-Fri"),
|
|
("Learn a new programming language", "User has Python+Go, new to systems"),
|
|
("Review my pull request", "User prefers small PRs with clear commits"),
|
|
]
|
|
|
|
PARAPHRASES = [
|
|
"How's the weather?", "Ship the release", "DB performance terrible",
|
|
"Fix the login issue", "Server errors everywhere", "Need observability",
|
|
"CI tests breaking", "Service using too much RAM", "Docker config help",
|
|
"Logs eating disk space", "Want to add a cache layer", "Website too slow",
|
|
"Payment code needs rework", "Provision a new machine", "Search is slow",
|
|
"Need a DB backup", "Proxy configuration", "When's the standup?",
|
|
"Want to learn Rust", "Check my pull request",
|
|
]
|
|
|
|
|
|
def winner_take_all(x, k):
|
|
_, idx = x.topk(k, dim=-1)
|
|
out = torch.zeros_like(x)
|
|
out.scatter_(-1, idx, 1.0)
|
|
return out
|
|
|
|
|
|
def cosine(a, b):
|
|
return nn.functional.cosine_similarity(a.unsqueeze(0), b.unsqueeze(0)).item()
|
|
|
|
|
|
class TwoStageHopfield:
|
|
def __init__(self, embed_dim, beta=16.0, top_k=20):
|
|
self.beta = beta
|
|
self.top_k = top_k
|
|
self.cue_embs = []
|
|
self.target_embs = []
|
|
|
|
def learn(self, cue_emb, target_emb):
|
|
self.cue_embs.append(cue_emb.detach())
|
|
self.target_embs.append(target_emb.detach())
|
|
|
|
def recall(self, query_emb, steps=3):
|
|
cue_mat = torch.stack(self.cue_embs)
|
|
target_mat = torch.stack(self.target_embs)
|
|
K = min(self.top_k, len(self.cue_embs))
|
|
sims = query_emb @ cue_mat.T
|
|
_, top_idx = sims.topk(K)
|
|
cand_cues = cue_mat[top_idx]
|
|
cand_targets = target_mat[top_idx]
|
|
|
|
xi = query_emb
|
|
for _ in range(steps):
|
|
scores = self.beta * (xi @ cand_cues.T)
|
|
attn = torch.softmax(scores, dim=0)
|
|
xi = attn @ cand_cues
|
|
xi = nn.functional.normalize(xi, dim=0)
|
|
|
|
scores = self.beta * (xi @ cand_cues.T)
|
|
attn = torch.softmax(scores, dim=0)
|
|
return nn.functional.normalize(attn @ cand_targets, dim=0)
|
|
|
|
|
|
def evaluate_model(model_name):
|
|
"""Full evaluation of one embedding model."""
|
|
from sentence_transformers import SentenceTransformer
|
|
|
|
print(f"\n--- {model_name} ---")
|
|
t0 = time.time()
|
|
model = SentenceTransformer(model_name, device=DEVICE)
|
|
load_time = time.time() - t0
|
|
embed_dim = model.get_sentence_embedding_dimension()
|
|
print(f" Dim: {embed_dim}, Load: {load_time:.1f}s")
|
|
|
|
# 1. Paraphrase similarity gap
|
|
cue_texts = [p[0] for p in PAIRS]
|
|
cue_embs = model.encode(cue_texts, convert_to_tensor=True,
|
|
normalize_embeddings=True, device=DEVICE)
|
|
para_embs = model.encode(PARAPHRASES, convert_to_tensor=True,
|
|
normalize_embeddings=True, device=DEVICE)
|
|
target_embs = model.encode([p[1] for p in PAIRS], convert_to_tensor=True,
|
|
normalize_embeddings=True, device=DEVICE)
|
|
|
|
same_sims = [cosine(cue_embs[i], para_embs[i]) for i in range(len(PAIRS))]
|
|
diff_sims = []
|
|
for i in range(len(PAIRS)):
|
|
for j in range(len(PAIRS)):
|
|
if i != j:
|
|
diff_sims.append(cosine(cue_embs[i], para_embs[j]))
|
|
|
|
mean_same = np.mean(same_sims)
|
|
mean_diff = np.mean(diff_sims)
|
|
min_same = np.min(same_sims)
|
|
gap = mean_same - mean_diff
|
|
|
|
print(f" Similarity: same={mean_same:.3f} (min={min_same:.3f}), "
|
|
f"diff={mean_diff:.3f}, gap={gap:.3f}")
|
|
|
|
# Show worst pairs
|
|
worst_idx = np.argsort(same_sims)[:3]
|
|
for idx in worst_idx:
|
|
print(f" Worst: {same_sims[idx]:.3f} '{cue_texts[idx][:30]}...' ↔ '{PARAPHRASES[idx][:30]}...'")
|
|
|
|
# 2. Encoding speed
|
|
texts_100 = [f"Test sentence number {i} about various topics" for i in range(100)]
|
|
t0 = time.time()
|
|
model.encode(texts_100, convert_to_tensor=True, device=DEVICE)
|
|
speed = 100 / (time.time() - t0)
|
|
print(f" Speed: {speed:.0f} sentences/s")
|
|
|
|
# 3. Recall with 2K background
|
|
mem = TwoStageHopfield(embed_dim, beta=16.0, top_k=20)
|
|
for i in range(len(PAIRS)):
|
|
mem.learn(cue_embs[i], target_embs[i])
|
|
|
|
# Background
|
|
bg_cues = [f"The {['server','db','api','fe','be','cache'][i%6]} has issue {i}"
|
|
for i in range(2000)]
|
|
bg_targets = [f"Fix issue {i}" for i in range(2000)]
|
|
bg_cue_embs = model.encode(bg_cues, convert_to_tensor=True,
|
|
normalize_embeddings=True, device=DEVICE, batch_size=256)
|
|
bg_target_embs = model.encode(bg_targets, convert_to_tensor=True,
|
|
normalize_embeddings=True, device=DEVICE, batch_size=256)
|
|
for i in range(2000):
|
|
mem.learn(bg_cue_embs[i], bg_target_embs[i])
|
|
|
|
correct = 0
|
|
for i in range(len(PARAPHRASES)):
|
|
recalled = mem.recall(para_embs[i])
|
|
all_sims = [cosine(recalled, target_embs[j]) for j in range(len(PAIRS))]
|
|
if np.argmax(all_sims) == i:
|
|
correct += 1
|
|
|
|
n = len(PARAPHRASES)
|
|
print(f" Recall (20 pairs + 2K bg): {correct}/{n} ({correct/n:.0%})")
|
|
|
|
# VRAM
|
|
vram = torch.cuda.memory_allocated() / 1024**2
|
|
print(f" VRAM: {vram:.0f} MB")
|
|
|
|
del model, mem
|
|
torch.cuda.empty_cache()
|
|
|
|
return {
|
|
"model": model_name, "dim": embed_dim,
|
|
"same_sim": mean_same, "diff_sim": mean_diff, "gap": gap,
|
|
"min_same": min_same, "speed": speed,
|
|
"recall": correct / n, "vram_mb": vram,
|
|
}
|
|
|
|
|
|
def main():
|
|
print("=" * 60)
|
|
print("Experiment P1: Embedding Model Comparison")
|
|
print("=" * 60)
|
|
|
|
models = [
|
|
"all-MiniLM-L6-v2", # Baseline, 22M, dim=384
|
|
"BAAI/bge-small-en-v1.5", # 33M, dim=384
|
|
"BAAI/bge-base-en-v1.5", # 109M, dim=768
|
|
"intfloat/e5-small-v2", # 33M, dim=384
|
|
]
|
|
|
|
results = []
|
|
for model_name in models:
|
|
try:
|
|
r = evaluate_model(model_name)
|
|
results.append(r)
|
|
except Exception as e:
|
|
print(f" ERROR: {e}")
|
|
|
|
# Summary table
|
|
print("\n" + "=" * 80)
|
|
print("SUMMARY")
|
|
print(f"{'Model':<30} {'Dim':>4} {'SameSim':>8} {'Gap':>6} "
|
|
f"{'MinSim':>7} {'Recall':>7} {'Speed':>6} {'VRAM':>6}")
|
|
print("-" * 80)
|
|
for r in results:
|
|
print(f"{r['model']:<30} {r['dim']:>4} {r['same_sim']:>8.3f} "
|
|
f"{r['gap']:>6.3f} {r['min_same']:>7.3f} "
|
|
f"{r['recall']:>6.0%} {r['speed']:>5.0f}/s {r['vram_mb']:>5.0f}MB")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|