- nocmem Python service (mem/): FastAPI wrapper around NuoNuo's Hopfield-Hebbian memory, with /recall, /ingest, /store, /stats endpoints - NOC integration: auto recall after user message (injected as system msg), async ingest after LLM response (fire-and-forget) - Recall: cosine pre-filter (threshold 0.35) + Hopfield attention (β=32), top_k=3, KV-cache friendly (appended after user msg, not in system prompt) - Ingest: LLM extraction + paraphrase augmentation, heuristic fallback - Wired into main.rs, life.rs (agent done), http.rs (api chat) - Config: optional `nocmem.endpoint` in config.yaml - Includes benchmarks: LongMemEval (R@5=94.0%), efficiency, noise vs scale - Design doc: doc/nocmem.md
280 lines
10 KiB
Python
280 lines
10 KiB
Python
"""Test nocmem with real conversation data from NOC's SQLite database.
|
||
|
||
Extracts conversation turns, ingests them, then tests recall with
|
||
realistic queries that a user would actually ask.
|
||
"""
|
||
|
||
import sys
|
||
import time
|
||
import sqlite3
|
||
import requests
|
||
|
||
BASE = "http://127.0.0.1:9820"
|
||
DB_PATH = "/data/src/noc/noc.db"
|
||
|
||
PASS = 0
|
||
FAIL = 0
|
||
|
||
|
||
def test(name, fn):
|
||
global PASS, FAIL
|
||
try:
|
||
fn()
|
||
print(f" ✓ {name}")
|
||
PASS += 1
|
||
except AssertionError as e:
|
||
print(f" ✗ {name}: {e}")
|
||
FAIL += 1
|
||
except Exception as e:
|
||
print(f" ✗ {name}: EXCEPTION {e}")
|
||
FAIL += 1
|
||
|
||
|
||
# ── step 1: extract conversation turns from SQLite ──────────────────
|
||
|
||
def extract_turns():
|
||
"""Extract (user_msg, assistant_msg) pairs from the database."""
|
||
conn = sqlite3.connect(DB_PATH)
|
||
rows = conn.execute(
|
||
"SELECT role, content FROM messages ORDER BY id"
|
||
).fetchall()
|
||
conn.close()
|
||
|
||
turns = []
|
||
i = 0
|
||
while i < len(rows) - 1:
|
||
role, content = rows[i]
|
||
# skip non-user messages, agent outputs, very short messages
|
||
if role != "user" or len(content) < 5 or content.startswith("[Agent ") or content.startswith("[用户上传") or content.startswith("[语音消息]"):
|
||
i += 1
|
||
continue
|
||
# find the next assistant reply
|
||
j = i + 1
|
||
while j < len(rows) and rows[j][0] != "assistant":
|
||
j += 1
|
||
if j < len(rows):
|
||
assistant_content = rows[j][1]
|
||
if len(assistant_content) > 10 and "<pad>" not in assistant_content:
|
||
turns.append((content, assistant_content))
|
||
i = j + 1
|
||
|
||
return turns
|
||
|
||
|
||
# ── step 2: ingest all turns ───────────────────────────────────────
|
||
|
||
def ingest_turns(turns):
|
||
"""Ingest conversation turns via /ingest endpoint."""
|
||
total_stored = 0
|
||
for user_msg, assistant_msg in turns:
|
||
r = requests.post(f"{BASE}/ingest", json={
|
||
"user_msg": user_msg,
|
||
"assistant_msg": assistant_msg,
|
||
})
|
||
if r.status_code == 200:
|
||
total_stored += r.json().get("stored", 0)
|
||
return total_stored
|
||
|
||
|
||
# ── step 3: also store some key facts directly ─────────────────────
|
||
|
||
def store_key_facts():
|
||
"""Store critical facts that heuristic extraction might miss."""
|
||
facts = [
|
||
{"cue": "bot的名字叫什么", "target": "bot的名字叫小乖,是Fam给取的", "importance": 0.9},
|
||
{"cue": "有哪些工具可以用", "target": "工具有: fam_todo(飞书待办), send_file(发文件), spawn_agent/agent_status/kill_agent(子代理管理), run_shell, run_python, update_memory, update_inner_state, gen_voice", "importance": 0.8},
|
||
{"cue": "vLLM在5090上的性能", "target": "RTX 5090上vLLM跑gemma模型只有4.8 tok/s,需要切换到awq_marlin量化来提升速度", "importance": 0.8},
|
||
{"cue": "repo-vis项目是什么", "target": "repo-vis是一个用Rust后端+Three.js前端的3D代码库可视化工具,目标支持Linux内核级别的大型仓库和Pico VR", "importance": 0.8},
|
||
{"cue": "repo-vis的性能瓶颈", "target": "Linux内核79K文件量级下,SQLite 1GB上限和O(n)全量反序列化是瓶颈,需要n-ary tree按需合并优化", "importance": 0.9},
|
||
{"cue": "明天的待办事项", "target": "最紧迫的是emblem scanner的AI Chat和KB部分(最高优先级),然后是曲面二维码识读优化信息收集", "importance": 0.7},
|
||
{"cue": "后端切换到了什么", "target": "NOC后端从原来的方案切换到了vLLM,速度变快了", "importance": 0.7},
|
||
{"cue": "home目录下有多少log文件", "target": "home目录及子目录下共有960个.log文件", "importance": 0.5},
|
||
]
|
||
stored = 0
|
||
for f in facts:
|
||
r = requests.post(f"{BASE}/store", json=f)
|
||
if r.status_code == 200:
|
||
stored += 1
|
||
return stored
|
||
|
||
|
||
# ── step 4: recall tests with realistic queries ────────────────────
|
||
|
||
def test_recall_bot_name():
|
||
r = requests.post(f"{BASE}/recall", json={"text": "你叫什么名字"})
|
||
data = r.json()
|
||
assert data["count"] > 0, "should recall something"
|
||
assert "小乖" in data["memories"], f"should mention 小乖, got: {data['memories'][:200]}"
|
||
|
||
def test_recall_tools():
|
||
r = requests.post(f"{BASE}/recall", json={"text": "有什么工具可以用"})
|
||
data = r.json()
|
||
assert data["count"] > 0
|
||
m = data["memories"].lower()
|
||
assert "tool" in m or "工具" in m or "spawn" in m or "fam_todo" in m, f"should mention tools, got: {data['memories'][:200]}"
|
||
|
||
def test_recall_vllm():
|
||
r = requests.post(f"{BASE}/recall", json={"text": "vllm性能怎么样"})
|
||
data = r.json()
|
||
assert data["count"] > 0
|
||
assert "4.8" in data["memories"] or "5090" in data["memories"] or "tok" in data["memories"], \
|
||
f"should mention vLLM stats, got: {data['memories'][:200]}"
|
||
|
||
def test_recall_repovis():
|
||
r = requests.post(f"{BASE}/recall", json={"text": "repo-vis项目"})
|
||
data = r.json()
|
||
assert data["count"] > 0
|
||
m = data["memories"]
|
||
assert "Rust" in m or "Three" in m or "3D" in m or "可视化" in m, \
|
||
f"should mention repo-vis tech, got: {m[:200]}"
|
||
|
||
def test_recall_performance_bottleneck():
|
||
r = requests.post(f"{BASE}/recall", json={"text": "Linux内核代码仓库跑不动"})
|
||
data = r.json()
|
||
assert data["count"] > 0
|
||
m = data["memories"]
|
||
assert "SQLite" in m or "79K" in m or "瓶颈" in m or "n-ary" in m or "内核" in m, \
|
||
f"should mention bottleneck, got: {m[:200]}"
|
||
|
||
def test_recall_todo():
|
||
r = requests.post(f"{BASE}/recall", json={"text": "待办事项有哪些"})
|
||
data = r.json()
|
||
assert data["count"] > 0
|
||
m = data["memories"]
|
||
assert "emblem" in m.lower() or "todo" in m.lower() or "待办" in m or "scanner" in m.lower(), \
|
||
f"should mention todos, got: {m[:200]}"
|
||
|
||
def test_recall_vr():
|
||
r = requests.post(f"{BASE}/recall", json={"text": "VR支持"})
|
||
data = r.json()
|
||
assert data["count"] > 0
|
||
m = data["memories"]
|
||
assert "Pico" in m or "VR" in m or "repo-vis" in m.lower(), \
|
||
f"should mention VR, got: {m[:200]}"
|
||
|
||
def test_recall_chinese_natural():
|
||
"""Test with natural Chinese conversational query."""
|
||
r = requests.post(f"{BASE}/recall", json={"text": "之前聊过什么技术话题"})
|
||
data = r.json()
|
||
assert data["count"] > 0, "should recall some technical topics"
|
||
|
||
def test_recall_cross_topic():
|
||
"""Query that spans multiple memories — should return diverse results."""
|
||
r = requests.post(f"{BASE}/recall", json={
|
||
"text": "项目进度和优化",
|
||
"top_k": 5,
|
||
})
|
||
data = r.json()
|
||
assert data["count"] >= 2, f"should recall multiple memories, got {data['count']}"
|
||
|
||
def test_recall_log_files():
|
||
r = requests.post(f"{BASE}/recall", json={"text": "日志文件有多少"})
|
||
data = r.json()
|
||
assert data["count"] > 0
|
||
assert "960" in data["memories"] or "log" in data["memories"].lower(), \
|
||
f"should mention log files, got: {data['memories'][:200]}"
|
||
|
||
|
||
# ── step 5: multi-hop chain test ──────────────────────────────────
|
||
|
||
def test_multihop_chain():
|
||
"""Test if Hebbian chaining connects related memories.
|
||
|
||
repo-vis → performance bottleneck → n-ary tree optimization
|
||
"""
|
||
r = requests.post(f"{BASE}/recall", json={
|
||
"text": "repo-vis",
|
||
"top_k": 3,
|
||
"hops": 3,
|
||
})
|
||
data = r.json()
|
||
assert data["count"] > 0
|
||
# print chain for inspection
|
||
print(f" chain: {data['memories'][:300]}")
|
||
|
||
|
||
# ── step 6: latency with real data ─────────────────────────────────
|
||
|
||
def test_latency_with_data():
|
||
"""Recall latency after loading real data."""
|
||
times = []
|
||
for q in ["工具", "vllm", "项目", "待办", "性能"]:
|
||
r = requests.post(f"{BASE}/recall", json={"text": q})
|
||
times.append(r.json()["latency_ms"])
|
||
avg = sum(times) / len(times)
|
||
print(f" avg latency: {avg:.1f}ms (max: {max(times):.1f}ms)")
|
||
assert avg < 50, f"average latency {avg:.1f}ms too high"
|
||
|
||
|
||
# ── main ────────────────────────────────────────────────────────────
|
||
|
||
def main():
|
||
global PASS, FAIL
|
||
|
||
print("nocmem real-data test")
|
||
print(f"server: {BASE}")
|
||
print(f"database: {DB_PATH}\n")
|
||
|
||
# check server
|
||
try:
|
||
requests.get(f"{BASE}/stats", timeout=3).raise_for_status()
|
||
except Exception:
|
||
print("ERROR: server not reachable")
|
||
sys.exit(1)
|
||
|
||
# extract
|
||
print("── extract ──")
|
||
turns = extract_turns()
|
||
print(f" extracted {len(turns)} conversation turns")
|
||
|
||
# ingest
|
||
print("\n── ingest (heuristic, no LLM) ──")
|
||
t0 = time.monotonic()
|
||
ingested = ingest_turns(turns)
|
||
elapsed = time.monotonic() - t0
|
||
print(f" ingested {ingested} memories from {len(turns)} turns ({elapsed:.1f}s)")
|
||
|
||
# store key facts
|
||
print("\n── store key facts ──")
|
||
stored = store_key_facts()
|
||
print(f" stored {stored} key facts")
|
||
|
||
# stats
|
||
r = requests.get(f"{BASE}/stats")
|
||
stats = r.json()
|
||
print(f"\n── memory stats ──")
|
||
print(f" memories: {stats['num_memories']}")
|
||
print(f" cue entries: {stats['num_cue_entries']} (aug ratio: {stats['augmentation_ratio']:.1f}x)")
|
||
print(f" W norm: {stats['w_norm']:.1f}")
|
||
|
||
# recall tests
|
||
print(f"\n── recall accuracy (natural language queries) ──")
|
||
test("bot的名字", test_recall_bot_name)
|
||
test("可用工具", test_recall_tools)
|
||
test("vLLM性能", test_recall_vllm)
|
||
test("repo-vis项目", test_recall_repovis)
|
||
test("性能瓶颈", test_recall_performance_bottleneck)
|
||
test("待办事项", test_recall_todo)
|
||
test("VR支持", test_recall_vr)
|
||
test("log文件数量", test_recall_log_files)
|
||
test("自然中文查询", test_recall_chinese_natural)
|
||
test("跨主题召回", test_recall_cross_topic)
|
||
|
||
print(f"\n── multi-hop chain ──")
|
||
test("repo-vis联想链", test_multihop_chain)
|
||
|
||
print(f"\n── latency ──")
|
||
test("平均延迟 < 50ms", test_latency_with_data)
|
||
|
||
print(f"\n{'='*50}")
|
||
total = PASS + FAIL
|
||
print(f"PASS: {PASS}/{total} FAIL: {FAIL}/{total}")
|
||
if FAIL:
|
||
sys.exit(1)
|
||
else:
|
||
print("All tests passed!")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|