diff --git a/src/session_search.py b/src/session_search.py index 23088ca5c..98ddbc757 100644 --- a/src/session_search.py +++ b/src/session_search.py @@ -214,6 +214,24 @@ def _search_like( return _rows_to_results(db, shaped, query, context_messages) +def _fetch_messages_by_id(db, message_ids): + """Fetch (message, session_name) for many message ids in a single query. + + The FTS search returns a list of hit ids; fetching each row on its own was an + N+1 query (one SELECT per hit). Batch them with one IN(...) query and return + a lookup so the caller can reassemble results in hit (relevance) order. + """ + if not message_ids: + return {} + rows = ( + db.query(DBChatMessage, DBSession.name) + .join(DBSession, DBChatMessage.session_id == DBSession.id) + .filter(DBChatMessage.id.in_(message_ids)) + .all() + ) + return {msg.id: (msg, session_name) for msg, session_name in rows} + + def _search_fts( db, query: str, @@ -267,19 +285,13 @@ def _search_fts( if not hits: return None + by_id = _fetch_messages_by_id(db, [hit[0] for hit in hits]) rows = [] for hit in hits: - message_id = hit[0] - snippet = hit[1] or "" - row = ( - db.query(DBChatMessage, DBSession.name) - .join(DBSession, DBChatMessage.session_id == DBSession.id) - .filter(DBChatMessage.id == message_id) - .first() - ) - if row: - msg, session_name = row - rows.append((msg, session_name, snippet)) + found = by_id.get(hit[0]) + if found: + msg, session_name = found + rows.append((msg, session_name, hit[1] or "")) return _rows_to_results(db, rows, query, context_messages) diff --git a/tests/test_session_search_batch_fetch.py b/tests/test_session_search_batch_fetch.py new file mode 100644 index 000000000..144e393d5 --- /dev/null +++ b/tests/test_session_search_batch_fetch.py @@ -0,0 +1,55 @@ +"""FTS session search must fetch hit rows in one query, not one per hit. + +_search_fts looked up each FTS hit's full row with its own +db.query(...).filter(id == message_id).first(), an N+1 query. The lookup is now +a single batched IN(...) query via _fetch_messages_by_id. +""" +from src.session_search import _fetch_messages_by_id + + +class _Msg: + def __init__(self, mid): + self.id = mid + + +class _Query: + def __init__(self, rows, calls): + self._rows = rows + self._calls = calls + + def join(self, *a, **k): + return self + + def filter(self, *a, **k): + return self + + def all(self): + self._calls["all"] += 1 + return self._rows + + +class _DB: + def __init__(self, rows): + self._rows = rows + self.calls = {"query": 0, "all": 0} + + def query(self, *a, **k): + self.calls["query"] += 1 + return _Query(self._rows, self.calls) + + +def test_batches_into_single_query(): + rows = [(_Msg("m1"), "Session One"), (_Msg("m2"), "Session Two")] + db = _DB(rows) + out = _fetch_messages_by_id(db, ["m1", "m2"]) + # One query for all hits, not one per hit. + assert db.calls["query"] == 1 + assert db.calls["all"] == 1 + assert out["m1"][1] == "Session One" + assert out["m2"][0].id == "m2" + + +def test_empty_ids_does_no_query(): + db = _DB([]) + assert _fetch_messages_by_id(db, []) == {} + assert db.calls["query"] == 0