Merge branch 'dev' into fix/no-scroll-snapping

This commit is contained in:
broken💎shaders
2026-06-11 11:43:53 +08:00
committed by GitHub
40 changed files with 1666 additions and 161 deletions
+1 -1
View File
@@ -329,7 +329,7 @@ To expose Odysseus on a local network or Tailscale with HTTPS:
| Package | Feature unlocked |
|---------|-----------------|
| `faster-whisper` | Local speech-to-text (microphone -> text) via the "local" STT provider. |
| `duckduckgo-search` | DuckDuckGo as a search provider option. |
| `ddgs` | DuckDuckGo as a search provider option. |
| `PyMuPDF` | PDF page rendering in the side viewer panel and form-filling. (Note: AGPL-3.0) |
| `markitdown` | Office/EPUB document text extraction (converts .docx/.xlsx/.pptx/.xls/.epub to Markdown). |
+20 -6
View File
@@ -56,7 +56,7 @@ from core.constants import (
)
from core.database import SessionLocal, ApiToken
from core.middleware import SecurityHeadersMiddleware, is_cors_preflight
from core.auth import AuthManager
from core.auth import AuthManager, normalize_known_username
from core.exceptions import (
SessionNotFoundError, InvalidFileUploadError,
LLMServiceError, WebSearchError,
@@ -228,8 +228,16 @@ if AUTH_ENABLED:
try:
rows = db.query(ApiToken).filter(ApiToken.is_active == True).all()
for r in rows:
owner_key = normalize_known_username(auth_manager.users, getattr(r, "owner", None))
if not owner_key:
logger.warning(
"Ignoring active API token '%s' for unknown auth user '%s'",
getattr(r, "id", ""),
getattr(r, "owner", None),
)
continue
scopes = [s.strip() for s in (getattr(r, "scopes", "") or "chat").split(",") if s.strip()]
new_map[r.token_prefix].append((r.id, r.token_hash, getattr(r, "owner", None), scopes))
new_map[r.token_prefix].append((r.id, r.token_hash, owner_key, scopes))
finally:
db.close()
_token_cache.clear()
@@ -495,6 +503,7 @@ api_key_manager = components["api_key_manager"]
preset_manager = components["preset_manager"]
chat_processor = components["chat_processor"]
research_handler = components["research_handler"]
app.state.research_handler = research_handler
chat_handler = components["chat_handler"]
model_discovery = components["model_discovery"]
skills_manager = components["skills_manager"]
@@ -938,10 +947,15 @@ async def _startup_event():
async def _warmup_endpoints():
try:
import httpx
endpoints = model_discovery.get_endpoints() if model_discovery else []
for ep in endpoints[:5]:
url = ep.get("url", "").replace("/chat/completions", "/models")
if url:
# model_discovery has no get_endpoints(); that call raised
# AttributeError every run and silently disabled warmup/keepalive.
# Resolve the /models probe URLs via the real discovery API, off the
# event loop since discovery does a blocking port scan.
urls = (
await asyncio.to_thread(model_discovery.warmup_ping_urls)
if model_discovery else []
)
for url in urls:
try:
async with httpx.AsyncClient(timeout=5.0) as client:
await client.get(url)
+56 -13
View File
@@ -67,6 +67,14 @@ TOKEN_TTL = 60 * 60 * 24 * 7 # 7 days
RESERVED_USERNAMES = frozenset({"internal-tool", "api", "demo", "system"})
def normalize_known_username(users: Dict[str, Any], username: str | None) -> Optional[str]:
"""Return a normalized username only when it exists in the auth user map."""
key = str(username or "").strip().lower()
if not key or key not in users:
return None
return key
def _hash_password(password: str) -> str:
return bcrypt.hashpw(password.encode("utf-8"), bcrypt.gensalt()).decode("utf-8")
@@ -96,6 +104,7 @@ class AuthManager:
self._load()
self._load_sessions()
self._migrate_single_user()
self._drop_reserved_loaded_users()
self._migrate_legacy_admin_role()
def _load(self):
@@ -148,7 +157,13 @@ class AuthManager:
def _migrate_single_user(self):
"""Migrate old single-user format to multi-user format."""
if "password_hash" in self._config and "users" not in self._config:
old_user = self._config.get("username", "admin")
old_user = str(self._config.get("username", "admin") or "admin").strip().lower()
if old_user in RESERVED_USERNAMES:
logger.warning(
"Migrating legacy single-user reserved username '%s' to 'admin'",
old_user,
)
old_user = "admin"
old_hash = self._config["password_hash"]
self._config = {
"users": {
@@ -162,6 +177,30 @@ class AuthManager:
self._save()
logger.info(f"Migrated single-user auth to multi-user (admin: {old_user})")
def _drop_reserved_loaded_users(self):
"""Fail closed for legacy/manual auth rows that collide with sentinels."""
users = self._config.get("users")
if not isinstance(users, dict):
return
normalized = {}
removed = []
for username, data in users.items():
key = str(username or "").strip().lower()
if not key:
continue
if key in RESERVED_USERNAMES:
removed.append(key)
continue
normalized[key] = data
if removed or normalized != users:
self._config["users"] = normalized
self._save()
if removed:
logger.warning(
"Removed reserved username(s) from auth config: %s",
", ".join(sorted(set(removed))),
)
def _migrate_legacy_admin_role(self):
"""Normalize setup.py's old role='admin' marker to is_admin=True."""
changed = False
@@ -244,6 +283,22 @@ class AuthManager:
return False
if not self.users.get(requesting_user, {}).get("is_admin"):
return False
# Revoke API bearer tokens before removing the auth row. The bearer
# path authenticates from ApiToken rows and does not require the
# owner to still exist, so a successful delete must not leave active
# rows behind. If the token store is unavailable, fail closed and
# keep the user/session state intact so the admin can retry.
try:
from core.database import get_db_session, ApiToken
with get_db_session() as db:
removed_tokens = db.query(ApiToken).filter(ApiToken.owner == username).delete()
if removed_tokens:
logger.info(
f"Revoked {removed_tokens} API token(s) owned by deleted user '{username}'"
)
except Exception:
logger.warning(f"Failed to revoke API tokens for deleted user '{username}'")
return False
del self._config["users"][username]
self._save()
# Purge all sessions belonging to this user. validate_token doesn't
@@ -258,18 +313,6 @@ class AuthManager:
revoked += 1
if revoked:
self._save_sessions()
# Also revoke API bearer tokens owned by this user. The bearer auth
# path authenticates straight against ApiToken rows and never
# re-checks that the owner still exists, so leaving the rows behind
# would let a deleted user keep full API access indefinitely.
try:
from core.database import get_db_session, ApiToken
with get_db_session() as db:
removed = db.query(ApiToken).filter(ApiToken.owner == username).delete()
if removed:
logger.info(f"Revoked {removed} API token(s) owned by deleted user '{username}'")
except Exception:
logger.warning(f"Failed to revoke API tokens for deleted user '{username}'")
logger.info(f"Deleted user '{username}' (by {requesting_user}); revoked {revoked} active session(s)")
return True
+150 -25
View File
@@ -688,6 +688,7 @@ def _migrate_add_last_message_at_column():
db_path = DATABASE_URL.replace("sqlite:///", "")
if not os.path.exists(db_path):
return
conn = None
try:
conn = sqlite3.connect(db_path)
cursor = conn.execute("PRAGMA table_info(sessions)")
@@ -713,10 +714,14 @@ def _migrate_add_last_message_at_column():
"ON sessions(archived, last_message_at)"
)
conn.commit()
conn.close()
logging.getLogger(__name__).info("Migrated: added + backfilled 'last_message_at' on sessions")
except Exception as e:
logging.getLogger(__name__).warning(f"last_message_at migration failed: {e}")
finally:
try:
conn.close()
except Exception:
pass
def _migrate_add_document_archived_column():
"""Add `archived` to documents (soft-archive flag). Guarded + idempotent."""
@@ -724,6 +729,7 @@ def _migrate_add_document_archived_column():
db_path = DATABASE_URL.replace("sqlite:///", "")
if not os.path.exists(db_path):
return
conn = None
try:
conn = sqlite3.connect(db_path)
cursor = conn.execute("PRAGMA table_info(documents)")
@@ -732,9 +738,13 @@ def _migrate_add_document_archived_column():
conn.execute("ALTER TABLE documents ADD COLUMN archived BOOLEAN DEFAULT 0")
conn.commit()
logging.getLogger(__name__).info("Migrated: added 'archived' to documents")
conn.close()
except Exception as e:
logging.getLogger(__name__).warning(f"documents.archived migration failed: {e}")
finally:
try:
conn.close()
except Exception:
pass
def _migrate_add_owner_column():
@@ -743,6 +753,7 @@ def _migrate_add_owner_column():
db_path = DATABASE_URL.replace("sqlite:///", "")
if not os.path.exists(db_path):
return
conn = None
try:
conn = sqlite3.connect(db_path)
cursor = conn.execute("PRAGMA table_info(sessions)")
@@ -752,9 +763,13 @@ def _migrate_add_owner_column():
conn.execute("CREATE INDEX IF NOT EXISTS ix_sessions_owner ON sessions(owner)")
conn.commit()
logging.getLogger(__name__).info("Migrated: added 'owner' column to sessions")
conn.close()
except Exception as e:
logging.getLogger(__name__).warning(f"Migration check failed: {e}")
finally:
try:
conn.close()
except Exception:
pass
def _migrate_model_endpoints():
"""Recreate model_endpoints table if schema changed (url->base_url)."""
@@ -762,6 +777,7 @@ def _migrate_model_endpoints():
db_path = DATABASE_URL.replace("sqlite:///", "")
if not os.path.exists(db_path):
return
conn = None
try:
conn = sqlite3.connect(db_path)
cursor = conn.execute("PRAGMA table_info(model_endpoints)")
@@ -770,9 +786,13 @@ def _migrate_model_endpoints():
conn.execute("DROP TABLE IF EXISTS model_endpoints")
conn.commit()
logging.getLogger(__name__).info("Migrated: dropped old model_endpoints table (schema change)")
conn.close()
except Exception as e:
logging.getLogger(__name__).warning(f"model_endpoints migration check failed: {e}")
finally:
try:
conn.close()
except Exception:
pass
def _migrate_add_hidden_models_column():
"""Add hidden_models column to model_endpoints if it doesn't exist."""
@@ -780,6 +800,7 @@ def _migrate_add_hidden_models_column():
db_path = DATABASE_URL.replace("sqlite:///", "")
if not os.path.exists(db_path):
return
conn = None
try:
conn = sqlite3.connect(db_path)
cursor = conn.execute("PRAGMA table_info(model_endpoints)")
@@ -788,9 +809,13 @@ def _migrate_add_hidden_models_column():
conn.execute("ALTER TABLE model_endpoints ADD COLUMN hidden_models TEXT")
conn.commit()
logging.getLogger(__name__).info("Migrated: added 'hidden_models' column to model_endpoints")
conn.close()
except Exception as e:
logging.getLogger(__name__).warning(f"hidden_models migration failed: {e}")
finally:
try:
conn.close()
except Exception:
pass
def _migrate_add_model_endpoint_owner_column():
"""Add owner column to model_endpoints if it doesn't exist.
@@ -805,6 +830,7 @@ def _migrate_add_model_endpoint_owner_column():
db_path = DATABASE_URL.replace("sqlite:///", "")
if not os.path.exists(db_path):
return
conn = None
try:
conn = sqlite3.connect(db_path)
cursor = conn.execute("PRAGMA table_info(model_endpoints)")
@@ -814,9 +840,13 @@ def _migrate_add_model_endpoint_owner_column():
conn.execute("CREATE INDEX IF NOT EXISTS ix_model_endpoints_owner ON model_endpoints(owner)")
conn.commit()
logging.getLogger(__name__).info("Migrated: added 'owner' column + index to model_endpoints")
conn.close()
except Exception as e:
logging.getLogger(__name__).warning(f"model_endpoints.owner migration failed: {e}")
finally:
try:
conn.close()
except Exception:
pass
def _migrate_add_provider_auth_id_column():
@@ -825,6 +855,7 @@ def _migrate_add_provider_auth_id_column():
db_path = DATABASE_URL.replace("sqlite:///", "")
if not os.path.exists(db_path):
return
conn = None
try:
conn = sqlite3.connect(db_path)
cursor = conn.execute("PRAGMA table_info(model_endpoints)")
@@ -834,9 +865,13 @@ def _migrate_add_provider_auth_id_column():
conn.execute("CREATE INDEX IF NOT EXISTS ix_model_endpoints_provider_auth_id ON model_endpoints(provider_auth_id)")
conn.commit()
logging.getLogger(__name__).info("Migrated: added 'provider_auth_id' column + index to model_endpoints")
conn.close()
except Exception as e:
logging.getLogger(__name__).warning(f"model_endpoints.provider_auth_id migration failed: {e}")
finally:
try:
conn.close()
except Exception:
pass
def _migrate_add_model_type_column():
@@ -845,6 +880,7 @@ def _migrate_add_model_type_column():
db_path = DATABASE_URL.replace("sqlite:///", "")
if not os.path.exists(db_path):
return
conn = None
try:
conn = sqlite3.connect(db_path)
cursor = conn.execute("PRAGMA table_info(model_endpoints)")
@@ -853,9 +889,13 @@ def _migrate_add_model_type_column():
conn.execute("ALTER TABLE model_endpoints ADD COLUMN model_type TEXT DEFAULT 'llm'")
conn.commit()
logging.getLogger(__name__).info("Migrated: added 'model_type' column to model_endpoints")
conn.close()
except Exception as e:
logging.getLogger(__name__).warning(f"model_type migration failed: {e}")
finally:
try:
conn.close()
except Exception:
pass
def _migrate_add_model_endpoint_refresh_columns():
"""Add endpoint classification / refresh policy columns if missing."""
@@ -863,6 +903,7 @@ def _migrate_add_model_endpoint_refresh_columns():
db_path = DATABASE_URL.replace("sqlite:///", "")
if not os.path.exists(db_path):
return
conn = None
try:
conn = sqlite3.connect(db_path)
cursor = conn.execute("PRAGMA table_info(model_endpoints)")
@@ -876,9 +917,13 @@ def _migrate_add_model_endpoint_refresh_columns():
if columns and "model_refresh_timeout" not in columns:
conn.execute("ALTER TABLE model_endpoints ADD COLUMN model_refresh_timeout INTEGER")
conn.commit()
conn.close()
except Exception as e:
logging.getLogger(__name__).warning(f"model_endpoints refresh-policy migration failed: {e}")
finally:
try:
conn.close()
except Exception:
pass
def _migrate_add_task_run_model_column():
"""Add model column to task_runs if it doesn't exist (records which model ran)."""
@@ -886,6 +931,7 @@ def _migrate_add_task_run_model_column():
db_path = DATABASE_URL.replace("sqlite:///", "")
if not os.path.exists(db_path):
return
conn = None
try:
conn = sqlite3.connect(db_path)
cursor = conn.execute("PRAGMA table_info(task_runs)")
@@ -894,9 +940,13 @@ def _migrate_add_task_run_model_column():
conn.execute("ALTER TABLE task_runs ADD COLUMN model TEXT")
conn.commit()
logging.getLogger(__name__).info("Migrated: added 'model' column to task_runs")
conn.close()
except Exception as e:
logging.getLogger(__name__).warning(f"task_runs model migration failed: {e}")
finally:
try:
conn.close()
except Exception:
pass
def _migrate_add_supports_tools_column():
"""Add supports_tools column to model_endpoints if it doesn't exist."""
@@ -904,6 +954,7 @@ def _migrate_add_supports_tools_column():
db_path = DATABASE_URL.replace("sqlite:///", "")
if not os.path.exists(db_path):
return
conn = None
try:
conn = sqlite3.connect(db_path)
cursor = conn.execute("PRAGMA table_info(model_endpoints)")
@@ -912,9 +963,13 @@ def _migrate_add_supports_tools_column():
conn.execute("ALTER TABLE model_endpoints ADD COLUMN supports_tools BOOLEAN")
conn.commit()
logging.getLogger(__name__).info("Migrated: added 'supports_tools' column to model_endpoints")
conn.close()
except Exception as e:
logging.getLogger(__name__).warning(f"supports_tools migration failed: {e}")
finally:
try:
conn.close()
except Exception:
pass
def _migrate_add_cached_models_column():
@@ -923,6 +978,7 @@ def _migrate_add_cached_models_column():
db_path = DATABASE_URL.replace("sqlite:///", "")
if not os.path.exists(db_path):
return
conn = None
try:
conn = sqlite3.connect(db_path)
cursor = conn.execute("PRAGMA table_info(model_endpoints)")
@@ -930,9 +986,13 @@ def _migrate_add_cached_models_column():
if columns and "cached_models" not in columns:
conn.execute("ALTER TABLE model_endpoints ADD COLUMN cached_models TEXT")
conn.commit()
conn.close()
except Exception as e:
logging.getLogger(__name__).warning(f"cached_models migration failed: {e}")
finally:
try:
conn.close()
except Exception:
pass
def _migrate_add_pinned_models_column():
"""Add pinned_models column to model_endpoints if it doesn't exist."""
@@ -940,6 +1000,7 @@ def _migrate_add_pinned_models_column():
db_path = DATABASE_URL.replace("sqlite:///", "")
if not os.path.exists(db_path):
return
conn = None
try:
conn = sqlite3.connect(db_path)
cursor = conn.execute("PRAGMA table_info(model_endpoints)")
@@ -948,9 +1009,13 @@ def _migrate_add_pinned_models_column():
conn.execute("ALTER TABLE model_endpoints ADD COLUMN pinned_models TEXT")
conn.commit()
logging.getLogger(__name__).info("Migrated: added 'pinned_models' column to model_endpoints")
conn.close()
except Exception as e:
logging.getLogger(__name__).warning(f"pinned_models migration failed: {e}")
finally:
try:
conn.close()
except Exception:
pass
def _migrate_add_notes_sort_order():
"""Add sort_order, image_url, repeat columns to notes if they don't exist."""
@@ -958,6 +1023,7 @@ def _migrate_add_notes_sort_order():
db_path = DATABASE_URL.replace("sqlite:///", "")
if not os.path.exists(db_path):
return
conn = None
try:
conn = sqlite3.connect(db_path)
cursor = conn.execute("PRAGMA table_info(notes)")
@@ -975,9 +1041,13 @@ def _migrate_add_notes_sort_order():
if columns and "agent_session_id" not in columns:
conn.execute("ALTER TABLE notes ADD COLUMN agent_session_id TEXT")
conn.commit()
conn.close()
except Exception as e:
logging.getLogger(__name__).warning(f"notes migration failed: {e}")
finally:
try:
conn.close()
except Exception:
pass
def _migrate_add_mode_column():
"""Add mode column to sessions table if it doesn't exist."""
@@ -985,6 +1055,7 @@ def _migrate_add_mode_column():
db_path = DATABASE_URL.replace("sqlite:///", "")
if not os.path.exists(db_path):
return
conn = None
try:
conn = sqlite3.connect(db_path)
cursor = conn.execute("PRAGMA table_info(sessions)")
@@ -993,9 +1064,13 @@ def _migrate_add_mode_column():
conn.execute("ALTER TABLE sessions ADD COLUMN mode TEXT")
conn.commit()
logging.getLogger(__name__).info("Migrated: added 'mode' column to sessions")
conn.close()
except Exception as e:
logging.getLogger(__name__).warning(f"Migration check for mode failed: {e}")
finally:
try:
conn.close()
except Exception:
pass
def _migrate_add_folder_column():
"""Add folder column to sessions table if it doesn't exist."""
@@ -1003,6 +1078,7 @@ def _migrate_add_folder_column():
db_path = DATABASE_URL.replace("sqlite:///", "")
if not os.path.exists(db_path):
return
conn = None
try:
conn = sqlite3.connect(db_path)
cursor = conn.execute("PRAGMA table_info(sessions)")
@@ -1011,9 +1087,13 @@ def _migrate_add_folder_column():
conn.execute("ALTER TABLE sessions ADD COLUMN folder TEXT")
conn.commit()
logging.getLogger(__name__).info("Migrated: added 'folder' column to sessions")
conn.close()
except Exception as e:
logging.getLogger(__name__).warning(f"Migration check for folder failed: {e}")
finally:
try:
conn.close()
except Exception:
pass
def _migrate_add_token_columns():
"""Add cumulative token tracking columns to sessions table."""
@@ -1021,6 +1101,7 @@ def _migrate_add_token_columns():
db_path = DATABASE_URL.replace("sqlite:///", "")
if not os.path.exists(db_path):
return
conn = None
try:
conn = sqlite3.connect(db_path)
cursor = conn.execute("PRAGMA table_info(sessions)")
@@ -1030,9 +1111,13 @@ def _migrate_add_token_columns():
conn.execute("ALTER TABLE sessions ADD COLUMN total_output_tokens INTEGER DEFAULT 0")
conn.commit()
logging.getLogger(__name__).info("Migrated: added token tracking columns to sessions")
conn.close()
except Exception as e:
logging.getLogger(__name__).warning(f"Migration check for token columns failed: {e}")
finally:
try:
conn.close()
except Exception:
pass
def _migrate_add_owner_to_table(table_name: str, index_name: str):
"""Generic helper: add owner TEXT column + index to a table if missing."""
@@ -1040,6 +1125,7 @@ def _migrate_add_owner_to_table(table_name: str, index_name: str):
db_path = DATABASE_URL.replace("sqlite:///", "")
if not os.path.exists(db_path):
return
conn = None
try:
conn = sqlite3.connect(db_path)
cursor = conn.execute(f"PRAGMA table_info({table_name})")
@@ -1049,9 +1135,13 @@ def _migrate_add_owner_to_table(table_name: str, index_name: str):
conn.execute(f"CREATE INDEX IF NOT EXISTS {index_name} ON {table_name}(owner)")
conn.commit()
logging.getLogger(__name__).info(f"Migrated: added 'owner' column to {table_name}")
conn.close()
except Exception as e:
logging.getLogger(__name__).warning(f"Migration owner column for {table_name} failed: {e}")
finally:
try:
conn.close()
except Exception:
pass
def _migrate_add_multiuser_owner_columns():
"""Add owner column to memories, gallery_images, user_tools, comparisons."""
@@ -1076,6 +1166,7 @@ def _migrate_add_api_token_scopes_column():
db_path = DATABASE_URL.replace("sqlite:///", "")
if not os.path.exists(db_path):
return
conn = None
try:
conn = sqlite3.connect(db_path)
columns = [row[1] for row in conn.execute("PRAGMA table_info(api_tokens)").fetchall()]
@@ -1084,9 +1175,13 @@ def _migrate_add_api_token_scopes_column():
conn.execute("UPDATE api_tokens SET scopes = 'chat' WHERE scopes IS NULL OR scopes = ''")
conn.commit()
logging.getLogger(__name__).info("Migrated: added scopes column to api_tokens")
conn.close()
except Exception as e:
logging.getLogger(__name__).warning(f"api_tokens.scopes migration failed: {e}")
finally:
try:
conn.close()
except Exception:
pass
def _migrate_assign_legacy_owner():
"""Assign all null-owner data to the first (admin) user.
@@ -1128,6 +1223,7 @@ def _migrate_assign_legacy_owner():
return
logger = logging.getLogger(__name__)
conn = None
try:
conn = sqlite3.connect(db_path)
# Every table with an `owner` column. New tables added later will be
@@ -1152,9 +1248,13 @@ def _migrate_assign_legacy_owner():
except Exception as e:
logger.warning(f"Legacy owner assignment for {table} failed: {e}")
conn.commit()
conn.close()
except Exception as e:
logger.warning(f"Legacy owner migration failed: {e}")
finally:
try:
conn.close()
except Exception:
pass
# Also migrate memory.json
mem_path = MEMORY_FILE
@@ -1773,6 +1873,7 @@ def _migrate_add_email_smtp_security():
db_path = DATABASE_URL.replace("sqlite:///", "")
if not os.path.exists(db_path):
return
conn = None
try:
conn = sqlite3.connect(db_path)
cursor = conn.execute("PRAGMA table_info(email_accounts)")
@@ -1788,9 +1889,13 @@ def _migrate_add_email_smtp_security():
)
conn.commit()
logging.getLogger(__name__).info("Migrated: added smtp_security column to email_accounts")
conn.close()
except Exception as e:
logging.getLogger(__name__).warning(f"smtp_security migration skipped: {e}")
finally:
try:
conn.close()
except Exception:
pass
def _migrate_encrypt_endpoint_keys():
@@ -1891,6 +1996,7 @@ def _migrate_add_calendar_is_utc():
db_path = DATABASE_URL.replace("sqlite:///", "")
if not os.path.exists(db_path):
return
conn = None
try:
conn = sqlite3.connect(db_path)
cursor = conn.execute("PRAGMA table_info(calendar_events)")
@@ -1899,9 +2005,13 @@ def _migrate_add_calendar_is_utc():
conn.execute("ALTER TABLE calendar_events ADD COLUMN is_utc BOOLEAN DEFAULT 0 NOT NULL")
conn.commit()
logging.getLogger(__name__).info("Migrated: added 'is_utc' column to calendar_events")
conn.close()
except Exception as e:
logging.getLogger(__name__).warning(f"is_utc migration failed: {e}")
finally:
try:
conn.close()
except Exception:
pass
def _migrate_add_calendar_origin():
@@ -1912,6 +2022,7 @@ def _migrate_add_calendar_origin():
db_path = DATABASE_URL.replace("sqlite:///", "")
if not os.path.exists(db_path):
return
conn = None
try:
conn = sqlite3.connect(db_path)
cursor = conn.execute("PRAGMA table_info(calendar_events)")
@@ -1921,9 +2032,13 @@ def _migrate_add_calendar_origin():
conn.execute("CREATE INDEX IF NOT EXISTS ix_calendar_events_origin ON calendar_events(origin)")
conn.commit()
logging.getLogger(__name__).info("Migrated: added 'origin' column to calendar_events")
conn.close()
except Exception as e:
logging.getLogger(__name__).warning(f"calendar_events.origin migration failed: {e}")
finally:
try:
conn.close()
except Exception:
pass
def _migrate_add_calendar_account_id():
@@ -1933,6 +2048,7 @@ def _migrate_add_calendar_account_id():
db_path = DATABASE_URL.replace("sqlite:///", "")
if not os.path.exists(db_path):
return
conn = None
try:
conn = sqlite3.connect(db_path)
cursor = conn.execute("PRAGMA table_info(calendars)")
@@ -1942,9 +2058,13 @@ def _migrate_add_calendar_account_id():
conn.execute("CREATE INDEX IF NOT EXISTS ix_calendars_account_id ON calendars(account_id)")
conn.commit()
logging.getLogger(__name__).info("Migrated: added 'account_id' column to calendars")
conn.close()
except Exception as e:
logging.getLogger(__name__).warning(f"calendars.account_id migration failed: {e}")
finally:
try:
conn.close()
except Exception:
pass
def _migrate_add_calendar_metadata():
@@ -1953,6 +2073,7 @@ def _migrate_add_calendar_metadata():
db_path = DATABASE_URL.replace("sqlite:///", "")
if not os.path.exists(db_path):
return
conn = None
try:
conn = sqlite3.connect(db_path)
cursor = conn.execute("PRAGMA table_info(calendar_events)")
@@ -1964,9 +2085,13 @@ def _migrate_add_calendar_metadata():
if columns and "last_pinged" not in columns:
conn.execute("ALTER TABLE calendar_events ADD COLUMN last_pinged DATETIME")
conn.commit()
conn.close()
except Exception as e:
logging.getLogger(__name__).warning(f"calendar_events migration failed: {e}")
finally:
try:
conn.close()
except Exception:
pass
def get_db():
"""
+4
View File
@@ -366,6 +366,10 @@ def _ssh_exec_argv(
strict_host_key_checking: bool | None = None,
) -> list[str]:
"""Build a consistent ssh argv for remote command execution."""
remote_value = str(remote or "").strip()
remote_host = remote_value.rsplit("@", 1)[-1]
if not remote_value or remote_value.startswith("-") or not remote_host or remote_host.startswith("-"):
raise ValueError("Invalid SSH remote host")
argv = ["ssh"]
if connect_timeout is not None:
argv.extend(["-o", f"ConnectTimeout={int(connect_timeout)}"])
+1 -1
View File
@@ -15,7 +15,7 @@ faster-whisper
# DuckDuckGo as a search provider option.
# Install if you want DDG in the search-provider dropdown.
# Alternatives: SearXNG, Brave, Tavily, Serper, Google PSE.
duckduckgo-search
ddgs
# PDF form-filling feature (fillable AcroForm detection, field extraction,
# value/annotation/signature stamping, page rendering for the form overlay).
+31
View File
@@ -0,0 +1,31 @@
import re
from fastapi import HTTPException
_REMOTE_HOST_RE = re.compile(
r"^(?:[A-Za-z0-9][A-Za-z0-9._-]*@)?[A-Za-z0-9][A-Za-z0-9._-]*$"
)
_SSH_PORT_RE = re.compile(r"^\d{1,5}$")
def validate_remote_host(v: str | None) -> str | None:
if v is None or v == "":
return None
if not _REMOTE_HOST_RE.match(v):
raise HTTPException(
400,
"Invalid remote_host — must be host or user@host, no SSH option syntax",
)
return v
def validate_ssh_port(v: str | None) -> str | None:
if v is None or v == "":
return None
if not _SSH_PORT_RE.fullmatch(str(v)):
raise HTTPException(400, "Invalid ssh_port")
port = int(v)
if port < 1 or port > 65535:
raise HTTPException(400, "Invalid ssh_port")
return str(port)
+54 -10
View File
@@ -305,6 +305,19 @@ def setup_auth_routes(auth_manager: AuthManager) -> APIRouter:
if not ok:
raise HTTPException(400, "Cannot rename user")
def _rollback_auth_rename() -> bool:
# On self-rename the admin session has already moved to the new
# username, so the rollback must authenticate as the new user.
rollback_user = new_username if user == old_username else user
try:
return bool(auth_manager.rename_user(new_username, old_username, rollback_user))
except Exception as rollback_err:
logger.error(
"Failed to roll back auth rename %s -> %s after owner migration failure: %s",
new_username, old_username, rollback_err,
)
return False
# Usernames are ownership keys for user data. Rename the common
# owner-scoped DB rows so the account keeps access to its sessions,
# docs, email accounts, tasks, etc.
@@ -330,6 +343,11 @@ def setup_auth_routes(auth_manager: AuthManager) -> APIRouter:
db.close()
except Exception as e:
logger.error("Failed to rename owner references %s -> %s: %s", old_username, new_username, e)
if not _rollback_auth_rename():
logger.error(
"Auth rename %s -> %s could not be rolled back after owner migration failure",
old_username, new_username,
)
raise HTTPException(500, "Failed to rename user data")
# Per-user prefs are JSON-backed, not SQL-backed.
@@ -349,6 +367,20 @@ def setup_auth_routes(auth_manager: AuthManager) -> APIRouter:
except Exception as e:
logger.warning("Failed to rename user prefs %s -> %s: %s", old_username, new_username, e)
# In-flight deep-research tasks live in the process-local
# ResearchHandler registry. They are not covered by the persisted JSON
# migration above, but the research routes filter and cancel by this
# owner field while the job is running. Do this before sweeping
# completed JSON files so a job that finishes during the rename saves
# with the new owner or is caught by the disk sweep below.
try:
rh = getattr(request.app.state, "research_handler", None)
rename_owner = getattr(rh, "rename_owner", None)
if callable(rename_owner):
rename_owner(old_username, new_username)
except Exception as e:
logger.warning("Failed to rename active research tasks %s -> %s: %s", old_username, new_username, e)
# deep_research: each completed report is a standalone JSON file with
# an `owner` field. research_routes filters by d.get("owner") == user,
# so a stale owner makes every report invisible to the renamed user.
@@ -391,7 +423,8 @@ def setup_auth_routes(auth_manager: AuthManager) -> APIRouter:
skills_root = Path(SKILLS_DIR)
if skills_root.is_dir():
_owner_re = re.compile(
r'(?m)^(owner:\s*)' + re.escape(old_username) + r'\s*$'
r'(?m)^(owner:\s*)' + re.escape(old_username) + r'\s*$',
re.IGNORECASE,
)
for p in skills_root.rglob("SKILL.md"):
try:
@@ -406,12 +439,12 @@ def setup_auth_routes(auth_manager: AuthManager) -> APIRouter:
try:
usage = json.loads(usage_path.read_text(encoding="utf-8"))
if isinstance(usage, dict):
prefix = old_username + "::"
new_usage = {}
changed = False
for k, v in usage.items():
if k.startswith(prefix):
new_usage[new_username + "::" + k[len(prefix):]] = v
owner_part, sep, skill_part = k.partition("::")
if sep and owner_part.lower() == old_username:
new_usage[new_username + "::" + skill_part] = v
changed = True
else:
new_usage[k] = v
@@ -473,7 +506,23 @@ def setup_auth_routes(auth_manager: AuthManager) -> APIRouter:
user = _get_current_user(request)
if not user or not auth_manager.is_admin(user):
raise HTTPException(403, "Admin only")
def _invalidate_api_token_cache():
try:
invalidator = getattr(request.app.state, "invalidate_token_cache", None)
if invalidator:
invalidator()
except Exception:
pass
try:
ok = auth_manager.delete_user(body.username, user)
except Exception:
# delete_user can touch ApiToken rows before a later auth-store write
# fails. Dirty the bearer cache anyway so a partial token purge does
# not leave already-cached tokens authenticating until restart.
_invalidate_api_token_cache()
raise
if not ok:
raise HTTPException(400, "Cannot delete user")
# delete_user removes the user's ApiToken rows, but the bearer-auth
@@ -481,12 +530,7 @@ def setup_auth_routes(auth_manager: AuthManager) -> APIRouter:
# rebuilds when flagged dirty. Without this, a deleted user's already
# cached token keeps authenticating until some other token op or a
# restart clears the cache. Mirror what the token routes do.
try:
invalidator = getattr(request.app.state, "invalidate_token_cache", None)
if invalidator:
invalidator()
except Exception:
pass
_invalidate_api_token_cache()
return {"ok": True}
# ---- Feature visibility (admin-managed) ----
+5 -2
View File
@@ -729,8 +729,11 @@ def setup_contacts_routes():
@router.post("/import")
async def import_vcf(data: dict, _admin: str = Depends(require_admin)):
"""Import contacts from .vcf or CSV. Body: {"vcf": "..."} or {"csv": "..."}."""
text = data.get("vcf") or data.get("text") or ""
csv_text = data.get("csv") or ""
# Coerce defensively: a non-string vcf/text/csv (e.g. a number or list
# in the JSON body) would otherwise reach .strip() and 500 with an
# AttributeError instead of degrading to a clean "no data" response.
text = str(data.get("vcf") or data.get("text") or "")
csv_text = str(data.get("csv") or "")
if text.strip():
if "BEGIN:VCARD" not in text.upper():
return {"success": False, "error": "No vCard data found"}
+1 -23
View File
@@ -11,6 +11,7 @@ import shlex
from fastapi import HTTPException
from pydantic import BaseModel
from routes._validators import validate_remote_host, validate_ssh_port
from core.platform_compat import _ssh_exec_argv
logger = logging.getLogger(__name__)
@@ -30,16 +31,12 @@ _LOCAL_MODEL_ID_RE = re.compile(r"^[A-Za-z0-9][A-Za-z0-9._-]*$")
_OLLAMA_MODEL_ID_RE = re.compile(r"^[A-Za-z0-9][A-Za-z0-9._:/-]{0,200}$")
# Include pattern is a glob: allow typical safe glyphs only.
_INCLUDE_RE = re.compile(r"^[A-Za-z0-9._\-*?/\[\]]+$")
# Remote host: either `user@host` or plain `host` (alias is allowed), where host
# is a safe DNS-like token or a short SSH config alias.
_REMOTE_HOST_RE = re.compile(r"^(?:[A-Za-z0-9._-]+@)?[A-Za-z0-9._-]+$")
# HF tokens and API tokens are url-safe base64-like.
_TOKEN_RE = re.compile(r"^[A-Za-z0-9._~+/=-]+$")
# Session IDs we mint look like "cookbook-deadbeef" or "serve-deadbeef".
# Anything beyond plain alphanumerics + dash + underscore could break out
# of the shell/PowerShell contexts the value lands in.
_SESSION_ID_RE = re.compile(r"^[A-Za-z0-9_-]{1,64}$")
_SSH_PORT_RE = re.compile(r"^\d{1,5}$")
_GPU_LIST_RE = re.compile(r"^\d+(?:,\d+)*$")
# A download target directory. Absolute or ~-relative path; safe path glyphs
# only (no quotes or shell metacharacters). Spaces are allowed because command
@@ -85,14 +82,6 @@ def _validate_include(v: str | None) -> str | None:
return v
def _validate_remote_host(v: str | None) -> str | None:
if v is None or v == "":
return None
if not _REMOTE_HOST_RE.match(v):
raise HTTPException(400, "Invalid remote_host — must be host or user@host, no SSH option syntax")
return v
def _validate_token(v: str | None) -> str | None:
if v is None or v == "":
return None
@@ -120,17 +109,6 @@ def _validate_local_dir(v: str | None) -> str | None:
return v
def _validate_ssh_port(v: str | None) -> str | None:
if v is None or v == "":
return None
if not _SSH_PORT_RE.fullmatch(str(v)):
raise HTTPException(400, "Invalid ssh_port")
port = int(v)
if port < 1 or port > 65535:
raise HTTPException(400, "Invalid ssh_port")
return str(port)
def _validate_gpus(v: str | None) -> str | None:
if v is None or v == "":
return None
+38 -26
View File
@@ -19,6 +19,7 @@ from src.constants import COOKBOOK_STATE_FILE
from pydantic import BaseModel
from core.middleware import require_admin
from routes._validators import validate_remote_host, validate_ssh_port
from core.platform_compat import (
IS_WINDOWS,
detached_popen_kwargs,
@@ -33,9 +34,8 @@ from routes.shell_routes import TMUX_LOG_DIR
logger = logging.getLogger(__name__)
from routes.cookbook_helpers import (
_SSH_PORT_RE, _REMOTE_HOST_RE, _SESSION_ID_RE,
_validate_repo_id, _validate_serve_model_id, _validate_include, _validate_remote_host, _validate_token,
_validate_local_dir, _validate_ssh_port, _validate_gpus, _shell_path,
_SESSION_ID_RE, _validate_repo_id, _validate_serve_model_id, _validate_include, _validate_token,
_validate_local_dir, _validate_gpus, _shell_path,
_ps_squote, _bash_squote, _validate_serve_cmd, _parse_serve_phase,
_safe_env_prefix, _local_tooling_path_export, _append_serve_preflight_exit_lines,
_append_serve_exit_code_lines, _append_llama_cpp_linux_accel_build_lines, _cached_model_scan_script,
@@ -407,8 +407,8 @@ def setup_cookbook_routes() -> APIRouter:
else:
_validate_repo_id(req.repo_id)
_validate_include(req.include)
_validate_remote_host(req.remote_host)
req.ssh_port = _validate_ssh_port(req.ssh_port)
validate_remote_host(req.remote_host)
req.ssh_port = validate_ssh_port(req.ssh_port)
req.local_dir = _validate_local_dir(req.local_dir)
req.hf_token = "" if is_ollama_download else (req.hf_token or _load_stored_hf_token())
_validate_token(req.hf_token)
@@ -739,9 +739,8 @@ def setup_cookbook_routes() -> APIRouter:
# Validate shell-bound inputs, matching the sibling list_gpus endpoint —
# `host`/`ssh_port` are interpolated into an ssh command below, so an
# unvalidated value (e.g. "x'; rm -rf ~ #") would be command injection.
host = _validate_remote_host(host)
if ssh_port is not None and ssh_port != "" and not _SSH_PORT_RE.fullmatch(ssh_port):
raise HTTPException(400, "Invalid ssh_port")
host = validate_remote_host(host)
ssh_port = validate_ssh_port(ssh_port)
TMUX_LOG_DIR.mkdir(parents=True, exist_ok=True)
model_dirs = []
@@ -890,11 +889,16 @@ def setup_cookbook_routes() -> APIRouter:
# listening" check without requiring ss/netstat/nmap.
ssh_base = ["ssh", "-o", "ConnectTimeout=4", "-o", "StrictHostKeyChecking=no"]
if ssh_port and str(ssh_port) != "22":
if not _SSH_PORT_RE.match(str(ssh_port)):
try:
ssh_port = validate_ssh_port(ssh_port)
except HTTPException:
return None
ssh_base.extend(["-p", str(ssh_port)])
host_arg = remote
if not _REMOTE_HOST_RE.match(host_arg):
try:
host_arg = validate_remote_host(remote)
except HTTPException:
return None
if not host_arg:
return None
probe_ports = " ".join(str(start_port + i) for i in range(max_offset + 1))
script = (
@@ -1197,8 +1201,8 @@ def setup_cookbook_routes() -> APIRouter:
"""
require_admin(request)
# Defence-in-depth: reject values that could break out of shell contexts.
_validate_remote_host(req.remote_host)
req.ssh_port = _validate_ssh_port(req.ssh_port)
validate_remote_host(req.remote_host)
req.ssh_port = validate_ssh_port(req.ssh_port)
req.gpus = _validate_gpus(req.gpus)
req.hf_token = req.hf_token or _load_stored_hf_token()
_validate_token(req.hf_token)
@@ -1638,12 +1642,11 @@ def setup_cookbook_routes() -> APIRouter:
async def server_setup(request: Request, req: SetupRequest):
"""Install required dependencies on a remote server via SSH."""
require_admin(request)
host = _validate_remote_host(req.host)
host = validate_remote_host(req.host)
if not host:
raise HTTPException(400, "host is required")
port = req.ssh_port
if port is not None and port != "" and not re.fullmatch(r"\d{1,5}", port):
raise HTTPException(400, "Invalid ssh_port")
port = validate_ssh_port(port)
pf = f"-p {port} " if port and port != "22" else ""
# Detect platform: Windows first (echo %OS% → Windows_NT), then Termux, then Linux
@@ -1887,9 +1890,8 @@ def setup_cookbook_routes() -> APIRouter:
`busy` is True when free_mb/total_mb < 0.5.
"""
require_admin(request)
host = _validate_remote_host(host)
if ssh_port is not None and ssh_port != "" and not _SSH_PORT_RE.fullmatch(ssh_port):
raise HTTPException(400, "Invalid ssh_port")
host = validate_remote_host(host)
ssh_port = validate_ssh_port(ssh_port)
gpu_query = "nvidia-smi --query-gpu=index,name,memory.free,memory.total,memory.used,utilization.gpu,uuid --format=csv,noheader,nounits"
nvidia_error = None
try:
@@ -2046,9 +2048,8 @@ def setup_cookbook_routes() -> APIRouter:
sig = (req.signal or "TERM").upper()
if sig not in ("TERM", "KILL", "INT"):
raise HTTPException(400, "signal must be TERM, KILL, or INT")
host = _validate_remote_host(req.host)
if req.ssh_port and not _SSH_PORT_RE.fullmatch(req.ssh_port):
raise HTTPException(400, "Invalid ssh_port")
host = validate_remote_host(req.host)
req.ssh_port = validate_ssh_port(req.ssh_port)
kill_cmd = f"kill -{sig} {req.pid}"
try:
if host:
@@ -2382,13 +2383,18 @@ def setup_cookbook_routes() -> APIRouter:
host = (srv.get("host") or "").strip()
if not host:
continue # local-only entry; the /proc scan handles it
if not _REMOTE_HOST_RE.match(host):
try:
host = validate_remote_host(host)
except HTTPException:
continue
sport = str(srv.get("port") or "").strip()
ssh_base = ["ssh", "-o", "ConnectTimeout=4", "-o", "StrictHostKeyChecking=no"]
if sport and sport != "22":
if not _SSH_PORT_RE.match(sport):
try:
sport = validate_ssh_port(sport)
except HTTPException:
continue
if sport != "22":
ssh_base.extend(["-p", sport])
try:
@@ -2743,10 +2749,16 @@ def setup_cookbook_routes() -> APIRouter:
if not _SESSION_ID_RE.match(session_id):
logger.warning(f"Skipping task with unsafe session_id: {session_id!r}")
continue
if remote and not _REMOTE_HOST_RE.match(remote):
if remote:
try:
remote = validate_remote_host(remote)
except HTTPException:
logger.warning(f"Skipping task with unsafe remoteHost: {remote!r}")
continue
if _tport and not _SSH_PORT_RE.match(str(_tport)):
if _tport:
try:
_tport = validate_ssh_port(str(_tport))
except HTTPException:
logger.warning(f"Skipping task with unsafe sshPort: {_tport!r}")
continue
if task_platform == "windows" and remote:
+23 -3
View File
@@ -1,7 +1,9 @@
import re
from copy import deepcopy
from fastapi import APIRouter
from fastapi import APIRouter, HTTPException
from routes._validators import validate_remote_host, validate_ssh_port
# Backends the manual hardware simulator accepts. Must stay a subset of what
@@ -11,6 +13,14 @@ from fastapi import APIRouter
_MANUAL_BACKENDS = {"cuda", "rocm", "metal", "cpu_x86", "cpu_arm"}
def _validate_detection_target(host: str = "", ssh_port: str = "") -> tuple[str, str]:
host_value = validate_remote_host(host) or ""
port_value = validate_ssh_port(ssh_port) or ""
if port_value and not host_value:
raise HTTPException(400, "ssh_port requires host")
return host_value, port_value
def _apply_manual_hardware(system, manual_mode="", manual_gpu_count="", manual_vram_gb="", manual_ram_gb="", manual_backend=""):
"""Manual hardware is a "what if I had this setup" simulator —
REPLACES the detected hardware entirely instead of adding to it.
@@ -105,6 +115,7 @@ def setup_hwfit_routes():
"""Detect and return current system hardware info. Pass host=user@server for remote.
fresh=true bypasses the per-host cache (the Rescan button)."""
from services.hwfit.hardware import detect_system
host, ssh_port = _validate_detection_target(host, ssh_port)
return detect_system(host=host, ssh_port=ssh_port, platform=platform, fresh=fresh)
@router.get("/models")
@@ -118,6 +129,7 @@ def setup_hwfit_routes():
from services.hwfit.hardware import detect_system
from services.hwfit.fit import rank_models
from services.hwfit.models import get_models, model_catalog_path
host, ssh_port = _validate_detection_target(host, ssh_port)
system = deepcopy(detect_system(host=host, ssh_port=ssh_port, platform=platform, fresh=fresh))
if system.get("error"):
return {"system": system, "models": [], "error": system["error"]}
@@ -165,8 +177,14 @@ def setup_hwfit_routes():
system["gpu_name"] = g["name"]
system["active_group"] = {**g, "use_count": n}
if gpu_count != "":
n = int(gpu_count)
# Parse the optional count defensively (matches the gpu_group guard
# above): a non-numeric query param previously raised ValueError ->
# HTTP 500. A malformed value is ignored, same as omitting it.
try:
n = int(gpu_count) if gpu_count != "" else None
except ValueError:
n = None
if n is not None:
if n == 0:
# RAM-only mode: rank against system memory, offload allowed.
system["has_gpu"] = False
@@ -229,6 +247,7 @@ def setup_hwfit_routes():
from services.hwfit.hardware import detect_system
from services.hwfit.models import get_models
from services.hwfit.profiles import compute_serve_profiles
host, ssh_port = _validate_detection_target(host, ssh_port)
system = detect_system(host=host, ssh_port=ssh_port, platform=platform, fresh=fresh)
if system.get("error"):
return {"system": system, "profiles": [], "error": system["error"]}
@@ -279,6 +298,7 @@ def setup_hwfit_routes():
"""Rank image generation models against detected hardware."""
from services.hwfit.hardware import detect_system
from services.hwfit.image_models import rank_image_models
host, ssh_port = _validate_detection_target(host, ssh_port)
system = deepcopy(detect_system(host=host, ssh_port=ssh_port, platform=platform, fresh=fresh))
if system.get("error"):
return {"system": system, "models": [], "error": system["error"]}
+12 -8
View File
@@ -11,7 +11,7 @@ from core.session_manager import SessionManager
from core.models import ChatMessage
from src.request_models import SessionResponse
from core.database import Session as DbSession, SessionLocal, Document, GalleryImage, utcnow_naive
from src.auth_helpers import get_current_user, effective_user, _auth_disabled
from src.auth_helpers import get_current_user, effective_user, _auth_disabled, owner_filter
from src.session_actions import is_session_recently_active
@@ -258,7 +258,9 @@ def setup_session_routes(session_manager: SessionManager, config: dict, webhook_
last_msg_map = {}
mode_map = {}
msg_count_map = {}
rows = db.query(DbSession.id, DbSession.folder, DbSession.total_input_tokens, DbSession.total_output_tokens, DbSession.is_important, DbSession.created_at, DbSession.updated_at, DbSession.last_message_at, DbSession.mode, DbSession.message_count).filter(DbSession.archived == False, DbSession.owner == user).all()
q = db.query(DbSession.id, DbSession.folder, DbSession.total_input_tokens, DbSession.total_output_tokens, DbSession.is_important, DbSession.created_at, DbSession.updated_at, DbSession.last_message_at, DbSession.mode, DbSession.message_count).filter(DbSession.archived == False)
q = owner_filter(q, DbSession, user)
rows = q.all()
for row in rows:
folder_map[row.id] = row.folder
token_map[row.id] = (row.total_input_tokens or 0) + (row.total_output_tokens or 0)
@@ -277,17 +279,19 @@ def setup_session_routes(session_manager: SessionManager, config: dict, webhook_
# Sessions with active documents that have content
from sqlalchemy import func
doc_session_ids = set(
r[0] for r in db.query(Document.session_id)
r[0] for r in owner_filter(
db.query(Document.session_id)
.filter(Document.is_active == True,
Document.current_content != None,
func.trim(Document.current_content) != "",
Document.owner == user)
func.trim(Document.current_content) != ""),
Document, user)
.distinct().all()
)
img_session_ids = set(
r[0] for r in db.query(GalleryImage.session_id)
.filter(GalleryImage.session_id != None,
GalleryImage.owner == user)
r[0] for r in owner_filter(
db.query(GalleryImage.session_id)
.filter(GalleryImage.session_id != None),
GalleryImage, user)
.distinct().all()
)
finally:
+1 -1
View File
@@ -417,7 +417,7 @@ def duckduckgo_search(query: str, count: Optional[int] = None, time_filter: Opti
return []
try:
from duckduckgo_search import DDGS
from ddgs import DDGS
except ImportError:
logger.warning("duckduckgo-search package not installed; using HTML fallback")
return _html_fallback()
+21 -8
View File
@@ -579,6 +579,24 @@ def _classify_event_heuristic(summary: str) -> tuple:
return etype, None
def _memory_context_lines(mems, limit: int = 40) -> list:
"""Render Memory rows into short personal-context bullets for event classify.
Reads the Memory ORM `text` column. The previous inline code read a
non-existent `content` attribute, so it raised AttributeError on the first
row, the surrounding except swallowed it, and the classifier ran with no
personal context at all. getattr keeps it robust to future schema drift.
"""
lines: list = []
for m in mems:
c = (getattr(m, "text", "") or "").strip()
if c:
lines.append(f"- {c[:200]}")
if len(lines) >= limit:
break
return lines
async def action_classify_events(owner: str, **kwargs) -> Tuple[str, bool]:
"""Hybrid classification of upcoming calendar events: fast heuristic for
obvious cases, LLM fallback for ambiguous ones. Assigns event_type +
@@ -614,16 +632,11 @@ async def action_classify_events(owner: str, **kwargs) -> Tuple[str, bool]:
try:
from core.database import Memory as _Mem
_mems = db.query(_Mem).filter(_Mem.owner == owner).limit(60).all() if owner else []
if _mems:
_lines = []
for m in _mems:
c = (m.content or "").strip()
if c:
_lines.append(f"- {c[:200]}")
_lines = _memory_context_lines(_mems)
if _lines:
_memory_context = "USER CONTEXT (relationships, work, life):\n" + "\n".join(_lines[:40]) + "\n\n"
_memory_context = "USER CONTEXT (relationships, work, life):\n" + "\n".join(_lines) + "\n\n"
except Exception as _me:
logger.debug(f"Could not load memory for classify: {_me}")
logger.warning(f"Could not load memory for classify: {_me}")
classified_h = 0
classified_llm = 0
+19
View File
@@ -223,6 +223,25 @@ class ModelDiscovery:
)
return {"hosts": hosts, "items": items}
def warmup_ping_urls(self, limit: int = 5) -> List[str]:
"""The ``/models`` URLs of up to ``limit`` discovered endpoints.
Used by the startup warmup / keepalive loop to prime connections. Each
discovered item already carries a ``/v1/chat/completions`` url; swap the
suffix for the cheap ``/models`` probe. Failures degrade to an empty list
so warmup never crashes the caller.
"""
try:
items = (self.discover_models() or {}).get("items", [])
except Exception:
return []
urls: List[str] = []
for ep in items[:limit]:
url = (ep.get("url") or "").replace("/chat/completions", "/models")
if url:
urls.append(url)
return urls
def get_providers(self) -> Dict[str, Any]:
"""Get all available providers"""
discovery = self.discover_models()
+24 -1
View File
@@ -221,6 +221,22 @@ class ResearchHandler:
# Task registry — background research with persistence
# ------------------------------------------------------------------
def rename_owner(self, old_owner: str, new_owner: str) -> int:
"""Move in-flight research tasks from one owner key to another."""
old_key = str(old_owner or "").strip().lower()
new_key = str(new_owner or "").strip().lower()
if not old_key or not new_key:
return 0
changed = 0
for entry in list(self._active_tasks.values()):
if not isinstance(entry, dict):
continue
if str(entry.get("owner", "")).strip().lower() == old_key:
entry["owner"] = new_key
changed += 1
return changed
def start_research(
self,
session_id: str,
@@ -390,7 +406,6 @@ class ResearchHandler:
def get_status(self, session_id: str) -> Optional[dict]:
"""Get current research status for a session."""
avg = self.get_avg_duration()
if session_id in self._active_tasks:
entry = self._active_tasks[session_id]
result = {
@@ -399,6 +414,14 @@ class ResearchHandler:
"query": entry["query"],
"started_at": entry["started_at"],
}
# avg_duration is a historical figure over completed reports on
# disk; get_avg_duration() globs and JSON-parses the whole research
# dir, so compute it at most once per active stream (memoized on the
# entry) instead of on every ~1s SSE poll. The disk branch below
# never used it, so it no longer pays that cost at all.
if "_avg_duration" not in entry:
entry["_avg_duration"] = self.get_avg_duration()
avg = entry["_avg_duration"]
if avg is not None:
result["avg_duration"] = round(avg, 1)
return result
+36
View File
@@ -1453,6 +1453,42 @@ async def do_manage_calendar(content: str, owner: Optional[str] = None) -> Dict:
except ValueError:
return {"error": "Invalid JSON arguments", "exit_code": 1}
# ── Batch normalization ──
# Some models (e.g. deepseek-v4-flash) emit {"events": [{...}, ...]}
# instead of individual create_event calls. Iterate and create each.
if isinstance(args.get("events"), list) and not args.get("action"):
results = []
for ev in args["events"]:
if not isinstance(ev, dict):
continue
# Normalize start/end from {dateTime: "..."} object to flat string
for field, target in [("start", "dtstart"), ("end", "dtend")]:
val = ev.pop(field, None)
if val and target not in ev:
ev[target] = val.get("dateTime", val) if isinstance(val, dict) else val
ev.setdefault("action", "create_event")
r = await do_manage_calendar(json.dumps(ev), owner=owner)
results.append(r)
created = [r for r in results if r.get("exit_code") == 0 and not r.get("error")]
failed = [r for r in results if r.get("error")]
if not results:
return {"error": "No events to create", "exit_code": 1}
# Surface both successes and failures
parts = []
if created:
summaries = [r.get("response", "") for r in created]
parts.append(f"Created {len(created)} event(s):\n" + "\n".join(summaries))
if failed:
first_error = failed[0].get("error", "Unknown error")
parts.append(f"Failed to create {len(failed)} event(s). First error: {first_error}")
response = "\n\n".join(parts)
# Non-zero exit code for partial or total failure
exit_code = 0 if not failed else 1
return {"response": response, "exit_code": exit_code, "created_count": len(created), "failed_count": len(failed)}
# Normalize action — some models emit hyphens ("list-calendars") instead
# of underscores. Treat them as equivalent so we don't bounce a
# cosmetic typo back to the model and waste a round-trip. Also accept
+15 -2
View File
@@ -162,13 +162,26 @@ def is_public_blocked_tool(tool_name: Optional[str]) -> bool:
def owner_is_admin_or_single_user(owner: Optional[str]) -> bool:
"""Return True for admins, or when auth is not configured yet."""
"""Return True for admins, or in intentional single-user mode.
Single-user mode means the operator explicitly disabled auth
(``AUTH_ENABLED=false``) the local/self-host default where the owner has
full access to their own box.
The pre-setup window (auth ENABLED but no admin created yet) is treated as
NON-admin: returning True there would hand server-execution tools
(``bash``/``python``) to any caller before setup completes. The auth
middleware already 401s ``/api/`` requests pre-setup, so this is
defense-in-depth for callers that bypass it (e.g. trusted loopback).
"""
try:
from core.auth import AuthManager
auth = AuthManager()
if not auth.is_configured:
return True
from src.auth_helpers import _auth_disabled
return _auth_disabled()
return bool(owner and auth.is_admin(owner))
except Exception as exc:
logger.warning("Unable to evaluate owner admin status: %s", exc)
+6 -2
View File
@@ -1082,7 +1082,7 @@ import { wireArrowUpRecall, getLastUserMessageFromChatHistory } from './composer
let _lastToolName = '';
const _searchIcon = '<svg width="14" height="14" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2.5" stroke-linecap="round" style="vertical-align:-2px;margin-right:4px"><circle cx="11" cy="11" r="8"/><line x1="21" y1="21" x2="16.65" y2="16.65"/></svg>';
const _toolLabels = {
'web_search': _searchIcon + 'Searching',
'web_search': 'Searching',
'bash': 'Running',
'python': 'Running',
'create_document': 'Writing',
@@ -1102,6 +1102,9 @@ import { wireArrowUpRecall, getLastUserMessageFromChatHistory } from './composer
'list_models': 'Browsing',
'ui_control': 'Adjusting',
};
const _toolIcons = {
'web_search': _searchIcon,
};
function _thinkingLabel() {
if (!_lastToolName) {
return 'Thinking';
@@ -2049,10 +2052,11 @@ import { wireArrowUpRecall, getLastUserMessageFromChatHistory } from './composer
}
threadWrap.classList.add('streaming');
const toolLabel = _toolLabels[json.tool.toLowerCase()] || json.tool;
const toolIcon = _toolIcons[json.tool.toLowerCase()] || '\u25B6';
const node = document.createElement('div')
node.className = 'agent-thread-node running';
const cmdHtml = cmd ? `<pre class="agent-thread-cmd">${esc(cmd)}</pre>` : '';
node.innerHTML = `<div class="agent-thread-dot"></div><div class="agent-thread-header"><span class="agent-thread-icon">\u25B6</span><span class="agent-thread-tool">${esc(toolLabel)}</span><span class="agent-thread-wave">▁▂▃</span></div><div class="agent-thread-content">${cmdHtml}</div>`;
node.innerHTML = `<div class="agent-thread-dot"></div><div class="agent-thread-header"><span class="agent-thread-icon">${toolIcon}</span><span class="agent-thread-tool">${esc(toolLabel)}</span><span class="agent-thread-wave">▁▂▃</span></div><div class="agent-thread-content">${cmdHtml}</div>`;
// Expand/collapse via delegated click handler (init at module bottom).
threadWrap.appendChild(node);
currentToolBubble = node;
+16 -1
View File
@@ -862,6 +862,20 @@ export function stripToolBlocks(text) {
return cleaned.trim();
}
/**
* Plain-text payload for the message copy buttons: the reply as the renderer
* displays it tool blocks and <think> reasoning stripped. dataset.raw keeps
* the full model output (chat.js even embeds the elapsed time into the
* <think> tag for reload persistence), so copying it verbatim leaks the
* thinking block (#3722). Falls back to the raw text when stripping leaves
* nothing (e.g. turns interrupted mid-thinking).
*/
export function copyMessageText(msgElement) {
const raw = msgElement.dataset.raw || msgElement.querySelector('.body')?.textContent || '';
const { content } = markdownModule.extractThinkingBlocks(stripToolBlocks(raw));
return content || raw;
}
/**
* Build a collapsible sources box (used by both research and web search).
*/
@@ -1372,7 +1386,7 @@ export function createMsgFooter(msgElement) {
{ id: 'copy', icon: COPY_ICON, title: 'Copy message', cls: 'footer-copy-btn', html: true, handler(e) {
e.stopPropagation();
const btn = e.currentTarget;
uiModule.copyToClipboard(msgElement.dataset.raw || msgElement.querySelector('.body')?.textContent || '');
uiModule.copyToClipboard(copyMessageText(msgElement));
btn.innerHTML = CHECK_ICON;
setTimeout(() => { btn.innerHTML = COPY_ICON; }, 1500);
}},
@@ -2444,6 +2458,7 @@ const chatRenderer = {
updateSessionCostUI,
roleTimestamp,
stripToolBlocks,
copyMessageText,
safeToolScreenshotSrc,
safeDisplayImageSrc,
buildSourcesBox,
+1 -1
View File
@@ -380,7 +380,7 @@ function _slashFooter(msgEl) {
copyBtn.innerHTML = _copySvg;
copyBtn.onclick = (e) => {
e.stopPropagation();
uiModule.copyToClipboard(msgEl.dataset.raw || msgEl.querySelector('.body')?.textContent || '');
uiModule.copyToClipboard(chatRenderer.copyMessageText(msgEl));
copyBtn.innerHTML = _checkSvg;
setTimeout(() => { copyBtn.innerHTML = _copySvg; }, 1500);
};
@@ -8,6 +8,9 @@ with missing users or assertion errors.
import json
import threading
import time
import contextlib
import sys
import types
from concurrent.futures import ThreadPoolExecutor, as_completed
import pytest
@@ -15,6 +18,41 @@ import pytest
from tests.helpers.import_state import clear_module
class _OwnerColumn:
def __eq__(self, other):
return ("owner ==", other)
class _FakeApiToken:
owner = _OwnerColumn()
class _FakeQuery:
def filter(self, *_conds):
return self
def delete(self, *args, **kwargs):
return 0
class _FakeSession:
def query(self, model):
assert model is _FakeApiToken
return _FakeQuery()
@pytest.fixture(autouse=True)
def _stub_api_token_purge(monkeypatch):
@contextlib.contextmanager
def _fake_db_session():
yield _FakeSession()
db_stub = types.ModuleType("core.database")
db_stub.get_db_session = _fake_db_session
db_stub.ApiToken = _FakeApiToken
monkeypatch.setitem(sys.modules, "core.database", db_stub)
def _fresh_auth_manager(tmp_path):
clear_module("core.auth")
from core.auth import AuthManager
+125
View File
@@ -0,0 +1,125 @@
"""Test that do_manage_calendar handles the batch {"events": [...]} format
that models like deepseek-v4-flash emit instead of individual create_event calls.
"""
import json
import sys
import uuid
import pytest
from tests.helpers.import_state import clear_fake_database_modules
from tests.helpers.sqlite_db import make_temp_sqlite
clear_fake_database_modules()
import core.database as cdb
from core.database import CalendarEvent
_TS, _ENGINE, _TMPDB = make_temp_sqlite(cdb.Base.metadata)
@pytest.fixture(autouse=True)
def _bind_temp_db(monkeypatch):
monkeypatch.setitem(sys.modules, "core.database", cdb)
parent = sys.modules.get("core")
if parent is not None:
monkeypatch.setattr(parent, "database", cdb, raising=False)
monkeypatch.setattr(cdb, "SessionLocal", _TS)
yield
async def test_batch_events_with_datetime_objects():
"""Model emits {"events": [{"summary": ..., "start": {"dateTime": ...}, "end": {"dateTime": ...}}]}."""
from src.tool_implementations import do_manage_calendar
owner = "tester-" + uuid.uuid4().hex[:6]
payload = {
"events": [
{
"summary": "Morning Gym",
"start": {"dateTime": "2026-06-09T06:00:00+05:30"},
"end": {"dateTime": "2026-06-09T07:00:00+05:30"},
},
{
"summary": "Morning Gym",
"start": {"dateTime": "2026-06-10T06:00:00+05:30"},
"end": {"dateTime": "2026-06-10T07:00:00+05:30"},
},
]
}
res = await do_manage_calendar(json.dumps(payload), owner=owner)
assert res.get("exit_code") == 0, res
assert "Created 2 event(s)" in res.get("response", "")
# Verify events exist in DB
db = _TS()
events = db.query(CalendarEvent).filter(CalendarEvent.summary == "Morning Gym").all()
assert len(events) == 2
db.close()
async def test_batch_events_with_flat_strings():
"""Model emits {"events": [{"summary": ..., "start": "ISO", "end": "ISO"}]}."""
from src.tool_implementations import do_manage_calendar
owner = "tester-" + uuid.uuid4().hex[:6]
payload = {
"events": [
{
"summary": "Standup",
"start": "2026-06-09T09:00:00",
"end": "2026-06-09T09:30:00",
},
]
}
res = await do_manage_calendar(json.dumps(payload), owner=owner)
assert res.get("exit_code") == 0, res
assert "Created 1 event(s)" in res.get("response", "")
async def test_batch_events_partial_failure():
"""Batch with some valid and some invalid events — should surface both counts and first error."""
from src.tool_implementations import do_manage_calendar
owner = "tester-" + uuid.uuid4().hex[:6]
payload = {
"events": [
{
"summary": "Valid Event 1",
"start": "2026-06-09T10:00:00",
"end": "2026-06-09T11:00:00",
},
{
"summary": "Invalid Event",
# Missing required dtstart — will fail
},
{
"summary": "Valid Event 2",
"start": "2026-06-09T14:00:00",
"end": "2026-06-09T15:00:00",
},
]
}
res = await do_manage_calendar(json.dumps(payload), owner=owner)
# Partial failure = non-zero exit code
assert res.get("exit_code") != 0, "Partial failure should return non-zero exit code"
# Response should mention both created and failed counts
response = res.get("response", "")
assert "Created 2 event(s)" in response, f"Should report 2 created: {response}"
assert "Failed to create 1 event(s)" in response, f"Should report 1 failed: {response}"
assert "error" in response.lower() or "required" in response.lower(), "Should include error details"
# Metadata fields
assert res.get("created_count") == 2
assert res.get("failed_count") == 1
# Verify only valid events were created
db = _TS()
events = db.query(CalendarEvent).filter(
CalendarEvent.summary.in_(["Valid Event 1", "Valid Event 2"])
).all()
assert len(events) == 2
db.close()
+33
View File
@@ -0,0 +1,33 @@
"""classify_events must read the Memory `text` column, not a non-existent
`content` attribute.
The previous inline loop did `m.content`, which raised AttributeError on the
first Memory row; the surrounding except swallowed it, so the personal-context
block the LLM relies on was always empty. The logic now lives in
`_memory_context_lines`, which reads `text`.
"""
from src.builtin_actions import _memory_context_lines
class _Mem:
def __init__(self, text):
self.text = text
def test_uses_text_and_truncates_and_skips_blank():
lines = _memory_context_lines([_Mem("Alice is my spouse"), _Mem(" "), _Mem("y" * 250)])
assert lines[0] == "- Alice is my spouse"
assert len(lines) == 2 # the blank row is skipped
assert lines[1] == "- " + "y" * 200 # truncated to 200 chars
def test_skips_rows_without_text_attribute():
class _Bad: # mimics a schema where the attribute is absent
pass
assert _memory_context_lines([_Bad(), _Mem("ok")]) == ["- ok"]
def test_respects_limit():
mems = [_Mem(f"memory {i}") for i in range(50)]
assert len(_memory_context_lines(mems, limit=40)) == 40
+39
View File
@@ -0,0 +1,39 @@
"""POST /api/contacts/import must not 500 on a non-string vcf/text/csv value.
`text = data.get("vcf") or ... or ""` left a non-string value (e.g. a number)
in place, so the next `text.strip()` raised AttributeError -> HTTP 500. The
handler now coerces with str() and degrades to a structured "no data" response.
"""
import asyncio
from routes.contacts_routes import setup_contacts_routes
def _import_handler():
router = setup_contacts_routes()
for route in router.routes:
if getattr(route, "path", "").endswith("/import") and "POST" in getattr(route, "methods", set()):
return route.endpoint
raise AssertionError("import route not found")
def _call(data):
handler = _import_handler()
return asyncio.run(handler(data=data, _admin="admin"))
def test_non_string_vcf_degrades_cleanly():
resp = _call({"vcf": 123})
assert resp["success"] is False
assert "error" in resp
def test_non_string_csv_degrades_cleanly():
resp = _call({"csv": ["a", "b"]})
assert resp["success"] is False
def test_empty_body_reports_no_data():
resp = _call({})
assert resp["success"] is False
assert resp["error"] == "No contact data found"
-7
View File
@@ -26,7 +26,6 @@ from routes.cookbook_helpers import (
_validate_repo_id,
_validate_serve_cmd,
_validate_serve_model_id,
_validate_ssh_port,
_shell_path,
run_ssh_command_async,
)
@@ -106,12 +105,6 @@ def test_safe_env_prefix_accepts_powershell_activation_path():
)
def test_validate_ssh_port_rejects_shell_payload():
with pytest.raises(HTTPException):
_validate_ssh_port("22; touch /tmp/pwned")
assert _validate_ssh_port("2222") == "2222"
def test_validate_local_dir_accepts_external_drive_paths_with_spaces():
path = "/Volumes/T7 2TB/AI Models/llamacpp"
@@ -0,0 +1,160 @@
"""Regression coverage for issue #3722 — the message copy button copied the
full raw model output (``dataset.raw``), which still contains the
``<think time="...">...</think>`` reasoning block that the renderer strips for
display. Pasting therefore leaked the model's thinking, and the first heading
after ``</think>`` lost its markdown formatting because it was glued to the
closing tag.
The fix adds chatRenderer.copyMessageText(), which mirrors the display
pipeline (``stripToolBlocks()`` then ``extractThinkingBlocks()``), and routes
both AI-message copy buttons (createMsgFooter and the slash-reply footer)
through it. extractThinkingBlocks() behavior is pinned here under node
(including on the payload from the issue report); the helper and handler
wiring are guarded at the source level because chatRenderer.js pulls in
browser globals and can't be imported under node (same approach as
test_new_chat_clears_input.py).
"""
import json
import re
import shutil
import subprocess
import textwrap
from pathlib import Path
import pytest
_REPO = Path(__file__).resolve().parent.parent
_HAS_NODE = shutil.which("node") is not None
@pytest.fixture(scope="module")
def node_available():
if not _HAS_NODE:
pytest.skip("node binary not on PATH")
def _extract_thinking_blocks(text: str) -> dict:
"""Run markdown.js extractThinkingBlocks(text) under node."""
script = textwrap.dedent(
r"""
import fs from 'node:fs';
globalThis.window = { location: { origin: 'http://localhost' }, katex: null };
globalThis.document = {
readyState: 'loading',
addEventListener() {},
createElement(tag) {
if (tag !== 'template') throw new Error(`unsupported element: ${tag}`);
return {
_html: '',
content: { querySelectorAll() { return []; } },
set innerHTML(value) { this._html = value; },
get innerHTML() { return this._html; },
};
},
};
globalThis.MutationObserver = class { observe() {} };
let source = fs.readFileSync('./static/js/markdown.js', 'utf8');
source = source.replace(
/import uiModule from ['"]\.\/ui\.js['"];/,
''
);
source = source.replace(
/import \{ splitTableRow \} from ['"]\.\/markdown\/tableRow\.js['"];/,
`function splitTableRow(row) {
return (row || '').replace(/^\\s*\\|/, '').replace(/\\|\\s*$/, '').split('|').map(c => c.trim());
}`
);
const emojiSource = fs.readFileSync('./static/js/emojiShortcodes.js', 'utf8')
.replace(/^export default .*$/m, '')
.replace(/export const /g, 'const ')
.replace(/export function /g, 'function ');
source = source.replace(
/import \{ replaceEmojiShortcodes, hasEmojiShortcode \} from ['"]\.\/emojiShortcodes\.js['"];/,
() => emojiSource
);
source = source.replace(
/var escapeHtml = uiModule\.esc;/,
`var escapeHtml = (value) => String(value ?? '')
.replace(/&/g, '&amp;')
.replace(/</g, '&lt;')
.replace(/>/g, '&gt;')
.replace(/"/g, '&quot;')
.replace(/'/g, '&#39;');`
);
const moduleUrl = 'data:text/javascript;base64,' + Buffer.from(source).toString('base64');
const mod = await import(moduleUrl);
const input = JSON.parse(process.argv[1]);
console.log(JSON.stringify({ out: mod.extractThinkingBlocks(input) }));
"""
)
result = subprocess.run(
["node", "--input-type=module", "-e", script, json.dumps(text)],
cwd=_REPO,
capture_output=True,
timeout=15,
text=True,
)
if result.returncode != 0:
raise AssertionError(f"node failed:\nSTDERR:\n{result.stderr}\nSTDOUT:\n{result.stdout}")
return json.loads(result.stdout.splitlines()[-1])["out"]
def test_issue_payload_copy_text_excludes_thinking(node_available):
# Shape reported in #3722: timed think block glued to the reply heading.
raw = (
'<think time="24.5">\n'
"Here's a thinking process that leads to the desired summary:\n\n"
"6. **Generate the Output.** (This matches the final provided response.)"
"</think>### Juxtaposition: Interweaving Cultural Norms in Lesson Design\n"
"The most effective lesson structure is created by deliberately juxtaposing."
)
out = _extract_thinking_blocks(raw)
assert out["content"].startswith("### Juxtaposition:"), out["content"]
assert "thinking process" not in out["content"]
assert "<think" not in out["content"]
assert out["thinkingTime"] == "24.5"
def test_plain_reply_copy_text_is_unchanged(node_available):
raw = "### Heading\nJust a normal reply with no reasoning markup."
out = _extract_thinking_blocks(raw)
assert out["content"] == raw
def test_thinking_only_message_yields_empty_content(node_available):
# The copy handler falls back to the raw text in this case so the button
# still copies something for turns interrupted mid-thinking.
out = _extract_thinking_blocks("<think>only reasoning, no reply yet</think>")
assert out["content"] == ""
def _function_body(text: str, marker: str) -> str:
start = text.index(marker)
rest = text[start + len(marker):]
m = re.search(r"\nexport function |\nfunction ", rest)
return rest[: m.start()] if m else rest
def test_copy_message_text_mirrors_display_pipeline():
text = (_REPO / "static/js/chatRenderer.js").read_text(encoding="utf-8")
body = _function_body(text, "export function copyMessageText")
# Mirrors the display path: tool blocks stripped, then thinking extracted.
assert "extractThinkingBlocks" in body
assert "stripToolBlocks" in body
assert "dataset.raw" in body
def test_copy_handlers_route_through_copy_message_text():
for path, count in (("static/js/chatRenderer.js", 1), ("static/js/slashCommands.js", 1)):
text = (_REPO / path).read_text(encoding="utf-8")
assert text.count("copyToClipboard(copyMessageText(") + text.count(
"copyToClipboard(chatRenderer.copyMessageText("
) == count, path
# The old behavior passed dataset.raw straight to the clipboard.
assert "copyToClipboard(msgElement.dataset.raw" not in text, path
assert "copyToClipboard(msgEl.dataset.raw" not in text, path
@@ -36,6 +36,17 @@ def _auth_manager(delete_result):
)
def _auth_manager_raising():
def _delete_user(_username, _requesting_user):
raise RuntimeError("auth save failed after token purge")
return types.SimpleNamespace(
get_username_for_token=lambda token: "admin",
is_admin=lambda user: True,
delete_user=_delete_user,
)
def test_successful_delete_invalidates_cache():
invalidations = []
router = setup_auth_routes(_auth_manager(delete_result=True))
@@ -56,3 +67,16 @@ def test_refused_delete_does_not_invalidate_cache():
raised = True
assert raised, "a refused delete should raise (HTTP 400)"
assert invalidations == [], "a refused delete must not touch the token cache"
def test_delete_exception_invalidates_cache_for_partial_token_purge():
invalidations = []
router = setup_auth_routes(_auth_manager_raising())
handler = _handler(router)
try:
asyncio.run(handler(DeleteUserRequest(username="bob"), _fake_request(invalidations)))
raised = False
except RuntimeError:
raised = True
assert raised, "delete_user exception should still propagate"
assert invalidations == [True], "partial token purge must dirty the bearer cache"
@@ -114,3 +114,21 @@ def test_refused_delete_leaves_tokens_alone(manager, db_calls):
def test_unknown_user_leaves_tokens_alone(manager, db_calls):
assert manager.delete_user("ghost", "admin") is False
assert db_calls == []
def test_delete_user_fails_closed_when_api_token_purge_fails(manager, monkeypatch):
token = manager.create_session("bob", "secret-bob-pw")
@contextlib.contextmanager
def _failing_db_session():
raise RuntimeError("database unavailable")
yield
db_stub = types.ModuleType("core.database")
db_stub.get_db_session = _failing_db_session
db_stub.ApiToken = _FakeApiToken
monkeypatch.setitem(sys.modules, "core.database", db_stub)
assert manager.delete_user("bob", "admin") is False
assert "bob" in manager.users
assert manager.validate_token(token) is True
+38
View File
@@ -0,0 +1,38 @@
"""GET /api/hwfit/models must not 500 on a non-numeric gpu_count.
The handler did `n = int(gpu_count)` with no guard, so `?gpu_count=abc` (or any
non-integer) raised ValueError -> HTTP 500. A malformed count is now ignored,
matching how the neighbouring gpu_group param is already parsed.
"""
from routes.hwfit_routes import setup_hwfit_routes
def _get_models():
router = setup_hwfit_routes()
for route in router.routes:
if getattr(route, "path", "").endswith("/models") and "GET" in getattr(route, "methods", set()):
return route.endpoint
raise AssertionError("hwfit /models route not found")
def test_non_numeric_gpu_count_does_not_raise():
handler = _get_models()
# Previously raised ValueError (HTTP 500); now degrades to a normal ranking.
result = handler(gpu_count="abc")
assert isinstance(result, dict)
def test_numeric_gpu_count_still_accepted():
handler = _get_models()
result = handler(gpu_count="0")
assert isinstance(result, dict)
def test_non_numeric_manual_gpu_count_does_not_raise():
# manual_gpu_count is the other count param on this endpoint (the hardware
# simulator in _apply_manual_hardware). A non-numeric value must also degrade
# (default to 1) rather than 500, so the endpoint's count parsing is fully
# covered.
handler = _get_models()
result = handler(manual_mode="gpu", manual_gpu_count="abc")
assert isinstance(result, dict)
+47
View File
@@ -0,0 +1,47 @@
import pytest
from fastapi import HTTPException
from core.platform_compat import _ssh_exec_argv
from routes.hwfit_routes import setup_hwfit_routes
def _endpoint(path: str):
router = setup_hwfit_routes()
for route in router.routes:
if getattr(route, "path", "") == path:
return route.endpoint
raise AssertionError(f"{path} route not found")
@pytest.mark.parametrize(
"path,kwargs",
[
("/api/hwfit/system", {}),
("/api/hwfit/models", {"limit": 1}),
("/api/hwfit/profiles", {"model": "demo"}),
("/api/hwfit/image-models", {}),
],
)
def test_hwfit_routes_reject_ssh_option_host(path, kwargs):
endpoint = _endpoint(path)
with pytest.raises(HTTPException) as exc:
endpoint(host="-oProxyCommand=sh", ssh_port="22", **kwargs)
assert exc.value.status_code == 400
def test_hwfit_routes_reject_port_without_host():
endpoint = _endpoint("/api/hwfit/system")
with pytest.raises(HTTPException) as exc:
endpoint(host="", ssh_port="2222")
assert exc.value.status_code == 400
def test_ssh_argv_rejects_option_shaped_remote():
with pytest.raises(ValueError):
_ssh_exec_argv("-oProxyCommand=sh", "22", remote_cmd="true")
with pytest.raises(ValueError):
_ssh_exec_argv("alice@-oProxyCommand=sh", "22", remote_cmd="true")
+253 -4
View File
@@ -11,7 +11,10 @@ owner column, but three file-backed / in-memory stores are left stale:
research_routes filters by `d.get("owner") == user`, making every report
invisible after rename.
3. data/memory.json a flat array where every entry has an `owner` field;
3. research_handler._active_tasks in-flight research jobs carry the same
owner key while status/cancel/active routes filter by it.
4. data/memory.json a flat array where every entry has an `owner` field;
memory_manager.load(owner=user) filters on it, so all memories vanish.
Regression coverage: these bugs are invisible in unit tests that mock the DB
@@ -26,6 +29,7 @@ from types import SimpleNamespace
from unittest.mock import MagicMock
import pytest
from fastapi import HTTPException
def _route(router, name):
@@ -63,18 +67,69 @@ def rename_endpoint(monkeypatch, tmp_path):
return _route(ar.setup_auth_routes(am), "rename_user"), am, tmp_path
def _request(tmp_path, session_manager=None):
def _request(tmp_path, session_manager=None, token="t", research_handler=None):
state = SimpleNamespace(
invalidate_token_cache=lambda: None,
session_manager=session_manager,
research_handler=research_handler,
)
return SimpleNamespace(
cookies={"odysseus_session": "t"},
cookies={"odysseus_session": token},
app=SimpleNamespace(state=state),
state=SimpleNamespace(current_user="admin"),
)
def _auth_manager_for_rollback_test(monkeypatch, tmp_path):
import core.auth as auth_mod
monkeypatch.setattr(auth_mod, "_hash_password", lambda password: f"hash:{password}")
monkeypatch.setattr(auth_mod, "_verify_password", lambda password, hashed: hashed == f"hash:{password}")
am = auth_mod.AuthManager(str(tmp_path / "auth.json"))
assert am.create_user("admin", "pw-123456", is_admin=True) is True
assert am.create_user("alice", "pw-123456") is True
return am
def _force_sql_owner_migration_failure(monkeypatch):
import core.database as cdb
class OwnerModel:
owner = "owner"
class FailingQuery:
def filter(self, *_args, **_kwargs):
return self
def update(self, *_args, **_kwargs):
raise RuntimeError("forced owner migration failure")
class FailingSession:
def __init__(self):
self.rolled_back = False
self.closed = False
def query(self, _model):
return FailingQuery()
def rollback(self):
self.rolled_back = True
def close(self):
self.closed = True
db = FailingSession()
monkeypatch.setattr(cdb, "SessionLocal", lambda: db)
monkeypatch.setattr(
cdb,
"Base",
SimpleNamespace(registry=SimpleNamespace(mappers=[SimpleNamespace(class_=OwnerModel)])),
raising=False,
)
return db
# ---------------------------------------------------------------------------
# 1. In-memory session cache
# ---------------------------------------------------------------------------
@@ -183,6 +238,108 @@ def test_rename_no_deep_research_dir_does_not_crash(rename_endpoint):
assert res["ok"] is True
def test_rename_updates_active_research_task_owner(rename_endpoint):
endpoint, _am, tmp_path = rename_endpoint
from routes.research_routes import setup_research_routes
from src.research_handler import ResearchHandler
rh = ResearchHandler.__new__(ResearchHandler)
rh._active_tasks = {
"alice-task": {
"owner": "Alice",
"status": "running",
"query": "q",
"progress": {},
"started_at": 1,
},
"carol-task": {
"owner": "carol",
"status": "running",
"query": "q2",
"progress": {},
"started_at": 2,
},
}
asyncio.run(endpoint(
"alice",
SimpleNamespace(username="alice2"),
_request(tmp_path, research_handler=rh),
))
assert rh._active_tasks["alice-task"]["owner"] == "alice2"
assert rh._active_tasks["carol-task"]["owner"] == "carol"
router = setup_research_routes(rh)
active = next(
r.endpoint for r in router.routes
if getattr(r, "path", "") == "/api/research/active"
)
alice2 = asyncio.run(active(
SimpleNamespace(state=SimpleNamespace(current_user="alice2")),
))
alice = asyncio.run(active(
SimpleNamespace(state=SimpleNamespace(current_user="alice")),
))
assert [item["session_id"] for item in alice2["active"]] == ["alice-task"]
assert alice["active"] == []
def test_research_handler_rename_owner_canonicalizes_new_owner():
from src.research_handler import ResearchHandler
rh = ResearchHandler.__new__(ResearchHandler)
rh._active_tasks = {
"task": {"owner": "Alice", "status": "running"},
}
changed = rh.rename_owner("alice", "Alice2")
assert changed == 1
assert rh._active_tasks["task"]["owner"] == "alice2"
def test_research_handler_rename_owner_uses_auth_lower_contract_not_casefold():
from src.research_handler import ResearchHandler
rh = ResearchHandler.__new__(ResearchHandler)
rh._active_tasks = {
"task-strasse": {"owner": "strasse", "status": "running"},
"task-sharp-s": {"owner": "straße", "status": "running"},
}
changed = rh.rename_owner("straße", "renamed")
assert changed == 1
assert rh._active_tasks["task-strasse"]["owner"] == "strasse"
assert rh._active_tasks["task-sharp-s"]["owner"] == "renamed"
def test_rename_updates_active_research_before_completed_json_sweep(rename_endpoint):
endpoint, _am, tmp_path = rename_endpoint
dr_dir = tmp_path / "deep_research"
dr_dir.mkdir()
report = dr_dir / "race-window.json"
report.write_text(json.dumps({"owner": "alice", "status": "done"}), encoding="utf-8")
owner_seen_by_active_hook = []
class FakeResearchHandler:
def rename_owner(self, _old, _new):
owner_seen_by_active_hook.append(json.loads(report.read_text(encoding="utf-8"))["owner"])
asyncio.run(endpoint(
"alice",
SimpleNamespace(username="alice2"),
_request(tmp_path, research_handler=FakeResearchHandler()),
))
assert owner_seen_by_active_hook == ["alice"]
assert json.loads(report.read_text(encoding="utf-8"))["owner"] == "alice2"
def test_rename_research_respects_custom_data_dir(monkeypatch, tmp_path):
"""DEEP_RESEARCH_DIR (which honours ODYSSEUS_DATA_DIR) is used, not a
hardcoded relative path. Before the fix, setting ODYSSEUS_DATA_DIR made
@@ -333,8 +490,100 @@ def test_rename_no_skills_dir_does_not_crash(rename_endpoint):
assert res["ok"] is True
def test_rename_skill_md_owner_case_insensitive(rename_endpoint):
"""SKILL.md written with owner: Alice (mixed case) must be updated when
renaming alice the regex was missing re.IGNORECASE."""
endpoint, _am, tmp_path = rename_endpoint
skill_dir = tmp_path / "skills" / "general" / "s"
skill_dir.mkdir(parents=True)
(skill_dir / "SKILL.md").write_text(_SKILL_MD.format(owner="Alice"), encoding="utf-8")
asyncio.run(endpoint("alice", SimpleNamespace(username="alice2"), _request(tmp_path)))
assert "owner: alice2" in (skill_dir / "SKILL.md").read_text(encoding="utf-8")
def test_rename_usage_keys_case_insensitive(rename_endpoint):
"""_usage.json keys stored as Alice::skill-name must be migrated when
renaming alice the old startswith check was not lowercasing."""
endpoint, _am, tmp_path = rename_endpoint
skills_root = tmp_path / "skills"
skills_root.mkdir(parents=True)
usage = {"Alice::my-skill": {"uses": 5, "last_used": 999}}
(skills_root / "_usage.json").write_text(json.dumps(usage), encoding="utf-8")
asyncio.run(endpoint("alice", SimpleNamespace(username="alice2"), _request(tmp_path)))
updated = json.loads((skills_root / "_usage.json").read_text(encoding="utf-8"))
assert "alice2::my-skill" in updated
assert "Alice::my-skill" not in updated
# ---------------------------------------------------------------------------
# 5. P1 regression: rejected auth rename must not mutate file-backed stores
# 5. Rollback: auth rename must be restored if SQL owner migration fails
# ---------------------------------------------------------------------------
def test_owner_migration_failure_rolls_back_auth_rename(monkeypatch, tmp_path):
import routes.auth_routes as ar
db = _force_sql_owner_migration_failure(monkeypatch)
am = _auth_manager_for_rollback_test(monkeypatch, tmp_path)
admin_token = am.create_session_trusted("admin")
alice_token = am.create_session_trusted("alice")
endpoint = _route(ar.setup_auth_routes(am), "rename_user")
with pytest.raises(HTTPException) as exc:
asyncio.run(
endpoint(
"alice",
SimpleNamespace(username="alice2"),
_request(tmp_path, token=admin_token),
)
)
assert exc.value.status_code == 500
assert db.rolled_back is True
assert db.closed is True
assert "alice" in am.users
assert "alice2" not in am.users
assert am.get_username_for_token(alice_token) == "alice"
saved_users = json.loads((tmp_path / "auth.json").read_text(encoding="utf-8"))["users"]
assert "alice" in saved_users
assert "alice2" not in saved_users
def test_self_rename_owner_migration_failure_rolls_back_auth_session(monkeypatch, tmp_path):
import routes.auth_routes as ar
db = _force_sql_owner_migration_failure(monkeypatch)
am = _auth_manager_for_rollback_test(monkeypatch, tmp_path)
admin_token = am.create_session_trusted("admin")
endpoint = _route(ar.setup_auth_routes(am), "rename_user")
with pytest.raises(HTTPException) as exc:
asyncio.run(
endpoint(
"admin",
SimpleNamespace(username="chief"),
_request(tmp_path, token=admin_token),
)
)
assert exc.value.status_code == 500
assert db.rolled_back is True
assert db.closed is True
assert "admin" in am.users
assert "chief" not in am.users
assert am.get_username_for_token(admin_token) == "admin"
saved_users = json.loads((tmp_path / "auth.json").read_text(encoding="utf-8"))["users"]
assert "admin" in saved_users
assert "chief" not in saved_users
# ---------------------------------------------------------------------------
# 6. P1 regression: rejected auth rename must not mutate file-backed stores
# ---------------------------------------------------------------------------
def test_rejected_rename_does_not_mutate_files(monkeypatch, tmp_path):
@@ -0,0 +1,41 @@
"""get_status must not rescan the whole research dir on every SSE poll.
get_avg_duration() globs and JSON-parses every file under the research data dir.
get_status() called it unconditionally on each poll, including for sessions that
are not active (the common case while a client polls a finished report). It is
now computed only for active sessions and memoized on the entry.
"""
from src.research_handler import ResearchHandler
def _handler():
h = ResearchHandler.__new__(ResearchHandler)
h._active_tasks = {}
return h
def test_inactive_session_does_not_compute_avg(monkeypatch):
h = _handler()
calls = []
monkeypatch.setattr(h, "get_avg_duration", lambda: (calls.append(1), 5.0)[1])
# Unknown session, no disk file -> None, and no expensive avg scan.
assert h.get_status("missing-session") is None
assert calls == []
def test_active_session_memoizes_avg(monkeypatch):
h = _handler()
h._active_tasks["s1"] = {
"status": "running", "progress": {}, "query": "q", "started_at": 0,
}
calls = []
monkeypatch.setattr(h, "get_avg_duration", lambda: (calls.append(1), 12.0)[1])
r1 = h.get_status("s1")
r2 = h.get_status("s1")
r3 = h.get_status("s1")
assert r1["avg_duration"] == 12.0
assert r2["avg_duration"] == 12.0 and r3["avg_duration"] == 12.0
# Computed once across many polls, not once per poll.
assert len(calls) == 1
@@ -58,6 +58,62 @@ def test_rename_into_reserved_username_is_blocked(tmp_path):
assert "bob" in mgr.users
def test_legacy_reserved_username_is_removed_on_load(tmp_path):
auth_path = tmp_path / "auth.json"
auth_path.write_text(
'{"users": {"internal-tool": {"password_hash": "unused", "is_admin": false}, '
'"admin": {"password_hash": "unused", "is_admin": true}}}',
encoding="utf-8",
)
mgr = _fresh_auth_manager(tmp_path)
assert "internal-tool" not in mgr.users
assert "admin" in mgr.users
assert "internal-tool" not in auth_path.read_text(encoding="utf-8")
def test_legacy_reserved_username_session_cannot_authenticate(tmp_path):
auth_path = tmp_path / "auth.json"
sessions_path = tmp_path / "sessions.json"
auth_path.write_text(
'{"users": {"internal-tool": {"password_hash": "unused", "is_admin": false}}}',
encoding="utf-8",
)
sessions_path.write_text(
'{"tok": {"username": "internal-tool", "expiry": 9999999999}}',
encoding="utf-8",
)
mgr = _fresh_auth_manager(tmp_path)
assert mgr.validate_token("tok") is False
assert mgr.get_username_for_token("tok") is None
def test_legacy_reserved_single_user_migrates_to_admin(tmp_path):
auth_path = tmp_path / "auth.json"
auth_path.write_text(
'{"username": "internal-tool", "password_hash": "unused"}',
encoding="utf-8",
)
mgr = _fresh_auth_manager(tmp_path)
assert "internal-tool" not in mgr.users
assert "admin" in mgr.users
assert mgr.is_admin("admin") is True
def test_token_cache_owner_normalization_requires_current_user():
clear_module("core.auth")
from core.auth import normalize_known_username
users = {"alice": {}, "admin": {}}
assert normalize_known_username(users, " Alice ") == "alice"
assert normalize_known_username(users, "internal-tool") is None
assert normalize_known_username(users, "api") is None
assert normalize_known_username(users, "") is None
def test_normal_usernames_still_allowed(tmp_path):
mgr = _fresh_auth_manager(tmp_path)
assert mgr.create_user("alice", "pw-123456") is True
+54
View File
@@ -647,6 +647,60 @@ def test_public_agent_policy_hides_sensitive_tools(monkeypatch):
assert "manage_tasks" in blocked
def test_presetup_does_not_grant_admin_tools_when_auth_enabled(monkeypatch):
"""Pre-setup window: auth is enabled but no admin user exists yet.
This must NOT be treated as single-user/admin at the tool layer the
server-execution tools (bash/python) stay blocked as defense-in-depth so
an unauthenticated caller that slips past the auth middleware (e.g. via a
loopback bypass) can't reach an RCE before setup completes.
"""
monkeypatch.delenv("AUTH_ENABLED", raising=False) # default: enabled
auth_mod = _install_core_auth_stub(monkeypatch)
class FakeAuth:
is_configured = False
def is_admin(self, username):
return False
monkeypatch.setattr(auth_mod, "AuthManager", lambda: FakeAuth())
from src.tool_security import (
blocked_tools_for_owner,
owner_is_admin_or_single_user,
)
assert owner_is_admin_or_single_user(None) is False
blocked = blocked_tools_for_owner(None)
assert "bash" in blocked
assert "python" in blocked
def test_single_user_mode_keeps_full_tool_access_when_auth_disabled(monkeypatch):
"""Intentional single-user mode (AUTH_ENABLED=false) keeps full tool
access even with no admin user this is the default local/self-host UX
and must not regress."""
monkeypatch.setenv("AUTH_ENABLED", "false")
auth_mod = _install_core_auth_stub(monkeypatch)
class FakeAuth:
is_configured = False
def is_admin(self, username):
return False
monkeypatch.setattr(auth_mod, "AuthManager", lambda: FakeAuth())
from src.tool_security import (
blocked_tools_for_owner,
owner_is_admin_or_single_user,
)
assert owner_is_admin_or_single_user(None) is True
assert blocked_tools_for_owner(None) == set()
@pytest.mark.asyncio
async def test_webhook_tool_reuses_private_url_validation():
class FakeDb:
+23
View File
@@ -0,0 +1,23 @@
import pytest
from fastapi import HTTPException
from routes._validators import validate_remote_host, validate_ssh_port
def test_validate_ssh_port_rejects_shell_payload():
for port in ["22;id", "$(id)", "-p 22", "0", "65536"]:
with pytest.raises(HTTPException):
validate_ssh_port(port)
assert validate_ssh_port("2222") == "2222"
def test_validate_remote_host_rejects_ssh_option_shape():
for host in [
"-oProxyCommand=sh",
"alice@-oProxyCommand=sh",
"--",
"-p2222",
]:
with pytest.raises(HTTPException):
validate_remote_host(host)
assert validate_remote_host("alice@gpu-box_1") == "alice@gpu-box_1"
+1 -1
View File
@@ -90,8 +90,8 @@ def test_service_ddg_html_fallback_sends_safesearch(monkeypatch):
seen["params"] = kwargs["params"]
return _Response()
monkeypatch.setitem(sys.modules, "duckduckgo_search", None)
monkeypatch.setattr(providers, "_get_search_settings", lambda: {"search_safesearch": "off"})
monkeypatch.setitem(sys.modules, "ddgs", None)
monkeypatch.setattr(providers.httpx, "get", fake_get)
results = providers.duckduckgo_search("odysseus", count=1)
+47
View File
@@ -0,0 +1,47 @@
"""Startup warmup must resolve real endpoint URLs.
The warmup/keepalive loop called `model_discovery.get_endpoints()`, which does
not exist on ModelDiscovery, so it raised AttributeError every run and pinged
nothing. `ModelDiscovery.warmup_ping_urls()` resolves the /models probe URLs
from the real discovery API.
"""
from src.model_discovery import ModelDiscovery
def _md():
return ModelDiscovery.__new__(ModelDiscovery)
def test_old_method_never_existed():
# Documents why the old warmup was a silent no-op.
assert not hasattr(ModelDiscovery, "get_endpoints")
def test_resolves_models_urls_from_discovered_items():
md = _md()
md.discover_models = lambda: {"items": [
{"url": "http://host:8000/v1/chat/completions", "models": ["a"]},
{"url": "http://host:1234/v1/chat/completions", "models": ["b"]},
]}
assert md.warmup_ping_urls() == [
"http://host:8000/v1/models",
"http://host:1234/v1/models",
]
def test_limit_caps_results():
md = _md()
md.discover_models = lambda: {"items": [
{"url": f"http://h:{8000 + i}/v1/chat/completions"} for i in range(10)
]}
assert len(md.warmup_ping_urls(limit=3)) == 3
def test_discovery_failure_degrades_to_empty():
md = _md()
def boom():
raise RuntimeError("port scan failed")
md.discover_models = boom
assert md.warmup_ping_urls() == []
+119
View File
@@ -0,0 +1,119 @@
"""Pin the web_search tool-icon rendering in the agent thread (PR #??).
Verifies:
- web_search renders an <svg> icon instead of raw markup
- Other tools get the default icon
- Hostile tool names are HTML-escaped in the label
Pure JS via node --input-type=module (same approach as
test_composer_arrow_up_recall_js.py). Skips when node is not installed.
"""
import json
import shutil
import subprocess
from pathlib import Path
import pytest
_REPO = Path(__file__).resolve().parent.parent
_HAS_NODE = shutil.which("node") is not None
_CHECK_JS = r"""
function esc(s) {
const map = { '&': '&amp;', '<': '&lt;', '>': '&gt;', '"': '&quot;', "'": '&#39;' };
return (s || '').replace(/[&<>"']/g, (m) => map[m]);
}
const _searchIcon = '<svg width="14" height="14" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2.5" stroke-linecap="round" style="vertical-align:-2px;margin-right:4px"><circle cx="11" cy="11" r="8"/><line x1="21" y1="21" x2="16.65" y2="16.65"/></svg>';
const _toolLabels = {
web_search: 'Searching',
bash: 'Running',
};
const _toolIcons = {
web_search: _searchIcon,
};
function renderIcon(toolName) {
return _toolIcons[toolName.toLowerCase()] || '\u25B6';
}
function renderLabel(toolName) {
return _toolLabels[toolName.toLowerCase()] || toolName;
}
function renderThreadHTML(toolName, cmd) {
const label = renderLabel(toolName);
const icon = renderIcon(toolName);
const cmdHtml = cmd ? `<pre class="agent-thread-cmd">${esc(cmd)}</pre>` : '';
return `<div class="agent-thread-dot"></div><div class="agent-thread-header"><span class="agent-thread-icon">${icon}</span><span class="agent-thread-tool">${esc(label)}</span><span class="agent-thread-wave">\u2581\u2582\u2583</span></div><div class="agent-thread-content">${cmdHtml}</div>`;
}
const cases = CASES_JSON;
const results = cases.map(c => {
const html = renderThreadHTML(c.tool, c.cmd || '');
return { tool: c.tool, html };
});
console.log(JSON.stringify(results));
"""
def _run(cases: list) -> list:
js = _CHECK_JS.replace("CASES_JSON", json.dumps(cases))
proc = subprocess.run(
["node", "--input-type=module"],
input=js,
capture_output=True,
text=True,
encoding="utf-8",
cwd=str(_REPO),
timeout=30,
)
assert proc.returncode == 0, proc.stderr
return json.loads(proc.stdout.strip())
@pytest.mark.skipif(not _HAS_NODE, reason="node binary not on PATH")
def test_web_search_icon_contains_svg():
out = _run([{"tool": "web_search"}])[0]
assert "<svg" in out["html"], "Expected <svg> in agent-thread-icon for web_search"
assert "Searching" in out["html"], "Expected 'Searching' label for web_search"
@pytest.mark.skipif(not _HAS_NODE, reason="node binary not on PATH")
def test_default_tool_icon_is_triangle():
out = _run([{"tool": "bash"}])[0]
assert "" in out["html"], "Expected ▶ icon for tools without custom icon"
assert "<svg" not in out["html"], "Expected no <svg> for bash"
assert "Running" in out["html"], "Expected 'Running' label for bash"
@pytest.mark.skipif(not _HAS_NODE, reason="node binary not on PATH")
def test_unknown_tool_falls_back_to_name():
out = _run([{"tool": "my_custom_tool"}])[0]
assert "" in out["html"], "Expected ▶ for unknown tool"
assert "my_custom_tool" in out["html"], "Expected tool name as label"
@pytest.mark.skipif(not _HAS_NODE, reason="node binary not on PATH")
def test_hostile_tool_name_is_escaped():
out = _run([{"tool": '<img src=x onerror="alert(1)">'}])[0]
assert "&lt;img" in out["html"], "Expected < to be HTML-escaped"
assert "&gt;" in out["html"], "Expected > to be HTML-escaped"
assert "<img" not in out["html"], "Raw <img> must not appear"
assert "onerror" not in out["html"] or "&quot;" in out["html"], "onerror must not be executable"
@pytest.mark.skipif(not _HAS_NODE, reason="node binary not on PATH")
def test_unknown_tool_case_insensitive_matches_icons():
out = _run([{"tool": "WEB_SEARCH"}, {"tool": "Web_Search"}])
for r in out:
assert "<svg" in r["html"], f"Expected SVG for case-variant '{r['tool']}'"
@pytest.mark.skipif(not _HAS_NODE, reason="node binary not on PATH")
def test_command_is_escaped():
out = _run([{"tool": "bash", "cmd": "echo $HOME && ls"}])[0]
assert "echo $HOME" in out["html"], "Expected command text in output"