Files
noc/mem/test_real_data.py
Fam Zheng 7000ccda0f add nocmem: auto memory recall + ingest via NuoNuo hippocampal network
- nocmem Python service (mem/): FastAPI wrapper around NuoNuo's
  Hopfield-Hebbian memory, with /recall, /ingest, /store, /stats endpoints
- NOC integration: auto recall after user message (injected as system msg),
  async ingest after LLM response (fire-and-forget)
- Recall: cosine pre-filter (threshold 0.35) + Hopfield attention (β=32),
  top_k=3, KV-cache friendly (appended after user msg, not in system prompt)
- Ingest: LLM extraction + paraphrase augmentation, heuristic fallback
- Wired into main.rs, life.rs (agent done), http.rs (api chat)
- Config: optional `nocmem.endpoint` in config.yaml
- Includes benchmarks: LongMemEval (R@5=94.0%), efficiency, noise vs scale
- Design doc: doc/nocmem.md
2026-04-11 12:24:48 +01:00

280 lines
10 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""Test nocmem with real conversation data from NOC's SQLite database.
Extracts conversation turns, ingests them, then tests recall with
realistic queries that a user would actually ask.
"""
import sys
import time
import sqlite3
import requests
BASE = "http://127.0.0.1:9820"
DB_PATH = "/data/src/noc/noc.db"
PASS = 0
FAIL = 0
def test(name, fn):
global PASS, FAIL
try:
fn()
print(f"{name}")
PASS += 1
except AssertionError as e:
print(f"{name}: {e}")
FAIL += 1
except Exception as e:
print(f"{name}: EXCEPTION {e}")
FAIL += 1
# ── step 1: extract conversation turns from SQLite ──────────────────
def extract_turns():
"""Extract (user_msg, assistant_msg) pairs from the database."""
conn = sqlite3.connect(DB_PATH)
rows = conn.execute(
"SELECT role, content FROM messages ORDER BY id"
).fetchall()
conn.close()
turns = []
i = 0
while i < len(rows) - 1:
role, content = rows[i]
# skip non-user messages, agent outputs, very short messages
if role != "user" or len(content) < 5 or content.startswith("[Agent ") or content.startswith("[用户上传") or content.startswith("[语音消息]"):
i += 1
continue
# find the next assistant reply
j = i + 1
while j < len(rows) and rows[j][0] != "assistant":
j += 1
if j < len(rows):
assistant_content = rows[j][1]
if len(assistant_content) > 10 and "<pad>" not in assistant_content:
turns.append((content, assistant_content))
i = j + 1
return turns
# ── step 2: ingest all turns ───────────────────────────────────────
def ingest_turns(turns):
"""Ingest conversation turns via /ingest endpoint."""
total_stored = 0
for user_msg, assistant_msg in turns:
r = requests.post(f"{BASE}/ingest", json={
"user_msg": user_msg,
"assistant_msg": assistant_msg,
})
if r.status_code == 200:
total_stored += r.json().get("stored", 0)
return total_stored
# ── step 3: also store some key facts directly ─────────────────────
def store_key_facts():
"""Store critical facts that heuristic extraction might miss."""
facts = [
{"cue": "bot的名字叫什么", "target": "bot的名字叫小乖是Fam给取的", "importance": 0.9},
{"cue": "有哪些工具可以用", "target": "工具有: fam_todo(飞书待办), send_file(发文件), spawn_agent/agent_status/kill_agent(子代理管理), run_shell, run_python, update_memory, update_inner_state, gen_voice", "importance": 0.8},
{"cue": "vLLM在5090上的性能", "target": "RTX 5090上vLLM跑gemma模型只有4.8 tok/s需要切换到awq_marlin量化来提升速度", "importance": 0.8},
{"cue": "repo-vis项目是什么", "target": "repo-vis是一个用Rust后端+Three.js前端的3D代码库可视化工具目标支持Linux内核级别的大型仓库和Pico VR", "importance": 0.8},
{"cue": "repo-vis的性能瓶颈", "target": "Linux内核79K文件量级下SQLite 1GB上限和O(n)全量反序列化是瓶颈需要n-ary tree按需合并优化", "importance": 0.9},
{"cue": "明天的待办事项", "target": "最紧迫的是emblem scanner的AI Chat和KB部分最高优先级然后是曲面二维码识读优化信息收集", "importance": 0.7},
{"cue": "后端切换到了什么", "target": "NOC后端从原来的方案切换到了vLLM速度变快了", "importance": 0.7},
{"cue": "home目录下有多少log文件", "target": "home目录及子目录下共有960个.log文件", "importance": 0.5},
]
stored = 0
for f in facts:
r = requests.post(f"{BASE}/store", json=f)
if r.status_code == 200:
stored += 1
return stored
# ── step 4: recall tests with realistic queries ────────────────────
def test_recall_bot_name():
r = requests.post(f"{BASE}/recall", json={"text": "你叫什么名字"})
data = r.json()
assert data["count"] > 0, "should recall something"
assert "小乖" in data["memories"], f"should mention 小乖, got: {data['memories'][:200]}"
def test_recall_tools():
r = requests.post(f"{BASE}/recall", json={"text": "有什么工具可以用"})
data = r.json()
assert data["count"] > 0
m = data["memories"].lower()
assert "tool" in m or "工具" in m or "spawn" in m or "fam_todo" in m, f"should mention tools, got: {data['memories'][:200]}"
def test_recall_vllm():
r = requests.post(f"{BASE}/recall", json={"text": "vllm性能怎么样"})
data = r.json()
assert data["count"] > 0
assert "4.8" in data["memories"] or "5090" in data["memories"] or "tok" in data["memories"], \
f"should mention vLLM stats, got: {data['memories'][:200]}"
def test_recall_repovis():
r = requests.post(f"{BASE}/recall", json={"text": "repo-vis项目"})
data = r.json()
assert data["count"] > 0
m = data["memories"]
assert "Rust" in m or "Three" in m or "3D" in m or "可视化" in m, \
f"should mention repo-vis tech, got: {m[:200]}"
def test_recall_performance_bottleneck():
r = requests.post(f"{BASE}/recall", json={"text": "Linux内核代码仓库跑不动"})
data = r.json()
assert data["count"] > 0
m = data["memories"]
assert "SQLite" in m or "79K" in m or "瓶颈" in m or "n-ary" in m or "内核" in m, \
f"should mention bottleneck, got: {m[:200]}"
def test_recall_todo():
r = requests.post(f"{BASE}/recall", json={"text": "待办事项有哪些"})
data = r.json()
assert data["count"] > 0
m = data["memories"]
assert "emblem" in m.lower() or "todo" in m.lower() or "待办" in m or "scanner" in m.lower(), \
f"should mention todos, got: {m[:200]}"
def test_recall_vr():
r = requests.post(f"{BASE}/recall", json={"text": "VR支持"})
data = r.json()
assert data["count"] > 0
m = data["memories"]
assert "Pico" in m or "VR" in m or "repo-vis" in m.lower(), \
f"should mention VR, got: {m[:200]}"
def test_recall_chinese_natural():
"""Test with natural Chinese conversational query."""
r = requests.post(f"{BASE}/recall", json={"text": "之前聊过什么技术话题"})
data = r.json()
assert data["count"] > 0, "should recall some technical topics"
def test_recall_cross_topic():
"""Query that spans multiple memories — should return diverse results."""
r = requests.post(f"{BASE}/recall", json={
"text": "项目进度和优化",
"top_k": 5,
})
data = r.json()
assert data["count"] >= 2, f"should recall multiple memories, got {data['count']}"
def test_recall_log_files():
r = requests.post(f"{BASE}/recall", json={"text": "日志文件有多少"})
data = r.json()
assert data["count"] > 0
assert "960" in data["memories"] or "log" in data["memories"].lower(), \
f"should mention log files, got: {data['memories'][:200]}"
# ── step 5: multi-hop chain test ──────────────────────────────────
def test_multihop_chain():
"""Test if Hebbian chaining connects related memories.
repo-vis → performance bottleneck → n-ary tree optimization
"""
r = requests.post(f"{BASE}/recall", json={
"text": "repo-vis",
"top_k": 3,
"hops": 3,
})
data = r.json()
assert data["count"] > 0
# print chain for inspection
print(f" chain: {data['memories'][:300]}")
# ── step 6: latency with real data ─────────────────────────────────
def test_latency_with_data():
"""Recall latency after loading real data."""
times = []
for q in ["工具", "vllm", "项目", "待办", "性能"]:
r = requests.post(f"{BASE}/recall", json={"text": q})
times.append(r.json()["latency_ms"])
avg = sum(times) / len(times)
print(f" avg latency: {avg:.1f}ms (max: {max(times):.1f}ms)")
assert avg < 50, f"average latency {avg:.1f}ms too high"
# ── main ────────────────────────────────────────────────────────────
def main():
global PASS, FAIL
print("nocmem real-data test")
print(f"server: {BASE}")
print(f"database: {DB_PATH}\n")
# check server
try:
requests.get(f"{BASE}/stats", timeout=3).raise_for_status()
except Exception:
print("ERROR: server not reachable")
sys.exit(1)
# extract
print("── extract ──")
turns = extract_turns()
print(f" extracted {len(turns)} conversation turns")
# ingest
print("\n── ingest (heuristic, no LLM) ──")
t0 = time.monotonic()
ingested = ingest_turns(turns)
elapsed = time.monotonic() - t0
print(f" ingested {ingested} memories from {len(turns)} turns ({elapsed:.1f}s)")
# store key facts
print("\n── store key facts ──")
stored = store_key_facts()
print(f" stored {stored} key facts")
# stats
r = requests.get(f"{BASE}/stats")
stats = r.json()
print(f"\n── memory stats ──")
print(f" memories: {stats['num_memories']}")
print(f" cue entries: {stats['num_cue_entries']} (aug ratio: {stats['augmentation_ratio']:.1f}x)")
print(f" W norm: {stats['w_norm']:.1f}")
# recall tests
print(f"\n── recall accuracy (natural language queries) ──")
test("bot的名字", test_recall_bot_name)
test("可用工具", test_recall_tools)
test("vLLM性能", test_recall_vllm)
test("repo-vis项目", test_recall_repovis)
test("性能瓶颈", test_recall_performance_bottleneck)
test("待办事项", test_recall_todo)
test("VR支持", test_recall_vr)
test("log文件数量", test_recall_log_files)
test("自然中文查询", test_recall_chinese_natural)
test("跨主题召回", test_recall_cross_topic)
print(f"\n── multi-hop chain ──")
test("repo-vis联想链", test_multihop_chain)
print(f"\n── latency ──")
test("平均延迟 < 50ms", test_latency_with_data)
print(f"\n{'='*50}")
total = PASS + FAIL
print(f"PASS: {PASS}/{total} FAIL: {FAIL}/{total}")
if FAIL:
sys.exit(1)
else:
print("All tests passed!")
if __name__ == "__main__":
main()