mirror of
https://github.com/pewdiepie-archdaemon/odysseus.git
synced 2026-06-17 10:15:27 -04:00
Merge remote-tracking branch 'origin/main' into visual-pr-playground
# Conflicts: # routes/cookbook_routes.py # routes/hwfit_routes.py # services/hwfit/fit.py # services/hwfit/models.py # static/js/cookbook-diagnosis.js # static/js/cookbook-hwfit.js # static/js/cookbook.js # static/js/cookbookRunning.js
This commit is contained in:
@@ -27,6 +27,7 @@ from core.database import (
|
||||
Document,
|
||||
DocumentVersion,
|
||||
GalleryImage,
|
||||
GalleryAlbum,
|
||||
CalendarEvent,
|
||||
CalendarCal,
|
||||
)
|
||||
@@ -145,8 +146,9 @@ def setup_admin_wipe_routes(session_manager):
|
||||
return {"status": "deleted", "kind": kind, "count": count}
|
||||
|
||||
if kind == "gallery":
|
||||
count = db.query(GalleryImage).count()
|
||||
count = db.query(GalleryImage).count() + db.query(GalleryAlbum).count()
|
||||
db.query(GalleryImage).delete()
|
||||
db.query(GalleryAlbum).delete()
|
||||
db.commit()
|
||||
# Also drop the upload dir so disk doesn't keep orphans.
|
||||
_rmtree_quiet(os.path.join(DATA_DIR, "gallery"))
|
||||
|
||||
+31
-6
@@ -67,6 +67,8 @@ class DeleteUserRequest(BaseModel):
|
||||
class RenameUserRequest(BaseModel):
|
||||
username: str
|
||||
|
||||
class SetOpenRegistrationRequest(BaseModel):
|
||||
enabled: bool
|
||||
|
||||
SESSION_COOKIE = "odysseus_session"
|
||||
|
||||
@@ -295,6 +297,7 @@ def setup_auth_routes(auth_manager: AuthManager) -> APIRouter:
|
||||
# owner-scoped DB rows before changing auth so the account keeps
|
||||
# access to its sessions, docs, email accounts, tasks, etc.
|
||||
try:
|
||||
from sqlalchemy import func
|
||||
from core.database import Base, SessionLocal
|
||||
db = SessionLocal()
|
||||
try:
|
||||
@@ -304,7 +307,7 @@ def setup_auth_routes(auth_manager: AuthManager) -> APIRouter:
|
||||
continue
|
||||
(
|
||||
db.query(model)
|
||||
.filter(model.owner == old_username)
|
||||
.filter(func.lower(model.owner) == old_username)
|
||||
.update({"owner": new_username}, synchronize_session=False)
|
||||
)
|
||||
db.commit()
|
||||
@@ -322,9 +325,15 @@ def setup_auth_routes(auth_manager: AuthManager) -> APIRouter:
|
||||
from routes.prefs_routes import _load as _load_prefs, _save as _save_prefs
|
||||
prefs = _load_prefs()
|
||||
users = prefs.get("_users") if isinstance(prefs, dict) else None
|
||||
if isinstance(users, dict) and old_username in users and new_username not in users:
|
||||
users[new_username] = users.pop(old_username)
|
||||
_save_prefs(prefs)
|
||||
if isinstance(users, dict):
|
||||
prefs_key = next(
|
||||
(k for k in users if str(k).strip().lower() == old_username),
|
||||
None,
|
||||
)
|
||||
new_taken = any(str(k).strip().lower() == new_username for k in users)
|
||||
if prefs_key is not None and not new_taken:
|
||||
users[new_username] = users.pop(prefs_key)
|
||||
_save_prefs(prefs)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to rename user prefs %s -> %s: %s", old_username, new_username, e)
|
||||
|
||||
@@ -333,15 +342,31 @@ def setup_auth_routes(auth_manager: AuthManager) -> APIRouter:
|
||||
raise HTTPException(400, "Cannot rename user")
|
||||
return {"ok": True, "username": new_username, "renamed_self": old_username == user}
|
||||
|
||||
@router.post("/signup-toggle")
|
||||
@router.post("/signup-toggle", deprecated=True)
|
||||
async def toggle_signup(request: Request):
|
||||
"""Toggle open registration on/off. Admin only."""
|
||||
"""
|
||||
Toggle open registration on/off. Admin only.
|
||||
|
||||
DEPRECATED: This endpoint uses toggle semantics which can lead to unsafe state changes.
|
||||
Use PUT /open-signup instead.
|
||||
|
||||
This endpoint is kept for backward compatibility and may be removed in future versions.
|
||||
"""
|
||||
user = _get_current_user(request)
|
||||
if not user or not auth_manager.is_admin(user):
|
||||
raise HTTPException(403, "Admin only")
|
||||
auth_manager.signup_enabled = not auth_manager.signup_enabled
|
||||
return {"ok": True, "signup_enabled": auth_manager.signup_enabled}
|
||||
|
||||
@router.put("/open-signup")
|
||||
async def set_signup_enabled(body: SetOpenRegistrationRequest, request: Request):
|
||||
"""Set open signup enabled state. Admin only."""
|
||||
user = _get_current_user(request)
|
||||
if not user or not auth_manager.is_admin(user):
|
||||
raise HTTPException(403, "Admin only")
|
||||
auth_manager.signup_enabled = body.enabled
|
||||
return {"ok": True,"signup_enabled": auth_manager.signup_enabled}
|
||||
|
||||
@router.delete("/users")
|
||||
async def admin_delete_user(body: DeleteUserRequest, request: Request):
|
||||
user = _get_current_user(request)
|
||||
|
||||
@@ -77,7 +77,12 @@ def setup_backup_routes(memory_manager, preset_manager, skills_manager) -> APIRo
|
||||
# ── Memories ──
|
||||
if "memories" in body and isinstance(body["memories"], list):
|
||||
existing = memory_manager.load_all()
|
||||
existing_texts = {e.get("text", "").strip().lower() for e in existing}
|
||||
# Dedup against THIS user's own memories only. Using every tenant's
|
||||
# rows (load_all) meant a memory whose text matched any other
|
||||
# user's was silently skipped, so the importing user lost their own
|
||||
# data. The full store is still saved back below.
|
||||
existing_texts = {e.get("text", "").strip().lower()
|
||||
for e in existing if e.get("owner") == user}
|
||||
added = 0
|
||||
for mem in body["memories"]:
|
||||
if not isinstance(mem, dict) or not mem.get("text"):
|
||||
|
||||
+147
-27
@@ -12,10 +12,27 @@ from dateutil.rrule import rrulestr, rruleset
|
||||
from dateutil.rrule import DAILY, WEEKLY, MONTHLY, YEARLY
|
||||
|
||||
from core.database import SessionLocal, CalendarCal, CalendarEvent
|
||||
from src.auth_helpers import get_current_user
|
||||
from src.auth_helpers import get_current_user, require_user
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _ics_naive_dtstart(dt):
|
||||
"""Naive value matching how import_ics STORES CalendarEvent.dtstart.
|
||||
|
||||
Timed tz-aware events are stored as UTC with tzinfo stripped, all-day
|
||||
dates as midnight datetimes, naive datetimes unchanged. The ICS dedup
|
||||
must compute the same value or a re-import never matches the stored row.
|
||||
"""
|
||||
if isinstance(dt, datetime):
|
||||
if dt.tzinfo is not None:
|
||||
from datetime import timezone as _tz
|
||||
return dt.astimezone(_tz.utc).replace(tzinfo=None)
|
||||
return dt
|
||||
if isinstance(dt, date):
|
||||
return datetime(dt.year, dt.month, dt.day)
|
||||
return dt
|
||||
|
||||
# Single-user fallback identity. Used only when:
|
||||
# 1. The app is configured for single-user (no auth middleware), AND
|
||||
# 2. The request didn't resolve to an authenticated user.
|
||||
@@ -28,16 +45,17 @@ _SINGLE_USER_MODE = _os.environ.get("ODYSSEUS_SINGLE_USER", "1") != "0"
|
||||
|
||||
|
||||
def _require_user(request: Request) -> str:
|
||||
"""Return the authenticated user. In multi-user mode an unauthenticated
|
||||
request raises 401; in single-user mode it falls through to
|
||||
FALLBACK_OWNER. Prevents the silent cross-user data write that would
|
||||
happen if a request slipped past auth middleware in a real deployment."""
|
||||
u = get_current_user(request)
|
||||
if u:
|
||||
return u
|
||||
if _SINGLE_USER_MODE:
|
||||
return FALLBACK_OWNER
|
||||
raise HTTPException(401, "Authentication required")
|
||||
"""Return the authenticated user. Uses require_user so AUTH_ENABLED=false
|
||||
and single-user mode both work: require_user returns "" when auth is
|
||||
disabled or unconfigured, and only raises 401 when auth is configured but
|
||||
the caller is unauthenticated. Falls back to FALLBACK_OWNER for calendar
|
||||
writes so data isn't stored under an empty owner in single-user mode."""
|
||||
user = require_user(request)
|
||||
if user:
|
||||
return user
|
||||
# require_user returned "" — auth is off or unconfigured (single-user).
|
||||
# Use FALLBACK_OWNER so calendar rows have a stable owner for filtering.
|
||||
return FALLBACK_OWNER
|
||||
|
||||
|
||||
def _get_or_404_calendar(db, cal_id: str, owner: str) -> CalendarCal:
|
||||
@@ -64,6 +82,24 @@ def _get_or_404_event(db, uid: str, owner: str) -> CalendarEvent:
|
||||
return ev
|
||||
|
||||
|
||||
def _ics_escape(text: str) -> str:
|
||||
"""Escape a value for an iCalendar TEXT field (RFC 5545 §3.3.11).
|
||||
|
||||
Backslash, semicolon and comma are structural in TEXT values and must be
|
||||
escaped, and newlines become a literal ``\\n``. Backslash is escaped first
|
||||
so the escapes we add aren't re-escaped.
|
||||
"""
|
||||
return (
|
||||
(text or "")
|
||||
.replace("\\", "\\\\")
|
||||
.replace(";", "\\;")
|
||||
.replace(",", "\\,")
|
||||
.replace("\r\n", "\\n")
|
||||
.replace("\n", "\\n")
|
||||
.replace("\r", "\\n")
|
||||
)
|
||||
|
||||
|
||||
def _resolve_base_uid(uid: str) -> str:
|
||||
"""Extract the base series UID from a compound occurrence UID.
|
||||
|
||||
@@ -319,8 +355,8 @@ def _parse_dt(s: str) -> datetime:
|
||||
return None
|
||||
return h, mn
|
||||
|
||||
# today/tomorrow/yesterday [at] TIME
|
||||
m = _re.match(r'^(today|tomorrow|tmrw|yesterday)(?:\s+at)?\s*(.*)$', lower)
|
||||
# today/tonight/tomorrow/yesterday [at] TIME
|
||||
m = _re.match(r'^(today|tonight|tomorrow|tmrw|yesterday)(?:\s+at)?\s*(.*)$', lower)
|
||||
if m:
|
||||
word, rest = m.group(1), m.group(2).strip()
|
||||
base = today
|
||||
@@ -434,8 +470,21 @@ def _expand_rrule(
|
||||
return [d]
|
||||
|
||||
# Parse the rrule, applying it to the base dtstart.
|
||||
rrule_str = ev.rrule
|
||||
if ev.dtstart is not None and getattr(ev.dtstart, "tzinfo", None) is None:
|
||||
# Events are stored with a naive (UTC) dtstart, but standard .ics
|
||||
# exporters (Google/Apple/Outlook/Fastmail) write the bound as an
|
||||
# absolute UTC value, e.g. UNTIL=20240105T090000Z. dateutil refuses to
|
||||
# mix a tz-aware UNTIL with a naive DTSTART ("RRULE UNTIL values must be
|
||||
# specified in UTC when DTSTART is timezone-aware"), so the except branch
|
||||
# below would silently collapse the whole series to a single event.
|
||||
# Drop the trailing Z so UNTIL matches the naive DTSTART.
|
||||
import re as _re
|
||||
rrule_str = _re.sub(
|
||||
r"(UNTIL=\d{8}(?:T\d{6})?)Z", r"\1", rrule_str, flags=_re.IGNORECASE
|
||||
)
|
||||
try:
|
||||
rule = rrulestr(ev.rrule, dtstart=ev.dtstart)
|
||||
rule = rrulestr(rrule_str, dtstart=ev.dtstart)
|
||||
except Exception as ex:
|
||||
logger.warning(
|
||||
"Failed to parse rrule=%r for event %s: %s", ev.rrule, ev.uid, ex
|
||||
@@ -509,13 +558,20 @@ def setup_calendar_routes() -> APIRouter:
|
||||
owner = _require_user(request)
|
||||
from routes.prefs_routes import _load_for_user
|
||||
cfg = (_load_for_user(owner) or {}).get("caldav", {}) or {}
|
||||
caldav_password = cfg.get("password") or ""
|
||||
if caldav_password:
|
||||
try:
|
||||
from src.secret_storage import decrypt
|
||||
caldav_password = decrypt(caldav_password)
|
||||
except Exception:
|
||||
pass
|
||||
# Surface url+username but never hand the password back to the
|
||||
# client — saved-state UI shouldn't leak the credential.
|
||||
return {
|
||||
"url": cfg.get("url", "") or "",
|
||||
"username": cfg.get("username", "") or "",
|
||||
"password": "",
|
||||
"has_password": bool(cfg.get("password")),
|
||||
"has_password": bool(caldav_password),
|
||||
"local": not bool(cfg.get("url")),
|
||||
}
|
||||
|
||||
@@ -534,12 +590,20 @@ def setup_calendar_routes() -> APIRouter:
|
||||
prefs.pop("caldav", None)
|
||||
_save_for_user(owner, prefs)
|
||||
return {"ok": True, "cleared": True}
|
||||
cfg["url"] = body.get("url", "").strip()
|
||||
from src.caldav_sync import validate_caldav_url
|
||||
try:
|
||||
cfg["url"] = validate_caldav_url(body.get("url", ""))
|
||||
except ValueError as e:
|
||||
raise HTTPException(400, str(e))
|
||||
cfg["username"] = (body.get("username") or "").strip()
|
||||
# Preserve the stored password when the client sends an empty
|
||||
# one (edit form re-submitted without re-typing the password).
|
||||
if body.get("password"):
|
||||
cfg["password"] = body["password"]
|
||||
from src.secret_storage import encrypt
|
||||
cfg["password"] = encrypt(body["password"])
|
||||
elif cfg.get("password"):
|
||||
from src.secret_storage import encrypt
|
||||
cfg["password"] = encrypt(cfg["password"])
|
||||
prefs["caldav"] = cfg
|
||||
_save_for_user(owner, prefs)
|
||||
return {"ok": True}
|
||||
@@ -566,9 +630,21 @@ def setup_calendar_routes() -> APIRouter:
|
||||
cfg = (_load_for_user(owner) or {}).get("caldav", {}) or {}
|
||||
url = url or (cfg.get("url") or "")
|
||||
user = user or (cfg.get("username") or "")
|
||||
pw = pw or (cfg.get("password") or "")
|
||||
if not pw:
|
||||
pw = cfg.get("password") or ""
|
||||
if pw:
|
||||
try:
|
||||
from src.secret_storage import decrypt
|
||||
pw = decrypt(pw)
|
||||
except Exception:
|
||||
pass
|
||||
if not (url and user and pw):
|
||||
return {"ok": False, "error": "Missing URL, username, or password"}
|
||||
from src.caldav_sync import validate_caldav_url
|
||||
try:
|
||||
url = validate_caldav_url(url)
|
||||
except ValueError as e:
|
||||
return {"ok": False, "error": str(e)}
|
||||
import httpx
|
||||
propfind_body = (
|
||||
'<?xml version="1.0" encoding="UTF-8"?>\n'
|
||||
@@ -576,13 +652,25 @@ def setup_calendar_routes() -> APIRouter:
|
||||
'</d:prop></d:propfind>'
|
||||
)
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=8.0, follow_redirects=True) as cx:
|
||||
async with httpx.AsyncClient(timeout=8.0, follow_redirects=False, trust_env=False) as cx:
|
||||
r = await cx.request(
|
||||
"PROPFIND", url,
|
||||
auth=(user, pw),
|
||||
headers={"Depth": "0", "Content-Type": "application/xml"},
|
||||
content=propfind_body,
|
||||
)
|
||||
# If the server demands Digest (Baïkal default, SabreDAV-based
|
||||
# servers, Radicale with htdigest), the Basic attempt above
|
||||
# 401s. Retry once with httpx.DigestAuth so this test matches
|
||||
# what the real sync does via caldav.DAVClient in
|
||||
# src/caldav_sync.py (which negotiates the scheme).
|
||||
if r.status_code == 401 and "digest" in r.headers.get("www-authenticate", "").lower():
|
||||
r = await cx.request(
|
||||
"PROPFIND", url,
|
||||
auth=httpx.DigestAuth(user, pw),
|
||||
headers={"Depth": "0", "Content-Type": "application/xml"},
|
||||
content=propfind_body,
|
||||
)
|
||||
# 207 = Multi-Status — standard CalDAV success. 200 also
|
||||
# acceptable. Anything else (401/403/404/5xx) means trouble.
|
||||
if r.status_code in (200, 207):
|
||||
@@ -593,6 +681,8 @@ def setup_calendar_routes() -> APIRouter:
|
||||
return {"ok": False, "error": "Forbidden — user can't access that URL"}
|
||||
if r.status_code == 404:
|
||||
return {"ok": False, "error": "Not found — check the URL path"}
|
||||
if 300 <= r.status_code < 400:
|
||||
return {"ok": False, "error": "Redirects are not followed for CalDAV safety; use the final URL"}
|
||||
return {"ok": False, "error": f"HTTP {r.status_code}"}
|
||||
except httpx.ConnectError as e:
|
||||
return {"ok": False, "error": f"Connection refused: {e}"[:200]}
|
||||
@@ -739,6 +829,16 @@ def setup_calendar_routes() -> APIRouter:
|
||||
)
|
||||
db.add(ev)
|
||||
db.commit()
|
||||
if cal.source == "caldav":
|
||||
# Push the new event to the remote so it appears on the user's
|
||||
# other devices — the sync is otherwise pull-only (#800).
|
||||
from src.caldav_writeback import writeback_event
|
||||
await writeback_event(owner, cal.source, cal.id, {
|
||||
"uid": uid, "summary": data.summary, "description": data.description,
|
||||
"location": data.location, "dtstart": dtstart, "dtend": dtend,
|
||||
"all_day": data.all_day, "is_utc": _is_utc and not data.all_day,
|
||||
"rrule": data.rrule or "",
|
||||
})
|
||||
return {"ok": True, "uid": uid}
|
||||
except HTTPException:
|
||||
raise
|
||||
@@ -785,6 +885,14 @@ def setup_calendar_routes() -> APIRouter:
|
||||
if data.color is not None:
|
||||
ev.color = data.color if data.color else None
|
||||
db.commit()
|
||||
cal = db.query(CalendarCal).filter(CalendarCal.id == ev.calendar_id).first()
|
||||
if cal and cal.source == "caldav":
|
||||
from src.caldav_writeback import writeback_event
|
||||
await writeback_event(owner, cal.source, cal.id, {
|
||||
"uid": ev.uid, "summary": ev.summary, "description": ev.description,
|
||||
"location": ev.location, "dtstart": ev.dtstart, "dtend": ev.dtend,
|
||||
"all_day": ev.all_day, "is_utc": ev.is_utc, "rrule": ev.rrule or "",
|
||||
})
|
||||
return {"ok": True}
|
||||
except HTTPException:
|
||||
raise
|
||||
@@ -805,8 +913,15 @@ def setup_calendar_routes() -> APIRouter:
|
||||
db = SessionLocal()
|
||||
try:
|
||||
ev = _get_or_404_event(db, base_uid, owner)
|
||||
# Capture what the remote push needs BEFORE the row is gone.
|
||||
_cal = db.query(CalendarCal).filter(CalendarCal.id == ev.calendar_id).first()
|
||||
_is_caldav = bool(_cal and _cal.source == "caldav")
|
||||
_cal_id, _ev_uid = ev.calendar_id, ev.uid
|
||||
db.delete(ev)
|
||||
db.commit()
|
||||
if _is_caldav:
|
||||
from src.caldav_writeback import writeback_event
|
||||
await writeback_event(owner, "caldav", _cal_id, {"uid": _ev_uid}, delete=True)
|
||||
return {"ok": True}
|
||||
except HTTPException:
|
||||
raise
|
||||
@@ -938,7 +1053,12 @@ def setup_calendar_routes() -> APIRouter:
|
||||
source_uid = str(comp.get("uid", "")) or None
|
||||
if source_uid:
|
||||
src_dtstart = dtstart.dt
|
||||
naive_src = src_dtstart.replace(tzinfo=None) if hasattr(src_dtstart, 'tzinfo') and src_dtstart.tzinfo else src_dtstart
|
||||
# Normalize to the SAME naive form import_ics stores, so a
|
||||
# re-import of a tz-aware event matches the existing row.
|
||||
# The old code stripped tzinfo WITHOUT converting to UTC
|
||||
# (wall clock), while storage converts to UTC first, so
|
||||
# every re-import of a TZID event created a duplicate.
|
||||
naive_src = _ics_naive_dtstart(src_dtstart)
|
||||
existing = (
|
||||
db.query(CalendarEvent)
|
||||
.filter(
|
||||
@@ -1032,23 +1152,23 @@ def setup_calendar_routes() -> APIRouter:
|
||||
"BEGIN:VCALENDAR",
|
||||
"VERSION:2.0",
|
||||
"PRODID:-//Odysseus//Calendar//EN",
|
||||
f"X-WR-CALNAME:{cal.name}",
|
||||
f"X-WR-CALNAME:{_ics_escape(cal.name)}",
|
||||
]
|
||||
for ev in events:
|
||||
lines.append("BEGIN:VEVENT")
|
||||
lines.append(f"UID:{ev.uid}")
|
||||
lines.append(f"SUMMARY:{ev.summary or ''}")
|
||||
lines.append(f"SUMMARY:{_ics_escape(ev.summary or '')}")
|
||||
if ev.all_day:
|
||||
lines.append(f"DTSTART;VALUE=DATE:{ev.dtstart.strftime('%Y%m%d')}")
|
||||
lines.append(f"DTEND;VALUE=DATE:{ev.dtend.strftime('%Y%m%d')}")
|
||||
else:
|
||||
lines.append(f"DTSTART:{ev.dtstart.strftime('%Y%m%dT%H%M%S')}")
|
||||
lines.append(f"DTEND:{ev.dtend.strftime('%Y%m%dT%H%M%S')}")
|
||||
_dt_suffix = "Z" if getattr(ev, "is_utc", False) else ""
|
||||
lines.append(f"DTSTART:{ev.dtstart.strftime('%Y%m%dT%H%M%S')}{_dt_suffix}")
|
||||
lines.append(f"DTEND:{ev.dtend.strftime('%Y%m%dT%H%M%S')}{_dt_suffix}")
|
||||
if ev.description:
|
||||
desc = ev.description.replace(chr(10), '\\n')
|
||||
lines.append(f"DESCRIPTION:{desc}")
|
||||
lines.append(f"DESCRIPTION:{_ics_escape(ev.description)}")
|
||||
if ev.location:
|
||||
lines.append(f"LOCATION:{ev.location}")
|
||||
lines.append(f"LOCATION:{_ics_escape(ev.location)}")
|
||||
if ev.rrule:
|
||||
lines.append(f"RRULE:{ev.rrule}")
|
||||
lines.append("END:VEVENT")
|
||||
|
||||
+113
-18
@@ -3,6 +3,7 @@
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Optional
|
||||
@@ -11,6 +12,7 @@ from core.models import ChatMessage
|
||||
from core.database import SessionLocal
|
||||
from core.database import Session as DBSession, ModelEndpoint
|
||||
from src.llm_core import normalize_model_id
|
||||
from src.endpoint_resolver import normalize_base
|
||||
from src.context_compactor import maybe_compact, trim_for_context
|
||||
from src.auth_helpers import get_current_user
|
||||
from src.prompt_security import untrusted_context_message
|
||||
@@ -119,7 +121,7 @@ def needs_auto_name(name: str) -> bool:
|
||||
if name.startswith("Chat:") or name == "Chat":
|
||||
return True
|
||||
# Default frontend name: "modelname HH:MM:SS AM/PM"
|
||||
if re.match(r'^.+ \d{1,2}:\d{2}:\d{2}\s*(AM|PM)$', name):
|
||||
if re.match(r"^.+ \d{1,2}:\d{2}:\d{2}(\s*(AM|PM))?$", name, re.IGNORECASE):
|
||||
return True
|
||||
return False
|
||||
|
||||
@@ -146,9 +148,13 @@ async def auto_name_session(session_manager, sess):
|
||||
if not first_msg:
|
||||
return
|
||||
|
||||
owner = getattr(sess, "owner", None)
|
||||
t_url, t_model, t_headers = resolve_task_endpoint(
|
||||
sess.endpoint_url, sess.model, sess.headers,
|
||||
sess.endpoint_url, sess.model, sess.headers, owner=owner,
|
||||
)
|
||||
if not t_model:
|
||||
logger.debug("[auto-name] No model provided, skipping")
|
||||
return
|
||||
|
||||
# max_tokens big enough that reasoning models (Minimax M2,
|
||||
# DeepSeek R1, QwQ, etc.) have headroom for <think>…</think>
|
||||
@@ -306,7 +312,24 @@ def fire_message_event(request, webhook_manager, session_id: str, sess, message:
|
||||
fire_event("message_sent", user)
|
||||
|
||||
|
||||
def resolve_session_auth(sess, session_id: str):
|
||||
def _session_url_matches_endpoint(session_url: str, endpoint_base: str) -> bool:
|
||||
if not session_url or not endpoint_base:
|
||||
return False
|
||||
try:
|
||||
from src.endpoint_resolver import build_chat_url, normalize_base
|
||||
|
||||
sess_url = session_url.rstrip("/")
|
||||
base = normalize_base(endpoint_base).rstrip("/")
|
||||
return sess_url in {
|
||||
base,
|
||||
base + "/chat/completions",
|
||||
build_chat_url(base).rstrip("/"),
|
||||
}
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def resolve_session_auth(sess, session_id: str, owner: Optional[str] = None):
|
||||
"""Ensure session has auth headers — resolve from endpoint DB if missing."""
|
||||
has_auth = sess.headers and isinstance(sess.headers, dict) and any(
|
||||
k.lower() in ('authorization', 'x-api-key') for k in sess.headers
|
||||
@@ -315,25 +338,96 @@ def resolve_session_auth(sess, session_id: str):
|
||||
return
|
||||
|
||||
try:
|
||||
from src.endpoint_resolver import build_headers
|
||||
from src.endpoint_resolver import build_headers, normalize_base
|
||||
db = SessionLocal()
|
||||
try:
|
||||
domain = sess.endpoint_url.split("//")[1].split("/")[0] if "//" in sess.endpoint_url else ""
|
||||
if domain:
|
||||
ep = db.query(ModelEndpoint).filter(ModelEndpoint.base_url.contains(domain)).first()
|
||||
if ep and ep.api_key:
|
||||
sess.headers = build_headers(ep.api_key, ep.base_url)
|
||||
db.query(DBSession).filter(DBSession.id == session_id).update(
|
||||
{"headers": json.dumps(sess.headers)}
|
||||
)
|
||||
db.commit()
|
||||
logger.info(f"Resolved and persisted auth headers for session {session_id} from endpoint {ep.name}")
|
||||
target_url = getattr(sess, "endpoint_url", "") or ""
|
||||
if not target_url:
|
||||
return
|
||||
q = db.query(ModelEndpoint).filter(ModelEndpoint.is_enabled == True)
|
||||
if owner:
|
||||
# Missing headers usually means "recover from the saved endpoint".
|
||||
# Scope that lookup to the session owner, otherwise two users
|
||||
# with similar endpoint URLs can borrow each other's API key.
|
||||
from src.auth_helpers import owner_filter
|
||||
q = owner_filter(q, ModelEndpoint, owner)
|
||||
for ep in q.all():
|
||||
if not _session_url_matches_endpoint(target_url, ep.base_url or ""):
|
||||
continue
|
||||
if not ep.api_key:
|
||||
return
|
||||
base = normalize_base(ep.base_url or "")
|
||||
sess.headers = build_headers(ep.api_key, base)
|
||||
update_q = db.query(DBSession).filter(DBSession.id == session_id)
|
||||
if owner:
|
||||
update_q = update_q.filter(DBSession.owner == owner)
|
||||
update_q.update({"headers": sess.headers})
|
||||
db.commit()
|
||||
logger.info(f"Resolved and persisted auth headers for session {session_id} from endpoint {ep.name}")
|
||||
return
|
||||
finally:
|
||||
db.close()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to resolve session headers: {e}")
|
||||
|
||||
|
||||
def _match_cached_model_id(requested: str, models) -> Optional[str]:
|
||||
if not requested or not models:
|
||||
return None
|
||||
model_ids = [str(m) for m in models if m]
|
||||
if requested in model_ids:
|
||||
return requested
|
||||
|
||||
req_base = os.path.basename(requested.rstrip("/"))
|
||||
for model_id in model_ids:
|
||||
if os.path.basename(model_id.rstrip("/")) == req_base:
|
||||
return model_id
|
||||
return None
|
||||
|
||||
|
||||
def _normalize_model_id_from_cache(sess) -> Optional[str]:
|
||||
"""Use stored endpoint model IDs before falling back to a live /models probe."""
|
||||
endpoint_url = getattr(sess, "endpoint_url", "") or ""
|
||||
requested = getattr(sess, "model", "") or ""
|
||||
if not endpoint_url or not requested:
|
||||
return None
|
||||
|
||||
try:
|
||||
session_base = normalize_base(endpoint_url)
|
||||
except Exception:
|
||||
session_base = endpoint_url.rstrip("/")
|
||||
if not session_base:
|
||||
return None
|
||||
|
||||
db = SessionLocal()
|
||||
try:
|
||||
endpoints = db.query(ModelEndpoint).filter(ModelEndpoint.is_enabled == True).all()
|
||||
for ep in endpoints:
|
||||
try:
|
||||
if normalize_base(getattr(ep, "base_url", "") or "") != session_base:
|
||||
continue
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
raw_models = getattr(ep, "cached_models", None)
|
||||
if not raw_models:
|
||||
continue
|
||||
try:
|
||||
models = json.loads(raw_models) if isinstance(raw_models, str) else raw_models
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
matched = _match_cached_model_id(requested, models)
|
||||
if matched:
|
||||
return matched
|
||||
except Exception as e:
|
||||
logger.debug("Cached model normalization skipped: %s", e)
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
return None
|
||||
|
||||
|
||||
async def build_chat_context(
|
||||
sess,
|
||||
request,
|
||||
@@ -434,8 +528,9 @@ async def build_chat_context(
|
||||
for transcript in preprocessed.youtube_transcripts:
|
||||
preface.append(untrusted_context_message("youtube transcript", transcript))
|
||||
|
||||
# Normalize model ID
|
||||
norm = normalize_model_id(sess.endpoint_url, sess.model)
|
||||
# Normalize model ID. Prefer cached endpoint models so group chat does not
|
||||
# re-hit slow local /models endpoints on every participant turn.
|
||||
norm = _normalize_model_id_from_cache(sess) or normalize_model_id(sess.endpoint_url, sess.model)
|
||||
if norm:
|
||||
sess.model = norm
|
||||
|
||||
@@ -743,7 +838,7 @@ def run_post_response_tasks(
|
||||
from services.memory.memory_extractor import extract_and_store
|
||||
from src.task_endpoint import resolve_task_endpoint
|
||||
t_url, t_model, t_headers = resolve_task_endpoint(
|
||||
sess.endpoint_url, sess.model, sess.headers,
|
||||
sess.endpoint_url, sess.model, sess.headers, owner=owner,
|
||||
)
|
||||
asyncio.create_task(extract_and_store(
|
||||
sess, memory_manager, memory_vector,
|
||||
@@ -780,7 +875,7 @@ def run_post_response_tasks(
|
||||
from services.memory.skill_extractor import maybe_extract_skill
|
||||
from src.task_endpoint import resolve_task_endpoint
|
||||
s_url, s_model, s_headers = resolve_task_endpoint(
|
||||
sess.endpoint_url, sess.model, sess.headers,
|
||||
sess.endpoint_url, sess.model, sess.headers, owner=owner,
|
||||
)
|
||||
logger.debug("[skill-extract] dispatching extractor (model=%s)", s_model)
|
||||
asyncio.create_task(maybe_extract_skill(
|
||||
|
||||
+222
-44
@@ -23,10 +23,12 @@ from src.prompt_security import untrusted_context_message
|
||||
from core.exceptions import SessionNotFoundError
|
||||
from src.auth_helpers import get_current_user
|
||||
from routes.session_routes import _verify_session_owner
|
||||
from routes.document_helpers import _owner_session_filter
|
||||
from core.database import SessionLocal, get_session_mode, set_session_mode
|
||||
from core.database import Session as DBSession, ChatMessage as DBChatMessage
|
||||
from core.database import Document as DBDocument, ModelEndpoint
|
||||
from routes.research_routes import _resolve_research_endpoint
|
||||
from routes.model_routes import _visible_models
|
||||
from routes.chat_helpers import (
|
||||
resolve_session_auth,
|
||||
build_chat_context,
|
||||
@@ -41,6 +43,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
# Track active streams for partial-save safety net
|
||||
_active_streams: Dict[str, dict] = {}
|
||||
_IMAGE_MODEL_PREFIXES = ("gpt-image", "dall-e", "chatgpt-image")
|
||||
|
||||
|
||||
def _stream_set(session_id: str, **fields) -> None:
|
||||
@@ -69,13 +72,17 @@ def _session_url_matches_endpoint(session_url: str, endpoint_base: str) -> bool:
|
||||
return sess in variants or sess.startswith(base + "/")
|
||||
|
||||
|
||||
def _clear_orphaned_session_endpoint(sess) -> bool:
|
||||
def _clear_orphaned_session_endpoint(sess, owner: str | None = None) -> bool:
|
||||
"""Clear a session model if its endpoint was deleted from ModelEndpoint."""
|
||||
if not getattr(sess, "endpoint_url", ""):
|
||||
return False
|
||||
db = SessionLocal()
|
||||
try:
|
||||
endpoints = db.query(ModelEndpoint).filter(ModelEndpoint.is_enabled == True).all()
|
||||
q = db.query(ModelEndpoint).filter(ModelEndpoint.is_enabled == True)
|
||||
if owner:
|
||||
from src.auth_helpers import owner_filter
|
||||
q = owner_filter(q, ModelEndpoint, owner)
|
||||
endpoints = q.all()
|
||||
for ep in endpoints:
|
||||
if _session_url_matches_endpoint(sess.endpoint_url or "", ep.base_url or ""):
|
||||
return False
|
||||
@@ -96,6 +103,132 @@ def _clear_orphaned_session_endpoint(sess) -> bool:
|
||||
db.close()
|
||||
|
||||
|
||||
def _endpoint_cache_contains_model(endpoint, model: str) -> bool:
|
||||
"""Return True when a populated endpoint model cache includes ``model``.
|
||||
|
||||
Empty/malformed caches are treated as unknown rather than a negative match
|
||||
so older image endpoints without cached models still work.
|
||||
"""
|
||||
raw = getattr(endpoint, "cached_models", None)
|
||||
if not raw:
|
||||
return True
|
||||
try:
|
||||
models = json.loads(raw) if isinstance(raw, str) else raw
|
||||
except Exception:
|
||||
return True
|
||||
if not isinstance(models, list) or not models:
|
||||
return True
|
||||
wanted = (model or "").strip()
|
||||
return wanted in {str(item).strip() for item in models}
|
||||
|
||||
|
||||
def _is_image_generation_session(sess, owner: str | None = None) -> bool:
|
||||
"""Whether this chat session should bypass text chat and generate images.
|
||||
|
||||
Model-name prefixes are explicit image models. Endpoint type is only used
|
||||
when the current session endpoint actually matches that image endpoint, and
|
||||
when a populated endpoint model cache includes the selected model. This
|
||||
prevents an image endpoint on the same host from misrouting ordinary text
|
||||
models into the image-generation path.
|
||||
"""
|
||||
model = (getattr(sess, "model", "") or "").strip()
|
||||
if any(model.lower().startswith(prefix) for prefix in _IMAGE_MODEL_PREFIXES):
|
||||
return True
|
||||
|
||||
endpoint_url = (getattr(sess, "endpoint_url", "") or "").strip()
|
||||
if not endpoint_url:
|
||||
return False
|
||||
|
||||
db = SessionLocal()
|
||||
try:
|
||||
q = db.query(ModelEndpoint).filter(ModelEndpoint.is_enabled == True)
|
||||
if owner:
|
||||
from src.auth_helpers import owner_filter
|
||||
q = owner_filter(q, ModelEndpoint, owner)
|
||||
endpoints = q.all()
|
||||
for endpoint in endpoints:
|
||||
if (getattr(endpoint, "model_type", None) or "llm") != "image":
|
||||
continue
|
||||
if not _session_url_matches_endpoint(endpoint_url, getattr(endpoint, "base_url", "") or ""):
|
||||
continue
|
||||
if _endpoint_cache_contains_model(endpoint, model):
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
finally:
|
||||
db.close()
|
||||
return False
|
||||
|
||||
|
||||
def _recover_empty_session_model(sess, session_id: str, owner: str | None = None) -> bool:
|
||||
"""Re-populate sess.model from the matching endpoint's cached models.
|
||||
|
||||
Covers the window between endpoint setup and the first chat send: the
|
||||
picker showed a model in the dropdown but the session record never got
|
||||
written (Issue #587 — UI uses the cached endpoint list, not s.model).
|
||||
Without this, we'd POST the upstream with model="" and get a generic
|
||||
401/503 instead of using the model the user already picked.
|
||||
|
||||
Returns True iff sess.model was repaired.
|
||||
"""
|
||||
if getattr(sess, "model", None):
|
||||
return False
|
||||
db = SessionLocal()
|
||||
try:
|
||||
# Prefer the endpoint whose base URL matches the session — we know the
|
||||
# user already pointed this session at that endpoint, so its first
|
||||
# cached model is the most defensible default.
|
||||
ep = None
|
||||
if getattr(sess, "endpoint_url", ""):
|
||||
q = db.query(ModelEndpoint).filter(ModelEndpoint.is_enabled == True)
|
||||
if owner:
|
||||
from src.auth_helpers import owner_filter
|
||||
q = owner_filter(q, ModelEndpoint, owner)
|
||||
endpoints = q.all()
|
||||
for cand in endpoints:
|
||||
if _session_url_matches_endpoint(sess.endpoint_url or "", cand.base_url or ""):
|
||||
ep = cand
|
||||
break
|
||||
if not ep:
|
||||
return False
|
||||
try:
|
||||
cached = json.loads(ep.cached_models) if isinstance(ep.cached_models, str) else (ep.cached_models or [])
|
||||
except Exception:
|
||||
cached = []
|
||||
if not cached:
|
||||
return False
|
||||
try:
|
||||
visible = _visible_models(cached, getattr(ep, "hidden_models", None))
|
||||
except Exception:
|
||||
visible = cached
|
||||
if not visible:
|
||||
return False
|
||||
model = visible[0]
|
||||
if not isinstance(model, str) or not model.strip():
|
||||
return False
|
||||
model = model.strip()
|
||||
# Persist so the next request, websocket reconnect, or page reload
|
||||
# picks up the same model (we'd otherwise re-pick on every send
|
||||
# and silently switch on the user if the cached order shifts).
|
||||
db_session = db.query(DBSession).filter(DBSession.id == session_id).first()
|
||||
if db_session:
|
||||
db_session.model = model
|
||||
db_session.updated_at = datetime.utcnow()
|
||||
db.commit()
|
||||
sess.model = model
|
||||
logger.info(
|
||||
"Recovered empty session model for %s — picked %r from endpoint %s",
|
||||
session_id, model, ep.id,
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
logger.warning("Failed to recover empty session model for %s: %s", session_id, e)
|
||||
return False
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
def setup_chat_routes(
|
||||
session_manager,
|
||||
chat_handler,
|
||||
@@ -130,9 +263,20 @@ def setup_chat_routes(
|
||||
sess = session_manager.get_session(session)
|
||||
except KeyError:
|
||||
raise HTTPException(404, f"Session '{session}' not found")
|
||||
if _clear_orphaned_session_endpoint(sess):
|
||||
owner = get_current_user(request)
|
||||
if _clear_orphaned_session_endpoint(sess, owner=owner):
|
||||
raise HTTPException(400, "Selected model endpoint was removed. Pick another model in Settings.")
|
||||
|
||||
# Empty model + live endpoint = setup race (Issue #587). Repair from
|
||||
# the endpoint's cached model list before privilege checks, which
|
||||
# otherwise see "" and behave inconsistently with the allowlist.
|
||||
_recover_empty_session_model(sess, session, owner=owner)
|
||||
if not getattr(sess, "model", "").strip():
|
||||
raise HTTPException(
|
||||
400,
|
||||
"No model selected for this chat. Open the model picker and choose one before sending.",
|
||||
)
|
||||
|
||||
# Same allowed_models + daily-cap gate as chat_stream (mirror so the
|
||||
# non-streaming path can't be used to bypass).
|
||||
_enforce_chat_privileges(request, sess)
|
||||
@@ -270,8 +414,21 @@ def setup_chat_routes(
|
||||
# but BEFORE loading. Prevents cross-user session hijack.
|
||||
_verify_session_owner(request, session)
|
||||
sess = session_manager.get_session(session)
|
||||
if _clear_orphaned_session_endpoint(sess):
|
||||
owner = get_current_user(request)
|
||||
if _clear_orphaned_session_endpoint(sess, owner=owner):
|
||||
raise HTTPException(400, "Selected model endpoint was removed. Pick another model in Settings.")
|
||||
# Issue #587: picker shows a model from the endpoint cache but
|
||||
# s.model never made it onto the DB row (first-send race after
|
||||
# endpoint setup, or a previous endpoint delete/recreate). Pull
|
||||
# the first cached model off the matching endpoint so the
|
||||
# upstream isn't called with model="" (which surfaces as a
|
||||
# generic 401/503).
|
||||
_recover_empty_session_model(sess, session, owner=owner)
|
||||
if not getattr(sess, "model", "").strip():
|
||||
raise HTTPException(
|
||||
400,
|
||||
"No model selected for this chat. Open the model picker and choose one before sending.",
|
||||
)
|
||||
except SessionNotFoundError as e:
|
||||
raise HTTPException(404, str(e))
|
||||
except (ValueError, ValidationError):
|
||||
@@ -288,7 +445,7 @@ def setup_chat_routes(
|
||||
_enforce_chat_privileges(request, sess)
|
||||
|
||||
# Ensure session has auth headers
|
||||
resolve_session_auth(sess, session)
|
||||
resolve_session_auth(sess, session, owner=get_current_user(request))
|
||||
|
||||
# Check for research_pending BEFORE mode persist overwrites it
|
||||
do_research = str(use_research).lower() == "true"
|
||||
@@ -343,18 +500,22 @@ def setup_chat_routes(
|
||||
try:
|
||||
if active_doc_id:
|
||||
logger.info(f"[doc-inject] active_doc_id from frontend: {active_doc_id}")
|
||||
active_doc = _doc_db.query(DBDocument).filter(
|
||||
DBDocument.id == active_doc_id,
|
||||
).first()
|
||||
# Scope to the caller's documents. The session and in-memory
|
||||
# fallbacks below are already owner/session-bound; this
|
||||
# explicit-id path looked up by id alone, so a user could
|
||||
# inject another user's document by passing its id.
|
||||
_doc_q = _doc_db.query(DBDocument).filter(DBDocument.id == active_doc_id)
|
||||
active_doc = _owner_session_filter(_doc_q, ctx.user).first()
|
||||
if active_doc:
|
||||
logger.info(f"[doc-inject] found by ID: title={active_doc.title!r}, lang={active_doc.language!r}, is_active={active_doc.is_active}, content_len={len(active_doc.current_content or '')}")
|
||||
else:
|
||||
logger.warning(f"[doc-inject] NOT FOUND by ID {active_doc_id}")
|
||||
if not active_doc:
|
||||
active_doc = _doc_db.query(DBDocument).filter(
|
||||
_session_doc_q = _doc_db.query(DBDocument).filter(
|
||||
DBDocument.session_id == session,
|
||||
DBDocument.is_active == True
|
||||
).order_by(DBDocument.updated_at.desc()).first()
|
||||
)
|
||||
active_doc = _owner_session_filter(_session_doc_q, ctx.user).order_by(DBDocument.updated_at.desc()).first()
|
||||
if active_doc:
|
||||
logger.info(f"[doc-inject] found by session fallback: title={active_doc.title!r}")
|
||||
# Last resort: the document the agent itself just created/edited
|
||||
@@ -368,7 +529,8 @@ def setup_chat_routes(
|
||||
from src.tool_implementations import get_active_document
|
||||
_mem_id = get_active_document()
|
||||
if _mem_id:
|
||||
cand = _doc_db.query(DBDocument).filter(DBDocument.id == _mem_id).first()
|
||||
_mem_q = _doc_db.query(DBDocument).filter(DBDocument.id == _mem_id)
|
||||
cand = _owner_session_filter(_mem_q, ctx.user).first()
|
||||
if cand and (not cand.session_id or cand.session_id == session):
|
||||
active_doc = cand
|
||||
logger.info(f"[doc-inject] found by in-memory active id: title={active_doc.title!r} (session_id={cand.session_id!r})")
|
||||
@@ -563,6 +725,7 @@ def setup_chat_routes(
|
||||
prior_findings=_prior_findings,
|
||||
prior_urls=_prior_urls,
|
||||
on_complete=_on_research_done,
|
||||
owner=_user,
|
||||
)
|
||||
|
||||
_heartbeat_counter = 0
|
||||
@@ -619,7 +782,7 @@ def setup_chat_routes(
|
||||
# output. Resolved once per request.
|
||||
try:
|
||||
from src.endpoint_resolver import resolve_chat_fallback_candidates
|
||||
_fallback_candidates = resolve_chat_fallback_candidates()
|
||||
_fallback_candidates = resolve_chat_fallback_candidates(owner=_user)
|
||||
except Exception:
|
||||
_fallback_candidates = []
|
||||
|
||||
@@ -632,28 +795,7 @@ def setup_chat_routes(
|
||||
_model_info["character_name"] = ctx.preset.character_name
|
||||
yield f'data: {json.dumps(_model_info)}\n\n'
|
||||
|
||||
# Detect image models and route directly to image generation
|
||||
_IMAGE_MODEL_PREFIXES = ("gpt-image", "dall-e", "chatgpt-image")
|
||||
_is_image_model = any(sess.model.lower().startswith(p) for p in _IMAGE_MODEL_PREFIXES)
|
||||
|
||||
# Also check if the endpoint is registered as an image-type endpoint
|
||||
if not _is_image_model:
|
||||
try:
|
||||
from src.endpoint_resolver import normalize_base as _nb
|
||||
_ep_base = _nb(sess.endpoint_url)
|
||||
_db = SessionLocal()
|
||||
try:
|
||||
_is_image_model = _db.query(ModelEndpoint).filter(
|
||||
ModelEndpoint.model_type == "image",
|
||||
ModelEndpoint.is_enabled == True,
|
||||
ModelEndpoint.base_url.contains(_ep_base.split("://")[-1].split("/")[0]),
|
||||
).first() is not None
|
||||
finally:
|
||||
_db.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if _is_image_model:
|
||||
if _is_image_generation_session(sess, owner=_user):
|
||||
from src.settings import get_setting
|
||||
if not get_setting("image_gen_enabled", True):
|
||||
yield f'data: {json.dumps({"delta": "Image generation is disabled by the administrator."})}\n\n'
|
||||
@@ -664,7 +806,7 @@ def setup_chat_routes(
|
||||
_user_msg = message or ""
|
||||
yield f'data: {json.dumps({"type": "tool_start", "tool": "generate_image", "command": _user_msg[:100]})}\n\n'
|
||||
yield ": heartbeat\n\n"
|
||||
_img_result = await do_generate_image(f"{_user_msg}\n{sess.model}", session)
|
||||
_img_result = await do_generate_image(f"{_user_msg}\n{sess.model}", session, owner=_user)
|
||||
_img_output = _img_result.get("results", _img_result.get("error", ""))
|
||||
_img_tool_data = {"type": "tool_output", "tool": "generate_image", "command": _user_msg[:100], "output": _img_output, "exit_code": 0 if "error" not in _img_result else 1}
|
||||
for _k in ("image_url", "image_id", "image_prompt", "image_model", "image_size", "image_quality"):
|
||||
@@ -688,6 +830,7 @@ def setup_chat_routes(
|
||||
return
|
||||
elif chat_mode == "chat":
|
||||
_chat_start = time.time()
|
||||
_answered_by = None # set if the selected model failed and a fallback answered
|
||||
# ── Chat mode: call stream_llm directly, NO tools, NO document access ──
|
||||
try:
|
||||
_chat_candidates = [(sess.endpoint_url, sess.model, sess.headers)] + _fallback_candidates
|
||||
@@ -708,16 +851,35 @@ def setup_chat_routes(
|
||||
try:
|
||||
data = json.loads(chunk[6:])
|
||||
if "delta" in data:
|
||||
full_response += data["delta"]
|
||||
_stream_set(session, partial=full_response)
|
||||
# Reasoning tokens arrive flagged thinking:true.
|
||||
# Forward them so the client can show a thinking
|
||||
# indicator, but don't fold them into the saved
|
||||
# reply (mirrors the rewrite path below).
|
||||
if not data.get("thinking"):
|
||||
full_response += data["delta"]
|
||||
_stream_set(session, partial=full_response)
|
||||
yield chunk
|
||||
elif data.get("type") == "fallback":
|
||||
# Selected model failed; a fallback answered.
|
||||
# Forward the notice and remember the real model.
|
||||
_answered_by = data.get("answered_by") or _answered_by
|
||||
yield chunk
|
||||
elif data.get("type") == "usage":
|
||||
last_metrics = data.get("data", {})
|
||||
last_metrics["model"] = sess.model
|
||||
last_metrics["model"] = _answered_by or sess.model
|
||||
if ctx.context_length and last_metrics.get("input_tokens"):
|
||||
pct = min(round((last_metrics["input_tokens"] / ctx.context_length) * 100, 1), 100.0)
|
||||
last_metrics["context_percent"] = pct
|
||||
last_metrics["context_length"] = ctx.context_length
|
||||
# The frontend reads `tokens_per_second`; the raw usage event
|
||||
# carries the backend's true gen speed as `gen_tps` (llama.cpp
|
||||
# timings). Map it through so this direct-chat path shows real
|
||||
# t/s instead of "n/a" → falling back to a bare token count.
|
||||
if last_metrics.get("gen_tps") and not last_metrics.get("tokens_per_second"):
|
||||
last_metrics["tokens_per_second"] = last_metrics["gen_tps"]
|
||||
last_metrics["tps_source"] = "backend"
|
||||
# Wall-clock response time for the stats popup ("Time").
|
||||
last_metrics.setdefault("response_time", round(time.time() - _chat_start, 2))
|
||||
yield f'data: {json.dumps({"type": "metrics", "data": last_metrics})}\n\n'
|
||||
except json.JSONDecodeError:
|
||||
yield chunk
|
||||
@@ -781,6 +943,7 @@ def setup_chat_routes(
|
||||
# ── Agent mode: full agent loop with tools ──
|
||||
_agent_rounds = 0
|
||||
_agent_tool_calls = 0
|
||||
_answered_by = None # set if the selected model failed and a fallback answered
|
||||
try:
|
||||
from src.settings import get_setting
|
||||
_tool_budget = int(get_setting("agent_max_tool_calls", 0))
|
||||
@@ -805,8 +968,12 @@ def setup_chat_routes(
|
||||
try:
|
||||
data = json.loads(chunk[6:])
|
||||
if "delta" in data:
|
||||
full_response += data["delta"]
|
||||
_stream_set(session, partial=full_response)
|
||||
# Reasoning tokens arrive flagged thinking:true.
|
||||
# Forward them for the live indicator, but keep
|
||||
# them out of the saved reply (same as chat mode).
|
||||
if not data.get("thinking"):
|
||||
full_response += data["delta"]
|
||||
_stream_set(session, partial=full_response)
|
||||
yield chunk
|
||||
elif data.get("type") == "web_sources":
|
||||
web_sources = data.get("data", [])
|
||||
@@ -821,9 +988,16 @@ def setup_chat_routes(
|
||||
elif data.get("type") == "tool_start":
|
||||
_agent_tool_calls += 1
|
||||
yield chunk
|
||||
elif data.get("type") == "fallback":
|
||||
# Selected model failed; a fallback answered.
|
||||
# Forward the notice and remember the real
|
||||
# model so metrics reflect it, not the masked
|
||||
# selected model.
|
||||
_answered_by = data.get("answered_by") or _answered_by
|
||||
yield chunk
|
||||
elif data.get("type") == "metrics":
|
||||
last_metrics = data.get("data", {})
|
||||
last_metrics["model"] = sess.model
|
||||
last_metrics["model"] = _answered_by or sess.model
|
||||
yield f'data: {json.dumps({"type": "metrics", "data": last_metrics})}\n\n'
|
||||
except json.JSONDecodeError:
|
||||
yield chunk
|
||||
@@ -920,11 +1094,15 @@ def setup_chat_routes(
|
||||
_verify_session_owner(request, session_id)
|
||||
# A detached run can still be going even if _active_streams was popped;
|
||||
# report it as active so the client knows to reconnect via /resume.
|
||||
if session_id not in _active_streams:
|
||||
# Read once via .get() to avoid a KeyError race between the membership
|
||||
# check and the indexed read if a sibling stream's finally pops the
|
||||
# entry in between (same pattern _stream_set already uses).
|
||||
rec = _active_streams.get(session_id)
|
||||
if rec is None:
|
||||
if agent_runs.is_active(session_id):
|
||||
return {"status": "streaming", "detached": True}
|
||||
raise HTTPException(404, "No active stream for this session")
|
||||
return _active_streams[session_id]
|
||||
return rec
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# POST /api/inject_context
|
||||
@@ -1088,7 +1266,7 @@ def setup_chat_routes(
|
||||
db_msg = (
|
||||
db.query(DBChatMessage)
|
||||
.filter(DBChatMessage.session_id == session_id, DBChatMessage.role == 'assistant')
|
||||
.order_by(DBChatMessage.created_at.desc())
|
||||
.order_by(DBChatMessage.timestamp.desc())
|
||||
.first()
|
||||
)
|
||||
if db_msg:
|
||||
|
||||
+19
-12
@@ -130,21 +130,28 @@ def _parse_vcards(text: str) -> List[Dict]:
|
||||
contact = {"name": "", "emails": [], "phones": [], "uid": ""}
|
||||
for line in block.split("\n"):
|
||||
line = line.strip()
|
||||
if line.startswith("FN:") or line.startswith("FN;"):
|
||||
contact["name"] = _vunesc(line.split(":", 1)[1]) if ":" in line else ""
|
||||
elif line.startswith("EMAIL"):
|
||||
# Strip an optional RFC 6350 group prefix (e.g. "item1.EMAIL;...")
|
||||
# that Apple Contacts / iCloud / many CardDAV servers emit by
|
||||
# default — without this the property-name checks below miss those
|
||||
# lines and silently drop the email / phone. The group token only
|
||||
# precedes the property name, so it is safe to strip for matching
|
||||
# and value extraction, and a no-op for non-grouped lines.
|
||||
name_part = re.sub(r"^[A-Za-z0-9-]+\.", "", line, count=1)
|
||||
if name_part.startswith("FN:") or name_part.startswith("FN;"):
|
||||
contact["name"] = _vunesc(name_part.split(":", 1)[1]) if ":" in name_part else ""
|
||||
elif name_part.startswith("EMAIL"):
|
||||
# Handle EMAIL:foo@bar OR EMAIL;TYPE=...:foo@bar OR EMAIL;PREF=1:foo@bar
|
||||
if ":" in line:
|
||||
email_addr = _vunesc(line.split(":", 1)[1])
|
||||
if ":" in name_part:
|
||||
email_addr = _vunesc(name_part.split(":", 1)[1])
|
||||
if email_addr and email_addr not in contact["emails"]:
|
||||
contact["emails"].append(email_addr)
|
||||
elif line.startswith("TEL"):
|
||||
if ":" in line:
|
||||
phone = _vunesc(line.split(":", 1)[1])
|
||||
elif name_part.startswith("TEL"):
|
||||
if ":" in name_part:
|
||||
phone = _vunesc(name_part.split(":", 1)[1])
|
||||
if phone and phone not in contact["phones"]:
|
||||
contact["phones"].append(phone)
|
||||
elif line.startswith("UID:"):
|
||||
contact["uid"] = _vunesc(line[4:])
|
||||
elif name_part.startswith("UID:"):
|
||||
contact["uid"] = _vunesc(name_part[4:])
|
||||
if contact["name"] or contact["emails"]:
|
||||
contacts.append(contact)
|
||||
return contacts
|
||||
@@ -676,8 +683,8 @@ def setup_contacts_routes():
|
||||
@router.post("/add")
|
||||
async def add_contact(data: dict, _admin: str = Depends(require_admin)):
|
||||
"""Add a new contact."""
|
||||
name = data.get("name", "").strip()
|
||||
email = data.get("email", "").strip()
|
||||
name = (data.get("name") or "").strip()
|
||||
email = (data.get("email") or "").strip()
|
||||
if not email:
|
||||
return {"success": False, "error": "Email required"}
|
||||
# Check if already exists
|
||||
|
||||
+250
-8
@@ -148,6 +148,108 @@ def _local_tooling_path_export(executable: str) -> str:
|
||||
return f'export PATH="{esc}:$PATH"'
|
||||
|
||||
|
||||
def _pip_install_no_cache(cmd: str) -> str:
|
||||
"""Add ``--no-cache-dir`` to a pip install command.
|
||||
|
||||
Cookbook dependency installs (vLLM, llama-cpp-python, …) build large wheels;
|
||||
pip's default cache lives under ``$HOME/.cache/pip`` and these builds can fill
|
||||
a small home filesystem with ``[Errno 28] No space left on device`` mid-build
|
||||
(issue #1219), leaving the dependency "installed" but unusable (#1459).
|
||||
Disabling the cache for these one-off installs keeps them off the home disk
|
||||
(the maintainer's suggested ``PIP_CACHE_DIR=`` workaround, made the default).
|
||||
Idempotent; leaves non-pip-install commands untouched."""
|
||||
if not cmd or "pip install" not in cmd or "--no-cache-dir" in cmd:
|
||||
return cmd
|
||||
return cmd.replace("pip install", "pip install --no-cache-dir", 1)
|
||||
|
||||
|
||||
def _pip_install_attempt(pip_cmd: str) -> str:
|
||||
"""Wrap a single pip install command so its exit status survives the
|
||||
fallback chain and its stderr is visible in the tmux log on failure.
|
||||
|
||||
Without this wrapper, `pip … 2>&1 | tail -5` returns ``tail``'s exit
|
||||
code (0), masking pip's real failure and preventing the next fallback
|
||||
from running. The generated snippet captures all output to a temp
|
||||
file, prints the last 5 lines on failure (so the Cookbook log panel
|
||||
shows useful diagnostics), cleans up, and exits with pip's original
|
||||
status.
|
||||
"""
|
||||
return (
|
||||
"bash -c '"
|
||||
f'_out=$(mktemp) && {pip_cmd} >"$_out" 2>&1; _rc=$?; '
|
||||
'tail -5 "$_out"; rm -f "$_out"; exit $_rc'
|
||||
"'"
|
||||
)
|
||||
|
||||
|
||||
def _pip_install_fallback_chain(package: str, *, python_cmd: str = "python3 -m pip", upgrade: bool = False) -> str:
|
||||
"""Build a bash pip install fallback chain that surfaces errors.
|
||||
|
||||
Try the active interpreter/environment first. ``--user`` is invalid
|
||||
inside many venvs, so only attempt the ``--user`` fallback when NOT
|
||||
inside a venv.
|
||||
|
||||
Each attempt is wrapped via :func:`_pip_install_attempt` so pip's real
|
||||
exit code is preserved (no ``| tail`` masking) and the last 5 lines of
|
||||
pip output appear in the Cookbook log on failure.
|
||||
"""
|
||||
upgrade_flag = " -U" if upgrade else ""
|
||||
# Shell-quote the package spec: an extras spec like ``llama-cpp-python[server]``
|
||||
# contains brackets that bash would treat as a glob, so it must be quoted
|
||||
# before being embedded in the install command. Plain names (e.g.
|
||||
# ``huggingface_hub``) are returned unchanged by ``shlex.quote``.
|
||||
pkg = shlex.quote(package)
|
||||
base = _pip_install_attempt(f"{python_cmd} install -q{upgrade_flag} {pkg}")
|
||||
user = _pip_install_attempt(f"{python_cmd} install --user --break-system-packages -q{upgrade_flag} {pkg}")
|
||||
# Derive the python executable for the venv detection check.
|
||||
# Must use the same interpreter that pip belongs to; hardcoding
|
||||
# python3 breaks when pip lives in a venv that only has "python".
|
||||
if " -m pip" in python_cmd:
|
||||
python_exe = python_cmd.replace(" -m pip", "")
|
||||
elif python_cmd.strip() == "pip":
|
||||
python_exe = "python"
|
||||
elif python_cmd.strip() == "pip3":
|
||||
python_exe = "python3"
|
||||
else:
|
||||
python_exe = "python3"
|
||||
venv_check = f'{python_exe} -c "import sys; sys.exit(0 if sys.prefix != sys.base_prefix else 1)"'
|
||||
# Negated: `! venv_check` succeeds (exit 0) when NOT in a venv → `&&` tries
|
||||
# --user. When IN a venv `! venv_check` fails → `&&` skips --user and the
|
||||
# group exits non-zero, propagating the base-install failure instead of
|
||||
# masking it as success (the `|| { venv_check || … }` shape from #903
|
||||
# swallowed the exit code because venv_check's exit-0 became the group's
|
||||
# result).
|
||||
return f"{base} || {{ ! {venv_check} && {user}; }}"
|
||||
|
||||
|
||||
def _venv_safe_local_pip_install_cmd(cmd: str, *, local: bool, in_venv: bool) -> str:
|
||||
"""Drop pip user-install flags that are invalid for local venv installs.
|
||||
|
||||
Cookbook dependency installs run through the model-serve task path so users
|
||||
can watch progress in the same log UI. For local POSIX runs, that task
|
||||
prepends Odysseus' own interpreter directory to PATH. If Odysseus itself is
|
||||
running from a venv, `python3` resolves to the venv Python and pip rejects
|
||||
`--user` with "User site-packages are not visible in this virtualenv".
|
||||
|
||||
Keep remote and non-venv installs unchanged: remotes may intentionally use
|
||||
system Python, and Docker/non-venv installs still need user-site fallback.
|
||||
"""
|
||||
if not local or not in_venv:
|
||||
return cmd
|
||||
if "pip install" not in (cmd or ""):
|
||||
return cmd
|
||||
try:
|
||||
parts = shlex.split(cmd)
|
||||
except ValueError:
|
||||
return cmd
|
||||
stripped = [
|
||||
part
|
||||
for part in parts
|
||||
if part not in {"--user", "--break-system-packages"}
|
||||
]
|
||||
return shlex.join(stripped)
|
||||
|
||||
|
||||
def _cached_model_scan_script(model_dirs: list[str] | None = None) -> str:
|
||||
"""Build the standalone Python scanner used by /api/model/cached."""
|
||||
lines = [
|
||||
@@ -166,6 +268,38 @@ def _cached_model_scan_script(model_dirs: list[str] | None = None) -> str:
|
||||
" for root, dirs, fns in os.walk(top, followlinks=False):",
|
||||
" dirs[:] = [d for d in dirs if not os.path.islink(os.path.join(root, d)) and safe_path(os.path.join(root, d))]",
|
||||
" yield root, dirs, fns",
|
||||
"def gguf_role(name):",
|
||||
" n = name.lower()",
|
||||
" if n.startswith('mmproj') or 'mmproj' in n: return 'projector'",
|
||||
" return 'model'",
|
||||
"def gguf_quant(name):",
|
||||
" m = re.search(r'(?i)(UD-)?(IQ[0-9]_[A-Z0-9_]+|Q[0-9](?:_[A-Z0-9]+)+|BF16|F16|FP16|F32|Q8_0)', name)",
|
||||
" return m.group(0).upper() if m else ''",
|
||||
"def collect_ggufs(base):",
|
||||
" files = []",
|
||||
" split_groups = {}",
|
||||
" if not os.path.isdir(base) or not safe_path(base): return files",
|
||||
" for root, dirs, fns in safe_walk(base):",
|
||||
" for fn in sorted(fns):",
|
||||
" if not fn.lower().endswith('.gguf'): continue",
|
||||
" fp = os.path.join(root, fn)",
|
||||
" try: size = os.path.getsize(fp)",
|
||||
" except Exception: size = 0",
|
||||
" try: rel = os.path.relpath(fp, base).replace(os.sep, '/')",
|
||||
" except Exception: rel = fn",
|
||||
" sm = re.match(r'(?i)^(.+)-(\\d+)-of-(\\d+)\\.gguf$', fn)",
|
||||
" if sm:",
|
||||
" prefix, part_s, total_s = sm.group(1), sm.group(2), sm.group(3)",
|
||||
" key = (root, prefix, total_s)",
|
||||
" g = split_groups.setdefault(key, {'name':fn,'rel_path':rel,'size_bytes':0,'role':gguf_role(fn),'quant':gguf_quant(fn),'parts':int(total_s),'split':True})",
|
||||
" g['size_bytes'] += size",
|
||||
" if int(part_s) == 1:",
|
||||
" g.update({'name':fn,'rel_path':rel,'role':gguf_role(fn),'quant':gguf_quant(fn)})",
|
||||
" continue",
|
||||
" files.append({'name':fn,'rel_path':rel,'size_bytes':size,'role':gguf_role(fn),'quant':gguf_quant(fn)})",
|
||||
" files.extend(split_groups.values())",
|
||||
" files.sort(key=lambda f: (f.get('role') != 'model', f.get('rel_path', '')))",
|
||||
" return files",
|
||||
"def scan_hf(cache):",
|
||||
" if not os.path.isdir(cache): return",
|
||||
" for d in sorted(os.listdir(cache)):",
|
||||
@@ -180,16 +314,14 @@ def _cached_model_scan_script(model_dirs: list[str] | None = None) -> str:
|
||||
" if f.is_file(): nf += 1; sz += f.stat().st_size",
|
||||
" if f.name.endswith('.incomplete'): ic = True",
|
||||
" snap = os.path.join(cache, d, 'snapshots')",
|
||||
" is_diffusion = False; is_gguf = False",
|
||||
" is_diffusion = False; gguf_files = []",
|
||||
" if os.path.isdir(snap):",
|
||||
" for sd in os.listdir(snap):",
|
||||
" sf = os.path.join(snap, sd)",
|
||||
" if not os.path.isdir(sf): continue",
|
||||
" if os.path.exists(os.path.join(sf, 'model_index.json')): is_diffusion = True",
|
||||
" try:",
|
||||
" if any(x.endswith('.gguf') for x in os.listdir(sf)): is_gguf = True",
|
||||
" except Exception: pass",
|
||||
" models.append({'repo_id':rid,'size_bytes':sz,'nb_files':nf,'has_incomplete':ic,'path':cache,'is_diffusion':is_diffusion,'is_gguf':is_gguf})",
|
||||
" for f in collect_ggufs(sf): f['rel_path'] = sd + '/' + f['rel_path']; gguf_files.append(f)",
|
||||
" models.append({'repo_id':rid,'size_bytes':sz,'nb_files':nf,'has_incomplete':ic,'path':cache,'is_diffusion':is_diffusion,'is_gguf':bool(gguf_files),'gguf_files':gguf_files})",
|
||||
"def scan_dir(p):",
|
||||
" if not os.path.isdir(p) or not safe_path(p): return",
|
||||
" for d in sorted(os.listdir(p)):",
|
||||
@@ -198,13 +330,14 @@ def _cached_model_scan_script(model_dirs: list[str] | None = None) -> str:
|
||||
" fp = os.path.join(p, d)",
|
||||
" if not os.path.isdir(fp) or os.path.islink(fp) or not safe_path(fp): continue",
|
||||
" if d in seen: continue",
|
||||
" is_model = False; is_gguf = False",
|
||||
" is_model = False; gguf_files = []",
|
||||
" for root, dirs, fns in safe_walk(fp):",
|
||||
" for fn in fns:",
|
||||
" if fn.endswith('.gguf'): is_gguf = True; is_model = True",
|
||||
" if fn.lower().endswith('.gguf'): is_model = True",
|
||||
" elif fn == 'config.json' or fn.endswith('.safetensors') or fn.endswith('.bin'): is_model = True",
|
||||
" if is_model: break",
|
||||
" if not is_model: continue",
|
||||
" gguf_files = collect_ggufs(fp)",
|
||||
" seen.add(d)",
|
||||
" sz, nf = 0, 0",
|
||||
" for dp, _, fns in safe_walk(fp):",
|
||||
@@ -212,7 +345,7 @@ def _cached_model_scan_script(model_dirs: list[str] | None = None) -> str:
|
||||
" try: nf += 1; sz += os.path.getsize(os.path.join(dp, fn))",
|
||||
" except Exception: pass",
|
||||
" is_diff = os.path.exists(os.path.join(fp, 'model_index.json'))",
|
||||
" models.append({'repo_id':d,'size_bytes':sz,'nb_files':nf,'has_incomplete':False,'path':p,'is_local_dir':True,'is_diffusion':is_diff,'is_gguf':is_gguf})",
|
||||
" models.append({'repo_id':d,'size_bytes':sz,'nb_files':nf,'has_incomplete':False,'path':p,'is_local_dir':True,'is_diffusion':is_diff,'is_gguf':bool(gguf_files),'gguf_files':gguf_files})",
|
||||
"def parse_size(num, unit):",
|
||||
" try: n = float(num)",
|
||||
" except Exception: return 0",
|
||||
@@ -293,6 +426,38 @@ _SERVE_CMD_ALLOWLIST = {
|
||||
_GGUF_PRELUDE_RE = re.compile(
|
||||
r'^MODEL_FILE=\$\([^\n]*?\)\s*&&\s*\{[^{}]*\}\s*\|\|\s*\{[^{}]*\}\s*&&\s*'
|
||||
)
|
||||
_OLLAMA_HOST_ASSIGNMENT_RE = re.compile(r"(?:^|\s)OLLAMA_HOST=([^\s]+)")
|
||||
_OLLAMA_BIND_RE = re.compile(r"^\[([^\]]+)\]:(\d+)$|^([^:]+):(\d+)$")
|
||||
_OLLAMA_BIND_HOST_RE = re.compile(r"^[A-Za-z0-9._:-]+$")
|
||||
|
||||
|
||||
def _ollama_bind_from_cmd(cmd: str | None, *, default_host: str = "127.0.0.1") -> tuple[str, str]:
|
||||
"""Return the Ollama bind host/port requested by a serve command.
|
||||
|
||||
Plain local `ollama serve` defaults to loopback. Remote callers can pass a
|
||||
wider default host so the resulting API is reachable by Odysseus.
|
||||
"""
|
||||
if not cmd:
|
||||
return default_host, "11434"
|
||||
match = _OLLAMA_HOST_ASSIGNMENT_RE.search(cmd)
|
||||
if not match:
|
||||
return default_host, "11434"
|
||||
value = match.group(1).strip("'\"")
|
||||
bind_match = _OLLAMA_BIND_RE.match(value)
|
||||
if not bind_match:
|
||||
return "127.0.0.1", "11434"
|
||||
bracketed_host = bind_match.group(1)
|
||||
host = bracketed_host or bind_match.group(3) or "127.0.0.1"
|
||||
port = bind_match.group(2) or bind_match.group(4) or "11434"
|
||||
if not _OLLAMA_BIND_HOST_RE.match(host):
|
||||
return "127.0.0.1", "11434"
|
||||
try:
|
||||
port_num = int(port, 10)
|
||||
except ValueError:
|
||||
return "127.0.0.1", "11434"
|
||||
if port_num < 1 or port_num > 65535:
|
||||
return "127.0.0.1", "11434"
|
||||
return f"[{host}]" if bracketed_host else host, port
|
||||
|
||||
|
||||
def _check_serve_binary(seg: str) -> None:
|
||||
@@ -370,6 +535,83 @@ def _append_serve_exit_code_lines(runner_lines: list[str], *, keep_shell_open: b
|
||||
runner_lines.append('echo ""; echo "=== Process exited with code $ODYSSEUS_CMD_EXIT ==="; exec "${SHELL:-/bin/bash}"')
|
||||
else:
|
||||
runner_lines.append('echo ""; echo "=== Process exited with code $ODYSSEUS_CMD_EXIT ==="')
|
||||
runner_lines.append('exit "$ODYSSEUS_CMD_EXIT"')
|
||||
|
||||
|
||||
def _append_llama_cpp_linux_accel_build_lines(runner_lines: list[str]) -> None:
|
||||
"""Append Linux llama.cpp build lines that prefer ROCm/HIP when available.
|
||||
|
||||
Cookbook already detects AMD GPUs elsewhere, but the llama.cpp bootstrap used
|
||||
to hard-wire CUDA on Linux. That made ROCm hosts attempt a CUDA configure and
|
||||
fail with "CUDA Toolkit not found" instead of building with HIP.
|
||||
"""
|
||||
# Detect pip-installed nvcc (from vLLM/nvidia CUDA wheels) and put it on PATH
|
||||
# so cmake's CUDA configure can find it. We keep this after the ROCm/HIP
|
||||
# check — a machine with both stacks should honor the native HIP toolchain on
|
||||
# AMD hosts instead of accidentally preferring a stray nvcc wheel.
|
||||
runner_lines.append(' for _cudir in ~/.local/lib/python*/site-packages/nvidia/cu13 ~/.local/lib/python*/site-packages/nvidia/cu12 ~/.local/lib/python*/site-packages/nvidia/cuda_nvcc; do')
|
||||
runner_lines.append(' [ -x "$_cudir/bin/nvcc" ] && export CUDA_HOME="$_cudir" && export PATH="$_cudir/bin:$PATH" && break')
|
||||
runner_lines.append(' done')
|
||||
# rm -rf build so a prior poisoned CMakeCache.txt (e.g. from a failed CUDA
|
||||
# or HIP attempt) doesn't cause the next configure to reuse stale settings.
|
||||
runner_lines.append(' cd ~/llama.cpp && rm -rf build')
|
||||
runner_lines.append(' if command -v hipconfig &>/dev/null || [ -d /opt/rocm ] || [ -n "$ROCM_PATH" ] || [ -n "$HIP_PATH" ]; then')
|
||||
runner_lines.append(' if command -v hipconfig &>/dev/null; then')
|
||||
runner_lines.append(' export HIPCXX="${HIPCXX:-$(hipconfig -l)/clang}"')
|
||||
runner_lines.append(' export HIP_PATH="${HIP_PATH:-$(hipconfig -R)}"')
|
||||
runner_lines.append(' fi')
|
||||
runner_lines.append(' echo "[odysseus] ROCm/HIP detected — building llama-server with HIP support..."')
|
||||
runner_lines.append(' cmake -B build -DCMAKE_BUILD_TYPE=Release -DGGML_HIP=ON && cmake --build build -j"$NPROC" --target llama-server && ln -sf ~/llama.cpp/build/bin/llama-server ~/bin/llama-server')
|
||||
runner_lines.append(' elif command -v nvcc &>/dev/null; then')
|
||||
# nvcc alone is not sufficient — pip-installed CUDA wheels or incomplete
|
||||
# tooling can expose nvcc without shipping libcudart, causing cmake to fail
|
||||
# mid-build with "CUDA runtime library not found". Check cudart explicitly
|
||||
# via a small helper so the guard stays readable.
|
||||
runner_lines.append(' _odysseus_has_cudart() {')
|
||||
runner_lines.append(' ldconfig -p 2>/dev/null | grep -q \'libcudart\\.so\' && return 0')
|
||||
runner_lines.append(' local _cuh="${CUDA_HOME:-/usr/local/cuda}"')
|
||||
runner_lines.append(' ls "$_cuh/lib64/libcudart.so"* &>/dev/null && return 0')
|
||||
runner_lines.append(' ls "$_cuh/lib/libcudart.so"* &>/dev/null && return 0')
|
||||
runner_lines.append(' ls /usr/local/cuda/lib64/libcudart.so* &>/dev/null && return 0')
|
||||
runner_lines.append(' ls /usr/local/cuda/lib/libcudart.so* &>/dev/null && return 0')
|
||||
runner_lines.append(' ls "${_cuh%/cuda_nvcc}/cuda_runtime/lib/libcudart.so"* &>/dev/null && return 0')
|
||||
runner_lines.append(' return 1')
|
||||
runner_lines.append(' }')
|
||||
runner_lines.append(' if _odysseus_has_cudart; then')
|
||||
runner_lines.append(' echo "[odysseus] CUDA nvcc + cudart found — building llama-server with CUDA (GPU) support..."')
|
||||
runner_lines.append(' cmake -B build -DCMAKE_BUILD_TYPE=Release -DGGML_CUDA=ON && cmake --build build -j"$NPROC" --target llama-server && ln -sf ~/llama.cpp/build/bin/llama-server ~/bin/llama-server')
|
||||
runner_lines.append(' else')
|
||||
runner_lines.append(' echo "[odysseus] WARNING: nvcc found but CUDA runtime (libcudart.so) is not visible — building llama-server for CPU only."')
|
||||
runner_lines.append(' echo "[odysseus] GPU inference will not be available for this llama.cpp build."')
|
||||
runner_lines.append(' echo "[odysseus] Ensure libcudart is installed (e.g. cuda-runtime package) and visible via ldconfig or CUDA_HOME."')
|
||||
runner_lines.append(' cmake -B build -DCMAKE_BUILD_TYPE=Release && cmake --build build -j"$NPROC" --target llama-server && ln -sf ~/llama.cpp/build/bin/llama-server ~/bin/llama-server')
|
||||
runner_lines.append(' fi')
|
||||
runner_lines.append(' else')
|
||||
runner_lines.append(' echo "[odysseus] WARNING: no HIP/CUDA toolchain found — building llama-server for CPU only."')
|
||||
runner_lines.append(' echo "[odysseus] GPU inference will not be available for this llama.cpp build."')
|
||||
runner_lines.append(' echo "[odysseus] Install ROCm for AMD GPUs or vLLM/CUDA tooling for NVIDIA, then re-launch this serve task."')
|
||||
runner_lines.append(' cmake -B build -DCMAKE_BUILD_TYPE=Release && cmake --build build -j"$NPROC" --target llama-server && ln -sf ~/llama.cpp/build/bin/llama-server ~/bin/llama-server')
|
||||
runner_lines.append(' fi')
|
||||
|
||||
|
||||
def _llama_cpp_rebuild_cmd() -> str:
|
||||
"""Shell command that clears the Cookbook-managed llama.cpp build.
|
||||
|
||||
Removes the cached ``llama-server`` symlink and the ``~/llama.cpp/build``
|
||||
directory so the next llama.cpp serve recompiles from source, picking up a
|
||||
CUDA or HIP toolchain if one is now available. The serve bootstrap only
|
||||
builds when ``llama-server`` is missing from PATH, so without this an
|
||||
existing CPU-only build is reused forever. It deliberately installs and
|
||||
downloads nothing; the rebuild itself happens on the next serve.
|
||||
"""
|
||||
return (
|
||||
'mkdir -p "$HOME/bin" && '
|
||||
'rm -f "$HOME/bin/llama-server" && '
|
||||
'rm -rf "$HOME/llama.cpp/build" && '
|
||||
'echo "[odysseus] Cleared the cached llama.cpp build. '
|
||||
'Re-launch the serve task to rebuild llama-server from source '
|
||||
'(CUDA or HIP will be used if a toolchain is now available)."'
|
||||
)
|
||||
|
||||
|
||||
class ModelDownloadRequest(BaseModel):
|
||||
|
||||
+319
-73
@@ -37,7 +37,8 @@ from routes.cookbook_helpers import (
|
||||
_validate_local_dir, _validate_ssh_port, _validate_gpus, _shell_path,
|
||||
_ps_squote, _bash_squote, _validate_serve_cmd, _parse_serve_phase,
|
||||
_safe_env_prefix, _local_tooling_path_export, _append_serve_preflight_exit_lines,
|
||||
_append_serve_exit_code_lines, _cached_model_scan_script,
|
||||
_append_serve_exit_code_lines, _append_llama_cpp_linux_accel_build_lines, _cached_model_scan_script,
|
||||
_ollama_bind_from_cmd, _pip_install_fallback_chain, _pip_install_no_cache, _venv_safe_local_pip_install_cmd,
|
||||
ModelDownloadRequest, ServeRequest,
|
||||
)
|
||||
|
||||
@@ -148,6 +149,15 @@ def setup_cookbook_routes() -> APIRouter:
|
||||
"No GPUs are visible to the serve process.",
|
||||
[{"label": "clear Cookbook GPU selection or choose available GPUs", "op": "settings", "field": "gpus", "value": ""}],
|
||||
),
|
||||
(
|
||||
r"Failed to infer device type|NVML Shared Library Not Found|No module named 'amdsmi'|platform is not available",
|
||||
"vLLM could not find a supported GPU (CUDA or ROCm). "
|
||||
"This machine may have integrated or unsupported graphics only.",
|
||||
[
|
||||
{"label": "switch to llama.cpp (CPU/Metal, works without a discrete GPU)", "op": "manual"},
|
||||
{"label": "switch to Ollama (CPU/Metal, works without a discrete GPU)", "op": "manual"},
|
||||
],
|
||||
),
|
||||
(
|
||||
r"vllm.*command not found|No module named vllm|ERROR: vLLM is not installed",
|
||||
"vLLM is not installed or not in PATH on this server.",
|
||||
@@ -163,6 +173,11 @@ def setup_cookbook_routes() -> APIRouter:
|
||||
"llama.cpp / llama-cpp-python dependencies are missing.",
|
||||
[{"label": "install llama.cpp dependencies or llama-cpp-python[server]", "op": "dependency", "package": "llama-cpp-python[server]"}],
|
||||
),
|
||||
(
|
||||
r"No GGUF found on this host|no \.gguf file|No GGUF file found",
|
||||
"No GGUF file found for this model on this host. The llama.cpp backend needs a .gguf file.",
|
||||
[{"label": "download a GGUF build of this model (repo name usually ends in -GGUF, file like Q4_K_M.gguf)", "op": "manual"}],
|
||||
),
|
||||
(
|
||||
r"No module named 'torch'|No module named torch|No module named 'diffusers'|No module named diffusers",
|
||||
"Diffusion serving requires PyTorch and diffusers.",
|
||||
@@ -368,11 +383,15 @@ def setup_cookbook_routes() -> APIRouter:
|
||||
encoding="utf-8",
|
||||
)
|
||||
argv = [os.environ.get("ComSpec", "cmd.exe"), "/c", str(script_path)]
|
||||
env = os.environ.copy()
|
||||
env["PYTHONUTF8"] = "1"
|
||||
env["PYTHONIOENCODING"] = "utf-8"
|
||||
proc = subprocess.Popen(
|
||||
argv,
|
||||
stdout=subprocess.DEVNULL,
|
||||
stderr=subprocess.DEVNULL,
|
||||
stdin=subprocess.DEVNULL,
|
||||
env=env,
|
||||
**detached_popen_kwargs(),
|
||||
)
|
||||
pid_path.write_text(str(proc.pid), encoding="utf-8")
|
||||
@@ -432,12 +451,12 @@ def setup_cookbook_routes() -> APIRouter:
|
||||
# throughput. Retries set disable_hf_transfer to fall back to the plain,
|
||||
# slower-but-reliable downloader (resumes cleanly from the .incomplete files).
|
||||
# Use `python3 -m pip` not `pip` — macOS has no bare `pip` command.
|
||||
lines.append("command -v hf >/dev/null 2>&1 || python3 -m pip install --user --break-system-packages -q -U huggingface_hub 2>/dev/null || python3 -m pip install -q -U huggingface_hub 2>/dev/null")
|
||||
lines.append(f"command -v hf >/dev/null 2>&1 || {_pip_install_fallback_chain('huggingface_hub', upgrade=True)}")
|
||||
if req.disable_hf_transfer:
|
||||
lines.append("export HF_HUB_ENABLE_HF_TRANSFER=0")
|
||||
lines.append("export HF_HUB_DOWNLOAD_MAX_WORKERS=4")
|
||||
else:
|
||||
lines.append("python3 -c 'import hf_transfer' 2>/dev/null || python3 -m pip install --user --break-system-packages -q hf_transfer 2>/dev/null || python3 -m pip install -q hf_transfer 2>/dev/null")
|
||||
lines.append(f"python3 -c 'import hf_transfer' 2>/dev/null || {_pip_install_fallback_chain('hf_transfer')}")
|
||||
lines.append("python3 -c 'import hf_transfer' 2>/dev/null && export HF_HUB_ENABLE_HF_TRANSFER=1")
|
||||
lines.append("export HF_HUB_DOWNLOAD_MAX_WORKERS=8")
|
||||
|
||||
@@ -531,12 +550,18 @@ def setup_cookbook_routes() -> APIRouter:
|
||||
)
|
||||
# Ensure pip-user scripts (e.g. hf CLI installed via --user) are on PATH
|
||||
runner_lines.append('export PATH="$HOME/.local/bin:$PATH"')
|
||||
# Install hf CLI + hf_transfer best-effort so future runs get the fast path.
|
||||
# Install hf CLI + optional hf_transfer best-effort. Retries disable
|
||||
# hf_transfer because the Rust parallel path is fast but has been
|
||||
# flaky near the end of very large multi-file downloads.
|
||||
# Use --break-system-packages on PEP-668 systems (Arch, newer Debian) so it doesn't bail.
|
||||
runner_lines.append("command -v hf >/dev/null 2>&1 || pip install --user --break-system-packages -q -U huggingface_hub 2>/dev/null || pip install -q -U huggingface_hub 2>/dev/null")
|
||||
runner_lines.append("python3 -c 'import hf_transfer' 2>/dev/null || pip install --user --break-system-packages -q hf_transfer 2>/dev/null || pip install -q hf_transfer 2>/dev/null")
|
||||
runner_lines.append("python3 -c 'import hf_transfer' 2>/dev/null && export HF_HUB_ENABLE_HF_TRANSFER=1")
|
||||
runner_lines.append("export HF_HUB_DOWNLOAD_MAX_WORKERS=8")
|
||||
runner_lines.append(f"command -v hf >/dev/null 2>&1 || {_pip_install_fallback_chain('huggingface_hub', python_cmd='pip', upgrade=True)}")
|
||||
if req.disable_hf_transfer:
|
||||
runner_lines.append("export HF_HUB_ENABLE_HF_TRANSFER=0")
|
||||
runner_lines.append("export HF_HUB_DOWNLOAD_MAX_WORKERS=4")
|
||||
else:
|
||||
runner_lines.append(f"python3 -c 'import hf_transfer' 2>/dev/null || {_pip_install_fallback_chain('hf_transfer', python_cmd='pip')}")
|
||||
runner_lines.append("python3 -c 'import hf_transfer' 2>/dev/null && export HF_HUB_ENABLE_HF_TRANSFER=1")
|
||||
runner_lines.append("export HF_HUB_DOWNLOAD_MAX_WORKERS=8")
|
||||
# Surface whether the HF token actually reached THIS server, so a gated
|
||||
# download's "not authorized" failure can be told apart from a missing
|
||||
# token (the token is masked — we only print applied / not-set).
|
||||
@@ -547,15 +572,19 @@ def setup_cookbook_routes() -> APIRouter:
|
||||
runner_lines.append(f' {hf_cmd} < /dev/null')
|
||||
runner_lines.append('elif python3 -c "import huggingface_hub" 2>/dev/null; then')
|
||||
runner_lines.append(' echo "hf CLI not found, using Python huggingface_hub..."')
|
||||
runner_lines.append(f' python3 -c "import os; from huggingface_hub import snapshot_download; snapshot_download(\'{req.repo_id}\'{_dl_pyarg}, max_workers=8)"')
|
||||
runner_lines.append(f' python3 -c "import os; from huggingface_hub import snapshot_download; snapshot_download(\'{req.repo_id}\'{_dl_pyarg}, max_workers={4 if req.disable_hf_transfer else 8})"')
|
||||
runner_lines.append('else')
|
||||
runner_lines.append(' echo "Installing huggingface-hub and dependencies..."')
|
||||
runner_lines.append(' pip install --no-deps -q huggingface-hub 2>/dev/null')
|
||||
runner_lines.append(' pip install -q filelock fsspec packaging pyyaml tqdm typer httpx requests hf_transfer 2>/dev/null')
|
||||
runner_lines.append(" python3 -c 'import hf_transfer' 2>/dev/null && export HF_HUB_ENABLE_HF_TRANSFER=1")
|
||||
runner_lines.append(f' python3 -c "import os; from huggingface_hub import snapshot_download; snapshot_download(\'{req.repo_id}\'{_dl_pyarg}, max_workers=8)"')
|
||||
if req.disable_hf_transfer:
|
||||
runner_lines.append(' pip install -q filelock fsspec packaging pyyaml tqdm typer httpx requests 2>/dev/null')
|
||||
runner_lines.append(' export HF_HUB_ENABLE_HF_TRANSFER=0')
|
||||
else:
|
||||
runner_lines.append(' pip install -q filelock fsspec packaging pyyaml tqdm typer httpx requests hf_transfer 2>/dev/null')
|
||||
runner_lines.append(" python3 -c 'import hf_transfer' 2>/dev/null && export HF_HUB_ENABLE_HF_TRANSFER=1")
|
||||
runner_lines.append(f' python3 -c "import os; from huggingface_hub import snapshot_download; snapshot_download(\'{req.repo_id}\'{_dl_pyarg}, max_workers={4 if req.disable_hf_transfer else 8})"')
|
||||
runner_lines.append('fi')
|
||||
runner_lines.append('if [ $? -eq 0 ]; then echo ""; echo "DOWNLOAD_OK"; else echo ""; echo "DOWNLOAD_FAILED (exit $?)"; fi')
|
||||
runner_lines.append('_ec=$?; if [ $_ec -eq 0 ]; then echo ""; echo "DOWNLOAD_OK"; else echo ""; echo "DOWNLOAD_FAILED (exit $_ec)"; fi')
|
||||
runner_lines.append(f"rm -f {remote_runner}")
|
||||
runner_lines.append('exec "${SHELL:-/bin/bash}"')
|
||||
runner_path = TMUX_LOG_DIR / f"{session_id}_run.sh"
|
||||
@@ -586,11 +615,11 @@ def setup_cookbook_routes() -> APIRouter:
|
||||
# Detached path: no controlling TTY, so skip `< /dev/null`
|
||||
# (handled by Popen stdin=DEVNULL) and don't keep a shell open.
|
||||
lines.append(hf_cmd)
|
||||
lines.append('if [ $? -eq 0 ]; then echo ""; echo "DOWNLOAD_OK"; else echo ""; echo "DOWNLOAD_FAILED (exit $?)"; fi')
|
||||
lines.append('_ec=$?; if [ $_ec -eq 0 ]; then echo ""; echo "DOWNLOAD_OK"; else echo ""; echo "DOWNLOAD_FAILED (exit $_ec)"; fi')
|
||||
else:
|
||||
# < /dev/null suppresses interactive "update available? [Y/n]" prompt
|
||||
lines.append(f"{hf_cmd} < /dev/null")
|
||||
lines.append('if [ $? -eq 0 ]; then echo ""; echo "DOWNLOAD_OK"; else echo ""; echo "DOWNLOAD_FAILED (exit $?)"; fi')
|
||||
lines.append('_ec=$?; if [ $_ec -eq 0 ]; then echo ""; echo "DOWNLOAD_OK"; else echo ""; echo "DOWNLOAD_FAILED (exit $_ec)"; fi')
|
||||
lines.append(f"rm -f '{wrapper_script}'")
|
||||
lines.append('exec "${SHELL:-/bin/bash}"')
|
||||
wrapper_script.write_text("\n".join(lines) + "\n", encoding="utf-8")
|
||||
@@ -672,11 +701,14 @@ def setup_cookbook_routes() -> APIRouter:
|
||||
cwd=str(Path.home()),
|
||||
)
|
||||
else:
|
||||
# LOCAL scan: run the interpreter directly. `python3` isn't a thing on
|
||||
# Windows (it's `python`/`py`), and shell single-quoting of the path
|
||||
# doesn't survive cmd.exe — so resolve the interpreter and exec it
|
||||
# with the script path as an argv element (no shell quoting needed).
|
||||
local_py = (
|
||||
# LOCAL scan: use sys.executable (the venv Python Odysseus is already
|
||||
# running under) — it's guaranteed real Python on all platforms.
|
||||
# Falling back to which_tool on Windows risks hitting the Microsoft
|
||||
# Store stub alias for "python3"/"python", which prints
|
||||
# "Python was not found; run without arguments to install from the
|
||||
# Microsoft Store" and exits 9009, producing empty stdout and a
|
||||
# JSON parse error. sys.executable bypasses PATH entirely.
|
||||
local_py = sys.executable or (
|
||||
which_tool("python3") or which_tool("python")
|
||||
or which_tool("py") or "python"
|
||||
)
|
||||
@@ -714,6 +746,8 @@ def setup_cookbook_routes() -> APIRouter:
|
||||
entry["backend"] = m.get("backend")
|
||||
if m.get("is_ollama"):
|
||||
entry["is_ollama"] = True
|
||||
if isinstance(m.get("gguf_files"), list):
|
||||
entry["gguf_files"] = m["gguf_files"]
|
||||
models.append(entry)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to parse cached models: {e}")
|
||||
@@ -775,6 +809,80 @@ def setup_cookbook_routes() -> APIRouter:
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
def _auto_register_llm_endpoint(req: ServeRequest, remote: str | None) -> str | None:
|
||||
"""Register a freshly-served LLM as a model endpoint so it appears in the
|
||||
model picker without a manual /setup step — the text-model sibling of
|
||||
_auto_register_image_endpoint.
|
||||
|
||||
Cookbook serve commands launch an OpenAI-compatible server (llama.cpp's
|
||||
llama-server, vLLM, SGLang, or Ollama) on a known port. We point an
|
||||
endpoint at that server's /v1; the picker auto-discovers the model id by
|
||||
probing /v1/models and dims the endpoint until the server is reachable,
|
||||
so registering immediately (before the server finishes loading) is safe.
|
||||
"""
|
||||
import re
|
||||
from core.database import SessionLocal, ModelEndpoint
|
||||
|
||||
# Port: an explicit --port wins. Otherwise fall back by backend — Ollama
|
||||
# is the only server in our generated commands that omits --port.
|
||||
port_match = re.search(r'--port\s+(\d+)', req.cmd)
|
||||
if port_match:
|
||||
port = int(port_match.group(1))
|
||||
elif "ollama" in req.cmd:
|
||||
port = 11434
|
||||
else:
|
||||
port = 8080 # llama.cpp's llama-server default — the Apple Silicon path
|
||||
|
||||
# Determine host (mirrors the image path: SSH alias for remote serves).
|
||||
if remote:
|
||||
host = remote.split("@")[-1] if "@" in remote else remote
|
||||
else:
|
||||
host = "localhost"
|
||||
|
||||
base_url = f"http://{host}:{port}/v1"
|
||||
|
||||
short_name = req.repo_id.split("/")[-1] if "/" in req.repo_id else req.repo_id
|
||||
display_name = short_name or "Local model"
|
||||
|
||||
# If the serve command opts models into OpenAI tool-calling, record it so
|
||||
# agent_loop trusts emitted tool_calls instead of the name heuristic.
|
||||
supports_tools = True if "--enable-auto-tool-choice" in req.cmd else None
|
||||
|
||||
db = SessionLocal()
|
||||
try:
|
||||
# Reuse an endpoint already pointed at this URL instead of duplicating.
|
||||
existing = db.query(ModelEndpoint).filter(ModelEndpoint.base_url == base_url).first()
|
||||
if existing:
|
||||
existing.is_enabled = True
|
||||
existing.model_type = "llm"
|
||||
existing.name = display_name
|
||||
if supports_tools is not None:
|
||||
existing.supports_tools = supports_tools
|
||||
db.commit()
|
||||
logger.info(f"Updated existing local model endpoint: {base_url}")
|
||||
return existing.id
|
||||
|
||||
ep_id = f"local-{uuid.uuid4().hex[:8]}"
|
||||
ep = ModelEndpoint(
|
||||
id=ep_id,
|
||||
name=display_name,
|
||||
base_url=base_url,
|
||||
api_key=None,
|
||||
is_enabled=True,
|
||||
model_type="llm",
|
||||
supports_tools=supports_tools,
|
||||
)
|
||||
db.add(ep)
|
||||
db.commit()
|
||||
logger.info(f"Auto-registered local model endpoint: {display_name} @ {base_url}")
|
||||
return ep_id
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to auto-register local model endpoint: {e}")
|
||||
db.rollback()
|
||||
return None
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
@router.post("/api/model/serve")
|
||||
async def model_serve(request: Request, req: ServeRequest):
|
||||
"""Launch a model server in a tmux session (or PowerShell background process on Windows).
|
||||
@@ -800,8 +908,17 @@ def setup_cookbook_routes() -> APIRouter:
|
||||
# many downstream `"engine" in req.cmd` membership checks can't hit
|
||||
# `TypeError: argument of type 'NoneType'` (a 500 instead of a clean 400).
|
||||
req.cmd = _validate_serve_cmd(req.cmd) or ""
|
||||
req.cmd = _venv_safe_local_pip_install_cmd(
|
||||
req.cmd,
|
||||
local=not bool(req.remote_host),
|
||||
in_venv=sys.prefix != sys.base_prefix,
|
||||
)
|
||||
is_pip_install = bool(req.cmd and "pip install" in req.cmd)
|
||||
if is_pip_install:
|
||||
# Keep big dependency wheel builds (vLLM, …) off the home filesystem's
|
||||
# pip cache so they don't fail mid-build with "No space left" (#1219)
|
||||
# and leave the dep installed-but-unusable (#1459).
|
||||
req.cmd = _pip_install_no_cache(req.cmd)
|
||||
# PEP-508-style package spec — letters, digits, `.-_` for the
|
||||
# name; `[` `]` for extras; `<>=!~,` for version specifiers.
|
||||
# v2 review HIGH-14: tightened from the previous regex which
|
||||
@@ -922,7 +1039,7 @@ def setup_cookbook_routes() -> APIRouter:
|
||||
runner_lines.append(' if ! python3 -c "import llama_cpp" 2>/dev/null; then')
|
||||
runner_lines.append(' pkg install -y cmake 2>/dev/null')
|
||||
runner_lines.append(' pip install numpy diskcache jinja2 2>/dev/null')
|
||||
runner_lines.append(' CMAKE_ARGS="-DGGML_BLAS=OFF -DGGML_LLAMAFILE=OFF" pip install llama-cpp-python --no-build-isolation --no-cache-dir 2>&1 || true')
|
||||
runner_lines.append(' CMAKE_ARGS="-DGGML_BLAS=OFF -DGGML_LLAMAFILE=OFF" pip install \'llama-cpp-python[server]\' --no-build-isolation --no-cache-dir 2>&1 || true')
|
||||
runner_lines.append(' fi')
|
||||
runner_lines.append('elif ! command -v llama-server &>/dev/null; then')
|
||||
runner_lines.append(' echo "Native llama-server not found — building from source (one-time, may take a few minutes)..."')
|
||||
@@ -944,61 +1061,45 @@ def setup_cookbook_routes() -> APIRouter:
|
||||
runner_lines.append(' && cmake --build build -j"$NPROC" --target llama-server \\')
|
||||
runner_lines.append(' && ln -sf ~/llama.cpp/build/bin/llama-server ~/bin/llama-server')
|
||||
runner_lines.append(' else')
|
||||
# Detect pip-installed nvcc (from vLLM/nvidia CUDA wheels) and put
|
||||
# it on PATH so cmake's CUDA configure can find it. We check the
|
||||
# same three layouts as entrypoint.sh:
|
||||
# nvidia/cu13 — nvidia-nvcc-cu13
|
||||
# nvidia/cu12 — nvidia-nvcc-cu12
|
||||
# nvidia/cuda_nvcc — nvidia-cuda-nvcc-cu12 (sub-package style)
|
||||
runner_lines.append(' for _cudir in ~/.local/lib/python*/site-packages/nvidia/cu13 ~/.local/lib/python*/site-packages/nvidia/cu12 ~/.local/lib/python*/site-packages/nvidia/cuda_nvcc; do')
|
||||
runner_lines.append(' [ -x "$_cudir/bin/nvcc" ] && export CUDA_HOME="$_cudir" && export PATH="$_cudir/bin:$PATH" && break')
|
||||
runner_lines.append(' done')
|
||||
# rm -rf build so a prior poisoned CMakeCache.txt (e.g. from a
|
||||
# failed CUDA attempt) doesn't cause the next configure to reuse
|
||||
# stale settings and silently produce a CPU-only binary.
|
||||
runner_lines.append(' cd ~/llama.cpp && rm -rf build')
|
||||
runner_lines.append(' if command -v nvcc &>/dev/null; then')
|
||||
runner_lines.append(' echo "[odysseus] CUDA nvcc found — building llama-server with CUDA (GPU) support..."')
|
||||
runner_lines.append(' cmake -B build -DCMAKE_BUILD_TYPE=Release -DGGML_CUDA=ON \\')
|
||||
runner_lines.append(' && cmake --build build -j"$NPROC" --target llama-server \\')
|
||||
runner_lines.append(' && ln -sf ~/llama.cpp/build/bin/llama-server ~/bin/llama-server')
|
||||
runner_lines.append(' else')
|
||||
runner_lines.append(' echo "[odysseus] WARNING: nvcc not found — building llama-server for CPU only."')
|
||||
runner_lines.append(' echo "[odysseus] GPU inference will not be available for this llama.cpp build."')
|
||||
runner_lines.append(' echo "[odysseus] To get a GPU build, first install vLLM via Cookbook -> Dependencies"')
|
||||
runner_lines.append(' echo "[odysseus] (its CUDA wheels include nvcc), then re-launch this serve task."')
|
||||
runner_lines.append(' cmake -B build -DCMAKE_BUILD_TYPE=Release \\')
|
||||
runner_lines.append(' && cmake --build build -j"$NPROC" --target llama-server \\')
|
||||
runner_lines.append(' && ln -sf ~/llama.cpp/build/bin/llama-server ~/bin/llama-server')
|
||||
runner_lines.append(' fi')
|
||||
_append_llama_cpp_linux_accel_build_lines(runner_lines)
|
||||
runner_lines.append(' fi')
|
||||
runner_lines.append(' # If the native build failed, fall back to the Python bindings.')
|
||||
runner_lines.append(' if ! command -v llama-server &>/dev/null && ! python3 -c "import llama_cpp" 2>/dev/null; then')
|
||||
runner_lines.append(' echo "llama-server build failed — installing Python bindings as fallback..."')
|
||||
runner_lines.append(' pip install --user --break-system-packages -q llama-cpp-python 2>/dev/null || pip install -q llama-cpp-python 2>/dev/null || true')
|
||||
runner_lines.append(f" {_pip_install_fallback_chain('llama-cpp-python[server]', python_cmd='pip')} || true")
|
||||
runner_lines.append(' fi')
|
||||
runner_lines.append(' if ! command -v llama-server &>/dev/null && ! python3 -c "import llama_cpp" 2>/dev/null; then')
|
||||
runner_lines.append(' echo "ERROR: llama.cpp serving is not available after install/build attempts."')
|
||||
runner_lines.append(' ODYSSEUS_PREFLIGHT_EXIT=127')
|
||||
runner_lines.append(' fi')
|
||||
runner_lines.append('fi')
|
||||
elif "ollama" in req.cmd:
|
||||
handled_ollama_serve = True
|
||||
_ollama_port = "11434"
|
||||
_ollama_match = re.search(r"OLLAMA_HOST=[^\s:]+:(\d+)", req.cmd)
|
||||
if _ollama_match:
|
||||
_ollama_port = _ollama_match.group(1)
|
||||
_ollama_default_host = "0.0.0.0" if remote else "127.0.0.1"
|
||||
_ollama_host, _ollama_port = _ollama_bind_from_cmd(
|
||||
req.cmd,
|
||||
default_host=_ollama_default_host,
|
||||
)
|
||||
# Ollama can be a host binary, a system service, or a Docker
|
||||
# container. If the HTTP API is already reachable, the model is
|
||||
# already served and we should not require a host `ollama` CLI.
|
||||
runner_lines.append(f'ODYSSEUS_OLLAMA_HOST={_bash_squote(_ollama_host)}')
|
||||
runner_lines.append(f'ODYSSEUS_OLLAMA_PORT="{_ollama_port}"')
|
||||
runner_lines.append('ODYSSEUS_OLLAMA_URL=""')
|
||||
runner_lines.append('for _ody_ollama_port in "$ODYSSEUS_OLLAMA_PORT" 11434; do')
|
||||
runner_lines.append(' [ -z "$_ody_ollama_port" ] && continue')
|
||||
runner_lines.append(' for _ody_ollama_host in 127.0.0.1 localhost host.docker.internal; do')
|
||||
runner_lines.append(' _ody_ollama_url="http://${_ody_ollama_host}:${_ody_ollama_port}"')
|
||||
runner_lines.append(' if curl -sf "$_ody_ollama_url/api/tags" >/dev/null 2>&1; then')
|
||||
runner_lines.append(' ODYSSEUS_OLLAMA_URL="$_ody_ollama_url"')
|
||||
runner_lines.append(' ODYSSEUS_OLLAMA_PORT="$_ody_ollama_port"')
|
||||
runner_lines.append(' break 2')
|
||||
runner_lines.append(' fi')
|
||||
runner_lines.append('for _ody_ollama_try in $(seq 1 20); do')
|
||||
runner_lines.append(' for _ody_ollama_port in "$ODYSSEUS_OLLAMA_PORT" 11434; do')
|
||||
runner_lines.append(' [ -z "$_ody_ollama_port" ] && continue')
|
||||
runner_lines.append(' for _ody_ollama_host in 127.0.0.1 localhost host.docker.internal; do')
|
||||
runner_lines.append(' _ody_ollama_url="http://${_ody_ollama_host}:${_ody_ollama_port}"')
|
||||
runner_lines.append(' if curl -sf "$_ody_ollama_url/api/tags" >/dev/null 2>&1; then')
|
||||
runner_lines.append(' ODYSSEUS_OLLAMA_URL="$_ody_ollama_url"')
|
||||
runner_lines.append(' ODYSSEUS_OLLAMA_PORT="$_ody_ollama_port"')
|
||||
runner_lines.append(' break 3')
|
||||
runner_lines.append(' fi')
|
||||
runner_lines.append(' done')
|
||||
runner_lines.append(' done')
|
||||
runner_lines.append(' [ "$_ody_ollama_try" -eq 1 ] && echo "[odysseus] Waiting for an existing Ollama API on ports ${ODYSSEUS_OLLAMA_PORT}/11434..."')
|
||||
runner_lines.append(' sleep 1')
|
||||
runner_lines.append('done')
|
||||
runner_lines.append('if [ -n "$ODYSSEUS_OLLAMA_URL" ]; then')
|
||||
runner_lines.append(' if [ "$ODYSSEUS_OLLAMA_PORT" != "' + _ollama_port + '" ]; then')
|
||||
@@ -1015,8 +1116,12 @@ def setup_cookbook_routes() -> APIRouter:
|
||||
runner_lines.append(' echo "=== Process exited with code 127 ==="')
|
||||
runner_lines.append(' exec bash -i')
|
||||
runner_lines.append('fi')
|
||||
runner_lines.append('echo "Starting ollama server on 0.0.0.0:${ODYSSEUS_OLLAMA_PORT}..."')
|
||||
runner_lines.append('OLLAMA_HOST="0.0.0.0:${ODYSSEUS_OLLAMA_PORT}" ollama serve')
|
||||
runner_lines.append('ODYSSEUS_OLLAMA_URL="http://${ODYSSEUS_OLLAMA_HOST}:${ODYSSEUS_OLLAMA_PORT}"')
|
||||
if remote and _ollama_host in ("0.0.0.0", "::"):
|
||||
runner_lines.append('echo "[odysseus] WARNING: remote Ollama will bind to ${ODYSSEUS_OLLAMA_HOST}:${ODYSSEUS_OLLAMA_PORT} so Odysseus can reach it from this host."')
|
||||
runner_lines.append('echo "[odysseus] Ollama has no built-in authentication; expose this only on a trusted LAN/VPN or provide an explicit OLLAMA_HOST with your own access controls."')
|
||||
runner_lines.append('echo "Starting ollama server on ${ODYSSEUS_OLLAMA_HOST}:${ODYSSEUS_OLLAMA_PORT}..."')
|
||||
runner_lines.append('OLLAMA_HOST="${ODYSSEUS_OLLAMA_HOST}:${ODYSSEUS_OLLAMA_PORT}" ollama serve')
|
||||
runner_lines.append('_ody_exit=$?')
|
||||
runner_lines.append('echo')
|
||||
runner_lines.append('echo "=== Process exited with code ${_ody_exit} ==="')
|
||||
@@ -1032,19 +1137,24 @@ def setup_cookbook_routes() -> APIRouter:
|
||||
# find the `vllm` CLI ("command not found"). Mirrors llama.cpp above.
|
||||
runner_lines.append('export PATH="$HOME/.local/bin:$PATH"')
|
||||
runner_lines.append('if ! command -v vllm &>/dev/null; then')
|
||||
runner_lines.append(' echo "ERROR: vLLM is not installed. Open Cookbook -> Dependencies and install vllm on this server, then launch again."')
|
||||
runner_lines.append(' echo "ERROR: vLLM is not installed."')
|
||||
runner_lines.append(' ODYSSEUS_PREFLIGHT_EXIT=127')
|
||||
runner_lines.append('fi')
|
||||
elif "sglang.launch_server" in req.cmd:
|
||||
runner_lines.append('export PATH="$HOME/.local/bin:$PATH"')
|
||||
runner_lines.append('if ! python3 -c "import sglang" 2>/dev/null; then')
|
||||
runner_lines.append(' echo "ERROR: SGLang is not installed. Open Cookbook -> Dependencies and install sglang on this server, then launch again."')
|
||||
runner_lines.append('if ! command -v sglang &>/dev/null; then')
|
||||
runner_lines.append(' echo "ERROR: SGLang is not installed."')
|
||||
runner_lines.append(' ODYSSEUS_PREFLIGHT_EXIT=127')
|
||||
runner_lines.append('elif ! ODYSSEUS_SGLANG_IMPORT_ERROR="$(python3 -c "import sglang" 2>&1)"; then')
|
||||
runner_lines.append(' echo "ERROR: SGLang is installed but failed to import."')
|
||||
runner_lines.append(' printf "%s\\n" "$ODYSSEUS_SGLANG_IMPORT_ERROR"')
|
||||
runner_lines.append(' ODYSSEUS_PREFLIGHT_EXIT=127')
|
||||
runner_lines.append('fi')
|
||||
elif "scripts/diffusion_server.py" in req.cmd or ".diffusion_server.py" in req.cmd:
|
||||
runner_lines.append('export PATH="$HOME/.local/bin:$PATH"')
|
||||
runner_lines.append('if ! python3 -c "import torch, diffusers" 2>/dev/null; then')
|
||||
runner_lines.append(' echo "ERROR: Diffusion serving requires PyTorch + diffusers. Open Cookbook -> Dependencies and install diffusers on this server, then launch again."')
|
||||
runner_lines.append('if ! ODYSSEUS_DIFFUSION_IMPORT_ERROR="$(python3 -c "import torch, diffusers" 2>&1)"; then')
|
||||
runner_lines.append(' echo "ERROR: Diffusion serving requires PyTorch + diffusers."')
|
||||
runner_lines.append(' printf "%s\\n" "$ODYSSEUS_DIFFUSION_IMPORT_ERROR"')
|
||||
runner_lines.append(' ODYSSEUS_PREFLIGHT_EXIT=127')
|
||||
runner_lines.append('fi')
|
||||
|
||||
@@ -1116,11 +1226,16 @@ def setup_cookbook_routes() -> APIRouter:
|
||||
stderr = (await proc.stderr.read()).decode(errors="replace")
|
||||
return {"ok": False, "error": stderr, "session_id": session_id}
|
||||
|
||||
# Auto-register as model endpoint if serving a diffusion model
|
||||
# Auto-register a model endpoint so the served model shows up in the model
|
||||
# picker with no manual /setup step. Diffusion models get an image
|
||||
# endpoint; any other real model serve (i.e. not a pip-install task) gets
|
||||
# a local LLM endpoint pointed at its /v1.
|
||||
endpoint_id = None
|
||||
is_diffusion = "diffusion_server.py" in req.cmd
|
||||
if is_diffusion:
|
||||
endpoint_id = _auto_register_image_endpoint(req, remote)
|
||||
elif not is_pip_install:
|
||||
endpoint_id = _auto_register_llm_endpoint(req, remote)
|
||||
|
||||
# Log to assistant
|
||||
try:
|
||||
@@ -1357,9 +1472,16 @@ def setup_cookbook_routes() -> APIRouter:
|
||||
total_mb = max(0, int(total_bytes / (1024 * 1024)))
|
||||
used_mb = max(0, min(total_mb, int(used_bytes / (1024 * 1024))))
|
||||
free_mb = max(0, total_mb - used_mb)
|
||||
# GTT = the system-RAM pool the GPU pages into when VRAM is full.
|
||||
# On a discrete card a large gtt_used means the model spilled past
|
||||
# VRAM into RAM over PCIe — much slower. Surface it so the UI can
|
||||
# warn "spilling to RAM" instead of the user wondering why it's slow.
|
||||
gtt_used_raw = await _gpu_read_file(f"{base}/mem_info_gtt_used", host, ssh_port)
|
||||
gtt_used_mb = max(0, int(int(gtt_used_raw) / (1024 * 1024))) if (gtt_used_raw and gtt_used_raw.isdigit()) else 0
|
||||
gpus.append({
|
||||
"index": len(gpus), "name": name, "uuid": entry,
|
||||
"free_mb": free_mb, "total_mb": total_mb, "used_mb": used_mb,
|
||||
"gtt_used_mb": gtt_used_mb,
|
||||
"util_pct": 0, "busy": bool(total_mb and (free_mb / total_mb) < 0.85),
|
||||
"processes": [], "backend": "rocm", "source": "amd-sysfs",
|
||||
"unified_memory": unified,
|
||||
@@ -1461,6 +1583,46 @@ def setup_cookbook_routes() -> APIRouter:
|
||||
if gpus:
|
||||
return {"ok": True, "gpus": gpus, "backend": "cuda", "source": "nvidia-smi"}
|
||||
|
||||
# Local Apple Silicon / Metal fallback. macOS has no nvidia-smi and no
|
||||
# Linux /sys/class/drm tree, but services.hwfit.hardware already knows
|
||||
# how to size the shared unified-memory GPU budget. Keep this route in
|
||||
# sync so Cookbook's GPU picker doesn't show "nvidia-smi not found" on
|
||||
# native Mac launches.
|
||||
if not host and sys.platform == "darwin":
|
||||
try:
|
||||
from services.hwfit.hardware import detect_system
|
||||
info = detect_system(fresh=True)
|
||||
backend = str(info.get("backend") or "").lower()
|
||||
if backend in {"metal", "mps", "apple"} and info.get("gpu_count", 0) > 0:
|
||||
total_mb = int(float(info.get("gpu_vram_gb") or info.get("total_ram_gb") or 0) * 1024)
|
||||
free_mb = int(float(info.get("available_ram_gb") or 0) * 1024)
|
||||
if total_mb and (free_mb <= 0 or free_mb > total_mb):
|
||||
free_mb = total_mb
|
||||
used_mb = max(0, total_mb - max(0, free_mb))
|
||||
return {
|
||||
"ok": True,
|
||||
"gpus": [{
|
||||
"index": 0,
|
||||
"name": info.get("gpu_name") or info.get("cpu_name") or "Apple Silicon GPU",
|
||||
"uuid": "apple-metal-0",
|
||||
"free_mb": max(0, free_mb),
|
||||
"total_mb": max(0, total_mb),
|
||||
"used_mb": used_mb,
|
||||
"util_pct": 0,
|
||||
"busy": bool(total_mb and (free_mb / total_mb) < 0.5),
|
||||
"processes": [],
|
||||
"backend": "metal",
|
||||
"source": "apple-metal",
|
||||
"unified_memory": True,
|
||||
}],
|
||||
"backend": "metal",
|
||||
"source": "apple-metal",
|
||||
"fallback_from": "nvidia-smi",
|
||||
"nvidia_error": nvidia_error,
|
||||
}
|
||||
except Exception as e:
|
||||
logger.warning("Apple Metal GPU fallback failed: %s", e)
|
||||
|
||||
amd_gpus = await _probe_amd_sysfs(host, ssh_port)
|
||||
if amd_gpus:
|
||||
return {
|
||||
@@ -1607,6 +1769,33 @@ def setup_cookbook_routes() -> APIRouter:
|
||||
|
||||
disk_tasks = on_disk.get("tasks") or [] if isinstance(on_disk, dict) else []
|
||||
incoming_tasks = data.get("tasks") if isinstance(data.get("tasks"), list) else []
|
||||
# Anti-poisoning guard: a stale browser tab can keep POSTing a
|
||||
# download task as status='done' from before the strict-finish
|
||||
# fix landed, undoing any server-side correction. For each
|
||||
# incoming "done" download, override to "running" if the last
|
||||
# shard pattern says N<total AND no DOWNLOAD_OK/DOWNLOAD_FAILED/
|
||||
# /snapshots/ sentinel is in the output.
|
||||
import re as _re_dl
|
||||
for _it in incoming_tasks:
|
||||
if (not isinstance(_it, dict)) or _it.get("type") != "download" or _it.get("status") != "done":
|
||||
continue
|
||||
_out = _it.get("output") or ""
|
||||
if ("DOWNLOAD_OK" in _out) or ("DOWNLOAD_FAILED" in _out) or ("/snapshots/" in _out):
|
||||
continue
|
||||
_shards = _re_dl.findall(r"model-(\d+)-of-(\d+)\.safetensors", _out)
|
||||
if _shards:
|
||||
_n, _tot = _shards[-1]
|
||||
if int(_n) < int(_tot):
|
||||
logger.info(f"cookbook state POST: rejecting stale done for {_it.get('sessionId')} "
|
||||
f"(last shard {_n}/{_tot}, no DOWNLOAD_OK)")
|
||||
_it["status"] = "running"
|
||||
else:
|
||||
_completed = _out.count("Download complete")
|
||||
_starts = _out.count("Downloading '")
|
||||
if _starts > _completed:
|
||||
logger.info(f"cookbook state POST: rejecting stale done for {_it.get('sessionId')} "
|
||||
f"({_completed}/{_starts} files complete, no DOWNLOAD_OK)")
|
||||
_it["status"] = "running"
|
||||
incoming_ids = {t.get("sessionId") for t in incoming_tasks if isinstance(t, dict) and t.get("sessionId")}
|
||||
import time as _t
|
||||
now_ms = int(_t.time() * 1000)
|
||||
@@ -1763,6 +1952,43 @@ def setup_cookbook_routes() -> APIRouter:
|
||||
def _cookbook_tasks_status_sync():
|
||||
import subprocess
|
||||
|
||||
def _download_cache_complete(repo_id: str, remote_host: str = "", ssh_port: str = "") -> bool:
|
||||
"""Best-effort check for a completed HF cache entry.
|
||||
|
||||
tmux output can stop at a stale progress line if the pane/session
|
||||
disappears before Cookbook captures the final DOWNLOAD_OK marker.
|
||||
In that case, trust the cache shape: a snapshot directory with files
|
||||
and no *.incomplete blobs means HuggingFace finished materializing the
|
||||
model.
|
||||
"""
|
||||
if not repo_id or "/" not in repo_id:
|
||||
return False
|
||||
py = (
|
||||
"import os,sys;"
|
||||
"repo=sys.argv[1];"
|
||||
"base=os.environ.get('HUGGINGFACE_HUB_CACHE') or os.path.join(os.environ.get('HF_HOME', os.path.expanduser('~/.cache/huggingface')), 'hub');"
|
||||
"d=os.path.join(base,'models--'+repo.replace('/','--'));"
|
||||
"snap=os.path.join(d,'snapshots');"
|
||||
"ok=os.path.isdir(snap) and any(os.path.isdir(os.path.join(snap,x)) and os.listdir(os.path.join(snap,x)) for x in os.listdir(snap));"
|
||||
"inc=False;"
|
||||
"blobs=os.path.join(d,'blobs');"
|
||||
"inc=os.path.isdir(blobs) and any(x.endswith('.incomplete') for x in os.listdir(blobs));"
|
||||
"sys.exit(0 if ok and not inc else 1)"
|
||||
)
|
||||
cmd = ["python3", "-c", py, repo_id]
|
||||
try:
|
||||
if remote_host:
|
||||
ssh_base = ["ssh"]
|
||||
if ssh_port and ssh_port != "22":
|
||||
ssh_base.extend(["-p", str(ssh_port)])
|
||||
shell_cmd = " ".join(shlex.quote(x) for x in cmd)
|
||||
proc = subprocess.run(ssh_base + [remote_host, shell_cmd], timeout=12, capture_output=True)
|
||||
else:
|
||||
proc = subprocess.run(cmd, timeout=12, capture_output=True)
|
||||
return proc.returncode == 0
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
# Load saved tasks from cookbook state
|
||||
tasks = []
|
||||
if _cookbook_state_path.exists():
|
||||
@@ -1902,14 +2128,21 @@ def setup_cookbook_routes() -> APIRouter:
|
||||
# persists after the process exits, so a finished download still has a
|
||||
# snapshot to classify (DOWNLOAD_OK / exit marker) — evaluate it even
|
||||
# when the PID is gone instead of blindly reporting "stopped".
|
||||
download_zero_files = False
|
||||
status = "unknown"
|
||||
if is_alive or (local_win_task and full_snapshot):
|
||||
lower = full_snapshot.lower()
|
||||
has_exit = "=== process exited with code" in lower
|
||||
exit_match = re.search(r"=== process exited with code\s+(-?\d+)", full_snapshot, re.I)
|
||||
has_exit = exit_match is not None
|
||||
exit_code = int(exit_match.group(1)) if exit_match else None
|
||||
has_error = "error" in lower or "failed" in lower or "traceback" in lower
|
||||
if has_exit and task_type == "serve":
|
||||
# Serve tasks that exit are always errors — they should run indefinitely
|
||||
status = "error"
|
||||
elif has_exit and task_type == "download":
|
||||
# Dependency installs are tracked as download tasks but only
|
||||
# emit the generic runner exit marker, not HF download markers.
|
||||
status = "completed" if exit_code == 0 else "error"
|
||||
elif has_exit and "unrecognized arguments" in lower:
|
||||
status = "error"
|
||||
elif has_error and not ("application startup complete" in lower):
|
||||
@@ -1918,7 +2151,11 @@ def setup_cookbook_routes() -> APIRouter:
|
||||
# Only download tasks treat 100% as "completed".
|
||||
# Serve tasks log 100%|██████| during inference progress
|
||||
# (diffusion sampling, etc.) — that's "running", not done.
|
||||
status = "completed"
|
||||
if re.search(r"Fetching\s+0\s+files", full_snapshot, re.IGNORECASE):
|
||||
status = "error"
|
||||
download_zero_files = True
|
||||
else:
|
||||
status = "completed"
|
||||
elif "application startup complete" in lower:
|
||||
status = "ready"
|
||||
elif not is_alive:
|
||||
@@ -1928,7 +2165,14 @@ def setup_cookbook_routes() -> APIRouter:
|
||||
status = "running"
|
||||
else:
|
||||
# Session is dead — check if it completed or crashed
|
||||
status = "stopped"
|
||||
if task_type == "download" and _download_cache_complete(_payload.get("repo_id") or model, remote, str(_tport or "")):
|
||||
status = "completed"
|
||||
if not progress_text:
|
||||
progress_text = "Download complete"
|
||||
if not full_snapshot:
|
||||
full_snapshot = "DOWNLOAD_OK"
|
||||
else:
|
||||
status = "stopped"
|
||||
|
||||
# Parse structured phase info — single source of truth for the UI
|
||||
phase_info = _parse_serve_phase(full_snapshot, task_type) if (task_type == "serve" and status == "running" and full_snapshot) else {}
|
||||
@@ -1938,6 +2182,8 @@ def setup_cookbook_routes() -> APIRouter:
|
||||
diagnosis = _diagnose_serve_output(full_snapshot) if task_type == "serve" and full_snapshot else None
|
||||
if diagnosis and status in {"running", "unknown", "stopped"}:
|
||||
status = "error"
|
||||
if download_zero_files:
|
||||
diagnosis = {"message": "No matching files were downloaded. The model repo or filename/quant pattern may be wrong (for example a ':Q4_K_M' tag that does not exist in the repo). Check the repo and the include/quant pattern."}
|
||||
output_tail = "\n".join(full_snapshot.splitlines()[-12:]) if full_snapshot else ""
|
||||
|
||||
results.append({
|
||||
|
||||
@@ -152,7 +152,7 @@ def _resolve_user_upload_path(
|
||||
owner=owner,
|
||||
auth_manager=auth_manager,
|
||||
)
|
||||
if not resolved:
|
||||
if not isinstance(resolved, dict) or not resolved:
|
||||
return None
|
||||
path = resolved.get("path")
|
||||
upload_dir = getattr(upload_handler, "upload_dir", None)
|
||||
@@ -203,6 +203,8 @@ def _assert_pdf_marker_upload_owned(
|
||||
def _derive_title(content: str) -> str:
|
||||
"""Derive a title from document content."""
|
||||
import re
|
||||
if not isinstance(content, str):
|
||||
return "Untitled"
|
||||
text = content.strip()
|
||||
if not text:
|
||||
return "Untitled"
|
||||
|
||||
@@ -15,6 +15,21 @@ from src.auth_helpers import get_current_user
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _aggregate_language_facets(lang_rows):
|
||||
"""Sum document counts per display language for the library facet.
|
||||
|
||||
NULL-language and explicit "text" rows share the "text" bucket (the
|
||||
language filter treats them as one), so they must be ADDED. The old dict
|
||||
comprehension keyed both to "text", silently overwriting one group and
|
||||
undercounting the facet versus what the filter actually returns.
|
||||
"""
|
||||
out = {}
|
||||
for lang, cnt in lang_rows:
|
||||
key = lang or "text"
|
||||
out[key] = out.get(key, 0) + cnt
|
||||
return out
|
||||
|
||||
|
||||
|
||||
from routes.document_helpers import (
|
||||
DocumentCreate, DocumentUpdate, DocumentPatch,
|
||||
@@ -145,7 +160,7 @@ def setup_document_routes(session_manager, upload_handler=None) -> APIRouter:
|
||||
create_form_markdown_document,
|
||||
create_plain_pdf_document,
|
||||
)
|
||||
from src.document_processor import _process_pdf
|
||||
from src.document_processor import _process_pdf, strip_pdf_content_marker
|
||||
import os
|
||||
|
||||
from src.auth_helpers import require_privilege
|
||||
@@ -184,7 +199,7 @@ def setup_document_routes(session_manager, upload_handler=None) -> APIRouter:
|
||||
|
||||
title = os.path.splitext(meta.get("original_name") or meta.get("name") or upload_id)[0]
|
||||
try:
|
||||
body_text = _process_pdf(pdf_path).lstrip("\n[PDF content]:").strip()
|
||||
body_text = strip_pdf_content_marker(_process_pdf(pdf_path))
|
||||
except Exception:
|
||||
body_text = None
|
||||
|
||||
@@ -258,7 +273,7 @@ def setup_document_routes(session_manager, upload_handler=None) -> APIRouter:
|
||||
)
|
||||
lang_q = _owner_session_filter(lang_q, user)
|
||||
lang_rows = lang_q.group_by(Document.language).all()
|
||||
languages = {lang or "text": cnt for lang, cnt in lang_rows}
|
||||
languages = _aggregate_language_facets(lang_rows)
|
||||
|
||||
# Session count (owner-filtered)
|
||||
sc_q = (
|
||||
@@ -402,7 +417,7 @@ def setup_document_routes(session_manager, upload_handler=None) -> APIRouter:
|
||||
text extraction was wired, plus for scanned/image-only PDFs where the
|
||||
VL model picks up text the basic pypdf path missed."""
|
||||
import re
|
||||
from src.document_processor import _process_pdf
|
||||
from src.document_processor import _process_pdf, strip_pdf_content_marker
|
||||
from src.pdf_form_doc import find_source_upload_id
|
||||
|
||||
user = get_current_user(request)
|
||||
@@ -423,7 +438,7 @@ def setup_document_routes(session_manager, upload_handler=None) -> APIRouter:
|
||||
raise HTTPException(404, "Source PDF could not be located")
|
||||
|
||||
try:
|
||||
body_text = _process_pdf(pdf_path).lstrip("\n[PDF content]:").strip()
|
||||
body_text = strip_pdf_content_marker(_process_pdf(pdf_path))
|
||||
except Exception as e:
|
||||
logger.error(f"extract_pdf_text failed for {pdf_path}: {e}")
|
||||
raise HTTPException(500, f"Extraction failed: {e}")
|
||||
@@ -593,6 +608,15 @@ def setup_document_routes(session_manager, upload_handler=None) -> APIRouter:
|
||||
if req.session_id is not None:
|
||||
# Empty string = unlink from session
|
||||
doc.session_id = req.session_id if req.session_id else None
|
||||
if not req.session_id:
|
||||
# Tab closed / doc detached from its session — drop the
|
||||
# in-memory active-doc pointer so the last-resort injection
|
||||
# path doesn't re-surface this doc in a later chat (#1160).
|
||||
try:
|
||||
from src.tool_implementations import clear_active_document
|
||||
clear_active_document(doc_id)
|
||||
except Exception:
|
||||
pass
|
||||
db.commit()
|
||||
db.refresh(doc)
|
||||
return _doc_to_dict(doc)
|
||||
@@ -615,6 +639,13 @@ def setup_document_routes(session_manager, upload_handler=None) -> APIRouter:
|
||||
raise HTTPException(404, "Document not found")
|
||||
_verify_doc_owner(db, doc, user)
|
||||
doc.is_active = False
|
||||
# Closed/deleted — drop the in-memory active-doc pointer so it isn't
|
||||
# re-injected into a later, unrelated chat (#1160).
|
||||
try:
|
||||
from src.tool_implementations import clear_active_document
|
||||
clear_active_document(doc_id)
|
||||
except Exception:
|
||||
pass
|
||||
db.commit()
|
||||
return {"status": "deleted", "id": doc_id}
|
||||
except HTTPException:
|
||||
@@ -885,7 +916,7 @@ def setup_document_routes(session_manager, upload_handler=None) -> APIRouter:
|
||||
for i, doc in enumerate(batch):
|
||||
if i >= len(verdicts):
|
||||
break
|
||||
verdict = verdicts[i].lower().strip()
|
||||
verdict = str(verdicts[i] or "").lower().strip()
|
||||
if verdict == "junk":
|
||||
doc.tidy_verdict = "junk"
|
||||
db.delete(doc)
|
||||
|
||||
@@ -67,6 +67,14 @@ def _summary(d: EditorDraft) -> Dict[str, Any]:
|
||||
}
|
||||
|
||||
|
||||
def _load_payload(raw: Optional[str]) -> Dict[str, Any]:
|
||||
try:
|
||||
payload = json.loads(raw) if raw else {}
|
||||
except Exception:
|
||||
return {}
|
||||
return payload if isinstance(payload, dict) else {}
|
||||
|
||||
|
||||
def setup_editor_draft_routes() -> APIRouter:
|
||||
router = APIRouter(tags=["editor-drafts"])
|
||||
|
||||
@@ -93,13 +101,9 @@ def setup_editor_draft_routes() -> APIRouter:
|
||||
).first()
|
||||
if not d or not _owns(d, user):
|
||||
raise HTTPException(404, "Draft not found")
|
||||
try:
|
||||
payload = json.loads(d.payload) if d.payload else {}
|
||||
except Exception:
|
||||
payload = {}
|
||||
return {
|
||||
**_summary(d),
|
||||
"payload": payload,
|
||||
"payload": _load_payload(d.payload),
|
||||
}
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
+151
-62
@@ -15,7 +15,6 @@ and `email_pollers.py` (the background loops):
|
||||
import os
|
||||
import imaplib
|
||||
import smtplib
|
||||
import ssl
|
||||
import email as email_mod
|
||||
import email.header
|
||||
import email.utils
|
||||
@@ -33,47 +32,43 @@ from fastapi import Query, HTTPException, Request
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional, List
|
||||
|
||||
from src.auth_helpers import get_current_user
|
||||
from src.auth_helpers import _auth_disabled, get_current_user
|
||||
from src.secret_storage import decrypt as _decrypt
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _send_smtp_message(cfg: dict, from_addr: str, recipients: list[str], message: str | bytes, timeout: int = 30) -> None:
|
||||
"""Send through SMTP using the conventional TLS mode for the configured port.
|
||||
def _smtp_security_mode(cfg: dict) -> str:
|
||||
raw = str(cfg.get("smtp_security") or "").strip().lower()
|
||||
if raw in {"ssl", "starttls", "none"}:
|
||||
return raw
|
||||
port = int(cfg.get("smtp_port") or 465)
|
||||
if port == 587:
|
||||
return "starttls"
|
||||
return "ssl"
|
||||
|
||||
Account settings only store host/port today. Port 465 is implicit TLS
|
||||
(SMTP_SSL); port 587 is plain SMTP upgraded with STARTTLS. Using SSL
|
||||
directly against 587 raises the classic "[SSL: WRONG_VERSION_NUMBER]"
|
||||
error even when credentials are correct.
|
||||
"""
|
||||
|
||||
def _send_smtp_message(cfg: dict, from_addr: str, recipients: list[str], message: str | bytes, timeout: int = 30) -> None:
|
||||
"""Send through SMTP using the configured transport security mode."""
|
||||
host = cfg["smtp_host"]
|
||||
port = int(cfg.get("smtp_port") or 465)
|
||||
user = cfg.get("smtp_user") or ""
|
||||
password = cfg.get("smtp_password") or ""
|
||||
def _send_starttls(starttls_port: int = 587) -> None:
|
||||
with smtplib.SMTP(host, starttls_port, timeout=timeout) as smtp:
|
||||
smtp.starttls()
|
||||
if user and password:
|
||||
smtp.login(user, password)
|
||||
smtp.sendmail(from_addr, recipients, message)
|
||||
security = _smtp_security_mode(cfg)
|
||||
|
||||
if port == 587:
|
||||
_send_starttls(587)
|
||||
return
|
||||
|
||||
try:
|
||||
if security == "ssl":
|
||||
with smtplib.SMTP_SSL(host, port, timeout=timeout) as smtp:
|
||||
if user and password:
|
||||
smtp.login(user, password)
|
||||
smtp.sendmail(from_addr, recipients, message)
|
||||
return
|
||||
except (TimeoutError, ssl.SSLError) as e:
|
||||
if port == 465:
|
||||
logger.warning("SMTP implicit TLS on %s:465 failed (%s); retrying STARTTLS on 587", host, e)
|
||||
_send_starttls(587)
|
||||
return
|
||||
raise
|
||||
|
||||
with smtplib.SMTP(host, port, timeout=timeout) as smtp:
|
||||
if security == "starttls":
|
||||
smtp.starttls()
|
||||
if user and password:
|
||||
smtp.login(user, password)
|
||||
smtp.sendmail(from_addr, recipients, message)
|
||||
|
||||
|
||||
def _strip_think(text: str) -> str:
|
||||
@@ -152,6 +147,8 @@ def _require_auth(request: Request) -> str:
|
||||
u = get_current_user(request)
|
||||
if u:
|
||||
return u
|
||||
if _auth_disabled():
|
||||
return ""
|
||||
auth_mgr = getattr(request.app.state, "auth_manager", None)
|
||||
if auth_mgr is not None and getattr(auth_mgr, "is_configured", False):
|
||||
raise HTTPException(401, "Not authenticated")
|
||||
@@ -300,7 +297,8 @@ def _init_scheduled_db():
|
||||
send_at TEXT NOT NULL,
|
||||
created_at TEXT NOT NULL,
|
||||
status TEXT NOT NULL DEFAULT 'pending',
|
||||
error TEXT
|
||||
error TEXT,
|
||||
owner TEXT DEFAULT ''
|
||||
)
|
||||
""")
|
||||
# Email summary cache (keyed by Message-ID)
|
||||
@@ -438,6 +436,35 @@ def _init_scheduled_db():
|
||||
conn.execute("ALTER TABLE scheduled_emails ADD COLUMN account_id TEXT")
|
||||
if "odysseus_kind" not in cols:
|
||||
conn.execute("ALTER TABLE scheduled_emails ADD COLUMN odysseus_kind TEXT")
|
||||
if "owner" not in cols:
|
||||
conn.execute("ALTER TABLE scheduled_emails ADD COLUMN owner TEXT DEFAULT ''")
|
||||
conn.execute("CREATE INDEX IF NOT EXISTS ix_scheduled_emails_owner_status ON scheduled_emails(owner, status)")
|
||||
# Backfill owner on legacy rows from the owning email account so the
|
||||
# owner-scoped list/cancel routes surface pre-migration scheduled
|
||||
# sends to the right user (the poller already resolves these by
|
||||
# account at send time; this aligns the UI with that).
|
||||
legacy_accounts = conn.execute(
|
||||
"SELECT DISTINCT account_id FROM scheduled_emails "
|
||||
"WHERE (owner IS NULL OR owner = '') AND account_id IS NOT NULL AND account_id != ''"
|
||||
).fetchall()
|
||||
if legacy_accounts:
|
||||
try:
|
||||
from core.database import SessionLocal as _SL, EmailAccount as _EA
|
||||
_db = _SL()
|
||||
try:
|
||||
for (acct_id,) in legacy_accounts:
|
||||
row = _db.query(_EA.owner).filter(_EA.id == acct_id).first()
|
||||
acct_owner = (row[0] or "") if row else ""
|
||||
if acct_owner:
|
||||
conn.execute(
|
||||
"UPDATE scheduled_emails SET owner = ? "
|
||||
"WHERE account_id = ? AND (owner IS NULL OR owner = '')",
|
||||
(acct_owner, acct_id),
|
||||
)
|
||||
finally:
|
||||
_db.close()
|
||||
except Exception:
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
# Lazy migration: add turns_json to email_boundaries for server-side
|
||||
@@ -541,6 +568,7 @@ def _get_email_config(account_id: str | None = None, owner: str = "") -> dict:
|
||||
"account_name": row.name,
|
||||
"smtp_host": row.smtp_host or "",
|
||||
"smtp_port": int(row.smtp_port or 465),
|
||||
"smtp_security": _smtp_security_mode({"smtp_security": getattr(row, "smtp_security", ""), "smtp_port": row.smtp_port}),
|
||||
"smtp_user": row.smtp_user or "",
|
||||
"smtp_password": _decrypt(row.smtp_password or ""),
|
||||
"imap_host": row.imap_host or "",
|
||||
@@ -567,6 +595,10 @@ def _get_email_config(account_id: str | None = None, owner: str = "") -> dict:
|
||||
"account_name": "legacy",
|
||||
"smtp_host": settings.get("smtp_host", os.environ.get("SMTP_HOST", "")),
|
||||
"smtp_port": int(settings.get("smtp_port", os.environ.get("SMTP_PORT", "465")) or 465),
|
||||
"smtp_security": _smtp_security_mode({
|
||||
"smtp_security": settings.get("smtp_security", os.environ.get("SMTP_SECURITY", "")),
|
||||
"smtp_port": settings.get("smtp_port", os.environ.get("SMTP_PORT", "465")),
|
||||
}),
|
||||
"smtp_user": settings.get("smtp_user", os.environ.get("SMTP_USER", "")),
|
||||
"smtp_password": settings.get("smtp_password", os.environ.get("SMTP_PASSWORD", "")),
|
||||
"imap_host": settings.get("imap_host", os.environ.get("IMAP_HOST", "")),
|
||||
@@ -606,7 +638,32 @@ def _list_email_accounts() -> list[dict]:
|
||||
|
||||
# ── IMAP helpers ──
|
||||
|
||||
_IMAP_TIMEOUT_SECONDS = 15
|
||||
def _coerce_imap_timeout_seconds(raw: str | None) -> int:
|
||||
try:
|
||||
value = int(raw or "30")
|
||||
except (TypeError, ValueError):
|
||||
value = 30
|
||||
return max(5, min(value, 300))
|
||||
|
||||
|
||||
_IMAP_TIMEOUT_SECONDS = _coerce_imap_timeout_seconds(os.environ.get("ODYSSEUS_IMAP_TIMEOUT_SECONDS"))
|
||||
|
||||
|
||||
def _open_imap_connection(host: str, port: int, *, starttls: bool, timeout: int = _IMAP_TIMEOUT_SECONDS):
|
||||
"""Open an IMAP connection using the configured security mode."""
|
||||
port = int(port or 993)
|
||||
if starttls:
|
||||
conn = imaplib.IMAP4(host, port, timeout=timeout)
|
||||
conn.starttls()
|
||||
elif port == 993:
|
||||
conn = imaplib.IMAP4_SSL(host, port, timeout=timeout)
|
||||
else:
|
||||
conn = imaplib.IMAP4(host, port, timeout=timeout)
|
||||
try:
|
||||
conn.sock.settimeout(timeout)
|
||||
except Exception:
|
||||
pass
|
||||
return conn
|
||||
|
||||
def _imap_connect(account_id: str | None = None, owner: str = ""):
|
||||
# SECURITY: passing `owner` scopes the fallback config lookup so a brand
|
||||
@@ -620,17 +677,12 @@ def _imap_connect(account_id: str | None = None, owner: str = ""):
|
||||
# The last branch is critical: previously this fell into IMAP4_SSL
|
||||
# for any non-STARTTLS port, which would fail the TLS handshake on
|
||||
# plain local servers (Dovecot on 31143, etc.).
|
||||
if cfg.get("imap_starttls"):
|
||||
conn = imaplib.IMAP4(cfg["imap_host"], cfg["imap_port"], timeout=_IMAP_TIMEOUT_SECONDS)
|
||||
conn.starttls()
|
||||
elif int(cfg.get("imap_port") or 993) == 993:
|
||||
conn = imaplib.IMAP4_SSL(cfg["imap_host"], cfg["imap_port"], timeout=_IMAP_TIMEOUT_SECONDS)
|
||||
else:
|
||||
conn = imaplib.IMAP4(cfg["imap_host"], cfg["imap_port"], timeout=_IMAP_TIMEOUT_SECONDS)
|
||||
try:
|
||||
conn.sock.settimeout(_IMAP_TIMEOUT_SECONDS)
|
||||
except Exception:
|
||||
pass
|
||||
conn = _open_imap_connection(
|
||||
cfg["imap_host"],
|
||||
cfg["imap_port"],
|
||||
starttls=bool(cfg.get("imap_starttls")),
|
||||
timeout=_IMAP_TIMEOUT_SECONDS,
|
||||
)
|
||||
conn.login(cfg["imap_user"], cfg["imap_password"])
|
||||
return conn
|
||||
|
||||
@@ -699,7 +751,13 @@ def _decode_header(raw):
|
||||
decoded = []
|
||||
for data, charset in parts:
|
||||
if isinstance(data, bytes):
|
||||
decoded.append(data.decode(charset or "utf-8", errors="replace"))
|
||||
try:
|
||||
decoded.append(data.decode(charset or "utf-8", errors="replace"))
|
||||
except (LookupError, ValueError):
|
||||
# Unknown/invalid MIME charset (e.g. a malformed or spam header
|
||||
# like =?x-unknown-charset?B?...?=). errors="replace" only covers
|
||||
# byte-decode errors, not codec lookup, so fall back to utf-8.
|
||||
decoded.append(data.decode("utf-8", errors="replace"))
|
||||
else:
|
||||
decoded.append(data)
|
||||
return " ".join(decoded)
|
||||
@@ -793,22 +851,27 @@ def _detect_spam_folder(conn):
|
||||
return None
|
||||
|
||||
|
||||
def _imap_move(uid, dest, src="INBOX"):
|
||||
def _imap_move(uid, dest, src="INBOX", account_id: str | None = None, owner: str = ""):
|
||||
"""Move a single IMAP UID from src folder to dest. Returns True on success."""
|
||||
c = None
|
||||
try:
|
||||
c = _imap_connect()
|
||||
c = _imap_connect(account_id, owner=owner)
|
||||
c.select(_q(src))
|
||||
status, _ = c.copy(uid, _q(dest))
|
||||
if status != "OK":
|
||||
c.logout()
|
||||
return False
|
||||
c.store(uid, "+FLAGS", "\\Deleted")
|
||||
c.expunge()
|
||||
c.logout()
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.warning(f"IMAP move {uid} → {dest} failed: {e}")
|
||||
return False
|
||||
finally:
|
||||
if c:
|
||||
try:
|
||||
c.logout()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def _extract_attachment_text(msg, max_chars: int = 6000) -> str:
|
||||
@@ -999,7 +1062,9 @@ def _fetch_sender_thread_context(sender_addr: str,
|
||||
exclude_folder: str = "INBOX",
|
||||
limit: int = 3,
|
||||
max_chars_per_email: int = 1500,
|
||||
max_attachment_chars: int = 4000) -> str:
|
||||
max_attachment_chars: int = 4000,
|
||||
account_id: str | None = None,
|
||||
owner: str = "") -> str:
|
||||
"""Pull the last N emails from `sender_addr` (across common folders),
|
||||
extract their body snippets + attachment text, and return one formatted
|
||||
block ready to be glued into an LLM system prompt as "REFERENCED MATERIAL".
|
||||
@@ -1021,7 +1086,7 @@ def _fetch_sender_thread_context(sender_addr: str,
|
||||
seen_uids.add((exclude_folder or "INBOX", str(exclude_uid)))
|
||||
|
||||
try:
|
||||
conn = _imap_connect()
|
||||
conn = _imap_connect(account_id, owner=owner)
|
||||
except Exception as e:
|
||||
logger.warning(f"sender-thread-context: imap connect failed: {e}")
|
||||
return ""
|
||||
@@ -1104,7 +1169,12 @@ def _fetch_sender_thread_context(sender_addr: str,
|
||||
return "\n\n=====\n\n".join(blocks)
|
||||
|
||||
|
||||
def _pre_retrieve_context(body: str, sender: str) -> tuple:
|
||||
def _pre_retrieve_context(
|
||||
body: str,
|
||||
sender: str,
|
||||
account_id: str | None = None,
|
||||
owner: str = "",
|
||||
) -> tuple:
|
||||
"""Extract key terms from an incoming email and search past emails + contacts.
|
||||
|
||||
Returns (context_snippets, terms_list). Best-effort; never raises.
|
||||
@@ -1128,18 +1198,37 @@ def _pre_retrieve_context(body: str, sender: str) -> tuple:
|
||||
# ── Known-sender check: only retrieve context for senders we already
|
||||
# have a relationship with. New / cold senders get an empty context.
|
||||
sender_addr = email.utils.parseaddr(sender or "")[1].lower()
|
||||
is_known = False
|
||||
# The CardDAV address book is global admin data backed by a single
|
||||
# Radicale instance, so only fold it into reply context for an admin /
|
||||
# single-user owner. Non-admin owners still get their own (owner-scoped)
|
||||
# IMAP history below, just not the shared contacts.
|
||||
try:
|
||||
from routes.contacts_routes import _fetch_contacts
|
||||
for c in _fetch_contacts() or []:
|
||||
if (c.get("email") or "").lower() == sender_addr:
|
||||
is_known = True
|
||||
break
|
||||
from src.tool_security import owner_is_admin_or_single_user
|
||||
contacts_allowed = owner_is_admin_or_single_user(owner or None)
|
||||
except Exception:
|
||||
pass
|
||||
contacts_allowed = not bool(owner)
|
||||
is_known = False
|
||||
if contacts_allowed:
|
||||
try:
|
||||
from routes.contacts_routes import _fetch_contacts
|
||||
for c in _fetch_contacts() or []:
|
||||
# Contacts are normalized to plural `emails` lists, but
|
||||
# keep the legacy singular key fallback for older data.
|
||||
contact_emails = []
|
||||
raw_emails = c.get("emails")
|
||||
if isinstance(raw_emails, list):
|
||||
contact_emails.extend(str(e or "") for e in raw_emails)
|
||||
legacy_email = c.get("email")
|
||||
if legacy_email:
|
||||
contact_emails.append(str(legacy_email))
|
||||
if any((addr or "").strip().lower() == sender_addr for addr in contact_emails):
|
||||
is_known = True
|
||||
break
|
||||
except Exception:
|
||||
pass
|
||||
if not is_known and sender_addr:
|
||||
try:
|
||||
with _imap() as _ck:
|
||||
with _imap(account_id, owner=owner) as _ck:
|
||||
_ck.select("INBOX", readonly=True)
|
||||
st_known, dk = _ck.search(None, f'(FROM "{sender_addr}")')
|
||||
if st_known == "OK" and dk and dk[0]:
|
||||
@@ -1177,7 +1266,7 @@ def _pre_retrieve_context(body: str, sender: str) -> tuple:
|
||||
return context_snippets, terms_list
|
||||
|
||||
try:
|
||||
ctx_conn = _imap_connect()
|
||||
ctx_conn = _imap_connect(account_id, owner=owner)
|
||||
for folder in ["INBOX", "Sent", "Archive", "Drafts"]:
|
||||
try:
|
||||
st_sel, _sd = ctx_conn.select(_q(folder), readonly=True)
|
||||
@@ -1221,18 +1310,18 @@ def _pre_retrieve_context(body: str, sender: str) -> tuple:
|
||||
|
||||
try:
|
||||
from routes.contacts_routes import _fetch_contacts
|
||||
all_contacts = _fetch_contacts()
|
||||
all_contacts = _fetch_contacts() if contacts_allowed else []
|
||||
for term in terms_list:
|
||||
t_lower = term.lower()
|
||||
matches = [c for c in all_contacts
|
||||
if t_lower in (c.get("name") or "").lower()
|
||||
or t_lower in (c.get("email") or "").lower()]
|
||||
or any(t_lower in (e or "").lower() for e in (c.get("emails") or []))]
|
||||
for c in matches[:2]:
|
||||
parts = [f"Name: {c.get('name','')}"]
|
||||
if c.get("email"):
|
||||
parts.append(f"Email: {c['email']}")
|
||||
if c.get("phone"):
|
||||
parts.append(f"Phone: {c['phone']}")
|
||||
if c.get("emails"):
|
||||
parts.append(f"Email: {', '.join(c['emails'])}")
|
||||
if c.get("phones"):
|
||||
parts.append(f"Phone: {', '.join(c['phones'])}")
|
||||
context_snippets.append(f"[Contact match for \"{term}\"] " + ", ".join(parts))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
+98
-47
@@ -45,6 +45,21 @@ from routes.email_helpers import (
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _owner_for_email_account(account_id: str | None) -> str:
|
||||
if not account_id:
|
||||
return ""
|
||||
try:
|
||||
from core.database import SessionLocal as _SL, EmailAccount as _EA
|
||||
db = _SL()
|
||||
try:
|
||||
row = db.query(_EA.owner).filter(_EA.id == account_id).first()
|
||||
return (row[0] or "") if row else ""
|
||||
finally:
|
||||
db.close()
|
||||
except Exception:
|
||||
return ""
|
||||
|
||||
|
||||
# ── Routes ──
|
||||
|
||||
async def _emit_progress(progress_cb, message: str):
|
||||
@@ -84,6 +99,36 @@ async def _run_auto_summarize_once(do_summary: bool = True, do_reply: bool = Tru
|
||||
_save_settings(s2)
|
||||
|
||||
|
||||
def _latest_inbox_fallback_uids(conn, reconnect):
|
||||
"""Latest INBOX UIDs via ``SEARCH ALL``, with a poisoned-socket guard (#1613).
|
||||
|
||||
On a large Gmail mailbox the fallback ``SEARCH ALL`` can time out mid-reply,
|
||||
leaving its enormous ``* SEARCH <uids…>`` line unread on the socket. The next
|
||||
command (the downstream re-select / EXAMINE) then reads those leftover bytes
|
||||
and fails with ``EXAMINE => unexpected response: b'325188 …'``. Reconnecting
|
||||
on failure guarantees the downstream command starts from a clean socket.
|
||||
|
||||
Returns ``(uids, conn)`` — ``conn`` is the live connection to keep using: the
|
||||
same one on success, a fresh one (via ``reconnect()``) if we had to recover.
|
||||
"""
|
||||
try:
|
||||
conn.select("INBOX", readonly=True)
|
||||
status, data = conn.uid("SEARCH", None, "ALL")
|
||||
uids = []
|
||||
if status == "OK" and data and data[0]:
|
||||
for u in reversed(data[0].split()[-8:]):
|
||||
uids.append(("INBOX", u))
|
||||
logger.info("Email task SINCE scan found no messages; fell back to latest INBOX messages")
|
||||
return uids, conn
|
||||
except Exception as _e:
|
||||
logger.warning(f"Latest-INBOX fallback scan failed: {_e}")
|
||||
try:
|
||||
conn.logout()
|
||||
except Exception:
|
||||
pass
|
||||
return [], reconnect()
|
||||
|
||||
|
||||
async def _auto_summarize_pass(days_back: int = 1, account_id: str | None = None, progress_cb=None) -> str:
|
||||
"""Single pass of the auto-summarize/reply scan.
|
||||
|
||||
@@ -132,7 +177,7 @@ async def _auto_summarize_pass_single(days_back: int = 1, account_id: str | None
|
||||
import sqlite3 as _sql3
|
||||
import requests as _req
|
||||
from src.endpoint_resolver import resolve_endpoint
|
||||
from src.llm_core import _uses_max_completion_tokens
|
||||
from src.llm_core import _uses_max_completion_tokens, _restricts_temperature
|
||||
|
||||
settings = _load_settings()
|
||||
auto_sum = settings.get("email_auto_summarize", False)
|
||||
@@ -143,25 +188,18 @@ async def _auto_summarize_pass_single(days_back: int = 1, account_id: str | None
|
||||
if not auto_sum and not auto_reply and not auto_tag and not auto_spam and not auto_cal:
|
||||
return "Nothing to do"
|
||||
|
||||
# Owner of the account being processed. All calendar reads/writes below are
|
||||
# scoped to this user: the multi-account fan-out runs every user's mailbox,
|
||||
# so an unscoped pass would disclose and mutate other tenants' calendars.
|
||||
_acct_owner = None
|
||||
try:
|
||||
from core.database import SessionLocal as _SLo, EmailAccount as _EAo
|
||||
_dbo = _SLo()
|
||||
try:
|
||||
if account_id:
|
||||
_arow = _dbo.query(_EAo).filter(_EAo.id == account_id).first()
|
||||
_acct_owner = _arow.owner if _arow else None
|
||||
finally:
|
||||
_dbo.close()
|
||||
except Exception:
|
||||
_acct_owner = None
|
||||
# Owner of the account being processed. All calendar + mailbox reads/writes
|
||||
# below are scoped to this user: the multi-account fan-out runs every user's
|
||||
# mailbox, so an unscoped pass would disclose/mutate other tenants' data.
|
||||
# One resolution feeds both the mailbox path (account_owner) and upstream's
|
||||
# calendar path (_acct_owner, which expects None rather than "").
|
||||
account_owner = _owner_for_email_account(account_id)
|
||||
_acct_owner = account_owner or None
|
||||
|
||||
conn = None
|
||||
try:
|
||||
await _emit_progress(progress_cb, "Connecting to mail…")
|
||||
conn = _imap_connect(account_id)
|
||||
conn = _imap_connect(account_id, owner=account_owner)
|
||||
from datetime import timedelta as _td
|
||||
since = (datetime.utcnow() - _td(days=max(1, days_back))).strftime("%d-%b-%Y")
|
||||
# uid_list carries real IMAP UIDs, matching the email UI/read routes.
|
||||
@@ -193,26 +231,27 @@ async def _auto_summarize_pass_single(days_back: int = 1, account_id: str | None
|
||||
# the latest visible inbox messages so Clear cache -> Run again can
|
||||
# actually repopulate AI reply/summary/tag caches.
|
||||
if not uid_list:
|
||||
try:
|
||||
conn.select("INBOX", readonly=True)
|
||||
status, data = conn.uid("SEARCH", None, "ALL")
|
||||
if status == "OK" and data and data[0]:
|
||||
for u in reversed(data[0].split()[-8:]):
|
||||
uid_list.append(("INBOX", u))
|
||||
logger.info("Email task SINCE scan found no messages; fell back to latest INBOX messages")
|
||||
except Exception as _e:
|
||||
logger.warning(f"Latest-INBOX fallback scan failed: {_e}")
|
||||
# Re-select INBOX as default for downstream code
|
||||
_fb_uids, conn = _latest_inbox_fallback_uids(
|
||||
conn, lambda: _imap_connect(account_id, owner=account_owner)
|
||||
)
|
||||
uid_list.extend(_fb_uids)
|
||||
# Re-select INBOX as default for downstream code (on a clean socket even
|
||||
# if the SEARCH ALL fallback above failed — see #1613).
|
||||
conn.select("INBOX", readonly=True)
|
||||
if not uid_list:
|
||||
conn.logout()
|
||||
return "No recent emails"
|
||||
await _emit_progress(progress_cb, f"Found {len(uid_list)} recent email(s); checking cache…")
|
||||
|
||||
_c = _sql3.connect(SCHEDULED_DB)
|
||||
_sum_existing = {r[0] for r in _c.execute("SELECT message_id FROM email_summaries").fetchall()}
|
||||
_reply_existing = {r[0] for r in _c.execute("SELECT message_id FROM email_ai_replies").fetchall()}
|
||||
_tag_existing = {r[0] for r in _c.execute("SELECT message_id FROM email_tags").fetchall()} if (auto_tag or auto_spam) else set()
|
||||
if auto_tag or auto_spam:
|
||||
if account_owner:
|
||||
_tag_existing = {r[0] for r in _c.execute("SELECT message_id FROM email_tags WHERE owner=?", (account_owner,)).fetchall()}
|
||||
else:
|
||||
_tag_existing = {r[0] for r in _c.execute("SELECT message_id FROM email_tags WHERE owner='' OR owner IS NULL").fetchall()}
|
||||
else:
|
||||
_tag_existing = set()
|
||||
_cal_existing = {r[0] for r in _c.execute("SELECT message_id FROM email_calendar_extractions").fetchall()} if auto_cal else set()
|
||||
# Urgency is handled by the built-in `check_email_urgency` task. Keep
|
||||
# this legacy poller path disabled so users don't get two independent
|
||||
@@ -225,7 +264,7 @@ async def _auto_summarize_pass_single(days_back: int = 1, account_id: str | None
|
||||
# this per-iteration was making big inbox scans crawl. Used by the
|
||||
# urgency self-loop check below.
|
||||
try:
|
||||
_self_self_addr = (_get_email_config(account_id).get("from_address") or "").strip().lower()
|
||||
_self_self_addr = (_get_email_config(account_id, owner=account_owner).get("from_address") or "").strip().lower()
|
||||
except Exception:
|
||||
_self_self_addr = ""
|
||||
|
||||
@@ -233,11 +272,10 @@ async def _auto_summarize_pass_single(days_back: int = 1, account_id: str | None
|
||||
if auto_spam and not spam_folder:
|
||||
logger.warning("Auto-spam enabled but no Junk/Spam folder detected — will classify but not move")
|
||||
|
||||
url, model, headers = resolve_endpoint("utility")
|
||||
url, model, headers = resolve_endpoint("utility", owner=account_owner)
|
||||
if not url:
|
||||
url, model, headers = resolve_endpoint("default")
|
||||
url, model, headers = resolve_endpoint("default", owner=account_owner)
|
||||
if not url or not model:
|
||||
conn.logout()
|
||||
return "No model configured"
|
||||
|
||||
writing_style = settings.get("email_writing_style", "")
|
||||
@@ -355,6 +393,9 @@ async def _auto_summarize_pass_single(days_back: int = 1, account_id: str | None
|
||||
"temperature": 0.3,
|
||||
"stream": False,
|
||||
}
|
||||
# Reasoning models (o1/o3/o4/gpt-5) reject an explicit temperature.
|
||||
if _restricts_temperature(model):
|
||||
payload.pop("temperature", None)
|
||||
try:
|
||||
# Use to_thread so this sync HTTP call doesn't freeze
|
||||
# the entire event loop while the LLM thinks (240s).
|
||||
@@ -392,8 +433,8 @@ async def _auto_summarize_pass_single(days_back: int = 1, account_id: str | None
|
||||
await _emit_progress(progress_cb, f"Drafting reply {processed + 1}/{_max_process} · checked {examined}/{len(uid_list)}")
|
||||
# Background reply drafting should not make the whole app
|
||||
# feel busy. Keep it lightweight: no extra IMAP context
|
||||
# mining here; manual AI Reply can still do that when the
|
||||
# user explicitly asks for a draft on one email.
|
||||
# mining here; manual AI Reply can still do that (owner-scoped)
|
||||
# when the user explicitly asks for a draft on one email.
|
||||
context_snippets, _terms = [], []
|
||||
sys_prompt = _EMAIL_REPLY_SYS_PROMPT_BASE
|
||||
if att_text:
|
||||
@@ -708,7 +749,7 @@ async def _auto_summarize_pass_single(days_back: int = 1, account_id: str | None
|
||||
# Send alert email immediately if critical or high
|
||||
if urgency in ("critical", "high"):
|
||||
try:
|
||||
cfg = _get_email_config(account_id)
|
||||
cfg = _get_email_config(account_id, owner=account_owner)
|
||||
to_addr = cfg["from_address"] # self-email
|
||||
|
||||
# Deep-link to open the original email in Odysseus (if public URL is configured).
|
||||
@@ -716,8 +757,8 @@ async def _auto_summarize_pass_single(days_back: int = 1, account_id: str | None
|
||||
from src.settings import load_settings as _ls
|
||||
_pub = (_ls().get("app_public_url") or "").rstrip("/")
|
||||
uid_str = uid.decode() if isinstance(uid, bytes) else str(uid)
|
||||
from urllib.parse import quote as _q
|
||||
open_url = f"{_pub}/#email={_q(_folder, safe='')}:{uid_str}" if _pub else ""
|
||||
from urllib.parse import quote as _url_q
|
||||
open_url = f"{_pub}/#email={_url_q(_folder, safe='')}:{uid_str}" if _pub else ""
|
||||
|
||||
alert_subject = f"[{urgency.upper()}] {subject}"
|
||||
alert_body = (
|
||||
@@ -806,12 +847,15 @@ async def _auto_summarize_pass_single(days_back: int = 1, account_id: str | None
|
||||
"temperature": 0.1,
|
||||
"stream": False,
|
||||
}
|
||||
# Reasoning models (o1/o3/o4/gpt-5) reject an explicit temperature.
|
||||
if _restricts_temperature(model):
|
||||
payload.pop("temperature", None)
|
||||
# to_thread keeps the event loop responsive during the LLM call
|
||||
resp = await asyncio.to_thread(
|
||||
_req.post, url, json=payload, headers=req_headers, timeout=120
|
||||
)
|
||||
if not resp.ok:
|
||||
logger.warning(f"Auto-classify {uid.decode()} HTTP {resp.status_code}: {resp.text[:200]}")
|
||||
logger.warning(f"Auto-classify {uid.decode() if isinstance(uid, bytes) else str(uid)} HTTP {resp.status_code}: {resp.text[:200]}")
|
||||
else:
|
||||
rdata = resp.json()
|
||||
m = (rdata.get("choices") or [{}])[0].get("message", {})
|
||||
@@ -840,17 +884,17 @@ async def _auto_summarize_pass_single(days_back: int = 1, account_id: str | None
|
||||
|
||||
moved_to = ""
|
||||
if is_spam and auto_spam and spam_folder:
|
||||
if _imap_move(uid, spam_folder):
|
||||
if _imap_move(uid, spam_folder, account_id=account_id, owner=account_owner):
|
||||
moved_to = spam_folder
|
||||
logger.info(f"Auto-spam moved uid={uid.decode()} to {spam_folder}: {spam_reason}")
|
||||
|
||||
_c = _sql3.connect(SCHEDULED_DB)
|
||||
_c.execute("""
|
||||
INSERT OR REPLACE INTO email_tags
|
||||
(message_id, uid, folder, subject, sender, tags, spam_verdict,
|
||||
(message_id, owner, uid, folder, subject, sender, tags, spam_verdict,
|
||||
spam_reason, moved_to, model_used, created_at)
|
||||
VALUES (?, ?, 'INBOX', ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""", (message_id, uid.decode(), subject, sender,
|
||||
VALUES (?, ?, ?, 'INBOX', ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""", (message_id, account_owner or "", uid.decode(), subject, sender,
|
||||
json.dumps(tags), 1 if is_spam else 0,
|
||||
spam_reason, moved_to, model, datetime.utcnow().isoformat()))
|
||||
_c.commit()
|
||||
@@ -865,7 +909,6 @@ async def _auto_summarize_pass_single(days_back: int = 1, account_id: str | None
|
||||
logger.warning(f"Auto-process {uid} failed: {e}")
|
||||
continue
|
||||
|
||||
conn.logout()
|
||||
await _emit_progress(progress_cb, "Finishing…")
|
||||
if processed > 0:
|
||||
logger.info(f"Auto-processed {processed} new email(s) for summary/reply/classify")
|
||||
@@ -902,6 +945,12 @@ async def _auto_summarize_pass_single(days_back: int = 1, account_id: str | None
|
||||
except Exception as e:
|
||||
logger.warning(f"Auto-summarize pass error: {e}")
|
||||
return f"Error: {e}"
|
||||
finally:
|
||||
if conn:
|
||||
try:
|
||||
conn.logout()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
async def _auto_summarize_poller():
|
||||
@@ -930,8 +979,9 @@ def _scheduled_poll_once() -> dict:
|
||||
conn = sqlite3.connect(SCHEDULED_DB)
|
||||
cols = [row[1] for row in conn.execute("PRAGMA table_info(scheduled_emails)").fetchall()]
|
||||
kind_expr = "odysseus_kind" if "odysseus_kind" in cols else "'scheduled' AS odysseus_kind"
|
||||
owner_expr = "owner" if "owner" in cols else "'' AS owner"
|
||||
rows = conn.execute(f"""
|
||||
SELECT id, to_addr, cc, bcc, subject, body, in_reply_to, references_hdr, attachments, account_id, {kind_expr}
|
||||
SELECT id, to_addr, cc, bcc, subject, body, in_reply_to, references_hdr, attachments, account_id, {kind_expr}, {owner_expr}
|
||||
FROM scheduled_emails
|
||||
WHERE status = 'pending' AND send_at <= ?
|
||||
""", (now_iso,)).fetchall()
|
||||
@@ -943,7 +993,8 @@ def _scheduled_poll_once() -> dict:
|
||||
attachments = json.loads(r[8] or "[]")
|
||||
row_account_id = r[9] if len(r) > 9 else None
|
||||
odysseus_kind = r[10] if len(r) > 10 else "scheduled"
|
||||
cfg = _get_email_config(row_account_id)
|
||||
row_owner = (r[11] if len(r) > 11 else "") or _owner_for_email_account(row_account_id)
|
||||
cfg = _get_email_config(row_account_id, owner=row_owner)
|
||||
has_atts = bool(attachments)
|
||||
if has_atts:
|
||||
outer = MIMEMultipart("mixed")
|
||||
@@ -980,7 +1031,7 @@ def _scheduled_poll_once() -> dict:
|
||||
|
||||
# Append to local Sent folder
|
||||
try:
|
||||
with _imap() as imap:
|
||||
with _imap(row_account_id, owner=row_owner) as imap:
|
||||
sent_folder = _detect_sent_folder(imap)
|
||||
imap.append(sent_folder, "\\Seen", None, outer.as_bytes())
|
||||
except Exception as e:
|
||||
|
||||
+113
-58
@@ -17,7 +17,6 @@ import sqlite3 as _sql3
|
||||
import email as email_mod
|
||||
import email.header
|
||||
import email.utils
|
||||
import imaplib
|
||||
import smtplib
|
||||
import json
|
||||
import re
|
||||
@@ -40,7 +39,8 @@ from routes.email_helpers import (
|
||||
_strip_think, _extract_reply, _apply_email_style_mechanics, require_owner, require_user, _assert_owns_account,
|
||||
_q, _attach_compose_uploads, _cleanup_compose_uploads,
|
||||
_load_settings, _save_settings, _get_email_config,
|
||||
_send_smtp_message,
|
||||
_send_smtp_message, _smtp_security_mode,
|
||||
_IMAP_TIMEOUT_SECONDS, _open_imap_connection,
|
||||
_imap_connect, _imap, _decode_header, _detect_sent_folder, _detect_drafts_folder,
|
||||
_extract_attachment_text, _list_attachments_from_msg,
|
||||
_extract_attachment_to_disk, _extract_html, _extract_text,
|
||||
@@ -90,6 +90,16 @@ def _email_tag_owner_aliases(account_id: str | None, owner: str = "") -> list[st
|
||||
return out or [""]
|
||||
|
||||
|
||||
def _email_tag_owner_clause(account_id: str | None, owner: str = "") -> tuple[str, list[str]]:
|
||||
aliases = _email_tag_owner_aliases(account_id, owner)
|
||||
placeholders = ",".join("?" * len(aliases))
|
||||
# In configured multi-user mode, do not treat legacy owner='' rows as
|
||||
# visible to everyone. Single-user/unconfigured mode keeps legacy rows.
|
||||
if owner:
|
||||
return f"owner IN ({placeholders})", aliases
|
||||
return f"(owner IN ({placeholders}) OR owner IS NULL)", aliases
|
||||
|
||||
|
||||
def _record_email_received_events(owner: str, account_id: str | None, folder: str, emails: list[dict]):
|
||||
"""Baseline inbox messages, then fire `email_received` for new arrivals."""
|
||||
if not owner or (folder or "INBOX").upper() != "INBOX" or not emails:
|
||||
@@ -312,6 +322,20 @@ def _apply_odysseus_headers(msg, kind: str | None = None, ref_id: str | None = N
|
||||
msg["X-Odysseus-Ref"] = re.sub(r"[^A-Za-z0-9_.:-]", "-", ref_id)[:128]
|
||||
|
||||
|
||||
def _envelope_recipients(*fields: str) -> list:
|
||||
"""Extract bare SMTP envelope addresses from one or more To/Cc/Bcc header
|
||||
strings. A naive `field.split(",")` corrupts display names that contain a
|
||||
comma (e.g. `"Smith, John" <john@corp.com>`, the canonical Outlook form):
|
||||
it splits into `"Smith` and `John" <john@corp.com>`, breaking delivery.
|
||||
email.utils.getaddresses parses the address grammar correctly."""
|
||||
out = []
|
||||
for _name, addr in email.utils.getaddresses([f for f in fields if f]):
|
||||
addr = (addr or "").strip()
|
||||
if addr:
|
||||
out.append(addr)
|
||||
return out
|
||||
|
||||
|
||||
def _md_to_email_html(text: str) -> str:
|
||||
"""Render the compose markdown body to a SAFE HTML fragment for the email's
|
||||
text/html part. Everything is HTML-escaped FIRST (so a pasted <script> /
|
||||
@@ -457,7 +481,7 @@ def setup_email_routes():
|
||||
_IMAP_POOL = {} # account_id → (conn, last_used_at)
|
||||
_IMAP_IDLE_MAX = 60.0
|
||||
_WARMING_READS = set()
|
||||
_WARM_READ_LIMIT = 3
|
||||
_WARM_READ_LIMIT = 1
|
||||
_WARM_MAX_BYTES = 128 * 1024
|
||||
_WARM_RECENT_SECONDS = 7 * 24 * 60 * 60
|
||||
_pool_lock = _threading.Lock()
|
||||
@@ -591,11 +615,11 @@ def setup_email_routes():
|
||||
SECURITY: `owner` is propagated so when `account_id` is missing,
|
||||
the fallback config lookup is scoped to this user's accounts only.
|
||||
"""
|
||||
conn = None
|
||||
try:
|
||||
conn = _imap_connect(account_id, owner=owner)
|
||||
select_status, _ = conn.select(_q(folder), readonly=True)
|
||||
if select_status != "OK":
|
||||
conn.logout()
|
||||
return {"emails": [], "total": 0, "folder": folder, "error": f"Folder not found: {folder}"}
|
||||
|
||||
from_clause = ""
|
||||
@@ -645,8 +669,7 @@ def setup_email_routes():
|
||||
try:
|
||||
import sqlite3 as _sql3t
|
||||
_ct = _sql3t.connect(SCHEDULED_DB)
|
||||
_owner_aliases = _email_tag_owner_aliases(account_id, owner)
|
||||
_owner_ph = ",".join("?" * len(_owner_aliases))
|
||||
_owner_clause, _owner_params = _email_tag_owner_clause(account_id, owner)
|
||||
# SECURITY: owner-scope the lookup (review C2/H8). Without
|
||||
# this, user A's `tag:urgent` filter would surface UIDs
|
||||
# written by user B and IMAP would return whatever
|
||||
@@ -658,8 +681,8 @@ def setup_email_routes():
|
||||
rows_t = _ct.execute(
|
||||
"SELECT message_id, uid FROM email_tags "
|
||||
"WHERE folder=? AND spam_verdict=1 "
|
||||
f"AND (owner IN ({_owner_ph}) OR owner IS NULL)",
|
||||
(folder, *_owner_aliases),
|
||||
f"AND {_owner_clause}",
|
||||
(folder, *_owner_params),
|
||||
).fetchall()
|
||||
for mid, uid in rows_t:
|
||||
if mid:
|
||||
@@ -670,8 +693,8 @@ def setup_email_routes():
|
||||
rows_t = _ct.execute(
|
||||
"SELECT message_id, uid, tags FROM email_tags "
|
||||
"WHERE folder=? AND tags IS NOT NULL AND tags != '' "
|
||||
f"AND (owner IN ({_owner_ph}) OR owner IS NULL)",
|
||||
(folder, *_owner_aliases),
|
||||
f"AND {_owner_clause}",
|
||||
(folder, *_owner_params),
|
||||
).fetchall()
|
||||
for r in rows_t:
|
||||
try:
|
||||
@@ -743,12 +766,11 @@ def setup_email_routes():
|
||||
_uid_strs = [u.decode() for u in uid_list]
|
||||
if _uid_strs:
|
||||
placeholders = ",".join("?" * len(_uid_strs))
|
||||
_owner_aliases = _email_tag_owner_aliases(account_id, owner)
|
||||
_owner_ph = ",".join("?" * len(_owner_aliases))
|
||||
_owner_clause, _owner_params = _email_tag_owner_clause(account_id, owner)
|
||||
rows = _c.execute(
|
||||
f"SELECT uid, tags, spam_verdict FROM email_tags "
|
||||
f"WHERE folder=? AND (owner IN ({_owner_ph}) OR owner IS NULL) AND uid IN ({placeholders})",
|
||||
[folder, *_owner_aliases, *_uid_strs],
|
||||
f"WHERE folder=? AND {_owner_clause} AND uid IN ({placeholders})",
|
||||
[folder, *_owner_params, *_uid_strs],
|
||||
).fetchall()
|
||||
for r in rows:
|
||||
try:
|
||||
@@ -805,14 +827,13 @@ def setup_email_routes():
|
||||
if header_ids:
|
||||
import sqlite3 as _sql3m
|
||||
_cm = _sql3m.connect(SCHEDULED_DB)
|
||||
_owner_aliases_m = _email_tag_owner_aliases(account_id, owner)
|
||||
_owner_ph_m = ",".join("?" * len(_owner_aliases_m))
|
||||
_owner_clause_m, _owner_params_m = _email_tag_owner_clause(account_id, owner)
|
||||
_mid_ph = ",".join("?" * len(header_ids))
|
||||
rows_m = _cm.execute(
|
||||
f"SELECT message_id, tags, spam_verdict FROM email_tags "
|
||||
f"WHERE folder=? AND (owner IN ({_owner_ph_m}) OR owner IS NULL) "
|
||||
f"WHERE folder=? AND {_owner_clause_m} "
|
||||
f"AND message_id IN ({_mid_ph})",
|
||||
[folder, *_owner_aliases_m, *header_ids],
|
||||
[folder, *_owner_params_m, *header_ids],
|
||||
).fetchall()
|
||||
_cm.close()
|
||||
for mid, tags_raw, spam_raw in rows_m:
|
||||
@@ -924,12 +945,17 @@ def setup_email_routes():
|
||||
except Exception as _summary_err:
|
||||
logger.debug(f"Bulk summary attach skipped: {_summary_err}")
|
||||
|
||||
conn.logout()
|
||||
return {"emails": emails, "total": total, "folder": folder, "offset": offset}
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to list emails: {e}")
|
||||
detail = str(e).strip()
|
||||
return {"emails": [], "total": 0, "error": f"Mail operation failed: {detail[:180]}" if detail else "Mail operation failed"}
|
||||
finally:
|
||||
if conn:
|
||||
try:
|
||||
conn.logout()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@router.get("/list")
|
||||
async def list_emails(
|
||||
@@ -971,10 +997,11 @@ def setup_email_routes():
|
||||
async def unflag_spam(uid: str, owner: str = Depends(require_owner)):
|
||||
"""User override — mark email as not spam."""
|
||||
try:
|
||||
owner_clause, owner_params = _email_tag_owner_clause(None, owner)
|
||||
_c = _sql3.connect(SCHEDULED_DB)
|
||||
_c.execute(
|
||||
"UPDATE email_tags SET spam_verdict=0, spam_reason='' WHERE uid=?",
|
||||
(uid,),
|
||||
f"UPDATE email_tags SET spam_verdict=0, spam_reason='' WHERE uid=? AND {owner_clause}",
|
||||
[uid, *owner_params],
|
||||
)
|
||||
_c.commit()
|
||||
_c.close()
|
||||
@@ -997,8 +1024,10 @@ def setup_email_routes():
|
||||
ql = (q or "").strip().lower()
|
||||
try:
|
||||
conn = _sql3.connect(SCHEDULED_DB)
|
||||
owner_clause, owner_params = _email_tag_owner_clause(None, owner)
|
||||
rows = conn.execute(
|
||||
"SELECT sender FROM email_tags WHERE sender IS NOT NULL AND sender != ''"
|
||||
f"SELECT sender FROM email_tags WHERE sender IS NOT NULL AND sender != '' AND {owner_clause}",
|
||||
owner_params,
|
||||
).fetchall()
|
||||
conn.close()
|
||||
seen = {}
|
||||
@@ -1046,7 +1075,7 @@ def setup_email_routes():
|
||||
|
||||
# Escape backslash and quote for the IMAP-SEARCH quoted-string.
|
||||
q_escaped = q.replace('\\', '\\\\').replace('"', '\\"')
|
||||
search_cmd = f'(OR FROM "{q_escaped}" TEXT "{q_escaped}")'
|
||||
search_cmd = f'(OR OR FROM "{q_escaped}" SUBJECT "{q_escaped}" TEXT "{q_escaped}")'
|
||||
|
||||
status, data = _imap_uid_search(conn, search_cmd)
|
||||
if status != "OK" or not data[0]:
|
||||
@@ -1928,11 +1957,7 @@ def setup_email_routes():
|
||||
outer.attach(body_container)
|
||||
_attach_compose_uploads(outer, attachments)
|
||||
|
||||
recipients = [r.strip() for r in to.split(",") if r.strip()]
|
||||
if cc:
|
||||
recipients.extend([r.strip() for r in cc.split(",") if r.strip()])
|
||||
if bcc:
|
||||
recipients.extend([r.strip() for r in bcc.split(",") if r.strip()])
|
||||
recipients = _envelope_recipients(to, cc, bcc)
|
||||
|
||||
_send_smtp_message(cfg, cfg["from_address"], recipients, outer.as_string())
|
||||
|
||||
@@ -1964,13 +1989,22 @@ def setup_email_routes():
|
||||
# minute doesn't trip the past-time guard.
|
||||
if parsed_at < now_utc:
|
||||
return {"success": False, "error": "send_at must be in the future"}
|
||||
# Normalize to naive UTC before storing: the poller selects due
|
||||
# rows with a lexicographic string compare against a naive
|
||||
# datetime.utcnow().isoformat(), so storing the raw client string
|
||||
# makes "+02:00" schedules fire hours late, negative offsets fire
|
||||
# hours early, and a "Z" suffix compares after the fractional
|
||||
# seconds of the poller timestamp.
|
||||
if parsed_at.tzinfo:
|
||||
parsed_at = parsed_at.astimezone(_tz.utc).replace(tzinfo=None)
|
||||
send_at = parsed_at.isoformat()
|
||||
|
||||
sid = _uuid.uuid4().hex[:16]
|
||||
conn = sqlite3.connect(SCHEDULED_DB)
|
||||
conn.execute("""
|
||||
INSERT INTO scheduled_emails
|
||||
(id, to_addr, cc, bcc, subject, body, in_reply_to, references_hdr, attachments, send_at, created_at, status, account_id, odysseus_kind)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, 'pending', ?, ?)
|
||||
(id, to_addr, cc, bcc, subject, body, in_reply_to, references_hdr, attachments, send_at, created_at, status, account_id, odysseus_kind, owner)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, 'pending', ?, ?, ?)
|
||||
""", (
|
||||
sid,
|
||||
req.get("to", ""),
|
||||
@@ -1985,6 +2019,7 @@ def setup_email_routes():
|
||||
datetime.utcnow().isoformat(),
|
||||
req.get("account_id") or None,
|
||||
req.get("odysseus_kind") or "scheduled",
|
||||
owner or "",
|
||||
))
|
||||
conn.commit()
|
||||
conn.close()
|
||||
@@ -2003,9 +2038,9 @@ def setup_email_routes():
|
||||
rows = conn.execute("""
|
||||
SELECT id, to_addr, cc, subject, send_at, created_at, status, error
|
||||
FROM scheduled_emails
|
||||
WHERE status IN ('pending', 'failed')
|
||||
WHERE status IN ('pending', 'failed') AND owner = ?
|
||||
ORDER BY send_at ASC
|
||||
""").fetchall()
|
||||
""", (owner or "",)).fetchall()
|
||||
conn.close()
|
||||
return {"scheduled": [
|
||||
{
|
||||
@@ -2023,7 +2058,10 @@ def setup_email_routes():
|
||||
import sqlite3
|
||||
try:
|
||||
conn = sqlite3.connect(SCHEDULED_DB)
|
||||
conn.execute("DELETE FROM scheduled_emails WHERE id = ? AND status = 'pending'", (sid,))
|
||||
conn.execute(
|
||||
"DELETE FROM scheduled_emails WHERE id = ? AND status = 'pending' AND owner = ?",
|
||||
(sid, owner or ""),
|
||||
)
|
||||
conn.commit()
|
||||
conn.close()
|
||||
return {"success": True}
|
||||
@@ -2035,7 +2073,7 @@ def setup_email_routes():
|
||||
async def resolve_contact(name: str = Query(..., description="Name to search for"), owner: str = Depends(require_owner)):
|
||||
"""Search Sent folder for a contact by name. Returns matching email addresses."""
|
||||
try:
|
||||
with _imap() as conn:
|
||||
with _imap(owner=owner) as conn:
|
||||
matches = {}
|
||||
for folder in ["Sent", "INBOX", "Drafts"]:
|
||||
try:
|
||||
@@ -2133,12 +2171,9 @@ def setup_email_routes():
|
||||
outer.attach(body_container)
|
||||
_attach_compose_uploads(outer, req.attachments)
|
||||
|
||||
# Build recipient list
|
||||
recipients = [r.strip() for r in req.to.split(",") if r.strip()]
|
||||
if req.cc:
|
||||
recipients.extend([r.strip() for r in req.cc.split(",") if r.strip()])
|
||||
if req.bcc:
|
||||
recipients.extend([r.strip() for r in req.bcc.split(",") if r.strip()])
|
||||
# Build recipient list (parse the address grammar so display names with
|
||||
# commas don't get split into broken envelope addresses)
|
||||
recipients = _envelope_recipients(req.to, req.cc, req.bcc)
|
||||
|
||||
# Serialize what the background task needs so the request object can be GC'd
|
||||
outer_bytes = outer.as_bytes()
|
||||
@@ -2146,6 +2181,7 @@ def setup_email_routes():
|
||||
_from = cfg["from_address"]
|
||||
_smtp_host = cfg["smtp_host"]
|
||||
_smtp_port = cfg["smtp_port"]
|
||||
_smtp_security = cfg.get("smtp_security")
|
||||
_smtp_user = cfg["smtp_user"]
|
||||
_smtp_pw = cfg["smtp_password"]
|
||||
_recipients = list(recipients)
|
||||
@@ -2163,6 +2199,7 @@ def setup_email_routes():
|
||||
{
|
||||
"smtp_host": _smtp_host,
|
||||
"smtp_port": _smtp_port,
|
||||
"smtp_security": _smtp_security,
|
||||
"smtp_user": _smtp_user,
|
||||
"smtp_password": _smtp_pw,
|
||||
},
|
||||
@@ -2417,7 +2454,7 @@ def setup_email_routes():
|
||||
"""Generate a quick AI summary of an email body."""
|
||||
try:
|
||||
from src.endpoint_resolver import resolve_endpoint
|
||||
from src.llm_core import _uses_max_completion_tokens
|
||||
from src.llm_core import _uses_max_completion_tokens, _restricts_temperature
|
||||
import requests as _req
|
||||
|
||||
body = data.get("body", "")
|
||||
@@ -2474,6 +2511,9 @@ def setup_email_routes():
|
||||
"temperature": 0.3,
|
||||
"stream": False,
|
||||
}
|
||||
# Reasoning models (o1/o3/o4/gpt-5) reject an explicit temperature.
|
||||
if _restricts_temperature(model):
|
||||
payload.pop("temperature", None)
|
||||
resp = await asyncio.to_thread(
|
||||
_req.post, url, json=payload, headers=req_headers, timeout=180
|
||||
)
|
||||
@@ -2585,7 +2625,7 @@ def setup_email_routes():
|
||||
# `api_key` field.
|
||||
from core.database import SessionLocal as _SL, Session as _CS
|
||||
_db = _SL()
|
||||
sess = _db.query(_CS).filter(_CS.id == session_id).first()
|
||||
sess = _db.query(_CS).filter(_CS.id == session_id, _CS.owner == owner).first()
|
||||
if sess and sess.endpoint_url:
|
||||
url = sess.endpoint_url
|
||||
# Some sessions stored headers double-encoded (a JSON
|
||||
@@ -2644,9 +2684,10 @@ def setup_email_routes():
|
||||
# Manual AI Reply should feel immediate. The heavier context mining
|
||||
# can involve multiple IMAP folder searches and attachment parsing;
|
||||
# reserve that for callers that explicitly opt out of fast mode.
|
||||
# Owner-scoped so pre-retrieval never crosses tenants.
|
||||
context_snippets, _terms = ([], [])
|
||||
if not fast_reply:
|
||||
context_snippets, _terms = _pre_retrieve_context(original_body, to)
|
||||
context_snippets, _terms = _pre_retrieve_context(original_body, to, owner=owner)
|
||||
|
||||
# NEW: also pull the last few emails from the original sender +
|
||||
# their attachments. The "to" field on this endpoint is the
|
||||
@@ -2662,6 +2703,7 @@ def setup_email_routes():
|
||||
exclude_uid=source_uid,
|
||||
exclude_folder=source_folder,
|
||||
limit=3,
|
||||
owner=owner,
|
||||
)
|
||||
except Exception as _e:
|
||||
logger.warning(f"sender-thread-context failed: {_e}")
|
||||
@@ -2723,7 +2765,7 @@ def setup_email_routes():
|
||||
# Configured fallback chains last.
|
||||
for cand in resolve_utility_fallback_candidates(owner=owner) or []:
|
||||
_add(*cand)
|
||||
for cand in resolve_chat_fallback_candidates() or []:
|
||||
for cand in resolve_chat_fallback_candidates(owner=owner) or []:
|
||||
_add(*cand)
|
||||
try:
|
||||
reply = await llm_call_async_with_fallback(
|
||||
@@ -2814,13 +2856,16 @@ def setup_email_routes():
|
||||
import uuid as _uuid
|
||||
db = SessionLocal()
|
||||
try:
|
||||
row = db.query(EmailAccount).filter(EmailAccount.is_default == True).first() # noqa: E712
|
||||
q = db.query(EmailAccount).filter(EmailAccount.is_default == True) # noqa: E712
|
||||
if owner:
|
||||
q = q.filter(EmailAccount.owner == owner)
|
||||
row = q.first()
|
||||
if row is None:
|
||||
row = EmailAccount(id=_uuid.uuid4().hex, name="Default", is_default=True, enabled=True)
|
||||
row = EmailAccount(id=_uuid.uuid4().hex, owner=owner, name="Default", is_default=True, enabled=True)
|
||||
db.add(row)
|
||||
field_map = {
|
||||
"smtp_host": "smtp_host", "smtp_port": "smtp_port", "smtp_user": "smtp_user",
|
||||
"imap_host": "imap_host", "imap_port": "imap_port", "imap_user": "imap_user",
|
||||
"smtp_security": "smtp_security", "imap_host": "imap_host", "imap_port": "imap_port", "imap_user": "imap_user",
|
||||
"imap_starttls": "imap_starttls", "email_from": "from_address",
|
||||
}
|
||||
for in_key, col_name in field_map.items():
|
||||
@@ -2838,6 +2883,10 @@ def setup_email_routes():
|
||||
row.imap_password = _enc(data["imap_password"])
|
||||
if data.get("smtp_password"):
|
||||
row.smtp_password = _enc(data["smtp_password"])
|
||||
clear_q = db.query(EmailAccount).filter(EmailAccount.id != row.id)
|
||||
if owner:
|
||||
clear_q = clear_q.filter(EmailAccount.owner == owner)
|
||||
clear_q.update({EmailAccount.is_default: False})
|
||||
db.commit()
|
||||
finally:
|
||||
db.close()
|
||||
@@ -2902,6 +2951,7 @@ def setup_email_routes():
|
||||
"imap_starttls": bool(r.imap_starttls),
|
||||
"smtp_host": r.smtp_host or "",
|
||||
"smtp_port": int(r.smtp_port or 465),
|
||||
"smtp_security": _smtp_security_mode({"smtp_security": getattr(r, "smtp_security", ""), "smtp_port": r.smtp_port}),
|
||||
"smtp_user": r.smtp_user or "",
|
||||
"from_address": r.from_address or "",
|
||||
"has_imap_password": bool(r.imap_password),
|
||||
@@ -2934,6 +2984,7 @@ def setup_email_routes():
|
||||
imap_starttls=bool(data.get("imap_starttls", True)),
|
||||
smtp_host=(data.get("smtp_host") or "").strip(),
|
||||
smtp_port=int(data.get("smtp_port") or 465),
|
||||
smtp_security=_smtp_security_mode({"smtp_security": data.get("smtp_security"), "smtp_port": data.get("smtp_port") or 465}),
|
||||
smtp_user=(data.get("smtp_user") or "").strip(),
|
||||
smtp_password=_enc(data.get("smtp_password") or ""),
|
||||
from_address=(data.get("from_address") or "").strip(),
|
||||
@@ -2977,6 +3028,8 @@ def setup_email_routes():
|
||||
for key in ("imap_port", "smtp_port"):
|
||||
if data.get(key) not in (None, ""):
|
||||
setattr(row, key, int(data[key]))
|
||||
if "smtp_security" in data:
|
||||
row.smtp_security = _smtp_security_mode({"smtp_security": data.get("smtp_security"), "smtp_port": data.get("smtp_port") or row.smtp_port})
|
||||
for key in ("imap_starttls", "enabled"):
|
||||
if key in data:
|
||||
setattr(row, key, bool(data[key]))
|
||||
@@ -3061,6 +3114,7 @@ def setup_email_routes():
|
||||
"imap_starttls": bool(row.imap_starttls),
|
||||
"smtp_host": row.smtp_host or "",
|
||||
"smtp_port": row.smtp_port or 465,
|
||||
"smtp_security": _smtp_security_mode({"smtp_security": getattr(row, "smtp_security", ""), "smtp_port": row.smtp_port}),
|
||||
"smtp_user": row.smtp_user or "",
|
||||
"smtp_password": _decrypt(row.smtp_password or ""),
|
||||
}
|
||||
@@ -3093,13 +3147,12 @@ def setup_email_routes():
|
||||
# port (Dovecot on 31143, etc.) would always fail the SSL
|
||||
# handshake because they're not actually wrapped in TLS.
|
||||
try:
|
||||
if imap_starttls:
|
||||
conn = imaplib.IMAP4(imap_host, imap_port, timeout=10)
|
||||
conn.starttls()
|
||||
elif imap_port == 993:
|
||||
conn = imaplib.IMAP4_SSL(imap_host, imap_port, timeout=10)
|
||||
else:
|
||||
conn = imaplib.IMAP4(imap_host, imap_port, timeout=10)
|
||||
conn = _open_imap_connection(
|
||||
imap_host,
|
||||
imap_port,
|
||||
starttls=imap_starttls,
|
||||
timeout=_IMAP_TIMEOUT_SECONDS,
|
||||
)
|
||||
try:
|
||||
conn.login(imap_user, imap_pass)
|
||||
imap_result = {"ok": True}
|
||||
@@ -3112,14 +3165,16 @@ def setup_email_routes():
|
||||
smtp_host = (body.get("smtp_host") or "").strip()
|
||||
if smtp_host:
|
||||
smtp_port = int(body.get("smtp_port") or 465)
|
||||
smtp_security = _smtp_security_mode({"smtp_security": body.get("smtp_security"), "smtp_port": smtp_port})
|
||||
smtp_user = (body.get("smtp_user") or imap_user).strip()
|
||||
smtp_pass = body.get("smtp_password") or imap_pass
|
||||
try:
|
||||
if smtp_port == 587:
|
||||
smtp = smtplib.SMTP(smtp_host, smtp_port, timeout=10)
|
||||
smtp.starttls()
|
||||
else:
|
||||
if smtp_security == "ssl":
|
||||
smtp = smtplib.SMTP_SSL(smtp_host, smtp_port, timeout=10)
|
||||
else:
|
||||
smtp = smtplib.SMTP(smtp_host, smtp_port, timeout=10)
|
||||
if smtp_security == "starttls":
|
||||
smtp.starttls()
|
||||
try:
|
||||
smtp.login(smtp_user, smtp_pass)
|
||||
smtp_result = {"ok": True}
|
||||
|
||||
@@ -86,7 +86,8 @@ def _load_custom_endpoint() -> dict:
|
||||
"""Load the saved custom embedding endpoint, if any."""
|
||||
try:
|
||||
if os.path.exists(_ENDPOINT_FILE):
|
||||
return json.loads(Path(_ENDPOINT_FILE).read_text(encoding="utf-8"))
|
||||
data = json.loads(Path(_ENDPOINT_FILE).read_text(encoding="utf-8"))
|
||||
return data if isinstance(data, dict) else {}
|
||||
except Exception:
|
||||
pass
|
||||
return {}
|
||||
@@ -160,7 +161,7 @@ def setup_embedding_routes():
|
||||
_downloading[model_name] = True
|
||||
try:
|
||||
# Run in thread to not block the event loop
|
||||
loop = asyncio.get_event_loop()
|
||||
loop = asyncio.get_running_loop()
|
||||
cache = _cache_dir()
|
||||
await loop.run_in_executor(
|
||||
None,
|
||||
@@ -242,6 +243,18 @@ def setup_embedding_routes():
|
||||
if not url:
|
||||
raise HTTPException(400, "URL is required")
|
||||
|
||||
# SSRF hardening: validate the user-supplied URL before any outbound
|
||||
# request. Local-first means loopback/LAN endpoints are allowed by
|
||||
# default; non-HTTP(S) schemes and the cloud metadata range are always
|
||||
# rejected. Set EMBEDDING_BLOCK_PRIVATE_IPS=true for full lockdown.
|
||||
from src.url_safety import check_outbound_url
|
||||
ok, reason = check_outbound_url(
|
||||
url,
|
||||
block_private=os.getenv("EMBEDDING_BLOCK_PRIVATE_IPS", "false").lower() == "true",
|
||||
)
|
||||
if not ok:
|
||||
raise HTTPException(400, f"Rejected endpoint URL: {reason}")
|
||||
|
||||
# Quick health check
|
||||
try:
|
||||
import httpx
|
||||
|
||||
+10
-2
@@ -5,6 +5,15 @@ from fastapi import APIRouter
|
||||
|
||||
CUSTOM_FONTS_DIR = os.path.join("static", "fonts", "custom")
|
||||
FONT_EXTENSIONS = {".ttf", ".otf", ".woff", ".woff2"}
|
||||
FAMILY_SUFFIX_WORDS = ("Display", "Rounded", "Serif", "Sans", "Mono", "Code", "Text")
|
||||
|
||||
|
||||
def _split_family_token(token):
|
||||
"""Split common compact font-family suffixes without breaking brand names."""
|
||||
for suffix in FAMILY_SUFFIX_WORDS:
|
||||
if token.endswith(suffix) and len(token) > len(suffix):
|
||||
return f"{token[:-len(suffix)]} {suffix}"
|
||||
return re.sub(r'(?<=[a-z])(?=[A-Z])', ' ', token)
|
||||
|
||||
|
||||
def _derive_family(filename):
|
||||
@@ -15,10 +24,9 @@ def _derive_family(filename):
|
||||
r'[-_ ]?(Thin|ExtraLight|UltraLight|Light|Regular|Medium|SemiBold|DemiBold|Bold|ExtraBold|UltraBold|Black|Heavy|Italic|Oblique|Variable|VF)$',
|
||||
'', name, flags=re.IGNORECASE
|
||||
)
|
||||
# Insert spaces before uppercase runs: "JetBrainsMono" → "Jet Brains Mono"
|
||||
name = re.sub(r'(?<=[a-z])(?=[A-Z])', ' ', name)
|
||||
# Replace dashes/underscores with spaces
|
||||
name = re.sub(r'[-_]+', ' ', name).strip()
|
||||
name = " ".join(_split_family_token(part) for part in name.split())
|
||||
return name or filename
|
||||
|
||||
|
||||
|
||||
@@ -32,10 +32,21 @@ def _extract_exif(content: bytes) -> dict:
|
||||
from PIL import Image
|
||||
from io import BytesIO
|
||||
img = Image.open(BytesIO(content))
|
||||
# Read the raw EXIF before any transpose: exif_transpose strips the
|
||||
# orientation tag and with it the parsed EXIF view.
|
||||
exif = img._getexif() if hasattr(img, '_getexif') else None
|
||||
|
||||
# Record DISPLAY dimensions (EXIF-rotated), matching upload_handler.
|
||||
# A phone photo with Orientation 6/8 is stored landscape but shown
|
||||
# portrait, so the raw width/height swap the aspect ratio.
|
||||
try:
|
||||
from PIL import ImageOps
|
||||
img = ImageOps.exif_transpose(img) or img
|
||||
except Exception:
|
||||
pass
|
||||
result["width"] = img.width
|
||||
result["height"] = img.height
|
||||
|
||||
exif = img._getexif() if hasattr(img, '_getexif') else None
|
||||
if not exif:
|
||||
return result
|
||||
|
||||
@@ -110,9 +121,17 @@ def _image_to_dict(img: GalleryImage, session_name: str = None) -> Dict[str, Any
|
||||
|
||||
|
||||
def _owner_filter(q, user):
|
||||
"""Apply owner filtering to a gallery query."""
|
||||
"""Apply owner filtering to a gallery query.
|
||||
|
||||
When auth is disabled (single-user mode) get_current_user returns None
|
||||
and there is no per-user scoping. The main library list and stats already
|
||||
treat None as "show everything" (`if user is not None`), so this helper
|
||||
must too — otherwise the tag/model filter sidebars come back empty and the
|
||||
tag-cleanup endpoints (clear-user-tags, clear-ai-tags, dedupe-tags)
|
||||
silently affect zero rows in the most common self-hosted deployment.
|
||||
"""
|
||||
if user is None:
|
||||
return q.filter(False)
|
||||
return q
|
||||
return q.filter(GalleryImage.owner == user)
|
||||
|
||||
|
||||
|
||||
@@ -3,6 +3,9 @@
|
||||
import os
|
||||
import hashlib
|
||||
import logging
|
||||
import re
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, Optional
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Query, Request
|
||||
@@ -17,6 +20,14 @@ from routes.gallery_helpers import (
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _sanitize_gallery_filename(filename: str) -> str:
|
||||
"""Return a local filename safe to join under generated_images."""
|
||||
safe_name = re.sub(r"[^A-Za-z0-9._-]", "_", Path(filename or "").name)[:128]
|
||||
if not safe_name or safe_name in {".", ".."}:
|
||||
safe_name = uuid.uuid4().hex[:12]
|
||||
return safe_name
|
||||
|
||||
def setup_gallery_routes() -> APIRouter:
|
||||
router = APIRouter(tags=["gallery"])
|
||||
|
||||
@@ -122,7 +133,7 @@ def setup_gallery_routes() -> APIRouter:
|
||||
content = await file.read()
|
||||
img_dir = Path("data/generated_images")
|
||||
img_dir.mkdir(parents=True, exist_ok=True)
|
||||
img_path = img_dir / img.filename
|
||||
img_path = img_dir / _sanitize_gallery_filename(img.filename)
|
||||
img_path.write_bytes(content)
|
||||
|
||||
# Refresh dimensions in case the editor resized the canvas.
|
||||
@@ -912,6 +923,16 @@ def setup_gallery_routes() -> APIRouter:
|
||||
body = await request.json()
|
||||
# Use endpoint from request body (editor dropdown) or fall back to DB lookup
|
||||
base = (body.pop("_endpoint", "") or "").rstrip("/")
|
||||
# SSRF hardening: validate a client-supplied endpoint before any
|
||||
# outbound request (mirrors routes/embedding_routes.py).
|
||||
if base:
|
||||
from src.url_safety import check_outbound_url
|
||||
ok, reason = check_outbound_url(
|
||||
base,
|
||||
block_private=os.getenv("IMAGE_BLOCK_PRIVATE_IPS", "false").lower() == "true",
|
||||
)
|
||||
if not ok:
|
||||
raise HTTPException(400, f"Rejected endpoint URL: {reason}")
|
||||
chosen_model = (body.pop("_model", "") or "").strip()
|
||||
api_key = None
|
||||
if not base:
|
||||
@@ -1104,6 +1125,18 @@ def setup_gallery_routes() -> APIRouter:
|
||||
raise HTTPException(400, "No image provided")
|
||||
|
||||
endpoint = (body.get("_endpoint") or "").rstrip("/")
|
||||
# SSRF hardening: a client-supplied endpoint is fetched server-side
|
||||
# below, so validate it first (mirrors routes/embedding_routes.py).
|
||||
# Local-first means loopback/LAN is allowed by default; the cloud
|
||||
# metadata range and non-HTTP(S) schemes are always rejected.
|
||||
if endpoint:
|
||||
from src.url_safety import check_outbound_url
|
||||
ok, reason = check_outbound_url(
|
||||
endpoint,
|
||||
block_private=os.getenv("IMAGE_BLOCK_PRIVATE_IPS", "false").lower() == "true",
|
||||
)
|
||||
if not ok:
|
||||
raise HTTPException(400, f"Rejected endpoint URL: {reason}")
|
||||
model = (body.get("_model") or "").strip()
|
||||
|
||||
base = endpoint
|
||||
@@ -1125,7 +1158,7 @@ def setup_gallery_routes() -> APIRouter:
|
||||
db = SessionLocal()
|
||||
try:
|
||||
for ep in db.query(ModelEndpoint).all():
|
||||
if ep.base_url.rstrip("/").rstrip("/v1") == base.rstrip("/v1"):
|
||||
if ep.base_url.rstrip("/").removesuffix("/v1").rstrip("/") == base.rstrip("/").removesuffix("/v1").rstrip("/"):
|
||||
api_key = ep.api_key
|
||||
break
|
||||
finally:
|
||||
@@ -1696,7 +1729,7 @@ def setup_gallery_routes() -> APIRouter:
|
||||
return {"error": "No vision-capable endpoint configured"}
|
||||
|
||||
# Call vision model — format differs between Anthropic and OpenAI
|
||||
from src.llm_core import _detect_provider
|
||||
from src.llm_core import _detect_provider, _restricts_temperature, _uses_max_completion_tokens
|
||||
provider = _detect_provider(chat_url)
|
||||
tag_prompt = (
|
||||
"Analyze this photo. Return ONLY a comma-separated list of tags. "
|
||||
@@ -1721,6 +1754,7 @@ def setup_gallery_routes() -> APIRouter:
|
||||
}],
|
||||
}
|
||||
else:
|
||||
_tok_key = "max_completion_tokens" if _uses_max_completion_tokens(model_name) else "max_tokens"
|
||||
payload = {
|
||||
"model": model_name,
|
||||
"messages": [{
|
||||
@@ -1730,9 +1764,12 @@ def setup_gallery_routes() -> APIRouter:
|
||||
{"type": "image_url", "image_url": {"url": f"data:{mime};base64,{b64}"}},
|
||||
],
|
||||
}],
|
||||
"max_tokens": 200,
|
||||
_tok_key: 200,
|
||||
"temperature": 0.3,
|
||||
}
|
||||
# Reasoning models (o1/o3/o4/gpt-5) reject an explicit temperature.
|
||||
if _restricts_temperature(model_name):
|
||||
payload.pop("temperature", None)
|
||||
|
||||
h = {"Content-Type": "application/json"}
|
||||
if headers:
|
||||
|
||||
+17
-10
@@ -58,7 +58,7 @@ def setup_history_routes(session_manager) -> APIRouter:
|
||||
.all()
|
||||
)
|
||||
import json as _json
|
||||
history_dict = []
|
||||
db_history = []
|
||||
for m in db_messages:
|
||||
entry = {"role": m.role, "content": m.content}
|
||||
meta = {}
|
||||
@@ -71,12 +71,19 @@ def setup_history_routes(session_manager) -> APIRouter:
|
||||
meta["timestamp"] = m.timestamp.isoformat() + "Z"
|
||||
if meta:
|
||||
entry["metadata"] = meta
|
||||
history_dict.append(entry)
|
||||
if history_dict:
|
||||
db_history.append(entry)
|
||||
if db_history:
|
||||
# Rebuild in-memory history from the full set so hidden
|
||||
# messages (e.g. compaction summaries) are kept for AI context.
|
||||
session.history = [
|
||||
ChatMessage(role=m["role"], content=m["content"], metadata=m.get("metadata"))
|
||||
for m in history_dict
|
||||
for m in db_history
|
||||
]
|
||||
# Response excludes hidden messages, matching the in-memory path.
|
||||
history_dict = [
|
||||
m for m in db_history
|
||||
if not (m.get("metadata") or {}).get("hidden")
|
||||
]
|
||||
except Exception as e:
|
||||
logger.error(f"DB fallback failed for {session_id}: {e}")
|
||||
finally:
|
||||
@@ -265,7 +272,7 @@ def setup_history_routes(session_manager) -> APIRouter:
|
||||
db_messages = (
|
||||
db.query(DbChatMessage)
|
||||
.filter(DbChatMessage.session_id == session_id, DbChatMessage.role == 'assistant')
|
||||
.order_by(DbChatMessage.created_at.desc())
|
||||
.order_by(DbChatMessage.timestamp.desc())
|
||||
.first()
|
||||
)
|
||||
if db_messages:
|
||||
@@ -320,7 +327,7 @@ def setup_history_routes(session_manager) -> APIRouter:
|
||||
db_msg = (
|
||||
db.query(DbChatMessage)
|
||||
.filter(DbChatMessage.session_id == session_id, DbChatMessage.role == 'assistant')
|
||||
.order_by(DbChatMessage.created_at.desc())
|
||||
.order_by(DbChatMessage.timestamp.desc())
|
||||
.first()
|
||||
)
|
||||
if db_msg:
|
||||
@@ -401,7 +408,7 @@ def setup_history_routes(session_manager) -> APIRouter:
|
||||
db_messages = (
|
||||
db.query(DbChatMessage)
|
||||
.filter(DbChatMessage.session_id == session_id)
|
||||
.order_by(DbChatMessage.created_at)
|
||||
.order_by(DbChatMessage.timestamp)
|
||||
.all()
|
||||
)
|
||||
# Find last two assistant messages in DB
|
||||
@@ -477,10 +484,10 @@ def setup_history_routes(session_manager) -> APIRouter:
|
||||
|
||||
@router.get("/api/conversations/topics")
|
||||
async def get_conversation_topics(request: Request) -> Dict[str, Any]:
|
||||
from src.auth_helpers import get_current_user
|
||||
user = get_current_user(request)
|
||||
from src.auth_helpers import require_user
|
||||
user = require_user(request)
|
||||
try:
|
||||
return analyze_topics(session_manager, owner=user)
|
||||
return analyze_topics(session_manager, owner=user or None)
|
||||
except Exception as e:
|
||||
raise HTTPException(500, f"Topic analysis failed: {e}")
|
||||
|
||||
|
||||
+152
-76
@@ -1,87 +1,105 @@
|
||||
import re
|
||||
from copy import deepcopy
|
||||
|
||||
from fastapi import APIRouter
|
||||
|
||||
|
||||
# Backends the manual hardware simulator accepts. Must stay a subset of what
|
||||
# services.hwfit.fit understands so a simulated box ranks like a real one:
|
||||
# "metal" routes through the Apple-Silicon path (GGUF-only, llama.cpp/Ollama),
|
||||
# the CPU backends through the RAM/offload path, cuda/rocm through vLLM.
|
||||
_MANUAL_BACKENDS = {"cuda", "rocm", "metal", "cpu_x86", "cpu_arm"}
|
||||
|
||||
|
||||
def _apply_manual_hardware(system, manual_mode="", manual_gpu_count="", manual_vram_gb="", manual_ram_gb="", manual_backend=""):
|
||||
"""Manual hardware is a "what if I had this setup" simulator —
|
||||
REPLACES the detected hardware entirely instead of adding to it.
|
||||
|
||||
The previous additive behavior averaged the manual VRAM across
|
||||
all GPUs (base + manual), which meant adding "1× 400 GB" on top
|
||||
of "2× 70 GB" only nudged the per-GPU cap from 70 to 180 GB
|
||||
(= 540 / 3), so GGUF models bigger than that still didn't surface
|
||||
— exactly the "cap stuck at detected level" bug the user hit.
|
||||
"""
|
||||
manual_mode = (manual_mode or "").lower()
|
||||
if manual_mode not in {"gpu", "ram"}:
|
||||
return system
|
||||
|
||||
try:
|
||||
override_ram_gb = float(manual_ram_gb) if manual_ram_gb else 0
|
||||
except ValueError:
|
||||
override_ram_gb = 0
|
||||
override_ram_gb = max(0.0, override_ram_gb)
|
||||
if override_ram_gb:
|
||||
# Replace RAM, don't add. The number in the field is the
|
||||
# TOTAL system memory the user wants to simulate.
|
||||
system["available_ram_gb"] = round(override_ram_gb, 1)
|
||||
system["total_ram_gb"] = round(override_ram_gb, 1)
|
||||
system["manual_hardware"] = True
|
||||
|
||||
if manual_mode == "ram":
|
||||
# RAM-only simulation — wipe GPU entirely so the ranker uses
|
||||
# CPU/RAM paths.
|
||||
system["has_gpu"] = False
|
||||
system["gpu_name"] = None
|
||||
system["gpu_vram_gb"] = 0
|
||||
system["gpu_count"] = 0
|
||||
system["gpus"] = []
|
||||
system["gpu_groups"] = []
|
||||
system["backend"] = "cpu_x86"
|
||||
system.pop("unified_memory", None)
|
||||
return system
|
||||
|
||||
try:
|
||||
count = int(manual_gpu_count) if manual_gpu_count else 1
|
||||
except ValueError:
|
||||
count = 1
|
||||
try:
|
||||
vram_each = float(manual_vram_gb) if manual_vram_gb else 8.0
|
||||
except ValueError:
|
||||
vram_each = 8.0
|
||||
count = max(1, min(count, 16))
|
||||
vram_each = max(1.0, vram_each)
|
||||
backend = (manual_backend or system.get("backend") or "cuda").lower()
|
||||
if backend not in _MANUAL_BACKENDS:
|
||||
backend = "cuda"
|
||||
total_vram = round(vram_each * count, 1)
|
||||
gpu_name = f"Simulated {backend.upper()} GPU" + (f" × {count}" if count > 1 else "")
|
||||
system["has_gpu"] = True
|
||||
system["gpu_name"] = gpu_name
|
||||
system["gpu_vram_gb"] = total_vram
|
||||
system["gpu_count"] = count
|
||||
system["gpus"] = [
|
||||
{"index": i, "name": gpu_name, "vram_gb": vram_each}
|
||||
for i in range(count)
|
||||
]
|
||||
# Single homogeneous pool — vram_each here is the ACTUAL per-GPU
|
||||
# VRAM the user entered, not an average. That's the whole point:
|
||||
# raising vram_each lifts the per-GPU cap (GGUF, tensor-parallel
|
||||
# math) all the way up, not just by a small fraction.
|
||||
system["gpu_groups"] = [{
|
||||
"name": gpu_name,
|
||||
"vram_each": vram_each,
|
||||
"count": count,
|
||||
"indices": list(range(count)),
|
||||
"vram_total": total_vram,
|
||||
}]
|
||||
system["homogeneous"] = True
|
||||
system["backend"] = backend
|
||||
# Apple Silicon shares one unified memory pool with the GPU; flag it so
|
||||
# the API/UI report it the way real Metal detection does. Discrete GPUs
|
||||
# (cuda/rocm) and the CPU backends carry separate VRAM, so clear any
|
||||
# stale flag a previous detection left on the dict.
|
||||
if backend == "metal":
|
||||
system["unified_memory"] = True
|
||||
else:
|
||||
system.pop("unified_memory", None)
|
||||
return system
|
||||
|
||||
|
||||
def setup_hwfit_routes():
|
||||
router = APIRouter(prefix="/api/hwfit", tags=["hwfit"])
|
||||
|
||||
def _apply_manual_hardware(system, manual_mode="", manual_gpu_count="", manual_vram_gb="", manual_ram_gb="", manual_backend=""):
|
||||
"""Manual hardware is a "what if I had this setup" simulator —
|
||||
REPLACES the detected hardware entirely instead of adding to it.
|
||||
|
||||
The previous additive behavior averaged the manual VRAM across
|
||||
all GPUs (base + manual), which meant adding "1× 400 GB" on top
|
||||
of "2× 70 GB" only nudged the per-GPU cap from 70 to 180 GB
|
||||
(= 540 / 3), so GGUF models bigger than that still didn't surface
|
||||
— exactly the "cap stuck at detected level" bug the user hit.
|
||||
"""
|
||||
manual_mode = (manual_mode or "").lower()
|
||||
if manual_mode not in {"gpu", "ram"}:
|
||||
return system
|
||||
|
||||
try:
|
||||
override_ram_gb = float(manual_ram_gb) if manual_ram_gb else 0
|
||||
except ValueError:
|
||||
override_ram_gb = 0
|
||||
override_ram_gb = max(0.0, override_ram_gb)
|
||||
if override_ram_gb:
|
||||
# Replace RAM, don't add. The number in the field is the
|
||||
# TOTAL system memory the user wants to simulate.
|
||||
system["available_ram_gb"] = round(override_ram_gb, 1)
|
||||
system["total_ram_gb"] = round(override_ram_gb, 1)
|
||||
system["manual_hardware"] = True
|
||||
|
||||
if manual_mode == "ram":
|
||||
# RAM-only simulation — wipe GPU entirely so the ranker uses
|
||||
# CPU/RAM paths.
|
||||
system["has_gpu"] = False
|
||||
system["gpu_name"] = None
|
||||
system["gpu_vram_gb"] = 0
|
||||
system["gpu_count"] = 0
|
||||
system["gpus"] = []
|
||||
system["gpu_groups"] = []
|
||||
system["backend"] = "cpu_x86"
|
||||
return system
|
||||
|
||||
try:
|
||||
count = int(manual_gpu_count) if manual_gpu_count else 1
|
||||
except ValueError:
|
||||
count = 1
|
||||
try:
|
||||
vram_each = float(manual_vram_gb) if manual_vram_gb else 8.0
|
||||
except ValueError:
|
||||
vram_each = 8.0
|
||||
count = max(1, min(count, 16))
|
||||
vram_each = max(1.0, vram_each)
|
||||
backend = (manual_backend or system.get("backend") or "cuda").lower()
|
||||
if backend not in {"cuda", "rocm", "cpu_x86", "cpu_arm"}:
|
||||
backend = "cuda"
|
||||
total_vram = round(vram_each * count, 1)
|
||||
gpu_name = f"Simulated {backend.upper()} GPU" + (f" × {count}" if count > 1 else "")
|
||||
system["has_gpu"] = True
|
||||
system["gpu_name"] = gpu_name
|
||||
system["gpu_vram_gb"] = total_vram
|
||||
system["gpu_count"] = count
|
||||
system["gpus"] = [
|
||||
{"index": i, "name": gpu_name, "vram_gb": vram_each}
|
||||
for i in range(count)
|
||||
]
|
||||
# Single homogeneous pool — vram_each here is the ACTUAL per-GPU
|
||||
# VRAM the user entered, not an average. That's the whole point:
|
||||
# raising vram_each lifts the per-GPU cap (GGUF, tensor-parallel
|
||||
# math) all the way up, not just by a small fraction.
|
||||
system["gpu_groups"] = [{
|
||||
"name": gpu_name,
|
||||
"vram_each": vram_each,
|
||||
"count": count,
|
||||
"indices": list(range(count)),
|
||||
"vram_total": total_vram,
|
||||
}]
|
||||
system["homogeneous"] = True
|
||||
system["backend"] = backend
|
||||
return system
|
||||
|
||||
@router.get("/system")
|
||||
def get_system(host: str = "", ssh_port: str = "", platform: str = "", fresh: bool = False):
|
||||
"""Detect and return current system hardware info. Pass host=user@server for remote.
|
||||
@@ -181,6 +199,64 @@ def setup_hwfit_routes():
|
||||
results = rank_models(system, use_case=use_case or None, limit=limit, search=search or None, sort=sort, quant=quant or None, target_context=target_context, fit_only=fit_only)
|
||||
return {"system": system, "models": results}
|
||||
|
||||
@router.get("/profiles")
|
||||
def get_serve_profiles(model: str = "", host: str = "", ssh_port: str = "", platform: str = "", fresh: bool = False, serve_weights_gb: float = 0.0, serve_quant: str = ""):
|
||||
"""Compute llama.cpp serve profiles (Quality/Balanced/Speed) for `model`
|
||||
against the detected hardware on `host` (or local). Returns concrete
|
||||
flags (n_gpu_layers, n_cpu_moe, cache_type, ctx) the serve UI can apply.
|
||||
|
||||
`model` is matched against the catalog by name; if it's not in the
|
||||
catalog (e.g. an ad-hoc HF repo), pass enough hints via a minimal synthetic
|
||||
entry isn't possible here, so we return [] and the UI keeps manual flags.
|
||||
"""
|
||||
from services.hwfit.hardware import detect_system
|
||||
from services.hwfit.models import get_models
|
||||
from services.hwfit.profiles import compute_serve_profiles
|
||||
system = detect_system(host=host, ssh_port=ssh_port, platform=platform, fresh=fresh)
|
||||
if system.get("error"):
|
||||
return {"system": system, "profiles": [], "error": system["error"]}
|
||||
catalog = {m.get("name"): m for m in (get_models() or [])}
|
||||
|
||||
def _norm(s):
|
||||
# Normalize for matching: drop org/ prefix, a trailing -GGUF/-gguf
|
||||
# marker, and any quant tag, lowercase. So "DeepSeek-Coder-V2-Lite-
|
||||
# Instruct-GGUF" (a local folder name) matches catalog entry
|
||||
# "deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct".
|
||||
s = (s or "").lower().strip()
|
||||
s = s.split("/")[-1] # drop org prefix
|
||||
s = re.sub(r"[-_.]?gguf$", "", s) # drop trailing gguf marker
|
||||
s = re.sub(r"[-_.](q\d[^/]*|iq\d[^/]*|fp8|bf16|f16|awq[^/]*|gptq[^/]*)$", "", s)
|
||||
return s
|
||||
|
||||
m = catalog.get(model)
|
||||
if m is None and model:
|
||||
want = _norm(model)
|
||||
for name, entry in catalog.items():
|
||||
nn = _norm(name)
|
||||
if nn and (nn == want or want.endswith(nn) or nn.endswith(want)):
|
||||
m = entry
|
||||
break
|
||||
if m is None:
|
||||
return {"system": system, "profiles": [], "error": "model not in catalog"}
|
||||
# Surface the model's trained context limit so the serve UI can clamp a
|
||||
# user-typed context down to it (asking for ctx > n_ctx_train overflows
|
||||
# and, with a quantized KV cache, can crash the GPU).
|
||||
model_ctx_max = 0
|
||||
for k in ("context_length", "max_position_embeddings", "n_ctx_train", "context"):
|
||||
v = m.get(k)
|
||||
if isinstance(v, (int, float)) and v > 0:
|
||||
model_ctx_max = int(v)
|
||||
break
|
||||
return {
|
||||
"system": system,
|
||||
"profiles": compute_serve_profiles(
|
||||
system, m,
|
||||
serve_weights_gb=(serve_weights_gb or None),
|
||||
serve_quant=(serve_quant or None),
|
||||
),
|
||||
"model_ctx_max": model_ctx_max,
|
||||
}
|
||||
|
||||
@router.get("/image-models")
|
||||
def get_image_models(sort: str = "fit", search: str = "", host: str = "", gpu_count: str = "", ssh_port: str = "", platform: str = "", fresh: bool = False, manual_mode: str = "", manual_gpu_count: str = "", manual_vram_gb: str = "", manual_ram_gb: str = "", manual_backend: str = "", ignore_detected_gpu: bool = False, ignore_detected_ram: bool = False):
|
||||
"""Rank image generation models against detected hardware."""
|
||||
|
||||
@@ -27,7 +27,7 @@ from src.request_models import MemoryAddRequest
|
||||
from core.database import SessionLocal
|
||||
from src.llm_core import llm_call_async
|
||||
from services.memory.memory_extractor import audit_memories
|
||||
from src.auth_helpers import get_current_user
|
||||
from src.auth_helpers import get_current_user, require_user
|
||||
from src.endpoint_resolver import resolve_endpoint
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -191,8 +191,7 @@ def setup_memory_routes(memory_manager: MemoryManager, session_manager: SessionM
|
||||
@router.post("/extract")
|
||||
async def extract_memory(request: Request, session: str = Form(...)) -> Dict[str, List[str]]:
|
||||
"""Analyze a session's chat history and return memory suggestions."""
|
||||
if not get_current_user(request):
|
||||
raise HTTPException(401, "Not authenticated")
|
||||
require_user(request)
|
||||
try:
|
||||
sess = session_manager.get_session(session)
|
||||
except KeyError:
|
||||
|
||||
+331
-125
@@ -1,73 +1,213 @@
|
||||
# routes/model_routes.py
|
||||
"""Routes for model and provider management."""
|
||||
import os
|
||||
import re
|
||||
import uuid
|
||||
import json
|
||||
import socket
|
||||
import time as _time
|
||||
import logging
|
||||
import httpx
|
||||
from datetime import datetime
|
||||
from typing import List, Dict, Any, Optional
|
||||
from urllib.parse import urlparse
|
||||
from urllib.parse import urlparse, urlunparse
|
||||
from fastapi import APIRouter, HTTPException, Form, Query, Body, Request
|
||||
from pydantic import BaseModel
|
||||
from fastapi.responses import StreamingResponse
|
||||
from core.database import SessionLocal, ModelEndpoint, Session as DbSession
|
||||
from core.middleware import require_admin
|
||||
from src.llm_core import _detect_provider, ANTHROPIC_MODELS
|
||||
from src.llm_core import _detect_provider, _host_match, ANTHROPIC_MODELS
|
||||
from src.settings import load_settings as _load_settings, save_settings as _save_settings
|
||||
from src.endpoint_resolver import normalize_base as _normalize_base, build_chat_url
|
||||
from src.auth_helpers import owner_filter
|
||||
from src.endpoint_resolver import (
|
||||
normalize_base as _normalize_base,
|
||||
build_chat_url,
|
||||
build_models_url,
|
||||
build_headers,
|
||||
)
|
||||
from src.auth_helpers import _auth_disabled, owner_filter
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_SPEECH_ENDPOINT_SETTINGS = (
|
||||
("tts_provider", "tts_model", "tts-1", "Text to Speech"),
|
||||
("stt_provider", "stt_model", "base", "Speech to Text"),
|
||||
)
|
||||
|
||||
def _anthropic_api_root(base: str) -> str:
|
||||
"""Return Anthropic's API root without duplicating /v1."""
|
||||
base = (base or "").strip().rstrip("/")
|
||||
host = urlparse(base).hostname or ""
|
||||
if host.endswith("anthropic.com") and base.endswith("/v1"):
|
||||
return base[:-3].rstrip("/")
|
||||
return base
|
||||
_ENDPOINT_SETTING_FIELDS = {
|
||||
"default_endpoint_id": ("default_model", "Default Model"),
|
||||
"utility_endpoint_id": ("utility_model", "Utility Model"),
|
||||
"research_endpoint_id": ("research_model", "Deep Research"),
|
||||
"task_endpoint_id": ("task_model", "Background Tasks"),
|
||||
}
|
||||
|
||||
_ENDPOINT_FALLBACK_FIELDS = {
|
||||
"default_model_fallbacks": "Default Model Fallbacks",
|
||||
"utility_model_fallbacks": "Utility Model Fallbacks",
|
||||
"vision_model_fallbacks": "Vision Model Fallbacks",
|
||||
}
|
||||
|
||||
|
||||
def _ollama_api_root(base: str) -> str:
|
||||
"""Return Ollama's native API root without depending on deferred imports."""
|
||||
base = (base or "").strip().rstrip("/")
|
||||
parsed = urlparse(base)
|
||||
host = parsed.hostname or ""
|
||||
path = (parsed.path or "").rstrip("/")
|
||||
if path.endswith("/api"):
|
||||
return base
|
||||
if host.endswith("ollama.com"):
|
||||
root = f"{parsed.scheme}://{parsed.netloc}" if parsed.scheme and parsed.netloc else "https://ollama.com"
|
||||
return root.rstrip("/") + "/api"
|
||||
return base
|
||||
def _speech_settings_using_endpoint(settings: dict, ep_id: str) -> list:
|
||||
"""Return speech settings that reference a model endpoint."""
|
||||
endpoint_ref = f"endpoint:{ep_id}"
|
||||
return [
|
||||
label
|
||||
for provider_key, _, _, label in _SPEECH_ENDPOINT_SETTINGS
|
||||
if (settings.get(provider_key) or "") == endpoint_ref
|
||||
]
|
||||
|
||||
|
||||
def _models_url(base: str) -> str:
|
||||
"""Return provider-specific model-list URL for route-local probing."""
|
||||
provider = _detect_provider(base)
|
||||
host = urlparse(base).hostname or ""
|
||||
if provider == "anthropic" or host.endswith("anthropic.com"):
|
||||
return _anthropic_api_root(base) + "/v1/models"
|
||||
if provider == "ollama" or host.endswith("ollama.com"):
|
||||
return _ollama_api_root(base) + "/tags"
|
||||
return base.rstrip("/") + "/models"
|
||||
def _clear_speech_settings_for_endpoint(settings: dict, ep_id: str) -> list:
|
||||
"""Reset speech settings that reference a model endpoint."""
|
||||
endpoint_ref = f"endpoint:{ep_id}"
|
||||
cleared = []
|
||||
for provider_key, model_key, default_model, label in _SPEECH_ENDPOINT_SETTINGS:
|
||||
if (settings.get(provider_key) or "") == endpoint_ref:
|
||||
settings[provider_key] = "disabled"
|
||||
settings[model_key] = default_model
|
||||
cleared.append(label)
|
||||
return cleared
|
||||
|
||||
|
||||
def _provider_headers(api_key: Optional[str], base: str) -> Dict[str, str]:
|
||||
"""Build provider auth headers without depending on import-time stubs."""
|
||||
if not api_key:
|
||||
return {}
|
||||
provider = _detect_provider(base)
|
||||
host = urlparse(base).hostname or ""
|
||||
if provider == "anthropic" or host.endswith("anthropic.com"):
|
||||
return {
|
||||
"x-api-key": api_key,
|
||||
"anthropic-version": "2023-06-01",
|
||||
}
|
||||
return {"Authorization": f"Bearer {api_key}"}
|
||||
def _endpoint_settings_using_endpoint(settings: dict, ep_id: str, *, include_speech: bool = False) -> list:
|
||||
"""Return labels for settings and fallback chains that reference an endpoint."""
|
||||
affected = []
|
||||
for ep_key, (_, label) in _ENDPOINT_SETTING_FIELDS.items():
|
||||
if (settings.get(ep_key) or "") == ep_id:
|
||||
affected.append(label)
|
||||
for fallback_key, label in _ENDPOINT_FALLBACK_FIELDS.items():
|
||||
chain = settings.get(fallback_key) or []
|
||||
if any(isinstance(entry, dict) and (entry.get("endpoint_id") or "") == ep_id for entry in chain):
|
||||
affected.append(label)
|
||||
if include_speech:
|
||||
affected.extend(_speech_settings_using_endpoint(settings, ep_id))
|
||||
return affected
|
||||
|
||||
|
||||
def _clear_endpoint_settings_for_endpoint(settings: dict, ep_id: str, *, include_speech: bool = False) -> list:
|
||||
"""Remove an endpoint from direct settings and model fallback chains."""
|
||||
cleared = []
|
||||
for ep_key, (model_key, label) in _ENDPOINT_SETTING_FIELDS.items():
|
||||
if (settings.get(ep_key) or "") == ep_id:
|
||||
settings[ep_key] = ""
|
||||
settings[model_key] = ""
|
||||
cleared.append(label)
|
||||
for fallback_key, label in _ENDPOINT_FALLBACK_FIELDS.items():
|
||||
chain = settings.get(fallback_key)
|
||||
if not isinstance(chain, list):
|
||||
continue
|
||||
kept = [
|
||||
entry for entry in chain
|
||||
if not (isinstance(entry, dict) and (entry.get("endpoint_id") or "") == ep_id)
|
||||
]
|
||||
if len(kept) != len(chain):
|
||||
settings[fallback_key] = kept
|
||||
cleared.append(label)
|
||||
if include_speech:
|
||||
cleared.extend(_clear_speech_settings_for_endpoint(settings, ep_id))
|
||||
return cleared
|
||||
|
||||
|
||||
def _clear_user_pref_endpoint_refs(all_prefs: dict, ep_id: str) -> int:
|
||||
"""Remove endpoint references from scoped or legacy-flat user preferences."""
|
||||
if not isinstance(all_prefs, dict):
|
||||
return 0
|
||||
users = all_prefs.get("_users")
|
||||
pref_sets = users.values() if isinstance(users, dict) else [all_prefs]
|
||||
cleared_users = 0
|
||||
for prefs in pref_sets:
|
||||
if isinstance(prefs, dict) and _clear_endpoint_settings_for_endpoint(prefs, ep_id):
|
||||
cleared_users += 1
|
||||
return cleared_users
|
||||
|
||||
|
||||
# Loopback hosts a user might type for a local model server (LM Studio,
|
||||
# llama.cpp, vLLM, …). Inside Docker these point at the *container*, not the
|
||||
# host the server actually runs on.
|
||||
_ANY_BIND_HOSTS = {"0.0.0.0", "::"}
|
||||
_LOOPBACK_HOSTS = {"localhost", "127.0.0.1", "::1", *_ANY_BIND_HOSTS}
|
||||
|
||||
|
||||
def _docker_host_gateway_reachable() -> bool:
|
||||
"""True when we run inside a container whose host is reachable via
|
||||
``host.docker.internal`` (compose maps it to ``host-gateway``). Returns
|
||||
False on native installs and on container setups without the mapping, so
|
||||
the loopback rewrite below stays a no-op there."""
|
||||
in_container = os.path.exists("/.dockerenv")
|
||||
if not in_container:
|
||||
try:
|
||||
with open("/proc/1/cgroup", encoding="utf-8") as fh:
|
||||
in_container = any(t in fh.read() for t in ("docker", "containerd", "kubepods"))
|
||||
except OSError:
|
||||
in_container = False
|
||||
if not in_container:
|
||||
return False
|
||||
try:
|
||||
socket.getaddrinfo("host.docker.internal", None)
|
||||
return True
|
||||
except OSError:
|
||||
return False
|
||||
|
||||
def _container_loopback_reachable(base_url: str, timeout: float = 0.2) -> bool:
|
||||
"""True when the requested loopback host:port is already reachable from
|
||||
inside the current container.
|
||||
|
||||
This distinguishes "a model server running alongside Odysseus in the same
|
||||
container" from "a model server running on the Docker host". Only the
|
||||
latter should be rewritten to host.docker.internal.
|
||||
"""
|
||||
try:
|
||||
parsed = urlparse(base_url)
|
||||
except Exception:
|
||||
return False
|
||||
host = (parsed.hostname or "").lower()
|
||||
port = parsed.port
|
||||
if host not in _LOOPBACK_HOSTS or not port:
|
||||
return False
|
||||
probe_host = "::1" if host == "::1" else "127.0.0.1"
|
||||
family = socket.AF_INET6 if probe_host == "::1" else socket.AF_INET
|
||||
try:
|
||||
with socket.socket(family, socket.SOCK_STREAM) as sock:
|
||||
sock.settimeout(timeout)
|
||||
sock.connect((probe_host, port))
|
||||
return True
|
||||
except OSError:
|
||||
return False
|
||||
|
||||
|
||||
def _rewrite_loopback_for_docker(base_url: str, *, container_local: bool = False) -> str:
|
||||
"""Rewrite a loopback model-endpoint URL to ``host.docker.internal`` when
|
||||
running in Docker. A URL like ``http://localhost:1234/v1`` (the LM Studio
|
||||
default) otherwise targets the Odysseus container itself, so the probe gets
|
||||
a connection error and the endpoint is rejected with a misleading "No
|
||||
models found for that provider/key".
|
||||
|
||||
Cookbook local serves are the opposite case: Odysseus started the model
|
||||
server inside the same container/process environment, so the saved endpoint
|
||||
must remain container-local. In that mode, normalize a bind address such as
|
||||
0.0.0.0 to a connectable loopback host, but do not jump to the Docker host.
|
||||
"""
|
||||
try:
|
||||
parsed = urlparse(base_url)
|
||||
except Exception:
|
||||
return base_url
|
||||
host = (parsed.hostname or "").lower()
|
||||
if host not in _LOOPBACK_HOSTS:
|
||||
return base_url
|
||||
if container_local:
|
||||
if host in _ANY_BIND_HOSTS:
|
||||
netloc = "127.0.0.1" + (f":{parsed.port}" if parsed.port else "")
|
||||
return urlunparse(parsed._replace(netloc=netloc))
|
||||
return base_url
|
||||
if host in _ANY_BIND_HOSTS and not _docker_host_gateway_reachable():
|
||||
netloc = "127.0.0.1" + (f":{parsed.port}" if parsed.port else "")
|
||||
return urlunparse(parsed._replace(netloc=netloc))
|
||||
if _container_loopback_reachable(base_url):
|
||||
return base_url
|
||||
if not _docker_host_gateway_reachable():
|
||||
return base_url
|
||||
netloc = "host.docker.internal" + (f":{parsed.port}" if parsed.port else "")
|
||||
return urlunparse(parsed._replace(netloc=netloc))
|
||||
|
||||
|
||||
# ── Curated model lists per provider ──
|
||||
@@ -84,10 +224,13 @@ _PROVIDER_CURATED = {
|
||||
"claude-sonnet-4-5", "claude-haiku-3-5",
|
||||
],
|
||||
"zai": [
|
||||
"glm-5", "glm-4.7", "glm-4.7-flash",
|
||||
"glm-5", "glm-5.1", "glm-5v-turbo", "glm-4.7", "glm-4.7-flash",
|
||||
"glm-4.6", "glm-4.6v",
|
||||
"glm-4.5", "glm-4.5v", "glm-4.5-air", "glm-4.5-flash",
|
||||
],
|
||||
"zai-coding": [
|
||||
"glm-5.1", "glm-5v-turbo", "glm-5-turbo", "glm-4.7", "glm-4.5-air",
|
||||
],
|
||||
"deepseek": [
|
||||
"deepseek-chat", "deepseek-reasoner",
|
||||
],
|
||||
@@ -122,31 +265,40 @@ _PROVIDER_CURATED = {
|
||||
],
|
||||
}
|
||||
|
||||
# Map URL substrings → curated-list keys for providers whose _detect_provider()
|
||||
# Map hostnames → curated-list keys for providers whose _detect_provider()
|
||||
# returns a generic value (e.g. "openai") but deserve their own curated list.
|
||||
# "openrouter" is a sentinel meaning "no curation — show all models as curated".
|
||||
_URL_TO_CURATED = {
|
||||
"z.ai": "zai",
|
||||
"api.deepseek.com": "deepseek",
|
||||
"api.groq.com": "groq",
|
||||
"api.mistral.ai": "mistral",
|
||||
"api.together.xyz": "together",
|
||||
"api.fireworks.ai": "fireworks",
|
||||
"generativelanguage.googleapis.com": "google",
|
||||
"api.x.ai": "xai",
|
||||
"openrouter.ai": "openrouter",
|
||||
"ollama.com": "ollama",
|
||||
}
|
||||
# Entries are matched by hostname equality or subdomain suffix (via _host_match),
|
||||
# so e.g. "deepseek.com" covers api.deepseek.com without matching the substring
|
||||
# inside an unrelated URL.
|
||||
_HOST_TO_CURATED = (
|
||||
("z.ai", "zai"),
|
||||
("deepseek.com", "deepseek"),
|
||||
("groq.com", "groq"),
|
||||
("mistral.ai", "mistral"),
|
||||
("together.xyz", "together"),
|
||||
("together.ai", "together"),
|
||||
("fireworks.ai", "fireworks"),
|
||||
("googleapis.com", "google"),
|
||||
("x.ai", "xai"),
|
||||
("openrouter.ai", "openrouter"),
|
||||
("ollama.com", "ollama"),
|
||||
)
|
||||
|
||||
|
||||
def _match_provider_curated(base_url: str, provider: str) -> str:
|
||||
"""Return the curated-list key for a given endpoint.
|
||||
|
||||
Checks the base URL against _URL_TO_CURATED first, then falls back
|
||||
to the raw provider string from _detect_provider().
|
||||
Checks path-based overrides first (for hosts serving multiple plans),
|
||||
then matches the base URL's hostname against known providers, and
|
||||
finally falls back to the raw provider string from _detect_provider().
|
||||
"""
|
||||
for substring, key in _URL_TO_CURATED.items():
|
||||
if substring in (base_url or ""):
|
||||
# Path-based overrides for hosts that serve multiple curated lists.
|
||||
parsed = urlparse(base_url)
|
||||
if _host_match(base_url, "z.ai") and "/api/coding" in (parsed.path or ""):
|
||||
return "zai-coding"
|
||||
for domain, key in _HOST_TO_CURATED:
|
||||
if _host_match(base_url, domain):
|
||||
return key
|
||||
return provider
|
||||
|
||||
@@ -235,16 +387,20 @@ def _probe_single_model(base: str, api_key: str, model_id: str, timeout: int = 1
|
||||
elif provider == "ollama":
|
||||
from src.llm_core import _build_ollama_payload
|
||||
target_url = build_chat_url(base)
|
||||
h = _provider_headers(api_key, base)
|
||||
h = build_headers(api_key, base)
|
||||
h["Content-Type"] = "application/json"
|
||||
payload = _build_ollama_payload(model_id, messages, 0.0, 5, stream=False, tools=_test_tools)
|
||||
else:
|
||||
target_url = build_chat_url(base)
|
||||
h = _provider_headers(api_key, base)
|
||||
h = build_headers(api_key, base)
|
||||
h["Content-Type"] = "application/json"
|
||||
from src.llm_core import _uses_max_completion_tokens
|
||||
from src.llm_core import _uses_max_completion_tokens, _restricts_temperature
|
||||
_max_key = "max_completion_tokens" if _uses_max_completion_tokens(model_id) else "max_tokens"
|
||||
payload = {"model": model_id, "messages": messages, _max_key: 5, "temperature": 0.0}
|
||||
payload = {"model": model_id, "messages": messages, _max_key: 5}
|
||||
# Reasoning models (o1/o3/o4/gpt-5) reject an explicit temperature, so a
|
||||
# probe that hardcodes one falsely reports a working endpoint as failing.
|
||||
if not _restricts_temperature(model_id):
|
||||
payload["temperature"] = 0.0
|
||||
if _test_tools:
|
||||
payload["tools"] = _test_tools
|
||||
|
||||
@@ -308,7 +464,7 @@ def _probe_endpoint(base_url: str, api_key: str = None, timeout: int = 5) -> Lis
|
||||
base = resolve_url(_normalize_base(base_url))
|
||||
if _detect_provider(base) == "anthropic":
|
||||
# Try Anthropic's /v1/models endpoint first
|
||||
url = _anthropic_api_root(base) + "/v1/models"
|
||||
url = build_models_url(base)
|
||||
headers = {"anthropic-version": "2023-06-01"}
|
||||
if api_key:
|
||||
headers["x-api-key"] = api_key
|
||||
@@ -331,8 +487,8 @@ def _probe_endpoint(base_url: str, api_key: str = None, timeout: int = 5) -> Lis
|
||||
return []
|
||||
logger.warning(f"Anthropic /v1/models failed, using hardcoded list: {e}")
|
||||
return list(ANTHROPIC_MODELS)
|
||||
url = _models_url(base)
|
||||
headers = _provider_headers(api_key, base)
|
||||
url = build_models_url(base)
|
||||
headers = build_headers(api_key, base)
|
||||
try:
|
||||
r = httpx.get(url, headers=headers, timeout=timeout)
|
||||
r.raise_for_status()
|
||||
@@ -343,6 +499,13 @@ def _probe_endpoint(base_url: str, api_key: str = None, timeout: int = 5) -> Lis
|
||||
if not models:
|
||||
models = [m.get("name") or m.get("model") for m in (data.get("models") or []) if m.get("name") or m.get("model")]
|
||||
if models:
|
||||
# Z.AI coding plan omits some working models from /models;
|
||||
# append curated-only entries for that endpoint only.
|
||||
if _host_match(base, "z.ai") and "/api/coding" in (urlparse(base).path or ""):
|
||||
_ck = _match_provider_curated(base, None)
|
||||
for _e in _PROVIDER_CURATED.get(_ck, []):
|
||||
if _e not in set(models) and not any(m.startswith(_e) for m in models):
|
||||
models.append(_e)
|
||||
return models
|
||||
except httpx.HTTPStatusError as e:
|
||||
if api_key:
|
||||
@@ -387,7 +550,24 @@ def _ping_endpoint(base_url: str, api_key: str = None, timeout: float = 1.5) ->
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
# Ollama exposes /v1/models (OpenAI-compatible) AND native /api/version,
|
||||
# /api/tags. The OpenAI-style GET base + "/models" returns 404 when the
|
||||
# base is the host root or the native /api root (e.g. http://localhost:11434,
|
||||
# http://localhost:11434/api) because /models lives under /v1 there. Treat
|
||||
# 4xx on a port-11434 / Ollama-named base as "try the native paths" rather
|
||||
# than as a definitive offline verdict — Ollama is reachable, it just
|
||||
# doesn't speak OpenAI on that prefix. Without this gate the quickstart
|
||||
# marks an alive Ollama as offline whenever cached_models is empty (issue
|
||||
# #1025): _probe_endpoint() falls through to /api/tags on the same 404, but
|
||||
# _ping_endpoint() was returning before that fallback could run.
|
||||
parsed_base = urlparse(base)
|
||||
looks_like_ollama = (
|
||||
parsed_base.port == 11434
|
||||
or "ollama" in (parsed_base.hostname or "").lower()
|
||||
)
|
||||
|
||||
url = base + "/models"
|
||||
last_error: Optional[str] = None
|
||||
try:
|
||||
r = httpx.get(url, headers=headers, timeout=timeout)
|
||||
if 300 <= r.status_code < 400:
|
||||
@@ -399,17 +579,21 @@ def _ping_endpoint(base_url: str, api_key: str = None, timeout: float = 1.5) ->
|
||||
"error": "That is Odysseus, not a model server. Use the Ollama URL, usually http://host.docker.internal:11434/v1 in Docker.",
|
||||
}
|
||||
return {"reachable": False, "status_code": r.status_code, "error": f"HTTP {r.status_code} redirect"}
|
||||
if r.status_code < 500:
|
||||
return {"reachable": r.status_code < 400, "status_code": r.status_code, "error": None if r.status_code < 400 else f"HTTP {r.status_code}"}
|
||||
if r.status_code < 400:
|
||||
return {"reachable": True, "status_code": r.status_code, "error": None}
|
||||
if r.status_code < 500 and not looks_like_ollama:
|
||||
return {"reachable": False, "status_code": r.status_code, "error": f"HTTP {r.status_code}"}
|
||||
last_error = f"HTTP {r.status_code}"
|
||||
except Exception as e:
|
||||
last_error = str(e)[:120]
|
||||
else:
|
||||
last_error = f"HTTP {r.status_code}"
|
||||
|
||||
try:
|
||||
parsed = urlparse(base)
|
||||
if parsed.port == 11434 or "ollama" in (parsed.hostname or "").lower():
|
||||
root = base[:-3].rstrip("/") if base.endswith("/v1") else base
|
||||
if looks_like_ollama:
|
||||
root = base
|
||||
for suffix in ("/v1", "/api"):
|
||||
if root.endswith(suffix):
|
||||
root = root[: -len(suffix)].rstrip("/")
|
||||
break
|
||||
for path in ("/api/version", "/api/tags"):
|
||||
try:
|
||||
r = httpx.get(root + path, timeout=timeout)
|
||||
@@ -449,6 +633,15 @@ def _model_endpoint_error_message(base_url: str, ping: Dict[str, Any] = None) ->
|
||||
return "No models found for that provider/key."
|
||||
|
||||
|
||||
def _visible_models(cached_models, hidden_models):
|
||||
"""Filter cached model IDs by hidden_models. Returns list of visible IDs."""
|
||||
all_models = json.loads(cached_models) if isinstance(cached_models, str) else (cached_models or [])
|
||||
if not hidden_models:
|
||||
return all_models
|
||||
hidden = set(json.loads(hidden_models) if isinstance(hidden_models, str) else (hidden_models or []))
|
||||
return [m for m in all_models if m not in hidden]
|
||||
|
||||
|
||||
def setup_model_routes(model_discovery):
|
||||
router = APIRouter(prefix="/api")
|
||||
|
||||
@@ -625,7 +818,7 @@ def setup_model_routes(model_discovery):
|
||||
# list to unauthenticated callers.
|
||||
try:
|
||||
auth_mgr = getattr(request.app.state, "auth_manager", None)
|
||||
if not owner and auth_mgr is not None and getattr(auth_mgr, "is_configured", False):
|
||||
if not owner and not _auth_disabled() and auth_mgr is not None and getattr(auth_mgr, "is_configured", False):
|
||||
raise HTTPException(401, "Not authenticated")
|
||||
except HTTPException:
|
||||
raise
|
||||
@@ -746,8 +939,8 @@ def setup_model_routes(model_discovery):
|
||||
entry["error"] = str(e)
|
||||
entry["model_count"] = 0
|
||||
else:
|
||||
url = _models_url(base)
|
||||
headers = _provider_headers(ep.api_key, base)
|
||||
url = build_models_url(base)
|
||||
headers = build_headers(ep.api_key, base)
|
||||
try:
|
||||
t0 = _time.time()
|
||||
r = httpx.get(url, headers=headers, timeout=5)
|
||||
@@ -965,23 +1158,23 @@ def setup_model_routes(model_discovery):
|
||||
require_models: str = Form("false"),
|
||||
model_type: str = Form("llm"),
|
||||
supports_tools: str = Form(""), # "true"/"false"/"" (unknown)
|
||||
container_local: str = Form("false"),
|
||||
# Default `shared=true` → endpoints are visible to all users (the
|
||||
# app's historical behaviour). Admins can pass `shared=false` to
|
||||
# scope a new endpoint to their own account only.
|
||||
shared: str = Form("true"),
|
||||
):
|
||||
require_admin(request)
|
||||
base_url = base_url.strip().rstrip("/")
|
||||
# Normalize: strip trailing /models, /chat/completions, /v1/messages etc to get clean base
|
||||
for suffix in ["/models", "/chat/completions", "/completions", "/v1/messages"]:
|
||||
if base_url.endswith(suffix):
|
||||
base_url = base_url[:-len(suffix)].rstrip("/")
|
||||
base_url = _normalize_base(base_url)
|
||||
if not base_url:
|
||||
raise HTTPException(400, "Base URL is required")
|
||||
# Resolve hostname via Tailscale if DNS fails
|
||||
from src.endpoint_resolver import resolve_url
|
||||
base_url = resolve_url(base_url)
|
||||
# In Docker, manually added loopback URLs usually point at a host-local
|
||||
# server. Cookbook local serves are launched inside Odysseus itself, so
|
||||
# keep those container-local when the frontend marks them as such.
|
||||
base_url = _rewrite_loopback_for_docker(base_url, container_local=_truthy(container_local))
|
||||
|
||||
# Auto-generate name from URL if not provided
|
||||
if not name.strip():
|
||||
@@ -1052,11 +1245,15 @@ def setup_model_routes(model_discovery):
|
||||
)
|
||||
db.add(ep)
|
||||
db.commit()
|
||||
# Auto-set as default chat endpoint if none configured yet
|
||||
# Auto-set as default chat endpoint if none configured yet. Seed
|
||||
# the first CHAT model (not raw model_ids[0]) so we don't pin the
|
||||
# global default to an embedding/tts/etc. entry a provider happens
|
||||
# to list first.
|
||||
settings = _load_settings()
|
||||
if not settings.get("default_endpoint_id"):
|
||||
from src.endpoint_resolver import _first_chat_model
|
||||
settings["default_endpoint_id"] = ep.id
|
||||
settings["default_model"] = model_ids[0] if model_ids else ""
|
||||
settings["default_model"] = _first_chat_model(model_ids) or ""
|
||||
_save_settings(settings)
|
||||
_invalidate_models_cache()
|
||||
_local_probe_cache["data"] = None
|
||||
@@ -1081,14 +1278,12 @@ def setup_model_routes(model_discovery):
|
||||
api_key: str = Form(""),
|
||||
):
|
||||
require_admin(request)
|
||||
base_url = base_url.strip().rstrip("/")
|
||||
for suffix in ["/models", "/chat/completions", "/completions", "/v1/messages"]:
|
||||
if base_url.endswith(suffix):
|
||||
base_url = base_url[:-len(suffix)].rstrip("/")
|
||||
base_url = _normalize_base(base_url)
|
||||
if not base_url:
|
||||
raise HTTPException(400, "Base URL is required")
|
||||
from src.endpoint_resolver import resolve_url
|
||||
base_url = resolve_url(base_url)
|
||||
base_url = _rewrite_loopback_for_docker(base_url)
|
||||
probe_timeout = 3 if (":11434" in base_url or "ollama" in base_url.lower()) else 2
|
||||
models = _probe_endpoint(base_url, api_key.strip() or None, timeout=probe_timeout)
|
||||
ping = {"reachable": True, "error": None} if models else _ping_endpoint(base_url, api_key.strip() or None, timeout=probe_timeout)
|
||||
@@ -1301,9 +1496,9 @@ def setup_model_routes(model_discovery):
|
||||
chat_url = build_chat_url(base)
|
||||
if not model and getattr(ep, "cached_models", None):
|
||||
try:
|
||||
models = _json.loads(ep.cached_models) if isinstance(ep.cached_models, str) else ep.cached_models
|
||||
if models:
|
||||
model = models[0]
|
||||
visible = _visible_models(ep.cached_models, getattr(ep, "hidden_models", None))
|
||||
if visible:
|
||||
model = visible[0]
|
||||
except Exception:
|
||||
pass
|
||||
return {"endpoint_id": ep.id, "endpoint_url": chat_url, "model": model}
|
||||
@@ -1337,58 +1532,63 @@ def setup_model_routes(model_discovery):
|
||||
ep.name = body["name"].strip() or ep.name
|
||||
if "model_type" in body and isinstance(body["model_type"], str):
|
||||
ep.model_type = body["model_type"].strip() or ep.model_type
|
||||
# Rotating an API key used to require DELETE+POST, which wiped
|
||||
# endpoint_url/model from every session referencing the old base
|
||||
# URL. Allow in-place updates so the admin can change the key
|
||||
# (or correct a typo'd base URL) without nuking session state.
|
||||
if "api_key" in body and isinstance(body["api_key"], str):
|
||||
_new_key = body["api_key"].strip()
|
||||
# Empty string means "clear it" (e.g. local Ollama no longer needs a key).
|
||||
ep.api_key = _new_key or None
|
||||
if "base_url" in body and isinstance(body["base_url"], str):
|
||||
_new_base = body["base_url"].strip().rstrip("/")
|
||||
for _suffix in ("/models", "/chat/completions", "/completions", "/v1/messages"):
|
||||
if _new_base.endswith(_suffix):
|
||||
_new_base = _new_base[: -len(_suffix)].rstrip("/")
|
||||
_new_base = _normalize_base(_new_base)
|
||||
if _new_base:
|
||||
ep.base_url = _new_base
|
||||
else:
|
||||
ep.is_enabled = not ep.is_enabled
|
||||
db.commit()
|
||||
_invalidate_models_cache()
|
||||
_local_probe_cache["data"] = None
|
||||
return {
|
||||
"id": ep.id,
|
||||
"is_enabled": ep.is_enabled,
|
||||
"supports_tools": ep.supports_tools,
|
||||
"name": ep.name,
|
||||
"model_type": ep.model_type,
|
||||
"base_url": ep.base_url,
|
||||
}
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
# ── Settings fields that store an endpoint ID ──
|
||||
_EP_SETTING_FIELDS = {
|
||||
"default_endpoint_id": ("default_model", "Default Model"),
|
||||
"utility_endpoint_id": ("utility_model", "Utility Model"),
|
||||
"research_endpoint_id": ("research_model", "Deep Research"),
|
||||
"task_endpoint_id": ("task_model", "Background Tasks"),
|
||||
}
|
||||
|
||||
def _settings_using_endpoint(ep_id: str) -> list:
|
||||
"""Return human-readable labels for settings that reference this endpoint."""
|
||||
settings = _load_settings()
|
||||
affected = []
|
||||
for ep_key, (_, label) in _EP_SETTING_FIELDS.items():
|
||||
if (settings.get(ep_key) or "") == ep_id:
|
||||
affected.append(label)
|
||||
tts_prov = settings.get("tts_provider") or ""
|
||||
if tts_prov == f"endpoint:{ep_id}":
|
||||
affected.append("Text to Speech")
|
||||
return affected
|
||||
return _endpoint_settings_using_endpoint(_load_settings(), ep_id, include_speech=True)
|
||||
|
||||
def _clear_settings_for_endpoint(ep_id: str) -> list:
|
||||
"""Clear all settings that reference this endpoint. Returns list of cleared labels."""
|
||||
settings = _load_settings()
|
||||
cleared = []
|
||||
for ep_key, (model_key, label) in _EP_SETTING_FIELDS.items():
|
||||
if (settings.get(ep_key) or "") == ep_id:
|
||||
settings[ep_key] = ""
|
||||
settings[model_key] = ""
|
||||
cleared.append(label)
|
||||
tts_prov = settings.get("tts_provider") or ""
|
||||
if tts_prov == f"endpoint:{ep_id}":
|
||||
settings["tts_provider"] = "disabled"
|
||||
settings["tts_model"] = "tts-1"
|
||||
cleared.append("Text to Speech")
|
||||
cleared = _clear_endpoint_settings_for_endpoint(settings, ep_id, include_speech=True)
|
||||
if cleared:
|
||||
_save_settings(settings)
|
||||
return cleared
|
||||
|
||||
def _clear_user_prefs_for_endpoint(ep_id: str) -> int:
|
||||
"""Clear per-user endpoint selections and fallback chains."""
|
||||
try:
|
||||
from routes.prefs_routes import _load as _load_prefs, _save as _save_prefs
|
||||
all_prefs = _load_prefs()
|
||||
cleared_users = _clear_user_pref_endpoint_refs(all_prefs, ep_id)
|
||||
if cleared_users:
|
||||
_save_prefs(all_prefs)
|
||||
return cleared_users
|
||||
except Exception as e:
|
||||
logger.warning("Failed to clear user prefs for endpoint %s: %s", ep_id, e)
|
||||
return 0
|
||||
|
||||
def _session_uses_endpoint_url(session_url: str, base_url: str) -> bool:
|
||||
if not session_url or not base_url:
|
||||
return False
|
||||
@@ -1402,12 +1602,18 @@ def setup_model_routes(model_discovery):
|
||||
return sess in variants or sess.startswith(base + "/")
|
||||
|
||||
def _clear_sessions_for_endpoint(db, base_url: str) -> int:
|
||||
"""Drop stored auth for sessions using an endpoint being deleted.
|
||||
|
||||
Keep the session's endpoint URL and model intact. If the admin is
|
||||
replacing an endpoint with the same URL, clearing those fields leaves
|
||||
the UI looking selected while chat requests arrive with an empty model.
|
||||
The chat-time orphan guard still clears truly dead endpoints when no
|
||||
matching enabled endpoint exists.
|
||||
"""
|
||||
cleared = 0
|
||||
rows = db.query(DbSession).filter(DbSession.endpoint_url.isnot(None)).all()
|
||||
for row in rows:
|
||||
if _session_uses_endpoint_url(row.endpoint_url or "", base_url):
|
||||
row.endpoint_url = ""
|
||||
row.model = ""
|
||||
row.headers = {}
|
||||
row.updated_at = datetime.utcnow()
|
||||
cleared += 1
|
||||
@@ -1425,8 +1631,6 @@ def setup_model_routes(model_discovery):
|
||||
try:
|
||||
for sess in list(getattr(manager, "sessions", {}).values()):
|
||||
if _session_uses_endpoint_url(getattr(sess, "endpoint_url", "") or "", base_url):
|
||||
sess.endpoint_url = ""
|
||||
sess.model = ""
|
||||
sess.headers = {}
|
||||
cleared += 1
|
||||
except Exception:
|
||||
@@ -1449,6 +1653,7 @@ def setup_model_routes(model_discovery):
|
||||
raise HTTPException(404, "Endpoint not found")
|
||||
# Clean up any settings that reference this endpoint
|
||||
cleared = _clear_settings_for_endpoint(ep_id)
|
||||
cleared_user_preferences = _clear_user_prefs_for_endpoint(ep_id)
|
||||
cleared_sessions = _clear_sessions_for_endpoint(db, ep.base_url)
|
||||
cleared_loaded_sessions = _clear_loaded_sessions_for_endpoint(ep.base_url)
|
||||
db.delete(ep)
|
||||
@@ -1458,6 +1663,7 @@ def setup_model_routes(model_discovery):
|
||||
return {
|
||||
"deleted": True,
|
||||
"cleared_settings": cleared,
|
||||
"cleared_user_preferences": cleared_user_preferences,
|
||||
"cleared_sessions": cleared_sessions,
|
||||
"cleared_loaded_sessions": cleared_loaded_sessions,
|
||||
}
|
||||
|
||||
@@ -683,9 +683,8 @@ def setup_note_routes(task_scheduler=None):
|
||||
Returns {synthesis, email_sent}.
|
||||
"""
|
||||
# Gate against anonymous callers — LLM synthesis can burn tokens.
|
||||
from src.auth_helpers import get_current_user as _gcu
|
||||
if not _gcu(request):
|
||||
raise HTTPException(401, "Not authenticated")
|
||||
from src.auth_helpers import require_user as _ru
|
||||
_ru(request)
|
||||
body = await request.json()
|
||||
note_id = body.get("note_id")
|
||||
title = (body.get("title") or "").strip()
|
||||
@@ -697,7 +696,7 @@ def setup_note_routes(task_scheduler=None):
|
||||
# the same dispatch without an HTTP roundtrip + auth cookie.
|
||||
return await dispatch_reminder(
|
||||
title=title, note_body=note_body, note_id=note_id,
|
||||
owner=_gcu(request) or "",
|
||||
owner=_owner(request) or "",
|
||||
queue_browser=False,
|
||||
)
|
||||
|
||||
|
||||
@@ -69,9 +69,12 @@ def setup_personal_routes(personal_docs_manager, rag_manager, rag_available):
|
||||
if not directory:
|
||||
raise HTTPException(400, "Directory path is required")
|
||||
|
||||
base_abs = os.path.abspath(PERSONAL_DIR)
|
||||
# realpath (not abspath) so a symlink inside PERSONAL_DIR that points
|
||||
# outside it is resolved before the commonpath confinement check below;
|
||||
# abspath only normalises `..` and would let such a symlink escape.
|
||||
base_abs = os.path.realpath(PERSONAL_DIR)
|
||||
candidate = directory if os.path.isabs(directory) else os.path.join(base_abs, directory)
|
||||
resolved = os.path.abspath(candidate)
|
||||
resolved = os.path.realpath(candidate)
|
||||
try:
|
||||
in_base = os.path.commonpath([resolved, base_abs]) == base_abs
|
||||
except ValueError:
|
||||
|
||||
+14
-2
@@ -12,7 +12,8 @@ def _load():
|
||||
"""Load the raw prefs file (internal use only)."""
|
||||
try:
|
||||
with open(PREFS_FILE, "r", encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
data = json.load(f)
|
||||
return data if isinstance(data, dict) else {}
|
||||
except (FileNotFoundError, json.JSONDecodeError):
|
||||
return {}
|
||||
|
||||
@@ -40,7 +41,18 @@ def _save_for_user(user: Optional[str], prefs: dict):
|
||||
"""Save preferences for a specific user."""
|
||||
all_prefs = _load()
|
||||
if user is None:
|
||||
# Auth disabled — save flat
|
||||
# Auth disabled. If the store is already multi-user (e.g. auth was
|
||||
# turned off on a deployment that previously ran multi-user), writing
|
||||
# `prefs` flat would overwrite the whole `_users` map and destroy every
|
||||
# other user's preferences. Instead write back into the same (first)
|
||||
# slot _load_for_user(None) reads from, preserving the others.
|
||||
if "_users" in all_prefs:
|
||||
users = all_prefs["_users"]
|
||||
first_key = next(iter(users), None)
|
||||
if first_key is not None:
|
||||
users[first_key] = prefs
|
||||
_save(all_prefs)
|
||||
return
|
||||
_save(prefs)
|
||||
return
|
||||
if "_users" not in all_prefs:
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
@@ -12,7 +13,9 @@ from fastapi import APIRouter, HTTPException, Query, Request
|
||||
from fastapi.responses import HTMLResponse, StreamingResponse
|
||||
from pydantic import BaseModel, Field
|
||||
from src.endpoint_resolver import resolve_endpoint
|
||||
from src.auth_helpers import get_current_user
|
||||
from src.auth_helpers import _auth_disabled, get_current_user
|
||||
|
||||
_SESSION_ID_RE = re.compile(r"^[a-zA-Z0-9-]{1,128}$")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -55,9 +58,15 @@ def setup_research_routes(research_handler, session_manager=None) -> APIRouter:
|
||||
verify the session belongs to this user."""
|
||||
user = get_current_user(request)
|
||||
if not user:
|
||||
if _auth_disabled():
|
||||
return ""
|
||||
raise HTTPException(401, "Not authenticated")
|
||||
return user
|
||||
|
||||
def _validate_session_id(session_id: str) -> None:
|
||||
if not _SESSION_ID_RE.fullmatch(session_id):
|
||||
raise HTTPException(400, "Invalid session ID format")
|
||||
|
||||
def _owns_in_memory(session_id: str, user: str) -> bool:
|
||||
"""Ownership check for an in-flight (in-memory) research task.
|
||||
Falls back to the on-disk JSON if the task has already finished."""
|
||||
@@ -95,6 +104,7 @@ def setup_research_routes(research_handler, session_manager=None) -> APIRouter:
|
||||
@router.get("/api/research/status/{session_id}")
|
||||
async def research_status(session_id: str, request: Request):
|
||||
user = _require_user(request)
|
||||
_validate_session_id(session_id)
|
||||
if not _owns_in_memory(session_id, user):
|
||||
raise HTTPException(404, "No research found for this session")
|
||||
status = research_handler.get_status(session_id)
|
||||
@@ -105,6 +115,7 @@ def setup_research_routes(research_handler, session_manager=None) -> APIRouter:
|
||||
@router.post("/api/research/cancel/{session_id}")
|
||||
async def research_cancel(session_id: str, request: Request):
|
||||
user = _require_user(request)
|
||||
_validate_session_id(session_id)
|
||||
if not _owns_in_memory(session_id, user):
|
||||
raise HTTPException(404, "No research found for this session")
|
||||
cancelled = research_handler.cancel_research(session_id)
|
||||
@@ -113,6 +124,7 @@ def setup_research_routes(research_handler, session_manager=None) -> APIRouter:
|
||||
@router.post("/api/research/result/{session_id}")
|
||||
async def research_result(session_id: str, request: Request):
|
||||
user = _require_user(request)
|
||||
_validate_session_id(session_id)
|
||||
if not _owns_in_memory(session_id, user):
|
||||
raise HTTPException(404, "No research result available")
|
||||
result = research_handler.get_result(session_id)
|
||||
@@ -140,6 +152,7 @@ def setup_research_routes(research_handler, session_manager=None) -> APIRouter:
|
||||
async def research_report(session_id: str, request: Request):
|
||||
"""Serve the visual HTML report for a completed research session."""
|
||||
user = _require_user(request)
|
||||
_validate_session_id(session_id)
|
||||
_assert_owns_research(session_id, user)
|
||||
logger.info(f"Visual report requested for session {session_id}")
|
||||
try:
|
||||
@@ -160,6 +173,7 @@ def setup_research_routes(research_handler, session_manager=None) -> APIRouter:
|
||||
"""Mark an image URL as hidden for this research's visual report.
|
||||
Persisted to the research JSON so subsequent /report renders skip it."""
|
||||
user = _require_user(request)
|
||||
_validate_session_id(session_id)
|
||||
_assert_owns_research(session_id, user)
|
||||
ok = research_handler.hide_image(session_id, body.url)
|
||||
if not ok:
|
||||
@@ -170,6 +184,7 @@ def setup_research_routes(research_handler, session_manager=None) -> APIRouter:
|
||||
async def research_unhide_images(session_id: str, request: Request):
|
||||
"""Clear the hidden-images list for a research session."""
|
||||
user = _require_user(request)
|
||||
_validate_session_id(session_id)
|
||||
_assert_owns_research(session_id, user)
|
||||
ok = research_handler.unhide_all_images(session_id)
|
||||
if not ok:
|
||||
@@ -235,6 +250,7 @@ def setup_research_routes(research_handler, session_manager=None) -> APIRouter:
|
||||
"""Return the full JSON for a single research result — sources,
|
||||
summary, stats — used by the Library preview panel."""
|
||||
user = _require_user(request)
|
||||
_validate_session_id(session_id)
|
||||
path = Path("data/deep_research") / f"{session_id}.json"
|
||||
if not path.exists():
|
||||
raise HTTPException(404, "Research not found")
|
||||
@@ -251,6 +267,7 @@ def setup_research_routes(research_handler, session_manager=None) -> APIRouter:
|
||||
async def research_archive(session_id: str, request: Request, archived: bool = Query(True)):
|
||||
"""Soft-archive / restore a research report (sets `archived` in its JSON)."""
|
||||
user = _require_user(request)
|
||||
_validate_session_id(session_id)
|
||||
path = Path("data/deep_research") / f"{session_id}.json"
|
||||
if not path.exists():
|
||||
raise HTTPException(404, "Research not found")
|
||||
@@ -270,6 +287,7 @@ def setup_research_routes(research_handler, session_manager=None) -> APIRouter:
|
||||
async def research_delete(session_id: str, request: Request):
|
||||
"""Delete a research result from disk."""
|
||||
user = _require_user(request)
|
||||
_validate_session_id(session_id)
|
||||
data_dir = Path("data/deep_research")
|
||||
json_path = data_dir / f"{session_id}.json"
|
||||
deleted = False
|
||||
@@ -299,7 +317,7 @@ def setup_research_routes(research_handler, session_manager=None) -> APIRouter:
|
||||
endpoint_id: Optional[str] = None
|
||||
model: Optional[str] = None
|
||||
max_time: int = Field(default=300, ge=60, le=1800)
|
||||
extraction_timeout: Optional[int] = Field(default=None, ge=15, le=600)
|
||||
extraction_timeout: Optional[int] = Field(default=None, ge=15, le=3600)
|
||||
extraction_concurrency: Optional[int] = Field(default=None, ge=1, le=12)
|
||||
category: Optional[str] = None
|
||||
|
||||
@@ -413,6 +431,7 @@ def setup_research_routes(research_handler, session_manager=None) -> APIRouter:
|
||||
async def research_stream(session_id: str, request: Request):
|
||||
"""SSE stream of research progress events."""
|
||||
user = _require_user(request)
|
||||
_validate_session_id(session_id)
|
||||
if not _owns_in_memory(session_id, user):
|
||||
raise HTTPException(404, "No research found for this session")
|
||||
async def _generate():
|
||||
@@ -446,6 +465,7 @@ def setup_research_routes(research_handler, session_manager=None) -> APIRouter:
|
||||
async def research_result_peek(session_id: str, request: Request):
|
||||
"""Get research result without clearing it (for panel use)."""
|
||||
user = _require_user(request)
|
||||
_validate_session_id(session_id)
|
||||
if not _owns_in_memory(session_id, user):
|
||||
raise HTTPException(404, "No research found for this session")
|
||||
result = research_handler.get_result(session_id)
|
||||
@@ -474,7 +494,14 @@ def setup_research_routes(research_handler, session_manager=None) -> APIRouter:
|
||||
injects a single system message containing the report and sources so
|
||||
the user can ask follow-up questions in a clean conversation.
|
||||
"""
|
||||
_require_user(request)
|
||||
user = _require_user(request)
|
||||
_validate_session_id(session_id)
|
||||
# SECURITY: gate on ownership before reading the persisted research —
|
||||
# otherwise any authenticated user could spin off (and thereby read)
|
||||
# another user's report by guessing its session ID. Mirrors every other
|
||||
# endpoint in this file (see result_peek above).
|
||||
if not _owns_in_memory(session_id, user):
|
||||
raise HTTPException(404, "No research found for this session")
|
||||
if session_manager is None:
|
||||
raise HTTPException(500, "session_manager not configured")
|
||||
|
||||
@@ -555,7 +582,6 @@ def setup_research_routes(research_handler, session_manager=None) -> APIRouter:
|
||||
|
||||
# Create new session
|
||||
new_sid = str(uuid.uuid4())
|
||||
user = get_current_user(request)
|
||||
|
||||
title_query = (query or "research").strip()
|
||||
if len(title_query) > 60:
|
||||
|
||||
+162
-48
@@ -11,45 +11,118 @@ from core.session_manager import SessionManager
|
||||
from core.models import ChatMessage
|
||||
from src.request_models import SessionResponse
|
||||
from core.database import Session as DbSession, SessionLocal, Document, GalleryImage
|
||||
from src.auth_helpers import get_current_user
|
||||
from src.auth_helpers import get_current_user, effective_user
|
||||
|
||||
|
||||
def _verify_session_owner(request: Request, session_id: str):
|
||||
"""Verify the current user owns the session. Raises 404 if not."""
|
||||
user = get_current_user(request)
|
||||
def _sanitize_export_filename(name: str) -> str:
|
||||
"""Return a conservative filename safe for Content-Disposition."""
|
||||
name = name if isinstance(name, str) else ""
|
||||
name = re.sub(r"[^A-Za-z0-9._-]", "_", name)
|
||||
return name[:128]
|
||||
|
||||
|
||||
def _verify_session_owner(request: Request, session_id: str, session_manager=None):
|
||||
"""Verify the current user owns the session. Raises 404 if not.
|
||||
|
||||
Ownership is checked against the DB row when one exists (unchanged). If
|
||||
there is no DB row but the caller owns an in-memory "ghost" session — one
|
||||
that lives only in ``session_manager`` because it was never persisted, or
|
||||
its DB row was removed out-of-band — fall back to the in-memory owner so the
|
||||
user can still manage and delete it. Without this fallback such sessions are
|
||||
listed by ``/api/sessions`` (they come from the in-memory manager) yet every
|
||||
per-session operation 404s, making them impossible to delete (issue #1044).
|
||||
|
||||
``session_manager`` is optional and defaults to ``None`` so existing callers
|
||||
that only care about persisted sessions keep their exact prior behavior.
|
||||
"""
|
||||
user = effective_user(request)
|
||||
if not user:
|
||||
raise HTTPException(403, "Authentication required")
|
||||
db = SessionLocal()
|
||||
try:
|
||||
row = db.query(DbSession.owner).filter(DbSession.id == session_id).first()
|
||||
if not row:
|
||||
raise HTTPException(404, f"Session {session_id} not found")
|
||||
if row.owner != user:
|
||||
raise HTTPException(404, f"Session {session_id} not found")
|
||||
finally:
|
||||
db.close()
|
||||
if row is not None:
|
||||
if row.owner != user:
|
||||
raise HTTPException(404, f"Session {session_id} not found")
|
||||
return
|
||||
# No DB row — allow the caller to act on an in-memory ghost they own.
|
||||
if session_manager is not None:
|
||||
ghost = getattr(session_manager, "sessions", {}).get(session_id)
|
||||
if ghost is not None and getattr(ghost, "owner", None) == user:
|
||||
return
|
||||
raise HTTPException(404, f"Session {session_id} not found")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api", tags=["sessions"])
|
||||
|
||||
def _pick_endpoint_for_sort():
|
||||
|
||||
def _current_user_is_admin(request: Request, user: str | None) -> bool:
|
||||
if not user:
|
||||
return False
|
||||
auth_mgr = getattr(request.app.state, "auth_manager", None)
|
||||
is_admin = getattr(auth_mgr, "is_admin", None)
|
||||
if not callable(is_admin):
|
||||
return False
|
||||
try:
|
||||
return bool(is_admin(user))
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def _reject_raw_endpoint_url_for_non_admin(
|
||||
request: Request,
|
||||
user: str | None,
|
||||
endpoint_id: str | None,
|
||||
endpoint_url: str | None,
|
||||
) -> None:
|
||||
"""Require registered endpoints for signed-in non-admin session changes."""
|
||||
if endpoint_id and endpoint_id.strip():
|
||||
return
|
||||
if not endpoint_url:
|
||||
return
|
||||
# Raw URLs make the server dial whatever host the request supplies. For
|
||||
# non-admin users, require a saved endpoint row so normal owner scoping and
|
||||
# endpoint validation have already happened.
|
||||
if user and not _current_user_is_admin(request, user):
|
||||
raise HTTPException(403, "Choose a registered model endpoint")
|
||||
|
||||
|
||||
def _persist_session_headers(session_id: str, headers: dict | None) -> None:
|
||||
"""Persist endpoint auth headers for DB-backed session metadata."""
|
||||
db = SessionLocal()
|
||||
try:
|
||||
db_session = db.query(DbSession).filter(DbSession.id == session_id).first()
|
||||
if db_session:
|
||||
db_session.headers = headers or {}
|
||||
db_session.updated_at = datetime.utcnow()
|
||||
db.commit()
|
||||
except Exception:
|
||||
db.rollback()
|
||||
raise
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
def _pick_endpoint_for_sort(owner=None):
|
||||
"""Pick model endpoint for auto-sort LLM call — uses utility endpoint setting, falls back to default."""
|
||||
from src.endpoint_resolver import resolve_endpoint
|
||||
# Try utility endpoint first (what the user configured for background tasks)
|
||||
url, model, headers = resolve_endpoint("utility")
|
||||
url, model, headers = resolve_endpoint("utility", owner=owner)
|
||||
if url and model:
|
||||
return url, model, headers
|
||||
# Fall back to task endpoint
|
||||
try:
|
||||
from src.task_endpoint import resolve_task_endpoint
|
||||
url, model, headers = resolve_task_endpoint()
|
||||
url, model, headers = resolve_task_endpoint(owner=owner)
|
||||
if url and model:
|
||||
return url, model, headers
|
||||
except Exception:
|
||||
pass
|
||||
# Fall back to default
|
||||
url, model, headers = resolve_endpoint("default")
|
||||
url, model, headers = resolve_endpoint("default", owner=owner)
|
||||
if url and model:
|
||||
return url, model, headers
|
||||
return None, None, None
|
||||
@@ -63,7 +136,7 @@ def setup_session_routes(session_manager: SessionManager, config: dict, webhook_
|
||||
|
||||
@router.get("/sessions")
|
||||
def list_sessions(request: Request):
|
||||
user = get_current_user(request)
|
||||
user = effective_user(request)
|
||||
# Lazy purge: incognito sessions are ephemeral by design — wipe leftovers
|
||||
# from the DB and session_manager so they vanish on the next page refresh.
|
||||
# BUT: skip sessions that were created within the last 10 minutes.
|
||||
@@ -172,11 +245,41 @@ def setup_session_routes(session_manager: SessionManager, config: dict, webhook_
|
||||
endpoint_id: str = Form(""),
|
||||
):
|
||||
skip_val = str(skip_validation).lower() == "true"
|
||||
user = get_current_user(request)
|
||||
endpoint_api_key = ""
|
||||
endpoint_base_url = ""
|
||||
_reject_raw_endpoint_url_for_non_admin(request, user, endpoint_id, endpoint_url)
|
||||
if endpoint_id and endpoint_id.strip():
|
||||
from core.database import ModelEndpoint
|
||||
from src.auth_helpers import owner_filter
|
||||
from src.endpoint_resolver import build_chat_url, normalize_base
|
||||
_db = SessionLocal()
|
||||
try:
|
||||
q = _db.query(ModelEndpoint).filter(
|
||||
ModelEndpoint.id == endpoint_id.strip(),
|
||||
ModelEndpoint.is_enabled == True,
|
||||
)
|
||||
if user:
|
||||
q = owner_filter(q, ModelEndpoint, user)
|
||||
endpoint_row = q.first()
|
||||
if not endpoint_row:
|
||||
raise HTTPException(400, "Model endpoint no longer exists")
|
||||
endpoint_base_url = endpoint_row.base_url or ""
|
||||
endpoint_api_key = endpoint_row.api_key or ""
|
||||
endpoint_url = build_chat_url(normalize_base(endpoint_base_url))
|
||||
finally:
|
||||
_db.close()
|
||||
|
||||
if not endpoint_url and not skip_val:
|
||||
raise HTTPException(400, "endpoint_url is required (choose from /api/models)")
|
||||
|
||||
model_to_use = model
|
||||
request_api_key = api_key.strip() if api_key else ""
|
||||
effective_api_key = request_api_key or endpoint_api_key
|
||||
validation_headers = None
|
||||
if effective_api_key:
|
||||
from src.endpoint_resolver import build_headers
|
||||
validation_headers = build_headers(effective_api_key, endpoint_base_url or endpoint_url)
|
||||
|
||||
if skip_val:
|
||||
# skip_validation = trust the caller and do NOT probe /v1/models.
|
||||
@@ -187,7 +290,7 @@ def setup_session_routes(session_manager: SessionManager, config: dict, webhook_
|
||||
elif not model_to_use:
|
||||
from src.llm_core import list_model_ids
|
||||
ids = list_model_ids(endpoint_url, timeout=REQUEST_TIMEOUT,
|
||||
headers={"Authorization": f"Bearer {api_key}"} if api_key.strip() else None)
|
||||
headers=validation_headers)
|
||||
if not ids:
|
||||
raise HTTPException(400, "Cannot reach /v1/models")
|
||||
# Default to the first CHAT model — endpoints often list embedding/
|
||||
@@ -202,7 +305,7 @@ def setup_session_routes(session_manager: SessionManager, config: dict, webhook_
|
||||
import os as _os
|
||||
req_base = _os.path.basename(model_to_use.rstrip("/"))
|
||||
avail = list_model_ids(endpoint_url, timeout=REQUEST_TIMEOUT,
|
||||
headers={"Authorization": f"Bearer {api_key}"} if api_key.strip() else None)
|
||||
headers=validation_headers)
|
||||
if not avail:
|
||||
raise HTTPException(400, "Cannot reach /v1/models")
|
||||
if model_to_use not in avail:
|
||||
@@ -217,7 +320,7 @@ def setup_session_routes(session_manager: SessionManager, config: dict, webhook_
|
||||
model_to_use = found
|
||||
|
||||
sid = str(uuid.uuid4())
|
||||
user = get_current_user(request)
|
||||
user = effective_user(request)
|
||||
session = session_manager.create_session(
|
||||
session_id=sid,
|
||||
name=name or "",
|
||||
@@ -227,22 +330,15 @@ def setup_session_routes(session_manager: SessionManager, config: dict, webhook_
|
||||
owner=user,
|
||||
)
|
||||
# Set auth headers for custom API-key endpoints
|
||||
resolved_key = api_key.strip() if api_key else ""
|
||||
resolved_key = request_api_key
|
||||
resolved_base = endpoint_url
|
||||
if not resolved_key and endpoint_id and endpoint_id.strip():
|
||||
from core.database import ModelEndpoint
|
||||
_db = SessionLocal()
|
||||
try:
|
||||
ep = _db.query(ModelEndpoint).filter(ModelEndpoint.id == endpoint_id.strip()).first()
|
||||
if ep and ep.api_key:
|
||||
resolved_key = ep.api_key
|
||||
resolved_base = ep.base_url
|
||||
finally:
|
||||
_db.close()
|
||||
if not resolved_key and endpoint_api_key:
|
||||
resolved_key = endpoint_api_key
|
||||
resolved_base = endpoint_base_url
|
||||
if resolved_key:
|
||||
from src.endpoint_resolver import build_headers
|
||||
session.headers = build_headers(resolved_key, resolved_base)
|
||||
session_manager.save_sessions()
|
||||
_persist_session_headers(sid, session.headers)
|
||||
# Fire webhook (sync-safe)
|
||||
if webhook_manager:
|
||||
webhook_manager.fire_and_forget("session.created", {
|
||||
@@ -288,27 +384,38 @@ def setup_session_routes(session_manager: SessionManager, config: dict, webhook_
|
||||
db.close()
|
||||
# Switch model/endpoint mid-session
|
||||
if model is not None and endpoint_url is not None:
|
||||
user = get_current_user(request)
|
||||
_reject_raw_endpoint_url_for_non_admin(request, user, endpoint_id, endpoint_url)
|
||||
endpoint_api_key = ""
|
||||
endpoint_base_url = ""
|
||||
if endpoint_id:
|
||||
from core.database import ModelEndpoint
|
||||
from src.auth_helpers import owner_filter
|
||||
from src.endpoint_resolver import build_chat_url, normalize_base
|
||||
_db = SessionLocal()
|
||||
try:
|
||||
ep = _db.query(ModelEndpoint).filter(ModelEndpoint.id == endpoint_id).first()
|
||||
q = _db.query(ModelEndpoint).filter(
|
||||
ModelEndpoint.id == endpoint_id,
|
||||
ModelEndpoint.is_enabled == True,
|
||||
)
|
||||
if user:
|
||||
q = owner_filter(q, ModelEndpoint, user)
|
||||
ep = q.first()
|
||||
if not ep:
|
||||
raise HTTPException(400, "Model endpoint no longer exists")
|
||||
endpoint_base_url = ep.base_url or ""
|
||||
endpoint_api_key = ep.api_key or ""
|
||||
endpoint_url = build_chat_url(normalize_base(endpoint_base_url))
|
||||
finally:
|
||||
_db.close()
|
||||
session.model = model
|
||||
session.endpoint_url = endpoint_url
|
||||
# Update auth headers from the endpoint's stored API key
|
||||
if endpoint_id:
|
||||
_db = SessionLocal()
|
||||
try:
|
||||
ep = _db.query(ModelEndpoint).filter(ModelEndpoint.id == endpoint_id).first()
|
||||
if ep and ep.api_key:
|
||||
from src.endpoint_resolver import build_headers
|
||||
session.headers = build_headers(ep.api_key, ep.base_url)
|
||||
finally:
|
||||
_db.close()
|
||||
if endpoint_api_key:
|
||||
from src.endpoint_resolver import build_headers
|
||||
session.headers = build_headers(endpoint_api_key, endpoint_base_url)
|
||||
else:
|
||||
session.headers = {}
|
||||
# Persist to DB
|
||||
db = SessionLocal()
|
||||
try:
|
||||
@@ -316,6 +423,7 @@ def setup_session_routes(session_manager: SessionManager, config: dict, webhook_
|
||||
if db_session:
|
||||
db_session.model = model
|
||||
db_session.endpoint_url = endpoint_url
|
||||
db_session.headers = session.headers or {}
|
||||
db_session.updated_at = datetime.utcnow()
|
||||
db.commit()
|
||||
finally:
|
||||
@@ -356,7 +464,7 @@ def setup_session_routes(session_manager: SessionManager, config: dict, webhook_
|
||||
ids = []
|
||||
for sid in ids:
|
||||
try:
|
||||
_verify_session_owner(request, sid)
|
||||
_verify_session_owner(request, sid, session_manager)
|
||||
session_manager.delete_session(sid)
|
||||
db = SessionLocal()
|
||||
try:
|
||||
@@ -374,7 +482,7 @@ def setup_session_routes(session_manager: SessionManager, config: dict, webhook_
|
||||
@router.delete("/session/{sid}")
|
||||
def delete_session(request: Request, sid: str):
|
||||
"""Permanently delete a session and all its messages."""
|
||||
_verify_session_owner(request, sid)
|
||||
_verify_session_owner(request, sid, session_manager)
|
||||
try:
|
||||
# Block deletion of starred/favorited sessions
|
||||
db = SessionLocal()
|
||||
@@ -499,7 +607,7 @@ def setup_session_routes(session_manager: SessionManager, config: dict, webhook_
|
||||
@router.get("/sessions/archived")
|
||||
def list_archived_sessions(request: Request, search: str = "", offset: int = 0, limit: int = 20, sort: str = "recent", model: str = ""):
|
||||
"""List archived sessions for the archive browser."""
|
||||
user = get_current_user(request)
|
||||
user = effective_user(request)
|
||||
db = SessionLocal()
|
||||
try:
|
||||
q = db.query(DbSession).filter(DbSession.archived == True)
|
||||
@@ -510,7 +618,12 @@ def setup_session_routes(session_manager: SessionManager, config: dict, webhook_
|
||||
safe_search = search.replace('%', r'\%').replace('_', r'\_')
|
||||
q = q.filter(DbSession.name.ilike(f"%{safe_search}%", escape='\\'))
|
||||
if model:
|
||||
q = q.filter(DbSession.model.ilike(f"%{model}"))
|
||||
# Contains match (mirrors the name filter above). The old
|
||||
# f"%{model}" was a SUFFIX-only match, so filtering by "gpt-4"
|
||||
# dropped "gpt-4o" and over-matched on shared suffixes; it also
|
||||
# left LIKE wildcards in the user value unescaped.
|
||||
safe_model = model.replace('%', r'\%').replace('_', r'\_')
|
||||
q = q.filter(DbSession.model.ilike(f"%{safe_model}%", escape='\\'))
|
||||
total = q.count()
|
||||
sort_map = {
|
||||
"recent": DbSession.updated_at.desc(),
|
||||
@@ -558,6 +671,7 @@ def setup_session_routes(session_manager: SessionManager, config: dict, webhook_
|
||||
|
||||
safe_name = re.sub(r'[^\w\-_]', '_', session.name)
|
||||
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
||||
filename = _sanitize_export_filename(filename)
|
||||
|
||||
if fmt == "json":
|
||||
import json as _json
|
||||
@@ -635,7 +749,7 @@ def setup_session_routes(session_manager: SessionManager, config: dict, webhook_
|
||||
|
||||
@router.post("/sessions/save")
|
||||
def sessions_save_now(request: Request):
|
||||
user = get_current_user(request)
|
||||
user = effective_user(request)
|
||||
if not user:
|
||||
raise HTTPException(401, "Not authenticated")
|
||||
session_manager.save_sessions()
|
||||
@@ -651,7 +765,7 @@ def setup_session_routes(session_manager: SessionManager, config: dict, webhook_
|
||||
if not OPENAI_API_KEY:
|
||||
raise HTTPException(400, "Server missing OPENAI_API_KEY")
|
||||
sid = str(uuid.uuid4())
|
||||
user = get_current_user(request)
|
||||
user = effective_user(request)
|
||||
session = session_manager.create_session(
|
||||
session_id=sid,
|
||||
name="",
|
||||
@@ -728,7 +842,7 @@ def setup_session_routes(session_manager: SessionManager, config: dict, webhook_
|
||||
from src.endpoint_resolver import resolve_endpoint
|
||||
from src.llm_core import llm_call_async
|
||||
|
||||
url, model, headers = resolve_endpoint("utility")
|
||||
url, model, headers = resolve_endpoint("utility", owner=get_current_user(request))
|
||||
if not url or not model:
|
||||
url, model, headers = session.endpoint_url, session.model, session.headers
|
||||
if not url or not model:
|
||||
@@ -791,7 +905,7 @@ def setup_session_routes(session_manager: SessionManager, config: dict, webhook_
|
||||
users can clean junk without spending tokens.
|
||||
"""
|
||||
from src.llm_core import llm_call
|
||||
user = get_current_user(request)
|
||||
user = effective_user(request)
|
||||
user_sessions = session_manager.get_sessions_for_user(user)
|
||||
|
||||
# Delete empty and throwaway sessions before sorting
|
||||
@@ -928,9 +1042,9 @@ def setup_session_routes(session_manager: SessionManager, config: dict, webhook_
|
||||
|
||||
# Pick an endpoint — prefer admin-configured task endpoint
|
||||
from src.task_endpoint import resolve_task_endpoint
|
||||
url, model, headers = resolve_task_endpoint()
|
||||
url, model, headers = resolve_task_endpoint(owner=user)
|
||||
if not url:
|
||||
url, model, headers = _pick_endpoint_for_sort()
|
||||
url, model, headers = _pick_endpoint_for_sort(owner=user)
|
||||
if not url:
|
||||
raise HTTPException(503, "No available model endpoint for auto-sort")
|
||||
|
||||
|
||||
+173
-13
@@ -118,6 +118,7 @@ def _running_in_container(dockerenv_path="/.dockerenv", cgroup_path="/proc/1/cgr
|
||||
|
||||
|
||||
DockerRowStatus = namedtuple("DockerRowStatus", ["applicable", "install_hint"])
|
||||
PackageUpdateStatus = namedtuple("PackageUpdateStatus", ["available", "note"])
|
||||
|
||||
|
||||
def _docker_row_status(*, on_remote, in_container, installed, default_hint):
|
||||
@@ -127,6 +128,24 @@ def _docker_row_status(*, on_remote, in_container, installed, default_hint):
|
||||
return DockerRowStatus(applicable=True, install_hint=default_hint)
|
||||
|
||||
|
||||
def _pip_dist_name(pkg: dict) -> str:
|
||||
"""Distribution name for importlib.metadata lookups.
|
||||
|
||||
The Cookbook package catalog carries both the import name (``name``, e.g.
|
||||
``llama_cpp``) and the pip spec (``pip``, e.g. ``llama-cpp-python[server]``).
|
||||
The distribution is NOT always the import name with underscores swapped for
|
||||
dashes — ``llama_cpp`` ships in the ``llama-cpp-python`` distribution — so
|
||||
derive it from the pip spec (stripping any ``[extras]`` and version markers)
|
||||
and fall back to the munged import name only when no pip spec is declared.
|
||||
"""
|
||||
pip = (pkg.get("pip") or "").strip()
|
||||
if pip:
|
||||
base = re.split(r"[\[<>=!~;\s]", pip, maxsplit=1)[0].strip()
|
||||
if base:
|
||||
return base
|
||||
return (pkg.get("name") or "").replace("_", "-")
|
||||
|
||||
|
||||
def _package_installed_from_probe(name: str, probe: dict) -> bool:
|
||||
"""Return whether an optional dependency is usable by Cookbook.
|
||||
|
||||
@@ -162,7 +181,10 @@ def _package_status_note(name: str, probe: dict) -> str:
|
||||
locations = module.get("locations") or []
|
||||
if name == "vllm":
|
||||
if binaries.get("vllm"):
|
||||
return f"vLLM CLI: {binaries['vllm']}"
|
||||
parts = [f"vLLM CLI: {binaries['vllm']}"]
|
||||
if dists.get("vllm"):
|
||||
parts.append(f"python package: vllm {dists['vllm']}")
|
||||
return "; ".join(parts)
|
||||
if module.get("found") and not dists.get("vllm"):
|
||||
loc = locations[0] if locations else module.get("origin") or "unknown path"
|
||||
return f"Python sees a vllm namespace at {loc}, but no vLLM CLI is on PATH."
|
||||
@@ -183,13 +205,70 @@ def _package_status_note(name: str, probe: dict) -> str:
|
||||
return ""
|
||||
|
||||
|
||||
def _package_pip_update_status(pkg: dict, probe: dict | None = None) -> PackageUpdateStatus:
|
||||
"""Return whether the Dependencies UI should offer a generic pip update.
|
||||
|
||||
"Installed" means Cookbook can use the dependency. It does not always mean
|
||||
the dependency is a Python package that Cookbook should update with pip:
|
||||
native llama-server can come from a package manager/source build, and a CLI
|
||||
may be on PATH without matching Python package metadata.
|
||||
"""
|
||||
if pkg.get("kind") == "system" or not pkg.get("pip"):
|
||||
return PackageUpdateStatus(False, "Update this system dependency outside Odysseus.")
|
||||
|
||||
name = pkg.get("name")
|
||||
binaries = probe.get("binaries") if isinstance(probe, dict) and isinstance(probe.get("binaries"), dict) else {}
|
||||
dists = probe.get("dists") if isinstance(probe, dict) and isinstance(probe.get("dists"), dict) else {}
|
||||
|
||||
if name == "llama_cpp" and binaries.get("llama-server"):
|
||||
return PackageUpdateStatus(
|
||||
False,
|
||||
"Using native llama-server on PATH; update it with its package manager or source checkout.",
|
||||
)
|
||||
if name == "vllm" and binaries.get("vllm") and not dists.get("vllm"):
|
||||
return PackageUpdateStatus(
|
||||
False,
|
||||
"Using a vLLM CLI on PATH without Python package metadata; update it outside Odysseus.",
|
||||
)
|
||||
|
||||
return PackageUpdateStatus(True, "Update uses pip in the selected Python environment.")
|
||||
|
||||
|
||||
def _prepend_user_install_bins_to_path() -> None:
|
||||
"""Make pip --user console scripts visible to dependency probes.
|
||||
|
||||
Docker Cookbook installs vLLM with `python -m pip install --user`, which
|
||||
drops the `vllm` CLI in /app/.local/bin. The running app process does not
|
||||
inherit that PATH update, so `shutil.which("vllm")` can report missing even
|
||||
after a successful install.
|
||||
"""
|
||||
try:
|
||||
import site
|
||||
|
||||
candidates = [os.path.join(site.USER_BASE, "bin")]
|
||||
except Exception:
|
||||
candidates = []
|
||||
candidates.append(os.path.expanduser("~/.local/bin"))
|
||||
|
||||
parts = os.environ.get("PATH", "").split(os.pathsep) if os.environ.get("PATH") else []
|
||||
changed = False
|
||||
for path in reversed([p for p in candidates if p]):
|
||||
if path not in parts:
|
||||
parts.insert(0, path)
|
||||
changed = True
|
||||
if changed:
|
||||
os.environ["PATH"] = os.pathsep.join(parts)
|
||||
|
||||
|
||||
def _package_probe_script(names: list[str]) -> str:
|
||||
names_lit = ",".join(repr(n) for n in names)
|
||||
return f"""
|
||||
import importlib.util
|
||||
import importlib.metadata as md
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import site
|
||||
|
||||
names=[{names_lit}]
|
||||
dist_names={{
|
||||
@@ -204,6 +283,24 @@ bin_names={{
|
||||
'llama_cpp':['llama-server'],
|
||||
}}
|
||||
|
||||
def add_user_install_bins_to_path():
|
||||
candidates = []
|
||||
try:
|
||||
candidates.append(os.path.join(site.USER_BASE, 'bin'))
|
||||
except Exception:
|
||||
pass
|
||||
candidates.append(os.path.expanduser('~/.local/bin'))
|
||||
parts = os.environ.get('PATH', '').split(os.pathsep) if os.environ.get('PATH') else []
|
||||
changed = False
|
||||
for path in reversed([p for p in candidates if p]):
|
||||
if path not in parts:
|
||||
parts.insert(0, path)
|
||||
changed = True
|
||||
if changed:
|
||||
os.environ['PATH'] = os.pathsep.join(parts)
|
||||
|
||||
add_user_install_bins_to_path()
|
||||
|
||||
def mod_status(n):
|
||||
spec = importlib.util.find_spec(n)
|
||||
loader = getattr(spec, 'loader', None) if spec else None
|
||||
@@ -317,7 +414,7 @@ async def _generate_pty(cmd: str, timeout: int, request: Request):
|
||||
yield f"data: {json.dumps({'exit_code': -1, 'error': PTY_UNSUPPORTED_ERROR})}\n\n"
|
||||
return
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
loop = asyncio.get_running_loop()
|
||||
master_fd, slave_fd = pty.openpty()
|
||||
|
||||
# Set master to non-blocking
|
||||
@@ -469,7 +566,8 @@ async def _generate_tmux(cmd: str, request: Request):
|
||||
f"EC=${{PIPESTATUS[0]}}\n"
|
||||
f"echo ':::EXIT_CODE:::'$EC >> '{log_path}'\n"
|
||||
f"rm -f '{script_path}'\n"
|
||||
f"exit $EC\n"
|
||||
f"exit $EC\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
script_path.chmod(0o755)
|
||||
logger.info("tmux wrapper script created: session=%s path=%s", session_id, script_path)
|
||||
@@ -504,7 +602,7 @@ async def _generate_tmux(cmd: str, request: Request):
|
||||
# Read new lines from log
|
||||
try:
|
||||
if log_path.exists():
|
||||
lines = log_path.read_text(errors="replace").splitlines()
|
||||
lines = log_path.read_text(encoding="utf-8", errors="replace").splitlines()
|
||||
new_lines = lines[lines_sent:]
|
||||
for line in new_lines:
|
||||
if line.startswith(":::EXIT_CODE:::"):
|
||||
@@ -532,7 +630,7 @@ async def _generate_tmux(cmd: str, request: Request):
|
||||
# Session ended — do one final read
|
||||
await asyncio.sleep(0.5)
|
||||
if log_path.exists():
|
||||
lines = log_path.read_text(errors="replace").splitlines()
|
||||
lines = log_path.read_text(encoding="utf-8", errors="replace").splitlines()
|
||||
for line in lines[lines_sent:]:
|
||||
if line.startswith(":::EXIT_CODE:::"):
|
||||
try:
|
||||
@@ -735,10 +833,11 @@ def setup_shell_routes() -> APIRouter:
|
||||
]
|
||||
|
||||
finished = 0
|
||||
deadline = (asyncio.get_event_loop().time() + timeout) if timeout else None
|
||||
loop = asyncio.get_running_loop()
|
||||
deadline = (loop.time() + timeout) if timeout else None
|
||||
while finished < 2:
|
||||
if deadline:
|
||||
remaining = deadline - asyncio.get_event_loop().time()
|
||||
remaining = deadline - loop.time()
|
||||
if remaining <= 0:
|
||||
raise asyncio.TimeoutError()
|
||||
wait = min(remaining, 2.0)
|
||||
@@ -791,7 +890,15 @@ def setup_shell_routes() -> APIRouter:
|
||||
"""
|
||||
_require_admin(request)
|
||||
_reject_cross_site(request)
|
||||
import importlib, importlib.metadata as importlib_metadata, shlex, json as _json
|
||||
import importlib, importlib.metadata as importlib_metadata, shlex, json as _json, site, sys
|
||||
_prepend_user_install_bins_to_path()
|
||||
importlib.invalidate_caches()
|
||||
try:
|
||||
user_site = site.getusersitepackages()
|
||||
if user_site and os.path.isdir(user_site) and user_site not in sys.path:
|
||||
sys.path.append(user_site)
|
||||
except Exception:
|
||||
pass
|
||||
if ssh_port and str(ssh_port).strip() not in ("", "22"):
|
||||
_port = str(ssh_port).strip()
|
||||
if not _SSH_PORT_RE.match(_port) or not (1 <= int(_port) <= 65535):
|
||||
@@ -870,6 +977,7 @@ def setup_shell_routes() -> APIRouter:
|
||||
|
||||
for pkg in packages:
|
||||
on_remote = bool(host and pkg.get("target") == "remote")
|
||||
probe = None
|
||||
if on_remote:
|
||||
pkg["installed"] = bool(remote_status.get(pkg["name"], False))
|
||||
probe = remote_details.get(pkg["name"])
|
||||
@@ -883,19 +991,36 @@ def setup_shell_routes() -> APIRouter:
|
||||
elif pkg["name"] == "llama_cpp" and shutil.which("llama-server"):
|
||||
pkg["installed"] = True
|
||||
pkg["status_note"] = f"native llama-server: {shutil.which('llama-server')}"
|
||||
probe = {"binaries": {"llama-server": shutil.which("llama-server")}, "dists": {}}
|
||||
elif pkg["name"] == "vllm":
|
||||
_vllm_cli = shutil.which("vllm")
|
||||
pkg["installed"] = _vllm_cli is not None
|
||||
if pkg["installed"]:
|
||||
try:
|
||||
_vllm_version = importlib_metadata.version(_pip_dist_name(pkg))
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
_vllm_version = None
|
||||
probe = {
|
||||
"binaries": {"vllm": _vllm_cli},
|
||||
"dists": {"vllm": _vllm_version} if _vllm_version else {},
|
||||
}
|
||||
pkg["status_note"] = _package_status_note("vllm", probe)
|
||||
else:
|
||||
try:
|
||||
importlib.import_module(pkg["name"])
|
||||
if pkg["name"] == "vllm":
|
||||
pkg["installed"] = shutil.which("vllm") is not None
|
||||
else:
|
||||
importlib_metadata.version(pkg["name"].replace("_", "-"))
|
||||
pkg["installed"] = True
|
||||
importlib_metadata.version(_pip_dist_name(pkg))
|
||||
pkg["installed"] = True
|
||||
except ImportError:
|
||||
pkg["installed"] = False
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
pkg["installed"] = False
|
||||
|
||||
if pkg.get("installed"):
|
||||
update_status = _package_pip_update_status(pkg, probe)
|
||||
pkg["pip_update_available"] = update_status.available
|
||||
if update_status.note:
|
||||
pkg["update_note"] = update_status.note
|
||||
|
||||
if pkg["name"] == "docker":
|
||||
status = _docker_row_status(
|
||||
on_remote=on_remote,
|
||||
@@ -933,4 +1058,39 @@ def setup_shell_routes() -> APIRouter:
|
||||
return {"ok": True, "output": stdout.decode()[-200:]}
|
||||
return {"ok": False, "error": stderr.decode()[-300:]}
|
||||
|
||||
@router.post("/api/cookbook/rebuild-engine")
|
||||
async def rebuild_engine(request: Request):
|
||||
"""Clear the cached llama.cpp build so the next serve recompiles.
|
||||
|
||||
Admin only — this removes the Cookbook-managed ``~/bin/llama-server``
|
||||
symlink and ``~/llama.cpp/build`` directory, locally or on the selected
|
||||
remote server. It installs and downloads nothing; the next llama.cpp
|
||||
serve rebuilds from source and picks up CUDA/HIP if a toolchain is now
|
||||
present. This is the missing "force a fresh GPU build" lever for hosts
|
||||
stuck on a CPU-only llama-server.
|
||||
"""
|
||||
_require_admin(request)
|
||||
from routes.cookbook_helpers import _llama_cpp_rebuild_cmd
|
||||
body = await request.json()
|
||||
engine = str(body.get("engine") or "llamacpp").strip()
|
||||
if engine != "llamacpp":
|
||||
return {"ok": False, "error": f"Unsupported engine: {engine}"}
|
||||
host = str(body.get("remote_host") or "").strip()
|
||||
ssh_port = body.get("ssh_port")
|
||||
cmd = _llama_cpp_rebuild_cmd()
|
||||
try:
|
||||
argv = (_ssh_base_argv(host, ssh_port) + [cmd]) if host else ["bash", "-lc", cmd]
|
||||
except ValueError as e:
|
||||
raise HTTPException(400, str(e))
|
||||
try:
|
||||
proc = await asyncio.create_subprocess_exec(
|
||||
*argv, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE
|
||||
)
|
||||
out, err = await asyncio.wait_for(proc.communicate(), timeout=30)
|
||||
except asyncio.TimeoutError:
|
||||
return {"ok": False, "error": "Rebuild-engine command timed out."}
|
||||
if proc.returncode == 0:
|
||||
return {"ok": True, "output": out.decode("utf-8", errors="replace")[-400:]}
|
||||
return {"ok": False, "error": err.decode("utf-8", errors="replace")[-400:]}
|
||||
|
||||
return router
|
||||
|
||||
+41
-29
@@ -79,6 +79,8 @@ def _skill_test_task(skill: dict) -> str:
|
||||
an email); if we just hand over the 'when to use' text the agent has nothing
|
||||
to work on and stalls asking for input. So we tell it to create its own
|
||||
realistic fixture first, then apply the skill end-to-end."""
|
||||
if not isinstance(skill, dict):
|
||||
skill = {}
|
||||
ctx = (skill.get("when_to_use") or skill.get("description") or skill.get("name") or "").strip()
|
||||
return (
|
||||
"Test this skill end-to-end. FIRST, set up a small realistic scenario it "
|
||||
@@ -310,6 +312,8 @@ def _should_check_retrieval_precision(skill: dict) -> bool:
|
||||
"installation", "install", "system", "ssh", "document", "documents",
|
||||
"search", "email", "calendar", "gpu", "server", "python",
|
||||
}
|
||||
if not isinstance(skill, dict):
|
||||
return False
|
||||
tags = {str(t or "").strip().lower() for t in (skill.get("tags") or [])}
|
||||
if tags & broad:
|
||||
return True
|
||||
@@ -463,13 +467,13 @@ async def _run_skill_test_job(key, name, md, task, url, model, headers, owner, s
|
||||
if skills_manager is not None:
|
||||
v = (job["verdict"] or {}).get("verdict") or "unknown"
|
||||
try:
|
||||
skills_manager.set_audit(name, v, by_teacher=False, worker_model=model)
|
||||
skills_manager.set_audit(name, v, by_teacher=False, worker_model=model, owner=owner)
|
||||
except Exception:
|
||||
pass
|
||||
conf = {"pass": 0.95, "needs_work": 0.6, "fail": 0.4}.get(v)
|
||||
if conf is not None:
|
||||
try:
|
||||
skills_manager.update_skill(name, {"confidence": conf})
|
||||
skills_manager.update_skill(name, {"confidence": conf}, owner=owner)
|
||||
except Exception:
|
||||
pass
|
||||
job["status"] = "done"
|
||||
@@ -563,6 +567,7 @@ def _skill_duplicate_blocker(skills_manager, name: str, owner) -> Optional[str]:
|
||||
False,
|
||||
[keeper_name],
|
||||
f"Lower-priority duplicate of {keeper_name}",
|
||||
owner=owner,
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
@@ -629,7 +634,7 @@ def _audit_finalize_status(skills_manager, name: str, owner, verdict: str,
|
||||
if generic_reason:
|
||||
necessary = False
|
||||
try:
|
||||
skills_manager.set_necessity(name, False, [], generic_reason)
|
||||
skills_manager.set_necessity(name, False, [], generic_reason, owner=owner)
|
||||
except Exception:
|
||||
pass
|
||||
duplicate_of = _skill_duplicate_blocker(skills_manager, name, owner) if verdict == "pass" else None
|
||||
@@ -638,7 +643,7 @@ def _audit_finalize_status(skills_manager, name: str, owner, verdict: str,
|
||||
c = float(confidence or 0.0)
|
||||
status = "published" if (auto_publish and necessary and verdict == "pass" and c >= min_conf) else "draft"
|
||||
try:
|
||||
skills_manager.update_skill(name, {"status": status})
|
||||
skills_manager.update_skill(name, {"status": status}, owner=owner)
|
||||
except Exception:
|
||||
pass
|
||||
return status
|
||||
@@ -662,7 +667,7 @@ def _apply_skill_md(skills_manager, name: str, md: str, owner) -> bool:
|
||||
"teacher_model": sk.teacher_model, "owner": sk.owner or owner,
|
||||
"when_to_use": sk.when_to_use, "procedure": sk.procedure,
|
||||
"pitfalls": sk.pitfalls, "verification": sk.verification, "body_extra": sk.body_extra,
|
||||
}))
|
||||
}, owner=owner))
|
||||
except Exception as e:
|
||||
logger.warning(f"Audit: could not save edited skill {name}: {e}")
|
||||
return False
|
||||
@@ -762,11 +767,11 @@ async def _audit_one_skill(skills_manager, skill, url, model, headers,
|
||||
# earns a bit less; a skill that still fails is marked low.
|
||||
def _set_conf(c):
|
||||
try:
|
||||
skills_manager.update_skill(name, {"confidence": c})
|
||||
skills_manager.update_skill(name, {"confidence": c}, owner=owner)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
md = skills_manager.read_skill_md(name)
|
||||
md = skills_manager.read_skill_md(name, owner=owner)
|
||||
if not md:
|
||||
log(f"{name}: no source — skipped")
|
||||
return {"skill": name, "result": "skipped"}
|
||||
@@ -788,7 +793,8 @@ async def _audit_one_skill(skills_manager, skill, url, model, headers,
|
||||
nec = await _eval_skill_necessity(md, others, url, model, headers)
|
||||
if nec is not None:
|
||||
skills_manager.set_necessity(name, nec.get("necessary", True),
|
||||
nec.get("redundant_with"), nec.get("reason"))
|
||||
nec.get("redundant_with"), nec.get("reason"),
|
||||
owner=owner)
|
||||
if not nec.get("necessary", True):
|
||||
log(f"{name}: possibly unnecessary — {nec.get('reason', '')[:80]}")
|
||||
except Exception as e:
|
||||
@@ -799,12 +805,12 @@ async def _audit_one_skill(skills_manager, skill, url, model, headers,
|
||||
if generic_reason or duplicate_of or (isinstance(nec, dict) and nec.get("necessary") is False):
|
||||
reason = generic_reason or (f"Lower-priority duplicate of {duplicate_of}" if duplicate_of else str((nec or {}).get("reason") or "Unnecessary skill"))
|
||||
try:
|
||||
skills_manager.update_skill(name, {"status": "draft", "confidence": 0.35})
|
||||
skills_manager.set_audit(name, "skipped", by_teacher=False, worker_model=model)
|
||||
skills_manager.update_skill(name, {"status": "draft", "confidence": 0.35}, owner=owner)
|
||||
skills_manager.set_audit(name, "skipped", by_teacher=False, worker_model=model, owner=owner)
|
||||
if duplicate_of:
|
||||
skills_manager.set_necessity(name, False, [duplicate_of], reason)
|
||||
skills_manager.set_necessity(name, False, [duplicate_of], reason, owner=owner)
|
||||
else:
|
||||
skills_manager.set_necessity(name, False, [], reason)
|
||||
skills_manager.set_necessity(name, False, [], reason, owner=owner)
|
||||
except Exception:
|
||||
pass
|
||||
log(f"{name}: draft — skipped functional test ({reason[:100]})")
|
||||
@@ -848,13 +854,13 @@ async def _audit_one_skill(skills_manager, skill, url, model, headers,
|
||||
if fixed and fixed.strip() != md.strip():
|
||||
_apply_skill_md(skills_manager, name, fixed, owner)
|
||||
_set_conf(0.95)
|
||||
skills_manager.set_audit(name, "pass", by_teacher=False, worker_model=model)
|
||||
skills_manager.set_audit(name, "pass", by_teacher=False, worker_model=model, owner=owner)
|
||||
refreshed = next((s for s in skills_manager.load(owner=owner) if s.get("name") == name), None)
|
||||
status = _audit_finalize_status(skills_manager, name, owner, "pass", 0.95, (refreshed or {}).get("necessity"), verdict)
|
||||
log(f"{name}: {status} — confidence 95%")
|
||||
return {"skill": name, "result": "pass", "verdict": verdict, "confidence": 0.95, "status": status}
|
||||
if v in ("unknown", "inconclusive"):
|
||||
skills_manager.set_audit(name, "inconclusive", by_teacher=False, worker_model=model)
|
||||
skills_manager.set_audit(name, "inconclusive", by_teacher=False, worker_model=model, owner=owner)
|
||||
status = _audit_finalize_status(skills_manager, name, owner, "inconclusive", skill.get("confidence") or 0.0, skill.get("necessity"))
|
||||
log(f"{name}: {status} — inconclusive")
|
||||
return {"skill": name, "result": "inconclusive", "verdict": verdict, "status": status}
|
||||
@@ -869,7 +875,7 @@ async def _audit_one_skill(skills_manager, skill, url, model, headers,
|
||||
log(f"{name}: retry (self) = {v}")
|
||||
if v == "pass":
|
||||
_set_conf(0.85)
|
||||
skills_manager.set_audit(name, "pass", by_teacher=False, worker_model=model)
|
||||
skills_manager.set_audit(name, "pass", by_teacher=False, worker_model=model, owner=owner)
|
||||
refreshed = next((s for s in skills_manager.load(owner=owner) if s.get("name") == name), None)
|
||||
status = _audit_finalize_status(skills_manager, name, owner, "pass", 0.85, (refreshed or {}).get("necessity"), verdict)
|
||||
log(f"{name}: {status} — confidence 85% after self-edit")
|
||||
@@ -893,7 +899,9 @@ async def _audit_one_skill(skills_manager, skill, url, model, headers,
|
||||
log(f"{name}: retry on student after teacher rewrite = {v}")
|
||||
if v == "pass":
|
||||
_set_conf(0.8)
|
||||
skills_manager.set_audit(name, "pass", by_teacher=True, worker_model=model, teacher_model=t_model)
|
||||
skills_manager.set_audit(
|
||||
name, "pass", by_teacher=True, worker_model=model, teacher_model=t_model, owner=owner
|
||||
)
|
||||
refreshed = next((s for s in skills_manager.load(owner=owner) if s.get("name") == name), None)
|
||||
status = _audit_finalize_status(skills_manager, name, owner, "pass", 0.8, (refreshed or {}).get("necessity"), verdict)
|
||||
log(f"{name}: {status} — confidence 80% after teacher rewrite")
|
||||
@@ -901,13 +909,14 @@ async def _audit_one_skill(skills_manager, skill, url, model, headers,
|
||||
|
||||
# Still failing → demote to draft + low confidence + flag (do NOT delete).
|
||||
try:
|
||||
skills_manager.update_skill(name, {"status": "draft", "confidence": 0.35})
|
||||
skills_manager.update_skill(name, {"status": "draft", "confidence": 0.35}, owner=owner)
|
||||
except Exception:
|
||||
pass
|
||||
skills_manager.set_audit(
|
||||
name, v or "fail", by_teacher=teacher_ran,
|
||||
worker_model=model,
|
||||
teacher_model=(teacher[1] if teacher_ran and teacher else ""),
|
||||
owner=owner,
|
||||
)
|
||||
log(f"{name}: flagged — confidence lowered, kept as draft for manual review")
|
||||
return {"skill": name, "result": "flagged", "verdict": verdict, "confidence": 0.35}
|
||||
@@ -976,7 +985,7 @@ async def _run_audit_all_job(key, skills_manager, names, url, model, headers, te
|
||||
job.pop("task", None)
|
||||
|
||||
|
||||
def _resolve_audit_models():
|
||||
def _resolve_audit_models(owner=None):
|
||||
"""Resolve (url, model, headers, teacher) for an audit run from Settings.
|
||||
|
||||
Worker = Utility model (falling back to Default, normalized to a served
|
||||
@@ -985,7 +994,7 @@ def _resolve_audit_models():
|
||||
ValueError if no worker model.
|
||||
"""
|
||||
from src.endpoint_resolver import resolve_endpoint
|
||||
url, model, headers = resolve_endpoint("utility")
|
||||
url, model, headers = resolve_endpoint("utility", owner=owner)
|
||||
if not url or not model:
|
||||
raise ValueError("No model configured — set a Default or Utility model in Settings.")
|
||||
try:
|
||||
@@ -1029,7 +1038,7 @@ async def run_scheduled_skill_audit(skills_manager: SkillsManager,
|
||||
return {"status": "running", "skipped": True}
|
||||
|
||||
try:
|
||||
url, model, headers, teacher = _resolve_audit_models()
|
||||
url, model, headers, teacher = _resolve_audit_models(owner=owner)
|
||||
except ValueError as e:
|
||||
logger.info(f"Scheduled skill audit skipped — {e}")
|
||||
return {"status": "skipped", "reason": str(e)}
|
||||
@@ -1246,7 +1255,7 @@ def setup_skills_routes(skills_manager: SkillsManager) -> APIRouter:
|
||||
if not match:
|
||||
raise HTTPException(404, "Skill not found")
|
||||
_verify_owner(match, user)
|
||||
md = skills_manager.read_skill_md(match.get("name"))
|
||||
md = skills_manager.read_skill_md(match.get("name"), owner=user)
|
||||
if md is None:
|
||||
raise HTTPException(404, "Skill source unavailable (legacy entry?)")
|
||||
return {"name": match.get("name"), "markdown": md}
|
||||
@@ -1273,14 +1282,14 @@ def setup_skills_routes(skills_manager: SkillsManager) -> APIRouter:
|
||||
raise HTTPException(404, "Skill not found")
|
||||
_verify_owner(match, user)
|
||||
name = match.get("name")
|
||||
md = skills_manager.read_skill_md(name) or ""
|
||||
md = skills_manager.read_skill_md(name, owner=user) or ""
|
||||
|
||||
if not task:
|
||||
task = _skill_test_task(match)
|
||||
|
||||
# Prefer the configured DEFAULT (→ Utility) model — not the current chat
|
||||
# session's model. Fall back to the caller's session model only if unset.
|
||||
url, model, headers = resolve_endpoint("default")
|
||||
url, model, headers = resolve_endpoint("default", owner=user)
|
||||
if not url or not model:
|
||||
url = url or ((body.get("endpoint_url") or "").strip() or None)
|
||||
model = model or ((body.get("model") or "").strip() or None)
|
||||
@@ -1360,7 +1369,7 @@ def setup_skills_routes(skills_manager: SkillsManager) -> APIRouter:
|
||||
|
||||
# Worker model (Default, normalized) + optional teacher — shared resolver.
|
||||
try:
|
||||
url, model, headers, teacher = _resolve_audit_models()
|
||||
url, model, headers, teacher = _resolve_audit_models(owner=user)
|
||||
except ValueError as e:
|
||||
raise HTTPException(400, str(e))
|
||||
|
||||
@@ -1437,7 +1446,7 @@ def setup_skills_routes(skills_manager: SkillsManager) -> APIRouter:
|
||||
@router.post("/{skill_id}/markdown")
|
||||
async def save_skill_markdown(request: Request, skill_id: str):
|
||||
"""Replace SKILL.md with new raw content. Parses + validates first."""
|
||||
from services.memory.skill_format import Skill, slugify
|
||||
from services.memory.skill_format import Skill
|
||||
user = _owner(request)
|
||||
body = await request.json()
|
||||
new_content = body.get("markdown")
|
||||
@@ -1452,7 +1461,10 @@ def setup_skills_routes(skills_manager: SkillsManager) -> APIRouter:
|
||||
sk = Skill.from_markdown(new_content)
|
||||
except Exception as e:
|
||||
raise HTTPException(400, f"Could not parse SKILL.md: {e}")
|
||||
sk.name = slugify(sk.name or match.get("name"))
|
||||
# Never rename on save: a changed `name` in the markdown would move
|
||||
# the skill dir (update_skill) and orphan the original id, so a later
|
||||
# delete 404s (#1333). Pin to the stored name, like _apply_skill_md.
|
||||
sk.name = match.get("name")
|
||||
if not sk.owner:
|
||||
sk.owner = match.get("owner") or user
|
||||
ok = skills_manager.update_skill(match.get("name"), {
|
||||
@@ -1474,7 +1486,7 @@ def setup_skills_routes(skills_manager: SkillsManager) -> APIRouter:
|
||||
"pitfalls": sk.pitfalls,
|
||||
"verification": sk.verification,
|
||||
"body_extra": sk.body_extra,
|
||||
})
|
||||
}, owner=user)
|
||||
if not ok:
|
||||
raise HTTPException(500, "Update failed")
|
||||
# Manual markdown edits can create or substantially rewrite a draft
|
||||
@@ -1496,7 +1508,7 @@ def setup_skills_routes(skills_manager: SkillsManager) -> APIRouter:
|
||||
updates = body.dict(exclude_none=True)
|
||||
if not updates:
|
||||
return {"ok": True}
|
||||
ok = skills_manager.update_skill(match.get("name"), updates)
|
||||
ok = skills_manager.update_skill(match.get("name"), updates, owner=user)
|
||||
if not ok:
|
||||
raise HTTPException(404, "Skill not found")
|
||||
if not match.get("audit_verdict"):
|
||||
@@ -1511,7 +1523,7 @@ def setup_skills_routes(skills_manager: SkillsManager) -> APIRouter:
|
||||
if not match:
|
||||
raise HTTPException(404, "Skill not found")
|
||||
_verify_owner(match, user)
|
||||
ok = skills_manager.delete_skill(match.get("name"))
|
||||
ok = skills_manager.delete_skill(match.get("name"), owner=user)
|
||||
if not ok:
|
||||
raise HTTPException(404, "Skill not found")
|
||||
return {"ok": True}
|
||||
|
||||
+14
-10
@@ -8,6 +8,7 @@ from typing import List
|
||||
import logging
|
||||
from core.middleware import require_admin
|
||||
from src.auth_helpers import get_current_user
|
||||
from src.upload_handler import count_recent_uploads
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -24,15 +25,18 @@ def setup_upload_routes(upload_handler):
|
||||
|
||||
client_ip = request.client.host if request.client else "unknown"
|
||||
out = []
|
||||
|
||||
# Limit concurrent uploads per IP
|
||||
ip_upload_count = sum(
|
||||
1 for f in files
|
||||
if client_ip in upload_handler.upload_rate_log and
|
||||
any(now > time.time() - 10 for now in upload_handler.upload_rate_log[client_ip][-len(files):])
|
||||
|
||||
# Limit concurrent uploads per IP. Count genuine recent upload events —
|
||||
# NOT the number of files in this batch. The previous check summed over
|
||||
# `files`, so a single multi-file request counted itself as N concurrent
|
||||
# uploads and tripped the limit (issue #1346: "attach more than one file
|
||||
# → the model doesn't even see them"). save_upload still enforces the
|
||||
# per-minute sliding-window rate limit per file.
|
||||
recent_uploads = count_recent_uploads(
|
||||
upload_handler.upload_rate_log.get(client_ip, []), time.time()
|
||||
)
|
||||
|
||||
if ip_upload_count >= upload_handler.max_concurrent_uploads:
|
||||
|
||||
if recent_uploads >= upload_handler.max_concurrent_uploads:
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail=f"Maximum concurrent uploads ({upload_handler.max_concurrent_uploads}) exceeded"
|
||||
@@ -107,7 +111,7 @@ def setup_upload_routes(upload_handler):
|
||||
if os.path.exists(uploads_db):
|
||||
with open(uploads_db, encoding="utf-8") as f:
|
||||
db = json.load(f)
|
||||
info = next((fi for fi in db.values() if fi["id"] == file_id), None)
|
||||
info = next((fi for fi in db.values() if fi.get("id") == file_id), None)
|
||||
if info:
|
||||
original_name = info.get("name", file_id)
|
||||
auth_mgr = getattr(request.app.state, "auth_manager", None)
|
||||
@@ -155,7 +159,7 @@ def setup_upload_routes(upload_handler):
|
||||
if os.path.exists(uploads_db):
|
||||
with open(uploads_db, encoding="utf-8") as f:
|
||||
db = json.load(f)
|
||||
info = next((fi for fi in db.values() if fi["id"] == file_id), None)
|
||||
info = next((fi for fi in db.values() if fi.get("id") == file_id), None)
|
||||
return info
|
||||
|
||||
def _vision_cache_path(file_id: str) -> str:
|
||||
|
||||
+15
-3
@@ -61,7 +61,8 @@ def _find_bw() -> str:
|
||||
def _load_config() -> dict:
|
||||
if VAULT_FILE.exists():
|
||||
try:
|
||||
return json.loads(VAULT_FILE.read_text(encoding="utf-8"))
|
||||
data = json.loads(VAULT_FILE.read_text(encoding="utf-8"))
|
||||
return data if isinstance(data, dict) else {}
|
||||
except Exception:
|
||||
pass
|
||||
return {}
|
||||
@@ -75,11 +76,18 @@ def _save_config(cfg: dict):
|
||||
safe_chmod(str(VAULT_FILE), 0o600)
|
||||
|
||||
|
||||
async def _run_bw(args: list, session: str = None, input_text: str = None) -> tuple:
|
||||
async def _run_bw(args: list, session: str = None, input_text: str = None,
|
||||
bw_password: str = None) -> tuple:
|
||||
env = {}
|
||||
env.update(os.environ)
|
||||
if session:
|
||||
env["BW_SESSION"] = session
|
||||
# Secrets must never be passed as argv — process arguments are world-readable
|
||||
# via `ps` / `/proc/<pid>/cmdline` to any local user. Keep --passwordenv
|
||||
# support for bw commands that need it; unlock/login callers should prefer
|
||||
# stdin so the master password is not left in the child environment either.
|
||||
if bw_password is not None:
|
||||
env["BW_PASSWORD"] = bw_password
|
||||
bw_path = _find_bw()
|
||||
try:
|
||||
proc = await asyncio.create_subprocess_exec(
|
||||
@@ -175,8 +183,12 @@ def setup_vault_routes():
|
||||
async def unlock(req: VaultUnlockRequest, request: Request):
|
||||
"""Unlock the vault and save the session key."""
|
||||
require_admin(request)
|
||||
# Pass the master password on stdin, not argv. argv is visible through
|
||||
# `ps` / /proc/<pid>/cmdline; stdin also avoids leaving the secret in
|
||||
# the child process environment.
|
||||
stdout, stderr, rc = await _run_bw(
|
||||
["unlock", req.master_password, "--raw"],
|
||||
["unlock", "--raw"],
|
||||
input_text=req.master_password + "\n",
|
||||
)
|
||||
if rc != 0:
|
||||
return {"ok": False, "error": f"Unlock failed: {stderr[:300]}"}
|
||||
|
||||
@@ -26,6 +26,44 @@ MAX_MESSAGE_LEN = 32_000
|
||||
from core.middleware import require_admin as _require_admin
|
||||
|
||||
|
||||
def _first_enabled_endpoint(db, owner):
|
||||
"""First enabled ModelEndpoint VISIBLE to `owner` — their own rows plus
|
||||
legacy null-owner ("shared") rows. Owner-scoped on purpose: ModelEndpoint
|
||||
is per-user (core/database.py — "when non-null, the model picker only shows
|
||||
the endpoint to that user"), and the sync-chat fallback uses the row's
|
||||
decrypted `api_key`. An unscoped ``.first()`` would let a chat-scoped token
|
||||
(e.g. a paired mobile device) fall back onto ANOTHER user's private
|
||||
endpoint and silently spend that owner's API key / quota — and reach
|
||||
whatever internal base_url they configured. Mirrors the owner_filter scoping
|
||||
in routes/model_routes.py and companion/routes.py. A null/empty owner is a
|
||||
no-op (single-user / legacy mode), preserving the original behaviour.
|
||||
"""
|
||||
from core.database import ModelEndpoint
|
||||
from src.auth_helpers import owner_filter
|
||||
q = db.query(ModelEndpoint).filter(ModelEndpoint.is_enabled == True) # noqa: E712
|
||||
q = owner_filter(q, ModelEndpoint, owner)
|
||||
return q.first()
|
||||
|
||||
|
||||
def _caller_owns_session(sess_owner, caller) -> bool:
|
||||
"""Strict session-ownership gate for the token-authenticated sync-chat
|
||||
endpoint (`POST /api/v1/chat`).
|
||||
|
||||
Mirrors ``_verify_session_owner`` in session_routes.py and the null-owner
|
||||
gates in notes/calendar/gallery: a caller may resume a session ONLY when
|
||||
its owner matches them exactly. A null/empty session owner (legacy or
|
||||
migrated rows) is deliberately NOT resumable by an arbitrary token — the
|
||||
old ``sess_owner and sess_owner != caller`` form skipped the check whenever
|
||||
``sess_owner`` was falsy, so any chat-scoped token (e.g. a paired mobile
|
||||
device) could resume such a session, inject a message, and read back its
|
||||
history and reuse the owner's endpoint credentials. Fail closed: an
|
||||
unresolvable caller also returns False.
|
||||
"""
|
||||
if not caller:
|
||||
return False
|
||||
return sess_owner == caller
|
||||
|
||||
|
||||
def setup_webhook_routes(
|
||||
webhook_manager: WebhookManager,
|
||||
auth_manager,
|
||||
@@ -159,6 +197,7 @@ def setup_webhook_routes(
|
||||
"openrouter": "https://openrouter.ai/api/v1",
|
||||
"ollama": "https://ollama.com/api",
|
||||
"fireworks": "https://api.fireworks.ai/inference/v1",
|
||||
"venice": "https://api.venice.ai/api/v1",
|
||||
}
|
||||
|
||||
# Model prefix → provider mapping for auto-detection
|
||||
@@ -203,7 +242,6 @@ def setup_webhook_routes(
|
||||
|
||||
from core.models import ChatMessage
|
||||
from src.llm_core import llm_call_async
|
||||
from core.database import ModelEndpoint
|
||||
from src.endpoint_resolver import build_chat_url, build_headers, build_models_url, normalize_base
|
||||
|
||||
message = body.message.strip()
|
||||
@@ -228,8 +266,11 @@ def setup_webhook_routes(
|
||||
_tok_user = token_owner or getattr(request.state, "user", None) or _gcu(request)
|
||||
except Exception:
|
||||
_tok_user = None
|
||||
# Strict ownership (see _caller_owns_session): fail closed so a
|
||||
# null-owner / cross-owner session can't be resumed by an arbitrary
|
||||
# chat-scoped token.
|
||||
_sess_owner = getattr(sess, "owner", None)
|
||||
if _tok_user and _sess_owner and _sess_owner != _tok_user:
|
||||
if not _caller_owns_session(_sess_owner, _tok_user):
|
||||
raise HTTPException(404, "Session not found")
|
||||
|
||||
# --- Case 2: Direct API key + model (no pre-configured endpoint needed) ---
|
||||
@@ -265,7 +306,9 @@ def setup_webhook_routes(
|
||||
if not sess:
|
||||
db = SessionLocal()
|
||||
try:
|
||||
ep = db.query(ModelEndpoint).filter(ModelEndpoint.is_enabled == True).first()
|
||||
# Owner-scoped: only THIS token owner's endpoints + legacy
|
||||
# shared rows, never another user's private endpoint/api_key.
|
||||
ep = _first_enabled_endpoint(db, token_owner)
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user