Hopfield + Hebbian hybrid memory system for LLMs. Two nights of experiments (16 iterations), validated on LongMemEval (ICLR 2025). Architecture: - Single-hop: Two-Stage Hopfield (NN top-20 → softmax settle) - Multi-hop: Hebbian W matrix with WTA pattern separation - 64% on LongMemEval (500 questions), retrieval-only, no LLM dependency - 4ms latency @ 20K memories, ~1GB VRAM Key findings: - Hopfield attention solved noise tolerance (20% → 100% vs flat Hebbian) - WTA pattern separation enables 20K+ capacity - Multi-hop associative chains (6 hops, CosSim=1.0) — RAG can't do this - MiniLM-L6 is optimal (discrimination gap > absolute similarity) - Paraphrase cue augmentation: 55% → 100% on synthetic, 36% → 64% on benchmark - SNN encoder viable (CosSim 0.99) but not needed for current architecture
204 lines
6.8 KiB
Python
204 lines
6.8 KiB
Python
"""LLM integration for hippocampal memory.
|
|
|
|
Functions:
|
|
1. extract_memories: Extract (cue, target) pairs from conversation turns
|
|
2. generate_paraphrases: Generate cue variants for augmentation
|
|
3. recall_and_inject: Recall memories and format for context injection
|
|
4. format_recalled_memories: Format RecallResults into prompt text
|
|
|
|
Supports any OpenAI-compatible API. Falls back to simple heuristics when LLM unavailable.
|
|
"""
|
|
|
|
import re
|
|
from typing import Optional
|
|
from dataclasses import dataclass
|
|
|
|
from openai import OpenAI
|
|
|
|
|
|
@dataclass
|
|
class ExtractedMemory:
|
|
cue: str
|
|
target: str
|
|
importance: float = 0.5 # 0-1, higher = more worth storing
|
|
|
|
|
|
class LLMClient:
|
|
"""Wrapper around OpenAI-compatible API with fallback."""
|
|
|
|
def __init__(self, base_url: str = "https://ste-jarvis.tiktok-row.net/llm/v1",
|
|
api_key: str = "unused",
|
|
model: str = "gemma4:12b",
|
|
timeout: float = 5.0):
|
|
self.model = model
|
|
self.available = False
|
|
try:
|
|
self.client = OpenAI(base_url=base_url, api_key=api_key, timeout=timeout)
|
|
# Quick check
|
|
self.client.models.list()
|
|
self.available = True
|
|
except Exception:
|
|
self.client = None
|
|
|
|
def chat(self, messages: list[dict], temperature: float = 0.7,
|
|
max_tokens: int = 512) -> Optional[str]:
|
|
if not self.available:
|
|
return None
|
|
try:
|
|
resp = self.client.chat.completions.create(
|
|
model=self.model,
|
|
messages=messages,
|
|
temperature=temperature,
|
|
max_tokens=max_tokens,
|
|
)
|
|
return resp.choices[0].message.content
|
|
except Exception:
|
|
return None
|
|
|
|
|
|
def extract_memories_llm(client: LLMClient, user_msg: str,
|
|
assistant_msg: str) -> list[ExtractedMemory]:
|
|
"""Use LLM to extract memorable facts from a conversation turn."""
|
|
prompt = f"""From this conversation turn, extract key facts worth remembering for future conversations.
|
|
For each fact, provide a "cue" (what would trigger recalling this) and a "target" (the fact itself).
|
|
Rate importance 0-1 (1 = critical fact, 0 = trivial).
|
|
|
|
User: {user_msg}
|
|
Assistant: {assistant_msg}
|
|
|
|
Output format (one per line):
|
|
CUE: <trigger phrase> | TARGET: <fact> | IMPORTANCE: <0-1>
|
|
|
|
Only extract genuinely useful facts. If nothing worth remembering, output NONE."""
|
|
|
|
result = client.chat([{"role": "user", "content": prompt}], temperature=0.3)
|
|
if not result:
|
|
return extract_memories_heuristic(user_msg, assistant_msg)
|
|
|
|
memories = []
|
|
for line in result.strip().split("\n"):
|
|
if line.strip() == "NONE":
|
|
break
|
|
match = re.match(r"CUE:\s*(.+?)\s*\|\s*TARGET:\s*(.+?)\s*\|\s*IMPORTANCE:\s*([\d.]+)", line)
|
|
if match:
|
|
memories.append(ExtractedMemory(
|
|
cue=match.group(1).strip(),
|
|
target=match.group(2).strip(),
|
|
importance=float(match.group(3)),
|
|
))
|
|
return memories
|
|
|
|
|
|
def extract_memories_heuristic(user_msg: str, assistant_msg: str) -> list[ExtractedMemory]:
|
|
"""Fallback: simple heuristic extraction when LLM unavailable.
|
|
|
|
Rules:
|
|
- User questions → store the answer
|
|
- Technical statements → store as-is
|
|
- Short messages (< 10 words) → skip
|
|
"""
|
|
memories = []
|
|
|
|
# User asked a question, assistant answered
|
|
if "?" in user_msg and len(assistant_msg.split()) > 5:
|
|
memories.append(ExtractedMemory(
|
|
cue=user_msg.rstrip("?").strip(),
|
|
target=assistant_msg[:200],
|
|
importance=0.6,
|
|
))
|
|
|
|
# Technical keywords suggest something worth remembering
|
|
tech_keywords = ["deploy", "config", "bug", "fix", "error", "database",
|
|
"server", "API", "port", "token", "password", "version",
|
|
"install", "upgrade", "migrate", "backup"]
|
|
combined = (user_msg + " " + assistant_msg).lower()
|
|
if any(kw in combined for kw in tech_keywords):
|
|
if len(user_msg.split()) >= 5:
|
|
memories.append(ExtractedMemory(
|
|
cue=user_msg[:100],
|
|
target=assistant_msg[:200],
|
|
importance=0.5,
|
|
))
|
|
|
|
return memories
|
|
|
|
|
|
def generate_paraphrases_llm(client: LLMClient, text: str,
|
|
n: int = 3) -> list[str]:
|
|
"""Use LLM to generate paraphrases of a cue text."""
|
|
prompt = f"""Generate {n} different paraphrases of this text. Each should convey the same meaning but use different words/phrasing. One per line, no numbering.
|
|
|
|
Text: {text}"""
|
|
|
|
result = client.chat([{"role": "user", "content": prompt}],
|
|
temperature=0.8, max_tokens=256)
|
|
if not result:
|
|
return generate_paraphrases_heuristic(text, n)
|
|
|
|
paraphrases = [line.strip() for line in result.strip().split("\n")
|
|
if line.strip() and len(line.strip()) > 3]
|
|
return paraphrases[:n]
|
|
|
|
|
|
def generate_paraphrases_heuristic(text: str, n: int = 3) -> list[str]:
|
|
"""Fallback: simple text augmentation when LLM unavailable.
|
|
|
|
Strategies:
|
|
- Remove/add common prefixes
|
|
- Swap known synonyms
|
|
- Truncate to key phrases
|
|
"""
|
|
variants = []
|
|
text_lower = text.lower().strip()
|
|
|
|
# Remove common prefixes
|
|
prefixes = ["can you ", "please ", "i need to ", "let's ", "we should ",
|
|
"how do i ", "how to ", "i want to ", "help me "]
|
|
for pfx in prefixes:
|
|
if text_lower.startswith(pfx):
|
|
stripped = text[len(pfx):].strip()
|
|
if stripped and stripped not in variants:
|
|
variants.append(stripped)
|
|
|
|
# Simple synonym swaps
|
|
swaps = {
|
|
"slow": "performance issues", "fast": "quick", "fix": "resolve",
|
|
"deploy": "release", "error": "issue", "bug": "problem",
|
|
"database": "DB", "server": "machine", "configure": "set up",
|
|
}
|
|
for old, new in swaps.items():
|
|
if old in text_lower:
|
|
variant = text.replace(old, new).replace(old.capitalize(), new.capitalize())
|
|
if variant != text and variant not in variants:
|
|
variants.append(variant)
|
|
|
|
# Add "the X is Y" pattern
|
|
if len(text.split()) <= 8:
|
|
variants.append(f"issue with {text_lower}")
|
|
|
|
return variants[:n]
|
|
|
|
|
|
def format_recalled_memories(results: list, max_memories: int = 5) -> str:
|
|
"""Format RecallResults into a prompt-ready string."""
|
|
if not results:
|
|
return ""
|
|
|
|
lines = []
|
|
for i, r in enumerate(results[:max_memories]):
|
|
meta = r.metadata
|
|
if "target" in meta:
|
|
text = meta["target"]
|
|
elif "text" in meta:
|
|
text = meta["text"]
|
|
else:
|
|
continue
|
|
|
|
hop_info = f" (via {r.hop_distance}-hop association)" if r.hop_distance > 1 else ""
|
|
lines.append(f"- {text}{hop_info}")
|
|
|
|
if not lines:
|
|
return ""
|
|
|
|
return "Recalled from memory:\n" + "\n".join(lines)
|