fix(search): batch FTS hit lookups into one query (N+1) (#3909)

_search_fts ran the FTS MATCH query, then looked up each hit's full row with its
own db.query(...).filter(id == message_id).first() inside a loop, so a search
returning N hits issued N extra SELECTs. Fetch all hit rows in a single IN(...)
query via _fetch_messages_by_id and reassemble results in hit (relevance) order.

Adds tests/test_session_search_batch_fetch.py asserting a single batched query
(and no query for empty input). Existing session-search tests stay green.
This commit is contained in:
Mazen Tamer Salah
2026-06-11 17:31:54 +03:00
committed by GitHub
parent bfac1d55d6
commit f941db29d3
2 changed files with 78 additions and 11 deletions
+23 -11
View File
@@ -214,6 +214,24 @@ def _search_like(
return _rows_to_results(db, shaped, query, context_messages) return _rows_to_results(db, shaped, query, context_messages)
def _fetch_messages_by_id(db, message_ids):
"""Fetch (message, session_name) for many message ids in a single query.
The FTS search returns a list of hit ids; fetching each row on its own was an
N+1 query (one SELECT per hit). Batch them with one IN(...) query and return
a lookup so the caller can reassemble results in hit (relevance) order.
"""
if not message_ids:
return {}
rows = (
db.query(DBChatMessage, DBSession.name)
.join(DBSession, DBChatMessage.session_id == DBSession.id)
.filter(DBChatMessage.id.in_(message_ids))
.all()
)
return {msg.id: (msg, session_name) for msg, session_name in rows}
def _search_fts( def _search_fts(
db, db,
query: str, query: str,
@@ -267,19 +285,13 @@ def _search_fts(
if not hits: if not hits:
return None return None
by_id = _fetch_messages_by_id(db, [hit[0] for hit in hits])
rows = [] rows = []
for hit in hits: for hit in hits:
message_id = hit[0] found = by_id.get(hit[0])
snippet = hit[1] or "" if found:
row = ( msg, session_name = found
db.query(DBChatMessage, DBSession.name) rows.append((msg, session_name, hit[1] or ""))
.join(DBSession, DBChatMessage.session_id == DBSession.id)
.filter(DBChatMessage.id == message_id)
.first()
)
if row:
msg, session_name = row
rows.append((msg, session_name, snippet))
return _rows_to_results(db, rows, query, context_messages) return _rows_to_results(db, rows, query, context_messages)
+55
View File
@@ -0,0 +1,55 @@
"""FTS session search must fetch hit rows in one query, not one per hit.
_search_fts looked up each FTS hit's full row with its own
db.query(...).filter(id == message_id).first(), an N+1 query. The lookup is now
a single batched IN(...) query via _fetch_messages_by_id.
"""
from src.session_search import _fetch_messages_by_id
class _Msg:
def __init__(self, mid):
self.id = mid
class _Query:
def __init__(self, rows, calls):
self._rows = rows
self._calls = calls
def join(self, *a, **k):
return self
def filter(self, *a, **k):
return self
def all(self):
self._calls["all"] += 1
return self._rows
class _DB:
def __init__(self, rows):
self._rows = rows
self.calls = {"query": 0, "all": 0}
def query(self, *a, **k):
self.calls["query"] += 1
return _Query(self._rows, self.calls)
def test_batches_into_single_query():
rows = [(_Msg("m1"), "Session One"), (_Msg("m2"), "Session Two")]
db = _DB(rows)
out = _fetch_messages_by_id(db, ["m1", "m2"])
# One query for all hits, not one per hit.
assert db.calls["query"] == 1
assert db.calls["all"] == 1
assert out["m1"][1] == "Session One"
assert out["m2"][0].id == "m2"
def test_empty_ids_does_no_query():
db = _DB([])
assert _fetch_messages_by_id(db, []) == {}
assert db.calls["query"] == 0