diff --git a/src/research_handler.py b/src/research_handler.py index b996f089f..b3af3b8e5 100644 --- a/src/research_handler.py +++ b/src/research_handler.py @@ -390,7 +390,6 @@ class ResearchHandler: def get_status(self, session_id: str) -> Optional[dict]: """Get current research status for a session.""" - avg = self.get_avg_duration() if session_id in self._active_tasks: entry = self._active_tasks[session_id] result = { @@ -399,6 +398,14 @@ class ResearchHandler: "query": entry["query"], "started_at": entry["started_at"], } + # avg_duration is a historical figure over completed reports on + # disk; get_avg_duration() globs and JSON-parses the whole research + # dir, so compute it at most once per active stream (memoized on the + # entry) instead of on every ~1s SSE poll. The disk branch below + # never used it, so it no longer pays that cost at all. + if "_avg_duration" not in entry: + entry["_avg_duration"] = self.get_avg_duration() + avg = entry["_avg_duration"] if avg is not None: result["avg_duration"] = round(avg, 1) return result diff --git a/tests/test_research_status_avg_duration.py b/tests/test_research_status_avg_duration.py new file mode 100644 index 000000000..d44c63242 --- /dev/null +++ b/tests/test_research_status_avg_duration.py @@ -0,0 +1,41 @@ +"""get_status must not rescan the whole research dir on every SSE poll. + +get_avg_duration() globs and JSON-parses every file under the research data dir. +get_status() called it unconditionally on each poll, including for sessions that +are not active (the common case while a client polls a finished report). It is +now computed only for active sessions and memoized on the entry. +""" +from src.research_handler import ResearchHandler + + +def _handler(): + h = ResearchHandler.__new__(ResearchHandler) + h._active_tasks = {} + return h + + +def test_inactive_session_does_not_compute_avg(monkeypatch): + h = _handler() + calls = [] + monkeypatch.setattr(h, "get_avg_duration", lambda: (calls.append(1), 5.0)[1]) + # Unknown session, no disk file -> None, and no expensive avg scan. + assert h.get_status("missing-session") is None + assert calls == [] + + +def test_active_session_memoizes_avg(monkeypatch): + h = _handler() + h._active_tasks["s1"] = { + "status": "running", "progress": {}, "query": "q", "started_at": 0, + } + calls = [] + monkeypatch.setattr(h, "get_avg_duration", lambda: (calls.append(1), 12.0)[1]) + + r1 = h.get_status("s1") + r2 = h.get_status("s1") + r3 = h.get_status("s1") + + assert r1["avg_duration"] == 12.0 + assert r2["avg_duration"] == 12.0 and r3["avg_duration"] == 12.0 + # Computed once across many polls, not once per poll. + assert len(calls) == 1