From e115b0155c724a260c9d62e7bb58c1838a3bc620 Mon Sep 17 00:00:00 2001 From: SurprisedDuck Date: Wed, 10 Jun 2026 14:37:26 +0200 Subject: [PATCH 01/18] fix(security): don't grant tool access in the pre-setup window (#3506) * fix(security): don't grant tool access in the pre-setup window owner_is_admin_or_single_user() returned True whenever auth was not configured, which conflated two very different states: - intentional single-user mode (operator set AUTH_ENABLED=false), and - the pre-setup window (auth enabled, but no admin created yet). In the second state, blocked_tools_for_owner() returned an empty set, so server-execution tools (bash/python) and other admin-only tools were ungated. The auth middleware already 401s /api/ requests pre-setup, but a caller that bypasses it (trusted loopback / internal-tool path) could reach those tools before setup completed. Treat "not configured" as admin only when auth is intentionally disabled (AUTH_ENABLED=false), mirroring the AUTH_ENABLED parsing in app.py and core.middleware. Single-user mode is preserved; the pre-setup window is now non-admin as defense-in-depth. Adds regression tests for both states. Fixes #3201 Supported by Claude Opus 4.8 * refactor(security): reuse _auth_disabled() instead of a duplicate helper Addresses review on #3506: src/auth_helpers.py already has _auth_disabled() with the identical AUTH_ENABLED parse. Drop the duplicate _auth_intentionally_disabled() and call the existing helper via a lazy import inside owner_is_admin_or_single_user (mirroring the lazy core.auth import) to avoid any import cycle. Removes the now-unused `import os`. Behaviour and the two regression tests are unchanged. Supported by Claude Opus 4.8 --------- Co-authored-by: SurprisedDuck <288741682+SurprisedDuck@users.noreply.github.com> --- src/tool_security.py | 17 ++++++++-- tests/test_review_regressions.py | 54 ++++++++++++++++++++++++++++++++ 2 files changed, 69 insertions(+), 2 deletions(-) diff --git a/src/tool_security.py b/src/tool_security.py index 82d2c3d67..6b7bc90df 100644 --- a/src/tool_security.py +++ b/src/tool_security.py @@ -162,13 +162,26 @@ def is_public_blocked_tool(tool_name: Optional[str]) -> bool: def owner_is_admin_or_single_user(owner: Optional[str]) -> bool: - """Return True for admins, or when auth is not configured yet.""" + """Return True for admins, or in intentional single-user mode. + + Single-user mode means the operator explicitly disabled auth + (``AUTH_ENABLED=false``) — the local/self-host default where the owner has + full access to their own box. + + The pre-setup window (auth ENABLED but no admin created yet) is treated as + NON-admin: returning True there would hand server-execution tools + (``bash``/``python``) to any caller before setup completes. The auth + middleware already 401s ``/api/`` requests pre-setup, so this is + defense-in-depth for callers that bypass it (e.g. trusted loopback). + """ try: from core.auth import AuthManager auth = AuthManager() if not auth.is_configured: - return True + from src.auth_helpers import _auth_disabled + + return _auth_disabled() return bool(owner and auth.is_admin(owner)) except Exception as exc: logger.warning("Unable to evaluate owner admin status: %s", exc) diff --git a/tests/test_review_regressions.py b/tests/test_review_regressions.py index b3988f88e..fe782f151 100644 --- a/tests/test_review_regressions.py +++ b/tests/test_review_regressions.py @@ -647,6 +647,60 @@ def test_public_agent_policy_hides_sensitive_tools(monkeypatch): assert "manage_tasks" in blocked +def test_presetup_does_not_grant_admin_tools_when_auth_enabled(monkeypatch): + """Pre-setup window: auth is enabled but no admin user exists yet. + + This must NOT be treated as single-user/admin at the tool layer — the + server-execution tools (bash/python) stay blocked as defense-in-depth so + an unauthenticated caller that slips past the auth middleware (e.g. via a + loopback bypass) can't reach an RCE before setup completes. + """ + monkeypatch.delenv("AUTH_ENABLED", raising=False) # default: enabled + auth_mod = _install_core_auth_stub(monkeypatch) + + class FakeAuth: + is_configured = False + + def is_admin(self, username): + return False + + monkeypatch.setattr(auth_mod, "AuthManager", lambda: FakeAuth()) + + from src.tool_security import ( + blocked_tools_for_owner, + owner_is_admin_or_single_user, + ) + + assert owner_is_admin_or_single_user(None) is False + blocked = blocked_tools_for_owner(None) + assert "bash" in blocked + assert "python" in blocked + + +def test_single_user_mode_keeps_full_tool_access_when_auth_disabled(monkeypatch): + """Intentional single-user mode (AUTH_ENABLED=false) keeps full tool + access even with no admin user — this is the default local/self-host UX + and must not regress.""" + monkeypatch.setenv("AUTH_ENABLED", "false") + auth_mod = _install_core_auth_stub(monkeypatch) + + class FakeAuth: + is_configured = False + + def is_admin(self, username): + return False + + monkeypatch.setattr(auth_mod, "AuthManager", lambda: FakeAuth()) + + from src.tool_security import ( + blocked_tools_for_owner, + owner_is_admin_or_single_user, + ) + + assert owner_is_admin_or_single_user(None) is True + assert blocked_tools_for_owner(None) == set() + + @pytest.mark.asyncio async def test_webhook_tool_reuses_private_url_validation(): class FakeDb: From cd3fb4e96bba195b43dc2357e97bf666d99649b2 Mon Sep 17 00:00:00 2001 From: RaresKeY <158580472+RaresKeY@users.noreply.github.com> Date: Wed, 10 Jun 2026 17:24:27 +0300 Subject: [PATCH 02/18] fix(auth): fail closed when deleting user tokens fails (#3733) --- core/auth.py | 28 ++++++++------ routes/auth_routes.py | 25 ++++++++---- tests/test_auth_config_lock_concurrency.py | 38 +++++++++++++++++++ ...est_delete_user_invalidates_token_cache.py | 24 ++++++++++++ tests/test_delete_user_revokes_api_tokens.py | 18 +++++++++ 5 files changed, 114 insertions(+), 19 deletions(-) diff --git a/core/auth.py b/core/auth.py index 5db2fed4c..11f38cd5f 100644 --- a/core/auth.py +++ b/core/auth.py @@ -244,6 +244,22 @@ class AuthManager: return False if not self.users.get(requesting_user, {}).get("is_admin"): return False + # Revoke API bearer tokens before removing the auth row. The bearer + # path authenticates from ApiToken rows and does not require the + # owner to still exist, so a successful delete must not leave active + # rows behind. If the token store is unavailable, fail closed and + # keep the user/session state intact so the admin can retry. + try: + from core.database import get_db_session, ApiToken + with get_db_session() as db: + removed_tokens = db.query(ApiToken).filter(ApiToken.owner == username).delete() + if removed_tokens: + logger.info( + f"Revoked {removed_tokens} API token(s) owned by deleted user '{username}'" + ) + except Exception: + logger.warning(f"Failed to revoke API tokens for deleted user '{username}'") + return False del self._config["users"][username] self._save() # Purge all sessions belonging to this user. validate_token doesn't @@ -258,18 +274,6 @@ class AuthManager: revoked += 1 if revoked: self._save_sessions() - # Also revoke API bearer tokens owned by this user. The bearer auth - # path authenticates straight against ApiToken rows and never - # re-checks that the owner still exists, so leaving the rows behind - # would let a deleted user keep full API access indefinitely. - try: - from core.database import get_db_session, ApiToken - with get_db_session() as db: - removed = db.query(ApiToken).filter(ApiToken.owner == username).delete() - if removed: - logger.info(f"Revoked {removed} API token(s) owned by deleted user '{username}'") - except Exception: - logger.warning(f"Failed to revoke API tokens for deleted user '{username}'") logger.info(f"Deleted user '{username}' (by {requesting_user}); revoked {revoked} active session(s)") return True diff --git a/routes/auth_routes.py b/routes/auth_routes.py index c20860892..853958d35 100644 --- a/routes/auth_routes.py +++ b/routes/auth_routes.py @@ -473,7 +473,23 @@ def setup_auth_routes(auth_manager: AuthManager) -> APIRouter: user = _get_current_user(request) if not user or not auth_manager.is_admin(user): raise HTTPException(403, "Admin only") - ok = auth_manager.delete_user(body.username, user) + + def _invalidate_api_token_cache(): + try: + invalidator = getattr(request.app.state, "invalidate_token_cache", None) + if invalidator: + invalidator() + except Exception: + pass + + try: + ok = auth_manager.delete_user(body.username, user) + except Exception: + # delete_user can touch ApiToken rows before a later auth-store write + # fails. Dirty the bearer cache anyway so a partial token purge does + # not leave already-cached tokens authenticating until restart. + _invalidate_api_token_cache() + raise if not ok: raise HTTPException(400, "Cannot delete user") # delete_user removes the user's ApiToken rows, but the bearer-auth @@ -481,12 +497,7 @@ def setup_auth_routes(auth_manager: AuthManager) -> APIRouter: # rebuilds when flagged dirty. Without this, a deleted user's already # cached token keeps authenticating until some other token op or a # restart clears the cache. Mirror what the token routes do. - try: - invalidator = getattr(request.app.state, "invalidate_token_cache", None) - if invalidator: - invalidator() - except Exception: - pass + _invalidate_api_token_cache() return {"ok": True} # ---- Feature visibility (admin-managed) ---- diff --git a/tests/test_auth_config_lock_concurrency.py b/tests/test_auth_config_lock_concurrency.py index f5cc8a18c..34232b9e2 100644 --- a/tests/test_auth_config_lock_concurrency.py +++ b/tests/test_auth_config_lock_concurrency.py @@ -8,6 +8,9 @@ with missing users or assertion errors. import json import threading import time +import contextlib +import sys +import types from concurrent.futures import ThreadPoolExecutor, as_completed import pytest @@ -15,6 +18,41 @@ import pytest from tests.helpers.import_state import clear_module +class _OwnerColumn: + def __eq__(self, other): + return ("owner ==", other) + + +class _FakeApiToken: + owner = _OwnerColumn() + + +class _FakeQuery: + def filter(self, *_conds): + return self + + def delete(self, *args, **kwargs): + return 0 + + +class _FakeSession: + def query(self, model): + assert model is _FakeApiToken + return _FakeQuery() + + +@pytest.fixture(autouse=True) +def _stub_api_token_purge(monkeypatch): + @contextlib.contextmanager + def _fake_db_session(): + yield _FakeSession() + + db_stub = types.ModuleType("core.database") + db_stub.get_db_session = _fake_db_session + db_stub.ApiToken = _FakeApiToken + monkeypatch.setitem(sys.modules, "core.database", db_stub) + + def _fresh_auth_manager(tmp_path): clear_module("core.auth") from core.auth import AuthManager diff --git a/tests/test_delete_user_invalidates_token_cache.py b/tests/test_delete_user_invalidates_token_cache.py index c9cb79a5e..91be50e93 100644 --- a/tests/test_delete_user_invalidates_token_cache.py +++ b/tests/test_delete_user_invalidates_token_cache.py @@ -36,6 +36,17 @@ def _auth_manager(delete_result): ) +def _auth_manager_raising(): + def _delete_user(_username, _requesting_user): + raise RuntimeError("auth save failed after token purge") + + return types.SimpleNamespace( + get_username_for_token=lambda token: "admin", + is_admin=lambda user: True, + delete_user=_delete_user, + ) + + def test_successful_delete_invalidates_cache(): invalidations = [] router = setup_auth_routes(_auth_manager(delete_result=True)) @@ -56,3 +67,16 @@ def test_refused_delete_does_not_invalidate_cache(): raised = True assert raised, "a refused delete should raise (HTTP 400)" assert invalidations == [], "a refused delete must not touch the token cache" + + +def test_delete_exception_invalidates_cache_for_partial_token_purge(): + invalidations = [] + router = setup_auth_routes(_auth_manager_raising()) + handler = _handler(router) + try: + asyncio.run(handler(DeleteUserRequest(username="bob"), _fake_request(invalidations))) + raised = False + except RuntimeError: + raised = True + assert raised, "delete_user exception should still propagate" + assert invalidations == [True], "partial token purge must dirty the bearer cache" diff --git a/tests/test_delete_user_revokes_api_tokens.py b/tests/test_delete_user_revokes_api_tokens.py index dab753ff0..52a7d55af 100644 --- a/tests/test_delete_user_revokes_api_tokens.py +++ b/tests/test_delete_user_revokes_api_tokens.py @@ -114,3 +114,21 @@ def test_refused_delete_leaves_tokens_alone(manager, db_calls): def test_unknown_user_leaves_tokens_alone(manager, db_calls): assert manager.delete_user("ghost", "admin") is False assert db_calls == [] + + +def test_delete_user_fails_closed_when_api_token_purge_fails(manager, monkeypatch): + token = manager.create_session("bob", "secret-bob-pw") + + @contextlib.contextmanager + def _failing_db_session(): + raise RuntimeError("database unavailable") + yield + + db_stub = types.ModuleType("core.database") + db_stub.get_db_session = _failing_db_session + db_stub.ApiToken = _FakeApiToken + monkeypatch.setitem(sys.modules, "core.database", db_stub) + + assert manager.delete_user("bob", "admin") is False + assert "bob" in manager.users + assert manager.validate_token(token) is True From ee6cfbd25a597d4ece1aac09554464e13970ce6e Mon Sep 17 00:00:00 2001 From: RaresKeY <158580472+RaresKeY@users.noreply.github.com> Date: Wed, 10 Jun 2026 17:31:26 +0300 Subject: [PATCH 03/18] fix(auth): drop reserved usernames loaded from auth config (#3727) --- app.py | 12 +++- core/auth.py | 41 +++++++++++++- ...test_reserved_username_admin_escalation.py | 56 +++++++++++++++++++ 3 files changed, 106 insertions(+), 3 deletions(-) diff --git a/app.py b/app.py index cfd73e83f..7cec8b0f1 100644 --- a/app.py +++ b/app.py @@ -56,7 +56,7 @@ from core.constants import ( ) from core.database import SessionLocal, ApiToken from core.middleware import SecurityHeadersMiddleware, is_cors_preflight -from core.auth import AuthManager +from core.auth import AuthManager, normalize_known_username from core.exceptions import ( SessionNotFoundError, InvalidFileUploadError, LLMServiceError, WebSearchError, @@ -228,8 +228,16 @@ if AUTH_ENABLED: try: rows = db.query(ApiToken).filter(ApiToken.is_active == True).all() for r in rows: + owner_key = normalize_known_username(auth_manager.users, getattr(r, "owner", None)) + if not owner_key: + logger.warning( + "Ignoring active API token '%s' for unknown auth user '%s'", + getattr(r, "id", ""), + getattr(r, "owner", None), + ) + continue scopes = [s.strip() for s in (getattr(r, "scopes", "") or "chat").split(",") if s.strip()] - new_map[r.token_prefix].append((r.id, r.token_hash, getattr(r, "owner", None), scopes)) + new_map[r.token_prefix].append((r.id, r.token_hash, owner_key, scopes)) finally: db.close() _token_cache.clear() diff --git a/core/auth.py b/core/auth.py index 11f38cd5f..2f9fd4e51 100644 --- a/core/auth.py +++ b/core/auth.py @@ -67,6 +67,14 @@ TOKEN_TTL = 60 * 60 * 24 * 7 # 7 days RESERVED_USERNAMES = frozenset({"internal-tool", "api", "demo", "system"}) +def normalize_known_username(users: Dict[str, Any], username: str | None) -> Optional[str]: + """Return a normalized username only when it exists in the auth user map.""" + key = str(username or "").strip().lower() + if not key or key not in users: + return None + return key + + def _hash_password(password: str) -> str: return bcrypt.hashpw(password.encode("utf-8"), bcrypt.gensalt()).decode("utf-8") @@ -96,6 +104,7 @@ class AuthManager: self._load() self._load_sessions() self._migrate_single_user() + self._drop_reserved_loaded_users() self._migrate_legacy_admin_role() def _load(self): @@ -148,7 +157,13 @@ class AuthManager: def _migrate_single_user(self): """Migrate old single-user format to multi-user format.""" if "password_hash" in self._config and "users" not in self._config: - old_user = self._config.get("username", "admin") + old_user = str(self._config.get("username", "admin") or "admin").strip().lower() + if old_user in RESERVED_USERNAMES: + logger.warning( + "Migrating legacy single-user reserved username '%s' to 'admin'", + old_user, + ) + old_user = "admin" old_hash = self._config["password_hash"] self._config = { "users": { @@ -162,6 +177,30 @@ class AuthManager: self._save() logger.info(f"Migrated single-user auth to multi-user (admin: {old_user})") + def _drop_reserved_loaded_users(self): + """Fail closed for legacy/manual auth rows that collide with sentinels.""" + users = self._config.get("users") + if not isinstance(users, dict): + return + normalized = {} + removed = [] + for username, data in users.items(): + key = str(username or "").strip().lower() + if not key: + continue + if key in RESERVED_USERNAMES: + removed.append(key) + continue + normalized[key] = data + if removed or normalized != users: + self._config["users"] = normalized + self._save() + if removed: + logger.warning( + "Removed reserved username(s) from auth config: %s", + ", ".join(sorted(set(removed))), + ) + def _migrate_legacy_admin_role(self): """Normalize setup.py's old role='admin' marker to is_admin=True.""" changed = False diff --git a/tests/test_reserved_username_admin_escalation.py b/tests/test_reserved_username_admin_escalation.py index 29c423774..fff1aea78 100644 --- a/tests/test_reserved_username_admin_escalation.py +++ b/tests/test_reserved_username_admin_escalation.py @@ -58,6 +58,62 @@ def test_rename_into_reserved_username_is_blocked(tmp_path): assert "bob" in mgr.users +def test_legacy_reserved_username_is_removed_on_load(tmp_path): + auth_path = tmp_path / "auth.json" + auth_path.write_text( + '{"users": {"internal-tool": {"password_hash": "unused", "is_admin": false}, ' + '"admin": {"password_hash": "unused", "is_admin": true}}}', + encoding="utf-8", + ) + mgr = _fresh_auth_manager(tmp_path) + + assert "internal-tool" not in mgr.users + assert "admin" in mgr.users + assert "internal-tool" not in auth_path.read_text(encoding="utf-8") + + +def test_legacy_reserved_username_session_cannot_authenticate(tmp_path): + auth_path = tmp_path / "auth.json" + sessions_path = tmp_path / "sessions.json" + auth_path.write_text( + '{"users": {"internal-tool": {"password_hash": "unused", "is_admin": false}}}', + encoding="utf-8", + ) + sessions_path.write_text( + '{"tok": {"username": "internal-tool", "expiry": 9999999999}}', + encoding="utf-8", + ) + mgr = _fresh_auth_manager(tmp_path) + + assert mgr.validate_token("tok") is False + assert mgr.get_username_for_token("tok") is None + + +def test_legacy_reserved_single_user_migrates_to_admin(tmp_path): + auth_path = tmp_path / "auth.json" + auth_path.write_text( + '{"username": "internal-tool", "password_hash": "unused"}', + encoding="utf-8", + ) + mgr = _fresh_auth_manager(tmp_path) + + assert "internal-tool" not in mgr.users + assert "admin" in mgr.users + assert mgr.is_admin("admin") is True + + +def test_token_cache_owner_normalization_requires_current_user(): + clear_module("core.auth") + from core.auth import normalize_known_username + + users = {"alice": {}, "admin": {}} + + assert normalize_known_username(users, " Alice ") == "alice" + assert normalize_known_username(users, "internal-tool") is None + assert normalize_known_username(users, "api") is None + assert normalize_known_username(users, "") is None + + def test_normal_usernames_still_allowed(tmp_path): mgr = _fresh_auth_manager(tmp_path) assert mgr.create_user("alice", "pw-123456") is True From edce6080089c6f2989abf3bafedd685308148c3e Mon Sep 17 00:00:00 2001 From: Maruf Hasan <170166811+MarufHasan-dev@users.noreply.github.com> Date: Wed, 10 Jun 2026 20:50:43 +0600 Subject: [PATCH 04/18] fix(ui): raw SVG markup displayed instead of search icon for web_search tool label (#3601) * fix(ui): escaped SVG renders as raw markup during web_search tool label The _toolLabels['web_search'] entry embedded an SVG HTML string concatenated with label text. At render time the entire value was passed through esc(), HTML-escaping tags so the icon displayed as raw text instead of rendering visually. Fix: separate icon from label text via a _toolIcons map. The SVG is injected as raw innerHTML (unescaped) in .agent-thread-icon, while the label text remains safely escaped. * test: add behavioral test for web_search tool icon rendering Co-authored-by: TheDragonTail --------- Co-authored-by: TheDragonTail --- static/js/chat.js | 8 +- tests/test_web_search_tool_icon_js.py | 119 ++++++++++++++++++++++++++ 2 files changed, 125 insertions(+), 2 deletions(-) create mode 100644 tests/test_web_search_tool_icon_js.py diff --git a/static/js/chat.js b/static/js/chat.js index 60149d005..7ecefdb7d 100644 --- a/static/js/chat.js +++ b/static/js/chat.js @@ -1082,7 +1082,7 @@ import { wireArrowUpRecall, getLastUserMessageFromChatHistory } from './composer let _lastToolName = ''; const _searchIcon = ''; const _toolLabels = { - 'web_search': _searchIcon + 'Searching', + 'web_search': 'Searching', 'bash': 'Running', 'python': 'Running', 'create_document': 'Writing', @@ -1102,6 +1102,9 @@ import { wireArrowUpRecall, getLastUserMessageFromChatHistory } from './composer 'list_models': 'Browsing', 'ui_control': 'Adjusting', }; + const _toolIcons = { + 'web_search': _searchIcon, + }; function _thinkingLabel() { if (!_lastToolName) { return 'Thinking'; @@ -2049,10 +2052,11 @@ import { wireArrowUpRecall, getLastUserMessageFromChatHistory } from './composer } threadWrap.classList.add('streaming'); const toolLabel = _toolLabels[json.tool.toLowerCase()] || json.tool; + const toolIcon = _toolIcons[json.tool.toLowerCase()] || '\u25B6'; const node = document.createElement('div') node.className = 'agent-thread-node running'; const cmdHtml = cmd ? `
${esc(cmd)}
` : ''; - node.innerHTML = `
\u25B6${esc(toolLabel)}▁▂▃
${cmdHtml}
`; + node.innerHTML = `
${toolIcon}${esc(toolLabel)}▁▂▃
${cmdHtml}
`; // Expand/collapse via delegated click handler (init at module bottom). threadWrap.appendChild(node); currentToolBubble = node; diff --git a/tests/test_web_search_tool_icon_js.py b/tests/test_web_search_tool_icon_js.py new file mode 100644 index 000000000..6e855df40 --- /dev/null +++ b/tests/test_web_search_tool_icon_js.py @@ -0,0 +1,119 @@ +"""Pin the web_search tool-icon rendering in the agent thread (PR #??). + +Verifies: +- web_search renders an icon instead of raw markup +- Other tools get the default ▶ icon +- Hostile tool names are HTML-escaped in the label + +Pure JS via node --input-type=module (same approach as +test_composer_arrow_up_recall_js.py). Skips when node is not installed. +""" + +import json +import shutil +import subprocess +from pathlib import Path + +import pytest + +_REPO = Path(__file__).resolve().parent.parent +_HAS_NODE = shutil.which("node") is not None + +_CHECK_JS = r""" +function esc(s) { + const map = { '&': '&', '<': '<', '>': '>', '"': '"', "'": ''' }; + return (s || '').replace(/[&<>"']/g, (m) => map[m]); +} + +const _searchIcon = ''; + +const _toolLabels = { + web_search: 'Searching', + bash: 'Running', +}; + +const _toolIcons = { + web_search: _searchIcon, +}; + +function renderIcon(toolName) { + return _toolIcons[toolName.toLowerCase()] || '\u25B6'; +} + +function renderLabel(toolName) { + return _toolLabels[toolName.toLowerCase()] || toolName; +} + +function renderThreadHTML(toolName, cmd) { + const label = renderLabel(toolName); + const icon = renderIcon(toolName); + const cmdHtml = cmd ? `
${esc(cmd)}
` : ''; + return `
${icon}${esc(label)}\u2581\u2582\u2583
${cmdHtml}
`; +} + +const cases = CASES_JSON; +const results = cases.map(c => { + const html = renderThreadHTML(c.tool, c.cmd || ''); + return { tool: c.tool, html }; +}); +console.log(JSON.stringify(results)); +""" + + +def _run(cases: list) -> list: + js = _CHECK_JS.replace("CASES_JSON", json.dumps(cases)) + proc = subprocess.run( + ["node", "--input-type=module"], + input=js, + capture_output=True, + text=True, + encoding="utf-8", + cwd=str(_REPO), + timeout=30, + ) + assert proc.returncode == 0, proc.stderr + return json.loads(proc.stdout.strip()) + + +@pytest.mark.skipif(not _HAS_NODE, reason="node binary not on PATH") +def test_web_search_icon_contains_svg(): + out = _run([{"tool": "web_search"}])[0] + assert " in agent-thread-icon for web_search" + assert "Searching" in out["html"], "Expected 'Searching' label for web_search" + + +@pytest.mark.skipif(not _HAS_NODE, reason="node binary not on PATH") +def test_default_tool_icon_is_triangle(): + out = _run([{"tool": "bash"}])[0] + assert "▶" in out["html"], "Expected ▶ icon for tools without custom icon" + assert " for bash" + assert "Running" in out["html"], "Expected 'Running' label for bash" + + +@pytest.mark.skipif(not _HAS_NODE, reason="node binary not on PATH") +def test_unknown_tool_falls_back_to_name(): + out = _run([{"tool": "my_custom_tool"}])[0] + assert "▶" in out["html"], "Expected ▶ for unknown tool" + assert "my_custom_tool" in out["html"], "Expected tool name as label" + + +@pytest.mark.skipif(not _HAS_NODE, reason="node binary not on PATH") +def test_hostile_tool_name_is_escaped(): + out = _run([{"tool": ''}])[0] + assert "<img" in out["html"], "Expected < to be HTML-escaped" + assert ">" in out["html"], "Expected > to be HTML-escaped" + assert " must not appear" + assert "onerror" not in out["html"] or """ in out["html"], "onerror must not be executable" + + +@pytest.mark.skipif(not _HAS_NODE, reason="node binary not on PATH") +def test_unknown_tool_case_insensitive_matches_icons(): + out = _run([{"tool": "WEB_SEARCH"}, {"tool": "Web_Search"}]) + for r in out: + assert " Date: Wed, 10 Jun 2026 20:33:01 +0530 Subject: [PATCH 05/18] fix(db): close sqlite migration connections on exception paths (#3600) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The _migrate_* startup helpers in core/database.py opened a raw sqlite3.connect() inside a try and called conn.close() as the last statement in that try. If any earlier statement raised (locked DB, unexpected schema, a failed ALTER), close() was skipped and the bare except only logged the error — leaking the connection (file handle + lock) for the lifetime of the process. These migrations run on every startup. Wrap each in the conn = None + try/except/finally pattern already used by _migrate_chat_messages_fts in this same file, so the connection is closed on all exit paths. 25 functions; no change on the success path. Helpers that already close safely are left untouched: _migrate_chat_messages_fts and _migrate_backfill_task_folders (the latter uses SQLAlchemy's engine.connect() context manager). Same bug class as the previously merged DB-connection-leak fix (#64) and the IMAP logout-on-all-paths fix (#1530). --- core/database.py | 175 ++++++++++++++++++++++++++++++++++++++++------- 1 file changed, 150 insertions(+), 25 deletions(-) diff --git a/core/database.py b/core/database.py index ee365c30c..6eec48d11 100644 --- a/core/database.py +++ b/core/database.py @@ -688,6 +688,7 @@ def _migrate_add_last_message_at_column(): db_path = DATABASE_URL.replace("sqlite:///", "") if not os.path.exists(db_path): return + conn = None try: conn = sqlite3.connect(db_path) cursor = conn.execute("PRAGMA table_info(sessions)") @@ -713,10 +714,14 @@ def _migrate_add_last_message_at_column(): "ON sessions(archived, last_message_at)" ) conn.commit() - conn.close() logging.getLogger(__name__).info("Migrated: added + backfilled 'last_message_at' on sessions") except Exception as e: logging.getLogger(__name__).warning(f"last_message_at migration failed: {e}") + finally: + try: + conn.close() + except Exception: + pass def _migrate_add_document_archived_column(): """Add `archived` to documents (soft-archive flag). Guarded + idempotent.""" @@ -724,6 +729,7 @@ def _migrate_add_document_archived_column(): db_path = DATABASE_URL.replace("sqlite:///", "") if not os.path.exists(db_path): return + conn = None try: conn = sqlite3.connect(db_path) cursor = conn.execute("PRAGMA table_info(documents)") @@ -732,9 +738,13 @@ def _migrate_add_document_archived_column(): conn.execute("ALTER TABLE documents ADD COLUMN archived BOOLEAN DEFAULT 0") conn.commit() logging.getLogger(__name__).info("Migrated: added 'archived' to documents") - conn.close() except Exception as e: logging.getLogger(__name__).warning(f"documents.archived migration failed: {e}") + finally: + try: + conn.close() + except Exception: + pass def _migrate_add_owner_column(): @@ -743,6 +753,7 @@ def _migrate_add_owner_column(): db_path = DATABASE_URL.replace("sqlite:///", "") if not os.path.exists(db_path): return + conn = None try: conn = sqlite3.connect(db_path) cursor = conn.execute("PRAGMA table_info(sessions)") @@ -752,9 +763,13 @@ def _migrate_add_owner_column(): conn.execute("CREATE INDEX IF NOT EXISTS ix_sessions_owner ON sessions(owner)") conn.commit() logging.getLogger(__name__).info("Migrated: added 'owner' column to sessions") - conn.close() except Exception as e: logging.getLogger(__name__).warning(f"Migration check failed: {e}") + finally: + try: + conn.close() + except Exception: + pass def _migrate_model_endpoints(): """Recreate model_endpoints table if schema changed (url->base_url).""" @@ -762,6 +777,7 @@ def _migrate_model_endpoints(): db_path = DATABASE_URL.replace("sqlite:///", "") if not os.path.exists(db_path): return + conn = None try: conn = sqlite3.connect(db_path) cursor = conn.execute("PRAGMA table_info(model_endpoints)") @@ -770,9 +786,13 @@ def _migrate_model_endpoints(): conn.execute("DROP TABLE IF EXISTS model_endpoints") conn.commit() logging.getLogger(__name__).info("Migrated: dropped old model_endpoints table (schema change)") - conn.close() except Exception as e: logging.getLogger(__name__).warning(f"model_endpoints migration check failed: {e}") + finally: + try: + conn.close() + except Exception: + pass def _migrate_add_hidden_models_column(): """Add hidden_models column to model_endpoints if it doesn't exist.""" @@ -780,6 +800,7 @@ def _migrate_add_hidden_models_column(): db_path = DATABASE_URL.replace("sqlite:///", "") if not os.path.exists(db_path): return + conn = None try: conn = sqlite3.connect(db_path) cursor = conn.execute("PRAGMA table_info(model_endpoints)") @@ -788,9 +809,13 @@ def _migrate_add_hidden_models_column(): conn.execute("ALTER TABLE model_endpoints ADD COLUMN hidden_models TEXT") conn.commit() logging.getLogger(__name__).info("Migrated: added 'hidden_models' column to model_endpoints") - conn.close() except Exception as e: logging.getLogger(__name__).warning(f"hidden_models migration failed: {e}") + finally: + try: + conn.close() + except Exception: + pass def _migrate_add_model_endpoint_owner_column(): """Add owner column to model_endpoints if it doesn't exist. @@ -805,6 +830,7 @@ def _migrate_add_model_endpoint_owner_column(): db_path = DATABASE_URL.replace("sqlite:///", "") if not os.path.exists(db_path): return + conn = None try: conn = sqlite3.connect(db_path) cursor = conn.execute("PRAGMA table_info(model_endpoints)") @@ -814,9 +840,13 @@ def _migrate_add_model_endpoint_owner_column(): conn.execute("CREATE INDEX IF NOT EXISTS ix_model_endpoints_owner ON model_endpoints(owner)") conn.commit() logging.getLogger(__name__).info("Migrated: added 'owner' column + index to model_endpoints") - conn.close() except Exception as e: logging.getLogger(__name__).warning(f"model_endpoints.owner migration failed: {e}") + finally: + try: + conn.close() + except Exception: + pass def _migrate_add_provider_auth_id_column(): @@ -825,6 +855,7 @@ def _migrate_add_provider_auth_id_column(): db_path = DATABASE_URL.replace("sqlite:///", "") if not os.path.exists(db_path): return + conn = None try: conn = sqlite3.connect(db_path) cursor = conn.execute("PRAGMA table_info(model_endpoints)") @@ -834,9 +865,13 @@ def _migrate_add_provider_auth_id_column(): conn.execute("CREATE INDEX IF NOT EXISTS ix_model_endpoints_provider_auth_id ON model_endpoints(provider_auth_id)") conn.commit() logging.getLogger(__name__).info("Migrated: added 'provider_auth_id' column + index to model_endpoints") - conn.close() except Exception as e: logging.getLogger(__name__).warning(f"model_endpoints.provider_auth_id migration failed: {e}") + finally: + try: + conn.close() + except Exception: + pass def _migrate_add_model_type_column(): @@ -845,6 +880,7 @@ def _migrate_add_model_type_column(): db_path = DATABASE_URL.replace("sqlite:///", "") if not os.path.exists(db_path): return + conn = None try: conn = sqlite3.connect(db_path) cursor = conn.execute("PRAGMA table_info(model_endpoints)") @@ -853,9 +889,13 @@ def _migrate_add_model_type_column(): conn.execute("ALTER TABLE model_endpoints ADD COLUMN model_type TEXT DEFAULT 'llm'") conn.commit() logging.getLogger(__name__).info("Migrated: added 'model_type' column to model_endpoints") - conn.close() except Exception as e: logging.getLogger(__name__).warning(f"model_type migration failed: {e}") + finally: + try: + conn.close() + except Exception: + pass def _migrate_add_model_endpoint_refresh_columns(): """Add endpoint classification / refresh policy columns if missing.""" @@ -863,6 +903,7 @@ def _migrate_add_model_endpoint_refresh_columns(): db_path = DATABASE_URL.replace("sqlite:///", "") if not os.path.exists(db_path): return + conn = None try: conn = sqlite3.connect(db_path) cursor = conn.execute("PRAGMA table_info(model_endpoints)") @@ -876,9 +917,13 @@ def _migrate_add_model_endpoint_refresh_columns(): if columns and "model_refresh_timeout" not in columns: conn.execute("ALTER TABLE model_endpoints ADD COLUMN model_refresh_timeout INTEGER") conn.commit() - conn.close() except Exception as e: logging.getLogger(__name__).warning(f"model_endpoints refresh-policy migration failed: {e}") + finally: + try: + conn.close() + except Exception: + pass def _migrate_add_task_run_model_column(): """Add model column to task_runs if it doesn't exist (records which model ran).""" @@ -886,6 +931,7 @@ def _migrate_add_task_run_model_column(): db_path = DATABASE_URL.replace("sqlite:///", "") if not os.path.exists(db_path): return + conn = None try: conn = sqlite3.connect(db_path) cursor = conn.execute("PRAGMA table_info(task_runs)") @@ -894,9 +940,13 @@ def _migrate_add_task_run_model_column(): conn.execute("ALTER TABLE task_runs ADD COLUMN model TEXT") conn.commit() logging.getLogger(__name__).info("Migrated: added 'model' column to task_runs") - conn.close() except Exception as e: logging.getLogger(__name__).warning(f"task_runs model migration failed: {e}") + finally: + try: + conn.close() + except Exception: + pass def _migrate_add_supports_tools_column(): """Add supports_tools column to model_endpoints if it doesn't exist.""" @@ -904,6 +954,7 @@ def _migrate_add_supports_tools_column(): db_path = DATABASE_URL.replace("sqlite:///", "") if not os.path.exists(db_path): return + conn = None try: conn = sqlite3.connect(db_path) cursor = conn.execute("PRAGMA table_info(model_endpoints)") @@ -912,9 +963,13 @@ def _migrate_add_supports_tools_column(): conn.execute("ALTER TABLE model_endpoints ADD COLUMN supports_tools BOOLEAN") conn.commit() logging.getLogger(__name__).info("Migrated: added 'supports_tools' column to model_endpoints") - conn.close() except Exception as e: logging.getLogger(__name__).warning(f"supports_tools migration failed: {e}") + finally: + try: + conn.close() + except Exception: + pass def _migrate_add_cached_models_column(): @@ -923,6 +978,7 @@ def _migrate_add_cached_models_column(): db_path = DATABASE_URL.replace("sqlite:///", "") if not os.path.exists(db_path): return + conn = None try: conn = sqlite3.connect(db_path) cursor = conn.execute("PRAGMA table_info(model_endpoints)") @@ -930,9 +986,13 @@ def _migrate_add_cached_models_column(): if columns and "cached_models" not in columns: conn.execute("ALTER TABLE model_endpoints ADD COLUMN cached_models TEXT") conn.commit() - conn.close() except Exception as e: logging.getLogger(__name__).warning(f"cached_models migration failed: {e}") + finally: + try: + conn.close() + except Exception: + pass def _migrate_add_pinned_models_column(): """Add pinned_models column to model_endpoints if it doesn't exist.""" @@ -940,6 +1000,7 @@ def _migrate_add_pinned_models_column(): db_path = DATABASE_URL.replace("sqlite:///", "") if not os.path.exists(db_path): return + conn = None try: conn = sqlite3.connect(db_path) cursor = conn.execute("PRAGMA table_info(model_endpoints)") @@ -948,9 +1009,13 @@ def _migrate_add_pinned_models_column(): conn.execute("ALTER TABLE model_endpoints ADD COLUMN pinned_models TEXT") conn.commit() logging.getLogger(__name__).info("Migrated: added 'pinned_models' column to model_endpoints") - conn.close() except Exception as e: logging.getLogger(__name__).warning(f"pinned_models migration failed: {e}") + finally: + try: + conn.close() + except Exception: + pass def _migrate_add_notes_sort_order(): """Add sort_order, image_url, repeat columns to notes if they don't exist.""" @@ -958,6 +1023,7 @@ def _migrate_add_notes_sort_order(): db_path = DATABASE_URL.replace("sqlite:///", "") if not os.path.exists(db_path): return + conn = None try: conn = sqlite3.connect(db_path) cursor = conn.execute("PRAGMA table_info(notes)") @@ -975,9 +1041,13 @@ def _migrate_add_notes_sort_order(): if columns and "agent_session_id" not in columns: conn.execute("ALTER TABLE notes ADD COLUMN agent_session_id TEXT") conn.commit() - conn.close() except Exception as e: logging.getLogger(__name__).warning(f"notes migration failed: {e}") + finally: + try: + conn.close() + except Exception: + pass def _migrate_add_mode_column(): """Add mode column to sessions table if it doesn't exist.""" @@ -985,6 +1055,7 @@ def _migrate_add_mode_column(): db_path = DATABASE_URL.replace("sqlite:///", "") if not os.path.exists(db_path): return + conn = None try: conn = sqlite3.connect(db_path) cursor = conn.execute("PRAGMA table_info(sessions)") @@ -993,9 +1064,13 @@ def _migrate_add_mode_column(): conn.execute("ALTER TABLE sessions ADD COLUMN mode TEXT") conn.commit() logging.getLogger(__name__).info("Migrated: added 'mode' column to sessions") - conn.close() except Exception as e: logging.getLogger(__name__).warning(f"Migration check for mode failed: {e}") + finally: + try: + conn.close() + except Exception: + pass def _migrate_add_folder_column(): """Add folder column to sessions table if it doesn't exist.""" @@ -1003,6 +1078,7 @@ def _migrate_add_folder_column(): db_path = DATABASE_URL.replace("sqlite:///", "") if not os.path.exists(db_path): return + conn = None try: conn = sqlite3.connect(db_path) cursor = conn.execute("PRAGMA table_info(sessions)") @@ -1011,9 +1087,13 @@ def _migrate_add_folder_column(): conn.execute("ALTER TABLE sessions ADD COLUMN folder TEXT") conn.commit() logging.getLogger(__name__).info("Migrated: added 'folder' column to sessions") - conn.close() except Exception as e: logging.getLogger(__name__).warning(f"Migration check for folder failed: {e}") + finally: + try: + conn.close() + except Exception: + pass def _migrate_add_token_columns(): """Add cumulative token tracking columns to sessions table.""" @@ -1021,6 +1101,7 @@ def _migrate_add_token_columns(): db_path = DATABASE_URL.replace("sqlite:///", "") if not os.path.exists(db_path): return + conn = None try: conn = sqlite3.connect(db_path) cursor = conn.execute("PRAGMA table_info(sessions)") @@ -1030,9 +1111,13 @@ def _migrate_add_token_columns(): conn.execute("ALTER TABLE sessions ADD COLUMN total_output_tokens INTEGER DEFAULT 0") conn.commit() logging.getLogger(__name__).info("Migrated: added token tracking columns to sessions") - conn.close() except Exception as e: logging.getLogger(__name__).warning(f"Migration check for token columns failed: {e}") + finally: + try: + conn.close() + except Exception: + pass def _migrate_add_owner_to_table(table_name: str, index_name: str): """Generic helper: add owner TEXT column + index to a table if missing.""" @@ -1040,6 +1125,7 @@ def _migrate_add_owner_to_table(table_name: str, index_name: str): db_path = DATABASE_URL.replace("sqlite:///", "") if not os.path.exists(db_path): return + conn = None try: conn = sqlite3.connect(db_path) cursor = conn.execute(f"PRAGMA table_info({table_name})") @@ -1049,9 +1135,13 @@ def _migrate_add_owner_to_table(table_name: str, index_name: str): conn.execute(f"CREATE INDEX IF NOT EXISTS {index_name} ON {table_name}(owner)") conn.commit() logging.getLogger(__name__).info(f"Migrated: added 'owner' column to {table_name}") - conn.close() except Exception as e: logging.getLogger(__name__).warning(f"Migration owner column for {table_name} failed: {e}") + finally: + try: + conn.close() + except Exception: + pass def _migrate_add_multiuser_owner_columns(): """Add owner column to memories, gallery_images, user_tools, comparisons.""" @@ -1076,6 +1166,7 @@ def _migrate_add_api_token_scopes_column(): db_path = DATABASE_URL.replace("sqlite:///", "") if not os.path.exists(db_path): return + conn = None try: conn = sqlite3.connect(db_path) columns = [row[1] for row in conn.execute("PRAGMA table_info(api_tokens)").fetchall()] @@ -1084,9 +1175,13 @@ def _migrate_add_api_token_scopes_column(): conn.execute("UPDATE api_tokens SET scopes = 'chat' WHERE scopes IS NULL OR scopes = ''") conn.commit() logging.getLogger(__name__).info("Migrated: added scopes column to api_tokens") - conn.close() except Exception as e: logging.getLogger(__name__).warning(f"api_tokens.scopes migration failed: {e}") + finally: + try: + conn.close() + except Exception: + pass def _migrate_assign_legacy_owner(): """Assign all null-owner data to the first (admin) user. @@ -1128,6 +1223,7 @@ def _migrate_assign_legacy_owner(): return logger = logging.getLogger(__name__) + conn = None try: conn = sqlite3.connect(db_path) # Every table with an `owner` column. New tables added later will be @@ -1152,9 +1248,13 @@ def _migrate_assign_legacy_owner(): except Exception as e: logger.warning(f"Legacy owner assignment for {table} failed: {e}") conn.commit() - conn.close() except Exception as e: logger.warning(f"Legacy owner migration failed: {e}") + finally: + try: + conn.close() + except Exception: + pass # Also migrate memory.json mem_path = MEMORY_FILE @@ -1773,6 +1873,7 @@ def _migrate_add_email_smtp_security(): db_path = DATABASE_URL.replace("sqlite:///", "") if not os.path.exists(db_path): return + conn = None try: conn = sqlite3.connect(db_path) cursor = conn.execute("PRAGMA table_info(email_accounts)") @@ -1788,9 +1889,13 @@ def _migrate_add_email_smtp_security(): ) conn.commit() logging.getLogger(__name__).info("Migrated: added smtp_security column to email_accounts") - conn.close() except Exception as e: logging.getLogger(__name__).warning(f"smtp_security migration skipped: {e}") + finally: + try: + conn.close() + except Exception: + pass def _migrate_encrypt_endpoint_keys(): @@ -1891,6 +1996,7 @@ def _migrate_add_calendar_is_utc(): db_path = DATABASE_URL.replace("sqlite:///", "") if not os.path.exists(db_path): return + conn = None try: conn = sqlite3.connect(db_path) cursor = conn.execute("PRAGMA table_info(calendar_events)") @@ -1899,9 +2005,13 @@ def _migrate_add_calendar_is_utc(): conn.execute("ALTER TABLE calendar_events ADD COLUMN is_utc BOOLEAN DEFAULT 0 NOT NULL") conn.commit() logging.getLogger(__name__).info("Migrated: added 'is_utc' column to calendar_events") - conn.close() except Exception as e: logging.getLogger(__name__).warning(f"is_utc migration failed: {e}") + finally: + try: + conn.close() + except Exception: + pass def _migrate_add_calendar_origin(): @@ -1912,6 +2022,7 @@ def _migrate_add_calendar_origin(): db_path = DATABASE_URL.replace("sqlite:///", "") if not os.path.exists(db_path): return + conn = None try: conn = sqlite3.connect(db_path) cursor = conn.execute("PRAGMA table_info(calendar_events)") @@ -1921,9 +2032,13 @@ def _migrate_add_calendar_origin(): conn.execute("CREATE INDEX IF NOT EXISTS ix_calendar_events_origin ON calendar_events(origin)") conn.commit() logging.getLogger(__name__).info("Migrated: added 'origin' column to calendar_events") - conn.close() except Exception as e: logging.getLogger(__name__).warning(f"calendar_events.origin migration failed: {e}") + finally: + try: + conn.close() + except Exception: + pass def _migrate_add_calendar_account_id(): @@ -1933,6 +2048,7 @@ def _migrate_add_calendar_account_id(): db_path = DATABASE_URL.replace("sqlite:///", "") if not os.path.exists(db_path): return + conn = None try: conn = sqlite3.connect(db_path) cursor = conn.execute("PRAGMA table_info(calendars)") @@ -1942,9 +2058,13 @@ def _migrate_add_calendar_account_id(): conn.execute("CREATE INDEX IF NOT EXISTS ix_calendars_account_id ON calendars(account_id)") conn.commit() logging.getLogger(__name__).info("Migrated: added 'account_id' column to calendars") - conn.close() except Exception as e: logging.getLogger(__name__).warning(f"calendars.account_id migration failed: {e}") + finally: + try: + conn.close() + except Exception: + pass def _migrate_add_calendar_metadata(): @@ -1953,6 +2073,7 @@ def _migrate_add_calendar_metadata(): db_path = DATABASE_URL.replace("sqlite:///", "") if not os.path.exists(db_path): return + conn = None try: conn = sqlite3.connect(db_path) cursor = conn.execute("PRAGMA table_info(calendar_events)") @@ -1964,9 +2085,13 @@ def _migrate_add_calendar_metadata(): if columns and "last_pinged" not in columns: conn.execute("ALTER TABLE calendar_events ADD COLUMN last_pinged DATETIME") conn.commit() - conn.close() except Exception as e: logging.getLogger(__name__).warning(f"calendar_events migration failed: {e}") + finally: + try: + conn.close() + except Exception: + pass def get_db(): """ From 6f73c8afaa1641c7a8db399b39b45cf0b9b671b8 Mon Sep 17 00:00:00 2001 From: Ashvin <76151462+ashvinctrl@users.noreply.github.com> Date: Wed, 10 Jun 2026 20:37:07 +0530 Subject: [PATCH 06/18] fix(sessions): use owner_filter for list_sessions queries when auth disabled (#3622) Direct DbSession.owner == user becomes WHERE owner IS NULL when user is None (auth disabled), hiding all sessions that carry an explicit owner. Same flaw on the Document and GalleryImage sub-queries (active-doc and gallery badges). Replace all three with owner_filter(), which is a no-op when user is falsy. Fixes #3620 --- routes/session_routes.py | 24 ++++++++++++++---------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/routes/session_routes.py b/routes/session_routes.py index 811a40bbe..1fb2a487a 100644 --- a/routes/session_routes.py +++ b/routes/session_routes.py @@ -11,7 +11,7 @@ from core.session_manager import SessionManager from core.models import ChatMessage from src.request_models import SessionResponse from core.database import Session as DbSession, SessionLocal, Document, GalleryImage, utcnow_naive -from src.auth_helpers import get_current_user, effective_user, _auth_disabled +from src.auth_helpers import get_current_user, effective_user, _auth_disabled, owner_filter from src.session_actions import is_session_recently_active @@ -258,7 +258,9 @@ def setup_session_routes(session_manager: SessionManager, config: dict, webhook_ last_msg_map = {} mode_map = {} msg_count_map = {} - rows = db.query(DbSession.id, DbSession.folder, DbSession.total_input_tokens, DbSession.total_output_tokens, DbSession.is_important, DbSession.created_at, DbSession.updated_at, DbSession.last_message_at, DbSession.mode, DbSession.message_count).filter(DbSession.archived == False, DbSession.owner == user).all() + q = db.query(DbSession.id, DbSession.folder, DbSession.total_input_tokens, DbSession.total_output_tokens, DbSession.is_important, DbSession.created_at, DbSession.updated_at, DbSession.last_message_at, DbSession.mode, DbSession.message_count).filter(DbSession.archived == False) + q = owner_filter(q, DbSession, user) + rows = q.all() for row in rows: folder_map[row.id] = row.folder token_map[row.id] = (row.total_input_tokens or 0) + (row.total_output_tokens or 0) @@ -277,17 +279,19 @@ def setup_session_routes(session_manager: SessionManager, config: dict, webhook_ # Sessions with active documents that have content from sqlalchemy import func doc_session_ids = set( - r[0] for r in db.query(Document.session_id) - .filter(Document.is_active == True, - Document.current_content != None, - func.trim(Document.current_content) != "", - Document.owner == user) + r[0] for r in owner_filter( + db.query(Document.session_id) + .filter(Document.is_active == True, + Document.current_content != None, + func.trim(Document.current_content) != ""), + Document, user) .distinct().all() ) img_session_ids = set( - r[0] for r in db.query(GalleryImage.session_id) - .filter(GalleryImage.session_id != None, - GalleryImage.owner == user) + r[0] for r in owner_filter( + db.query(GalleryImage.session_id) + .filter(GalleryImage.session_id != None), + GalleryImage, user) .distinct().all() ) finally: From 9c8df899734b267dd13430f213a6f54700315b47 Mon Sep 17 00:00:00 2001 From: Ashvin <76151462+ashvinctrl@users.noreply.github.com> Date: Wed, 10 Jun 2026 20:50:36 +0530 Subject: [PATCH 07/18] fix(auth): case-insensitive skill owner match on rename (#3614) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit SKILL.md files written with mixed-case owner (e.g. 'owner: Alice') were skipped because the regex had no IGNORECASE flag. _usage.json keys like 'Alice::skill-name' were missed by the startswith prefix check for the same reason. Both comparisons now match the same way the deep_research and memory blocks do — case-insensitively against old_username. Fixes #3611 --- routes/auth_routes.py | 9 ++++---- tests/test_rename_user_owner_sync.py | 31 ++++++++++++++++++++++++++++ 2 files changed, 36 insertions(+), 4 deletions(-) diff --git a/routes/auth_routes.py b/routes/auth_routes.py index 853958d35..6e0ae8a5a 100644 --- a/routes/auth_routes.py +++ b/routes/auth_routes.py @@ -391,7 +391,8 @@ def setup_auth_routes(auth_manager: AuthManager) -> APIRouter: skills_root = Path(SKILLS_DIR) if skills_root.is_dir(): _owner_re = re.compile( - r'(?m)^(owner:\s*)' + re.escape(old_username) + r'\s*$' + r'(?m)^(owner:\s*)' + re.escape(old_username) + r'\s*$', + re.IGNORECASE, ) for p in skills_root.rglob("SKILL.md"): try: @@ -406,12 +407,12 @@ def setup_auth_routes(auth_manager: AuthManager) -> APIRouter: try: usage = json.loads(usage_path.read_text(encoding="utf-8")) if isinstance(usage, dict): - prefix = old_username + "::" new_usage = {} changed = False for k, v in usage.items(): - if k.startswith(prefix): - new_usage[new_username + "::" + k[len(prefix):]] = v + owner_part, sep, skill_part = k.partition("::") + if sep and owner_part.lower() == old_username: + new_usage[new_username + "::" + skill_part] = v changed = True else: new_usage[k] = v diff --git a/tests/test_rename_user_owner_sync.py b/tests/test_rename_user_owner_sync.py index 16d91c512..1de14f31a 100644 --- a/tests/test_rename_user_owner_sync.py +++ b/tests/test_rename_user_owner_sync.py @@ -333,6 +333,37 @@ def test_rename_no_skills_dir_does_not_crash(rename_endpoint): assert res["ok"] is True +def test_rename_skill_md_owner_case_insensitive(rename_endpoint): + """SKILL.md written with owner: Alice (mixed case) must be updated when + renaming alice — the regex was missing re.IGNORECASE.""" + endpoint, _am, tmp_path = rename_endpoint + + skill_dir = tmp_path / "skills" / "general" / "s" + skill_dir.mkdir(parents=True) + (skill_dir / "SKILL.md").write_text(_SKILL_MD.format(owner="Alice"), encoding="utf-8") + + asyncio.run(endpoint("alice", SimpleNamespace(username="alice2"), _request(tmp_path))) + + assert "owner: alice2" in (skill_dir / "SKILL.md").read_text(encoding="utf-8") + + +def test_rename_usage_keys_case_insensitive(rename_endpoint): + """_usage.json keys stored as Alice::skill-name must be migrated when + renaming alice — the old startswith check was not lowercasing.""" + endpoint, _am, tmp_path = rename_endpoint + + skills_root = tmp_path / "skills" + skills_root.mkdir(parents=True) + usage = {"Alice::my-skill": {"uses": 5, "last_used": 999}} + (skills_root / "_usage.json").write_text(json.dumps(usage), encoding="utf-8") + + asyncio.run(endpoint("alice", SimpleNamespace(username="alice2"), _request(tmp_path))) + + updated = json.loads((skills_root / "_usage.json").read_text(encoding="utf-8")) + assert "alice2::my-skill" in updated + assert "Alice::my-skill" not in updated + + # --------------------------------------------------------------------------- # 5. P1 regression: rejected auth rename must not mutate file-backed stores # --------------------------------------------------------------------------- From 800d391234b739fb55e4008b7ec99ab0210ef10e Mon Sep 17 00:00:00 2001 From: RaresKeY <158580472+RaresKeY@users.noreply.github.com> Date: Wed, 10 Jun 2026 18:28:27 +0300 Subject: [PATCH 08/18] fix(auth): roll back rename on owner migration failure (#3616) --- routes/auth_routes.py | 18 ++++ tests/test_rename_user_owner_sync.py | 118 ++++++++++++++++++++++++++- 2 files changed, 133 insertions(+), 3 deletions(-) diff --git a/routes/auth_routes.py b/routes/auth_routes.py index 6e0ae8a5a..e67a4758f 100644 --- a/routes/auth_routes.py +++ b/routes/auth_routes.py @@ -305,6 +305,19 @@ def setup_auth_routes(auth_manager: AuthManager) -> APIRouter: if not ok: raise HTTPException(400, "Cannot rename user") + def _rollback_auth_rename() -> bool: + # On self-rename the admin session has already moved to the new + # username, so the rollback must authenticate as the new user. + rollback_user = new_username if user == old_username else user + try: + return bool(auth_manager.rename_user(new_username, old_username, rollback_user)) + except Exception as rollback_err: + logger.error( + "Failed to roll back auth rename %s -> %s after owner migration failure: %s", + new_username, old_username, rollback_err, + ) + return False + # Usernames are ownership keys for user data. Rename the common # owner-scoped DB rows so the account keeps access to its sessions, # docs, email accounts, tasks, etc. @@ -330,6 +343,11 @@ def setup_auth_routes(auth_manager: AuthManager) -> APIRouter: db.close() except Exception as e: logger.error("Failed to rename owner references %s -> %s: %s", old_username, new_username, e) + if not _rollback_auth_rename(): + logger.error( + "Auth rename %s -> %s could not be rolled back after owner migration failure", + old_username, new_username, + ) raise HTTPException(500, "Failed to rename user data") # Per-user prefs are JSON-backed, not SQL-backed. diff --git a/tests/test_rename_user_owner_sync.py b/tests/test_rename_user_owner_sync.py index 1de14f31a..24e1fb67c 100644 --- a/tests/test_rename_user_owner_sync.py +++ b/tests/test_rename_user_owner_sync.py @@ -26,6 +26,7 @@ from types import SimpleNamespace from unittest.mock import MagicMock import pytest +from fastapi import HTTPException def _route(router, name): @@ -63,18 +64,68 @@ def rename_endpoint(monkeypatch, tmp_path): return _route(ar.setup_auth_routes(am), "rename_user"), am, tmp_path -def _request(tmp_path, session_manager=None): +def _request(tmp_path, session_manager=None, token="t"): state = SimpleNamespace( invalidate_token_cache=lambda: None, session_manager=session_manager, ) return SimpleNamespace( - cookies={"odysseus_session": "t"}, + cookies={"odysseus_session": token}, app=SimpleNamespace(state=state), state=SimpleNamespace(current_user="admin"), ) +def _auth_manager_for_rollback_test(monkeypatch, tmp_path): + import core.auth as auth_mod + + monkeypatch.setattr(auth_mod, "_hash_password", lambda password: f"hash:{password}") + monkeypatch.setattr(auth_mod, "_verify_password", lambda password, hashed: hashed == f"hash:{password}") + + am = auth_mod.AuthManager(str(tmp_path / "auth.json")) + assert am.create_user("admin", "pw-123456", is_admin=True) is True + assert am.create_user("alice", "pw-123456") is True + return am + + +def _force_sql_owner_migration_failure(monkeypatch): + import core.database as cdb + + class OwnerModel: + owner = "owner" + + class FailingQuery: + def filter(self, *_args, **_kwargs): + return self + + def update(self, *_args, **_kwargs): + raise RuntimeError("forced owner migration failure") + + class FailingSession: + def __init__(self): + self.rolled_back = False + self.closed = False + + def query(self, _model): + return FailingQuery() + + def rollback(self): + self.rolled_back = True + + def close(self): + self.closed = True + + db = FailingSession() + monkeypatch.setattr(cdb, "SessionLocal", lambda: db) + monkeypatch.setattr( + cdb, + "Base", + SimpleNamespace(registry=SimpleNamespace(mappers=[SimpleNamespace(class_=OwnerModel)])), + raising=False, + ) + return db + + # --------------------------------------------------------------------------- # 1. In-memory session cache # --------------------------------------------------------------------------- @@ -365,7 +416,68 @@ def test_rename_usage_keys_case_insensitive(rename_endpoint): # --------------------------------------------------------------------------- -# 5. P1 regression: rejected auth rename must not mutate file-backed stores +# 5. Rollback: auth rename must be restored if SQL owner migration fails +# --------------------------------------------------------------------------- + +def test_owner_migration_failure_rolls_back_auth_rename(monkeypatch, tmp_path): + import routes.auth_routes as ar + + db = _force_sql_owner_migration_failure(monkeypatch) + am = _auth_manager_for_rollback_test(monkeypatch, tmp_path) + admin_token = am.create_session_trusted("admin") + alice_token = am.create_session_trusted("alice") + endpoint = _route(ar.setup_auth_routes(am), "rename_user") + + with pytest.raises(HTTPException) as exc: + asyncio.run( + endpoint( + "alice", + SimpleNamespace(username="alice2"), + _request(tmp_path, token=admin_token), + ) + ) + + assert exc.value.status_code == 500 + assert db.rolled_back is True + assert db.closed is True + assert "alice" in am.users + assert "alice2" not in am.users + assert am.get_username_for_token(alice_token) == "alice" + saved_users = json.loads((tmp_path / "auth.json").read_text(encoding="utf-8"))["users"] + assert "alice" in saved_users + assert "alice2" not in saved_users + + +def test_self_rename_owner_migration_failure_rolls_back_auth_session(monkeypatch, tmp_path): + import routes.auth_routes as ar + + db = _force_sql_owner_migration_failure(monkeypatch) + am = _auth_manager_for_rollback_test(monkeypatch, tmp_path) + admin_token = am.create_session_trusted("admin") + endpoint = _route(ar.setup_auth_routes(am), "rename_user") + + with pytest.raises(HTTPException) as exc: + asyncio.run( + endpoint( + "admin", + SimpleNamespace(username="chief"), + _request(tmp_path, token=admin_token), + ) + ) + + assert exc.value.status_code == 500 + assert db.rolled_back is True + assert db.closed is True + assert "admin" in am.users + assert "chief" not in am.users + assert am.get_username_for_token(admin_token) == "admin" + saved_users = json.loads((tmp_path / "auth.json").read_text(encoding="utf-8"))["users"] + assert "admin" in saved_users + assert "chief" not in saved_users + + +# --------------------------------------------------------------------------- +# 6. P1 regression: rejected auth rename must not mutate file-backed stores # --------------------------------------------------------------------------- def test_rejected_rename_does_not_mutate_files(monkeypatch, tmp_path): From 4e210d333780f555d0de44e6ac5c21ecfed2f070 Mon Sep 17 00:00:00 2001 From: Mazen Tamer Salah <78306991+mazen-salah@users.noreply.github.com> Date: Wed, 10 Jun 2026 18:40:44 +0300 Subject: [PATCH 09/18] fix(research): stop rescanning the research dir on every status poll (#3637) get_status() called get_avg_duration() unconditionally, and that helper globs and JSON-parses every file under the research data dir. The SSE status stream polls get_status() roughly once a second, so with a few saved reports each poll re-read and re-parsed all of them, including for sessions that are not active (the disk branch never even used the value). Compute avg_duration only for active sessions and memoize it on the task entry, so a long stream computes it once instead of on every poll. Behaviour is unchanged: active streams still report avg_duration. Adds tests/test_research_status_avg_duration.py: an inactive session does no avg scan, and an active session computes it once across many polls. --- src/research_handler.py | 9 ++++- tests/test_research_status_avg_duration.py | 41 ++++++++++++++++++++++ 2 files changed, 49 insertions(+), 1 deletion(-) create mode 100644 tests/test_research_status_avg_duration.py diff --git a/src/research_handler.py b/src/research_handler.py index b996f089f..b3af3b8e5 100644 --- a/src/research_handler.py +++ b/src/research_handler.py @@ -390,7 +390,6 @@ class ResearchHandler: def get_status(self, session_id: str) -> Optional[dict]: """Get current research status for a session.""" - avg = self.get_avg_duration() if session_id in self._active_tasks: entry = self._active_tasks[session_id] result = { @@ -399,6 +398,14 @@ class ResearchHandler: "query": entry["query"], "started_at": entry["started_at"], } + # avg_duration is a historical figure over completed reports on + # disk; get_avg_duration() globs and JSON-parses the whole research + # dir, so compute it at most once per active stream (memoized on the + # entry) instead of on every ~1s SSE poll. The disk branch below + # never used it, so it no longer pays that cost at all. + if "_avg_duration" not in entry: + entry["_avg_duration"] = self.get_avg_duration() + avg = entry["_avg_duration"] if avg is not None: result["avg_duration"] = round(avg, 1) return result diff --git a/tests/test_research_status_avg_duration.py b/tests/test_research_status_avg_duration.py new file mode 100644 index 000000000..d44c63242 --- /dev/null +++ b/tests/test_research_status_avg_duration.py @@ -0,0 +1,41 @@ +"""get_status must not rescan the whole research dir on every SSE poll. + +get_avg_duration() globs and JSON-parses every file under the research data dir. +get_status() called it unconditionally on each poll, including for sessions that +are not active (the common case while a client polls a finished report). It is +now computed only for active sessions and memoized on the entry. +""" +from src.research_handler import ResearchHandler + + +def _handler(): + h = ResearchHandler.__new__(ResearchHandler) + h._active_tasks = {} + return h + + +def test_inactive_session_does_not_compute_avg(monkeypatch): + h = _handler() + calls = [] + monkeypatch.setattr(h, "get_avg_duration", lambda: (calls.append(1), 5.0)[1]) + # Unknown session, no disk file -> None, and no expensive avg scan. + assert h.get_status("missing-session") is None + assert calls == [] + + +def test_active_session_memoizes_avg(monkeypatch): + h = _handler() + h._active_tasks["s1"] = { + "status": "running", "progress": {}, "query": "q", "started_at": 0, + } + calls = [] + monkeypatch.setattr(h, "get_avg_duration", lambda: (calls.append(1), 12.0)[1]) + + r1 = h.get_status("s1") + r2 = h.get_status("s1") + r3 = h.get_status("s1") + + assert r1["avg_duration"] == 12.0 + assert r2["avg_duration"] == 12.0 and r3["avg_duration"] == 12.0 + # Computed once across many polls, not once per poll. + assert len(calls) == 1 From 96975f8dd974d7ea001e47c9c745e5b83865173d Mon Sep 17 00:00:00 2001 From: Mazen Tamer Salah <78306991+mazen-salah@users.noreply.github.com> Date: Wed, 10 Jun 2026 18:50:22 +0300 Subject: [PATCH 10/18] fix(contacts): tolerate non-string body in /api/contacts/import (#3638) import_vcf built `text = data.get("vcf") or data.get("text") or ""`, so a non-string JSON value (a number, list, etc.) stayed in place and the following `text.strip()` raised AttributeError, returning HTTP 500. Coerce vcf/text/csv with str() so non-string input degrades to the existing structured "no data" response, matching the file's convention elsewhere. Adds tests/test_contacts_import_nonstring.py covering non-string vcf, non-string csv, and an empty body. --- routes/contacts_routes.py | 7 +++-- tests/test_contacts_import_nonstring.py | 39 +++++++++++++++++++++++++ 2 files changed, 44 insertions(+), 2 deletions(-) create mode 100644 tests/test_contacts_import_nonstring.py diff --git a/routes/contacts_routes.py b/routes/contacts_routes.py index e4e8ce759..58a57a1e1 100644 --- a/routes/contacts_routes.py +++ b/routes/contacts_routes.py @@ -729,8 +729,11 @@ def setup_contacts_routes(): @router.post("/import") async def import_vcf(data: dict, _admin: str = Depends(require_admin)): """Import contacts from .vcf or CSV. Body: {"vcf": "..."} or {"csv": "..."}.""" - text = data.get("vcf") or data.get("text") or "" - csv_text = data.get("csv") or "" + # Coerce defensively: a non-string vcf/text/csv (e.g. a number or list + # in the JSON body) would otherwise reach .strip() and 500 with an + # AttributeError instead of degrading to a clean "no data" response. + text = str(data.get("vcf") or data.get("text") or "") + csv_text = str(data.get("csv") or "") if text.strip(): if "BEGIN:VCARD" not in text.upper(): return {"success": False, "error": "No vCard data found"} diff --git a/tests/test_contacts_import_nonstring.py b/tests/test_contacts_import_nonstring.py new file mode 100644 index 000000000..c029b569d --- /dev/null +++ b/tests/test_contacts_import_nonstring.py @@ -0,0 +1,39 @@ +"""POST /api/contacts/import must not 500 on a non-string vcf/text/csv value. + +`text = data.get("vcf") or ... or ""` left a non-string value (e.g. a number) +in place, so the next `text.strip()` raised AttributeError -> HTTP 500. The +handler now coerces with str() and degrades to a structured "no data" response. +""" +import asyncio + +from routes.contacts_routes import setup_contacts_routes + + +def _import_handler(): + router = setup_contacts_routes() + for route in router.routes: + if getattr(route, "path", "").endswith("/import") and "POST" in getattr(route, "methods", set()): + return route.endpoint + raise AssertionError("import route not found") + + +def _call(data): + handler = _import_handler() + return asyncio.run(handler(data=data, _admin="admin")) + + +def test_non_string_vcf_degrades_cleanly(): + resp = _call({"vcf": 123}) + assert resp["success"] is False + assert "error" in resp + + +def test_non_string_csv_degrades_cleanly(): + resp = _call({"csv": ["a", "b"]}) + assert resp["success"] is False + + +def test_empty_body_reports_no_data(): + resp = _call({}) + assert resp["success"] is False + assert resp["error"] == "No contact data found" From a0b0420e6fef6982a2946e7b181e37adb25dcca5 Mon Sep 17 00:00:00 2001 From: ThomasAngel <30532050+rekterakathom@users.noreply.github.com> Date: Wed, 10 Jun 2026 18:59:47 +0300 Subject: [PATCH 11/18] chore: Switch duckduckgo-search to ddgs (#3143) * Switch to ddgs duckduckgo_search was deprecated, this is the recommended replacement * Update test_service_search_provider_guards.py According to review comment --- README.md | 2 +- requirements-optional.txt | 2 +- services/search/providers.py | 2 +- tests/test_service_search_provider_guards.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index a320f0052..a0dde96a9 100644 --- a/README.md +++ b/README.md @@ -329,7 +329,7 @@ To expose Odysseus on a local network or Tailscale with HTTPS: | Package | Feature unlocked | |---------|-----------------| | `faster-whisper` | Local speech-to-text (microphone -> text) via the "local" STT provider. | -| `duckduckgo-search` | DuckDuckGo as a search provider option. | +| `ddgs` | DuckDuckGo as a search provider option. | | `PyMuPDF` | PDF page rendering in the side viewer panel and form-filling. (Note: AGPL-3.0) | | `markitdown` | Office/EPUB document text extraction (converts .docx/.xlsx/.pptx/.xls/.epub to Markdown). | diff --git a/requirements-optional.txt b/requirements-optional.txt index eeb57c151..b4b654232 100644 --- a/requirements-optional.txt +++ b/requirements-optional.txt @@ -15,7 +15,7 @@ faster-whisper # DuckDuckGo as a search provider option. # Install if you want DDG in the search-provider dropdown. # Alternatives: SearXNG, Brave, Tavily, Serper, Google PSE. -duckduckgo-search +ddgs # PDF form-filling feature (fillable AcroForm detection, field extraction, # value/annotation/signature stamping, page rendering for the form overlay). diff --git a/services/search/providers.py b/services/search/providers.py index 1f8097ad8..b913e1c6f 100644 --- a/services/search/providers.py +++ b/services/search/providers.py @@ -417,7 +417,7 @@ def duckduckgo_search(query: str, count: Optional[int] = None, time_filter: Opti return [] try: - from duckduckgo_search import DDGS + from ddgs import DDGS except ImportError: logger.warning("duckduckgo-search package not installed; using HTML fallback") return _html_fallback() diff --git a/tests/test_service_search_provider_guards.py b/tests/test_service_search_provider_guards.py index 373928e64..cb9171a54 100644 --- a/tests/test_service_search_provider_guards.py +++ b/tests/test_service_search_provider_guards.py @@ -90,8 +90,8 @@ def test_service_ddg_html_fallback_sends_safesearch(monkeypatch): seen["params"] = kwargs["params"] return _Response() - monkeypatch.setitem(sys.modules, "duckduckgo_search", None) monkeypatch.setattr(providers, "_get_search_settings", lambda: {"search_safesearch": "off"}) + monkeypatch.setitem(sys.modules, "ddgs", None) monkeypatch.setattr(providers.httpx, "get", fake_get) results = providers.duckduckgo_search("odysseus", count=1) From 8bf821284671d184897d4d39b5321e0f75a806ee Mon Sep 17 00:00:00 2001 From: Max Hsu Date: Thu, 11 Jun 2026 00:29:22 +0800 Subject: [PATCH 12/18] fix(chat): copy only the displayed reply from the message copy buttons (#3731) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The AI-message copy buttons copied dataset.raw, which is the full accumulated model output — still containing the reasoning block and any tool-call markup that the renderer strips for display. Pasting therefore leaked the model's thinking, and the first heading after lost its markdown formatting because it was glued to the closing tag. Add chatRenderer.copyMessageText(), which mirrors the display pipeline (stripToolBlocks then extractThinkingBlocks) and falls back to the raw text when stripping leaves nothing (thinking-only turns), and route both copy handlers — the message footer and the slash-reply footer — through it. The interrupted-turn Continue flow intentionally keeps reading dataset.raw. Fixes #3722 Co-authored-by: Claude Fable 5 --- static/js/chatRenderer.js | 17 +- static/js/slashCommands.js | 2 +- tests/test_copy_message_strips_thinking_js.py | 160 ++++++++++++++++++ 3 files changed, 177 insertions(+), 2 deletions(-) create mode 100644 tests/test_copy_message_strips_thinking_js.py diff --git a/static/js/chatRenderer.js b/static/js/chatRenderer.js index 9a5c6f78b..7c6ecd096 100644 --- a/static/js/chatRenderer.js +++ b/static/js/chatRenderer.js @@ -862,6 +862,20 @@ export function stripToolBlocks(text) { return cleaned.trim(); } +/** + * Plain-text payload for the message copy buttons: the reply as the renderer + * displays it — tool blocks and reasoning stripped. dataset.raw keeps + * the full model output (chat.js even embeds the elapsed time into the + * tag for reload persistence), so copying it verbatim leaks the + * thinking block (#3722). Falls back to the raw text when stripping leaves + * nothing (e.g. turns interrupted mid-thinking). + */ +export function copyMessageText(msgElement) { + const raw = msgElement.dataset.raw || msgElement.querySelector('.body')?.textContent || ''; + const { content } = markdownModule.extractThinkingBlocks(stripToolBlocks(raw)); + return content || raw; +} + /** * Build a collapsible sources box (used by both research and web search). */ @@ -1372,7 +1386,7 @@ export function createMsgFooter(msgElement) { { id: 'copy', icon: COPY_ICON, title: 'Copy message', cls: 'footer-copy-btn', html: true, handler(e) { e.stopPropagation(); const btn = e.currentTarget; - uiModule.copyToClipboard(msgElement.dataset.raw || msgElement.querySelector('.body')?.textContent || ''); + uiModule.copyToClipboard(copyMessageText(msgElement)); btn.innerHTML = CHECK_ICON; setTimeout(() => { btn.innerHTML = COPY_ICON; }, 1500); }}, @@ -2444,6 +2458,7 @@ const chatRenderer = { updateSessionCostUI, roleTimestamp, stripToolBlocks, + copyMessageText, safeToolScreenshotSrc, safeDisplayImageSrc, buildSourcesBox, diff --git a/static/js/slashCommands.js b/static/js/slashCommands.js index 6a32cb89e..79b037cf4 100644 --- a/static/js/slashCommands.js +++ b/static/js/slashCommands.js @@ -380,7 +380,7 @@ function _slashFooter(msgEl) { copyBtn.innerHTML = _copySvg; copyBtn.onclick = (e) => { e.stopPropagation(); - uiModule.copyToClipboard(msgEl.dataset.raw || msgEl.querySelector('.body')?.textContent || ''); + uiModule.copyToClipboard(chatRenderer.copyMessageText(msgEl)); copyBtn.innerHTML = _checkSvg; setTimeout(() => { copyBtn.innerHTML = _copySvg; }, 1500); }; diff --git a/tests/test_copy_message_strips_thinking_js.py b/tests/test_copy_message_strips_thinking_js.py new file mode 100644 index 000000000..4c88bb6d4 --- /dev/null +++ b/tests/test_copy_message_strips_thinking_js.py @@ -0,0 +1,160 @@ +"""Regression coverage for issue #3722 — the message copy button copied the +full raw model output (``dataset.raw``), which still contains the +``...`` reasoning block that the renderer strips for +display. Pasting therefore leaked the model's thinking, and the first heading +after ```` lost its markdown formatting because it was glued to the +closing tag. + +The fix adds chatRenderer.copyMessageText(), which mirrors the display +pipeline (``stripToolBlocks()`` then ``extractThinkingBlocks()``), and routes +both AI-message copy buttons (createMsgFooter and the slash-reply footer) +through it. extractThinkingBlocks() behavior is pinned here under node +(including on the payload from the issue report); the helper and handler +wiring are guarded at the source level because chatRenderer.js pulls in +browser globals and can't be imported under node (same approach as +test_new_chat_clears_input.py). +""" + +import json +import re +import shutil +import subprocess +import textwrap +from pathlib import Path + +import pytest + +_REPO = Path(__file__).resolve().parent.parent +_HAS_NODE = shutil.which("node") is not None + + +@pytest.fixture(scope="module") +def node_available(): + if not _HAS_NODE: + pytest.skip("node binary not on PATH") + + +def _extract_thinking_blocks(text: str) -> dict: + """Run markdown.js extractThinkingBlocks(text) under node.""" + script = textwrap.dedent( + r""" + import fs from 'node:fs'; + + globalThis.window = { location: { origin: 'http://localhost' }, katex: null }; + globalThis.document = { + readyState: 'loading', + addEventListener() {}, + createElement(tag) { + if (tag !== 'template') throw new Error(`unsupported element: ${tag}`); + return { + _html: '', + content: { querySelectorAll() { return []; } }, + set innerHTML(value) { this._html = value; }, + get innerHTML() { return this._html; }, + }; + }, + }; + globalThis.MutationObserver = class { observe() {} }; + + let source = fs.readFileSync('./static/js/markdown.js', 'utf8'); + source = source.replace( + /import uiModule from ['"]\.\/ui\.js['"];/, + '' + ); + source = source.replace( + /import \{ splitTableRow \} from ['"]\.\/markdown\/tableRow\.js['"];/, + `function splitTableRow(row) { + return (row || '').replace(/^\\s*\\|/, '').replace(/\\|\\s*$/, '').split('|').map(c => c.trim()); + }` + ); + const emojiSource = fs.readFileSync('./static/js/emojiShortcodes.js', 'utf8') + .replace(/^export default .*$/m, '') + .replace(/export const /g, 'const ') + .replace(/export function /g, 'function '); + source = source.replace( + /import \{ replaceEmojiShortcodes, hasEmojiShortcode \} from ['"]\.\/emojiShortcodes\.js['"];/, + () => emojiSource + ); + source = source.replace( + /var escapeHtml = uiModule\.esc;/, + `var escapeHtml = (value) => String(value ?? '') + .replace(/&/g, '&') + .replace(//g, '>') + .replace(/"/g, '"') + .replace(/'/g, ''');` + ); + + const moduleUrl = 'data:text/javascript;base64,' + Buffer.from(source).toString('base64'); + const mod = await import(moduleUrl); + const input = JSON.parse(process.argv[1]); + console.log(JSON.stringify({ out: mod.extractThinkingBlocks(input) })); + """ + ) + result = subprocess.run( + ["node", "--input-type=module", "-e", script, json.dumps(text)], + cwd=_REPO, + capture_output=True, + timeout=15, + text=True, + ) + if result.returncode != 0: + raise AssertionError(f"node failed:\nSTDERR:\n{result.stderr}\nSTDOUT:\n{result.stdout}") + return json.loads(result.stdout.splitlines()[-1])["out"] + + +def test_issue_payload_copy_text_excludes_thinking(node_available): + # Shape reported in #3722: timed think block glued to the reply heading. + raw = ( + '\n' + "Here's a thinking process that leads to the desired summary:\n\n" + "6. **Generate the Output.** (This matches the final provided response.)" + "### Juxtaposition: Interweaving Cultural Norms in Lesson Design\n" + "The most effective lesson structure is created by deliberately juxtaposing." + ) + out = _extract_thinking_blocks(raw) + + assert out["content"].startswith("### Juxtaposition:"), out["content"] + assert "thinking process" not in out["content"] + assert "only reasoning, no reply yet") + assert out["content"] == "" + + +def _function_body(text: str, marker: str) -> str: + start = text.index(marker) + rest = text[start + len(marker):] + m = re.search(r"\nexport function |\nfunction ", rest) + return rest[: m.start()] if m else rest + + +def test_copy_message_text_mirrors_display_pipeline(): + text = (_REPO / "static/js/chatRenderer.js").read_text(encoding="utf-8") + body = _function_body(text, "export function copyMessageText") + # Mirrors the display path: tool blocks stripped, then thinking extracted. + assert "extractThinkingBlocks" in body + assert "stripToolBlocks" in body + assert "dataset.raw" in body + + +def test_copy_handlers_route_through_copy_message_text(): + for path, count in (("static/js/chatRenderer.js", 1), ("static/js/slashCommands.js", 1)): + text = (_REPO / path).read_text(encoding="utf-8") + assert text.count("copyToClipboard(copyMessageText(") + text.count( + "copyToClipboard(chatRenderer.copyMessageText(" + ) == count, path + # The old behavior passed dataset.raw straight to the clipboard. + assert "copyToClipboard(msgElement.dataset.raw" not in text, path + assert "copyToClipboard(msgEl.dataset.raw" not in text, path From f5b91f1e9e6b2812190e93400b356d809d8a2821 Mon Sep 17 00:00:00 2001 From: Mazen Tamer Salah <78306991+mazen-salah@users.noreply.github.com> Date: Wed, 10 Jun 2026 20:03:45 +0300 Subject: [PATCH 13/18] fix(tasks): read Memory.text in classify_events personal context (#3640) The classify_events task pulled user memories to give the LLM personal context, but read `m.content`, which the Memory ORM does not have (the column is `text`). That raised AttributeError on the first row; the surrounding except swallowed it and logged at debug, so the personal-context block was silently always empty and events were classified without it. Extract the rendering into `_memory_context_lines` (reads `text`, robust via getattr, keeps the 200-char and 40-line caps) and raise the swallowed-exception log to warning so a future schema mismatch is visible. Adds tests/test_classify_events_memory_text.py for the field, truncation, blank skipping, missing-attr robustness, and the line cap. --- src/builtin_actions.py | 31 ++++++++++++++------- tests/test_classify_events_memory_text.py | 33 +++++++++++++++++++++++ 2 files changed, 55 insertions(+), 9 deletions(-) create mode 100644 tests/test_classify_events_memory_text.py diff --git a/src/builtin_actions.py b/src/builtin_actions.py index b48ed94fa..1ea7cd8a4 100644 --- a/src/builtin_actions.py +++ b/src/builtin_actions.py @@ -579,6 +579,24 @@ def _classify_event_heuristic(summary: str) -> tuple: return etype, None +def _memory_context_lines(mems, limit: int = 40) -> list: + """Render Memory rows into short personal-context bullets for event classify. + + Reads the Memory ORM `text` column. The previous inline code read a + non-existent `content` attribute, so it raised AttributeError on the first + row, the surrounding except swallowed it, and the classifier ran with no + personal context at all. getattr keeps it robust to future schema drift. + """ + lines: list = [] + for m in mems: + c = (getattr(m, "text", "") or "").strip() + if c: + lines.append(f"- {c[:200]}") + if len(lines) >= limit: + break + return lines + + async def action_classify_events(owner: str, **kwargs) -> Tuple[str, bool]: """Hybrid classification of upcoming calendar events: fast heuristic for obvious cases, LLM fallback for ambiguous ones. Assigns event_type + @@ -614,16 +632,11 @@ async def action_classify_events(owner: str, **kwargs) -> Tuple[str, bool]: try: from core.database import Memory as _Mem _mems = db.query(_Mem).filter(_Mem.owner == owner).limit(60).all() if owner else [] - if _mems: - _lines = [] - for m in _mems: - c = (m.content or "").strip() - if c: - _lines.append(f"- {c[:200]}") - if _lines: - _memory_context = "USER CONTEXT (relationships, work, life):\n" + "\n".join(_lines[:40]) + "\n\n" + _lines = _memory_context_lines(_mems) + if _lines: + _memory_context = "USER CONTEXT (relationships, work, life):\n" + "\n".join(_lines) + "\n\n" except Exception as _me: - logger.debug(f"Could not load memory for classify: {_me}") + logger.warning(f"Could not load memory for classify: {_me}") classified_h = 0 classified_llm = 0 diff --git a/tests/test_classify_events_memory_text.py b/tests/test_classify_events_memory_text.py new file mode 100644 index 000000000..328929115 --- /dev/null +++ b/tests/test_classify_events_memory_text.py @@ -0,0 +1,33 @@ +"""classify_events must read the Memory `text` column, not a non-existent +`content` attribute. + +The previous inline loop did `m.content`, which raised AttributeError on the +first Memory row; the surrounding except swallowed it, so the personal-context +block the LLM relies on was always empty. The logic now lives in +`_memory_context_lines`, which reads `text`. +""" +from src.builtin_actions import _memory_context_lines + + +class _Mem: + def __init__(self, text): + self.text = text + + +def test_uses_text_and_truncates_and_skips_blank(): + lines = _memory_context_lines([_Mem("Alice is my spouse"), _Mem(" "), _Mem("y" * 250)]) + assert lines[0] == "- Alice is my spouse" + assert len(lines) == 2 # the blank row is skipped + assert lines[1] == "- " + "y" * 200 # truncated to 200 chars + + +def test_skips_rows_without_text_attribute(): + class _Bad: # mimics a schema where the attribute is absent + pass + + assert _memory_context_lines([_Bad(), _Mem("ok")]) == ["- ok"] + + +def test_respects_limit(): + mems = [_Mem(f"memory {i}") for i in range(50)] + assert len(_memory_context_lines(mems, limit=40)) == 40 From d9a4b99046f4f5111fe409ff61d1b4ed96a33bf5 Mon Sep 17 00:00:00 2001 From: Srinesh R Date: Wed, 10 Jun 2026 22:43:08 +0530 Subject: [PATCH 14/18] fix: handle batch events format in manage_calendar tool (#3503) * fix: handle batch events format in manage_calendar tool Models like deepseek-v4-flash emit batch events array instead of individual create_event calls. The tool defaulted to list_events (no action key), so events were never created despite the model confirming success. - Add batch normalization in do_manage_calendar - Map start/end objects to flat dtstart/dtend strings - Add tests for both object and flat string formats * fix: surface partial batch failures in manage_calendar Partial failures were silently dropped - batches with mixed success/failure would report only created count with no error visibility. - Return non-zero exit code for any failures - Surface both created and failed counts in response - Include first error message for debugging - Add test for partial failure case * chore: strip trailing whitespace in batch normalization block * chore: strip whitespace-only blank lines in batch events test --- src/tool_implementations.py | 36 ++++++++ tests/test_calendar_batch_events.py | 125 ++++++++++++++++++++++++++++ 2 files changed, 161 insertions(+) create mode 100644 tests/test_calendar_batch_events.py diff --git a/src/tool_implementations.py b/src/tool_implementations.py index 494795037..27c05f139 100644 --- a/src/tool_implementations.py +++ b/src/tool_implementations.py @@ -1453,6 +1453,42 @@ async def do_manage_calendar(content: str, owner: Optional[str] = None) -> Dict: except ValueError: return {"error": "Invalid JSON arguments", "exit_code": 1} + # ── Batch normalization ── + # Some models (e.g. deepseek-v4-flash) emit {"events": [{...}, ...]} + # instead of individual create_event calls. Iterate and create each. + if isinstance(args.get("events"), list) and not args.get("action"): + results = [] + for ev in args["events"]: + if not isinstance(ev, dict): + continue + # Normalize start/end from {dateTime: "..."} object to flat string + for field, target in [("start", "dtstart"), ("end", "dtend")]: + val = ev.pop(field, None) + if val and target not in ev: + ev[target] = val.get("dateTime", val) if isinstance(val, dict) else val + ev.setdefault("action", "create_event") + r = await do_manage_calendar(json.dumps(ev), owner=owner) + results.append(r) + created = [r for r in results if r.get("exit_code") == 0 and not r.get("error")] + failed = [r for r in results if r.get("error")] + + if not results: + return {"error": "No events to create", "exit_code": 1} + + # Surface both successes and failures + parts = [] + if created: + summaries = [r.get("response", "") for r in created] + parts.append(f"Created {len(created)} event(s):\n" + "\n".join(summaries)) + if failed: + first_error = failed[0].get("error", "Unknown error") + parts.append(f"Failed to create {len(failed)} event(s). First error: {first_error}") + + response = "\n\n".join(parts) + # Non-zero exit code for partial or total failure + exit_code = 0 if not failed else 1 + return {"response": response, "exit_code": exit_code, "created_count": len(created), "failed_count": len(failed)} + # Normalize action — some models emit hyphens ("list-calendars") instead # of underscores. Treat them as equivalent so we don't bounce a # cosmetic typo back to the model and waste a round-trip. Also accept diff --git a/tests/test_calendar_batch_events.py b/tests/test_calendar_batch_events.py new file mode 100644 index 000000000..d8176afcd --- /dev/null +++ b/tests/test_calendar_batch_events.py @@ -0,0 +1,125 @@ +"""Test that do_manage_calendar handles the batch {"events": [...]} format +that models like deepseek-v4-flash emit instead of individual create_event calls. +""" + +import json +import sys +import uuid + +import pytest + +from tests.helpers.import_state import clear_fake_database_modules +from tests.helpers.sqlite_db import make_temp_sqlite + +clear_fake_database_modules() + +import core.database as cdb +from core.database import CalendarEvent + +_TS, _ENGINE, _TMPDB = make_temp_sqlite(cdb.Base.metadata) + + +@pytest.fixture(autouse=True) +def _bind_temp_db(monkeypatch): + monkeypatch.setitem(sys.modules, "core.database", cdb) + parent = sys.modules.get("core") + if parent is not None: + monkeypatch.setattr(parent, "database", cdb, raising=False) + monkeypatch.setattr(cdb, "SessionLocal", _TS) + yield + + +async def test_batch_events_with_datetime_objects(): + """Model emits {"events": [{"summary": ..., "start": {"dateTime": ...}, "end": {"dateTime": ...}}]}.""" + from src.tool_implementations import do_manage_calendar + + owner = "tester-" + uuid.uuid4().hex[:6] + payload = { + "events": [ + { + "summary": "Morning Gym", + "start": {"dateTime": "2026-06-09T06:00:00+05:30"}, + "end": {"dateTime": "2026-06-09T07:00:00+05:30"}, + }, + { + "summary": "Morning Gym", + "start": {"dateTime": "2026-06-10T06:00:00+05:30"}, + "end": {"dateTime": "2026-06-10T07:00:00+05:30"}, + }, + ] + } + res = await do_manage_calendar(json.dumps(payload), owner=owner) + assert res.get("exit_code") == 0, res + assert "Created 2 event(s)" in res.get("response", "") + + # Verify events exist in DB + db = _TS() + events = db.query(CalendarEvent).filter(CalendarEvent.summary == "Morning Gym").all() + assert len(events) == 2 + db.close() + + +async def test_batch_events_with_flat_strings(): + """Model emits {"events": [{"summary": ..., "start": "ISO", "end": "ISO"}]}.""" + from src.tool_implementations import do_manage_calendar + + owner = "tester-" + uuid.uuid4().hex[:6] + payload = { + "events": [ + { + "summary": "Standup", + "start": "2026-06-09T09:00:00", + "end": "2026-06-09T09:30:00", + }, + ] + } + res = await do_manage_calendar(json.dumps(payload), owner=owner) + assert res.get("exit_code") == 0, res + assert "Created 1 event(s)" in res.get("response", "") + + +async def test_batch_events_partial_failure(): + """Batch with some valid and some invalid events — should surface both counts and first error.""" + from src.tool_implementations import do_manage_calendar + + owner = "tester-" + uuid.uuid4().hex[:6] + payload = { + "events": [ + { + "summary": "Valid Event 1", + "start": "2026-06-09T10:00:00", + "end": "2026-06-09T11:00:00", + }, + { + "summary": "Invalid Event", + # Missing required dtstart — will fail + }, + { + "summary": "Valid Event 2", + "start": "2026-06-09T14:00:00", + "end": "2026-06-09T15:00:00", + }, + ] + } + res = await do_manage_calendar(json.dumps(payload), owner=owner) + + # Partial failure = non-zero exit code + assert res.get("exit_code") != 0, "Partial failure should return non-zero exit code" + + # Response should mention both created and failed counts + response = res.get("response", "") + assert "Created 2 event(s)" in response, f"Should report 2 created: {response}" + assert "Failed to create 1 event(s)" in response, f"Should report 1 failed: {response}" + assert "error" in response.lower() or "required" in response.lower(), "Should include error details" + + # Metadata fields + assert res.get("created_count") == 2 + assert res.get("failed_count") == 1 + + # Verify only valid events were created + db = _TS() + events = db.query(CalendarEvent).filter( + CalendarEvent.summary.in_(["Valid Event 1", "Valid Event 2"]) + ).all() + assert len(events) == 2 + db.close() From 218b9ecbc8117990f612569e236eb0376ca7756d Mon Sep 17 00:00:00 2001 From: Mazen Tamer Salah <78306991+mazen-salah@users.noreply.github.com> Date: Wed, 10 Jun 2026 20:21:45 +0300 Subject: [PATCH 15/18] fix(startup): ping real endpoints in warmup/keepalive (#3641) _warmup_endpoints called model_discovery.get_endpoints(), which does not exist on ModelDiscovery. It raised AttributeError on every startup and on every 60s keepalive tick, was swallowed by the outer except, and pinged nothing, so the cold-start prevention the loop exists for never ran. Add ModelDiscovery.warmup_ping_urls(), which resolves the /models probe URLs from the real discover_models() output, and call it from the warmup loop via asyncio.to_thread (discovery does a blocking port scan, so keep it off the event loop). Adds tests/test_warmup_ping_urls.py: resolves /models URLs from discovered items, honors the limit, degrades to [] on discovery failure, and documents that get_endpoints never existed. --- app.py | 25 ++++++++++-------- src/model_discovery.py | 19 ++++++++++++++ tests/test_warmup_ping_urls.py | 47 ++++++++++++++++++++++++++++++++++ 3 files changed, 81 insertions(+), 10 deletions(-) create mode 100644 tests/test_warmup_ping_urls.py diff --git a/app.py b/app.py index 7cec8b0f1..2e1677ca2 100644 --- a/app.py +++ b/app.py @@ -946,16 +946,21 @@ async def _startup_event(): async def _warmup_endpoints(): try: import httpx - endpoints = model_discovery.get_endpoints() if model_discovery else [] - for ep in endpoints[:5]: - url = ep.get("url", "").replace("/chat/completions", "/models") - if url: - try: - async with httpx.AsyncClient(timeout=5.0) as client: - await client.get(url) - logger.info(f"Warmup ping OK: {url}") - except Exception as e: - logger.debug(f"Warmup ping failed for endpoint: {e}") + # model_discovery has no get_endpoints(); that call raised + # AttributeError every run and silently disabled warmup/keepalive. + # Resolve the /models probe URLs via the real discovery API, off the + # event loop since discovery does a blocking port scan. + urls = ( + await asyncio.to_thread(model_discovery.warmup_ping_urls) + if model_discovery else [] + ) + for url in urls: + try: + async with httpx.AsyncClient(timeout=5.0) as client: + await client.get(url) + logger.info(f"Warmup ping OK: {url}") + except Exception as e: + logger.debug(f"Warmup ping failed for endpoint: {e}") except Exception as e: logger.debug(f"Warmup ping skipped: {e}") diff --git a/src/model_discovery.py b/src/model_discovery.py index 68b402d25..506fcb6c4 100644 --- a/src/model_discovery.py +++ b/src/model_discovery.py @@ -223,6 +223,25 @@ class ModelDiscovery: ) return {"hosts": hosts, "items": items} + def warmup_ping_urls(self, limit: int = 5) -> List[str]: + """The ``/models`` URLs of up to ``limit`` discovered endpoints. + + Used by the startup warmup / keepalive loop to prime connections. Each + discovered item already carries a ``/v1/chat/completions`` url; swap the + suffix for the cheap ``/models`` probe. Failures degrade to an empty list + so warmup never crashes the caller. + """ + try: + items = (self.discover_models() or {}).get("items", []) + except Exception: + return [] + urls: List[str] = [] + for ep in items[:limit]: + url = (ep.get("url") or "").replace("/chat/completions", "/models") + if url: + urls.append(url) + return urls + def get_providers(self) -> Dict[str, Any]: """Get all available providers""" discovery = self.discover_models() diff --git a/tests/test_warmup_ping_urls.py b/tests/test_warmup_ping_urls.py new file mode 100644 index 000000000..7b5961831 --- /dev/null +++ b/tests/test_warmup_ping_urls.py @@ -0,0 +1,47 @@ +"""Startup warmup must resolve real endpoint URLs. + +The warmup/keepalive loop called `model_discovery.get_endpoints()`, which does +not exist on ModelDiscovery, so it raised AttributeError every run and pinged +nothing. `ModelDiscovery.warmup_ping_urls()` resolves the /models probe URLs +from the real discovery API. +""" +from src.model_discovery import ModelDiscovery + + +def _md(): + return ModelDiscovery.__new__(ModelDiscovery) + + +def test_old_method_never_existed(): + # Documents why the old warmup was a silent no-op. + assert not hasattr(ModelDiscovery, "get_endpoints") + + +def test_resolves_models_urls_from_discovered_items(): + md = _md() + md.discover_models = lambda: {"items": [ + {"url": "http://host:8000/v1/chat/completions", "models": ["a"]}, + {"url": "http://host:1234/v1/chat/completions", "models": ["b"]}, + ]} + assert md.warmup_ping_urls() == [ + "http://host:8000/v1/models", + "http://host:1234/v1/models", + ] + + +def test_limit_caps_results(): + md = _md() + md.discover_models = lambda: {"items": [ + {"url": f"http://h:{8000 + i}/v1/chat/completions"} for i in range(10) + ]} + assert len(md.warmup_ping_urls(limit=3)) == 3 + + +def test_discovery_failure_degrades_to_empty(): + md = _md() + + def boom(): + raise RuntimeError("port scan failed") + + md.discover_models = boom + assert md.warmup_ping_urls() == [] From d1a5a7d680e5b06249ad19a8790a9347d78961d9 Mon Sep 17 00:00:00 2001 From: RaresKeY <158580472+RaresKeY@users.noreply.github.com> Date: Thu, 11 Jun 2026 01:43:49 +0300 Subject: [PATCH 16/18] fix(hwfit): validate remote SSH detection targets (#3718) --- core/platform_compat.py | 4 ++ routes/_validators.py | 31 +++++++++++ routes/cookbook_helpers.py | 24 +-------- routes/cookbook_routes.py | 74 ++++++++++++++++----------- routes/hwfit_routes.py | 16 +++++- tests/test_cookbook_helpers.py | 7 --- tests/test_hwfit_remote_validation.py | 47 +++++++++++++++++ tests/test_route_validators.py | 23 +++++++++ 8 files changed, 164 insertions(+), 62 deletions(-) create mode 100644 routes/_validators.py create mode 100644 tests/test_hwfit_remote_validation.py create mode 100644 tests/test_route_validators.py diff --git a/core/platform_compat.py b/core/platform_compat.py index 3eda4a107..b3b157111 100644 --- a/core/platform_compat.py +++ b/core/platform_compat.py @@ -366,6 +366,10 @@ def _ssh_exec_argv( strict_host_key_checking: bool | None = None, ) -> list[str]: """Build a consistent ssh argv for remote command execution.""" + remote_value = str(remote or "").strip() + remote_host = remote_value.rsplit("@", 1)[-1] + if not remote_value or remote_value.startswith("-") or not remote_host or remote_host.startswith("-"): + raise ValueError("Invalid SSH remote host") argv = ["ssh"] if connect_timeout is not None: argv.extend(["-o", f"ConnectTimeout={int(connect_timeout)}"]) diff --git a/routes/_validators.py b/routes/_validators.py new file mode 100644 index 000000000..aa4cf00cc --- /dev/null +++ b/routes/_validators.py @@ -0,0 +1,31 @@ +import re + +from fastapi import HTTPException + + +_REMOTE_HOST_RE = re.compile( + r"^(?:[A-Za-z0-9][A-Za-z0-9._-]*@)?[A-Za-z0-9][A-Za-z0-9._-]*$" +) +_SSH_PORT_RE = re.compile(r"^\d{1,5}$") + + +def validate_remote_host(v: str | None) -> str | None: + if v is None or v == "": + return None + if not _REMOTE_HOST_RE.match(v): + raise HTTPException( + 400, + "Invalid remote_host — must be host or user@host, no SSH option syntax", + ) + return v + + +def validate_ssh_port(v: str | None) -> str | None: + if v is None or v == "": + return None + if not _SSH_PORT_RE.fullmatch(str(v)): + raise HTTPException(400, "Invalid ssh_port") + port = int(v) + if port < 1 or port > 65535: + raise HTTPException(400, "Invalid ssh_port") + return str(port) diff --git a/routes/cookbook_helpers.py b/routes/cookbook_helpers.py index 709245287..53bdde80e 100644 --- a/routes/cookbook_helpers.py +++ b/routes/cookbook_helpers.py @@ -11,6 +11,7 @@ import shlex from fastapi import HTTPException from pydantic import BaseModel +from routes._validators import validate_remote_host, validate_ssh_port from core.platform_compat import _ssh_exec_argv logger = logging.getLogger(__name__) @@ -30,16 +31,12 @@ _LOCAL_MODEL_ID_RE = re.compile(r"^[A-Za-z0-9][A-Za-z0-9._-]*$") _OLLAMA_MODEL_ID_RE = re.compile(r"^[A-Za-z0-9][A-Za-z0-9._:/-]{0,200}$") # Include pattern is a glob: allow typical safe glyphs only. _INCLUDE_RE = re.compile(r"^[A-Za-z0-9._\-*?/\[\]]+$") -# Remote host: either `user@host` or plain `host` (alias is allowed), where host -# is a safe DNS-like token or a short SSH config alias. -_REMOTE_HOST_RE = re.compile(r"^(?:[A-Za-z0-9._-]+@)?[A-Za-z0-9._-]+$") # HF tokens and API tokens are url-safe base64-like. _TOKEN_RE = re.compile(r"^[A-Za-z0-9._~+/=-]+$") # Session IDs we mint look like "cookbook-deadbeef" or "serve-deadbeef". # Anything beyond plain alphanumerics + dash + underscore could break out # of the shell/PowerShell contexts the value lands in. _SESSION_ID_RE = re.compile(r"^[A-Za-z0-9_-]{1,64}$") -_SSH_PORT_RE = re.compile(r"^\d{1,5}$") _GPU_LIST_RE = re.compile(r"^\d+(?:,\d+)*$") # A download target directory. Absolute or ~-relative path; safe path glyphs # only (no quotes or shell metacharacters). Spaces are allowed because command @@ -85,14 +82,6 @@ def _validate_include(v: str | None) -> str | None: return v -def _validate_remote_host(v: str | None) -> str | None: - if v is None or v == "": - return None - if not _REMOTE_HOST_RE.match(v): - raise HTTPException(400, "Invalid remote_host — must be host or user@host, no SSH option syntax") - return v - - def _validate_token(v: str | None) -> str | None: if v is None or v == "": return None @@ -120,17 +109,6 @@ def _validate_local_dir(v: str | None) -> str | None: return v -def _validate_ssh_port(v: str | None) -> str | None: - if v is None or v == "": - return None - if not _SSH_PORT_RE.fullmatch(str(v)): - raise HTTPException(400, "Invalid ssh_port") - port = int(v) - if port < 1 or port > 65535: - raise HTTPException(400, "Invalid ssh_port") - return str(port) - - def _validate_gpus(v: str | None) -> str | None: if v is None or v == "": return None diff --git a/routes/cookbook_routes.py b/routes/cookbook_routes.py index 4a4764232..36f98aeae 100644 --- a/routes/cookbook_routes.py +++ b/routes/cookbook_routes.py @@ -19,6 +19,7 @@ from src.constants import COOKBOOK_STATE_FILE from pydantic import BaseModel from core.middleware import require_admin +from routes._validators import validate_remote_host, validate_ssh_port from core.platform_compat import ( IS_WINDOWS, detached_popen_kwargs, @@ -33,9 +34,8 @@ from routes.shell_routes import TMUX_LOG_DIR logger = logging.getLogger(__name__) from routes.cookbook_helpers import ( - _SSH_PORT_RE, _REMOTE_HOST_RE, _SESSION_ID_RE, - _validate_repo_id, _validate_serve_model_id, _validate_include, _validate_remote_host, _validate_token, - _validate_local_dir, _validate_ssh_port, _validate_gpus, _shell_path, + _SESSION_ID_RE, _validate_repo_id, _validate_serve_model_id, _validate_include, _validate_token, + _validate_local_dir, _validate_gpus, _shell_path, _ps_squote, _bash_squote, _validate_serve_cmd, _parse_serve_phase, _safe_env_prefix, _local_tooling_path_export, _append_serve_preflight_exit_lines, _append_serve_exit_code_lines, _append_llama_cpp_linux_accel_build_lines, _cached_model_scan_script, @@ -407,8 +407,8 @@ def setup_cookbook_routes() -> APIRouter: else: _validate_repo_id(req.repo_id) _validate_include(req.include) - _validate_remote_host(req.remote_host) - req.ssh_port = _validate_ssh_port(req.ssh_port) + validate_remote_host(req.remote_host) + req.ssh_port = validate_ssh_port(req.ssh_port) req.local_dir = _validate_local_dir(req.local_dir) req.hf_token = "" if is_ollama_download else (req.hf_token or _load_stored_hf_token()) _validate_token(req.hf_token) @@ -739,9 +739,8 @@ def setup_cookbook_routes() -> APIRouter: # Validate shell-bound inputs, matching the sibling list_gpus endpoint — # `host`/`ssh_port` are interpolated into an ssh command below, so an # unvalidated value (e.g. "x'; rm -rf ~ #") would be command injection. - host = _validate_remote_host(host) - if ssh_port is not None and ssh_port != "" and not _SSH_PORT_RE.fullmatch(ssh_port): - raise HTTPException(400, "Invalid ssh_port") + host = validate_remote_host(host) + ssh_port = validate_ssh_port(ssh_port) TMUX_LOG_DIR.mkdir(parents=True, exist_ok=True) model_dirs = [] @@ -890,11 +889,16 @@ def setup_cookbook_routes() -> APIRouter: # listening" check without requiring ss/netstat/nmap. ssh_base = ["ssh", "-o", "ConnectTimeout=4", "-o", "StrictHostKeyChecking=no"] if ssh_port and str(ssh_port) != "22": - if not _SSH_PORT_RE.match(str(ssh_port)): + try: + ssh_port = validate_ssh_port(ssh_port) + except HTTPException: return None ssh_base.extend(["-p", str(ssh_port)]) - host_arg = remote - if not _REMOTE_HOST_RE.match(host_arg): + try: + host_arg = validate_remote_host(remote) + except HTTPException: + return None + if not host_arg: return None probe_ports = " ".join(str(start_port + i) for i in range(max_offset + 1)) script = ( @@ -1197,8 +1201,8 @@ def setup_cookbook_routes() -> APIRouter: """ require_admin(request) # Defence-in-depth: reject values that could break out of shell contexts. - _validate_remote_host(req.remote_host) - req.ssh_port = _validate_ssh_port(req.ssh_port) + validate_remote_host(req.remote_host) + req.ssh_port = validate_ssh_port(req.ssh_port) req.gpus = _validate_gpus(req.gpus) req.hf_token = req.hf_token or _load_stored_hf_token() _validate_token(req.hf_token) @@ -1638,12 +1642,11 @@ def setup_cookbook_routes() -> APIRouter: async def server_setup(request: Request, req: SetupRequest): """Install required dependencies on a remote server via SSH.""" require_admin(request) - host = _validate_remote_host(req.host) + host = validate_remote_host(req.host) if not host: raise HTTPException(400, "host is required") port = req.ssh_port - if port is not None and port != "" and not re.fullmatch(r"\d{1,5}", port): - raise HTTPException(400, "Invalid ssh_port") + port = validate_ssh_port(port) pf = f"-p {port} " if port and port != "22" else "" # Detect platform: Windows first (echo %OS% → Windows_NT), then Termux, then Linux @@ -1887,9 +1890,8 @@ def setup_cookbook_routes() -> APIRouter: `busy` is True when free_mb/total_mb < 0.5. """ require_admin(request) - host = _validate_remote_host(host) - if ssh_port is not None and ssh_port != "" and not _SSH_PORT_RE.fullmatch(ssh_port): - raise HTTPException(400, "Invalid ssh_port") + host = validate_remote_host(host) + ssh_port = validate_ssh_port(ssh_port) gpu_query = "nvidia-smi --query-gpu=index,name,memory.free,memory.total,memory.used,utilization.gpu,uuid --format=csv,noheader,nounits" nvidia_error = None try: @@ -2046,9 +2048,8 @@ def setup_cookbook_routes() -> APIRouter: sig = (req.signal or "TERM").upper() if sig not in ("TERM", "KILL", "INT"): raise HTTPException(400, "signal must be TERM, KILL, or INT") - host = _validate_remote_host(req.host) - if req.ssh_port and not _SSH_PORT_RE.fullmatch(req.ssh_port): - raise HTTPException(400, "Invalid ssh_port") + host = validate_remote_host(req.host) + req.ssh_port = validate_ssh_port(req.ssh_port) kill_cmd = f"kill -{sig} {req.pid}" try: if host: @@ -2382,14 +2383,19 @@ def setup_cookbook_routes() -> APIRouter: host = (srv.get("host") or "").strip() if not host: continue # local-only entry; the /proc scan handles it - if not _REMOTE_HOST_RE.match(host): + try: + host = validate_remote_host(host) + except HTTPException: continue sport = str(srv.get("port") or "").strip() ssh_base = ["ssh", "-o", "ConnectTimeout=4", "-o", "StrictHostKeyChecking=no"] if sport and sport != "22": - if not _SSH_PORT_RE.match(sport): + try: + sport = validate_ssh_port(sport) + except HTTPException: continue - ssh_base.extend(["-p", sport]) + if sport != "22": + ssh_base.extend(["-p", sport]) try: ls = subprocess.run( @@ -2743,12 +2749,18 @@ def setup_cookbook_routes() -> APIRouter: if not _SESSION_ID_RE.match(session_id): logger.warning(f"Skipping task with unsafe session_id: {session_id!r}") continue - if remote and not _REMOTE_HOST_RE.match(remote): - logger.warning(f"Skipping task with unsafe remoteHost: {remote!r}") - continue - if _tport and not _SSH_PORT_RE.match(str(_tport)): - logger.warning(f"Skipping task with unsafe sshPort: {_tport!r}") - continue + if remote: + try: + remote = validate_remote_host(remote) + except HTTPException: + logger.warning(f"Skipping task with unsafe remoteHost: {remote!r}") + continue + if _tport: + try: + _tport = validate_ssh_port(str(_tport)) + except HTTPException: + logger.warning(f"Skipping task with unsafe sshPort: {_tport!r}") + continue if task_platform == "windows" and remote: # Windows: check PID file + Get-Process, read log tail sd = "$env:TEMP\\odysseus-sessions" diff --git a/routes/hwfit_routes.py b/routes/hwfit_routes.py index eb408ac9d..564c3a03c 100644 --- a/routes/hwfit_routes.py +++ b/routes/hwfit_routes.py @@ -1,7 +1,9 @@ import re from copy import deepcopy -from fastapi import APIRouter +from fastapi import APIRouter, HTTPException + +from routes._validators import validate_remote_host, validate_ssh_port # Backends the manual hardware simulator accepts. Must stay a subset of what @@ -11,6 +13,14 @@ from fastapi import APIRouter _MANUAL_BACKENDS = {"cuda", "rocm", "metal", "cpu_x86", "cpu_arm"} +def _validate_detection_target(host: str = "", ssh_port: str = "") -> tuple[str, str]: + host_value = validate_remote_host(host) or "" + port_value = validate_ssh_port(ssh_port) or "" + if port_value and not host_value: + raise HTTPException(400, "ssh_port requires host") + return host_value, port_value + + def _apply_manual_hardware(system, manual_mode="", manual_gpu_count="", manual_vram_gb="", manual_ram_gb="", manual_backend=""): """Manual hardware is a "what if I had this setup" simulator — REPLACES the detected hardware entirely instead of adding to it. @@ -105,6 +115,7 @@ def setup_hwfit_routes(): """Detect and return current system hardware info. Pass host=user@server for remote. fresh=true bypasses the per-host cache (the Rescan button).""" from services.hwfit.hardware import detect_system + host, ssh_port = _validate_detection_target(host, ssh_port) return detect_system(host=host, ssh_port=ssh_port, platform=platform, fresh=fresh) @router.get("/models") @@ -118,6 +129,7 @@ def setup_hwfit_routes(): from services.hwfit.hardware import detect_system from services.hwfit.fit import rank_models from services.hwfit.models import get_models, model_catalog_path + host, ssh_port = _validate_detection_target(host, ssh_port) system = deepcopy(detect_system(host=host, ssh_port=ssh_port, platform=platform, fresh=fresh)) if system.get("error"): return {"system": system, "models": [], "error": system["error"]} @@ -229,6 +241,7 @@ def setup_hwfit_routes(): from services.hwfit.hardware import detect_system from services.hwfit.models import get_models from services.hwfit.profiles import compute_serve_profiles + host, ssh_port = _validate_detection_target(host, ssh_port) system = detect_system(host=host, ssh_port=ssh_port, platform=platform, fresh=fresh) if system.get("error"): return {"system": system, "profiles": [], "error": system["error"]} @@ -279,6 +292,7 @@ def setup_hwfit_routes(): """Rank image generation models against detected hardware.""" from services.hwfit.hardware import detect_system from services.hwfit.image_models import rank_image_models + host, ssh_port = _validate_detection_target(host, ssh_port) system = deepcopy(detect_system(host=host, ssh_port=ssh_port, platform=platform, fresh=fresh)) if system.get("error"): return {"system": system, "models": [], "error": system["error"]} diff --git a/tests/test_cookbook_helpers.py b/tests/test_cookbook_helpers.py index acc001812..779b48e3c 100644 --- a/tests/test_cookbook_helpers.py +++ b/tests/test_cookbook_helpers.py @@ -26,7 +26,6 @@ from routes.cookbook_helpers import ( _validate_repo_id, _validate_serve_cmd, _validate_serve_model_id, - _validate_ssh_port, _shell_path, run_ssh_command_async, ) @@ -106,12 +105,6 @@ def test_safe_env_prefix_accepts_powershell_activation_path(): ) -def test_validate_ssh_port_rejects_shell_payload(): - with pytest.raises(HTTPException): - _validate_ssh_port("22; touch /tmp/pwned") - assert _validate_ssh_port("2222") == "2222" - - def test_validate_local_dir_accepts_external_drive_paths_with_spaces(): path = "/Volumes/T7 2TB/AI Models/llamacpp" diff --git a/tests/test_hwfit_remote_validation.py b/tests/test_hwfit_remote_validation.py new file mode 100644 index 000000000..aee2aaadb --- /dev/null +++ b/tests/test_hwfit_remote_validation.py @@ -0,0 +1,47 @@ +import pytest +from fastapi import HTTPException + +from core.platform_compat import _ssh_exec_argv +from routes.hwfit_routes import setup_hwfit_routes + + +def _endpoint(path: str): + router = setup_hwfit_routes() + for route in router.routes: + if getattr(route, "path", "") == path: + return route.endpoint + raise AssertionError(f"{path} route not found") + + +@pytest.mark.parametrize( + "path,kwargs", + [ + ("/api/hwfit/system", {}), + ("/api/hwfit/models", {"limit": 1}), + ("/api/hwfit/profiles", {"model": "demo"}), + ("/api/hwfit/image-models", {}), + ], +) +def test_hwfit_routes_reject_ssh_option_host(path, kwargs): + endpoint = _endpoint(path) + + with pytest.raises(HTTPException) as exc: + endpoint(host="-oProxyCommand=sh", ssh_port="22", **kwargs) + + assert exc.value.status_code == 400 + + +def test_hwfit_routes_reject_port_without_host(): + endpoint = _endpoint("/api/hwfit/system") + + with pytest.raises(HTTPException) as exc: + endpoint(host="", ssh_port="2222") + + assert exc.value.status_code == 400 + + +def test_ssh_argv_rejects_option_shaped_remote(): + with pytest.raises(ValueError): + _ssh_exec_argv("-oProxyCommand=sh", "22", remote_cmd="true") + with pytest.raises(ValueError): + _ssh_exec_argv("alice@-oProxyCommand=sh", "22", remote_cmd="true") diff --git a/tests/test_route_validators.py b/tests/test_route_validators.py new file mode 100644 index 000000000..a6fc07a98 --- /dev/null +++ b/tests/test_route_validators.py @@ -0,0 +1,23 @@ +import pytest +from fastapi import HTTPException + +from routes._validators import validate_remote_host, validate_ssh_port + + +def test_validate_ssh_port_rejects_shell_payload(): + for port in ["22;id", "$(id)", "-p 22", "0", "65536"]: + with pytest.raises(HTTPException): + validate_ssh_port(port) + assert validate_ssh_port("2222") == "2222" + + +def test_validate_remote_host_rejects_ssh_option_shape(): + for host in [ + "-oProxyCommand=sh", + "alice@-oProxyCommand=sh", + "--", + "-p2222", + ]: + with pytest.raises(HTTPException): + validate_remote_host(host) + assert validate_remote_host("alice@gpu-box_1") == "alice@gpu-box_1" From 9c00da6d1ca2124439b2511065d8ff5fe2f0f3b5 Mon Sep 17 00:00:00 2001 From: Mazen Tamer Salah <78306991+mazen-salah@users.noreply.github.com> Date: Thu, 11 Jun 2026 02:01:58 +0300 Subject: [PATCH 17/18] fix(hwfit): tolerate non-numeric gpu_count in /api/hwfit/models (#3639) * fix(hwfit): tolerate non-numeric gpu_count in /api/hwfit/models The route did `n = int(gpu_count)` with no guard, so a non-numeric query param like `?gpu_count=abc` raised ValueError and returned HTTP 500. Parse it defensively (mirroring the gpu_group guard a few lines above): a malformed value is ignored, exactly like omitting the param, and valid values still apply. Adds tests/test_hwfit_gpu_count_nonnumeric.py: a non-numeric gpu_count returns a ranking instead of raising, and a numeric value is still accepted. * test(hwfit): cover non-numeric manual_gpu_count too Follow-up to the gpu_count guard: add a regression test for the sibling manual_gpu_count query param (the hardware simulator in _apply_manual_hardware), which dev already guards by defaulting to 1 on a non-numeric value. This pins that behaviour so the endpoint's count parsing is fully covered and cannot regress to a 500. --- routes/hwfit_routes.py | 10 +++++-- tests/test_hwfit_gpu_count_nonnumeric.py | 38 ++++++++++++++++++++++++ 2 files changed, 46 insertions(+), 2 deletions(-) create mode 100644 tests/test_hwfit_gpu_count_nonnumeric.py diff --git a/routes/hwfit_routes.py b/routes/hwfit_routes.py index 564c3a03c..45c209b0b 100644 --- a/routes/hwfit_routes.py +++ b/routes/hwfit_routes.py @@ -177,8 +177,14 @@ def setup_hwfit_routes(): system["gpu_name"] = g["name"] system["active_group"] = {**g, "use_count": n} - if gpu_count != "": - n = int(gpu_count) + # Parse the optional count defensively (matches the gpu_group guard + # above): a non-numeric query param previously raised ValueError -> + # HTTP 500. A malformed value is ignored, same as omitting it. + try: + n = int(gpu_count) if gpu_count != "" else None + except ValueError: + n = None + if n is not None: if n == 0: # RAM-only mode: rank against system memory, offload allowed. system["has_gpu"] = False diff --git a/tests/test_hwfit_gpu_count_nonnumeric.py b/tests/test_hwfit_gpu_count_nonnumeric.py new file mode 100644 index 000000000..13e6b2f25 --- /dev/null +++ b/tests/test_hwfit_gpu_count_nonnumeric.py @@ -0,0 +1,38 @@ +"""GET /api/hwfit/models must not 500 on a non-numeric gpu_count. + +The handler did `n = int(gpu_count)` with no guard, so `?gpu_count=abc` (or any +non-integer) raised ValueError -> HTTP 500. A malformed count is now ignored, +matching how the neighbouring gpu_group param is already parsed. +""" +from routes.hwfit_routes import setup_hwfit_routes + + +def _get_models(): + router = setup_hwfit_routes() + for route in router.routes: + if getattr(route, "path", "").endswith("/models") and "GET" in getattr(route, "methods", set()): + return route.endpoint + raise AssertionError("hwfit /models route not found") + + +def test_non_numeric_gpu_count_does_not_raise(): + handler = _get_models() + # Previously raised ValueError (HTTP 500); now degrades to a normal ranking. + result = handler(gpu_count="abc") + assert isinstance(result, dict) + + +def test_numeric_gpu_count_still_accepted(): + handler = _get_models() + result = handler(gpu_count="0") + assert isinstance(result, dict) + + +def test_non_numeric_manual_gpu_count_does_not_raise(): + # manual_gpu_count is the other count param on this endpoint (the hardware + # simulator in _apply_manual_hardware). A non-numeric value must also degrade + # (default to 1) rather than 500, so the endpoint's count parsing is fully + # covered. + handler = _get_models() + result = handler(manual_mode="gpu", manual_gpu_count="abc") + assert isinstance(result, dict) From d5603ee57551c00e59f9a6c7b4b07075fb66ef6f Mon Sep 17 00:00:00 2001 From: RaresKeY <158580472+RaresKeY@users.noreply.github.com> Date: Thu, 11 Jun 2026 02:17:02 +0300 Subject: [PATCH 18/18] fix(research): migrate active task owners on rename (#3618) --- app.py | 1 + routes/auth_routes.py | 14 ++++ src/research_handler.py | 16 ++++ tests/test_rename_user_owner_sync.py | 110 ++++++++++++++++++++++++++- 4 files changed, 139 insertions(+), 2 deletions(-) diff --git a/app.py b/app.py index 2e1677ca2..365eee94a 100644 --- a/app.py +++ b/app.py @@ -503,6 +503,7 @@ api_key_manager = components["api_key_manager"] preset_manager = components["preset_manager"] chat_processor = components["chat_processor"] research_handler = components["research_handler"] +app.state.research_handler = research_handler chat_handler = components["chat_handler"] model_discovery = components["model_discovery"] skills_manager = components["skills_manager"] diff --git a/routes/auth_routes.py b/routes/auth_routes.py index e67a4758f..b9158c93a 100644 --- a/routes/auth_routes.py +++ b/routes/auth_routes.py @@ -367,6 +367,20 @@ def setup_auth_routes(auth_manager: AuthManager) -> APIRouter: except Exception as e: logger.warning("Failed to rename user prefs %s -> %s: %s", old_username, new_username, e) + # In-flight deep-research tasks live in the process-local + # ResearchHandler registry. They are not covered by the persisted JSON + # migration above, but the research routes filter and cancel by this + # owner field while the job is running. Do this before sweeping + # completed JSON files so a job that finishes during the rename saves + # with the new owner or is caught by the disk sweep below. + try: + rh = getattr(request.app.state, "research_handler", None) + rename_owner = getattr(rh, "rename_owner", None) + if callable(rename_owner): + rename_owner(old_username, new_username) + except Exception as e: + logger.warning("Failed to rename active research tasks %s -> %s: %s", old_username, new_username, e) + # deep_research: each completed report is a standalone JSON file with # an `owner` field. research_routes filters by d.get("owner") == user, # so a stale owner makes every report invisible to the renamed user. diff --git a/src/research_handler.py b/src/research_handler.py index b3af3b8e5..f1d120ef2 100644 --- a/src/research_handler.py +++ b/src/research_handler.py @@ -221,6 +221,22 @@ class ResearchHandler: # Task registry — background research with persistence # ------------------------------------------------------------------ + def rename_owner(self, old_owner: str, new_owner: str) -> int: + """Move in-flight research tasks from one owner key to another.""" + old_key = str(old_owner or "").strip().lower() + new_key = str(new_owner or "").strip().lower() + if not old_key or not new_key: + return 0 + + changed = 0 + for entry in list(self._active_tasks.values()): + if not isinstance(entry, dict): + continue + if str(entry.get("owner", "")).strip().lower() == old_key: + entry["owner"] = new_key + changed += 1 + return changed + def start_research( self, session_id: str, diff --git a/tests/test_rename_user_owner_sync.py b/tests/test_rename_user_owner_sync.py index 24e1fb67c..e5e89b4dc 100644 --- a/tests/test_rename_user_owner_sync.py +++ b/tests/test_rename_user_owner_sync.py @@ -11,7 +11,10 @@ owner column, but three file-backed / in-memory stores are left stale: research_routes filters by `d.get("owner") == user`, making every report invisible after rename. -3. data/memory.json — a flat array where every entry has an `owner` field; +3. research_handler._active_tasks — in-flight research jobs carry the same + owner key while status/cancel/active routes filter by it. + +4. data/memory.json — a flat array where every entry has an `owner` field; memory_manager.load(owner=user) filters on it, so all memories vanish. Regression coverage: these bugs are invisible in unit tests that mock the DB @@ -64,10 +67,11 @@ def rename_endpoint(monkeypatch, tmp_path): return _route(ar.setup_auth_routes(am), "rename_user"), am, tmp_path -def _request(tmp_path, session_manager=None, token="t"): +def _request(tmp_path, session_manager=None, token="t", research_handler=None): state = SimpleNamespace( invalidate_token_cache=lambda: None, session_manager=session_manager, + research_handler=research_handler, ) return SimpleNamespace( cookies={"odysseus_session": token}, @@ -234,6 +238,108 @@ def test_rename_no_deep_research_dir_does_not_crash(rename_endpoint): assert res["ok"] is True +def test_rename_updates_active_research_task_owner(rename_endpoint): + endpoint, _am, tmp_path = rename_endpoint + + from routes.research_routes import setup_research_routes + from src.research_handler import ResearchHandler + + rh = ResearchHandler.__new__(ResearchHandler) + rh._active_tasks = { + "alice-task": { + "owner": "Alice", + "status": "running", + "query": "q", + "progress": {}, + "started_at": 1, + }, + "carol-task": { + "owner": "carol", + "status": "running", + "query": "q2", + "progress": {}, + "started_at": 2, + }, + } + + asyncio.run(endpoint( + "alice", + SimpleNamespace(username="alice2"), + _request(tmp_path, research_handler=rh), + )) + + assert rh._active_tasks["alice-task"]["owner"] == "alice2" + assert rh._active_tasks["carol-task"]["owner"] == "carol" + + router = setup_research_routes(rh) + active = next( + r.endpoint for r in router.routes + if getattr(r, "path", "") == "/api/research/active" + ) + + alice2 = asyncio.run(active( + SimpleNamespace(state=SimpleNamespace(current_user="alice2")), + )) + alice = asyncio.run(active( + SimpleNamespace(state=SimpleNamespace(current_user="alice")), + )) + + assert [item["session_id"] for item in alice2["active"]] == ["alice-task"] + assert alice["active"] == [] + + +def test_research_handler_rename_owner_canonicalizes_new_owner(): + from src.research_handler import ResearchHandler + + rh = ResearchHandler.__new__(ResearchHandler) + rh._active_tasks = { + "task": {"owner": "Alice", "status": "running"}, + } + + changed = rh.rename_owner("alice", "Alice2") + assert changed == 1 + assert rh._active_tasks["task"]["owner"] == "alice2" + + +def test_research_handler_rename_owner_uses_auth_lower_contract_not_casefold(): + from src.research_handler import ResearchHandler + + rh = ResearchHandler.__new__(ResearchHandler) + rh._active_tasks = { + "task-strasse": {"owner": "strasse", "status": "running"}, + "task-sharp-s": {"owner": "straße", "status": "running"}, + } + + changed = rh.rename_owner("straße", "renamed") + + assert changed == 1 + assert rh._active_tasks["task-strasse"]["owner"] == "strasse" + assert rh._active_tasks["task-sharp-s"]["owner"] == "renamed" + + +def test_rename_updates_active_research_before_completed_json_sweep(rename_endpoint): + endpoint, _am, tmp_path = rename_endpoint + + dr_dir = tmp_path / "deep_research" + dr_dir.mkdir() + report = dr_dir / "race-window.json" + report.write_text(json.dumps({"owner": "alice", "status": "done"}), encoding="utf-8") + owner_seen_by_active_hook = [] + + class FakeResearchHandler: + def rename_owner(self, _old, _new): + owner_seen_by_active_hook.append(json.loads(report.read_text(encoding="utf-8"))["owner"]) + + asyncio.run(endpoint( + "alice", + SimpleNamespace(username="alice2"), + _request(tmp_path, research_handler=FakeResearchHandler()), + )) + + assert owner_seen_by_active_hook == ["alice"] + assert json.loads(report.read_text(encoding="utf-8"))["owner"] == "alice2" + + def test_rename_research_respects_custom_data_dir(monkeypatch, tmp_path): """DEEP_RESEARCH_DIR (which honours ODYSSEUS_DATA_DIR) is used, not a hardcoded relative path. Before the fix, setting ODYSSEUS_DATA_DIR made