- nocmem Python service (mem/): FastAPI wrapper around NuoNuo's Hopfield-Hebbian memory, with /recall, /ingest, /store, /stats endpoints - NOC integration: auto recall after user message (injected as system msg), async ingest after LLM response (fire-and-forget) - Recall: cosine pre-filter (threshold 0.35) + Hopfield attention (β=32), top_k=3, KV-cache friendly (appended after user msg, not in system prompt) - Ingest: LLM extraction + paraphrase augmentation, heuristic fallback - Wired into main.rs, life.rs (agent done), http.rs (api chat) - Config: optional `nocmem.endpoint` in config.yaml - Includes benchmarks: LongMemEval (R@5=94.0%), efficiency, noise vs scale - Design doc: doc/nocmem.md
391 lines
14 KiB
Python
391 lines
14 KiB
Python
"""nocmem API integration tests.
|
||
|
||
Run with: uv run python test_api.py
|
||
Requires nocmem server running on localhost:9820.
|
||
"""
|
||
|
||
import sys
|
||
import time
|
||
import requests
|
||
|
||
BASE = "http://127.0.0.1:9820"
|
||
PASS = 0
|
||
FAIL = 0
|
||
|
||
|
||
def test(name: str, fn):
|
||
global PASS, FAIL
|
||
try:
|
||
fn()
|
||
print(f" ✓ {name}")
|
||
PASS += 1
|
||
except AssertionError as e:
|
||
print(f" ✗ {name}: {e}")
|
||
FAIL += 1
|
||
except Exception as e:
|
||
print(f" ✗ {name}: EXCEPTION {e}")
|
||
FAIL += 1
|
||
|
||
|
||
def assert_eq(a, b, msg=""):
|
||
assert a == b, f"expected {b!r}, got {a!r}" + (f" ({msg})" if msg else "")
|
||
|
||
|
||
def assert_gt(a, b, msg=""):
|
||
assert a > b, f"expected > {b!r}, got {a!r}" + (f" ({msg})" if msg else "")
|
||
|
||
|
||
def assert_in(needle, haystack, msg=""):
|
||
assert needle in haystack, f"{needle!r} not in {haystack!r}" + (f" ({msg})" if msg else "")
|
||
|
||
|
||
# ── health check ────────────────────────────────────────────────────
|
||
|
||
def check_server():
|
||
try:
|
||
r = requests.get(f"{BASE}/stats", timeout=3)
|
||
r.raise_for_status()
|
||
return True
|
||
except Exception:
|
||
return False
|
||
|
||
|
||
# ── test: stats on empty db ─────────────────────────────────────────
|
||
|
||
def test_stats_empty():
|
||
r = requests.get(f"{BASE}/stats")
|
||
assert_eq(r.status_code, 200)
|
||
data = r.json()
|
||
assert "num_memories" in data
|
||
assert "device" in data
|
||
assert_eq(data["embedding_model"], "all-MiniLM-L6-v2")
|
||
|
||
|
||
# ── test: recall on empty db ───────────────────────<E29480><E29480><EFBFBD>────────────────
|
||
|
||
def test_recall_empty():
|
||
r = requests.post(f"{BASE}/recall", json={"text": "hello"})
|
||
assert_eq(r.status_code, 200)
|
||
data = r.json()
|
||
assert_eq(data["memories"], "")
|
||
assert_eq(data["count"], 0)
|
||
|
||
|
||
# ── test: direct store ────────<E29480><E29480><EFBFBD>─────────────────────────────────────
|
||
|
||
stored_ids = []
|
||
|
||
def test_store_single():
|
||
r = requests.post(f"{BASE}/store", json={
|
||
"cue": "what port does postgres run on",
|
||
"target": "PostgreSQL runs on port 5432",
|
||
"importance": 0.8,
|
||
})
|
||
assert_eq(r.status_code, 200)
|
||
data = r.json()
|
||
assert "memory_id" in data
|
||
stored_ids.append(data["memory_id"])
|
||
|
||
|
||
def test_store_multiple():
|
||
memories = [
|
||
{"cue": "what is the database password", "target": "The DB password is stored in /etc/secrets/db.env", "importance": 0.9},
|
||
{"cue": "how to deploy the app", "target": "Run make deploy-hera to deploy to the suite VPS via SSH", "importance": 0.7},
|
||
{"cue": "what timezone is Fam in", "target": "Fam is in London, UK timezone (Europe/London, GMT/BST)", "importance": 0.6},
|
||
{"cue": "which embedding model works best", "target": "all-MiniLM-L6-v2 has the best gap metric for hippocampal memory", "importance": 0.8},
|
||
{"cue": "what GPU does the server have", "target": "The server has an NVIDIA RTX 4090 with 24GB VRAM", "importance": 0.7},
|
||
]
|
||
for m in memories:
|
||
r = requests.post(f"{BASE}/store", json=m)
|
||
assert_eq(r.status_code, 200)
|
||
stored_ids.append(r.json()["memory_id"])
|
||
|
||
|
||
# ── test: exact recall ──────────────────────────────────────────────
|
||
|
||
def test_recall_exact():
|
||
"""Recall with the exact cue text should return the right memory."""
|
||
r = requests.post(f"{BASE}/recall", json={
|
||
"text": "what port does postgres run on",
|
||
"top_k": 3,
|
||
})
|
||
assert_eq(r.status_code, 200)
|
||
data = r.json()
|
||
assert_gt(data["count"], 0, "should recall at least 1")
|
||
assert_in("5432", data["memories"], "should mention port 5432")
|
||
|
||
|
||
# ── test: paraphrase recall ─────────────────────────────────────────
|
||
|
||
def test_recall_paraphrase():
|
||
"""Recall with a paraphrased query (not exact cue text)."""
|
||
r = requests.post(f"{BASE}/recall", json={
|
||
"text": "which port is postgresql listening on",
|
||
"top_k": 3,
|
||
})
|
||
assert_eq(r.status_code, 200)
|
||
data = r.json()
|
||
assert_gt(data["count"], 0, "paraphrase should still recall")
|
||
assert_in("5432", data["memories"])
|
||
|
||
|
||
def test_recall_different_wording():
|
||
"""Even more different wording."""
|
||
r = requests.post(f"{BASE}/recall", json={
|
||
"text": "database connection port number",
|
||
"top_k": 3,
|
||
})
|
||
assert_eq(r.status_code, 200)
|
||
data = r.json()
|
||
assert_gt(data["count"], 0, "different wording should recall")
|
||
assert_in("5432", data["memories"])
|
||
|
||
|
||
# ── test: recall relevance ──────────────────────────────────────────
|
||
|
||
def test_recall_deployment():
|
||
r = requests.post(f"{BASE}/recall", json={
|
||
"text": "how do I deploy to production",
|
||
"top_k": 3,
|
||
})
|
||
assert_eq(r.status_code, 200)
|
||
data = r.json()
|
||
assert_gt(data["count"], 0)
|
||
assert_in("deploy", data["memories"].lower())
|
||
|
||
|
||
def test_recall_timezone():
|
||
r = requests.post(f"{BASE}/recall", json={
|
||
"text": "where is Fam located",
|
||
"top_k": 3,
|
||
})
|
||
assert_eq(r.status_code, 200)
|
||
data = r.json()
|
||
assert_gt(data["count"], 0)
|
||
assert_in("London", data["memories"])
|
||
|
||
|
||
def test_recall_gpu():
|
||
r = requests.post(f"{BASE}/recall", json={
|
||
"text": "what hardware does the server have",
|
||
"top_k": 3,
|
||
})
|
||
assert_eq(r.status_code, 200)
|
||
data = r.json()
|
||
assert_gt(data["count"], 0)
|
||
assert_in("4090", data["memories"])
|
||
|
||
|
||
# ── test: top_k ─────────────────────────────────────────────────────
|
||
|
||
def test_recall_top_k_1():
|
||
r = requests.post(f"{BASE}/recall", json={
|
||
"text": "postgres port",
|
||
"top_k": 1,
|
||
})
|
||
data = r.json()
|
||
assert_eq(data["count"], 1, "top_k=1 should return exactly 1")
|
||
|
||
|
||
def test_recall_top_k_all():
|
||
r = requests.post(f"{BASE}/recall", json={
|
||
"text": "tell me everything",
|
||
"top_k": 20,
|
||
})
|
||
data = r.json()
|
||
assert_gt(data["count"], 0, "should recall something")
|
||
|
||
|
||
# ── test: recall latency ────────────────────────────────────────────
|
||
|
||
def test_recall_latency():
|
||
"""Recall should be fast (< 100ms including HTTP + embedding)."""
|
||
t0 = time.monotonic()
|
||
r = requests.post(f"{BASE}/recall", json={"text": "database port"})
|
||
elapsed_ms = (time.monotonic() - t0) * 1000
|
||
data = r.json()
|
||
# internal latency (no HTTP overhead)
|
||
assert data["latency_ms"] < 100, f"internal latency {data['latency_ms']:.1f}ms too high"
|
||
# end-to-end including HTTP
|
||
print(f" (e2e={elapsed_ms:.1f}ms, internal={data['latency_ms']:.1f}ms)")
|
||
|
||
|
||
# ── test: ingest (heuristic, no LLM) ───────────────────────────────
|
||
|
||
def test_ingest_heuristic():
|
||
"""Ingest without LLM should use heuristic extraction."""
|
||
r = requests.post(f"{BASE}/ingest", json={
|
||
"user_msg": "What version of Python are we running?",
|
||
"assistant_msg": "We are running Python 3.12.4 on the server, installed via uv.",
|
||
})
|
||
assert_eq(r.status_code, 200)
|
||
data = r.json()
|
||
# heuristic should extract at least the Q&A pair
|
||
assert_gt(data["stored"], 0, "heuristic should extract at least 1 memory")
|
||
|
||
|
||
def test_ingest_then_recall():
|
||
"""After ingesting, the memory should be recallable."""
|
||
# first ingest
|
||
requests.post(f"{BASE}/ingest", json={
|
||
"user_msg": "What's the Redis cache TTL?",
|
||
"assistant_msg": "The Redis cache TTL is set to 3600 seconds (1 hour) in production.",
|
||
})
|
||
# wait a tiny bit for async processing
|
||
time.sleep(0.5)
|
||
# then recall
|
||
r = requests.post(f"{BASE}/recall", json={
|
||
"text": "redis cache timeout",
|
||
"top_k": 3,
|
||
})
|
||
data = r.json()
|
||
assert_gt(data["count"], 0, "ingested memory should be recallable")
|
||
# Check it mentions the TTL
|
||
assert_in("3600", data["memories"], "should recall the TTL value")
|
||
|
||
|
||
# ── test: forget ───────────<E29480><E29480><EFBFBD>────────────────────────<E29480><E29480>───────────────
|
||
|
||
def test_forget():
|
||
"""Delete a memory and verify it's gone."""
|
||
# store something
|
||
r = requests.post(f"{BASE}/store", json={
|
||
"cue": "temporary test memory for deletion",
|
||
"target": "this should be deleted XYZZY",
|
||
})
|
||
mid = r.json()["memory_id"]
|
||
|
||
# verify it's recallable
|
||
r = requests.post(f"{BASE}/recall", json={"text": "temporary test memory for deletion"})
|
||
assert_in("XYZZY", r.json()["memories"])
|
||
|
||
# delete
|
||
r = requests.delete(f"{BASE}/memory/{mid}")
|
||
assert_eq(r.status_code, 200)
|
||
|
||
# verify gone — recall the exact cue, should not return XYZZY
|
||
r = requests.post(f"{BASE}/recall", json={"text": "temporary test memory for deletion"})
|
||
if r.json()["memories"]:
|
||
assert "XYZZY" not in r.json()["memories"], "deleted memory should not appear"
|
||
|
||
|
||
# ── test: format ─────────────────────────────────────<E29480><E29480>──────────────
|
||
|
||
def test_recall_format():
|
||
"""Recalled memories should have the expected format."""
|
||
r = requests.post(f"{BASE}/recall", json={"text": "postgres port"})
|
||
data = r.json()
|
||
if data["count"] > 0:
|
||
assert data["memories"].startswith("[相关记忆]"), "should start with header"
|
||
assert "\n- " in data["memories"], "each memory should start with '- '"
|
||
|
||
|
||
# ── test: stats after stores ──────<E29480><E29480>─────────────────────────────────
|
||
|
||
def test_stats_after():
|
||
r = requests.get(f"{BASE}/stats")
|
||
data = r.json()
|
||
assert_gt(data["num_memories"], 0, "should have memories")
|
||
assert_gt(data["num_cue_entries"], data["num_memories"],
|
||
"cue entries should >= memories (augmentation from ingest)")
|
||
|
||
|
||
# ── test: edge cases ────────────────────────────────────────────────
|
||
|
||
def test_recall_empty_text():
|
||
r = requests.post(f"{BASE}/recall", json={"text": ""})
|
||
# should not crash
|
||
assert r.status_code == 200
|
||
|
||
|
||
def test_recall_long_text():
|
||
r = requests.post(f"{BASE}/recall", json={"text": "a " * 1000})
|
||
assert r.status_code == 200
|
||
|
||
|
||
def test_recall_chinese():
|
||
"""Chinese text should work."""
|
||
# store a Chinese memory
|
||
requests.post(f"{BASE}/store", json={
|
||
"cue": "数据库密码在哪里",
|
||
"target": "数据库密码存在 /etc/secrets/db.env 文件中",
|
||
})
|
||
r = requests.post(f"{BASE}/recall", json={"text": "数据库密码"})
|
||
data = r.json()
|
||
assert_gt(data["count"], 0, "Chinese recall should work")
|
||
assert_in("secrets", data["memories"])
|
||
|
||
|
||
def test_store_validation():
|
||
"""Missing required fields should return 422."""
|
||
r = requests.post(f"{BASE}/store", json={"cue": "only cue"})
|
||
assert_eq(r.status_code, 422)
|
||
|
||
|
||
# ── run ─────<E29480><E29480><EFBFBD>───────────────────────────────────────────────────────
|
||
|
||
def main():
|
||
global PASS, FAIL
|
||
|
||
print("nocmem API tests")
|
||
print(f"server: {BASE}\n")
|
||
|
||
if not check_server():
|
||
print("ERROR: server not reachable")
|
||
sys.exit(1)
|
||
|
||
# first clean slate — check what we start with
|
||
r = requests.get(f"{BASE}/stats")
|
||
initial = r.json()["num_memories"]
|
||
|
||
print(f"[initial state: {initial} memories]\n")
|
||
|
||
print("── basic ──")
|
||
test("stats endpoint", test_stats_empty)
|
||
test("recall on empty/existing db", test_recall_empty if initial == 0 else lambda: None)
|
||
|
||
print("\n── store ──")
|
||
test("store single memory", test_store_single)
|
||
test("store multiple memories", test_store_multiple)
|
||
|
||
print("\n── recall accuracy ─<><E29480><EFBFBD>")
|
||
test("exact cue recall", test_recall_exact)
|
||
test("paraphrase recall", test_recall_paraphrase)
|
||
test("different wording recall", test_recall_different_wording)
|
||
test("deployment query", test_recall_deployment)
|
||
test("timezone query", test_recall_timezone)
|
||
test("GPU query", test_recall_gpu)
|
||
|
||
print("\n── recall params ──")
|
||
test("top_k=1", test_recall_top_k_1)
|
||
test("top_k=20 (all)", test_recall_top_k_all)
|
||
test("latency < 100ms", test_recall_latency)
|
||
test("format check", test_recall_format)
|
||
|
||
print("\n── ingest ──")
|
||
test("heuristic ingest", test_ingest_heuristic)
|
||
test("ingest then recall", test_ingest_then_recall)
|
||
|
||
print("\n── forget ──")
|
||
test("store + forget + verify", test_forget)
|
||
|
||
print("\n── edge cases ──")
|
||
test("empty text", test_recall_empty_text)
|
||
test("long text", test_recall_long_text)
|
||
test("Chinese text", test_recall_chinese)
|
||
test("validation error", test_store_validation)
|
||
|
||
print("\n── stats ──")
|
||
test("stats after stores", test_stats_after)
|
||
|
||
print(f"\n{'='*40}")
|
||
print(f"PASS: {PASS} FAIL: {FAIL}")
|
||
if FAIL:
|
||
sys.exit(1)
|
||
else:
|
||
print("All tests passed!")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|