Hopfield + Hebbian hybrid memory system for LLMs. Two nights of experiments (16 iterations), validated on LongMemEval (ICLR 2025). Architecture: - Single-hop: Two-Stage Hopfield (NN top-20 → softmax settle) - Multi-hop: Hebbian W matrix with WTA pattern separation - 64% on LongMemEval (500 questions), retrieval-only, no LLM dependency - 4ms latency @ 20K memories, ~1GB VRAM Key findings: - Hopfield attention solved noise tolerance (20% → 100% vs flat Hebbian) - WTA pattern separation enables 20K+ capacity - Multi-hop associative chains (6 hops, CosSim=1.0) — RAG can't do this - MiniLM-L6 is optimal (discrimination gap > absolute similarity) - Paraphrase cue augmentation: 55% → 100% on synthetic, 36% → 64% on benchmark - SNN encoder viable (CosSim 0.99) but not needed for current architecture
221 lines
9.5 KiB
Python
221 lines
9.5 KiB
Python
"""Experiment P2: Auto Paraphrase Generation.
|
|
|
|
LLM gateway down, so test:
|
|
1. Heuristic paraphrase effect on recall (how much does crappy augmentation help?)
|
|
2. "Oracle" paraphrase (hand-crafted) vs heuristic vs none
|
|
3. Design: what makes a good paraphrase for memory augmentation?
|
|
4. Analysis: which failures are fixable by paraphrase vs need better embeddings?
|
|
"""
|
|
|
|
import sys
|
|
import time
|
|
from pathlib import Path
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import numpy as np
|
|
|
|
sys.path.insert(0, str(Path(__file__).parent.parent / "src"))
|
|
sys.path.insert(0, str(Path(__file__).parent.parent))
|
|
from llm import generate_paraphrases_heuristic
|
|
|
|
DEVICE = "cuda"
|
|
|
|
PAIRS = [
|
|
("What's the weather like today?", "User checks weather every morning"),
|
|
("Let's deploy the new version", "Deployment uses GitHub Actions with k3s"),
|
|
("The database is slow again", "Missing index on users table"),
|
|
("I need to fix the authentication bug", "JWT tokens with 24h expiry in Redis"),
|
|
("The API returns 500 errors", "OOM in the Python worker"),
|
|
("Let's set up monitoring", "Prometheus + Grafana on OCI"),
|
|
("Tests failing in CI", "CI needs postgres service container"),
|
|
("Memory usage too high", "Leak in websocket handler"),
|
|
("Help with Docker setup", "docker-compose for dev, k3s for prod"),
|
|
("Log files too large", "Logs rotate daily, shipped to Loki"),
|
|
("How to add caching?", "Redis available at redis.internal:6379"),
|
|
("Frontend loads slowly", "CDN CloudFlare, 1h TTL for assets"),
|
|
("Refactor payment module", "Stripe API, webhook in payments/webhook.py"),
|
|
("Set up new server", "Ubuntu 22.04, Docker, Tailscale, monitoring"),
|
|
("Optimize search", "Elasticsearch v8, recently upgraded"),
|
|
("Backup the database", "Daily 3am UTC cron to S3"),
|
|
("Configure reverse proxy", "Traefik, not nginx"),
|
|
("Team meeting schedule", "Standup 10am London, Mon-Fri"),
|
|
("Learn a new programming language", "User has Python+Go, new to systems"),
|
|
("Review my pull request", "User prefers small PRs with clear commits"),
|
|
]
|
|
|
|
PARAPHRASES = [
|
|
"How's the weather?", "Ship the release", "DB performance terrible",
|
|
"Fix the login issue", "Server errors everywhere", "Need observability",
|
|
"CI tests breaking", "Service using too much RAM", "Docker config help",
|
|
"Logs eating disk space", "Want to add a cache layer", "Website too slow",
|
|
"Payment code needs rework", "Provision a new machine", "Search is slow",
|
|
"Need a DB backup", "Proxy configuration", "When's the standup?",
|
|
"Want to learn Rust", "Check my pull request",
|
|
]
|
|
|
|
# Oracle paraphrases: hand-crafted to cover the semantic gaps
|
|
ORACLE_PARAPHRASES = {
|
|
1: ["Ship the release", "Push to production", "Release the new build", "Deploy new code"],
|
|
3: ["Fix the login issue", "Authentication broken", "Login doesn't work", "Auth bug"],
|
|
4: ["Server errors everywhere", "Getting 500s", "Internal server error", "API is down"],
|
|
5: ["Need observability", "Set up alerts", "Monitor services", "Add monitoring"],
|
|
10: ["Add a cache layer", "Implement caching", "Cache responses"],
|
|
11: ["Website too slow", "Page loads slowly", "Frontend performance bad"],
|
|
12: ["Payment code needs rework", "Refactor payments", "Payment system restructure"],
|
|
13: ["Provision a new machine", "Need a new server", "Set up new box", "New machine setup"],
|
|
14: ["Search is slow", "Search performance", "Optimize search queries"],
|
|
17: ["When's the standup?", "Meeting time?", "Daily sync schedule", "What time is standup?"],
|
|
18: ["Want to learn Rust", "Learning Rust", "Getting into Rust", "Start with Rust"],
|
|
19: ["Check my pull request", "Look at my code", "PR review please", "Review my code changes"],
|
|
}
|
|
|
|
|
|
def cosine(a, b):
|
|
return nn.functional.cosine_similarity(a.unsqueeze(0), b.unsqueeze(0)).item()
|
|
|
|
|
|
class TwoStageHopfield:
|
|
def __init__(self, beta=16.0, top_k=20):
|
|
self.beta = beta
|
|
self.top_k = top_k
|
|
self.cue_embs = []
|
|
self.target_embs = []
|
|
self.memory_ids = []
|
|
|
|
def learn(self, cue_emb, target_emb, mid):
|
|
self.cue_embs.append(cue_emb.detach())
|
|
self.target_embs.append(target_emb.detach())
|
|
self.memory_ids.append(mid)
|
|
|
|
def recall(self, query_emb, steps=3):
|
|
cue_mat = torch.stack(self.cue_embs)
|
|
target_mat = torch.stack(self.target_embs)
|
|
K = min(self.top_k, len(self.cue_embs))
|
|
sims = query_emb @ cue_mat.T
|
|
_, top_idx = sims.topk(K)
|
|
cand_cues = cue_mat[top_idx]
|
|
cand_targets = target_mat[top_idx]
|
|
cand_mids = [self.memory_ids[i] for i in top_idx.tolist()]
|
|
|
|
xi = query_emb
|
|
for _ in range(steps):
|
|
scores = self.beta * (xi @ cand_cues.T)
|
|
attn = torch.softmax(scores, dim=0)
|
|
xi = attn @ cand_cues
|
|
xi = nn.functional.normalize(xi, dim=0)
|
|
|
|
scores = self.beta * (xi @ cand_cues.T)
|
|
attn = torch.softmax(scores, dim=0)
|
|
|
|
# Aggregate by memory_id
|
|
mid_scores = {}
|
|
for i, mid in enumerate(cand_mids):
|
|
mid_scores[mid] = mid_scores.get(mid, 0) + attn[i].item()
|
|
|
|
best_mid = max(mid_scores, key=mid_scores.get)
|
|
target = nn.functional.normalize(attn @ cand_targets, dim=0)
|
|
return target, best_mid
|
|
|
|
|
|
def evaluate(model, augmentation_mode, n_background=2000):
|
|
"""Test recall with different augmentation strategies."""
|
|
from sentence_transformers import SentenceTransformer
|
|
|
|
cue_embs = model.encode([p[0] for p in PAIRS], convert_to_tensor=True,
|
|
normalize_embeddings=True, device=DEVICE)
|
|
target_embs = model.encode([p[1] for p in PAIRS], convert_to_tensor=True,
|
|
normalize_embeddings=True, device=DEVICE)
|
|
para_embs = model.encode(PARAPHRASES, convert_to_tensor=True,
|
|
normalize_embeddings=True, device=DEVICE)
|
|
|
|
mem = TwoStageHopfield(beta=16.0, top_k=20)
|
|
|
|
for i in range(len(PAIRS)):
|
|
mem.learn(cue_embs[i], target_embs[i], mid=i)
|
|
|
|
if augmentation_mode == "heuristic":
|
|
paras = generate_paraphrases_heuristic(PAIRS[i][0], n=3)
|
|
para_e = model.encode(paras, convert_to_tensor=True,
|
|
normalize_embeddings=True, device=DEVICE)
|
|
for j in range(len(paras)):
|
|
mem.learn(para_e[j], target_embs[i], mid=i)
|
|
|
|
elif augmentation_mode == "oracle":
|
|
if i in ORACLE_PARAPHRASES:
|
|
paras = ORACLE_PARAPHRASES[i]
|
|
para_e = model.encode(paras, convert_to_tensor=True,
|
|
normalize_embeddings=True, device=DEVICE)
|
|
for j in range(len(paras)):
|
|
mem.learn(para_e[j], target_embs[i], mid=i)
|
|
|
|
elif augmentation_mode == "oracle_all":
|
|
# Oracle for all pairs (3 generic paraphrases each)
|
|
if i in ORACLE_PARAPHRASES:
|
|
paras = ORACLE_PARAPHRASES[i]
|
|
else:
|
|
paras = generate_paraphrases_heuristic(PAIRS[i][0], n=3)
|
|
para_e = model.encode(paras, convert_to_tensor=True,
|
|
normalize_embeddings=True, device=DEVICE)
|
|
for j in range(len(paras)):
|
|
mem.learn(para_e[j], target_embs[i], mid=i)
|
|
|
|
# Background
|
|
if n_background > 0:
|
|
topics = ["server", "db", "api", "fe", "be", "cache"]
|
|
bg_cues = [f"The {topics[i%6]} has issue {i}" for i in range(n_background)]
|
|
bg_targets = [f"Fix issue {i}" for i in range(n_background)]
|
|
bg_c = model.encode(bg_cues, convert_to_tensor=True,
|
|
normalize_embeddings=True, device=DEVICE, batch_size=256)
|
|
bg_t = model.encode(bg_targets, convert_to_tensor=True,
|
|
normalize_embeddings=True, device=DEVICE, batch_size=256)
|
|
for i in range(n_background):
|
|
mem.learn(bg_c[i], bg_t[i], mid=100+i)
|
|
|
|
correct = 0
|
|
failures = []
|
|
for i in range(len(PARAPHRASES)):
|
|
_, best_mid = mem.recall(para_embs[i])
|
|
if best_mid == i:
|
|
correct += 1
|
|
else:
|
|
failures.append((i, best_mid))
|
|
|
|
n = len(PARAPHRASES)
|
|
return correct, n, failures
|
|
|
|
|
|
def main():
|
|
print("=" * 60)
|
|
print("Experiment P2: Auto Paraphrase Analysis")
|
|
print("=" * 60)
|
|
|
|
from sentence_transformers import SentenceTransformer
|
|
model = SentenceTransformer("all-MiniLM-L6-v2", device=DEVICE)
|
|
|
|
for bg in [0, 500, 2000]:
|
|
print(f"\n=== Background: {bg} ===")
|
|
for mode in ["none", "heuristic", "oracle", "oracle_all"]:
|
|
correct, n, failures = evaluate(model, mode, n_background=bg)
|
|
fail_ids = [f[0] for f in failures]
|
|
print(f" {mode:<15}: {correct}/{n} ({correct/n:.0%})"
|
|
+ (f" | Failures: {fail_ids}" if failures else ""))
|
|
|
|
# Analyze: which failures are fixable?
|
|
print("\n=== Failure Analysis (2K bg, no augmentation) ===")
|
|
correct, n, failures = evaluate(model, "none", 2000)
|
|
cue_texts = [p[0] for p in PAIRS]
|
|
for qi, gi in failures:
|
|
cue_emb = model.encode([cue_texts[qi]], convert_to_tensor=True,
|
|
normalize_embeddings=True, device=DEVICE)[0]
|
|
para_emb = model.encode([PARAPHRASES[qi]], convert_to_tensor=True,
|
|
normalize_embeddings=True, device=DEVICE)[0]
|
|
sim = cosine(cue_emb, para_emb)
|
|
fixable = qi in ORACLE_PARAPHRASES
|
|
print(f" [{qi}] '{PARAPHRASES[qi][:25]}...' → got [{gi}], "
|
|
f"cue_sim={sim:.3f}, oracle_fix={'✓' if fixable else '✗'}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|