NuoNuo: Hippocampal memory module prototype
Hopfield + Hebbian hybrid memory system for LLMs. Two nights of experiments (16 iterations), validated on LongMemEval (ICLR 2025). Architecture: - Single-hop: Two-Stage Hopfield (NN top-20 → softmax settle) - Multi-hop: Hebbian W matrix with WTA pattern separation - 64% on LongMemEval (500 questions), retrieval-only, no LLM dependency - 4ms latency @ 20K memories, ~1GB VRAM Key findings: - Hopfield attention solved noise tolerance (20% → 100% vs flat Hebbian) - WTA pattern separation enables 20K+ capacity - Multi-hop associative chains (6 hops, CosSim=1.0) — RAG can't do this - MiniLM-L6 is optimal (discrimination gap > absolute similarity) - Paraphrase cue augmentation: 55% → 100% on synthetic, 36% → 64% on benchmark - SNN encoder viable (CosSim 0.99) but not needed for current architecture
This commit is contained in:
194
experiments/exp02g_multihop.py
Normal file
194
experiments/exp02g_multihop.py
Normal file
@@ -0,0 +1,194 @@
|
||||
"""Experiment 2g: Multi-hop associative recall.
|
||||
|
||||
The unique advantage of Hebbian memory over simple cosine retrieval:
|
||||
If A→B and B→C are learned, can we recall C from A by chaining through B?
|
||||
|
||||
This is impossible with standard RAG (which only does single-hop NN lookup).
|
||||
If this works, it's the strongest argument for the Hebbian approach.
|
||||
"""
|
||||
|
||||
import sys
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
|
||||
DEVICE = "cuda"
|
||||
|
||||
|
||||
def cosine(a, b):
|
||||
if a.norm() == 0 or b.norm() == 0:
|
||||
return 0.0
|
||||
return nn.functional.cosine_similarity(a.unsqueeze(0), b.unsqueeze(0)).item()
|
||||
|
||||
|
||||
def winner_take_all(x, k):
|
||||
_, idx = x.topk(k, dim=-1)
|
||||
out = torch.zeros_like(x)
|
||||
out.scatter_(-1, idx, 1.0)
|
||||
return out
|
||||
|
||||
|
||||
class HebbianMemory:
|
||||
"""Simple Hebbian memory for multi-hop tests."""
|
||||
def __init__(self, input_dim=768, code_dim=16384, k=20):
|
||||
self.k = k
|
||||
self.proj = (torch.randn(input_dim, code_dim, device=DEVICE)
|
||||
* (1.0 / input_dim**0.5))
|
||||
self.W = torch.zeros(code_dim, code_dim, device=DEVICE)
|
||||
|
||||
def sep(self, x):
|
||||
return winner_take_all(x @ self.proj, self.k)
|
||||
|
||||
def learn(self, cue, target):
|
||||
cc = self.sep(cue)
|
||||
tc = self.sep(target)
|
||||
self.W += torch.outer(tc, cc)
|
||||
|
||||
def recall_code(self, code, k=None):
|
||||
if k is None:
|
||||
k = self.k
|
||||
raw = self.W @ code
|
||||
return winner_take_all(raw, k)
|
||||
|
||||
def recall(self, cue):
|
||||
return self.recall_code(self.sep(cue))
|
||||
|
||||
def multi_hop_recall(self, cue, hops=2):
|
||||
"""Chain through associations: cue → hop1 → hop2 → ..."""
|
||||
code = self.sep(cue)
|
||||
for _ in range(hops):
|
||||
code = self.recall_code(code)
|
||||
return code
|
||||
|
||||
|
||||
def test_chain(chain_length, num_chains, dim=768, code_dim=16384, k=20):
|
||||
"""Test multi-hop recall along chains of length L.
|
||||
|
||||
Create chains: A₁→A₂→...→Aₗ
|
||||
Learn pairs: (A₁,A₂), (A₂,A₃), ..., (Aₗ₋₁,Aₗ)
|
||||
Test: given A₁, can we reach A₂, A₃, ..., Aₗ in 1, 2, ... hops?
|
||||
"""
|
||||
mem = HebbianMemory(dim, code_dim, k)
|
||||
|
||||
chains = []
|
||||
for _ in range(num_chains):
|
||||
chain = [nn.functional.normalize(torch.randn(dim, device=DEVICE), dim=0)
|
||||
for _ in range(chain_length)]
|
||||
chains.append(chain)
|
||||
|
||||
# Learn consecutive pairs
|
||||
for i in range(chain_length - 1):
|
||||
mem.learn(chain[i], chain[i+1])
|
||||
|
||||
# Test recall at different hop distances
|
||||
results = {}
|
||||
for hops in range(1, chain_length):
|
||||
correct_sims = []
|
||||
for chain in chains:
|
||||
start = chain[0]
|
||||
target = chain[hops]
|
||||
target_code = mem.sep(target)
|
||||
|
||||
recalled = mem.multi_hop_recall(start, hops=hops)
|
||||
cs = cosine(recalled, target_code)
|
||||
correct_sims.append(cs)
|
||||
|
||||
mc = np.mean(correct_sims)
|
||||
exact = np.mean([s > 0.5 for s in correct_sims])
|
||||
results[hops] = {"mean_cos": mc, "recall_rate": exact}
|
||||
print(f" chain_len={chain_length}, chains={num_chains}, "
|
||||
f"hops={hops}: CosSim={mc:.4f}, recall>{0.5:.0%}={exact:.2%}")
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def test_convergent_chains(dim=768, code_dim=16384, k=20):
|
||||
"""Test convergent chains: A→C and B→C.
|
||||
Can we recall C from both A and B?"""
|
||||
mem = HebbianMemory(dim, code_dim, k)
|
||||
|
||||
# Create convergent pattern
|
||||
a = nn.functional.normalize(torch.randn(dim, device=DEVICE), dim=0)
|
||||
b = nn.functional.normalize(torch.randn(dim, device=DEVICE), dim=0)
|
||||
c = nn.functional.normalize(torch.randn(dim, device=DEVICE), dim=0)
|
||||
|
||||
mem.learn(a, c)
|
||||
mem.learn(b, c)
|
||||
|
||||
c_code = mem.sep(c)
|
||||
|
||||
# Recall from A
|
||||
ra = mem.recall(a)
|
||||
sim_a = cosine(ra, c_code)
|
||||
|
||||
# Recall from B
|
||||
rb = mem.recall(b)
|
||||
sim_b = cosine(rb, c_code)
|
||||
|
||||
print(f" Convergent: A→C sim={sim_a:.4f}, B→C sim={sim_b:.4f}")
|
||||
return {"a_to_c": sim_a, "b_to_c": sim_b}
|
||||
|
||||
|
||||
def test_divergent_chains(dim=768, code_dim=16384, k=20):
|
||||
"""Test divergent chains: A→B and A→C.
|
||||
Do B and C interfere?"""
|
||||
mem = HebbianMemory(dim, code_dim, k)
|
||||
|
||||
a = nn.functional.normalize(torch.randn(dim, device=DEVICE), dim=0)
|
||||
b = nn.functional.normalize(torch.randn(dim, device=DEVICE), dim=0)
|
||||
c = nn.functional.normalize(torch.randn(dim, device=DEVICE), dim=0)
|
||||
|
||||
mem.learn(a, b)
|
||||
mem.learn(a, c)
|
||||
|
||||
b_code = mem.sep(b)
|
||||
c_code = mem.sep(c)
|
||||
|
||||
recalled = mem.recall(a)
|
||||
sim_b = cosine(recalled, b_code)
|
||||
sim_c = cosine(recalled, c_code)
|
||||
|
||||
print(f" Divergent: A→B sim={sim_b:.4f}, A→C sim={sim_c:.4f}")
|
||||
return {"a_to_b": sim_b, "a_to_c": sim_c}
|
||||
|
||||
|
||||
def main():
|
||||
print("=" * 60)
|
||||
print("Experiment 2g: Multi-hop Associative Recall")
|
||||
print("=" * 60)
|
||||
|
||||
# Test 1: Simple chains
|
||||
print("\n=== Chain recall (single chain) ===")
|
||||
for L in [3, 5, 7]:
|
||||
test_chain(L, num_chains=1)
|
||||
|
||||
# Test 2: Multiple chains (interference between chains)
|
||||
print("\n=== Chain recall (multiple chains, interference) ===")
|
||||
for n_chains in [1, 5, 10, 50, 100]:
|
||||
print(f"\n-- {n_chains} chains of length 4 --")
|
||||
test_chain(4, num_chains=n_chains)
|
||||
|
||||
# Test 3: Convergent
|
||||
print("\n=== Convergent chains (A→C, B→C) ===")
|
||||
results = []
|
||||
for _ in range(20):
|
||||
r = test_convergent_chains()
|
||||
results.append(r)
|
||||
mean_a = np.mean([r["a_to_c"] for r in results])
|
||||
mean_b = np.mean([r["b_to_c"] for r in results])
|
||||
print(f" Average: A→C={mean_a:.4f}, B→C={mean_b:.4f}")
|
||||
|
||||
# Test 4: Divergent
|
||||
print("\n=== Divergent chains (A→B, A→C) ===")
|
||||
results = []
|
||||
for _ in range(20):
|
||||
r = test_divergent_chains()
|
||||
results.append(r)
|
||||
mean_b = np.mean([r["a_to_b"] for r in results])
|
||||
mean_c = np.mean([r["a_to_c"] for r in results])
|
||||
print(f" Average: A→B={mean_b:.4f}, A→C={mean_c:.4f}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user