diff --git a/README.md b/README.md index a320f0052..a0dde96a9 100644 --- a/README.md +++ b/README.md @@ -329,7 +329,7 @@ To expose Odysseus on a local network or Tailscale with HTTPS: | Package | Feature unlocked | |---------|-----------------| | `faster-whisper` | Local speech-to-text (microphone -> text) via the "local" STT provider. | -| `duckduckgo-search` | DuckDuckGo as a search provider option. | +| `ddgs` | DuckDuckGo as a search provider option. | | `PyMuPDF` | PDF page rendering in the side viewer panel and form-filling. (Note: AGPL-3.0) | | `markitdown` | Office/EPUB document text extraction (converts .docx/.xlsx/.pptx/.xls/.epub to Markdown). | diff --git a/app.py b/app.py index cfd73e83f..365eee94a 100644 --- a/app.py +++ b/app.py @@ -56,7 +56,7 @@ from core.constants import ( ) from core.database import SessionLocal, ApiToken from core.middleware import SecurityHeadersMiddleware, is_cors_preflight -from core.auth import AuthManager +from core.auth import AuthManager, normalize_known_username from core.exceptions import ( SessionNotFoundError, InvalidFileUploadError, LLMServiceError, WebSearchError, @@ -228,8 +228,16 @@ if AUTH_ENABLED: try: rows = db.query(ApiToken).filter(ApiToken.is_active == True).all() for r in rows: + owner_key = normalize_known_username(auth_manager.users, getattr(r, "owner", None)) + if not owner_key: + logger.warning( + "Ignoring active API token '%s' for unknown auth user '%s'", + getattr(r, "id", ""), + getattr(r, "owner", None), + ) + continue scopes = [s.strip() for s in (getattr(r, "scopes", "") or "chat").split(",") if s.strip()] - new_map[r.token_prefix].append((r.id, r.token_hash, getattr(r, "owner", None), scopes)) + new_map[r.token_prefix].append((r.id, r.token_hash, owner_key, scopes)) finally: db.close() _token_cache.clear() @@ -495,6 +503,7 @@ api_key_manager = components["api_key_manager"] preset_manager = components["preset_manager"] chat_processor = components["chat_processor"] research_handler = components["research_handler"] +app.state.research_handler = research_handler chat_handler = components["chat_handler"] model_discovery = components["model_discovery"] skills_manager = components["skills_manager"] @@ -938,16 +947,21 @@ async def _startup_event(): async def _warmup_endpoints(): try: import httpx - endpoints = model_discovery.get_endpoints() if model_discovery else [] - for ep in endpoints[:5]: - url = ep.get("url", "").replace("/chat/completions", "/models") - if url: - try: - async with httpx.AsyncClient(timeout=5.0) as client: - await client.get(url) - logger.info(f"Warmup ping OK: {url}") - except Exception as e: - logger.debug(f"Warmup ping failed for endpoint: {e}") + # model_discovery has no get_endpoints(); that call raised + # AttributeError every run and silently disabled warmup/keepalive. + # Resolve the /models probe URLs via the real discovery API, off the + # event loop since discovery does a blocking port scan. + urls = ( + await asyncio.to_thread(model_discovery.warmup_ping_urls) + if model_discovery else [] + ) + for url in urls: + try: + async with httpx.AsyncClient(timeout=5.0) as client: + await client.get(url) + logger.info(f"Warmup ping OK: {url}") + except Exception as e: + logger.debug(f"Warmup ping failed for endpoint: {e}") except Exception as e: logger.debug(f"Warmup ping skipped: {e}") diff --git a/core/auth.py b/core/auth.py index 5db2fed4c..2f9fd4e51 100644 --- a/core/auth.py +++ b/core/auth.py @@ -67,6 +67,14 @@ TOKEN_TTL = 60 * 60 * 24 * 7 # 7 days RESERVED_USERNAMES = frozenset({"internal-tool", "api", "demo", "system"}) +def normalize_known_username(users: Dict[str, Any], username: str | None) -> Optional[str]: + """Return a normalized username only when it exists in the auth user map.""" + key = str(username or "").strip().lower() + if not key or key not in users: + return None + return key + + def _hash_password(password: str) -> str: return bcrypt.hashpw(password.encode("utf-8"), bcrypt.gensalt()).decode("utf-8") @@ -96,6 +104,7 @@ class AuthManager: self._load() self._load_sessions() self._migrate_single_user() + self._drop_reserved_loaded_users() self._migrate_legacy_admin_role() def _load(self): @@ -148,7 +157,13 @@ class AuthManager: def _migrate_single_user(self): """Migrate old single-user format to multi-user format.""" if "password_hash" in self._config and "users" not in self._config: - old_user = self._config.get("username", "admin") + old_user = str(self._config.get("username", "admin") or "admin").strip().lower() + if old_user in RESERVED_USERNAMES: + logger.warning( + "Migrating legacy single-user reserved username '%s' to 'admin'", + old_user, + ) + old_user = "admin" old_hash = self._config["password_hash"] self._config = { "users": { @@ -162,6 +177,30 @@ class AuthManager: self._save() logger.info(f"Migrated single-user auth to multi-user (admin: {old_user})") + def _drop_reserved_loaded_users(self): + """Fail closed for legacy/manual auth rows that collide with sentinels.""" + users = self._config.get("users") + if not isinstance(users, dict): + return + normalized = {} + removed = [] + for username, data in users.items(): + key = str(username or "").strip().lower() + if not key: + continue + if key in RESERVED_USERNAMES: + removed.append(key) + continue + normalized[key] = data + if removed or normalized != users: + self._config["users"] = normalized + self._save() + if removed: + logger.warning( + "Removed reserved username(s) from auth config: %s", + ", ".join(sorted(set(removed))), + ) + def _migrate_legacy_admin_role(self): """Normalize setup.py's old role='admin' marker to is_admin=True.""" changed = False @@ -244,6 +283,22 @@ class AuthManager: return False if not self.users.get(requesting_user, {}).get("is_admin"): return False + # Revoke API bearer tokens before removing the auth row. The bearer + # path authenticates from ApiToken rows and does not require the + # owner to still exist, so a successful delete must not leave active + # rows behind. If the token store is unavailable, fail closed and + # keep the user/session state intact so the admin can retry. + try: + from core.database import get_db_session, ApiToken + with get_db_session() as db: + removed_tokens = db.query(ApiToken).filter(ApiToken.owner == username).delete() + if removed_tokens: + logger.info( + f"Revoked {removed_tokens} API token(s) owned by deleted user '{username}'" + ) + except Exception: + logger.warning(f"Failed to revoke API tokens for deleted user '{username}'") + return False del self._config["users"][username] self._save() # Purge all sessions belonging to this user. validate_token doesn't @@ -258,18 +313,6 @@ class AuthManager: revoked += 1 if revoked: self._save_sessions() - # Also revoke API bearer tokens owned by this user. The bearer auth - # path authenticates straight against ApiToken rows and never - # re-checks that the owner still exists, so leaving the rows behind - # would let a deleted user keep full API access indefinitely. - try: - from core.database import get_db_session, ApiToken - with get_db_session() as db: - removed = db.query(ApiToken).filter(ApiToken.owner == username).delete() - if removed: - logger.info(f"Revoked {removed} API token(s) owned by deleted user '{username}'") - except Exception: - logger.warning(f"Failed to revoke API tokens for deleted user '{username}'") logger.info(f"Deleted user '{username}' (by {requesting_user}); revoked {revoked} active session(s)") return True diff --git a/core/database.py b/core/database.py index ee365c30c..6eec48d11 100644 --- a/core/database.py +++ b/core/database.py @@ -688,6 +688,7 @@ def _migrate_add_last_message_at_column(): db_path = DATABASE_URL.replace("sqlite:///", "") if not os.path.exists(db_path): return + conn = None try: conn = sqlite3.connect(db_path) cursor = conn.execute("PRAGMA table_info(sessions)") @@ -713,10 +714,14 @@ def _migrate_add_last_message_at_column(): "ON sessions(archived, last_message_at)" ) conn.commit() - conn.close() logging.getLogger(__name__).info("Migrated: added + backfilled 'last_message_at' on sessions") except Exception as e: logging.getLogger(__name__).warning(f"last_message_at migration failed: {e}") + finally: + try: + conn.close() + except Exception: + pass def _migrate_add_document_archived_column(): """Add `archived` to documents (soft-archive flag). Guarded + idempotent.""" @@ -724,6 +729,7 @@ def _migrate_add_document_archived_column(): db_path = DATABASE_URL.replace("sqlite:///", "") if not os.path.exists(db_path): return + conn = None try: conn = sqlite3.connect(db_path) cursor = conn.execute("PRAGMA table_info(documents)") @@ -732,9 +738,13 @@ def _migrate_add_document_archived_column(): conn.execute("ALTER TABLE documents ADD COLUMN archived BOOLEAN DEFAULT 0") conn.commit() logging.getLogger(__name__).info("Migrated: added 'archived' to documents") - conn.close() except Exception as e: logging.getLogger(__name__).warning(f"documents.archived migration failed: {e}") + finally: + try: + conn.close() + except Exception: + pass def _migrate_add_owner_column(): @@ -743,6 +753,7 @@ def _migrate_add_owner_column(): db_path = DATABASE_URL.replace("sqlite:///", "") if not os.path.exists(db_path): return + conn = None try: conn = sqlite3.connect(db_path) cursor = conn.execute("PRAGMA table_info(sessions)") @@ -752,9 +763,13 @@ def _migrate_add_owner_column(): conn.execute("CREATE INDEX IF NOT EXISTS ix_sessions_owner ON sessions(owner)") conn.commit() logging.getLogger(__name__).info("Migrated: added 'owner' column to sessions") - conn.close() except Exception as e: logging.getLogger(__name__).warning(f"Migration check failed: {e}") + finally: + try: + conn.close() + except Exception: + pass def _migrate_model_endpoints(): """Recreate model_endpoints table if schema changed (url->base_url).""" @@ -762,6 +777,7 @@ def _migrate_model_endpoints(): db_path = DATABASE_URL.replace("sqlite:///", "") if not os.path.exists(db_path): return + conn = None try: conn = sqlite3.connect(db_path) cursor = conn.execute("PRAGMA table_info(model_endpoints)") @@ -770,9 +786,13 @@ def _migrate_model_endpoints(): conn.execute("DROP TABLE IF EXISTS model_endpoints") conn.commit() logging.getLogger(__name__).info("Migrated: dropped old model_endpoints table (schema change)") - conn.close() except Exception as e: logging.getLogger(__name__).warning(f"model_endpoints migration check failed: {e}") + finally: + try: + conn.close() + except Exception: + pass def _migrate_add_hidden_models_column(): """Add hidden_models column to model_endpoints if it doesn't exist.""" @@ -780,6 +800,7 @@ def _migrate_add_hidden_models_column(): db_path = DATABASE_URL.replace("sqlite:///", "") if not os.path.exists(db_path): return + conn = None try: conn = sqlite3.connect(db_path) cursor = conn.execute("PRAGMA table_info(model_endpoints)") @@ -788,9 +809,13 @@ def _migrate_add_hidden_models_column(): conn.execute("ALTER TABLE model_endpoints ADD COLUMN hidden_models TEXT") conn.commit() logging.getLogger(__name__).info("Migrated: added 'hidden_models' column to model_endpoints") - conn.close() except Exception as e: logging.getLogger(__name__).warning(f"hidden_models migration failed: {e}") + finally: + try: + conn.close() + except Exception: + pass def _migrate_add_model_endpoint_owner_column(): """Add owner column to model_endpoints if it doesn't exist. @@ -805,6 +830,7 @@ def _migrate_add_model_endpoint_owner_column(): db_path = DATABASE_URL.replace("sqlite:///", "") if not os.path.exists(db_path): return + conn = None try: conn = sqlite3.connect(db_path) cursor = conn.execute("PRAGMA table_info(model_endpoints)") @@ -814,9 +840,13 @@ def _migrate_add_model_endpoint_owner_column(): conn.execute("CREATE INDEX IF NOT EXISTS ix_model_endpoints_owner ON model_endpoints(owner)") conn.commit() logging.getLogger(__name__).info("Migrated: added 'owner' column + index to model_endpoints") - conn.close() except Exception as e: logging.getLogger(__name__).warning(f"model_endpoints.owner migration failed: {e}") + finally: + try: + conn.close() + except Exception: + pass def _migrate_add_provider_auth_id_column(): @@ -825,6 +855,7 @@ def _migrate_add_provider_auth_id_column(): db_path = DATABASE_URL.replace("sqlite:///", "") if not os.path.exists(db_path): return + conn = None try: conn = sqlite3.connect(db_path) cursor = conn.execute("PRAGMA table_info(model_endpoints)") @@ -834,9 +865,13 @@ def _migrate_add_provider_auth_id_column(): conn.execute("CREATE INDEX IF NOT EXISTS ix_model_endpoints_provider_auth_id ON model_endpoints(provider_auth_id)") conn.commit() logging.getLogger(__name__).info("Migrated: added 'provider_auth_id' column + index to model_endpoints") - conn.close() except Exception as e: logging.getLogger(__name__).warning(f"model_endpoints.provider_auth_id migration failed: {e}") + finally: + try: + conn.close() + except Exception: + pass def _migrate_add_model_type_column(): @@ -845,6 +880,7 @@ def _migrate_add_model_type_column(): db_path = DATABASE_URL.replace("sqlite:///", "") if not os.path.exists(db_path): return + conn = None try: conn = sqlite3.connect(db_path) cursor = conn.execute("PRAGMA table_info(model_endpoints)") @@ -853,9 +889,13 @@ def _migrate_add_model_type_column(): conn.execute("ALTER TABLE model_endpoints ADD COLUMN model_type TEXT DEFAULT 'llm'") conn.commit() logging.getLogger(__name__).info("Migrated: added 'model_type' column to model_endpoints") - conn.close() except Exception as e: logging.getLogger(__name__).warning(f"model_type migration failed: {e}") + finally: + try: + conn.close() + except Exception: + pass def _migrate_add_model_endpoint_refresh_columns(): """Add endpoint classification / refresh policy columns if missing.""" @@ -863,6 +903,7 @@ def _migrate_add_model_endpoint_refresh_columns(): db_path = DATABASE_URL.replace("sqlite:///", "") if not os.path.exists(db_path): return + conn = None try: conn = sqlite3.connect(db_path) cursor = conn.execute("PRAGMA table_info(model_endpoints)") @@ -876,9 +917,13 @@ def _migrate_add_model_endpoint_refresh_columns(): if columns and "model_refresh_timeout" not in columns: conn.execute("ALTER TABLE model_endpoints ADD COLUMN model_refresh_timeout INTEGER") conn.commit() - conn.close() except Exception as e: logging.getLogger(__name__).warning(f"model_endpoints refresh-policy migration failed: {e}") + finally: + try: + conn.close() + except Exception: + pass def _migrate_add_task_run_model_column(): """Add model column to task_runs if it doesn't exist (records which model ran).""" @@ -886,6 +931,7 @@ def _migrate_add_task_run_model_column(): db_path = DATABASE_URL.replace("sqlite:///", "") if not os.path.exists(db_path): return + conn = None try: conn = sqlite3.connect(db_path) cursor = conn.execute("PRAGMA table_info(task_runs)") @@ -894,9 +940,13 @@ def _migrate_add_task_run_model_column(): conn.execute("ALTER TABLE task_runs ADD COLUMN model TEXT") conn.commit() logging.getLogger(__name__).info("Migrated: added 'model' column to task_runs") - conn.close() except Exception as e: logging.getLogger(__name__).warning(f"task_runs model migration failed: {e}") + finally: + try: + conn.close() + except Exception: + pass def _migrate_add_supports_tools_column(): """Add supports_tools column to model_endpoints if it doesn't exist.""" @@ -904,6 +954,7 @@ def _migrate_add_supports_tools_column(): db_path = DATABASE_URL.replace("sqlite:///", "") if not os.path.exists(db_path): return + conn = None try: conn = sqlite3.connect(db_path) cursor = conn.execute("PRAGMA table_info(model_endpoints)") @@ -912,9 +963,13 @@ def _migrate_add_supports_tools_column(): conn.execute("ALTER TABLE model_endpoints ADD COLUMN supports_tools BOOLEAN") conn.commit() logging.getLogger(__name__).info("Migrated: added 'supports_tools' column to model_endpoints") - conn.close() except Exception as e: logging.getLogger(__name__).warning(f"supports_tools migration failed: {e}") + finally: + try: + conn.close() + except Exception: + pass def _migrate_add_cached_models_column(): @@ -923,6 +978,7 @@ def _migrate_add_cached_models_column(): db_path = DATABASE_URL.replace("sqlite:///", "") if not os.path.exists(db_path): return + conn = None try: conn = sqlite3.connect(db_path) cursor = conn.execute("PRAGMA table_info(model_endpoints)") @@ -930,9 +986,13 @@ def _migrate_add_cached_models_column(): if columns and "cached_models" not in columns: conn.execute("ALTER TABLE model_endpoints ADD COLUMN cached_models TEXT") conn.commit() - conn.close() except Exception as e: logging.getLogger(__name__).warning(f"cached_models migration failed: {e}") + finally: + try: + conn.close() + except Exception: + pass def _migrate_add_pinned_models_column(): """Add pinned_models column to model_endpoints if it doesn't exist.""" @@ -940,6 +1000,7 @@ def _migrate_add_pinned_models_column(): db_path = DATABASE_URL.replace("sqlite:///", "") if not os.path.exists(db_path): return + conn = None try: conn = sqlite3.connect(db_path) cursor = conn.execute("PRAGMA table_info(model_endpoints)") @@ -948,9 +1009,13 @@ def _migrate_add_pinned_models_column(): conn.execute("ALTER TABLE model_endpoints ADD COLUMN pinned_models TEXT") conn.commit() logging.getLogger(__name__).info("Migrated: added 'pinned_models' column to model_endpoints") - conn.close() except Exception as e: logging.getLogger(__name__).warning(f"pinned_models migration failed: {e}") + finally: + try: + conn.close() + except Exception: + pass def _migrate_add_notes_sort_order(): """Add sort_order, image_url, repeat columns to notes if they don't exist.""" @@ -958,6 +1023,7 @@ def _migrate_add_notes_sort_order(): db_path = DATABASE_URL.replace("sqlite:///", "") if not os.path.exists(db_path): return + conn = None try: conn = sqlite3.connect(db_path) cursor = conn.execute("PRAGMA table_info(notes)") @@ -975,9 +1041,13 @@ def _migrate_add_notes_sort_order(): if columns and "agent_session_id" not in columns: conn.execute("ALTER TABLE notes ADD COLUMN agent_session_id TEXT") conn.commit() - conn.close() except Exception as e: logging.getLogger(__name__).warning(f"notes migration failed: {e}") + finally: + try: + conn.close() + except Exception: + pass def _migrate_add_mode_column(): """Add mode column to sessions table if it doesn't exist.""" @@ -985,6 +1055,7 @@ def _migrate_add_mode_column(): db_path = DATABASE_URL.replace("sqlite:///", "") if not os.path.exists(db_path): return + conn = None try: conn = sqlite3.connect(db_path) cursor = conn.execute("PRAGMA table_info(sessions)") @@ -993,9 +1064,13 @@ def _migrate_add_mode_column(): conn.execute("ALTER TABLE sessions ADD COLUMN mode TEXT") conn.commit() logging.getLogger(__name__).info("Migrated: added 'mode' column to sessions") - conn.close() except Exception as e: logging.getLogger(__name__).warning(f"Migration check for mode failed: {e}") + finally: + try: + conn.close() + except Exception: + pass def _migrate_add_folder_column(): """Add folder column to sessions table if it doesn't exist.""" @@ -1003,6 +1078,7 @@ def _migrate_add_folder_column(): db_path = DATABASE_URL.replace("sqlite:///", "") if not os.path.exists(db_path): return + conn = None try: conn = sqlite3.connect(db_path) cursor = conn.execute("PRAGMA table_info(sessions)") @@ -1011,9 +1087,13 @@ def _migrate_add_folder_column(): conn.execute("ALTER TABLE sessions ADD COLUMN folder TEXT") conn.commit() logging.getLogger(__name__).info("Migrated: added 'folder' column to sessions") - conn.close() except Exception as e: logging.getLogger(__name__).warning(f"Migration check for folder failed: {e}") + finally: + try: + conn.close() + except Exception: + pass def _migrate_add_token_columns(): """Add cumulative token tracking columns to sessions table.""" @@ -1021,6 +1101,7 @@ def _migrate_add_token_columns(): db_path = DATABASE_URL.replace("sqlite:///", "") if not os.path.exists(db_path): return + conn = None try: conn = sqlite3.connect(db_path) cursor = conn.execute("PRAGMA table_info(sessions)") @@ -1030,9 +1111,13 @@ def _migrate_add_token_columns(): conn.execute("ALTER TABLE sessions ADD COLUMN total_output_tokens INTEGER DEFAULT 0") conn.commit() logging.getLogger(__name__).info("Migrated: added token tracking columns to sessions") - conn.close() except Exception as e: logging.getLogger(__name__).warning(f"Migration check for token columns failed: {e}") + finally: + try: + conn.close() + except Exception: + pass def _migrate_add_owner_to_table(table_name: str, index_name: str): """Generic helper: add owner TEXT column + index to a table if missing.""" @@ -1040,6 +1125,7 @@ def _migrate_add_owner_to_table(table_name: str, index_name: str): db_path = DATABASE_URL.replace("sqlite:///", "") if not os.path.exists(db_path): return + conn = None try: conn = sqlite3.connect(db_path) cursor = conn.execute(f"PRAGMA table_info({table_name})") @@ -1049,9 +1135,13 @@ def _migrate_add_owner_to_table(table_name: str, index_name: str): conn.execute(f"CREATE INDEX IF NOT EXISTS {index_name} ON {table_name}(owner)") conn.commit() logging.getLogger(__name__).info(f"Migrated: added 'owner' column to {table_name}") - conn.close() except Exception as e: logging.getLogger(__name__).warning(f"Migration owner column for {table_name} failed: {e}") + finally: + try: + conn.close() + except Exception: + pass def _migrate_add_multiuser_owner_columns(): """Add owner column to memories, gallery_images, user_tools, comparisons.""" @@ -1076,6 +1166,7 @@ def _migrate_add_api_token_scopes_column(): db_path = DATABASE_URL.replace("sqlite:///", "") if not os.path.exists(db_path): return + conn = None try: conn = sqlite3.connect(db_path) columns = [row[1] for row in conn.execute("PRAGMA table_info(api_tokens)").fetchall()] @@ -1084,9 +1175,13 @@ def _migrate_add_api_token_scopes_column(): conn.execute("UPDATE api_tokens SET scopes = 'chat' WHERE scopes IS NULL OR scopes = ''") conn.commit() logging.getLogger(__name__).info("Migrated: added scopes column to api_tokens") - conn.close() except Exception as e: logging.getLogger(__name__).warning(f"api_tokens.scopes migration failed: {e}") + finally: + try: + conn.close() + except Exception: + pass def _migrate_assign_legacy_owner(): """Assign all null-owner data to the first (admin) user. @@ -1128,6 +1223,7 @@ def _migrate_assign_legacy_owner(): return logger = logging.getLogger(__name__) + conn = None try: conn = sqlite3.connect(db_path) # Every table with an `owner` column. New tables added later will be @@ -1152,9 +1248,13 @@ def _migrate_assign_legacy_owner(): except Exception as e: logger.warning(f"Legacy owner assignment for {table} failed: {e}") conn.commit() - conn.close() except Exception as e: logger.warning(f"Legacy owner migration failed: {e}") + finally: + try: + conn.close() + except Exception: + pass # Also migrate memory.json mem_path = MEMORY_FILE @@ -1773,6 +1873,7 @@ def _migrate_add_email_smtp_security(): db_path = DATABASE_URL.replace("sqlite:///", "") if not os.path.exists(db_path): return + conn = None try: conn = sqlite3.connect(db_path) cursor = conn.execute("PRAGMA table_info(email_accounts)") @@ -1788,9 +1889,13 @@ def _migrate_add_email_smtp_security(): ) conn.commit() logging.getLogger(__name__).info("Migrated: added smtp_security column to email_accounts") - conn.close() except Exception as e: logging.getLogger(__name__).warning(f"smtp_security migration skipped: {e}") + finally: + try: + conn.close() + except Exception: + pass def _migrate_encrypt_endpoint_keys(): @@ -1891,6 +1996,7 @@ def _migrate_add_calendar_is_utc(): db_path = DATABASE_URL.replace("sqlite:///", "") if not os.path.exists(db_path): return + conn = None try: conn = sqlite3.connect(db_path) cursor = conn.execute("PRAGMA table_info(calendar_events)") @@ -1899,9 +2005,13 @@ def _migrate_add_calendar_is_utc(): conn.execute("ALTER TABLE calendar_events ADD COLUMN is_utc BOOLEAN DEFAULT 0 NOT NULL") conn.commit() logging.getLogger(__name__).info("Migrated: added 'is_utc' column to calendar_events") - conn.close() except Exception as e: logging.getLogger(__name__).warning(f"is_utc migration failed: {e}") + finally: + try: + conn.close() + except Exception: + pass def _migrate_add_calendar_origin(): @@ -1912,6 +2022,7 @@ def _migrate_add_calendar_origin(): db_path = DATABASE_URL.replace("sqlite:///", "") if not os.path.exists(db_path): return + conn = None try: conn = sqlite3.connect(db_path) cursor = conn.execute("PRAGMA table_info(calendar_events)") @@ -1921,9 +2032,13 @@ def _migrate_add_calendar_origin(): conn.execute("CREATE INDEX IF NOT EXISTS ix_calendar_events_origin ON calendar_events(origin)") conn.commit() logging.getLogger(__name__).info("Migrated: added 'origin' column to calendar_events") - conn.close() except Exception as e: logging.getLogger(__name__).warning(f"calendar_events.origin migration failed: {e}") + finally: + try: + conn.close() + except Exception: + pass def _migrate_add_calendar_account_id(): @@ -1933,6 +2048,7 @@ def _migrate_add_calendar_account_id(): db_path = DATABASE_URL.replace("sqlite:///", "") if not os.path.exists(db_path): return + conn = None try: conn = sqlite3.connect(db_path) cursor = conn.execute("PRAGMA table_info(calendars)") @@ -1942,9 +2058,13 @@ def _migrate_add_calendar_account_id(): conn.execute("CREATE INDEX IF NOT EXISTS ix_calendars_account_id ON calendars(account_id)") conn.commit() logging.getLogger(__name__).info("Migrated: added 'account_id' column to calendars") - conn.close() except Exception as e: logging.getLogger(__name__).warning(f"calendars.account_id migration failed: {e}") + finally: + try: + conn.close() + except Exception: + pass def _migrate_add_calendar_metadata(): @@ -1953,6 +2073,7 @@ def _migrate_add_calendar_metadata(): db_path = DATABASE_URL.replace("sqlite:///", "") if not os.path.exists(db_path): return + conn = None try: conn = sqlite3.connect(db_path) cursor = conn.execute("PRAGMA table_info(calendar_events)") @@ -1964,9 +2085,13 @@ def _migrate_add_calendar_metadata(): if columns and "last_pinged" not in columns: conn.execute("ALTER TABLE calendar_events ADD COLUMN last_pinged DATETIME") conn.commit() - conn.close() except Exception as e: logging.getLogger(__name__).warning(f"calendar_events migration failed: {e}") + finally: + try: + conn.close() + except Exception: + pass def get_db(): """ diff --git a/core/platform_compat.py b/core/platform_compat.py index 3eda4a107..b3b157111 100644 --- a/core/platform_compat.py +++ b/core/platform_compat.py @@ -366,6 +366,10 @@ def _ssh_exec_argv( strict_host_key_checking: bool | None = None, ) -> list[str]: """Build a consistent ssh argv for remote command execution.""" + remote_value = str(remote or "").strip() + remote_host = remote_value.rsplit("@", 1)[-1] + if not remote_value or remote_value.startswith("-") or not remote_host or remote_host.startswith("-"): + raise ValueError("Invalid SSH remote host") argv = ["ssh"] if connect_timeout is not None: argv.extend(["-o", f"ConnectTimeout={int(connect_timeout)}"]) diff --git a/requirements-optional.txt b/requirements-optional.txt index eeb57c151..b4b654232 100644 --- a/requirements-optional.txt +++ b/requirements-optional.txt @@ -15,7 +15,7 @@ faster-whisper # DuckDuckGo as a search provider option. # Install if you want DDG in the search-provider dropdown. # Alternatives: SearXNG, Brave, Tavily, Serper, Google PSE. -duckduckgo-search +ddgs # PDF form-filling feature (fillable AcroForm detection, field extraction, # value/annotation/signature stamping, page rendering for the form overlay). diff --git a/routes/_validators.py b/routes/_validators.py new file mode 100644 index 000000000..aa4cf00cc --- /dev/null +++ b/routes/_validators.py @@ -0,0 +1,31 @@ +import re + +from fastapi import HTTPException + + +_REMOTE_HOST_RE = re.compile( + r"^(?:[A-Za-z0-9][A-Za-z0-9._-]*@)?[A-Za-z0-9][A-Za-z0-9._-]*$" +) +_SSH_PORT_RE = re.compile(r"^\d{1,5}$") + + +def validate_remote_host(v: str | None) -> str | None: + if v is None or v == "": + return None + if not _REMOTE_HOST_RE.match(v): + raise HTTPException( + 400, + "Invalid remote_host — must be host or user@host, no SSH option syntax", + ) + return v + + +def validate_ssh_port(v: str | None) -> str | None: + if v is None or v == "": + return None + if not _SSH_PORT_RE.fullmatch(str(v)): + raise HTTPException(400, "Invalid ssh_port") + port = int(v) + if port < 1 or port > 65535: + raise HTTPException(400, "Invalid ssh_port") + return str(port) diff --git a/routes/auth_routes.py b/routes/auth_routes.py index c20860892..b9158c93a 100644 --- a/routes/auth_routes.py +++ b/routes/auth_routes.py @@ -305,6 +305,19 @@ def setup_auth_routes(auth_manager: AuthManager) -> APIRouter: if not ok: raise HTTPException(400, "Cannot rename user") + def _rollback_auth_rename() -> bool: + # On self-rename the admin session has already moved to the new + # username, so the rollback must authenticate as the new user. + rollback_user = new_username if user == old_username else user + try: + return bool(auth_manager.rename_user(new_username, old_username, rollback_user)) + except Exception as rollback_err: + logger.error( + "Failed to roll back auth rename %s -> %s after owner migration failure: %s", + new_username, old_username, rollback_err, + ) + return False + # Usernames are ownership keys for user data. Rename the common # owner-scoped DB rows so the account keeps access to its sessions, # docs, email accounts, tasks, etc. @@ -330,6 +343,11 @@ def setup_auth_routes(auth_manager: AuthManager) -> APIRouter: db.close() except Exception as e: logger.error("Failed to rename owner references %s -> %s: %s", old_username, new_username, e) + if not _rollback_auth_rename(): + logger.error( + "Auth rename %s -> %s could not be rolled back after owner migration failure", + old_username, new_username, + ) raise HTTPException(500, "Failed to rename user data") # Per-user prefs are JSON-backed, not SQL-backed. @@ -349,6 +367,20 @@ def setup_auth_routes(auth_manager: AuthManager) -> APIRouter: except Exception as e: logger.warning("Failed to rename user prefs %s -> %s: %s", old_username, new_username, e) + # In-flight deep-research tasks live in the process-local + # ResearchHandler registry. They are not covered by the persisted JSON + # migration above, but the research routes filter and cancel by this + # owner field while the job is running. Do this before sweeping + # completed JSON files so a job that finishes during the rename saves + # with the new owner or is caught by the disk sweep below. + try: + rh = getattr(request.app.state, "research_handler", None) + rename_owner = getattr(rh, "rename_owner", None) + if callable(rename_owner): + rename_owner(old_username, new_username) + except Exception as e: + logger.warning("Failed to rename active research tasks %s -> %s: %s", old_username, new_username, e) + # deep_research: each completed report is a standalone JSON file with # an `owner` field. research_routes filters by d.get("owner") == user, # so a stale owner makes every report invisible to the renamed user. @@ -391,7 +423,8 @@ def setup_auth_routes(auth_manager: AuthManager) -> APIRouter: skills_root = Path(SKILLS_DIR) if skills_root.is_dir(): _owner_re = re.compile( - r'(?m)^(owner:\s*)' + re.escape(old_username) + r'\s*$' + r'(?m)^(owner:\s*)' + re.escape(old_username) + r'\s*$', + re.IGNORECASE, ) for p in skills_root.rglob("SKILL.md"): try: @@ -406,12 +439,12 @@ def setup_auth_routes(auth_manager: AuthManager) -> APIRouter: try: usage = json.loads(usage_path.read_text(encoding="utf-8")) if isinstance(usage, dict): - prefix = old_username + "::" new_usage = {} changed = False for k, v in usage.items(): - if k.startswith(prefix): - new_usage[new_username + "::" + k[len(prefix):]] = v + owner_part, sep, skill_part = k.partition("::") + if sep and owner_part.lower() == old_username: + new_usage[new_username + "::" + skill_part] = v changed = True else: new_usage[k] = v @@ -473,7 +506,23 @@ def setup_auth_routes(auth_manager: AuthManager) -> APIRouter: user = _get_current_user(request) if not user or not auth_manager.is_admin(user): raise HTTPException(403, "Admin only") - ok = auth_manager.delete_user(body.username, user) + + def _invalidate_api_token_cache(): + try: + invalidator = getattr(request.app.state, "invalidate_token_cache", None) + if invalidator: + invalidator() + except Exception: + pass + + try: + ok = auth_manager.delete_user(body.username, user) + except Exception: + # delete_user can touch ApiToken rows before a later auth-store write + # fails. Dirty the bearer cache anyway so a partial token purge does + # not leave already-cached tokens authenticating until restart. + _invalidate_api_token_cache() + raise if not ok: raise HTTPException(400, "Cannot delete user") # delete_user removes the user's ApiToken rows, but the bearer-auth @@ -481,12 +530,7 @@ def setup_auth_routes(auth_manager: AuthManager) -> APIRouter: # rebuilds when flagged dirty. Without this, a deleted user's already # cached token keeps authenticating until some other token op or a # restart clears the cache. Mirror what the token routes do. - try: - invalidator = getattr(request.app.state, "invalidate_token_cache", None) - if invalidator: - invalidator() - except Exception: - pass + _invalidate_api_token_cache() return {"ok": True} # ---- Feature visibility (admin-managed) ---- diff --git a/routes/contacts_routes.py b/routes/contacts_routes.py index e4e8ce759..58a57a1e1 100644 --- a/routes/contacts_routes.py +++ b/routes/contacts_routes.py @@ -729,8 +729,11 @@ def setup_contacts_routes(): @router.post("/import") async def import_vcf(data: dict, _admin: str = Depends(require_admin)): """Import contacts from .vcf or CSV. Body: {"vcf": "..."} or {"csv": "..."}.""" - text = data.get("vcf") or data.get("text") or "" - csv_text = data.get("csv") or "" + # Coerce defensively: a non-string vcf/text/csv (e.g. a number or list + # in the JSON body) would otherwise reach .strip() and 500 with an + # AttributeError instead of degrading to a clean "no data" response. + text = str(data.get("vcf") or data.get("text") or "") + csv_text = str(data.get("csv") or "") if text.strip(): if "BEGIN:VCARD" not in text.upper(): return {"success": False, "error": "No vCard data found"} diff --git a/routes/cookbook_helpers.py b/routes/cookbook_helpers.py index 709245287..53bdde80e 100644 --- a/routes/cookbook_helpers.py +++ b/routes/cookbook_helpers.py @@ -11,6 +11,7 @@ import shlex from fastapi import HTTPException from pydantic import BaseModel +from routes._validators import validate_remote_host, validate_ssh_port from core.platform_compat import _ssh_exec_argv logger = logging.getLogger(__name__) @@ -30,16 +31,12 @@ _LOCAL_MODEL_ID_RE = re.compile(r"^[A-Za-z0-9][A-Za-z0-9._-]*$") _OLLAMA_MODEL_ID_RE = re.compile(r"^[A-Za-z0-9][A-Za-z0-9._:/-]{0,200}$") # Include pattern is a glob: allow typical safe glyphs only. _INCLUDE_RE = re.compile(r"^[A-Za-z0-9._\-*?/\[\]]+$") -# Remote host: either `user@host` or plain `host` (alias is allowed), where host -# is a safe DNS-like token or a short SSH config alias. -_REMOTE_HOST_RE = re.compile(r"^(?:[A-Za-z0-9._-]+@)?[A-Za-z0-9._-]+$") # HF tokens and API tokens are url-safe base64-like. _TOKEN_RE = re.compile(r"^[A-Za-z0-9._~+/=-]+$") # Session IDs we mint look like "cookbook-deadbeef" or "serve-deadbeef". # Anything beyond plain alphanumerics + dash + underscore could break out # of the shell/PowerShell contexts the value lands in. _SESSION_ID_RE = re.compile(r"^[A-Za-z0-9_-]{1,64}$") -_SSH_PORT_RE = re.compile(r"^\d{1,5}$") _GPU_LIST_RE = re.compile(r"^\d+(?:,\d+)*$") # A download target directory. Absolute or ~-relative path; safe path glyphs # only (no quotes or shell metacharacters). Spaces are allowed because command @@ -85,14 +82,6 @@ def _validate_include(v: str | None) -> str | None: return v -def _validate_remote_host(v: str | None) -> str | None: - if v is None or v == "": - return None - if not _REMOTE_HOST_RE.match(v): - raise HTTPException(400, "Invalid remote_host — must be host or user@host, no SSH option syntax") - return v - - def _validate_token(v: str | None) -> str | None: if v is None or v == "": return None @@ -120,17 +109,6 @@ def _validate_local_dir(v: str | None) -> str | None: return v -def _validate_ssh_port(v: str | None) -> str | None: - if v is None or v == "": - return None - if not _SSH_PORT_RE.fullmatch(str(v)): - raise HTTPException(400, "Invalid ssh_port") - port = int(v) - if port < 1 or port > 65535: - raise HTTPException(400, "Invalid ssh_port") - return str(port) - - def _validate_gpus(v: str | None) -> str | None: if v is None or v == "": return None diff --git a/routes/cookbook_routes.py b/routes/cookbook_routes.py index 4a4764232..36f98aeae 100644 --- a/routes/cookbook_routes.py +++ b/routes/cookbook_routes.py @@ -19,6 +19,7 @@ from src.constants import COOKBOOK_STATE_FILE from pydantic import BaseModel from core.middleware import require_admin +from routes._validators import validate_remote_host, validate_ssh_port from core.platform_compat import ( IS_WINDOWS, detached_popen_kwargs, @@ -33,9 +34,8 @@ from routes.shell_routes import TMUX_LOG_DIR logger = logging.getLogger(__name__) from routes.cookbook_helpers import ( - _SSH_PORT_RE, _REMOTE_HOST_RE, _SESSION_ID_RE, - _validate_repo_id, _validate_serve_model_id, _validate_include, _validate_remote_host, _validate_token, - _validate_local_dir, _validate_ssh_port, _validate_gpus, _shell_path, + _SESSION_ID_RE, _validate_repo_id, _validate_serve_model_id, _validate_include, _validate_token, + _validate_local_dir, _validate_gpus, _shell_path, _ps_squote, _bash_squote, _validate_serve_cmd, _parse_serve_phase, _safe_env_prefix, _local_tooling_path_export, _append_serve_preflight_exit_lines, _append_serve_exit_code_lines, _append_llama_cpp_linux_accel_build_lines, _cached_model_scan_script, @@ -407,8 +407,8 @@ def setup_cookbook_routes() -> APIRouter: else: _validate_repo_id(req.repo_id) _validate_include(req.include) - _validate_remote_host(req.remote_host) - req.ssh_port = _validate_ssh_port(req.ssh_port) + validate_remote_host(req.remote_host) + req.ssh_port = validate_ssh_port(req.ssh_port) req.local_dir = _validate_local_dir(req.local_dir) req.hf_token = "" if is_ollama_download else (req.hf_token or _load_stored_hf_token()) _validate_token(req.hf_token) @@ -739,9 +739,8 @@ def setup_cookbook_routes() -> APIRouter: # Validate shell-bound inputs, matching the sibling list_gpus endpoint — # `host`/`ssh_port` are interpolated into an ssh command below, so an # unvalidated value (e.g. "x'; rm -rf ~ #") would be command injection. - host = _validate_remote_host(host) - if ssh_port is not None and ssh_port != "" and not _SSH_PORT_RE.fullmatch(ssh_port): - raise HTTPException(400, "Invalid ssh_port") + host = validate_remote_host(host) + ssh_port = validate_ssh_port(ssh_port) TMUX_LOG_DIR.mkdir(parents=True, exist_ok=True) model_dirs = [] @@ -890,11 +889,16 @@ def setup_cookbook_routes() -> APIRouter: # listening" check without requiring ss/netstat/nmap. ssh_base = ["ssh", "-o", "ConnectTimeout=4", "-o", "StrictHostKeyChecking=no"] if ssh_port and str(ssh_port) != "22": - if not _SSH_PORT_RE.match(str(ssh_port)): + try: + ssh_port = validate_ssh_port(ssh_port) + except HTTPException: return None ssh_base.extend(["-p", str(ssh_port)]) - host_arg = remote - if not _REMOTE_HOST_RE.match(host_arg): + try: + host_arg = validate_remote_host(remote) + except HTTPException: + return None + if not host_arg: return None probe_ports = " ".join(str(start_port + i) for i in range(max_offset + 1)) script = ( @@ -1197,8 +1201,8 @@ def setup_cookbook_routes() -> APIRouter: """ require_admin(request) # Defence-in-depth: reject values that could break out of shell contexts. - _validate_remote_host(req.remote_host) - req.ssh_port = _validate_ssh_port(req.ssh_port) + validate_remote_host(req.remote_host) + req.ssh_port = validate_ssh_port(req.ssh_port) req.gpus = _validate_gpus(req.gpus) req.hf_token = req.hf_token or _load_stored_hf_token() _validate_token(req.hf_token) @@ -1638,12 +1642,11 @@ def setup_cookbook_routes() -> APIRouter: async def server_setup(request: Request, req: SetupRequest): """Install required dependencies on a remote server via SSH.""" require_admin(request) - host = _validate_remote_host(req.host) + host = validate_remote_host(req.host) if not host: raise HTTPException(400, "host is required") port = req.ssh_port - if port is not None and port != "" and not re.fullmatch(r"\d{1,5}", port): - raise HTTPException(400, "Invalid ssh_port") + port = validate_ssh_port(port) pf = f"-p {port} " if port and port != "22" else "" # Detect platform: Windows first (echo %OS% → Windows_NT), then Termux, then Linux @@ -1887,9 +1890,8 @@ def setup_cookbook_routes() -> APIRouter: `busy` is True when free_mb/total_mb < 0.5. """ require_admin(request) - host = _validate_remote_host(host) - if ssh_port is not None and ssh_port != "" and not _SSH_PORT_RE.fullmatch(ssh_port): - raise HTTPException(400, "Invalid ssh_port") + host = validate_remote_host(host) + ssh_port = validate_ssh_port(ssh_port) gpu_query = "nvidia-smi --query-gpu=index,name,memory.free,memory.total,memory.used,utilization.gpu,uuid --format=csv,noheader,nounits" nvidia_error = None try: @@ -2046,9 +2048,8 @@ def setup_cookbook_routes() -> APIRouter: sig = (req.signal or "TERM").upper() if sig not in ("TERM", "KILL", "INT"): raise HTTPException(400, "signal must be TERM, KILL, or INT") - host = _validate_remote_host(req.host) - if req.ssh_port and not _SSH_PORT_RE.fullmatch(req.ssh_port): - raise HTTPException(400, "Invalid ssh_port") + host = validate_remote_host(req.host) + req.ssh_port = validate_ssh_port(req.ssh_port) kill_cmd = f"kill -{sig} {req.pid}" try: if host: @@ -2382,14 +2383,19 @@ def setup_cookbook_routes() -> APIRouter: host = (srv.get("host") or "").strip() if not host: continue # local-only entry; the /proc scan handles it - if not _REMOTE_HOST_RE.match(host): + try: + host = validate_remote_host(host) + except HTTPException: continue sport = str(srv.get("port") or "").strip() ssh_base = ["ssh", "-o", "ConnectTimeout=4", "-o", "StrictHostKeyChecking=no"] if sport and sport != "22": - if not _SSH_PORT_RE.match(sport): + try: + sport = validate_ssh_port(sport) + except HTTPException: continue - ssh_base.extend(["-p", sport]) + if sport != "22": + ssh_base.extend(["-p", sport]) try: ls = subprocess.run( @@ -2743,12 +2749,18 @@ def setup_cookbook_routes() -> APIRouter: if not _SESSION_ID_RE.match(session_id): logger.warning(f"Skipping task with unsafe session_id: {session_id!r}") continue - if remote and not _REMOTE_HOST_RE.match(remote): - logger.warning(f"Skipping task with unsafe remoteHost: {remote!r}") - continue - if _tport and not _SSH_PORT_RE.match(str(_tport)): - logger.warning(f"Skipping task with unsafe sshPort: {_tport!r}") - continue + if remote: + try: + remote = validate_remote_host(remote) + except HTTPException: + logger.warning(f"Skipping task with unsafe remoteHost: {remote!r}") + continue + if _tport: + try: + _tport = validate_ssh_port(str(_tport)) + except HTTPException: + logger.warning(f"Skipping task with unsafe sshPort: {_tport!r}") + continue if task_platform == "windows" and remote: # Windows: check PID file + Get-Process, read log tail sd = "$env:TEMP\\odysseus-sessions" diff --git a/routes/hwfit_routes.py b/routes/hwfit_routes.py index eb408ac9d..45c209b0b 100644 --- a/routes/hwfit_routes.py +++ b/routes/hwfit_routes.py @@ -1,7 +1,9 @@ import re from copy import deepcopy -from fastapi import APIRouter +from fastapi import APIRouter, HTTPException + +from routes._validators import validate_remote_host, validate_ssh_port # Backends the manual hardware simulator accepts. Must stay a subset of what @@ -11,6 +13,14 @@ from fastapi import APIRouter _MANUAL_BACKENDS = {"cuda", "rocm", "metal", "cpu_x86", "cpu_arm"} +def _validate_detection_target(host: str = "", ssh_port: str = "") -> tuple[str, str]: + host_value = validate_remote_host(host) or "" + port_value = validate_ssh_port(ssh_port) or "" + if port_value and not host_value: + raise HTTPException(400, "ssh_port requires host") + return host_value, port_value + + def _apply_manual_hardware(system, manual_mode="", manual_gpu_count="", manual_vram_gb="", manual_ram_gb="", manual_backend=""): """Manual hardware is a "what if I had this setup" simulator — REPLACES the detected hardware entirely instead of adding to it. @@ -105,6 +115,7 @@ def setup_hwfit_routes(): """Detect and return current system hardware info. Pass host=user@server for remote. fresh=true bypasses the per-host cache (the Rescan button).""" from services.hwfit.hardware import detect_system + host, ssh_port = _validate_detection_target(host, ssh_port) return detect_system(host=host, ssh_port=ssh_port, platform=platform, fresh=fresh) @router.get("/models") @@ -118,6 +129,7 @@ def setup_hwfit_routes(): from services.hwfit.hardware import detect_system from services.hwfit.fit import rank_models from services.hwfit.models import get_models, model_catalog_path + host, ssh_port = _validate_detection_target(host, ssh_port) system = deepcopy(detect_system(host=host, ssh_port=ssh_port, platform=platform, fresh=fresh)) if system.get("error"): return {"system": system, "models": [], "error": system["error"]} @@ -165,8 +177,14 @@ def setup_hwfit_routes(): system["gpu_name"] = g["name"] system["active_group"] = {**g, "use_count": n} - if gpu_count != "": - n = int(gpu_count) + # Parse the optional count defensively (matches the gpu_group guard + # above): a non-numeric query param previously raised ValueError -> + # HTTP 500. A malformed value is ignored, same as omitting it. + try: + n = int(gpu_count) if gpu_count != "" else None + except ValueError: + n = None + if n is not None: if n == 0: # RAM-only mode: rank against system memory, offload allowed. system["has_gpu"] = False @@ -229,6 +247,7 @@ def setup_hwfit_routes(): from services.hwfit.hardware import detect_system from services.hwfit.models import get_models from services.hwfit.profiles import compute_serve_profiles + host, ssh_port = _validate_detection_target(host, ssh_port) system = detect_system(host=host, ssh_port=ssh_port, platform=platform, fresh=fresh) if system.get("error"): return {"system": system, "profiles": [], "error": system["error"]} @@ -279,6 +298,7 @@ def setup_hwfit_routes(): """Rank image generation models against detected hardware.""" from services.hwfit.hardware import detect_system from services.hwfit.image_models import rank_image_models + host, ssh_port = _validate_detection_target(host, ssh_port) system = deepcopy(detect_system(host=host, ssh_port=ssh_port, platform=platform, fresh=fresh)) if system.get("error"): return {"system": system, "models": [], "error": system["error"]} diff --git a/routes/session_routes.py b/routes/session_routes.py index 811a40bbe..1fb2a487a 100644 --- a/routes/session_routes.py +++ b/routes/session_routes.py @@ -11,7 +11,7 @@ from core.session_manager import SessionManager from core.models import ChatMessage from src.request_models import SessionResponse from core.database import Session as DbSession, SessionLocal, Document, GalleryImage, utcnow_naive -from src.auth_helpers import get_current_user, effective_user, _auth_disabled +from src.auth_helpers import get_current_user, effective_user, _auth_disabled, owner_filter from src.session_actions import is_session_recently_active @@ -258,7 +258,9 @@ def setup_session_routes(session_manager: SessionManager, config: dict, webhook_ last_msg_map = {} mode_map = {} msg_count_map = {} - rows = db.query(DbSession.id, DbSession.folder, DbSession.total_input_tokens, DbSession.total_output_tokens, DbSession.is_important, DbSession.created_at, DbSession.updated_at, DbSession.last_message_at, DbSession.mode, DbSession.message_count).filter(DbSession.archived == False, DbSession.owner == user).all() + q = db.query(DbSession.id, DbSession.folder, DbSession.total_input_tokens, DbSession.total_output_tokens, DbSession.is_important, DbSession.created_at, DbSession.updated_at, DbSession.last_message_at, DbSession.mode, DbSession.message_count).filter(DbSession.archived == False) + q = owner_filter(q, DbSession, user) + rows = q.all() for row in rows: folder_map[row.id] = row.folder token_map[row.id] = (row.total_input_tokens or 0) + (row.total_output_tokens or 0) @@ -277,17 +279,19 @@ def setup_session_routes(session_manager: SessionManager, config: dict, webhook_ # Sessions with active documents that have content from sqlalchemy import func doc_session_ids = set( - r[0] for r in db.query(Document.session_id) - .filter(Document.is_active == True, - Document.current_content != None, - func.trim(Document.current_content) != "", - Document.owner == user) + r[0] for r in owner_filter( + db.query(Document.session_id) + .filter(Document.is_active == True, + Document.current_content != None, + func.trim(Document.current_content) != ""), + Document, user) .distinct().all() ) img_session_ids = set( - r[0] for r in db.query(GalleryImage.session_id) - .filter(GalleryImage.session_id != None, - GalleryImage.owner == user) + r[0] for r in owner_filter( + db.query(GalleryImage.session_id) + .filter(GalleryImage.session_id != None), + GalleryImage, user) .distinct().all() ) finally: diff --git a/services/search/providers.py b/services/search/providers.py index 1f8097ad8..b913e1c6f 100644 --- a/services/search/providers.py +++ b/services/search/providers.py @@ -417,7 +417,7 @@ def duckduckgo_search(query: str, count: Optional[int] = None, time_filter: Opti return [] try: - from duckduckgo_search import DDGS + from ddgs import DDGS except ImportError: logger.warning("duckduckgo-search package not installed; using HTML fallback") return _html_fallback() diff --git a/src/builtin_actions.py b/src/builtin_actions.py index b48ed94fa..1ea7cd8a4 100644 --- a/src/builtin_actions.py +++ b/src/builtin_actions.py @@ -579,6 +579,24 @@ def _classify_event_heuristic(summary: str) -> tuple: return etype, None +def _memory_context_lines(mems, limit: int = 40) -> list: + """Render Memory rows into short personal-context bullets for event classify. + + Reads the Memory ORM `text` column. The previous inline code read a + non-existent `content` attribute, so it raised AttributeError on the first + row, the surrounding except swallowed it, and the classifier ran with no + personal context at all. getattr keeps it robust to future schema drift. + """ + lines: list = [] + for m in mems: + c = (getattr(m, "text", "") or "").strip() + if c: + lines.append(f"- {c[:200]}") + if len(lines) >= limit: + break + return lines + + async def action_classify_events(owner: str, **kwargs) -> Tuple[str, bool]: """Hybrid classification of upcoming calendar events: fast heuristic for obvious cases, LLM fallback for ambiguous ones. Assigns event_type + @@ -614,16 +632,11 @@ async def action_classify_events(owner: str, **kwargs) -> Tuple[str, bool]: try: from core.database import Memory as _Mem _mems = db.query(_Mem).filter(_Mem.owner == owner).limit(60).all() if owner else [] - if _mems: - _lines = [] - for m in _mems: - c = (m.content or "").strip() - if c: - _lines.append(f"- {c[:200]}") - if _lines: - _memory_context = "USER CONTEXT (relationships, work, life):\n" + "\n".join(_lines[:40]) + "\n\n" + _lines = _memory_context_lines(_mems) + if _lines: + _memory_context = "USER CONTEXT (relationships, work, life):\n" + "\n".join(_lines) + "\n\n" except Exception as _me: - logger.debug(f"Could not load memory for classify: {_me}") + logger.warning(f"Could not load memory for classify: {_me}") classified_h = 0 classified_llm = 0 diff --git a/src/model_discovery.py b/src/model_discovery.py index 68b402d25..506fcb6c4 100644 --- a/src/model_discovery.py +++ b/src/model_discovery.py @@ -223,6 +223,25 @@ class ModelDiscovery: ) return {"hosts": hosts, "items": items} + def warmup_ping_urls(self, limit: int = 5) -> List[str]: + """The ``/models`` URLs of up to ``limit`` discovered endpoints. + + Used by the startup warmup / keepalive loop to prime connections. Each + discovered item already carries a ``/v1/chat/completions`` url; swap the + suffix for the cheap ``/models`` probe. Failures degrade to an empty list + so warmup never crashes the caller. + """ + try: + items = (self.discover_models() or {}).get("items", []) + except Exception: + return [] + urls: List[str] = [] + for ep in items[:limit]: + url = (ep.get("url") or "").replace("/chat/completions", "/models") + if url: + urls.append(url) + return urls + def get_providers(self) -> Dict[str, Any]: """Get all available providers""" discovery = self.discover_models() diff --git a/src/research_handler.py b/src/research_handler.py index b996f089f..f1d120ef2 100644 --- a/src/research_handler.py +++ b/src/research_handler.py @@ -221,6 +221,22 @@ class ResearchHandler: # Task registry — background research with persistence # ------------------------------------------------------------------ + def rename_owner(self, old_owner: str, new_owner: str) -> int: + """Move in-flight research tasks from one owner key to another.""" + old_key = str(old_owner or "").strip().lower() + new_key = str(new_owner or "").strip().lower() + if not old_key or not new_key: + return 0 + + changed = 0 + for entry in list(self._active_tasks.values()): + if not isinstance(entry, dict): + continue + if str(entry.get("owner", "")).strip().lower() == old_key: + entry["owner"] = new_key + changed += 1 + return changed + def start_research( self, session_id: str, @@ -390,7 +406,6 @@ class ResearchHandler: def get_status(self, session_id: str) -> Optional[dict]: """Get current research status for a session.""" - avg = self.get_avg_duration() if session_id in self._active_tasks: entry = self._active_tasks[session_id] result = { @@ -399,6 +414,14 @@ class ResearchHandler: "query": entry["query"], "started_at": entry["started_at"], } + # avg_duration is a historical figure over completed reports on + # disk; get_avg_duration() globs and JSON-parses the whole research + # dir, so compute it at most once per active stream (memoized on the + # entry) instead of on every ~1s SSE poll. The disk branch below + # never used it, so it no longer pays that cost at all. + if "_avg_duration" not in entry: + entry["_avg_duration"] = self.get_avg_duration() + avg = entry["_avg_duration"] if avg is not None: result["avg_duration"] = round(avg, 1) return result diff --git a/src/tool_implementations.py b/src/tool_implementations.py index 494795037..27c05f139 100644 --- a/src/tool_implementations.py +++ b/src/tool_implementations.py @@ -1453,6 +1453,42 @@ async def do_manage_calendar(content: str, owner: Optional[str] = None) -> Dict: except ValueError: return {"error": "Invalid JSON arguments", "exit_code": 1} + # ── Batch normalization ── + # Some models (e.g. deepseek-v4-flash) emit {"events": [{...}, ...]} + # instead of individual create_event calls. Iterate and create each. + if isinstance(args.get("events"), list) and not args.get("action"): + results = [] + for ev in args["events"]: + if not isinstance(ev, dict): + continue + # Normalize start/end from {dateTime: "..."} object to flat string + for field, target in [("start", "dtstart"), ("end", "dtend")]: + val = ev.pop(field, None) + if val and target not in ev: + ev[target] = val.get("dateTime", val) if isinstance(val, dict) else val + ev.setdefault("action", "create_event") + r = await do_manage_calendar(json.dumps(ev), owner=owner) + results.append(r) + created = [r for r in results if r.get("exit_code") == 0 and not r.get("error")] + failed = [r for r in results if r.get("error")] + + if not results: + return {"error": "No events to create", "exit_code": 1} + + # Surface both successes and failures + parts = [] + if created: + summaries = [r.get("response", "") for r in created] + parts.append(f"Created {len(created)} event(s):\n" + "\n".join(summaries)) + if failed: + first_error = failed[0].get("error", "Unknown error") + parts.append(f"Failed to create {len(failed)} event(s). First error: {first_error}") + + response = "\n\n".join(parts) + # Non-zero exit code for partial or total failure + exit_code = 0 if not failed else 1 + return {"response": response, "exit_code": exit_code, "created_count": len(created), "failed_count": len(failed)} + # Normalize action — some models emit hyphens ("list-calendars") instead # of underscores. Treat them as equivalent so we don't bounce a # cosmetic typo back to the model and waste a round-trip. Also accept diff --git a/src/tool_security.py b/src/tool_security.py index 82d2c3d67..6b7bc90df 100644 --- a/src/tool_security.py +++ b/src/tool_security.py @@ -162,13 +162,26 @@ def is_public_blocked_tool(tool_name: Optional[str]) -> bool: def owner_is_admin_or_single_user(owner: Optional[str]) -> bool: - """Return True for admins, or when auth is not configured yet.""" + """Return True for admins, or in intentional single-user mode. + + Single-user mode means the operator explicitly disabled auth + (``AUTH_ENABLED=false``) — the local/self-host default where the owner has + full access to their own box. + + The pre-setup window (auth ENABLED but no admin created yet) is treated as + NON-admin: returning True there would hand server-execution tools + (``bash``/``python``) to any caller before setup completes. The auth + middleware already 401s ``/api/`` requests pre-setup, so this is + defense-in-depth for callers that bypass it (e.g. trusted loopback). + """ try: from core.auth import AuthManager auth = AuthManager() if not auth.is_configured: - return True + from src.auth_helpers import _auth_disabled + + return _auth_disabled() return bool(owner and auth.is_admin(owner)) except Exception as exc: logger.warning("Unable to evaluate owner admin status: %s", exc) diff --git a/static/js/chat.js b/static/js/chat.js index 60149d005..7ecefdb7d 100644 --- a/static/js/chat.js +++ b/static/js/chat.js @@ -1082,7 +1082,7 @@ import { wireArrowUpRecall, getLastUserMessageFromChatHistory } from './composer let _lastToolName = ''; const _searchIcon = ''; const _toolLabels = { - 'web_search': _searchIcon + 'Searching', + 'web_search': 'Searching', 'bash': 'Running', 'python': 'Running', 'create_document': 'Writing', @@ -1102,6 +1102,9 @@ import { wireArrowUpRecall, getLastUserMessageFromChatHistory } from './composer 'list_models': 'Browsing', 'ui_control': 'Adjusting', }; + const _toolIcons = { + 'web_search': _searchIcon, + }; function _thinkingLabel() { if (!_lastToolName) { return 'Thinking'; @@ -2049,10 +2052,11 @@ import { wireArrowUpRecall, getLastUserMessageFromChatHistory } from './composer } threadWrap.classList.add('streaming'); const toolLabel = _toolLabels[json.tool.toLowerCase()] || json.tool; + const toolIcon = _toolIcons[json.tool.toLowerCase()] || '\u25B6'; const node = document.createElement('div') node.className = 'agent-thread-node running'; const cmdHtml = cmd ? `
${esc(cmd)}
` : ''; - node.innerHTML = `
\u25B6${esc(toolLabel)}▁▂▃
${cmdHtml}
`; + node.innerHTML = `
${toolIcon}${esc(toolLabel)}▁▂▃
${cmdHtml}
`; // Expand/collapse via delegated click handler (init at module bottom). threadWrap.appendChild(node); currentToolBubble = node; diff --git a/static/js/chatRenderer.js b/static/js/chatRenderer.js index 9a5c6f78b..7c6ecd096 100644 --- a/static/js/chatRenderer.js +++ b/static/js/chatRenderer.js @@ -862,6 +862,20 @@ export function stripToolBlocks(text) { return cleaned.trim(); } +/** + * Plain-text payload for the message copy buttons: the reply as the renderer + * displays it — tool blocks and reasoning stripped. dataset.raw keeps + * the full model output (chat.js even embeds the elapsed time into the + * tag for reload persistence), so copying it verbatim leaks the + * thinking block (#3722). Falls back to the raw text when stripping leaves + * nothing (e.g. turns interrupted mid-thinking). + */ +export function copyMessageText(msgElement) { + const raw = msgElement.dataset.raw || msgElement.querySelector('.body')?.textContent || ''; + const { content } = markdownModule.extractThinkingBlocks(stripToolBlocks(raw)); + return content || raw; +} + /** * Build a collapsible sources box (used by both research and web search). */ @@ -1372,7 +1386,7 @@ export function createMsgFooter(msgElement) { { id: 'copy', icon: COPY_ICON, title: 'Copy message', cls: 'footer-copy-btn', html: true, handler(e) { e.stopPropagation(); const btn = e.currentTarget; - uiModule.copyToClipboard(msgElement.dataset.raw || msgElement.querySelector('.body')?.textContent || ''); + uiModule.copyToClipboard(copyMessageText(msgElement)); btn.innerHTML = CHECK_ICON; setTimeout(() => { btn.innerHTML = COPY_ICON; }, 1500); }}, @@ -2444,6 +2458,7 @@ const chatRenderer = { updateSessionCostUI, roleTimestamp, stripToolBlocks, + copyMessageText, safeToolScreenshotSrc, safeDisplayImageSrc, buildSourcesBox, diff --git a/static/js/slashCommands.js b/static/js/slashCommands.js index 6a32cb89e..79b037cf4 100644 --- a/static/js/slashCommands.js +++ b/static/js/slashCommands.js @@ -380,7 +380,7 @@ function _slashFooter(msgEl) { copyBtn.innerHTML = _copySvg; copyBtn.onclick = (e) => { e.stopPropagation(); - uiModule.copyToClipboard(msgEl.dataset.raw || msgEl.querySelector('.body')?.textContent || ''); + uiModule.copyToClipboard(chatRenderer.copyMessageText(msgEl)); copyBtn.innerHTML = _checkSvg; setTimeout(() => { copyBtn.innerHTML = _copySvg; }, 1500); }; diff --git a/tests/test_auth_config_lock_concurrency.py b/tests/test_auth_config_lock_concurrency.py index f5cc8a18c..34232b9e2 100644 --- a/tests/test_auth_config_lock_concurrency.py +++ b/tests/test_auth_config_lock_concurrency.py @@ -8,6 +8,9 @@ with missing users or assertion errors. import json import threading import time +import contextlib +import sys +import types from concurrent.futures import ThreadPoolExecutor, as_completed import pytest @@ -15,6 +18,41 @@ import pytest from tests.helpers.import_state import clear_module +class _OwnerColumn: + def __eq__(self, other): + return ("owner ==", other) + + +class _FakeApiToken: + owner = _OwnerColumn() + + +class _FakeQuery: + def filter(self, *_conds): + return self + + def delete(self, *args, **kwargs): + return 0 + + +class _FakeSession: + def query(self, model): + assert model is _FakeApiToken + return _FakeQuery() + + +@pytest.fixture(autouse=True) +def _stub_api_token_purge(monkeypatch): + @contextlib.contextmanager + def _fake_db_session(): + yield _FakeSession() + + db_stub = types.ModuleType("core.database") + db_stub.get_db_session = _fake_db_session + db_stub.ApiToken = _FakeApiToken + monkeypatch.setitem(sys.modules, "core.database", db_stub) + + def _fresh_auth_manager(tmp_path): clear_module("core.auth") from core.auth import AuthManager diff --git a/tests/test_calendar_batch_events.py b/tests/test_calendar_batch_events.py new file mode 100644 index 000000000..d8176afcd --- /dev/null +++ b/tests/test_calendar_batch_events.py @@ -0,0 +1,125 @@ +"""Test that do_manage_calendar handles the batch {"events": [...]} format +that models like deepseek-v4-flash emit instead of individual create_event calls. +""" + +import json +import sys +import uuid + +import pytest + +from tests.helpers.import_state import clear_fake_database_modules +from tests.helpers.sqlite_db import make_temp_sqlite + +clear_fake_database_modules() + +import core.database as cdb +from core.database import CalendarEvent + +_TS, _ENGINE, _TMPDB = make_temp_sqlite(cdb.Base.metadata) + + +@pytest.fixture(autouse=True) +def _bind_temp_db(monkeypatch): + monkeypatch.setitem(sys.modules, "core.database", cdb) + parent = sys.modules.get("core") + if parent is not None: + monkeypatch.setattr(parent, "database", cdb, raising=False) + monkeypatch.setattr(cdb, "SessionLocal", _TS) + yield + + +async def test_batch_events_with_datetime_objects(): + """Model emits {"events": [{"summary": ..., "start": {"dateTime": ...}, "end": {"dateTime": ...}}]}.""" + from src.tool_implementations import do_manage_calendar + + owner = "tester-" + uuid.uuid4().hex[:6] + payload = { + "events": [ + { + "summary": "Morning Gym", + "start": {"dateTime": "2026-06-09T06:00:00+05:30"}, + "end": {"dateTime": "2026-06-09T07:00:00+05:30"}, + }, + { + "summary": "Morning Gym", + "start": {"dateTime": "2026-06-10T06:00:00+05:30"}, + "end": {"dateTime": "2026-06-10T07:00:00+05:30"}, + }, + ] + } + res = await do_manage_calendar(json.dumps(payload), owner=owner) + assert res.get("exit_code") == 0, res + assert "Created 2 event(s)" in res.get("response", "") + + # Verify events exist in DB + db = _TS() + events = db.query(CalendarEvent).filter(CalendarEvent.summary == "Morning Gym").all() + assert len(events) == 2 + db.close() + + +async def test_batch_events_with_flat_strings(): + """Model emits {"events": [{"summary": ..., "start": "ISO", "end": "ISO"}]}.""" + from src.tool_implementations import do_manage_calendar + + owner = "tester-" + uuid.uuid4().hex[:6] + payload = { + "events": [ + { + "summary": "Standup", + "start": "2026-06-09T09:00:00", + "end": "2026-06-09T09:30:00", + }, + ] + } + res = await do_manage_calendar(json.dumps(payload), owner=owner) + assert res.get("exit_code") == 0, res + assert "Created 1 event(s)" in res.get("response", "") + + +async def test_batch_events_partial_failure(): + """Batch with some valid and some invalid events — should surface both counts and first error.""" + from src.tool_implementations import do_manage_calendar + + owner = "tester-" + uuid.uuid4().hex[:6] + payload = { + "events": [ + { + "summary": "Valid Event 1", + "start": "2026-06-09T10:00:00", + "end": "2026-06-09T11:00:00", + }, + { + "summary": "Invalid Event", + # Missing required dtstart — will fail + }, + { + "summary": "Valid Event 2", + "start": "2026-06-09T14:00:00", + "end": "2026-06-09T15:00:00", + }, + ] + } + res = await do_manage_calendar(json.dumps(payload), owner=owner) + + # Partial failure = non-zero exit code + assert res.get("exit_code") != 0, "Partial failure should return non-zero exit code" + + # Response should mention both created and failed counts + response = res.get("response", "") + assert "Created 2 event(s)" in response, f"Should report 2 created: {response}" + assert "Failed to create 1 event(s)" in response, f"Should report 1 failed: {response}" + assert "error" in response.lower() or "required" in response.lower(), "Should include error details" + + # Metadata fields + assert res.get("created_count") == 2 + assert res.get("failed_count") == 1 + + # Verify only valid events were created + db = _TS() + events = db.query(CalendarEvent).filter( + CalendarEvent.summary.in_(["Valid Event 1", "Valid Event 2"]) + ).all() + assert len(events) == 2 + db.close() diff --git a/tests/test_classify_events_memory_text.py b/tests/test_classify_events_memory_text.py new file mode 100644 index 000000000..328929115 --- /dev/null +++ b/tests/test_classify_events_memory_text.py @@ -0,0 +1,33 @@ +"""classify_events must read the Memory `text` column, not a non-existent +`content` attribute. + +The previous inline loop did `m.content`, which raised AttributeError on the +first Memory row; the surrounding except swallowed it, so the personal-context +block the LLM relies on was always empty. The logic now lives in +`_memory_context_lines`, which reads `text`. +""" +from src.builtin_actions import _memory_context_lines + + +class _Mem: + def __init__(self, text): + self.text = text + + +def test_uses_text_and_truncates_and_skips_blank(): + lines = _memory_context_lines([_Mem("Alice is my spouse"), _Mem(" "), _Mem("y" * 250)]) + assert lines[0] == "- Alice is my spouse" + assert len(lines) == 2 # the blank row is skipped + assert lines[1] == "- " + "y" * 200 # truncated to 200 chars + + +def test_skips_rows_without_text_attribute(): + class _Bad: # mimics a schema where the attribute is absent + pass + + assert _memory_context_lines([_Bad(), _Mem("ok")]) == ["- ok"] + + +def test_respects_limit(): + mems = [_Mem(f"memory {i}") for i in range(50)] + assert len(_memory_context_lines(mems, limit=40)) == 40 diff --git a/tests/test_contacts_import_nonstring.py b/tests/test_contacts_import_nonstring.py new file mode 100644 index 000000000..c029b569d --- /dev/null +++ b/tests/test_contacts_import_nonstring.py @@ -0,0 +1,39 @@ +"""POST /api/contacts/import must not 500 on a non-string vcf/text/csv value. + +`text = data.get("vcf") or ... or ""` left a non-string value (e.g. a number) +in place, so the next `text.strip()` raised AttributeError -> HTTP 500. The +handler now coerces with str() and degrades to a structured "no data" response. +""" +import asyncio + +from routes.contacts_routes import setup_contacts_routes + + +def _import_handler(): + router = setup_contacts_routes() + for route in router.routes: + if getattr(route, "path", "").endswith("/import") and "POST" in getattr(route, "methods", set()): + return route.endpoint + raise AssertionError("import route not found") + + +def _call(data): + handler = _import_handler() + return asyncio.run(handler(data=data, _admin="admin")) + + +def test_non_string_vcf_degrades_cleanly(): + resp = _call({"vcf": 123}) + assert resp["success"] is False + assert "error" in resp + + +def test_non_string_csv_degrades_cleanly(): + resp = _call({"csv": ["a", "b"]}) + assert resp["success"] is False + + +def test_empty_body_reports_no_data(): + resp = _call({}) + assert resp["success"] is False + assert resp["error"] == "No contact data found" diff --git a/tests/test_cookbook_helpers.py b/tests/test_cookbook_helpers.py index acc001812..779b48e3c 100644 --- a/tests/test_cookbook_helpers.py +++ b/tests/test_cookbook_helpers.py @@ -26,7 +26,6 @@ from routes.cookbook_helpers import ( _validate_repo_id, _validate_serve_cmd, _validate_serve_model_id, - _validate_ssh_port, _shell_path, run_ssh_command_async, ) @@ -106,12 +105,6 @@ def test_safe_env_prefix_accepts_powershell_activation_path(): ) -def test_validate_ssh_port_rejects_shell_payload(): - with pytest.raises(HTTPException): - _validate_ssh_port("22; touch /tmp/pwned") - assert _validate_ssh_port("2222") == "2222" - - def test_validate_local_dir_accepts_external_drive_paths_with_spaces(): path = "/Volumes/T7 2TB/AI Models/llamacpp" diff --git a/tests/test_copy_message_strips_thinking_js.py b/tests/test_copy_message_strips_thinking_js.py new file mode 100644 index 000000000..4c88bb6d4 --- /dev/null +++ b/tests/test_copy_message_strips_thinking_js.py @@ -0,0 +1,160 @@ +"""Regression coverage for issue #3722 — the message copy button copied the +full raw model output (``dataset.raw``), which still contains the +``...`` reasoning block that the renderer strips for +display. Pasting therefore leaked the model's thinking, and the first heading +after ```` lost its markdown formatting because it was glued to the +closing tag. + +The fix adds chatRenderer.copyMessageText(), which mirrors the display +pipeline (``stripToolBlocks()`` then ``extractThinkingBlocks()``), and routes +both AI-message copy buttons (createMsgFooter and the slash-reply footer) +through it. extractThinkingBlocks() behavior is pinned here under node +(including on the payload from the issue report); the helper and handler +wiring are guarded at the source level because chatRenderer.js pulls in +browser globals and can't be imported under node (same approach as +test_new_chat_clears_input.py). +""" + +import json +import re +import shutil +import subprocess +import textwrap +from pathlib import Path + +import pytest + +_REPO = Path(__file__).resolve().parent.parent +_HAS_NODE = shutil.which("node") is not None + + +@pytest.fixture(scope="module") +def node_available(): + if not _HAS_NODE: + pytest.skip("node binary not on PATH") + + +def _extract_thinking_blocks(text: str) -> dict: + """Run markdown.js extractThinkingBlocks(text) under node.""" + script = textwrap.dedent( + r""" + import fs from 'node:fs'; + + globalThis.window = { location: { origin: 'http://localhost' }, katex: null }; + globalThis.document = { + readyState: 'loading', + addEventListener() {}, + createElement(tag) { + if (tag !== 'template') throw new Error(`unsupported element: ${tag}`); + return { + _html: '', + content: { querySelectorAll() { return []; } }, + set innerHTML(value) { this._html = value; }, + get innerHTML() { return this._html; }, + }; + }, + }; + globalThis.MutationObserver = class { observe() {} }; + + let source = fs.readFileSync('./static/js/markdown.js', 'utf8'); + source = source.replace( + /import uiModule from ['"]\.\/ui\.js['"];/, + '' + ); + source = source.replace( + /import \{ splitTableRow \} from ['"]\.\/markdown\/tableRow\.js['"];/, + `function splitTableRow(row) { + return (row || '').replace(/^\\s*\\|/, '').replace(/\\|\\s*$/, '').split('|').map(c => c.trim()); + }` + ); + const emojiSource = fs.readFileSync('./static/js/emojiShortcodes.js', 'utf8') + .replace(/^export default .*$/m, '') + .replace(/export const /g, 'const ') + .replace(/export function /g, 'function '); + source = source.replace( + /import \{ replaceEmojiShortcodes, hasEmojiShortcode \} from ['"]\.\/emojiShortcodes\.js['"];/, + () => emojiSource + ); + source = source.replace( + /var escapeHtml = uiModule\.esc;/, + `var escapeHtml = (value) => String(value ?? '') + .replace(/&/g, '&') + .replace(//g, '>') + .replace(/"/g, '"') + .replace(/'/g, ''');` + ); + + const moduleUrl = 'data:text/javascript;base64,' + Buffer.from(source).toString('base64'); + const mod = await import(moduleUrl); + const input = JSON.parse(process.argv[1]); + console.log(JSON.stringify({ out: mod.extractThinkingBlocks(input) })); + """ + ) + result = subprocess.run( + ["node", "--input-type=module", "-e", script, json.dumps(text)], + cwd=_REPO, + capture_output=True, + timeout=15, + text=True, + ) + if result.returncode != 0: + raise AssertionError(f"node failed:\nSTDERR:\n{result.stderr}\nSTDOUT:\n{result.stdout}") + return json.loads(result.stdout.splitlines()[-1])["out"] + + +def test_issue_payload_copy_text_excludes_thinking(node_available): + # Shape reported in #3722: timed think block glued to the reply heading. + raw = ( + '\n' + "Here's a thinking process that leads to the desired summary:\n\n" + "6. **Generate the Output.** (This matches the final provided response.)" + "### Juxtaposition: Interweaving Cultural Norms in Lesson Design\n" + "The most effective lesson structure is created by deliberately juxtaposing." + ) + out = _extract_thinking_blocks(raw) + + assert out["content"].startswith("### Juxtaposition:"), out["content"] + assert "thinking process" not in out["content"] + assert "only reasoning, no reply yet") + assert out["content"] == "" + + +def _function_body(text: str, marker: str) -> str: + start = text.index(marker) + rest = text[start + len(marker):] + m = re.search(r"\nexport function |\nfunction ", rest) + return rest[: m.start()] if m else rest + + +def test_copy_message_text_mirrors_display_pipeline(): + text = (_REPO / "static/js/chatRenderer.js").read_text(encoding="utf-8") + body = _function_body(text, "export function copyMessageText") + # Mirrors the display path: tool blocks stripped, then thinking extracted. + assert "extractThinkingBlocks" in body + assert "stripToolBlocks" in body + assert "dataset.raw" in body + + +def test_copy_handlers_route_through_copy_message_text(): + for path, count in (("static/js/chatRenderer.js", 1), ("static/js/slashCommands.js", 1)): + text = (_REPO / path).read_text(encoding="utf-8") + assert text.count("copyToClipboard(copyMessageText(") + text.count( + "copyToClipboard(chatRenderer.copyMessageText(" + ) == count, path + # The old behavior passed dataset.raw straight to the clipboard. + assert "copyToClipboard(msgElement.dataset.raw" not in text, path + assert "copyToClipboard(msgEl.dataset.raw" not in text, path diff --git a/tests/test_delete_user_invalidates_token_cache.py b/tests/test_delete_user_invalidates_token_cache.py index c9cb79a5e..91be50e93 100644 --- a/tests/test_delete_user_invalidates_token_cache.py +++ b/tests/test_delete_user_invalidates_token_cache.py @@ -36,6 +36,17 @@ def _auth_manager(delete_result): ) +def _auth_manager_raising(): + def _delete_user(_username, _requesting_user): + raise RuntimeError("auth save failed after token purge") + + return types.SimpleNamespace( + get_username_for_token=lambda token: "admin", + is_admin=lambda user: True, + delete_user=_delete_user, + ) + + def test_successful_delete_invalidates_cache(): invalidations = [] router = setup_auth_routes(_auth_manager(delete_result=True)) @@ -56,3 +67,16 @@ def test_refused_delete_does_not_invalidate_cache(): raised = True assert raised, "a refused delete should raise (HTTP 400)" assert invalidations == [], "a refused delete must not touch the token cache" + + +def test_delete_exception_invalidates_cache_for_partial_token_purge(): + invalidations = [] + router = setup_auth_routes(_auth_manager_raising()) + handler = _handler(router) + try: + asyncio.run(handler(DeleteUserRequest(username="bob"), _fake_request(invalidations))) + raised = False + except RuntimeError: + raised = True + assert raised, "delete_user exception should still propagate" + assert invalidations == [True], "partial token purge must dirty the bearer cache" diff --git a/tests/test_delete_user_revokes_api_tokens.py b/tests/test_delete_user_revokes_api_tokens.py index dab753ff0..52a7d55af 100644 --- a/tests/test_delete_user_revokes_api_tokens.py +++ b/tests/test_delete_user_revokes_api_tokens.py @@ -114,3 +114,21 @@ def test_refused_delete_leaves_tokens_alone(manager, db_calls): def test_unknown_user_leaves_tokens_alone(manager, db_calls): assert manager.delete_user("ghost", "admin") is False assert db_calls == [] + + +def test_delete_user_fails_closed_when_api_token_purge_fails(manager, monkeypatch): + token = manager.create_session("bob", "secret-bob-pw") + + @contextlib.contextmanager + def _failing_db_session(): + raise RuntimeError("database unavailable") + yield + + db_stub = types.ModuleType("core.database") + db_stub.get_db_session = _failing_db_session + db_stub.ApiToken = _FakeApiToken + monkeypatch.setitem(sys.modules, "core.database", db_stub) + + assert manager.delete_user("bob", "admin") is False + assert "bob" in manager.users + assert manager.validate_token(token) is True diff --git a/tests/test_hwfit_gpu_count_nonnumeric.py b/tests/test_hwfit_gpu_count_nonnumeric.py new file mode 100644 index 000000000..13e6b2f25 --- /dev/null +++ b/tests/test_hwfit_gpu_count_nonnumeric.py @@ -0,0 +1,38 @@ +"""GET /api/hwfit/models must not 500 on a non-numeric gpu_count. + +The handler did `n = int(gpu_count)` with no guard, so `?gpu_count=abc` (or any +non-integer) raised ValueError -> HTTP 500. A malformed count is now ignored, +matching how the neighbouring gpu_group param is already parsed. +""" +from routes.hwfit_routes import setup_hwfit_routes + + +def _get_models(): + router = setup_hwfit_routes() + for route in router.routes: + if getattr(route, "path", "").endswith("/models") and "GET" in getattr(route, "methods", set()): + return route.endpoint + raise AssertionError("hwfit /models route not found") + + +def test_non_numeric_gpu_count_does_not_raise(): + handler = _get_models() + # Previously raised ValueError (HTTP 500); now degrades to a normal ranking. + result = handler(gpu_count="abc") + assert isinstance(result, dict) + + +def test_numeric_gpu_count_still_accepted(): + handler = _get_models() + result = handler(gpu_count="0") + assert isinstance(result, dict) + + +def test_non_numeric_manual_gpu_count_does_not_raise(): + # manual_gpu_count is the other count param on this endpoint (the hardware + # simulator in _apply_manual_hardware). A non-numeric value must also degrade + # (default to 1) rather than 500, so the endpoint's count parsing is fully + # covered. + handler = _get_models() + result = handler(manual_mode="gpu", manual_gpu_count="abc") + assert isinstance(result, dict) diff --git a/tests/test_hwfit_remote_validation.py b/tests/test_hwfit_remote_validation.py new file mode 100644 index 000000000..aee2aaadb --- /dev/null +++ b/tests/test_hwfit_remote_validation.py @@ -0,0 +1,47 @@ +import pytest +from fastapi import HTTPException + +from core.platform_compat import _ssh_exec_argv +from routes.hwfit_routes import setup_hwfit_routes + + +def _endpoint(path: str): + router = setup_hwfit_routes() + for route in router.routes: + if getattr(route, "path", "") == path: + return route.endpoint + raise AssertionError(f"{path} route not found") + + +@pytest.mark.parametrize( + "path,kwargs", + [ + ("/api/hwfit/system", {}), + ("/api/hwfit/models", {"limit": 1}), + ("/api/hwfit/profiles", {"model": "demo"}), + ("/api/hwfit/image-models", {}), + ], +) +def test_hwfit_routes_reject_ssh_option_host(path, kwargs): + endpoint = _endpoint(path) + + with pytest.raises(HTTPException) as exc: + endpoint(host="-oProxyCommand=sh", ssh_port="22", **kwargs) + + assert exc.value.status_code == 400 + + +def test_hwfit_routes_reject_port_without_host(): + endpoint = _endpoint("/api/hwfit/system") + + with pytest.raises(HTTPException) as exc: + endpoint(host="", ssh_port="2222") + + assert exc.value.status_code == 400 + + +def test_ssh_argv_rejects_option_shaped_remote(): + with pytest.raises(ValueError): + _ssh_exec_argv("-oProxyCommand=sh", "22", remote_cmd="true") + with pytest.raises(ValueError): + _ssh_exec_argv("alice@-oProxyCommand=sh", "22", remote_cmd="true") diff --git a/tests/test_rename_user_owner_sync.py b/tests/test_rename_user_owner_sync.py index 16d91c512..e5e89b4dc 100644 --- a/tests/test_rename_user_owner_sync.py +++ b/tests/test_rename_user_owner_sync.py @@ -11,7 +11,10 @@ owner column, but three file-backed / in-memory stores are left stale: research_routes filters by `d.get("owner") == user`, making every report invisible after rename. -3. data/memory.json — a flat array where every entry has an `owner` field; +3. research_handler._active_tasks — in-flight research jobs carry the same + owner key while status/cancel/active routes filter by it. + +4. data/memory.json — a flat array where every entry has an `owner` field; memory_manager.load(owner=user) filters on it, so all memories vanish. Regression coverage: these bugs are invisible in unit tests that mock the DB @@ -26,6 +29,7 @@ from types import SimpleNamespace from unittest.mock import MagicMock import pytest +from fastapi import HTTPException def _route(router, name): @@ -63,18 +67,69 @@ def rename_endpoint(monkeypatch, tmp_path): return _route(ar.setup_auth_routes(am), "rename_user"), am, tmp_path -def _request(tmp_path, session_manager=None): +def _request(tmp_path, session_manager=None, token="t", research_handler=None): state = SimpleNamespace( invalidate_token_cache=lambda: None, session_manager=session_manager, + research_handler=research_handler, ) return SimpleNamespace( - cookies={"odysseus_session": "t"}, + cookies={"odysseus_session": token}, app=SimpleNamespace(state=state), state=SimpleNamespace(current_user="admin"), ) +def _auth_manager_for_rollback_test(monkeypatch, tmp_path): + import core.auth as auth_mod + + monkeypatch.setattr(auth_mod, "_hash_password", lambda password: f"hash:{password}") + monkeypatch.setattr(auth_mod, "_verify_password", lambda password, hashed: hashed == f"hash:{password}") + + am = auth_mod.AuthManager(str(tmp_path / "auth.json")) + assert am.create_user("admin", "pw-123456", is_admin=True) is True + assert am.create_user("alice", "pw-123456") is True + return am + + +def _force_sql_owner_migration_failure(monkeypatch): + import core.database as cdb + + class OwnerModel: + owner = "owner" + + class FailingQuery: + def filter(self, *_args, **_kwargs): + return self + + def update(self, *_args, **_kwargs): + raise RuntimeError("forced owner migration failure") + + class FailingSession: + def __init__(self): + self.rolled_back = False + self.closed = False + + def query(self, _model): + return FailingQuery() + + def rollback(self): + self.rolled_back = True + + def close(self): + self.closed = True + + db = FailingSession() + monkeypatch.setattr(cdb, "SessionLocal", lambda: db) + monkeypatch.setattr( + cdb, + "Base", + SimpleNamespace(registry=SimpleNamespace(mappers=[SimpleNamespace(class_=OwnerModel)])), + raising=False, + ) + return db + + # --------------------------------------------------------------------------- # 1. In-memory session cache # --------------------------------------------------------------------------- @@ -183,6 +238,108 @@ def test_rename_no_deep_research_dir_does_not_crash(rename_endpoint): assert res["ok"] is True +def test_rename_updates_active_research_task_owner(rename_endpoint): + endpoint, _am, tmp_path = rename_endpoint + + from routes.research_routes import setup_research_routes + from src.research_handler import ResearchHandler + + rh = ResearchHandler.__new__(ResearchHandler) + rh._active_tasks = { + "alice-task": { + "owner": "Alice", + "status": "running", + "query": "q", + "progress": {}, + "started_at": 1, + }, + "carol-task": { + "owner": "carol", + "status": "running", + "query": "q2", + "progress": {}, + "started_at": 2, + }, + } + + asyncio.run(endpoint( + "alice", + SimpleNamespace(username="alice2"), + _request(tmp_path, research_handler=rh), + )) + + assert rh._active_tasks["alice-task"]["owner"] == "alice2" + assert rh._active_tasks["carol-task"]["owner"] == "carol" + + router = setup_research_routes(rh) + active = next( + r.endpoint for r in router.routes + if getattr(r, "path", "") == "/api/research/active" + ) + + alice2 = asyncio.run(active( + SimpleNamespace(state=SimpleNamespace(current_user="alice2")), + )) + alice = asyncio.run(active( + SimpleNamespace(state=SimpleNamespace(current_user="alice")), + )) + + assert [item["session_id"] for item in alice2["active"]] == ["alice-task"] + assert alice["active"] == [] + + +def test_research_handler_rename_owner_canonicalizes_new_owner(): + from src.research_handler import ResearchHandler + + rh = ResearchHandler.__new__(ResearchHandler) + rh._active_tasks = { + "task": {"owner": "Alice", "status": "running"}, + } + + changed = rh.rename_owner("alice", "Alice2") + assert changed == 1 + assert rh._active_tasks["task"]["owner"] == "alice2" + + +def test_research_handler_rename_owner_uses_auth_lower_contract_not_casefold(): + from src.research_handler import ResearchHandler + + rh = ResearchHandler.__new__(ResearchHandler) + rh._active_tasks = { + "task-strasse": {"owner": "strasse", "status": "running"}, + "task-sharp-s": {"owner": "straße", "status": "running"}, + } + + changed = rh.rename_owner("straße", "renamed") + + assert changed == 1 + assert rh._active_tasks["task-strasse"]["owner"] == "strasse" + assert rh._active_tasks["task-sharp-s"]["owner"] == "renamed" + + +def test_rename_updates_active_research_before_completed_json_sweep(rename_endpoint): + endpoint, _am, tmp_path = rename_endpoint + + dr_dir = tmp_path / "deep_research" + dr_dir.mkdir() + report = dr_dir / "race-window.json" + report.write_text(json.dumps({"owner": "alice", "status": "done"}), encoding="utf-8") + owner_seen_by_active_hook = [] + + class FakeResearchHandler: + def rename_owner(self, _old, _new): + owner_seen_by_active_hook.append(json.loads(report.read_text(encoding="utf-8"))["owner"]) + + asyncio.run(endpoint( + "alice", + SimpleNamespace(username="alice2"), + _request(tmp_path, research_handler=FakeResearchHandler()), + )) + + assert owner_seen_by_active_hook == ["alice"] + assert json.loads(report.read_text(encoding="utf-8"))["owner"] == "alice2" + + def test_rename_research_respects_custom_data_dir(monkeypatch, tmp_path): """DEEP_RESEARCH_DIR (which honours ODYSSEUS_DATA_DIR) is used, not a hardcoded relative path. Before the fix, setting ODYSSEUS_DATA_DIR made @@ -333,8 +490,100 @@ def test_rename_no_skills_dir_does_not_crash(rename_endpoint): assert res["ok"] is True +def test_rename_skill_md_owner_case_insensitive(rename_endpoint): + """SKILL.md written with owner: Alice (mixed case) must be updated when + renaming alice — the regex was missing re.IGNORECASE.""" + endpoint, _am, tmp_path = rename_endpoint + + skill_dir = tmp_path / "skills" / "general" / "s" + skill_dir.mkdir(parents=True) + (skill_dir / "SKILL.md").write_text(_SKILL_MD.format(owner="Alice"), encoding="utf-8") + + asyncio.run(endpoint("alice", SimpleNamespace(username="alice2"), _request(tmp_path))) + + assert "owner: alice2" in (skill_dir / "SKILL.md").read_text(encoding="utf-8") + + +def test_rename_usage_keys_case_insensitive(rename_endpoint): + """_usage.json keys stored as Alice::skill-name must be migrated when + renaming alice — the old startswith check was not lowercasing.""" + endpoint, _am, tmp_path = rename_endpoint + + skills_root = tmp_path / "skills" + skills_root.mkdir(parents=True) + usage = {"Alice::my-skill": {"uses": 5, "last_used": 999}} + (skills_root / "_usage.json").write_text(json.dumps(usage), encoding="utf-8") + + asyncio.run(endpoint("alice", SimpleNamespace(username="alice2"), _request(tmp_path))) + + updated = json.loads((skills_root / "_usage.json").read_text(encoding="utf-8")) + assert "alice2::my-skill" in updated + assert "Alice::my-skill" not in updated + + # --------------------------------------------------------------------------- -# 5. P1 regression: rejected auth rename must not mutate file-backed stores +# 5. Rollback: auth rename must be restored if SQL owner migration fails +# --------------------------------------------------------------------------- + +def test_owner_migration_failure_rolls_back_auth_rename(monkeypatch, tmp_path): + import routes.auth_routes as ar + + db = _force_sql_owner_migration_failure(monkeypatch) + am = _auth_manager_for_rollback_test(monkeypatch, tmp_path) + admin_token = am.create_session_trusted("admin") + alice_token = am.create_session_trusted("alice") + endpoint = _route(ar.setup_auth_routes(am), "rename_user") + + with pytest.raises(HTTPException) as exc: + asyncio.run( + endpoint( + "alice", + SimpleNamespace(username="alice2"), + _request(tmp_path, token=admin_token), + ) + ) + + assert exc.value.status_code == 500 + assert db.rolled_back is True + assert db.closed is True + assert "alice" in am.users + assert "alice2" not in am.users + assert am.get_username_for_token(alice_token) == "alice" + saved_users = json.loads((tmp_path / "auth.json").read_text(encoding="utf-8"))["users"] + assert "alice" in saved_users + assert "alice2" not in saved_users + + +def test_self_rename_owner_migration_failure_rolls_back_auth_session(monkeypatch, tmp_path): + import routes.auth_routes as ar + + db = _force_sql_owner_migration_failure(monkeypatch) + am = _auth_manager_for_rollback_test(monkeypatch, tmp_path) + admin_token = am.create_session_trusted("admin") + endpoint = _route(ar.setup_auth_routes(am), "rename_user") + + with pytest.raises(HTTPException) as exc: + asyncio.run( + endpoint( + "admin", + SimpleNamespace(username="chief"), + _request(tmp_path, token=admin_token), + ) + ) + + assert exc.value.status_code == 500 + assert db.rolled_back is True + assert db.closed is True + assert "admin" in am.users + assert "chief" not in am.users + assert am.get_username_for_token(admin_token) == "admin" + saved_users = json.loads((tmp_path / "auth.json").read_text(encoding="utf-8"))["users"] + assert "admin" in saved_users + assert "chief" not in saved_users + + +# --------------------------------------------------------------------------- +# 6. P1 regression: rejected auth rename must not mutate file-backed stores # --------------------------------------------------------------------------- def test_rejected_rename_does_not_mutate_files(monkeypatch, tmp_path): diff --git a/tests/test_research_status_avg_duration.py b/tests/test_research_status_avg_duration.py new file mode 100644 index 000000000..d44c63242 --- /dev/null +++ b/tests/test_research_status_avg_duration.py @@ -0,0 +1,41 @@ +"""get_status must not rescan the whole research dir on every SSE poll. + +get_avg_duration() globs and JSON-parses every file under the research data dir. +get_status() called it unconditionally on each poll, including for sessions that +are not active (the common case while a client polls a finished report). It is +now computed only for active sessions and memoized on the entry. +""" +from src.research_handler import ResearchHandler + + +def _handler(): + h = ResearchHandler.__new__(ResearchHandler) + h._active_tasks = {} + return h + + +def test_inactive_session_does_not_compute_avg(monkeypatch): + h = _handler() + calls = [] + monkeypatch.setattr(h, "get_avg_duration", lambda: (calls.append(1), 5.0)[1]) + # Unknown session, no disk file -> None, and no expensive avg scan. + assert h.get_status("missing-session") is None + assert calls == [] + + +def test_active_session_memoizes_avg(monkeypatch): + h = _handler() + h._active_tasks["s1"] = { + "status": "running", "progress": {}, "query": "q", "started_at": 0, + } + calls = [] + monkeypatch.setattr(h, "get_avg_duration", lambda: (calls.append(1), 12.0)[1]) + + r1 = h.get_status("s1") + r2 = h.get_status("s1") + r3 = h.get_status("s1") + + assert r1["avg_duration"] == 12.0 + assert r2["avg_duration"] == 12.0 and r3["avg_duration"] == 12.0 + # Computed once across many polls, not once per poll. + assert len(calls) == 1 diff --git a/tests/test_reserved_username_admin_escalation.py b/tests/test_reserved_username_admin_escalation.py index 29c423774..fff1aea78 100644 --- a/tests/test_reserved_username_admin_escalation.py +++ b/tests/test_reserved_username_admin_escalation.py @@ -58,6 +58,62 @@ def test_rename_into_reserved_username_is_blocked(tmp_path): assert "bob" in mgr.users +def test_legacy_reserved_username_is_removed_on_load(tmp_path): + auth_path = tmp_path / "auth.json" + auth_path.write_text( + '{"users": {"internal-tool": {"password_hash": "unused", "is_admin": false}, ' + '"admin": {"password_hash": "unused", "is_admin": true}}}', + encoding="utf-8", + ) + mgr = _fresh_auth_manager(tmp_path) + + assert "internal-tool" not in mgr.users + assert "admin" in mgr.users + assert "internal-tool" not in auth_path.read_text(encoding="utf-8") + + +def test_legacy_reserved_username_session_cannot_authenticate(tmp_path): + auth_path = tmp_path / "auth.json" + sessions_path = tmp_path / "sessions.json" + auth_path.write_text( + '{"users": {"internal-tool": {"password_hash": "unused", "is_admin": false}}}', + encoding="utf-8", + ) + sessions_path.write_text( + '{"tok": {"username": "internal-tool", "expiry": 9999999999}}', + encoding="utf-8", + ) + mgr = _fresh_auth_manager(tmp_path) + + assert mgr.validate_token("tok") is False + assert mgr.get_username_for_token("tok") is None + + +def test_legacy_reserved_single_user_migrates_to_admin(tmp_path): + auth_path = tmp_path / "auth.json" + auth_path.write_text( + '{"username": "internal-tool", "password_hash": "unused"}', + encoding="utf-8", + ) + mgr = _fresh_auth_manager(tmp_path) + + assert "internal-tool" not in mgr.users + assert "admin" in mgr.users + assert mgr.is_admin("admin") is True + + +def test_token_cache_owner_normalization_requires_current_user(): + clear_module("core.auth") + from core.auth import normalize_known_username + + users = {"alice": {}, "admin": {}} + + assert normalize_known_username(users, " Alice ") == "alice" + assert normalize_known_username(users, "internal-tool") is None + assert normalize_known_username(users, "api") is None + assert normalize_known_username(users, "") is None + + def test_normal_usernames_still_allowed(tmp_path): mgr = _fresh_auth_manager(tmp_path) assert mgr.create_user("alice", "pw-123456") is True diff --git a/tests/test_review_regressions.py b/tests/test_review_regressions.py index b3988f88e..fe782f151 100644 --- a/tests/test_review_regressions.py +++ b/tests/test_review_regressions.py @@ -647,6 +647,60 @@ def test_public_agent_policy_hides_sensitive_tools(monkeypatch): assert "manage_tasks" in blocked +def test_presetup_does_not_grant_admin_tools_when_auth_enabled(monkeypatch): + """Pre-setup window: auth is enabled but no admin user exists yet. + + This must NOT be treated as single-user/admin at the tool layer — the + server-execution tools (bash/python) stay blocked as defense-in-depth so + an unauthenticated caller that slips past the auth middleware (e.g. via a + loopback bypass) can't reach an RCE before setup completes. + """ + monkeypatch.delenv("AUTH_ENABLED", raising=False) # default: enabled + auth_mod = _install_core_auth_stub(monkeypatch) + + class FakeAuth: + is_configured = False + + def is_admin(self, username): + return False + + monkeypatch.setattr(auth_mod, "AuthManager", lambda: FakeAuth()) + + from src.tool_security import ( + blocked_tools_for_owner, + owner_is_admin_or_single_user, + ) + + assert owner_is_admin_or_single_user(None) is False + blocked = blocked_tools_for_owner(None) + assert "bash" in blocked + assert "python" in blocked + + +def test_single_user_mode_keeps_full_tool_access_when_auth_disabled(monkeypatch): + """Intentional single-user mode (AUTH_ENABLED=false) keeps full tool + access even with no admin user — this is the default local/self-host UX + and must not regress.""" + monkeypatch.setenv("AUTH_ENABLED", "false") + auth_mod = _install_core_auth_stub(monkeypatch) + + class FakeAuth: + is_configured = False + + def is_admin(self, username): + return False + + monkeypatch.setattr(auth_mod, "AuthManager", lambda: FakeAuth()) + + from src.tool_security import ( + blocked_tools_for_owner, + owner_is_admin_or_single_user, + ) + + assert owner_is_admin_or_single_user(None) is True + assert blocked_tools_for_owner(None) == set() + + @pytest.mark.asyncio async def test_webhook_tool_reuses_private_url_validation(): class FakeDb: diff --git a/tests/test_route_validators.py b/tests/test_route_validators.py new file mode 100644 index 000000000..a6fc07a98 --- /dev/null +++ b/tests/test_route_validators.py @@ -0,0 +1,23 @@ +import pytest +from fastapi import HTTPException + +from routes._validators import validate_remote_host, validate_ssh_port + + +def test_validate_ssh_port_rejects_shell_payload(): + for port in ["22;id", "$(id)", "-p 22", "0", "65536"]: + with pytest.raises(HTTPException): + validate_ssh_port(port) + assert validate_ssh_port("2222") == "2222" + + +def test_validate_remote_host_rejects_ssh_option_shape(): + for host in [ + "-oProxyCommand=sh", + "alice@-oProxyCommand=sh", + "--", + "-p2222", + ]: + with pytest.raises(HTTPException): + validate_remote_host(host) + assert validate_remote_host("alice@gpu-box_1") == "alice@gpu-box_1" diff --git a/tests/test_service_search_provider_guards.py b/tests/test_service_search_provider_guards.py index 373928e64..cb9171a54 100644 --- a/tests/test_service_search_provider_guards.py +++ b/tests/test_service_search_provider_guards.py @@ -90,8 +90,8 @@ def test_service_ddg_html_fallback_sends_safesearch(monkeypatch): seen["params"] = kwargs["params"] return _Response() - monkeypatch.setitem(sys.modules, "duckduckgo_search", None) monkeypatch.setattr(providers, "_get_search_settings", lambda: {"search_safesearch": "off"}) + monkeypatch.setitem(sys.modules, "ddgs", None) monkeypatch.setattr(providers.httpx, "get", fake_get) results = providers.duckduckgo_search("odysseus", count=1) diff --git a/tests/test_warmup_ping_urls.py b/tests/test_warmup_ping_urls.py new file mode 100644 index 000000000..7b5961831 --- /dev/null +++ b/tests/test_warmup_ping_urls.py @@ -0,0 +1,47 @@ +"""Startup warmup must resolve real endpoint URLs. + +The warmup/keepalive loop called `model_discovery.get_endpoints()`, which does +not exist on ModelDiscovery, so it raised AttributeError every run and pinged +nothing. `ModelDiscovery.warmup_ping_urls()` resolves the /models probe URLs +from the real discovery API. +""" +from src.model_discovery import ModelDiscovery + + +def _md(): + return ModelDiscovery.__new__(ModelDiscovery) + + +def test_old_method_never_existed(): + # Documents why the old warmup was a silent no-op. + assert not hasattr(ModelDiscovery, "get_endpoints") + + +def test_resolves_models_urls_from_discovered_items(): + md = _md() + md.discover_models = lambda: {"items": [ + {"url": "http://host:8000/v1/chat/completions", "models": ["a"]}, + {"url": "http://host:1234/v1/chat/completions", "models": ["b"]}, + ]} + assert md.warmup_ping_urls() == [ + "http://host:8000/v1/models", + "http://host:1234/v1/models", + ] + + +def test_limit_caps_results(): + md = _md() + md.discover_models = lambda: {"items": [ + {"url": f"http://h:{8000 + i}/v1/chat/completions"} for i in range(10) + ]} + assert len(md.warmup_ping_urls(limit=3)) == 3 + + +def test_discovery_failure_degrades_to_empty(): + md = _md() + + def boom(): + raise RuntimeError("port scan failed") + + md.discover_models = boom + assert md.warmup_ping_urls() == [] diff --git a/tests/test_web_search_tool_icon_js.py b/tests/test_web_search_tool_icon_js.py new file mode 100644 index 000000000..6e855df40 --- /dev/null +++ b/tests/test_web_search_tool_icon_js.py @@ -0,0 +1,119 @@ +"""Pin the web_search tool-icon rendering in the agent thread (PR #??). + +Verifies: +- web_search renders an icon instead of raw markup +- Other tools get the default ▶ icon +- Hostile tool names are HTML-escaped in the label + +Pure JS via node --input-type=module (same approach as +test_composer_arrow_up_recall_js.py). Skips when node is not installed. +""" + +import json +import shutil +import subprocess +from pathlib import Path + +import pytest + +_REPO = Path(__file__).resolve().parent.parent +_HAS_NODE = shutil.which("node") is not None + +_CHECK_JS = r""" +function esc(s) { + const map = { '&': '&', '<': '<', '>': '>', '"': '"', "'": ''' }; + return (s || '').replace(/[&<>"']/g, (m) => map[m]); +} + +const _searchIcon = ''; + +const _toolLabels = { + web_search: 'Searching', + bash: 'Running', +}; + +const _toolIcons = { + web_search: _searchIcon, +}; + +function renderIcon(toolName) { + return _toolIcons[toolName.toLowerCase()] || '\u25B6'; +} + +function renderLabel(toolName) { + return _toolLabels[toolName.toLowerCase()] || toolName; +} + +function renderThreadHTML(toolName, cmd) { + const label = renderLabel(toolName); + const icon = renderIcon(toolName); + const cmdHtml = cmd ? `
${esc(cmd)}
` : ''; + return `
${icon}${esc(label)}\u2581\u2582\u2583
${cmdHtml}
`; +} + +const cases = CASES_JSON; +const results = cases.map(c => { + const html = renderThreadHTML(c.tool, c.cmd || ''); + return { tool: c.tool, html }; +}); +console.log(JSON.stringify(results)); +""" + + +def _run(cases: list) -> list: + js = _CHECK_JS.replace("CASES_JSON", json.dumps(cases)) + proc = subprocess.run( + ["node", "--input-type=module"], + input=js, + capture_output=True, + text=True, + encoding="utf-8", + cwd=str(_REPO), + timeout=30, + ) + assert proc.returncode == 0, proc.stderr + return json.loads(proc.stdout.strip()) + + +@pytest.mark.skipif(not _HAS_NODE, reason="node binary not on PATH") +def test_web_search_icon_contains_svg(): + out = _run([{"tool": "web_search"}])[0] + assert " in agent-thread-icon for web_search" + assert "Searching" in out["html"], "Expected 'Searching' label for web_search" + + +@pytest.mark.skipif(not _HAS_NODE, reason="node binary not on PATH") +def test_default_tool_icon_is_triangle(): + out = _run([{"tool": "bash"}])[0] + assert "▶" in out["html"], "Expected ▶ icon for tools without custom icon" + assert " for bash" + assert "Running" in out["html"], "Expected 'Running' label for bash" + + +@pytest.mark.skipif(not _HAS_NODE, reason="node binary not on PATH") +def test_unknown_tool_falls_back_to_name(): + out = _run([{"tool": "my_custom_tool"}])[0] + assert "▶" in out["html"], "Expected ▶ for unknown tool" + assert "my_custom_tool" in out["html"], "Expected tool name as label" + + +@pytest.mark.skipif(not _HAS_NODE, reason="node binary not on PATH") +def test_hostile_tool_name_is_escaped(): + out = _run([{"tool": ''}])[0] + assert "<img" in out["html"], "Expected < to be HTML-escaped" + assert ">" in out["html"], "Expected > to be HTML-escaped" + assert " must not appear" + assert "onerror" not in out["html"] or """ in out["html"], "onerror must not be executable" + + +@pytest.mark.skipif(not _HAS_NODE, reason="node binary not on PATH") +def test_unknown_tool_case_insensitive_matches_icons(): + out = _run([{"tool": "WEB_SEARCH"}, {"tool": "Web_Search"}]) + for r in out: + assert "