mirror of
https://github.com/pewdiepie-archdaemon/odysseus.git
synced 2026-06-15 17:25:26 -04:00
299 lines
10 KiB
Python
299 lines
10 KiB
Python
from datetime import datetime, timedelta
|
|
import asyncio
|
|
import sqlite3
|
|
|
|
from sqlalchemy import create_engine
|
|
from sqlalchemy.orm import sessionmaker
|
|
|
|
from core.database import Base
|
|
from core.database import ChatMessage as DbChatMessage
|
|
from core.database import Session as DbSession
|
|
from src.session_search import SessionSearchResult, search_session_messages
|
|
|
|
|
|
def _db(with_fts=True):
|
|
engine = create_engine("sqlite:///:memory:")
|
|
Base.metadata.create_all(engine)
|
|
db = sessionmaker(bind=engine)()
|
|
if with_fts:
|
|
db.connection().exec_driver_sql(
|
|
"""
|
|
CREATE VIRTUAL TABLE chat_messages_fts USING fts5(
|
|
content,
|
|
message_id UNINDEXED,
|
|
session_id UNINDEXED,
|
|
role UNINDEXED
|
|
)
|
|
"""
|
|
)
|
|
return db
|
|
|
|
|
|
def _add_session(db, sid, owner="alice", archived=False, name=None):
|
|
db.add(
|
|
DbSession(
|
|
id=sid,
|
|
name=name or sid,
|
|
endpoint_url="http://example.test",
|
|
model="test-model",
|
|
owner=owner,
|
|
archived=archived,
|
|
message_count=0,
|
|
)
|
|
)
|
|
|
|
|
|
def _add_message(db, sid, mid, role, content, when):
|
|
db.add(DbChatMessage(id=mid, session_id=sid, role=role, content=content, timestamp=when))
|
|
if _has_fts(db):
|
|
db.connection().exec_driver_sql(
|
|
"INSERT INTO chat_messages_fts(content, message_id, session_id, role) VALUES (?, ?, ?, ?)",
|
|
(content, mid, sid, role),
|
|
)
|
|
|
|
|
|
def _has_fts(db):
|
|
return (
|
|
db.connection()
|
|
.exec_driver_sql("SELECT 1 FROM sqlite_master WHERE type='table' AND name='chat_messages_fts'")
|
|
.first()
|
|
is not None
|
|
)
|
|
|
|
|
|
def test_session_search_uses_fts_and_returns_context():
|
|
db = _db(with_fts=True)
|
|
try:
|
|
base = datetime(2026, 1, 1, 12, 0, 0)
|
|
_add_session(db, "s1", owner="alice", name="Jazz planning")
|
|
_add_message(db, "s1", "m1", "user", "Before context about music", base)
|
|
_add_message(db, "s1", "m2", "assistant", "We talked about modal jazz theory", base + timedelta(minutes=1))
|
|
_add_message(db, "s1", "m3", "user", "After context about tasks", base + timedelta(minutes=2))
|
|
db.commit()
|
|
|
|
results = search_session_messages("modal jazz", owner="alice", db=db)
|
|
|
|
assert [r.message_id for r in results] == ["m2"]
|
|
assert results[0].session_name == "Jazz planning"
|
|
assert results[0].context_before[0]["message_id"] == "m1"
|
|
assert results[0].context_after[0]["message_id"] == "m3"
|
|
assert "modal" in results[0].content_snippet.lower()
|
|
finally:
|
|
db.close()
|
|
|
|
|
|
def test_session_search_escapes_like_wildcards_in_fallback():
|
|
db = _db(with_fts=False)
|
|
try:
|
|
base = datetime(2026, 1, 1, 12, 0, 0)
|
|
_add_session(db, "s1", owner="alice")
|
|
_add_message(db, "s1", "literal", "user", "The literal token is foo_bar.", base)
|
|
_add_message(db, "s1", "wild", "user", "The wildcard-looking token is fooXbar.", base + timedelta(minutes=1))
|
|
db.commit()
|
|
|
|
results = search_session_messages("foo_bar", owner="alice", db=db)
|
|
|
|
assert [r.message_id for r in results] == ["literal"]
|
|
finally:
|
|
db.close()
|
|
|
|
|
|
def test_session_search_owner_scope_includes_legacy_and_excludes_other_users():
|
|
db = _db(with_fts=True)
|
|
try:
|
|
base = datetime(2026, 1, 1, 12, 0, 0)
|
|
_add_session(db, "alice", owner="alice")
|
|
_add_session(db, "legacy", owner=None)
|
|
_add_session(db, "bob", owner="bob")
|
|
_add_message(db, "alice", "m-alice", "user", "shared recall target", base)
|
|
_add_message(db, "legacy", "m-legacy", "user", "shared recall target", base + timedelta(minutes=1))
|
|
_add_message(db, "bob", "m-bob", "user", "shared recall target", base + timedelta(minutes=2))
|
|
db.commit()
|
|
|
|
results = search_session_messages("shared recall target", owner="alice", db=db)
|
|
|
|
assert {r.message_id for r in results} == {"m-alice", "m-legacy"}
|
|
finally:
|
|
db.close()
|
|
|
|
|
|
def test_session_search_can_exclude_legacy_rows_for_authenticated_ui_scope():
|
|
db = _db(with_fts=True)
|
|
try:
|
|
base = datetime(2026, 1, 1, 12, 0, 0)
|
|
_add_session(db, "alice", owner="alice")
|
|
_add_session(db, "legacy", owner=None)
|
|
_add_message(db, "alice", "m-alice", "user", "exact owner target", base)
|
|
_add_message(db, "legacy", "m-legacy", "user", "exact owner target", base + timedelta(minutes=1))
|
|
db.commit()
|
|
|
|
results = search_session_messages(
|
|
"exact owner target",
|
|
owner="alice",
|
|
include_legacy_owner=False,
|
|
db=db,
|
|
)
|
|
|
|
assert [r.message_id for r in results] == ["m-alice"]
|
|
finally:
|
|
db.close()
|
|
|
|
|
|
def test_session_search_ownerless_call_only_sees_legacy_rows():
|
|
db = _db(with_fts=True)
|
|
try:
|
|
base = datetime(2026, 1, 1, 12, 0, 0)
|
|
_add_session(db, "alice", owner="alice")
|
|
_add_session(db, "legacy", owner=None)
|
|
_add_message(db, "alice", "m-alice", "user", "ownerless search target", base)
|
|
_add_message(db, "legacy", "m-legacy", "user", "ownerless search target", base + timedelta(minutes=1))
|
|
db.commit()
|
|
|
|
results = search_session_messages("ownerless search target", owner=None, db=db)
|
|
|
|
assert [r.message_id for r in results] == ["m-legacy"]
|
|
finally:
|
|
db.close()
|
|
|
|
|
|
def test_session_search_falls_back_to_like_when_fts_has_no_substring_hits():
|
|
db = _db(with_fts=True)
|
|
try:
|
|
base = datetime(2026, 1, 1, 12, 0, 0)
|
|
_add_session(db, "s1", owner="alice")
|
|
_add_message(db, "s1", "m1", "user", "We discussed customidentifier routing.", base)
|
|
db.commit()
|
|
|
|
results = search_session_messages("identifier", owner="alice", db=db)
|
|
|
|
assert [r.message_id for r in results] == ["m1"]
|
|
assert "identifier" in results[0].content_snippet
|
|
finally:
|
|
db.close()
|
|
|
|
|
|
def test_session_search_merges_like_substring_hits_with_fts_hits():
|
|
db = _db(with_fts=True)
|
|
try:
|
|
base = datetime(2026, 1, 1, 12, 0, 0)
|
|
_add_session(db, "s1", owner="alice")
|
|
_add_message(db, "s1", "m-token", "user", "The identifier token is standalone.", base)
|
|
_add_message(db, "s1", "m-substring", "assistant", "We also discussed customidentifier routing.", base + timedelta(minutes=1))
|
|
db.commit()
|
|
|
|
results = search_session_messages("identifier", owner="alice", db=db)
|
|
|
|
assert {r.message_id for r in results} == {"m-token", "m-substring"}
|
|
finally:
|
|
db.close()
|
|
|
|
|
|
def test_session_search_can_preserve_unrestricted_no_auth_route_scope():
|
|
db = _db(with_fts=True)
|
|
try:
|
|
base = datetime(2026, 1, 1, 12, 0, 0)
|
|
_add_session(db, "owned", owner="admin")
|
|
_add_session(db, "legacy", owner=None)
|
|
_add_message(db, "owned", "m-owned", "user", "no auth search target", base)
|
|
_add_message(db, "legacy", "m-legacy", "user", "no auth search target", base + timedelta(minutes=1))
|
|
db.commit()
|
|
|
|
results = search_session_messages(
|
|
"no auth search target",
|
|
owner=None,
|
|
restrict_owner=False,
|
|
db=db,
|
|
)
|
|
|
|
assert {r.message_id for r in results} == {"m-owned", "m-legacy"}
|
|
finally:
|
|
db.close()
|
|
|
|
|
|
def test_session_search_excludes_archived_by_default():
|
|
db = _db(with_fts=True)
|
|
try:
|
|
base = datetime(2026, 1, 1, 12, 0, 0)
|
|
_add_session(db, "active", owner="alice")
|
|
_add_session(db, "archived", owner="alice", archived=True)
|
|
_add_message(db, "active", "m-active", "user", "archive filter target", base)
|
|
_add_message(db, "archived", "m-archived", "user", "archive filter target", base + timedelta(minutes=1))
|
|
db.commit()
|
|
|
|
results = search_session_messages("archive filter target", owner="alice", db=db)
|
|
|
|
assert [r.message_id for r in results] == ["m-active"]
|
|
finally:
|
|
db.close()
|
|
|
|
|
|
def test_chat_messages_fts_migration_backfills_and_tracks_inserts(tmp_path, monkeypatch):
|
|
from core import database as cdb
|
|
|
|
db_path = tmp_path / "app.db"
|
|
conn = sqlite3.connect(db_path)
|
|
conn.executescript(
|
|
"""
|
|
CREATE TABLE chat_messages (
|
|
id TEXT PRIMARY KEY,
|
|
session_id TEXT NOT NULL,
|
|
role TEXT NOT NULL,
|
|
content TEXT NOT NULL
|
|
);
|
|
INSERT INTO chat_messages(id, session_id, role, content)
|
|
VALUES ('m1', 's1', 'user', 'backfilled transcript search');
|
|
"""
|
|
)
|
|
conn.close()
|
|
|
|
monkeypatch.setattr(cdb, "DATABASE_URL", f"sqlite:///{db_path}")
|
|
|
|
cdb._migrate_chat_messages_fts()
|
|
|
|
conn = sqlite3.connect(db_path)
|
|
try:
|
|
backfilled = conn.execute(
|
|
"SELECT message_id FROM chat_messages_fts WHERE chat_messages_fts MATCH 'backfilled'"
|
|
).fetchall()
|
|
assert backfilled == [("m1",)]
|
|
|
|
conn.execute(
|
|
"INSERT INTO chat_messages(id, session_id, role, content) VALUES (?, ?, ?, ?)",
|
|
("m2", "s1", "assistant", "triggered transcript search"),
|
|
)
|
|
triggered = conn.execute(
|
|
"SELECT message_id FROM chat_messages_fts WHERE chat_messages_fts MATCH 'triggered'"
|
|
).fetchall()
|
|
assert triggered == [("m2",)]
|
|
finally:
|
|
conn.close()
|
|
|
|
|
|
def test_search_chats_formats_shared_results(monkeypatch):
|
|
from src import session_search
|
|
from src.tool_implementations import do_search_chats
|
|
|
|
def fake_search(query, limit=20, owner=None, include_archived=False, context_messages=1, db=None):
|
|
return [
|
|
SessionSearchResult(
|
|
message_id="m2",
|
|
session_id="s1",
|
|
session_name="Design notes",
|
|
role="assistant",
|
|
content="We discussed session search.",
|
|
content_snippet="We discussed session search.",
|
|
timestamp="2026-01-01T12:00:00",
|
|
context_before=[{"message_id": "m1", "role": "user", "content": "Can you find old chats?", "timestamp": None}],
|
|
context_after=[{"message_id": "m3", "role": "user", "content": "That helps.", "timestamp": None}],
|
|
)
|
|
]
|
|
|
|
monkeypatch.setattr(session_search, "search_session_messages", fake_search)
|
|
|
|
out = asyncio.run(do_search_chats("session search", owner="alice"))
|
|
|
|
assert "Design notes" in out["results"]
|
|
assert "Match (assistant): We discussed session search." in out["results"]
|
|
assert "Before (user): Can you find old chats?" in out["results"]
|
|
assert "After (user): That helps." in out["results"]
|