mirror of
https://github.com/pewdiepie-archdaemon/odysseus.git
synced 2026-06-16 17:55:26 -04:00
Merge branch 'dev'
# Conflicts: # routes/task_routes.py # src/caldav_sync.py
This commit is contained in:
@@ -31,7 +31,7 @@ from core.database import (
|
||||
CalendarEvent,
|
||||
CalendarCal,
|
||||
)
|
||||
from src.constants import DATA_DIR
|
||||
from src.constants import DATA_DIR, SKILLS_DIR, SKILLS_FILE, GALLERY_DIR, GALLERY_UPLOADS_DIR
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -107,7 +107,7 @@ def setup_admin_wipe_routes(session_manager):
|
||||
# Skills live as SKILL.md files under data/skills/. Drop
|
||||
# the entire directory; the SkillsManager re-creates the
|
||||
# tree on next write.
|
||||
skills_dir = os.path.join(DATA_DIR, "skills")
|
||||
skills_dir = SKILLS_DIR
|
||||
count = 0
|
||||
if os.path.isdir(skills_dir):
|
||||
# Count SKILL.md files for the response — quick walk.
|
||||
@@ -115,7 +115,7 @@ def setup_admin_wipe_routes(session_manager):
|
||||
count += sum(1 for f in files if f == "SKILL.md")
|
||||
_rmtree_quiet(skills_dir)
|
||||
# Legacy fallback file
|
||||
legacy = os.path.join(DATA_DIR, "skills.json")
|
||||
legacy = SKILLS_FILE
|
||||
if os.path.exists(legacy):
|
||||
try:
|
||||
os.remove(legacy)
|
||||
@@ -151,8 +151,8 @@ def setup_admin_wipe_routes(session_manager):
|
||||
db.query(GalleryAlbum).delete()
|
||||
db.commit()
|
||||
# Also drop the upload dir so disk doesn't keep orphans.
|
||||
_rmtree_quiet(os.path.join(DATA_DIR, "gallery"))
|
||||
_rmtree_quiet(os.path.join(DATA_DIR, "gallery_uploads"))
|
||||
_rmtree_quiet(GALLERY_DIR)
|
||||
_rmtree_quiet(GALLERY_UPLOADS_DIR)
|
||||
return {"status": "deleted", "kind": kind, "count": count}
|
||||
|
||||
if kind == "calendar":
|
||||
|
||||
@@ -155,22 +155,30 @@ def setup_api_token_routes() -> APIRouter:
|
||||
payload = await request.json()
|
||||
except Exception:
|
||||
payload = {}
|
||||
scope_list = _normalize_scopes(payload.get("scopes"))
|
||||
scopes_value = ",".join(scope_list)
|
||||
with get_db_session() as db:
|
||||
token = db.query(ApiToken).filter(ApiToken.id == token_id).first()
|
||||
if not token:
|
||||
raise HTTPException(404, "Token not found")
|
||||
if isinstance(payload.get("name"), str) and payload["name"].strip():
|
||||
token.name = payload["name"].strip()[:MAX_NAME_LEN]
|
||||
token.scopes = scopes_value
|
||||
# Only touch scopes when the caller actually sent them. A partial
|
||||
# update such as a rename ({"name": ...} with no "scopes" key) must
|
||||
# not silently reset the token to the default scope — that dropped
|
||||
# every previously granted scope.
|
||||
if "scopes" in payload:
|
||||
token.scopes = ",".join(_normalize_scopes(payload.get("scopes")))
|
||||
db.add(token)
|
||||
current_scopes = [
|
||||
s.strip()
|
||||
for s in (getattr(token, "scopes", "") or DEFAULT_SCOPES).split(",")
|
||||
if s.strip()
|
||||
]
|
||||
response = {
|
||||
"id": token_id,
|
||||
"name": getattr(token, "name", ""),
|
||||
"owner": getattr(token, "owner", None),
|
||||
"token_prefix": getattr(token, "token_prefix", ""),
|
||||
"scopes": scope_list,
|
||||
"scopes": current_scopes,
|
||||
}
|
||||
_invalidate_cache(request)
|
||||
return response
|
||||
|
||||
+23
-4
@@ -131,10 +131,8 @@ def setup_auth_routes(auth_manager: AuthManager) -> APIRouter:
|
||||
return {"ok": False, "requires_totp": True, "username": username}
|
||||
if not auth_manager.totp_verify(username, body.totp_code):
|
||||
raise HTTPException(401, "Invalid 2FA code")
|
||||
# All checks passed — create session
|
||||
token = await asyncio.to_thread(auth_manager.create_session, username, body.password)
|
||||
if not token:
|
||||
raise HTTPException(401, "Invalid credentials")
|
||||
# All checks passed — create session (password already verified above)
|
||||
token = await asyncio.to_thread(auth_manager.create_session_trusted, username)
|
||||
cookie_kwargs = dict(
|
||||
key=SESSION_COOKIE,
|
||||
value=token,
|
||||
@@ -585,6 +583,27 @@ def setup_auth_routes(auth_manager: AuthManager) -> APIRouter:
|
||||
hint = " If this is Docker Compose ntfy, set NTFY_BIND to that host/Tailscale IP and NTFY_BASE_URL to the same server URL in .env, then recreate ntfy."
|
||||
return {"ok": False, "message": f"ntfy publish to {full_url} failed: {e}.{hint}"[:500]}
|
||||
|
||||
if preset == "discord_webhook":
|
||||
import httpx
|
||||
webhook_url = (integ.get("base_url") or "").strip()
|
||||
if not webhook_url:
|
||||
return {"ok": False, "message": "No webhook URL set — paste the full Discord webhook URL into the Base URL field."}
|
||||
payload = {
|
||||
"embeds": [{
|
||||
"title": "Odysseus connectivity test",
|
||||
"description": "If you see this, your Discord Webhook integration is wired up correctly.",
|
||||
"color": 5793266,
|
||||
}]
|
||||
}
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=8.0) as client:
|
||||
r = await client.post(webhook_url, json=payload)
|
||||
if r.is_success:
|
||||
return {"ok": True, "message": "Test embed sent — check your Discord channel to confirm it arrived."}
|
||||
return {"ok": False, "message": f"Discord returned HTTP {r.status_code}: {r.text[:200]}"}
|
||||
except Exception as e:
|
||||
return {"ok": False, "message": f"Request failed: {e}"[:400]}
|
||||
|
||||
# All other presets: GET against a known health endpoint.
|
||||
# Fall back to detecting from name if preset is missing.
|
||||
health_paths = {
|
||||
|
||||
+56
-12
@@ -101,24 +101,68 @@ def setup_backup_routes(memory_manager, preset_manager, skills_manager) -> APIRo
|
||||
# ── Skills ──
|
||||
if "skills" in body and isinstance(body["skills"], list):
|
||||
existing = skills_manager.load_all()
|
||||
existing_ids = {s.get("id") for s in existing}
|
||||
existing_titles = {s.get("title", "").strip().lower() for s in existing}
|
||||
existing_names = {s.get("name") for s in existing if s.get("name")}
|
||||
existing_ids = {s.get("id") for s in existing if s.get("id")}
|
||||
existing_titles = {
|
||||
(s.get("title") or s.get("description") or "").strip().lower()
|
||||
for s in existing
|
||||
}
|
||||
added = 0
|
||||
for skill in body["skills"]:
|
||||
if not isinstance(skill, dict) or not skill.get("title"):
|
||||
if not isinstance(skill, dict):
|
||||
continue
|
||||
# Skip if same id or same title already exists
|
||||
if skill.get("id") in existing_ids:
|
||||
title = (
|
||||
skill.get("title") or skill.get("description")
|
||||
or skill.get("name") or ""
|
||||
).strip()
|
||||
if not title:
|
||||
continue
|
||||
if skill["title"].strip().lower() in existing_titles:
|
||||
sid = skill.get("id") or skill.get("name")
|
||||
if sid and sid in existing_ids:
|
||||
continue
|
||||
if user and not skill.get("owner"):
|
||||
skill["owner"] = user
|
||||
existing.append(skill)
|
||||
existing_ids.add(skill.get("id"))
|
||||
existing_titles.add(skill["title"].strip().lower())
|
||||
nm = skill.get("name")
|
||||
if nm and nm in existing_names:
|
||||
continue
|
||||
if title.lower() in existing_titles:
|
||||
continue
|
||||
owner = skill.get("owner")
|
||||
if user and not owner:
|
||||
owner = user
|
||||
# Skills live on disk as SKILL.md files; the old JSON-era
|
||||
# skills_manager.save() no longer exists. Write each new skill
|
||||
# via add_skill (source="user" skips auto-dedup — this is an
|
||||
# explicit backup restore).
|
||||
result = skills_manager.add_skill(
|
||||
title=title,
|
||||
name=skill.get("name"),
|
||||
description=skill.get("description"),
|
||||
problem=skill.get("problem", ""),
|
||||
solution=skill.get("solution", ""),
|
||||
steps=skill.get("steps"),
|
||||
tags=skill.get("tags"),
|
||||
source="user",
|
||||
teacher_model=skill.get("teacher_model"),
|
||||
confidence=skill.get("confidence", 0.8),
|
||||
owner=owner,
|
||||
category=skill.get("category", "general"),
|
||||
when_to_use=skill.get("when_to_use"),
|
||||
procedure=skill.get("procedure"),
|
||||
pitfalls=skill.get("pitfalls"),
|
||||
verification=skill.get("verification"),
|
||||
platforms=skill.get("platforms"),
|
||||
requires_toolsets=skill.get("requires_toolsets"),
|
||||
fallback_for_toolsets=skill.get("fallback_for_toolsets"),
|
||||
status=skill.get("status", "draft"),
|
||||
version=skill.get("version", "1.0.0"),
|
||||
)
|
||||
if result.get("_deduped"):
|
||||
continue
|
||||
if result.get("name"):
|
||||
existing_names.add(result["name"])
|
||||
if result.get("id"):
|
||||
existing_ids.add(result["id"])
|
||||
existing_titles.add(title.lower())
|
||||
added += 1
|
||||
skills_manager.save(existing)
|
||||
imported.append(f"{added} skills")
|
||||
|
||||
# ── Presets ──
|
||||
|
||||
+254
-69
@@ -1,6 +1,7 @@
|
||||
"""Calendar routes — local SQLite-backed calendar CRUD."""
|
||||
|
||||
import logging
|
||||
import re
|
||||
import uuid
|
||||
from datetime import datetime, date, timedelta
|
||||
from typing import Optional, List
|
||||
@@ -12,7 +13,7 @@ from dateutil.rrule import rrulestr
|
||||
|
||||
from core.database import SessionLocal, CalendarCal, CalendarEvent
|
||||
from src.auth_helpers import require_user
|
||||
from src.upload_limits import read_upload_limited
|
||||
from src.upload_limits import read_upload_limited, ICS_MAX_BYTES
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -100,6 +101,15 @@ def _ics_escape(text: str) -> str:
|
||||
)
|
||||
|
||||
|
||||
def _safe_ics_filename(name: str) -> str:
|
||||
"""Return a conservative .ics filename safe for Content-Disposition."""
|
||||
stem = name if isinstance(name, str) else ""
|
||||
stem = re.sub(r"[^A-Za-z0-9._-]", "_", stem).strip("._-")
|
||||
if not stem:
|
||||
stem = "calendar"
|
||||
return f"{stem[:128]}.ics"
|
||||
|
||||
|
||||
def _resolve_base_uid(uid: str) -> str:
|
||||
"""Extract the base series UID from a compound occurrence UID.
|
||||
|
||||
@@ -248,6 +258,17 @@ def parse_due_for_user(s: str) -> str:
|
||||
if t is not None:
|
||||
return base.replace(hour=t[0], minute=t[1]).isoformat()
|
||||
|
||||
# Time-first: "3pm today", "11pm today", "9am tomorrow"
|
||||
m = _re.match(r'^(.+?)\s+(today|tonight|tomorrow|tmrw|yesterday)$', lower)
|
||||
if m:
|
||||
time_part, word = m.group(1).strip(), m.group(2)
|
||||
base = today
|
||||
if word in ("tomorrow", "tmrw"): base = today + _td(days=1)
|
||||
elif word == "yesterday": base = today - _td(days=1)
|
||||
t = _parse_time(time_part)
|
||||
if t is not None:
|
||||
return base.replace(hour=t[0], minute=t[1]).isoformat()
|
||||
|
||||
m = _re.match(r'^in\s+(\d+)\s*(hour|hr|minute|min|day)s?\s*$', lower)
|
||||
if m:
|
||||
n = int(m.group(1)); unit = m.group(2)
|
||||
@@ -399,7 +420,17 @@ def _parse_dt(s: str) -> datetime:
|
||||
# Last resort: dateutil's fuzzy parser
|
||||
try:
|
||||
from dateutil import parser as _du
|
||||
return _du.parse(s)
|
||||
parsed = _du.parse(s)
|
||||
# Strip tz like every other return path above — this function's
|
||||
# contract is naive datetimes (CalendarEvent.dtstart is naive). An
|
||||
# offset-bearing non-ISO input (e.g. RFC-2822 "Mon, 05 Jan 2026
|
||||
# 14:00:00 +0900") otherwise leaked tz-aware into the naive column and
|
||||
# crashed read-back comparisons in _expand_rrule with "can't compare
|
||||
# offset-naive and offset-aware datetimes".
|
||||
if parsed.tzinfo is not None:
|
||||
from datetime import timezone as _tz
|
||||
return parsed.astimezone(_tz.utc).replace(tzinfo=None)
|
||||
return parsed
|
||||
except Exception:
|
||||
raise ValueError(f"could not parse datetime: {s!r}")
|
||||
|
||||
@@ -440,6 +471,9 @@ def _event_to_dict(ev: CalendarEvent) -> dict:
|
||||
|
||||
# ── Recurrence expansion ──
|
||||
|
||||
_RRULE_EXPANSION_LIMIT = 1000
|
||||
|
||||
|
||||
def _expand_rrule(
|
||||
ev: CalendarEvent, start: datetime, end: datetime
|
||||
) -> List[dict]:
|
||||
@@ -462,6 +496,7 @@ def _expand_rrule(
|
||||
d = _event_to_dict(ev)
|
||||
d["is_recurrence"] = False
|
||||
d["series_uid"] = ev.uid
|
||||
d["truncated"] = False
|
||||
return [d]
|
||||
|
||||
# Parse the rrule, applying it to the base dtstart.
|
||||
@@ -487,6 +522,7 @@ def _expand_rrule(
|
||||
d = _event_to_dict(ev)
|
||||
d["is_recurrence"] = False
|
||||
d["series_uid"] = ev.uid
|
||||
d["truncated"] = False
|
||||
# Malformed RRULE rows are fetched by the recurring SQL branch
|
||||
# with only dtstart < end_dt — the base event may not actually
|
||||
# overlap the window. Only return if it does.
|
||||
@@ -499,22 +535,26 @@ def _expand_rrule(
|
||||
# (matching non-recurring overlap semantics: dtstart < end AND
|
||||
# dtend > start).
|
||||
expand_start = start - duration
|
||||
occurrences = rule.between(expand_start, end, inc=True)
|
||||
if not occurrences:
|
||||
return []
|
||||
|
||||
results = []
|
||||
truncated = False
|
||||
base = _event_to_dict(ev)
|
||||
|
||||
for occ_start in occurrences:
|
||||
for occ_start in rule.xafter(expand_start, inc=True):
|
||||
if occ_start >= end:
|
||||
break
|
||||
|
||||
occ_end = occ_start + duration
|
||||
|
||||
# Overlap filter: occurrence must intersect [start, end).
|
||||
# This enforces exclusive-end semantics (occ_start >= end is
|
||||
# excluded) and includes multi-day crossings (occ_end > start).
|
||||
if occ_start >= end or occ_end <= start:
|
||||
if occ_end <= start:
|
||||
continue
|
||||
|
||||
if len(results) >= _RRULE_EXPANSION_LIMIT:
|
||||
truncated = True
|
||||
break
|
||||
|
||||
# Build the compound uid: {base_uid}::{date} or ::{datetime}
|
||||
if ev.all_day:
|
||||
occ_uid = f"{ev.uid}::{occ_start.strftime('%Y-%m-%d')}"
|
||||
@@ -525,6 +565,7 @@ def _expand_rrule(
|
||||
d["uid"] = occ_uid
|
||||
d["series_uid"] = ev.uid
|
||||
d["is_recurrence"] = True
|
||||
d["truncated"] = False
|
||||
|
||||
if ev.all_day:
|
||||
d["dtstart"] = occ_start.strftime("%Y-%m-%d")
|
||||
@@ -537,6 +578,10 @@ def _expand_rrule(
|
||||
|
||||
results.append(d)
|
||||
|
||||
if truncated:
|
||||
for d in results:
|
||||
d["truncated"] = True
|
||||
|
||||
return results
|
||||
|
||||
|
||||
@@ -545,72 +590,178 @@ def _expand_rrule(
|
||||
def setup_calendar_routes() -> APIRouter:
|
||||
router = APIRouter(prefix="/api/calendar", tags=["calendar"])
|
||||
|
||||
# CalDAV connect form (Integrations → Calendar). Storage is local
|
||||
# SQLite; sync (src/caldav_sync.py) pulls remote events into it on
|
||||
# calendar open and periodically via the scheduler.
|
||||
# ── CalDAV multi-account helpers ─────────────────────────────────────────
|
||||
|
||||
def _get_caldav_accounts(owner: str) -> list:
|
||||
from src.caldav_sync import _load_caldav_accounts
|
||||
return _load_caldav_accounts(owner)
|
||||
|
||||
def _save_caldav_accounts(owner: str, accounts: list) -> None:
|
||||
from routes.prefs_routes import _load_for_user, _save_for_user
|
||||
prefs = _load_for_user(owner) or {}
|
||||
prefs["caldav_accounts"] = accounts
|
||||
prefs.pop("caldav", None)
|
||||
_save_for_user(owner, prefs)
|
||||
|
||||
# ── CalDAV config routes (backward-compat single-account API) ────────────
|
||||
|
||||
@router.get("/config")
|
||||
async def get_config(request: Request):
|
||||
"""Legacy single-account endpoint — returns the first configured account."""
|
||||
owner = _require_user(request)
|
||||
from routes.prefs_routes import _load_for_user
|
||||
cfg = (_load_for_user(owner) or {}).get("caldav", {}) or {}
|
||||
caldav_password = cfg.get("password") or ""
|
||||
if caldav_password:
|
||||
accounts = _get_caldav_accounts(owner)
|
||||
if not accounts:
|
||||
return {"url": "", "username": "", "password": "", "has_password": False, "local": True}
|
||||
first = accounts[0]
|
||||
pw = first.get("password") or ""
|
||||
has_pw = False
|
||||
if pw:
|
||||
try:
|
||||
from src.secret_storage import decrypt
|
||||
caldav_password = decrypt(caldav_password)
|
||||
has_pw = bool(decrypt(pw))
|
||||
except Exception:
|
||||
pass
|
||||
# Surface url+username but never hand the password back to the
|
||||
# client — saved-state UI shouldn't leak the credential.
|
||||
has_pw = bool(pw)
|
||||
return {
|
||||
"url": cfg.get("url", "") or "",
|
||||
"username": cfg.get("username", "") or "",
|
||||
"url": first.get("url", "") or "",
|
||||
"username": first.get("username", "") or "",
|
||||
"password": "",
|
||||
"has_password": bool(caldav_password),
|
||||
"local": not bool(cfg.get("url")),
|
||||
"has_password": has_pw,
|
||||
"local": not bool(first.get("url")),
|
||||
}
|
||||
|
||||
@router.post("/config")
|
||||
async def save_config(request: Request):
|
||||
"""Legacy single-account endpoint — upserts the first account."""
|
||||
owner = _require_user(request)
|
||||
from routes.prefs_routes import _load_for_user, _save_for_user
|
||||
try:
|
||||
body = await request.json()
|
||||
except Exception:
|
||||
body = {}
|
||||
prefs = _load_for_user(owner) or {}
|
||||
cfg = dict(prefs.get("caldav") or {})
|
||||
# Empty url => clear the whole entry (treat as "remove integration").
|
||||
accounts = _get_caldav_accounts(owner)
|
||||
if not (body.get("url") or "").strip():
|
||||
prefs.pop("caldav", None)
|
||||
_save_for_user(owner, prefs)
|
||||
_save_caldav_accounts(owner, [])
|
||||
return {"ok": True, "cleared": True}
|
||||
from src.caldav_sync import validate_caldav_url
|
||||
try:
|
||||
cfg["url"] = validate_caldav_url(body.get("url", ""))
|
||||
validated_url = validate_caldav_url(body.get("url", ""))
|
||||
except ValueError as e:
|
||||
raise HTTPException(400, str(e))
|
||||
cfg["username"] = (body.get("username") or "").strip()
|
||||
# Preserve the stored password when the client sends an empty
|
||||
# one (edit form re-submitted without re-typing the password).
|
||||
# cfg already holds the existing (already-encrypted) password from
|
||||
# prefs, so we only touch it when a new password is supplied —
|
||||
# re-encrypting the stored value would double-encrypt it.
|
||||
if accounts:
|
||||
acc = dict(accounts[0])
|
||||
else:
|
||||
import uuid as _uuid
|
||||
acc = {"id": str(_uuid.uuid4()), "label": "CalDAV"}
|
||||
acc["url"] = validated_url
|
||||
acc["username"] = (body.get("username") or "").strip()
|
||||
if body.get("password"):
|
||||
from src.secret_storage import encrypt
|
||||
cfg["password"] = encrypt(body["password"])
|
||||
prefs["caldav"] = cfg
|
||||
_save_for_user(owner, prefs)
|
||||
acc["password"] = encrypt(body["password"])
|
||||
new_accounts = [acc] + (accounts[1:] if len(accounts) > 1 else [])
|
||||
_save_caldav_accounts(owner, new_accounts)
|
||||
return {"ok": True}
|
||||
|
||||
# ── CalDAV multi-account CRUD ─────────────────────────────────────────────
|
||||
|
||||
@router.get("/config/accounts")
|
||||
async def list_caldav_accounts(request: Request):
|
||||
"""Return all configured CalDAV accounts (passwords never returned)."""
|
||||
owner = _require_user(request)
|
||||
accounts = _get_caldav_accounts(owner)
|
||||
safe = []
|
||||
for acc in accounts:
|
||||
pw = acc.get("password") or ""
|
||||
has_pw = False
|
||||
if pw:
|
||||
try:
|
||||
from src.secret_storage import decrypt
|
||||
has_pw = bool(decrypt(pw))
|
||||
except Exception:
|
||||
has_pw = bool(pw)
|
||||
safe.append({
|
||||
"id": acc.get("id", ""),
|
||||
"label": acc.get("label", "") or acc.get("url", ""),
|
||||
"url": acc.get("url", "") or "",
|
||||
"username": acc.get("username", "") or "",
|
||||
"has_password": has_pw,
|
||||
})
|
||||
return {"accounts": safe}
|
||||
|
||||
@router.post("/config/accounts")
|
||||
async def add_caldav_account(request: Request):
|
||||
"""Add a new CalDAV account."""
|
||||
import uuid as _uuid
|
||||
owner = _require_user(request)
|
||||
try:
|
||||
body = await request.json()
|
||||
except Exception:
|
||||
body = {}
|
||||
from src.caldav_sync import validate_caldav_url
|
||||
try:
|
||||
url = validate_caldav_url(body.get("url", ""))
|
||||
except ValueError as e:
|
||||
raise HTTPException(400, str(e))
|
||||
if not body.get("password"):
|
||||
raise HTTPException(400, "Password is required")
|
||||
from src.secret_storage import encrypt
|
||||
new_acc = {
|
||||
"id": str(_uuid.uuid4()),
|
||||
"label": (body.get("label") or "").strip() or "CalDAV",
|
||||
"url": url,
|
||||
"username": (body.get("username") or "").strip(),
|
||||
"password": encrypt(body["password"]),
|
||||
}
|
||||
accounts = _get_caldav_accounts(owner)
|
||||
accounts.append(new_acc)
|
||||
_save_caldav_accounts(owner, accounts)
|
||||
return {"ok": True, "id": new_acc["id"]}
|
||||
|
||||
@router.put("/config/accounts/{account_id}")
|
||||
async def update_caldav_account(account_id: str, request: Request):
|
||||
"""Update an existing CalDAV account by id."""
|
||||
owner = _require_user(request)
|
||||
try:
|
||||
body = await request.json()
|
||||
except Exception:
|
||||
body = {}
|
||||
accounts = _get_caldav_accounts(owner)
|
||||
idx = next((i for i, a in enumerate(accounts) if a.get("id") == account_id), None)
|
||||
if idx is None:
|
||||
raise HTTPException(404, "Account not found")
|
||||
acc = dict(accounts[idx])
|
||||
if body.get("url"):
|
||||
from src.caldav_sync import validate_caldav_url
|
||||
try:
|
||||
acc["url"] = validate_caldav_url(body["url"])
|
||||
except ValueError as e:
|
||||
raise HTTPException(400, str(e))
|
||||
if body.get("label") is not None:
|
||||
acc["label"] = (body.get("label") or "").strip() or "CalDAV"
|
||||
if body.get("username") is not None:
|
||||
acc["username"] = (body.get("username") or "").strip()
|
||||
if body.get("password"):
|
||||
from src.secret_storage import encrypt
|
||||
acc["password"] = encrypt(body["password"])
|
||||
accounts[idx] = acc
|
||||
_save_caldav_accounts(owner, accounts)
|
||||
return {"ok": True}
|
||||
|
||||
@router.delete("/config/accounts/{account_id}")
|
||||
async def delete_caldav_account(account_id: str, request: Request):
|
||||
"""Remove a CalDAV account by id."""
|
||||
owner = _require_user(request)
|
||||
accounts = _get_caldav_accounts(owner)
|
||||
new_accounts = [a for a in accounts if a.get("id") != account_id]
|
||||
if len(new_accounts) == len(accounts):
|
||||
raise HTTPException(404, "Account not found")
|
||||
_save_caldav_accounts(owner, new_accounts)
|
||||
return {"ok": True}
|
||||
|
||||
@router.post("/test")
|
||||
async def test_connection(request: Request):
|
||||
"""Actually probe the configured CalDAV server with a PROPFIND
|
||||
request (the same handshake every CalDAV client uses). Accepts
|
||||
an optional {url, username, password} body so the user can test
|
||||
a configuration BEFORE saving it; falls back to the stored
|
||||
creds otherwise. Returns {ok, error?} with a useful message on
|
||||
failure (status code, auth issue, network error)."""
|
||||
"""Probe a CalDAV server with a PROPFIND. Accepts an optional body:
|
||||
{url, username, password} to test before saving, or {account_id} to
|
||||
test an already-saved account. Falls back to the first saved account
|
||||
when nothing is provided."""
|
||||
owner = _require_user(request)
|
||||
try:
|
||||
body = await request.json()
|
||||
@@ -620,19 +771,24 @@ def setup_calendar_routes() -> APIRouter:
|
||||
user = (body.get("username") or "").strip()
|
||||
pw = body.get("password") or ""
|
||||
if not (url and user and pw):
|
||||
# Fall back to saved settings for this user.
|
||||
from routes.prefs_routes import _load_for_user
|
||||
cfg = (_load_for_user(owner) or {}).get("caldav", {}) or {}
|
||||
url = url or (cfg.get("url") or "")
|
||||
user = user or (cfg.get("username") or "")
|
||||
if not pw:
|
||||
pw = cfg.get("password") or ""
|
||||
if pw:
|
||||
try:
|
||||
from src.secret_storage import decrypt
|
||||
pw = decrypt(pw)
|
||||
except Exception:
|
||||
pass
|
||||
# Look up a saved account: by id if supplied, else first account.
|
||||
accounts = _get_caldav_accounts(owner)
|
||||
acc = None
|
||||
if body.get("account_id"):
|
||||
acc = next((a for a in accounts if a.get("id") == body["account_id"]), None)
|
||||
if acc is None and accounts:
|
||||
acc = accounts[0]
|
||||
if acc:
|
||||
url = url or (acc.get("url") or "")
|
||||
user = user or (acc.get("username") or "")
|
||||
if not pw:
|
||||
pw = acc.get("password") or ""
|
||||
if pw:
|
||||
try:
|
||||
from src.secret_storage import decrypt
|
||||
pw = decrypt(pw)
|
||||
except Exception:
|
||||
pass
|
||||
if not (url and user and pw):
|
||||
return {"ok": False, "error": "Missing URL, username, or password"}
|
||||
from src.caldav_sync import validate_caldav_url
|
||||
@@ -695,6 +851,28 @@ def setup_calendar_routes() -> APIRouter:
|
||||
from src.caldav_sync import sync_caldav
|
||||
return await sync_caldav(owner)
|
||||
|
||||
@router.delete("/calendars/{cal_id}")
|
||||
async def delete_calendar(cal_id: str, request: Request):
|
||||
owner = _require_user(request)
|
||||
db = SessionLocal()
|
||||
try:
|
||||
cal = db.query(CalendarCal).filter(
|
||||
CalendarCal.id == cal_id,
|
||||
CalendarCal.owner == owner,
|
||||
).first()
|
||||
if not cal:
|
||||
raise HTTPException(404, "Calendar not found")
|
||||
db.delete(cal)
|
||||
db.commit()
|
||||
return {"ok": True}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Failed to delete calendar %s: %s", cal_id, e)
|
||||
raise HTTPException(500, "Failed to delete calendar")
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
@router.get("/calendars")
|
||||
async def list_calendars(request: Request):
|
||||
owner = _require_user(request)
|
||||
@@ -703,7 +881,7 @@ def setup_calendar_routes() -> APIRouter:
|
||||
_ensure_default_calendar(db, owner)
|
||||
cals = db.query(CalendarCal).filter(CalendarCal.owner == owner).all()
|
||||
return {"calendars": [
|
||||
{"name": c.name, "href": c.id, "color": c.color}
|
||||
{"name": c.name, "href": c.id, "color": c.color, "source": c.source}
|
||||
for c in cals
|
||||
]}
|
||||
except HTTPException:
|
||||
@@ -766,8 +944,12 @@ def setup_calendar_routes() -> APIRouter:
|
||||
expanded.extend(_expand_rrule(e, start_dt, end_dt))
|
||||
|
||||
# Sort by occurrence start time for consistent frontend ordering.
|
||||
truncated = any(e.get("truncated") for e in expanded)
|
||||
expanded.sort(key=lambda d: d["dtstart"])
|
||||
return {"events": expanded}
|
||||
response: dict = {"events": expanded}
|
||||
if truncated:
|
||||
response["truncated"] = True
|
||||
return response
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
@@ -988,9 +1170,9 @@ def setup_calendar_routes() -> APIRouter:
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
# 10 MB hard cap on ICS upload. Loading the whole file into memory is
|
||||
# unavoidable with python-icalendar, so an unbounded upload would OOM.
|
||||
_ICS_MAX_BYTES = 10 * 1024 * 1024
|
||||
# Hard cap on ICS upload (ICS_MAX_BYTES, default 10 MB). Loading the whole
|
||||
# file into memory is unavoidable with python-icalendar, so an unbounded
|
||||
# upload would OOM.
|
||||
|
||||
@router.post("/import")
|
||||
async def import_ics(request: Request, file: UploadFile = File(...), calendar_name: str = ""):
|
||||
@@ -1000,7 +1182,7 @@ def setup_calendar_routes() -> APIRouter:
|
||||
owner = _require_user(request)
|
||||
db = SessionLocal()
|
||||
try:
|
||||
content = await read_upload_limited(file, _ICS_MAX_BYTES, "ICS file")
|
||||
content = await read_upload_limited(file, ICS_MAX_BYTES, "ICS file")
|
||||
try:
|
||||
cal_data = iCal.from_ical(content)
|
||||
except Exception as e:
|
||||
@@ -1168,11 +1350,14 @@ def setup_calendar_routes() -> APIRouter:
|
||||
lines.append("END:VCALENDAR")
|
||||
|
||||
ics_data = "\r\n".join(lines)
|
||||
safe_name = cal.name.replace(" ", "_").replace("/", "_")
|
||||
download_name = _safe_ics_filename(cal.name)
|
||||
return Response(
|
||||
content=ics_data,
|
||||
media_type="text/calendar",
|
||||
headers={"Content-Disposition": f'attachment; filename="{safe_name}.ics"'},
|
||||
headers={
|
||||
"Content-Disposition": f'attachment; filename="{download_name}"',
|
||||
"X-Content-Type-Options": "nosniff",
|
||||
},
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
@@ -1194,7 +1379,7 @@ def setup_calendar_routes() -> APIRouter:
|
||||
"tomorrow", "next Tuesday", "in 30 minutes" resolve correctly.
|
||||
Uses the "utility" endpoint (small / fast model) to keep latency low.
|
||||
"""
|
||||
_require_user(request)
|
||||
owner = _require_user(request)
|
||||
from src.endpoint_resolver import resolve_endpoint
|
||||
from src.llm_core import llm_call_async
|
||||
from src.text_helpers import strip_think
|
||||
@@ -1220,9 +1405,9 @@ def setup_calendar_routes() -> APIRouter:
|
||||
if tz_hint:
|
||||
set_user_tz_name(tz_hint)
|
||||
|
||||
url, model, headers = resolve_endpoint("utility")
|
||||
url, model, headers = resolve_endpoint("utility", owner=owner or None)
|
||||
if not url:
|
||||
url, model, headers = resolve_endpoint("default")
|
||||
url, model, headers = resolve_endpoint("default", owner=owner or None)
|
||||
if not url or not model:
|
||||
return {"ok": False, "error": "No LLM endpoint configured"}
|
||||
|
||||
|
||||
+130
-38
@@ -75,7 +75,7 @@ def _enforce_chat_privileges(request, sess) -> None:
|
||||
allowlist, or HTTPException(429) if the user has hit their daily message
|
||||
cap. No-op for unauthenticated callers or when auth_manager is absent
|
||||
(single-user mode). Admins receive ADMIN_PRIVILEGES from get_privileges,
|
||||
which means empty allowed_models / zero cap → no-op for them.
|
||||
which means unrestricted allowed_models / zero cap -> no-op for them.
|
||||
"""
|
||||
try:
|
||||
user = get_current_user(request)
|
||||
@@ -88,8 +88,18 @@ def _enforce_chat_privileges(request, sess) -> None:
|
||||
return
|
||||
|
||||
privs = auth_manager.get_privileges(user) or {}
|
||||
allowed = privs.get("allowed_models") or []
|
||||
if allowed and sess.model and sess.model not in allowed:
|
||||
|
||||
# Explicit "block everything" sentinel takes precedence over the
|
||||
# allowlist — it's the only way to distinguish "user clicked [None]"
|
||||
# (block all) from "user clicked [All]" (no restriction), since both
|
||||
# otherwise produce an empty `allowed_models` list.
|
||||
if privs.get("block_all_models"):
|
||||
raise HTTPException(403, f"Your account is not allowed to use model '{sess.model}'.")
|
||||
|
||||
allowed_raw = privs.get("allowed_models")
|
||||
allowed = allowed_raw if isinstance(allowed_raw, list) else []
|
||||
restricted = bool(privs.get("allowed_models_restricted")) or bool(allowed)
|
||||
if restricted and sess.model and sess.model not in allowed:
|
||||
raise HTTPException(403, f"Your account is not allowed to use model '{sess.model}'.")
|
||||
|
||||
cap = int(privs.get("max_messages_per_day") or 0)
|
||||
@@ -194,14 +204,26 @@ def try_fallback_endpoint(sess, session_id: str) -> dict | None:
|
||||
Returns {"model": ..., "endpoint_url": ..., "endpoint_name": ...} or None.
|
||||
"""
|
||||
import requests as _req
|
||||
from src.endpoint_resolver import build_chat_url, build_headers, build_models_url, normalize_base
|
||||
from src.endpoint_resolver import (
|
||||
build_chat_url,
|
||||
build_headers,
|
||||
build_models_url,
|
||||
normalize_base,
|
||||
resolve_endpoint_runtime,
|
||||
)
|
||||
from src.chatgpt_subscription import is_chatgpt_subscription_base
|
||||
|
||||
current_url = sess.endpoint_url or ""
|
||||
owner = getattr(sess, "owner", None)
|
||||
db = SessionLocal()
|
||||
try:
|
||||
endpoints = db.query(ModelEndpoint).filter(
|
||||
q = db.query(ModelEndpoint).filter(
|
||||
ModelEndpoint.is_enabled == True
|
||||
).all()
|
||||
)
|
||||
if owner:
|
||||
from src.auth_helpers import owner_filter
|
||||
q = owner_filter(q, ModelEndpoint, owner)
|
||||
endpoints = q.all()
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
@@ -210,26 +232,33 @@ def try_fallback_endpoint(sess, session_id: str) -> dict | None:
|
||||
# Skip current endpoint
|
||||
if current_url and base in current_url:
|
||||
continue
|
||||
# Quick ping
|
||||
ping_url = build_models_url(base)
|
||||
headers = build_headers(ep.api_key, base)
|
||||
try:
|
||||
r = _req.get(ping_url, headers=headers, timeout=5)
|
||||
r.raise_for_status()
|
||||
data = r.json()
|
||||
models = [m.get("id") for m in (data.get("data") or []) if m.get("id")]
|
||||
if not models:
|
||||
models = [
|
||||
m.get("name") or m.get("model")
|
||||
for m in (data.get("models") or [])
|
||||
if m.get("name") or m.get("model")
|
||||
]
|
||||
base, api_key = resolve_endpoint_runtime(ep, owner=owner)
|
||||
except Exception:
|
||||
continue
|
||||
ping_url = build_models_url(base)
|
||||
headers = build_headers(api_key, base)
|
||||
try:
|
||||
if ping_url:
|
||||
r = _req.get(ping_url, headers=headers, timeout=5)
|
||||
r.raise_for_status()
|
||||
data = r.json()
|
||||
models = [m.get("id") for m in (data.get("data") or []) if m.get("id")]
|
||||
if not models:
|
||||
models = [
|
||||
m.get("name") or m.get("model")
|
||||
for m in (data.get("models") or [])
|
||||
if m.get("name") or m.get("model")
|
||||
]
|
||||
else:
|
||||
models = json.loads(ep.cached_models or "[]")
|
||||
if not models:
|
||||
continue
|
||||
# Found a working endpoint — update session
|
||||
new_model = models[0]
|
||||
chat_url = build_chat_url(base)
|
||||
new_headers = build_headers(ep.api_key, base)
|
||||
new_headers = build_headers(api_key, base)
|
||||
persisted_headers = {} if is_chatgpt_subscription_base(base) else new_headers
|
||||
|
||||
sess.model = new_model
|
||||
sess.endpoint_url = chat_url
|
||||
@@ -241,7 +270,7 @@ def try_fallback_endpoint(sess, session_id: str) -> dict | None:
|
||||
_db.query(DBSession).filter(DBSession.id == session_id).update({
|
||||
"model": new_model,
|
||||
"endpoint_url": chat_url,
|
||||
"headers": json.dumps(new_headers),
|
||||
"headers": persisted_headers,
|
||||
})
|
||||
_db.commit()
|
||||
finally:
|
||||
@@ -275,11 +304,16 @@ def extract_preset(chat_handler, preset_id) -> PresetInfo:
|
||||
async def preprocess(
|
||||
chat_handler, message, att_ids, sess,
|
||||
auto_opened_docs: Optional[list] = None,
|
||||
allow_tool_preprocessing: bool = True,
|
||||
) -> PreprocessedMessage:
|
||||
"""Run chat_handler.preprocess_message and wrap the result."""
|
||||
enhanced, user_content, text_ctx, yt_transcripts, att_meta = (
|
||||
await chat_handler.preprocess_message(
|
||||
message, att_ids, sess, auto_opened_docs=auto_opened_docs
|
||||
message,
|
||||
att_ids,
|
||||
sess,
|
||||
auto_opened_docs=auto_opened_docs,
|
||||
allow_tool_preprocessing=allow_tool_preprocessing,
|
||||
)
|
||||
)
|
||||
return PreprocessedMessage(
|
||||
@@ -329,16 +363,26 @@ def _session_url_matches_endpoint(session_url: str, endpoint_base: str) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def _has_auth_keys(headers) -> bool:
|
||||
"""True if a headers dict carries an Authorization/x-api-key entry."""
|
||||
return isinstance(headers, dict) and any(
|
||||
k.lower() in ('authorization', 'x-api-key') for k in headers
|
||||
)
|
||||
|
||||
|
||||
def resolve_session_auth(sess, session_id: str, owner: Optional[str] = None):
|
||||
"""Ensure session has auth headers — resolve from endpoint DB if missing."""
|
||||
has_auth = sess.headers and isinstance(sess.headers, dict) and any(
|
||||
k.lower() in ('authorization', 'x-api-key') for k in sess.headers
|
||||
)
|
||||
if has_auth:
|
||||
try:
|
||||
from src.chatgpt_subscription import is_chatgpt_subscription_base
|
||||
is_chatgpt_subscription = is_chatgpt_subscription_base(getattr(sess, "endpoint_url", "") or "")
|
||||
except Exception:
|
||||
is_chatgpt_subscription = False
|
||||
has_auth = _has_auth_keys(sess.headers)
|
||||
if has_auth and not is_chatgpt_subscription:
|
||||
return
|
||||
|
||||
try:
|
||||
from src.endpoint_resolver import build_headers, normalize_base
|
||||
from src.endpoint_resolver import build_headers, resolve_endpoint_runtime
|
||||
db = SessionLocal()
|
||||
try:
|
||||
target_url = getattr(sess, "endpoint_url", "") or ""
|
||||
@@ -354,10 +398,30 @@ def resolve_session_auth(sess, session_id: str, owner: Optional[str] = None):
|
||||
for ep in q.all():
|
||||
if not _session_url_matches_endpoint(target_url, ep.base_url or ""):
|
||||
continue
|
||||
if not ep.api_key:
|
||||
try:
|
||||
base, api_key = resolve_endpoint_runtime(ep, owner=owner)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to resolve provider auth for session %s: %s", session_id, e)
|
||||
return
|
||||
if not api_key:
|
||||
# No usable key (e.g. ChatGPT Subscription needs re-auth).
|
||||
return
|
||||
sess.headers = build_headers(api_key, base)
|
||||
if is_chatgpt_subscription:
|
||||
# The bearer is short-lived and re-resolved per request, so it
|
||||
# stays request-local and is never written to the plaintext
|
||||
# sessions.headers column. Proactively strip any bearer an
|
||||
# older code path may have persisted so it does not linger.
|
||||
stale_q = db.query(DBSession).filter(DBSession.id == session_id)
|
||||
if owner:
|
||||
stale_q = stale_q.filter(DBSession.owner == owner)
|
||||
stored = stale_q.first()
|
||||
if stored is not None and _has_auth_keys(stored.headers):
|
||||
stale_q.update({"headers": {}})
|
||||
db.commit()
|
||||
logger.info(f"Cleared persisted ChatGPT Subscription bearer from session {session_id}")
|
||||
logger.debug(f"Resolved request-local ChatGPT Subscription auth for session {session_id}")
|
||||
return
|
||||
base = normalize_base(ep.base_url or "")
|
||||
sess.headers = build_headers(ep.api_key, base)
|
||||
update_q = db.query(DBSession).filter(DBSession.id == session_id)
|
||||
if owner:
|
||||
update_q = update_q.filter(DBSession.owner == owner)
|
||||
@@ -401,7 +465,12 @@ def _normalize_model_id_from_cache(sess) -> Optional[str]:
|
||||
|
||||
db = SessionLocal()
|
||||
try:
|
||||
endpoints = db.query(ModelEndpoint).filter(ModelEndpoint.is_enabled == True).all()
|
||||
q = db.query(ModelEndpoint).filter(ModelEndpoint.is_enabled == True)
|
||||
owner = getattr(sess, "owner", None)
|
||||
if owner:
|
||||
from src.auth_helpers import owner_filter
|
||||
q = owner_filter(q, ModelEndpoint, owner)
|
||||
endpoints = q.all()
|
||||
for ep in endpoints:
|
||||
try:
|
||||
if normalize_base(getattr(ep, "base_url", "") or "") != session_base:
|
||||
@@ -448,6 +517,7 @@ async def build_chat_context(
|
||||
webhook_manager=None,
|
||||
use_enhanced_message: bool = False,
|
||||
agent_mode: bool = False,
|
||||
allow_tool_preprocessing: bool = True,
|
||||
) -> ChatContext:
|
||||
"""Build the full context (preface + messages) for an LLM call.
|
||||
|
||||
@@ -465,6 +535,7 @@ async def build_chat_context(
|
||||
preprocessed = await preprocess(
|
||||
chat_handler, message, att_ids or [], sess,
|
||||
auto_opened_docs=auto_opened_docs,
|
||||
allow_tool_preprocessing=allow_tool_preprocessing,
|
||||
)
|
||||
|
||||
# Add user message to history
|
||||
@@ -483,6 +554,9 @@ async def build_chat_context(
|
||||
# Skills injection respects its own enable toggle (mirrors memory_enabled).
|
||||
# When off, the "Available skills" index is not added to the prompt.
|
||||
skills_enabled = not incognito and uprefs.get("skills_enabled", True)
|
||||
if not allow_tool_preprocessing:
|
||||
mem_enabled = False
|
||||
skills_enabled = False
|
||||
logger.debug(
|
||||
"Memory enabled=%s for user=%s (incognito=%s, no_memory=%s, pref=%s)",
|
||||
mem_enabled, user, incognito, no_memory, uprefs.get("memory_enabled", "NOT_SET"),
|
||||
@@ -490,11 +564,11 @@ async def build_chat_context(
|
||||
|
||||
# Use RAG?
|
||||
use_rag_val = (str(use_rag).lower() != "false") if use_rag is not None else True
|
||||
if incognito:
|
||||
if incognito or not allow_tool_preprocessing:
|
||||
use_rag_val = False
|
||||
|
||||
# If pre-fetched search context was provided (compare mode), skip live web search
|
||||
skip_web = bool(search_context)
|
||||
skip_web = bool(search_context) or not allow_tool_preprocessing
|
||||
|
||||
# Build context preface
|
||||
# The stream path uses enhanced_message (with CoT/preprocessing applied),
|
||||
@@ -521,7 +595,7 @@ async def build_chat_context(
|
||||
used_memories = getattr(chat_processor, '_last_used_memories', [])
|
||||
|
||||
# Inject pre-fetched search context (compare mode)
|
||||
if search_context:
|
||||
if search_context and allow_tool_preprocessing:
|
||||
preface.append(untrusted_context_message("prefetched search context", search_context))
|
||||
|
||||
# YouTube transcripts
|
||||
@@ -530,7 +604,11 @@ async def build_chat_context(
|
||||
|
||||
# Normalize model ID. Prefer cached endpoint models so group chat does not
|
||||
# re-hit slow local /models endpoints on every participant turn.
|
||||
norm = _normalize_model_id_from_cache(sess) or normalize_model_id(sess.endpoint_url, sess.model)
|
||||
norm = _normalize_model_id_from_cache(sess) or normalize_model_id(
|
||||
sess.endpoint_url,
|
||||
sess.model,
|
||||
owner=getattr(sess, "owner", None),
|
||||
)
|
||||
if norm:
|
||||
sess.model = norm
|
||||
|
||||
@@ -539,7 +617,7 @@ async def build_chat_context(
|
||||
|
||||
# Auto-compact
|
||||
messages, context_length, was_compacted = await maybe_compact(
|
||||
sess, sess.endpoint_url, sess.model, messages, sess.headers,
|
||||
sess, sess.endpoint_url, sess.model, messages, sess.headers, owner=user,
|
||||
)
|
||||
messages = trim_for_context(messages, context_length)
|
||||
|
||||
@@ -772,7 +850,19 @@ def save_assistant_response(
|
||||
):
|
||||
"""Add assistant response to session history. In incognito mode, keeps in-memory context but skips DB persistence."""
|
||||
md = dict(last_metrics) if last_metrics else {}
|
||||
md["model"] = sess.model
|
||||
def _model_value(value) -> str:
|
||||
if value is None:
|
||||
return ""
|
||||
if not isinstance(value, str):
|
||||
value = str(value)
|
||||
return value.strip()
|
||||
|
||||
requested_model = _model_value(md.get("requested_model") or md.get("selected_model") or getattr(sess, "model", ""))
|
||||
actual_model = _model_value(md.get("model") or md.get("actual_model") or requested_model)
|
||||
if requested_model:
|
||||
md["requested_model"] = requested_model
|
||||
if actual_model:
|
||||
md["model"] = actual_model
|
||||
if character_name:
|
||||
md["character_name"] = character_name
|
||||
if web_sources:
|
||||
@@ -841,12 +931,13 @@ def run_post_response_tasks(
|
||||
skills_manager=None,
|
||||
owner: str = None,
|
||||
extract_skills: bool = True,
|
||||
allow_background_extraction: bool = True,
|
||||
):
|
||||
"""Fire background tasks after a completed response: memory extraction, webhooks, auto-name, skill extraction."""
|
||||
# Memory extraction — only every 4th message pair to avoid excess LLM calls
|
||||
_msg_count = len(sess.history) if hasattr(sess, 'history') else 0
|
||||
_should_extract = (_msg_count >= 4) and (_msg_count % 4 == 0)
|
||||
if not incognito and not compare_mode and _should_extract and uprefs.get("auto_memory", True):
|
||||
if allow_background_extraction and not incognito and not compare_mode and _should_extract and uprefs.get("auto_memory", True):
|
||||
from services.memory.memory_extractor import extract_and_store
|
||||
from src.task_endpoint import resolve_task_endpoint
|
||||
t_url, t_model, t_headers = resolve_task_endpoint(
|
||||
@@ -873,6 +964,7 @@ def run_post_response_tasks(
|
||||
)
|
||||
if (
|
||||
extract_skills
|
||||
and allow_background_extraction
|
||||
and auto_skills_enabled
|
||||
and not incognito
|
||||
and not compare_mode
|
||||
|
||||
+206
-72
@@ -20,6 +20,7 @@ from src import agent_runs
|
||||
from src.model_context import estimate_tokens
|
||||
from src.chat_helpers import coerce_message_and_session
|
||||
from src.endpoint_resolver import normalize_base as _normalize_base, build_chat_url
|
||||
from src.session_search import search_session_messages
|
||||
from src.prompt_security import untrusted_context_message
|
||||
from core.exceptions import SessionNotFoundError
|
||||
from src.auth_helpers import get_current_user
|
||||
@@ -39,6 +40,7 @@ from routes.chat_helpers import (
|
||||
_enforce_chat_privileges,
|
||||
)
|
||||
from src.action_intents import classify_tool_intent as _classify_tool_intent
|
||||
from src.tool_policy import build_effective_tool_policy
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -167,13 +169,20 @@ def _recover_empty_session_model(sess, session_id: str, owner: str | None = None
|
||||
Covers the window between endpoint setup and the first chat send: the
|
||||
picker showed a model in the dropdown but the session record never got
|
||||
written (Issue #587 — UI uses the cached endpoint list, not s.model).
|
||||
Without this, we'd POST the upstream with model="" and get a generic
|
||||
401/503 instead of using the model the user already picked.
|
||||
|
||||
Returns True iff sess.model was repaired.
|
||||
For ChatGPT Subscription, also repairs stale OpenAI API model names such as
|
||||
``gpt-5`` that are not accepted by the Codex-backed ChatGPT account route.
|
||||
"""
|
||||
if getattr(sess, "model", None):
|
||||
return False
|
||||
current_model = (getattr(sess, "model", "") or "").strip()
|
||||
endpoint_url = (getattr(sess, "endpoint_url", "") or "").strip()
|
||||
is_chatgpt_subscription = False
|
||||
if current_model:
|
||||
try:
|
||||
from src.chatgpt_subscription import is_chatgpt_subscription_base
|
||||
is_chatgpt_subscription = is_chatgpt_subscription_base(endpoint_url)
|
||||
if not is_chatgpt_subscription:
|
||||
return False
|
||||
except Exception:
|
||||
return False
|
||||
db = SessionLocal()
|
||||
try:
|
||||
# Prefer the endpoint whose base URL matches the session — we know the
|
||||
@@ -192,16 +201,51 @@ def _recover_empty_session_model(sess, session_id: str, owner: str | None = None
|
||||
break
|
||||
if not ep:
|
||||
return False
|
||||
if not is_chatgpt_subscription:
|
||||
try:
|
||||
from src.chatgpt_subscription import is_chatgpt_subscription_base
|
||||
is_chatgpt_subscription = is_chatgpt_subscription_base(getattr(ep, "base_url", "") or endpoint_url)
|
||||
except Exception:
|
||||
is_chatgpt_subscription = False
|
||||
try:
|
||||
cached = json.loads(ep.cached_models) if isinstance(ep.cached_models, str) else (ep.cached_models or [])
|
||||
except Exception:
|
||||
cached = []
|
||||
if not cached:
|
||||
visible = []
|
||||
else:
|
||||
try:
|
||||
visible = _visible_models(cached, getattr(ep, "hidden_models", None))
|
||||
except Exception:
|
||||
visible = cached
|
||||
if current_model and current_model in {str(item).strip() for item in visible}:
|
||||
return False
|
||||
try:
|
||||
visible = _visible_models(cached, getattr(ep, "hidden_models", None))
|
||||
except Exception:
|
||||
visible = cached
|
||||
if is_chatgpt_subscription:
|
||||
live_models = []
|
||||
if getattr(ep, "provider_auth_id", None):
|
||||
try:
|
||||
from src.chatgpt_subscription import fetch_available_models
|
||||
from src.endpoint_resolver import resolve_endpoint_runtime
|
||||
_base, api_key = resolve_endpoint_runtime(ep, owner=owner)
|
||||
if api_key:
|
||||
live_models = fetch_available_models(api_key)
|
||||
if live_models:
|
||||
ep.cached_models = json.dumps(live_models)
|
||||
db.commit()
|
||||
except Exception:
|
||||
live_models = []
|
||||
# ChatGPT Subscription recovery must use the live Codex catalog.
|
||||
# Cached rows are only trusted above to avoid revalidating a model
|
||||
# that is already present in the visible picker list.
|
||||
cached = live_models
|
||||
if not cached:
|
||||
return False
|
||||
try:
|
||||
visible = _visible_models(cached, getattr(ep, "hidden_models", None))
|
||||
except Exception:
|
||||
visible = cached
|
||||
if current_model and current_model in {str(item).strip() for item in visible}:
|
||||
return False
|
||||
if not visible:
|
||||
return False
|
||||
model = visible[0]
|
||||
@@ -211,14 +255,17 @@ def _recover_empty_session_model(sess, session_id: str, owner: str | None = None
|
||||
# Persist so the next request, websocket reconnect, or page reload
|
||||
# picks up the same model (we'd otherwise re-pick on every send
|
||||
# and silently switch on the user if the cached order shifts).
|
||||
db_session = db.query(DBSession).filter(DBSession.id == session_id).first()
|
||||
db_session_q = db.query(DBSession).filter(DBSession.id == session_id)
|
||||
if owner:
|
||||
db_session_q = db_session_q.filter(DBSession.owner == owner)
|
||||
db_session = db_session_q.first()
|
||||
if db_session:
|
||||
db_session.model = model
|
||||
db_session.updated_at = datetime.utcnow()
|
||||
db.commit()
|
||||
sess.model = model
|
||||
logger.info(
|
||||
"Recovered empty session model for %s — picked %r from endpoint %s",
|
||||
"Recovered session model for %s — picked %r from endpoint %s",
|
||||
session_id, model, ep.id,
|
||||
)
|
||||
return True
|
||||
@@ -304,8 +351,13 @@ def setup_chat_routes(
|
||||
# non-streaming path can't be used to bypass).
|
||||
_enforce_chat_privileges(request, sess)
|
||||
|
||||
tool_policy = build_effective_tool_policy(last_user_message=message)
|
||||
allow_tool_preprocessing = not tool_policy.block_all_tool_calls
|
||||
|
||||
# Inline memory command
|
||||
memory_response = await chat_handler.handle_memory_command(sess, message)
|
||||
memory_response = None
|
||||
if not tool_policy.blocks("manage_memory"):
|
||||
memory_response = await chat_handler.handle_memory_command(sess, message)
|
||||
if memory_response:
|
||||
return {"response": memory_response}
|
||||
|
||||
@@ -319,10 +371,15 @@ def setup_chat_routes(
|
||||
use_web=use_web,
|
||||
time_filter=time_filter,
|
||||
webhook_manager=webhook_manager,
|
||||
allow_tool_preprocessing=allow_tool_preprocessing,
|
||||
)
|
||||
|
||||
# Research injection
|
||||
if use_research:
|
||||
research_blocked_by_policy = (
|
||||
tool_policy.blocks("trigger_research")
|
||||
or tool_policy.blocks("manage_research")
|
||||
)
|
||||
if use_research and not research_blocked_by_policy:
|
||||
try:
|
||||
_r_ep, _r_model, _r_headers = _resolve_research_endpoint(sess)
|
||||
research_ctx = await research_handler.call_research_service(
|
||||
@@ -357,6 +414,7 @@ def setup_chat_routes(
|
||||
ctx.uprefs, memory_manager, memory_vector, webhook_manager,
|
||||
character_name=ctx.preset.character_name,
|
||||
owner=ctx.user,
|
||||
allow_background_extraction=not tool_policy.block_all_tool_calls,
|
||||
)
|
||||
|
||||
return {"response": reply}
|
||||
@@ -394,6 +452,7 @@ def setup_chat_routes(
|
||||
search_context = form_data.get("search_context") # pre-fetched web search results (compare mode)
|
||||
compare_mode = str(form_data.get("compare_mode", "")).lower() == "true"
|
||||
incognito = str(form_data.get("incognito", "")).lower() == "true"
|
||||
plan_mode = str(form_data.get("plan_mode", "")).lower() == "true"
|
||||
chat_mode = str(form_data.get("mode", "")).lower() # 'chat' or 'agent'
|
||||
# Workspace: confine the agent's file/shell tools to this folder. Validate
|
||||
# it's a real directory; ignore (no confinement) otherwise.
|
||||
@@ -401,6 +460,17 @@ def setup_chat_routes(
|
||||
if workspace:
|
||||
_ws_real = os.path.realpath(os.path.expanduser(workspace))
|
||||
workspace = _ws_real if os.path.isdir(_ws_real) else ""
|
||||
# Plan mode is a modifier on agent mode — it only makes sense with tools.
|
||||
if plan_mode:
|
||||
chat_mode = "agent"
|
||||
# An approved plan being EXECUTED: the frontend sends the checklist back
|
||||
# on each turn so we can pin it in context. This way a long plan on a
|
||||
# weak model survives history truncation — the agent can always re-read
|
||||
# the plan. Ignored while still proposing (plan_mode on). Capped so a
|
||||
# huge plan can't blow the prompt.
|
||||
approved_plan = ""
|
||||
if not plan_mode:
|
||||
approved_plan = (form_data.get("approved_plan") or "").strip()[:8192]
|
||||
# Did the USER explicitly pick agent mode? (vs. us auto-escalating
|
||||
# below). Skill extraction should only learn from real agent sessions,
|
||||
# not chats we quietly promoted for a notes/calendar intent.
|
||||
@@ -479,11 +549,6 @@ def setup_chat_routes(
|
||||
do_research = True
|
||||
logger.info(f"Session {session} in research_pending — auto-triggering research")
|
||||
|
||||
# Persist session mode (research > agent > chat)
|
||||
_effective_mode = 'research' if do_research else (chat_mode or 'chat')
|
||||
if _effective_mode in ('agent', 'research', 'chat'):
|
||||
set_session_mode(session, _effective_mode)
|
||||
|
||||
att_ids = []
|
||||
if body and isinstance(body.get("attachments"), list):
|
||||
att_ids = [str(x) for x in body["attachments"]]
|
||||
@@ -494,6 +559,10 @@ def setup_chat_routes(
|
||||
pass
|
||||
|
||||
no_memory = str(form_data.get("no_memory", "")).lower() == "true"
|
||||
pre_context_tool_policy = build_effective_tool_policy(
|
||||
last_user_message=message,
|
||||
)
|
||||
allow_tool_preprocessing = not pre_context_tool_policy.block_all_tool_calls
|
||||
|
||||
# Build shared context (stream path uses enhanced_message for context preface)
|
||||
ctx = await build_chat_context(
|
||||
@@ -515,6 +584,7 @@ def setup_chat_routes(
|
||||
# manage_skills (agent mode). In plain chat or incognito the
|
||||
# index would be useless / unwanted noise.
|
||||
agent_mode=(chat_mode == "agent"),
|
||||
allow_tool_preprocessing=allow_tool_preprocessing,
|
||||
)
|
||||
|
||||
_research_flags = {"do": do_research} # Mutable container for generator scope
|
||||
@@ -659,6 +729,32 @@ def setup_chat_routes(
|
||||
if chat_mode == 'chat':
|
||||
disabled_tools.update({"bash", "python", "read_file", "write_file", "web_search", "web_fetch", "search_chats", "manage_tasks"})
|
||||
|
||||
# Plan mode: investigate read-only, propose a plan, don't mutate. Block
|
||||
# every tool not on the read-only allowlist. (stream_agent_loop enforces
|
||||
# this again + drops MCP, so this is belt-and-suspenders.)
|
||||
if plan_mode:
|
||||
from src.tool_security import plan_mode_disabled_tools
|
||||
disabled_tools.update(plan_mode_disabled_tools())
|
||||
|
||||
tool_policy = build_effective_tool_policy(
|
||||
disabled_tools=disabled_tools,
|
||||
last_user_message=message,
|
||||
)
|
||||
disabled_tools = tool_policy.all_disabled_names()
|
||||
research_blocked_by_policy = bool(
|
||||
tool_policy.blocks("trigger_research")
|
||||
or tool_policy.blocks("manage_research")
|
||||
)
|
||||
effective_do_research = bool(
|
||||
do_research and _research_flags["do"] and not research_blocked_by_policy
|
||||
)
|
||||
|
||||
# Persist session mode after policy/privilege gates so blocked research
|
||||
# turns remain ordinary chat/agent streams and saved messages.
|
||||
_effective_mode = 'research' if effective_do_research else (chat_mode or 'chat')
|
||||
if _effective_mode in ('agent', 'research', 'chat'):
|
||||
set_session_mode(session, _effective_mode)
|
||||
|
||||
async def stream_with_save() -> AsyncGenerator[str, None]:
|
||||
# _effective_mode is read-only here; closure captures it from
|
||||
# the outer scope. (Was `nonlocal` but never reassigned.)
|
||||
@@ -666,7 +762,7 @@ def setup_chat_routes(
|
||||
web_sources = ctx.web_sources
|
||||
|
||||
# Register active stream for partial-save safety net
|
||||
_active_streams[session] = {"status": "streaming", "partial": "", "query": message, "is_research": do_research, "mode": _effective_mode}
|
||||
_active_streams[session] = {"status": "streaming", "partial": "", "query": message, "is_research": effective_do_research, "mode": _effective_mode}
|
||||
|
||||
if ctx.preprocessed.attachment_meta:
|
||||
yield f"data: {json.dumps({'type': 'attachments', 'data': ctx.preprocessed.attachment_meta})}\n\n"
|
||||
@@ -690,7 +786,7 @@ def setup_chat_routes(
|
||||
yield f"data: {json.dumps({'type': 'memories_used', 'data': ctx.used_memories})}\n\n"
|
||||
|
||||
# Run research as a background task (survives page refresh)
|
||||
if do_research and _research_flags["do"]:
|
||||
if effective_do_research:
|
||||
_r_ep, _r_model, _r_headers = _resolve_research_endpoint(sess)
|
||||
_auth_keys = list(_r_headers.keys()) if _r_headers else []
|
||||
logger.info(f"Research endpoint resolved: model={_r_model}, endpoint={_r_ep}, auth_keys={_auth_keys}, sess_headers_keys={list(sess.headers.keys()) if isinstance(sess.headers, dict) else type(sess.headers)}")
|
||||
@@ -829,7 +925,7 @@ def setup_chat_routes(
|
||||
_fallback_candidates = []
|
||||
|
||||
# Send model name early so the frontend can show it during streaming
|
||||
_model_suffix = "Research" if do_research else None
|
||||
_model_suffix = "Research" if effective_do_research else None
|
||||
_model_info = {"type": "model_info", "model": sess.model}
|
||||
if _model_suffix:
|
||||
_model_info["suffix"] = _model_suffix
|
||||
@@ -839,6 +935,12 @@ def setup_chat_routes(
|
||||
|
||||
if _is_image_generation_session(sess, owner=_user):
|
||||
from src.settings import get_setting
|
||||
if tool_policy.blocks("generate_image"):
|
||||
_blocked_msg = tool_policy.reason_for("generate_image")
|
||||
yield f'data: {json.dumps({"delta": _blocked_msg})}\n\n'
|
||||
yield "data: [DONE]\n\n"
|
||||
_active_streams.pop(session, None)
|
||||
return
|
||||
if not get_setting("image_gen_enabled", True):
|
||||
yield f'data: {json.dumps({"delta": "Image generation is disabled by the administrator."})}\n\n'
|
||||
yield "data: [DONE]\n\n"
|
||||
@@ -873,6 +975,8 @@ def setup_chat_routes(
|
||||
elif chat_mode == "chat":
|
||||
_chat_start = time.time()
|
||||
_answered_by = None # set if the selected model failed and a fallback answered
|
||||
_requested_model = sess.model
|
||||
_actual_model = None
|
||||
# ── Chat mode: call stream_llm directly, NO tools, NO document access ──
|
||||
try:
|
||||
_chat_candidates = [(sess.endpoint_url, sess.model, sess.headers)] + _fallback_candidates
|
||||
@@ -905,10 +1009,18 @@ def setup_chat_routes(
|
||||
# Selected model failed; a fallback answered.
|
||||
# Forward the notice and remember the real model.
|
||||
_answered_by = data.get("answered_by") or _answered_by
|
||||
_actual_model = _actual_model or _answered_by
|
||||
data["selected_model"] = data.get("selected_model") or _requested_model
|
||||
yield chunk
|
||||
elif data.get("type") == "model_actual":
|
||||
_actual_model = data.get("model") or _actual_model
|
||||
data["requested_model"] = _requested_model
|
||||
yield f'data: {json.dumps(data)}\n\n'
|
||||
elif data.get("type") == "usage":
|
||||
last_metrics = data.get("data", {})
|
||||
last_metrics["model"] = _answered_by or sess.model
|
||||
_reported_model = last_metrics.get("model")
|
||||
last_metrics["requested_model"] = _requested_model
|
||||
last_metrics["model"] = _reported_model or _actual_model or _answered_by or _requested_model
|
||||
if ctx.context_length and last_metrics.get("input_tokens"):
|
||||
pct = min(round((last_metrics["input_tokens"] / ctx.context_length) * 100, 1), 100.0)
|
||||
last_metrics["context_percent"] = pct
|
||||
@@ -945,7 +1057,8 @@ def setup_chat_routes(
|
||||
"tokens_per_second": _tps,
|
||||
"context_percent": _ctx_pct,
|
||||
"context_length": ctx.context_length,
|
||||
"model": sess.model,
|
||||
"model": _actual_model or _answered_by or _requested_model,
|
||||
"requested_model": _requested_model,
|
||||
"usage_source": "estimated",
|
||||
}
|
||||
yield f'data: {json.dumps({"type": "metrics", "data": last_metrics})}\n\n'
|
||||
@@ -957,7 +1070,7 @@ def setup_chat_routes(
|
||||
rag_sources=ctx.rag_sources,
|
||||
research_sources=research_sources,
|
||||
used_memories=ctx.used_memories,
|
||||
do_research=do_research,
|
||||
do_research=effective_do_research,
|
||||
incognito=incognito,
|
||||
)
|
||||
if _saved_id:
|
||||
@@ -967,14 +1080,22 @@ def setup_chat_routes(
|
||||
last_metrics, ctx.uprefs, memory_manager, memory_vector, webhook_manager,
|
||||
incognito=incognito, compare_mode=compare_mode,
|
||||
character_name=ctx.preset.character_name,
|
||||
owner=_user,
|
||||
owner=_user,
|
||||
allow_background_extraction=not tool_policy.block_all_tool_calls,
|
||||
)
|
||||
_stream_set(session, status="done")
|
||||
yield chunk
|
||||
except (asyncio.CancelledError, GeneratorExit):
|
||||
if full_response:
|
||||
logger.info("Client disconnected mid-stream (chat mode) for session %s, saving partial (%d chars)", session, len(full_response))
|
||||
_stopped_content, _stopped_md = clean_thinking_for_save(full_response, {"stopped": True, "model": sess.model})
|
||||
_stopped_content, _stopped_md = clean_thinking_for_save(
|
||||
full_response,
|
||||
{
|
||||
"stopped": True,
|
||||
"model": _actual_model or _answered_by or _requested_model,
|
||||
"requested_model": _requested_model,
|
||||
},
|
||||
)
|
||||
sess.add_message(ChatMessage("assistant", _stopped_content, metadata=_stopped_md))
|
||||
if not incognito:
|
||||
session_manager.save_sessions()
|
||||
@@ -986,6 +1107,8 @@ def setup_chat_routes(
|
||||
_agent_rounds = 0
|
||||
_agent_tool_calls = 0
|
||||
_answered_by = None # set if the selected model failed and a fallback answered
|
||||
_requested_model = sess.model
|
||||
_actual_model = None
|
||||
try:
|
||||
from src.settings import get_setting
|
||||
from src.agent_tools import MAX_AGENT_ROUNDS as _DEFAULT_ROUNDS
|
||||
@@ -1012,9 +1135,12 @@ def setup_chat_routes(
|
||||
active_document=active_doc,
|
||||
session_id=session,
|
||||
disabled_tools=disabled_tools if disabled_tools else None,
|
||||
tool_policy=tool_policy,
|
||||
owner=_user,
|
||||
fallbacks=_fallback_candidates,
|
||||
workspace=workspace or None,
|
||||
plan_mode=plan_mode,
|
||||
approved_plan=approved_plan or None,
|
||||
):
|
||||
if chunk.startswith("data: ") and not chunk.startswith("data: [DONE]"):
|
||||
try:
|
||||
@@ -1035,6 +1161,8 @@ def setup_chat_routes(
|
||||
"doc_stream_open", "doc_stream_delta",
|
||||
"doc_update", "doc_suggestions", "ui_control",
|
||||
"rounds_exhausted",
|
||||
"ask_user",
|
||||
"plan_update",
|
||||
):
|
||||
if data.get("type") == "agent_step":
|
||||
_agent_rounds = max(_agent_rounds, data.get("round", 1))
|
||||
@@ -1047,10 +1175,18 @@ def setup_chat_routes(
|
||||
# model so metrics reflect it, not the masked
|
||||
# selected model.
|
||||
_answered_by = data.get("answered_by") or _answered_by
|
||||
_actual_model = _actual_model or _answered_by
|
||||
data["selected_model"] = data.get("selected_model") or _requested_model
|
||||
yield chunk
|
||||
elif data.get("type") == "model_actual":
|
||||
_actual_model = data.get("model") or _actual_model
|
||||
data["requested_model"] = _requested_model
|
||||
yield f'data: {json.dumps(data)}\n\n'
|
||||
elif data.get("type") == "metrics":
|
||||
last_metrics = data.get("data", {})
|
||||
last_metrics["model"] = _answered_by or sess.model
|
||||
_reported_model = last_metrics.get("model")
|
||||
last_metrics["requested_model"] = last_metrics.get("requested_model") or _requested_model
|
||||
last_metrics["model"] = _reported_model or _actual_model or _answered_by or _requested_model
|
||||
yield f'data: {json.dumps({"type": "metrics", "data": last_metrics})}\n\n'
|
||||
except json.JSONDecodeError:
|
||||
yield chunk
|
||||
@@ -1078,6 +1214,7 @@ def setup_chat_routes(
|
||||
skills_manager=skills_manager,
|
||||
owner=_user,
|
||||
extract_skills=user_requested_agent,
|
||||
allow_background_extraction=not tool_policy.block_all_tool_calls,
|
||||
)
|
||||
_stream_set(session, status="done")
|
||||
yield chunk
|
||||
@@ -1091,7 +1228,14 @@ def setup_chat_routes(
|
||||
try:
|
||||
if full_response:
|
||||
logger.info("Client disconnected mid-stream for session %s, saving partial response (%d chars)", session, len(full_response))
|
||||
_stopped_content2, _stopped_md2 = clean_thinking_for_save(full_response, {"stopped": True, "model": sess.model})
|
||||
_stopped_content2, _stopped_md2 = clean_thinking_for_save(
|
||||
full_response,
|
||||
{
|
||||
"stopped": True,
|
||||
"model": _actual_model or _answered_by or _requested_model,
|
||||
"requested_model": _requested_model,
|
||||
},
|
||||
)
|
||||
sess.add_message(ChatMessage("assistant", _stopped_content2, metadata=_stopped_md2))
|
||||
if not incognito:
|
||||
session_manager.save_sessions()
|
||||
@@ -1110,11 +1254,30 @@ def setup_chat_routes(
|
||||
finally:
|
||||
_active_streams.pop(session, None)
|
||||
|
||||
# Run the stream as a DETACHED background task so it survives the client
|
||||
# closing the tab / navigating away (true terminal-agent behavior). The
|
||||
# SSE response just subscribes (replay buffered output + live); dropping
|
||||
# the SSE only removes a subscriber — the run keeps going and saves the
|
||||
# assistant message on completion regardless. Reconnect via /api/chat/resume.
|
||||
# Compare panes are short-lived, single-shot generations whose sessions
|
||||
# exist only to drive that one pane — there's nothing to "resume" and
|
||||
# the user expects the pane's Stop button (which aborts the fetch,
|
||||
# closing this SSE) to promptly cancel the upstream LLM call. Detaching
|
||||
# them would keep burning upstream tokens/compute after the pane is
|
||||
# stopped or the comparison is abandoned, and would surface a stale
|
||||
# "still streaming" /resume target for a session nobody will revisit.
|
||||
#
|
||||
# So: stream them directly (no agent_runs wrapping). Starlette cancels
|
||||
# the underlying async generator (raising CancelledError/GeneratorExit
|
||||
# inside it) as soon as it notices the client disconnected — which the
|
||||
# mode-specific except blocks above already handle by saving the
|
||||
# partial response exactly once. This stops the upstream call promptly
|
||||
# without waiting on the next streamed chunk.
|
||||
#
|
||||
# Normal chat/agent streams keep the DETACHED behavior below: they
|
||||
# survive the client closing the tab / navigating away (true
|
||||
# terminal-agent semantics). The SSE response just subscribes (replay
|
||||
# buffered output + live); dropping the SSE only removes a subscriber —
|
||||
# the run keeps going and saves the assistant message on completion
|
||||
# regardless. Reconnect via /api/chat/resume.
|
||||
if compare_mode:
|
||||
return StreamingResponse(_safe_stream(), media_type="text/event-stream")
|
||||
|
||||
agent_runs.start(session, _safe_stream())
|
||||
return StreamingResponse(agent_runs.subscribe(session), media_type="text/event-stream")
|
||||
|
||||
@@ -1185,45 +1348,16 @@ def setup_chat_routes(
|
||||
return []
|
||||
|
||||
_user = get_current_user(request)
|
||||
query_term = q.strip()
|
||||
db = SessionLocal()
|
||||
try:
|
||||
base_q = (
|
||||
db.query(DBChatMessage, DBSession.name)
|
||||
.join(DBSession, DBChatMessage.session_id == DBSession.id)
|
||||
.filter(
|
||||
DBSession.archived == False,
|
||||
DBChatMessage.content.ilike(f"%{query_term}%"),
|
||||
DBChatMessage.role.in_(["user", "assistant"]),
|
||||
)
|
||||
return [
|
||||
result.to_dict()
|
||||
for result in search_session_messages(
|
||||
q,
|
||||
limit=limit,
|
||||
owner=_user,
|
||||
restrict_owner=_user is not None,
|
||||
include_legacy_owner=False,
|
||||
)
|
||||
if _user:
|
||||
base_q = base_q.filter(DBSession.owner == _user)
|
||||
rows = base_q.order_by(DBChatMessage.timestamp.desc()).limit(limit).all()
|
||||
|
||||
results = []
|
||||
for msg, session_name in rows:
|
||||
content = msg.content or ""
|
||||
lower_content = content.lower()
|
||||
idx = lower_content.find(query_term.lower())
|
||||
if idx == -1:
|
||||
snippet = content[:120]
|
||||
else:
|
||||
start = max(0, idx - 50)
|
||||
end = min(len(content), idx + len(query_term) + 50)
|
||||
snippet = ("..." if start > 0 else "") + content[start:end] + ("..." if end < len(content) else "")
|
||||
|
||||
results.append({
|
||||
"session_id": msg.session_id,
|
||||
"session_name": session_name or "Untitled",
|
||||
"role": msg.role,
|
||||
"content_snippet": snippet,
|
||||
"timestamp": msg.timestamp.isoformat() if msg.timestamp else None,
|
||||
})
|
||||
|
||||
return results
|
||||
finally:
|
||||
db.close()
|
||||
]
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# POST /api/rewrite — lightweight rewrite of last AI message (no tools)
|
||||
|
||||
@@ -0,0 +1,170 @@
|
||||
"""ChatGPT Subscription device-flow setup routes."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Dict, Optional
|
||||
|
||||
from fastapi import HTTPException, Request
|
||||
|
||||
from core.database import ModelEndpoint, ProviderAuthSession, SessionLocal, utcnow_naive
|
||||
from routes.device_flow import (
|
||||
DeviceFlowPoll,
|
||||
DeviceFlowStart,
|
||||
PendingDeviceFlowStore,
|
||||
create_device_flow_router,
|
||||
)
|
||||
from src.auth_helpers import get_current_user
|
||||
from src import chatgpt_subscription
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_DEVICE_FLOW_STORE = PendingDeviceFlowStore()
|
||||
|
||||
|
||||
def _provision_endpoint(tokens: Dict, owner: Optional[str]) -> Dict:
|
||||
access_token = tokens.get("access_token")
|
||||
refresh_token = tokens.get("refresh_token")
|
||||
if not access_token or not refresh_token:
|
||||
raise ValueError("ChatGPT token response was missing access_token or refresh_token")
|
||||
|
||||
base = chatgpt_subscription.DEFAULT_CHATGPT_SUBSCRIPTION_BASE_URL
|
||||
models = chatgpt_subscription.fetch_available_models(access_token)
|
||||
if not models:
|
||||
raise ValueError("ChatGPT Subscription connected, but no usable Codex models were discovered for this account.")
|
||||
db = SessionLocal()
|
||||
try:
|
||||
auth = (
|
||||
db.query(ProviderAuthSession)
|
||||
.filter(
|
||||
ProviderAuthSession.provider == chatgpt_subscription.CHATGPT_SUBSCRIPTION_PROVIDER,
|
||||
ProviderAuthSession.owner == owner,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if auth is None:
|
||||
auth = ProviderAuthSession(
|
||||
id=str(uuid.uuid4())[:8],
|
||||
provider=chatgpt_subscription.CHATGPT_SUBSCRIPTION_PROVIDER,
|
||||
owner=owner,
|
||||
label="ChatGPT Subscription",
|
||||
base_url=base,
|
||||
auth_mode="chatgpt",
|
||||
)
|
||||
db.add(auth)
|
||||
auth.base_url = base
|
||||
auth.access_token = access_token
|
||||
auth.refresh_token = refresh_token
|
||||
auth.last_refresh = utcnow_naive()
|
||||
auth.auth_mode = "chatgpt"
|
||||
|
||||
ep = (
|
||||
db.query(ModelEndpoint)
|
||||
.filter(
|
||||
ModelEndpoint.base_url == base,
|
||||
ModelEndpoint.provider_auth_id == auth.id,
|
||||
ModelEndpoint.owner == owner,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if ep is None:
|
||||
ep = ModelEndpoint(
|
||||
id=str(uuid.uuid4())[:8],
|
||||
name="ChatGPT Subscription",
|
||||
base_url=base,
|
||||
model_type="llm",
|
||||
endpoint_kind="api",
|
||||
owner=owner,
|
||||
)
|
||||
db.add(ep)
|
||||
ep.name = "ChatGPT Subscription"
|
||||
ep.base_url = base
|
||||
ep.api_key = None
|
||||
ep.provider_auth_id = auth.id
|
||||
ep.is_enabled = True
|
||||
ep.supports_tools = False
|
||||
ep.model_type = "llm"
|
||||
ep.endpoint_kind = "api"
|
||||
ep.model_refresh_mode = "manual"
|
||||
ep.cached_models = json.dumps(models)
|
||||
db.commit()
|
||||
result = {
|
||||
"id": ep.id,
|
||||
"name": ep.name,
|
||||
"base_url": ep.base_url,
|
||||
"models": models,
|
||||
}
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
try:
|
||||
from routes.model_routes import _invalidate_models_cache
|
||||
|
||||
_invalidate_models_cache()
|
||||
except Exception:
|
||||
pass
|
||||
return result
|
||||
|
||||
|
||||
def _start_device_flow(request: Request, _form) -> DeviceFlowStart:
|
||||
try:
|
||||
data = chatgpt_subscription.request_device_code()
|
||||
except Exception as exc:
|
||||
raise chatgpt_subscription.to_http_exception(exc)
|
||||
|
||||
device_auth_id = data.get("device_auth_id")
|
||||
user_code = data.get("user_code")
|
||||
if not device_auth_id or not user_code:
|
||||
raise HTTPException(502, "ChatGPT did not return a complete device code")
|
||||
verification_uri = data.get("verification_uri") or f"{chatgpt_subscription.CHATGPT_OAUTH_ISSUER}/codex/device"
|
||||
return DeviceFlowStart(
|
||||
pending={
|
||||
"device_auth_id": device_auth_id,
|
||||
"user_code": user_code,
|
||||
"owner": get_current_user(request) or None,
|
||||
},
|
||||
response={
|
||||
"user_code": user_code,
|
||||
"verification_uri": verification_uri,
|
||||
},
|
||||
interval=int(data.get("interval") or 5),
|
||||
expires_in=int(data.get("expires_in") or 900),
|
||||
)
|
||||
|
||||
|
||||
def _poll_device_flow(_request: Request, pending: Dict) -> DeviceFlowPoll:
|
||||
try:
|
||||
data = chatgpt_subscription.poll_device_auth(pending["device_auth_id"], pending["user_code"])
|
||||
except Exception as exc:
|
||||
logger.debug("ChatGPT device poll failed: %s", exc)
|
||||
return DeviceFlowPoll.pending(str(exc))
|
||||
|
||||
authorization_code = data.get("authorization_code")
|
||||
code_verifier = data.get("code_verifier")
|
||||
if authorization_code and code_verifier:
|
||||
try:
|
||||
tokens = chatgpt_subscription.exchange_authorization_code(authorization_code, code_verifier)
|
||||
result = _provision_endpoint(tokens, pending["owner"])
|
||||
except Exception as exc:
|
||||
logger.exception("ChatGPT Subscription endpoint provisioning failed")
|
||||
raise chatgpt_subscription.to_http_exception(exc)
|
||||
return DeviceFlowPoll.authorized(result)
|
||||
|
||||
err = data.get("error") or data.get("status")
|
||||
if err in ("authorization_pending", "pending", None):
|
||||
return DeviceFlowPoll.pending()
|
||||
if err == "slow_down":
|
||||
return DeviceFlowPoll.slow_down(int(data.get("interval") or 0) or None)
|
||||
if err in ("expired_token", "access_denied", "denied"):
|
||||
return DeviceFlowPoll.failed(err)
|
||||
return DeviceFlowPoll.pending(err or "unknown")
|
||||
|
||||
|
||||
def setup_chatgpt_subscription_routes():
|
||||
return create_device_flow_router(
|
||||
prefix="/api/chatgpt-subscription",
|
||||
tags=["chatgpt-subscription"],
|
||||
store=_DEVICE_FLOW_STORE,
|
||||
start_flow=_start_device_flow,
|
||||
poll_flow=_poll_device_flow,
|
||||
)
|
||||
+16
-6
@@ -15,8 +15,9 @@ from typing import Any
|
||||
from fastapi import APIRouter, BackgroundTasks, Body, HTTPException, Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
from src.auth_helpers import require_user
|
||||
from src.auth_helpers import require_authenticated_request, require_user
|
||||
from src.tool_implementations import do_manage_notes
|
||||
from src.constants import COOKBOOK_STATE_FILE
|
||||
|
||||
|
||||
COOKBOOK_READ_SCOPES = {"cookbook:read", "cookbook:launch"}
|
||||
@@ -41,7 +42,9 @@ async def _as_owner(request: Request, owner: str, fn, *args, **kwargs):
|
||||
the scope-gated owner (not the "api" pseudo-user the bearer middleware sets).
|
||||
Restores the original value when done. Works for sync and async handlers."""
|
||||
orig = getattr(request.state, "current_user", None)
|
||||
orig_api_token = getattr(request.state, "api_token", None)
|
||||
request.state.current_user = owner
|
||||
request.state.api_token = False
|
||||
try:
|
||||
result = fn(*args, **kwargs)
|
||||
if asyncio.iscoroutine(result):
|
||||
@@ -49,6 +52,13 @@ async def _as_owner(request: Request, owner: str, fn, *args, **kwargs):
|
||||
return result
|
||||
finally:
|
||||
request.state.current_user = orig
|
||||
if orig_api_token is None:
|
||||
try:
|
||||
delattr(request.state, "api_token")
|
||||
except AttributeError:
|
||||
pass
|
||||
else:
|
||||
request.state.api_token = orig_api_token
|
||||
|
||||
|
||||
def _scope_owner(request: Request, allowed: set[str]) -> str:
|
||||
@@ -146,7 +156,7 @@ def setup_codex_routes(
|
||||
|
||||
@router.get("/plugin.zip")
|
||||
def plugin_zip(request: Request):
|
||||
require_user(request)
|
||||
require_authenticated_request(request)
|
||||
root = Path(__file__).resolve().parent.parent / "integrations" / "codex"
|
||||
if not root.exists():
|
||||
raise HTTPException(404, "Codex plugin bundle not found")
|
||||
@@ -415,8 +425,8 @@ def setup_codex_routes(
|
||||
|
||||
def _read_cookbook_state() -> dict:
|
||||
from pathlib import Path as _Path
|
||||
import os as _os, json as _json
|
||||
p = _Path(_os.environ.get("DATA_DIR", "data")) / "cookbook_state.json"
|
||||
import json as _json
|
||||
p = _Path(COOKBOOK_STATE_FILE)
|
||||
if not p.exists():
|
||||
return {}
|
||||
try:
|
||||
@@ -724,7 +734,7 @@ def setup_codex_routes(
|
||||
import time as _t, json as _json
|
||||
from core.atomic_io import atomic_write_json
|
||||
from pathlib import Path as _Path
|
||||
cookbook_state_path = _Path("/app/data/cookbook_state.json")
|
||||
cookbook_state_path = _Path(COOKBOOK_STATE_FILE)
|
||||
try:
|
||||
state = _json.loads(cookbook_state_path.read_text(encoding="utf-8"))
|
||||
except Exception:
|
||||
@@ -762,7 +772,7 @@ def setup_claude_routes() -> APIRouter:
|
||||
|
||||
@router.get("/plugin.zip")
|
||||
def plugin_zip(request: Request):
|
||||
require_user(request)
|
||||
require_authenticated_request(request)
|
||||
# Only ship the skills/ subtree so extracting at ~/.claude/ doesn't dump
|
||||
# README.md or other bundle metadata into the user's claude config dir.
|
||||
skills_root = Path(__file__).resolve().parent.parent / "integrations" / "claude" / "skills"
|
||||
|
||||
+110
-22
@@ -12,6 +12,7 @@ import logging
|
||||
from core.database import Comparison, SessionLocal
|
||||
from core.session_manager import SessionManager
|
||||
from src.auth_helpers import get_current_user
|
||||
from routes.session_routes import _reject_raw_endpoint_url_for_non_admin
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -38,6 +39,24 @@ def _owned_endpoint_by_url(db, base_url, owner):
|
||||
return owner_filter(q, ModelEndpoint, owner).first()
|
||||
|
||||
|
||||
def _owned_endpoint_by_id(db, endpoint_id, owner):
|
||||
"""ModelEndpoint whose id == `endpoint_id` and is VISIBLE to `owner` (their
|
||||
own rows + legacy null-owner "shared" rows); None otherwise.
|
||||
|
||||
Preferred over _owned_endpoint_by_url for credential resolution: two visible
|
||||
endpoints can share the same base_url but hold DIFFERENT api_keys (e.g. two
|
||||
accounts on the same provider). A base_url-only match returns whichever row
|
||||
sorts first, so it can copy the WRONG owner-scoped key into the [CMP] session.
|
||||
An id pins the exact registered endpoint, so /api/compare/start prefers it and
|
||||
only falls back to URL matching for legacy / admin raw-URL callers. Owner
|
||||
scoping is identical to _owned_endpoint_by_url (a null/empty owner is a no-op).
|
||||
"""
|
||||
from core.database import ModelEndpoint
|
||||
from src.auth_helpers import owner_filter
|
||||
q = db.query(ModelEndpoint).filter(ModelEndpoint.id == endpoint_id)
|
||||
return owner_filter(q, ModelEndpoint, owner).first()
|
||||
|
||||
|
||||
class RecordVoteRequest(BaseModel):
|
||||
prompt: str
|
||||
models: List[str]
|
||||
@@ -54,8 +73,10 @@ def setup_compare_routes(session_manager: SessionManager):
|
||||
prompt: str = Form(...),
|
||||
model_a: str = Form(...),
|
||||
model_b: str = Form(...),
|
||||
endpoint_a: str = Form(...),
|
||||
endpoint_b: str = Form(...),
|
||||
endpoint_a: str = Form(""),
|
||||
endpoint_b: str = Form(""),
|
||||
endpoint_a_id: str = Form(""),
|
||||
endpoint_b_id: str = Form(""),
|
||||
is_blind: str = Form("true"),
|
||||
):
|
||||
"""Create two ephemeral sessions and a comparison record.
|
||||
@@ -63,10 +84,10 @@ def setup_compare_routes(session_manager: SessionManager):
|
||||
Returns the comparison ID and the two session IDs so the client
|
||||
can fire two independent SSE streams to /api/chat_stream.
|
||||
"""
|
||||
user = getattr(request.state, 'current_user', None)
|
||||
comp_id = str(uuid.uuid4())
|
||||
sid_a = str(uuid.uuid4())
|
||||
sid_b = str(uuid.uuid4())
|
||||
user = getattr(request.state, 'current_user', None)
|
||||
|
||||
# Blind mapping: randomly assign left/right
|
||||
blind = str(is_blind).lower() == "true"
|
||||
@@ -87,31 +108,94 @@ def setup_compare_routes(session_manager: SessionManager):
|
||||
# de-anonymizing the comparison before the user votes (issue #1285).
|
||||
slot_name = {session_left: "Model A", session_right: "Model B"}
|
||||
|
||||
# Create ephemeral sessions (prefixed [CMP])
|
||||
for sid, model, endpoint in [(sid_a, model_a, endpoint_a), (sid_b, model_b, endpoint_b)]:
|
||||
# SECURITY: resolve and validate BOTH endpoints before creating any
|
||||
# session. Compare copies a registered endpoint's Authorization header
|
||||
# into the [CMP] session, so validating one endpoint while creating its
|
||||
# session, then rejecting the other, would leave a partial compare
|
||||
# session behind with that header attached. Doing all the owner-scope
|
||||
# resolution + raw-URL rejection up front means a 403 on either endpoint
|
||||
# aborts the whole request with nothing created and no header copied.
|
||||
from src.endpoint_resolver import build_chat_url, build_headers, normalize_base
|
||||
resolved = []
|
||||
db = SessionLocal()
|
||||
try:
|
||||
for sid, model, endpoint, endpoint_id in [
|
||||
(sid_a, model_a, endpoint_a, endpoint_a_id),
|
||||
(sid_b, model_b, endpoint_b, endpoint_b_id),
|
||||
]:
|
||||
# Prefer an explicit endpoint id: it pins the EXACT registered
|
||||
# endpoint (and its api_key), even when two endpoints visible to
|
||||
# the caller share a base_url with different keys — a URL-only
|
||||
# match would copy whichever row sorts first, i.e. possibly the
|
||||
# wrong key. Fall back to URL resolution only for legacy / admin
|
||||
# raw-URL callers that don't send an id.
|
||||
eid = endpoint_id.strip() if isinstance(endpoint_id, str) else ""
|
||||
if eid:
|
||||
ep = _owned_endpoint_by_id(db, eid, user)
|
||||
if ep is None:
|
||||
# An id the caller can't see (wrong owner / deleted) must
|
||||
# NOT silently fall back to a same-URL row with a different
|
||||
# key — that's exactly the mix-up ids exist to prevent.
|
||||
raise HTTPException(404, "Model endpoint not found")
|
||||
# The id already resolved the endpoint; ignore any raw URL the
|
||||
# caller also sent and dial the stored config instead.
|
||||
endpoint = ep.base_url
|
||||
elif not endpoint:
|
||||
raise HTTPException(
|
||||
422, "endpoint_a/endpoint_b or endpoint_a_id/endpoint_b_id is required"
|
||||
)
|
||||
else:
|
||||
# Resolve the supplied URL to a ModelEndpoint the caller owns
|
||||
# (their own rows + legacy null-owner shared rows), scoped so a
|
||||
# comparison can't borrow another user's private endpoint key.
|
||||
base = normalize_base(endpoint)
|
||||
ep = _owned_endpoint_by_url(db, base, user)
|
||||
# Reject *unregistered* raw URLs for signed-in non-admins; a
|
||||
# matched registered endpoint supplies an id so the caller can
|
||||
# still compare endpoints they own. Blanket-rejecting here (the
|
||||
# earlier `endpoint_id=None` call) locked non-admins out of
|
||||
# compare entirely, since compare resolves endpoints by URL with
|
||||
# no endpoint_id. Mirrors the gallery inpaint/harmonize checks.
|
||||
# Raised here (phase 1), before any session exists.
|
||||
_reject_raw_endpoint_url_for_non_admin(
|
||||
request, user, str(ep.id) if ep is not None else None, endpoint
|
||||
)
|
||||
# Bind the [CMP] session to the RESOLVED endpoint, not the raw
|
||||
# caller-supplied string. When the URL matches a registered
|
||||
# endpoint visible to the caller, use that row's own normalized
|
||||
# base URL (the same value owner scoping + endpoint validation
|
||||
# already vetted) so the session dials exactly where the stored
|
||||
# config points. The raw `endpoint` only survives for callers
|
||||
# allowed to pass one — admins / single-user mode, where
|
||||
# `_reject_raw_endpoint_url_for_non_admin` is a no-op and `ep`
|
||||
# is None. Mirrors the registered-endpoint path in session_routes.
|
||||
session_endpoint_url = (
|
||||
build_chat_url(normalize_base(ep.base_url)) if ep is not None else endpoint
|
||||
)
|
||||
# Headers come only from a matched endpoint's key; None when
|
||||
# `ep` is None (raw admin URL or no match), so a comparison can
|
||||
# never inherit another user's key/headers.
|
||||
headers = build_headers(ep.api_key, ep.base_url) if (ep and ep.api_key) else None
|
||||
resolved.append((sid, model, session_endpoint_url, headers))
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
# Both endpoints validated — only now create the ephemeral [CMP]
|
||||
# sessions and copy any resolved headers.
|
||||
for sid, model, session_endpoint_url, headers in resolved:
|
||||
name = f"[CMP] {slot_name[sid]}" if blind else f"[CMP] {model.split('/')[-1]}"
|
||||
session_manager.create_session(
|
||||
session_id=sid,
|
||||
name=name,
|
||||
endpoint_url=endpoint,
|
||||
endpoint_url=session_endpoint_url,
|
||||
model=model,
|
||||
rag=False,
|
||||
owner=user,
|
||||
)
|
||||
# Copy API key from endpoint config
|
||||
db = SessionLocal()
|
||||
try:
|
||||
from src.endpoint_resolver import build_headers, normalize_base
|
||||
# Find matching endpoint by URL, scoped to the caller so a
|
||||
# comparison can't borrow another user's private endpoint key.
|
||||
base = normalize_base(endpoint)
|
||||
ep = _owned_endpoint_by_url(db, base, user)
|
||||
if ep and ep.api_key:
|
||||
s = session_manager.sessions.get(sid)
|
||||
if s:
|
||||
s.headers = build_headers(ep.api_key, ep.base_url)
|
||||
finally:
|
||||
db.close()
|
||||
if headers:
|
||||
s = session_manager.sessions.get(sid)
|
||||
if s:
|
||||
s.headers = headers
|
||||
|
||||
# Store comparison record
|
||||
db = SessionLocal()
|
||||
@@ -121,8 +205,12 @@ def setup_compare_routes(session_manager: SessionManager):
|
||||
prompt=prompt,
|
||||
model_a=model_a,
|
||||
model_b=model_b,
|
||||
endpoint_a=endpoint_a,
|
||||
endpoint_b=endpoint_b,
|
||||
# Record the URL the session actually dials. For URL callers this
|
||||
# is their raw input; for id-only callers (empty endpoint_a/_b)
|
||||
# fall back to the resolved endpoint URL so the column stays
|
||||
# meaningful and non-null. resolved is in [a, b] order.
|
||||
endpoint_a=endpoint_a or resolved[0][2],
|
||||
endpoint_b=endpoint_b or resolved[1][2],
|
||||
is_blind=blind,
|
||||
blind_mapping=json.dumps(mapping),
|
||||
owner=user,
|
||||
|
||||
+53
-18
@@ -11,20 +11,24 @@ import uuid
|
||||
import json
|
||||
import csv
|
||||
import io
|
||||
import os
|
||||
import httpx
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
from fastapi import APIRouter, Query, Depends, Response
|
||||
from urllib.parse import urljoin, urlparse, urlunparse
|
||||
|
||||
from fastapi import APIRouter, Query, Depends, Response, HTTPException
|
||||
from typing import List, Dict, Optional
|
||||
|
||||
from src.auth_helpers import require_user
|
||||
from core.middleware import require_admin
|
||||
from src.url_safety import check_outbound_url
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DATA_DIR = Path(__file__).resolve().parent.parent / "data"
|
||||
SETTINGS_FILE = DATA_DIR / "settings.json"
|
||||
LOCAL_CONTACTS_FILE = DATA_DIR / "contacts.json"
|
||||
from src.constants import DATA_DIR as _DATA_DIR, SETTINGS_FILE as _SETTINGS_FILE, CONTACTS_FILE as _CONTACTS_FILE
|
||||
DATA_DIR = Path(_DATA_DIR)
|
||||
SETTINGS_FILE = Path(_SETTINGS_FILE)
|
||||
LOCAL_CONTACTS_FILE = Path(_CONTACTS_FILE)
|
||||
|
||||
|
||||
def _load_settings():
|
||||
@@ -53,6 +57,21 @@ def _carddav_configured(cfg: Optional[Dict] = None) -> bool:
|
||||
return bool((cfg.get("url") or "").strip())
|
||||
|
||||
|
||||
def _validate_carddav_url(url: str) -> str:
|
||||
cleaned = (url if isinstance(url, str) else "").strip().rstrip("/")
|
||||
ok, reason = check_outbound_url(
|
||||
cleaned,
|
||||
block_private=os.getenv("CARDDAV_BLOCK_PRIVATE_IPS", "false").lower() == "true",
|
||||
)
|
||||
if not ok:
|
||||
raise ValueError(f"Rejected CardDAV URL: {reason}")
|
||||
return cleaned
|
||||
|
||||
|
||||
def _carddav_base_url(cfg: Dict) -> str:
|
||||
return _validate_carddav_url(cfg.get("url") or "")
|
||||
|
||||
|
||||
def _normalize_contact(contact: Dict) -> Dict:
|
||||
emails = []
|
||||
for e in contact.get("emails") or ([] if not contact.get("email") else [contact.get("email")]):
|
||||
@@ -219,14 +238,18 @@ _contact_cache = {"contacts": [], "fetched_at": None}
|
||||
def _abs_url(href: str) -> str:
|
||||
"""Combine a multistatus <href> (an absolute path like
|
||||
/user/contacts/x.vcf) with the configured CardDAV server origin so we
|
||||
get a fully-qualified URL to PUT/DELETE. If href is already absolute
|
||||
(http...), return it as-is."""
|
||||
from urllib.parse import urlparse, urlunparse
|
||||
if href.startswith("http://") or href.startswith("https://"):
|
||||
return href
|
||||
get a fully-qualified URL to PUT/DELETE. Absolute hrefs are accepted only
|
||||
for the configured origin; a cross-origin href is treated as a path on the
|
||||
configured server so a malicious CardDAV response cannot redirect later
|
||||
writes/deletes to cloud metadata or another host."""
|
||||
cfg = _get_carddav_config()
|
||||
p = urlparse(cfg["url"])
|
||||
return urlunparse((p.scheme, p.netloc, href, "", "", ""))
|
||||
base = _carddav_base_url(cfg)
|
||||
base_p = urlparse(base)
|
||||
joined = urljoin(base.rstrip("/") + "/", href or "")
|
||||
joined_p = urlparse(joined)
|
||||
if (joined_p.scheme, joined_p.netloc) != (base_p.scheme, base_p.netloc):
|
||||
joined = urlunparse((base_p.scheme, base_p.netloc, joined_p.path or "/", "", joined_p.query, ""))
|
||||
return _validate_carddav_url(joined)
|
||||
|
||||
|
||||
# CardDAV REPORT body — pull every card's etag + raw vCard in ONE request,
|
||||
@@ -297,6 +320,7 @@ def _fetch_contacts(force=False):
|
||||
return contacts
|
||||
|
||||
try:
|
||||
cfg["url"] = _carddav_base_url(cfg)
|
||||
auth = None
|
||||
if cfg["username"]:
|
||||
auth = (cfg["username"], cfg["password"])
|
||||
@@ -353,8 +377,8 @@ def _create_contact(name: str, email: str) -> bool:
|
||||
|
||||
contact_uid = str(uuid.uuid4())
|
||||
vcard = _build_vcard(name, email, contact_uid)
|
||||
url = cfg["url"].rstrip("/") + "/" + contact_uid + ".vcf"
|
||||
try:
|
||||
url = _carddav_base_url(cfg) + "/" + contact_uid + ".vcf"
|
||||
auth = None
|
||||
if cfg["username"]:
|
||||
auth = (cfg["username"], cfg["password"])
|
||||
@@ -382,7 +406,7 @@ def _vcard_url(uid: str) -> str:
|
||||
escape the collection and target an arbitrary CardDAV resource."""
|
||||
from urllib.parse import quote
|
||||
cfg = _get_carddav_config()
|
||||
return cfg["url"].rstrip("/") + "/" + quote(uid, safe="") + ".vcf"
|
||||
return _carddav_base_url(cfg) + "/" + quote(uid, safe="") + ".vcf"
|
||||
|
||||
|
||||
def _import_vcards(text: str) -> Dict:
|
||||
@@ -413,6 +437,11 @@ def _import_vcards(text: str) -> Dict:
|
||||
if imported:
|
||||
_save_local_contacts(contacts)
|
||||
return {"imported": imported, "failed": 0, "total": len(parsed)}
|
||||
try:
|
||||
base_url = _carddav_base_url(cfg)
|
||||
except ValueError as e:
|
||||
logger.warning("CardDAV import URL rejected: %s", e)
|
||||
return {"imported": 0, "failed": 0, "total": 0, "error": str(e)}
|
||||
auth = (cfg["username"], cfg["password"]) if cfg["username"] else None
|
||||
# Split into individual cards. re.split drops the BEGIN line, so we
|
||||
# re-add it. Normalize CRLF.
|
||||
@@ -441,7 +470,7 @@ def _import_vcards(text: str) -> Dict:
|
||||
elif not re.search(r"^VERSION:", block, re.MULTILINE):
|
||||
block = block.replace("BEGIN:VCARD", "BEGIN:VCARD\nVERSION:4.0", 1)
|
||||
vcard = block.replace("\n", "\r\n") + "\r\n"
|
||||
url = cfg["url"].rstrip("/") + "/" + quote(uid, safe="") + ".vcf"
|
||||
url = base_url + "/" + quote(uid, safe="") + ".vcf"
|
||||
try:
|
||||
r = httpx.put(
|
||||
url, data=vcard.encode("utf-8"),
|
||||
@@ -601,8 +630,8 @@ def _update_contact(uid: str, name: str, emails: List[str], phones: List[str]) -
|
||||
vcard = _build_vcard(name, "", uid=uid, emails=emails, phones=phones)
|
||||
# Use the real resource href (handles externally-created contacts whose
|
||||
# filename != UID); falls back to the <uid>.vcf guess.
|
||||
url = _resolve_resource_url(uid)
|
||||
try:
|
||||
url = _resolve_resource_url(uid)
|
||||
auth = (cfg["username"], cfg["password"]) if cfg["username"] else None
|
||||
r = httpx.put(
|
||||
url,
|
||||
@@ -630,8 +659,8 @@ def _delete_contact(uid: str) -> bool:
|
||||
_save_local_contacts(remaining)
|
||||
return True
|
||||
|
||||
url = _resolve_resource_url(uid)
|
||||
try:
|
||||
url = _resolve_resource_url(uid)
|
||||
auth = (cfg["username"], cfg["password"]) if cfg["username"] else None
|
||||
r = httpx.delete(url, auth=auth, timeout=10)
|
||||
if r.status_code in (200, 204):
|
||||
@@ -747,7 +776,13 @@ def setup_contacts_routes():
|
||||
settings = _load_settings()
|
||||
for key in ("carddav_url", "carddav_username", "carddav_password"):
|
||||
if key in data:
|
||||
settings[key] = data[key]
|
||||
if key == "carddav_url" and str(data[key] or "").strip():
|
||||
try:
|
||||
settings[key] = _validate_carddav_url(data[key])
|
||||
except ValueError as e:
|
||||
raise HTTPException(400, str(e))
|
||||
else:
|
||||
settings[key] = data[key]
|
||||
_save_settings(settings)
|
||||
# Force re-fetch
|
||||
_contact_cache["fetched_at"] = None
|
||||
|
||||
+312
-14
@@ -11,6 +11,8 @@ import shlex
|
||||
from fastapi import HTTPException
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.platform_compat import _ssh_exec_argv
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -195,6 +197,20 @@ def _pip_install_attempt(pip_cmd: str) -> str:
|
||||
)
|
||||
|
||||
|
||||
def _pip_command(python_cmd: str) -> str:
|
||||
"""Return a pip command for either a pip executable or a Python executable."""
|
||||
cmd = python_cmd.strip()
|
||||
if " -m pip" in cmd or cmd in {"pip", "pip3"}:
|
||||
return python_cmd
|
||||
if cmd in {"python", "python3", "python.exe"} or cmd.endswith(("/python", "/python3", "\\python.exe")):
|
||||
return f"{python_cmd} -m pip"
|
||||
return python_cmd
|
||||
|
||||
|
||||
def _pip_break_system_packages_check(pip_cmd: str) -> str:
|
||||
return f"{pip_cmd} install --help 2>/dev/null | grep -q -- --break-system-packages"
|
||||
|
||||
|
||||
def _pip_install_fallback_chain(package: str, *, python_cmd: str = "python3 -m pip", upgrade: bool = False) -> str:
|
||||
"""Build a bash pip install fallback chain that surfaces errors.
|
||||
|
||||
@@ -206,33 +222,44 @@ def _pip_install_fallback_chain(package: str, *, python_cmd: str = "python3 -m p
|
||||
exit code is preserved (no ``| tail`` masking) and the last 5 lines of
|
||||
pip output appear in the Cookbook log on failure.
|
||||
"""
|
||||
from core.platform_compat import IS_WINDOWS
|
||||
upgrade_flag = " -U" if upgrade else ""
|
||||
# Shell-quote the package spec: an extras spec like ``llama-cpp-python[server]``
|
||||
# contains brackets that bash would treat as a glob, so it must be quoted
|
||||
# before being embedded in the install command. Plain names (e.g.
|
||||
# ``huggingface_hub``) are returned unchanged by ``shlex.quote``.
|
||||
pkg = shlex.quote(package)
|
||||
base = _pip_install_attempt(f"{python_cmd} install -q{upgrade_flag} {pkg}")
|
||||
user = _pip_install_attempt(f"{python_cmd} install --user --break-system-packages -q{upgrade_flag} {pkg}")
|
||||
# llama-cpp-python source builds are brittle on older distro pip/packaging
|
||||
# stacks (common on WSL images). Prefer the prebuilt wheel index whenever
|
||||
# this package is requested so dependency-install tasks are reliable.
|
||||
if "llama-cpp-python" in package:
|
||||
pkg += " --extra-index-url https://abetlen.github.io/llama-cpp-python/whl/cpu"
|
||||
|
||||
pip_cmd = _pip_command(python_cmd)
|
||||
base = _pip_install_attempt(f"{pip_cmd} install -q{upgrade_flag} {pkg}")
|
||||
user = _pip_install_attempt(f"{pip_cmd} install --user -q{upgrade_flag} {pkg}")
|
||||
user_break_system = _pip_install_attempt(f"{pip_cmd} install --user --break-system-packages -q{upgrade_flag} {pkg}")
|
||||
user_fallback = f"( {user} || {{ {_pip_break_system_packages_check(pip_cmd)} && {user_break_system}; }} )"
|
||||
# Derive the python executable for the venv detection check.
|
||||
# Must use the same interpreter that pip belongs to; hardcoding
|
||||
# python3 breaks when pip lives in a venv that only has "python".
|
||||
if " -m pip" in python_cmd:
|
||||
python_exe = python_cmd.replace(" -m pip", "")
|
||||
elif python_cmd.strip() == "pip":
|
||||
if " -m pip" in pip_cmd:
|
||||
python_exe = pip_cmd.replace(" -m pip", "")
|
||||
elif pip_cmd.strip() == "pip":
|
||||
python_exe = "python"
|
||||
elif python_cmd.strip() == "pip3":
|
||||
elif pip_cmd.strip() == "pip3":
|
||||
python_exe = "python3"
|
||||
else:
|
||||
python_exe = "python3"
|
||||
venv_check = f'{python_exe} -c "import sys; sys.exit(0 if sys.prefix != sys.base_prefix else 1)"'
|
||||
# Negated: `! venv_check` succeeds (exit 0) when NOT in a venv → `&&` tries
|
||||
# --user. When IN a venv `! venv_check` fails → `&&` skips --user and the
|
||||
# Negated: `! venv_check` succeeds (exit 0) when NOT in a venv -> `&&` tries
|
||||
# --user. When IN a venv `! venv_check` fails -> `&&` skips --user and the
|
||||
# group exits non-zero, propagating the base-install failure instead of
|
||||
# masking it as success (the `|| { venv_check || … }` shape from #903
|
||||
# swallowed the exit code because venv_check's exit-0 became the group's
|
||||
# result).
|
||||
return f"{base} || {{ ! {venv_check} && {user}; }}"
|
||||
# result). `--break-system-packages` is only attempted when the active pip
|
||||
# supports it; older pip versions abort with "no such option" otherwise.
|
||||
return f"{base} || {{ ! {venv_check} && {user_fallback}; }}"
|
||||
|
||||
|
||||
def _venv_safe_local_pip_install_cmd(cmd: str, *, local: bool, in_venv: bool) -> str:
|
||||
@@ -263,6 +290,55 @@ def _venv_safe_local_pip_install_cmd(cmd: str, *, local: bool, in_venv: bool) ->
|
||||
return shlex.join(stripped)
|
||||
|
||||
|
||||
def _pip_install_command_without_break_system_packages(cmd: str) -> str:
|
||||
try:
|
||||
parts = shlex.split(cmd)
|
||||
except ValueError:
|
||||
return cmd
|
||||
stripped = [part for part in parts if part != "--break-system-packages"]
|
||||
return shlex.join(stripped)
|
||||
|
||||
|
||||
def _pip_install_help_check_from_cmd(cmd: str) -> str | None:
|
||||
try:
|
||||
parts = shlex.split(cmd)
|
||||
except ValueError:
|
||||
return None
|
||||
try:
|
||||
install_index = parts.index("install")
|
||||
except ValueError:
|
||||
return None
|
||||
if install_index <= 0:
|
||||
return None
|
||||
pip_prefix = parts[:install_index]
|
||||
return f"{shlex.join(pip_prefix + ['install', '--help'])} 2>/dev/null | grep -q -- --break-system-packages"
|
||||
|
||||
|
||||
def _append_pip_install_runner_lines(runner_lines: list[str], cmd: str) -> None:
|
||||
"""Append a pip install command, guarding --break-system-packages support.
|
||||
|
||||
The Dependencies UI may submit ``python3 -m pip install --user
|
||||
--break-system-packages ...`` for non-venv installs. That flag is useful on
|
||||
PEP-668-locked distros, but older pip (including Ubuntu 22.04's apt pip in
|
||||
the NVIDIA CUDA base image) aborts with "no such option". Branch at runner
|
||||
time so stale browser JS and remote targets are handled by the server too.
|
||||
"""
|
||||
if "--break-system-packages" not in (cmd or ""):
|
||||
runner_lines.append(cmd)
|
||||
return
|
||||
help_check = _pip_install_help_check_from_cmd(cmd)
|
||||
without_break = _pip_install_command_without_break_system_packages(cmd)
|
||||
if not help_check or without_break == cmd:
|
||||
runner_lines.append(cmd)
|
||||
return
|
||||
runner_lines.append(f"if {help_check}; then")
|
||||
runner_lines.append(f" {cmd}")
|
||||
runner_lines.append("else")
|
||||
runner_lines.append(' echo "[odysseus] pip does not support --break-system-packages; installing without it."')
|
||||
runner_lines.append(f" {without_break}")
|
||||
runner_lines.append("fi")
|
||||
|
||||
|
||||
def _user_shell_path_bootstrap() -> list[str]:
|
||||
return [
|
||||
'ODYSSEUS_USER_SHELL="${SHELL:-}"',
|
||||
@@ -271,11 +347,14 @@ def _user_shell_path_bootstrap() -> list[str]:
|
||||
' if [ -n "$ODYSSEUS_USER_PATH" ]; then export PATH="$ODYSSEUS_USER_PATH:$PATH"; fi',
|
||||
'fi',
|
||||
'command -v python3 >/dev/null 2>&1 || python3() { python "$@"; }',
|
||||
'command -v python >/dev/null 2>&1 || python() { python3 "$@"; }',
|
||||
]
|
||||
|
||||
|
||||
def _cached_model_scan_script(model_dirs: list[str] | None = None) -> str:
|
||||
"""Build the standalone Python scanner used by /api/model/cached."""
|
||||
def _cached_model_scan_script(model_dirs: list[str] | None = None, add_hf_cache: str | None = None) -> str:
|
||||
"""Build the standalone Python scanner used by /api/model/cached.
|
||||
Allows for an additional HuggingFace cache path to be scanned (i.e. Windows HF cache for local WSL envs.)
|
||||
"""
|
||||
lines = [
|
||||
"import json, os, re, shutil, subprocess, urllib.request",
|
||||
"models = []",
|
||||
@@ -338,6 +417,15 @@ def _cached_model_scan_script(model_dirs: list[str] | None = None) -> str:
|
||||
" if f.is_file(): nf += 1; sz += f.stat().st_size",
|
||||
" if f.name.endswith('.incomplete'): ic = True",
|
||||
" snap = os.path.join(cache, d, 'snapshots')",
|
||||
" # Windows HF cache stores files directly in snapshots/; blobs/ may be empty.",
|
||||
" # Fallback: scan snapshots for real files when blobs yielded nothing.",
|
||||
" if sz == 0 and os.path.isdir(snap):",
|
||||
" for sd in os.listdir(snap):",
|
||||
" sf = os.path.join(snap, sd)",
|
||||
" if not os.path.isdir(sf): continue",
|
||||
" for f in os.scandir(sf):",
|
||||
" if f.is_file(): nf += 1; sz += f.stat().st_size",
|
||||
" if f.name.endswith('.incomplete'): ic = True",
|
||||
" is_diffusion = False; gguf_files = []",
|
||||
" if os.path.isdir(snap):",
|
||||
" for sd in os.listdir(snap):",
|
||||
@@ -346,6 +434,21 @@ def _cached_model_scan_script(model_dirs: list[str] | None = None) -> str:
|
||||
" if os.path.exists(os.path.join(sf, 'model_index.json')): is_diffusion = True",
|
||||
" for f in collect_ggufs(sf): f['rel_path'] = sd + '/' + f['rel_path']; gguf_files.append(f)",
|
||||
" models.append({'repo_id':rid,'size_bytes':sz,'nb_files':nf,'has_incomplete':ic,'path':cache,'is_diffusion':is_diffusion,'is_gguf':bool(gguf_files),'gguf_files':gguf_files})",
|
||||
"def hf_cache_paths():",
|
||||
" candidates = []",
|
||||
" def add(p):",
|
||||
" if not p: return",
|
||||
" p = os.path.expanduser(p)",
|
||||
" if p not in candidates: candidates.append(p)",
|
||||
" add(os.environ.get('HUGGINGFACE_HUB_CACHE'))",
|
||||
" hf_home = os.environ.get('HF_HOME')",
|
||||
" if hf_home: add(os.path.join(hf_home, 'hub'))",
|
||||
" add('~/.cache/huggingface/hub')",
|
||||
" # Docker images mount ./data/huggingface at /app/.cache/huggingface.",
|
||||
" # When HOME is /root, expanduser() misses that persisted cache.",
|
||||
" add('/app/.cache/huggingface/hub')",
|
||||
f" add({add_hf_cache!r})" if add_hf_cache else "",
|
||||
" return candidates",
|
||||
"def scan_dir(p):",
|
||||
" if not os.path.isdir(p) or not safe_path(p): return",
|
||||
" for d in sorted(os.listdir(p)):",
|
||||
@@ -409,7 +512,7 @@ def _cached_model_scan_script(model_dirs: list[str] | None = None) -> str:
|
||||
" seen.add(name)",
|
||||
" models.append({'repo_id':name,'size_bytes':size_bytes,'nb_files':1,'has_incomplete':False,'path':'ollama','backend':'ollama','is_ollama':True})",
|
||||
" return",
|
||||
"scan_hf(os.path.expanduser('~/.cache/huggingface/hub'))",
|
||||
"for _hf_cache in hf_cache_paths(): scan_hf(_hf_cache)",
|
||||
"scan_ollama()",
|
||||
"scan_ollama_api()",
|
||||
]
|
||||
@@ -525,6 +628,7 @@ def _validate_serve_cmd(v: str | None) -> str | None:
|
||||
# Backticks and raw newlines are never legitimate here.
|
||||
if any(c in v for c in ("`", "\n", "\r")):
|
||||
raise HTTPException(400, "Invalid characters in cmd")
|
||||
|
||||
# Known GGUF launcher prelude → validate the serve invocation(s) it guards.
|
||||
m = _GGUF_PRELUDE_RE.match(v)
|
||||
if m:
|
||||
@@ -533,9 +637,19 @@ def _validate_serve_cmd(v: str | None) -> str | None:
|
||||
for part in rest.split("||"):
|
||||
_check_serve_binary(part.strip())
|
||||
return v
|
||||
|
||||
# Otherwise: a single invocation — no shell metacharacters allowed.
|
||||
# Temporarily replace safe $(printf %s ...) expressions with a placeholder
|
||||
# to avoid triggering the metacharacter/command-injection checks.
|
||||
cleaned_v = v
|
||||
printf_matches = list(re.finditer(r"\$\(\s*printf\s+%s\s+([^\n()]*?)\)", v))
|
||||
for match in printf_matches:
|
||||
inner = match.group(1)
|
||||
if not any(c in inner for c in (";", "&&", "||", "$(", "`")):
|
||||
cleaned_v = cleaned_v.replace(match.group(0), "/placeholder/safe/path.gguf")
|
||||
|
||||
# (`$(` was the original intent; bare `$` is fine for shell-safe paths.)
|
||||
if any(c in v for c in (";", "&&", "||", "$(")):
|
||||
if any(c in cleaned_v for c in (";", "&&", "||", "$(")):
|
||||
raise HTTPException(400, "Invalid characters in cmd")
|
||||
_check_serve_binary(v)
|
||||
return v
|
||||
@@ -559,6 +673,21 @@ def _append_serve_preflight_exit_lines(runner_lines: list[str], *, keep_shell_op
|
||||
runner_lines.append('fi')
|
||||
|
||||
|
||||
def _append_vllm_linux_preflight_lines(runner_lines: list[str]) -> None:
|
||||
"""Append Linux vLLM readiness lines that identify the runtime being used."""
|
||||
# Keep the user install bin visible for Odysseus-managed `pip install --user`
|
||||
# installs, but then report the actual CLI path so external runtimes are clear.
|
||||
runner_lines.append('export PATH="$HOME/.local/bin:$PATH"')
|
||||
runner_lines.append('ODYSSEUS_VLLM_BIN="$(command -v vllm 2>/dev/null || true)"')
|
||||
runner_lines.append('if [ -z "$ODYSSEUS_VLLM_BIN" ]; then')
|
||||
runner_lines.append(' echo "ERROR: vLLM is not installed."')
|
||||
runner_lines.append(' ODYSSEUS_PREFLIGHT_EXIT=127')
|
||||
runner_lines.append('else')
|
||||
runner_lines.append(' echo "[odysseus] vLLM CLI: $ODYSSEUS_VLLM_BIN"')
|
||||
runner_lines.append(' ODYSSEUS_VLLM_VERSION="$("$ODYSSEUS_VLLM_BIN" --version 2>&1 | head -n 1 || true)"')
|
||||
runner_lines.append(' if [ -n "$ODYSSEUS_VLLM_VERSION" ]; then echo "[odysseus] vLLM version: $ODYSSEUS_VLLM_VERSION"; fi')
|
||||
runner_lines.append('fi')
|
||||
|
||||
def _append_serve_exit_code_lines(
|
||||
runner_lines: list[str],
|
||||
*,
|
||||
@@ -804,3 +933,172 @@ def _ssh_ps(host, script_path, port=None):
|
||||
|
||||
# Windows session dir — stored in user's temp on the remote
|
||||
WIN_SESSION_DIR = "$env:TEMP\\\\odysseus-sessions"
|
||||
|
||||
|
||||
def _diagnose_serve_output(text: str) -> dict | None:
|
||||
"""Server-side mirror of the Cookbook UI's common serve diagnoses.
|
||||
|
||||
The browser uses cookbook-diagnosis.js for clickable fixes. This gives
|
||||
the agent/tool path the same structured signal so it can retry with an
|
||||
adjusted command instead of guessing from raw tmux output.
|
||||
"""
|
||||
if not text:
|
||||
return None
|
||||
tail = text[-6000:]
|
||||
patterns = [
|
||||
(
|
||||
r"No available memory for the cache blocks|Available KV cache memory:.*-",
|
||||
"No GPU memory left for KV cache after loading model.",
|
||||
[
|
||||
{"label": "retry with GPU memory utilization 0.95", "op": "replace", "flag": "--gpu-memory-utilization", "value": "0.95"},
|
||||
{"label": "retry with context 2048", "op": "replace", "flag": "--max-model-len", "value": "2048"},
|
||||
],
|
||||
),
|
||||
(
|
||||
r"CUDA out of memory|torch\.cuda\.OutOfMemoryError|CUDA error: out of memory|warming up sampler|max_num_seqs.*gpu_memory_utilization",
|
||||
"GPU ran out of memory during startup or warmup.",
|
||||
[
|
||||
{"label": "retry with context 4096", "op": "replace", "flag": "--max-model-len", "value": "4096"},
|
||||
{"label": "retry with GPU memory utilization 0.80", "op": "replace", "flag": "--gpu-memory-utilization", "value": "0.80"},
|
||||
{"label": "retry with --enforce-eager", "op": "append", "arg": "--enforce-eager"},
|
||||
],
|
||||
),
|
||||
(
|
||||
r"not divisib|must be divisible|attention heads.*divisible",
|
||||
"Tensor parallel size is incompatible with the model.",
|
||||
[
|
||||
{"label": "retry with tensor parallel size 1", "op": "replace", "flag": "--tensor-parallel-size", "value": "1"},
|
||||
{"label": "retry with tensor parallel size 2", "op": "replace", "flag": "--tensor-parallel-size", "value": "2"},
|
||||
],
|
||||
),
|
||||
(
|
||||
r"KV cache.*too (small|large)|max_model_len.*exceeds|maximum.*context",
|
||||
"Context length is too large for available GPU memory.",
|
||||
[
|
||||
{"label": "retry with context 8192", "op": "replace", "flag": "--max-model-len", "value": "8192"},
|
||||
{"label": "retry with context 4096", "op": "replace", "flag": "--max-model-len", "value": "4096"},
|
||||
],
|
||||
),
|
||||
(
|
||||
r"enable-auto-tool-choice requires --tool-call-parser",
|
||||
"Auto tool choice requires an explicit tool call parser.",
|
||||
[{"label": "retry with Hermes tool parser", "op": "append", "arg": "--tool-call-parser hermes"}],
|
||||
),
|
||||
(
|
||||
r"Please pass.*trust.remote.code=True|contains custom code which must be executed to correctly load|does not recognize this architecture|model type.*but Transformers does not",
|
||||
"Model requires custom code or newer model support.",
|
||||
[{"label": "retry with --trust-remote-code", "op": "append", "arg": "--trust-remote-code"}],
|
||||
),
|
||||
(
|
||||
r"There is no module or parameter named ['\"]lm_head\.input_scale['\"]|lm_head\.input_scale|weight_scale_2",
|
||||
"vLLM cannot load this ModelOpt LM-head quantized checkpoint with the current runtime.",
|
||||
[
|
||||
{
|
||||
"label": "upgrade vLLM through the environment that provides this CLI, or use a compatible checkpoint",
|
||||
"op": "manual",
|
||||
}
|
||||
],
|
||||
),
|
||||
(
|
||||
r"Either a revision or a version must be specified|transformers\.integrations\.hub_kernels|kernels/layer",
|
||||
"vLLM/Transformers kernel package mismatch.",
|
||||
[{"label": "update vLLM, Transformers, and kernels on this server", "op": "dependency", "package": "vllm transformers kernels"}],
|
||||
),
|
||||
(
|
||||
r"Address already in use|bind.*address.*in use",
|
||||
"Port is already in use.",
|
||||
[{"label": "retry on port 8001", "op": "replace", "flag": "--port", "value": "8001"}],
|
||||
),
|
||||
(
|
||||
r"No CUDA GPUs are available|no GPU.*found|CUDA_VISIBLE_DEVICES.*invalid",
|
||||
"No GPUs are visible to the serve process.",
|
||||
[{"label": "clear Cookbook GPU selection or choose available GPUs", "op": "settings", "field": "gpus", "value": ""}],
|
||||
),
|
||||
(
|
||||
r"Failed to infer device type|NVML Shared Library Not Found|No module named 'amdsmi'|platform is not available",
|
||||
"vLLM could not find a supported GPU (CUDA or ROCm). "
|
||||
"This machine may have integrated or unsupported graphics only.",
|
||||
[
|
||||
{"label": "switch to llama.cpp (CPU/Metal, works without a discrete GPU)", "op": "manual"},
|
||||
{"label": "switch to Ollama (CPU/Metal, works without a discrete GPU)", "op": "manual"},
|
||||
],
|
||||
),
|
||||
(
|
||||
r"vllm.*command not found|No module named vllm|ERROR: vLLM is not installed",
|
||||
"vLLM is not installed or not in PATH on this server.",
|
||||
[{"label": "install vLLM in Cookbook Dependencies", "op": "dependency", "package": "vllm"}],
|
||||
),
|
||||
(
|
||||
r"sglang.*command not found|No module named sglang|SGLang is not installed",
|
||||
"SGLang is not installed or not in PATH on this server.",
|
||||
[{"label": "install SGLang in Cookbook Dependencies", "op": "dependency", "package": "sglang[all]"}],
|
||||
),
|
||||
(
|
||||
r"llama-server.*command not found|llama\.cpp.*not found|No module named.*llama_cpp|No module named 'starlette_context'|git: command not found|cmake: command not found",
|
||||
"llama.cpp / llama-cpp-python dependencies are missing.",
|
||||
[{"label": "install llama.cpp dependencies or llama-cpp-python[server]", "op": "dependency", "package": "llama-cpp-python[server]"}],
|
||||
),
|
||||
(
|
||||
r"No GGUF found on this host|no \.gguf file|No GGUF file found",
|
||||
"No GGUF file found for this model on this host. The llama.cpp backend needs a .gguf file.",
|
||||
[{"label": "download a GGUF build of this model (repo name usually ends in -GGUF, file like Q4_K_M.gguf)", "op": "manual"}],
|
||||
),
|
||||
(
|
||||
r"No module named 'torch'|No module named torch|No module named 'diffusers'|No module named diffusers",
|
||||
"Diffusion serving requires PyTorch and diffusers.",
|
||||
[{"label": "install diffusers[torch] in Cookbook Dependencies", "op": "dependency", "package": "diffusers[torch]"}],
|
||||
),
|
||||
(
|
||||
r"403 Forbidden|401 Unauthorized|Access to model.*is restricted|gated repo|not in the authorized list|awaiting a review",
|
||||
"Model access is gated or unauthorized.",
|
||||
[{"label": "set HF token and request model access on HuggingFace", "op": "manual"}],
|
||||
),
|
||||
]
|
||||
for pattern, message, suggestions in patterns:
|
||||
if re.search(pattern, tail, re.I):
|
||||
return {"message": message, "suggestions": suggestions}
|
||||
if re.search(r"Traceback \(most recent call last\)", tail, re.I) and not re.search(
|
||||
r"Application startup complete|GET /v1/|Uvicorn running on", tail, re.I
|
||||
):
|
||||
return {
|
||||
"message": "Python traceback detected during serve startup.",
|
||||
"suggestions": [{"label": "inspect traceback and retry with adjusted backend/settings", "op": "manual"}],
|
||||
}
|
||||
return None
|
||||
|
||||
|
||||
async def run_ssh_command_async(
|
||||
remote: str,
|
||||
ssh_port: str | None,
|
||||
remote_cmd: str,
|
||||
*,
|
||||
timeout: float,
|
||||
connect_timeout: int | None = None,
|
||||
strict_host_key_checking: bool | None = None,
|
||||
stdin_data: bytes | None = None,
|
||||
) -> tuple[int, bytes, bytes]:
|
||||
"""Run an ssh command with centralized timeout and stderr/stdout capture.
|
||||
Async version of core.platform_compat.run_ssh_command_sync.
|
||||
"""
|
||||
import asyncio
|
||||
proc = await asyncio.create_subprocess_exec(
|
||||
*_ssh_exec_argv(
|
||||
remote,
|
||||
ssh_port,
|
||||
remote_cmd=remote_cmd,
|
||||
connect_timeout=connect_timeout,
|
||||
strict_host_key_checking=strict_host_key_checking,
|
||||
),
|
||||
stdin=asyncio.subprocess.PIPE if stdin_data is not None else None,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
)
|
||||
try:
|
||||
stdout, stderr = await asyncio.wait_for(
|
||||
proc.communicate(input=stdin_data), timeout=timeout
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
proc.kill()
|
||||
await proc.communicate()
|
||||
raise
|
||||
return proc.returncode or 0, stdout, stderr
|
||||
|
||||
+189
-206
@@ -15,19 +15,26 @@ from pathlib import Path
|
||||
from fastapi import APIRouter, HTTPException, Request, Depends
|
||||
|
||||
from src.auth_helpers import require_user
|
||||
from src.constants import COOKBOOK_STATE_FILE
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.middleware import require_admin
|
||||
from core.platform_compat import (
|
||||
IS_WINDOWS,
|
||||
SSH_PATH_OVERRIDE,
|
||||
NVIDIA_PATH_CANDIDATES,
|
||||
detached_popen_kwargs,
|
||||
find_bash,
|
||||
git_bash_path,
|
||||
kill_process_tree,
|
||||
pid_alive,
|
||||
safe_chmod,
|
||||
which_tool,
|
||||
translate_path,
|
||||
get_wsl_windows_user_profile,
|
||||
)
|
||||
from routes.shell_routes import TMUX_LOG_DIR
|
||||
from src.constants import COOKBOOK_STATE_FILE
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -38,8 +45,10 @@ from routes.cookbook_helpers import (
|
||||
_ps_squote, _bash_squote, _validate_serve_cmd, _parse_serve_phase,
|
||||
_safe_env_prefix, _local_tooling_path_export, _append_serve_preflight_exit_lines,
|
||||
_append_serve_exit_code_lines, _append_llama_cpp_linux_accel_build_lines, _cached_model_scan_script,
|
||||
_ollama_bind_from_cmd, _pip_install_fallback_chain, _pip_install_no_cache,
|
||||
_user_shell_path_bootstrap, _venv_safe_local_pip_install_cmd,
|
||||
_append_vllm_linux_preflight_lines, _ollama_bind_from_cmd, _pip_install_fallback_chain,
|
||||
_pip_install_no_cache, _user_shell_path_bootstrap, _venv_safe_local_pip_install_cmd,
|
||||
_append_pip_install_runner_lines,
|
||||
_diagnose_serve_output, run_ssh_command_async,
|
||||
ModelDownloadRequest, ServeRequest,
|
||||
)
|
||||
|
||||
@@ -54,7 +63,7 @@ _HF_TOKEN_STATUS_SNIPPET = (
|
||||
|
||||
def setup_cookbook_routes() -> APIRouter:
|
||||
router = APIRouter(tags=["cookbook"])
|
||||
_cookbook_state_path = Path(os.environ.get("DATA_DIR", "data")) / "cookbook_state.json"
|
||||
_cookbook_state_path = Path(COOKBOOK_STATE_FILE)
|
||||
|
||||
def _mask_secret(value: str) -> str:
|
||||
if not value:
|
||||
@@ -81,127 +90,6 @@ def setup_cookbook_routes() -> APIRouter:
|
||||
task["payload"].pop("hf_token", None)
|
||||
return state
|
||||
|
||||
def _diagnose_serve_output(text: str) -> dict | None:
|
||||
"""Server-side mirror of the Cookbook UI's common serve diagnoses.
|
||||
|
||||
The browser uses cookbook-diagnosis.js for clickable fixes. This gives
|
||||
the agent/tool path the same structured signal so it can retry with an
|
||||
adjusted command instead of guessing from raw tmux output.
|
||||
"""
|
||||
if not text:
|
||||
return None
|
||||
tail = text[-6000:]
|
||||
patterns = [
|
||||
(
|
||||
r"No available memory for the cache blocks|Available KV cache memory:.*-",
|
||||
"No GPU memory left for KV cache after loading model.",
|
||||
[
|
||||
{"label": "retry with GPU memory utilization 0.95", "op": "replace", "flag": "--gpu-memory-utilization", "value": "0.95"},
|
||||
{"label": "retry with context 2048", "op": "replace", "flag": "--max-model-len", "value": "2048"},
|
||||
],
|
||||
),
|
||||
(
|
||||
r"CUDA out of memory|torch\.cuda\.OutOfMemoryError|CUDA error: out of memory|warming up sampler|max_num_seqs.*gpu_memory_utilization",
|
||||
"GPU ran out of memory during startup or warmup.",
|
||||
[
|
||||
{"label": "retry with context 4096", "op": "replace", "flag": "--max-model-len", "value": "4096"},
|
||||
{"label": "retry with GPU memory utilization 0.80", "op": "replace", "flag": "--gpu-memory-utilization", "value": "0.80"},
|
||||
{"label": "retry with --enforce-eager", "op": "append", "arg": "--enforce-eager"},
|
||||
],
|
||||
),
|
||||
(
|
||||
r"not divisib|must be divisible|attention heads.*divisible",
|
||||
"Tensor parallel size is incompatible with the model.",
|
||||
[
|
||||
{"label": "retry with tensor parallel size 1", "op": "replace", "flag": "--tensor-parallel-size", "value": "1"},
|
||||
{"label": "retry with tensor parallel size 2", "op": "replace", "flag": "--tensor-parallel-size", "value": "2"},
|
||||
],
|
||||
),
|
||||
(
|
||||
r"KV cache.*too (small|large)|max_model_len.*exceeds|maximum.*context",
|
||||
"Context length is too large for available GPU memory.",
|
||||
[
|
||||
{"label": "retry with context 8192", "op": "replace", "flag": "--max-model-len", "value": "8192"},
|
||||
{"label": "retry with context 4096", "op": "replace", "flag": "--max-model-len", "value": "4096"},
|
||||
],
|
||||
),
|
||||
(
|
||||
r"enable-auto-tool-choice requires --tool-call-parser",
|
||||
"Auto tool choice requires an explicit tool call parser.",
|
||||
[{"label": "retry with Hermes tool parser", "op": "append", "arg": "--tool-call-parser hermes"}],
|
||||
),
|
||||
(
|
||||
r"Please pass.*trust.remote.code=True|contains custom code which must be executed to correctly load|does not recognize this architecture|model type.*but Transformers does not",
|
||||
"Model requires custom code or newer model support.",
|
||||
[{"label": "retry with --trust-remote-code", "op": "append", "arg": "--trust-remote-code"}],
|
||||
),
|
||||
(
|
||||
r"Either a revision or a version must be specified|transformers\.integrations\.hub_kernels|kernels/layer",
|
||||
"vLLM/Transformers kernel package mismatch.",
|
||||
[{"label": "update vLLM, Transformers, and kernels on this server", "op": "dependency", "package": "vllm transformers kernels"}],
|
||||
),
|
||||
(
|
||||
r"Address already in use|bind.*address.*in use",
|
||||
"Port is already in use.",
|
||||
[{"label": "retry on port 8001", "op": "replace", "flag": "--port", "value": "8001"}],
|
||||
),
|
||||
(
|
||||
r"No CUDA GPUs are available|no GPU.*found|CUDA_VISIBLE_DEVICES.*invalid",
|
||||
"No GPUs are visible to the serve process.",
|
||||
[{"label": "clear Cookbook GPU selection or choose available GPUs", "op": "settings", "field": "gpus", "value": ""}],
|
||||
),
|
||||
(
|
||||
r"Failed to infer device type|NVML Shared Library Not Found|No module named 'amdsmi'|platform is not available",
|
||||
"vLLM could not find a supported GPU (CUDA or ROCm). "
|
||||
"This machine may have integrated or unsupported graphics only.",
|
||||
[
|
||||
{"label": "switch to llama.cpp (CPU/Metal, works without a discrete GPU)", "op": "manual"},
|
||||
{"label": "switch to Ollama (CPU/Metal, works without a discrete GPU)", "op": "manual"},
|
||||
],
|
||||
),
|
||||
(
|
||||
r"vllm.*command not found|No module named vllm|ERROR: vLLM is not installed",
|
||||
"vLLM is not installed or not in PATH on this server.",
|
||||
[{"label": "install vLLM in Cookbook Dependencies", "op": "dependency", "package": "vllm"}],
|
||||
),
|
||||
(
|
||||
r"sglang.*command not found|No module named sglang|SGLang is not installed",
|
||||
"SGLang is not installed or not in PATH on this server.",
|
||||
[{"label": "install SGLang in Cookbook Dependencies", "op": "dependency", "package": "sglang[all]"}],
|
||||
),
|
||||
(
|
||||
r"llama-server.*command not found|llama\.cpp.*not found|No module named.*llama_cpp|No module named 'starlette_context'|git: command not found|cmake: command not found",
|
||||
"llama.cpp / llama-cpp-python dependencies are missing.",
|
||||
[{"label": "install llama.cpp dependencies or llama-cpp-python[server]", "op": "dependency", "package": "llama-cpp-python[server]"}],
|
||||
),
|
||||
(
|
||||
r"No GGUF found on this host|no \.gguf file|No GGUF file found",
|
||||
"No GGUF file found for this model on this host. The llama.cpp backend needs a .gguf file.",
|
||||
[{"label": "download a GGUF build of this model (repo name usually ends in -GGUF, file like Q4_K_M.gguf)", "op": "manual"}],
|
||||
),
|
||||
(
|
||||
r"No module named 'torch'|No module named torch|No module named 'diffusers'|No module named diffusers",
|
||||
"Diffusion serving requires PyTorch and diffusers.",
|
||||
[{"label": "install diffusers[torch] in Cookbook Dependencies", "op": "dependency", "package": "diffusers[torch]"}],
|
||||
),
|
||||
(
|
||||
r"403 Forbidden|401 Unauthorized|Access to model.*is restricted|gated repo|not in the authorized list|awaiting a review",
|
||||
"Model access is gated or unauthorized.",
|
||||
[{"label": "set HF token and request model access on HuggingFace", "op": "manual"}],
|
||||
),
|
||||
]
|
||||
for pattern, message, suggestions in patterns:
|
||||
if re.search(pattern, tail, re.I):
|
||||
return {"message": message, "suggestions": suggestions}
|
||||
if re.search(r"Traceback \(most recent call last\)", tail, re.I) and not re.search(
|
||||
r"Application startup complete|GET /v1/|Uvicorn running on", tail, re.I
|
||||
):
|
||||
return {
|
||||
"message": "Python traceback detected during serve startup.",
|
||||
"suggestions": [{"label": "inspect traceback and retry with adjusted backend/settings", "op": "manual"}],
|
||||
}
|
||||
return None
|
||||
|
||||
def _state_for_client(state):
|
||||
"""Return cookbook state without raw secrets for browser clients."""
|
||||
_strip_task_secrets(state)
|
||||
@@ -295,6 +183,7 @@ def setup_cookbook_routes() -> APIRouter:
|
||||
safe_chmod(key_path.with_suffix(".pub"), 0o644)
|
||||
return {"ok": True, "public_key": _read_cookbook_public_key()}
|
||||
|
||||
|
||||
def _needs_binary(cmd: str, binary: str) -> bool:
|
||||
return bool(re.search(rf"(^|[\s;&|()]){re.escape(binary)}($|[\s;&|()])", cmd or ""))
|
||||
|
||||
@@ -355,8 +244,8 @@ def setup_cookbook_routes() -> APIRouter:
|
||||
# POSIX form + shell-quoting so drive paths / spaces survive.
|
||||
inner = TMUX_LOG_DIR / f"{session_id}_run.sh"
|
||||
inner.write_text("\n".join(bash_lines) + "\n", encoding="utf-8")
|
||||
lp = shlex.quote(log_path.as_posix())
|
||||
ip = shlex.quote(inner.as_posix())
|
||||
lp = shlex.quote(git_bash_path(log_path))
|
||||
ip = shlex.quote(git_bash_path(inner))
|
||||
script_path = TMUX_LOG_DIR / f"{session_id}.sh"
|
||||
script_path.write_text(
|
||||
f"bash {ip} > {lp} 2>&1\n",
|
||||
@@ -472,6 +361,8 @@ def setup_cookbook_routes() -> APIRouter:
|
||||
ps_lines = []
|
||||
ps_lines.append('$sessionDir = "$env:TEMP\\odysseus-sessions"')
|
||||
ps_lines.append('New-Item -ItemType Directory -Force -Path $sessionDir | Out-Null')
|
||||
ps_lines.append('$env:PYTHONIOENCODING = "utf-8"')
|
||||
ps_lines.append('$env:PYTHONUTF8 = "1"')
|
||||
if req.hf_token:
|
||||
ps_lines.append(f"$env:HF_TOKEN = '{_ps_squote(req.hf_token)}'")
|
||||
if req.env_prefix:
|
||||
@@ -545,7 +436,7 @@ def setup_cookbook_routes() -> APIRouter:
|
||||
# Install hf CLI + optional hf_transfer best-effort. Retries disable
|
||||
# hf_transfer because the Rust parallel path is fast but has been
|
||||
# flaky near the end of very large multi-file downloads.
|
||||
# Use --break-system-packages on PEP-668 systems (Arch, newer Debian) so it doesn't bail.
|
||||
# The helper tries active pip first, then guarded user-site fallbacks.
|
||||
runner_lines.append(f"command -v hf >/dev/null 2>&1 || {_pip_install_fallback_chain('huggingface_hub', python_cmd='pip', upgrade=True)}")
|
||||
if req.disable_hf_transfer:
|
||||
runner_lines.append("export HF_HUB_ENABLE_HF_TRANSFER=0")
|
||||
@@ -673,24 +564,35 @@ def setup_cookbook_routes() -> APIRouter:
|
||||
for d in model_dir.split(','):
|
||||
d = d.strip()
|
||||
if d:
|
||||
model_dirs.append(d)
|
||||
paths_code = _cached_model_scan_script(model_dirs)
|
||||
translated_d = translate_path(d) if not host else d
|
||||
model_dirs.append(translated_d)
|
||||
win_hf_hub = None
|
||||
if not host:
|
||||
win_profile = get_wsl_windows_user_profile()
|
||||
win_hf_hub = os.path.join(win_profile, ".cache", "huggingface", "hub") if win_profile else None
|
||||
|
||||
paths_code = _cached_model_scan_script(model_dirs, win_hf_hub)
|
||||
|
||||
scan_py = TMUX_LOG_DIR / "scan_cache.py"
|
||||
scan_py.write_text(paths_code, encoding="utf-8")
|
||||
scan_payload = scan_py.read_bytes()
|
||||
|
||||
if host:
|
||||
_pf = f"-p {ssh_port} " if ssh_port and ssh_port != "22" else ""
|
||||
if platform == "windows":
|
||||
# Windows: use 'python' and pipe via stdin with double-quote wrapping
|
||||
cmd = f'ssh {_pf}{host} "python -" < \'{scan_py}\''
|
||||
remote_cmd = "python -"
|
||||
else:
|
||||
cmd = f"ssh {_pf}{host} 'python3 -' < '{scan_py}'"
|
||||
proc = await asyncio.create_subprocess_shell(
|
||||
cmd,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
cwd=str(Path.home()),
|
||||
# POSIX: use 'python3' if available, fall back to 'python'; throw if neither is found.
|
||||
remote_cmd = (
|
||||
"if command -v python3 >/dev/null 2>&1; then python3 -; "
|
||||
"elif command -v python >/dev/null 2>&1; then python -; "
|
||||
"else echo \"python3/python not found\" >&2; exit 127; fi"
|
||||
)
|
||||
rc, stdout_b, stderr_b = await run_ssh_command_async(
|
||||
host,
|
||||
ssh_port,
|
||||
remote_cmd,
|
||||
timeout=60,
|
||||
stdin_data=scan_payload,
|
||||
)
|
||||
else:
|
||||
# LOCAL scan: use sys.executable (the venv Python Odysseus is already
|
||||
@@ -710,7 +612,7 @@ def setup_cookbook_routes() -> APIRouter:
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
cwd=str(Path.home()),
|
||||
)
|
||||
stdout_b, stderr_b = await asyncio.wait_for(proc.communicate(), timeout=60)
|
||||
stdout_b, stderr_b = await asyncio.wait_for(proc.communicate(), timeout=60)
|
||||
|
||||
models = []
|
||||
try:
|
||||
@@ -915,6 +817,10 @@ def setup_cookbook_routes() -> APIRouter:
|
||||
existing.name = display_name
|
||||
if supports_tools is not None:
|
||||
existing.supports_tools = supports_tools
|
||||
# Wipe stale model lists so the picker re-probes and discovers
|
||||
# the newly-served model instead of showing the old one.
|
||||
existing.cached_models = None
|
||||
existing.hidden_models = None
|
||||
db.commit()
|
||||
logger.info(f"Updated existing local model endpoint: {base_url}")
|
||||
return existing.id
|
||||
@@ -971,11 +877,27 @@ def setup_cookbook_routes() -> APIRouter:
|
||||
in_venv=sys.prefix != sys.base_prefix,
|
||||
)
|
||||
is_pip_install = bool(req.cmd and "pip install" in req.cmd)
|
||||
remote = req.remote_host
|
||||
is_windows = req.platform == "windows"
|
||||
local_windows = IS_WINDOWS and not remote
|
||||
if is_windows or local_windows:
|
||||
if req.cmd.startswith("python3 "):
|
||||
req.cmd = "python " + req.cmd[len("python3 "):]
|
||||
if is_pip_install and ("llama-cpp-python" in req.cmd or "llama_cpp" in req.cmd) and (is_windows or local_windows):
|
||||
if "--extra-index-url" not in req.cmd:
|
||||
req.cmd += " --extra-index-url https://abetlen.github.io/llama-cpp-python/whl/cpu"
|
||||
|
||||
if is_pip_install:
|
||||
# Keep big dependency wheel builds (vLLM, …) off the home filesystem's
|
||||
# pip cache so they don't fail mid-build with "No space left" (#1219)
|
||||
# and leave the dep installed-but-unusable (#1459).
|
||||
req.cmd = _pip_install_no_cache(req.cmd)
|
||||
# Accept common aliases and enforce server extras for llama-cpp so
|
||||
# `python -m llama_cpp.server` has all runtime dependencies.
|
||||
req.cmd = re.sub(r"(?<![A-Za-z0-9_.-])llama_cpp(?![A-Za-z0-9_.-])", "llama-cpp-python[server]", req.cmd)
|
||||
req.cmd = re.sub(r"(?<![A-Za-z0-9_.-])llama-cpp-python(?!\[)", "llama-cpp-python[server]", req.cmd)
|
||||
if "llama-cpp-python" in req.cmd and "--extra-index-url" not in req.cmd:
|
||||
req.cmd += " --extra-index-url https://abetlen.github.io/llama-cpp-python/whl/cpu"
|
||||
# PEP-508-style package spec — letters, digits, `.-_` for the
|
||||
# name; `[` `]` for extras; `<>=!~,` for version specifiers.
|
||||
# v2 review HIGH-14: tightened from the previous regex which
|
||||
@@ -1028,6 +950,8 @@ def setup_cookbook_routes() -> APIRouter:
|
||||
ps_lines = []
|
||||
ps_lines.append('$sessionDir = "$env:TEMP\\odysseus-sessions"')
|
||||
ps_lines.append('New-Item -ItemType Directory -Force -Path $sessionDir | Out-Null')
|
||||
ps_lines.append('$env:PYTHONIOENCODING = "utf-8"')
|
||||
ps_lines.append('$env:PYTHONUTF8 = "1"')
|
||||
if req.hf_token:
|
||||
ps_lines.append(f"$env:HF_TOKEN = '{_ps_squote(req.hf_token)}'")
|
||||
if req.gpus:
|
||||
@@ -1046,7 +970,7 @@ def setup_cookbook_routes() -> APIRouter:
|
||||
ps_lines.append('try { python -c "import llama_cpp" 2>$null } catch {}')
|
||||
ps_lines.append('if ($LASTEXITCODE -ne 0) {')
|
||||
ps_lines.append(' Write-Host "Installing llama-cpp-python..."')
|
||||
ps_lines.append(' python -m pip install llama-cpp-python[server]')
|
||||
ps_lines.append(' python -m pip install llama-cpp-python[server] --extra-index-url https://abetlen.github.io/llama-cpp-python/whl/cpu')
|
||||
ps_lines.append('}')
|
||||
elif "vllm" in req.cmd:
|
||||
ps_lines.append('Write-Host "ERROR: vLLM is not supported on Windows. Use Ollama or llama.cpp instead."')
|
||||
@@ -1121,45 +1045,57 @@ def setup_cookbook_routes() -> APIRouter:
|
||||
# ollama is found (otherwise macOS falls back to a slow source build).
|
||||
# /opt/homebrew = Apple Silicon, /usr/local = Intel; harmless on Linux.
|
||||
runner_lines.append('export PATH="$HOME/.local/bin:$HOME/bin:$HOME/llama.cpp/build/bin:/opt/homebrew/bin:/usr/local/bin:$PATH"')
|
||||
runner_lines.append('if [ -d /data/data/com.termux ]; then')
|
||||
runner_lines.append(' # Termux: no native build — use the Python bindings (CPU).')
|
||||
runner_lines.append(' if ! python3 -c "import llama_cpp" 2>/dev/null; then')
|
||||
runner_lines.append(' pkg install -y cmake 2>/dev/null')
|
||||
runner_lines.append(' pip install numpy diskcache jinja2 2>/dev/null')
|
||||
runner_lines.append(' CMAKE_ARGS="-DGGML_BLAS=OFF -DGGML_LLAMAFILE=OFF" pip install \'llama-cpp-python[server]\' --no-build-isolation --no-cache-dir 2>&1 || true')
|
||||
runner_lines.append(' fi')
|
||||
runner_lines.append('elif ! command -v llama-server &>/dev/null; then')
|
||||
runner_lines.append(' echo "Native llama-server not found — building from source (one-time, may take a few minutes)..."')
|
||||
runner_lines.append(' mkdir -p ~/bin')
|
||||
runner_lines.append(' cd ~ && [ -d llama.cpp ] || git clone --depth 1 https://github.com/ggml-org/llama.cpp')
|
||||
# Build with the right accelerator: Metal on macOS (llama.cpp
|
||||
# enables it automatically, no flag), CUDA on Linux when present,
|
||||
# else a plain CPU build. nproc is Linux-only — fall back to
|
||||
# `sysctl hw.ncpu` on macOS. (Tip: `brew install llama.cpp` ships
|
||||
# a prebuilt llama-server and skips this whole source build.)
|
||||
runner_lines.append(' NPROC="$(nproc 2>/dev/null || sysctl -n hw.ncpu 2>/dev/null || echo 4)"')
|
||||
runner_lines.append(' if [ "$(uname -s)" = "Darwin" ]; then')
|
||||
runner_lines.append(' command -v cmake >/dev/null 2>&1 || echo "WARNING: cmake not found — install it with: brew install cmake (or: brew install llama.cpp for a prebuilt llama-server)."')
|
||||
# Start from a clean cache: a prior failed configure (e.g. a CUDA
|
||||
# attempt) poisons build/CMakeCache.txt, so a plain `cmake -B build`
|
||||
# would reuse the bad settings and fail again. CMAKE_BUILD_TYPE is
|
||||
# explicit so the binary is optimized (Metal auto-enables on macOS).
|
||||
runner_lines.append(' cd ~/llama.cpp && rm -rf build && cmake -B build -DCMAKE_BUILD_TYPE=Release \\')
|
||||
runner_lines.append(' && cmake --build build -j"$NPROC" --target llama-server \\')
|
||||
runner_lines.append(' && ln -sf ~/llama.cpp/build/bin/llama-server ~/bin/llama-server')
|
||||
runner_lines.append(' else')
|
||||
_append_llama_cpp_linux_accel_build_lines(runner_lines)
|
||||
runner_lines.append(' fi')
|
||||
runner_lines.append(' # If the native build failed, fall back to the Python bindings.')
|
||||
runner_lines.append(' if ! command -v llama-server &>/dev/null && ! python3 -c "import llama_cpp" 2>/dev/null; then')
|
||||
runner_lines.append(' echo "llama-server build failed — installing Python bindings as fallback..."')
|
||||
runner_lines.append(f" {_pip_install_fallback_chain('llama-cpp-python[server]', python_cmd='pip')} || true")
|
||||
runner_lines.append(' fi')
|
||||
runner_lines.append(' if ! command -v llama-server &>/dev/null && ! python3 -c "import llama_cpp" 2>/dev/null; then')
|
||||
runner_lines.append(' echo "ERROR: llama.cpp serving is not available after install/build attempts."')
|
||||
runner_lines.append(' ODYSSEUS_PREFLIGHT_EXIT=127')
|
||||
runner_lines.append(' fi')
|
||||
runner_lines.append('fi')
|
||||
if local_windows:
|
||||
# LOCAL Windows: no native source compilation (no cmake/compiler on Git Bash).
|
||||
# Just check python bindings (using native `python` binary) and fall back to pip install.
|
||||
runner_lines.append('if ! command -v llama-server &>/dev/null && ! python -c "import llama_cpp" 2>/dev/null; then')
|
||||
runner_lines.append(' echo "llama-server not found — installing Python bindings..."')
|
||||
runner_lines.append(f" {_pip_install_fallback_chain('llama-cpp-python[server]', python_cmd='python')} || true")
|
||||
runner_lines.append('fi')
|
||||
runner_lines.append('if ! command -v llama-server &>/dev/null && ! python -c "import llama_cpp" 2>/dev/null; then')
|
||||
runner_lines.append(' echo "ERROR: llama.cpp serving is not available after install attempts."')
|
||||
runner_lines.append(' ODYSSEUS_PREFLIGHT_EXIT=127')
|
||||
runner_lines.append('fi')
|
||||
else:
|
||||
runner_lines.append('if [ -d /data/data/com.termux ]; then')
|
||||
runner_lines.append(' # Termux: no native build — use the Python bindings (CPU).')
|
||||
runner_lines.append(' if ! python3 -c "import llama_cpp" 2>/dev/null; then')
|
||||
runner_lines.append(' pkg install -y cmake 2>/dev/null')
|
||||
runner_lines.append(' pip install numpy diskcache jinja2 2>/dev/null')
|
||||
runner_lines.append(' CMAKE_ARGS="-DGGML_BLAS=OFF -DGGML_LLAMAFILE=OFF" pip install \'llama-cpp-python[server]\' --no-build-isolation --no-cache-dir 2>&1 || true')
|
||||
runner_lines.append(' fi')
|
||||
runner_lines.append('elif ! command -v llama-server &>/dev/null; then')
|
||||
runner_lines.append(' echo "Native llama-server not found — building from source (one-time, may take a few minutes)..."')
|
||||
runner_lines.append(' mkdir -p ~/bin')
|
||||
runner_lines.append(' cd ~ && [ -d llama.cpp ] || git clone --depth 1 https://github.com/ggml-org/llama.cpp')
|
||||
# Build with the right accelerator: Metal on macOS (llama.cpp
|
||||
# enables it automatically, no flag), CUDA on Linux when present,
|
||||
# else a plain CPU build. nproc is Linux-only — fall back to
|
||||
# `sysctl hw.ncpu` on macOS. (Tip: `brew install llama.cpp` ships
|
||||
# a prebuilt llama-server and skips this whole source build.)
|
||||
runner_lines.append(' NPROC="$(nproc 2>/dev/null || sysctl -n hw.ncpu 2>/dev/null || echo 4)"')
|
||||
runner_lines.append(' if [ "$(uname -s)" = "Darwin" ]; then')
|
||||
runner_lines.append(' command -v cmake >/dev/null 2>&1 || echo "WARNING: cmake not found — install it with: brew install cmake (or: brew install llama.cpp for a prebuilt llama-server)."')
|
||||
# Start from a clean cache: a prior failed configure (e.g. a CUDA
|
||||
# attempt) poisons build/CMakeCache.txt, so a plain `cmake -B build`
|
||||
# would reuse the bad settings and fail again. CMAKE_BUILD_TYPE is
|
||||
# explicit so the binary is optimized (Metal auto-enables on macOS).
|
||||
runner_lines.append(' cd ~/llama.cpp && rm -rf build && cmake -B build -DCMAKE_BUILD_TYPE=Release \\')
|
||||
runner_lines.append(' && cmake --build build -j"$NPROC" --target llama-server \\')
|
||||
runner_lines.append(' && ln -sf ~/llama.cpp/build/bin/llama-server ~/bin/llama-server')
|
||||
runner_lines.append(' else')
|
||||
_append_llama_cpp_linux_accel_build_lines(runner_lines)
|
||||
runner_lines.append(' fi')
|
||||
# If the native build failed, fall back to the Python bindings.
|
||||
runner_lines.append(' if ! command -v llama-server &>/dev/null && ! python3 -c "import llama_cpp" 2>/dev/null; then')
|
||||
runner_lines.append(' echo "llama-server build failed — installing Python bindings as fallback..."')
|
||||
runner_lines.append(f" {_pip_install_fallback_chain('llama-cpp-python[server]', python_cmd='pip')} || true")
|
||||
runner_lines.append(' fi')
|
||||
runner_lines.append(' if ! command -v llama-server &>/dev/null && ! python3 -c "import llama_cpp" 2>/dev/null; then')
|
||||
runner_lines.append(' echo "ERROR: llama.cpp serving is not available after install/build attempts."')
|
||||
runner_lines.append(' ODYSSEUS_PREFLIGHT_EXIT=127')
|
||||
runner_lines.append(' fi')
|
||||
runner_lines.append('fi')
|
||||
elif "ollama" in req.cmd:
|
||||
handled_ollama_serve = True
|
||||
_ollama_default_host = "0.0.0.0" if remote else "127.0.0.1"
|
||||
@@ -1181,13 +1117,23 @@ def setup_cookbook_routes() -> APIRouter:
|
||||
runner_lines.append(' ODYSSEUS_OLLAMA_PORT="$_ody_try_port"')
|
||||
runner_lines.append(' break')
|
||||
runner_lines.append(' fi')
|
||||
runner_lines.append(' exec 3<&-; exec 3>&-')
|
||||
runner_lines.append('done')
|
||||
runner_lines.append(' echo "[odysseus] Ollama API ready on port ${ODYSSEUS_OLLAMA_PORT}: ${ODYSSEUS_OLLAMA_URL}"')
|
||||
runner_lines.append(' echo "[odysseus] This task is monitoring an existing Ollama server; stopping it here will not stop an external Docker/system service."')
|
||||
if local_windows:
|
||||
# Windows detached process has no TTY; exec bash -i crashes.
|
||||
# Keep the monitoring task alive with a sleep loop.
|
||||
runner_lines.append(' while true; do sleep 60; done')
|
||||
else:
|
||||
runner_lines.append(' exec bash -i')
|
||||
runner_lines.append('fi')
|
||||
runner_lines.append('if ! command -v ollama &>/dev/null; then')
|
||||
runner_lines.append(' echo "ERROR: Ollama not found on this server. Install it from https://ollama.com/download or `curl -fsSL https://ollama.com/install.sh | sh`."')
|
||||
runner_lines.append(' echo')
|
||||
runner_lines.append(' echo "=== Process exited with code 127 ==="')
|
||||
runner_lines.append(' exec bash -i')
|
||||
if local_windows:
|
||||
runner_lines.append(' exit 127')
|
||||
else:
|
||||
runner_lines.append(' exec bash -i')
|
||||
runner_lines.append('fi')
|
||||
runner_lines.append('ODYSSEUS_OLLAMA_URL="http://${ODYSSEUS_OLLAMA_HOST}:${ODYSSEUS_OLLAMA_PORT}"')
|
||||
if remote and _ollama_host in ("0.0.0.0", "::"):
|
||||
@@ -1195,24 +1141,20 @@ def setup_cookbook_routes() -> APIRouter:
|
||||
runner_lines.append('echo "[odysseus] Ollama has no built-in authentication; expose this only on a trusted LAN/VPN or provide an explicit OLLAMA_HOST with your own access controls."')
|
||||
runner_lines.append('echo "Starting ollama server on ${ODYSSEUS_OLLAMA_HOST}:${ODYSSEUS_OLLAMA_PORT}..."')
|
||||
runner_lines.append('OLLAMA_HOST="${ODYSSEUS_OLLAMA_HOST}:${ODYSSEUS_OLLAMA_PORT}" ollama serve')
|
||||
runner_lines.append('_ody_exit=$?')
|
||||
runner_lines.append('echo')
|
||||
runner_lines.append('echo "=== Process exited with code ${_ody_exit} ==="')
|
||||
runner_lines.append('exec bash -i')
|
||||
if local_windows:
|
||||
_append_serve_exit_code_lines(runner_lines, keep_shell_open=False)
|
||||
else:
|
||||
runner_lines.append('_ody_exit=$?')
|
||||
runner_lines.append('echo')
|
||||
runner_lines.append('echo "=== Process exited with code ${_ody_exit} ==="')
|
||||
runner_lines.append('exec bash -i')
|
||||
elif "vllm serve" in req.cmd:
|
||||
# vLLM is CUDA/ROCm-only and does not run on macOS at all.
|
||||
runner_lines.append('if [ "$(uname -s)" = "Darwin" ]; then')
|
||||
runner_lines.append(' echo "ERROR: vLLM does not run on macOS. Use Ollama or llama.cpp (Metal) instead."')
|
||||
runner_lines.append(' ODYSSEUS_PREFLIGHT_EXIT=1')
|
||||
runner_lines.append('fi')
|
||||
# Put ~/.local/bin on PATH first — without a venv, vllm installs
|
||||
# there via --user and the non-login serve shell otherwise can't
|
||||
# find the `vllm` CLI ("command not found"). Mirrors llama.cpp above.
|
||||
runner_lines.append('export PATH="$HOME/.local/bin:$PATH"')
|
||||
runner_lines.append('if ! command -v vllm &>/dev/null; then')
|
||||
runner_lines.append(' echo "ERROR: vLLM is not installed."')
|
||||
runner_lines.append(' ODYSSEUS_PREFLIGHT_EXIT=127')
|
||||
runner_lines.append('fi')
|
||||
_append_vllm_linux_preflight_lines(runner_lines)
|
||||
elif "sglang.launch_server" in req.cmd:
|
||||
runner_lines.append('export PATH="$HOME/.local/bin:$PATH"')
|
||||
runner_lines.append('if ! command -v sglang &>/dev/null; then')
|
||||
@@ -1236,7 +1178,10 @@ def setup_cookbook_routes() -> APIRouter:
|
||||
runner_lines,
|
||||
keep_shell_open=not local_windows,
|
||||
)
|
||||
runner_lines.append(req.cmd)
|
||||
if is_pip_install:
|
||||
_append_pip_install_runner_lines(runner_lines, req.cmd)
|
||||
else:
|
||||
runner_lines.append(req.cmd)
|
||||
if local_windows:
|
||||
# Detached background process — no interactive shell to keep open.
|
||||
# Print the exit marker the status poller looks for, then stop.
|
||||
@@ -1397,8 +1342,8 @@ def setup_cookbook_routes() -> APIRouter:
|
||||
cmd = f"ssh {pf}{host} '{setup_script}'"
|
||||
else:
|
||||
# Linux: auto-install tmux (via whichever package manager is available)
|
||||
# and huggingface_hub + hf_transfer (falling back to --user/--break-system-packages
|
||||
# on PEP-668 locked distros like Arch / newer Debian).
|
||||
# and huggingface_hub + hf_transfer (falling back to --user, then
|
||||
# guarded --break-system-packages on PEP-668 locked distros).
|
||||
setup_script = (
|
||||
# Install tmux if missing — try common package managers; skip if no sudo
|
||||
"if ! command -v tmux >/dev/null 2>&1; then "
|
||||
@@ -1410,10 +1355,15 @@ def setup_cookbook_routes() -> APIRouter:
|
||||
" fi; "
|
||||
"fi; "
|
||||
"command -v tmux >/dev/null 2>&1 || echo 'WARNING: tmux missing and auto-install failed (need passwordless sudo). Install manually.'; "
|
||||
# Install Python bits. Try system install first; fall back to --user --break-system-packages on PEP 668 systems.
|
||||
# Install Python bits. Try system install first; fall back to --user,
|
||||
# then use --break-system-packages only when pip supports it.
|
||||
"pip install -q huggingface_hub hf_transfer 2>/dev/null || "
|
||||
"pip install --user --break-system-packages -q huggingface_hub hf_transfer 2>/dev/null || "
|
||||
"pip3 install --user --break-system-packages -q huggingface_hub hf_transfer 2>/dev/null; "
|
||||
"pip install --user -q huggingface_hub hf_transfer 2>/dev/null || "
|
||||
"( pip install --help 2>/dev/null | grep -q -- --break-system-packages && "
|
||||
"pip install --user --break-system-packages -q huggingface_hub hf_transfer 2>/dev/null ) || "
|
||||
"pip3 install --user -q huggingface_hub hf_transfer 2>/dev/null || "
|
||||
"( pip3 install --help 2>/dev/null | grep -q -- --break-system-packages && "
|
||||
"pip3 install --user --break-system-packages -q huggingface_hub hf_transfer 2>/dev/null ); "
|
||||
"python3 -c 'from huggingface_hub import snapshot_download; print(\"OK\")'"
|
||||
)
|
||||
cmd = f"ssh {pf}{host} '{setup_script}'"
|
||||
@@ -1436,11 +1386,38 @@ def setup_cookbook_routes() -> APIRouter:
|
||||
async def _run_nvidia_smi(query: str, host: str | None, ssh_port: str | None, timeout: int = 8):
|
||||
"""Run nvidia-smi locally or over SSH. Returns (stdout, error_or_None)."""
|
||||
if host:
|
||||
pf = f"-p {ssh_port} " if ssh_port and ssh_port != "22" else ""
|
||||
cmd = f"ssh -o ConnectTimeout=5 -o StrictHostKeyChecking=no {pf}{host} '{query}'"
|
||||
proc = await asyncio.create_subprocess_shell(
|
||||
cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE
|
||||
)
|
||||
candidates = [query]
|
||||
stripped = query.strip()
|
||||
if stripped.startswith("nvidia-smi "):
|
||||
args = stripped[len("nvidia-smi "):]
|
||||
candidates.append(
|
||||
"bash -lc "
|
||||
+ shlex.quote(
|
||||
f"{SSH_PATH_OVERRIDE}"
|
||||
f"nvidia-smi {args}"
|
||||
)
|
||||
)
|
||||
for nvidia_path in NVIDIA_PATH_CANDIDATES:
|
||||
candidates.append(f"{nvidia_path} {args}")
|
||||
|
||||
last_err = "nvidia-smi failed"
|
||||
for candidate in candidates:
|
||||
try:
|
||||
rc, stdout, stderr = await run_ssh_command_async(
|
||||
host,
|
||||
ssh_port,
|
||||
candidate,
|
||||
connect_timeout=5,
|
||||
timeout=timeout,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
return None, "nvidia-smi timed out"
|
||||
if rc == 0:
|
||||
return stdout.decode("utf-8", errors="replace"), None
|
||||
err = (stderr.decode("utf-8", errors="replace") or "").strip()[:200]
|
||||
if err:
|
||||
last_err = err
|
||||
return None, last_err
|
||||
else:
|
||||
proc = await asyncio.create_subprocess_exec(
|
||||
*shlex.split(query),
|
||||
@@ -2203,7 +2180,13 @@ def setup_cookbook_routes() -> APIRouter:
|
||||
"inc=os.path.isdir(blobs) and any(x.endswith('.incomplete') for x in os.listdir(blobs));"
|
||||
"sys.exit(0 if ok and not inc else 1)"
|
||||
)
|
||||
cmd = ["python3", "-c", py, repo_id]
|
||||
if remote_host:
|
||||
cmd = ["python3", "-c", py, repo_id]
|
||||
else:
|
||||
# Local Windows: python3 can hit the Microsoft Store stub. Use the
|
||||
# real Python Odysseus is running under (guaranteed to exist).
|
||||
import sys as _sys_local
|
||||
cmd = [_sys_local.executable, "-c", py, repo_id]
|
||||
try:
|
||||
if remote_host:
|
||||
ssh_base = ["ssh"]
|
||||
|
||||
+67
-117
@@ -20,39 +20,26 @@ All routes are admin-gated (endpoint/provider management is an admin action).
|
||||
"""
|
||||
|
||||
import json
|
||||
import time
|
||||
import uuid
|
||||
import logging
|
||||
import threading
|
||||
from typing import Dict, Optional
|
||||
|
||||
import httpx
|
||||
from fastapi import APIRouter, Request, Form, HTTPException
|
||||
from fastapi import HTTPException, Request
|
||||
|
||||
from core.database import SessionLocal, ModelEndpoint
|
||||
from core.middleware import require_admin
|
||||
from routes.device_flow import (
|
||||
DeviceFlowPoll,
|
||||
DeviceFlowStart,
|
||||
PendingDeviceFlowStore,
|
||||
create_device_flow_router,
|
||||
)
|
||||
from src.auth_helpers import get_current_user
|
||||
from src import copilot
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Pending device-flow logins, keyed by an opaque poll_id. The device_code is a
|
||||
# bearer-like secret, so it lives here (server memory) rather than in the
|
||||
# browser. Entries expire with the GitHub device code.
|
||||
#
|
||||
# NOTE: this is per-process state. The device flow assumes a single worker
|
||||
# (Odysseus' default): with multiple uvicorn workers, the poll request can land
|
||||
# on a worker that never saw the start, returning "Unknown or expired login
|
||||
# session". Move this to a shared store (DB/Redis) if running multi-worker.
|
||||
_PENDING: Dict[str, Dict] = {}
|
||||
_PENDING_LOCK = threading.Lock()
|
||||
|
||||
|
||||
def _prune_expired() -> None:
|
||||
now = time.time()
|
||||
with _PENDING_LOCK:
|
||||
for k in [k for k, v in _PENDING.items() if v.get("expires_at", 0) < now]:
|
||||
_PENDING.pop(k, None)
|
||||
_DEVICE_FLOW_STORE = PendingDeviceFlowStore()
|
||||
|
||||
|
||||
def _provision_endpoint(token: str, base: str, owner: Optional[str]) -> Dict:
|
||||
@@ -112,112 +99,75 @@ def _provision_endpoint(token: str, base: str, owner: Optional[str]) -> Dict:
|
||||
return result
|
||||
|
||||
|
||||
def setup_copilot_routes() -> APIRouter:
|
||||
router = APIRouter(prefix="/api/copilot", tags=["copilot"])
|
||||
def _start_device_flow(request: Request, form) -> DeviceFlowStart:
|
||||
host = copilot.GITHUB_HOST
|
||||
ent = str(form.get("enterprise_url") or "").strip()
|
||||
if ent:
|
||||
host = copilot.normalize_domain(ent)
|
||||
try:
|
||||
data = copilot.request_device_code(host)
|
||||
except httpx.HTTPStatusError as e:
|
||||
status = e.response.status_code if e.response is not None else "unknown"
|
||||
raise HTTPException(502, f"GitHub device-code request failed (HTTP {status})")
|
||||
except Exception as e:
|
||||
raise HTTPException(502, f"GitHub device-code request failed: {e}")
|
||||
|
||||
@router.post("/device/start")
|
||||
def device_start(request: Request, enterprise_url: str = Form("")):
|
||||
require_admin(request)
|
||||
_prune_expired()
|
||||
host = copilot.GITHUB_HOST
|
||||
ent = (enterprise_url or "").strip()
|
||||
if ent:
|
||||
host = copilot.normalize_domain(ent)
|
||||
try:
|
||||
data = copilot.request_device_code(host)
|
||||
except httpx.HTTPStatusError as e:
|
||||
status = e.response.status_code if e.response is not None else "unknown"
|
||||
raise HTTPException(502, f"GitHub device-code request failed (HTTP {status})")
|
||||
except Exception as e:
|
||||
raise HTTPException(502, f"GitHub device-code request failed: {e}")
|
||||
device_code = data.get("device_code")
|
||||
if not device_code:
|
||||
raise HTTPException(502, "GitHub did not return a device code")
|
||||
|
||||
device_code = data.get("device_code")
|
||||
if not device_code:
|
||||
raise HTTPException(502, "GitHub did not return a device code")
|
||||
interval = int(data.get("interval") or 5)
|
||||
expires_in = int(data.get("expires_in") or 900)
|
||||
poll_id = uuid.uuid4().hex
|
||||
with _PENDING_LOCK:
|
||||
_PENDING[poll_id] = {
|
||||
"device_code": device_code,
|
||||
"host": host,
|
||||
"enterprise_url": ent,
|
||||
"interval": interval,
|
||||
"owner": get_current_user(request) or None,
|
||||
"expires_at": time.time() + expires_in,
|
||||
"next_poll_at": 0.0,
|
||||
}
|
||||
# verification_uri_complete embeds the user code, so the browser tab we
|
||||
# open lands the user straight on GitHub's "Authorize" screen with the
|
||||
# code pre-filled — one click, no manual code entry.
|
||||
return {
|
||||
"poll_id": poll_id,
|
||||
# verification_uri_complete embeds the user code, so the browser tab we
|
||||
# open lands the user straight on GitHub's "Authorize" screen with the
|
||||
# code pre-filled — one click, no manual code entry.
|
||||
return DeviceFlowStart(
|
||||
pending={
|
||||
"device_code": device_code,
|
||||
"host": host,
|
||||
"enterprise_url": ent,
|
||||
"owner": get_current_user(request) or None,
|
||||
},
|
||||
response={
|
||||
"user_code": data.get("user_code"),
|
||||
"verification_uri": data.get("verification_uri"),
|
||||
"verification_uri_complete": data.get("verification_uri_complete"),
|
||||
"interval": interval,
|
||||
"expires_in": expires_in,
|
||||
}
|
||||
},
|
||||
interval=int(data.get("interval") or 5),
|
||||
expires_in=int(data.get("expires_in") or 900),
|
||||
)
|
||||
|
||||
@router.post("/device/poll")
|
||||
def device_poll(request: Request, poll_id: str = Form(...)):
|
||||
require_admin(request)
|
||||
_prune_expired()
|
||||
with _PENDING_LOCK:
|
||||
pending = _PENDING.get(poll_id)
|
||||
if not pending:
|
||||
raise HTTPException(404, "Unknown or expired login session")
|
||||
|
||||
# Enforce GitHub's polling interval server-side so a chatty client
|
||||
# can't trip slow_down.
|
||||
now = time.time()
|
||||
if now < pending.get("next_poll_at", 0):
|
||||
return {"status": "pending"}
|
||||
def _poll_device_flow(_request: Request, pending: Dict) -> DeviceFlowPoll:
|
||||
try:
|
||||
data = copilot.poll_access_token(pending["host"], pending["device_code"])
|
||||
except Exception as e:
|
||||
return DeviceFlowPoll.pending(f"poll error: {e}")
|
||||
|
||||
token = data.get("access_token")
|
||||
if token:
|
||||
base = copilot.enterprise_base(pending["enterprise_url"]) if pending["enterprise_url"] else copilot.COPILOT_BASE
|
||||
try:
|
||||
data = copilot.poll_access_token(pending["host"], pending["device_code"])
|
||||
result = _provision_endpoint(token, base, pending["owner"])
|
||||
except Exception as e:
|
||||
return {"status": "pending", "detail": f"poll error: {e}"}
|
||||
logger.exception("Copilot endpoint provisioning failed")
|
||||
raise HTTPException(500, f"Login succeeded but provisioning failed: {e}")
|
||||
return DeviceFlowPoll.authorized(result)
|
||||
|
||||
token = data.get("access_token")
|
||||
if token:
|
||||
base = copilot.enterprise_base(pending["enterprise_url"]) if pending["enterprise_url"] else copilot.COPILOT_BASE
|
||||
try:
|
||||
result = _provision_endpoint(token, base, pending["owner"])
|
||||
except Exception as e:
|
||||
logger.exception("Copilot endpoint provisioning failed")
|
||||
with _PENDING_LOCK:
|
||||
_PENDING.pop(poll_id, None)
|
||||
raise HTTPException(500, f"Login succeeded but provisioning failed: {e}")
|
||||
with _PENDING_LOCK:
|
||||
_PENDING.pop(poll_id, None)
|
||||
return {"status": "authorized", "endpoint": result}
|
||||
err = data.get("error")
|
||||
if err == "authorization_pending":
|
||||
return DeviceFlowPoll.pending()
|
||||
if err == "slow_down":
|
||||
return DeviceFlowPoll.slow_down(int(data.get("interval") or 0) or None)
|
||||
if err in ("expired_token", "access_denied"):
|
||||
return DeviceFlowPoll.failed(err)
|
||||
# Unknown error — surface but keep the session for another try.
|
||||
return DeviceFlowPoll.pending(err or "unknown")
|
||||
|
||||
err = data.get("error")
|
||||
if err == "authorization_pending":
|
||||
with _PENDING_LOCK:
|
||||
if poll_id in _PENDING:
|
||||
_PENDING[poll_id]["next_poll_at"] = now + pending["interval"]
|
||||
return {"status": "pending"}
|
||||
if err == "slow_down":
|
||||
new_interval = int(data.get("interval") or (pending["interval"] + 5))
|
||||
with _PENDING_LOCK:
|
||||
if poll_id in _PENDING:
|
||||
_PENDING[poll_id]["interval"] = new_interval
|
||||
_PENDING[poll_id]["next_poll_at"] = now + new_interval
|
||||
return {"status": "pending"}
|
||||
if err in ("expired_token", "access_denied"):
|
||||
with _PENDING_LOCK:
|
||||
_PENDING.pop(poll_id, None)
|
||||
return {"status": "failed", "error": err}
|
||||
# Unknown error — surface but keep the session for another try.
|
||||
return {"status": "pending", "detail": err or "unknown"}
|
||||
|
||||
@router.post("/device/cancel")
|
||||
def device_cancel(request: Request, poll_id: str = Form(...)):
|
||||
require_admin(request)
|
||||
with _PENDING_LOCK:
|
||||
_PENDING.pop(poll_id, None)
|
||||
return {"status": "cancelled"}
|
||||
|
||||
return router
|
||||
def setup_copilot_routes():
|
||||
return create_device_flow_router(
|
||||
prefix="/api/copilot",
|
||||
tags=["copilot"],
|
||||
store=_DEVICE_FLOW_STORE,
|
||||
start_flow=_start_device_flow,
|
||||
poll_flow=_poll_device_flow,
|
||||
)
|
||||
|
||||
@@ -0,0 +1,193 @@
|
||||
"""Shared OAuth/device-flow route scaffolding for provider setup."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Iterable, Mapping, Optional
|
||||
|
||||
from fastapi import APIRouter, Form, HTTPException, Request
|
||||
|
||||
from core.middleware import require_admin
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DeviceFlowStart:
|
||||
"""Provider-specific start result consumed by the shared route wrapper."""
|
||||
|
||||
pending: Mapping[str, Any]
|
||||
response: Mapping[str, Any]
|
||||
interval: int = 5
|
||||
expires_in: int = 900
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DeviceFlowPoll:
|
||||
"""Normalized provider poll outcome."""
|
||||
|
||||
status: str
|
||||
endpoint: Optional[Mapping[str, Any]] = None
|
||||
error: Optional[str] = None
|
||||
detail: Optional[str] = None
|
||||
interval: Optional[int] = None
|
||||
|
||||
@classmethod
|
||||
def pending(cls, detail: Optional[str] = None) -> "DeviceFlowPoll":
|
||||
return cls(status="pending", detail=detail)
|
||||
|
||||
@classmethod
|
||||
def slow_down(cls, interval: Optional[int] = None, detail: Optional[str] = None) -> "DeviceFlowPoll":
|
||||
return cls(status="slow_down", interval=interval, detail=detail)
|
||||
|
||||
@classmethod
|
||||
def authorized(cls, endpoint: Mapping[str, Any]) -> "DeviceFlowPoll":
|
||||
return cls(status="authorized", endpoint=endpoint)
|
||||
|
||||
@classmethod
|
||||
def failed(cls, error: str) -> "DeviceFlowPoll":
|
||||
return cls(status="failed", error=error)
|
||||
|
||||
|
||||
class PendingDeviceFlowStore:
|
||||
"""Thread-safe in-memory pending device-flow store.
|
||||
|
||||
Device codes and provider-side secrets stay inside this process. Each entry
|
||||
stores provider payload separately from poll metadata so provider callbacks
|
||||
only receive the fields they created.
|
||||
"""
|
||||
|
||||
def __init__(self, *, time_func: Callable[[], float] = time.time):
|
||||
self._pending: dict[str, dict[str, Any]] = {}
|
||||
self._lock = threading.Lock()
|
||||
self._time = time_func
|
||||
|
||||
def _now(self) -> float:
|
||||
return float(self._time())
|
||||
|
||||
def prune_expired(self) -> None:
|
||||
now = self._now()
|
||||
with self._lock:
|
||||
for key in [k for k, v in self._pending.items() if v.get("expires_at", 0) < now]:
|
||||
self._pending.pop(key, None)
|
||||
|
||||
def add(self, payload: Mapping[str, Any], *, interval: int, expires_in: int) -> str:
|
||||
self.prune_expired()
|
||||
poll_id = uuid.uuid4().hex
|
||||
with self._lock:
|
||||
self._pending[poll_id] = {
|
||||
"payload": dict(payload),
|
||||
"interval": max(int(interval or 5), 1),
|
||||
"expires_at": self._now() + max(int(expires_in or 900), 1),
|
||||
"next_poll_at": 0.0,
|
||||
}
|
||||
return poll_id
|
||||
|
||||
def get_payload(self, poll_id: str) -> Optional[dict[str, Any]]:
|
||||
self.prune_expired()
|
||||
with self._lock:
|
||||
entry = self._pending.get(poll_id)
|
||||
if entry is None:
|
||||
return None
|
||||
return dict(entry.get("payload") or {})
|
||||
|
||||
def is_throttled(self, poll_id: str) -> bool:
|
||||
with self._lock:
|
||||
entry = self._pending.get(poll_id)
|
||||
return bool(entry and self._now() < float(entry.get("next_poll_at") or 0))
|
||||
|
||||
def schedule_next(self, poll_id: str) -> None:
|
||||
now = self._now()
|
||||
with self._lock:
|
||||
entry = self._pending.get(poll_id)
|
||||
if entry is not None:
|
||||
entry["next_poll_at"] = now + int(entry.get("interval") or 5)
|
||||
|
||||
def slow_down(self, poll_id: str, interval: Optional[int] = None) -> None:
|
||||
now = self._now()
|
||||
with self._lock:
|
||||
entry = self._pending.get(poll_id)
|
||||
if entry is not None:
|
||||
new_interval = int(interval or (int(entry.get("interval") or 5) + 5))
|
||||
entry["interval"] = max(new_interval, 1)
|
||||
entry["next_poll_at"] = now + entry["interval"]
|
||||
|
||||
def pop(self, poll_id: str) -> None:
|
||||
with self._lock:
|
||||
self._pending.pop(poll_id, None)
|
||||
|
||||
|
||||
async def _maybe_await(value: Any) -> Any:
|
||||
if inspect.isawaitable(value):
|
||||
return await value
|
||||
return value
|
||||
|
||||
|
||||
def _pending_response(detail: Optional[str] = None) -> dict[str, Any]:
|
||||
response: dict[str, Any] = {"status": "pending"}
|
||||
if detail:
|
||||
response["detail"] = detail
|
||||
return response
|
||||
|
||||
|
||||
def create_device_flow_router(
|
||||
*,
|
||||
prefix: str,
|
||||
tags: Iterable[str],
|
||||
store: PendingDeviceFlowStore,
|
||||
start_flow: Callable[[Request, Mapping[str, Any]], DeviceFlowStart],
|
||||
poll_flow: Callable[[Request, Mapping[str, Any]], DeviceFlowPoll],
|
||||
) -> APIRouter:
|
||||
"""Create standard `/device/start|poll|cancel` routes for a provider."""
|
||||
|
||||
router = APIRouter(prefix=prefix, tags=list(tags))
|
||||
|
||||
@router.post("/device/start")
|
||||
async def device_start(request: Request):
|
||||
require_admin(request)
|
||||
form = await request.form()
|
||||
start = await _maybe_await(start_flow(request, form))
|
||||
interval = int(start.interval or 5)
|
||||
expires_in = int(start.expires_in or 900)
|
||||
poll_id = store.add(start.pending, interval=interval, expires_in=expires_in)
|
||||
response = dict(start.response)
|
||||
response.update({"poll_id": poll_id, "interval": interval, "expires_in": expires_in})
|
||||
return response
|
||||
|
||||
@router.post("/device/poll")
|
||||
async def device_poll(request: Request, poll_id: str = Form(...)):
|
||||
require_admin(request)
|
||||
payload = store.get_payload(poll_id)
|
||||
if payload is None:
|
||||
raise HTTPException(404, "Unknown or expired login session")
|
||||
if store.is_throttled(poll_id):
|
||||
return {"status": "pending"}
|
||||
|
||||
try:
|
||||
outcome = await _maybe_await(poll_flow(request, payload))
|
||||
except Exception:
|
||||
store.pop(poll_id)
|
||||
raise
|
||||
|
||||
if outcome.status == "authorized":
|
||||
store.pop(poll_id)
|
||||
return {"status": "authorized", "endpoint": dict(outcome.endpoint or {})}
|
||||
if outcome.status == "failed":
|
||||
store.pop(poll_id)
|
||||
return {"status": "failed", "error": outcome.error or "denied"}
|
||||
if outcome.status == "slow_down":
|
||||
store.slow_down(poll_id, outcome.interval)
|
||||
return _pending_response(outcome.detail)
|
||||
|
||||
store.schedule_next(poll_id)
|
||||
return _pending_response(outcome.detail)
|
||||
|
||||
@router.post("/device/cancel")
|
||||
def device_cancel(request: Request, poll_id: str = Form(...)):
|
||||
require_admin(request)
|
||||
store.pop(poll_id)
|
||||
return {"status": "cancelled"}
|
||||
|
||||
return router
|
||||
+66
-36
@@ -7,14 +7,24 @@ from typing import Dict, Any, List, Optional
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Query, Request, UploadFile, File, Form
|
||||
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import case, func, or_
|
||||
from core.database import SessionLocal, Document, DocumentVersion
|
||||
from core.database import Session as DbSession
|
||||
from src.auth_helpers import get_current_user
|
||||
from src.constants import MAIL_ATTACHMENTS_DIR
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _get_session_or_404(db, session_id: str, user: Optional[str]):
|
||||
session = db.query(DbSession).filter(DbSession.id == session_id).first()
|
||||
if not session:
|
||||
raise HTTPException(404, "Session not found")
|
||||
if user and session.owner != user:
|
||||
raise HTTPException(404, "Session not found")
|
||||
return session
|
||||
|
||||
|
||||
def _aggregate_language_facets(lang_rows):
|
||||
"""Sum document counts per display language for the library facet.
|
||||
|
||||
@@ -30,6 +40,19 @@ def _aggregate_language_facets(lang_rows):
|
||||
return out
|
||||
|
||||
|
||||
def _library_language_for_document(doc: Document) -> str:
|
||||
"""Return the display language used by the document library.
|
||||
|
||||
PDF documents are stored as markdown wrappers so the editor can preserve
|
||||
extracted text, form fields, and annotations. The library should still
|
||||
identify them as PDFs instead of exposing that internal wrapper format.
|
||||
"""
|
||||
from src.pdf_form_doc import find_source_upload_id
|
||||
|
||||
if find_source_upload_id(doc.current_content or ""):
|
||||
return "pdf"
|
||||
return doc.language or "text"
|
||||
|
||||
|
||||
from routes.document_helpers import (
|
||||
DocumentCreate, DocumentUpdate, DocumentPatch,
|
||||
@@ -69,17 +92,12 @@ def setup_document_routes(session_manager, upload_handler=None) -> APIRouter:
|
||||
# the doc is owner-stamped, so it lives in the library on its own.
|
||||
session = None
|
||||
if req.session_id:
|
||||
session = db.query(DbSession).filter(DbSession.id == req.session_id).first()
|
||||
if not session:
|
||||
raise HTTPException(404, "Session not found")
|
||||
# Match the lenient ownership model the rest of the app uses
|
||||
# (see _owner_filter): only block when an AUTHENTICATED user is
|
||||
# writing into a DIFFERENT user's session. In single-user /
|
||||
# unconfigured / localhost-bypass mode the middleware leaves
|
||||
# current_user unset (None), and those sessions are already
|
||||
# served freely everywhere else.
|
||||
if user and session.owner and session.owner != user:
|
||||
raise HTTPException(403, "Cannot create document in another user's session")
|
||||
# unconfigured / localhost-bypass mode, falsey users preserve
|
||||
# the existing lenient path.
|
||||
session = _get_session_or_404(db, req.session_id, user)
|
||||
|
||||
doc_id = str(uuid.uuid4())
|
||||
ver_id = str(uuid.uuid4())
|
||||
@@ -171,11 +189,7 @@ def setup_document_routes(session_manager, upload_handler=None) -> APIRouter:
|
||||
if session_id:
|
||||
db = SessionLocal()
|
||||
try:
|
||||
sess = db.query(DbSession).filter(DbSession.id == session_id).first()
|
||||
if not sess:
|
||||
raise HTTPException(404, "Session not found")
|
||||
if user and sess.owner and sess.owner != user:
|
||||
raise HTTPException(403, "Cannot import into another user's session")
|
||||
_get_session_or_404(db, session_id, user)
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
@@ -198,7 +212,7 @@ def setup_document_routes(session_manager, upload_handler=None) -> APIRouter:
|
||||
|
||||
title = os.path.splitext(meta.get("original_name") or meta.get("name") or upload_id)[0]
|
||||
try:
|
||||
body_text = strip_pdf_content_marker(_process_pdf(pdf_path))
|
||||
body_text = strip_pdf_content_marker(_process_pdf(pdf_path, owner=user))
|
||||
except Exception:
|
||||
body_text = None
|
||||
|
||||
@@ -260,18 +274,29 @@ def setup_document_routes(session_manager, upload_handler=None) -> APIRouter:
|
||||
db = SessionLocal()
|
||||
try:
|
||||
from sqlalchemy import or_
|
||||
pdf_marker_cond = or_(
|
||||
Document.current_content.like('%<!-- pdf_source upload_id="%'),
|
||||
Document.current_content.like('%<!-- pdf_form_source upload_id="%'),
|
||||
)
|
||||
library_language_expr = case(
|
||||
(pdf_marker_cond, "pdf"),
|
||||
(Document.language.is_(None), "text"),
|
||||
else_=Document.language,
|
||||
)
|
||||
# Archived view shows ONLY archived docs; the default view excludes
|
||||
# them (NULL = legacy rows that predate the column = not archived).
|
||||
_arch_cond = (Document.archived == True) if archived else or_(
|
||||
Document.archived == False, Document.archived.is_(None))
|
||||
# Language facet counts (owner-filtered)
|
||||
# Language facet counts (owner-filtered). PDF documents are stored
|
||||
# as markdown wrappers, so group by the library display language
|
||||
# instead of the raw stored language.
|
||||
lang_q = (
|
||||
db.query(Document.language, func.count(Document.id))
|
||||
db.query(library_language_expr, func.count(Document.id))
|
||||
.outerjoin(DbSession, Document.session_id == DbSession.id)
|
||||
.filter(Document.is_active == True).filter(_arch_cond)
|
||||
)
|
||||
lang_q = _owner_session_filter(lang_q, user)
|
||||
lang_rows = lang_q.group_by(Document.language).all()
|
||||
lang_rows = lang_q.group_by(library_language_expr).all()
|
||||
languages = _aggregate_language_facets(lang_rows)
|
||||
|
||||
# Session count (owner-filtered)
|
||||
@@ -303,12 +328,17 @@ def setup_document_routes(session_manager, upload_handler=None) -> APIRouter:
|
||||
Document.title.ilike(term) | Document.current_content.ilike(term)
|
||||
)
|
||||
|
||||
# Language filter
|
||||
# Language filter. "pdf" is a display language derived from the
|
||||
# source marker; "markdown" excludes those wrappers.
|
||||
if language:
|
||||
if language == "text":
|
||||
q = q.filter((Document.language == None) | (Document.language == "text"))
|
||||
elif language == "pdf":
|
||||
q = q.filter(pdf_marker_cond)
|
||||
else:
|
||||
q = q.filter(Document.language == language)
|
||||
if language == "markdown":
|
||||
q = q.filter(~pdf_marker_cond)
|
||||
|
||||
# Total before pagination
|
||||
total = q.count()
|
||||
@@ -332,7 +362,7 @@ def setup_document_routes(session_manager, upload_handler=None) -> APIRouter:
|
||||
"session_id": doc.session_id,
|
||||
"session_name": session_name,
|
||||
"title": doc.title,
|
||||
"language": doc.language or "text",
|
||||
"language": _library_language_for_document(doc),
|
||||
"preview": (doc.current_content or "")[:500],
|
||||
"version_count": doc.version_count,
|
||||
"created_at": (doc.created_at.isoformat() + "Z") if doc.created_at else None,
|
||||
@@ -359,18 +389,17 @@ def setup_document_routes(session_manager, upload_handler=None) -> APIRouter:
|
||||
try:
|
||||
if not user:
|
||||
raise HTTPException(403, "Authentication required")
|
||||
session = db.query(DbSession).filter(DbSession.id == session_id).first()
|
||||
# v2 review HIGH-9: raise 403 explicitly when the caller
|
||||
# can't see this session, instead of returning [] which the
|
||||
# UI treats identically to "no docs" and silently masks
|
||||
# auth failures.
|
||||
if not session:
|
||||
raise HTTPException(404, "Session not found")
|
||||
if user and session.owner and session.owner != user:
|
||||
raise HTTPException(403, "Access denied")
|
||||
docs = db.query(Document).filter(
|
||||
_get_session_or_404(db, session_id, user)
|
||||
q = db.query(Document).filter(
|
||||
Document.session_id == session_id
|
||||
).order_by(Document.created_at.desc()).all()
|
||||
)
|
||||
if user:
|
||||
q = q.filter(or_(Document.owner == user, Document.owner.is_(None)))
|
||||
docs = q.order_by(Document.created_at.desc()).all()
|
||||
return [_doc_to_dict(d) for d in docs]
|
||||
finally:
|
||||
db.close()
|
||||
@@ -437,7 +466,7 @@ def setup_document_routes(session_manager, upload_handler=None) -> APIRouter:
|
||||
raise HTTPException(404, "Source PDF could not be located")
|
||||
|
||||
try:
|
||||
body_text = strip_pdf_content_marker(_process_pdf(pdf_path))
|
||||
body_text = strip_pdf_content_marker(_process_pdf(pdf_path, owner=user))
|
||||
except Exception as e:
|
||||
logger.error(f"extract_pdf_text failed for {pdf_path}: {e}")
|
||||
raise HTTPException(500, f"Extraction failed: {e}")
|
||||
@@ -606,6 +635,8 @@ def setup_document_routes(session_manager, upload_handler=None) -> APIRouter:
|
||||
doc.language = req.language
|
||||
if req.session_id is not None:
|
||||
# Empty string = unlink from session
|
||||
if req.session_id:
|
||||
_get_session_or_404(db, req.session_id, user)
|
||||
doc.session_id = req.session_id if req.session_id else None
|
||||
if not req.session_id:
|
||||
# Tab closed / doc detached from its session — drop the
|
||||
@@ -855,10 +886,10 @@ def setup_document_routes(session_manager, upload_handler=None) -> APIRouter:
|
||||
from src.llm_core import llm_call_async
|
||||
|
||||
user = get_current_user(request)
|
||||
url, model, headers = resolve_task_endpoint()
|
||||
url, model, headers = resolve_task_endpoint(owner=user or None)
|
||||
if not url or not model:
|
||||
# Fall back to default endpoint
|
||||
url, model, headers = resolve_endpoint("default")
|
||||
url, model, headers = resolve_endpoint("default", owner=user or None)
|
||||
if not url or not model:
|
||||
raise HTTPException(500, "No endpoint configured for AI tidy")
|
||||
|
||||
@@ -1158,7 +1189,7 @@ def setup_document_routes(session_manager, upload_handler=None) -> APIRouter:
|
||||
settings = _load_vl_settings()
|
||||
vl_model = settings.get("vision_model", "")
|
||||
try:
|
||||
url, model_id, headers = _resolve_vl_model(vl_model)
|
||||
url, model_id, headers = _resolve_vl_model(vl_model, owner=user)
|
||||
except Exception as e:
|
||||
raise HTTPException(503, f"No vision model available: {e}")
|
||||
|
||||
@@ -1512,10 +1543,7 @@ def setup_document_routes(session_manager, upload_handler=None) -> APIRouter:
|
||||
# don't import from a routes file (cycle-prone). Same env override
|
||||
# as email_routes (ODYSSEUS_MAIL_ATTACHMENTS_DIR).
|
||||
from pathlib import Path as _Path
|
||||
import os as _os
|
||||
_DATA_DIR = _Path(__file__).resolve().parent.parent / "data"
|
||||
_BASE = _os.environ.get("ODYSSEUS_MAIL_ATTACHMENTS_DIR", str(_DATA_DIR / "mail-attachments"))
|
||||
_COMPOSE_DIR = _Path(_BASE) / "_compose"
|
||||
_COMPOSE_DIR = _Path(MAIL_ATTACHMENTS_DIR) / "_compose"
|
||||
_COMPOSE_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
user = get_current_user(request)
|
||||
@@ -1631,9 +1659,11 @@ def setup_document_routes(session_manager, upload_handler=None) -> APIRouter:
|
||||
# context (To/Subject/In-Reply-To/References).
|
||||
try:
|
||||
from routes.email_routes import _imap, _decode_header
|
||||
from routes.email_helpers import _q
|
||||
except Exception:
|
||||
_imap = None
|
||||
_decode_header = lambda x: x or ""
|
||||
_q = lambda x: x or ""
|
||||
|
||||
to_addr = ""
|
||||
from_name = ""
|
||||
@@ -1643,7 +1673,7 @@ def setup_document_routes(session_manager, upload_handler=None) -> APIRouter:
|
||||
if _imap:
|
||||
try:
|
||||
with _imap(doc.source_email_account_id or None) as conn:
|
||||
conn.select(doc.source_email_folder, readonly=True)
|
||||
conn.select(_q(doc.source_email_folder), readonly=True)
|
||||
status, data = conn.fetch(doc.source_email_uid.encode(), "(RFC822.HEADER)")
|
||||
if status == "OK" and data and data[0]:
|
||||
raw_hdr = data[0][1]
|
||||
|
||||
+98
-33
@@ -71,6 +71,38 @@ def _send_smtp_message(cfg: dict, from_addr: str, recipients: list[str], message
|
||||
smtp.sendmail(from_addr, recipients, message)
|
||||
|
||||
|
||||
def _friendly_email_auth_error(protocol: str, host: str, error: object) -> str:
|
||||
"""Return a clearer setup error for known provider auth policies."""
|
||||
raw = str(error or "")
|
||||
lower = raw.lower()
|
||||
host_lower = (host or "").lower()
|
||||
microsoft_host = any(
|
||||
marker in host_lower
|
||||
for marker in (
|
||||
"outlook.office365.com",
|
||||
"smtp.office365.com",
|
||||
"office365.com",
|
||||
"outlook.com",
|
||||
"hotmail.com",
|
||||
"live.com",
|
||||
)
|
||||
)
|
||||
microsoft_basic_auth_failure = (
|
||||
"5.7.139" in lower
|
||||
or "basic authentication is disabled" in lower
|
||||
or ("authenticate failed" in lower and microsoft_host)
|
||||
or ("authentication unsuccessful" in lower and microsoft_host)
|
||||
)
|
||||
if microsoft_basic_auth_failure:
|
||||
return (
|
||||
"Microsoft no longer accepts normal mailbox passwords for "
|
||||
"Outlook/Office 365 IMAP/SMTP in most accounts. Odysseus "
|
||||
"does not support Microsoft OAuth/Graph mail yet, so Outlook "
|
||||
"accounts cannot be added with this password form."
|
||||
)
|
||||
return raw[:200]
|
||||
|
||||
|
||||
def _strip_think(text: str) -> str:
|
||||
"""Email-flavored think strip — thin wrapper over the central helper.
|
||||
|
||||
@@ -254,16 +286,17 @@ def _cleanup_compose_uploads(tokens) -> None:
|
||||
pass
|
||||
|
||||
|
||||
DATA_DIR = Path(__file__).resolve().parent.parent / "data"
|
||||
SETTINGS_FILE = DATA_DIR / "settings.json"
|
||||
from src.constants import DATA_DIR as _DATA_DIR, MAIL_ATTACHMENTS_DIR, SETTINGS_FILE as _SETTINGS_FILE, SCHEDULED_EMAILS_DB
|
||||
DATA_DIR = Path(_DATA_DIR)
|
||||
SETTINGS_FILE = Path(_SETTINGS_FILE)
|
||||
# Override at deploy time via ODYSSEUS_MAIL_ATTACHMENTS_DIR. Defaults to a
|
||||
# subdir of the install's data/ tree so the app works out-of-the-box without
|
||||
# a hardcoded /home/<user>/ path.
|
||||
ATTACHMENTS_DIR = Path(os.environ.get("ODYSSEUS_MAIL_ATTACHMENTS_DIR", str(DATA_DIR / "mail-attachments")))
|
||||
ATTACHMENTS_DIR = Path(MAIL_ATTACHMENTS_DIR)
|
||||
ATTACHMENTS_DIR.mkdir(parents=True, exist_ok=True)
|
||||
COMPOSE_UPLOADS_DIR = ATTACHMENTS_DIR / "_compose"
|
||||
COMPOSE_UPLOADS_DIR.mkdir(parents=True, exist_ok=True)
|
||||
SCHEDULED_DB = DATA_DIR / "scheduled_emails.db"
|
||||
SCHEDULED_DB = Path(SCHEDULED_EMAILS_DB)
|
||||
|
||||
|
||||
OWNER_SCOPED_EMAIL_CACHE_TABLES = {
|
||||
@@ -705,7 +738,16 @@ def _open_imap_connection(host: str, port: int, *, starttls: bool, timeout: int
|
||||
port = int(port or 993)
|
||||
if starttls:
|
||||
conn = imaplib.IMAP4(host, port, timeout=timeout)
|
||||
conn.starttls()
|
||||
try:
|
||||
conn.starttls()
|
||||
except Exception:
|
||||
# Don't leak the open plain socket if the STARTTLS upgrade is
|
||||
# rejected; close it before propagating. (#3174)
|
||||
try:
|
||||
conn.shutdown()
|
||||
except Exception:
|
||||
pass
|
||||
raise
|
||||
elif port == 993:
|
||||
conn = imaplib.IMAP4_SSL(host, port, timeout=timeout)
|
||||
else:
|
||||
@@ -714,6 +756,10 @@ def _open_imap_connection(host: str, port: int, *, starttls: bool, timeout: int
|
||||
conn.sock.settimeout(timeout)
|
||||
except Exception:
|
||||
pass
|
||||
# Raise the IMAP line-length limit from the default 1 MB to 50 MB so that
|
||||
# large mailboxes (tens of thousands of messages) don't crash with
|
||||
# "got more than 1000000 bytes" on UID SEARCH ALL. (#2883)
|
||||
imaplib._MAXLINE = 50_000_000
|
||||
return conn
|
||||
|
||||
def _imap_connect(account_id: str | None = None, owner: str = ""):
|
||||
@@ -734,7 +780,18 @@ def _imap_connect(account_id: str | None = None, owner: str = ""):
|
||||
starttls=bool(cfg.get("imap_starttls")),
|
||||
timeout=_IMAP_TIMEOUT_SECONDS,
|
||||
)
|
||||
conn.login(cfg["imap_user"], cfg["imap_password"])
|
||||
try:
|
||||
conn.login(cfg["imap_user"], cfg["imap_password"])
|
||||
except Exception:
|
||||
# A failed AUTHENTICATE (e.g. an Office 365 app password on an
|
||||
# MFA-enabled tenant, #3174) otherwise orphans the already-connected
|
||||
# socket; close it before propagating so a misconfigured account
|
||||
# can't leak one descriptor per retry / background poller pass.
|
||||
try:
|
||||
conn.shutdown()
|
||||
except Exception:
|
||||
pass
|
||||
raise
|
||||
return conn
|
||||
|
||||
|
||||
@@ -798,20 +855,28 @@ def _imap(account_id: str | None = None, owner: str = ""):
|
||||
def _decode_header(raw):
|
||||
if not raw:
|
||||
return ""
|
||||
parts = email.header.decode_header(raw)
|
||||
decoded = []
|
||||
for data, charset in parts:
|
||||
if isinstance(data, bytes):
|
||||
try:
|
||||
decoded.append(data.decode(charset or "utf-8", errors="replace"))
|
||||
except (LookupError, ValueError):
|
||||
# Unknown/invalid MIME charset (e.g. a malformed or spam header
|
||||
# like =?x-unknown-charset?B?...?=). errors="replace" only covers
|
||||
# byte-decode errors, not codec lookup, so fall back to utf-8.
|
||||
decoded.append(data.decode("utf-8", errors="replace"))
|
||||
else:
|
||||
decoded.append(data)
|
||||
return " ".join(decoded)
|
||||
try:
|
||||
# make_header concatenates per RFC 2047: no spurious space between an
|
||||
# encoded-word and adjacent plain text (plain runs keep their own
|
||||
# whitespace), and the whitespace between two adjacent encoded-words is
|
||||
# dropped. The old " ".join produced "Re: Jose"-style double spaces on
|
||||
# every non-ASCII subject or sender.
|
||||
return str(email.header.make_header(email.header.decode_header(raw)))
|
||||
except Exception:
|
||||
# Malformed header or unknown/invalid MIME charset (e.g. a spam header
|
||||
# like =?x-unknown-charset?B?...?=) makes make_header raise LookupError;
|
||||
# fall back to a lossy per-part decode. errors="replace" only covers
|
||||
# byte-decode errors, not codec lookup, hence the explicit utf-8 retry.
|
||||
decoded = []
|
||||
for data, charset in email.header.decode_header(raw):
|
||||
if isinstance(data, bytes):
|
||||
try:
|
||||
decoded.append(data.decode(charset or "utf-8", errors="replace"))
|
||||
except (LookupError, ValueError):
|
||||
decoded.append(data.decode("utf-8", errors="replace"))
|
||||
else:
|
||||
decoded.append(data)
|
||||
return "".join(decoded)
|
||||
|
||||
|
||||
def _detect_sent_folder(conn):
|
||||
@@ -1136,13 +1201,9 @@ def _fetch_sender_thread_context(sender_addr: str,
|
||||
if exclude_uid:
|
||||
seen_uids.add((exclude_folder or "INBOX", str(exclude_uid)))
|
||||
|
||||
conn = None
|
||||
try:
|
||||
conn = _imap_connect(account_id, owner=owner)
|
||||
except Exception as e:
|
||||
logger.warning(f"sender-thread-context: imap connect failed: {e}")
|
||||
return ""
|
||||
|
||||
try:
|
||||
for folder in ["INBOX", "Sent", "Archive", "Drafts"]:
|
||||
if len(blocks) >= limit:
|
||||
break
|
||||
@@ -1209,11 +1270,14 @@ def _fetch_sender_thread_context(sender_addr: str,
|
||||
if atts_text:
|
||||
lines.append(atts_text)
|
||||
blocks.append("\n".join(lines))
|
||||
except Exception as e:
|
||||
logger.warning(f"sender-thread-context: imap failed: {e}")
|
||||
finally:
|
||||
try: conn.close()
|
||||
except Exception: pass
|
||||
try: conn.logout()
|
||||
except Exception: pass
|
||||
if conn:
|
||||
try: conn.close()
|
||||
except Exception: pass
|
||||
try: conn.logout()
|
||||
except Exception: pass
|
||||
|
||||
if not blocks:
|
||||
return ""
|
||||
@@ -1316,6 +1380,7 @@ def _pre_retrieve_context(
|
||||
if not terms_list:
|
||||
return context_snippets, terms_list
|
||||
|
||||
ctx_conn = None
|
||||
try:
|
||||
ctx_conn = _imap_connect(account_id, owner=owner)
|
||||
for folder in ["INBOX", "Sent", "Archive", "Drafts"]:
|
||||
@@ -1352,12 +1417,12 @@ def _pre_retrieve_context(
|
||||
except Exception as _e:
|
||||
logger.warning(f" search {folder} {term!r} failed: {_e}")
|
||||
continue
|
||||
try:
|
||||
ctx_conn.logout()
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as _e:
|
||||
logger.warning(f"IMAP context search failed: {_e}")
|
||||
finally:
|
||||
if ctx_conn:
|
||||
try: ctx_conn.logout()
|
||||
except Exception: pass
|
||||
|
||||
try:
|
||||
from routes.contacts_routes import _fetch_contacts
|
||||
|
||||
@@ -210,7 +210,7 @@ async def _auto_summarize_pass_single(days_back: int = 1, account_id: str | None
|
||||
if auto_cal:
|
||||
for sent_name in ("Sent", "INBOX/Sent", "Sent Items", "[Gmail]/Sent Mail"):
|
||||
try:
|
||||
st, _ = conn.select(sent_name, readonly=True)
|
||||
st, _ = conn.select(_q(sent_name), readonly=True)
|
||||
if st == "OK":
|
||||
folders_to_scan.append(sent_name)
|
||||
break
|
||||
@@ -1046,7 +1046,7 @@ def _scheduled_poll_once() -> dict:
|
||||
try:
|
||||
with _imap(row_account_id, owner=row_owner) as imap:
|
||||
sent_folder = _detect_sent_folder(imap)
|
||||
imap.append(sent_folder, "\\Seen", None, outer.as_bytes())
|
||||
imap.append(_q(sent_folder), "\\Seen", None, outer.as_bytes())
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to append scheduled {sid} to Sent: {e}")
|
||||
|
||||
|
||||
@@ -32,9 +32,10 @@ from email.mime.multipart import MIMEMultipart
|
||||
|
||||
from fastapi import APIRouter, Query, UploadFile, File, BackgroundTasks, HTTPException, Depends, Request
|
||||
from fastapi.responses import FileResponse
|
||||
from src.constants import DATA_DIR
|
||||
|
||||
from src.llm_core import llm_call_async
|
||||
from src.upload_limits import read_upload_limited
|
||||
from src.upload_limits import read_upload_limited, EMAIL_COMPOSE_UPLOAD_MAX_BYTES
|
||||
|
||||
from routes.email_helpers import (
|
||||
_strip_think, _extract_reply, _apply_email_style_mechanics, require_owner, require_user, _assert_owns_account,
|
||||
@@ -47,6 +48,7 @@ from routes.email_helpers import (
|
||||
_extract_attachment_to_disk, _extract_html, _extract_text,
|
||||
_fetch_sender_thread_context, _pre_retrieve_context,
|
||||
_EMAIL_REPLY_SYS_PROMPT_BASE, _POOL_HOOKS,
|
||||
_friendly_email_auth_error,
|
||||
SendEmailRequest, ExtractStyleRequest,
|
||||
ATTACHMENTS_DIR, COMPOSE_UPLOADS_DIR, SCHEDULED_DB,
|
||||
attachment_extract_dir, _email_cache_owner_clause,
|
||||
@@ -56,7 +58,6 @@ from routes.email_pollers import _start_poller
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ODYSSEUS_MAIL_ORIGIN = "odysseus-ui"
|
||||
EMAIL_COMPOSE_UPLOAD_MAX_BYTES = 25 * 1024 * 1024
|
||||
|
||||
|
||||
def _email_tag_owner_aliases(account_id: str | None, owner: str = "") -> list[str]:
|
||||
@@ -2904,7 +2905,7 @@ def setup_email_routes():
|
||||
from pathlib import Path as _P
|
||||
import json as _json
|
||||
_slug = "".join(c if (c.isalnum() or c in "-_.@") else "_" for c in (owner or "default"))
|
||||
path = _P(f"data/email_urgency_state_{_slug}.json")
|
||||
path = _P(DATA_DIR) / f"email_urgency_state_{_slug}.json"
|
||||
if not path.exists():
|
||||
return {"total_unread": 0, "total_urgent": 0, "max_score": 0, "per_uid": {}}
|
||||
try:
|
||||
@@ -3162,7 +3163,7 @@ def setup_email_routes():
|
||||
try: conn.logout()
|
||||
except Exception: pass
|
||||
except Exception as e:
|
||||
imap_result = {"ok": False, "error": str(e)[:200]}
|
||||
imap_result = {"ok": False, "error": _friendly_email_auth_error("IMAP", imap_host, e)}
|
||||
|
||||
smtp_host = (body.get("smtp_host") or "").strip()
|
||||
if smtp_host:
|
||||
@@ -3184,7 +3185,7 @@ def setup_email_routes():
|
||||
try: smtp.quit()
|
||||
except Exception: pass
|
||||
except Exception as e:
|
||||
smtp_result = {"ok": False, "error": str(e)[:200]}
|
||||
smtp_result = {"ok": False, "error": _friendly_email_auth_error("SMTP", smtp_host, e)}
|
||||
|
||||
return {
|
||||
"ok": imap_result["ok"] and (smtp_result is None or smtp_result["ok"]),
|
||||
|
||||
+65
-22
@@ -7,12 +7,12 @@ import logging
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
from fastapi import APIRouter, HTTPException, Form, Depends
|
||||
from core.constants import BASE_DIR
|
||||
from core.constants import EMBEDDING_ENDPOINT_FILE, FASTEMBED_CACHE_DIR
|
||||
from core.middleware import require_admin
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_ENDPOINT_FILE = os.path.join(BASE_DIR, "data", "embedding_endpoint.json")
|
||||
_ENDPOINT_FILE = EMBEDDING_ENDPOINT_FILE
|
||||
|
||||
# Track in-progress downloads
|
||||
_downloading: dict = {}
|
||||
@@ -35,13 +35,7 @@ def _cache_dir() -> str:
|
||||
default lived in /tmp, which many systems wipe on reboot — forcing a
|
||||
full re-download of the embedding model after every restart.
|
||||
"""
|
||||
env = os.environ.get("FASTEMBED_CACHE_PATH")
|
||||
if env:
|
||||
return env
|
||||
return os.path.join(
|
||||
os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
|
||||
"data", "fastembed_cache",
|
||||
)
|
||||
return FASTEMBED_CACHE_DIR
|
||||
|
||||
|
||||
def _model_cache_name(hf_source: str) -> str:
|
||||
@@ -49,19 +43,35 @@ def _model_cache_name(hf_source: str) -> str:
|
||||
return "models--" + hf_source.replace("/", "--")
|
||||
|
||||
|
||||
def _model_cache_path(hf_source: str) -> Path:
|
||||
"""Return a confined cache path for a fastembed HF source."""
|
||||
root = Path(_cache_dir()).expanduser().resolve()
|
||||
raw_path = root / _model_cache_name(hf_source)
|
||||
if raw_path.is_symlink():
|
||||
raise ValueError("Model cache path must not be a symlink")
|
||||
path = raw_path.resolve(strict=False)
|
||||
try:
|
||||
path.relative_to(root)
|
||||
except ValueError:
|
||||
raise ValueError("Model cache path escapes cache root")
|
||||
return path
|
||||
|
||||
|
||||
def _is_downloaded(hf_source: str) -> bool:
|
||||
"""Check if a model is already cached."""
|
||||
cache = _cache_dir()
|
||||
model_dir = os.path.join(cache, _model_cache_name(hf_source))
|
||||
if not os.path.isdir(model_dir):
|
||||
try:
|
||||
model_dir = _model_cache_path(hf_source)
|
||||
except ValueError:
|
||||
return False
|
||||
if not model_dir.is_dir():
|
||||
return False
|
||||
# Check for actual model files (not just empty dir)
|
||||
snapshots = os.path.join(model_dir, "snapshots")
|
||||
if os.path.isdir(snapshots):
|
||||
return any(os.listdir(snapshots))
|
||||
snapshots = model_dir / "snapshots"
|
||||
if snapshots.is_dir():
|
||||
return any(snapshots.iterdir())
|
||||
# Also check for blobs (older cache format)
|
||||
blobs = os.path.join(model_dir, "blobs")
|
||||
return os.path.isdir(blobs) and any(os.listdir(blobs))
|
||||
blobs = model_dir / "blobs"
|
||||
return blobs.is_dir() and any(blobs.iterdir())
|
||||
|
||||
|
||||
def _active_model() -> str:
|
||||
@@ -119,8 +129,10 @@ def setup_embedding_routes():
|
||||
|
||||
cached_size = None
|
||||
if downloaded and hf_src:
|
||||
model_path = os.path.join(_cache_dir(), _model_cache_name(hf_src))
|
||||
cached_size = _dir_size_mb(model_path)
|
||||
try:
|
||||
cached_size = _dir_size_mb(str(_model_cache_path(hf_src)))
|
||||
except ValueError:
|
||||
cached_size = None
|
||||
|
||||
result.append({
|
||||
"model": m["model"],
|
||||
@@ -217,8 +229,11 @@ def setup_embedding_routes():
|
||||
if not hf_src:
|
||||
raise HTTPException(400, "No cache source for this model")
|
||||
|
||||
model_path = os.path.join(_cache_dir(), _model_cache_name(hf_src))
|
||||
if not os.path.isdir(model_path):
|
||||
try:
|
||||
model_path = _model_cache_path(hf_src)
|
||||
except ValueError as e:
|
||||
raise HTTPException(400, str(e))
|
||||
if not model_path.is_dir():
|
||||
return {"deleted": False, "message": "Model not cached"}
|
||||
|
||||
shutil.rmtree(model_path)
|
||||
@@ -237,7 +252,7 @@ def setup_embedding_routes():
|
||||
}
|
||||
|
||||
@router.post("/endpoint")
|
||||
def set_endpoint(url: str = Form(...), model: str = Form("")):
|
||||
def set_endpoint(url: str = Form(...), model: str = Form(""), api_key: str = Form("")):
|
||||
"""Save a custom embedding endpoint URL."""
|
||||
url = url.strip()
|
||||
if not url:
|
||||
@@ -261,6 +276,7 @@ def setup_embedding_routes():
|
||||
resp = httpx.post(
|
||||
url,
|
||||
json={"input": ["test"], "model": model or "test"},
|
||||
headers={"Authorization": f"Bearer {api_key}"} if api_key else {},
|
||||
timeout=10,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
@@ -271,10 +287,16 @@ def setup_embedding_routes():
|
||||
data = {"url": url}
|
||||
if model:
|
||||
data["model"] = model
|
||||
if api_key:
|
||||
from src.secret_storage import encrypt
|
||||
data["api_key"] = encrypt(api_key)
|
||||
|
||||
_save_custom_endpoint(data)
|
||||
os.environ["EMBEDDING_URL"] = url
|
||||
if model:
|
||||
os.environ["EMBEDDING_MODEL"] = model
|
||||
if api_key:
|
||||
os.environ["EMBEDDING_API_KEY"] = api_key
|
||||
|
||||
# Reset the RAG singleton so it picks up the new endpoint
|
||||
import src.rag_singleton as _rs
|
||||
@@ -288,6 +310,16 @@ def setup_embedding_routes():
|
||||
reset_http_embed_state()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
from src.embedding_lanes import reset_embedding_lane_state
|
||||
reset_embedding_lane_state()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
from src.tool_index import reset_tool_index
|
||||
reset_tool_index()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Reset ChromaDB client (collections will be recreated with new embeddings)
|
||||
try:
|
||||
@@ -308,6 +340,7 @@ def setup_embedding_routes():
|
||||
# Remove from environment
|
||||
os.environ.pop("EMBEDDING_URL", None)
|
||||
os.environ.pop("EMBEDDING_MODEL", None)
|
||||
os.environ.pop("EMBEDDING_API_KEY", None)
|
||||
|
||||
# Reset the RAG singleton so it falls back to fastembed
|
||||
import src.rag_singleton as _rs
|
||||
@@ -318,6 +351,16 @@ def setup_embedding_routes():
|
||||
reset_http_embed_state()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
from src.embedding_lanes import reset_embedding_lane_state
|
||||
reset_embedding_lane_state()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
from src.tool_index import reset_tool_index
|
||||
reset_tool_index()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Reset ChromaDB client
|
||||
try:
|
||||
|
||||
+45
-6
@@ -16,22 +16,54 @@ from pathlib import Path
|
||||
|
||||
import httpx
|
||||
from fastapi import APIRouter
|
||||
from fastapi.responses import FileResponse, Response
|
||||
from fastapi.responses import Response
|
||||
|
||||
from src.constants import EMOJI_CACHE_DIR
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_CACHE_DIR = Path(__file__).resolve().parent.parent / "data" / "emoji_cache"
|
||||
_CACHE_DIR = Path(EMOJI_CACHE_DIR)
|
||||
# OpenMoji "black" set = monochrome line-art SVGs. Filenames are the codepoints
|
||||
# in UPPERCASE (FE0F dropped, same as we compute), '-' joined.
|
||||
_OPENMOJI_BASE = "https://cdn.jsdelivr.net/npm/openmoji@15.0.0/black/svg"
|
||||
# codepoints like "1f600" or "1f468-200d-1f469-200d-1f467" (lowercase hex, '-' joined)
|
||||
_CODE_RE = re.compile(r"^[0-9a-f]{2,6}(?:-[0-9a-f]{2,6})*$")
|
||||
_SVG_HEADERS = {"Cache-Control": "public, max-age=31536000, immutable"}
|
||||
_MAX_SVG_BYTES = 256 * 1024
|
||||
_BLOCKED_SVG_RE = re.compile(
|
||||
br"<\s*(?:script|foreignObject|iframe|object|embed|image)\b|"
|
||||
br"\bon[a-z0-9_-]+\s*=",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
_EXTERNAL_REF_RE = re.compile(
|
||||
br"\b(?:href|xlink:href)\s*=\s*['\"](?:https?:|//|data:|javascript:)",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
_SVG_SECURITY_HEADERS = {
|
||||
"X-Content-Type-Options": "nosniff",
|
||||
"Content-Security-Policy": "sandbox",
|
||||
"Cross-Origin-Resource-Policy": "same-origin",
|
||||
}
|
||||
_SVG_HEADERS = {
|
||||
"Cache-Control": "public, max-age=31536000, immutable",
|
||||
**_SVG_SECURITY_HEADERS,
|
||||
}
|
||||
# Returned when a codepoint is unknown/unreachable: an empty (transparent) SVG,
|
||||
# so the CSS mask renders nothing instead of a solid box. Not cached, so a later
|
||||
# request can still pick up the real glyph once the CDN is reachable.
|
||||
_BLANK_SVG = b'<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 1 1"></svg>'
|
||||
_BLANK_HEADERS = {"Cache-Control": "no-store"}
|
||||
_BLANK_HEADERS = {"Cache-Control": "no-store", **_SVG_SECURITY_HEADERS}
|
||||
|
||||
|
||||
def _is_safe_svg(content: bytes) -> bool:
|
||||
if not isinstance(content, bytes) or not content:
|
||||
return False
|
||||
if len(content) > _MAX_SVG_BYTES:
|
||||
return False
|
||||
if b"<svg" not in content[:256].lower():
|
||||
return False
|
||||
if _BLOCKED_SVG_RE.search(content) or _EXTERNAL_REF_RE.search(content):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def setup_emoji_routes() -> APIRouter:
|
||||
@@ -49,14 +81,21 @@ def setup_emoji_routes() -> APIRouter:
|
||||
_CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
||||
fp = _CACHE_DIR / f"{code}.svg"
|
||||
if fp.exists():
|
||||
return FileResponse(fp, media_type="image/svg+xml", headers=_SVG_HEADERS)
|
||||
try:
|
||||
content = fp.read_bytes()
|
||||
if _is_safe_svg(content):
|
||||
return Response(content, media_type="image/svg+xml", headers=_SVG_HEADERS)
|
||||
fp.unlink(missing_ok=True)
|
||||
except Exception as e:
|
||||
logger.warning("emoji cache read %s failed: %s", code, e)
|
||||
return _blank()
|
||||
|
||||
# First time we've seen this emoji — fetch the OpenMoji black SVG + cache
|
||||
# it. OpenMoji filenames are the codepoints uppercased.
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=8.0) as client:
|
||||
r = await client.get(f"{_OPENMOJI_BASE}/{code.upper()}.svg")
|
||||
if r.status_code == 200 and b"<svg" in r.content[:256]:
|
||||
if r.status_code == 200 and _is_safe_svg(r.content):
|
||||
try:
|
||||
fp.write_bytes(r.content)
|
||||
except Exception:
|
||||
|
||||
+144
-54
@@ -12,8 +12,13 @@ from fastapi import APIRouter, HTTPException, Query, Request
|
||||
|
||||
from core.database import SessionLocal, GalleryImage, GalleryAlbum, ModelEndpoint
|
||||
from core.database import Session as DbSession
|
||||
from src.auth_helpers import get_current_user, require_privilege
|
||||
from src.upload_limits import read_upload_limited
|
||||
from src.auth_helpers import get_current_user, owner_filter, require_privilege
|
||||
from src.upload_limits import (
|
||||
read_upload_limited,
|
||||
GALLERY_UPLOAD_MAX_BYTES,
|
||||
GALLERY_TRANSFORM_UPLOAD_MAX_BYTES,
|
||||
)
|
||||
from src.constants import GENERATED_IMAGES_DIR
|
||||
|
||||
from routes.gallery_helpers import (
|
||||
GalleryPatch, _extract_exif, _image_to_dict, _owner_filter, _human_size,
|
||||
@@ -21,17 +26,88 @@ from routes.gallery_helpers import (
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
GALLERY_UPLOAD_MAX_BYTES = int(os.getenv("ODYSSEUS_GALLERY_UPLOAD_MAX_BYTES", str(100 * 1024 * 1024)))
|
||||
GALLERY_TRANSFORM_UPLOAD_MAX_BYTES = int(os.getenv("ODYSSEUS_GALLERY_TRANSFORM_UPLOAD_MAX_BYTES", str(25 * 1024 * 1024)))
|
||||
|
||||
def _current_user_is_admin(request: Request, user: str | None) -> bool:
|
||||
if not user:
|
||||
return False
|
||||
auth_mgr = getattr(request.app.state, "auth_manager", None)
|
||||
is_admin = getattr(auth_mgr, "is_admin", None)
|
||||
if not callable(is_admin):
|
||||
return False
|
||||
try:
|
||||
return bool(is_admin(user))
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def _sanitize_gallery_filename(filename: str) -> str:
|
||||
"""Return a local filename safe to join under generated_images."""
|
||||
safe_name = re.sub(r"[^A-Za-z0-9._-]", "_", Path(filename or "").name)[:128]
|
||||
safe_name = re.sub(r"[^A-Za-z0-9._-]", "_", Path(str(filename or "")).name)[:128]
|
||||
if not safe_name or safe_name in {".", ".."}:
|
||||
safe_name = uuid.uuid4().hex[:12]
|
||||
return safe_name
|
||||
|
||||
|
||||
GALLERY_IMAGE_DIR = Path(GENERATED_IMAGES_DIR)
|
||||
|
||||
|
||||
def _gallery_image_path(filename: str) -> Path:
|
||||
"""Resolve a stored gallery filename without leaving generated_images."""
|
||||
if not isinstance(filename, str):
|
||||
raise HTTPException(400, "Unsafe gallery filename")
|
||||
safe_name = _sanitize_gallery_filename(filename)
|
||||
original = str(filename or "")
|
||||
root = GALLERY_IMAGE_DIR.resolve()
|
||||
path = (GALLERY_IMAGE_DIR / safe_name).resolve()
|
||||
try:
|
||||
if os.path.commonpath([str(root), str(path)]) != str(root):
|
||||
raise ValueError
|
||||
except Exception:
|
||||
raise HTTPException(400, "Unsafe gallery filename")
|
||||
if safe_name != original:
|
||||
raise HTTPException(400, "Unsafe gallery filename")
|
||||
return path
|
||||
|
||||
|
||||
def _normalize_image_endpoint_base(url: str) -> str:
|
||||
base = (url or "").strip().rstrip("/")
|
||||
if base.endswith("/v1"):
|
||||
base = base[:-3].rstrip("/")
|
||||
return base
|
||||
|
||||
|
||||
def _visible_image_endpoint_query(db, owner: str | None):
|
||||
from src.auth_helpers import owner_filter
|
||||
q = db.query(ModelEndpoint).filter(
|
||||
ModelEndpoint.model_type == "image",
|
||||
ModelEndpoint.is_enabled == True, # noqa: E712
|
||||
)
|
||||
return owner_filter(q, ModelEndpoint, owner)
|
||||
|
||||
|
||||
def _first_visible_image_endpoint(db, owner: str | None):
|
||||
endpoints = _visible_image_endpoint_query(db, owner).all()
|
||||
if owner:
|
||||
for ep in endpoints:
|
||||
if getattr(ep, "owner", None) == owner:
|
||||
return ep
|
||||
return endpoints[0] if endpoints else None
|
||||
|
||||
|
||||
def _visible_image_endpoint_for_base(db, base: str, owner: str | None):
|
||||
target = _normalize_image_endpoint_base(base)
|
||||
if not target:
|
||||
return None
|
||||
fallback = None
|
||||
for ep in _visible_image_endpoint_query(db, owner).all():
|
||||
if _normalize_image_endpoint_base(getattr(ep, "base_url", "")) == target:
|
||||
if owner and getattr(ep, "owner", None) == owner:
|
||||
return ep
|
||||
if fallback is None:
|
||||
fallback = ep
|
||||
return fallback
|
||||
|
||||
|
||||
def setup_gallery_routes() -> APIRouter:
|
||||
router = APIRouter(tags=["gallery"])
|
||||
|
||||
@@ -55,6 +131,9 @@ def setup_gallery_routes() -> APIRouter:
|
||||
file_hash = hashlib.sha256(content).hexdigest()
|
||||
db = SessionLocal()
|
||||
try:
|
||||
if album_id and user is not None:
|
||||
_get_or_404_album(db, album_id, user)
|
||||
|
||||
# SECURITY: scope the dup-detect to THIS user — otherwise a
|
||||
# caller can probe whether someone else uploaded the same
|
||||
# file (the response leaks the existing row's id+filename).
|
||||
@@ -69,7 +148,7 @@ def setup_gallery_routes() -> APIRouter:
|
||||
return {"ok": False, "duplicate": True, "filename": existing.filename,
|
||||
"id": existing.id, "message": "Duplicate photo skipped"}
|
||||
|
||||
img_dir = Path("data/generated_images")
|
||||
img_dir = Path(GENERATED_IMAGES_DIR)
|
||||
img_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
ext = file.filename.rsplit(".", 1)[-1].lower() if "." in file.filename else "png"
|
||||
@@ -135,7 +214,7 @@ def setup_gallery_routes() -> APIRouter:
|
||||
raise HTTPException(400, "No image provided")
|
||||
|
||||
content = await read_upload_limited(file, GALLERY_UPLOAD_MAX_BYTES, "Gallery replacement")
|
||||
img_dir = Path("data/generated_images")
|
||||
img_dir = Path(GENERATED_IMAGES_DIR)
|
||||
img_dir.mkdir(parents=True, exist_ok=True)
|
||||
img_path = img_dir / _sanitize_gallery_filename(img.filename)
|
||||
img_path.write_bytes(content)
|
||||
@@ -211,7 +290,7 @@ def setup_gallery_routes() -> APIRouter:
|
||||
if not user or img.owner != user:
|
||||
raise HTTPException(403, "Not your image")
|
||||
|
||||
img_path = Path("data/generated_images") / img.filename
|
||||
img_path = _gallery_image_path(img.filename)
|
||||
if not img_path.exists():
|
||||
raise HTTPException(404, "Image file not found")
|
||||
|
||||
@@ -248,7 +327,7 @@ def setup_gallery_routes() -> APIRouter:
|
||||
"""AI upscale using img2img with the diffusion server."""
|
||||
import base64, httpx
|
||||
|
||||
require_privilege(request, "can_generate_images")
|
||||
user = require_privilege(request, "can_generate_images")
|
||||
form = await request.form()
|
||||
file = form.get("image")
|
||||
if not file: raise HTTPException(400, "No image")
|
||||
@@ -260,7 +339,7 @@ def setup_gallery_routes() -> APIRouter:
|
||||
# Find image endpoint
|
||||
db = SessionLocal()
|
||||
try:
|
||||
ep = db.query(ModelEndpoint).filter(ModelEndpoint.model_type == "image", ModelEndpoint.is_enabled == True).first()
|
||||
ep = _first_visible_image_endpoint(db, user)
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
@@ -291,7 +370,7 @@ def setup_gallery_routes() -> APIRouter:
|
||||
"""Style transfer using img2img with the diffusion server."""
|
||||
import base64, httpx
|
||||
|
||||
require_privilege(request, "can_generate_images")
|
||||
user = require_privilege(request, "can_generate_images")
|
||||
form = await request.form()
|
||||
file = form.get("image")
|
||||
prompt = form.get("prompt", "")
|
||||
@@ -303,7 +382,7 @@ def setup_gallery_routes() -> APIRouter:
|
||||
|
||||
db = SessionLocal()
|
||||
try:
|
||||
ep = db.query(ModelEndpoint).filter(ModelEndpoint.model_type == "image", ModelEndpoint.is_enabled == True).first()
|
||||
ep = _first_visible_image_endpoint(db, user)
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
@@ -505,18 +584,24 @@ def setup_gallery_routes() -> APIRouter:
|
||||
albums = q.order_by(GalleryAlbum.created_at.desc()).all()
|
||||
result = []
|
||||
for a in albums:
|
||||
count = db.query(GalleryImage).filter(
|
||||
_count_q = db.query(GalleryImage).filter(
|
||||
GalleryImage.album_id == a.id, GalleryImage.is_active == True
|
||||
).count()
|
||||
)
|
||||
if user:
|
||||
_count_q = _count_q.filter(GalleryImage.owner == user)
|
||||
count = _count_q.count()
|
||||
cover_url = None
|
||||
if a.cover_id:
|
||||
cover = db.query(GalleryImage).filter(GalleryImage.id == a.cover_id).first()
|
||||
if cover:
|
||||
cover_url = f"/api/generated-image/{cover.filename}"
|
||||
elif count > 0:
|
||||
first = db.query(GalleryImage).filter(
|
||||
_cover_q = db.query(GalleryImage).filter(
|
||||
GalleryImage.album_id == a.id, GalleryImage.is_active == True
|
||||
).order_by(GalleryImage.created_at.desc()).first()
|
||||
)
|
||||
if user:
|
||||
_cover_q = _cover_q.filter(GalleryImage.owner == user)
|
||||
first = _cover_q.order_by(GalleryImage.created_at.desc()).first()
|
||||
if first:
|
||||
cover_url = f"/api/generated-image/{first.filename}"
|
||||
result.append({
|
||||
@@ -649,7 +734,14 @@ def setup_gallery_routes() -> APIRouter:
|
||||
if req.favorite is not None:
|
||||
img.favorite = req.favorite
|
||||
if req.album_id is not None:
|
||||
img.album_id = req.album_id if req.album_id else None
|
||||
if req.album_id:
|
||||
# Validate the target album belongs to the caller before
|
||||
# moving the image into it — mirrors add_to_album, so you
|
||||
# cannot file your image into another user's album.
|
||||
_get_or_404_album(db, req.album_id, user)
|
||||
img.album_id = req.album_id
|
||||
else:
|
||||
img.album_id = None
|
||||
db.commit()
|
||||
db.refresh(img)
|
||||
return _image_to_dict(img)
|
||||
@@ -692,11 +784,11 @@ def setup_gallery_routes() -> APIRouter:
|
||||
used = set()
|
||||
with zipfile.ZipFile(buf, "w", zipfile.ZIP_DEFLATED) as zf:
|
||||
for img in imgs:
|
||||
src = os.path.join("data", "generated_images", img.filename)
|
||||
if not os.path.exists(src):
|
||||
src = _gallery_image_path(img.filename)
|
||||
if not src.exists():
|
||||
continue
|
||||
ext = os.path.splitext(img.filename)[1] or ".png"
|
||||
base = (img.prompt or "").strip() or os.path.splitext(img.filename)[0]
|
||||
ext = src.suffix or ".png"
|
||||
base = (img.prompt or "").strip() or src.stem
|
||||
base = re.sub(r"[^\w\-. ]+", "", base)[:60].strip() or img.id
|
||||
name = f"{base}{ext}"
|
||||
i = 1
|
||||
@@ -818,9 +910,9 @@ def setup_gallery_routes() -> APIRouter:
|
||||
|
||||
img_filename = img.filename
|
||||
# Remove the file from disk
|
||||
img_path = os.path.join("data", "generated_images", img_filename)
|
||||
if os.path.exists(img_path):
|
||||
os.remove(img_path)
|
||||
img_path = _gallery_image_path(img_filename)
|
||||
if img_path.exists():
|
||||
img_path.unlink()
|
||||
|
||||
# Soft-delete the record
|
||||
img.is_active = False
|
||||
@@ -923,7 +1015,7 @@ def setup_gallery_routes() -> APIRouter:
|
||||
the request for /v1/images/edits (multipart, inverted mask). Otherwise
|
||||
proxy through to a self-hosted diffusion server's /v1/images/inpaint."""
|
||||
import httpx
|
||||
require_privilege(request, "can_generate_images")
|
||||
user = require_privilege(request, "can_generate_images")
|
||||
body = await request.json()
|
||||
# Use endpoint from request body (editor dropdown) or fall back to DB lookup
|
||||
base = (body.pop("_endpoint", "") or "").rstrip("/")
|
||||
@@ -942,14 +1034,11 @@ def setup_gallery_routes() -> APIRouter:
|
||||
if not base:
|
||||
db = SessionLocal()
|
||||
try:
|
||||
eps = db.query(ModelEndpoint).filter(
|
||||
ModelEndpoint.is_enabled == True,
|
||||
ModelEndpoint.model_type == "image",
|
||||
).all()
|
||||
if not eps:
|
||||
ep = _first_visible_image_endpoint(db, user)
|
||||
if not ep:
|
||||
raise HTTPException(400, "No image generation endpoint configured. Serve a diffusion model via Cookbook first.")
|
||||
base = eps[0].base_url.rstrip("/")
|
||||
api_key = eps[0].api_key
|
||||
base = ep.base_url.rstrip("/")
|
||||
api_key = ep.api_key
|
||||
finally:
|
||||
db.close()
|
||||
else:
|
||||
@@ -966,10 +1055,12 @@ def setup_gallery_routes() -> APIRouter:
|
||||
_target = _norm_url(base)
|
||||
db = SessionLocal()
|
||||
try:
|
||||
for ep in db.query(ModelEndpoint).all():
|
||||
if _norm_url(ep.base_url) == _target:
|
||||
api_key = ep.api_key
|
||||
break
|
||||
ep = _visible_image_endpoint_for_base(db, _target, user)
|
||||
if ep:
|
||||
base = (ep.base_url or base).rstrip("/")
|
||||
api_key = ep.api_key
|
||||
elif user and not _current_user_is_admin(request, user):
|
||||
raise HTTPException(403, "Choose a registered image endpoint")
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
@@ -1121,7 +1212,7 @@ def setup_gallery_routes() -> APIRouter:
|
||||
you get edge blending + lighting unification while keeping the
|
||||
composition recognisable."""
|
||||
import httpx, base64 as _b64
|
||||
require_privilege(request, "can_generate_images")
|
||||
user = require_privilege(request, "can_generate_images")
|
||||
body = await request.json()
|
||||
|
||||
image_b64 = body.get("image")
|
||||
@@ -1148,23 +1239,22 @@ def setup_gallery_routes() -> APIRouter:
|
||||
if not base:
|
||||
db = SessionLocal()
|
||||
try:
|
||||
eps = db.query(ModelEndpoint).filter(
|
||||
ModelEndpoint.is_enabled == True,
|
||||
ModelEndpoint.model_type == "image",
|
||||
).all()
|
||||
if not eps:
|
||||
ep = _first_visible_image_endpoint(db, user)
|
||||
if not ep:
|
||||
raise HTTPException(400, "No image generation endpoint configured.")
|
||||
base = eps[0].base_url.rstrip("/")
|
||||
api_key = eps[0].api_key
|
||||
base = ep.base_url.rstrip("/")
|
||||
api_key = ep.api_key
|
||||
finally:
|
||||
db.close()
|
||||
else:
|
||||
db = SessionLocal()
|
||||
try:
|
||||
for ep in db.query(ModelEndpoint).all():
|
||||
if ep.base_url.rstrip("/").removesuffix("/v1").rstrip("/") == base.rstrip("/").removesuffix("/v1").rstrip("/"):
|
||||
api_key = ep.api_key
|
||||
break
|
||||
ep = _visible_image_endpoint_for_base(db, base, user)
|
||||
if ep:
|
||||
base = (ep.base_url or base).rstrip("/")
|
||||
api_key = ep.api_key
|
||||
elif user and not _current_user_is_admin(request, user):
|
||||
raise HTTPException(403, "Choose a registered image endpoint")
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
@@ -1636,9 +1726,10 @@ def setup_gallery_routes() -> APIRouter:
|
||||
db = SessionLocal()
|
||||
try:
|
||||
album = _get_or_404_album(db, album_id, user)
|
||||
db.query(GalleryImage).filter(GalleryImage.album_id == album_id).update(
|
||||
{"album_id": None}, synchronize_session=False
|
||||
)
|
||||
q = db.query(GalleryImage).filter(GalleryImage.album_id == album_id)
|
||||
if user is not None:
|
||||
q = q.filter(GalleryImage.owner == user)
|
||||
q.update({"album_id": None}, synchronize_session=False)
|
||||
db.delete(album)
|
||||
db.commit()
|
||||
return {"ok": True}
|
||||
@@ -1709,7 +1800,7 @@ def setup_gallery_routes() -> APIRouter:
|
||||
try:
|
||||
img = _get_or_404_image(db, image_id, user)
|
||||
|
||||
img_path = Path("data/generated_images") / img.filename
|
||||
img_path = _gallery_image_path(img.filename)
|
||||
if not img_path.exists():
|
||||
raise HTTPException(404, "Image file not found")
|
||||
|
||||
@@ -1727,7 +1818,7 @@ def setup_gallery_routes() -> APIRouter:
|
||||
return {"error": "Vision is disabled — enable it in Settings → Vision"}
|
||||
configured = vl_settings.get("vision_model", "")
|
||||
try:
|
||||
chat_url, model_name, headers = _resolve_vl_model(configured)
|
||||
chat_url, model_name, headers = _resolve_vl_model(configured, owner=user)
|
||||
except ValueError:
|
||||
return {"error": "No vision model configured — set one in Settings → Vision"}
|
||||
if not chat_url:
|
||||
@@ -1808,4 +1899,3 @@ def setup_gallery_routes() -> APIRouter:
|
||||
db.close()
|
||||
|
||||
return router
|
||||
|
||||
|
||||
@@ -490,7 +490,13 @@ def setup_history_routes(session_manager) -> APIRouter:
|
||||
# Copy messages up to keep_count
|
||||
msgs_to_copy = source.history[:keep_count]
|
||||
for msg in msgs_to_copy:
|
||||
new_session.add_message(ChatMessage(msg.role, msg.content, msg.metadata))
|
||||
# Copy the metadata dict. Sharing it would let the fork's
|
||||
# persistence (add_message -> _persist_message stamps
|
||||
# _db_id/timestamp onto the dict) mutate the SOURCE session's
|
||||
# in-memory messages, corrupting their _db_id and breaking
|
||||
# edit/delete-by-id on the original conversation.
|
||||
meta = dict(msg.metadata) if isinstance(msg.metadata, dict) else None
|
||||
new_session.add_message(ChatMessage(msg.role, msg.content, meta))
|
||||
try:
|
||||
from src.event_bus import fire_event
|
||||
fire_event("session_created", getattr(source, 'owner', None))
|
||||
@@ -522,6 +528,8 @@ def setup_history_routes(session_manager) -> APIRouter:
|
||||
async def compact_session(request: Request, session_id: str):
|
||||
"""Manually trigger context compaction for a session."""
|
||||
_verify_session_owner(request, session_id)
|
||||
from src.auth_helpers import effective_user
|
||||
owner = effective_user(request)
|
||||
try:
|
||||
session = session_manager.get_session(session_id)
|
||||
except KeyError:
|
||||
@@ -555,7 +563,7 @@ def setup_history_routes(session_manager) -> APIRouter:
|
||||
)
|
||||
|
||||
# Use utility model if available
|
||||
util_url, util_model, util_headers = resolve_endpoint("utility")
|
||||
util_url, util_model, util_headers = resolve_endpoint("utility", owner=owner or None)
|
||||
compact_url = util_url or session.endpoint_url
|
||||
compact_model = util_model or session.model
|
||||
compact_headers = util_headers if util_url else session.headers
|
||||
|
||||
@@ -13,7 +13,7 @@ import httpx
|
||||
|
||||
from core.database import McpServer, SessionLocal
|
||||
from core.middleware import require_admin
|
||||
from src.constants import DATA_DIR
|
||||
from src.constants import DATA_DIR, MCP_OAUTH_DIR
|
||||
from src.mcp_manager import McpManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -23,7 +23,7 @@ router = APIRouter(prefix="/api/mcp", tags=["mcp"])
|
||||
|
||||
def _mcp_oauth_base_dir() -> Path:
|
||||
"""Directory that may contain OAuth files managed by Odysseus."""
|
||||
return (Path(DATA_DIR) / "mcp_oauth").resolve(strict=False)
|
||||
return Path(MCP_OAUTH_DIR).resolve(strict=False)
|
||||
|
||||
|
||||
def _resolve_mcp_oauth_path(raw_path, field_name: str) -> str:
|
||||
|
||||
@@ -29,11 +29,10 @@ from src.llm_core import llm_call_async
|
||||
from services.memory.memory_extractor import audit_memories
|
||||
from src.auth_helpers import get_current_user, require_user
|
||||
from src.endpoint_resolver import resolve_endpoint
|
||||
from src.upload_limits import read_upload_limited
|
||||
from src.upload_limits import read_upload_limited, MEMORY_IMPORT_MAX_BYTES
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
MEMORY_IMPORT_MAX_BYTES = int(os.getenv("ODYSSEUS_MEMORY_IMPORT_MAX_BYTES", str(10 * 1024 * 1024)))
|
||||
|
||||
def setup_memory_routes(memory_manager: MemoryManager, session_manager: SessionManager, memory_vector=None):
|
||||
"""Set up memory-related routes."""
|
||||
@@ -371,7 +370,7 @@ def setup_memory_routes(memory_manager: MemoryManager, session_manager: SessionM
|
||||
tmp.write(content)
|
||||
tmp_path = tmp.name
|
||||
try:
|
||||
text = _process_pdf(tmp_path)
|
||||
text = _process_pdf(tmp_path, owner=_owner(request))
|
||||
finally:
|
||||
os.unlink(tmp_path)
|
||||
else:
|
||||
|
||||
+188
-26
@@ -5,6 +5,7 @@ import re
|
||||
import uuid
|
||||
import json
|
||||
import socket
|
||||
import hashlib
|
||||
import time as _time
|
||||
import logging
|
||||
import httpx
|
||||
@@ -282,8 +283,11 @@ _HOST_TO_CURATED = (
|
||||
("fireworks.ai", "fireworks"),
|
||||
("googleapis.com", "google"),
|
||||
("x.ai", "xai"),
|
||||
|
||||
("openrouter.ai", "openrouter"),
|
||||
("ollama.com", "ollama"),
|
||||
("opencode.ai/zen/go", "opencode-go"),
|
||||
("opencode.ai/zen", "opencode-zen"),
|
||||
)
|
||||
|
||||
|
||||
@@ -490,6 +494,8 @@ _NON_CHAT_EXACT_PREFIXES = (
|
||||
def _is_chat_model(model_id: str) -> bool:
|
||||
"""Return True if the model ID looks like a chat/completions-capable model."""
|
||||
mid = model_id.lower()
|
||||
if mid in {"gpt-5.1-codex"}:
|
||||
return True
|
||||
for prefix in _NON_CHAT_PREFIXES:
|
||||
if mid.startswith(prefix):
|
||||
return False
|
||||
@@ -502,9 +508,67 @@ def _is_chat_model(model_id: str) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
def _probe_single_model(base: str, api_key: str, model_id: str, timeout: int = 10, with_tools: bool = False) -> dict:
|
||||
def _delete_orphaned_provider_auth(db, auth_id: Optional[str], exclude_ep_id: Optional[str] = None) -> bool:
|
||||
"""Delete a ProviderAuthSession once no endpoint still references it.
|
||||
|
||||
Subscription providers (e.g. ChatGPT Subscription) keep their refresh token
|
||||
in ProviderAuthSession rather than ModelEndpoint.api_key. When the last
|
||||
endpoint backed by that auth row is removed, the stored credentials should
|
||||
be cleared instead of lingering. Returns True if a row was deleted.
|
||||
``exclude_ep_id`` drops the endpoint currently being deleted from the
|
||||
reference count so it does not keep its own auth alive.
|
||||
"""
|
||||
if not auth_id:
|
||||
return False
|
||||
from core.database import ProviderAuthSession
|
||||
still_referenced = db.query(ModelEndpoint.id).filter(
|
||||
ModelEndpoint.provider_auth_id == auth_id,
|
||||
ModelEndpoint.id != exclude_ep_id,
|
||||
).first()
|
||||
if still_referenced is not None:
|
||||
return False
|
||||
auth_row = db.query(ProviderAuthSession).filter(ProviderAuthSession.id == auth_id).first()
|
||||
if auth_row is None:
|
||||
return False
|
||||
db.delete(auth_row)
|
||||
return True
|
||||
|
||||
|
||||
def _is_discovery_only_provider(provider: str) -> bool:
|
||||
"""Provider that only supports model discovery, not live probing.
|
||||
|
||||
ChatGPT Subscription speaks the Responses/Codex API and has no
|
||||
chat-completions or general health endpoint, so completion probes and
|
||||
reachability pings are skipped — status is derived from cached models.
|
||||
"""
|
||||
return provider == "chatgpt-subscription"
|
||||
|
||||
|
||||
def _resolve_probe_key(ep) -> Optional[str]:
|
||||
"""API key/bearer to probe an endpoint with.
|
||||
|
||||
Delegates to ``resolve_endpoint_runtime``, which already returns the static
|
||||
``ModelEndpoint.api_key`` for keyed endpoints and resolves (and refreshes)
|
||||
the runtime bearer for session-backed providers (e.g. ChatGPT Subscription).
|
||||
Returns None if resolution fails (e.g. re-auth required) so probing skips
|
||||
rather than raising. Reads only already-loaded scalar attributes of ``ep``.
|
||||
"""
|
||||
try:
|
||||
from src.endpoint_resolver import resolve_endpoint_runtime
|
||||
_base, key = resolve_endpoint_runtime(ep, owner=getattr(ep, "owner", None))
|
||||
return key
|
||||
except Exception as e:
|
||||
logger.warning("Probe key resolution failed for %s: %s", getattr(ep, "id", "?"), e)
|
||||
return None
|
||||
|
||||
|
||||
def _probe_single_model(base: str, api_key: Optional[str], model_id: str, timeout: int = 10, with_tools: bool = False) -> dict:
|
||||
"""Send a realistic completion request to a single model. Returns {status, latency_ms, error?}."""
|
||||
provider = _detect_provider(base)
|
||||
if _is_discovery_only_provider(provider):
|
||||
# Responses/Codex API, not chat-completions: a completion probe would
|
||||
# 400 and the re-probe flow would then hide every model. Discovery-only.
|
||||
return {"status": "ok", "latency_ms": 0, "skipped": True}
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Say OK"},
|
||||
@@ -618,6 +682,11 @@ def _probe_endpoint(base_url: str, api_key: str = None, timeout: int = 5) -> Lis
|
||||
For Anthropic, queries their /v1/models API, falling back to hardcoded list."""
|
||||
from src.endpoint_resolver import resolve_url
|
||||
base = resolve_url(_normalize_base(base_url))
|
||||
if _detect_provider(base) == "chatgpt-subscription":
|
||||
from src.chatgpt_subscription import fetch_available_models
|
||||
if api_key:
|
||||
return fetch_available_models(api_key, timeout=timeout)
|
||||
return []
|
||||
if _detect_provider(base) == "anthropic":
|
||||
# Try Anthropic's /v1/models endpoint first
|
||||
url = build_models_url(base)
|
||||
@@ -644,6 +713,10 @@ def _probe_endpoint(base_url: str, api_key: str = None, timeout: int = 5) -> Lis
|
||||
logger.warning(f"Anthropic /v1/models failed, using hardcoded list: {e}")
|
||||
return list(ANTHROPIC_MODELS)
|
||||
url = build_models_url(base)
|
||||
if not url:
|
||||
curated_key = _match_provider_curated(base, None)
|
||||
fallback = _PROVIDER_CURATED.get(curated_key) if curated_key else None
|
||||
return list(fallback or [])
|
||||
headers = build_headers(api_key, base)
|
||||
try:
|
||||
r = httpx.get(url, headers=headers, timeout=timeout, verify=llm_verify())
|
||||
@@ -697,7 +770,6 @@ def _probe_endpoint(base_url: str, api_key: str = None, timeout: int = 5) -> Lis
|
||||
return list(fallback)
|
||||
return []
|
||||
|
||||
|
||||
def _ping_endpoint(base_url: str, api_key: str = None, timeout: float = 1.5) -> Dict[str, Any]:
|
||||
"""Reachability probe that does not require installed/listed models."""
|
||||
from src.endpoint_resolver import resolve_url
|
||||
@@ -713,6 +785,10 @@ def _ping_endpoint(base_url: str, api_key: str = None, timeout: float = 1.5) ->
|
||||
or "ollama" in (parsed_base.hostname or "").lower()
|
||||
)
|
||||
|
||||
# APFEL-specific detection
|
||||
host = (parsed_base.hostname or "").lower()
|
||||
looks_like_apfel = "apfel" in host or parsed_base.port == 11435
|
||||
|
||||
def _result_from_response(r) -> Dict[str, Any]:
|
||||
if 300 <= r.status_code < 400:
|
||||
loc = r.headers.get("location", "")
|
||||
@@ -734,7 +810,23 @@ def _ping_endpoint(base_url: str, api_key: str = None, timeout: float = 1.5) ->
|
||||
last_error: Optional[str] = None
|
||||
|
||||
try:
|
||||
if looks_like_ollama:
|
||||
# APFEL does not behave like Ollama; use its health endpoint.
|
||||
if looks_like_apfel:
|
||||
root = base
|
||||
for suffix in ("/v1", "/api"):
|
||||
if root.endswith(suffix):
|
||||
root = root[: -len(suffix)].rstrip("/")
|
||||
break
|
||||
try:
|
||||
r = httpx.get(root + "/health", timeout=timeout, verify=llm_verify())
|
||||
result = _result_from_response(r)
|
||||
if result["reachable"]:
|
||||
return result
|
||||
last_error = result.get("error")
|
||||
except Exception as e:
|
||||
last_error = str(e)[:120]
|
||||
|
||||
elif looks_like_ollama:
|
||||
root = base
|
||||
for suffix in ("/v1", "/api"):
|
||||
if root.endswith(suffix):
|
||||
@@ -754,14 +846,31 @@ def _ping_endpoint(base_url: str, api_key: str = None, timeout: float = 1.5) ->
|
||||
|
||||
try:
|
||||
r = httpx.get(base, headers=headers, timeout=timeout, verify=llm_verify())
|
||||
return _result_from_response(r)
|
||||
result = _result_from_response(r)
|
||||
# If the bare base URL returns a non-auth 4xx (e.g. 404), try /models
|
||||
# as a fallback. OpenAI-compatible servers like llama-swap return 404
|
||||
# on the base /v1 prefix but 200 on /v1/models. Auth failures (401/403)
|
||||
# are definitive — probing /models would just repeat the same rejection.
|
||||
if (
|
||||
not result["reachable"]
|
||||
and result.get("status_code") is not None
|
||||
and 400 <= result["status_code"] < 500
|
||||
and result["status_code"] not in (401, 403)
|
||||
):
|
||||
models_url = build_models_url(base)
|
||||
try:
|
||||
r2 = httpx.get(models_url, headers=headers, timeout=timeout, verify=llm_verify())
|
||||
result2 = _result_from_response(r2)
|
||||
if result2["reachable"]:
|
||||
return result2
|
||||
except Exception:
|
||||
pass
|
||||
return result
|
||||
except Exception as e:
|
||||
last_error = str(e)[:120]
|
||||
|
||||
return {"reachable": False, "status_code": None, "error": last_error}
|
||||
|
||||
|
||||
|
||||
def _model_endpoint_error_message(base_url: str, ping: Dict[str, Any] = None) -> str:
|
||||
"""Return a provider-aware error message for failed endpoint probes."""
|
||||
ping = ping or {}
|
||||
@@ -850,6 +959,14 @@ def _visible_models(cached_models, hidden_models, pinned_models=None):
|
||||
return [m for m in merged if m not in hidden]
|
||||
|
||||
|
||||
def _api_key_fingerprint(api_key: Optional[str]) -> str:
|
||||
"""Stable, non-secret label for distinguishing same-URL credentials."""
|
||||
key = (api_key or "").strip()
|
||||
if not key:
|
||||
return ""
|
||||
return hashlib.sha256(key.encode("utf-8")).hexdigest()[:8]
|
||||
|
||||
|
||||
def setup_model_routes(model_discovery):
|
||||
router = APIRouter(prefix="/api")
|
||||
|
||||
@@ -951,6 +1068,17 @@ def setup_model_routes(model_discovery):
|
||||
ok, info = _should_refresh_endpoint(ep, now, force=force)
|
||||
if not ok:
|
||||
continue
|
||||
if getattr(ep, "provider_auth_id", None):
|
||||
try:
|
||||
from src.endpoint_resolver import resolve_endpoint_runtime
|
||||
info["base"], info["api_key"] = resolve_endpoint_runtime(
|
||||
ep,
|
||||
owner=getattr(ep, "owner", None),
|
||||
)
|
||||
info["key"] = _refresh_key(info["base"], info["api_key"])
|
||||
except Exception as e:
|
||||
logger.warning("Skipping model refresh for %s: could not resolve provider auth: %s", getattr(ep, "name", ep.id), e)
|
||||
continue
|
||||
groups.setdefault(info["key"], {
|
||||
"base": info["base"],
|
||||
"api_key": info["api_key"],
|
||||
@@ -1104,8 +1232,9 @@ def setup_model_routes(model_discovery):
|
||||
raise HTTPException(401, "Not authenticated")
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.error('Auth gate error in GET /api/models, failing closed: %s', e)
|
||||
raise HTTPException(status_code=500, detail='Internal error')
|
||||
# Admins see every endpoint (they manage the global pool); regular
|
||||
# users get the owner-scoped view.
|
||||
_is_admin = False
|
||||
@@ -1219,12 +1348,20 @@ def setup_model_routes(model_discovery):
|
||||
"endpoint_kind": kind,
|
||||
}
|
||||
try:
|
||||
t0 = _time.time()
|
||||
ping = _ping_endpoint(base, ep.api_key, timeout=1.5)
|
||||
entry["latency_ms"] = round((_time.time() - t0) * 1000)
|
||||
entry["status"] = "online" if ping.get("reachable") or cached_count else "offline"
|
||||
entry["error"] = ping.get("error")
|
||||
entry["model_count"] = cached_count or (len(ANTHROPIC_MODELS) if provider == "anthropic" else 0)
|
||||
if _is_discovery_only_provider(provider):
|
||||
# No general health endpoint — an unauthenticated GET just
|
||||
# 401s. Report status from cached models instead of pinging.
|
||||
entry["latency_ms"] = None
|
||||
entry["status"] = "online" if cached_count else "offline"
|
||||
entry["error"] = None
|
||||
entry["model_count"] = cached_count
|
||||
else:
|
||||
t0 = _time.time()
|
||||
ping = _ping_endpoint(base, ep.api_key, timeout=1.5)
|
||||
entry["latency_ms"] = round((_time.time() - t0) * 1000)
|
||||
entry["status"] = "online" if ping.get("reachable") or cached_count else "offline"
|
||||
entry["error"] = ping.get("error")
|
||||
entry["model_count"] = cached_count or (len(ANTHROPIC_MODELS) if provider == "anthropic" else 0)
|
||||
except Exception as e:
|
||||
entry["latency_ms"] = None
|
||||
entry["status"] = "online" if cached_count else "offline"
|
||||
@@ -1257,7 +1394,7 @@ def setup_model_routes(model_discovery):
|
||||
if ep_id and ep_id not in endpoints_cache:
|
||||
ep = db.query(ModelEndpoint).filter(ModelEndpoint.id == ep_id).first()
|
||||
if ep:
|
||||
endpoints_cache[ep_id] = {"base_url": ep.base_url, "api_key": ep.api_key}
|
||||
endpoints_cache[ep_id] = {"base_url": ep.base_url, "api_key": _resolve_probe_key(ep)}
|
||||
ep_data = endpoints_cache.get(ep_id)
|
||||
if not ep_data:
|
||||
# Try to find by base_url from the model's endpoint field
|
||||
@@ -1296,7 +1433,7 @@ def setup_model_routes(model_discovery):
|
||||
"id": ep.id,
|
||||
"name": ep.name,
|
||||
"base_url": ep.base_url,
|
||||
"api_key": ep.api_key,
|
||||
"api_key": _resolve_probe_key(ep),
|
||||
})
|
||||
finally:
|
||||
db.close()
|
||||
@@ -1385,18 +1522,21 @@ def setup_model_routes(model_discovery):
|
||||
# Endpoint counts as reachable if it has any model — including
|
||||
# admin-pinned IDs that a probe would never surface.
|
||||
status = "online" if (all_models or pinned) else "offline"
|
||||
base = _normalize_base(r.base_url)
|
||||
ping = None
|
||||
if not all_models and not pinned and r.is_enabled:
|
||||
# Discovery-only providers have no health endpoint — an
|
||||
# unauthenticated ping just 401s, so don't bother.
|
||||
if not all_models and not pinned and r.is_enabled and not _is_discovery_only_provider(_detect_provider(base)):
|
||||
ping = _ping_endpoint(r.base_url, r.api_key, timeout=1.0)
|
||||
if ping.get("reachable"):
|
||||
status = "empty"
|
||||
base = _normalize_base(r.base_url)
|
||||
kind = _effective_endpoint_kind(r, base)
|
||||
results.append({
|
||||
"id": r.id,
|
||||
"name": r.name,
|
||||
"base_url": r.base_url,
|
||||
"has_key": bool(r.api_key),
|
||||
"api_key_fingerprint": _api_key_fingerprint(r.api_key),
|
||||
"is_enabled": r.is_enabled,
|
||||
"models": visible,
|
||||
"pinned_models": pinned,
|
||||
@@ -1463,21 +1603,34 @@ def setup_model_routes(model_discovery):
|
||||
)
|
||||
explicit_timeout = _explicit_model_list_timeout(base_url, requested_kind, refresh_timeout)
|
||||
|
||||
# Dedupe: if an endpoint with the same base_url already exists and
|
||||
# is reachable by the caller (shared or owned by them), return it
|
||||
# instead of creating a duplicate row. Fixes "Scan for Servers"
|
||||
# re-adding manually-added endpoints under their host:port name.
|
||||
# Dedupe: if an endpoint with the same base_url and compatible
|
||||
# credentials already exists and is reachable by the caller (shared or
|
||||
# owned by them), return it instead of creating a duplicate row. Keep
|
||||
# same-url/different-key rows distinct so users can group the same
|
||||
# provider URL under multiple credentials.
|
||||
from src.auth_helpers import get_current_user as _gcu_dedup
|
||||
_caller = _gcu_dedup(request) or None
|
||||
_incoming_api_key = api_key.strip()
|
||||
_db_dedup = SessionLocal()
|
||||
try:
|
||||
existing = (
|
||||
_same_url_rows = (
|
||||
_db_dedup.query(ModelEndpoint)
|
||||
.filter(ModelEndpoint.base_url == base_url)
|
||||
.filter((ModelEndpoint.owner.is_(None)) | (ModelEndpoint.owner == _caller))
|
||||
.order_by(ModelEndpoint.owner.desc()) # prefer owned over shared
|
||||
.first()
|
||||
.all()
|
||||
)
|
||||
existing = None
|
||||
_empty_key_existing = None
|
||||
for _candidate in _same_url_rows:
|
||||
_candidate_key = (getattr(_candidate, "api_key", None) or "").strip()
|
||||
if _candidate_key == _incoming_api_key:
|
||||
existing = _candidate
|
||||
break
|
||||
if _incoming_api_key and not _candidate_key and _empty_key_existing is None:
|
||||
_empty_key_existing = _candidate
|
||||
if existing is None and _incoming_api_key and _empty_key_existing is not None:
|
||||
existing = _empty_key_existing
|
||||
if existing:
|
||||
changed = False
|
||||
# Persist any incoming pinned IDs onto the existing row. An
|
||||
@@ -1526,6 +1679,8 @@ def setup_model_routes(model_discovery):
|
||||
"id": existing.id,
|
||||
"name": existing.name,
|
||||
"base_url": existing.base_url,
|
||||
"has_key": bool(existing.api_key),
|
||||
"api_key_fingerprint": _api_key_fingerprint(existing.api_key),
|
||||
"models": _visible_models(
|
||||
existing_models,
|
||||
getattr(existing, "hidden_models", None),
|
||||
@@ -1599,6 +1754,8 @@ def setup_model_routes(model_discovery):
|
||||
"id": ep_id,
|
||||
"name": name.strip(),
|
||||
"base_url": base_url,
|
||||
"has_key": bool(api_key.strip()),
|
||||
"api_key_fingerprint": _api_key_fingerprint(api_key),
|
||||
"models": _merge_model_ids(model_ids, _pinned),
|
||||
"pinned_models": _pinned,
|
||||
"online": bool(model_ids) or bool(_pinned) or bool(ping.get("reachable")),
|
||||
@@ -1648,7 +1805,7 @@ def setup_model_routes(model_discovery):
|
||||
ep = db.query(ModelEndpoint).filter(ModelEndpoint.id == ep_id).first()
|
||||
if not ep:
|
||||
raise HTTPException(404, "Endpoint not found")
|
||||
ep_data = {"id": ep.id, "name": ep.name, "base_url": ep.base_url, "api_key": ep.api_key}
|
||||
ep_data = {"id": ep.id, "name": ep.name, "base_url": ep.base_url, "api_key": _resolve_probe_key(ep)}
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
@@ -1712,7 +1869,7 @@ def setup_model_routes(model_discovery):
|
||||
category = _classify_endpoint(base, kind)
|
||||
timeout = _manual_refresh_timeout(ep, category, refresh_timeout)
|
||||
try:
|
||||
probed = _probe_endpoint(base, ep.api_key, timeout=timeout)
|
||||
probed = _probe_endpoint(base, _resolve_probe_key(ep), timeout=timeout)
|
||||
except Exception as exc:
|
||||
logger.warning("Manual model refresh failed for endpoint %s at %s: %s", ep_id, base, exc)
|
||||
probed = []
|
||||
@@ -1948,6 +2105,8 @@ def setup_model_routes(model_discovery):
|
||||
"name": ep.name,
|
||||
"model_type": ep.model_type,
|
||||
"base_url": ep.base_url,
|
||||
"has_key": bool(ep.api_key),
|
||||
"api_key_fingerprint": _api_key_fingerprint(ep.api_key),
|
||||
"pinned_models": _normalize_model_ids(getattr(ep, "pinned_models", None)),
|
||||
"endpoint_kind": getattr(ep, "endpoint_kind", None) or "auto",
|
||||
"model_refresh_mode": getattr(ep, "model_refresh_mode", None) or "auto",
|
||||
@@ -2049,7 +2208,9 @@ def setup_model_routes(model_discovery):
|
||||
cleared_user_preferences = _clear_user_prefs_for_endpoint(ep_id)
|
||||
cleared_sessions = _clear_sessions_for_endpoint(db, ep.base_url)
|
||||
cleared_loaded_sessions = _clear_loaded_sessions_for_endpoint(ep.base_url)
|
||||
auth_id = getattr(ep, "provider_auth_id", None)
|
||||
db.delete(ep)
|
||||
cleared_provider_auth = _delete_orphaned_provider_auth(db, auth_id, exclude_ep_id=ep_id)
|
||||
db.commit()
|
||||
_invalidate_models_cache()
|
||||
_local_probe_cache["data"] = None
|
||||
@@ -2059,6 +2220,7 @@ def setup_model_routes(model_discovery):
|
||||
"cleared_user_preferences": cleared_user_preferences,
|
||||
"cleared_sessions": cleared_sessions,
|
||||
"cleared_loaded_sessions": cleared_loaded_sessions,
|
||||
"cleared_provider_auth": cleared_provider_auth,
|
||||
}
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
+161
-16
@@ -11,6 +11,7 @@ from pydantic import BaseModel
|
||||
|
||||
from core.database import SessionLocal, Note
|
||||
from src.auth_helpers import get_current_user
|
||||
from src.constants import DATA_DIR
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -95,6 +96,32 @@ def _note_to_dict(note: Note) -> Dict[str, Any]:
|
||||
}
|
||||
|
||||
|
||||
def _reminder_text_from_note(note: Note) -> tuple[str, str]:
|
||||
"""Return the reminder title/body from a stored note row."""
|
||||
title = (note.title or "Note reminder").strip() or "Note reminder"
|
||||
if note.items:
|
||||
try:
|
||||
items = json.loads(note.items)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
items = None
|
||||
if isinstance(items, list):
|
||||
pending: list[str] = []
|
||||
for item in items:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
if item.get("done") or item.get("checked"):
|
||||
continue
|
||||
text = str(item.get("text") or "").strip()
|
||||
if text:
|
||||
pending.append(text)
|
||||
if pending:
|
||||
shown = "\n".join(f"- {text}" for text in pending[:8])
|
||||
extra = f"\n...and {len(pending) - 8} more" if len(pending) > 8 else ""
|
||||
return title, f"Pending ({len(pending)}):\n{shown}{extra}"
|
||||
return title, f"{len(items)} item{'s' if len(items) != 1 else ''}"
|
||||
return title, (note.content or "").strip()[:400]
|
||||
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Reminder dispatch — module-level so background tasks (built-in actions)
|
||||
@@ -114,8 +141,9 @@ async def dispatch_reminder(
|
||||
note_id: str,
|
||||
owner: str = "",
|
||||
queue_browser: bool = True,
|
||||
settings_override: dict | None = None,
|
||||
) -> dict:
|
||||
"""Fire a reminder via the configured channel (browser/email/ntfy).
|
||||
"""Fire a reminder via the configured channel (browser/email/ntfy/webhook).
|
||||
|
||||
Args:
|
||||
title: short headline shown to the user
|
||||
@@ -129,7 +157,7 @@ async def dispatch_reminder(
|
||||
nothing is "sent" synchronously for it — the channel just routes there.
|
||||
"""
|
||||
from src.settings import load_settings
|
||||
settings = load_settings()
|
||||
settings = {**load_settings(), **(settings_override or {})}
|
||||
channel = settings.get("reminder_channel", "browser")
|
||||
llm_on = bool(settings.get("reminder_llm_synthesis", False))
|
||||
title = (title or "").strip()
|
||||
@@ -143,7 +171,7 @@ async def dispatch_reminder(
|
||||
from datetime import datetime as _dt, timezone as _tz, timedelta as _td
|
||||
from pathlib import Path as _P
|
||||
_slug = "".join(c if (c.isalnum() or c in "-_.@") else "_" for c in (owner or "default"))
|
||||
cache_path = _P(f"data/note_pings_{_slug}.json")
|
||||
cache_path = _P(DATA_DIR) / f"note_pings_{_slug}.json"
|
||||
if cache_path.exists():
|
||||
cache = _json.loads(cache_path.read_text(encoding="utf-8"))
|
||||
last = cache.get(cache_key)
|
||||
@@ -160,13 +188,14 @@ async def dispatch_reminder(
|
||||
# Treat those as browser-only dedupe so email reminders can be
|
||||
# retried by the backend scanner after a failed frontend path.
|
||||
should_skip = last_dt >= _dt.now(_tz.utc) - _td(minutes=25)
|
||||
if should_skip and channel in ("email", "ntfy"):
|
||||
if should_skip and channel in ("email", "ntfy", "webhook"):
|
||||
should_skip = last_channel == channel
|
||||
if should_skip:
|
||||
return {
|
||||
"synthesis": None,
|
||||
"email_sent": False,
|
||||
"ntfy_sent": False,
|
||||
"webhook_sent": False,
|
||||
"browser_sent": True,
|
||||
"skipped": True,
|
||||
}
|
||||
@@ -179,9 +208,9 @@ async def dispatch_reminder(
|
||||
try:
|
||||
from src.endpoint_resolver import resolve_endpoint
|
||||
from src.llm_core import llm_call_async
|
||||
url, model, headers = resolve_endpoint("utility")
|
||||
url, model, headers = resolve_endpoint("utility", owner=owner or None)
|
||||
if not url:
|
||||
url, model, headers = resolve_endpoint("default")
|
||||
url, model, headers = resolve_endpoint("default", owner=owner or None)
|
||||
if url and model:
|
||||
raw = await llm_call_async(
|
||||
url=url, model=model,
|
||||
@@ -360,6 +389,76 @@ async def dispatch_reminder(
|
||||
email_error = str(e) or e.__class__.__name__
|
||||
logger.warning(f"Reminder email send failed: {e}")
|
||||
|
||||
webhook_sent = False
|
||||
webhook_error = ""
|
||||
if channel == "webhook":
|
||||
try:
|
||||
import httpx
|
||||
import json as _wjson
|
||||
from src.integrations import load_integrations
|
||||
# Built-in payload defaults for known presets so users don't have
|
||||
# to configure a template just to use a standard service.
|
||||
_PRESET_TEMPLATE_DEFAULTS = {
|
||||
"discord_webhook": '{"embeds": [{"title": "{{title}}", "description": "{{message}}", "color": 5793266}]}',
|
||||
}
|
||||
intg_id = settings.get("reminder_webhook_integration_id", "").strip()
|
||||
template = settings.get("reminder_webhook_payload_template", "").strip()
|
||||
if not intg_id:
|
||||
webhook_error = "No webhook integration selected"
|
||||
else:
|
||||
intg = next(
|
||||
(i for i in load_integrations()
|
||||
if i.get("id") == intg_id and i.get("base_url")),
|
||||
None,
|
||||
)
|
||||
if not intg:
|
||||
webhook_error = f"Integration {intg_id!r} not found or missing base URL"
|
||||
else:
|
||||
# Fall back to a built-in default for known presets so
|
||||
# users don't have to configure a template for standard
|
||||
# services like Discord.
|
||||
if not template:
|
||||
template = _PRESET_TEMPLATE_DEFAULTS.get(intg.get("preset", ""), "")
|
||||
if not template:
|
||||
webhook_error = "No payload template configured"
|
||||
else:
|
||||
# Render template: JSON-escape the values so the result
|
||||
# is always valid JSON regardless of special characters.
|
||||
# dumps() returns `"value"` — strip outer quotes.
|
||||
msg = (synthesis or note_body or title or "Reminder")[:4000]
|
||||
_t = _wjson.dumps(title or "Reminder")[1:-1]
|
||||
_m = _wjson.dumps(msg)[1:-1]
|
||||
rendered = template.replace("{{title}}", _t).replace("{{message}}", _m)
|
||||
hdrs = {"Content-Type": "application/json"}
|
||||
api_key = intg.get("api_key", "")
|
||||
auth_type = (intg.get("auth_type") or "none").lower()
|
||||
if api_key:
|
||||
if auth_type == "bearer":
|
||||
hdrs["Authorization"] = f"Bearer {api_key}"
|
||||
elif auth_type == "header":
|
||||
hdrs[intg.get("auth_header") or "Authorization"] = api_key
|
||||
url = intg["base_url"].rstrip("/")
|
||||
# SSRF guard — matches the pattern used by webhook_routes,
|
||||
# CalDAV, search, and embeddings. Blocks link-local / metadata
|
||||
# addresses (169.254.x.x) by default; set
|
||||
# REMINDER_WEBHOOK_BLOCK_PRIVATE_IPS=true to also block
|
||||
# RFC-1918 ranges for locked-down deployments.
|
||||
import os as _os
|
||||
from src.url_safety import check_outbound_url as _chk
|
||||
_block = _os.getenv("REMINDER_WEBHOOK_BLOCK_PRIVATE_IPS", "false").lower() == "true"
|
||||
_ok, _reason = _chk(url, block_private=_block)
|
||||
if not _ok:
|
||||
webhook_error = f"Webhook URL rejected: {_reason}"
|
||||
else:
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
resp = await client.post(url, content=rendered.encode(), headers=hdrs)
|
||||
webhook_sent = resp.is_success
|
||||
if not webhook_sent:
|
||||
webhook_error = f"Webhook returned HTTP {resp.status_code}"
|
||||
except Exception as e:
|
||||
webhook_error = str(e) or e.__class__.__name__
|
||||
logger.warning(f"Reminder webhook send failed: {e}")
|
||||
|
||||
ntfy_sent = False
|
||||
ntfy_error = ""
|
||||
if channel == "ntfy":
|
||||
@@ -415,7 +514,7 @@ async def dispatch_reminder(
|
||||
# second send for the same note within 25 min. Without this, a note
|
||||
# whose due_date fires while the user has the app open got TWO emails
|
||||
# (frontend-fired here + background-fired by ping_notes 0–5 min later).
|
||||
if (email_sent or ntfy_sent or browser_sent or local_browser_sent) and note_id:
|
||||
if (email_sent or ntfy_sent or webhook_sent or browser_sent or local_browser_sent) and note_id:
|
||||
try:
|
||||
import json as _json
|
||||
from datetime import datetime as _dt, timezone as _tz
|
||||
@@ -425,13 +524,13 @@ async def dispatch_reminder(
|
||||
_STATE = cache_path
|
||||
if _STATE is None:
|
||||
_slug = "".join(c if (c.isalnum() or c in "-_.@") else "_" for c in (owner or "default"))
|
||||
_STATE = _P(f"data/note_pings_{_slug}.json")
|
||||
_STATE = _P(DATA_DIR) / f"note_pings_{_slug}.json"
|
||||
_STATE.parent.mkdir(parents=True, exist_ok=True)
|
||||
try:
|
||||
_cache = cache or (_json.loads(_STATE.read_text(encoding="utf-8")) if _STATE.exists() else {})
|
||||
except Exception:
|
||||
_cache = {}
|
||||
sent_channel = "email" if email_sent else "ntfy" if ntfy_sent else "browser"
|
||||
sent_channel = "email" if email_sent else "ntfy" if ntfy_sent else "webhook" if webhook_sent else "browser"
|
||||
_cache[cache_key or str(note_id)] = {
|
||||
"at": _dt.now(_tz.utc).isoformat(),
|
||||
"channel": sent_channel,
|
||||
@@ -441,11 +540,14 @@ async def dispatch_reminder(
|
||||
logger.debug(f"dispatch_reminder: cache write failed: {_e}")
|
||||
|
||||
return {
|
||||
"channel": channel,
|
||||
"synthesis": synthesis,
|
||||
"email_sent": email_sent,
|
||||
"email_error": email_error,
|
||||
"ntfy_sent": ntfy_sent,
|
||||
"ntfy_error": ntfy_error,
|
||||
"webhook_sent": webhook_sent,
|
||||
"webhook_error": webhook_error,
|
||||
"browser_sent": browser_sent or local_browser_sent,
|
||||
}
|
||||
|
||||
@@ -467,6 +569,23 @@ def setup_note_routes(task_scheduler=None):
|
||||
def _owner(request: Request) -> Optional[str]:
|
||||
return get_current_user(request)
|
||||
|
||||
def _is_admin_or_single_user(request: Request, user: str | None) -> bool:
|
||||
if user == "internal-tool":
|
||||
return True
|
||||
if not user:
|
||||
# require_user() already admitted this request, which only happens
|
||||
# for auth-disabled, loopback-bypass, or unconfigured single-user
|
||||
# modes. There is no separate non-admin account boundary there.
|
||||
return True
|
||||
try:
|
||||
from core.auth import AuthManager
|
||||
auth_mgr = getattr(request.app.state, "auth_manager", None) or AuthManager()
|
||||
if not getattr(auth_mgr, "is_configured", True):
|
||||
return True
|
||||
return bool(auth_mgr.is_admin(user))
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
# --- LIST ---
|
||||
@router.get("")
|
||||
def list_notes(
|
||||
@@ -684,20 +803,46 @@ def setup_note_routes(task_scheduler=None):
|
||||
"""
|
||||
# Gate against anonymous callers — LLM synthesis can burn tokens.
|
||||
from src.auth_helpers import require_user as _ru
|
||||
_ru(request)
|
||||
user = _ru(request)
|
||||
body = await request.json()
|
||||
note_id = body.get("note_id")
|
||||
title = (body.get("title") or "").strip()
|
||||
note_body = (body.get("body") or "").strip()
|
||||
note_id = str(body.get("note_id") or "").strip()
|
||||
if not note_id:
|
||||
raise HTTPException(400, "note_id required")
|
||||
|
||||
# Delegate to the module-level helper so background tasks can reuse
|
||||
# the same dispatch without an HTTP roundtrip + auth cookie.
|
||||
caller = _owner(request)
|
||||
is_test = note_id.startswith("test-")
|
||||
is_admin = _is_admin_or_single_user(request, user or caller)
|
||||
_override: dict = {}
|
||||
if is_test:
|
||||
if not is_admin:
|
||||
raise HTTPException(403, "Admin only")
|
||||
title = (body.get("title") or "Test Reminder").strip() or "Test Reminder"
|
||||
note_body = (body.get("body") or "").strip()
|
||||
# Optional overrides let the admin settings test button pass the
|
||||
# current UI values directly so it never races a pending save.
|
||||
if body.get("channel"):
|
||||
_override["reminder_channel"] = body["channel"]
|
||||
if body.get("webhook_integration_id"):
|
||||
_override["reminder_webhook_integration_id"] = body["webhook_integration_id"]
|
||||
if body.get("webhook_payload_template"):
|
||||
_override["reminder_webhook_payload_template"] = body["webhook_payload_template"]
|
||||
else:
|
||||
db = SessionLocal()
|
||||
try:
|
||||
note = db.query(Note).filter(Note.id == note_id).first()
|
||||
if not note:
|
||||
raise HTTPException(404, "Note not found")
|
||||
if caller is not None and note.owner != caller:
|
||||
raise HTTPException(404, "Note not found")
|
||||
title, note_body = _reminder_text_from_note(note)
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
return await dispatch_reminder(
|
||||
title=title, note_body=note_body, note_id=note_id,
|
||||
owner=_owner(request) or "",
|
||||
owner=caller or "",
|
||||
queue_browser=False,
|
||||
settings_override=_override or None,
|
||||
)
|
||||
|
||||
# --- REORDER NOTES ---
|
||||
|
||||
+13
-12
@@ -6,16 +6,14 @@ import uuid
|
||||
from typing import List, Tuple
|
||||
from fastapi import APIRouter, HTTPException, Query, Request, UploadFile, File, Depends
|
||||
from src.request_models import DirectoryRequest
|
||||
from core.constants import BASE_DIR, PERSONAL_DIR
|
||||
from core.constants import BASE_DIR, PERSONAL_DIR, PERSONAL_UPLOADS_DIR
|
||||
from src.rag_singleton import get_rag_manager
|
||||
from src.auth_helpers import get_current_user, require_user
|
||||
from src.auth_helpers import require_privilege, require_user
|
||||
from core.middleware import require_admin
|
||||
from src.upload_handler import secure_filename
|
||||
from src.upload_limits import PERSONAL_UPLOAD_MAX_BYTES
|
||||
|
||||
UPLOADS_DIR = os.path.join(BASE_DIR, "data", "personal_uploads")
|
||||
MAX_PERSONAL_UPLOAD_BYTES = int(
|
||||
os.getenv("ODYSSEUS_PERSONAL_UPLOAD_MAX_BYTES", str(25 * 1024 * 1024))
|
||||
)
|
||||
UPLOADS_DIR = PERSONAL_UPLOADS_DIR
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -194,7 +192,7 @@ def setup_personal_routes(personal_docs_manager, rag_manager, rag_available):
|
||||
@router.post("/upload")
|
||||
async def upload_files_to_rag(request: Request, files: List[UploadFile] = File(...)):
|
||||
"""Upload files directly into RAG. Supports text and PDF."""
|
||||
user = get_current_user(request)
|
||||
user = require_privilege(request, "can_use_documents")
|
||||
rag = _rag()
|
||||
if not rag:
|
||||
raise HTTPException(503, "RAG system is not available — is the embedding service running?")
|
||||
@@ -208,8 +206,8 @@ def setup_personal_routes(personal_docs_manager, rag_manager, rag_available):
|
||||
for upload in files:
|
||||
try:
|
||||
file_path, stored_name, safe_name = _unique_personal_upload_path(upload_dir, upload.filename)
|
||||
content_bytes = await upload.read(MAX_PERSONAL_UPLOAD_BYTES + 1)
|
||||
if len(content_bytes) > MAX_PERSONAL_UPLOAD_BYTES:
|
||||
content_bytes = await upload.read(PERSONAL_UPLOAD_MAX_BYTES + 1)
|
||||
if len(content_bytes) > PERSONAL_UPLOAD_MAX_BYTES:
|
||||
logger.warning(f"Rejected oversized personal upload: {upload.filename!r}")
|
||||
total_failed += 1
|
||||
continue
|
||||
@@ -286,9 +284,12 @@ def setup_personal_routes(personal_docs_manager, rag_manager, rag_available):
|
||||
except ValueError:
|
||||
# commonpath raises on mixed drives / non-comparable paths
|
||||
in_uploads = False
|
||||
if in_uploads and abs_target != base_abs and os.path.exists(abs_target):
|
||||
os.remove(abs_target)
|
||||
deleted_from_disk = True
|
||||
if in_uploads and abs_target != base_abs:
|
||||
try:
|
||||
os.remove(abs_target)
|
||||
deleted_from_disk = True
|
||||
except FileNotFoundError:
|
||||
pass # already gone — race with another request or cleanup
|
||||
|
||||
# Exclude the file from the listing (persists across restarts)
|
||||
personal_docs_manager.exclude_file(filepath)
|
||||
|
||||
@@ -4,8 +4,9 @@ import os
|
||||
from typing import Optional
|
||||
from fastapi import APIRouter, Request
|
||||
from src.auth_helpers import get_current_user
|
||||
from src.constants import USER_PREFS_FILE
|
||||
|
||||
PREFS_FILE = os.path.join("data", "user_prefs.json")
|
||||
PREFS_FILE = USER_PREFS_FILE
|
||||
|
||||
|
||||
def _load():
|
||||
|
||||
@@ -9,6 +9,7 @@ from pydantic import BaseModel, Field
|
||||
|
||||
from src.request_models import PresetUpdateRequest
|
||||
from core.middleware import require_admin
|
||||
from src.auth_helpers import effective_user
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -100,7 +101,8 @@ def setup_preset_routes(preset_manager) -> APIRouter:
|
||||
|
||||
try:
|
||||
model_spec = data.get("model") or ""
|
||||
url, model, headers = _resolve_model(model_spec)
|
||||
user = effective_user(request)
|
||||
url, model, headers = _resolve_model(model_spec, owner=user)
|
||||
result = await llm_call_async(url, model, messages, temperature=0.8, max_tokens=500, headers=headers)
|
||||
return {"success": True, "prompt": result.strip()}
|
||||
except Exception as e:
|
||||
|
||||
+61
-46
@@ -14,6 +14,7 @@ from fastapi.responses import HTMLResponse, StreamingResponse
|
||||
from pydantic import BaseModel, Field
|
||||
from src.endpoint_resolver import resolve_endpoint
|
||||
from src.auth_helpers import _auth_disabled, get_current_user
|
||||
from src.constants import DEEP_RESEARCH_DIR
|
||||
|
||||
_SESSION_ID_RE = re.compile(r"^[a-zA-Z0-9-]{1,128}$")
|
||||
|
||||
@@ -37,13 +38,15 @@ def _first_chat_model(models) -> str:
|
||||
return (models[0] if models else "")
|
||||
|
||||
|
||||
def _resolve_research_endpoint(sess) -> tuple:
|
||||
def _resolve_research_endpoint(sess, owner: Optional[str] = None) -> tuple:
|
||||
"""Return (endpoint_url, model, headers) for Deep Research, checking admin overrides."""
|
||||
owner = owner or getattr(sess, "owner", None) or None
|
||||
url, model, headers = resolve_endpoint(
|
||||
"research",
|
||||
fallback_url=sess.endpoint_url,
|
||||
fallback_model=sess.model,
|
||||
fallback_headers=sess.headers,
|
||||
owner=owner,
|
||||
)
|
||||
return url, model, headers
|
||||
|
||||
@@ -72,6 +75,38 @@ def _owned_enabled_endpoint(db, owner, endpoint_id=None):
|
||||
return owner_filter(q, ModelEndpoint, owner).first()
|
||||
|
||||
|
||||
def _resolve_endpoint_runtime(ep, owner=None, model: Optional[str] = None):
|
||||
"""Resolve a ModelEndpoint row into (chat_url, model, headers).
|
||||
|
||||
Mirrors endpoint_resolver.resolve_endpoint's provider-auth handling for
|
||||
panel-selected research endpoints. ChatGPT Subscription endpoints keep
|
||||
OAuth tokens in ProviderAuthSession, so ep.api_key is intentionally empty.
|
||||
"""
|
||||
from src.endpoint_resolver import (
|
||||
build_chat_url,
|
||||
build_headers,
|
||||
resolve_endpoint_runtime as resolve_model_endpoint_runtime,
|
||||
)
|
||||
|
||||
try:
|
||||
base, api_key = resolve_model_endpoint_runtime(ep, owner=owner)
|
||||
except Exception as e:
|
||||
logger.warning("Could not resolve endpoint credentials for research: %s", e)
|
||||
return None
|
||||
|
||||
ep_model = (model or "").strip()
|
||||
if not ep_model:
|
||||
try:
|
||||
models = json.loads(ep.cached_models) if ep.cached_models else []
|
||||
if models:
|
||||
ep_model = _first_chat_model(models)
|
||||
except Exception:
|
||||
pass
|
||||
if not ep_model:
|
||||
return None
|
||||
return build_chat_url(base), ep_model, build_headers(api_key, base)
|
||||
|
||||
|
||||
def setup_research_routes(research_handler, session_manager=None) -> APIRouter:
|
||||
router = APIRouter(tags=["research"])
|
||||
|
||||
@@ -98,7 +133,7 @@ def setup_research_routes(research_handler, session_manager=None) -> APIRouter:
|
||||
if entry is not None:
|
||||
return entry.get("owner", "") == user
|
||||
# Task no longer in memory — check the persisted JSON.
|
||||
path = Path("data/deep_research") / f"{session_id}.json"
|
||||
path = Path(DEEP_RESEARCH_DIR) / f"{session_id}.json"
|
||||
if not path.exists():
|
||||
return False
|
||||
try:
|
||||
@@ -162,7 +197,7 @@ def setup_research_routes(research_handler, session_manager=None) -> APIRouter:
|
||||
def _assert_owns_research(session_id: str, user: str) -> None:
|
||||
"""404-not-403 ownership gate for a research session's on-disk JSON.
|
||||
Use BEFORE returning any data or mutating the file."""
|
||||
path = Path("data/deep_research") / f"{session_id}.json"
|
||||
path = Path(DEEP_RESEARCH_DIR) / f"{session_id}.json"
|
||||
if not path.exists():
|
||||
raise HTTPException(404, "Research not found")
|
||||
try:
|
||||
@@ -225,7 +260,7 @@ def setup_research_routes(research_handler, session_manager=None) -> APIRouter:
|
||||
):
|
||||
user = _require_user(request)
|
||||
"""List all completed research for the Library panel."""
|
||||
data_dir = Path("data/deep_research")
|
||||
data_dir = Path(DEEP_RESEARCH_DIR)
|
||||
items = []
|
||||
for p in data_dir.glob("*.json"):
|
||||
try:
|
||||
@@ -275,7 +310,7 @@ def setup_research_routes(research_handler, session_manager=None) -> APIRouter:
|
||||
summary, stats — used by the Library preview panel."""
|
||||
user = _require_user(request)
|
||||
_validate_session_id(session_id)
|
||||
path = Path("data/deep_research") / f"{session_id}.json"
|
||||
path = Path(DEEP_RESEARCH_DIR) / f"{session_id}.json"
|
||||
if not path.exists():
|
||||
raise HTTPException(404, "Research not found")
|
||||
try:
|
||||
@@ -292,7 +327,7 @@ def setup_research_routes(research_handler, session_manager=None) -> APIRouter:
|
||||
"""Soft-archive / restore a research report (sets `archived` in its JSON)."""
|
||||
user = _require_user(request)
|
||||
_validate_session_id(session_id)
|
||||
path = Path("data/deep_research") / f"{session_id}.json"
|
||||
path = Path(DEEP_RESEARCH_DIR) / f"{session_id}.json"
|
||||
if not path.exists():
|
||||
raise HTTPException(404, "Research not found")
|
||||
try:
|
||||
@@ -312,7 +347,7 @@ def setup_research_routes(research_handler, session_manager=None) -> APIRouter:
|
||||
"""Delete a research result from disk."""
|
||||
user = _require_user(request)
|
||||
_validate_session_id(session_id)
|
||||
data_dir = Path("data/deep_research")
|
||||
data_dir = Path(DEEP_RESEARCH_DIR)
|
||||
json_path = data_dir / f"{session_id}.json"
|
||||
deleted = False
|
||||
if json_path.exists():
|
||||
@@ -368,7 +403,6 @@ def setup_research_routes(research_handler, session_manager=None) -> APIRouter:
|
||||
|
||||
if body.endpoint_id:
|
||||
from src.database import SessionLocal
|
||||
from src.endpoint_resolver import normalize_base, build_chat_url, build_headers
|
||||
db = SessionLocal()
|
||||
try:
|
||||
# Owner-scoped: never resolve another user's private endpoint
|
||||
@@ -377,35 +411,26 @@ def setup_research_routes(research_handler, session_manager=None) -> APIRouter:
|
||||
ep = _owned_enabled_endpoint(db, user, body.endpoint_id)
|
||||
if not ep:
|
||||
raise HTTPException(404, "Endpoint not found or disabled")
|
||||
base = normalize_base(ep.base_url)
|
||||
ep_url = build_chat_url(base)
|
||||
ep_headers = build_headers(ep.api_key, base)
|
||||
ep_model = body.model or ""
|
||||
if not ep_model:
|
||||
try:
|
||||
import json as _json
|
||||
models = _json.loads(ep.cached_models) if ep.cached_models else []
|
||||
if models:
|
||||
ep_model = _first_chat_model(models)
|
||||
except Exception:
|
||||
pass
|
||||
resolved = _resolve_endpoint_runtime(ep, owner=user, model=body.model)
|
||||
if not resolved:
|
||||
raise HTTPException(400, "Endpoint is not configured with a usable model.")
|
||||
ep_url, ep_model, ep_headers = resolved
|
||||
finally:
|
||||
db.close()
|
||||
else:
|
||||
ep_url, ep_model, ep_headers = resolve_endpoint("research")
|
||||
ep_url, ep_model, ep_headers = resolve_endpoint("research", owner=user)
|
||||
if not ep_url:
|
||||
ep_url, ep_model, ep_headers = resolve_endpoint("utility")
|
||||
ep_url, ep_model, ep_headers = resolve_endpoint("utility", owner=user)
|
||||
# When neither research nor utility is configured, use the user's
|
||||
# configured DEFAULT model (default_endpoint_id/default_model) rather
|
||||
# than arbitrarily grabbing the first enabled endpoint's first model
|
||||
# (which surfaced gpt-3.5). "Default" should mean the default model.
|
||||
if not ep_url:
|
||||
ep_url, ep_model, ep_headers = resolve_endpoint("default")
|
||||
ep_url, ep_model, ep_headers = resolve_endpoint("default", owner=user)
|
||||
if not ep_url:
|
||||
ep_url, ep_model, ep_headers = resolve_endpoint("chat")
|
||||
ep_url, ep_model, ep_headers = resolve_endpoint("chat", owner=user)
|
||||
if not ep_url:
|
||||
from src.database import SessionLocal
|
||||
from src.endpoint_resolver import normalize_base, build_chat_url, build_headers
|
||||
db = SessionLocal()
|
||||
try:
|
||||
# Owner-scoped first-enabled fallback: the caller's own rows
|
||||
@@ -414,18 +439,9 @@ def setup_research_routes(research_handler, session_manager=None) -> APIRouter:
|
||||
# /api/v1/chat fallback (webhook_routes._first_enabled_endpoint).
|
||||
ep = _owned_enabled_endpoint(db, user)
|
||||
if ep:
|
||||
base = normalize_base(ep.base_url)
|
||||
ep_url = build_chat_url(base)
|
||||
ep_headers = build_headers(ep.api_key, base)
|
||||
ep_model = ""
|
||||
if ep.cached_models:
|
||||
try:
|
||||
import json as _json
|
||||
models = _json.loads(ep.cached_models)
|
||||
if models:
|
||||
ep_model = _first_chat_model(models)
|
||||
except Exception:
|
||||
pass
|
||||
resolved = _resolve_endpoint_runtime(ep, owner=user)
|
||||
if resolved:
|
||||
ep_url, ep_model, ep_headers = resolved
|
||||
finally:
|
||||
db.close()
|
||||
if not ep_url:
|
||||
@@ -494,7 +510,7 @@ def setup_research_routes(research_handler, session_manager=None) -> APIRouter:
|
||||
raise HTTPException(404, "No research found for this session")
|
||||
result = research_handler.get_result(session_id)
|
||||
if result is None:
|
||||
p = Path("data/deep_research") / f"{session_id}.json"
|
||||
p = Path(DEEP_RESEARCH_DIR) / f"{session_id}.json"
|
||||
if p.exists():
|
||||
d = json.loads(p.read_text(encoding="utf-8"))
|
||||
return {
|
||||
@@ -534,7 +550,7 @@ def setup_research_routes(research_handler, session_manager=None) -> APIRouter:
|
||||
sources = research_handler.get_sources(session_id) or []
|
||||
query = ""
|
||||
|
||||
path = Path("data/deep_research") / f"{session_id}.json"
|
||||
path = Path(DEEP_RESEARCH_DIR) / f"{session_id}.json"
|
||||
if path.exists():
|
||||
try:
|
||||
disk = json.loads(path.read_text(encoding="utf-8"))
|
||||
@@ -572,19 +588,18 @@ def setup_research_routes(research_handler, session_manager=None) -> APIRouter:
|
||||
ep_headers = dict(r_headers)
|
||||
|
||||
if not ep_url or not ep_model:
|
||||
_merge(*resolve_endpoint("chat"))
|
||||
_merge(*resolve_endpoint("chat", owner=user))
|
||||
if not ep_url or not ep_model:
|
||||
_merge(*resolve_endpoint("research"))
|
||||
_merge(*resolve_endpoint("research", owner=user))
|
||||
if not ep_url or not ep_model:
|
||||
_merge(*resolve_endpoint("utility"))
|
||||
_merge(*resolve_endpoint("utility", owner=user))
|
||||
if not ep_url or not ep_model:
|
||||
# Last resort: any enabled endpoint
|
||||
# Last resort: this user's enabled endpoint, plus legacy shared rows.
|
||||
from src.database import SessionLocal
|
||||
from src.database import ModelEndpoint
|
||||
from src.endpoint_resolver import normalize_base, build_chat_url, build_headers
|
||||
db = SessionLocal()
|
||||
try:
|
||||
ep = db.query(ModelEndpoint).filter(ModelEndpoint.is_enabled == True).first()
|
||||
ep = _owned_enabled_endpoint(db, user)
|
||||
if ep:
|
||||
base = normalize_base(ep.base_url)
|
||||
fallback_url = build_chat_url(base)
|
||||
@@ -594,7 +609,7 @@ def setup_research_routes(research_handler, session_manager=None) -> APIRouter:
|
||||
try:
|
||||
models = json.loads(ep.cached_models)
|
||||
if models:
|
||||
fallback_model = models[0]
|
||||
fallback_model = _first_chat_model(models)
|
||||
except Exception:
|
||||
pass
|
||||
_merge(fallback_url, fallback_model, fallback_headers)
|
||||
|
||||
+48
-33
@@ -10,8 +10,9 @@ import logging
|
||||
from core.session_manager import SessionManager
|
||||
from core.models import ChatMessage
|
||||
from src.request_models import SessionResponse
|
||||
from core.database import Session as DbSession, SessionLocal, Document, GalleryImage
|
||||
from src.auth_helpers import get_current_user, effective_user
|
||||
from core.database import Session as DbSession, SessionLocal, Document, GalleryImage, utcnow_naive
|
||||
from src.auth_helpers import get_current_user, effective_user, _auth_disabled
|
||||
from src.session_actions import is_session_recently_active
|
||||
|
||||
|
||||
def _sanitize_export_filename(name: str) -> str:
|
||||
@@ -92,35 +93,30 @@ def _reject_compact_during_active_run(session_id: str) -> None:
|
||||
|
||||
|
||||
def _verify_session_owner(request: Request, session_id: str, session_manager=None):
|
||||
"""Verify the current user owns the session. Raises 404 if not.
|
||||
"""Verify the current user owns the session, honoring single-user modes.
|
||||
|
||||
Ownership is checked against the DB row when one exists (unchanged). If
|
||||
there is no DB row but the caller owns an in-memory "ghost" session — one
|
||||
that lives only in ``session_manager`` because it was never persisted, or
|
||||
its DB row was removed out-of-band — fall back to the in-memory owner so the
|
||||
user can still manage and delete it. Without this fallback such sessions are
|
||||
listed by ``/api/sessions`` (they come from the in-memory manager) yet every
|
||||
per-session operation 404s, making them impossible to delete (issue #1044).
|
||||
|
||||
``session_manager`` is optional and defaults to ``None`` so existing callers
|
||||
that only care about persisted sessions keep their exact prior behavior.
|
||||
Authenticated requests must match the stored DB or in-memory owner. When
|
||||
auth is disabled and no user is present, treat the app as single-user mode:
|
||||
verify that the session exists, but do not compare its stored owner. This
|
||||
keeps QA/dev instances with AUTH_ENABLED=false from rejecting owner-stamped
|
||||
rows created while auth was previously enabled.
|
||||
"""
|
||||
user = effective_user(request)
|
||||
if not user:
|
||||
raise HTTPException(403, "Authentication required")
|
||||
if not user and not _auth_disabled():
|
||||
raise HTTPException(401, "Authentication required")
|
||||
db = SessionLocal()
|
||||
try:
|
||||
row = db.query(DbSession.owner).filter(DbSession.id == session_id).first()
|
||||
finally:
|
||||
db.close()
|
||||
if row is not None:
|
||||
if row.owner != user:
|
||||
if user and row.owner != user:
|
||||
raise HTTPException(404, f"Session {session_id} not found")
|
||||
return
|
||||
# No DB row — allow the caller to act on an in-memory ghost they own.
|
||||
if session_manager is not None:
|
||||
ghost = getattr(session_manager, "sessions", {}).get(session_id)
|
||||
if ghost is not None and getattr(ghost, "owner", None) == user:
|
||||
if ghost is not None and (not user or getattr(ghost, "owner", None) == user):
|
||||
return
|
||||
raise HTTPException(404, f"Session {session_id} not found")
|
||||
|
||||
@@ -262,7 +258,7 @@ def setup_session_routes(session_manager: SessionManager, config: dict, webhook_
|
||||
last_msg_map = {}
|
||||
mode_map = {}
|
||||
msg_count_map = {}
|
||||
rows = db.query(DbSession.id, DbSession.folder, DbSession.total_input_tokens, DbSession.total_output_tokens, DbSession.is_important, DbSession.created_at, DbSession.updated_at, DbSession.last_message_at, DbSession.mode, DbSession.message_count).filter(DbSession.archived == False).all()
|
||||
rows = db.query(DbSession.id, DbSession.folder, DbSession.total_input_tokens, DbSession.total_output_tokens, DbSession.is_important, DbSession.created_at, DbSession.updated_at, DbSession.last_message_at, DbSession.mode, DbSession.message_count).filter(DbSession.archived == False, DbSession.owner == user).all()
|
||||
for row in rows:
|
||||
folder_map[row.id] = row.folder
|
||||
token_map[row.id] = (row.total_input_tokens or 0) + (row.total_output_tokens or 0)
|
||||
@@ -284,12 +280,14 @@ def setup_session_routes(session_manager: SessionManager, config: dict, webhook_
|
||||
r[0] for r in db.query(Document.session_id)
|
||||
.filter(Document.is_active == True,
|
||||
Document.current_content != None,
|
||||
func.trim(Document.current_content) != "")
|
||||
func.trim(Document.current_content) != "",
|
||||
Document.owner == user)
|
||||
.distinct().all()
|
||||
)
|
||||
img_session_ids = set(
|
||||
r[0] for r in db.query(GalleryImage.session_id)
|
||||
.filter(GalleryImage.session_id != None)
|
||||
.filter(GalleryImage.session_id != None,
|
||||
GalleryImage.owner == user)
|
||||
.distinct().all()
|
||||
)
|
||||
finally:
|
||||
@@ -370,8 +368,13 @@ def setup_session_routes(session_manager: SessionManager, config: dict, webhook_
|
||||
pass
|
||||
elif not model_to_use:
|
||||
from src.llm_core import list_model_ids
|
||||
ids = list_model_ids(endpoint_url, timeout=REQUEST_TIMEOUT,
|
||||
headers=validation_headers)
|
||||
ids = list_model_ids(
|
||||
endpoint_url,
|
||||
timeout=REQUEST_TIMEOUT,
|
||||
headers=validation_headers,
|
||||
owner=user,
|
||||
endpoint_id=endpoint_id.strip() if endpoint_id else None,
|
||||
)
|
||||
if not ids:
|
||||
raise HTTPException(400, "Cannot reach /v1/models")
|
||||
# Default to the first CHAT model — endpoints often list embedding/
|
||||
@@ -385,8 +388,13 @@ def setup_session_routes(session_manager: SessionManager, config: dict, webhook_
|
||||
from src.llm_core import list_model_ids
|
||||
import os as _os
|
||||
req_base = _os.path.basename(model_to_use.rstrip("/"))
|
||||
avail = list_model_ids(endpoint_url, timeout=REQUEST_TIMEOUT,
|
||||
headers=validation_headers)
|
||||
avail = list_model_ids(
|
||||
endpoint_url,
|
||||
timeout=REQUEST_TIMEOUT,
|
||||
headers=validation_headers,
|
||||
owner=user,
|
||||
endpoint_id=endpoint_id.strip() if endpoint_id else None,
|
||||
)
|
||||
if not avail:
|
||||
raise HTTPException(400, "Cannot reach /v1/models")
|
||||
if model_to_use not in avail:
|
||||
@@ -543,22 +551,25 @@ def setup_session_routes(session_manager: SessionManager, config: dict, webhook_
|
||||
ids = body.get("ids", [])
|
||||
except Exception:
|
||||
ids = []
|
||||
deleted_count = 0
|
||||
for sid in ids:
|
||||
try:
|
||||
_verify_session_owner(request, sid, session_manager)
|
||||
session_manager.delete_session(sid)
|
||||
|
||||
# Enforce "starred" protection consistent with single-session delete
|
||||
db = SessionLocal()
|
||||
try:
|
||||
db.query(_CM).filter(_CM.session_id == sid).delete()
|
||||
db.query(DbSession).filter(DbSession.id == sid).delete()
|
||||
db.commit()
|
||||
except Exception:
|
||||
db.rollback()
|
||||
db_sess = db.query(DbSession).filter(DbSession.id == sid).first()
|
||||
if db_sess and db_sess.is_important:
|
||||
continue
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
if session_manager.delete_session(sid):
|
||||
deleted_count += 1
|
||||
except Exception:
|
||||
pass
|
||||
return {"deleted": len(ids)}
|
||||
return {"deleted": deleted_count}
|
||||
|
||||
@router.delete("/session/{sid}")
|
||||
def delete_session(request: Request, sid: str):
|
||||
@@ -924,7 +935,8 @@ def setup_session_routes(session_manager: SessionManager, config: dict, webhook_
|
||||
from src.endpoint_resolver import resolve_endpoint
|
||||
from src.llm_core import llm_call_async
|
||||
|
||||
url, model, headers = resolve_endpoint("utility", owner=get_current_user(request))
|
||||
owner = getattr(session, "owner", None) or effective_user(request)
|
||||
url, model, headers = resolve_endpoint("utility", owner=owner)
|
||||
if not url or not model:
|
||||
url, model, headers = session.endpoint_url, session.model, session.headers
|
||||
if not url or not model:
|
||||
@@ -1006,7 +1018,7 @@ def setup_session_routes(session_manager: SessionManager, config: dict, webhook_
|
||||
}
|
||||
_THROWAWAY_MAX_MESSAGES = 4 # only delete if <= this many messages
|
||||
try:
|
||||
rows = db.query(DbSession).filter(DbSession.archived == False, DbSession.owner == user).all()
|
||||
rows = db.query(DbSession).filter(DbSession.archived == False, DbSession.owner == user).limit(2000).all()
|
||||
folder_map = {r.id: r.folder for r in rows}
|
||||
# Precompute per-session message counts in TWO aggregate queries
|
||||
# instead of 1–3 queries PER session — with many chats the per-row
|
||||
@@ -1017,6 +1029,7 @@ def setup_session_routes(session_manager: SessionManager, config: dict, webhook_
|
||||
db.query(DbMsg.session_id, _sa_func.count(DbMsg.id))
|
||||
.filter(DbMsg.role == "assistant").group_by(DbMsg.session_id).all()
|
||||
)
|
||||
cleanup_now = utcnow_naive()
|
||||
for row in rows:
|
||||
# Never delete important sessions
|
||||
if getattr(row, 'is_important', False):
|
||||
@@ -1029,6 +1042,8 @@ def setup_session_routes(session_manager: SessionManager, config: dict, webhook_
|
||||
if hasattr(session_manager, 'delete_session'):
|
||||
session_manager.delete_session(row.id)
|
||||
continue
|
||||
if is_session_recently_active(row, now=cleanup_now):
|
||||
continue
|
||||
msg_count = _counts.get(row.id, 0)
|
||||
should_delete = False
|
||||
if msg_count == 0:
|
||||
|
||||
+279
-58
@@ -13,6 +13,7 @@ import tempfile
|
||||
from collections import namedtuple
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any
|
||||
from core.platform_compat import IS_APPLE_SILICON, which_tool
|
||||
|
||||
# POSIX-only: `pty`/`fcntl` transitively import `termios`, which does NOT exist
|
||||
# on Windows, so importing them unconditionally crashed app startup there
|
||||
@@ -37,6 +38,7 @@ from core.platform_compat import (
|
||||
IS_WINDOWS,
|
||||
detached_popen_kwargs,
|
||||
find_bash,
|
||||
git_bash_path,
|
||||
)
|
||||
|
||||
|
||||
@@ -92,6 +94,7 @@ def _venv_activate_prefix(venv: str | None) -> str:
|
||||
act = venv if venv.endswith("/bin/activate") else venv.rstrip("/") + "/bin/activate"
|
||||
return f". {act} && "
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
PTY_SUPPORTED = pty is not None and fcntl is not None and hasattr(os, "setsid")
|
||||
@@ -169,7 +172,10 @@ def _package_installed_from_probe(name: str, probe: dict) -> bool:
|
||||
and (dists.get("torch") or modules.get("torch", {}).get("real_module"))
|
||||
)
|
||||
if name == "hf_transfer":
|
||||
return bool(dists.get("hf-transfer") or modules.get("hf_transfer", {}).get("real_module"))
|
||||
return bool(
|
||||
dists.get("hf-transfer")
|
||||
or modules.get("hf_transfer", {}).get("real_module")
|
||||
)
|
||||
return bool(dists.get(name) or modules.get(name, {}).get("real_module"))
|
||||
|
||||
|
||||
@@ -194,8 +200,14 @@ def _package_status_note(name: str, probe: dict) -> str:
|
||||
if binaries.get("llama-server"):
|
||||
parts.append(f"native llama-server: {binaries['llama-server']}")
|
||||
if dists.get("llama-cpp-python"):
|
||||
parts.append(f"python package: llama-cpp-python {dists['llama-cpp-python']}")
|
||||
return "; ".join(parts) if parts else "No native llama-server or llama-cpp-python server package found."
|
||||
parts.append(
|
||||
f"python package: llama-cpp-python {dists['llama-cpp-python']}"
|
||||
)
|
||||
return (
|
||||
"; ".join(parts)
|
||||
if parts
|
||||
else "No native llama-server or llama-cpp-python server package found."
|
||||
)
|
||||
if name == "diffusers":
|
||||
if _package_installed_from_probe(name, probe):
|
||||
return f"diffusers {dists.get('diffusers', 'available')} with torch {dists.get('torch', 'available')}"
|
||||
@@ -205,7 +217,9 @@ def _package_status_note(name: str, probe: dict) -> str:
|
||||
return ""
|
||||
|
||||
|
||||
def _package_pip_update_status(pkg: dict, probe: dict | None = None) -> PackageUpdateStatus:
|
||||
def _package_pip_update_status(
|
||||
pkg: dict, probe: dict | None = None
|
||||
) -> PackageUpdateStatus:
|
||||
"""Return whether the Dependencies UI should offer a generic pip update.
|
||||
|
||||
"Installed" means Cookbook can use the dependency. It does not always mean
|
||||
@@ -213,12 +227,28 @@ def _package_pip_update_status(pkg: dict, probe: dict | None = None) -> PackageU
|
||||
native llama-server can come from a package manager/source build, and a CLI
|
||||
may be on PATH without matching Python package metadata.
|
||||
"""
|
||||
if pkg.get("name") == "APFEL":
|
||||
return PackageUpdateStatus(
|
||||
False,
|
||||
"", # Note is empty because IT DOES allow for updates outside of PIP.
|
||||
)
|
||||
|
||||
if pkg.get("kind") == "system" or not pkg.get("pip"):
|
||||
return PackageUpdateStatus(False, "Update this system dependency outside Odysseus.")
|
||||
return PackageUpdateStatus(
|
||||
False, "Update this system dependency outside Odysseus."
|
||||
)
|
||||
|
||||
name = pkg.get("name")
|
||||
binaries = probe.get("binaries") if isinstance(probe, dict) and isinstance(probe.get("binaries"), dict) else {}
|
||||
dists = probe.get("dists") if isinstance(probe, dict) and isinstance(probe.get("dists"), dict) else {}
|
||||
binaries = (
|
||||
probe.get("binaries")
|
||||
if isinstance(probe, dict) and isinstance(probe.get("binaries"), dict)
|
||||
else {}
|
||||
)
|
||||
dists = (
|
||||
probe.get("dists")
|
||||
if isinstance(probe, dict) and isinstance(probe.get("dists"), dict)
|
||||
else {}
|
||||
)
|
||||
|
||||
if name == "llama_cpp" and binaries.get("llama-server"):
|
||||
return PackageUpdateStatus(
|
||||
@@ -231,7 +261,9 @@ def _package_pip_update_status(pkg: dict, probe: dict | None = None) -> PackageU
|
||||
"Using a vLLM CLI on PATH without Python package metadata; update it outside Odysseus.",
|
||||
)
|
||||
|
||||
return PackageUpdateStatus(True, "Update uses pip in the selected Python environment.")
|
||||
return PackageUpdateStatus(
|
||||
True, "Update uses pip in the selected Python environment."
|
||||
)
|
||||
|
||||
|
||||
def _prepend_user_install_bins_to_path() -> None:
|
||||
@@ -250,7 +282,9 @@ def _prepend_user_install_bins_to_path() -> None:
|
||||
candidates = []
|
||||
candidates.append(os.path.expanduser("~/.local/bin"))
|
||||
|
||||
parts = os.environ.get("PATH", "").split(os.pathsep) if os.environ.get("PATH") else []
|
||||
parts = (
|
||||
os.environ.get("PATH", "").split(os.pathsep) if os.environ.get("PATH") else []
|
||||
)
|
||||
changed = False
|
||||
for path in reversed([p for p in candidates if p]):
|
||||
if path not in parts:
|
||||
@@ -357,9 +391,11 @@ PTY_UNSUPPORTED_ERROR = "pty_unsupported"
|
||||
|
||||
class ShellExecRequest(BaseModel):
|
||||
command: str
|
||||
timeout: int | None = None # optional override; 0 = no timeout (run until client disconnects)
|
||||
use_pty: bool = False # use pseudo-TTY (for progress bars)
|
||||
use_tmux: bool = False # run in tmux session (survives browser disconnect)
|
||||
timeout: int | None = (
|
||||
None # optional override; 0 = no timeout (run until client disconnects)
|
||||
)
|
||||
use_pty: bool = False # use pseudo-TTY (for progress bars)
|
||||
use_tmux: bool = False # run in tmux session (survives browser disconnect)
|
||||
|
||||
|
||||
async def _create_shell(command: str, **kwargs):
|
||||
@@ -368,8 +404,16 @@ async def _create_shell(command: str, **kwargs):
|
||||
POSIX: /bin/sh via create_subprocess_shell (unchanged behaviour).
|
||||
Windows: prefer a real bash (Git Bash/WSL) so bash-syntax commands behave
|
||||
the same as on Linux; fall back to cmd.exe when no bash is installed.
|
||||
Powershell commands are executed directly via cmd.exe /c to avoid quoting
|
||||
and env variable expansion errors under Git Bash.
|
||||
"""
|
||||
if IS_WINDOWS:
|
||||
# PowerShell commands (used by the frontend for Windows log-file polling
|
||||
# and session management) must run directly — passing them through
|
||||
# bash -c mangles $env:VAR syntax and breaks the command.
|
||||
cmd_trim = command.strip()
|
||||
if cmd_trim.startswith("powershell") or cmd_trim.startswith("cmd "):
|
||||
return await asyncio.create_subprocess_shell(command, **kwargs)
|
||||
bash = find_bash()
|
||||
if bash:
|
||||
return await asyncio.create_subprocess_exec(bash, "-c", command, **kwargs)
|
||||
@@ -386,9 +430,7 @@ async def _exec_shell(command: str, timeout: int = EXEC_TIMEOUT) -> Dict[str, An
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
cwd=str(Path.home()),
|
||||
)
|
||||
stdout_b, stderr_b = await asyncio.wait_for(
|
||||
proc.communicate(), timeout=timeout
|
||||
)
|
||||
stdout_b, stderr_b = await asyncio.wait_for(proc.communicate(), timeout=timeout)
|
||||
stdout = stdout_b.decode(errors="replace")[:MAX_OUTPUT]
|
||||
stderr = stderr_b.decode(errors="replace")[:MAX_OUTPUT]
|
||||
return {"stdout": stdout, "stderr": stderr, "exit_code": proc.returncode}
|
||||
@@ -399,7 +441,11 @@ async def _exec_shell(command: str, timeout: int = EXEC_TIMEOUT) -> Dict[str, An
|
||||
await proc.wait()
|
||||
except ProcessLookupError:
|
||||
pass
|
||||
return {"stdout": "", "stderr": f"Command timed out after {timeout}s", "exit_code": -1}
|
||||
return {
|
||||
"stdout": "",
|
||||
"stderr": f"Command timed out after {timeout}s",
|
||||
"exit_code": -1,
|
||||
}
|
||||
except Exception as e:
|
||||
return {"stdout": "", "stderr": str(e), "exit_code": -1}
|
||||
|
||||
@@ -481,7 +527,7 @@ async def _generate_pty(cmd: str, timeout: int, request: Request):
|
||||
if idx == -1:
|
||||
break
|
||||
line = buf[:idx].decode(errors="replace")
|
||||
buf = buf[idx + sep_len:]
|
||||
buf = buf[idx + sep_len :]
|
||||
if line:
|
||||
yield f"data: {json.dumps({'stream': 'stdout', 'data': line})}\n\n"
|
||||
|
||||
@@ -503,7 +549,7 @@ async def _generate_pty(cmd: str, timeout: int, request: Request):
|
||||
if idx == -1:
|
||||
break
|
||||
line = buf[:idx].decode(errors="replace")
|
||||
buf = buf[idx + sep_len:]
|
||||
buf = buf[idx + sep_len :]
|
||||
if line:
|
||||
yield f"data: {json.dumps({'stream': 'stdout', 'data': line})}\n\n"
|
||||
if buf:
|
||||
@@ -534,6 +580,7 @@ def _pty_read(fd: int) -> bytes | None:
|
||||
"""Blocking read from PTY fd. Called via run_in_executor.
|
||||
Returns bytes on data, None on timeout (no data yet)."""
|
||||
import select
|
||||
|
||||
r, _, _ = select.select([fd], [], [], 1.0)
|
||||
if r:
|
||||
try:
|
||||
@@ -557,10 +604,10 @@ async def _generate_tmux(cmd: str, request: Request):
|
||||
script_path = TMUX_LOG_DIR / f"{session_id}.sh"
|
||||
script_path.write_text(
|
||||
f"#!/bin/bash\n"
|
||||
f"ODYSSEUS_USER_SHELL=\"${{SHELL:-}}\"\n"
|
||||
f"if [ -n \"$ODYSSEUS_USER_SHELL\" ] && [ -x \"$ODYSSEUS_USER_SHELL\" ]; then\n"
|
||||
f" ODYSSEUS_USER_PATH=\"$(\"$ODYSSEUS_USER_SHELL\" -ic 'printf \"__ODYSSEUS_PATH__%s\\n\" \"$PATH\"' 2>/dev/null | sed -n 's/^__ODYSSEUS_PATH__//p' | tail -n 1 || true)\"\n"
|
||||
f" if [ -n \"$ODYSSEUS_USER_PATH\" ]; then export PATH=\"$ODYSSEUS_USER_PATH:$PATH\"; fi\n"
|
||||
f'ODYSSEUS_USER_SHELL="${{SHELL:-}}"\n'
|
||||
f'if [ -n "$ODYSSEUS_USER_SHELL" ] && [ -x "$ODYSSEUS_USER_SHELL" ]; then\n'
|
||||
f' ODYSSEUS_USER_PATH="$("$ODYSSEUS_USER_SHELL" -ic \'printf "__ODYSSEUS_PATH__%s\\n" "$PATH"\' 2>/dev/null | sed -n \'s/^__ODYSSEUS_PATH__//p\' | tail -n 1 || true)"\n'
|
||||
f' if [ -n "$ODYSSEUS_USER_PATH" ]; then export PATH="$ODYSSEUS_USER_PATH:$PATH"; fi\n'
|
||||
f"fi\n"
|
||||
f"{cmd} 2>&1 | tee '{log_path}'\n"
|
||||
f"EC=${{PIPESTATUS[0]}}\n"
|
||||
@@ -570,7 +617,9 @@ async def _generate_tmux(cmd: str, request: Request):
|
||||
encoding="utf-8",
|
||||
)
|
||||
script_path.chmod(0o755)
|
||||
logger.info("tmux wrapper script created: session=%s path=%s", session_id, script_path)
|
||||
logger.info(
|
||||
"tmux wrapper script created: session=%s path=%s", session_id, script_path
|
||||
)
|
||||
|
||||
tmux_cmd = f"tmux new-session -d -s {session_id} {shlex.quote(str(script_path))}"
|
||||
|
||||
@@ -602,7 +651,9 @@ async def _generate_tmux(cmd: str, request: Request):
|
||||
# Read new lines from log
|
||||
try:
|
||||
if log_path.exists():
|
||||
lines = log_path.read_text(encoding="utf-8", errors="replace").splitlines()
|
||||
lines = log_path.read_text(
|
||||
encoding="utf-8", errors="replace"
|
||||
).splitlines()
|
||||
new_lines = lines[lines_sent:]
|
||||
for line in new_lines:
|
||||
if line.startswith(":::EXIT_CODE:::"):
|
||||
@@ -630,7 +681,9 @@ async def _generate_tmux(cmd: str, request: Request):
|
||||
# Session ended — do one final read
|
||||
await asyncio.sleep(0.5)
|
||||
if log_path.exists():
|
||||
lines = log_path.read_text(encoding="utf-8", errors="replace").splitlines()
|
||||
lines = log_path.read_text(
|
||||
encoding="utf-8", errors="replace"
|
||||
).splitlines()
|
||||
for line in lines[lines_sent:]:
|
||||
if line.startswith(":::EXIT_CODE:::"):
|
||||
try:
|
||||
@@ -672,8 +725,8 @@ async def _generate_win_detached(cmd: str, request: Request):
|
||||
if bash:
|
||||
script_path = TMUX_LOG_DIR / f"{session_id}.sh"
|
||||
script_path.write_text(
|
||||
f"{cmd} > {shlex.quote(str(log_path))} 2>&1\n"
|
||||
f"echo $? > {shlex.quote(str(exit_path))}\n",
|
||||
f"{cmd} > {shlex.quote(git_bash_path(log_path))} 2>&1\n"
|
||||
f"echo $? > {shlex.quote(git_bash_path(exit_path))}\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
argv = [bash, str(script_path)]
|
||||
@@ -711,7 +764,9 @@ async def _generate_win_detached(cmd: str, request: Request):
|
||||
return
|
||||
try:
|
||||
if log_path.exists():
|
||||
lines = log_path.read_text(encoding="utf-8", errors="replace").splitlines()
|
||||
lines = log_path.read_text(
|
||||
encoding="utf-8", errors="replace"
|
||||
).splitlines()
|
||||
for line in lines[lines_sent:]:
|
||||
yield f"data: {json.dumps({'stream': 'stdout', 'data': line})}\n\n"
|
||||
lines_sent = len(lines)
|
||||
@@ -723,11 +778,18 @@ async def _generate_win_detached(cmd: str, request: Request):
|
||||
await asyncio.sleep(0.3)
|
||||
try:
|
||||
if log_path.exists():
|
||||
lines = log_path.read_text(encoding="utf-8", errors="replace").splitlines()
|
||||
lines = log_path.read_text(
|
||||
encoding="utf-8", errors="replace"
|
||||
).splitlines()
|
||||
for line in lines[lines_sent:]:
|
||||
yield f"data: {json.dumps({'stream': 'stdout', 'data': line})}\n\n"
|
||||
lines_sent = len(lines)
|
||||
exit_code = int((exit_path.read_text(encoding="utf-8", errors="replace").strip() or "0"))
|
||||
exit_code = int(
|
||||
(
|
||||
exit_path.read_text(encoding="utf-8", errors="replace").strip()
|
||||
or "0"
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
exit_code = 0
|
||||
break
|
||||
@@ -753,7 +815,9 @@ def setup_shell_routes() -> APIRouter:
|
||||
return {"stdout": "", "stderr": "No command provided", "exit_code": 1}
|
||||
|
||||
logger.info("User shell exec requested: length=%d", len(cmd))
|
||||
result = await _exec_shell(cmd, timeout=EXEC_TIMEOUT)
|
||||
result = await _exec_shell(
|
||||
cmd, timeout=req.timeout if req.timeout is not None else EXEC_TIMEOUT
|
||||
)
|
||||
return result
|
||||
|
||||
@router.post("/api/shell/stream")
|
||||
@@ -762,9 +826,11 @@ def setup_shell_routes() -> APIRouter:
|
||||
_require_admin(request)
|
||||
cmd = req.command.strip()
|
||||
if not cmd:
|
||||
|
||||
async def empty():
|
||||
yield f"data: {json.dumps({'stream': 'stderr', 'data': 'No command provided'})}\n\n"
|
||||
yield f"data: {json.dumps({'exit_code': 1})}\n\n"
|
||||
|
||||
return StreamingResponse(empty(), media_type="text/event-stream")
|
||||
|
||||
timeout = req.timeout if req.timeout is not None else STREAM_TIMEOUT
|
||||
@@ -781,7 +847,11 @@ def setup_shell_routes() -> APIRouter:
|
||||
if use_tmux:
|
||||
# tmux is POSIX-only; Windows uses a detached-process + logfile tail
|
||||
# that preserves the "survives disconnect" behaviour.
|
||||
gen = _generate_win_detached(cmd, request) if IS_WINDOWS else _generate_tmux(cmd, request)
|
||||
gen = (
|
||||
_generate_win_detached(cmd, request)
|
||||
if IS_WINDOWS
|
||||
else _generate_tmux(cmd, request)
|
||||
)
|
||||
return StreamingResponse(gen, media_type="text/event-stream")
|
||||
|
||||
if use_pty and not IS_WINDOWS:
|
||||
@@ -813,7 +883,12 @@ def setup_shell_routes() -> APIRouter:
|
||||
chunk = await stream.read(4096)
|
||||
if not chunk:
|
||||
if buf:
|
||||
await q.put((name, buf.decode(errors="replace").rstrip("\r\n")))
|
||||
await q.put(
|
||||
(
|
||||
name,
|
||||
buf.decode(errors="replace").rstrip("\r\n"),
|
||||
)
|
||||
)
|
||||
break
|
||||
buf += chunk
|
||||
while True:
|
||||
@@ -821,7 +896,7 @@ def setup_shell_routes() -> APIRouter:
|
||||
if idx == -1:
|
||||
break
|
||||
line = buf[:idx].decode(errors="replace")
|
||||
buf = buf[idx + sep_len:]
|
||||
buf = buf[idx + sep_len :]
|
||||
if line:
|
||||
await q.put((name, line))
|
||||
finally:
|
||||
@@ -880,7 +955,12 @@ def setup_shell_routes() -> APIRouter:
|
||||
return StreamingResponse(generate(), media_type="text/event-stream")
|
||||
|
||||
@router.get("/api/cookbook/packages")
|
||||
async def list_packages(request: Request, host: str | None = None, ssh_port: str | None = None, venv: str | None = None):
|
||||
async def list_packages(
|
||||
request: Request,
|
||||
host: str | None = None,
|
||||
ssh_port: str | None = None,
|
||||
venv: str | None = None,
|
||||
):
|
||||
"""Check which optional packages are installed.
|
||||
|
||||
Local-target packages are checked in-process. Remote-target packages
|
||||
@@ -890,7 +970,13 @@ def setup_shell_routes() -> APIRouter:
|
||||
"""
|
||||
_require_admin(request)
|
||||
_reject_cross_site(request)
|
||||
import importlib, importlib.metadata as importlib_metadata, shlex, json as _json, site, sys
|
||||
import importlib
|
||||
import importlib.metadata as importlib_metadata
|
||||
import shlex
|
||||
import json as _json
|
||||
import site
|
||||
import sys
|
||||
|
||||
_prepend_user_install_bins_to_path()
|
||||
importlib.invalidate_caches()
|
||||
try:
|
||||
@@ -905,26 +991,115 @@ def setup_shell_routes() -> APIRouter:
|
||||
raise HTTPException(400, "Invalid ssh_port")
|
||||
packages = [
|
||||
# ── System ── OS binaries, not pip packages
|
||||
{"name": "tmux", "pip": "", "desc": "Required for Linux/Termux Cookbook background downloads and serves", "category": "System", "target": "remote", "kind": "system", "install_hint": "Run Cookbook server setup, or install tmux with apt/pacman/dnf/apk/zypper."},
|
||||
{"name": "docker", "pip": "", "desc": "Required only for Docker-backed launch commands", "category": "System", "target": "remote", "kind": "system", "install_hint": "Install Docker on the selected server and allow this user to run docker."},
|
||||
{
|
||||
"name": "tmux",
|
||||
"pip": "",
|
||||
"desc": "Required for Linux/Termux Cookbook background downloads and serves",
|
||||
"category": "System",
|
||||
"target": "remote",
|
||||
"kind": "system",
|
||||
"install_hint": "Run Cookbook server setup, or install tmux with apt/pacman/dnf/apk/zypper.",
|
||||
},
|
||||
{
|
||||
"name": "docker",
|
||||
"pip": "",
|
||||
"desc": "Required only for Docker-backed launch commands",
|
||||
"category": "System",
|
||||
"target": "remote",
|
||||
"kind": "system",
|
||||
"install_hint": "Install Docker on the selected server and allow this user to run docker.",
|
||||
},
|
||||
# ── LLM ── installs on GPU servers for model serving/downloading
|
||||
{"name": "hf_transfer", "pip": "hf_transfer", "desc": "Fast model downloads from HuggingFace", "category": "LLM", "target": "remote"},
|
||||
{"name": "llama_cpp", "pip": "llama-cpp-python[server]", "desc": "Serve GGUF models via llama.cpp", "category": "LLM", "target": "remote"},
|
||||
{"name": "sglang", "pip": "sglang[all]", "desc": "Serve HF safetensors models via SGLang", "category": "LLM", "target": "remote"},
|
||||
{"name": "vllm", "pip": "vllm", "desc": "High-throughput LLM serving engine", "category": "LLM", "target": "remote"},
|
||||
{
|
||||
"name": "hf_transfer",
|
||||
"pip": "hf_transfer",
|
||||
"desc": "Fast model downloads from HuggingFace",
|
||||
"category": "LLM",
|
||||
"target": "remote",
|
||||
},
|
||||
{
|
||||
"name": "llama_cpp",
|
||||
"pip": "llama-cpp-python[server]",
|
||||
"desc": "Serve GGUF models via llama.cpp",
|
||||
"category": "LLM",
|
||||
"target": "remote",
|
||||
},
|
||||
{
|
||||
"name": "sglang",
|
||||
"pip": "sglang[all]",
|
||||
"desc": "Serve HF safetensors models via SGLang",
|
||||
"category": "LLM",
|
||||
"target": "remote",
|
||||
},
|
||||
{
|
||||
"name": "vllm",
|
||||
"pip": "vllm",
|
||||
"desc": "High-throughput LLM serving engine",
|
||||
"category": "LLM",
|
||||
"target": "remote",
|
||||
},
|
||||
{
|
||||
"name": "APFEL",
|
||||
"pip": "",
|
||||
"desc": "OpenAI-compatible API for Apple Foundational Models on Apple Silicon",
|
||||
"category": "LLM",
|
||||
"target": "local",
|
||||
"kind": "system",
|
||||
"install_cmd": "brew install apfel",
|
||||
"update_cmd": "brew upgrade apfel",
|
||||
"install_hint": "Requires a native Apple Silicon Mac with Apple Foundational Models support. Installable via Homebrew on supported Macs.",
|
||||
},
|
||||
# ── Image ── editor + diffusion model serving
|
||||
{"name": "diffusers", "pip": "diffusers[torch]", "desc": "Image generation pipelines (SD, Flux) with PyTorch", "category": "Image", "target": "remote"},
|
||||
{"name": "rembg", "pip": "rembg[gpu]", "desc": "AI background removal for image editor", "category": "Image", "target": "local"},
|
||||
{"name": "realesrgan", "pip": "realesrgan", "desc": "AI denoise + upscale (Real-ESRGAN). Used by editor's Denoise and Upscale tools.", "category": "Image", "target": "local"},
|
||||
{
|
||||
"name": "diffusers",
|
||||
"pip": "diffusers[torch]",
|
||||
"desc": "Image generation pipelines (SD, Flux) with PyTorch",
|
||||
"category": "Image",
|
||||
"target": "remote",
|
||||
},
|
||||
{
|
||||
"name": "rembg",
|
||||
"pip": "rembg[gpu]",
|
||||
"desc": "AI background removal for image editor",
|
||||
"category": "Image",
|
||||
"target": "local",
|
||||
},
|
||||
{
|
||||
"name": "realesrgan",
|
||||
"pip": "realesrgan",
|
||||
"desc": "AI denoise + upscale (Real-ESRGAN). Used by editor's Denoise and Upscale tools.",
|
||||
"category": "Image",
|
||||
"target": "local",
|
||||
},
|
||||
# ── Tools ──
|
||||
{"name": "playwright", "pip": "playwright", "desc": "Browser automation for web tools", "category": "Tools", "target": "local"},
|
||||
{
|
||||
"name": "playwright",
|
||||
"pip": "playwright",
|
||||
"desc": "Browser automation for web tools",
|
||||
"category": "Tools",
|
||||
"target": "local",
|
||||
},
|
||||
]
|
||||
|
||||
# Most packages should not be installed through external means. Hence, set the default of the
|
||||
# install_cmd and update_cmd to None, which indicates that the recommended way to install/update is through the Cookbook # server setup or pip. Only system packages, should have explicit install/update commands provided.
|
||||
for pkg in packages:
|
||||
pkg.setdefault("install_cmd", None)
|
||||
pkg.setdefault("update_cmd", None)
|
||||
# Remote check: for remote-target packages, probe the selected server's
|
||||
# venv over SSH so a remote `pip install` actually reflects here.
|
||||
remote_status: dict = {}
|
||||
remote_details: dict = {}
|
||||
remote_names = [p["name"] for p in packages if p.get("target") == "remote" and p.get("kind") != "system"]
|
||||
remote_system_names = [p["name"] for p in packages if p.get("target") == "remote" and p.get("kind") == "system"]
|
||||
remote_names = [
|
||||
p["name"]
|
||||
for p in packages
|
||||
if p.get("target") == "remote" and p.get("kind") != "system"
|
||||
]
|
||||
remote_system_names = [
|
||||
p["name"]
|
||||
for p in packages
|
||||
if p.get("target") == "remote" and p.get("kind") == "system"
|
||||
]
|
||||
if host and remote_names:
|
||||
try:
|
||||
py = _package_probe_script(remote_names)
|
||||
@@ -934,7 +1109,9 @@ def setup_shell_routes() -> APIRouter:
|
||||
inner = f"{src}python3 -c {shlex.quote(py)}"
|
||||
argv = _ssh_base_argv(host, ssh_port) + [inner]
|
||||
proc = await asyncio.create_subprocess_exec(
|
||||
*argv, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE
|
||||
*argv,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
)
|
||||
out, _err = await asyncio.wait_for(proc.communicate(), timeout=12)
|
||||
txt = out.decode("utf-8", errors="replace").strip()
|
||||
@@ -958,11 +1135,15 @@ def setup_shell_routes() -> APIRouter:
|
||||
checks = []
|
||||
for name in remote_system_names:
|
||||
qn = shlex.quote(name)
|
||||
checks.append(f"if command -v {qn} >/dev/null 2>&1; then echo {qn}=1; else echo {qn}=0; fi")
|
||||
checks.append(
|
||||
f"if command -v {qn} >/dev/null 2>&1; then echo {qn}=1; else echo {qn}=0; fi"
|
||||
)
|
||||
inner = " ; ".join(checks)
|
||||
argv = _ssh_base_argv(host, ssh_port) + [inner]
|
||||
proc = await asyncio.create_subprocess_exec(
|
||||
*argv, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE
|
||||
*argv,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
)
|
||||
out, _err = await asyncio.wait_for(proc.communicate(), timeout=12)
|
||||
txt = out.decode("utf-8", errors="replace").strip()
|
||||
@@ -987,11 +1168,25 @@ def setup_shell_routes() -> APIRouter:
|
||||
if note:
|
||||
pkg["status_note"] = note
|
||||
elif pkg.get("kind") == "system":
|
||||
pkg["installed"] = shutil.which(pkg["name"]) is not None
|
||||
if pkg["name"] == "APFEL":
|
||||
pkg["applicable"] = IS_APPLE_SILICON
|
||||
pkg["installed"] = which_tool("apfel") is not None
|
||||
pkg["status_note"] = (
|
||||
"Available on Apple Silicon (arm64) devices; exposed through a local OpenAI-compatible API."
|
||||
if IS_APPLE_SILICON
|
||||
else "Requires a native Apple Silicon Mac with Apple Foundational Models support."
|
||||
)
|
||||
else:
|
||||
pkg["installed"] = shutil.which(pkg["name"]) is not None
|
||||
elif pkg["name"] == "llama_cpp" and shutil.which("llama-server"):
|
||||
pkg["installed"] = True
|
||||
pkg["status_note"] = f"native llama-server: {shutil.which('llama-server')}"
|
||||
probe = {"binaries": {"llama-server": shutil.which("llama-server")}, "dists": {}}
|
||||
pkg["status_note"] = (
|
||||
f"native llama-server: {shutil.which('llama-server')}"
|
||||
)
|
||||
probe = {
|
||||
"binaries": {"llama-server": shutil.which("llama-server")},
|
||||
"dists": {},
|
||||
}
|
||||
elif pkg["name"] == "vllm":
|
||||
_vllm_cli = shutil.which("vllm")
|
||||
pkg["installed"] = _vllm_cli is not None
|
||||
@@ -1014,6 +1209,12 @@ def setup_shell_routes() -> APIRouter:
|
||||
pkg["installed"] = False
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
pkg["installed"] = False
|
||||
except Exception:
|
||||
# Installed but crashes on import — e.g. a CUDA build of
|
||||
# llama-cpp-python raising FileNotFoundError when the CUDA
|
||||
# toolkit dir is absent. One broken optional package must not
|
||||
# 500 the entire packages panel; report it as not usable.
|
||||
pkg["installed"] = False
|
||||
|
||||
if pkg.get("installed"):
|
||||
update_status = _package_pip_update_status(pkg, probe)
|
||||
@@ -1037,15 +1238,30 @@ def setup_shell_routes() -> APIRouter:
|
||||
"""Install a package via pip. Admin only — pip install is effectively code exec."""
|
||||
_require_admin(request)
|
||||
import sys as _sys
|
||||
|
||||
body = await request.json()
|
||||
pip_name = body.get("pip")
|
||||
if not pip_name:
|
||||
return {"ok": False, "error": "No package specified"}
|
||||
# Validate against known packages to prevent arbitrary pip install
|
||||
known = {
|
||||
"rembg[gpu]", "hf_transfer", "llama-cpp-python[server]", "sglang[all]", "diffusers", "diffusers[torch]",
|
||||
"TTS", "bark", "faster-whisper", "playwright", "realesrgan", "gfpgan",
|
||||
"insightface", "onnxruntime-gpu", "onnxruntime", "hdbscan", "vllm",
|
||||
"rembg[gpu]",
|
||||
"hf_transfer",
|
||||
"llama-cpp-python[server]",
|
||||
"sglang[all]",
|
||||
"diffusers",
|
||||
"diffusers[torch]",
|
||||
"TTS",
|
||||
"bark",
|
||||
"faster-whisper",
|
||||
"playwright",
|
||||
"realesrgan",
|
||||
"gfpgan",
|
||||
"insightface",
|
||||
"onnxruntime-gpu",
|
||||
"onnxruntime",
|
||||
"hdbscan",
|
||||
"vllm",
|
||||
}
|
||||
if pip_name not in known:
|
||||
return {"ok": False, "error": f"Unknown package: {pip_name}"}
|
||||
@@ -1071,6 +1287,7 @@ def setup_shell_routes() -> APIRouter:
|
||||
"""
|
||||
_require_admin(request)
|
||||
from routes.cookbook_helpers import _llama_cpp_rebuild_cmd
|
||||
|
||||
body = await request.json()
|
||||
engine = str(body.get("engine") or "llamacpp").strip()
|
||||
if engine != "llamacpp":
|
||||
@@ -1079,7 +1296,11 @@ def setup_shell_routes() -> APIRouter:
|
||||
ssh_port = body.get("ssh_port")
|
||||
cmd = _llama_cpp_rebuild_cmd()
|
||||
try:
|
||||
argv = (_ssh_base_argv(host, ssh_port) + [cmd]) if host else ["bash", "-lc", cmd]
|
||||
argv = (
|
||||
(_ssh_base_argv(host, ssh_port) + [cmd])
|
||||
if host
|
||||
else ["bash", "-lc", cmd]
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(400, str(e))
|
||||
try:
|
||||
|
||||
+44
-16
@@ -21,10 +21,44 @@ from src.auth_helpers import get_current_user
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
_DATA_URL_RE = re.compile(
|
||||
r'^data:image/(?P<fmt>png|jpeg|jpg);base64,(?P<data>.+)$',
|
||||
re.IGNORECASE | re.DOTALL,
|
||||
)
|
||||
_DATA_URL_RE = re.compile(r"^data:image/png;base64,(?P<data>.+)$", re.IGNORECASE | re.DOTALL)
|
||||
_ANY_IMAGE_DATA_URL_RE = re.compile(r"^data:image/[^;]+;base64,", re.IGNORECASE)
|
||||
_PNG_MAGIC = b"\x89PNG\r\n\x1a\n"
|
||||
_MAX_SIGNATURE_BYTES = 2 * 1024 * 1024
|
||||
_MAX_SIGNATURE_B64 = ((_MAX_SIGNATURE_BYTES + 2) // 3) * 4
|
||||
_MAX_SIGNATURE_DIMENSION = 4096
|
||||
|
||||
|
||||
def _normalize_signature_png(raw: str) -> str:
|
||||
raw = (raw or "").strip()
|
||||
m = _DATA_URL_RE.match(raw)
|
||||
if m:
|
||||
b64 = m.group("data")
|
||||
elif _ANY_IMAGE_DATA_URL_RE.match(raw):
|
||||
raise HTTPException(400, "Signature data must be a PNG image")
|
||||
else:
|
||||
b64 = raw
|
||||
if len(b64) > _MAX_SIGNATURE_B64:
|
||||
raise HTTPException(400, "Signature PNG is too large")
|
||||
try:
|
||||
payload = base64.b64decode(b64, validate=True)
|
||||
except Exception:
|
||||
raise HTTPException(400, "Signature data must be base64-encoded PNG bytes")
|
||||
if not payload:
|
||||
raise HTTPException(400, "Signature PNG is empty")
|
||||
if len(payload) > _MAX_SIGNATURE_BYTES:
|
||||
raise HTTPException(400, "Signature PNG is too large")
|
||||
if not payload.startswith(_PNG_MAGIC):
|
||||
raise HTTPException(400, "Signature data must be a PNG image")
|
||||
return base64.b64encode(payload).decode("ascii")
|
||||
|
||||
|
||||
def _signature_dimension(value: Optional[int]) -> Optional[int]:
|
||||
if value is None:
|
||||
return None
|
||||
if not isinstance(value, int) or value < 1 or value > _MAX_SIGNATURE_DIMENSION:
|
||||
raise HTTPException(400, "Signature dimensions are invalid")
|
||||
return value
|
||||
|
||||
|
||||
class SignatureCreate(BaseModel):
|
||||
@@ -67,24 +101,18 @@ def setup_signature_routes() -> APIRouter:
|
||||
@router.post("/api/signatures")
|
||||
async def create_signature(request: Request, req: SignatureCreate) -> Dict[str, Any]:
|
||||
user = get_current_user(request)
|
||||
raw = (req.data or "").strip()
|
||||
m = _DATA_URL_RE.match(raw)
|
||||
b64 = m.group("data") if m else raw
|
||||
try:
|
||||
payload = base64.b64decode(b64, validate=True)
|
||||
if not payload:
|
||||
raise ValueError("empty payload")
|
||||
except Exception:
|
||||
raise HTTPException(400, "Signature data must be base64-encoded PNG bytes")
|
||||
b64 = _normalize_signature_png(req.data)
|
||||
width = _signature_dimension(req.width)
|
||||
height = _signature_dimension(req.height)
|
||||
|
||||
sig = Signature(
|
||||
id=str(uuid.uuid4()),
|
||||
owner=user,
|
||||
name=(req.name or "Signature").strip() or "Signature",
|
||||
data_png=b64,
|
||||
width=req.width,
|
||||
height=req.height,
|
||||
svg=req.svg,
|
||||
width=width,
|
||||
height=height,
|
||||
svg=None,
|
||||
)
|
||||
db = SessionLocal()
|
||||
try:
|
||||
|
||||
+107
-1
@@ -11,6 +11,8 @@ import logging
|
||||
import re
|
||||
from typing import List, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Request
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
@@ -51,6 +53,10 @@ class SkillAddRequest(BaseModel):
|
||||
steps: List[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
class SkillImportUrlRequest(BaseModel):
|
||||
url: str = Field(..., min_length=8, max_length=2000)
|
||||
|
||||
|
||||
class SkillUpdateRequest(BaseModel):
|
||||
name: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
@@ -1014,7 +1020,7 @@ def _resolve_audit_models(owner=None):
|
||||
spec = (get_setting("teacher_model", "") or "").strip()
|
||||
if spec:
|
||||
from src.ai_interaction import _resolve_model
|
||||
t_url, t_model, t_headers = _resolve_model(spec)
|
||||
t_url, t_model, t_headers = _resolve_model(spec, owner=owner)
|
||||
if t_url and t_model:
|
||||
teacher = (t_url, t_model, t_headers)
|
||||
except Exception as e:
|
||||
@@ -1103,6 +1109,35 @@ def setup_skills_routes(skills_manager: SkillsManager) -> APIRouter:
|
||||
idx = skills_manager.index_for(owner=user)
|
||||
return {"index": idx, "count": len(idx)}
|
||||
|
||||
@router.get("/slash-catalog")
|
||||
async def get_slash_catalog(request: Request):
|
||||
"""Return skills that are available as slash commands.
|
||||
|
||||
Mirrors the agent prompt's published-skill index so the UI never offers
|
||||
a slash command the model would not normally be allowed to discover.
|
||||
"""
|
||||
user = _owner(request)
|
||||
all_skills = {s.get("name"): s for s in skills_manager.load(owner=user)}
|
||||
entries = []
|
||||
for s in skills_manager.index_for(owner=user):
|
||||
name = (s.get("name") or "").strip()
|
||||
if not name:
|
||||
continue
|
||||
full = all_skills.get(name) or {}
|
||||
category = (s.get("category") or full.get("category") or "general").strip() or "general"
|
||||
entries.append({
|
||||
"type": "skill",
|
||||
"token": f"/{name}",
|
||||
"name": name,
|
||||
"category": f"Skills / {category}",
|
||||
"help": s.get("description") or full.get("description") or "",
|
||||
"usage": f"/{name} <request>",
|
||||
"uses": int(full.get("uses") or 0),
|
||||
"last_used": full.get("last_used"),
|
||||
})
|
||||
entries.sort(key=lambda row: row["name"])
|
||||
return {"skills": entries, "count": len(entries)}
|
||||
|
||||
@router.get("/builtin")
|
||||
async def list_builtin_skills(request: Request):
|
||||
"""Read-only list of the agent's built-in tool capabilities (research,
|
||||
@@ -1203,6 +1238,36 @@ def setup_skills_routes(skills_manager: SkillsManager) -> APIRouter:
|
||||
save_settings(settings)
|
||||
return {"ok": True, "name": name, "is_overridden": False}
|
||||
|
||||
@router.post("/import-from-url")
|
||||
async def import_skill_from_url(request: Request, body: SkillImportUrlRequest):
|
||||
"""Install a SKILL.md bundle from a public GitHub URL (skills.sh links supported)."""
|
||||
require_admin(request)
|
||||
user = _owner(request)
|
||||
from services.memory.skill_importer import (
|
||||
SkillImportError,
|
||||
fetch_skill_bundle,
|
||||
)
|
||||
|
||||
try:
|
||||
files, _src = fetch_skill_bundle(body.url.strip())
|
||||
entry = skills_manager.import_bundle_from_files(
|
||||
files,
|
||||
owner=user,
|
||||
source_url=body.url.strip(),
|
||||
)
|
||||
except SkillImportError as e:
|
||||
raise HTTPException(400, str(e)) from e
|
||||
except httpx.HTTPError as e:
|
||||
logger.warning("skill import fetch failed: %s", e)
|
||||
detail = str(e).strip() or "Could not download skill from URL"
|
||||
raise HTTPException(502, detail) from e
|
||||
except Exception as e:
|
||||
logger.error("skill import failed: %s", e)
|
||||
raise HTTPException(500, "Skill import failed") from e
|
||||
|
||||
_fire_skill_added(user)
|
||||
return {"ok": True, "skill": entry, "files": len(files)}
|
||||
|
||||
@router.post("/add")
|
||||
async def add_skill(request: Request, body: SkillAddRequest):
|
||||
user = _owner(request)
|
||||
@@ -1236,6 +1301,47 @@ def setup_skills_routes(skills_manager: SkillsManager) -> APIRouter:
|
||||
_fire_skill_added(user)
|
||||
return {"ok": True, "deduped": bool(entry.get("_deduped")), "skill": entry}
|
||||
|
||||
@router.post("/{skill_id}/invoke")
|
||||
async def invoke_skill(request: Request, skill_id: str):
|
||||
"""Build a skill-pinned prompt for slash-command invocation.
|
||||
|
||||
This is intentionally server-side so availability, ownership, and usage
|
||||
accounting use the same rules as the SkillsManager.
|
||||
"""
|
||||
user = _owner(request)
|
||||
try:
|
||||
body = await request.json()
|
||||
except Exception:
|
||||
body = {}
|
||||
request_text = (body.get("request") or "").strip() if isinstance(body, dict) else ""
|
||||
|
||||
invokable = {
|
||||
s.get("name"): s for s in skills_manager.index_for(owner=user)
|
||||
if (s.get("name") or "").strip()
|
||||
}
|
||||
match = invokable.get(skill_id)
|
||||
if not match:
|
||||
raise HTTPException(404, "Skill is not available for slash invocation")
|
||||
|
||||
name = match.get("name")
|
||||
md = skills_manager.read_skill_md(name, owner=user)
|
||||
if md is None:
|
||||
raise HTTPException(404, "Skill source unavailable")
|
||||
|
||||
skills_manager.record_use(name, owner=user)
|
||||
message = (
|
||||
"Apply the skill below to my request, following its Procedure / Pitfalls / Verification.\n\n"
|
||||
f"--- BEGIN SKILL ---\n{md}\n--- END SKILL ---\n\n"
|
||||
+ (f"Request: {request_text}" if request_text else "Request: (use the skill as appropriate)")
|
||||
)
|
||||
return {
|
||||
"ok": True,
|
||||
"type": "skill",
|
||||
"name": name,
|
||||
"command": f"/{name}",
|
||||
"message": message,
|
||||
}
|
||||
|
||||
@router.get("/{skill_id}")
|
||||
async def get_skill(request: Request, skill_id: str):
|
||||
user = _owner(request)
|
||||
|
||||
@@ -4,12 +4,10 @@
|
||||
from fastapi import APIRouter, HTTPException, UploadFile, File
|
||||
import logging
|
||||
|
||||
from src.upload_limits import read_upload_limited
|
||||
from src.upload_limits import read_upload_limited, STT_MAX_AUDIO_BYTES
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
STT_MAX_AUDIO_BYTES = 25 * 1024 * 1024
|
||||
|
||||
|
||||
def setup_stt_routes(stt_service):
|
||||
"""Setup STT routes with the provided STT service"""
|
||||
|
||||
+37
-22
@@ -11,7 +11,9 @@ from fastapi import APIRouter, HTTPException, Request
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.database import SessionLocal, ScheduledTask, TaskRun
|
||||
from core.constants import internal_api_base
|
||||
from src.auth_helpers import get_current_user
|
||||
from src.constants import DATA_DIR, EMAIL_URGENCY_CACHE_DIR
|
||||
from src.task_scheduler import compute_next_run, HOUSEKEEPING_DEFAULTS
|
||||
from routes.prefs_routes import _load_for_user, _save_for_user
|
||||
|
||||
@@ -56,7 +58,7 @@ def _maybe_cascade_calendar_event(task) -> None:
|
||||
try:
|
||||
with httpx.Client(timeout=10) as client:
|
||||
r = client.delete(
|
||||
f"http://localhost:7000/api/calendar/events/{uid}",
|
||||
f"{internal_api_base()}/api/calendar/events/{uid}",
|
||||
headers=headers,
|
||||
)
|
||||
if r.status_code >= 400:
|
||||
@@ -81,7 +83,7 @@ def _maybe_cascade_calendar_event(task) -> None:
|
||||
try:
|
||||
with httpx.Client(timeout=10) as client:
|
||||
# Find the Cookbook calendar.
|
||||
cal_r = client.get("http://localhost:7000/api/calendar/calendars", headers=headers)
|
||||
cal_r = client.get(f"{internal_api_base()}/api/calendar/calendars", headers=headers)
|
||||
if cal_r.status_code >= 400:
|
||||
return
|
||||
cals = (cal_r.json() or {}).get("calendars", [])
|
||||
@@ -98,7 +100,7 @@ def _maybe_cascade_calendar_event(task) -> None:
|
||||
start = (now - _td(days=30)).isoformat()
|
||||
end = (now + _td(days=365)).isoformat()
|
||||
ev_r = client.get(
|
||||
"http://localhost:7000/api/calendar/events",
|
||||
f"{internal_api_base()}/api/calendar/events",
|
||||
params={"start": start, "end": end, "calendar": cal_href},
|
||||
headers=headers,
|
||||
)
|
||||
@@ -291,20 +293,24 @@ def setup_task_routes(task_scheduler) -> APIRouter:
|
||||
def _owner(request: Request):
|
||||
return get_current_user(request)
|
||||
|
||||
async def _generate_task_name(prompt: str) -> str:
|
||||
async def _generate_task_name(prompt: str, owner: Optional[str] = None) -> str:
|
||||
"""Use LLM to generate a short task name from the prompt."""
|
||||
try:
|
||||
from src.llm_core import llm_call_async
|
||||
from core.database import Session as DbSession
|
||||
db = SessionLocal()
|
||||
try:
|
||||
recent = db.query(DbSession).filter(
|
||||
q = db.query(DbSession).filter(
|
||||
DbSession.endpoint_url.isnot(None),
|
||||
DbSession.model.isnot(None),
|
||||
).order_by(DbSession.created_at.desc()).first()
|
||||
)
|
||||
if owner:
|
||||
q = q.filter(DbSession.owner == owner)
|
||||
recent = q.order_by(DbSession.created_at.desc()).first()
|
||||
if not recent:
|
||||
return prompt[:50].strip()
|
||||
url, model = recent.endpoint_url, recent.model
|
||||
headers = recent.headers or {}
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
@@ -315,6 +321,7 @@ def setup_task_routes(task_scheduler) -> APIRouter:
|
||||
{"role": "user", "content": prompt[:500]},
|
||||
],
|
||||
max_tokens=20,
|
||||
headers=headers,
|
||||
timeout=15,
|
||||
)
|
||||
title = result.strip().strip('"\'').strip()
|
||||
@@ -429,6 +436,20 @@ def setup_task_routes(task_scheduler) -> APIRouter:
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def _validate_then_task_id(db, then_task_id: Optional[str], user: Optional[str], current_task_id: Optional[str] = None) -> Optional[str]:
|
||||
target_id = (then_task_id or "").strip()
|
||||
if not target_id:
|
||||
return None
|
||||
if current_task_id and target_id == current_task_id:
|
||||
raise HTTPException(400, "Task cannot chain to itself")
|
||||
q = db.query(ScheduledTask).filter(ScheduledTask.id == target_id)
|
||||
if user:
|
||||
q = q.filter(ScheduledTask.owner == user)
|
||||
target = q.first()
|
||||
if not target:
|
||||
raise HTTPException(404, "Chained task not found")
|
||||
return target.id
|
||||
|
||||
@router.post("")
|
||||
async def create_task(request: Request, req: TaskCreate):
|
||||
user = _owner(request)
|
||||
@@ -465,7 +486,7 @@ def setup_task_routes(task_scheduler) -> APIRouter:
|
||||
from src.builtin_actions import BUILTIN_ACTION_INFO
|
||||
name = BUILTIN_ACTION_INFO.get(req.action, req.action or "Action Task")
|
||||
elif req.prompt:
|
||||
name = await _generate_task_name(req.prompt)
|
||||
name = await _generate_task_name(req.prompt, owner=user)
|
||||
else:
|
||||
name = "Untitled Task"
|
||||
|
||||
@@ -492,6 +513,7 @@ def setup_task_routes(task_scheduler) -> APIRouter:
|
||||
task_id = str(uuid.uuid4())
|
||||
db = SessionLocal()
|
||||
try:
|
||||
then_task_id = _validate_then_task_id(db, req.then_task_id, user)
|
||||
notifications_enabled = (
|
||||
False if req.task_type == "action" and req.notifications_enabled is None
|
||||
else bool(req.notifications_enabled) if req.notifications_enabled is not None
|
||||
@@ -527,7 +549,7 @@ def setup_task_routes(task_scheduler) -> APIRouter:
|
||||
output_target=req.output_target,
|
||||
model=req.model or None,
|
||||
endpoint_url=req.endpoint_url or None,
|
||||
then_task_id=req.then_task_id or None,
|
||||
then_task_id=then_task_id,
|
||||
webhook_token=webhook_token,
|
||||
notifications_enabled=notifications_enabled,
|
||||
)
|
||||
@@ -609,7 +631,7 @@ def setup_task_routes(task_scheduler) -> APIRouter:
|
||||
|
||||
removed_files = 0
|
||||
if action == "check_email_urgency":
|
||||
cache_dir = Path("data/email_urgency_cache")
|
||||
cache_dir = Path(EMAIL_URGENCY_CACHE_DIR)
|
||||
if cache_dir.exists():
|
||||
for child in cache_dir.glob("*.json"):
|
||||
try:
|
||||
@@ -618,7 +640,7 @@ def setup_task_routes(task_scheduler) -> APIRouter:
|
||||
except Exception:
|
||||
pass
|
||||
owner_slug = "".join(c if (c.isalnum() or c in "-_.@") else "_" for c in (user or "default"))
|
||||
for state_path in [Path(f"data/email_urgency_state_{owner_slug}.json")]:
|
||||
for state_path in [Path(DATA_DIR) / f"email_urgency_state_{owner_slug}.json"]:
|
||||
try:
|
||||
if state_path.exists():
|
||||
state_path.unlink()
|
||||
@@ -680,15 +702,7 @@ def setup_task_routes(task_scheduler) -> APIRouter:
|
||||
if req.trigger_count is not None:
|
||||
task.trigger_count = req.trigger_count
|
||||
if req.then_task_id is not None:
|
||||
if req.then_task_id:
|
||||
chain_target = db.query(ScheduledTask).filter(
|
||||
ScheduledTask.id == req.then_task_id
|
||||
).first()
|
||||
if not chain_target:
|
||||
raise HTTPException(400, "Chained task not found")
|
||||
if chain_target.owner != user:
|
||||
raise HTTPException(403, "Cannot chain to another user's task")
|
||||
task.then_task_id = req.then_task_id or None
|
||||
task.then_task_id = _validate_then_task_id(db, req.then_task_id, user, current_task_id=task.id)
|
||||
if req.notifications_enabled is not None:
|
||||
task.notifications_enabled = bool(req.notifications_enabled)
|
||||
if req.cron_expression is not None:
|
||||
@@ -969,7 +983,7 @@ def setup_task_routes(task_scheduler) -> APIRouter:
|
||||
"tag", "label", "move", "archive", "delete", "mark", "schedule",
|
||||
)
|
||||
try:
|
||||
from src.agent_tools import get_mcp_manager
|
||||
from src.tool_utils import get_mcp_manager
|
||||
mcp = get_mcp_manager()
|
||||
if mcp:
|
||||
for tool in mcp.get_all_tools():
|
||||
@@ -1064,6 +1078,7 @@ def setup_task_routes(task_scheduler) -> APIRouter:
|
||||
desc = (body.get("description") or "").strip()
|
||||
if not desc:
|
||||
return {"success": False, "message": "Nothing to parse"}
|
||||
user = _owner(request)
|
||||
|
||||
now = _dt.now()
|
||||
# Give the model the current date/time + weekday so relative phrasing
|
||||
@@ -1090,9 +1105,9 @@ def setup_task_routes(task_scheduler) -> APIRouter:
|
||||
"use cron '0 H * * 1-5'. Keep the prompt actionable and self-contained."
|
||||
)
|
||||
try:
|
||||
url, model, headers = resolve_endpoint("utility")
|
||||
url, model, headers = resolve_endpoint("utility", owner=user or None)
|
||||
if not url:
|
||||
url, model, headers = resolve_endpoint("default")
|
||||
url, model, headers = resolve_endpoint("default", owner=user or None)
|
||||
if not (url and model):
|
||||
return {"success": False, "message": "No model endpoint configured"}
|
||||
raw = await llm_call_async(
|
||||
|
||||
+51
-34
@@ -13,9 +13,43 @@ from src.upload_handler import count_recent_uploads
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/upload", tags=["upload"])
|
||||
UPLOAD_RESPONSE_HEADERS = {"X-Content-Type-Options": "nosniff"}
|
||||
|
||||
def setup_upload_routes(upload_handler):
|
||||
"""Setup upload routes with the provided handler"""
|
||||
|
||||
def _upload_root() -> str:
|
||||
from src.constants import UPLOAD_DIR
|
||||
return os.path.realpath(getattr(upload_handler, "upload_dir", UPLOAD_DIR))
|
||||
|
||||
def _path_inside_upload_dir(path: str) -> bool:
|
||||
try:
|
||||
return os.path.commonpath([_upload_root(), os.path.realpath(path)]) == _upload_root()
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def _resolve_upload_path(file_id: str) -> str:
|
||||
from src.constants import UPLOAD_DIR
|
||||
upload_root = getattr(upload_handler, "upload_dir", UPLOAD_DIR)
|
||||
direct = os.path.join(upload_root, file_id)
|
||||
if os.path.lexists(direct):
|
||||
if not _path_inside_upload_dir(direct):
|
||||
raise HTTPException(403, "Access denied")
|
||||
if os.path.isfile(direct):
|
||||
return direct
|
||||
raise HTTPException(404, "File not found")
|
||||
|
||||
for root, _dirs, files in os.walk(upload_root, followlinks=False):
|
||||
if file_id not in files:
|
||||
continue
|
||||
path = os.path.join(root, file_id)
|
||||
if not _path_inside_upload_dir(path):
|
||||
raise HTTPException(403, "Access denied")
|
||||
if os.path.isfile(path):
|
||||
return path
|
||||
raise HTTPException(404, "File not found")
|
||||
|
||||
raise HTTPException(404, "File not found")
|
||||
|
||||
@router.post("")
|
||||
async def api_upload(request: Request, files: List[UploadFile] = File(...)):
|
||||
@@ -91,23 +125,11 @@ def setup_upload_routes(upload_handler):
|
||||
client isn't downloading the full-resolution photo just to show it tiny."""
|
||||
if not upload_handler.validate_upload_id(file_id):
|
||||
raise HTTPException(400, "Invalid file ID")
|
||||
# Search upload directories for the file
|
||||
from src.constants import UPLOAD_DIR
|
||||
import mimetypes as _mt
|
||||
path = os.path.join(UPLOAD_DIR, file_id)
|
||||
if not os.path.exists(path):
|
||||
for root, dirs, files in os.walk(UPLOAD_DIR):
|
||||
if file_id in files:
|
||||
path = os.path.join(root, file_id)
|
||||
break
|
||||
else:
|
||||
raise HTTPException(404, "File not found")
|
||||
if not upload_handler.inside_base_dir(path):
|
||||
raise HTTPException(403, "Access denied")
|
||||
# Look up original filename and owner from uploads.json
|
||||
original_name = file_id
|
||||
info = None
|
||||
uploads_db = os.path.join(UPLOAD_DIR, "uploads.json")
|
||||
uploads_db = os.path.join(_upload_root(), "uploads.json")
|
||||
if os.path.exists(uploads_db):
|
||||
with open(uploads_db, encoding="utf-8") as f:
|
||||
db = json.load(f)
|
||||
@@ -123,13 +145,14 @@ def setup_upload_routes(upload_handler):
|
||||
raise HTTPException(403, "Access denied")
|
||||
if file_owner != current_user and not auth_mgr.is_admin(current_user):
|
||||
raise HTTPException(404, "File not found")
|
||||
mime = _mt.guess_type(path)[0] or "application/octet-stream"
|
||||
path = _resolve_upload_path(file_id)
|
||||
mime = (info or {}).get("mime") or _mt.guess_type(path)[0] or "application/octet-stream"
|
||||
from fastapi.responses import FileResponse
|
||||
# Downscaled thumbnail for image previews — generated once and cached.
|
||||
if thumb and mime.startswith("image/"):
|
||||
try:
|
||||
from PIL import Image, ImageOps
|
||||
thumb_dir = os.path.join(UPLOAD_DIR, ".thumbs")
|
||||
thumb_dir = os.path.join(_upload_root(), ".thumbs")
|
||||
os.makedirs(thumb_dir, exist_ok=True)
|
||||
thumb_path = os.path.join(thumb_dir, file_id + ".jpg")
|
||||
if (not os.path.exists(thumb_path)
|
||||
@@ -145,17 +168,21 @@ def setup_upload_routes(upload_handler):
|
||||
if im.mode not in ("RGB", "L"):
|
||||
im = im.convert("RGB")
|
||||
im.save(thumb_path, "JPEG", quality=80)
|
||||
return FileResponse(thumb_path, media_type="image/jpeg")
|
||||
return FileResponse(thumb_path, media_type="image/jpeg", headers=UPLOAD_RESPONSE_HEADERS)
|
||||
except Exception as e:
|
||||
logger.warning(f"Thumbnail generation failed for {file_id}: {e}")
|
||||
# Fall through to the full image.
|
||||
return FileResponse(path, media_type=mime, filename=original_name)
|
||||
return FileResponse(
|
||||
path,
|
||||
media_type=mime,
|
||||
filename=original_name,
|
||||
headers=UPLOAD_RESPONSE_HEADERS,
|
||||
)
|
||||
|
||||
def _load_upload_info(file_id: str):
|
||||
"""Look up the uploads.json record for a file_id, with owner/auth checks."""
|
||||
from src.constants import UPLOAD_DIR
|
||||
info = None
|
||||
uploads_db = os.path.join(UPLOAD_DIR, "uploads.json")
|
||||
uploads_db = os.path.join(_upload_root(), "uploads.json")
|
||||
if os.path.exists(uploads_db):
|
||||
with open(uploads_db, encoding="utf-8") as f:
|
||||
db = json.load(f)
|
||||
@@ -163,8 +190,7 @@ def setup_upload_routes(upload_handler):
|
||||
return info
|
||||
|
||||
def _vision_cache_path(file_id: str) -> str:
|
||||
from src.constants import UPLOAD_DIR
|
||||
cache_dir = os.path.join(UPLOAD_DIR, ".vision")
|
||||
cache_dir = os.path.join(_upload_root(), ".vision")
|
||||
os.makedirs(cache_dir, exist_ok=True)
|
||||
return os.path.join(cache_dir, file_id + ".txt")
|
||||
|
||||
@@ -175,17 +201,6 @@ def setup_upload_routes(upload_handler):
|
||||
subsequent loads are instant. Pass force=1 to recompute."""
|
||||
if not upload_handler.validate_upload_id(file_id):
|
||||
raise HTTPException(400, "Invalid file ID")
|
||||
from src.constants import UPLOAD_DIR
|
||||
path = os.path.join(UPLOAD_DIR, file_id)
|
||||
if not os.path.exists(path):
|
||||
for root, dirs, files in os.walk(UPLOAD_DIR):
|
||||
if file_id in files:
|
||||
path = os.path.join(root, file_id)
|
||||
break
|
||||
else:
|
||||
raise HTTPException(404, "File not found")
|
||||
if not upload_handler.inside_base_dir(path):
|
||||
raise HTTPException(403, "Access denied")
|
||||
info = _load_upload_info(file_id)
|
||||
auth_mgr = getattr(request.app.state, "auth_manager", None)
|
||||
auth_configured = bool(auth_mgr and auth_mgr.is_configured)
|
||||
@@ -196,8 +211,9 @@ def setup_upload_routes(upload_handler):
|
||||
raise HTTPException(403, "Access denied")
|
||||
if file_owner != current_user and not auth_mgr.is_admin(current_user):
|
||||
raise HTTPException(404, "File not found")
|
||||
path = _resolve_upload_path(file_id)
|
||||
import mimetypes as _mt
|
||||
mime = _mt.guess_type(path)[0] or ""
|
||||
mime = (info or {}).get("mime") or _mt.guess_type(path)[0] or ""
|
||||
if not mime.startswith("image/"):
|
||||
raise HTTPException(400, "Not an image")
|
||||
cache_path = _vision_cache_path(file_id)
|
||||
@@ -209,7 +225,7 @@ def setup_upload_routes(upload_handler):
|
||||
logger.warning(f"Vision cache read failed for {file_id}: {e}")
|
||||
from src.document_processor import analyze_image_with_vl
|
||||
try:
|
||||
text = analyze_image_with_vl(path) or ""
|
||||
text = analyze_image_with_vl(path, owner=current_user) or ""
|
||||
except Exception as e:
|
||||
logger.error(f"Vision analysis failed for {file_id}: {e}")
|
||||
raise HTTPException(500, f"Vision analysis failed: {e}")
|
||||
@@ -238,6 +254,7 @@ def setup_upload_routes(upload_handler):
|
||||
raise HTTPException(403, "Access denied")
|
||||
if file_owner != current_user and not auth_mgr.is_admin(current_user):
|
||||
raise HTTPException(404, "File not found")
|
||||
_resolve_upload_path(file_id)
|
||||
body = await request.json()
|
||||
text = (body or {}).get("text", "")
|
||||
if not isinstance(text, str):
|
||||
|
||||
@@ -17,10 +17,11 @@ from pydantic import BaseModel
|
||||
|
||||
from core.middleware import require_admin
|
||||
from core.platform_compat import IS_WINDOWS, safe_chmod, which_tool
|
||||
from src.constants import VAULT_FILE as _VAULT_FILE
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
VAULT_FILE = Path("data/vault.json")
|
||||
VAULT_FILE = Path(_VAULT_FILE)
|
||||
|
||||
|
||||
def _find_bw() -> str:
|
||||
|
||||
+23
-10
@@ -194,6 +194,8 @@ def setup_webhook_routes(
|
||||
"together": "https://api.together.xyz/v1",
|
||||
"openrouter": "https://openrouter.ai/api/v1",
|
||||
"ollama": "https://ollama.com/api",
|
||||
"opencode-zen": "https://opencode.ai/zen/v1",
|
||||
"opencode-go": "https://opencode.ai/zen/go/v1",
|
||||
"fireworks": "https://api.fireworks.ai/inference/v1",
|
||||
"venice": "https://api.venice.ai/api/v1",
|
||||
}
|
||||
@@ -323,22 +325,33 @@ def setup_webhook_routes(
|
||||
endpoint_url = build_chat_url(base_url)
|
||||
model = body.model or "auto"
|
||||
api_key = ep.api_key
|
||||
if getattr(ep, "provider_auth_id", None):
|
||||
try:
|
||||
from src.endpoint_resolver import resolve_endpoint_runtime
|
||||
base_url, api_key = resolve_endpoint_runtime(ep, owner=token_owner)
|
||||
endpoint_url = build_chat_url(base_url)
|
||||
except Exception:
|
||||
raise HTTPException(500, "Could not resolve endpoint credentials")
|
||||
|
||||
if model == "auto":
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=5) as client:
|
||||
models_url = build_models_url(base_url)
|
||||
hdrs = build_headers(api_key, base_url)
|
||||
resp = await client.get(models_url, headers=hdrs)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
ids = [m.get("id") for m in (data.get("data") or []) if m.get("id")]
|
||||
if not ids:
|
||||
ids = [
|
||||
m.get("name") or m.get("model")
|
||||
for m in (data.get("models") or [])
|
||||
if m.get("name") or m.get("model")
|
||||
]
|
||||
if models_url:
|
||||
resp = await client.get(models_url, headers=hdrs)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
ids = [m.get("id") for m in (data.get("data") or []) if m.get("id")]
|
||||
if not ids:
|
||||
ids = [
|
||||
m.get("name") or m.get("model")
|
||||
for m in (data.get("models") or [])
|
||||
if m.get("name") or m.get("model")
|
||||
]
|
||||
else:
|
||||
import json as _json
|
||||
ids = _json.loads(ep.cached_models or "[]")
|
||||
model = ids[0] if ids else "auto"
|
||||
except Exception:
|
||||
raise HTTPException(500, "Could not discover models from endpoint")
|
||||
|
||||
Reference in New Issue
Block a user