NuoNuo: Hippocampal memory module prototype
Hopfield + Hebbian hybrid memory system for LLMs. Two nights of experiments (16 iterations), validated on LongMemEval (ICLR 2025). Architecture: - Single-hop: Two-Stage Hopfield (NN top-20 → softmax settle) - Multi-hop: Hebbian W matrix with WTA pattern separation - 64% on LongMemEval (500 questions), retrieval-only, no LLM dependency - 4ms latency @ 20K memories, ~1GB VRAM Key findings: - Hopfield attention solved noise tolerance (20% → 100% vs flat Hebbian) - WTA pattern separation enables 20K+ capacity - Multi-hop associative chains (6 hops, CosSim=1.0) — RAG can't do this - MiniLM-L6 is optimal (discrimination gap > absolute similarity) - Paraphrase cue augmentation: 55% → 100% on synthetic, 36% → 64% on benchmark - SNN encoder viable (CosSim 0.99) but not needed for current architecture
This commit is contained in:
316
experiments/exp03_consolidation.py
Normal file
316
experiments/exp03_consolidation.py
Normal file
@@ -0,0 +1,316 @@
|
||||
"""Experiment 3: Sleep Consolidation Effects.
|
||||
|
||||
Test questions:
|
||||
1. Does consolidation (replay + homeostasis) help or hurt recall?
|
||||
2. Does replay with noise improve noise tolerance?
|
||||
3. How does pruning affect capacity?
|
||||
4. Multi-night scenario: learn day 1, consolidate, learn day 2, consolidate.
|
||||
Do day 1 memories survive?
|
||||
5. Selective consolidation: replay important memories more → priority memory
|
||||
"""
|
||||
|
||||
import sys
|
||||
import time
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent / "src"))
|
||||
from nuonuo.consolidation import MemoryConsolidator, winner_take_all
|
||||
|
||||
DEVICE = "cuda"
|
||||
RESULTS_DIR = Path(__file__).parent.parent / "doc"
|
||||
|
||||
|
||||
def cosine(a, b):
|
||||
if a.norm() == 0 or b.norm() == 0:
|
||||
return 0.0
|
||||
return nn.functional.cosine_similarity(a.unsqueeze(0), b.unsqueeze(0)).item()
|
||||
|
||||
|
||||
class TestableMemory:
|
||||
"""Memory with consolidation support for testing."""
|
||||
def __init__(self, input_dim=768, code_dim=16384, k=20):
|
||||
self.k = k
|
||||
self.code_dim = code_dim
|
||||
self.proj = (torch.randn(input_dim, code_dim, device=DEVICE)
|
||||
* (1.0 / input_dim**0.5))
|
||||
self.target_proj = (torch.randn(input_dim, code_dim, device=DEVICE)
|
||||
* (1.0 / input_dim**0.5))
|
||||
self.W = nn.Parameter(torch.zeros(code_dim, code_dim, device=DEVICE),
|
||||
requires_grad=False)
|
||||
self.consolidator = MemoryConsolidator(code_dim, k)
|
||||
|
||||
def sep(self, x):
|
||||
return winner_take_all(x @ self.proj, self.k)
|
||||
|
||||
def sep_target(self, x):
|
||||
return winner_take_all(x @ self.target_proj, self.k)
|
||||
|
||||
def learn(self, cue, target, record=True):
|
||||
cc = self.sep(cue)
|
||||
tc = self.sep_target(target)
|
||||
self.W.data += torch.outer(tc, cc)
|
||||
if record:
|
||||
self.consolidator.record(cc, tc)
|
||||
|
||||
def recall(self, cue):
|
||||
cc = self.sep(cue)
|
||||
raw = self.W @ cc
|
||||
return winner_take_all(raw, self.k)
|
||||
|
||||
def test_recall(self, cues, targets, noise_std=0.0):
|
||||
"""Test recall accuracy."""
|
||||
correct = []
|
||||
for i in range(len(cues)):
|
||||
if noise_std > 0:
|
||||
c = nn.functional.normalize(
|
||||
cues[i] + torch.randn_like(cues[i]) * noise_std, dim=0)
|
||||
else:
|
||||
c = cues[i]
|
||||
recalled = self.recall(c)
|
||||
tc = self.sep_target(targets[i])
|
||||
correct.append(cosine(recalled, tc))
|
||||
return np.mean(correct), np.mean([s > 0.5 for s in correct])
|
||||
|
||||
def consolidate(self, **kwargs):
|
||||
return self.consolidator.consolidate(
|
||||
self.W, self.proj, self.target_proj, **kwargs)
|
||||
|
||||
|
||||
def gen_memories(n, dim=768):
|
||||
cues = [nn.functional.normalize(torch.randn(dim, device=DEVICE), dim=0)
|
||||
for _ in range(n)]
|
||||
targets = [nn.functional.normalize(torch.randn(dim, device=DEVICE), dim=0)
|
||||
for _ in range(n)]
|
||||
return cues, targets
|
||||
|
||||
|
||||
def test_basic_consolidation():
|
||||
"""Does replay + homeostasis help?"""
|
||||
print("=== Test 1: Basic Consolidation Effect ===")
|
||||
|
||||
for n_pairs in [100, 500]:
|
||||
mem = TestableMemory()
|
||||
cues, targets = gen_memories(n_pairs)
|
||||
|
||||
# Learn
|
||||
for i in range(n_pairs):
|
||||
mem.learn(cues[i], targets[i])
|
||||
|
||||
# Before consolidation
|
||||
cos_before, rate_before = mem.test_recall(cues, targets)
|
||||
w_norm_before = mem.W.data.norm().item()
|
||||
|
||||
print(f"\n {n_pairs} pairs:")
|
||||
print(f" Before: CosSim={cos_before:.4f}, Rate={rate_before:.2%}, "
|
||||
f"W_norm={w_norm_before:.2f}")
|
||||
|
||||
# Consolidation with different settings
|
||||
for epochs in [1, 3, 5, 10]:
|
||||
# Clone memory for each test
|
||||
mem_test = TestableMemory()
|
||||
mem_test.W.data.copy_(mem.W.data)
|
||||
mem_test.proj = mem.proj
|
||||
mem_test.target_proj = mem.target_proj
|
||||
mem_test.consolidator.replay_buffer = list(mem.consolidator.replay_buffer)
|
||||
|
||||
stats = mem_test.consolidate(
|
||||
num_epochs=epochs, homeostasis_factor=0.95, prune_threshold=0.001)
|
||||
cos_after, rate_after = mem_test.test_recall(cues, targets)
|
||||
|
||||
print(f" After (epochs={epochs}): CosSim={cos_after:.4f}, "
|
||||
f"Rate={rate_after:.2%}, "
|
||||
f"W_norm={stats['final_w_norm']:.2f}, "
|
||||
f"Sparsity={stats['final_sparsity']:.2%}")
|
||||
|
||||
|
||||
def test_noisy_replay():
|
||||
"""Does replay with noise improve noise tolerance?"""
|
||||
print("\n=== Test 2: Noisy Replay for Robustness ===")
|
||||
|
||||
n_pairs = 100
|
||||
mem_base = TestableMemory()
|
||||
cues, targets = gen_memories(n_pairs)
|
||||
|
||||
for i in range(n_pairs):
|
||||
mem_base.learn(cues[i], targets[i])
|
||||
|
||||
# Test at different noise levels
|
||||
test_noises = [0.0, 0.05, 0.1, 0.2]
|
||||
|
||||
# No consolidation (baseline)
|
||||
print("\n No consolidation:")
|
||||
for ns in test_noises:
|
||||
cos, rate = mem_base.test_recall(cues, targets, noise_std=ns)
|
||||
print(f" test_noise={ns:.2f}: CosSim={cos:.4f}, Rate={rate:.2%}")
|
||||
|
||||
# Consolidation with different replay noise
|
||||
for replay_noise in [0.0, 0.1, 0.5, 1.0]:
|
||||
mem_test = TestableMemory()
|
||||
mem_test.W.data.copy_(mem_base.W.data)
|
||||
mem_test.proj = mem_base.proj
|
||||
mem_test.target_proj = mem_base.target_proj
|
||||
mem_test.consolidator.replay_buffer = list(mem_base.consolidator.replay_buffer)
|
||||
|
||||
mem_test.consolidate(num_epochs=5, replay_noise=replay_noise,
|
||||
homeostasis_factor=0.95)
|
||||
|
||||
print(f"\n Consolidated (replay_noise={replay_noise}):")
|
||||
for ns in test_noises:
|
||||
cos, rate = mem_test.test_recall(cues, targets, noise_std=ns)
|
||||
print(f" test_noise={ns:.2f}: CosSim={cos:.4f}, Rate={rate:.2%}")
|
||||
|
||||
|
||||
def test_multi_night():
|
||||
"""Multi-night scenario: learn, consolidate, learn more.
|
||||
Do old memories survive?"""
|
||||
print("\n=== Test 3: Multi-Night Memory Survival ===")
|
||||
|
||||
mem = TestableMemory()
|
||||
|
||||
# Day 1: Learn 100 memories
|
||||
cues_d1, targets_d1 = gen_memories(100)
|
||||
for i in range(100):
|
||||
mem.learn(cues_d1[i], targets_d1[i])
|
||||
|
||||
cos_d1, _ = mem.test_recall(cues_d1, targets_d1)
|
||||
print(f" After Day 1 (100 memories): CosSim={cos_d1:.4f}")
|
||||
|
||||
# Night 1: Consolidate
|
||||
stats = mem.consolidate(num_epochs=5, homeostasis_factor=0.95)
|
||||
cos_d1_after, _ = mem.test_recall(cues_d1, targets_d1)
|
||||
print(f" After Night 1 consolidation: CosSim={cos_d1_after:.4f}, "
|
||||
f"W_norm={stats['final_w_norm']:.2f}")
|
||||
mem.consolidator.selective_clear(keep_fraction=0.3)
|
||||
|
||||
# Day 2: Learn 100 more memories
|
||||
cues_d2, targets_d2 = gen_memories(100)
|
||||
for i in range(100):
|
||||
mem.learn(cues_d2[i], targets_d2[i])
|
||||
|
||||
cos_d1_mid, _ = mem.test_recall(cues_d1, targets_d1)
|
||||
cos_d2_mid, _ = mem.test_recall(cues_d2, targets_d2)
|
||||
print(f" After Day 2 (100 more): Day1={cos_d1_mid:.4f}, Day2={cos_d2_mid:.4f}")
|
||||
|
||||
# Night 2: Consolidate (with day 1 carryover + day 2)
|
||||
stats = mem.consolidate(num_epochs=5, homeostasis_factor=0.95)
|
||||
cos_d1_final, _ = mem.test_recall(cues_d1, targets_d1)
|
||||
cos_d2_final, _ = mem.test_recall(cues_d2, targets_d2)
|
||||
print(f" After Night 2: Day1={cos_d1_final:.4f}, Day2={cos_d2_final:.4f}, "
|
||||
f"W_norm={stats['final_w_norm']:.2f}")
|
||||
|
||||
# Continue for 5 more days
|
||||
for day in range(3, 8):
|
||||
mem.consolidator.selective_clear(keep_fraction=0.3)
|
||||
cues_new, targets_new = gen_memories(100)
|
||||
for i in range(100):
|
||||
mem.learn(cues_new[i], targets_new[i])
|
||||
mem.consolidate(num_epochs=5, homeostasis_factor=0.95)
|
||||
|
||||
cos_d1_now, _ = mem.test_recall(cues_d1, targets_d1)
|
||||
cos_d2_now, _ = mem.test_recall(cues_d2, targets_d2)
|
||||
cos_new, _ = mem.test_recall(cues_new, targets_new)
|
||||
w_norm = mem.W.data.norm().item()
|
||||
sparsity = (mem.W.data.abs() < 0.001).float().mean().item()
|
||||
print(f" After Day {day}: Day1={cos_d1_now:.4f}, Day2={cos_d2_now:.4f}, "
|
||||
f"Latest={cos_new:.4f}, W_norm={w_norm:.1f}, Sparsity={sparsity:.2%}")
|
||||
|
||||
|
||||
def test_priority_replay():
|
||||
"""Test selective consolidation: replay important memories more."""
|
||||
print("\n=== Test 4: Priority Replay ===")
|
||||
|
||||
mem = TestableMemory()
|
||||
|
||||
# 50 "important" memories (replay 5x)
|
||||
cues_imp, targets_imp = gen_memories(50)
|
||||
for i in range(50):
|
||||
mem.learn(cues_imp[i], targets_imp[i])
|
||||
# Record extra copies for priority replay
|
||||
cc = mem.sep(cues_imp[i])
|
||||
tc = mem.sep_target(targets_imp[i])
|
||||
for _ in range(4): # 4 extra = 5x total
|
||||
mem.consolidator.record(cc, tc)
|
||||
|
||||
# 50 "unimportant" memories (replay 1x, normal)
|
||||
cues_unimp, targets_unimp = gen_memories(50)
|
||||
for i in range(50):
|
||||
mem.learn(cues_unimp[i], targets_unimp[i])
|
||||
|
||||
cos_imp_before, _ = mem.test_recall(cues_imp, targets_imp)
|
||||
cos_unimp_before, _ = mem.test_recall(cues_unimp, targets_unimp)
|
||||
print(f" Before consolidation: Important={cos_imp_before:.4f}, "
|
||||
f"Unimportant={cos_unimp_before:.4f}")
|
||||
|
||||
# Consolidate with strong homeostasis (will decay unimportant more)
|
||||
mem.consolidate(num_epochs=10, homeostasis_factor=0.90)
|
||||
|
||||
cos_imp_after, _ = mem.test_recall(cues_imp, targets_imp)
|
||||
cos_unimp_after, _ = mem.test_recall(cues_unimp, targets_unimp)
|
||||
print(f" After consolidation: Important={cos_imp_after:.4f}, "
|
||||
f"Unimportant={cos_unimp_after:.4f}")
|
||||
print(f" Priority effect: Δimportant={cos_imp_after-cos_imp_before:+.4f}, "
|
||||
f"Δunimportant={cos_unimp_after-cos_unimp_before:+.4f}")
|
||||
|
||||
|
||||
def test_forgetting_curve():
|
||||
"""Measure memory decay over multiple consolidation cycles without replay."""
|
||||
print("\n=== Test 5: Forgetting Curve ===")
|
||||
|
||||
mem = TestableMemory()
|
||||
cues, targets = gen_memories(100)
|
||||
|
||||
for i in range(100):
|
||||
mem.learn(cues[i], targets[i])
|
||||
|
||||
cos0, _ = mem.test_recall(cues, targets)
|
||||
print(f" Day 0: CosSim={cos0:.4f}")
|
||||
|
||||
# Simulate nights with homeostasis but NO replay
|
||||
for night in range(1, 11):
|
||||
# Only homeostasis + pruning, no replay
|
||||
mem.W.data *= 0.95
|
||||
mask = mem.W.data.abs() >= 0.001
|
||||
mem.W.data *= mask.float()
|
||||
|
||||
cos, rate = mem.test_recall(cues, targets)
|
||||
w_norm = mem.W.data.norm().item()
|
||||
print(f" Night {night:2d} (no replay): CosSim={cos:.4f}, "
|
||||
f"Rate={rate:.2%}, W_norm={w_norm:.2f}")
|
||||
|
||||
# Same but WITH replay
|
||||
print("\n --- With replay ---")
|
||||
mem2 = TestableMemory()
|
||||
mem2.proj = mem.proj
|
||||
mem2.target_proj = mem.target_proj
|
||||
|
||||
for i in range(100):
|
||||
mem2.learn(cues[i], targets[i])
|
||||
|
||||
for night in range(1, 11):
|
||||
mem2.consolidate(num_epochs=1, homeostasis_factor=0.95)
|
||||
|
||||
cos, rate = mem2.test_recall(cues, targets)
|
||||
w_norm = mem2.W.data.norm().item()
|
||||
print(f" Night {night:2d} (with replay): CosSim={cos:.4f}, "
|
||||
f"Rate={rate:.2%}, W_norm={w_norm:.2f}")
|
||||
|
||||
|
||||
def main():
|
||||
print("=" * 60)
|
||||
print("Experiment 3: Sleep Consolidation")
|
||||
print("=" * 60)
|
||||
|
||||
test_basic_consolidation()
|
||||
test_noisy_replay()
|
||||
test_multi_night()
|
||||
test_priority_replay()
|
||||
test_forgetting_curve()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user