Merge branch 'dev'

# Conflicts:
#	routes/task_routes.py
#	src/caldav_sync.py
This commit is contained in:
pewdiepie-archdaemon
2026-06-09 09:36:01 +09:00
351 changed files with 28143 additions and 3932 deletions
+5 -5
View File
@@ -31,7 +31,7 @@ from core.database import (
CalendarEvent,
CalendarCal,
)
from src.constants import DATA_DIR
from src.constants import DATA_DIR, SKILLS_DIR, SKILLS_FILE, GALLERY_DIR, GALLERY_UPLOADS_DIR
logger = logging.getLogger(__name__)
@@ -107,7 +107,7 @@ def setup_admin_wipe_routes(session_manager):
# Skills live as SKILL.md files under data/skills/. Drop
# the entire directory; the SkillsManager re-creates the
# tree on next write.
skills_dir = os.path.join(DATA_DIR, "skills")
skills_dir = SKILLS_DIR
count = 0
if os.path.isdir(skills_dir):
# Count SKILL.md files for the response — quick walk.
@@ -115,7 +115,7 @@ def setup_admin_wipe_routes(session_manager):
count += sum(1 for f in files if f == "SKILL.md")
_rmtree_quiet(skills_dir)
# Legacy fallback file
legacy = os.path.join(DATA_DIR, "skills.json")
legacy = SKILLS_FILE
if os.path.exists(legacy):
try:
os.remove(legacy)
@@ -151,8 +151,8 @@ def setup_admin_wipe_routes(session_manager):
db.query(GalleryAlbum).delete()
db.commit()
# Also drop the upload dir so disk doesn't keep orphans.
_rmtree_quiet(os.path.join(DATA_DIR, "gallery"))
_rmtree_quiet(os.path.join(DATA_DIR, "gallery_uploads"))
_rmtree_quiet(GALLERY_DIR)
_rmtree_quiet(GALLERY_UPLOADS_DIR)
return {"status": "deleted", "kind": kind, "count": count}
if kind == "calendar":
+12 -4
View File
@@ -155,22 +155,30 @@ def setup_api_token_routes() -> APIRouter:
payload = await request.json()
except Exception:
payload = {}
scope_list = _normalize_scopes(payload.get("scopes"))
scopes_value = ",".join(scope_list)
with get_db_session() as db:
token = db.query(ApiToken).filter(ApiToken.id == token_id).first()
if not token:
raise HTTPException(404, "Token not found")
if isinstance(payload.get("name"), str) and payload["name"].strip():
token.name = payload["name"].strip()[:MAX_NAME_LEN]
token.scopes = scopes_value
# Only touch scopes when the caller actually sent them. A partial
# update such as a rename ({"name": ...} with no "scopes" key) must
# not silently reset the token to the default scope — that dropped
# every previously granted scope.
if "scopes" in payload:
token.scopes = ",".join(_normalize_scopes(payload.get("scopes")))
db.add(token)
current_scopes = [
s.strip()
for s in (getattr(token, "scopes", "") or DEFAULT_SCOPES).split(",")
if s.strip()
]
response = {
"id": token_id,
"name": getattr(token, "name", ""),
"owner": getattr(token, "owner", None),
"token_prefix": getattr(token, "token_prefix", ""),
"scopes": scope_list,
"scopes": current_scopes,
}
_invalidate_cache(request)
return response
+23 -4
View File
@@ -131,10 +131,8 @@ def setup_auth_routes(auth_manager: AuthManager) -> APIRouter:
return {"ok": False, "requires_totp": True, "username": username}
if not auth_manager.totp_verify(username, body.totp_code):
raise HTTPException(401, "Invalid 2FA code")
# All checks passed — create session
token = await asyncio.to_thread(auth_manager.create_session, username, body.password)
if not token:
raise HTTPException(401, "Invalid credentials")
# All checks passed — create session (password already verified above)
token = await asyncio.to_thread(auth_manager.create_session_trusted, username)
cookie_kwargs = dict(
key=SESSION_COOKIE,
value=token,
@@ -585,6 +583,27 @@ def setup_auth_routes(auth_manager: AuthManager) -> APIRouter:
hint = " If this is Docker Compose ntfy, set NTFY_BIND to that host/Tailscale IP and NTFY_BASE_URL to the same server URL in .env, then recreate ntfy."
return {"ok": False, "message": f"ntfy publish to {full_url} failed: {e}.{hint}"[:500]}
if preset == "discord_webhook":
import httpx
webhook_url = (integ.get("base_url") or "").strip()
if not webhook_url:
return {"ok": False, "message": "No webhook URL set — paste the full Discord webhook URL into the Base URL field."}
payload = {
"embeds": [{
"title": "Odysseus connectivity test",
"description": "If you see this, your Discord Webhook integration is wired up correctly.",
"color": 5793266,
}]
}
try:
async with httpx.AsyncClient(timeout=8.0) as client:
r = await client.post(webhook_url, json=payload)
if r.is_success:
return {"ok": True, "message": "Test embed sent — check your Discord channel to confirm it arrived."}
return {"ok": False, "message": f"Discord returned HTTP {r.status_code}: {r.text[:200]}"}
except Exception as e:
return {"ok": False, "message": f"Request failed: {e}"[:400]}
# All other presets: GET against a known health endpoint.
# Fall back to detecting from name if preset is missing.
health_paths = {
+56 -12
View File
@@ -101,24 +101,68 @@ def setup_backup_routes(memory_manager, preset_manager, skills_manager) -> APIRo
# ── Skills ──
if "skills" in body and isinstance(body["skills"], list):
existing = skills_manager.load_all()
existing_ids = {s.get("id") for s in existing}
existing_titles = {s.get("title", "").strip().lower() for s in existing}
existing_names = {s.get("name") for s in existing if s.get("name")}
existing_ids = {s.get("id") for s in existing if s.get("id")}
existing_titles = {
(s.get("title") or s.get("description") or "").strip().lower()
for s in existing
}
added = 0
for skill in body["skills"]:
if not isinstance(skill, dict) or not skill.get("title"):
if not isinstance(skill, dict):
continue
# Skip if same id or same title already exists
if skill.get("id") in existing_ids:
title = (
skill.get("title") or skill.get("description")
or skill.get("name") or ""
).strip()
if not title:
continue
if skill["title"].strip().lower() in existing_titles:
sid = skill.get("id") or skill.get("name")
if sid and sid in existing_ids:
continue
if user and not skill.get("owner"):
skill["owner"] = user
existing.append(skill)
existing_ids.add(skill.get("id"))
existing_titles.add(skill["title"].strip().lower())
nm = skill.get("name")
if nm and nm in existing_names:
continue
if title.lower() in existing_titles:
continue
owner = skill.get("owner")
if user and not owner:
owner = user
# Skills live on disk as SKILL.md files; the old JSON-era
# skills_manager.save() no longer exists. Write each new skill
# via add_skill (source="user" skips auto-dedup — this is an
# explicit backup restore).
result = skills_manager.add_skill(
title=title,
name=skill.get("name"),
description=skill.get("description"),
problem=skill.get("problem", ""),
solution=skill.get("solution", ""),
steps=skill.get("steps"),
tags=skill.get("tags"),
source="user",
teacher_model=skill.get("teacher_model"),
confidence=skill.get("confidence", 0.8),
owner=owner,
category=skill.get("category", "general"),
when_to_use=skill.get("when_to_use"),
procedure=skill.get("procedure"),
pitfalls=skill.get("pitfalls"),
verification=skill.get("verification"),
platforms=skill.get("platforms"),
requires_toolsets=skill.get("requires_toolsets"),
fallback_for_toolsets=skill.get("fallback_for_toolsets"),
status=skill.get("status", "draft"),
version=skill.get("version", "1.0.0"),
)
if result.get("_deduped"):
continue
if result.get("name"):
existing_names.add(result["name"])
if result.get("id"):
existing_ids.add(result["id"])
existing_titles.add(title.lower())
added += 1
skills_manager.save(existing)
imported.append(f"{added} skills")
# ── Presets ──
+254 -69
View File
@@ -1,6 +1,7 @@
"""Calendar routes — local SQLite-backed calendar CRUD."""
import logging
import re
import uuid
from datetime import datetime, date, timedelta
from typing import Optional, List
@@ -12,7 +13,7 @@ from dateutil.rrule import rrulestr
from core.database import SessionLocal, CalendarCal, CalendarEvent
from src.auth_helpers import require_user
from src.upload_limits import read_upload_limited
from src.upload_limits import read_upload_limited, ICS_MAX_BYTES
logger = logging.getLogger(__name__)
@@ -100,6 +101,15 @@ def _ics_escape(text: str) -> str:
)
def _safe_ics_filename(name: str) -> str:
"""Return a conservative .ics filename safe for Content-Disposition."""
stem = name if isinstance(name, str) else ""
stem = re.sub(r"[^A-Za-z0-9._-]", "_", stem).strip("._-")
if not stem:
stem = "calendar"
return f"{stem[:128]}.ics"
def _resolve_base_uid(uid: str) -> str:
"""Extract the base series UID from a compound occurrence UID.
@@ -248,6 +258,17 @@ def parse_due_for_user(s: str) -> str:
if t is not None:
return base.replace(hour=t[0], minute=t[1]).isoformat()
# Time-first: "3pm today", "11pm today", "9am tomorrow"
m = _re.match(r'^(.+?)\s+(today|tonight|tomorrow|tmrw|yesterday)$', lower)
if m:
time_part, word = m.group(1).strip(), m.group(2)
base = today
if word in ("tomorrow", "tmrw"): base = today + _td(days=1)
elif word == "yesterday": base = today - _td(days=1)
t = _parse_time(time_part)
if t is not None:
return base.replace(hour=t[0], minute=t[1]).isoformat()
m = _re.match(r'^in\s+(\d+)\s*(hour|hr|minute|min|day)s?\s*$', lower)
if m:
n = int(m.group(1)); unit = m.group(2)
@@ -399,7 +420,17 @@ def _parse_dt(s: str) -> datetime:
# Last resort: dateutil's fuzzy parser
try:
from dateutil import parser as _du
return _du.parse(s)
parsed = _du.parse(s)
# Strip tz like every other return path above — this function's
# contract is naive datetimes (CalendarEvent.dtstart is naive). An
# offset-bearing non-ISO input (e.g. RFC-2822 "Mon, 05 Jan 2026
# 14:00:00 +0900") otherwise leaked tz-aware into the naive column and
# crashed read-back comparisons in _expand_rrule with "can't compare
# offset-naive and offset-aware datetimes".
if parsed.tzinfo is not None:
from datetime import timezone as _tz
return parsed.astimezone(_tz.utc).replace(tzinfo=None)
return parsed
except Exception:
raise ValueError(f"could not parse datetime: {s!r}")
@@ -440,6 +471,9 @@ def _event_to_dict(ev: CalendarEvent) -> dict:
# ── Recurrence expansion ──
_RRULE_EXPANSION_LIMIT = 1000
def _expand_rrule(
ev: CalendarEvent, start: datetime, end: datetime
) -> List[dict]:
@@ -462,6 +496,7 @@ def _expand_rrule(
d = _event_to_dict(ev)
d["is_recurrence"] = False
d["series_uid"] = ev.uid
d["truncated"] = False
return [d]
# Parse the rrule, applying it to the base dtstart.
@@ -487,6 +522,7 @@ def _expand_rrule(
d = _event_to_dict(ev)
d["is_recurrence"] = False
d["series_uid"] = ev.uid
d["truncated"] = False
# Malformed RRULE rows are fetched by the recurring SQL branch
# with only dtstart < end_dt — the base event may not actually
# overlap the window. Only return if it does.
@@ -499,22 +535,26 @@ def _expand_rrule(
# (matching non-recurring overlap semantics: dtstart < end AND
# dtend > start).
expand_start = start - duration
occurrences = rule.between(expand_start, end, inc=True)
if not occurrences:
return []
results = []
truncated = False
base = _event_to_dict(ev)
for occ_start in occurrences:
for occ_start in rule.xafter(expand_start, inc=True):
if occ_start >= end:
break
occ_end = occ_start + duration
# Overlap filter: occurrence must intersect [start, end).
# This enforces exclusive-end semantics (occ_start >= end is
# excluded) and includes multi-day crossings (occ_end > start).
if occ_start >= end or occ_end <= start:
if occ_end <= start:
continue
if len(results) >= _RRULE_EXPANSION_LIMIT:
truncated = True
break
# Build the compound uid: {base_uid}::{date} or ::{datetime}
if ev.all_day:
occ_uid = f"{ev.uid}::{occ_start.strftime('%Y-%m-%d')}"
@@ -525,6 +565,7 @@ def _expand_rrule(
d["uid"] = occ_uid
d["series_uid"] = ev.uid
d["is_recurrence"] = True
d["truncated"] = False
if ev.all_day:
d["dtstart"] = occ_start.strftime("%Y-%m-%d")
@@ -537,6 +578,10 @@ def _expand_rrule(
results.append(d)
if truncated:
for d in results:
d["truncated"] = True
return results
@@ -545,72 +590,178 @@ def _expand_rrule(
def setup_calendar_routes() -> APIRouter:
router = APIRouter(prefix="/api/calendar", tags=["calendar"])
# CalDAV connect form (Integrations → Calendar). Storage is local
# SQLite; sync (src/caldav_sync.py) pulls remote events into it on
# calendar open and periodically via the scheduler.
# ── CalDAV multi-account helpers ─────────────────────────────────────────
def _get_caldav_accounts(owner: str) -> list:
from src.caldav_sync import _load_caldav_accounts
return _load_caldav_accounts(owner)
def _save_caldav_accounts(owner: str, accounts: list) -> None:
from routes.prefs_routes import _load_for_user, _save_for_user
prefs = _load_for_user(owner) or {}
prefs["caldav_accounts"] = accounts
prefs.pop("caldav", None)
_save_for_user(owner, prefs)
# ── CalDAV config routes (backward-compat single-account API) ────────────
@router.get("/config")
async def get_config(request: Request):
"""Legacy single-account endpoint — returns the first configured account."""
owner = _require_user(request)
from routes.prefs_routes import _load_for_user
cfg = (_load_for_user(owner) or {}).get("caldav", {}) or {}
caldav_password = cfg.get("password") or ""
if caldav_password:
accounts = _get_caldav_accounts(owner)
if not accounts:
return {"url": "", "username": "", "password": "", "has_password": False, "local": True}
first = accounts[0]
pw = first.get("password") or ""
has_pw = False
if pw:
try:
from src.secret_storage import decrypt
caldav_password = decrypt(caldav_password)
has_pw = bool(decrypt(pw))
except Exception:
pass
# Surface url+username but never hand the password back to the
# client — saved-state UI shouldn't leak the credential.
has_pw = bool(pw)
return {
"url": cfg.get("url", "") or "",
"username": cfg.get("username", "") or "",
"url": first.get("url", "") or "",
"username": first.get("username", "") or "",
"password": "",
"has_password": bool(caldav_password),
"local": not bool(cfg.get("url")),
"has_password": has_pw,
"local": not bool(first.get("url")),
}
@router.post("/config")
async def save_config(request: Request):
"""Legacy single-account endpoint — upserts the first account."""
owner = _require_user(request)
from routes.prefs_routes import _load_for_user, _save_for_user
try:
body = await request.json()
except Exception:
body = {}
prefs = _load_for_user(owner) or {}
cfg = dict(prefs.get("caldav") or {})
# Empty url => clear the whole entry (treat as "remove integration").
accounts = _get_caldav_accounts(owner)
if not (body.get("url") or "").strip():
prefs.pop("caldav", None)
_save_for_user(owner, prefs)
_save_caldav_accounts(owner, [])
return {"ok": True, "cleared": True}
from src.caldav_sync import validate_caldav_url
try:
cfg["url"] = validate_caldav_url(body.get("url", ""))
validated_url = validate_caldav_url(body.get("url", ""))
except ValueError as e:
raise HTTPException(400, str(e))
cfg["username"] = (body.get("username") or "").strip()
# Preserve the stored password when the client sends an empty
# one (edit form re-submitted without re-typing the password).
# cfg already holds the existing (already-encrypted) password from
# prefs, so we only touch it when a new password is supplied —
# re-encrypting the stored value would double-encrypt it.
if accounts:
acc = dict(accounts[0])
else:
import uuid as _uuid
acc = {"id": str(_uuid.uuid4()), "label": "CalDAV"}
acc["url"] = validated_url
acc["username"] = (body.get("username") or "").strip()
if body.get("password"):
from src.secret_storage import encrypt
cfg["password"] = encrypt(body["password"])
prefs["caldav"] = cfg
_save_for_user(owner, prefs)
acc["password"] = encrypt(body["password"])
new_accounts = [acc] + (accounts[1:] if len(accounts) > 1 else [])
_save_caldav_accounts(owner, new_accounts)
return {"ok": True}
# ── CalDAV multi-account CRUD ─────────────────────────────────────────────
@router.get("/config/accounts")
async def list_caldav_accounts(request: Request):
"""Return all configured CalDAV accounts (passwords never returned)."""
owner = _require_user(request)
accounts = _get_caldav_accounts(owner)
safe = []
for acc in accounts:
pw = acc.get("password") or ""
has_pw = False
if pw:
try:
from src.secret_storage import decrypt
has_pw = bool(decrypt(pw))
except Exception:
has_pw = bool(pw)
safe.append({
"id": acc.get("id", ""),
"label": acc.get("label", "") or acc.get("url", ""),
"url": acc.get("url", "") or "",
"username": acc.get("username", "") or "",
"has_password": has_pw,
})
return {"accounts": safe}
@router.post("/config/accounts")
async def add_caldav_account(request: Request):
"""Add a new CalDAV account."""
import uuid as _uuid
owner = _require_user(request)
try:
body = await request.json()
except Exception:
body = {}
from src.caldav_sync import validate_caldav_url
try:
url = validate_caldav_url(body.get("url", ""))
except ValueError as e:
raise HTTPException(400, str(e))
if not body.get("password"):
raise HTTPException(400, "Password is required")
from src.secret_storage import encrypt
new_acc = {
"id": str(_uuid.uuid4()),
"label": (body.get("label") or "").strip() or "CalDAV",
"url": url,
"username": (body.get("username") or "").strip(),
"password": encrypt(body["password"]),
}
accounts = _get_caldav_accounts(owner)
accounts.append(new_acc)
_save_caldav_accounts(owner, accounts)
return {"ok": True, "id": new_acc["id"]}
@router.put("/config/accounts/{account_id}")
async def update_caldav_account(account_id: str, request: Request):
"""Update an existing CalDAV account by id."""
owner = _require_user(request)
try:
body = await request.json()
except Exception:
body = {}
accounts = _get_caldav_accounts(owner)
idx = next((i for i, a in enumerate(accounts) if a.get("id") == account_id), None)
if idx is None:
raise HTTPException(404, "Account not found")
acc = dict(accounts[idx])
if body.get("url"):
from src.caldav_sync import validate_caldav_url
try:
acc["url"] = validate_caldav_url(body["url"])
except ValueError as e:
raise HTTPException(400, str(e))
if body.get("label") is not None:
acc["label"] = (body.get("label") or "").strip() or "CalDAV"
if body.get("username") is not None:
acc["username"] = (body.get("username") or "").strip()
if body.get("password"):
from src.secret_storage import encrypt
acc["password"] = encrypt(body["password"])
accounts[idx] = acc
_save_caldav_accounts(owner, accounts)
return {"ok": True}
@router.delete("/config/accounts/{account_id}")
async def delete_caldav_account(account_id: str, request: Request):
"""Remove a CalDAV account by id."""
owner = _require_user(request)
accounts = _get_caldav_accounts(owner)
new_accounts = [a for a in accounts if a.get("id") != account_id]
if len(new_accounts) == len(accounts):
raise HTTPException(404, "Account not found")
_save_caldav_accounts(owner, new_accounts)
return {"ok": True}
@router.post("/test")
async def test_connection(request: Request):
"""Actually probe the configured CalDAV server with a PROPFIND
request (the same handshake every CalDAV client uses). Accepts
an optional {url, username, password} body so the user can test
a configuration BEFORE saving it; falls back to the stored
creds otherwise. Returns {ok, error?} with a useful message on
failure (status code, auth issue, network error)."""
"""Probe a CalDAV server with a PROPFIND. Accepts an optional body:
{url, username, password} to test before saving, or {account_id} to
test an already-saved account. Falls back to the first saved account
when nothing is provided."""
owner = _require_user(request)
try:
body = await request.json()
@@ -620,19 +771,24 @@ def setup_calendar_routes() -> APIRouter:
user = (body.get("username") or "").strip()
pw = body.get("password") or ""
if not (url and user and pw):
# Fall back to saved settings for this user.
from routes.prefs_routes import _load_for_user
cfg = (_load_for_user(owner) or {}).get("caldav", {}) or {}
url = url or (cfg.get("url") or "")
user = user or (cfg.get("username") or "")
if not pw:
pw = cfg.get("password") or ""
if pw:
try:
from src.secret_storage import decrypt
pw = decrypt(pw)
except Exception:
pass
# Look up a saved account: by id if supplied, else first account.
accounts = _get_caldav_accounts(owner)
acc = None
if body.get("account_id"):
acc = next((a for a in accounts if a.get("id") == body["account_id"]), None)
if acc is None and accounts:
acc = accounts[0]
if acc:
url = url or (acc.get("url") or "")
user = user or (acc.get("username") or "")
if not pw:
pw = acc.get("password") or ""
if pw:
try:
from src.secret_storage import decrypt
pw = decrypt(pw)
except Exception:
pass
if not (url and user and pw):
return {"ok": False, "error": "Missing URL, username, or password"}
from src.caldav_sync import validate_caldav_url
@@ -695,6 +851,28 @@ def setup_calendar_routes() -> APIRouter:
from src.caldav_sync import sync_caldav
return await sync_caldav(owner)
@router.delete("/calendars/{cal_id}")
async def delete_calendar(cal_id: str, request: Request):
owner = _require_user(request)
db = SessionLocal()
try:
cal = db.query(CalendarCal).filter(
CalendarCal.id == cal_id,
CalendarCal.owner == owner,
).first()
if not cal:
raise HTTPException(404, "Calendar not found")
db.delete(cal)
db.commit()
return {"ok": True}
except HTTPException:
raise
except Exception as e:
logger.error("Failed to delete calendar %s: %s", cal_id, e)
raise HTTPException(500, "Failed to delete calendar")
finally:
db.close()
@router.get("/calendars")
async def list_calendars(request: Request):
owner = _require_user(request)
@@ -703,7 +881,7 @@ def setup_calendar_routes() -> APIRouter:
_ensure_default_calendar(db, owner)
cals = db.query(CalendarCal).filter(CalendarCal.owner == owner).all()
return {"calendars": [
{"name": c.name, "href": c.id, "color": c.color}
{"name": c.name, "href": c.id, "color": c.color, "source": c.source}
for c in cals
]}
except HTTPException:
@@ -766,8 +944,12 @@ def setup_calendar_routes() -> APIRouter:
expanded.extend(_expand_rrule(e, start_dt, end_dt))
# Sort by occurrence start time for consistent frontend ordering.
truncated = any(e.get("truncated") for e in expanded)
expanded.sort(key=lambda d: d["dtstart"])
return {"events": expanded}
response: dict = {"events": expanded}
if truncated:
response["truncated"] = True
return response
except HTTPException:
raise
except Exception as e:
@@ -988,9 +1170,9 @@ def setup_calendar_routes() -> APIRouter:
finally:
db.close()
# 10 MB hard cap on ICS upload. Loading the whole file into memory is
# unavoidable with python-icalendar, so an unbounded upload would OOM.
_ICS_MAX_BYTES = 10 * 1024 * 1024
# Hard cap on ICS upload (ICS_MAX_BYTES, default 10 MB). Loading the whole
# file into memory is unavoidable with python-icalendar, so an unbounded
# upload would OOM.
@router.post("/import")
async def import_ics(request: Request, file: UploadFile = File(...), calendar_name: str = ""):
@@ -1000,7 +1182,7 @@ def setup_calendar_routes() -> APIRouter:
owner = _require_user(request)
db = SessionLocal()
try:
content = await read_upload_limited(file, _ICS_MAX_BYTES, "ICS file")
content = await read_upload_limited(file, ICS_MAX_BYTES, "ICS file")
try:
cal_data = iCal.from_ical(content)
except Exception as e:
@@ -1168,11 +1350,14 @@ def setup_calendar_routes() -> APIRouter:
lines.append("END:VCALENDAR")
ics_data = "\r\n".join(lines)
safe_name = cal.name.replace(" ", "_").replace("/", "_")
download_name = _safe_ics_filename(cal.name)
return Response(
content=ics_data,
media_type="text/calendar",
headers={"Content-Disposition": f'attachment; filename="{safe_name}.ics"'},
headers={
"Content-Disposition": f'attachment; filename="{download_name}"',
"X-Content-Type-Options": "nosniff",
},
)
except HTTPException:
raise
@@ -1194,7 +1379,7 @@ def setup_calendar_routes() -> APIRouter:
"tomorrow", "next Tuesday", "in 30 minutes" resolve correctly.
Uses the "utility" endpoint (small / fast model) to keep latency low.
"""
_require_user(request)
owner = _require_user(request)
from src.endpoint_resolver import resolve_endpoint
from src.llm_core import llm_call_async
from src.text_helpers import strip_think
@@ -1220,9 +1405,9 @@ def setup_calendar_routes() -> APIRouter:
if tz_hint:
set_user_tz_name(tz_hint)
url, model, headers = resolve_endpoint("utility")
url, model, headers = resolve_endpoint("utility", owner=owner or None)
if not url:
url, model, headers = resolve_endpoint("default")
url, model, headers = resolve_endpoint("default", owner=owner or None)
if not url or not model:
return {"ok": False, "error": "No LLM endpoint configured"}
+130 -38
View File
@@ -75,7 +75,7 @@ def _enforce_chat_privileges(request, sess) -> None:
allowlist, or HTTPException(429) if the user has hit their daily message
cap. No-op for unauthenticated callers or when auth_manager is absent
(single-user mode). Admins receive ADMIN_PRIVILEGES from get_privileges,
which means empty allowed_models / zero cap no-op for them.
which means unrestricted allowed_models / zero cap -> no-op for them.
"""
try:
user = get_current_user(request)
@@ -88,8 +88,18 @@ def _enforce_chat_privileges(request, sess) -> None:
return
privs = auth_manager.get_privileges(user) or {}
allowed = privs.get("allowed_models") or []
if allowed and sess.model and sess.model not in allowed:
# Explicit "block everything" sentinel takes precedence over the
# allowlist — it's the only way to distinguish "user clicked [None]"
# (block all) from "user clicked [All]" (no restriction), since both
# otherwise produce an empty `allowed_models` list.
if privs.get("block_all_models"):
raise HTTPException(403, f"Your account is not allowed to use model '{sess.model}'.")
allowed_raw = privs.get("allowed_models")
allowed = allowed_raw if isinstance(allowed_raw, list) else []
restricted = bool(privs.get("allowed_models_restricted")) or bool(allowed)
if restricted and sess.model and sess.model not in allowed:
raise HTTPException(403, f"Your account is not allowed to use model '{sess.model}'.")
cap = int(privs.get("max_messages_per_day") or 0)
@@ -194,14 +204,26 @@ def try_fallback_endpoint(sess, session_id: str) -> dict | None:
Returns {"model": ..., "endpoint_url": ..., "endpoint_name": ...} or None.
"""
import requests as _req
from src.endpoint_resolver import build_chat_url, build_headers, build_models_url, normalize_base
from src.endpoint_resolver import (
build_chat_url,
build_headers,
build_models_url,
normalize_base,
resolve_endpoint_runtime,
)
from src.chatgpt_subscription import is_chatgpt_subscription_base
current_url = sess.endpoint_url or ""
owner = getattr(sess, "owner", None)
db = SessionLocal()
try:
endpoints = db.query(ModelEndpoint).filter(
q = db.query(ModelEndpoint).filter(
ModelEndpoint.is_enabled == True
).all()
)
if owner:
from src.auth_helpers import owner_filter
q = owner_filter(q, ModelEndpoint, owner)
endpoints = q.all()
finally:
db.close()
@@ -210,26 +232,33 @@ def try_fallback_endpoint(sess, session_id: str) -> dict | None:
# Skip current endpoint
if current_url and base in current_url:
continue
# Quick ping
ping_url = build_models_url(base)
headers = build_headers(ep.api_key, base)
try:
r = _req.get(ping_url, headers=headers, timeout=5)
r.raise_for_status()
data = r.json()
models = [m.get("id") for m in (data.get("data") or []) if m.get("id")]
if not models:
models = [
m.get("name") or m.get("model")
for m in (data.get("models") or [])
if m.get("name") or m.get("model")
]
base, api_key = resolve_endpoint_runtime(ep, owner=owner)
except Exception:
continue
ping_url = build_models_url(base)
headers = build_headers(api_key, base)
try:
if ping_url:
r = _req.get(ping_url, headers=headers, timeout=5)
r.raise_for_status()
data = r.json()
models = [m.get("id") for m in (data.get("data") or []) if m.get("id")]
if not models:
models = [
m.get("name") or m.get("model")
for m in (data.get("models") or [])
if m.get("name") or m.get("model")
]
else:
models = json.loads(ep.cached_models or "[]")
if not models:
continue
# Found a working endpoint — update session
new_model = models[0]
chat_url = build_chat_url(base)
new_headers = build_headers(ep.api_key, base)
new_headers = build_headers(api_key, base)
persisted_headers = {} if is_chatgpt_subscription_base(base) else new_headers
sess.model = new_model
sess.endpoint_url = chat_url
@@ -241,7 +270,7 @@ def try_fallback_endpoint(sess, session_id: str) -> dict | None:
_db.query(DBSession).filter(DBSession.id == session_id).update({
"model": new_model,
"endpoint_url": chat_url,
"headers": json.dumps(new_headers),
"headers": persisted_headers,
})
_db.commit()
finally:
@@ -275,11 +304,16 @@ def extract_preset(chat_handler, preset_id) -> PresetInfo:
async def preprocess(
chat_handler, message, att_ids, sess,
auto_opened_docs: Optional[list] = None,
allow_tool_preprocessing: bool = True,
) -> PreprocessedMessage:
"""Run chat_handler.preprocess_message and wrap the result."""
enhanced, user_content, text_ctx, yt_transcripts, att_meta = (
await chat_handler.preprocess_message(
message, att_ids, sess, auto_opened_docs=auto_opened_docs
message,
att_ids,
sess,
auto_opened_docs=auto_opened_docs,
allow_tool_preprocessing=allow_tool_preprocessing,
)
)
return PreprocessedMessage(
@@ -329,16 +363,26 @@ def _session_url_matches_endpoint(session_url: str, endpoint_base: str) -> bool:
return False
def _has_auth_keys(headers) -> bool:
"""True if a headers dict carries an Authorization/x-api-key entry."""
return isinstance(headers, dict) and any(
k.lower() in ('authorization', 'x-api-key') for k in headers
)
def resolve_session_auth(sess, session_id: str, owner: Optional[str] = None):
"""Ensure session has auth headers — resolve from endpoint DB if missing."""
has_auth = sess.headers and isinstance(sess.headers, dict) and any(
k.lower() in ('authorization', 'x-api-key') for k in sess.headers
)
if has_auth:
try:
from src.chatgpt_subscription import is_chatgpt_subscription_base
is_chatgpt_subscription = is_chatgpt_subscription_base(getattr(sess, "endpoint_url", "") or "")
except Exception:
is_chatgpt_subscription = False
has_auth = _has_auth_keys(sess.headers)
if has_auth and not is_chatgpt_subscription:
return
try:
from src.endpoint_resolver import build_headers, normalize_base
from src.endpoint_resolver import build_headers, resolve_endpoint_runtime
db = SessionLocal()
try:
target_url = getattr(sess, "endpoint_url", "") or ""
@@ -354,10 +398,30 @@ def resolve_session_auth(sess, session_id: str, owner: Optional[str] = None):
for ep in q.all():
if not _session_url_matches_endpoint(target_url, ep.base_url or ""):
continue
if not ep.api_key:
try:
base, api_key = resolve_endpoint_runtime(ep, owner=owner)
except Exception as e:
logger.warning("Failed to resolve provider auth for session %s: %s", session_id, e)
return
if not api_key:
# No usable key (e.g. ChatGPT Subscription needs re-auth).
return
sess.headers = build_headers(api_key, base)
if is_chatgpt_subscription:
# The bearer is short-lived and re-resolved per request, so it
# stays request-local and is never written to the plaintext
# sessions.headers column. Proactively strip any bearer an
# older code path may have persisted so it does not linger.
stale_q = db.query(DBSession).filter(DBSession.id == session_id)
if owner:
stale_q = stale_q.filter(DBSession.owner == owner)
stored = stale_q.first()
if stored is not None and _has_auth_keys(stored.headers):
stale_q.update({"headers": {}})
db.commit()
logger.info(f"Cleared persisted ChatGPT Subscription bearer from session {session_id}")
logger.debug(f"Resolved request-local ChatGPT Subscription auth for session {session_id}")
return
base = normalize_base(ep.base_url or "")
sess.headers = build_headers(ep.api_key, base)
update_q = db.query(DBSession).filter(DBSession.id == session_id)
if owner:
update_q = update_q.filter(DBSession.owner == owner)
@@ -401,7 +465,12 @@ def _normalize_model_id_from_cache(sess) -> Optional[str]:
db = SessionLocal()
try:
endpoints = db.query(ModelEndpoint).filter(ModelEndpoint.is_enabled == True).all()
q = db.query(ModelEndpoint).filter(ModelEndpoint.is_enabled == True)
owner = getattr(sess, "owner", None)
if owner:
from src.auth_helpers import owner_filter
q = owner_filter(q, ModelEndpoint, owner)
endpoints = q.all()
for ep in endpoints:
try:
if normalize_base(getattr(ep, "base_url", "") or "") != session_base:
@@ -448,6 +517,7 @@ async def build_chat_context(
webhook_manager=None,
use_enhanced_message: bool = False,
agent_mode: bool = False,
allow_tool_preprocessing: bool = True,
) -> ChatContext:
"""Build the full context (preface + messages) for an LLM call.
@@ -465,6 +535,7 @@ async def build_chat_context(
preprocessed = await preprocess(
chat_handler, message, att_ids or [], sess,
auto_opened_docs=auto_opened_docs,
allow_tool_preprocessing=allow_tool_preprocessing,
)
# Add user message to history
@@ -483,6 +554,9 @@ async def build_chat_context(
# Skills injection respects its own enable toggle (mirrors memory_enabled).
# When off, the "Available skills" index is not added to the prompt.
skills_enabled = not incognito and uprefs.get("skills_enabled", True)
if not allow_tool_preprocessing:
mem_enabled = False
skills_enabled = False
logger.debug(
"Memory enabled=%s for user=%s (incognito=%s, no_memory=%s, pref=%s)",
mem_enabled, user, incognito, no_memory, uprefs.get("memory_enabled", "NOT_SET"),
@@ -490,11 +564,11 @@ async def build_chat_context(
# Use RAG?
use_rag_val = (str(use_rag).lower() != "false") if use_rag is not None else True
if incognito:
if incognito or not allow_tool_preprocessing:
use_rag_val = False
# If pre-fetched search context was provided (compare mode), skip live web search
skip_web = bool(search_context)
skip_web = bool(search_context) or not allow_tool_preprocessing
# Build context preface
# The stream path uses enhanced_message (with CoT/preprocessing applied),
@@ -521,7 +595,7 @@ async def build_chat_context(
used_memories = getattr(chat_processor, '_last_used_memories', [])
# Inject pre-fetched search context (compare mode)
if search_context:
if search_context and allow_tool_preprocessing:
preface.append(untrusted_context_message("prefetched search context", search_context))
# YouTube transcripts
@@ -530,7 +604,11 @@ async def build_chat_context(
# Normalize model ID. Prefer cached endpoint models so group chat does not
# re-hit slow local /models endpoints on every participant turn.
norm = _normalize_model_id_from_cache(sess) or normalize_model_id(sess.endpoint_url, sess.model)
norm = _normalize_model_id_from_cache(sess) or normalize_model_id(
sess.endpoint_url,
sess.model,
owner=getattr(sess, "owner", None),
)
if norm:
sess.model = norm
@@ -539,7 +617,7 @@ async def build_chat_context(
# Auto-compact
messages, context_length, was_compacted = await maybe_compact(
sess, sess.endpoint_url, sess.model, messages, sess.headers,
sess, sess.endpoint_url, sess.model, messages, sess.headers, owner=user,
)
messages = trim_for_context(messages, context_length)
@@ -772,7 +850,19 @@ def save_assistant_response(
):
"""Add assistant response to session history. In incognito mode, keeps in-memory context but skips DB persistence."""
md = dict(last_metrics) if last_metrics else {}
md["model"] = sess.model
def _model_value(value) -> str:
if value is None:
return ""
if not isinstance(value, str):
value = str(value)
return value.strip()
requested_model = _model_value(md.get("requested_model") or md.get("selected_model") or getattr(sess, "model", ""))
actual_model = _model_value(md.get("model") or md.get("actual_model") or requested_model)
if requested_model:
md["requested_model"] = requested_model
if actual_model:
md["model"] = actual_model
if character_name:
md["character_name"] = character_name
if web_sources:
@@ -841,12 +931,13 @@ def run_post_response_tasks(
skills_manager=None,
owner: str = None,
extract_skills: bool = True,
allow_background_extraction: bool = True,
):
"""Fire background tasks after a completed response: memory extraction, webhooks, auto-name, skill extraction."""
# Memory extraction — only every 4th message pair to avoid excess LLM calls
_msg_count = len(sess.history) if hasattr(sess, 'history') else 0
_should_extract = (_msg_count >= 4) and (_msg_count % 4 == 0)
if not incognito and not compare_mode and _should_extract and uprefs.get("auto_memory", True):
if allow_background_extraction and not incognito and not compare_mode and _should_extract and uprefs.get("auto_memory", True):
from services.memory.memory_extractor import extract_and_store
from src.task_endpoint import resolve_task_endpoint
t_url, t_model, t_headers = resolve_task_endpoint(
@@ -873,6 +964,7 @@ def run_post_response_tasks(
)
if (
extract_skills
and allow_background_extraction
and auto_skills_enabled
and not incognito
and not compare_mode
+206 -72
View File
@@ -20,6 +20,7 @@ from src import agent_runs
from src.model_context import estimate_tokens
from src.chat_helpers import coerce_message_and_session
from src.endpoint_resolver import normalize_base as _normalize_base, build_chat_url
from src.session_search import search_session_messages
from src.prompt_security import untrusted_context_message
from core.exceptions import SessionNotFoundError
from src.auth_helpers import get_current_user
@@ -39,6 +40,7 @@ from routes.chat_helpers import (
_enforce_chat_privileges,
)
from src.action_intents import classify_tool_intent as _classify_tool_intent
from src.tool_policy import build_effective_tool_policy
logger = logging.getLogger(__name__)
@@ -167,13 +169,20 @@ def _recover_empty_session_model(sess, session_id: str, owner: str | None = None
Covers the window between endpoint setup and the first chat send: the
picker showed a model in the dropdown but the session record never got
written (Issue #587 — UI uses the cached endpoint list, not s.model).
Without this, we'd POST the upstream with model="" and get a generic
401/503 instead of using the model the user already picked.
Returns True iff sess.model was repaired.
For ChatGPT Subscription, also repairs stale OpenAI API model names such as
``gpt-5`` that are not accepted by the Codex-backed ChatGPT account route.
"""
if getattr(sess, "model", None):
return False
current_model = (getattr(sess, "model", "") or "").strip()
endpoint_url = (getattr(sess, "endpoint_url", "") or "").strip()
is_chatgpt_subscription = False
if current_model:
try:
from src.chatgpt_subscription import is_chatgpt_subscription_base
is_chatgpt_subscription = is_chatgpt_subscription_base(endpoint_url)
if not is_chatgpt_subscription:
return False
except Exception:
return False
db = SessionLocal()
try:
# Prefer the endpoint whose base URL matches the session — we know the
@@ -192,16 +201,51 @@ def _recover_empty_session_model(sess, session_id: str, owner: str | None = None
break
if not ep:
return False
if not is_chatgpt_subscription:
try:
from src.chatgpt_subscription import is_chatgpt_subscription_base
is_chatgpt_subscription = is_chatgpt_subscription_base(getattr(ep, "base_url", "") or endpoint_url)
except Exception:
is_chatgpt_subscription = False
try:
cached = json.loads(ep.cached_models) if isinstance(ep.cached_models, str) else (ep.cached_models or [])
except Exception:
cached = []
if not cached:
visible = []
else:
try:
visible = _visible_models(cached, getattr(ep, "hidden_models", None))
except Exception:
visible = cached
if current_model and current_model in {str(item).strip() for item in visible}:
return False
try:
visible = _visible_models(cached, getattr(ep, "hidden_models", None))
except Exception:
visible = cached
if is_chatgpt_subscription:
live_models = []
if getattr(ep, "provider_auth_id", None):
try:
from src.chatgpt_subscription import fetch_available_models
from src.endpoint_resolver import resolve_endpoint_runtime
_base, api_key = resolve_endpoint_runtime(ep, owner=owner)
if api_key:
live_models = fetch_available_models(api_key)
if live_models:
ep.cached_models = json.dumps(live_models)
db.commit()
except Exception:
live_models = []
# ChatGPT Subscription recovery must use the live Codex catalog.
# Cached rows are only trusted above to avoid revalidating a model
# that is already present in the visible picker list.
cached = live_models
if not cached:
return False
try:
visible = _visible_models(cached, getattr(ep, "hidden_models", None))
except Exception:
visible = cached
if current_model and current_model in {str(item).strip() for item in visible}:
return False
if not visible:
return False
model = visible[0]
@@ -211,14 +255,17 @@ def _recover_empty_session_model(sess, session_id: str, owner: str | None = None
# Persist so the next request, websocket reconnect, or page reload
# picks up the same model (we'd otherwise re-pick on every send
# and silently switch on the user if the cached order shifts).
db_session = db.query(DBSession).filter(DBSession.id == session_id).first()
db_session_q = db.query(DBSession).filter(DBSession.id == session_id)
if owner:
db_session_q = db_session_q.filter(DBSession.owner == owner)
db_session = db_session_q.first()
if db_session:
db_session.model = model
db_session.updated_at = datetime.utcnow()
db.commit()
sess.model = model
logger.info(
"Recovered empty session model for %s — picked %r from endpoint %s",
"Recovered session model for %s — picked %r from endpoint %s",
session_id, model, ep.id,
)
return True
@@ -304,8 +351,13 @@ def setup_chat_routes(
# non-streaming path can't be used to bypass).
_enforce_chat_privileges(request, sess)
tool_policy = build_effective_tool_policy(last_user_message=message)
allow_tool_preprocessing = not tool_policy.block_all_tool_calls
# Inline memory command
memory_response = await chat_handler.handle_memory_command(sess, message)
memory_response = None
if not tool_policy.blocks("manage_memory"):
memory_response = await chat_handler.handle_memory_command(sess, message)
if memory_response:
return {"response": memory_response}
@@ -319,10 +371,15 @@ def setup_chat_routes(
use_web=use_web,
time_filter=time_filter,
webhook_manager=webhook_manager,
allow_tool_preprocessing=allow_tool_preprocessing,
)
# Research injection
if use_research:
research_blocked_by_policy = (
tool_policy.blocks("trigger_research")
or tool_policy.blocks("manage_research")
)
if use_research and not research_blocked_by_policy:
try:
_r_ep, _r_model, _r_headers = _resolve_research_endpoint(sess)
research_ctx = await research_handler.call_research_service(
@@ -357,6 +414,7 @@ def setup_chat_routes(
ctx.uprefs, memory_manager, memory_vector, webhook_manager,
character_name=ctx.preset.character_name,
owner=ctx.user,
allow_background_extraction=not tool_policy.block_all_tool_calls,
)
return {"response": reply}
@@ -394,6 +452,7 @@ def setup_chat_routes(
search_context = form_data.get("search_context") # pre-fetched web search results (compare mode)
compare_mode = str(form_data.get("compare_mode", "")).lower() == "true"
incognito = str(form_data.get("incognito", "")).lower() == "true"
plan_mode = str(form_data.get("plan_mode", "")).lower() == "true"
chat_mode = str(form_data.get("mode", "")).lower() # 'chat' or 'agent'
# Workspace: confine the agent's file/shell tools to this folder. Validate
# it's a real directory; ignore (no confinement) otherwise.
@@ -401,6 +460,17 @@ def setup_chat_routes(
if workspace:
_ws_real = os.path.realpath(os.path.expanduser(workspace))
workspace = _ws_real if os.path.isdir(_ws_real) else ""
# Plan mode is a modifier on agent mode — it only makes sense with tools.
if plan_mode:
chat_mode = "agent"
# An approved plan being EXECUTED: the frontend sends the checklist back
# on each turn so we can pin it in context. This way a long plan on a
# weak model survives history truncation — the agent can always re-read
# the plan. Ignored while still proposing (plan_mode on). Capped so a
# huge plan can't blow the prompt.
approved_plan = ""
if not plan_mode:
approved_plan = (form_data.get("approved_plan") or "").strip()[:8192]
# Did the USER explicitly pick agent mode? (vs. us auto-escalating
# below). Skill extraction should only learn from real agent sessions,
# not chats we quietly promoted for a notes/calendar intent.
@@ -479,11 +549,6 @@ def setup_chat_routes(
do_research = True
logger.info(f"Session {session} in research_pending — auto-triggering research")
# Persist session mode (research > agent > chat)
_effective_mode = 'research' if do_research else (chat_mode or 'chat')
if _effective_mode in ('agent', 'research', 'chat'):
set_session_mode(session, _effective_mode)
att_ids = []
if body and isinstance(body.get("attachments"), list):
att_ids = [str(x) for x in body["attachments"]]
@@ -494,6 +559,10 @@ def setup_chat_routes(
pass
no_memory = str(form_data.get("no_memory", "")).lower() == "true"
pre_context_tool_policy = build_effective_tool_policy(
last_user_message=message,
)
allow_tool_preprocessing = not pre_context_tool_policy.block_all_tool_calls
# Build shared context (stream path uses enhanced_message for context preface)
ctx = await build_chat_context(
@@ -515,6 +584,7 @@ def setup_chat_routes(
# manage_skills (agent mode). In plain chat or incognito the
# index would be useless / unwanted noise.
agent_mode=(chat_mode == "agent"),
allow_tool_preprocessing=allow_tool_preprocessing,
)
_research_flags = {"do": do_research} # Mutable container for generator scope
@@ -659,6 +729,32 @@ def setup_chat_routes(
if chat_mode == 'chat':
disabled_tools.update({"bash", "python", "read_file", "write_file", "web_search", "web_fetch", "search_chats", "manage_tasks"})
# Plan mode: investigate read-only, propose a plan, don't mutate. Block
# every tool not on the read-only allowlist. (stream_agent_loop enforces
# this again + drops MCP, so this is belt-and-suspenders.)
if plan_mode:
from src.tool_security import plan_mode_disabled_tools
disabled_tools.update(plan_mode_disabled_tools())
tool_policy = build_effective_tool_policy(
disabled_tools=disabled_tools,
last_user_message=message,
)
disabled_tools = tool_policy.all_disabled_names()
research_blocked_by_policy = bool(
tool_policy.blocks("trigger_research")
or tool_policy.blocks("manage_research")
)
effective_do_research = bool(
do_research and _research_flags["do"] and not research_blocked_by_policy
)
# Persist session mode after policy/privilege gates so blocked research
# turns remain ordinary chat/agent streams and saved messages.
_effective_mode = 'research' if effective_do_research else (chat_mode or 'chat')
if _effective_mode in ('agent', 'research', 'chat'):
set_session_mode(session, _effective_mode)
async def stream_with_save() -> AsyncGenerator[str, None]:
# _effective_mode is read-only here; closure captures it from
# the outer scope. (Was `nonlocal` but never reassigned.)
@@ -666,7 +762,7 @@ def setup_chat_routes(
web_sources = ctx.web_sources
# Register active stream for partial-save safety net
_active_streams[session] = {"status": "streaming", "partial": "", "query": message, "is_research": do_research, "mode": _effective_mode}
_active_streams[session] = {"status": "streaming", "partial": "", "query": message, "is_research": effective_do_research, "mode": _effective_mode}
if ctx.preprocessed.attachment_meta:
yield f"data: {json.dumps({'type': 'attachments', 'data': ctx.preprocessed.attachment_meta})}\n\n"
@@ -690,7 +786,7 @@ def setup_chat_routes(
yield f"data: {json.dumps({'type': 'memories_used', 'data': ctx.used_memories})}\n\n"
# Run research as a background task (survives page refresh)
if do_research and _research_flags["do"]:
if effective_do_research:
_r_ep, _r_model, _r_headers = _resolve_research_endpoint(sess)
_auth_keys = list(_r_headers.keys()) if _r_headers else []
logger.info(f"Research endpoint resolved: model={_r_model}, endpoint={_r_ep}, auth_keys={_auth_keys}, sess_headers_keys={list(sess.headers.keys()) if isinstance(sess.headers, dict) else type(sess.headers)}")
@@ -829,7 +925,7 @@ def setup_chat_routes(
_fallback_candidates = []
# Send model name early so the frontend can show it during streaming
_model_suffix = "Research" if do_research else None
_model_suffix = "Research" if effective_do_research else None
_model_info = {"type": "model_info", "model": sess.model}
if _model_suffix:
_model_info["suffix"] = _model_suffix
@@ -839,6 +935,12 @@ def setup_chat_routes(
if _is_image_generation_session(sess, owner=_user):
from src.settings import get_setting
if tool_policy.blocks("generate_image"):
_blocked_msg = tool_policy.reason_for("generate_image")
yield f'data: {json.dumps({"delta": _blocked_msg})}\n\n'
yield "data: [DONE]\n\n"
_active_streams.pop(session, None)
return
if not get_setting("image_gen_enabled", True):
yield f'data: {json.dumps({"delta": "Image generation is disabled by the administrator."})}\n\n'
yield "data: [DONE]\n\n"
@@ -873,6 +975,8 @@ def setup_chat_routes(
elif chat_mode == "chat":
_chat_start = time.time()
_answered_by = None # set if the selected model failed and a fallback answered
_requested_model = sess.model
_actual_model = None
# ── Chat mode: call stream_llm directly, NO tools, NO document access ──
try:
_chat_candidates = [(sess.endpoint_url, sess.model, sess.headers)] + _fallback_candidates
@@ -905,10 +1009,18 @@ def setup_chat_routes(
# Selected model failed; a fallback answered.
# Forward the notice and remember the real model.
_answered_by = data.get("answered_by") or _answered_by
_actual_model = _actual_model or _answered_by
data["selected_model"] = data.get("selected_model") or _requested_model
yield chunk
elif data.get("type") == "model_actual":
_actual_model = data.get("model") or _actual_model
data["requested_model"] = _requested_model
yield f'data: {json.dumps(data)}\n\n'
elif data.get("type") == "usage":
last_metrics = data.get("data", {})
last_metrics["model"] = _answered_by or sess.model
_reported_model = last_metrics.get("model")
last_metrics["requested_model"] = _requested_model
last_metrics["model"] = _reported_model or _actual_model or _answered_by or _requested_model
if ctx.context_length and last_metrics.get("input_tokens"):
pct = min(round((last_metrics["input_tokens"] / ctx.context_length) * 100, 1), 100.0)
last_metrics["context_percent"] = pct
@@ -945,7 +1057,8 @@ def setup_chat_routes(
"tokens_per_second": _tps,
"context_percent": _ctx_pct,
"context_length": ctx.context_length,
"model": sess.model,
"model": _actual_model or _answered_by or _requested_model,
"requested_model": _requested_model,
"usage_source": "estimated",
}
yield f'data: {json.dumps({"type": "metrics", "data": last_metrics})}\n\n'
@@ -957,7 +1070,7 @@ def setup_chat_routes(
rag_sources=ctx.rag_sources,
research_sources=research_sources,
used_memories=ctx.used_memories,
do_research=do_research,
do_research=effective_do_research,
incognito=incognito,
)
if _saved_id:
@@ -967,14 +1080,22 @@ def setup_chat_routes(
last_metrics, ctx.uprefs, memory_manager, memory_vector, webhook_manager,
incognito=incognito, compare_mode=compare_mode,
character_name=ctx.preset.character_name,
owner=_user,
owner=_user,
allow_background_extraction=not tool_policy.block_all_tool_calls,
)
_stream_set(session, status="done")
yield chunk
except (asyncio.CancelledError, GeneratorExit):
if full_response:
logger.info("Client disconnected mid-stream (chat mode) for session %s, saving partial (%d chars)", session, len(full_response))
_stopped_content, _stopped_md = clean_thinking_for_save(full_response, {"stopped": True, "model": sess.model})
_stopped_content, _stopped_md = clean_thinking_for_save(
full_response,
{
"stopped": True,
"model": _actual_model or _answered_by or _requested_model,
"requested_model": _requested_model,
},
)
sess.add_message(ChatMessage("assistant", _stopped_content, metadata=_stopped_md))
if not incognito:
session_manager.save_sessions()
@@ -986,6 +1107,8 @@ def setup_chat_routes(
_agent_rounds = 0
_agent_tool_calls = 0
_answered_by = None # set if the selected model failed and a fallback answered
_requested_model = sess.model
_actual_model = None
try:
from src.settings import get_setting
from src.agent_tools import MAX_AGENT_ROUNDS as _DEFAULT_ROUNDS
@@ -1012,9 +1135,12 @@ def setup_chat_routes(
active_document=active_doc,
session_id=session,
disabled_tools=disabled_tools if disabled_tools else None,
tool_policy=tool_policy,
owner=_user,
fallbacks=_fallback_candidates,
workspace=workspace or None,
plan_mode=plan_mode,
approved_plan=approved_plan or None,
):
if chunk.startswith("data: ") and not chunk.startswith("data: [DONE]"):
try:
@@ -1035,6 +1161,8 @@ def setup_chat_routes(
"doc_stream_open", "doc_stream_delta",
"doc_update", "doc_suggestions", "ui_control",
"rounds_exhausted",
"ask_user",
"plan_update",
):
if data.get("type") == "agent_step":
_agent_rounds = max(_agent_rounds, data.get("round", 1))
@@ -1047,10 +1175,18 @@ def setup_chat_routes(
# model so metrics reflect it, not the masked
# selected model.
_answered_by = data.get("answered_by") or _answered_by
_actual_model = _actual_model or _answered_by
data["selected_model"] = data.get("selected_model") or _requested_model
yield chunk
elif data.get("type") == "model_actual":
_actual_model = data.get("model") or _actual_model
data["requested_model"] = _requested_model
yield f'data: {json.dumps(data)}\n\n'
elif data.get("type") == "metrics":
last_metrics = data.get("data", {})
last_metrics["model"] = _answered_by or sess.model
_reported_model = last_metrics.get("model")
last_metrics["requested_model"] = last_metrics.get("requested_model") or _requested_model
last_metrics["model"] = _reported_model or _actual_model or _answered_by or _requested_model
yield f'data: {json.dumps({"type": "metrics", "data": last_metrics})}\n\n'
except json.JSONDecodeError:
yield chunk
@@ -1078,6 +1214,7 @@ def setup_chat_routes(
skills_manager=skills_manager,
owner=_user,
extract_skills=user_requested_agent,
allow_background_extraction=not tool_policy.block_all_tool_calls,
)
_stream_set(session, status="done")
yield chunk
@@ -1091,7 +1228,14 @@ def setup_chat_routes(
try:
if full_response:
logger.info("Client disconnected mid-stream for session %s, saving partial response (%d chars)", session, len(full_response))
_stopped_content2, _stopped_md2 = clean_thinking_for_save(full_response, {"stopped": True, "model": sess.model})
_stopped_content2, _stopped_md2 = clean_thinking_for_save(
full_response,
{
"stopped": True,
"model": _actual_model or _answered_by or _requested_model,
"requested_model": _requested_model,
},
)
sess.add_message(ChatMessage("assistant", _stopped_content2, metadata=_stopped_md2))
if not incognito:
session_manager.save_sessions()
@@ -1110,11 +1254,30 @@ def setup_chat_routes(
finally:
_active_streams.pop(session, None)
# Run the stream as a DETACHED background task so it survives the client
# closing the tab / navigating away (true terminal-agent behavior). The
# SSE response just subscribes (replay buffered output + live); dropping
# the SSE only removes a subscriber — the run keeps going and saves the
# assistant message on completion regardless. Reconnect via /api/chat/resume.
# Compare panes are short-lived, single-shot generations whose sessions
# exist only to drive that one pane — there's nothing to "resume" and
# the user expects the pane's Stop button (which aborts the fetch,
# closing this SSE) to promptly cancel the upstream LLM call. Detaching
# them would keep burning upstream tokens/compute after the pane is
# stopped or the comparison is abandoned, and would surface a stale
# "still streaming" /resume target for a session nobody will revisit.
#
# So: stream them directly (no agent_runs wrapping). Starlette cancels
# the underlying async generator (raising CancelledError/GeneratorExit
# inside it) as soon as it notices the client disconnected — which the
# mode-specific except blocks above already handle by saving the
# partial response exactly once. This stops the upstream call promptly
# without waiting on the next streamed chunk.
#
# Normal chat/agent streams keep the DETACHED behavior below: they
# survive the client closing the tab / navigating away (true
# terminal-agent semantics). The SSE response just subscribes (replay
# buffered output + live); dropping the SSE only removes a subscriber —
# the run keeps going and saves the assistant message on completion
# regardless. Reconnect via /api/chat/resume.
if compare_mode:
return StreamingResponse(_safe_stream(), media_type="text/event-stream")
agent_runs.start(session, _safe_stream())
return StreamingResponse(agent_runs.subscribe(session), media_type="text/event-stream")
@@ -1185,45 +1348,16 @@ def setup_chat_routes(
return []
_user = get_current_user(request)
query_term = q.strip()
db = SessionLocal()
try:
base_q = (
db.query(DBChatMessage, DBSession.name)
.join(DBSession, DBChatMessage.session_id == DBSession.id)
.filter(
DBSession.archived == False,
DBChatMessage.content.ilike(f"%{query_term}%"),
DBChatMessage.role.in_(["user", "assistant"]),
)
return [
result.to_dict()
for result in search_session_messages(
q,
limit=limit,
owner=_user,
restrict_owner=_user is not None,
include_legacy_owner=False,
)
if _user:
base_q = base_q.filter(DBSession.owner == _user)
rows = base_q.order_by(DBChatMessage.timestamp.desc()).limit(limit).all()
results = []
for msg, session_name in rows:
content = msg.content or ""
lower_content = content.lower()
idx = lower_content.find(query_term.lower())
if idx == -1:
snippet = content[:120]
else:
start = max(0, idx - 50)
end = min(len(content), idx + len(query_term) + 50)
snippet = ("..." if start > 0 else "") + content[start:end] + ("..." if end < len(content) else "")
results.append({
"session_id": msg.session_id,
"session_name": session_name or "Untitled",
"role": msg.role,
"content_snippet": snippet,
"timestamp": msg.timestamp.isoformat() if msg.timestamp else None,
})
return results
finally:
db.close()
]
# ------------------------------------------------------------------ #
# POST /api/rewrite — lightweight rewrite of last AI message (no tools)
+170
View File
@@ -0,0 +1,170 @@
"""ChatGPT Subscription device-flow setup routes."""
import json
import logging
import uuid
from typing import Dict, Optional
from fastapi import HTTPException, Request
from core.database import ModelEndpoint, ProviderAuthSession, SessionLocal, utcnow_naive
from routes.device_flow import (
DeviceFlowPoll,
DeviceFlowStart,
PendingDeviceFlowStore,
create_device_flow_router,
)
from src.auth_helpers import get_current_user
from src import chatgpt_subscription
logger = logging.getLogger(__name__)
_DEVICE_FLOW_STORE = PendingDeviceFlowStore()
def _provision_endpoint(tokens: Dict, owner: Optional[str]) -> Dict:
access_token = tokens.get("access_token")
refresh_token = tokens.get("refresh_token")
if not access_token or not refresh_token:
raise ValueError("ChatGPT token response was missing access_token or refresh_token")
base = chatgpt_subscription.DEFAULT_CHATGPT_SUBSCRIPTION_BASE_URL
models = chatgpt_subscription.fetch_available_models(access_token)
if not models:
raise ValueError("ChatGPT Subscription connected, but no usable Codex models were discovered for this account.")
db = SessionLocal()
try:
auth = (
db.query(ProviderAuthSession)
.filter(
ProviderAuthSession.provider == chatgpt_subscription.CHATGPT_SUBSCRIPTION_PROVIDER,
ProviderAuthSession.owner == owner,
)
.first()
)
if auth is None:
auth = ProviderAuthSession(
id=str(uuid.uuid4())[:8],
provider=chatgpt_subscription.CHATGPT_SUBSCRIPTION_PROVIDER,
owner=owner,
label="ChatGPT Subscription",
base_url=base,
auth_mode="chatgpt",
)
db.add(auth)
auth.base_url = base
auth.access_token = access_token
auth.refresh_token = refresh_token
auth.last_refresh = utcnow_naive()
auth.auth_mode = "chatgpt"
ep = (
db.query(ModelEndpoint)
.filter(
ModelEndpoint.base_url == base,
ModelEndpoint.provider_auth_id == auth.id,
ModelEndpoint.owner == owner,
)
.first()
)
if ep is None:
ep = ModelEndpoint(
id=str(uuid.uuid4())[:8],
name="ChatGPT Subscription",
base_url=base,
model_type="llm",
endpoint_kind="api",
owner=owner,
)
db.add(ep)
ep.name = "ChatGPT Subscription"
ep.base_url = base
ep.api_key = None
ep.provider_auth_id = auth.id
ep.is_enabled = True
ep.supports_tools = False
ep.model_type = "llm"
ep.endpoint_kind = "api"
ep.model_refresh_mode = "manual"
ep.cached_models = json.dumps(models)
db.commit()
result = {
"id": ep.id,
"name": ep.name,
"base_url": ep.base_url,
"models": models,
}
finally:
db.close()
try:
from routes.model_routes import _invalidate_models_cache
_invalidate_models_cache()
except Exception:
pass
return result
def _start_device_flow(request: Request, _form) -> DeviceFlowStart:
try:
data = chatgpt_subscription.request_device_code()
except Exception as exc:
raise chatgpt_subscription.to_http_exception(exc)
device_auth_id = data.get("device_auth_id")
user_code = data.get("user_code")
if not device_auth_id or not user_code:
raise HTTPException(502, "ChatGPT did not return a complete device code")
verification_uri = data.get("verification_uri") or f"{chatgpt_subscription.CHATGPT_OAUTH_ISSUER}/codex/device"
return DeviceFlowStart(
pending={
"device_auth_id": device_auth_id,
"user_code": user_code,
"owner": get_current_user(request) or None,
},
response={
"user_code": user_code,
"verification_uri": verification_uri,
},
interval=int(data.get("interval") or 5),
expires_in=int(data.get("expires_in") or 900),
)
def _poll_device_flow(_request: Request, pending: Dict) -> DeviceFlowPoll:
try:
data = chatgpt_subscription.poll_device_auth(pending["device_auth_id"], pending["user_code"])
except Exception as exc:
logger.debug("ChatGPT device poll failed: %s", exc)
return DeviceFlowPoll.pending(str(exc))
authorization_code = data.get("authorization_code")
code_verifier = data.get("code_verifier")
if authorization_code and code_verifier:
try:
tokens = chatgpt_subscription.exchange_authorization_code(authorization_code, code_verifier)
result = _provision_endpoint(tokens, pending["owner"])
except Exception as exc:
logger.exception("ChatGPT Subscription endpoint provisioning failed")
raise chatgpt_subscription.to_http_exception(exc)
return DeviceFlowPoll.authorized(result)
err = data.get("error") or data.get("status")
if err in ("authorization_pending", "pending", None):
return DeviceFlowPoll.pending()
if err == "slow_down":
return DeviceFlowPoll.slow_down(int(data.get("interval") or 0) or None)
if err in ("expired_token", "access_denied", "denied"):
return DeviceFlowPoll.failed(err)
return DeviceFlowPoll.pending(err or "unknown")
def setup_chatgpt_subscription_routes():
return create_device_flow_router(
prefix="/api/chatgpt-subscription",
tags=["chatgpt-subscription"],
store=_DEVICE_FLOW_STORE,
start_flow=_start_device_flow,
poll_flow=_poll_device_flow,
)
+16 -6
View File
@@ -15,8 +15,9 @@ from typing import Any
from fastapi import APIRouter, BackgroundTasks, Body, HTTPException, Request
from fastapi.responses import StreamingResponse
from src.auth_helpers import require_user
from src.auth_helpers import require_authenticated_request, require_user
from src.tool_implementations import do_manage_notes
from src.constants import COOKBOOK_STATE_FILE
COOKBOOK_READ_SCOPES = {"cookbook:read", "cookbook:launch"}
@@ -41,7 +42,9 @@ async def _as_owner(request: Request, owner: str, fn, *args, **kwargs):
the scope-gated owner (not the "api" pseudo-user the bearer middleware sets).
Restores the original value when done. Works for sync and async handlers."""
orig = getattr(request.state, "current_user", None)
orig_api_token = getattr(request.state, "api_token", None)
request.state.current_user = owner
request.state.api_token = False
try:
result = fn(*args, **kwargs)
if asyncio.iscoroutine(result):
@@ -49,6 +52,13 @@ async def _as_owner(request: Request, owner: str, fn, *args, **kwargs):
return result
finally:
request.state.current_user = orig
if orig_api_token is None:
try:
delattr(request.state, "api_token")
except AttributeError:
pass
else:
request.state.api_token = orig_api_token
def _scope_owner(request: Request, allowed: set[str]) -> str:
@@ -146,7 +156,7 @@ def setup_codex_routes(
@router.get("/plugin.zip")
def plugin_zip(request: Request):
require_user(request)
require_authenticated_request(request)
root = Path(__file__).resolve().parent.parent / "integrations" / "codex"
if not root.exists():
raise HTTPException(404, "Codex plugin bundle not found")
@@ -415,8 +425,8 @@ def setup_codex_routes(
def _read_cookbook_state() -> dict:
from pathlib import Path as _Path
import os as _os, json as _json
p = _Path(_os.environ.get("DATA_DIR", "data")) / "cookbook_state.json"
import json as _json
p = _Path(COOKBOOK_STATE_FILE)
if not p.exists():
return {}
try:
@@ -724,7 +734,7 @@ def setup_codex_routes(
import time as _t, json as _json
from core.atomic_io import atomic_write_json
from pathlib import Path as _Path
cookbook_state_path = _Path("/app/data/cookbook_state.json")
cookbook_state_path = _Path(COOKBOOK_STATE_FILE)
try:
state = _json.loads(cookbook_state_path.read_text(encoding="utf-8"))
except Exception:
@@ -762,7 +772,7 @@ def setup_claude_routes() -> APIRouter:
@router.get("/plugin.zip")
def plugin_zip(request: Request):
require_user(request)
require_authenticated_request(request)
# Only ship the skills/ subtree so extracting at ~/.claude/ doesn't dump
# README.md or other bundle metadata into the user's claude config dir.
skills_root = Path(__file__).resolve().parent.parent / "integrations" / "claude" / "skills"
+110 -22
View File
@@ -12,6 +12,7 @@ import logging
from core.database import Comparison, SessionLocal
from core.session_manager import SessionManager
from src.auth_helpers import get_current_user
from routes.session_routes import _reject_raw_endpoint_url_for_non_admin
logger = logging.getLogger(__name__)
@@ -38,6 +39,24 @@ def _owned_endpoint_by_url(db, base_url, owner):
return owner_filter(q, ModelEndpoint, owner).first()
def _owned_endpoint_by_id(db, endpoint_id, owner):
"""ModelEndpoint whose id == `endpoint_id` and is VISIBLE to `owner` (their
own rows + legacy null-owner "shared" rows); None otherwise.
Preferred over _owned_endpoint_by_url for credential resolution: two visible
endpoints can share the same base_url but hold DIFFERENT api_keys (e.g. two
accounts on the same provider). A base_url-only match returns whichever row
sorts first, so it can copy the WRONG owner-scoped key into the [CMP] session.
An id pins the exact registered endpoint, so /api/compare/start prefers it and
only falls back to URL matching for legacy / admin raw-URL callers. Owner
scoping is identical to _owned_endpoint_by_url (a null/empty owner is a no-op).
"""
from core.database import ModelEndpoint
from src.auth_helpers import owner_filter
q = db.query(ModelEndpoint).filter(ModelEndpoint.id == endpoint_id)
return owner_filter(q, ModelEndpoint, owner).first()
class RecordVoteRequest(BaseModel):
prompt: str
models: List[str]
@@ -54,8 +73,10 @@ def setup_compare_routes(session_manager: SessionManager):
prompt: str = Form(...),
model_a: str = Form(...),
model_b: str = Form(...),
endpoint_a: str = Form(...),
endpoint_b: str = Form(...),
endpoint_a: str = Form(""),
endpoint_b: str = Form(""),
endpoint_a_id: str = Form(""),
endpoint_b_id: str = Form(""),
is_blind: str = Form("true"),
):
"""Create two ephemeral sessions and a comparison record.
@@ -63,10 +84,10 @@ def setup_compare_routes(session_manager: SessionManager):
Returns the comparison ID and the two session IDs so the client
can fire two independent SSE streams to /api/chat_stream.
"""
user = getattr(request.state, 'current_user', None)
comp_id = str(uuid.uuid4())
sid_a = str(uuid.uuid4())
sid_b = str(uuid.uuid4())
user = getattr(request.state, 'current_user', None)
# Blind mapping: randomly assign left/right
blind = str(is_blind).lower() == "true"
@@ -87,31 +108,94 @@ def setup_compare_routes(session_manager: SessionManager):
# de-anonymizing the comparison before the user votes (issue #1285).
slot_name = {session_left: "Model A", session_right: "Model B"}
# Create ephemeral sessions (prefixed [CMP])
for sid, model, endpoint in [(sid_a, model_a, endpoint_a), (sid_b, model_b, endpoint_b)]:
# SECURITY: resolve and validate BOTH endpoints before creating any
# session. Compare copies a registered endpoint's Authorization header
# into the [CMP] session, so validating one endpoint while creating its
# session, then rejecting the other, would leave a partial compare
# session behind with that header attached. Doing all the owner-scope
# resolution + raw-URL rejection up front means a 403 on either endpoint
# aborts the whole request with nothing created and no header copied.
from src.endpoint_resolver import build_chat_url, build_headers, normalize_base
resolved = []
db = SessionLocal()
try:
for sid, model, endpoint, endpoint_id in [
(sid_a, model_a, endpoint_a, endpoint_a_id),
(sid_b, model_b, endpoint_b, endpoint_b_id),
]:
# Prefer an explicit endpoint id: it pins the EXACT registered
# endpoint (and its api_key), even when two endpoints visible to
# the caller share a base_url with different keys — a URL-only
# match would copy whichever row sorts first, i.e. possibly the
# wrong key. Fall back to URL resolution only for legacy / admin
# raw-URL callers that don't send an id.
eid = endpoint_id.strip() if isinstance(endpoint_id, str) else ""
if eid:
ep = _owned_endpoint_by_id(db, eid, user)
if ep is None:
# An id the caller can't see (wrong owner / deleted) must
# NOT silently fall back to a same-URL row with a different
# key — that's exactly the mix-up ids exist to prevent.
raise HTTPException(404, "Model endpoint not found")
# The id already resolved the endpoint; ignore any raw URL the
# caller also sent and dial the stored config instead.
endpoint = ep.base_url
elif not endpoint:
raise HTTPException(
422, "endpoint_a/endpoint_b or endpoint_a_id/endpoint_b_id is required"
)
else:
# Resolve the supplied URL to a ModelEndpoint the caller owns
# (their own rows + legacy null-owner shared rows), scoped so a
# comparison can't borrow another user's private endpoint key.
base = normalize_base(endpoint)
ep = _owned_endpoint_by_url(db, base, user)
# Reject *unregistered* raw URLs for signed-in non-admins; a
# matched registered endpoint supplies an id so the caller can
# still compare endpoints they own. Blanket-rejecting here (the
# earlier `endpoint_id=None` call) locked non-admins out of
# compare entirely, since compare resolves endpoints by URL with
# no endpoint_id. Mirrors the gallery inpaint/harmonize checks.
# Raised here (phase 1), before any session exists.
_reject_raw_endpoint_url_for_non_admin(
request, user, str(ep.id) if ep is not None else None, endpoint
)
# Bind the [CMP] session to the RESOLVED endpoint, not the raw
# caller-supplied string. When the URL matches a registered
# endpoint visible to the caller, use that row's own normalized
# base URL (the same value owner scoping + endpoint validation
# already vetted) so the session dials exactly where the stored
# config points. The raw `endpoint` only survives for callers
# allowed to pass one — admins / single-user mode, where
# `_reject_raw_endpoint_url_for_non_admin` is a no-op and `ep`
# is None. Mirrors the registered-endpoint path in session_routes.
session_endpoint_url = (
build_chat_url(normalize_base(ep.base_url)) if ep is not None else endpoint
)
# Headers come only from a matched endpoint's key; None when
# `ep` is None (raw admin URL or no match), so a comparison can
# never inherit another user's key/headers.
headers = build_headers(ep.api_key, ep.base_url) if (ep and ep.api_key) else None
resolved.append((sid, model, session_endpoint_url, headers))
finally:
db.close()
# Both endpoints validated — only now create the ephemeral [CMP]
# sessions and copy any resolved headers.
for sid, model, session_endpoint_url, headers in resolved:
name = f"[CMP] {slot_name[sid]}" if blind else f"[CMP] {model.split('/')[-1]}"
session_manager.create_session(
session_id=sid,
name=name,
endpoint_url=endpoint,
endpoint_url=session_endpoint_url,
model=model,
rag=False,
owner=user,
)
# Copy API key from endpoint config
db = SessionLocal()
try:
from src.endpoint_resolver import build_headers, normalize_base
# Find matching endpoint by URL, scoped to the caller so a
# comparison can't borrow another user's private endpoint key.
base = normalize_base(endpoint)
ep = _owned_endpoint_by_url(db, base, user)
if ep and ep.api_key:
s = session_manager.sessions.get(sid)
if s:
s.headers = build_headers(ep.api_key, ep.base_url)
finally:
db.close()
if headers:
s = session_manager.sessions.get(sid)
if s:
s.headers = headers
# Store comparison record
db = SessionLocal()
@@ -121,8 +205,12 @@ def setup_compare_routes(session_manager: SessionManager):
prompt=prompt,
model_a=model_a,
model_b=model_b,
endpoint_a=endpoint_a,
endpoint_b=endpoint_b,
# Record the URL the session actually dials. For URL callers this
# is their raw input; for id-only callers (empty endpoint_a/_b)
# fall back to the resolved endpoint URL so the column stays
# meaningful and non-null. resolved is in [a, b] order.
endpoint_a=endpoint_a or resolved[0][2],
endpoint_b=endpoint_b or resolved[1][2],
is_blind=blind,
blind_mapping=json.dumps(mapping),
owner=user,
+53 -18
View File
@@ -11,20 +11,24 @@ import uuid
import json
import csv
import io
import os
import httpx
from pathlib import Path
from datetime import datetime
from fastapi import APIRouter, Query, Depends, Response
from urllib.parse import urljoin, urlparse, urlunparse
from fastapi import APIRouter, Query, Depends, Response, HTTPException
from typing import List, Dict, Optional
from src.auth_helpers import require_user
from core.middleware import require_admin
from src.url_safety import check_outbound_url
logger = logging.getLogger(__name__)
DATA_DIR = Path(__file__).resolve().parent.parent / "data"
SETTINGS_FILE = DATA_DIR / "settings.json"
LOCAL_CONTACTS_FILE = DATA_DIR / "contacts.json"
from src.constants import DATA_DIR as _DATA_DIR, SETTINGS_FILE as _SETTINGS_FILE, CONTACTS_FILE as _CONTACTS_FILE
DATA_DIR = Path(_DATA_DIR)
SETTINGS_FILE = Path(_SETTINGS_FILE)
LOCAL_CONTACTS_FILE = Path(_CONTACTS_FILE)
def _load_settings():
@@ -53,6 +57,21 @@ def _carddav_configured(cfg: Optional[Dict] = None) -> bool:
return bool((cfg.get("url") or "").strip())
def _validate_carddav_url(url: str) -> str:
cleaned = (url if isinstance(url, str) else "").strip().rstrip("/")
ok, reason = check_outbound_url(
cleaned,
block_private=os.getenv("CARDDAV_BLOCK_PRIVATE_IPS", "false").lower() == "true",
)
if not ok:
raise ValueError(f"Rejected CardDAV URL: {reason}")
return cleaned
def _carddav_base_url(cfg: Dict) -> str:
return _validate_carddav_url(cfg.get("url") or "")
def _normalize_contact(contact: Dict) -> Dict:
emails = []
for e in contact.get("emails") or ([] if not contact.get("email") else [contact.get("email")]):
@@ -219,14 +238,18 @@ _contact_cache = {"contacts": [], "fetched_at": None}
def _abs_url(href: str) -> str:
"""Combine a multistatus <href> (an absolute path like
/user/contacts/x.vcf) with the configured CardDAV server origin so we
get a fully-qualified URL to PUT/DELETE. If href is already absolute
(http...), return it as-is."""
from urllib.parse import urlparse, urlunparse
if href.startswith("http://") or href.startswith("https://"):
return href
get a fully-qualified URL to PUT/DELETE. Absolute hrefs are accepted only
for the configured origin; a cross-origin href is treated as a path on the
configured server so a malicious CardDAV response cannot redirect later
writes/deletes to cloud metadata or another host."""
cfg = _get_carddav_config()
p = urlparse(cfg["url"])
return urlunparse((p.scheme, p.netloc, href, "", "", ""))
base = _carddav_base_url(cfg)
base_p = urlparse(base)
joined = urljoin(base.rstrip("/") + "/", href or "")
joined_p = urlparse(joined)
if (joined_p.scheme, joined_p.netloc) != (base_p.scheme, base_p.netloc):
joined = urlunparse((base_p.scheme, base_p.netloc, joined_p.path or "/", "", joined_p.query, ""))
return _validate_carddav_url(joined)
# CardDAV REPORT body — pull every card's etag + raw vCard in ONE request,
@@ -297,6 +320,7 @@ def _fetch_contacts(force=False):
return contacts
try:
cfg["url"] = _carddav_base_url(cfg)
auth = None
if cfg["username"]:
auth = (cfg["username"], cfg["password"])
@@ -353,8 +377,8 @@ def _create_contact(name: str, email: str) -> bool:
contact_uid = str(uuid.uuid4())
vcard = _build_vcard(name, email, contact_uid)
url = cfg["url"].rstrip("/") + "/" + contact_uid + ".vcf"
try:
url = _carddav_base_url(cfg) + "/" + contact_uid + ".vcf"
auth = None
if cfg["username"]:
auth = (cfg["username"], cfg["password"])
@@ -382,7 +406,7 @@ def _vcard_url(uid: str) -> str:
escape the collection and target an arbitrary CardDAV resource."""
from urllib.parse import quote
cfg = _get_carddav_config()
return cfg["url"].rstrip("/") + "/" + quote(uid, safe="") + ".vcf"
return _carddav_base_url(cfg) + "/" + quote(uid, safe="") + ".vcf"
def _import_vcards(text: str) -> Dict:
@@ -413,6 +437,11 @@ def _import_vcards(text: str) -> Dict:
if imported:
_save_local_contacts(contacts)
return {"imported": imported, "failed": 0, "total": len(parsed)}
try:
base_url = _carddav_base_url(cfg)
except ValueError as e:
logger.warning("CardDAV import URL rejected: %s", e)
return {"imported": 0, "failed": 0, "total": 0, "error": str(e)}
auth = (cfg["username"], cfg["password"]) if cfg["username"] else None
# Split into individual cards. re.split drops the BEGIN line, so we
# re-add it. Normalize CRLF.
@@ -441,7 +470,7 @@ def _import_vcards(text: str) -> Dict:
elif not re.search(r"^VERSION:", block, re.MULTILINE):
block = block.replace("BEGIN:VCARD", "BEGIN:VCARD\nVERSION:4.0", 1)
vcard = block.replace("\n", "\r\n") + "\r\n"
url = cfg["url"].rstrip("/") + "/" + quote(uid, safe="") + ".vcf"
url = base_url + "/" + quote(uid, safe="") + ".vcf"
try:
r = httpx.put(
url, data=vcard.encode("utf-8"),
@@ -601,8 +630,8 @@ def _update_contact(uid: str, name: str, emails: List[str], phones: List[str]) -
vcard = _build_vcard(name, "", uid=uid, emails=emails, phones=phones)
# Use the real resource href (handles externally-created contacts whose
# filename != UID); falls back to the <uid>.vcf guess.
url = _resolve_resource_url(uid)
try:
url = _resolve_resource_url(uid)
auth = (cfg["username"], cfg["password"]) if cfg["username"] else None
r = httpx.put(
url,
@@ -630,8 +659,8 @@ def _delete_contact(uid: str) -> bool:
_save_local_contacts(remaining)
return True
url = _resolve_resource_url(uid)
try:
url = _resolve_resource_url(uid)
auth = (cfg["username"], cfg["password"]) if cfg["username"] else None
r = httpx.delete(url, auth=auth, timeout=10)
if r.status_code in (200, 204):
@@ -747,7 +776,13 @@ def setup_contacts_routes():
settings = _load_settings()
for key in ("carddav_url", "carddav_username", "carddav_password"):
if key in data:
settings[key] = data[key]
if key == "carddav_url" and str(data[key] or "").strip():
try:
settings[key] = _validate_carddav_url(data[key])
except ValueError as e:
raise HTTPException(400, str(e))
else:
settings[key] = data[key]
_save_settings(settings)
# Force re-fetch
_contact_cache["fetched_at"] = None
+312 -14
View File
@@ -11,6 +11,8 @@ import shlex
from fastapi import HTTPException
from pydantic import BaseModel
from core.platform_compat import _ssh_exec_argv
logger = logging.getLogger(__name__)
@@ -195,6 +197,20 @@ def _pip_install_attempt(pip_cmd: str) -> str:
)
def _pip_command(python_cmd: str) -> str:
"""Return a pip command for either a pip executable or a Python executable."""
cmd = python_cmd.strip()
if " -m pip" in cmd or cmd in {"pip", "pip3"}:
return python_cmd
if cmd in {"python", "python3", "python.exe"} or cmd.endswith(("/python", "/python3", "\\python.exe")):
return f"{python_cmd} -m pip"
return python_cmd
def _pip_break_system_packages_check(pip_cmd: str) -> str:
return f"{pip_cmd} install --help 2>/dev/null | grep -q -- --break-system-packages"
def _pip_install_fallback_chain(package: str, *, python_cmd: str = "python3 -m pip", upgrade: bool = False) -> str:
"""Build a bash pip install fallback chain that surfaces errors.
@@ -206,33 +222,44 @@ def _pip_install_fallback_chain(package: str, *, python_cmd: str = "python3 -m p
exit code is preserved (no ``| tail`` masking) and the last 5 lines of
pip output appear in the Cookbook log on failure.
"""
from core.platform_compat import IS_WINDOWS
upgrade_flag = " -U" if upgrade else ""
# Shell-quote the package spec: an extras spec like ``llama-cpp-python[server]``
# contains brackets that bash would treat as a glob, so it must be quoted
# before being embedded in the install command. Plain names (e.g.
# ``huggingface_hub``) are returned unchanged by ``shlex.quote``.
pkg = shlex.quote(package)
base = _pip_install_attempt(f"{python_cmd} install -q{upgrade_flag} {pkg}")
user = _pip_install_attempt(f"{python_cmd} install --user --break-system-packages -q{upgrade_flag} {pkg}")
# llama-cpp-python source builds are brittle on older distro pip/packaging
# stacks (common on WSL images). Prefer the prebuilt wheel index whenever
# this package is requested so dependency-install tasks are reliable.
if "llama-cpp-python" in package:
pkg += " --extra-index-url https://abetlen.github.io/llama-cpp-python/whl/cpu"
pip_cmd = _pip_command(python_cmd)
base = _pip_install_attempt(f"{pip_cmd} install -q{upgrade_flag} {pkg}")
user = _pip_install_attempt(f"{pip_cmd} install --user -q{upgrade_flag} {pkg}")
user_break_system = _pip_install_attempt(f"{pip_cmd} install --user --break-system-packages -q{upgrade_flag} {pkg}")
user_fallback = f"( {user} || {{ {_pip_break_system_packages_check(pip_cmd)} && {user_break_system}; }} )"
# Derive the python executable for the venv detection check.
# Must use the same interpreter that pip belongs to; hardcoding
# python3 breaks when pip lives in a venv that only has "python".
if " -m pip" in python_cmd:
python_exe = python_cmd.replace(" -m pip", "")
elif python_cmd.strip() == "pip":
if " -m pip" in pip_cmd:
python_exe = pip_cmd.replace(" -m pip", "")
elif pip_cmd.strip() == "pip":
python_exe = "python"
elif python_cmd.strip() == "pip3":
elif pip_cmd.strip() == "pip3":
python_exe = "python3"
else:
python_exe = "python3"
venv_check = f'{python_exe} -c "import sys; sys.exit(0 if sys.prefix != sys.base_prefix else 1)"'
# Negated: `! venv_check` succeeds (exit 0) when NOT in a venv `&&` tries
# --user. When IN a venv `! venv_check` fails `&&` skips --user and the
# Negated: `! venv_check` succeeds (exit 0) when NOT in a venv -> `&&` tries
# --user. When IN a venv `! venv_check` fails -> `&&` skips --user and the
# group exits non-zero, propagating the base-install failure instead of
# masking it as success (the `|| { venv_check || … }` shape from #903
# swallowed the exit code because venv_check's exit-0 became the group's
# result).
return f"{base} || {{ ! {venv_check} && {user}; }}"
# result). `--break-system-packages` is only attempted when the active pip
# supports it; older pip versions abort with "no such option" otherwise.
return f"{base} || {{ ! {venv_check} && {user_fallback}; }}"
def _venv_safe_local_pip_install_cmd(cmd: str, *, local: bool, in_venv: bool) -> str:
@@ -263,6 +290,55 @@ def _venv_safe_local_pip_install_cmd(cmd: str, *, local: bool, in_venv: bool) ->
return shlex.join(stripped)
def _pip_install_command_without_break_system_packages(cmd: str) -> str:
try:
parts = shlex.split(cmd)
except ValueError:
return cmd
stripped = [part for part in parts if part != "--break-system-packages"]
return shlex.join(stripped)
def _pip_install_help_check_from_cmd(cmd: str) -> str | None:
try:
parts = shlex.split(cmd)
except ValueError:
return None
try:
install_index = parts.index("install")
except ValueError:
return None
if install_index <= 0:
return None
pip_prefix = parts[:install_index]
return f"{shlex.join(pip_prefix + ['install', '--help'])} 2>/dev/null | grep -q -- --break-system-packages"
def _append_pip_install_runner_lines(runner_lines: list[str], cmd: str) -> None:
"""Append a pip install command, guarding --break-system-packages support.
The Dependencies UI may submit ``python3 -m pip install --user
--break-system-packages ...`` for non-venv installs. That flag is useful on
PEP-668-locked distros, but older pip (including Ubuntu 22.04's apt pip in
the NVIDIA CUDA base image) aborts with "no such option". Branch at runner
time so stale browser JS and remote targets are handled by the server too.
"""
if "--break-system-packages" not in (cmd or ""):
runner_lines.append(cmd)
return
help_check = _pip_install_help_check_from_cmd(cmd)
without_break = _pip_install_command_without_break_system_packages(cmd)
if not help_check or without_break == cmd:
runner_lines.append(cmd)
return
runner_lines.append(f"if {help_check}; then")
runner_lines.append(f" {cmd}")
runner_lines.append("else")
runner_lines.append(' echo "[odysseus] pip does not support --break-system-packages; installing without it."')
runner_lines.append(f" {without_break}")
runner_lines.append("fi")
def _user_shell_path_bootstrap() -> list[str]:
return [
'ODYSSEUS_USER_SHELL="${SHELL:-}"',
@@ -271,11 +347,14 @@ def _user_shell_path_bootstrap() -> list[str]:
' if [ -n "$ODYSSEUS_USER_PATH" ]; then export PATH="$ODYSSEUS_USER_PATH:$PATH"; fi',
'fi',
'command -v python3 >/dev/null 2>&1 || python3() { python "$@"; }',
'command -v python >/dev/null 2>&1 || python() { python3 "$@"; }',
]
def _cached_model_scan_script(model_dirs: list[str] | None = None) -> str:
"""Build the standalone Python scanner used by /api/model/cached."""
def _cached_model_scan_script(model_dirs: list[str] | None = None, add_hf_cache: str | None = None) -> str:
"""Build the standalone Python scanner used by /api/model/cached.
Allows for an additional HuggingFace cache path to be scanned (i.e. Windows HF cache for local WSL envs.)
"""
lines = [
"import json, os, re, shutil, subprocess, urllib.request",
"models = []",
@@ -338,6 +417,15 @@ def _cached_model_scan_script(model_dirs: list[str] | None = None) -> str:
" if f.is_file(): nf += 1; sz += f.stat().st_size",
" if f.name.endswith('.incomplete'): ic = True",
" snap = os.path.join(cache, d, 'snapshots')",
" # Windows HF cache stores files directly in snapshots/; blobs/ may be empty.",
" # Fallback: scan snapshots for real files when blobs yielded nothing.",
" if sz == 0 and os.path.isdir(snap):",
" for sd in os.listdir(snap):",
" sf = os.path.join(snap, sd)",
" if not os.path.isdir(sf): continue",
" for f in os.scandir(sf):",
" if f.is_file(): nf += 1; sz += f.stat().st_size",
" if f.name.endswith('.incomplete'): ic = True",
" is_diffusion = False; gguf_files = []",
" if os.path.isdir(snap):",
" for sd in os.listdir(snap):",
@@ -346,6 +434,21 @@ def _cached_model_scan_script(model_dirs: list[str] | None = None) -> str:
" if os.path.exists(os.path.join(sf, 'model_index.json')): is_diffusion = True",
" for f in collect_ggufs(sf): f['rel_path'] = sd + '/' + f['rel_path']; gguf_files.append(f)",
" models.append({'repo_id':rid,'size_bytes':sz,'nb_files':nf,'has_incomplete':ic,'path':cache,'is_diffusion':is_diffusion,'is_gguf':bool(gguf_files),'gguf_files':gguf_files})",
"def hf_cache_paths():",
" candidates = []",
" def add(p):",
" if not p: return",
" p = os.path.expanduser(p)",
" if p not in candidates: candidates.append(p)",
" add(os.environ.get('HUGGINGFACE_HUB_CACHE'))",
" hf_home = os.environ.get('HF_HOME')",
" if hf_home: add(os.path.join(hf_home, 'hub'))",
" add('~/.cache/huggingface/hub')",
" # Docker images mount ./data/huggingface at /app/.cache/huggingface.",
" # When HOME is /root, expanduser() misses that persisted cache.",
" add('/app/.cache/huggingface/hub')",
f" add({add_hf_cache!r})" if add_hf_cache else "",
" return candidates",
"def scan_dir(p):",
" if not os.path.isdir(p) or not safe_path(p): return",
" for d in sorted(os.listdir(p)):",
@@ -409,7 +512,7 @@ def _cached_model_scan_script(model_dirs: list[str] | None = None) -> str:
" seen.add(name)",
" models.append({'repo_id':name,'size_bytes':size_bytes,'nb_files':1,'has_incomplete':False,'path':'ollama','backend':'ollama','is_ollama':True})",
" return",
"scan_hf(os.path.expanduser('~/.cache/huggingface/hub'))",
"for _hf_cache in hf_cache_paths(): scan_hf(_hf_cache)",
"scan_ollama()",
"scan_ollama_api()",
]
@@ -525,6 +628,7 @@ def _validate_serve_cmd(v: str | None) -> str | None:
# Backticks and raw newlines are never legitimate here.
if any(c in v for c in ("`", "\n", "\r")):
raise HTTPException(400, "Invalid characters in cmd")
# Known GGUF launcher prelude → validate the serve invocation(s) it guards.
m = _GGUF_PRELUDE_RE.match(v)
if m:
@@ -533,9 +637,19 @@ def _validate_serve_cmd(v: str | None) -> str | None:
for part in rest.split("||"):
_check_serve_binary(part.strip())
return v
# Otherwise: a single invocation — no shell metacharacters allowed.
# Temporarily replace safe $(printf %s ...) expressions with a placeholder
# to avoid triggering the metacharacter/command-injection checks.
cleaned_v = v
printf_matches = list(re.finditer(r"\$\(\s*printf\s+%s\s+([^\n()]*?)\)", v))
for match in printf_matches:
inner = match.group(1)
if not any(c in inner for c in (";", "&&", "||", "$(", "`")):
cleaned_v = cleaned_v.replace(match.group(0), "/placeholder/safe/path.gguf")
# (`$(` was the original intent; bare `$` is fine for shell-safe paths.)
if any(c in v for c in (";", "&&", "||", "$(")):
if any(c in cleaned_v for c in (";", "&&", "||", "$(")):
raise HTTPException(400, "Invalid characters in cmd")
_check_serve_binary(v)
return v
@@ -559,6 +673,21 @@ def _append_serve_preflight_exit_lines(runner_lines: list[str], *, keep_shell_op
runner_lines.append('fi')
def _append_vllm_linux_preflight_lines(runner_lines: list[str]) -> None:
"""Append Linux vLLM readiness lines that identify the runtime being used."""
# Keep the user install bin visible for Odysseus-managed `pip install --user`
# installs, but then report the actual CLI path so external runtimes are clear.
runner_lines.append('export PATH="$HOME/.local/bin:$PATH"')
runner_lines.append('ODYSSEUS_VLLM_BIN="$(command -v vllm 2>/dev/null || true)"')
runner_lines.append('if [ -z "$ODYSSEUS_VLLM_BIN" ]; then')
runner_lines.append(' echo "ERROR: vLLM is not installed."')
runner_lines.append(' ODYSSEUS_PREFLIGHT_EXIT=127')
runner_lines.append('else')
runner_lines.append(' echo "[odysseus] vLLM CLI: $ODYSSEUS_VLLM_BIN"')
runner_lines.append(' ODYSSEUS_VLLM_VERSION="$("$ODYSSEUS_VLLM_BIN" --version 2>&1 | head -n 1 || true)"')
runner_lines.append(' if [ -n "$ODYSSEUS_VLLM_VERSION" ]; then echo "[odysseus] vLLM version: $ODYSSEUS_VLLM_VERSION"; fi')
runner_lines.append('fi')
def _append_serve_exit_code_lines(
runner_lines: list[str],
*,
@@ -804,3 +933,172 @@ def _ssh_ps(host, script_path, port=None):
# Windows session dir — stored in user's temp on the remote
WIN_SESSION_DIR = "$env:TEMP\\\\odysseus-sessions"
def _diagnose_serve_output(text: str) -> dict | None:
"""Server-side mirror of the Cookbook UI's common serve diagnoses.
The browser uses cookbook-diagnosis.js for clickable fixes. This gives
the agent/tool path the same structured signal so it can retry with an
adjusted command instead of guessing from raw tmux output.
"""
if not text:
return None
tail = text[-6000:]
patterns = [
(
r"No available memory for the cache blocks|Available KV cache memory:.*-",
"No GPU memory left for KV cache after loading model.",
[
{"label": "retry with GPU memory utilization 0.95", "op": "replace", "flag": "--gpu-memory-utilization", "value": "0.95"},
{"label": "retry with context 2048", "op": "replace", "flag": "--max-model-len", "value": "2048"},
],
),
(
r"CUDA out of memory|torch\.cuda\.OutOfMemoryError|CUDA error: out of memory|warming up sampler|max_num_seqs.*gpu_memory_utilization",
"GPU ran out of memory during startup or warmup.",
[
{"label": "retry with context 4096", "op": "replace", "flag": "--max-model-len", "value": "4096"},
{"label": "retry with GPU memory utilization 0.80", "op": "replace", "flag": "--gpu-memory-utilization", "value": "0.80"},
{"label": "retry with --enforce-eager", "op": "append", "arg": "--enforce-eager"},
],
),
(
r"not divisib|must be divisible|attention heads.*divisible",
"Tensor parallel size is incompatible with the model.",
[
{"label": "retry with tensor parallel size 1", "op": "replace", "flag": "--tensor-parallel-size", "value": "1"},
{"label": "retry with tensor parallel size 2", "op": "replace", "flag": "--tensor-parallel-size", "value": "2"},
],
),
(
r"KV cache.*too (small|large)|max_model_len.*exceeds|maximum.*context",
"Context length is too large for available GPU memory.",
[
{"label": "retry with context 8192", "op": "replace", "flag": "--max-model-len", "value": "8192"},
{"label": "retry with context 4096", "op": "replace", "flag": "--max-model-len", "value": "4096"},
],
),
(
r"enable-auto-tool-choice requires --tool-call-parser",
"Auto tool choice requires an explicit tool call parser.",
[{"label": "retry with Hermes tool parser", "op": "append", "arg": "--tool-call-parser hermes"}],
),
(
r"Please pass.*trust.remote.code=True|contains custom code which must be executed to correctly load|does not recognize this architecture|model type.*but Transformers does not",
"Model requires custom code or newer model support.",
[{"label": "retry with --trust-remote-code", "op": "append", "arg": "--trust-remote-code"}],
),
(
r"There is no module or parameter named ['\"]lm_head\.input_scale['\"]|lm_head\.input_scale|weight_scale_2",
"vLLM cannot load this ModelOpt LM-head quantized checkpoint with the current runtime.",
[
{
"label": "upgrade vLLM through the environment that provides this CLI, or use a compatible checkpoint",
"op": "manual",
}
],
),
(
r"Either a revision or a version must be specified|transformers\.integrations\.hub_kernels|kernels/layer",
"vLLM/Transformers kernel package mismatch.",
[{"label": "update vLLM, Transformers, and kernels on this server", "op": "dependency", "package": "vllm transformers kernels"}],
),
(
r"Address already in use|bind.*address.*in use",
"Port is already in use.",
[{"label": "retry on port 8001", "op": "replace", "flag": "--port", "value": "8001"}],
),
(
r"No CUDA GPUs are available|no GPU.*found|CUDA_VISIBLE_DEVICES.*invalid",
"No GPUs are visible to the serve process.",
[{"label": "clear Cookbook GPU selection or choose available GPUs", "op": "settings", "field": "gpus", "value": ""}],
),
(
r"Failed to infer device type|NVML Shared Library Not Found|No module named 'amdsmi'|platform is not available",
"vLLM could not find a supported GPU (CUDA or ROCm). "
"This machine may have integrated or unsupported graphics only.",
[
{"label": "switch to llama.cpp (CPU/Metal, works without a discrete GPU)", "op": "manual"},
{"label": "switch to Ollama (CPU/Metal, works without a discrete GPU)", "op": "manual"},
],
),
(
r"vllm.*command not found|No module named vllm|ERROR: vLLM is not installed",
"vLLM is not installed or not in PATH on this server.",
[{"label": "install vLLM in Cookbook Dependencies", "op": "dependency", "package": "vllm"}],
),
(
r"sglang.*command not found|No module named sglang|SGLang is not installed",
"SGLang is not installed or not in PATH on this server.",
[{"label": "install SGLang in Cookbook Dependencies", "op": "dependency", "package": "sglang[all]"}],
),
(
r"llama-server.*command not found|llama\.cpp.*not found|No module named.*llama_cpp|No module named 'starlette_context'|git: command not found|cmake: command not found",
"llama.cpp / llama-cpp-python dependencies are missing.",
[{"label": "install llama.cpp dependencies or llama-cpp-python[server]", "op": "dependency", "package": "llama-cpp-python[server]"}],
),
(
r"No GGUF found on this host|no \.gguf file|No GGUF file found",
"No GGUF file found for this model on this host. The llama.cpp backend needs a .gguf file.",
[{"label": "download a GGUF build of this model (repo name usually ends in -GGUF, file like Q4_K_M.gguf)", "op": "manual"}],
),
(
r"No module named 'torch'|No module named torch|No module named 'diffusers'|No module named diffusers",
"Diffusion serving requires PyTorch and diffusers.",
[{"label": "install diffusers[torch] in Cookbook Dependencies", "op": "dependency", "package": "diffusers[torch]"}],
),
(
r"403 Forbidden|401 Unauthorized|Access to model.*is restricted|gated repo|not in the authorized list|awaiting a review",
"Model access is gated or unauthorized.",
[{"label": "set HF token and request model access on HuggingFace", "op": "manual"}],
),
]
for pattern, message, suggestions in patterns:
if re.search(pattern, tail, re.I):
return {"message": message, "suggestions": suggestions}
if re.search(r"Traceback \(most recent call last\)", tail, re.I) and not re.search(
r"Application startup complete|GET /v1/|Uvicorn running on", tail, re.I
):
return {
"message": "Python traceback detected during serve startup.",
"suggestions": [{"label": "inspect traceback and retry with adjusted backend/settings", "op": "manual"}],
}
return None
async def run_ssh_command_async(
remote: str,
ssh_port: str | None,
remote_cmd: str,
*,
timeout: float,
connect_timeout: int | None = None,
strict_host_key_checking: bool | None = None,
stdin_data: bytes | None = None,
) -> tuple[int, bytes, bytes]:
"""Run an ssh command with centralized timeout and stderr/stdout capture.
Async version of core.platform_compat.run_ssh_command_sync.
"""
import asyncio
proc = await asyncio.create_subprocess_exec(
*_ssh_exec_argv(
remote,
ssh_port,
remote_cmd=remote_cmd,
connect_timeout=connect_timeout,
strict_host_key_checking=strict_host_key_checking,
),
stdin=asyncio.subprocess.PIPE if stdin_data is not None else None,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
try:
stdout, stderr = await asyncio.wait_for(
proc.communicate(input=stdin_data), timeout=timeout
)
except asyncio.TimeoutError:
proc.kill()
await proc.communicate()
raise
return proc.returncode or 0, stdout, stderr
+189 -206
View File
@@ -15,19 +15,26 @@ from pathlib import Path
from fastapi import APIRouter, HTTPException, Request, Depends
from src.auth_helpers import require_user
from src.constants import COOKBOOK_STATE_FILE
from pydantic import BaseModel
from core.middleware import require_admin
from core.platform_compat import (
IS_WINDOWS,
SSH_PATH_OVERRIDE,
NVIDIA_PATH_CANDIDATES,
detached_popen_kwargs,
find_bash,
git_bash_path,
kill_process_tree,
pid_alive,
safe_chmod,
which_tool,
translate_path,
get_wsl_windows_user_profile,
)
from routes.shell_routes import TMUX_LOG_DIR
from src.constants import COOKBOOK_STATE_FILE
logger = logging.getLogger(__name__)
@@ -38,8 +45,10 @@ from routes.cookbook_helpers import (
_ps_squote, _bash_squote, _validate_serve_cmd, _parse_serve_phase,
_safe_env_prefix, _local_tooling_path_export, _append_serve_preflight_exit_lines,
_append_serve_exit_code_lines, _append_llama_cpp_linux_accel_build_lines, _cached_model_scan_script,
_ollama_bind_from_cmd, _pip_install_fallback_chain, _pip_install_no_cache,
_user_shell_path_bootstrap, _venv_safe_local_pip_install_cmd,
_append_vllm_linux_preflight_lines, _ollama_bind_from_cmd, _pip_install_fallback_chain,
_pip_install_no_cache, _user_shell_path_bootstrap, _venv_safe_local_pip_install_cmd,
_append_pip_install_runner_lines,
_diagnose_serve_output, run_ssh_command_async,
ModelDownloadRequest, ServeRequest,
)
@@ -54,7 +63,7 @@ _HF_TOKEN_STATUS_SNIPPET = (
def setup_cookbook_routes() -> APIRouter:
router = APIRouter(tags=["cookbook"])
_cookbook_state_path = Path(os.environ.get("DATA_DIR", "data")) / "cookbook_state.json"
_cookbook_state_path = Path(COOKBOOK_STATE_FILE)
def _mask_secret(value: str) -> str:
if not value:
@@ -81,127 +90,6 @@ def setup_cookbook_routes() -> APIRouter:
task["payload"].pop("hf_token", None)
return state
def _diagnose_serve_output(text: str) -> dict | None:
"""Server-side mirror of the Cookbook UI's common serve diagnoses.
The browser uses cookbook-diagnosis.js for clickable fixes. This gives
the agent/tool path the same structured signal so it can retry with an
adjusted command instead of guessing from raw tmux output.
"""
if not text:
return None
tail = text[-6000:]
patterns = [
(
r"No available memory for the cache blocks|Available KV cache memory:.*-",
"No GPU memory left for KV cache after loading model.",
[
{"label": "retry with GPU memory utilization 0.95", "op": "replace", "flag": "--gpu-memory-utilization", "value": "0.95"},
{"label": "retry with context 2048", "op": "replace", "flag": "--max-model-len", "value": "2048"},
],
),
(
r"CUDA out of memory|torch\.cuda\.OutOfMemoryError|CUDA error: out of memory|warming up sampler|max_num_seqs.*gpu_memory_utilization",
"GPU ran out of memory during startup or warmup.",
[
{"label": "retry with context 4096", "op": "replace", "flag": "--max-model-len", "value": "4096"},
{"label": "retry with GPU memory utilization 0.80", "op": "replace", "flag": "--gpu-memory-utilization", "value": "0.80"},
{"label": "retry with --enforce-eager", "op": "append", "arg": "--enforce-eager"},
],
),
(
r"not divisib|must be divisible|attention heads.*divisible",
"Tensor parallel size is incompatible with the model.",
[
{"label": "retry with tensor parallel size 1", "op": "replace", "flag": "--tensor-parallel-size", "value": "1"},
{"label": "retry with tensor parallel size 2", "op": "replace", "flag": "--tensor-parallel-size", "value": "2"},
],
),
(
r"KV cache.*too (small|large)|max_model_len.*exceeds|maximum.*context",
"Context length is too large for available GPU memory.",
[
{"label": "retry with context 8192", "op": "replace", "flag": "--max-model-len", "value": "8192"},
{"label": "retry with context 4096", "op": "replace", "flag": "--max-model-len", "value": "4096"},
],
),
(
r"enable-auto-tool-choice requires --tool-call-parser",
"Auto tool choice requires an explicit tool call parser.",
[{"label": "retry with Hermes tool parser", "op": "append", "arg": "--tool-call-parser hermes"}],
),
(
r"Please pass.*trust.remote.code=True|contains custom code which must be executed to correctly load|does not recognize this architecture|model type.*but Transformers does not",
"Model requires custom code or newer model support.",
[{"label": "retry with --trust-remote-code", "op": "append", "arg": "--trust-remote-code"}],
),
(
r"Either a revision or a version must be specified|transformers\.integrations\.hub_kernels|kernels/layer",
"vLLM/Transformers kernel package mismatch.",
[{"label": "update vLLM, Transformers, and kernels on this server", "op": "dependency", "package": "vllm transformers kernels"}],
),
(
r"Address already in use|bind.*address.*in use",
"Port is already in use.",
[{"label": "retry on port 8001", "op": "replace", "flag": "--port", "value": "8001"}],
),
(
r"No CUDA GPUs are available|no GPU.*found|CUDA_VISIBLE_DEVICES.*invalid",
"No GPUs are visible to the serve process.",
[{"label": "clear Cookbook GPU selection or choose available GPUs", "op": "settings", "field": "gpus", "value": ""}],
),
(
r"Failed to infer device type|NVML Shared Library Not Found|No module named 'amdsmi'|platform is not available",
"vLLM could not find a supported GPU (CUDA or ROCm). "
"This machine may have integrated or unsupported graphics only.",
[
{"label": "switch to llama.cpp (CPU/Metal, works without a discrete GPU)", "op": "manual"},
{"label": "switch to Ollama (CPU/Metal, works without a discrete GPU)", "op": "manual"},
],
),
(
r"vllm.*command not found|No module named vllm|ERROR: vLLM is not installed",
"vLLM is not installed or not in PATH on this server.",
[{"label": "install vLLM in Cookbook Dependencies", "op": "dependency", "package": "vllm"}],
),
(
r"sglang.*command not found|No module named sglang|SGLang is not installed",
"SGLang is not installed or not in PATH on this server.",
[{"label": "install SGLang in Cookbook Dependencies", "op": "dependency", "package": "sglang[all]"}],
),
(
r"llama-server.*command not found|llama\.cpp.*not found|No module named.*llama_cpp|No module named 'starlette_context'|git: command not found|cmake: command not found",
"llama.cpp / llama-cpp-python dependencies are missing.",
[{"label": "install llama.cpp dependencies or llama-cpp-python[server]", "op": "dependency", "package": "llama-cpp-python[server]"}],
),
(
r"No GGUF found on this host|no \.gguf file|No GGUF file found",
"No GGUF file found for this model on this host. The llama.cpp backend needs a .gguf file.",
[{"label": "download a GGUF build of this model (repo name usually ends in -GGUF, file like Q4_K_M.gguf)", "op": "manual"}],
),
(
r"No module named 'torch'|No module named torch|No module named 'diffusers'|No module named diffusers",
"Diffusion serving requires PyTorch and diffusers.",
[{"label": "install diffusers[torch] in Cookbook Dependencies", "op": "dependency", "package": "diffusers[torch]"}],
),
(
r"403 Forbidden|401 Unauthorized|Access to model.*is restricted|gated repo|not in the authorized list|awaiting a review",
"Model access is gated or unauthorized.",
[{"label": "set HF token and request model access on HuggingFace", "op": "manual"}],
),
]
for pattern, message, suggestions in patterns:
if re.search(pattern, tail, re.I):
return {"message": message, "suggestions": suggestions}
if re.search(r"Traceback \(most recent call last\)", tail, re.I) and not re.search(
r"Application startup complete|GET /v1/|Uvicorn running on", tail, re.I
):
return {
"message": "Python traceback detected during serve startup.",
"suggestions": [{"label": "inspect traceback and retry with adjusted backend/settings", "op": "manual"}],
}
return None
def _state_for_client(state):
"""Return cookbook state without raw secrets for browser clients."""
_strip_task_secrets(state)
@@ -295,6 +183,7 @@ def setup_cookbook_routes() -> APIRouter:
safe_chmod(key_path.with_suffix(".pub"), 0o644)
return {"ok": True, "public_key": _read_cookbook_public_key()}
def _needs_binary(cmd: str, binary: str) -> bool:
return bool(re.search(rf"(^|[\s;&|()]){re.escape(binary)}($|[\s;&|()])", cmd or ""))
@@ -355,8 +244,8 @@ def setup_cookbook_routes() -> APIRouter:
# POSIX form + shell-quoting so drive paths / spaces survive.
inner = TMUX_LOG_DIR / f"{session_id}_run.sh"
inner.write_text("\n".join(bash_lines) + "\n", encoding="utf-8")
lp = shlex.quote(log_path.as_posix())
ip = shlex.quote(inner.as_posix())
lp = shlex.quote(git_bash_path(log_path))
ip = shlex.quote(git_bash_path(inner))
script_path = TMUX_LOG_DIR / f"{session_id}.sh"
script_path.write_text(
f"bash {ip} > {lp} 2>&1\n",
@@ -472,6 +361,8 @@ def setup_cookbook_routes() -> APIRouter:
ps_lines = []
ps_lines.append('$sessionDir = "$env:TEMP\\odysseus-sessions"')
ps_lines.append('New-Item -ItemType Directory -Force -Path $sessionDir | Out-Null')
ps_lines.append('$env:PYTHONIOENCODING = "utf-8"')
ps_lines.append('$env:PYTHONUTF8 = "1"')
if req.hf_token:
ps_lines.append(f"$env:HF_TOKEN = '{_ps_squote(req.hf_token)}'")
if req.env_prefix:
@@ -545,7 +436,7 @@ def setup_cookbook_routes() -> APIRouter:
# Install hf CLI + optional hf_transfer best-effort. Retries disable
# hf_transfer because the Rust parallel path is fast but has been
# flaky near the end of very large multi-file downloads.
# Use --break-system-packages on PEP-668 systems (Arch, newer Debian) so it doesn't bail.
# The helper tries active pip first, then guarded user-site fallbacks.
runner_lines.append(f"command -v hf >/dev/null 2>&1 || {_pip_install_fallback_chain('huggingface_hub', python_cmd='pip', upgrade=True)}")
if req.disable_hf_transfer:
runner_lines.append("export HF_HUB_ENABLE_HF_TRANSFER=0")
@@ -673,24 +564,35 @@ def setup_cookbook_routes() -> APIRouter:
for d in model_dir.split(','):
d = d.strip()
if d:
model_dirs.append(d)
paths_code = _cached_model_scan_script(model_dirs)
translated_d = translate_path(d) if not host else d
model_dirs.append(translated_d)
win_hf_hub = None
if not host:
win_profile = get_wsl_windows_user_profile()
win_hf_hub = os.path.join(win_profile, ".cache", "huggingface", "hub") if win_profile else None
paths_code = _cached_model_scan_script(model_dirs, win_hf_hub)
scan_py = TMUX_LOG_DIR / "scan_cache.py"
scan_py.write_text(paths_code, encoding="utf-8")
scan_payload = scan_py.read_bytes()
if host:
_pf = f"-p {ssh_port} " if ssh_port and ssh_port != "22" else ""
if platform == "windows":
# Windows: use 'python' and pipe via stdin with double-quote wrapping
cmd = f'ssh {_pf}{host} "python -" < \'{scan_py}\''
remote_cmd = "python -"
else:
cmd = f"ssh {_pf}{host} 'python3 -' < '{scan_py}'"
proc = await asyncio.create_subprocess_shell(
cmd,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
cwd=str(Path.home()),
# POSIX: use 'python3' if available, fall back to 'python'; throw if neither is found.
remote_cmd = (
"if command -v python3 >/dev/null 2>&1; then python3 -; "
"elif command -v python >/dev/null 2>&1; then python -; "
"else echo \"python3/python not found\" >&2; exit 127; fi"
)
rc, stdout_b, stderr_b = await run_ssh_command_async(
host,
ssh_port,
remote_cmd,
timeout=60,
stdin_data=scan_payload,
)
else:
# LOCAL scan: use sys.executable (the venv Python Odysseus is already
@@ -710,7 +612,7 @@ def setup_cookbook_routes() -> APIRouter:
stderr=asyncio.subprocess.PIPE,
cwd=str(Path.home()),
)
stdout_b, stderr_b = await asyncio.wait_for(proc.communicate(), timeout=60)
stdout_b, stderr_b = await asyncio.wait_for(proc.communicate(), timeout=60)
models = []
try:
@@ -915,6 +817,10 @@ def setup_cookbook_routes() -> APIRouter:
existing.name = display_name
if supports_tools is not None:
existing.supports_tools = supports_tools
# Wipe stale model lists so the picker re-probes and discovers
# the newly-served model instead of showing the old one.
existing.cached_models = None
existing.hidden_models = None
db.commit()
logger.info(f"Updated existing local model endpoint: {base_url}")
return existing.id
@@ -971,11 +877,27 @@ def setup_cookbook_routes() -> APIRouter:
in_venv=sys.prefix != sys.base_prefix,
)
is_pip_install = bool(req.cmd and "pip install" in req.cmd)
remote = req.remote_host
is_windows = req.platform == "windows"
local_windows = IS_WINDOWS and not remote
if is_windows or local_windows:
if req.cmd.startswith("python3 "):
req.cmd = "python " + req.cmd[len("python3 "):]
if is_pip_install and ("llama-cpp-python" in req.cmd or "llama_cpp" in req.cmd) and (is_windows or local_windows):
if "--extra-index-url" not in req.cmd:
req.cmd += " --extra-index-url https://abetlen.github.io/llama-cpp-python/whl/cpu"
if is_pip_install:
# Keep big dependency wheel builds (vLLM, …) off the home filesystem's
# pip cache so they don't fail mid-build with "No space left" (#1219)
# and leave the dep installed-but-unusable (#1459).
req.cmd = _pip_install_no_cache(req.cmd)
# Accept common aliases and enforce server extras for llama-cpp so
# `python -m llama_cpp.server` has all runtime dependencies.
req.cmd = re.sub(r"(?<![A-Za-z0-9_.-])llama_cpp(?![A-Za-z0-9_.-])", "llama-cpp-python[server]", req.cmd)
req.cmd = re.sub(r"(?<![A-Za-z0-9_.-])llama-cpp-python(?!\[)", "llama-cpp-python[server]", req.cmd)
if "llama-cpp-python" in req.cmd and "--extra-index-url" not in req.cmd:
req.cmd += " --extra-index-url https://abetlen.github.io/llama-cpp-python/whl/cpu"
# PEP-508-style package spec — letters, digits, `.-_` for the
# name; `[` `]` for extras; `<>=!~,` for version specifiers.
# v2 review HIGH-14: tightened from the previous regex which
@@ -1028,6 +950,8 @@ def setup_cookbook_routes() -> APIRouter:
ps_lines = []
ps_lines.append('$sessionDir = "$env:TEMP\\odysseus-sessions"')
ps_lines.append('New-Item -ItemType Directory -Force -Path $sessionDir | Out-Null')
ps_lines.append('$env:PYTHONIOENCODING = "utf-8"')
ps_lines.append('$env:PYTHONUTF8 = "1"')
if req.hf_token:
ps_lines.append(f"$env:HF_TOKEN = '{_ps_squote(req.hf_token)}'")
if req.gpus:
@@ -1046,7 +970,7 @@ def setup_cookbook_routes() -> APIRouter:
ps_lines.append('try { python -c "import llama_cpp" 2>$null } catch {}')
ps_lines.append('if ($LASTEXITCODE -ne 0) {')
ps_lines.append(' Write-Host "Installing llama-cpp-python..."')
ps_lines.append(' python -m pip install llama-cpp-python[server]')
ps_lines.append(' python -m pip install llama-cpp-python[server] --extra-index-url https://abetlen.github.io/llama-cpp-python/whl/cpu')
ps_lines.append('}')
elif "vllm" in req.cmd:
ps_lines.append('Write-Host "ERROR: vLLM is not supported on Windows. Use Ollama or llama.cpp instead."')
@@ -1121,45 +1045,57 @@ def setup_cookbook_routes() -> APIRouter:
# ollama is found (otherwise macOS falls back to a slow source build).
# /opt/homebrew = Apple Silicon, /usr/local = Intel; harmless on Linux.
runner_lines.append('export PATH="$HOME/.local/bin:$HOME/bin:$HOME/llama.cpp/build/bin:/opt/homebrew/bin:/usr/local/bin:$PATH"')
runner_lines.append('if [ -d /data/data/com.termux ]; then')
runner_lines.append(' # Termux: no native build — use the Python bindings (CPU).')
runner_lines.append(' if ! python3 -c "import llama_cpp" 2>/dev/null; then')
runner_lines.append(' pkg install -y cmake 2>/dev/null')
runner_lines.append(' pip install numpy diskcache jinja2 2>/dev/null')
runner_lines.append(' CMAKE_ARGS="-DGGML_BLAS=OFF -DGGML_LLAMAFILE=OFF" pip install \'llama-cpp-python[server]\' --no-build-isolation --no-cache-dir 2>&1 || true')
runner_lines.append(' fi')
runner_lines.append('elif ! command -v llama-server &>/dev/null; then')
runner_lines.append(' echo "Native llama-server not found — building from source (one-time, may take a few minutes)..."')
runner_lines.append(' mkdir -p ~/bin')
runner_lines.append(' cd ~ && [ -d llama.cpp ] || git clone --depth 1 https://github.com/ggml-org/llama.cpp')
# Build with the right accelerator: Metal on macOS (llama.cpp
# enables it automatically, no flag), CUDA on Linux when present,
# else a plain CPU build. nproc is Linux-only — fall back to
# `sysctl hw.ncpu` on macOS. (Tip: `brew install llama.cpp` ships
# a prebuilt llama-server and skips this whole source build.)
runner_lines.append(' NPROC="$(nproc 2>/dev/null || sysctl -n hw.ncpu 2>/dev/null || echo 4)"')
runner_lines.append(' if [ "$(uname -s)" = "Darwin" ]; then')
runner_lines.append(' command -v cmake >/dev/null 2>&1 || echo "WARNING: cmake not found — install it with: brew install cmake (or: brew install llama.cpp for a prebuilt llama-server)."')
# Start from a clean cache: a prior failed configure (e.g. a CUDA
# attempt) poisons build/CMakeCache.txt, so a plain `cmake -B build`
# would reuse the bad settings and fail again. CMAKE_BUILD_TYPE is
# explicit so the binary is optimized (Metal auto-enables on macOS).
runner_lines.append(' cd ~/llama.cpp && rm -rf build && cmake -B build -DCMAKE_BUILD_TYPE=Release \\')
runner_lines.append(' && cmake --build build -j"$NPROC" --target llama-server \\')
runner_lines.append(' && ln -sf ~/llama.cpp/build/bin/llama-server ~/bin/llama-server')
runner_lines.append(' else')
_append_llama_cpp_linux_accel_build_lines(runner_lines)
runner_lines.append(' fi')
runner_lines.append(' # If the native build failed, fall back to the Python bindings.')
runner_lines.append(' if ! command -v llama-server &>/dev/null && ! python3 -c "import llama_cpp" 2>/dev/null; then')
runner_lines.append(' echo "llama-server build failed — installing Python bindings as fallback..."')
runner_lines.append(f" {_pip_install_fallback_chain('llama-cpp-python[server]', python_cmd='pip')} || true")
runner_lines.append(' fi')
runner_lines.append(' if ! command -v llama-server &>/dev/null && ! python3 -c "import llama_cpp" 2>/dev/null; then')
runner_lines.append(' echo "ERROR: llama.cpp serving is not available after install/build attempts."')
runner_lines.append(' ODYSSEUS_PREFLIGHT_EXIT=127')
runner_lines.append(' fi')
runner_lines.append('fi')
if local_windows:
# LOCAL Windows: no native source compilation (no cmake/compiler on Git Bash).
# Just check python bindings (using native `python` binary) and fall back to pip install.
runner_lines.append('if ! command -v llama-server &>/dev/null && ! python -c "import llama_cpp" 2>/dev/null; then')
runner_lines.append(' echo "llama-server not found — installing Python bindings..."')
runner_lines.append(f" {_pip_install_fallback_chain('llama-cpp-python[server]', python_cmd='python')} || true")
runner_lines.append('fi')
runner_lines.append('if ! command -v llama-server &>/dev/null && ! python -c "import llama_cpp" 2>/dev/null; then')
runner_lines.append(' echo "ERROR: llama.cpp serving is not available after install attempts."')
runner_lines.append(' ODYSSEUS_PREFLIGHT_EXIT=127')
runner_lines.append('fi')
else:
runner_lines.append('if [ -d /data/data/com.termux ]; then')
runner_lines.append(' # Termux: no native build — use the Python bindings (CPU).')
runner_lines.append(' if ! python3 -c "import llama_cpp" 2>/dev/null; then')
runner_lines.append(' pkg install -y cmake 2>/dev/null')
runner_lines.append(' pip install numpy diskcache jinja2 2>/dev/null')
runner_lines.append(' CMAKE_ARGS="-DGGML_BLAS=OFF -DGGML_LLAMAFILE=OFF" pip install \'llama-cpp-python[server]\' --no-build-isolation --no-cache-dir 2>&1 || true')
runner_lines.append(' fi')
runner_lines.append('elif ! command -v llama-server &>/dev/null; then')
runner_lines.append(' echo "Native llama-server not found — building from source (one-time, may take a few minutes)..."')
runner_lines.append(' mkdir -p ~/bin')
runner_lines.append(' cd ~ && [ -d llama.cpp ] || git clone --depth 1 https://github.com/ggml-org/llama.cpp')
# Build with the right accelerator: Metal on macOS (llama.cpp
# enables it automatically, no flag), CUDA on Linux when present,
# else a plain CPU build. nproc is Linux-only — fall back to
# `sysctl hw.ncpu` on macOS. (Tip: `brew install llama.cpp` ships
# a prebuilt llama-server and skips this whole source build.)
runner_lines.append(' NPROC="$(nproc 2>/dev/null || sysctl -n hw.ncpu 2>/dev/null || echo 4)"')
runner_lines.append(' if [ "$(uname -s)" = "Darwin" ]; then')
runner_lines.append(' command -v cmake >/dev/null 2>&1 || echo "WARNING: cmake not found — install it with: brew install cmake (or: brew install llama.cpp for a prebuilt llama-server)."')
# Start from a clean cache: a prior failed configure (e.g. a CUDA
# attempt) poisons build/CMakeCache.txt, so a plain `cmake -B build`
# would reuse the bad settings and fail again. CMAKE_BUILD_TYPE is
# explicit so the binary is optimized (Metal auto-enables on macOS).
runner_lines.append(' cd ~/llama.cpp && rm -rf build && cmake -B build -DCMAKE_BUILD_TYPE=Release \\')
runner_lines.append(' && cmake --build build -j"$NPROC" --target llama-server \\')
runner_lines.append(' && ln -sf ~/llama.cpp/build/bin/llama-server ~/bin/llama-server')
runner_lines.append(' else')
_append_llama_cpp_linux_accel_build_lines(runner_lines)
runner_lines.append(' fi')
# If the native build failed, fall back to the Python bindings.
runner_lines.append(' if ! command -v llama-server &>/dev/null && ! python3 -c "import llama_cpp" 2>/dev/null; then')
runner_lines.append(' echo "llama-server build failed — installing Python bindings as fallback..."')
runner_lines.append(f" {_pip_install_fallback_chain('llama-cpp-python[server]', python_cmd='pip')} || true")
runner_lines.append(' fi')
runner_lines.append(' if ! command -v llama-server &>/dev/null && ! python3 -c "import llama_cpp" 2>/dev/null; then')
runner_lines.append(' echo "ERROR: llama.cpp serving is not available after install/build attempts."')
runner_lines.append(' ODYSSEUS_PREFLIGHT_EXIT=127')
runner_lines.append(' fi')
runner_lines.append('fi')
elif "ollama" in req.cmd:
handled_ollama_serve = True
_ollama_default_host = "0.0.0.0" if remote else "127.0.0.1"
@@ -1181,13 +1117,23 @@ def setup_cookbook_routes() -> APIRouter:
runner_lines.append(' ODYSSEUS_OLLAMA_PORT="$_ody_try_port"')
runner_lines.append(' break')
runner_lines.append(' fi')
runner_lines.append(' exec 3<&-; exec 3>&-')
runner_lines.append('done')
runner_lines.append(' echo "[odysseus] Ollama API ready on port ${ODYSSEUS_OLLAMA_PORT}: ${ODYSSEUS_OLLAMA_URL}"')
runner_lines.append(' echo "[odysseus] This task is monitoring an existing Ollama server; stopping it here will not stop an external Docker/system service."')
if local_windows:
# Windows detached process has no TTY; exec bash -i crashes.
# Keep the monitoring task alive with a sleep loop.
runner_lines.append(' while true; do sleep 60; done')
else:
runner_lines.append(' exec bash -i')
runner_lines.append('fi')
runner_lines.append('if ! command -v ollama &>/dev/null; then')
runner_lines.append(' echo "ERROR: Ollama not found on this server. Install it from https://ollama.com/download or `curl -fsSL https://ollama.com/install.sh | sh`."')
runner_lines.append(' echo')
runner_lines.append(' echo "=== Process exited with code 127 ==="')
runner_lines.append(' exec bash -i')
if local_windows:
runner_lines.append(' exit 127')
else:
runner_lines.append(' exec bash -i')
runner_lines.append('fi')
runner_lines.append('ODYSSEUS_OLLAMA_URL="http://${ODYSSEUS_OLLAMA_HOST}:${ODYSSEUS_OLLAMA_PORT}"')
if remote and _ollama_host in ("0.0.0.0", "::"):
@@ -1195,24 +1141,20 @@ def setup_cookbook_routes() -> APIRouter:
runner_lines.append('echo "[odysseus] Ollama has no built-in authentication; expose this only on a trusted LAN/VPN or provide an explicit OLLAMA_HOST with your own access controls."')
runner_lines.append('echo "Starting ollama server on ${ODYSSEUS_OLLAMA_HOST}:${ODYSSEUS_OLLAMA_PORT}..."')
runner_lines.append('OLLAMA_HOST="${ODYSSEUS_OLLAMA_HOST}:${ODYSSEUS_OLLAMA_PORT}" ollama serve')
runner_lines.append('_ody_exit=$?')
runner_lines.append('echo')
runner_lines.append('echo "=== Process exited with code ${_ody_exit} ==="')
runner_lines.append('exec bash -i')
if local_windows:
_append_serve_exit_code_lines(runner_lines, keep_shell_open=False)
else:
runner_lines.append('_ody_exit=$?')
runner_lines.append('echo')
runner_lines.append('echo "=== Process exited with code ${_ody_exit} ==="')
runner_lines.append('exec bash -i')
elif "vllm serve" in req.cmd:
# vLLM is CUDA/ROCm-only and does not run on macOS at all.
runner_lines.append('if [ "$(uname -s)" = "Darwin" ]; then')
runner_lines.append(' echo "ERROR: vLLM does not run on macOS. Use Ollama or llama.cpp (Metal) instead."')
runner_lines.append(' ODYSSEUS_PREFLIGHT_EXIT=1')
runner_lines.append('fi')
# Put ~/.local/bin on PATH first — without a venv, vllm installs
# there via --user and the non-login serve shell otherwise can't
# find the `vllm` CLI ("command not found"). Mirrors llama.cpp above.
runner_lines.append('export PATH="$HOME/.local/bin:$PATH"')
runner_lines.append('if ! command -v vllm &>/dev/null; then')
runner_lines.append(' echo "ERROR: vLLM is not installed."')
runner_lines.append(' ODYSSEUS_PREFLIGHT_EXIT=127')
runner_lines.append('fi')
_append_vllm_linux_preflight_lines(runner_lines)
elif "sglang.launch_server" in req.cmd:
runner_lines.append('export PATH="$HOME/.local/bin:$PATH"')
runner_lines.append('if ! command -v sglang &>/dev/null; then')
@@ -1236,7 +1178,10 @@ def setup_cookbook_routes() -> APIRouter:
runner_lines,
keep_shell_open=not local_windows,
)
runner_lines.append(req.cmd)
if is_pip_install:
_append_pip_install_runner_lines(runner_lines, req.cmd)
else:
runner_lines.append(req.cmd)
if local_windows:
# Detached background process — no interactive shell to keep open.
# Print the exit marker the status poller looks for, then stop.
@@ -1397,8 +1342,8 @@ def setup_cookbook_routes() -> APIRouter:
cmd = f"ssh {pf}{host} '{setup_script}'"
else:
# Linux: auto-install tmux (via whichever package manager is available)
# and huggingface_hub + hf_transfer (falling back to --user/--break-system-packages
# on PEP-668 locked distros like Arch / newer Debian).
# and huggingface_hub + hf_transfer (falling back to --user, then
# guarded --break-system-packages on PEP-668 locked distros).
setup_script = (
# Install tmux if missing — try common package managers; skip if no sudo
"if ! command -v tmux >/dev/null 2>&1; then "
@@ -1410,10 +1355,15 @@ def setup_cookbook_routes() -> APIRouter:
" fi; "
"fi; "
"command -v tmux >/dev/null 2>&1 || echo 'WARNING: tmux missing and auto-install failed (need passwordless sudo). Install manually.'; "
# Install Python bits. Try system install first; fall back to --user --break-system-packages on PEP 668 systems.
# Install Python bits. Try system install first; fall back to --user,
# then use --break-system-packages only when pip supports it.
"pip install -q huggingface_hub hf_transfer 2>/dev/null || "
"pip install --user --break-system-packages -q huggingface_hub hf_transfer 2>/dev/null || "
"pip3 install --user --break-system-packages -q huggingface_hub hf_transfer 2>/dev/null; "
"pip install --user -q huggingface_hub hf_transfer 2>/dev/null || "
"( pip install --help 2>/dev/null | grep -q -- --break-system-packages && "
"pip install --user --break-system-packages -q huggingface_hub hf_transfer 2>/dev/null ) || "
"pip3 install --user -q huggingface_hub hf_transfer 2>/dev/null || "
"( pip3 install --help 2>/dev/null | grep -q -- --break-system-packages && "
"pip3 install --user --break-system-packages -q huggingface_hub hf_transfer 2>/dev/null ); "
"python3 -c 'from huggingface_hub import snapshot_download; print(\"OK\")'"
)
cmd = f"ssh {pf}{host} '{setup_script}'"
@@ -1436,11 +1386,38 @@ def setup_cookbook_routes() -> APIRouter:
async def _run_nvidia_smi(query: str, host: str | None, ssh_port: str | None, timeout: int = 8):
"""Run nvidia-smi locally or over SSH. Returns (stdout, error_or_None)."""
if host:
pf = f"-p {ssh_port} " if ssh_port and ssh_port != "22" else ""
cmd = f"ssh -o ConnectTimeout=5 -o StrictHostKeyChecking=no {pf}{host} '{query}'"
proc = await asyncio.create_subprocess_shell(
cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE
)
candidates = [query]
stripped = query.strip()
if stripped.startswith("nvidia-smi "):
args = stripped[len("nvidia-smi "):]
candidates.append(
"bash -lc "
+ shlex.quote(
f"{SSH_PATH_OVERRIDE}"
f"nvidia-smi {args}"
)
)
for nvidia_path in NVIDIA_PATH_CANDIDATES:
candidates.append(f"{nvidia_path} {args}")
last_err = "nvidia-smi failed"
for candidate in candidates:
try:
rc, stdout, stderr = await run_ssh_command_async(
host,
ssh_port,
candidate,
connect_timeout=5,
timeout=timeout,
)
except asyncio.TimeoutError:
return None, "nvidia-smi timed out"
if rc == 0:
return stdout.decode("utf-8", errors="replace"), None
err = (stderr.decode("utf-8", errors="replace") or "").strip()[:200]
if err:
last_err = err
return None, last_err
else:
proc = await asyncio.create_subprocess_exec(
*shlex.split(query),
@@ -2203,7 +2180,13 @@ def setup_cookbook_routes() -> APIRouter:
"inc=os.path.isdir(blobs) and any(x.endswith('.incomplete') for x in os.listdir(blobs));"
"sys.exit(0 if ok and not inc else 1)"
)
cmd = ["python3", "-c", py, repo_id]
if remote_host:
cmd = ["python3", "-c", py, repo_id]
else:
# Local Windows: python3 can hit the Microsoft Store stub. Use the
# real Python Odysseus is running under (guaranteed to exist).
import sys as _sys_local
cmd = [_sys_local.executable, "-c", py, repo_id]
try:
if remote_host:
ssh_base = ["ssh"]
+67 -117
View File
@@ -20,39 +20,26 @@ All routes are admin-gated (endpoint/provider management is an admin action).
"""
import json
import time
import uuid
import logging
import threading
from typing import Dict, Optional
import httpx
from fastapi import APIRouter, Request, Form, HTTPException
from fastapi import HTTPException, Request
from core.database import SessionLocal, ModelEndpoint
from core.middleware import require_admin
from routes.device_flow import (
DeviceFlowPoll,
DeviceFlowStart,
PendingDeviceFlowStore,
create_device_flow_router,
)
from src.auth_helpers import get_current_user
from src import copilot
logger = logging.getLogger(__name__)
# Pending device-flow logins, keyed by an opaque poll_id. The device_code is a
# bearer-like secret, so it lives here (server memory) rather than in the
# browser. Entries expire with the GitHub device code.
#
# NOTE: this is per-process state. The device flow assumes a single worker
# (Odysseus' default): with multiple uvicorn workers, the poll request can land
# on a worker that never saw the start, returning "Unknown or expired login
# session". Move this to a shared store (DB/Redis) if running multi-worker.
_PENDING: Dict[str, Dict] = {}
_PENDING_LOCK = threading.Lock()
def _prune_expired() -> None:
now = time.time()
with _PENDING_LOCK:
for k in [k for k, v in _PENDING.items() if v.get("expires_at", 0) < now]:
_PENDING.pop(k, None)
_DEVICE_FLOW_STORE = PendingDeviceFlowStore()
def _provision_endpoint(token: str, base: str, owner: Optional[str]) -> Dict:
@@ -112,112 +99,75 @@ def _provision_endpoint(token: str, base: str, owner: Optional[str]) -> Dict:
return result
def setup_copilot_routes() -> APIRouter:
router = APIRouter(prefix="/api/copilot", tags=["copilot"])
def _start_device_flow(request: Request, form) -> DeviceFlowStart:
host = copilot.GITHUB_HOST
ent = str(form.get("enterprise_url") or "").strip()
if ent:
host = copilot.normalize_domain(ent)
try:
data = copilot.request_device_code(host)
except httpx.HTTPStatusError as e:
status = e.response.status_code if e.response is not None else "unknown"
raise HTTPException(502, f"GitHub device-code request failed (HTTP {status})")
except Exception as e:
raise HTTPException(502, f"GitHub device-code request failed: {e}")
@router.post("/device/start")
def device_start(request: Request, enterprise_url: str = Form("")):
require_admin(request)
_prune_expired()
host = copilot.GITHUB_HOST
ent = (enterprise_url or "").strip()
if ent:
host = copilot.normalize_domain(ent)
try:
data = copilot.request_device_code(host)
except httpx.HTTPStatusError as e:
status = e.response.status_code if e.response is not None else "unknown"
raise HTTPException(502, f"GitHub device-code request failed (HTTP {status})")
except Exception as e:
raise HTTPException(502, f"GitHub device-code request failed: {e}")
device_code = data.get("device_code")
if not device_code:
raise HTTPException(502, "GitHub did not return a device code")
device_code = data.get("device_code")
if not device_code:
raise HTTPException(502, "GitHub did not return a device code")
interval = int(data.get("interval") or 5)
expires_in = int(data.get("expires_in") or 900)
poll_id = uuid.uuid4().hex
with _PENDING_LOCK:
_PENDING[poll_id] = {
"device_code": device_code,
"host": host,
"enterprise_url": ent,
"interval": interval,
"owner": get_current_user(request) or None,
"expires_at": time.time() + expires_in,
"next_poll_at": 0.0,
}
# verification_uri_complete embeds the user code, so the browser tab we
# open lands the user straight on GitHub's "Authorize" screen with the
# code pre-filled — one click, no manual code entry.
return {
"poll_id": poll_id,
# verification_uri_complete embeds the user code, so the browser tab we
# open lands the user straight on GitHub's "Authorize" screen with the
# code pre-filled — one click, no manual code entry.
return DeviceFlowStart(
pending={
"device_code": device_code,
"host": host,
"enterprise_url": ent,
"owner": get_current_user(request) or None,
},
response={
"user_code": data.get("user_code"),
"verification_uri": data.get("verification_uri"),
"verification_uri_complete": data.get("verification_uri_complete"),
"interval": interval,
"expires_in": expires_in,
}
},
interval=int(data.get("interval") or 5),
expires_in=int(data.get("expires_in") or 900),
)
@router.post("/device/poll")
def device_poll(request: Request, poll_id: str = Form(...)):
require_admin(request)
_prune_expired()
with _PENDING_LOCK:
pending = _PENDING.get(poll_id)
if not pending:
raise HTTPException(404, "Unknown or expired login session")
# Enforce GitHub's polling interval server-side so a chatty client
# can't trip slow_down.
now = time.time()
if now < pending.get("next_poll_at", 0):
return {"status": "pending"}
def _poll_device_flow(_request: Request, pending: Dict) -> DeviceFlowPoll:
try:
data = copilot.poll_access_token(pending["host"], pending["device_code"])
except Exception as e:
return DeviceFlowPoll.pending(f"poll error: {e}")
token = data.get("access_token")
if token:
base = copilot.enterprise_base(pending["enterprise_url"]) if pending["enterprise_url"] else copilot.COPILOT_BASE
try:
data = copilot.poll_access_token(pending["host"], pending["device_code"])
result = _provision_endpoint(token, base, pending["owner"])
except Exception as e:
return {"status": "pending", "detail": f"poll error: {e}"}
logger.exception("Copilot endpoint provisioning failed")
raise HTTPException(500, f"Login succeeded but provisioning failed: {e}")
return DeviceFlowPoll.authorized(result)
token = data.get("access_token")
if token:
base = copilot.enterprise_base(pending["enterprise_url"]) if pending["enterprise_url"] else copilot.COPILOT_BASE
try:
result = _provision_endpoint(token, base, pending["owner"])
except Exception as e:
logger.exception("Copilot endpoint provisioning failed")
with _PENDING_LOCK:
_PENDING.pop(poll_id, None)
raise HTTPException(500, f"Login succeeded but provisioning failed: {e}")
with _PENDING_LOCK:
_PENDING.pop(poll_id, None)
return {"status": "authorized", "endpoint": result}
err = data.get("error")
if err == "authorization_pending":
return DeviceFlowPoll.pending()
if err == "slow_down":
return DeviceFlowPoll.slow_down(int(data.get("interval") or 0) or None)
if err in ("expired_token", "access_denied"):
return DeviceFlowPoll.failed(err)
# Unknown error — surface but keep the session for another try.
return DeviceFlowPoll.pending(err or "unknown")
err = data.get("error")
if err == "authorization_pending":
with _PENDING_LOCK:
if poll_id in _PENDING:
_PENDING[poll_id]["next_poll_at"] = now + pending["interval"]
return {"status": "pending"}
if err == "slow_down":
new_interval = int(data.get("interval") or (pending["interval"] + 5))
with _PENDING_LOCK:
if poll_id in _PENDING:
_PENDING[poll_id]["interval"] = new_interval
_PENDING[poll_id]["next_poll_at"] = now + new_interval
return {"status": "pending"}
if err in ("expired_token", "access_denied"):
with _PENDING_LOCK:
_PENDING.pop(poll_id, None)
return {"status": "failed", "error": err}
# Unknown error — surface but keep the session for another try.
return {"status": "pending", "detail": err or "unknown"}
@router.post("/device/cancel")
def device_cancel(request: Request, poll_id: str = Form(...)):
require_admin(request)
with _PENDING_LOCK:
_PENDING.pop(poll_id, None)
return {"status": "cancelled"}
return router
def setup_copilot_routes():
return create_device_flow_router(
prefix="/api/copilot",
tags=["copilot"],
store=_DEVICE_FLOW_STORE,
start_flow=_start_device_flow,
poll_flow=_poll_device_flow,
)
+193
View File
@@ -0,0 +1,193 @@
"""Shared OAuth/device-flow route scaffolding for provider setup."""
from __future__ import annotations
import inspect
import threading
import time
import uuid
from dataclasses import dataclass
from typing import Any, Callable, Iterable, Mapping, Optional
from fastapi import APIRouter, Form, HTTPException, Request
from core.middleware import require_admin
@dataclass(frozen=True)
class DeviceFlowStart:
"""Provider-specific start result consumed by the shared route wrapper."""
pending: Mapping[str, Any]
response: Mapping[str, Any]
interval: int = 5
expires_in: int = 900
@dataclass(frozen=True)
class DeviceFlowPoll:
"""Normalized provider poll outcome."""
status: str
endpoint: Optional[Mapping[str, Any]] = None
error: Optional[str] = None
detail: Optional[str] = None
interval: Optional[int] = None
@classmethod
def pending(cls, detail: Optional[str] = None) -> "DeviceFlowPoll":
return cls(status="pending", detail=detail)
@classmethod
def slow_down(cls, interval: Optional[int] = None, detail: Optional[str] = None) -> "DeviceFlowPoll":
return cls(status="slow_down", interval=interval, detail=detail)
@classmethod
def authorized(cls, endpoint: Mapping[str, Any]) -> "DeviceFlowPoll":
return cls(status="authorized", endpoint=endpoint)
@classmethod
def failed(cls, error: str) -> "DeviceFlowPoll":
return cls(status="failed", error=error)
class PendingDeviceFlowStore:
"""Thread-safe in-memory pending device-flow store.
Device codes and provider-side secrets stay inside this process. Each entry
stores provider payload separately from poll metadata so provider callbacks
only receive the fields they created.
"""
def __init__(self, *, time_func: Callable[[], float] = time.time):
self._pending: dict[str, dict[str, Any]] = {}
self._lock = threading.Lock()
self._time = time_func
def _now(self) -> float:
return float(self._time())
def prune_expired(self) -> None:
now = self._now()
with self._lock:
for key in [k for k, v in self._pending.items() if v.get("expires_at", 0) < now]:
self._pending.pop(key, None)
def add(self, payload: Mapping[str, Any], *, interval: int, expires_in: int) -> str:
self.prune_expired()
poll_id = uuid.uuid4().hex
with self._lock:
self._pending[poll_id] = {
"payload": dict(payload),
"interval": max(int(interval or 5), 1),
"expires_at": self._now() + max(int(expires_in or 900), 1),
"next_poll_at": 0.0,
}
return poll_id
def get_payload(self, poll_id: str) -> Optional[dict[str, Any]]:
self.prune_expired()
with self._lock:
entry = self._pending.get(poll_id)
if entry is None:
return None
return dict(entry.get("payload") or {})
def is_throttled(self, poll_id: str) -> bool:
with self._lock:
entry = self._pending.get(poll_id)
return bool(entry and self._now() < float(entry.get("next_poll_at") or 0))
def schedule_next(self, poll_id: str) -> None:
now = self._now()
with self._lock:
entry = self._pending.get(poll_id)
if entry is not None:
entry["next_poll_at"] = now + int(entry.get("interval") or 5)
def slow_down(self, poll_id: str, interval: Optional[int] = None) -> None:
now = self._now()
with self._lock:
entry = self._pending.get(poll_id)
if entry is not None:
new_interval = int(interval or (int(entry.get("interval") or 5) + 5))
entry["interval"] = max(new_interval, 1)
entry["next_poll_at"] = now + entry["interval"]
def pop(self, poll_id: str) -> None:
with self._lock:
self._pending.pop(poll_id, None)
async def _maybe_await(value: Any) -> Any:
if inspect.isawaitable(value):
return await value
return value
def _pending_response(detail: Optional[str] = None) -> dict[str, Any]:
response: dict[str, Any] = {"status": "pending"}
if detail:
response["detail"] = detail
return response
def create_device_flow_router(
*,
prefix: str,
tags: Iterable[str],
store: PendingDeviceFlowStore,
start_flow: Callable[[Request, Mapping[str, Any]], DeviceFlowStart],
poll_flow: Callable[[Request, Mapping[str, Any]], DeviceFlowPoll],
) -> APIRouter:
"""Create standard `/device/start|poll|cancel` routes for a provider."""
router = APIRouter(prefix=prefix, tags=list(tags))
@router.post("/device/start")
async def device_start(request: Request):
require_admin(request)
form = await request.form()
start = await _maybe_await(start_flow(request, form))
interval = int(start.interval or 5)
expires_in = int(start.expires_in or 900)
poll_id = store.add(start.pending, interval=interval, expires_in=expires_in)
response = dict(start.response)
response.update({"poll_id": poll_id, "interval": interval, "expires_in": expires_in})
return response
@router.post("/device/poll")
async def device_poll(request: Request, poll_id: str = Form(...)):
require_admin(request)
payload = store.get_payload(poll_id)
if payload is None:
raise HTTPException(404, "Unknown or expired login session")
if store.is_throttled(poll_id):
return {"status": "pending"}
try:
outcome = await _maybe_await(poll_flow(request, payload))
except Exception:
store.pop(poll_id)
raise
if outcome.status == "authorized":
store.pop(poll_id)
return {"status": "authorized", "endpoint": dict(outcome.endpoint or {})}
if outcome.status == "failed":
store.pop(poll_id)
return {"status": "failed", "error": outcome.error or "denied"}
if outcome.status == "slow_down":
store.slow_down(poll_id, outcome.interval)
return _pending_response(outcome.detail)
store.schedule_next(poll_id)
return _pending_response(outcome.detail)
@router.post("/device/cancel")
def device_cancel(request: Request, poll_id: str = Form(...)):
require_admin(request)
store.pop(poll_id)
return {"status": "cancelled"}
return router
+66 -36
View File
@@ -7,14 +7,24 @@ from typing import Dict, Any, List, Optional
from fastapi import APIRouter, HTTPException, Query, Request, UploadFile, File, Form
from sqlalchemy import func
from sqlalchemy import case, func, or_
from core.database import SessionLocal, Document, DocumentVersion
from core.database import Session as DbSession
from src.auth_helpers import get_current_user
from src.constants import MAIL_ATTACHMENTS_DIR
logger = logging.getLogger(__name__)
def _get_session_or_404(db, session_id: str, user: Optional[str]):
session = db.query(DbSession).filter(DbSession.id == session_id).first()
if not session:
raise HTTPException(404, "Session not found")
if user and session.owner != user:
raise HTTPException(404, "Session not found")
return session
def _aggregate_language_facets(lang_rows):
"""Sum document counts per display language for the library facet.
@@ -30,6 +40,19 @@ def _aggregate_language_facets(lang_rows):
return out
def _library_language_for_document(doc: Document) -> str:
"""Return the display language used by the document library.
PDF documents are stored as markdown wrappers so the editor can preserve
extracted text, form fields, and annotations. The library should still
identify them as PDFs instead of exposing that internal wrapper format.
"""
from src.pdf_form_doc import find_source_upload_id
if find_source_upload_id(doc.current_content or ""):
return "pdf"
return doc.language or "text"
from routes.document_helpers import (
DocumentCreate, DocumentUpdate, DocumentPatch,
@@ -69,17 +92,12 @@ def setup_document_routes(session_manager, upload_handler=None) -> APIRouter:
# the doc is owner-stamped, so it lives in the library on its own.
session = None
if req.session_id:
session = db.query(DbSession).filter(DbSession.id == req.session_id).first()
if not session:
raise HTTPException(404, "Session not found")
# Match the lenient ownership model the rest of the app uses
# (see _owner_filter): only block when an AUTHENTICATED user is
# writing into a DIFFERENT user's session. In single-user /
# unconfigured / localhost-bypass mode the middleware leaves
# current_user unset (None), and those sessions are already
# served freely everywhere else.
if user and session.owner and session.owner != user:
raise HTTPException(403, "Cannot create document in another user's session")
# unconfigured / localhost-bypass mode, falsey users preserve
# the existing lenient path.
session = _get_session_or_404(db, req.session_id, user)
doc_id = str(uuid.uuid4())
ver_id = str(uuid.uuid4())
@@ -171,11 +189,7 @@ def setup_document_routes(session_manager, upload_handler=None) -> APIRouter:
if session_id:
db = SessionLocal()
try:
sess = db.query(DbSession).filter(DbSession.id == session_id).first()
if not sess:
raise HTTPException(404, "Session not found")
if user and sess.owner and sess.owner != user:
raise HTTPException(403, "Cannot import into another user's session")
_get_session_or_404(db, session_id, user)
finally:
db.close()
@@ -198,7 +212,7 @@ def setup_document_routes(session_manager, upload_handler=None) -> APIRouter:
title = os.path.splitext(meta.get("original_name") or meta.get("name") or upload_id)[0]
try:
body_text = strip_pdf_content_marker(_process_pdf(pdf_path))
body_text = strip_pdf_content_marker(_process_pdf(pdf_path, owner=user))
except Exception:
body_text = None
@@ -260,18 +274,29 @@ def setup_document_routes(session_manager, upload_handler=None) -> APIRouter:
db = SessionLocal()
try:
from sqlalchemy import or_
pdf_marker_cond = or_(
Document.current_content.like('%<!-- pdf_source upload_id="%'),
Document.current_content.like('%<!-- pdf_form_source upload_id="%'),
)
library_language_expr = case(
(pdf_marker_cond, "pdf"),
(Document.language.is_(None), "text"),
else_=Document.language,
)
# Archived view shows ONLY archived docs; the default view excludes
# them (NULL = legacy rows that predate the column = not archived).
_arch_cond = (Document.archived == True) if archived else or_(
Document.archived == False, Document.archived.is_(None))
# Language facet counts (owner-filtered)
# Language facet counts (owner-filtered). PDF documents are stored
# as markdown wrappers, so group by the library display language
# instead of the raw stored language.
lang_q = (
db.query(Document.language, func.count(Document.id))
db.query(library_language_expr, func.count(Document.id))
.outerjoin(DbSession, Document.session_id == DbSession.id)
.filter(Document.is_active == True).filter(_arch_cond)
)
lang_q = _owner_session_filter(lang_q, user)
lang_rows = lang_q.group_by(Document.language).all()
lang_rows = lang_q.group_by(library_language_expr).all()
languages = _aggregate_language_facets(lang_rows)
# Session count (owner-filtered)
@@ -303,12 +328,17 @@ def setup_document_routes(session_manager, upload_handler=None) -> APIRouter:
Document.title.ilike(term) | Document.current_content.ilike(term)
)
# Language filter
# Language filter. "pdf" is a display language derived from the
# source marker; "markdown" excludes those wrappers.
if language:
if language == "text":
q = q.filter((Document.language == None) | (Document.language == "text"))
elif language == "pdf":
q = q.filter(pdf_marker_cond)
else:
q = q.filter(Document.language == language)
if language == "markdown":
q = q.filter(~pdf_marker_cond)
# Total before pagination
total = q.count()
@@ -332,7 +362,7 @@ def setup_document_routes(session_manager, upload_handler=None) -> APIRouter:
"session_id": doc.session_id,
"session_name": session_name,
"title": doc.title,
"language": doc.language or "text",
"language": _library_language_for_document(doc),
"preview": (doc.current_content or "")[:500],
"version_count": doc.version_count,
"created_at": (doc.created_at.isoformat() + "Z") if doc.created_at else None,
@@ -359,18 +389,17 @@ def setup_document_routes(session_manager, upload_handler=None) -> APIRouter:
try:
if not user:
raise HTTPException(403, "Authentication required")
session = db.query(DbSession).filter(DbSession.id == session_id).first()
# v2 review HIGH-9: raise 403 explicitly when the caller
# can't see this session, instead of returning [] which the
# UI treats identically to "no docs" and silently masks
# auth failures.
if not session:
raise HTTPException(404, "Session not found")
if user and session.owner and session.owner != user:
raise HTTPException(403, "Access denied")
docs = db.query(Document).filter(
_get_session_or_404(db, session_id, user)
q = db.query(Document).filter(
Document.session_id == session_id
).order_by(Document.created_at.desc()).all()
)
if user:
q = q.filter(or_(Document.owner == user, Document.owner.is_(None)))
docs = q.order_by(Document.created_at.desc()).all()
return [_doc_to_dict(d) for d in docs]
finally:
db.close()
@@ -437,7 +466,7 @@ def setup_document_routes(session_manager, upload_handler=None) -> APIRouter:
raise HTTPException(404, "Source PDF could not be located")
try:
body_text = strip_pdf_content_marker(_process_pdf(pdf_path))
body_text = strip_pdf_content_marker(_process_pdf(pdf_path, owner=user))
except Exception as e:
logger.error(f"extract_pdf_text failed for {pdf_path}: {e}")
raise HTTPException(500, f"Extraction failed: {e}")
@@ -606,6 +635,8 @@ def setup_document_routes(session_manager, upload_handler=None) -> APIRouter:
doc.language = req.language
if req.session_id is not None:
# Empty string = unlink from session
if req.session_id:
_get_session_or_404(db, req.session_id, user)
doc.session_id = req.session_id if req.session_id else None
if not req.session_id:
# Tab closed / doc detached from its session — drop the
@@ -855,10 +886,10 @@ def setup_document_routes(session_manager, upload_handler=None) -> APIRouter:
from src.llm_core import llm_call_async
user = get_current_user(request)
url, model, headers = resolve_task_endpoint()
url, model, headers = resolve_task_endpoint(owner=user or None)
if not url or not model:
# Fall back to default endpoint
url, model, headers = resolve_endpoint("default")
url, model, headers = resolve_endpoint("default", owner=user or None)
if not url or not model:
raise HTTPException(500, "No endpoint configured for AI tidy")
@@ -1158,7 +1189,7 @@ def setup_document_routes(session_manager, upload_handler=None) -> APIRouter:
settings = _load_vl_settings()
vl_model = settings.get("vision_model", "")
try:
url, model_id, headers = _resolve_vl_model(vl_model)
url, model_id, headers = _resolve_vl_model(vl_model, owner=user)
except Exception as e:
raise HTTPException(503, f"No vision model available: {e}")
@@ -1512,10 +1543,7 @@ def setup_document_routes(session_manager, upload_handler=None) -> APIRouter:
# don't import from a routes file (cycle-prone). Same env override
# as email_routes (ODYSSEUS_MAIL_ATTACHMENTS_DIR).
from pathlib import Path as _Path
import os as _os
_DATA_DIR = _Path(__file__).resolve().parent.parent / "data"
_BASE = _os.environ.get("ODYSSEUS_MAIL_ATTACHMENTS_DIR", str(_DATA_DIR / "mail-attachments"))
_COMPOSE_DIR = _Path(_BASE) / "_compose"
_COMPOSE_DIR = _Path(MAIL_ATTACHMENTS_DIR) / "_compose"
_COMPOSE_DIR.mkdir(parents=True, exist_ok=True)
user = get_current_user(request)
@@ -1631,9 +1659,11 @@ def setup_document_routes(session_manager, upload_handler=None) -> APIRouter:
# context (To/Subject/In-Reply-To/References).
try:
from routes.email_routes import _imap, _decode_header
from routes.email_helpers import _q
except Exception:
_imap = None
_decode_header = lambda x: x or ""
_q = lambda x: x or ""
to_addr = ""
from_name = ""
@@ -1643,7 +1673,7 @@ def setup_document_routes(session_manager, upload_handler=None) -> APIRouter:
if _imap:
try:
with _imap(doc.source_email_account_id or None) as conn:
conn.select(doc.source_email_folder, readonly=True)
conn.select(_q(doc.source_email_folder), readonly=True)
status, data = conn.fetch(doc.source_email_uid.encode(), "(RFC822.HEADER)")
if status == "OK" and data and data[0]:
raw_hdr = data[0][1]
+98 -33
View File
@@ -71,6 +71,38 @@ def _send_smtp_message(cfg: dict, from_addr: str, recipients: list[str], message
smtp.sendmail(from_addr, recipients, message)
def _friendly_email_auth_error(protocol: str, host: str, error: object) -> str:
"""Return a clearer setup error for known provider auth policies."""
raw = str(error or "")
lower = raw.lower()
host_lower = (host or "").lower()
microsoft_host = any(
marker in host_lower
for marker in (
"outlook.office365.com",
"smtp.office365.com",
"office365.com",
"outlook.com",
"hotmail.com",
"live.com",
)
)
microsoft_basic_auth_failure = (
"5.7.139" in lower
or "basic authentication is disabled" in lower
or ("authenticate failed" in lower and microsoft_host)
or ("authentication unsuccessful" in lower and microsoft_host)
)
if microsoft_basic_auth_failure:
return (
"Microsoft no longer accepts normal mailbox passwords for "
"Outlook/Office 365 IMAP/SMTP in most accounts. Odysseus "
"does not support Microsoft OAuth/Graph mail yet, so Outlook "
"accounts cannot be added with this password form."
)
return raw[:200]
def _strip_think(text: str) -> str:
"""Email-flavored think strip — thin wrapper over the central helper.
@@ -254,16 +286,17 @@ def _cleanup_compose_uploads(tokens) -> None:
pass
DATA_DIR = Path(__file__).resolve().parent.parent / "data"
SETTINGS_FILE = DATA_DIR / "settings.json"
from src.constants import DATA_DIR as _DATA_DIR, MAIL_ATTACHMENTS_DIR, SETTINGS_FILE as _SETTINGS_FILE, SCHEDULED_EMAILS_DB
DATA_DIR = Path(_DATA_DIR)
SETTINGS_FILE = Path(_SETTINGS_FILE)
# Override at deploy time via ODYSSEUS_MAIL_ATTACHMENTS_DIR. Defaults to a
# subdir of the install's data/ tree so the app works out-of-the-box without
# a hardcoded /home/<user>/ path.
ATTACHMENTS_DIR = Path(os.environ.get("ODYSSEUS_MAIL_ATTACHMENTS_DIR", str(DATA_DIR / "mail-attachments")))
ATTACHMENTS_DIR = Path(MAIL_ATTACHMENTS_DIR)
ATTACHMENTS_DIR.mkdir(parents=True, exist_ok=True)
COMPOSE_UPLOADS_DIR = ATTACHMENTS_DIR / "_compose"
COMPOSE_UPLOADS_DIR.mkdir(parents=True, exist_ok=True)
SCHEDULED_DB = DATA_DIR / "scheduled_emails.db"
SCHEDULED_DB = Path(SCHEDULED_EMAILS_DB)
OWNER_SCOPED_EMAIL_CACHE_TABLES = {
@@ -705,7 +738,16 @@ def _open_imap_connection(host: str, port: int, *, starttls: bool, timeout: int
port = int(port or 993)
if starttls:
conn = imaplib.IMAP4(host, port, timeout=timeout)
conn.starttls()
try:
conn.starttls()
except Exception:
# Don't leak the open plain socket if the STARTTLS upgrade is
# rejected; close it before propagating. (#3174)
try:
conn.shutdown()
except Exception:
pass
raise
elif port == 993:
conn = imaplib.IMAP4_SSL(host, port, timeout=timeout)
else:
@@ -714,6 +756,10 @@ def _open_imap_connection(host: str, port: int, *, starttls: bool, timeout: int
conn.sock.settimeout(timeout)
except Exception:
pass
# Raise the IMAP line-length limit from the default 1 MB to 50 MB so that
# large mailboxes (tens of thousands of messages) don't crash with
# "got more than 1000000 bytes" on UID SEARCH ALL. (#2883)
imaplib._MAXLINE = 50_000_000
return conn
def _imap_connect(account_id: str | None = None, owner: str = ""):
@@ -734,7 +780,18 @@ def _imap_connect(account_id: str | None = None, owner: str = ""):
starttls=bool(cfg.get("imap_starttls")),
timeout=_IMAP_TIMEOUT_SECONDS,
)
conn.login(cfg["imap_user"], cfg["imap_password"])
try:
conn.login(cfg["imap_user"], cfg["imap_password"])
except Exception:
# A failed AUTHENTICATE (e.g. an Office 365 app password on an
# MFA-enabled tenant, #3174) otherwise orphans the already-connected
# socket; close it before propagating so a misconfigured account
# can't leak one descriptor per retry / background poller pass.
try:
conn.shutdown()
except Exception:
pass
raise
return conn
@@ -798,20 +855,28 @@ def _imap(account_id: str | None = None, owner: str = ""):
def _decode_header(raw):
if not raw:
return ""
parts = email.header.decode_header(raw)
decoded = []
for data, charset in parts:
if isinstance(data, bytes):
try:
decoded.append(data.decode(charset or "utf-8", errors="replace"))
except (LookupError, ValueError):
# Unknown/invalid MIME charset (e.g. a malformed or spam header
# like =?x-unknown-charset?B?...?=). errors="replace" only covers
# byte-decode errors, not codec lookup, so fall back to utf-8.
decoded.append(data.decode("utf-8", errors="replace"))
else:
decoded.append(data)
return " ".join(decoded)
try:
# make_header concatenates per RFC 2047: no spurious space between an
# encoded-word and adjacent plain text (plain runs keep their own
# whitespace), and the whitespace between two adjacent encoded-words is
# dropped. The old " ".join produced "Re: Jose"-style double spaces on
# every non-ASCII subject or sender.
return str(email.header.make_header(email.header.decode_header(raw)))
except Exception:
# Malformed header or unknown/invalid MIME charset (e.g. a spam header
# like =?x-unknown-charset?B?...?=) makes make_header raise LookupError;
# fall back to a lossy per-part decode. errors="replace" only covers
# byte-decode errors, not codec lookup, hence the explicit utf-8 retry.
decoded = []
for data, charset in email.header.decode_header(raw):
if isinstance(data, bytes):
try:
decoded.append(data.decode(charset or "utf-8", errors="replace"))
except (LookupError, ValueError):
decoded.append(data.decode("utf-8", errors="replace"))
else:
decoded.append(data)
return "".join(decoded)
def _detect_sent_folder(conn):
@@ -1136,13 +1201,9 @@ def _fetch_sender_thread_context(sender_addr: str,
if exclude_uid:
seen_uids.add((exclude_folder or "INBOX", str(exclude_uid)))
conn = None
try:
conn = _imap_connect(account_id, owner=owner)
except Exception as e:
logger.warning(f"sender-thread-context: imap connect failed: {e}")
return ""
try:
for folder in ["INBOX", "Sent", "Archive", "Drafts"]:
if len(blocks) >= limit:
break
@@ -1209,11 +1270,14 @@ def _fetch_sender_thread_context(sender_addr: str,
if atts_text:
lines.append(atts_text)
blocks.append("\n".join(lines))
except Exception as e:
logger.warning(f"sender-thread-context: imap failed: {e}")
finally:
try: conn.close()
except Exception: pass
try: conn.logout()
except Exception: pass
if conn:
try: conn.close()
except Exception: pass
try: conn.logout()
except Exception: pass
if not blocks:
return ""
@@ -1316,6 +1380,7 @@ def _pre_retrieve_context(
if not terms_list:
return context_snippets, terms_list
ctx_conn = None
try:
ctx_conn = _imap_connect(account_id, owner=owner)
for folder in ["INBOX", "Sent", "Archive", "Drafts"]:
@@ -1352,12 +1417,12 @@ def _pre_retrieve_context(
except Exception as _e:
logger.warning(f" search {folder} {term!r} failed: {_e}")
continue
try:
ctx_conn.logout()
except Exception:
pass
except Exception as _e:
logger.warning(f"IMAP context search failed: {_e}")
finally:
if ctx_conn:
try: ctx_conn.logout()
except Exception: pass
try:
from routes.contacts_routes import _fetch_contacts
+2 -2
View File
@@ -210,7 +210,7 @@ async def _auto_summarize_pass_single(days_back: int = 1, account_id: str | None
if auto_cal:
for sent_name in ("Sent", "INBOX/Sent", "Sent Items", "[Gmail]/Sent Mail"):
try:
st, _ = conn.select(sent_name, readonly=True)
st, _ = conn.select(_q(sent_name), readonly=True)
if st == "OK":
folders_to_scan.append(sent_name)
break
@@ -1046,7 +1046,7 @@ def _scheduled_poll_once() -> dict:
try:
with _imap(row_account_id, owner=row_owner) as imap:
sent_folder = _detect_sent_folder(imap)
imap.append(sent_folder, "\\Seen", None, outer.as_bytes())
imap.append(_q(sent_folder), "\\Seen", None, outer.as_bytes())
except Exception as e:
logger.warning(f"Failed to append scheduled {sid} to Sent: {e}")
+6 -5
View File
@@ -32,9 +32,10 @@ from email.mime.multipart import MIMEMultipart
from fastapi import APIRouter, Query, UploadFile, File, BackgroundTasks, HTTPException, Depends, Request
from fastapi.responses import FileResponse
from src.constants import DATA_DIR
from src.llm_core import llm_call_async
from src.upload_limits import read_upload_limited
from src.upload_limits import read_upload_limited, EMAIL_COMPOSE_UPLOAD_MAX_BYTES
from routes.email_helpers import (
_strip_think, _extract_reply, _apply_email_style_mechanics, require_owner, require_user, _assert_owns_account,
@@ -47,6 +48,7 @@ from routes.email_helpers import (
_extract_attachment_to_disk, _extract_html, _extract_text,
_fetch_sender_thread_context, _pre_retrieve_context,
_EMAIL_REPLY_SYS_PROMPT_BASE, _POOL_HOOKS,
_friendly_email_auth_error,
SendEmailRequest, ExtractStyleRequest,
ATTACHMENTS_DIR, COMPOSE_UPLOADS_DIR, SCHEDULED_DB,
attachment_extract_dir, _email_cache_owner_clause,
@@ -56,7 +58,6 @@ from routes.email_pollers import _start_poller
logger = logging.getLogger(__name__)
ODYSSEUS_MAIL_ORIGIN = "odysseus-ui"
EMAIL_COMPOSE_UPLOAD_MAX_BYTES = 25 * 1024 * 1024
def _email_tag_owner_aliases(account_id: str | None, owner: str = "") -> list[str]:
@@ -2904,7 +2905,7 @@ def setup_email_routes():
from pathlib import Path as _P
import json as _json
_slug = "".join(c if (c.isalnum() or c in "-_.@") else "_" for c in (owner or "default"))
path = _P(f"data/email_urgency_state_{_slug}.json")
path = _P(DATA_DIR) / f"email_urgency_state_{_slug}.json"
if not path.exists():
return {"total_unread": 0, "total_urgent": 0, "max_score": 0, "per_uid": {}}
try:
@@ -3162,7 +3163,7 @@ def setup_email_routes():
try: conn.logout()
except Exception: pass
except Exception as e:
imap_result = {"ok": False, "error": str(e)[:200]}
imap_result = {"ok": False, "error": _friendly_email_auth_error("IMAP", imap_host, e)}
smtp_host = (body.get("smtp_host") or "").strip()
if smtp_host:
@@ -3184,7 +3185,7 @@ def setup_email_routes():
try: smtp.quit()
except Exception: pass
except Exception as e:
smtp_result = {"ok": False, "error": str(e)[:200]}
smtp_result = {"ok": False, "error": _friendly_email_auth_error("SMTP", smtp_host, e)}
return {
"ok": imap_result["ok"] and (smtp_result is None or smtp_result["ok"]),
+65 -22
View File
@@ -7,12 +7,12 @@ import logging
import asyncio
from pathlib import Path
from fastapi import APIRouter, HTTPException, Form, Depends
from core.constants import BASE_DIR
from core.constants import EMBEDDING_ENDPOINT_FILE, FASTEMBED_CACHE_DIR
from core.middleware import require_admin
logger = logging.getLogger(__name__)
_ENDPOINT_FILE = os.path.join(BASE_DIR, "data", "embedding_endpoint.json")
_ENDPOINT_FILE = EMBEDDING_ENDPOINT_FILE
# Track in-progress downloads
_downloading: dict = {}
@@ -35,13 +35,7 @@ def _cache_dir() -> str:
default lived in /tmp, which many systems wipe on reboot forcing a
full re-download of the embedding model after every restart.
"""
env = os.environ.get("FASTEMBED_CACHE_PATH")
if env:
return env
return os.path.join(
os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
"data", "fastembed_cache",
)
return FASTEMBED_CACHE_DIR
def _model_cache_name(hf_source: str) -> str:
@@ -49,19 +43,35 @@ def _model_cache_name(hf_source: str) -> str:
return "models--" + hf_source.replace("/", "--")
def _model_cache_path(hf_source: str) -> Path:
"""Return a confined cache path for a fastembed HF source."""
root = Path(_cache_dir()).expanduser().resolve()
raw_path = root / _model_cache_name(hf_source)
if raw_path.is_symlink():
raise ValueError("Model cache path must not be a symlink")
path = raw_path.resolve(strict=False)
try:
path.relative_to(root)
except ValueError:
raise ValueError("Model cache path escapes cache root")
return path
def _is_downloaded(hf_source: str) -> bool:
"""Check if a model is already cached."""
cache = _cache_dir()
model_dir = os.path.join(cache, _model_cache_name(hf_source))
if not os.path.isdir(model_dir):
try:
model_dir = _model_cache_path(hf_source)
except ValueError:
return False
if not model_dir.is_dir():
return False
# Check for actual model files (not just empty dir)
snapshots = os.path.join(model_dir, "snapshots")
if os.path.isdir(snapshots):
return any(os.listdir(snapshots))
snapshots = model_dir / "snapshots"
if snapshots.is_dir():
return any(snapshots.iterdir())
# Also check for blobs (older cache format)
blobs = os.path.join(model_dir, "blobs")
return os.path.isdir(blobs) and any(os.listdir(blobs))
blobs = model_dir / "blobs"
return blobs.is_dir() and any(blobs.iterdir())
def _active_model() -> str:
@@ -119,8 +129,10 @@ def setup_embedding_routes():
cached_size = None
if downloaded and hf_src:
model_path = os.path.join(_cache_dir(), _model_cache_name(hf_src))
cached_size = _dir_size_mb(model_path)
try:
cached_size = _dir_size_mb(str(_model_cache_path(hf_src)))
except ValueError:
cached_size = None
result.append({
"model": m["model"],
@@ -217,8 +229,11 @@ def setup_embedding_routes():
if not hf_src:
raise HTTPException(400, "No cache source for this model")
model_path = os.path.join(_cache_dir(), _model_cache_name(hf_src))
if not os.path.isdir(model_path):
try:
model_path = _model_cache_path(hf_src)
except ValueError as e:
raise HTTPException(400, str(e))
if not model_path.is_dir():
return {"deleted": False, "message": "Model not cached"}
shutil.rmtree(model_path)
@@ -237,7 +252,7 @@ def setup_embedding_routes():
}
@router.post("/endpoint")
def set_endpoint(url: str = Form(...), model: str = Form("")):
def set_endpoint(url: str = Form(...), model: str = Form(""), api_key: str = Form("")):
"""Save a custom embedding endpoint URL."""
url = url.strip()
if not url:
@@ -261,6 +276,7 @@ def setup_embedding_routes():
resp = httpx.post(
url,
json={"input": ["test"], "model": model or "test"},
headers={"Authorization": f"Bearer {api_key}"} if api_key else {},
timeout=10,
)
resp.raise_for_status()
@@ -271,10 +287,16 @@ def setup_embedding_routes():
data = {"url": url}
if model:
data["model"] = model
if api_key:
from src.secret_storage import encrypt
data["api_key"] = encrypt(api_key)
_save_custom_endpoint(data)
os.environ["EMBEDDING_URL"] = url
if model:
os.environ["EMBEDDING_MODEL"] = model
if api_key:
os.environ["EMBEDDING_API_KEY"] = api_key
# Reset the RAG singleton so it picks up the new endpoint
import src.rag_singleton as _rs
@@ -288,6 +310,16 @@ def setup_embedding_routes():
reset_http_embed_state()
except Exception:
pass
try:
from src.embedding_lanes import reset_embedding_lane_state
reset_embedding_lane_state()
except Exception:
pass
try:
from src.tool_index import reset_tool_index
reset_tool_index()
except Exception:
pass
# Reset ChromaDB client (collections will be recreated with new embeddings)
try:
@@ -308,6 +340,7 @@ def setup_embedding_routes():
# Remove from environment
os.environ.pop("EMBEDDING_URL", None)
os.environ.pop("EMBEDDING_MODEL", None)
os.environ.pop("EMBEDDING_API_KEY", None)
# Reset the RAG singleton so it falls back to fastembed
import src.rag_singleton as _rs
@@ -318,6 +351,16 @@ def setup_embedding_routes():
reset_http_embed_state()
except Exception:
pass
try:
from src.embedding_lanes import reset_embedding_lane_state
reset_embedding_lane_state()
except Exception:
pass
try:
from src.tool_index import reset_tool_index
reset_tool_index()
except Exception:
pass
# Reset ChromaDB client
try:
+45 -6
View File
@@ -16,22 +16,54 @@ from pathlib import Path
import httpx
from fastapi import APIRouter
from fastapi.responses import FileResponse, Response
from fastapi.responses import Response
from src.constants import EMOJI_CACHE_DIR
logger = logging.getLogger(__name__)
_CACHE_DIR = Path(__file__).resolve().parent.parent / "data" / "emoji_cache"
_CACHE_DIR = Path(EMOJI_CACHE_DIR)
# OpenMoji "black" set = monochrome line-art SVGs. Filenames are the codepoints
# in UPPERCASE (FE0F dropped, same as we compute), '-' joined.
_OPENMOJI_BASE = "https://cdn.jsdelivr.net/npm/openmoji@15.0.0/black/svg"
# codepoints like "1f600" or "1f468-200d-1f469-200d-1f467" (lowercase hex, '-' joined)
_CODE_RE = re.compile(r"^[0-9a-f]{2,6}(?:-[0-9a-f]{2,6})*$")
_SVG_HEADERS = {"Cache-Control": "public, max-age=31536000, immutable"}
_MAX_SVG_BYTES = 256 * 1024
_BLOCKED_SVG_RE = re.compile(
br"<\s*(?:script|foreignObject|iframe|object|embed|image)\b|"
br"\bon[a-z0-9_-]+\s*=",
re.IGNORECASE,
)
_EXTERNAL_REF_RE = re.compile(
br"\b(?:href|xlink:href)\s*=\s*['\"](?:https?:|//|data:|javascript:)",
re.IGNORECASE,
)
_SVG_SECURITY_HEADERS = {
"X-Content-Type-Options": "nosniff",
"Content-Security-Policy": "sandbox",
"Cross-Origin-Resource-Policy": "same-origin",
}
_SVG_HEADERS = {
"Cache-Control": "public, max-age=31536000, immutable",
**_SVG_SECURITY_HEADERS,
}
# Returned when a codepoint is unknown/unreachable: an empty (transparent) SVG,
# so the CSS mask renders nothing instead of a solid box. Not cached, so a later
# request can still pick up the real glyph once the CDN is reachable.
_BLANK_SVG = b'<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 1 1"></svg>'
_BLANK_HEADERS = {"Cache-Control": "no-store"}
_BLANK_HEADERS = {"Cache-Control": "no-store", **_SVG_SECURITY_HEADERS}
def _is_safe_svg(content: bytes) -> bool:
if not isinstance(content, bytes) or not content:
return False
if len(content) > _MAX_SVG_BYTES:
return False
if b"<svg" not in content[:256].lower():
return False
if _BLOCKED_SVG_RE.search(content) or _EXTERNAL_REF_RE.search(content):
return False
return True
def setup_emoji_routes() -> APIRouter:
@@ -49,14 +81,21 @@ def setup_emoji_routes() -> APIRouter:
_CACHE_DIR.mkdir(parents=True, exist_ok=True)
fp = _CACHE_DIR / f"{code}.svg"
if fp.exists():
return FileResponse(fp, media_type="image/svg+xml", headers=_SVG_HEADERS)
try:
content = fp.read_bytes()
if _is_safe_svg(content):
return Response(content, media_type="image/svg+xml", headers=_SVG_HEADERS)
fp.unlink(missing_ok=True)
except Exception as e:
logger.warning("emoji cache read %s failed: %s", code, e)
return _blank()
# First time we've seen this emoji — fetch the OpenMoji black SVG + cache
# it. OpenMoji filenames are the codepoints uppercased.
try:
async with httpx.AsyncClient(timeout=8.0) as client:
r = await client.get(f"{_OPENMOJI_BASE}/{code.upper()}.svg")
if r.status_code == 200 and b"<svg" in r.content[:256]:
if r.status_code == 200 and _is_safe_svg(r.content):
try:
fp.write_bytes(r.content)
except Exception:
+144 -54
View File
@@ -12,8 +12,13 @@ from fastapi import APIRouter, HTTPException, Query, Request
from core.database import SessionLocal, GalleryImage, GalleryAlbum, ModelEndpoint
from core.database import Session as DbSession
from src.auth_helpers import get_current_user, require_privilege
from src.upload_limits import read_upload_limited
from src.auth_helpers import get_current_user, owner_filter, require_privilege
from src.upload_limits import (
read_upload_limited,
GALLERY_UPLOAD_MAX_BYTES,
GALLERY_TRANSFORM_UPLOAD_MAX_BYTES,
)
from src.constants import GENERATED_IMAGES_DIR
from routes.gallery_helpers import (
GalleryPatch, _extract_exif, _image_to_dict, _owner_filter, _human_size,
@@ -21,17 +26,88 @@ from routes.gallery_helpers import (
logger = logging.getLogger(__name__)
GALLERY_UPLOAD_MAX_BYTES = int(os.getenv("ODYSSEUS_GALLERY_UPLOAD_MAX_BYTES", str(100 * 1024 * 1024)))
GALLERY_TRANSFORM_UPLOAD_MAX_BYTES = int(os.getenv("ODYSSEUS_GALLERY_TRANSFORM_UPLOAD_MAX_BYTES", str(25 * 1024 * 1024)))
def _current_user_is_admin(request: Request, user: str | None) -> bool:
if not user:
return False
auth_mgr = getattr(request.app.state, "auth_manager", None)
is_admin = getattr(auth_mgr, "is_admin", None)
if not callable(is_admin):
return False
try:
return bool(is_admin(user))
except Exception:
return False
def _sanitize_gallery_filename(filename: str) -> str:
"""Return a local filename safe to join under generated_images."""
safe_name = re.sub(r"[^A-Za-z0-9._-]", "_", Path(filename or "").name)[:128]
safe_name = re.sub(r"[^A-Za-z0-9._-]", "_", Path(str(filename or "")).name)[:128]
if not safe_name or safe_name in {".", ".."}:
safe_name = uuid.uuid4().hex[:12]
return safe_name
GALLERY_IMAGE_DIR = Path(GENERATED_IMAGES_DIR)
def _gallery_image_path(filename: str) -> Path:
"""Resolve a stored gallery filename without leaving generated_images."""
if not isinstance(filename, str):
raise HTTPException(400, "Unsafe gallery filename")
safe_name = _sanitize_gallery_filename(filename)
original = str(filename or "")
root = GALLERY_IMAGE_DIR.resolve()
path = (GALLERY_IMAGE_DIR / safe_name).resolve()
try:
if os.path.commonpath([str(root), str(path)]) != str(root):
raise ValueError
except Exception:
raise HTTPException(400, "Unsafe gallery filename")
if safe_name != original:
raise HTTPException(400, "Unsafe gallery filename")
return path
def _normalize_image_endpoint_base(url: str) -> str:
base = (url or "").strip().rstrip("/")
if base.endswith("/v1"):
base = base[:-3].rstrip("/")
return base
def _visible_image_endpoint_query(db, owner: str | None):
from src.auth_helpers import owner_filter
q = db.query(ModelEndpoint).filter(
ModelEndpoint.model_type == "image",
ModelEndpoint.is_enabled == True, # noqa: E712
)
return owner_filter(q, ModelEndpoint, owner)
def _first_visible_image_endpoint(db, owner: str | None):
endpoints = _visible_image_endpoint_query(db, owner).all()
if owner:
for ep in endpoints:
if getattr(ep, "owner", None) == owner:
return ep
return endpoints[0] if endpoints else None
def _visible_image_endpoint_for_base(db, base: str, owner: str | None):
target = _normalize_image_endpoint_base(base)
if not target:
return None
fallback = None
for ep in _visible_image_endpoint_query(db, owner).all():
if _normalize_image_endpoint_base(getattr(ep, "base_url", "")) == target:
if owner and getattr(ep, "owner", None) == owner:
return ep
if fallback is None:
fallback = ep
return fallback
def setup_gallery_routes() -> APIRouter:
router = APIRouter(tags=["gallery"])
@@ -55,6 +131,9 @@ def setup_gallery_routes() -> APIRouter:
file_hash = hashlib.sha256(content).hexdigest()
db = SessionLocal()
try:
if album_id and user is not None:
_get_or_404_album(db, album_id, user)
# SECURITY: scope the dup-detect to THIS user — otherwise a
# caller can probe whether someone else uploaded the same
# file (the response leaks the existing row's id+filename).
@@ -69,7 +148,7 @@ def setup_gallery_routes() -> APIRouter:
return {"ok": False, "duplicate": True, "filename": existing.filename,
"id": existing.id, "message": "Duplicate photo skipped"}
img_dir = Path("data/generated_images")
img_dir = Path(GENERATED_IMAGES_DIR)
img_dir.mkdir(parents=True, exist_ok=True)
ext = file.filename.rsplit(".", 1)[-1].lower() if "." in file.filename else "png"
@@ -135,7 +214,7 @@ def setup_gallery_routes() -> APIRouter:
raise HTTPException(400, "No image provided")
content = await read_upload_limited(file, GALLERY_UPLOAD_MAX_BYTES, "Gallery replacement")
img_dir = Path("data/generated_images")
img_dir = Path(GENERATED_IMAGES_DIR)
img_dir.mkdir(parents=True, exist_ok=True)
img_path = img_dir / _sanitize_gallery_filename(img.filename)
img_path.write_bytes(content)
@@ -211,7 +290,7 @@ def setup_gallery_routes() -> APIRouter:
if not user or img.owner != user:
raise HTTPException(403, "Not your image")
img_path = Path("data/generated_images") / img.filename
img_path = _gallery_image_path(img.filename)
if not img_path.exists():
raise HTTPException(404, "Image file not found")
@@ -248,7 +327,7 @@ def setup_gallery_routes() -> APIRouter:
"""AI upscale using img2img with the diffusion server."""
import base64, httpx
require_privilege(request, "can_generate_images")
user = require_privilege(request, "can_generate_images")
form = await request.form()
file = form.get("image")
if not file: raise HTTPException(400, "No image")
@@ -260,7 +339,7 @@ def setup_gallery_routes() -> APIRouter:
# Find image endpoint
db = SessionLocal()
try:
ep = db.query(ModelEndpoint).filter(ModelEndpoint.model_type == "image", ModelEndpoint.is_enabled == True).first()
ep = _first_visible_image_endpoint(db, user)
finally:
db.close()
@@ -291,7 +370,7 @@ def setup_gallery_routes() -> APIRouter:
"""Style transfer using img2img with the diffusion server."""
import base64, httpx
require_privilege(request, "can_generate_images")
user = require_privilege(request, "can_generate_images")
form = await request.form()
file = form.get("image")
prompt = form.get("prompt", "")
@@ -303,7 +382,7 @@ def setup_gallery_routes() -> APIRouter:
db = SessionLocal()
try:
ep = db.query(ModelEndpoint).filter(ModelEndpoint.model_type == "image", ModelEndpoint.is_enabled == True).first()
ep = _first_visible_image_endpoint(db, user)
finally:
db.close()
@@ -505,18 +584,24 @@ def setup_gallery_routes() -> APIRouter:
albums = q.order_by(GalleryAlbum.created_at.desc()).all()
result = []
for a in albums:
count = db.query(GalleryImage).filter(
_count_q = db.query(GalleryImage).filter(
GalleryImage.album_id == a.id, GalleryImage.is_active == True
).count()
)
if user:
_count_q = _count_q.filter(GalleryImage.owner == user)
count = _count_q.count()
cover_url = None
if a.cover_id:
cover = db.query(GalleryImage).filter(GalleryImage.id == a.cover_id).first()
if cover:
cover_url = f"/api/generated-image/{cover.filename}"
elif count > 0:
first = db.query(GalleryImage).filter(
_cover_q = db.query(GalleryImage).filter(
GalleryImage.album_id == a.id, GalleryImage.is_active == True
).order_by(GalleryImage.created_at.desc()).first()
)
if user:
_cover_q = _cover_q.filter(GalleryImage.owner == user)
first = _cover_q.order_by(GalleryImage.created_at.desc()).first()
if first:
cover_url = f"/api/generated-image/{first.filename}"
result.append({
@@ -649,7 +734,14 @@ def setup_gallery_routes() -> APIRouter:
if req.favorite is not None:
img.favorite = req.favorite
if req.album_id is not None:
img.album_id = req.album_id if req.album_id else None
if req.album_id:
# Validate the target album belongs to the caller before
# moving the image into it — mirrors add_to_album, so you
# cannot file your image into another user's album.
_get_or_404_album(db, req.album_id, user)
img.album_id = req.album_id
else:
img.album_id = None
db.commit()
db.refresh(img)
return _image_to_dict(img)
@@ -692,11 +784,11 @@ def setup_gallery_routes() -> APIRouter:
used = set()
with zipfile.ZipFile(buf, "w", zipfile.ZIP_DEFLATED) as zf:
for img in imgs:
src = os.path.join("data", "generated_images", img.filename)
if not os.path.exists(src):
src = _gallery_image_path(img.filename)
if not src.exists():
continue
ext = os.path.splitext(img.filename)[1] or ".png"
base = (img.prompt or "").strip() or os.path.splitext(img.filename)[0]
ext = src.suffix or ".png"
base = (img.prompt or "").strip() or src.stem
base = re.sub(r"[^\w\-. ]+", "", base)[:60].strip() or img.id
name = f"{base}{ext}"
i = 1
@@ -818,9 +910,9 @@ def setup_gallery_routes() -> APIRouter:
img_filename = img.filename
# Remove the file from disk
img_path = os.path.join("data", "generated_images", img_filename)
if os.path.exists(img_path):
os.remove(img_path)
img_path = _gallery_image_path(img_filename)
if img_path.exists():
img_path.unlink()
# Soft-delete the record
img.is_active = False
@@ -923,7 +1015,7 @@ def setup_gallery_routes() -> APIRouter:
the request for /v1/images/edits (multipart, inverted mask). Otherwise
proxy through to a self-hosted diffusion server's /v1/images/inpaint."""
import httpx
require_privilege(request, "can_generate_images")
user = require_privilege(request, "can_generate_images")
body = await request.json()
# Use endpoint from request body (editor dropdown) or fall back to DB lookup
base = (body.pop("_endpoint", "") or "").rstrip("/")
@@ -942,14 +1034,11 @@ def setup_gallery_routes() -> APIRouter:
if not base:
db = SessionLocal()
try:
eps = db.query(ModelEndpoint).filter(
ModelEndpoint.is_enabled == True,
ModelEndpoint.model_type == "image",
).all()
if not eps:
ep = _first_visible_image_endpoint(db, user)
if not ep:
raise HTTPException(400, "No image generation endpoint configured. Serve a diffusion model via Cookbook first.")
base = eps[0].base_url.rstrip("/")
api_key = eps[0].api_key
base = ep.base_url.rstrip("/")
api_key = ep.api_key
finally:
db.close()
else:
@@ -966,10 +1055,12 @@ def setup_gallery_routes() -> APIRouter:
_target = _norm_url(base)
db = SessionLocal()
try:
for ep in db.query(ModelEndpoint).all():
if _norm_url(ep.base_url) == _target:
api_key = ep.api_key
break
ep = _visible_image_endpoint_for_base(db, _target, user)
if ep:
base = (ep.base_url or base).rstrip("/")
api_key = ep.api_key
elif user and not _current_user_is_admin(request, user):
raise HTTPException(403, "Choose a registered image endpoint")
finally:
db.close()
@@ -1121,7 +1212,7 @@ def setup_gallery_routes() -> APIRouter:
you get edge blending + lighting unification while keeping the
composition recognisable."""
import httpx, base64 as _b64
require_privilege(request, "can_generate_images")
user = require_privilege(request, "can_generate_images")
body = await request.json()
image_b64 = body.get("image")
@@ -1148,23 +1239,22 @@ def setup_gallery_routes() -> APIRouter:
if not base:
db = SessionLocal()
try:
eps = db.query(ModelEndpoint).filter(
ModelEndpoint.is_enabled == True,
ModelEndpoint.model_type == "image",
).all()
if not eps:
ep = _first_visible_image_endpoint(db, user)
if not ep:
raise HTTPException(400, "No image generation endpoint configured.")
base = eps[0].base_url.rstrip("/")
api_key = eps[0].api_key
base = ep.base_url.rstrip("/")
api_key = ep.api_key
finally:
db.close()
else:
db = SessionLocal()
try:
for ep in db.query(ModelEndpoint).all():
if ep.base_url.rstrip("/").removesuffix("/v1").rstrip("/") == base.rstrip("/").removesuffix("/v1").rstrip("/"):
api_key = ep.api_key
break
ep = _visible_image_endpoint_for_base(db, base, user)
if ep:
base = (ep.base_url or base).rstrip("/")
api_key = ep.api_key
elif user and not _current_user_is_admin(request, user):
raise HTTPException(403, "Choose a registered image endpoint")
finally:
db.close()
@@ -1636,9 +1726,10 @@ def setup_gallery_routes() -> APIRouter:
db = SessionLocal()
try:
album = _get_or_404_album(db, album_id, user)
db.query(GalleryImage).filter(GalleryImage.album_id == album_id).update(
{"album_id": None}, synchronize_session=False
)
q = db.query(GalleryImage).filter(GalleryImage.album_id == album_id)
if user is not None:
q = q.filter(GalleryImage.owner == user)
q.update({"album_id": None}, synchronize_session=False)
db.delete(album)
db.commit()
return {"ok": True}
@@ -1709,7 +1800,7 @@ def setup_gallery_routes() -> APIRouter:
try:
img = _get_or_404_image(db, image_id, user)
img_path = Path("data/generated_images") / img.filename
img_path = _gallery_image_path(img.filename)
if not img_path.exists():
raise HTTPException(404, "Image file not found")
@@ -1727,7 +1818,7 @@ def setup_gallery_routes() -> APIRouter:
return {"error": "Vision is disabled — enable it in Settings → Vision"}
configured = vl_settings.get("vision_model", "")
try:
chat_url, model_name, headers = _resolve_vl_model(configured)
chat_url, model_name, headers = _resolve_vl_model(configured, owner=user)
except ValueError:
return {"error": "No vision model configured — set one in Settings → Vision"}
if not chat_url:
@@ -1808,4 +1899,3 @@ def setup_gallery_routes() -> APIRouter:
db.close()
return router
+10 -2
View File
@@ -490,7 +490,13 @@ def setup_history_routes(session_manager) -> APIRouter:
# Copy messages up to keep_count
msgs_to_copy = source.history[:keep_count]
for msg in msgs_to_copy:
new_session.add_message(ChatMessage(msg.role, msg.content, msg.metadata))
# Copy the metadata dict. Sharing it would let the fork's
# persistence (add_message -> _persist_message stamps
# _db_id/timestamp onto the dict) mutate the SOURCE session's
# in-memory messages, corrupting their _db_id and breaking
# edit/delete-by-id on the original conversation.
meta = dict(msg.metadata) if isinstance(msg.metadata, dict) else None
new_session.add_message(ChatMessage(msg.role, msg.content, meta))
try:
from src.event_bus import fire_event
fire_event("session_created", getattr(source, 'owner', None))
@@ -522,6 +528,8 @@ def setup_history_routes(session_manager) -> APIRouter:
async def compact_session(request: Request, session_id: str):
"""Manually trigger context compaction for a session."""
_verify_session_owner(request, session_id)
from src.auth_helpers import effective_user
owner = effective_user(request)
try:
session = session_manager.get_session(session_id)
except KeyError:
@@ -555,7 +563,7 @@ def setup_history_routes(session_manager) -> APIRouter:
)
# Use utility model if available
util_url, util_model, util_headers = resolve_endpoint("utility")
util_url, util_model, util_headers = resolve_endpoint("utility", owner=owner or None)
compact_url = util_url or session.endpoint_url
compact_model = util_model or session.model
compact_headers = util_headers if util_url else session.headers
+2 -2
View File
@@ -13,7 +13,7 @@ import httpx
from core.database import McpServer, SessionLocal
from core.middleware import require_admin
from src.constants import DATA_DIR
from src.constants import DATA_DIR, MCP_OAUTH_DIR
from src.mcp_manager import McpManager
logger = logging.getLogger(__name__)
@@ -23,7 +23,7 @@ router = APIRouter(prefix="/api/mcp", tags=["mcp"])
def _mcp_oauth_base_dir() -> Path:
"""Directory that may contain OAuth files managed by Odysseus."""
return (Path(DATA_DIR) / "mcp_oauth").resolve(strict=False)
return Path(MCP_OAUTH_DIR).resolve(strict=False)
def _resolve_mcp_oauth_path(raw_path, field_name: str) -> str:
+2 -3
View File
@@ -29,11 +29,10 @@ from src.llm_core import llm_call_async
from services.memory.memory_extractor import audit_memories
from src.auth_helpers import get_current_user, require_user
from src.endpoint_resolver import resolve_endpoint
from src.upload_limits import read_upload_limited
from src.upload_limits import read_upload_limited, MEMORY_IMPORT_MAX_BYTES
logger = logging.getLogger(__name__)
MEMORY_IMPORT_MAX_BYTES = int(os.getenv("ODYSSEUS_MEMORY_IMPORT_MAX_BYTES", str(10 * 1024 * 1024)))
def setup_memory_routes(memory_manager: MemoryManager, session_manager: SessionManager, memory_vector=None):
"""Set up memory-related routes."""
@@ -371,7 +370,7 @@ def setup_memory_routes(memory_manager: MemoryManager, session_manager: SessionM
tmp.write(content)
tmp_path = tmp.name
try:
text = _process_pdf(tmp_path)
text = _process_pdf(tmp_path, owner=_owner(request))
finally:
os.unlink(tmp_path)
else:
+188 -26
View File
@@ -5,6 +5,7 @@ import re
import uuid
import json
import socket
import hashlib
import time as _time
import logging
import httpx
@@ -282,8 +283,11 @@ _HOST_TO_CURATED = (
("fireworks.ai", "fireworks"),
("googleapis.com", "google"),
("x.ai", "xai"),
("openrouter.ai", "openrouter"),
("ollama.com", "ollama"),
("opencode.ai/zen/go", "opencode-go"),
("opencode.ai/zen", "opencode-zen"),
)
@@ -490,6 +494,8 @@ _NON_CHAT_EXACT_PREFIXES = (
def _is_chat_model(model_id: str) -> bool:
"""Return True if the model ID looks like a chat/completions-capable model."""
mid = model_id.lower()
if mid in {"gpt-5.1-codex"}:
return True
for prefix in _NON_CHAT_PREFIXES:
if mid.startswith(prefix):
return False
@@ -502,9 +508,67 @@ def _is_chat_model(model_id: str) -> bool:
return True
def _probe_single_model(base: str, api_key: str, model_id: str, timeout: int = 10, with_tools: bool = False) -> dict:
def _delete_orphaned_provider_auth(db, auth_id: Optional[str], exclude_ep_id: Optional[str] = None) -> bool:
"""Delete a ProviderAuthSession once no endpoint still references it.
Subscription providers (e.g. ChatGPT Subscription) keep their refresh token
in ProviderAuthSession rather than ModelEndpoint.api_key. When the last
endpoint backed by that auth row is removed, the stored credentials should
be cleared instead of lingering. Returns True if a row was deleted.
``exclude_ep_id`` drops the endpoint currently being deleted from the
reference count so it does not keep its own auth alive.
"""
if not auth_id:
return False
from core.database import ProviderAuthSession
still_referenced = db.query(ModelEndpoint.id).filter(
ModelEndpoint.provider_auth_id == auth_id,
ModelEndpoint.id != exclude_ep_id,
).first()
if still_referenced is not None:
return False
auth_row = db.query(ProviderAuthSession).filter(ProviderAuthSession.id == auth_id).first()
if auth_row is None:
return False
db.delete(auth_row)
return True
def _is_discovery_only_provider(provider: str) -> bool:
"""Provider that only supports model discovery, not live probing.
ChatGPT Subscription speaks the Responses/Codex API and has no
chat-completions or general health endpoint, so completion probes and
reachability pings are skipped status is derived from cached models.
"""
return provider == "chatgpt-subscription"
def _resolve_probe_key(ep) -> Optional[str]:
"""API key/bearer to probe an endpoint with.
Delegates to ``resolve_endpoint_runtime``, which already returns the static
``ModelEndpoint.api_key`` for keyed endpoints and resolves (and refreshes)
the runtime bearer for session-backed providers (e.g. ChatGPT Subscription).
Returns None if resolution fails (e.g. re-auth required) so probing skips
rather than raising. Reads only already-loaded scalar attributes of ``ep``.
"""
try:
from src.endpoint_resolver import resolve_endpoint_runtime
_base, key = resolve_endpoint_runtime(ep, owner=getattr(ep, "owner", None))
return key
except Exception as e:
logger.warning("Probe key resolution failed for %s: %s", getattr(ep, "id", "?"), e)
return None
def _probe_single_model(base: str, api_key: Optional[str], model_id: str, timeout: int = 10, with_tools: bool = False) -> dict:
"""Send a realistic completion request to a single model. Returns {status, latency_ms, error?}."""
provider = _detect_provider(base)
if _is_discovery_only_provider(provider):
# Responses/Codex API, not chat-completions: a completion probe would
# 400 and the re-probe flow would then hide every model. Discovery-only.
return {"status": "ok", "latency_ms": 0, "skipped": True}
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Say OK"},
@@ -618,6 +682,11 @@ def _probe_endpoint(base_url: str, api_key: str = None, timeout: int = 5) -> Lis
For Anthropic, queries their /v1/models API, falling back to hardcoded list."""
from src.endpoint_resolver import resolve_url
base = resolve_url(_normalize_base(base_url))
if _detect_provider(base) == "chatgpt-subscription":
from src.chatgpt_subscription import fetch_available_models
if api_key:
return fetch_available_models(api_key, timeout=timeout)
return []
if _detect_provider(base) == "anthropic":
# Try Anthropic's /v1/models endpoint first
url = build_models_url(base)
@@ -644,6 +713,10 @@ def _probe_endpoint(base_url: str, api_key: str = None, timeout: int = 5) -> Lis
logger.warning(f"Anthropic /v1/models failed, using hardcoded list: {e}")
return list(ANTHROPIC_MODELS)
url = build_models_url(base)
if not url:
curated_key = _match_provider_curated(base, None)
fallback = _PROVIDER_CURATED.get(curated_key) if curated_key else None
return list(fallback or [])
headers = build_headers(api_key, base)
try:
r = httpx.get(url, headers=headers, timeout=timeout, verify=llm_verify())
@@ -697,7 +770,6 @@ def _probe_endpoint(base_url: str, api_key: str = None, timeout: int = 5) -> Lis
return list(fallback)
return []
def _ping_endpoint(base_url: str, api_key: str = None, timeout: float = 1.5) -> Dict[str, Any]:
"""Reachability probe that does not require installed/listed models."""
from src.endpoint_resolver import resolve_url
@@ -713,6 +785,10 @@ def _ping_endpoint(base_url: str, api_key: str = None, timeout: float = 1.5) ->
or "ollama" in (parsed_base.hostname or "").lower()
)
# APFEL-specific detection
host = (parsed_base.hostname or "").lower()
looks_like_apfel = "apfel" in host or parsed_base.port == 11435
def _result_from_response(r) -> Dict[str, Any]:
if 300 <= r.status_code < 400:
loc = r.headers.get("location", "")
@@ -734,7 +810,23 @@ def _ping_endpoint(base_url: str, api_key: str = None, timeout: float = 1.5) ->
last_error: Optional[str] = None
try:
if looks_like_ollama:
# APFEL does not behave like Ollama; use its health endpoint.
if looks_like_apfel:
root = base
for suffix in ("/v1", "/api"):
if root.endswith(suffix):
root = root[: -len(suffix)].rstrip("/")
break
try:
r = httpx.get(root + "/health", timeout=timeout, verify=llm_verify())
result = _result_from_response(r)
if result["reachable"]:
return result
last_error = result.get("error")
except Exception as e:
last_error = str(e)[:120]
elif looks_like_ollama:
root = base
for suffix in ("/v1", "/api"):
if root.endswith(suffix):
@@ -754,14 +846,31 @@ def _ping_endpoint(base_url: str, api_key: str = None, timeout: float = 1.5) ->
try:
r = httpx.get(base, headers=headers, timeout=timeout, verify=llm_verify())
return _result_from_response(r)
result = _result_from_response(r)
# If the bare base URL returns a non-auth 4xx (e.g. 404), try /models
# as a fallback. OpenAI-compatible servers like llama-swap return 404
# on the base /v1 prefix but 200 on /v1/models. Auth failures (401/403)
# are definitive — probing /models would just repeat the same rejection.
if (
not result["reachable"]
and result.get("status_code") is not None
and 400 <= result["status_code"] < 500
and result["status_code"] not in (401, 403)
):
models_url = build_models_url(base)
try:
r2 = httpx.get(models_url, headers=headers, timeout=timeout, verify=llm_verify())
result2 = _result_from_response(r2)
if result2["reachable"]:
return result2
except Exception:
pass
return result
except Exception as e:
last_error = str(e)[:120]
return {"reachable": False, "status_code": None, "error": last_error}
def _model_endpoint_error_message(base_url: str, ping: Dict[str, Any] = None) -> str:
"""Return a provider-aware error message for failed endpoint probes."""
ping = ping or {}
@@ -850,6 +959,14 @@ def _visible_models(cached_models, hidden_models, pinned_models=None):
return [m for m in merged if m not in hidden]
def _api_key_fingerprint(api_key: Optional[str]) -> str:
"""Stable, non-secret label for distinguishing same-URL credentials."""
key = (api_key or "").strip()
if not key:
return ""
return hashlib.sha256(key.encode("utf-8")).hexdigest()[:8]
def setup_model_routes(model_discovery):
router = APIRouter(prefix="/api")
@@ -951,6 +1068,17 @@ def setup_model_routes(model_discovery):
ok, info = _should_refresh_endpoint(ep, now, force=force)
if not ok:
continue
if getattr(ep, "provider_auth_id", None):
try:
from src.endpoint_resolver import resolve_endpoint_runtime
info["base"], info["api_key"] = resolve_endpoint_runtime(
ep,
owner=getattr(ep, "owner", None),
)
info["key"] = _refresh_key(info["base"], info["api_key"])
except Exception as e:
logger.warning("Skipping model refresh for %s: could not resolve provider auth: %s", getattr(ep, "name", ep.id), e)
continue
groups.setdefault(info["key"], {
"base": info["base"],
"api_key": info["api_key"],
@@ -1104,8 +1232,9 @@ def setup_model_routes(model_discovery):
raise HTTPException(401, "Not authenticated")
except HTTPException:
raise
except Exception:
pass
except Exception as e:
logger.error('Auth gate error in GET /api/models, failing closed: %s', e)
raise HTTPException(status_code=500, detail='Internal error')
# Admins see every endpoint (they manage the global pool); regular
# users get the owner-scoped view.
_is_admin = False
@@ -1219,12 +1348,20 @@ def setup_model_routes(model_discovery):
"endpoint_kind": kind,
}
try:
t0 = _time.time()
ping = _ping_endpoint(base, ep.api_key, timeout=1.5)
entry["latency_ms"] = round((_time.time() - t0) * 1000)
entry["status"] = "online" if ping.get("reachable") or cached_count else "offline"
entry["error"] = ping.get("error")
entry["model_count"] = cached_count or (len(ANTHROPIC_MODELS) if provider == "anthropic" else 0)
if _is_discovery_only_provider(provider):
# No general health endpoint — an unauthenticated GET just
# 401s. Report status from cached models instead of pinging.
entry["latency_ms"] = None
entry["status"] = "online" if cached_count else "offline"
entry["error"] = None
entry["model_count"] = cached_count
else:
t0 = _time.time()
ping = _ping_endpoint(base, ep.api_key, timeout=1.5)
entry["latency_ms"] = round((_time.time() - t0) * 1000)
entry["status"] = "online" if ping.get("reachable") or cached_count else "offline"
entry["error"] = ping.get("error")
entry["model_count"] = cached_count or (len(ANTHROPIC_MODELS) if provider == "anthropic" else 0)
except Exception as e:
entry["latency_ms"] = None
entry["status"] = "online" if cached_count else "offline"
@@ -1257,7 +1394,7 @@ def setup_model_routes(model_discovery):
if ep_id and ep_id not in endpoints_cache:
ep = db.query(ModelEndpoint).filter(ModelEndpoint.id == ep_id).first()
if ep:
endpoints_cache[ep_id] = {"base_url": ep.base_url, "api_key": ep.api_key}
endpoints_cache[ep_id] = {"base_url": ep.base_url, "api_key": _resolve_probe_key(ep)}
ep_data = endpoints_cache.get(ep_id)
if not ep_data:
# Try to find by base_url from the model's endpoint field
@@ -1296,7 +1433,7 @@ def setup_model_routes(model_discovery):
"id": ep.id,
"name": ep.name,
"base_url": ep.base_url,
"api_key": ep.api_key,
"api_key": _resolve_probe_key(ep),
})
finally:
db.close()
@@ -1385,18 +1522,21 @@ def setup_model_routes(model_discovery):
# Endpoint counts as reachable if it has any model — including
# admin-pinned IDs that a probe would never surface.
status = "online" if (all_models or pinned) else "offline"
base = _normalize_base(r.base_url)
ping = None
if not all_models and not pinned and r.is_enabled:
# Discovery-only providers have no health endpoint — an
# unauthenticated ping just 401s, so don't bother.
if not all_models and not pinned and r.is_enabled and not _is_discovery_only_provider(_detect_provider(base)):
ping = _ping_endpoint(r.base_url, r.api_key, timeout=1.0)
if ping.get("reachable"):
status = "empty"
base = _normalize_base(r.base_url)
kind = _effective_endpoint_kind(r, base)
results.append({
"id": r.id,
"name": r.name,
"base_url": r.base_url,
"has_key": bool(r.api_key),
"api_key_fingerprint": _api_key_fingerprint(r.api_key),
"is_enabled": r.is_enabled,
"models": visible,
"pinned_models": pinned,
@@ -1463,21 +1603,34 @@ def setup_model_routes(model_discovery):
)
explicit_timeout = _explicit_model_list_timeout(base_url, requested_kind, refresh_timeout)
# Dedupe: if an endpoint with the same base_url already exists and
# is reachable by the caller (shared or owned by them), return it
# instead of creating a duplicate row. Fixes "Scan for Servers"
# re-adding manually-added endpoints under their host:port name.
# Dedupe: if an endpoint with the same base_url and compatible
# credentials already exists and is reachable by the caller (shared or
# owned by them), return it instead of creating a duplicate row. Keep
# same-url/different-key rows distinct so users can group the same
# provider URL under multiple credentials.
from src.auth_helpers import get_current_user as _gcu_dedup
_caller = _gcu_dedup(request) or None
_incoming_api_key = api_key.strip()
_db_dedup = SessionLocal()
try:
existing = (
_same_url_rows = (
_db_dedup.query(ModelEndpoint)
.filter(ModelEndpoint.base_url == base_url)
.filter((ModelEndpoint.owner.is_(None)) | (ModelEndpoint.owner == _caller))
.order_by(ModelEndpoint.owner.desc()) # prefer owned over shared
.first()
.all()
)
existing = None
_empty_key_existing = None
for _candidate in _same_url_rows:
_candidate_key = (getattr(_candidate, "api_key", None) or "").strip()
if _candidate_key == _incoming_api_key:
existing = _candidate
break
if _incoming_api_key and not _candidate_key and _empty_key_existing is None:
_empty_key_existing = _candidate
if existing is None and _incoming_api_key and _empty_key_existing is not None:
existing = _empty_key_existing
if existing:
changed = False
# Persist any incoming pinned IDs onto the existing row. An
@@ -1526,6 +1679,8 @@ def setup_model_routes(model_discovery):
"id": existing.id,
"name": existing.name,
"base_url": existing.base_url,
"has_key": bool(existing.api_key),
"api_key_fingerprint": _api_key_fingerprint(existing.api_key),
"models": _visible_models(
existing_models,
getattr(existing, "hidden_models", None),
@@ -1599,6 +1754,8 @@ def setup_model_routes(model_discovery):
"id": ep_id,
"name": name.strip(),
"base_url": base_url,
"has_key": bool(api_key.strip()),
"api_key_fingerprint": _api_key_fingerprint(api_key),
"models": _merge_model_ids(model_ids, _pinned),
"pinned_models": _pinned,
"online": bool(model_ids) or bool(_pinned) or bool(ping.get("reachable")),
@@ -1648,7 +1805,7 @@ def setup_model_routes(model_discovery):
ep = db.query(ModelEndpoint).filter(ModelEndpoint.id == ep_id).first()
if not ep:
raise HTTPException(404, "Endpoint not found")
ep_data = {"id": ep.id, "name": ep.name, "base_url": ep.base_url, "api_key": ep.api_key}
ep_data = {"id": ep.id, "name": ep.name, "base_url": ep.base_url, "api_key": _resolve_probe_key(ep)}
finally:
db.close()
@@ -1712,7 +1869,7 @@ def setup_model_routes(model_discovery):
category = _classify_endpoint(base, kind)
timeout = _manual_refresh_timeout(ep, category, refresh_timeout)
try:
probed = _probe_endpoint(base, ep.api_key, timeout=timeout)
probed = _probe_endpoint(base, _resolve_probe_key(ep), timeout=timeout)
except Exception as exc:
logger.warning("Manual model refresh failed for endpoint %s at %s: %s", ep_id, base, exc)
probed = []
@@ -1948,6 +2105,8 @@ def setup_model_routes(model_discovery):
"name": ep.name,
"model_type": ep.model_type,
"base_url": ep.base_url,
"has_key": bool(ep.api_key),
"api_key_fingerprint": _api_key_fingerprint(ep.api_key),
"pinned_models": _normalize_model_ids(getattr(ep, "pinned_models", None)),
"endpoint_kind": getattr(ep, "endpoint_kind", None) or "auto",
"model_refresh_mode": getattr(ep, "model_refresh_mode", None) or "auto",
@@ -2049,7 +2208,9 @@ def setup_model_routes(model_discovery):
cleared_user_preferences = _clear_user_prefs_for_endpoint(ep_id)
cleared_sessions = _clear_sessions_for_endpoint(db, ep.base_url)
cleared_loaded_sessions = _clear_loaded_sessions_for_endpoint(ep.base_url)
auth_id = getattr(ep, "provider_auth_id", None)
db.delete(ep)
cleared_provider_auth = _delete_orphaned_provider_auth(db, auth_id, exclude_ep_id=ep_id)
db.commit()
_invalidate_models_cache()
_local_probe_cache["data"] = None
@@ -2059,6 +2220,7 @@ def setup_model_routes(model_discovery):
"cleared_user_preferences": cleared_user_preferences,
"cleared_sessions": cleared_sessions,
"cleared_loaded_sessions": cleared_loaded_sessions,
"cleared_provider_auth": cleared_provider_auth,
}
finally:
db.close()
+161 -16
View File
@@ -11,6 +11,7 @@ from pydantic import BaseModel
from core.database import SessionLocal, Note
from src.auth_helpers import get_current_user
from src.constants import DATA_DIR
from sqlalchemy.orm.attributes import flag_modified
logger = logging.getLogger(__name__)
@@ -95,6 +96,32 @@ def _note_to_dict(note: Note) -> Dict[str, Any]:
}
def _reminder_text_from_note(note: Note) -> tuple[str, str]:
"""Return the reminder title/body from a stored note row."""
title = (note.title or "Note reminder").strip() or "Note reminder"
if note.items:
try:
items = json.loads(note.items)
except (json.JSONDecodeError, TypeError):
items = None
if isinstance(items, list):
pending: list[str] = []
for item in items:
if not isinstance(item, dict):
continue
if item.get("done") or item.get("checked"):
continue
text = str(item.get("text") or "").strip()
if text:
pending.append(text)
if pending:
shown = "\n".join(f"- {text}" for text in pending[:8])
extra = f"\n...and {len(pending) - 8} more" if len(pending) > 8 else ""
return title, f"Pending ({len(pending)}):\n{shown}{extra}"
return title, f"{len(items)} item{'s' if len(items) != 1 else ''}"
return title, (note.content or "").strip()[:400]
# ---------------------------------------------------------------------------
# Reminder dispatch — module-level so background tasks (built-in actions)
@@ -114,8 +141,9 @@ async def dispatch_reminder(
note_id: str,
owner: str = "",
queue_browser: bool = True,
settings_override: dict | None = None,
) -> dict:
"""Fire a reminder via the configured channel (browser/email/ntfy).
"""Fire a reminder via the configured channel (browser/email/ntfy/webhook).
Args:
title: short headline shown to the user
@@ -129,7 +157,7 @@ async def dispatch_reminder(
nothing is "sent" synchronously for it the channel just routes there.
"""
from src.settings import load_settings
settings = load_settings()
settings = {**load_settings(), **(settings_override or {})}
channel = settings.get("reminder_channel", "browser")
llm_on = bool(settings.get("reminder_llm_synthesis", False))
title = (title or "").strip()
@@ -143,7 +171,7 @@ async def dispatch_reminder(
from datetime import datetime as _dt, timezone as _tz, timedelta as _td
from pathlib import Path as _P
_slug = "".join(c if (c.isalnum() or c in "-_.@") else "_" for c in (owner or "default"))
cache_path = _P(f"data/note_pings_{_slug}.json")
cache_path = _P(DATA_DIR) / f"note_pings_{_slug}.json"
if cache_path.exists():
cache = _json.loads(cache_path.read_text(encoding="utf-8"))
last = cache.get(cache_key)
@@ -160,13 +188,14 @@ async def dispatch_reminder(
# Treat those as browser-only dedupe so email reminders can be
# retried by the backend scanner after a failed frontend path.
should_skip = last_dt >= _dt.now(_tz.utc) - _td(minutes=25)
if should_skip and channel in ("email", "ntfy"):
if should_skip and channel in ("email", "ntfy", "webhook"):
should_skip = last_channel == channel
if should_skip:
return {
"synthesis": None,
"email_sent": False,
"ntfy_sent": False,
"webhook_sent": False,
"browser_sent": True,
"skipped": True,
}
@@ -179,9 +208,9 @@ async def dispatch_reminder(
try:
from src.endpoint_resolver import resolve_endpoint
from src.llm_core import llm_call_async
url, model, headers = resolve_endpoint("utility")
url, model, headers = resolve_endpoint("utility", owner=owner or None)
if not url:
url, model, headers = resolve_endpoint("default")
url, model, headers = resolve_endpoint("default", owner=owner or None)
if url and model:
raw = await llm_call_async(
url=url, model=model,
@@ -360,6 +389,76 @@ async def dispatch_reminder(
email_error = str(e) or e.__class__.__name__
logger.warning(f"Reminder email send failed: {e}")
webhook_sent = False
webhook_error = ""
if channel == "webhook":
try:
import httpx
import json as _wjson
from src.integrations import load_integrations
# Built-in payload defaults for known presets so users don't have
# to configure a template just to use a standard service.
_PRESET_TEMPLATE_DEFAULTS = {
"discord_webhook": '{"embeds": [{"title": "{{title}}", "description": "{{message}}", "color": 5793266}]}',
}
intg_id = settings.get("reminder_webhook_integration_id", "").strip()
template = settings.get("reminder_webhook_payload_template", "").strip()
if not intg_id:
webhook_error = "No webhook integration selected"
else:
intg = next(
(i for i in load_integrations()
if i.get("id") == intg_id and i.get("base_url")),
None,
)
if not intg:
webhook_error = f"Integration {intg_id!r} not found or missing base URL"
else:
# Fall back to a built-in default for known presets so
# users don't have to configure a template for standard
# services like Discord.
if not template:
template = _PRESET_TEMPLATE_DEFAULTS.get(intg.get("preset", ""), "")
if not template:
webhook_error = "No payload template configured"
else:
# Render template: JSON-escape the values so the result
# is always valid JSON regardless of special characters.
# dumps() returns `"value"` — strip outer quotes.
msg = (synthesis or note_body or title or "Reminder")[:4000]
_t = _wjson.dumps(title or "Reminder")[1:-1]
_m = _wjson.dumps(msg)[1:-1]
rendered = template.replace("{{title}}", _t).replace("{{message}}", _m)
hdrs = {"Content-Type": "application/json"}
api_key = intg.get("api_key", "")
auth_type = (intg.get("auth_type") or "none").lower()
if api_key:
if auth_type == "bearer":
hdrs["Authorization"] = f"Bearer {api_key}"
elif auth_type == "header":
hdrs[intg.get("auth_header") or "Authorization"] = api_key
url = intg["base_url"].rstrip("/")
# SSRF guard — matches the pattern used by webhook_routes,
# CalDAV, search, and embeddings. Blocks link-local / metadata
# addresses (169.254.x.x) by default; set
# REMINDER_WEBHOOK_BLOCK_PRIVATE_IPS=true to also block
# RFC-1918 ranges for locked-down deployments.
import os as _os
from src.url_safety import check_outbound_url as _chk
_block = _os.getenv("REMINDER_WEBHOOK_BLOCK_PRIVATE_IPS", "false").lower() == "true"
_ok, _reason = _chk(url, block_private=_block)
if not _ok:
webhook_error = f"Webhook URL rejected: {_reason}"
else:
async with httpx.AsyncClient(timeout=10.0) as client:
resp = await client.post(url, content=rendered.encode(), headers=hdrs)
webhook_sent = resp.is_success
if not webhook_sent:
webhook_error = f"Webhook returned HTTP {resp.status_code}"
except Exception as e:
webhook_error = str(e) or e.__class__.__name__
logger.warning(f"Reminder webhook send failed: {e}")
ntfy_sent = False
ntfy_error = ""
if channel == "ntfy":
@@ -415,7 +514,7 @@ async def dispatch_reminder(
# second send for the same note within 25 min. Without this, a note
# whose due_date fires while the user has the app open got TWO emails
# (frontend-fired here + background-fired by ping_notes 05 min later).
if (email_sent or ntfy_sent or browser_sent or local_browser_sent) and note_id:
if (email_sent or ntfy_sent or webhook_sent or browser_sent or local_browser_sent) and note_id:
try:
import json as _json
from datetime import datetime as _dt, timezone as _tz
@@ -425,13 +524,13 @@ async def dispatch_reminder(
_STATE = cache_path
if _STATE is None:
_slug = "".join(c if (c.isalnum() or c in "-_.@") else "_" for c in (owner or "default"))
_STATE = _P(f"data/note_pings_{_slug}.json")
_STATE = _P(DATA_DIR) / f"note_pings_{_slug}.json"
_STATE.parent.mkdir(parents=True, exist_ok=True)
try:
_cache = cache or (_json.loads(_STATE.read_text(encoding="utf-8")) if _STATE.exists() else {})
except Exception:
_cache = {}
sent_channel = "email" if email_sent else "ntfy" if ntfy_sent else "browser"
sent_channel = "email" if email_sent else "ntfy" if ntfy_sent else "webhook" if webhook_sent else "browser"
_cache[cache_key or str(note_id)] = {
"at": _dt.now(_tz.utc).isoformat(),
"channel": sent_channel,
@@ -441,11 +540,14 @@ async def dispatch_reminder(
logger.debug(f"dispatch_reminder: cache write failed: {_e}")
return {
"channel": channel,
"synthesis": synthesis,
"email_sent": email_sent,
"email_error": email_error,
"ntfy_sent": ntfy_sent,
"ntfy_error": ntfy_error,
"webhook_sent": webhook_sent,
"webhook_error": webhook_error,
"browser_sent": browser_sent or local_browser_sent,
}
@@ -467,6 +569,23 @@ def setup_note_routes(task_scheduler=None):
def _owner(request: Request) -> Optional[str]:
return get_current_user(request)
def _is_admin_or_single_user(request: Request, user: str | None) -> bool:
if user == "internal-tool":
return True
if not user:
# require_user() already admitted this request, which only happens
# for auth-disabled, loopback-bypass, or unconfigured single-user
# modes. There is no separate non-admin account boundary there.
return True
try:
from core.auth import AuthManager
auth_mgr = getattr(request.app.state, "auth_manager", None) or AuthManager()
if not getattr(auth_mgr, "is_configured", True):
return True
return bool(auth_mgr.is_admin(user))
except Exception:
return False
# --- LIST ---
@router.get("")
def list_notes(
@@ -684,20 +803,46 @@ def setup_note_routes(task_scheduler=None):
"""
# Gate against anonymous callers — LLM synthesis can burn tokens.
from src.auth_helpers import require_user as _ru
_ru(request)
user = _ru(request)
body = await request.json()
note_id = body.get("note_id")
title = (body.get("title") or "").strip()
note_body = (body.get("body") or "").strip()
note_id = str(body.get("note_id") or "").strip()
if not note_id:
raise HTTPException(400, "note_id required")
# Delegate to the module-level helper so background tasks can reuse
# the same dispatch without an HTTP roundtrip + auth cookie.
caller = _owner(request)
is_test = note_id.startswith("test-")
is_admin = _is_admin_or_single_user(request, user or caller)
_override: dict = {}
if is_test:
if not is_admin:
raise HTTPException(403, "Admin only")
title = (body.get("title") or "Test Reminder").strip() or "Test Reminder"
note_body = (body.get("body") or "").strip()
# Optional overrides let the admin settings test button pass the
# current UI values directly so it never races a pending save.
if body.get("channel"):
_override["reminder_channel"] = body["channel"]
if body.get("webhook_integration_id"):
_override["reminder_webhook_integration_id"] = body["webhook_integration_id"]
if body.get("webhook_payload_template"):
_override["reminder_webhook_payload_template"] = body["webhook_payload_template"]
else:
db = SessionLocal()
try:
note = db.query(Note).filter(Note.id == note_id).first()
if not note:
raise HTTPException(404, "Note not found")
if caller is not None and note.owner != caller:
raise HTTPException(404, "Note not found")
title, note_body = _reminder_text_from_note(note)
finally:
db.close()
return await dispatch_reminder(
title=title, note_body=note_body, note_id=note_id,
owner=_owner(request) or "",
owner=caller or "",
queue_browser=False,
settings_override=_override or None,
)
# --- REORDER NOTES ---
+13 -12
View File
@@ -6,16 +6,14 @@ import uuid
from typing import List, Tuple
from fastapi import APIRouter, HTTPException, Query, Request, UploadFile, File, Depends
from src.request_models import DirectoryRequest
from core.constants import BASE_DIR, PERSONAL_DIR
from core.constants import BASE_DIR, PERSONAL_DIR, PERSONAL_UPLOADS_DIR
from src.rag_singleton import get_rag_manager
from src.auth_helpers import get_current_user, require_user
from src.auth_helpers import require_privilege, require_user
from core.middleware import require_admin
from src.upload_handler import secure_filename
from src.upload_limits import PERSONAL_UPLOAD_MAX_BYTES
UPLOADS_DIR = os.path.join(BASE_DIR, "data", "personal_uploads")
MAX_PERSONAL_UPLOAD_BYTES = int(
os.getenv("ODYSSEUS_PERSONAL_UPLOAD_MAX_BYTES", str(25 * 1024 * 1024))
)
UPLOADS_DIR = PERSONAL_UPLOADS_DIR
logger = logging.getLogger(__name__)
@@ -194,7 +192,7 @@ def setup_personal_routes(personal_docs_manager, rag_manager, rag_available):
@router.post("/upload")
async def upload_files_to_rag(request: Request, files: List[UploadFile] = File(...)):
"""Upload files directly into RAG. Supports text and PDF."""
user = get_current_user(request)
user = require_privilege(request, "can_use_documents")
rag = _rag()
if not rag:
raise HTTPException(503, "RAG system is not available — is the embedding service running?")
@@ -208,8 +206,8 @@ def setup_personal_routes(personal_docs_manager, rag_manager, rag_available):
for upload in files:
try:
file_path, stored_name, safe_name = _unique_personal_upload_path(upload_dir, upload.filename)
content_bytes = await upload.read(MAX_PERSONAL_UPLOAD_BYTES + 1)
if len(content_bytes) > MAX_PERSONAL_UPLOAD_BYTES:
content_bytes = await upload.read(PERSONAL_UPLOAD_MAX_BYTES + 1)
if len(content_bytes) > PERSONAL_UPLOAD_MAX_BYTES:
logger.warning(f"Rejected oversized personal upload: {upload.filename!r}")
total_failed += 1
continue
@@ -286,9 +284,12 @@ def setup_personal_routes(personal_docs_manager, rag_manager, rag_available):
except ValueError:
# commonpath raises on mixed drives / non-comparable paths
in_uploads = False
if in_uploads and abs_target != base_abs and os.path.exists(abs_target):
os.remove(abs_target)
deleted_from_disk = True
if in_uploads and abs_target != base_abs:
try:
os.remove(abs_target)
deleted_from_disk = True
except FileNotFoundError:
pass # already gone — race with another request or cleanup
# Exclude the file from the listing (persists across restarts)
personal_docs_manager.exclude_file(filepath)
+2 -1
View File
@@ -4,8 +4,9 @@ import os
from typing import Optional
from fastapi import APIRouter, Request
from src.auth_helpers import get_current_user
from src.constants import USER_PREFS_FILE
PREFS_FILE = os.path.join("data", "user_prefs.json")
PREFS_FILE = USER_PREFS_FILE
def _load():
+3 -1
View File
@@ -9,6 +9,7 @@ from pydantic import BaseModel, Field
from src.request_models import PresetUpdateRequest
from core.middleware import require_admin
from src.auth_helpers import effective_user
logger = logging.getLogger(__name__)
@@ -100,7 +101,8 @@ def setup_preset_routes(preset_manager) -> APIRouter:
try:
model_spec = data.get("model") or ""
url, model, headers = _resolve_model(model_spec)
user = effective_user(request)
url, model, headers = _resolve_model(model_spec, owner=user)
result = await llm_call_async(url, model, messages, temperature=0.8, max_tokens=500, headers=headers)
return {"success": True, "prompt": result.strip()}
except Exception as e:
+61 -46
View File
@@ -14,6 +14,7 @@ from fastapi.responses import HTMLResponse, StreamingResponse
from pydantic import BaseModel, Field
from src.endpoint_resolver import resolve_endpoint
from src.auth_helpers import _auth_disabled, get_current_user
from src.constants import DEEP_RESEARCH_DIR
_SESSION_ID_RE = re.compile(r"^[a-zA-Z0-9-]{1,128}$")
@@ -37,13 +38,15 @@ def _first_chat_model(models) -> str:
return (models[0] if models else "")
def _resolve_research_endpoint(sess) -> tuple:
def _resolve_research_endpoint(sess, owner: Optional[str] = None) -> tuple:
"""Return (endpoint_url, model, headers) for Deep Research, checking admin overrides."""
owner = owner or getattr(sess, "owner", None) or None
url, model, headers = resolve_endpoint(
"research",
fallback_url=sess.endpoint_url,
fallback_model=sess.model,
fallback_headers=sess.headers,
owner=owner,
)
return url, model, headers
@@ -72,6 +75,38 @@ def _owned_enabled_endpoint(db, owner, endpoint_id=None):
return owner_filter(q, ModelEndpoint, owner).first()
def _resolve_endpoint_runtime(ep, owner=None, model: Optional[str] = None):
"""Resolve a ModelEndpoint row into (chat_url, model, headers).
Mirrors endpoint_resolver.resolve_endpoint's provider-auth handling for
panel-selected research endpoints. ChatGPT Subscription endpoints keep
OAuth tokens in ProviderAuthSession, so ep.api_key is intentionally empty.
"""
from src.endpoint_resolver import (
build_chat_url,
build_headers,
resolve_endpoint_runtime as resolve_model_endpoint_runtime,
)
try:
base, api_key = resolve_model_endpoint_runtime(ep, owner=owner)
except Exception as e:
logger.warning("Could not resolve endpoint credentials for research: %s", e)
return None
ep_model = (model or "").strip()
if not ep_model:
try:
models = json.loads(ep.cached_models) if ep.cached_models else []
if models:
ep_model = _first_chat_model(models)
except Exception:
pass
if not ep_model:
return None
return build_chat_url(base), ep_model, build_headers(api_key, base)
def setup_research_routes(research_handler, session_manager=None) -> APIRouter:
router = APIRouter(tags=["research"])
@@ -98,7 +133,7 @@ def setup_research_routes(research_handler, session_manager=None) -> APIRouter:
if entry is not None:
return entry.get("owner", "") == user
# Task no longer in memory — check the persisted JSON.
path = Path("data/deep_research") / f"{session_id}.json"
path = Path(DEEP_RESEARCH_DIR) / f"{session_id}.json"
if not path.exists():
return False
try:
@@ -162,7 +197,7 @@ def setup_research_routes(research_handler, session_manager=None) -> APIRouter:
def _assert_owns_research(session_id: str, user: str) -> None:
"""404-not-403 ownership gate for a research session's on-disk JSON.
Use BEFORE returning any data or mutating the file."""
path = Path("data/deep_research") / f"{session_id}.json"
path = Path(DEEP_RESEARCH_DIR) / f"{session_id}.json"
if not path.exists():
raise HTTPException(404, "Research not found")
try:
@@ -225,7 +260,7 @@ def setup_research_routes(research_handler, session_manager=None) -> APIRouter:
):
user = _require_user(request)
"""List all completed research for the Library panel."""
data_dir = Path("data/deep_research")
data_dir = Path(DEEP_RESEARCH_DIR)
items = []
for p in data_dir.glob("*.json"):
try:
@@ -275,7 +310,7 @@ def setup_research_routes(research_handler, session_manager=None) -> APIRouter:
summary, stats used by the Library preview panel."""
user = _require_user(request)
_validate_session_id(session_id)
path = Path("data/deep_research") / f"{session_id}.json"
path = Path(DEEP_RESEARCH_DIR) / f"{session_id}.json"
if not path.exists():
raise HTTPException(404, "Research not found")
try:
@@ -292,7 +327,7 @@ def setup_research_routes(research_handler, session_manager=None) -> APIRouter:
"""Soft-archive / restore a research report (sets `archived` in its JSON)."""
user = _require_user(request)
_validate_session_id(session_id)
path = Path("data/deep_research") / f"{session_id}.json"
path = Path(DEEP_RESEARCH_DIR) / f"{session_id}.json"
if not path.exists():
raise HTTPException(404, "Research not found")
try:
@@ -312,7 +347,7 @@ def setup_research_routes(research_handler, session_manager=None) -> APIRouter:
"""Delete a research result from disk."""
user = _require_user(request)
_validate_session_id(session_id)
data_dir = Path("data/deep_research")
data_dir = Path(DEEP_RESEARCH_DIR)
json_path = data_dir / f"{session_id}.json"
deleted = False
if json_path.exists():
@@ -368,7 +403,6 @@ def setup_research_routes(research_handler, session_manager=None) -> APIRouter:
if body.endpoint_id:
from src.database import SessionLocal
from src.endpoint_resolver import normalize_base, build_chat_url, build_headers
db = SessionLocal()
try:
# Owner-scoped: never resolve another user's private endpoint
@@ -377,35 +411,26 @@ def setup_research_routes(research_handler, session_manager=None) -> APIRouter:
ep = _owned_enabled_endpoint(db, user, body.endpoint_id)
if not ep:
raise HTTPException(404, "Endpoint not found or disabled")
base = normalize_base(ep.base_url)
ep_url = build_chat_url(base)
ep_headers = build_headers(ep.api_key, base)
ep_model = body.model or ""
if not ep_model:
try:
import json as _json
models = _json.loads(ep.cached_models) if ep.cached_models else []
if models:
ep_model = _first_chat_model(models)
except Exception:
pass
resolved = _resolve_endpoint_runtime(ep, owner=user, model=body.model)
if not resolved:
raise HTTPException(400, "Endpoint is not configured with a usable model.")
ep_url, ep_model, ep_headers = resolved
finally:
db.close()
else:
ep_url, ep_model, ep_headers = resolve_endpoint("research")
ep_url, ep_model, ep_headers = resolve_endpoint("research", owner=user)
if not ep_url:
ep_url, ep_model, ep_headers = resolve_endpoint("utility")
ep_url, ep_model, ep_headers = resolve_endpoint("utility", owner=user)
# When neither research nor utility is configured, use the user's
# configured DEFAULT model (default_endpoint_id/default_model) rather
# than arbitrarily grabbing the first enabled endpoint's first model
# (which surfaced gpt-3.5). "Default" should mean the default model.
if not ep_url:
ep_url, ep_model, ep_headers = resolve_endpoint("default")
ep_url, ep_model, ep_headers = resolve_endpoint("default", owner=user)
if not ep_url:
ep_url, ep_model, ep_headers = resolve_endpoint("chat")
ep_url, ep_model, ep_headers = resolve_endpoint("chat", owner=user)
if not ep_url:
from src.database import SessionLocal
from src.endpoint_resolver import normalize_base, build_chat_url, build_headers
db = SessionLocal()
try:
# Owner-scoped first-enabled fallback: the caller's own rows
@@ -414,18 +439,9 @@ def setup_research_routes(research_handler, session_manager=None) -> APIRouter:
# /api/v1/chat fallback (webhook_routes._first_enabled_endpoint).
ep = _owned_enabled_endpoint(db, user)
if ep:
base = normalize_base(ep.base_url)
ep_url = build_chat_url(base)
ep_headers = build_headers(ep.api_key, base)
ep_model = ""
if ep.cached_models:
try:
import json as _json
models = _json.loads(ep.cached_models)
if models:
ep_model = _first_chat_model(models)
except Exception:
pass
resolved = _resolve_endpoint_runtime(ep, owner=user)
if resolved:
ep_url, ep_model, ep_headers = resolved
finally:
db.close()
if not ep_url:
@@ -494,7 +510,7 @@ def setup_research_routes(research_handler, session_manager=None) -> APIRouter:
raise HTTPException(404, "No research found for this session")
result = research_handler.get_result(session_id)
if result is None:
p = Path("data/deep_research") / f"{session_id}.json"
p = Path(DEEP_RESEARCH_DIR) / f"{session_id}.json"
if p.exists():
d = json.loads(p.read_text(encoding="utf-8"))
return {
@@ -534,7 +550,7 @@ def setup_research_routes(research_handler, session_manager=None) -> APIRouter:
sources = research_handler.get_sources(session_id) or []
query = ""
path = Path("data/deep_research") / f"{session_id}.json"
path = Path(DEEP_RESEARCH_DIR) / f"{session_id}.json"
if path.exists():
try:
disk = json.loads(path.read_text(encoding="utf-8"))
@@ -572,19 +588,18 @@ def setup_research_routes(research_handler, session_manager=None) -> APIRouter:
ep_headers = dict(r_headers)
if not ep_url or not ep_model:
_merge(*resolve_endpoint("chat"))
_merge(*resolve_endpoint("chat", owner=user))
if not ep_url or not ep_model:
_merge(*resolve_endpoint("research"))
_merge(*resolve_endpoint("research", owner=user))
if not ep_url or not ep_model:
_merge(*resolve_endpoint("utility"))
_merge(*resolve_endpoint("utility", owner=user))
if not ep_url or not ep_model:
# Last resort: any enabled endpoint
# Last resort: this user's enabled endpoint, plus legacy shared rows.
from src.database import SessionLocal
from src.database import ModelEndpoint
from src.endpoint_resolver import normalize_base, build_chat_url, build_headers
db = SessionLocal()
try:
ep = db.query(ModelEndpoint).filter(ModelEndpoint.is_enabled == True).first()
ep = _owned_enabled_endpoint(db, user)
if ep:
base = normalize_base(ep.base_url)
fallback_url = build_chat_url(base)
@@ -594,7 +609,7 @@ def setup_research_routes(research_handler, session_manager=None) -> APIRouter:
try:
models = json.loads(ep.cached_models)
if models:
fallback_model = models[0]
fallback_model = _first_chat_model(models)
except Exception:
pass
_merge(fallback_url, fallback_model, fallback_headers)
+48 -33
View File
@@ -10,8 +10,9 @@ import logging
from core.session_manager import SessionManager
from core.models import ChatMessage
from src.request_models import SessionResponse
from core.database import Session as DbSession, SessionLocal, Document, GalleryImage
from src.auth_helpers import get_current_user, effective_user
from core.database import Session as DbSession, SessionLocal, Document, GalleryImage, utcnow_naive
from src.auth_helpers import get_current_user, effective_user, _auth_disabled
from src.session_actions import is_session_recently_active
def _sanitize_export_filename(name: str) -> str:
@@ -92,35 +93,30 @@ def _reject_compact_during_active_run(session_id: str) -> None:
def _verify_session_owner(request: Request, session_id: str, session_manager=None):
"""Verify the current user owns the session. Raises 404 if not.
"""Verify the current user owns the session, honoring single-user modes.
Ownership is checked against the DB row when one exists (unchanged). If
there is no DB row but the caller owns an in-memory "ghost" session one
that lives only in ``session_manager`` because it was never persisted, or
its DB row was removed out-of-band fall back to the in-memory owner so the
user can still manage and delete it. Without this fallback such sessions are
listed by ``/api/sessions`` (they come from the in-memory manager) yet every
per-session operation 404s, making them impossible to delete (issue #1044).
``session_manager`` is optional and defaults to ``None`` so existing callers
that only care about persisted sessions keep their exact prior behavior.
Authenticated requests must match the stored DB or in-memory owner. When
auth is disabled and no user is present, treat the app as single-user mode:
verify that the session exists, but do not compare its stored owner. This
keeps QA/dev instances with AUTH_ENABLED=false from rejecting owner-stamped
rows created while auth was previously enabled.
"""
user = effective_user(request)
if not user:
raise HTTPException(403, "Authentication required")
if not user and not _auth_disabled():
raise HTTPException(401, "Authentication required")
db = SessionLocal()
try:
row = db.query(DbSession.owner).filter(DbSession.id == session_id).first()
finally:
db.close()
if row is not None:
if row.owner != user:
if user and row.owner != user:
raise HTTPException(404, f"Session {session_id} not found")
return
# No DB row — allow the caller to act on an in-memory ghost they own.
if session_manager is not None:
ghost = getattr(session_manager, "sessions", {}).get(session_id)
if ghost is not None and getattr(ghost, "owner", None) == user:
if ghost is not None and (not user or getattr(ghost, "owner", None) == user):
return
raise HTTPException(404, f"Session {session_id} not found")
@@ -262,7 +258,7 @@ def setup_session_routes(session_manager: SessionManager, config: dict, webhook_
last_msg_map = {}
mode_map = {}
msg_count_map = {}
rows = db.query(DbSession.id, DbSession.folder, DbSession.total_input_tokens, DbSession.total_output_tokens, DbSession.is_important, DbSession.created_at, DbSession.updated_at, DbSession.last_message_at, DbSession.mode, DbSession.message_count).filter(DbSession.archived == False).all()
rows = db.query(DbSession.id, DbSession.folder, DbSession.total_input_tokens, DbSession.total_output_tokens, DbSession.is_important, DbSession.created_at, DbSession.updated_at, DbSession.last_message_at, DbSession.mode, DbSession.message_count).filter(DbSession.archived == False, DbSession.owner == user).all()
for row in rows:
folder_map[row.id] = row.folder
token_map[row.id] = (row.total_input_tokens or 0) + (row.total_output_tokens or 0)
@@ -284,12 +280,14 @@ def setup_session_routes(session_manager: SessionManager, config: dict, webhook_
r[0] for r in db.query(Document.session_id)
.filter(Document.is_active == True,
Document.current_content != None,
func.trim(Document.current_content) != "")
func.trim(Document.current_content) != "",
Document.owner == user)
.distinct().all()
)
img_session_ids = set(
r[0] for r in db.query(GalleryImage.session_id)
.filter(GalleryImage.session_id != None)
.filter(GalleryImage.session_id != None,
GalleryImage.owner == user)
.distinct().all()
)
finally:
@@ -370,8 +368,13 @@ def setup_session_routes(session_manager: SessionManager, config: dict, webhook_
pass
elif not model_to_use:
from src.llm_core import list_model_ids
ids = list_model_ids(endpoint_url, timeout=REQUEST_TIMEOUT,
headers=validation_headers)
ids = list_model_ids(
endpoint_url,
timeout=REQUEST_TIMEOUT,
headers=validation_headers,
owner=user,
endpoint_id=endpoint_id.strip() if endpoint_id else None,
)
if not ids:
raise HTTPException(400, "Cannot reach /v1/models")
# Default to the first CHAT model — endpoints often list embedding/
@@ -385,8 +388,13 @@ def setup_session_routes(session_manager: SessionManager, config: dict, webhook_
from src.llm_core import list_model_ids
import os as _os
req_base = _os.path.basename(model_to_use.rstrip("/"))
avail = list_model_ids(endpoint_url, timeout=REQUEST_TIMEOUT,
headers=validation_headers)
avail = list_model_ids(
endpoint_url,
timeout=REQUEST_TIMEOUT,
headers=validation_headers,
owner=user,
endpoint_id=endpoint_id.strip() if endpoint_id else None,
)
if not avail:
raise HTTPException(400, "Cannot reach /v1/models")
if model_to_use not in avail:
@@ -543,22 +551,25 @@ def setup_session_routes(session_manager: SessionManager, config: dict, webhook_
ids = body.get("ids", [])
except Exception:
ids = []
deleted_count = 0
for sid in ids:
try:
_verify_session_owner(request, sid, session_manager)
session_manager.delete_session(sid)
# Enforce "starred" protection consistent with single-session delete
db = SessionLocal()
try:
db.query(_CM).filter(_CM.session_id == sid).delete()
db.query(DbSession).filter(DbSession.id == sid).delete()
db.commit()
except Exception:
db.rollback()
db_sess = db.query(DbSession).filter(DbSession.id == sid).first()
if db_sess and db_sess.is_important:
continue
finally:
db.close()
if session_manager.delete_session(sid):
deleted_count += 1
except Exception:
pass
return {"deleted": len(ids)}
return {"deleted": deleted_count}
@router.delete("/session/{sid}")
def delete_session(request: Request, sid: str):
@@ -924,7 +935,8 @@ def setup_session_routes(session_manager: SessionManager, config: dict, webhook_
from src.endpoint_resolver import resolve_endpoint
from src.llm_core import llm_call_async
url, model, headers = resolve_endpoint("utility", owner=get_current_user(request))
owner = getattr(session, "owner", None) or effective_user(request)
url, model, headers = resolve_endpoint("utility", owner=owner)
if not url or not model:
url, model, headers = session.endpoint_url, session.model, session.headers
if not url or not model:
@@ -1006,7 +1018,7 @@ def setup_session_routes(session_manager: SessionManager, config: dict, webhook_
}
_THROWAWAY_MAX_MESSAGES = 4 # only delete if <= this many messages
try:
rows = db.query(DbSession).filter(DbSession.archived == False, DbSession.owner == user).all()
rows = db.query(DbSession).filter(DbSession.archived == False, DbSession.owner == user).limit(2000).all()
folder_map = {r.id: r.folder for r in rows}
# Precompute per-session message counts in TWO aggregate queries
# instead of 13 queries PER session — with many chats the per-row
@@ -1017,6 +1029,7 @@ def setup_session_routes(session_manager: SessionManager, config: dict, webhook_
db.query(DbMsg.session_id, _sa_func.count(DbMsg.id))
.filter(DbMsg.role == "assistant").group_by(DbMsg.session_id).all()
)
cleanup_now = utcnow_naive()
for row in rows:
# Never delete important sessions
if getattr(row, 'is_important', False):
@@ -1029,6 +1042,8 @@ def setup_session_routes(session_manager: SessionManager, config: dict, webhook_
if hasattr(session_manager, 'delete_session'):
session_manager.delete_session(row.id)
continue
if is_session_recently_active(row, now=cleanup_now):
continue
msg_count = _counts.get(row.id, 0)
should_delete = False
if msg_count == 0:
+279 -58
View File
@@ -13,6 +13,7 @@ import tempfile
from collections import namedtuple
from pathlib import Path
from typing import Dict, Any
from core.platform_compat import IS_APPLE_SILICON, which_tool
# POSIX-only: `pty`/`fcntl` transitively import `termios`, which does NOT exist
# on Windows, so importing them unconditionally crashed app startup there
@@ -37,6 +38,7 @@ from core.platform_compat import (
IS_WINDOWS,
detached_popen_kwargs,
find_bash,
git_bash_path,
)
@@ -92,6 +94,7 @@ def _venv_activate_prefix(venv: str | None) -> str:
act = venv if venv.endswith("/bin/activate") else venv.rstrip("/") + "/bin/activate"
return f". {act} && "
logger = logging.getLogger(__name__)
PTY_SUPPORTED = pty is not None and fcntl is not None and hasattr(os, "setsid")
@@ -169,7 +172,10 @@ def _package_installed_from_probe(name: str, probe: dict) -> bool:
and (dists.get("torch") or modules.get("torch", {}).get("real_module"))
)
if name == "hf_transfer":
return bool(dists.get("hf-transfer") or modules.get("hf_transfer", {}).get("real_module"))
return bool(
dists.get("hf-transfer")
or modules.get("hf_transfer", {}).get("real_module")
)
return bool(dists.get(name) or modules.get(name, {}).get("real_module"))
@@ -194,8 +200,14 @@ def _package_status_note(name: str, probe: dict) -> str:
if binaries.get("llama-server"):
parts.append(f"native llama-server: {binaries['llama-server']}")
if dists.get("llama-cpp-python"):
parts.append(f"python package: llama-cpp-python {dists['llama-cpp-python']}")
return "; ".join(parts) if parts else "No native llama-server or llama-cpp-python server package found."
parts.append(
f"python package: llama-cpp-python {dists['llama-cpp-python']}"
)
return (
"; ".join(parts)
if parts
else "No native llama-server or llama-cpp-python server package found."
)
if name == "diffusers":
if _package_installed_from_probe(name, probe):
return f"diffusers {dists.get('diffusers', 'available')} with torch {dists.get('torch', 'available')}"
@@ -205,7 +217,9 @@ def _package_status_note(name: str, probe: dict) -> str:
return ""
def _package_pip_update_status(pkg: dict, probe: dict | None = None) -> PackageUpdateStatus:
def _package_pip_update_status(
pkg: dict, probe: dict | None = None
) -> PackageUpdateStatus:
"""Return whether the Dependencies UI should offer a generic pip update.
"Installed" means Cookbook can use the dependency. It does not always mean
@@ -213,12 +227,28 @@ def _package_pip_update_status(pkg: dict, probe: dict | None = None) -> PackageU
native llama-server can come from a package manager/source build, and a CLI
may be on PATH without matching Python package metadata.
"""
if pkg.get("name") == "APFEL":
return PackageUpdateStatus(
False,
"", # Note is empty because IT DOES allow for updates outside of PIP.
)
if pkg.get("kind") == "system" or not pkg.get("pip"):
return PackageUpdateStatus(False, "Update this system dependency outside Odysseus.")
return PackageUpdateStatus(
False, "Update this system dependency outside Odysseus."
)
name = pkg.get("name")
binaries = probe.get("binaries") if isinstance(probe, dict) and isinstance(probe.get("binaries"), dict) else {}
dists = probe.get("dists") if isinstance(probe, dict) and isinstance(probe.get("dists"), dict) else {}
binaries = (
probe.get("binaries")
if isinstance(probe, dict) and isinstance(probe.get("binaries"), dict)
else {}
)
dists = (
probe.get("dists")
if isinstance(probe, dict) and isinstance(probe.get("dists"), dict)
else {}
)
if name == "llama_cpp" and binaries.get("llama-server"):
return PackageUpdateStatus(
@@ -231,7 +261,9 @@ def _package_pip_update_status(pkg: dict, probe: dict | None = None) -> PackageU
"Using a vLLM CLI on PATH without Python package metadata; update it outside Odysseus.",
)
return PackageUpdateStatus(True, "Update uses pip in the selected Python environment.")
return PackageUpdateStatus(
True, "Update uses pip in the selected Python environment."
)
def _prepend_user_install_bins_to_path() -> None:
@@ -250,7 +282,9 @@ def _prepend_user_install_bins_to_path() -> None:
candidates = []
candidates.append(os.path.expanduser("~/.local/bin"))
parts = os.environ.get("PATH", "").split(os.pathsep) if os.environ.get("PATH") else []
parts = (
os.environ.get("PATH", "").split(os.pathsep) if os.environ.get("PATH") else []
)
changed = False
for path in reversed([p for p in candidates if p]):
if path not in parts:
@@ -357,9 +391,11 @@ PTY_UNSUPPORTED_ERROR = "pty_unsupported"
class ShellExecRequest(BaseModel):
command: str
timeout: int | None = None # optional override; 0 = no timeout (run until client disconnects)
use_pty: bool = False # use pseudo-TTY (for progress bars)
use_tmux: bool = False # run in tmux session (survives browser disconnect)
timeout: int | None = (
None # optional override; 0 = no timeout (run until client disconnects)
)
use_pty: bool = False # use pseudo-TTY (for progress bars)
use_tmux: bool = False # run in tmux session (survives browser disconnect)
async def _create_shell(command: str, **kwargs):
@@ -368,8 +404,16 @@ async def _create_shell(command: str, **kwargs):
POSIX: /bin/sh via create_subprocess_shell (unchanged behaviour).
Windows: prefer a real bash (Git Bash/WSL) so bash-syntax commands behave
the same as on Linux; fall back to cmd.exe when no bash is installed.
Powershell commands are executed directly via cmd.exe /c to avoid quoting
and env variable expansion errors under Git Bash.
"""
if IS_WINDOWS:
# PowerShell commands (used by the frontend for Windows log-file polling
# and session management) must run directly — passing them through
# bash -c mangles $env:VAR syntax and breaks the command.
cmd_trim = command.strip()
if cmd_trim.startswith("powershell") or cmd_trim.startswith("cmd "):
return await asyncio.create_subprocess_shell(command, **kwargs)
bash = find_bash()
if bash:
return await asyncio.create_subprocess_exec(bash, "-c", command, **kwargs)
@@ -386,9 +430,7 @@ async def _exec_shell(command: str, timeout: int = EXEC_TIMEOUT) -> Dict[str, An
stderr=asyncio.subprocess.PIPE,
cwd=str(Path.home()),
)
stdout_b, stderr_b = await asyncio.wait_for(
proc.communicate(), timeout=timeout
)
stdout_b, stderr_b = await asyncio.wait_for(proc.communicate(), timeout=timeout)
stdout = stdout_b.decode(errors="replace")[:MAX_OUTPUT]
stderr = stderr_b.decode(errors="replace")[:MAX_OUTPUT]
return {"stdout": stdout, "stderr": stderr, "exit_code": proc.returncode}
@@ -399,7 +441,11 @@ async def _exec_shell(command: str, timeout: int = EXEC_TIMEOUT) -> Dict[str, An
await proc.wait()
except ProcessLookupError:
pass
return {"stdout": "", "stderr": f"Command timed out after {timeout}s", "exit_code": -1}
return {
"stdout": "",
"stderr": f"Command timed out after {timeout}s",
"exit_code": -1,
}
except Exception as e:
return {"stdout": "", "stderr": str(e), "exit_code": -1}
@@ -481,7 +527,7 @@ async def _generate_pty(cmd: str, timeout: int, request: Request):
if idx == -1:
break
line = buf[:idx].decode(errors="replace")
buf = buf[idx + sep_len:]
buf = buf[idx + sep_len :]
if line:
yield f"data: {json.dumps({'stream': 'stdout', 'data': line})}\n\n"
@@ -503,7 +549,7 @@ async def _generate_pty(cmd: str, timeout: int, request: Request):
if idx == -1:
break
line = buf[:idx].decode(errors="replace")
buf = buf[idx + sep_len:]
buf = buf[idx + sep_len :]
if line:
yield f"data: {json.dumps({'stream': 'stdout', 'data': line})}\n\n"
if buf:
@@ -534,6 +580,7 @@ def _pty_read(fd: int) -> bytes | None:
"""Blocking read from PTY fd. Called via run_in_executor.
Returns bytes on data, None on timeout (no data yet)."""
import select
r, _, _ = select.select([fd], [], [], 1.0)
if r:
try:
@@ -557,10 +604,10 @@ async def _generate_tmux(cmd: str, request: Request):
script_path = TMUX_LOG_DIR / f"{session_id}.sh"
script_path.write_text(
f"#!/bin/bash\n"
f"ODYSSEUS_USER_SHELL=\"${{SHELL:-}}\"\n"
f"if [ -n \"$ODYSSEUS_USER_SHELL\" ] && [ -x \"$ODYSSEUS_USER_SHELL\" ]; then\n"
f" ODYSSEUS_USER_PATH=\"$(\"$ODYSSEUS_USER_SHELL\" -ic 'printf \"__ODYSSEUS_PATH__%s\\n\" \"$PATH\"' 2>/dev/null | sed -n 's/^__ODYSSEUS_PATH__//p' | tail -n 1 || true)\"\n"
f" if [ -n \"$ODYSSEUS_USER_PATH\" ]; then export PATH=\"$ODYSSEUS_USER_PATH:$PATH\"; fi\n"
f'ODYSSEUS_USER_SHELL="${{SHELL:-}}"\n'
f'if [ -n "$ODYSSEUS_USER_SHELL" ] && [ -x "$ODYSSEUS_USER_SHELL" ]; then\n'
f' ODYSSEUS_USER_PATH="$("$ODYSSEUS_USER_SHELL" -ic \'printf "__ODYSSEUS_PATH__%s\\n" "$PATH"\' 2>/dev/null | sed -n \'s/^__ODYSSEUS_PATH__//p\' | tail -n 1 || true)"\n'
f' if [ -n "$ODYSSEUS_USER_PATH" ]; then export PATH="$ODYSSEUS_USER_PATH:$PATH"; fi\n'
f"fi\n"
f"{cmd} 2>&1 | tee '{log_path}'\n"
f"EC=${{PIPESTATUS[0]}}\n"
@@ -570,7 +617,9 @@ async def _generate_tmux(cmd: str, request: Request):
encoding="utf-8",
)
script_path.chmod(0o755)
logger.info("tmux wrapper script created: session=%s path=%s", session_id, script_path)
logger.info(
"tmux wrapper script created: session=%s path=%s", session_id, script_path
)
tmux_cmd = f"tmux new-session -d -s {session_id} {shlex.quote(str(script_path))}"
@@ -602,7 +651,9 @@ async def _generate_tmux(cmd: str, request: Request):
# Read new lines from log
try:
if log_path.exists():
lines = log_path.read_text(encoding="utf-8", errors="replace").splitlines()
lines = log_path.read_text(
encoding="utf-8", errors="replace"
).splitlines()
new_lines = lines[lines_sent:]
for line in new_lines:
if line.startswith(":::EXIT_CODE:::"):
@@ -630,7 +681,9 @@ async def _generate_tmux(cmd: str, request: Request):
# Session ended — do one final read
await asyncio.sleep(0.5)
if log_path.exists():
lines = log_path.read_text(encoding="utf-8", errors="replace").splitlines()
lines = log_path.read_text(
encoding="utf-8", errors="replace"
).splitlines()
for line in lines[lines_sent:]:
if line.startswith(":::EXIT_CODE:::"):
try:
@@ -672,8 +725,8 @@ async def _generate_win_detached(cmd: str, request: Request):
if bash:
script_path = TMUX_LOG_DIR / f"{session_id}.sh"
script_path.write_text(
f"{cmd} > {shlex.quote(str(log_path))} 2>&1\n"
f"echo $? > {shlex.quote(str(exit_path))}\n",
f"{cmd} > {shlex.quote(git_bash_path(log_path))} 2>&1\n"
f"echo $? > {shlex.quote(git_bash_path(exit_path))}\n",
encoding="utf-8",
)
argv = [bash, str(script_path)]
@@ -711,7 +764,9 @@ async def _generate_win_detached(cmd: str, request: Request):
return
try:
if log_path.exists():
lines = log_path.read_text(encoding="utf-8", errors="replace").splitlines()
lines = log_path.read_text(
encoding="utf-8", errors="replace"
).splitlines()
for line in lines[lines_sent:]:
yield f"data: {json.dumps({'stream': 'stdout', 'data': line})}\n\n"
lines_sent = len(lines)
@@ -723,11 +778,18 @@ async def _generate_win_detached(cmd: str, request: Request):
await asyncio.sleep(0.3)
try:
if log_path.exists():
lines = log_path.read_text(encoding="utf-8", errors="replace").splitlines()
lines = log_path.read_text(
encoding="utf-8", errors="replace"
).splitlines()
for line in lines[lines_sent:]:
yield f"data: {json.dumps({'stream': 'stdout', 'data': line})}\n\n"
lines_sent = len(lines)
exit_code = int((exit_path.read_text(encoding="utf-8", errors="replace").strip() or "0"))
exit_code = int(
(
exit_path.read_text(encoding="utf-8", errors="replace").strip()
or "0"
)
)
except Exception:
exit_code = 0
break
@@ -753,7 +815,9 @@ def setup_shell_routes() -> APIRouter:
return {"stdout": "", "stderr": "No command provided", "exit_code": 1}
logger.info("User shell exec requested: length=%d", len(cmd))
result = await _exec_shell(cmd, timeout=EXEC_TIMEOUT)
result = await _exec_shell(
cmd, timeout=req.timeout if req.timeout is not None else EXEC_TIMEOUT
)
return result
@router.post("/api/shell/stream")
@@ -762,9 +826,11 @@ def setup_shell_routes() -> APIRouter:
_require_admin(request)
cmd = req.command.strip()
if not cmd:
async def empty():
yield f"data: {json.dumps({'stream': 'stderr', 'data': 'No command provided'})}\n\n"
yield f"data: {json.dumps({'exit_code': 1})}\n\n"
return StreamingResponse(empty(), media_type="text/event-stream")
timeout = req.timeout if req.timeout is not None else STREAM_TIMEOUT
@@ -781,7 +847,11 @@ def setup_shell_routes() -> APIRouter:
if use_tmux:
# tmux is POSIX-only; Windows uses a detached-process + logfile tail
# that preserves the "survives disconnect" behaviour.
gen = _generate_win_detached(cmd, request) if IS_WINDOWS else _generate_tmux(cmd, request)
gen = (
_generate_win_detached(cmd, request)
if IS_WINDOWS
else _generate_tmux(cmd, request)
)
return StreamingResponse(gen, media_type="text/event-stream")
if use_pty and not IS_WINDOWS:
@@ -813,7 +883,12 @@ def setup_shell_routes() -> APIRouter:
chunk = await stream.read(4096)
if not chunk:
if buf:
await q.put((name, buf.decode(errors="replace").rstrip("\r\n")))
await q.put(
(
name,
buf.decode(errors="replace").rstrip("\r\n"),
)
)
break
buf += chunk
while True:
@@ -821,7 +896,7 @@ def setup_shell_routes() -> APIRouter:
if idx == -1:
break
line = buf[:idx].decode(errors="replace")
buf = buf[idx + sep_len:]
buf = buf[idx + sep_len :]
if line:
await q.put((name, line))
finally:
@@ -880,7 +955,12 @@ def setup_shell_routes() -> APIRouter:
return StreamingResponse(generate(), media_type="text/event-stream")
@router.get("/api/cookbook/packages")
async def list_packages(request: Request, host: str | None = None, ssh_port: str | None = None, venv: str | None = None):
async def list_packages(
request: Request,
host: str | None = None,
ssh_port: str | None = None,
venv: str | None = None,
):
"""Check which optional packages are installed.
Local-target packages are checked in-process. Remote-target packages
@@ -890,7 +970,13 @@ def setup_shell_routes() -> APIRouter:
"""
_require_admin(request)
_reject_cross_site(request)
import importlib, importlib.metadata as importlib_metadata, shlex, json as _json, site, sys
import importlib
import importlib.metadata as importlib_metadata
import shlex
import json as _json
import site
import sys
_prepend_user_install_bins_to_path()
importlib.invalidate_caches()
try:
@@ -905,26 +991,115 @@ def setup_shell_routes() -> APIRouter:
raise HTTPException(400, "Invalid ssh_port")
packages = [
# ── System ── OS binaries, not pip packages
{"name": "tmux", "pip": "", "desc": "Required for Linux/Termux Cookbook background downloads and serves", "category": "System", "target": "remote", "kind": "system", "install_hint": "Run Cookbook server setup, or install tmux with apt/pacman/dnf/apk/zypper."},
{"name": "docker", "pip": "", "desc": "Required only for Docker-backed launch commands", "category": "System", "target": "remote", "kind": "system", "install_hint": "Install Docker on the selected server and allow this user to run docker."},
{
"name": "tmux",
"pip": "",
"desc": "Required for Linux/Termux Cookbook background downloads and serves",
"category": "System",
"target": "remote",
"kind": "system",
"install_hint": "Run Cookbook server setup, or install tmux with apt/pacman/dnf/apk/zypper.",
},
{
"name": "docker",
"pip": "",
"desc": "Required only for Docker-backed launch commands",
"category": "System",
"target": "remote",
"kind": "system",
"install_hint": "Install Docker on the selected server and allow this user to run docker.",
},
# ── LLM ── installs on GPU servers for model serving/downloading
{"name": "hf_transfer", "pip": "hf_transfer", "desc": "Fast model downloads from HuggingFace", "category": "LLM", "target": "remote"},
{"name": "llama_cpp", "pip": "llama-cpp-python[server]", "desc": "Serve GGUF models via llama.cpp", "category": "LLM", "target": "remote"},
{"name": "sglang", "pip": "sglang[all]", "desc": "Serve HF safetensors models via SGLang", "category": "LLM", "target": "remote"},
{"name": "vllm", "pip": "vllm", "desc": "High-throughput LLM serving engine", "category": "LLM", "target": "remote"},
{
"name": "hf_transfer",
"pip": "hf_transfer",
"desc": "Fast model downloads from HuggingFace",
"category": "LLM",
"target": "remote",
},
{
"name": "llama_cpp",
"pip": "llama-cpp-python[server]",
"desc": "Serve GGUF models via llama.cpp",
"category": "LLM",
"target": "remote",
},
{
"name": "sglang",
"pip": "sglang[all]",
"desc": "Serve HF safetensors models via SGLang",
"category": "LLM",
"target": "remote",
},
{
"name": "vllm",
"pip": "vllm",
"desc": "High-throughput LLM serving engine",
"category": "LLM",
"target": "remote",
},
{
"name": "APFEL",
"pip": "",
"desc": "OpenAI-compatible API for Apple Foundational Models on Apple Silicon",
"category": "LLM",
"target": "local",
"kind": "system",
"install_cmd": "brew install apfel",
"update_cmd": "brew upgrade apfel",
"install_hint": "Requires a native Apple Silicon Mac with Apple Foundational Models support. Installable via Homebrew on supported Macs.",
},
# ── Image ── editor + diffusion model serving
{"name": "diffusers", "pip": "diffusers[torch]", "desc": "Image generation pipelines (SD, Flux) with PyTorch", "category": "Image", "target": "remote"},
{"name": "rembg", "pip": "rembg[gpu]", "desc": "AI background removal for image editor", "category": "Image", "target": "local"},
{"name": "realesrgan", "pip": "realesrgan", "desc": "AI denoise + upscale (Real-ESRGAN). Used by editor's Denoise and Upscale tools.", "category": "Image", "target": "local"},
{
"name": "diffusers",
"pip": "diffusers[torch]",
"desc": "Image generation pipelines (SD, Flux) with PyTorch",
"category": "Image",
"target": "remote",
},
{
"name": "rembg",
"pip": "rembg[gpu]",
"desc": "AI background removal for image editor",
"category": "Image",
"target": "local",
},
{
"name": "realesrgan",
"pip": "realesrgan",
"desc": "AI denoise + upscale (Real-ESRGAN). Used by editor's Denoise and Upscale tools.",
"category": "Image",
"target": "local",
},
# ── Tools ──
{"name": "playwright", "pip": "playwright", "desc": "Browser automation for web tools", "category": "Tools", "target": "local"},
{
"name": "playwright",
"pip": "playwright",
"desc": "Browser automation for web tools",
"category": "Tools",
"target": "local",
},
]
# Most packages should not be installed through external means. Hence, set the default of the
# install_cmd and update_cmd to None, which indicates that the recommended way to install/update is through the Cookbook # server setup or pip. Only system packages, should have explicit install/update commands provided.
for pkg in packages:
pkg.setdefault("install_cmd", None)
pkg.setdefault("update_cmd", None)
# Remote check: for remote-target packages, probe the selected server's
# venv over SSH so a remote `pip install` actually reflects here.
remote_status: dict = {}
remote_details: dict = {}
remote_names = [p["name"] for p in packages if p.get("target") == "remote" and p.get("kind") != "system"]
remote_system_names = [p["name"] for p in packages if p.get("target") == "remote" and p.get("kind") == "system"]
remote_names = [
p["name"]
for p in packages
if p.get("target") == "remote" and p.get("kind") != "system"
]
remote_system_names = [
p["name"]
for p in packages
if p.get("target") == "remote" and p.get("kind") == "system"
]
if host and remote_names:
try:
py = _package_probe_script(remote_names)
@@ -934,7 +1109,9 @@ def setup_shell_routes() -> APIRouter:
inner = f"{src}python3 -c {shlex.quote(py)}"
argv = _ssh_base_argv(host, ssh_port) + [inner]
proc = await asyncio.create_subprocess_exec(
*argv, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE
*argv,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
out, _err = await asyncio.wait_for(proc.communicate(), timeout=12)
txt = out.decode("utf-8", errors="replace").strip()
@@ -958,11 +1135,15 @@ def setup_shell_routes() -> APIRouter:
checks = []
for name in remote_system_names:
qn = shlex.quote(name)
checks.append(f"if command -v {qn} >/dev/null 2>&1; then echo {qn}=1; else echo {qn}=0; fi")
checks.append(
f"if command -v {qn} >/dev/null 2>&1; then echo {qn}=1; else echo {qn}=0; fi"
)
inner = " ; ".join(checks)
argv = _ssh_base_argv(host, ssh_port) + [inner]
proc = await asyncio.create_subprocess_exec(
*argv, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE
*argv,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
out, _err = await asyncio.wait_for(proc.communicate(), timeout=12)
txt = out.decode("utf-8", errors="replace").strip()
@@ -987,11 +1168,25 @@ def setup_shell_routes() -> APIRouter:
if note:
pkg["status_note"] = note
elif pkg.get("kind") == "system":
pkg["installed"] = shutil.which(pkg["name"]) is not None
if pkg["name"] == "APFEL":
pkg["applicable"] = IS_APPLE_SILICON
pkg["installed"] = which_tool("apfel") is not None
pkg["status_note"] = (
"Available on Apple Silicon (arm64) devices; exposed through a local OpenAI-compatible API."
if IS_APPLE_SILICON
else "Requires a native Apple Silicon Mac with Apple Foundational Models support."
)
else:
pkg["installed"] = shutil.which(pkg["name"]) is not None
elif pkg["name"] == "llama_cpp" and shutil.which("llama-server"):
pkg["installed"] = True
pkg["status_note"] = f"native llama-server: {shutil.which('llama-server')}"
probe = {"binaries": {"llama-server": shutil.which("llama-server")}, "dists": {}}
pkg["status_note"] = (
f"native llama-server: {shutil.which('llama-server')}"
)
probe = {
"binaries": {"llama-server": shutil.which("llama-server")},
"dists": {},
}
elif pkg["name"] == "vllm":
_vllm_cli = shutil.which("vllm")
pkg["installed"] = _vllm_cli is not None
@@ -1014,6 +1209,12 @@ def setup_shell_routes() -> APIRouter:
pkg["installed"] = False
except importlib_metadata.PackageNotFoundError:
pkg["installed"] = False
except Exception:
# Installed but crashes on import — e.g. a CUDA build of
# llama-cpp-python raising FileNotFoundError when the CUDA
# toolkit dir is absent. One broken optional package must not
# 500 the entire packages panel; report it as not usable.
pkg["installed"] = False
if pkg.get("installed"):
update_status = _package_pip_update_status(pkg, probe)
@@ -1037,15 +1238,30 @@ def setup_shell_routes() -> APIRouter:
"""Install a package via pip. Admin only — pip install is effectively code exec."""
_require_admin(request)
import sys as _sys
body = await request.json()
pip_name = body.get("pip")
if not pip_name:
return {"ok": False, "error": "No package specified"}
# Validate against known packages to prevent arbitrary pip install
known = {
"rembg[gpu]", "hf_transfer", "llama-cpp-python[server]", "sglang[all]", "diffusers", "diffusers[torch]",
"TTS", "bark", "faster-whisper", "playwright", "realesrgan", "gfpgan",
"insightface", "onnxruntime-gpu", "onnxruntime", "hdbscan", "vllm",
"rembg[gpu]",
"hf_transfer",
"llama-cpp-python[server]",
"sglang[all]",
"diffusers",
"diffusers[torch]",
"TTS",
"bark",
"faster-whisper",
"playwright",
"realesrgan",
"gfpgan",
"insightface",
"onnxruntime-gpu",
"onnxruntime",
"hdbscan",
"vllm",
}
if pip_name not in known:
return {"ok": False, "error": f"Unknown package: {pip_name}"}
@@ -1071,6 +1287,7 @@ def setup_shell_routes() -> APIRouter:
"""
_require_admin(request)
from routes.cookbook_helpers import _llama_cpp_rebuild_cmd
body = await request.json()
engine = str(body.get("engine") or "llamacpp").strip()
if engine != "llamacpp":
@@ -1079,7 +1296,11 @@ def setup_shell_routes() -> APIRouter:
ssh_port = body.get("ssh_port")
cmd = _llama_cpp_rebuild_cmd()
try:
argv = (_ssh_base_argv(host, ssh_port) + [cmd]) if host else ["bash", "-lc", cmd]
argv = (
(_ssh_base_argv(host, ssh_port) + [cmd])
if host
else ["bash", "-lc", cmd]
)
except ValueError as e:
raise HTTPException(400, str(e))
try:
+44 -16
View File
@@ -21,10 +21,44 @@ from src.auth_helpers import get_current_user
logger = logging.getLogger(__name__)
_DATA_URL_RE = re.compile(
r'^data:image/(?P<fmt>png|jpeg|jpg);base64,(?P<data>.+)$',
re.IGNORECASE | re.DOTALL,
)
_DATA_URL_RE = re.compile(r"^data:image/png;base64,(?P<data>.+)$", re.IGNORECASE | re.DOTALL)
_ANY_IMAGE_DATA_URL_RE = re.compile(r"^data:image/[^;]+;base64,", re.IGNORECASE)
_PNG_MAGIC = b"\x89PNG\r\n\x1a\n"
_MAX_SIGNATURE_BYTES = 2 * 1024 * 1024
_MAX_SIGNATURE_B64 = ((_MAX_SIGNATURE_BYTES + 2) // 3) * 4
_MAX_SIGNATURE_DIMENSION = 4096
def _normalize_signature_png(raw: str) -> str:
raw = (raw or "").strip()
m = _DATA_URL_RE.match(raw)
if m:
b64 = m.group("data")
elif _ANY_IMAGE_DATA_URL_RE.match(raw):
raise HTTPException(400, "Signature data must be a PNG image")
else:
b64 = raw
if len(b64) > _MAX_SIGNATURE_B64:
raise HTTPException(400, "Signature PNG is too large")
try:
payload = base64.b64decode(b64, validate=True)
except Exception:
raise HTTPException(400, "Signature data must be base64-encoded PNG bytes")
if not payload:
raise HTTPException(400, "Signature PNG is empty")
if len(payload) > _MAX_SIGNATURE_BYTES:
raise HTTPException(400, "Signature PNG is too large")
if not payload.startswith(_PNG_MAGIC):
raise HTTPException(400, "Signature data must be a PNG image")
return base64.b64encode(payload).decode("ascii")
def _signature_dimension(value: Optional[int]) -> Optional[int]:
if value is None:
return None
if not isinstance(value, int) or value < 1 or value > _MAX_SIGNATURE_DIMENSION:
raise HTTPException(400, "Signature dimensions are invalid")
return value
class SignatureCreate(BaseModel):
@@ -67,24 +101,18 @@ def setup_signature_routes() -> APIRouter:
@router.post("/api/signatures")
async def create_signature(request: Request, req: SignatureCreate) -> Dict[str, Any]:
user = get_current_user(request)
raw = (req.data or "").strip()
m = _DATA_URL_RE.match(raw)
b64 = m.group("data") if m else raw
try:
payload = base64.b64decode(b64, validate=True)
if not payload:
raise ValueError("empty payload")
except Exception:
raise HTTPException(400, "Signature data must be base64-encoded PNG bytes")
b64 = _normalize_signature_png(req.data)
width = _signature_dimension(req.width)
height = _signature_dimension(req.height)
sig = Signature(
id=str(uuid.uuid4()),
owner=user,
name=(req.name or "Signature").strip() or "Signature",
data_png=b64,
width=req.width,
height=req.height,
svg=req.svg,
width=width,
height=height,
svg=None,
)
db = SessionLocal()
try:
+107 -1
View File
@@ -11,6 +11,8 @@ import logging
import re
from typing import List, Optional
import httpx
from fastapi import APIRouter, HTTPException, Request
from pydantic import BaseModel, Field
@@ -51,6 +53,10 @@ class SkillAddRequest(BaseModel):
steps: List[str] = Field(default_factory=list)
class SkillImportUrlRequest(BaseModel):
url: str = Field(..., min_length=8, max_length=2000)
class SkillUpdateRequest(BaseModel):
name: Optional[str] = None
description: Optional[str] = None
@@ -1014,7 +1020,7 @@ def _resolve_audit_models(owner=None):
spec = (get_setting("teacher_model", "") or "").strip()
if spec:
from src.ai_interaction import _resolve_model
t_url, t_model, t_headers = _resolve_model(spec)
t_url, t_model, t_headers = _resolve_model(spec, owner=owner)
if t_url and t_model:
teacher = (t_url, t_model, t_headers)
except Exception as e:
@@ -1103,6 +1109,35 @@ def setup_skills_routes(skills_manager: SkillsManager) -> APIRouter:
idx = skills_manager.index_for(owner=user)
return {"index": idx, "count": len(idx)}
@router.get("/slash-catalog")
async def get_slash_catalog(request: Request):
"""Return skills that are available as slash commands.
Mirrors the agent prompt's published-skill index so the UI never offers
a slash command the model would not normally be allowed to discover.
"""
user = _owner(request)
all_skills = {s.get("name"): s for s in skills_manager.load(owner=user)}
entries = []
for s in skills_manager.index_for(owner=user):
name = (s.get("name") or "").strip()
if not name:
continue
full = all_skills.get(name) or {}
category = (s.get("category") or full.get("category") or "general").strip() or "general"
entries.append({
"type": "skill",
"token": f"/{name}",
"name": name,
"category": f"Skills / {category}",
"help": s.get("description") or full.get("description") or "",
"usage": f"/{name} <request>",
"uses": int(full.get("uses") or 0),
"last_used": full.get("last_used"),
})
entries.sort(key=lambda row: row["name"])
return {"skills": entries, "count": len(entries)}
@router.get("/builtin")
async def list_builtin_skills(request: Request):
"""Read-only list of the agent's built-in tool capabilities (research,
@@ -1203,6 +1238,36 @@ def setup_skills_routes(skills_manager: SkillsManager) -> APIRouter:
save_settings(settings)
return {"ok": True, "name": name, "is_overridden": False}
@router.post("/import-from-url")
async def import_skill_from_url(request: Request, body: SkillImportUrlRequest):
"""Install a SKILL.md bundle from a public GitHub URL (skills.sh links supported)."""
require_admin(request)
user = _owner(request)
from services.memory.skill_importer import (
SkillImportError,
fetch_skill_bundle,
)
try:
files, _src = fetch_skill_bundle(body.url.strip())
entry = skills_manager.import_bundle_from_files(
files,
owner=user,
source_url=body.url.strip(),
)
except SkillImportError as e:
raise HTTPException(400, str(e)) from e
except httpx.HTTPError as e:
logger.warning("skill import fetch failed: %s", e)
detail = str(e).strip() or "Could not download skill from URL"
raise HTTPException(502, detail) from e
except Exception as e:
logger.error("skill import failed: %s", e)
raise HTTPException(500, "Skill import failed") from e
_fire_skill_added(user)
return {"ok": True, "skill": entry, "files": len(files)}
@router.post("/add")
async def add_skill(request: Request, body: SkillAddRequest):
user = _owner(request)
@@ -1236,6 +1301,47 @@ def setup_skills_routes(skills_manager: SkillsManager) -> APIRouter:
_fire_skill_added(user)
return {"ok": True, "deduped": bool(entry.get("_deduped")), "skill": entry}
@router.post("/{skill_id}/invoke")
async def invoke_skill(request: Request, skill_id: str):
"""Build a skill-pinned prompt for slash-command invocation.
This is intentionally server-side so availability, ownership, and usage
accounting use the same rules as the SkillsManager.
"""
user = _owner(request)
try:
body = await request.json()
except Exception:
body = {}
request_text = (body.get("request") or "").strip() if isinstance(body, dict) else ""
invokable = {
s.get("name"): s for s in skills_manager.index_for(owner=user)
if (s.get("name") or "").strip()
}
match = invokable.get(skill_id)
if not match:
raise HTTPException(404, "Skill is not available for slash invocation")
name = match.get("name")
md = skills_manager.read_skill_md(name, owner=user)
if md is None:
raise HTTPException(404, "Skill source unavailable")
skills_manager.record_use(name, owner=user)
message = (
"Apply the skill below to my request, following its Procedure / Pitfalls / Verification.\n\n"
f"--- BEGIN SKILL ---\n{md}\n--- END SKILL ---\n\n"
+ (f"Request: {request_text}" if request_text else "Request: (use the skill as appropriate)")
)
return {
"ok": True,
"type": "skill",
"name": name,
"command": f"/{name}",
"message": message,
}
@router.get("/{skill_id}")
async def get_skill(request: Request, skill_id: str):
user = _owner(request)
+1 -3
View File
@@ -4,12 +4,10 @@
from fastapi import APIRouter, HTTPException, UploadFile, File
import logging
from src.upload_limits import read_upload_limited
from src.upload_limits import read_upload_limited, STT_MAX_AUDIO_BYTES
logger = logging.getLogger(__name__)
STT_MAX_AUDIO_BYTES = 25 * 1024 * 1024
def setup_stt_routes(stt_service):
"""Setup STT routes with the provided STT service"""
+37 -22
View File
@@ -11,7 +11,9 @@ from fastapi import APIRouter, HTTPException, Request
from pydantic import BaseModel
from core.database import SessionLocal, ScheduledTask, TaskRun
from core.constants import internal_api_base
from src.auth_helpers import get_current_user
from src.constants import DATA_DIR, EMAIL_URGENCY_CACHE_DIR
from src.task_scheduler import compute_next_run, HOUSEKEEPING_DEFAULTS
from routes.prefs_routes import _load_for_user, _save_for_user
@@ -56,7 +58,7 @@ def _maybe_cascade_calendar_event(task) -> None:
try:
with httpx.Client(timeout=10) as client:
r = client.delete(
f"http://localhost:7000/api/calendar/events/{uid}",
f"{internal_api_base()}/api/calendar/events/{uid}",
headers=headers,
)
if r.status_code >= 400:
@@ -81,7 +83,7 @@ def _maybe_cascade_calendar_event(task) -> None:
try:
with httpx.Client(timeout=10) as client:
# Find the Cookbook calendar.
cal_r = client.get("http://localhost:7000/api/calendar/calendars", headers=headers)
cal_r = client.get(f"{internal_api_base()}/api/calendar/calendars", headers=headers)
if cal_r.status_code >= 400:
return
cals = (cal_r.json() or {}).get("calendars", [])
@@ -98,7 +100,7 @@ def _maybe_cascade_calendar_event(task) -> None:
start = (now - _td(days=30)).isoformat()
end = (now + _td(days=365)).isoformat()
ev_r = client.get(
"http://localhost:7000/api/calendar/events",
f"{internal_api_base()}/api/calendar/events",
params={"start": start, "end": end, "calendar": cal_href},
headers=headers,
)
@@ -291,20 +293,24 @@ def setup_task_routes(task_scheduler) -> APIRouter:
def _owner(request: Request):
return get_current_user(request)
async def _generate_task_name(prompt: str) -> str:
async def _generate_task_name(prompt: str, owner: Optional[str] = None) -> str:
"""Use LLM to generate a short task name from the prompt."""
try:
from src.llm_core import llm_call_async
from core.database import Session as DbSession
db = SessionLocal()
try:
recent = db.query(DbSession).filter(
q = db.query(DbSession).filter(
DbSession.endpoint_url.isnot(None),
DbSession.model.isnot(None),
).order_by(DbSession.created_at.desc()).first()
)
if owner:
q = q.filter(DbSession.owner == owner)
recent = q.order_by(DbSession.created_at.desc()).first()
if not recent:
return prompt[:50].strip()
url, model = recent.endpoint_url, recent.model
headers = recent.headers or {}
finally:
db.close()
@@ -315,6 +321,7 @@ def setup_task_routes(task_scheduler) -> APIRouter:
{"role": "user", "content": prompt[:500]},
],
max_tokens=20,
headers=headers,
timeout=15,
)
title = result.strip().strip('"\'').strip()
@@ -429,6 +436,20 @@ def setup_task_routes(task_scheduler) -> APIRouter:
except Exception:
return False
def _validate_then_task_id(db, then_task_id: Optional[str], user: Optional[str], current_task_id: Optional[str] = None) -> Optional[str]:
target_id = (then_task_id or "").strip()
if not target_id:
return None
if current_task_id and target_id == current_task_id:
raise HTTPException(400, "Task cannot chain to itself")
q = db.query(ScheduledTask).filter(ScheduledTask.id == target_id)
if user:
q = q.filter(ScheduledTask.owner == user)
target = q.first()
if not target:
raise HTTPException(404, "Chained task not found")
return target.id
@router.post("")
async def create_task(request: Request, req: TaskCreate):
user = _owner(request)
@@ -465,7 +486,7 @@ def setup_task_routes(task_scheduler) -> APIRouter:
from src.builtin_actions import BUILTIN_ACTION_INFO
name = BUILTIN_ACTION_INFO.get(req.action, req.action or "Action Task")
elif req.prompt:
name = await _generate_task_name(req.prompt)
name = await _generate_task_name(req.prompt, owner=user)
else:
name = "Untitled Task"
@@ -492,6 +513,7 @@ def setup_task_routes(task_scheduler) -> APIRouter:
task_id = str(uuid.uuid4())
db = SessionLocal()
try:
then_task_id = _validate_then_task_id(db, req.then_task_id, user)
notifications_enabled = (
False if req.task_type == "action" and req.notifications_enabled is None
else bool(req.notifications_enabled) if req.notifications_enabled is not None
@@ -527,7 +549,7 @@ def setup_task_routes(task_scheduler) -> APIRouter:
output_target=req.output_target,
model=req.model or None,
endpoint_url=req.endpoint_url or None,
then_task_id=req.then_task_id or None,
then_task_id=then_task_id,
webhook_token=webhook_token,
notifications_enabled=notifications_enabled,
)
@@ -609,7 +631,7 @@ def setup_task_routes(task_scheduler) -> APIRouter:
removed_files = 0
if action == "check_email_urgency":
cache_dir = Path("data/email_urgency_cache")
cache_dir = Path(EMAIL_URGENCY_CACHE_DIR)
if cache_dir.exists():
for child in cache_dir.glob("*.json"):
try:
@@ -618,7 +640,7 @@ def setup_task_routes(task_scheduler) -> APIRouter:
except Exception:
pass
owner_slug = "".join(c if (c.isalnum() or c in "-_.@") else "_" for c in (user or "default"))
for state_path in [Path(f"data/email_urgency_state_{owner_slug}.json")]:
for state_path in [Path(DATA_DIR) / f"email_urgency_state_{owner_slug}.json"]:
try:
if state_path.exists():
state_path.unlink()
@@ -680,15 +702,7 @@ def setup_task_routes(task_scheduler) -> APIRouter:
if req.trigger_count is not None:
task.trigger_count = req.trigger_count
if req.then_task_id is not None:
if req.then_task_id:
chain_target = db.query(ScheduledTask).filter(
ScheduledTask.id == req.then_task_id
).first()
if not chain_target:
raise HTTPException(400, "Chained task not found")
if chain_target.owner != user:
raise HTTPException(403, "Cannot chain to another user's task")
task.then_task_id = req.then_task_id or None
task.then_task_id = _validate_then_task_id(db, req.then_task_id, user, current_task_id=task.id)
if req.notifications_enabled is not None:
task.notifications_enabled = bool(req.notifications_enabled)
if req.cron_expression is not None:
@@ -969,7 +983,7 @@ def setup_task_routes(task_scheduler) -> APIRouter:
"tag", "label", "move", "archive", "delete", "mark", "schedule",
)
try:
from src.agent_tools import get_mcp_manager
from src.tool_utils import get_mcp_manager
mcp = get_mcp_manager()
if mcp:
for tool in mcp.get_all_tools():
@@ -1064,6 +1078,7 @@ def setup_task_routes(task_scheduler) -> APIRouter:
desc = (body.get("description") or "").strip()
if not desc:
return {"success": False, "message": "Nothing to parse"}
user = _owner(request)
now = _dt.now()
# Give the model the current date/time + weekday so relative phrasing
@@ -1090,9 +1105,9 @@ def setup_task_routes(task_scheduler) -> APIRouter:
"use cron '0 H * * 1-5'. Keep the prompt actionable and self-contained."
)
try:
url, model, headers = resolve_endpoint("utility")
url, model, headers = resolve_endpoint("utility", owner=user or None)
if not url:
url, model, headers = resolve_endpoint("default")
url, model, headers = resolve_endpoint("default", owner=user or None)
if not (url and model):
return {"success": False, "message": "No model endpoint configured"}
raw = await llm_call_async(
+51 -34
View File
@@ -13,9 +13,43 @@ from src.upload_handler import count_recent_uploads
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/upload", tags=["upload"])
UPLOAD_RESPONSE_HEADERS = {"X-Content-Type-Options": "nosniff"}
def setup_upload_routes(upload_handler):
"""Setup upload routes with the provided handler"""
def _upload_root() -> str:
from src.constants import UPLOAD_DIR
return os.path.realpath(getattr(upload_handler, "upload_dir", UPLOAD_DIR))
def _path_inside_upload_dir(path: str) -> bool:
try:
return os.path.commonpath([_upload_root(), os.path.realpath(path)]) == _upload_root()
except Exception:
return False
def _resolve_upload_path(file_id: str) -> str:
from src.constants import UPLOAD_DIR
upload_root = getattr(upload_handler, "upload_dir", UPLOAD_DIR)
direct = os.path.join(upload_root, file_id)
if os.path.lexists(direct):
if not _path_inside_upload_dir(direct):
raise HTTPException(403, "Access denied")
if os.path.isfile(direct):
return direct
raise HTTPException(404, "File not found")
for root, _dirs, files in os.walk(upload_root, followlinks=False):
if file_id not in files:
continue
path = os.path.join(root, file_id)
if not _path_inside_upload_dir(path):
raise HTTPException(403, "Access denied")
if os.path.isfile(path):
return path
raise HTTPException(404, "File not found")
raise HTTPException(404, "File not found")
@router.post("")
async def api_upload(request: Request, files: List[UploadFile] = File(...)):
@@ -91,23 +125,11 @@ def setup_upload_routes(upload_handler):
client isn't downloading the full-resolution photo just to show it tiny."""
if not upload_handler.validate_upload_id(file_id):
raise HTTPException(400, "Invalid file ID")
# Search upload directories for the file
from src.constants import UPLOAD_DIR
import mimetypes as _mt
path = os.path.join(UPLOAD_DIR, file_id)
if not os.path.exists(path):
for root, dirs, files in os.walk(UPLOAD_DIR):
if file_id in files:
path = os.path.join(root, file_id)
break
else:
raise HTTPException(404, "File not found")
if not upload_handler.inside_base_dir(path):
raise HTTPException(403, "Access denied")
# Look up original filename and owner from uploads.json
original_name = file_id
info = None
uploads_db = os.path.join(UPLOAD_DIR, "uploads.json")
uploads_db = os.path.join(_upload_root(), "uploads.json")
if os.path.exists(uploads_db):
with open(uploads_db, encoding="utf-8") as f:
db = json.load(f)
@@ -123,13 +145,14 @@ def setup_upload_routes(upload_handler):
raise HTTPException(403, "Access denied")
if file_owner != current_user and not auth_mgr.is_admin(current_user):
raise HTTPException(404, "File not found")
mime = _mt.guess_type(path)[0] or "application/octet-stream"
path = _resolve_upload_path(file_id)
mime = (info or {}).get("mime") or _mt.guess_type(path)[0] or "application/octet-stream"
from fastapi.responses import FileResponse
# Downscaled thumbnail for image previews — generated once and cached.
if thumb and mime.startswith("image/"):
try:
from PIL import Image, ImageOps
thumb_dir = os.path.join(UPLOAD_DIR, ".thumbs")
thumb_dir = os.path.join(_upload_root(), ".thumbs")
os.makedirs(thumb_dir, exist_ok=True)
thumb_path = os.path.join(thumb_dir, file_id + ".jpg")
if (not os.path.exists(thumb_path)
@@ -145,17 +168,21 @@ def setup_upload_routes(upload_handler):
if im.mode not in ("RGB", "L"):
im = im.convert("RGB")
im.save(thumb_path, "JPEG", quality=80)
return FileResponse(thumb_path, media_type="image/jpeg")
return FileResponse(thumb_path, media_type="image/jpeg", headers=UPLOAD_RESPONSE_HEADERS)
except Exception as e:
logger.warning(f"Thumbnail generation failed for {file_id}: {e}")
# Fall through to the full image.
return FileResponse(path, media_type=mime, filename=original_name)
return FileResponse(
path,
media_type=mime,
filename=original_name,
headers=UPLOAD_RESPONSE_HEADERS,
)
def _load_upload_info(file_id: str):
"""Look up the uploads.json record for a file_id, with owner/auth checks."""
from src.constants import UPLOAD_DIR
info = None
uploads_db = os.path.join(UPLOAD_DIR, "uploads.json")
uploads_db = os.path.join(_upload_root(), "uploads.json")
if os.path.exists(uploads_db):
with open(uploads_db, encoding="utf-8") as f:
db = json.load(f)
@@ -163,8 +190,7 @@ def setup_upload_routes(upload_handler):
return info
def _vision_cache_path(file_id: str) -> str:
from src.constants import UPLOAD_DIR
cache_dir = os.path.join(UPLOAD_DIR, ".vision")
cache_dir = os.path.join(_upload_root(), ".vision")
os.makedirs(cache_dir, exist_ok=True)
return os.path.join(cache_dir, file_id + ".txt")
@@ -175,17 +201,6 @@ def setup_upload_routes(upload_handler):
subsequent loads are instant. Pass force=1 to recompute."""
if not upload_handler.validate_upload_id(file_id):
raise HTTPException(400, "Invalid file ID")
from src.constants import UPLOAD_DIR
path = os.path.join(UPLOAD_DIR, file_id)
if not os.path.exists(path):
for root, dirs, files in os.walk(UPLOAD_DIR):
if file_id in files:
path = os.path.join(root, file_id)
break
else:
raise HTTPException(404, "File not found")
if not upload_handler.inside_base_dir(path):
raise HTTPException(403, "Access denied")
info = _load_upload_info(file_id)
auth_mgr = getattr(request.app.state, "auth_manager", None)
auth_configured = bool(auth_mgr and auth_mgr.is_configured)
@@ -196,8 +211,9 @@ def setup_upload_routes(upload_handler):
raise HTTPException(403, "Access denied")
if file_owner != current_user and not auth_mgr.is_admin(current_user):
raise HTTPException(404, "File not found")
path = _resolve_upload_path(file_id)
import mimetypes as _mt
mime = _mt.guess_type(path)[0] or ""
mime = (info or {}).get("mime") or _mt.guess_type(path)[0] or ""
if not mime.startswith("image/"):
raise HTTPException(400, "Not an image")
cache_path = _vision_cache_path(file_id)
@@ -209,7 +225,7 @@ def setup_upload_routes(upload_handler):
logger.warning(f"Vision cache read failed for {file_id}: {e}")
from src.document_processor import analyze_image_with_vl
try:
text = analyze_image_with_vl(path) or ""
text = analyze_image_with_vl(path, owner=current_user) or ""
except Exception as e:
logger.error(f"Vision analysis failed for {file_id}: {e}")
raise HTTPException(500, f"Vision analysis failed: {e}")
@@ -238,6 +254,7 @@ def setup_upload_routes(upload_handler):
raise HTTPException(403, "Access denied")
if file_owner != current_user and not auth_mgr.is_admin(current_user):
raise HTTPException(404, "File not found")
_resolve_upload_path(file_id)
body = await request.json()
text = (body or {}).get("text", "")
if not isinstance(text, str):
+2 -1
View File
@@ -17,10 +17,11 @@ from pydantic import BaseModel
from core.middleware import require_admin
from core.platform_compat import IS_WINDOWS, safe_chmod, which_tool
from src.constants import VAULT_FILE as _VAULT_FILE
logger = logging.getLogger(__name__)
VAULT_FILE = Path("data/vault.json")
VAULT_FILE = Path(_VAULT_FILE)
def _find_bw() -> str:
+23 -10
View File
@@ -194,6 +194,8 @@ def setup_webhook_routes(
"together": "https://api.together.xyz/v1",
"openrouter": "https://openrouter.ai/api/v1",
"ollama": "https://ollama.com/api",
"opencode-zen": "https://opencode.ai/zen/v1",
"opencode-go": "https://opencode.ai/zen/go/v1",
"fireworks": "https://api.fireworks.ai/inference/v1",
"venice": "https://api.venice.ai/api/v1",
}
@@ -323,22 +325,33 @@ def setup_webhook_routes(
endpoint_url = build_chat_url(base_url)
model = body.model or "auto"
api_key = ep.api_key
if getattr(ep, "provider_auth_id", None):
try:
from src.endpoint_resolver import resolve_endpoint_runtime
base_url, api_key = resolve_endpoint_runtime(ep, owner=token_owner)
endpoint_url = build_chat_url(base_url)
except Exception:
raise HTTPException(500, "Could not resolve endpoint credentials")
if model == "auto":
try:
async with httpx.AsyncClient(timeout=5) as client:
models_url = build_models_url(base_url)
hdrs = build_headers(api_key, base_url)
resp = await client.get(models_url, headers=hdrs)
resp.raise_for_status()
data = resp.json()
ids = [m.get("id") for m in (data.get("data") or []) if m.get("id")]
if not ids:
ids = [
m.get("name") or m.get("model")
for m in (data.get("models") or [])
if m.get("name") or m.get("model")
]
if models_url:
resp = await client.get(models_url, headers=hdrs)
resp.raise_for_status()
data = resp.json()
ids = [m.get("id") for m in (data.get("data") or []) if m.get("id")]
if not ids:
ids = [
m.get("name") or m.get("model")
for m in (data.get("models") or [])
if m.get("name") or m.get("model")
]
else:
import json as _json
ids = _json.loads(ep.cached_models or "[]")
model = ids[0] if ids else "auto"
except Exception:
raise HTTPException(500, "Could not discover models from endpoint")