"""LongMemEval benchmark with Gemma 4 post-retrieval reasoning. Previous result: 36% with retrieval-only. This version: retrieve top-K memories → Gemma 4 reads them and answers the question. Improvements: 1. No truncation of assistant responses (store full text, chunked) 2. Store user statements as memories (preference extraction) 3. Include timestamps in stored memories 4. Post-retrieval: Gemma 4 synthesizes answer from recalled memories """ import sys import json import time import re import requests from pathlib import Path from collections import Counter import torch import torch.nn as nn import numpy as np sys.path.insert(0, str(Path(__file__).parent.parent / "src")) from nuonuo.hippocampus import HippocampalMemory DEVICE = "cuda" OLLAMA_URL = "http://localhost:11434/api/chat" def gemma_chat(messages, max_tokens=256, temperature=0): """Call Gemma 4 via Ollama.""" try: resp = requests.post(OLLAMA_URL, json={ "model": "gemma4:31b", "messages": messages, "stream": False, "think": False, "options": {"num_predict": max_tokens, "temperature": temperature}, }, timeout=60) return resp.json()["message"]["content"] except Exception as e: return None def load_model(): from sentence_transformers import SentenceTransformer return SentenceTransformer("all-MiniLM-L6-v2", device=DEVICE) def emb_batch(model, texts): if not texts: return [] embs = model.encode(texts, convert_to_tensor=True, normalize_embeddings=True, device=DEVICE, batch_size=64) return [embs[i] for i in range(embs.shape[0])] def extract_memories_v2(session, session_date=""): """Improved memory extraction. Changes from v1: - Don't truncate assistant responses; chunk them instead - Extract user preferences ("I like/prefer/use/enjoy X") - Store timestamps """ memories = [] # list of (cue_text, target_text, extra_metadata) for i, turn in enumerate(session): content = turn["content"].strip() role = turn["role"] if role == "user": # Find next assistant response assistant_text = "" for j in range(i + 1, len(session)): if session[j]["role"] == "assistant": assistant_text = session[j]["content"].strip() break if len(content) > 10 and len(assistant_text) > 10: # Chunk long assistant responses (500 char chunks) if len(assistant_text) > 500: chunks = [] for start in range(0, len(assistant_text), 400): chunk = assistant_text[start:start + 500] if len(chunk) > 50: chunks.append(chunk) for ci, chunk in enumerate(chunks[:5]): # max 5 chunks memories.append(( content, chunk, {"date": session_date, "chunk": ci} )) else: memories.append((content, assistant_text, {"date": session_date})) # User's own statements as memories if len(content) > 20: memories.append((content, content, {"date": session_date, "type": "user_statement"})) # Preference extraction pref_patterns = [ r"I (?:like|love|prefer|enjoy|use|usually|always|often)\s+(.{5,80})", r"my (?:favorite|preferred)\s+(.{5,80})", r"I'm (?:a fan of|into|interested in)\s+(.{5,80})", ] for pat in pref_patterns: match = re.search(pat, content, re.IGNORECASE) if match: pref_text = match.group(0) memories.append(( f"user preference: {pref_text}", content, {"date": session_date, "type": "preference"} )) return memories def check_answer(answer_text, expected_answer): """Check if the generated answer contains the expected answer.""" if not answer_text: return False expected = str(expected_answer).lower().strip() generated = answer_text.lower().strip() # Handle unanswerable if "did not mention" in expected or "not mention" in expected: neg_phrases = ["don't have", "no information", "not mentioned", "cannot find", "don't recall", "no record", "not available"] return any(p in generated for p in neg_phrases) # Direct substring match if expected in generated: return True # Key words match (60% of answer words present) answer_words = [w for w in expected.split() if len(w) > 3] if answer_words: matches = sum(1 for w in answer_words if w in generated) if matches >= max(1, len(answer_words) * 0.5): return True # Number matching (for temporal questions) expected_nums = re.findall(r'\d+', expected) generated_nums = re.findall(r'\d+', generated) if expected_nums and any(n in generated_nums for n in expected_nums): return True return False def run_benchmark(model, oracle, max_questions=None): """Full benchmark with Gemma 4 reasoning.""" if max_questions: oracle = oracle[:max_questions] results_by_type = Counter() total_by_type = Counter() retrieval_hits = Counter() gemma_calls = 0 gemma_errors = 0 total_time = 0 for qi, entry in enumerate(oracle): qtype = entry["question_type"] question = entry["question"] answer = entry["answer"] sessions = entry["haystack_sessions"] dates = entry.get("haystack_dates", [""] * len(sessions)) total_by_type[qtype] += 1 # Build memory mem = HippocampalMemory(embed_dim=384) all_texts = [] # (cue, target, meta) for si, session in enumerate(sessions): date = dates[si] if si < len(dates) else "" pairs = extract_memories_v2(session, session_date=date) all_texts.extend(pairs) if not all_texts: continue cue_texts = [t[0] for t in all_texts] target_texts = [t[1] for t in all_texts] metas = [t[2] for t in all_texts] cue_embs = emb_batch(model, cue_texts) target_embs = emb_batch(model, target_texts) for i in range(len(cue_texts)): mem.store(cue_embs[i], target_embs[i], metadata={"cue": cue_texts[i][:200], "target": target_texts[i][:500], **metas[i]}) # Recall t0 = time.time() q_emb = model.encode([question], convert_to_tensor=True, normalize_embeddings=True, device=DEVICE)[0] results = mem.recall(q_emb, top_k=10) # Check if retrieval alone has the answer retrieved_texts = [] for r in results: retrieved_texts.append(r.metadata.get("target", "")) retrieved_texts.append(r.metadata.get("cue", "")) retrieval_hit = check_answer("\n".join(retrieved_texts), answer) if retrieval_hit: retrieval_hits[qtype] += 1 # Build context for Gemma context_parts = [] for r in results[:7]: # top 7 memories date = r.metadata.get("date", "") target = r.metadata.get("target", "") if date: context_parts.append(f"[{date}] {target}") else: context_parts.append(target) context = "\n".join(context_parts) # Ask Gemma to answer based on recalled memories prompt = f"""You are answering a question based on recalled memories from past conversations with the user. The memories may contain dates and timestamps — use them for temporal reasoning. Memories: {context} Question: {question} Instructions: - Answer based ONLY on the memories above - For "which came first" questions, compare the dates in the memories - For "how many days" questions, calculate from the dates mentioned - For counting questions, count all relevant items across memories - Extract specific facts (names, numbers, places) directly from the memories - Be concise (1-2 sentences) - If genuinely no relevant information exists, say "Not mentioned in our conversations" Answer:""" gemma_answer = gemma_chat([{"role": "user", "content": prompt}], max_tokens=100) gemma_calls += 1 if gemma_answer is None: gemma_errors += 1 gemma_answer = "\n".join(retrieved_texts[:3]) # fallback total_time += time.time() - t0 hit = check_answer(gemma_answer, answer) if hit: results_by_type[qtype] += 1 if qi < 3 or (not hit and qi < 20): status = "✓" if hit else "✗" ret_status = "R✓" if retrieval_hit else "R✗" print(f" {status} {ret_status} [{qtype[:12]:>12}] Q: {question[:55]}...") print(f" Expected: {str(answer)[:60]}...") if gemma_answer: print(f" Gemma: {gemma_answer[:80]}...") del mem if (qi + 1) % 25 == 0: print(f" ... {qi+1}/{len(oracle)} ({total_time:.0f}s, " f"{gemma_errors} Gemma errors)") return results_by_type, total_by_type, retrieval_hits, total_time def main(): print("=" * 60) print("LongMemEval + Gemma 4 Post-Retrieval Reasoning") print("=" * 60) # Verify Gemma test = gemma_chat([{"role": "user", "content": "Say OK"}], max_tokens=10) if test: print(f"Gemma 4 connected: '{test.strip()}'") else: print("ERROR: Gemma 4 not available!") return model = load_model() with open("data/longmemeval_oracle.json") as f: oracle = json.load(f) # Full benchmark directly if True: print(f"\n=== Full Benchmark (500 questions) ===\n") results, totals, ret_hits, dt = run_benchmark(model, oracle) print(f"\n{'='*60}") print("FINAL RESULTS (Retrieval + Gemma 4 Reasoning)") print(f"{'='*60}") overall = sum(results.values()) total = sum(totals.values()) ret_overall = sum(ret_hits.values()) print(f"Overall: {overall}/{total} ({overall/total:.0%}) " f"[retrieval-only baseline: {ret_overall}/{total} ({ret_overall/total:.0%})]") print() for qtype in sorted(totals.keys()): c = results.get(qtype, 0) rc = ret_hits.get(qtype, 0) t = totals[qtype] bar = "█" * int(c/t * 20) + "░" * (20 - int(c/t * 20)) print(f" {qtype:<25}: {c:>3}/{t:<3} ({c/t:>5.1%}) {bar} [ret: {rc/t:.0%}]") print(f"\nTime: {dt:.0f}s ({dt/total:.1f}s/question)") if __name__ == "__main__": main()