mirror of
https://github.com/pewdiepie-archdaemon/odysseus.git
synced 2026-06-16 01:35:36 -04:00
Compare commits
16 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 68f19a889a | |||
| 422f23fb12 | |||
| 0f966d6b9f | |||
| 7b09491557 | |||
| fafaf089c5 | |||
| b58af4267b | |||
| 8ff76f083c | |||
| 2196869c86 | |||
| dd2e23c9af | |||
| facc50cb0f | |||
| 074a1e6eff | |||
| 2fab378c6a | |||
| 5bafc30622 | |||
| d6d2e17214 | |||
| f4e8990635 | |||
| fc3a5e555e |
@@ -331,8 +331,8 @@ if AUTH_ENABLED:
|
||||
request.state.current_user = "internal-tool"
|
||||
request.state.api_token = False
|
||||
return await call_next(request)
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as _e:
|
||||
logger.warning("Internal tool auth header check failed", exc_info=_e)
|
||||
# Allow DIRECT localhost requests (internal service calls from
|
||||
# heartbeats etc.). Tunnel/proxy-forwarded requests are excluded by
|
||||
# _is_trusted_loopback so LOCALHOST_BYPASS can't be abused over a
|
||||
@@ -385,11 +385,10 @@ if AUTH_ENABLED:
|
||||
_db.close()
|
||||
try:
|
||||
await _asyncio.to_thread(_do)
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as _e:
|
||||
logger.debug("Failed to update token last_used_at", exc_info=_e)
|
||||
_asyncio.create_task(_touch_last_used(matched_id))
|
||||
# Keep bearer-token callers out of normal cookie/user
|
||||
# routes. API-aware routes can read api_token_owner.
|
||||
request.state.current_user = "api"
|
||||
request.state.api_token = True
|
||||
request.state.api_token_id = matched_id
|
||||
@@ -464,8 +463,8 @@ async def serve_generated_image(filename: str, request: Request):
|
||||
_db.close()
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as _e:
|
||||
logger.warning("Image ownership verification failed for %r", filename, exc_info=_e)
|
||||
ext = filename.rsplit('.', 1)[-1].lower()
|
||||
mime = {
|
||||
"png": "image/png", "jpg": "image/jpeg", "jpeg": "image/jpeg",
|
||||
|
||||
+17
-3
@@ -5,8 +5,9 @@ offers and pair to it, without duplicating any LLM logic.
|
||||
|
||||
Auth is enforced globally by AuthMiddleware (app.py), so reaching a handler here
|
||||
means the caller is authenticated by either a cookie session or a Bearer `ody_`
|
||||
API token. The read endpoints (ping/info/models) accept either; the pairing
|
||||
endpoints are admin-cookie only.
|
||||
API token. Ping/info accept either credential type, models requires a chat-
|
||||
scoped API token for bearer callers, and the pairing endpoints are admin-cookie
|
||||
only.
|
||||
|
||||
Pairing CSRF posture: minting happens ONLY on POST. The session cookie is
|
||||
SameSite=Lax (routes/auth_routes.py), which a browser does not send on a
|
||||
@@ -18,7 +19,7 @@ on a GET would be unsafe (Lax cookies ride top-level GET navigations), so GET
|
||||
|
||||
import html
|
||||
|
||||
from fastapi import APIRouter, Request
|
||||
from fastapi import APIRouter, HTTPException, Request
|
||||
from fastapi.responses import HTMLResponse
|
||||
|
||||
from core.middleware import require_admin
|
||||
@@ -52,6 +53,18 @@ def owner_can_see(row_owner, owner) -> bool:
|
||||
return row_owner is None or row_owner == owner
|
||||
|
||||
|
||||
def require_models_scope(request: Request) -> None:
|
||||
"""Require the companion chat scope for bearer-token model inventory."""
|
||||
if not getattr(request.state, "api_token", False):
|
||||
return
|
||||
scopes = getattr(request.state, "api_token_scopes", None) or []
|
||||
if isinstance(scopes, str):
|
||||
scopes = [scope.strip() for scope in scopes.split(",")]
|
||||
scope_set = {str(scope).strip() for scope in scopes if str(scope).strip()}
|
||||
if _pairing.COMPANION_SCOPE not in scope_set:
|
||||
raise HTTPException(403, "API token requires chat scope")
|
||||
|
||||
|
||||
def mint_pairing_token(owner: str, invalidate=None) -> tuple[str, str]:
|
||||
"""Mint a pairing token AND invalidate the auth middleware's in-memory token
|
||||
cache, so the new token is accepted on the very next request without a server
|
||||
@@ -103,6 +116,7 @@ def setup_companion_routes() -> APIRouter:
|
||||
rows -- the same rule as owner_filter. Read-only; never returns api_key
|
||||
material.
|
||||
"""
|
||||
require_models_scope(request)
|
||||
import json as _json
|
||||
|
||||
from core.database import SessionLocal, ModelEndpoint
|
||||
|
||||
+22
-2
@@ -2,12 +2,15 @@ import os
|
||||
import logging
|
||||
import sqlite3
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from sqlalchemy import event, create_engine, Column, String, Text, Boolean, DateTime, Integer, ForeignKey, JSON, Index, func, text
|
||||
from sqlalchemy.engine import Engine
|
||||
from sqlalchemy.types import TypeDecorator
|
||||
from sqlalchemy.ext.declarative import declarative_base, declared_attr
|
||||
from sqlalchemy.orm import relationship, sessionmaker, backref
|
||||
|
||||
from src.runtime_paths import get_app_root
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Create base class for declarative models
|
||||
@@ -29,9 +32,26 @@ class TimestampMixin:
|
||||
def updated_at(cls):
|
||||
return Column(DateTime, default=utcnow_naive, onupdate=utcnow_naive, nullable=False)
|
||||
|
||||
# Get database URL from environment, default to SQLite in DATA_DIR
|
||||
# Ensure the writable data directory exists before SQLite connects.
|
||||
from src.constants import DATA_DIR, AUTH_FILE, MEMORY_FILE, USER_PREFS_FILE, SETTINGS_FILE
|
||||
DATABASE_URL = os.getenv("DATABASE_URL", f"sqlite:///{DATA_DIR}/app.db")
|
||||
Path(DATA_DIR).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
def _default_database_url() -> str:
|
||||
return f"sqlite:///{Path(DATA_DIR) / 'app.db'}"
|
||||
|
||||
|
||||
def _normalize_sqlite_url(url: str) -> str:
|
||||
if not url.startswith("sqlite:///"):
|
||||
return url
|
||||
db_path = url.replace("sqlite:///", "", 1)
|
||||
if db_path == ":memory:" or os.path.isabs(db_path):
|
||||
return url
|
||||
return f"sqlite:///{(Path(get_app_root()) / db_path).resolve().as_posix()}"
|
||||
|
||||
|
||||
# Get database URL from environment, default to SQLite in DATA_DIR
|
||||
DATABASE_URL = _normalize_sqlite_url(os.getenv("DATABASE_URL", _default_database_url()))
|
||||
|
||||
# Create engine
|
||||
engine = create_engine(
|
||||
|
||||
@@ -6,6 +6,7 @@ Imports MemoryManager and MemoryVectorStore from the Odysseus codebase.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
@@ -23,6 +24,55 @@ _memory_manager = None
|
||||
_memory_vector = None
|
||||
_initialized = False
|
||||
|
||||
_OWNER_ENV_KEYS = ("ODYSSEUS_MCP_MEMORY_OWNER", "ODYSSEUS_MEMORY_OWNER")
|
||||
_OWNER_SCOPE_ERROR = (
|
||||
"Error: Memory MCP owner is not configured for an owner-scoped memory store. "
|
||||
"Set ODYSSEUS_MCP_MEMORY_OWNER for this server or use the owner-aware native memory tool."
|
||||
)
|
||||
|
||||
|
||||
def _configured_owner() -> str | None:
|
||||
for key in _OWNER_ENV_KEYS:
|
||||
owner = os.environ.get(key, "").strip()
|
||||
if owner:
|
||||
return owner
|
||||
return None
|
||||
|
||||
|
||||
def _entry_owner(entry: dict) -> str | None:
|
||||
owner = entry.get("owner")
|
||||
if owner is None:
|
||||
return None
|
||||
owner_text = str(owner).strip()
|
||||
return owner_text or None
|
||||
|
||||
|
||||
def _owner_scoped_store(entries: list[dict]) -> bool:
|
||||
return any(_entry_owner(entry) for entry in entries if isinstance(entry, dict))
|
||||
|
||||
|
||||
def _scope_entries() -> tuple[str | None, list[dict], list[dict], str | None]:
|
||||
"""Return configured owner, all entries, visible entries, and optional error."""
|
||||
entries = _memory_manager.load_all()
|
||||
owner = _configured_owner()
|
||||
if owner is None and _owner_scoped_store(entries):
|
||||
return None, entries, [], _OWNER_SCOPE_ERROR
|
||||
if owner is None:
|
||||
visible = [
|
||||
entry for entry in entries
|
||||
if isinstance(entry, dict) and _entry_owner(entry) is None
|
||||
]
|
||||
else:
|
||||
visible = [
|
||||
entry for entry in entries
|
||||
if isinstance(entry, dict) and _entry_owner(entry) == owner
|
||||
]
|
||||
return owner, entries, visible, None
|
||||
|
||||
|
||||
def _text_result(text: str) -> list[TextContent]:
|
||||
return [TextContent(type="text", text=text)]
|
||||
|
||||
|
||||
def _ensure_init():
|
||||
"""Lazy-init memory managers on first use."""
|
||||
@@ -75,24 +125,26 @@ async def list_tools() -> list[Tool]:
|
||||
@server.call_tool()
|
||||
async def call_tool(name: str, arguments: dict) -> list[TextContent]:
|
||||
if name != "manage_memory":
|
||||
return [TextContent(type="text", text=f"Unknown tool: {name}")]
|
||||
return _text_result(f"Unknown tool: {name}")
|
||||
|
||||
_ensure_init()
|
||||
if not _memory_manager:
|
||||
return [TextContent(type="text", text="Error: Memory manager not available")]
|
||||
return _text_result("Error: Memory manager not available")
|
||||
|
||||
action = arguments.get("action", "")
|
||||
|
||||
if action == "list":
|
||||
category_filter = arguments.get("category", "")
|
||||
memories = _memory_manager.load()
|
||||
_owner, _all_memories, memories, scope_error = _scope_entries()
|
||||
if scope_error:
|
||||
return _text_result(scope_error)
|
||||
if category_filter:
|
||||
memories = [m for m in memories if m.get("category", "").lower() == category_filter.lower()]
|
||||
if not memories:
|
||||
msg = "No memories found"
|
||||
if category_filter:
|
||||
msg += f" in category '{category_filter}'"
|
||||
return [TextContent(type="text", text=msg + ".")]
|
||||
return _text_result(msg + ".")
|
||||
|
||||
lines = [f"Found {len(memories)} memory entries:\n"]
|
||||
for m in memories:
|
||||
@@ -102,15 +154,17 @@ async def call_tool(name: str, arguments: dict) -> list[TextContent]:
|
||||
if len(text) > 150:
|
||||
text = text[:150] + "..."
|
||||
lines.append(f"- [{cat}] `{mid}` — {text}")
|
||||
return [TextContent(type="text", text="\n".join(lines))]
|
||||
return _text_result("\n".join(lines))
|
||||
|
||||
elif action == "add":
|
||||
text = arguments.get("text", "")
|
||||
category = arguments.get("category", "fact")
|
||||
if not text:
|
||||
return [TextContent(type="text", text="Error: Memory text cannot be empty")]
|
||||
entry = _memory_manager.add_entry(text, source="ai_agent", category=category)
|
||||
memories = _memory_manager.load_all()
|
||||
return _text_result("Error: Memory text cannot be empty")
|
||||
owner, memories, _visible, scope_error = _scope_entries()
|
||||
if scope_error:
|
||||
return _text_result(scope_error)
|
||||
entry = _memory_manager.add_entry(text, source="ai_agent", category=category, owner=owner)
|
||||
memories.append(entry)
|
||||
_memory_manager.save(memories)
|
||||
if _memory_vector and _memory_vector.healthy:
|
||||
@@ -118,25 +172,28 @@ async def call_tool(name: str, arguments: dict) -> list[TextContent]:
|
||||
_memory_vector.add(entry["id"], text)
|
||||
except Exception:
|
||||
pass
|
||||
return [TextContent(type="text", text=f"Memory added: [{category}] {text} (id: {entry['id'][:8]})")]
|
||||
return _text_result(f"Memory added: [{category}] {text} (id: {entry['id'][:8]})")
|
||||
|
||||
elif action == "edit":
|
||||
memory_id = arguments.get("memory_id", "")
|
||||
new_text = arguments.get("text", "")
|
||||
if not memory_id or not new_text:
|
||||
return [TextContent(type="text", text="Error: edit needs memory_id and text")]
|
||||
memories = _memory_manager.load_all()
|
||||
found = False
|
||||
return _text_result("Error: edit needs memory_id and text")
|
||||
_owner, memories, visible, scope_error = _scope_entries()
|
||||
if scope_error:
|
||||
return _text_result(scope_error)
|
||||
full_id = None
|
||||
for m in memories:
|
||||
for m in visible:
|
||||
if m.get("id", "").startswith(memory_id):
|
||||
m["text"] = new_text
|
||||
m["timestamp"] = int(time.time())
|
||||
found = True
|
||||
full_id = m["id"]
|
||||
break
|
||||
if not found:
|
||||
return [TextContent(type="text", text=f"Error: Memory '{memory_id}' not found")]
|
||||
if not full_id:
|
||||
return _text_result(f"Error: Memory '{memory_id}' not found")
|
||||
for m in memories:
|
||||
if m.get("id") == full_id:
|
||||
m["text"] = new_text
|
||||
m["timestamp"] = int(time.time())
|
||||
break
|
||||
_memory_manager.save(memories)
|
||||
if _memory_vector and _memory_vector.healthy and full_id:
|
||||
try:
|
||||
@@ -144,24 +201,26 @@ async def call_tool(name: str, arguments: dict) -> list[TextContent]:
|
||||
_memory_vector.add(full_id, new_text)
|
||||
except Exception:
|
||||
pass
|
||||
return [TextContent(type="text", text=f"Memory updated: {new_text}")]
|
||||
return _text_result(f"Memory updated: {new_text}")
|
||||
|
||||
elif action == "delete":
|
||||
memory_id = arguments.get("memory_id", "")
|
||||
if not memory_id:
|
||||
return [TextContent(type="text", text="Error: delete needs memory_id")]
|
||||
memories = _memory_manager.load_all()
|
||||
return _text_result("Error: delete needs memory_id")
|
||||
_owner, memories, visible, scope_error = _scope_entries()
|
||||
if scope_error:
|
||||
return _text_result(scope_error)
|
||||
full_id = None
|
||||
deleted_text = ""
|
||||
deleted_category = ""
|
||||
for m in memories:
|
||||
for m in visible:
|
||||
if m.get("id", "").startswith(memory_id):
|
||||
full_id = m["id"]
|
||||
deleted_text = m.get("text", "")
|
||||
deleted_category = m.get("category", "")
|
||||
break
|
||||
if not full_id:
|
||||
return [TextContent(type="text", text=f"Error: Memory '{memory_id}' not found")]
|
||||
return _text_result(f"Error: Memory '{memory_id}' not found")
|
||||
memories = [m for m in memories if m.get("id") != full_id]
|
||||
_memory_manager.save(memories)
|
||||
if _memory_vector and _memory_vector.healthy and full_id:
|
||||
@@ -171,30 +230,32 @@ async def call_tool(name: str, arguments: dict) -> list[TextContent]:
|
||||
pass
|
||||
cat = f"[{deleted_category}] " if deleted_category else ""
|
||||
snippet = deleted_text if len(deleted_text) <= 120 else deleted_text[:117] + "..."
|
||||
return [TextContent(type="text", text=f"Memory deleted: {cat}{snippet} (id: {memory_id})")]
|
||||
return _text_result(f"Memory deleted: {cat}{snippet} (id: {memory_id})")
|
||||
|
||||
elif action == "search":
|
||||
query = arguments.get("text", "")
|
||||
if not query:
|
||||
return [TextContent(type="text", text="Error: search needs text (query)")]
|
||||
memories = _memory_manager.load()
|
||||
return _text_result("Error: search needs text (query)")
|
||||
_owner, _all_memories, memories, scope_error = _scope_entries()
|
||||
if scope_error:
|
||||
return _text_result(scope_error)
|
||||
if hasattr(_memory_manager, 'get_relevant_memories'):
|
||||
results = _memory_manager.get_relevant_memories(query, memories, threshold=0.05, max_items=20)
|
||||
else:
|
||||
query_lower = query.lower()
|
||||
results = [m for m in memories if query_lower in m.get("text", "").lower()][:20]
|
||||
if not results:
|
||||
return [TextContent(type="text", text=f"No memories found matching '{query}'.")]
|
||||
return _text_result(f"No memories found matching '{query}'.")
|
||||
lines = [f"Found {len(results)} matching memories:\n"]
|
||||
for m in results:
|
||||
cat = m.get("category", "fact")
|
||||
mid = m.get("id", "?")[:8]
|
||||
text = m.get("text", "")
|
||||
lines.append(f"- [{cat}] `{mid}` — {text}")
|
||||
return [TextContent(type="text", text="\n".join(lines))]
|
||||
return _text_result("\n".join(lines))
|
||||
|
||||
else:
|
||||
return [TextContent(type="text", text=f"Error: Unknown action '{action}'. Use: list, add, edit, delete, search")]
|
||||
return _text_result(f"Error: Unknown action '{action}'. Use: list, add, edit, delete, search")
|
||||
|
||||
|
||||
async def run():
|
||||
|
||||
@@ -160,6 +160,8 @@ def setup_api_token_routes() -> APIRouter:
|
||||
payload = await request.json()
|
||||
except Exception:
|
||||
payload = {}
|
||||
if not isinstance(payload, dict):
|
||||
payload = {}
|
||||
with get_db_session() as db:
|
||||
token = db.query(ApiToken).filter(ApiToken.id == token_id).first()
|
||||
if not token:
|
||||
|
||||
@@ -14,7 +14,7 @@ from core.database import Session as DBSession, ModelEndpoint
|
||||
from src.llm_core import normalize_model_id
|
||||
from src.endpoint_resolver import normalize_base
|
||||
from src.context_compactor import maybe_compact, trim_for_context
|
||||
from src.auth_helpers import get_current_user
|
||||
from src.auth_helpers import effective_user
|
||||
from src.prompt_security import untrusted_context_message
|
||||
from routes.prefs_routes import _load_for_user as load_prefs_for_user
|
||||
|
||||
@@ -78,7 +78,7 @@ def _enforce_chat_privileges(request, sess) -> None:
|
||||
which means unrestricted allowed_models / zero cap -> no-op for them.
|
||||
"""
|
||||
try:
|
||||
user = get_current_user(request)
|
||||
user = effective_user(request)
|
||||
except Exception:
|
||||
user = None
|
||||
if not user:
|
||||
@@ -346,11 +346,11 @@ def add_user_message(sess, chat_handler, preprocessed: PreprocessedMessage, inco
|
||||
def fire_message_event(request, webhook_manager, session_id: str, sess, message: str, compare_mode: bool = False):
|
||||
"""Fire webhook and event_bus events for a new user message."""
|
||||
if webhook_manager and not compare_mode:
|
||||
asyncio.create_task(webhook_manager.fire("chat.message", {
|
||||
webhook_manager.fire_and_forget("chat.message", {
|
||||
"session_id": session_id, "model": sess.model, "message": message[:2000],
|
||||
}))
|
||||
})
|
||||
from src.event_bus import fire_event
|
||||
user = get_current_user(request)
|
||||
user = effective_user(request)
|
||||
fire_event("message_sent", user)
|
||||
|
||||
|
||||
@@ -577,7 +577,7 @@ async def build_chat_context(
|
||||
fire_message_event(request, webhook_manager, session_id, sess, message, compare_mode)
|
||||
|
||||
# Resolve user prefs
|
||||
user = get_current_user(request)
|
||||
user = effective_user(request)
|
||||
uprefs = load_prefs_for_user(user)
|
||||
|
||||
# Memory enabled?
|
||||
@@ -1120,10 +1120,10 @@ def run_post_response_tasks(
|
||||
|
||||
# Webhook
|
||||
if webhook_manager and not compare_mode:
|
||||
asyncio.create_task(webhook_manager.fire("chat.completed", {
|
||||
webhook_manager.fire_and_forget("chat.completed", {
|
||||
"session_id": session_id, "model": sess.model,
|
||||
"user_message": message, "response": full_response[:2000],
|
||||
}))
|
||||
})
|
||||
|
||||
# Auto-name
|
||||
if needs_auto_name(sess.name):
|
||||
|
||||
+13
-12
@@ -23,7 +23,7 @@ from src.endpoint_resolver import normalize_base as _normalize_base, build_chat_
|
||||
from src.session_search import search_session_messages
|
||||
from src.prompt_security import untrusted_context_message
|
||||
from core.exceptions import SessionNotFoundError
|
||||
from src.auth_helpers import get_current_user
|
||||
from src.auth_helpers import effective_user, get_current_user
|
||||
from routes.session_routes import _verify_session_owner
|
||||
from routes.document_helpers import _owner_session_filter
|
||||
from core.database import SessionLocal, get_session_mode, set_session_mode
|
||||
@@ -126,7 +126,8 @@ def _clear_orphaned_session_endpoint(sess, owner: str | None = None) -> bool:
|
||||
sess.model = ""
|
||||
sess.headers = {}
|
||||
return True
|
||||
except Exception:
|
||||
except Exception as e:
|
||||
logger.warning("Failed to clear orphaned session endpoint", exc_info=e)
|
||||
db.rollback()
|
||||
return False
|
||||
finally:
|
||||
@@ -144,7 +145,8 @@ def _endpoint_cache_contains_model(endpoint, model: str) -> bool:
|
||||
return True
|
||||
try:
|
||||
models = json.loads(raw) if isinstance(raw, str) else raw
|
||||
except Exception:
|
||||
except Exception as e:
|
||||
logger.warning("Failed to parse cached models list, treating as containing model", exc_info=e)
|
||||
return True
|
||||
if not isinstance(models, list) or not models:
|
||||
return True
|
||||
@@ -236,7 +238,8 @@ def _recover_empty_session_model(sess, session_id: str, owner: str | None = None
|
||||
is_chatgpt_subscription = False
|
||||
try:
|
||||
cached = json.loads(ep.cached_models) if isinstance(ep.cached_models, str) else (ep.cached_models or [])
|
||||
except Exception:
|
||||
except Exception as e:
|
||||
logger.warning("Failed to parse cached_models for endpoint %r", getattr(ep, "id", "?"), exc_info=e)
|
||||
cached = []
|
||||
if not cached:
|
||||
visible = []
|
||||
@@ -360,7 +363,7 @@ def setup_chat_routes(
|
||||
sess = session_manager.get_session(session)
|
||||
except KeyError:
|
||||
raise HTTPException(404, f"Session '{session}' not found")
|
||||
owner = get_current_user(request)
|
||||
owner = effective_user(request)
|
||||
if _clear_orphaned_session_endpoint(sess, owner=owner):
|
||||
raise HTTPException(400, "Selected model endpoint was removed. Pick another model in Settings.")
|
||||
|
||||
@@ -600,7 +603,7 @@ def setup_chat_routes(
|
||||
# but BEFORE loading. Prevents cross-user session hijack.
|
||||
_verify_session_owner(request, session)
|
||||
sess = session_manager.get_session(session)
|
||||
owner = get_current_user(request)
|
||||
owner = effective_user(request)
|
||||
if _clear_orphaned_session_endpoint(sess, owner=owner):
|
||||
raise HTTPException(400, "Selected model endpoint was removed. Pick another model in Settings.")
|
||||
# Issue #587: picker shows a model from the endpoint cache but
|
||||
@@ -631,7 +634,7 @@ def setup_chat_routes(
|
||||
_enforce_chat_privileges(request, sess)
|
||||
|
||||
# Ensure session has auth headers
|
||||
resolve_session_auth(sess, session, owner=get_current_user(request))
|
||||
resolve_session_auth(sess, session, owner=effective_user(request))
|
||||
|
||||
# Check for research_pending BEFORE mode persist overwrites it
|
||||
do_research = str(use_research).lower() == "true"
|
||||
@@ -646,8 +649,8 @@ def setup_chat_routes(
|
||||
elif attachments:
|
||||
try:
|
||||
att_ids = [str(x) for x in json.loads(attachments)]
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.warning("Failed to parse attachments JSON, ignoring attachments", exc_info=e)
|
||||
|
||||
no_memory = str(form_data.get("no_memory", "")).lower() == "true"
|
||||
pre_context_tool_policy = build_effective_tool_policy(
|
||||
@@ -1297,8 +1300,6 @@ def setup_chat_routes(
|
||||
"doc_stream_open", "doc_stream_delta",
|
||||
"doc_update", "doc_suggestions", "ui_control",
|
||||
"rounds_exhausted",
|
||||
"loop_breaker_triggered",
|
||||
"intent_nudge_exhausted",
|
||||
"ask_user",
|
||||
"plan_update",
|
||||
):
|
||||
@@ -1484,7 +1485,7 @@ def setup_chat_routes(
|
||||
if not q or not q.strip():
|
||||
return []
|
||||
|
||||
_user = get_current_user(request)
|
||||
_user = effective_user(request)
|
||||
return [
|
||||
result.to_dict()
|
||||
for result in search_session_messages(
|
||||
|
||||
@@ -505,6 +505,8 @@ def _cached_model_scan_script(model_dirs: list[str] | None = None, add_hf_cache:
|
||||
" if u.startswith('KB'): return int(n * 1024)",
|
||||
" return int(n)",
|
||||
"def scan_ollama():",
|
||||
" if any(m.get('is_ollama') for m in models): return",
|
||||
" if os.name == 'nt' and not os.environ.get('ODYSSEUS_ALLOW_OLLAMA_CLI_SCAN'): return",
|
||||
" if not shutil.which('ollama'): return",
|
||||
" try:",
|
||||
" p = subprocess.run(['ollama', 'list'], stdout=subprocess.PIPE, stderr=subprocess.DEVNULL, text=True, timeout=6)",
|
||||
@@ -535,8 +537,8 @@ def _cached_model_scan_script(model_dirs: list[str] | None = None, add_hf_cache:
|
||||
" models.append({'repo_id':name,'size_bytes':size_bytes,'nb_files':1,'has_incomplete':False,'path':'ollama','backend':'ollama','is_ollama':True})",
|
||||
" return",
|
||||
"for _hf_cache in hf_cache_paths(): scan_hf(_hf_cache)",
|
||||
"scan_ollama()",
|
||||
"scan_ollama_api()",
|
||||
"scan_ollama()",
|
||||
]
|
||||
for model_dir in model_dirs or []:
|
||||
lines.append(f"scan_dir(os.path.expanduser({model_dir!r}))")
|
||||
|
||||
@@ -503,7 +503,8 @@ def setup_document_routes(session_manager, upload_handler=None) -> APIRouter:
|
||||
user = get_current_user(request)
|
||||
try:
|
||||
data = await request.json()
|
||||
except Exception:
|
||||
except Exception as e:
|
||||
logger.warning("Failed to parse export request body, defaulting to empty", exc_info=e)
|
||||
data = {}
|
||||
ids = data.get("ids") or []
|
||||
if not ids:
|
||||
@@ -645,8 +646,8 @@ def setup_document_routes(session_manager, upload_handler=None) -> APIRouter:
|
||||
try:
|
||||
from src.agent_tools.document_tools import clear_active_document
|
||||
clear_active_document(doc_id)
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.warning("Failed to clear active document %r on detach", doc_id, exc_info=e)
|
||||
db.commit()
|
||||
db.refresh(doc)
|
||||
return _doc_to_dict(doc)
|
||||
|
||||
@@ -79,15 +79,16 @@ def _email_tag_owner_aliases(account_id: str | None, owner: str = "") -> list[st
|
||||
cfg.get("smtp_user") or "",
|
||||
cfg.get("from_address") or "",
|
||||
])
|
||||
except Exception:
|
||||
except Exception as _e:
|
||||
logger.warning("Failed to resolve email account alias", exc_info=_e)
|
||||
resolved_account_id = None
|
||||
row = db.get(_EA, resolved_account_id) if resolved_account_id else None
|
||||
if row:
|
||||
aliases.extend([row.owner or "", row.imap_user or "", row.from_address or ""])
|
||||
finally:
|
||||
db.close()
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as _e:
|
||||
logger.warning("Failed to load email aliases", exc_info=_e)
|
||||
out = []
|
||||
for a in aliases:
|
||||
a = (a or "").strip()
|
||||
|
||||
@@ -9,6 +9,7 @@ from pathlib import Path
|
||||
from fastapi import APIRouter, HTTPException, Form, Depends
|
||||
from core.constants import EMBEDDING_ENDPOINT_FILE, FASTEMBED_CACHE_DIR
|
||||
from core.middleware import require_admin
|
||||
from src.runtime_paths import get_app_root
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@ from core.session_manager import SessionManager
|
||||
from core.models import ChatMessage
|
||||
from src.request_models import SessionResponse
|
||||
from core.database import Session as DbSession, SessionLocal, Document, GalleryImage, utcnow_naive
|
||||
from src.auth_helpers import get_current_user, effective_user, _auth_disabled, owner_filter
|
||||
from src.auth_helpers import effective_user, _auth_disabled, owner_filter
|
||||
from src.session_actions import is_session_recently_active
|
||||
|
||||
|
||||
@@ -328,7 +328,7 @@ def setup_session_routes(session_manager: SessionManager, config: dict, webhook_
|
||||
endpoint_id: str = Form(""),
|
||||
):
|
||||
skip_val = str(skip_validation).lower() == "true"
|
||||
user = get_current_user(request)
|
||||
user = effective_user(request)
|
||||
endpoint_api_key = ""
|
||||
endpoint_base_url = ""
|
||||
_reject_raw_endpoint_url_for_non_admin(request, user, endpoint_id, endpoint_url)
|
||||
@@ -477,7 +477,7 @@ def setup_session_routes(session_manager: SessionManager, config: dict, webhook_
|
||||
db.close()
|
||||
# Switch model/endpoint mid-session
|
||||
if model is not None and endpoint_url is not None:
|
||||
user = get_current_user(request)
|
||||
user = effective_user(request)
|
||||
_reject_raw_endpoint_url_for_non_admin(request, user, endpoint_id, endpoint_url)
|
||||
endpoint_api_key = ""
|
||||
endpoint_base_url = ""
|
||||
|
||||
@@ -7,7 +7,7 @@ from fastapi import APIRouter, Request, File, UploadFile, HTTPException
|
||||
from typing import List
|
||||
import logging
|
||||
from core.middleware import require_admin
|
||||
from src.auth_helpers import get_current_user
|
||||
from src.auth_helpers import effective_user
|
||||
from src.upload_handler import count_recent_uploads
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -78,7 +78,7 @@ def setup_upload_routes(upload_handler):
|
||||
|
||||
for u in files:
|
||||
try:
|
||||
meta = upload_handler.save_upload(u, client_ip, owner=get_current_user(request))
|
||||
meta = upload_handler.save_upload(u, client_ip, owner=effective_user(request))
|
||||
out.append({
|
||||
"id": meta["id"],
|
||||
"name": meta["name"],
|
||||
@@ -138,7 +138,7 @@ def setup_upload_routes(upload_handler):
|
||||
original_name = info.get("name", file_id)
|
||||
auth_mgr = getattr(request.app.state, "auth_manager", None)
|
||||
auth_configured = bool(auth_mgr and auth_mgr.is_configured)
|
||||
current_user = get_current_user(request)
|
||||
current_user = effective_user(request)
|
||||
file_owner = info.get("owner") if info else None
|
||||
if auth_configured:
|
||||
if not current_user:
|
||||
@@ -204,7 +204,7 @@ def setup_upload_routes(upload_handler):
|
||||
info = _load_upload_info(file_id)
|
||||
auth_mgr = getattr(request.app.state, "auth_manager", None)
|
||||
auth_configured = bool(auth_mgr and auth_mgr.is_configured)
|
||||
current_user = get_current_user(request)
|
||||
current_user = effective_user(request)
|
||||
file_owner = info.get("owner") if info else None
|
||||
if auth_configured:
|
||||
if not current_user:
|
||||
@@ -247,7 +247,7 @@ def setup_upload_routes(upload_handler):
|
||||
raise HTTPException(404, "File not found")
|
||||
auth_mgr = getattr(request.app.state, "auth_manager", None)
|
||||
auth_configured = bool(auth_mgr and auth_mgr.is_configured)
|
||||
current_user = get_current_user(request)
|
||||
current_user = effective_user(request)
|
||||
file_owner = info.get("owner")
|
||||
if auth_configured:
|
||||
if not current_user:
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
"""Webhook, API Token, and sync chat routes."""
|
||||
|
||||
import asyncio
|
||||
import uuid
|
||||
import logging
|
||||
from typing import Optional
|
||||
@@ -385,10 +384,10 @@ def setup_webhook_routes(
|
||||
sess.add_message(ChatMessage("assistant", reply))
|
||||
session_manager.save_sessions()
|
||||
|
||||
asyncio.create_task(webhook_manager.fire("chat.completed", {
|
||||
webhook_manager.fire_and_forget("chat.completed", {
|
||||
"session_id": session_id, "model": sess.model,
|
||||
"user_message": message[:2000], "response": reply[:2000],
|
||||
}))
|
||||
})
|
||||
|
||||
return {"response": reply, "session_id": session_id, "model": sess.model}
|
||||
|
||||
|
||||
@@ -19,6 +19,10 @@ GPU_BANDWIDTH = {
|
||||
"6950 xt": 576, "6900 xt": 512, "6800 xt": 512, "6800": 512, "6700 xt": 384, "6600 xt": 256, "6600": 224,
|
||||
"mi300x": 5300, "mi300": 5300, "mi250x": 3277, "mi250": 3277, "mi210": 1638, "mi100": 1229,
|
||||
"9070 xt": 624, "9070": 488, "9060 xt": 322, "9060": 322,
|
||||
# NVIDIA GB10 Grace-Blackwell superchip (DGX Spark). Unified LPDDR5X memory,
|
||||
# not Apple Silicon, so it lives in the generic GPU table — the Apple-only
|
||||
# lookup never matches it (its name carries no "apple").
|
||||
"gb10": 273,
|
||||
}
|
||||
|
||||
# Pre-sort keys by length descending for correct substring matching
|
||||
|
||||
+163
-14
@@ -15,6 +15,8 @@ from urllib.parse import urljoin, urlparse
|
||||
import httpx
|
||||
from bs4 import BeautifulSoup
|
||||
|
||||
from src.constants import WEB_FETCH_SOFT_MAX_BYTES, WEB_FETCH_HARD_MAX_BYTES, WEB_FETCH_USER_AGENT
|
||||
|
||||
from .analytics import RateLimitError, error_logger
|
||||
from .cache import (
|
||||
CONTENT_CACHE_DIR,
|
||||
@@ -89,18 +91,128 @@ def _public_http_url(url: str) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def _get_public_url(url: str, headers: dict, timeout: int, max_redirects: int = 5) -> httpx.Response:
|
||||
class BodyTooLargeError(Exception):
|
||||
"""The server declared a body larger than the hard fetch ceiling."""
|
||||
|
||||
def __init__(self, url: str, declared_bytes: int):
|
||||
self.url = url
|
||||
self.declared_bytes = declared_bytes
|
||||
super().__init__(
|
||||
f"response body is {declared_bytes:,} bytes, over the "
|
||||
f"{WEB_FETCH_HARD_MAX_BYTES:,}-byte hard cap"
|
||||
)
|
||||
|
||||
|
||||
class _CappedFetch:
|
||||
"""Result of a size-capped streaming GET.
|
||||
|
||||
Carries just what fetch_webpage_content needs from an httpx.Response,
|
||||
plus the cap bookkeeping: the (possibly truncated) body, whether the
|
||||
cap cut it short, and the size the server declared via Content-Length
|
||||
(wire bytes; None when absent).
|
||||
"""
|
||||
|
||||
__slots__ = ("status_code", "headers", "content", "truncated",
|
||||
"declared_bytes", "encoding", "url")
|
||||
|
||||
def __init__(self, status_code, headers, content, truncated,
|
||||
declared_bytes, encoding, url):
|
||||
self.status_code = status_code
|
||||
self.headers = headers
|
||||
self.content = content
|
||||
self.truncated = truncated
|
||||
self.declared_bytes = declared_bytes
|
||||
self.encoding = encoding
|
||||
self.url = url
|
||||
|
||||
@property
|
||||
def text(self) -> str:
|
||||
return self.content.decode(self.encoding or "utf-8", errors="replace")
|
||||
|
||||
def raise_for_status(self):
|
||||
if self.status_code >= 400:
|
||||
request = httpx.Request("GET", self.url)
|
||||
raise httpx.HTTPStatusError(
|
||||
f"HTTP {self.status_code} for {self.url}",
|
||||
request=request,
|
||||
response=httpx.Response(self.status_code, request=request),
|
||||
)
|
||||
|
||||
|
||||
def _get_public_url(url: str, headers: dict, timeout: int, max_redirects: int = 5,
|
||||
max_bytes: int = None) -> "_CappedFetch":
|
||||
"""Capped streaming GET with SSRF-guarded manual redirects.
|
||||
|
||||
The body is streamed and buffering stops at ``max_bytes`` (default: the
|
||||
soft cap), so an oversized resource cannot be pulled into memory or the
|
||||
content cache in full. When Content-Length already declares a body over
|
||||
the hard ceiling, the fetch is refused before any body bytes are read.
|
||||
"""
|
||||
cap = min(max_bytes or WEB_FETCH_SOFT_MAX_BYTES, WEB_FETCH_HARD_MAX_BYTES)
|
||||
current = url
|
||||
for _ in range(max_redirects + 1):
|
||||
if not _public_http_url(current):
|
||||
raise httpx.RequestError("Blocked private/internal URL", request=httpx.Request("GET", current))
|
||||
response = httpx.get(current, headers=headers, timeout=timeout, follow_redirects=False)
|
||||
if response.status_code not in (301, 302, 303, 307, 308):
|
||||
return response
|
||||
location = response.headers.get("location")
|
||||
if not location:
|
||||
return response
|
||||
current = urljoin(str(response.url), location)
|
||||
# Force identity transfer-encoding. With gzip/deflate the wire bytes
|
||||
# (and Content-Length) can be a small fraction of the decoded body, so
|
||||
# a tiny compressed response could pass the hard-cap preflight and then
|
||||
# expand past the ceiling in a single decoded chunk before the streamed
|
||||
# cap below can slice it. Identity makes Content-Length the true body
|
||||
# size and keeps each streamed chunk bounded by the network read.
|
||||
req_headers = dict(headers or {})
|
||||
req_headers["Accept-Encoding"] = "identity"
|
||||
with httpx.stream("GET", current, headers=req_headers, timeout=timeout,
|
||||
follow_redirects=False) as response:
|
||||
if response.status_code in (301, 302, 303, 307, 308):
|
||||
location = response.headers.get("location")
|
||||
if not location:
|
||||
return _CappedFetch(response.status_code, response.headers, b"",
|
||||
False, None, response.encoding, str(response.url))
|
||||
current = urljoin(str(response.url), location)
|
||||
continue
|
||||
|
||||
# A server can ignore the identity request and still return a
|
||||
# compressed body; httpx.iter_bytes would then decode it, and a tiny
|
||||
# gzip can balloon into one decoded chunk far past the cap before we
|
||||
# slice. Refuse a compressed Content-Encoding so the streamed cap
|
||||
# stays a real memory bound (Content-Length is the compressed wire
|
||||
# length here, so the preflight and size metadata are unreliable too).
|
||||
enc = (response.headers.get("content-encoding") or "").strip().lower()
|
||||
if enc and enc != "identity":
|
||||
raise httpx.RequestError(
|
||||
f"Refusing compressed response (Content-Encoding: {enc}) after "
|
||||
"requesting identity: cannot bound decoded body size",
|
||||
request=httpx.Request("GET", current),
|
||||
)
|
||||
|
||||
declared = None
|
||||
raw_len = response.headers.get("content-length")
|
||||
if raw_len and raw_len.isdigit():
|
||||
declared = int(raw_len)
|
||||
# Refuse before buffering anything when the server already tells
|
||||
# us the body exceeds the absolute ceiling (Content-Length is wire
|
||||
# bytes; the decompressed body can only be larger).
|
||||
if declared is not None and declared > WEB_FETCH_HARD_MAX_BYTES:
|
||||
raise BodyTooLargeError(current, declared)
|
||||
|
||||
chunks = []
|
||||
read = 0
|
||||
truncated = False
|
||||
# We requested identity above, so iter_bytes yields the raw body in
|
||||
# network-read-sized chunks (no decompression expansion); the cap
|
||||
# therefore bounds what we actually buffer.
|
||||
for chunk in response.iter_bytes():
|
||||
read += len(chunk)
|
||||
if read > cap:
|
||||
keep = cap - (read - len(chunk))
|
||||
if keep > 0:
|
||||
chunks.append(chunk[:keep])
|
||||
truncated = True
|
||||
break
|
||||
chunks.append(chunk)
|
||||
return _CappedFetch(response.status_code, response.headers,
|
||||
b"".join(chunks), truncated, declared,
|
||||
response.encoding, str(response.url))
|
||||
raise httpx.RequestError("Too many redirects", request=httpx.Request("GET", current))
|
||||
|
||||
# PDF extraction (optional dependency)
|
||||
@@ -222,9 +334,19 @@ def _empty_result(url: str, error: str = "") -> dict:
|
||||
# ----------------------------------------------------------------------
|
||||
# Main content fetcher
|
||||
# ----------------------------------------------------------------------
|
||||
def fetch_webpage_content(url: str, timeout: int = 5, retry_attempt: int = 0) -> dict:
|
||||
"""Fetch and extract meaningful content from a webpage with caching."""
|
||||
cache_key = generate_cache_key(url)
|
||||
def fetch_webpage_content(url: str, timeout: int = 5, retry_attempt: int = 0,
|
||||
max_bytes: int = None) -> dict:
|
||||
"""Fetch and extract meaningful content from a webpage with caching.
|
||||
|
||||
``max_bytes`` raises the download budget per call (clamped to the hard
|
||||
cap); the default is the soft cap. When the body is cut short the result
|
||||
carries ``truncated``/``fetched_bytes``/``total_bytes`` so callers can
|
||||
tell the model the content is partial (#3812).
|
||||
"""
|
||||
effective_cap = min(max_bytes or WEB_FETCH_SOFT_MAX_BYTES, WEB_FETCH_HARD_MAX_BYTES)
|
||||
# The cap is part of the cache identity: a truncated soft-cap fetch must
|
||||
# not be served to a later full-budget request for the same URL.
|
||||
cache_key = generate_cache_key(f"{url}#cap={effective_cap}")
|
||||
cache_file = CONTENT_CACHE_DIR / f"{cache_key}.cache"
|
||||
|
||||
# Check cache
|
||||
@@ -247,18 +369,24 @@ def fetch_webpage_content(url: str, timeout: int = 5, retry_attempt: int = 0) ->
|
||||
# Fetch
|
||||
try:
|
||||
headers = {
|
||||
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36",
|
||||
"User-Agent": WEB_FETCH_USER_AGENT,
|
||||
"Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8",
|
||||
"Accept-Language": "en-US,en;q=0.5",
|
||||
"Accept-Encoding": "gzip, deflate",
|
||||
# identity so the streamed size cap in _get_public_url stays honest
|
||||
# (a compressed body can decode to far more than Content-Length).
|
||||
"Accept-Encoding": "identity",
|
||||
"Connection": "keep-alive",
|
||||
}
|
||||
response = _get_public_url(url, headers=headers, timeout=timeout)
|
||||
response = _get_public_url(url, headers=headers, timeout=timeout,
|
||||
max_bytes=effective_cap)
|
||||
|
||||
if response.status_code == 429:
|
||||
raise RateLimitError(f"Rate limit hit for {url} (attempt {retry_attempt})")
|
||||
|
||||
response.raise_for_status()
|
||||
except BodyTooLargeError as e:
|
||||
error_logger.warning(f"Refused oversized body for {url}: {e}")
|
||||
return _empty_result(url, f"TooLarge: {e}")
|
||||
except httpx.HTTPStatusError as e:
|
||||
error_logger.warning(f"HTTP {e.response.status_code} fetching {url}: {e}")
|
||||
return _empty_result(url, f"HTTP {e.response.status_code}: {e}")
|
||||
@@ -269,9 +397,27 @@ def fetch_webpage_content(url: str, timeout: int = 5, retry_attempt: int = 0) ->
|
||||
error_logger.error(str(e))
|
||||
return _empty_result(url, str(e))
|
||||
|
||||
# Size bookkeeping shared by every content branch below. getattr keeps
|
||||
# plain httpx.Response stand-ins (tests) working without the cap fields.
|
||||
_size_fields = {
|
||||
"truncated": getattr(response, "truncated", False),
|
||||
"fetched_bytes": len(response.content),
|
||||
"total_bytes": getattr(response, "declared_bytes", None),
|
||||
}
|
||||
|
||||
# PDF handling
|
||||
content_type = response.headers.get("Content-Type", "").lower()
|
||||
if "application/pdf" in content_type or url.lower().endswith(".pdf"):
|
||||
if _size_fields["truncated"]:
|
||||
# A PDF cut mid-stream is not parseable; unlike text there is no
|
||||
# useful partial result, so report the budget problem instead.
|
||||
_declared = _size_fields["total_bytes"]
|
||||
return _empty_result(
|
||||
url,
|
||||
f"TooLarge: PDF exceeds the {effective_cap:,}-byte fetch budget"
|
||||
+ (f" (size {_declared:,} bytes)" if _declared else "")
|
||||
+ "; retry with a larger budget if it fits under the hard cap",
|
||||
)
|
||||
if pdf_extract_text is None:
|
||||
logger.error("pdfminer.six is not installed; cannot extract PDF text.")
|
||||
pdf_text = ""
|
||||
@@ -295,6 +441,7 @@ def fetch_webpage_content(url: str, timeout: int = 5, retry_attempt: int = 0) ->
|
||||
"js_message": "",
|
||||
"success": bool(pdf_text),
|
||||
"error": "" if pdf_text else "Failed to extract PDF text",
|
||||
**_size_fields,
|
||||
}
|
||||
_cache_result(cache_file, cache_key, result, url)
|
||||
return result
|
||||
@@ -329,6 +476,7 @@ def fetch_webpage_content(url: str, timeout: int = 5, retry_attempt: int = 0) ->
|
||||
"js_message": "",
|
||||
"success": bool(text_body),
|
||||
"error": "" if text_body else "Empty response body",
|
||||
**_size_fields,
|
||||
}
|
||||
_cache_result(cache_file, cache_key, result, url)
|
||||
return result
|
||||
@@ -391,6 +539,7 @@ def fetch_webpage_content(url: str, timeout: int = 5, retry_attempt: int = 0) ->
|
||||
"js_message": js_message,
|
||||
"success": True,
|
||||
"error": "",
|
||||
**_size_fields,
|
||||
}
|
||||
_cache_result(cache_file, cache_key, result, url)
|
||||
return result
|
||||
|
||||
@@ -9,14 +9,12 @@ from urllib.parse import urljoin, urlparse, parse_qs
|
||||
import httpx
|
||||
from bs4 import BeautifulSoup
|
||||
|
||||
from src.constants import SEARXNG_INSTANCE
|
||||
from src.constants import SEARXNG_INSTANCE, REQUEST_TIMEOUT, WEB_FETCH_USER_AGENT
|
||||
from .analytics import RateLimitError, error_logger
|
||||
from .query import build_enhanced_query
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
REQUEST_TIMEOUT = 20
|
||||
|
||||
# Provider registry — maps setting value to (label, needs_key, needs_url)
|
||||
PROVIDER_INFO = {
|
||||
"searxng": ("SearXNG", False, True),
|
||||
@@ -140,7 +138,7 @@ def searxng_search_api(query: str, count: Optional[int] = None, categories: str
|
||||
count = count if count is not None else _get_result_count()
|
||||
instance = _get_search_instance()
|
||||
api_key = ""
|
||||
headers = {"User-Agent": "Mozilla/5.0"}
|
||||
headers = {"User-Agent": WEB_FETCH_USER_AGENT}
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
# News/fresh queries do badly in the 'general' category — it favours
|
||||
@@ -252,7 +250,7 @@ def searxng_search(query, max_results=10):
|
||||
"""Search using SearXNG instance - parsing HTML."""
|
||||
instance = _get_search_instance()
|
||||
api_key = ""
|
||||
req_headers = {"User-Agent": "Mozilla/5.0"}
|
||||
req_headers = {"User-Agent": WEB_FETCH_USER_AGENT}
|
||||
if api_key:
|
||||
req_headers["Authorization"] = f"Bearer {api_key}"
|
||||
try:
|
||||
@@ -391,7 +389,7 @@ def duckduckgo_search(query: str, count: Optional[int] = None, time_filter: Opti
|
||||
response = httpx.get(
|
||||
"https://html.duckduckgo.com/html/",
|
||||
params={"q": query, "kp": _safesearch_for("duckduckgo_html")},
|
||||
headers={"User-Agent": "Mozilla/5.0"},
|
||||
headers={"User-Agent": WEB_FETCH_USER_AGENT},
|
||||
timeout=REQUEST_TIMEOUT,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
+19
-228
@@ -38,167 +38,6 @@ from src.agent_tools import (
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Redaction patterns for common secret-bearing shapes. Explicit and tested
|
||||
# (see tests/test_loop_guard_signals.py) rather than one clever broad regex —
|
||||
# safety first, but we try not to mangle harmless prose. Applied in order.
|
||||
_REDACTED = "[redacted]"
|
||||
|
||||
# Cookie: ... / Set-Cookie: ... — redact the rest of the line (cookies hold spaces).
|
||||
_SENSITIVE_COOKIE_RE = re.compile(
|
||||
r"(?i)\b((?:set-)?cookie\s*[:=]\s*)[^\r\n]+"
|
||||
)
|
||||
# URL credentials, e.g. postgres://user:pass@host/db. The password half allows
|
||||
# inner colons (postgres://user:pa:ss@host/db) but still stops at / and @.
|
||||
_SENSITIVE_URL_CRED_RE = re.compile(
|
||||
r"(?i)\b([a-z][a-z0-9+.\-]*://)[^\s:/@]+:[^\s/@]+@"
|
||||
)
|
||||
# Prefix-only discovery regexes. Each matches the key and its separator (the part
|
||||
# we KEEP); the value that follows is found by a linear scanner rather than by a
|
||||
# regex, so there is no backtracking-prone quantifier over uncontrolled input.
|
||||
#
|
||||
# Authorization: Bearer <tok> / Authorization: Basic "two word secret"
|
||||
_AUTH_PREFIX_RE = re.compile(
|
||||
r"(?i)authorization\s*[:=]\s*(?:bearer|basic)\s+"
|
||||
)
|
||||
# Provider-prefixed env names, e.g. OPENAI_API_KEY=..., AWS_SECRET_ACCESS_KEY=...,
|
||||
# GITHUB_TOKEN=... — require a sensitive suffix preceded by `_` so benign names
|
||||
# that merely end in KEY (MONKEY, TURKEY) are left alone.
|
||||
_ENV_PREFIX_RE = re.compile(
|
||||
r"(?:export\s+)?\b[A-Z][A-Z0-9_]*"
|
||||
r"_(?:KEY|TOKEN|SECRET|PASSWORD|PASSWD|PWD|CREDENTIALS?)\s*=\s*"
|
||||
)
|
||||
# Generic sensitive key, e.g. password=..., api_key: ..., client_secret=...
|
||||
_KEY_PREFIX_RE = re.compile(
|
||||
r"(?i)\b(?:password|passwd|pwd|token|api[_-]?key|client_secret|secret)\b\s*[:=]\s*"
|
||||
)
|
||||
# Obvious provider-shaped bare tokens (no surrounding key needed).
|
||||
_SENSITIVE_BARE_TOKEN_RE = re.compile(
|
||||
r"\b("
|
||||
r"sk-[A-Za-z0-9_\-]{16,}" # OpenAI / Anthropic style
|
||||
r"|gh[pousr]_[A-Za-z0-9]{20,}" # GitHub PAT
|
||||
r"|xox[baprs]-[A-Za-z0-9\-]{10,}" # Slack
|
||||
r"|AKIA[0-9A-Z]{16}" # AWS access key id
|
||||
r"|hf_[A-Za-z0-9]{16,}" # Hugging Face token
|
||||
r"|AIza[0-9A-Za-z_\-]{20,}" # Google API key
|
||||
r")\b"
|
||||
)
|
||||
|
||||
|
||||
def _consume_secret_value_end(text: str, start: int) -> int:
|
||||
"""Return the exclusive end index of the secret value beginning at ``start``.
|
||||
|
||||
If the value is quoted, scan to the matching unescaped quote (backslash
|
||||
escapes are skipped two chars at a time). Otherwise scan to the first
|
||||
whitespace, comma, or semicolon. The scan is linear in the length of the
|
||||
input, so it cannot exhibit catastrophic backtracking.
|
||||
"""
|
||||
n = len(text)
|
||||
if start >= n:
|
||||
return start
|
||||
quote = text[start]
|
||||
if quote in ("'", '"'):
|
||||
i = start + 1
|
||||
while i < n:
|
||||
ch = text[i]
|
||||
if ch == "\\":
|
||||
i += 2
|
||||
continue
|
||||
if ch == quote:
|
||||
return i + 1
|
||||
i += 1
|
||||
return n # unterminated quote: redact to the end
|
||||
i = start
|
||||
while i < n and not text[i].isspace() and text[i] not in (",", ";"):
|
||||
i += 1
|
||||
return i
|
||||
|
||||
|
||||
def _redact_after_prefix(text: str, prefix_re: "re.Pattern") -> str:
|
||||
"""Redact the value following each ``prefix_re`` match using a linear scan."""
|
||||
result = []
|
||||
pos = 0
|
||||
n = len(text)
|
||||
while pos < n:
|
||||
match = prefix_re.search(text, pos)
|
||||
if match is None:
|
||||
result.append(text[pos:])
|
||||
break
|
||||
result.append(text[pos:match.end()])
|
||||
value_end = _consume_secret_value_end(text, match.end())
|
||||
if value_end > match.end():
|
||||
result.append(_REDACTED)
|
||||
pos = value_end
|
||||
else:
|
||||
# Empty value: nothing to redact; step past the prefix and continue.
|
||||
pos = match.end()
|
||||
if pos < n:
|
||||
result.append(text[pos])
|
||||
pos += 1
|
||||
return "".join(result)
|
||||
|
||||
|
||||
def _redact_private_keys(text: str) -> str:
|
||||
"""Replace PEM private-key blocks with a placeholder via linear scanning.
|
||||
|
||||
Finds ``-----BEGIN `` markers, verifies the header names a PRIVATE KEY,
|
||||
locates the matching ``-----END `` marker, and collapses the whole block.
|
||||
No regex is used, so the (multi-line, uncontrolled) body cannot trigger
|
||||
polynomial matching.
|
||||
"""
|
||||
begin_marker = "-----BEGIN "
|
||||
end_marker = "-----END "
|
||||
dash = "-----"
|
||||
max_header = 64 # generous bound on "[TYPE ]PRIVATE KEY"
|
||||
result = []
|
||||
pos = 0
|
||||
while True:
|
||||
begin = text.find(begin_marker, pos)
|
||||
if begin == -1:
|
||||
result.append(text[pos:])
|
||||
return "".join(result)
|
||||
header_start = begin + len(begin_marker)
|
||||
header_close = text.find(dash, header_start)
|
||||
if (
|
||||
header_close == -1
|
||||
or header_close - header_start > max_header
|
||||
or not text[header_start:header_close].endswith("PRIVATE KEY")
|
||||
):
|
||||
result.append(text[pos:header_start])
|
||||
pos = header_start
|
||||
continue
|
||||
end = text.find(end_marker, header_close)
|
||||
if end == -1:
|
||||
result.append(text[pos:])
|
||||
return "".join(result)
|
||||
end_header_start = end + len(end_marker)
|
||||
end_close = text.find(dash, end_header_start)
|
||||
if (
|
||||
end_close == -1
|
||||
or end_close - end_header_start > max_header
|
||||
or not text[end_header_start:end_close].endswith("PRIVATE KEY")
|
||||
):
|
||||
result.append(text[pos:header_start])
|
||||
pos = header_start
|
||||
continue
|
||||
result.append(text[pos:begin])
|
||||
result.append("[redacted private key]")
|
||||
pos = end_close + len(dash)
|
||||
|
||||
|
||||
def _redact_sensitive_text(value: object) -> str:
|
||||
"""Redact obvious credential values before surfacing tool output."""
|
||||
if value is None:
|
||||
return ""
|
||||
|
||||
text = str(value)
|
||||
text = _redact_private_keys(text)
|
||||
text = _redact_after_prefix(text, _AUTH_PREFIX_RE)
|
||||
text = _SENSITIVE_COOKIE_RE.sub(r"\1" + _REDACTED, text)
|
||||
text = _SENSITIVE_URL_CRED_RE.sub(r"\1" + _REDACTED + "@", text)
|
||||
text = _redact_after_prefix(text, _ENV_PREFIX_RE)
|
||||
text = _redact_after_prefix(text, _KEY_PREFIX_RE)
|
||||
return _SENSITIVE_BARE_TOKEN_RE.sub(_REDACTED, text)
|
||||
|
||||
|
||||
def _load_mcp_disabled_map() -> Dict[str, set]:
|
||||
"""Load per-server disabled tool sets from the database."""
|
||||
@@ -685,7 +524,7 @@ def get_builtin_overrides() -> dict:
|
||||
ov = get_setting("builtin_tool_overrides", {})
|
||||
return ov if isinstance(ov, dict) else {}
|
||||
except Exception as e:
|
||||
logger.warning('Failed to load builtin tool overrides: %s', e)
|
||||
logger.warning("Failed to load builtin tool overrides, using defaults", exc_info=e)
|
||||
return {}
|
||||
|
||||
|
||||
@@ -1090,8 +929,8 @@ def _build_system_prompt(
|
||||
try:
|
||||
from src.user_time import current_datetime_context_message
|
||||
_datetime_message = current_datetime_context_message()
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.warning("Failed to build datetime context message", exc_info=e)
|
||||
|
||||
# Document context is kept as a SEPARATE message (not merged into the tool
|
||||
# prompt) so the context trimmer doesn't destroy it when truncating the
|
||||
@@ -1134,8 +973,8 @@ def _build_system_prompt(
|
||||
try:
|
||||
from src.pdf_form_doc import find_source_upload_id
|
||||
_is_form_backed = bool(find_source_upload_id(active_document.current_content or ""))
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.warning("Failed to detect if document is form-backed, assuming plain", exc_info=e)
|
||||
|
||||
if _is_form_backed:
|
||||
doc_ctx = (
|
||||
@@ -2376,7 +2215,6 @@ async def stream_agent_loop(
|
||||
# signatures + consecutive no-text tool rounds to bail early.
|
||||
_recent_call_sigs = collections.deque(maxlen=6)
|
||||
_stuck_rounds = 0
|
||||
_MAX_STUCK_ROUNDS = 4 # consecutive no-progress rounds before loop-breaker bails
|
||||
# Frequency of each exact call signature (tool + args), for the runaway
|
||||
# backstop. Counting identical repeats — not distinct same-tool calls —
|
||||
# lets a legit batch (e.g. 18 calendar events at once) through.
|
||||
@@ -2799,22 +2637,17 @@ async def stream_agent_loop(
|
||||
# promise: short response (<400 chars), no fenced code/answer,
|
||||
# and an action-intent phrase was matched. Long answers that
|
||||
# happen to contain "let me know" are not stalls.
|
||||
_promise_shape = (
|
||||
_looks_like_promise = (
|
||||
not guide_only
|
||||
and _intent_match is not None
|
||||
and len(_intent_text) < 400
|
||||
and "```" not in _intent_text
|
||||
and _intent_nudge_count < _MAX_INTENT_NUDGES
|
||||
)
|
||||
_looks_like_promise = _promise_shape and _intent_nudge_count < _MAX_INTENT_NUDGES
|
||||
if _looks_like_promise:
|
||||
_intent_nudge_count += 1
|
||||
_matched_phrase = _intent_match.group(0).strip()
|
||||
# Don't log the matched phrase — it's raw model text that may
|
||||
# carry credentials. Structural metadata only.
|
||||
logger.info(
|
||||
"[agent] intent-without-action nudge #%d on round %d",
|
||||
_intent_nudge_count, round_num,
|
||||
)
|
||||
logger.info(f"[agent] intent-without-action nudge #{_intent_nudge_count} on round {round_num}: {_matched_phrase!r}")
|
||||
messages.append({
|
||||
"role": "system",
|
||||
"content": (
|
||||
@@ -2830,24 +2663,6 @@ async def stream_agent_loop(
|
||||
# Visible signal in the stream so the user knows we caught it.
|
||||
yield f'data: {json.dumps({"type": "agent_step", "round": round_num + 1})}\n\n'
|
||||
continue
|
||||
# The model keeps announcing actions it never takes and we've spent
|
||||
# every nudge — surface why the turn is ending instead of letting it
|
||||
# look like a clean completion.
|
||||
if _promise_shape and _intent_nudge_count >= _MAX_INTENT_NUDGES:
|
||||
_matched_phrase = _intent_match.group(0).strip()
|
||||
_matched_phrase_safe = _redact_sensitive_text(_matched_phrase)
|
||||
_in_message = (
|
||||
f"Intent-nudge cap reached on round {round_num}: the model "
|
||||
f"announced an action ({_matched_phrase_safe!r}) without a tool call "
|
||||
f"after {_intent_nudge_count} nudge(s); ending the turn."
|
||||
)
|
||||
# Do not log the matched phrase, even redacted. It is raw model
|
||||
# text and may contain credentials; keep logs structural only.
|
||||
logger.warning(
|
||||
"[agent] intent-nudge cap exhausted on round %d (%d/%d)",
|
||||
round_num, _intent_nudge_count, _MAX_INTENT_NUDGES,
|
||||
)
|
||||
yield f'data: {json.dumps({"type": "intent_nudge_exhausted", "round": round_num, "nudges": _intent_nudge_count, "max_nudges": _MAX_INTENT_NUDGES, "message": _in_message})}\n\n'
|
||||
break # no tools — done
|
||||
|
||||
# ── Loop-breaker (Terminus-style stall detector) ──────────────
|
||||
@@ -2880,23 +2695,10 @@ async def stream_agent_loop(
|
||||
# Distinct calls to one tool (a real batch) are legitimate work, so we
|
||||
# count identical call signatures, not raw per-tool-type totals.
|
||||
_runaway = _detect_runaway_call(_call_freq)
|
||||
if _stuck_rounds >= _MAX_STUCK_ROUNDS or _runaway:
|
||||
if _stuck_rounds >= 4 or _runaway:
|
||||
reason = (f"calling {_runaway} with identical arguments over and over" if _runaway
|
||||
else "repeating the same tool calls without new progress")
|
||||
_lb_message = (
|
||||
f"Loop-breaker stopped the agent on round {round_num}: {reason}. "
|
||||
"Forced one tool-free round to converge on an answer or state what's blocked."
|
||||
)
|
||||
# Log structural metadata only — `_sig` is raw tool-call content
|
||||
# that may carry credentials.
|
||||
logger.warning(
|
||||
"[agent] loop-breaker tripped on round %d (%s); "
|
||||
"stuck_rounds=%d/%d runaway=%r",
|
||||
round_num, reason, _stuck_rounds, _MAX_STUCK_ROUNDS, _runaway,
|
||||
)
|
||||
# Surface the stop cause to the stream so the user (and journalctl)
|
||||
# can tell a guard fired, not a clean completion.
|
||||
yield f'data: {json.dumps({"type": "loop_breaker_triggered", "round": round_num, "reason": reason, "stuck_rounds": _stuck_rounds, "max_stuck_rounds": _MAX_STUCK_ROUNDS, "runaway": _runaway, "message": _lb_message})}\n\n'
|
||||
logger.warning(f"[agent] loop-breaker tripped on round {round_num} ({reason}); sig={_sig[:80]!r}")
|
||||
# The model has been executing tools, so its results are already
|
||||
# in context. Force ONE tool-free round to converge: write the
|
||||
# answer from what it has, or state plainly what's blocking it.
|
||||
@@ -2975,10 +2777,6 @@ async def stream_agent_loop(
|
||||
cmd_display = block.content.split("\n")[0].strip()[:80]
|
||||
else:
|
||||
cmd_display = block.content.strip()
|
||||
# The display string is streamed (tool_start/tool_output) and persisted;
|
||||
# redact any secrets in it. block.content itself is left untouched so
|
||||
# tool execution still sees the real command.
|
||||
cmd_display = _redact_sensitive_text(cmd_display)
|
||||
|
||||
if tool_policy and tool_policy.blocks(block.tool_type):
|
||||
desc = f"{block.tool_type}: BLOCKED"
|
||||
@@ -3024,15 +2822,8 @@ async def stream_agent_loop(
|
||||
evt = await _progress_q.get()
|
||||
if evt is None:
|
||||
break
|
||||
# Redact secrets in the live tail before streaming — the
|
||||
# final tool_output is redacted, so the progress tail must
|
||||
# be too, or a secret could flash by mid-run. Copy so we
|
||||
# don't mutate the tool's own event payload.
|
||||
_evt = dict(evt)
|
||||
if isinstance(_evt.get("tail"), str):
|
||||
_evt["tail"] = _redact_sensitive_text(_evt["tail"])
|
||||
yield (
|
||||
f'data: {json.dumps({"type": "tool_progress", "tool": block.tool_type, "round": round_num, **_evt})}\n\n'
|
||||
f'data: {json.dumps({"type": "tool_progress", "tool": block.tool_type, "round": round_num, **evt})}\n\n'
|
||||
)
|
||||
desc, result = await _tool_task
|
||||
|
||||
@@ -3098,7 +2889,7 @@ async def stream_agent_loop(
|
||||
result["results"] = _clean
|
||||
elif "stdout" in result:
|
||||
result["stdout"] = _clean
|
||||
except Exception:
|
||||
except (json.JSONDecodeError, Exception):
|
||||
pass
|
||||
|
||||
# Emit doc-specific event for document tools — the frontend
|
||||
@@ -3167,29 +2958,29 @@ async def stream_agent_loop(
|
||||
# empty) stdout/stderr; fall back to the error so the "timed
|
||||
# out" reason reaches the UI instead of a blank result.
|
||||
raw = result["stdout"] or result["stderr"] or result.get("error", "")
|
||||
output_text = _truncate(_redact_sensitive_text(raw))
|
||||
output_text = _truncate(raw)
|
||||
elif "output" in result:
|
||||
# bash / python canonical result: {"output": ..., "exit_code": ...}
|
||||
raw = result["output"] or ""
|
||||
output_text = _truncate(_redact_sensitive_text(raw))
|
||||
output_text = _truncate(raw)
|
||||
elif "response" in result:
|
||||
# AI interaction tools (chat_with_model, send_to_session)
|
||||
label = result.get("model", result.get("session_name", "AI"))
|
||||
output_text = _truncate(_redact_sensitive_text(f"{label}: {result['response']}"))
|
||||
output_text = _truncate(f"{label}: {result['response']}")
|
||||
elif "content" in result:
|
||||
output_text = _truncate(_redact_sensitive_text(result["content"]))
|
||||
output_text = _truncate(result["content"])
|
||||
elif "results" in result:
|
||||
output_text = _truncate(_redact_sensitive_text(result["results"]))
|
||||
output_text = _truncate(result["results"])
|
||||
elif "session_id" in result and "name" in result:
|
||||
output_text = f"Session created: {result['name']} (id: {result['session_id']})"
|
||||
elif "success" in result:
|
||||
output_text = (
|
||||
f"Written: {result.get('path', '')}"
|
||||
if result["success"]
|
||||
else f"Error: {_redact_sensitive_text(result.get('error', ''))}"
|
||||
else f"Error: {result.get('error', '')}"
|
||||
)
|
||||
elif "error" in result:
|
||||
output_text = _truncate(_redact_sensitive_text(result["error"]))
|
||||
output_text = _truncate(result["error"])
|
||||
|
||||
# Emit tool_output (include ui_event data if present)
|
||||
tool_output_data = {"type": "tool_output", "tool": block.tool_type, "command": cmd_display, "output": output_text, "exit_code": result.get("exit_code")}
|
||||
|
||||
@@ -57,13 +57,23 @@ class WebSearchTool:
|
||||
class WebFetchTool:
|
||||
async def execute(self, content: str, ctx: dict) -> dict:
|
||||
from src.search.content import fetch_webpage_content
|
||||
from src.constants import WEB_FETCH_HARD_MAX_BYTES
|
||||
raw = content.strip()
|
||||
url = ""
|
||||
max_bytes = None
|
||||
if raw.startswith("{"):
|
||||
try:
|
||||
parsed = json.loads(raw)
|
||||
if isinstance(parsed, dict):
|
||||
url = str(parsed.get("url") or "").strip()
|
||||
# Download-budget override (#3812): "full": true raises the
|
||||
# budget to the hard cap; an explicit max_bytes is clamped
|
||||
# to the hard cap downstream. Default stays the soft cap.
|
||||
if parsed.get("full") is True:
|
||||
max_bytes = WEB_FETCH_HARD_MAX_BYTES
|
||||
mb = parsed.get("max_bytes")
|
||||
if isinstance(mb, int) and mb > 0:
|
||||
max_bytes = mb
|
||||
except json.JSONDecodeError:
|
||||
url = ""
|
||||
if not url:
|
||||
@@ -78,7 +88,7 @@ class WebFetchTool:
|
||||
loop = asyncio.get_running_loop()
|
||||
try:
|
||||
result = await asyncio.wait_for(
|
||||
loop.run_in_executor(None, lambda: fetch_webpage_content(url, timeout=10)),
|
||||
loop.run_in_executor(None, lambda: fetch_webpage_content(url, timeout=10, max_bytes=max_bytes)),
|
||||
timeout=30,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
@@ -94,8 +104,28 @@ class WebFetchTool:
|
||||
return {"error": f"web_fetch: {url}: {err}", "exit_code": 1}
|
||||
return {"error": f"web_fetch: {url}: no readable text content (not HTML, or the page needs JS/login)", "exit_code": 1}
|
||||
|
||||
# Tell the model when the download budget cut the body short and how
|
||||
# to get the rest, instead of silently presenting a partial page as
|
||||
# the whole thing.
|
||||
size_note = ""
|
||||
if result.get("truncated"):
|
||||
fetched = result.get("fetched_bytes") or 0
|
||||
total = result.get("total_bytes")
|
||||
total_txt = f" of {total:,} bytes" if total else ""
|
||||
size_note = (
|
||||
f"[partial content: download stopped at {fetched:,} bytes{total_txt}. "
|
||||
f'Re-call with {{"url": "{url}", "full": true}} to fetch up to '
|
||||
f"{WEB_FETCH_HARD_MAX_BYTES:,} bytes.]\n\n"
|
||||
)
|
||||
|
||||
# The notice must lead the output so the MAX_OUTPUT_CHARS trim below can
|
||||
# never drop it. The title is untrusted, uncapped page content, so a
|
||||
# giant title ahead of the notice could push it out of range; keep the
|
||||
# notice first and cap the title as a second guard.
|
||||
if len(title) > 300:
|
||||
title = title[:300] + "..."
|
||||
header = (f"# {title}\n" if title else "") + f"Source: {url}\n\n"
|
||||
output = header + text
|
||||
output = size_note + header + text
|
||||
if len(output) > MAX_OUTPUT_CHARS:
|
||||
output = output[:MAX_OUTPUT_CHARS] + "\n\n[...truncated]"
|
||||
return {"output": output, "exit_code": 0}
|
||||
|
||||
+3
-2
@@ -14,6 +14,7 @@ import subprocess
|
||||
import sys
|
||||
|
||||
from core.platform_compat import IS_WINDOWS, which_tool
|
||||
from src.runtime_paths import get_app_root
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -81,7 +82,7 @@ _BUILTIN_NPX_SERVERS = {
|
||||
"name": "Built-in: Browser",
|
||||
"command": "npx",
|
||||
"args": ["-y", "@playwright/mcp@latest", "--headless", "--caps", "vision"],
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
# Global flag to disable MCP if there are compatibility issues
|
||||
@@ -94,7 +95,7 @@ async def register_builtin_servers(mcp_manager):
|
||||
logger.info("Built-in MCP servers disabled via ODYSSEUS_DISABLE_MCP")
|
||||
return
|
||||
|
||||
base_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
base_dir = get_app_root()
|
||||
python = sys.executable
|
||||
|
||||
async def _connect_python_server(server_id: str, script_path: str, name: str):
|
||||
|
||||
+3
-2
@@ -5,6 +5,7 @@ from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
from pydantic import Field, field_validator
|
||||
|
||||
from src.constants import DATA_DIR as _DATA_DIR_CONST
|
||||
from src.runtime_paths import get_app_root
|
||||
|
||||
# Cross-platform OS flag, exposed here so callers can `from src.config import
|
||||
# IS_WINDOWS`. Defined locally (a trivial `os.name == "nt"`) rather than imported
|
||||
@@ -19,7 +20,7 @@ IS_WINDOWS = os.name == "nt"
|
||||
class DataConfig(BaseSettings):
|
||||
"""Configuration for data storage and file handling."""
|
||||
# Base directory
|
||||
base_dir: Path = Field(default=Path(__file__).parent.parent, description="Base directory for the application")
|
||||
base_dir: Path = Field(default=Path(get_app_root()), description="Base directory for the application")
|
||||
|
||||
# Data paths
|
||||
data_dir: Path = Field(default=Path(_DATA_DIR_CONST), description="Main data directory")
|
||||
@@ -138,7 +139,7 @@ class AppConfig(BaseSettings):
|
||||
if isinstance(v, dict) and "base_dir" in v:
|
||||
base_dir = v["base_dir"]
|
||||
else:
|
||||
base_dir = Path(__file__).parent.parent
|
||||
base_dir = Path(get_app_root())
|
||||
|
||||
# Convert string paths to Path objects relative to base_dir
|
||||
data_dir = Path(_DATA_DIR_CONST)
|
||||
|
||||
+26
-3
@@ -2,12 +2,14 @@
|
||||
"""Application-wide constants and configuration values."""
|
||||
import os
|
||||
|
||||
from src.runtime_paths import get_app_root, get_default_data_dir
|
||||
|
||||
APP_VERSION = "1.0.0"
|
||||
|
||||
# Base paths
|
||||
BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + "/"
|
||||
BASE_DIR = os.path.join(get_app_root(), "")
|
||||
STATIC_DIR = os.path.join(BASE_DIR, "static")
|
||||
DATA_DIR = os.getenv("ODYSSEUS_DATA_DIR", os.path.join(BASE_DIR, "data"))
|
||||
DATA_DIR = os.getenv("ODYSSEUS_DATA_DIR", get_default_data_dir())
|
||||
|
||||
# Data file paths
|
||||
# Single source of truth: every persisted file/dir lives under DATA_DIR, which
|
||||
@@ -55,7 +57,13 @@ MEMORY_VECTORS_DIR = os.path.join(DATA_DIR, "memory_vectors")
|
||||
|
||||
# Paths with an intentional dedicated env override, defaulting under DATA_DIR.
|
||||
MAIL_ATTACHMENTS_DIR = os.getenv("ODYSSEUS_MAIL_ATTACHMENTS_DIR", os.path.join(DATA_DIR, "mail-attachments"))
|
||||
FASTEMBED_CACHE_DIR = os.getenv("FASTEMBED_CACHE_PATH", os.path.join(DATA_DIR, "fastembed_cache"))
|
||||
# `or` (not os.getenv's default arg) so a PRESENT-but-EMPTY value falls back to
|
||||
# the default. docker-compose.yml injects `FASTEMBED_CACHE_PATH=${FASTEMBED_CACHE_PATH:-}`,
|
||||
# which sets the var to "" when the host hasn't defined it. os.getenv(name, default)
|
||||
# only returns the default when the var is ABSENT, so the empty string would win →
|
||||
# os.makedirs("") raises [Errno 2] No such file or directory: '' → FastEmbed fails to
|
||||
# init and all vector features (RAG, semantic memory, tool index) silently degrade.
|
||||
FASTEMBED_CACHE_DIR = os.getenv("FASTEMBED_CACHE_PATH") or os.path.join(DATA_DIR, "fastembed_cache")
|
||||
|
||||
# Agent tool output limits (single source of truth — imported by tool_execution.py,
|
||||
# tool_implementations.py, agent_tools.py, and any other module that needs them)
|
||||
@@ -63,11 +71,26 @@ MAX_OUTPUT_CHARS = 10_000 # cap for bash/python/web_search/web_fetch outpu
|
||||
MAX_READ_CHARS = 20_000 # cap for read_file / document preview
|
||||
MAX_DIFF_LINES = 400 # cap for edit_file unified-diff display
|
||||
|
||||
# web_fetch response-size policy (#3812). MAX_OUTPUT_CHARS above only trims
|
||||
# what the agent SEES; these caps bound what the server downloads, parses,
|
||||
# and writes to the content cache. The soft cap is the default download
|
||||
# budget; the agent can raise it per call (full/max_bytes) but never past
|
||||
# the hard cap, so a model can't decide to pull a multi-GB file.
|
||||
WEB_FETCH_SOFT_MAX_BYTES = 2_000_000 # default download budget (2 MB)
|
||||
WEB_FETCH_HARD_MAX_BYTES = 20_000_000 # absolute ceiling, even with override (20 MB)
|
||||
|
||||
# API Configuration
|
||||
MAX_CONTEXT_MESSAGES = 90
|
||||
REQUEST_TIMEOUT = 20
|
||||
OPENAI_COMPAT_PATH = "/v1/chat/completions"
|
||||
|
||||
# Outbound UA for web_fetch / web_search scraping; common desktop UA so pages serve normal HTML.
|
||||
WEB_FETCH_USER_AGENT = os.environ.get(
|
||||
"WEB_FETCH_USER_AGENT",
|
||||
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 "
|
||||
"(KHTML, like Gecko) Chrome/148.0.0.0 Safari/537.36",
|
||||
)
|
||||
|
||||
# Environment variables with defaults
|
||||
DEFAULT_HOST = os.getenv("LLM_HOST", "localhost")
|
||||
LLM_HOSTS = [h.strip() for h in os.getenv("LLM_HOSTS", "").split(",") if h.strip()]
|
||||
|
||||
@@ -31,6 +31,8 @@ import numpy as np
|
||||
import httpx
|
||||
from typing import List, Optional
|
||||
|
||||
from src.runtime_paths import get_app_root
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_DEFAULT_MODEL = "all-minilm:l6-v2"
|
||||
|
||||
+4
-3
@@ -283,7 +283,8 @@ def _is_ollama_native_url(url: str) -> bool:
|
||||
"""Return True for native Ollama API URLs, including Ollama Cloud."""
|
||||
try:
|
||||
parsed = urlparse(url or "")
|
||||
except Exception:
|
||||
except Exception as e:
|
||||
logger.warning("Failed to parse URL for Ollama detection", exc_info=e)
|
||||
return False
|
||||
host = parsed.hostname or ""
|
||||
path = (parsed.path or "").rstrip("/")
|
||||
@@ -1345,8 +1346,8 @@ def list_model_ids(
|
||||
r = httpx.get(root + "/api/tags", timeout=timeout)
|
||||
r.raise_for_status()
|
||||
return [m.get("name") or m.get("model") for m in (r.json().get("models") or []) if m.get("name") or m.get("model")]
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.warning("Failed to fetch model list from configured endpoint", exc_info=e)
|
||||
return []
|
||||
|
||||
def normalize_model_id(
|
||||
|
||||
+3
-1
@@ -11,6 +11,8 @@ import os
|
||||
import re
|
||||
from typing import Any, Dict, List, Optional, Set, Tuple
|
||||
|
||||
from src.runtime_paths import get_app_root
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def _format_mcp_connection_error(name: str, command: str = "", args: Optional[List[str]] = None, error: Exception = None) -> str:
|
||||
@@ -508,7 +510,7 @@ class McpManager:
|
||||
return False
|
||||
|
||||
script_rel, name = _BUILTIN_SERVERS[server_id]
|
||||
base_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
base_dir = get_app_root()
|
||||
script_path = os.path.join(base_dir, script_rel)
|
||||
|
||||
# Clean up old connection
|
||||
|
||||
@@ -7,6 +7,7 @@ import time
|
||||
from pathlib import Path
|
||||
|
||||
from src.constants import RAG_DIR
|
||||
from src.runtime_paths import get_app_root
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -0,0 +1,30 @@
|
||||
"""Helpers for resolving runtime paths in source and frozen builds."""
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
|
||||
def get_app_root() -> str:
|
||||
"""Return the app root directory.
|
||||
|
||||
In normal source runs, this is the repository root. In a frozen Windows
|
||||
build, it is the bundle content root (PyInstaller's internal directory)
|
||||
so bundled runtime folders like `static/`, `scripts/`, and `data/` stay
|
||||
together with the executable payload.
|
||||
"""
|
||||
if getattr(sys, "frozen", False):
|
||||
return getattr(sys, "_MEIPASS", os.path.dirname(os.path.abspath(sys.executable)))
|
||||
return os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
|
||||
def get_default_data_dir() -> str:
|
||||
"""Return the default path to the data directory.
|
||||
|
||||
In normal runs, this is a 'data' subdirectory under the app root.
|
||||
In frozen builds, it is a persistent user directory (~/.odysseus/data)
|
||||
to prevent SQLite databases and other persistent files from being
|
||||
written to the ephemeral, temporary extraction bundle directory.
|
||||
"""
|
||||
if getattr(sys, "frozen", False):
|
||||
return os.path.join(os.path.expanduser("~"), ".odysseus", "data")
|
||||
return os.path.join(get_app_root(), "data")
|
||||
+24
-5
@@ -236,6 +236,29 @@ def _digest_windows(now):
|
||||
]
|
||||
|
||||
|
||||
def _checkin_calendar_events(db, owner, start, end):
|
||||
"""Calendar events in [start, end] for ONE owner, for the check-in digest.
|
||||
|
||||
Ownership lives on CalendarCal.owner; events inherit it via calendar_id.
|
||||
The digest query had no owner scope, so it pulled EVERY user's events into
|
||||
one user's check-in (a cross-tenant leak of summaries/locations). Scope it
|
||||
by joining CalendarCal, mirroring routes/calendar_routes.list_events.
|
||||
"""
|
||||
from core.database import CalendarEvent as _CE, CalendarCal as _CC
|
||||
return (
|
||||
db.query(_CE)
|
||||
.join(_CC, _CE.calendar_id == _CC.id)
|
||||
.filter(
|
||||
_CC.owner == owner,
|
||||
_CE.dtstart >= start,
|
||||
_CE.dtstart <= end,
|
||||
_CE.status != "cancelled",
|
||||
)
|
||||
.order_by(_CE.dtstart)
|
||||
.all()
|
||||
)
|
||||
|
||||
|
||||
class TaskScheduler:
|
||||
def __init__(self, session_manager):
|
||||
self._session_manager = session_manager
|
||||
@@ -1127,11 +1150,7 @@ class TaskScheduler:
|
||||
# Strip timezone for naive DB comparison
|
||||
_s = start.replace(tzinfo=None) if start.tzinfo else start
|
||||
_e = end.replace(tzinfo=None) if end.tzinfo else end
|
||||
evs = _db.query(_CE).filter(
|
||||
_CE.dtstart >= _s,
|
||||
_CE.dtstart <= _e,
|
||||
_CE.status != "cancelled",
|
||||
).order_by(_CE.dtstart).all()
|
||||
evs = _checkin_calendar_events(_db, task.owner, _s, _e)
|
||||
if not evs:
|
||||
continue
|
||||
# Group by importance for richer output
|
||||
|
||||
@@ -3797,7 +3797,7 @@ async def do_resolve_contact(content: str, owner: Optional[str] = None) -> Dict:
|
||||
if not name:
|
||||
return {"error": "name is required", "exit_code": 1}
|
||||
|
||||
contacts = {} # email -> {name, source}
|
||||
contacts = {} # email_or_phone -> {name, source, phone?}
|
||||
|
||||
# 1. CardDAV (Radicale) — structured contacts. Call in-process: a
|
||||
# server-side httpx GET to /api/contacts/search carries no session
|
||||
@@ -3812,10 +3812,18 @@ async def do_resolve_contact(content: str, owner: Optional[str] = None) -> Dict:
|
||||
match = q in hay_name or any(q in (e or "").lower() for e in c.get("emails", []))
|
||||
if not match:
|
||||
continue
|
||||
has_email = False
|
||||
for email in (c.get("emails") or []):
|
||||
email = (email or "").strip().lower()
|
||||
if email and "@" in email:
|
||||
contacts[email] = {"name": c.get("name") or email, "source": "contacts"}
|
||||
has_email = True
|
||||
# Fall back to phone numbers when the contact has no email address
|
||||
if not has_email:
|
||||
for phone in (c.get("phones") or []):
|
||||
phone = (phone or "").strip()
|
||||
if phone:
|
||||
contacts[phone] = {"name": c.get("name") or phone, "source": "contacts", "phone": phone}
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@@ -3835,8 +3843,11 @@ async def do_resolve_contact(content: str, owner: Optional[str] = None) -> Dict:
|
||||
return {"output": f"No contacts found matching '{name}'.", "exit_code": 0}
|
||||
|
||||
lines = [f"Contacts matching '{name}':"]
|
||||
for email, info in contacts.items():
|
||||
lines.append(f"- {info['name']} <{email}> ({info['source']})")
|
||||
for key, info in contacts.items():
|
||||
if info.get("phone"):
|
||||
lines.append(f"- {info['name']} — phone: {info['phone']} ({info['source']})")
|
||||
else:
|
||||
lines.append(f"- {info['name']} <{key}> ({info['source']})")
|
||||
return {"output": "\n".join(lines), "exit_code": 0}
|
||||
|
||||
|
||||
|
||||
+4
-3
@@ -68,11 +68,12 @@ FUNCTION_TOOL_SCHEMAS = [
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "web_fetch",
|
||||
"description": "Fetch and read the text content of a specific URL the user names (e.g. 'check example.com', 'what's on this page <url>'). Use when you already have a concrete URL/domain. NOT for open-ended searches (use web_search) or 'research X' jobs (use trigger_research).",
|
||||
"description": "Fetch and read the text content of a specific URL the user names (e.g. 'check example.com', 'what's on this page <url>'). Use when you already have a concrete URL/domain. NOT for open-ended searches (use web_search) or 'research X' jobs (use trigger_research). Downloads are size-budgeted; a '[partial content: ...]' notice in the result means the body was cut short and you can re-call with full=true for the rest.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"url": {"type": "string", "description": "The URL or domain to fetch (http/https; a bare domain like example.com is fine)"}
|
||||
"url": {"type": "string", "description": "The URL or domain to fetch (http/https; a bare domain like example.com is fine)"},
|
||||
"full": {"type": "boolean", "description": "Raise the download budget to the hard cap for large pages/files. Use only after a result reported partial content."}
|
||||
},
|
||||
"required": ["url"]
|
||||
}
|
||||
@@ -1008,7 +1009,7 @@ FUNCTION_TOOL_SCHEMAS = [
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "resolve_contact",
|
||||
"description": "Look up a contact's email address by name. Searches CardDAV address book and sent email history. Use when the user says 'message [name]' or 'email [name]' without an email address.",
|
||||
"description": "Look up a contact by name. Searches CardDAV address book and sent email history. Returns email addresses (when available) or phone numbers. Use when the user says 'message [name]', 'email [name]', or asks for someone's contact details.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
|
||||
@@ -1911,23 +1911,6 @@ import { wireArrowUpRecall, getLastUserMessageFromChatHistory } from './composer
|
||||
_chatBox.appendChild(note);
|
||||
try { note.scrollIntoView({ block: 'end', behavior: 'smooth' }); } catch (_) { uiModule.scrollHistory && uiModule.scrollHistory(); }
|
||||
}
|
||||
} else if (json.type === 'loop_breaker_triggered' || json.type === 'intent_nudge_exhausted') {
|
||||
// A loop guard ended the turn — surface why so it isn't mistaken
|
||||
// for a clean completion or a silent stall.
|
||||
const _chatBox = document.getElementById('chat-history');
|
||||
if (!_isBg && _chatBox) {
|
||||
const note = document.createElement('div');
|
||||
note.className = 'stopped-indicator loop-guard-stop';
|
||||
const label = document.createElement('span');
|
||||
label.className = 'rounds-exhausted-label';
|
||||
label.textContent = json.message ||
|
||||
(json.type === 'loop_breaker_triggered'
|
||||
? 'Stopped by the loop-breaker (no new progress).'
|
||||
: 'Stopped: announced an action but never called the tool.');
|
||||
note.appendChild(label);
|
||||
_chatBox.appendChild(note);
|
||||
try { note.scrollIntoView({ block: 'end', behavior: 'smooth' }); } catch (_) { uiModule.scrollHistory && uiModule.scrollHistory(); }
|
||||
}
|
||||
} else if (json.type === 'model_actual') {
|
||||
if (!_isBg && holder) {
|
||||
holder._requestedModel = json.requested_model || holder._requestedModel || modelName;
|
||||
|
||||
@@ -219,6 +219,9 @@ class _WebhookManager:
|
||||
async def fire(self, event, payload):
|
||||
return None
|
||||
|
||||
def fire_and_forget(self, event, payload):
|
||||
return None
|
||||
|
||||
|
||||
def _install_sync_chat_stubs(monkeypatch):
|
||||
# FastAPI checks for python_multipart at import time when Form is used;
|
||||
|
||||
@@ -502,3 +502,77 @@ def test_delete_token_owner_check_skipped_when_auth_disabled(monkeypatch, token_
|
||||
resp = delete_token(request=req, token_id="tok123")
|
||||
assert resp == {"status": "deleted"}
|
||||
fake_session.delete.assert_called_once_with(fake_token)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 7. PATCH /api/tokens/{id} — non-object JSON bodies must not 500
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_update_token_with_array_body_does_not_500(monkeypatch, token_routes_mod):
|
||||
"""PATCH body of [] must be normalised to {} and not raise."""
|
||||
monkeypatch.setenv("AUTH_ENABLED", "true")
|
||||
mod = token_routes_mod
|
||||
|
||||
token = SimpleNamespace(
|
||||
id="tok123", name="original", owner="alice",
|
||||
token_prefix="ody_orig", scopes="email:read", is_active=True,
|
||||
)
|
||||
fake_session = MagicMock()
|
||||
fake_session.query.return_value.filter.return_value.first.return_value = token
|
||||
monkeypatch.setattr(mod, "get_db_session", lambda: _db_ctx(fake_session))
|
||||
|
||||
invalidator = MagicMock()
|
||||
req = _patch_request(invalidator, [])
|
||||
update_token = _get_handler(mod, "PATCH", "/tokens/{token_id}")
|
||||
resp = asyncio.run(update_token(request=req, token_id="tok123"))
|
||||
|
||||
# Name and scopes must be unchanged — payload was normalised to {}
|
||||
assert token.name == "original"
|
||||
assert token.scopes == "email:read"
|
||||
assert resp["name"] == "original"
|
||||
|
||||
|
||||
def test_update_token_with_null_body_does_not_500(monkeypatch, token_routes_mod):
|
||||
"""PATCH body of null must be normalised to {} and not raise."""
|
||||
monkeypatch.setenv("AUTH_ENABLED", "true")
|
||||
mod = token_routes_mod
|
||||
|
||||
token = SimpleNamespace(
|
||||
id="tok123", name="original", owner="alice",
|
||||
token_prefix="ody_orig", scopes="chat", is_active=True,
|
||||
)
|
||||
fake_session = MagicMock()
|
||||
fake_session.query.return_value.filter.return_value.first.return_value = token
|
||||
monkeypatch.setattr(mod, "get_db_session", lambda: _db_ctx(fake_session))
|
||||
|
||||
invalidator = MagicMock()
|
||||
req = _patch_request(invalidator, None)
|
||||
update_token = _get_handler(mod, "PATCH", "/tokens/{token_id}")
|
||||
resp = asyncio.run(update_token(request=req, token_id="tok123"))
|
||||
|
||||
assert token.name == "original"
|
||||
assert token.scopes == "chat"
|
||||
|
||||
|
||||
def test_update_token_normal_object_still_works(monkeypatch, token_routes_mod):
|
||||
"""Normal dict payload continues to update fields as before."""
|
||||
monkeypatch.setenv("AUTH_ENABLED", "true")
|
||||
mod = token_routes_mod
|
||||
|
||||
token = SimpleNamespace(
|
||||
id="tok123", name="original", owner="alice",
|
||||
token_prefix="ody_orig", scopes="email:read", is_active=True,
|
||||
)
|
||||
fake_session = MagicMock()
|
||||
fake_session.query.return_value.filter.return_value.first.return_value = token
|
||||
monkeypatch.setattr(mod, "get_db_session", lambda: _db_ctx(fake_session))
|
||||
|
||||
invalidator = MagicMock()
|
||||
req = _patch_request(invalidator, {"name": "updated"})
|
||||
update_token = _get_handler(mod, "PATCH", "/tokens/{token_id}")
|
||||
resp = asyncio.run(update_token(request=req, token_id="tok123"))
|
||||
|
||||
assert token.name == "updated"
|
||||
assert resp["name"] == "updated"
|
||||
invalidator.assert_called_once()
|
||||
|
||||
@@ -30,7 +30,7 @@ class _Session:
|
||||
|
||||
|
||||
def test_allowed_models_legacy_empty_list_remains_unrestricted(monkeypatch):
|
||||
monkeypatch.setattr("routes.chat_helpers.get_current_user", lambda request: "alice")
|
||||
monkeypatch.setattr("routes.chat_helpers.effective_user", lambda request: "alice")
|
||||
|
||||
_enforce_chat_privileges(
|
||||
_Request({"allowed_models": [], "max_messages_per_day": 0}),
|
||||
@@ -39,7 +39,7 @@ def test_allowed_models_legacy_empty_list_remains_unrestricted(monkeypatch):
|
||||
|
||||
|
||||
def test_allowed_models_explicit_empty_restricted_list_blocks_all_models(monkeypatch):
|
||||
monkeypatch.setattr("routes.chat_helpers.get_current_user", lambda request: "alice")
|
||||
monkeypatch.setattr("routes.chat_helpers.effective_user", lambda request: "alice")
|
||||
|
||||
with pytest.raises(HTTPException) as exc:
|
||||
_enforce_chat_privileges(
|
||||
@@ -56,7 +56,7 @@ def test_allowed_models_explicit_empty_restricted_list_blocks_all_models(monkeyp
|
||||
|
||||
|
||||
def test_allowed_models_nonempty_list_still_restricts_without_new_flag(monkeypatch):
|
||||
monkeypatch.setattr("routes.chat_helpers.get_current_user", lambda request: "alice")
|
||||
monkeypatch.setattr("routes.chat_helpers.effective_user", lambda request: "alice")
|
||||
|
||||
_enforce_chat_privileges(
|
||||
_Request({"allowed_models": ["provider/model-a"], "max_messages_per_day": 0}),
|
||||
@@ -70,7 +70,7 @@ def test_allowed_models_nonempty_list_still_restricts_without_new_flag(monkeypat
|
||||
|
||||
|
||||
def test_no_restriction_allows_any_model(monkeypatch):
|
||||
monkeypatch.setattr("routes.chat_helpers.get_current_user", lambda request: "alice")
|
||||
monkeypatch.setattr("routes.chat_helpers.effective_user", lambda request: "alice")
|
||||
|
||||
privs = {"allowed_models": [], "block_all_models": False, "max_messages_per_day": 0}
|
||||
_enforce_chat_privileges(_Request(privs), _Session("provider/model-a"))
|
||||
@@ -78,7 +78,7 @@ def test_no_restriction_allows_any_model(monkeypatch):
|
||||
|
||||
|
||||
def test_specific_allowlist_blocks_models_outside_it(monkeypatch):
|
||||
monkeypatch.setattr("routes.chat_helpers.get_current_user", lambda request: "alice")
|
||||
monkeypatch.setattr("routes.chat_helpers.effective_user", lambda request: "alice")
|
||||
|
||||
privs = {
|
||||
"allowed_models": ["gpt-4"],
|
||||
@@ -92,7 +92,7 @@ def test_specific_allowlist_blocks_models_outside_it(monkeypatch):
|
||||
|
||||
|
||||
def test_block_all_models_blocks_regardless_of_allowed_models_contents(monkeypatch):
|
||||
monkeypatch.setattr("routes.chat_helpers.get_current_user", lambda request: "alice")
|
||||
monkeypatch.setattr("routes.chat_helpers.effective_user", lambda request: "alice")
|
||||
|
||||
# Even if allowed_models contains entries, block_all_models wins.
|
||||
privs = {
|
||||
@@ -111,7 +111,7 @@ def test_block_all_models_blocks_regardless_of_allowed_models_contents(monkeypat
|
||||
def test_admin_user_is_never_blocked(monkeypatch):
|
||||
from core.auth import ADMIN_PRIVILEGES
|
||||
|
||||
monkeypatch.setattr("routes.chat_helpers.get_current_user", lambda request: "admin")
|
||||
monkeypatch.setattr("routes.chat_helpers.effective_user", lambda request: "admin")
|
||||
|
||||
class _AdminAuthManager:
|
||||
def get_privileges(self, username):
|
||||
|
||||
@@ -0,0 +1,70 @@
|
||||
"""Check-in calendar digest must be scoped to the task owner.
|
||||
|
||||
The digest query selected CalendarEvent with no owner scope, so a scheduled
|
||||
check-in for one user pulled EVERY user's calendar events (summaries,
|
||||
locations) into their digest — a cross-tenant leak. Ownership lives on
|
||||
CalendarCal.owner; the query must join it, like routes/calendar_routes.
|
||||
"""
|
||||
import tempfile
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy.pool import NullPool
|
||||
|
||||
import core.database as cdb
|
||||
from core.database import CalendarEvent, CalendarCal
|
||||
from src.task_scheduler import _checkin_calendar_events
|
||||
|
||||
_TMPDB = tempfile.NamedTemporaryFile(suffix=".db", delete=False)
|
||||
_ENGINE = create_engine(f"sqlite:///{_TMPDB.name}", connect_args={"check_same_thread": False}, poolclass=NullPool)
|
||||
cdb.Base.metadata.create_all(_ENGINE)
|
||||
_TS = sessionmaker(bind=_ENGINE, autoflush=False, autocommit=False)
|
||||
|
||||
|
||||
def _seed():
|
||||
db = _TS()
|
||||
try:
|
||||
db.query(CalendarEvent).delete(); db.query(CalendarCal).delete()
|
||||
db.add(CalendarCal(id="calA", owner="alice", name="A"))
|
||||
db.add(CalendarCal(id="calB", owner="bob", name="B"))
|
||||
db.add(CalendarEvent(uid="a1", calendar_id="calA", summary="Alice mtg",
|
||||
dtstart=datetime(2026, 6, 10, 9, 0),
|
||||
dtend=datetime(2026, 6, 10, 10, 0), status="confirmed"))
|
||||
db.add(CalendarEvent(uid="b1", calendar_id="calB", summary="Bob secret",
|
||||
dtstart=datetime(2026, 6, 10, 10, 0),
|
||||
dtend=datetime(2026, 6, 10, 11, 0), status="confirmed"))
|
||||
db.commit()
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
def test_digest_only_returns_owner_events():
|
||||
_seed()
|
||||
db = _TS()
|
||||
try:
|
||||
s, e = datetime(2026, 6, 1), datetime(2026, 6, 30)
|
||||
alice = _checkin_calendar_events(db, "alice", s, e)
|
||||
assert [ev.summary for ev in alice] == ["Alice mtg"] # not Bob's
|
||||
bob = _checkin_calendar_events(db, "bob", s, e)
|
||||
assert [ev.summary for ev in bob] == ["Bob secret"]
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
def test_cancelled_excluded_and_window_respected():
|
||||
_seed()
|
||||
db = _TS()
|
||||
try:
|
||||
db2 = _TS()
|
||||
db2.add(CalendarEvent(uid="a2", calendar_id="calA", summary="cancelled",
|
||||
dtstart=datetime(2026, 6, 11),
|
||||
dtend=datetime(2026, 6, 11, 1, 0), status="cancelled"))
|
||||
db2.commit(); db2.close()
|
||||
s, e = datetime(2026, 6, 1), datetime(2026, 6, 30)
|
||||
out = _checkin_calendar_events(db, "alice", s, e)
|
||||
assert "cancelled" not in [ev.summary for ev in out]
|
||||
finally:
|
||||
db.close()
|
||||
@@ -13,6 +13,9 @@ import json
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
# core.database instantiates SQLAlchemy declarative classes at import time, which
|
||||
@@ -225,12 +228,34 @@ def test_models_route_scopes_api_token_to_token_owner(monkeypatch):
|
||||
endpoints = _call_models_route(
|
||||
monkeypatch,
|
||||
rows,
|
||||
_request(api_token=True, api_token_owner="alice", current_user="api"),
|
||||
_request(
|
||||
api_token=True,
|
||||
api_token_owner="alice",
|
||||
api_token_scopes=["chat"],
|
||||
current_user="api",
|
||||
),
|
||||
)
|
||||
|
||||
assert _endpoint_names(endpoints) == ["alice-endpoint", "shared-endpoint"]
|
||||
|
||||
|
||||
def test_models_route_rejects_api_token_without_chat_scope(monkeypatch):
|
||||
monkeypatch.setattr(companion_routes, "get_current_user", lambda request: "api")
|
||||
|
||||
with pytest.raises(HTTPException) as exc:
|
||||
_models_route()(
|
||||
_request(
|
||||
api_token=True,
|
||||
api_token_owner="alice",
|
||||
api_token_scopes=["todos:read"],
|
||||
current_user="api",
|
||||
)
|
||||
)
|
||||
|
||||
assert exc.value.status_code == 403
|
||||
assert "chat scope" in exc.value.detail
|
||||
|
||||
|
||||
def test_models_route_unresolved_owner_returns_only_shared_rows(monkeypatch):
|
||||
rows = [
|
||||
_ep(1, "alice-endpoint", "alice"),
|
||||
@@ -242,7 +267,12 @@ def test_models_route_unresolved_owner_returns_only_shared_rows(monkeypatch):
|
||||
endpoints = _call_models_route(
|
||||
monkeypatch,
|
||||
rows,
|
||||
_request(api_token=True, api_token_owner=None, current_user="api"),
|
||||
_request(
|
||||
api_token=True,
|
||||
api_token_owner=None,
|
||||
api_token_scopes=["chat"],
|
||||
current_user="api",
|
||||
),
|
||||
)
|
||||
|
||||
assert _endpoint_names(endpoints) == ["shared-endpoint"]
|
||||
|
||||
@@ -786,6 +786,50 @@ def test_cached_model_scan_reports_plain_dir_gguf(tmp_path):
|
||||
assert ggufs[3]["quant"] == "BF16"
|
||||
|
||||
|
||||
def test_cached_model_scan_uses_ollama_api_before_cli_and_windows_opt_in():
|
||||
script = _cached_model_scan_script()
|
||||
|
||||
assert "scan_ollama_api()\nscan_ollama()" in script
|
||||
assert "if any(m.get('is_ollama') for m in models): return" in script
|
||||
assert "os.name == 'nt'" in script
|
||||
assert "ODYSSEUS_ALLOW_OLLAMA_CLI_SCAN" in script
|
||||
|
||||
|
||||
@pytest.mark.skipif(os.name != "nt", reason="Windows Ollama CLI startup guard")
|
||||
def test_cached_model_scan_does_not_launch_ollama_cli_on_windows(tmp_path):
|
||||
"""Official Ollama for Windows can auto-start the tray/server on `ollama list`.
|
||||
The read-only cache scanner must not invoke that CLI unless explicitly opted in.
|
||||
"""
|
||||
marker = tmp_path / "ollama-called.txt"
|
||||
fake_ollama = tmp_path / "ollama.cmd"
|
||||
fake_ollama.write_text(
|
||||
"@echo off\r\n"
|
||||
f'echo called>"{marker}"\r\n'
|
||||
"echo NAME ID SIZE MODIFIED\r\n"
|
||||
"echo local-model:latest abc 1 GB now\r\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
empty_home = tmp_path / "home"
|
||||
empty_home.mkdir()
|
||||
scan_py = tmp_path / "scan_cache.py"
|
||||
scan_py.write_text(_cached_model_scan_script(), encoding="utf-8")
|
||||
env = dict(os.environ)
|
||||
env["PATH"] = str(tmp_path) + os.pathsep + env.get("PATH", "")
|
||||
env["HOME"] = str(empty_home)
|
||||
env.pop("ODYSSEUS_ALLOW_OLLAMA_CLI_SCAN", None)
|
||||
proc = subprocess.run(
|
||||
[sys.executable, str(scan_py)],
|
||||
check=True,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
env=env,
|
||||
)
|
||||
|
||||
assert marker.exists() is False
|
||||
assert all(m.get("backend") != "ollama" for m in json.loads(proc.stdout))
|
||||
|
||||
|
||||
def test_cached_model_scan_uses_huggingface_cache_env(tmp_path):
|
||||
"""Docker recreates can leave the persisted HF cache outside HOME.
|
||||
The Serve scanner should honor the cache env path instead of only ~/.cache.
|
||||
|
||||
@@ -0,0 +1,69 @@
|
||||
"""Regression: FASTEMBED_CACHE_DIR must tolerate a PRESENT-but-EMPTY
|
||||
FASTEMBED_CACHE_PATH.
|
||||
|
||||
docker-compose.yml injects ``FASTEMBED_CACHE_PATH=${FASTEMBED_CACHE_PATH:-}``,
|
||||
which sets the variable to ``""`` when the host has not defined it. The old
|
||||
``os.getenv("FASTEMBED_CACHE_PATH", default)`` only used the default when the
|
||||
variable was ABSENT, so an empty value made ``FASTEMBED_CACHE_DIR == ""`` →
|
||||
``os.makedirs("")`` raised ``[Errno 2] No such file or directory: ''`` →
|
||||
FastEmbed failed to initialise and every vector feature (RAG, semantic memory,
|
||||
tool index) silently degraded on the default Docker stack.
|
||||
|
||||
These tests pin the fix: empty is treated like absent → use the DATA_DIR
|
||||
default, while an explicit non-empty override is still honoured.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import os
|
||||
|
||||
import src.constants as constants
|
||||
|
||||
|
||||
def _reload_with(monkeypatch, value):
|
||||
"""Reload src.constants with FASTEMBED_CACHE_PATH set to ``value`` (or
|
||||
removed when ``value`` is None) and return the reloaded module."""
|
||||
if value is None:
|
||||
monkeypatch.delenv("FASTEMBED_CACHE_PATH", raising=False)
|
||||
else:
|
||||
monkeypatch.setenv("FASTEMBED_CACHE_PATH", value)
|
||||
return importlib.reload(constants)
|
||||
|
||||
|
||||
def _restore(monkeypatch):
|
||||
"""Return the module to its env-default state so reloading it here does
|
||||
not leak a test-specific FASTEMBED_CACHE_DIR into other tests."""
|
||||
monkeypatch.delenv("FASTEMBED_CACHE_PATH", raising=False)
|
||||
importlib.reload(constants)
|
||||
|
||||
|
||||
def test_empty_fastembed_cache_path_falls_back_to_default(monkeypatch):
|
||||
"""The bug: an empty FASTEMBED_CACHE_PATH (exactly what Docker injects)
|
||||
must fall back to the DATA_DIR default, never the empty string."""
|
||||
try:
|
||||
mod = _reload_with(monkeypatch, "")
|
||||
assert mod.FASTEMBED_CACHE_DIR, "empty env must not yield an empty path"
|
||||
assert mod.FASTEMBED_CACHE_DIR == os.path.join(mod.DATA_DIR, "fastembed_cache")
|
||||
finally:
|
||||
_restore(monkeypatch)
|
||||
|
||||
|
||||
def test_unset_fastembed_cache_path_uses_default(monkeypatch):
|
||||
"""Sanity: an absent variable also resolves to the default."""
|
||||
try:
|
||||
mod = _reload_with(monkeypatch, None)
|
||||
assert mod.FASTEMBED_CACHE_DIR == os.path.join(mod.DATA_DIR, "fastembed_cache")
|
||||
finally:
|
||||
_restore(monkeypatch)
|
||||
|
||||
|
||||
def test_explicit_fastembed_cache_path_is_respected(monkeypatch):
|
||||
"""A real explicit override must still win — the fix only changes the
|
||||
empty-value handling, not the documented FASTEMBED_CACHE_PATH override."""
|
||||
custom = os.path.join("custom", "fastembed-cache")
|
||||
try:
|
||||
mod = _reload_with(monkeypatch, custom)
|
||||
assert mod.FASTEMBED_CACHE_DIR == custom
|
||||
finally:
|
||||
_restore(monkeypatch)
|
||||
@@ -79,7 +79,7 @@ def _build_context_harness(monkeypatch, chat_helpers, history):
|
||||
monkeypatch.setattr(chat_helpers, "extract_preset", fake_extract_preset)
|
||||
monkeypatch.setattr(chat_helpers, "add_user_message", fake_add_user_message)
|
||||
monkeypatch.setattr(chat_helpers, "load_prefs_for_user", lambda user: {})
|
||||
monkeypatch.setattr(chat_helpers, "get_current_user", lambda request: "tester")
|
||||
monkeypatch.setattr(chat_helpers, "effective_user", lambda request: "tester")
|
||||
monkeypatch.setattr(chat_helpers, "normalize_model_id", lambda endpoint_url, model, **kwargs: None)
|
||||
monkeypatch.setattr(chat_helpers, "maybe_compact", fake_maybe_compact)
|
||||
monkeypatch.setattr(chat_helpers, "trim_for_context", lambda messages, context_length: messages)
|
||||
|
||||
@@ -1,350 +0,0 @@
|
||||
"""Regression: stream_agent_loop surfaces *why* a guard ended the turn.
|
||||
|
||||
Two internal guards used to stop the agent in ways that looked like a clean
|
||||
completion or a vague blocked message:
|
||||
|
||||
* the loop-breaker stall detector -> now emits `loop_breaker_triggered`
|
||||
* the intent-without-action nudge cap -> now emits `intent_nudge_exhausted`
|
||||
|
||||
These tests run the real loop body against a fake LLM stream (no model calls,
|
||||
no sleeps) and assert the structured stop event is emitted.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
|
||||
import pytest
|
||||
|
||||
import src.agent_loop as al
|
||||
|
||||
|
||||
def _collect(gen):
|
||||
async def _run():
|
||||
return [c async for c in gen]
|
||||
return asyncio.run(_run())
|
||||
|
||||
|
||||
def _types(chunks):
|
||||
out = []
|
||||
for c in chunks:
|
||||
if c.startswith("data: ") and not c.startswith("data: [DONE]"):
|
||||
try:
|
||||
out.append(json.loads(c[6:]))
|
||||
except Exception:
|
||||
pass
|
||||
return out
|
||||
|
||||
|
||||
def _patch_common(monkeypatch):
|
||||
monkeypatch.setattr(al, "get_setting", lambda key, default=None: default, raising=False)
|
||||
monkeypatch.setattr(al, "get_mcp_manager", lambda: None, raising=False)
|
||||
monkeypatch.setattr(al, "estimate_tokens", lambda *a, **k: 10, raising=False)
|
||||
|
||||
async def _fake_exec(block, *a, **k):
|
||||
return ("bash", {"output": "ok", "exit_code": 0})
|
||||
monkeypatch.setattr(al, "execute_tool_block", _fake_exec, raising=False)
|
||||
|
||||
|
||||
def _run_loop(monkeypatch, round_text, max_rounds, relevant_tools={"bash"}):
|
||||
async def _fake_stream(_candidates, messages, **kwargs):
|
||||
yield f'data: {json.dumps({"delta": round_text})}\n\n'
|
||||
yield "data: [DONE]\n\n"
|
||||
monkeypatch.setattr(al, "stream_llm_with_fallback", _fake_stream, raising=False)
|
||||
|
||||
gen = al.stream_agent_loop(
|
||||
"http://x/v1", "m",
|
||||
[{"role": "user", "content": "do a long multi-step task"}],
|
||||
max_rounds=max_rounds,
|
||||
relevant_tools=relevant_tools,
|
||||
)
|
||||
return _types(_collect(gen))
|
||||
|
||||
|
||||
def test_emits_loop_breaker_triggered_on_repeated_no_progress(monkeypatch):
|
||||
_patch_common(monkeypatch)
|
||||
# Same exact tool call every round, no answer text -> stuck-round streak
|
||||
# trips the loop-breaker once the cap is reached.
|
||||
events = _run_loop(monkeypatch, "```bash\necho hi\n```", max_rounds=8)
|
||||
lb = [e for e in events if e.get("type") == "loop_breaker_triggered"]
|
||||
assert lb, events
|
||||
e = lb[0]
|
||||
assert e["reason"]
|
||||
assert e["max_stuck_rounds"] == 4
|
||||
assert e["stuck_rounds"] >= 4
|
||||
assert "message" in e
|
||||
|
||||
|
||||
def test_no_loop_breaker_on_normal_finish(monkeypatch):
|
||||
_patch_common(monkeypatch)
|
||||
events = _run_loop(monkeypatch, "All done, here is your answer.", max_rounds=8)
|
||||
assert not any(e.get("type") == "loop_breaker_triggered" for e in events), events
|
||||
|
||||
|
||||
def test_emits_intent_nudge_exhausted_when_cap_reached(monkeypatch):
|
||||
_patch_common(monkeypatch)
|
||||
# The model keeps announcing an action with no tool call. After the nudge
|
||||
# cap is spent, the turn ends with an explicit intent_nudge_exhausted event.
|
||||
events = _run_loop(monkeypatch, "Let me check the logs now", max_rounds=5)
|
||||
inx = [e for e in events if e.get("type") == "intent_nudge_exhausted"]
|
||||
assert inx, events
|
||||
e = inx[0]
|
||||
assert e["max_nudges"] == 2
|
||||
assert e["nudges"] >= 2
|
||||
assert "message" in e
|
||||
|
||||
|
||||
def test_no_intent_nudge_exhausted_on_normal_finish(monkeypatch):
|
||||
_patch_common(monkeypatch)
|
||||
events = _run_loop(monkeypatch, "Here is the complete answer to your question.", max_rounds=5)
|
||||
assert not any(e.get("type") == "intent_nudge_exhausted" for e in events), events
|
||||
|
||||
|
||||
def _assert_guard_log_safe(caplog, *, structural, secret="secret123"):
|
||||
"""The guard's own structural log line fired, and that record carries no raw
|
||||
secret. Scoped to the guard's records on purpose: an unrelated, pre-existing
|
||||
round-summary log echoes raw model text and is out of scope for this PR."""
|
||||
records = [r for r in caplog.records if structural in r.getMessage()]
|
||||
assert records, caplog.text
|
||||
for r in records:
|
||||
assert secret not in r.getMessage(), r.getMessage()
|
||||
|
||||
|
||||
def test_intent_nudge_logging_does_not_leak_secret(monkeypatch, caplog):
|
||||
# The model announces an action (no tool call) with a secret in the text.
|
||||
# The nudge logger must record only structural metadata, never the matched
|
||||
# phrase — so the credential never lands in journalctl.
|
||||
_patch_common(monkeypatch)
|
||||
with caplog.at_level(logging.INFO, logger="src.agent_loop"):
|
||||
events = _run_loop(monkeypatch, "Let me check api_key=secret123 now", max_rounds=5)
|
||||
assert any(e.get("type") == "intent_nudge_exhausted" for e in events), events
|
||||
_assert_guard_log_safe(caplog, structural="intent-without-action nudge")
|
||||
|
||||
|
||||
def test_loop_breaker_logging_does_not_leak_secret(monkeypatch, caplog):
|
||||
# A repeated tool command carrying a secret trips the loop-breaker. The
|
||||
# structural log must not contain `_sig` / raw tool-call content.
|
||||
_patch_common(monkeypatch)
|
||||
with caplog.at_level(logging.INFO, logger="src.agent_loop"):
|
||||
events = _run_loop(monkeypatch, "```bash\necho api_key=secret123\n```", max_rounds=8)
|
||||
assert any(e.get("type") == "loop_breaker_triggered" for e in events), events
|
||||
_assert_guard_log_safe(caplog, structural="loop-breaker tripped")
|
||||
|
||||
|
||||
def test_redacts_sensitive_tool_output_before_surfacing():
|
||||
text = al._redact_sensitive_text(
|
||||
"password: private-value\n"
|
||||
"api_key=private-key\n"
|
||||
"Authorization: Bearer private-token\n"
|
||||
"normal output"
|
||||
)
|
||||
|
||||
assert "private-value" not in text
|
||||
assert "private-key" not in text
|
||||
assert "private-token" not in text
|
||||
assert "password: [redacted]" in text
|
||||
assert "api_key=[redacted]" in text
|
||||
assert "Authorization: Bearer [redacted]" in text
|
||||
assert "normal output" in text
|
||||
|
||||
|
||||
_GCP_API_KEY_SAMPLE = "AI" + "za" + ("A" * 35)
|
||||
|
||||
# (input, secret substring that must be gone, expected substring that must remain)
|
||||
_REDACTION_CASES = [
|
||||
("Authorization: Bearer abc123tok", "abc123tok", "Authorization: Bearer [redacted]"),
|
||||
("Authorization: Basic dXNlcjpwYXNz", "dXNlcjpwYXNz", "Authorization: Basic [redacted]"),
|
||||
# Quoted Authorization value (spaces) must be redacted whole.
|
||||
('Authorization: Bearer "two word secret"', "two word secret", "Authorization: Bearer [redacted]"),
|
||||
# Escaped quote inside a quoted secret must not leak the tail.
|
||||
(r'password="abc\"def secret"', "def secret", "password=[redacted]"),
|
||||
# URL password containing a colon must still be redacted whole.
|
||||
("postgres://user:pa:ss@host/db", "pa:ss", "postgres://[redacted]@host/db"),
|
||||
# Provider-shaped bare tokens.
|
||||
("token is hf_abcdefghij1234567890XYZ", "hf_abcdefghij1234567890XYZ", "[redacted]"),
|
||||
("key " + _GCP_API_KEY_SAMPLE, _GCP_API_KEY_SAMPLE, "[redacted]"),
|
||||
("Cookie: session=abc123secret", "abc123secret", "Cookie: [redacted]"),
|
||||
("Set-Cookie: sid=xyz789; HttpOnly", "xyz789", "Set-Cookie: [redacted]"),
|
||||
("postgres://user:pa55word@host/db", "pa55word", "postgres://[redacted]@host/db"),
|
||||
("client_secret=supersecretvalue", "supersecretvalue", "client_secret=[redacted]"),
|
||||
("OPENAI_API_KEY=abcd1234deadbeef", "abcd1234deadbeef", "OPENAI_API_KEY=[redacted]"),
|
||||
# Quoted multi-word env value must be fully redacted, not clipped at the space.
|
||||
('OPENAI_API_KEY="two word secret"', "two word secret", "OPENAI_API_KEY=[redacted]"),
|
||||
('password: "my secret value"', "my secret value", "password: [redacted]"),
|
||||
("here is sk-abcdefghij1234567890", "sk-abcdefghij1234567890", "[redacted]"),
|
||||
(
|
||||
"-----BEGIN PRIVATE KEY-----\nMIIfakeKEYbody\n-----END PRIVATE KEY-----",
|
||||
"MIIfakeKEYbody",
|
||||
"[redacted private key]",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("raw, secret, expected", _REDACTION_CASES)
|
||||
def test_redaction_covers_requested_secret_shapes(raw, secret, expected):
|
||||
out = al._redact_sensitive_text(raw)
|
||||
assert secret not in out, out
|
||||
assert expected in out, out
|
||||
|
||||
|
||||
@pytest.mark.parametrize("raw", [
|
||||
"the build completed in 3.2s with 0 errors",
|
||||
"password reset email sent to the user",
|
||||
"Listing 5 files: a.py b.py c.py d.py e.py",
|
||||
"https://example.com/path?page=2",
|
||||
# Benign uppercase names that merely end in KEY must not be redacted.
|
||||
"MONKEY=banana",
|
||||
"TURKEY=dinner",
|
||||
])
|
||||
def test_redaction_keeps_normal_output_readable(raw):
|
||||
assert al._redact_sensitive_text(raw) == raw
|
||||
|
||||
|
||||
def test_redacts_before_truncating():
|
||||
# A secret near the start must be gone even if truncation would otherwise
|
||||
# only clip the tail — redaction runs first.
|
||||
raw = "api_key=topsecretvalue " + ("x" * 50_000)
|
||||
out = al._truncate(al._redact_sensitive_text(raw))
|
||||
assert "topsecretvalue" not in out
|
||||
assert "api_key=[redacted]" in out
|
||||
|
||||
|
||||
def _run_tool_result(monkeypatch, tool, exec_result, max_rounds=2):
|
||||
"""Drive one tool round whose execution returns `exec_result`, and collect
|
||||
the streamed events. Used to assert restored per-tool-result emissions."""
|
||||
_patch_common(monkeypatch)
|
||||
|
||||
async def _fake_exec(block, *a, **k):
|
||||
return (tool, exec_result)
|
||||
monkeypatch.setattr(al, "execute_tool_block", _fake_exec, raising=False)
|
||||
|
||||
round_text = f"```{tool}\n{{}}\n```"
|
||||
|
||||
async def _fake_stream(_candidates, messages, **kwargs):
|
||||
yield f'data: {json.dumps({"delta": round_text})}\n\n'
|
||||
yield "data: [DONE]\n\n"
|
||||
monkeypatch.setattr(al, "stream_llm_with_fallback", _fake_stream, raising=False)
|
||||
|
||||
gen = al.stream_agent_loop(
|
||||
"http://x/v1", "m",
|
||||
[{"role": "user", "content": "do something"}],
|
||||
max_rounds=max_rounds,
|
||||
relevant_tools={tool},
|
||||
)
|
||||
return _types(_collect(gen))
|
||||
|
||||
|
||||
def test_restores_doc_suggestions_event(monkeypatch):
|
||||
events = _run_tool_result(
|
||||
monkeypatch, "suggest_document",
|
||||
{"action": "suggest", "doc_id": "d1", "suggestions": [{"text": "x"}], "exit_code": 0},
|
||||
)
|
||||
assert any(e.get("type") == "doc_suggestions" for e in events), events
|
||||
|
||||
|
||||
def test_restores_doc_update_event(monkeypatch):
|
||||
events = _run_tool_result(
|
||||
monkeypatch, "edit_document",
|
||||
{"action": "edit", "doc_id": "d1", "content": "body", "version": 2,
|
||||
"title": "T", "language": "md", "exit_code": 0},
|
||||
)
|
||||
# A native document block also emits doc_update AFTER tool_output, so a plain
|
||||
# "any doc_update" check would pass even if the restored generic block were
|
||||
# gone. Prove the restored block fires BEFORE the first tool_output.
|
||||
types = [e.get("type") for e in events]
|
||||
assert "doc_update" in types, events
|
||||
assert "tool_output" in types, events
|
||||
assert types.index("doc_update") < types.index("tool_output"), types
|
||||
|
||||
|
||||
def test_restores_ui_control_event(monkeypatch):
|
||||
events = _run_tool_result(
|
||||
monkeypatch, "ui_control",
|
||||
{"ui_event": "toggle", "toggle_name": "bash", "state": "off", "exit_code": 0},
|
||||
)
|
||||
assert any(e.get("type") == "ui_control" for e in events), events
|
||||
|
||||
|
||||
def test_restores_plan_update_event(monkeypatch):
|
||||
events = _run_tool_result(
|
||||
monkeypatch, "update_plan",
|
||||
{"plan_update": {"steps": [{"text": "step", "done": True}]}, "exit_code": 0},
|
||||
)
|
||||
assert any(e.get("type") == "plan_update" for e in events), events
|
||||
|
||||
|
||||
def test_restores_ask_user_event_and_persists_question(monkeypatch):
|
||||
events = _run_tool_result(
|
||||
monkeypatch, "ask_user",
|
||||
{"ask_user": {"question": "Which option?", "options": [{"label": "A"}, {"label": "B"}]},
|
||||
"exit_code": 0},
|
||||
)
|
||||
# Exactly one ask_user event — not re-emitted on a follow-up round.
|
||||
_ask_events = [e for e in events if e.get("type") == "ask_user"]
|
||||
assert len(_ask_events) == 1, events
|
||||
# The question is streamed as assistant text so it persists for replay.
|
||||
# Upstream prepends "\n\n" when full_response already holds streamed text,
|
||||
# so match on containment — and it must be streamed exactly once.
|
||||
_q_deltas = [e for e in events if "Which option?" in (e.get("delta") or "")]
|
||||
assert len(_q_deltas) == 1, events
|
||||
# Setting `_awaiting_user` breaks the loop, so the turn does NOT advance into
|
||||
# another agent round (which would emit an agent_step event) after the ask.
|
||||
assert not any(e.get("type") == "agent_step" for e in events), events
|
||||
|
||||
|
||||
def test_redacts_command_display_in_streamed_events(monkeypatch):
|
||||
# A tool command line can carry a secret. The streamed command display
|
||||
# (tool_start / tool_output) must be redacted, even though the real command
|
||||
# passed to execution is left untouched.
|
||||
_patch_common(monkeypatch)
|
||||
|
||||
round_text = "```bash\necho api_key=secret123\n```"
|
||||
|
||||
async def _fake_stream(_candidates, messages, **kwargs):
|
||||
yield f'data: {json.dumps({"delta": round_text})}\n\n'
|
||||
yield "data: [DONE]\n\n"
|
||||
monkeypatch.setattr(al, "stream_llm_with_fallback", _fake_stream, raising=False)
|
||||
|
||||
gen = al.stream_agent_loop(
|
||||
"http://x/v1", "m",
|
||||
[{"role": "user", "content": "run it"}],
|
||||
max_rounds=2,
|
||||
relevant_tools={"bash"},
|
||||
)
|
||||
events = _types(_collect(gen))
|
||||
cmds = [e for e in events if e.get("type") in ("tool_start", "tool_output")]
|
||||
assert cmds, events
|
||||
assert all("secret123" not in (e.get("command") or "") for e in cmds), cmds
|
||||
assert any("api_key=[redacted]" in (e.get("command") or "") for e in cmds), cmds
|
||||
|
||||
|
||||
def test_redacts_live_tool_progress_tail(monkeypatch):
|
||||
# A secret in the live progress tail must be redacted before streaming —
|
||||
# otherwise it flashes by before the (already redacted) final tool_output.
|
||||
_patch_common(monkeypatch)
|
||||
|
||||
async def _fake_exec(block, *a, **k):
|
||||
await k["progress_cb"]({"tail": "api_key=secret123", "elapsed_s": 1})
|
||||
return ("bash", {"output": "done", "exit_code": 0})
|
||||
monkeypatch.setattr(al, "execute_tool_block", _fake_exec, raising=False)
|
||||
|
||||
round_text = "```bash\necho hi\n```"
|
||||
|
||||
async def _fake_stream(_candidates, messages, **kwargs):
|
||||
yield f'data: {json.dumps({"delta": round_text})}\n\n'
|
||||
yield "data: [DONE]\n\n"
|
||||
monkeypatch.setattr(al, "stream_llm_with_fallback", _fake_stream, raising=False)
|
||||
|
||||
gen = al.stream_agent_loop(
|
||||
"http://x/v1", "m",
|
||||
[{"role": "user", "content": "run it"}],
|
||||
max_rounds=2,
|
||||
relevant_tools={"bash"},
|
||||
)
|
||||
events = _types(_collect(gen))
|
||||
prog = [e for e in events if e.get("type") == "tool_progress"]
|
||||
assert prog, events
|
||||
assert all("secret123" not in (e.get("tail") or "") for e in prog), prog
|
||||
assert any("api_key=[redacted]" in (e.get("tail") or "") for e in prog), prog
|
||||
# Other fields are preserved.
|
||||
assert any(e.get("elapsed_s") == 1 for e in prog), prog
|
||||
@@ -0,0 +1,150 @@
|
||||
import asyncio
|
||||
|
||||
import mcp_servers.memory_server as memory_server
|
||||
from src.memory import MemoryManager
|
||||
|
||||
|
||||
class FakeVector:
|
||||
healthy = True
|
||||
|
||||
def __init__(self):
|
||||
self.added = []
|
||||
self.removed = []
|
||||
|
||||
def add(self, memory_id, text):
|
||||
self.added.append((memory_id, text))
|
||||
|
||||
def remove(self, memory_id):
|
||||
self.removed.append(memory_id)
|
||||
|
||||
|
||||
def _tool_text(arguments):
|
||||
result = asyncio.run(memory_server.call_tool("manage_memory", arguments))
|
||||
return result[0].text
|
||||
|
||||
|
||||
def _entry(manager, text, owner=None, memory_id=None, category="fact"):
|
||||
entry = manager.add_entry(text, owner=owner, category=category)
|
||||
if memory_id:
|
||||
entry["id"] = memory_id
|
||||
return entry
|
||||
|
||||
|
||||
def _configure_server(monkeypatch, manager, vector=None):
|
||||
monkeypatch.setattr(memory_server, "_memory_manager", manager)
|
||||
monkeypatch.setattr(memory_server, "_memory_vector", vector)
|
||||
monkeypatch.setattr(memory_server, "_initialized", True)
|
||||
for key in memory_server._OWNER_ENV_KEYS:
|
||||
monkeypatch.delenv(key, raising=False)
|
||||
|
||||
|
||||
def test_mcp_memory_uses_configured_owner_for_all_operations(monkeypatch, tmp_path):
|
||||
manager = MemoryManager(str(tmp_path))
|
||||
vector = FakeVector()
|
||||
alice = _entry(
|
||||
manager,
|
||||
"Alice likes green tea",
|
||||
owner="alice",
|
||||
memory_id="aaaaaaaa-0000-0000-0000-000000000000",
|
||||
)
|
||||
bob = _entry(
|
||||
manager,
|
||||
"Bob likes espresso",
|
||||
owner="bob",
|
||||
memory_id="bbbbbbbb-0000-0000-0000-000000000000",
|
||||
)
|
||||
manager.save([alice, bob])
|
||||
_configure_server(monkeypatch, manager, vector)
|
||||
monkeypatch.setenv("ODYSSEUS_MCP_MEMORY_OWNER", "alice")
|
||||
|
||||
list_text = _tool_text({"action": "list"})
|
||||
assert "Alice likes green tea" in list_text
|
||||
assert "Bob likes espresso" not in list_text
|
||||
|
||||
search_text = _tool_text({"action": "search", "text": "likes"})
|
||||
assert "Alice likes green tea" in search_text
|
||||
assert "Bob likes espresso" not in search_text
|
||||
|
||||
add_text = _tool_text({
|
||||
"action": "add",
|
||||
"text": "Alice prefers concise notes",
|
||||
"category": "preference",
|
||||
})
|
||||
assert "Memory added" in add_text
|
||||
added = next(
|
||||
entry for entry in manager.load_all()
|
||||
if entry["text"] == "Alice prefers concise notes"
|
||||
)
|
||||
assert added["owner"] == "alice"
|
||||
assert vector.added == [(added["id"], "Alice prefers concise notes")]
|
||||
|
||||
edit_text = _tool_text({
|
||||
"action": "edit",
|
||||
"memory_id": bob["id"][:8],
|
||||
"text": "Bob changed",
|
||||
})
|
||||
assert edit_text == "Error: Memory 'bbbbbbbb' not found"
|
||||
bob_after_edit = next(
|
||||
entry for entry in manager.load_all()
|
||||
if entry["id"] == bob["id"]
|
||||
)
|
||||
assert bob_after_edit["text"] == "Bob likes espresso"
|
||||
|
||||
delete_text = _tool_text({"action": "delete", "memory_id": bob["id"][:8]})
|
||||
assert delete_text == "Error: Memory 'bbbbbbbb' not found"
|
||||
assert any(entry["id"] == bob["id"] for entry in manager.load_all())
|
||||
|
||||
|
||||
def test_mcp_memory_fails_closed_without_owner_for_owner_scoped_store(monkeypatch, tmp_path):
|
||||
manager = MemoryManager(str(tmp_path))
|
||||
alice = _entry(manager, "Alice private memory", owner="alice", memory_id="aaaaaaaa-0000")
|
||||
bob = _entry(manager, "Bob private memory", owner="bob", memory_id="bbbbbbbb-0000")
|
||||
manager.save([alice, bob])
|
||||
_configure_server(monkeypatch, manager, FakeVector())
|
||||
before = manager.load_all()
|
||||
|
||||
actions = [
|
||||
{"action": "list"},
|
||||
{"action": "search", "text": "private"},
|
||||
{"action": "add", "text": "new ownerless memory"},
|
||||
{"action": "edit", "memory_id": alice["id"][:8], "text": "changed"},
|
||||
{"action": "delete", "memory_id": alice["id"][:8]},
|
||||
]
|
||||
|
||||
for arguments in actions:
|
||||
assert _tool_text(arguments).startswith("Error: Memory MCP owner is not configured")
|
||||
|
||||
assert manager.load_all() == before
|
||||
|
||||
|
||||
def test_mcp_memory_preserves_ownerless_local_behavior(monkeypatch, tmp_path):
|
||||
manager = MemoryManager(str(tmp_path))
|
||||
legacy = _entry(
|
||||
manager,
|
||||
"Legacy local memory",
|
||||
memory_id="llllllll-0000-0000-0000-000000000000",
|
||||
)
|
||||
manager.save([legacy])
|
||||
_configure_server(monkeypatch, manager, FakeVector())
|
||||
|
||||
assert "Legacy local memory" in _tool_text({"action": "list"})
|
||||
assert "Legacy local memory" in _tool_text({"action": "search", "text": "legacy"})
|
||||
|
||||
add_text = _tool_text({"action": "add", "text": "Another local memory"})
|
||||
assert "Memory added" in add_text
|
||||
added = next(
|
||||
entry for entry in manager.load_all()
|
||||
if entry["text"] == "Another local memory"
|
||||
)
|
||||
assert "owner" not in added
|
||||
|
||||
assert _tool_text({
|
||||
"action": "edit",
|
||||
"memory_id": legacy["id"][:8],
|
||||
"text": "Updated local memory",
|
||||
}) == "Memory updated: Updated local memory"
|
||||
assert any(entry["text"] == "Updated local memory" for entry in manager.load_all())
|
||||
|
||||
delete_text = _tool_text({"action": "delete", "memory_id": legacy["id"][:8]})
|
||||
assert delete_text.startswith("Memory deleted:")
|
||||
assert all(entry["id"] != legacy["id"] for entry in manager.load_all())
|
||||
@@ -385,7 +385,7 @@ async def test_build_chat_context_incognito_does_not_duplicate_current_user_mess
|
||||
monkeypatch.setattr(chat_helpers, "extract_preset", fake_extract_preset)
|
||||
monkeypatch.setattr(chat_helpers, "add_user_message", fake_add_user_message)
|
||||
monkeypatch.setattr(chat_helpers, "load_prefs_for_user", lambda user: {})
|
||||
monkeypatch.setattr(chat_helpers, "get_current_user", lambda request: "tester")
|
||||
monkeypatch.setattr(chat_helpers, "effective_user", lambda request: "tester")
|
||||
monkeypatch.setattr(chat_helpers, "normalize_model_id", lambda endpoint_url, model, **kwargs: None)
|
||||
monkeypatch.setattr(chat_helpers, "maybe_compact", fake_maybe_compact)
|
||||
monkeypatch.setattr(chat_helpers, "trim_for_context", lambda messages, context_length: messages)
|
||||
|
||||
@@ -0,0 +1,50 @@
|
||||
import os
|
||||
import sys
|
||||
from unittest import mock
|
||||
import pytest
|
||||
from src.runtime_paths import get_app_root, get_default_data_dir
|
||||
|
||||
|
||||
def test_get_app_root_normal_run():
|
||||
"""Verify that get_app_root returns the repository root parent of src/ when not frozen."""
|
||||
with mock.patch.object(sys, "frozen", False, create=True):
|
||||
app_root = get_app_root()
|
||||
# Verify it is a valid directory path and matches expected parent structure
|
||||
assert os.path.isdir(app_root)
|
||||
assert os.path.exists(os.path.join(app_root, "src"))
|
||||
|
||||
|
||||
def test_get_app_root_frozen_with_meipass():
|
||||
"""Verify that get_app_root returns the sys._MEIPASS directory when frozen by PyInstaller."""
|
||||
mock_meipass = os.path.abspath("mock_meipass_dir")
|
||||
with mock.patch.object(sys, "frozen", True, create=True), \
|
||||
mock.patch.object(sys, "_MEIPASS", mock_meipass, create=True):
|
||||
app_root = get_app_root()
|
||||
assert app_root == mock_meipass
|
||||
|
||||
|
||||
def test_get_app_root_frozen_without_meipass():
|
||||
"""Verify that get_app_root falls back to the sys.executable parent directory when frozen but _MEIPASS is absent."""
|
||||
mock_exe_path = os.path.join(os.path.abspath("mock_exe_dir"), "Odysseus.exe")
|
||||
with mock.patch.object(sys, "frozen", True, create=True), \
|
||||
mock.patch.object(sys, "executable", mock_exe_path, create=True):
|
||||
# Remove sys._MEIPASS if it exists in the test process environment
|
||||
if hasattr(sys, "_MEIPASS"):
|
||||
delattr(sys, "_MEIPASS")
|
||||
app_root = get_app_root()
|
||||
assert app_root == os.path.abspath("mock_exe_dir")
|
||||
|
||||
|
||||
def test_get_default_data_dir_normal():
|
||||
"""Verify that get_default_data_dir resolves to get_app_root() / 'data' when not frozen."""
|
||||
with mock.patch.object(sys, "frozen", False, create=True):
|
||||
res = get_default_data_dir()
|
||||
assert res == os.path.join(get_app_root(), "data")
|
||||
|
||||
|
||||
def test_get_default_data_dir_frozen():
|
||||
"""Verify that get_default_data_dir resolves to a persistent user path under ~ when frozen."""
|
||||
with mock.patch.object(sys, "frozen", True, create=True):
|
||||
res = get_default_data_dir()
|
||||
expected = os.path.join(os.path.expanduser("~"), ".odysseus", "data")
|
||||
assert res == expected
|
||||
@@ -58,7 +58,7 @@ def test_content_fetcher_extracts_og_image_and_body_fallback(module, tmp_path, m
|
||||
|
||||
monkeypatch.setattr(module, "CONTENT_CACHE_DIR", tmp_path)
|
||||
module.content_cache_index.clear()
|
||||
monkeypatch.setattr(module, "_get_public_url", lambda url, headers, timeout: _FakeResponse(html))
|
||||
monkeypatch.setattr(module, "_get_public_url", lambda url, headers, timeout, **kwargs: _FakeResponse(html))
|
||||
|
||||
result = module.fetch_webpage_content("https://example.com/parity-test")
|
||||
|
||||
@@ -82,7 +82,7 @@ def test_fetch_webpage_content_returns_empty_result_on_http_status_error(status_
|
||||
monkeypatch.setattr(
|
||||
service_content,
|
||||
"_get_public_url",
|
||||
lambda url, headers, timeout: _FakeErrorResponse(status_code),
|
||||
lambda url, headers, timeout, **kwargs: _FakeErrorResponse(status_code),
|
||||
)
|
||||
|
||||
result = service_content.fetch_webpage_content(f"https://example.com/status-{status_code}")
|
||||
@@ -119,7 +119,7 @@ def test_fetch_webpage_content_429_takes_distinct_rate_limit_path(tmp_path, monk
|
||||
monkeypatch.setattr(
|
||||
service_content,
|
||||
"_get_public_url",
|
||||
lambda url, headers, timeout: _FakeRateLimitResponse(),
|
||||
lambda url, headers, timeout, **kwargs: _FakeRateLimitResponse(),
|
||||
)
|
||||
|
||||
result = service_content.fetch_webpage_content("https://example.com/rate-limited")
|
||||
|
||||
@@ -904,7 +904,13 @@ def test_web_fetch_guard_blocks_redirect_into_private(monkeypatch):
|
||||
url = "http://public.example/start"
|
||||
headers = {"location": "http://169.254.169.254/latest/meta-data/"}
|
||||
|
||||
monkeypatch.setattr(httpx, "get", lambda url, **kwargs: _Resp())
|
||||
from contextlib import contextmanager
|
||||
|
||||
@contextmanager
|
||||
def _fake_stream(method, url, **kwargs):
|
||||
yield _Resp()
|
||||
|
||||
monkeypatch.setattr(httpx, "stream", _fake_stream)
|
||||
|
||||
with _pytest.raises(httpx.RequestError) as exc:
|
||||
content._get_public_url("http://public.example/start", headers={}, timeout=5)
|
||||
|
||||
@@ -52,6 +52,6 @@ def test_chat_endpoint_recovery_paths_are_owner_scoped():
|
||||
assert "def _clear_orphaned_session_endpoint(sess, owner:" in chat_routes
|
||||
assert "def _recover_empty_session_model(sess, session_id: str, owner:" in chat_routes
|
||||
assert "q = owner_filter(q, ModelEndpoint, owner)" in chat_routes
|
||||
assert "resolve_session_auth(sess, session, owner=get_current_user(request))" in chat_routes
|
||||
assert "resolve_session_auth(sess, session, owner=effective_user(request))" in chat_routes
|
||||
assert "def resolve_session_auth(sess, session_id: str, owner:" in chat_helpers
|
||||
assert "update_q = update_q.filter(DBSession.owner == owner)" in chat_helpers
|
||||
|
||||
@@ -35,7 +35,7 @@ def _patch_fetch(monkeypatch, text, content_type):
|
||||
monkeypatch.setattr(
|
||||
content_mod,
|
||||
"_get_public_url",
|
||||
lambda url, headers=None, timeout=5: _FakeResponse(text, content_type),
|
||||
lambda url, headers=None, timeout=5, **kwargs: _FakeResponse(text, content_type),
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,206 @@
|
||||
"""web_fetch download budgets (#3812).
|
||||
|
||||
MAX_OUTPUT_CHARS only trims what the agent sees; these caps bound what the
|
||||
server downloads, parses, and caches. Soft cap by default with a truncation
|
||||
notice, per-call override clamped to the hard cap, and a pre-buffer refusal
|
||||
when Content-Length already exceeds the hard ceiling.
|
||||
"""
|
||||
import json
|
||||
from contextlib import contextmanager
|
||||
|
||||
import pytest
|
||||
|
||||
from src.constants import WEB_FETCH_SOFT_MAX_BYTES, WEB_FETCH_HARD_MAX_BYTES
|
||||
from services.search import content as content_mod
|
||||
|
||||
|
||||
class _FakeStream:
|
||||
"""Stands in for the httpx.stream(...) context manager."""
|
||||
|
||||
def __init__(self, body: bytes, content_type="text/plain", content_length=None,
|
||||
status_code=200, chunk=8192):
|
||||
self._body = body
|
||||
self._chunk = chunk
|
||||
self.status_code = status_code
|
||||
self.encoding = "utf-8"
|
||||
self.url = "https://example.com/x"
|
||||
self.headers = {"Content-Type": content_type}
|
||||
if content_length is not None:
|
||||
self.headers["content-length"] = str(content_length)
|
||||
self.body_reads = 0
|
||||
|
||||
def iter_bytes(self):
|
||||
for i in range(0, len(self._body), self._chunk):
|
||||
self.body_reads += 1
|
||||
yield self._body[i:i + self._chunk]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def no_cache(monkeypatch, tmp_path):
|
||||
monkeypatch.setattr(content_mod, "CONTENT_CACHE_DIR", tmp_path)
|
||||
monkeypatch.setattr(content_mod, "_cache_result", lambda *a, **k: None)
|
||||
monkeypatch.setattr(content_mod, "_public_http_url", lambda u: True)
|
||||
|
||||
|
||||
def _patch_stream(monkeypatch, fake):
|
||||
@contextmanager
|
||||
def fake_stream(method, url, **kwargs):
|
||||
yield fake
|
||||
monkeypatch.setattr(content_mod.httpx, "stream", fake_stream)
|
||||
return fake
|
||||
|
||||
|
||||
def test_body_under_cap_is_untouched(monkeypatch, no_cache):
|
||||
_patch_stream(monkeypatch, _FakeStream(b"hello world"))
|
||||
r = content_mod.fetch_webpage_content("https://example.com/a.txt")
|
||||
assert r["success"] is True
|
||||
assert r["content"] == "hello world"
|
||||
assert r["truncated"] is False
|
||||
assert r["fetched_bytes"] == len(b"hello world")
|
||||
|
||||
|
||||
def test_body_over_soft_cap_truncates_with_flags(monkeypatch, no_cache):
|
||||
body = b"x" * (WEB_FETCH_SOFT_MAX_BYTES + 50_000)
|
||||
_patch_stream(monkeypatch, _FakeStream(body, content_length=len(body)))
|
||||
r = content_mod.fetch_webpage_content("https://example.com/big.txt")
|
||||
assert r["truncated"] is True
|
||||
assert r["fetched_bytes"] == WEB_FETCH_SOFT_MAX_BYTES
|
||||
assert r["total_bytes"] == len(body)
|
||||
assert len(r["content"]) == WEB_FETCH_SOFT_MAX_BYTES
|
||||
|
||||
|
||||
def test_max_bytes_override_raises_budget(monkeypatch, no_cache):
|
||||
body = b"y" * (WEB_FETCH_SOFT_MAX_BYTES + 50_000)
|
||||
_patch_stream(monkeypatch, _FakeStream(body))
|
||||
r = content_mod.fetch_webpage_content(
|
||||
"https://example.com/big.txt", max_bytes=len(body) + 1
|
||||
)
|
||||
assert r["truncated"] is False
|
||||
assert r["fetched_bytes"] == len(body)
|
||||
|
||||
|
||||
def test_override_is_clamped_to_hard_cap(monkeypatch, no_cache):
|
||||
# Ask for more than the ceiling; the effective budget must be the ceiling.
|
||||
fake = _patch_stream(monkeypatch, _FakeStream(b"z" * 10, chunk=4))
|
||||
r = content_mod.fetch_webpage_content(
|
||||
"https://example.com/a.txt", max_bytes=WEB_FETCH_HARD_MAX_BYTES * 10
|
||||
)
|
||||
assert r["success"] is True
|
||||
# The clamp itself: effective cap recorded in the cache key path is the
|
||||
# hard cap, and a declared body over the ceiling is refused regardless.
|
||||
big = _FakeStream(b"", content_length=WEB_FETCH_HARD_MAX_BYTES + 1)
|
||||
_patch_stream(monkeypatch, big)
|
||||
r = content_mod.fetch_webpage_content(
|
||||
"https://example.com/huge.bin", max_bytes=WEB_FETCH_HARD_MAX_BYTES * 10
|
||||
)
|
||||
assert r["success"] is False
|
||||
assert "TooLarge" in r["error"]
|
||||
assert big.body_reads == 0 # refused before buffering
|
||||
|
||||
|
||||
def test_declared_over_hard_cap_refused_before_buffering(monkeypatch, no_cache):
|
||||
fake = _FakeStream(b"irrelevant", content_length=WEB_FETCH_HARD_MAX_BYTES + 1)
|
||||
_patch_stream(monkeypatch, fake)
|
||||
r = content_mod.fetch_webpage_content("https://example.com/huge.iso")
|
||||
assert r["success"] is False
|
||||
assert "TooLarge" in r["error"]
|
||||
assert fake.body_reads == 0
|
||||
|
||||
|
||||
def test_truncated_pdf_is_an_error_not_garbage(monkeypatch, no_cache):
|
||||
body = b"%PDF-1.4 " + b"p" * (WEB_FETCH_SOFT_MAX_BYTES + 10)
|
||||
_patch_stream(monkeypatch, _FakeStream(body, content_type="application/pdf"))
|
||||
r = content_mod.fetch_webpage_content("https://example.com/big.pdf")
|
||||
assert r["success"] is False
|
||||
assert "TooLarge" in r["error"]
|
||||
|
||||
|
||||
def test_fetch_requests_identity_encoding(monkeypatch, no_cache):
|
||||
# Compressed responses can decode to far more than Content-Length, so the
|
||||
# streamed cap and the hard-cap preflight are only honest when we refuse
|
||||
# transfer compression. Pin that the fetch advertises identity, not gzip.
|
||||
seen = {}
|
||||
|
||||
@contextmanager
|
||||
def fake_stream(method, url, **kwargs):
|
||||
seen["headers"] = kwargs.get("headers") or {}
|
||||
yield _FakeStream(b"hello")
|
||||
monkeypatch.setattr(content_mod.httpx, "stream", fake_stream)
|
||||
|
||||
content_mod.fetch_webpage_content("https://example.com/a.txt")
|
||||
assert seen["headers"].get("Accept-Encoding") == "identity"
|
||||
|
||||
|
||||
def test_rejects_compressed_response_that_ignored_identity(monkeypatch, no_cache):
|
||||
# We request Accept-Encoding: identity, but a server can ignore it and send
|
||||
# gzip anyway. httpx would decode it, so a tiny compressed body could balloon
|
||||
# past the cap in one decoded chunk. Refuse before reading the body.
|
||||
fake = _FakeStream(b"x" * 5000, content_length=40)
|
||||
fake.headers["content-encoding"] = "gzip"
|
||||
_patch_stream(monkeypatch, fake)
|
||||
r = content_mod.fetch_webpage_content("https://example.com/a.txt")
|
||||
assert r["success"] is False
|
||||
assert "Content-Encoding" in r["error"] or "compressed" in r["error"]
|
||||
assert fake.body_reads == 0 # refused before decoding any body
|
||||
|
||||
|
||||
def test_oversized_title_does_not_hide_partial_notice(monkeypatch):
|
||||
# The partial-content notice is the PR's core contract; an untrusted,
|
||||
# oversized page title must not push it past MAX_OUTPUT_CHARS.
|
||||
import asyncio
|
||||
from src.agent_tools.web_tools import WebFetchTool
|
||||
from src.constants import MAX_OUTPUT_CHARS
|
||||
|
||||
def fake_fetch(url, timeout=10, max_bytes=None):
|
||||
return {
|
||||
"content": "partial body",
|
||||
"title": "T" * (MAX_OUTPUT_CHARS + 5_000),
|
||||
"error": "",
|
||||
"truncated": True,
|
||||
"fetched_bytes": WEB_FETCH_SOFT_MAX_BYTES,
|
||||
"total_bytes": 9_000_000,
|
||||
}
|
||||
|
||||
import src.search.content as alias_mod
|
||||
monkeypatch.setattr(alias_mod, "fetch_webpage_content", fake_fetch)
|
||||
|
||||
out = asyncio.run(WebFetchTool().execute(
|
||||
json.dumps({"url": "https://example.com/big.txt"}), ctx={}
|
||||
))
|
||||
assert out["exit_code"] == 0
|
||||
assert out["output"].startswith("[partial content:")
|
||||
assert '"full": true' in out["output"]
|
||||
|
||||
|
||||
def test_tool_layer_emits_partial_notice_and_parses_full(monkeypatch):
|
||||
import asyncio
|
||||
from src.agent_tools.web_tools import WebFetchTool
|
||||
|
||||
calls = {}
|
||||
|
||||
def fake_fetch(url, timeout=10, max_bytes=None):
|
||||
calls["max_bytes"] = max_bytes
|
||||
return {
|
||||
"content": "partial body",
|
||||
"title": "Big File",
|
||||
"error": "",
|
||||
"truncated": True,
|
||||
"fetched_bytes": WEB_FETCH_SOFT_MAX_BYTES,
|
||||
"total_bytes": 5_000_000,
|
||||
}
|
||||
|
||||
import src.search.content as alias_mod
|
||||
monkeypatch.setattr(alias_mod, "fetch_webpage_content", fake_fetch)
|
||||
|
||||
out = asyncio.run(WebFetchTool().execute(
|
||||
json.dumps({"url": "https://example.com/big.txt"}), ctx={}
|
||||
))
|
||||
assert out["exit_code"] == 0
|
||||
assert "[partial content:" in out["output"]
|
||||
assert '"full": true' in out["output"]
|
||||
assert calls["max_bytes"] is None
|
||||
|
||||
asyncio.run(WebFetchTool().execute(
|
||||
json.dumps({"url": "https://example.com/big.txt", "full": True}), ctx={}
|
||||
))
|
||||
assert calls["max_bytes"] == WEB_FETCH_HARD_MAX_BYTES
|
||||
@@ -0,0 +1,18 @@
|
||||
"""The web scraping path routes its User-Agent through one constant.
|
||||
|
||||
Guards the dedup: web_fetch / web_search outbound UAs go through
|
||||
WEB_FETCH_USER_AGENT, so a stale or bare Mozilla string cannot be re-inlined in
|
||||
the search sources.
|
||||
"""
|
||||
from pathlib import Path
|
||||
|
||||
_SEARCH = Path(__file__).resolve().parent.parent / "services" / "search"
|
||||
|
||||
|
||||
def test_search_sources_have_no_inline_mozilla_ua():
|
||||
offenders = [
|
||||
str(py.relative_to(_SEARCH.parent.parent))
|
||||
for py in _SEARCH.rglob("*.py")
|
||||
if "Mozilla/" in py.read_text(encoding="utf-8")
|
||||
]
|
||||
assert not offenders, f"inline Mozilla UA found; use WEB_FETCH_USER_AGENT: {offenders}"
|
||||
@@ -0,0 +1,54 @@
|
||||
"""Guard: every public webhook emitter goes through the manager.
|
||||
|
||||
Public emitters in `routes/` must schedule their fire through
|
||||
`webhook_manager.fire_and_forget(...)` (or `_spawn_tracked`). A bare
|
||||
`asyncio.create_task(webhook_manager.fire(...))` escapes
|
||||
`WebhookManager._bg_tasks`, so asyncio only holds a weak reference to the
|
||||
delivery task and the GC can collect it before it sends — silently dropping
|
||||
the webhook. Catching this with a scan stops a regression from sneaking
|
||||
back in via a copy-paste.
|
||||
"""
|
||||
import ast
|
||||
from pathlib import Path
|
||||
|
||||
ROUTES_DIR = Path(__file__).resolve().parent.parent / "routes"
|
||||
|
||||
|
||||
def _untracked_fire_calls(tree: ast.AST) -> list[tuple[int, str]]:
|
||||
"""Return (lineno, snippet) for any asyncio.create_task(webhook_manager.fire(...))."""
|
||||
hits: list[tuple[int, str]] = []
|
||||
for node in ast.walk(tree):
|
||||
if not isinstance(node, ast.Call):
|
||||
continue
|
||||
func = node.func
|
||||
if not (isinstance(func, ast.Attribute) and func.attr == "create_task"):
|
||||
continue
|
||||
if not (isinstance(func.value, ast.Name) and func.value.id == "asyncio"):
|
||||
continue
|
||||
if not node.args:
|
||||
continue
|
||||
inner = node.args[0]
|
||||
if not isinstance(inner, ast.Call):
|
||||
continue
|
||||
inner_func = inner.func
|
||||
if (
|
||||
isinstance(inner_func, ast.Attribute)
|
||||
and inner_func.attr == "fire"
|
||||
and isinstance(inner_func.value, ast.Name)
|
||||
and inner_func.value.id == "webhook_manager"
|
||||
):
|
||||
hits.append((node.lineno, ast.unparse(node)))
|
||||
return hits
|
||||
|
||||
|
||||
def test_no_untracked_webhook_fire_in_routes():
|
||||
offenders: list[str] = []
|
||||
for path in ROUTES_DIR.rglob("*.py"):
|
||||
tree = ast.parse(path.read_text(), filename=str(path))
|
||||
for lineno, snippet in _untracked_fire_calls(tree):
|
||||
offenders.append(f"{path.relative_to(ROUTES_DIR.parent)}:{lineno}: {snippet}")
|
||||
assert not offenders, (
|
||||
"Public webhook emitters must use webhook_manager.fire_and_forget(...) "
|
||||
"so the delivery task is tracked in WebhookManager._bg_tasks. Found "
|
||||
"untracked emitter(s):\n " + "\n ".join(offenders)
|
||||
)
|
||||
Reference in New Issue
Block a user