Hopfield + Hebbian hybrid memory system for LLMs. Two nights of experiments (16 iterations), validated on LongMemEval (ICLR 2025). Architecture: - Single-hop: Two-Stage Hopfield (NN top-20 → softmax settle) - Multi-hop: Hebbian W matrix with WTA pattern separation - 64% on LongMemEval (500 questions), retrieval-only, no LLM dependency - 4ms latency @ 20K memories, ~1GB VRAM Key findings: - Hopfield attention solved noise tolerance (20% → 100% vs flat Hebbian) - WTA pattern separation enables 20K+ capacity - Multi-hop associative chains (6 hops, CosSim=1.0) — RAG can't do this - MiniLM-L6 is optimal (discrimination gap > absolute similarity) - Paraphrase cue augmentation: 55% → 100% on synthetic, 36% → 64% on benchmark - SNN encoder viable (CosSim 0.99) but not needed for current architecture
257 lines
10 KiB
Python
257 lines
10 KiB
Python
"""Experiment 4b: Fix multi-hop for real embeddings.
|
|
|
|
Problem: exp04 used separate projections for cues and targets,
|
|
so target codes lived in a different space from cue codes.
|
|
Multi-hop requires: recalled_target_code CAN be used as next cue_code.
|
|
|
|
Fix: Use a SINGLE projection for everything.
|
|
W maps from code_space → code_space.
|
|
W @ sep(A) ≈ sep(B) when we learned (A, B).
|
|
Then W @ sep(B) ≈ sep(C) if we also learned (B, C).
|
|
|
|
Also: retest paraphrase recall with single projection and various code_dim/k.
|
|
"""
|
|
|
|
import sys
|
|
import time
|
|
import json
|
|
from pathlib import Path
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import numpy as np
|
|
|
|
DEVICE = "cuda"
|
|
RESULTS_DIR = Path(__file__).parent.parent / "doc"
|
|
|
|
|
|
def cosine(a, b):
|
|
if a.norm() == 0 or b.norm() == 0:
|
|
return 0.0
|
|
return nn.functional.cosine_similarity(a.unsqueeze(0), b.unsqueeze(0)).item()
|
|
|
|
|
|
def winner_take_all(x, k):
|
|
_, idx = x.topk(k, dim=-1)
|
|
out = torch.zeros_like(x)
|
|
out.scatter_(-1, idx, 1.0)
|
|
return out
|
|
|
|
|
|
class UnifiedHebbianMemory:
|
|
"""Hebbian memory with single unified projection.
|
|
Cues and targets share the same code space → multi-hop works.
|
|
"""
|
|
def __init__(self, input_dim, code_dim=16384, k=20):
|
|
self.k = k
|
|
self.code_dim = code_dim
|
|
self.proj = (torch.randn(input_dim, code_dim, device=DEVICE)
|
|
* (1.0 / input_dim**0.5))
|
|
self.W = torch.zeros(code_dim, code_dim, device=DEVICE)
|
|
self.cue_store = []
|
|
self.target_store = []
|
|
self.metadata = []
|
|
|
|
def sep(self, x):
|
|
return winner_take_all(x @ self.proj, self.k)
|
|
|
|
def learn(self, cue_emb, target_emb, cue_text="", target_text=""):
|
|
cc = self.sep(cue_emb)
|
|
tc = self.sep(target_emb)
|
|
self.W += torch.outer(tc, cc)
|
|
self.cue_store.append(cue_emb.detach().clone())
|
|
self.target_store.append(target_emb.detach().clone())
|
|
self.metadata.append({"cue": cue_text, "target": target_text})
|
|
|
|
def recall(self, query_emb, hops=1):
|
|
code = self.sep(query_emb)
|
|
for _ in range(hops):
|
|
raw = self.W @ code
|
|
code = winner_take_all(raw, self.k)
|
|
return code
|
|
|
|
def recall_coarse_to_fine(self, query_emb):
|
|
"""NN lookup → exact Hebbian recall."""
|
|
cue_matrix = torch.stack(self.cue_store)
|
|
sims = nn.functional.cosine_similarity(
|
|
query_emb.unsqueeze(0), cue_matrix, dim=-1)
|
|
best_idx = sims.argmax()
|
|
code = self.sep(self.cue_store[best_idx])
|
|
raw = self.W @ code
|
|
return winner_take_all(raw, self.k), best_idx.item()
|
|
|
|
def find_nearest_target(self, recalled_code, top_n=3):
|
|
target_codes = [self.sep(t) for t in self.target_store] # Same projection!
|
|
sims = [cosine(recalled_code, tc) for tc in target_codes]
|
|
sorted_idx = np.argsort(sims)[::-1]
|
|
return [(int(i), sims[i], self.metadata[i]) for i in sorted_idx[:top_n]]
|
|
|
|
|
|
MEMORY_PAIRS = [
|
|
("What's the weather like today?", "User prefers to check weather every morning"),
|
|
("Let's deploy the new version", "The deployment pipeline uses GitHub Actions with k3s"),
|
|
("The database is slow again", "Last time DB was slow it was because of missing index on users table"),
|
|
("Can you review my pull request?", "User prefers small PRs with clear commit messages"),
|
|
("I need to fix the authentication bug", "Auth service uses JWT tokens with 24h expiry stored in Redis"),
|
|
("Let's set up monitoring", "Prometheus + Grafana stack is already running on the OCI cluster"),
|
|
("The API is returning 500 errors", "Last 500 error was caused by OOM in the Python worker"),
|
|
("I want to learn Rust", "User has strong Python and Go background, new to systems programming"),
|
|
("Schedule a meeting with the team", "Team standup is at 10am London time, Mon-Fri"),
|
|
("How do I configure nginx?", "The project uses Traefik as reverse proxy, not nginx"),
|
|
("The tests are failing in CI", "CI runs on Gitea Actions, tests need postgres service container"),
|
|
("Let's optimize the search function", "Search uses Elasticsearch, recently upgraded to v8"),
|
|
("I need to backup the database", "Backups run daily at 3am UTC via cron job to S3"),
|
|
("The memory usage is too high", "Python service has a known memory leak in the websocket handler"),
|
|
("Can you help with the Docker setup?", "Project uses docker-compose for local dev, k3s for production"),
|
|
("I want to add caching", "Redis is already available at redis.internal:6379"),
|
|
("The frontend is loading slowly", "CDN is CloudFlare, assets should be cached with 1h TTL"),
|
|
("Let's refactor the payment module", "Payment uses Stripe API, webhook handler is in payments/webhook.py"),
|
|
("I need to set up a new server", "Standard setup: Ubuntu 22.04, Docker, Tailscale, monitoring agent"),
|
|
("The log files are too large", "Logs rotate daily, kept for 30 days, shipped to Loki"),
|
|
]
|
|
|
|
PARAPHRASED_QUERIES = [
|
|
"How's the weather outside?",
|
|
"We should push the new release",
|
|
"The DB performance is terrible",
|
|
"Please look at my code changes",
|
|
"There's a login bug I need to fix",
|
|
"We need better observability",
|
|
"Getting internal server errors from the API",
|
|
"I'm interested in learning a new language like Rust",
|
|
"Need to organize a team meeting",
|
|
"How to set up nginx as a web server?",
|
|
"CI tests keep breaking",
|
|
"The search feature needs to be faster",
|
|
"How do I create a database backup?",
|
|
"The service is using too much RAM",
|
|
"Help me with Docker configuration",
|
|
"I want to implement caching for the API",
|
|
"The website is really slow",
|
|
"The payment system needs restructuring",
|
|
"Setting up a fresh Linux server",
|
|
"Logs are eating up disk space",
|
|
]
|
|
|
|
|
|
def load_model():
|
|
from sentence_transformers import SentenceTransformer
|
|
model = SentenceTransformer("all-MiniLM-L6-v2", device=DEVICE)
|
|
return model
|
|
|
|
|
|
def embed_texts(model, texts):
|
|
return model.encode(texts, convert_to_tensor=True,
|
|
normalize_embeddings=True, device=DEVICE)
|
|
|
|
|
|
def test_multihop(model):
|
|
"""Multi-hop with unified projection."""
|
|
print("\n=== Multi-hop (unified projection) ===")
|
|
|
|
chains = [
|
|
["What's the weather?", "I usually check weather before going out",
|
|
"My favorite coffee shop is around the corner", "They have great latte art"],
|
|
["Let's review the code", "The code review found a memory leak",
|
|
"Memory leaks often cause OOM kills", "We need to add memory limits to k8s pods"],
|
|
["Deploy to production", "Production uses blue-green deployment",
|
|
"The blue environment is currently active", "Switch DNS to green when ready"],
|
|
["The server crashed", "Check the error logs first",
|
|
"Logs show out of memory error", "Need to increase pod memory limit"],
|
|
]
|
|
|
|
embed_dim = model.get_sentence_embedding_dimension()
|
|
|
|
for chain in chains:
|
|
# Separate memory per chain to avoid cross-chain interference
|
|
mem = UnifiedHebbianMemory(embed_dim, code_dim=8192, k=20)
|
|
|
|
chain_embs = [embed_texts(model, [t])[0] for t in chain]
|
|
|
|
# Learn consecutive pairs
|
|
for i in range(len(chain) - 1):
|
|
mem.learn(chain_embs[i], chain_embs[i+1], chain[i], chain[i+1])
|
|
|
|
print(f"\n Chain: {' → '.join([c[:20]+'...' for c in chain])}")
|
|
for hops in range(1, len(chain)):
|
|
recalled = mem.recall(chain_embs[0], hops=hops)
|
|
target_code = mem.sep(chain_embs[hops])
|
|
sim = cosine(recalled, target_code)
|
|
status = "✓" if sim > 0.5 else "✗"
|
|
print(f" {status} {hops} hop(s): → '{chain[hops][:30]}...' sim={sim:.4f}")
|
|
|
|
# Test multi-hop with all chains in ONE memory
|
|
print("\n --- All chains in ONE memory ---")
|
|
mem_all = UnifiedHebbianMemory(embed_dim, code_dim=16384, k=20)
|
|
|
|
all_chain_embs = []
|
|
for chain in chains:
|
|
embs = [embed_texts(model, [t])[0] for t in chain]
|
|
all_chain_embs.append(embs)
|
|
for i in range(len(chain) - 1):
|
|
mem_all.learn(embs[i], embs[i+1], chain[i], chain[i+1])
|
|
|
|
for ci, chain in enumerate(chains):
|
|
for hops in range(1, len(chain)):
|
|
recalled = mem_all.recall(all_chain_embs[ci][0], hops=hops)
|
|
target_code = mem_all.sep(all_chain_embs[ci][hops])
|
|
sim = cosine(recalled, target_code)
|
|
status = "✓" if sim > 0.5 else "✗"
|
|
print(f" {status} Chain{ci+1} {hops}hop: → '{chain[hops][:30]}...' sim={sim:.4f}")
|
|
|
|
|
|
def test_paraphrase_with_configs(model):
|
|
"""Test paraphrase recall with different code_dim/k configs."""
|
|
print("\n=== Paraphrase Recall: Config Sweep ===")
|
|
|
|
embed_dim = model.get_sentence_embedding_dimension()
|
|
cue_embs = embed_texts(model, [p[0] for p in MEMORY_PAIRS])
|
|
target_embs = embed_texts(model, [p[1] for p in MEMORY_PAIRS])
|
|
para_embs = embed_texts(model, PARAPHRASED_QUERIES)
|
|
|
|
configs = [
|
|
(4096, 20), (8192, 20), (16384, 20), (32768, 20),
|
|
(16384, 10), (16384, 50), (16384, 100),
|
|
]
|
|
|
|
for code_dim, k in configs:
|
|
mem = UnifiedHebbianMemory(embed_dim, code_dim, k)
|
|
for i in range(len(MEMORY_PAIRS)):
|
|
mem.learn(cue_embs[i], target_embs[i],
|
|
MEMORY_PAIRS[i][0], MEMORY_PAIRS[i][1])
|
|
|
|
# Direct recall with paraphrased queries
|
|
direct_correct = 0
|
|
coarse_correct = 0
|
|
for i in range(len(PARAPHRASED_QUERIES)):
|
|
# Direct
|
|
recalled = mem.recall(para_embs[i])
|
|
matches = mem.find_nearest_target(recalled, top_n=1)
|
|
if matches[0][0] == i:
|
|
direct_correct += 1
|
|
|
|
# Coarse-to-fine
|
|
recalled_cf, _ = mem.recall_coarse_to_fine(para_embs[i])
|
|
matches_cf = mem.find_nearest_target(recalled_cf, top_n=1)
|
|
if matches_cf[0][0] == i:
|
|
coarse_correct += 1
|
|
|
|
n = len(PARAPHRASED_QUERIES)
|
|
print(f" code={code_dim:>5}, k={k:>3}: "
|
|
f"Direct={direct_correct}/{n} ({direct_correct/n:.0%}), "
|
|
f"Coarse={coarse_correct}/{n} ({coarse_correct/n:.0%})")
|
|
|
|
|
|
def main():
|
|
print("=" * 60)
|
|
print("Experiment 4b: Multi-hop Fix + Config Sweep")
|
|
print("=" * 60)
|
|
|
|
model = load_model()
|
|
test_multihop(model)
|
|
test_paraphrase_with_configs(model)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|