10 Commits

Author SHA1 Message Date
Rares Tudor 016157019c fix(tools): use _INTERNAL_BASE in serve-session endpoint registration (#3675)
#3322 renamed the loopback base to _INTERNAL_BASE, but a later Cookbook
commit reintroduced one call site using the old _COOKBOOK_BASE name,
raising NameError whenever the agent registers a model endpoint for a
running serve session.

Fixes #3669
2026-06-09 20:31:29 +02:00
RaresKeY 5d33393a28 fix(gallery): fail closed for null-user owner scope (#3613) 2026-06-09 20:20:21 +02:00
Alexandre Teixeira cdfda4bd16 test: add fast lane and duration visibility (#3659) 2026-06-09 20:11:47 +02:00
Sid 9e74a327f8 fix(llm): remove max_output_tokens from ChatGPT Subscription payload (#3656)
ChatGPT's Codex API rejects any request that includes max_output_tokens,
returning HTTP 400 "Unsupported parameter: max_output_tokens". This caused
Deep Research to always fail during the endpoint probe when a ChatGPT
Subscription model was selected.

Remove the conditional that set payload["max_output_tokens"] in
_build_chatgpt_responses_payload(). The parameter is simply not sent.

Also update the two affected tests:
- Rename test_chatgpt_subscription_payload_uses_max_output_tokens →
  test_chatgpt_subscription_payload_omits_max_output_tokens
- Rename test_chatgpt_subscription_payload_omits_empty_max_output_tokens →
  test_chatgpt_subscription_payload_omits_max_output_tokens_when_zero
- Assert max_output_tokens is absent rather than present

Fixes #3650
2026-06-09 17:42:12 +02:00
Ashvin 60d25e0e26 fix(cookbook): use COOKBOOK_STATE_FILE constant for state path (#3623)
The module derived its state file path as Path(os.environ.get("DATA_DIR", "data"))
/ "cookbook_state.json". The correct env var is ODYSSEUS_DATA_DIR, which is
already read by src/constants.py and exported as COOKBOOK_STATE_FILE. When
ODYSSEUS_DATA_DIR is set (Docker, custom installs), the old code read the wrong
env var and silently wrote state to data/cookbook_state.json relative to CWD
while every other file resolved under the custom data directory.

Fixes #3621
2026-06-09 17:39:06 +02:00
RosenTomov c46d37d876 test(tool_execution): stop two tests leaking src.tool_execution into the suite (#2686)
* Make in-venv pip-fallback test independent of the runner's environment

test_pip_install_fallback_chain_propagates_failure_in_venv simulated the in-venv case by probing the real interpreter (sys.prefix != sys.base_prefix). That assumes the test runner is itself inside a venv. CI runs pytest with no venv, so venv_check reported not-in-venv, the negated guard flipped, the --user branch fired, and the assertion failed. Make venv_check exit 0 directly to simulate the in-venv condition deterministically, mirroring the outside-venv companion test.

* Stop agent-tool import shims from leaking into the admin-gate test

test_function_call_non_object_args and test_unknown_tool_calls stub heavy DB/auth deps at import time to load the real agent-tool stack, but they popped src.tool_execution and left core.auth stubbed without restoring. Popping and re-importing src.tool_execution rebinds the src package's tool_execution attribute, so test_edit_file's later 'import src.tool_execution as te' resolved to a different module object than the one execute_tool_block lives in. The monkeypatch on _owner_is_admin then missed, the non-admin edit_file gate never fired, and the edit went through (exit_code 0). Stop touching src.tool_execution and restore the heavy stubs after import. Verified the full suite is green on Linux (Python 3.11, matching CI).

---------

Co-authored-by: Alexandre Teixeira <111787685+alteixeira20@users.noreply.github.com>
2026-06-09 16:35:10 +01:00
Alexandre Teixeira d4ab09e8e1 test: add focused test selection runner (#3556) 2026-06-09 17:03:47 +02:00
Sheikh Rahat Mahmud 9180847c0e feat(diagnostics): add consolidated service health endpoint for degraded-state reporting (#964)
* Add consolidated service health endpoint for degraded-state reporting

ROADMAP (High Priority) asks for "Better degraded-state reporting for
ChromaDB, SearXNG, email, ntfy, and provider probes." Until now there was no
single readout of which subsystems are actually working: /api/health is only a
liveness ping and each subsystem's signal lives in a different module, so a
misconfigured self-host install gives no consolidated picture.

This adds an admin-only GET /api/diagnostics/services endpoint backed by a new
src/service_health.py aggregator. Each subsystem reports a uniform
{name, status, detail, meta} where status is ok | degraded | down | disabled,
and the response rolls up an overall verdict (worst non-disabled status).

Probes are deliberately non-intrusive and safe to poll:
- ChromaDB: reads the .healthy flags on the RAG and memory vector stores.
- SearXNG: GET /healthz (2xx), falling back to the instance root (<500). No
  search query is run.
- ntfy: GET the server's built-in /v1/health. No test notification is sent.
- email: short IMAP connect+logout per configured account (no credentials in
  meta).
- providers: probe each enabled ModelEndpoint's model list (no api_key in meta).

Probe functions take their inputs as parameters and isolate the network call to
injectable callables, so they unit-test without touching the network (same
pattern as the merged provider-endpoint tests). Network probes run concurrently
off the event loop via asyncio.to_thread with bounded per-probe timeouts.

memory_vector is now passed into setup_diagnostics_routes (new optional param,
backward-compatible) so ChromaDB's vector-memory store can be reported too.

Tests: tests/test_service_health.py — 29 tests covering every status mapping
per subsystem, the overall rollup, and that no secrets leak into meta.

Verification:
  python -m pytest tests/test_service_health.py -q          # 29 passed
  python -m py_compile src/service_health.py routes/diagnostics_routes.py app.py
  python -m pytest tests/test_endpoint_resolver.py tests/test_provider_endpoints.py -q

Backend + tests only; an Admin/Settings UI badge that renders this endpoint is
a natural follow-up.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>

* fix(diagnostics): bound service-health wall-clock and redact secrets

Addresses review on #964.

Blocker 1 — genuinely bounded wall-clock:
- providers_health and email_health now fan out per-item probes across a
  bounded thread pool (_bounded_map) with a hard total budget (_FANOUT_BUDGET),
  instead of probing endpoints/accounts sequentially. Stragglers are reported
  as a controlled `timeout` and never block; the pool is shut down with
  wait=False so the response returns on time regardless of endpoint/account
  count.
- The IMAP connect path now honors the service-health budget: _imap_connect
  gained a pass-through `timeout` param and the probe calls it with
  _PROBE_TIMEOUT instead of the default 15s.
- collect_service_health runs the four network subsystems concurrently, each
  under a per-subsystem deadline (_SUBSYSTEM_DEADLINE), with an overall
  wait_for ceiling (_AGGREGATE_DEADLINE) as a backstop.

Blocker 2 — no secret/raw-error leakage in the response:
- _safe_url strips userinfo, query, and fragment from every URL surfaced in
  meta (searxng instance, ntfy base, provider name fallback), keeping only
  scheme/host/port/path.
- _classify_error maps every probe failure to a controlled category token
  (timeout, connection_refused, dns_error, tls_error, network_error,
  http_error, auth_or_protocol_error, …) — raw str(exception), which can embed
  credentialed URLs or server text, is never returned.

Tests (tests/test_service_health.py, +tests/test_diagnostics_service_route.py):
- URL userinfo/query redaction for searxng/ntfy/providers.
- secret-bearing exception strings map to categories and don't leak.
- multiple slow providers/accounts stay bounded (single + 25-endpoint cases).
- subsystems run concurrently; aggregate deadline yields a controlled result.
- route-level unauthenticated (401) / non-admin (403) / admin (200) coverage.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>

* test(diagnostics): isolate route tests so they don't leak module globals

The new route tests replaced src.service_health.collect_service_health and
routes.diagnostics_routes.require_admin via direct assignment, which persisted
for the rest of the pytest session. In CI's full alphabetical run that fake
collector (returning services=[]) leaked into the later collect_service_health
tests and failed them. Switch to monkeypatch.setattr so both are restored after
each test. No production code change.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>

---------

Co-authored-by: Claude Opus 4.8 <noreply@anthropic.com>
Co-authored-by: Alexandre Teixeira <111787685+alteixeira20@users.noreply.github.com>
2026-06-09 16:00:24 +01:00
Maanas c1674fc2aa refactor(tools): migrate execution logic to src/agent_tools/ package with handler registry (#3435)
* refactor(tools): implement strict cohesive class coordinator pattern per #2917

* test: update edit_file tests to use EditFileTool class

* fix(tools): restore tool_policy param and security backstop in coordinator

* refactor(tools): migrate domain tools to agent_tools package per #2917

* test: update test imports for new agent_tools package

* fix: resolve circular import between tool_execution and agent_tools

* fix: remove leftover git conflict markers

* fix(tools): resolve pytest failure and document _apply method

* fix(tools): clean up whitespace and remove dead _tool_python helper

---------

Co-authored-by: Alexandre Teixeira <111787685+alteixeira20@users.noreply.github.com>
2026-06-09 14:35:36 +01:00
Joshua Valderrama 35b4dd2824 fix: session context drifting — messages leaking between chats (#135) (#267)
* docs: add implementation plan for fixing chat context drifting (#135)

* fix: make Session.history immutable + fix {}.history crash

- Session.history now exposes a COPY of the internal _history list
- add_message() replaces history with a fresh copy each time
- get_context_messages() derives from _history directly
- replace_messages() updates both _history and history
- truncate_messages() updates both _history and history
- _persist_message() line 207: fixed {}.history fallback crash
- Added 11 tests for session isolation and edge cases

Addresses #135 root cause #1: shared mutable references

* fix: task scheduler uses SessionManager methods instead of overwriting sessions

- Added ensure_task_session() to SessionManager (checks cache first)
- Task scheduler now uses ensure_task_session() instead of direct dict assignment
- Task scheduler now uses SessionManager.add_message() for message persistence
- Removed direct sess_obj.history.append() that was silently losing data

Addresses #135 root causes #2 and #3

* fix: add age guard to cleanup_empty_sessions — don't delete sessions <1h old

Prevents the cleanup task from deleting sessions that were just created
and haven't received any messages yet (message_count == 0).

Addresses #135 root cause #5

* test: comprehensive session isolation tests (10/10 passing)

* refactor: consolidate _session_manager into singleton pattern

- Added set_session_manager_instance / get_session_manager_instance to core/models
- kept backward-compat aliases (set_session_manager, get_session_manager)
- session_manager.py re-exports the singleton functions
- ai_interaction.set_session_manager now syncs with the core singleton
- context_compactor uses get_session_manager_instance() instead of getattr hack
- app.py initializes the singleton once

Addresses #135 root cause #4: fragile global wiring

* test: add concurrent session isolation integration tests

Verifies:
- Concurrent add_message to different sessions doesn't cross-contaminate
- Rapid parallel writes maintain isolation
- Read-write concurrent access is safe

All 3 async tests pass, proving the immutable history fix works under concurrency

* fix: pre-import core.models in conftest to prevent test pollution

test_agent_loop.py stubs sys.modules['core.models'] = MagicMock() at
module level during collection. Any test collected after it imports
Session as a MagicMock. Pre-importing core.models in conftest.py
before test_agent_loop.py's module-level code runs prevents this.

* fix: make .history authoritative mutable list, address PR review

Per review feedback: keep .history as the authoritative mutable list so
existing code doing .history.pop(), .history = [...], etc. still works.
Fix the cross-contamination bug by ensuring __post_init__() gives each
Session its OWN unique history list (never shared).

Changes:
- core/models.py: .history IS the authoritative list. _history aliases it.
  Each Session gets its own list in __post_init__.
- core/session_manager.py: add_message() delegates to Session.add_message()
  instead of appending directly — no double-append, single source of truth.
- tests/test_session_manager.py: updated test to reflect that .history
  references see new messages (same list, not a snapshot).
- docs/plans/2026-06-01-fix-chat-context-drifting.md: removed (not for
  shipping — useful design context but too much process/doc to ship).

All 272 tests pass (3 pre-existing failures unrelated).

* Fix session manager message persistence

* Fix session history alias regressions

* Fix session history aliasing and task delivery
2026-06-09 14:12:52 +01:00
41 changed files with 3348 additions and 839 deletions
+1
View File
@@ -89,3 +89,4 @@ docs/windows-port/
compound.config.json
*.error.log
_scratch/
/odysseus/
+4 -1
View File
@@ -472,6 +472,9 @@ components = initialize_managers(BASE_DIR, rag_manager)
session_manager = components["session_manager"]
from src.assistant_log import set_session_manager as _set_asst_sm
_set_asst_sm(session_manager)
# Set the global session manager singleton (used by core.models.Session.add_message)
from core.models import set_session_manager_instance
set_session_manager_instance(session_manager)
app.state.session_manager = session_manager
memory_manager = components["memory_manager"]
memory_vector = components.get("memory_vector")
@@ -574,7 +577,7 @@ app.include_router(setup_preset_routes(preset_manager))
# Diagnostics
from routes.diagnostics_routes import setup_diagnostics_routes
app.include_router(setup_diagnostics_routes(rag_manager, rag_available, research_handler))
app.include_router(setup_diagnostics_routes(rag_manager, rag_available, research_handler, memory_vector))
# Cleanup
from routes.cleanup_routes import setup_cleanup_routes
+48 -13
View File
@@ -11,14 +11,24 @@ from typing import Dict, List, Any, Optional, TYPE_CHECKING
if TYPE_CHECKING:
from .session_manager import SessionManager
# Module-level session manager reference (set at app startup)
_session_manager: Optional["SessionManager"] = None
# Module-level session manager singleton (single source of truth)
_SESSION_MANAGER_INSTANCE: Optional["SessionManager"] = None
def set_session_manager(manager: "SessionManager"):
"""Set the global session manager reference."""
global _session_manager
_session_manager = manager
def set_session_manager_instance(manager: "SessionManager"):
"""Set the global SessionManager singleton."""
global _SESSION_MANAGER_INSTANCE
_SESSION_MANAGER_INSTANCE = manager
def get_session_manager_instance() -> Optional["SessionManager"]:
"""Get the global SessionManager singleton."""
return _SESSION_MANAGER_INSTANCE
# Keep legacy name for backward compatibility
set_session_manager = set_session_manager_instance
get_session_manager = get_session_manager_instance
@dataclass
@@ -42,7 +52,17 @@ class ChatMessage:
@dataclass
class Session:
"""A chat session — pure data container."""
"""A chat session — pure data container.
``.history`` is the authoritative mutable message list. Callers may
read, append, pop, or reassign it directly — these changes take
effect immediately. ``_history`` remains a compatibility alias that
always resolves to the authoritative ``history`` list.
Each session gets its own unique history list at construction time
(the dataclass default is never shared between instances).
"""
id: str
name: str
endpoint_url: str
@@ -56,24 +76,35 @@ class Session:
message_count: int = 0
def __post_init__(self):
if self.history is None:
self.history = []
if self.headers is None:
self.headers = {}
# Ensure each session gets its OWN list (not the shared dataclass default)
if self.history is None:
self.history = []
@property
def _history(self) -> List[ChatMessage]:
"""Compatibility alias for callers that still reference ``_history``."""
return self.history
@_history.setter
def _history(self, messages: List[ChatMessage]):
self.history = messages
def add_message(self, message: ChatMessage):
"""
Add a message to this session.
Delegates to SessionManager for persistence if available,
otherwise just appends to history.
Appends to the authoritative history list and increments
message_count. Delegates to SessionManager for persistence
if available.
"""
self.history.append(message)
self.message_count = len(self.history)
# Delegate to session manager for persistence
if _session_manager:
_session_manager._persist_message(self.id, message)
if _SESSION_MANAGER_INSTANCE:
_SESSION_MANAGER_INSTANCE._persist_message(self.id, message)
def get_context_messages(self) -> List[Dict[str, Any]]:
"""Get messages in format for LLM API.
@@ -94,3 +125,7 @@ class Session:
def get(self, key: str, default=None):
"""Dict-like access for compatibility."""
return getattr(self, key, default)
def __getitem__(self, key: str):
"""Allow session['field'] syntax."""
return getattr(self, key)
+45 -4
View File
@@ -17,6 +17,9 @@ from typing import Dict, Optional
from .database import Session as DbSession, ChatMessage as DbChatMessage, Document as DbDocument, SessionLocal, utcnow_naive
from .models import Session, ChatMessage
# Re-export singleton accessors from models for convenience
from .models import set_session_manager_instance, get_session_manager_instance
logger = logging.getLogger(__name__)
@@ -188,12 +191,17 @@ class SessionManager:
"""
Add a message to a session and persist to database.
Updates the authoritative history list and persists through this
manager directly so tests and temporary managers do not depend on the
process-wide session-manager singleton.
Args:
session_id: Session ID
message: ChatMessage to add
"""
session = self.get_session(session_id)
session.history.append(message)
session._history = session.history
session.message_count = len(session.history)
self._persist_message(session_id, message)
@@ -232,7 +240,10 @@ class SessionManager:
)
db.add(db_message)
db_session.message_count = len(self.sessions.get(session_id, {}).history) if session_id in self.sessions else 0
if session_id in self.sessions:
db_session.message_count = len(self.sessions[session_id].history)
else:
db_session.message_count = 0
_now = datetime.now(timezone.utc)
db_session.last_accessed = _now
# Clean "last conversation" timestamp — only bumped here on a
@@ -283,6 +294,7 @@ class SessionManager:
# Update in-memory
session.history = session.history[:keep_count]
session._history = session.history
logger.info(f"Truncated session {session_id} to {keep_count} messages")
return True
@@ -333,6 +345,7 @@ class SessionManager:
db.commit()
session.history = list(messages)
session._history = session.history
session.message_count = len(messages)
logger.info("Replaced session %s history with %d messages", session_id, len(messages))
return True
@@ -608,24 +621,52 @@ class SessionManager:
def save_sessions(self):
"""No-op for DB compatibility."""
def ensure_task_session(self, session_id: str, name: str, endpoint_url: str, model: str, owner: str = None, task: object = None) -> Session:
"""Create a task session if it doesn't exist, or return the existing one.
Unlike create_session, this checks the cache first and does NOT
overwrite an existing in-memory session. The task scheduler must
use this instead of direct dict assignment.
"""
if session_id in self.sessions:
return self.sessions[session_id]
session = self.create_session(session_id, name, endpoint_url, model, owner=owner)
if task is not None:
task.session_id = session_id
return session
# ------------------------------------------------------------------
# Cleanup
# ------------------------------------------------------------------
def cleanup_empty_sessions(self, auto_archive_days: int = 30) -> dict:
"""Clean up empty and old sessions."""
def cleanup_empty_sessions(self, auto_archive_days: int = 30, min_age_hours: int = 1) -> dict:
"""Clean up empty and old sessions.
Args:
auto_archive_days: Age in days before non-important sessions are archived.
min_age_hours: Minimum age in hours before an empty session can be deleted.
Prevents deleting sessions that were just created.
"""
db = SessionLocal()
stats = {'deleted_empty': 0, 'archived_old': 0, 'total_checked': 0}
try:
all_sessions = db.query(DbSession).all()
cutoff_date = utcnow_naive() - timedelta(days=auto_archive_days)
min_age = utcnow_naive() - timedelta(hours=min_age_hours)
for db_session in all_sessions:
stats['total_checked'] += 1
# Delete empty sessions
# Delete empty sessions only if older than min_age_hours
if db_session.message_count == 0:
if db_session.created_at is not None:
created = db_session.created_at
if created.tzinfo is None:
created = created.replace(tzinfo=timezone.utc)
if created > min_age:
continue # Too young to delete
if db_session.id in self.sessions:
del self.sessions[db_session.id]
db.delete(db_session)
+4
View File
@@ -15,4 +15,8 @@ markers = [
"area_helpers: self-tests for the shared test helpers in tests/helpers/",
"area_unit: pure parser / utility tests that do not clearly belong elsewhere",
"area_uncategorized: tests not yet matched by the taxonomy (fallback)",
# Fast-lane marker (issue #3443). Opt-in and orthogonal to the area_*/sub_*
# taxonomy. The fast lane runs `not slow`; mark a test slow only with
# duration evidence (see tests/run_focus.py --durations and tests/README.md).
"slow: opt-in marker for known-slow tests; excluded by the fast lane (not slow)",
]
+2 -1
View File
@@ -15,6 +15,7 @@ from pathlib import Path
from fastapi import APIRouter, HTTPException, Request, Depends
from src.auth_helpers import require_user
from src.constants import COOKBOOK_STATE_FILE
from pydantic import BaseModel
from core.middleware import require_admin
@@ -54,7 +55,7 @@ _HF_TOKEN_STATUS_SNIPPET = (
def setup_cookbook_routes() -> APIRouter:
router = APIRouter(tags=["cookbook"])
_cookbook_state_path = Path(os.environ.get("DATA_DIR", "data")) / "cookbook_state.json"
_cookbook_state_path = Path(COOKBOOK_STATE_FILE)
def _mask_secret(value: str) -> str:
if not value:
+9
View File
@@ -16,9 +16,18 @@ def setup_diagnostics_routes(
rag_manager,
rag_available: bool,
research_handler,
memory_vector=None,
) -> APIRouter:
router = APIRouter(tags=["diagnostics"])
@router.get("/api/diagnostics/services")
async def get_service_health(request: Request) -> Dict[str, Any]:
"""Consolidated degraded-state report for ChromaDB, SearXNG, email,
ntfy, and provider endpoints. Non-intrusive probes safe to poll."""
require_admin(request)
from src.service_health import collect_service_health
return await collect_service_health(rag_manager, memory_vector)
@router.get("/api/db/stats")
async def get_database_stats(request: Request) -> Dict[str, Any]:
require_admin(request)
+6 -2
View File
@@ -762,10 +762,14 @@ def _open_imap_connection(host: str, port: int, *, starttls: bool, timeout: int
imaplib._MAXLINE = 50_000_000
return conn
def _imap_connect(account_id: str | None = None, owner: str = ""):
def _imap_connect(account_id: str | None = None, owner: str = "",
timeout: int = _IMAP_TIMEOUT_SECONDS):
# SECURITY: passing `owner` scopes the fallback config lookup so a brand
# new user doesn't get connected against another user's default mailbox
# when they have no account configured.
#
# `timeout` is overridable so short-lived callers (e.g. the service-health
# probe) can impose a tighter budget than the default IMAP timeout.
cfg = _get_email_config(account_id, owner=owner)
# Connection mode:
# STARTTLS on → plain + upgrade
@@ -778,7 +782,7 @@ def _imap_connect(account_id: str | None = None, owner: str = ""):
cfg["imap_host"],
cfg["imap_port"],
starttls=bool(cfg.get("imap_starttls")),
timeout=_IMAP_TIMEOUT_SECONDS,
timeout=timeout,
)
try:
conn.login(cfg["imap_user"], cfg["imap_password"])
+9 -9
View File
@@ -11,6 +11,7 @@ from typing import Dict, Any, Optional
from pydantic import BaseModel
from core.database import GalleryImage
from src.auth_helpers import _auth_disabled
logger = logging.getLogger(__name__)
@@ -120,19 +121,18 @@ def _image_to_dict(img: GalleryImage, session_name: str = None) -> Dict[str, Any
}
def _owner_filter(q, user):
def _owner_filter(q, user, model_cls=GalleryImage):
"""Apply owner filtering to a gallery query.
When auth is disabled (single-user mode) get_current_user returns None
and there is no per-user scoping. The main library list and stats already
treat None as "show everything" (`if user is not None`), so this helper
must too otherwise the tag/model filter sidebars come back empty and the
tag-cleanup endpoints (clear-user-tags, clear-ai-tags, dedupe-tags)
silently affect zero rows in the most common self-hosted deployment.
``get_current_user`` returns None both in auth-disabled single-user mode
and when auth is enabled but no current user was resolved. Preserve the
single-user behavior, but fail closed for auth-enabled null-user states.
"""
if user is None:
if user is not None:
return q.filter(model_cls.owner == user)
if _auth_disabled():
return q
return q.filter(GalleryImage.owner == user)
return q.filter(False)
+10 -15
View File
@@ -476,8 +476,7 @@ def setup_gallery_routes() -> APIRouter:
.outerjoin(DbSession, GalleryImage.session_id == DbSession.id)
.filter(GalleryImage.is_active == True)
)
if user is not None:
q = q.filter(GalleryImage.owner == user)
q = _owner_filter(q, user)
# Search filter (prompt + tags + ai_tags)
if search:
@@ -579,28 +578,26 @@ def setup_gallery_routes() -> APIRouter:
db = SessionLocal()
try:
q = db.query(GalleryAlbum)
if user:
q = q.filter(GalleryAlbum.owner == user)
q = _owner_filter(q, user, GalleryAlbum)
albums = q.order_by(GalleryAlbum.created_at.desc()).all()
result = []
for a in albums:
_count_q = db.query(GalleryImage).filter(
GalleryImage.album_id == a.id, GalleryImage.is_active == True
)
if user:
_count_q = _count_q.filter(GalleryImage.owner == user)
_count_q = _owner_filter(_count_q, user)
count = _count_q.count()
cover_url = None
if a.cover_id:
cover = db.query(GalleryImage).filter(GalleryImage.id == a.cover_id).first()
cover_q = db.query(GalleryImage).filter(GalleryImage.id == a.cover_id)
cover = _owner_filter(cover_q, user).first()
if cover:
cover_url = f"/api/generated-image/{cover.filename}"
elif count > 0:
_cover_q = db.query(GalleryImage).filter(
GalleryImage.album_id == a.id, GalleryImage.is_active == True
)
if user:
_cover_q = _cover_q.filter(GalleryImage.owner == user)
_cover_q = _owner_filter(_cover_q, user)
first = _cover_q.order_by(GalleryImage.created_at.desc()).first()
if first:
cover_url = f"/api/generated-image/{first.filename}"
@@ -643,10 +640,9 @@ def setup_gallery_routes() -> APIRouter:
base = db.query(GalleryImage).filter(GalleryImage.is_active == True)
size_q = db.query(func.sum(GalleryImage.file_size)).filter(GalleryImage.is_active == True)
album_q = db.query(GalleryAlbum)
if user:
base = base.filter(GalleryImage.owner == user)
size_q = size_q.filter(GalleryImage.owner == user)
album_q = album_q.filter(GalleryAlbum.owner == user)
base = _owner_filter(base, user)
size_q = _owner_filter(size_q, user)
album_q = _owner_filter(album_q, user, GalleryAlbum)
total = base.count()
total_size = size_q.scalar() or 0
fav_count = base.filter(GalleryImage.favorite == True).count()
@@ -674,8 +670,7 @@ def setup_gallery_routes() -> APIRouter:
GalleryImage.is_active == True,
(GalleryImage.ai_tags == None) | (GalleryImage.ai_tags == ""),
)
if user:
q = q.filter(GalleryImage.owner == user)
q = _owner_filter(q, user)
if album_id:
q = q.filter(GalleryImage.album_id == album_id)
untagged = q.count()
@@ -18,6 +18,23 @@ from src.tool_utils import _truncate, get_mcp_manager, set_mcp_manager
logger = logging.getLogger(__name__)
from .subprocess_tools import BashTool, PythonTool
from .web_tools import WebSearchTool, WebFetchTool
from .filesystem_tools import ReadFileTool, WriteFileTool, EditFileTool, LsTool, GlobTool, GrepTool
TOOL_HANDLERS = {
"bash": BashTool().execute,
"python": PythonTool().execute,
"web_search": WebSearchTool().execute,
"web_fetch": WebFetchTool().execute,
"read_file": ReadFileTool().execute,
"write_file": WriteFileTool().execute,
"edit_file": EditFileTool().execute,
"ls": LsTool().execute,
"glob": GlobTool().execute,
"grep": GrepTool().execute,
}
# ---------------------------------------------------------------------------
# Constants (re-exported for backward compatibility — single source of truth
# is src.constants; always prefer importing from there for new code)
+419
View File
@@ -0,0 +1,419 @@
import asyncio
import json
import os
import difflib
import fnmatch
import shutil
from typing import Optional, Dict, Any, Tuple
from src.constants import MAX_READ_CHARS, MAX_DIFF_LINES, MAX_OUTPUT_CHARS
_CODENAV_SKIP_DIRS = frozenset({
".git", ".hg", ".svn", "node_modules", "venv", ".venv", "__pycache__",
".mypy_cache", ".pytest_cache", ".ruff_cache", "dist", "build",
".next", ".cache", "site-packages", ".idea", ".tox",
})
_CODENAV_MAX_HITS = 200
_CODENAV_MAX_LINE = 400
def _unified_diff(old: str, new: str, path: str) -> Optional[Dict[str, Any]]:
if old == new:
return None
old_lines = old.splitlines()
new_lines = new.splitlines()
label = path or "file"
diff_lines = list(difflib.unified_diff(
old_lines, new_lines,
fromfile=f"a/{label}", tofile=f"b/{label}",
lineterm="",
))
added = sum(1 for line in diff_lines if line.startswith("+") and not line.startswith("+++"))
removed = sum(1 for line in diff_lines if line.startswith("-") and not line.startswith("---"))
truncated = False
if len(diff_lines) > MAX_DIFF_LINES:
diff_lines = diff_lines[:MAX_DIFF_LINES]
truncated = True
text = "\n".join(diff_lines)
if truncated:
text += f"\n… diff truncated at {MAX_DIFF_LINES} lines"
return {
"text": text,
"added": added,
"removed": removed,
"new_file": old == "",
"file": os.path.basename(path) or (path or "file"),
}
class EditFileTool:
async def execute(self, content: str, ctx: dict) -> dict:
from src.tool_execution import (
_resolve_tool_path,
_resolve_tool_path_in_workspace,
_resolve_search_root,
_truncate
)
workspace = ctx.get("workspace")
try:
args = json.loads(content) if content.strip().startswith("{") else {}
except (json.JSONDecodeError, TypeError):
args = {}
raw_path = (args.get("path") or "").strip()
old = args.get("old_string", "")
new = args.get("new_string", "")
replace_all = bool(args.get("replace_all", False))
if not raw_path:
return {"error": "edit_file: path required", "exit_code": 1}
try:
path = (_resolve_tool_path_in_workspace(workspace, raw_path)
if workspace else _resolve_tool_path(raw_path))
except ValueError as e:
return {"error": f"edit_file: {e}", "exit_code": 1}
if old == "":
return {"error": "edit_file: old_string required (use write_file to create a file)", "exit_code": 1}
if old == new:
return {"error": "edit_file: old_string and new_string are identical", "exit_code": 1}
def _apply():
"""Helper function that performs the actual string replacement and file writing logic."""
with open(path, "r", encoding="utf-8") as f:
original = f.read()
count = original.count(old)
if count == 0:
return original, None, "not_found"
if count > 1 and not replace_all:
return original, None, f"not_unique:{count}"
updated = original.replace(old, new) if replace_all else original.replace(old, new, 1)
with open(path, "w", encoding="utf-8") as f:
f.write(updated)
return original, updated, "ok"
try:
original, updated, status = await asyncio.to_thread(_apply)
except FileNotFoundError:
return {"error": f"edit_file: {path}: not found (use write_file to create it)", "exit_code": 1}
except (IsADirectoryError, UnicodeDecodeError):
return {"error": f"edit_file: {path}: not an editable text file", "exit_code": 1}
except PermissionError:
return {"error": f"edit_file: {path}: permission denied", "exit_code": 1}
except OSError as e:
return {"error": f"edit_file: {path}: {e}", "exit_code": 1}
if status == "not_found":
return {"error": f"edit_file: old_string not found in {path}. Read the file and match it exactly.", "exit_code": 1}
if status.startswith("not_unique"):
n = status.split(":", 1)[1]
return {"error": f"edit_file: old_string is not unique in {path} ({n} matches). Add surrounding context or set replace_all=true.", "exit_code": 1}
n = original.count(old)
result = {"output": f"Edited {path} ({n} replacement{'s' if n != 1 else ''})", "exit_code": 0}
diff = _unified_diff(original, updated, path)
if diff:
result["diff"] = diff
return result
class ReadFileTool:
async def execute(self, content: str, ctx: dict) -> dict:
from src.tool_execution import (
_resolve_tool_path,
_resolve_tool_path_in_workspace,
_resolve_search_root,
_truncate
)
workspace = ctx.get("workspace")
raw_path, offset, limit = content.split("\n", 1)[0].strip(), 0, 0
_stripped = content.strip()
if _stripped.startswith("{"):
try:
_a = json.loads(_stripped)
raw_path = str(_a.get("path", "")).strip()
offset = int(_a.get("offset") or 0)
limit = int(_a.get("limit") or 0)
except (json.JSONDecodeError, TypeError, ValueError):
pass
try:
path = (_resolve_tool_path_in_workspace(workspace, raw_path)
if workspace else _resolve_tool_path(raw_path))
except ValueError as e:
return {"error": f"read_file: {e}", "exit_code": 1}
try:
def _read():
if offset > 0 or limit > 0:
start = max(offset, 1)
out, n, budget = [], 0, MAX_READ_CHARS
with open(path, "r", encoding="utf-8", errors="replace") as f:
for i, line in enumerate(f, 1):
if i < start:
continue
if limit > 0 and n >= limit:
break
out.append(line)
n += 1
budget -= len(line)
if budget <= 0:
out.append(f"\n... [truncated at {MAX_READ_CHARS} chars]")
break
return "".join(out)
with open(path, "r", encoding="utf-8", errors="replace") as f:
return f.read(MAX_READ_CHARS + 1)
data = await asyncio.to_thread(_read)
except FileNotFoundError:
return {"error": f"read_file: {path}: not found", "exit_code": 1}
except PermissionError:
return {"error": f"read_file: {path}: permission denied", "exit_code": 1}
except IsADirectoryError:
return {"error": f"read_file: {path}: is a directory (use ls)", "exit_code": 1}
except OSError as e:
return {"error": f"read_file: {path}: {e}", "exit_code": 1}
if not (offset > 0 or limit > 0) and len(data) > MAX_READ_CHARS:
data = data[:MAX_READ_CHARS] + f"\n... [truncated at {MAX_READ_CHARS} chars]"
return {"output": data, "exit_code": 0}
class WriteFileTool:
async def execute(self, content: str, ctx: dict) -> dict:
from src.tool_execution import (
_resolve_tool_path,
_resolve_tool_path_in_workspace,
_resolve_search_root,
_truncate
)
workspace = ctx.get("workspace")
lines = content.split("\n", 1)
raw_path = lines[0].strip()
body = lines[1] if len(lines) > 1 else ""
try:
path = (_resolve_tool_path_in_workspace(workspace, raw_path)
if workspace else _resolve_tool_path(raw_path))
except ValueError as e:
return {"error": f"write_file: {e}", "exit_code": 1}
try:
def _write():
old = ""
try:
with open(path, "r", encoding="utf-8") as f:
old = f.read()
except (FileNotFoundError, IsADirectoryError, UnicodeDecodeError, OSError):
old = ""
d = os.path.dirname(path)
if d:
os.makedirs(d, exist_ok=True)
with open(path, "w", encoding="utf-8") as f:
f.write(body)
return old, len(body)
old_content, size = await asyncio.to_thread(_write)
except PermissionError:
return {"error": f"write_file: {path}: permission denied", "exit_code": 1}
except OSError as e:
return {"error": f"write_file: {path}: {e}", "exit_code": 1}
diff = _unified_diff(old_content, body, path)
result = {"output": f"Wrote {size} bytes to {path}", "exit_code": 0}
if diff:
result["diff"] = diff
return result
class LsTool:
async def execute(self, content: str, ctx: dict) -> dict:
from src.tool_execution import (
_resolve_tool_path,
_resolve_tool_path_in_workspace,
_resolve_search_root,
_truncate
)
workspace = ctx.get("workspace")
raw_path = ""
_s = (content or "").strip()
if _s.startswith("{"):
try:
raw_path = str(json.loads(_s).get("path", "")).strip()
except json.JSONDecodeError:
raw_path = ""
else:
raw_path = _s.split("\n", 1)[0].strip()
try:
root = _resolve_search_root(raw_path)
except ValueError as e:
return {"error": f"ls: {e}", "exit_code": 1}
def _ls():
if not os.path.isdir(root):
return None, f"ls: {root}: not a directory"
rows = []
try:
with os.scandir(root) as it:
for entry in it:
if entry.name.startswith("."):
continue
try:
is_dir = entry.is_dir(follow_symlinks=False)
size = entry.stat(follow_symlinks=False).st_size if not is_dir else 0
except OSError:
continue
rows.append((is_dir, entry.name, size))
except (PermissionError, OSError) as _e:
return None, f"ls: {_e}"
rows.sort(key=lambda r: (not r[0], r[1].lower()))
lines = [f"{root}:"]
for is_dir, name, size in rows[:_CODENAV_MAX_HITS]:
lines.append(f" {name}/" if is_dir else f" {name} ({size} B)")
if len(rows) > _CODENAV_MAX_HITS:
lines.append(f" ... [{len(rows) - _CODENAV_MAX_HITS} more]")
if not rows:
lines.append(" (empty)")
return "\n".join(lines), None
out, err = await asyncio.to_thread(_ls)
if err:
return {"error": err, "exit_code": 1}
return {"output": _truncate(out), "exit_code": 0}
class GlobTool:
async def execute(self, content: str, ctx: dict) -> dict:
from src.tool_execution import (
_resolve_tool_path,
_resolve_tool_path_in_workspace,
_resolve_search_root,
_truncate
)
workspace = ctx.get("workspace")
args = {}
_s = (content or "").strip()
if _s.startswith("{"):
try:
args = json.loads(_s)
except json.JSONDecodeError:
args = {}
else:
args = {"pattern": _s}
pattern = str(args.get("pattern", "")).strip()
if not pattern:
return {"error": "glob: pattern is required", "exit_code": 1}
try:
root = _resolve_search_root(str(args.get("path", "")))
except ValueError as e:
return {"error": f"glob: {e}", "exit_code": 1}
def _glob():
from pathlib import Path
base = Path(root)
if not base.is_dir():
return None, f"glob: {root}: not a directory"
matched = []
try:
for p in base.rglob(pattern):
if set(p.relative_to(base).parts) & _CODENAV_SKIP_DIRS:
continue
try:
mtime = p.stat().st_mtime
except OSError:
mtime = 0
matched.append((mtime, str(p)))
if len(matched) > _CODENAV_MAX_HITS * 5:
break
except (OSError, ValueError) as _e:
return None, f"glob: {_e}"
matched.sort(key=lambda t: t[0], reverse=True)
return [pth for _, pth in matched[:_CODENAV_MAX_HITS]], None
paths, err = await asyncio.to_thread(_glob)
if err:
return {"error": err, "exit_code": 1}
if not paths:
return {"output": f"No files matching {pattern!r} under {root}", "exit_code": 0}
out = "\n".join(paths)
if len(paths) >= _CODENAV_MAX_HITS:
out += f"\n... [capped at {_CODENAV_MAX_HITS} files]"
return {"output": _truncate(out), "exit_code": 0}
class GrepTool:
async def execute(self, content: str, ctx: dict) -> dict:
from src.tool_execution import (
_resolve_tool_path,
_resolve_tool_path_in_workspace,
_resolve_search_root,
_truncate
)
workspace = ctx.get("workspace")
args: Dict[str, Any] = {}
_s = (content or "").strip()
if _s.startswith("{"):
try:
args = json.loads(_s)
except json.JSONDecodeError:
args = {}
else:
args = {"pattern": _s}
pattern = str(args.get("pattern", "")).strip()
if not pattern:
return {"error": "grep: pattern is required", "exit_code": 1}
ignore_case = bool(args.get("ignore_case"))
glob_pat = str(args.get("glob", "") or "").strip()
try:
max_hits = int(args.get("max_results") or _CODENAV_MAX_HITS)
except (TypeError, ValueError):
max_hits = _CODENAV_MAX_HITS
max_hits = max(1, min(max_hits, _CODENAV_MAX_HITS))
try:
root = _resolve_search_root(str(args.get("path", "")))
except ValueError as e:
return {"error": f"grep: {e}", "exit_code": 1}
def _grep():
import re as _re
import shutil
rg = shutil.which("rg")
if rg:
cmd = [rg, "--line-number", "--no-heading", "--color=never",
"--max-count", str(max_hits)]
if ignore_case:
cmd.append("--ignore-case")
if glob_pat:
cmd += ["--glob", glob_pat]
for _d in _CODENAV_SKIP_DIRS:
cmd += ["--glob", f"!**/{_d}/**"]
cmd += ["--regexp", pattern, root]
try:
import subprocess
p = subprocess.run(cmd, capture_output=True, text=True, timeout=20)
lines = [ln for ln in (p.stdout or "").splitlines() if ln][:max_hits]
return lines, None
except subprocess.TimeoutExpired:
return None, "grep: timed out"
except Exception as _e:
return None, f"grep: {_e}"
try:
rx = _re.compile(pattern, _re.IGNORECASE if ignore_case else 0)
except _re.error as _e:
return None, f"grep: bad pattern: {_e}"
hits = []
if os.path.isfile(root):
file_iter = [root]
else:
file_iter = []
for dp, dns, fns in os.walk(root):
dns[:] = [d for d in dns if d not in _CODENAV_SKIP_DIRS]
for fn in fns:
if glob_pat and not fnmatch.fnmatch(fn, glob_pat):
continue
file_iter.append(os.path.join(dp, fn))
for fp in file_iter:
if len(hits) >= max_hits:
break
try:
with open(fp, "r", encoding="utf-8", errors="strict") as f:
for i, line in enumerate(f, 1):
if rx.search(line):
hits.append(f"{fp}:{i}:{line.rstrip()[:_CODENAV_MAX_LINE]}")
if len(hits) >= max_hits:
break
except (UnicodeDecodeError, OSError):
continue
return hits, None
lines, err = await asyncio.to_thread(_grep)
if err:
return {"error": err, "exit_code": 1}
if not lines:
return {"output": f"No matches for {pattern!r} under {root}", "exit_code": 0}
out = "\n".join(ln[:_CODENAV_MAX_LINE] for ln in lines)
if len(lines) >= max_hits:
out += f"\n... [capped at {max_hits} matches]"
return {"output": _truncate(out), "exit_code": 0}
+155
View File
@@ -0,0 +1,155 @@
import asyncio
import sys
import time
import collections
from typing import Optional, Callable, Awaitable, Tuple, Dict
from src.constants import MAX_OUTPUT_CHARS
DEFAULT_BASH_TIMEOUT = 60 * 60 # 1 hour
DEFAULT_PYTHON_TIMEOUT = 60 * 60
PROGRESS_INTERVAL_S = 2.0
PROGRESS_TAIL_LINES = 12
async def _run_subprocess_streaming(
proc: asyncio.subprocess.Process,
*,
timeout: float,
progress_cb: Optional[Callable[[Dict], Awaitable[None]]] = None,
) -> Tuple[str, str, Optional[int], bool]:
started = time.time()
stdout_full: list[str] = []
stderr_full: list[str] = []
tail = collections.deque(maxlen=PROGRESS_TAIL_LINES)
async def _reader(stream, full_buf, label: str):
if stream is None:
return
while True:
line = await stream.readline()
if not line:
break
decoded = line.decode("utf-8", errors="replace").rstrip("\n")
full_buf.append(decoded)
if label == "err":
tail.append(f"! {decoded}")
else:
tail.append(decoded)
async def _progress_emitter():
await asyncio.sleep(PROGRESS_INTERVAL_S)
while True:
if progress_cb:
try:
await progress_cb({
"elapsed_s": round(time.time() - started, 1),
"tail": "\n".join(list(tail)),
})
except Exception:
pass
await asyncio.sleep(PROGRESS_INTERVAL_S)
rd_out = asyncio.create_task(_reader(proc.stdout, stdout_full, "out"))
rd_err = asyncio.create_task(_reader(proc.stderr, stderr_full, "err"))
prog_task = asyncio.create_task(_progress_emitter()) if progress_cb else None
timed_out = False
try:
await asyncio.wait_for(proc.wait(), timeout=timeout)
except asyncio.TimeoutError:
timed_out = True
try:
proc.kill()
except Exception:
pass
try:
await asyncio.wait_for(proc.wait(), timeout=2)
except Exception:
pass
except asyncio.CancelledError:
try:
proc.kill()
except Exception:
pass
try:
await asyncio.wait_for(proc.wait(), timeout=2)
except Exception:
pass
for t in (rd_out, rd_err):
t.cancel()
if prog_task is not None:
prog_task.cancel()
raise
finally:
if prog_task is not None and not prog_task.done():
prog_task.cancel()
try:
await prog_task
except (asyncio.CancelledError, Exception):
pass
for t in (rd_out, rd_err):
try:
await asyncio.wait_for(t, timeout=1)
except Exception:
pass
return (
"\n".join(stdout_full),
"\n".join(stderr_full),
proc.returncode,
timed_out,
)
class BashTool:
async def execute(self, content: str, ctx: dict) -> dict:
from src.tool_execution import _AGENT_WORKDIR, _truncate
progress_cb = ctx.get("progress_cb")
workspace = ctx.get("workspace")
_subproc_env = ctx.get("subproc_env")
proc = await asyncio.create_subprocess_shell(
content,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
env=_subproc_env,
cwd=workspace or _AGENT_WORKDIR,
)
stdout, stderr, rc, timed_out = await _run_subprocess_streaming(
proc,
timeout=DEFAULT_BASH_TIMEOUT,
progress_cb=progress_cb,
)
if timed_out:
return {"error": f"bash: timed out after {DEFAULT_BASH_TIMEOUT}s — process killed", "exit_code": 124, "stdout": _truncate(stdout, MAX_OUTPUT_CHARS), "stderr": _truncate(stderr, MAX_OUTPUT_CHARS)}
output = stdout.rstrip()
err = stderr.rstrip()
if err:
output = (output + "\nSTDERR: " + err).strip() if output else "STDERR: " + err
output = _truncate(output, MAX_OUTPUT_CHARS)
return {"output": output or "(no output)", "exit_code": rc or 0}
class PythonTool:
async def execute(self, content: str, ctx: dict) -> dict:
from src.tool_execution import _AGENT_WORKDIR, _truncate
progress_cb = ctx.get("progress_cb")
workspace = ctx.get("workspace")
_subproc_env = ctx.get("subproc_env")
proc = await asyncio.create_subprocess_exec(
(sys.executable or "python"), "-I", "-c", content,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
env=_subproc_env,
cwd=workspace or _AGENT_WORKDIR,
)
stdout, stderr, rc, timed_out = await _run_subprocess_streaming(
proc,
timeout=DEFAULT_PYTHON_TIMEOUT,
progress_cb=progress_cb,
)
if timed_out:
return {"error": f"python: timed out after {DEFAULT_PYTHON_TIMEOUT}s — process killed", "exit_code": 124, "stdout": _truncate(stdout, MAX_OUTPUT_CHARS), "stderr": _truncate(stderr, MAX_OUTPUT_CHARS)}
output = stdout.rstrip()
err = stderr.rstrip()
if err:
output = (output + "\nSTDERR: " + err).strip() if output else "STDERR: " + err
output = _truncate(output, MAX_OUTPUT_CHARS)
return {"output": output or "(no output)", "exit_code": rc or 0}
+101
View File
@@ -0,0 +1,101 @@
import asyncio
import json
from typing import Dict, Any
from src.constants import MAX_OUTPUT_CHARS
class WebSearchTool:
async def execute(self, content: str, ctx: dict) -> dict:
from src.search import comprehensive_web_search
raw = content.strip()
query = raw
time_filter = None
max_pages = 5
if raw.startswith("{"):
try:
parsed = json.loads(raw)
if isinstance(parsed, dict) and "query" in parsed:
query = str(parsed.get("query", "")).strip()
tf = parsed.get("time_filter") or parsed.get("freshness")
if isinstance(tf, str) and tf.lower() in ("day", "week", "month", "year"):
time_filter = tf.lower()
mp = parsed.get("max_pages")
if isinstance(mp, int) and 1 <= mp <= 10:
max_pages = mp
except json.JSONDecodeError:
pass
if not query:
query = raw.split("\n")[0].strip()
if time_filter is None:
q_lc = query.lower()
if any(kw in q_lc for kw in ("today", "latest", "breaking", "this morning", "right now", "currently")):
time_filter = "day"
elif any(kw in q_lc for kw in ("this week", "past week", "recent news", "last few days")):
time_filter = "week"
elif any(kw in q_lc for kw in ("this month", "past month")):
time_filter = "month"
elif " news" in q_lc or q_lc.startswith("news ") or q_lc.endswith(" news"):
time_filter = "week"
loop = asyncio.get_running_loop()
text, sources = await asyncio.wait_for(
loop.run_in_executor(
None,
lambda: comprehensive_web_search(
query,
max_pages=max_pages,
time_filter=time_filter,
return_sources=True,
),
),
timeout=30,
)
output = text[:MAX_OUTPUT_CHARS] if len(text) > MAX_OUTPUT_CHARS else text
if sources:
output += "\n\n<!-- SOURCES:" + json.dumps(sources) + " -->"
return {"output": output, "exit_code": 0}
class WebFetchTool:
async def execute(self, content: str, ctx: dict) -> dict:
from src.search.content import fetch_webpage_content
raw = content.strip()
url = ""
if raw.startswith("{"):
try:
parsed = json.loads(raw)
if isinstance(parsed, dict):
url = str(parsed.get("url") or "").strip()
except json.JSONDecodeError:
url = ""
if not url:
url = raw.split("\n")[0].strip()
if not url or url.startswith("{") or any(c in url for c in (" ", "\t", "\n")):
return {"error": "web_fetch: provide a single URL or domain, e.g. example.com", "exit_code": 1}
low = url.lower()
if "://" in low and not low.startswith(("http://", "https://")):
return {"error": f"web_fetch: unsupported URL scheme (only http/https): {url[:80]}", "exit_code": 1}
if not low.startswith(("http://", "https://")):
url = "https://" + url
loop = asyncio.get_running_loop()
try:
result = await asyncio.wait_for(
loop.run_in_executor(None, lambda: fetch_webpage_content(url, timeout=10)),
timeout=30,
)
except asyncio.TimeoutError:
return {"error": f"web_fetch: timed out fetching {url}", "exit_code": 1}
except Exception as e:
return {"error": f"web_fetch: {url}: {e}", "exit_code": 1}
err = result.get("error")
text = (result.get("content") or "").strip()
title = result.get("title") or ""
if not text:
if err:
return {"error": f"web_fetch: {url}: {err}", "exit_code": 1}
return {"error": f"web_fetch: {url}: no readable text content (not HTML, or the page needs JS/login)", "exit_code": 1}
header = (f"# {title}\n" if title else "") + f"Source: {url}\n\n"
output = header + text
if len(output) > MAX_OUTPUT_CHARS:
output = output[:MAX_OUTPUT_CHARS] + "\n\n[...truncated]"
return {"output": output, "exit_code": 0}
+7 -1
View File
@@ -24,7 +24,9 @@ MAX_PIPELINE_STEPS = 10
# ---------------------------------------------------------------------------
# Global managers (set from app.py, same pattern as _mcp_manager)
# ---------------------------------------------------------------------------
# _session_manager is kept as a local cache for performance (avoiding
# repeated get_session_manager_instance() calls). It's synced with
# the authoritative singleton in core.models.
_session_manager = None
_memory_manager = None
_memory_vector = None
@@ -33,11 +35,15 @@ _personal_docs_manager = None
def set_session_manager(mgr):
"""Set the global session manager. Syncs local cache + core singleton."""
global _session_manager
_session_manager = mgr
from core.models import set_session_manager_instance
set_session_manager_instance(mgr)
def get_session_manager():
"""Get the global session manager."""
return _session_manager
+2 -2
View File
@@ -438,8 +438,8 @@ def _update_session_history(session, split_point: int, summary: str,
)
new_history = system_prefix + [summary_msg] + recent_history
try:
from core import models as _core_models
manager = getattr(_core_models, "_session_manager", None)
from core.models import get_session_manager_instance
manager = get_session_manager_instance()
except Exception:
manager = None
if manager and getattr(session, "id", None):
+3 -2
View File
@@ -563,8 +563,9 @@ def _build_chatgpt_responses_payload(
}
if not _restricts_temperature(model):
payload["temperature"] = temperature
if max_tokens and max_tokens > 0:
payload["max_output_tokens"] = max_tokens
# ChatGPT Subscription Codex API does not support max_output_tokens —
# passing it returns HTTP 400 "Unsupported parameter: max_output_tokens".
# Do not include it in the payload.
return payload
+506
View File
@@ -0,0 +1,506 @@
"""Consolidated service health / degraded-state reporting.
ROADMAP: "Better degraded-state reporting for ChromaDB, SearXNG, email, ntfy,
and provider probes." There was no single readout of which subsystems are
actually working `/api/health` is only a liveness ping and each subsystem's
signal lives in a different module. This collects them into one uniform,
*non-intrusive* report (no test push is sent, no real search is run), so the
admin endpoint built on top of it is safe to poll.
Each probe returns:
{"name": str, "status": "ok"|"degraded"|"down"|"disabled",
"detail": str, "meta": dict}
- ok reachable / working
- degraded partially working (one of several components down)
- down configured & enabled but unreachable / erroring
- disabled not configured or turned off (not counted as a failure)
Design notes (driven by review feedback):
- **Bounded wall-clock.** Per-item probes (providers, email accounts) fan out
across a bounded thread pool with a hard total budget (`_FANOUT_BUDGET`);
stragglers are reported as a controlled `timeout` rather than blocking. The
aggregate adds a per-subsystem deadline (`_SUBSYSTEM_DEADLINE`) and an overall
ceiling (`_AGGREGATE_DEADLINE`), so the endpoint cannot hang regardless of how
many endpoints/accounts are configured or how slowly they respond.
- **No secret leakage.** Even though the endpoint is admin-only, the response
never returns credential-bearing URLs or raw exception text: URLs are passed
through `_safe_url` (userinfo / query / fragment stripped) and failures are
mapped to controlled categories via `_classify_error`.
The probe functions take their inputs as parameters (settings dict, account
list, endpoint list, manager objects) and isolate the network call to
``_http_get`` / injected callables, so they unit-test without touching the
network.
"""
import asyncio
import concurrent.futures
import logging
import socket
import ssl
import time
from typing import Any, Callable, Dict, List, Optional
from urllib.parse import urlparse
logger = logging.getLogger(__name__)
# Status ordering for rolling up an overall verdict. "disabled" is excluded —
# a turned-off feature must never drag the overall status down.
_SEVERITY = {"ok": 0, "degraded": 1, "down": 2}
OK = "ok"
DEGRADED = "degraded"
DOWN = "down"
DISABLED = "disabled"
# Timing budgets (seconds). _PROBE_TIMEOUT bounds a single network op;
# _FANOUT_BUDGET bounds a whole fan-out (providers/email) regardless of count;
# the aggregate layer adds a per-subsystem deadline and an overall ceiling.
_PROBE_TIMEOUT = 4
_PROBE_CONCURRENCY = 8
_FANOUT_BUDGET = 8
_SUBSYSTEM_DEADLINE = 10
_AGGREGATE_DEADLINE = 14
# Controlled, secret-free phrasing for each failure category.
_ERROR_DETAIL = {
"timeout": "probe timed out",
"connection_refused": "connection refused",
"dns_error": "host could not be resolved",
"tls_error": "TLS handshake failed",
"network_error": "network error",
"http_error": "server returned an error response",
"auth_or_protocol_error": "authentication or protocol error",
"no_models": "endpoint returned no models",
"no_host": "no host configured",
"error": "probe failed",
}
def _svc(name: str, status: str, detail: str, **meta: Any) -> Dict[str, Any]:
return {"name": name, "status": status, "detail": detail, "meta": dict(meta)}
def _safe_url(url: Optional[str]) -> str:
"""Strip credentials (userinfo), query, and fragment from a URL.
Keeps scheme / host / port / path so the report is still useful, but never
echoes `user:pass@`, `?api_key=`, or `#…` back to the caller. Returns
"<redacted>" if the URL can't be parsed into at least a host.
"""
if not url:
return ""
raw = url.strip()
try:
p = urlparse(raw if "://" in raw else "//" + raw)
host = p.hostname or ""
if not host:
return "<redacted>"
netloc = f"{host}:{p.port}" if p.port else host
path = (p.path or "").rstrip("/")
scheme = f"{p.scheme}://" if p.scheme else ""
return f"{scheme}{netloc}{path}"
except Exception:
return "<redacted>"
def _classify_error(exc: BaseException) -> str:
"""Map an exception to a controlled, secret-free category token.
Never returns `str(exc)` httpx/imaplib exception text can embed the target
URL (which may carry credentials) or server-supplied detail.
"""
if isinstance(exc, (asyncio.TimeoutError, concurrent.futures.TimeoutError,
TimeoutError, socket.timeout)):
return "timeout"
name = type(exc).__name__
mod = (type(exc).__module__ or "")
if isinstance(exc, ssl.SSLError) or "SSL" in name or "Certificate" in name:
return "tls_error"
if isinstance(exc, socket.gaierror) or name in ("gaierror", "herror"):
return "dns_error"
if isinstance(exc, ConnectionRefusedError) or "ConnectionRefused" in name \
or name in ("ConnectError",):
return "connection_refused"
if "Timeout" in name:
return "timeout"
if mod.startswith("imaplib") or name in ("error", "abort", "readonly"):
return "auth_or_protocol_error"
if name == "HTTPStatusError":
return "http_error"
if name in ("ConnectTimeout", "ReadTimeout", "ReadError", "WriteError",
"PoolTimeout", "RemoteProtocolError", "NetworkError",
"ProxyError", "ProtocolError"):
return "network_error"
if isinstance(exc, OSError):
return "network_error"
return "error"
def _detail_for(category: str) -> str:
return _ERROR_DETAIL.get(category, _ERROR_DETAIL["error"])
def _http_get(url: str, timeout: float = _PROBE_TIMEOUT):
"""Single network entry point for the HTTP probes (monkeypatched in tests)."""
import httpx
return httpx.get(url, timeout=timeout)
def _bounded_map(items: List[Any], worker: Callable[[int, Any], Dict[str, Any]],
*, budget: float = _FANOUT_BUDGET,
concurrency: int = _PROBE_CONCURRENCY) -> List[Optional[Dict[str, Any]]]:
"""Run ``worker(index, item)`` across a bounded thread pool, in order.
`worker` must catch its own exceptions and return a per-item dict. Any item
not finished within `budget` seconds *in total* is left as ``None`` (the
caller substitutes a controlled `timeout` entry). The pool is shut down with
``wait=False`` so stragglers never block the response their own per-op
timeout reaps them shortly after.
"""
n = len(items)
out: List[Optional[Dict[str, Any]]] = [None] * n
if n == 0:
return out
ex = concurrent.futures.ThreadPoolExecutor(max_workers=max(1, min(concurrency, n)))
futures = {ex.submit(worker, i, items[i]): i for i in range(n)}
try:
for fut in concurrent.futures.as_completed(futures, timeout=budget):
i = futures[fut]
try:
out[i] = fut.result()
except Exception as e: # worker is expected to handle its own errors
out[i] = {"ok": False, "error": _classify_error(e)}
except concurrent.futures.TimeoutError:
pass # unfinished items stay None → marked timeout by the caller
finally:
ex.shutdown(wait=False, cancel_futures=True)
return out
# ── ChromaDB (vector RAG + vector memory) ──
def chromadb_health(rag_manager: Any, memory_vector: Any) -> Dict[str, Any]:
"""Report on the two ChromaDB-backed stores via their `.healthy` flags.
Both absent disabled (Chroma/embeddings not installed or off).
Both healthy ok. One down degraded. Both present but unhealthy down.
"""
rag_present = rag_manager is not None
mem_present = memory_vector is not None
if not rag_present and not mem_present:
return _svc("chromadb", DISABLED,
"Vector RAG and vector memory are not initialized.",
rag=None, memory=None)
rag_ok = bool(rag_present and getattr(rag_manager, "healthy", False))
mem_ok = bool(mem_present and getattr(memory_vector, "healthy", False))
meta = {"rag": rag_ok if rag_present else None,
"memory": mem_ok if mem_present else None}
healthy = [ok for ok in (rag_ok if rag_present else None,
mem_ok if mem_present else None) if ok is not None]
if healthy and all(healthy):
return _svc("chromadb", OK, "Vector stores healthy.", **meta)
if any(healthy):
return _svc("chromadb", DEGRADED,
"One vector store is unavailable.", **meta)
return _svc("chromadb", DOWN, "Vector stores are unavailable.", **meta)
# ── SearXNG ──
def _searxng_instance(settings: Dict[str, Any]) -> str:
"""Mirror src/search/providers.py:_get_search_instance precedence."""
url = (settings.get("search_url") or "").strip()
if url:
return url.rstrip("/")
from src.constants import SEARXNG_INSTANCE
return SEARXNG_INSTANCE.rstrip("/")
def searxng_health(settings: Dict[str, Any],
*, http_get: Callable = _http_get) -> Dict[str, Any]:
"""Non-intrusive reachability probe for the configured SearXNG instance.
Tries `/healthz` (2xx), falling back to the instance root (any non-5xx means
the host answered). No search query is run. The configured instance is
probed in full, but only its sanitized form is returned in `meta`.
"""
provider = (settings.get("search_provider") or "searxng")
if provider != "searxng":
return _svc("searxng", DISABLED,
f"Search provider is '{provider}', not SearXNG.",
provider=provider)
instance = _searxng_instance(settings)
if not instance:
return _svc("searxng", DISABLED, "No SearXNG instance configured.")
safe_instance = _safe_url(instance)
last_category = "error"
for path, accept in (("/healthz", lambda c: 200 <= c < 300),
("/", lambda c: 0 < c < 500)):
try:
r = http_get(instance + path, timeout=_PROBE_TIMEOUT)
code = getattr(r, "status_code", 0)
if accept(code):
return _svc("searxng", OK, f"Reachable (HTTP {code}).",
instance=safe_instance, probed=path, http_status=code)
last_category = "http_error"
except Exception as e: # connection refused, DNS, timeout, …
last_category = _classify_error(e)
return _svc("searxng", DOWN, f"Unreachable ({_detail_for(last_category)}).",
instance=safe_instance, error=last_category)
# ── ntfy ──
def _ntfy_integration(integrations: List[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
"""First enabled ntfy integration with a base_url (matches note_routes)."""
for i in integrations or []:
if (i.get("preset") == "ntfy" and i.get("enabled", True)
and i.get("base_url")):
return i
return None
def ntfy_health(integrations: List[Dict[str, Any]], settings: Dict[str, Any],
*, http_get: Callable = _http_get) -> Dict[str, Any]:
"""Non-intrusive ntfy probe via the server's built-in `/v1/health` route.
No test notification is POSTed `/v1/health` returns `{"healthy":true}`
without publishing to a topic. The request keeps whatever credentials the
configured base_url carries, but `meta.base` is sanitized.
"""
channel = settings.get("reminder_channel") or "browser"
intg = _ntfy_integration(integrations)
if not intg:
return _svc("ntfy", DISABLED, "No ntfy integration configured.",
reminder_channel=channel)
raw = (intg.get("base_url") or "").strip()
parsed = urlparse(raw)
probe_base = (f"{parsed.scheme}://{parsed.netloc}"
if parsed.scheme and parsed.netloc else raw.rstrip("/"))
safe_base = _safe_url(raw)
try:
r = http_get(probe_base + "/v1/health", timeout=_PROBE_TIMEOUT)
code = getattr(r, "status_code", 0)
if code and code < 500:
return _svc("ntfy", OK, f"Reachable (HTTP {code}).",
base=safe_base, reminder_channel=channel, http_status=code)
return _svc("ntfy", DOWN, "Server returned an error response.",
base=safe_base, reminder_channel=channel, error="http_error")
except Exception as e:
category = _classify_error(e)
return _svc("ntfy", DOWN, f"Unreachable ({_detail_for(category)}).",
base=safe_base, reminder_channel=channel, error=category)
# ── Email (IMAP) ──
def email_health(accounts: List[Dict[str, Any]],
*, connect: Optional[Callable] = None) -> Dict[str, Any]:
"""Try a short IMAP connect+logout per configured account, concurrently.
All connect ok. Some fail degraded. All fail down. No account
configured disabled. Bounded by `_FANOUT_BUDGET` regardless of count.
`meta` carries only the account label and a controlled error category
never credentials or raw exception text.
"""
if not accounts:
return _svc("email", DISABLED, "No email accounts configured.")
if connect is None:
from routes.email_helpers import _imap_connect
# Impose the service-health budget on the IMAP connect itself.
connect = lambda aid: _imap_connect(aid, timeout=_PROBE_TIMEOUT) # noqa: E731
def _label(acc: Dict[str, Any]) -> str:
return acc.get("account_name") or acc.get("account_id") or "account"
def _check(_i: int, acc: Dict[str, Any]) -> Dict[str, Any]:
name = _label(acc)
if not (acc.get("imap_host") or ""):
return {"name": name, "ok": False, "error": "no_host"}
try:
conn = connect(acc.get("account_id"))
try:
conn.logout()
except Exception:
pass
return {"name": name, "ok": True, "error": None}
except Exception as e:
return {"name": name, "ok": False, "error": _classify_error(e)}
raw = _bounded_map(accounts, _check, budget=_FANOUT_BUDGET,
concurrency=_PROBE_CONCURRENCY)
per_account = [r if r is not None
else {"name": _label(accounts[i]), "ok": False, "error": "timeout"}
for i, r in enumerate(raw)]
return _rollup_items("email", "mailbox(es)", per_account)
# ── Provider endpoints ──
def providers_health(endpoints: List[Dict[str, Any]],
*, probe: Optional[Callable] = None) -> Dict[str, Any]:
"""Probe each enabled model endpoint's model list, concurrently.
`endpoints` is a list of plain dicts ({name, base_url, api_key}) so this
stays decoupled from the ORM and trivially testable. Non-empty model list
reachable. Bounded by `_FANOUT_BUDGET` regardless of count. `meta` never
contains api_key or raw URLs only a display name (or a sanitized URL when
no name is set) and a controlled error category.
"""
if not endpoints:
return _svc("providers", DISABLED, "No model endpoints configured.")
if probe is None:
from routes.model_routes import _probe_endpoint as probe
def _label(ep: Dict[str, Any]) -> str:
return ep.get("name") or _safe_url(ep.get("base_url")) or "endpoint"
def _check(_i: int, ep: Dict[str, Any]) -> Dict[str, Any]:
name = _label(ep)
try:
models = probe(ep.get("base_url"), ep.get("api_key"),
timeout=_PROBE_TIMEOUT) or []
except Exception as e:
return {"name": name, "ok": False, "model_count": 0,
"error": _classify_error(e)}
count = len(models)
return {"name": name, "ok": bool(count), "model_count": count,
"error": None if count else "no_models"}
raw = _bounded_map(endpoints, _check, budget=_FANOUT_BUDGET,
concurrency=_PROBE_CONCURRENCY)
per_endpoint = [r if r is not None
else {"name": _label(endpoints[i]), "ok": False,
"model_count": 0, "error": "timeout"}
for i, r in enumerate(raw)]
return _rollup_items("providers", "endpoint(s)", per_endpoint, key="endpoints")
def _rollup_items(name: str, noun: str, items: List[Dict[str, Any]],
key: str = "accounts") -> Dict[str, Any]:
"""Shared ok/degraded/down rollup for a list of per-item probe results."""
total = len(items)
ok_count = sum(1 for it in items if it.get("ok"))
if ok_count == total:
status, detail = OK, f"{ok_count}/{total} {noun} reachable."
elif ok_count == 0:
status, detail = DOWN, f"No {noun} reachable."
else:
status, detail = DEGRADED, f"{ok_count}/{total} {noun} reachable."
return _svc(name, status, detail, **{key: items})
# ── Aggregate ──
def _rollup(services: List[Dict[str, Any]]) -> str:
worst = OK
for s in services:
sev = _SEVERITY.get(s.get("status"))
if sev is not None and sev > _SEVERITY[worst]:
worst = s["status"]
return worst
def _gather_inputs() -> Dict[str, Any]:
"""Pull live config/account/endpoint lists from the app's data sources.
Each lookup fails soft: a broken source yields an empty/neutral value so a
single failure can't take down the whole health report.
"""
settings: Dict[str, Any] = {}
integrations: List[Dict[str, Any]] = []
accounts: List[Dict[str, Any]] = []
endpoints: List[Dict[str, Any]] = []
try:
from src.settings import load_settings
settings = load_settings() or {}
except Exception as e:
logger.debug(f"service_health: settings load failed: {e}")
try:
from src.integrations import load_integrations
integrations = load_integrations() or []
except Exception as e:
logger.debug(f"service_health: integrations load failed: {e}")
try:
from routes.email_helpers import _list_email_accounts
accounts = _list_email_accounts() or []
except Exception as e:
logger.debug(f"service_health: email accounts load failed: {e}")
try:
from core.database import SessionLocal, ModelEndpoint
db = SessionLocal()
try:
rows = db.query(ModelEndpoint).filter(
ModelEndpoint.is_enabled == True).all() # noqa: E712
endpoints = [{"name": r.name, "base_url": r.base_url,
"api_key": r.api_key} for r in rows]
finally:
db.close()
except Exception as e:
logger.debug(f"service_health: endpoint load failed: {e}")
return {"settings": settings, "integrations": integrations,
"accounts": accounts, "endpoints": endpoints}
async def _run_subsystem(name: str, fn: Callable, *args: Any) -> Dict[str, Any]:
"""Run one (sync) subsystem probe in a thread under a hard deadline.
A subsystem that overruns `_SUBSYSTEM_DEADLINE` (or raises) becomes a
controlled `down`/`timeout` entry instead of hanging or leaking the error.
"""
try:
return await asyncio.wait_for(asyncio.to_thread(fn, *args),
timeout=_SUBSYSTEM_DEADLINE)
except asyncio.TimeoutError:
return _svc(name, DOWN, _detail_for("timeout"), error="timeout")
except Exception as e:
category = _classify_error(e)
return _svc(name, DOWN, _detail_for(category), error=category)
async def collect_service_health(rag_manager: Any = None,
memory_vector: Any = None) -> Dict[str, Any]:
"""Run every probe and return {overall, services, timestamp}.
Bounded end-to-end: in-process ChromaDB flags are read synchronously; the
four network subsystems run concurrently, each under `_SUBSYSTEM_DEADLINE`,
with an overall `_AGGREGATE_DEADLINE` backstop. Per-item probes inside
providers/email are themselves bounded by `_FANOUT_BUDGET`.
"""
from datetime import datetime, timezone
inputs = _gather_inputs()
settings = inputs["settings"]
# ChromaDB is in-process and synchronous (just reads flags).
chroma = chromadb_health(rag_manager, memory_vector)
names = ["searxng", "ntfy", "email", "providers"]
coros = [
_run_subsystem("searxng", searxng_health, settings),
_run_subsystem("ntfy", ntfy_health, inputs["integrations"], settings),
_run_subsystem("email", email_health, inputs["accounts"]),
_run_subsystem("providers", providers_health, inputs["endpoints"]),
]
try:
results = await asyncio.wait_for(asyncio.gather(*coros),
timeout=_AGGREGATE_DEADLINE)
except asyncio.TimeoutError:
# Hard backstop — should not normally fire given per-subsystem deadlines.
results = [_svc(n, DOWN, _detail_for("timeout"), error="timeout")
for n in names]
services = [chroma, *results]
return {
"overall": _rollup(services),
"services": services,
# Timezone-aware UTC (…+00:00). Avoids the deprecated naive
# datetime.utcnow() flagged in review (overlaps with #1116).
"timestamp": datetime.now(timezone.utc).isoformat(),
}
+50 -29
View File
@@ -1324,7 +1324,10 @@ class TaskScheduler:
db.commit()
if self._session_manager:
try:
self._session_manager.sessions[session_id] = self._session_manager._db_to_session(sess)
self._session_manager.ensure_task_session(
session_id, f"[Task] {task.name}", endpoint_url, model,
owner=task.owner, task=task
)
except Exception:
pass
@@ -1417,6 +1420,7 @@ class TaskScheduler:
task's visible output target.
"""
from core.database import Session as DbSession, ChatMessage, CrewMember
from core.models import ChatMessage as MemChatMessage
output = task.output_target or "session"
if (
@@ -1473,7 +1477,10 @@ class TaskScheduler:
db.commit()
if self._session_manager:
try:
self._session_manager.sessions[session_id] = self._session_manager._db_to_session(sess)
self._session_manager.ensure_task_session(
session_id, f"[Task] {task.name}", endpoint_url, model_name,
owner=task.owner, task=task
)
except Exception:
pass
@@ -1482,36 +1489,50 @@ class TaskScheduler:
meta["model"] = model_name
if crew and crew.is_default_assistant:
meta.update({"source": "cron", "task_id": task.id, "task_name": task.name})
msg_meta = json.dumps(meta)
user_content = task.prompt or f"[Task] {task.name}"
user_msg = ChatMessage(
id=str(uuid.uuid4()),
session_id=session_id,
role="user",
content=user_content,
timestamp=_utcnow(),
meta_data=msg_meta,
)
assistant_msg = ChatMessage(
id=str(uuid.uuid4()),
session_id=session_id,
role="assistant",
content=result or "",
timestamp=_utcnow(),
meta_data=msg_meta,
)
db.add(user_msg)
db.add(assistant_msg)
db.commit()
if self._session_manager:
# Use SessionManager for persistence so in-memory cache stays in sync
if self._session_manager and session_id:
try:
from core.models import ChatMessage as MemMsg
sess_obj = self._session_manager.get_session(session_id)
sess_obj.history.append(MemMsg(role="user", content=user_msg.content, metadata=meta))
sess_obj.history.append(MemMsg(role="assistant", content=assistant_msg.content, metadata=meta))
self._session_manager.add_message(
session_id,
MemChatMessage(
"user",
task.prompt or f"[Task] {task.name}",
metadata=dict(meta),
),
)
self._session_manager.add_message(
session_id,
MemChatMessage(
"assistant",
result or "",
metadata=dict(meta),
),
)
except Exception:
pass
logger.exception("Failed to deliver task %s through SessionManager", task.id)
else:
# Fallback: raw DB write (no session manager available)
msg_meta = json.dumps(meta)
user_msg = ChatMessage(
id=str(uuid.uuid4()),
session_id=session_id,
role="user",
content=task.prompt or f"[Task] {task.name}",
timestamp=_utcnow(),
meta_data=msg_meta,
)
assistant_msg = ChatMessage(
id=str(uuid.uuid4()),
session_id=session_id,
role="assistant",
content=result or "",
timestamp=_utcnow(),
meta_data=msg_meta,
)
db.add(user_msg)
db.add(assistant_msg)
db.commit()
@staticmethod
def _is_email_output_target(output: str) -> bool:
+61 -702
View File
@@ -18,6 +18,8 @@ import sys
import time
from typing import Any, Awaitable, Callable, Dict, Optional, Tuple
from src.tool_security import is_public_blocked_tool, owner_is_admin_or_single_user
from src.tool_policy import ToolPolicy
from src.constants import MAX_OUTPUT_CHARS, MAX_READ_CHARS, MAX_DIFF_LINES, DATA_DIR
@@ -31,105 +33,6 @@ from src.tool_utils import _truncate, get_mcp_manager
_AGENT_WORKDIR = DATA_DIR
def _unified_diff(old: str, new: str, path: str) -> Optional[Dict[str, Any]]:
"""Build a unified diff of a file write for display in the chat.
Returns {"text": <unified diff>, "added": N, "removed": M, "new_file": bool}
or None when there's no textual change. Truncates very large diffs.
"""
if old == new:
return None
import difflib
old_lines = old.splitlines()
new_lines = new.splitlines()
label = path or "file"
diff_lines = list(difflib.unified_diff(
old_lines, new_lines,
fromfile=f"a/{label}", tofile=f"b/{label}",
lineterm="",
))
added = sum(1 for line in diff_lines if line.startswith("+") and not line.startswith("+++"))
removed = sum(1 for line in diff_lines if line.startswith("-") and not line.startswith("---"))
truncated = False
if len(diff_lines) > MAX_DIFF_LINES:
diff_lines = diff_lines[:MAX_DIFF_LINES]
truncated = True
text = "\n".join(diff_lines)
if truncated:
text += f"\n… diff truncated at {MAX_DIFF_LINES} lines"
return {
"text": text,
"added": added,
"removed": removed,
"new_file": old == "",
"file": os.path.basename(path) or (path or "file"),
}
async def _do_edit_file(content: str) -> Dict[str, Any]:
"""Exact string-replacement edit of an on-disk file.
content is JSON: {"path", "old_string", "new_string", "replace_all"?}.
Fails if old_string is missing or non-unique (unless replace_all) so the
model can't silently edit the wrong place. Returns a unified diff for the UI.
"""
try:
args = json.loads(content) if content.strip().startswith("{") else {}
except (json.JSONDecodeError, TypeError):
args = {}
raw_path = (args.get("path") or "").strip()
old = args.get("old_string", "")
new = args.get("new_string", "")
replace_all = bool(args.get("replace_all", False))
if not raw_path:
return {"error": "edit_file: path required", "exit_code": 1}
# Allowlist + sensitive-file policy as read/write_file.
try:
path = _resolve_tool_path(raw_path)
except ValueError as e:
return {"error": f"edit_file: {e}", "exit_code": 1}
if old == "":
return {"error": "edit_file: old_string required (use write_file to create a file)", "exit_code": 1}
if old == new:
return {"error": "edit_file: old_string and new_string are identical", "exit_code": 1}
def _apply():
with open(path, "r", encoding="utf-8") as f:
original = f.read()
count = original.count(old)
if count == 0:
return original, None, "not_found"
if count > 1 and not replace_all:
return original, None, f"not_unique:{count}"
updated = original.replace(old, new) if replace_all else original.replace(old, new, 1)
with open(path, "w", encoding="utf-8") as f:
f.write(updated)
return original, updated, "ok"
try:
original, updated, status = await asyncio.to_thread(_apply)
except FileNotFoundError:
return {"error": f"edit_file: {path}: not found (use write_file to create it)", "exit_code": 1}
except (IsADirectoryError, UnicodeDecodeError):
return {"error": f"edit_file: {path}: not an editable text file", "exit_code": 1}
except PermissionError:
return {"error": f"edit_file: {path}: permission denied", "exit_code": 1}
except OSError as e:
return {"error": f"edit_file: {path}: {e}", "exit_code": 1}
if status == "not_found":
return {"error": f"edit_file: old_string not found in {path}. Read the file and match it exactly.", "exit_code": 1}
if status.startswith("not_unique"):
n = status.split(":", 1)[1]
return {"error": f"edit_file: old_string is not unique in {path} ({n} matches). Add surrounding context or set replace_all=true.", "exit_code": 1}
n = original.count(old)
result = {"output": f"Edited {path} ({n} replacement{'s' if n != 1 else ''})", "exit_code": 0}
diff = _unified_diff(original, updated, path)
if diff:
result["diff"] = diff
return result
# ---------------------------------------------------------------------------
# Path confinement for read_file / write_file
@@ -269,40 +172,46 @@ def _resolve_tool_path(raw_path: str) -> str:
)
# Bash + python tools used to share a single 60s timeout. That's
# enough for one-shot commands but starves real workloads (pip
# install, ffmpeg conversions, etc.) — and worse, the agent saw the
# 60s timeout and went silent because it had nothing to report.
# The new default is intentionally generous: long enough that real
# work isn't killed mid-flight, but bounded so a runaway process
# (infinite loop, hung connect, etc.) eventually frees the worker.
# The user can cancel sooner via the chat stop button — when the
# SSE stream is torn down, the asyncio task running the subprocess
# gets cancelled and the subprocess is killed by the finally block.
DEFAULT_BASH_TIMEOUT = 60 * 60 # 1 hour
DEFAULT_PYTHON_TIMEOUT = 60 * 60
def _resolve_tool_path_in_workspace(workspace: str, raw_path: str) -> str:
"""Confine a model-supplied path to the active workspace.
Layered on top of upstream's path policy: the workspace is the allowed
root (relative paths resolve under it; paths that escape it are rejected),
and the sensitive-file deny list (.ssh, .gnupg, id_rsa, ) still applies
inside it. When no workspace is set, callers use _resolve_tool_path (the
default data/tmp allowlist) instead.
"""
if raw_path is None or not str(raw_path).strip():
raise ValueError("path is required")
base = os.path.realpath(workspace)
expanded = os.path.expanduser(str(raw_path).strip())
candidate = expanded if os.path.isabs(expanded) else os.path.join(base, expanded)
resolved = os.path.realpath(candidate)
if _is_sensitive_path(resolved):
raise ValueError(
f"path '{raw_path}' is inside a sensitive directory "
f"(e.g. .ssh, .gnupg) or matches a sensitive filename"
)
if resolved != base:
# normcase so containment holds on case-insensitive filesystems
# (Windows, default macOS): it lowercases on Windows and is a no-op on
# POSIX. commonpath raises ValueError across Windows drives (C: vs D:)
# or mixed abs/rel — both mean "outside", so the except rejects them.
nbase = os.path.normcase(base)
try:
if os.path.commonpath([os.path.normcase(resolved), nbase]) != nbase:
raise ValueError
except ValueError:
raise ValueError(f"path '{raw_path}' is outside the workspace ({workspace})")
return resolved
def get_mcp_manager():
from src import agent_tools
return agent_tools.get_mcp_manager()
# How often to push a progress event while a long-running subprocess
# is still in flight. The frontend cares about "alive" more than
# "every-byte" — 2s is the sweet spot.
PROGRESS_INTERVAL_S = 2.0
# Tail buffer size — we keep the most recent N lines of stdout +
# stderr so the progress event includes a "what's it doing right now"
# snippet without dragging the whole output along.
PROGRESS_TAIL_LINES = 12
# Directories ignored by the code-nav tools' Python fallbacks so results aren't
# polluted by VCS internals / dependency trees / build caches. ripgrep already
# honours .gitignore; this is the parity floor for the no-rg path (and the
# explicit excludes passed to rg so it skips them even without a .gitignore).
_CODENAV_SKIP_DIRS = frozenset({
".git", ".hg", ".svn", "node_modules", "venv", ".venv", "__pycache__",
".mypy_cache", ".pytest_cache", ".ruff_cache", "dist", "build",
".next", ".cache", "site-packages", ".idea", ".tox",
})
# Per-tool result caps (keep tool output cheap + model-friendly).
_CODENAV_MAX_HITS = 200
_CODENAV_MAX_LINE = 400
def _resolve_search_root(raw_path: str) -> str:
@@ -320,116 +229,6 @@ def _resolve_search_root(raw_path: str) -> str:
logger = logging.getLogger(__name__)
async def _run_subprocess_streaming(
proc: asyncio.subprocess.Process,
*,
timeout: float,
progress_cb: Optional[Callable[[Dict], Awaitable[None]]] = None,
) -> Tuple[str, str, Optional[int], bool]:
"""Run a subprocess to completion, streaming progress.
Reads stdout + stderr line-by-line into ring buffers so a
periodic progress callback can emit a "tail" of recent output
without waiting for the full result. Returns
(full_stdout, full_stderr, return_code, timed_out).
`timed_out=True` means the process was killed because it ran
past `timeout` seconds. Whatever output we'd buffered up to
that point is still returned.
"""
started = time.time()
stdout_full: list[str] = []
stderr_full: list[str] = []
tail = collections.deque(maxlen=PROGRESS_TAIL_LINES)
async def _reader(stream, full_buf, label: str):
if stream is None:
return
while True:
line = await stream.readline()
if not line:
break
decoded = line.decode("utf-8", errors="replace").rstrip("\n")
full_buf.append(decoded)
if label == "err":
tail.append(f"! {decoded}")
else:
tail.append(decoded)
async def _progress_emitter():
# Skip the first push — many commands finish well under
# PROGRESS_INTERVAL_S and a 0-second "progress" event would
# just add UI churn.
await asyncio.sleep(PROGRESS_INTERVAL_S)
while True:
if progress_cb:
try:
await progress_cb({
"elapsed_s": round(time.time() - started, 1),
"tail": "\n".join(list(tail)),
})
except Exception:
# Progress is best-effort — never let a UI hiccup
# break the underlying subprocess.
pass
await asyncio.sleep(PROGRESS_INTERVAL_S)
rd_out = asyncio.create_task(_reader(proc.stdout, stdout_full, "out"))
rd_err = asyncio.create_task(_reader(proc.stderr, stderr_full, "err"))
prog_task = asyncio.create_task(_progress_emitter()) if progress_cb else None
timed_out = False
try:
await asyncio.wait_for(proc.wait(), timeout=timeout)
except asyncio.TimeoutError:
timed_out = True
try:
proc.kill()
except Exception:
pass
try:
await asyncio.wait_for(proc.wait(), timeout=2)
except Exception:
pass
except asyncio.CancelledError:
# User hit stop / SSE stream torn down. Kill the child so it
# doesn't keep running orphaned. Re-raise so the agent loop's
# cancellation propagates as the user expects.
try:
proc.kill()
except Exception:
pass
try:
await asyncio.wait_for(proc.wait(), timeout=2)
except Exception:
pass
# Best-effort: stop the readers + emitter before re-raising.
for t in (rd_out, rd_err):
t.cancel()
if prog_task is not None:
prog_task.cancel()
raise
finally:
if prog_task is not None and not prog_task.done():
prog_task.cancel()
try:
await prog_task
except (asyncio.CancelledError, Exception):
pass
# Wait for readers to finish draining the pipes.
for t in (rd_out, rd_err):
try:
await asyncio.wait_for(t, timeout=1)
except Exception:
pass
return (
"\n".join(stdout_full),
"\n".join(stderr_full),
proc.returncode,
timed_out,
)
_ADMIN_TOOLS = {
"app_api",
"manage_endpoints",
@@ -593,24 +392,8 @@ async def _direct_fallback(
tool: str,
content: str,
progress_cb: Optional[Callable[[Dict], Awaitable[None]]] = None,
workspace: Optional[str] = None,
) -> Optional[Dict]:
"""In-process execution path for the eight tools that used to live as
stdio MCP servers under mcp_servers/. Those servers were deleted in
favor of native execution; this function is now the canonical path,
not a fallback. The name is kept for backwards compat with callers.
`progress_cb` is called periodically while bash/python subprocesses
are still running, with `{elapsed_s, tail}` payloads. Other tools
ignore it.
"""
# Inherit env + force a sane terminal so subprocesses that touch
# terminfo (anything calling `clear`, `tput`, `os.system("clear")`,
# or scripts that probe $TERM) don't spam "TERM environment variable
# not set" errors. The agent's bash/python tool calls run with PIPE
# stdin/stdout (no real TTY), so curses/termios still won't work —
# but at least non-interactive code with incidental TERM lookups
# stops failing. COLUMNS/LINES give terminal-width-aware tools (less,
# rich, etc.) reasonable defaults instead of 0×0.
_subproc_env = {
**os.environ,
"TERM": "xterm-256color",
@@ -620,444 +403,16 @@ async def _direct_fallback(
}
try:
if tool == "bash":
proc = await asyncio.create_subprocess_shell(
content,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
env=_subproc_env,
cwd=_AGENT_WORKDIR,
)
stdout, stderr, rc, timed_out = await _run_subprocess_streaming(
proc,
timeout=DEFAULT_BASH_TIMEOUT,
progress_cb=progress_cb,
)
if timed_out:
return {"error": f"bash: timed out after {DEFAULT_BASH_TIMEOUT}s — process killed", "exit_code": 124, "stdout": _truncate(stdout, MAX_OUTPUT_CHARS), "stderr": _truncate(stderr, MAX_OUTPUT_CHARS)}
output = stdout.rstrip()
err = stderr.rstrip()
if err:
output = (output + "\nSTDERR: " + err).strip() if output else "STDERR: " + err
output = _truncate(output, MAX_OUTPUT_CHARS)
return {"output": output or "(no output)", "exit_code": rc or 0}
ctx = {
"progress_cb": progress_cb,
"workspace": workspace,
"subproc_env": _subproc_env,
}
if tool == "python":
# Run user code in a subprocess so an infinite loop or crash
# can't take the whole server down. -I = isolated mode (skip
# user site, no PYTHONPATH inheritance) for hygiene.
proc = await asyncio.create_subprocess_exec(
# Use the running interpreter — there is no `python3.exe` on
# Windows, which made the agent's `python` tool fail there.
(sys.executable or "python"), "-I", "-c", content,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
env=_subproc_env,
cwd=_AGENT_WORKDIR,
)
stdout, stderr, rc, timed_out = await _run_subprocess_streaming(
proc,
timeout=DEFAULT_PYTHON_TIMEOUT,
progress_cb=progress_cb,
)
if timed_out:
return {"error": f"python: timed out after {DEFAULT_PYTHON_TIMEOUT}s — process killed", "exit_code": 124, "stdout": _truncate(stdout, MAX_OUTPUT_CHARS), "stderr": _truncate(stderr, MAX_OUTPUT_CHARS)}
output = stdout.rstrip()
err = stderr.rstrip()
if err:
output = (output + "\nSTDERR: " + err).strip() if output else "STDERR: " + err
output = _truncate(output, MAX_OUTPUT_CHARS)
return {"output": output or "(no output)", "exit_code": rc or 0}
from src.agent_tools import TOOL_HANDLERS
if tool in TOOL_HANDLERS:
return await TOOL_HANDLERS[tool](content, ctx)
if tool == "read_file":
# Args: plain path on line 1 (back-compat) OR JSON
# {path, offset?, limit?} where offset/limit are a 1-based line range.
raw_path, offset, limit = content.split("\n", 1)[0].strip(), 0, 0
_stripped = content.strip()
if _stripped.startswith("{"):
try:
_a = json.loads(_stripped)
raw_path = str(_a.get("path", "")).strip()
offset = int(_a.get("offset") or 0)
limit = int(_a.get("limit") or 0)
except (json.JSONDecodeError, TypeError, ValueError):
pass
try:
path = _resolve_tool_path(raw_path)
except ValueError as e:
return {"error": f"read_file: {e}", "exit_code": 1}
try:
# Run blocking read in a thread to keep the loop responsive.
def _read():
if offset > 0 or limit > 0:
# Line-range read: slice [offset, offset+limit).
start = max(offset, 1)
out, n, budget = [], 0, MAX_READ_CHARS
with open(path, "r", encoding="utf-8", errors="replace") as f:
for i, line in enumerate(f, 1):
if i < start:
continue
if limit > 0 and n >= limit:
break
out.append(line)
n += 1
budget -= len(line)
if budget <= 0:
out.append(f"\n... [truncated at {MAX_READ_CHARS} chars]")
break
return "".join(out)
with open(path, "r", encoding="utf-8", errors="replace") as f:
return f.read(MAX_READ_CHARS + 1)
data = await asyncio.to_thread(_read)
except FileNotFoundError:
return {"error": f"read_file: {path}: not found", "exit_code": 1}
except PermissionError:
return {"error": f"read_file: {path}: permission denied", "exit_code": 1}
except IsADirectoryError:
return {"error": f"read_file: {path}: is a directory (use ls)", "exit_code": 1}
except OSError as e:
return {"error": f"read_file: {path}: {e}", "exit_code": 1}
if not (offset > 0 or limit > 0) and len(data) > MAX_READ_CHARS:
data = data[:MAX_READ_CHARS] + f"\n... [truncated at {MAX_READ_CHARS} chars]"
return {"output": data, "exit_code": 0}
if tool == "write_file":
lines = content.split("\n", 1)
raw_path = lines[0].strip()
body = lines[1] if len(lines) > 1 else ""
try:
path = _resolve_tool_path(raw_path)
except ValueError as e:
return {"error": f"write_file: {e}", "exit_code": 1}
try:
def _write():
# Capture prior content (best-effort, text) so we can show a
# before/after diff. Missing/binary file → treat as empty.
old = ""
try:
with open(path, "r", encoding="utf-8") as f:
old = f.read()
except (FileNotFoundError, IsADirectoryError, UnicodeDecodeError, OSError):
old = ""
d = os.path.dirname(path)
if d:
os.makedirs(d, exist_ok=True)
with open(path, "w", encoding="utf-8") as f:
f.write(body)
return old, len(body)
old_content, size = await asyncio.to_thread(_write)
except PermissionError:
return {"error": f"write_file: {path}: permission denied", "exit_code": 1}
except OSError as e:
return {"error": f"write_file: {path}: {e}", "exit_code": 1}
diff = _unified_diff(old_content, body, path)
result = {"output": f"Wrote {size} bytes to {path}", "exit_code": 0}
if diff:
result["diff"] = diff
return result
if tool == "grep":
# Args (JSON): {pattern, path?, glob?, ignore_case?, max_results?}.
# Bare string → treated as the pattern.
args: Dict[str, Any] = {}
_s = (content or "").strip()
if _s.startswith("{"):
try:
args = json.loads(_s)
except json.JSONDecodeError:
args = {}
else:
args = {"pattern": _s}
pattern = str(args.get("pattern", "")).strip()
if not pattern:
return {"error": "grep: pattern is required", "exit_code": 1}
ignore_case = bool(args.get("ignore_case"))
glob_pat = str(args.get("glob", "") or "").strip()
try:
max_hits = int(args.get("max_results") or _CODENAV_MAX_HITS)
except (TypeError, ValueError):
max_hits = _CODENAV_MAX_HITS
max_hits = max(1, min(max_hits, _CODENAV_MAX_HITS))
try:
root = _resolve_search_root(str(args.get("path", "")))
except ValueError as e:
return {"error": f"grep: {e}", "exit_code": 1}
def _grep():
import re as _re
import shutil
rg = shutil.which("rg")
if rg:
cmd = [rg, "--line-number", "--no-heading", "--color=never",
"--max-count", str(max_hits)]
if ignore_case:
cmd.append("--ignore-case")
if glob_pat:
cmd += ["--glob", glob_pat]
# Exclude junk dirs even when the tree has no .gitignore, so
# results match the Python fallback's skip set.
for _d in _CODENAV_SKIP_DIRS:
cmd += ["--glob", f"!**/{_d}/**"]
cmd += ["--regexp", pattern, root]
try:
import subprocess
p = subprocess.run(cmd, capture_output=True, text=True, timeout=20)
lines = [ln for ln in (p.stdout or "").splitlines() if ln][:max_hits]
return lines, None
except subprocess.TimeoutExpired:
return None, "grep: timed out"
except Exception as _e:
return None, f"grep: {_e}"
# Python fallback (no ripgrep): walk + regex.
try:
rx = _re.compile(pattern, _re.IGNORECASE if ignore_case else 0)
except _re.error as _e:
return None, f"grep: bad pattern: {_e}"
import fnmatch
hits = []
if os.path.isfile(root):
file_iter = [root]
else:
file_iter = []
for dp, dns, fns in os.walk(root):
dns[:] = [d for d in dns if d not in _CODENAV_SKIP_DIRS]
for fn in fns:
if glob_pat and not fnmatch.fnmatch(fn, glob_pat):
continue
file_iter.append(os.path.join(dp, fn))
for fp in file_iter:
if len(hits) >= max_hits:
break
try:
with open(fp, "r", encoding="utf-8", errors="strict") as f:
for i, line in enumerate(f, 1):
if rx.search(line):
hits.append(f"{fp}:{i}:{line.rstrip()[:_CODENAV_MAX_LINE]}")
if len(hits) >= max_hits:
break
except (UnicodeDecodeError, OSError):
continue # skip binary / unreadable
return hits, None
lines, err = await asyncio.to_thread(_grep)
if err:
return {"error": err, "exit_code": 1}
if not lines:
return {"output": f"No matches for {pattern!r} under {root}", "exit_code": 0}
out = "\n".join(ln[:_CODENAV_MAX_LINE] for ln in lines)
if len(lines) >= max_hits:
out += f"\n... [capped at {max_hits} matches]"
return {"output": _truncate(out), "exit_code": 0}
if tool == "glob":
args = {}
_s = (content or "").strip()
if _s.startswith("{"):
try:
args = json.loads(_s)
except json.JSONDecodeError:
args = {}
else:
args = {"pattern": _s}
pattern = str(args.get("pattern", "")).strip()
if not pattern:
return {"error": "glob: pattern is required", "exit_code": 1}
try:
root = _resolve_search_root(str(args.get("path", "")))
except ValueError as e:
return {"error": f"glob: {e}", "exit_code": 1}
def _glob():
from pathlib import Path
base = Path(root)
if not base.is_dir():
return None, f"glob: {root}: not a directory"
matched = []
try:
for p in base.rglob(pattern):
if set(p.relative_to(base).parts) & _CODENAV_SKIP_DIRS:
continue
try:
mtime = p.stat().st_mtime
except OSError:
mtime = 0
matched.append((mtime, str(p)))
if len(matched) > _CODENAV_MAX_HITS * 5:
break
except (OSError, ValueError) as _e:
return None, f"glob: {_e}"
matched.sort(key=lambda t: t[0], reverse=True) # newest first
return [pth for _, pth in matched[:_CODENAV_MAX_HITS]], None
paths, err = await asyncio.to_thread(_glob)
if err:
return {"error": err, "exit_code": 1}
if not paths:
return {"output": f"No files matching {pattern!r} under {root}", "exit_code": 0}
out = "\n".join(paths)
if len(paths) >= _CODENAV_MAX_HITS:
out += f"\n... [capped at {_CODENAV_MAX_HITS} files]"
return {"output": _truncate(out), "exit_code": 0}
if tool == "ls":
raw_path = ""
_s = (content or "").strip()
if _s.startswith("{"):
try:
raw_path = str(json.loads(_s).get("path", "")).strip()
except json.JSONDecodeError:
raw_path = ""
else:
raw_path = _s.split("\n", 1)[0].strip()
try:
root = _resolve_search_root(raw_path)
except ValueError as e:
return {"error": f"ls: {e}", "exit_code": 1}
def _ls():
if not os.path.isdir(root):
return None, f"ls: {root}: not a directory"
rows = []
try:
with os.scandir(root) as it:
for entry in it:
if entry.name.startswith("."):
continue
try:
is_dir = entry.is_dir(follow_symlinks=False)
size = entry.stat(follow_symlinks=False).st_size if not is_dir else 0
except OSError:
continue
rows.append((is_dir, entry.name, size))
except (PermissionError, OSError) as _e:
return None, f"ls: {_e}"
rows.sort(key=lambda r: (not r[0], r[1].lower())) # dirs first, then name
lines = [f"{root}:"]
for is_dir, name, size in rows[:_CODENAV_MAX_HITS]:
lines.append(f" {name}/" if is_dir else f" {name} ({size} B)")
if len(rows) > _CODENAV_MAX_HITS:
lines.append(f" ... [{len(rows) - _CODENAV_MAX_HITS} more]")
if not rows:
lines.append(" (empty)")
return "\n".join(lines), None
out, err = await asyncio.to_thread(_ls)
if err:
return {"error": err, "exit_code": 1}
return {"output": _truncate(out), "exit_code": 0}
if tool == "web_search":
from src.search import comprehensive_web_search
raw = content.strip()
query = raw
time_filter = None
max_pages = 5
# Allow JSON-shaped args: {"query": "...", "time_filter": "day", "max_pages": 7}
if raw.startswith("{"):
try:
parsed = json.loads(raw)
if isinstance(parsed, dict) and "query" in parsed:
query = str(parsed.get("query", "")).strip()
tf = parsed.get("time_filter") or parsed.get("freshness")
if isinstance(tf, str) and tf.lower() in ("day", "week", "month", "year"):
time_filter = tf.lower()
mp = parsed.get("max_pages")
if isinstance(mp, int) and 1 <= mp <= 10:
max_pages = mp
except json.JSONDecodeError:
pass
if not query:
query = raw.split("\n")[0].strip()
# Auto-detect freshness from query phrasing when not explicit
if time_filter is None:
q_lc = query.lower()
if any(kw in q_lc for kw in ("today", "latest", "breaking", "this morning", "right now", "currently")):
time_filter = "day"
elif any(kw in q_lc for kw in ("this week", "past week", "recent news", "last few days")):
time_filter = "week"
elif any(kw in q_lc for kw in ("this month", "past month")):
time_filter = "month"
elif " news" in q_lc or q_lc.startswith("news ") or q_lc.endswith(" news"):
time_filter = "week"
loop = asyncio.get_running_loop()
text, sources = await asyncio.wait_for(
loop.run_in_executor(
None,
lambda: comprehensive_web_search(
query,
max_pages=max_pages,
time_filter=time_filter,
return_sources=True,
),
),
timeout=30,
)
output = text[:MAX_OUTPUT_CHARS] if len(text) > MAX_OUTPUT_CHARS else text
if sources:
output += "\n\n<!-- SOURCES:" + json.dumps(sources) + " -->"
return {"output": output, "exit_code": 0}
if tool == "web_fetch":
# Lightweight single-URL fetch. Wraps the SSRF-safe fetcher used
# by deep research, so private/loopback/metadata addresses are
# already blocked there.
from src.search.content import fetch_webpage_content
raw = content.strip()
url = ""
# Accept either a JSON arg ({"url": "..."}) or a plain URL/domain.
if raw.startswith("{"):
try:
parsed = json.loads(raw)
if isinstance(parsed, dict):
url = str(parsed.get("url") or "").strip()
except json.JSONDecodeError:
url = ""
if not url:
# Non-JSON (or JSON without a usable url): take the first line
# only, so a URL followed by commentary still parses.
url = raw.split("\n")[0].strip()
# Reject anything that isn't a single bare URL/domain token.
if not url or url.startswith("{") or any(c in url for c in (" ", "\t", "\n")):
return {"error": "web_fetch: provide a single URL or domain, e.g. example.com", "exit_code": 1}
low = url.lower()
if "://" in low and not low.startswith(("http://", "https://")):
return {"error": f"web_fetch: unsupported URL scheme (only http/https): {url[:80]}", "exit_code": 1}
# Accept bare domains like "example.com" by defaulting to https.
if not low.startswith(("http://", "https://")):
url = "https://" + url
loop = asyncio.get_running_loop()
try:
result = await asyncio.wait_for(
loop.run_in_executor(None, lambda: fetch_webpage_content(url, timeout=10)),
timeout=30,
)
except asyncio.TimeoutError:
return {"error": f"web_fetch: timed out fetching {url}", "exit_code": 1}
except Exception as e:
# Direct URL fetches can hit bot protection / auth walls
# (e.g. eBay 403). Treat that as a tool failure the model can
# reason around, not an uncaught chat-stream 500.
return {"error": f"web_fetch: {url}: {e}", "exit_code": 1}
err = result.get("error")
text = (result.get("content") or "").strip()
title = result.get("title") or ""
if not text:
if err:
return {"error": f"web_fetch: {url}: {err}", "exit_code": 1}
# No extractable text: non-HTML body, or a pure client-rendered
# shell. The agent can fall back to the builtin_browser tool.
return {"error": f"web_fetch: {url}: no readable text content (not HTML, or the page needs JS/login)", "exit_code": 1}
header = (f"# {title}\n" if title else "") + f"Source: {url}\n\n"
output = header + text
if len(output) > MAX_OUTPUT_CHARS:
output = output[:MAX_OUTPUT_CHARS] + "\n\n[...truncated]"
return {"output": output, "exit_code": 0}
# manage_memory / generate_image still live as MCP servers
# (mcp_servers/{memory,image_gen}_server.py); the MCP path above
# handles them.
except Exception as e:
return {"error": f"{tool}: {e}", "exit_code": 1}
@@ -1072,9 +427,10 @@ async def execute_tool_block(
block: Any,
session_id: Optional[str] = None,
disabled_tools: Optional[set] = None,
tool_policy: Optional[ToolPolicy] = None,
owner: Optional[str] = None,
progress_cb: Optional[Callable[[Dict], Awaitable[None]]] = None,
workspace: Optional[str] = None,
tool_policy: Optional[Any] = None,
) -> Tuple[str, Dict]:
"""Execute a single tool block. Returns (description, result_dict).
@@ -1130,18 +486,21 @@ async def execute_tool_block(
pass
# Reject tools that the user has disabled for this request
if tool_policy and tool_policy.blocks(tool):
desc = f"{tool}: BLOCKED"
result = {"error": tool_policy.reason_for(tool), "exit_code": 1}
logger.info("Tool blocked by policy: %s", tool)
return desc, result
if disabled_tools and tool in disabled_tools:
desc = f"{tool}: BLOCKED"
result = {"error": f"Tool '{tool}' is disabled by user.", "exit_code": 1}
logger.info(f"Tool blocked by user: {tool}")
return desc, result
if tool_policy and tool_policy.blocks(tool):
desc = f"{tool}: BLOCKED"
result = {
"error": f"Execution of tool '{tool}' is forbade by the active guide-only policy.",
"exit_code": 1,
}
logger.warning("Tool policy blocked tool=%s", tool)
return desc, result
if tool in _ADMIN_TOOLS and not _owner_is_admin(owner):
desc = f"{tool}: BLOCKED"
result = {"error": f"Tool '{tool}' requires an admin user.", "exit_code": 1}
@@ -1381,7 +740,7 @@ async def execute_tool_block(
desc = "edit_image"
result = await do_edit_image(content, owner=owner)
elif tool == "edit_file":
result = await _do_edit_file(content)
result = await _direct_fallback(tool, content, workspace=workspace) or {"error": "edit failed", "exit_code": 1}
desc = result.get("output") or result.get("error") or "edit_file"
elif tool == "trigger_research":
desc = "trigger_research"
+1 -1
View File
@@ -2684,7 +2684,7 @@ async def _ensure_served_endpoint(
try:
async with httpx.AsyncClient(timeout=30) as client:
resp = await client.post(
f"{_COOKBOOK_BASE}/api/model-endpoints",
f"{_INTERNAL_BASE}/api/model-endpoints",
data=payload,
headers=_internal_headers(),
)
+43
View File
@@ -33,6 +33,49 @@ the sub-area. The `area_*` names are registered in `pyproject.toml`; the dynamic
`sub_*` names are registered before collection by `pytest_configure` in
`tests/conftest.py`, so unknown-mark warnings still flag genuine typos.
For common focused runs, use `tests/run_focus.py`. It validates area and
sub-area names, accepts sub-areas with or without the `sub_` prefix, and passes
extra pytest arguments after `--`:
```bash
python3 tests/run_focus.py --area security
python3 tests/run_focus.py --area services --sub-area cookbook
python3 tests/run_focus.py --sub-area sub_cookbook
python3 tests/run_focus.py --keyword taxonomy
python3 tests/run_focus.py --last-failed
python3 tests/run_focus.py --dry-run --area services --sub-area cookbook
python3 tests/run_focus.py --area services -- --maxfail=1 -q
```
### Fast lane and duration visibility
`--fast` runs the fast lane: the tests that are *not* marked `slow` (it adds the
marker expression `not slow`). It composes with `--area`/`--sub-area` using
`and`. Because no tests may be marked `slow` yet, `--fast` can initially match
the full focused selection; it becomes a real speed-up as `slow` marks are added
from duration evidence. Use it for quick local or reviewer feedback; it does not
replace broader focused or full-suite validation before merge.
`--durations N` and `--durations-min FLOAT` add pytest's slowest-test reporting
so you can see where time goes. They are reporting only and do not count as a
focus selector, so `--durations` must be combined with a real selector
(`--area`, `--sub-area`, `--keyword`, `--last-failed`, or `--fast`).
Activate or otherwise use the project Python environment before running these
commands. The examples use `python3` intentionally to avoid hard-coding a local
venv path.
```bash
python3 tests/run_focus.py --fast
python3 tests/run_focus.py --area services --fast
python3 tests/run_focus.py --area services --durations 25
python3 tests/run_focus.py --area services --fast --durations 25 --durations-min 0.05
```
The `slow` marker is opt-in. Mark a test `slow` only with duration evidence
(from `--durations`), not by guessing - see the fast-lane policy in
`TESTING_STANDARD.md`.
## Core principles
- Keep PRs small and homogeneous: one kind of change per PR.
+10
View File
@@ -74,6 +74,16 @@ A test that genuinely spans categories (e.g. a route test that also pins a
security invariant) is classified by its **primary** assertion target and may be
split if it grows.
## Fast lane policy
The fast lane is `not slow`: `tests/run_focus.py --fast` selects every test that
is not marked `slow`. The `slow` marker is **opt-in**, and slow marks must be
**evidence-driven from `--durations` output** - mark a test slow only when its
measured duration shows it is genuinely expensive, never by guessing. The fast
lane exists for quick local and reviewer feedback; it is **not** a replacement
for broader focused or full-suite validation before merge, and a test must never
be marked `slow` to hide a failure or skip coverage.
## Determinism & isolation rules
Do not mutate shared process state without a controlled helper and guaranteed
+4
View File
@@ -55,6 +55,10 @@ if "src.database" not in sys.modules:
_db.ModelEndpoint = MagicMock()
sys.modules["src.database"] = _db
# Pre-import core.models before test_agent_loop.py's module-level stubs
# run (it replaces sys.modules['core.models'] with a MagicMock during
# collection, which breaks session import in subsequent tests).
import core.models # noqa: E402
def pytest_configure(config):
"""Register the dynamic taxonomy ``sub_*`` markers before collection.
+300
View File
@@ -0,0 +1,300 @@
#!/usr/bin/env python3
"""Focused test selection runner for the pytest taxonomy markers (issue #3442).
This wraps ``pytest -m`` selection over the ``area_*`` / ``sub_*`` markers that
``tests/conftest.py`` adds at collection time (issue #3491) so focused
validation is repeatable and less error-prone than hand-written marker
expressions. It builds a pytest command line and either prints it (``--dry-run``)
or runs it.
Examples:
tests/run_focus.py --area security
tests/run_focus.py --area services --sub-area cookbook
tests/run_focus.py --keyword taxonomy -- --maxfail=1 -q
tests/run_focus.py --fast
tests/run_focus.py --area services --fast --durations 25
This script imports no production code and changes no test behavior. It only
constructs and (optionally) executes a pytest invocation.
"""
from __future__ import annotations
import argparse
import shlex
import subprocess
import sys
from collections.abc import Callable, Sequence
from dataclasses import dataclass, field
from pathlib import Path
PROJECT_ROOT = Path(__file__).resolve().parent.parent
TESTS_DIR = Path(__file__).resolve().parent
if str(PROJECT_ROOT) not in sys.path:
sys.path.insert(0, str(PROJECT_ROOT))
from tests._taxonomy import discover_markers, normalize_marker_name # noqa: E402
# The canonical taxonomy areas, mirroring the ``area_*`` markers declared in
# pyproject.toml and produced by tests/_taxonomy.py.
AREAS: tuple[str, ...] = (
"security",
"routes",
"services",
"cli",
"js",
"helpers",
"unit",
"uncategorized",
)
def normalize_sub_area(value: str) -> str:
"""Normalize a CLI sub-area value and remove an optional ``sub_`` prefix."""
token = normalize_marker_name(value)
if token.startswith("sub_"):
token = token.removeprefix("sub_")
if not token:
raise argparse.ArgumentTypeError(
f"invalid sub-area {value!r}: must contain at least one letter or digit"
)
return token
def discover_sub_areas(tests_dir: Path = TESTS_DIR) -> frozenset[str]:
"""Discover valid taxonomy sub-areas from Python test filenames."""
paths = list(tests_dir.rglob("test_*.py"))
paths += list(tests_dir.rglob("*_test.py"))
markers = discover_markers(paths)
return frozenset(
marker.removeprefix("sub_")
for marker in markers
if marker.startswith("sub_")
)
def non_negative_int(value: str) -> int:
"""argparse type: a non-negative int (0 means "show all" for --durations)."""
number = int(value)
if number < 0:
raise argparse.ArgumentTypeError(f"must be >= 0, got {value!r}")
return number
def non_negative_float(value: str) -> float:
"""argparse type: a non-negative float (seconds threshold for --durations-min)."""
number = float(value)
if number < 0:
raise argparse.ArgumentTypeError(f"must be >= 0, got {value!r}")
return number
def sub_area_type(valid_sub_areas: frozenset[str]) -> Callable[[str], str]:
"""Build an argparse converter that accepts only discovered sub-areas."""
def validate(value: str) -> str:
sub_area = normalize_sub_area(value)
if sub_area not in valid_sub_areas:
raise argparse.ArgumentTypeError(
f"unknown sub-area {value!r}; choose a discovered taxonomy sub-area"
)
return sub_area
return validate
@dataclass(frozen=True)
class FocusSelection:
"""A single focused-selection request, decoupled from argparse and pytest."""
area: str | None = None
sub_area: str | None = None
keyword: str | None = None
last_failed: bool = False
fast: bool = False
durations: int | None = None
durations_min: float | None = None
pytest_args: tuple[str, ...] = field(default_factory=tuple)
@property
def has_focus(self) -> bool:
"""True when at least one focusing selector (not just pass-through) is set.
Duration visibility (``durations`` / ``durations_min``) is reporting
only, not a selector, so it does not count as focus on its own.
"""
return bool(
self.area
or self.sub_area
or self.keyword
or self.last_failed
or self.fast
)
def build_marker_expression(
area: str | None, sub_area: str | None, fast: bool = False
) -> str | None:
"""Build the ``-m`` marker expression from area, sub-area, and the fast lane.
The fast lane adds ``not slow`` and composes with any area/sub-area with
``and``. Returns ``None`` when nothing is given so the caller can omit ``-m``.
"""
parts: list[str] = []
if area:
parts.append(f"area_{area}")
if sub_area:
parts.append(f"sub_{sub_area}")
if fast:
parts.append("not slow")
if not parts:
return None
return " and ".join(parts)
def build_pytest_command(
selection: FocusSelection, python: str | None = None
) -> list[str]:
"""Build the pytest argv list for ``selection``.
No shell is involved; the result is a plain argv list for subprocess. The
interpreter defaults to the one running this script (the project venv when
invoked as ``.venv/bin/python tests/run_focus.py``).
"""
command = [python or sys.executable, "-m", "pytest"]
marker_expression = build_marker_expression(
selection.area, selection.sub_area, selection.fast
)
if marker_expression:
command += ["-m", marker_expression]
if selection.keyword:
command += ["-k", selection.keyword]
if selection.last_failed:
command += ["--last-failed", "--last-failed-no-failures=none"]
if selection.durations is not None:
command += [f"--durations={selection.durations}"]
if selection.durations_min is not None:
command += [f"--durations-min={selection.durations_min}"]
command += list(selection.pytest_args)
return command
def selection_from_args(namespace: argparse.Namespace) -> FocusSelection:
"""Convert parsed argparse values into a ``FocusSelection``."""
return FocusSelection(
area=namespace.area,
sub_area=namespace.sub_area,
keyword=namespace.keyword,
last_failed=namespace.last_failed,
fast=namespace.fast,
durations=namespace.durations,
durations_min=namespace.durations_min,
pytest_args=tuple(namespace.pytest_args),
)
def build_parser(
valid_sub_areas: frozenset[str] | None = None,
) -> argparse.ArgumentParser:
"""Build the argument parser for the focused runner."""
if valid_sub_areas is None:
valid_sub_areas = discover_sub_areas()
parser = argparse.ArgumentParser(
prog="run_focus.py",
description=(
"Run a focused subset of the test suite using the area_*/sub_* "
"taxonomy markers. Combine --area and --sub-area to intersect them."
),
epilog=(
"Pass extra pytest arguments after a literal -- separator, e.g.: "
"run_focus.py --area services -- --maxfail=1 -q"
),
)
parser.add_argument(
"--area",
choices=AREAS,
help="select tests in one taxonomy area (marker area_<area>)",
)
parser.add_argument(
"--sub-area",
type=sub_area_type(valid_sub_areas),
metavar="NAME",
help="select tests in a sub-area (marker sub_<name>); combinable with --area",
)
parser.add_argument(
"-k",
"--keyword",
help="pass a keyword expression through to pytest -k",
)
parser.add_argument(
"--last-failed",
action="store_true",
help="re-run only tests that failed on the last run (pytest --last-failed)",
)
parser.add_argument(
"--fast",
action="store_true",
help="fast lane: exclude tests marked slow (adds 'not slow'); composable with --area/--sub-area",
)
parser.add_argument(
"--durations",
type=non_negative_int,
metavar="N",
help="report the N slowest tests (pytest --durations=N, 0 shows all); not a focus selector",
)
parser.add_argument(
"--durations-min",
type=non_negative_float,
metavar="SECONDS",
help="minimum duration to report with --durations (pytest --durations-min)",
)
parser.add_argument(
"--dry-run",
action="store_true",
help="print the pytest command without executing it",
)
parser.add_argument(
"pytest_args",
nargs="*",
metavar="-- PYTEST_ARGS",
help="extra arguments forwarded to pytest after a literal --",
)
return parser
def run(
argv: Sequence[str] | None = None,
executor: Callable[[list[str]], int] = subprocess.call,
) -> int:
"""Parse ``argv``, build the pytest command, and run or print it.
``executor`` is injected so tests can assert on the constructed command
without spawning a process. It must accept an argv list and return an exit
code, matching ``subprocess.call``.
"""
parser = build_parser()
namespace = parser.parse_args(argv)
selection = selection_from_args(namespace)
if not selection.has_focus:
parser.error(
"no focus selected: pass at least one of --area, --sub-area, "
"--keyword, --last-failed, or --fast (--durations is reporting only)"
)
if selection.durations_min is not None and selection.durations is None:
parser.error(
"--durations-min has no effect without --durations; pass "
"--durations N as well"
)
command = build_pytest_command(selection)
if namespace.dry_run:
print(shlex.join(command))
return 0
return executor(command)
def main() -> int:
"""Console entry point."""
return run(sys.argv[1:])
if __name__ == "__main__":
raise SystemExit(main())
+68
View File
@@ -0,0 +1,68 @@
"""Route-level regression tests for GET /api/diagnostics/services.
The reviewer asked for explicit coverage of unauthenticated / non-admin / admin
access to this admin diagnostics route, beyond the unit tests for the collector.
These need a real FastAPI + TestClient (the conftest only stubs FastAPI when it
is *not* installed). When the full app deps aren't present we skip rather than
fail, so the suite stays green in minimal environments; CI installs
requirements, so the tests run there.
"""
import pytest
fastapi = pytest.importorskip("fastapi")
pytest.importorskip("starlette.testclient")
from fastapi import FastAPI, HTTPException, Request
from starlette.testclient import TestClient
# Importing the route module pulls a few app deps; skip cleanly if unavailable.
diag = pytest.importorskip("routes.diagnostics_routes")
def _client_with_admin_gate(monkeypatch, gate):
"""Mount the diagnostics router with `require_admin` and the collector
patched (via monkeypatch so the module globals are restored afterwards),
and return a TestClient. `gate` plays the role of require_admin."""
import src.service_health as sh
async def _fake_collect(_rag, _mem):
return {"overall": "ok", "services": [], "timestamp": "t"}
# monkeypatch.setattr restores these after the test — a plain assignment
# would leak the fakes into every later test in the session.
monkeypatch.setattr(diag, "require_admin", gate)
monkeypatch.setattr(sh, "collect_service_health", _fake_collect)
app = FastAPI()
app.include_router(diag.setup_diagnostics_routes(
rag_manager=None, rag_available=False, research_handler=None,
memory_vector=None))
return TestClient(app, raise_server_exceptions=False)
def test_unauthenticated_is_rejected(monkeypatch):
def gate(_request: Request):
raise HTTPException(401, "Not authenticated")
client = _client_with_admin_gate(monkeypatch, gate)
r = client.get("/api/diagnostics/services")
assert r.status_code == 401
def test_non_admin_is_forbidden(monkeypatch):
def gate(_request: Request):
raise HTTPException(403, "Admin only")
client = _client_with_admin_gate(monkeypatch, gate)
r = client.get("/api/diagnostics/services")
assert r.status_code == 403
def test_admin_gets_report(monkeypatch):
def gate(_request: Request):
return None # admin allowed
client = _client_with_admin_gate(monkeypatch, gate)
r = client.get("/api/diagnostics/services")
assert r.status_code == 200
body = r.json()
assert set(body) == {"overall", "services", "timestamp"}
assert body["overall"] == "ok"
+6 -6
View File
@@ -11,7 +11,7 @@ from src.tool_security import (
is_public_blocked_tool,
blocked_tools_for_owner,
)
from src.tool_execution import _do_edit_file
from src.agent_tools.filesystem_tools import EditFileTool
from src.agent_tools import ToolBlock
@@ -60,7 +60,7 @@ async def test_edit_file_blocked_at_execution_for_non_admin(monkeypatch):
async def test_edit_file_success():
p = os.path.join("/tmp", "ef_ok.py")
open(p, "w").write("def f():\n return 1\n")
res = await _do_edit_file(json.dumps({"path": p, "old_string": "return 1", "new_string": "return 2"}))
res = await EditFileTool().execute(json.dumps({"path": p, "old_string": "return 1", "new_string": "return 2"}), {})
assert res["exit_code"] == 0
assert open(p).read() == "def f():\n return 2\n"
assert res["diff"]["added"] == 1 and res["diff"]["removed"] == 1 and res["diff"]["file"] == "ef_ok.py"
@@ -71,7 +71,7 @@ async def test_edit_file_success():
async def test_edit_file_not_found():
p = os.path.join("/tmp", "ef_nf.txt")
open(p, "w").write("hello\n")
res = await _do_edit_file(json.dumps({"path": p, "old_string": "nope", "new_string": "x"}))
res = await EditFileTool().execute(json.dumps({"path": p, "old_string": "nope", "new_string": "x"}), {})
assert res["exit_code"] == 1 and "not found" in res["error"]
os.unlink(p)
@@ -80,15 +80,15 @@ async def test_edit_file_not_found():
async def test_edit_file_non_unique():
p = os.path.join("/tmp", "ef_dup.txt")
open(p, "w").write("x\nx\n")
res = await _do_edit_file(json.dumps({"path": p, "old_string": "x", "new_string": "y"}))
res = await EditFileTool().execute(json.dumps({"path": p, "old_string": "x", "new_string": "y"}), {})
assert res["exit_code"] == 1 and "not unique" in res["error"]
# replace_all resolves it
res = await _do_edit_file(json.dumps({"path": p, "old_string": "x", "new_string": "y", "replace_all": True}))
res = await EditFileTool().execute(json.dumps({"path": p, "old_string": "x", "new_string": "y", "replace_all": True}), {})
assert res["exit_code"] == 0 and open(p).read() == "y\ny\n"
os.unlink(p)
@pytest.mark.asyncio
async def test_edit_file_outside_allowed_roots():
res = await _do_edit_file(json.dumps({"path": "/etc/hosts", "old_string": "x", "new_string": "y"}))
res = await EditFileTool().execute(json.dumps({"path": "/etc/hosts", "old_string": "x", "new_string": "y"}), {})
assert res["exit_code"] == 1 and ("outside the allowed roots" in res["error"] or "sensitive" in res["error"])
+30 -14
View File
@@ -1,22 +1,38 @@
import sys
from unittest.mock import MagicMock
# Clean up any mocks from previous tests to ensure we load real modules
for mod in ['src.agent_tools', 'src.tool_parsing', 'src.tool_schemas', 'src.tool_execution']:
sys.modules.pop(mod, None)
# This module needs the real agent-tool stack; importing it pulls in heavy
# DB/auth deps, so we stub those just long enough to import, then restore them.
# We deliberately do NOT pop src.tool_execution: popping and re-importing it
# rebinds the `src` package's `tool_execution` attribute, so a later
# `import src.tool_execution as te` resolves to a different module object than
# the one its functions live in - which silently breaks tests that monkeypatch
# it (e.g. test_edit_file's admin gate).
_ABSENT = object()
_AGENT_MODULES = ["src.agent_tools", "src.tool_parsing", "src.tool_schemas"]
_STUBBED = [
"sqlalchemy", "sqlalchemy.orm", "sqlalchemy.ext", "sqlalchemy.ext.declarative",
"sqlalchemy.ext.hybrid", "sqlalchemy.sql", "sqlalchemy.sql.expression",
"src.database", "core.models", "core.database", "core.auth",
]
_saved_stubs = {name: sys.modules.get(name, _ABSENT) for name in _STUBBED}
# Mock heavy database/model dependencies before importing
for mod in [
'sqlalchemy', 'sqlalchemy.orm', 'sqlalchemy.ext', 'sqlalchemy.ext.declarative',
'sqlalchemy.ext.hybrid', 'sqlalchemy.sql', 'sqlalchemy.sql.expression',
'src.database', 'core.models', 'core.database', 'core.auth'
]:
if mod not in sys.modules:
sys.modules[mod] = MagicMock()
for _mod in _AGENT_MODULES:
sys.modules.pop(_mod, None)
for _mod in _STUBBED:
if _mod not in sys.modules:
sys.modules[_mod] = MagicMock()
import pytest
import src.agent_tools # noqa: F401
from src.tool_schemas import function_call_to_tool_block
import pytest # noqa: E402
import src.agent_tools # noqa: E402,F401
from src.tool_schemas import function_call_to_tool_block # noqa: E402
# Drop the stubs we installed so they do not leak into later tests.
for _name, _original in _saved_stubs.items():
if _original is _ABSENT:
sys.modules.pop(_name, None)
else:
sys.modules[_name] = _original
@pytest.mark.parametrize("arguments", [
+6 -3
View File
@@ -40,9 +40,12 @@ def test_upload_validates_target_album_ownership():
def test_list_albums_count_and_cover_are_owner_scoped():
fns = _function_sources()
body = fns["list_albums"]
# Both the per-album image count and the cover-fallback query must owner-scope
# by GalleryImage.owner (the album list itself already filters by owner).
assert body.count("GalleryImage.owner == user") >= 2
# The album list, per-album image count, explicit cover, and cover-fallback
# queries should all share the same gallery owner policy.
assert "q = _owner_filter(q, user, GalleryAlbum)" in body
assert "_count_q = _owner_filter(_count_q, user)" in body
assert "cover = _owner_filter(cover_q, user).first()" in body
assert "_cover_q = _owner_filter(_cover_q, user)" in body
def test_delete_album_cleanup_is_owner_scoped():
+149
View File
@@ -0,0 +1,149 @@
import uuid
from fastapi import FastAPI
from fastapi.testclient import TestClient
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy.pool import NullPool
import core.database as cdb
from core.database import GalleryAlbum, GalleryImage
import routes.gallery_routes as gallery_routes
def _client_with_gallery(monkeypatch, tmp_path):
engine = create_engine(
f"sqlite:///{tmp_path / 'gallery.db'}",
connect_args={"check_same_thread": False},
poolclass=NullPool,
)
cdb.Base.metadata.create_all(engine)
session_factory = sessionmaker(bind=engine, autoflush=False, autocommit=False)
monkeypatch.setattr(gallery_routes, "SessionLocal", session_factory)
db = session_factory()
try:
db.add_all(
[
GalleryAlbum(id="album-alice", name="Alice album", owner="alice"),
GalleryAlbum(id="album-bob", name="Bob album", owner="bob"),
GalleryImage(
id="img-alice",
filename=f"{uuid.uuid4().hex}.png",
prompt="alice prompt",
model="model-a",
tags="alice-tag",
ai_tags="",
owner="alice",
album_id="album-alice",
is_active=True,
file_size=10,
),
GalleryImage(
id="img-bob",
filename=f"{uuid.uuid4().hex}.png",
prompt="bob prompt",
model="model-b",
tags="bob-tag",
ai_tags="",
owner="bob",
album_id="album-bob",
is_active=True,
file_size=20,
),
]
)
db.commit()
finally:
db.close()
app = FastAPI()
app.include_router(gallery_routes.setup_gallery_routes())
return TestClient(app)
def test_auth_enabled_null_user_gallery_routes_fail_closed(monkeypatch, tmp_path):
monkeypatch.setenv("AUTH_ENABLED", "true")
client = _client_with_gallery(monkeypatch, tmp_path)
library = client.get("/api/gallery/library").json()
assert library["items"] == []
assert library["total"] == 0
assert library["total_tagged"] == 0
assert library["tags"] == []
assert library["models"] == []
shuffled = client.get("/api/gallery/library", params={"sort": "shuffle"}).json()
assert shuffled["items"] == []
assert shuffled["total"] == 0
assert client.get("/api/gallery/tags").json() == {"tags": []}
assert client.get("/api/gallery/albums").json() == {"albums": []}
assert client.get("/api/gallery/stats").json() == {
"total_photos": 0,
"total_size": 0,
"total_size_human": "0.0 B",
"favorites": 0,
"albums": 0,
}
assert client.post("/api/gallery/ai-tag-batch").json() == {
"ok": True,
"queued": 0,
"total_untagged": 0,
"image_ids": [],
}
def test_auth_disabled_null_user_gallery_routes_keep_single_user_mode(monkeypatch, tmp_path):
monkeypatch.setenv("AUTH_ENABLED", "false")
client = _client_with_gallery(monkeypatch, tmp_path)
library = client.get("/api/gallery/library").json()
assert {item["id"] for item in library["items"]} == {"img-alice", "img-bob"}
assert library["total"] == 2
assert library["tags"] == ["alice-tag", "bob-tag"]
assert library["models"] == ["model-a", "model-b"]
assert client.get("/api/gallery/tags").json() == {"tags": ["alice-tag", "bob-tag"]}
assert len(client.get("/api/gallery/albums").json()["albums"]) == 2
assert client.get("/api/gallery/stats").json() == {
"total_photos": 2,
"total_size": 30,
"total_size_human": "30.0 B",
"favorites": 0,
"albums": 2,
}
batch = client.post("/api/gallery/ai-tag-batch").json()
assert batch["ok"] is True
assert batch["queued"] == 2
assert batch["total_untagged"] == 2
assert set(batch["image_ids"]) == {"img-alice", "img-bob"}
def test_authenticated_gallery_routes_remain_owner_scoped(monkeypatch, tmp_path):
monkeypatch.setenv("AUTH_ENABLED", "true")
monkeypatch.setattr(gallery_routes, "get_current_user", lambda request: "alice")
client = _client_with_gallery(monkeypatch, tmp_path)
library = client.get("/api/gallery/library").json()
assert [item["id"] for item in library["items"]] == ["img-alice"]
assert library["total"] == 1
assert library["tags"] == ["alice-tag"]
assert library["models"] == ["model-a"]
assert client.get("/api/gallery/tags").json() == {"tags": ["alice-tag"]}
albums = client.get("/api/gallery/albums").json()["albums"]
assert [album["id"] for album in albums] == ["album-alice"]
assert client.get("/api/gallery/stats").json() == {
"total_photos": 1,
"total_size": 10,
"total_size_human": "10.0 B",
"favorites": 0,
"albums": 1,
}
assert client.post("/api/gallery/ai-tag-batch").json() == {
"ok": True,
"queued": 1,
"total_untagged": 1,
"image_ids": ["img-alice"],
}
+16 -8
View File
@@ -1,11 +1,8 @@
"""_owner_filter must not blank out the gallery in single-user mode.
"""_owner_filter must separate single-user mode from anonymous callers.
When AUTH_ENABLED=false, get_current_user returns None. The gallery main
list and stats treat None as "show all images" (`if user is not None`), but
_owner_filter returned q.filter(False) (zero rows) for None. So the tag and
model filter chips were always empty and clear-user-tags / clear-ai-tags /
dedupe-tags silently no-oped. _owner_filter must match the main list: no
filter when user is None, owner-scoped otherwise.
When AUTH_ENABLED=false, get_current_user returns None and gallery routes should
stay all-visible. When AUTH_ENABLED=true and no current user resolves, the same
None means an anonymous caller and gallery queries must fail closed.
"""
import tempfile
import uuid
@@ -36,7 +33,8 @@ def _seed(*owners):
db.close()
def test_none_user_returns_all_rows():
def test_none_user_returns_all_rows(monkeypatch):
monkeypatch.setenv("AUTH_ENABLED", "false")
_seed(None, None, "alice")
db = _TS()
try:
@@ -54,3 +52,13 @@ def test_named_user_is_still_scoped():
assert _owner_filter(db.query(GalleryImage), "bob").count() == 1
finally:
db.close()
def test_none_user_blocks_when_auth_is_enabled(monkeypatch):
monkeypatch.setenv("AUTH_ENABLED", "true")
_seed(None, "alice", "bob")
db = _TS()
try:
assert _owner_filter(db.query(GalleryImage), None).count() == 0
finally:
db.close()
+6 -3
View File
@@ -75,7 +75,10 @@ def test_normal_model_payload_keeps_temperature_above_one(monkeypatch):
assert payload["temperature"] == 1.2
def test_chatgpt_subscription_payload_uses_max_output_tokens():
def test_chatgpt_subscription_payload_omits_max_output_tokens():
# ChatGPT Subscription Codex API does not support max_output_tokens —
# passing it returns HTTP 400 "Unsupported parameter: max_output_tokens".
# The payload should NOT include max_output_tokens regardless of max_tokens.
payload = llm_core._build_chatgpt_responses_payload(
"gpt-5.1-codex",
[{"role": "user", "content": "Say OK"}],
@@ -83,10 +86,10 @@ def test_chatgpt_subscription_payload_uses_max_output_tokens():
max_tokens=37,
)
assert payload["max_output_tokens"] == 37
assert "max_output_tokens" not in payload
def test_chatgpt_subscription_payload_omits_empty_max_output_tokens():
def test_chatgpt_subscription_payload_omits_max_output_tokens_when_zero():
payload = llm_core._build_chatgpt_responses_payload(
"gpt-5.1-codex",
[{"role": "user", "content": "Say OK"}],
+11 -2
View File
@@ -153,11 +153,20 @@ def test_document_owner_filter_applies_owner_clause():
# gallery._owner_filter
# ---------------------------------------------------------------------------
def test_gallery_owner_filter_allows_single_user_mode():
def test_gallery_owner_filter_blocks_anonymous(monkeypatch):
monkeypatch.setenv("AUTH_ENABLED", "true")
from routes.gallery_routes import _owner_filter
fake_q = MagicMock()
out = _owner_filter(fake_q, user=None)
fake_q.filter.assert_called_once_with(False)
assert out is fake_q.filter.return_value
def test_gallery_owner_filter_allows_single_user_mode(monkeypatch):
monkeypatch.setenv("AUTH_ENABLED", "false")
from routes.gallery_routes import _owner_filter
fake_q = MagicMock()
out = _owner_filter(fake_q, user=None)
# user=None means single-user/auth-disabled mode: return q unchanged, no filter.
fake_q.filter.assert_not_called()
assert out is fake_q
+16 -4
View File
@@ -15,7 +15,6 @@ import uuid
import pytest
import core.database as cdb
from core.database import Session as DbSession
from core.models import ChatMessage
from tests.helpers.sqlite_db import make_temp_sqlite
@@ -34,9 +33,9 @@ def manager(monkeypatch):
def _make_session(sid, owner="alice"):
db = _TS()
try:
db.add(DbSession(id=sid, owner=owner, name="chat", model="gpt-4o",
endpoint_url="http://localhost:11434",
archived=False, message_count=1))
db.add(cdb.Session(id=sid, owner=owner, name="chat", model="gpt-4o",
endpoint_url="http://localhost:11434",
archived=False, message_count=1))
db.commit()
finally:
db.close()
@@ -69,3 +68,16 @@ def test_plain_string_content_still_round_trips(manager):
manager.sessions.clear()
reloaded = manager.get_session(sid)
assert reloaded.history[0].content == "just text"
def test_replace_messages_keeps_history_alias_for_context_messages(manager):
sid = "sess-" + uuid.uuid4().hex[:8]
_make_session(sid)
msgs = [ChatMessage(role="user", content="original")]
assert manager.replace_messages(sid, msgs) is True
session = manager.sessions[sid]
assert session.history is session._history
session.history.append(ChatMessage(role="user", content="after direct mutation"))
assert session.get_context_messages()[-1]["content"] == "after direct mutation"
+353
View File
@@ -0,0 +1,353 @@
"""Direct tests for the focused test-selection runner (tests/run_focus.py).
Command construction is tested separately from process execution: the pure
builder functions are asserted directly, and ``run`` is exercised with an
injected fake executor so no pytest subprocess is ever spawned.
"""
from __future__ import annotations
import argparse
import sys
import pytest
from tests.run_focus import (
FocusSelection,
build_marker_expression,
build_pytest_command,
discover_sub_areas,
normalize_sub_area,
run,
)
PY = "PY" # placeholder interpreter for deterministic command assertions
def _cmd(**kwargs) -> list[str]:
"""Build a pytest command for a FocusSelection made from kwargs."""
return build_pytest_command(FocusSelection(**kwargs), python=PY)
# --- marker expression building -------------------------------------------
def test_area_only_marker_expression():
assert build_marker_expression("security", None) == "area_security"
def test_sub_area_only_marker_expression():
assert build_marker_expression(None, "cookbook") == "sub_cookbook"
def test_area_and_sub_area_marker_expression():
assert build_marker_expression("services", "cookbook") == "area_services and sub_cookbook"
def test_no_selection_marker_expression_is_none():
assert build_marker_expression(None, None) is None
def test_fast_only_marker_expression():
assert build_marker_expression(None, None, fast=True) == "not slow"
def test_fast_composes_with_area():
assert build_marker_expression("services", None, fast=True) == "area_services and not slow"
def test_fast_composes_with_area_and_sub_area():
assert (
build_marker_expression("services", "cookbook", fast=True)
== "area_services and sub_cookbook and not slow"
)
# --- command construction --------------------------------------------------
def test_area_only_command():
assert _cmd(area="security") == [PY, "-m", "pytest", "-m", "area_security"]
def test_sub_area_only_command():
assert _cmd(sub_area="cookbook") == [PY, "-m", "pytest", "-m", "sub_cookbook"]
def test_area_and_sub_area_command():
assert _cmd(area="services", sub_area="cookbook") == [
PY, "-m", "pytest", "-m", "area_services and sub_cookbook",
]
def test_keyword_only_command():
assert _cmd(keyword="taxonomy") == [PY, "-m", "pytest", "-k", "taxonomy"]
def test_area_and_keyword_command():
assert _cmd(area="services", keyword="cookbook") == [
PY, "-m", "pytest", "-m", "area_services", "-k", "cookbook",
]
def test_passthrough_pytest_args_appended_last():
command = _cmd(area="services", pytest_args=("--maxfail=1", "-q"))
assert command == [PY, "-m", "pytest", "-m", "area_services", "--maxfail=1", "-q"]
def test_last_failed_appends_safe_flags():
assert _cmd(last_failed=True) == [
PY,
"-m",
"pytest",
"--last-failed",
"--last-failed-no-failures=none",
]
def test_default_python_is_current_interpreter():
command = build_pytest_command(FocusSelection(area="cli"))
assert command[0] == sys.executable
# --- fast lane and duration visibility -------------------------------------
def test_fast_only_command():
assert _cmd(fast=True) == [PY, "-m", "pytest", "-m", "not slow"]
def test_fast_with_area_command():
assert _cmd(area="services", fast=True) == [
PY, "-m", "pytest", "-m", "area_services and not slow",
]
def test_fast_with_area_and_sub_area_command():
assert _cmd(area="services", sub_area="cookbook", fast=True) == [
PY, "-m", "pytest", "-m", "area_services and sub_cookbook and not slow",
]
def test_durations_appends_flag():
assert _cmd(fast=True, durations=25) == [
PY, "-m", "pytest", "-m", "not slow", "--durations=25",
]
def test_durations_min_appends_flag():
assert _cmd(fast=True, durations=25, durations_min=0.05) == [
PY, "-m", "pytest", "-m", "not slow", "--durations=25", "--durations-min=0.05",
]
def test_durations_is_not_a_focus_selector():
assert FocusSelection(durations=25).has_focus is False
assert FocusSelection(fast=True).has_focus is True
def test_durations_kept_before_passthrough_args():
command = _cmd(fast=True, durations=25, pytest_args=("-q",))
assert command == [PY, "-m", "pytest", "-m", "not slow", "--durations=25", "-q"]
# --- sub-area normalization ------------------------------------------------
def test_normalize_sub_area_lowercases_and_collapses():
assert normalize_sub_area("Cook Book") == "cook_book"
def test_normalize_sub_area_strips_separators():
assert normalize_sub_area("--owner.scope--") == "owner_scope"
def test_normalize_sub_area_removes_marker_prefix():
assert normalize_sub_area("sub_cookbook") == "cookbook"
def test_normalize_sub_area_rejects_empty_after_normalization():
with pytest.raises(argparse.ArgumentTypeError):
normalize_sub_area("!!!")
def test_discover_sub_areas_from_test_filename(tmp_path):
(tmp_path / "test_cookbook_helpers.py").write_text("", encoding="utf-8")
assert discover_sub_areas(tmp_path) == frozenset({"cookbook"})
# --- run(): dry-run, execution, validation ---------------------------------
class _FakeExecutor:
"""Records the command it was asked to run and returns a fixed code."""
def __init__(self, returncode: int = 0):
self.returncode = returncode
self.calls: list[list[str]] = []
def __call__(self, command: list[str]) -> int:
self.calls.append(command)
return self.returncode
def test_dry_run_prints_command_and_does_not_execute(capsys):
executor = _FakeExecutor()
code = run(
["--dry-run", "--area", "services", "--sub-area", "cookbook"],
executor=executor,
)
out = capsys.readouterr().out
assert code == 0
assert executor.calls == []
assert out == (
f"{sys.executable} -m pytest "
"-m 'area_services and sub_cookbook'\n"
)
def test_dry_run_last_failed_prints_safe_flags(capsys):
executor = _FakeExecutor()
code = run(["--dry-run", "--last-failed"], executor=executor)
out = capsys.readouterr().out
assert code == 0
assert executor.calls == []
assert out == (
f"{sys.executable} -m pytest "
"--last-failed --last-failed-no-failures=none\n"
)
def test_run_invokes_executor_with_built_command():
executor = _FakeExecutor(returncode=3)
code = run(["--keyword", "taxonomy", "--", "--maxfail=1"], executor=executor)
assert code == 3
assert executor.calls == [[sys.executable, "-m", "pytest", "-k", "taxonomy", "--maxfail=1"]]
def test_run_last_failed_only():
executor = _FakeExecutor()
run(["--last-failed"], executor=executor)
assert executor.calls == [[
sys.executable,
"-m",
"pytest",
"--last-failed",
"--last-failed-no-failures=none",
]]
@pytest.mark.parametrize("value", ["cookbook", "sub_cookbook"])
def test_run_accepts_both_sub_area_forms(value):
executor = _FakeExecutor()
run(["--sub-area", value], executor=executor)
assert executor.calls == [[
sys.executable,
"-m",
"pytest",
"-m",
"sub_cookbook",
]]
def test_invalid_area_exits_with_error():
with pytest.raises(SystemExit) as excinfo:
run(["--area", "bogus"], executor=_FakeExecutor())
assert excinfo.value.code == 2
def test_invalid_sub_area_exits_with_error(capsys):
with pytest.raises(SystemExit) as excinfo:
run(
["--sub-area", "definitely_not_a_real_sub_area"],
executor=_FakeExecutor(),
)
assert excinfo.value.code == 2
assert "unknown sub-area" in capsys.readouterr().err
def test_no_focus_selector_is_rejected():
executor = _FakeExecutor()
with pytest.raises(SystemExit) as excinfo:
run(["--", "-q"], executor=executor)
assert excinfo.value.code == 2
assert executor.calls == []
def test_fast_run_invokes_executor_with_not_slow():
executor = _FakeExecutor()
run(["--fast"], executor=executor)
assert executor.calls == [[sys.executable, "-m", "pytest", "-m", "not slow"]]
def test_fast_with_durations_run_invokes_executor():
executor = _FakeExecutor()
run(["--area", "services", "--fast", "--durations", "25"], executor=executor)
assert executor.calls == [[
sys.executable,
"-m",
"pytest",
"-m",
"area_services and not slow",
"--durations=25",
]]
def test_fast_durations_dry_run_prints_command(capsys):
executor = _FakeExecutor()
code = run(["--dry-run", "--fast", "--durations", "25"], executor=executor)
out = capsys.readouterr().out
assert code == 0
assert executor.calls == []
assert out == f"{sys.executable} -m pytest -m 'not slow' --durations=25\n"
def test_durations_alone_is_rejected_before_executor():
executor = _FakeExecutor()
with pytest.raises(SystemExit) as excinfo:
run(["--durations", "25"], executor=executor)
assert excinfo.value.code == 2
assert executor.calls == []
def test_durations_zero_is_allowed_means_show_all():
executor = _FakeExecutor()
run(["--fast", "--durations", "0"], executor=executor)
assert executor.calls == [[
sys.executable, "-m", "pytest", "-m", "not slow", "--durations=0",
]]
@pytest.mark.parametrize("flag,value", [("--durations", "-1"), ("--durations-min", "-0.5")])
def test_negative_duration_values_are_rejected(flag, value):
executor = _FakeExecutor()
with pytest.raises(SystemExit) as excinfo:
run(["--fast", flag, value], executor=executor)
assert excinfo.value.code == 2
assert executor.calls == []
@pytest.mark.parametrize("argv", [
["--fast", "--durations-min", "0.05"],
["--area", "services", "--durations-min", "0.05"],
])
def test_durations_min_without_durations_is_rejected(argv):
executor = _FakeExecutor()
with pytest.raises(SystemExit) as excinfo:
run(argv, executor=executor)
assert excinfo.value.code == 2
assert executor.calls == []
def test_durations_min_with_durations_is_allowed():
executor = _FakeExecutor()
run(["--fast", "--durations", "25", "--durations-min", "0.05"], executor=executor)
assert executor.calls == [[
sys.executable,
"-m",
"pytest",
"-m",
"not slow",
"--durations=25",
"--durations-min=0.05",
]]
+472
View File
@@ -0,0 +1,472 @@
"""Tests for src.service_health — the consolidated degraded-state report.
Imports the real module (conftest.py stubs the heavy deps). Network is never
touched: HTTP probes take an injected `http_get`, and the email/provider probes
take an injected `connect` / `probe`. Asserts the ok/degraded/down/disabled
mapping per subsystem, the overall rollup, and that no secrets leak into meta.
"""
import types
import pytest
from src import service_health as sh
def _resp(status_code):
return types.SimpleNamespace(status_code=status_code)
def _raise(*_a, **_k):
raise RuntimeError("connection refused")
# ── chromadb_health ──
class _Store:
def __init__(self, healthy):
self.healthy = healthy
def test_chromadb_both_healthy_ok():
s = sh.chromadb_health(_Store(True), _Store(True))
assert s["status"] == sh.OK
assert s["meta"] == {"rag": True, "memory": True}
def test_chromadb_one_down_degraded():
s = sh.chromadb_health(_Store(True), _Store(False))
assert s["status"] == sh.DEGRADED
def test_chromadb_both_unhealthy_down():
s = sh.chromadb_health(_Store(False), _Store(False))
assert s["status"] == sh.DOWN
def test_chromadb_both_absent_disabled():
s = sh.chromadb_health(None, None)
assert s["status"] == sh.DISABLED
def test_chromadb_one_absent_one_healthy_ok():
# An absent store is not a failure; the present one being healthy is ok.
s = sh.chromadb_health(_Store(True), None)
assert s["status"] == sh.OK
assert s["meta"]["memory"] is None
# ── searxng_health ──
def test_searxng_disabled_when_other_provider():
s = sh.searxng_health({"search_provider": "brave"})
assert s["status"] == sh.DISABLED
def test_searxng_ok_on_healthz():
s = sh.searxng_health(
{"search_provider": "searxng", "search_url": "http://sx:8080"},
http_get=lambda url, timeout: _resp(200),
)
assert s["status"] == sh.OK
assert s["meta"]["probed"] == "/healthz"
def test_searxng_ok_on_root_fallback():
def getter(url, timeout):
return _resp(404) if url.endswith("/healthz") else _resp(200)
s = sh.searxng_health(
{"search_provider": "searxng", "search_url": "http://sx:8080"},
http_get=getter,
)
assert s["status"] == sh.OK
assert s["meta"]["probed"] == "/"
def test_searxng_down_on_exception():
s = sh.searxng_health(
{"search_provider": "searxng", "search_url": "http://sx:8080"},
http_get=_raise,
)
assert s["status"] == sh.DOWN
def test_searxng_down_on_5xx():
s = sh.searxng_health(
{"search_provider": "searxng", "search_url": "http://sx:8080"},
http_get=lambda url, timeout: _resp(502),
)
assert s["status"] == sh.DOWN
# ── ntfy_health ──
def _ntfy_intg():
return [{"preset": "ntfy", "enabled": True, "base_url": "http://ntfy:80"}]
def test_ntfy_disabled_without_integration():
s = sh.ntfy_health([], {"reminder_channel": "ntfy"})
assert s["status"] == sh.DISABLED
def test_ntfy_ok():
s = sh.ntfy_health(_ntfy_intg(), {"reminder_channel": "ntfy"},
http_get=lambda url, timeout: _resp(200))
assert s["status"] == sh.OK
assert s["meta"]["base"] == "http://ntfy:80"
def test_ntfy_probes_v1_health_not_a_topic():
seen = {}
def getter(url, timeout):
seen["url"] = url
return _resp(200)
sh.ntfy_health(_ntfy_intg(), {"reminder_channel": "ntfy"}, http_get=getter)
# Non-intrusive: hits /v1/health, never publishes to a topic.
assert seen["url"].endswith("/v1/health")
def test_ntfy_down_on_exception():
s = sh.ntfy_health(_ntfy_intg(), {"reminder_channel": "ntfy"},
http_get=_raise)
assert s["status"] == sh.DOWN
# ── email_health ──
def _acct(name, host="imap.example.com"):
return {"account_id": name, "account_name": name, "imap_host": host,
"imap_password": "hunter2"}
class _Conn:
def logout(self):
pass
def test_email_disabled_without_accounts():
assert sh.email_health([])["status"] == sh.DISABLED
def test_email_ok_all_connect():
s = sh.email_health([_acct("a"), _acct("b")], connect=lambda _id: _Conn())
assert s["status"] == sh.OK
def test_email_degraded_some_fail():
def connect(account_id):
if account_id == "bad":
raise RuntimeError("auth failed")
return _Conn()
s = sh.email_health([_acct("good"), _acct("bad")], connect=connect)
assert s["status"] == sh.DEGRADED
def test_email_down_all_fail():
s = sh.email_health([_acct("a")], connect=_raise)
assert s["status"] == sh.DOWN
def test_email_account_without_host_marked_failed():
s = sh.email_health([_acct("a", host="")], connect=lambda _id: _Conn())
assert s["status"] == sh.DOWN
def test_email_meta_never_leaks_password():
s = sh.email_health([_acct("a")], connect=lambda _id: _Conn())
assert "hunter2" not in repr(s)
# ── providers_health ──
def _ep(name):
return {"name": name, "base_url": f"http://{name}:8000/v1", "api_key": "sk-secret"}
def test_providers_disabled_without_endpoints():
assert sh.providers_health([])["status"] == sh.DISABLED
def test_providers_ok_all_reachable():
s = sh.providers_health([_ep("a")],
probe=lambda base, key, timeout: ["m1", "m2"])
assert s["status"] == sh.OK
assert s["meta"]["endpoints"][0]["model_count"] == 2
def test_providers_degraded_some_empty():
def probe(base, key, timeout):
return ["m1"] if "good" in base else []
s = sh.providers_health([_ep("good"), _ep("bad")], probe=probe)
assert s["status"] == sh.DEGRADED
def test_providers_down_all_fail():
s = sh.providers_health([_ep("a")], probe=_raise)
assert s["status"] == sh.DOWN
def test_providers_meta_never_leaks_api_key():
s = sh.providers_health([_ep("a")],
probe=lambda base, key, timeout: ["m1"])
assert "sk-secret" not in repr(s)
# ── rollup ──
def test_rollup_picks_worst_non_disabled():
services = [
{"status": sh.OK}, {"status": sh.DISABLED},
{"status": sh.DEGRADED}, {"status": sh.OK},
]
assert sh._rollup(services) == sh.DEGRADED
def test_rollup_down_beats_degraded():
assert sh._rollup([{"status": sh.DEGRADED}, {"status": sh.DOWN}]) == sh.DOWN
def test_rollup_all_disabled_is_ok():
assert sh._rollup([{"status": sh.DISABLED}, {"status": sh.DISABLED}]) == sh.OK
# ── collect_service_health (async aggregate) ──
def test_collect_service_health_shape(monkeypatch):
import asyncio
# Avoid touching real data sources / network.
monkeypatch.setattr(sh, "_gather_inputs", lambda: {
"settings": {"search_provider": "disabled"},
"integrations": [],
"accounts": [],
"endpoints": [],
})
out = asyncio.run(sh.collect_service_health(_Store(True), _Store(True)))
assert set(out) == {"overall", "services", "timestamp"}
names = {s["name"] for s in out["services"]}
assert names == {"chromadb", "searxng", "ntfy", "email", "providers"}
# Chroma healthy, everything else disabled → overall ok.
assert out["overall"] == sh.OK
# ── _safe_url: strip userinfo / query / fragment ──
@pytest.mark.parametrize("raw,expected", [
("http://user:pass@host:8080/path?api_key=secret#frag", "http://host:8080/path"),
("https://admin:hunter2@searx.example.com/", "https://searx.example.com"),
("http://ntfy.local:80?token=abc", "http://ntfy.local:80"),
("host:8080", "host:8080"),
("", ""),
(None, ""),
])
def test_safe_url_strips_secrets(raw, expected):
out = sh._safe_url(raw)
assert out == expected
for bad in ("pass", "secret", "hunter2", "abc", "token", "@"):
if raw and bad in raw and bad not in expected:
assert bad not in out
# ── _classify_error: controlled categories, never raw text ──
def test_classify_error_categories():
import socket
assert sh._classify_error(TimeoutError()) == "timeout"
assert sh._classify_error(socket.timeout()) == "timeout"
assert sh._classify_error(socket.gaierror()) == "dns_error"
assert sh._classify_error(ConnectionRefusedError()) == "connection_refused"
assert sh._classify_error(OSError("boom")) == "network_error"
assert sh._classify_error(ValueError("x")) == "error"
# ── Sanitization in subsystem output (blocker #2) ──
def test_searxng_meta_redacts_instance_url():
s = sh.searxng_health(
{"search_provider": "searxng",
"search_url": "http://user:s3cr3t@searx.local:8080/?token=zzz"},
http_get=lambda url, timeout: _resp(200),
)
blob = repr(s)
assert "s3cr3t" not in blob and "zzz" not in blob and "user:" not in blob
assert s["meta"]["instance"] == "http://searx.local:8080"
def test_searxng_down_uses_error_category_not_raw_exception():
def boom(url, timeout):
raise RuntimeError("failed connecting to http://user:pw@searx.local secret-token")
s = sh.searxng_health(
{"search_provider": "searxng", "search_url": "http://searx.local"},
http_get=boom,
)
assert s["status"] == sh.DOWN
assert s["meta"]["error"] == "error" # controlled category token
assert "secret-token" not in repr(s) and "pw@" not in repr(s)
def test_ntfy_meta_redacts_userinfo_in_base():
intg = [{"preset": "ntfy", "enabled": True,
"base_url": "https://user:topsecret@ntfy.example.com"}]
seen = {}
def getter(url, timeout):
seen["url"] = url # the probe itself may keep credentials
return _resp(200)
s = sh.ntfy_health(intg, {"reminder_channel": "ntfy"}, http_get=getter)
assert s["meta"]["base"] == "https://ntfy.example.com"
assert "topsecret" not in repr(s)
def test_providers_name_fallback_is_sanitized():
# No display name → falls back to the base_url, which must be sanitized.
ep = {"base_url": "http://user:k3y@prov.local:9000/v1?api_key=zzz", "api_key": "sk-x"}
s = sh.providers_health([ep], probe=lambda b, k, t: ["m1"])
entry = s["meta"]["endpoints"][0]
assert entry["name"] == "http://prov.local:9000/v1"
assert "k3y" not in repr(s) and "zzz" not in repr(s) and "sk-x" not in repr(s)
def test_providers_probe_exception_maps_to_category():
def boom(base, key, timeout):
raise RuntimeError(f"500 from {base} with key {key}") # would leak base+key
s = sh.providers_health([_ep("a")], probe=boom)
assert s["status"] == sh.DOWN
assert s["meta"]["endpoints"][0]["error"] == "error"
assert "sk-secret" not in repr(s) and "http://a" not in repr(s)
def test_email_connect_exception_maps_to_category():
def boom(account_id):
raise RuntimeError("login failed for user bob with password hunter2")
s = sh.email_health([_acct("a")], connect=boom)
assert s["status"] == sh.DOWN
assert s["meta"]["accounts"][0]["error"] == "error"
assert "hunter2" not in repr(s)
# ── Bounded wall-clock (blocker #1) ──
def test_providers_bounded_marks_slow_as_timeout(monkeypatch):
import time
monkeypatch.setattr(sh, "_FANOUT_BUDGET", 1)
def probe(base, key, timeout):
if "slow" in base:
time.sleep(10) # would blow the budget if unbounded
return ["m1"]
eps = [{"name": "fast", "base_url": "http://fast", "api_key": "k"},
{"name": "slow", "base_url": "http://slow", "api_key": "k"}]
t0 = time.monotonic()
out = sh.providers_health(eps, probe=probe)
elapsed = time.monotonic() - t0
assert elapsed < 4, f"providers_health not bounded: took {elapsed:.1f}s"
by = {e["name"]: e for e in out["meta"]["endpoints"]}
assert by["fast"]["ok"] is True
assert by["slow"]["ok"] is False and by["slow"]["error"] == "timeout"
assert out["status"] == sh.DEGRADED
def test_providers_bounded_with_many_slow_endpoints(monkeypatch):
import time
monkeypatch.setattr(sh, "_FANOUT_BUDGET", 1)
def probe(base, key, timeout):
time.sleep(10)
return ["m1"]
eps = [{"name": f"ep{i}", "base_url": f"http://ep{i}", "api_key": "k"}
for i in range(25)]
t0 = time.monotonic()
out = sh.providers_health(eps, probe=probe)
elapsed = time.monotonic() - t0
# 25 endpoints * sleep would be huge if sequential; bounded keeps it ~budget.
assert elapsed < 4, f"not bounded with many endpoints: {elapsed:.1f}s"
assert out["status"] == sh.DOWN
assert all(e["error"] == "timeout" for e in out["meta"]["endpoints"])
def test_email_bounded_marks_slow_as_timeout(monkeypatch):
import time
monkeypatch.setattr(sh, "_FANOUT_BUDGET", 1)
def connect(account_id):
if account_id == "slow":
time.sleep(10)
return _Conn()
accts = [_acct("fast"), _acct("slow")]
accts[1]["account_id"] = "slow"
t0 = time.monotonic()
out = sh.email_health(accts, connect=connect)
elapsed = time.monotonic() - t0
assert elapsed < 4, f"email_health not bounded: took {elapsed:.1f}s"
by = {a["name"]: a for a in out["meta"]["accounts"]}
assert by["slow"]["error"] == "timeout"
def test_collect_runs_subsystems_concurrently(monkeypatch):
# The aggregate is bounded by running the (internally-bounded) subsystems
# concurrently, so total wall-clock ≈ max(subsystem), not the sum. Each of
# the four network subsystems here sleeps ~0.6s; sequential would be ~2.4s.
import asyncio
import time
monkeypatch.setattr(sh, "_gather_inputs", lambda: {
"settings": {}, "integrations": [], "accounts": [], "endpoints": [],
})
def slow(name):
def _fn(*_a, **_k):
time.sleep(0.6)
return {"name": name, "status": sh.OK, "detail": "", "meta": {}}
return _fn
monkeypatch.setattr(sh, "searxng_health", slow("searxng"))
monkeypatch.setattr(sh, "ntfy_health", slow("ntfy"))
monkeypatch.setattr(sh, "email_health", slow("email"))
monkeypatch.setattr(sh, "providers_health", slow("providers"))
t0 = time.monotonic()
out = asyncio.run(sh.collect_service_health(None, None))
elapsed = time.monotonic() - t0
assert elapsed < 1.5, f"subsystems not concurrent: took {elapsed:.1f}s"
assert {s["name"] for s in out["services"]} == {
"chromadb", "searxng", "ntfy", "email", "providers"}
def test_collect_aggregate_deadline_yields_controlled_result(monkeypatch):
# If the gather overruns the aggregate ceiling, the response is still a
# controlled {overall, services, timestamp} with each network subsystem
# marked down/timeout — never a hang or a raised exception.
import asyncio
import time
monkeypatch.setattr(sh, "_AGGREGATE_DEADLINE", 0.5)
monkeypatch.setattr(sh, "_SUBSYSTEM_DEADLINE", 0.4)
monkeypatch.setattr(sh, "_gather_inputs", lambda: {
"settings": {}, "integrations": [], "accounts": [], "endpoints": [],
})
async def _slow_gather(*coros, **_k):
for c in coros: # close unawaited coros to avoid warnings
close = getattr(c, "close", None)
if close:
close()
await asyncio.sleep(5)
# Force the outer wait_for to trip by making gather itself slow.
monkeypatch.setattr(sh.asyncio, "gather", _slow_gather)
t0 = time.monotonic()
out = asyncio.run(sh.collect_service_health(None, None))
elapsed = time.monotonic() - t0
assert elapsed < 2, f"aggregate deadline did not bound: {elapsed:.1f}s"
assert set(out) == {"overall", "services", "timestamp"}
net = [s for s in out["services"] if s["name"] != "chromadb"]
assert all(s["status"] == sh.DOWN and s["meta"].get("error") == "timeout"
for s in net)
+112
View File
@@ -0,0 +1,112 @@
"""Integration tests: concurrent chat sessions must not leak.
These tests verify that the async streaming chat path maintains session
isolation even under concurrent access patterns.
"""
import asyncio
import sys
import os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
import pytest
from core.models import Session, ChatMessage
from core.session_manager import SessionManager
@pytest.mark.asyncio
async def test_concurrent_sessions_have_independent_history():
"""Simulating concurrent message adds to different sessions."""
sm = SessionManager()
sm.sessions = {} # Bypass DB load
s1 = Session(id="sess-a", name="Chat A", endpoint_url="http://ep", model="model-a")
s2 = Session(id="sess-b", name="Chat B", endpoint_url="http://ep", model="model-b")
sm.sessions["sess-a"] = s1
sm.sessions["sess-b"] = s2
async def add_to_session(sid, msgs):
sess = sm.sessions[sid]
for role, content in msgs:
sess.add_message(ChatMessage(role, content))
# Simulate concurrent adds
await asyncio.gather(
add_to_session("sess-a", [("user", "hello from A"), ("assistant", "reply A")]),
add_to_session("sess-b", [("user", "hello from B")]),
)
a = sm.sessions["sess-a"]
b = sm.sessions["sess-b"]
assert len(a.history) == 2, f"Session A has {len(a.history)} messages, expected 2"
assert len(b.history) == 1, f"Session B has {len(b.history)} messages, expected 1"
assert b.history[0].content == "hello from B"
@pytest.mark.asyncio
async def test_concurrent_add_message_does_not_cross_contaminate():
"""Concurrent add_message calls must not write to each other's sessions."""
sm = SessionManager()
sm.sessions = {}
s1 = Session(id="a", name="A", endpoint_url="http://ep", model="m1")
s2 = Session(id="b", name="B", endpoint_url="http://ep", model="m2")
sm.sessions["a"] = s1
sm.sessions["b"] = s2
async def rapid_add(sid, count):
sess = sm.sessions[sid]
for i in range(count):
sess.add_message(ChatMessage("user", f"msg_{i}_from_{sid}"))
await asyncio.gather(
rapid_add("a", 5),
rapid_add("b", 5),
rapid_add("a", 3), # More adds to A
)
a = sm.sessions["a"]
b = sm.sessions["b"]
assert len(a.history) == 8, f"Session A has {len(a.history)} messages"
assert len(b.history) == 5, f"Session B has {len(b.history)} messages"
# Verify B's messages are purely from B
for msg in b.history:
assert msg.content.endswith("_from_b"), f"Session B has cross-contaminated: {msg.content}"
@pytest.mark.asyncio
async def test_concurrent_read_write_isolation():
"""Reading one session while writing to another must return correct data."""
sm = SessionManager()
sm.sessions = {}
s1 = Session(id="reader", name="Reader", endpoint_url="http://ep", model="m")
s2 = Session(id="writer", name="Writer", endpoint_url="http://ep", model="m")
sm.sessions["reader"] = s1
sm.sessions["writer"] = s2
# Pre-populate reader
s1.add_message(ChatMessage("user", "original"))
async def read_and_check():
for _ in range(20):
sess = sm.sessions["reader"]
hist = sess.get_context_messages()
# Should never see writer's messages
for msg in hist:
assert "writer_data" not in msg.get("content", ""), "Reader saw writer data!"
async def write_to_writer():
for i in range(20):
sm.sessions["writer"].add_message(ChatMessage("user", f"writer_data_{i}"))
await asyncio.gather(read_and_check(), write_to_writer())
# Final state check
reader = sm.sessions["reader"]
writer = sm.sessions["writer"]
assert len(reader.history) == 1, "Reader history mutated!"
assert len(writer.history) == 20, f"Writer has {len(writer.history)} messages"
+194
View File
@@ -0,0 +1,194 @@
"""Tests for SessionManager — session isolation and data integrity.
These tests prove the chat context drifting bug (#135) exists and verify fixes.
Uses mocked DB to test in-memory session management logic in isolation.
"""
import sys
import os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
import pytest
from unittest.mock import MagicMock, patch
from core.session_manager import SessionManager
from core.models import Session, ChatMessage
@pytest.fixture
def sm():
"""SessionManager with a fresh in-memory store, no DB load."""
# We need to patch INSIDE session_manager because it does
# `from .database import SessionLocal` at import time.
# The conftest stubs sqlalchemy itself, which can interfere,
# so we isolate by patching the imported names directly.
orig_session_local = SessionManager.__init__
def patched_init(self, sessions_file=None):
"""__init__ that skips DB load and starts with empty cache."""
self.sessions = {}
SessionManager.__init__ = patched_init
manager = SessionManager()
yield manager
SessionManager.__init__ = orig_session_local
class TestSessionIsolation:
"""PROVING THE BUG: Shared mutable history leaks between sessions."""
def test_history_is_not_shared_between_sessions(self, sm):
"""Two sessions must have independent history lists."""
# Manually create sessions without hitting DB
s1 = Session(id="s1", name="Chat A", endpoint_url="http://ep", model="model-a")
s2 = Session(id="s2", name="Chat B", endpoint_url="http://ep", model="model-b")
sm.sessions["s1"] = s1
sm.sessions["s2"] = s2
s1.add_message(ChatMessage("user", "hello from A"))
s2.add_message(ChatMessage("user", "hello from B"))
assert len(s1.history) == 1, f"Session A has {len(s1.history)} messages"
assert len(s2.history) == 1, f"Session B has {len(s2.history)} messages"
assert s1.history[0].content == "hello from A"
assert s2.history[0].content == "hello from B"
def test_mutating_one_session_history_does_not_affect_another(self, sm):
"""Appending to one session must not add messages to another."""
s1 = Session(id="s1", name="Chat A", endpoint_url="http://ep", model="model-a")
s2 = Session(id="s2", name="Chat B", endpoint_url="http://ep", model="model-b")
sm.sessions["s1"] = s1
sm.sessions["s2"] = s2
s1.add_message(ChatMessage("user", "msg1"))
s1.add_message(ChatMessage("assistant", "resp1"))
assert len(s2.history) == 0, (
f"Session B has {len(s2.history)} messages leaked from Session A"
)
def test_history_reference_sees_new_messages(self, sm):
"""Pre-existing references to .history must see new messages (it's the same list)."""
s = Session(id="s1", name="Test", endpoint_url="http://ep", model="model")
sm.sessions["s1"] = s
s.add_message(ChatMessage("user", "hi"))
old_history_ref = s.history
s.add_message(ChatMessage("user", "second message"))
# .history is the authoritative mutable list — old ref sees the append
assert len(old_history_ref) == 2, (
f"Old history ref has {len(old_history_ref)} items, expected 2"
)
assert len(s.history) == 2
def test_history_reassignment_updates_context_and_legacy_alias(self, sm):
"""Direct history reassignment must remain authoritative for context reads."""
s = Session(id="s1", name="Test", endpoint_url="http://ep", model="model")
replacement = [ChatMessage("user", "replacement")]
s.history = replacement
assert s._history is replacement
assert s.get_context_messages() == [
{"role": "user", "content": "replacement"}
]
def test_delete_session_removes_from_cache(self, sm):
"""delete_session must remove session from in-memory cache even when DB lookup fails."""
s = Session(id="unique-del", name="ToDelete", endpoint_url="http://ep", model="model")
sm.sessions["unique-del"] = s
assert "unique-del" in sm.sessions
sm.delete_session("unique-del")
# Note: In production, delete_session also deletes from DB.
# In this unit test without real DB, the cache entry is cleaned
# by the method's DB-query path. If that path fails, the session
# stays in cache — this is the pre-existing behavior.
# The real fix is to always delete from cache regardless of DB result.
pass
def test_empty_session_isolation(self, sm):
"""Empty session must not inherit messages from active sessions."""
s_empty = Session(id="empty", name="Empty", endpoint_url="http://ep", model="model")
s_active = Session(id="active", name="Active", endpoint_url="http://ep", model="model")
sm.sessions["empty"] = s_empty
sm.sessions["active"] = s_active
s_active.add_message(ChatMessage("user", "first"))
assert len(s_empty.history) == 0, (
f"Empty session has {len(s_empty.history)} messages from active session"
)
def test_add_message_updates_message_count(self, sm):
"""add_message must correctly increment message_count."""
s = Session(id="s1", name="Test", endpoint_url="http://ep", model="model")
sm.sessions["s1"] = s
assert s.message_count == 0
s.add_message(ChatMessage("user", "first"))
assert s.message_count == 1
s.add_message(ChatMessage("assistant", "reply"))
assert s.message_count == 2
def test_history_order_preserved(self, sm):
"""Messages must maintain insertion order."""
s = Session(id="s1", name="Test", endpoint_url="http://ep", model="model")
sm.sessions["s1"] = s
msgs = [
ChatMessage("user", "q1"),
ChatMessage("assistant", "a1"),
ChatMessage("user", "q2"),
ChatMessage("assistant", "a2"),
]
for m in msgs:
s.add_message(m)
for i, expected in enumerate(msgs):
assert s.history[i].role == expected.role
assert s.history[i].content == expected.content
def test_multiple_sessions_independent_counts(self, sm):
"""Multiple sessions must each track their own message counts."""
s1 = Session(id="s1", name="A", endpoint_url="http://ep", model="m1")
s2 = Session(id="s2", name="B", endpoint_url="http://ep", model="m2")
s3 = Session(id="s3", name="C", endpoint_url="http://ep", model="m3")
sm.sessions["s1"] = s1
sm.sessions["s2"] = s2
sm.sessions["s3"] = s3
s1.add_message(ChatMessage("user", "a1"))
s1.add_message(ChatMessage("user", "a2"))
s2.add_message(ChatMessage("user", "b1"))
assert s1.message_count == 2
assert s2.message_count == 1
assert s3.message_count == 0
def test_get_context_messages_returns_copies(self, sm):
"""get_context_messages must not expose internal list for mutation."""
s = Session(id="s1", name="Test", endpoint_url="http://ep", model="model")
sm.sessions["s1"] = s
s.add_message(ChatMessage("user", "original"))
ctx = s.get_context_messages()
ctx.append({"role": "user", "content": "injected"})
ctx2 = s.get_context_messages()
assert len(ctx2) == 1, (
f"get_context_messages leaked: {len(ctx2)} messages"
)
assert ctx2[0]["content"] == "original"
def test_get_session_uses_cache(self, sm):
"""get_session returns the session from cache."""
s = Session(id="s1", name="Test", endpoint_url="http://ep", model="model")
sm.sessions["s1"] = s
s.add_message(ChatMessage("user", "hi"))
retrieved = sm.get_session("s1")
assert len(retrieved.history) == 1
assert retrieved.history[0].content == "hi"
@@ -18,6 +18,7 @@ clear_fake_database_modules()
import core.database as cdb
from core.database import Base, Session as DbSession
from core.models import ChatMessage as MemChatMessage
from src.task_scheduler import TaskScheduler
# This test needs the real core.database (real SQLAlchemy Base/ChatMessage).
@@ -71,3 +72,44 @@ def test_session_delivery_survives_empty_database(monkeypatch):
assert len(sessions) == 1
assert sessions[0].endpoint_url == ""
assert sessions[0].model == ""
def test_session_delivery_uses_in_memory_messages_with_manager(monkeypatch):
"""Manager delivery must not construct the SQLAlchemy ChatMessage model."""
monkeypatch.setitem(sys.modules, "core.database", cdb)
parent = sys.modules.get("core")
if parent is not None:
monkeypatch.setattr(parent, "database", cdb, raising=False)
class RecordingManager:
def __init__(self):
self.messages = []
def add_message(self, session_id, message):
assert isinstance(message, MemChatMessage)
self.messages.append((session_id, message))
db = _make_db()
manager = RecordingManager()
scheduler = TaskScheduler.__new__(TaskScheduler)
scheduler._session_manager = manager
task = _make_task()
task.session_id = "existing-session"
task.endpoint_url = "http://endpoint"
task.model = "test-model"
asyncio.run(scheduler._deliver_task_result(task, "done", db))
assert [message.role for _, message in manager.messages] == [
"user",
"assistant",
]
assert [message.content for _, message in manager.messages] == [
"tidy",
"done",
]
assert all(session_id == "existing-session" for session_id, _ in manager.messages)
assert all(
message.metadata == {"model": "test-model"}
for _, message in manager.messages
)
@@ -57,3 +57,22 @@ def test_truncate_keep_count_exceeds_total_does_not_inflate_count():
)
finally:
db.close()
def test_truncate_keeps_history_alias_for_context_messages():
from core.models import ChatMessage
sm, database, sm_mod = _make_manager()
sid = "alias-after-truncate"
sm.create_session(session_id=sid, name="t", endpoint_url="x",
model="m", rag=False, owner="u")
for i in range(3):
sm.add_message(sid, ChatMessage("user", f"msg{i}"))
assert sm.truncate_messages(sid, 2) is True
session = sm.sessions[sid]
assert session.history is session._history
session.history.append(ChatMessage("user", "after direct mutation"))
assert session.get_context_messages()[-1]["content"] == "after direct mutation"
+31 -17
View File
@@ -1,25 +1,39 @@
import sys
from unittest.mock import MagicMock
# Clean up any mocks from previous tests to ensure we load real modules
for mod in ['src.agent_tools', 'src.tool_parsing', 'src.tool_schemas', 'src.tool_execution']:
sys.modules.pop(mod, None)
# This module needs the real agent-tool stack; importing it pulls in heavy
# DB/auth deps, so we stub those just long enough to import, then restore them.
# We deliberately do NOT pop src.tool_execution: popping and re-importing it
# rebinds the `src` package's `tool_execution` attribute, so a later
# `import src.tool_execution as te` resolves to a different module object than
# the one its functions live in - which silently breaks tests that monkeypatch
# it (e.g. test_edit_file's admin gate).
_ABSENT = object()
_AGENT_MODULES = ["src.agent_tools", "src.tool_parsing", "src.tool_schemas"]
_STUBBED = [
"sqlalchemy", "sqlalchemy.orm", "sqlalchemy.ext", "sqlalchemy.ext.declarative",
"sqlalchemy.ext.hybrid", "sqlalchemy.sql", "sqlalchemy.sql.expression",
"src.database", "core.models", "core.database", "core.auth",
]
_saved_stubs = {name: sys.modules.get(name, _ABSENT) for name in _STUBBED}
# Mock heavy database/model dependencies before importing
for mod in [
'sqlalchemy', 'sqlalchemy.orm', 'sqlalchemy.ext', 'sqlalchemy.ext.declarative',
'sqlalchemy.ext.hybrid', 'sqlalchemy.sql', 'sqlalchemy.sql.expression',
'src.database', 'core.models', 'core.database', 'core.auth'
]:
if mod not in sys.modules:
sys.modules[mod] = MagicMock()
for _mod in _AGENT_MODULES:
sys.modules.pop(_mod, None)
for _mod in _STUBBED:
if _mod not in sys.modules:
sys.modules[_mod] = MagicMock()
import pytest
import src.agent_tools
from src.tool_parsing import parse_tool_blocks
from src.tool_schemas import function_call_to_tool_block
from src.tool_execution import execute_tool_block
from types import SimpleNamespace
import pytest # noqa: E402
import src.agent_tools # noqa: E402,F401
from src.tool_parsing import parse_tool_blocks # noqa: E402
from src.tool_schemas import function_call_to_tool_block # noqa: E402
# Drop the stubs we installed so they do not leak into later tests.
for _name, _original in _saved_stubs.items():
if _original is _ABSENT:
sys.modules.pop(_name, None)
else:
sys.modules[_name] = _original
def test_parse_xml_unknown_tool_returns_none():