mirror of
https://github.com/pewdiepie-archdaemon/odysseus.git
synced 2026-06-18 10:45:31 -04:00
356 lines
10 KiB
Python
356 lines
10 KiB
Python
"""Shared session transcript search for UI and agent tools."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
import re
|
|
from dataclasses import dataclass
|
|
from datetime import datetime
|
|
from typing import Any, Iterable
|
|
|
|
from sqlalchemy import text
|
|
|
|
from core.database import ChatMessage as DBChatMessage
|
|
from core.database import Session as DBSession
|
|
from core.database import SessionLocal
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
SEARCH_ROLES = ("user", "assistant")
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class SessionSearchResult:
|
|
message_id: str
|
|
session_id: str
|
|
session_name: str
|
|
role: str
|
|
content: str
|
|
content_snippet: str
|
|
timestamp: str | None
|
|
context_before: list[dict[str, Any]]
|
|
context_after: list[dict[str, Any]]
|
|
|
|
def to_dict(self) -> dict[str, Any]:
|
|
return {
|
|
"message_id": self.message_id,
|
|
"session_id": self.session_id,
|
|
"session_name": self.session_name,
|
|
"role": self.role,
|
|
"content_snippet": self.content_snippet,
|
|
"timestamp": self.timestamp,
|
|
"context_before": self.context_before,
|
|
"context_after": self.context_after,
|
|
}
|
|
|
|
|
|
def _iso(value: datetime | None) -> str | None:
|
|
return value.isoformat() if value else None
|
|
|
|
|
|
def _message_to_context(msg: DBChatMessage) -> dict[str, Any]:
|
|
return {
|
|
"message_id": msg.id,
|
|
"role": msg.role,
|
|
"content": msg.content or "",
|
|
"timestamp": _iso(msg.timestamp),
|
|
}
|
|
|
|
|
|
def _escape_like(value: str) -> str:
|
|
return value.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_")
|
|
|
|
|
|
def _snippet(content: str, query: str, radius: int = 60) -> str:
|
|
content = content or ""
|
|
query = query or ""
|
|
if not query:
|
|
return content[: radius * 2]
|
|
|
|
idx = content.lower().find(query.lower())
|
|
if idx == -1:
|
|
return content[: radius * 2]
|
|
|
|
start = max(0, idx - radius)
|
|
end = min(len(content), idx + len(query) + radius)
|
|
return ("..." if start > 0 else "") + content[start:end] + ("..." if end < len(content) else "")
|
|
|
|
|
|
def _sanitize_fts_query(query: str) -> str | None:
|
|
"""Convert free text into a conservative FTS5 MATCH query.
|
|
|
|
User input can contain FTS5 operators or punctuation that raises
|
|
sqlite3.OperationalError. For transcript search we do not need advanced
|
|
syntax in v1, so keep only words and balanced quoted phrases.
|
|
"""
|
|
parts: list[str] = []
|
|
for match in re.finditer(r'"([^"]+)"|[\w][\w._-]*', query, flags=re.UNICODE):
|
|
phrase = match.group(1)
|
|
if phrase is not None:
|
|
phrase = phrase.strip()
|
|
if phrase:
|
|
parts.append('"' + phrase.replace('"', '""') + '"')
|
|
continue
|
|
|
|
token = match.group(0).strip("._-")
|
|
if not token:
|
|
continue
|
|
if any(ch in token for ch in "._-"):
|
|
parts.append('"' + token.replace('"', '""') + '"')
|
|
else:
|
|
parts.append(token)
|
|
|
|
if not parts:
|
|
return None
|
|
return " ".join(parts)
|
|
|
|
|
|
def _is_sqlite_session(db) -> bool:
|
|
try:
|
|
bind = db.get_bind()
|
|
return getattr(getattr(bind, "dialect", None), "name", None) == "sqlite"
|
|
except Exception:
|
|
return False
|
|
|
|
|
|
def _has_fts_table(db) -> bool:
|
|
if not _is_sqlite_session(db):
|
|
return False
|
|
try:
|
|
row = db.execute(
|
|
text("SELECT 1 FROM sqlite_master WHERE type='table' AND name='chat_messages_fts' LIMIT 1")
|
|
).first()
|
|
return row is not None
|
|
except Exception as e:
|
|
logger.debug("chat_messages_fts availability check failed: %s", e)
|
|
return False
|
|
|
|
|
|
def _owner_filter(query, owner: str | None, include_legacy_owner: bool):
|
|
if owner is None:
|
|
return query.filter(DBSession.owner.is_(None))
|
|
if not include_legacy_owner:
|
|
return query.filter(DBSession.owner == owner)
|
|
return query.filter((DBSession.owner == owner) | (DBSession.owner.is_(None)))
|
|
|
|
|
|
def _context_for_message(db, msg: DBChatMessage, count: int) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]:
|
|
if count <= 0 or not msg.timestamp:
|
|
return [], []
|
|
|
|
before_rows = (
|
|
db.query(DBChatMessage)
|
|
.filter(
|
|
DBChatMessage.session_id == msg.session_id,
|
|
DBChatMessage.role.in_(SEARCH_ROLES),
|
|
DBChatMessage.timestamp < msg.timestamp,
|
|
)
|
|
.order_by(DBChatMessage.timestamp.desc())
|
|
.limit(count)
|
|
.all()
|
|
)
|
|
after_rows = (
|
|
db.query(DBChatMessage)
|
|
.filter(
|
|
DBChatMessage.session_id == msg.session_id,
|
|
DBChatMessage.role.in_(SEARCH_ROLES),
|
|
DBChatMessage.timestamp > msg.timestamp,
|
|
)
|
|
.order_by(DBChatMessage.timestamp.asc())
|
|
.limit(count)
|
|
.all()
|
|
)
|
|
before = [_message_to_context(row) for row in reversed(before_rows)]
|
|
after = [_message_to_context(row) for row in after_rows]
|
|
return before, after
|
|
|
|
|
|
def _rows_to_results(db, rows: Iterable[tuple[DBChatMessage, str, str]], query: str, context_messages: int) -> list[SessionSearchResult]:
|
|
results: list[SessionSearchResult] = []
|
|
for msg, session_name, snippet in rows:
|
|
before, after = _context_for_message(db, msg, context_messages)
|
|
content = msg.content or ""
|
|
results.append(
|
|
SessionSearchResult(
|
|
message_id=msg.id,
|
|
session_id=msg.session_id,
|
|
session_name=session_name or "Untitled",
|
|
role=msg.role,
|
|
content=content,
|
|
content_snippet=snippet or _snippet(content, query),
|
|
timestamp=_iso(msg.timestamp),
|
|
context_before=before,
|
|
context_after=after,
|
|
)
|
|
)
|
|
return results
|
|
|
|
|
|
def _search_like(
|
|
db,
|
|
query: str,
|
|
limit: int,
|
|
owner: str | None,
|
|
include_archived: bool,
|
|
context_messages: int,
|
|
restrict_owner: bool,
|
|
include_legacy_owner: bool,
|
|
) -> list[SessionSearchResult]:
|
|
safe_q = _escape_like(query)
|
|
q = (
|
|
db.query(DBChatMessage, DBSession.name)
|
|
.join(DBSession, DBChatMessage.session_id == DBSession.id)
|
|
.filter(
|
|
DBChatMessage.content.ilike(f"%{safe_q}%", escape="\\"),
|
|
DBChatMessage.role.in_(SEARCH_ROLES),
|
|
)
|
|
)
|
|
if not include_archived:
|
|
q = q.filter(DBSession.archived == False)
|
|
if restrict_owner:
|
|
q = _owner_filter(q, owner, include_legacy_owner)
|
|
rows = q.order_by(DBChatMessage.timestamp.desc()).limit(limit).all()
|
|
shaped = ((msg, session_name, _snippet(msg.content or "", query)) for msg, session_name in rows)
|
|
return _rows_to_results(db, shaped, query, context_messages)
|
|
|
|
|
|
def _search_fts(
|
|
db,
|
|
query: str,
|
|
limit: int,
|
|
owner: str | None,
|
|
include_archived: bool,
|
|
context_messages: int,
|
|
restrict_owner: bool,
|
|
include_legacy_owner: bool,
|
|
) -> list[SessionSearchResult] | None:
|
|
fts_query = _sanitize_fts_query(query)
|
|
if not fts_query or not _has_fts_table(db):
|
|
return None
|
|
|
|
archived_clause = "" if include_archived else "AND s.archived = 0"
|
|
if not restrict_owner:
|
|
owner_clause = ""
|
|
elif owner is None:
|
|
owner_clause = "AND s.owner IS NULL"
|
|
elif not include_legacy_owner:
|
|
owner_clause = "AND s.owner = :owner"
|
|
else:
|
|
owner_clause = "AND (s.owner = :owner OR s.owner IS NULL)"
|
|
params: dict[str, Any] = {"fts_query": fts_query, "limit": limit}
|
|
if restrict_owner and owner is not None:
|
|
params["owner"] = owner
|
|
|
|
sql = text(
|
|
f"""
|
|
SELECT
|
|
m.id AS message_id,
|
|
snippet(chat_messages_fts, 0, '', '', '...', 24) AS content_snippet
|
|
FROM chat_messages_fts
|
|
JOIN chat_messages m ON m.id = chat_messages_fts.message_id
|
|
JOIN sessions s ON s.id = m.session_id
|
|
WHERE chat_messages_fts MATCH :fts_query
|
|
{archived_clause}
|
|
{owner_clause}
|
|
AND m.role IN ('user', 'assistant')
|
|
ORDER BY bm25(chat_messages_fts), m.timestamp DESC
|
|
LIMIT :limit
|
|
"""
|
|
)
|
|
|
|
try:
|
|
hits = db.execute(sql, params).fetchall()
|
|
except Exception as e:
|
|
logger.debug("FTS session search failed; falling back to LIKE: %s", e)
|
|
return None
|
|
|
|
if not hits:
|
|
return None
|
|
|
|
rows = []
|
|
for hit in hits:
|
|
message_id = hit[0]
|
|
snippet = hit[1] or ""
|
|
row = (
|
|
db.query(DBChatMessage, DBSession.name)
|
|
.join(DBSession, DBChatMessage.session_id == DBSession.id)
|
|
.filter(DBChatMessage.id == message_id)
|
|
.first()
|
|
)
|
|
if row:
|
|
msg, session_name = row
|
|
rows.append((msg, session_name, snippet))
|
|
return _rows_to_results(db, rows, query, context_messages)
|
|
|
|
|
|
def search_session_messages(
|
|
query: str,
|
|
limit: int = 20,
|
|
owner: str | None = None,
|
|
include_archived: bool = False,
|
|
context_messages: int = 1,
|
|
restrict_owner: bool = True,
|
|
include_legacy_owner: bool = True,
|
|
db=None,
|
|
) -> list[SessionSearchResult]:
|
|
"""Search session transcripts using FTS5 when available.
|
|
|
|
`owner=None` is deliberately treated as legacy/null-owner scope rather
|
|
than global access.
|
|
"""
|
|
query = (query or "").strip()
|
|
if not query:
|
|
return []
|
|
|
|
limit = max(1, min(int(limit or 20), 100))
|
|
context_messages = max(0, min(int(context_messages or 0), 3))
|
|
|
|
owns_db = db is None
|
|
if owns_db:
|
|
db = SessionLocal()
|
|
try:
|
|
fts_results = _search_fts(
|
|
db,
|
|
query,
|
|
limit,
|
|
owner,
|
|
include_archived,
|
|
context_messages,
|
|
restrict_owner,
|
|
include_legacy_owner,
|
|
)
|
|
if fts_results is not None:
|
|
like_results = _search_like(
|
|
db,
|
|
query,
|
|
limit,
|
|
owner,
|
|
include_archived,
|
|
context_messages,
|
|
restrict_owner,
|
|
include_legacy_owner,
|
|
)
|
|
merged: list[SessionSearchResult] = []
|
|
seen: set[str] = set()
|
|
for result in [*fts_results, *like_results]:
|
|
if result.message_id in seen:
|
|
continue
|
|
seen.add(result.message_id)
|
|
merged.append(result)
|
|
if len(merged) >= limit:
|
|
break
|
|
return merged
|
|
return _search_like(
|
|
db,
|
|
query,
|
|
limit,
|
|
owner,
|
|
include_archived,
|
|
context_messages,
|
|
restrict_owner,
|
|
include_legacy_owner,
|
|
)
|
|
finally:
|
|
if owns_db:
|
|
db.close()
|