Hopfield + Hebbian hybrid memory system for LLMs. Two nights of experiments (16 iterations), validated on LongMemEval (ICLR 2025). Architecture: - Single-hop: Two-Stage Hopfield (NN top-20 → softmax settle) - Multi-hop: Hebbian W matrix with WTA pattern separation - 64% on LongMemEval (500 questions), retrieval-only, no LLM dependency - 4ms latency @ 20K memories, ~1GB VRAM Key findings: - Hopfield attention solved noise tolerance (20% → 100% vs flat Hebbian) - WTA pattern separation enables 20K+ capacity - Multi-hop associative chains (6 hops, CosSim=1.0) — RAG can't do this - MiniLM-L6 is optimal (discrimination gap > absolute similarity) - Paraphrase cue augmentation: 55% → 100% on synthetic, 36% → 64% on benchmark - SNN encoder viable (CosSim 0.99) but not needed for current architecture
259 lines
9.4 KiB
Python
259 lines
9.4 KiB
Python
"""Experiment P4: Memory Lifecycle Management.
|
||
|
||
Questions:
|
||
1. What's worth storing? (not everything in a conversation is a "memory")
|
||
2. When to forget? (access-based decay, age-based decay, capacity pressure)
|
||
3. Can we merge similar memories? (deduplification / compression)
|
||
4. Importance scoring: how to prioritize during recall and forgetting?
|
||
|
||
Strategy: implement and test each mechanism, measure impact on recall quality.
|
||
"""
|
||
|
||
import sys
|
||
import time
|
||
from pathlib import Path
|
||
from collections import Counter
|
||
|
||
import torch
|
||
import torch.nn as nn
|
||
import numpy as np
|
||
|
||
sys.path.insert(0, str(Path(__file__).parent.parent / "src"))
|
||
from nuonuo.hippocampus import HippocampalMemory
|
||
|
||
DEVICE = "cuda"
|
||
|
||
|
||
def cosine(a, b):
|
||
return nn.functional.cosine_similarity(a.unsqueeze(0), b.unsqueeze(0)).item()
|
||
|
||
|
||
def load_model():
|
||
from sentence_transformers import SentenceTransformer
|
||
return SentenceTransformer("all-MiniLM-L6-v2", device=DEVICE)
|
||
|
||
|
||
def emb(model, text):
|
||
return model.encode([text], convert_to_tensor=True,
|
||
normalize_embeddings=True, device=DEVICE)[0]
|
||
|
||
|
||
def test_deduplication(model):
|
||
"""Test: can we detect and merge duplicate/near-duplicate memories?"""
|
||
print("=== Test 1: Deduplication ===\n")
|
||
|
||
mem = HippocampalMemory(embed_dim=384)
|
||
|
||
# Store some memories with near-duplicates
|
||
memories = [
|
||
("The database is slow", "Check missing indexes"),
|
||
("Database is really slow today", "Check missing indexes on users table"), # near-dup
|
||
("DB performance is terrible", "Look at index usage"), # near-dup
|
||
("Deploy to production", "Use blue-green deployment"),
|
||
("Push to prod", "Blue-green deployment via GitHub Actions"), # near-dup
|
||
("The API returns 500 errors", "Check for OOM in Python worker"),
|
||
("Getting 500 errors from API", "Python worker might be OOM"), # near-dup
|
||
("Set up monitoring", "Prometheus + Grafana"),
|
||
("We need better observability", "Set up Prometheus and Grafana"), # near-dup
|
||
]
|
||
|
||
for cue, target in memories:
|
||
mem.store(emb(model, cue), emb(model, target),
|
||
metadata={"cue": cue, "target": target})
|
||
|
||
print(f" Before dedup: {len(mem.memories)} memories")
|
||
|
||
# Detect near-duplicates by cue similarity
|
||
entries = list(mem.memories.values())
|
||
groups = []
|
||
used = set()
|
||
|
||
for i, e1 in enumerate(entries):
|
||
if i in used:
|
||
continue
|
||
group = [i]
|
||
for j, e2 in enumerate(entries):
|
||
if j <= i or j in used:
|
||
continue
|
||
sim = cosine(e1.cue_embedding, e2.cue_embedding)
|
||
if sim > 0.7: # threshold for "near-duplicate"
|
||
group.append(j)
|
||
used.add(j)
|
||
groups.append(group)
|
||
used.add(i)
|
||
|
||
print(f" Found {len(groups)} groups (from {len(entries)} memories):")
|
||
for group in groups:
|
||
if len(group) > 1:
|
||
cues = [entries[i].metadata.get("cue", "?") for i in group]
|
||
print(f" Group ({len(group)}): {[c[:30] for c in cues]}")
|
||
|
||
# Merge: keep the one with longest target (most info)
|
||
to_remove = []
|
||
for group in groups:
|
||
if len(group) > 1:
|
||
# Keep the one with longest target text
|
||
best = max(group, key=lambda i: len(entries[i].metadata.get("target", "")))
|
||
for i in group:
|
||
if i != best:
|
||
to_remove.append(entries[i].memory_id)
|
||
|
||
for mid in to_remove:
|
||
mem.forget(mid)
|
||
|
||
print(f" After dedup: {len(mem.memories)} memories")
|
||
print(f" Removed {len(to_remove)} duplicates")
|
||
|
||
|
||
def test_importance_scoring(model):
|
||
"""Test: importance-based memory management."""
|
||
print("\n=== Test 2: Importance Scoring ===\n")
|
||
|
||
# Simulate conversation with varying importance
|
||
conversations = [
|
||
# (user, assistant, expected_importance)
|
||
("Hi there!", "Hello! How can I help?", "low"),
|
||
("What's the weather?", "It's sunny today.", "low"),
|
||
("The production database crashed at 3am", "Emergency: restore from latest backup at s3://backups/db-latest.sql", "high"),
|
||
("What time is it?", "It's 3:45 PM.", "low"),
|
||
("The auth service JWT secret was compromised", "Rotate secret immediately: kubectl set env deployment/auth JWT_SECRET=new_value", "critical"),
|
||
("Deploy the hotfix", "Deployed via GitHub Actions, monitor Grafana for 30 min", "high"),
|
||
("Thanks for your help", "You're welcome!", "low"),
|
||
]
|
||
|
||
def score_importance(user_msg, assistant_msg):
|
||
"""Simple heuristic importance scoring."""
|
||
score = 0.3 # base
|
||
|
||
# Length suggests complexity
|
||
if len(assistant_msg.split()) > 15:
|
||
score += 0.2
|
||
|
||
# Technical keywords
|
||
critical_words = ["crash", "emergency", "compromised", "secret", "password",
|
||
"production", "outage", "down", "data loss"]
|
||
high_words = ["deploy", "config", "fix", "bug", "error", "migrate",
|
||
"backup", "restore", "rollback"]
|
||
for w in critical_words:
|
||
if w in (user_msg + assistant_msg).lower():
|
||
score += 0.3
|
||
for w in high_words:
|
||
if w in (user_msg + assistant_msg).lower():
|
||
score += 0.1
|
||
|
||
# Questions suggest retrievable info
|
||
if "?" in user_msg:
|
||
score += 0.1
|
||
|
||
return min(score, 1.0)
|
||
|
||
for user, assistant, expected in conversations:
|
||
score = score_importance(user, assistant)
|
||
status = "✓" if (expected == "low" and score < 0.5) or \
|
||
(expected == "high" and 0.5 <= score < 0.8) or \
|
||
(expected == "critical" and score >= 0.8) else "✗"
|
||
should_store = score >= 0.4
|
||
print(f" {status} [{score:.2f}] {'STORE' if should_store else 'SKIP ':>5} "
|
||
f"({expected:>8}) '{user[:40]}...'")
|
||
|
||
|
||
def test_forgetting_strategies(model):
|
||
"""Test: different forgetting strategies under memory pressure."""
|
||
print("\n=== Test 3: Forgetting Strategies ===\n")
|
||
|
||
# Simulate 7 days of memories, each day 10 memories
|
||
days = 7
|
||
per_day = 10
|
||
max_capacity = 30 # Force forgetting after 30 memories
|
||
|
||
cue_template = "Day {day} task {i}: {topic}"
|
||
target_template = "Solution for day {day} task {i}"
|
||
topics = ["database", "deploy", "monitoring", "auth", "API",
|
||
"caching", "logging", "testing", "docker", "CI/CD"]
|
||
|
||
def run_strategy(strategy_name, forget_fn):
|
||
mem = HippocampalMemory(embed_dim=384)
|
||
day_memories = {} # day → list of memory_ids
|
||
|
||
for day in range(1, days + 1):
|
||
day_memories[day] = []
|
||
for i in range(per_day):
|
||
cue = cue_template.format(day=day, i=i, topic=topics[i])
|
||
target = target_template.format(day=day, i=i)
|
||
mid = mem.store(emb(model, cue), emb(model, target),
|
||
metadata={"day": day, "task": i},
|
||
timestamp=float(day))
|
||
day_memories[day].append(mid)
|
||
|
||
# Check capacity
|
||
if len(mem.memories) > max_capacity:
|
||
forget_fn(mem, max_capacity)
|
||
|
||
# Test recall for each day's memories
|
||
day_recall = {}
|
||
for day in range(1, days + 1):
|
||
correct = 0
|
||
total = 0
|
||
for i in range(per_day):
|
||
mid = day_memories[day][i] if i < len(day_memories[day]) else None
|
||
if mid is None or mid not in mem.memories:
|
||
continue
|
||
cue = cue_template.format(day=day, i=i, topic=topics[i])
|
||
results = mem.recall(emb(model, cue), top_k=1)
|
||
if results and results[0].memory_id == mid:
|
||
correct += 1
|
||
total += 1
|
||
day_recall[day] = (correct, total)
|
||
|
||
# Print results
|
||
surviving = len(mem.memories)
|
||
print(f" {strategy_name}: {surviving} memories surviving")
|
||
for day in range(1, days + 1):
|
||
c, t = day_recall[day]
|
||
pct = f"{c}/{t}" if t > 0 else "0/0"
|
||
print(f" Day {day}: {pct}")
|
||
|
||
# Strategy 1: FIFO (oldest first)
|
||
def forget_fifo(mem, cap):
|
||
entries = sorted(mem.memories.values(), key=lambda e: e.timestamp)
|
||
to_remove = len(mem.memories) - cap
|
||
for e in entries[:to_remove]:
|
||
mem.forget(e.memory_id)
|
||
|
||
# Strategy 2: LRU (least recently accessed)
|
||
def forget_lru(mem, cap):
|
||
entries = sorted(mem.memories.values(), key=lambda e: e.access_count)
|
||
to_remove = len(mem.memories) - cap
|
||
for e in entries[:to_remove]:
|
||
mem.forget(e.memory_id)
|
||
|
||
# Strategy 3: Low importance first (by timestamp recency as proxy)
|
||
def forget_low_importance(mem, cap):
|
||
entries = sorted(mem.memories.values(),
|
||
key=lambda e: e.timestamp + e.access_count * 0.5)
|
||
to_remove = len(mem.memories) - cap
|
||
for e in entries[:to_remove]:
|
||
mem.forget(e.memory_id)
|
||
|
||
print("(max_capacity=30, 7 days × 10 memories = 70 total)")
|
||
run_strategy("FIFO (oldest first)", forget_fifo)
|
||
print()
|
||
run_strategy("LRU (least accessed)", forget_lru)
|
||
print()
|
||
run_strategy("Importance (recency+access)", forget_low_importance)
|
||
|
||
|
||
def main():
|
||
print("=" * 60)
|
||
print("Experiment P4: Memory Lifecycle")
|
||
print("=" * 60)
|
||
|
||
model = load_model()
|
||
test_deduplication(model)
|
||
test_importance_scoring(model)
|
||
test_forgetting_strategies(model)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|