mirror of
https://github.com/pewdiepie-archdaemon/odysseus.git
synced 2026-06-17 18:25:26 -04:00
Compare commits
54 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 64cf0f3fc1 | |||
| 76c3cac175 | |||
| 620fdd0859 | |||
| 02f25f0a1c | |||
| d528da8308 | |||
| e32150ad96 | |||
| 95c54ac3cb | |||
| 263d41c58a | |||
| f941db29d3 | |||
| bfac1d55d6 | |||
| cc8ba04ea8 | |||
| 4fa4d0100a | |||
| c500bcb47d | |||
| f7a3605b16 | |||
| 1a2bcfcae4 | |||
| 65d9603c8c | |||
| a7b03398b6 | |||
| 4f48cfa9ae | |||
| af61b2d4e6 | |||
| 0b0656df11 | |||
| 9f47c5ff87 | |||
| dd2d375c7b | |||
| 73823c878e | |||
| 50fedff2f2 | |||
| 66c25cbc2f | |||
| 09ec880c06 | |||
| 5e16126bde | |||
| c01034f9cb | |||
| 8adca3a924 | |||
| d5603ee575 | |||
| 9c00da6d1c | |||
| d1a5a7d680 | |||
| 218b9ecbc8 | |||
| d9a4b99046 | |||
| f5b91f1e9e | |||
| 8bf8212846 | |||
| a0b0420e6f | |||
| 96975f8dd9 | |||
| 4e210d3337 | |||
| 800d391234 | |||
| 9c8df89973 | |||
| 6f73c8afaa | |||
| e384c5a2a6 | |||
| edce608008 | |||
| ee6cfbd25a | |||
| cd3fb4e96b | |||
| e115b0155c | |||
| 59fc6604be | |||
| 725d174243 | |||
| e98567c2b9 | |||
| f34ae6b965 | |||
| 1ef50279fb | |||
| c0d8c4de3e | |||
| 5deea5664e |
@@ -329,7 +329,7 @@ To expose Odysseus on a local network or Tailscale with HTTPS:
|
|||||||
| Package | Feature unlocked |
|
| Package | Feature unlocked |
|
||||||
|---------|-----------------|
|
|---------|-----------------|
|
||||||
| `faster-whisper` | Local speech-to-text (microphone -> text) via the "local" STT provider. |
|
| `faster-whisper` | Local speech-to-text (microphone -> text) via the "local" STT provider. |
|
||||||
| `duckduckgo-search` | DuckDuckGo as a search provider option. |
|
| `ddgs` | DuckDuckGo as a search provider option. |
|
||||||
| `PyMuPDF` | PDF page rendering in the side viewer panel and form-filling. (Note: AGPL-3.0) |
|
| `PyMuPDF` | PDF page rendering in the side viewer panel and form-filling. (Note: AGPL-3.0) |
|
||||||
| `markitdown` | Office/EPUB document text extraction (converts .docx/.xlsx/.pptx/.xls/.epub to Markdown). |
|
| `markitdown` | Office/EPUB document text extraction (converts .docx/.xlsx/.pptx/.xls/.epub to Markdown). |
|
||||||
|
|
||||||
|
|||||||
@@ -56,7 +56,7 @@ from core.constants import (
|
|||||||
)
|
)
|
||||||
from core.database import SessionLocal, ApiToken
|
from core.database import SessionLocal, ApiToken
|
||||||
from core.middleware import SecurityHeadersMiddleware, is_cors_preflight
|
from core.middleware import SecurityHeadersMiddleware, is_cors_preflight
|
||||||
from core.auth import AuthManager
|
from core.auth import AuthManager, normalize_known_username
|
||||||
from core.exceptions import (
|
from core.exceptions import (
|
||||||
SessionNotFoundError, InvalidFileUploadError,
|
SessionNotFoundError, InvalidFileUploadError,
|
||||||
LLMServiceError, WebSearchError,
|
LLMServiceError, WebSearchError,
|
||||||
@@ -228,8 +228,16 @@ if AUTH_ENABLED:
|
|||||||
try:
|
try:
|
||||||
rows = db.query(ApiToken).filter(ApiToken.is_active == True).all()
|
rows = db.query(ApiToken).filter(ApiToken.is_active == True).all()
|
||||||
for r in rows:
|
for r in rows:
|
||||||
|
owner_key = normalize_known_username(auth_manager.users, getattr(r, "owner", None))
|
||||||
|
if not owner_key:
|
||||||
|
logger.warning(
|
||||||
|
"Ignoring active API token '%s' for unknown auth user '%s'",
|
||||||
|
getattr(r, "id", ""),
|
||||||
|
getattr(r, "owner", None),
|
||||||
|
)
|
||||||
|
continue
|
||||||
scopes = [s.strip() for s in (getattr(r, "scopes", "") or "chat").split(",") if s.strip()]
|
scopes = [s.strip() for s in (getattr(r, "scopes", "") or "chat").split(",") if s.strip()]
|
||||||
new_map[r.token_prefix].append((r.id, r.token_hash, getattr(r, "owner", None), scopes))
|
new_map[r.token_prefix].append((r.id, r.token_hash, owner_key, scopes))
|
||||||
finally:
|
finally:
|
||||||
db.close()
|
db.close()
|
||||||
_token_cache.clear()
|
_token_cache.clear()
|
||||||
@@ -490,11 +498,13 @@ app.state.session_manager = session_manager
|
|||||||
memory_manager = components["memory_manager"]
|
memory_manager = components["memory_manager"]
|
||||||
memory_vector = components.get("memory_vector")
|
memory_vector = components.get("memory_vector")
|
||||||
upload_handler = components["upload_handler"]
|
upload_handler = components["upload_handler"]
|
||||||
|
app.state.upload_handler = upload_handler
|
||||||
personal_docs_mgr = components["personal_docs_manager"]
|
personal_docs_mgr = components["personal_docs_manager"]
|
||||||
api_key_manager = components["api_key_manager"]
|
api_key_manager = components["api_key_manager"]
|
||||||
preset_manager = components["preset_manager"]
|
preset_manager = components["preset_manager"]
|
||||||
chat_processor = components["chat_processor"]
|
chat_processor = components["chat_processor"]
|
||||||
research_handler = components["research_handler"]
|
research_handler = components["research_handler"]
|
||||||
|
app.state.research_handler = research_handler
|
||||||
chat_handler = components["chat_handler"]
|
chat_handler = components["chat_handler"]
|
||||||
model_discovery = components["model_discovery"]
|
model_discovery = components["model_discovery"]
|
||||||
skills_manager = components["skills_manager"]
|
skills_manager = components["skills_manager"]
|
||||||
@@ -666,6 +676,9 @@ app.include_router(setup_shell_routes())
|
|||||||
from routes.cookbook_routes import setup_cookbook_routes
|
from routes.cookbook_routes import setup_cookbook_routes
|
||||||
app.include_router(setup_cookbook_routes())
|
app.include_router(setup_cookbook_routes())
|
||||||
|
|
||||||
|
from routes.workspace_routes import setup_workspace_routes
|
||||||
|
app.include_router(setup_workspace_routes())
|
||||||
|
|
||||||
# Hardware model fitting (cookbook "What Fits?" tab)
|
# Hardware model fitting (cookbook "What Fits?" tab)
|
||||||
from routes.hwfit_routes import setup_hwfit_routes
|
from routes.hwfit_routes import setup_hwfit_routes
|
||||||
app.include_router(setup_hwfit_routes())
|
app.include_router(setup_hwfit_routes())
|
||||||
@@ -938,10 +951,15 @@ async def _startup_event():
|
|||||||
async def _warmup_endpoints():
|
async def _warmup_endpoints():
|
||||||
try:
|
try:
|
||||||
import httpx
|
import httpx
|
||||||
endpoints = model_discovery.get_endpoints() if model_discovery else []
|
# model_discovery has no get_endpoints(); that call raised
|
||||||
for ep in endpoints[:5]:
|
# AttributeError every run and silently disabled warmup/keepalive.
|
||||||
url = ep.get("url", "").replace("/chat/completions", "/models")
|
# Resolve the /models probe URLs via the real discovery API, off the
|
||||||
if url:
|
# event loop since discovery does a blocking port scan.
|
||||||
|
urls = (
|
||||||
|
await asyncio.to_thread(model_discovery.warmup_ping_urls)
|
||||||
|
if model_discovery else []
|
||||||
|
)
|
||||||
|
for url in urls:
|
||||||
try:
|
try:
|
||||||
async with httpx.AsyncClient(timeout=5.0) as client:
|
async with httpx.AsyncClient(timeout=5.0) as client:
|
||||||
await client.get(url)
|
await client.get(url)
|
||||||
|
|||||||
+56
-13
@@ -67,6 +67,14 @@ TOKEN_TTL = 60 * 60 * 24 * 7 # 7 days
|
|||||||
RESERVED_USERNAMES = frozenset({"internal-tool", "api", "demo", "system"})
|
RESERVED_USERNAMES = frozenset({"internal-tool", "api", "demo", "system"})
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_known_username(users: Dict[str, Any], username: str | None) -> Optional[str]:
|
||||||
|
"""Return a normalized username only when it exists in the auth user map."""
|
||||||
|
key = str(username or "").strip().lower()
|
||||||
|
if not key or key not in users:
|
||||||
|
return None
|
||||||
|
return key
|
||||||
|
|
||||||
|
|
||||||
def _hash_password(password: str) -> str:
|
def _hash_password(password: str) -> str:
|
||||||
return bcrypt.hashpw(password.encode("utf-8"), bcrypt.gensalt()).decode("utf-8")
|
return bcrypt.hashpw(password.encode("utf-8"), bcrypt.gensalt()).decode("utf-8")
|
||||||
|
|
||||||
@@ -96,6 +104,7 @@ class AuthManager:
|
|||||||
self._load()
|
self._load()
|
||||||
self._load_sessions()
|
self._load_sessions()
|
||||||
self._migrate_single_user()
|
self._migrate_single_user()
|
||||||
|
self._drop_reserved_loaded_users()
|
||||||
self._migrate_legacy_admin_role()
|
self._migrate_legacy_admin_role()
|
||||||
|
|
||||||
def _load(self):
|
def _load(self):
|
||||||
@@ -148,7 +157,13 @@ class AuthManager:
|
|||||||
def _migrate_single_user(self):
|
def _migrate_single_user(self):
|
||||||
"""Migrate old single-user format to multi-user format."""
|
"""Migrate old single-user format to multi-user format."""
|
||||||
if "password_hash" in self._config and "users" not in self._config:
|
if "password_hash" in self._config and "users" not in self._config:
|
||||||
old_user = self._config.get("username", "admin")
|
old_user = str(self._config.get("username", "admin") or "admin").strip().lower()
|
||||||
|
if old_user in RESERVED_USERNAMES:
|
||||||
|
logger.warning(
|
||||||
|
"Migrating legacy single-user reserved username '%s' to 'admin'",
|
||||||
|
old_user,
|
||||||
|
)
|
||||||
|
old_user = "admin"
|
||||||
old_hash = self._config["password_hash"]
|
old_hash = self._config["password_hash"]
|
||||||
self._config = {
|
self._config = {
|
||||||
"users": {
|
"users": {
|
||||||
@@ -162,6 +177,30 @@ class AuthManager:
|
|||||||
self._save()
|
self._save()
|
||||||
logger.info(f"Migrated single-user auth to multi-user (admin: {old_user})")
|
logger.info(f"Migrated single-user auth to multi-user (admin: {old_user})")
|
||||||
|
|
||||||
|
def _drop_reserved_loaded_users(self):
|
||||||
|
"""Fail closed for legacy/manual auth rows that collide with sentinels."""
|
||||||
|
users = self._config.get("users")
|
||||||
|
if not isinstance(users, dict):
|
||||||
|
return
|
||||||
|
normalized = {}
|
||||||
|
removed = []
|
||||||
|
for username, data in users.items():
|
||||||
|
key = str(username or "").strip().lower()
|
||||||
|
if not key:
|
||||||
|
continue
|
||||||
|
if key in RESERVED_USERNAMES:
|
||||||
|
removed.append(key)
|
||||||
|
continue
|
||||||
|
normalized[key] = data
|
||||||
|
if removed or normalized != users:
|
||||||
|
self._config["users"] = normalized
|
||||||
|
self._save()
|
||||||
|
if removed:
|
||||||
|
logger.warning(
|
||||||
|
"Removed reserved username(s) from auth config: %s",
|
||||||
|
", ".join(sorted(set(removed))),
|
||||||
|
)
|
||||||
|
|
||||||
def _migrate_legacy_admin_role(self):
|
def _migrate_legacy_admin_role(self):
|
||||||
"""Normalize setup.py's old role='admin' marker to is_admin=True."""
|
"""Normalize setup.py's old role='admin' marker to is_admin=True."""
|
||||||
changed = False
|
changed = False
|
||||||
@@ -244,6 +283,22 @@ class AuthManager:
|
|||||||
return False
|
return False
|
||||||
if not self.users.get(requesting_user, {}).get("is_admin"):
|
if not self.users.get(requesting_user, {}).get("is_admin"):
|
||||||
return False
|
return False
|
||||||
|
# Revoke API bearer tokens before removing the auth row. The bearer
|
||||||
|
# path authenticates from ApiToken rows and does not require the
|
||||||
|
# owner to still exist, so a successful delete must not leave active
|
||||||
|
# rows behind. If the token store is unavailable, fail closed and
|
||||||
|
# keep the user/session state intact so the admin can retry.
|
||||||
|
try:
|
||||||
|
from core.database import get_db_session, ApiToken
|
||||||
|
with get_db_session() as db:
|
||||||
|
removed_tokens = db.query(ApiToken).filter(ApiToken.owner == username).delete()
|
||||||
|
if removed_tokens:
|
||||||
|
logger.info(
|
||||||
|
f"Revoked {removed_tokens} API token(s) owned by deleted user '{username}'"
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logger.warning(f"Failed to revoke API tokens for deleted user '{username}'")
|
||||||
|
return False
|
||||||
del self._config["users"][username]
|
del self._config["users"][username]
|
||||||
self._save()
|
self._save()
|
||||||
# Purge all sessions belonging to this user. validate_token doesn't
|
# Purge all sessions belonging to this user. validate_token doesn't
|
||||||
@@ -258,18 +313,6 @@ class AuthManager:
|
|||||||
revoked += 1
|
revoked += 1
|
||||||
if revoked:
|
if revoked:
|
||||||
self._save_sessions()
|
self._save_sessions()
|
||||||
# Also revoke API bearer tokens owned by this user. The bearer auth
|
|
||||||
# path authenticates straight against ApiToken rows and never
|
|
||||||
# re-checks that the owner still exists, so leaving the rows behind
|
|
||||||
# would let a deleted user keep full API access indefinitely.
|
|
||||||
try:
|
|
||||||
from core.database import get_db_session, ApiToken
|
|
||||||
with get_db_session() as db:
|
|
||||||
removed = db.query(ApiToken).filter(ApiToken.owner == username).delete()
|
|
||||||
if removed:
|
|
||||||
logger.info(f"Revoked {removed} API token(s) owned by deleted user '{username}'")
|
|
||||||
except Exception:
|
|
||||||
logger.warning(f"Failed to revoke API tokens for deleted user '{username}'")
|
|
||||||
logger.info(f"Deleted user '{username}' (by {requesting_user}); revoked {revoked} active session(s)")
|
logger.info(f"Deleted user '{username}' (by {requesting_user}); revoked {revoked} active session(s)")
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|||||||
+150
-25
@@ -688,6 +688,7 @@ def _migrate_add_last_message_at_column():
|
|||||||
db_path = DATABASE_URL.replace("sqlite:///", "")
|
db_path = DATABASE_URL.replace("sqlite:///", "")
|
||||||
if not os.path.exists(db_path):
|
if not os.path.exists(db_path):
|
||||||
return
|
return
|
||||||
|
conn = None
|
||||||
try:
|
try:
|
||||||
conn = sqlite3.connect(db_path)
|
conn = sqlite3.connect(db_path)
|
||||||
cursor = conn.execute("PRAGMA table_info(sessions)")
|
cursor = conn.execute("PRAGMA table_info(sessions)")
|
||||||
@@ -713,10 +714,14 @@ def _migrate_add_last_message_at_column():
|
|||||||
"ON sessions(archived, last_message_at)"
|
"ON sessions(archived, last_message_at)"
|
||||||
)
|
)
|
||||||
conn.commit()
|
conn.commit()
|
||||||
conn.close()
|
|
||||||
logging.getLogger(__name__).info("Migrated: added + backfilled 'last_message_at' on sessions")
|
logging.getLogger(__name__).info("Migrated: added + backfilled 'last_message_at' on sessions")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.getLogger(__name__).warning(f"last_message_at migration failed: {e}")
|
logging.getLogger(__name__).warning(f"last_message_at migration failed: {e}")
|
||||||
|
finally:
|
||||||
|
try:
|
||||||
|
conn.close()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
def _migrate_add_document_archived_column():
|
def _migrate_add_document_archived_column():
|
||||||
"""Add `archived` to documents (soft-archive flag). Guarded + idempotent."""
|
"""Add `archived` to documents (soft-archive flag). Guarded + idempotent."""
|
||||||
@@ -724,6 +729,7 @@ def _migrate_add_document_archived_column():
|
|||||||
db_path = DATABASE_URL.replace("sqlite:///", "")
|
db_path = DATABASE_URL.replace("sqlite:///", "")
|
||||||
if not os.path.exists(db_path):
|
if not os.path.exists(db_path):
|
||||||
return
|
return
|
||||||
|
conn = None
|
||||||
try:
|
try:
|
||||||
conn = sqlite3.connect(db_path)
|
conn = sqlite3.connect(db_path)
|
||||||
cursor = conn.execute("PRAGMA table_info(documents)")
|
cursor = conn.execute("PRAGMA table_info(documents)")
|
||||||
@@ -732,9 +738,13 @@ def _migrate_add_document_archived_column():
|
|||||||
conn.execute("ALTER TABLE documents ADD COLUMN archived BOOLEAN DEFAULT 0")
|
conn.execute("ALTER TABLE documents ADD COLUMN archived BOOLEAN DEFAULT 0")
|
||||||
conn.commit()
|
conn.commit()
|
||||||
logging.getLogger(__name__).info("Migrated: added 'archived' to documents")
|
logging.getLogger(__name__).info("Migrated: added 'archived' to documents")
|
||||||
conn.close()
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.getLogger(__name__).warning(f"documents.archived migration failed: {e}")
|
logging.getLogger(__name__).warning(f"documents.archived migration failed: {e}")
|
||||||
|
finally:
|
||||||
|
try:
|
||||||
|
conn.close()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
def _migrate_add_owner_column():
|
def _migrate_add_owner_column():
|
||||||
@@ -743,6 +753,7 @@ def _migrate_add_owner_column():
|
|||||||
db_path = DATABASE_URL.replace("sqlite:///", "")
|
db_path = DATABASE_URL.replace("sqlite:///", "")
|
||||||
if not os.path.exists(db_path):
|
if not os.path.exists(db_path):
|
||||||
return
|
return
|
||||||
|
conn = None
|
||||||
try:
|
try:
|
||||||
conn = sqlite3.connect(db_path)
|
conn = sqlite3.connect(db_path)
|
||||||
cursor = conn.execute("PRAGMA table_info(sessions)")
|
cursor = conn.execute("PRAGMA table_info(sessions)")
|
||||||
@@ -752,9 +763,13 @@ def _migrate_add_owner_column():
|
|||||||
conn.execute("CREATE INDEX IF NOT EXISTS ix_sessions_owner ON sessions(owner)")
|
conn.execute("CREATE INDEX IF NOT EXISTS ix_sessions_owner ON sessions(owner)")
|
||||||
conn.commit()
|
conn.commit()
|
||||||
logging.getLogger(__name__).info("Migrated: added 'owner' column to sessions")
|
logging.getLogger(__name__).info("Migrated: added 'owner' column to sessions")
|
||||||
conn.close()
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.getLogger(__name__).warning(f"Migration check failed: {e}")
|
logging.getLogger(__name__).warning(f"Migration check failed: {e}")
|
||||||
|
finally:
|
||||||
|
try:
|
||||||
|
conn.close()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
def _migrate_model_endpoints():
|
def _migrate_model_endpoints():
|
||||||
"""Recreate model_endpoints table if schema changed (url->base_url)."""
|
"""Recreate model_endpoints table if schema changed (url->base_url)."""
|
||||||
@@ -762,6 +777,7 @@ def _migrate_model_endpoints():
|
|||||||
db_path = DATABASE_URL.replace("sqlite:///", "")
|
db_path = DATABASE_URL.replace("sqlite:///", "")
|
||||||
if not os.path.exists(db_path):
|
if not os.path.exists(db_path):
|
||||||
return
|
return
|
||||||
|
conn = None
|
||||||
try:
|
try:
|
||||||
conn = sqlite3.connect(db_path)
|
conn = sqlite3.connect(db_path)
|
||||||
cursor = conn.execute("PRAGMA table_info(model_endpoints)")
|
cursor = conn.execute("PRAGMA table_info(model_endpoints)")
|
||||||
@@ -770,9 +786,13 @@ def _migrate_model_endpoints():
|
|||||||
conn.execute("DROP TABLE IF EXISTS model_endpoints")
|
conn.execute("DROP TABLE IF EXISTS model_endpoints")
|
||||||
conn.commit()
|
conn.commit()
|
||||||
logging.getLogger(__name__).info("Migrated: dropped old model_endpoints table (schema change)")
|
logging.getLogger(__name__).info("Migrated: dropped old model_endpoints table (schema change)")
|
||||||
conn.close()
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.getLogger(__name__).warning(f"model_endpoints migration check failed: {e}")
|
logging.getLogger(__name__).warning(f"model_endpoints migration check failed: {e}")
|
||||||
|
finally:
|
||||||
|
try:
|
||||||
|
conn.close()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
def _migrate_add_hidden_models_column():
|
def _migrate_add_hidden_models_column():
|
||||||
"""Add hidden_models column to model_endpoints if it doesn't exist."""
|
"""Add hidden_models column to model_endpoints if it doesn't exist."""
|
||||||
@@ -780,6 +800,7 @@ def _migrate_add_hidden_models_column():
|
|||||||
db_path = DATABASE_URL.replace("sqlite:///", "")
|
db_path = DATABASE_URL.replace("sqlite:///", "")
|
||||||
if not os.path.exists(db_path):
|
if not os.path.exists(db_path):
|
||||||
return
|
return
|
||||||
|
conn = None
|
||||||
try:
|
try:
|
||||||
conn = sqlite3.connect(db_path)
|
conn = sqlite3.connect(db_path)
|
||||||
cursor = conn.execute("PRAGMA table_info(model_endpoints)")
|
cursor = conn.execute("PRAGMA table_info(model_endpoints)")
|
||||||
@@ -788,9 +809,13 @@ def _migrate_add_hidden_models_column():
|
|||||||
conn.execute("ALTER TABLE model_endpoints ADD COLUMN hidden_models TEXT")
|
conn.execute("ALTER TABLE model_endpoints ADD COLUMN hidden_models TEXT")
|
||||||
conn.commit()
|
conn.commit()
|
||||||
logging.getLogger(__name__).info("Migrated: added 'hidden_models' column to model_endpoints")
|
logging.getLogger(__name__).info("Migrated: added 'hidden_models' column to model_endpoints")
|
||||||
conn.close()
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.getLogger(__name__).warning(f"hidden_models migration failed: {e}")
|
logging.getLogger(__name__).warning(f"hidden_models migration failed: {e}")
|
||||||
|
finally:
|
||||||
|
try:
|
||||||
|
conn.close()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
def _migrate_add_model_endpoint_owner_column():
|
def _migrate_add_model_endpoint_owner_column():
|
||||||
"""Add owner column to model_endpoints if it doesn't exist.
|
"""Add owner column to model_endpoints if it doesn't exist.
|
||||||
@@ -805,6 +830,7 @@ def _migrate_add_model_endpoint_owner_column():
|
|||||||
db_path = DATABASE_URL.replace("sqlite:///", "")
|
db_path = DATABASE_URL.replace("sqlite:///", "")
|
||||||
if not os.path.exists(db_path):
|
if not os.path.exists(db_path):
|
||||||
return
|
return
|
||||||
|
conn = None
|
||||||
try:
|
try:
|
||||||
conn = sqlite3.connect(db_path)
|
conn = sqlite3.connect(db_path)
|
||||||
cursor = conn.execute("PRAGMA table_info(model_endpoints)")
|
cursor = conn.execute("PRAGMA table_info(model_endpoints)")
|
||||||
@@ -814,9 +840,13 @@ def _migrate_add_model_endpoint_owner_column():
|
|||||||
conn.execute("CREATE INDEX IF NOT EXISTS ix_model_endpoints_owner ON model_endpoints(owner)")
|
conn.execute("CREATE INDEX IF NOT EXISTS ix_model_endpoints_owner ON model_endpoints(owner)")
|
||||||
conn.commit()
|
conn.commit()
|
||||||
logging.getLogger(__name__).info("Migrated: added 'owner' column + index to model_endpoints")
|
logging.getLogger(__name__).info("Migrated: added 'owner' column + index to model_endpoints")
|
||||||
conn.close()
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.getLogger(__name__).warning(f"model_endpoints.owner migration failed: {e}")
|
logging.getLogger(__name__).warning(f"model_endpoints.owner migration failed: {e}")
|
||||||
|
finally:
|
||||||
|
try:
|
||||||
|
conn.close()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
def _migrate_add_provider_auth_id_column():
|
def _migrate_add_provider_auth_id_column():
|
||||||
@@ -825,6 +855,7 @@ def _migrate_add_provider_auth_id_column():
|
|||||||
db_path = DATABASE_URL.replace("sqlite:///", "")
|
db_path = DATABASE_URL.replace("sqlite:///", "")
|
||||||
if not os.path.exists(db_path):
|
if not os.path.exists(db_path):
|
||||||
return
|
return
|
||||||
|
conn = None
|
||||||
try:
|
try:
|
||||||
conn = sqlite3.connect(db_path)
|
conn = sqlite3.connect(db_path)
|
||||||
cursor = conn.execute("PRAGMA table_info(model_endpoints)")
|
cursor = conn.execute("PRAGMA table_info(model_endpoints)")
|
||||||
@@ -834,9 +865,13 @@ def _migrate_add_provider_auth_id_column():
|
|||||||
conn.execute("CREATE INDEX IF NOT EXISTS ix_model_endpoints_provider_auth_id ON model_endpoints(provider_auth_id)")
|
conn.execute("CREATE INDEX IF NOT EXISTS ix_model_endpoints_provider_auth_id ON model_endpoints(provider_auth_id)")
|
||||||
conn.commit()
|
conn.commit()
|
||||||
logging.getLogger(__name__).info("Migrated: added 'provider_auth_id' column + index to model_endpoints")
|
logging.getLogger(__name__).info("Migrated: added 'provider_auth_id' column + index to model_endpoints")
|
||||||
conn.close()
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.getLogger(__name__).warning(f"model_endpoints.provider_auth_id migration failed: {e}")
|
logging.getLogger(__name__).warning(f"model_endpoints.provider_auth_id migration failed: {e}")
|
||||||
|
finally:
|
||||||
|
try:
|
||||||
|
conn.close()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
def _migrate_add_model_type_column():
|
def _migrate_add_model_type_column():
|
||||||
@@ -845,6 +880,7 @@ def _migrate_add_model_type_column():
|
|||||||
db_path = DATABASE_URL.replace("sqlite:///", "")
|
db_path = DATABASE_URL.replace("sqlite:///", "")
|
||||||
if not os.path.exists(db_path):
|
if not os.path.exists(db_path):
|
||||||
return
|
return
|
||||||
|
conn = None
|
||||||
try:
|
try:
|
||||||
conn = sqlite3.connect(db_path)
|
conn = sqlite3.connect(db_path)
|
||||||
cursor = conn.execute("PRAGMA table_info(model_endpoints)")
|
cursor = conn.execute("PRAGMA table_info(model_endpoints)")
|
||||||
@@ -853,9 +889,13 @@ def _migrate_add_model_type_column():
|
|||||||
conn.execute("ALTER TABLE model_endpoints ADD COLUMN model_type TEXT DEFAULT 'llm'")
|
conn.execute("ALTER TABLE model_endpoints ADD COLUMN model_type TEXT DEFAULT 'llm'")
|
||||||
conn.commit()
|
conn.commit()
|
||||||
logging.getLogger(__name__).info("Migrated: added 'model_type' column to model_endpoints")
|
logging.getLogger(__name__).info("Migrated: added 'model_type' column to model_endpoints")
|
||||||
conn.close()
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.getLogger(__name__).warning(f"model_type migration failed: {e}")
|
logging.getLogger(__name__).warning(f"model_type migration failed: {e}")
|
||||||
|
finally:
|
||||||
|
try:
|
||||||
|
conn.close()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
def _migrate_add_model_endpoint_refresh_columns():
|
def _migrate_add_model_endpoint_refresh_columns():
|
||||||
"""Add endpoint classification / refresh policy columns if missing."""
|
"""Add endpoint classification / refresh policy columns if missing."""
|
||||||
@@ -863,6 +903,7 @@ def _migrate_add_model_endpoint_refresh_columns():
|
|||||||
db_path = DATABASE_URL.replace("sqlite:///", "")
|
db_path = DATABASE_URL.replace("sqlite:///", "")
|
||||||
if not os.path.exists(db_path):
|
if not os.path.exists(db_path):
|
||||||
return
|
return
|
||||||
|
conn = None
|
||||||
try:
|
try:
|
||||||
conn = sqlite3.connect(db_path)
|
conn = sqlite3.connect(db_path)
|
||||||
cursor = conn.execute("PRAGMA table_info(model_endpoints)")
|
cursor = conn.execute("PRAGMA table_info(model_endpoints)")
|
||||||
@@ -876,9 +917,13 @@ def _migrate_add_model_endpoint_refresh_columns():
|
|||||||
if columns and "model_refresh_timeout" not in columns:
|
if columns and "model_refresh_timeout" not in columns:
|
||||||
conn.execute("ALTER TABLE model_endpoints ADD COLUMN model_refresh_timeout INTEGER")
|
conn.execute("ALTER TABLE model_endpoints ADD COLUMN model_refresh_timeout INTEGER")
|
||||||
conn.commit()
|
conn.commit()
|
||||||
conn.close()
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.getLogger(__name__).warning(f"model_endpoints refresh-policy migration failed: {e}")
|
logging.getLogger(__name__).warning(f"model_endpoints refresh-policy migration failed: {e}")
|
||||||
|
finally:
|
||||||
|
try:
|
||||||
|
conn.close()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
def _migrate_add_task_run_model_column():
|
def _migrate_add_task_run_model_column():
|
||||||
"""Add model column to task_runs if it doesn't exist (records which model ran)."""
|
"""Add model column to task_runs if it doesn't exist (records which model ran)."""
|
||||||
@@ -886,6 +931,7 @@ def _migrate_add_task_run_model_column():
|
|||||||
db_path = DATABASE_URL.replace("sqlite:///", "")
|
db_path = DATABASE_URL.replace("sqlite:///", "")
|
||||||
if not os.path.exists(db_path):
|
if not os.path.exists(db_path):
|
||||||
return
|
return
|
||||||
|
conn = None
|
||||||
try:
|
try:
|
||||||
conn = sqlite3.connect(db_path)
|
conn = sqlite3.connect(db_path)
|
||||||
cursor = conn.execute("PRAGMA table_info(task_runs)")
|
cursor = conn.execute("PRAGMA table_info(task_runs)")
|
||||||
@@ -894,9 +940,13 @@ def _migrate_add_task_run_model_column():
|
|||||||
conn.execute("ALTER TABLE task_runs ADD COLUMN model TEXT")
|
conn.execute("ALTER TABLE task_runs ADD COLUMN model TEXT")
|
||||||
conn.commit()
|
conn.commit()
|
||||||
logging.getLogger(__name__).info("Migrated: added 'model' column to task_runs")
|
logging.getLogger(__name__).info("Migrated: added 'model' column to task_runs")
|
||||||
conn.close()
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.getLogger(__name__).warning(f"task_runs model migration failed: {e}")
|
logging.getLogger(__name__).warning(f"task_runs model migration failed: {e}")
|
||||||
|
finally:
|
||||||
|
try:
|
||||||
|
conn.close()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
def _migrate_add_supports_tools_column():
|
def _migrate_add_supports_tools_column():
|
||||||
"""Add supports_tools column to model_endpoints if it doesn't exist."""
|
"""Add supports_tools column to model_endpoints if it doesn't exist."""
|
||||||
@@ -904,6 +954,7 @@ def _migrate_add_supports_tools_column():
|
|||||||
db_path = DATABASE_URL.replace("sqlite:///", "")
|
db_path = DATABASE_URL.replace("sqlite:///", "")
|
||||||
if not os.path.exists(db_path):
|
if not os.path.exists(db_path):
|
||||||
return
|
return
|
||||||
|
conn = None
|
||||||
try:
|
try:
|
||||||
conn = sqlite3.connect(db_path)
|
conn = sqlite3.connect(db_path)
|
||||||
cursor = conn.execute("PRAGMA table_info(model_endpoints)")
|
cursor = conn.execute("PRAGMA table_info(model_endpoints)")
|
||||||
@@ -912,9 +963,13 @@ def _migrate_add_supports_tools_column():
|
|||||||
conn.execute("ALTER TABLE model_endpoints ADD COLUMN supports_tools BOOLEAN")
|
conn.execute("ALTER TABLE model_endpoints ADD COLUMN supports_tools BOOLEAN")
|
||||||
conn.commit()
|
conn.commit()
|
||||||
logging.getLogger(__name__).info("Migrated: added 'supports_tools' column to model_endpoints")
|
logging.getLogger(__name__).info("Migrated: added 'supports_tools' column to model_endpoints")
|
||||||
conn.close()
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.getLogger(__name__).warning(f"supports_tools migration failed: {e}")
|
logging.getLogger(__name__).warning(f"supports_tools migration failed: {e}")
|
||||||
|
finally:
|
||||||
|
try:
|
||||||
|
conn.close()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
def _migrate_add_cached_models_column():
|
def _migrate_add_cached_models_column():
|
||||||
@@ -923,6 +978,7 @@ def _migrate_add_cached_models_column():
|
|||||||
db_path = DATABASE_URL.replace("sqlite:///", "")
|
db_path = DATABASE_URL.replace("sqlite:///", "")
|
||||||
if not os.path.exists(db_path):
|
if not os.path.exists(db_path):
|
||||||
return
|
return
|
||||||
|
conn = None
|
||||||
try:
|
try:
|
||||||
conn = sqlite3.connect(db_path)
|
conn = sqlite3.connect(db_path)
|
||||||
cursor = conn.execute("PRAGMA table_info(model_endpoints)")
|
cursor = conn.execute("PRAGMA table_info(model_endpoints)")
|
||||||
@@ -930,9 +986,13 @@ def _migrate_add_cached_models_column():
|
|||||||
if columns and "cached_models" not in columns:
|
if columns and "cached_models" not in columns:
|
||||||
conn.execute("ALTER TABLE model_endpoints ADD COLUMN cached_models TEXT")
|
conn.execute("ALTER TABLE model_endpoints ADD COLUMN cached_models TEXT")
|
||||||
conn.commit()
|
conn.commit()
|
||||||
conn.close()
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.getLogger(__name__).warning(f"cached_models migration failed: {e}")
|
logging.getLogger(__name__).warning(f"cached_models migration failed: {e}")
|
||||||
|
finally:
|
||||||
|
try:
|
||||||
|
conn.close()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
def _migrate_add_pinned_models_column():
|
def _migrate_add_pinned_models_column():
|
||||||
"""Add pinned_models column to model_endpoints if it doesn't exist."""
|
"""Add pinned_models column to model_endpoints if it doesn't exist."""
|
||||||
@@ -940,6 +1000,7 @@ def _migrate_add_pinned_models_column():
|
|||||||
db_path = DATABASE_URL.replace("sqlite:///", "")
|
db_path = DATABASE_URL.replace("sqlite:///", "")
|
||||||
if not os.path.exists(db_path):
|
if not os.path.exists(db_path):
|
||||||
return
|
return
|
||||||
|
conn = None
|
||||||
try:
|
try:
|
||||||
conn = sqlite3.connect(db_path)
|
conn = sqlite3.connect(db_path)
|
||||||
cursor = conn.execute("PRAGMA table_info(model_endpoints)")
|
cursor = conn.execute("PRAGMA table_info(model_endpoints)")
|
||||||
@@ -948,9 +1009,13 @@ def _migrate_add_pinned_models_column():
|
|||||||
conn.execute("ALTER TABLE model_endpoints ADD COLUMN pinned_models TEXT")
|
conn.execute("ALTER TABLE model_endpoints ADD COLUMN pinned_models TEXT")
|
||||||
conn.commit()
|
conn.commit()
|
||||||
logging.getLogger(__name__).info("Migrated: added 'pinned_models' column to model_endpoints")
|
logging.getLogger(__name__).info("Migrated: added 'pinned_models' column to model_endpoints")
|
||||||
conn.close()
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.getLogger(__name__).warning(f"pinned_models migration failed: {e}")
|
logging.getLogger(__name__).warning(f"pinned_models migration failed: {e}")
|
||||||
|
finally:
|
||||||
|
try:
|
||||||
|
conn.close()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
def _migrate_add_notes_sort_order():
|
def _migrate_add_notes_sort_order():
|
||||||
"""Add sort_order, image_url, repeat columns to notes if they don't exist."""
|
"""Add sort_order, image_url, repeat columns to notes if they don't exist."""
|
||||||
@@ -958,6 +1023,7 @@ def _migrate_add_notes_sort_order():
|
|||||||
db_path = DATABASE_URL.replace("sqlite:///", "")
|
db_path = DATABASE_URL.replace("sqlite:///", "")
|
||||||
if not os.path.exists(db_path):
|
if not os.path.exists(db_path):
|
||||||
return
|
return
|
||||||
|
conn = None
|
||||||
try:
|
try:
|
||||||
conn = sqlite3.connect(db_path)
|
conn = sqlite3.connect(db_path)
|
||||||
cursor = conn.execute("PRAGMA table_info(notes)")
|
cursor = conn.execute("PRAGMA table_info(notes)")
|
||||||
@@ -975,9 +1041,13 @@ def _migrate_add_notes_sort_order():
|
|||||||
if columns and "agent_session_id" not in columns:
|
if columns and "agent_session_id" not in columns:
|
||||||
conn.execute("ALTER TABLE notes ADD COLUMN agent_session_id TEXT")
|
conn.execute("ALTER TABLE notes ADD COLUMN agent_session_id TEXT")
|
||||||
conn.commit()
|
conn.commit()
|
||||||
conn.close()
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.getLogger(__name__).warning(f"notes migration failed: {e}")
|
logging.getLogger(__name__).warning(f"notes migration failed: {e}")
|
||||||
|
finally:
|
||||||
|
try:
|
||||||
|
conn.close()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
def _migrate_add_mode_column():
|
def _migrate_add_mode_column():
|
||||||
"""Add mode column to sessions table if it doesn't exist."""
|
"""Add mode column to sessions table if it doesn't exist."""
|
||||||
@@ -985,6 +1055,7 @@ def _migrate_add_mode_column():
|
|||||||
db_path = DATABASE_URL.replace("sqlite:///", "")
|
db_path = DATABASE_URL.replace("sqlite:///", "")
|
||||||
if not os.path.exists(db_path):
|
if not os.path.exists(db_path):
|
||||||
return
|
return
|
||||||
|
conn = None
|
||||||
try:
|
try:
|
||||||
conn = sqlite3.connect(db_path)
|
conn = sqlite3.connect(db_path)
|
||||||
cursor = conn.execute("PRAGMA table_info(sessions)")
|
cursor = conn.execute("PRAGMA table_info(sessions)")
|
||||||
@@ -993,9 +1064,13 @@ def _migrate_add_mode_column():
|
|||||||
conn.execute("ALTER TABLE sessions ADD COLUMN mode TEXT")
|
conn.execute("ALTER TABLE sessions ADD COLUMN mode TEXT")
|
||||||
conn.commit()
|
conn.commit()
|
||||||
logging.getLogger(__name__).info("Migrated: added 'mode' column to sessions")
|
logging.getLogger(__name__).info("Migrated: added 'mode' column to sessions")
|
||||||
conn.close()
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.getLogger(__name__).warning(f"Migration check for mode failed: {e}")
|
logging.getLogger(__name__).warning(f"Migration check for mode failed: {e}")
|
||||||
|
finally:
|
||||||
|
try:
|
||||||
|
conn.close()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
def _migrate_add_folder_column():
|
def _migrate_add_folder_column():
|
||||||
"""Add folder column to sessions table if it doesn't exist."""
|
"""Add folder column to sessions table if it doesn't exist."""
|
||||||
@@ -1003,6 +1078,7 @@ def _migrate_add_folder_column():
|
|||||||
db_path = DATABASE_URL.replace("sqlite:///", "")
|
db_path = DATABASE_URL.replace("sqlite:///", "")
|
||||||
if not os.path.exists(db_path):
|
if not os.path.exists(db_path):
|
||||||
return
|
return
|
||||||
|
conn = None
|
||||||
try:
|
try:
|
||||||
conn = sqlite3.connect(db_path)
|
conn = sqlite3.connect(db_path)
|
||||||
cursor = conn.execute("PRAGMA table_info(sessions)")
|
cursor = conn.execute("PRAGMA table_info(sessions)")
|
||||||
@@ -1011,9 +1087,13 @@ def _migrate_add_folder_column():
|
|||||||
conn.execute("ALTER TABLE sessions ADD COLUMN folder TEXT")
|
conn.execute("ALTER TABLE sessions ADD COLUMN folder TEXT")
|
||||||
conn.commit()
|
conn.commit()
|
||||||
logging.getLogger(__name__).info("Migrated: added 'folder' column to sessions")
|
logging.getLogger(__name__).info("Migrated: added 'folder' column to sessions")
|
||||||
conn.close()
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.getLogger(__name__).warning(f"Migration check for folder failed: {e}")
|
logging.getLogger(__name__).warning(f"Migration check for folder failed: {e}")
|
||||||
|
finally:
|
||||||
|
try:
|
||||||
|
conn.close()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
def _migrate_add_token_columns():
|
def _migrate_add_token_columns():
|
||||||
"""Add cumulative token tracking columns to sessions table."""
|
"""Add cumulative token tracking columns to sessions table."""
|
||||||
@@ -1021,6 +1101,7 @@ def _migrate_add_token_columns():
|
|||||||
db_path = DATABASE_URL.replace("sqlite:///", "")
|
db_path = DATABASE_URL.replace("sqlite:///", "")
|
||||||
if not os.path.exists(db_path):
|
if not os.path.exists(db_path):
|
||||||
return
|
return
|
||||||
|
conn = None
|
||||||
try:
|
try:
|
||||||
conn = sqlite3.connect(db_path)
|
conn = sqlite3.connect(db_path)
|
||||||
cursor = conn.execute("PRAGMA table_info(sessions)")
|
cursor = conn.execute("PRAGMA table_info(sessions)")
|
||||||
@@ -1030,9 +1111,13 @@ def _migrate_add_token_columns():
|
|||||||
conn.execute("ALTER TABLE sessions ADD COLUMN total_output_tokens INTEGER DEFAULT 0")
|
conn.execute("ALTER TABLE sessions ADD COLUMN total_output_tokens INTEGER DEFAULT 0")
|
||||||
conn.commit()
|
conn.commit()
|
||||||
logging.getLogger(__name__).info("Migrated: added token tracking columns to sessions")
|
logging.getLogger(__name__).info("Migrated: added token tracking columns to sessions")
|
||||||
conn.close()
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.getLogger(__name__).warning(f"Migration check for token columns failed: {e}")
|
logging.getLogger(__name__).warning(f"Migration check for token columns failed: {e}")
|
||||||
|
finally:
|
||||||
|
try:
|
||||||
|
conn.close()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
def _migrate_add_owner_to_table(table_name: str, index_name: str):
|
def _migrate_add_owner_to_table(table_name: str, index_name: str):
|
||||||
"""Generic helper: add owner TEXT column + index to a table if missing."""
|
"""Generic helper: add owner TEXT column + index to a table if missing."""
|
||||||
@@ -1040,6 +1125,7 @@ def _migrate_add_owner_to_table(table_name: str, index_name: str):
|
|||||||
db_path = DATABASE_URL.replace("sqlite:///", "")
|
db_path = DATABASE_URL.replace("sqlite:///", "")
|
||||||
if not os.path.exists(db_path):
|
if not os.path.exists(db_path):
|
||||||
return
|
return
|
||||||
|
conn = None
|
||||||
try:
|
try:
|
||||||
conn = sqlite3.connect(db_path)
|
conn = sqlite3.connect(db_path)
|
||||||
cursor = conn.execute(f"PRAGMA table_info({table_name})")
|
cursor = conn.execute(f"PRAGMA table_info({table_name})")
|
||||||
@@ -1049,9 +1135,13 @@ def _migrate_add_owner_to_table(table_name: str, index_name: str):
|
|||||||
conn.execute(f"CREATE INDEX IF NOT EXISTS {index_name} ON {table_name}(owner)")
|
conn.execute(f"CREATE INDEX IF NOT EXISTS {index_name} ON {table_name}(owner)")
|
||||||
conn.commit()
|
conn.commit()
|
||||||
logging.getLogger(__name__).info(f"Migrated: added 'owner' column to {table_name}")
|
logging.getLogger(__name__).info(f"Migrated: added 'owner' column to {table_name}")
|
||||||
conn.close()
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.getLogger(__name__).warning(f"Migration owner column for {table_name} failed: {e}")
|
logging.getLogger(__name__).warning(f"Migration owner column for {table_name} failed: {e}")
|
||||||
|
finally:
|
||||||
|
try:
|
||||||
|
conn.close()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
def _migrate_add_multiuser_owner_columns():
|
def _migrate_add_multiuser_owner_columns():
|
||||||
"""Add owner column to memories, gallery_images, user_tools, comparisons."""
|
"""Add owner column to memories, gallery_images, user_tools, comparisons."""
|
||||||
@@ -1076,6 +1166,7 @@ def _migrate_add_api_token_scopes_column():
|
|||||||
db_path = DATABASE_URL.replace("sqlite:///", "")
|
db_path = DATABASE_URL.replace("sqlite:///", "")
|
||||||
if not os.path.exists(db_path):
|
if not os.path.exists(db_path):
|
||||||
return
|
return
|
||||||
|
conn = None
|
||||||
try:
|
try:
|
||||||
conn = sqlite3.connect(db_path)
|
conn = sqlite3.connect(db_path)
|
||||||
columns = [row[1] for row in conn.execute("PRAGMA table_info(api_tokens)").fetchall()]
|
columns = [row[1] for row in conn.execute("PRAGMA table_info(api_tokens)").fetchall()]
|
||||||
@@ -1084,9 +1175,13 @@ def _migrate_add_api_token_scopes_column():
|
|||||||
conn.execute("UPDATE api_tokens SET scopes = 'chat' WHERE scopes IS NULL OR scopes = ''")
|
conn.execute("UPDATE api_tokens SET scopes = 'chat' WHERE scopes IS NULL OR scopes = ''")
|
||||||
conn.commit()
|
conn.commit()
|
||||||
logging.getLogger(__name__).info("Migrated: added scopes column to api_tokens")
|
logging.getLogger(__name__).info("Migrated: added scopes column to api_tokens")
|
||||||
conn.close()
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.getLogger(__name__).warning(f"api_tokens.scopes migration failed: {e}")
|
logging.getLogger(__name__).warning(f"api_tokens.scopes migration failed: {e}")
|
||||||
|
finally:
|
||||||
|
try:
|
||||||
|
conn.close()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
def _migrate_assign_legacy_owner():
|
def _migrate_assign_legacy_owner():
|
||||||
"""Assign all null-owner data to the first (admin) user.
|
"""Assign all null-owner data to the first (admin) user.
|
||||||
@@ -1128,6 +1223,7 @@ def _migrate_assign_legacy_owner():
|
|||||||
return
|
return
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
conn = None
|
||||||
try:
|
try:
|
||||||
conn = sqlite3.connect(db_path)
|
conn = sqlite3.connect(db_path)
|
||||||
# Every table with an `owner` column. New tables added later will be
|
# Every table with an `owner` column. New tables added later will be
|
||||||
@@ -1152,9 +1248,13 @@ def _migrate_assign_legacy_owner():
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Legacy owner assignment for {table} failed: {e}")
|
logger.warning(f"Legacy owner assignment for {table} failed: {e}")
|
||||||
conn.commit()
|
conn.commit()
|
||||||
conn.close()
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Legacy owner migration failed: {e}")
|
logger.warning(f"Legacy owner migration failed: {e}")
|
||||||
|
finally:
|
||||||
|
try:
|
||||||
|
conn.close()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
# Also migrate memory.json
|
# Also migrate memory.json
|
||||||
mem_path = MEMORY_FILE
|
mem_path = MEMORY_FILE
|
||||||
@@ -1773,6 +1873,7 @@ def _migrate_add_email_smtp_security():
|
|||||||
db_path = DATABASE_URL.replace("sqlite:///", "")
|
db_path = DATABASE_URL.replace("sqlite:///", "")
|
||||||
if not os.path.exists(db_path):
|
if not os.path.exists(db_path):
|
||||||
return
|
return
|
||||||
|
conn = None
|
||||||
try:
|
try:
|
||||||
conn = sqlite3.connect(db_path)
|
conn = sqlite3.connect(db_path)
|
||||||
cursor = conn.execute("PRAGMA table_info(email_accounts)")
|
cursor = conn.execute("PRAGMA table_info(email_accounts)")
|
||||||
@@ -1788,9 +1889,13 @@ def _migrate_add_email_smtp_security():
|
|||||||
)
|
)
|
||||||
conn.commit()
|
conn.commit()
|
||||||
logging.getLogger(__name__).info("Migrated: added smtp_security column to email_accounts")
|
logging.getLogger(__name__).info("Migrated: added smtp_security column to email_accounts")
|
||||||
conn.close()
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.getLogger(__name__).warning(f"smtp_security migration skipped: {e}")
|
logging.getLogger(__name__).warning(f"smtp_security migration skipped: {e}")
|
||||||
|
finally:
|
||||||
|
try:
|
||||||
|
conn.close()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
def _migrate_encrypt_endpoint_keys():
|
def _migrate_encrypt_endpoint_keys():
|
||||||
@@ -1891,6 +1996,7 @@ def _migrate_add_calendar_is_utc():
|
|||||||
db_path = DATABASE_URL.replace("sqlite:///", "")
|
db_path = DATABASE_URL.replace("sqlite:///", "")
|
||||||
if not os.path.exists(db_path):
|
if not os.path.exists(db_path):
|
||||||
return
|
return
|
||||||
|
conn = None
|
||||||
try:
|
try:
|
||||||
conn = sqlite3.connect(db_path)
|
conn = sqlite3.connect(db_path)
|
||||||
cursor = conn.execute("PRAGMA table_info(calendar_events)")
|
cursor = conn.execute("PRAGMA table_info(calendar_events)")
|
||||||
@@ -1899,9 +2005,13 @@ def _migrate_add_calendar_is_utc():
|
|||||||
conn.execute("ALTER TABLE calendar_events ADD COLUMN is_utc BOOLEAN DEFAULT 0 NOT NULL")
|
conn.execute("ALTER TABLE calendar_events ADD COLUMN is_utc BOOLEAN DEFAULT 0 NOT NULL")
|
||||||
conn.commit()
|
conn.commit()
|
||||||
logging.getLogger(__name__).info("Migrated: added 'is_utc' column to calendar_events")
|
logging.getLogger(__name__).info("Migrated: added 'is_utc' column to calendar_events")
|
||||||
conn.close()
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.getLogger(__name__).warning(f"is_utc migration failed: {e}")
|
logging.getLogger(__name__).warning(f"is_utc migration failed: {e}")
|
||||||
|
finally:
|
||||||
|
try:
|
||||||
|
conn.close()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
def _migrate_add_calendar_origin():
|
def _migrate_add_calendar_origin():
|
||||||
@@ -1912,6 +2022,7 @@ def _migrate_add_calendar_origin():
|
|||||||
db_path = DATABASE_URL.replace("sqlite:///", "")
|
db_path = DATABASE_URL.replace("sqlite:///", "")
|
||||||
if not os.path.exists(db_path):
|
if not os.path.exists(db_path):
|
||||||
return
|
return
|
||||||
|
conn = None
|
||||||
try:
|
try:
|
||||||
conn = sqlite3.connect(db_path)
|
conn = sqlite3.connect(db_path)
|
||||||
cursor = conn.execute("PRAGMA table_info(calendar_events)")
|
cursor = conn.execute("PRAGMA table_info(calendar_events)")
|
||||||
@@ -1921,9 +2032,13 @@ def _migrate_add_calendar_origin():
|
|||||||
conn.execute("CREATE INDEX IF NOT EXISTS ix_calendar_events_origin ON calendar_events(origin)")
|
conn.execute("CREATE INDEX IF NOT EXISTS ix_calendar_events_origin ON calendar_events(origin)")
|
||||||
conn.commit()
|
conn.commit()
|
||||||
logging.getLogger(__name__).info("Migrated: added 'origin' column to calendar_events")
|
logging.getLogger(__name__).info("Migrated: added 'origin' column to calendar_events")
|
||||||
conn.close()
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.getLogger(__name__).warning(f"calendar_events.origin migration failed: {e}")
|
logging.getLogger(__name__).warning(f"calendar_events.origin migration failed: {e}")
|
||||||
|
finally:
|
||||||
|
try:
|
||||||
|
conn.close()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
def _migrate_add_calendar_account_id():
|
def _migrate_add_calendar_account_id():
|
||||||
@@ -1933,6 +2048,7 @@ def _migrate_add_calendar_account_id():
|
|||||||
db_path = DATABASE_URL.replace("sqlite:///", "")
|
db_path = DATABASE_URL.replace("sqlite:///", "")
|
||||||
if not os.path.exists(db_path):
|
if not os.path.exists(db_path):
|
||||||
return
|
return
|
||||||
|
conn = None
|
||||||
try:
|
try:
|
||||||
conn = sqlite3.connect(db_path)
|
conn = sqlite3.connect(db_path)
|
||||||
cursor = conn.execute("PRAGMA table_info(calendars)")
|
cursor = conn.execute("PRAGMA table_info(calendars)")
|
||||||
@@ -1942,9 +2058,13 @@ def _migrate_add_calendar_account_id():
|
|||||||
conn.execute("CREATE INDEX IF NOT EXISTS ix_calendars_account_id ON calendars(account_id)")
|
conn.execute("CREATE INDEX IF NOT EXISTS ix_calendars_account_id ON calendars(account_id)")
|
||||||
conn.commit()
|
conn.commit()
|
||||||
logging.getLogger(__name__).info("Migrated: added 'account_id' column to calendars")
|
logging.getLogger(__name__).info("Migrated: added 'account_id' column to calendars")
|
||||||
conn.close()
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.getLogger(__name__).warning(f"calendars.account_id migration failed: {e}")
|
logging.getLogger(__name__).warning(f"calendars.account_id migration failed: {e}")
|
||||||
|
finally:
|
||||||
|
try:
|
||||||
|
conn.close()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
def _migrate_add_calendar_metadata():
|
def _migrate_add_calendar_metadata():
|
||||||
@@ -1953,6 +2073,7 @@ def _migrate_add_calendar_metadata():
|
|||||||
db_path = DATABASE_URL.replace("sqlite:///", "")
|
db_path = DATABASE_URL.replace("sqlite:///", "")
|
||||||
if not os.path.exists(db_path):
|
if not os.path.exists(db_path):
|
||||||
return
|
return
|
||||||
|
conn = None
|
||||||
try:
|
try:
|
||||||
conn = sqlite3.connect(db_path)
|
conn = sqlite3.connect(db_path)
|
||||||
cursor = conn.execute("PRAGMA table_info(calendar_events)")
|
cursor = conn.execute("PRAGMA table_info(calendar_events)")
|
||||||
@@ -1964,9 +2085,13 @@ def _migrate_add_calendar_metadata():
|
|||||||
if columns and "last_pinged" not in columns:
|
if columns and "last_pinged" not in columns:
|
||||||
conn.execute("ALTER TABLE calendar_events ADD COLUMN last_pinged DATETIME")
|
conn.execute("ALTER TABLE calendar_events ADD COLUMN last_pinged DATETIME")
|
||||||
conn.commit()
|
conn.commit()
|
||||||
conn.close()
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.getLogger(__name__).warning(f"calendar_events migration failed: {e}")
|
logging.getLogger(__name__).warning(f"calendar_events migration failed: {e}")
|
||||||
|
finally:
|
||||||
|
try:
|
||||||
|
conn.close()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
def get_db():
|
def get_db():
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -191,6 +191,8 @@ def _windows_bash_fallbacks() -> List[str]:
|
|||||||
base = os.environ.get(env_name)
|
base = os.environ.get(env_name)
|
||||||
if base:
|
if base:
|
||||||
roots.append(ntpath.join(base, "Git"))
|
roots.append(ntpath.join(base, "Git"))
|
||||||
|
if env_name == "LocalAppData":
|
||||||
|
roots.append(ntpath.join(base, "Programs", "Git"))
|
||||||
roots.extend(_WINDOWS_BASH_DEFAULT_ROOTS)
|
roots.extend(_WINDOWS_BASH_DEFAULT_ROOTS)
|
||||||
|
|
||||||
paths: List[str] = []
|
paths: List[str] = []
|
||||||
@@ -366,6 +368,10 @@ def _ssh_exec_argv(
|
|||||||
strict_host_key_checking: bool | None = None,
|
strict_host_key_checking: bool | None = None,
|
||||||
) -> list[str]:
|
) -> list[str]:
|
||||||
"""Build a consistent ssh argv for remote command execution."""
|
"""Build a consistent ssh argv for remote command execution."""
|
||||||
|
remote_value = str(remote or "").strip()
|
||||||
|
remote_host = remote_value.rsplit("@", 1)[-1]
|
||||||
|
if not remote_value or remote_value.startswith("-") or not remote_host or remote_host.startswith("-"):
|
||||||
|
raise ValueError("Invalid SSH remote host")
|
||||||
argv = ["ssh"]
|
argv = ["ssh"]
|
||||||
if connect_timeout is not None:
|
if connect_timeout is not None:
|
||||||
argv.extend(["-o", f"ConnectTimeout={int(connect_timeout)}"])
|
argv.extend(["-o", f"ConnectTimeout={int(connect_timeout)}"])
|
||||||
|
|||||||
+10
-3
@@ -25,9 +25,16 @@
|
|||||||
--radius: 8px;
|
--radius: 8px;
|
||||||
}
|
}
|
||||||
* { box-sizing: border-box; }
|
* { box-sizing: border-box; }
|
||||||
html { scroll-behavior: smooth; scroll-snap-type: y proximity; scroll-padding-top: 60px; }
|
html { scroll-behavior: smooth; scroll-padding-top: 60px; }
|
||||||
/* Each section is a full-viewport "page" with its content centered, so only
|
/* REMOVED: "scroll-snap-type: y proximity"
|
||||||
one shows at a time and the snap is obvious. */
|
The idea was: >>Each section is a full-viewport "page" with its content centered,
|
||||||
|
so only one shows at a time and the snap is obvious.<<
|
||||||
|
|
||||||
|
PROBLEM: sections easily grow taller than 100vh IRL
|
||||||
|
This cause forced jumps mid-read. It's intrusive UX.
|
||||||
|
The landing-page is not a PowerPoint presentation!
|
||||||
|
|
||||||
|
Preserved: CSS snap-points to avoid destroying code meta-data*/
|
||||||
.hero, section {
|
.hero, section {
|
||||||
scroll-snap-align: start; min-height: 100vh;
|
scroll-snap-align: start; min-height: 100vh;
|
||||||
display: flex; flex-direction: column; justify-content: center;
|
display: flex; flex-direction: column; justify-content: center;
|
||||||
|
|||||||
+14
-2
@@ -30,14 +30,26 @@ function Fail($msg) {
|
|||||||
exit 1
|
exit 1
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function Test-WindowsBashStub($path) {
|
||||||
|
if (-not $path) { return $false }
|
||||||
|
$lowered = $path.ToLowerInvariant()
|
||||||
|
foreach ($stub in @("system32\bash.exe", "sysnative\bash.exe", "windowsapps\bash.exe")) {
|
||||||
|
if ($lowered.Contains($stub)) { return $true }
|
||||||
|
}
|
||||||
|
return $false
|
||||||
|
}
|
||||||
|
|
||||||
function Find-GitBash {
|
function Find-GitBash {
|
||||||
$cmd = Get-Command bash -ErrorAction SilentlyContinue
|
$cmd = Get-Command bash -ErrorAction SilentlyContinue
|
||||||
if ($cmd) { return $cmd.Source }
|
if ($cmd -and -not (Test-WindowsBashStub $cmd.Source)) { return $cmd.Source }
|
||||||
|
|
||||||
$roots = @()
|
$roots = @()
|
||||||
foreach ($name in @("ProgramFiles", "ProgramW6432", "ProgramFiles(x86)", "LocalAppData")) {
|
foreach ($name in @("ProgramFiles", "ProgramW6432", "ProgramFiles(x86)", "LocalAppData")) {
|
||||||
$base = [Environment]::GetEnvironmentVariable($name)
|
$base = [Environment]::GetEnvironmentVariable($name)
|
||||||
if ($base) { $roots += (Join-Path $base "Git") }
|
if ($base) {
|
||||||
|
$roots += (Join-Path $base "Git")
|
||||||
|
if ($name -eq "LocalAppData") { $roots += (Join-Path $base "Programs\Git") }
|
||||||
|
}
|
||||||
}
|
}
|
||||||
$roots += @("C:\Program Files\Git", "C:\Program Files (x86)\Git")
|
$roots += @("C:\Program Files\Git", "C:\Program Files (x86)\Git")
|
||||||
|
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ faster-whisper
|
|||||||
# DuckDuckGo as a search provider option.
|
# DuckDuckGo as a search provider option.
|
||||||
# Install if you want DDG in the search-provider dropdown.
|
# Install if you want DDG in the search-provider dropdown.
|
||||||
# Alternatives: SearXNG, Brave, Tavily, Serper, Google PSE.
|
# Alternatives: SearXNG, Brave, Tavily, Serper, Google PSE.
|
||||||
duckduckgo-search
|
ddgs
|
||||||
|
|
||||||
# PDF form-filling feature (fillable AcroForm detection, field extraction,
|
# PDF form-filling feature (fillable AcroForm detection, field extraction,
|
||||||
# value/annotation/signature stamping, page rendering for the form overlay).
|
# value/annotation/signature stamping, page rendering for the form overlay).
|
||||||
|
|||||||
@@ -43,3 +43,7 @@ qrcode[pil]
|
|||||||
croniter
|
croniter
|
||||||
pytest
|
pytest
|
||||||
pytest-asyncio
|
pytest-asyncio
|
||||||
|
# starlette.testclient prefers httpx2 since Starlette 1.2.0 and warns on every
|
||||||
|
# TestClient import when only classic httpx is present. Runtime code keeps
|
||||||
|
# using `httpx` above; this is test-client only.
|
||||||
|
httpx2
|
||||||
|
|||||||
@@ -0,0 +1,31 @@
|
|||||||
|
import re
|
||||||
|
|
||||||
|
from fastapi import HTTPException
|
||||||
|
|
||||||
|
|
||||||
|
_REMOTE_HOST_RE = re.compile(
|
||||||
|
r"^(?:[A-Za-z0-9][A-Za-z0-9._-]*@)?[A-Za-z0-9][A-Za-z0-9._-]*$"
|
||||||
|
)
|
||||||
|
_SSH_PORT_RE = re.compile(r"^\d{1,5}$")
|
||||||
|
|
||||||
|
|
||||||
|
def validate_remote_host(v: str | None) -> str | None:
|
||||||
|
if v is None or v == "":
|
||||||
|
return None
|
||||||
|
if not _REMOTE_HOST_RE.match(v):
|
||||||
|
raise HTTPException(
|
||||||
|
400,
|
||||||
|
"Invalid remote_host — must be host or user@host, no SSH option syntax",
|
||||||
|
)
|
||||||
|
return v
|
||||||
|
|
||||||
|
|
||||||
|
def validate_ssh_port(v: str | None) -> str | None:
|
||||||
|
if v is None or v == "":
|
||||||
|
return None
|
||||||
|
if not _SSH_PORT_RE.fullmatch(str(v)):
|
||||||
|
raise HTTPException(400, "Invalid ssh_port")
|
||||||
|
port = int(v)
|
||||||
|
if port < 1 or port > 65535:
|
||||||
|
raise HTTPException(400, "Invalid ssh_port")
|
||||||
|
return str(port)
|
||||||
@@ -154,6 +154,7 @@ def setup_api_token_routes() -> APIRouter:
|
|||||||
@router.patch("/tokens/{token_id}")
|
@router.patch("/tokens/{token_id}")
|
||||||
async def update_token(request: Request, token_id: str):
|
async def update_token(request: Request, token_id: str):
|
||||||
require_admin(request)
|
require_admin(request)
|
||||||
|
current_user = get_current_user(request)
|
||||||
try:
|
try:
|
||||||
payload = await request.json()
|
payload = await request.json()
|
||||||
except Exception:
|
except Exception:
|
||||||
@@ -162,6 +163,8 @@ def setup_api_token_routes() -> APIRouter:
|
|||||||
token = db.query(ApiToken).filter(ApiToken.id == token_id).first()
|
token = db.query(ApiToken).filter(ApiToken.id == token_id).first()
|
||||||
if not token:
|
if not token:
|
||||||
raise HTTPException(404, "Token not found")
|
raise HTTPException(404, "Token not found")
|
||||||
|
if current_user and token.owner != current_user:
|
||||||
|
raise HTTPException(403, "Not your token")
|
||||||
if isinstance(payload.get("name"), str) and payload["name"].strip():
|
if isinstance(payload.get("name"), str) and payload["name"].strip():
|
||||||
token.name = payload["name"].strip()[:MAX_NAME_LEN]
|
token.name = payload["name"].strip()[:MAX_NAME_LEN]
|
||||||
# Only touch scopes when the caller actually sent them. A partial
|
# Only touch scopes when the caller actually sent them. A partial
|
||||||
@@ -189,10 +192,14 @@ def setup_api_token_routes() -> APIRouter:
|
|||||||
@router.delete("/tokens/{token_id}")
|
@router.delete("/tokens/{token_id}")
|
||||||
def delete_token(request: Request, token_id: str):
|
def delete_token(request: Request, token_id: str):
|
||||||
require_admin(request)
|
require_admin(request)
|
||||||
|
current_user = get_current_user(request)
|
||||||
with get_db_session() as db:
|
with get_db_session() as db:
|
||||||
deleted = db.query(ApiToken).filter(ApiToken.id == token_id).delete()
|
token = db.query(ApiToken).filter(ApiToken.id == token_id).first()
|
||||||
if not deleted:
|
if not token:
|
||||||
raise HTTPException(404, "Token not found")
|
raise HTTPException(404, "Token not found")
|
||||||
|
if current_user and token.owner != current_user:
|
||||||
|
raise HTTPException(403, "Not your token")
|
||||||
|
db.delete(token)
|
||||||
_invalidate_cache(request)
|
_invalidate_cache(request)
|
||||||
return {"status": "deleted"}
|
return {"status": "deleted"}
|
||||||
|
|
||||||
|
|||||||
+65
-10
@@ -305,6 +305,19 @@ def setup_auth_routes(auth_manager: AuthManager) -> APIRouter:
|
|||||||
if not ok:
|
if not ok:
|
||||||
raise HTTPException(400, "Cannot rename user")
|
raise HTTPException(400, "Cannot rename user")
|
||||||
|
|
||||||
|
def _rollback_auth_rename() -> bool:
|
||||||
|
# On self-rename the admin session has already moved to the new
|
||||||
|
# username, so the rollback must authenticate as the new user.
|
||||||
|
rollback_user = new_username if user == old_username else user
|
||||||
|
try:
|
||||||
|
return bool(auth_manager.rename_user(new_username, old_username, rollback_user))
|
||||||
|
except Exception as rollback_err:
|
||||||
|
logger.error(
|
||||||
|
"Failed to roll back auth rename %s -> %s after owner migration failure: %s",
|
||||||
|
new_username, old_username, rollback_err,
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
# Usernames are ownership keys for user data. Rename the common
|
# Usernames are ownership keys for user data. Rename the common
|
||||||
# owner-scoped DB rows so the account keeps access to its sessions,
|
# owner-scoped DB rows so the account keeps access to its sessions,
|
||||||
# docs, email accounts, tasks, etc.
|
# docs, email accounts, tasks, etc.
|
||||||
@@ -330,6 +343,11 @@ def setup_auth_routes(auth_manager: AuthManager) -> APIRouter:
|
|||||||
db.close()
|
db.close()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("Failed to rename owner references %s -> %s: %s", old_username, new_username, e)
|
logger.error("Failed to rename owner references %s -> %s: %s", old_username, new_username, e)
|
||||||
|
if not _rollback_auth_rename():
|
||||||
|
logger.error(
|
||||||
|
"Auth rename %s -> %s could not be rolled back after owner migration failure",
|
||||||
|
old_username, new_username,
|
||||||
|
)
|
||||||
raise HTTPException(500, "Failed to rename user data")
|
raise HTTPException(500, "Failed to rename user data")
|
||||||
|
|
||||||
# Per-user prefs are JSON-backed, not SQL-backed.
|
# Per-user prefs are JSON-backed, not SQL-backed.
|
||||||
@@ -349,6 +367,20 @@ def setup_auth_routes(auth_manager: AuthManager) -> APIRouter:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning("Failed to rename user prefs %s -> %s: %s", old_username, new_username, e)
|
logger.warning("Failed to rename user prefs %s -> %s: %s", old_username, new_username, e)
|
||||||
|
|
||||||
|
# In-flight deep-research tasks live in the process-local
|
||||||
|
# ResearchHandler registry. They are not covered by the persisted JSON
|
||||||
|
# migration above, but the research routes filter and cancel by this
|
||||||
|
# owner field while the job is running. Do this before sweeping
|
||||||
|
# completed JSON files so a job that finishes during the rename saves
|
||||||
|
# with the new owner or is caught by the disk sweep below.
|
||||||
|
try:
|
||||||
|
rh = getattr(request.app.state, "research_handler", None)
|
||||||
|
rename_owner = getattr(rh, "rename_owner", None)
|
||||||
|
if callable(rename_owner):
|
||||||
|
rename_owner(old_username, new_username)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("Failed to rename active research tasks %s -> %s: %s", old_username, new_username, e)
|
||||||
|
|
||||||
# deep_research: each completed report is a standalone JSON file with
|
# deep_research: each completed report is a standalone JSON file with
|
||||||
# an `owner` field. research_routes filters by d.get("owner") == user,
|
# an `owner` field. research_routes filters by d.get("owner") == user,
|
||||||
# so a stale owner makes every report invisible to the renamed user.
|
# so a stale owner makes every report invisible to the renamed user.
|
||||||
@@ -384,6 +416,17 @@ def setup_auth_routes(auth_manager: AuthManager) -> APIRouter:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning("Failed to rename memory.json owner references %s -> %s: %s", old_username, new_username, e)
|
logger.warning("Failed to rename memory.json owner references %s -> %s: %s", old_username, new_username, e)
|
||||||
|
|
||||||
|
# uploads.json: upload rows use owner metadata for access checks and
|
||||||
|
# owner-prefixed index keys for dedupe. Rename both so attachments keep
|
||||||
|
# resolving after the account username changes.
|
||||||
|
try:
|
||||||
|
upload_handler = getattr(request.app.state, "upload_handler", None)
|
||||||
|
rename_owner = getattr(upload_handler, "rename_owner", None)
|
||||||
|
if callable(rename_owner):
|
||||||
|
rename_owner(old_username, new_username)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("Failed to rename upload owner references %s -> %s: %s", old_username, new_username, e)
|
||||||
|
|
||||||
# skills: SKILL.md frontmatter carries owner: <username>; the usage
|
# skills: SKILL.md frontmatter carries owner: <username>; the usage
|
||||||
# sidecar (_usage.json) keys entries as owner::skill-name. Both must
|
# sidecar (_usage.json) keys entries as owner::skill-name. Both must
|
||||||
# be updated or the renamed user's Skills panel goes empty.
|
# be updated or the renamed user's Skills panel goes empty.
|
||||||
@@ -391,7 +434,8 @@ def setup_auth_routes(auth_manager: AuthManager) -> APIRouter:
|
|||||||
skills_root = Path(SKILLS_DIR)
|
skills_root = Path(SKILLS_DIR)
|
||||||
if skills_root.is_dir():
|
if skills_root.is_dir():
|
||||||
_owner_re = re.compile(
|
_owner_re = re.compile(
|
||||||
r'(?m)^(owner:\s*)' + re.escape(old_username) + r'\s*$'
|
r'(?m)^(owner:\s*)' + re.escape(old_username) + r'\s*$',
|
||||||
|
re.IGNORECASE,
|
||||||
)
|
)
|
||||||
for p in skills_root.rglob("SKILL.md"):
|
for p in skills_root.rglob("SKILL.md"):
|
||||||
try:
|
try:
|
||||||
@@ -406,12 +450,12 @@ def setup_auth_routes(auth_manager: AuthManager) -> APIRouter:
|
|||||||
try:
|
try:
|
||||||
usage = json.loads(usage_path.read_text(encoding="utf-8"))
|
usage = json.loads(usage_path.read_text(encoding="utf-8"))
|
||||||
if isinstance(usage, dict):
|
if isinstance(usage, dict):
|
||||||
prefix = old_username + "::"
|
|
||||||
new_usage = {}
|
new_usage = {}
|
||||||
changed = False
|
changed = False
|
||||||
for k, v in usage.items():
|
for k, v in usage.items():
|
||||||
if k.startswith(prefix):
|
owner_part, sep, skill_part = k.partition("::")
|
||||||
new_usage[new_username + "::" + k[len(prefix):]] = v
|
if sep and owner_part.lower() == old_username:
|
||||||
|
new_usage[new_username + "::" + skill_part] = v
|
||||||
changed = True
|
changed = True
|
||||||
else:
|
else:
|
||||||
new_usage[k] = v
|
new_usage[k] = v
|
||||||
@@ -473,7 +517,23 @@ def setup_auth_routes(auth_manager: AuthManager) -> APIRouter:
|
|||||||
user = _get_current_user(request)
|
user = _get_current_user(request)
|
||||||
if not user or not auth_manager.is_admin(user):
|
if not user or not auth_manager.is_admin(user):
|
||||||
raise HTTPException(403, "Admin only")
|
raise HTTPException(403, "Admin only")
|
||||||
|
|
||||||
|
def _invalidate_api_token_cache():
|
||||||
|
try:
|
||||||
|
invalidator = getattr(request.app.state, "invalidate_token_cache", None)
|
||||||
|
if invalidator:
|
||||||
|
invalidator()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
ok = auth_manager.delete_user(body.username, user)
|
ok = auth_manager.delete_user(body.username, user)
|
||||||
|
except Exception:
|
||||||
|
# delete_user can touch ApiToken rows before a later auth-store write
|
||||||
|
# fails. Dirty the bearer cache anyway so a partial token purge does
|
||||||
|
# not leave already-cached tokens authenticating until restart.
|
||||||
|
_invalidate_api_token_cache()
|
||||||
|
raise
|
||||||
if not ok:
|
if not ok:
|
||||||
raise HTTPException(400, "Cannot delete user")
|
raise HTTPException(400, "Cannot delete user")
|
||||||
# delete_user removes the user's ApiToken rows, but the bearer-auth
|
# delete_user removes the user's ApiToken rows, but the bearer-auth
|
||||||
@@ -481,12 +541,7 @@ def setup_auth_routes(auth_manager: AuthManager) -> APIRouter:
|
|||||||
# rebuilds when flagged dirty. Without this, a deleted user's already
|
# rebuilds when flagged dirty. Without this, a deleted user's already
|
||||||
# cached token keeps authenticating until some other token op or a
|
# cached token keeps authenticating until some other token op or a
|
||||||
# restart clears the cache. Mirror what the token routes do.
|
# restart clears the cache. Mirror what the token routes do.
|
||||||
try:
|
_invalidate_api_token_cache()
|
||||||
invalidator = getattr(request.app.state, "invalidate_token_cache", None)
|
|
||||||
if invalidator:
|
|
||||||
invalidator()
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
return {"ok": True}
|
return {"ok": True}
|
||||||
|
|
||||||
# ---- Feature visibility (admin-managed) ----
|
# ---- Feature visibility (admin-managed) ----
|
||||||
|
|||||||
@@ -62,6 +62,33 @@ def _stream_set(session_id: str, **fields) -> None:
|
|||||||
rec.update(fields)
|
rec.update(fields)
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_request_workspace(request, raw_value) -> tuple:
|
||||||
|
"""Resolve the posted workspace for this request: (workspace, rejected).
|
||||||
|
|
||||||
|
Privilege is checked BEFORE the path ever touches the filesystem. Only
|
||||||
|
admin/single-user callers can use the workspace-backed file/shell tools,
|
||||||
|
so only they get vet_workspace() and the workspace_rejected signal. For
|
||||||
|
any other caller the submitted value is dropped uniformly, with no vetting
|
||||||
|
and no event: otherwise the presence/absence of workspace_rejected would
|
||||||
|
let a non-admin chat caller probe which host paths exist.
|
||||||
|
|
||||||
|
vet_workspace rejects non-directories, sensitive roots (.ssh, .gnupg,
|
||||||
|
...), and filesystem roots; on rejection there is no confinement and the
|
||||||
|
default tool-path allowlist applies. The rejected value is surfaced so the
|
||||||
|
stream can tell an admin client (which believes a workspace is active)
|
||||||
|
that it was dropped.
|
||||||
|
"""
|
||||||
|
requested = (raw_value or "").strip()
|
||||||
|
if not requested:
|
||||||
|
return "", ""
|
||||||
|
from src.tool_security import owner_is_admin_or_single_user
|
||||||
|
if not owner_is_admin_or_single_user(get_current_user(request)):
|
||||||
|
return "", ""
|
||||||
|
from src.tool_execution import vet_workspace
|
||||||
|
workspace = vet_workspace(requested) or ""
|
||||||
|
return workspace, (requested if not workspace else "")
|
||||||
|
|
||||||
|
|
||||||
def _session_url_matches_endpoint(session_url: str, endpoint_base: str) -> bool:
|
def _session_url_matches_endpoint(session_url: str, endpoint_base: str) -> bool:
|
||||||
if not session_url or not endpoint_base:
|
if not session_url or not endpoint_base:
|
||||||
return False
|
return False
|
||||||
@@ -457,6 +484,10 @@ def setup_chat_routes(
|
|||||||
# manual form posts that still send plan_mode=true.
|
# manual form posts that still send plan_mode=true.
|
||||||
plan_mode = False
|
plan_mode = False
|
||||||
chat_mode = str(form_data.get("mode", "")).lower() # 'chat' or 'agent'
|
chat_mode = str(form_data.get("mode", "")).lower() # 'chat' or 'agent'
|
||||||
|
# Workspace: confine the agent's file/shell tools to this folder.
|
||||||
|
workspace, workspace_rejected = _resolve_request_workspace(
|
||||||
|
request, form_data.get("workspace")
|
||||||
|
)
|
||||||
# Plan mode is a modifier on agent mode — it only makes sense with tools.
|
# Plan mode is a modifier on agent mode — it only makes sense with tools.
|
||||||
if plan_mode:
|
if plan_mode:
|
||||||
chat_mode = "agent"
|
chat_mode = "agent"
|
||||||
@@ -761,6 +792,13 @@ def setup_chat_routes(
|
|||||||
# Register active stream for partial-save safety net
|
# Register active stream for partial-save safety net
|
||||||
_active_streams[session] = {"status": "streaming", "partial": "", "query": message, "is_research": effective_do_research, "mode": _effective_mode}
|
_active_streams[session] = {"status": "streaming", "partial": "", "query": message, "is_research": effective_do_research, "mode": _effective_mode}
|
||||||
|
|
||||||
|
# The client sent a workspace the server refused to bind (deleted
|
||||||
|
# folder, file path, sensitive dir, filesystem root). Tell it up
|
||||||
|
# front so the UI can clear the pill instead of displaying a
|
||||||
|
# confinement that is not actually in effect.
|
||||||
|
if workspace_rejected:
|
||||||
|
yield f"data: {json.dumps({'type': 'workspace_rejected', 'data': {'path': workspace_rejected}})}\n\n"
|
||||||
|
|
||||||
if ctx.preprocessed.attachment_meta:
|
if ctx.preprocessed.attachment_meta:
|
||||||
yield f"data: {json.dumps({'type': 'attachments', 'data': ctx.preprocessed.attachment_meta})}\n\n"
|
yield f"data: {json.dumps({'type': 'attachments', 'data': ctx.preprocessed.attachment_meta})}\n\n"
|
||||||
|
|
||||||
@@ -1138,6 +1176,7 @@ def setup_chat_routes(
|
|||||||
fallbacks=_fallback_candidates,
|
fallbacks=_fallback_candidates,
|
||||||
plan_mode=plan_mode,
|
plan_mode=plan_mode,
|
||||||
approved_plan=approved_plan or None,
|
approved_plan=approved_plan or None,
|
||||||
|
workspace=workspace or None,
|
||||||
):
|
):
|
||||||
if chunk.startswith("data: ") and not chunk.startswith("data: [DONE]"):
|
if chunk.startswith("data: ") and not chunk.startswith("data: [DONE]"):
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -729,8 +729,11 @@ def setup_contacts_routes():
|
|||||||
@router.post("/import")
|
@router.post("/import")
|
||||||
async def import_vcf(data: dict, _admin: str = Depends(require_admin)):
|
async def import_vcf(data: dict, _admin: str = Depends(require_admin)):
|
||||||
"""Import contacts from .vcf or CSV. Body: {"vcf": "..."} or {"csv": "..."}."""
|
"""Import contacts from .vcf or CSV. Body: {"vcf": "..."} or {"csv": "..."}."""
|
||||||
text = data.get("vcf") or data.get("text") or ""
|
# Coerce defensively: a non-string vcf/text/csv (e.g. a number or list
|
||||||
csv_text = data.get("csv") or ""
|
# in the JSON body) would otherwise reach .strip() and 500 with an
|
||||||
|
# AttributeError instead of degrading to a clean "no data" response.
|
||||||
|
text = str(data.get("vcf") or data.get("text") or "")
|
||||||
|
csv_text = str(data.get("csv") or "")
|
||||||
if text.strip():
|
if text.strip():
|
||||||
if "BEGIN:VCARD" not in text.upper():
|
if "BEGIN:VCARD" not in text.upper():
|
||||||
return {"success": False, "error": "No vCard data found"}
|
return {"success": False, "error": "No vCard data found"}
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ import shlex
|
|||||||
from fastapi import HTTPException
|
from fastapi import HTTPException
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from routes._validators import validate_remote_host, validate_ssh_port
|
||||||
from core.platform_compat import _ssh_exec_argv
|
from core.platform_compat import _ssh_exec_argv
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -30,16 +31,12 @@ _LOCAL_MODEL_ID_RE = re.compile(r"^[A-Za-z0-9][A-Za-z0-9._-]*$")
|
|||||||
_OLLAMA_MODEL_ID_RE = re.compile(r"^[A-Za-z0-9][A-Za-z0-9._:/-]{0,200}$")
|
_OLLAMA_MODEL_ID_RE = re.compile(r"^[A-Za-z0-9][A-Za-z0-9._:/-]{0,200}$")
|
||||||
# Include pattern is a glob: allow typical safe glyphs only.
|
# Include pattern is a glob: allow typical safe glyphs only.
|
||||||
_INCLUDE_RE = re.compile(r"^[A-Za-z0-9._\-*?/\[\]]+$")
|
_INCLUDE_RE = re.compile(r"^[A-Za-z0-9._\-*?/\[\]]+$")
|
||||||
# Remote host: either `user@host` or plain `host` (alias is allowed), where host
|
|
||||||
# is a safe DNS-like token or a short SSH config alias.
|
|
||||||
_REMOTE_HOST_RE = re.compile(r"^(?:[A-Za-z0-9._-]+@)?[A-Za-z0-9._-]+$")
|
|
||||||
# HF tokens and API tokens are url-safe base64-like.
|
# HF tokens and API tokens are url-safe base64-like.
|
||||||
_TOKEN_RE = re.compile(r"^[A-Za-z0-9._~+/=-]+$")
|
_TOKEN_RE = re.compile(r"^[A-Za-z0-9._~+/=-]+$")
|
||||||
# Session IDs we mint look like "cookbook-deadbeef" or "serve-deadbeef".
|
# Session IDs we mint look like "cookbook-deadbeef" or "serve-deadbeef".
|
||||||
# Anything beyond plain alphanumerics + dash + underscore could break out
|
# Anything beyond plain alphanumerics + dash + underscore could break out
|
||||||
# of the shell/PowerShell contexts the value lands in.
|
# of the shell/PowerShell contexts the value lands in.
|
||||||
_SESSION_ID_RE = re.compile(r"^[A-Za-z0-9_-]{1,64}$")
|
_SESSION_ID_RE = re.compile(r"^[A-Za-z0-9_-]{1,64}$")
|
||||||
_SSH_PORT_RE = re.compile(r"^\d{1,5}$")
|
|
||||||
_GPU_LIST_RE = re.compile(r"^\d+(?:,\d+)*$")
|
_GPU_LIST_RE = re.compile(r"^\d+(?:,\d+)*$")
|
||||||
# A download target directory. Absolute or ~-relative path; safe path glyphs
|
# A download target directory. Absolute or ~-relative path; safe path glyphs
|
||||||
# only (no quotes or shell metacharacters). Spaces are allowed because command
|
# only (no quotes or shell metacharacters). Spaces are allowed because command
|
||||||
@@ -85,14 +82,6 @@ def _validate_include(v: str | None) -> str | None:
|
|||||||
return v
|
return v
|
||||||
|
|
||||||
|
|
||||||
def _validate_remote_host(v: str | None) -> str | None:
|
|
||||||
if v is None or v == "":
|
|
||||||
return None
|
|
||||||
if not _REMOTE_HOST_RE.match(v):
|
|
||||||
raise HTTPException(400, "Invalid remote_host — must be host or user@host, no SSH option syntax")
|
|
||||||
return v
|
|
||||||
|
|
||||||
|
|
||||||
def _validate_token(v: str | None) -> str | None:
|
def _validate_token(v: str | None) -> str | None:
|
||||||
if v is None or v == "":
|
if v is None or v == "":
|
||||||
return None
|
return None
|
||||||
@@ -120,17 +109,6 @@ def _validate_local_dir(v: str | None) -> str | None:
|
|||||||
return v
|
return v
|
||||||
|
|
||||||
|
|
||||||
def _validate_ssh_port(v: str | None) -> str | None:
|
|
||||||
if v is None or v == "":
|
|
||||||
return None
|
|
||||||
if not _SSH_PORT_RE.fullmatch(str(v)):
|
|
||||||
raise HTTPException(400, "Invalid ssh_port")
|
|
||||||
port = int(v)
|
|
||||||
if port < 1 or port > 65535:
|
|
||||||
raise HTTPException(400, "Invalid ssh_port")
|
|
||||||
return str(port)
|
|
||||||
|
|
||||||
|
|
||||||
def _validate_gpus(v: str | None) -> str | None:
|
def _validate_gpus(v: str | None) -> str | None:
|
||||||
if v is None or v == "":
|
if v is None or v == "":
|
||||||
return None
|
return None
|
||||||
|
|||||||
+38
-26
@@ -19,6 +19,7 @@ from src.constants import COOKBOOK_STATE_FILE
|
|||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from core.middleware import require_admin
|
from core.middleware import require_admin
|
||||||
|
from routes._validators import validate_remote_host, validate_ssh_port
|
||||||
from core.platform_compat import (
|
from core.platform_compat import (
|
||||||
IS_WINDOWS,
|
IS_WINDOWS,
|
||||||
detached_popen_kwargs,
|
detached_popen_kwargs,
|
||||||
@@ -33,9 +34,8 @@ from routes.shell_routes import TMUX_LOG_DIR
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
from routes.cookbook_helpers import (
|
from routes.cookbook_helpers import (
|
||||||
_SSH_PORT_RE, _REMOTE_HOST_RE, _SESSION_ID_RE,
|
_SESSION_ID_RE, _validate_repo_id, _validate_serve_model_id, _validate_include, _validate_token,
|
||||||
_validate_repo_id, _validate_serve_model_id, _validate_include, _validate_remote_host, _validate_token,
|
_validate_local_dir, _validate_gpus, _shell_path,
|
||||||
_validate_local_dir, _validate_ssh_port, _validate_gpus, _shell_path,
|
|
||||||
_ps_squote, _bash_squote, _validate_serve_cmd, _parse_serve_phase,
|
_ps_squote, _bash_squote, _validate_serve_cmd, _parse_serve_phase,
|
||||||
_safe_env_prefix, _local_tooling_path_export, _append_serve_preflight_exit_lines,
|
_safe_env_prefix, _local_tooling_path_export, _append_serve_preflight_exit_lines,
|
||||||
_append_serve_exit_code_lines, _append_llama_cpp_linux_accel_build_lines, _cached_model_scan_script,
|
_append_serve_exit_code_lines, _append_llama_cpp_linux_accel_build_lines, _cached_model_scan_script,
|
||||||
@@ -407,8 +407,8 @@ def setup_cookbook_routes() -> APIRouter:
|
|||||||
else:
|
else:
|
||||||
_validate_repo_id(req.repo_id)
|
_validate_repo_id(req.repo_id)
|
||||||
_validate_include(req.include)
|
_validate_include(req.include)
|
||||||
_validate_remote_host(req.remote_host)
|
validate_remote_host(req.remote_host)
|
||||||
req.ssh_port = _validate_ssh_port(req.ssh_port)
|
req.ssh_port = validate_ssh_port(req.ssh_port)
|
||||||
req.local_dir = _validate_local_dir(req.local_dir)
|
req.local_dir = _validate_local_dir(req.local_dir)
|
||||||
req.hf_token = "" if is_ollama_download else (req.hf_token or _load_stored_hf_token())
|
req.hf_token = "" if is_ollama_download else (req.hf_token or _load_stored_hf_token())
|
||||||
_validate_token(req.hf_token)
|
_validate_token(req.hf_token)
|
||||||
@@ -739,9 +739,8 @@ def setup_cookbook_routes() -> APIRouter:
|
|||||||
# Validate shell-bound inputs, matching the sibling list_gpus endpoint —
|
# Validate shell-bound inputs, matching the sibling list_gpus endpoint —
|
||||||
# `host`/`ssh_port` are interpolated into an ssh command below, so an
|
# `host`/`ssh_port` are interpolated into an ssh command below, so an
|
||||||
# unvalidated value (e.g. "x'; rm -rf ~ #") would be command injection.
|
# unvalidated value (e.g. "x'; rm -rf ~ #") would be command injection.
|
||||||
host = _validate_remote_host(host)
|
host = validate_remote_host(host)
|
||||||
if ssh_port is not None and ssh_port != "" and not _SSH_PORT_RE.fullmatch(ssh_port):
|
ssh_port = validate_ssh_port(ssh_port)
|
||||||
raise HTTPException(400, "Invalid ssh_port")
|
|
||||||
TMUX_LOG_DIR.mkdir(parents=True, exist_ok=True)
|
TMUX_LOG_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
model_dirs = []
|
model_dirs = []
|
||||||
@@ -890,11 +889,16 @@ def setup_cookbook_routes() -> APIRouter:
|
|||||||
# listening" check without requiring ss/netstat/nmap.
|
# listening" check without requiring ss/netstat/nmap.
|
||||||
ssh_base = ["ssh", "-o", "ConnectTimeout=4", "-o", "StrictHostKeyChecking=no"]
|
ssh_base = ["ssh", "-o", "ConnectTimeout=4", "-o", "StrictHostKeyChecking=no"]
|
||||||
if ssh_port and str(ssh_port) != "22":
|
if ssh_port and str(ssh_port) != "22":
|
||||||
if not _SSH_PORT_RE.match(str(ssh_port)):
|
try:
|
||||||
|
ssh_port = validate_ssh_port(ssh_port)
|
||||||
|
except HTTPException:
|
||||||
return None
|
return None
|
||||||
ssh_base.extend(["-p", str(ssh_port)])
|
ssh_base.extend(["-p", str(ssh_port)])
|
||||||
host_arg = remote
|
try:
|
||||||
if not _REMOTE_HOST_RE.match(host_arg):
|
host_arg = validate_remote_host(remote)
|
||||||
|
except HTTPException:
|
||||||
|
return None
|
||||||
|
if not host_arg:
|
||||||
return None
|
return None
|
||||||
probe_ports = " ".join(str(start_port + i) for i in range(max_offset + 1))
|
probe_ports = " ".join(str(start_port + i) for i in range(max_offset + 1))
|
||||||
script = (
|
script = (
|
||||||
@@ -1197,8 +1201,8 @@ def setup_cookbook_routes() -> APIRouter:
|
|||||||
"""
|
"""
|
||||||
require_admin(request)
|
require_admin(request)
|
||||||
# Defence-in-depth: reject values that could break out of shell contexts.
|
# Defence-in-depth: reject values that could break out of shell contexts.
|
||||||
_validate_remote_host(req.remote_host)
|
validate_remote_host(req.remote_host)
|
||||||
req.ssh_port = _validate_ssh_port(req.ssh_port)
|
req.ssh_port = validate_ssh_port(req.ssh_port)
|
||||||
req.gpus = _validate_gpus(req.gpus)
|
req.gpus = _validate_gpus(req.gpus)
|
||||||
req.hf_token = req.hf_token or _load_stored_hf_token()
|
req.hf_token = req.hf_token or _load_stored_hf_token()
|
||||||
_validate_token(req.hf_token)
|
_validate_token(req.hf_token)
|
||||||
@@ -1638,12 +1642,11 @@ def setup_cookbook_routes() -> APIRouter:
|
|||||||
async def server_setup(request: Request, req: SetupRequest):
|
async def server_setup(request: Request, req: SetupRequest):
|
||||||
"""Install required dependencies on a remote server via SSH."""
|
"""Install required dependencies on a remote server via SSH."""
|
||||||
require_admin(request)
|
require_admin(request)
|
||||||
host = _validate_remote_host(req.host)
|
host = validate_remote_host(req.host)
|
||||||
if not host:
|
if not host:
|
||||||
raise HTTPException(400, "host is required")
|
raise HTTPException(400, "host is required")
|
||||||
port = req.ssh_port
|
port = req.ssh_port
|
||||||
if port is not None and port != "" and not re.fullmatch(r"\d{1,5}", port):
|
port = validate_ssh_port(port)
|
||||||
raise HTTPException(400, "Invalid ssh_port")
|
|
||||||
pf = f"-p {port} " if port and port != "22" else ""
|
pf = f"-p {port} " if port and port != "22" else ""
|
||||||
|
|
||||||
# Detect platform: Windows first (echo %OS% → Windows_NT), then Termux, then Linux
|
# Detect platform: Windows first (echo %OS% → Windows_NT), then Termux, then Linux
|
||||||
@@ -1887,9 +1890,8 @@ def setup_cookbook_routes() -> APIRouter:
|
|||||||
`busy` is True when free_mb/total_mb < 0.5.
|
`busy` is True when free_mb/total_mb < 0.5.
|
||||||
"""
|
"""
|
||||||
require_admin(request)
|
require_admin(request)
|
||||||
host = _validate_remote_host(host)
|
host = validate_remote_host(host)
|
||||||
if ssh_port is not None and ssh_port != "" and not _SSH_PORT_RE.fullmatch(ssh_port):
|
ssh_port = validate_ssh_port(ssh_port)
|
||||||
raise HTTPException(400, "Invalid ssh_port")
|
|
||||||
gpu_query = "nvidia-smi --query-gpu=index,name,memory.free,memory.total,memory.used,utilization.gpu,uuid --format=csv,noheader,nounits"
|
gpu_query = "nvidia-smi --query-gpu=index,name,memory.free,memory.total,memory.used,utilization.gpu,uuid --format=csv,noheader,nounits"
|
||||||
nvidia_error = None
|
nvidia_error = None
|
||||||
try:
|
try:
|
||||||
@@ -2046,9 +2048,8 @@ def setup_cookbook_routes() -> APIRouter:
|
|||||||
sig = (req.signal or "TERM").upper()
|
sig = (req.signal or "TERM").upper()
|
||||||
if sig not in ("TERM", "KILL", "INT"):
|
if sig not in ("TERM", "KILL", "INT"):
|
||||||
raise HTTPException(400, "signal must be TERM, KILL, or INT")
|
raise HTTPException(400, "signal must be TERM, KILL, or INT")
|
||||||
host = _validate_remote_host(req.host)
|
host = validate_remote_host(req.host)
|
||||||
if req.ssh_port and not _SSH_PORT_RE.fullmatch(req.ssh_port):
|
req.ssh_port = validate_ssh_port(req.ssh_port)
|
||||||
raise HTTPException(400, "Invalid ssh_port")
|
|
||||||
kill_cmd = f"kill -{sig} {req.pid}"
|
kill_cmd = f"kill -{sig} {req.pid}"
|
||||||
try:
|
try:
|
||||||
if host:
|
if host:
|
||||||
@@ -2382,13 +2383,18 @@ def setup_cookbook_routes() -> APIRouter:
|
|||||||
host = (srv.get("host") or "").strip()
|
host = (srv.get("host") or "").strip()
|
||||||
if not host:
|
if not host:
|
||||||
continue # local-only entry; the /proc scan handles it
|
continue # local-only entry; the /proc scan handles it
|
||||||
if not _REMOTE_HOST_RE.match(host):
|
try:
|
||||||
|
host = validate_remote_host(host)
|
||||||
|
except HTTPException:
|
||||||
continue
|
continue
|
||||||
sport = str(srv.get("port") or "").strip()
|
sport = str(srv.get("port") or "").strip()
|
||||||
ssh_base = ["ssh", "-o", "ConnectTimeout=4", "-o", "StrictHostKeyChecking=no"]
|
ssh_base = ["ssh", "-o", "ConnectTimeout=4", "-o", "StrictHostKeyChecking=no"]
|
||||||
if sport and sport != "22":
|
if sport and sport != "22":
|
||||||
if not _SSH_PORT_RE.match(sport):
|
try:
|
||||||
|
sport = validate_ssh_port(sport)
|
||||||
|
except HTTPException:
|
||||||
continue
|
continue
|
||||||
|
if sport != "22":
|
||||||
ssh_base.extend(["-p", sport])
|
ssh_base.extend(["-p", sport])
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -2743,10 +2749,16 @@ def setup_cookbook_routes() -> APIRouter:
|
|||||||
if not _SESSION_ID_RE.match(session_id):
|
if not _SESSION_ID_RE.match(session_id):
|
||||||
logger.warning(f"Skipping task with unsafe session_id: {session_id!r}")
|
logger.warning(f"Skipping task with unsafe session_id: {session_id!r}")
|
||||||
continue
|
continue
|
||||||
if remote and not _REMOTE_HOST_RE.match(remote):
|
if remote:
|
||||||
|
try:
|
||||||
|
remote = validate_remote_host(remote)
|
||||||
|
except HTTPException:
|
||||||
logger.warning(f"Skipping task with unsafe remoteHost: {remote!r}")
|
logger.warning(f"Skipping task with unsafe remoteHost: {remote!r}")
|
||||||
continue
|
continue
|
||||||
if _tport and not _SSH_PORT_RE.match(str(_tport)):
|
if _tport:
|
||||||
|
try:
|
||||||
|
_tport = validate_ssh_port(str(_tport))
|
||||||
|
except HTTPException:
|
||||||
logger.warning(f"Skipping task with unsafe sshPort: {_tport!r}")
|
logger.warning(f"Skipping task with unsafe sshPort: {_tport!r}")
|
||||||
continue
|
continue
|
||||||
if task_platform == "windows" and remote:
|
if task_platform == "windows" and remote:
|
||||||
|
|||||||
+54
-14
@@ -304,6 +304,7 @@ OWNER_SCOPED_EMAIL_CACHE_TABLES = {
|
|||||||
"email_ai_replies",
|
"email_ai_replies",
|
||||||
"email_calendar_extractions",
|
"email_calendar_extractions",
|
||||||
"email_urgency_alerts",
|
"email_urgency_alerts",
|
||||||
|
"sender_signatures",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -341,6 +342,55 @@ def _ensure_owner_scoped_email_cache_table(conn, table: str, create_sql: str, co
|
|||||||
_lg.getLogger(__name__).warning(f"{table} owner-migration skipped: {_mig_e}")
|
_lg.getLogger(__name__).warning(f"{table} owner-migration skipped: {_mig_e}")
|
||||||
|
|
||||||
|
|
||||||
|
def _ensure_sender_signatures_table(conn):
|
||||||
|
"""Create/migrate learned sender signatures to an owner-scoped cache."""
|
||||||
|
create_sql = """
|
||||||
|
CREATE TABLE IF NOT EXISTS sender_signatures (
|
||||||
|
from_address TEXT,
|
||||||
|
owner TEXT DEFAULT '',
|
||||||
|
signature_text TEXT,
|
||||||
|
sample_count INTEGER,
|
||||||
|
last_built_at TEXT NOT NULL,
|
||||||
|
model_used TEXT,
|
||||||
|
source TEXT,
|
||||||
|
PRIMARY KEY (from_address, owner)
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
conn.execute(create_sql)
|
||||||
|
try:
|
||||||
|
info = conn.execute("PRAGMA table_info(sender_signatures)").fetchall()
|
||||||
|
cols = [r[1] for r in info]
|
||||||
|
pk_cols = [r[1] for r in sorted((r for r in info if r[5]), key=lambda r: r[5])]
|
||||||
|
if "owner" in cols and pk_cols == ["from_address", "owner"]:
|
||||||
|
return
|
||||||
|
|
||||||
|
conn.execute("ALTER TABLE sender_signatures RENAME TO sender_signatures__old")
|
||||||
|
conn.execute(create_sql)
|
||||||
|
old_cols = [r[1] for r in conn.execute("PRAGMA table_info(sender_signatures__old)").fetchall()]
|
||||||
|
copy_cols = [
|
||||||
|
c for c in (
|
||||||
|
"from_address",
|
||||||
|
"signature_text",
|
||||||
|
"sample_count",
|
||||||
|
"last_built_at",
|
||||||
|
"model_used",
|
||||||
|
"source",
|
||||||
|
)
|
||||||
|
if c in old_cols
|
||||||
|
]
|
||||||
|
source_owner = "COALESCE(owner, '')" if "owner" in old_cols else "''"
|
||||||
|
conn.execute(
|
||||||
|
f"INSERT OR IGNORE INTO sender_signatures "
|
||||||
|
f"({', '.join([*copy_cols, 'owner'])}) "
|
||||||
|
f"SELECT {', '.join([*copy_cols, source_owner])} "
|
||||||
|
f"FROM sender_signatures__old"
|
||||||
|
)
|
||||||
|
conn.execute("DROP TABLE sender_signatures__old")
|
||||||
|
except Exception as _mig_e:
|
||||||
|
import logging as _lg
|
||||||
|
_lg.getLogger(__name__).warning(f"sender_signatures owner-migration skipped: {_mig_e}")
|
||||||
|
|
||||||
|
|
||||||
def attachment_extract_dir(folder: str, uid: str) -> Path:
|
def attachment_extract_dir(folder: str, uid: str) -> Path:
|
||||||
"""Containment-safe extraction directory for an attachment.
|
"""Containment-safe extraction directory for an attachment.
|
||||||
|
|
||||||
@@ -559,20 +609,10 @@ def _init_scheduled_db():
|
|||||||
conn.execute("ALTER TABLE email_boundaries ADD COLUMN turns_json TEXT")
|
conn.execute("ALTER TABLE email_boundaries ADD COLUMN turns_json TEXT")
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
# Per-sender signature cache. Populated by `learn_sender_signatures`
|
# Per-sender signature cache. Populated by `learn_sender_signatures`.
|
||||||
# action: the LLM extracts the common trailing block across N emails
|
# Message sender addresses are global, so signatures must be scoped to the
|
||||||
# from each sender; the renderer folds it consistently for every
|
# mailbox owner before `/read` returns them to the renderer.
|
||||||
# future email from that address.
|
_ensure_sender_signatures_table(conn)
|
||||||
conn.execute("""
|
|
||||||
CREATE TABLE IF NOT EXISTS sender_signatures (
|
|
||||||
from_address TEXT PRIMARY KEY,
|
|
||||||
signature_text TEXT,
|
|
||||||
sample_count INTEGER,
|
|
||||||
last_built_at TEXT NOT NULL,
|
|
||||||
model_used TEXT,
|
|
||||||
source TEXT
|
|
||||||
)
|
|
||||||
""")
|
|
||||||
conn.commit()
|
conn.commit()
|
||||||
conn.close()
|
conn.close()
|
||||||
|
|
||||||
|
|||||||
+51
-23
@@ -249,6 +249,41 @@ def _uid_from_fetch_meta(meta_b: bytes) -> str:
|
|||||||
return m.group(1).decode() if m else ""
|
return m.group(1).decode() if m else ""
|
||||||
|
|
||||||
|
|
||||||
|
_FETCH_SEQ_RE = re.compile(rb"^(\d+)\s+\(")
|
||||||
|
|
||||||
|
|
||||||
|
def _group_uid_fetch_records(msg_data) -> list:
|
||||||
|
"""Group an imaplib UID FETCH response into per-message (meta, payload).
|
||||||
|
|
||||||
|
imaplib yields an interleaved list: ``(meta, literal)`` tuples for
|
||||||
|
attributes that carry a literal (``RFC822.HEADER {n}`` etc.) plus bare
|
||||||
|
``bytes`` elements for everything the server sends outside a literal.
|
||||||
|
Where each attribute lands is server-specific: Dovecot sends FLAGS
|
||||||
|
*before* the header literal (so it ends up inside the tuple meta), while
|
||||||
|
Gmail sends FLAGS *after* it, arriving as a bare ``b' FLAGS (\\Seen))'``
|
||||||
|
element. Dropping bare elements therefore silently loses FLAGS on Gmail
|
||||||
|
and every message renders as unread/unflagged.
|
||||||
|
|
||||||
|
A tuple whose meta starts with a sequence number opens a new record;
|
||||||
|
every other part — continuation tuple or bare bytes — is folded into the
|
||||||
|
current record's meta so attribute regexes see the full meta text.
|
||||||
|
Plain ``b')'`` terminators get folded in too, which is harmless.
|
||||||
|
"""
|
||||||
|
grouped: list = [] # list of (meta_bytes, payload_bytes_or_None)
|
||||||
|
for part in (msg_data or []):
|
||||||
|
if isinstance(part, tuple):
|
||||||
|
meta_b = part[0] if isinstance(part[0], (bytes, bytearray)) else str(part[0]).encode()
|
||||||
|
if _FETCH_SEQ_RE.match(meta_b):
|
||||||
|
grouped.append((meta_b, part[1]))
|
||||||
|
elif grouped:
|
||||||
|
cur_meta, cur_payload = grouped[-1]
|
||||||
|
grouped[-1] = (cur_meta + b" " + meta_b, cur_payload or part[1])
|
||||||
|
elif isinstance(part, (bytes, bytearray)) and grouped:
|
||||||
|
cur_meta, cur_payload = grouped[-1]
|
||||||
|
grouped[-1] = (cur_meta + b" " + bytes(part), cur_payload)
|
||||||
|
return grouped
|
||||||
|
|
||||||
|
|
||||||
def _smtp_ready(cfg: dict) -> bool:
|
def _smtp_ready(cfg: dict) -> bool:
|
||||||
return bool(cfg.get("smtp_host") and cfg.get("smtp_user") and cfg.get("smtp_password"))
|
return bool(cfg.get("smtp_host") and cfg.get("smtp_user") and cfg.get("smtp_password"))
|
||||||
|
|
||||||
@@ -799,20 +834,11 @@ def setup_email_routes():
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Batch fetch failed, falling back to per-UID: {e}")
|
logger.warning(f"Batch fetch failed, falling back to per-UID: {e}")
|
||||||
status, msg_data = "NO", []
|
status, msg_data = "NO", []
|
||||||
# imaplib batch responses interleave (meta, payload) tuples and
|
# Group the batched response into per-message (meta, payload)
|
||||||
# `b')'` terminators. Group by message: each tuple where the
|
# records. Bare bytes parts must be kept: Gmail returns FLAGS
|
||||||
# meta begins with a seq number starts a new message record.
|
# after the header literal as a bare element, and dropping it
|
||||||
seq_re = re.compile(rb'^(\d+)\s+\(')
|
# rendered every Gmail message as unread/unflagged.
|
||||||
grouped = [] # list of (meta_str, payload_bytes)
|
grouped = _group_uid_fetch_records(msg_data)
|
||||||
for part in (msg_data or []):
|
|
||||||
if isinstance(part, tuple):
|
|
||||||
meta_b = part[0] if isinstance(part[0], (bytes, bytearray)) else str(part[0]).encode()
|
|
||||||
if seq_re.match(meta_b):
|
|
||||||
grouped.append((meta_b, part[1]))
|
|
||||||
elif grouped:
|
|
||||||
# continuation of previous message — concatenate meta info if any
|
|
||||||
cur_meta, cur_payload = grouped[-1]
|
|
||||||
grouped[-1] = (cur_meta + b" " + meta_b, cur_payload or part[1])
|
|
||||||
|
|
||||||
if status != "OK" and not grouped:
|
if status != "OK" and not grouped:
|
||||||
conn.logout()
|
conn.logout()
|
||||||
@@ -1098,14 +1124,15 @@ def setup_email_routes():
|
|||||||
continue
|
continue
|
||||||
raw_header = None
|
raw_header = None
|
||||||
flags = ""
|
flags = ""
|
||||||
for part in msg_data:
|
# Same Gmail caveat as the list route: FLAGS may
|
||||||
if isinstance(part, tuple):
|
# arrive after the header literal, so group bare
|
||||||
meta = part[0].decode() if isinstance(part[0], bytes) else str(part[0])
|
# parts back into the message meta before scanning.
|
||||||
if b"RFC822.HEADER" in part[0] if isinstance(part[0], bytes) else "RFC822.HEADER" in meta:
|
for meta_b, payload in _group_uid_fetch_records(msg_data):
|
||||||
raw_header = part[1]
|
if payload and b"RFC822.HEADER" in meta_b:
|
||||||
flag_match = re.search(r'FLAGS \(([^)]*)\)', meta)
|
raw_header = payload
|
||||||
|
flag_match = re.search(rb'FLAGS \(([^)]*)\)', meta_b)
|
||||||
if flag_match:
|
if flag_match:
|
||||||
flags = flag_match.group(1)
|
flags = flag_match.group(1).decode(errors="replace")
|
||||||
if not raw_header:
|
if not raw_header:
|
||||||
continue
|
continue
|
||||||
msg = email_mod.message_from_bytes(raw_header)
|
msg = email_mod.message_from_bytes(raw_header)
|
||||||
@@ -1247,8 +1274,9 @@ def setup_email_routes():
|
|||||||
try:
|
try:
|
||||||
if sender_addr:
|
if sender_addr:
|
||||||
_rs = _c.execute(
|
_rs = _c.execute(
|
||||||
"SELECT signature_text FROM sender_signatures WHERE from_address = ?",
|
f"SELECT signature_text FROM sender_signatures "
|
||||||
(sender_addr.lower().strip(),),
|
f"WHERE from_address = ? AND {owner_clause}",
|
||||||
|
(sender_addr.lower().strip(), *owner_params),
|
||||||
).fetchone()
|
).fetchone()
|
||||||
if _rs and _rs[0]:
|
if _rs and _rs[0]:
|
||||||
cached_sender_sig = _rs[0]
|
cached_sender_sig = _rs[0]
|
||||||
|
|||||||
+23
-3
@@ -1,7 +1,9 @@
|
|||||||
import re
|
import re
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
|
|
||||||
from fastapi import APIRouter
|
from fastapi import APIRouter, HTTPException
|
||||||
|
|
||||||
|
from routes._validators import validate_remote_host, validate_ssh_port
|
||||||
|
|
||||||
|
|
||||||
# Backends the manual hardware simulator accepts. Must stay a subset of what
|
# Backends the manual hardware simulator accepts. Must stay a subset of what
|
||||||
@@ -11,6 +13,14 @@ from fastapi import APIRouter
|
|||||||
_MANUAL_BACKENDS = {"cuda", "rocm", "metal", "cpu_x86", "cpu_arm"}
|
_MANUAL_BACKENDS = {"cuda", "rocm", "metal", "cpu_x86", "cpu_arm"}
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_detection_target(host: str = "", ssh_port: str = "") -> tuple[str, str]:
|
||||||
|
host_value = validate_remote_host(host) or ""
|
||||||
|
port_value = validate_ssh_port(ssh_port) or ""
|
||||||
|
if port_value and not host_value:
|
||||||
|
raise HTTPException(400, "ssh_port requires host")
|
||||||
|
return host_value, port_value
|
||||||
|
|
||||||
|
|
||||||
def _apply_manual_hardware(system, manual_mode="", manual_gpu_count="", manual_vram_gb="", manual_ram_gb="", manual_backend=""):
|
def _apply_manual_hardware(system, manual_mode="", manual_gpu_count="", manual_vram_gb="", manual_ram_gb="", manual_backend=""):
|
||||||
"""Manual hardware is a "what if I had this setup" simulator —
|
"""Manual hardware is a "what if I had this setup" simulator —
|
||||||
REPLACES the detected hardware entirely instead of adding to it.
|
REPLACES the detected hardware entirely instead of adding to it.
|
||||||
@@ -105,6 +115,7 @@ def setup_hwfit_routes():
|
|||||||
"""Detect and return current system hardware info. Pass host=user@server for remote.
|
"""Detect and return current system hardware info. Pass host=user@server for remote.
|
||||||
fresh=true bypasses the per-host cache (the Rescan button)."""
|
fresh=true bypasses the per-host cache (the Rescan button)."""
|
||||||
from services.hwfit.hardware import detect_system
|
from services.hwfit.hardware import detect_system
|
||||||
|
host, ssh_port = _validate_detection_target(host, ssh_port)
|
||||||
return detect_system(host=host, ssh_port=ssh_port, platform=platform, fresh=fresh)
|
return detect_system(host=host, ssh_port=ssh_port, platform=platform, fresh=fresh)
|
||||||
|
|
||||||
@router.get("/models")
|
@router.get("/models")
|
||||||
@@ -118,6 +129,7 @@ def setup_hwfit_routes():
|
|||||||
from services.hwfit.hardware import detect_system
|
from services.hwfit.hardware import detect_system
|
||||||
from services.hwfit.fit import rank_models
|
from services.hwfit.fit import rank_models
|
||||||
from services.hwfit.models import get_models, model_catalog_path
|
from services.hwfit.models import get_models, model_catalog_path
|
||||||
|
host, ssh_port = _validate_detection_target(host, ssh_port)
|
||||||
system = deepcopy(detect_system(host=host, ssh_port=ssh_port, platform=platform, fresh=fresh))
|
system = deepcopy(detect_system(host=host, ssh_port=ssh_port, platform=platform, fresh=fresh))
|
||||||
if system.get("error"):
|
if system.get("error"):
|
||||||
return {"system": system, "models": [], "error": system["error"]}
|
return {"system": system, "models": [], "error": system["error"]}
|
||||||
@@ -165,8 +177,14 @@ def setup_hwfit_routes():
|
|||||||
system["gpu_name"] = g["name"]
|
system["gpu_name"] = g["name"]
|
||||||
system["active_group"] = {**g, "use_count": n}
|
system["active_group"] = {**g, "use_count": n}
|
||||||
|
|
||||||
if gpu_count != "":
|
# Parse the optional count defensively (matches the gpu_group guard
|
||||||
n = int(gpu_count)
|
# above): a non-numeric query param previously raised ValueError ->
|
||||||
|
# HTTP 500. A malformed value is ignored, same as omitting it.
|
||||||
|
try:
|
||||||
|
n = int(gpu_count) if gpu_count != "" else None
|
||||||
|
except ValueError:
|
||||||
|
n = None
|
||||||
|
if n is not None:
|
||||||
if n == 0:
|
if n == 0:
|
||||||
# RAM-only mode: rank against system memory, offload allowed.
|
# RAM-only mode: rank against system memory, offload allowed.
|
||||||
system["has_gpu"] = False
|
system["has_gpu"] = False
|
||||||
@@ -229,6 +247,7 @@ def setup_hwfit_routes():
|
|||||||
from services.hwfit.hardware import detect_system
|
from services.hwfit.hardware import detect_system
|
||||||
from services.hwfit.models import get_models
|
from services.hwfit.models import get_models
|
||||||
from services.hwfit.profiles import compute_serve_profiles
|
from services.hwfit.profiles import compute_serve_profiles
|
||||||
|
host, ssh_port = _validate_detection_target(host, ssh_port)
|
||||||
system = detect_system(host=host, ssh_port=ssh_port, platform=platform, fresh=fresh)
|
system = detect_system(host=host, ssh_port=ssh_port, platform=platform, fresh=fresh)
|
||||||
if system.get("error"):
|
if system.get("error"):
|
||||||
return {"system": system, "profiles": [], "error": system["error"]}
|
return {"system": system, "profiles": [], "error": system["error"]}
|
||||||
@@ -279,6 +298,7 @@ def setup_hwfit_routes():
|
|||||||
"""Rank image generation models against detected hardware."""
|
"""Rank image generation models against detected hardware."""
|
||||||
from services.hwfit.hardware import detect_system
|
from services.hwfit.hardware import detect_system
|
||||||
from services.hwfit.image_models import rank_image_models
|
from services.hwfit.image_models import rank_image_models
|
||||||
|
host, ssh_port = _validate_detection_target(host, ssh_port)
|
||||||
system = deepcopy(detect_system(host=host, ssh_port=ssh_port, platform=platform, fresh=fresh))
|
system = deepcopy(detect_system(host=host, ssh_port=ssh_port, platform=platform, fresh=fresh))
|
||||||
if system.get("error"):
|
if system.get("error"):
|
||||||
return {"system": system, "models": [], "error": system["error"]}
|
return {"system": system, "models": [], "error": system["error"]}
|
||||||
|
|||||||
@@ -105,6 +105,13 @@ def setup_memory_routes(memory_manager: MemoryManager, session_manager: SessionM
|
|||||||
if memory_manager.find_duplicates(text, user_mem):
|
if memory_manager.find_duplicates(text, user_mem):
|
||||||
return {"ok": True, "count": len(user_mem), "message": "Memory already exists"}
|
return {"ok": True, "count": len(user_mem), "message": "Memory already exists"}
|
||||||
|
|
||||||
|
if memory_data.session_id:
|
||||||
|
try:
|
||||||
|
session_obj = session_manager.get_session(memory_data.session_id)
|
||||||
|
except KeyError:
|
||||||
|
raise HTTPException(404, "Session not found")
|
||||||
|
_assert_session_owner(session_obj, user)
|
||||||
|
|
||||||
new_entry = memory_manager.add_entry(text, memory_data.source, memory_data.category, owner=user)
|
new_entry = memory_manager.add_entry(text, memory_data.source, memory_data.category, owner=user)
|
||||||
if memory_data.session_id:
|
if memory_data.session_id:
|
||||||
new_entry["session_id"] = memory_data.session_id
|
new_entry["session_id"] = memory_data.session_id
|
||||||
@@ -163,8 +170,17 @@ def setup_memory_routes(memory_manager: MemoryManager, session_manager: SessionM
|
|||||||
|
|
||||||
session_id = memory.get("session_id")
|
session_id = memory.get("session_id")
|
||||||
if session_id and session_id in session_manager.sessions:
|
if session_id and session_id in session_manager.sessions:
|
||||||
|
try:
|
||||||
session = session_manager.get_session(session_id)
|
session = session_manager.get_session(session_id)
|
||||||
|
if session:
|
||||||
|
_assert_session_owner(session, user)
|
||||||
memory["session_name"] = session.name if session else f"Session {session_id[:6]}"
|
memory["session_name"] = session.name if session else f"Session {session_id[:6]}"
|
||||||
|
except KeyError:
|
||||||
|
memory["session_name"] = "Unknown"
|
||||||
|
except HTTPException as exc:
|
||||||
|
if exc.status_code != 404:
|
||||||
|
raise
|
||||||
|
memory["session_name"] = "Unknown"
|
||||||
else:
|
else:
|
||||||
memory["session_name"] = "Unknown"
|
memory["session_name"] = "Unknown"
|
||||||
|
|
||||||
|
|||||||
+27
-5
@@ -123,6 +123,21 @@ def _clear_user_pref_endpoint_refs(all_prefs: dict, ep_id: str) -> int:
|
|||||||
return cleared_users
|
return cleared_users
|
||||||
|
|
||||||
|
|
||||||
|
def _default_endpoint_needs_assignment(current_default_id: str, enabled_endpoint_ids) -> bool:
|
||||||
|
"""Whether the global default chat endpoint should be (re)assigned.
|
||||||
|
|
||||||
|
True when nothing is configured yet, or the configured default no longer
|
||||||
|
resolves to an enabled endpoint (e.g. the user disabled it). Without the
|
||||||
|
second case, adding a new endpoint after disabling the previous default
|
||||||
|
leaves `default_endpoint_id` pointing at the disabled endpoint, so features
|
||||||
|
that read the raw setting (Memory → Tidy) fail with "No default model
|
||||||
|
configured" even though an enabled endpoint exists. See #3586.
|
||||||
|
"""
|
||||||
|
if not current_default_id:
|
||||||
|
return True
|
||||||
|
return current_default_id not in enabled_endpoint_ids
|
||||||
|
|
||||||
|
|
||||||
# Loopback hosts a user might type for a local model server (LM Studio,
|
# Loopback hosts a user might type for a local model server (LM Studio,
|
||||||
# llama.cpp, vLLM, …). Inside Docker these point at the *container*, not the
|
# llama.cpp, vLLM, …). Inside Docker these point at the *container*, not the
|
||||||
# host the server actually runs on.
|
# host the server actually runs on.
|
||||||
@@ -1727,12 +1742,19 @@ def setup_model_routes(model_discovery):
|
|||||||
)
|
)
|
||||||
db.add(ep)
|
db.add(ep)
|
||||||
db.commit()
|
db.commit()
|
||||||
# Auto-set as default chat endpoint if none configured yet. Seed
|
# Auto-set as default chat endpoint when none is usable yet — either
|
||||||
# the first CHAT model (not raw model_ids[0]) so we don't pin the
|
# nothing is configured, or the configured default points at an
|
||||||
# global default to an embedding/tts/etc. entry a provider happens
|
# endpoint that is now missing/disabled (#3586). Seed the first CHAT
|
||||||
# to list first.
|
# model (not raw model_ids[0]) so we don't pin the global default to
|
||||||
|
# an embedding/tts/etc. entry a provider happens to list first.
|
||||||
settings = _load_settings()
|
settings = _load_settings()
|
||||||
if not settings.get("default_endpoint_id"):
|
enabled_ids = {
|
||||||
|
e.id
|
||||||
|
for e in db.query(ModelEndpoint).filter(
|
||||||
|
ModelEndpoint.is_enabled == True # noqa: E712
|
||||||
|
).all()
|
||||||
|
}
|
||||||
|
if _default_endpoint_needs_assignment(settings.get("default_endpoint_id") or "", enabled_ids):
|
||||||
from src.endpoint_resolver import _first_chat_model
|
from src.endpoint_resolver import _first_chat_model
|
||||||
settings["default_endpoint_id"] = ep.id
|
settings["default_endpoint_id"] = ep.id
|
||||||
settings["default_model"] = _first_chat_model(model_ids) or ""
|
settings["default_model"] = _first_chat_model(model_ids) or ""
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ from core.session_manager import SessionManager
|
|||||||
from core.models import ChatMessage
|
from core.models import ChatMessage
|
||||||
from src.request_models import SessionResponse
|
from src.request_models import SessionResponse
|
||||||
from core.database import Session as DbSession, SessionLocal, Document, GalleryImage, utcnow_naive
|
from core.database import Session as DbSession, SessionLocal, Document, GalleryImage, utcnow_naive
|
||||||
from src.auth_helpers import get_current_user, effective_user, _auth_disabled
|
from src.auth_helpers import get_current_user, effective_user, _auth_disabled, owner_filter
|
||||||
from src.session_actions import is_session_recently_active
|
from src.session_actions import is_session_recently_active
|
||||||
|
|
||||||
|
|
||||||
@@ -258,7 +258,9 @@ def setup_session_routes(session_manager: SessionManager, config: dict, webhook_
|
|||||||
last_msg_map = {}
|
last_msg_map = {}
|
||||||
mode_map = {}
|
mode_map = {}
|
||||||
msg_count_map = {}
|
msg_count_map = {}
|
||||||
rows = db.query(DbSession.id, DbSession.folder, DbSession.total_input_tokens, DbSession.total_output_tokens, DbSession.is_important, DbSession.created_at, DbSession.updated_at, DbSession.last_message_at, DbSession.mode, DbSession.message_count).filter(DbSession.archived == False, DbSession.owner == user).all()
|
q = db.query(DbSession.id, DbSession.folder, DbSession.total_input_tokens, DbSession.total_output_tokens, DbSession.is_important, DbSession.created_at, DbSession.updated_at, DbSession.last_message_at, DbSession.mode, DbSession.message_count).filter(DbSession.archived == False)
|
||||||
|
q = owner_filter(q, DbSession, user)
|
||||||
|
rows = q.all()
|
||||||
for row in rows:
|
for row in rows:
|
||||||
folder_map[row.id] = row.folder
|
folder_map[row.id] = row.folder
|
||||||
token_map[row.id] = (row.total_input_tokens or 0) + (row.total_output_tokens or 0)
|
token_map[row.id] = (row.total_input_tokens or 0) + (row.total_output_tokens or 0)
|
||||||
@@ -277,17 +279,19 @@ def setup_session_routes(session_manager: SessionManager, config: dict, webhook_
|
|||||||
# Sessions with active documents that have content
|
# Sessions with active documents that have content
|
||||||
from sqlalchemy import func
|
from sqlalchemy import func
|
||||||
doc_session_ids = set(
|
doc_session_ids = set(
|
||||||
r[0] for r in db.query(Document.session_id)
|
r[0] for r in owner_filter(
|
||||||
|
db.query(Document.session_id)
|
||||||
.filter(Document.is_active == True,
|
.filter(Document.is_active == True,
|
||||||
Document.current_content != None,
|
Document.current_content != None,
|
||||||
func.trim(Document.current_content) != "",
|
func.trim(Document.current_content) != ""),
|
||||||
Document.owner == user)
|
Document, user)
|
||||||
.distinct().all()
|
.distinct().all()
|
||||||
)
|
)
|
||||||
img_session_ids = set(
|
img_session_ids = set(
|
||||||
r[0] for r in db.query(GalleryImage.session_id)
|
r[0] for r in owner_filter(
|
||||||
.filter(GalleryImage.session_id != None,
|
db.query(GalleryImage.session_id)
|
||||||
GalleryImage.owner == user)
|
.filter(GalleryImage.session_id != None),
|
||||||
|
GalleryImage, user)
|
||||||
.distinct().all()
|
.distinct().all()
|
||||||
)
|
)
|
||||||
finally:
|
finally:
|
||||||
|
|||||||
@@ -0,0 +1,85 @@
|
|||||||
|
"""Workspace API - browse server directories to pick a tool workspace folder."""
|
||||||
|
import os
|
||||||
|
from fastapi import APIRouter, Request, HTTPException, Query
|
||||||
|
|
||||||
|
from src.auth_helpers import get_current_user
|
||||||
|
from src.tool_security import owner_is_admin_or_single_user
|
||||||
|
|
||||||
|
# Cap entries returned per directory (mirrors filesystem_tools._CODENAV_MAX_HITS).
|
||||||
|
# A huge directory shouldn't dump thousands of rows into the picker; the user can
|
||||||
|
# type/paste a path to jump straight in instead.
|
||||||
|
_MAX_BROWSE_DIRS = 500
|
||||||
|
|
||||||
|
|
||||||
|
def setup_workspace_routes():
|
||||||
|
router = APIRouter(prefix="/api/workspace", tags=["workspace"])
|
||||||
|
|
||||||
|
@router.get("/browse")
|
||||||
|
def browse(request: Request, path: str = Query(default="")):
|
||||||
|
"""List subdirectories of `path` (default: home) so the UI can navigate
|
||||||
|
the server filesystem and pick a workspace folder. Directories only.
|
||||||
|
|
||||||
|
ADMIN-ONLY: this enumerates the server filesystem, so it is gated the
|
||||||
|
same way the file/shell tools are (read_file/write_file/bash are in
|
||||||
|
NON_ADMIN_BLOCKED_TOOLS). A non-admin who can't use those tools must not
|
||||||
|
be able to map the host's directory tree either.
|
||||||
|
"""
|
||||||
|
owner = get_current_user(request)
|
||||||
|
if not owner_is_admin_or_single_user(owner):
|
||||||
|
raise HTTPException(status_code=403, detail="Workspace browsing is admin-only")
|
||||||
|
|
||||||
|
# Resolve symlinks so the reported path is canonical and the UI navigates
|
||||||
|
# real directories (defends against symlink games in displayed paths).
|
||||||
|
target = os.path.realpath(os.path.expanduser(path.strip() or "~"))
|
||||||
|
if not os.path.isdir(target):
|
||||||
|
target = os.path.realpath(os.path.expanduser("~"))
|
||||||
|
|
||||||
|
dirs = []
|
||||||
|
try:
|
||||||
|
with os.scandir(target) as it:
|
||||||
|
for entry in it:
|
||||||
|
try:
|
||||||
|
# Don't follow symlinks when classifying - a symlinked
|
||||||
|
# dir is skipped rather than letting the browser wander
|
||||||
|
# off via a link. Hidden entries are omitted.
|
||||||
|
if entry.is_dir(follow_symlinks=False) and not entry.name.startswith("."):
|
||||||
|
# Build the child path server-side with os.path.join
|
||||||
|
# so it's correct on Windows (backslashes) and Linux.
|
||||||
|
dirs.append({"name": entry.name, "path": os.path.join(target, entry.name)})
|
||||||
|
except OSError:
|
||||||
|
continue
|
||||||
|
except (PermissionError, OSError):
|
||||||
|
dirs = []
|
||||||
|
|
||||||
|
dirs_sorted = sorted(dirs, key=lambda d: d["name"].lower())
|
||||||
|
truncated = len(dirs_sorted) > _MAX_BROWSE_DIRS
|
||||||
|
parent = os.path.dirname(target)
|
||||||
|
from src.tool_execution import vet_workspace
|
||||||
|
return {
|
||||||
|
"path": target,
|
||||||
|
"parent": parent if parent and parent != target else None,
|
||||||
|
"dirs": dirs_sorted[:_MAX_BROWSE_DIRS],
|
||||||
|
"truncated": truncated,
|
||||||
|
# Whether this directory may be bound as a workspace (filesystem
|
||||||
|
# roots and sensitive dirs may be browsed through but not chosen).
|
||||||
|
"selectable": vet_workspace(target) is not None,
|
||||||
|
}
|
||||||
|
|
||||||
|
@router.get("/vet")
|
||||||
|
def vet(request: Request, path: str = Query(default="")):
|
||||||
|
"""Validate a workspace path without binding it.
|
||||||
|
|
||||||
|
The UI calls this before persisting a manually typed path (/workspace
|
||||||
|
set) so a typo, file path, deleted folder, sensitive dir, or filesystem
|
||||||
|
root is rejected up front with the canonical path returned on success,
|
||||||
|
instead of being stored client-side and silently dropped at chat time.
|
||||||
|
Admin-gated like /browse: it confirms path existence on the host.
|
||||||
|
"""
|
||||||
|
owner = get_current_user(request)
|
||||||
|
if not owner_is_admin_or_single_user(owner):
|
||||||
|
raise HTTPException(status_code=403, detail="Workspace selection is admin-only")
|
||||||
|
from src.tool_execution import vet_workspace
|
||||||
|
resolved = vet_workspace(path)
|
||||||
|
return {"ok": resolved is not None, "path": resolved}
|
||||||
|
|
||||||
|
return router
|
||||||
@@ -285,6 +285,7 @@ class ResearchHandler:
|
|||||||
query, report, stats, elapsed,
|
query, report, stats, elapsed,
|
||||||
findings=researcher.findings,
|
findings=researcher.findings,
|
||||||
evolving_report=researcher.evolving_report,
|
evolving_report=researcher.evolving_report,
|
||||||
|
analyzed_urls=getattr(researcher, "analyzed_urls", None),
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -331,7 +332,8 @@ class ResearchHandler:
|
|||||||
|
|
||||||
def _format_research_report(
|
def _format_research_report(
|
||||||
self, query: str, full_report: str, stats: dict, elapsed: float,
|
self, query: str, full_report: str, stats: dict, elapsed: float,
|
||||||
findings: list = None, evolving_report: str = None,
|
findings: Optional[list] = None, evolving_report: Optional[str] = None,
|
||||||
|
analyzed_urls: Optional[list] = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Format research report with sources list and expandable raw findings."""
|
"""Format research report with sources list and expandable raw findings."""
|
||||||
summary_lines = [
|
summary_lines = [
|
||||||
@@ -342,20 +344,34 @@ class ResearchHandler:
|
|||||||
]
|
]
|
||||||
summary_text = " | ".join(summary_lines)
|
summary_text = " | ".join(summary_lines)
|
||||||
|
|
||||||
# Build sources list with clickable links
|
# Build sources list with clickable links. Keep the curated Sources
|
||||||
|
# section filtered for citation quality, but also list every unique URL
|
||||||
|
# the research run inspected so the "URLs Analyzed" count is auditable.
|
||||||
sources_section = ""
|
sources_section = ""
|
||||||
if findings:
|
analyzed_urls_section = ""
|
||||||
|
url_items = analyzed_urls if analyzed_urls is not None else findings
|
||||||
|
if findings or url_items:
|
||||||
seen_urls = set()
|
seen_urls = set()
|
||||||
source_lines = []
|
source_lines = []
|
||||||
for f in findings:
|
analyzed_seen = set()
|
||||||
|
analyzed_lines = []
|
||||||
|
for f in findings or []:
|
||||||
url = f.get("url", "")
|
url = f.get("url", "")
|
||||||
title = f.get("title", "") or url
|
title = f.get("title", "") or url
|
||||||
summary = f.get("summary", "") or f.get("evidence", "")
|
summary = f.get("summary", "") or f.get("evidence", "")
|
||||||
if url and url not in seen_urls and not is_low_quality(summary):
|
if url and url not in seen_urls and not is_low_quality(summary):
|
||||||
seen_urls.add(url)
|
seen_urls.add(url)
|
||||||
source_lines.append(f"- [{title}]({url})")
|
source_lines.append(f"- [{title}]({url})")
|
||||||
|
for item in url_items or []:
|
||||||
|
url = item.get("url", "")
|
||||||
|
title = item.get("title", "") or url
|
||||||
|
if url and url not in analyzed_seen:
|
||||||
|
analyzed_seen.add(url)
|
||||||
|
analyzed_lines.append(f"{len(analyzed_lines) + 1}. [{title}]({url})")
|
||||||
if source_lines:
|
if source_lines:
|
||||||
sources_section = "\n### Sources\n\n" + "\n".join(source_lines) + "\n"
|
sources_section = "\n### Sources\n\n" + "\n".join(source_lines) + "\n"
|
||||||
|
if analyzed_lines:
|
||||||
|
analyzed_urls_section = "\n### Analyzed URLs\n\n" + "\n".join(analyzed_lines) + "\n"
|
||||||
|
|
||||||
# Build raw findings section (individual extractions per source)
|
# Build raw findings section (individual extractions per source)
|
||||||
raw_findings_section = ""
|
raw_findings_section = ""
|
||||||
@@ -391,6 +407,7 @@ class ResearchHandler:
|
|||||||
{full_report}
|
{full_report}
|
||||||
|
|
||||||
{sources_section}
|
{sources_section}
|
||||||
|
{analyzed_urls_section}
|
||||||
{collected_section}
|
{collected_section}
|
||||||
---
|
---
|
||||||
|
|
||||||
|
|||||||
@@ -299,6 +299,40 @@ def fetch_webpage_content(url: str, timeout: int = 5, retry_attempt: int = 0) ->
|
|||||||
_cache_result(cache_file, cache_key, result, url)
|
_cache_result(cache_file, cache_key, result, url)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
# Plain-text / Markdown / JSON handling. Sources like
|
||||||
|
# raw.githubusercontent.com serve Markdown as `text/plain`, JSON APIs and
|
||||||
|
# raw config files serve `application/json`, and a lot of code and tool
|
||||||
|
# docs live in `.md` / `.txt`. These have no HTML structure, so the HTML
|
||||||
|
# branch below would extract nothing and report "no readable text content".
|
||||||
|
# Return the body verbatim instead. The `is_html` guard keeps real HTML
|
||||||
|
# (including `application/xhtml+xml`) on the parsing path; the `json` check
|
||||||
|
# covers `application/json` and `+json` suffixes; the URL-suffix fallback
|
||||||
|
# catches servers that mislabel text files as `application/octet-stream`.
|
||||||
|
is_html = "html" in content_type
|
||||||
|
is_json = "json" in content_type
|
||||||
|
url_path = url.lower().split("?", 1)[0].split("#", 1)[0]
|
||||||
|
looks_like_text_file = url_path.endswith(
|
||||||
|
(".md", ".markdown", ".txt", ".text", ".json", ".jsonl")
|
||||||
|
)
|
||||||
|
if not is_html and (content_type.startswith("text/") or is_json or looks_like_text_file):
|
||||||
|
text_body = (response.text or "").strip()
|
||||||
|
result = {
|
||||||
|
"url": url,
|
||||||
|
"title": os.path.basename(url_path) or url,
|
||||||
|
"content": text_body,
|
||||||
|
"lists": [],
|
||||||
|
"tables": [],
|
||||||
|
"code_blocks": [],
|
||||||
|
"meta_description": "",
|
||||||
|
"meta_keywords": "",
|
||||||
|
"js_rendered": False,
|
||||||
|
"js_message": "",
|
||||||
|
"success": bool(text_body),
|
||||||
|
"error": "" if text_body else "Empty response body",
|
||||||
|
}
|
||||||
|
_cache_result(cache_file, cache_key, result, url)
|
||||||
|
return result
|
||||||
|
|
||||||
# HTML handling
|
# HTML handling
|
||||||
try:
|
try:
|
||||||
soup = BeautifulSoup(response.text, "html.parser")
|
soup = BeautifulSoup(response.text, "html.parser")
|
||||||
|
|||||||
@@ -417,7 +417,7 @@ def duckduckgo_search(query: str, count: Optional[int] = None, time_filter: Opti
|
|||||||
return []
|
return []
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from duckduckgo_search import DDGS
|
from ddgs import DDGS
|
||||||
except ImportError:
|
except ImportError:
|
||||||
logger.warning("duckduckgo-search package not installed; using HTML fallback")
|
logger.warning("duckduckgo-search package not installed; using HTML fallback")
|
||||||
return _html_fallback()
|
return _html_fallback()
|
||||||
|
|||||||
+28
-8
@@ -21,7 +21,7 @@ from src.settings import get_setting
|
|||||||
from src.prompt_security import untrusted_context_message
|
from src.prompt_security import untrusted_context_message
|
||||||
from src.tool_security import blocked_tools_for_owner, plan_mode_disabled_tools
|
from src.tool_security import blocked_tools_for_owner, plan_mode_disabled_tools
|
||||||
from src.tool_policy import GUIDE_ONLY_DIRECTIVE, ToolPolicy
|
from src.tool_policy import GUIDE_ONLY_DIRECTIVE, ToolPolicy
|
||||||
from src.tool_utils import get_mcp_manager
|
from src.tool_utils import _truncate, get_mcp_manager
|
||||||
from src.agent_tools import (
|
from src.agent_tools import (
|
||||||
parse_tool_blocks,
|
parse_tool_blocks,
|
||||||
strip_tool_blocks,
|
strip_tool_blocks,
|
||||||
@@ -272,7 +272,7 @@ _DOMAIN_TOOL_MAP = {
|
|||||||
"notes_calendar_tasks": {"manage_notes", "manage_calendar", "manage_tasks"},
|
"notes_calendar_tasks": {"manage_notes", "manage_calendar", "manage_tasks"},
|
||||||
"ui": {"ui_control"},
|
"ui": {"ui_control"},
|
||||||
"sessions": {"create_session", "list_sessions", "manage_session", "send_to_session", "search_chats"},
|
"sessions": {"create_session", "list_sessions", "manage_session", "send_to_session", "search_chats"},
|
||||||
"files": {"bash", "python", "read_file", "write_file", "edit_file", "grep", "glob", "ls"},
|
"files": {"bash", "python", "read_file", "write_file", "edit_file", "grep", "glob", "ls", "get_workspace"},
|
||||||
"settings": {"manage_settings", "manage_endpoints", "manage_mcp", "manage_webhooks", "manage_tokens", "app_api"},
|
"settings": {"manage_settings", "manage_endpoints", "manage_mcp", "manage_webhooks", "manage_tokens", "app_api"},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -309,6 +309,7 @@ NEVER pipe multi-line Python through `python -c "..."` — shell quoting eats re
|
|||||||
<python code>
|
<python code>
|
||||||
```
|
```
|
||||||
Execute Python code. Use for computation, data processing, scripting. NOT for writing code for the user (use create_document for that). Same sandbox limits as bash — no TTY, no GUI, no `input()`; for anything the user should interact with, generate a single HTML file with inline JS instead.
|
Execute Python code. Use for computation, data processing, scripting. NOT for writing code for the user (use create_document for that). Same sandbox limits as bash — no TTY, no GUI, no `input()`; for anything the user should interact with, generate a single HTML file with inline JS instead.
|
||||||
|
Prefer a dedicated tool whenever one fits the job (reading, searching, or writing files); use python only for computation/processing no dedicated tool covers - not for reading or writing files.
|
||||||
Do NOT use Python/requests for web lookup/search/latest/current requests when `web_search` or `web_fetch` is available.""",
|
Do NOT use Python/requests for web lookup/search/latest/current requests when `web_search` or `web_fetch` is available.""",
|
||||||
|
|
||||||
"web_search": """\
|
"web_search": """\
|
||||||
@@ -347,6 +348,11 @@ Write content to a file. First line is the path, rest is the content.""",
|
|||||||
```
|
```
|
||||||
Edit an EXISTING file by exact string replacement. PREFER this over bash (sed/echo/redirects) for changing files — it shows a before/after diff. `old_string` must match the file exactly and be unique unless `replace_all` is true. Use write_file to create a new file.""",
|
Edit an EXISTING file by exact string replacement. PREFER this over bash (sed/echo/redirects) for changing files — it shows a before/after diff. `old_string` must match the file exactly and be unique unless `replace_all` is true. Use write_file to create a new file.""",
|
||||||
|
|
||||||
|
"get_workspace": """\
|
||||||
|
```get_workspace
|
||||||
|
```
|
||||||
|
Return the absolute path of the active workspace folder. File tools are CONFINED to it (paths can be RELATIVE to it); the shell starts there (cwd) but is NOT sandboxed. Call this first when the user says "the project"/"the code"/"this folder" without a path, instead of asking them. No arguments.""",
|
||||||
|
|
||||||
"create_document": """\
|
"create_document": """\
|
||||||
```create_document
|
```create_document
|
||||||
<title>
|
<title>
|
||||||
@@ -1726,6 +1732,7 @@ async def stream_agent_loop(
|
|||||||
plan_mode: bool = False,
|
plan_mode: bool = False,
|
||||||
approved_plan: Optional[str] = None,
|
approved_plan: Optional[str] = None,
|
||||||
tool_policy: Optional[ToolPolicy] = None,
|
tool_policy: Optional[ToolPolicy] = None,
|
||||||
|
workspace: Optional[str] = None,
|
||||||
_is_teacher_run: bool = False,
|
_is_teacher_run: bool = False,
|
||||||
) -> AsyncGenerator[str, None]:
|
) -> AsyncGenerator[str, None]:
|
||||||
"""Streaming agent loop generator.
|
"""Streaming agent loop generator.
|
||||||
@@ -1795,6 +1802,16 @@ async def stream_agent_loop(
|
|||||||
if not guide_only and not _relevant_tools and bool(_intent.get("low_signal")):
|
if not guide_only and not _relevant_tools and bool(_intent.get("low_signal")):
|
||||||
from src.tool_index import ALWAYS_AVAILABLE
|
from src.tool_index import ALWAYS_AVAILABLE
|
||||||
_relevant_tools = set(ALWAYS_AVAILABLE)
|
_relevant_tools = set(ALWAYS_AVAILABLE)
|
||||||
|
if workspace:
|
||||||
|
# An active workspace IS the file-work signal: a vague "look at the
|
||||||
|
# project" means explore this folder. Surface only the READ-ONLY file
|
||||||
|
# tools (intersection with the plan-mode read-only allowlist) so the
|
||||||
|
# agent can investigate; write/shell tools stay out until the request
|
||||||
|
# actually calls for them (RAG retrieval adds those on a real ask).
|
||||||
|
from src.tool_security import PLAN_MODE_READONLY_TOOLS
|
||||||
|
_relevant_tools |= (_DOMAIN_TOOL_MAP["files"] & PLAN_MODE_READONLY_TOOLS)
|
||||||
|
logger.info("[tool-rag] Low-signal but workspace active; including read-only file tools")
|
||||||
|
else:
|
||||||
logger.info("[tool-rag] Low-signal agent message; skipping retrieval and using always-available tools only")
|
logger.info("[tool-rag] Low-signal agent message; skipping retrieval and using always-available tools only")
|
||||||
if not guide_only and not _relevant_tools:
|
if not guide_only and not _relevant_tools:
|
||||||
try:
|
try:
|
||||||
@@ -2644,6 +2661,7 @@ async def stream_agent_loop(
|
|||||||
tool_policy=tool_policy,
|
tool_policy=tool_policy,
|
||||||
owner=owner,
|
owner=owner,
|
||||||
progress_cb=_push_progress,
|
progress_cb=_push_progress,
|
||||||
|
workspace=workspace,
|
||||||
)
|
)
|
||||||
finally:
|
finally:
|
||||||
# Sentinel so the drainer knows to stop.
|
# Sentinel so the drainer knows to stop.
|
||||||
@@ -2751,18 +2769,20 @@ async def stream_agent_loop(
|
|||||||
# On a bash/python timeout the result carries error + (often
|
# On a bash/python timeout the result carries error + (often
|
||||||
# empty) stdout/stderr; fall back to the error so the "timed
|
# empty) stdout/stderr; fall back to the error so the "timed
|
||||||
# out" reason reaches the UI instead of a blank result.
|
# out" reason reaches the UI instead of a blank result.
|
||||||
output_text = (result["stdout"] or result["stderr"] or result.get("error", ""))[:2000]
|
raw = result["stdout"] or result["stderr"] or result.get("error", "")
|
||||||
|
output_text = _truncate(raw)
|
||||||
elif "output" in result:
|
elif "output" in result:
|
||||||
# bash / python canonical result: {"output": ..., "exit_code": ...}
|
# bash / python canonical result: {"output": ..., "exit_code": ...}
|
||||||
output_text = (result["output"] or "")[:2000]
|
raw = result["output"] or ""
|
||||||
|
output_text = _truncate(raw)
|
||||||
elif "response" in result:
|
elif "response" in result:
|
||||||
# AI interaction tools (chat_with_model, send_to_session)
|
# AI interaction tools (chat_with_model, send_to_session)
|
||||||
label = result.get("model", result.get("session_name", "AI"))
|
label = result.get("model", result.get("session_name", "AI"))
|
||||||
output_text = f"{label}: {result['response']}"[:4000]
|
output_text = _truncate(f"{label}: {result['response']}")
|
||||||
elif "content" in result:
|
elif "content" in result:
|
||||||
output_text = result["content"][:2000]
|
output_text = _truncate(result["content"])
|
||||||
elif "results" in result:
|
elif "results" in result:
|
||||||
output_text = result["results"][:4000]
|
output_text = _truncate(result["results"])
|
||||||
elif "session_id" in result and "name" in result:
|
elif "session_id" in result and "name" in result:
|
||||||
output_text = f"Session created: {result['name']} (id: {result['session_id']})"
|
output_text = f"Session created: {result['name']} (id: {result['session_id']})"
|
||||||
elif "success" in result:
|
elif "success" in result:
|
||||||
@@ -2772,7 +2792,7 @@ async def stream_agent_loop(
|
|||||||
else f"Error: {result.get('error', '')}"
|
else f"Error: {result.get('error', '')}"
|
||||||
)
|
)
|
||||||
elif "error" in result:
|
elif "error" in result:
|
||||||
output_text = result["error"][:2000]
|
output_text = _truncate(result["error"])
|
||||||
|
|
||||||
# Emit tool_output (include ui_event data if present)
|
# Emit tool_output (include ui_event data if present)
|
||||||
tool_output_data = {"type": "tool_output", "tool": block.tool_type, "command": cmd_display, "output": output_text, "exit_code": result.get("exit_code")}
|
tool_output_data = {"type": "tool_output", "tool": block.tool_type, "command": cmd_display, "output": output_text, "exit_code": result.get("exit_code")}
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
from .subprocess_tools import BashTool, PythonTool
|
from .subprocess_tools import BashTool, PythonTool
|
||||||
from .web_tools import WebSearchTool, WebFetchTool
|
from .web_tools import WebSearchTool, WebFetchTool
|
||||||
from .filesystem_tools import ReadFileTool, WriteFileTool, EditFileTool, LsTool, GlobTool, GrepTool
|
from .filesystem_tools import ReadFileTool, WriteFileTool, EditFileTool, LsTool, GlobTool, GrepTool, GetWorkspaceTool
|
||||||
from .document_tools import CreateDocumentTool, UpdateDocumentTool, EditDocumentTool, SuggestDocumentTool, ManageDocumentTool
|
from .document_tools import CreateDocumentTool, UpdateDocumentTool, EditDocumentTool, SuggestDocumentTool, ManageDocumentTool
|
||||||
|
|
||||||
TOOL_HANDLERS = {
|
TOOL_HANDLERS = {
|
||||||
@@ -39,6 +39,7 @@ TOOL_HANDLERS = {
|
|||||||
"edit_document": EditDocumentTool().execute,
|
"edit_document": EditDocumentTool().execute,
|
||||||
"suggest_document": SuggestDocumentTool().execute,
|
"suggest_document": SuggestDocumentTool().execute,
|
||||||
"manage_documents": ManageDocumentTool().execute,
|
"manage_documents": ManageDocumentTool().execute,
|
||||||
|
"get_workspace": GetWorkspaceTool().execute,
|
||||||
}
|
}
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -51,7 +52,7 @@ PYTHON_TIMEOUT = 30
|
|||||||
|
|
||||||
# Tool types that trigger execution
|
# Tool types that trigger execution
|
||||||
TOOL_TAGS = {"bash", "python", "web_search", "web_fetch", "read_file", "write_file", "edit_file",
|
TOOL_TAGS = {"bash", "python", "web_search", "web_fetch", "read_file", "write_file", "edit_file",
|
||||||
"grep", "glob", "ls",
|
"grep", "glob", "ls", "get_workspace",
|
||||||
"create_document", "update_document", "edit_document",
|
"create_document", "update_document", "edit_document",
|
||||||
"search_chats",
|
"search_chats",
|
||||||
"chat_with_model", "create_session", "list_sessions",
|
"chat_with_model", "create_session", "list_sessions",
|
||||||
|
|||||||
@@ -46,13 +46,7 @@ def _unified_diff(old: str, new: str, path: str) -> Optional[Dict[str, Any]]:
|
|||||||
|
|
||||||
class EditFileTool:
|
class EditFileTool:
|
||||||
async def execute(self, content: str, ctx: dict) -> dict:
|
async def execute(self, content: str, ctx: dict) -> dict:
|
||||||
from src.tool_execution import (
|
from src.tool_execution import _resolve_tool_path, _resolve_search_root, _truncate
|
||||||
_resolve_tool_path,
|
|
||||||
_resolve_tool_path_in_workspace,
|
|
||||||
_resolve_search_root,
|
|
||||||
_truncate
|
|
||||||
)
|
|
||||||
workspace = ctx.get("workspace")
|
|
||||||
try:
|
try:
|
||||||
args = json.loads(content) if content.strip().startswith("{") else {}
|
args = json.loads(content) if content.strip().startswith("{") else {}
|
||||||
except (json.JSONDecodeError, TypeError):
|
except (json.JSONDecodeError, TypeError):
|
||||||
@@ -64,8 +58,7 @@ class EditFileTool:
|
|||||||
if not raw_path:
|
if not raw_path:
|
||||||
return {"error": "edit_file: path required", "exit_code": 1}
|
return {"error": "edit_file: path required", "exit_code": 1}
|
||||||
try:
|
try:
|
||||||
path = (_resolve_tool_path_in_workspace(workspace, raw_path)
|
path = _resolve_tool_path(raw_path)
|
||||||
if workspace else _resolve_tool_path(raw_path))
|
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
return {"error": f"edit_file: {e}", "exit_code": 1}
|
return {"error": f"edit_file: {e}", "exit_code": 1}
|
||||||
if old == "":
|
if old == "":
|
||||||
@@ -113,13 +106,7 @@ class EditFileTool:
|
|||||||
|
|
||||||
class ReadFileTool:
|
class ReadFileTool:
|
||||||
async def execute(self, content: str, ctx: dict) -> dict:
|
async def execute(self, content: str, ctx: dict) -> dict:
|
||||||
from src.tool_execution import (
|
from src.tool_execution import _resolve_tool_path, _resolve_search_root, _truncate
|
||||||
_resolve_tool_path,
|
|
||||||
_resolve_tool_path_in_workspace,
|
|
||||||
_resolve_search_root,
|
|
||||||
_truncate
|
|
||||||
)
|
|
||||||
workspace = ctx.get("workspace")
|
|
||||||
raw_path, offset, limit = content.split("\n", 1)[0].strip(), 0, 0
|
raw_path, offset, limit = content.split("\n", 1)[0].strip(), 0, 0
|
||||||
_stripped = content.strip()
|
_stripped = content.strip()
|
||||||
if _stripped.startswith("{"):
|
if _stripped.startswith("{"):
|
||||||
@@ -131,8 +118,7 @@ class ReadFileTool:
|
|||||||
except (json.JSONDecodeError, TypeError, ValueError):
|
except (json.JSONDecodeError, TypeError, ValueError):
|
||||||
pass
|
pass
|
||||||
try:
|
try:
|
||||||
path = (_resolve_tool_path_in_workspace(workspace, raw_path)
|
path = _resolve_tool_path(raw_path)
|
||||||
if workspace else _resolve_tool_path(raw_path))
|
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
return {"error": f"read_file: {e}", "exit_code": 1}
|
return {"error": f"read_file: {e}", "exit_code": 1}
|
||||||
try:
|
try:
|
||||||
@@ -170,19 +156,12 @@ class ReadFileTool:
|
|||||||
|
|
||||||
class WriteFileTool:
|
class WriteFileTool:
|
||||||
async def execute(self, content: str, ctx: dict) -> dict:
|
async def execute(self, content: str, ctx: dict) -> dict:
|
||||||
from src.tool_execution import (
|
from src.tool_execution import _resolve_tool_path, _resolve_search_root, _truncate
|
||||||
_resolve_tool_path,
|
|
||||||
_resolve_tool_path_in_workspace,
|
|
||||||
_resolve_search_root,
|
|
||||||
_truncate
|
|
||||||
)
|
|
||||||
workspace = ctx.get("workspace")
|
|
||||||
lines = content.split("\n", 1)
|
lines = content.split("\n", 1)
|
||||||
raw_path = lines[0].strip()
|
raw_path = lines[0].strip()
|
||||||
body = lines[1] if len(lines) > 1 else ""
|
body = lines[1] if len(lines) > 1 else ""
|
||||||
try:
|
try:
|
||||||
path = (_resolve_tool_path_in_workspace(workspace, raw_path)
|
path = _resolve_tool_path(raw_path)
|
||||||
if workspace else _resolve_tool_path(raw_path))
|
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
return {"error": f"write_file: {e}", "exit_code": 1}
|
return {"error": f"write_file: {e}", "exit_code": 1}
|
||||||
try:
|
try:
|
||||||
@@ -212,13 +191,7 @@ class WriteFileTool:
|
|||||||
|
|
||||||
class LsTool:
|
class LsTool:
|
||||||
async def execute(self, content: str, ctx: dict) -> dict:
|
async def execute(self, content: str, ctx: dict) -> dict:
|
||||||
from src.tool_execution import (
|
from src.tool_execution import _resolve_tool_path, _resolve_search_root, _truncate
|
||||||
_resolve_tool_path,
|
|
||||||
_resolve_tool_path_in_workspace,
|
|
||||||
_resolve_search_root,
|
|
||||||
_truncate
|
|
||||||
)
|
|
||||||
workspace = ctx.get("workspace")
|
|
||||||
raw_path = ""
|
raw_path = ""
|
||||||
_s = (content or "").strip()
|
_s = (content or "").strip()
|
||||||
if _s.startswith("{"):
|
if _s.startswith("{"):
|
||||||
@@ -267,13 +240,7 @@ class LsTool:
|
|||||||
|
|
||||||
class GlobTool:
|
class GlobTool:
|
||||||
async def execute(self, content: str, ctx: dict) -> dict:
|
async def execute(self, content: str, ctx: dict) -> dict:
|
||||||
from src.tool_execution import (
|
from src.tool_execution import _resolve_tool_path, _resolve_search_root, _truncate
|
||||||
_resolve_tool_path,
|
|
||||||
_resolve_tool_path_in_workspace,
|
|
||||||
_resolve_search_root,
|
|
||||||
_truncate
|
|
||||||
)
|
|
||||||
workspace = ctx.get("workspace")
|
|
||||||
args = {}
|
args = {}
|
||||||
_s = (content or "").strip()
|
_s = (content or "").strip()
|
||||||
if _s.startswith("{"):
|
if _s.startswith("{"):
|
||||||
@@ -325,13 +292,7 @@ class GlobTool:
|
|||||||
|
|
||||||
class GrepTool:
|
class GrepTool:
|
||||||
async def execute(self, content: str, ctx: dict) -> dict:
|
async def execute(self, content: str, ctx: dict) -> dict:
|
||||||
from src.tool_execution import (
|
from src.tool_execution import _resolve_tool_path, _resolve_search_root, _truncate
|
||||||
_resolve_tool_path,
|
|
||||||
_resolve_tool_path_in_workspace,
|
|
||||||
_resolve_search_root,
|
|
||||||
_truncate
|
|
||||||
)
|
|
||||||
workspace = ctx.get("workspace")
|
|
||||||
args: Dict[str, Any] = {}
|
args: Dict[str, Any] = {}
|
||||||
_s = (content or "").strip()
|
_s = (content or "").strip()
|
||||||
if _s.startswith("{"):
|
if _s.startswith("{"):
|
||||||
@@ -417,3 +378,21 @@ class GrepTool:
|
|||||||
if len(lines) >= max_hits:
|
if len(lines) >= max_hits:
|
||||||
out += f"\n... [capped at {max_hits} matches]"
|
out += f"\n... [capped at {max_hits} matches]"
|
||||||
return {"output": _truncate(out), "exit_code": 0}
|
return {"output": _truncate(out), "exit_code": 0}
|
||||||
|
|
||||||
|
class GetWorkspaceTool:
|
||||||
|
"""Report the active workspace folder (no args). File tools are confined to
|
||||||
|
it; the shell starts there (cwd) but is NOT sandboxed."""
|
||||||
|
async def execute(self, content: str, ctx: dict) -> dict:
|
||||||
|
from src.tool_execution import get_active_workspace
|
||||||
|
ws = get_active_workspace()
|
||||||
|
if ws:
|
||||||
|
return {
|
||||||
|
"output": f"{ws}\n(File tools are confined to this folder; the shell starts "
|
||||||
|
f"here but is not sandboxed and can reach outside it.)",
|
||||||
|
"exit_code": 0,
|
||||||
|
}
|
||||||
|
return {
|
||||||
|
"output": "No workspace is set. File tools use the default allowed roots; "
|
||||||
|
"resolve paths from the user or use absolute paths.",
|
||||||
|
"exit_code": 0,
|
||||||
|
}
|
||||||
|
|||||||
@@ -102,16 +102,15 @@ async def _run_subprocess_streaming(
|
|||||||
|
|
||||||
class BashTool:
|
class BashTool:
|
||||||
async def execute(self, content: str, ctx: dict) -> dict:
|
async def execute(self, content: str, ctx: dict) -> dict:
|
||||||
from src.tool_execution import _AGENT_WORKDIR, _truncate
|
from src.tool_execution import agent_cwd, _truncate
|
||||||
progress_cb = ctx.get("progress_cb")
|
progress_cb = ctx.get("progress_cb")
|
||||||
workspace = ctx.get("workspace")
|
|
||||||
_subproc_env = ctx.get("subproc_env")
|
_subproc_env = ctx.get("subproc_env")
|
||||||
proc = await asyncio.create_subprocess_shell(
|
proc = await asyncio.create_subprocess_shell(
|
||||||
content,
|
content,
|
||||||
stdout=asyncio.subprocess.PIPE,
|
stdout=asyncio.subprocess.PIPE,
|
||||||
stderr=asyncio.subprocess.PIPE,
|
stderr=asyncio.subprocess.PIPE,
|
||||||
env=_subproc_env,
|
env=_subproc_env,
|
||||||
cwd=workspace or _AGENT_WORKDIR,
|
cwd=agent_cwd(),
|
||||||
)
|
)
|
||||||
stdout, stderr, rc, timed_out = await _run_subprocess_streaming(
|
stdout, stderr, rc, timed_out = await _run_subprocess_streaming(
|
||||||
proc,
|
proc,
|
||||||
@@ -129,16 +128,15 @@ class BashTool:
|
|||||||
|
|
||||||
class PythonTool:
|
class PythonTool:
|
||||||
async def execute(self, content: str, ctx: dict) -> dict:
|
async def execute(self, content: str, ctx: dict) -> dict:
|
||||||
from src.tool_execution import _AGENT_WORKDIR, _truncate
|
from src.tool_execution import agent_cwd, _truncate
|
||||||
progress_cb = ctx.get("progress_cb")
|
progress_cb = ctx.get("progress_cb")
|
||||||
workspace = ctx.get("workspace")
|
|
||||||
_subproc_env = ctx.get("subproc_env")
|
_subproc_env = ctx.get("subproc_env")
|
||||||
proc = await asyncio.create_subprocess_exec(
|
proc = await asyncio.create_subprocess_exec(
|
||||||
(sys.executable or "python"), "-I", "-c", content,
|
(sys.executable or "python"), "-I", "-c", content,
|
||||||
stdout=asyncio.subprocess.PIPE,
|
stdout=asyncio.subprocess.PIPE,
|
||||||
stderr=asyncio.subprocess.PIPE,
|
stderr=asyncio.subprocess.PIPE,
|
||||||
env=_subproc_env,
|
env=_subproc_env,
|
||||||
cwd=workspace or _AGENT_WORKDIR,
|
cwd=agent_cwd(),
|
||||||
)
|
)
|
||||||
stdout, stderr, rc, timed_out = await _run_subprocess_streaming(
|
stdout, stderr, rc, timed_out = await _run_subprocess_streaming(
|
||||||
proc,
|
proc,
|
||||||
|
|||||||
+31
-15
@@ -579,6 +579,24 @@ def _classify_event_heuristic(summary: str) -> tuple:
|
|||||||
return etype, None
|
return etype, None
|
||||||
|
|
||||||
|
|
||||||
|
def _memory_context_lines(mems, limit: int = 40) -> list:
|
||||||
|
"""Render Memory rows into short personal-context bullets for event classify.
|
||||||
|
|
||||||
|
Reads the Memory ORM `text` column. The previous inline code read a
|
||||||
|
non-existent `content` attribute, so it raised AttributeError on the first
|
||||||
|
row, the surrounding except swallowed it, and the classifier ran with no
|
||||||
|
personal context at all. getattr keeps it robust to future schema drift.
|
||||||
|
"""
|
||||||
|
lines: list = []
|
||||||
|
for m in mems:
|
||||||
|
c = (getattr(m, "text", "") or "").strip()
|
||||||
|
if c:
|
||||||
|
lines.append(f"- {c[:200]}")
|
||||||
|
if len(lines) >= limit:
|
||||||
|
break
|
||||||
|
return lines
|
||||||
|
|
||||||
|
|
||||||
async def action_classify_events(owner: str, **kwargs) -> Tuple[str, bool]:
|
async def action_classify_events(owner: str, **kwargs) -> Tuple[str, bool]:
|
||||||
"""Hybrid classification of upcoming calendar events: fast heuristic for
|
"""Hybrid classification of upcoming calendar events: fast heuristic for
|
||||||
obvious cases, LLM fallback for ambiguous ones. Assigns event_type +
|
obvious cases, LLM fallback for ambiguous ones. Assigns event_type +
|
||||||
@@ -614,16 +632,11 @@ async def action_classify_events(owner: str, **kwargs) -> Tuple[str, bool]:
|
|||||||
try:
|
try:
|
||||||
from core.database import Memory as _Mem
|
from core.database import Memory as _Mem
|
||||||
_mems = db.query(_Mem).filter(_Mem.owner == owner).limit(60).all() if owner else []
|
_mems = db.query(_Mem).filter(_Mem.owner == owner).limit(60).all() if owner else []
|
||||||
if _mems:
|
_lines = _memory_context_lines(_mems)
|
||||||
_lines = []
|
|
||||||
for m in _mems:
|
|
||||||
c = (m.content or "").strip()
|
|
||||||
if c:
|
|
||||||
_lines.append(f"- {c[:200]}")
|
|
||||||
if _lines:
|
if _lines:
|
||||||
_memory_context = "USER CONTEXT (relationships, work, life):\n" + "\n".join(_lines[:40]) + "\n\n"
|
_memory_context = "USER CONTEXT (relationships, work, life):\n" + "\n".join(_lines) + "\n\n"
|
||||||
except Exception as _me:
|
except Exception as _me:
|
||||||
logger.debug(f"Could not load memory for classify: {_me}")
|
logger.warning(f"Could not load memory for classify: {_me}")
|
||||||
|
|
||||||
classified_h = 0
|
classified_h = 0
|
||||||
classified_llm = 0
|
classified_llm = 0
|
||||||
@@ -796,14 +809,14 @@ async def action_learn_sender_signatures(owner: str, **kwargs) -> Tuple[str, boo
|
|||||||
import email as _email_mod
|
import email as _email_mod
|
||||||
import asyncio as _aio
|
import asyncio as _aio
|
||||||
from datetime import datetime as _dt, timedelta as _td
|
from datetime import datetime as _dt, timedelta as _td
|
||||||
from routes.email_helpers import _imap_connect, SCHEDULED_DB
|
from routes.email_helpers import _email_cache_owner_clause, _imap_connect, SCHEDULED_DB
|
||||||
from src.endpoint_resolver import resolve_endpoint
|
from src.endpoint_resolver import resolve_endpoint
|
||||||
from src.llm_core import llm_call_async
|
from src.llm_core import llm_call_async
|
||||||
|
|
||||||
# 1. Pull recent UIDs + From headers cheaply (header-only fetch).
|
# 1. Pull recent UIDs + From headers cheaply (header-only fetch).
|
||||||
def _pull_headers():
|
def _pull_headers():
|
||||||
results = []
|
results = []
|
||||||
conn = _imap_connect(None)
|
conn = _imap_connect(None, owner=owner)
|
||||||
try:
|
try:
|
||||||
conn.select("INBOX", readonly=True)
|
conn.select("INBOX", readonly=True)
|
||||||
status, data = conn.search(None, "ALL")
|
status, data = conn.search(None, "ALL")
|
||||||
@@ -855,9 +868,11 @@ async def action_learn_sender_signatures(owner: str, **kwargs) -> Tuple[str, boo
|
|||||||
# 3. Eligibility: ≥3 emails AND (no cache OR cache > 30 days old).
|
# 3. Eligibility: ≥3 emails AND (no cache OR cache > 30 days old).
|
||||||
try:
|
try:
|
||||||
conn = _sql3.connect(SCHEDULED_DB)
|
conn = _sql3.connect(SCHEDULED_DB)
|
||||||
|
owner_clause, owner_params = _email_cache_owner_clause(owner)
|
||||||
cached = {
|
cached = {
|
||||||
r[0]: r[1] for r in conn.execute(
|
r[0]: r[1] for r in conn.execute(
|
||||||
"SELECT from_address, last_built_at FROM sender_signatures"
|
f"SELECT from_address, last_built_at FROM sender_signatures WHERE {owner_clause}",
|
||||||
|
owner_params,
|
||||||
).fetchall()
|
).fetchall()
|
||||||
}
|
}
|
||||||
conn.close()
|
conn.close()
|
||||||
@@ -888,7 +903,7 @@ async def action_learn_sender_signatures(owner: str, **kwargs) -> Tuple[str, boo
|
|||||||
|
|
||||||
def _fetch_bodies(_msgs):
|
def _fetch_bodies(_msgs):
|
||||||
bodies = []
|
bodies = []
|
||||||
conn2 = _imap_connect(None)
|
conn2 = _imap_connect(None, owner=owner)
|
||||||
try:
|
try:
|
||||||
conn2.select("INBOX", readonly=True)
|
conn2.select("INBOX", readonly=True)
|
||||||
for mm in _msgs:
|
for mm in _msgs:
|
||||||
@@ -965,11 +980,12 @@ async def action_learn_sender_signatures(owner: str, **kwargs) -> Tuple[str, boo
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
conn = _sql3.connect(SCHEDULED_DB)
|
conn = _sql3.connect(SCHEDULED_DB)
|
||||||
|
owner_value = (owner or "").strip()
|
||||||
conn.execute(
|
conn.execute(
|
||||||
"INSERT OR REPLACE INTO sender_signatures "
|
"INSERT OR REPLACE INTO sender_signatures "
|
||||||
"(from_address, signature_text, sample_count, last_built_at, model_used, source) "
|
"(from_address, owner, signature_text, sample_count, last_built_at, model_used, source) "
|
||||||
"VALUES (?, ?, ?, ?, ?, ?)",
|
"VALUES (?, ?, ?, ?, ?, ?, ?)",
|
||||||
(addr, cached_sig, len(bodies), _dt.utcnow().isoformat(), model, "llm"),
|
(addr, owner_value, cached_sig, len(bodies), _dt.utcnow().isoformat(), model, "llm"),
|
||||||
)
|
)
|
||||||
conn.commit()
|
conn.commit()
|
||||||
conn.close()
|
conn.close()
|
||||||
|
|||||||
@@ -232,6 +232,7 @@ class DeepResearcher:
|
|||||||
self._start_time: float = 0
|
self._start_time: float = 0
|
||||||
self.queries_used: Set[str] = set()
|
self.queries_used: Set[str] = set()
|
||||||
self.urls_fetched: Set[str] = set()
|
self.urls_fetched: Set[str] = set()
|
||||||
|
self.analyzed_urls: List[Dict[str, str]] = []
|
||||||
self.round_count: int = 0
|
self.round_count: int = 0
|
||||||
# Track which search providers actually returned results during the
|
# Track which search providers actually returned results during the
|
||||||
# run, in arrival order — surfaced in the visual report so users can
|
# run, in arrival order — surfaced in the visual report so users can
|
||||||
@@ -525,6 +526,10 @@ class DeepResearcher:
|
|||||||
if url and url not in self.urls_fetched:
|
if url and url not in self.urls_fetched:
|
||||||
urls_to_fetch.append(r)
|
urls_to_fetch.append(r)
|
||||||
self.urls_fetched.add(url)
|
self.urls_fetched.add(url)
|
||||||
|
self.analyzed_urls.append({
|
||||||
|
"url": url,
|
||||||
|
"title": r.get("title", "") or url,
|
||||||
|
})
|
||||||
if len(urls_to_fetch) >= self.max_urls_per_round * len(queries):
|
if len(urls_to_fetch) >= self.max_urls_per_round * len(queries):
|
||||||
break
|
break
|
||||||
|
|
||||||
|
|||||||
+41
-7
@@ -457,15 +457,25 @@ def _detect_provider(url: str) -> str:
|
|||||||
|
|
||||||
def _is_self_hosted_openai_compatible(url: str) -> bool:
|
def _is_self_hosted_openai_compatible(url: str) -> bool:
|
||||||
"""True for custom/local OpenAI-compatible servers (llama.cpp, LM Studio,
|
"""True for custom/local OpenAI-compatible servers (llama.cpp, LM Studio,
|
||||||
vLLM, text-generation-webui, etc.) as opposed to api.openai.com itself.
|
vLLM, text-generation-webui, etc.) as opposed to cloud APIs.
|
||||||
|
|
||||||
Used to gate llama.cpp-server-specific payload extras (``session_id``,
|
Used to gate llama.cpp-server-specific payload extras (``session_id``,
|
||||||
``cache_prompt``) — sending unrecognized top-level fields to OpenAI's
|
``cache_prompt``) used for KV-cache slot affinity (issue #2927). Strict
|
||||||
actual API returns a 400 ("Unrecognized request argument"), but
|
cloud providers reject unrecognized top-level fields (api.openai.com
|
||||||
self-hosted servers generally ignore unknown fields and many (notably
|
returns 400, Mistral returns 422 "extra_forbidden", issue #3793), and any
|
||||||
llama.cpp's server) use them for KV-cache slot affinity (issue #2927).
|
unknown OpenAI-compatible host used to be treated as self-hosted, so those
|
||||||
|
fields leaked to every strict provider added as a custom endpoint.
|
||||||
|
|
||||||
|
A server only counts as self-hosted when it also resolves as local:
|
||||||
|
loopback/private/tailscale host, or the endpoint explicitly configured
|
||||||
|
with kind "local". A self-hosted server exposed via a public hostname
|
||||||
|
loses the affinity hint unless its endpoint kind is set to "local" -
|
||||||
|
a lost perf hint, versus a hard 4xx on every request the other way.
|
||||||
"""
|
"""
|
||||||
return _detect_provider(url) == "openai" and not _host_match(url, "openai.com")
|
if _detect_provider(url) != "openai" or _host_match(url, "openai.com"):
|
||||||
|
return False
|
||||||
|
from src.model_context import is_local_endpoint
|
||||||
|
return is_local_endpoint(url)
|
||||||
|
|
||||||
|
|
||||||
def _apply_local_cache_affinity(payload: Dict, url: str, session_id: Optional[str]) -> None:
|
def _apply_local_cache_affinity(payload: Dict, url: str, session_id: Optional[str]) -> None:
|
||||||
@@ -681,6 +691,27 @@ def _restricts_temperature(model: str) -> bool:
|
|||||||
m = model.lower()
|
m = model.lower()
|
||||||
return any(m.startswith(p) or f"/{p}" in m for p in _FIXED_TEMPERATURE_MODELS)
|
return any(m.startswith(p) or f"/{p}" in m for p in _FIXED_TEMPERATURE_MODELS)
|
||||||
|
|
||||||
|
# Anthropic removed the sampling parameters (temperature, top_p, top_k) starting
|
||||||
|
# with Claude Opus 4.7. On Opus 4.7 and later, sending `temperature` at all —
|
||||||
|
# even 0.0 — returns HTTP 400. Earlier Claude models (Opus 4.6 and below, every
|
||||||
|
# Sonnet/Haiku) still accept temperature in [0.0, 1.0], so the omission must be
|
||||||
|
# version-gated rather than applied to all `claude-*` models.
|
||||||
|
def _anthropic_rejects_temperature(model: str) -> bool:
|
||||||
|
"""Check if a native-Anthropic model rejects the temperature field (Opus 4.7+)."""
|
||||||
|
if not isinstance(model, str) or not model:
|
||||||
|
return False
|
||||||
|
# `(?<![a-z])` anchors "opus" to a word boundary so a substring match like
|
||||||
|
# `oct-opus`/`octopus-4-8` can't be read as Opus (it would otherwise strip
|
||||||
|
# temperature). Cap the minor at 1-2 digits and forbid a trailing digit so a
|
||||||
|
# dated id like `claude-opus-4-20250514` (Opus 4.0) parses as major-only (no
|
||||||
|
# minor match, kept) instead of reading the date `20250514` as a giant minor
|
||||||
|
# that would falsely test >= 4.7. Dated 4.7+ snapshots (`claude-opus-4-7-
|
||||||
|
# 20260201`) keep their explicit minor and are still matched.
|
||||||
|
match = re.search(r"(?<![a-z])opus[-_]?(\d+)[-_.](\d{1,2})(?!\d)", model.lower())
|
||||||
|
if not match:
|
||||||
|
return False
|
||||||
|
return (int(match.group(1)), int(match.group(2))) >= (4, 7)
|
||||||
|
|
||||||
# Models that support structured thinking — may output </think> without opening tag
|
# Models that support structured thinking — may output </think> without opening tag
|
||||||
_THINKING_MODEL_PATTERNS = ("qwen3", "qwq", "deepseek-r1", "deepseek-reasoner", "minimax", "m2-reap", "gemma")
|
_THINKING_MODEL_PATTERNS = ("qwen3", "qwq", "deepseek-r1", "deepseek-reasoner", "minimax", "m2-reap", "gemma")
|
||||||
|
|
||||||
@@ -784,8 +815,11 @@ def _build_anthropic_payload(model, messages, temperature, max_tokens, stream=Fa
|
|||||||
"model": model,
|
"model": model,
|
||||||
"messages": chat_messages,
|
"messages": chat_messages,
|
||||||
"max_tokens": max_tokens if max_tokens and max_tokens > 0 else 4096,
|
"max_tokens": max_tokens if max_tokens and max_tokens > 0 else 4096,
|
||||||
"temperature": temperature,
|
|
||||||
}
|
}
|
||||||
|
# Opus 4.7+ removed the sampling parameters — sending `temperature` (even 0.0)
|
||||||
|
# returns HTTP 400. Omit it for those models; older Claude models still take it.
|
||||||
|
if not _anthropic_rejects_temperature(model):
|
||||||
|
payload["temperature"] = temperature
|
||||||
if system_parts:
|
if system_parts:
|
||||||
system_text = "\n\n".join(system_parts)
|
system_text = "\n\n".join(system_parts)
|
||||||
# Send `system` as a structured text block so we can attach a prompt-cache
|
# Send `system` as a structured text block so we can attach a prompt-cache
|
||||||
|
|||||||
+20
-6
@@ -5,6 +5,7 @@ Query and cache model context window sizes from OpenAI-compatible APIs.
|
|||||||
Provides token estimation for context usage tracking.
|
Provides token estimation for context usage tracking.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import ipaddress
|
||||||
import logging
|
import logging
|
||||||
import sys
|
import sys
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Dict, List, Optional, Tuple
|
||||||
@@ -19,7 +20,20 @@ _LOCAL_HOSTS = {"localhost", "127.0.0.1", "0.0.0.0", "::1", "host.docker.interna
|
|||||||
_PRIVATE_PREFIXES = ("10.", "172.16.", "172.17.", "172.18.", "172.19.",
|
_PRIVATE_PREFIXES = ("10.", "172.16.", "172.17.", "172.18.", "172.19.",
|
||||||
"172.20.", "172.21.", "172.22.", "172.23.", "172.24.",
|
"172.20.", "172.21.", "172.22.", "172.23.", "172.24.",
|
||||||
"172.25.", "172.26.", "172.27.", "172.28.", "172.29.",
|
"172.25.", "172.26.", "172.27.", "172.28.", "172.29.",
|
||||||
"172.30.", "172.31.", "192.168.", "100.")
|
"172.30.", "172.31.", "192.168.")
|
||||||
|
|
||||||
|
# Tailscale uses the CGNAT range 100.64.0.0/10, NOT all of 100.0.0.0/8.
|
||||||
|
# A bare "100." prefix would classify public addresses (e.g. AWS ranges
|
||||||
|
# under 100.x outside the CGNAT block) as local; routes/model_routes.py
|
||||||
|
# already narrows this the same way for endpoint classification.
|
||||||
|
_TAILSCALE_CGNAT = ipaddress.ip_network("100.64.0.0/10")
|
||||||
|
|
||||||
|
|
||||||
|
def _in_tailscale_range(host: str) -> bool:
|
||||||
|
try:
|
||||||
|
return ipaddress.ip_address(host) in _TAILSCALE_CGNAT
|
||||||
|
except ValueError:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
def _normalize_base_for_compare(url: str) -> str:
|
def _normalize_base_for_compare(url: str) -> str:
|
||||||
@@ -64,7 +78,7 @@ def _configured_endpoint_kind(url: str) -> Optional[str]:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def _is_local_endpoint(url: str) -> bool:
|
def is_local_endpoint(url: str) -> bool:
|
||||||
"""Check if URL points to a local/private/tailscale address."""
|
"""Check if URL points to a local/private/tailscale address."""
|
||||||
kind = _configured_endpoint_kind(url)
|
kind = _configured_endpoint_kind(url)
|
||||||
if kind in ("api", "proxy"):
|
if kind in ("api", "proxy"):
|
||||||
@@ -73,7 +87,7 @@ def _is_local_endpoint(url: str) -> bool:
|
|||||||
return True
|
return True
|
||||||
try:
|
try:
|
||||||
host = urlparse(url).hostname or ""
|
host = urlparse(url).hostname or ""
|
||||||
return host in _LOCAL_HOSTS or host.startswith(_PRIVATE_PREFIXES)
|
return host in _LOCAL_HOSTS or host.startswith(_PRIVATE_PREFIXES) or _in_tailscale_range(host)
|
||||||
except Exception:
|
except Exception:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@@ -219,7 +233,7 @@ def get_context_length(endpoint_url: str, model: str) -> int:
|
|||||||
Falls back to DEFAULT_CONTEXT if unavailable.
|
Falls back to DEFAULT_CONTEXT if unavailable.
|
||||||
"""
|
"""
|
||||||
configured_kind = _configured_endpoint_kind(endpoint_url)
|
configured_kind = _configured_endpoint_kind(endpoint_url)
|
||||||
is_local = _is_local_endpoint(endpoint_url)
|
is_local = is_local_endpoint(endpoint_url)
|
||||||
# Key on (endpoint_url, model): the same model id can be served by two
|
# Key on (endpoint_url, model): the same model id can be served by two
|
||||||
# different remote endpoints with different real context windows (e.g. a
|
# different remote endpoints with different real context windows (e.g. a
|
||||||
# capped proxy vs. the full provider), so caching by model id alone would
|
# capped proxy vs. the full provider), so caching by model id alone would
|
||||||
@@ -273,7 +287,7 @@ def _query_context_length(endpoint_url: str, model: str) -> int:
|
|||||||
return DEFAULT_CONTEXT
|
return DEFAULT_CONTEXT
|
||||||
|
|
||||||
# Try llama.cpp /slots endpoint first — reports actual serving context
|
# Try llama.cpp /slots endpoint first — reports actual serving context
|
||||||
if _is_local_endpoint(endpoint_url):
|
if is_local_endpoint(endpoint_url):
|
||||||
try:
|
try:
|
||||||
base = endpoint_url.split("/v1")[0] if "/v1" in endpoint_url else endpoint_url.rsplit("/", 1)[0]
|
base = endpoint_url.split("/v1")[0] if "/v1" in endpoint_url else endpoint_url.rsplit("/", 1)[0]
|
||||||
r = httpx.get(f"{base}/slots", timeout=REQUEST_TIMEOUT)
|
r = httpx.get(f"{base}/slots", timeout=REQUEST_TIMEOUT)
|
||||||
@@ -337,7 +351,7 @@ def _query_context_length(endpoint_url: str, model: str) -> int:
|
|||||||
# For local/self-hosted endpoints, trust the API value (user set --max-model-len)
|
# For local/self-hosted endpoints, trust the API value (user set --max-model-len)
|
||||||
# For cloud APIs, use the larger value (API can report low defaults)
|
# For cloud APIs, use the larger value (API can report low defaults)
|
||||||
if api_ctx and known:
|
if api_ctx and known:
|
||||||
_is_local = _is_local_endpoint(endpoint_url)
|
_is_local = is_local_endpoint(endpoint_url)
|
||||||
if _is_local and api_ctx < known:
|
if _is_local and api_ctx < known:
|
||||||
logger.info(f"Local endpoint reports {api_ctx} for {model} (known max: {known}) — using API value")
|
logger.info(f"Local endpoint reports {api_ctx} for {model} (known max: {known}) — using API value")
|
||||||
return api_ctx
|
return api_ctx
|
||||||
|
|||||||
@@ -223,6 +223,25 @@ class ModelDiscovery:
|
|||||||
)
|
)
|
||||||
return {"hosts": hosts, "items": items}
|
return {"hosts": hosts, "items": items}
|
||||||
|
|
||||||
|
def warmup_ping_urls(self, limit: int = 5) -> List[str]:
|
||||||
|
"""The ``/models`` URLs of up to ``limit`` discovered endpoints.
|
||||||
|
|
||||||
|
Used by the startup warmup / keepalive loop to prime connections. Each
|
||||||
|
discovered item already carries a ``/v1/chat/completions`` url; swap the
|
||||||
|
suffix for the cheap ``/models`` probe. Failures degrade to an empty list
|
||||||
|
so warmup never crashes the caller.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
items = (self.discover_models() or {}).get("items", [])
|
||||||
|
except Exception:
|
||||||
|
return []
|
||||||
|
urls: List[str] = []
|
||||||
|
for ep in items[:limit]:
|
||||||
|
url = (ep.get("url") or "").replace("/chat/completions", "/models")
|
||||||
|
if url:
|
||||||
|
urls.append(url)
|
||||||
|
return urls
|
||||||
|
|
||||||
def get_providers(self) -> Dict[str, Any]:
|
def get_providers(self) -> Dict[str, Any]:
|
||||||
"""Get all available providers"""
|
"""Get all available providers"""
|
||||||
discovery = self.discover_models()
|
discovery = self.discover_models()
|
||||||
|
|||||||
+24
-1
@@ -221,6 +221,22 @@ class ResearchHandler:
|
|||||||
# Task registry — background research with persistence
|
# Task registry — background research with persistence
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def rename_owner(self, old_owner: str, new_owner: str) -> int:
|
||||||
|
"""Move in-flight research tasks from one owner key to another."""
|
||||||
|
old_key = str(old_owner or "").strip().lower()
|
||||||
|
new_key = str(new_owner or "").strip().lower()
|
||||||
|
if not old_key or not new_key:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
changed = 0
|
||||||
|
for entry in list(self._active_tasks.values()):
|
||||||
|
if not isinstance(entry, dict):
|
||||||
|
continue
|
||||||
|
if str(entry.get("owner", "")).strip().lower() == old_key:
|
||||||
|
entry["owner"] = new_key
|
||||||
|
changed += 1
|
||||||
|
return changed
|
||||||
|
|
||||||
def start_research(
|
def start_research(
|
||||||
self,
|
self,
|
||||||
session_id: str,
|
session_id: str,
|
||||||
@@ -390,7 +406,6 @@ class ResearchHandler:
|
|||||||
|
|
||||||
def get_status(self, session_id: str) -> Optional[dict]:
|
def get_status(self, session_id: str) -> Optional[dict]:
|
||||||
"""Get current research status for a session."""
|
"""Get current research status for a session."""
|
||||||
avg = self.get_avg_duration()
|
|
||||||
if session_id in self._active_tasks:
|
if session_id in self._active_tasks:
|
||||||
entry = self._active_tasks[session_id]
|
entry = self._active_tasks[session_id]
|
||||||
result = {
|
result = {
|
||||||
@@ -399,6 +414,14 @@ class ResearchHandler:
|
|||||||
"query": entry["query"],
|
"query": entry["query"],
|
||||||
"started_at": entry["started_at"],
|
"started_at": entry["started_at"],
|
||||||
}
|
}
|
||||||
|
# avg_duration is a historical figure over completed reports on
|
||||||
|
# disk; get_avg_duration() globs and JSON-parses the whole research
|
||||||
|
# dir, so compute it at most once per active stream (memoized on the
|
||||||
|
# entry) instead of on every ~1s SSE poll. The disk branch below
|
||||||
|
# never used it, so it no longer pays that cost at all.
|
||||||
|
if "_avg_duration" not in entry:
|
||||||
|
entry["_avg_duration"] = self.get_avg_duration()
|
||||||
|
avg = entry["_avg_duration"]
|
||||||
if avg is not None:
|
if avg is not None:
|
||||||
result["avg_duration"] = round(avg, 1)
|
result["avg_duration"] = round(avg, 1)
|
||||||
return result
|
return result
|
||||||
|
|||||||
+23
-11
@@ -214,6 +214,24 @@ def _search_like(
|
|||||||
return _rows_to_results(db, shaped, query, context_messages)
|
return _rows_to_results(db, shaped, query, context_messages)
|
||||||
|
|
||||||
|
|
||||||
|
def _fetch_messages_by_id(db, message_ids):
|
||||||
|
"""Fetch (message, session_name) for many message ids in a single query.
|
||||||
|
|
||||||
|
The FTS search returns a list of hit ids; fetching each row on its own was an
|
||||||
|
N+1 query (one SELECT per hit). Batch them with one IN(...) query and return
|
||||||
|
a lookup so the caller can reassemble results in hit (relevance) order.
|
||||||
|
"""
|
||||||
|
if not message_ids:
|
||||||
|
return {}
|
||||||
|
rows = (
|
||||||
|
db.query(DBChatMessage, DBSession.name)
|
||||||
|
.join(DBSession, DBChatMessage.session_id == DBSession.id)
|
||||||
|
.filter(DBChatMessage.id.in_(message_ids))
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
return {msg.id: (msg, session_name) for msg, session_name in rows}
|
||||||
|
|
||||||
|
|
||||||
def _search_fts(
|
def _search_fts(
|
||||||
db,
|
db,
|
||||||
query: str,
|
query: str,
|
||||||
@@ -267,19 +285,13 @@ def _search_fts(
|
|||||||
if not hits:
|
if not hits:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
by_id = _fetch_messages_by_id(db, [hit[0] for hit in hits])
|
||||||
rows = []
|
rows = []
|
||||||
for hit in hits:
|
for hit in hits:
|
||||||
message_id = hit[0]
|
found = by_id.get(hit[0])
|
||||||
snippet = hit[1] or ""
|
if found:
|
||||||
row = (
|
msg, session_name = found
|
||||||
db.query(DBChatMessage, DBSession.name)
|
rows.append((msg, session_name, hit[1] or ""))
|
||||||
.join(DBSession, DBChatMessage.session_id == DBSession.id)
|
|
||||||
.filter(DBChatMessage.id == message_id)
|
|
||||||
.first()
|
|
||||||
)
|
|
||||||
if row:
|
|
||||||
msg, session_name = row
|
|
||||||
rows.append((msg, session_name, snippet))
|
|
||||||
return _rows_to_results(db, rows, query, context_messages)
|
return _rows_to_results(db, rows, query, context_messages)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
+11
-1
@@ -12,6 +12,8 @@ tunnel / reverse proxy. Scrubbing is deep (recurses nested dicts/lists) and keye
|
|||||||
on secret-shaped names.
|
on secret-shaped names.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import re
|
||||||
|
|
||||||
_SECRET_KEY_PATTERNS = (
|
_SECRET_KEY_PATTERNS = (
|
||||||
"_api_key", "_apikey", "_password", "_passwd", "_pass", "_pwd",
|
"_api_key", "_apikey", "_password", "_passwd", "_pass", "_pwd",
|
||||||
"_secret", "_client_secret", "_token", "_access_token", "_refresh_token",
|
"_secret", "_client_secret", "_token", "_access_token", "_refresh_token",
|
||||||
@@ -26,8 +28,16 @@ _SENSITIVE_KEY_EXACT = (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _canonical_key_name(name: str) -> str:
|
||||||
|
"""Normalize common JS-style key names so secret matching is style-agnostic."""
|
||||||
|
n = (name or "").replace("-", "_")
|
||||||
|
n = re.sub(r"(.)([A-Z][a-z]+)", r"\1_\2", n)
|
||||||
|
n = re.sub(r"([a-z0-9])([A-Z])", r"\1_\2", n)
|
||||||
|
return n.lower()
|
||||||
|
|
||||||
|
|
||||||
def is_secret_key(name: str) -> bool:
|
def is_secret_key(name: str) -> bool:
|
||||||
n = (name or "").lower()
|
n = _canonical_key_name(name)
|
||||||
if n in _SECRET_KEY_ALLOW:
|
if n in _SECRET_KEY_ALLOW:
|
||||||
return False
|
return False
|
||||||
if n in _SENSITIVE_KEY_EXACT:
|
if n in _SENSITIVE_KEY_EXACT:
|
||||||
|
|||||||
+94
-7
@@ -9,6 +9,7 @@ Extracted from agent_tools.py.
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import collections
|
import collections
|
||||||
|
import contextvars
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
@@ -146,7 +147,13 @@ def _resolve_tool_path(raw_path: str) -> str:
|
|||||||
|
|
||||||
Returns the realpath on success. Raises ValueError on rejection.
|
Returns the realpath on success. Raises ValueError on rejection.
|
||||||
Symlinks are resolved before comparison.
|
Symlinks are resolved before comparison.
|
||||||
|
|
||||||
|
When a workspace is active for this turn, paths are confined to it instead
|
||||||
|
of the default allowlist (see _resolve_tool_path_in_workspace).
|
||||||
"""
|
"""
|
||||||
|
ws = get_active_workspace()
|
||||||
|
if ws:
|
||||||
|
return _resolve_tool_path_in_workspace(ws, raw_path)
|
||||||
if raw_path is None or not str(raw_path).strip():
|
if raw_path is None or not str(raw_path).strip():
|
||||||
raise ValueError("path is required")
|
raise ValueError("path is required")
|
||||||
expanded = os.path.expanduser(str(raw_path).strip())
|
expanded = os.path.expanduser(str(raw_path).strip())
|
||||||
@@ -207,6 +214,55 @@ def _resolve_tool_path_in_workspace(workspace: str, raw_path: str) -> str:
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Active workspace (per-turn, context-local)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Set ONCE in execute_tool_block from the request's `workspace`. The path
|
||||||
|
# resolvers (_resolve_tool_path / _resolve_search_root) and the subprocess cwd
|
||||||
|
# helper (agent_cwd) read it from here, so confinement is enforced in a single
|
||||||
|
# place: any tool that resolves paths through these helpers is confined
|
||||||
|
# automatically and cannot accidentally bypass the workspace. contextvars are
|
||||||
|
# task-local, so concurrent turns don't leak into each other.
|
||||||
|
_active_workspace: contextvars.ContextVar = contextvars.ContextVar(
|
||||||
|
"agent_active_workspace", default=None
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_active_workspace() -> Optional[str]:
|
||||||
|
"""The folder the agent is confined to this turn, or None."""
|
||||||
|
return _active_workspace.get()
|
||||||
|
|
||||||
|
|
||||||
|
def vet_workspace(raw: str) -> Optional[str]:
|
||||||
|
"""Validate a requested workspace path at bind time.
|
||||||
|
|
||||||
|
Returns the canonical path, or None when it is unusable: not a real
|
||||||
|
directory, or itself a sensitive path (.ssh, .gnupg, ...). The in-workspace
|
||||||
|
resolver deny-lists sensitive paths *inside* the workspace, but the
|
||||||
|
empty-path search root is the workspace itself, so the root has to be
|
||||||
|
vetted before it is ever bound.
|
||||||
|
"""
|
||||||
|
raw = (raw or "").strip()
|
||||||
|
if not raw:
|
||||||
|
return None
|
||||||
|
resolved = os.path.realpath(os.path.expanduser(raw))
|
||||||
|
if not os.path.isdir(resolved) or _is_sensitive_path(resolved):
|
||||||
|
return None
|
||||||
|
# Reject filesystem roots: binding / (or a Windows drive/UNC root) as the
|
||||||
|
# workspace would make every absolute path "inside" it, collapsing the
|
||||||
|
# confinement into host-wide file access. A root is its own dirname, which
|
||||||
|
# also covers C:\ and \\server\share without platform-specific lists.
|
||||||
|
if os.path.dirname(resolved) == resolved:
|
||||||
|
return None
|
||||||
|
return resolved
|
||||||
|
|
||||||
|
|
||||||
|
def agent_cwd() -> str:
|
||||||
|
"""Working directory for agent subprocesses (bash/python/background jobs):
|
||||||
|
the active workspace when set, else the persistent data dir."""
|
||||||
|
return get_active_workspace() or _AGENT_WORKDIR
|
||||||
|
|
||||||
|
|
||||||
def get_mcp_manager():
|
def get_mcp_manager():
|
||||||
from src import agent_tools
|
from src import agent_tools
|
||||||
return agent_tools.get_mcp_manager()
|
return agent_tools.get_mcp_manager()
|
||||||
@@ -217,10 +273,15 @@ def get_mcp_manager():
|
|||||||
def _resolve_search_root(raw_path: str) -> str:
|
def _resolve_search_root(raw_path: str) -> str:
|
||||||
"""Resolve + confine a code-nav path (grep/glob/ls).
|
"""Resolve + confine a code-nav path (grep/glob/ls).
|
||||||
|
|
||||||
An empty path defaults to the agent's primary root (project data dir) and a
|
With a workspace active, the workspace folder is the root and a supplied
|
||||||
supplied path is confined by the global allowlist + sensitive-file policy.
|
path is confined inside it. Otherwise an empty path defaults to the agent's
|
||||||
|
primary root (project data dir) and a supplied path is confined by the
|
||||||
|
global allowlist + sensitive-file policy.
|
||||||
"""
|
"""
|
||||||
raw = (raw_path or "").strip()
|
raw = (raw_path or "").strip()
|
||||||
|
ws = get_active_workspace()
|
||||||
|
if ws:
|
||||||
|
return os.path.realpath(ws) if not raw else _resolve_tool_path_in_workspace(ws, raw)
|
||||||
if not raw:
|
if not raw:
|
||||||
roots = _tool_path_roots()
|
roots = _tool_path_roots()
|
||||||
return roots[0] if roots else os.path.realpath(".")
|
return roots[0] if roots else os.path.realpath(".")
|
||||||
@@ -392,7 +453,6 @@ async def _direct_fallback(
|
|||||||
tool: str,
|
tool: str,
|
||||||
content: str,
|
content: str,
|
||||||
progress_cb: Optional[Callable[[Dict], Awaitable[None]]] = None,
|
progress_cb: Optional[Callable[[Dict], Awaitable[None]]] = None,
|
||||||
workspace: Optional[str] = None,
|
|
||||||
) -> Optional[Dict]:
|
) -> Optional[Dict]:
|
||||||
_subproc_env = {
|
_subproc_env = {
|
||||||
**os.environ,
|
**os.environ,
|
||||||
@@ -405,7 +465,6 @@ async def _direct_fallback(
|
|||||||
try:
|
try:
|
||||||
ctx = {
|
ctx = {
|
||||||
"progress_cb": progress_cb,
|
"progress_cb": progress_cb,
|
||||||
"workspace": workspace,
|
|
||||||
"subproc_env": _subproc_env,
|
"subproc_env": _subproc_env,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -448,6 +507,34 @@ async def execute_tool_block(
|
|||||||
) -> Tuple[str, Dict]:
|
) -> Tuple[str, Dict]:
|
||||||
"""Execute a single tool block. Returns (description, result_dict).
|
"""Execute a single tool block. Returns (description, result_dict).
|
||||||
|
|
||||||
|
Thin wrapper: bind the per-turn workspace (so the path resolvers + subprocess
|
||||||
|
cwd confine to it) for the duration of this call, then delegate. Reset on the
|
||||||
|
way out so the binding never leaks to the next tool call.
|
||||||
|
"""
|
||||||
|
token = _active_workspace.set(workspace or None)
|
||||||
|
try:
|
||||||
|
return await _execute_tool_block_impl(
|
||||||
|
block,
|
||||||
|
session_id=session_id,
|
||||||
|
disabled_tools=disabled_tools,
|
||||||
|
owner=owner,
|
||||||
|
progress_cb=progress_cb,
|
||||||
|
tool_policy=tool_policy,
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
_active_workspace.reset(token)
|
||||||
|
|
||||||
|
|
||||||
|
async def _execute_tool_block_impl(
|
||||||
|
block: Any,
|
||||||
|
session_id: Optional[str] = None,
|
||||||
|
disabled_tools: Optional[set] = None,
|
||||||
|
owner: Optional[str] = None,
|
||||||
|
progress_cb: Optional[Callable[[Dict], Awaitable[None]]] = None,
|
||||||
|
tool_policy: Optional[Any] = None,
|
||||||
|
) -> Tuple[str, Dict]:
|
||||||
|
"""Execute a single tool block. Returns (description, result_dict).
|
||||||
|
|
||||||
`progress_cb` is forwarded to long-running subprocess tools
|
`progress_cb` is forwarded to long-running subprocess tools
|
||||||
(bash, python) so the agent loop can emit `tool_progress` SSE
|
(bash, python) so the agent loop can emit `tool_progress` SSE
|
||||||
events while the command is in flight. Ignored by other tools.
|
events while the command is in flight. Ignored by other tools.
|
||||||
@@ -621,7 +708,7 @@ async def execute_tool_block(
|
|||||||
_is_bg, _bg_cmd = _split_bg_marker(content)
|
_is_bg, _bg_cmd = _split_bg_marker(content)
|
||||||
if _is_bg and _bg_cmd:
|
if _is_bg and _bg_cmd:
|
||||||
from src import bg_jobs
|
from src import bg_jobs
|
||||||
rec = bg_jobs.launch(_bg_cmd, session_id=session_id, cwd=_AGENT_WORKDIR)
|
rec = bg_jobs.launch(_bg_cmd, session_id=session_id, cwd=agent_cwd())
|
||||||
short = _bg_cmd.strip().split(chr(10))[0][:80]
|
short = _bg_cmd.strip().split(chr(10))[0][:80]
|
||||||
desc = f"bash (background): {short}"
|
desc = f"bash (background): {short}"
|
||||||
result = {
|
result = {
|
||||||
@@ -644,7 +731,7 @@ async def execute_tool_block(
|
|||||||
first_line = content.split(chr(10))[0][:80]
|
first_line = content.split(chr(10))[0][:80]
|
||||||
desc = f"{tool}: {first_line}"
|
desc = f"{tool}: {first_line}"
|
||||||
result = await _call_mcp_tool(tool, content, progress_cb=progress_cb)
|
result = await _call_mcp_tool(tool, content, progress_cb=progress_cb)
|
||||||
elif tool in ("grep", "glob", "ls"):
|
elif tool in ("grep", "glob", "ls", "get_workspace"):
|
||||||
# Code-navigation tools — no MCP server; run the direct implementation.
|
# Code-navigation tools — no MCP server; run the direct implementation.
|
||||||
first_line = content.split(chr(10))[0][:80]
|
first_line = content.split(chr(10))[0][:80]
|
||||||
desc = f"{tool}: {first_line}"
|
desc = f"{tool}: {first_line}"
|
||||||
@@ -744,7 +831,7 @@ async def execute_tool_block(
|
|||||||
desc = "edit_image"
|
desc = "edit_image"
|
||||||
result = await do_edit_image(content, owner=owner)
|
result = await do_edit_image(content, owner=owner)
|
||||||
elif tool == "edit_file":
|
elif tool == "edit_file":
|
||||||
result = await _direct_fallback(tool, content, workspace=workspace) or {"error": "edit failed", "exit_code": 1}
|
result = await _direct_fallback(tool, content) or {"error": "edit failed", "exit_code": 1}
|
||||||
desc = result.get("output") or result.get("error") or "edit_file"
|
desc = result.get("output") or result.get("error") or "edit_file"
|
||||||
elif tool == "trigger_research":
|
elif tool == "trigger_research":
|
||||||
desc = "trigger_research"
|
desc = "trigger_research"
|
||||||
|
|||||||
@@ -1453,6 +1453,42 @@ async def do_manage_calendar(content: str, owner: Optional[str] = None) -> Dict:
|
|||||||
except ValueError:
|
except ValueError:
|
||||||
return {"error": "Invalid JSON arguments", "exit_code": 1}
|
return {"error": "Invalid JSON arguments", "exit_code": 1}
|
||||||
|
|
||||||
|
# ── Batch normalization ──
|
||||||
|
# Some models (e.g. deepseek-v4-flash) emit {"events": [{...}, ...]}
|
||||||
|
# instead of individual create_event calls. Iterate and create each.
|
||||||
|
if isinstance(args.get("events"), list) and not args.get("action"):
|
||||||
|
results = []
|
||||||
|
for ev in args["events"]:
|
||||||
|
if not isinstance(ev, dict):
|
||||||
|
continue
|
||||||
|
# Normalize start/end from {dateTime: "..."} object to flat string
|
||||||
|
for field, target in [("start", "dtstart"), ("end", "dtend")]:
|
||||||
|
val = ev.pop(field, None)
|
||||||
|
if val and target not in ev:
|
||||||
|
ev[target] = val.get("dateTime", val) if isinstance(val, dict) else val
|
||||||
|
ev.setdefault("action", "create_event")
|
||||||
|
r = await do_manage_calendar(json.dumps(ev), owner=owner)
|
||||||
|
results.append(r)
|
||||||
|
created = [r for r in results if r.get("exit_code") == 0 and not r.get("error")]
|
||||||
|
failed = [r for r in results if r.get("error")]
|
||||||
|
|
||||||
|
if not results:
|
||||||
|
return {"error": "No events to create", "exit_code": 1}
|
||||||
|
|
||||||
|
# Surface both successes and failures
|
||||||
|
parts = []
|
||||||
|
if created:
|
||||||
|
summaries = [r.get("response", "") for r in created]
|
||||||
|
parts.append(f"Created {len(created)} event(s):\n" + "\n".join(summaries))
|
||||||
|
if failed:
|
||||||
|
first_error = failed[0].get("error", "Unknown error")
|
||||||
|
parts.append(f"Failed to create {len(failed)} event(s). First error: {first_error}")
|
||||||
|
|
||||||
|
response = "\n\n".join(parts)
|
||||||
|
# Non-zero exit code for partial or total failure
|
||||||
|
exit_code = 0 if not failed else 1
|
||||||
|
return {"response": response, "exit_code": exit_code, "created_count": len(created), "failed_count": len(failed)}
|
||||||
|
|
||||||
# Normalize action — some models emit hyphens ("list-calendars") instead
|
# Normalize action — some models emit hyphens ("list-calendars") instead
|
||||||
# of underscores. Treat them as equivalent so we don't bounce a
|
# of underscores. Treat them as equivalent so we don't bounce a
|
||||||
# cosmetic typo back to the model and waste a round-trip. Also accept
|
# cosmetic typo back to the model and waste a round-trip. Also accept
|
||||||
|
|||||||
+3
-2
@@ -67,14 +67,15 @@ COLLECTION_NAME = "odysseus_tool_index"
|
|||||||
# Each tool gets a searchable description that helps retrieval.
|
# Each tool gets a searchable description that helps retrieval.
|
||||||
# These are richer than the system prompt one-liners — they're for embedding.
|
# These are richer than the system prompt one-liners — they're for embedding.
|
||||||
BUILTIN_TOOL_DESCRIPTIONS: Dict[str, str] = {
|
BUILTIN_TOOL_DESCRIPTIONS: Dict[str, str] = {
|
||||||
"bash": "Run shell commands on the server. Install packages, check files, git operations, system info, and process management. Do not use for web lookup/search; use web_search or web_fetch when web tools are available.",
|
"bash": "Run shell commands on the server. Install packages, git operations, builds, system info, process management. Prefer a dedicated tool whenever one fits the job (file read/write/edit, search, listing); use bash only for what no dedicated tool covers. Do not use for web lookup/search; use web_search or web_fetch when web tools are available.",
|
||||||
"python": "Execute Python code for computation, data processing, math, scripting, and parsing. Not for writing code for the user. Do not use for web lookup/search; use web_search or web_fetch when web tools are available.",
|
"python": "Execute Python code for computation, data processing, math, scripting, and parsing. Not for writing code for the user. Prefer a dedicated tool for reading, writing, or searching files; use python only for what no dedicated tool covers. Do not use for web lookup/search; use web_search or web_fetch when web tools are available.",
|
||||||
"web_search": "Quick single web lookup for a fact, current event, latest/current information, or doc mid-task. Use this instead of bash/curl/python/requests for web searches. NOT for 'research X' / 'do research on X' requests — those are deep-research jobs (use trigger_research). web_search = one query; trigger_research = a full researched report in the sidebar.",
|
"web_search": "Quick single web lookup for a fact, current event, latest/current information, or doc mid-task. Use this instead of bash/curl/python/requests for web searches. NOT for 'research X' / 'do research on X' requests — those are deep-research jobs (use trigger_research). web_search = one query; trigger_research = a full researched report in the sidebar.",
|
||||||
"web_fetch": "Fetch and read the text content of a specific URL/website the user names (e.g. 'check example.com', 'open this link'). Use when you have a concrete URL; for open-ended lookups use web_search instead.",
|
"web_fetch": "Fetch and read the text content of a specific URL/website the user names (e.g. 'check example.com', 'open this link'). Use when you have a concrete URL; for open-ended lookups use web_search instead.",
|
||||||
"read_file": "Read a file from disk and return its contents. View source code, config files, logs. Supports an optional line range (offset/limit) for large files.",
|
"read_file": "Read a file from disk and return its contents. View source code, config files, logs. Supports an optional line range (offset/limit) for large files.",
|
||||||
"grep": "Search file CONTENTS for a regex across a directory tree (ripgrep-backed, honours .gitignore). Returns file:line:match. Use to find where code/symbols/strings live — prefer over bash grep.",
|
"grep": "Search file CONTENTS for a regex across a directory tree (ripgrep-backed, honours .gitignore). Returns file:line:match. Use to find where code/symbols/strings live — prefer over bash grep.",
|
||||||
"glob": "Find FILES by glob pattern (e.g. '**/*.py'), newest first. Use to locate files by name/extension — prefer over bash find/ls.",
|
"glob": "Find FILES by glob pattern (e.g. '**/*.py'), newest first. Use to locate files by name/extension — prefer over bash find/ls.",
|
||||||
"ls": "List a directory's entries (folders then files with sizes). Use to see what's in a folder — prefer over bash ls.",
|
"ls": "List a directory's entries (folders then files with sizes). Use to see what's in a folder — prefer over bash ls.",
|
||||||
|
"get_workspace": "Return the absolute path of the active workspace folder the user is working in. File tools are confined to it; the shell starts there but is not sandboxed. Call this first when the user refers to 'the project'/'the code'/'this folder' without giving a path, instead of asking them.",
|
||||||
"write_file": "Write/create or fully rewrite a file ON DISK (source code, configs, project files). Use for new files or full rewrites — NOT create_document (editor panel) and NOT a bash heredoc.",
|
"write_file": "Write/create or fully rewrite a file ON DISK (source code, configs, project files). Use for new files or full rewrites — NOT create_document (editor panel) and NOT a bash heredoc.",
|
||||||
"edit_file": "Edit an existing file ON DISK by exact string replacement (fix a bug, change a function). Shows a diff. The tool for changing files on disk — NOT edit_document (editor panel) and NOT bash sed/heredoc.",
|
"edit_file": "Edit an existing file ON DISK by exact string replacement (fix a bug, change a function). Shows a diff. The tool for changing files on disk — NOT edit_document (editor panel) and NOT bash sed/heredoc.",
|
||||||
"create_document": "Create a new document in the editor panel. For code, articles, text content longer than 15 lines, unless an already-open document/email draft is the obvious target. If an email compose draft is open, edit that draft instead of creating another document.",
|
"create_document": "Create a new document in the editor panel. For code, articles, text content longer than 15 lines, unless an already-open document/email draft is the obvious target. If an email compose draft is open, edit that draft instead of creating another document.",
|
||||||
|
|||||||
+12
-2
@@ -25,7 +25,7 @@ FUNCTION_TOOL_SCHEMAS = [
|
|||||||
"type": "function",
|
"type": "function",
|
||||||
"function": {
|
"function": {
|
||||||
"name": "bash",
|
"name": "bash",
|
||||||
"description": "Run a shell command (full access)",
|
"description": "Run a shell command (full access). Prefer a dedicated tool whenever one fits the job (reading, writing, editing, searching, or listing files); use bash only for what no dedicated tool covers (installs, git, builds, running programs, system info). Do NOT create or edit files via bash redirects/heredocs/sed -- use the dedicated file tools.",
|
||||||
"parameters": {
|
"parameters": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
@@ -39,7 +39,7 @@ FUNCTION_TOOL_SCHEMAS = [
|
|||||||
"type": "function",
|
"type": "function",
|
||||||
"function": {
|
"function": {
|
||||||
"name": "python",
|
"name": "python",
|
||||||
"description": "Execute Python code to compute a result or test something",
|
"description": "Execute Python code to compute a result or test something. Prefer a dedicated tool whenever one fits the job (reading, writing, or searching files); use python only for computation, data processing, or scripting no dedicated tool covers.",
|
||||||
"parameters": {
|
"parameters": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
@@ -141,6 +141,14 @@ FUNCTION_TOOL_SCHEMAS = [
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "get_workspace",
|
||||||
|
"description": "Return the absolute path of the active workspace folder the user is working in. File tools are confined to it; the shell starts there but is not sandboxed. Call this first when the user refers to 'the project'/'the code'/'this folder' without a path, instead of asking them. Takes no arguments.",
|
||||||
|
"parameters": {"type": "object", "properties": {}, "required": []}
|
||||||
|
}
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"type": "function",
|
"type": "function",
|
||||||
"function": {
|
"function": {
|
||||||
@@ -1246,6 +1254,8 @@ def function_call_to_tool_block(name: str, arguments: str) -> Optional[ToolBlock
|
|||||||
content = args.get("path", "")
|
content = args.get("path", "")
|
||||||
elif tool_type in ("grep", "glob", "ls"):
|
elif tool_type in ("grep", "glob", "ls"):
|
||||||
content = json.dumps(args) if args else "{}"
|
content = json.dumps(args) if args else "{}"
|
||||||
|
elif tool_type == "get_workspace":
|
||||||
|
content = ""
|
||||||
elif tool_type == "write_file":
|
elif tool_type == "write_file":
|
||||||
content = args.get("path", "") + "\n" + args.get("content", "")
|
content = args.get("path", "") + "\n" + args.get("content", "")
|
||||||
elif tool_type == "edit_file":
|
elif tool_type == "edit_file":
|
||||||
|
|||||||
+17
-2
@@ -20,6 +20,7 @@ NON_ADMIN_BLOCKED_TOOLS = {
|
|||||||
"grep",
|
"grep",
|
||||||
"glob",
|
"glob",
|
||||||
"ls",
|
"ls",
|
||||||
|
"get_workspace",
|
||||||
"search_chats",
|
"search_chats",
|
||||||
"manage_memory",
|
"manage_memory",
|
||||||
"manage_skills",
|
"manage_skills",
|
||||||
@@ -66,6 +67,7 @@ PLAN_MODE_READONLY_TOOLS = {
|
|||||||
"grep",
|
"grep",
|
||||||
"glob",
|
"glob",
|
||||||
"ls",
|
"ls",
|
||||||
|
"get_workspace",
|
||||||
"web_search",
|
"web_search",
|
||||||
"web_fetch",
|
"web_fetch",
|
||||||
"search_chats",
|
"search_chats",
|
||||||
@@ -162,13 +164,26 @@ def is_public_blocked_tool(tool_name: Optional[str]) -> bool:
|
|||||||
|
|
||||||
|
|
||||||
def owner_is_admin_or_single_user(owner: Optional[str]) -> bool:
|
def owner_is_admin_or_single_user(owner: Optional[str]) -> bool:
|
||||||
"""Return True for admins, or when auth is not configured yet."""
|
"""Return True for admins, or in intentional single-user mode.
|
||||||
|
|
||||||
|
Single-user mode means the operator explicitly disabled auth
|
||||||
|
(``AUTH_ENABLED=false``) — the local/self-host default where the owner has
|
||||||
|
full access to their own box.
|
||||||
|
|
||||||
|
The pre-setup window (auth ENABLED but no admin created yet) is treated as
|
||||||
|
NON-admin: returning True there would hand server-execution tools
|
||||||
|
(``bash``/``python``) to any caller before setup completes. The auth
|
||||||
|
middleware already 401s ``/api/`` requests pre-setup, so this is
|
||||||
|
defense-in-depth for callers that bypass it (e.g. trusted loopback).
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
from core.auth import AuthManager
|
from core.auth import AuthManager
|
||||||
|
|
||||||
auth = AuthManager()
|
auth = AuthManager()
|
||||||
if not auth.is_configured:
|
if not auth.is_configured:
|
||||||
return True
|
from src.auth_helpers import _auth_disabled
|
||||||
|
|
||||||
|
return _auth_disabled()
|
||||||
return bool(owner and auth.is_admin(owner))
|
return bool(owner and auth.is_admin(owner))
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.warning("Unable to evaluate owner admin status: %s", exc)
|
logger.warning("Unable to evaluate owner admin status: %s", exc)
|
||||||
|
|||||||
@@ -352,6 +352,86 @@ class UploadHandler:
|
|||||||
return dict(info)
|
return dict(info)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def _renamed_upload_index_key(self, key: str, info: Dict[str, Any], old_owner: str, new_owner: str) -> str:
|
||||||
|
"""Return the storage key to use after renaming an owned upload row."""
|
||||||
|
if isinstance(key, str) and ":" in key:
|
||||||
|
owner_part, rest = key.split(":", 1)
|
||||||
|
if owner_part.strip().lower() == old_owner:
|
||||||
|
return f"{new_owner}:{rest}"
|
||||||
|
file_hash = info.get("hash")
|
||||||
|
if file_hash:
|
||||||
|
return f"{new_owner}:{file_hash}"
|
||||||
|
return key
|
||||||
|
|
||||||
|
def _unique_upload_index_key(self, base_key: str, used_keys: set, reserved_keys: set, info: Dict[str, Any]) -> str:
|
||||||
|
"""Choose a deterministic collision key without overwriting an existing row."""
|
||||||
|
if base_key not in used_keys and base_key not in reserved_keys:
|
||||||
|
return base_key
|
||||||
|
|
||||||
|
upload_id = str(info.get("id") or "renamed").strip() or "renamed"
|
||||||
|
candidate = f"{base_key}:{upload_id}"
|
||||||
|
if candidate not in used_keys and candidate not in reserved_keys:
|
||||||
|
return candidate
|
||||||
|
|
||||||
|
index = 2
|
||||||
|
while True:
|
||||||
|
candidate = f"{base_key}:{upload_id}:{index}"
|
||||||
|
if candidate not in used_keys and candidate not in reserved_keys:
|
||||||
|
return candidate
|
||||||
|
index += 1
|
||||||
|
|
||||||
|
def rename_owner(self, old_owner: str, new_owner: str) -> int:
|
||||||
|
"""Rename upload metadata ownership from old_owner to new_owner.
|
||||||
|
|
||||||
|
Upload rows are keyed by owner-qualified hashes for dedupe and also
|
||||||
|
carry an `owner` field for access checks. Both must move together when
|
||||||
|
usernames change.
|
||||||
|
"""
|
||||||
|
old_owner_normalized = str(old_owner or "").strip().lower()
|
||||||
|
new_owner = str(new_owner or "").strip()
|
||||||
|
if not old_owner_normalized or not new_owner:
|
||||||
|
return 0
|
||||||
|
if old_owner_normalized == new_owner.lower():
|
||||||
|
return 0
|
||||||
|
|
||||||
|
uploads_db_path = os.path.join(self.upload_dir, "uploads.json")
|
||||||
|
with self._index_lock:
|
||||||
|
current = self._load_upload_index()
|
||||||
|
if not current:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
updated = {}
|
||||||
|
renamed = 0
|
||||||
|
original_keys = set(current.keys())
|
||||||
|
|
||||||
|
for key, info in current.items():
|
||||||
|
new_key = key
|
||||||
|
new_info = info
|
||||||
|
if isinstance(info, dict) and str(info.get("owner", "")).strip().lower() == old_owner_normalized:
|
||||||
|
new_info = dict(info)
|
||||||
|
new_info["owner"] = new_owner
|
||||||
|
base_key = self._renamed_upload_index_key(key, new_info, old_owner_normalized, new_owner)
|
||||||
|
new_key = self._unique_upload_index_key(
|
||||||
|
base_key,
|
||||||
|
set(updated.keys()),
|
||||||
|
original_keys - {key},
|
||||||
|
new_info,
|
||||||
|
)
|
||||||
|
if new_key != base_key:
|
||||||
|
logger.warning(
|
||||||
|
"Upload owner rename key collision for %s -> %s at %s; preserving row as %s",
|
||||||
|
old_owner_normalized,
|
||||||
|
new_owner,
|
||||||
|
base_key,
|
||||||
|
new_key,
|
||||||
|
)
|
||||||
|
renamed += 1
|
||||||
|
updated[new_key] = new_info
|
||||||
|
|
||||||
|
if renamed:
|
||||||
|
self._atomic_write_json(uploads_db_path, updated)
|
||||||
|
return renamed
|
||||||
|
|
||||||
def _find_upload_path(self, upload_id: str) -> Optional[str]:
|
def _find_upload_path(self, upload_id: str) -> Optional[str]:
|
||||||
"""Find an upload file by ID while staying inside upload_dir."""
|
"""Find an upload file by ID while staying inside upload_dir."""
|
||||||
if not self.validate_upload_id(upload_id):
|
if not self.validate_upload_id(upload_id):
|
||||||
|
|||||||
+15
-3
@@ -202,6 +202,18 @@ class WebhookManager:
|
|||||||
self._client = httpx.AsyncClient(timeout=10, follow_redirects=False)
|
self._client = httpx.AsyncClient(timeout=10, follow_redirects=False)
|
||||||
self._loop: Optional[asyncio.AbstractEventLoop] = None
|
self._loop: Optional[asyncio.AbstractEventLoop] = None
|
||||||
self._api_key_manager = api_key_manager
|
self._api_key_manager = api_key_manager
|
||||||
|
# Strong references to in-flight fire-and-forget tasks. asyncio only
|
||||||
|
# keeps weak references to tasks, so without this the GC can collect a
|
||||||
|
# delivery task mid-flight and the webhook is silently never sent.
|
||||||
|
self._bg_tasks: set = set()
|
||||||
|
|
||||||
|
def _spawn_tracked(self, coro):
|
||||||
|
"""Schedule a background task and hold a strong reference until it
|
||||||
|
finishes, so it can't be garbage-collected before delivery completes."""
|
||||||
|
task = asyncio.ensure_future(coro)
|
||||||
|
self._bg_tasks.add(task)
|
||||||
|
task.add_done_callback(self._bg_tasks.discard)
|
||||||
|
return task
|
||||||
|
|
||||||
def set_loop(self, loop: asyncio.AbstractEventLoop):
|
def set_loop(self, loop: asyncio.AbstractEventLoop):
|
||||||
self._loop = loop
|
self._loop = loop
|
||||||
@@ -223,8 +235,8 @@ class WebhookManager:
|
|||||||
if event not in ALLOWED_EVENTS:
|
if event not in ALLOWED_EVENTS:
|
||||||
return
|
return
|
||||||
try:
|
try:
|
||||||
loop = asyncio.get_running_loop()
|
asyncio.get_running_loop()
|
||||||
loop.create_task(self.fire(event, payload))
|
self._spawn_tracked(self.fire(event, payload))
|
||||||
except RuntimeError:
|
except RuntimeError:
|
||||||
# Called from a sync thread (e.g. sync FastAPI route in threadpool)
|
# Called from a sync thread (e.g. sync FastAPI route in threadpool)
|
||||||
if self._loop and self._loop.is_running():
|
if self._loop and self._loop.is_running():
|
||||||
@@ -243,7 +255,7 @@ class WebhookManager:
|
|||||||
|
|
||||||
for wh in matching:
|
for wh in matching:
|
||||||
decrypted_secret = self._decrypt_secret(wh.secret)
|
decrypted_secret = self._decrypt_secret(wh.secret)
|
||||||
asyncio.create_task(self._deliver(wh.id, wh.url, decrypted_secret, event, payload))
|
self._spawn_tracked(self._deliver(wh.id, wh.url, decrypted_secret, event, payload))
|
||||||
|
|
||||||
async def deliver_test(self, webhook_id: str, url: str, encrypted_secret: Optional[str]):
|
async def deliver_test(self, webhook_id: str, url: str, encrypted_secret: Optional[str]):
|
||||||
"""Public method for the test-webhook route."""
|
"""Public method for the test-webhook route."""
|
||||||
|
|||||||
+6
-6
@@ -4,6 +4,7 @@
|
|||||||
// ============================================
|
// ============================================
|
||||||
import Storage from './js/storage.js';
|
import Storage from './js/storage.js';
|
||||||
import uiModule from './js/ui.js';
|
import uiModule from './js/ui.js';
|
||||||
|
import workspaceModule from './js/workspace.js';
|
||||||
import fileHandlerModule from './js/fileHandler.js';
|
import fileHandlerModule from './js/fileHandler.js';
|
||||||
import modelsModule from './js/models.js';
|
import modelsModule from './js/models.js';
|
||||||
import ragModule from './js/rag.js';
|
import ragModule from './js/rag.js';
|
||||||
@@ -1159,7 +1160,7 @@ function initializeEventListeners() {
|
|||||||
if (!p.can_use_bash) {
|
if (!p.can_use_bash) {
|
||||||
const bashToggle = document.getElementById('bash-toggle');
|
const bashToggle = document.getElementById('bash-toggle');
|
||||||
if (bashToggle) bashToggle.closest('.chat-input-toggle')?.style.setProperty('display', 'none');
|
if (bashToggle) bashToggle.closest('.chat-input-toggle')?.style.setProperty('display', 'none');
|
||||||
const bashBtn = document.getElementById('tool-bash-btn');
|
const bashBtn = document.getElementById('bash-toggle-btn');
|
||||||
if (bashBtn) bashBtn.style.display = 'none';
|
if (bashBtn) bashBtn.style.display = 'none';
|
||||||
}
|
}
|
||||||
// Hide document button
|
// Hide document button
|
||||||
@@ -1176,11 +1177,7 @@ function initializeEventListeners() {
|
|||||||
const resOverflow = document.getElementById('overflow-research-btn');
|
const resOverflow = document.getElementById('overflow-research-btn');
|
||||||
if (resOverflow) resOverflow.style.display = 'none';
|
if (resOverflow) resOverflow.style.display = 'none';
|
||||||
}
|
}
|
||||||
// Hide image generation options
|
|
||||||
if (!p.can_generate_images) {
|
|
||||||
const imgBtn = document.getElementById('tool-image-btn');
|
|
||||||
if (imgBtn) imgBtn.style.display = 'none';
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
.catch(() => {});
|
.catch(() => {});
|
||||||
@@ -1626,6 +1623,8 @@ function initializeEventListeners() {
|
|||||||
// Slide the pill to the active button
|
// Slide the pill to the active button
|
||||||
const toggle = agentBtn.closest('.mode-toggle');
|
const toggle = agentBtn.closest('.mode-toggle');
|
||||||
if (toggle) toggle.classList.toggle('mode-chat', mode === 'chat');
|
if (toggle) toggle.classList.toggle('mode-chat', mode === 'chat');
|
||||||
|
// Workspace pill + overflow entry are agent-only - hide immediately (no flash).
|
||||||
|
try { workspaceModule.applyMode(mode); } catch (_) {}
|
||||||
// Delay tool glow-up for a staggered effect
|
// Delay tool glow-up for a staggered effect
|
||||||
setTimeout(() => applyModeToToggles(mode), 500);
|
setTimeout(() => applyModeToToggles(mode), 500);
|
||||||
}
|
}
|
||||||
@@ -1701,6 +1700,7 @@ function initializeEventListeners() {
|
|||||||
}
|
}
|
||||||
setupToggle('web-toggle-btn', 'web-toggle', 'web');
|
setupToggle('web-toggle-btn', 'web-toggle', 'web');
|
||||||
setupToggle('bash-toggle-btn', 'bash-toggle', 'bash');
|
setupToggle('bash-toggle-btn', 'bash-toggle', 'bash');
|
||||||
|
try { workspaceModule.initWorkspace(); } catch (_) {}
|
||||||
|
|
||||||
// Document editor toggle (special: uses module panel, not a checkbox)
|
// Document editor toggle (special: uses module panel, not a checkbox)
|
||||||
const overflowDocBtn = el('overflow-doc-btn');
|
const overflowDocBtn = el('overflow-doc-btn');
|
||||||
|
|||||||
+14
-1
@@ -1040,6 +1040,13 @@
|
|||||||
<span>RAG</span>
|
<span>RAG</span>
|
||||||
<span class="overflow-active-dot"></span>
|
<span class="overflow-active-dot"></span>
|
||||||
</button>
|
</button>
|
||||||
|
<button type="button" class="overflow-menu-item" id="overflow-workspace-btn">
|
||||||
|
<svg width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round">
|
||||||
|
<path d="M3 7a2 2 0 0 1 2-2h4l2 2h8a2 2 0 0 1 2 2v8a2 2 0 0 1-2 2H5a2 2 0 0 1-2-2z"/>
|
||||||
|
</svg>
|
||||||
|
<span>Workspace</span>
|
||||||
|
<span class="overflow-active-dot"></span>
|
||||||
|
</button>
|
||||||
<!-- Inline "deep research mode" toggle removed (superseded by the
|
<!-- Inline "deep research mode" toggle removed (superseded by the
|
||||||
Deep Research sidebar / trigger_research). The hidden
|
Deep Research sidebar / trigger_research). The hidden
|
||||||
#research-toggle checkbox is kept inert so existing JS refs
|
#research-toggle checkbox is kept inert so existing JS refs
|
||||||
@@ -1071,6 +1078,12 @@
|
|||||||
<polyline points="4 17 10 11 4 5"/><line x1="12" y1="19" x2="20" y2="19"/>
|
<polyline points="4 17 10 11 4 5"/><line x1="12" y1="19" x2="20" y2="19"/>
|
||||||
</svg>
|
</svg>
|
||||||
</button>
|
</button>
|
||||||
|
<!-- Workspace indicator (hidden until a folder is set) -->
|
||||||
|
<button type="button" class="input-icon-btn tool-indicator" title="Workspace - click to clear" id="workspace-indicator-btn" aria-label="Clear workspace" style="display:none;">
|
||||||
|
<svg width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><path d="M3 7a2 2 0 0 1 2-2h4l2 2h8a2 2 0 0 1 2 2v8a2 2 0 0 1-2 2H5a2 2 0 0 1-2-2z"/></svg>
|
||||||
|
<span style="font-size:11px;margin-left:2px;max-width:120px;overflow:hidden;text-overflow:ellipsis;white-space:nowrap;" id="workspace-indicator-name"></span>
|
||||||
|
<svg class="tool-indicator-x" width="10" height="10" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="3" stroke-linecap="round"><line x1="6" y1="6" x2="18" y2="18"/><line x1="18" y1="6" x2="6" y2="18"/></svg>
|
||||||
|
</button>
|
||||||
<!-- RAG toolbar indicator (hidden until active) -->
|
<!-- RAG toolbar indicator (hidden until active) -->
|
||||||
<button type="button" class="input-icon-btn tool-indicator" title="RAG active — click to deactivate" id="rag-indicator-btn" style="display:none;">
|
<button type="button" class="input-icon-btn tool-indicator" title="RAG active — click to deactivate" id="rag-indicator-btn" style="display:none;">
|
||||||
<svg width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2">
|
<svg width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2">
|
||||||
@@ -2342,7 +2355,7 @@
|
|||||||
<script type="module" src="/static/js/chatRenderer.js"></script>
|
<script type="module" src="/static/js/chatRenderer.js"></script>
|
||||||
<script type="module" src="/static/js/codeRunner.js"></script>
|
<script type="module" src="/static/js/codeRunner.js"></script>
|
||||||
<script type="module" src="/static/js/chatStream.js"></script>
|
<script type="module" src="/static/js/chatStream.js"></script>
|
||||||
<script type="module" src="/static/js/chat.js?v=20260604s"></script>
|
<script type="module" src="/static/js/chat.js?v=20260609ws"></script>
|
||||||
<script type="module" src="/static/js/cookbook.js"></script>
|
<script type="module" src="/static/js/cookbook.js"></script>
|
||||||
<script src="/static/js/cookbookSchedule.js"></script>
|
<script src="/static/js/cookbookSchedule.js"></script>
|
||||||
<script type="module" src="/static/js/search-chat.js"></script>
|
<script type="module" src="/static/js/search-chat.js"></script>
|
||||||
|
|||||||
+25
-2
@@ -819,6 +819,10 @@ import { wireArrowUpRecall, getLastUserMessageFromChatHistory } from './composer
|
|||||||
if (incognitoChk && incognitoChk.checked) {
|
if (incognitoChk && incognitoChk.checked) {
|
||||||
fd.append('incognito', 'true');
|
fd.append('incognito', 'true');
|
||||||
}
|
}
|
||||||
|
const _ws = (Storage.KEYS && Storage.get(Storage.KEYS.WORKSPACE, '')) || '';
|
||||||
|
if (_ws) {
|
||||||
|
fd.append('workspace', _ws);
|
||||||
|
}
|
||||||
if (presetsModule.getSelectedPreset()) {
|
if (presetsModule.getSelectedPreset()) {
|
||||||
fd.append('preset_id', presetsModule.getSelectedPreset());
|
fd.append('preset_id', presetsModule.getSelectedPreset());
|
||||||
}
|
}
|
||||||
@@ -1082,7 +1086,7 @@ import { wireArrowUpRecall, getLastUserMessageFromChatHistory } from './composer
|
|||||||
let _lastToolName = '';
|
let _lastToolName = '';
|
||||||
const _searchIcon = '<svg width="14" height="14" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2.5" stroke-linecap="round" style="vertical-align:-2px;margin-right:4px"><circle cx="11" cy="11" r="8"/><line x1="21" y1="21" x2="16.65" y2="16.65"/></svg>';
|
const _searchIcon = '<svg width="14" height="14" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2.5" stroke-linecap="round" style="vertical-align:-2px;margin-right:4px"><circle cx="11" cy="11" r="8"/><line x1="21" y1="21" x2="16.65" y2="16.65"/></svg>';
|
||||||
const _toolLabels = {
|
const _toolLabels = {
|
||||||
'web_search': _searchIcon + 'Searching',
|
'web_search': 'Searching',
|
||||||
'bash': 'Running',
|
'bash': 'Running',
|
||||||
'python': 'Running',
|
'python': 'Running',
|
||||||
'create_document': 'Writing',
|
'create_document': 'Writing',
|
||||||
@@ -1102,6 +1106,9 @@ import { wireArrowUpRecall, getLastUserMessageFromChatHistory } from './composer
|
|||||||
'list_models': 'Browsing',
|
'list_models': 'Browsing',
|
||||||
'ui_control': 'Adjusting',
|
'ui_control': 'Adjusting',
|
||||||
};
|
};
|
||||||
|
const _toolIcons = {
|
||||||
|
'web_search': _searchIcon,
|
||||||
|
};
|
||||||
function _thinkingLabel() {
|
function _thinkingLabel() {
|
||||||
if (!_lastToolName) {
|
if (!_lastToolName) {
|
||||||
return 'Thinking';
|
return 'Thinking';
|
||||||
@@ -1778,6 +1785,21 @@ import { wireArrowUpRecall, getLastUserMessageFromChatHistory } from './composer
|
|||||||
_sourcesData = json.data; _sourcesType = 'web';
|
_sourcesData = json.data; _sourcesType = 'web';
|
||||||
_sourcesHtml = _buildSourcesBox(json.data, 'web');
|
_sourcesHtml = _buildSourcesBox(json.data, 'web');
|
||||||
}
|
}
|
||||||
|
} else if (json.type === 'workspace_rejected') {
|
||||||
|
// Server refused to bind the posted workspace (deleted folder,
|
||||||
|
// file path, sensitive dir, filesystem root). Clear the stored
|
||||||
|
// value so the pill stops claiming a confinement that is not in
|
||||||
|
// effect, and tell the user.
|
||||||
|
const _wsPath = (json.data && json.data.path) || '';
|
||||||
|
import('./workspace.js').then((m) => {
|
||||||
|
const ws = m.default || m;
|
||||||
|
if (ws && ws.setWorkspace) ws.setWorkspace('');
|
||||||
|
});
|
||||||
|
uiModule.showToast(
|
||||||
|
`Workspace ${_wsPath || '(unknown)'} is no longer usable; running without confinement`,
|
||||||
|
6000
|
||||||
|
);
|
||||||
|
continue;
|
||||||
} else if (json.type === 'model_fallback') {
|
} else if (json.type === 'model_fallback') {
|
||||||
// Model went offline — switched to fallback
|
// Model went offline — switched to fallback
|
||||||
var _fbData = json.data || {};
|
var _fbData = json.data || {};
|
||||||
@@ -2049,10 +2071,11 @@ import { wireArrowUpRecall, getLastUserMessageFromChatHistory } from './composer
|
|||||||
}
|
}
|
||||||
threadWrap.classList.add('streaming');
|
threadWrap.classList.add('streaming');
|
||||||
const toolLabel = _toolLabels[json.tool.toLowerCase()] || json.tool;
|
const toolLabel = _toolLabels[json.tool.toLowerCase()] || json.tool;
|
||||||
|
const toolIcon = _toolIcons[json.tool.toLowerCase()] || '\u25B6';
|
||||||
const node = document.createElement('div')
|
const node = document.createElement('div')
|
||||||
node.className = 'agent-thread-node running';
|
node.className = 'agent-thread-node running';
|
||||||
const cmdHtml = cmd ? `<pre class="agent-thread-cmd">${esc(cmd)}</pre>` : '';
|
const cmdHtml = cmd ? `<pre class="agent-thread-cmd">${esc(cmd)}</pre>` : '';
|
||||||
node.innerHTML = `<div class="agent-thread-dot"></div><div class="agent-thread-header"><span class="agent-thread-icon">\u25B6</span><span class="agent-thread-tool">${esc(toolLabel)}</span><span class="agent-thread-wave">▁▂▃</span></div><div class="agent-thread-content">${cmdHtml}</div>`;
|
node.innerHTML = `<div class="agent-thread-dot"></div><div class="agent-thread-header"><span class="agent-thread-icon">${toolIcon}</span><span class="agent-thread-tool">${esc(toolLabel)}</span><span class="agent-thread-wave">▁▂▃</span></div><div class="agent-thread-content">${cmdHtml}</div>`;
|
||||||
// Expand/collapse via delegated click handler (init at module bottom).
|
// Expand/collapse via delegated click handler (init at module bottom).
|
||||||
threadWrap.appendChild(node);
|
threadWrap.appendChild(node);
|
||||||
currentToolBubble = node;
|
currentToolBubble = node;
|
||||||
|
|||||||
@@ -862,6 +862,20 @@ export function stripToolBlocks(text) {
|
|||||||
return cleaned.trim();
|
return cleaned.trim();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Plain-text payload for the message copy buttons: the reply as the renderer
|
||||||
|
* displays it — tool blocks and <think> reasoning stripped. dataset.raw keeps
|
||||||
|
* the full model output (chat.js even embeds the elapsed time into the
|
||||||
|
* <think> tag for reload persistence), so copying it verbatim leaks the
|
||||||
|
* thinking block (#3722). Falls back to the raw text when stripping leaves
|
||||||
|
* nothing (e.g. turns interrupted mid-thinking).
|
||||||
|
*/
|
||||||
|
export function copyMessageText(msgElement) {
|
||||||
|
const raw = msgElement.dataset.raw || msgElement.querySelector('.body')?.textContent || '';
|
||||||
|
const { content } = markdownModule.extractThinkingBlocks(stripToolBlocks(raw));
|
||||||
|
return content || raw;
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Build a collapsible sources box (used by both research and web search).
|
* Build a collapsible sources box (used by both research and web search).
|
||||||
*/
|
*/
|
||||||
@@ -1372,7 +1386,7 @@ export function createMsgFooter(msgElement) {
|
|||||||
{ id: 'copy', icon: COPY_ICON, title: 'Copy message', cls: 'footer-copy-btn', html: true, handler(e) {
|
{ id: 'copy', icon: COPY_ICON, title: 'Copy message', cls: 'footer-copy-btn', html: true, handler(e) {
|
||||||
e.stopPropagation();
|
e.stopPropagation();
|
||||||
const btn = e.currentTarget;
|
const btn = e.currentTarget;
|
||||||
uiModule.copyToClipboard(msgElement.dataset.raw || msgElement.querySelector('.body')?.textContent || '');
|
uiModule.copyToClipboard(copyMessageText(msgElement));
|
||||||
btn.innerHTML = CHECK_ICON;
|
btn.innerHTML = CHECK_ICON;
|
||||||
setTimeout(() => { btn.innerHTML = COPY_ICON; }, 1500);
|
setTimeout(() => { btn.innerHTML = COPY_ICON; }, 1500);
|
||||||
}},
|
}},
|
||||||
@@ -2444,6 +2458,7 @@ const chatRenderer = {
|
|||||||
updateSessionCostUI,
|
updateSessionCostUI,
|
||||||
roleTimestamp,
|
roleTimestamp,
|
||||||
stripToolBlocks,
|
stripToolBlocks,
|
||||||
|
copyMessageText,
|
||||||
safeToolScreenshotSrc,
|
safeToolScreenshotSrc,
|
||||||
safeDisplayImageSrc,
|
safeDisplayImageSrc,
|
||||||
buildSourcesBox,
|
buildSourcesBox,
|
||||||
|
|||||||
@@ -406,7 +406,7 @@ export const ERROR_PATTERNS = [
|
|||||||
{ label: 'Repair kernel package', action: () => {
|
{ label: 'Repair kernel package', action: () => {
|
||||||
const _vp = (_envState.env === 'venv' && _envState.envPath)
|
const _vp = (_envState.env === 'venv' && _envState.envPath)
|
||||||
? `${_envState.envPath.replace(/\/+$/, '')}/bin/python3` : 'python3';
|
? `${_envState.envPath.replace(/\/+$/, '')}/bin/python3` : 'python3';
|
||||||
_launchServeTask('repair-kernels', 'pip-update', `${_vp} -m pip install --user --break-system-packages kernels<0.15`);
|
_launchServeTask('repair-kernels', 'pip-update', `${_vp} -m pip install --user --break-system-packages "kernels<0.15"`);
|
||||||
}},
|
}},
|
||||||
{ label: 'Open Dependencies', action: () => _openCookbookDependencies('sglang') },
|
{ label: 'Open Dependencies', action: () => _openCookbookDependencies('sglang') },
|
||||||
],
|
],
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ import chatRenderer from './chatRenderer.js';
|
|||||||
import spinnerModule from './spinner.js';
|
import spinnerModule from './spinner.js';
|
||||||
import themeModule from './theme.js';
|
import themeModule from './theme.js';
|
||||||
import documentModule from './document.js';
|
import documentModule from './document.js';
|
||||||
|
import workspaceModule from './workspace.js';
|
||||||
import settingsModule from './settings.js';
|
import settingsModule from './settings.js';
|
||||||
import cookbookModule from './cookbook.js';
|
import cookbookModule from './cookbook.js';
|
||||||
import { EVAL_PROMPTS } from './compare/index.js';
|
import { EVAL_PROMPTS } from './compare/index.js';
|
||||||
@@ -380,7 +381,7 @@ function _slashFooter(msgEl) {
|
|||||||
copyBtn.innerHTML = _copySvg;
|
copyBtn.innerHTML = _copySvg;
|
||||||
copyBtn.onclick = (e) => {
|
copyBtn.onclick = (e) => {
|
||||||
e.stopPropagation();
|
e.stopPropagation();
|
||||||
uiModule.copyToClipboard(msgEl.dataset.raw || msgEl.querySelector('.body')?.textContent || '');
|
uiModule.copyToClipboard(chatRenderer.copyMessageText(msgEl));
|
||||||
copyBtn.innerHTML = _checkSvg;
|
copyBtn.innerHTML = _checkSvg;
|
||||||
setTimeout(() => { copyBtn.innerHTML = _copySvg; }, 1500);
|
setTimeout(() => { copyBtn.innerHTML = _copySvg; }, 1500);
|
||||||
};
|
};
|
||||||
@@ -1229,6 +1230,40 @@ async function _cmdToggleDoc(args, ctx) {
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Workspace: confine the agent's file/shell tools to a folder. Not a boolean -
|
||||||
|
// show / set <path> / clear / pick (open the directory browser).
|
||||||
|
async function _cmdWorkspace(args, ctx) {
|
||||||
|
const sub = (args[0] || '').toLowerCase();
|
||||||
|
const rest = args.slice(1).join(' ').trim();
|
||||||
|
const cur = workspaceModule.getWorkspace();
|
||||||
|
if (!sub || sub === 'show' || sub === 'status' || sub === 'info') {
|
||||||
|
slashReply(cur ? `Workspace: <code>${uiModule.esc(cur)}</code>` : 'No workspace set. <code>/workspace pick</code> or <code>/workspace set /path</code>.');
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
if (sub === 'set' || sub === 'cd' || sub === 'use') {
|
||||||
|
if (!rest) { slashReply('Usage: <code>/workspace set /absolute/path</code>'); return true; }
|
||||||
|
// Validate server-side before persisting so the pill never claims a
|
||||||
|
// workspace the backend will refuse to bind (typo, file path, deleted
|
||||||
|
// folder, sensitive dir, filesystem root).
|
||||||
|
workspaceModule.vetAndSetWorkspace(rest).then(({ ok, path }) => {
|
||||||
|
if (ok) slashReply(`Workspace set: <code>${uiModule.esc(path)}</code>`);
|
||||||
|
else slashReply(`Not a usable workspace folder: <code>${uiModule.esc(rest)}</code>. It must be an existing directory, not a filesystem root or sensitive path.`);
|
||||||
|
});
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
if (sub === 'clear' || sub === 'off' || sub === 'none' || sub === 'unset') {
|
||||||
|
workspaceModule.clearWorkspace();
|
||||||
|
slashReply('Workspace cleared.');
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
if (sub === 'pick' || sub === 'browse' || sub === 'open') {
|
||||||
|
workspaceModule.openWorkspaceBrowser();
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
slashReply('Usage: <code>/workspace</code> · <code>set /path</code> · <code>clear</code> · <code>pick</code>');
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
async function _cmdToggleShow(args, ctx) {
|
async function _cmdToggleShow(args, ctx) {
|
||||||
const name = (args[0] || '').toLowerCase();
|
const name = (args[0] || '').toLowerCase();
|
||||||
const val = (args[1] || '').toLowerCase();
|
const val = (args[1] || '').toLowerCase();
|
||||||
@@ -5731,6 +5766,14 @@ const COMMANDS = {
|
|||||||
'_show': { handler: _cmdToggleShow, alias: [], help: 'Show all toggle states', usage: '/toggle' }
|
'_show': { handler: _cmdToggleShow, alias: [], help: 'Show all toggle states', usage: '/toggle' }
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
workspace: {
|
||||||
|
alias: ['ws'],
|
||||||
|
category: 'Agent',
|
||||||
|
help: 'Set the folder the agent works in',
|
||||||
|
handler: _cmdWorkspace,
|
||||||
|
noUserBubble: true,
|
||||||
|
usage: '/workspace [set <path> | clear | pick]',
|
||||||
|
},
|
||||||
memory: {
|
memory: {
|
||||||
alias: ['m'],
|
alias: ['m'],
|
||||||
category: 'Memory',
|
category: 'Memory',
|
||||||
|
|||||||
@@ -23,7 +23,8 @@ export const KEYS = {
|
|||||||
MCP_ACTIVE: 'odysseus-mcp-active',
|
MCP_ACTIVE: 'odysseus-mcp-active',
|
||||||
SECTION_ORDER: 'sidebar-section-order',
|
SECTION_ORDER: 'sidebar-section-order',
|
||||||
ADMIN_LAST_TAB: 'admin-last-tab',
|
ADMIN_LAST_TAB: 'admin-last-tab',
|
||||||
DENSITY: 'odysseus-density'
|
DENSITY: 'odysseus-density',
|
||||||
|
WORKSPACE: 'odysseus-workspace'
|
||||||
};
|
};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|||||||
@@ -0,0 +1,208 @@
|
|||||||
|
// static/js/workspace.js
|
||||||
|
//
|
||||||
|
// Workspace picker: browse server directories in a draggable modal, choose a
|
||||||
|
// folder, and show it as a removable pill in the chat input bar. While set, the
|
||||||
|
// chat request sends `workspace` so the agent's file/shell tools are confined
|
||||||
|
// to that folder (see routes/chat_routes.py + src/tool_execution.py).
|
||||||
|
|
||||||
|
import Storage, { KEYS } from './storage.js';
|
||||||
|
import uiModule from './ui.js';
|
||||||
|
import { makeWindowDraggable } from './windowDrag.js';
|
||||||
|
|
||||||
|
const API_BASE = window.location.origin;
|
||||||
|
// Same folder glyph as the overflow menu item + pill (not an emoji).
|
||||||
|
const _FOLDER_SVG = '<svg class="workspace-row-icon" width="15" height="15" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><path d="M3 7a2 2 0 0 1 2-2h4l2 2h8a2 2 0 0 1 2 2v8a2 2 0 0 1-2 2H5a2 2 0 0 1-2-2z"/></svg>';
|
||||||
|
let _modal = null;
|
||||||
|
let _curPath = '';
|
||||||
|
|
||||||
|
export function getWorkspace() {
|
||||||
|
return Storage.get(KEYS.WORKSPACE, '') || '';
|
||||||
|
}
|
||||||
|
|
||||||
|
function _basename(p) {
|
||||||
|
if (!p) return '';
|
||||||
|
// Handle both POSIX (/) and Windows (\) separators.
|
||||||
|
const parts = p.replace(/[\\/]+$/, '').split(/[\\/]/);
|
||||||
|
return parts[parts.length - 1] || p;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Workspace only applies to agent mode (it scopes the file/shell tools), so the
|
||||||
|
// pill + overflow entry are hidden in chat mode, like the bash toggle.
|
||||||
|
function _isChatMode() {
|
||||||
|
const b = document.getElementById('mode-chat-btn');
|
||||||
|
return !!(b && b.classList.contains('active'));
|
||||||
|
}
|
||||||
|
|
||||||
|
export function syncWorkspaceIndicator(path) {
|
||||||
|
const chat = _isChatMode();
|
||||||
|
const pill = document.getElementById('workspace-indicator-btn');
|
||||||
|
const name = document.getElementById('workspace-indicator-name');
|
||||||
|
const overflow = document.getElementById('overflow-workspace-btn');
|
||||||
|
if (pill) {
|
||||||
|
pill.style.display = (path && !chat) ? '' : 'none';
|
||||||
|
pill.classList.toggle('active', !!path);
|
||||||
|
if (path) pill.title = `Workspace: ${path}\nFile tools are confined here; shell commands start here but are not sandboxed and can reach outside it.\nClick to clear.`;
|
||||||
|
}
|
||||||
|
if (name) name.textContent = path ? _basename(path) : '';
|
||||||
|
if (overflow) {
|
||||||
|
overflow.style.display = chat ? 'none' : '';
|
||||||
|
overflow.classList.toggle('active', !!path);
|
||||||
|
}
|
||||||
|
// Recompute the "+" overflow dot (app.js owns updatePlusDot via this event).
|
||||||
|
try { document.dispatchEvent(new CustomEvent('overflow-state-change')); } catch (_) {}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Called by the agent/chat mode toggle so the pill + overflow entry follow mode.
|
||||||
|
export function applyMode(_mode) {
|
||||||
|
syncWorkspaceIndicator(getWorkspace());
|
||||||
|
}
|
||||||
|
|
||||||
|
export function setWorkspace(path) {
|
||||||
|
if (path) Storage.set(KEYS.WORKSPACE, path);
|
||||||
|
else Storage.remove(KEYS.WORKSPACE);
|
||||||
|
syncWorkspaceIndicator(path || '');
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Validate a manually entered path server-side, then persist the canonical
|
||||||
|
* form. Returns {ok, path|null}. Without this, a typo / file path / deleted
|
||||||
|
* folder / filesystem root would be stored and shown as active while the
|
||||||
|
* backend silently refuses to bind it on every send.
|
||||||
|
*/
|
||||||
|
export async function vetAndSetWorkspace(path) {
|
||||||
|
try {
|
||||||
|
const res = await fetch(`${API_BASE}/api/workspace/vet?path=${encodeURIComponent(path)}`, { credentials: 'same-origin' });
|
||||||
|
if (!res.ok) return { ok: false, path: null };
|
||||||
|
const data = await res.json();
|
||||||
|
if (data.ok && data.path) {
|
||||||
|
setWorkspace(data.path);
|
||||||
|
return { ok: true, path: data.path };
|
||||||
|
}
|
||||||
|
return { ok: false, path: null };
|
||||||
|
} catch (e) {
|
||||||
|
return { ok: false, path: null };
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
export function clearWorkspace() {
|
||||||
|
setWorkspace('');
|
||||||
|
if (uiModule && uiModule.showToast) uiModule.showToast('Workspace cleared');
|
||||||
|
}
|
||||||
|
|
||||||
|
async function _load(path) {
|
||||||
|
const url = `${API_BASE}/api/workspace/browse${path ? `?path=${encodeURIComponent(path)}` : ''}`;
|
||||||
|
const res = await fetch(url, { credentials: 'same-origin' });
|
||||||
|
if (!res.ok) throw new Error(`browse failed: ${res.status}`);
|
||||||
|
return res.json();
|
||||||
|
}
|
||||||
|
|
||||||
|
function _render(data) {
|
||||||
|
_curPath = data.path;
|
||||||
|
const body = _modal.querySelector('#workspace-body');
|
||||||
|
const pathEl = _modal.querySelector('#workspace-cur-path');
|
||||||
|
if (pathEl) {
|
||||||
|
// Reflect the resolved (realpath) location back into the editable field.
|
||||||
|
pathEl.value = data.path;
|
||||||
|
pathEl.title = data.path;
|
||||||
|
}
|
||||||
|
let rows = '';
|
||||||
|
if (data.parent) {
|
||||||
|
rows += `<div class="workspace-row workspace-up" data-path="${encodeURIComponent(data.parent)}">↑ ..</div>`;
|
||||||
|
}
|
||||||
|
for (const d of data.dirs) {
|
||||||
|
// Backend supplies the full child path (os.path.join → cross-platform).
|
||||||
|
rows += `<div class="workspace-row" data-path="${encodeURIComponent(d.path)}">${_FOLDER_SVG}<span>${uiModule.esc(d.name)}</span></div>`;
|
||||||
|
}
|
||||||
|
if (data.truncated) {
|
||||||
|
rows += '<div class="workspace-empty">Too many folders to list. Type or paste a path above to jump in.</div>';
|
||||||
|
}
|
||||||
|
if (!data.dirs.length && !data.parent) rows = '<div class="workspace-empty">No subfolders</div>';
|
||||||
|
body.innerHTML = rows || '<div class="workspace-empty">No subfolders</div>';
|
||||||
|
body.querySelectorAll('.workspace-row').forEach((row) => {
|
||||||
|
row.addEventListener('click', () => _navigate(decodeURIComponent(row.dataset.path)));
|
||||||
|
});
|
||||||
|
// Filesystem roots (and sensitive dirs) can be browsed through but never
|
||||||
|
// bound as the workspace; the backend rejects them too.
|
||||||
|
const useBtn = _modal.querySelector('#workspace-use');
|
||||||
|
if (useBtn) {
|
||||||
|
useBtn.disabled = data.selectable === false;
|
||||||
|
useBtn.title = data.selectable === false ? 'This folder cannot be used as a workspace' : '';
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async function _navigate(path) {
|
||||||
|
try {
|
||||||
|
_render(await _load(path));
|
||||||
|
} catch (e) {
|
||||||
|
if (uiModule && uiModule.showError) uiModule.showError('Could not open folder');
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function _getModal() {
|
||||||
|
if (_modal) return _modal;
|
||||||
|
_modal = document.createElement('div');
|
||||||
|
_modal.id = 'workspace-modal';
|
||||||
|
_modal.className = 'modal';
|
||||||
|
_modal.style.display = 'none';
|
||||||
|
_modal.innerHTML = `
|
||||||
|
<div class="modal-content">
|
||||||
|
<div class="modal-header">
|
||||||
|
<h4><svg width="14" height="14" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" style="vertical-align:-2px;margin-right:6px"><path d="M3 7a2 2 0 0 1 2-2h4l2 2h8a2 2 0 0 1 2 2v8a2 2 0 0 1-2 2H5a2 2 0 0 1-2-2z"/></svg>Select workspace</h4>
|
||||||
|
<button class="close-btn" id="workspace-close" aria-label="Close">✖</button>
|
||||||
|
</div>
|
||||||
|
<input type="text" class="styled-prompt-input workspace-cur" id="workspace-cur-path"
|
||||||
|
spellcheck="false" autocomplete="off" autocapitalize="off" autocorrect="off"
|
||||||
|
placeholder="Type or paste a folder path, then press Enter" />
|
||||||
|
<p class="muted workspace-note">File tools are <strong>confined</strong> to this folder. Shell commands start here but are <strong>not sandboxed</strong> and can reach outside it. A workspace scopes the tools; it is not a security boundary.</p>
|
||||||
|
<div class="modal-body workspace-body" id="workspace-body"></div>
|
||||||
|
<div class="modal-footer workspace-footer">
|
||||||
|
<button type="button" class="confirm-btn confirm-btn-secondary" id="workspace-cancel">Cancel</button>
|
||||||
|
<button type="button" class="confirm-btn confirm-btn-primary" id="workspace-use">Use this folder</button>
|
||||||
|
</div>
|
||||||
|
</div>`;
|
||||||
|
document.body.appendChild(_modal);
|
||||||
|
_modal.querySelector('#workspace-close').addEventListener('click', closeWorkspaceBrowser);
|
||||||
|
_modal.querySelector('#workspace-cancel').addEventListener('click', closeWorkspaceBrowser);
|
||||||
|
// Editable path bar: Enter navigates to a typed/pasted folder.
|
||||||
|
_modal.querySelector('#workspace-cur-path').addEventListener('keydown', (e) => {
|
||||||
|
if (e.key === 'Enter') {
|
||||||
|
e.preventDefault();
|
||||||
|
const v = e.target.value.trim();
|
||||||
|
if (v) _navigate(v);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
_modal.querySelector('#workspace-use').addEventListener('click', () => {
|
||||||
|
setWorkspace(_curPath);
|
||||||
|
if (uiModule && uiModule.showToast) uiModule.showToast(`Workspace set: ${_basename(_curPath)}`);
|
||||||
|
closeWorkspaceBrowser();
|
||||||
|
});
|
||||||
|
const content = _modal.querySelector('.modal-content');
|
||||||
|
const header = _modal.querySelector('.modal-header');
|
||||||
|
if (content && header) makeWindowDraggable(_modal, { content, header });
|
||||||
|
return _modal;
|
||||||
|
}
|
||||||
|
|
||||||
|
export async function openWorkspaceBrowser() {
|
||||||
|
const modal = _getModal();
|
||||||
|
modal.style.display = 'flex';
|
||||||
|
try {
|
||||||
|
_render(await _load(getWorkspace() || ''));
|
||||||
|
} catch (e) {
|
||||||
|
if (uiModule && uiModule.showError) uiModule.showError('Could not browse folders');
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
export function closeWorkspaceBrowser() {
|
||||||
|
if (_modal) _modal.style.display = 'none';
|
||||||
|
}
|
||||||
|
|
||||||
|
export function initWorkspace() {
|
||||||
|
// Restore persisted workspace into the pill on load.
|
||||||
|
syncWorkspaceIndicator(getWorkspace());
|
||||||
|
const overflow = document.getElementById('overflow-workspace-btn');
|
||||||
|
if (overflow) overflow.addEventListener('click', openWorkspaceBrowser);
|
||||||
|
const pill = document.getElementById('workspace-indicator-btn');
|
||||||
|
if (pill) pill.addEventListener('click', clearWorkspace);
|
||||||
|
}
|
||||||
|
|
||||||
|
export default { initWorkspace, openWorkspaceBrowser, getWorkspace, setWorkspace, vetAndSetWorkspace, clearWorkspace, syncWorkspaceIndicator, applyMode };
|
||||||
@@ -36606,3 +36606,48 @@ body.theme-frosted .modal {
|
|||||||
the input beside it (.confirm-btn won't stretch on its own). */
|
the input beside it (.confirm-btn won't stretch on its own). */
|
||||||
.ask-user-other-send { flex-shrink: 0; white-space: nowrap; min-height: 39px; }
|
.ask-user-other-send { flex-shrink: 0; white-space: nowrap; min-height: 39px; }
|
||||||
.ask-user-other-send:disabled { opacity: 0.5; cursor: default; }
|
.ask-user-other-send:disabled { opacity: 0.5; cursor: default; }
|
||||||
|
|
||||||
|
/* ── Workspace picker ───────────────────────────────────────────── */
|
||||||
|
/* Layout (width/flex column/max-height) inherited from base .modal-content. */
|
||||||
|
/* Editable path/address bar: reuses .styled-prompt-input for border/bg/radius/
|
||||||
|
focus ring (set in the element's class list). Overrides only the deltas:
|
||||||
|
mono font, and full-bleed via flex stretch with no horizontal margin (the
|
||||||
|
modal-content's 10px padding is the gutter) instead of the base width:100%,
|
||||||
|
which overflowed against the overflow:auto scrollbar. */
|
||||||
|
.workspace-cur {
|
||||||
|
align-self: stretch;
|
||||||
|
width: auto;
|
||||||
|
min-width: 0;
|
||||||
|
margin: 4px 0 8px;
|
||||||
|
font-family: var(--mono, monospace);
|
||||||
|
font-size: 12px;
|
||||||
|
}
|
||||||
|
/* flex/overflow inherited from base .modal-body; only the padding differs. */
|
||||||
|
.workspace-body { padding: 6px 0; }
|
||||||
|
.workspace-row {
|
||||||
|
padding: 7px 18px;
|
||||||
|
cursor: pointer;
|
||||||
|
font-size: 13px;
|
||||||
|
display: flex;
|
||||||
|
align-items: center;
|
||||||
|
gap: 8px;
|
||||||
|
}
|
||||||
|
.workspace-row > span {
|
||||||
|
white-space: nowrap;
|
||||||
|
overflow: hidden;
|
||||||
|
text-overflow: ellipsis;
|
||||||
|
}
|
||||||
|
.workspace-row-icon { flex-shrink: 0; opacity: 0.75; }
|
||||||
|
.workspace-row:hover {
|
||||||
|
background: color-mix(in srgb, var(--border) 20%, transparent);
|
||||||
|
}
|
||||||
|
.workspace-up { opacity: 0.7; }
|
||||||
|
.workspace-empty { padding: 14px 18px; opacity: 0.5; font-size: 13px; }
|
||||||
|
.workspace-footer {
|
||||||
|
display: flex;
|
||||||
|
justify-content: flex-end;
|
||||||
|
gap: 8px;
|
||||||
|
padding: 10px 18px;
|
||||||
|
border-top: 1px solid var(--border);
|
||||||
|
}
|
||||||
|
.workspace-note { margin: 0 0 8px; font-size: 11px; line-height: 1.4; }
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ disturb marker registration or focused selection.
|
|||||||
|
|
||||||
Groups whose tests need no route/app setup and no real DB/session setup:
|
Groups whose tests need no route/app setup and no real DB/session setup:
|
||||||
|
|
||||||
1. **CLI / script tests** (`area_cli`, 27 files) - load `scripts/` entry
|
1. **CLI / script tests** (`area_cli`, 28 files) - load `scripts/` entry
|
||||||
points via `tests.helpers.cli_loader.load_script`; DB access is stubbed
|
points via `tests.helpers.cli_loader.load_script`; DB access is stubbed
|
||||||
with `tests.helpers.db_stubs` (`SessionLocal` is a plain stub attribute).
|
with `tests.helpers.db_stubs` (`SessionLocal` is a plain stub attribute).
|
||||||
No `TestClient`, no FastAPI app import, no SQLite files.
|
No `TestClient`, no FastAPI app import, no SQLite files.
|
||||||
@@ -59,7 +59,9 @@ Why this group over the alternatives:
|
|||||||
|
|
||||||
## Files included in the first move
|
## Files included in the first move
|
||||||
|
|
||||||
The 27 files classified `area_cli` (verified against `_taxonomy.py`):
|
The 28 files classified `area_cli` (verified against `_taxonomy.py`):
|
||||||
|
|
||||||
|
Note: this inventory was refreshed against current `dev` after `tests/test_research_cli_status.py` was added to the `area_cli` set.
|
||||||
|
|
||||||
- `tests/test_calendar_cli_name.py`
|
- `tests/test_calendar_cli_name.py`
|
||||||
- `tests/test_contacts_cli_rows.py`
|
- `tests/test_contacts_cli_rows.py`
|
||||||
@@ -80,6 +82,7 @@ The 27 files classified `area_cli` (verified against `_taxonomy.py`):
|
|||||||
- `tests/test_preset_cli_store.py`
|
- `tests/test_preset_cli_store.py`
|
||||||
- `tests/test_research_cli_preview.py`
|
- `tests/test_research_cli_preview.py`
|
||||||
- `tests/test_research_cli_status_filter.py`
|
- `tests/test_research_cli_status_filter.py`
|
||||||
|
- `tests/test_research_cli_status.py`
|
||||||
- `tests/test_research_cli_store.py`
|
- `tests/test_research_cli_store.py`
|
||||||
- `tests/test_sessions_cli.py`
|
- `tests/test_sessions_cli.py`
|
||||||
- `tests/test_signature_cli_export.py`
|
- `tests/test_signature_cli_export.py`
|
||||||
@@ -115,7 +118,7 @@ Read-only checks, run from the repo root on this branch. Note the real API is
|
|||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Compute the area_cli set and confirm test_backup_cli_security.py is
|
# Compute the area_cli set and confirm test_backup_cli_security.py is
|
||||||
# area_security. Expected: 27 files, then "security".
|
# area_security. Expected: 28 files, then "security".
|
||||||
.venv/bin/python - <<'PY'
|
.venv/bin/python - <<'PY'
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from tests._taxonomy import classify_test_path
|
from tests._taxonomy import classify_test_path
|
||||||
@@ -151,7 +154,7 @@ rg -n -F -f /tmp/area_cli_paths.txt .github scripts docs \
|
|||||||
Also checked by reading the code: `tests/conftest.py` registers sub-markers
|
Also checked by reading the code: `tests/conftest.py` registers sub-markers
|
||||||
from a recursive `rglob` scan, and `tests/_taxonomy.py` classifies by filename
|
from a recursive `rglob` scan, and `tests/_taxonomy.py` classifies by filename
|
||||||
tokens only (plus the `tests/helpers/` directory rule), so the markers of the
|
tokens only (plus the `tests/helpers/` directory rule), so the markers of the
|
||||||
27 files do not change when they move into `tests/cli/`.
|
28 files do not change when they move into `tests/cli/`.
|
||||||
|
|
||||||
## Validation for the future move PR
|
## Validation for the future move PR
|
||||||
|
|
||||||
@@ -159,7 +162,7 @@ Run with the project venv (`.venv/bin/python`); system `python3` may miss
|
|||||||
pinned deps. Before the move, record the baseline; after, compare:
|
pinned deps. Before the move, record the baseline; after, compare:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Selection must match the 27 files before and after the move.
|
# Selection must match the 28 files before and after the move.
|
||||||
.venv/bin/python tests/run_focus.py --dry-run --area cli
|
.venv/bin/python tests/run_focus.py --dry-run --area cli
|
||||||
.venv/bin/python -m pytest -m area_cli -q
|
.venv/bin/python -m pytest -m area_cli -q
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,43 @@
|
|||||||
|
"""Tool-output display truncation uses _truncate with an indicator.
|
||||||
|
|
||||||
|
Previously agent_loop sliced tool output to a hard character limit ([:2000]
|
||||||
|
or [:4000]) with no signal to the UI that data was lost. Now it delegates to
|
||||||
|
tool_utils._truncate which caps at MAX_OUTPUT_CHARS (10 000) and appends
|
||||||
|
a ``... (truncated, N chars total)`` suffix so the frontend can show a
|
||||||
|
truncation indicator in the tool bubble.
|
||||||
|
"""
|
||||||
|
from src.tool_utils import _truncate, MAX_OUTPUT_CHARS
|
||||||
|
|
||||||
|
|
||||||
|
def test_short_output_unchanged():
|
||||||
|
"""Outputs within the limit pass through verbatim."""
|
||||||
|
text = "hello world"
|
||||||
|
assert _truncate(text) == text
|
||||||
|
|
||||||
|
|
||||||
|
def test_long_output_truncated_with_indicator():
|
||||||
|
"""Outputs exceeding MAX_OUTPUT_CHARS are truncated with a suffix."""
|
||||||
|
text = "x" * (MAX_OUTPUT_CHARS + 500)
|
||||||
|
result = _truncate(text)
|
||||||
|
assert len(result) > MAX_OUTPUT_CHARS # includes suffix
|
||||||
|
assert result.startswith("x" * MAX_OUTPUT_CHARS)
|
||||||
|
assert "truncated" in result
|
||||||
|
assert str(len(text)) in result # original length reported
|
||||||
|
|
||||||
|
|
||||||
|
def test_exact_limit_unchanged():
|
||||||
|
"""An output exactly at the limit is not truncated."""
|
||||||
|
text = "a" * MAX_OUTPUT_CHARS
|
||||||
|
assert _truncate(text) == text
|
||||||
|
|
||||||
|
|
||||||
|
def test_default_limit_matches_constant():
|
||||||
|
"""_truncate default limit equals MAX_OUTPUT_CHARS (10 000)."""
|
||||||
|
assert MAX_OUTPUT_CHARS == 10_000
|
||||||
|
text = "y" * 10_001
|
||||||
|
result = _truncate(text)
|
||||||
|
assert "truncated" in result
|
||||||
|
|
||||||
|
|
||||||
|
def test_empty_string():
|
||||||
|
assert _truncate("") == ""
|
||||||
@@ -287,8 +287,9 @@ def test_delete_token_deletes_and_invalidates_cache(monkeypatch, token_routes_mo
|
|||||||
monkeypatch.setattr(mod, "get_current_user", lambda req: req.state.current_user)
|
monkeypatch.setattr(mod, "get_current_user", lambda req: req.state.current_user)
|
||||||
monkeypatch.setattr(mod, "ApiToken", MagicMock())
|
monkeypatch.setattr(mod, "ApiToken", MagicMock())
|
||||||
|
|
||||||
|
fake_token = SimpleNamespace(id="abcd1234", owner="alice", name="test")
|
||||||
fake_session = MagicMock()
|
fake_session = MagicMock()
|
||||||
fake_session.query.return_value.filter.return_value.delete.return_value = 1
|
fake_session.query.return_value.filter.return_value.first.return_value = fake_token
|
||||||
monkeypatch.setattr(mod, "get_db_session", lambda: _db_ctx(fake_session))
|
monkeypatch.setattr(mod, "get_db_session", lambda: _db_ctx(fake_session))
|
||||||
|
|
||||||
invalidator = MagicMock()
|
invalidator = MagicMock()
|
||||||
@@ -297,6 +298,7 @@ def test_delete_token_deletes_and_invalidates_cache(monkeypatch, token_routes_mo
|
|||||||
resp = delete_token(request=req, token_id="abcd1234")
|
resp = delete_token(request=req, token_id="abcd1234")
|
||||||
|
|
||||||
assert resp == {"status": "deleted"}
|
assert resp == {"status": "deleted"}
|
||||||
|
fake_session.delete.assert_called_once_with(fake_token)
|
||||||
invalidator.assert_called_once()
|
invalidator.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
@@ -312,7 +314,7 @@ def test_delete_missing_token_returns_404_without_invalidating_cache(monkeypatch
|
|||||||
monkeypatch.setattr(mod, "ApiToken", MagicMock())
|
monkeypatch.setattr(mod, "ApiToken", MagicMock())
|
||||||
|
|
||||||
fake_session = MagicMock()
|
fake_session = MagicMock()
|
||||||
fake_session.query.return_value.filter.return_value.delete.return_value = 0
|
fake_session.query.return_value.filter.return_value.first.return_value = None
|
||||||
monkeypatch.setattr(mod, "get_db_session", lambda: _db_ctx(fake_session))
|
monkeypatch.setattr(mod, "get_db_session", lambda: _db_ctx(fake_session))
|
||||||
|
|
||||||
invalidator = MagicMock()
|
invalidator = MagicMock()
|
||||||
@@ -404,3 +406,99 @@ def test_update_missing_token_returns_404(monkeypatch, token_routes_mod):
|
|||||||
with pytest.raises(HTTPException) as exc:
|
with pytest.raises(HTTPException) as exc:
|
||||||
asyncio.run(update_token(request=req, token_id="missing99"))
|
asyncio.run(update_token(request=req, token_id="missing99"))
|
||||||
assert exc.value.status_code == 404
|
assert exc.value.status_code == 404
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# 7. Owner check — update/delete reject a different admin's token with 403
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _bob_patch_request(invalidator, body):
|
||||||
|
"""An admin request from bob whose async .json() yields `body`."""
|
||||||
|
req = _req("bob", is_admin=True, invalidator=invalidator)
|
||||||
|
|
||||||
|
async def _json():
|
||||||
|
return body
|
||||||
|
|
||||||
|
req.json = _json
|
||||||
|
return req
|
||||||
|
|
||||||
|
|
||||||
|
def test_update_token_rejects_non_owner(monkeypatch, token_routes_mod):
|
||||||
|
monkeypatch.setenv("AUTH_ENABLED", "true")
|
||||||
|
mod = token_routes_mod
|
||||||
|
monkeypatch.setattr(mod, "get_current_user", lambda req: req.state.current_user)
|
||||||
|
|
||||||
|
token = SimpleNamespace(
|
||||||
|
id="tok123", name="alice-token", owner="alice",
|
||||||
|
token_prefix="ody_alic", scopes="chat", is_active=True,
|
||||||
|
)
|
||||||
|
fake_session = MagicMock()
|
||||||
|
fake_session.query.return_value.filter.return_value.first.return_value = token
|
||||||
|
monkeypatch.setattr(mod, "get_db_session", lambda: _db_ctx(fake_session))
|
||||||
|
|
||||||
|
req = _bob_patch_request(MagicMock(), {"name": "hijacked"})
|
||||||
|
update_token = _get_handler(mod, "PATCH", "/tokens/{token_id}")
|
||||||
|
with pytest.raises(HTTPException) as exc:
|
||||||
|
asyncio.run(update_token(request=req, token_id="tok123"))
|
||||||
|
assert exc.value.status_code == 403
|
||||||
|
assert token.name == "alice-token"
|
||||||
|
|
||||||
|
|
||||||
|
def test_delete_token_rejects_non_owner(monkeypatch, token_routes_mod):
|
||||||
|
monkeypatch.setenv("AUTH_ENABLED", "true")
|
||||||
|
mod = token_routes_mod
|
||||||
|
monkeypatch.setattr(mod, "get_current_user", lambda req: req.state.current_user)
|
||||||
|
monkeypatch.setattr(mod, "ApiToken", MagicMock())
|
||||||
|
|
||||||
|
fake_token = SimpleNamespace(id="tok123", owner="alice", name="alice-token")
|
||||||
|
fake_session = MagicMock()
|
||||||
|
fake_session.query.return_value.filter.return_value.first.return_value = fake_token
|
||||||
|
monkeypatch.setattr(mod, "get_db_session", lambda: _db_ctx(fake_session))
|
||||||
|
|
||||||
|
invalidator = MagicMock()
|
||||||
|
req = _req("bob", is_admin=True, invalidator=invalidator)
|
||||||
|
delete_token = _get_handler(mod, "DELETE", "/tokens/{token_id}")
|
||||||
|
with pytest.raises(HTTPException) as exc:
|
||||||
|
delete_token(request=req, token_id="tok123")
|
||||||
|
assert exc.value.status_code == 403
|
||||||
|
fake_session.delete.assert_not_called()
|
||||||
|
invalidator.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
def test_update_token_owner_check_skipped_when_auth_disabled(monkeypatch, token_routes_mod):
|
||||||
|
monkeypatch.setenv("AUTH_ENABLED", "false")
|
||||||
|
mod = token_routes_mod
|
||||||
|
monkeypatch.setattr(mod, "get_current_user", lambda req: None)
|
||||||
|
|
||||||
|
token = SimpleNamespace(
|
||||||
|
id="tok123", name="original", owner="alice",
|
||||||
|
token_prefix="ody_alic", scopes="chat", is_active=True,
|
||||||
|
)
|
||||||
|
fake_session = MagicMock()
|
||||||
|
fake_session.query.return_value.filter.return_value.first.return_value = token
|
||||||
|
monkeypatch.setattr(mod, "get_db_session", lambda: _db_ctx(fake_session))
|
||||||
|
|
||||||
|
req = _bob_patch_request(MagicMock(), {"name": "renamed-in-single-user"})
|
||||||
|
update_token = _get_handler(mod, "PATCH", "/tokens/{token_id}")
|
||||||
|
resp = asyncio.run(update_token(request=req, token_id="tok123"))
|
||||||
|
assert resp["name"] == "renamed-in-single-user"
|
||||||
|
|
||||||
|
|
||||||
|
def test_delete_token_owner_check_skipped_when_auth_disabled(monkeypatch, token_routes_mod):
|
||||||
|
monkeypatch.setenv("AUTH_ENABLED", "false")
|
||||||
|
mod = token_routes_mod
|
||||||
|
monkeypatch.setattr(mod, "get_current_user", lambda req: None)
|
||||||
|
monkeypatch.setattr(mod, "ApiToken", MagicMock())
|
||||||
|
|
||||||
|
fake_token = SimpleNamespace(id="tok123", owner="alice", name="alice-token")
|
||||||
|
fake_session = MagicMock()
|
||||||
|
fake_session.query.return_value.filter.return_value.first.return_value = fake_token
|
||||||
|
monkeypatch.setattr(mod, "get_db_session", lambda: _db_ctx(fake_session))
|
||||||
|
|
||||||
|
invalidator = MagicMock()
|
||||||
|
req = _req("", is_admin=True, invalidator=invalidator)
|
||||||
|
delete_token = _get_handler(mod, "DELETE", "/tokens/{token_id}")
|
||||||
|
resp = delete_token(request=req, token_id="tok123")
|
||||||
|
assert resp == {"status": "deleted"}
|
||||||
|
fake_session.delete.assert_called_once_with(fake_token)
|
||||||
|
|||||||
@@ -8,6 +8,9 @@ with missing users or assertion errors.
|
|||||||
import json
|
import json
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
|
import contextlib
|
||||||
|
import sys
|
||||||
|
import types
|
||||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@@ -15,6 +18,41 @@ import pytest
|
|||||||
from tests.helpers.import_state import clear_module
|
from tests.helpers.import_state import clear_module
|
||||||
|
|
||||||
|
|
||||||
|
class _OwnerColumn:
|
||||||
|
def __eq__(self, other):
|
||||||
|
return ("owner ==", other)
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeApiToken:
|
||||||
|
owner = _OwnerColumn()
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeQuery:
|
||||||
|
def filter(self, *_conds):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def delete(self, *args, **kwargs):
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeSession:
|
||||||
|
def query(self, model):
|
||||||
|
assert model is _FakeApiToken
|
||||||
|
return _FakeQuery()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def _stub_api_token_purge(monkeypatch):
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def _fake_db_session():
|
||||||
|
yield _FakeSession()
|
||||||
|
|
||||||
|
db_stub = types.ModuleType("core.database")
|
||||||
|
db_stub.get_db_session = _fake_db_session
|
||||||
|
db_stub.ApiToken = _FakeApiToken
|
||||||
|
monkeypatch.setitem(sys.modules, "core.database", db_stub)
|
||||||
|
|
||||||
|
|
||||||
def _fresh_auth_manager(tmp_path):
|
def _fresh_auth_manager(tmp_path):
|
||||||
clear_module("core.auth")
|
clear_module("core.auth")
|
||||||
from core.auth import AuthManager
|
from core.auth import AuthManager
|
||||||
|
|||||||
@@ -106,6 +106,9 @@ async def test_learn_sender_signatures_resolves_llm_for_task_owner(monkeypatch):
|
|||||||
from src.builtin_actions import action_learn_sender_signatures
|
from src.builtin_actions import action_learn_sender_signatures
|
||||||
|
|
||||||
class FakeImap:
|
class FakeImap:
|
||||||
|
def __init__(self, owner=""):
|
||||||
|
self.owner = owner
|
||||||
|
|
||||||
def select(self, *_args, **_kwargs):
|
def select(self, *_args, **_kwargs):
|
||||||
return "OK", []
|
return "OK", []
|
||||||
|
|
||||||
@@ -119,13 +122,20 @@ async def test_learn_sender_signatures_resolves_llm_for_task_owner(monkeypatch):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
calls, _fallback_calls = _resolver_spy(monkeypatch, utility_result=("", "", {}), default_result=("", "", {}))
|
calls, _fallback_calls = _resolver_spy(monkeypatch, utility_result=("", "", {}), default_result=("", "", {}))
|
||||||
monkeypatch.setattr(email_helpers, "_imap_connect", lambda _account_id=None: FakeImap())
|
imap_owners = []
|
||||||
|
|
||||||
|
def fake_imap_connect(_account_id=None, owner=""):
|
||||||
|
imap_owners.append(owner)
|
||||||
|
return FakeImap(owner)
|
||||||
|
|
||||||
|
monkeypatch.setattr(email_helpers, "_imap_connect", fake_imap_connect)
|
||||||
|
|
||||||
message, ok = await action_learn_sender_signatures("alice")
|
message, ok = await action_learn_sender_signatures("alice")
|
||||||
|
|
||||||
assert ok is False
|
assert ok is False
|
||||||
assert message == "No LLM endpoint available"
|
assert message == "No LLM endpoint available"
|
||||||
assert calls == [("utility", "alice"), ("default", "alice")]
|
assert calls == [("utility", "alice"), ("default", "alice")]
|
||||||
|
assert imap_owners == ["alice"]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
|||||||
@@ -0,0 +1,94 @@
|
|||||||
|
"""llama.cpp slot-affinity fields must never reach cloud providers (#3793).
|
||||||
|
|
||||||
|
_apply_local_cache_affinity adds session_id + cache_prompt to outgoing
|
||||||
|
payloads for KV-cache slot affinity (#2927). The old gate treated any unknown
|
||||||
|
OpenAI-compatible host as self-hosted, so strict cloud APIs added as custom
|
||||||
|
endpoints (Mistral at api.mistral.ai) received the extra fields and rejected
|
||||||
|
every request with 422 extra_forbidden. Self-hosted now also requires the
|
||||||
|
endpoint to resolve as local: loopback/private/tailscale host, or endpoint
|
||||||
|
kind explicitly configured as "local".
|
||||||
|
"""
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
import src.llm_core as llm_core
|
||||||
|
import src.model_context as model_context
|
||||||
|
|
||||||
|
|
||||||
|
def _affinity_fields(url, monkeypatch, kind=None):
|
||||||
|
monkeypatch.setattr(model_context, "_configured_endpoint_kind", lambda _u: kind)
|
||||||
|
payload = {}
|
||||||
|
llm_core._apply_local_cache_affinity(payload, url, "sess-123")
|
||||||
|
return payload
|
||||||
|
|
||||||
|
|
||||||
|
def test_mistral_cloud_api_gets_no_affinity_fields(monkeypatch):
|
||||||
|
# The #3793 repro: Mistral rejects unknown body fields with 422.
|
||||||
|
payload = _affinity_fields("https://api.mistral.ai/v1", monkeypatch)
|
||||||
|
assert payload == {}
|
||||||
|
|
||||||
|
|
||||||
|
def test_openai_api_gets_no_affinity_fields(monkeypatch):
|
||||||
|
payload = _affinity_fields("https://api.openai.com/v1", monkeypatch)
|
||||||
|
assert payload == {}
|
||||||
|
|
||||||
|
|
||||||
|
def test_unknown_public_host_gets_no_affinity_fields(monkeypatch):
|
||||||
|
# Any strict cloud provider added as a custom endpoint, not just Mistral.
|
||||||
|
payload = _affinity_fields("https://llm.example-cloud.com/v1", monkeypatch)
|
||||||
|
assert payload == {}
|
||||||
|
|
||||||
|
|
||||||
|
def test_localhost_server_gets_affinity_fields(monkeypatch):
|
||||||
|
payload = _affinity_fields("http://localhost:8080/v1", monkeypatch)
|
||||||
|
assert payload == {"session_id": "sess-123", "cache_prompt": True}
|
||||||
|
|
||||||
|
|
||||||
|
def test_private_lan_server_gets_affinity_fields(monkeypatch):
|
||||||
|
payload = _affinity_fields("http://192.168.1.50:8000/v1", monkeypatch)
|
||||||
|
assert payload == {"session_id": "sess-123", "cache_prompt": True}
|
||||||
|
|
||||||
|
|
||||||
|
def test_public_host_with_local_kind_override_gets_affinity_fields(monkeypatch):
|
||||||
|
# Escape hatch: a self-hosted llama.cpp exposed via a tunnel keeps the
|
||||||
|
# slot-affinity hint when its endpoint kind is configured as "local".
|
||||||
|
payload = _affinity_fields("https://my-llama.example.com/v1", monkeypatch, kind="local")
|
||||||
|
assert payload == {"session_id": "sess-123", "cache_prompt": True}
|
||||||
|
|
||||||
|
|
||||||
|
def test_no_session_id_is_a_noop(monkeypatch):
|
||||||
|
monkeypatch.setattr(model_context, "_configured_endpoint_kind", lambda _u: None)
|
||||||
|
payload = {}
|
||||||
|
llm_core._apply_local_cache_affinity(payload, "http://localhost:8080/v1", None)
|
||||||
|
assert payload == {}
|
||||||
|
|
||||||
|
|
||||||
|
# Cloud-host sweep absorbed from #3839 (credit: Shabablinchikow) - every cloud
|
||||||
|
# API that falls through provider detection to the OpenAI-compatible default
|
||||||
|
# must stay clean, not just the Mistral host from the original report.
|
||||||
|
@pytest.mark.parametrize("url", [
|
||||||
|
"https://api.mistral.ai/v1/chat/completions",
|
||||||
|
"https://api.deepseek.com/v1/chat/completions",
|
||||||
|
"https://api.x.ai/v1/chat/completions",
|
||||||
|
"https://api.together.xyz/v1/chat/completions",
|
||||||
|
"https://api.fireworks.ai/inference/v1/chat/completions",
|
||||||
|
"https://generativelanguage.googleapis.com/v1beta/openai/chat/completions",
|
||||||
|
])
|
||||||
|
def test_cloud_openai_compatible_hosts_get_no_affinity_fields(monkeypatch, url):
|
||||||
|
assert _affinity_fields(url, monkeypatch) == {}
|
||||||
|
|
||||||
|
|
||||||
|
# Tailscale CGNAT boundaries (review finding on #3945): only 100.64.0.0/10 is
|
||||||
|
# Tailscale; the rest of 100.0.0.0/8 contains public ranges, and a strict
|
||||||
|
# provider addressed by one must not receive the llama.cpp extras.
|
||||||
|
def test_host_just_below_cgnat_gets_no_affinity_fields(monkeypatch):
|
||||||
|
assert _affinity_fields("http://100.63.255.255/v1", monkeypatch) == {}
|
||||||
|
|
||||||
|
|
||||||
|
def test_host_just_above_cgnat_gets_no_affinity_fields(monkeypatch):
|
||||||
|
assert _affinity_fields("http://100.128.0.1/v1", monkeypatch) == {}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("host", ["100.64.0.1", "100.100.50.2", "100.127.255.254"])
|
||||||
|
def test_hosts_inside_cgnat_get_affinity_fields(monkeypatch, host):
|
||||||
|
payload = _affinity_fields(f"http://{host}:8080/v1", monkeypatch)
|
||||||
|
assert payload == {"session_id": "sess-123", "cache_prompt": True}
|
||||||
@@ -0,0 +1,125 @@
|
|||||||
|
"""Test that do_manage_calendar handles the batch {"events": [...]} format
|
||||||
|
that models like deepseek-v4-flash emit instead of individual create_event calls.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import sys
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from tests.helpers.import_state import clear_fake_database_modules
|
||||||
|
from tests.helpers.sqlite_db import make_temp_sqlite
|
||||||
|
|
||||||
|
clear_fake_database_modules()
|
||||||
|
|
||||||
|
import core.database as cdb
|
||||||
|
from core.database import CalendarEvent
|
||||||
|
|
||||||
|
_TS, _ENGINE, _TMPDB = make_temp_sqlite(cdb.Base.metadata)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def _bind_temp_db(monkeypatch):
|
||||||
|
monkeypatch.setitem(sys.modules, "core.database", cdb)
|
||||||
|
parent = sys.modules.get("core")
|
||||||
|
if parent is not None:
|
||||||
|
monkeypatch.setattr(parent, "database", cdb, raising=False)
|
||||||
|
monkeypatch.setattr(cdb, "SessionLocal", _TS)
|
||||||
|
yield
|
||||||
|
|
||||||
|
|
||||||
|
async def test_batch_events_with_datetime_objects():
|
||||||
|
"""Model emits {"events": [{"summary": ..., "start": {"dateTime": ...}, "end": {"dateTime": ...}}]}."""
|
||||||
|
from src.tool_implementations import do_manage_calendar
|
||||||
|
|
||||||
|
owner = "tester-" + uuid.uuid4().hex[:6]
|
||||||
|
payload = {
|
||||||
|
"events": [
|
||||||
|
{
|
||||||
|
"summary": "Morning Gym",
|
||||||
|
"start": {"dateTime": "2026-06-09T06:00:00+05:30"},
|
||||||
|
"end": {"dateTime": "2026-06-09T07:00:00+05:30"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"summary": "Morning Gym",
|
||||||
|
"start": {"dateTime": "2026-06-10T06:00:00+05:30"},
|
||||||
|
"end": {"dateTime": "2026-06-10T07:00:00+05:30"},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
res = await do_manage_calendar(json.dumps(payload), owner=owner)
|
||||||
|
assert res.get("exit_code") == 0, res
|
||||||
|
assert "Created 2 event(s)" in res.get("response", "")
|
||||||
|
|
||||||
|
# Verify events exist in DB
|
||||||
|
db = _TS()
|
||||||
|
events = db.query(CalendarEvent).filter(CalendarEvent.summary == "Morning Gym").all()
|
||||||
|
assert len(events) == 2
|
||||||
|
db.close()
|
||||||
|
|
||||||
|
|
||||||
|
async def test_batch_events_with_flat_strings():
|
||||||
|
"""Model emits {"events": [{"summary": ..., "start": "ISO", "end": "ISO"}]}."""
|
||||||
|
from src.tool_implementations import do_manage_calendar
|
||||||
|
|
||||||
|
owner = "tester-" + uuid.uuid4().hex[:6]
|
||||||
|
payload = {
|
||||||
|
"events": [
|
||||||
|
{
|
||||||
|
"summary": "Standup",
|
||||||
|
"start": "2026-06-09T09:00:00",
|
||||||
|
"end": "2026-06-09T09:30:00",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
res = await do_manage_calendar(json.dumps(payload), owner=owner)
|
||||||
|
assert res.get("exit_code") == 0, res
|
||||||
|
assert "Created 1 event(s)" in res.get("response", "")
|
||||||
|
|
||||||
|
|
||||||
|
async def test_batch_events_partial_failure():
|
||||||
|
"""Batch with some valid and some invalid events — should surface both counts and first error."""
|
||||||
|
from src.tool_implementations import do_manage_calendar
|
||||||
|
|
||||||
|
owner = "tester-" + uuid.uuid4().hex[:6]
|
||||||
|
payload = {
|
||||||
|
"events": [
|
||||||
|
{
|
||||||
|
"summary": "Valid Event 1",
|
||||||
|
"start": "2026-06-09T10:00:00",
|
||||||
|
"end": "2026-06-09T11:00:00",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"summary": "Invalid Event",
|
||||||
|
# Missing required dtstart — will fail
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"summary": "Valid Event 2",
|
||||||
|
"start": "2026-06-09T14:00:00",
|
||||||
|
"end": "2026-06-09T15:00:00",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
res = await do_manage_calendar(json.dumps(payload), owner=owner)
|
||||||
|
|
||||||
|
# Partial failure = non-zero exit code
|
||||||
|
assert res.get("exit_code") != 0, "Partial failure should return non-zero exit code"
|
||||||
|
|
||||||
|
# Response should mention both created and failed counts
|
||||||
|
response = res.get("response", "")
|
||||||
|
assert "Created 2 event(s)" in response, f"Should report 2 created: {response}"
|
||||||
|
assert "Failed to create 1 event(s)" in response, f"Should report 1 failed: {response}"
|
||||||
|
assert "error" in response.lower() or "required" in response.lower(), "Should include error details"
|
||||||
|
|
||||||
|
# Metadata fields
|
||||||
|
assert res.get("created_count") == 2
|
||||||
|
assert res.get("failed_count") == 1
|
||||||
|
|
||||||
|
# Verify only valid events were created
|
||||||
|
db = _TS()
|
||||||
|
events = db.query(CalendarEvent).filter(
|
||||||
|
CalendarEvent.summary.in_(["Valid Event 1", "Valid Event 2"])
|
||||||
|
).all()
|
||||||
|
assert len(events) == 2
|
||||||
|
db.close()
|
||||||
@@ -0,0 +1,33 @@
|
|||||||
|
"""classify_events must read the Memory `text` column, not a non-existent
|
||||||
|
`content` attribute.
|
||||||
|
|
||||||
|
The previous inline loop did `m.content`, which raised AttributeError on the
|
||||||
|
first Memory row; the surrounding except swallowed it, so the personal-context
|
||||||
|
block the LLM relies on was always empty. The logic now lives in
|
||||||
|
`_memory_context_lines`, which reads `text`.
|
||||||
|
"""
|
||||||
|
from src.builtin_actions import _memory_context_lines
|
||||||
|
|
||||||
|
|
||||||
|
class _Mem:
|
||||||
|
def __init__(self, text):
|
||||||
|
self.text = text
|
||||||
|
|
||||||
|
|
||||||
|
def test_uses_text_and_truncates_and_skips_blank():
|
||||||
|
lines = _memory_context_lines([_Mem("Alice is my spouse"), _Mem(" "), _Mem("y" * 250)])
|
||||||
|
assert lines[0] == "- Alice is my spouse"
|
||||||
|
assert len(lines) == 2 # the blank row is skipped
|
||||||
|
assert lines[1] == "- " + "y" * 200 # truncated to 200 chars
|
||||||
|
|
||||||
|
|
||||||
|
def test_skips_rows_without_text_attribute():
|
||||||
|
class _Bad: # mimics a schema where the attribute is absent
|
||||||
|
pass
|
||||||
|
|
||||||
|
assert _memory_context_lines([_Bad(), _Mem("ok")]) == ["- ok"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_respects_limit():
|
||||||
|
mems = [_Mem(f"memory {i}") for i in range(50)]
|
||||||
|
assert len(_memory_context_lines(mems, limit=40)) == 40
|
||||||
@@ -0,0 +1,39 @@
|
|||||||
|
"""POST /api/contacts/import must not 500 on a non-string vcf/text/csv value.
|
||||||
|
|
||||||
|
`text = data.get("vcf") or ... or ""` left a non-string value (e.g. a number)
|
||||||
|
in place, so the next `text.strip()` raised AttributeError -> HTTP 500. The
|
||||||
|
handler now coerces with str() and degrades to a structured "no data" response.
|
||||||
|
"""
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
from routes.contacts_routes import setup_contacts_routes
|
||||||
|
|
||||||
|
|
||||||
|
def _import_handler():
|
||||||
|
router = setup_contacts_routes()
|
||||||
|
for route in router.routes:
|
||||||
|
if getattr(route, "path", "").endswith("/import") and "POST" in getattr(route, "methods", set()):
|
||||||
|
return route.endpoint
|
||||||
|
raise AssertionError("import route not found")
|
||||||
|
|
||||||
|
|
||||||
|
def _call(data):
|
||||||
|
handler = _import_handler()
|
||||||
|
return asyncio.run(handler(data=data, _admin="admin"))
|
||||||
|
|
||||||
|
|
||||||
|
def test_non_string_vcf_degrades_cleanly():
|
||||||
|
resp = _call({"vcf": 123})
|
||||||
|
assert resp["success"] is False
|
||||||
|
assert "error" in resp
|
||||||
|
|
||||||
|
|
||||||
|
def test_non_string_csv_degrades_cleanly():
|
||||||
|
resp = _call({"csv": ["a", "b"]})
|
||||||
|
assert resp["success"] is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_empty_body_reports_no_data():
|
||||||
|
resp = _call({})
|
||||||
|
assert resp["success"] is False
|
||||||
|
assert resp["error"] == "No contact data found"
|
||||||
@@ -11,7 +11,7 @@ import src.model_context as mc
|
|||||||
|
|
||||||
def _setup(monkeypatch, windows):
|
def _setup(monkeypatch, windows):
|
||||||
"""windows: {endpoint_url: context_length}. Force the remote path."""
|
"""windows: {endpoint_url: context_length}. Force the remote path."""
|
||||||
monkeypatch.setattr(mc, "_is_local_endpoint", lambda url: False)
|
monkeypatch.setattr(mc, "is_local_endpoint", lambda url: False)
|
||||||
monkeypatch.setattr(mc, "_configured_endpoint_kind", lambda url: "api")
|
monkeypatch.setattr(mc, "_configured_endpoint_kind", lambda url: "api")
|
||||||
monkeypatch.setattr(mc, "_query_context_length", lambda url, model: windows[url])
|
monkeypatch.setattr(mc, "_query_context_length", lambda url, model: windows[url])
|
||||||
mc._context_cache.clear()
|
mc._context_cache.clear()
|
||||||
|
|||||||
@@ -0,0 +1,12 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
ROOT = Path(__file__).resolve().parent.parent
|
||||||
|
DIAGNOSIS_JS = ROOT / "static" / "js" / "cookbook-diagnosis.js"
|
||||||
|
|
||||||
|
|
||||||
|
def test_repair_kernels_pip_spec_is_shell_quoted():
|
||||||
|
source = DIAGNOSIS_JS.read_text(encoding="utf-8")
|
||||||
|
|
||||||
|
assert '"kernels<0.15"' in source
|
||||||
|
assert " --break-system-packages kernels<0.15" not in source
|
||||||
@@ -26,7 +26,6 @@ from routes.cookbook_helpers import (
|
|||||||
_validate_repo_id,
|
_validate_repo_id,
|
||||||
_validate_serve_cmd,
|
_validate_serve_cmd,
|
||||||
_validate_serve_model_id,
|
_validate_serve_model_id,
|
||||||
_validate_ssh_port,
|
|
||||||
_shell_path,
|
_shell_path,
|
||||||
run_ssh_command_async,
|
run_ssh_command_async,
|
||||||
)
|
)
|
||||||
@@ -106,12 +105,6 @@ def test_safe_env_prefix_accepts_powershell_activation_path():
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_validate_ssh_port_rejects_shell_payload():
|
|
||||||
with pytest.raises(HTTPException):
|
|
||||||
_validate_ssh_port("22; touch /tmp/pwned")
|
|
||||||
assert _validate_ssh_port("2222") == "2222"
|
|
||||||
|
|
||||||
|
|
||||||
def test_validate_local_dir_accepts_external_drive_paths_with_spaces():
|
def test_validate_local_dir_accepts_external_drive_paths_with_spaces():
|
||||||
path = "/Volumes/T7 2TB/AI Models/llamacpp"
|
path = "/Volumes/T7 2TB/AI Models/llamacpp"
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,160 @@
|
|||||||
|
"""Regression coverage for issue #3722 — the message copy button copied the
|
||||||
|
full raw model output (``dataset.raw``), which still contains the
|
||||||
|
``<think time="...">...</think>`` reasoning block that the renderer strips for
|
||||||
|
display. Pasting therefore leaked the model's thinking, and the first heading
|
||||||
|
after ``</think>`` lost its markdown formatting because it was glued to the
|
||||||
|
closing tag.
|
||||||
|
|
||||||
|
The fix adds chatRenderer.copyMessageText(), which mirrors the display
|
||||||
|
pipeline (``stripToolBlocks()`` then ``extractThinkingBlocks()``), and routes
|
||||||
|
both AI-message copy buttons (createMsgFooter and the slash-reply footer)
|
||||||
|
through it. extractThinkingBlocks() behavior is pinned here under node
|
||||||
|
(including on the payload from the issue report); the helper and handler
|
||||||
|
wiring are guarded at the source level because chatRenderer.js pulls in
|
||||||
|
browser globals and can't be imported under node (same approach as
|
||||||
|
test_new_chat_clears_input.py).
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import re
|
||||||
|
import shutil
|
||||||
|
import subprocess
|
||||||
|
import textwrap
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
_REPO = Path(__file__).resolve().parent.parent
|
||||||
|
_HAS_NODE = shutil.which("node") is not None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def node_available():
|
||||||
|
if not _HAS_NODE:
|
||||||
|
pytest.skip("node binary not on PATH")
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_thinking_blocks(text: str) -> dict:
|
||||||
|
"""Run markdown.js extractThinkingBlocks(text) under node."""
|
||||||
|
script = textwrap.dedent(
|
||||||
|
r"""
|
||||||
|
import fs from 'node:fs';
|
||||||
|
|
||||||
|
globalThis.window = { location: { origin: 'http://localhost' }, katex: null };
|
||||||
|
globalThis.document = {
|
||||||
|
readyState: 'loading',
|
||||||
|
addEventListener() {},
|
||||||
|
createElement(tag) {
|
||||||
|
if (tag !== 'template') throw new Error(`unsupported element: ${tag}`);
|
||||||
|
return {
|
||||||
|
_html: '',
|
||||||
|
content: { querySelectorAll() { return []; } },
|
||||||
|
set innerHTML(value) { this._html = value; },
|
||||||
|
get innerHTML() { return this._html; },
|
||||||
|
};
|
||||||
|
},
|
||||||
|
};
|
||||||
|
globalThis.MutationObserver = class { observe() {} };
|
||||||
|
|
||||||
|
let source = fs.readFileSync('./static/js/markdown.js', 'utf8');
|
||||||
|
source = source.replace(
|
||||||
|
/import uiModule from ['"]\.\/ui\.js['"];/,
|
||||||
|
''
|
||||||
|
);
|
||||||
|
source = source.replace(
|
||||||
|
/import \{ splitTableRow \} from ['"]\.\/markdown\/tableRow\.js['"];/,
|
||||||
|
`function splitTableRow(row) {
|
||||||
|
return (row || '').replace(/^\\s*\\|/, '').replace(/\\|\\s*$/, '').split('|').map(c => c.trim());
|
||||||
|
}`
|
||||||
|
);
|
||||||
|
const emojiSource = fs.readFileSync('./static/js/emojiShortcodes.js', 'utf8')
|
||||||
|
.replace(/^export default .*$/m, '')
|
||||||
|
.replace(/export const /g, 'const ')
|
||||||
|
.replace(/export function /g, 'function ');
|
||||||
|
source = source.replace(
|
||||||
|
/import \{ replaceEmojiShortcodes, hasEmojiShortcode \} from ['"]\.\/emojiShortcodes\.js['"];/,
|
||||||
|
() => emojiSource
|
||||||
|
);
|
||||||
|
source = source.replace(
|
||||||
|
/var escapeHtml = uiModule\.esc;/,
|
||||||
|
`var escapeHtml = (value) => String(value ?? '')
|
||||||
|
.replace(/&/g, '&')
|
||||||
|
.replace(/</g, '<')
|
||||||
|
.replace(/>/g, '>')
|
||||||
|
.replace(/"/g, '"')
|
||||||
|
.replace(/'/g, ''');`
|
||||||
|
);
|
||||||
|
|
||||||
|
const moduleUrl = 'data:text/javascript;base64,' + Buffer.from(source).toString('base64');
|
||||||
|
const mod = await import(moduleUrl);
|
||||||
|
const input = JSON.parse(process.argv[1]);
|
||||||
|
console.log(JSON.stringify({ out: mod.extractThinkingBlocks(input) }));
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
result = subprocess.run(
|
||||||
|
["node", "--input-type=module", "-e", script, json.dumps(text)],
|
||||||
|
cwd=_REPO,
|
||||||
|
capture_output=True,
|
||||||
|
timeout=15,
|
||||||
|
text=True,
|
||||||
|
)
|
||||||
|
if result.returncode != 0:
|
||||||
|
raise AssertionError(f"node failed:\nSTDERR:\n{result.stderr}\nSTDOUT:\n{result.stdout}")
|
||||||
|
return json.loads(result.stdout.splitlines()[-1])["out"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_issue_payload_copy_text_excludes_thinking(node_available):
|
||||||
|
# Shape reported in #3722: timed think block glued to the reply heading.
|
||||||
|
raw = (
|
||||||
|
'<think time="24.5">\n'
|
||||||
|
"Here's a thinking process that leads to the desired summary:\n\n"
|
||||||
|
"6. **Generate the Output.** (This matches the final provided response.)"
|
||||||
|
"</think>### Juxtaposition: Interweaving Cultural Norms in Lesson Design\n"
|
||||||
|
"The most effective lesson structure is created by deliberately juxtaposing."
|
||||||
|
)
|
||||||
|
out = _extract_thinking_blocks(raw)
|
||||||
|
|
||||||
|
assert out["content"].startswith("### Juxtaposition:"), out["content"]
|
||||||
|
assert "thinking process" not in out["content"]
|
||||||
|
assert "<think" not in out["content"]
|
||||||
|
assert out["thinkingTime"] == "24.5"
|
||||||
|
|
||||||
|
|
||||||
|
def test_plain_reply_copy_text_is_unchanged(node_available):
|
||||||
|
raw = "### Heading\nJust a normal reply with no reasoning markup."
|
||||||
|
out = _extract_thinking_blocks(raw)
|
||||||
|
assert out["content"] == raw
|
||||||
|
|
||||||
|
|
||||||
|
def test_thinking_only_message_yields_empty_content(node_available):
|
||||||
|
# The copy handler falls back to the raw text in this case so the button
|
||||||
|
# still copies something for turns interrupted mid-thinking.
|
||||||
|
out = _extract_thinking_blocks("<think>only reasoning, no reply yet</think>")
|
||||||
|
assert out["content"] == ""
|
||||||
|
|
||||||
|
|
||||||
|
def _function_body(text: str, marker: str) -> str:
|
||||||
|
start = text.index(marker)
|
||||||
|
rest = text[start + len(marker):]
|
||||||
|
m = re.search(r"\nexport function |\nfunction ", rest)
|
||||||
|
return rest[: m.start()] if m else rest
|
||||||
|
|
||||||
|
|
||||||
|
def test_copy_message_text_mirrors_display_pipeline():
|
||||||
|
text = (_REPO / "static/js/chatRenderer.js").read_text(encoding="utf-8")
|
||||||
|
body = _function_body(text, "export function copyMessageText")
|
||||||
|
# Mirrors the display path: tool blocks stripped, then thinking extracted.
|
||||||
|
assert "extractThinkingBlocks" in body
|
||||||
|
assert "stripToolBlocks" in body
|
||||||
|
assert "dataset.raw" in body
|
||||||
|
|
||||||
|
|
||||||
|
def test_copy_handlers_route_through_copy_message_text():
|
||||||
|
for path, count in (("static/js/chatRenderer.js", 1), ("static/js/slashCommands.js", 1)):
|
||||||
|
text = (_REPO / path).read_text(encoding="utf-8")
|
||||||
|
assert text.count("copyToClipboard(copyMessageText(") + text.count(
|
||||||
|
"copyToClipboard(chatRenderer.copyMessageText("
|
||||||
|
) == count, path
|
||||||
|
# The old behavior passed dataset.raw straight to the clipboard.
|
||||||
|
assert "copyToClipboard(msgElement.dataset.raw" not in text, path
|
||||||
|
assert "copyToClipboard(msgEl.dataset.raw" not in text, path
|
||||||
@@ -45,6 +45,20 @@ async def test_search_and_extract_respects_extraction_concurrency():
|
|||||||
assert researcher.max_active == 2
|
assert researcher.max_active == 2
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_search_and_extract_tracks_all_urls_selected_for_analysis():
|
||||||
|
researcher = _ControlledResearcher(extraction_concurrency=2, max_urls_per_round=2)
|
||||||
|
researcher._start_time = time.time()
|
||||||
|
|
||||||
|
findings = await researcher._search_and_extract(["a"], "question")
|
||||||
|
|
||||||
|
assert len(findings) == 2
|
||||||
|
assert researcher.analyzed_urls == [
|
||||||
|
{"url": "https://example.test/a/0", "title": "a-0"},
|
||||||
|
{"url": "https://example.test/a/1", "title": "a-1"},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_fetch_and_extract_uses_configured_timeout(monkeypatch):
|
async def test_fetch_and_extract_uses_configured_timeout(monkeypatch):
|
||||||
captured = {}
|
captured = {}
|
||||||
|
|||||||
@@ -36,6 +36,17 @@ def _auth_manager(delete_result):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _auth_manager_raising():
|
||||||
|
def _delete_user(_username, _requesting_user):
|
||||||
|
raise RuntimeError("auth save failed after token purge")
|
||||||
|
|
||||||
|
return types.SimpleNamespace(
|
||||||
|
get_username_for_token=lambda token: "admin",
|
||||||
|
is_admin=lambda user: True,
|
||||||
|
delete_user=_delete_user,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_successful_delete_invalidates_cache():
|
def test_successful_delete_invalidates_cache():
|
||||||
invalidations = []
|
invalidations = []
|
||||||
router = setup_auth_routes(_auth_manager(delete_result=True))
|
router = setup_auth_routes(_auth_manager(delete_result=True))
|
||||||
@@ -56,3 +67,16 @@ def test_refused_delete_does_not_invalidate_cache():
|
|||||||
raised = True
|
raised = True
|
||||||
assert raised, "a refused delete should raise (HTTP 400)"
|
assert raised, "a refused delete should raise (HTTP 400)"
|
||||||
assert invalidations == [], "a refused delete must not touch the token cache"
|
assert invalidations == [], "a refused delete must not touch the token cache"
|
||||||
|
|
||||||
|
|
||||||
|
def test_delete_exception_invalidates_cache_for_partial_token_purge():
|
||||||
|
invalidations = []
|
||||||
|
router = setup_auth_routes(_auth_manager_raising())
|
||||||
|
handler = _handler(router)
|
||||||
|
try:
|
||||||
|
asyncio.run(handler(DeleteUserRequest(username="bob"), _fake_request(invalidations)))
|
||||||
|
raised = False
|
||||||
|
except RuntimeError:
|
||||||
|
raised = True
|
||||||
|
assert raised, "delete_user exception should still propagate"
|
||||||
|
assert invalidations == [True], "partial token purge must dirty the bearer cache"
|
||||||
|
|||||||
@@ -114,3 +114,21 @@ def test_refused_delete_leaves_tokens_alone(manager, db_calls):
|
|||||||
def test_unknown_user_leaves_tokens_alone(manager, db_calls):
|
def test_unknown_user_leaves_tokens_alone(manager, db_calls):
|
||||||
assert manager.delete_user("ghost", "admin") is False
|
assert manager.delete_user("ghost", "admin") is False
|
||||||
assert db_calls == []
|
assert db_calls == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_delete_user_fails_closed_when_api_token_purge_fails(manager, monkeypatch):
|
||||||
|
token = manager.create_session("bob", "secret-bob-pw")
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def _failing_db_session():
|
||||||
|
raise RuntimeError("database unavailable")
|
||||||
|
yield
|
||||||
|
|
||||||
|
db_stub = types.ModuleType("core.database")
|
||||||
|
db_stub.get_db_session = _failing_db_session
|
||||||
|
db_stub.ApiToken = _FakeApiToken
|
||||||
|
monkeypatch.setitem(sys.modules, "core.database", db_stub)
|
||||||
|
|
||||||
|
assert manager.delete_user("bob", "admin") is False
|
||||||
|
assert "bob" in manager.users
|
||||||
|
assert manager.validate_token(token) is True
|
||||||
|
|||||||
@@ -0,0 +1,71 @@
|
|||||||
|
"""Regression tests for _group_uid_fetch_records (Gmail FLAGS placement).
|
||||||
|
|
||||||
|
imaplib hands back UID FETCH responses as an interleaved list of
|
||||||
|
``(meta, literal)`` tuples and bare ``bytes`` elements. Dovecot sends FLAGS
|
||||||
|
before the RFC822.HEADER literal, so they sit inside the tuple meta; Gmail
|
||||||
|
sends FLAGS *after* the literal, as a bare ``b' FLAGS (\\Seen))'`` element.
|
||||||
|
The old grouping loop only looked at tuples, so on Gmail every message lost
|
||||||
|
its FLAGS and rendered as unread/unflagged in the email library.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import re
|
||||||
|
|
||||||
|
from routes.email_routes import _group_uid_fetch_records, _uid_from_fetch_meta
|
||||||
|
|
||||||
|
|
||||||
|
def _flags(meta_b: bytes) -> str:
|
||||||
|
m = re.search(rb"FLAGS \(([^)]*)\)", meta_b)
|
||||||
|
return m.group(1).decode() if m else ""
|
||||||
|
|
||||||
|
|
||||||
|
# Captured shape of a real Gmail response to
|
||||||
|
# UID FETCH a,b (UID FLAGS RFC822.HEADER RFC822.SIZE):
|
||||||
|
GMAIL_RESPONSE = [
|
||||||
|
(b"10779 (UID 18723 RFC822.SIZE 54308 RFC822.HEADER {24}", b"Subject: read one\r\n\r\n"),
|
||||||
|
rb" FLAGS (\Seen))",
|
||||||
|
(b"10780 (UID 18724 RFC822.SIZE 124310 RFC822.HEADER {26}", b"Subject: unread one\r\n\r\n"),
|
||||||
|
rb" FLAGS ())",
|
||||||
|
]
|
||||||
|
|
||||||
|
# Dovecot puts FLAGS before the literal and terminates with a bare b')'.
|
||||||
|
DOVECOT_RESPONSE = [
|
||||||
|
(rb"1 (UID 5 FLAGS (\Seen) RFC822.SIZE 100 RFC822.HEADER {18}", b"Subject: hi\r\n\r\n"),
|
||||||
|
b")",
|
||||||
|
(b"2 (UID 6 FLAGS () RFC822.SIZE 90 RFC822.HEADER {19}", b"Subject: new\r\n\r\n"),
|
||||||
|
b")",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_gmail_post_literal_flags_attach_to_their_own_message():
|
||||||
|
grouped = _group_uid_fetch_records(GMAIL_RESPONSE)
|
||||||
|
|
||||||
|
assert len(grouped) == 2
|
||||||
|
assert _uid_from_fetch_meta(grouped[0][0]) == "18723"
|
||||||
|
assert _flags(grouped[0][0]) == r"\Seen"
|
||||||
|
assert grouped[0][1] == b"Subject: read one\r\n\r\n"
|
||||||
|
|
||||||
|
assert _uid_from_fetch_meta(grouped[1][0]) == "18724"
|
||||||
|
assert _flags(grouped[1][0]) == ""
|
||||||
|
assert grouped[1][1] == b"Subject: unread one\r\n\r\n"
|
||||||
|
|
||||||
|
|
||||||
|
def test_dovecot_pre_literal_flags_unchanged():
|
||||||
|
grouped = _group_uid_fetch_records(DOVECOT_RESPONSE)
|
||||||
|
|
||||||
|
assert len(grouped) == 2
|
||||||
|
assert _flags(grouped[0][0]) == r"\Seen"
|
||||||
|
assert _flags(grouped[1][0]) == ""
|
||||||
|
assert grouped[1][1] == b"Subject: new\r\n\r\n"
|
||||||
|
|
||||||
|
|
||||||
|
def test_size_and_uid_survive_grouping():
|
||||||
|
grouped = _group_uid_fetch_records(GMAIL_RESPONSE)
|
||||||
|
sizes = [re.search(rb"RFC822\.SIZE (\d+)", m).group(1) for m, _ in grouped]
|
||||||
|
assert sizes == [b"54308", b"124310"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_empty_and_none_inputs():
|
||||||
|
assert _group_uid_fetch_records(None) == []
|
||||||
|
assert _group_uid_fetch_records([]) == []
|
||||||
|
# A stray bare element before any tuple opens no record and must not crash.
|
||||||
|
assert _group_uid_fetch_records([rb" FLAGS (\Seen))"]) == []
|
||||||
@@ -1,5 +1,7 @@
|
|||||||
import sqlite3
|
import sqlite3
|
||||||
|
from contextlib import contextmanager
|
||||||
from datetime import datetime, timedelta, timezone
|
from datetime import datetime, timedelta, timezone
|
||||||
|
from types import SimpleNamespace
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
@@ -117,6 +119,71 @@ def test_email_ai_cache_tables_are_owner_scoped_and_migrate_legacy_rows(tmp_path
|
|||||||
conn.close()
|
conn.close()
|
||||||
|
|
||||||
|
|
||||||
|
def test_sender_signature_cache_is_owner_scoped_and_migrates_legacy_rows(tmp_path, monkeypatch):
|
||||||
|
import routes.email_helpers as email_helpers
|
||||||
|
|
||||||
|
db_path = tmp_path / "scheduled_emails.db"
|
||||||
|
monkeypatch.setattr(email_helpers, "SCHEDULED_DB", db_path)
|
||||||
|
|
||||||
|
conn = sqlite3.connect(db_path)
|
||||||
|
conn.execute(
|
||||||
|
"""
|
||||||
|
CREATE TABLE sender_signatures (
|
||||||
|
from_address TEXT PRIMARY KEY,
|
||||||
|
signature_text TEXT,
|
||||||
|
sample_count INTEGER,
|
||||||
|
last_built_at TEXT NOT NULL,
|
||||||
|
model_used TEXT,
|
||||||
|
source TEXT
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
conn.execute(
|
||||||
|
"""
|
||||||
|
INSERT INTO sender_signatures
|
||||||
|
(from_address, signature_text, sample_count, last_built_at, model_used, source)
|
||||||
|
VALUES ('writer@example.com', 'legacy sig', 3, '2026-01-01', 'm', 'llm')
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
conn.commit()
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
email_helpers._init_scheduled_db()
|
||||||
|
|
||||||
|
conn = sqlite3.connect(db_path)
|
||||||
|
try:
|
||||||
|
info = conn.execute("PRAGMA table_info(sender_signatures)").fetchall()
|
||||||
|
pk_cols = [r[1] for r in sorted((r for r in info if r[5]), key=lambda r: r[5])]
|
||||||
|
assert pk_cols == ["from_address", "owner"]
|
||||||
|
assert conn.execute(
|
||||||
|
"SELECT owner, signature_text FROM sender_signatures WHERE from_address=?",
|
||||||
|
("writer@example.com",),
|
||||||
|
).fetchone() == ("", "legacy sig")
|
||||||
|
conn.execute(
|
||||||
|
"""
|
||||||
|
INSERT INTO sender_signatures
|
||||||
|
(from_address, owner, signature_text, sample_count, last_built_at, model_used, source)
|
||||||
|
VALUES (?, ?, ?, ?, ?, ?, ?)
|
||||||
|
""",
|
||||||
|
("writer@example.com", "alice", "alice sig", 3, "2026-01-02", "m", "llm"),
|
||||||
|
)
|
||||||
|
conn.execute(
|
||||||
|
"""
|
||||||
|
INSERT INTO sender_signatures
|
||||||
|
(from_address, owner, signature_text, sample_count, last_built_at, model_used, source)
|
||||||
|
VALUES (?, ?, ?, ?, ?, ?, ?)
|
||||||
|
""",
|
||||||
|
("writer@example.com", "bob", "bob sig", 3, "2026-01-03", "m", "llm"),
|
||||||
|
)
|
||||||
|
rows = conn.execute(
|
||||||
|
"SELECT owner, signature_text FROM sender_signatures WHERE from_address=? ORDER BY owner",
|
||||||
|
("writer@example.com",),
|
||||||
|
).fetchall()
|
||||||
|
assert rows == [("", "legacy sig"), ("alice", "alice sig"), ("bob", "bob sig")]
|
||||||
|
finally:
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_ai_reply_cache_lookup_is_owner_scoped(tmp_path, monkeypatch):
|
async def test_ai_reply_cache_lookup_is_owner_scoped(tmp_path, monkeypatch):
|
||||||
import routes.email_helpers as email_helpers
|
import routes.email_helpers as email_helpers
|
||||||
@@ -166,6 +233,136 @@ async def test_ai_reply_cache_lookup_is_owner_scoped(tmp_path, monkeypatch):
|
|||||||
assert result["model_used"] == "m-b"
|
assert result["model_used"] == "m-b"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_sender_signature_read_lookup_is_owner_scoped(tmp_path, monkeypatch):
|
||||||
|
import routes.email_helpers as email_helpers
|
||||||
|
import routes.email_routes as email_routes
|
||||||
|
|
||||||
|
db_path = tmp_path / "scheduled_emails.db"
|
||||||
|
monkeypatch.setattr(email_helpers, "SCHEDULED_DB", db_path)
|
||||||
|
monkeypatch.setattr(email_routes, "SCHEDULED_DB", db_path)
|
||||||
|
email_helpers._init_scheduled_db()
|
||||||
|
|
||||||
|
conn = sqlite3.connect(db_path)
|
||||||
|
conn.execute(
|
||||||
|
"""
|
||||||
|
INSERT INTO sender_signatures
|
||||||
|
(from_address, owner, signature_text, sample_count, last_built_at, model_used, source)
|
||||||
|
VALUES (?, ?, ?, ?, ?, ?, ?)
|
||||||
|
""",
|
||||||
|
("writer@example.com", "alice", "alice private sig", 3, "2026-01-01", "m-a", "llm"),
|
||||||
|
)
|
||||||
|
conn.execute(
|
||||||
|
"""
|
||||||
|
INSERT INTO sender_signatures
|
||||||
|
(from_address, owner, signature_text, sample_count, last_built_at, model_used, source)
|
||||||
|
VALUES (?, ?, ?, ?, ?, ?, ?)
|
||||||
|
""",
|
||||||
|
("writer@example.com", "bob", "bob private sig", 3, "2026-01-02", "m-b", "llm"),
|
||||||
|
)
|
||||||
|
conn.commit()
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
raw = (
|
||||||
|
b"From: Writer <writer@example.com>\r\n"
|
||||||
|
b"To: Bob <bob@example.com>\r\n"
|
||||||
|
b"Subject: Hello\r\n"
|
||||||
|
b"Message-ID: <shared@example.com>\r\n"
|
||||||
|
b"Date: Tue, 01 Jan 2026 12:00:00 +0000\r\n"
|
||||||
|
b"Content-Type: text/plain; charset=utf-8\r\n"
|
||||||
|
b"\r\n"
|
||||||
|
b"Body"
|
||||||
|
)
|
||||||
|
|
||||||
|
class FakeImap:
|
||||||
|
def select(self, *_args, **_kwargs):
|
||||||
|
return "OK", []
|
||||||
|
|
||||||
|
def uid(self, command, _uid, query):
|
||||||
|
assert command == "FETCH"
|
||||||
|
assert query == "(BODY.PEEK[])"
|
||||||
|
return "OK", [(b"1 (UID 1 BODY[])", raw)]
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def fake_imap(_account_id=None, owner=""):
|
||||||
|
assert owner == "bob"
|
||||||
|
yield FakeImap()
|
||||||
|
|
||||||
|
monkeypatch.setattr(email_routes, "_imap", fake_imap)
|
||||||
|
router = email_routes.setup_email_routes()
|
||||||
|
read_email = _route_endpoint(router, "/api/email/read/{uid}", "GET")
|
||||||
|
|
||||||
|
result = await read_email("1", folder="INBOX", account_id=None, owner="bob", mark_seen=False)
|
||||||
|
|
||||||
|
assert result["sender_signature"] == "bob private sig"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_sender_signature_clear_cache_keeps_other_owner_rows(tmp_path, monkeypatch):
|
||||||
|
import routes.email_helpers as email_helpers
|
||||||
|
import routes.task_routes as task_routes
|
||||||
|
|
||||||
|
db_path = tmp_path / "scheduled_emails.db"
|
||||||
|
monkeypatch.setattr(email_helpers, "SCHEDULED_DB", db_path)
|
||||||
|
email_helpers._init_scheduled_db()
|
||||||
|
|
||||||
|
conn = sqlite3.connect(db_path)
|
||||||
|
conn.execute(
|
||||||
|
"""
|
||||||
|
INSERT INTO sender_signatures
|
||||||
|
(from_address, owner, signature_text, sample_count, last_built_at, model_used, source)
|
||||||
|
VALUES (?, ?, ?, ?, ?, ?, ?)
|
||||||
|
""",
|
||||||
|
("writer@example.com", "alice", "alice private sig", 3, "2026-01-01", "m-a", "llm"),
|
||||||
|
)
|
||||||
|
conn.execute(
|
||||||
|
"""
|
||||||
|
INSERT INTO sender_signatures
|
||||||
|
(from_address, owner, signature_text, sample_count, last_built_at, model_used, source)
|
||||||
|
VALUES (?, ?, ?, ?, ?, ?, ?)
|
||||||
|
""",
|
||||||
|
("writer@example.com", "bob", "bob private sig", 3, "2026-01-02", "m-b", "llm"),
|
||||||
|
)
|
||||||
|
conn.commit()
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
class FakeQuery:
|
||||||
|
def filter(self, *_args):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def first(self):
|
||||||
|
return SimpleNamespace(
|
||||||
|
id="task-1",
|
||||||
|
owner="alice",
|
||||||
|
action="learn_sender_signatures",
|
||||||
|
)
|
||||||
|
|
||||||
|
class FakeDb:
|
||||||
|
def query(self, _model):
|
||||||
|
return FakeQuery()
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
monkeypatch.setattr(task_routes, "SessionLocal", lambda: FakeDb())
|
||||||
|
monkeypatch.setattr(task_routes, "get_current_user", lambda _request: "alice")
|
||||||
|
|
||||||
|
router = task_routes.setup_task_routes(task_scheduler=SimpleNamespace(pop_notifications=lambda owner: []))
|
||||||
|
clear_cache = _route_endpoint(router, "/api/tasks/{task_id}/clear-cache", "POST")
|
||||||
|
|
||||||
|
result = await clear_cache(SimpleNamespace(), "task-1")
|
||||||
|
|
||||||
|
assert result["cleared"]["sender_signatures"] == 1
|
||||||
|
conn = sqlite3.connect(db_path)
|
||||||
|
try:
|
||||||
|
rows = conn.execute(
|
||||||
|
"SELECT owner, signature_text FROM sender_signatures ORDER BY owner",
|
||||||
|
).fetchall()
|
||||||
|
finally:
|
||||||
|
conn.close()
|
||||||
|
assert rows == [("bob", "bob private sig")]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_scheduled_email_routes_are_owner_scoped(tmp_path, monkeypatch):
|
async def test_scheduled_email_routes_are_owner_scoped(tmp_path, monkeypatch):
|
||||||
import routes.email_helpers as email_helpers
|
import routes.email_helpers as email_helpers
|
||||||
|
|||||||
@@ -0,0 +1,38 @@
|
|||||||
|
"""GET /api/hwfit/models must not 500 on a non-numeric gpu_count.
|
||||||
|
|
||||||
|
The handler did `n = int(gpu_count)` with no guard, so `?gpu_count=abc` (or any
|
||||||
|
non-integer) raised ValueError -> HTTP 500. A malformed count is now ignored,
|
||||||
|
matching how the neighbouring gpu_group param is already parsed.
|
||||||
|
"""
|
||||||
|
from routes.hwfit_routes import setup_hwfit_routes
|
||||||
|
|
||||||
|
|
||||||
|
def _get_models():
|
||||||
|
router = setup_hwfit_routes()
|
||||||
|
for route in router.routes:
|
||||||
|
if getattr(route, "path", "").endswith("/models") and "GET" in getattr(route, "methods", set()):
|
||||||
|
return route.endpoint
|
||||||
|
raise AssertionError("hwfit /models route not found")
|
||||||
|
|
||||||
|
|
||||||
|
def test_non_numeric_gpu_count_does_not_raise():
|
||||||
|
handler = _get_models()
|
||||||
|
# Previously raised ValueError (HTTP 500); now degrades to a normal ranking.
|
||||||
|
result = handler(gpu_count="abc")
|
||||||
|
assert isinstance(result, dict)
|
||||||
|
|
||||||
|
|
||||||
|
def test_numeric_gpu_count_still_accepted():
|
||||||
|
handler = _get_models()
|
||||||
|
result = handler(gpu_count="0")
|
||||||
|
assert isinstance(result, dict)
|
||||||
|
|
||||||
|
|
||||||
|
def test_non_numeric_manual_gpu_count_does_not_raise():
|
||||||
|
# manual_gpu_count is the other count param on this endpoint (the hardware
|
||||||
|
# simulator in _apply_manual_hardware). A non-numeric value must also degrade
|
||||||
|
# (default to 1) rather than 500, so the endpoint's count parsing is fully
|
||||||
|
# covered.
|
||||||
|
handler = _get_models()
|
||||||
|
result = handler(manual_mode="gpu", manual_gpu_count="abc")
|
||||||
|
assert isinstance(result, dict)
|
||||||
@@ -0,0 +1,47 @@
|
|||||||
|
import pytest
|
||||||
|
from fastapi import HTTPException
|
||||||
|
|
||||||
|
from core.platform_compat import _ssh_exec_argv
|
||||||
|
from routes.hwfit_routes import setup_hwfit_routes
|
||||||
|
|
||||||
|
|
||||||
|
def _endpoint(path: str):
|
||||||
|
router = setup_hwfit_routes()
|
||||||
|
for route in router.routes:
|
||||||
|
if getattr(route, "path", "") == path:
|
||||||
|
return route.endpoint
|
||||||
|
raise AssertionError(f"{path} route not found")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"path,kwargs",
|
||||||
|
[
|
||||||
|
("/api/hwfit/system", {}),
|
||||||
|
("/api/hwfit/models", {"limit": 1}),
|
||||||
|
("/api/hwfit/profiles", {"model": "demo"}),
|
||||||
|
("/api/hwfit/image-models", {}),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_hwfit_routes_reject_ssh_option_host(path, kwargs):
|
||||||
|
endpoint = _endpoint(path)
|
||||||
|
|
||||||
|
with pytest.raises(HTTPException) as exc:
|
||||||
|
endpoint(host="-oProxyCommand=sh", ssh_port="22", **kwargs)
|
||||||
|
|
||||||
|
assert exc.value.status_code == 400
|
||||||
|
|
||||||
|
|
||||||
|
def test_hwfit_routes_reject_port_without_host():
|
||||||
|
endpoint = _endpoint("/api/hwfit/system")
|
||||||
|
|
||||||
|
with pytest.raises(HTTPException) as exc:
|
||||||
|
endpoint(host="", ssh_port="2222")
|
||||||
|
|
||||||
|
assert exc.value.status_code == 400
|
||||||
|
|
||||||
|
|
||||||
|
def test_ssh_argv_rejects_option_shaped_remote():
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
_ssh_exec_argv("-oProxyCommand=sh", "22", remote_cmd="true")
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
_ssh_exec_argv("alice@-oProxyCommand=sh", "22", remote_cmd="true")
|
||||||
@@ -0,0 +1,94 @@
|
|||||||
|
"""Regression guard: Opus 4.7+ rejects the temperature field entirely.
|
||||||
|
|
||||||
|
Anthropic removed the sampling parameters (temperature, top_p, top_k) starting
|
||||||
|
with Claude Opus 4.7 — sending `temperature` at all, even 0.0, returns HTTP 400.
|
||||||
|
This broke every native-Anthropic call to Opus 4.7/4.8, including the research
|
||||||
|
endpoint probe (temperature=0) and all DeepResearcher LLM calls, because
|
||||||
|
_build_anthropic_payload sent `temperature` unconditionally.
|
||||||
|
|
||||||
|
Earlier Claude models (Opus 4.6 and below, every Sonnet/Haiku) still accept
|
||||||
|
temperature in [0.0, 1.0], so the omission is version-gated — the clamp-to-[0,1]
|
||||||
|
behavior for those models (test_llm_core_anthropic_temp_clamp.py) is unchanged.
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
|
||||||
|
os.environ.setdefault("DATABASE_URL", "sqlite:///:memory:")
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from src.llm_core import _anthropic_rejects_temperature, _build_anthropic_payload
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"model",
|
||||||
|
[
|
||||||
|
"claude-opus-4-7",
|
||||||
|
"claude-opus-4-8",
|
||||||
|
"claude-opus-4-8-20260101", # tolerate a dated snapshot suffix
|
||||||
|
"claude-opus-4-7-20260201", # dated 4.7 snapshot — explicit minor, still >= 4.7
|
||||||
|
"anthropic/claude-opus-4-7", # tolerate a provider-prefixed id
|
||||||
|
"claude-opus-4-10", # future minor still >= 4.7
|
||||||
|
"claude-opus-5-0", # future major
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_opus_47_plus_rejects_temperature(model):
|
||||||
|
assert _anthropic_rejects_temperature(model) is True
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"model",
|
||||||
|
[
|
||||||
|
"claude-opus-4-6",
|
||||||
|
"claude-opus-4-5",
|
||||||
|
"claude-opus-4-1",
|
||||||
|
"claude-opus-4-0",
|
||||||
|
"claude-opus-4", # bare major (no minor) — kept
|
||||||
|
"claude-opus-4-20250514", # Opus 4.0 dated id — the date must NOT read as a 4.7+ minor
|
||||||
|
"claude-opus-4-1-20250805", # Opus 4.1 dated id — explicit minor before the date
|
||||||
|
"claude-opus-4-6-20251201", # dated 4.6 snapshot — older, still keeps temperature
|
||||||
|
"claude-sonnet-4-6",
|
||||||
|
"claude-3-5-sonnet",
|
||||||
|
"claude-3-opus-20240229", # legacy Claude 3 Opus — no opus-N-M pattern, kept
|
||||||
|
"claude-haiku-4-5",
|
||||||
|
"claude-x",
|
||||||
|
"octopus-4-8", # "opus" only as a substring of another word — must not match
|
||||||
|
"myproxy/octopus-4-8", # same, behind a provider prefix
|
||||||
|
"",
|
||||||
|
None,
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_older_claude_models_keep_temperature(model):
|
||||||
|
assert _anthropic_rejects_temperature(model) is False
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("model", [123, 1.5, ["claude-opus-4-8"], {"a": 1}, object()])
|
||||||
|
def test_non_string_model_is_handled_without_crashing(model):
|
||||||
|
# Defensive: the gate must not raise on a non-string model (the old builder
|
||||||
|
# never called .lower() on it). Truthy non-strings should classify as False.
|
||||||
|
assert _anthropic_rejects_temperature(model) is False
|
||||||
|
|
||||||
|
|
||||||
|
def _payload(model, temperature=0.0):
|
||||||
|
return _build_anthropic_payload(
|
||||||
|
model, [{"role": "user", "content": "hi"}], temperature, 100
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_payload_omits_temperature_for_opus_47_plus():
|
||||||
|
# The endpoint probe sends temperature=0; on Opus 4.7+ that field must be gone.
|
||||||
|
payload = _payload("claude-opus-4-8", 0.0)
|
||||||
|
assert "temperature" not in payload
|
||||||
|
|
||||||
|
|
||||||
|
def test_payload_keeps_temperature_for_older_models():
|
||||||
|
payload = _payload("claude-opus-4-6", 0.3)
|
||||||
|
assert payload["temperature"] == 0.3
|
||||||
|
# Older models retain the [0,1] clamp (Nietzsche preset at 1.2 -> 1.0).
|
||||||
|
assert _payload("claude-3-5-sonnet", 1.2)["temperature"] == 1.0
|
||||||
|
|
||||||
|
|
||||||
|
def test_payload_keeps_temperature_for_dated_opus_4_0():
|
||||||
|
# Anthropic's dated id for Opus 4.0 (claude-opus-4-20250514) is in this repo's
|
||||||
|
# ANTHROPIC_MODELS list. The date must not be misread as a >= 4.7 minor, or the
|
||||||
|
# user's temperature would be silently dropped on a model that accepts it.
|
||||||
|
assert _payload("claude-opus-4-20250514", 0.5)["temperature"] == 0.5
|
||||||
@@ -14,6 +14,7 @@ import pytest
|
|||||||
from fastapi import HTTPException
|
from fastapi import HTTPException
|
||||||
|
|
||||||
import routes.memory_routes as mr
|
import routes.memory_routes as mr
|
||||||
|
from src.request_models import MemoryAddRequest
|
||||||
|
|
||||||
|
|
||||||
def _route(router, path, method):
|
def _route(router, path, method):
|
||||||
@@ -38,6 +39,13 @@ def _router(monkeypatch, caller):
|
|||||||
return mr.setup_memory_routes(mem, sm)
|
return mr.setup_memory_routes(mem, sm)
|
||||||
|
|
||||||
|
|
||||||
|
def _request(user):
|
||||||
|
return SimpleNamespace(
|
||||||
|
state=SimpleNamespace(current_user=user),
|
||||||
|
app=SimpleNamespace(state=SimpleNamespace(auth_manager=None)),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_extract_rejects_other_users_session(monkeypatch):
|
def test_extract_rejects_other_users_session(monkeypatch):
|
||||||
router = _router(monkeypatch, caller="bob")
|
router = _router(monkeypatch, caller="bob")
|
||||||
extract = _route(router, "/api/memory/extract", "POST")
|
extract = _route(router, "/api/memory/extract", "POST")
|
||||||
@@ -59,3 +67,61 @@ def test_owner_can_access_own_session(monkeypatch):
|
|||||||
gbs = _route(router, "/api/memory/by-session/{session_id}", "GET")
|
gbs = _route(router, "/api/memory/by-session/{session_id}", "GET")
|
||||||
out = gbs(request=None, session_id="alice-sess")
|
out = gbs(request=None, session_id="alice-sess")
|
||||||
assert out["session_name"] == "Secret project"
|
assert out["session_name"] == "Secret project"
|
||||||
|
|
||||||
|
|
||||||
|
def test_add_memory_rejects_other_users_session(monkeypatch):
|
||||||
|
memory_manager = MagicMock()
|
||||||
|
session_manager = MagicMock()
|
||||||
|
memory_vector = MagicMock(healthy=True)
|
||||||
|
router = mr.setup_memory_routes(
|
||||||
|
memory_manager=memory_manager,
|
||||||
|
session_manager=session_manager,
|
||||||
|
memory_vector=memory_vector,
|
||||||
|
)
|
||||||
|
add_memory = _route(router, "/api/memory/add", "POST")
|
||||||
|
|
||||||
|
memory_manager.load.return_value = []
|
||||||
|
memory_manager.find_duplicates.return_value = False
|
||||||
|
session_manager.get_session.return_value = SimpleNamespace(owner="bob", name="Bob session")
|
||||||
|
|
||||||
|
with pytest.raises(HTTPException) as exc:
|
||||||
|
asyncio.run(
|
||||||
|
add_memory(
|
||||||
|
request=_request("alice"),
|
||||||
|
memory_data=MemoryAddRequest(
|
||||||
|
text="Alice note",
|
||||||
|
category="fact",
|
||||||
|
source="user",
|
||||||
|
session_id="bob-session",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
assert exc.value.status_code == 404
|
||||||
|
assert exc.value.detail == "Session not found"
|
||||||
|
session_manager.get_session.assert_called_once_with("bob-session")
|
||||||
|
memory_manager.add_entry.assert_not_called()
|
||||||
|
memory_manager.save.assert_not_called()
|
||||||
|
memory_vector.add.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
def test_timeline_does_not_expose_other_users_session_name():
|
||||||
|
memory_manager = MagicMock()
|
||||||
|
session_manager = MagicMock()
|
||||||
|
session_manager.sessions = {"bob-session": object()}
|
||||||
|
session_manager.get_session.return_value = SimpleNamespace(owner="bob", name="Bob roadmap")
|
||||||
|
memory_manager.load.return_value = [
|
||||||
|
{
|
||||||
|
"id": "m1",
|
||||||
|
"text": "Alice note",
|
||||||
|
"owner": "alice",
|
||||||
|
"session_id": "bob-session",
|
||||||
|
"timestamp": 1,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
router = mr.setup_memory_routes(memory_manager, session_manager)
|
||||||
|
timeline = _route(router, "/api/memory/timeline", "GET")
|
||||||
|
|
||||||
|
out = timeline(request=_request("alice"))
|
||||||
|
|
||||||
|
assert out["timeline"][0]["session_name"] == "Unknown"
|
||||||
|
|||||||
+11
-11
@@ -6,7 +6,7 @@ import types
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
import src.model_context as model_context
|
import src.model_context as model_context
|
||||||
from src.model_context import _is_local_endpoint, estimate_tokens, _lookup_known
|
from src.model_context import is_local_endpoint, estimate_tokens, _lookup_known
|
||||||
|
|
||||||
|
|
||||||
class _Column:
|
class _Column:
|
||||||
@@ -56,20 +56,20 @@ def _install_endpoint_db(monkeypatch, rows):
|
|||||||
|
|
||||||
class TestIsLocalEndpoint:
|
class TestIsLocalEndpoint:
|
||||||
def test_localhost(self):
|
def test_localhost(self):
|
||||||
assert _is_local_endpoint("http://localhost:5000/v1/chat/completions") is True
|
assert is_local_endpoint("http://localhost:5000/v1/chat/completions") is True
|
||||||
|
|
||||||
def test_loopback_ipv4(self):
|
def test_loopback_ipv4(self):
|
||||||
assert _is_local_endpoint("http://127.0.0.1:8080/v1/chat/completions") is True
|
assert is_local_endpoint("http://127.0.0.1:8080/v1/chat/completions") is True
|
||||||
|
|
||||||
def test_private_192_168(self):
|
def test_private_192_168(self):
|
||||||
assert _is_local_endpoint("http://192.168.1.1:11434/v1/chat/completions") is True
|
assert is_local_endpoint("http://192.168.1.1:11434/v1/chat/completions") is True
|
||||||
|
|
||||||
def test_private_10(self):
|
def test_private_10(self):
|
||||||
assert _is_local_endpoint("http://10.0.0.5:8000/v1/chat/completions") is True
|
assert is_local_endpoint("http://10.0.0.5:8000/v1/chat/completions") is True
|
||||||
|
|
||||||
def test_tailscale_100(self):
|
def test_tailscale_100(self):
|
||||||
# 100.64.0.0/10 is the CGNAT range Tailscale uses.
|
# 100.64.0.0/10 is the CGNAT range Tailscale uses.
|
||||||
assert _is_local_endpoint("http://100.64.0.1:5000/v1/chat/completions") is True
|
assert is_local_endpoint("http://100.64.0.1:5000/v1/chat/completions") is True
|
||||||
|
|
||||||
def test_configured_tailscale_proxy_is_remote(self, monkeypatch):
|
def test_configured_tailscale_proxy_is_remote(self, monkeypatch):
|
||||||
_install_endpoint_db(monkeypatch, [
|
_install_endpoint_db(monkeypatch, [
|
||||||
@@ -81,19 +81,19 @@ class TestIsLocalEndpoint:
|
|||||||
)
|
)
|
||||||
])
|
])
|
||||||
|
|
||||||
assert _is_local_endpoint("http://100.117.136.97:34521/v1/chat/completions") is False
|
assert is_local_endpoint("http://100.117.136.97:34521/v1/chat/completions") is False
|
||||||
|
|
||||||
def test_openai_is_remote(self):
|
def test_openai_is_remote(self):
|
||||||
assert _is_local_endpoint("https://api.openai.com/v1/chat/completions") is False
|
assert is_local_endpoint("https://api.openai.com/v1/chat/completions") is False
|
||||||
|
|
||||||
def test_anthropic_is_remote(self):
|
def test_anthropic_is_remote(self):
|
||||||
assert _is_local_endpoint("https://api.anthropic.com/v1/messages") is False
|
assert is_local_endpoint("https://api.anthropic.com/v1/messages") is False
|
||||||
|
|
||||||
def test_empty_url(self):
|
def test_empty_url(self):
|
||||||
assert _is_local_endpoint("") is False
|
assert is_local_endpoint("") is False
|
||||||
|
|
||||||
def test_malformed_url(self):
|
def test_malformed_url(self):
|
||||||
assert _is_local_endpoint("not-a-url") is False
|
assert is_local_endpoint("not-a-url") is False
|
||||||
|
|
||||||
|
|
||||||
class TestEstimateTokens:
|
class TestEstimateTokens:
|
||||||
|
|||||||
@@ -54,6 +54,7 @@ with preserve_import_state("core.database", "src.database", "core.session_manage
|
|||||||
_endpoint_settings_using_endpoint,
|
_endpoint_settings_using_endpoint,
|
||||||
_clear_endpoint_settings_for_endpoint,
|
_clear_endpoint_settings_for_endpoint,
|
||||||
_clear_user_pref_endpoint_refs,
|
_clear_user_pref_endpoint_refs,
|
||||||
|
_default_endpoint_needs_assignment,
|
||||||
_PROVIDER_CURATED,
|
_PROVIDER_CURATED,
|
||||||
)
|
)
|
||||||
from src.llm_core import ANTHROPIC_MODELS
|
from src.llm_core import ANTHROPIC_MODELS
|
||||||
@@ -154,6 +155,26 @@ def test_endpoint_cleanup_updates_scoped_and_legacy_user_prefs():
|
|||||||
assert legacy["default_model_fallbacks"] == []
|
assert legacy["default_model_fallbacks"] == []
|
||||||
|
|
||||||
|
|
||||||
|
# ── _default_endpoint_needs_assignment (add-endpoint auto-default) ──
|
||||||
|
|
||||||
|
def test_default_assignment_when_none_configured():
|
||||||
|
# Nothing configured yet → first added endpoint should become the default.
|
||||||
|
assert _default_endpoint_needs_assignment("", {"a", "b"}) is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_default_assignment_when_current_default_disabled():
|
||||||
|
# #3586: the configured default points at an endpoint that is no longer
|
||||||
|
# enabled (the user disabled it). Adding a new endpoint must reassign the
|
||||||
|
# default — otherwise Memory → Tidy keeps failing with "No default model
|
||||||
|
# configured" even though an enabled endpoint exists.
|
||||||
|
assert _default_endpoint_needs_assignment("disabled-ep", {"new-ep"}) is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_default_preserved_when_current_default_enabled():
|
||||||
|
# Normal case: the configured default is still enabled → leave it alone.
|
||||||
|
assert _default_endpoint_needs_assignment("live-ep", {"live-ep", "new-ep"}) is False
|
||||||
|
|
||||||
|
|
||||||
# ── _match_provider_curated ──
|
# ── _match_provider_curated ──
|
||||||
|
|
||||||
class TestMatchProviderCurated:
|
class TestMatchProviderCurated:
|
||||||
@@ -966,16 +987,21 @@ def _create_form_kwargs(**overrides):
|
|||||||
return kwargs
|
return kwargs
|
||||||
|
|
||||||
|
|
||||||
def _patch_create_deps(monkeypatch, db):
|
def _patch_create_deps(monkeypatch, db, settings=None):
|
||||||
import src.auth_helpers as auth_helpers
|
import src.auth_helpers as auth_helpers
|
||||||
|
# Shared, in-memory settings so the auto-default write path stays hermetic
|
||||||
|
# (no real settings.json). Returned so tests can assert what was persisted.
|
||||||
|
settings = {"default_endpoint_id": "exists"} if settings is None else settings
|
||||||
monkeypatch.setattr(model_routes, "SessionLocal", lambda: db)
|
monkeypatch.setattr(model_routes, "SessionLocal", lambda: db)
|
||||||
monkeypatch.setattr(model_routes, "require_admin", lambda request: None)
|
monkeypatch.setattr(model_routes, "require_admin", lambda request: None)
|
||||||
monkeypatch.setattr(model_routes, "ModelEndpoint", _RecordingEndpoint)
|
monkeypatch.setattr(model_routes, "ModelEndpoint", _RecordingEndpoint)
|
||||||
monkeypatch.setattr(model_routes, "_normalize_base", lambda b: b)
|
monkeypatch.setattr(model_routes, "_normalize_base", lambda b: b)
|
||||||
monkeypatch.setattr(model_routes, "_rewrite_loopback_for_docker", lambda b, **k: b)
|
monkeypatch.setattr(model_routes, "_rewrite_loopback_for_docker", lambda b, **k: b)
|
||||||
monkeypatch.setattr(model_routes, "_load_settings", lambda: {"default_endpoint_id": "exists"})
|
monkeypatch.setattr(model_routes, "_load_settings", lambda: settings)
|
||||||
|
monkeypatch.setattr(model_routes, "_save_settings", lambda s: settings.update(s))
|
||||||
monkeypatch.setattr(endpoint_resolver, "resolve_url", lambda u: u)
|
monkeypatch.setattr(endpoint_resolver, "resolve_url", lambda u: u)
|
||||||
monkeypatch.setattr(auth_helpers, "get_current_user", lambda req: None)
|
monkeypatch.setattr(auth_helpers, "get_current_user", lambda req: None)
|
||||||
|
return settings
|
||||||
|
|
||||||
|
|
||||||
def test_list_model_endpoints_returns_key_fingerprint(monkeypatch):
|
def test_list_model_endpoints_returns_key_fingerprint(monkeypatch):
|
||||||
@@ -1091,6 +1117,48 @@ def test_post_same_base_url_different_api_key_creates_distinct_endpoint(monkeypa
|
|||||||
assert db.added[0].api_key == "key-two"
|
assert db.added[0].api_key == "key-two"
|
||||||
|
|
||||||
|
|
||||||
|
def test_post_reassigns_default_when_current_default_disabled(monkeypatch):
|
||||||
|
# #3586: the configured default points at a now-disabled endpoint. Adding a
|
||||||
|
# new endpoint must promote it to the default, otherwise raw-setting readers
|
||||||
|
# (Memory → Tidy) keep failing with "No default model configured".
|
||||||
|
disabled = _make_endpoint(id="dead", base_url="http://old-host/v1", is_enabled=False)
|
||||||
|
db = _PinnedFakeDb([disabled])
|
||||||
|
settings = _patch_create_deps(
|
||||||
|
monkeypatch, db, settings={"default_endpoint_id": "dead", "default_model": "stale"}
|
||||||
|
)
|
||||||
|
create = _get_route("/api/model-endpoints", "POST")
|
||||||
|
|
||||||
|
create(
|
||||||
|
_PinnedFakeRequest(),
|
||||||
|
base_url="http://new-host:1234/v1",
|
||||||
|
**_create_form_kwargs(),
|
||||||
|
)
|
||||||
|
|
||||||
|
new_id = db.added[0].id
|
||||||
|
assert settings["default_endpoint_id"] == new_id
|
||||||
|
assert settings["default_endpoint_id"] != "dead"
|
||||||
|
|
||||||
|
|
||||||
|
def test_post_keeps_default_when_current_default_enabled(monkeypatch):
|
||||||
|
# Counter-case: an enabled default must be left untouched when another
|
||||||
|
# endpoint is added.
|
||||||
|
live = _make_endpoint(id="live", base_url="http://live-host/v1", is_enabled=True)
|
||||||
|
db = _PinnedFakeDb([live])
|
||||||
|
settings = _patch_create_deps(
|
||||||
|
monkeypatch, db, settings={"default_endpoint_id": "live", "default_model": "live-model"}
|
||||||
|
)
|
||||||
|
create = _get_route("/api/model-endpoints", "POST")
|
||||||
|
|
||||||
|
create(
|
||||||
|
_PinnedFakeRequest(),
|
||||||
|
base_url="http://another-host:1234/v1",
|
||||||
|
**_create_form_kwargs(),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert settings["default_endpoint_id"] == "live"
|
||||||
|
assert settings["default_model"] == "live-model"
|
||||||
|
|
||||||
|
|
||||||
def test_post_same_base_url_same_api_key_still_dedupes(monkeypatch):
|
def test_post_same_base_url_same_api_key_still_dedupes(monkeypatch):
|
||||||
existing = _make_endpoint(
|
existing = _make_endpoint(
|
||||||
base_url="https://api.example.test/v1",
|
base_url="https://api.example.test/v1",
|
||||||
|
|||||||
@@ -47,6 +47,20 @@ def test_find_bash_checks_local_app_data_git_install(monkeypatch):
|
|||||||
assert platform_compat.find_bash() == expected
|
assert platform_compat.find_bash() == expected
|
||||||
|
|
||||||
|
|
||||||
|
def test_find_bash_checks_local_app_data_programs_git_install(monkeypatch):
|
||||||
|
_reset_bash_cache(monkeypatch)
|
||||||
|
monkeypatch.setattr(platform_compat, "IS_WINDOWS", True)
|
||||||
|
monkeypatch.setattr(platform_compat.shutil, "which", lambda _name: None)
|
||||||
|
for env_name in platform_compat._WINDOWS_BASH_ROOT_ENV_VARS:
|
||||||
|
monkeypatch.delenv(env_name, raising=False)
|
||||||
|
monkeypatch.setenv("LocalAppData", r"C:\Users\alice\AppData\Local")
|
||||||
|
|
||||||
|
expected = r"C:\Users\alice\AppData\Local\Programs\Git\bin\bash.exe"
|
||||||
|
monkeypatch.setattr(platform_compat.os.path, "exists", lambda path: path == expected)
|
||||||
|
|
||||||
|
assert platform_compat.find_bash() == expected
|
||||||
|
|
||||||
|
|
||||||
def test_find_bash_skips_windows_wsl_stub(monkeypatch):
|
def test_find_bash_skips_windows_wsl_stub(monkeypatch):
|
||||||
_reset_bash_cache(monkeypatch)
|
_reset_bash_cache(monkeypatch)
|
||||||
monkeypatch.setattr(platform_compat, "IS_WINDOWS", True)
|
monkeypatch.setattr(platform_compat, "IS_WINDOWS", True)
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
"""Renaming a user must update all three owner caches, not just the SQL DB.
|
"""Renaming a user must update non-SQL owner stores, not just the SQL DB.
|
||||||
|
|
||||||
The DB owner-rename loop in the rename_user route updates every SQL-backed
|
The DB owner-rename loop in the rename_user route updates every SQL-backed
|
||||||
owner column, but three file-backed / in-memory stores are left stale:
|
owner column, but three file-backed / in-memory stores are left stale:
|
||||||
@@ -11,9 +11,15 @@ owner column, but three file-backed / in-memory stores are left stale:
|
|||||||
research_routes filters by `d.get("owner") == user`, making every report
|
research_routes filters by `d.get("owner") == user`, making every report
|
||||||
invisible after rename.
|
invisible after rename.
|
||||||
|
|
||||||
3. data/memory.json — a flat array where every entry has an `owner` field;
|
3. research_handler._active_tasks — in-flight research jobs carry the same
|
||||||
|
owner key while status/cancel/active routes filter by it.
|
||||||
|
|
||||||
|
4. data/memory.json — a flat array where every entry has an `owner` field;
|
||||||
memory_manager.load(owner=user) filters on it, so all memories vanish.
|
memory_manager.load(owner=user) filters on it, so all memories vanish.
|
||||||
|
|
||||||
|
5. data/uploads/uploads.json — each upload row carries an `owner` field and
|
||||||
|
owner-prefixed index key; stale metadata denies renamed users their uploads.
|
||||||
|
|
||||||
Regression coverage: these bugs are invisible in unit tests that mock the DB
|
Regression coverage: these bugs are invisible in unit tests that mock the DB
|
||||||
loop but don't exercise the file/cache patches added to the route.
|
loop but don't exercise the file/cache patches added to the route.
|
||||||
"""
|
"""
|
||||||
@@ -26,6 +32,7 @@ from types import SimpleNamespace
|
|||||||
from unittest.mock import MagicMock
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from fastapi import HTTPException
|
||||||
|
|
||||||
|
|
||||||
def _route(router, name):
|
def _route(router, name):
|
||||||
@@ -63,18 +70,70 @@ def rename_endpoint(monkeypatch, tmp_path):
|
|||||||
return _route(ar.setup_auth_routes(am), "rename_user"), am, tmp_path
|
return _route(ar.setup_auth_routes(am), "rename_user"), am, tmp_path
|
||||||
|
|
||||||
|
|
||||||
def _request(tmp_path, session_manager=None):
|
def _request(tmp_path, session_manager=None, token="t", research_handler=None, upload_handler=None):
|
||||||
state = SimpleNamespace(
|
state = SimpleNamespace(
|
||||||
invalidate_token_cache=lambda: None,
|
invalidate_token_cache=lambda: None,
|
||||||
session_manager=session_manager,
|
session_manager=session_manager,
|
||||||
|
research_handler=research_handler,
|
||||||
|
upload_handler=upload_handler,
|
||||||
)
|
)
|
||||||
return SimpleNamespace(
|
return SimpleNamespace(
|
||||||
cookies={"odysseus_session": "t"},
|
cookies={"odysseus_session": token},
|
||||||
app=SimpleNamespace(state=state),
|
app=SimpleNamespace(state=state),
|
||||||
state=SimpleNamespace(current_user="admin"),
|
state=SimpleNamespace(current_user="admin"),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _auth_manager_for_rollback_test(monkeypatch, tmp_path):
|
||||||
|
import core.auth as auth_mod
|
||||||
|
|
||||||
|
monkeypatch.setattr(auth_mod, "_hash_password", lambda password: f"hash:{password}")
|
||||||
|
monkeypatch.setattr(auth_mod, "_verify_password", lambda password, hashed: hashed == f"hash:{password}")
|
||||||
|
|
||||||
|
am = auth_mod.AuthManager(str(tmp_path / "auth.json"))
|
||||||
|
assert am.create_user("admin", "pw-123456", is_admin=True) is True
|
||||||
|
assert am.create_user("alice", "pw-123456") is True
|
||||||
|
return am
|
||||||
|
|
||||||
|
|
||||||
|
def _force_sql_owner_migration_failure(monkeypatch):
|
||||||
|
import core.database as cdb
|
||||||
|
|
||||||
|
class OwnerModel:
|
||||||
|
owner = "owner"
|
||||||
|
|
||||||
|
class FailingQuery:
|
||||||
|
def filter(self, *_args, **_kwargs):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def update(self, *_args, **_kwargs):
|
||||||
|
raise RuntimeError("forced owner migration failure")
|
||||||
|
|
||||||
|
class FailingSession:
|
||||||
|
def __init__(self):
|
||||||
|
self.rolled_back = False
|
||||||
|
self.closed = False
|
||||||
|
|
||||||
|
def query(self, _model):
|
||||||
|
return FailingQuery()
|
||||||
|
|
||||||
|
def rollback(self):
|
||||||
|
self.rolled_back = True
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
self.closed = True
|
||||||
|
|
||||||
|
db = FailingSession()
|
||||||
|
monkeypatch.setattr(cdb, "SessionLocal", lambda: db)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
cdb,
|
||||||
|
"Base",
|
||||||
|
SimpleNamespace(registry=SimpleNamespace(mappers=[SimpleNamespace(class_=OwnerModel)])),
|
||||||
|
raising=False,
|
||||||
|
)
|
||||||
|
return db
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# 1. In-memory session cache
|
# 1. In-memory session cache
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -183,6 +242,108 @@ def test_rename_no_deep_research_dir_does_not_crash(rename_endpoint):
|
|||||||
assert res["ok"] is True
|
assert res["ok"] is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_rename_updates_active_research_task_owner(rename_endpoint):
|
||||||
|
endpoint, _am, tmp_path = rename_endpoint
|
||||||
|
|
||||||
|
from routes.research_routes import setup_research_routes
|
||||||
|
from src.research_handler import ResearchHandler
|
||||||
|
|
||||||
|
rh = ResearchHandler.__new__(ResearchHandler)
|
||||||
|
rh._active_tasks = {
|
||||||
|
"alice-task": {
|
||||||
|
"owner": "Alice",
|
||||||
|
"status": "running",
|
||||||
|
"query": "q",
|
||||||
|
"progress": {},
|
||||||
|
"started_at": 1,
|
||||||
|
},
|
||||||
|
"carol-task": {
|
||||||
|
"owner": "carol",
|
||||||
|
"status": "running",
|
||||||
|
"query": "q2",
|
||||||
|
"progress": {},
|
||||||
|
"started_at": 2,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
asyncio.run(endpoint(
|
||||||
|
"alice",
|
||||||
|
SimpleNamespace(username="alice2"),
|
||||||
|
_request(tmp_path, research_handler=rh),
|
||||||
|
))
|
||||||
|
|
||||||
|
assert rh._active_tasks["alice-task"]["owner"] == "alice2"
|
||||||
|
assert rh._active_tasks["carol-task"]["owner"] == "carol"
|
||||||
|
|
||||||
|
router = setup_research_routes(rh)
|
||||||
|
active = next(
|
||||||
|
r.endpoint for r in router.routes
|
||||||
|
if getattr(r, "path", "") == "/api/research/active"
|
||||||
|
)
|
||||||
|
|
||||||
|
alice2 = asyncio.run(active(
|
||||||
|
SimpleNamespace(state=SimpleNamespace(current_user="alice2")),
|
||||||
|
))
|
||||||
|
alice = asyncio.run(active(
|
||||||
|
SimpleNamespace(state=SimpleNamespace(current_user="alice")),
|
||||||
|
))
|
||||||
|
|
||||||
|
assert [item["session_id"] for item in alice2["active"]] == ["alice-task"]
|
||||||
|
assert alice["active"] == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_research_handler_rename_owner_canonicalizes_new_owner():
|
||||||
|
from src.research_handler import ResearchHandler
|
||||||
|
|
||||||
|
rh = ResearchHandler.__new__(ResearchHandler)
|
||||||
|
rh._active_tasks = {
|
||||||
|
"task": {"owner": "Alice", "status": "running"},
|
||||||
|
}
|
||||||
|
|
||||||
|
changed = rh.rename_owner("alice", "Alice2")
|
||||||
|
assert changed == 1
|
||||||
|
assert rh._active_tasks["task"]["owner"] == "alice2"
|
||||||
|
|
||||||
|
|
||||||
|
def test_research_handler_rename_owner_uses_auth_lower_contract_not_casefold():
|
||||||
|
from src.research_handler import ResearchHandler
|
||||||
|
|
||||||
|
rh = ResearchHandler.__new__(ResearchHandler)
|
||||||
|
rh._active_tasks = {
|
||||||
|
"task-strasse": {"owner": "strasse", "status": "running"},
|
||||||
|
"task-sharp-s": {"owner": "straße", "status": "running"},
|
||||||
|
}
|
||||||
|
|
||||||
|
changed = rh.rename_owner("straße", "renamed")
|
||||||
|
|
||||||
|
assert changed == 1
|
||||||
|
assert rh._active_tasks["task-strasse"]["owner"] == "strasse"
|
||||||
|
assert rh._active_tasks["task-sharp-s"]["owner"] == "renamed"
|
||||||
|
|
||||||
|
|
||||||
|
def test_rename_updates_active_research_before_completed_json_sweep(rename_endpoint):
|
||||||
|
endpoint, _am, tmp_path = rename_endpoint
|
||||||
|
|
||||||
|
dr_dir = tmp_path / "deep_research"
|
||||||
|
dr_dir.mkdir()
|
||||||
|
report = dr_dir / "race-window.json"
|
||||||
|
report.write_text(json.dumps({"owner": "alice", "status": "done"}), encoding="utf-8")
|
||||||
|
owner_seen_by_active_hook = []
|
||||||
|
|
||||||
|
class FakeResearchHandler:
|
||||||
|
def rename_owner(self, _old, _new):
|
||||||
|
owner_seen_by_active_hook.append(json.loads(report.read_text(encoding="utf-8"))["owner"])
|
||||||
|
|
||||||
|
asyncio.run(endpoint(
|
||||||
|
"alice",
|
||||||
|
SimpleNamespace(username="alice2"),
|
||||||
|
_request(tmp_path, research_handler=FakeResearchHandler()),
|
||||||
|
))
|
||||||
|
|
||||||
|
assert owner_seen_by_active_hook == ["alice"]
|
||||||
|
assert json.loads(report.read_text(encoding="utf-8"))["owner"] == "alice2"
|
||||||
|
|
||||||
|
|
||||||
def test_rename_research_respects_custom_data_dir(monkeypatch, tmp_path):
|
def test_rename_research_respects_custom_data_dir(monkeypatch, tmp_path):
|
||||||
"""DEEP_RESEARCH_DIR (which honours ODYSSEUS_DATA_DIR) is used, not a
|
"""DEEP_RESEARCH_DIR (which honours ODYSSEUS_DATA_DIR) is used, not a
|
||||||
hardcoded relative path. Before the fix, setting ODYSSEUS_DATA_DIR made
|
hardcoded relative path. Before the fix, setting ODYSSEUS_DATA_DIR made
|
||||||
@@ -258,7 +419,56 @@ def test_rename_no_memory_json_does_not_crash(rename_endpoint):
|
|||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# 4. Skills (SKILL.md frontmatter + _usage.json sidecar)
|
# 4. uploads.json
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def test_rename_updates_upload_metadata_owner(rename_endpoint):
|
||||||
|
endpoint, _am, tmp_path = rename_endpoint
|
||||||
|
from src.upload_handler import UploadHandler
|
||||||
|
|
||||||
|
upload_dir = tmp_path / "uploads"
|
||||||
|
dated = upload_dir / "2026" / "06" / "09"
|
||||||
|
dated.mkdir(parents=True)
|
||||||
|
upload_id = "a" * 32 + ".txt"
|
||||||
|
upload_path = dated / upload_id
|
||||||
|
upload_path.write_text("alice private upload", encoding="utf-8")
|
||||||
|
handler = UploadHandler(str(tmp_path), str(upload_dir))
|
||||||
|
handler._atomic_write_json(
|
||||||
|
str(upload_dir / "uploads.json"),
|
||||||
|
{
|
||||||
|
"alice:hash-alice": {
|
||||||
|
"id": upload_id,
|
||||||
|
"path": str(upload_path),
|
||||||
|
"mime": "text/plain",
|
||||||
|
"size": upload_path.stat().st_size,
|
||||||
|
"name": "note.txt",
|
||||||
|
"hash": "hash-alice",
|
||||||
|
"original_name": "note.txt",
|
||||||
|
"uploaded_at": "2026-06-09T10:00:00",
|
||||||
|
"last_accessed": "2026-06-09T10:00:00",
|
||||||
|
"client_ip": "127.0.0.1",
|
||||||
|
"owner": "alice",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
asyncio.run(
|
||||||
|
endpoint(
|
||||||
|
"alice",
|
||||||
|
SimpleNamespace(username="alice2"),
|
||||||
|
_request(tmp_path, upload_handler=handler),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
updated = json.loads((upload_dir / "uploads.json").read_text(encoding="utf-8"))
|
||||||
|
assert "alice:hash-alice" not in updated
|
||||||
|
assert updated["alice2:hash-alice"]["owner"] == "alice2"
|
||||||
|
assert handler.resolve_upload(upload_id, owner="alice2")["path"] == str(upload_path)
|
||||||
|
assert handler.resolve_upload(upload_id, owner="alice") is None
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# 5. Skills (SKILL.md frontmatter + _usage.json sidecar)
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
_SKILL_MD = """\
|
_SKILL_MD = """\
|
||||||
@@ -333,8 +543,100 @@ def test_rename_no_skills_dir_does_not_crash(rename_endpoint):
|
|||||||
assert res["ok"] is True
|
assert res["ok"] is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_rename_skill_md_owner_case_insensitive(rename_endpoint):
|
||||||
|
"""SKILL.md written with owner: Alice (mixed case) must be updated when
|
||||||
|
renaming alice — the regex was missing re.IGNORECASE."""
|
||||||
|
endpoint, _am, tmp_path = rename_endpoint
|
||||||
|
|
||||||
|
skill_dir = tmp_path / "skills" / "general" / "s"
|
||||||
|
skill_dir.mkdir(parents=True)
|
||||||
|
(skill_dir / "SKILL.md").write_text(_SKILL_MD.format(owner="Alice"), encoding="utf-8")
|
||||||
|
|
||||||
|
asyncio.run(endpoint("alice", SimpleNamespace(username="alice2"), _request(tmp_path)))
|
||||||
|
|
||||||
|
assert "owner: alice2" in (skill_dir / "SKILL.md").read_text(encoding="utf-8")
|
||||||
|
|
||||||
|
|
||||||
|
def test_rename_usage_keys_case_insensitive(rename_endpoint):
|
||||||
|
"""_usage.json keys stored as Alice::skill-name must be migrated when
|
||||||
|
renaming alice — the old startswith check was not lowercasing."""
|
||||||
|
endpoint, _am, tmp_path = rename_endpoint
|
||||||
|
|
||||||
|
skills_root = tmp_path / "skills"
|
||||||
|
skills_root.mkdir(parents=True)
|
||||||
|
usage = {"Alice::my-skill": {"uses": 5, "last_used": 999}}
|
||||||
|
(skills_root / "_usage.json").write_text(json.dumps(usage), encoding="utf-8")
|
||||||
|
|
||||||
|
asyncio.run(endpoint("alice", SimpleNamespace(username="alice2"), _request(tmp_path)))
|
||||||
|
|
||||||
|
updated = json.loads((skills_root / "_usage.json").read_text(encoding="utf-8"))
|
||||||
|
assert "alice2::my-skill" in updated
|
||||||
|
assert "Alice::my-skill" not in updated
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# 5. P1 regression: rejected auth rename must not mutate file-backed stores
|
# 6. Rollback: auth rename must be restored if SQL owner migration fails
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def test_owner_migration_failure_rolls_back_auth_rename(monkeypatch, tmp_path):
|
||||||
|
import routes.auth_routes as ar
|
||||||
|
|
||||||
|
db = _force_sql_owner_migration_failure(monkeypatch)
|
||||||
|
am = _auth_manager_for_rollback_test(monkeypatch, tmp_path)
|
||||||
|
admin_token = am.create_session_trusted("admin")
|
||||||
|
alice_token = am.create_session_trusted("alice")
|
||||||
|
endpoint = _route(ar.setup_auth_routes(am), "rename_user")
|
||||||
|
|
||||||
|
with pytest.raises(HTTPException) as exc:
|
||||||
|
asyncio.run(
|
||||||
|
endpoint(
|
||||||
|
"alice",
|
||||||
|
SimpleNamespace(username="alice2"),
|
||||||
|
_request(tmp_path, token=admin_token),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
assert exc.value.status_code == 500
|
||||||
|
assert db.rolled_back is True
|
||||||
|
assert db.closed is True
|
||||||
|
assert "alice" in am.users
|
||||||
|
assert "alice2" not in am.users
|
||||||
|
assert am.get_username_for_token(alice_token) == "alice"
|
||||||
|
saved_users = json.loads((tmp_path / "auth.json").read_text(encoding="utf-8"))["users"]
|
||||||
|
assert "alice" in saved_users
|
||||||
|
assert "alice2" not in saved_users
|
||||||
|
|
||||||
|
|
||||||
|
def test_self_rename_owner_migration_failure_rolls_back_auth_session(monkeypatch, tmp_path):
|
||||||
|
import routes.auth_routes as ar
|
||||||
|
|
||||||
|
db = _force_sql_owner_migration_failure(monkeypatch)
|
||||||
|
am = _auth_manager_for_rollback_test(monkeypatch, tmp_path)
|
||||||
|
admin_token = am.create_session_trusted("admin")
|
||||||
|
endpoint = _route(ar.setup_auth_routes(am), "rename_user")
|
||||||
|
|
||||||
|
with pytest.raises(HTTPException) as exc:
|
||||||
|
asyncio.run(
|
||||||
|
endpoint(
|
||||||
|
"admin",
|
||||||
|
SimpleNamespace(username="chief"),
|
||||||
|
_request(tmp_path, token=admin_token),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
assert exc.value.status_code == 500
|
||||||
|
assert db.rolled_back is True
|
||||||
|
assert db.closed is True
|
||||||
|
assert "admin" in am.users
|
||||||
|
assert "chief" not in am.users
|
||||||
|
assert am.get_username_for_token(admin_token) == "admin"
|
||||||
|
saved_users = json.loads((tmp_path / "auth.json").read_text(encoding="utf-8"))["users"]
|
||||||
|
assert "admin" in saved_users
|
||||||
|
assert "chief" not in saved_users
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# 7. P1 regression: rejected auth rename must not mutate file-backed stores
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
def test_rejected_rename_does_not_mutate_files(monkeypatch, tmp_path):
|
def test_rejected_rename_does_not_mutate_files(monkeypatch, tmp_path):
|
||||||
|
|||||||
@@ -0,0 +1,57 @@
|
|||||||
|
"""`odysseus-research list --status complete` must match completed runs.
|
||||||
|
|
||||||
|
Completed research runs are persisted with status "done" (research_handler),
|
||||||
|
but the user-facing CLI value is the friendlier "complete". The CLI offered
|
||||||
|
"complete" yet filtered `status != args.status`, so `--status complete` never
|
||||||
|
matched any record. The fix keeps "complete" as the CLI value and maps it to
|
||||||
|
the stored "done" at filter time, so the on-disk corpus stays the source of
|
||||||
|
truth and the documented CLI surface keeps working.
|
||||||
|
"""
|
||||||
|
import importlib.machinery
|
||||||
|
import importlib.util
|
||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
from types import SimpleNamespace
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
ROOT = Path(__file__).resolve().parents[1]
|
||||||
|
|
||||||
|
|
||||||
|
def _load_cli():
|
||||||
|
path = ROOT / "scripts" / "odysseus-research"
|
||||||
|
loader = importlib.machinery.SourceFileLoader("odysseus_research_cli_status", str(path))
|
||||||
|
spec = importlib.util.spec_from_loader(loader.name, loader)
|
||||||
|
module = importlib.util.module_from_spec(spec)
|
||||||
|
loader.exec_module(module)
|
||||||
|
return module
|
||||||
|
|
||||||
|
|
||||||
|
def test_complete_is_a_valid_status_choice():
|
||||||
|
cli = _load_cli()
|
||||||
|
parser = cli._build_parser()
|
||||||
|
ns = parser.parse_args(["list", "--status", "complete"])
|
||||||
|
assert ns.status == "complete"
|
||||||
|
|
||||||
|
|
||||||
|
def test_filter_returns_completed_runs(tmp_path, monkeypatch):
|
||||||
|
cli = _load_cli(); cli._DATA_DIR = tmp_path
|
||||||
|
(tmp_path / "r1.json").write_text(json.dumps({"query": "q1", "status": "done"}))
|
||||||
|
(tmp_path / "r2.json").write_text(json.dumps({"query": "q2", "status": "running"}))
|
||||||
|
emitted = []
|
||||||
|
monkeypatch.setattr(cli, "emit", lambda value, args: emitted.append(value))
|
||||||
|
# CLI "complete" must map to the stored "done" and match r1.
|
||||||
|
cli.cmd_list(SimpleNamespace(status="complete", limit=50))
|
||||||
|
ids = [r["id"] for r in emitted[0]]
|
||||||
|
assert ids == ["r1"] # only the completed run
|
||||||
|
|
||||||
|
|
||||||
|
def test_verbatim_status_still_filters(tmp_path, monkeypatch):
|
||||||
|
cli = _load_cli(); cli._DATA_DIR = tmp_path
|
||||||
|
(tmp_path / "r1.json").write_text(json.dumps({"query": "q1", "status": "done"}))
|
||||||
|
(tmp_path / "r2.json").write_text(json.dumps({"query": "q2", "status": "running"}))
|
||||||
|
emitted = []
|
||||||
|
monkeypatch.setattr(cli, "emit", lambda value, args: emitted.append(value))
|
||||||
|
cli.cmd_list(SimpleNamespace(status="running", limit=50))
|
||||||
|
ids = [r["id"] for r in emitted[0]]
|
||||||
|
assert ids == ["r2"] # verbatim choices pass through unchanged
|
||||||
@@ -0,0 +1,99 @@
|
|||||||
|
from services.research.research_handler import ResearchHandler
|
||||||
|
|
||||||
|
|
||||||
|
def _format_report(findings):
|
||||||
|
handler = object.__new__(ResearchHandler)
|
||||||
|
return handler._format_research_report(
|
||||||
|
"test query",
|
||||||
|
"# Report\n\nBody",
|
||||||
|
{"Rounds": 1, "Queries": 1, "URLs": len(findings)},
|
||||||
|
1.0,
|
||||||
|
findings=findings,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _format_report_with_analyzed_urls(findings, analyzed_urls):
|
||||||
|
handler = object.__new__(ResearchHandler)
|
||||||
|
return handler._format_research_report(
|
||||||
|
"test query",
|
||||||
|
"# Report\n\nBody",
|
||||||
|
{"Rounds": 1, "Queries": 1, "URLs": len(analyzed_urls)},
|
||||||
|
1.0,
|
||||||
|
findings=findings,
|
||||||
|
analyzed_urls=analyzed_urls,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_research_report_lists_every_analyzed_url_once():
|
||||||
|
findings = [
|
||||||
|
{
|
||||||
|
"url": "https://example.com/good",
|
||||||
|
"title": "Good Source",
|
||||||
|
"summary": "Detailed useful evidence about the query.",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"url": "https://example.com/low-quality",
|
||||||
|
"title": "Low Quality Page",
|
||||||
|
"summary": "",
|
||||||
|
"evidence": "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"url": "https://example.com/good",
|
||||||
|
"title": "Good Source Duplicate",
|
||||||
|
"summary": "Repeated extraction from the same URL.",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
report = _format_report(findings)
|
||||||
|
|
||||||
|
assert "### Analyzed URLs" in report
|
||||||
|
analyzed_section = report.split("### Analyzed URLs", 1)[1].split("<details>", 1)[0]
|
||||||
|
assert "1. [Good Source](https://example.com/good)" in analyzed_section
|
||||||
|
assert "2. [Low Quality Page](https://example.com/low-quality)" in analyzed_section
|
||||||
|
assert analyzed_section.count("https://example.com/good") == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_research_report_keeps_sources_section_curated():
|
||||||
|
findings = [
|
||||||
|
{
|
||||||
|
"url": "https://example.com/good",
|
||||||
|
"title": "Good Source",
|
||||||
|
"summary": "Detailed useful evidence about the query.",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"url": "https://example.com/low-quality",
|
||||||
|
"title": "Low Quality Page",
|
||||||
|
"summary": "",
|
||||||
|
"evidence": "",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
report = _format_report(findings)
|
||||||
|
|
||||||
|
sources_section = report.split("### Sources", 1)[1].split("### Analyzed URLs", 1)[0]
|
||||||
|
assert "[Good Source](https://example.com/good)" in sources_section
|
||||||
|
assert "https://example.com/low-quality" not in sources_section
|
||||||
|
|
||||||
|
|
||||||
|
def test_research_report_uses_full_analyzed_url_set_not_just_findings():
|
||||||
|
findings = [
|
||||||
|
{
|
||||||
|
"url": "https://example.com/finding",
|
||||||
|
"title": "Finding Source",
|
||||||
|
"summary": "Detailed useful evidence about the query.",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
analyzed_urls = [
|
||||||
|
{"url": "https://example.com/finding", "title": "Finding Source"},
|
||||||
|
{"url": "https://example.com/fetched-no-finding", "title": "Fetched No Finding"},
|
||||||
|
{"url": "https://example.com/finding", "title": "Duplicate"},
|
||||||
|
]
|
||||||
|
|
||||||
|
report = _format_report_with_analyzed_urls(findings, analyzed_urls)
|
||||||
|
|
||||||
|
sources_section = report.split("### Sources", 1)[1].split("### Analyzed URLs", 1)[0]
|
||||||
|
analyzed_section = report.split("### Analyzed URLs", 1)[1].split("<details>", 1)[0]
|
||||||
|
assert "https://example.com/fetched-no-finding" not in sources_section
|
||||||
|
assert "1. [Finding Source](https://example.com/finding)" in analyzed_section
|
||||||
|
assert "2. [Fetched No Finding](https://example.com/fetched-no-finding)" in analyzed_section
|
||||||
|
assert analyzed_section.count("https://example.com/finding") == 1
|
||||||
@@ -0,0 +1,41 @@
|
|||||||
|
"""get_status must not rescan the whole research dir on every SSE poll.
|
||||||
|
|
||||||
|
get_avg_duration() globs and JSON-parses every file under the research data dir.
|
||||||
|
get_status() called it unconditionally on each poll, including for sessions that
|
||||||
|
are not active (the common case while a client polls a finished report). It is
|
||||||
|
now computed only for active sessions and memoized on the entry.
|
||||||
|
"""
|
||||||
|
from src.research_handler import ResearchHandler
|
||||||
|
|
||||||
|
|
||||||
|
def _handler():
|
||||||
|
h = ResearchHandler.__new__(ResearchHandler)
|
||||||
|
h._active_tasks = {}
|
||||||
|
return h
|
||||||
|
|
||||||
|
|
||||||
|
def test_inactive_session_does_not_compute_avg(monkeypatch):
|
||||||
|
h = _handler()
|
||||||
|
calls = []
|
||||||
|
monkeypatch.setattr(h, "get_avg_duration", lambda: (calls.append(1), 5.0)[1])
|
||||||
|
# Unknown session, no disk file -> None, and no expensive avg scan.
|
||||||
|
assert h.get_status("missing-session") is None
|
||||||
|
assert calls == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_active_session_memoizes_avg(monkeypatch):
|
||||||
|
h = _handler()
|
||||||
|
h._active_tasks["s1"] = {
|
||||||
|
"status": "running", "progress": {}, "query": "q", "started_at": 0,
|
||||||
|
}
|
||||||
|
calls = []
|
||||||
|
monkeypatch.setattr(h, "get_avg_duration", lambda: (calls.append(1), 12.0)[1])
|
||||||
|
|
||||||
|
r1 = h.get_status("s1")
|
||||||
|
r2 = h.get_status("s1")
|
||||||
|
r3 = h.get_status("s1")
|
||||||
|
|
||||||
|
assert r1["avg_duration"] == 12.0
|
||||||
|
assert r2["avg_duration"] == 12.0 and r3["avg_duration"] == 12.0
|
||||||
|
# Computed once across many polls, not once per poll.
|
||||||
|
assert len(calls) == 1
|
||||||
@@ -58,6 +58,62 @@ def test_rename_into_reserved_username_is_blocked(tmp_path):
|
|||||||
assert "bob" in mgr.users
|
assert "bob" in mgr.users
|
||||||
|
|
||||||
|
|
||||||
|
def test_legacy_reserved_username_is_removed_on_load(tmp_path):
|
||||||
|
auth_path = tmp_path / "auth.json"
|
||||||
|
auth_path.write_text(
|
||||||
|
'{"users": {"internal-tool": {"password_hash": "unused", "is_admin": false}, '
|
||||||
|
'"admin": {"password_hash": "unused", "is_admin": true}}}',
|
||||||
|
encoding="utf-8",
|
||||||
|
)
|
||||||
|
mgr = _fresh_auth_manager(tmp_path)
|
||||||
|
|
||||||
|
assert "internal-tool" not in mgr.users
|
||||||
|
assert "admin" in mgr.users
|
||||||
|
assert "internal-tool" not in auth_path.read_text(encoding="utf-8")
|
||||||
|
|
||||||
|
|
||||||
|
def test_legacy_reserved_username_session_cannot_authenticate(tmp_path):
|
||||||
|
auth_path = tmp_path / "auth.json"
|
||||||
|
sessions_path = tmp_path / "sessions.json"
|
||||||
|
auth_path.write_text(
|
||||||
|
'{"users": {"internal-tool": {"password_hash": "unused", "is_admin": false}}}',
|
||||||
|
encoding="utf-8",
|
||||||
|
)
|
||||||
|
sessions_path.write_text(
|
||||||
|
'{"tok": {"username": "internal-tool", "expiry": 9999999999}}',
|
||||||
|
encoding="utf-8",
|
||||||
|
)
|
||||||
|
mgr = _fresh_auth_manager(tmp_path)
|
||||||
|
|
||||||
|
assert mgr.validate_token("tok") is False
|
||||||
|
assert mgr.get_username_for_token("tok") is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_legacy_reserved_single_user_migrates_to_admin(tmp_path):
|
||||||
|
auth_path = tmp_path / "auth.json"
|
||||||
|
auth_path.write_text(
|
||||||
|
'{"username": "internal-tool", "password_hash": "unused"}',
|
||||||
|
encoding="utf-8",
|
||||||
|
)
|
||||||
|
mgr = _fresh_auth_manager(tmp_path)
|
||||||
|
|
||||||
|
assert "internal-tool" not in mgr.users
|
||||||
|
assert "admin" in mgr.users
|
||||||
|
assert mgr.is_admin("admin") is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_token_cache_owner_normalization_requires_current_user():
|
||||||
|
clear_module("core.auth")
|
||||||
|
from core.auth import normalize_known_username
|
||||||
|
|
||||||
|
users = {"alice": {}, "admin": {}}
|
||||||
|
|
||||||
|
assert normalize_known_username(users, " Alice ") == "alice"
|
||||||
|
assert normalize_known_username(users, "internal-tool") is None
|
||||||
|
assert normalize_known_username(users, "api") is None
|
||||||
|
assert normalize_known_username(users, "") is None
|
||||||
|
|
||||||
|
|
||||||
def test_normal_usernames_still_allowed(tmp_path):
|
def test_normal_usernames_still_allowed(tmp_path):
|
||||||
mgr = _fresh_auth_manager(tmp_path)
|
mgr = _fresh_auth_manager(tmp_path)
|
||||||
assert mgr.create_user("alice", "pw-123456") is True
|
assert mgr.create_user("alice", "pw-123456") is True
|
||||||
|
|||||||
@@ -647,6 +647,60 @@ def test_public_agent_policy_hides_sensitive_tools(monkeypatch):
|
|||||||
assert "manage_tasks" in blocked
|
assert "manage_tasks" in blocked
|
||||||
|
|
||||||
|
|
||||||
|
def test_presetup_does_not_grant_admin_tools_when_auth_enabled(monkeypatch):
|
||||||
|
"""Pre-setup window: auth is enabled but no admin user exists yet.
|
||||||
|
|
||||||
|
This must NOT be treated as single-user/admin at the tool layer — the
|
||||||
|
server-execution tools (bash/python) stay blocked as defense-in-depth so
|
||||||
|
an unauthenticated caller that slips past the auth middleware (e.g. via a
|
||||||
|
loopback bypass) can't reach an RCE before setup completes.
|
||||||
|
"""
|
||||||
|
monkeypatch.delenv("AUTH_ENABLED", raising=False) # default: enabled
|
||||||
|
auth_mod = _install_core_auth_stub(monkeypatch)
|
||||||
|
|
||||||
|
class FakeAuth:
|
||||||
|
is_configured = False
|
||||||
|
|
||||||
|
def is_admin(self, username):
|
||||||
|
return False
|
||||||
|
|
||||||
|
monkeypatch.setattr(auth_mod, "AuthManager", lambda: FakeAuth())
|
||||||
|
|
||||||
|
from src.tool_security import (
|
||||||
|
blocked_tools_for_owner,
|
||||||
|
owner_is_admin_or_single_user,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert owner_is_admin_or_single_user(None) is False
|
||||||
|
blocked = blocked_tools_for_owner(None)
|
||||||
|
assert "bash" in blocked
|
||||||
|
assert "python" in blocked
|
||||||
|
|
||||||
|
|
||||||
|
def test_single_user_mode_keeps_full_tool_access_when_auth_disabled(monkeypatch):
|
||||||
|
"""Intentional single-user mode (AUTH_ENABLED=false) keeps full tool
|
||||||
|
access even with no admin user — this is the default local/self-host UX
|
||||||
|
and must not regress."""
|
||||||
|
monkeypatch.setenv("AUTH_ENABLED", "false")
|
||||||
|
auth_mod = _install_core_auth_stub(monkeypatch)
|
||||||
|
|
||||||
|
class FakeAuth:
|
||||||
|
is_configured = False
|
||||||
|
|
||||||
|
def is_admin(self, username):
|
||||||
|
return False
|
||||||
|
|
||||||
|
monkeypatch.setattr(auth_mod, "AuthManager", lambda: FakeAuth())
|
||||||
|
|
||||||
|
from src.tool_security import (
|
||||||
|
blocked_tools_for_owner,
|
||||||
|
owner_is_admin_or_single_user,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert owner_is_admin_or_single_user(None) is True
|
||||||
|
assert blocked_tools_for_owner(None) == set()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_webhook_tool_reuses_private_url_validation():
|
async def test_webhook_tool_reuses_private_url_validation():
|
||||||
class FakeDb:
|
class FakeDb:
|
||||||
|
|||||||
@@ -0,0 +1,23 @@
|
|||||||
|
import pytest
|
||||||
|
from fastapi import HTTPException
|
||||||
|
|
||||||
|
from routes._validators import validate_remote_host, validate_ssh_port
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_ssh_port_rejects_shell_payload():
|
||||||
|
for port in ["22;id", "$(id)", "-p 22", "0", "65536"]:
|
||||||
|
with pytest.raises(HTTPException):
|
||||||
|
validate_ssh_port(port)
|
||||||
|
assert validate_ssh_port("2222") == "2222"
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_remote_host_rejects_ssh_option_shape():
|
||||||
|
for host in [
|
||||||
|
"-oProxyCommand=sh",
|
||||||
|
"alice@-oProxyCommand=sh",
|
||||||
|
"--",
|
||||||
|
"-p2222",
|
||||||
|
]:
|
||||||
|
with pytest.raises(HTTPException):
|
||||||
|
validate_remote_host(host)
|
||||||
|
assert validate_remote_host("alice@gpu-box_1") == "alice@gpu-box_1"
|
||||||
@@ -90,8 +90,8 @@ def test_service_ddg_html_fallback_sends_safesearch(monkeypatch):
|
|||||||
seen["params"] = kwargs["params"]
|
seen["params"] = kwargs["params"]
|
||||||
return _Response()
|
return _Response()
|
||||||
|
|
||||||
monkeypatch.setitem(sys.modules, "duckduckgo_search", None)
|
|
||||||
monkeypatch.setattr(providers, "_get_search_settings", lambda: {"search_safesearch": "off"})
|
monkeypatch.setattr(providers, "_get_search_settings", lambda: {"search_safesearch": "off"})
|
||||||
|
monkeypatch.setitem(sys.modules, "ddgs", None)
|
||||||
monkeypatch.setattr(providers.httpx, "get", fake_get)
|
monkeypatch.setattr(providers.httpx, "get", fake_get)
|
||||||
|
|
||||||
results = providers.duckduckgo_search("odysseus", count=1)
|
results = providers.duckduckgo_search("odysseus", count=1)
|
||||||
|
|||||||
@@ -0,0 +1,55 @@
|
|||||||
|
"""FTS session search must fetch hit rows in one query, not one per hit.
|
||||||
|
|
||||||
|
_search_fts looked up each FTS hit's full row with its own
|
||||||
|
db.query(...).filter(id == message_id).first(), an N+1 query. The lookup is now
|
||||||
|
a single batched IN(...) query via _fetch_messages_by_id.
|
||||||
|
"""
|
||||||
|
from src.session_search import _fetch_messages_by_id
|
||||||
|
|
||||||
|
|
||||||
|
class _Msg:
|
||||||
|
def __init__(self, mid):
|
||||||
|
self.id = mid
|
||||||
|
|
||||||
|
|
||||||
|
class _Query:
|
||||||
|
def __init__(self, rows, calls):
|
||||||
|
self._rows = rows
|
||||||
|
self._calls = calls
|
||||||
|
|
||||||
|
def join(self, *a, **k):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def filter(self, *a, **k):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def all(self):
|
||||||
|
self._calls["all"] += 1
|
||||||
|
return self._rows
|
||||||
|
|
||||||
|
|
||||||
|
class _DB:
|
||||||
|
def __init__(self, rows):
|
||||||
|
self._rows = rows
|
||||||
|
self.calls = {"query": 0, "all": 0}
|
||||||
|
|
||||||
|
def query(self, *a, **k):
|
||||||
|
self.calls["query"] += 1
|
||||||
|
return _Query(self._rows, self.calls)
|
||||||
|
|
||||||
|
|
||||||
|
def test_batches_into_single_query():
|
||||||
|
rows = [(_Msg("m1"), "Session One"), (_Msg("m2"), "Session Two")]
|
||||||
|
db = _DB(rows)
|
||||||
|
out = _fetch_messages_by_id(db, ["m1", "m2"])
|
||||||
|
# One query for all hits, not one per hit.
|
||||||
|
assert db.calls["query"] == 1
|
||||||
|
assert db.calls["all"] == 1
|
||||||
|
assert out["m1"][1] == "Session One"
|
||||||
|
assert out["m2"][0].id == "m2"
|
||||||
|
|
||||||
|
|
||||||
|
def test_empty_ids_does_no_query():
|
||||||
|
db = _DB([])
|
||||||
|
assert _fetch_messages_by_id(db, []) == {}
|
||||||
|
assert db.calls["query"] == 0
|
||||||
@@ -40,7 +40,8 @@ def test_secret_in_list_of_dicts_blanked():
|
|||||||
|
|
||||||
def test_non_secret_keys_preserved():
|
def test_non_secret_keys_preserved():
|
||||||
s = {"keybinds": {"send": "Enter"}, "theme": "dark", "image_model": "x",
|
s = {"keybinds": {"send": "Enter"}, "theme": "dark", "image_model": "x",
|
||||||
"default_endpoint_id": "ep1", "search_result_count": 5, "tts_enabled": True}
|
"default_endpoint_id": "ep1", "search_result_count": 5, "tts_enabled": True,
|
||||||
|
"tokenId": "public-id", "keyId": "public-key-id"}
|
||||||
assert scrub_settings(s) == s # untouched
|
assert scrub_settings(s) == s # untouched
|
||||||
|
|
||||||
|
|
||||||
@@ -71,6 +72,23 @@ def test_exact_name_matches():
|
|||||||
assert all(v == "" for v in out.values()), out
|
assert all(v == "" for v in out.values()), out
|
||||||
|
|
||||||
|
|
||||||
|
def test_camel_case_secret_keys_blanked():
|
||||||
|
out = scrub_settings({
|
||||||
|
"apiKey": "api-secret",
|
||||||
|
"accessToken": "access-secret",
|
||||||
|
"refreshToken": "refresh-secret",
|
||||||
|
"clientSecret": "client-secret",
|
||||||
|
"hfToken": "hf-secret",
|
||||||
|
"nested": {"privateKey": "private-secret"},
|
||||||
|
})
|
||||||
|
assert out["apiKey"] == ""
|
||||||
|
assert out["accessToken"] == ""
|
||||||
|
assert out["refreshToken"] == ""
|
||||||
|
assert out["clientSecret"] == ""
|
||||||
|
assert out["hfToken"] == ""
|
||||||
|
assert out["nested"]["privateKey"] == ""
|
||||||
|
|
||||||
|
|
||||||
def test_non_object_settings_return_empty_mapping():
|
def test_non_object_settings_return_empty_mapping():
|
||||||
assert scrub_settings(["not", "settings"]) == {}
|
assert scrub_settings(["not", "settings"]) == {}
|
||||||
assert scrub_settings("not settings") == {}
|
assert scrub_settings("not settings") == {}
|
||||||
|
|||||||
@@ -0,0 +1,101 @@
|
|||||||
|
import json
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from src.upload_handler import UploadHandler
|
||||||
|
|
||||||
|
|
||||||
|
def _make_handler(tmp_path: Path) -> UploadHandler:
|
||||||
|
base = tmp_path / "base"
|
||||||
|
upload = tmp_path / "uploads"
|
||||||
|
base.mkdir()
|
||||||
|
upload.mkdir()
|
||||||
|
return UploadHandler(base_dir=str(base), upload_dir=str(upload))
|
||||||
|
|
||||||
|
|
||||||
|
def _db_path(handler: UploadHandler) -> str:
|
||||||
|
return os.path.join(handler.upload_dir, "uploads.json")
|
||||||
|
|
||||||
|
|
||||||
|
def _write_upload_file(handler: UploadHandler, file_id: str, content: bytes = b"content") -> str:
|
||||||
|
upload_day = Path(handler.upload_dir) / "2026" / "06" / "09"
|
||||||
|
upload_day.mkdir(parents=True, exist_ok=True)
|
||||||
|
path = upload_day / file_id
|
||||||
|
path.write_bytes(content)
|
||||||
|
return str(path)
|
||||||
|
|
||||||
|
|
||||||
|
def _entry(handler: UploadHandler, owner: str, file_hash: str, file_id: str) -> dict:
|
||||||
|
path = _write_upload_file(handler, file_id, content=f"{owner}:{file_hash}".encode())
|
||||||
|
return {
|
||||||
|
"id": file_id,
|
||||||
|
"path": path,
|
||||||
|
"mime": "text/plain",
|
||||||
|
"size": os.path.getsize(path),
|
||||||
|
"name": f"{file_id}.txt",
|
||||||
|
"hash": file_hash,
|
||||||
|
"original_name": f"{file_id}.txt",
|
||||||
|
"uploaded_at": "2026-06-09T10:00:00",
|
||||||
|
"last_accessed": "2026-06-09T10:00:00",
|
||||||
|
"client_ip": "127.0.0.1",
|
||||||
|
"owner": owner,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_rename_owner_updates_upload_metadata_key_and_resolver(tmp_path):
|
||||||
|
handler = _make_handler(tmp_path)
|
||||||
|
alice_id = "a" * 32 + ".txt"
|
||||||
|
alice_entry = _entry(handler, "Alice", "hash-alice", alice_id)
|
||||||
|
bob_entry = _entry(handler, "bob", "hash-bob", "b" * 32 + ".txt")
|
||||||
|
handler._atomic_write_json(
|
||||||
|
_db_path(handler),
|
||||||
|
{
|
||||||
|
"Alice:hash-alice": alice_entry,
|
||||||
|
"bob:hash-bob": bob_entry,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
renamed = handler.rename_owner("alice", "alice2")
|
||||||
|
|
||||||
|
assert renamed == 1
|
||||||
|
updated = json.loads(Path(_db_path(handler)).read_text(encoding="utf-8"))
|
||||||
|
assert "Alice:hash-alice" not in updated
|
||||||
|
assert "alice2:hash-alice" in updated
|
||||||
|
assert updated["alice2:hash-alice"]["owner"] == "alice2"
|
||||||
|
assert updated["alice2:hash-alice"]["path"] == alice_entry["path"]
|
||||||
|
assert updated["alice2:hash-alice"]["hash"] == alice_entry["hash"]
|
||||||
|
assert updated["alice2:hash-alice"]["uploaded_at"] == alice_entry["uploaded_at"]
|
||||||
|
assert updated["alice2:hash-alice"]["last_accessed"] == alice_entry["last_accessed"]
|
||||||
|
assert updated["bob:hash-bob"]["owner"] == "bob"
|
||||||
|
|
||||||
|
assert handler.resolve_upload(alice_id, owner="alice2")["id"] == alice_id
|
||||||
|
assert handler.resolve_upload(alice_id, owner="alice") is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_rename_owner_preserves_rows_when_target_key_collides(tmp_path):
|
||||||
|
handler = _make_handler(tmp_path)
|
||||||
|
migrated_id = "c" * 32 + ".txt"
|
||||||
|
existing_id = "d" * 32 + ".txt"
|
||||||
|
migrated = _entry(handler, "alice", "same-hash", migrated_id)
|
||||||
|
existing = _entry(handler, "alice2", "same-hash", existing_id)
|
||||||
|
unrelated = _entry(handler, "carol", "other-hash", "e" * 32 + ".txt")
|
||||||
|
handler._atomic_write_json(
|
||||||
|
_db_path(handler),
|
||||||
|
{
|
||||||
|
"alice:same-hash": migrated,
|
||||||
|
"alice2:same-hash": existing,
|
||||||
|
"carol:other-hash": unrelated,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
renamed = handler.rename_owner("alice", "alice2")
|
||||||
|
|
||||||
|
assert renamed == 1
|
||||||
|
updated = json.loads(Path(_db_path(handler)).read_text(encoding="utf-8"))
|
||||||
|
assert len(updated) == 3
|
||||||
|
assert updated["alice2:same-hash"]["id"] == existing_id
|
||||||
|
migrated_key = f"alice2:same-hash:{migrated_id}"
|
||||||
|
assert updated[migrated_key]["id"] == migrated_id
|
||||||
|
assert updated[migrated_key]["owner"] == "alice2"
|
||||||
|
assert updated[migrated_key]["path"] == migrated["path"]
|
||||||
|
assert updated["carol:other-hash"] == unrelated
|
||||||
@@ -0,0 +1,47 @@
|
|||||||
|
"""Startup warmup must resolve real endpoint URLs.
|
||||||
|
|
||||||
|
The warmup/keepalive loop called `model_discovery.get_endpoints()`, which does
|
||||||
|
not exist on ModelDiscovery, so it raised AttributeError every run and pinged
|
||||||
|
nothing. `ModelDiscovery.warmup_ping_urls()` resolves the /models probe URLs
|
||||||
|
from the real discovery API.
|
||||||
|
"""
|
||||||
|
from src.model_discovery import ModelDiscovery
|
||||||
|
|
||||||
|
|
||||||
|
def _md():
|
||||||
|
return ModelDiscovery.__new__(ModelDiscovery)
|
||||||
|
|
||||||
|
|
||||||
|
def test_old_method_never_existed():
|
||||||
|
# Documents why the old warmup was a silent no-op.
|
||||||
|
assert not hasattr(ModelDiscovery, "get_endpoints")
|
||||||
|
|
||||||
|
|
||||||
|
def test_resolves_models_urls_from_discovered_items():
|
||||||
|
md = _md()
|
||||||
|
md.discover_models = lambda: {"items": [
|
||||||
|
{"url": "http://host:8000/v1/chat/completions", "models": ["a"]},
|
||||||
|
{"url": "http://host:1234/v1/chat/completions", "models": ["b"]},
|
||||||
|
]}
|
||||||
|
assert md.warmup_ping_urls() == [
|
||||||
|
"http://host:8000/v1/models",
|
||||||
|
"http://host:1234/v1/models",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_limit_caps_results():
|
||||||
|
md = _md()
|
||||||
|
md.discover_models = lambda: {"items": [
|
||||||
|
{"url": f"http://h:{8000 + i}/v1/chat/completions"} for i in range(10)
|
||||||
|
]}
|
||||||
|
assert len(md.warmup_ping_urls(limit=3)) == 3
|
||||||
|
|
||||||
|
|
||||||
|
def test_discovery_failure_degrades_to_empty():
|
||||||
|
md = _md()
|
||||||
|
|
||||||
|
def boom():
|
||||||
|
raise RuntimeError("port scan failed")
|
||||||
|
|
||||||
|
md.discover_models = boom
|
||||||
|
assert md.warmup_ping_urls() == []
|
||||||
@@ -0,0 +1,110 @@
|
|||||||
|
"""fetch_webpage_content must return plain-text and Markdown bodies verbatim.
|
||||||
|
|
||||||
|
raw.githubusercontent.com serves Markdown as `text/plain`, and a lot of code
|
||||||
|
and tool documentation lives in `.md` / `.txt`. Those have no HTML structure,
|
||||||
|
so the HTML branch extracted nothing and web_fetch reported "no readable text
|
||||||
|
content". The plain-text branch returns the body as-is. HTML stays on the
|
||||||
|
parsing path.
|
||||||
|
"""
|
||||||
|
import types
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from services.search import content as content_mod
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeResponse:
|
||||||
|
def __init__(self, text, content_type, status_code=200):
|
||||||
|
self.text = text
|
||||||
|
self.content = text.encode("utf-8")
|
||||||
|
self.headers = {"Content-Type": content_type}
|
||||||
|
self.status_code = status_code
|
||||||
|
|
||||||
|
def raise_for_status(self):
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def no_cache(monkeypatch, tmp_path):
|
||||||
|
# Force a cache miss and skip disk writes so the test is hermetic.
|
||||||
|
monkeypatch.setattr(content_mod, "CONTENT_CACHE_DIR", tmp_path)
|
||||||
|
monkeypatch.setattr(content_mod, "_cache_result", lambda *a, **k: None)
|
||||||
|
|
||||||
|
|
||||||
|
def _patch_fetch(monkeypatch, text, content_type):
|
||||||
|
monkeypatch.setattr(
|
||||||
|
content_mod,
|
||||||
|
"_get_public_url",
|
||||||
|
lambda url, headers=None, timeout=5: _FakeResponse(text, content_type),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
MARKDOWN = "# Title\n\nSome **docs** with a [link](https://example.com).\n"
|
||||||
|
|
||||||
|
|
||||||
|
def test_markdown_text_plain_returns_body(monkeypatch, no_cache):
|
||||||
|
_patch_fetch(monkeypatch, MARKDOWN, "text/plain; charset=utf-8")
|
||||||
|
r = content_mod.fetch_webpage_content(
|
||||||
|
"https://raw.githubusercontent.com/o/r/master/Documentation/Patterns.md"
|
||||||
|
)
|
||||||
|
assert r["success"] is True
|
||||||
|
assert r["content"] == MARKDOWN.strip()
|
||||||
|
assert r["title"] == "patterns.md"
|
||||||
|
assert r["error"] == ""
|
||||||
|
|
||||||
|
|
||||||
|
def test_text_markdown_content_type_returns_body(monkeypatch, no_cache):
|
||||||
|
_patch_fetch(monkeypatch, MARKDOWN, "text/markdown")
|
||||||
|
r = content_mod.fetch_webpage_content("https://example.com/readme")
|
||||||
|
assert r["success"] is True
|
||||||
|
assert r["content"] == MARKDOWN.strip()
|
||||||
|
|
||||||
|
|
||||||
|
def test_octet_stream_with_txt_suffix_returns_body(monkeypatch, no_cache):
|
||||||
|
# Some servers mislabel text files; the URL-suffix fallback still reads it.
|
||||||
|
_patch_fetch(monkeypatch, "plain notes\nline two\n", "application/octet-stream")
|
||||||
|
r = content_mod.fetch_webpage_content("https://example.com/notes.txt")
|
||||||
|
assert r["success"] is True
|
||||||
|
assert r["content"] == "plain notes\nline two"
|
||||||
|
|
||||||
|
|
||||||
|
def test_application_json_returns_body(monkeypatch, no_cache):
|
||||||
|
# application/json is not text/*; it must still be returned verbatim
|
||||||
|
# instead of being fed to the HTML parser (which yields empty content).
|
||||||
|
body = '{"name": "odysseus", "items": [1, 2, 3]}'
|
||||||
|
_patch_fetch(monkeypatch, body, "application/json")
|
||||||
|
r = content_mod.fetch_webpage_content("https://api.example.com/data")
|
||||||
|
assert r["success"] is True
|
||||||
|
assert r["content"] == body
|
||||||
|
|
||||||
|
|
||||||
|
def test_ld_json_suffix_content_type_returns_body(monkeypatch, no_cache):
|
||||||
|
body = '{"@context": "https://schema.org"}'
|
||||||
|
_patch_fetch(monkeypatch, body, "application/ld+json")
|
||||||
|
r = content_mod.fetch_webpage_content("https://example.com/meta")
|
||||||
|
assert r["success"] is True
|
||||||
|
assert r["content"] == body
|
||||||
|
|
||||||
|
|
||||||
|
def test_json_suffix_with_octet_stream_returns_body(monkeypatch, no_cache):
|
||||||
|
body = '{"raw": true}'
|
||||||
|
_patch_fetch(monkeypatch, body, "application/octet-stream")
|
||||||
|
r = content_mod.fetch_webpage_content("https://example.com/package.json")
|
||||||
|
assert r["success"] is True
|
||||||
|
assert r["content"] == body
|
||||||
|
|
||||||
|
|
||||||
|
def test_empty_text_body_is_not_success(monkeypatch, no_cache):
|
||||||
|
_patch_fetch(monkeypatch, " \n ", "text/plain")
|
||||||
|
r = content_mod.fetch_webpage_content("https://example.com/blank.txt")
|
||||||
|
assert r["success"] is False
|
||||||
|
assert r["content"] == ""
|
||||||
|
|
||||||
|
|
||||||
|
def test_html_still_uses_parser(monkeypatch, no_cache):
|
||||||
|
# An HTML body must not be short-circuited by the text branch.
|
||||||
|
html = "<html><head><title>Hi</title></head><body><p>Hello world body text</p></body></html>"
|
||||||
|
_patch_fetch(monkeypatch, html, "text/html; charset=utf-8")
|
||||||
|
r = content_mod.fetch_webpage_content("https://example.com/page")
|
||||||
|
assert r["title"] == "Hi"
|
||||||
|
assert "Hello world body text" in r["content"]
|
||||||
@@ -0,0 +1,119 @@
|
|||||||
|
"""Pin the web_search tool-icon rendering in the agent thread (PR #??).
|
||||||
|
|
||||||
|
Verifies:
|
||||||
|
- web_search renders an <svg> icon instead of raw markup
|
||||||
|
- Other tools get the default ▶ icon
|
||||||
|
- Hostile tool names are HTML-escaped in the label
|
||||||
|
|
||||||
|
Pure JS via node --input-type=module (same approach as
|
||||||
|
test_composer_arrow_up_recall_js.py). Skips when node is not installed.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import shutil
|
||||||
|
import subprocess
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
_REPO = Path(__file__).resolve().parent.parent
|
||||||
|
_HAS_NODE = shutil.which("node") is not None
|
||||||
|
|
||||||
|
_CHECK_JS = r"""
|
||||||
|
function esc(s) {
|
||||||
|
const map = { '&': '&', '<': '<', '>': '>', '"': '"', "'": ''' };
|
||||||
|
return (s || '').replace(/[&<>"']/g, (m) => map[m]);
|
||||||
|
}
|
||||||
|
|
||||||
|
const _searchIcon = '<svg width="14" height="14" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2.5" stroke-linecap="round" style="vertical-align:-2px;margin-right:4px"><circle cx="11" cy="11" r="8"/><line x1="21" y1="21" x2="16.65" y2="16.65"/></svg>';
|
||||||
|
|
||||||
|
const _toolLabels = {
|
||||||
|
web_search: 'Searching',
|
||||||
|
bash: 'Running',
|
||||||
|
};
|
||||||
|
|
||||||
|
const _toolIcons = {
|
||||||
|
web_search: _searchIcon,
|
||||||
|
};
|
||||||
|
|
||||||
|
function renderIcon(toolName) {
|
||||||
|
return _toolIcons[toolName.toLowerCase()] || '\u25B6';
|
||||||
|
}
|
||||||
|
|
||||||
|
function renderLabel(toolName) {
|
||||||
|
return _toolLabels[toolName.toLowerCase()] || toolName;
|
||||||
|
}
|
||||||
|
|
||||||
|
function renderThreadHTML(toolName, cmd) {
|
||||||
|
const label = renderLabel(toolName);
|
||||||
|
const icon = renderIcon(toolName);
|
||||||
|
const cmdHtml = cmd ? `<pre class="agent-thread-cmd">${esc(cmd)}</pre>` : '';
|
||||||
|
return `<div class="agent-thread-dot"></div><div class="agent-thread-header"><span class="agent-thread-icon">${icon}</span><span class="agent-thread-tool">${esc(label)}</span><span class="agent-thread-wave">\u2581\u2582\u2583</span></div><div class="agent-thread-content">${cmdHtml}</div>`;
|
||||||
|
}
|
||||||
|
|
||||||
|
const cases = CASES_JSON;
|
||||||
|
const results = cases.map(c => {
|
||||||
|
const html = renderThreadHTML(c.tool, c.cmd || '');
|
||||||
|
return { tool: c.tool, html };
|
||||||
|
});
|
||||||
|
console.log(JSON.stringify(results));
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def _run(cases: list) -> list:
|
||||||
|
js = _CHECK_JS.replace("CASES_JSON", json.dumps(cases))
|
||||||
|
proc = subprocess.run(
|
||||||
|
["node", "--input-type=module"],
|
||||||
|
input=js,
|
||||||
|
capture_output=True,
|
||||||
|
text=True,
|
||||||
|
encoding="utf-8",
|
||||||
|
cwd=str(_REPO),
|
||||||
|
timeout=30,
|
||||||
|
)
|
||||||
|
assert proc.returncode == 0, proc.stderr
|
||||||
|
return json.loads(proc.stdout.strip())
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not _HAS_NODE, reason="node binary not on PATH")
|
||||||
|
def test_web_search_icon_contains_svg():
|
||||||
|
out = _run([{"tool": "web_search"}])[0]
|
||||||
|
assert "<svg" in out["html"], "Expected <svg> in agent-thread-icon for web_search"
|
||||||
|
assert "Searching" in out["html"], "Expected 'Searching' label for web_search"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not _HAS_NODE, reason="node binary not on PATH")
|
||||||
|
def test_default_tool_icon_is_triangle():
|
||||||
|
out = _run([{"tool": "bash"}])[0]
|
||||||
|
assert "▶" in out["html"], "Expected ▶ icon for tools without custom icon"
|
||||||
|
assert "<svg" not in out["html"], "Expected no <svg> for bash"
|
||||||
|
assert "Running" in out["html"], "Expected 'Running' label for bash"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not _HAS_NODE, reason="node binary not on PATH")
|
||||||
|
def test_unknown_tool_falls_back_to_name():
|
||||||
|
out = _run([{"tool": "my_custom_tool"}])[0]
|
||||||
|
assert "▶" in out["html"], "Expected ▶ for unknown tool"
|
||||||
|
assert "my_custom_tool" in out["html"], "Expected tool name as label"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not _HAS_NODE, reason="node binary not on PATH")
|
||||||
|
def test_hostile_tool_name_is_escaped():
|
||||||
|
out = _run([{"tool": '<img src=x onerror="alert(1)">'}])[0]
|
||||||
|
assert "<img" in out["html"], "Expected < to be HTML-escaped"
|
||||||
|
assert ">" in out["html"], "Expected > to be HTML-escaped"
|
||||||
|
assert "<img" not in out["html"], "Raw <img> must not appear"
|
||||||
|
assert "onerror" not in out["html"] or """ in out["html"], "onerror must not be executable"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not _HAS_NODE, reason="node binary not on PATH")
|
||||||
|
def test_unknown_tool_case_insensitive_matches_icons():
|
||||||
|
out = _run([{"tool": "WEB_SEARCH"}, {"tool": "Web_Search"}])
|
||||||
|
for r in out:
|
||||||
|
assert "<svg" in r["html"], f"Expected SVG for case-variant '{r['tool']}'"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not _HAS_NODE, reason="node binary not on PATH")
|
||||||
|
def test_command_is_escaped():
|
||||||
|
out = _run([{"tool": "bash", "cmd": "echo $HOME && ls"}])[0]
|
||||||
|
assert "echo $HOME" in out["html"], "Expected command text in output"
|
||||||
@@ -0,0 +1,55 @@
|
|||||||
|
"""Fire-and-forget webhook tasks must be referenced until they finish.
|
||||||
|
|
||||||
|
asyncio keeps only a weak reference to a bare create_task() result, so a
|
||||||
|
delivery task could be garbage-collected before it ran and the webhook silently
|
||||||
|
dropped. WebhookManager now holds a strong reference for the task's lifetime and
|
||||||
|
releases it on completion.
|
||||||
|
"""
|
||||||
|
import asyncio
|
||||||
|
import sys
|
||||||
|
|
||||||
|
# webhook_manager does `from src.database import SessionLocal, Webhook` at import
|
||||||
|
# time. The shared test harness stubs src.database without Webhook, so ensure the
|
||||||
|
# attribute exists before importing the manager. These tests never touch the DB
|
||||||
|
# (the manager is built via __new__), so a placeholder class is sufficient.
|
||||||
|
_db = sys.modules.get("src.database")
|
||||||
|
if _db is not None and not hasattr(_db, "Webhook"):
|
||||||
|
_db.Webhook = type("Webhook", (), {})
|
||||||
|
|
||||||
|
from src.webhook_manager import WebhookManager # noqa: E402
|
||||||
|
|
||||||
|
|
||||||
|
def test_spawn_tracked_holds_then_releases_reference():
|
||||||
|
async def run():
|
||||||
|
wm = WebhookManager.__new__(WebhookManager)
|
||||||
|
wm._bg_tasks = set()
|
||||||
|
|
||||||
|
gate = asyncio.Event()
|
||||||
|
|
||||||
|
async def work():
|
||||||
|
await gate.wait()
|
||||||
|
|
||||||
|
task = wm._spawn_tracked(work())
|
||||||
|
# Referenced while in flight (this is what stops GC from collecting it).
|
||||||
|
assert task in wm._bg_tasks
|
||||||
|
gate.set()
|
||||||
|
await task
|
||||||
|
# Reference released once done, so the set does not grow unbounded.
|
||||||
|
assert task not in wm._bg_tasks
|
||||||
|
|
||||||
|
asyncio.run(run())
|
||||||
|
|
||||||
|
|
||||||
|
def test_spawn_tracked_runs_the_coroutine():
|
||||||
|
async def run():
|
||||||
|
wm = WebhookManager.__new__(WebhookManager)
|
||||||
|
wm._bg_tasks = set()
|
||||||
|
ran = []
|
||||||
|
|
||||||
|
async def work():
|
||||||
|
ran.append(True)
|
||||||
|
|
||||||
|
await wm._spawn_tracked(work())
|
||||||
|
assert ran == [True]
|
||||||
|
|
||||||
|
asyncio.run(run())
|
||||||
@@ -0,0 +1,328 @@
|
|||||||
|
"""Workspace confinement.
|
||||||
|
|
||||||
|
The agent's per-turn workspace is a single context-local binding set in
|
||||||
|
execute_tool_block. The shared path resolvers (_resolve_tool_path /
|
||||||
|
_resolve_search_root) and the subprocess cwd helper (agent_cwd) read it, so
|
||||||
|
confinement is enforced in ONE place: a tool that uses the shared helpers is
|
||||||
|
confined automatically and a new tool cannot accidentally bypass it.
|
||||||
|
|
||||||
|
Covers: the resolver helper, the central binding (the safety net), end-to-end
|
||||||
|
confinement of read/write/edit/grep/ls + subprocess cwd via execute_tool_block,
|
||||||
|
the get_workspace tool, no-leak across calls, and the admin-gated browse route.
|
||||||
|
"""
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
from types import SimpleNamespace
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from src.tool_execution import (
|
||||||
|
_AGENT_WORKDIR,
|
||||||
|
_active_workspace,
|
||||||
|
_resolve_search_root,
|
||||||
|
_resolve_tool_path,
|
||||||
|
_resolve_tool_path_in_workspace,
|
||||||
|
agent_cwd,
|
||||||
|
execute_tool_block,
|
||||||
|
get_active_workspace,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _block(tool, content=""):
|
||||||
|
return SimpleNamespace(tool_type=tool, content=content)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def ws():
|
||||||
|
d = tempfile.mkdtemp()
|
||||||
|
with open(os.path.join(d, "a.txt"), "w") as f:
|
||||||
|
f.write("x")
|
||||||
|
return d
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def admin(monkeypatch):
|
||||||
|
"""Pass the public-tool gate so file tools dispatch in tests."""
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"src.tool_execution.owner_is_admin_or_single_user", lambda owner: True
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── the resolver helper ────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def test_resolver_confines(ws):
|
||||||
|
real = os.path.realpath(os.path.join(ws, "a.txt"))
|
||||||
|
assert _resolve_tool_path_in_workspace(ws, "a.txt") == real # relative
|
||||||
|
assert _resolve_tool_path_in_workspace(ws, os.path.join(ws, "a.txt")) == real # abs inside
|
||||||
|
outside = tempfile.mkdtemp()
|
||||||
|
with pytest.raises(ValueError): # abs outside
|
||||||
|
_resolve_tool_path_in_workspace(ws, os.path.join(outside, "x.txt"))
|
||||||
|
with pytest.raises(ValueError): # parent escape
|
||||||
|
_resolve_tool_path_in_workspace(ws, os.path.join("..", "..", "escape.txt"))
|
||||||
|
|
||||||
|
|
||||||
|
def test_resolver_blocks_sensitive_inside_workspace(ws):
|
||||||
|
os.makedirs(os.path.join(ws, ".ssh"), exist_ok=True)
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
_resolve_tool_path_in_workspace(ws, ".ssh/authorized_keys")
|
||||||
|
|
||||||
|
|
||||||
|
# ── the central binding: the safety net ─────────────────────────────────
|
||||||
|
|
||||||
|
def test_active_binding_confines_shared_resolvers(ws):
|
||||||
|
"""ANY tool resolving paths through the shared helpers is confined while the
|
||||||
|
binding is active, without doing anything workspace-specific itself. This is
|
||||||
|
what stops a newly added tool from accidentally ignoring the workspace."""
|
||||||
|
token = _active_workspace.set(ws)
|
||||||
|
try:
|
||||||
|
assert get_active_workspace() == ws
|
||||||
|
assert agent_cwd() == ws
|
||||||
|
assert _resolve_tool_path("a.txt") == os.path.realpath(os.path.join(ws, "a.txt"))
|
||||||
|
with pytest.raises(ValueError): # normally-allowed root, now outside ws
|
||||||
|
_resolve_tool_path("/tmp/whatever.txt")
|
||||||
|
assert _resolve_search_root("") == os.path.realpath(ws)
|
||||||
|
finally:
|
||||||
|
_active_workspace.reset(token)
|
||||||
|
|
||||||
|
|
||||||
|
def test_no_binding_uses_default_roots():
|
||||||
|
assert get_active_workspace() is None
|
||||||
|
assert agent_cwd() == _AGENT_WORKDIR
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
_resolve_tool_path("/etc/hosts")
|
||||||
|
|
||||||
|
|
||||||
|
# ── end-to-end via execute_tool_block (sets + resets the binding) ───────
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_read_write_edit_confined_e2e(ws, admin):
|
||||||
|
_, r = await execute_tool_block(_block("write_file", "note.txt\nhello"), owner="a", workspace=ws)
|
||||||
|
assert r["exit_code"] == 0 and os.path.isfile(os.path.join(ws, "note.txt"))
|
||||||
|
_, r = await execute_tool_block(_block("read_file", "note.txt"), owner="a", workspace=ws)
|
||||||
|
assert r["exit_code"] == 0 and r["output"] == "hello"
|
||||||
|
|
||||||
|
with open(os.path.join(ws, "f.txt"), "w") as f:
|
||||||
|
f.write("foo bar")
|
||||||
|
_, r = await execute_tool_block(
|
||||||
|
_block("edit_file", json.dumps({"path": "f.txt", "old_string": "foo", "new_string": "baz"})),
|
||||||
|
owner="a", workspace=ws,
|
||||||
|
)
|
||||||
|
assert r["exit_code"] == 0
|
||||||
|
with open(os.path.join(ws, "f.txt")) as f:
|
||||||
|
assert f.read() == "baz bar"
|
||||||
|
|
||||||
|
# outside the workspace is rejected, and nothing is created
|
||||||
|
outside = tempfile.mkdtemp()
|
||||||
|
of = os.path.join(outside, "secret.txt")
|
||||||
|
with open(of, "w") as f:
|
||||||
|
f.write("nope")
|
||||||
|
_, r = await execute_tool_block(_block("read_file", of), owner="a", workspace=ws)
|
||||||
|
assert r["exit_code"] == 1 and "outside the workspace" in r["error"]
|
||||||
|
escape = os.path.join(outside, "_esc.txt")
|
||||||
|
_, r = await execute_tool_block(_block("write_file", f"{escape}\nx"), owner="a", workspace=ws)
|
||||||
|
assert r["exit_code"] == 1 and "outside the workspace" in r["error"]
|
||||||
|
assert not os.path.exists(escape)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_grep_and_ls_confined_e2e(ws, admin):
|
||||||
|
with open(os.path.join(ws, "doc.txt"), "w") as f:
|
||||||
|
f.write("hello workspace\n")
|
||||||
|
_, r = await execute_tool_block(_block("grep", json.dumps({"pattern": "hello"})), owner="a", workspace=ws)
|
||||||
|
assert r["exit_code"] == 0 and "doc.txt" in r["output"]
|
||||||
|
outside = tempfile.mkdtemp()
|
||||||
|
_, r = await execute_tool_block(_block("grep", json.dumps({"pattern": "x", "path": outside})), owner="a", workspace=ws)
|
||||||
|
assert r["exit_code"] == 1 and "outside the workspace" in r["error"]
|
||||||
|
_, r = await execute_tool_block(_block("ls", ""), owner="a", workspace=ws)
|
||||||
|
assert r["exit_code"] == 0 and "doc.txt" in r["output"]
|
||||||
|
_, r = await execute_tool_block(_block("ls", outside), owner="a", workspace=ws)
|
||||||
|
assert r["exit_code"] == 1 and "outside the workspace" in r["error"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_subprocess_cwd_is_workspace_e2e(ws, admin):
|
||||||
|
"""python tool runs with cwd = workspace (OS-agnostic probe)."""
|
||||||
|
_, r = await execute_tool_block(_block("python", "import os; print(os.getcwd())"), owner="a", workspace=ws)
|
||||||
|
assert r["exit_code"] == 0
|
||||||
|
assert os.path.realpath(r["output"].strip()) == os.path.realpath(ws)
|
||||||
|
|
||||||
|
|
||||||
|
# ── get_workspace tool ──────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_workspace_tool(ws, admin):
|
||||||
|
_, r = await execute_tool_block(_block("get_workspace", ""), owner="a", workspace=ws)
|
||||||
|
assert r["exit_code"] == 0 and r["output"].startswith(ws) and "not sandboxed" in r["output"]
|
||||||
|
_, r = await execute_tool_block(_block("get_workspace", ""), owner="a") # none active
|
||||||
|
assert r["exit_code"] == 0 and "No workspace" in r["output"]
|
||||||
|
|
||||||
|
|
||||||
|
# ── no leak across calls ────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_binding_does_not_leak(ws, admin):
|
||||||
|
await execute_tool_block(_block("ls", ""), owner="a", workspace=ws)
|
||||||
|
assert get_active_workspace() is None
|
||||||
|
|
||||||
|
|
||||||
|
# ── tool selection: an active workspace is the file-work signal ─────────
|
||||||
|
# A vague ("low-signal") message like "look at the local project" matches no
|
||||||
|
# domain keywords, so retrieval is normally skipped. When a workspace is set it
|
||||||
|
# must still surface the file tools, otherwise the agent says it has no file
|
||||||
|
# access (the bug this guards against).
|
||||||
|
|
||||||
|
def _sent_tool_names(monkeypatch, *, workspace):
|
||||||
|
import asyncio
|
||||||
|
import src.agent_loop as al
|
||||||
|
|
||||||
|
monkeypatch.setattr(al, "get_setting", lambda key, default=None: default, raising=False)
|
||||||
|
monkeypatch.setattr(al, "get_mcp_manager", lambda: None, raising=False)
|
||||||
|
monkeypatch.setattr(al, "estimate_tokens", lambda *a, **k: 10, raising=False)
|
||||||
|
# Isolate the selection logic from owner gating (tested separately).
|
||||||
|
monkeypatch.setattr(al, "blocked_tools_for_owner", lambda owner: set(), raising=False)
|
||||||
|
|
||||||
|
captured = []
|
||||||
|
|
||||||
|
async def _fake_stream(_candidates, messages, **kwargs):
|
||||||
|
captured.append(kwargs.get("tools"))
|
||||||
|
yield "data: " + json.dumps({"delta": "ok"}) + "\n\n"
|
||||||
|
yield "data: [DONE]\n\n"
|
||||||
|
|
||||||
|
monkeypatch.setattr(al, "stream_llm_with_fallback", _fake_stream, raising=False)
|
||||||
|
|
||||||
|
async def _run():
|
||||||
|
gen = al.stream_agent_loop(
|
||||||
|
"https://api.openai.com/v1", "gpt-test",
|
||||||
|
[{"role": "user", "content": "look at the local project"}],
|
||||||
|
max_rounds=1, relevant_tools=None, owner="admin", workspace=workspace,
|
||||||
|
)
|
||||||
|
return [c async for c in gen]
|
||||||
|
|
||||||
|
asyncio.run(_run())
|
||||||
|
schemas = captured[0] or []
|
||||||
|
return {t["function"]["name"] for t in schemas if isinstance(t, dict) and "function" in t}
|
||||||
|
|
||||||
|
|
||||||
|
def test_low_signal_with_workspace_surfaces_readonly_file_tools(monkeypatch):
|
||||||
|
names = _sent_tool_names(monkeypatch, workspace="/tmp")
|
||||||
|
# read-only nav tools surface so the agent can explore
|
||||||
|
assert "read_file" in names
|
||||||
|
assert "get_workspace" in names
|
||||||
|
assert "grep" in names
|
||||||
|
# write/shell tools do NOT surface on a vague message
|
||||||
|
assert "write_file" not in names
|
||||||
|
assert "edit_file" not in names
|
||||||
|
assert "bash" not in names
|
||||||
|
assert "python" not in names
|
||||||
|
|
||||||
|
|
||||||
|
def test_low_signal_without_workspace_excludes_file_tools(monkeypatch):
|
||||||
|
names = _sent_tool_names(monkeypatch, workspace=None)
|
||||||
|
assert "read_file" not in names
|
||||||
|
assert "get_workspace" not in names
|
||||||
|
|
||||||
|
|
||||||
|
# ── browse route is admin-gated ─────────────────────────────────────────
|
||||||
|
|
||||||
|
def test_browse_is_admin_gated(monkeypatch):
|
||||||
|
from fastapi import HTTPException
|
||||||
|
import routes.workspace_routes as wr
|
||||||
|
|
||||||
|
router = wr.setup_workspace_routes()
|
||||||
|
browse = next(r.endpoint for r in router.routes if r.path == "/api/workspace/browse")
|
||||||
|
|
||||||
|
monkeypatch.setattr(wr, "get_current_user", lambda req: "bob")
|
||||||
|
monkeypatch.setattr(wr, "owner_is_admin_or_single_user", lambda owner: False)
|
||||||
|
with pytest.raises(HTTPException) as ei:
|
||||||
|
browse(request=object(), path="/")
|
||||||
|
assert ei.value.status_code == 403
|
||||||
|
|
||||||
|
monkeypatch.setattr(wr, "owner_is_admin_or_single_user", lambda owner: True)
|
||||||
|
out = browse(request=object(), path=os.path.expanduser("~"))
|
||||||
|
assert "dirs" in out and "path" in out
|
||||||
|
assert all("name" in d and "path" in d for d in out["dirs"])
|
||||||
|
|
||||||
|
|
||||||
|
# ── bind-time vetting of the workspace root ─────────────────────────────
|
||||||
|
|
||||||
|
def test_vet_workspace_accepts_normal_dir(ws):
|
||||||
|
from src.tool_execution import vet_workspace
|
||||||
|
assert vet_workspace(ws) == os.path.realpath(ws)
|
||||||
|
|
||||||
|
|
||||||
|
def test_vet_workspace_rejects_sensitive_root(tmp_path):
|
||||||
|
# The resolver deny-lists sensitive paths inside the workspace, but the
|
||||||
|
# empty-path search root is the workspace itself - a sensitive root must
|
||||||
|
# be rejected before it is bound or `ls` with no path would list it.
|
||||||
|
from src.tool_execution import vet_workspace
|
||||||
|
ssh_dir = tmp_path / ".ssh"
|
||||||
|
ssh_dir.mkdir()
|
||||||
|
assert vet_workspace(str(ssh_dir)) is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_vet_workspace_rejects_nondir_and_empty(ws):
|
||||||
|
from src.tool_execution import vet_workspace
|
||||||
|
assert vet_workspace(os.path.join(ws, "a.txt")) is None # file, not dir
|
||||||
|
assert vet_workspace("/nonexistent/path/xyz") is None
|
||||||
|
assert vet_workspace("") is None
|
||||||
|
assert vet_workspace(" ") is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_vet_workspace_rejects_filesystem_root():
|
||||||
|
# Binding / would make every absolute path "inside" the workspace,
|
||||||
|
# collapsing confinement into host-wide file access.
|
||||||
|
from src.tool_execution import vet_workspace
|
||||||
|
assert vet_workspace("/") is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_browse_marks_root_unselectable_and_vet_endpoint(monkeypatch):
|
||||||
|
import routes.workspace_routes as wr
|
||||||
|
|
||||||
|
router = wr.setup_workspace_routes()
|
||||||
|
browse = next(r.endpoint for r in router.routes if r.path == "/api/workspace/browse")
|
||||||
|
vet = next(r.endpoint for r in router.routes if r.path == "/api/workspace/vet")
|
||||||
|
|
||||||
|
monkeypatch.setattr(wr, "get_current_user", lambda req: "admin")
|
||||||
|
monkeypatch.setattr(wr, "owner_is_admin_or_single_user", lambda owner: True)
|
||||||
|
|
||||||
|
out = browse(request=object(), path="/")
|
||||||
|
assert out["selectable"] is False
|
||||||
|
out = browse(request=object(), path=os.path.expanduser("~"))
|
||||||
|
assert out["selectable"] is True
|
||||||
|
|
||||||
|
assert vet(request=object(), path="/") == {"ok": False, "path": None}
|
||||||
|
home = os.path.realpath(os.path.expanduser("~"))
|
||||||
|
assert vet(request=object(), path="~") == {"ok": True, "path": home}
|
||||||
|
|
||||||
|
from fastapi import HTTPException
|
||||||
|
monkeypatch.setattr(wr, "owner_is_admin_or_single_user", lambda owner: False)
|
||||||
|
with pytest.raises(HTTPException) as ei:
|
||||||
|
vet(request=object(), path="/tmp")
|
||||||
|
assert ei.value.status_code == 403
|
||||||
|
|
||||||
|
|
||||||
|
# ── send-time privilege gate (no path oracle for non-admins) ────────────
|
||||||
|
|
||||||
|
def test_request_workspace_gate(ws, monkeypatch):
|
||||||
|
"""Non-admin chat callers must get a uniform drop with no vetting: the
|
||||||
|
workspace_rejected signal would otherwise reveal which host paths exist."""
|
||||||
|
import routes.chat_routes as cr
|
||||||
|
|
||||||
|
monkeypatch.setattr(cr, "get_current_user", lambda req: "bob")
|
||||||
|
vet_calls = []
|
||||||
|
import src.tool_execution as te
|
||||||
|
real_vet = te.vet_workspace
|
||||||
|
monkeypatch.setattr(te, "vet_workspace", lambda p: vet_calls.append(p) or real_vet(p))
|
||||||
|
|
||||||
|
import src.tool_security as ts
|
||||||
|
monkeypatch.setattr(ts, "owner_is_admin_or_single_user", lambda owner: False)
|
||||||
|
# Valid and invalid paths are indistinguishable for a non-admin: both
|
||||||
|
# drop silently, and the path never reaches the filesystem.
|
||||||
|
assert cr._resolve_request_workspace(object(), ws) == ("", "")
|
||||||
|
assert cr._resolve_request_workspace(object(), "/nonexistent/xyz") == ("", "")
|
||||||
|
assert vet_calls == []
|
||||||
|
|
||||||
|
monkeypatch.setattr(ts, "owner_is_admin_or_single_user", lambda owner: True)
|
||||||
|
assert cr._resolve_request_workspace(object(), ws) == (os.path.realpath(ws), "")
|
||||||
|
assert cr._resolve_request_workspace(object(), "/nonexistent/xyz") == ("", "/nonexistent/xyz")
|
||||||
Reference in New Issue
Block a user