Merge remote-tracking branch 'origin/main' into visual-pr-playground

# Conflicts:
#	routes/cookbook_routes.py
#	routes/hwfit_routes.py
#	services/hwfit/fit.py
#	services/hwfit/models.py
#	static/js/cookbook-diagnosis.js
#	static/js/cookbook-hwfit.js
#	static/js/cookbook.js
#	static/js/cookbookRunning.js
This commit is contained in:
pewdiepie-archdaemon
2026-06-03 16:49:10 +09:00
569 changed files with 35252 additions and 3489 deletions
+3 -1
View File
@@ -27,6 +27,7 @@ from core.database import (
Document,
DocumentVersion,
GalleryImage,
GalleryAlbum,
CalendarEvent,
CalendarCal,
)
@@ -145,8 +146,9 @@ def setup_admin_wipe_routes(session_manager):
return {"status": "deleted", "kind": kind, "count": count}
if kind == "gallery":
count = db.query(GalleryImage).count()
count = db.query(GalleryImage).count() + db.query(GalleryAlbum).count()
db.query(GalleryImage).delete()
db.query(GalleryAlbum).delete()
db.commit()
# Also drop the upload dir so disk doesn't keep orphans.
_rmtree_quiet(os.path.join(DATA_DIR, "gallery"))
+31 -6
View File
@@ -67,6 +67,8 @@ class DeleteUserRequest(BaseModel):
class RenameUserRequest(BaseModel):
username: str
class SetOpenRegistrationRequest(BaseModel):
enabled: bool
SESSION_COOKIE = "odysseus_session"
@@ -295,6 +297,7 @@ def setup_auth_routes(auth_manager: AuthManager) -> APIRouter:
# owner-scoped DB rows before changing auth so the account keeps
# access to its sessions, docs, email accounts, tasks, etc.
try:
from sqlalchemy import func
from core.database import Base, SessionLocal
db = SessionLocal()
try:
@@ -304,7 +307,7 @@ def setup_auth_routes(auth_manager: AuthManager) -> APIRouter:
continue
(
db.query(model)
.filter(model.owner == old_username)
.filter(func.lower(model.owner) == old_username)
.update({"owner": new_username}, synchronize_session=False)
)
db.commit()
@@ -322,9 +325,15 @@ def setup_auth_routes(auth_manager: AuthManager) -> APIRouter:
from routes.prefs_routes import _load as _load_prefs, _save as _save_prefs
prefs = _load_prefs()
users = prefs.get("_users") if isinstance(prefs, dict) else None
if isinstance(users, dict) and old_username in users and new_username not in users:
users[new_username] = users.pop(old_username)
_save_prefs(prefs)
if isinstance(users, dict):
prefs_key = next(
(k for k in users if str(k).strip().lower() == old_username),
None,
)
new_taken = any(str(k).strip().lower() == new_username for k in users)
if prefs_key is not None and not new_taken:
users[new_username] = users.pop(prefs_key)
_save_prefs(prefs)
except Exception as e:
logger.warning("Failed to rename user prefs %s -> %s: %s", old_username, new_username, e)
@@ -333,15 +342,31 @@ def setup_auth_routes(auth_manager: AuthManager) -> APIRouter:
raise HTTPException(400, "Cannot rename user")
return {"ok": True, "username": new_username, "renamed_self": old_username == user}
@router.post("/signup-toggle")
@router.post("/signup-toggle", deprecated=True)
async def toggle_signup(request: Request):
"""Toggle open registration on/off. Admin only."""
"""
Toggle open registration on/off. Admin only.
DEPRECATED: This endpoint uses toggle semantics which can lead to unsafe state changes.
Use PUT /open-signup instead.
This endpoint is kept for backward compatibility and may be removed in future versions.
"""
user = _get_current_user(request)
if not user or not auth_manager.is_admin(user):
raise HTTPException(403, "Admin only")
auth_manager.signup_enabled = not auth_manager.signup_enabled
return {"ok": True, "signup_enabled": auth_manager.signup_enabled}
@router.put("/open-signup")
async def set_signup_enabled(body: SetOpenRegistrationRequest, request: Request):
"""Set open signup enabled state. Admin only."""
user = _get_current_user(request)
if not user or not auth_manager.is_admin(user):
raise HTTPException(403, "Admin only")
auth_manager.signup_enabled = body.enabled
return {"ok": True,"signup_enabled": auth_manager.signup_enabled}
@router.delete("/users")
async def admin_delete_user(body: DeleteUserRequest, request: Request):
user = _get_current_user(request)
+6 -1
View File
@@ -77,7 +77,12 @@ def setup_backup_routes(memory_manager, preset_manager, skills_manager) -> APIRo
# ── Memories ──
if "memories" in body and isinstance(body["memories"], list):
existing = memory_manager.load_all()
existing_texts = {e.get("text", "").strip().lower() for e in existing}
# Dedup against THIS user's own memories only. Using every tenant's
# rows (load_all) meant a memory whose text matched any other
# user's was silently skipped, so the importing user lost their own
# data. The full store is still saved back below.
existing_texts = {e.get("text", "").strip().lower()
for e in existing if e.get("owner") == user}
added = 0
for mem in body["memories"]:
if not isinstance(mem, dict) or not mem.get("text"):
+147 -27
View File
@@ -12,10 +12,27 @@ from dateutil.rrule import rrulestr, rruleset
from dateutil.rrule import DAILY, WEEKLY, MONTHLY, YEARLY
from core.database import SessionLocal, CalendarCal, CalendarEvent
from src.auth_helpers import get_current_user
from src.auth_helpers import get_current_user, require_user
logger = logging.getLogger(__name__)
def _ics_naive_dtstart(dt):
"""Naive value matching how import_ics STORES CalendarEvent.dtstart.
Timed tz-aware events are stored as UTC with tzinfo stripped, all-day
dates as midnight datetimes, naive datetimes unchanged. The ICS dedup
must compute the same value or a re-import never matches the stored row.
"""
if isinstance(dt, datetime):
if dt.tzinfo is not None:
from datetime import timezone as _tz
return dt.astimezone(_tz.utc).replace(tzinfo=None)
return dt
if isinstance(dt, date):
return datetime(dt.year, dt.month, dt.day)
return dt
# Single-user fallback identity. Used only when:
# 1. The app is configured for single-user (no auth middleware), AND
# 2. The request didn't resolve to an authenticated user.
@@ -28,16 +45,17 @@ _SINGLE_USER_MODE = _os.environ.get("ODYSSEUS_SINGLE_USER", "1") != "0"
def _require_user(request: Request) -> str:
"""Return the authenticated user. In multi-user mode an unauthenticated
request raises 401; in single-user mode it falls through to
FALLBACK_OWNER. Prevents the silent cross-user data write that would
happen if a request slipped past auth middleware in a real deployment."""
u = get_current_user(request)
if u:
return u
if _SINGLE_USER_MODE:
return FALLBACK_OWNER
raise HTTPException(401, "Authentication required")
"""Return the authenticated user. Uses require_user so AUTH_ENABLED=false
and single-user mode both work: require_user returns "" when auth is
disabled or unconfigured, and only raises 401 when auth is configured but
the caller is unauthenticated. Falls back to FALLBACK_OWNER for calendar
writes so data isn't stored under an empty owner in single-user mode."""
user = require_user(request)
if user:
return user
# require_user returned "" — auth is off or unconfigured (single-user).
# Use FALLBACK_OWNER so calendar rows have a stable owner for filtering.
return FALLBACK_OWNER
def _get_or_404_calendar(db, cal_id: str, owner: str) -> CalendarCal:
@@ -64,6 +82,24 @@ def _get_or_404_event(db, uid: str, owner: str) -> CalendarEvent:
return ev
def _ics_escape(text: str) -> str:
"""Escape a value for an iCalendar TEXT field (RFC 5545 §3.3.11).
Backslash, semicolon and comma are structural in TEXT values and must be
escaped, and newlines become a literal ``\\n``. Backslash is escaped first
so the escapes we add aren't re-escaped.
"""
return (
(text or "")
.replace("\\", "\\\\")
.replace(";", "\\;")
.replace(",", "\\,")
.replace("\r\n", "\\n")
.replace("\n", "\\n")
.replace("\r", "\\n")
)
def _resolve_base_uid(uid: str) -> str:
"""Extract the base series UID from a compound occurrence UID.
@@ -319,8 +355,8 @@ def _parse_dt(s: str) -> datetime:
return None
return h, mn
# today/tomorrow/yesterday [at] TIME
m = _re.match(r'^(today|tomorrow|tmrw|yesterday)(?:\s+at)?\s*(.*)$', lower)
# today/tonight/tomorrow/yesterday [at] TIME
m = _re.match(r'^(today|tonight|tomorrow|tmrw|yesterday)(?:\s+at)?\s*(.*)$', lower)
if m:
word, rest = m.group(1), m.group(2).strip()
base = today
@@ -434,8 +470,21 @@ def _expand_rrule(
return [d]
# Parse the rrule, applying it to the base dtstart.
rrule_str = ev.rrule
if ev.dtstart is not None and getattr(ev.dtstart, "tzinfo", None) is None:
# Events are stored with a naive (UTC) dtstart, but standard .ics
# exporters (Google/Apple/Outlook/Fastmail) write the bound as an
# absolute UTC value, e.g. UNTIL=20240105T090000Z. dateutil refuses to
# mix a tz-aware UNTIL with a naive DTSTART ("RRULE UNTIL values must be
# specified in UTC when DTSTART is timezone-aware"), so the except branch
# below would silently collapse the whole series to a single event.
# Drop the trailing Z so UNTIL matches the naive DTSTART.
import re as _re
rrule_str = _re.sub(
r"(UNTIL=\d{8}(?:T\d{6})?)Z", r"\1", rrule_str, flags=_re.IGNORECASE
)
try:
rule = rrulestr(ev.rrule, dtstart=ev.dtstart)
rule = rrulestr(rrule_str, dtstart=ev.dtstart)
except Exception as ex:
logger.warning(
"Failed to parse rrule=%r for event %s: %s", ev.rrule, ev.uid, ex
@@ -509,13 +558,20 @@ def setup_calendar_routes() -> APIRouter:
owner = _require_user(request)
from routes.prefs_routes import _load_for_user
cfg = (_load_for_user(owner) or {}).get("caldav", {}) or {}
caldav_password = cfg.get("password") or ""
if caldav_password:
try:
from src.secret_storage import decrypt
caldav_password = decrypt(caldav_password)
except Exception:
pass
# Surface url+username but never hand the password back to the
# client — saved-state UI shouldn't leak the credential.
return {
"url": cfg.get("url", "") or "",
"username": cfg.get("username", "") or "",
"password": "",
"has_password": bool(cfg.get("password")),
"has_password": bool(caldav_password),
"local": not bool(cfg.get("url")),
}
@@ -534,12 +590,20 @@ def setup_calendar_routes() -> APIRouter:
prefs.pop("caldav", None)
_save_for_user(owner, prefs)
return {"ok": True, "cleared": True}
cfg["url"] = body.get("url", "").strip()
from src.caldav_sync import validate_caldav_url
try:
cfg["url"] = validate_caldav_url(body.get("url", ""))
except ValueError as e:
raise HTTPException(400, str(e))
cfg["username"] = (body.get("username") or "").strip()
# Preserve the stored password when the client sends an empty
# one (edit form re-submitted without re-typing the password).
if body.get("password"):
cfg["password"] = body["password"]
from src.secret_storage import encrypt
cfg["password"] = encrypt(body["password"])
elif cfg.get("password"):
from src.secret_storage import encrypt
cfg["password"] = encrypt(cfg["password"])
prefs["caldav"] = cfg
_save_for_user(owner, prefs)
return {"ok": True}
@@ -566,9 +630,21 @@ def setup_calendar_routes() -> APIRouter:
cfg = (_load_for_user(owner) or {}).get("caldav", {}) or {}
url = url or (cfg.get("url") or "")
user = user or (cfg.get("username") or "")
pw = pw or (cfg.get("password") or "")
if not pw:
pw = cfg.get("password") or ""
if pw:
try:
from src.secret_storage import decrypt
pw = decrypt(pw)
except Exception:
pass
if not (url and user and pw):
return {"ok": False, "error": "Missing URL, username, or password"}
from src.caldav_sync import validate_caldav_url
try:
url = validate_caldav_url(url)
except ValueError as e:
return {"ok": False, "error": str(e)}
import httpx
propfind_body = (
'<?xml version="1.0" encoding="UTF-8"?>\n'
@@ -576,13 +652,25 @@ def setup_calendar_routes() -> APIRouter:
'</d:prop></d:propfind>'
)
try:
async with httpx.AsyncClient(timeout=8.0, follow_redirects=True) as cx:
async with httpx.AsyncClient(timeout=8.0, follow_redirects=False, trust_env=False) as cx:
r = await cx.request(
"PROPFIND", url,
auth=(user, pw),
headers={"Depth": "0", "Content-Type": "application/xml"},
content=propfind_body,
)
# If the server demands Digest (Baïkal default, SabreDAV-based
# servers, Radicale with htdigest), the Basic attempt above
# 401s. Retry once with httpx.DigestAuth so this test matches
# what the real sync does via caldav.DAVClient in
# src/caldav_sync.py (which negotiates the scheme).
if r.status_code == 401 and "digest" in r.headers.get("www-authenticate", "").lower():
r = await cx.request(
"PROPFIND", url,
auth=httpx.DigestAuth(user, pw),
headers={"Depth": "0", "Content-Type": "application/xml"},
content=propfind_body,
)
# 207 = Multi-Status — standard CalDAV success. 200 also
# acceptable. Anything else (401/403/404/5xx) means trouble.
if r.status_code in (200, 207):
@@ -593,6 +681,8 @@ def setup_calendar_routes() -> APIRouter:
return {"ok": False, "error": "Forbidden — user can't access that URL"}
if r.status_code == 404:
return {"ok": False, "error": "Not found — check the URL path"}
if 300 <= r.status_code < 400:
return {"ok": False, "error": "Redirects are not followed for CalDAV safety; use the final URL"}
return {"ok": False, "error": f"HTTP {r.status_code}"}
except httpx.ConnectError as e:
return {"ok": False, "error": f"Connection refused: {e}"[:200]}
@@ -739,6 +829,16 @@ def setup_calendar_routes() -> APIRouter:
)
db.add(ev)
db.commit()
if cal.source == "caldav":
# Push the new event to the remote so it appears on the user's
# other devices — the sync is otherwise pull-only (#800).
from src.caldav_writeback import writeback_event
await writeback_event(owner, cal.source, cal.id, {
"uid": uid, "summary": data.summary, "description": data.description,
"location": data.location, "dtstart": dtstart, "dtend": dtend,
"all_day": data.all_day, "is_utc": _is_utc and not data.all_day,
"rrule": data.rrule or "",
})
return {"ok": True, "uid": uid}
except HTTPException:
raise
@@ -785,6 +885,14 @@ def setup_calendar_routes() -> APIRouter:
if data.color is not None:
ev.color = data.color if data.color else None
db.commit()
cal = db.query(CalendarCal).filter(CalendarCal.id == ev.calendar_id).first()
if cal and cal.source == "caldav":
from src.caldav_writeback import writeback_event
await writeback_event(owner, cal.source, cal.id, {
"uid": ev.uid, "summary": ev.summary, "description": ev.description,
"location": ev.location, "dtstart": ev.dtstart, "dtend": ev.dtend,
"all_day": ev.all_day, "is_utc": ev.is_utc, "rrule": ev.rrule or "",
})
return {"ok": True}
except HTTPException:
raise
@@ -805,8 +913,15 @@ def setup_calendar_routes() -> APIRouter:
db = SessionLocal()
try:
ev = _get_or_404_event(db, base_uid, owner)
# Capture what the remote push needs BEFORE the row is gone.
_cal = db.query(CalendarCal).filter(CalendarCal.id == ev.calendar_id).first()
_is_caldav = bool(_cal and _cal.source == "caldav")
_cal_id, _ev_uid = ev.calendar_id, ev.uid
db.delete(ev)
db.commit()
if _is_caldav:
from src.caldav_writeback import writeback_event
await writeback_event(owner, "caldav", _cal_id, {"uid": _ev_uid}, delete=True)
return {"ok": True}
except HTTPException:
raise
@@ -938,7 +1053,12 @@ def setup_calendar_routes() -> APIRouter:
source_uid = str(comp.get("uid", "")) or None
if source_uid:
src_dtstart = dtstart.dt
naive_src = src_dtstart.replace(tzinfo=None) if hasattr(src_dtstart, 'tzinfo') and src_dtstart.tzinfo else src_dtstart
# Normalize to the SAME naive form import_ics stores, so a
# re-import of a tz-aware event matches the existing row.
# The old code stripped tzinfo WITHOUT converting to UTC
# (wall clock), while storage converts to UTC first, so
# every re-import of a TZID event created a duplicate.
naive_src = _ics_naive_dtstart(src_dtstart)
existing = (
db.query(CalendarEvent)
.filter(
@@ -1032,23 +1152,23 @@ def setup_calendar_routes() -> APIRouter:
"BEGIN:VCALENDAR",
"VERSION:2.0",
"PRODID:-//Odysseus//Calendar//EN",
f"X-WR-CALNAME:{cal.name}",
f"X-WR-CALNAME:{_ics_escape(cal.name)}",
]
for ev in events:
lines.append("BEGIN:VEVENT")
lines.append(f"UID:{ev.uid}")
lines.append(f"SUMMARY:{ev.summary or ''}")
lines.append(f"SUMMARY:{_ics_escape(ev.summary or '')}")
if ev.all_day:
lines.append(f"DTSTART;VALUE=DATE:{ev.dtstart.strftime('%Y%m%d')}")
lines.append(f"DTEND;VALUE=DATE:{ev.dtend.strftime('%Y%m%d')}")
else:
lines.append(f"DTSTART:{ev.dtstart.strftime('%Y%m%dT%H%M%S')}")
lines.append(f"DTEND:{ev.dtend.strftime('%Y%m%dT%H%M%S')}")
_dt_suffix = "Z" if getattr(ev, "is_utc", False) else ""
lines.append(f"DTSTART:{ev.dtstart.strftime('%Y%m%dT%H%M%S')}{_dt_suffix}")
lines.append(f"DTEND:{ev.dtend.strftime('%Y%m%dT%H%M%S')}{_dt_suffix}")
if ev.description:
desc = ev.description.replace(chr(10), '\\n')
lines.append(f"DESCRIPTION:{desc}")
lines.append(f"DESCRIPTION:{_ics_escape(ev.description)}")
if ev.location:
lines.append(f"LOCATION:{ev.location}")
lines.append(f"LOCATION:{_ics_escape(ev.location)}")
if ev.rrule:
lines.append(f"RRULE:{ev.rrule}")
lines.append("END:VEVENT")
+113 -18
View File
@@ -3,6 +3,7 @@
import asyncio
import json
import logging
import os
import re
from dataclasses import dataclass, field
from typing import Any, Optional
@@ -11,6 +12,7 @@ from core.models import ChatMessage
from core.database import SessionLocal
from core.database import Session as DBSession, ModelEndpoint
from src.llm_core import normalize_model_id
from src.endpoint_resolver import normalize_base
from src.context_compactor import maybe_compact, trim_for_context
from src.auth_helpers import get_current_user
from src.prompt_security import untrusted_context_message
@@ -119,7 +121,7 @@ def needs_auto_name(name: str) -> bool:
if name.startswith("Chat:") or name == "Chat":
return True
# Default frontend name: "modelname HH:MM:SS AM/PM"
if re.match(r'^.+ \d{1,2}:\d{2}:\d{2}\s*(AM|PM)$', name):
if re.match(r"^.+ \d{1,2}:\d{2}:\d{2}(\s*(AM|PM))?$", name, re.IGNORECASE):
return True
return False
@@ -146,9 +148,13 @@ async def auto_name_session(session_manager, sess):
if not first_msg:
return
owner = getattr(sess, "owner", None)
t_url, t_model, t_headers = resolve_task_endpoint(
sess.endpoint_url, sess.model, sess.headers,
sess.endpoint_url, sess.model, sess.headers, owner=owner,
)
if not t_model:
logger.debug("[auto-name] No model provided, skipping")
return
# max_tokens big enough that reasoning models (Minimax M2,
# DeepSeek R1, QwQ, etc.) have headroom for <think>…</think>
@@ -306,7 +312,24 @@ def fire_message_event(request, webhook_manager, session_id: str, sess, message:
fire_event("message_sent", user)
def resolve_session_auth(sess, session_id: str):
def _session_url_matches_endpoint(session_url: str, endpoint_base: str) -> bool:
if not session_url or not endpoint_base:
return False
try:
from src.endpoint_resolver import build_chat_url, normalize_base
sess_url = session_url.rstrip("/")
base = normalize_base(endpoint_base).rstrip("/")
return sess_url in {
base,
base + "/chat/completions",
build_chat_url(base).rstrip("/"),
}
except Exception:
return False
def resolve_session_auth(sess, session_id: str, owner: Optional[str] = None):
"""Ensure session has auth headers — resolve from endpoint DB if missing."""
has_auth = sess.headers and isinstance(sess.headers, dict) and any(
k.lower() in ('authorization', 'x-api-key') for k in sess.headers
@@ -315,25 +338,96 @@ def resolve_session_auth(sess, session_id: str):
return
try:
from src.endpoint_resolver import build_headers
from src.endpoint_resolver import build_headers, normalize_base
db = SessionLocal()
try:
domain = sess.endpoint_url.split("//")[1].split("/")[0] if "//" in sess.endpoint_url else ""
if domain:
ep = db.query(ModelEndpoint).filter(ModelEndpoint.base_url.contains(domain)).first()
if ep and ep.api_key:
sess.headers = build_headers(ep.api_key, ep.base_url)
db.query(DBSession).filter(DBSession.id == session_id).update(
{"headers": json.dumps(sess.headers)}
)
db.commit()
logger.info(f"Resolved and persisted auth headers for session {session_id} from endpoint {ep.name}")
target_url = getattr(sess, "endpoint_url", "") or ""
if not target_url:
return
q = db.query(ModelEndpoint).filter(ModelEndpoint.is_enabled == True)
if owner:
# Missing headers usually means "recover from the saved endpoint".
# Scope that lookup to the session owner, otherwise two users
# with similar endpoint URLs can borrow each other's API key.
from src.auth_helpers import owner_filter
q = owner_filter(q, ModelEndpoint, owner)
for ep in q.all():
if not _session_url_matches_endpoint(target_url, ep.base_url or ""):
continue
if not ep.api_key:
return
base = normalize_base(ep.base_url or "")
sess.headers = build_headers(ep.api_key, base)
update_q = db.query(DBSession).filter(DBSession.id == session_id)
if owner:
update_q = update_q.filter(DBSession.owner == owner)
update_q.update({"headers": sess.headers})
db.commit()
logger.info(f"Resolved and persisted auth headers for session {session_id} from endpoint {ep.name}")
return
finally:
db.close()
except Exception as e:
logger.warning(f"Failed to resolve session headers: {e}")
def _match_cached_model_id(requested: str, models) -> Optional[str]:
if not requested or not models:
return None
model_ids = [str(m) for m in models if m]
if requested in model_ids:
return requested
req_base = os.path.basename(requested.rstrip("/"))
for model_id in model_ids:
if os.path.basename(model_id.rstrip("/")) == req_base:
return model_id
return None
def _normalize_model_id_from_cache(sess) -> Optional[str]:
"""Use stored endpoint model IDs before falling back to a live /models probe."""
endpoint_url = getattr(sess, "endpoint_url", "") or ""
requested = getattr(sess, "model", "") or ""
if not endpoint_url or not requested:
return None
try:
session_base = normalize_base(endpoint_url)
except Exception:
session_base = endpoint_url.rstrip("/")
if not session_base:
return None
db = SessionLocal()
try:
endpoints = db.query(ModelEndpoint).filter(ModelEndpoint.is_enabled == True).all()
for ep in endpoints:
try:
if normalize_base(getattr(ep, "base_url", "") or "") != session_base:
continue
except Exception:
continue
raw_models = getattr(ep, "cached_models", None)
if not raw_models:
continue
try:
models = json.loads(raw_models) if isinstance(raw_models, str) else raw_models
except Exception:
continue
matched = _match_cached_model_id(requested, models)
if matched:
return matched
except Exception as e:
logger.debug("Cached model normalization skipped: %s", e)
finally:
db.close()
return None
async def build_chat_context(
sess,
request,
@@ -434,8 +528,9 @@ async def build_chat_context(
for transcript in preprocessed.youtube_transcripts:
preface.append(untrusted_context_message("youtube transcript", transcript))
# Normalize model ID
norm = normalize_model_id(sess.endpoint_url, sess.model)
# Normalize model ID. Prefer cached endpoint models so group chat does not
# re-hit slow local /models endpoints on every participant turn.
norm = _normalize_model_id_from_cache(sess) or normalize_model_id(sess.endpoint_url, sess.model)
if norm:
sess.model = norm
@@ -743,7 +838,7 @@ def run_post_response_tasks(
from services.memory.memory_extractor import extract_and_store
from src.task_endpoint import resolve_task_endpoint
t_url, t_model, t_headers = resolve_task_endpoint(
sess.endpoint_url, sess.model, sess.headers,
sess.endpoint_url, sess.model, sess.headers, owner=owner,
)
asyncio.create_task(extract_and_store(
sess, memory_manager, memory_vector,
@@ -780,7 +875,7 @@ def run_post_response_tasks(
from services.memory.skill_extractor import maybe_extract_skill
from src.task_endpoint import resolve_task_endpoint
s_url, s_model, s_headers = resolve_task_endpoint(
sess.endpoint_url, sess.model, sess.headers,
sess.endpoint_url, sess.model, sess.headers, owner=owner,
)
logger.debug("[skill-extract] dispatching extractor (model=%s)", s_model)
asyncio.create_task(maybe_extract_skill(
+222 -44
View File
@@ -23,10 +23,12 @@ from src.prompt_security import untrusted_context_message
from core.exceptions import SessionNotFoundError
from src.auth_helpers import get_current_user
from routes.session_routes import _verify_session_owner
from routes.document_helpers import _owner_session_filter
from core.database import SessionLocal, get_session_mode, set_session_mode
from core.database import Session as DBSession, ChatMessage as DBChatMessage
from core.database import Document as DBDocument, ModelEndpoint
from routes.research_routes import _resolve_research_endpoint
from routes.model_routes import _visible_models
from routes.chat_helpers import (
resolve_session_auth,
build_chat_context,
@@ -41,6 +43,7 @@ logger = logging.getLogger(__name__)
# Track active streams for partial-save safety net
_active_streams: Dict[str, dict] = {}
_IMAGE_MODEL_PREFIXES = ("gpt-image", "dall-e", "chatgpt-image")
def _stream_set(session_id: str, **fields) -> None:
@@ -69,13 +72,17 @@ def _session_url_matches_endpoint(session_url: str, endpoint_base: str) -> bool:
return sess in variants or sess.startswith(base + "/")
def _clear_orphaned_session_endpoint(sess) -> bool:
def _clear_orphaned_session_endpoint(sess, owner: str | None = None) -> bool:
"""Clear a session model if its endpoint was deleted from ModelEndpoint."""
if not getattr(sess, "endpoint_url", ""):
return False
db = SessionLocal()
try:
endpoints = db.query(ModelEndpoint).filter(ModelEndpoint.is_enabled == True).all()
q = db.query(ModelEndpoint).filter(ModelEndpoint.is_enabled == True)
if owner:
from src.auth_helpers import owner_filter
q = owner_filter(q, ModelEndpoint, owner)
endpoints = q.all()
for ep in endpoints:
if _session_url_matches_endpoint(sess.endpoint_url or "", ep.base_url or ""):
return False
@@ -96,6 +103,132 @@ def _clear_orphaned_session_endpoint(sess) -> bool:
db.close()
def _endpoint_cache_contains_model(endpoint, model: str) -> bool:
"""Return True when a populated endpoint model cache includes ``model``.
Empty/malformed caches are treated as unknown rather than a negative match
so older image endpoints without cached models still work.
"""
raw = getattr(endpoint, "cached_models", None)
if not raw:
return True
try:
models = json.loads(raw) if isinstance(raw, str) else raw
except Exception:
return True
if not isinstance(models, list) or not models:
return True
wanted = (model or "").strip()
return wanted in {str(item).strip() for item in models}
def _is_image_generation_session(sess, owner: str | None = None) -> bool:
"""Whether this chat session should bypass text chat and generate images.
Model-name prefixes are explicit image models. Endpoint type is only used
when the current session endpoint actually matches that image endpoint, and
when a populated endpoint model cache includes the selected model. This
prevents an image endpoint on the same host from misrouting ordinary text
models into the image-generation path.
"""
model = (getattr(sess, "model", "") or "").strip()
if any(model.lower().startswith(prefix) for prefix in _IMAGE_MODEL_PREFIXES):
return True
endpoint_url = (getattr(sess, "endpoint_url", "") or "").strip()
if not endpoint_url:
return False
db = SessionLocal()
try:
q = db.query(ModelEndpoint).filter(ModelEndpoint.is_enabled == True)
if owner:
from src.auth_helpers import owner_filter
q = owner_filter(q, ModelEndpoint, owner)
endpoints = q.all()
for endpoint in endpoints:
if (getattr(endpoint, "model_type", None) or "llm") != "image":
continue
if not _session_url_matches_endpoint(endpoint_url, getattr(endpoint, "base_url", "") or ""):
continue
if _endpoint_cache_contains_model(endpoint, model):
return True
except Exception:
return False
finally:
db.close()
return False
def _recover_empty_session_model(sess, session_id: str, owner: str | None = None) -> bool:
"""Re-populate sess.model from the matching endpoint's cached models.
Covers the window between endpoint setup and the first chat send: the
picker showed a model in the dropdown but the session record never got
written (Issue #587 — UI uses the cached endpoint list, not s.model).
Without this, we'd POST the upstream with model="" and get a generic
401/503 instead of using the model the user already picked.
Returns True iff sess.model was repaired.
"""
if getattr(sess, "model", None):
return False
db = SessionLocal()
try:
# Prefer the endpoint whose base URL matches the session — we know the
# user already pointed this session at that endpoint, so its first
# cached model is the most defensible default.
ep = None
if getattr(sess, "endpoint_url", ""):
q = db.query(ModelEndpoint).filter(ModelEndpoint.is_enabled == True)
if owner:
from src.auth_helpers import owner_filter
q = owner_filter(q, ModelEndpoint, owner)
endpoints = q.all()
for cand in endpoints:
if _session_url_matches_endpoint(sess.endpoint_url or "", cand.base_url or ""):
ep = cand
break
if not ep:
return False
try:
cached = json.loads(ep.cached_models) if isinstance(ep.cached_models, str) else (ep.cached_models or [])
except Exception:
cached = []
if not cached:
return False
try:
visible = _visible_models(cached, getattr(ep, "hidden_models", None))
except Exception:
visible = cached
if not visible:
return False
model = visible[0]
if not isinstance(model, str) or not model.strip():
return False
model = model.strip()
# Persist so the next request, websocket reconnect, or page reload
# picks up the same model (we'd otherwise re-pick on every send
# and silently switch on the user if the cached order shifts).
db_session = db.query(DBSession).filter(DBSession.id == session_id).first()
if db_session:
db_session.model = model
db_session.updated_at = datetime.utcnow()
db.commit()
sess.model = model
logger.info(
"Recovered empty session model for %s — picked %r from endpoint %s",
session_id, model, ep.id,
)
return True
except Exception as e:
db.rollback()
logger.warning("Failed to recover empty session model for %s: %s", session_id, e)
return False
finally:
db.close()
def setup_chat_routes(
session_manager,
chat_handler,
@@ -130,9 +263,20 @@ def setup_chat_routes(
sess = session_manager.get_session(session)
except KeyError:
raise HTTPException(404, f"Session '{session}' not found")
if _clear_orphaned_session_endpoint(sess):
owner = get_current_user(request)
if _clear_orphaned_session_endpoint(sess, owner=owner):
raise HTTPException(400, "Selected model endpoint was removed. Pick another model in Settings.")
# Empty model + live endpoint = setup race (Issue #587). Repair from
# the endpoint's cached model list before privilege checks, which
# otherwise see "" and behave inconsistently with the allowlist.
_recover_empty_session_model(sess, session, owner=owner)
if not getattr(sess, "model", "").strip():
raise HTTPException(
400,
"No model selected for this chat. Open the model picker and choose one before sending.",
)
# Same allowed_models + daily-cap gate as chat_stream (mirror so the
# non-streaming path can't be used to bypass).
_enforce_chat_privileges(request, sess)
@@ -270,8 +414,21 @@ def setup_chat_routes(
# but BEFORE loading. Prevents cross-user session hijack.
_verify_session_owner(request, session)
sess = session_manager.get_session(session)
if _clear_orphaned_session_endpoint(sess):
owner = get_current_user(request)
if _clear_orphaned_session_endpoint(sess, owner=owner):
raise HTTPException(400, "Selected model endpoint was removed. Pick another model in Settings.")
# Issue #587: picker shows a model from the endpoint cache but
# s.model never made it onto the DB row (first-send race after
# endpoint setup, or a previous endpoint delete/recreate). Pull
# the first cached model off the matching endpoint so the
# upstream isn't called with model="" (which surfaces as a
# generic 401/503).
_recover_empty_session_model(sess, session, owner=owner)
if not getattr(sess, "model", "").strip():
raise HTTPException(
400,
"No model selected for this chat. Open the model picker and choose one before sending.",
)
except SessionNotFoundError as e:
raise HTTPException(404, str(e))
except (ValueError, ValidationError):
@@ -288,7 +445,7 @@ def setup_chat_routes(
_enforce_chat_privileges(request, sess)
# Ensure session has auth headers
resolve_session_auth(sess, session)
resolve_session_auth(sess, session, owner=get_current_user(request))
# Check for research_pending BEFORE mode persist overwrites it
do_research = str(use_research).lower() == "true"
@@ -343,18 +500,22 @@ def setup_chat_routes(
try:
if active_doc_id:
logger.info(f"[doc-inject] active_doc_id from frontend: {active_doc_id}")
active_doc = _doc_db.query(DBDocument).filter(
DBDocument.id == active_doc_id,
).first()
# Scope to the caller's documents. The session and in-memory
# fallbacks below are already owner/session-bound; this
# explicit-id path looked up by id alone, so a user could
# inject another user's document by passing its id.
_doc_q = _doc_db.query(DBDocument).filter(DBDocument.id == active_doc_id)
active_doc = _owner_session_filter(_doc_q, ctx.user).first()
if active_doc:
logger.info(f"[doc-inject] found by ID: title={active_doc.title!r}, lang={active_doc.language!r}, is_active={active_doc.is_active}, content_len={len(active_doc.current_content or '')}")
else:
logger.warning(f"[doc-inject] NOT FOUND by ID {active_doc_id}")
if not active_doc:
active_doc = _doc_db.query(DBDocument).filter(
_session_doc_q = _doc_db.query(DBDocument).filter(
DBDocument.session_id == session,
DBDocument.is_active == True
).order_by(DBDocument.updated_at.desc()).first()
)
active_doc = _owner_session_filter(_session_doc_q, ctx.user).order_by(DBDocument.updated_at.desc()).first()
if active_doc:
logger.info(f"[doc-inject] found by session fallback: title={active_doc.title!r}")
# Last resort: the document the agent itself just created/edited
@@ -368,7 +529,8 @@ def setup_chat_routes(
from src.tool_implementations import get_active_document
_mem_id = get_active_document()
if _mem_id:
cand = _doc_db.query(DBDocument).filter(DBDocument.id == _mem_id).first()
_mem_q = _doc_db.query(DBDocument).filter(DBDocument.id == _mem_id)
cand = _owner_session_filter(_mem_q, ctx.user).first()
if cand and (not cand.session_id or cand.session_id == session):
active_doc = cand
logger.info(f"[doc-inject] found by in-memory active id: title={active_doc.title!r} (session_id={cand.session_id!r})")
@@ -563,6 +725,7 @@ def setup_chat_routes(
prior_findings=_prior_findings,
prior_urls=_prior_urls,
on_complete=_on_research_done,
owner=_user,
)
_heartbeat_counter = 0
@@ -619,7 +782,7 @@ def setup_chat_routes(
# output. Resolved once per request.
try:
from src.endpoint_resolver import resolve_chat_fallback_candidates
_fallback_candidates = resolve_chat_fallback_candidates()
_fallback_candidates = resolve_chat_fallback_candidates(owner=_user)
except Exception:
_fallback_candidates = []
@@ -632,28 +795,7 @@ def setup_chat_routes(
_model_info["character_name"] = ctx.preset.character_name
yield f'data: {json.dumps(_model_info)}\n\n'
# Detect image models and route directly to image generation
_IMAGE_MODEL_PREFIXES = ("gpt-image", "dall-e", "chatgpt-image")
_is_image_model = any(sess.model.lower().startswith(p) for p in _IMAGE_MODEL_PREFIXES)
# Also check if the endpoint is registered as an image-type endpoint
if not _is_image_model:
try:
from src.endpoint_resolver import normalize_base as _nb
_ep_base = _nb(sess.endpoint_url)
_db = SessionLocal()
try:
_is_image_model = _db.query(ModelEndpoint).filter(
ModelEndpoint.model_type == "image",
ModelEndpoint.is_enabled == True,
ModelEndpoint.base_url.contains(_ep_base.split("://")[-1].split("/")[0]),
).first() is not None
finally:
_db.close()
except Exception:
pass
if _is_image_model:
if _is_image_generation_session(sess, owner=_user):
from src.settings import get_setting
if not get_setting("image_gen_enabled", True):
yield f'data: {json.dumps({"delta": "Image generation is disabled by the administrator."})}\n\n'
@@ -664,7 +806,7 @@ def setup_chat_routes(
_user_msg = message or ""
yield f'data: {json.dumps({"type": "tool_start", "tool": "generate_image", "command": _user_msg[:100]})}\n\n'
yield ": heartbeat\n\n"
_img_result = await do_generate_image(f"{_user_msg}\n{sess.model}", session)
_img_result = await do_generate_image(f"{_user_msg}\n{sess.model}", session, owner=_user)
_img_output = _img_result.get("results", _img_result.get("error", ""))
_img_tool_data = {"type": "tool_output", "tool": "generate_image", "command": _user_msg[:100], "output": _img_output, "exit_code": 0 if "error" not in _img_result else 1}
for _k in ("image_url", "image_id", "image_prompt", "image_model", "image_size", "image_quality"):
@@ -688,6 +830,7 @@ def setup_chat_routes(
return
elif chat_mode == "chat":
_chat_start = time.time()
_answered_by = None # set if the selected model failed and a fallback answered
# ── Chat mode: call stream_llm directly, NO tools, NO document access ──
try:
_chat_candidates = [(sess.endpoint_url, sess.model, sess.headers)] + _fallback_candidates
@@ -708,16 +851,35 @@ def setup_chat_routes(
try:
data = json.loads(chunk[6:])
if "delta" in data:
full_response += data["delta"]
_stream_set(session, partial=full_response)
# Reasoning tokens arrive flagged thinking:true.
# Forward them so the client can show a thinking
# indicator, but don't fold them into the saved
# reply (mirrors the rewrite path below).
if not data.get("thinking"):
full_response += data["delta"]
_stream_set(session, partial=full_response)
yield chunk
elif data.get("type") == "fallback":
# Selected model failed; a fallback answered.
# Forward the notice and remember the real model.
_answered_by = data.get("answered_by") or _answered_by
yield chunk
elif data.get("type") == "usage":
last_metrics = data.get("data", {})
last_metrics["model"] = sess.model
last_metrics["model"] = _answered_by or sess.model
if ctx.context_length and last_metrics.get("input_tokens"):
pct = min(round((last_metrics["input_tokens"] / ctx.context_length) * 100, 1), 100.0)
last_metrics["context_percent"] = pct
last_metrics["context_length"] = ctx.context_length
# The frontend reads `tokens_per_second`; the raw usage event
# carries the backend's true gen speed as `gen_tps` (llama.cpp
# timings). Map it through so this direct-chat path shows real
# t/s instead of "n/a" → falling back to a bare token count.
if last_metrics.get("gen_tps") and not last_metrics.get("tokens_per_second"):
last_metrics["tokens_per_second"] = last_metrics["gen_tps"]
last_metrics["tps_source"] = "backend"
# Wall-clock response time for the stats popup ("Time").
last_metrics.setdefault("response_time", round(time.time() - _chat_start, 2))
yield f'data: {json.dumps({"type": "metrics", "data": last_metrics})}\n\n'
except json.JSONDecodeError:
yield chunk
@@ -781,6 +943,7 @@ def setup_chat_routes(
# ── Agent mode: full agent loop with tools ──
_agent_rounds = 0
_agent_tool_calls = 0
_answered_by = None # set if the selected model failed and a fallback answered
try:
from src.settings import get_setting
_tool_budget = int(get_setting("agent_max_tool_calls", 0))
@@ -805,8 +968,12 @@ def setup_chat_routes(
try:
data = json.loads(chunk[6:])
if "delta" in data:
full_response += data["delta"]
_stream_set(session, partial=full_response)
# Reasoning tokens arrive flagged thinking:true.
# Forward them for the live indicator, but keep
# them out of the saved reply (same as chat mode).
if not data.get("thinking"):
full_response += data["delta"]
_stream_set(session, partial=full_response)
yield chunk
elif data.get("type") == "web_sources":
web_sources = data.get("data", [])
@@ -821,9 +988,16 @@ def setup_chat_routes(
elif data.get("type") == "tool_start":
_agent_tool_calls += 1
yield chunk
elif data.get("type") == "fallback":
# Selected model failed; a fallback answered.
# Forward the notice and remember the real
# model so metrics reflect it, not the masked
# selected model.
_answered_by = data.get("answered_by") or _answered_by
yield chunk
elif data.get("type") == "metrics":
last_metrics = data.get("data", {})
last_metrics["model"] = sess.model
last_metrics["model"] = _answered_by or sess.model
yield f'data: {json.dumps({"type": "metrics", "data": last_metrics})}\n\n'
except json.JSONDecodeError:
yield chunk
@@ -920,11 +1094,15 @@ def setup_chat_routes(
_verify_session_owner(request, session_id)
# A detached run can still be going even if _active_streams was popped;
# report it as active so the client knows to reconnect via /resume.
if session_id not in _active_streams:
# Read once via .get() to avoid a KeyError race between the membership
# check and the indexed read if a sibling stream's finally pops the
# entry in between (same pattern _stream_set already uses).
rec = _active_streams.get(session_id)
if rec is None:
if agent_runs.is_active(session_id):
return {"status": "streaming", "detached": True}
raise HTTPException(404, "No active stream for this session")
return _active_streams[session_id]
return rec
# ------------------------------------------------------------------ #
# POST /api/inject_context
@@ -1088,7 +1266,7 @@ def setup_chat_routes(
db_msg = (
db.query(DBChatMessage)
.filter(DBChatMessage.session_id == session_id, DBChatMessage.role == 'assistant')
.order_by(DBChatMessage.created_at.desc())
.order_by(DBChatMessage.timestamp.desc())
.first()
)
if db_msg:
+19 -12
View File
@@ -130,21 +130,28 @@ def _parse_vcards(text: str) -> List[Dict]:
contact = {"name": "", "emails": [], "phones": [], "uid": ""}
for line in block.split("\n"):
line = line.strip()
if line.startswith("FN:") or line.startswith("FN;"):
contact["name"] = _vunesc(line.split(":", 1)[1]) if ":" in line else ""
elif line.startswith("EMAIL"):
# Strip an optional RFC 6350 group prefix (e.g. "item1.EMAIL;...")
# that Apple Contacts / iCloud / many CardDAV servers emit by
# default — without this the property-name checks below miss those
# lines and silently drop the email / phone. The group token only
# precedes the property name, so it is safe to strip for matching
# and value extraction, and a no-op for non-grouped lines.
name_part = re.sub(r"^[A-Za-z0-9-]+\.", "", line, count=1)
if name_part.startswith("FN:") or name_part.startswith("FN;"):
contact["name"] = _vunesc(name_part.split(":", 1)[1]) if ":" in name_part else ""
elif name_part.startswith("EMAIL"):
# Handle EMAIL:foo@bar OR EMAIL;TYPE=...:foo@bar OR EMAIL;PREF=1:foo@bar
if ":" in line:
email_addr = _vunesc(line.split(":", 1)[1])
if ":" in name_part:
email_addr = _vunesc(name_part.split(":", 1)[1])
if email_addr and email_addr not in contact["emails"]:
contact["emails"].append(email_addr)
elif line.startswith("TEL"):
if ":" in line:
phone = _vunesc(line.split(":", 1)[1])
elif name_part.startswith("TEL"):
if ":" in name_part:
phone = _vunesc(name_part.split(":", 1)[1])
if phone and phone not in contact["phones"]:
contact["phones"].append(phone)
elif line.startswith("UID:"):
contact["uid"] = _vunesc(line[4:])
elif name_part.startswith("UID:"):
contact["uid"] = _vunesc(name_part[4:])
if contact["name"] or contact["emails"]:
contacts.append(contact)
return contacts
@@ -676,8 +683,8 @@ def setup_contacts_routes():
@router.post("/add")
async def add_contact(data: dict, _admin: str = Depends(require_admin)):
"""Add a new contact."""
name = data.get("name", "").strip()
email = data.get("email", "").strip()
name = (data.get("name") or "").strip()
email = (data.get("email") or "").strip()
if not email:
return {"success": False, "error": "Email required"}
# Check if already exists
+250 -8
View File
@@ -148,6 +148,108 @@ def _local_tooling_path_export(executable: str) -> str:
return f'export PATH="{esc}:$PATH"'
def _pip_install_no_cache(cmd: str) -> str:
"""Add ``--no-cache-dir`` to a pip install command.
Cookbook dependency installs (vLLM, llama-cpp-python, ) build large wheels;
pip's default cache lives under ``$HOME/.cache/pip`` and these builds can fill
a small home filesystem with ``[Errno 28] No space left on device`` mid-build
(issue #1219), leaving the dependency "installed" but unusable (#1459).
Disabling the cache for these one-off installs keeps them off the home disk
(the maintainer's suggested ``PIP_CACHE_DIR=`` workaround, made the default).
Idempotent; leaves non-pip-install commands untouched."""
if not cmd or "pip install" not in cmd or "--no-cache-dir" in cmd:
return cmd
return cmd.replace("pip install", "pip install --no-cache-dir", 1)
def _pip_install_attempt(pip_cmd: str) -> str:
"""Wrap a single pip install command so its exit status survives the
fallback chain and its stderr is visible in the tmux log on failure.
Without this wrapper, `pip 2>&1 | tail -5` returns ``tail``'s exit
code (0), masking pip's real failure and preventing the next fallback
from running. The generated snippet captures all output to a temp
file, prints the last 5 lines on failure (so the Cookbook log panel
shows useful diagnostics), cleans up, and exits with pip's original
status.
"""
return (
"bash -c '"
f'_out=$(mktemp) && {pip_cmd} >"$_out" 2>&1; _rc=$?; '
'tail -5 "$_out"; rm -f "$_out"; exit $_rc'
"'"
)
def _pip_install_fallback_chain(package: str, *, python_cmd: str = "python3 -m pip", upgrade: bool = False) -> str:
"""Build a bash pip install fallback chain that surfaces errors.
Try the active interpreter/environment first. ``--user`` is invalid
inside many venvs, so only attempt the ``--user`` fallback when NOT
inside a venv.
Each attempt is wrapped via :func:`_pip_install_attempt` so pip's real
exit code is preserved (no ``| tail`` masking) and the last 5 lines of
pip output appear in the Cookbook log on failure.
"""
upgrade_flag = " -U" if upgrade else ""
# Shell-quote the package spec: an extras spec like ``llama-cpp-python[server]``
# contains brackets that bash would treat as a glob, so it must be quoted
# before being embedded in the install command. Plain names (e.g.
# ``huggingface_hub``) are returned unchanged by ``shlex.quote``.
pkg = shlex.quote(package)
base = _pip_install_attempt(f"{python_cmd} install -q{upgrade_flag} {pkg}")
user = _pip_install_attempt(f"{python_cmd} install --user --break-system-packages -q{upgrade_flag} {pkg}")
# Derive the python executable for the venv detection check.
# Must use the same interpreter that pip belongs to; hardcoding
# python3 breaks when pip lives in a venv that only has "python".
if " -m pip" in python_cmd:
python_exe = python_cmd.replace(" -m pip", "")
elif python_cmd.strip() == "pip":
python_exe = "python"
elif python_cmd.strip() == "pip3":
python_exe = "python3"
else:
python_exe = "python3"
venv_check = f'{python_exe} -c "import sys; sys.exit(0 if sys.prefix != sys.base_prefix else 1)"'
# Negated: `! venv_check` succeeds (exit 0) when NOT in a venv → `&&` tries
# --user. When IN a venv `! venv_check` fails → `&&` skips --user and the
# group exits non-zero, propagating the base-install failure instead of
# masking it as success (the `|| { venv_check || … }` shape from #903
# swallowed the exit code because venv_check's exit-0 became the group's
# result).
return f"{base} || {{ ! {venv_check} && {user}; }}"
def _venv_safe_local_pip_install_cmd(cmd: str, *, local: bool, in_venv: bool) -> str:
"""Drop pip user-install flags that are invalid for local venv installs.
Cookbook dependency installs run through the model-serve task path so users
can watch progress in the same log UI. For local POSIX runs, that task
prepends Odysseus' own interpreter directory to PATH. If Odysseus itself is
running from a venv, `python3` resolves to the venv Python and pip rejects
`--user` with "User site-packages are not visible in this virtualenv".
Keep remote and non-venv installs unchanged: remotes may intentionally use
system Python, and Docker/non-venv installs still need user-site fallback.
"""
if not local or not in_venv:
return cmd
if "pip install" not in (cmd or ""):
return cmd
try:
parts = shlex.split(cmd)
except ValueError:
return cmd
stripped = [
part
for part in parts
if part not in {"--user", "--break-system-packages"}
]
return shlex.join(stripped)
def _cached_model_scan_script(model_dirs: list[str] | None = None) -> str:
"""Build the standalone Python scanner used by /api/model/cached."""
lines = [
@@ -166,6 +268,38 @@ def _cached_model_scan_script(model_dirs: list[str] | None = None) -> str:
" for root, dirs, fns in os.walk(top, followlinks=False):",
" dirs[:] = [d for d in dirs if not os.path.islink(os.path.join(root, d)) and safe_path(os.path.join(root, d))]",
" yield root, dirs, fns",
"def gguf_role(name):",
" n = name.lower()",
" if n.startswith('mmproj') or 'mmproj' in n: return 'projector'",
" return 'model'",
"def gguf_quant(name):",
" m = re.search(r'(?i)(UD-)?(IQ[0-9]_[A-Z0-9_]+|Q[0-9](?:_[A-Z0-9]+)+|BF16|F16|FP16|F32|Q8_0)', name)",
" return m.group(0).upper() if m else ''",
"def collect_ggufs(base):",
" files = []",
" split_groups = {}",
" if not os.path.isdir(base) or not safe_path(base): return files",
" for root, dirs, fns in safe_walk(base):",
" for fn in sorted(fns):",
" if not fn.lower().endswith('.gguf'): continue",
" fp = os.path.join(root, fn)",
" try: size = os.path.getsize(fp)",
" except Exception: size = 0",
" try: rel = os.path.relpath(fp, base).replace(os.sep, '/')",
" except Exception: rel = fn",
" sm = re.match(r'(?i)^(.+)-(\\d+)-of-(\\d+)\\.gguf$', fn)",
" if sm:",
" prefix, part_s, total_s = sm.group(1), sm.group(2), sm.group(3)",
" key = (root, prefix, total_s)",
" g = split_groups.setdefault(key, {'name':fn,'rel_path':rel,'size_bytes':0,'role':gguf_role(fn),'quant':gguf_quant(fn),'parts':int(total_s),'split':True})",
" g['size_bytes'] += size",
" if int(part_s) == 1:",
" g.update({'name':fn,'rel_path':rel,'role':gguf_role(fn),'quant':gguf_quant(fn)})",
" continue",
" files.append({'name':fn,'rel_path':rel,'size_bytes':size,'role':gguf_role(fn),'quant':gguf_quant(fn)})",
" files.extend(split_groups.values())",
" files.sort(key=lambda f: (f.get('role') != 'model', f.get('rel_path', '')))",
" return files",
"def scan_hf(cache):",
" if not os.path.isdir(cache): return",
" for d in sorted(os.listdir(cache)):",
@@ -180,16 +314,14 @@ def _cached_model_scan_script(model_dirs: list[str] | None = None) -> str:
" if f.is_file(): nf += 1; sz += f.stat().st_size",
" if f.name.endswith('.incomplete'): ic = True",
" snap = os.path.join(cache, d, 'snapshots')",
" is_diffusion = False; is_gguf = False",
" is_diffusion = False; gguf_files = []",
" if os.path.isdir(snap):",
" for sd in os.listdir(snap):",
" sf = os.path.join(snap, sd)",
" if not os.path.isdir(sf): continue",
" if os.path.exists(os.path.join(sf, 'model_index.json')): is_diffusion = True",
" try:",
" if any(x.endswith('.gguf') for x in os.listdir(sf)): is_gguf = True",
" except Exception: pass",
" models.append({'repo_id':rid,'size_bytes':sz,'nb_files':nf,'has_incomplete':ic,'path':cache,'is_diffusion':is_diffusion,'is_gguf':is_gguf})",
" for f in collect_ggufs(sf): f['rel_path'] = sd + '/' + f['rel_path']; gguf_files.append(f)",
" models.append({'repo_id':rid,'size_bytes':sz,'nb_files':nf,'has_incomplete':ic,'path':cache,'is_diffusion':is_diffusion,'is_gguf':bool(gguf_files),'gguf_files':gguf_files})",
"def scan_dir(p):",
" if not os.path.isdir(p) or not safe_path(p): return",
" for d in sorted(os.listdir(p)):",
@@ -198,13 +330,14 @@ def _cached_model_scan_script(model_dirs: list[str] | None = None) -> str:
" fp = os.path.join(p, d)",
" if not os.path.isdir(fp) or os.path.islink(fp) or not safe_path(fp): continue",
" if d in seen: continue",
" is_model = False; is_gguf = False",
" is_model = False; gguf_files = []",
" for root, dirs, fns in safe_walk(fp):",
" for fn in fns:",
" if fn.endswith('.gguf'): is_gguf = True; is_model = True",
" if fn.lower().endswith('.gguf'): is_model = True",
" elif fn == 'config.json' or fn.endswith('.safetensors') or fn.endswith('.bin'): is_model = True",
" if is_model: break",
" if not is_model: continue",
" gguf_files = collect_ggufs(fp)",
" seen.add(d)",
" sz, nf = 0, 0",
" for dp, _, fns in safe_walk(fp):",
@@ -212,7 +345,7 @@ def _cached_model_scan_script(model_dirs: list[str] | None = None) -> str:
" try: nf += 1; sz += os.path.getsize(os.path.join(dp, fn))",
" except Exception: pass",
" is_diff = os.path.exists(os.path.join(fp, 'model_index.json'))",
" models.append({'repo_id':d,'size_bytes':sz,'nb_files':nf,'has_incomplete':False,'path':p,'is_local_dir':True,'is_diffusion':is_diff,'is_gguf':is_gguf})",
" models.append({'repo_id':d,'size_bytes':sz,'nb_files':nf,'has_incomplete':False,'path':p,'is_local_dir':True,'is_diffusion':is_diff,'is_gguf':bool(gguf_files),'gguf_files':gguf_files})",
"def parse_size(num, unit):",
" try: n = float(num)",
" except Exception: return 0",
@@ -293,6 +426,38 @@ _SERVE_CMD_ALLOWLIST = {
_GGUF_PRELUDE_RE = re.compile(
r'^MODEL_FILE=\$\([^\n]*?\)\s*&&\s*\{[^{}]*\}\s*\|\|\s*\{[^{}]*\}\s*&&\s*'
)
_OLLAMA_HOST_ASSIGNMENT_RE = re.compile(r"(?:^|\s)OLLAMA_HOST=([^\s]+)")
_OLLAMA_BIND_RE = re.compile(r"^\[([^\]]+)\]:(\d+)$|^([^:]+):(\d+)$")
_OLLAMA_BIND_HOST_RE = re.compile(r"^[A-Za-z0-9._:-]+$")
def _ollama_bind_from_cmd(cmd: str | None, *, default_host: str = "127.0.0.1") -> tuple[str, str]:
"""Return the Ollama bind host/port requested by a serve command.
Plain local `ollama serve` defaults to loopback. Remote callers can pass a
wider default host so the resulting API is reachable by Odysseus.
"""
if not cmd:
return default_host, "11434"
match = _OLLAMA_HOST_ASSIGNMENT_RE.search(cmd)
if not match:
return default_host, "11434"
value = match.group(1).strip("'\"")
bind_match = _OLLAMA_BIND_RE.match(value)
if not bind_match:
return "127.0.0.1", "11434"
bracketed_host = bind_match.group(1)
host = bracketed_host or bind_match.group(3) or "127.0.0.1"
port = bind_match.group(2) or bind_match.group(4) or "11434"
if not _OLLAMA_BIND_HOST_RE.match(host):
return "127.0.0.1", "11434"
try:
port_num = int(port, 10)
except ValueError:
return "127.0.0.1", "11434"
if port_num < 1 or port_num > 65535:
return "127.0.0.1", "11434"
return f"[{host}]" if bracketed_host else host, port
def _check_serve_binary(seg: str) -> None:
@@ -370,6 +535,83 @@ def _append_serve_exit_code_lines(runner_lines: list[str], *, keep_shell_open: b
runner_lines.append('echo ""; echo "=== Process exited with code $ODYSSEUS_CMD_EXIT ==="; exec "${SHELL:-/bin/bash}"')
else:
runner_lines.append('echo ""; echo "=== Process exited with code $ODYSSEUS_CMD_EXIT ==="')
runner_lines.append('exit "$ODYSSEUS_CMD_EXIT"')
def _append_llama_cpp_linux_accel_build_lines(runner_lines: list[str]) -> None:
"""Append Linux llama.cpp build lines that prefer ROCm/HIP when available.
Cookbook already detects AMD GPUs elsewhere, but the llama.cpp bootstrap used
to hard-wire CUDA on Linux. That made ROCm hosts attempt a CUDA configure and
fail with "CUDA Toolkit not found" instead of building with HIP.
"""
# Detect pip-installed nvcc (from vLLM/nvidia CUDA wheels) and put it on PATH
# so cmake's CUDA configure can find it. We keep this after the ROCm/HIP
# check — a machine with both stacks should honor the native HIP toolchain on
# AMD hosts instead of accidentally preferring a stray nvcc wheel.
runner_lines.append(' for _cudir in ~/.local/lib/python*/site-packages/nvidia/cu13 ~/.local/lib/python*/site-packages/nvidia/cu12 ~/.local/lib/python*/site-packages/nvidia/cuda_nvcc; do')
runner_lines.append(' [ -x "$_cudir/bin/nvcc" ] && export CUDA_HOME="$_cudir" && export PATH="$_cudir/bin:$PATH" && break')
runner_lines.append(' done')
# rm -rf build so a prior poisoned CMakeCache.txt (e.g. from a failed CUDA
# or HIP attempt) doesn't cause the next configure to reuse stale settings.
runner_lines.append(' cd ~/llama.cpp && rm -rf build')
runner_lines.append(' if command -v hipconfig &>/dev/null || [ -d /opt/rocm ] || [ -n "$ROCM_PATH" ] || [ -n "$HIP_PATH" ]; then')
runner_lines.append(' if command -v hipconfig &>/dev/null; then')
runner_lines.append(' export HIPCXX="${HIPCXX:-$(hipconfig -l)/clang}"')
runner_lines.append(' export HIP_PATH="${HIP_PATH:-$(hipconfig -R)}"')
runner_lines.append(' fi')
runner_lines.append(' echo "[odysseus] ROCm/HIP detected — building llama-server with HIP support..."')
runner_lines.append(' cmake -B build -DCMAKE_BUILD_TYPE=Release -DGGML_HIP=ON && cmake --build build -j"$NPROC" --target llama-server && ln -sf ~/llama.cpp/build/bin/llama-server ~/bin/llama-server')
runner_lines.append(' elif command -v nvcc &>/dev/null; then')
# nvcc alone is not sufficient — pip-installed CUDA wheels or incomplete
# tooling can expose nvcc without shipping libcudart, causing cmake to fail
# mid-build with "CUDA runtime library not found". Check cudart explicitly
# via a small helper so the guard stays readable.
runner_lines.append(' _odysseus_has_cudart() {')
runner_lines.append(' ldconfig -p 2>/dev/null | grep -q \'libcudart\\.so\' && return 0')
runner_lines.append(' local _cuh="${CUDA_HOME:-/usr/local/cuda}"')
runner_lines.append(' ls "$_cuh/lib64/libcudart.so"* &>/dev/null && return 0')
runner_lines.append(' ls "$_cuh/lib/libcudart.so"* &>/dev/null && return 0')
runner_lines.append(' ls /usr/local/cuda/lib64/libcudart.so* &>/dev/null && return 0')
runner_lines.append(' ls /usr/local/cuda/lib/libcudart.so* &>/dev/null && return 0')
runner_lines.append(' ls "${_cuh%/cuda_nvcc}/cuda_runtime/lib/libcudart.so"* &>/dev/null && return 0')
runner_lines.append(' return 1')
runner_lines.append(' }')
runner_lines.append(' if _odysseus_has_cudart; then')
runner_lines.append(' echo "[odysseus] CUDA nvcc + cudart found — building llama-server with CUDA (GPU) support..."')
runner_lines.append(' cmake -B build -DCMAKE_BUILD_TYPE=Release -DGGML_CUDA=ON && cmake --build build -j"$NPROC" --target llama-server && ln -sf ~/llama.cpp/build/bin/llama-server ~/bin/llama-server')
runner_lines.append(' else')
runner_lines.append(' echo "[odysseus] WARNING: nvcc found but CUDA runtime (libcudart.so) is not visible — building llama-server for CPU only."')
runner_lines.append(' echo "[odysseus] GPU inference will not be available for this llama.cpp build."')
runner_lines.append(' echo "[odysseus] Ensure libcudart is installed (e.g. cuda-runtime package) and visible via ldconfig or CUDA_HOME."')
runner_lines.append(' cmake -B build -DCMAKE_BUILD_TYPE=Release && cmake --build build -j"$NPROC" --target llama-server && ln -sf ~/llama.cpp/build/bin/llama-server ~/bin/llama-server')
runner_lines.append(' fi')
runner_lines.append(' else')
runner_lines.append(' echo "[odysseus] WARNING: no HIP/CUDA toolchain found — building llama-server for CPU only."')
runner_lines.append(' echo "[odysseus] GPU inference will not be available for this llama.cpp build."')
runner_lines.append(' echo "[odysseus] Install ROCm for AMD GPUs or vLLM/CUDA tooling for NVIDIA, then re-launch this serve task."')
runner_lines.append(' cmake -B build -DCMAKE_BUILD_TYPE=Release && cmake --build build -j"$NPROC" --target llama-server && ln -sf ~/llama.cpp/build/bin/llama-server ~/bin/llama-server')
runner_lines.append(' fi')
def _llama_cpp_rebuild_cmd() -> str:
"""Shell command that clears the Cookbook-managed llama.cpp build.
Removes the cached ``llama-server`` symlink and the ``~/llama.cpp/build``
directory so the next llama.cpp serve recompiles from source, picking up a
CUDA or HIP toolchain if one is now available. The serve bootstrap only
builds when ``llama-server`` is missing from PATH, so without this an
existing CPU-only build is reused forever. It deliberately installs and
downloads nothing; the rebuild itself happens on the next serve.
"""
return (
'mkdir -p "$HOME/bin" && '
'rm -f "$HOME/bin/llama-server" && '
'rm -rf "$HOME/llama.cpp/build" && '
'echo "[odysseus] Cleared the cached llama.cpp build. '
'Re-launch the serve task to rebuild llama-server from source '
'(CUDA or HIP will be used if a toolchain is now available)."'
)
class ModelDownloadRequest(BaseModel):
+319 -73
View File
@@ -37,7 +37,8 @@ from routes.cookbook_helpers import (
_validate_local_dir, _validate_ssh_port, _validate_gpus, _shell_path,
_ps_squote, _bash_squote, _validate_serve_cmd, _parse_serve_phase,
_safe_env_prefix, _local_tooling_path_export, _append_serve_preflight_exit_lines,
_append_serve_exit_code_lines, _cached_model_scan_script,
_append_serve_exit_code_lines, _append_llama_cpp_linux_accel_build_lines, _cached_model_scan_script,
_ollama_bind_from_cmd, _pip_install_fallback_chain, _pip_install_no_cache, _venv_safe_local_pip_install_cmd,
ModelDownloadRequest, ServeRequest,
)
@@ -148,6 +149,15 @@ def setup_cookbook_routes() -> APIRouter:
"No GPUs are visible to the serve process.",
[{"label": "clear Cookbook GPU selection or choose available GPUs", "op": "settings", "field": "gpus", "value": ""}],
),
(
r"Failed to infer device type|NVML Shared Library Not Found|No module named 'amdsmi'|platform is not available",
"vLLM could not find a supported GPU (CUDA or ROCm). "
"This machine may have integrated or unsupported graphics only.",
[
{"label": "switch to llama.cpp (CPU/Metal, works without a discrete GPU)", "op": "manual"},
{"label": "switch to Ollama (CPU/Metal, works without a discrete GPU)", "op": "manual"},
],
),
(
r"vllm.*command not found|No module named vllm|ERROR: vLLM is not installed",
"vLLM is not installed or not in PATH on this server.",
@@ -163,6 +173,11 @@ def setup_cookbook_routes() -> APIRouter:
"llama.cpp / llama-cpp-python dependencies are missing.",
[{"label": "install llama.cpp dependencies or llama-cpp-python[server]", "op": "dependency", "package": "llama-cpp-python[server]"}],
),
(
r"No GGUF found on this host|no \.gguf file|No GGUF file found",
"No GGUF file found for this model on this host. The llama.cpp backend needs a .gguf file.",
[{"label": "download a GGUF build of this model (repo name usually ends in -GGUF, file like Q4_K_M.gguf)", "op": "manual"}],
),
(
r"No module named 'torch'|No module named torch|No module named 'diffusers'|No module named diffusers",
"Diffusion serving requires PyTorch and diffusers.",
@@ -368,11 +383,15 @@ def setup_cookbook_routes() -> APIRouter:
encoding="utf-8",
)
argv = [os.environ.get("ComSpec", "cmd.exe"), "/c", str(script_path)]
env = os.environ.copy()
env["PYTHONUTF8"] = "1"
env["PYTHONIOENCODING"] = "utf-8"
proc = subprocess.Popen(
argv,
stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL,
stdin=subprocess.DEVNULL,
env=env,
**detached_popen_kwargs(),
)
pid_path.write_text(str(proc.pid), encoding="utf-8")
@@ -432,12 +451,12 @@ def setup_cookbook_routes() -> APIRouter:
# throughput. Retries set disable_hf_transfer to fall back to the plain,
# slower-but-reliable downloader (resumes cleanly from the .incomplete files).
# Use `python3 -m pip` not `pip` — macOS has no bare `pip` command.
lines.append("command -v hf >/dev/null 2>&1 || python3 -m pip install --user --break-system-packages -q -U huggingface_hub 2>/dev/null || python3 -m pip install -q -U huggingface_hub 2>/dev/null")
lines.append(f"command -v hf >/dev/null 2>&1 || {_pip_install_fallback_chain('huggingface_hub', upgrade=True)}")
if req.disable_hf_transfer:
lines.append("export HF_HUB_ENABLE_HF_TRANSFER=0")
lines.append("export HF_HUB_DOWNLOAD_MAX_WORKERS=4")
else:
lines.append("python3 -c 'import hf_transfer' 2>/dev/null || python3 -m pip install --user --break-system-packages -q hf_transfer 2>/dev/null || python3 -m pip install -q hf_transfer 2>/dev/null")
lines.append(f"python3 -c 'import hf_transfer' 2>/dev/null || {_pip_install_fallback_chain('hf_transfer')}")
lines.append("python3 -c 'import hf_transfer' 2>/dev/null && export HF_HUB_ENABLE_HF_TRANSFER=1")
lines.append("export HF_HUB_DOWNLOAD_MAX_WORKERS=8")
@@ -531,12 +550,18 @@ def setup_cookbook_routes() -> APIRouter:
)
# Ensure pip-user scripts (e.g. hf CLI installed via --user) are on PATH
runner_lines.append('export PATH="$HOME/.local/bin:$PATH"')
# Install hf CLI + hf_transfer best-effort so future runs get the fast path.
# Install hf CLI + optional hf_transfer best-effort. Retries disable
# hf_transfer because the Rust parallel path is fast but has been
# flaky near the end of very large multi-file downloads.
# Use --break-system-packages on PEP-668 systems (Arch, newer Debian) so it doesn't bail.
runner_lines.append("command -v hf >/dev/null 2>&1 || pip install --user --break-system-packages -q -U huggingface_hub 2>/dev/null || pip install -q -U huggingface_hub 2>/dev/null")
runner_lines.append("python3 -c 'import hf_transfer' 2>/dev/null || pip install --user --break-system-packages -q hf_transfer 2>/dev/null || pip install -q hf_transfer 2>/dev/null")
runner_lines.append("python3 -c 'import hf_transfer' 2>/dev/null && export HF_HUB_ENABLE_HF_TRANSFER=1")
runner_lines.append("export HF_HUB_DOWNLOAD_MAX_WORKERS=8")
runner_lines.append(f"command -v hf >/dev/null 2>&1 || {_pip_install_fallback_chain('huggingface_hub', python_cmd='pip', upgrade=True)}")
if req.disable_hf_transfer:
runner_lines.append("export HF_HUB_ENABLE_HF_TRANSFER=0")
runner_lines.append("export HF_HUB_DOWNLOAD_MAX_WORKERS=4")
else:
runner_lines.append(f"python3 -c 'import hf_transfer' 2>/dev/null || {_pip_install_fallback_chain('hf_transfer', python_cmd='pip')}")
runner_lines.append("python3 -c 'import hf_transfer' 2>/dev/null && export HF_HUB_ENABLE_HF_TRANSFER=1")
runner_lines.append("export HF_HUB_DOWNLOAD_MAX_WORKERS=8")
# Surface whether the HF token actually reached THIS server, so a gated
# download's "not authorized" failure can be told apart from a missing
# token (the token is masked — we only print applied / not-set).
@@ -547,15 +572,19 @@ def setup_cookbook_routes() -> APIRouter:
runner_lines.append(f' {hf_cmd} < /dev/null')
runner_lines.append('elif python3 -c "import huggingface_hub" 2>/dev/null; then')
runner_lines.append(' echo "hf CLI not found, using Python huggingface_hub..."')
runner_lines.append(f' python3 -c "import os; from huggingface_hub import snapshot_download; snapshot_download(\'{req.repo_id}\'{_dl_pyarg}, max_workers=8)"')
runner_lines.append(f' python3 -c "import os; from huggingface_hub import snapshot_download; snapshot_download(\'{req.repo_id}\'{_dl_pyarg}, max_workers={4 if req.disable_hf_transfer else 8})"')
runner_lines.append('else')
runner_lines.append(' echo "Installing huggingface-hub and dependencies..."')
runner_lines.append(' pip install --no-deps -q huggingface-hub 2>/dev/null')
runner_lines.append(' pip install -q filelock fsspec packaging pyyaml tqdm typer httpx requests hf_transfer 2>/dev/null')
runner_lines.append(" python3 -c 'import hf_transfer' 2>/dev/null && export HF_HUB_ENABLE_HF_TRANSFER=1")
runner_lines.append(f' python3 -c "import os; from huggingface_hub import snapshot_download; snapshot_download(\'{req.repo_id}\'{_dl_pyarg}, max_workers=8)"')
if req.disable_hf_transfer:
runner_lines.append(' pip install -q filelock fsspec packaging pyyaml tqdm typer httpx requests 2>/dev/null')
runner_lines.append(' export HF_HUB_ENABLE_HF_TRANSFER=0')
else:
runner_lines.append(' pip install -q filelock fsspec packaging pyyaml tqdm typer httpx requests hf_transfer 2>/dev/null')
runner_lines.append(" python3 -c 'import hf_transfer' 2>/dev/null && export HF_HUB_ENABLE_HF_TRANSFER=1")
runner_lines.append(f' python3 -c "import os; from huggingface_hub import snapshot_download; snapshot_download(\'{req.repo_id}\'{_dl_pyarg}, max_workers={4 if req.disable_hf_transfer else 8})"')
runner_lines.append('fi')
runner_lines.append('if [ $? -eq 0 ]; then echo ""; echo "DOWNLOAD_OK"; else echo ""; echo "DOWNLOAD_FAILED (exit $?)"; fi')
runner_lines.append('_ec=$?; if [ $_ec -eq 0 ]; then echo ""; echo "DOWNLOAD_OK"; else echo ""; echo "DOWNLOAD_FAILED (exit $_ec)"; fi')
runner_lines.append(f"rm -f {remote_runner}")
runner_lines.append('exec "${SHELL:-/bin/bash}"')
runner_path = TMUX_LOG_DIR / f"{session_id}_run.sh"
@@ -586,11 +615,11 @@ def setup_cookbook_routes() -> APIRouter:
# Detached path: no controlling TTY, so skip `< /dev/null`
# (handled by Popen stdin=DEVNULL) and don't keep a shell open.
lines.append(hf_cmd)
lines.append('if [ $? -eq 0 ]; then echo ""; echo "DOWNLOAD_OK"; else echo ""; echo "DOWNLOAD_FAILED (exit $?)"; fi')
lines.append('_ec=$?; if [ $_ec -eq 0 ]; then echo ""; echo "DOWNLOAD_OK"; else echo ""; echo "DOWNLOAD_FAILED (exit $_ec)"; fi')
else:
# < /dev/null suppresses interactive "update available? [Y/n]" prompt
lines.append(f"{hf_cmd} < /dev/null")
lines.append('if [ $? -eq 0 ]; then echo ""; echo "DOWNLOAD_OK"; else echo ""; echo "DOWNLOAD_FAILED (exit $?)"; fi')
lines.append('_ec=$?; if [ $_ec -eq 0 ]; then echo ""; echo "DOWNLOAD_OK"; else echo ""; echo "DOWNLOAD_FAILED (exit $_ec)"; fi')
lines.append(f"rm -f '{wrapper_script}'")
lines.append('exec "${SHELL:-/bin/bash}"')
wrapper_script.write_text("\n".join(lines) + "\n", encoding="utf-8")
@@ -672,11 +701,14 @@ def setup_cookbook_routes() -> APIRouter:
cwd=str(Path.home()),
)
else:
# LOCAL scan: run the interpreter directly. `python3` isn't a thing on
# Windows (it's `python`/`py`), and shell single-quoting of the path
# doesn't survive cmd.exe — so resolve the interpreter and exec it
# with the script path as an argv element (no shell quoting needed).
local_py = (
# LOCAL scan: use sys.executable (the venv Python Odysseus is already
# running under) — it's guaranteed real Python on all platforms.
# Falling back to which_tool on Windows risks hitting the Microsoft
# Store stub alias for "python3"/"python", which prints
# "Python was not found; run without arguments to install from the
# Microsoft Store" and exits 9009, producing empty stdout and a
# JSON parse error. sys.executable bypasses PATH entirely.
local_py = sys.executable or (
which_tool("python3") or which_tool("python")
or which_tool("py") or "python"
)
@@ -714,6 +746,8 @@ def setup_cookbook_routes() -> APIRouter:
entry["backend"] = m.get("backend")
if m.get("is_ollama"):
entry["is_ollama"] = True
if isinstance(m.get("gguf_files"), list):
entry["gguf_files"] = m["gguf_files"]
models.append(entry)
except Exception as e:
logger.warning(f"Failed to parse cached models: {e}")
@@ -775,6 +809,80 @@ def setup_cookbook_routes() -> APIRouter:
finally:
db.close()
def _auto_register_llm_endpoint(req: ServeRequest, remote: str | None) -> str | None:
"""Register a freshly-served LLM as a model endpoint so it appears in the
model picker without a manual /setup step the text-model sibling of
_auto_register_image_endpoint.
Cookbook serve commands launch an OpenAI-compatible server (llama.cpp's
llama-server, vLLM, SGLang, or Ollama) on a known port. We point an
endpoint at that server's /v1; the picker auto-discovers the model id by
probing /v1/models and dims the endpoint until the server is reachable,
so registering immediately (before the server finishes loading) is safe.
"""
import re
from core.database import SessionLocal, ModelEndpoint
# Port: an explicit --port wins. Otherwise fall back by backend — Ollama
# is the only server in our generated commands that omits --port.
port_match = re.search(r'--port\s+(\d+)', req.cmd)
if port_match:
port = int(port_match.group(1))
elif "ollama" in req.cmd:
port = 11434
else:
port = 8080 # llama.cpp's llama-server default — the Apple Silicon path
# Determine host (mirrors the image path: SSH alias for remote serves).
if remote:
host = remote.split("@")[-1] if "@" in remote else remote
else:
host = "localhost"
base_url = f"http://{host}:{port}/v1"
short_name = req.repo_id.split("/")[-1] if "/" in req.repo_id else req.repo_id
display_name = short_name or "Local model"
# If the serve command opts models into OpenAI tool-calling, record it so
# agent_loop trusts emitted tool_calls instead of the name heuristic.
supports_tools = True if "--enable-auto-tool-choice" in req.cmd else None
db = SessionLocal()
try:
# Reuse an endpoint already pointed at this URL instead of duplicating.
existing = db.query(ModelEndpoint).filter(ModelEndpoint.base_url == base_url).first()
if existing:
existing.is_enabled = True
existing.model_type = "llm"
existing.name = display_name
if supports_tools is not None:
existing.supports_tools = supports_tools
db.commit()
logger.info(f"Updated existing local model endpoint: {base_url}")
return existing.id
ep_id = f"local-{uuid.uuid4().hex[:8]}"
ep = ModelEndpoint(
id=ep_id,
name=display_name,
base_url=base_url,
api_key=None,
is_enabled=True,
model_type="llm",
supports_tools=supports_tools,
)
db.add(ep)
db.commit()
logger.info(f"Auto-registered local model endpoint: {display_name} @ {base_url}")
return ep_id
except Exception as e:
logger.error(f"Failed to auto-register local model endpoint: {e}")
db.rollback()
return None
finally:
db.close()
@router.post("/api/model/serve")
async def model_serve(request: Request, req: ServeRequest):
"""Launch a model server in a tmux session (or PowerShell background process on Windows).
@@ -800,8 +908,17 @@ def setup_cookbook_routes() -> APIRouter:
# many downstream `"engine" in req.cmd` membership checks can't hit
# `TypeError: argument of type 'NoneType'` (a 500 instead of a clean 400).
req.cmd = _validate_serve_cmd(req.cmd) or ""
req.cmd = _venv_safe_local_pip_install_cmd(
req.cmd,
local=not bool(req.remote_host),
in_venv=sys.prefix != sys.base_prefix,
)
is_pip_install = bool(req.cmd and "pip install" in req.cmd)
if is_pip_install:
# Keep big dependency wheel builds (vLLM, …) off the home filesystem's
# pip cache so they don't fail mid-build with "No space left" (#1219)
# and leave the dep installed-but-unusable (#1459).
req.cmd = _pip_install_no_cache(req.cmd)
# PEP-508-style package spec — letters, digits, `.-_` for the
# name; `[` `]` for extras; `<>=!~,` for version specifiers.
# v2 review HIGH-14: tightened from the previous regex which
@@ -922,7 +1039,7 @@ def setup_cookbook_routes() -> APIRouter:
runner_lines.append(' if ! python3 -c "import llama_cpp" 2>/dev/null; then')
runner_lines.append(' pkg install -y cmake 2>/dev/null')
runner_lines.append(' pip install numpy diskcache jinja2 2>/dev/null')
runner_lines.append(' CMAKE_ARGS="-DGGML_BLAS=OFF -DGGML_LLAMAFILE=OFF" pip install llama-cpp-python --no-build-isolation --no-cache-dir 2>&1 || true')
runner_lines.append(' CMAKE_ARGS="-DGGML_BLAS=OFF -DGGML_LLAMAFILE=OFF" pip install \'llama-cpp-python[server]\' --no-build-isolation --no-cache-dir 2>&1 || true')
runner_lines.append(' fi')
runner_lines.append('elif ! command -v llama-server &>/dev/null; then')
runner_lines.append(' echo "Native llama-server not found — building from source (one-time, may take a few minutes)..."')
@@ -944,61 +1061,45 @@ def setup_cookbook_routes() -> APIRouter:
runner_lines.append(' && cmake --build build -j"$NPROC" --target llama-server \\')
runner_lines.append(' && ln -sf ~/llama.cpp/build/bin/llama-server ~/bin/llama-server')
runner_lines.append(' else')
# Detect pip-installed nvcc (from vLLM/nvidia CUDA wheels) and put
# it on PATH so cmake's CUDA configure can find it. We check the
# same three layouts as entrypoint.sh:
# nvidia/cu13 — nvidia-nvcc-cu13
# nvidia/cu12 — nvidia-nvcc-cu12
# nvidia/cuda_nvcc — nvidia-cuda-nvcc-cu12 (sub-package style)
runner_lines.append(' for _cudir in ~/.local/lib/python*/site-packages/nvidia/cu13 ~/.local/lib/python*/site-packages/nvidia/cu12 ~/.local/lib/python*/site-packages/nvidia/cuda_nvcc; do')
runner_lines.append(' [ -x "$_cudir/bin/nvcc" ] && export CUDA_HOME="$_cudir" && export PATH="$_cudir/bin:$PATH" && break')
runner_lines.append(' done')
# rm -rf build so a prior poisoned CMakeCache.txt (e.g. from a
# failed CUDA attempt) doesn't cause the next configure to reuse
# stale settings and silently produce a CPU-only binary.
runner_lines.append(' cd ~/llama.cpp && rm -rf build')
runner_lines.append(' if command -v nvcc &>/dev/null; then')
runner_lines.append(' echo "[odysseus] CUDA nvcc found — building llama-server with CUDA (GPU) support..."')
runner_lines.append(' cmake -B build -DCMAKE_BUILD_TYPE=Release -DGGML_CUDA=ON \\')
runner_lines.append(' && cmake --build build -j"$NPROC" --target llama-server \\')
runner_lines.append(' && ln -sf ~/llama.cpp/build/bin/llama-server ~/bin/llama-server')
runner_lines.append(' else')
runner_lines.append(' echo "[odysseus] WARNING: nvcc not found — building llama-server for CPU only."')
runner_lines.append(' echo "[odysseus] GPU inference will not be available for this llama.cpp build."')
runner_lines.append(' echo "[odysseus] To get a GPU build, first install vLLM via Cookbook -> Dependencies"')
runner_lines.append(' echo "[odysseus] (its CUDA wheels include nvcc), then re-launch this serve task."')
runner_lines.append(' cmake -B build -DCMAKE_BUILD_TYPE=Release \\')
runner_lines.append(' && cmake --build build -j"$NPROC" --target llama-server \\')
runner_lines.append(' && ln -sf ~/llama.cpp/build/bin/llama-server ~/bin/llama-server')
runner_lines.append(' fi')
_append_llama_cpp_linux_accel_build_lines(runner_lines)
runner_lines.append(' fi')
runner_lines.append(' # If the native build failed, fall back to the Python bindings.')
runner_lines.append(' if ! command -v llama-server &>/dev/null && ! python3 -c "import llama_cpp" 2>/dev/null; then')
runner_lines.append(' echo "llama-server build failed — installing Python bindings as fallback..."')
runner_lines.append(' pip install --user --break-system-packages -q llama-cpp-python 2>/dev/null || pip install -q llama-cpp-python 2>/dev/null || true')
runner_lines.append(f" {_pip_install_fallback_chain('llama-cpp-python[server]', python_cmd='pip')} || true")
runner_lines.append(' fi')
runner_lines.append(' if ! command -v llama-server &>/dev/null && ! python3 -c "import llama_cpp" 2>/dev/null; then')
runner_lines.append(' echo "ERROR: llama.cpp serving is not available after install/build attempts."')
runner_lines.append(' ODYSSEUS_PREFLIGHT_EXIT=127')
runner_lines.append(' fi')
runner_lines.append('fi')
elif "ollama" in req.cmd:
handled_ollama_serve = True
_ollama_port = "11434"
_ollama_match = re.search(r"OLLAMA_HOST=[^\s:]+:(\d+)", req.cmd)
if _ollama_match:
_ollama_port = _ollama_match.group(1)
_ollama_default_host = "0.0.0.0" if remote else "127.0.0.1"
_ollama_host, _ollama_port = _ollama_bind_from_cmd(
req.cmd,
default_host=_ollama_default_host,
)
# Ollama can be a host binary, a system service, or a Docker
# container. If the HTTP API is already reachable, the model is
# already served and we should not require a host `ollama` CLI.
runner_lines.append(f'ODYSSEUS_OLLAMA_HOST={_bash_squote(_ollama_host)}')
runner_lines.append(f'ODYSSEUS_OLLAMA_PORT="{_ollama_port}"')
runner_lines.append('ODYSSEUS_OLLAMA_URL=""')
runner_lines.append('for _ody_ollama_port in "$ODYSSEUS_OLLAMA_PORT" 11434; do')
runner_lines.append(' [ -z "$_ody_ollama_port" ] && continue')
runner_lines.append(' for _ody_ollama_host in 127.0.0.1 localhost host.docker.internal; do')
runner_lines.append(' _ody_ollama_url="http://${_ody_ollama_host}:${_ody_ollama_port}"')
runner_lines.append(' if curl -sf "$_ody_ollama_url/api/tags" >/dev/null 2>&1; then')
runner_lines.append(' ODYSSEUS_OLLAMA_URL="$_ody_ollama_url"')
runner_lines.append(' ODYSSEUS_OLLAMA_PORT="$_ody_ollama_port"')
runner_lines.append(' break 2')
runner_lines.append(' fi')
runner_lines.append('for _ody_ollama_try in $(seq 1 20); do')
runner_lines.append(' for _ody_ollama_port in "$ODYSSEUS_OLLAMA_PORT" 11434; do')
runner_lines.append(' [ -z "$_ody_ollama_port" ] && continue')
runner_lines.append(' for _ody_ollama_host in 127.0.0.1 localhost host.docker.internal; do')
runner_lines.append(' _ody_ollama_url="http://${_ody_ollama_host}:${_ody_ollama_port}"')
runner_lines.append(' if curl -sf "$_ody_ollama_url/api/tags" >/dev/null 2>&1; then')
runner_lines.append(' ODYSSEUS_OLLAMA_URL="$_ody_ollama_url"')
runner_lines.append(' ODYSSEUS_OLLAMA_PORT="$_ody_ollama_port"')
runner_lines.append(' break 3')
runner_lines.append(' fi')
runner_lines.append(' done')
runner_lines.append(' done')
runner_lines.append(' [ "$_ody_ollama_try" -eq 1 ] && echo "[odysseus] Waiting for an existing Ollama API on ports ${ODYSSEUS_OLLAMA_PORT}/11434..."')
runner_lines.append(' sleep 1')
runner_lines.append('done')
runner_lines.append('if [ -n "$ODYSSEUS_OLLAMA_URL" ]; then')
runner_lines.append(' if [ "$ODYSSEUS_OLLAMA_PORT" != "' + _ollama_port + '" ]; then')
@@ -1015,8 +1116,12 @@ def setup_cookbook_routes() -> APIRouter:
runner_lines.append(' echo "=== Process exited with code 127 ==="')
runner_lines.append(' exec bash -i')
runner_lines.append('fi')
runner_lines.append('echo "Starting ollama server on 0.0.0.0:${ODYSSEUS_OLLAMA_PORT}..."')
runner_lines.append('OLLAMA_HOST="0.0.0.0:${ODYSSEUS_OLLAMA_PORT}" ollama serve')
runner_lines.append('ODYSSEUS_OLLAMA_URL="http://${ODYSSEUS_OLLAMA_HOST}:${ODYSSEUS_OLLAMA_PORT}"')
if remote and _ollama_host in ("0.0.0.0", "::"):
runner_lines.append('echo "[odysseus] WARNING: remote Ollama will bind to ${ODYSSEUS_OLLAMA_HOST}:${ODYSSEUS_OLLAMA_PORT} so Odysseus can reach it from this host."')
runner_lines.append('echo "[odysseus] Ollama has no built-in authentication; expose this only on a trusted LAN/VPN or provide an explicit OLLAMA_HOST with your own access controls."')
runner_lines.append('echo "Starting ollama server on ${ODYSSEUS_OLLAMA_HOST}:${ODYSSEUS_OLLAMA_PORT}..."')
runner_lines.append('OLLAMA_HOST="${ODYSSEUS_OLLAMA_HOST}:${ODYSSEUS_OLLAMA_PORT}" ollama serve')
runner_lines.append('_ody_exit=$?')
runner_lines.append('echo')
runner_lines.append('echo "=== Process exited with code ${_ody_exit} ==="')
@@ -1032,19 +1137,24 @@ def setup_cookbook_routes() -> APIRouter:
# find the `vllm` CLI ("command not found"). Mirrors llama.cpp above.
runner_lines.append('export PATH="$HOME/.local/bin:$PATH"')
runner_lines.append('if ! command -v vllm &>/dev/null; then')
runner_lines.append(' echo "ERROR: vLLM is not installed. Open Cookbook -> Dependencies and install vllm on this server, then launch again."')
runner_lines.append(' echo "ERROR: vLLM is not installed."')
runner_lines.append(' ODYSSEUS_PREFLIGHT_EXIT=127')
runner_lines.append('fi')
elif "sglang.launch_server" in req.cmd:
runner_lines.append('export PATH="$HOME/.local/bin:$PATH"')
runner_lines.append('if ! python3 -c "import sglang" 2>/dev/null; then')
runner_lines.append(' echo "ERROR: SGLang is not installed. Open Cookbook -> Dependencies and install sglang on this server, then launch again."')
runner_lines.append('if ! command -v sglang &>/dev/null; then')
runner_lines.append(' echo "ERROR: SGLang is not installed."')
runner_lines.append(' ODYSSEUS_PREFLIGHT_EXIT=127')
runner_lines.append('elif ! ODYSSEUS_SGLANG_IMPORT_ERROR="$(python3 -c "import sglang" 2>&1)"; then')
runner_lines.append(' echo "ERROR: SGLang is installed but failed to import."')
runner_lines.append(' printf "%s\\n" "$ODYSSEUS_SGLANG_IMPORT_ERROR"')
runner_lines.append(' ODYSSEUS_PREFLIGHT_EXIT=127')
runner_lines.append('fi')
elif "scripts/diffusion_server.py" in req.cmd or ".diffusion_server.py" in req.cmd:
runner_lines.append('export PATH="$HOME/.local/bin:$PATH"')
runner_lines.append('if ! python3 -c "import torch, diffusers" 2>/dev/null; then')
runner_lines.append(' echo "ERROR: Diffusion serving requires PyTorch + diffusers. Open Cookbook -> Dependencies and install diffusers on this server, then launch again."')
runner_lines.append('if ! ODYSSEUS_DIFFUSION_IMPORT_ERROR="$(python3 -c "import torch, diffusers" 2>&1)"; then')
runner_lines.append(' echo "ERROR: Diffusion serving requires PyTorch + diffusers."')
runner_lines.append(' printf "%s\\n" "$ODYSSEUS_DIFFUSION_IMPORT_ERROR"')
runner_lines.append(' ODYSSEUS_PREFLIGHT_EXIT=127')
runner_lines.append('fi')
@@ -1116,11 +1226,16 @@ def setup_cookbook_routes() -> APIRouter:
stderr = (await proc.stderr.read()).decode(errors="replace")
return {"ok": False, "error": stderr, "session_id": session_id}
# Auto-register as model endpoint if serving a diffusion model
# Auto-register a model endpoint so the served model shows up in the model
# picker with no manual /setup step. Diffusion models get an image
# endpoint; any other real model serve (i.e. not a pip-install task) gets
# a local LLM endpoint pointed at its /v1.
endpoint_id = None
is_diffusion = "diffusion_server.py" in req.cmd
if is_diffusion:
endpoint_id = _auto_register_image_endpoint(req, remote)
elif not is_pip_install:
endpoint_id = _auto_register_llm_endpoint(req, remote)
# Log to assistant
try:
@@ -1357,9 +1472,16 @@ def setup_cookbook_routes() -> APIRouter:
total_mb = max(0, int(total_bytes / (1024 * 1024)))
used_mb = max(0, min(total_mb, int(used_bytes / (1024 * 1024))))
free_mb = max(0, total_mb - used_mb)
# GTT = the system-RAM pool the GPU pages into when VRAM is full.
# On a discrete card a large gtt_used means the model spilled past
# VRAM into RAM over PCIe — much slower. Surface it so the UI can
# warn "spilling to RAM" instead of the user wondering why it's slow.
gtt_used_raw = await _gpu_read_file(f"{base}/mem_info_gtt_used", host, ssh_port)
gtt_used_mb = max(0, int(int(gtt_used_raw) / (1024 * 1024))) if (gtt_used_raw and gtt_used_raw.isdigit()) else 0
gpus.append({
"index": len(gpus), "name": name, "uuid": entry,
"free_mb": free_mb, "total_mb": total_mb, "used_mb": used_mb,
"gtt_used_mb": gtt_used_mb,
"util_pct": 0, "busy": bool(total_mb and (free_mb / total_mb) < 0.85),
"processes": [], "backend": "rocm", "source": "amd-sysfs",
"unified_memory": unified,
@@ -1461,6 +1583,46 @@ def setup_cookbook_routes() -> APIRouter:
if gpus:
return {"ok": True, "gpus": gpus, "backend": "cuda", "source": "nvidia-smi"}
# Local Apple Silicon / Metal fallback. macOS has no nvidia-smi and no
# Linux /sys/class/drm tree, but services.hwfit.hardware already knows
# how to size the shared unified-memory GPU budget. Keep this route in
# sync so Cookbook's GPU picker doesn't show "nvidia-smi not found" on
# native Mac launches.
if not host and sys.platform == "darwin":
try:
from services.hwfit.hardware import detect_system
info = detect_system(fresh=True)
backend = str(info.get("backend") or "").lower()
if backend in {"metal", "mps", "apple"} and info.get("gpu_count", 0) > 0:
total_mb = int(float(info.get("gpu_vram_gb") or info.get("total_ram_gb") or 0) * 1024)
free_mb = int(float(info.get("available_ram_gb") or 0) * 1024)
if total_mb and (free_mb <= 0 or free_mb > total_mb):
free_mb = total_mb
used_mb = max(0, total_mb - max(0, free_mb))
return {
"ok": True,
"gpus": [{
"index": 0,
"name": info.get("gpu_name") or info.get("cpu_name") or "Apple Silicon GPU",
"uuid": "apple-metal-0",
"free_mb": max(0, free_mb),
"total_mb": max(0, total_mb),
"used_mb": used_mb,
"util_pct": 0,
"busy": bool(total_mb and (free_mb / total_mb) < 0.5),
"processes": [],
"backend": "metal",
"source": "apple-metal",
"unified_memory": True,
}],
"backend": "metal",
"source": "apple-metal",
"fallback_from": "nvidia-smi",
"nvidia_error": nvidia_error,
}
except Exception as e:
logger.warning("Apple Metal GPU fallback failed: %s", e)
amd_gpus = await _probe_amd_sysfs(host, ssh_port)
if amd_gpus:
return {
@@ -1607,6 +1769,33 @@ def setup_cookbook_routes() -> APIRouter:
disk_tasks = on_disk.get("tasks") or [] if isinstance(on_disk, dict) else []
incoming_tasks = data.get("tasks") if isinstance(data.get("tasks"), list) else []
# Anti-poisoning guard: a stale browser tab can keep POSTing a
# download task as status='done' from before the strict-finish
# fix landed, undoing any server-side correction. For each
# incoming "done" download, override to "running" if the last
# shard pattern says N<total AND no DOWNLOAD_OK/DOWNLOAD_FAILED/
# /snapshots/ sentinel is in the output.
import re as _re_dl
for _it in incoming_tasks:
if (not isinstance(_it, dict)) or _it.get("type") != "download" or _it.get("status") != "done":
continue
_out = _it.get("output") or ""
if ("DOWNLOAD_OK" in _out) or ("DOWNLOAD_FAILED" in _out) or ("/snapshots/" in _out):
continue
_shards = _re_dl.findall(r"model-(\d+)-of-(\d+)\.safetensors", _out)
if _shards:
_n, _tot = _shards[-1]
if int(_n) < int(_tot):
logger.info(f"cookbook state POST: rejecting stale done for {_it.get('sessionId')} "
f"(last shard {_n}/{_tot}, no DOWNLOAD_OK)")
_it["status"] = "running"
else:
_completed = _out.count("Download complete")
_starts = _out.count("Downloading '")
if _starts > _completed:
logger.info(f"cookbook state POST: rejecting stale done for {_it.get('sessionId')} "
f"({_completed}/{_starts} files complete, no DOWNLOAD_OK)")
_it["status"] = "running"
incoming_ids = {t.get("sessionId") for t in incoming_tasks if isinstance(t, dict) and t.get("sessionId")}
import time as _t
now_ms = int(_t.time() * 1000)
@@ -1763,6 +1952,43 @@ def setup_cookbook_routes() -> APIRouter:
def _cookbook_tasks_status_sync():
import subprocess
def _download_cache_complete(repo_id: str, remote_host: str = "", ssh_port: str = "") -> bool:
"""Best-effort check for a completed HF cache entry.
tmux output can stop at a stale progress line if the pane/session
disappears before Cookbook captures the final DOWNLOAD_OK marker.
In that case, trust the cache shape: a snapshot directory with files
and no *.incomplete blobs means HuggingFace finished materializing the
model.
"""
if not repo_id or "/" not in repo_id:
return False
py = (
"import os,sys;"
"repo=sys.argv[1];"
"base=os.environ.get('HUGGINGFACE_HUB_CACHE') or os.path.join(os.environ.get('HF_HOME', os.path.expanduser('~/.cache/huggingface')), 'hub');"
"d=os.path.join(base,'models--'+repo.replace('/','--'));"
"snap=os.path.join(d,'snapshots');"
"ok=os.path.isdir(snap) and any(os.path.isdir(os.path.join(snap,x)) and os.listdir(os.path.join(snap,x)) for x in os.listdir(snap));"
"inc=False;"
"blobs=os.path.join(d,'blobs');"
"inc=os.path.isdir(blobs) and any(x.endswith('.incomplete') for x in os.listdir(blobs));"
"sys.exit(0 if ok and not inc else 1)"
)
cmd = ["python3", "-c", py, repo_id]
try:
if remote_host:
ssh_base = ["ssh"]
if ssh_port and ssh_port != "22":
ssh_base.extend(["-p", str(ssh_port)])
shell_cmd = " ".join(shlex.quote(x) for x in cmd)
proc = subprocess.run(ssh_base + [remote_host, shell_cmd], timeout=12, capture_output=True)
else:
proc = subprocess.run(cmd, timeout=12, capture_output=True)
return proc.returncode == 0
except Exception:
return False
# Load saved tasks from cookbook state
tasks = []
if _cookbook_state_path.exists():
@@ -1902,14 +2128,21 @@ def setup_cookbook_routes() -> APIRouter:
# persists after the process exits, so a finished download still has a
# snapshot to classify (DOWNLOAD_OK / exit marker) — evaluate it even
# when the PID is gone instead of blindly reporting "stopped".
download_zero_files = False
status = "unknown"
if is_alive or (local_win_task and full_snapshot):
lower = full_snapshot.lower()
has_exit = "=== process exited with code" in lower
exit_match = re.search(r"=== process exited with code\s+(-?\d+)", full_snapshot, re.I)
has_exit = exit_match is not None
exit_code = int(exit_match.group(1)) if exit_match else None
has_error = "error" in lower or "failed" in lower or "traceback" in lower
if has_exit and task_type == "serve":
# Serve tasks that exit are always errors — they should run indefinitely
status = "error"
elif has_exit and task_type == "download":
# Dependency installs are tracked as download tasks but only
# emit the generic runner exit marker, not HF download markers.
status = "completed" if exit_code == 0 else "error"
elif has_exit and "unrecognized arguments" in lower:
status = "error"
elif has_error and not ("application startup complete" in lower):
@@ -1918,7 +2151,11 @@ def setup_cookbook_routes() -> APIRouter:
# Only download tasks treat 100% as "completed".
# Serve tasks log 100%|██████| during inference progress
# (diffusion sampling, etc.) — that's "running", not done.
status = "completed"
if re.search(r"Fetching\s+0\s+files", full_snapshot, re.IGNORECASE):
status = "error"
download_zero_files = True
else:
status = "completed"
elif "application startup complete" in lower:
status = "ready"
elif not is_alive:
@@ -1928,7 +2165,14 @@ def setup_cookbook_routes() -> APIRouter:
status = "running"
else:
# Session is dead — check if it completed or crashed
status = "stopped"
if task_type == "download" and _download_cache_complete(_payload.get("repo_id") or model, remote, str(_tport or "")):
status = "completed"
if not progress_text:
progress_text = "Download complete"
if not full_snapshot:
full_snapshot = "DOWNLOAD_OK"
else:
status = "stopped"
# Parse structured phase info — single source of truth for the UI
phase_info = _parse_serve_phase(full_snapshot, task_type) if (task_type == "serve" and status == "running" and full_snapshot) else {}
@@ -1938,6 +2182,8 @@ def setup_cookbook_routes() -> APIRouter:
diagnosis = _diagnose_serve_output(full_snapshot) if task_type == "serve" and full_snapshot else None
if diagnosis and status in {"running", "unknown", "stopped"}:
status = "error"
if download_zero_files:
diagnosis = {"message": "No matching files were downloaded. The model repo or filename/quant pattern may be wrong (for example a ':Q4_K_M' tag that does not exist in the repo). Check the repo and the include/quant pattern."}
output_tail = "\n".join(full_snapshot.splitlines()[-12:]) if full_snapshot else ""
results.append({
+3 -1
View File
@@ -152,7 +152,7 @@ def _resolve_user_upload_path(
owner=owner,
auth_manager=auth_manager,
)
if not resolved:
if not isinstance(resolved, dict) or not resolved:
return None
path = resolved.get("path")
upload_dir = getattr(upload_handler, "upload_dir", None)
@@ -203,6 +203,8 @@ def _assert_pdf_marker_upload_owned(
def _derive_title(content: str) -> str:
"""Derive a title from document content."""
import re
if not isinstance(content, str):
return "Untitled"
text = content.strip()
if not text:
return "Untitled"
+37 -6
View File
@@ -15,6 +15,21 @@ from src.auth_helpers import get_current_user
logger = logging.getLogger(__name__)
def _aggregate_language_facets(lang_rows):
"""Sum document counts per display language for the library facet.
NULL-language and explicit "text" rows share the "text" bucket (the
language filter treats them as one), so they must be ADDED. The old dict
comprehension keyed both to "text", silently overwriting one group and
undercounting the facet versus what the filter actually returns.
"""
out = {}
for lang, cnt in lang_rows:
key = lang or "text"
out[key] = out.get(key, 0) + cnt
return out
from routes.document_helpers import (
DocumentCreate, DocumentUpdate, DocumentPatch,
@@ -145,7 +160,7 @@ def setup_document_routes(session_manager, upload_handler=None) -> APIRouter:
create_form_markdown_document,
create_plain_pdf_document,
)
from src.document_processor import _process_pdf
from src.document_processor import _process_pdf, strip_pdf_content_marker
import os
from src.auth_helpers import require_privilege
@@ -184,7 +199,7 @@ def setup_document_routes(session_manager, upload_handler=None) -> APIRouter:
title = os.path.splitext(meta.get("original_name") or meta.get("name") or upload_id)[0]
try:
body_text = _process_pdf(pdf_path).lstrip("\n[PDF content]:").strip()
body_text = strip_pdf_content_marker(_process_pdf(pdf_path))
except Exception:
body_text = None
@@ -258,7 +273,7 @@ def setup_document_routes(session_manager, upload_handler=None) -> APIRouter:
)
lang_q = _owner_session_filter(lang_q, user)
lang_rows = lang_q.group_by(Document.language).all()
languages = {lang or "text": cnt for lang, cnt in lang_rows}
languages = _aggregate_language_facets(lang_rows)
# Session count (owner-filtered)
sc_q = (
@@ -402,7 +417,7 @@ def setup_document_routes(session_manager, upload_handler=None) -> APIRouter:
text extraction was wired, plus for scanned/image-only PDFs where the
VL model picks up text the basic pypdf path missed."""
import re
from src.document_processor import _process_pdf
from src.document_processor import _process_pdf, strip_pdf_content_marker
from src.pdf_form_doc import find_source_upload_id
user = get_current_user(request)
@@ -423,7 +438,7 @@ def setup_document_routes(session_manager, upload_handler=None) -> APIRouter:
raise HTTPException(404, "Source PDF could not be located")
try:
body_text = _process_pdf(pdf_path).lstrip("\n[PDF content]:").strip()
body_text = strip_pdf_content_marker(_process_pdf(pdf_path))
except Exception as e:
logger.error(f"extract_pdf_text failed for {pdf_path}: {e}")
raise HTTPException(500, f"Extraction failed: {e}")
@@ -593,6 +608,15 @@ def setup_document_routes(session_manager, upload_handler=None) -> APIRouter:
if req.session_id is not None:
# Empty string = unlink from session
doc.session_id = req.session_id if req.session_id else None
if not req.session_id:
# Tab closed / doc detached from its session — drop the
# in-memory active-doc pointer so the last-resort injection
# path doesn't re-surface this doc in a later chat (#1160).
try:
from src.tool_implementations import clear_active_document
clear_active_document(doc_id)
except Exception:
pass
db.commit()
db.refresh(doc)
return _doc_to_dict(doc)
@@ -615,6 +639,13 @@ def setup_document_routes(session_manager, upload_handler=None) -> APIRouter:
raise HTTPException(404, "Document not found")
_verify_doc_owner(db, doc, user)
doc.is_active = False
# Closed/deleted — drop the in-memory active-doc pointer so it isn't
# re-injected into a later, unrelated chat (#1160).
try:
from src.tool_implementations import clear_active_document
clear_active_document(doc_id)
except Exception:
pass
db.commit()
return {"status": "deleted", "id": doc_id}
except HTTPException:
@@ -885,7 +916,7 @@ def setup_document_routes(session_manager, upload_handler=None) -> APIRouter:
for i, doc in enumerate(batch):
if i >= len(verdicts):
break
verdict = verdicts[i].lower().strip()
verdict = str(verdicts[i] or "").lower().strip()
if verdict == "junk":
doc.tidy_verdict = "junk"
db.delete(doc)
+9 -5
View File
@@ -67,6 +67,14 @@ def _summary(d: EditorDraft) -> Dict[str, Any]:
}
def _load_payload(raw: Optional[str]) -> Dict[str, Any]:
try:
payload = json.loads(raw) if raw else {}
except Exception:
return {}
return payload if isinstance(payload, dict) else {}
def setup_editor_draft_routes() -> APIRouter:
router = APIRouter(tags=["editor-drafts"])
@@ -93,13 +101,9 @@ def setup_editor_draft_routes() -> APIRouter:
).first()
if not d or not _owns(d, user):
raise HTTPException(404, "Draft not found")
try:
payload = json.loads(d.payload) if d.payload else {}
except Exception:
payload = {}
return {
**_summary(d),
"payload": payload,
"payload": _load_payload(d.payload),
}
finally:
db.close()
+151 -62
View File
@@ -15,7 +15,6 @@ and `email_pollers.py` (the background loops):
import os
import imaplib
import smtplib
import ssl
import email as email_mod
import email.header
import email.utils
@@ -33,47 +32,43 @@ from fastapi import Query, HTTPException, Request
from pydantic import BaseModel
from typing import Optional, List
from src.auth_helpers import get_current_user
from src.auth_helpers import _auth_disabled, get_current_user
from src.secret_storage import decrypt as _decrypt
logger = logging.getLogger(__name__)
def _send_smtp_message(cfg: dict, from_addr: str, recipients: list[str], message: str | bytes, timeout: int = 30) -> None:
"""Send through SMTP using the conventional TLS mode for the configured port.
def _smtp_security_mode(cfg: dict) -> str:
raw = str(cfg.get("smtp_security") or "").strip().lower()
if raw in {"ssl", "starttls", "none"}:
return raw
port = int(cfg.get("smtp_port") or 465)
if port == 587:
return "starttls"
return "ssl"
Account settings only store host/port today. Port 465 is implicit TLS
(SMTP_SSL); port 587 is plain SMTP upgraded with STARTTLS. Using SSL
directly against 587 raises the classic "[SSL: WRONG_VERSION_NUMBER]"
error even when credentials are correct.
"""
def _send_smtp_message(cfg: dict, from_addr: str, recipients: list[str], message: str | bytes, timeout: int = 30) -> None:
"""Send through SMTP using the configured transport security mode."""
host = cfg["smtp_host"]
port = int(cfg.get("smtp_port") or 465)
user = cfg.get("smtp_user") or ""
password = cfg.get("smtp_password") or ""
def _send_starttls(starttls_port: int = 587) -> None:
with smtplib.SMTP(host, starttls_port, timeout=timeout) as smtp:
smtp.starttls()
if user and password:
smtp.login(user, password)
smtp.sendmail(from_addr, recipients, message)
security = _smtp_security_mode(cfg)
if port == 587:
_send_starttls(587)
return
try:
if security == "ssl":
with smtplib.SMTP_SSL(host, port, timeout=timeout) as smtp:
if user and password:
smtp.login(user, password)
smtp.sendmail(from_addr, recipients, message)
return
except (TimeoutError, ssl.SSLError) as e:
if port == 465:
logger.warning("SMTP implicit TLS on %s:465 failed (%s); retrying STARTTLS on 587", host, e)
_send_starttls(587)
return
raise
with smtplib.SMTP(host, port, timeout=timeout) as smtp:
if security == "starttls":
smtp.starttls()
if user and password:
smtp.login(user, password)
smtp.sendmail(from_addr, recipients, message)
def _strip_think(text: str) -> str:
@@ -152,6 +147,8 @@ def _require_auth(request: Request) -> str:
u = get_current_user(request)
if u:
return u
if _auth_disabled():
return ""
auth_mgr = getattr(request.app.state, "auth_manager", None)
if auth_mgr is not None and getattr(auth_mgr, "is_configured", False):
raise HTTPException(401, "Not authenticated")
@@ -300,7 +297,8 @@ def _init_scheduled_db():
send_at TEXT NOT NULL,
created_at TEXT NOT NULL,
status TEXT NOT NULL DEFAULT 'pending',
error TEXT
error TEXT,
owner TEXT DEFAULT ''
)
""")
# Email summary cache (keyed by Message-ID)
@@ -438,6 +436,35 @@ def _init_scheduled_db():
conn.execute("ALTER TABLE scheduled_emails ADD COLUMN account_id TEXT")
if "odysseus_kind" not in cols:
conn.execute("ALTER TABLE scheduled_emails ADD COLUMN odysseus_kind TEXT")
if "owner" not in cols:
conn.execute("ALTER TABLE scheduled_emails ADD COLUMN owner TEXT DEFAULT ''")
conn.execute("CREATE INDEX IF NOT EXISTS ix_scheduled_emails_owner_status ON scheduled_emails(owner, status)")
# Backfill owner on legacy rows from the owning email account so the
# owner-scoped list/cancel routes surface pre-migration scheduled
# sends to the right user (the poller already resolves these by
# account at send time; this aligns the UI with that).
legacy_accounts = conn.execute(
"SELECT DISTINCT account_id FROM scheduled_emails "
"WHERE (owner IS NULL OR owner = '') AND account_id IS NOT NULL AND account_id != ''"
).fetchall()
if legacy_accounts:
try:
from core.database import SessionLocal as _SL, EmailAccount as _EA
_db = _SL()
try:
for (acct_id,) in legacy_accounts:
row = _db.query(_EA.owner).filter(_EA.id == acct_id).first()
acct_owner = (row[0] or "") if row else ""
if acct_owner:
conn.execute(
"UPDATE scheduled_emails SET owner = ? "
"WHERE account_id = ? AND (owner IS NULL OR owner = '')",
(acct_owner, acct_id),
)
finally:
_db.close()
except Exception:
pass
except Exception:
pass
# Lazy migration: add turns_json to email_boundaries for server-side
@@ -541,6 +568,7 @@ def _get_email_config(account_id: str | None = None, owner: str = "") -> dict:
"account_name": row.name,
"smtp_host": row.smtp_host or "",
"smtp_port": int(row.smtp_port or 465),
"smtp_security": _smtp_security_mode({"smtp_security": getattr(row, "smtp_security", ""), "smtp_port": row.smtp_port}),
"smtp_user": row.smtp_user or "",
"smtp_password": _decrypt(row.smtp_password or ""),
"imap_host": row.imap_host or "",
@@ -567,6 +595,10 @@ def _get_email_config(account_id: str | None = None, owner: str = "") -> dict:
"account_name": "legacy",
"smtp_host": settings.get("smtp_host", os.environ.get("SMTP_HOST", "")),
"smtp_port": int(settings.get("smtp_port", os.environ.get("SMTP_PORT", "465")) or 465),
"smtp_security": _smtp_security_mode({
"smtp_security": settings.get("smtp_security", os.environ.get("SMTP_SECURITY", "")),
"smtp_port": settings.get("smtp_port", os.environ.get("SMTP_PORT", "465")),
}),
"smtp_user": settings.get("smtp_user", os.environ.get("SMTP_USER", "")),
"smtp_password": settings.get("smtp_password", os.environ.get("SMTP_PASSWORD", "")),
"imap_host": settings.get("imap_host", os.environ.get("IMAP_HOST", "")),
@@ -606,7 +638,32 @@ def _list_email_accounts() -> list[dict]:
# ── IMAP helpers ──
_IMAP_TIMEOUT_SECONDS = 15
def _coerce_imap_timeout_seconds(raw: str | None) -> int:
try:
value = int(raw or "30")
except (TypeError, ValueError):
value = 30
return max(5, min(value, 300))
_IMAP_TIMEOUT_SECONDS = _coerce_imap_timeout_seconds(os.environ.get("ODYSSEUS_IMAP_TIMEOUT_SECONDS"))
def _open_imap_connection(host: str, port: int, *, starttls: bool, timeout: int = _IMAP_TIMEOUT_SECONDS):
"""Open an IMAP connection using the configured security mode."""
port = int(port or 993)
if starttls:
conn = imaplib.IMAP4(host, port, timeout=timeout)
conn.starttls()
elif port == 993:
conn = imaplib.IMAP4_SSL(host, port, timeout=timeout)
else:
conn = imaplib.IMAP4(host, port, timeout=timeout)
try:
conn.sock.settimeout(timeout)
except Exception:
pass
return conn
def _imap_connect(account_id: str | None = None, owner: str = ""):
# SECURITY: passing `owner` scopes the fallback config lookup so a brand
@@ -620,17 +677,12 @@ def _imap_connect(account_id: str | None = None, owner: str = ""):
# The last branch is critical: previously this fell into IMAP4_SSL
# for any non-STARTTLS port, which would fail the TLS handshake on
# plain local servers (Dovecot on 31143, etc.).
if cfg.get("imap_starttls"):
conn = imaplib.IMAP4(cfg["imap_host"], cfg["imap_port"], timeout=_IMAP_TIMEOUT_SECONDS)
conn.starttls()
elif int(cfg.get("imap_port") or 993) == 993:
conn = imaplib.IMAP4_SSL(cfg["imap_host"], cfg["imap_port"], timeout=_IMAP_TIMEOUT_SECONDS)
else:
conn = imaplib.IMAP4(cfg["imap_host"], cfg["imap_port"], timeout=_IMAP_TIMEOUT_SECONDS)
try:
conn.sock.settimeout(_IMAP_TIMEOUT_SECONDS)
except Exception:
pass
conn = _open_imap_connection(
cfg["imap_host"],
cfg["imap_port"],
starttls=bool(cfg.get("imap_starttls")),
timeout=_IMAP_TIMEOUT_SECONDS,
)
conn.login(cfg["imap_user"], cfg["imap_password"])
return conn
@@ -699,7 +751,13 @@ def _decode_header(raw):
decoded = []
for data, charset in parts:
if isinstance(data, bytes):
decoded.append(data.decode(charset or "utf-8", errors="replace"))
try:
decoded.append(data.decode(charset or "utf-8", errors="replace"))
except (LookupError, ValueError):
# Unknown/invalid MIME charset (e.g. a malformed or spam header
# like =?x-unknown-charset?B?...?=). errors="replace" only covers
# byte-decode errors, not codec lookup, so fall back to utf-8.
decoded.append(data.decode("utf-8", errors="replace"))
else:
decoded.append(data)
return " ".join(decoded)
@@ -793,22 +851,27 @@ def _detect_spam_folder(conn):
return None
def _imap_move(uid, dest, src="INBOX"):
def _imap_move(uid, dest, src="INBOX", account_id: str | None = None, owner: str = ""):
"""Move a single IMAP UID from src folder to dest. Returns True on success."""
c = None
try:
c = _imap_connect()
c = _imap_connect(account_id, owner=owner)
c.select(_q(src))
status, _ = c.copy(uid, _q(dest))
if status != "OK":
c.logout()
return False
c.store(uid, "+FLAGS", "\\Deleted")
c.expunge()
c.logout()
return True
except Exception as e:
logger.warning(f"IMAP move {uid}{dest} failed: {e}")
return False
finally:
if c:
try:
c.logout()
except Exception:
pass
def _extract_attachment_text(msg, max_chars: int = 6000) -> str:
@@ -999,7 +1062,9 @@ def _fetch_sender_thread_context(sender_addr: str,
exclude_folder: str = "INBOX",
limit: int = 3,
max_chars_per_email: int = 1500,
max_attachment_chars: int = 4000) -> str:
max_attachment_chars: int = 4000,
account_id: str | None = None,
owner: str = "") -> str:
"""Pull the last N emails from `sender_addr` (across common folders),
extract their body snippets + attachment text, and return one formatted
block ready to be glued into an LLM system prompt as "REFERENCED MATERIAL".
@@ -1021,7 +1086,7 @@ def _fetch_sender_thread_context(sender_addr: str,
seen_uids.add((exclude_folder or "INBOX", str(exclude_uid)))
try:
conn = _imap_connect()
conn = _imap_connect(account_id, owner=owner)
except Exception as e:
logger.warning(f"sender-thread-context: imap connect failed: {e}")
return ""
@@ -1104,7 +1169,12 @@ def _fetch_sender_thread_context(sender_addr: str,
return "\n\n=====\n\n".join(blocks)
def _pre_retrieve_context(body: str, sender: str) -> tuple:
def _pre_retrieve_context(
body: str,
sender: str,
account_id: str | None = None,
owner: str = "",
) -> tuple:
"""Extract key terms from an incoming email and search past emails + contacts.
Returns (context_snippets, terms_list). Best-effort; never raises.
@@ -1128,18 +1198,37 @@ def _pre_retrieve_context(body: str, sender: str) -> tuple:
# ── Known-sender check: only retrieve context for senders we already
# have a relationship with. New / cold senders get an empty context.
sender_addr = email.utils.parseaddr(sender or "")[1].lower()
is_known = False
# The CardDAV address book is global admin data backed by a single
# Radicale instance, so only fold it into reply context for an admin /
# single-user owner. Non-admin owners still get their own (owner-scoped)
# IMAP history below, just not the shared contacts.
try:
from routes.contacts_routes import _fetch_contacts
for c in _fetch_contacts() or []:
if (c.get("email") or "").lower() == sender_addr:
is_known = True
break
from src.tool_security import owner_is_admin_or_single_user
contacts_allowed = owner_is_admin_or_single_user(owner or None)
except Exception:
pass
contacts_allowed = not bool(owner)
is_known = False
if contacts_allowed:
try:
from routes.contacts_routes import _fetch_contacts
for c in _fetch_contacts() or []:
# Contacts are normalized to plural `emails` lists, but
# keep the legacy singular key fallback for older data.
contact_emails = []
raw_emails = c.get("emails")
if isinstance(raw_emails, list):
contact_emails.extend(str(e or "") for e in raw_emails)
legacy_email = c.get("email")
if legacy_email:
contact_emails.append(str(legacy_email))
if any((addr or "").strip().lower() == sender_addr for addr in contact_emails):
is_known = True
break
except Exception:
pass
if not is_known and sender_addr:
try:
with _imap() as _ck:
with _imap(account_id, owner=owner) as _ck:
_ck.select("INBOX", readonly=True)
st_known, dk = _ck.search(None, f'(FROM "{sender_addr}")')
if st_known == "OK" and dk and dk[0]:
@@ -1177,7 +1266,7 @@ def _pre_retrieve_context(body: str, sender: str) -> tuple:
return context_snippets, terms_list
try:
ctx_conn = _imap_connect()
ctx_conn = _imap_connect(account_id, owner=owner)
for folder in ["INBOX", "Sent", "Archive", "Drafts"]:
try:
st_sel, _sd = ctx_conn.select(_q(folder), readonly=True)
@@ -1221,18 +1310,18 @@ def _pre_retrieve_context(body: str, sender: str) -> tuple:
try:
from routes.contacts_routes import _fetch_contacts
all_contacts = _fetch_contacts()
all_contacts = _fetch_contacts() if contacts_allowed else []
for term in terms_list:
t_lower = term.lower()
matches = [c for c in all_contacts
if t_lower in (c.get("name") or "").lower()
or t_lower in (c.get("email") or "").lower()]
or any(t_lower in (e or "").lower() for e in (c.get("emails") or []))]
for c in matches[:2]:
parts = [f"Name: {c.get('name','')}"]
if c.get("email"):
parts.append(f"Email: {c['email']}")
if c.get("phone"):
parts.append(f"Phone: {c['phone']}")
if c.get("emails"):
parts.append(f"Email: {', '.join(c['emails'])}")
if c.get("phones"):
parts.append(f"Phone: {', '.join(c['phones'])}")
context_snippets.append(f"[Contact match for \"{term}\"] " + ", ".join(parts))
except Exception:
pass
+98 -47
View File
@@ -45,6 +45,21 @@ from routes.email_helpers import (
logger = logging.getLogger(__name__)
def _owner_for_email_account(account_id: str | None) -> str:
if not account_id:
return ""
try:
from core.database import SessionLocal as _SL, EmailAccount as _EA
db = _SL()
try:
row = db.query(_EA.owner).filter(_EA.id == account_id).first()
return (row[0] or "") if row else ""
finally:
db.close()
except Exception:
return ""
# ── Routes ──
async def _emit_progress(progress_cb, message: str):
@@ -84,6 +99,36 @@ async def _run_auto_summarize_once(do_summary: bool = True, do_reply: bool = Tru
_save_settings(s2)
def _latest_inbox_fallback_uids(conn, reconnect):
"""Latest INBOX UIDs via ``SEARCH ALL``, with a poisoned-socket guard (#1613).
On a large Gmail mailbox the fallback ``SEARCH ALL`` can time out mid-reply,
leaving its enormous ``* SEARCH <uids>`` line unread on the socket. The next
command (the downstream re-select / EXAMINE) then reads those leftover bytes
and fails with ``EXAMINE => unexpected response: b'325188 …'``. Reconnecting
on failure guarantees the downstream command starts from a clean socket.
Returns ``(uids, conn)`` ``conn`` is the live connection to keep using: the
same one on success, a fresh one (via ``reconnect()``) if we had to recover.
"""
try:
conn.select("INBOX", readonly=True)
status, data = conn.uid("SEARCH", None, "ALL")
uids = []
if status == "OK" and data and data[0]:
for u in reversed(data[0].split()[-8:]):
uids.append(("INBOX", u))
logger.info("Email task SINCE scan found no messages; fell back to latest INBOX messages")
return uids, conn
except Exception as _e:
logger.warning(f"Latest-INBOX fallback scan failed: {_e}")
try:
conn.logout()
except Exception:
pass
return [], reconnect()
async def _auto_summarize_pass(days_back: int = 1, account_id: str | None = None, progress_cb=None) -> str:
"""Single pass of the auto-summarize/reply scan.
@@ -132,7 +177,7 @@ async def _auto_summarize_pass_single(days_back: int = 1, account_id: str | None
import sqlite3 as _sql3
import requests as _req
from src.endpoint_resolver import resolve_endpoint
from src.llm_core import _uses_max_completion_tokens
from src.llm_core import _uses_max_completion_tokens, _restricts_temperature
settings = _load_settings()
auto_sum = settings.get("email_auto_summarize", False)
@@ -143,25 +188,18 @@ async def _auto_summarize_pass_single(days_back: int = 1, account_id: str | None
if not auto_sum and not auto_reply and not auto_tag and not auto_spam and not auto_cal:
return "Nothing to do"
# Owner of the account being processed. All calendar reads/writes below are
# scoped to this user: the multi-account fan-out runs every user's mailbox,
# so an unscoped pass would disclose and mutate other tenants' calendars.
_acct_owner = None
try:
from core.database import SessionLocal as _SLo, EmailAccount as _EAo
_dbo = _SLo()
try:
if account_id:
_arow = _dbo.query(_EAo).filter(_EAo.id == account_id).first()
_acct_owner = _arow.owner if _arow else None
finally:
_dbo.close()
except Exception:
_acct_owner = None
# Owner of the account being processed. All calendar + mailbox reads/writes
# below are scoped to this user: the multi-account fan-out runs every user's
# mailbox, so an unscoped pass would disclose/mutate other tenants' data.
# One resolution feeds both the mailbox path (account_owner) and upstream's
# calendar path (_acct_owner, which expects None rather than "").
account_owner = _owner_for_email_account(account_id)
_acct_owner = account_owner or None
conn = None
try:
await _emit_progress(progress_cb, "Connecting to mail…")
conn = _imap_connect(account_id)
conn = _imap_connect(account_id, owner=account_owner)
from datetime import timedelta as _td
since = (datetime.utcnow() - _td(days=max(1, days_back))).strftime("%d-%b-%Y")
# uid_list carries real IMAP UIDs, matching the email UI/read routes.
@@ -193,26 +231,27 @@ async def _auto_summarize_pass_single(days_back: int = 1, account_id: str | None
# the latest visible inbox messages so Clear cache -> Run again can
# actually repopulate AI reply/summary/tag caches.
if not uid_list:
try:
conn.select("INBOX", readonly=True)
status, data = conn.uid("SEARCH", None, "ALL")
if status == "OK" and data and data[0]:
for u in reversed(data[0].split()[-8:]):
uid_list.append(("INBOX", u))
logger.info("Email task SINCE scan found no messages; fell back to latest INBOX messages")
except Exception as _e:
logger.warning(f"Latest-INBOX fallback scan failed: {_e}")
# Re-select INBOX as default for downstream code
_fb_uids, conn = _latest_inbox_fallback_uids(
conn, lambda: _imap_connect(account_id, owner=account_owner)
)
uid_list.extend(_fb_uids)
# Re-select INBOX as default for downstream code (on a clean socket even
# if the SEARCH ALL fallback above failed — see #1613).
conn.select("INBOX", readonly=True)
if not uid_list:
conn.logout()
return "No recent emails"
await _emit_progress(progress_cb, f"Found {len(uid_list)} recent email(s); checking cache…")
_c = _sql3.connect(SCHEDULED_DB)
_sum_existing = {r[0] for r in _c.execute("SELECT message_id FROM email_summaries").fetchall()}
_reply_existing = {r[0] for r in _c.execute("SELECT message_id FROM email_ai_replies").fetchall()}
_tag_existing = {r[0] for r in _c.execute("SELECT message_id FROM email_tags").fetchall()} if (auto_tag or auto_spam) else set()
if auto_tag or auto_spam:
if account_owner:
_tag_existing = {r[0] for r in _c.execute("SELECT message_id FROM email_tags WHERE owner=?", (account_owner,)).fetchall()}
else:
_tag_existing = {r[0] for r in _c.execute("SELECT message_id FROM email_tags WHERE owner='' OR owner IS NULL").fetchall()}
else:
_tag_existing = set()
_cal_existing = {r[0] for r in _c.execute("SELECT message_id FROM email_calendar_extractions").fetchall()} if auto_cal else set()
# Urgency is handled by the built-in `check_email_urgency` task. Keep
# this legacy poller path disabled so users don't get two independent
@@ -225,7 +264,7 @@ async def _auto_summarize_pass_single(days_back: int = 1, account_id: str | None
# this per-iteration was making big inbox scans crawl. Used by the
# urgency self-loop check below.
try:
_self_self_addr = (_get_email_config(account_id).get("from_address") or "").strip().lower()
_self_self_addr = (_get_email_config(account_id, owner=account_owner).get("from_address") or "").strip().lower()
except Exception:
_self_self_addr = ""
@@ -233,11 +272,10 @@ async def _auto_summarize_pass_single(days_back: int = 1, account_id: str | None
if auto_spam and not spam_folder:
logger.warning("Auto-spam enabled but no Junk/Spam folder detected — will classify but not move")
url, model, headers = resolve_endpoint("utility")
url, model, headers = resolve_endpoint("utility", owner=account_owner)
if not url:
url, model, headers = resolve_endpoint("default")
url, model, headers = resolve_endpoint("default", owner=account_owner)
if not url or not model:
conn.logout()
return "No model configured"
writing_style = settings.get("email_writing_style", "")
@@ -355,6 +393,9 @@ async def _auto_summarize_pass_single(days_back: int = 1, account_id: str | None
"temperature": 0.3,
"stream": False,
}
# Reasoning models (o1/o3/o4/gpt-5) reject an explicit temperature.
if _restricts_temperature(model):
payload.pop("temperature", None)
try:
# Use to_thread so this sync HTTP call doesn't freeze
# the entire event loop while the LLM thinks (240s).
@@ -392,8 +433,8 @@ async def _auto_summarize_pass_single(days_back: int = 1, account_id: str | None
await _emit_progress(progress_cb, f"Drafting reply {processed + 1}/{_max_process} · checked {examined}/{len(uid_list)}")
# Background reply drafting should not make the whole app
# feel busy. Keep it lightweight: no extra IMAP context
# mining here; manual AI Reply can still do that when the
# user explicitly asks for a draft on one email.
# mining here; manual AI Reply can still do that (owner-scoped)
# when the user explicitly asks for a draft on one email.
context_snippets, _terms = [], []
sys_prompt = _EMAIL_REPLY_SYS_PROMPT_BASE
if att_text:
@@ -708,7 +749,7 @@ async def _auto_summarize_pass_single(days_back: int = 1, account_id: str | None
# Send alert email immediately if critical or high
if urgency in ("critical", "high"):
try:
cfg = _get_email_config(account_id)
cfg = _get_email_config(account_id, owner=account_owner)
to_addr = cfg["from_address"] # self-email
# Deep-link to open the original email in Odysseus (if public URL is configured).
@@ -716,8 +757,8 @@ async def _auto_summarize_pass_single(days_back: int = 1, account_id: str | None
from src.settings import load_settings as _ls
_pub = (_ls().get("app_public_url") or "").rstrip("/")
uid_str = uid.decode() if isinstance(uid, bytes) else str(uid)
from urllib.parse import quote as _q
open_url = f"{_pub}/#email={_q(_folder, safe='')}:{uid_str}" if _pub else ""
from urllib.parse import quote as _url_q
open_url = f"{_pub}/#email={_url_q(_folder, safe='')}:{uid_str}" if _pub else ""
alert_subject = f"[{urgency.upper()}] {subject}"
alert_body = (
@@ -806,12 +847,15 @@ async def _auto_summarize_pass_single(days_back: int = 1, account_id: str | None
"temperature": 0.1,
"stream": False,
}
# Reasoning models (o1/o3/o4/gpt-5) reject an explicit temperature.
if _restricts_temperature(model):
payload.pop("temperature", None)
# to_thread keeps the event loop responsive during the LLM call
resp = await asyncio.to_thread(
_req.post, url, json=payload, headers=req_headers, timeout=120
)
if not resp.ok:
logger.warning(f"Auto-classify {uid.decode()} HTTP {resp.status_code}: {resp.text[:200]}")
logger.warning(f"Auto-classify {uid.decode() if isinstance(uid, bytes) else str(uid)} HTTP {resp.status_code}: {resp.text[:200]}")
else:
rdata = resp.json()
m = (rdata.get("choices") or [{}])[0].get("message", {})
@@ -840,17 +884,17 @@ async def _auto_summarize_pass_single(days_back: int = 1, account_id: str | None
moved_to = ""
if is_spam and auto_spam and spam_folder:
if _imap_move(uid, spam_folder):
if _imap_move(uid, spam_folder, account_id=account_id, owner=account_owner):
moved_to = spam_folder
logger.info(f"Auto-spam moved uid={uid.decode()} to {spam_folder}: {spam_reason}")
_c = _sql3.connect(SCHEDULED_DB)
_c.execute("""
INSERT OR REPLACE INTO email_tags
(message_id, uid, folder, subject, sender, tags, spam_verdict,
(message_id, owner, uid, folder, subject, sender, tags, spam_verdict,
spam_reason, moved_to, model_used, created_at)
VALUES (?, ?, 'INBOX', ?, ?, ?, ?, ?, ?, ?, ?)
""", (message_id, uid.decode(), subject, sender,
VALUES (?, ?, ?, 'INBOX', ?, ?, ?, ?, ?, ?, ?, ?)
""", (message_id, account_owner or "", uid.decode(), subject, sender,
json.dumps(tags), 1 if is_spam else 0,
spam_reason, moved_to, model, datetime.utcnow().isoformat()))
_c.commit()
@@ -865,7 +909,6 @@ async def _auto_summarize_pass_single(days_back: int = 1, account_id: str | None
logger.warning(f"Auto-process {uid} failed: {e}")
continue
conn.logout()
await _emit_progress(progress_cb, "Finishing…")
if processed > 0:
logger.info(f"Auto-processed {processed} new email(s) for summary/reply/classify")
@@ -902,6 +945,12 @@ async def _auto_summarize_pass_single(days_back: int = 1, account_id: str | None
except Exception as e:
logger.warning(f"Auto-summarize pass error: {e}")
return f"Error: {e}"
finally:
if conn:
try:
conn.logout()
except Exception:
pass
async def _auto_summarize_poller():
@@ -930,8 +979,9 @@ def _scheduled_poll_once() -> dict:
conn = sqlite3.connect(SCHEDULED_DB)
cols = [row[1] for row in conn.execute("PRAGMA table_info(scheduled_emails)").fetchall()]
kind_expr = "odysseus_kind" if "odysseus_kind" in cols else "'scheduled' AS odysseus_kind"
owner_expr = "owner" if "owner" in cols else "'' AS owner"
rows = conn.execute(f"""
SELECT id, to_addr, cc, bcc, subject, body, in_reply_to, references_hdr, attachments, account_id, {kind_expr}
SELECT id, to_addr, cc, bcc, subject, body, in_reply_to, references_hdr, attachments, account_id, {kind_expr}, {owner_expr}
FROM scheduled_emails
WHERE status = 'pending' AND send_at <= ?
""", (now_iso,)).fetchall()
@@ -943,7 +993,8 @@ def _scheduled_poll_once() -> dict:
attachments = json.loads(r[8] or "[]")
row_account_id = r[9] if len(r) > 9 else None
odysseus_kind = r[10] if len(r) > 10 else "scheduled"
cfg = _get_email_config(row_account_id)
row_owner = (r[11] if len(r) > 11 else "") or _owner_for_email_account(row_account_id)
cfg = _get_email_config(row_account_id, owner=row_owner)
has_atts = bool(attachments)
if has_atts:
outer = MIMEMultipart("mixed")
@@ -980,7 +1031,7 @@ def _scheduled_poll_once() -> dict:
# Append to local Sent folder
try:
with _imap() as imap:
with _imap(row_account_id, owner=row_owner) as imap:
sent_folder = _detect_sent_folder(imap)
imap.append(sent_folder, "\\Seen", None, outer.as_bytes())
except Exception as e:
+113 -58
View File
@@ -17,7 +17,6 @@ import sqlite3 as _sql3
import email as email_mod
import email.header
import email.utils
import imaplib
import smtplib
import json
import re
@@ -40,7 +39,8 @@ from routes.email_helpers import (
_strip_think, _extract_reply, _apply_email_style_mechanics, require_owner, require_user, _assert_owns_account,
_q, _attach_compose_uploads, _cleanup_compose_uploads,
_load_settings, _save_settings, _get_email_config,
_send_smtp_message,
_send_smtp_message, _smtp_security_mode,
_IMAP_TIMEOUT_SECONDS, _open_imap_connection,
_imap_connect, _imap, _decode_header, _detect_sent_folder, _detect_drafts_folder,
_extract_attachment_text, _list_attachments_from_msg,
_extract_attachment_to_disk, _extract_html, _extract_text,
@@ -90,6 +90,16 @@ def _email_tag_owner_aliases(account_id: str | None, owner: str = "") -> list[st
return out or [""]
def _email_tag_owner_clause(account_id: str | None, owner: str = "") -> tuple[str, list[str]]:
aliases = _email_tag_owner_aliases(account_id, owner)
placeholders = ",".join("?" * len(aliases))
# In configured multi-user mode, do not treat legacy owner='' rows as
# visible to everyone. Single-user/unconfigured mode keeps legacy rows.
if owner:
return f"owner IN ({placeholders})", aliases
return f"(owner IN ({placeholders}) OR owner IS NULL)", aliases
def _record_email_received_events(owner: str, account_id: str | None, folder: str, emails: list[dict]):
"""Baseline inbox messages, then fire `email_received` for new arrivals."""
if not owner or (folder or "INBOX").upper() != "INBOX" or not emails:
@@ -312,6 +322,20 @@ def _apply_odysseus_headers(msg, kind: str | None = None, ref_id: str | None = N
msg["X-Odysseus-Ref"] = re.sub(r"[^A-Za-z0-9_.:-]", "-", ref_id)[:128]
def _envelope_recipients(*fields: str) -> list:
"""Extract bare SMTP envelope addresses from one or more To/Cc/Bcc header
strings. A naive `field.split(",")` corrupts display names that contain a
comma (e.g. `"Smith, John" <john@corp.com>`, the canonical Outlook form):
it splits into `"Smith` and `John" <john@corp.com>`, breaking delivery.
email.utils.getaddresses parses the address grammar correctly."""
out = []
for _name, addr in email.utils.getaddresses([f for f in fields if f]):
addr = (addr or "").strip()
if addr:
out.append(addr)
return out
def _md_to_email_html(text: str) -> str:
"""Render the compose markdown body to a SAFE HTML fragment for the email's
text/html part. Everything is HTML-escaped FIRST (so a pasted <script> /
@@ -457,7 +481,7 @@ def setup_email_routes():
_IMAP_POOL = {} # account_id → (conn, last_used_at)
_IMAP_IDLE_MAX = 60.0
_WARMING_READS = set()
_WARM_READ_LIMIT = 3
_WARM_READ_LIMIT = 1
_WARM_MAX_BYTES = 128 * 1024
_WARM_RECENT_SECONDS = 7 * 24 * 60 * 60
_pool_lock = _threading.Lock()
@@ -591,11 +615,11 @@ def setup_email_routes():
SECURITY: `owner` is propagated so when `account_id` is missing,
the fallback config lookup is scoped to this user's accounts only.
"""
conn = None
try:
conn = _imap_connect(account_id, owner=owner)
select_status, _ = conn.select(_q(folder), readonly=True)
if select_status != "OK":
conn.logout()
return {"emails": [], "total": 0, "folder": folder, "error": f"Folder not found: {folder}"}
from_clause = ""
@@ -645,8 +669,7 @@ def setup_email_routes():
try:
import sqlite3 as _sql3t
_ct = _sql3t.connect(SCHEDULED_DB)
_owner_aliases = _email_tag_owner_aliases(account_id, owner)
_owner_ph = ",".join("?" * len(_owner_aliases))
_owner_clause, _owner_params = _email_tag_owner_clause(account_id, owner)
# SECURITY: owner-scope the lookup (review C2/H8). Without
# this, user A's `tag:urgent` filter would surface UIDs
# written by user B and IMAP would return whatever
@@ -658,8 +681,8 @@ def setup_email_routes():
rows_t = _ct.execute(
"SELECT message_id, uid FROM email_tags "
"WHERE folder=? AND spam_verdict=1 "
f"AND (owner IN ({_owner_ph}) OR owner IS NULL)",
(folder, *_owner_aliases),
f"AND {_owner_clause}",
(folder, *_owner_params),
).fetchall()
for mid, uid in rows_t:
if mid:
@@ -670,8 +693,8 @@ def setup_email_routes():
rows_t = _ct.execute(
"SELECT message_id, uid, tags FROM email_tags "
"WHERE folder=? AND tags IS NOT NULL AND tags != '' "
f"AND (owner IN ({_owner_ph}) OR owner IS NULL)",
(folder, *_owner_aliases),
f"AND {_owner_clause}",
(folder, *_owner_params),
).fetchall()
for r in rows_t:
try:
@@ -743,12 +766,11 @@ def setup_email_routes():
_uid_strs = [u.decode() for u in uid_list]
if _uid_strs:
placeholders = ",".join("?" * len(_uid_strs))
_owner_aliases = _email_tag_owner_aliases(account_id, owner)
_owner_ph = ",".join("?" * len(_owner_aliases))
_owner_clause, _owner_params = _email_tag_owner_clause(account_id, owner)
rows = _c.execute(
f"SELECT uid, tags, spam_verdict FROM email_tags "
f"WHERE folder=? AND (owner IN ({_owner_ph}) OR owner IS NULL) AND uid IN ({placeholders})",
[folder, *_owner_aliases, *_uid_strs],
f"WHERE folder=? AND {_owner_clause} AND uid IN ({placeholders})",
[folder, *_owner_params, *_uid_strs],
).fetchall()
for r in rows:
try:
@@ -805,14 +827,13 @@ def setup_email_routes():
if header_ids:
import sqlite3 as _sql3m
_cm = _sql3m.connect(SCHEDULED_DB)
_owner_aliases_m = _email_tag_owner_aliases(account_id, owner)
_owner_ph_m = ",".join("?" * len(_owner_aliases_m))
_owner_clause_m, _owner_params_m = _email_tag_owner_clause(account_id, owner)
_mid_ph = ",".join("?" * len(header_ids))
rows_m = _cm.execute(
f"SELECT message_id, tags, spam_verdict FROM email_tags "
f"WHERE folder=? AND (owner IN ({_owner_ph_m}) OR owner IS NULL) "
f"WHERE folder=? AND {_owner_clause_m} "
f"AND message_id IN ({_mid_ph})",
[folder, *_owner_aliases_m, *header_ids],
[folder, *_owner_params_m, *header_ids],
).fetchall()
_cm.close()
for mid, tags_raw, spam_raw in rows_m:
@@ -924,12 +945,17 @@ def setup_email_routes():
except Exception as _summary_err:
logger.debug(f"Bulk summary attach skipped: {_summary_err}")
conn.logout()
return {"emails": emails, "total": total, "folder": folder, "offset": offset}
except Exception as e:
logger.error(f"Failed to list emails: {e}")
detail = str(e).strip()
return {"emails": [], "total": 0, "error": f"Mail operation failed: {detail[:180]}" if detail else "Mail operation failed"}
finally:
if conn:
try:
conn.logout()
except Exception:
pass
@router.get("/list")
async def list_emails(
@@ -971,10 +997,11 @@ def setup_email_routes():
async def unflag_spam(uid: str, owner: str = Depends(require_owner)):
"""User override — mark email as not spam."""
try:
owner_clause, owner_params = _email_tag_owner_clause(None, owner)
_c = _sql3.connect(SCHEDULED_DB)
_c.execute(
"UPDATE email_tags SET spam_verdict=0, spam_reason='' WHERE uid=?",
(uid,),
f"UPDATE email_tags SET spam_verdict=0, spam_reason='' WHERE uid=? AND {owner_clause}",
[uid, *owner_params],
)
_c.commit()
_c.close()
@@ -997,8 +1024,10 @@ def setup_email_routes():
ql = (q or "").strip().lower()
try:
conn = _sql3.connect(SCHEDULED_DB)
owner_clause, owner_params = _email_tag_owner_clause(None, owner)
rows = conn.execute(
"SELECT sender FROM email_tags WHERE sender IS NOT NULL AND sender != ''"
f"SELECT sender FROM email_tags WHERE sender IS NOT NULL AND sender != '' AND {owner_clause}",
owner_params,
).fetchall()
conn.close()
seen = {}
@@ -1046,7 +1075,7 @@ def setup_email_routes():
# Escape backslash and quote for the IMAP-SEARCH quoted-string.
q_escaped = q.replace('\\', '\\\\').replace('"', '\\"')
search_cmd = f'(OR FROM "{q_escaped}" TEXT "{q_escaped}")'
search_cmd = f'(OR OR FROM "{q_escaped}" SUBJECT "{q_escaped}" TEXT "{q_escaped}")'
status, data = _imap_uid_search(conn, search_cmd)
if status != "OK" or not data[0]:
@@ -1928,11 +1957,7 @@ def setup_email_routes():
outer.attach(body_container)
_attach_compose_uploads(outer, attachments)
recipients = [r.strip() for r in to.split(",") if r.strip()]
if cc:
recipients.extend([r.strip() for r in cc.split(",") if r.strip()])
if bcc:
recipients.extend([r.strip() for r in bcc.split(",") if r.strip()])
recipients = _envelope_recipients(to, cc, bcc)
_send_smtp_message(cfg, cfg["from_address"], recipients, outer.as_string())
@@ -1964,13 +1989,22 @@ def setup_email_routes():
# minute doesn't trip the past-time guard.
if parsed_at < now_utc:
return {"success": False, "error": "send_at must be in the future"}
# Normalize to naive UTC before storing: the poller selects due
# rows with a lexicographic string compare against a naive
# datetime.utcnow().isoformat(), so storing the raw client string
# makes "+02:00" schedules fire hours late, negative offsets fire
# hours early, and a "Z" suffix compares after the fractional
# seconds of the poller timestamp.
if parsed_at.tzinfo:
parsed_at = parsed_at.astimezone(_tz.utc).replace(tzinfo=None)
send_at = parsed_at.isoformat()
sid = _uuid.uuid4().hex[:16]
conn = sqlite3.connect(SCHEDULED_DB)
conn.execute("""
INSERT INTO scheduled_emails
(id, to_addr, cc, bcc, subject, body, in_reply_to, references_hdr, attachments, send_at, created_at, status, account_id, odysseus_kind)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, 'pending', ?, ?)
(id, to_addr, cc, bcc, subject, body, in_reply_to, references_hdr, attachments, send_at, created_at, status, account_id, odysseus_kind, owner)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, 'pending', ?, ?, ?)
""", (
sid,
req.get("to", ""),
@@ -1985,6 +2019,7 @@ def setup_email_routes():
datetime.utcnow().isoformat(),
req.get("account_id") or None,
req.get("odysseus_kind") or "scheduled",
owner or "",
))
conn.commit()
conn.close()
@@ -2003,9 +2038,9 @@ def setup_email_routes():
rows = conn.execute("""
SELECT id, to_addr, cc, subject, send_at, created_at, status, error
FROM scheduled_emails
WHERE status IN ('pending', 'failed')
WHERE status IN ('pending', 'failed') AND owner = ?
ORDER BY send_at ASC
""").fetchall()
""", (owner or "",)).fetchall()
conn.close()
return {"scheduled": [
{
@@ -2023,7 +2058,10 @@ def setup_email_routes():
import sqlite3
try:
conn = sqlite3.connect(SCHEDULED_DB)
conn.execute("DELETE FROM scheduled_emails WHERE id = ? AND status = 'pending'", (sid,))
conn.execute(
"DELETE FROM scheduled_emails WHERE id = ? AND status = 'pending' AND owner = ?",
(sid, owner or ""),
)
conn.commit()
conn.close()
return {"success": True}
@@ -2035,7 +2073,7 @@ def setup_email_routes():
async def resolve_contact(name: str = Query(..., description="Name to search for"), owner: str = Depends(require_owner)):
"""Search Sent folder for a contact by name. Returns matching email addresses."""
try:
with _imap() as conn:
with _imap(owner=owner) as conn:
matches = {}
for folder in ["Sent", "INBOX", "Drafts"]:
try:
@@ -2133,12 +2171,9 @@ def setup_email_routes():
outer.attach(body_container)
_attach_compose_uploads(outer, req.attachments)
# Build recipient list
recipients = [r.strip() for r in req.to.split(",") if r.strip()]
if req.cc:
recipients.extend([r.strip() for r in req.cc.split(",") if r.strip()])
if req.bcc:
recipients.extend([r.strip() for r in req.bcc.split(",") if r.strip()])
# Build recipient list (parse the address grammar so display names with
# commas don't get split into broken envelope addresses)
recipients = _envelope_recipients(req.to, req.cc, req.bcc)
# Serialize what the background task needs so the request object can be GC'd
outer_bytes = outer.as_bytes()
@@ -2146,6 +2181,7 @@ def setup_email_routes():
_from = cfg["from_address"]
_smtp_host = cfg["smtp_host"]
_smtp_port = cfg["smtp_port"]
_smtp_security = cfg.get("smtp_security")
_smtp_user = cfg["smtp_user"]
_smtp_pw = cfg["smtp_password"]
_recipients = list(recipients)
@@ -2163,6 +2199,7 @@ def setup_email_routes():
{
"smtp_host": _smtp_host,
"smtp_port": _smtp_port,
"smtp_security": _smtp_security,
"smtp_user": _smtp_user,
"smtp_password": _smtp_pw,
},
@@ -2417,7 +2454,7 @@ def setup_email_routes():
"""Generate a quick AI summary of an email body."""
try:
from src.endpoint_resolver import resolve_endpoint
from src.llm_core import _uses_max_completion_tokens
from src.llm_core import _uses_max_completion_tokens, _restricts_temperature
import requests as _req
body = data.get("body", "")
@@ -2474,6 +2511,9 @@ def setup_email_routes():
"temperature": 0.3,
"stream": False,
}
# Reasoning models (o1/o3/o4/gpt-5) reject an explicit temperature.
if _restricts_temperature(model):
payload.pop("temperature", None)
resp = await asyncio.to_thread(
_req.post, url, json=payload, headers=req_headers, timeout=180
)
@@ -2585,7 +2625,7 @@ def setup_email_routes():
# `api_key` field.
from core.database import SessionLocal as _SL, Session as _CS
_db = _SL()
sess = _db.query(_CS).filter(_CS.id == session_id).first()
sess = _db.query(_CS).filter(_CS.id == session_id, _CS.owner == owner).first()
if sess and sess.endpoint_url:
url = sess.endpoint_url
# Some sessions stored headers double-encoded (a JSON
@@ -2644,9 +2684,10 @@ def setup_email_routes():
# Manual AI Reply should feel immediate. The heavier context mining
# can involve multiple IMAP folder searches and attachment parsing;
# reserve that for callers that explicitly opt out of fast mode.
# Owner-scoped so pre-retrieval never crosses tenants.
context_snippets, _terms = ([], [])
if not fast_reply:
context_snippets, _terms = _pre_retrieve_context(original_body, to)
context_snippets, _terms = _pre_retrieve_context(original_body, to, owner=owner)
# NEW: also pull the last few emails from the original sender +
# their attachments. The "to" field on this endpoint is the
@@ -2662,6 +2703,7 @@ def setup_email_routes():
exclude_uid=source_uid,
exclude_folder=source_folder,
limit=3,
owner=owner,
)
except Exception as _e:
logger.warning(f"sender-thread-context failed: {_e}")
@@ -2723,7 +2765,7 @@ def setup_email_routes():
# Configured fallback chains last.
for cand in resolve_utility_fallback_candidates(owner=owner) or []:
_add(*cand)
for cand in resolve_chat_fallback_candidates() or []:
for cand in resolve_chat_fallback_candidates(owner=owner) or []:
_add(*cand)
try:
reply = await llm_call_async_with_fallback(
@@ -2814,13 +2856,16 @@ def setup_email_routes():
import uuid as _uuid
db = SessionLocal()
try:
row = db.query(EmailAccount).filter(EmailAccount.is_default == True).first() # noqa: E712
q = db.query(EmailAccount).filter(EmailAccount.is_default == True) # noqa: E712
if owner:
q = q.filter(EmailAccount.owner == owner)
row = q.first()
if row is None:
row = EmailAccount(id=_uuid.uuid4().hex, name="Default", is_default=True, enabled=True)
row = EmailAccount(id=_uuid.uuid4().hex, owner=owner, name="Default", is_default=True, enabled=True)
db.add(row)
field_map = {
"smtp_host": "smtp_host", "smtp_port": "smtp_port", "smtp_user": "smtp_user",
"imap_host": "imap_host", "imap_port": "imap_port", "imap_user": "imap_user",
"smtp_security": "smtp_security", "imap_host": "imap_host", "imap_port": "imap_port", "imap_user": "imap_user",
"imap_starttls": "imap_starttls", "email_from": "from_address",
}
for in_key, col_name in field_map.items():
@@ -2838,6 +2883,10 @@ def setup_email_routes():
row.imap_password = _enc(data["imap_password"])
if data.get("smtp_password"):
row.smtp_password = _enc(data["smtp_password"])
clear_q = db.query(EmailAccount).filter(EmailAccount.id != row.id)
if owner:
clear_q = clear_q.filter(EmailAccount.owner == owner)
clear_q.update({EmailAccount.is_default: False})
db.commit()
finally:
db.close()
@@ -2902,6 +2951,7 @@ def setup_email_routes():
"imap_starttls": bool(r.imap_starttls),
"smtp_host": r.smtp_host or "",
"smtp_port": int(r.smtp_port or 465),
"smtp_security": _smtp_security_mode({"smtp_security": getattr(r, "smtp_security", ""), "smtp_port": r.smtp_port}),
"smtp_user": r.smtp_user or "",
"from_address": r.from_address or "",
"has_imap_password": bool(r.imap_password),
@@ -2934,6 +2984,7 @@ def setup_email_routes():
imap_starttls=bool(data.get("imap_starttls", True)),
smtp_host=(data.get("smtp_host") or "").strip(),
smtp_port=int(data.get("smtp_port") or 465),
smtp_security=_smtp_security_mode({"smtp_security": data.get("smtp_security"), "smtp_port": data.get("smtp_port") or 465}),
smtp_user=(data.get("smtp_user") or "").strip(),
smtp_password=_enc(data.get("smtp_password") or ""),
from_address=(data.get("from_address") or "").strip(),
@@ -2977,6 +3028,8 @@ def setup_email_routes():
for key in ("imap_port", "smtp_port"):
if data.get(key) not in (None, ""):
setattr(row, key, int(data[key]))
if "smtp_security" in data:
row.smtp_security = _smtp_security_mode({"smtp_security": data.get("smtp_security"), "smtp_port": data.get("smtp_port") or row.smtp_port})
for key in ("imap_starttls", "enabled"):
if key in data:
setattr(row, key, bool(data[key]))
@@ -3061,6 +3114,7 @@ def setup_email_routes():
"imap_starttls": bool(row.imap_starttls),
"smtp_host": row.smtp_host or "",
"smtp_port": row.smtp_port or 465,
"smtp_security": _smtp_security_mode({"smtp_security": getattr(row, "smtp_security", ""), "smtp_port": row.smtp_port}),
"smtp_user": row.smtp_user or "",
"smtp_password": _decrypt(row.smtp_password or ""),
}
@@ -3093,13 +3147,12 @@ def setup_email_routes():
# port (Dovecot on 31143, etc.) would always fail the SSL
# handshake because they're not actually wrapped in TLS.
try:
if imap_starttls:
conn = imaplib.IMAP4(imap_host, imap_port, timeout=10)
conn.starttls()
elif imap_port == 993:
conn = imaplib.IMAP4_SSL(imap_host, imap_port, timeout=10)
else:
conn = imaplib.IMAP4(imap_host, imap_port, timeout=10)
conn = _open_imap_connection(
imap_host,
imap_port,
starttls=imap_starttls,
timeout=_IMAP_TIMEOUT_SECONDS,
)
try:
conn.login(imap_user, imap_pass)
imap_result = {"ok": True}
@@ -3112,14 +3165,16 @@ def setup_email_routes():
smtp_host = (body.get("smtp_host") or "").strip()
if smtp_host:
smtp_port = int(body.get("smtp_port") or 465)
smtp_security = _smtp_security_mode({"smtp_security": body.get("smtp_security"), "smtp_port": smtp_port})
smtp_user = (body.get("smtp_user") or imap_user).strip()
smtp_pass = body.get("smtp_password") or imap_pass
try:
if smtp_port == 587:
smtp = smtplib.SMTP(smtp_host, smtp_port, timeout=10)
smtp.starttls()
else:
if smtp_security == "ssl":
smtp = smtplib.SMTP_SSL(smtp_host, smtp_port, timeout=10)
else:
smtp = smtplib.SMTP(smtp_host, smtp_port, timeout=10)
if smtp_security == "starttls":
smtp.starttls()
try:
smtp.login(smtp_user, smtp_pass)
smtp_result = {"ok": True}
+15 -2
View File
@@ -86,7 +86,8 @@ def _load_custom_endpoint() -> dict:
"""Load the saved custom embedding endpoint, if any."""
try:
if os.path.exists(_ENDPOINT_FILE):
return json.loads(Path(_ENDPOINT_FILE).read_text(encoding="utf-8"))
data = json.loads(Path(_ENDPOINT_FILE).read_text(encoding="utf-8"))
return data if isinstance(data, dict) else {}
except Exception:
pass
return {}
@@ -160,7 +161,7 @@ def setup_embedding_routes():
_downloading[model_name] = True
try:
# Run in thread to not block the event loop
loop = asyncio.get_event_loop()
loop = asyncio.get_running_loop()
cache = _cache_dir()
await loop.run_in_executor(
None,
@@ -242,6 +243,18 @@ def setup_embedding_routes():
if not url:
raise HTTPException(400, "URL is required")
# SSRF hardening: validate the user-supplied URL before any outbound
# request. Local-first means loopback/LAN endpoints are allowed by
# default; non-HTTP(S) schemes and the cloud metadata range are always
# rejected. Set EMBEDDING_BLOCK_PRIVATE_IPS=true for full lockdown.
from src.url_safety import check_outbound_url
ok, reason = check_outbound_url(
url,
block_private=os.getenv("EMBEDDING_BLOCK_PRIVATE_IPS", "false").lower() == "true",
)
if not ok:
raise HTTPException(400, f"Rejected endpoint URL: {reason}")
# Quick health check
try:
import httpx
+10 -2
View File
@@ -5,6 +5,15 @@ from fastapi import APIRouter
CUSTOM_FONTS_DIR = os.path.join("static", "fonts", "custom")
FONT_EXTENSIONS = {".ttf", ".otf", ".woff", ".woff2"}
FAMILY_SUFFIX_WORDS = ("Display", "Rounded", "Serif", "Sans", "Mono", "Code", "Text")
def _split_family_token(token):
"""Split common compact font-family suffixes without breaking brand names."""
for suffix in FAMILY_SUFFIX_WORDS:
if token.endswith(suffix) and len(token) > len(suffix):
return f"{token[:-len(suffix)]} {suffix}"
return re.sub(r'(?<=[a-z])(?=[A-Z])', ' ', token)
def _derive_family(filename):
@@ -15,10 +24,9 @@ def _derive_family(filename):
r'[-_ ]?(Thin|ExtraLight|UltraLight|Light|Regular|Medium|SemiBold|DemiBold|Bold|ExtraBold|UltraBold|Black|Heavy|Italic|Oblique|Variable|VF)$',
'', name, flags=re.IGNORECASE
)
# Insert spaces before uppercase runs: "JetBrainsMono" → "Jet Brains Mono"
name = re.sub(r'(?<=[a-z])(?=[A-Z])', ' ', name)
# Replace dashes/underscores with spaces
name = re.sub(r'[-_]+', ' ', name).strip()
name = " ".join(_split_family_token(part) for part in name.split())
return name or filename
+22 -3
View File
@@ -32,10 +32,21 @@ def _extract_exif(content: bytes) -> dict:
from PIL import Image
from io import BytesIO
img = Image.open(BytesIO(content))
# Read the raw EXIF before any transpose: exif_transpose strips the
# orientation tag and with it the parsed EXIF view.
exif = img._getexif() if hasattr(img, '_getexif') else None
# Record DISPLAY dimensions (EXIF-rotated), matching upload_handler.
# A phone photo with Orientation 6/8 is stored landscape but shown
# portrait, so the raw width/height swap the aspect ratio.
try:
from PIL import ImageOps
img = ImageOps.exif_transpose(img) or img
except Exception:
pass
result["width"] = img.width
result["height"] = img.height
exif = img._getexif() if hasattr(img, '_getexif') else None
if not exif:
return result
@@ -110,9 +121,17 @@ def _image_to_dict(img: GalleryImage, session_name: str = None) -> Dict[str, Any
def _owner_filter(q, user):
"""Apply owner filtering to a gallery query."""
"""Apply owner filtering to a gallery query.
When auth is disabled (single-user mode) get_current_user returns None
and there is no per-user scoping. The main library list and stats already
treat None as "show everything" (`if user is not None`), so this helper
must too otherwise the tag/model filter sidebars come back empty and the
tag-cleanup endpoints (clear-user-tags, clear-ai-tags, dedupe-tags)
silently affect zero rows in the most common self-hosted deployment.
"""
if user is None:
return q.filter(False)
return q
return q.filter(GalleryImage.owner == user)
+41 -4
View File
@@ -3,6 +3,9 @@
import os
import hashlib
import logging
import re
import uuid
from pathlib import Path
from typing import Dict, Any, Optional
from fastapi import APIRouter, HTTPException, Query, Request
@@ -17,6 +20,14 @@ from routes.gallery_helpers import (
logger = logging.getLogger(__name__)
def _sanitize_gallery_filename(filename: str) -> str:
"""Return a local filename safe to join under generated_images."""
safe_name = re.sub(r"[^A-Za-z0-9._-]", "_", Path(filename or "").name)[:128]
if not safe_name or safe_name in {".", ".."}:
safe_name = uuid.uuid4().hex[:12]
return safe_name
def setup_gallery_routes() -> APIRouter:
router = APIRouter(tags=["gallery"])
@@ -122,7 +133,7 @@ def setup_gallery_routes() -> APIRouter:
content = await file.read()
img_dir = Path("data/generated_images")
img_dir.mkdir(parents=True, exist_ok=True)
img_path = img_dir / img.filename
img_path = img_dir / _sanitize_gallery_filename(img.filename)
img_path.write_bytes(content)
# Refresh dimensions in case the editor resized the canvas.
@@ -912,6 +923,16 @@ def setup_gallery_routes() -> APIRouter:
body = await request.json()
# Use endpoint from request body (editor dropdown) or fall back to DB lookup
base = (body.pop("_endpoint", "") or "").rstrip("/")
# SSRF hardening: validate a client-supplied endpoint before any
# outbound request (mirrors routes/embedding_routes.py).
if base:
from src.url_safety import check_outbound_url
ok, reason = check_outbound_url(
base,
block_private=os.getenv("IMAGE_BLOCK_PRIVATE_IPS", "false").lower() == "true",
)
if not ok:
raise HTTPException(400, f"Rejected endpoint URL: {reason}")
chosen_model = (body.pop("_model", "") or "").strip()
api_key = None
if not base:
@@ -1104,6 +1125,18 @@ def setup_gallery_routes() -> APIRouter:
raise HTTPException(400, "No image provided")
endpoint = (body.get("_endpoint") or "").rstrip("/")
# SSRF hardening: a client-supplied endpoint is fetched server-side
# below, so validate it first (mirrors routes/embedding_routes.py).
# Local-first means loopback/LAN is allowed by default; the cloud
# metadata range and non-HTTP(S) schemes are always rejected.
if endpoint:
from src.url_safety import check_outbound_url
ok, reason = check_outbound_url(
endpoint,
block_private=os.getenv("IMAGE_BLOCK_PRIVATE_IPS", "false").lower() == "true",
)
if not ok:
raise HTTPException(400, f"Rejected endpoint URL: {reason}")
model = (body.get("_model") or "").strip()
base = endpoint
@@ -1125,7 +1158,7 @@ def setup_gallery_routes() -> APIRouter:
db = SessionLocal()
try:
for ep in db.query(ModelEndpoint).all():
if ep.base_url.rstrip("/").rstrip("/v1") == base.rstrip("/v1"):
if ep.base_url.rstrip("/").removesuffix("/v1").rstrip("/") == base.rstrip("/").removesuffix("/v1").rstrip("/"):
api_key = ep.api_key
break
finally:
@@ -1696,7 +1729,7 @@ def setup_gallery_routes() -> APIRouter:
return {"error": "No vision-capable endpoint configured"}
# Call vision model — format differs between Anthropic and OpenAI
from src.llm_core import _detect_provider
from src.llm_core import _detect_provider, _restricts_temperature, _uses_max_completion_tokens
provider = _detect_provider(chat_url)
tag_prompt = (
"Analyze this photo. Return ONLY a comma-separated list of tags. "
@@ -1721,6 +1754,7 @@ def setup_gallery_routes() -> APIRouter:
}],
}
else:
_tok_key = "max_completion_tokens" if _uses_max_completion_tokens(model_name) else "max_tokens"
payload = {
"model": model_name,
"messages": [{
@@ -1730,9 +1764,12 @@ def setup_gallery_routes() -> APIRouter:
{"type": "image_url", "image_url": {"url": f"data:{mime};base64,{b64}"}},
],
}],
"max_tokens": 200,
_tok_key: 200,
"temperature": 0.3,
}
# Reasoning models (o1/o3/o4/gpt-5) reject an explicit temperature.
if _restricts_temperature(model_name):
payload.pop("temperature", None)
h = {"Content-Type": "application/json"}
if headers:
+17 -10
View File
@@ -58,7 +58,7 @@ def setup_history_routes(session_manager) -> APIRouter:
.all()
)
import json as _json
history_dict = []
db_history = []
for m in db_messages:
entry = {"role": m.role, "content": m.content}
meta = {}
@@ -71,12 +71,19 @@ def setup_history_routes(session_manager) -> APIRouter:
meta["timestamp"] = m.timestamp.isoformat() + "Z"
if meta:
entry["metadata"] = meta
history_dict.append(entry)
if history_dict:
db_history.append(entry)
if db_history:
# Rebuild in-memory history from the full set so hidden
# messages (e.g. compaction summaries) are kept for AI context.
session.history = [
ChatMessage(role=m["role"], content=m["content"], metadata=m.get("metadata"))
for m in history_dict
for m in db_history
]
# Response excludes hidden messages, matching the in-memory path.
history_dict = [
m for m in db_history
if not (m.get("metadata") or {}).get("hidden")
]
except Exception as e:
logger.error(f"DB fallback failed for {session_id}: {e}")
finally:
@@ -265,7 +272,7 @@ def setup_history_routes(session_manager) -> APIRouter:
db_messages = (
db.query(DbChatMessage)
.filter(DbChatMessage.session_id == session_id, DbChatMessage.role == 'assistant')
.order_by(DbChatMessage.created_at.desc())
.order_by(DbChatMessage.timestamp.desc())
.first()
)
if db_messages:
@@ -320,7 +327,7 @@ def setup_history_routes(session_manager) -> APIRouter:
db_msg = (
db.query(DbChatMessage)
.filter(DbChatMessage.session_id == session_id, DbChatMessage.role == 'assistant')
.order_by(DbChatMessage.created_at.desc())
.order_by(DbChatMessage.timestamp.desc())
.first()
)
if db_msg:
@@ -401,7 +408,7 @@ def setup_history_routes(session_manager) -> APIRouter:
db_messages = (
db.query(DbChatMessage)
.filter(DbChatMessage.session_id == session_id)
.order_by(DbChatMessage.created_at)
.order_by(DbChatMessage.timestamp)
.all()
)
# Find last two assistant messages in DB
@@ -477,10 +484,10 @@ def setup_history_routes(session_manager) -> APIRouter:
@router.get("/api/conversations/topics")
async def get_conversation_topics(request: Request) -> Dict[str, Any]:
from src.auth_helpers import get_current_user
user = get_current_user(request)
from src.auth_helpers import require_user
user = require_user(request)
try:
return analyze_topics(session_manager, owner=user)
return analyze_topics(session_manager, owner=user or None)
except Exception as e:
raise HTTPException(500, f"Topic analysis failed: {e}")
+152 -76
View File
@@ -1,87 +1,105 @@
import re
from copy import deepcopy
from fastapi import APIRouter
# Backends the manual hardware simulator accepts. Must stay a subset of what
# services.hwfit.fit understands so a simulated box ranks like a real one:
# "metal" routes through the Apple-Silicon path (GGUF-only, llama.cpp/Ollama),
# the CPU backends through the RAM/offload path, cuda/rocm through vLLM.
_MANUAL_BACKENDS = {"cuda", "rocm", "metal", "cpu_x86", "cpu_arm"}
def _apply_manual_hardware(system, manual_mode="", manual_gpu_count="", manual_vram_gb="", manual_ram_gb="", manual_backend=""):
"""Manual hardware is a "what if I had this setup" simulator —
REPLACES the detected hardware entirely instead of adding to it.
The previous additive behavior averaged the manual VRAM across
all GPUs (base + manual), which meant adding "1× 400 GB" on top
of "2× 70 GB" only nudged the per-GPU cap from 70 to 180 GB
(= 540 / 3), so GGUF models bigger than that still didn't surface
exactly the "cap stuck at detected level" bug the user hit.
"""
manual_mode = (manual_mode or "").lower()
if manual_mode not in {"gpu", "ram"}:
return system
try:
override_ram_gb = float(manual_ram_gb) if manual_ram_gb else 0
except ValueError:
override_ram_gb = 0
override_ram_gb = max(0.0, override_ram_gb)
if override_ram_gb:
# Replace RAM, don't add. The number in the field is the
# TOTAL system memory the user wants to simulate.
system["available_ram_gb"] = round(override_ram_gb, 1)
system["total_ram_gb"] = round(override_ram_gb, 1)
system["manual_hardware"] = True
if manual_mode == "ram":
# RAM-only simulation — wipe GPU entirely so the ranker uses
# CPU/RAM paths.
system["has_gpu"] = False
system["gpu_name"] = None
system["gpu_vram_gb"] = 0
system["gpu_count"] = 0
system["gpus"] = []
system["gpu_groups"] = []
system["backend"] = "cpu_x86"
system.pop("unified_memory", None)
return system
try:
count = int(manual_gpu_count) if manual_gpu_count else 1
except ValueError:
count = 1
try:
vram_each = float(manual_vram_gb) if manual_vram_gb else 8.0
except ValueError:
vram_each = 8.0
count = max(1, min(count, 16))
vram_each = max(1.0, vram_each)
backend = (manual_backend or system.get("backend") or "cuda").lower()
if backend not in _MANUAL_BACKENDS:
backend = "cuda"
total_vram = round(vram_each * count, 1)
gpu_name = f"Simulated {backend.upper()} GPU" + (f" × {count}" if count > 1 else "")
system["has_gpu"] = True
system["gpu_name"] = gpu_name
system["gpu_vram_gb"] = total_vram
system["gpu_count"] = count
system["gpus"] = [
{"index": i, "name": gpu_name, "vram_gb": vram_each}
for i in range(count)
]
# Single homogeneous pool — vram_each here is the ACTUAL per-GPU
# VRAM the user entered, not an average. That's the whole point:
# raising vram_each lifts the per-GPU cap (GGUF, tensor-parallel
# math) all the way up, not just by a small fraction.
system["gpu_groups"] = [{
"name": gpu_name,
"vram_each": vram_each,
"count": count,
"indices": list(range(count)),
"vram_total": total_vram,
}]
system["homogeneous"] = True
system["backend"] = backend
# Apple Silicon shares one unified memory pool with the GPU; flag it so
# the API/UI report it the way real Metal detection does. Discrete GPUs
# (cuda/rocm) and the CPU backends carry separate VRAM, so clear any
# stale flag a previous detection left on the dict.
if backend == "metal":
system["unified_memory"] = True
else:
system.pop("unified_memory", None)
return system
def setup_hwfit_routes():
router = APIRouter(prefix="/api/hwfit", tags=["hwfit"])
def _apply_manual_hardware(system, manual_mode="", manual_gpu_count="", manual_vram_gb="", manual_ram_gb="", manual_backend=""):
"""Manual hardware is a "what if I had this setup" simulator —
REPLACES the detected hardware entirely instead of adding to it.
The previous additive behavior averaged the manual VRAM across
all GPUs (base + manual), which meant adding "1× 400 GB" on top
of "2× 70 GB" only nudged the per-GPU cap from 70 to 180 GB
(= 540 / 3), so GGUF models bigger than that still didn't surface
exactly the "cap stuck at detected level" bug the user hit.
"""
manual_mode = (manual_mode or "").lower()
if manual_mode not in {"gpu", "ram"}:
return system
try:
override_ram_gb = float(manual_ram_gb) if manual_ram_gb else 0
except ValueError:
override_ram_gb = 0
override_ram_gb = max(0.0, override_ram_gb)
if override_ram_gb:
# Replace RAM, don't add. The number in the field is the
# TOTAL system memory the user wants to simulate.
system["available_ram_gb"] = round(override_ram_gb, 1)
system["total_ram_gb"] = round(override_ram_gb, 1)
system["manual_hardware"] = True
if manual_mode == "ram":
# RAM-only simulation — wipe GPU entirely so the ranker uses
# CPU/RAM paths.
system["has_gpu"] = False
system["gpu_name"] = None
system["gpu_vram_gb"] = 0
system["gpu_count"] = 0
system["gpus"] = []
system["gpu_groups"] = []
system["backend"] = "cpu_x86"
return system
try:
count = int(manual_gpu_count) if manual_gpu_count else 1
except ValueError:
count = 1
try:
vram_each = float(manual_vram_gb) if manual_vram_gb else 8.0
except ValueError:
vram_each = 8.0
count = max(1, min(count, 16))
vram_each = max(1.0, vram_each)
backend = (manual_backend or system.get("backend") or "cuda").lower()
if backend not in {"cuda", "rocm", "cpu_x86", "cpu_arm"}:
backend = "cuda"
total_vram = round(vram_each * count, 1)
gpu_name = f"Simulated {backend.upper()} GPU" + (f" × {count}" if count > 1 else "")
system["has_gpu"] = True
system["gpu_name"] = gpu_name
system["gpu_vram_gb"] = total_vram
system["gpu_count"] = count
system["gpus"] = [
{"index": i, "name": gpu_name, "vram_gb": vram_each}
for i in range(count)
]
# Single homogeneous pool — vram_each here is the ACTUAL per-GPU
# VRAM the user entered, not an average. That's the whole point:
# raising vram_each lifts the per-GPU cap (GGUF, tensor-parallel
# math) all the way up, not just by a small fraction.
system["gpu_groups"] = [{
"name": gpu_name,
"vram_each": vram_each,
"count": count,
"indices": list(range(count)),
"vram_total": total_vram,
}]
system["homogeneous"] = True
system["backend"] = backend
return system
@router.get("/system")
def get_system(host: str = "", ssh_port: str = "", platform: str = "", fresh: bool = False):
"""Detect and return current system hardware info. Pass host=user@server for remote.
@@ -181,6 +199,64 @@ def setup_hwfit_routes():
results = rank_models(system, use_case=use_case or None, limit=limit, search=search or None, sort=sort, quant=quant or None, target_context=target_context, fit_only=fit_only)
return {"system": system, "models": results}
@router.get("/profiles")
def get_serve_profiles(model: str = "", host: str = "", ssh_port: str = "", platform: str = "", fresh: bool = False, serve_weights_gb: float = 0.0, serve_quant: str = ""):
"""Compute llama.cpp serve profiles (Quality/Balanced/Speed) for `model`
against the detected hardware on `host` (or local). Returns concrete
flags (n_gpu_layers, n_cpu_moe, cache_type, ctx) the serve UI can apply.
`model` is matched against the catalog by name; if it's not in the
catalog (e.g. an ad-hoc HF repo), pass enough hints via a minimal synthetic
entry isn't possible here, so we return [] and the UI keeps manual flags.
"""
from services.hwfit.hardware import detect_system
from services.hwfit.models import get_models
from services.hwfit.profiles import compute_serve_profiles
system = detect_system(host=host, ssh_port=ssh_port, platform=platform, fresh=fresh)
if system.get("error"):
return {"system": system, "profiles": [], "error": system["error"]}
catalog = {m.get("name"): m for m in (get_models() or [])}
def _norm(s):
# Normalize for matching: drop org/ prefix, a trailing -GGUF/-gguf
# marker, and any quant tag, lowercase. So "DeepSeek-Coder-V2-Lite-
# Instruct-GGUF" (a local folder name) matches catalog entry
# "deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct".
s = (s or "").lower().strip()
s = s.split("/")[-1] # drop org prefix
s = re.sub(r"[-_.]?gguf$", "", s) # drop trailing gguf marker
s = re.sub(r"[-_.](q\d[^/]*|iq\d[^/]*|fp8|bf16|f16|awq[^/]*|gptq[^/]*)$", "", s)
return s
m = catalog.get(model)
if m is None and model:
want = _norm(model)
for name, entry in catalog.items():
nn = _norm(name)
if nn and (nn == want or want.endswith(nn) or nn.endswith(want)):
m = entry
break
if m is None:
return {"system": system, "profiles": [], "error": "model not in catalog"}
# Surface the model's trained context limit so the serve UI can clamp a
# user-typed context down to it (asking for ctx > n_ctx_train overflows
# and, with a quantized KV cache, can crash the GPU).
model_ctx_max = 0
for k in ("context_length", "max_position_embeddings", "n_ctx_train", "context"):
v = m.get(k)
if isinstance(v, (int, float)) and v > 0:
model_ctx_max = int(v)
break
return {
"system": system,
"profiles": compute_serve_profiles(
system, m,
serve_weights_gb=(serve_weights_gb or None),
serve_quant=(serve_quant or None),
),
"model_ctx_max": model_ctx_max,
}
@router.get("/image-models")
def get_image_models(sort: str = "fit", search: str = "", host: str = "", gpu_count: str = "", ssh_port: str = "", platform: str = "", fresh: bool = False, manual_mode: str = "", manual_gpu_count: str = "", manual_vram_gb: str = "", manual_ram_gb: str = "", manual_backend: str = "", ignore_detected_gpu: bool = False, ignore_detected_ram: bool = False):
"""Rank image generation models against detected hardware."""
+2 -3
View File
@@ -27,7 +27,7 @@ from src.request_models import MemoryAddRequest
from core.database import SessionLocal
from src.llm_core import llm_call_async
from services.memory.memory_extractor import audit_memories
from src.auth_helpers import get_current_user
from src.auth_helpers import get_current_user, require_user
from src.endpoint_resolver import resolve_endpoint
logger = logging.getLogger(__name__)
@@ -191,8 +191,7 @@ def setup_memory_routes(memory_manager: MemoryManager, session_manager: SessionM
@router.post("/extract")
async def extract_memory(request: Request, session: str = Form(...)) -> Dict[str, List[str]]:
"""Analyze a session's chat history and return memory suggestions."""
if not get_current_user(request):
raise HTTPException(401, "Not authenticated")
require_user(request)
try:
sess = session_manager.get_session(session)
except KeyError:
+331 -125
View File
@@ -1,73 +1,213 @@
# routes/model_routes.py
"""Routes for model and provider management."""
import os
import re
import uuid
import json
import socket
import time as _time
import logging
import httpx
from datetime import datetime
from typing import List, Dict, Any, Optional
from urllib.parse import urlparse
from urllib.parse import urlparse, urlunparse
from fastapi import APIRouter, HTTPException, Form, Query, Body, Request
from pydantic import BaseModel
from fastapi.responses import StreamingResponse
from core.database import SessionLocal, ModelEndpoint, Session as DbSession
from core.middleware import require_admin
from src.llm_core import _detect_provider, ANTHROPIC_MODELS
from src.llm_core import _detect_provider, _host_match, ANTHROPIC_MODELS
from src.settings import load_settings as _load_settings, save_settings as _save_settings
from src.endpoint_resolver import normalize_base as _normalize_base, build_chat_url
from src.auth_helpers import owner_filter
from src.endpoint_resolver import (
normalize_base as _normalize_base,
build_chat_url,
build_models_url,
build_headers,
)
from src.auth_helpers import _auth_disabled, owner_filter
logger = logging.getLogger(__name__)
_SPEECH_ENDPOINT_SETTINGS = (
("tts_provider", "tts_model", "tts-1", "Text to Speech"),
("stt_provider", "stt_model", "base", "Speech to Text"),
)
def _anthropic_api_root(base: str) -> str:
"""Return Anthropic's API root without duplicating /v1."""
base = (base or "").strip().rstrip("/")
host = urlparse(base).hostname or ""
if host.endswith("anthropic.com") and base.endswith("/v1"):
return base[:-3].rstrip("/")
return base
_ENDPOINT_SETTING_FIELDS = {
"default_endpoint_id": ("default_model", "Default Model"),
"utility_endpoint_id": ("utility_model", "Utility Model"),
"research_endpoint_id": ("research_model", "Deep Research"),
"task_endpoint_id": ("task_model", "Background Tasks"),
}
_ENDPOINT_FALLBACK_FIELDS = {
"default_model_fallbacks": "Default Model Fallbacks",
"utility_model_fallbacks": "Utility Model Fallbacks",
"vision_model_fallbacks": "Vision Model Fallbacks",
}
def _ollama_api_root(base: str) -> str:
"""Return Ollama's native API root without depending on deferred imports."""
base = (base or "").strip().rstrip("/")
parsed = urlparse(base)
host = parsed.hostname or ""
path = (parsed.path or "").rstrip("/")
if path.endswith("/api"):
return base
if host.endswith("ollama.com"):
root = f"{parsed.scheme}://{parsed.netloc}" if parsed.scheme and parsed.netloc else "https://ollama.com"
return root.rstrip("/") + "/api"
return base
def _speech_settings_using_endpoint(settings: dict, ep_id: str) -> list:
"""Return speech settings that reference a model endpoint."""
endpoint_ref = f"endpoint:{ep_id}"
return [
label
for provider_key, _, _, label in _SPEECH_ENDPOINT_SETTINGS
if (settings.get(provider_key) or "") == endpoint_ref
]
def _models_url(base: str) -> str:
"""Return provider-specific model-list URL for route-local probing."""
provider = _detect_provider(base)
host = urlparse(base).hostname or ""
if provider == "anthropic" or host.endswith("anthropic.com"):
return _anthropic_api_root(base) + "/v1/models"
if provider == "ollama" or host.endswith("ollama.com"):
return _ollama_api_root(base) + "/tags"
return base.rstrip("/") + "/models"
def _clear_speech_settings_for_endpoint(settings: dict, ep_id: str) -> list:
"""Reset speech settings that reference a model endpoint."""
endpoint_ref = f"endpoint:{ep_id}"
cleared = []
for provider_key, model_key, default_model, label in _SPEECH_ENDPOINT_SETTINGS:
if (settings.get(provider_key) or "") == endpoint_ref:
settings[provider_key] = "disabled"
settings[model_key] = default_model
cleared.append(label)
return cleared
def _provider_headers(api_key: Optional[str], base: str) -> Dict[str, str]:
"""Build provider auth headers without depending on import-time stubs."""
if not api_key:
return {}
provider = _detect_provider(base)
host = urlparse(base).hostname or ""
if provider == "anthropic" or host.endswith("anthropic.com"):
return {
"x-api-key": api_key,
"anthropic-version": "2023-06-01",
}
return {"Authorization": f"Bearer {api_key}"}
def _endpoint_settings_using_endpoint(settings: dict, ep_id: str, *, include_speech: bool = False) -> list:
"""Return labels for settings and fallback chains that reference an endpoint."""
affected = []
for ep_key, (_, label) in _ENDPOINT_SETTING_FIELDS.items():
if (settings.get(ep_key) or "") == ep_id:
affected.append(label)
for fallback_key, label in _ENDPOINT_FALLBACK_FIELDS.items():
chain = settings.get(fallback_key) or []
if any(isinstance(entry, dict) and (entry.get("endpoint_id") or "") == ep_id for entry in chain):
affected.append(label)
if include_speech:
affected.extend(_speech_settings_using_endpoint(settings, ep_id))
return affected
def _clear_endpoint_settings_for_endpoint(settings: dict, ep_id: str, *, include_speech: bool = False) -> list:
"""Remove an endpoint from direct settings and model fallback chains."""
cleared = []
for ep_key, (model_key, label) in _ENDPOINT_SETTING_FIELDS.items():
if (settings.get(ep_key) or "") == ep_id:
settings[ep_key] = ""
settings[model_key] = ""
cleared.append(label)
for fallback_key, label in _ENDPOINT_FALLBACK_FIELDS.items():
chain = settings.get(fallback_key)
if not isinstance(chain, list):
continue
kept = [
entry for entry in chain
if not (isinstance(entry, dict) and (entry.get("endpoint_id") or "") == ep_id)
]
if len(kept) != len(chain):
settings[fallback_key] = kept
cleared.append(label)
if include_speech:
cleared.extend(_clear_speech_settings_for_endpoint(settings, ep_id))
return cleared
def _clear_user_pref_endpoint_refs(all_prefs: dict, ep_id: str) -> int:
"""Remove endpoint references from scoped or legacy-flat user preferences."""
if not isinstance(all_prefs, dict):
return 0
users = all_prefs.get("_users")
pref_sets = users.values() if isinstance(users, dict) else [all_prefs]
cleared_users = 0
for prefs in pref_sets:
if isinstance(prefs, dict) and _clear_endpoint_settings_for_endpoint(prefs, ep_id):
cleared_users += 1
return cleared_users
# Loopback hosts a user might type for a local model server (LM Studio,
# llama.cpp, vLLM, …). Inside Docker these point at the *container*, not the
# host the server actually runs on.
_ANY_BIND_HOSTS = {"0.0.0.0", "::"}
_LOOPBACK_HOSTS = {"localhost", "127.0.0.1", "::1", *_ANY_BIND_HOSTS}
def _docker_host_gateway_reachable() -> bool:
"""True when we run inside a container whose host is reachable via
``host.docker.internal`` (compose maps it to ``host-gateway``). Returns
False on native installs and on container setups without the mapping, so
the loopback rewrite below stays a no-op there."""
in_container = os.path.exists("/.dockerenv")
if not in_container:
try:
with open("/proc/1/cgroup", encoding="utf-8") as fh:
in_container = any(t in fh.read() for t in ("docker", "containerd", "kubepods"))
except OSError:
in_container = False
if not in_container:
return False
try:
socket.getaddrinfo("host.docker.internal", None)
return True
except OSError:
return False
def _container_loopback_reachable(base_url: str, timeout: float = 0.2) -> bool:
"""True when the requested loopback host:port is already reachable from
inside the current container.
This distinguishes "a model server running alongside Odysseus in the same
container" from "a model server running on the Docker host". Only the
latter should be rewritten to host.docker.internal.
"""
try:
parsed = urlparse(base_url)
except Exception:
return False
host = (parsed.hostname or "").lower()
port = parsed.port
if host not in _LOOPBACK_HOSTS or not port:
return False
probe_host = "::1" if host == "::1" else "127.0.0.1"
family = socket.AF_INET6 if probe_host == "::1" else socket.AF_INET
try:
with socket.socket(family, socket.SOCK_STREAM) as sock:
sock.settimeout(timeout)
sock.connect((probe_host, port))
return True
except OSError:
return False
def _rewrite_loopback_for_docker(base_url: str, *, container_local: bool = False) -> str:
"""Rewrite a loopback model-endpoint URL to ``host.docker.internal`` when
running in Docker. A URL like ``http://localhost:1234/v1`` (the LM Studio
default) otherwise targets the Odysseus container itself, so the probe gets
a connection error and the endpoint is rejected with a misleading "No
models found for that provider/key".
Cookbook local serves are the opposite case: Odysseus started the model
server inside the same container/process environment, so the saved endpoint
must remain container-local. In that mode, normalize a bind address such as
0.0.0.0 to a connectable loopback host, but do not jump to the Docker host.
"""
try:
parsed = urlparse(base_url)
except Exception:
return base_url
host = (parsed.hostname or "").lower()
if host not in _LOOPBACK_HOSTS:
return base_url
if container_local:
if host in _ANY_BIND_HOSTS:
netloc = "127.0.0.1" + (f":{parsed.port}" if parsed.port else "")
return urlunparse(parsed._replace(netloc=netloc))
return base_url
if host in _ANY_BIND_HOSTS and not _docker_host_gateway_reachable():
netloc = "127.0.0.1" + (f":{parsed.port}" if parsed.port else "")
return urlunparse(parsed._replace(netloc=netloc))
if _container_loopback_reachable(base_url):
return base_url
if not _docker_host_gateway_reachable():
return base_url
netloc = "host.docker.internal" + (f":{parsed.port}" if parsed.port else "")
return urlunparse(parsed._replace(netloc=netloc))
# ── Curated model lists per provider ──
@@ -84,10 +224,13 @@ _PROVIDER_CURATED = {
"claude-sonnet-4-5", "claude-haiku-3-5",
],
"zai": [
"glm-5", "glm-4.7", "glm-4.7-flash",
"glm-5", "glm-5.1", "glm-5v-turbo", "glm-4.7", "glm-4.7-flash",
"glm-4.6", "glm-4.6v",
"glm-4.5", "glm-4.5v", "glm-4.5-air", "glm-4.5-flash",
],
"zai-coding": [
"glm-5.1", "glm-5v-turbo", "glm-5-turbo", "glm-4.7", "glm-4.5-air",
],
"deepseek": [
"deepseek-chat", "deepseek-reasoner",
],
@@ -122,31 +265,40 @@ _PROVIDER_CURATED = {
],
}
# Map URL substrings → curated-list keys for providers whose _detect_provider()
# Map hostnames → curated-list keys for providers whose _detect_provider()
# returns a generic value (e.g. "openai") but deserve their own curated list.
# "openrouter" is a sentinel meaning "no curation — show all models as curated".
_URL_TO_CURATED = {
"z.ai": "zai",
"api.deepseek.com": "deepseek",
"api.groq.com": "groq",
"api.mistral.ai": "mistral",
"api.together.xyz": "together",
"api.fireworks.ai": "fireworks",
"generativelanguage.googleapis.com": "google",
"api.x.ai": "xai",
"openrouter.ai": "openrouter",
"ollama.com": "ollama",
}
# Entries are matched by hostname equality or subdomain suffix (via _host_match),
# so e.g. "deepseek.com" covers api.deepseek.com without matching the substring
# inside an unrelated URL.
_HOST_TO_CURATED = (
("z.ai", "zai"),
("deepseek.com", "deepseek"),
("groq.com", "groq"),
("mistral.ai", "mistral"),
("together.xyz", "together"),
("together.ai", "together"),
("fireworks.ai", "fireworks"),
("googleapis.com", "google"),
("x.ai", "xai"),
("openrouter.ai", "openrouter"),
("ollama.com", "ollama"),
)
def _match_provider_curated(base_url: str, provider: str) -> str:
"""Return the curated-list key for a given endpoint.
Checks the base URL against _URL_TO_CURATED first, then falls back
to the raw provider string from _detect_provider().
Checks path-based overrides first (for hosts serving multiple plans),
then matches the base URL's hostname against known providers, and
finally falls back to the raw provider string from _detect_provider().
"""
for substring, key in _URL_TO_CURATED.items():
if substring in (base_url or ""):
# Path-based overrides for hosts that serve multiple curated lists.
parsed = urlparse(base_url)
if _host_match(base_url, "z.ai") and "/api/coding" in (parsed.path or ""):
return "zai-coding"
for domain, key in _HOST_TO_CURATED:
if _host_match(base_url, domain):
return key
return provider
@@ -235,16 +387,20 @@ def _probe_single_model(base: str, api_key: str, model_id: str, timeout: int = 1
elif provider == "ollama":
from src.llm_core import _build_ollama_payload
target_url = build_chat_url(base)
h = _provider_headers(api_key, base)
h = build_headers(api_key, base)
h["Content-Type"] = "application/json"
payload = _build_ollama_payload(model_id, messages, 0.0, 5, stream=False, tools=_test_tools)
else:
target_url = build_chat_url(base)
h = _provider_headers(api_key, base)
h = build_headers(api_key, base)
h["Content-Type"] = "application/json"
from src.llm_core import _uses_max_completion_tokens
from src.llm_core import _uses_max_completion_tokens, _restricts_temperature
_max_key = "max_completion_tokens" if _uses_max_completion_tokens(model_id) else "max_tokens"
payload = {"model": model_id, "messages": messages, _max_key: 5, "temperature": 0.0}
payload = {"model": model_id, "messages": messages, _max_key: 5}
# Reasoning models (o1/o3/o4/gpt-5) reject an explicit temperature, so a
# probe that hardcodes one falsely reports a working endpoint as failing.
if not _restricts_temperature(model_id):
payload["temperature"] = 0.0
if _test_tools:
payload["tools"] = _test_tools
@@ -308,7 +464,7 @@ def _probe_endpoint(base_url: str, api_key: str = None, timeout: int = 5) -> Lis
base = resolve_url(_normalize_base(base_url))
if _detect_provider(base) == "anthropic":
# Try Anthropic's /v1/models endpoint first
url = _anthropic_api_root(base) + "/v1/models"
url = build_models_url(base)
headers = {"anthropic-version": "2023-06-01"}
if api_key:
headers["x-api-key"] = api_key
@@ -331,8 +487,8 @@ def _probe_endpoint(base_url: str, api_key: str = None, timeout: int = 5) -> Lis
return []
logger.warning(f"Anthropic /v1/models failed, using hardcoded list: {e}")
return list(ANTHROPIC_MODELS)
url = _models_url(base)
headers = _provider_headers(api_key, base)
url = build_models_url(base)
headers = build_headers(api_key, base)
try:
r = httpx.get(url, headers=headers, timeout=timeout)
r.raise_for_status()
@@ -343,6 +499,13 @@ def _probe_endpoint(base_url: str, api_key: str = None, timeout: int = 5) -> Lis
if not models:
models = [m.get("name") or m.get("model") for m in (data.get("models") or []) if m.get("name") or m.get("model")]
if models:
# Z.AI coding plan omits some working models from /models;
# append curated-only entries for that endpoint only.
if _host_match(base, "z.ai") and "/api/coding" in (urlparse(base).path or ""):
_ck = _match_provider_curated(base, None)
for _e in _PROVIDER_CURATED.get(_ck, []):
if _e not in set(models) and not any(m.startswith(_e) for m in models):
models.append(_e)
return models
except httpx.HTTPStatusError as e:
if api_key:
@@ -387,7 +550,24 @@ def _ping_endpoint(base_url: str, api_key: str = None, timeout: float = 1.5) ->
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
# Ollama exposes /v1/models (OpenAI-compatible) AND native /api/version,
# /api/tags. The OpenAI-style GET base + "/models" returns 404 when the
# base is the host root or the native /api root (e.g. http://localhost:11434,
# http://localhost:11434/api) because /models lives under /v1 there. Treat
# 4xx on a port-11434 / Ollama-named base as "try the native paths" rather
# than as a definitive offline verdict — Ollama is reachable, it just
# doesn't speak OpenAI on that prefix. Without this gate the quickstart
# marks an alive Ollama as offline whenever cached_models is empty (issue
# #1025): _probe_endpoint() falls through to /api/tags on the same 404, but
# _ping_endpoint() was returning before that fallback could run.
parsed_base = urlparse(base)
looks_like_ollama = (
parsed_base.port == 11434
or "ollama" in (parsed_base.hostname or "").lower()
)
url = base + "/models"
last_error: Optional[str] = None
try:
r = httpx.get(url, headers=headers, timeout=timeout)
if 300 <= r.status_code < 400:
@@ -399,17 +579,21 @@ def _ping_endpoint(base_url: str, api_key: str = None, timeout: float = 1.5) ->
"error": "That is Odysseus, not a model server. Use the Ollama URL, usually http://host.docker.internal:11434/v1 in Docker.",
}
return {"reachable": False, "status_code": r.status_code, "error": f"HTTP {r.status_code} redirect"}
if r.status_code < 500:
return {"reachable": r.status_code < 400, "status_code": r.status_code, "error": None if r.status_code < 400 else f"HTTP {r.status_code}"}
if r.status_code < 400:
return {"reachable": True, "status_code": r.status_code, "error": None}
if r.status_code < 500 and not looks_like_ollama:
return {"reachable": False, "status_code": r.status_code, "error": f"HTTP {r.status_code}"}
last_error = f"HTTP {r.status_code}"
except Exception as e:
last_error = str(e)[:120]
else:
last_error = f"HTTP {r.status_code}"
try:
parsed = urlparse(base)
if parsed.port == 11434 or "ollama" in (parsed.hostname or "").lower():
root = base[:-3].rstrip("/") if base.endswith("/v1") else base
if looks_like_ollama:
root = base
for suffix in ("/v1", "/api"):
if root.endswith(suffix):
root = root[: -len(suffix)].rstrip("/")
break
for path in ("/api/version", "/api/tags"):
try:
r = httpx.get(root + path, timeout=timeout)
@@ -449,6 +633,15 @@ def _model_endpoint_error_message(base_url: str, ping: Dict[str, Any] = None) ->
return "No models found for that provider/key."
def _visible_models(cached_models, hidden_models):
"""Filter cached model IDs by hidden_models. Returns list of visible IDs."""
all_models = json.loads(cached_models) if isinstance(cached_models, str) else (cached_models or [])
if not hidden_models:
return all_models
hidden = set(json.loads(hidden_models) if isinstance(hidden_models, str) else (hidden_models or []))
return [m for m in all_models if m not in hidden]
def setup_model_routes(model_discovery):
router = APIRouter(prefix="/api")
@@ -625,7 +818,7 @@ def setup_model_routes(model_discovery):
# list to unauthenticated callers.
try:
auth_mgr = getattr(request.app.state, "auth_manager", None)
if not owner and auth_mgr is not None and getattr(auth_mgr, "is_configured", False):
if not owner and not _auth_disabled() and auth_mgr is not None and getattr(auth_mgr, "is_configured", False):
raise HTTPException(401, "Not authenticated")
except HTTPException:
raise
@@ -746,8 +939,8 @@ def setup_model_routes(model_discovery):
entry["error"] = str(e)
entry["model_count"] = 0
else:
url = _models_url(base)
headers = _provider_headers(ep.api_key, base)
url = build_models_url(base)
headers = build_headers(ep.api_key, base)
try:
t0 = _time.time()
r = httpx.get(url, headers=headers, timeout=5)
@@ -965,23 +1158,23 @@ def setup_model_routes(model_discovery):
require_models: str = Form("false"),
model_type: str = Form("llm"),
supports_tools: str = Form(""), # "true"/"false"/"" (unknown)
container_local: str = Form("false"),
# Default `shared=true` → endpoints are visible to all users (the
# app's historical behaviour). Admins can pass `shared=false` to
# scope a new endpoint to their own account only.
shared: str = Form("true"),
):
require_admin(request)
base_url = base_url.strip().rstrip("/")
# Normalize: strip trailing /models, /chat/completions, /v1/messages etc to get clean base
for suffix in ["/models", "/chat/completions", "/completions", "/v1/messages"]:
if base_url.endswith(suffix):
base_url = base_url[:-len(suffix)].rstrip("/")
base_url = _normalize_base(base_url)
if not base_url:
raise HTTPException(400, "Base URL is required")
# Resolve hostname via Tailscale if DNS fails
from src.endpoint_resolver import resolve_url
base_url = resolve_url(base_url)
# In Docker, manually added loopback URLs usually point at a host-local
# server. Cookbook local serves are launched inside Odysseus itself, so
# keep those container-local when the frontend marks them as such.
base_url = _rewrite_loopback_for_docker(base_url, container_local=_truthy(container_local))
# Auto-generate name from URL if not provided
if not name.strip():
@@ -1052,11 +1245,15 @@ def setup_model_routes(model_discovery):
)
db.add(ep)
db.commit()
# Auto-set as default chat endpoint if none configured yet
# Auto-set as default chat endpoint if none configured yet. Seed
# the first CHAT model (not raw model_ids[0]) so we don't pin the
# global default to an embedding/tts/etc. entry a provider happens
# to list first.
settings = _load_settings()
if not settings.get("default_endpoint_id"):
from src.endpoint_resolver import _first_chat_model
settings["default_endpoint_id"] = ep.id
settings["default_model"] = model_ids[0] if model_ids else ""
settings["default_model"] = _first_chat_model(model_ids) or ""
_save_settings(settings)
_invalidate_models_cache()
_local_probe_cache["data"] = None
@@ -1081,14 +1278,12 @@ def setup_model_routes(model_discovery):
api_key: str = Form(""),
):
require_admin(request)
base_url = base_url.strip().rstrip("/")
for suffix in ["/models", "/chat/completions", "/completions", "/v1/messages"]:
if base_url.endswith(suffix):
base_url = base_url[:-len(suffix)].rstrip("/")
base_url = _normalize_base(base_url)
if not base_url:
raise HTTPException(400, "Base URL is required")
from src.endpoint_resolver import resolve_url
base_url = resolve_url(base_url)
base_url = _rewrite_loopback_for_docker(base_url)
probe_timeout = 3 if (":11434" in base_url or "ollama" in base_url.lower()) else 2
models = _probe_endpoint(base_url, api_key.strip() or None, timeout=probe_timeout)
ping = {"reachable": True, "error": None} if models else _ping_endpoint(base_url, api_key.strip() or None, timeout=probe_timeout)
@@ -1301,9 +1496,9 @@ def setup_model_routes(model_discovery):
chat_url = build_chat_url(base)
if not model and getattr(ep, "cached_models", None):
try:
models = _json.loads(ep.cached_models) if isinstance(ep.cached_models, str) else ep.cached_models
if models:
model = models[0]
visible = _visible_models(ep.cached_models, getattr(ep, "hidden_models", None))
if visible:
model = visible[0]
except Exception:
pass
return {"endpoint_id": ep.id, "endpoint_url": chat_url, "model": model}
@@ -1337,58 +1532,63 @@ def setup_model_routes(model_discovery):
ep.name = body["name"].strip() or ep.name
if "model_type" in body and isinstance(body["model_type"], str):
ep.model_type = body["model_type"].strip() or ep.model_type
# Rotating an API key used to require DELETE+POST, which wiped
# endpoint_url/model from every session referencing the old base
# URL. Allow in-place updates so the admin can change the key
# (or correct a typo'd base URL) without nuking session state.
if "api_key" in body and isinstance(body["api_key"], str):
_new_key = body["api_key"].strip()
# Empty string means "clear it" (e.g. local Ollama no longer needs a key).
ep.api_key = _new_key or None
if "base_url" in body and isinstance(body["base_url"], str):
_new_base = body["base_url"].strip().rstrip("/")
for _suffix in ("/models", "/chat/completions", "/completions", "/v1/messages"):
if _new_base.endswith(_suffix):
_new_base = _new_base[: -len(_suffix)].rstrip("/")
_new_base = _normalize_base(_new_base)
if _new_base:
ep.base_url = _new_base
else:
ep.is_enabled = not ep.is_enabled
db.commit()
_invalidate_models_cache()
_local_probe_cache["data"] = None
return {
"id": ep.id,
"is_enabled": ep.is_enabled,
"supports_tools": ep.supports_tools,
"name": ep.name,
"model_type": ep.model_type,
"base_url": ep.base_url,
}
finally:
db.close()
# ── Settings fields that store an endpoint ID ──
_EP_SETTING_FIELDS = {
"default_endpoint_id": ("default_model", "Default Model"),
"utility_endpoint_id": ("utility_model", "Utility Model"),
"research_endpoint_id": ("research_model", "Deep Research"),
"task_endpoint_id": ("task_model", "Background Tasks"),
}
def _settings_using_endpoint(ep_id: str) -> list:
"""Return human-readable labels for settings that reference this endpoint."""
settings = _load_settings()
affected = []
for ep_key, (_, label) in _EP_SETTING_FIELDS.items():
if (settings.get(ep_key) or "") == ep_id:
affected.append(label)
tts_prov = settings.get("tts_provider") or ""
if tts_prov == f"endpoint:{ep_id}":
affected.append("Text to Speech")
return affected
return _endpoint_settings_using_endpoint(_load_settings(), ep_id, include_speech=True)
def _clear_settings_for_endpoint(ep_id: str) -> list:
"""Clear all settings that reference this endpoint. Returns list of cleared labels."""
settings = _load_settings()
cleared = []
for ep_key, (model_key, label) in _EP_SETTING_FIELDS.items():
if (settings.get(ep_key) or "") == ep_id:
settings[ep_key] = ""
settings[model_key] = ""
cleared.append(label)
tts_prov = settings.get("tts_provider") or ""
if tts_prov == f"endpoint:{ep_id}":
settings["tts_provider"] = "disabled"
settings["tts_model"] = "tts-1"
cleared.append("Text to Speech")
cleared = _clear_endpoint_settings_for_endpoint(settings, ep_id, include_speech=True)
if cleared:
_save_settings(settings)
return cleared
def _clear_user_prefs_for_endpoint(ep_id: str) -> int:
"""Clear per-user endpoint selections and fallback chains."""
try:
from routes.prefs_routes import _load as _load_prefs, _save as _save_prefs
all_prefs = _load_prefs()
cleared_users = _clear_user_pref_endpoint_refs(all_prefs, ep_id)
if cleared_users:
_save_prefs(all_prefs)
return cleared_users
except Exception as e:
logger.warning("Failed to clear user prefs for endpoint %s: %s", ep_id, e)
return 0
def _session_uses_endpoint_url(session_url: str, base_url: str) -> bool:
if not session_url or not base_url:
return False
@@ -1402,12 +1602,18 @@ def setup_model_routes(model_discovery):
return sess in variants or sess.startswith(base + "/")
def _clear_sessions_for_endpoint(db, base_url: str) -> int:
"""Drop stored auth for sessions using an endpoint being deleted.
Keep the session's endpoint URL and model intact. If the admin is
replacing an endpoint with the same URL, clearing those fields leaves
the UI looking selected while chat requests arrive with an empty model.
The chat-time orphan guard still clears truly dead endpoints when no
matching enabled endpoint exists.
"""
cleared = 0
rows = db.query(DbSession).filter(DbSession.endpoint_url.isnot(None)).all()
for row in rows:
if _session_uses_endpoint_url(row.endpoint_url or "", base_url):
row.endpoint_url = ""
row.model = ""
row.headers = {}
row.updated_at = datetime.utcnow()
cleared += 1
@@ -1425,8 +1631,6 @@ def setup_model_routes(model_discovery):
try:
for sess in list(getattr(manager, "sessions", {}).values()):
if _session_uses_endpoint_url(getattr(sess, "endpoint_url", "") or "", base_url):
sess.endpoint_url = ""
sess.model = ""
sess.headers = {}
cleared += 1
except Exception:
@@ -1449,6 +1653,7 @@ def setup_model_routes(model_discovery):
raise HTTPException(404, "Endpoint not found")
# Clean up any settings that reference this endpoint
cleared = _clear_settings_for_endpoint(ep_id)
cleared_user_preferences = _clear_user_prefs_for_endpoint(ep_id)
cleared_sessions = _clear_sessions_for_endpoint(db, ep.base_url)
cleared_loaded_sessions = _clear_loaded_sessions_for_endpoint(ep.base_url)
db.delete(ep)
@@ -1458,6 +1663,7 @@ def setup_model_routes(model_discovery):
return {
"deleted": True,
"cleared_settings": cleared,
"cleared_user_preferences": cleared_user_preferences,
"cleared_sessions": cleared_sessions,
"cleared_loaded_sessions": cleared_loaded_sessions,
}
+3 -4
View File
@@ -683,9 +683,8 @@ def setup_note_routes(task_scheduler=None):
Returns {synthesis, email_sent}.
"""
# Gate against anonymous callers — LLM synthesis can burn tokens.
from src.auth_helpers import get_current_user as _gcu
if not _gcu(request):
raise HTTPException(401, "Not authenticated")
from src.auth_helpers import require_user as _ru
_ru(request)
body = await request.json()
note_id = body.get("note_id")
title = (body.get("title") or "").strip()
@@ -697,7 +696,7 @@ def setup_note_routes(task_scheduler=None):
# the same dispatch without an HTTP roundtrip + auth cookie.
return await dispatch_reminder(
title=title, note_body=note_body, note_id=note_id,
owner=_gcu(request) or "",
owner=_owner(request) or "",
queue_browser=False,
)
+5 -2
View File
@@ -69,9 +69,12 @@ def setup_personal_routes(personal_docs_manager, rag_manager, rag_available):
if not directory:
raise HTTPException(400, "Directory path is required")
base_abs = os.path.abspath(PERSONAL_DIR)
# realpath (not abspath) so a symlink inside PERSONAL_DIR that points
# outside it is resolved before the commonpath confinement check below;
# abspath only normalises `..` and would let such a symlink escape.
base_abs = os.path.realpath(PERSONAL_DIR)
candidate = directory if os.path.isabs(directory) else os.path.join(base_abs, directory)
resolved = os.path.abspath(candidate)
resolved = os.path.realpath(candidate)
try:
in_base = os.path.commonpath([resolved, base_abs]) == base_abs
except ValueError:
+14 -2
View File
@@ -12,7 +12,8 @@ def _load():
"""Load the raw prefs file (internal use only)."""
try:
with open(PREFS_FILE, "r", encoding="utf-8") as f:
return json.load(f)
data = json.load(f)
return data if isinstance(data, dict) else {}
except (FileNotFoundError, json.JSONDecodeError):
return {}
@@ -40,7 +41,18 @@ def _save_for_user(user: Optional[str], prefs: dict):
"""Save preferences for a specific user."""
all_prefs = _load()
if user is None:
# Auth disabled — save flat
# Auth disabled. If the store is already multi-user (e.g. auth was
# turned off on a deployment that previously ran multi-user), writing
# `prefs` flat would overwrite the whole `_users` map and destroy every
# other user's preferences. Instead write back into the same (first)
# slot _load_for_user(None) reads from, preserving the others.
if "_users" in all_prefs:
users = all_prefs["_users"]
first_key = next(iter(users), None)
if first_key is not None:
users[first_key] = prefs
_save(all_prefs)
return
_save(prefs)
return
if "_users" not in all_prefs:
+30 -4
View File
@@ -3,6 +3,7 @@
import asyncio
import json
import logging
import re
import uuid
from datetime import datetime
from pathlib import Path
@@ -12,7 +13,9 @@ from fastapi import APIRouter, HTTPException, Query, Request
from fastapi.responses import HTMLResponse, StreamingResponse
from pydantic import BaseModel, Field
from src.endpoint_resolver import resolve_endpoint
from src.auth_helpers import get_current_user
from src.auth_helpers import _auth_disabled, get_current_user
_SESSION_ID_RE = re.compile(r"^[a-zA-Z0-9-]{1,128}$")
logger = logging.getLogger(__name__)
@@ -55,9 +58,15 @@ def setup_research_routes(research_handler, session_manager=None) -> APIRouter:
verify the session belongs to this user."""
user = get_current_user(request)
if not user:
if _auth_disabled():
return ""
raise HTTPException(401, "Not authenticated")
return user
def _validate_session_id(session_id: str) -> None:
if not _SESSION_ID_RE.fullmatch(session_id):
raise HTTPException(400, "Invalid session ID format")
def _owns_in_memory(session_id: str, user: str) -> bool:
"""Ownership check for an in-flight (in-memory) research task.
Falls back to the on-disk JSON if the task has already finished."""
@@ -95,6 +104,7 @@ def setup_research_routes(research_handler, session_manager=None) -> APIRouter:
@router.get("/api/research/status/{session_id}")
async def research_status(session_id: str, request: Request):
user = _require_user(request)
_validate_session_id(session_id)
if not _owns_in_memory(session_id, user):
raise HTTPException(404, "No research found for this session")
status = research_handler.get_status(session_id)
@@ -105,6 +115,7 @@ def setup_research_routes(research_handler, session_manager=None) -> APIRouter:
@router.post("/api/research/cancel/{session_id}")
async def research_cancel(session_id: str, request: Request):
user = _require_user(request)
_validate_session_id(session_id)
if not _owns_in_memory(session_id, user):
raise HTTPException(404, "No research found for this session")
cancelled = research_handler.cancel_research(session_id)
@@ -113,6 +124,7 @@ def setup_research_routes(research_handler, session_manager=None) -> APIRouter:
@router.post("/api/research/result/{session_id}")
async def research_result(session_id: str, request: Request):
user = _require_user(request)
_validate_session_id(session_id)
if not _owns_in_memory(session_id, user):
raise HTTPException(404, "No research result available")
result = research_handler.get_result(session_id)
@@ -140,6 +152,7 @@ def setup_research_routes(research_handler, session_manager=None) -> APIRouter:
async def research_report(session_id: str, request: Request):
"""Serve the visual HTML report for a completed research session."""
user = _require_user(request)
_validate_session_id(session_id)
_assert_owns_research(session_id, user)
logger.info(f"Visual report requested for session {session_id}")
try:
@@ -160,6 +173,7 @@ def setup_research_routes(research_handler, session_manager=None) -> APIRouter:
"""Mark an image URL as hidden for this research's visual report.
Persisted to the research JSON so subsequent /report renders skip it."""
user = _require_user(request)
_validate_session_id(session_id)
_assert_owns_research(session_id, user)
ok = research_handler.hide_image(session_id, body.url)
if not ok:
@@ -170,6 +184,7 @@ def setup_research_routes(research_handler, session_manager=None) -> APIRouter:
async def research_unhide_images(session_id: str, request: Request):
"""Clear the hidden-images list for a research session."""
user = _require_user(request)
_validate_session_id(session_id)
_assert_owns_research(session_id, user)
ok = research_handler.unhide_all_images(session_id)
if not ok:
@@ -235,6 +250,7 @@ def setup_research_routes(research_handler, session_manager=None) -> APIRouter:
"""Return the full JSON for a single research result — sources,
summary, stats used by the Library preview panel."""
user = _require_user(request)
_validate_session_id(session_id)
path = Path("data/deep_research") / f"{session_id}.json"
if not path.exists():
raise HTTPException(404, "Research not found")
@@ -251,6 +267,7 @@ def setup_research_routes(research_handler, session_manager=None) -> APIRouter:
async def research_archive(session_id: str, request: Request, archived: bool = Query(True)):
"""Soft-archive / restore a research report (sets `archived` in its JSON)."""
user = _require_user(request)
_validate_session_id(session_id)
path = Path("data/deep_research") / f"{session_id}.json"
if not path.exists():
raise HTTPException(404, "Research not found")
@@ -270,6 +287,7 @@ def setup_research_routes(research_handler, session_manager=None) -> APIRouter:
async def research_delete(session_id: str, request: Request):
"""Delete a research result from disk."""
user = _require_user(request)
_validate_session_id(session_id)
data_dir = Path("data/deep_research")
json_path = data_dir / f"{session_id}.json"
deleted = False
@@ -299,7 +317,7 @@ def setup_research_routes(research_handler, session_manager=None) -> APIRouter:
endpoint_id: Optional[str] = None
model: Optional[str] = None
max_time: int = Field(default=300, ge=60, le=1800)
extraction_timeout: Optional[int] = Field(default=None, ge=15, le=600)
extraction_timeout: Optional[int] = Field(default=None, ge=15, le=3600)
extraction_concurrency: Optional[int] = Field(default=None, ge=1, le=12)
category: Optional[str] = None
@@ -413,6 +431,7 @@ def setup_research_routes(research_handler, session_manager=None) -> APIRouter:
async def research_stream(session_id: str, request: Request):
"""SSE stream of research progress events."""
user = _require_user(request)
_validate_session_id(session_id)
if not _owns_in_memory(session_id, user):
raise HTTPException(404, "No research found for this session")
async def _generate():
@@ -446,6 +465,7 @@ def setup_research_routes(research_handler, session_manager=None) -> APIRouter:
async def research_result_peek(session_id: str, request: Request):
"""Get research result without clearing it (for panel use)."""
user = _require_user(request)
_validate_session_id(session_id)
if not _owns_in_memory(session_id, user):
raise HTTPException(404, "No research found for this session")
result = research_handler.get_result(session_id)
@@ -474,7 +494,14 @@ def setup_research_routes(research_handler, session_manager=None) -> APIRouter:
injects a single system message containing the report and sources so
the user can ask follow-up questions in a clean conversation.
"""
_require_user(request)
user = _require_user(request)
_validate_session_id(session_id)
# SECURITY: gate on ownership before reading the persisted research —
# otherwise any authenticated user could spin off (and thereby read)
# another user's report by guessing its session ID. Mirrors every other
# endpoint in this file (see result_peek above).
if not _owns_in_memory(session_id, user):
raise HTTPException(404, "No research found for this session")
if session_manager is None:
raise HTTPException(500, "session_manager not configured")
@@ -555,7 +582,6 @@ def setup_research_routes(research_handler, session_manager=None) -> APIRouter:
# Create new session
new_sid = str(uuid.uuid4())
user = get_current_user(request)
title_query = (query or "research").strip()
if len(title_query) > 60:
+162 -48
View File
@@ -11,45 +11,118 @@ from core.session_manager import SessionManager
from core.models import ChatMessage
from src.request_models import SessionResponse
from core.database import Session as DbSession, SessionLocal, Document, GalleryImage
from src.auth_helpers import get_current_user
from src.auth_helpers import get_current_user, effective_user
def _verify_session_owner(request: Request, session_id: str):
"""Verify the current user owns the session. Raises 404 if not."""
user = get_current_user(request)
def _sanitize_export_filename(name: str) -> str:
"""Return a conservative filename safe for Content-Disposition."""
name = name if isinstance(name, str) else ""
name = re.sub(r"[^A-Za-z0-9._-]", "_", name)
return name[:128]
def _verify_session_owner(request: Request, session_id: str, session_manager=None):
"""Verify the current user owns the session. Raises 404 if not.
Ownership is checked against the DB row when one exists (unchanged). If
there is no DB row but the caller owns an in-memory "ghost" session one
that lives only in ``session_manager`` because it was never persisted, or
its DB row was removed out-of-band fall back to the in-memory owner so the
user can still manage and delete it. Without this fallback such sessions are
listed by ``/api/sessions`` (they come from the in-memory manager) yet every
per-session operation 404s, making them impossible to delete (issue #1044).
``session_manager`` is optional and defaults to ``None`` so existing callers
that only care about persisted sessions keep their exact prior behavior.
"""
user = effective_user(request)
if not user:
raise HTTPException(403, "Authentication required")
db = SessionLocal()
try:
row = db.query(DbSession.owner).filter(DbSession.id == session_id).first()
if not row:
raise HTTPException(404, f"Session {session_id} not found")
if row.owner != user:
raise HTTPException(404, f"Session {session_id} not found")
finally:
db.close()
if row is not None:
if row.owner != user:
raise HTTPException(404, f"Session {session_id} not found")
return
# No DB row — allow the caller to act on an in-memory ghost they own.
if session_manager is not None:
ghost = getattr(session_manager, "sessions", {}).get(session_id)
if ghost is not None and getattr(ghost, "owner", None) == user:
return
raise HTTPException(404, f"Session {session_id} not found")
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api", tags=["sessions"])
def _pick_endpoint_for_sort():
def _current_user_is_admin(request: Request, user: str | None) -> bool:
if not user:
return False
auth_mgr = getattr(request.app.state, "auth_manager", None)
is_admin = getattr(auth_mgr, "is_admin", None)
if not callable(is_admin):
return False
try:
return bool(is_admin(user))
except Exception:
return False
def _reject_raw_endpoint_url_for_non_admin(
request: Request,
user: str | None,
endpoint_id: str | None,
endpoint_url: str | None,
) -> None:
"""Require registered endpoints for signed-in non-admin session changes."""
if endpoint_id and endpoint_id.strip():
return
if not endpoint_url:
return
# Raw URLs make the server dial whatever host the request supplies. For
# non-admin users, require a saved endpoint row so normal owner scoping and
# endpoint validation have already happened.
if user and not _current_user_is_admin(request, user):
raise HTTPException(403, "Choose a registered model endpoint")
def _persist_session_headers(session_id: str, headers: dict | None) -> None:
"""Persist endpoint auth headers for DB-backed session metadata."""
db = SessionLocal()
try:
db_session = db.query(DbSession).filter(DbSession.id == session_id).first()
if db_session:
db_session.headers = headers or {}
db_session.updated_at = datetime.utcnow()
db.commit()
except Exception:
db.rollback()
raise
finally:
db.close()
def _pick_endpoint_for_sort(owner=None):
"""Pick model endpoint for auto-sort LLM call — uses utility endpoint setting, falls back to default."""
from src.endpoint_resolver import resolve_endpoint
# Try utility endpoint first (what the user configured for background tasks)
url, model, headers = resolve_endpoint("utility")
url, model, headers = resolve_endpoint("utility", owner=owner)
if url and model:
return url, model, headers
# Fall back to task endpoint
try:
from src.task_endpoint import resolve_task_endpoint
url, model, headers = resolve_task_endpoint()
url, model, headers = resolve_task_endpoint(owner=owner)
if url and model:
return url, model, headers
except Exception:
pass
# Fall back to default
url, model, headers = resolve_endpoint("default")
url, model, headers = resolve_endpoint("default", owner=owner)
if url and model:
return url, model, headers
return None, None, None
@@ -63,7 +136,7 @@ def setup_session_routes(session_manager: SessionManager, config: dict, webhook_
@router.get("/sessions")
def list_sessions(request: Request):
user = get_current_user(request)
user = effective_user(request)
# Lazy purge: incognito sessions are ephemeral by design — wipe leftovers
# from the DB and session_manager so they vanish on the next page refresh.
# BUT: skip sessions that were created within the last 10 minutes.
@@ -172,11 +245,41 @@ def setup_session_routes(session_manager: SessionManager, config: dict, webhook_
endpoint_id: str = Form(""),
):
skip_val = str(skip_validation).lower() == "true"
user = get_current_user(request)
endpoint_api_key = ""
endpoint_base_url = ""
_reject_raw_endpoint_url_for_non_admin(request, user, endpoint_id, endpoint_url)
if endpoint_id and endpoint_id.strip():
from core.database import ModelEndpoint
from src.auth_helpers import owner_filter
from src.endpoint_resolver import build_chat_url, normalize_base
_db = SessionLocal()
try:
q = _db.query(ModelEndpoint).filter(
ModelEndpoint.id == endpoint_id.strip(),
ModelEndpoint.is_enabled == True,
)
if user:
q = owner_filter(q, ModelEndpoint, user)
endpoint_row = q.first()
if not endpoint_row:
raise HTTPException(400, "Model endpoint no longer exists")
endpoint_base_url = endpoint_row.base_url or ""
endpoint_api_key = endpoint_row.api_key or ""
endpoint_url = build_chat_url(normalize_base(endpoint_base_url))
finally:
_db.close()
if not endpoint_url and not skip_val:
raise HTTPException(400, "endpoint_url is required (choose from /api/models)")
model_to_use = model
request_api_key = api_key.strip() if api_key else ""
effective_api_key = request_api_key or endpoint_api_key
validation_headers = None
if effective_api_key:
from src.endpoint_resolver import build_headers
validation_headers = build_headers(effective_api_key, endpoint_base_url or endpoint_url)
if skip_val:
# skip_validation = trust the caller and do NOT probe /v1/models.
@@ -187,7 +290,7 @@ def setup_session_routes(session_manager: SessionManager, config: dict, webhook_
elif not model_to_use:
from src.llm_core import list_model_ids
ids = list_model_ids(endpoint_url, timeout=REQUEST_TIMEOUT,
headers={"Authorization": f"Bearer {api_key}"} if api_key.strip() else None)
headers=validation_headers)
if not ids:
raise HTTPException(400, "Cannot reach /v1/models")
# Default to the first CHAT model — endpoints often list embedding/
@@ -202,7 +305,7 @@ def setup_session_routes(session_manager: SessionManager, config: dict, webhook_
import os as _os
req_base = _os.path.basename(model_to_use.rstrip("/"))
avail = list_model_ids(endpoint_url, timeout=REQUEST_TIMEOUT,
headers={"Authorization": f"Bearer {api_key}"} if api_key.strip() else None)
headers=validation_headers)
if not avail:
raise HTTPException(400, "Cannot reach /v1/models")
if model_to_use not in avail:
@@ -217,7 +320,7 @@ def setup_session_routes(session_manager: SessionManager, config: dict, webhook_
model_to_use = found
sid = str(uuid.uuid4())
user = get_current_user(request)
user = effective_user(request)
session = session_manager.create_session(
session_id=sid,
name=name or "",
@@ -227,22 +330,15 @@ def setup_session_routes(session_manager: SessionManager, config: dict, webhook_
owner=user,
)
# Set auth headers for custom API-key endpoints
resolved_key = api_key.strip() if api_key else ""
resolved_key = request_api_key
resolved_base = endpoint_url
if not resolved_key and endpoint_id and endpoint_id.strip():
from core.database import ModelEndpoint
_db = SessionLocal()
try:
ep = _db.query(ModelEndpoint).filter(ModelEndpoint.id == endpoint_id.strip()).first()
if ep and ep.api_key:
resolved_key = ep.api_key
resolved_base = ep.base_url
finally:
_db.close()
if not resolved_key and endpoint_api_key:
resolved_key = endpoint_api_key
resolved_base = endpoint_base_url
if resolved_key:
from src.endpoint_resolver import build_headers
session.headers = build_headers(resolved_key, resolved_base)
session_manager.save_sessions()
_persist_session_headers(sid, session.headers)
# Fire webhook (sync-safe)
if webhook_manager:
webhook_manager.fire_and_forget("session.created", {
@@ -288,27 +384,38 @@ def setup_session_routes(session_manager: SessionManager, config: dict, webhook_
db.close()
# Switch model/endpoint mid-session
if model is not None and endpoint_url is not None:
user = get_current_user(request)
_reject_raw_endpoint_url_for_non_admin(request, user, endpoint_id, endpoint_url)
endpoint_api_key = ""
endpoint_base_url = ""
if endpoint_id:
from core.database import ModelEndpoint
from src.auth_helpers import owner_filter
from src.endpoint_resolver import build_chat_url, normalize_base
_db = SessionLocal()
try:
ep = _db.query(ModelEndpoint).filter(ModelEndpoint.id == endpoint_id).first()
q = _db.query(ModelEndpoint).filter(
ModelEndpoint.id == endpoint_id,
ModelEndpoint.is_enabled == True,
)
if user:
q = owner_filter(q, ModelEndpoint, user)
ep = q.first()
if not ep:
raise HTTPException(400, "Model endpoint no longer exists")
endpoint_base_url = ep.base_url or ""
endpoint_api_key = ep.api_key or ""
endpoint_url = build_chat_url(normalize_base(endpoint_base_url))
finally:
_db.close()
session.model = model
session.endpoint_url = endpoint_url
# Update auth headers from the endpoint's stored API key
if endpoint_id:
_db = SessionLocal()
try:
ep = _db.query(ModelEndpoint).filter(ModelEndpoint.id == endpoint_id).first()
if ep and ep.api_key:
from src.endpoint_resolver import build_headers
session.headers = build_headers(ep.api_key, ep.base_url)
finally:
_db.close()
if endpoint_api_key:
from src.endpoint_resolver import build_headers
session.headers = build_headers(endpoint_api_key, endpoint_base_url)
else:
session.headers = {}
# Persist to DB
db = SessionLocal()
try:
@@ -316,6 +423,7 @@ def setup_session_routes(session_manager: SessionManager, config: dict, webhook_
if db_session:
db_session.model = model
db_session.endpoint_url = endpoint_url
db_session.headers = session.headers or {}
db_session.updated_at = datetime.utcnow()
db.commit()
finally:
@@ -356,7 +464,7 @@ def setup_session_routes(session_manager: SessionManager, config: dict, webhook_
ids = []
for sid in ids:
try:
_verify_session_owner(request, sid)
_verify_session_owner(request, sid, session_manager)
session_manager.delete_session(sid)
db = SessionLocal()
try:
@@ -374,7 +482,7 @@ def setup_session_routes(session_manager: SessionManager, config: dict, webhook_
@router.delete("/session/{sid}")
def delete_session(request: Request, sid: str):
"""Permanently delete a session and all its messages."""
_verify_session_owner(request, sid)
_verify_session_owner(request, sid, session_manager)
try:
# Block deletion of starred/favorited sessions
db = SessionLocal()
@@ -499,7 +607,7 @@ def setup_session_routes(session_manager: SessionManager, config: dict, webhook_
@router.get("/sessions/archived")
def list_archived_sessions(request: Request, search: str = "", offset: int = 0, limit: int = 20, sort: str = "recent", model: str = ""):
"""List archived sessions for the archive browser."""
user = get_current_user(request)
user = effective_user(request)
db = SessionLocal()
try:
q = db.query(DbSession).filter(DbSession.archived == True)
@@ -510,7 +618,12 @@ def setup_session_routes(session_manager: SessionManager, config: dict, webhook_
safe_search = search.replace('%', r'\%').replace('_', r'\_')
q = q.filter(DbSession.name.ilike(f"%{safe_search}%", escape='\\'))
if model:
q = q.filter(DbSession.model.ilike(f"%{model}"))
# Contains match (mirrors the name filter above). The old
# f"%{model}" was a SUFFIX-only match, so filtering by "gpt-4"
# dropped "gpt-4o" and over-matched on shared suffixes; it also
# left LIKE wildcards in the user value unescaped.
safe_model = model.replace('%', r'\%').replace('_', r'\_')
q = q.filter(DbSession.model.ilike(f"%{safe_model}%", escape='\\'))
total = q.count()
sort_map = {
"recent": DbSession.updated_at.desc(),
@@ -558,6 +671,7 @@ def setup_session_routes(session_manager: SessionManager, config: dict, webhook_
safe_name = re.sub(r'[^\w\-_]', '_', session.name)
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
filename = _sanitize_export_filename(filename)
if fmt == "json":
import json as _json
@@ -635,7 +749,7 @@ def setup_session_routes(session_manager: SessionManager, config: dict, webhook_
@router.post("/sessions/save")
def sessions_save_now(request: Request):
user = get_current_user(request)
user = effective_user(request)
if not user:
raise HTTPException(401, "Not authenticated")
session_manager.save_sessions()
@@ -651,7 +765,7 @@ def setup_session_routes(session_manager: SessionManager, config: dict, webhook_
if not OPENAI_API_KEY:
raise HTTPException(400, "Server missing OPENAI_API_KEY")
sid = str(uuid.uuid4())
user = get_current_user(request)
user = effective_user(request)
session = session_manager.create_session(
session_id=sid,
name="",
@@ -728,7 +842,7 @@ def setup_session_routes(session_manager: SessionManager, config: dict, webhook_
from src.endpoint_resolver import resolve_endpoint
from src.llm_core import llm_call_async
url, model, headers = resolve_endpoint("utility")
url, model, headers = resolve_endpoint("utility", owner=get_current_user(request))
if not url or not model:
url, model, headers = session.endpoint_url, session.model, session.headers
if not url or not model:
@@ -791,7 +905,7 @@ def setup_session_routes(session_manager: SessionManager, config: dict, webhook_
users can clean junk without spending tokens.
"""
from src.llm_core import llm_call
user = get_current_user(request)
user = effective_user(request)
user_sessions = session_manager.get_sessions_for_user(user)
# Delete empty and throwaway sessions before sorting
@@ -928,9 +1042,9 @@ def setup_session_routes(session_manager: SessionManager, config: dict, webhook_
# Pick an endpoint — prefer admin-configured task endpoint
from src.task_endpoint import resolve_task_endpoint
url, model, headers = resolve_task_endpoint()
url, model, headers = resolve_task_endpoint(owner=user)
if not url:
url, model, headers = _pick_endpoint_for_sort()
url, model, headers = _pick_endpoint_for_sort(owner=user)
if not url:
raise HTTPException(503, "No available model endpoint for auto-sort")
+173 -13
View File
@@ -118,6 +118,7 @@ def _running_in_container(dockerenv_path="/.dockerenv", cgroup_path="/proc/1/cgr
DockerRowStatus = namedtuple("DockerRowStatus", ["applicable", "install_hint"])
PackageUpdateStatus = namedtuple("PackageUpdateStatus", ["available", "note"])
def _docker_row_status(*, on_remote, in_container, installed, default_hint):
@@ -127,6 +128,24 @@ def _docker_row_status(*, on_remote, in_container, installed, default_hint):
return DockerRowStatus(applicable=True, install_hint=default_hint)
def _pip_dist_name(pkg: dict) -> str:
"""Distribution name for importlib.metadata lookups.
The Cookbook package catalog carries both the import name (``name``, e.g.
``llama_cpp``) and the pip spec (``pip``, e.g. ``llama-cpp-python[server]``).
The distribution is NOT always the import name with underscores swapped for
dashes ``llama_cpp`` ships in the ``llama-cpp-python`` distribution so
derive it from the pip spec (stripping any ``[extras]`` and version markers)
and fall back to the munged import name only when no pip spec is declared.
"""
pip = (pkg.get("pip") or "").strip()
if pip:
base = re.split(r"[\[<>=!~;\s]", pip, maxsplit=1)[0].strip()
if base:
return base
return (pkg.get("name") or "").replace("_", "-")
def _package_installed_from_probe(name: str, probe: dict) -> bool:
"""Return whether an optional dependency is usable by Cookbook.
@@ -162,7 +181,10 @@ def _package_status_note(name: str, probe: dict) -> str:
locations = module.get("locations") or []
if name == "vllm":
if binaries.get("vllm"):
return f"vLLM CLI: {binaries['vllm']}"
parts = [f"vLLM CLI: {binaries['vllm']}"]
if dists.get("vllm"):
parts.append(f"python package: vllm {dists['vllm']}")
return "; ".join(parts)
if module.get("found") and not dists.get("vllm"):
loc = locations[0] if locations else module.get("origin") or "unknown path"
return f"Python sees a vllm namespace at {loc}, but no vLLM CLI is on PATH."
@@ -183,13 +205,70 @@ def _package_status_note(name: str, probe: dict) -> str:
return ""
def _package_pip_update_status(pkg: dict, probe: dict | None = None) -> PackageUpdateStatus:
"""Return whether the Dependencies UI should offer a generic pip update.
"Installed" means Cookbook can use the dependency. It does not always mean
the dependency is a Python package that Cookbook should update with pip:
native llama-server can come from a package manager/source build, and a CLI
may be on PATH without matching Python package metadata.
"""
if pkg.get("kind") == "system" or not pkg.get("pip"):
return PackageUpdateStatus(False, "Update this system dependency outside Odysseus.")
name = pkg.get("name")
binaries = probe.get("binaries") if isinstance(probe, dict) and isinstance(probe.get("binaries"), dict) else {}
dists = probe.get("dists") if isinstance(probe, dict) and isinstance(probe.get("dists"), dict) else {}
if name == "llama_cpp" and binaries.get("llama-server"):
return PackageUpdateStatus(
False,
"Using native llama-server on PATH; update it with its package manager or source checkout.",
)
if name == "vllm" and binaries.get("vllm") and not dists.get("vllm"):
return PackageUpdateStatus(
False,
"Using a vLLM CLI on PATH without Python package metadata; update it outside Odysseus.",
)
return PackageUpdateStatus(True, "Update uses pip in the selected Python environment.")
def _prepend_user_install_bins_to_path() -> None:
"""Make pip --user console scripts visible to dependency probes.
Docker Cookbook installs vLLM with `python -m pip install --user`, which
drops the `vllm` CLI in /app/.local/bin. The running app process does not
inherit that PATH update, so `shutil.which("vllm")` can report missing even
after a successful install.
"""
try:
import site
candidates = [os.path.join(site.USER_BASE, "bin")]
except Exception:
candidates = []
candidates.append(os.path.expanduser("~/.local/bin"))
parts = os.environ.get("PATH", "").split(os.pathsep) if os.environ.get("PATH") else []
changed = False
for path in reversed([p for p in candidates if p]):
if path not in parts:
parts.insert(0, path)
changed = True
if changed:
os.environ["PATH"] = os.pathsep.join(parts)
def _package_probe_script(names: list[str]) -> str:
names_lit = ",".join(repr(n) for n in names)
return f"""
import importlib.util
import importlib.metadata as md
import json
import os
import shutil
import site
names=[{names_lit}]
dist_names={{
@@ -204,6 +283,24 @@ bin_names={{
'llama_cpp':['llama-server'],
}}
def add_user_install_bins_to_path():
candidates = []
try:
candidates.append(os.path.join(site.USER_BASE, 'bin'))
except Exception:
pass
candidates.append(os.path.expanduser('~/.local/bin'))
parts = os.environ.get('PATH', '').split(os.pathsep) if os.environ.get('PATH') else []
changed = False
for path in reversed([p for p in candidates if p]):
if path not in parts:
parts.insert(0, path)
changed = True
if changed:
os.environ['PATH'] = os.pathsep.join(parts)
add_user_install_bins_to_path()
def mod_status(n):
spec = importlib.util.find_spec(n)
loader = getattr(spec, 'loader', None) if spec else None
@@ -317,7 +414,7 @@ async def _generate_pty(cmd: str, timeout: int, request: Request):
yield f"data: {json.dumps({'exit_code': -1, 'error': PTY_UNSUPPORTED_ERROR})}\n\n"
return
loop = asyncio.get_event_loop()
loop = asyncio.get_running_loop()
master_fd, slave_fd = pty.openpty()
# Set master to non-blocking
@@ -469,7 +566,8 @@ async def _generate_tmux(cmd: str, request: Request):
f"EC=${{PIPESTATUS[0]}}\n"
f"echo ':::EXIT_CODE:::'$EC >> '{log_path}'\n"
f"rm -f '{script_path}'\n"
f"exit $EC\n"
f"exit $EC\n",
encoding="utf-8",
)
script_path.chmod(0o755)
logger.info("tmux wrapper script created: session=%s path=%s", session_id, script_path)
@@ -504,7 +602,7 @@ async def _generate_tmux(cmd: str, request: Request):
# Read new lines from log
try:
if log_path.exists():
lines = log_path.read_text(errors="replace").splitlines()
lines = log_path.read_text(encoding="utf-8", errors="replace").splitlines()
new_lines = lines[lines_sent:]
for line in new_lines:
if line.startswith(":::EXIT_CODE:::"):
@@ -532,7 +630,7 @@ async def _generate_tmux(cmd: str, request: Request):
# Session ended — do one final read
await asyncio.sleep(0.5)
if log_path.exists():
lines = log_path.read_text(errors="replace").splitlines()
lines = log_path.read_text(encoding="utf-8", errors="replace").splitlines()
for line in lines[lines_sent:]:
if line.startswith(":::EXIT_CODE:::"):
try:
@@ -735,10 +833,11 @@ def setup_shell_routes() -> APIRouter:
]
finished = 0
deadline = (asyncio.get_event_loop().time() + timeout) if timeout else None
loop = asyncio.get_running_loop()
deadline = (loop.time() + timeout) if timeout else None
while finished < 2:
if deadline:
remaining = deadline - asyncio.get_event_loop().time()
remaining = deadline - loop.time()
if remaining <= 0:
raise asyncio.TimeoutError()
wait = min(remaining, 2.0)
@@ -791,7 +890,15 @@ def setup_shell_routes() -> APIRouter:
"""
_require_admin(request)
_reject_cross_site(request)
import importlib, importlib.metadata as importlib_metadata, shlex, json as _json
import importlib, importlib.metadata as importlib_metadata, shlex, json as _json, site, sys
_prepend_user_install_bins_to_path()
importlib.invalidate_caches()
try:
user_site = site.getusersitepackages()
if user_site and os.path.isdir(user_site) and user_site not in sys.path:
sys.path.append(user_site)
except Exception:
pass
if ssh_port and str(ssh_port).strip() not in ("", "22"):
_port = str(ssh_port).strip()
if not _SSH_PORT_RE.match(_port) or not (1 <= int(_port) <= 65535):
@@ -870,6 +977,7 @@ def setup_shell_routes() -> APIRouter:
for pkg in packages:
on_remote = bool(host and pkg.get("target") == "remote")
probe = None
if on_remote:
pkg["installed"] = bool(remote_status.get(pkg["name"], False))
probe = remote_details.get(pkg["name"])
@@ -883,19 +991,36 @@ def setup_shell_routes() -> APIRouter:
elif pkg["name"] == "llama_cpp" and shutil.which("llama-server"):
pkg["installed"] = True
pkg["status_note"] = f"native llama-server: {shutil.which('llama-server')}"
probe = {"binaries": {"llama-server": shutil.which("llama-server")}, "dists": {}}
elif pkg["name"] == "vllm":
_vllm_cli = shutil.which("vllm")
pkg["installed"] = _vllm_cli is not None
if pkg["installed"]:
try:
_vllm_version = importlib_metadata.version(_pip_dist_name(pkg))
except importlib_metadata.PackageNotFoundError:
_vllm_version = None
probe = {
"binaries": {"vllm": _vllm_cli},
"dists": {"vllm": _vllm_version} if _vllm_version else {},
}
pkg["status_note"] = _package_status_note("vllm", probe)
else:
try:
importlib.import_module(pkg["name"])
if pkg["name"] == "vllm":
pkg["installed"] = shutil.which("vllm") is not None
else:
importlib_metadata.version(pkg["name"].replace("_", "-"))
pkg["installed"] = True
importlib_metadata.version(_pip_dist_name(pkg))
pkg["installed"] = True
except ImportError:
pkg["installed"] = False
except importlib_metadata.PackageNotFoundError:
pkg["installed"] = False
if pkg.get("installed"):
update_status = _package_pip_update_status(pkg, probe)
pkg["pip_update_available"] = update_status.available
if update_status.note:
pkg["update_note"] = update_status.note
if pkg["name"] == "docker":
status = _docker_row_status(
on_remote=on_remote,
@@ -933,4 +1058,39 @@ def setup_shell_routes() -> APIRouter:
return {"ok": True, "output": stdout.decode()[-200:]}
return {"ok": False, "error": stderr.decode()[-300:]}
@router.post("/api/cookbook/rebuild-engine")
async def rebuild_engine(request: Request):
"""Clear the cached llama.cpp build so the next serve recompiles.
Admin only this removes the Cookbook-managed ``~/bin/llama-server``
symlink and ``~/llama.cpp/build`` directory, locally or on the selected
remote server. It installs and downloads nothing; the next llama.cpp
serve rebuilds from source and picks up CUDA/HIP if a toolchain is now
present. This is the missing "force a fresh GPU build" lever for hosts
stuck on a CPU-only llama-server.
"""
_require_admin(request)
from routes.cookbook_helpers import _llama_cpp_rebuild_cmd
body = await request.json()
engine = str(body.get("engine") or "llamacpp").strip()
if engine != "llamacpp":
return {"ok": False, "error": f"Unsupported engine: {engine}"}
host = str(body.get("remote_host") or "").strip()
ssh_port = body.get("ssh_port")
cmd = _llama_cpp_rebuild_cmd()
try:
argv = (_ssh_base_argv(host, ssh_port) + [cmd]) if host else ["bash", "-lc", cmd]
except ValueError as e:
raise HTTPException(400, str(e))
try:
proc = await asyncio.create_subprocess_exec(
*argv, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE
)
out, err = await asyncio.wait_for(proc.communicate(), timeout=30)
except asyncio.TimeoutError:
return {"ok": False, "error": "Rebuild-engine command timed out."}
if proc.returncode == 0:
return {"ok": True, "output": out.decode("utf-8", errors="replace")[-400:]}
return {"ok": False, "error": err.decode("utf-8", errors="replace")[-400:]}
return router
+41 -29
View File
@@ -79,6 +79,8 @@ def _skill_test_task(skill: dict) -> str:
an email); if we just hand over the 'when to use' text the agent has nothing
to work on and stalls asking for input. So we tell it to create its own
realistic fixture first, then apply the skill end-to-end."""
if not isinstance(skill, dict):
skill = {}
ctx = (skill.get("when_to_use") or skill.get("description") or skill.get("name") or "").strip()
return (
"Test this skill end-to-end. FIRST, set up a small realistic scenario it "
@@ -310,6 +312,8 @@ def _should_check_retrieval_precision(skill: dict) -> bool:
"installation", "install", "system", "ssh", "document", "documents",
"search", "email", "calendar", "gpu", "server", "python",
}
if not isinstance(skill, dict):
return False
tags = {str(t or "").strip().lower() for t in (skill.get("tags") or [])}
if tags & broad:
return True
@@ -463,13 +467,13 @@ async def _run_skill_test_job(key, name, md, task, url, model, headers, owner, s
if skills_manager is not None:
v = (job["verdict"] or {}).get("verdict") or "unknown"
try:
skills_manager.set_audit(name, v, by_teacher=False, worker_model=model)
skills_manager.set_audit(name, v, by_teacher=False, worker_model=model, owner=owner)
except Exception:
pass
conf = {"pass": 0.95, "needs_work": 0.6, "fail": 0.4}.get(v)
if conf is not None:
try:
skills_manager.update_skill(name, {"confidence": conf})
skills_manager.update_skill(name, {"confidence": conf}, owner=owner)
except Exception:
pass
job["status"] = "done"
@@ -563,6 +567,7 @@ def _skill_duplicate_blocker(skills_manager, name: str, owner) -> Optional[str]:
False,
[keeper_name],
f"Lower-priority duplicate of {keeper_name}",
owner=owner,
)
except Exception:
pass
@@ -629,7 +634,7 @@ def _audit_finalize_status(skills_manager, name: str, owner, verdict: str,
if generic_reason:
necessary = False
try:
skills_manager.set_necessity(name, False, [], generic_reason)
skills_manager.set_necessity(name, False, [], generic_reason, owner=owner)
except Exception:
pass
duplicate_of = _skill_duplicate_blocker(skills_manager, name, owner) if verdict == "pass" else None
@@ -638,7 +643,7 @@ def _audit_finalize_status(skills_manager, name: str, owner, verdict: str,
c = float(confidence or 0.0)
status = "published" if (auto_publish and necessary and verdict == "pass" and c >= min_conf) else "draft"
try:
skills_manager.update_skill(name, {"status": status})
skills_manager.update_skill(name, {"status": status}, owner=owner)
except Exception:
pass
return status
@@ -662,7 +667,7 @@ def _apply_skill_md(skills_manager, name: str, md: str, owner) -> bool:
"teacher_model": sk.teacher_model, "owner": sk.owner or owner,
"when_to_use": sk.when_to_use, "procedure": sk.procedure,
"pitfalls": sk.pitfalls, "verification": sk.verification, "body_extra": sk.body_extra,
}))
}, owner=owner))
except Exception as e:
logger.warning(f"Audit: could not save edited skill {name}: {e}")
return False
@@ -762,11 +767,11 @@ async def _audit_one_skill(skills_manager, skill, url, model, headers,
# earns a bit less; a skill that still fails is marked low.
def _set_conf(c):
try:
skills_manager.update_skill(name, {"confidence": c})
skills_manager.update_skill(name, {"confidence": c}, owner=owner)
except Exception:
pass
md = skills_manager.read_skill_md(name)
md = skills_manager.read_skill_md(name, owner=owner)
if not md:
log(f"{name}: no source — skipped")
return {"skill": name, "result": "skipped"}
@@ -788,7 +793,8 @@ async def _audit_one_skill(skills_manager, skill, url, model, headers,
nec = await _eval_skill_necessity(md, others, url, model, headers)
if nec is not None:
skills_manager.set_necessity(name, nec.get("necessary", True),
nec.get("redundant_with"), nec.get("reason"))
nec.get("redundant_with"), nec.get("reason"),
owner=owner)
if not nec.get("necessary", True):
log(f"{name}: possibly unnecessary — {nec.get('reason', '')[:80]}")
except Exception as e:
@@ -799,12 +805,12 @@ async def _audit_one_skill(skills_manager, skill, url, model, headers,
if generic_reason or duplicate_of or (isinstance(nec, dict) and nec.get("necessary") is False):
reason = generic_reason or (f"Lower-priority duplicate of {duplicate_of}" if duplicate_of else str((nec or {}).get("reason") or "Unnecessary skill"))
try:
skills_manager.update_skill(name, {"status": "draft", "confidence": 0.35})
skills_manager.set_audit(name, "skipped", by_teacher=False, worker_model=model)
skills_manager.update_skill(name, {"status": "draft", "confidence": 0.35}, owner=owner)
skills_manager.set_audit(name, "skipped", by_teacher=False, worker_model=model, owner=owner)
if duplicate_of:
skills_manager.set_necessity(name, False, [duplicate_of], reason)
skills_manager.set_necessity(name, False, [duplicate_of], reason, owner=owner)
else:
skills_manager.set_necessity(name, False, [], reason)
skills_manager.set_necessity(name, False, [], reason, owner=owner)
except Exception:
pass
log(f"{name}: draft — skipped functional test ({reason[:100]})")
@@ -848,13 +854,13 @@ async def _audit_one_skill(skills_manager, skill, url, model, headers,
if fixed and fixed.strip() != md.strip():
_apply_skill_md(skills_manager, name, fixed, owner)
_set_conf(0.95)
skills_manager.set_audit(name, "pass", by_teacher=False, worker_model=model)
skills_manager.set_audit(name, "pass", by_teacher=False, worker_model=model, owner=owner)
refreshed = next((s for s in skills_manager.load(owner=owner) if s.get("name") == name), None)
status = _audit_finalize_status(skills_manager, name, owner, "pass", 0.95, (refreshed or {}).get("necessity"), verdict)
log(f"{name}: {status} — confidence 95%")
return {"skill": name, "result": "pass", "verdict": verdict, "confidence": 0.95, "status": status}
if v in ("unknown", "inconclusive"):
skills_manager.set_audit(name, "inconclusive", by_teacher=False, worker_model=model)
skills_manager.set_audit(name, "inconclusive", by_teacher=False, worker_model=model, owner=owner)
status = _audit_finalize_status(skills_manager, name, owner, "inconclusive", skill.get("confidence") or 0.0, skill.get("necessity"))
log(f"{name}: {status} — inconclusive")
return {"skill": name, "result": "inconclusive", "verdict": verdict, "status": status}
@@ -869,7 +875,7 @@ async def _audit_one_skill(skills_manager, skill, url, model, headers,
log(f"{name}: retry (self) = {v}")
if v == "pass":
_set_conf(0.85)
skills_manager.set_audit(name, "pass", by_teacher=False, worker_model=model)
skills_manager.set_audit(name, "pass", by_teacher=False, worker_model=model, owner=owner)
refreshed = next((s for s in skills_manager.load(owner=owner) if s.get("name") == name), None)
status = _audit_finalize_status(skills_manager, name, owner, "pass", 0.85, (refreshed or {}).get("necessity"), verdict)
log(f"{name}: {status} — confidence 85% after self-edit")
@@ -893,7 +899,9 @@ async def _audit_one_skill(skills_manager, skill, url, model, headers,
log(f"{name}: retry on student after teacher rewrite = {v}")
if v == "pass":
_set_conf(0.8)
skills_manager.set_audit(name, "pass", by_teacher=True, worker_model=model, teacher_model=t_model)
skills_manager.set_audit(
name, "pass", by_teacher=True, worker_model=model, teacher_model=t_model, owner=owner
)
refreshed = next((s for s in skills_manager.load(owner=owner) if s.get("name") == name), None)
status = _audit_finalize_status(skills_manager, name, owner, "pass", 0.8, (refreshed or {}).get("necessity"), verdict)
log(f"{name}: {status} — confidence 80% after teacher rewrite")
@@ -901,13 +909,14 @@ async def _audit_one_skill(skills_manager, skill, url, model, headers,
# Still failing → demote to draft + low confidence + flag (do NOT delete).
try:
skills_manager.update_skill(name, {"status": "draft", "confidence": 0.35})
skills_manager.update_skill(name, {"status": "draft", "confidence": 0.35}, owner=owner)
except Exception:
pass
skills_manager.set_audit(
name, v or "fail", by_teacher=teacher_ran,
worker_model=model,
teacher_model=(teacher[1] if teacher_ran and teacher else ""),
owner=owner,
)
log(f"{name}: flagged — confidence lowered, kept as draft for manual review")
return {"skill": name, "result": "flagged", "verdict": verdict, "confidence": 0.35}
@@ -976,7 +985,7 @@ async def _run_audit_all_job(key, skills_manager, names, url, model, headers, te
job.pop("task", None)
def _resolve_audit_models():
def _resolve_audit_models(owner=None):
"""Resolve (url, model, headers, teacher) for an audit run from Settings.
Worker = Utility model (falling back to Default, normalized to a served
@@ -985,7 +994,7 @@ def _resolve_audit_models():
ValueError if no worker model.
"""
from src.endpoint_resolver import resolve_endpoint
url, model, headers = resolve_endpoint("utility")
url, model, headers = resolve_endpoint("utility", owner=owner)
if not url or not model:
raise ValueError("No model configured — set a Default or Utility model in Settings.")
try:
@@ -1029,7 +1038,7 @@ async def run_scheduled_skill_audit(skills_manager: SkillsManager,
return {"status": "running", "skipped": True}
try:
url, model, headers, teacher = _resolve_audit_models()
url, model, headers, teacher = _resolve_audit_models(owner=owner)
except ValueError as e:
logger.info(f"Scheduled skill audit skipped — {e}")
return {"status": "skipped", "reason": str(e)}
@@ -1246,7 +1255,7 @@ def setup_skills_routes(skills_manager: SkillsManager) -> APIRouter:
if not match:
raise HTTPException(404, "Skill not found")
_verify_owner(match, user)
md = skills_manager.read_skill_md(match.get("name"))
md = skills_manager.read_skill_md(match.get("name"), owner=user)
if md is None:
raise HTTPException(404, "Skill source unavailable (legacy entry?)")
return {"name": match.get("name"), "markdown": md}
@@ -1273,14 +1282,14 @@ def setup_skills_routes(skills_manager: SkillsManager) -> APIRouter:
raise HTTPException(404, "Skill not found")
_verify_owner(match, user)
name = match.get("name")
md = skills_manager.read_skill_md(name) or ""
md = skills_manager.read_skill_md(name, owner=user) or ""
if not task:
task = _skill_test_task(match)
# Prefer the configured DEFAULT (→ Utility) model — not the current chat
# session's model. Fall back to the caller's session model only if unset.
url, model, headers = resolve_endpoint("default")
url, model, headers = resolve_endpoint("default", owner=user)
if not url or not model:
url = url or ((body.get("endpoint_url") or "").strip() or None)
model = model or ((body.get("model") or "").strip() or None)
@@ -1360,7 +1369,7 @@ def setup_skills_routes(skills_manager: SkillsManager) -> APIRouter:
# Worker model (Default, normalized) + optional teacher — shared resolver.
try:
url, model, headers, teacher = _resolve_audit_models()
url, model, headers, teacher = _resolve_audit_models(owner=user)
except ValueError as e:
raise HTTPException(400, str(e))
@@ -1437,7 +1446,7 @@ def setup_skills_routes(skills_manager: SkillsManager) -> APIRouter:
@router.post("/{skill_id}/markdown")
async def save_skill_markdown(request: Request, skill_id: str):
"""Replace SKILL.md with new raw content. Parses + validates first."""
from services.memory.skill_format import Skill, slugify
from services.memory.skill_format import Skill
user = _owner(request)
body = await request.json()
new_content = body.get("markdown")
@@ -1452,7 +1461,10 @@ def setup_skills_routes(skills_manager: SkillsManager) -> APIRouter:
sk = Skill.from_markdown(new_content)
except Exception as e:
raise HTTPException(400, f"Could not parse SKILL.md: {e}")
sk.name = slugify(sk.name or match.get("name"))
# Never rename on save: a changed `name` in the markdown would move
# the skill dir (update_skill) and orphan the original id, so a later
# delete 404s (#1333). Pin to the stored name, like _apply_skill_md.
sk.name = match.get("name")
if not sk.owner:
sk.owner = match.get("owner") or user
ok = skills_manager.update_skill(match.get("name"), {
@@ -1474,7 +1486,7 @@ def setup_skills_routes(skills_manager: SkillsManager) -> APIRouter:
"pitfalls": sk.pitfalls,
"verification": sk.verification,
"body_extra": sk.body_extra,
})
}, owner=user)
if not ok:
raise HTTPException(500, "Update failed")
# Manual markdown edits can create or substantially rewrite a draft
@@ -1496,7 +1508,7 @@ def setup_skills_routes(skills_manager: SkillsManager) -> APIRouter:
updates = body.dict(exclude_none=True)
if not updates:
return {"ok": True}
ok = skills_manager.update_skill(match.get("name"), updates)
ok = skills_manager.update_skill(match.get("name"), updates, owner=user)
if not ok:
raise HTTPException(404, "Skill not found")
if not match.get("audit_verdict"):
@@ -1511,7 +1523,7 @@ def setup_skills_routes(skills_manager: SkillsManager) -> APIRouter:
if not match:
raise HTTPException(404, "Skill not found")
_verify_owner(match, user)
ok = skills_manager.delete_skill(match.get("name"))
ok = skills_manager.delete_skill(match.get("name"), owner=user)
if not ok:
raise HTTPException(404, "Skill not found")
return {"ok": True}
+14 -10
View File
@@ -8,6 +8,7 @@ from typing import List
import logging
from core.middleware import require_admin
from src.auth_helpers import get_current_user
from src.upload_handler import count_recent_uploads
logger = logging.getLogger(__name__)
@@ -24,15 +25,18 @@ def setup_upload_routes(upload_handler):
client_ip = request.client.host if request.client else "unknown"
out = []
# Limit concurrent uploads per IP
ip_upload_count = sum(
1 for f in files
if client_ip in upload_handler.upload_rate_log and
any(now > time.time() - 10 for now in upload_handler.upload_rate_log[client_ip][-len(files):])
# Limit concurrent uploads per IP. Count genuine recent upload events —
# NOT the number of files in this batch. The previous check summed over
# `files`, so a single multi-file request counted itself as N concurrent
# uploads and tripped the limit (issue #1346: "attach more than one file
# → the model doesn't even see them"). save_upload still enforces the
# per-minute sliding-window rate limit per file.
recent_uploads = count_recent_uploads(
upload_handler.upload_rate_log.get(client_ip, []), time.time()
)
if ip_upload_count >= upload_handler.max_concurrent_uploads:
if recent_uploads >= upload_handler.max_concurrent_uploads:
raise HTTPException(
status_code=429,
detail=f"Maximum concurrent uploads ({upload_handler.max_concurrent_uploads}) exceeded"
@@ -107,7 +111,7 @@ def setup_upload_routes(upload_handler):
if os.path.exists(uploads_db):
with open(uploads_db, encoding="utf-8") as f:
db = json.load(f)
info = next((fi for fi in db.values() if fi["id"] == file_id), None)
info = next((fi for fi in db.values() if fi.get("id") == file_id), None)
if info:
original_name = info.get("name", file_id)
auth_mgr = getattr(request.app.state, "auth_manager", None)
@@ -155,7 +159,7 @@ def setup_upload_routes(upload_handler):
if os.path.exists(uploads_db):
with open(uploads_db, encoding="utf-8") as f:
db = json.load(f)
info = next((fi for fi in db.values() if fi["id"] == file_id), None)
info = next((fi for fi in db.values() if fi.get("id") == file_id), None)
return info
def _vision_cache_path(file_id: str) -> str:
+15 -3
View File
@@ -61,7 +61,8 @@ def _find_bw() -> str:
def _load_config() -> dict:
if VAULT_FILE.exists():
try:
return json.loads(VAULT_FILE.read_text(encoding="utf-8"))
data = json.loads(VAULT_FILE.read_text(encoding="utf-8"))
return data if isinstance(data, dict) else {}
except Exception:
pass
return {}
@@ -75,11 +76,18 @@ def _save_config(cfg: dict):
safe_chmod(str(VAULT_FILE), 0o600)
async def _run_bw(args: list, session: str = None, input_text: str = None) -> tuple:
async def _run_bw(args: list, session: str = None, input_text: str = None,
bw_password: str = None) -> tuple:
env = {}
env.update(os.environ)
if session:
env["BW_SESSION"] = session
# Secrets must never be passed as argv — process arguments are world-readable
# via `ps` / `/proc/<pid>/cmdline` to any local user. Keep --passwordenv
# support for bw commands that need it; unlock/login callers should prefer
# stdin so the master password is not left in the child environment either.
if bw_password is not None:
env["BW_PASSWORD"] = bw_password
bw_path = _find_bw()
try:
proc = await asyncio.create_subprocess_exec(
@@ -175,8 +183,12 @@ def setup_vault_routes():
async def unlock(req: VaultUnlockRequest, request: Request):
"""Unlock the vault and save the session key."""
require_admin(request)
# Pass the master password on stdin, not argv. argv is visible through
# `ps` / /proc/<pid>/cmdline; stdin also avoids leaving the secret in
# the child process environment.
stdout, stderr, rc = await _run_bw(
["unlock", req.master_password, "--raw"],
["unlock", "--raw"],
input_text=req.master_password + "\n",
)
if rc != 0:
return {"ok": False, "error": f"Unlock failed: {stderr[:300]}"}
+46 -3
View File
@@ -26,6 +26,44 @@ MAX_MESSAGE_LEN = 32_000
from core.middleware import require_admin as _require_admin
def _first_enabled_endpoint(db, owner):
"""First enabled ModelEndpoint VISIBLE to `owner` — their own rows plus
legacy null-owner ("shared") rows. Owner-scoped on purpose: ModelEndpoint
is per-user (core/database.py "when non-null, the model picker only shows
the endpoint to that user"), and the sync-chat fallback uses the row's
decrypted `api_key`. An unscoped ``.first()`` would let a chat-scoped token
(e.g. a paired mobile device) fall back onto ANOTHER user's private
endpoint and silently spend that owner's API key / quota — and reach
whatever internal base_url they configured. Mirrors the owner_filter scoping
in routes/model_routes.py and companion/routes.py. A null/empty owner is a
no-op (single-user / legacy mode), preserving the original behaviour.
"""
from core.database import ModelEndpoint
from src.auth_helpers import owner_filter
q = db.query(ModelEndpoint).filter(ModelEndpoint.is_enabled == True) # noqa: E712
q = owner_filter(q, ModelEndpoint, owner)
return q.first()
def _caller_owns_session(sess_owner, caller) -> bool:
"""Strict session-ownership gate for the token-authenticated sync-chat
endpoint (`POST /api/v1/chat`).
Mirrors ``_verify_session_owner`` in session_routes.py and the null-owner
gates in notes/calendar/gallery: a caller may resume a session ONLY when
its owner matches them exactly. A null/empty session owner (legacy or
migrated rows) is deliberately NOT resumable by an arbitrary token the
old ``sess_owner and sess_owner != caller`` form skipped the check whenever
``sess_owner`` was falsy, so any chat-scoped token (e.g. a paired mobile
device) could resume such a session, inject a message, and read back its
history and reuse the owner's endpoint credentials. Fail closed: an
unresolvable caller also returns False.
"""
if not caller:
return False
return sess_owner == caller
def setup_webhook_routes(
webhook_manager: WebhookManager,
auth_manager,
@@ -159,6 +197,7 @@ def setup_webhook_routes(
"openrouter": "https://openrouter.ai/api/v1",
"ollama": "https://ollama.com/api",
"fireworks": "https://api.fireworks.ai/inference/v1",
"venice": "https://api.venice.ai/api/v1",
}
# Model prefix → provider mapping for auto-detection
@@ -203,7 +242,6 @@ def setup_webhook_routes(
from core.models import ChatMessage
from src.llm_core import llm_call_async
from core.database import ModelEndpoint
from src.endpoint_resolver import build_chat_url, build_headers, build_models_url, normalize_base
message = body.message.strip()
@@ -228,8 +266,11 @@ def setup_webhook_routes(
_tok_user = token_owner or getattr(request.state, "user", None) or _gcu(request)
except Exception:
_tok_user = None
# Strict ownership (see _caller_owns_session): fail closed so a
# null-owner / cross-owner session can't be resumed by an arbitrary
# chat-scoped token.
_sess_owner = getattr(sess, "owner", None)
if _tok_user and _sess_owner and _sess_owner != _tok_user:
if not _caller_owns_session(_sess_owner, _tok_user):
raise HTTPException(404, "Session not found")
# --- Case 2: Direct API key + model (no pre-configured endpoint needed) ---
@@ -265,7 +306,9 @@ def setup_webhook_routes(
if not sess:
db = SessionLocal()
try:
ep = db.query(ModelEndpoint).filter(ModelEndpoint.is_enabled == True).first()
# Owner-scoped: only THIS token owner's endpoints + legacy
# shared rows, never another user's private endpoint/api_key.
ep = _first_enabled_endpoint(db, token_owner)
finally:
db.close()