mirror of
https://github.com/pewdiepie-archdaemon/odysseus.git
synced 2026-06-15 17:25:26 -04:00
Merge branch 'dev' into fix/no-scroll-snapping
This commit is contained in:
@@ -329,7 +329,7 @@ To expose Odysseus on a local network or Tailscale with HTTPS:
|
||||
| Package | Feature unlocked |
|
||||
|---------|-----------------|
|
||||
| `faster-whisper` | Local speech-to-text (microphone -> text) via the "local" STT provider. |
|
||||
| `duckduckgo-search` | DuckDuckGo as a search provider option. |
|
||||
| `ddgs` | DuckDuckGo as a search provider option. |
|
||||
| `PyMuPDF` | PDF page rendering in the side viewer panel and form-filling. (Note: AGPL-3.0) |
|
||||
| `markitdown` | Office/EPUB document text extraction (converts .docx/.xlsx/.pptx/.xls/.epub to Markdown). |
|
||||
|
||||
|
||||
@@ -56,7 +56,7 @@ from core.constants import (
|
||||
)
|
||||
from core.database import SessionLocal, ApiToken
|
||||
from core.middleware import SecurityHeadersMiddleware, is_cors_preflight
|
||||
from core.auth import AuthManager
|
||||
from core.auth import AuthManager, normalize_known_username
|
||||
from core.exceptions import (
|
||||
SessionNotFoundError, InvalidFileUploadError,
|
||||
LLMServiceError, WebSearchError,
|
||||
@@ -228,8 +228,16 @@ if AUTH_ENABLED:
|
||||
try:
|
||||
rows = db.query(ApiToken).filter(ApiToken.is_active == True).all()
|
||||
for r in rows:
|
||||
owner_key = normalize_known_username(auth_manager.users, getattr(r, "owner", None))
|
||||
if not owner_key:
|
||||
logger.warning(
|
||||
"Ignoring active API token '%s' for unknown auth user '%s'",
|
||||
getattr(r, "id", ""),
|
||||
getattr(r, "owner", None),
|
||||
)
|
||||
continue
|
||||
scopes = [s.strip() for s in (getattr(r, "scopes", "") or "chat").split(",") if s.strip()]
|
||||
new_map[r.token_prefix].append((r.id, r.token_hash, getattr(r, "owner", None), scopes))
|
||||
new_map[r.token_prefix].append((r.id, r.token_hash, owner_key, scopes))
|
||||
finally:
|
||||
db.close()
|
||||
_token_cache.clear()
|
||||
@@ -495,6 +503,7 @@ api_key_manager = components["api_key_manager"]
|
||||
preset_manager = components["preset_manager"]
|
||||
chat_processor = components["chat_processor"]
|
||||
research_handler = components["research_handler"]
|
||||
app.state.research_handler = research_handler
|
||||
chat_handler = components["chat_handler"]
|
||||
model_discovery = components["model_discovery"]
|
||||
skills_manager = components["skills_manager"]
|
||||
@@ -938,10 +947,15 @@ async def _startup_event():
|
||||
async def _warmup_endpoints():
|
||||
try:
|
||||
import httpx
|
||||
endpoints = model_discovery.get_endpoints() if model_discovery else []
|
||||
for ep in endpoints[:5]:
|
||||
url = ep.get("url", "").replace("/chat/completions", "/models")
|
||||
if url:
|
||||
# model_discovery has no get_endpoints(); that call raised
|
||||
# AttributeError every run and silently disabled warmup/keepalive.
|
||||
# Resolve the /models probe URLs via the real discovery API, off the
|
||||
# event loop since discovery does a blocking port scan.
|
||||
urls = (
|
||||
await asyncio.to_thread(model_discovery.warmup_ping_urls)
|
||||
if model_discovery else []
|
||||
)
|
||||
for url in urls:
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=5.0) as client:
|
||||
await client.get(url)
|
||||
|
||||
+56
-13
@@ -67,6 +67,14 @@ TOKEN_TTL = 60 * 60 * 24 * 7 # 7 days
|
||||
RESERVED_USERNAMES = frozenset({"internal-tool", "api", "demo", "system"})
|
||||
|
||||
|
||||
def normalize_known_username(users: Dict[str, Any], username: str | None) -> Optional[str]:
|
||||
"""Return a normalized username only when it exists in the auth user map."""
|
||||
key = str(username or "").strip().lower()
|
||||
if not key or key not in users:
|
||||
return None
|
||||
return key
|
||||
|
||||
|
||||
def _hash_password(password: str) -> str:
|
||||
return bcrypt.hashpw(password.encode("utf-8"), bcrypt.gensalt()).decode("utf-8")
|
||||
|
||||
@@ -96,6 +104,7 @@ class AuthManager:
|
||||
self._load()
|
||||
self._load_sessions()
|
||||
self._migrate_single_user()
|
||||
self._drop_reserved_loaded_users()
|
||||
self._migrate_legacy_admin_role()
|
||||
|
||||
def _load(self):
|
||||
@@ -148,7 +157,13 @@ class AuthManager:
|
||||
def _migrate_single_user(self):
|
||||
"""Migrate old single-user format to multi-user format."""
|
||||
if "password_hash" in self._config and "users" not in self._config:
|
||||
old_user = self._config.get("username", "admin")
|
||||
old_user = str(self._config.get("username", "admin") or "admin").strip().lower()
|
||||
if old_user in RESERVED_USERNAMES:
|
||||
logger.warning(
|
||||
"Migrating legacy single-user reserved username '%s' to 'admin'",
|
||||
old_user,
|
||||
)
|
||||
old_user = "admin"
|
||||
old_hash = self._config["password_hash"]
|
||||
self._config = {
|
||||
"users": {
|
||||
@@ -162,6 +177,30 @@ class AuthManager:
|
||||
self._save()
|
||||
logger.info(f"Migrated single-user auth to multi-user (admin: {old_user})")
|
||||
|
||||
def _drop_reserved_loaded_users(self):
|
||||
"""Fail closed for legacy/manual auth rows that collide with sentinels."""
|
||||
users = self._config.get("users")
|
||||
if not isinstance(users, dict):
|
||||
return
|
||||
normalized = {}
|
||||
removed = []
|
||||
for username, data in users.items():
|
||||
key = str(username or "").strip().lower()
|
||||
if not key:
|
||||
continue
|
||||
if key in RESERVED_USERNAMES:
|
||||
removed.append(key)
|
||||
continue
|
||||
normalized[key] = data
|
||||
if removed or normalized != users:
|
||||
self._config["users"] = normalized
|
||||
self._save()
|
||||
if removed:
|
||||
logger.warning(
|
||||
"Removed reserved username(s) from auth config: %s",
|
||||
", ".join(sorted(set(removed))),
|
||||
)
|
||||
|
||||
def _migrate_legacy_admin_role(self):
|
||||
"""Normalize setup.py's old role='admin' marker to is_admin=True."""
|
||||
changed = False
|
||||
@@ -244,6 +283,22 @@ class AuthManager:
|
||||
return False
|
||||
if not self.users.get(requesting_user, {}).get("is_admin"):
|
||||
return False
|
||||
# Revoke API bearer tokens before removing the auth row. The bearer
|
||||
# path authenticates from ApiToken rows and does not require the
|
||||
# owner to still exist, so a successful delete must not leave active
|
||||
# rows behind. If the token store is unavailable, fail closed and
|
||||
# keep the user/session state intact so the admin can retry.
|
||||
try:
|
||||
from core.database import get_db_session, ApiToken
|
||||
with get_db_session() as db:
|
||||
removed_tokens = db.query(ApiToken).filter(ApiToken.owner == username).delete()
|
||||
if removed_tokens:
|
||||
logger.info(
|
||||
f"Revoked {removed_tokens} API token(s) owned by deleted user '{username}'"
|
||||
)
|
||||
except Exception:
|
||||
logger.warning(f"Failed to revoke API tokens for deleted user '{username}'")
|
||||
return False
|
||||
del self._config["users"][username]
|
||||
self._save()
|
||||
# Purge all sessions belonging to this user. validate_token doesn't
|
||||
@@ -258,18 +313,6 @@ class AuthManager:
|
||||
revoked += 1
|
||||
if revoked:
|
||||
self._save_sessions()
|
||||
# Also revoke API bearer tokens owned by this user. The bearer auth
|
||||
# path authenticates straight against ApiToken rows and never
|
||||
# re-checks that the owner still exists, so leaving the rows behind
|
||||
# would let a deleted user keep full API access indefinitely.
|
||||
try:
|
||||
from core.database import get_db_session, ApiToken
|
||||
with get_db_session() as db:
|
||||
removed = db.query(ApiToken).filter(ApiToken.owner == username).delete()
|
||||
if removed:
|
||||
logger.info(f"Revoked {removed} API token(s) owned by deleted user '{username}'")
|
||||
except Exception:
|
||||
logger.warning(f"Failed to revoke API tokens for deleted user '{username}'")
|
||||
logger.info(f"Deleted user '{username}' (by {requesting_user}); revoked {revoked} active session(s)")
|
||||
return True
|
||||
|
||||
|
||||
+150
-25
@@ -688,6 +688,7 @@ def _migrate_add_last_message_at_column():
|
||||
db_path = DATABASE_URL.replace("sqlite:///", "")
|
||||
if not os.path.exists(db_path):
|
||||
return
|
||||
conn = None
|
||||
try:
|
||||
conn = sqlite3.connect(db_path)
|
||||
cursor = conn.execute("PRAGMA table_info(sessions)")
|
||||
@@ -713,10 +714,14 @@ def _migrate_add_last_message_at_column():
|
||||
"ON sessions(archived, last_message_at)"
|
||||
)
|
||||
conn.commit()
|
||||
conn.close()
|
||||
logging.getLogger(__name__).info("Migrated: added + backfilled 'last_message_at' on sessions")
|
||||
except Exception as e:
|
||||
logging.getLogger(__name__).warning(f"last_message_at migration failed: {e}")
|
||||
finally:
|
||||
try:
|
||||
conn.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _migrate_add_document_archived_column():
|
||||
"""Add `archived` to documents (soft-archive flag). Guarded + idempotent."""
|
||||
@@ -724,6 +729,7 @@ def _migrate_add_document_archived_column():
|
||||
db_path = DATABASE_URL.replace("sqlite:///", "")
|
||||
if not os.path.exists(db_path):
|
||||
return
|
||||
conn = None
|
||||
try:
|
||||
conn = sqlite3.connect(db_path)
|
||||
cursor = conn.execute("PRAGMA table_info(documents)")
|
||||
@@ -732,9 +738,13 @@ def _migrate_add_document_archived_column():
|
||||
conn.execute("ALTER TABLE documents ADD COLUMN archived BOOLEAN DEFAULT 0")
|
||||
conn.commit()
|
||||
logging.getLogger(__name__).info("Migrated: added 'archived' to documents")
|
||||
conn.close()
|
||||
except Exception as e:
|
||||
logging.getLogger(__name__).warning(f"documents.archived migration failed: {e}")
|
||||
finally:
|
||||
try:
|
||||
conn.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def _migrate_add_owner_column():
|
||||
@@ -743,6 +753,7 @@ def _migrate_add_owner_column():
|
||||
db_path = DATABASE_URL.replace("sqlite:///", "")
|
||||
if not os.path.exists(db_path):
|
||||
return
|
||||
conn = None
|
||||
try:
|
||||
conn = sqlite3.connect(db_path)
|
||||
cursor = conn.execute("PRAGMA table_info(sessions)")
|
||||
@@ -752,9 +763,13 @@ def _migrate_add_owner_column():
|
||||
conn.execute("CREATE INDEX IF NOT EXISTS ix_sessions_owner ON sessions(owner)")
|
||||
conn.commit()
|
||||
logging.getLogger(__name__).info("Migrated: added 'owner' column to sessions")
|
||||
conn.close()
|
||||
except Exception as e:
|
||||
logging.getLogger(__name__).warning(f"Migration check failed: {e}")
|
||||
finally:
|
||||
try:
|
||||
conn.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _migrate_model_endpoints():
|
||||
"""Recreate model_endpoints table if schema changed (url->base_url)."""
|
||||
@@ -762,6 +777,7 @@ def _migrate_model_endpoints():
|
||||
db_path = DATABASE_URL.replace("sqlite:///", "")
|
||||
if not os.path.exists(db_path):
|
||||
return
|
||||
conn = None
|
||||
try:
|
||||
conn = sqlite3.connect(db_path)
|
||||
cursor = conn.execute("PRAGMA table_info(model_endpoints)")
|
||||
@@ -770,9 +786,13 @@ def _migrate_model_endpoints():
|
||||
conn.execute("DROP TABLE IF EXISTS model_endpoints")
|
||||
conn.commit()
|
||||
logging.getLogger(__name__).info("Migrated: dropped old model_endpoints table (schema change)")
|
||||
conn.close()
|
||||
except Exception as e:
|
||||
logging.getLogger(__name__).warning(f"model_endpoints migration check failed: {e}")
|
||||
finally:
|
||||
try:
|
||||
conn.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _migrate_add_hidden_models_column():
|
||||
"""Add hidden_models column to model_endpoints if it doesn't exist."""
|
||||
@@ -780,6 +800,7 @@ def _migrate_add_hidden_models_column():
|
||||
db_path = DATABASE_URL.replace("sqlite:///", "")
|
||||
if not os.path.exists(db_path):
|
||||
return
|
||||
conn = None
|
||||
try:
|
||||
conn = sqlite3.connect(db_path)
|
||||
cursor = conn.execute("PRAGMA table_info(model_endpoints)")
|
||||
@@ -788,9 +809,13 @@ def _migrate_add_hidden_models_column():
|
||||
conn.execute("ALTER TABLE model_endpoints ADD COLUMN hidden_models TEXT")
|
||||
conn.commit()
|
||||
logging.getLogger(__name__).info("Migrated: added 'hidden_models' column to model_endpoints")
|
||||
conn.close()
|
||||
except Exception as e:
|
||||
logging.getLogger(__name__).warning(f"hidden_models migration failed: {e}")
|
||||
finally:
|
||||
try:
|
||||
conn.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _migrate_add_model_endpoint_owner_column():
|
||||
"""Add owner column to model_endpoints if it doesn't exist.
|
||||
@@ -805,6 +830,7 @@ def _migrate_add_model_endpoint_owner_column():
|
||||
db_path = DATABASE_URL.replace("sqlite:///", "")
|
||||
if not os.path.exists(db_path):
|
||||
return
|
||||
conn = None
|
||||
try:
|
||||
conn = sqlite3.connect(db_path)
|
||||
cursor = conn.execute("PRAGMA table_info(model_endpoints)")
|
||||
@@ -814,9 +840,13 @@ def _migrate_add_model_endpoint_owner_column():
|
||||
conn.execute("CREATE INDEX IF NOT EXISTS ix_model_endpoints_owner ON model_endpoints(owner)")
|
||||
conn.commit()
|
||||
logging.getLogger(__name__).info("Migrated: added 'owner' column + index to model_endpoints")
|
||||
conn.close()
|
||||
except Exception as e:
|
||||
logging.getLogger(__name__).warning(f"model_endpoints.owner migration failed: {e}")
|
||||
finally:
|
||||
try:
|
||||
conn.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def _migrate_add_provider_auth_id_column():
|
||||
@@ -825,6 +855,7 @@ def _migrate_add_provider_auth_id_column():
|
||||
db_path = DATABASE_URL.replace("sqlite:///", "")
|
||||
if not os.path.exists(db_path):
|
||||
return
|
||||
conn = None
|
||||
try:
|
||||
conn = sqlite3.connect(db_path)
|
||||
cursor = conn.execute("PRAGMA table_info(model_endpoints)")
|
||||
@@ -834,9 +865,13 @@ def _migrate_add_provider_auth_id_column():
|
||||
conn.execute("CREATE INDEX IF NOT EXISTS ix_model_endpoints_provider_auth_id ON model_endpoints(provider_auth_id)")
|
||||
conn.commit()
|
||||
logging.getLogger(__name__).info("Migrated: added 'provider_auth_id' column + index to model_endpoints")
|
||||
conn.close()
|
||||
except Exception as e:
|
||||
logging.getLogger(__name__).warning(f"model_endpoints.provider_auth_id migration failed: {e}")
|
||||
finally:
|
||||
try:
|
||||
conn.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def _migrate_add_model_type_column():
|
||||
@@ -845,6 +880,7 @@ def _migrate_add_model_type_column():
|
||||
db_path = DATABASE_URL.replace("sqlite:///", "")
|
||||
if not os.path.exists(db_path):
|
||||
return
|
||||
conn = None
|
||||
try:
|
||||
conn = sqlite3.connect(db_path)
|
||||
cursor = conn.execute("PRAGMA table_info(model_endpoints)")
|
||||
@@ -853,9 +889,13 @@ def _migrate_add_model_type_column():
|
||||
conn.execute("ALTER TABLE model_endpoints ADD COLUMN model_type TEXT DEFAULT 'llm'")
|
||||
conn.commit()
|
||||
logging.getLogger(__name__).info("Migrated: added 'model_type' column to model_endpoints")
|
||||
conn.close()
|
||||
except Exception as e:
|
||||
logging.getLogger(__name__).warning(f"model_type migration failed: {e}")
|
||||
finally:
|
||||
try:
|
||||
conn.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _migrate_add_model_endpoint_refresh_columns():
|
||||
"""Add endpoint classification / refresh policy columns if missing."""
|
||||
@@ -863,6 +903,7 @@ def _migrate_add_model_endpoint_refresh_columns():
|
||||
db_path = DATABASE_URL.replace("sqlite:///", "")
|
||||
if not os.path.exists(db_path):
|
||||
return
|
||||
conn = None
|
||||
try:
|
||||
conn = sqlite3.connect(db_path)
|
||||
cursor = conn.execute("PRAGMA table_info(model_endpoints)")
|
||||
@@ -876,9 +917,13 @@ def _migrate_add_model_endpoint_refresh_columns():
|
||||
if columns and "model_refresh_timeout" not in columns:
|
||||
conn.execute("ALTER TABLE model_endpoints ADD COLUMN model_refresh_timeout INTEGER")
|
||||
conn.commit()
|
||||
conn.close()
|
||||
except Exception as e:
|
||||
logging.getLogger(__name__).warning(f"model_endpoints refresh-policy migration failed: {e}")
|
||||
finally:
|
||||
try:
|
||||
conn.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _migrate_add_task_run_model_column():
|
||||
"""Add model column to task_runs if it doesn't exist (records which model ran)."""
|
||||
@@ -886,6 +931,7 @@ def _migrate_add_task_run_model_column():
|
||||
db_path = DATABASE_URL.replace("sqlite:///", "")
|
||||
if not os.path.exists(db_path):
|
||||
return
|
||||
conn = None
|
||||
try:
|
||||
conn = sqlite3.connect(db_path)
|
||||
cursor = conn.execute("PRAGMA table_info(task_runs)")
|
||||
@@ -894,9 +940,13 @@ def _migrate_add_task_run_model_column():
|
||||
conn.execute("ALTER TABLE task_runs ADD COLUMN model TEXT")
|
||||
conn.commit()
|
||||
logging.getLogger(__name__).info("Migrated: added 'model' column to task_runs")
|
||||
conn.close()
|
||||
except Exception as e:
|
||||
logging.getLogger(__name__).warning(f"task_runs model migration failed: {e}")
|
||||
finally:
|
||||
try:
|
||||
conn.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _migrate_add_supports_tools_column():
|
||||
"""Add supports_tools column to model_endpoints if it doesn't exist."""
|
||||
@@ -904,6 +954,7 @@ def _migrate_add_supports_tools_column():
|
||||
db_path = DATABASE_URL.replace("sqlite:///", "")
|
||||
if not os.path.exists(db_path):
|
||||
return
|
||||
conn = None
|
||||
try:
|
||||
conn = sqlite3.connect(db_path)
|
||||
cursor = conn.execute("PRAGMA table_info(model_endpoints)")
|
||||
@@ -912,9 +963,13 @@ def _migrate_add_supports_tools_column():
|
||||
conn.execute("ALTER TABLE model_endpoints ADD COLUMN supports_tools BOOLEAN")
|
||||
conn.commit()
|
||||
logging.getLogger(__name__).info("Migrated: added 'supports_tools' column to model_endpoints")
|
||||
conn.close()
|
||||
except Exception as e:
|
||||
logging.getLogger(__name__).warning(f"supports_tools migration failed: {e}")
|
||||
finally:
|
||||
try:
|
||||
conn.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def _migrate_add_cached_models_column():
|
||||
@@ -923,6 +978,7 @@ def _migrate_add_cached_models_column():
|
||||
db_path = DATABASE_URL.replace("sqlite:///", "")
|
||||
if not os.path.exists(db_path):
|
||||
return
|
||||
conn = None
|
||||
try:
|
||||
conn = sqlite3.connect(db_path)
|
||||
cursor = conn.execute("PRAGMA table_info(model_endpoints)")
|
||||
@@ -930,9 +986,13 @@ def _migrate_add_cached_models_column():
|
||||
if columns and "cached_models" not in columns:
|
||||
conn.execute("ALTER TABLE model_endpoints ADD COLUMN cached_models TEXT")
|
||||
conn.commit()
|
||||
conn.close()
|
||||
except Exception as e:
|
||||
logging.getLogger(__name__).warning(f"cached_models migration failed: {e}")
|
||||
finally:
|
||||
try:
|
||||
conn.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _migrate_add_pinned_models_column():
|
||||
"""Add pinned_models column to model_endpoints if it doesn't exist."""
|
||||
@@ -940,6 +1000,7 @@ def _migrate_add_pinned_models_column():
|
||||
db_path = DATABASE_URL.replace("sqlite:///", "")
|
||||
if not os.path.exists(db_path):
|
||||
return
|
||||
conn = None
|
||||
try:
|
||||
conn = sqlite3.connect(db_path)
|
||||
cursor = conn.execute("PRAGMA table_info(model_endpoints)")
|
||||
@@ -948,9 +1009,13 @@ def _migrate_add_pinned_models_column():
|
||||
conn.execute("ALTER TABLE model_endpoints ADD COLUMN pinned_models TEXT")
|
||||
conn.commit()
|
||||
logging.getLogger(__name__).info("Migrated: added 'pinned_models' column to model_endpoints")
|
||||
conn.close()
|
||||
except Exception as e:
|
||||
logging.getLogger(__name__).warning(f"pinned_models migration failed: {e}")
|
||||
finally:
|
||||
try:
|
||||
conn.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _migrate_add_notes_sort_order():
|
||||
"""Add sort_order, image_url, repeat columns to notes if they don't exist."""
|
||||
@@ -958,6 +1023,7 @@ def _migrate_add_notes_sort_order():
|
||||
db_path = DATABASE_URL.replace("sqlite:///", "")
|
||||
if not os.path.exists(db_path):
|
||||
return
|
||||
conn = None
|
||||
try:
|
||||
conn = sqlite3.connect(db_path)
|
||||
cursor = conn.execute("PRAGMA table_info(notes)")
|
||||
@@ -975,9 +1041,13 @@ def _migrate_add_notes_sort_order():
|
||||
if columns and "agent_session_id" not in columns:
|
||||
conn.execute("ALTER TABLE notes ADD COLUMN agent_session_id TEXT")
|
||||
conn.commit()
|
||||
conn.close()
|
||||
except Exception as e:
|
||||
logging.getLogger(__name__).warning(f"notes migration failed: {e}")
|
||||
finally:
|
||||
try:
|
||||
conn.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _migrate_add_mode_column():
|
||||
"""Add mode column to sessions table if it doesn't exist."""
|
||||
@@ -985,6 +1055,7 @@ def _migrate_add_mode_column():
|
||||
db_path = DATABASE_URL.replace("sqlite:///", "")
|
||||
if not os.path.exists(db_path):
|
||||
return
|
||||
conn = None
|
||||
try:
|
||||
conn = sqlite3.connect(db_path)
|
||||
cursor = conn.execute("PRAGMA table_info(sessions)")
|
||||
@@ -993,9 +1064,13 @@ def _migrate_add_mode_column():
|
||||
conn.execute("ALTER TABLE sessions ADD COLUMN mode TEXT")
|
||||
conn.commit()
|
||||
logging.getLogger(__name__).info("Migrated: added 'mode' column to sessions")
|
||||
conn.close()
|
||||
except Exception as e:
|
||||
logging.getLogger(__name__).warning(f"Migration check for mode failed: {e}")
|
||||
finally:
|
||||
try:
|
||||
conn.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _migrate_add_folder_column():
|
||||
"""Add folder column to sessions table if it doesn't exist."""
|
||||
@@ -1003,6 +1078,7 @@ def _migrate_add_folder_column():
|
||||
db_path = DATABASE_URL.replace("sqlite:///", "")
|
||||
if not os.path.exists(db_path):
|
||||
return
|
||||
conn = None
|
||||
try:
|
||||
conn = sqlite3.connect(db_path)
|
||||
cursor = conn.execute("PRAGMA table_info(sessions)")
|
||||
@@ -1011,9 +1087,13 @@ def _migrate_add_folder_column():
|
||||
conn.execute("ALTER TABLE sessions ADD COLUMN folder TEXT")
|
||||
conn.commit()
|
||||
logging.getLogger(__name__).info("Migrated: added 'folder' column to sessions")
|
||||
conn.close()
|
||||
except Exception as e:
|
||||
logging.getLogger(__name__).warning(f"Migration check for folder failed: {e}")
|
||||
finally:
|
||||
try:
|
||||
conn.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _migrate_add_token_columns():
|
||||
"""Add cumulative token tracking columns to sessions table."""
|
||||
@@ -1021,6 +1101,7 @@ def _migrate_add_token_columns():
|
||||
db_path = DATABASE_URL.replace("sqlite:///", "")
|
||||
if not os.path.exists(db_path):
|
||||
return
|
||||
conn = None
|
||||
try:
|
||||
conn = sqlite3.connect(db_path)
|
||||
cursor = conn.execute("PRAGMA table_info(sessions)")
|
||||
@@ -1030,9 +1111,13 @@ def _migrate_add_token_columns():
|
||||
conn.execute("ALTER TABLE sessions ADD COLUMN total_output_tokens INTEGER DEFAULT 0")
|
||||
conn.commit()
|
||||
logging.getLogger(__name__).info("Migrated: added token tracking columns to sessions")
|
||||
conn.close()
|
||||
except Exception as e:
|
||||
logging.getLogger(__name__).warning(f"Migration check for token columns failed: {e}")
|
||||
finally:
|
||||
try:
|
||||
conn.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _migrate_add_owner_to_table(table_name: str, index_name: str):
|
||||
"""Generic helper: add owner TEXT column + index to a table if missing."""
|
||||
@@ -1040,6 +1125,7 @@ def _migrate_add_owner_to_table(table_name: str, index_name: str):
|
||||
db_path = DATABASE_URL.replace("sqlite:///", "")
|
||||
if not os.path.exists(db_path):
|
||||
return
|
||||
conn = None
|
||||
try:
|
||||
conn = sqlite3.connect(db_path)
|
||||
cursor = conn.execute(f"PRAGMA table_info({table_name})")
|
||||
@@ -1049,9 +1135,13 @@ def _migrate_add_owner_to_table(table_name: str, index_name: str):
|
||||
conn.execute(f"CREATE INDEX IF NOT EXISTS {index_name} ON {table_name}(owner)")
|
||||
conn.commit()
|
||||
logging.getLogger(__name__).info(f"Migrated: added 'owner' column to {table_name}")
|
||||
conn.close()
|
||||
except Exception as e:
|
||||
logging.getLogger(__name__).warning(f"Migration owner column for {table_name} failed: {e}")
|
||||
finally:
|
||||
try:
|
||||
conn.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _migrate_add_multiuser_owner_columns():
|
||||
"""Add owner column to memories, gallery_images, user_tools, comparisons."""
|
||||
@@ -1076,6 +1166,7 @@ def _migrate_add_api_token_scopes_column():
|
||||
db_path = DATABASE_URL.replace("sqlite:///", "")
|
||||
if not os.path.exists(db_path):
|
||||
return
|
||||
conn = None
|
||||
try:
|
||||
conn = sqlite3.connect(db_path)
|
||||
columns = [row[1] for row in conn.execute("PRAGMA table_info(api_tokens)").fetchall()]
|
||||
@@ -1084,9 +1175,13 @@ def _migrate_add_api_token_scopes_column():
|
||||
conn.execute("UPDATE api_tokens SET scopes = 'chat' WHERE scopes IS NULL OR scopes = ''")
|
||||
conn.commit()
|
||||
logging.getLogger(__name__).info("Migrated: added scopes column to api_tokens")
|
||||
conn.close()
|
||||
except Exception as e:
|
||||
logging.getLogger(__name__).warning(f"api_tokens.scopes migration failed: {e}")
|
||||
finally:
|
||||
try:
|
||||
conn.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _migrate_assign_legacy_owner():
|
||||
"""Assign all null-owner data to the first (admin) user.
|
||||
@@ -1128,6 +1223,7 @@ def _migrate_assign_legacy_owner():
|
||||
return
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
conn = None
|
||||
try:
|
||||
conn = sqlite3.connect(db_path)
|
||||
# Every table with an `owner` column. New tables added later will be
|
||||
@@ -1152,9 +1248,13 @@ def _migrate_assign_legacy_owner():
|
||||
except Exception as e:
|
||||
logger.warning(f"Legacy owner assignment for {table} failed: {e}")
|
||||
conn.commit()
|
||||
conn.close()
|
||||
except Exception as e:
|
||||
logger.warning(f"Legacy owner migration failed: {e}")
|
||||
finally:
|
||||
try:
|
||||
conn.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Also migrate memory.json
|
||||
mem_path = MEMORY_FILE
|
||||
@@ -1773,6 +1873,7 @@ def _migrate_add_email_smtp_security():
|
||||
db_path = DATABASE_URL.replace("sqlite:///", "")
|
||||
if not os.path.exists(db_path):
|
||||
return
|
||||
conn = None
|
||||
try:
|
||||
conn = sqlite3.connect(db_path)
|
||||
cursor = conn.execute("PRAGMA table_info(email_accounts)")
|
||||
@@ -1788,9 +1889,13 @@ def _migrate_add_email_smtp_security():
|
||||
)
|
||||
conn.commit()
|
||||
logging.getLogger(__name__).info("Migrated: added smtp_security column to email_accounts")
|
||||
conn.close()
|
||||
except Exception as e:
|
||||
logging.getLogger(__name__).warning(f"smtp_security migration skipped: {e}")
|
||||
finally:
|
||||
try:
|
||||
conn.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def _migrate_encrypt_endpoint_keys():
|
||||
@@ -1891,6 +1996,7 @@ def _migrate_add_calendar_is_utc():
|
||||
db_path = DATABASE_URL.replace("sqlite:///", "")
|
||||
if not os.path.exists(db_path):
|
||||
return
|
||||
conn = None
|
||||
try:
|
||||
conn = sqlite3.connect(db_path)
|
||||
cursor = conn.execute("PRAGMA table_info(calendar_events)")
|
||||
@@ -1899,9 +2005,13 @@ def _migrate_add_calendar_is_utc():
|
||||
conn.execute("ALTER TABLE calendar_events ADD COLUMN is_utc BOOLEAN DEFAULT 0 NOT NULL")
|
||||
conn.commit()
|
||||
logging.getLogger(__name__).info("Migrated: added 'is_utc' column to calendar_events")
|
||||
conn.close()
|
||||
except Exception as e:
|
||||
logging.getLogger(__name__).warning(f"is_utc migration failed: {e}")
|
||||
finally:
|
||||
try:
|
||||
conn.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def _migrate_add_calendar_origin():
|
||||
@@ -1912,6 +2022,7 @@ def _migrate_add_calendar_origin():
|
||||
db_path = DATABASE_URL.replace("sqlite:///", "")
|
||||
if not os.path.exists(db_path):
|
||||
return
|
||||
conn = None
|
||||
try:
|
||||
conn = sqlite3.connect(db_path)
|
||||
cursor = conn.execute("PRAGMA table_info(calendar_events)")
|
||||
@@ -1921,9 +2032,13 @@ def _migrate_add_calendar_origin():
|
||||
conn.execute("CREATE INDEX IF NOT EXISTS ix_calendar_events_origin ON calendar_events(origin)")
|
||||
conn.commit()
|
||||
logging.getLogger(__name__).info("Migrated: added 'origin' column to calendar_events")
|
||||
conn.close()
|
||||
except Exception as e:
|
||||
logging.getLogger(__name__).warning(f"calendar_events.origin migration failed: {e}")
|
||||
finally:
|
||||
try:
|
||||
conn.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def _migrate_add_calendar_account_id():
|
||||
@@ -1933,6 +2048,7 @@ def _migrate_add_calendar_account_id():
|
||||
db_path = DATABASE_URL.replace("sqlite:///", "")
|
||||
if not os.path.exists(db_path):
|
||||
return
|
||||
conn = None
|
||||
try:
|
||||
conn = sqlite3.connect(db_path)
|
||||
cursor = conn.execute("PRAGMA table_info(calendars)")
|
||||
@@ -1942,9 +2058,13 @@ def _migrate_add_calendar_account_id():
|
||||
conn.execute("CREATE INDEX IF NOT EXISTS ix_calendars_account_id ON calendars(account_id)")
|
||||
conn.commit()
|
||||
logging.getLogger(__name__).info("Migrated: added 'account_id' column to calendars")
|
||||
conn.close()
|
||||
except Exception as e:
|
||||
logging.getLogger(__name__).warning(f"calendars.account_id migration failed: {e}")
|
||||
finally:
|
||||
try:
|
||||
conn.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def _migrate_add_calendar_metadata():
|
||||
@@ -1953,6 +2073,7 @@ def _migrate_add_calendar_metadata():
|
||||
db_path = DATABASE_URL.replace("sqlite:///", "")
|
||||
if not os.path.exists(db_path):
|
||||
return
|
||||
conn = None
|
||||
try:
|
||||
conn = sqlite3.connect(db_path)
|
||||
cursor = conn.execute("PRAGMA table_info(calendar_events)")
|
||||
@@ -1964,9 +2085,13 @@ def _migrate_add_calendar_metadata():
|
||||
if columns and "last_pinged" not in columns:
|
||||
conn.execute("ALTER TABLE calendar_events ADD COLUMN last_pinged DATETIME")
|
||||
conn.commit()
|
||||
conn.close()
|
||||
except Exception as e:
|
||||
logging.getLogger(__name__).warning(f"calendar_events migration failed: {e}")
|
||||
finally:
|
||||
try:
|
||||
conn.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def get_db():
|
||||
"""
|
||||
|
||||
@@ -366,6 +366,10 @@ def _ssh_exec_argv(
|
||||
strict_host_key_checking: bool | None = None,
|
||||
) -> list[str]:
|
||||
"""Build a consistent ssh argv for remote command execution."""
|
||||
remote_value = str(remote or "").strip()
|
||||
remote_host = remote_value.rsplit("@", 1)[-1]
|
||||
if not remote_value or remote_value.startswith("-") or not remote_host or remote_host.startswith("-"):
|
||||
raise ValueError("Invalid SSH remote host")
|
||||
argv = ["ssh"]
|
||||
if connect_timeout is not None:
|
||||
argv.extend(["-o", f"ConnectTimeout={int(connect_timeout)}"])
|
||||
|
||||
@@ -15,7 +15,7 @@ faster-whisper
|
||||
# DuckDuckGo as a search provider option.
|
||||
# Install if you want DDG in the search-provider dropdown.
|
||||
# Alternatives: SearXNG, Brave, Tavily, Serper, Google PSE.
|
||||
duckduckgo-search
|
||||
ddgs
|
||||
|
||||
# PDF form-filling feature (fillable AcroForm detection, field extraction,
|
||||
# value/annotation/signature stamping, page rendering for the form overlay).
|
||||
|
||||
@@ -0,0 +1,31 @@
|
||||
import re
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
|
||||
_REMOTE_HOST_RE = re.compile(
|
||||
r"^(?:[A-Za-z0-9][A-Za-z0-9._-]*@)?[A-Za-z0-9][A-Za-z0-9._-]*$"
|
||||
)
|
||||
_SSH_PORT_RE = re.compile(r"^\d{1,5}$")
|
||||
|
||||
|
||||
def validate_remote_host(v: str | None) -> str | None:
|
||||
if v is None or v == "":
|
||||
return None
|
||||
if not _REMOTE_HOST_RE.match(v):
|
||||
raise HTTPException(
|
||||
400,
|
||||
"Invalid remote_host — must be host or user@host, no SSH option syntax",
|
||||
)
|
||||
return v
|
||||
|
||||
|
||||
def validate_ssh_port(v: str | None) -> str | None:
|
||||
if v is None or v == "":
|
||||
return None
|
||||
if not _SSH_PORT_RE.fullmatch(str(v)):
|
||||
raise HTTPException(400, "Invalid ssh_port")
|
||||
port = int(v)
|
||||
if port < 1 or port > 65535:
|
||||
raise HTTPException(400, "Invalid ssh_port")
|
||||
return str(port)
|
||||
+54
-10
@@ -305,6 +305,19 @@ def setup_auth_routes(auth_manager: AuthManager) -> APIRouter:
|
||||
if not ok:
|
||||
raise HTTPException(400, "Cannot rename user")
|
||||
|
||||
def _rollback_auth_rename() -> bool:
|
||||
# On self-rename the admin session has already moved to the new
|
||||
# username, so the rollback must authenticate as the new user.
|
||||
rollback_user = new_username if user == old_username else user
|
||||
try:
|
||||
return bool(auth_manager.rename_user(new_username, old_username, rollback_user))
|
||||
except Exception as rollback_err:
|
||||
logger.error(
|
||||
"Failed to roll back auth rename %s -> %s after owner migration failure: %s",
|
||||
new_username, old_username, rollback_err,
|
||||
)
|
||||
return False
|
||||
|
||||
# Usernames are ownership keys for user data. Rename the common
|
||||
# owner-scoped DB rows so the account keeps access to its sessions,
|
||||
# docs, email accounts, tasks, etc.
|
||||
@@ -330,6 +343,11 @@ def setup_auth_routes(auth_manager: AuthManager) -> APIRouter:
|
||||
db.close()
|
||||
except Exception as e:
|
||||
logger.error("Failed to rename owner references %s -> %s: %s", old_username, new_username, e)
|
||||
if not _rollback_auth_rename():
|
||||
logger.error(
|
||||
"Auth rename %s -> %s could not be rolled back after owner migration failure",
|
||||
old_username, new_username,
|
||||
)
|
||||
raise HTTPException(500, "Failed to rename user data")
|
||||
|
||||
# Per-user prefs are JSON-backed, not SQL-backed.
|
||||
@@ -349,6 +367,20 @@ def setup_auth_routes(auth_manager: AuthManager) -> APIRouter:
|
||||
except Exception as e:
|
||||
logger.warning("Failed to rename user prefs %s -> %s: %s", old_username, new_username, e)
|
||||
|
||||
# In-flight deep-research tasks live in the process-local
|
||||
# ResearchHandler registry. They are not covered by the persisted JSON
|
||||
# migration above, but the research routes filter and cancel by this
|
||||
# owner field while the job is running. Do this before sweeping
|
||||
# completed JSON files so a job that finishes during the rename saves
|
||||
# with the new owner or is caught by the disk sweep below.
|
||||
try:
|
||||
rh = getattr(request.app.state, "research_handler", None)
|
||||
rename_owner = getattr(rh, "rename_owner", None)
|
||||
if callable(rename_owner):
|
||||
rename_owner(old_username, new_username)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to rename active research tasks %s -> %s: %s", old_username, new_username, e)
|
||||
|
||||
# deep_research: each completed report is a standalone JSON file with
|
||||
# an `owner` field. research_routes filters by d.get("owner") == user,
|
||||
# so a stale owner makes every report invisible to the renamed user.
|
||||
@@ -391,7 +423,8 @@ def setup_auth_routes(auth_manager: AuthManager) -> APIRouter:
|
||||
skills_root = Path(SKILLS_DIR)
|
||||
if skills_root.is_dir():
|
||||
_owner_re = re.compile(
|
||||
r'(?m)^(owner:\s*)' + re.escape(old_username) + r'\s*$'
|
||||
r'(?m)^(owner:\s*)' + re.escape(old_username) + r'\s*$',
|
||||
re.IGNORECASE,
|
||||
)
|
||||
for p in skills_root.rglob("SKILL.md"):
|
||||
try:
|
||||
@@ -406,12 +439,12 @@ def setup_auth_routes(auth_manager: AuthManager) -> APIRouter:
|
||||
try:
|
||||
usage = json.loads(usage_path.read_text(encoding="utf-8"))
|
||||
if isinstance(usage, dict):
|
||||
prefix = old_username + "::"
|
||||
new_usage = {}
|
||||
changed = False
|
||||
for k, v in usage.items():
|
||||
if k.startswith(prefix):
|
||||
new_usage[new_username + "::" + k[len(prefix):]] = v
|
||||
owner_part, sep, skill_part = k.partition("::")
|
||||
if sep and owner_part.lower() == old_username:
|
||||
new_usage[new_username + "::" + skill_part] = v
|
||||
changed = True
|
||||
else:
|
||||
new_usage[k] = v
|
||||
@@ -473,7 +506,23 @@ def setup_auth_routes(auth_manager: AuthManager) -> APIRouter:
|
||||
user = _get_current_user(request)
|
||||
if not user or not auth_manager.is_admin(user):
|
||||
raise HTTPException(403, "Admin only")
|
||||
|
||||
def _invalidate_api_token_cache():
|
||||
try:
|
||||
invalidator = getattr(request.app.state, "invalidate_token_cache", None)
|
||||
if invalidator:
|
||||
invalidator()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
ok = auth_manager.delete_user(body.username, user)
|
||||
except Exception:
|
||||
# delete_user can touch ApiToken rows before a later auth-store write
|
||||
# fails. Dirty the bearer cache anyway so a partial token purge does
|
||||
# not leave already-cached tokens authenticating until restart.
|
||||
_invalidate_api_token_cache()
|
||||
raise
|
||||
if not ok:
|
||||
raise HTTPException(400, "Cannot delete user")
|
||||
# delete_user removes the user's ApiToken rows, but the bearer-auth
|
||||
@@ -481,12 +530,7 @@ def setup_auth_routes(auth_manager: AuthManager) -> APIRouter:
|
||||
# rebuilds when flagged dirty. Without this, a deleted user's already
|
||||
# cached token keeps authenticating until some other token op or a
|
||||
# restart clears the cache. Mirror what the token routes do.
|
||||
try:
|
||||
invalidator = getattr(request.app.state, "invalidate_token_cache", None)
|
||||
if invalidator:
|
||||
invalidator()
|
||||
except Exception:
|
||||
pass
|
||||
_invalidate_api_token_cache()
|
||||
return {"ok": True}
|
||||
|
||||
# ---- Feature visibility (admin-managed) ----
|
||||
|
||||
@@ -729,8 +729,11 @@ def setup_contacts_routes():
|
||||
@router.post("/import")
|
||||
async def import_vcf(data: dict, _admin: str = Depends(require_admin)):
|
||||
"""Import contacts from .vcf or CSV. Body: {"vcf": "..."} or {"csv": "..."}."""
|
||||
text = data.get("vcf") or data.get("text") or ""
|
||||
csv_text = data.get("csv") or ""
|
||||
# Coerce defensively: a non-string vcf/text/csv (e.g. a number or list
|
||||
# in the JSON body) would otherwise reach .strip() and 500 with an
|
||||
# AttributeError instead of degrading to a clean "no data" response.
|
||||
text = str(data.get("vcf") or data.get("text") or "")
|
||||
csv_text = str(data.get("csv") or "")
|
||||
if text.strip():
|
||||
if "BEGIN:VCARD" not in text.upper():
|
||||
return {"success": False, "error": "No vCard data found"}
|
||||
|
||||
@@ -11,6 +11,7 @@ import shlex
|
||||
from fastapi import HTTPException
|
||||
from pydantic import BaseModel
|
||||
|
||||
from routes._validators import validate_remote_host, validate_ssh_port
|
||||
from core.platform_compat import _ssh_exec_argv
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -30,16 +31,12 @@ _LOCAL_MODEL_ID_RE = re.compile(r"^[A-Za-z0-9][A-Za-z0-9._-]*$")
|
||||
_OLLAMA_MODEL_ID_RE = re.compile(r"^[A-Za-z0-9][A-Za-z0-9._:/-]{0,200}$")
|
||||
# Include pattern is a glob: allow typical safe glyphs only.
|
||||
_INCLUDE_RE = re.compile(r"^[A-Za-z0-9._\-*?/\[\]]+$")
|
||||
# Remote host: either `user@host` or plain `host` (alias is allowed), where host
|
||||
# is a safe DNS-like token or a short SSH config alias.
|
||||
_REMOTE_HOST_RE = re.compile(r"^(?:[A-Za-z0-9._-]+@)?[A-Za-z0-9._-]+$")
|
||||
# HF tokens and API tokens are url-safe base64-like.
|
||||
_TOKEN_RE = re.compile(r"^[A-Za-z0-9._~+/=-]+$")
|
||||
# Session IDs we mint look like "cookbook-deadbeef" or "serve-deadbeef".
|
||||
# Anything beyond plain alphanumerics + dash + underscore could break out
|
||||
# of the shell/PowerShell contexts the value lands in.
|
||||
_SESSION_ID_RE = re.compile(r"^[A-Za-z0-9_-]{1,64}$")
|
||||
_SSH_PORT_RE = re.compile(r"^\d{1,5}$")
|
||||
_GPU_LIST_RE = re.compile(r"^\d+(?:,\d+)*$")
|
||||
# A download target directory. Absolute or ~-relative path; safe path glyphs
|
||||
# only (no quotes or shell metacharacters). Spaces are allowed because command
|
||||
@@ -85,14 +82,6 @@ def _validate_include(v: str | None) -> str | None:
|
||||
return v
|
||||
|
||||
|
||||
def _validate_remote_host(v: str | None) -> str | None:
|
||||
if v is None or v == "":
|
||||
return None
|
||||
if not _REMOTE_HOST_RE.match(v):
|
||||
raise HTTPException(400, "Invalid remote_host — must be host or user@host, no SSH option syntax")
|
||||
return v
|
||||
|
||||
|
||||
def _validate_token(v: str | None) -> str | None:
|
||||
if v is None or v == "":
|
||||
return None
|
||||
@@ -120,17 +109,6 @@ def _validate_local_dir(v: str | None) -> str | None:
|
||||
return v
|
||||
|
||||
|
||||
def _validate_ssh_port(v: str | None) -> str | None:
|
||||
if v is None or v == "":
|
||||
return None
|
||||
if not _SSH_PORT_RE.fullmatch(str(v)):
|
||||
raise HTTPException(400, "Invalid ssh_port")
|
||||
port = int(v)
|
||||
if port < 1 or port > 65535:
|
||||
raise HTTPException(400, "Invalid ssh_port")
|
||||
return str(port)
|
||||
|
||||
|
||||
def _validate_gpus(v: str | None) -> str | None:
|
||||
if v is None or v == "":
|
||||
return None
|
||||
|
||||
+38
-26
@@ -19,6 +19,7 @@ from src.constants import COOKBOOK_STATE_FILE
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.middleware import require_admin
|
||||
from routes._validators import validate_remote_host, validate_ssh_port
|
||||
from core.platform_compat import (
|
||||
IS_WINDOWS,
|
||||
detached_popen_kwargs,
|
||||
@@ -33,9 +34,8 @@ from routes.shell_routes import TMUX_LOG_DIR
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from routes.cookbook_helpers import (
|
||||
_SSH_PORT_RE, _REMOTE_HOST_RE, _SESSION_ID_RE,
|
||||
_validate_repo_id, _validate_serve_model_id, _validate_include, _validate_remote_host, _validate_token,
|
||||
_validate_local_dir, _validate_ssh_port, _validate_gpus, _shell_path,
|
||||
_SESSION_ID_RE, _validate_repo_id, _validate_serve_model_id, _validate_include, _validate_token,
|
||||
_validate_local_dir, _validate_gpus, _shell_path,
|
||||
_ps_squote, _bash_squote, _validate_serve_cmd, _parse_serve_phase,
|
||||
_safe_env_prefix, _local_tooling_path_export, _append_serve_preflight_exit_lines,
|
||||
_append_serve_exit_code_lines, _append_llama_cpp_linux_accel_build_lines, _cached_model_scan_script,
|
||||
@@ -407,8 +407,8 @@ def setup_cookbook_routes() -> APIRouter:
|
||||
else:
|
||||
_validate_repo_id(req.repo_id)
|
||||
_validate_include(req.include)
|
||||
_validate_remote_host(req.remote_host)
|
||||
req.ssh_port = _validate_ssh_port(req.ssh_port)
|
||||
validate_remote_host(req.remote_host)
|
||||
req.ssh_port = validate_ssh_port(req.ssh_port)
|
||||
req.local_dir = _validate_local_dir(req.local_dir)
|
||||
req.hf_token = "" if is_ollama_download else (req.hf_token or _load_stored_hf_token())
|
||||
_validate_token(req.hf_token)
|
||||
@@ -739,9 +739,8 @@ def setup_cookbook_routes() -> APIRouter:
|
||||
# Validate shell-bound inputs, matching the sibling list_gpus endpoint —
|
||||
# `host`/`ssh_port` are interpolated into an ssh command below, so an
|
||||
# unvalidated value (e.g. "x'; rm -rf ~ #") would be command injection.
|
||||
host = _validate_remote_host(host)
|
||||
if ssh_port is not None and ssh_port != "" and not _SSH_PORT_RE.fullmatch(ssh_port):
|
||||
raise HTTPException(400, "Invalid ssh_port")
|
||||
host = validate_remote_host(host)
|
||||
ssh_port = validate_ssh_port(ssh_port)
|
||||
TMUX_LOG_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
model_dirs = []
|
||||
@@ -890,11 +889,16 @@ def setup_cookbook_routes() -> APIRouter:
|
||||
# listening" check without requiring ss/netstat/nmap.
|
||||
ssh_base = ["ssh", "-o", "ConnectTimeout=4", "-o", "StrictHostKeyChecking=no"]
|
||||
if ssh_port and str(ssh_port) != "22":
|
||||
if not _SSH_PORT_RE.match(str(ssh_port)):
|
||||
try:
|
||||
ssh_port = validate_ssh_port(ssh_port)
|
||||
except HTTPException:
|
||||
return None
|
||||
ssh_base.extend(["-p", str(ssh_port)])
|
||||
host_arg = remote
|
||||
if not _REMOTE_HOST_RE.match(host_arg):
|
||||
try:
|
||||
host_arg = validate_remote_host(remote)
|
||||
except HTTPException:
|
||||
return None
|
||||
if not host_arg:
|
||||
return None
|
||||
probe_ports = " ".join(str(start_port + i) for i in range(max_offset + 1))
|
||||
script = (
|
||||
@@ -1197,8 +1201,8 @@ def setup_cookbook_routes() -> APIRouter:
|
||||
"""
|
||||
require_admin(request)
|
||||
# Defence-in-depth: reject values that could break out of shell contexts.
|
||||
_validate_remote_host(req.remote_host)
|
||||
req.ssh_port = _validate_ssh_port(req.ssh_port)
|
||||
validate_remote_host(req.remote_host)
|
||||
req.ssh_port = validate_ssh_port(req.ssh_port)
|
||||
req.gpus = _validate_gpus(req.gpus)
|
||||
req.hf_token = req.hf_token or _load_stored_hf_token()
|
||||
_validate_token(req.hf_token)
|
||||
@@ -1638,12 +1642,11 @@ def setup_cookbook_routes() -> APIRouter:
|
||||
async def server_setup(request: Request, req: SetupRequest):
|
||||
"""Install required dependencies on a remote server via SSH."""
|
||||
require_admin(request)
|
||||
host = _validate_remote_host(req.host)
|
||||
host = validate_remote_host(req.host)
|
||||
if not host:
|
||||
raise HTTPException(400, "host is required")
|
||||
port = req.ssh_port
|
||||
if port is not None and port != "" and not re.fullmatch(r"\d{1,5}", port):
|
||||
raise HTTPException(400, "Invalid ssh_port")
|
||||
port = validate_ssh_port(port)
|
||||
pf = f"-p {port} " if port and port != "22" else ""
|
||||
|
||||
# Detect platform: Windows first (echo %OS% → Windows_NT), then Termux, then Linux
|
||||
@@ -1887,9 +1890,8 @@ def setup_cookbook_routes() -> APIRouter:
|
||||
`busy` is True when free_mb/total_mb < 0.5.
|
||||
"""
|
||||
require_admin(request)
|
||||
host = _validate_remote_host(host)
|
||||
if ssh_port is not None and ssh_port != "" and not _SSH_PORT_RE.fullmatch(ssh_port):
|
||||
raise HTTPException(400, "Invalid ssh_port")
|
||||
host = validate_remote_host(host)
|
||||
ssh_port = validate_ssh_port(ssh_port)
|
||||
gpu_query = "nvidia-smi --query-gpu=index,name,memory.free,memory.total,memory.used,utilization.gpu,uuid --format=csv,noheader,nounits"
|
||||
nvidia_error = None
|
||||
try:
|
||||
@@ -2046,9 +2048,8 @@ def setup_cookbook_routes() -> APIRouter:
|
||||
sig = (req.signal or "TERM").upper()
|
||||
if sig not in ("TERM", "KILL", "INT"):
|
||||
raise HTTPException(400, "signal must be TERM, KILL, or INT")
|
||||
host = _validate_remote_host(req.host)
|
||||
if req.ssh_port and not _SSH_PORT_RE.fullmatch(req.ssh_port):
|
||||
raise HTTPException(400, "Invalid ssh_port")
|
||||
host = validate_remote_host(req.host)
|
||||
req.ssh_port = validate_ssh_port(req.ssh_port)
|
||||
kill_cmd = f"kill -{sig} {req.pid}"
|
||||
try:
|
||||
if host:
|
||||
@@ -2382,13 +2383,18 @@ def setup_cookbook_routes() -> APIRouter:
|
||||
host = (srv.get("host") or "").strip()
|
||||
if not host:
|
||||
continue # local-only entry; the /proc scan handles it
|
||||
if not _REMOTE_HOST_RE.match(host):
|
||||
try:
|
||||
host = validate_remote_host(host)
|
||||
except HTTPException:
|
||||
continue
|
||||
sport = str(srv.get("port") or "").strip()
|
||||
ssh_base = ["ssh", "-o", "ConnectTimeout=4", "-o", "StrictHostKeyChecking=no"]
|
||||
if sport and sport != "22":
|
||||
if not _SSH_PORT_RE.match(sport):
|
||||
try:
|
||||
sport = validate_ssh_port(sport)
|
||||
except HTTPException:
|
||||
continue
|
||||
if sport != "22":
|
||||
ssh_base.extend(["-p", sport])
|
||||
|
||||
try:
|
||||
@@ -2743,10 +2749,16 @@ def setup_cookbook_routes() -> APIRouter:
|
||||
if not _SESSION_ID_RE.match(session_id):
|
||||
logger.warning(f"Skipping task with unsafe session_id: {session_id!r}")
|
||||
continue
|
||||
if remote and not _REMOTE_HOST_RE.match(remote):
|
||||
if remote:
|
||||
try:
|
||||
remote = validate_remote_host(remote)
|
||||
except HTTPException:
|
||||
logger.warning(f"Skipping task with unsafe remoteHost: {remote!r}")
|
||||
continue
|
||||
if _tport and not _SSH_PORT_RE.match(str(_tport)):
|
||||
if _tport:
|
||||
try:
|
||||
_tport = validate_ssh_port(str(_tport))
|
||||
except HTTPException:
|
||||
logger.warning(f"Skipping task with unsafe sshPort: {_tport!r}")
|
||||
continue
|
||||
if task_platform == "windows" and remote:
|
||||
|
||||
+23
-3
@@ -1,7 +1,9 @@
|
||||
import re
|
||||
from copy import deepcopy
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import APIRouter, HTTPException
|
||||
|
||||
from routes._validators import validate_remote_host, validate_ssh_port
|
||||
|
||||
|
||||
# Backends the manual hardware simulator accepts. Must stay a subset of what
|
||||
@@ -11,6 +13,14 @@ from fastapi import APIRouter
|
||||
_MANUAL_BACKENDS = {"cuda", "rocm", "metal", "cpu_x86", "cpu_arm"}
|
||||
|
||||
|
||||
def _validate_detection_target(host: str = "", ssh_port: str = "") -> tuple[str, str]:
|
||||
host_value = validate_remote_host(host) or ""
|
||||
port_value = validate_ssh_port(ssh_port) or ""
|
||||
if port_value and not host_value:
|
||||
raise HTTPException(400, "ssh_port requires host")
|
||||
return host_value, port_value
|
||||
|
||||
|
||||
def _apply_manual_hardware(system, manual_mode="", manual_gpu_count="", manual_vram_gb="", manual_ram_gb="", manual_backend=""):
|
||||
"""Manual hardware is a "what if I had this setup" simulator —
|
||||
REPLACES the detected hardware entirely instead of adding to it.
|
||||
@@ -105,6 +115,7 @@ def setup_hwfit_routes():
|
||||
"""Detect and return current system hardware info. Pass host=user@server for remote.
|
||||
fresh=true bypasses the per-host cache (the Rescan button)."""
|
||||
from services.hwfit.hardware import detect_system
|
||||
host, ssh_port = _validate_detection_target(host, ssh_port)
|
||||
return detect_system(host=host, ssh_port=ssh_port, platform=platform, fresh=fresh)
|
||||
|
||||
@router.get("/models")
|
||||
@@ -118,6 +129,7 @@ def setup_hwfit_routes():
|
||||
from services.hwfit.hardware import detect_system
|
||||
from services.hwfit.fit import rank_models
|
||||
from services.hwfit.models import get_models, model_catalog_path
|
||||
host, ssh_port = _validate_detection_target(host, ssh_port)
|
||||
system = deepcopy(detect_system(host=host, ssh_port=ssh_port, platform=platform, fresh=fresh))
|
||||
if system.get("error"):
|
||||
return {"system": system, "models": [], "error": system["error"]}
|
||||
@@ -165,8 +177,14 @@ def setup_hwfit_routes():
|
||||
system["gpu_name"] = g["name"]
|
||||
system["active_group"] = {**g, "use_count": n}
|
||||
|
||||
if gpu_count != "":
|
||||
n = int(gpu_count)
|
||||
# Parse the optional count defensively (matches the gpu_group guard
|
||||
# above): a non-numeric query param previously raised ValueError ->
|
||||
# HTTP 500. A malformed value is ignored, same as omitting it.
|
||||
try:
|
||||
n = int(gpu_count) if gpu_count != "" else None
|
||||
except ValueError:
|
||||
n = None
|
||||
if n is not None:
|
||||
if n == 0:
|
||||
# RAM-only mode: rank against system memory, offload allowed.
|
||||
system["has_gpu"] = False
|
||||
@@ -229,6 +247,7 @@ def setup_hwfit_routes():
|
||||
from services.hwfit.hardware import detect_system
|
||||
from services.hwfit.models import get_models
|
||||
from services.hwfit.profiles import compute_serve_profiles
|
||||
host, ssh_port = _validate_detection_target(host, ssh_port)
|
||||
system = detect_system(host=host, ssh_port=ssh_port, platform=platform, fresh=fresh)
|
||||
if system.get("error"):
|
||||
return {"system": system, "profiles": [], "error": system["error"]}
|
||||
@@ -279,6 +298,7 @@ def setup_hwfit_routes():
|
||||
"""Rank image generation models against detected hardware."""
|
||||
from services.hwfit.hardware import detect_system
|
||||
from services.hwfit.image_models import rank_image_models
|
||||
host, ssh_port = _validate_detection_target(host, ssh_port)
|
||||
system = deepcopy(detect_system(host=host, ssh_port=ssh_port, platform=platform, fresh=fresh))
|
||||
if system.get("error"):
|
||||
return {"system": system, "models": [], "error": system["error"]}
|
||||
|
||||
@@ -11,7 +11,7 @@ from core.session_manager import SessionManager
|
||||
from core.models import ChatMessage
|
||||
from src.request_models import SessionResponse
|
||||
from core.database import Session as DbSession, SessionLocal, Document, GalleryImage, utcnow_naive
|
||||
from src.auth_helpers import get_current_user, effective_user, _auth_disabled
|
||||
from src.auth_helpers import get_current_user, effective_user, _auth_disabled, owner_filter
|
||||
from src.session_actions import is_session_recently_active
|
||||
|
||||
|
||||
@@ -258,7 +258,9 @@ def setup_session_routes(session_manager: SessionManager, config: dict, webhook_
|
||||
last_msg_map = {}
|
||||
mode_map = {}
|
||||
msg_count_map = {}
|
||||
rows = db.query(DbSession.id, DbSession.folder, DbSession.total_input_tokens, DbSession.total_output_tokens, DbSession.is_important, DbSession.created_at, DbSession.updated_at, DbSession.last_message_at, DbSession.mode, DbSession.message_count).filter(DbSession.archived == False, DbSession.owner == user).all()
|
||||
q = db.query(DbSession.id, DbSession.folder, DbSession.total_input_tokens, DbSession.total_output_tokens, DbSession.is_important, DbSession.created_at, DbSession.updated_at, DbSession.last_message_at, DbSession.mode, DbSession.message_count).filter(DbSession.archived == False)
|
||||
q = owner_filter(q, DbSession, user)
|
||||
rows = q.all()
|
||||
for row in rows:
|
||||
folder_map[row.id] = row.folder
|
||||
token_map[row.id] = (row.total_input_tokens or 0) + (row.total_output_tokens or 0)
|
||||
@@ -277,17 +279,19 @@ def setup_session_routes(session_manager: SessionManager, config: dict, webhook_
|
||||
# Sessions with active documents that have content
|
||||
from sqlalchemy import func
|
||||
doc_session_ids = set(
|
||||
r[0] for r in db.query(Document.session_id)
|
||||
r[0] for r in owner_filter(
|
||||
db.query(Document.session_id)
|
||||
.filter(Document.is_active == True,
|
||||
Document.current_content != None,
|
||||
func.trim(Document.current_content) != "",
|
||||
Document.owner == user)
|
||||
func.trim(Document.current_content) != ""),
|
||||
Document, user)
|
||||
.distinct().all()
|
||||
)
|
||||
img_session_ids = set(
|
||||
r[0] for r in db.query(GalleryImage.session_id)
|
||||
.filter(GalleryImage.session_id != None,
|
||||
GalleryImage.owner == user)
|
||||
r[0] for r in owner_filter(
|
||||
db.query(GalleryImage.session_id)
|
||||
.filter(GalleryImage.session_id != None),
|
||||
GalleryImage, user)
|
||||
.distinct().all()
|
||||
)
|
||||
finally:
|
||||
|
||||
@@ -417,7 +417,7 @@ def duckduckgo_search(query: str, count: Optional[int] = None, time_filter: Opti
|
||||
return []
|
||||
|
||||
try:
|
||||
from duckduckgo_search import DDGS
|
||||
from ddgs import DDGS
|
||||
except ImportError:
|
||||
logger.warning("duckduckgo-search package not installed; using HTML fallback")
|
||||
return _html_fallback()
|
||||
|
||||
+21
-8
@@ -579,6 +579,24 @@ def _classify_event_heuristic(summary: str) -> tuple:
|
||||
return etype, None
|
||||
|
||||
|
||||
def _memory_context_lines(mems, limit: int = 40) -> list:
|
||||
"""Render Memory rows into short personal-context bullets for event classify.
|
||||
|
||||
Reads the Memory ORM `text` column. The previous inline code read a
|
||||
non-existent `content` attribute, so it raised AttributeError on the first
|
||||
row, the surrounding except swallowed it, and the classifier ran with no
|
||||
personal context at all. getattr keeps it robust to future schema drift.
|
||||
"""
|
||||
lines: list = []
|
||||
for m in mems:
|
||||
c = (getattr(m, "text", "") or "").strip()
|
||||
if c:
|
||||
lines.append(f"- {c[:200]}")
|
||||
if len(lines) >= limit:
|
||||
break
|
||||
return lines
|
||||
|
||||
|
||||
async def action_classify_events(owner: str, **kwargs) -> Tuple[str, bool]:
|
||||
"""Hybrid classification of upcoming calendar events: fast heuristic for
|
||||
obvious cases, LLM fallback for ambiguous ones. Assigns event_type +
|
||||
@@ -614,16 +632,11 @@ async def action_classify_events(owner: str, **kwargs) -> Tuple[str, bool]:
|
||||
try:
|
||||
from core.database import Memory as _Mem
|
||||
_mems = db.query(_Mem).filter(_Mem.owner == owner).limit(60).all() if owner else []
|
||||
if _mems:
|
||||
_lines = []
|
||||
for m in _mems:
|
||||
c = (m.content or "").strip()
|
||||
if c:
|
||||
_lines.append(f"- {c[:200]}")
|
||||
_lines = _memory_context_lines(_mems)
|
||||
if _lines:
|
||||
_memory_context = "USER CONTEXT (relationships, work, life):\n" + "\n".join(_lines[:40]) + "\n\n"
|
||||
_memory_context = "USER CONTEXT (relationships, work, life):\n" + "\n".join(_lines) + "\n\n"
|
||||
except Exception as _me:
|
||||
logger.debug(f"Could not load memory for classify: {_me}")
|
||||
logger.warning(f"Could not load memory for classify: {_me}")
|
||||
|
||||
classified_h = 0
|
||||
classified_llm = 0
|
||||
|
||||
@@ -223,6 +223,25 @@ class ModelDiscovery:
|
||||
)
|
||||
return {"hosts": hosts, "items": items}
|
||||
|
||||
def warmup_ping_urls(self, limit: int = 5) -> List[str]:
|
||||
"""The ``/models`` URLs of up to ``limit`` discovered endpoints.
|
||||
|
||||
Used by the startup warmup / keepalive loop to prime connections. Each
|
||||
discovered item already carries a ``/v1/chat/completions`` url; swap the
|
||||
suffix for the cheap ``/models`` probe. Failures degrade to an empty list
|
||||
so warmup never crashes the caller.
|
||||
"""
|
||||
try:
|
||||
items = (self.discover_models() or {}).get("items", [])
|
||||
except Exception:
|
||||
return []
|
||||
urls: List[str] = []
|
||||
for ep in items[:limit]:
|
||||
url = (ep.get("url") or "").replace("/chat/completions", "/models")
|
||||
if url:
|
||||
urls.append(url)
|
||||
return urls
|
||||
|
||||
def get_providers(self) -> Dict[str, Any]:
|
||||
"""Get all available providers"""
|
||||
discovery = self.discover_models()
|
||||
|
||||
+24
-1
@@ -221,6 +221,22 @@ class ResearchHandler:
|
||||
# Task registry — background research with persistence
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def rename_owner(self, old_owner: str, new_owner: str) -> int:
|
||||
"""Move in-flight research tasks from one owner key to another."""
|
||||
old_key = str(old_owner or "").strip().lower()
|
||||
new_key = str(new_owner or "").strip().lower()
|
||||
if not old_key or not new_key:
|
||||
return 0
|
||||
|
||||
changed = 0
|
||||
for entry in list(self._active_tasks.values()):
|
||||
if not isinstance(entry, dict):
|
||||
continue
|
||||
if str(entry.get("owner", "")).strip().lower() == old_key:
|
||||
entry["owner"] = new_key
|
||||
changed += 1
|
||||
return changed
|
||||
|
||||
def start_research(
|
||||
self,
|
||||
session_id: str,
|
||||
@@ -390,7 +406,6 @@ class ResearchHandler:
|
||||
|
||||
def get_status(self, session_id: str) -> Optional[dict]:
|
||||
"""Get current research status for a session."""
|
||||
avg = self.get_avg_duration()
|
||||
if session_id in self._active_tasks:
|
||||
entry = self._active_tasks[session_id]
|
||||
result = {
|
||||
@@ -399,6 +414,14 @@ class ResearchHandler:
|
||||
"query": entry["query"],
|
||||
"started_at": entry["started_at"],
|
||||
}
|
||||
# avg_duration is a historical figure over completed reports on
|
||||
# disk; get_avg_duration() globs and JSON-parses the whole research
|
||||
# dir, so compute it at most once per active stream (memoized on the
|
||||
# entry) instead of on every ~1s SSE poll. The disk branch below
|
||||
# never used it, so it no longer pays that cost at all.
|
||||
if "_avg_duration" not in entry:
|
||||
entry["_avg_duration"] = self.get_avg_duration()
|
||||
avg = entry["_avg_duration"]
|
||||
if avg is not None:
|
||||
result["avg_duration"] = round(avg, 1)
|
||||
return result
|
||||
|
||||
@@ -1453,6 +1453,42 @@ async def do_manage_calendar(content: str, owner: Optional[str] = None) -> Dict:
|
||||
except ValueError:
|
||||
return {"error": "Invalid JSON arguments", "exit_code": 1}
|
||||
|
||||
# ── Batch normalization ──
|
||||
# Some models (e.g. deepseek-v4-flash) emit {"events": [{...}, ...]}
|
||||
# instead of individual create_event calls. Iterate and create each.
|
||||
if isinstance(args.get("events"), list) and not args.get("action"):
|
||||
results = []
|
||||
for ev in args["events"]:
|
||||
if not isinstance(ev, dict):
|
||||
continue
|
||||
# Normalize start/end from {dateTime: "..."} object to flat string
|
||||
for field, target in [("start", "dtstart"), ("end", "dtend")]:
|
||||
val = ev.pop(field, None)
|
||||
if val and target not in ev:
|
||||
ev[target] = val.get("dateTime", val) if isinstance(val, dict) else val
|
||||
ev.setdefault("action", "create_event")
|
||||
r = await do_manage_calendar(json.dumps(ev), owner=owner)
|
||||
results.append(r)
|
||||
created = [r for r in results if r.get("exit_code") == 0 and not r.get("error")]
|
||||
failed = [r for r in results if r.get("error")]
|
||||
|
||||
if not results:
|
||||
return {"error": "No events to create", "exit_code": 1}
|
||||
|
||||
# Surface both successes and failures
|
||||
parts = []
|
||||
if created:
|
||||
summaries = [r.get("response", "") for r in created]
|
||||
parts.append(f"Created {len(created)} event(s):\n" + "\n".join(summaries))
|
||||
if failed:
|
||||
first_error = failed[0].get("error", "Unknown error")
|
||||
parts.append(f"Failed to create {len(failed)} event(s). First error: {first_error}")
|
||||
|
||||
response = "\n\n".join(parts)
|
||||
# Non-zero exit code for partial or total failure
|
||||
exit_code = 0 if not failed else 1
|
||||
return {"response": response, "exit_code": exit_code, "created_count": len(created), "failed_count": len(failed)}
|
||||
|
||||
# Normalize action — some models emit hyphens ("list-calendars") instead
|
||||
# of underscores. Treat them as equivalent so we don't bounce a
|
||||
# cosmetic typo back to the model and waste a round-trip. Also accept
|
||||
|
||||
+15
-2
@@ -162,13 +162,26 @@ def is_public_blocked_tool(tool_name: Optional[str]) -> bool:
|
||||
|
||||
|
||||
def owner_is_admin_or_single_user(owner: Optional[str]) -> bool:
|
||||
"""Return True for admins, or when auth is not configured yet."""
|
||||
"""Return True for admins, or in intentional single-user mode.
|
||||
|
||||
Single-user mode means the operator explicitly disabled auth
|
||||
(``AUTH_ENABLED=false``) — the local/self-host default where the owner has
|
||||
full access to their own box.
|
||||
|
||||
The pre-setup window (auth ENABLED but no admin created yet) is treated as
|
||||
NON-admin: returning True there would hand server-execution tools
|
||||
(``bash``/``python``) to any caller before setup completes. The auth
|
||||
middleware already 401s ``/api/`` requests pre-setup, so this is
|
||||
defense-in-depth for callers that bypass it (e.g. trusted loopback).
|
||||
"""
|
||||
try:
|
||||
from core.auth import AuthManager
|
||||
|
||||
auth = AuthManager()
|
||||
if not auth.is_configured:
|
||||
return True
|
||||
from src.auth_helpers import _auth_disabled
|
||||
|
||||
return _auth_disabled()
|
||||
return bool(owner and auth.is_admin(owner))
|
||||
except Exception as exc:
|
||||
logger.warning("Unable to evaluate owner admin status: %s", exc)
|
||||
|
||||
+6
-2
@@ -1082,7 +1082,7 @@ import { wireArrowUpRecall, getLastUserMessageFromChatHistory } from './composer
|
||||
let _lastToolName = '';
|
||||
const _searchIcon = '<svg width="14" height="14" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2.5" stroke-linecap="round" style="vertical-align:-2px;margin-right:4px"><circle cx="11" cy="11" r="8"/><line x1="21" y1="21" x2="16.65" y2="16.65"/></svg>';
|
||||
const _toolLabels = {
|
||||
'web_search': _searchIcon + 'Searching',
|
||||
'web_search': 'Searching',
|
||||
'bash': 'Running',
|
||||
'python': 'Running',
|
||||
'create_document': 'Writing',
|
||||
@@ -1102,6 +1102,9 @@ import { wireArrowUpRecall, getLastUserMessageFromChatHistory } from './composer
|
||||
'list_models': 'Browsing',
|
||||
'ui_control': 'Adjusting',
|
||||
};
|
||||
const _toolIcons = {
|
||||
'web_search': _searchIcon,
|
||||
};
|
||||
function _thinkingLabel() {
|
||||
if (!_lastToolName) {
|
||||
return 'Thinking';
|
||||
@@ -2049,10 +2052,11 @@ import { wireArrowUpRecall, getLastUserMessageFromChatHistory } from './composer
|
||||
}
|
||||
threadWrap.classList.add('streaming');
|
||||
const toolLabel = _toolLabels[json.tool.toLowerCase()] || json.tool;
|
||||
const toolIcon = _toolIcons[json.tool.toLowerCase()] || '\u25B6';
|
||||
const node = document.createElement('div')
|
||||
node.className = 'agent-thread-node running';
|
||||
const cmdHtml = cmd ? `<pre class="agent-thread-cmd">${esc(cmd)}</pre>` : '';
|
||||
node.innerHTML = `<div class="agent-thread-dot"></div><div class="agent-thread-header"><span class="agent-thread-icon">\u25B6</span><span class="agent-thread-tool">${esc(toolLabel)}</span><span class="agent-thread-wave">▁▂▃</span></div><div class="agent-thread-content">${cmdHtml}</div>`;
|
||||
node.innerHTML = `<div class="agent-thread-dot"></div><div class="agent-thread-header"><span class="agent-thread-icon">${toolIcon}</span><span class="agent-thread-tool">${esc(toolLabel)}</span><span class="agent-thread-wave">▁▂▃</span></div><div class="agent-thread-content">${cmdHtml}</div>`;
|
||||
// Expand/collapse via delegated click handler (init at module bottom).
|
||||
threadWrap.appendChild(node);
|
||||
currentToolBubble = node;
|
||||
|
||||
@@ -862,6 +862,20 @@ export function stripToolBlocks(text) {
|
||||
return cleaned.trim();
|
||||
}
|
||||
|
||||
/**
|
||||
* Plain-text payload for the message copy buttons: the reply as the renderer
|
||||
* displays it — tool blocks and <think> reasoning stripped. dataset.raw keeps
|
||||
* the full model output (chat.js even embeds the elapsed time into the
|
||||
* <think> tag for reload persistence), so copying it verbatim leaks the
|
||||
* thinking block (#3722). Falls back to the raw text when stripping leaves
|
||||
* nothing (e.g. turns interrupted mid-thinking).
|
||||
*/
|
||||
export function copyMessageText(msgElement) {
|
||||
const raw = msgElement.dataset.raw || msgElement.querySelector('.body')?.textContent || '';
|
||||
const { content } = markdownModule.extractThinkingBlocks(stripToolBlocks(raw));
|
||||
return content || raw;
|
||||
}
|
||||
|
||||
/**
|
||||
* Build a collapsible sources box (used by both research and web search).
|
||||
*/
|
||||
@@ -1372,7 +1386,7 @@ export function createMsgFooter(msgElement) {
|
||||
{ id: 'copy', icon: COPY_ICON, title: 'Copy message', cls: 'footer-copy-btn', html: true, handler(e) {
|
||||
e.stopPropagation();
|
||||
const btn = e.currentTarget;
|
||||
uiModule.copyToClipboard(msgElement.dataset.raw || msgElement.querySelector('.body')?.textContent || '');
|
||||
uiModule.copyToClipboard(copyMessageText(msgElement));
|
||||
btn.innerHTML = CHECK_ICON;
|
||||
setTimeout(() => { btn.innerHTML = COPY_ICON; }, 1500);
|
||||
}},
|
||||
@@ -2444,6 +2458,7 @@ const chatRenderer = {
|
||||
updateSessionCostUI,
|
||||
roleTimestamp,
|
||||
stripToolBlocks,
|
||||
copyMessageText,
|
||||
safeToolScreenshotSrc,
|
||||
safeDisplayImageSrc,
|
||||
buildSourcesBox,
|
||||
|
||||
@@ -380,7 +380,7 @@ function _slashFooter(msgEl) {
|
||||
copyBtn.innerHTML = _copySvg;
|
||||
copyBtn.onclick = (e) => {
|
||||
e.stopPropagation();
|
||||
uiModule.copyToClipboard(msgEl.dataset.raw || msgEl.querySelector('.body')?.textContent || '');
|
||||
uiModule.copyToClipboard(chatRenderer.copyMessageText(msgEl));
|
||||
copyBtn.innerHTML = _checkSvg;
|
||||
setTimeout(() => { copyBtn.innerHTML = _copySvg; }, 1500);
|
||||
};
|
||||
|
||||
@@ -8,6 +8,9 @@ with missing users or assertion errors.
|
||||
import json
|
||||
import threading
|
||||
import time
|
||||
import contextlib
|
||||
import sys
|
||||
import types
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
|
||||
import pytest
|
||||
@@ -15,6 +18,41 @@ import pytest
|
||||
from tests.helpers.import_state import clear_module
|
||||
|
||||
|
||||
class _OwnerColumn:
|
||||
def __eq__(self, other):
|
||||
return ("owner ==", other)
|
||||
|
||||
|
||||
class _FakeApiToken:
|
||||
owner = _OwnerColumn()
|
||||
|
||||
|
||||
class _FakeQuery:
|
||||
def filter(self, *_conds):
|
||||
return self
|
||||
|
||||
def delete(self, *args, **kwargs):
|
||||
return 0
|
||||
|
||||
|
||||
class _FakeSession:
|
||||
def query(self, model):
|
||||
assert model is _FakeApiToken
|
||||
return _FakeQuery()
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _stub_api_token_purge(monkeypatch):
|
||||
@contextlib.contextmanager
|
||||
def _fake_db_session():
|
||||
yield _FakeSession()
|
||||
|
||||
db_stub = types.ModuleType("core.database")
|
||||
db_stub.get_db_session = _fake_db_session
|
||||
db_stub.ApiToken = _FakeApiToken
|
||||
monkeypatch.setitem(sys.modules, "core.database", db_stub)
|
||||
|
||||
|
||||
def _fresh_auth_manager(tmp_path):
|
||||
clear_module("core.auth")
|
||||
from core.auth import AuthManager
|
||||
|
||||
@@ -0,0 +1,125 @@
|
||||
"""Test that do_manage_calendar handles the batch {"events": [...]} format
|
||||
that models like deepseek-v4-flash emit instead of individual create_event calls.
|
||||
"""
|
||||
|
||||
import json
|
||||
import sys
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.helpers.import_state import clear_fake_database_modules
|
||||
from tests.helpers.sqlite_db import make_temp_sqlite
|
||||
|
||||
clear_fake_database_modules()
|
||||
|
||||
import core.database as cdb
|
||||
from core.database import CalendarEvent
|
||||
|
||||
_TS, _ENGINE, _TMPDB = make_temp_sqlite(cdb.Base.metadata)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _bind_temp_db(monkeypatch):
|
||||
monkeypatch.setitem(sys.modules, "core.database", cdb)
|
||||
parent = sys.modules.get("core")
|
||||
if parent is not None:
|
||||
monkeypatch.setattr(parent, "database", cdb, raising=False)
|
||||
monkeypatch.setattr(cdb, "SessionLocal", _TS)
|
||||
yield
|
||||
|
||||
|
||||
async def test_batch_events_with_datetime_objects():
|
||||
"""Model emits {"events": [{"summary": ..., "start": {"dateTime": ...}, "end": {"dateTime": ...}}]}."""
|
||||
from src.tool_implementations import do_manage_calendar
|
||||
|
||||
owner = "tester-" + uuid.uuid4().hex[:6]
|
||||
payload = {
|
||||
"events": [
|
||||
{
|
||||
"summary": "Morning Gym",
|
||||
"start": {"dateTime": "2026-06-09T06:00:00+05:30"},
|
||||
"end": {"dateTime": "2026-06-09T07:00:00+05:30"},
|
||||
},
|
||||
{
|
||||
"summary": "Morning Gym",
|
||||
"start": {"dateTime": "2026-06-10T06:00:00+05:30"},
|
||||
"end": {"dateTime": "2026-06-10T07:00:00+05:30"},
|
||||
},
|
||||
]
|
||||
}
|
||||
res = await do_manage_calendar(json.dumps(payload), owner=owner)
|
||||
assert res.get("exit_code") == 0, res
|
||||
assert "Created 2 event(s)" in res.get("response", "")
|
||||
|
||||
# Verify events exist in DB
|
||||
db = _TS()
|
||||
events = db.query(CalendarEvent).filter(CalendarEvent.summary == "Morning Gym").all()
|
||||
assert len(events) == 2
|
||||
db.close()
|
||||
|
||||
|
||||
async def test_batch_events_with_flat_strings():
|
||||
"""Model emits {"events": [{"summary": ..., "start": "ISO", "end": "ISO"}]}."""
|
||||
from src.tool_implementations import do_manage_calendar
|
||||
|
||||
owner = "tester-" + uuid.uuid4().hex[:6]
|
||||
payload = {
|
||||
"events": [
|
||||
{
|
||||
"summary": "Standup",
|
||||
"start": "2026-06-09T09:00:00",
|
||||
"end": "2026-06-09T09:30:00",
|
||||
},
|
||||
]
|
||||
}
|
||||
res = await do_manage_calendar(json.dumps(payload), owner=owner)
|
||||
assert res.get("exit_code") == 0, res
|
||||
assert "Created 1 event(s)" in res.get("response", "")
|
||||
|
||||
|
||||
async def test_batch_events_partial_failure():
|
||||
"""Batch with some valid and some invalid events — should surface both counts and first error."""
|
||||
from src.tool_implementations import do_manage_calendar
|
||||
|
||||
owner = "tester-" + uuid.uuid4().hex[:6]
|
||||
payload = {
|
||||
"events": [
|
||||
{
|
||||
"summary": "Valid Event 1",
|
||||
"start": "2026-06-09T10:00:00",
|
||||
"end": "2026-06-09T11:00:00",
|
||||
},
|
||||
{
|
||||
"summary": "Invalid Event",
|
||||
# Missing required dtstart — will fail
|
||||
},
|
||||
{
|
||||
"summary": "Valid Event 2",
|
||||
"start": "2026-06-09T14:00:00",
|
||||
"end": "2026-06-09T15:00:00",
|
||||
},
|
||||
]
|
||||
}
|
||||
res = await do_manage_calendar(json.dumps(payload), owner=owner)
|
||||
|
||||
# Partial failure = non-zero exit code
|
||||
assert res.get("exit_code") != 0, "Partial failure should return non-zero exit code"
|
||||
|
||||
# Response should mention both created and failed counts
|
||||
response = res.get("response", "")
|
||||
assert "Created 2 event(s)" in response, f"Should report 2 created: {response}"
|
||||
assert "Failed to create 1 event(s)" in response, f"Should report 1 failed: {response}"
|
||||
assert "error" in response.lower() or "required" in response.lower(), "Should include error details"
|
||||
|
||||
# Metadata fields
|
||||
assert res.get("created_count") == 2
|
||||
assert res.get("failed_count") == 1
|
||||
|
||||
# Verify only valid events were created
|
||||
db = _TS()
|
||||
events = db.query(CalendarEvent).filter(
|
||||
CalendarEvent.summary.in_(["Valid Event 1", "Valid Event 2"])
|
||||
).all()
|
||||
assert len(events) == 2
|
||||
db.close()
|
||||
@@ -0,0 +1,33 @@
|
||||
"""classify_events must read the Memory `text` column, not a non-existent
|
||||
`content` attribute.
|
||||
|
||||
The previous inline loop did `m.content`, which raised AttributeError on the
|
||||
first Memory row; the surrounding except swallowed it, so the personal-context
|
||||
block the LLM relies on was always empty. The logic now lives in
|
||||
`_memory_context_lines`, which reads `text`.
|
||||
"""
|
||||
from src.builtin_actions import _memory_context_lines
|
||||
|
||||
|
||||
class _Mem:
|
||||
def __init__(self, text):
|
||||
self.text = text
|
||||
|
||||
|
||||
def test_uses_text_and_truncates_and_skips_blank():
|
||||
lines = _memory_context_lines([_Mem("Alice is my spouse"), _Mem(" "), _Mem("y" * 250)])
|
||||
assert lines[0] == "- Alice is my spouse"
|
||||
assert len(lines) == 2 # the blank row is skipped
|
||||
assert lines[1] == "- " + "y" * 200 # truncated to 200 chars
|
||||
|
||||
|
||||
def test_skips_rows_without_text_attribute():
|
||||
class _Bad: # mimics a schema where the attribute is absent
|
||||
pass
|
||||
|
||||
assert _memory_context_lines([_Bad(), _Mem("ok")]) == ["- ok"]
|
||||
|
||||
|
||||
def test_respects_limit():
|
||||
mems = [_Mem(f"memory {i}") for i in range(50)]
|
||||
assert len(_memory_context_lines(mems, limit=40)) == 40
|
||||
@@ -0,0 +1,39 @@
|
||||
"""POST /api/contacts/import must not 500 on a non-string vcf/text/csv value.
|
||||
|
||||
`text = data.get("vcf") or ... or ""` left a non-string value (e.g. a number)
|
||||
in place, so the next `text.strip()` raised AttributeError -> HTTP 500. The
|
||||
handler now coerces with str() and degrades to a structured "no data" response.
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
from routes.contacts_routes import setup_contacts_routes
|
||||
|
||||
|
||||
def _import_handler():
|
||||
router = setup_contacts_routes()
|
||||
for route in router.routes:
|
||||
if getattr(route, "path", "").endswith("/import") and "POST" in getattr(route, "methods", set()):
|
||||
return route.endpoint
|
||||
raise AssertionError("import route not found")
|
||||
|
||||
|
||||
def _call(data):
|
||||
handler = _import_handler()
|
||||
return asyncio.run(handler(data=data, _admin="admin"))
|
||||
|
||||
|
||||
def test_non_string_vcf_degrades_cleanly():
|
||||
resp = _call({"vcf": 123})
|
||||
assert resp["success"] is False
|
||||
assert "error" in resp
|
||||
|
||||
|
||||
def test_non_string_csv_degrades_cleanly():
|
||||
resp = _call({"csv": ["a", "b"]})
|
||||
assert resp["success"] is False
|
||||
|
||||
|
||||
def test_empty_body_reports_no_data():
|
||||
resp = _call({})
|
||||
assert resp["success"] is False
|
||||
assert resp["error"] == "No contact data found"
|
||||
@@ -26,7 +26,6 @@ from routes.cookbook_helpers import (
|
||||
_validate_repo_id,
|
||||
_validate_serve_cmd,
|
||||
_validate_serve_model_id,
|
||||
_validate_ssh_port,
|
||||
_shell_path,
|
||||
run_ssh_command_async,
|
||||
)
|
||||
@@ -106,12 +105,6 @@ def test_safe_env_prefix_accepts_powershell_activation_path():
|
||||
)
|
||||
|
||||
|
||||
def test_validate_ssh_port_rejects_shell_payload():
|
||||
with pytest.raises(HTTPException):
|
||||
_validate_ssh_port("22; touch /tmp/pwned")
|
||||
assert _validate_ssh_port("2222") == "2222"
|
||||
|
||||
|
||||
def test_validate_local_dir_accepts_external_drive_paths_with_spaces():
|
||||
path = "/Volumes/T7 2TB/AI Models/llamacpp"
|
||||
|
||||
|
||||
@@ -0,0 +1,160 @@
|
||||
"""Regression coverage for issue #3722 — the message copy button copied the
|
||||
full raw model output (``dataset.raw``), which still contains the
|
||||
``<think time="...">...</think>`` reasoning block that the renderer strips for
|
||||
display. Pasting therefore leaked the model's thinking, and the first heading
|
||||
after ``</think>`` lost its markdown formatting because it was glued to the
|
||||
closing tag.
|
||||
|
||||
The fix adds chatRenderer.copyMessageText(), which mirrors the display
|
||||
pipeline (``stripToolBlocks()`` then ``extractThinkingBlocks()``), and routes
|
||||
both AI-message copy buttons (createMsgFooter and the slash-reply footer)
|
||||
through it. extractThinkingBlocks() behavior is pinned here under node
|
||||
(including on the payload from the issue report); the helper and handler
|
||||
wiring are guarded at the source level because chatRenderer.js pulls in
|
||||
browser globals and can't be imported under node (same approach as
|
||||
test_new_chat_clears_input.py).
|
||||
"""
|
||||
|
||||
import json
|
||||
import re
|
||||
import shutil
|
||||
import subprocess
|
||||
import textwrap
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
_REPO = Path(__file__).resolve().parent.parent
|
||||
_HAS_NODE = shutil.which("node") is not None
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def node_available():
|
||||
if not _HAS_NODE:
|
||||
pytest.skip("node binary not on PATH")
|
||||
|
||||
|
||||
def _extract_thinking_blocks(text: str) -> dict:
|
||||
"""Run markdown.js extractThinkingBlocks(text) under node."""
|
||||
script = textwrap.dedent(
|
||||
r"""
|
||||
import fs from 'node:fs';
|
||||
|
||||
globalThis.window = { location: { origin: 'http://localhost' }, katex: null };
|
||||
globalThis.document = {
|
||||
readyState: 'loading',
|
||||
addEventListener() {},
|
||||
createElement(tag) {
|
||||
if (tag !== 'template') throw new Error(`unsupported element: ${tag}`);
|
||||
return {
|
||||
_html: '',
|
||||
content: { querySelectorAll() { return []; } },
|
||||
set innerHTML(value) { this._html = value; },
|
||||
get innerHTML() { return this._html; },
|
||||
};
|
||||
},
|
||||
};
|
||||
globalThis.MutationObserver = class { observe() {} };
|
||||
|
||||
let source = fs.readFileSync('./static/js/markdown.js', 'utf8');
|
||||
source = source.replace(
|
||||
/import uiModule from ['"]\.\/ui\.js['"];/,
|
||||
''
|
||||
);
|
||||
source = source.replace(
|
||||
/import \{ splitTableRow \} from ['"]\.\/markdown\/tableRow\.js['"];/,
|
||||
`function splitTableRow(row) {
|
||||
return (row || '').replace(/^\\s*\\|/, '').replace(/\\|\\s*$/, '').split('|').map(c => c.trim());
|
||||
}`
|
||||
);
|
||||
const emojiSource = fs.readFileSync('./static/js/emojiShortcodes.js', 'utf8')
|
||||
.replace(/^export default .*$/m, '')
|
||||
.replace(/export const /g, 'const ')
|
||||
.replace(/export function /g, 'function ');
|
||||
source = source.replace(
|
||||
/import \{ replaceEmojiShortcodes, hasEmojiShortcode \} from ['"]\.\/emojiShortcodes\.js['"];/,
|
||||
() => emojiSource
|
||||
);
|
||||
source = source.replace(
|
||||
/var escapeHtml = uiModule\.esc;/,
|
||||
`var escapeHtml = (value) => String(value ?? '')
|
||||
.replace(/&/g, '&')
|
||||
.replace(/</g, '<')
|
||||
.replace(/>/g, '>')
|
||||
.replace(/"/g, '"')
|
||||
.replace(/'/g, ''');`
|
||||
);
|
||||
|
||||
const moduleUrl = 'data:text/javascript;base64,' + Buffer.from(source).toString('base64');
|
||||
const mod = await import(moduleUrl);
|
||||
const input = JSON.parse(process.argv[1]);
|
||||
console.log(JSON.stringify({ out: mod.extractThinkingBlocks(input) }));
|
||||
"""
|
||||
)
|
||||
result = subprocess.run(
|
||||
["node", "--input-type=module", "-e", script, json.dumps(text)],
|
||||
cwd=_REPO,
|
||||
capture_output=True,
|
||||
timeout=15,
|
||||
text=True,
|
||||
)
|
||||
if result.returncode != 0:
|
||||
raise AssertionError(f"node failed:\nSTDERR:\n{result.stderr}\nSTDOUT:\n{result.stdout}")
|
||||
return json.loads(result.stdout.splitlines()[-1])["out"]
|
||||
|
||||
|
||||
def test_issue_payload_copy_text_excludes_thinking(node_available):
|
||||
# Shape reported in #3722: timed think block glued to the reply heading.
|
||||
raw = (
|
||||
'<think time="24.5">\n'
|
||||
"Here's a thinking process that leads to the desired summary:\n\n"
|
||||
"6. **Generate the Output.** (This matches the final provided response.)"
|
||||
"</think>### Juxtaposition: Interweaving Cultural Norms in Lesson Design\n"
|
||||
"The most effective lesson structure is created by deliberately juxtaposing."
|
||||
)
|
||||
out = _extract_thinking_blocks(raw)
|
||||
|
||||
assert out["content"].startswith("### Juxtaposition:"), out["content"]
|
||||
assert "thinking process" not in out["content"]
|
||||
assert "<think" not in out["content"]
|
||||
assert out["thinkingTime"] == "24.5"
|
||||
|
||||
|
||||
def test_plain_reply_copy_text_is_unchanged(node_available):
|
||||
raw = "### Heading\nJust a normal reply with no reasoning markup."
|
||||
out = _extract_thinking_blocks(raw)
|
||||
assert out["content"] == raw
|
||||
|
||||
|
||||
def test_thinking_only_message_yields_empty_content(node_available):
|
||||
# The copy handler falls back to the raw text in this case so the button
|
||||
# still copies something for turns interrupted mid-thinking.
|
||||
out = _extract_thinking_blocks("<think>only reasoning, no reply yet</think>")
|
||||
assert out["content"] == ""
|
||||
|
||||
|
||||
def _function_body(text: str, marker: str) -> str:
|
||||
start = text.index(marker)
|
||||
rest = text[start + len(marker):]
|
||||
m = re.search(r"\nexport function |\nfunction ", rest)
|
||||
return rest[: m.start()] if m else rest
|
||||
|
||||
|
||||
def test_copy_message_text_mirrors_display_pipeline():
|
||||
text = (_REPO / "static/js/chatRenderer.js").read_text(encoding="utf-8")
|
||||
body = _function_body(text, "export function copyMessageText")
|
||||
# Mirrors the display path: tool blocks stripped, then thinking extracted.
|
||||
assert "extractThinkingBlocks" in body
|
||||
assert "stripToolBlocks" in body
|
||||
assert "dataset.raw" in body
|
||||
|
||||
|
||||
def test_copy_handlers_route_through_copy_message_text():
|
||||
for path, count in (("static/js/chatRenderer.js", 1), ("static/js/slashCommands.js", 1)):
|
||||
text = (_REPO / path).read_text(encoding="utf-8")
|
||||
assert text.count("copyToClipboard(copyMessageText(") + text.count(
|
||||
"copyToClipboard(chatRenderer.copyMessageText("
|
||||
) == count, path
|
||||
# The old behavior passed dataset.raw straight to the clipboard.
|
||||
assert "copyToClipboard(msgElement.dataset.raw" not in text, path
|
||||
assert "copyToClipboard(msgEl.dataset.raw" not in text, path
|
||||
@@ -36,6 +36,17 @@ def _auth_manager(delete_result):
|
||||
)
|
||||
|
||||
|
||||
def _auth_manager_raising():
|
||||
def _delete_user(_username, _requesting_user):
|
||||
raise RuntimeError("auth save failed after token purge")
|
||||
|
||||
return types.SimpleNamespace(
|
||||
get_username_for_token=lambda token: "admin",
|
||||
is_admin=lambda user: True,
|
||||
delete_user=_delete_user,
|
||||
)
|
||||
|
||||
|
||||
def test_successful_delete_invalidates_cache():
|
||||
invalidations = []
|
||||
router = setup_auth_routes(_auth_manager(delete_result=True))
|
||||
@@ -56,3 +67,16 @@ def test_refused_delete_does_not_invalidate_cache():
|
||||
raised = True
|
||||
assert raised, "a refused delete should raise (HTTP 400)"
|
||||
assert invalidations == [], "a refused delete must not touch the token cache"
|
||||
|
||||
|
||||
def test_delete_exception_invalidates_cache_for_partial_token_purge():
|
||||
invalidations = []
|
||||
router = setup_auth_routes(_auth_manager_raising())
|
||||
handler = _handler(router)
|
||||
try:
|
||||
asyncio.run(handler(DeleteUserRequest(username="bob"), _fake_request(invalidations)))
|
||||
raised = False
|
||||
except RuntimeError:
|
||||
raised = True
|
||||
assert raised, "delete_user exception should still propagate"
|
||||
assert invalidations == [True], "partial token purge must dirty the bearer cache"
|
||||
|
||||
@@ -114,3 +114,21 @@ def test_refused_delete_leaves_tokens_alone(manager, db_calls):
|
||||
def test_unknown_user_leaves_tokens_alone(manager, db_calls):
|
||||
assert manager.delete_user("ghost", "admin") is False
|
||||
assert db_calls == []
|
||||
|
||||
|
||||
def test_delete_user_fails_closed_when_api_token_purge_fails(manager, monkeypatch):
|
||||
token = manager.create_session("bob", "secret-bob-pw")
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _failing_db_session():
|
||||
raise RuntimeError("database unavailable")
|
||||
yield
|
||||
|
||||
db_stub = types.ModuleType("core.database")
|
||||
db_stub.get_db_session = _failing_db_session
|
||||
db_stub.ApiToken = _FakeApiToken
|
||||
monkeypatch.setitem(sys.modules, "core.database", db_stub)
|
||||
|
||||
assert manager.delete_user("bob", "admin") is False
|
||||
assert "bob" in manager.users
|
||||
assert manager.validate_token(token) is True
|
||||
|
||||
@@ -0,0 +1,38 @@
|
||||
"""GET /api/hwfit/models must not 500 on a non-numeric gpu_count.
|
||||
|
||||
The handler did `n = int(gpu_count)` with no guard, so `?gpu_count=abc` (or any
|
||||
non-integer) raised ValueError -> HTTP 500. A malformed count is now ignored,
|
||||
matching how the neighbouring gpu_group param is already parsed.
|
||||
"""
|
||||
from routes.hwfit_routes import setup_hwfit_routes
|
||||
|
||||
|
||||
def _get_models():
|
||||
router = setup_hwfit_routes()
|
||||
for route in router.routes:
|
||||
if getattr(route, "path", "").endswith("/models") and "GET" in getattr(route, "methods", set()):
|
||||
return route.endpoint
|
||||
raise AssertionError("hwfit /models route not found")
|
||||
|
||||
|
||||
def test_non_numeric_gpu_count_does_not_raise():
|
||||
handler = _get_models()
|
||||
# Previously raised ValueError (HTTP 500); now degrades to a normal ranking.
|
||||
result = handler(gpu_count="abc")
|
||||
assert isinstance(result, dict)
|
||||
|
||||
|
||||
def test_numeric_gpu_count_still_accepted():
|
||||
handler = _get_models()
|
||||
result = handler(gpu_count="0")
|
||||
assert isinstance(result, dict)
|
||||
|
||||
|
||||
def test_non_numeric_manual_gpu_count_does_not_raise():
|
||||
# manual_gpu_count is the other count param on this endpoint (the hardware
|
||||
# simulator in _apply_manual_hardware). A non-numeric value must also degrade
|
||||
# (default to 1) rather than 500, so the endpoint's count parsing is fully
|
||||
# covered.
|
||||
handler = _get_models()
|
||||
result = handler(manual_mode="gpu", manual_gpu_count="abc")
|
||||
assert isinstance(result, dict)
|
||||
@@ -0,0 +1,47 @@
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
|
||||
from core.platform_compat import _ssh_exec_argv
|
||||
from routes.hwfit_routes import setup_hwfit_routes
|
||||
|
||||
|
||||
def _endpoint(path: str):
|
||||
router = setup_hwfit_routes()
|
||||
for route in router.routes:
|
||||
if getattr(route, "path", "") == path:
|
||||
return route.endpoint
|
||||
raise AssertionError(f"{path} route not found")
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"path,kwargs",
|
||||
[
|
||||
("/api/hwfit/system", {}),
|
||||
("/api/hwfit/models", {"limit": 1}),
|
||||
("/api/hwfit/profiles", {"model": "demo"}),
|
||||
("/api/hwfit/image-models", {}),
|
||||
],
|
||||
)
|
||||
def test_hwfit_routes_reject_ssh_option_host(path, kwargs):
|
||||
endpoint = _endpoint(path)
|
||||
|
||||
with pytest.raises(HTTPException) as exc:
|
||||
endpoint(host="-oProxyCommand=sh", ssh_port="22", **kwargs)
|
||||
|
||||
assert exc.value.status_code == 400
|
||||
|
||||
|
||||
def test_hwfit_routes_reject_port_without_host():
|
||||
endpoint = _endpoint("/api/hwfit/system")
|
||||
|
||||
with pytest.raises(HTTPException) as exc:
|
||||
endpoint(host="", ssh_port="2222")
|
||||
|
||||
assert exc.value.status_code == 400
|
||||
|
||||
|
||||
def test_ssh_argv_rejects_option_shaped_remote():
|
||||
with pytest.raises(ValueError):
|
||||
_ssh_exec_argv("-oProxyCommand=sh", "22", remote_cmd="true")
|
||||
with pytest.raises(ValueError):
|
||||
_ssh_exec_argv("alice@-oProxyCommand=sh", "22", remote_cmd="true")
|
||||
@@ -11,7 +11,10 @@ owner column, but three file-backed / in-memory stores are left stale:
|
||||
research_routes filters by `d.get("owner") == user`, making every report
|
||||
invisible after rename.
|
||||
|
||||
3. data/memory.json — a flat array where every entry has an `owner` field;
|
||||
3. research_handler._active_tasks — in-flight research jobs carry the same
|
||||
owner key while status/cancel/active routes filter by it.
|
||||
|
||||
4. data/memory.json — a flat array where every entry has an `owner` field;
|
||||
memory_manager.load(owner=user) filters on it, so all memories vanish.
|
||||
|
||||
Regression coverage: these bugs are invisible in unit tests that mock the DB
|
||||
@@ -26,6 +29,7 @@ from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
|
||||
|
||||
def _route(router, name):
|
||||
@@ -63,18 +67,69 @@ def rename_endpoint(monkeypatch, tmp_path):
|
||||
return _route(ar.setup_auth_routes(am), "rename_user"), am, tmp_path
|
||||
|
||||
|
||||
def _request(tmp_path, session_manager=None):
|
||||
def _request(tmp_path, session_manager=None, token="t", research_handler=None):
|
||||
state = SimpleNamespace(
|
||||
invalidate_token_cache=lambda: None,
|
||||
session_manager=session_manager,
|
||||
research_handler=research_handler,
|
||||
)
|
||||
return SimpleNamespace(
|
||||
cookies={"odysseus_session": "t"},
|
||||
cookies={"odysseus_session": token},
|
||||
app=SimpleNamespace(state=state),
|
||||
state=SimpleNamespace(current_user="admin"),
|
||||
)
|
||||
|
||||
|
||||
def _auth_manager_for_rollback_test(monkeypatch, tmp_path):
|
||||
import core.auth as auth_mod
|
||||
|
||||
monkeypatch.setattr(auth_mod, "_hash_password", lambda password: f"hash:{password}")
|
||||
monkeypatch.setattr(auth_mod, "_verify_password", lambda password, hashed: hashed == f"hash:{password}")
|
||||
|
||||
am = auth_mod.AuthManager(str(tmp_path / "auth.json"))
|
||||
assert am.create_user("admin", "pw-123456", is_admin=True) is True
|
||||
assert am.create_user("alice", "pw-123456") is True
|
||||
return am
|
||||
|
||||
|
||||
def _force_sql_owner_migration_failure(monkeypatch):
|
||||
import core.database as cdb
|
||||
|
||||
class OwnerModel:
|
||||
owner = "owner"
|
||||
|
||||
class FailingQuery:
|
||||
def filter(self, *_args, **_kwargs):
|
||||
return self
|
||||
|
||||
def update(self, *_args, **_kwargs):
|
||||
raise RuntimeError("forced owner migration failure")
|
||||
|
||||
class FailingSession:
|
||||
def __init__(self):
|
||||
self.rolled_back = False
|
||||
self.closed = False
|
||||
|
||||
def query(self, _model):
|
||||
return FailingQuery()
|
||||
|
||||
def rollback(self):
|
||||
self.rolled_back = True
|
||||
|
||||
def close(self):
|
||||
self.closed = True
|
||||
|
||||
db = FailingSession()
|
||||
monkeypatch.setattr(cdb, "SessionLocal", lambda: db)
|
||||
monkeypatch.setattr(
|
||||
cdb,
|
||||
"Base",
|
||||
SimpleNamespace(registry=SimpleNamespace(mappers=[SimpleNamespace(class_=OwnerModel)])),
|
||||
raising=False,
|
||||
)
|
||||
return db
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 1. In-memory session cache
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -183,6 +238,108 @@ def test_rename_no_deep_research_dir_does_not_crash(rename_endpoint):
|
||||
assert res["ok"] is True
|
||||
|
||||
|
||||
def test_rename_updates_active_research_task_owner(rename_endpoint):
|
||||
endpoint, _am, tmp_path = rename_endpoint
|
||||
|
||||
from routes.research_routes import setup_research_routes
|
||||
from src.research_handler import ResearchHandler
|
||||
|
||||
rh = ResearchHandler.__new__(ResearchHandler)
|
||||
rh._active_tasks = {
|
||||
"alice-task": {
|
||||
"owner": "Alice",
|
||||
"status": "running",
|
||||
"query": "q",
|
||||
"progress": {},
|
||||
"started_at": 1,
|
||||
},
|
||||
"carol-task": {
|
||||
"owner": "carol",
|
||||
"status": "running",
|
||||
"query": "q2",
|
||||
"progress": {},
|
||||
"started_at": 2,
|
||||
},
|
||||
}
|
||||
|
||||
asyncio.run(endpoint(
|
||||
"alice",
|
||||
SimpleNamespace(username="alice2"),
|
||||
_request(tmp_path, research_handler=rh),
|
||||
))
|
||||
|
||||
assert rh._active_tasks["alice-task"]["owner"] == "alice2"
|
||||
assert rh._active_tasks["carol-task"]["owner"] == "carol"
|
||||
|
||||
router = setup_research_routes(rh)
|
||||
active = next(
|
||||
r.endpoint for r in router.routes
|
||||
if getattr(r, "path", "") == "/api/research/active"
|
||||
)
|
||||
|
||||
alice2 = asyncio.run(active(
|
||||
SimpleNamespace(state=SimpleNamespace(current_user="alice2")),
|
||||
))
|
||||
alice = asyncio.run(active(
|
||||
SimpleNamespace(state=SimpleNamespace(current_user="alice")),
|
||||
))
|
||||
|
||||
assert [item["session_id"] for item in alice2["active"]] == ["alice-task"]
|
||||
assert alice["active"] == []
|
||||
|
||||
|
||||
def test_research_handler_rename_owner_canonicalizes_new_owner():
|
||||
from src.research_handler import ResearchHandler
|
||||
|
||||
rh = ResearchHandler.__new__(ResearchHandler)
|
||||
rh._active_tasks = {
|
||||
"task": {"owner": "Alice", "status": "running"},
|
||||
}
|
||||
|
||||
changed = rh.rename_owner("alice", "Alice2")
|
||||
assert changed == 1
|
||||
assert rh._active_tasks["task"]["owner"] == "alice2"
|
||||
|
||||
|
||||
def test_research_handler_rename_owner_uses_auth_lower_contract_not_casefold():
|
||||
from src.research_handler import ResearchHandler
|
||||
|
||||
rh = ResearchHandler.__new__(ResearchHandler)
|
||||
rh._active_tasks = {
|
||||
"task-strasse": {"owner": "strasse", "status": "running"},
|
||||
"task-sharp-s": {"owner": "straße", "status": "running"},
|
||||
}
|
||||
|
||||
changed = rh.rename_owner("straße", "renamed")
|
||||
|
||||
assert changed == 1
|
||||
assert rh._active_tasks["task-strasse"]["owner"] == "strasse"
|
||||
assert rh._active_tasks["task-sharp-s"]["owner"] == "renamed"
|
||||
|
||||
|
||||
def test_rename_updates_active_research_before_completed_json_sweep(rename_endpoint):
|
||||
endpoint, _am, tmp_path = rename_endpoint
|
||||
|
||||
dr_dir = tmp_path / "deep_research"
|
||||
dr_dir.mkdir()
|
||||
report = dr_dir / "race-window.json"
|
||||
report.write_text(json.dumps({"owner": "alice", "status": "done"}), encoding="utf-8")
|
||||
owner_seen_by_active_hook = []
|
||||
|
||||
class FakeResearchHandler:
|
||||
def rename_owner(self, _old, _new):
|
||||
owner_seen_by_active_hook.append(json.loads(report.read_text(encoding="utf-8"))["owner"])
|
||||
|
||||
asyncio.run(endpoint(
|
||||
"alice",
|
||||
SimpleNamespace(username="alice2"),
|
||||
_request(tmp_path, research_handler=FakeResearchHandler()),
|
||||
))
|
||||
|
||||
assert owner_seen_by_active_hook == ["alice"]
|
||||
assert json.loads(report.read_text(encoding="utf-8"))["owner"] == "alice2"
|
||||
|
||||
|
||||
def test_rename_research_respects_custom_data_dir(monkeypatch, tmp_path):
|
||||
"""DEEP_RESEARCH_DIR (which honours ODYSSEUS_DATA_DIR) is used, not a
|
||||
hardcoded relative path. Before the fix, setting ODYSSEUS_DATA_DIR made
|
||||
@@ -333,8 +490,100 @@ def test_rename_no_skills_dir_does_not_crash(rename_endpoint):
|
||||
assert res["ok"] is True
|
||||
|
||||
|
||||
def test_rename_skill_md_owner_case_insensitive(rename_endpoint):
|
||||
"""SKILL.md written with owner: Alice (mixed case) must be updated when
|
||||
renaming alice — the regex was missing re.IGNORECASE."""
|
||||
endpoint, _am, tmp_path = rename_endpoint
|
||||
|
||||
skill_dir = tmp_path / "skills" / "general" / "s"
|
||||
skill_dir.mkdir(parents=True)
|
||||
(skill_dir / "SKILL.md").write_text(_SKILL_MD.format(owner="Alice"), encoding="utf-8")
|
||||
|
||||
asyncio.run(endpoint("alice", SimpleNamespace(username="alice2"), _request(tmp_path)))
|
||||
|
||||
assert "owner: alice2" in (skill_dir / "SKILL.md").read_text(encoding="utf-8")
|
||||
|
||||
|
||||
def test_rename_usage_keys_case_insensitive(rename_endpoint):
|
||||
"""_usage.json keys stored as Alice::skill-name must be migrated when
|
||||
renaming alice — the old startswith check was not lowercasing."""
|
||||
endpoint, _am, tmp_path = rename_endpoint
|
||||
|
||||
skills_root = tmp_path / "skills"
|
||||
skills_root.mkdir(parents=True)
|
||||
usage = {"Alice::my-skill": {"uses": 5, "last_used": 999}}
|
||||
(skills_root / "_usage.json").write_text(json.dumps(usage), encoding="utf-8")
|
||||
|
||||
asyncio.run(endpoint("alice", SimpleNamespace(username="alice2"), _request(tmp_path)))
|
||||
|
||||
updated = json.loads((skills_root / "_usage.json").read_text(encoding="utf-8"))
|
||||
assert "alice2::my-skill" in updated
|
||||
assert "Alice::my-skill" not in updated
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 5. P1 regression: rejected auth rename must not mutate file-backed stores
|
||||
# 5. Rollback: auth rename must be restored if SQL owner migration fails
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_owner_migration_failure_rolls_back_auth_rename(monkeypatch, tmp_path):
|
||||
import routes.auth_routes as ar
|
||||
|
||||
db = _force_sql_owner_migration_failure(monkeypatch)
|
||||
am = _auth_manager_for_rollback_test(monkeypatch, tmp_path)
|
||||
admin_token = am.create_session_trusted("admin")
|
||||
alice_token = am.create_session_trusted("alice")
|
||||
endpoint = _route(ar.setup_auth_routes(am), "rename_user")
|
||||
|
||||
with pytest.raises(HTTPException) as exc:
|
||||
asyncio.run(
|
||||
endpoint(
|
||||
"alice",
|
||||
SimpleNamespace(username="alice2"),
|
||||
_request(tmp_path, token=admin_token),
|
||||
)
|
||||
)
|
||||
|
||||
assert exc.value.status_code == 500
|
||||
assert db.rolled_back is True
|
||||
assert db.closed is True
|
||||
assert "alice" in am.users
|
||||
assert "alice2" not in am.users
|
||||
assert am.get_username_for_token(alice_token) == "alice"
|
||||
saved_users = json.loads((tmp_path / "auth.json").read_text(encoding="utf-8"))["users"]
|
||||
assert "alice" in saved_users
|
||||
assert "alice2" not in saved_users
|
||||
|
||||
|
||||
def test_self_rename_owner_migration_failure_rolls_back_auth_session(monkeypatch, tmp_path):
|
||||
import routes.auth_routes as ar
|
||||
|
||||
db = _force_sql_owner_migration_failure(monkeypatch)
|
||||
am = _auth_manager_for_rollback_test(monkeypatch, tmp_path)
|
||||
admin_token = am.create_session_trusted("admin")
|
||||
endpoint = _route(ar.setup_auth_routes(am), "rename_user")
|
||||
|
||||
with pytest.raises(HTTPException) as exc:
|
||||
asyncio.run(
|
||||
endpoint(
|
||||
"admin",
|
||||
SimpleNamespace(username="chief"),
|
||||
_request(tmp_path, token=admin_token),
|
||||
)
|
||||
)
|
||||
|
||||
assert exc.value.status_code == 500
|
||||
assert db.rolled_back is True
|
||||
assert db.closed is True
|
||||
assert "admin" in am.users
|
||||
assert "chief" not in am.users
|
||||
assert am.get_username_for_token(admin_token) == "admin"
|
||||
saved_users = json.loads((tmp_path / "auth.json").read_text(encoding="utf-8"))["users"]
|
||||
assert "admin" in saved_users
|
||||
assert "chief" not in saved_users
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 6. P1 regression: rejected auth rename must not mutate file-backed stores
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_rejected_rename_does_not_mutate_files(monkeypatch, tmp_path):
|
||||
|
||||
@@ -0,0 +1,41 @@
|
||||
"""get_status must not rescan the whole research dir on every SSE poll.
|
||||
|
||||
get_avg_duration() globs and JSON-parses every file under the research data dir.
|
||||
get_status() called it unconditionally on each poll, including for sessions that
|
||||
are not active (the common case while a client polls a finished report). It is
|
||||
now computed only for active sessions and memoized on the entry.
|
||||
"""
|
||||
from src.research_handler import ResearchHandler
|
||||
|
||||
|
||||
def _handler():
|
||||
h = ResearchHandler.__new__(ResearchHandler)
|
||||
h._active_tasks = {}
|
||||
return h
|
||||
|
||||
|
||||
def test_inactive_session_does_not_compute_avg(monkeypatch):
|
||||
h = _handler()
|
||||
calls = []
|
||||
monkeypatch.setattr(h, "get_avg_duration", lambda: (calls.append(1), 5.0)[1])
|
||||
# Unknown session, no disk file -> None, and no expensive avg scan.
|
||||
assert h.get_status("missing-session") is None
|
||||
assert calls == []
|
||||
|
||||
|
||||
def test_active_session_memoizes_avg(monkeypatch):
|
||||
h = _handler()
|
||||
h._active_tasks["s1"] = {
|
||||
"status": "running", "progress": {}, "query": "q", "started_at": 0,
|
||||
}
|
||||
calls = []
|
||||
monkeypatch.setattr(h, "get_avg_duration", lambda: (calls.append(1), 12.0)[1])
|
||||
|
||||
r1 = h.get_status("s1")
|
||||
r2 = h.get_status("s1")
|
||||
r3 = h.get_status("s1")
|
||||
|
||||
assert r1["avg_duration"] == 12.0
|
||||
assert r2["avg_duration"] == 12.0 and r3["avg_duration"] == 12.0
|
||||
# Computed once across many polls, not once per poll.
|
||||
assert len(calls) == 1
|
||||
@@ -58,6 +58,62 @@ def test_rename_into_reserved_username_is_blocked(tmp_path):
|
||||
assert "bob" in mgr.users
|
||||
|
||||
|
||||
def test_legacy_reserved_username_is_removed_on_load(tmp_path):
|
||||
auth_path = tmp_path / "auth.json"
|
||||
auth_path.write_text(
|
||||
'{"users": {"internal-tool": {"password_hash": "unused", "is_admin": false}, '
|
||||
'"admin": {"password_hash": "unused", "is_admin": true}}}',
|
||||
encoding="utf-8",
|
||||
)
|
||||
mgr = _fresh_auth_manager(tmp_path)
|
||||
|
||||
assert "internal-tool" not in mgr.users
|
||||
assert "admin" in mgr.users
|
||||
assert "internal-tool" not in auth_path.read_text(encoding="utf-8")
|
||||
|
||||
|
||||
def test_legacy_reserved_username_session_cannot_authenticate(tmp_path):
|
||||
auth_path = tmp_path / "auth.json"
|
||||
sessions_path = tmp_path / "sessions.json"
|
||||
auth_path.write_text(
|
||||
'{"users": {"internal-tool": {"password_hash": "unused", "is_admin": false}}}',
|
||||
encoding="utf-8",
|
||||
)
|
||||
sessions_path.write_text(
|
||||
'{"tok": {"username": "internal-tool", "expiry": 9999999999}}',
|
||||
encoding="utf-8",
|
||||
)
|
||||
mgr = _fresh_auth_manager(tmp_path)
|
||||
|
||||
assert mgr.validate_token("tok") is False
|
||||
assert mgr.get_username_for_token("tok") is None
|
||||
|
||||
|
||||
def test_legacy_reserved_single_user_migrates_to_admin(tmp_path):
|
||||
auth_path = tmp_path / "auth.json"
|
||||
auth_path.write_text(
|
||||
'{"username": "internal-tool", "password_hash": "unused"}',
|
||||
encoding="utf-8",
|
||||
)
|
||||
mgr = _fresh_auth_manager(tmp_path)
|
||||
|
||||
assert "internal-tool" not in mgr.users
|
||||
assert "admin" in mgr.users
|
||||
assert mgr.is_admin("admin") is True
|
||||
|
||||
|
||||
def test_token_cache_owner_normalization_requires_current_user():
|
||||
clear_module("core.auth")
|
||||
from core.auth import normalize_known_username
|
||||
|
||||
users = {"alice": {}, "admin": {}}
|
||||
|
||||
assert normalize_known_username(users, " Alice ") == "alice"
|
||||
assert normalize_known_username(users, "internal-tool") is None
|
||||
assert normalize_known_username(users, "api") is None
|
||||
assert normalize_known_username(users, "") is None
|
||||
|
||||
|
||||
def test_normal_usernames_still_allowed(tmp_path):
|
||||
mgr = _fresh_auth_manager(tmp_path)
|
||||
assert mgr.create_user("alice", "pw-123456") is True
|
||||
|
||||
@@ -647,6 +647,60 @@ def test_public_agent_policy_hides_sensitive_tools(monkeypatch):
|
||||
assert "manage_tasks" in blocked
|
||||
|
||||
|
||||
def test_presetup_does_not_grant_admin_tools_when_auth_enabled(monkeypatch):
|
||||
"""Pre-setup window: auth is enabled but no admin user exists yet.
|
||||
|
||||
This must NOT be treated as single-user/admin at the tool layer — the
|
||||
server-execution tools (bash/python) stay blocked as defense-in-depth so
|
||||
an unauthenticated caller that slips past the auth middleware (e.g. via a
|
||||
loopback bypass) can't reach an RCE before setup completes.
|
||||
"""
|
||||
monkeypatch.delenv("AUTH_ENABLED", raising=False) # default: enabled
|
||||
auth_mod = _install_core_auth_stub(monkeypatch)
|
||||
|
||||
class FakeAuth:
|
||||
is_configured = False
|
||||
|
||||
def is_admin(self, username):
|
||||
return False
|
||||
|
||||
monkeypatch.setattr(auth_mod, "AuthManager", lambda: FakeAuth())
|
||||
|
||||
from src.tool_security import (
|
||||
blocked_tools_for_owner,
|
||||
owner_is_admin_or_single_user,
|
||||
)
|
||||
|
||||
assert owner_is_admin_or_single_user(None) is False
|
||||
blocked = blocked_tools_for_owner(None)
|
||||
assert "bash" in blocked
|
||||
assert "python" in blocked
|
||||
|
||||
|
||||
def test_single_user_mode_keeps_full_tool_access_when_auth_disabled(monkeypatch):
|
||||
"""Intentional single-user mode (AUTH_ENABLED=false) keeps full tool
|
||||
access even with no admin user — this is the default local/self-host UX
|
||||
and must not regress."""
|
||||
monkeypatch.setenv("AUTH_ENABLED", "false")
|
||||
auth_mod = _install_core_auth_stub(monkeypatch)
|
||||
|
||||
class FakeAuth:
|
||||
is_configured = False
|
||||
|
||||
def is_admin(self, username):
|
||||
return False
|
||||
|
||||
monkeypatch.setattr(auth_mod, "AuthManager", lambda: FakeAuth())
|
||||
|
||||
from src.tool_security import (
|
||||
blocked_tools_for_owner,
|
||||
owner_is_admin_or_single_user,
|
||||
)
|
||||
|
||||
assert owner_is_admin_or_single_user(None) is True
|
||||
assert blocked_tools_for_owner(None) == set()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_webhook_tool_reuses_private_url_validation():
|
||||
class FakeDb:
|
||||
|
||||
@@ -0,0 +1,23 @@
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
|
||||
from routes._validators import validate_remote_host, validate_ssh_port
|
||||
|
||||
|
||||
def test_validate_ssh_port_rejects_shell_payload():
|
||||
for port in ["22;id", "$(id)", "-p 22", "0", "65536"]:
|
||||
with pytest.raises(HTTPException):
|
||||
validate_ssh_port(port)
|
||||
assert validate_ssh_port("2222") == "2222"
|
||||
|
||||
|
||||
def test_validate_remote_host_rejects_ssh_option_shape():
|
||||
for host in [
|
||||
"-oProxyCommand=sh",
|
||||
"alice@-oProxyCommand=sh",
|
||||
"--",
|
||||
"-p2222",
|
||||
]:
|
||||
with pytest.raises(HTTPException):
|
||||
validate_remote_host(host)
|
||||
assert validate_remote_host("alice@gpu-box_1") == "alice@gpu-box_1"
|
||||
@@ -90,8 +90,8 @@ def test_service_ddg_html_fallback_sends_safesearch(monkeypatch):
|
||||
seen["params"] = kwargs["params"]
|
||||
return _Response()
|
||||
|
||||
monkeypatch.setitem(sys.modules, "duckduckgo_search", None)
|
||||
monkeypatch.setattr(providers, "_get_search_settings", lambda: {"search_safesearch": "off"})
|
||||
monkeypatch.setitem(sys.modules, "ddgs", None)
|
||||
monkeypatch.setattr(providers.httpx, "get", fake_get)
|
||||
|
||||
results = providers.duckduckgo_search("odysseus", count=1)
|
||||
|
||||
@@ -0,0 +1,47 @@
|
||||
"""Startup warmup must resolve real endpoint URLs.
|
||||
|
||||
The warmup/keepalive loop called `model_discovery.get_endpoints()`, which does
|
||||
not exist on ModelDiscovery, so it raised AttributeError every run and pinged
|
||||
nothing. `ModelDiscovery.warmup_ping_urls()` resolves the /models probe URLs
|
||||
from the real discovery API.
|
||||
"""
|
||||
from src.model_discovery import ModelDiscovery
|
||||
|
||||
|
||||
def _md():
|
||||
return ModelDiscovery.__new__(ModelDiscovery)
|
||||
|
||||
|
||||
def test_old_method_never_existed():
|
||||
# Documents why the old warmup was a silent no-op.
|
||||
assert not hasattr(ModelDiscovery, "get_endpoints")
|
||||
|
||||
|
||||
def test_resolves_models_urls_from_discovered_items():
|
||||
md = _md()
|
||||
md.discover_models = lambda: {"items": [
|
||||
{"url": "http://host:8000/v1/chat/completions", "models": ["a"]},
|
||||
{"url": "http://host:1234/v1/chat/completions", "models": ["b"]},
|
||||
]}
|
||||
assert md.warmup_ping_urls() == [
|
||||
"http://host:8000/v1/models",
|
||||
"http://host:1234/v1/models",
|
||||
]
|
||||
|
||||
|
||||
def test_limit_caps_results():
|
||||
md = _md()
|
||||
md.discover_models = lambda: {"items": [
|
||||
{"url": f"http://h:{8000 + i}/v1/chat/completions"} for i in range(10)
|
||||
]}
|
||||
assert len(md.warmup_ping_urls(limit=3)) == 3
|
||||
|
||||
|
||||
def test_discovery_failure_degrades_to_empty():
|
||||
md = _md()
|
||||
|
||||
def boom():
|
||||
raise RuntimeError("port scan failed")
|
||||
|
||||
md.discover_models = boom
|
||||
assert md.warmup_ping_urls() == []
|
||||
@@ -0,0 +1,119 @@
|
||||
"""Pin the web_search tool-icon rendering in the agent thread (PR #??).
|
||||
|
||||
Verifies:
|
||||
- web_search renders an <svg> icon instead of raw markup
|
||||
- Other tools get the default ▶ icon
|
||||
- Hostile tool names are HTML-escaped in the label
|
||||
|
||||
Pure JS via node --input-type=module (same approach as
|
||||
test_composer_arrow_up_recall_js.py). Skips when node is not installed.
|
||||
"""
|
||||
|
||||
import json
|
||||
import shutil
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
_REPO = Path(__file__).resolve().parent.parent
|
||||
_HAS_NODE = shutil.which("node") is not None
|
||||
|
||||
_CHECK_JS = r"""
|
||||
function esc(s) {
|
||||
const map = { '&': '&', '<': '<', '>': '>', '"': '"', "'": ''' };
|
||||
return (s || '').replace(/[&<>"']/g, (m) => map[m]);
|
||||
}
|
||||
|
||||
const _searchIcon = '<svg width="14" height="14" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2.5" stroke-linecap="round" style="vertical-align:-2px;margin-right:4px"><circle cx="11" cy="11" r="8"/><line x1="21" y1="21" x2="16.65" y2="16.65"/></svg>';
|
||||
|
||||
const _toolLabels = {
|
||||
web_search: 'Searching',
|
||||
bash: 'Running',
|
||||
};
|
||||
|
||||
const _toolIcons = {
|
||||
web_search: _searchIcon,
|
||||
};
|
||||
|
||||
function renderIcon(toolName) {
|
||||
return _toolIcons[toolName.toLowerCase()] || '\u25B6';
|
||||
}
|
||||
|
||||
function renderLabel(toolName) {
|
||||
return _toolLabels[toolName.toLowerCase()] || toolName;
|
||||
}
|
||||
|
||||
function renderThreadHTML(toolName, cmd) {
|
||||
const label = renderLabel(toolName);
|
||||
const icon = renderIcon(toolName);
|
||||
const cmdHtml = cmd ? `<pre class="agent-thread-cmd">${esc(cmd)}</pre>` : '';
|
||||
return `<div class="agent-thread-dot"></div><div class="agent-thread-header"><span class="agent-thread-icon">${icon}</span><span class="agent-thread-tool">${esc(label)}</span><span class="agent-thread-wave">\u2581\u2582\u2583</span></div><div class="agent-thread-content">${cmdHtml}</div>`;
|
||||
}
|
||||
|
||||
const cases = CASES_JSON;
|
||||
const results = cases.map(c => {
|
||||
const html = renderThreadHTML(c.tool, c.cmd || '');
|
||||
return { tool: c.tool, html };
|
||||
});
|
||||
console.log(JSON.stringify(results));
|
||||
"""
|
||||
|
||||
|
||||
def _run(cases: list) -> list:
|
||||
js = _CHECK_JS.replace("CASES_JSON", json.dumps(cases))
|
||||
proc = subprocess.run(
|
||||
["node", "--input-type=module"],
|
||||
input=js,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
encoding="utf-8",
|
||||
cwd=str(_REPO),
|
||||
timeout=30,
|
||||
)
|
||||
assert proc.returncode == 0, proc.stderr
|
||||
return json.loads(proc.stdout.strip())
|
||||
|
||||
|
||||
@pytest.mark.skipif(not _HAS_NODE, reason="node binary not on PATH")
|
||||
def test_web_search_icon_contains_svg():
|
||||
out = _run([{"tool": "web_search"}])[0]
|
||||
assert "<svg" in out["html"], "Expected <svg> in agent-thread-icon for web_search"
|
||||
assert "Searching" in out["html"], "Expected 'Searching' label for web_search"
|
||||
|
||||
|
||||
@pytest.mark.skipif(not _HAS_NODE, reason="node binary not on PATH")
|
||||
def test_default_tool_icon_is_triangle():
|
||||
out = _run([{"tool": "bash"}])[0]
|
||||
assert "▶" in out["html"], "Expected ▶ icon for tools without custom icon"
|
||||
assert "<svg" not in out["html"], "Expected no <svg> for bash"
|
||||
assert "Running" in out["html"], "Expected 'Running' label for bash"
|
||||
|
||||
|
||||
@pytest.mark.skipif(not _HAS_NODE, reason="node binary not on PATH")
|
||||
def test_unknown_tool_falls_back_to_name():
|
||||
out = _run([{"tool": "my_custom_tool"}])[0]
|
||||
assert "▶" in out["html"], "Expected ▶ for unknown tool"
|
||||
assert "my_custom_tool" in out["html"], "Expected tool name as label"
|
||||
|
||||
|
||||
@pytest.mark.skipif(not _HAS_NODE, reason="node binary not on PATH")
|
||||
def test_hostile_tool_name_is_escaped():
|
||||
out = _run([{"tool": '<img src=x onerror="alert(1)">'}])[0]
|
||||
assert "<img" in out["html"], "Expected < to be HTML-escaped"
|
||||
assert ">" in out["html"], "Expected > to be HTML-escaped"
|
||||
assert "<img" not in out["html"], "Raw <img> must not appear"
|
||||
assert "onerror" not in out["html"] or """ in out["html"], "onerror must not be executable"
|
||||
|
||||
|
||||
@pytest.mark.skipif(not _HAS_NODE, reason="node binary not on PATH")
|
||||
def test_unknown_tool_case_insensitive_matches_icons():
|
||||
out = _run([{"tool": "WEB_SEARCH"}, {"tool": "Web_Search"}])
|
||||
for r in out:
|
||||
assert "<svg" in r["html"], f"Expected SVG for case-variant '{r['tool']}'"
|
||||
|
||||
|
||||
@pytest.mark.skipif(not _HAS_NODE, reason="node binary not on PATH")
|
||||
def test_command_is_escaped():
|
||||
out = _run([{"tool": "bash", "cmd": "echo $HOME && ls"}])[0]
|
||||
assert "echo $HOME" in out["html"], "Expected command text in output"
|
||||
Reference in New Issue
Block a user