mirror of
https://github.com/pewdiepie-archdaemon/odysseus.git
synced 2026-06-17 02:05:22 -04:00
Merge remote-tracking branch 'origin/dev'
This commit is contained in:
+25
-2
@@ -340,6 +340,14 @@ def setup_auth_routes(auth_manager: AuthManager) -> APIRouter:
|
||||
ok = auth_manager.rename_user(old_username, new_username, user)
|
||||
if not ok:
|
||||
raise HTTPException(400, "Cannot rename user")
|
||||
# The owner-rename loop above updated ApiToken.owner in the DB, but the
|
||||
# bearer-token cache still maps each token to the OLD owner. Without
|
||||
# refreshing it, the renamed user's API tokens resolve to the old (now
|
||||
# non-existent) owner and stop reaching their data until the cache next
|
||||
# goes dirty. Invalidate it now, like the token CRUD routes do.
|
||||
invalidator = getattr(request.app.state, "invalidate_token_cache", None)
|
||||
if callable(invalidator):
|
||||
invalidator()
|
||||
return {"ok": True, "username": new_username, "renamed_self": old_username == user}
|
||||
|
||||
@router.post("/signup-toggle", deprecated=True)
|
||||
@@ -430,9 +438,24 @@ def setup_auth_routes(auth_manager: AuthManager) -> APIRouter:
|
||||
raise HTTPException(403, "Admin only")
|
||||
body = await request.json()
|
||||
current = _load_settings()
|
||||
# Per-key validation for numeric settings: coerce to int and clamp to a
|
||||
# sane range so a bad value can't disable the agent or let it run away.
|
||||
_INT_RANGES = {
|
||||
"agent_max_rounds": (1, 200),
|
||||
"agent_max_tool_calls": (0, 1000), # 0 = unlimited
|
||||
}
|
||||
for key in DEFAULT_SETTINGS:
|
||||
if key in body:
|
||||
current[key] = body[key]
|
||||
if key not in body:
|
||||
continue
|
||||
val = body[key]
|
||||
if key in _INT_RANGES:
|
||||
lo, hi = _INT_RANGES[key]
|
||||
try:
|
||||
val = int(val)
|
||||
except (TypeError, ValueError):
|
||||
raise HTTPException(400, f"{key} must be an integer")
|
||||
val = max(lo, min(val, hi))
|
||||
current[key] = val
|
||||
_save_settings(current)
|
||||
return current
|
||||
|
||||
|
||||
+15
-3
@@ -589,6 +589,8 @@ def _normalize_thinking(text: str) -> str:
|
||||
import re
|
||||
if not text:
|
||||
return text
|
||||
from src.text_helpers import normalize_thinking_markup
|
||||
text = normalize_thinking_markup(text)
|
||||
reasoning_prefix_re = re.compile(
|
||||
r'^\s*(?:thinking(?:\s+process)?\s*:|the user |i need |i should |i will |they are |the question |i can )',
|
||||
re.IGNORECASE,
|
||||
@@ -699,6 +701,10 @@ def _extract_thinking_meta(text: str) -> dict | None:
|
||||
import re
|
||||
if not text:
|
||||
return None
|
||||
from src.text_helpers import normalize_thinking_markup
|
||||
original_text = text
|
||||
text = normalize_thinking_markup(text)
|
||||
normalized_changed = text != original_text
|
||||
|
||||
# Check for <think> tags (native or injected)
|
||||
time_match = re.search(r'<think(?:ing)?\s+time="([\d.]+)"', text)
|
||||
@@ -729,6 +735,9 @@ def _extract_thinking_meta(text: str) -> dict | None:
|
||||
if thinking and reply:
|
||||
return {"thinking": thinking, "reply": reply, "time": think_time}
|
||||
|
||||
if normalized_changed and text.strip() and text.strip() != original_text.strip():
|
||||
return {"thinking": "", "reply": text.strip(), "time": think_time}
|
||||
|
||||
return None
|
||||
|
||||
|
||||
@@ -737,7 +746,8 @@ def clean_thinking_for_save(content: str, metadata: dict | None = None) -> tuple
|
||||
md = dict(metadata) if metadata else {}
|
||||
info = _extract_thinking_meta(content)
|
||||
if info:
|
||||
md["thinking"] = info["thinking"]
|
||||
if info.get("thinking"):
|
||||
md["thinking"] = info["thinking"]
|
||||
if info.get("time"):
|
||||
md["thinking_time"] = info["time"]
|
||||
return info["reply"], md
|
||||
@@ -781,8 +791,10 @@ def save_assistant_response(
|
||||
# Extract thinking into metadata (don't pollute message content with <think> tags)
|
||||
_think_info = _extract_thinking_meta(full_response)
|
||||
if _think_info:
|
||||
md["thinking"] = _think_info["thinking"]
|
||||
md["thinking_time"] = _think_info.get("time")
|
||||
if _think_info.get("thinking"):
|
||||
md["thinking"] = _think_info["thinking"]
|
||||
if _think_info.get("time"):
|
||||
md["thinking_time"] = _think_info.get("time")
|
||||
_content = _think_info["reply"]
|
||||
else:
|
||||
_content = full_response
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
import logging
|
||||
from datetime import datetime
|
||||
@@ -394,6 +395,12 @@ def setup_chat_routes(
|
||||
compare_mode = str(form_data.get("compare_mode", "")).lower() == "true"
|
||||
incognito = str(form_data.get("incognito", "")).lower() == "true"
|
||||
chat_mode = str(form_data.get("mode", "")).lower() # 'chat' or 'agent'
|
||||
# Workspace: confine the agent's file/shell tools to this folder. Validate
|
||||
# it's a real directory; ignore (no confinement) otherwise.
|
||||
workspace = (form_data.get("workspace") or "").strip()
|
||||
if workspace:
|
||||
_ws_real = os.path.realpath(os.path.expanduser(workspace))
|
||||
workspace = _ws_real if os.path.isdir(_ws_real) else ""
|
||||
# Did the USER explicitly pick agent mode? (vs. us auto-escalating
|
||||
# below). Skill extraction should only learn from real agent sessions,
|
||||
# not chats we quietly promoted for a notes/calendar intent.
|
||||
@@ -981,7 +988,15 @@ def setup_chat_routes(
|
||||
_answered_by = None # set if the selected model failed and a fallback answered
|
||||
try:
|
||||
from src.settings import get_setting
|
||||
from src.agent_tools import MAX_AGENT_ROUNDS as _DEFAULT_ROUNDS
|
||||
_tool_budget = int(get_setting("agent_max_tool_calls", 0))
|
||||
# Per-message round cap from settings; clamp defensively in
|
||||
# case settings.json was hand-edited to a bad value.
|
||||
try:
|
||||
_max_rounds = int(get_setting("agent_max_rounds", _DEFAULT_ROUNDS) or _DEFAULT_ROUNDS)
|
||||
except (TypeError, ValueError):
|
||||
_max_rounds = _DEFAULT_ROUNDS
|
||||
_max_rounds = max(1, min(_max_rounds, 200))
|
||||
|
||||
async for chunk in stream_agent_loop(
|
||||
sess.endpoint_url,
|
||||
@@ -992,12 +1007,14 @@ def setup_chat_routes(
|
||||
max_tokens=ctx.preset.max_tokens,
|
||||
prompt_type=preset_id,
|
||||
max_tool_calls=_tool_budget,
|
||||
max_rounds=_max_rounds,
|
||||
context_length=ctx.context_length,
|
||||
active_document=active_doc,
|
||||
session_id=session,
|
||||
disabled_tools=disabled_tools if disabled_tools else None,
|
||||
owner=_user,
|
||||
fallbacks=_fallback_candidates,
|
||||
workspace=workspace or None,
|
||||
):
|
||||
if chunk.startswith("data: ") and not chunk.startswith("data: [DONE]"):
|
||||
try:
|
||||
@@ -1017,6 +1034,7 @@ def setup_chat_routes(
|
||||
"tool_start", "tool_output", "agent_step",
|
||||
"doc_stream_open", "doc_stream_delta",
|
||||
"doc_update", "doc_suggestions", "ui_control",
|
||||
"rounds_exhausted",
|
||||
):
|
||||
if data.get("type") == "agent_step":
|
||||
_agent_rounds = max(_agent_rounds, data.get("round", 1))
|
||||
|
||||
@@ -0,0 +1,223 @@
|
||||
# routes/copilot_routes.py
|
||||
"""GitHub Copilot device-flow login.
|
||||
|
||||
Drives the GitHub OAuth *device flow* and, on success, creates (or refreshes)
|
||||
an owner-scoped ``ModelEndpoint`` pointing at the Copilot API with the
|
||||
device-flow access token stored as its (encrypted) ``api_key``. After that the
|
||||
endpoint behaves like any other OpenAI-compatible provider — the Copilot-
|
||||
specific request headers are injected centrally by ``build_headers`` /
|
||||
``_provider_headers`` (see :mod:`src.copilot`).
|
||||
|
||||
Flow:
|
||||
1. ``POST /api/copilot/device/start`` → returns a ``poll_id`` plus the
|
||||
``user_code`` + ``verification_uri`` to show the user. The secret
|
||||
``device_code`` is kept server-side, never sent to the browser.
|
||||
2. The browser polls ``POST /api/copilot/device/poll`` with ``poll_id``.
|
||||
While pending it returns ``{status: "pending"}``; once the user authorises
|
||||
it provisions the endpoint and returns ``{status: "authorized", ...}``.
|
||||
|
||||
All routes are admin-gated (endpoint/provider management is an admin action).
|
||||
"""
|
||||
|
||||
import json
|
||||
import time
|
||||
import uuid
|
||||
import logging
|
||||
import threading
|
||||
from typing import Dict, Optional
|
||||
|
||||
import httpx
|
||||
from fastapi import APIRouter, Request, Form, HTTPException
|
||||
|
||||
from core.database import SessionLocal, ModelEndpoint
|
||||
from core.middleware import require_admin
|
||||
from src.auth_helpers import get_current_user
|
||||
from src import copilot
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Pending device-flow logins, keyed by an opaque poll_id. The device_code is a
|
||||
# bearer-like secret, so it lives here (server memory) rather than in the
|
||||
# browser. Entries expire with the GitHub device code.
|
||||
#
|
||||
# NOTE: this is per-process state. The device flow assumes a single worker
|
||||
# (Odysseus' default): with multiple uvicorn workers, the poll request can land
|
||||
# on a worker that never saw the start, returning "Unknown or expired login
|
||||
# session". Move this to a shared store (DB/Redis) if running multi-worker.
|
||||
_PENDING: Dict[str, Dict] = {}
|
||||
_PENDING_LOCK = threading.Lock()
|
||||
|
||||
|
||||
def _prune_expired() -> None:
|
||||
now = time.time()
|
||||
with _PENDING_LOCK:
|
||||
for k in [k for k, v in _PENDING.items() if v.get("expires_at", 0) < now]:
|
||||
_PENDING.pop(k, None)
|
||||
|
||||
|
||||
def _provision_endpoint(token: str, base: str, owner: Optional[str]) -> Dict:
|
||||
"""Create or update the owner's Copilot endpoint with a fresh token."""
|
||||
try:
|
||||
models = copilot.fetch_models(base, token)
|
||||
except Exception as e:
|
||||
logger.warning(f"Copilot model fetch failed during provisioning: {e}")
|
||||
models = []
|
||||
model_ids = [m["id"] for m in models]
|
||||
# Copilot picker models support OpenAI-style tool calling; mark the endpoint
|
||||
# tool-capable so the agent loop sends native tool schemas.
|
||||
# Tool-capable if any picker model advertises tool_calls. When the model
|
||||
# fetch failed (empty list) default to True, since Copilot picker models
|
||||
# support OpenAI-style tool calling.
|
||||
supports_tools = bool(not models or any(m.get("tool_calls") for m in models))
|
||||
|
||||
db = SessionLocal()
|
||||
try:
|
||||
ep = (
|
||||
db.query(ModelEndpoint)
|
||||
.filter(ModelEndpoint.base_url == base)
|
||||
.filter((ModelEndpoint.owner.is_(None)) | (ModelEndpoint.owner == owner))
|
||||
.order_by(ModelEndpoint.owner.desc())
|
||||
.first()
|
||||
)
|
||||
if ep is None:
|
||||
ep = ModelEndpoint(
|
||||
id=str(uuid.uuid4())[:8],
|
||||
name="GitHub Copilot",
|
||||
base_url=base,
|
||||
model_type="llm",
|
||||
owner=owner,
|
||||
)
|
||||
db.add(ep)
|
||||
ep.api_key = token
|
||||
ep.is_enabled = True
|
||||
ep.supports_tools = supports_tools
|
||||
if model_ids:
|
||||
ep.cached_models = json.dumps(model_ids)
|
||||
db.commit()
|
||||
result = {
|
||||
"id": ep.id,
|
||||
"name": ep.name,
|
||||
"base_url": ep.base_url,
|
||||
"models": model_ids,
|
||||
}
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
# Best-effort: refresh the model cache so the new endpoint shows up.
|
||||
try:
|
||||
from routes.model_routes import _invalidate_models_cache
|
||||
_invalidate_models_cache()
|
||||
except Exception:
|
||||
pass
|
||||
return result
|
||||
|
||||
|
||||
def setup_copilot_routes() -> APIRouter:
|
||||
router = APIRouter(prefix="/api/copilot", tags=["copilot"])
|
||||
|
||||
@router.post("/device/start")
|
||||
def device_start(request: Request, enterprise_url: str = Form("")):
|
||||
require_admin(request)
|
||||
_prune_expired()
|
||||
host = copilot.GITHUB_HOST
|
||||
ent = (enterprise_url or "").strip()
|
||||
if ent:
|
||||
host = copilot.normalize_domain(ent)
|
||||
try:
|
||||
data = copilot.request_device_code(host)
|
||||
except httpx.HTTPStatusError as e:
|
||||
status = e.response.status_code if e.response is not None else "unknown"
|
||||
raise HTTPException(502, f"GitHub device-code request failed (HTTP {status})")
|
||||
except Exception as e:
|
||||
raise HTTPException(502, f"GitHub device-code request failed: {e}")
|
||||
|
||||
device_code = data.get("device_code")
|
||||
if not device_code:
|
||||
raise HTTPException(502, "GitHub did not return a device code")
|
||||
interval = int(data.get("interval") or 5)
|
||||
expires_in = int(data.get("expires_in") or 900)
|
||||
poll_id = uuid.uuid4().hex
|
||||
with _PENDING_LOCK:
|
||||
_PENDING[poll_id] = {
|
||||
"device_code": device_code,
|
||||
"host": host,
|
||||
"enterprise_url": ent,
|
||||
"interval": interval,
|
||||
"owner": get_current_user(request) or None,
|
||||
"expires_at": time.time() + expires_in,
|
||||
"next_poll_at": 0.0,
|
||||
}
|
||||
# verification_uri_complete embeds the user code, so the browser tab we
|
||||
# open lands the user straight on GitHub's "Authorize" screen with the
|
||||
# code pre-filled — one click, no manual code entry.
|
||||
return {
|
||||
"poll_id": poll_id,
|
||||
"user_code": data.get("user_code"),
|
||||
"verification_uri": data.get("verification_uri"),
|
||||
"verification_uri_complete": data.get("verification_uri_complete"),
|
||||
"interval": interval,
|
||||
"expires_in": expires_in,
|
||||
}
|
||||
|
||||
@router.post("/device/poll")
|
||||
def device_poll(request: Request, poll_id: str = Form(...)):
|
||||
require_admin(request)
|
||||
_prune_expired()
|
||||
with _PENDING_LOCK:
|
||||
pending = _PENDING.get(poll_id)
|
||||
if not pending:
|
||||
raise HTTPException(404, "Unknown or expired login session")
|
||||
|
||||
# Enforce GitHub's polling interval server-side so a chatty client
|
||||
# can't trip slow_down.
|
||||
now = time.time()
|
||||
if now < pending.get("next_poll_at", 0):
|
||||
return {"status": "pending"}
|
||||
|
||||
try:
|
||||
data = copilot.poll_access_token(pending["host"], pending["device_code"])
|
||||
except Exception as e:
|
||||
return {"status": "pending", "detail": f"poll error: {e}"}
|
||||
|
||||
token = data.get("access_token")
|
||||
if token:
|
||||
base = copilot.enterprise_base(pending["enterprise_url"]) if pending["enterprise_url"] else copilot.COPILOT_BASE
|
||||
try:
|
||||
result = _provision_endpoint(token, base, pending["owner"])
|
||||
except Exception as e:
|
||||
logger.exception("Copilot endpoint provisioning failed")
|
||||
with _PENDING_LOCK:
|
||||
_PENDING.pop(poll_id, None)
|
||||
raise HTTPException(500, f"Login succeeded but provisioning failed: {e}")
|
||||
with _PENDING_LOCK:
|
||||
_PENDING.pop(poll_id, None)
|
||||
return {"status": "authorized", "endpoint": result}
|
||||
|
||||
err = data.get("error")
|
||||
if err == "authorization_pending":
|
||||
with _PENDING_LOCK:
|
||||
if poll_id in _PENDING:
|
||||
_PENDING[poll_id]["next_poll_at"] = now + pending["interval"]
|
||||
return {"status": "pending"}
|
||||
if err == "slow_down":
|
||||
new_interval = int(data.get("interval") or (pending["interval"] + 5))
|
||||
with _PENDING_LOCK:
|
||||
if poll_id in _PENDING:
|
||||
_PENDING[poll_id]["interval"] = new_interval
|
||||
_PENDING[poll_id]["next_poll_at"] = now + new_interval
|
||||
return {"status": "pending"}
|
||||
if err in ("expired_token", "access_denied"):
|
||||
with _PENDING_LOCK:
|
||||
_PENDING.pop(poll_id, None)
|
||||
return {"status": "failed", "error": err}
|
||||
# Unknown error — surface but keep the session for another try.
|
||||
return {"status": "pending", "detail": err or "unknown"}
|
||||
|
||||
@router.post("/device/cancel")
|
||||
def device_cancel(request: Request, poll_id: str = Form(...)):
|
||||
require_admin(request)
|
||||
with _PENDING_LOCK:
|
||||
_PENDING.pop(poll_id, None)
|
||||
return {"status": "cancelled"}
|
||||
|
||||
return router
|
||||
@@ -153,7 +153,6 @@ def setup_document_routes(session_manager, upload_handler=None) -> APIRouter:
|
||||
with a `pdf_source` marker so the viewer renders the pages without
|
||||
overlays.
|
||||
"""
|
||||
from src.constants import UPLOAD_DIR
|
||||
from src.pdf_forms import has_form_fields, extract_fields
|
||||
from src.pdf_form_doc import (
|
||||
save_field_sidecar,
|
||||
@@ -950,7 +949,6 @@ def setup_document_routes(session_manager, upload_handler=None) -> APIRouter:
|
||||
any wrong values before triggering the actual download.
|
||||
"""
|
||||
from src.pdf_form_doc import find_source_upload_id, parse_markdown_to_values, load_field_sidecar
|
||||
from src.constants import UPLOAD_DIR
|
||||
|
||||
user = get_current_user(request)
|
||||
db = SessionLocal()
|
||||
@@ -1015,7 +1013,6 @@ def setup_document_routes(session_manager, upload_handler=None) -> APIRouter:
|
||||
Frontend overlays HTML form controls at those positions.
|
||||
"""
|
||||
from src.pdf_form_doc import find_source_upload_id, parse_markdown_to_values, load_field_sidecar
|
||||
from src.constants import UPLOAD_DIR
|
||||
|
||||
user = get_current_user(request)
|
||||
db = SessionLocal()
|
||||
@@ -1083,7 +1080,6 @@ def setup_document_routes(session_manager, upload_handler=None) -> APIRouter:
|
||||
frontend overlays HTML form inputs on top)."""
|
||||
from fastapi.responses import Response
|
||||
from src.pdf_form_doc import find_source_upload_id
|
||||
from src.constants import UPLOAD_DIR
|
||||
|
||||
user = get_current_user(request)
|
||||
db = SessionLocal()
|
||||
@@ -1132,7 +1128,6 @@ def setup_document_routes(session_manager, upload_handler=None) -> APIRouter:
|
||||
import json
|
||||
import fitz
|
||||
from src.pdf_form_doc import find_source_upload_id
|
||||
from src.constants import UPLOAD_DIR
|
||||
from src.document_processor import _resolve_vl_model, _load_vl_settings
|
||||
from src.llm_core import llm_call_async
|
||||
|
||||
@@ -1275,7 +1270,6 @@ def setup_document_routes(session_manager, upload_handler=None) -> APIRouter:
|
||||
from starlette.background import BackgroundTask
|
||||
from src.pdf_form_doc import find_source_upload_id, parse_markdown_to_values, parse_markdown_annotations
|
||||
from src.pdf_forms import fill_fields, stamp_annotations
|
||||
from src.constants import UPLOAD_DIR
|
||||
from core.database import Signature
|
||||
|
||||
# Track temp files for this request so they get unlinked AFTER
|
||||
@@ -1370,7 +1364,6 @@ def setup_document_routes(session_manager, upload_handler=None) -> APIRouter:
|
||||
from starlette.background import BackgroundTask
|
||||
from src.pdf_form_doc import find_source_upload_id, parse_markdown_to_values, load_field_sidecar, parse_markdown_annotations
|
||||
from src.pdf_forms import fill_fields, stamp_signatures, stamp_annotations
|
||||
from src.constants import UPLOAD_DIR
|
||||
from core.database import Signature
|
||||
|
||||
_to_unlink: list[str] = []
|
||||
@@ -1512,7 +1505,6 @@ def setup_document_routes(session_manager, upload_handler=None) -> APIRouter:
|
||||
load_field_sidecar, parse_markdown_annotations,
|
||||
)
|
||||
from src.pdf_forms import fill_fields, stamp_signatures, stamp_annotations
|
||||
from src.constants import UPLOAD_DIR
|
||||
from core.database import Signature
|
||||
# COMPOSE_UPLOADS_DIR lives in email_routes — re-derive here so we
|
||||
# don't import from a routes file (cycle-prone). Same env override
|
||||
|
||||
+68
-17
@@ -266,6 +266,48 @@ COMPOSE_UPLOADS_DIR.mkdir(parents=True, exist_ok=True)
|
||||
SCHEDULED_DB = DATA_DIR / "scheduled_emails.db"
|
||||
|
||||
|
||||
OWNER_SCOPED_EMAIL_CACHE_TABLES = {
|
||||
"email_summaries",
|
||||
"email_ai_replies",
|
||||
"email_calendar_extractions",
|
||||
"email_urgency_alerts",
|
||||
}
|
||||
|
||||
|
||||
def _email_cache_owner_clause(owner: str = "") -> tuple[str, tuple[str, ...]]:
|
||||
owner = (owner or "").strip()
|
||||
if owner:
|
||||
return "owner = ?", (owner,)
|
||||
return "(owner = '' OR owner IS NULL)", ()
|
||||
|
||||
|
||||
def _ensure_owner_scoped_email_cache_table(conn, table: str, create_sql: str, columns: list[str]):
|
||||
"""Rebuild legacy Message-ID-only cache tables with owner in the PK."""
|
||||
conn.execute(create_sql)
|
||||
try:
|
||||
info = conn.execute(f"PRAGMA table_info({table})").fetchall()
|
||||
cols = [r[1] for r in info]
|
||||
pk_cols = [r[1] for r in sorted((r for r in info if r[5]), key=lambda r: r[5])]
|
||||
if "owner" in cols and pk_cols == ["message_id", "owner"]:
|
||||
return
|
||||
|
||||
conn.execute(f"ALTER TABLE {table} RENAME TO {table}__old")
|
||||
conn.execute(create_sql)
|
||||
old_cols = [r[1] for r in conn.execute(f"PRAGMA table_info({table}__old)").fetchall()]
|
||||
copy_cols = [c for c in columns if c != "owner" and c in old_cols]
|
||||
source_owner = "COALESCE(owner, '')" if "owner" in old_cols else "''"
|
||||
target_cols = ["owner", *copy_cols]
|
||||
select_exprs = [source_owner, *copy_cols]
|
||||
conn.execute(
|
||||
f"INSERT OR IGNORE INTO {table} ({', '.join(target_cols)}) "
|
||||
f"SELECT {', '.join(select_exprs)} FROM {table}__old"
|
||||
)
|
||||
conn.execute(f"DROP TABLE {table}__old")
|
||||
except Exception as _mig_e:
|
||||
import logging as _lg
|
||||
_lg.getLogger(__name__).warning(f"{table} owner-migration skipped: {_mig_e}")
|
||||
|
||||
|
||||
def attachment_extract_dir(folder: str, uid: str) -> Path:
|
||||
"""Containment-safe extraction directory for an attachment.
|
||||
|
||||
@@ -301,30 +343,35 @@ def _init_scheduled_db():
|
||||
owner TEXT DEFAULT ''
|
||||
)
|
||||
""")
|
||||
# Email summary cache (keyed by Message-ID)
|
||||
conn.execute("""
|
||||
# Email summary cache. SECURITY: Message-IDs are global, so AI-derived
|
||||
# cache rows must be owner-scoped just like email_tags.
|
||||
_ensure_owner_scoped_email_cache_table(conn, "email_summaries", """
|
||||
CREATE TABLE IF NOT EXISTS email_summaries (
|
||||
message_id TEXT PRIMARY KEY,
|
||||
message_id TEXT,
|
||||
owner TEXT DEFAULT '',
|
||||
uid TEXT,
|
||||
folder TEXT,
|
||||
subject TEXT,
|
||||
sender TEXT,
|
||||
summary TEXT NOT NULL,
|
||||
model_used TEXT,
|
||||
created_at TEXT NOT NULL
|
||||
created_at TEXT NOT NULL,
|
||||
PRIMARY KEY (message_id, owner)
|
||||
)
|
||||
""")
|
||||
""", ["message_id", "owner", "uid", "folder", "subject", "sender", "summary", "model_used", "created_at"])
|
||||
# Email AI reply cache (pre-generated draft replies)
|
||||
conn.execute("""
|
||||
_ensure_owner_scoped_email_cache_table(conn, "email_ai_replies", """
|
||||
CREATE TABLE IF NOT EXISTS email_ai_replies (
|
||||
message_id TEXT PRIMARY KEY,
|
||||
message_id TEXT,
|
||||
owner TEXT DEFAULT '',
|
||||
uid TEXT,
|
||||
folder TEXT,
|
||||
reply TEXT NOT NULL,
|
||||
model_used TEXT,
|
||||
created_at TEXT NOT NULL
|
||||
created_at TEXT NOT NULL,
|
||||
PRIMARY KEY (message_id, owner)
|
||||
)
|
||||
""")
|
||||
""", ["message_id", "owner", "uid", "folder", "reply", "model_used", "created_at"])
|
||||
# Email tags / spam classification cache. SECURITY: keyed by
|
||||
# (message_id, owner) because Message-IDs are GLOBAL (a newsletter goes
|
||||
# to many users with the same Message-ID). Without owner-scoping, a
|
||||
@@ -384,17 +431,20 @@ def _init_scheduled_db():
|
||||
# Best-effort — log via the module logger if available
|
||||
import logging as _lg
|
||||
_lg.getLogger(__name__).warning(f"email_tags owner-migration skipped: {_mig_e}")
|
||||
conn.execute("""
|
||||
_ensure_owner_scoped_email_cache_table(conn, "email_calendar_extractions", """
|
||||
CREATE TABLE IF NOT EXISTS email_calendar_extractions (
|
||||
message_id TEXT PRIMARY KEY,
|
||||
message_id TEXT,
|
||||
owner TEXT DEFAULT '',
|
||||
uid TEXT,
|
||||
events_created INTEGER DEFAULT 0,
|
||||
created_at TEXT NOT NULL
|
||||
created_at TEXT NOT NULL,
|
||||
PRIMARY KEY (message_id, owner)
|
||||
)
|
||||
""")
|
||||
conn.execute("""
|
||||
""", ["message_id", "owner", "uid", "events_created", "created_at"])
|
||||
_ensure_owner_scoped_email_cache_table(conn, "email_urgency_alerts", """
|
||||
CREATE TABLE IF NOT EXISTS email_urgency_alerts (
|
||||
message_id TEXT PRIMARY KEY,
|
||||
message_id TEXT,
|
||||
owner TEXT DEFAULT '',
|
||||
uid TEXT,
|
||||
folder TEXT,
|
||||
subject TEXT,
|
||||
@@ -402,9 +452,10 @@ def _init_scheduled_db():
|
||||
urgency TEXT,
|
||||
reason TEXT,
|
||||
alerted INTEGER DEFAULT 0,
|
||||
created_at TEXT NOT NULL
|
||||
created_at TEXT NOT NULL,
|
||||
PRIMARY KEY (message_id, owner)
|
||||
)
|
||||
""")
|
||||
""", ["message_id", "owner", "uid", "folder", "subject", "sender", "urgency", "reason", "alerted", "created_at"])
|
||||
conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS email_event_seen (
|
||||
owner TEXT NOT NULL,
|
||||
|
||||
+29
-16
@@ -39,7 +39,7 @@ from routes.email_helpers import (
|
||||
_extract_attachment_text, _extract_text,
|
||||
_pre_retrieve_context,
|
||||
_attach_compose_uploads, _cleanup_compose_uploads, _q,
|
||||
SCHEDULED_DB, _EMAIL_REPLY_SYS_PROMPT_BASE,
|
||||
SCHEDULED_DB, _EMAIL_REPLY_SYS_PROMPT_BASE, _email_cache_owner_clause,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -243,8 +243,15 @@ async def _auto_summarize_pass_single(days_back: int = 1, account_id: str | None
|
||||
await _emit_progress(progress_cb, f"Found {len(uid_list)} recent email(s); checking cache…")
|
||||
|
||||
_c = _sql3.connect(SCHEDULED_DB)
|
||||
_sum_existing = {r[0] for r in _c.execute("SELECT message_id FROM email_summaries").fetchall()}
|
||||
_reply_existing = {r[0] for r in _c.execute("SELECT message_id FROM email_ai_replies").fetchall()}
|
||||
_cache_owner_clause, _cache_owner_params = _email_cache_owner_clause(account_owner)
|
||||
_sum_existing = {r[0] for r in _c.execute(
|
||||
f"SELECT message_id FROM email_summaries WHERE {_cache_owner_clause}",
|
||||
_cache_owner_params,
|
||||
).fetchall()}
|
||||
_reply_existing = {r[0] for r in _c.execute(
|
||||
f"SELECT message_id FROM email_ai_replies WHERE {_cache_owner_clause}",
|
||||
_cache_owner_params,
|
||||
).fetchall()}
|
||||
if auto_tag or auto_spam:
|
||||
if account_owner:
|
||||
_tag_existing = {r[0] for r in _c.execute("SELECT message_id FROM email_tags WHERE owner=?", (account_owner,)).fetchall()}
|
||||
@@ -252,12 +259,18 @@ async def _auto_summarize_pass_single(days_back: int = 1, account_id: str | None
|
||||
_tag_existing = {r[0] for r in _c.execute("SELECT message_id FROM email_tags WHERE owner='' OR owner IS NULL").fetchall()}
|
||||
else:
|
||||
_tag_existing = set()
|
||||
_cal_existing = {r[0] for r in _c.execute("SELECT message_id FROM email_calendar_extractions").fetchall()} if auto_cal else set()
|
||||
_cal_existing = {r[0] for r in _c.execute(
|
||||
f"SELECT message_id FROM email_calendar_extractions WHERE {_cache_owner_clause}",
|
||||
_cache_owner_params,
|
||||
).fetchall()} if auto_cal else set()
|
||||
# Urgency is handled by the built-in `check_email_urgency` task. Keep
|
||||
# this legacy poller path disabled so users don't get two independent
|
||||
# urgent-email systems.
|
||||
auto_urgent = False
|
||||
_urgent_existing = {r[0] for r in _c.execute("SELECT message_id FROM email_urgency_alerts").fetchall()} if auto_urgent else set()
|
||||
_urgent_existing = {r[0] for r in _c.execute(
|
||||
f"SELECT message_id FROM email_urgency_alerts WHERE {_cache_owner_clause}",
|
||||
_cache_owner_params,
|
||||
).fetchall()} if auto_urgent else set()
|
||||
_c.close()
|
||||
|
||||
# Hoist the self-address lookup OUT of the per-email loop — fetching
|
||||
@@ -415,9 +428,9 @@ async def _auto_summarize_pass_single(days_back: int = 1, account_id: str | None
|
||||
_c = _sql3.connect(SCHEDULED_DB)
|
||||
_c.execute("""
|
||||
INSERT OR REPLACE INTO email_summaries
|
||||
(message_id, uid, folder, subject, sender, summary, model_used, created_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""", (message_id, uid.decode() if isinstance(uid, bytes) else str(uid), _folder, subject, sender, summary, model, datetime.utcnow().isoformat()))
|
||||
(message_id, owner, uid, folder, subject, sender, summary, model_used, created_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""", (message_id, account_owner or "", uid.decode() if isinstance(uid, bytes) else str(uid), _folder, subject, sender, summary, model, datetime.utcnow().isoformat()))
|
||||
_c.commit()
|
||||
_c.close()
|
||||
_sum_existing.add(message_id)
|
||||
@@ -458,9 +471,9 @@ async def _auto_summarize_pass_single(days_back: int = 1, account_id: str | None
|
||||
_c = _sql3.connect(SCHEDULED_DB)
|
||||
_c.execute("""
|
||||
INSERT OR REPLACE INTO email_ai_replies
|
||||
(message_id, uid, folder, reply, model_used, created_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?)
|
||||
""", (message_id, uid.decode() if isinstance(uid, bytes) else str(uid), _folder, reply, model, datetime.utcnow().isoformat()))
|
||||
(message_id, owner, uid, folder, reply, model_used, created_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?)
|
||||
""", (message_id, account_owner or "", uid.decode() if isinstance(uid, bytes) else str(uid), _folder, reply, model, datetime.utcnow().isoformat()))
|
||||
_c.commit()
|
||||
_c.close()
|
||||
_reply_existing.add(message_id)
|
||||
@@ -675,8 +688,8 @@ async def _auto_summarize_pass_single(days_back: int = 1, account_id: str | None
|
||||
_cc = _sql3.connect(SCHEDULED_DB)
|
||||
_cc.execute(
|
||||
"INSERT OR REPLACE INTO email_calendar_extractions "
|
||||
"(message_id, uid, events_created, created_at) VALUES (?, ?, ?, ?)",
|
||||
(message_id, uid.decode() if isinstance(uid, bytes) else str(uid),
|
||||
"(message_id, owner, uid, events_created, created_at) VALUES (?, ?, ?, ?, ?)",
|
||||
(message_id, account_owner or "", uid.decode() if isinstance(uid, bytes) else str(uid),
|
||||
_cal_run_count, datetime.utcnow().isoformat())
|
||||
)
|
||||
_cc.commit()
|
||||
@@ -733,9 +746,9 @@ async def _auto_summarize_pass_single(days_back: int = 1, account_id: str | None
|
||||
_uc = _sql3.connect(SCHEDULED_DB)
|
||||
_uc.execute(
|
||||
"INSERT OR REPLACE INTO email_urgency_alerts "
|
||||
"(message_id, uid, folder, subject, sender, urgency, reason, alerted, created_at) "
|
||||
"VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)",
|
||||
(message_id, uid.decode() if isinstance(uid, bytes) else str(uid),
|
||||
"(message_id, owner, uid, folder, subject, sender, urgency, reason, alerted, created_at) "
|
||||
"VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
|
||||
(message_id, account_owner or "", uid.decode() if isinstance(uid, bytes) else str(uid),
|
||||
_folder, subject, sender, urgency, reason,
|
||||
1 if urgency in ("critical", "high") else 0,
|
||||
datetime.utcnow().isoformat())
|
||||
|
||||
+19
-15
@@ -49,7 +49,7 @@ from routes.email_helpers import (
|
||||
_EMAIL_REPLY_SYS_PROMPT_BASE, _POOL_HOOKS,
|
||||
SendEmailRequest, ExtractStyleRequest,
|
||||
ATTACHMENTS_DIR, COMPOSE_UPLOADS_DIR, SCHEDULED_DB,
|
||||
attachment_extract_dir,
|
||||
attachment_extract_dir, _email_cache_owner_clause,
|
||||
)
|
||||
from routes.email_pollers import _start_poller
|
||||
|
||||
@@ -934,9 +934,11 @@ def setup_email_routes():
|
||||
import sqlite3 as _sql3
|
||||
_c = _sql3.connect(SCHEDULED_DB)
|
||||
placeholders = ",".join("?" * len(ids))
|
||||
owner_clause, owner_params = _email_cache_owner_clause(owner)
|
||||
rows = _c.execute(
|
||||
f"SELECT message_id, summary FROM email_summaries WHERE message_id IN ({placeholders})",
|
||||
ids,
|
||||
f"SELECT message_id, summary FROM email_summaries "
|
||||
f"WHERE message_id IN ({placeholders}) AND {owner_clause}",
|
||||
(*ids, *owner_params),
|
||||
).fetchall()
|
||||
_c.close()
|
||||
by_id = {r[0]: r[1] for r in rows}
|
||||
@@ -1219,15 +1221,16 @@ def setup_email_routes():
|
||||
try:
|
||||
import sqlite3 as _sql3
|
||||
_c = _sql3.connect(SCHEDULED_DB)
|
||||
owner_clause, owner_params = _email_cache_owner_clause(owner)
|
||||
_row = _c.execute(
|
||||
"SELECT summary FROM email_summaries WHERE message_id = ?",
|
||||
(message_id.strip(),),
|
||||
f"SELECT summary FROM email_summaries WHERE message_id = ? AND {owner_clause}",
|
||||
(message_id.strip(), *owner_params),
|
||||
).fetchone()
|
||||
if _row:
|
||||
cached_summary = _row[0]
|
||||
_row2 = _c.execute(
|
||||
"SELECT reply FROM email_ai_replies WHERE message_id = ?",
|
||||
(message_id.strip(),),
|
||||
f"SELECT reply FROM email_ai_replies WHERE message_id = ? AND {owner_clause}",
|
||||
(message_id.strip(), *owner_params),
|
||||
).fetchone()
|
||||
if _row2:
|
||||
cached_ai_reply = _apply_email_style_mechanics(_extract_reply(_row2[0] or ""))
|
||||
@@ -2549,10 +2552,10 @@ def setup_email_routes():
|
||||
_c = _sql3.connect(SCHEDULED_DB)
|
||||
_c.execute("""
|
||||
INSERT OR REPLACE INTO email_summaries
|
||||
(message_id, uid, folder, subject, sender, summary, model_used, created_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
||||
(message_id, owner, uid, folder, subject, sender, summary, model_used, created_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""", (
|
||||
mid, data.get("uid", ""), data.get("folder", ""),
|
||||
mid, owner, data.get("uid", ""), data.get("folder", ""),
|
||||
subject, sender, content, model, datetime.utcnow().isoformat(),
|
||||
))
|
||||
_c.commit()
|
||||
@@ -2587,9 +2590,10 @@ def setup_email_routes():
|
||||
if message_id:
|
||||
try:
|
||||
_c = _sql3.connect(SCHEDULED_DB)
|
||||
owner_clause, owner_params = _email_cache_owner_clause(owner)
|
||||
_row = _c.execute(
|
||||
"SELECT reply, model_used FROM email_ai_replies WHERE message_id = ?",
|
||||
(message_id,),
|
||||
f"SELECT reply, model_used FROM email_ai_replies WHERE message_id = ? AND {owner_clause}",
|
||||
(message_id, *owner_params),
|
||||
).fetchone()
|
||||
_c.close()
|
||||
if _row and _row[0]:
|
||||
@@ -2791,9 +2795,9 @@ def setup_email_routes():
|
||||
_c = _sql3.connect(SCHEDULED_DB)
|
||||
_c.execute("""
|
||||
INSERT OR REPLACE INTO email_ai_replies
|
||||
(message_id, uid, folder, reply, model_used, created_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?)
|
||||
""", (message_id, source_uid, source_folder, reply, model, datetime.utcnow().isoformat()))
|
||||
(message_id, owner, uid, folder, reply, model_used, created_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?)
|
||||
""", (message_id, owner, source_uid, source_folder, reply, model, datetime.utcnow().isoformat()))
|
||||
_c.commit()
|
||||
_c.close()
|
||||
except Exception as e:
|
||||
|
||||
@@ -10,11 +10,36 @@ from fastapi import APIRouter, Request, HTTPException
|
||||
from core.models import ChatMessage
|
||||
from core.database import SessionLocal, ChatMessage as DbChatMessage, Session as DbSession
|
||||
from src.topic_analyzer import analyze_topics
|
||||
from routes.session_routes import _verify_session_owner
|
||||
from routes.session_routes import (
|
||||
_message_role,
|
||||
_message_text,
|
||||
_reject_compact_during_active_run,
|
||||
_verify_session_owner,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _merge_continue_rows_to_delete(db_messages, db1, db2):
|
||||
"""DB rows to delete when merging the last two assistant messages.
|
||||
|
||||
Always the second assistant message (db2), plus ONLY the single
|
||||
intervening "continue" user message (the one carrying "previous response
|
||||
was interrupted") — matching the in-memory merge. The previous code
|
||||
deleted the whole index range between the two assistant rows, destroying
|
||||
any tool/system/user messages in between and desyncing the DB from the
|
||||
in-memory history.
|
||||
"""
|
||||
to_delete = [db2]
|
||||
i1 = next((i for i, m in enumerate(db_messages) if m is db1), None)
|
||||
i2 = next((i for i, m in enumerate(db_messages) if m is db2), None)
|
||||
if i1 is not None and i2 is not None and i2 - 1 > i1:
|
||||
between = db_messages[i2 - 1]
|
||||
if getattr(between, "role", "") == "user" and "previous response was interrupted" in (getattr(between, "content", "") or ""):
|
||||
to_delete.append(between)
|
||||
return to_delete
|
||||
|
||||
|
||||
def setup_history_routes(session_manager) -> APIRouter:
|
||||
router = APIRouter(tags=["history"])
|
||||
|
||||
@@ -418,11 +443,13 @@ def setup_history_routes(session_manager) -> APIRouter:
|
||||
db1.content = merged_content
|
||||
db1.meta_data = _json.dumps(merged_meta)
|
||||
|
||||
# Remove the continue user message if between them
|
||||
db_idx2 = db_messages.index(db2)
|
||||
db_idx1 = db_messages.index(db1)
|
||||
for di in range(db_idx2, db_idx1, -1):
|
||||
db.delete(db_messages[di])
|
||||
# Mirror the in-memory deletion: remove the second assistant
|
||||
# message and ONLY the "continue" user message between them
|
||||
# (not arbitrary tool/system/user rows). The old
|
||||
# range-delete destroyed every row between the two assistant
|
||||
# messages, desyncing the DB from the in-memory history.
|
||||
for _row in _merge_continue_rows_to_delete(db_messages, db1, db2):
|
||||
db.delete(_row)
|
||||
|
||||
db.commit()
|
||||
finally:
|
||||
@@ -499,6 +526,7 @@ def setup_history_routes(session_manager) -> APIRouter:
|
||||
session = session_manager.get_session(session_id)
|
||||
except KeyError:
|
||||
raise HTTPException(404, "Session not found")
|
||||
_reject_compact_during_active_run(session_id)
|
||||
|
||||
try:
|
||||
from src.model_context import estimate_tokens, get_context_length
|
||||
@@ -521,8 +549,8 @@ def setup_history_routes(session_manager) -> APIRouter:
|
||||
|
||||
# Build text to summarize
|
||||
convo_text = "\n".join(
|
||||
f"{(m.role if isinstance(m, ChatMessage) else m.get('role', '')).upper()}: "
|
||||
f"{(m.content if isinstance(m, ChatMessage) else m.get('content', ''))[:2000]}"
|
||||
f"{_message_role(m).upper()}: "
|
||||
f"{_message_text(m)[:2000]}"
|
||||
for m in older
|
||||
)
|
||||
|
||||
|
||||
+123
-17
@@ -5,6 +5,7 @@ import os
|
||||
import uuid
|
||||
import urllib.parse
|
||||
import html
|
||||
from pathlib import Path
|
||||
from fastapi import APIRouter, Form, HTTPException, Request
|
||||
from fastapi.responses import RedirectResponse, HTMLResponse
|
||||
import logging
|
||||
@@ -12,6 +13,7 @@ import httpx
|
||||
|
||||
from core.database import McpServer, SessionLocal
|
||||
from core.middleware import require_admin
|
||||
from src.constants import DATA_DIR
|
||||
from src.mcp_manager import McpManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -19,6 +21,75 @@ logger = logging.getLogger(__name__)
|
||||
router = APIRouter(prefix="/api/mcp", tags=["mcp"])
|
||||
|
||||
|
||||
def _mcp_oauth_base_dir() -> Path:
|
||||
"""Directory that may contain OAuth files managed by Odysseus."""
|
||||
return (Path(DATA_DIR) / "mcp_oauth").resolve(strict=False)
|
||||
|
||||
|
||||
def _resolve_mcp_oauth_path(raw_path, field_name: str) -> str:
|
||||
"""Resolve an MCP OAuth path and keep it under DATA_DIR/mcp_oauth."""
|
||||
raw = str(raw_path or "").strip()
|
||||
if not raw:
|
||||
return ""
|
||||
|
||||
base = _mcp_oauth_base_dir()
|
||||
path = Path(os.path.expanduser(raw))
|
||||
if not path.is_absolute():
|
||||
path = base / path
|
||||
resolved = path.resolve(strict=False)
|
||||
|
||||
try:
|
||||
resolved.relative_to(base)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(
|
||||
400,
|
||||
f"Invalid OAuth {field_name}: path must stay under {base}",
|
||||
) from exc
|
||||
return str(resolved)
|
||||
|
||||
|
||||
def _sanitize_mcp_oauth_config(oauth_cfg):
|
||||
"""Return an OAuth config copy with file paths confined to mcp_oauth."""
|
||||
if not oauth_cfg:
|
||||
return oauth_cfg
|
||||
if not isinstance(oauth_cfg, dict):
|
||||
return {}
|
||||
sanitized = dict(oauth_cfg)
|
||||
for field_name in ("keys_file", "token_file"):
|
||||
if sanitized.get(field_name):
|
||||
sanitized[field_name] = _resolve_mcp_oauth_path(
|
||||
sanitized[field_name],
|
||||
field_name,
|
||||
)
|
||||
return sanitized
|
||||
|
||||
|
||||
def _mcp_oauth_token_missing(oauth_cfg, *, strict: bool = True) -> bool:
|
||||
"""Check token existence without letting legacy bad paths break listing."""
|
||||
if not isinstance(oauth_cfg, dict):
|
||||
return False
|
||||
try:
|
||||
token_file = _resolve_mcp_oauth_path(oauth_cfg.get("token_file", ""), "token_file")
|
||||
except HTTPException:
|
||||
if strict:
|
||||
raise
|
||||
logger.warning("Ignoring MCP OAuth config with unsafe token_file")
|
||||
return True
|
||||
return bool(token_file and not os.path.exists(token_file))
|
||||
|
||||
|
||||
def _apply_mcp_oauth_env(env: dict, oauth_cfg) -> None:
|
||||
"""Pass sanitized Gmail package paths to MCP servers that honor them."""
|
||||
if not oauth_cfg or not isinstance(env, dict):
|
||||
return
|
||||
keys_file = oauth_cfg.get("keys_file")
|
||||
token_file = oauth_cfg.get("token_file")
|
||||
if keys_file:
|
||||
env["GMAIL_OAUTH_PATH"] = keys_file
|
||||
if token_file:
|
||||
env["GMAIL_CREDENTIALS_PATH"] = token_file
|
||||
|
||||
|
||||
def _load_disabled_map():
|
||||
"""Load per-server disabled tool sets from DB."""
|
||||
db = SessionLocal()
|
||||
@@ -53,8 +124,7 @@ def setup_mcp_routes(mcp_manager: McpManager):
|
||||
oauth_cfg = json.loads(srv.oauth_config) if srv.oauth_config else None
|
||||
needs_oauth = False
|
||||
if oauth_cfg:
|
||||
token_file = os.path.expanduser(oauth_cfg.get("token_file", ""))
|
||||
needs_oauth = token_file and not os.path.exists(token_file)
|
||||
needs_oauth = _mcp_oauth_token_missing(oauth_cfg, strict=False)
|
||||
disabled_list = json.loads(srv.disabled_tools) if srv.disabled_tools else []
|
||||
total_tools = status.get("tool_count", 0)
|
||||
result.append({
|
||||
@@ -71,6 +141,7 @@ def setup_mcp_routes(mcp_manager: McpManager):
|
||||
"disabled_tool_count": len(disabled_list),
|
||||
"enabled_tool_count": max(0, total_tools - len(disabled_list)),
|
||||
"error": status.get("error"),
|
||||
"auth_url": status.get("auth_url"),
|
||||
"has_oauth": oauth_cfg is not None,
|
||||
"needs_oauth": needs_oauth,
|
||||
})
|
||||
@@ -101,6 +172,8 @@ def setup_mcp_routes(mcp_manager: McpManager):
|
||||
raise HTTPException(400, "command is required for stdio transport")
|
||||
if transport == "sse" and not url:
|
||||
raise HTTPException(400, "url is required for SSE transport")
|
||||
if transport == "http" and not url:
|
||||
raise HTTPException(400, "url is required for HTTP transport")
|
||||
|
||||
# Parse JSON fields
|
||||
try:
|
||||
@@ -111,26 +184,33 @@ def setup_mcp_routes(mcp_manager: McpManager):
|
||||
parsed_env = json.loads(env) if env else {}
|
||||
except json.JSONDecodeError:
|
||||
parsed_env = {}
|
||||
if not isinstance(parsed_env, dict):
|
||||
parsed_env = {}
|
||||
|
||||
# Parse OAuth config
|
||||
parsed_oauth_config = None
|
||||
if oauth_config:
|
||||
try:
|
||||
parsed_oauth_config = json.loads(oauth_config)
|
||||
parsed_oauth_config = _sanitize_mcp_oauth_config(json.loads(oauth_config))
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
_apply_mcp_oauth_env(parsed_env, parsed_oauth_config)
|
||||
|
||||
# Write OAuth credentials file if provided (for Google MCP servers)
|
||||
logger.info(f"MCP add_server: oauth_file={oauth_file!r}")
|
||||
if oauth_file:
|
||||
try:
|
||||
oauth_data = json.loads(oauth_file)
|
||||
oauth_dir = os.path.expanduser(oauth_data.get("dir", ""))
|
||||
oauth_dir = _resolve_mcp_oauth_path(oauth_data.get("dir", ""), "dir")
|
||||
oauth_filename = oauth_data.get("filename", "")
|
||||
client_id = oauth_data.get("client_id", "")
|
||||
client_secret = oauth_data.get("client_secret", "")
|
||||
if oauth_dir and oauth_filename and client_id and client_secret:
|
||||
os.makedirs(oauth_dir, exist_ok=True)
|
||||
filepath = _resolve_mcp_oauth_path(
|
||||
Path(oauth_dir) / str(oauth_filename),
|
||||
"filename",
|
||||
)
|
||||
os.makedirs(os.path.dirname(filepath), exist_ok=True)
|
||||
creds = {
|
||||
"installed": {
|
||||
"client_id": client_id,
|
||||
@@ -140,7 +220,6 @@ def setup_mcp_routes(mcp_manager: McpManager):
|
||||
"token_uri": "https://accounts.google.com/o/oauth2/token",
|
||||
}
|
||||
}
|
||||
filepath = os.path.join(oauth_dir, oauth_filename)
|
||||
with open(filepath, "w", encoding="utf-8") as f:
|
||||
json.dump(creds, f, indent=2)
|
||||
logger.info(f"Wrote OAuth credentials to {filepath}")
|
||||
@@ -171,9 +250,7 @@ def setup_mcp_routes(mcp_manager: McpManager):
|
||||
# Check if OAuth token already exists — skip connection attempt if not
|
||||
needs_oauth = False
|
||||
if parsed_oauth_config:
|
||||
token_file = os.path.expanduser(parsed_oauth_config.get("token_file", ""))
|
||||
if token_file and not os.path.exists(token_file):
|
||||
needs_oauth = True
|
||||
needs_oauth = _mcp_oauth_token_missing(parsed_oauth_config)
|
||||
|
||||
connected = False
|
||||
if not needs_oauth:
|
||||
@@ -188,6 +265,7 @@ def setup_mcp_routes(mcp_manager: McpManager):
|
||||
)
|
||||
|
||||
status = mcp_manager.get_server_status(server_id)
|
||||
needs_auth = status.get("status") == "needs_auth"
|
||||
return {
|
||||
"id": server_id,
|
||||
"name": name,
|
||||
@@ -196,6 +274,8 @@ def setup_mcp_routes(mcp_manager: McpManager):
|
||||
"tool_count": status.get("tool_count", 0),
|
||||
"error": "OAuth authorization required" if needs_oauth else status.get("error"),
|
||||
"needs_oauth": needs_oauth,
|
||||
"needs_auth": needs_auth,
|
||||
"auth_url": status.get("auth_url"),
|
||||
}
|
||||
|
||||
@router.post("/servers/{server_id}/reconnect")
|
||||
@@ -228,6 +308,8 @@ def setup_mcp_routes(mcp_manager: McpManager):
|
||||
"status": status.get("status", "disconnected"),
|
||||
"tool_count": status.get("tool_count", 0),
|
||||
"error": status.get("error"),
|
||||
"auth_url": status.get("auth_url"),
|
||||
"needs_auth": status.get("status") == "needs_auth",
|
||||
}
|
||||
finally:
|
||||
db.close()
|
||||
@@ -349,8 +431,8 @@ def setup_mcp_routes(mcp_manager: McpManager):
|
||||
if not srv.oauth_config:
|
||||
raise HTTPException(400, "Server has no OAuth config")
|
||||
|
||||
oauth_cfg = json.loads(srv.oauth_config)
|
||||
keys_file = os.path.expanduser(oauth_cfg.get("keys_file", ""))
|
||||
oauth_cfg = _sanitize_mcp_oauth_config(json.loads(srv.oauth_config))
|
||||
keys_file = oauth_cfg.get("keys_file", "")
|
||||
if not keys_file or not os.path.exists(keys_file):
|
||||
raise HTTPException(400, "OAuth keys file not found")
|
||||
|
||||
@@ -393,10 +475,18 @@ def setup_mcp_routes(mcp_manager: McpManager):
|
||||
|
||||
@router.get("/oauth/callback")
|
||||
async def oauth_callback(code: str, state: str, request: Request):
|
||||
"""Handle OAuth callback from Google — exchange code for tokens."""
|
||||
"""Handle OAuth callback. Generic MCP OAuth flows resolve via the
|
||||
pending-state registry; Google flows fall through to the legacy path."""
|
||||
require_admin(request)
|
||||
server_id = state
|
||||
return await _exchange_and_connect(server_id, code, request)
|
||||
from src.mcp_oauth import resolve_pending
|
||||
if resolve_pending(state, code):
|
||||
return HTMLResponse(_oauth_result_page(
|
||||
"Authorization Successful",
|
||||
"The MCP server is connecting. You can close this window and return to Odysseus.",
|
||||
success=True,
|
||||
))
|
||||
# Legacy Google path: state is the server_id
|
||||
return await _exchange_and_connect(state, code, request)
|
||||
|
||||
@router.post("/oauth/exchange/{server_id}")
|
||||
async def oauth_exchange(server_id: str, request: Request, callback_url: str = Form(...)):
|
||||
@@ -411,6 +501,17 @@ def setup_mcp_routes(mcp_manager: McpManager):
|
||||
except Exception:
|
||||
return HTMLResponse(_oauth_result_page("Error", "Invalid URL format."), status_code=400)
|
||||
|
||||
# Generic MCP OAuth: if the pasted URL carries a state we are waiting on,
|
||||
# resolve it directly (the background connect finishes the handshake).
|
||||
state = params.get("state", [None])[0]
|
||||
from src.mcp_oauth import resolve_pending
|
||||
if state and resolve_pending(state, code):
|
||||
return HTMLResponse(_oauth_result_page(
|
||||
"Authorization Successful",
|
||||
"The MCP server is connecting. You can close this window and return to Odysseus.",
|
||||
success=True,
|
||||
))
|
||||
|
||||
return await _exchange_and_connect(server_id, code, request)
|
||||
|
||||
async def _exchange_and_connect(server_id: str, code: str, request: Request):
|
||||
@@ -423,9 +524,11 @@ def setup_mcp_routes(mcp_manager: McpManager):
|
||||
if not srv.oauth_config:
|
||||
return HTMLResponse(_oauth_result_page("Error", "No OAuth config."), status_code=400)
|
||||
|
||||
oauth_cfg = json.loads(srv.oauth_config)
|
||||
keys_file = os.path.expanduser(oauth_cfg.get("keys_file", ""))
|
||||
token_file = os.path.expanduser(oauth_cfg.get("token_file", ""))
|
||||
oauth_cfg = _sanitize_mcp_oauth_config(json.loads(srv.oauth_config))
|
||||
keys_file = oauth_cfg.get("keys_file", "")
|
||||
token_file = oauth_cfg.get("token_file", "")
|
||||
if not keys_file or not token_file:
|
||||
raise HTTPException(400, "OAuth keys/token file not configured")
|
||||
|
||||
with open(keys_file, encoding="utf-8") as f:
|
||||
keys_data = json.load(f)
|
||||
@@ -488,6 +591,9 @@ def setup_mcp_routes(mcp_manager: McpManager):
|
||||
"Authorized but Connection Failed",
|
||||
f"Tokens saved, but the server failed to connect: {status.get('error', 'unknown error')}. Try reconnecting from Settings.",
|
||||
))
|
||||
except HTTPException as e:
|
||||
logger.warning(f"OAuth callback rejected: {e.detail}")
|
||||
return HTMLResponse(_oauth_result_page("Error", str(e.detail)), status_code=e.status_code)
|
||||
except Exception as e:
|
||||
logger.exception(f"OAuth callback error: {e}")
|
||||
return HTMLResponse(_oauth_result_page("Error", str(e)), status_code=500)
|
||||
|
||||
+16
-7
@@ -1029,12 +1029,13 @@ def setup_model_routes(model_discovery):
|
||||
for ep in endpoints:
|
||||
base = _normalize_base(ep.base_url)
|
||||
provider = _detect_provider(base)
|
||||
# Use cached models — background refresh keeps them updated
|
||||
model_ids = _cached_model_ids(ep)
|
||||
# Merge cached + pinned models, then filter out hidden ones
|
||||
ep_model_type = getattr(ep, "model_type", None) or "llm"
|
||||
# Filter out hidden (probe-failed) models
|
||||
hidden = _hidden_model_ids(ep)
|
||||
model_ids = [m for m in model_ids if m not in hidden]
|
||||
model_ids = _visible_models(
|
||||
_cached_model_ids(ep),
|
||||
ep.hidden_models,
|
||||
getattr(ep, "pinned_models", None),
|
||||
)
|
||||
# Build correct URL based on provider
|
||||
chat_url = build_chat_url(base)
|
||||
kind = _effective_endpoint_kind(ep, base)
|
||||
@@ -1043,6 +1044,13 @@ def setup_model_routes(model_discovery):
|
||||
if model_ids:
|
||||
curated_key = _match_provider_curated(base, None)
|
||||
curated, extra = _curate_models(model_ids, curated_key)
|
||||
# Pinned models are admin-selected — they always belong in the
|
||||
# primary curated list, not buried in extras.
|
||||
pinned = _normalize_model_ids(getattr(ep, "pinned_models", None))
|
||||
for m in pinned:
|
||||
if m not in curated:
|
||||
curated.append(m)
|
||||
extra = [m for m in extra if m not in pinned]
|
||||
items.append({
|
||||
"host": "custom",
|
||||
"port": 0,
|
||||
@@ -1891,9 +1899,10 @@ def setup_model_routes(model_discovery):
|
||||
if body:
|
||||
if "supports_tools" in body:
|
||||
v = body["supports_tools"]
|
||||
ep.supports_tools = bool(v) if v in (True, False, "true", "false", 1, 0) else None
|
||||
ep.supports_tools = {True: True, False: False, 'true': True, 'false': False, 1: True, 0: False}.get(v)
|
||||
if "is_enabled" in body:
|
||||
ep.is_enabled = bool(body["is_enabled"])
|
||||
v_ie = body['is_enabled']
|
||||
ep.is_enabled = v_ie.lower() in ('true', '1', 'yes') if isinstance(v_ie, str) else bool(v_ie)
|
||||
if "name" in body and isinstance(body["name"], str):
|
||||
ep.name = body["name"].strip() or ep.name
|
||||
if "model_type" in body and isinstance(body["model_type"], str):
|
||||
|
||||
@@ -57,6 +57,40 @@ def _content_to_text(content) -> str:
|
||||
return ""
|
||||
|
||||
|
||||
def _message_role(message) -> str:
|
||||
if isinstance(message, ChatMessage):
|
||||
return message.role or ""
|
||||
if isinstance(message, dict):
|
||||
return message.get("role", "") or ""
|
||||
return getattr(message, "role", "") or ""
|
||||
|
||||
|
||||
def _message_text(message) -> str:
|
||||
if isinstance(message, ChatMessage):
|
||||
content = message.content
|
||||
elif isinstance(message, dict):
|
||||
content = message.get("content")
|
||||
else:
|
||||
content = getattr(message, "content", None)
|
||||
return _content_to_text(content)
|
||||
|
||||
|
||||
def _message_metadata(message) -> dict:
|
||||
if isinstance(message, ChatMessage):
|
||||
metadata = message.metadata
|
||||
elif isinstance(message, dict):
|
||||
metadata = message.get("metadata")
|
||||
else:
|
||||
metadata = getattr(message, "metadata", None)
|
||||
return metadata if isinstance(metadata, dict) else {}
|
||||
|
||||
|
||||
def _reject_compact_during_active_run(session_id: str) -> None:
|
||||
from src import agent_runs
|
||||
if agent_runs.is_active(session_id):
|
||||
raise HTTPException(409, "Session has an active run; try compacting after it finishes")
|
||||
|
||||
|
||||
def _verify_session_owner(request: Request, session_id: str, session_manager=None):
|
||||
"""Verify the current user owns the session. Raises 404 if not.
|
||||
|
||||
@@ -872,6 +906,7 @@ def setup_session_routes(session_manager: SessionManager, config: dict, webhook_
|
||||
session = session_manager.get_session(session_id)
|
||||
except KeyError:
|
||||
raise HTTPException(404, f"Session {session_id} not found")
|
||||
_reject_compact_during_active_run(session_id)
|
||||
|
||||
history = list(session.history or [])
|
||||
if len(history) < 6:
|
||||
@@ -897,7 +932,7 @@ def setup_session_routes(session_manager: SessionManager, config: dict, webhook_
|
||||
|
||||
prior_compactions = sum(
|
||||
1 for m in history
|
||||
if (m.metadata or {}).get("compacted") or "[Conversation summary" in (m.content or "")
|
||||
if _message_metadata(m).get("compacted") or "[Conversation summary" in _message_text(m)
|
||||
)
|
||||
prompt = SELF_SUMMARY_SYSTEM_PROMPT.replace(
|
||||
"{count}", str(len(older))
|
||||
@@ -905,7 +940,7 @@ def setup_session_routes(session_manager: SessionManager, config: dict, webhook_
|
||||
"{n}", str(prior_compactions + 1)
|
||||
)
|
||||
convo_text = "\n".join(
|
||||
f"{m.role.upper()}: {(m.content or '')[:2000]}"
|
||||
f"{_message_role(m).upper()}: {_message_text(m)[:2000]}"
|
||||
for m in older
|
||||
)
|
||||
try:
|
||||
|
||||
@@ -455,7 +455,7 @@ def setup_task_routes(task_scheduler) -> APIRouter:
|
||||
|
||||
import sqlite3
|
||||
from pathlib import Path
|
||||
from routes.email_helpers import SCHEDULED_DB
|
||||
from routes.email_helpers import SCHEDULED_DB, OWNER_SCOPED_EMAIL_CACHE_TABLES, _email_cache_owner_clause
|
||||
|
||||
cleared = {}
|
||||
conn = sqlite3.connect(SCHEDULED_DB)
|
||||
@@ -468,6 +468,13 @@ def setup_task_routes(task_scheduler) -> APIRouter:
|
||||
(user,),
|
||||
).fetchone()[0]
|
||||
conn.execute("DELETE FROM email_tags WHERE owner = ? OR owner = ''", (user,))
|
||||
elif table in OWNER_SCOPED_EMAIL_CACHE_TABLES and user:
|
||||
owner_clause, owner_params = _email_cache_owner_clause(user)
|
||||
before = conn.execute(
|
||||
f"SELECT COUNT(*) FROM {table} WHERE {owner_clause}",
|
||||
owner_params,
|
||||
).fetchone()[0]
|
||||
conn.execute(f"DELETE FROM {table} WHERE {owner_clause}", owner_params)
|
||||
else:
|
||||
before = conn.execute(f"SELECT COUNT(*) FROM {table}").fetchone()[0]
|
||||
conn.execute(f"DELETE FROM {table}")
|
||||
|
||||
@@ -0,0 +1,56 @@
|
||||
"""Workspace API — browse server directories to pick a tool workspace folder."""
|
||||
import os
|
||||
from fastapi import APIRouter, Request, HTTPException, Query
|
||||
|
||||
from src.auth_helpers import get_current_user
|
||||
from src.tool_security import owner_is_admin_or_single_user
|
||||
|
||||
|
||||
def setup_workspace_routes():
|
||||
router = APIRouter(prefix="/api/workspace", tags=["workspace"])
|
||||
|
||||
@router.get("/browse")
|
||||
def browse(request: Request, path: str = Query(default="")):
|
||||
"""List subdirectories of `path` (default: home) so the UI can navigate
|
||||
the server filesystem and pick a workspace folder. Directories only.
|
||||
|
||||
ADMIN-ONLY: this enumerates the server filesystem, so it is gated the
|
||||
same way the file/shell tools are (read_file/write_file/bash are in
|
||||
NON_ADMIN_BLOCKED_TOOLS). A non-admin who can't use those tools must not
|
||||
be able to map the host's directory tree either.
|
||||
"""
|
||||
owner = get_current_user(request)
|
||||
if not owner_is_admin_or_single_user(owner):
|
||||
raise HTTPException(status_code=403, detail="Workspace browsing is admin-only")
|
||||
|
||||
# Resolve symlinks so the reported path is canonical and the UI navigates
|
||||
# real directories (defends against symlink games in displayed paths).
|
||||
target = os.path.realpath(os.path.expanduser(path.strip() or "~"))
|
||||
if not os.path.isdir(target):
|
||||
target = os.path.realpath(os.path.expanduser("~"))
|
||||
|
||||
dirs = []
|
||||
try:
|
||||
with os.scandir(target) as it:
|
||||
for entry in it:
|
||||
try:
|
||||
# Don't follow symlinks when classifying — a symlinked
|
||||
# dir is skipped rather than letting the browser wander
|
||||
# off via a link. Hidden entries are omitted.
|
||||
if entry.is_dir(follow_symlinks=False) and not entry.name.startswith("."):
|
||||
# Build the child path server-side with os.path.join
|
||||
# so it's correct on Windows (backslashes) and Linux.
|
||||
dirs.append({"name": entry.name, "path": os.path.join(target, entry.name)})
|
||||
except OSError:
|
||||
continue
|
||||
except (PermissionError, OSError):
|
||||
dirs = []
|
||||
|
||||
parent = os.path.dirname(target)
|
||||
return {
|
||||
"path": target,
|
||||
"parent": parent if parent and parent != target else None,
|
||||
"dirs": sorted(dirs, key=lambda d: d["name"].lower()),
|
||||
}
|
||||
|
||||
return router
|
||||
Reference in New Issue
Block a user