mirror of
https://github.com/pewdiepie-archdaemon/odysseus.git
synced 2026-06-30 08:32:07 -04:00
Merge remote-tracking branch 'origin/dev' into fix/native-agent-loop-guard-signals
# Conflicts: # src/agent_loop.py
This commit is contained in:
@@ -0,0 +1,31 @@
|
||||
import re
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
|
||||
_REMOTE_HOST_RE = re.compile(
|
||||
r"^(?:[A-Za-z0-9][A-Za-z0-9._-]*@)?[A-Za-z0-9][A-Za-z0-9._-]*$"
|
||||
)
|
||||
_SSH_PORT_RE = re.compile(r"^\d{1,5}$")
|
||||
|
||||
|
||||
def validate_remote_host(v: str | None) -> str | None:
|
||||
if v is None or v == "":
|
||||
return None
|
||||
if not _REMOTE_HOST_RE.match(v):
|
||||
raise HTTPException(
|
||||
400,
|
||||
"Invalid remote_host — must be host or user@host, no SSH option syntax",
|
||||
)
|
||||
return v
|
||||
|
||||
|
||||
def validate_ssh_port(v: str | None) -> str | None:
|
||||
if v is None or v == "":
|
||||
return None
|
||||
if not _SSH_PORT_RE.fullmatch(str(v)):
|
||||
raise HTTPException(400, "Invalid ssh_port")
|
||||
port = int(v)
|
||||
if port < 1 or port > 65535:
|
||||
raise HTTPException(400, "Invalid ssh_port")
|
||||
return str(port)
|
||||
@@ -25,6 +25,8 @@ ALLOWED_SCOPES = {
|
||||
"calendar:write",
|
||||
"memory:read",
|
||||
"memory:write",
|
||||
"cookbook:read",
|
||||
"cookbook:launch",
|
||||
}
|
||||
TOKEN_PROFILES = {
|
||||
"chat": ["chat"],
|
||||
@@ -65,6 +67,7 @@ def _normalize_scopes(scopes: str | list[str] | None = None, profile: str | None
|
||||
ensure_before("calendar:write", "calendar:read")
|
||||
ensure_before("memory:write", "memory:read")
|
||||
ensure_before("email:draft", "email:read")
|
||||
ensure_before("cookbook:launch", "cookbook:read")
|
||||
|
||||
return normalized or [DEFAULT_SCOPES]
|
||||
|
||||
@@ -151,6 +154,7 @@ def setup_api_token_routes() -> APIRouter:
|
||||
@router.patch("/tokens/{token_id}")
|
||||
async def update_token(request: Request, token_id: str):
|
||||
require_admin(request)
|
||||
current_user = get_current_user(request)
|
||||
try:
|
||||
payload = await request.json()
|
||||
except Exception:
|
||||
@@ -159,6 +163,8 @@ def setup_api_token_routes() -> APIRouter:
|
||||
token = db.query(ApiToken).filter(ApiToken.id == token_id).first()
|
||||
if not token:
|
||||
raise HTTPException(404, "Token not found")
|
||||
if current_user and token.owner != current_user:
|
||||
raise HTTPException(403, "Not your token")
|
||||
if isinstance(payload.get("name"), str) and payload["name"].strip():
|
||||
token.name = payload["name"].strip()[:MAX_NAME_LEN]
|
||||
# Only touch scopes when the caller actually sent them. A partial
|
||||
@@ -186,10 +192,14 @@ def setup_api_token_routes() -> APIRouter:
|
||||
@router.delete("/tokens/{token_id}")
|
||||
def delete_token(request: Request, token_id: str):
|
||||
require_admin(request)
|
||||
current_user = get_current_user(request)
|
||||
with get_db_session() as db:
|
||||
deleted = db.query(ApiToken).filter(ApiToken.id == token_id).delete()
|
||||
if not deleted:
|
||||
token = db.query(ApiToken).filter(ApiToken.id == token_id).first()
|
||||
if not token:
|
||||
raise HTTPException(404, "Token not found")
|
||||
if current_user and token.owner != current_user:
|
||||
raise HTTPException(403, "Not your token")
|
||||
db.delete(token)
|
||||
_invalidate_cache(request)
|
||||
return {"status": "deleted"}
|
||||
|
||||
|
||||
+162
-12
@@ -7,7 +7,13 @@ import asyncio
|
||||
import logging
|
||||
import os
|
||||
|
||||
import json
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
from core.atomic_io import atomic_write_json, atomic_write_text
|
||||
from core.auth import AuthManager
|
||||
from src.constants import DEEP_RESEARCH_DIR, MEMORY_FILE, SKILLS_DIR
|
||||
from src.rate_limiter import RateLimiter
|
||||
from src.settings_scrub import scrub_settings
|
||||
from src.settings import (
|
||||
@@ -291,9 +297,30 @@ def setup_auth_routes(auth_manager: AuthManager) -> APIRouter:
|
||||
if new_username in auth_manager.users:
|
||||
raise HTTPException(409, "Username already taken")
|
||||
|
||||
# Gate on auth first. Every mutation below is contingent on this
|
||||
# succeeding — doing it last meant a rejected rename (e.g. reserved
|
||||
# username) left file-backed owner fields already rewritten with no
|
||||
# way to roll them back.
|
||||
ok = auth_manager.rename_user(old_username, new_username, user)
|
||||
if not ok:
|
||||
raise HTTPException(400, "Cannot rename user")
|
||||
|
||||
def _rollback_auth_rename() -> bool:
|
||||
# On self-rename the admin session has already moved to the new
|
||||
# username, so the rollback must authenticate as the new user.
|
||||
rollback_user = new_username if user == old_username else user
|
||||
try:
|
||||
return bool(auth_manager.rename_user(new_username, old_username, rollback_user))
|
||||
except Exception as rollback_err:
|
||||
logger.error(
|
||||
"Failed to roll back auth rename %s -> %s after owner migration failure: %s",
|
||||
new_username, old_username, rollback_err,
|
||||
)
|
||||
return False
|
||||
|
||||
# Usernames are ownership keys for user data. Rename the common
|
||||
# owner-scoped DB rows before changing auth so the account keeps
|
||||
# access to its sessions, docs, email accounts, tasks, etc.
|
||||
# owner-scoped DB rows so the account keeps access to its sessions,
|
||||
# docs, email accounts, tasks, etc.
|
||||
try:
|
||||
from sqlalchemy import func
|
||||
from core.database import Base, SessionLocal
|
||||
@@ -316,6 +343,11 @@ def setup_auth_routes(auth_manager: AuthManager) -> APIRouter:
|
||||
db.close()
|
||||
except Exception as e:
|
||||
logger.error("Failed to rename owner references %s -> %s: %s", old_username, new_username, e)
|
||||
if not _rollback_auth_rename():
|
||||
logger.error(
|
||||
"Auth rename %s -> %s could not be rolled back after owner migration failure",
|
||||
old_username, new_username,
|
||||
)
|
||||
raise HTTPException(500, "Failed to rename user data")
|
||||
|
||||
# Per-user prefs are JSON-backed, not SQL-backed.
|
||||
@@ -335,9 +367,116 @@ def setup_auth_routes(auth_manager: AuthManager) -> APIRouter:
|
||||
except Exception as e:
|
||||
logger.warning("Failed to rename user prefs %s -> %s: %s", old_username, new_username, e)
|
||||
|
||||
ok = auth_manager.rename_user(old_username, new_username, user)
|
||||
if not ok:
|
||||
raise HTTPException(400, "Cannot rename user")
|
||||
# In-flight deep-research tasks live in the process-local
|
||||
# ResearchHandler registry. They are not covered by the persisted JSON
|
||||
# migration above, but the research routes filter and cancel by this
|
||||
# owner field while the job is running. Do this before sweeping
|
||||
# completed JSON files so a job that finishes during the rename saves
|
||||
# with the new owner or is caught by the disk sweep below.
|
||||
try:
|
||||
rh = getattr(request.app.state, "research_handler", None)
|
||||
rename_owner = getattr(rh, "rename_owner", None)
|
||||
if callable(rename_owner):
|
||||
rename_owner(old_username, new_username)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to rename active research tasks %s -> %s: %s", old_username, new_username, e)
|
||||
|
||||
# deep_research: each completed report is a standalone JSON file with
|
||||
# an `owner` field. research_routes filters by d.get("owner") == user,
|
||||
# so a stale owner makes every report invisible to the renamed user.
|
||||
try:
|
||||
dr_dir = Path(DEEP_RESEARCH_DIR)
|
||||
if dr_dir.is_dir():
|
||||
for p in dr_dir.glob("*.json"):
|
||||
try:
|
||||
d = json.loads(p.read_text(encoding="utf-8"))
|
||||
if str(d.get("owner", "")).strip().lower() == old_username:
|
||||
d["owner"] = new_username
|
||||
atomic_write_json(str(p), d)
|
||||
except Exception as err:
|
||||
logger.warning("Failed to update research owner in %s: %s", p.name, err)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to rename research owner references %s -> %s: %s", old_username, new_username, e)
|
||||
|
||||
# memory.json: a flat JSON array where each entry carries an `owner`
|
||||
# field. memory_manager.load(owner=user) filters on it, so stale
|
||||
# entries disappear from the memory panel.
|
||||
try:
|
||||
if os.path.isfile(MEMORY_FILE):
|
||||
with open(MEMORY_FILE, encoding="utf-8") as fh:
|
||||
entries = json.loads(fh.read())
|
||||
if isinstance(entries, list):
|
||||
changed = False
|
||||
for entry in entries:
|
||||
if isinstance(entry, dict) and str(entry.get("owner", "")).strip().lower() == old_username:
|
||||
entry["owner"] = new_username
|
||||
changed = True
|
||||
if changed:
|
||||
atomic_write_json(MEMORY_FILE, entries)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to rename memory.json owner references %s -> %s: %s", old_username, new_username, e)
|
||||
|
||||
# uploads.json: upload rows use owner metadata for access checks and
|
||||
# owner-prefixed index keys for dedupe. Rename both so attachments keep
|
||||
# resolving after the account username changes.
|
||||
try:
|
||||
upload_handler = getattr(request.app.state, "upload_handler", None)
|
||||
rename_owner = getattr(upload_handler, "rename_owner", None)
|
||||
if callable(rename_owner):
|
||||
rename_owner(old_username, new_username)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to rename upload owner references %s -> %s: %s", old_username, new_username, e)
|
||||
|
||||
# skills: SKILL.md frontmatter carries owner: <username>; the usage
|
||||
# sidecar (_usage.json) keys entries as owner::skill-name. Both must
|
||||
# be updated or the renamed user's Skills panel goes empty.
|
||||
try:
|
||||
skills_root = Path(SKILLS_DIR)
|
||||
if skills_root.is_dir():
|
||||
_owner_re = re.compile(
|
||||
r'(?m)^(owner:\s*)' + re.escape(old_username) + r'\s*$',
|
||||
re.IGNORECASE,
|
||||
)
|
||||
for p in skills_root.rglob("SKILL.md"):
|
||||
try:
|
||||
text = p.read_text(encoding="utf-8")
|
||||
new_text = _owner_re.sub(r'\g<1>' + new_username, text)
|
||||
if new_text != text:
|
||||
atomic_write_text(str(p), new_text)
|
||||
except Exception as err:
|
||||
logger.warning("Failed to update skill owner in %s: %s", p, err)
|
||||
usage_path = skills_root / "_usage.json"
|
||||
if usage_path.is_file():
|
||||
try:
|
||||
usage = json.loads(usage_path.read_text(encoding="utf-8"))
|
||||
if isinstance(usage, dict):
|
||||
new_usage = {}
|
||||
changed = False
|
||||
for k, v in usage.items():
|
||||
owner_part, sep, skill_part = k.partition("::")
|
||||
if sep and owner_part.lower() == old_username:
|
||||
new_usage[new_username + "::" + skill_part] = v
|
||||
changed = True
|
||||
else:
|
||||
new_usage[k] = v
|
||||
if changed:
|
||||
atomic_write_json(str(usage_path), new_usage)
|
||||
except Exception as err:
|
||||
logger.warning("Failed to update skills usage keys %s -> %s: %s", old_username, new_username, err)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to rename skills owner references %s -> %s: %s", old_username, new_username, e)
|
||||
|
||||
# The in-memory session cache (session_manager.sessions) stores each
|
||||
# session's owner at load time. Without this patch the renamed user's
|
||||
# sessions are invisible on the next /api/sessions call because
|
||||
# get_sessions_for_user does an exact `s.owner == username` comparison
|
||||
# against stale in-memory values.
|
||||
sm = getattr(request.app.state, "session_manager", None)
|
||||
if sm is not None:
|
||||
for sess in list(getattr(sm, "sessions", {}).values()):
|
||||
if str(getattr(sess, "owner", None) or "").strip().lower() == old_username:
|
||||
sess.owner = new_username
|
||||
|
||||
# The owner-rename loop above updated ApiToken.owner in the DB, but the
|
||||
# bearer-token cache still maps each token to the OLD owner. Without
|
||||
# refreshing it, the renamed user's API tokens resolve to the old (now
|
||||
@@ -378,7 +517,23 @@ def setup_auth_routes(auth_manager: AuthManager) -> APIRouter:
|
||||
user = _get_current_user(request)
|
||||
if not user or not auth_manager.is_admin(user):
|
||||
raise HTTPException(403, "Admin only")
|
||||
ok = auth_manager.delete_user(body.username, user)
|
||||
|
||||
def _invalidate_api_token_cache():
|
||||
try:
|
||||
invalidator = getattr(request.app.state, "invalidate_token_cache", None)
|
||||
if invalidator:
|
||||
invalidator()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
ok = auth_manager.delete_user(body.username, user)
|
||||
except Exception:
|
||||
# delete_user can touch ApiToken rows before a later auth-store write
|
||||
# fails. Dirty the bearer cache anyway so a partial token purge does
|
||||
# not leave already-cached tokens authenticating until restart.
|
||||
_invalidate_api_token_cache()
|
||||
raise
|
||||
if not ok:
|
||||
raise HTTPException(400, "Cannot delete user")
|
||||
# delete_user removes the user's ApiToken rows, but the bearer-auth
|
||||
@@ -386,12 +541,7 @@ def setup_auth_routes(auth_manager: AuthManager) -> APIRouter:
|
||||
# rebuilds when flagged dirty. Without this, a deleted user's already
|
||||
# cached token keeps authenticating until some other token op or a
|
||||
# restart clears the cache. Mirror what the token routes do.
|
||||
try:
|
||||
invalidator = getattr(request.app.state, "invalidate_token_cache", None)
|
||||
if invalidator:
|
||||
invalidator()
|
||||
except Exception:
|
||||
pass
|
||||
_invalidate_api_token_cache()
|
||||
return {"ok": True}
|
||||
|
||||
# ---- Feature visibility (admin-managed) ----
|
||||
|
||||
@@ -101,11 +101,17 @@ def setup_backup_routes(memory_manager, preset_manager, skills_manager) -> APIRo
|
||||
# ── Skills ──
|
||||
if "skills" in body and isinstance(body["skills"], list):
|
||||
existing = skills_manager.load_all()
|
||||
existing_names = {s.get("name") for s in existing if s.get("name")}
|
||||
existing_ids = {s.get("id") for s in existing if s.get("id")}
|
||||
# Dedup against THIS user's own skills only. Using every tenant's
|
||||
# rows (load_all) meant a skill whose id/name/title matched any
|
||||
# other user's was silently skipped, so the importing user lost
|
||||
# their own data — same cross-tenant bug fixed for memories above.
|
||||
# The full store is still saved back below.
|
||||
own = [s for s in existing if s.get("owner") == user]
|
||||
existing_names = {s.get("name") for s in own if s.get("name")}
|
||||
existing_ids = {s.get("id") for s in own if s.get("id")}
|
||||
existing_titles = {
|
||||
(s.get("title") or s.get("description") or "").strip().lower()
|
||||
for s in existing
|
||||
for s in own
|
||||
}
|
||||
added = 0
|
||||
for skill in body["skills"]:
|
||||
|
||||
@@ -851,28 +851,27 @@ def setup_calendar_routes() -> APIRouter:
|
||||
from src.caldav_sync import sync_caldav
|
||||
return await sync_caldav(owner)
|
||||
|
||||
|
||||
@router.delete("/calendars/{cal_id}")
|
||||
async def delete_calendar(cal_id: str, request: Request):
|
||||
async def delete_calendar(request: Request, cal_id: str):
|
||||
owner = _require_user(request)
|
||||
db = SessionLocal()
|
||||
try:
|
||||
cal = db.query(CalendarCal).filter(
|
||||
CalendarCal.id == cal_id,
|
||||
CalendarCal.owner == owner,
|
||||
).first()
|
||||
if not cal:
|
||||
raise HTTPException(404, "Calendar not found")
|
||||
cal = _get_or_404_calendar(db, cal_id, owner)
|
||||
db.query(CalendarEvent).filter(CalendarEvent.calendar_id == cal_id).delete()
|
||||
db.delete(cal)
|
||||
db.commit()
|
||||
return {"ok": True}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
logger.error("Failed to delete calendar %s: %s", cal_id, e)
|
||||
raise HTTPException(500, "Failed to delete calendar")
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
@router.get("/calendars")
|
||||
async def list_calendars(request: Request):
|
||||
owner = _require_user(request)
|
||||
@@ -1152,23 +1151,6 @@ def setup_calendar_routes() -> APIRouter:
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
@router.delete("/calendars/{cal_id}")
|
||||
async def delete_calendar(request: Request, cal_id: str):
|
||||
owner = _require_user(request)
|
||||
db = SessionLocal()
|
||||
try:
|
||||
cal = _get_or_404_calendar(db, cal_id, owner)
|
||||
db.query(CalendarEvent).filter(CalendarEvent.calendar_id == cal_id).delete()
|
||||
db.delete(cal)
|
||||
db.commit()
|
||||
return {"ok": True}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
return {"error": str(e)}
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
# Hard cap on ICS upload (ICS_MAX_BYTES, default 10 MB). Loading the whole
|
||||
# file into memory is unavoidable with python-icalendar, so an unbounded
|
||||
|
||||
+91
-5
@@ -615,6 +615,26 @@ async def build_chat_context(
|
||||
# Build messages
|
||||
messages = preface + sess.get_context_messages()
|
||||
|
||||
# Current date/time — injected as a standalone *user*-role context message
|
||||
# placed immediately before the latest user turn, NOT folded into the
|
||||
# system prompt. Its text changes every minute, and local OpenAI-compatible
|
||||
# backends (llama.cpp / LM Studio) key their KV-cache prefix off the
|
||||
# system message byte-for-byte; mixing ever-changing timestamp text into
|
||||
# it would invalidate the cached prefix on every request (issue #2927).
|
||||
# Placing it at the tail also keeps it out of the stable
|
||||
# preface+history prefix, so that prefix stays byte-identical turn over
|
||||
# turn (modulo the genuinely new history entries) and the cache survives.
|
||||
if not agent_mode:
|
||||
try:
|
||||
from src.user_time import current_datetime_context_message
|
||||
_dt_msg = current_datetime_context_message()
|
||||
if messages and messages[-1].get("role") == "user":
|
||||
messages.insert(len(messages) - 1, _dt_msg)
|
||||
else:
|
||||
messages.append(_dt_msg)
|
||||
except Exception:
|
||||
logger.debug("Failed to add current date/time context", exc_info=True)
|
||||
|
||||
# Auto-compact
|
||||
messages, context_length, was_compacted = await maybe_compact(
|
||||
sess, sess.endpoint_url, sess.model, messages, sess.headers, owner=user,
|
||||
@@ -911,6 +931,54 @@ def save_assistant_response(
|
||||
return None
|
||||
|
||||
|
||||
def _is_session_stream_active(session_id: str) -> bool:
|
||||
"""Best-effort check for "is a chat completion currently streaming for
|
||||
this session?" — used to keep background extraction from overlapping a
|
||||
main completion and competing for the local backend's processing slots
|
||||
(issue #2927). Lazily imports the route module's live registry to avoid
|
||||
a circular import (chat_routes imports this module at load time)."""
|
||||
try:
|
||||
from routes import chat_routes as _cr
|
||||
return session_id in getattr(_cr, "_active_streams", {})
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
async def _run_extraction_jobs_sequentially(session_id: str, jobs: list, max_wait_s: float = 120.0):
|
||||
"""Run queued background-extraction coroutines one at a time, only once
|
||||
no chat completion is actively streaming for this session.
|
||||
|
||||
As diagnosed in issue #2927, firing memory/skill extraction concurrently
|
||||
with the main chat completion (or with each other) makes them compete for
|
||||
the local backend's limited processing slots, evicting the main
|
||||
conversation's cached KV-cache checkpoint and forcing a full prompt
|
||||
re-evaluation on the next turn. Waiting for the stream to go idle and then
|
||||
running the jobs strictly in sequence keeps at most one "side" request in
|
||||
flight against the backend at any time, and never alongside the user's
|
||||
own conversation.
|
||||
"""
|
||||
# Wait for the triggering turn's own stream to finish winding down (it
|
||||
# almost always already has by the time this task gets scheduled — this
|
||||
# is a small safety margin, not the primary mechanism).
|
||||
waited = 0.0
|
||||
poll = 0.25
|
||||
while _is_session_stream_active(session_id) and waited < max_wait_s:
|
||||
await asyncio.sleep(poll)
|
||||
waited += poll
|
||||
|
||||
for name, job in jobs:
|
||||
# Re-check before each job: a fast follow-up message from the user
|
||||
# may have started a new stream for this session while we waited.
|
||||
waited = 0.0
|
||||
while _is_session_stream_active(session_id) and waited < max_wait_s:
|
||||
await asyncio.sleep(poll)
|
||||
waited += poll
|
||||
try:
|
||||
await job
|
||||
except Exception:
|
||||
logger.warning("[bg-extract] %s extraction job failed for session %s", name, session_id, exc_info=True)
|
||||
|
||||
|
||||
def run_post_response_tasks(
|
||||
sess,
|
||||
session_manager,
|
||||
@@ -933,7 +1001,22 @@ def run_post_response_tasks(
|
||||
extract_skills: bool = True,
|
||||
allow_background_extraction: bool = True,
|
||||
):
|
||||
"""Fire background tasks after a completed response: memory extraction, webhooks, auto-name, skill extraction."""
|
||||
"""Fire background tasks after a completed response: memory extraction, webhooks, auto-name, skill extraction.
|
||||
|
||||
Memory/skill extraction are queued to run *sequentially*, after the main
|
||||
completion stream for this session has fully wound down — never
|
||||
concurrently with it or with each other. As diagnosed in issue #2927,
|
||||
firing these "side" LLM calls in parallel with the main chat completion
|
||||
makes them compete for the local backend's limited processing slots
|
||||
(llama.cpp defaults to 4), evicting the main conversation's cached
|
||||
checkpoint and forcing a full prompt re-evaluation on the next turn. By
|
||||
the time this function runs the main response is already saved, but the
|
||||
extraction calls themselves are still async — queuing them through
|
||||
``_queue_background_extraction`` keeps them from overlapping the *next*
|
||||
turn's request too.
|
||||
"""
|
||||
_extraction_jobs: list = []
|
||||
|
||||
# Memory extraction — only every 4th message pair to avoid excess LLM calls
|
||||
_msg_count = len(sess.history) if hasattr(sess, 'history') else 0
|
||||
_should_extract = (_msg_count >= 4) and (_msg_count % 4 == 0)
|
||||
@@ -943,10 +1026,10 @@ def run_post_response_tasks(
|
||||
t_url, t_model, t_headers = resolve_task_endpoint(
|
||||
sess.endpoint_url, sess.model, sess.headers, owner=owner,
|
||||
)
|
||||
asyncio.create_task(extract_and_store(
|
||||
_extraction_jobs.append(("memory", extract_and_store(
|
||||
sess, memory_manager, memory_vector,
|
||||
t_url, t_model, t_headers,
|
||||
))
|
||||
)))
|
||||
|
||||
# Skill extraction from complex agent runs. Only when the user actually
|
||||
# chose agent mode — not a chat we auto-escalated for a notes/calendar
|
||||
@@ -982,12 +1065,15 @@ def run_post_response_tasks(
|
||||
sess.endpoint_url, sess.model, sess.headers, owner=owner,
|
||||
)
|
||||
logger.debug("[skill-extract] dispatching extractor (model=%s)", s_model)
|
||||
asyncio.create_task(maybe_extract_skill(
|
||||
_extraction_jobs.append(("skill", maybe_extract_skill(
|
||||
sess, skills_manager,
|
||||
s_url, s_model, s_headers,
|
||||
agent_rounds, agent_tool_calls,
|
||||
owner=owner,
|
||||
))
|
||||
)))
|
||||
|
||||
if _extraction_jobs:
|
||||
asyncio.create_task(_run_extraction_jobs_sequentially(session_id, _extraction_jobs))
|
||||
|
||||
# Token accumulation
|
||||
if last_metrics:
|
||||
|
||||
+57
-15
@@ -62,6 +62,33 @@ def _stream_set(session_id: str, **fields) -> None:
|
||||
rec.update(fields)
|
||||
|
||||
|
||||
def _resolve_request_workspace(request, raw_value) -> tuple:
|
||||
"""Resolve the posted workspace for this request: (workspace, rejected).
|
||||
|
||||
Privilege is checked BEFORE the path ever touches the filesystem. Only
|
||||
admin/single-user callers can use the workspace-backed file/shell tools,
|
||||
so only they get vet_workspace() and the workspace_rejected signal. For
|
||||
any other caller the submitted value is dropped uniformly, with no vetting
|
||||
and no event: otherwise the presence/absence of workspace_rejected would
|
||||
let a non-admin chat caller probe which host paths exist.
|
||||
|
||||
vet_workspace rejects non-directories, sensitive roots (.ssh, .gnupg,
|
||||
...), and filesystem roots; on rejection there is no confinement and the
|
||||
default tool-path allowlist applies. The rejected value is surfaced so the
|
||||
stream can tell an admin client (which believes a workspace is active)
|
||||
that it was dropped.
|
||||
"""
|
||||
requested = (raw_value or "").strip()
|
||||
if not requested:
|
||||
return "", ""
|
||||
from src.tool_security import owner_is_admin_or_single_user
|
||||
if not owner_is_admin_or_single_user(get_current_user(request)):
|
||||
return "", ""
|
||||
from src.tool_execution import vet_workspace
|
||||
workspace = vet_workspace(requested) or ""
|
||||
return workspace, (requested if not workspace else "")
|
||||
|
||||
|
||||
def _session_url_matches_endpoint(session_url: str, endpoint_base: str) -> bool:
|
||||
if not session_url or not endpoint_base:
|
||||
return False
|
||||
@@ -400,6 +427,7 @@ def setup_chat_routes(
|
||||
temperature=ctx.preset.temperature,
|
||||
max_tokens=ctx.preset.max_tokens,
|
||||
prompt_type=preset_id,
|
||||
session_id=session,
|
||||
)
|
||||
_clean_reply, _clean_md = clean_thinking_for_save(reply, {"model": sess.model})
|
||||
sess.add_message(ChatMessage("assistant", _clean_reply, metadata=_clean_md))
|
||||
@@ -446,20 +474,23 @@ def setup_chat_routes(
|
||||
use_research = form_data.get("use_research")
|
||||
time_filter = form_data.get("time_filter")
|
||||
preset_id = form_data.get("preset_id")
|
||||
allow_bash = form_data.get("allow_bash")
|
||||
allow_web_search = form_data.get("allow_web_search")
|
||||
# Issue #3229: API callers send JSON, not FormData. Read from the
|
||||
# JSON body as fallback so callers who send {"allow_bash": true}
|
||||
# actually get bash enabled.
|
||||
allow_bash = form_data.get("allow_bash") or (body or {}).get("allow_bash")
|
||||
allow_web_search = form_data.get("allow_web_search") or (body or {}).get("allow_web_search")
|
||||
use_rag = form_data.get("use_rag")
|
||||
search_context = form_data.get("search_context") # pre-fetched web search results (compare mode)
|
||||
compare_mode = str(form_data.get("compare_mode", "")).lower() == "true"
|
||||
incognito = str(form_data.get("incognito", "")).lower() == "true"
|
||||
plan_mode = str(form_data.get("plan_mode", "")).lower() == "true"
|
||||
# Plan mode is not part of the merge-ready UI. Ignore stale clients or
|
||||
# manual form posts that still send plan_mode=true.
|
||||
plan_mode = False
|
||||
chat_mode = str(form_data.get("mode", "")).lower() # 'chat' or 'agent'
|
||||
# Workspace: confine the agent's file/shell tools to this folder. Validate
|
||||
# it's a real directory; ignore (no confinement) otherwise.
|
||||
workspace = (form_data.get("workspace") or "").strip()
|
||||
if workspace:
|
||||
_ws_real = os.path.realpath(os.path.expanduser(workspace))
|
||||
workspace = _ws_real if os.path.isdir(_ws_real) else ""
|
||||
# Workspace: confine the agent's file/shell tools to this folder.
|
||||
workspace, workspace_rejected = _resolve_request_workspace(
|
||||
request, form_data.get("workspace")
|
||||
)
|
||||
# Plan mode is a modifier on agent mode — it only makes sense with tools.
|
||||
if plan_mode:
|
||||
chat_mode = "agent"
|
||||
@@ -638,7 +669,7 @@ def setup_chat_routes(
|
||||
# leak a doc that belongs to a DIFFERENT session.
|
||||
if not active_doc:
|
||||
try:
|
||||
from src.tool_implementations import get_active_document
|
||||
from src.agent_tools.document_tools import get_active_document
|
||||
_mem_id = get_active_document()
|
||||
if _mem_id:
|
||||
_mem_q = _doc_db.query(DBDocument).filter(DBDocument.id == _mem_id)
|
||||
@@ -659,9 +690,13 @@ def setup_chat_routes(
|
||||
|
||||
# Build disabled-tools set from frontend toggles + user privileges
|
||||
disabled_tools = set()
|
||||
if str(allow_bash).lower() != "true":
|
||||
# Only disable bash/web_search when the caller *explicitly* set them
|
||||
# to a falsy value. When unset (None), defer to per-user privilege
|
||||
# checks below — this lets admins with can_use_bash=True use bash
|
||||
# by default without having to send allow_bash in every request.
|
||||
if allow_bash is not None and str(allow_bash).lower() != "true":
|
||||
disabled_tools.add("bash")
|
||||
if str(allow_web_search).lower() != "true":
|
||||
if allow_web_search is not None and str(allow_web_search).lower() != "true":
|
||||
disabled_tools.add("web_search")
|
||||
disabled_tools.add("web_fetch")
|
||||
|
||||
@@ -764,6 +799,13 @@ def setup_chat_routes(
|
||||
# Register active stream for partial-save safety net
|
||||
_active_streams[session] = {"status": "streaming", "partial": "", "query": message, "is_research": effective_do_research, "mode": _effective_mode}
|
||||
|
||||
# The client sent a workspace the server refused to bind (deleted
|
||||
# folder, file path, sensitive dir, filesystem root). Tell it up
|
||||
# front so the UI can clear the pill instead of displaying a
|
||||
# confinement that is not actually in effect.
|
||||
if workspace_rejected:
|
||||
yield f"data: {json.dumps({'type': 'workspace_rejected', 'data': {'path': workspace_rejected}})}\n\n"
|
||||
|
||||
if ctx.preprocessed.attachment_meta:
|
||||
yield f"data: {json.dumps({'type': 'attachments', 'data': ctx.preprocessed.attachment_meta})}\n\n"
|
||||
|
||||
@@ -992,6 +1034,7 @@ def setup_chat_routes(
|
||||
max_tokens=ctx.preset.max_tokens,
|
||||
prompt_type=preset_id,
|
||||
tools=None,
|
||||
session_id=session,
|
||||
):
|
||||
if chunk.startswith("data: ") and not chunk.startswith("data: [DONE]"):
|
||||
try:
|
||||
@@ -1138,9 +1181,9 @@ def setup_chat_routes(
|
||||
tool_policy=tool_policy,
|
||||
owner=_user,
|
||||
fallbacks=_fallback_candidates,
|
||||
workspace=workspace or None,
|
||||
plan_mode=plan_mode,
|
||||
approved_plan=approved_plan or None,
|
||||
workspace=workspace or None,
|
||||
):
|
||||
if chunk.startswith("data: ") and not chunk.startswith("data: [DONE]"):
|
||||
try:
|
||||
@@ -1272,8 +1315,7 @@ def setup_chat_routes(
|
||||
# without waiting on the next streamed chunk.
|
||||
#
|
||||
# Normal chat/agent streams keep the DETACHED behavior below: they
|
||||
# survive the client closing the tab / navigating away (true
|
||||
# terminal-agent semantics). The SSE response just subscribes (replay
|
||||
# survive the client closing the tab / navigating away. The SSE response just subscribes (replay
|
||||
# buffered output + live); dropping the SSE only removes a subscriber —
|
||||
# the run keeps going and saves the assistant message on completion
|
||||
# regardless. Reconnect via /api/chat/resume.
|
||||
|
||||
@@ -729,8 +729,11 @@ def setup_contacts_routes():
|
||||
@router.post("/import")
|
||||
async def import_vcf(data: dict, _admin: str = Depends(require_admin)):
|
||||
"""Import contacts from .vcf or CSV. Body: {"vcf": "..."} or {"csv": "..."}."""
|
||||
text = data.get("vcf") or data.get("text") or ""
|
||||
csv_text = data.get("csv") or ""
|
||||
# Coerce defensively: a non-string vcf/text/csv (e.g. a number or list
|
||||
# in the JSON body) would otherwise reach .strip() and 500 with an
|
||||
# AttributeError instead of degrading to a clean "no data" response.
|
||||
text = str(data.get("vcf") or data.get("text") or "")
|
||||
csv_text = str(data.get("csv") or "")
|
||||
if text.strip():
|
||||
if "BEGIN:VCARD" not in text.upper():
|
||||
return {"success": False, "error": "No vCard data found"}
|
||||
|
||||
+46
-28
@@ -1,16 +1,19 @@
|
||||
"""cookbook_helpers.py — validators + small helpers shared by the cookbook routes.
|
||||
Extracted from cookbook_routes.py; the routes module imports the symbols it needs."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import ntpath
|
||||
import os
|
||||
import posixpath
|
||||
import re
|
||||
import shlex
|
||||
from pathlib import Path
|
||||
|
||||
from fastapi import HTTPException
|
||||
from pydantic import BaseModel
|
||||
|
||||
from routes._validators import validate_remote_host, validate_ssh_port
|
||||
from core.platform_compat import _ssh_exec_argv
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -30,20 +33,24 @@ _LOCAL_MODEL_ID_RE = re.compile(r"^[A-Za-z0-9][A-Za-z0-9._-]*$")
|
||||
_OLLAMA_MODEL_ID_RE = re.compile(r"^[A-Za-z0-9][A-Za-z0-9._:/-]{0,200}$")
|
||||
# Include pattern is a glob: allow typical safe glyphs only.
|
||||
_INCLUDE_RE = re.compile(r"^[A-Za-z0-9._\-*?/\[\]]+$")
|
||||
# Remote host: user@host (optionally with :port-free hostname parts).
|
||||
_REMOTE_HOST_RE = re.compile(r"^[A-Za-z0-9._-]+@[A-Za-z0-9._-]+$")
|
||||
# HF tokens and API tokens are url-safe base64-like.
|
||||
_TOKEN_RE = re.compile(r"^[A-Za-z0-9._~+/=-]+$")
|
||||
# Session IDs we mint look like "cookbook-deadbeef" or "serve-deadbeef".
|
||||
# Anything beyond plain alphanumerics + dash + underscore could break out
|
||||
# of the shell/PowerShell contexts the value lands in.
|
||||
_SESSION_ID_RE = re.compile(r"^[A-Za-z0-9_-]{1,64}$")
|
||||
_SSH_PORT_RE = re.compile(r"^\d{1,5}$")
|
||||
_GPU_LIST_RE = re.compile(r"^\d+(?:,\d+)*$")
|
||||
# A download target directory. Absolute or ~-relative path; safe path glyphs
|
||||
# only (no quotes, shell metacharacters, or spaces) since it lands in a shell
|
||||
# command. A leading ~ is expanded to $HOME at command-build time.
|
||||
_LOCAL_DIR_RE = re.compile(r"^~?/[A-Za-z0-9._/-]*$|^~$")
|
||||
# only (no quotes or shell metacharacters). Spaces are allowed because command
|
||||
# builders pass the value through quoted shell/Python contexts. The character
|
||||
# class uses ``\w`` — Unicode word characters under Python 3's default str
|
||||
# matching — so non-ASCII folder names pass validation too: Cyrillic, accented
|
||||
# Latin, CJK, e.g. ``/Volumes/Модели`` or ``D:\AI Models\Модели``. This stays
|
||||
# shell-safe: none of ``; & | ` $ '' "" () {}`` newlines etc. are in ``[\w. -]``,
|
||||
# so injection vectors remain rejected. A leading ~ is expanded to $HOME at
|
||||
# command-build time. (Drive letters stay ASCII: ``[A-Za-z]:``.)
|
||||
_LOCAL_DIR_RE = re.compile(r"^~?(?:/[\w. -]*)+$|^~$")
|
||||
_WINDOWS_LOCAL_DIR_RE = re.compile(r"^[A-Za-z]:[\\/](?:[\w. -]+(?:[\\/][\w. -]+)*[\\/]?)?$")
|
||||
_WINDOWS_DRIVE_PATH_RE = re.compile(r"^[A-Za-z]:[\\/]")
|
||||
|
||||
|
||||
@@ -77,14 +84,6 @@ def _validate_include(v: str | None) -> str | None:
|
||||
return v
|
||||
|
||||
|
||||
def _validate_remote_host(v: str | None) -> str | None:
|
||||
if v is None or v == "":
|
||||
return None
|
||||
if not _REMOTE_HOST_RE.match(v):
|
||||
raise HTTPException(400, "Invalid remote_host — must be user@host, no SSH option syntax")
|
||||
return v
|
||||
|
||||
|
||||
def _validate_token(v: str | None) -> str | None:
|
||||
if v is None or v == "":
|
||||
return None
|
||||
@@ -93,26 +92,43 @@ def _validate_token(v: str | None) -> str | None:
|
||||
return v
|
||||
|
||||
|
||||
def load_stored_hf_token(*, state_path: Path | str | None = None) -> str:
|
||||
"""Return the decrypted HF token from cookbook_state.json, else env fallback."""
|
||||
path = Path(state_path) if state_path else Path(os.environ.get("DATA_DIR", "data")) / "cookbook_state.json"
|
||||
token = ""
|
||||
if path.exists():
|
||||
try:
|
||||
state = json.loads(path.read_text(encoding="utf-8"))
|
||||
env = state.get("env") if isinstance(state, dict) else {}
|
||||
if isinstance(env, dict) and env.get("hfToken"):
|
||||
from src.secret_storage import decrypt
|
||||
token = decrypt(env.get("hfToken") or "")
|
||||
except Exception:
|
||||
token = ""
|
||||
if not token:
|
||||
token = (os.environ.get("HF_TOKEN") or os.environ.get("HUGGING_FACE_HUB_TOKEN") or "").strip()
|
||||
return token
|
||||
|
||||
|
||||
def _validate_local_dir(v: str | None) -> str | None:
|
||||
if v is None or v == "":
|
||||
return None
|
||||
if len(v) >= 2 and v[0] == v[-1] and v[0] in {"'", '"'}:
|
||||
v = v[1:-1]
|
||||
v = v.rstrip("/") or "/"
|
||||
if not _LOCAL_DIR_RE.match(v):
|
||||
raise HTTPException(400, "Invalid local_dir — must be an absolute or ~ path with no spaces or shell metacharacters")
|
||||
if not (_LOCAL_DIR_RE.match(v) or _WINDOWS_LOCAL_DIR_RE.match(v)):
|
||||
raise HTTPException(400, "Invalid local_dir — must be an absolute or ~ path with no shell metacharacters")
|
||||
# Reject path segments that start with '-' (option injection). '-' is in the
|
||||
# allowlist, so a dir like ``/models/-rf`` or ``D:\models\-rf`` could be read
|
||||
# as a CLI flag by hf/etc. — and quoting does NOT stop a value from being
|
||||
# parsed as an option. This is the one residual that command-build-time
|
||||
# quoting can't cover, so the guard lives here, keeping the safety wholly
|
||||
# inside the validator rather than relying on consumers.
|
||||
if any(seg.startswith("-") for seg in re.split(r"[\\/]", v) if seg):
|
||||
raise HTTPException(400, "Invalid local_dir — path segments cannot start with '-'")
|
||||
return v
|
||||
|
||||
|
||||
def _validate_ssh_port(v: str | None) -> str | None:
|
||||
if v is None or v == "":
|
||||
return None
|
||||
if not _SSH_PORT_RE.fullmatch(str(v)):
|
||||
raise HTTPException(400, "Invalid ssh_port")
|
||||
port = int(v)
|
||||
if port < 1 or port > 65535:
|
||||
raise HTTPException(400, "Invalid ssh_port")
|
||||
return str(port)
|
||||
|
||||
|
||||
def _validate_gpus(v: str | None) -> str | None:
|
||||
if v is None or v == "":
|
||||
return None
|
||||
@@ -124,7 +140,7 @@ def _validate_gpus(v: str | None) -> str | None:
|
||||
def _shell_path(p: str) -> str:
|
||||
"""Render a validated path for a double-quoted shell context, expanding a
|
||||
leading ~ to $HOME (single quotes wouldn't expand it). Safe because
|
||||
_validate_local_dir already restricts the charset."""
|
||||
_validate_local_dir already rejects quotes and shell metacharacters."""
|
||||
if p == "~":
|
||||
return '"$HOME"'
|
||||
if p.startswith("~/"):
|
||||
@@ -385,6 +401,7 @@ def _cached_model_scan_script(model_dirs: list[str] | None = None, add_hf_cache:
|
||||
" for root, dirs, fns in safe_walk(base):",
|
||||
" for fn in sorted(fns):",
|
||||
" if not fn.lower().endswith('.gguf'): continue",
|
||||
" if fn.startswith('._'): continue # macOS AppleDouble sidecar, not a real GGUF",
|
||||
" fp = os.path.join(root, fn)",
|
||||
" try: size = os.path.getsize(fp)",
|
||||
" except Exception: size = 0",
|
||||
@@ -787,6 +804,7 @@ def _llama_cpp_rebuild_cmd() -> str:
|
||||
|
||||
class ModelDownloadRequest(BaseModel):
|
||||
repo_id: str
|
||||
backend: str | None = None # "hf" (default) or "ollama"
|
||||
include: str | None = None # glob pattern e.g. "*Q4_K_M*"
|
||||
hf_token: str | None = None
|
||||
env_prefix: str | None = None # e.g. "source ~/venv/bin/activate"
|
||||
|
||||
@@ -0,0 +1,19 @@
|
||||
"""Pure helpers for shaping cookbook task output for the status response.
|
||||
|
||||
Kept dependency-free (no FastAPI / SQLAlchemy imports) so the behavior can be
|
||||
unit-tested without standing up the whole app.
|
||||
"""
|
||||
|
||||
|
||||
def error_aware_output_tail(full_snapshot: str, status: str) -> str:
|
||||
"""Return the trailing slice of a task log for the status response.
|
||||
|
||||
Failed tasks return the last 50 lines so the "Copy last 50 lines" action
|
||||
surfaces the actual error context (stack traces, build output). Running and
|
||||
other non-error tasks keep the cheaper 12-line tail to limit the payload on
|
||||
the 10s polling interval.
|
||||
"""
|
||||
if not full_snapshot:
|
||||
return ""
|
||||
tail_lines = 50 if status == "error" else 12
|
||||
return "\n".join(full_snapshot.splitlines()[-tail_lines:])
|
||||
+882
-352
File diff suppressed because it is too large
Load Diff
@@ -16,9 +16,18 @@ def setup_diagnostics_routes(
|
||||
rag_manager,
|
||||
rag_available: bool,
|
||||
research_handler,
|
||||
memory_vector=None,
|
||||
) -> APIRouter:
|
||||
router = APIRouter(tags=["diagnostics"])
|
||||
|
||||
@router.get("/api/diagnostics/services")
|
||||
async def get_service_health(request: Request) -> Dict[str, Any]:
|
||||
"""Consolidated degraded-state report for ChromaDB, SearXNG, email,
|
||||
ntfy, and provider endpoints. Non-intrusive probes — safe to poll."""
|
||||
require_admin(request)
|
||||
from src.service_health import collect_service_health
|
||||
return await collect_service_health(rag_manager, memory_vector)
|
||||
|
||||
@router.get("/api/db/stats")
|
||||
async def get_database_stats(request: Request) -> Dict[str, Any]:
|
||||
require_admin(request)
|
||||
|
||||
@@ -108,10 +108,10 @@ def setup_document_routes(session_manager, upload_handler=None) -> APIRouter:
|
||||
# to markdown for prose.
|
||||
language = req.language
|
||||
if not language:
|
||||
from src.tool_implementations import _looks_like_email_document, _sniff_doc_language
|
||||
from src.agent_tools.document_tools import _looks_like_email_document, _sniff_doc_language
|
||||
language = _sniff_doc_language(req.content)
|
||||
else:
|
||||
from src.tool_implementations import _looks_like_email_document
|
||||
from src.agent_tools.document_tools import _looks_like_email_document
|
||||
if _looks_like_email_document(req.content, req.title):
|
||||
language = "email"
|
||||
|
||||
@@ -643,7 +643,7 @@ def setup_document_routes(session_manager, upload_handler=None) -> APIRouter:
|
||||
# in-memory active-doc pointer so the last-resort injection
|
||||
# path doesn't re-surface this doc in a later chat (#1160).
|
||||
try:
|
||||
from src.tool_implementations import clear_active_document
|
||||
from src.agent_tools.document_tools import clear_active_document
|
||||
clear_active_document(doc_id)
|
||||
except Exception:
|
||||
pass
|
||||
@@ -672,7 +672,7 @@ def setup_document_routes(session_manager, upload_handler=None) -> APIRouter:
|
||||
# Closed/deleted — drop the in-memory active-doc pointer so it isn't
|
||||
# re-injected into a later, unrelated chat (#1160).
|
||||
try:
|
||||
from src.tool_implementations import clear_active_document
|
||||
from src.agent_tools.document_tools import clear_active_document
|
||||
clear_active_document(doc_id)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
+60
-16
@@ -304,6 +304,7 @@ OWNER_SCOPED_EMAIL_CACHE_TABLES = {
|
||||
"email_ai_replies",
|
||||
"email_calendar_extractions",
|
||||
"email_urgency_alerts",
|
||||
"sender_signatures",
|
||||
}
|
||||
|
||||
|
||||
@@ -341,6 +342,55 @@ def _ensure_owner_scoped_email_cache_table(conn, table: str, create_sql: str, co
|
||||
_lg.getLogger(__name__).warning(f"{table} owner-migration skipped: {_mig_e}")
|
||||
|
||||
|
||||
def _ensure_sender_signatures_table(conn):
|
||||
"""Create/migrate learned sender signatures to an owner-scoped cache."""
|
||||
create_sql = """
|
||||
CREATE TABLE IF NOT EXISTS sender_signatures (
|
||||
from_address TEXT,
|
||||
owner TEXT DEFAULT '',
|
||||
signature_text TEXT,
|
||||
sample_count INTEGER,
|
||||
last_built_at TEXT NOT NULL,
|
||||
model_used TEXT,
|
||||
source TEXT,
|
||||
PRIMARY KEY (from_address, owner)
|
||||
)
|
||||
"""
|
||||
conn.execute(create_sql)
|
||||
try:
|
||||
info = conn.execute("PRAGMA table_info(sender_signatures)").fetchall()
|
||||
cols = [r[1] for r in info]
|
||||
pk_cols = [r[1] for r in sorted((r for r in info if r[5]), key=lambda r: r[5])]
|
||||
if "owner" in cols and pk_cols == ["from_address", "owner"]:
|
||||
return
|
||||
|
||||
conn.execute("ALTER TABLE sender_signatures RENAME TO sender_signatures__old")
|
||||
conn.execute(create_sql)
|
||||
old_cols = [r[1] for r in conn.execute("PRAGMA table_info(sender_signatures__old)").fetchall()]
|
||||
copy_cols = [
|
||||
c for c in (
|
||||
"from_address",
|
||||
"signature_text",
|
||||
"sample_count",
|
||||
"last_built_at",
|
||||
"model_used",
|
||||
"source",
|
||||
)
|
||||
if c in old_cols
|
||||
]
|
||||
source_owner = "COALESCE(owner, '')" if "owner" in old_cols else "''"
|
||||
conn.execute(
|
||||
f"INSERT OR IGNORE INTO sender_signatures "
|
||||
f"({', '.join([*copy_cols, 'owner'])}) "
|
||||
f"SELECT {', '.join([*copy_cols, source_owner])} "
|
||||
f"FROM sender_signatures__old"
|
||||
)
|
||||
conn.execute("DROP TABLE sender_signatures__old")
|
||||
except Exception as _mig_e:
|
||||
import logging as _lg
|
||||
_lg.getLogger(__name__).warning(f"sender_signatures owner-migration skipped: {_mig_e}")
|
||||
|
||||
|
||||
def attachment_extract_dir(folder: str, uid: str) -> Path:
|
||||
"""Containment-safe extraction directory for an attachment.
|
||||
|
||||
@@ -559,20 +609,10 @@ def _init_scheduled_db():
|
||||
conn.execute("ALTER TABLE email_boundaries ADD COLUMN turns_json TEXT")
|
||||
except Exception:
|
||||
pass
|
||||
# Per-sender signature cache. Populated by `learn_sender_signatures`
|
||||
# action: the LLM extracts the common trailing block across N emails
|
||||
# from each sender; the renderer folds it consistently for every
|
||||
# future email from that address.
|
||||
conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS sender_signatures (
|
||||
from_address TEXT PRIMARY KEY,
|
||||
signature_text TEXT,
|
||||
sample_count INTEGER,
|
||||
last_built_at TEXT NOT NULL,
|
||||
model_used TEXT,
|
||||
source TEXT
|
||||
)
|
||||
""")
|
||||
# Per-sender signature cache. Populated by `learn_sender_signatures`.
|
||||
# Message sender addresses are global, so signatures must be scoped to the
|
||||
# mailbox owner before `/read` returns them to the renderer.
|
||||
_ensure_sender_signatures_table(conn)
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
@@ -762,10 +802,14 @@ def _open_imap_connection(host: str, port: int, *, starttls: bool, timeout: int
|
||||
imaplib._MAXLINE = 50_000_000
|
||||
return conn
|
||||
|
||||
def _imap_connect(account_id: str | None = None, owner: str = ""):
|
||||
def _imap_connect(account_id: str | None = None, owner: str = "",
|
||||
timeout: int = _IMAP_TIMEOUT_SECONDS):
|
||||
# SECURITY: passing `owner` scopes the fallback config lookup so a brand
|
||||
# new user doesn't get connected against another user's default mailbox
|
||||
# when they have no account configured.
|
||||
#
|
||||
# `timeout` is overridable so short-lived callers (e.g. the service-health
|
||||
# probe) can impose a tighter budget than the default IMAP timeout.
|
||||
cfg = _get_email_config(account_id, owner=owner)
|
||||
# Connection mode:
|
||||
# STARTTLS on → plain + upgrade
|
||||
@@ -778,7 +822,7 @@ def _imap_connect(account_id: str | None = None, owner: str = ""):
|
||||
cfg["imap_host"],
|
||||
cfg["imap_port"],
|
||||
starttls=bool(cfg.get("imap_starttls")),
|
||||
timeout=_IMAP_TIMEOUT_SECONDS,
|
||||
timeout=timeout,
|
||||
)
|
||||
try:
|
||||
conn.login(cfg["imap_user"], cfg["imap_password"])
|
||||
|
||||
+52
-24
@@ -249,6 +249,41 @@ def _uid_from_fetch_meta(meta_b: bytes) -> str:
|
||||
return m.group(1).decode() if m else ""
|
||||
|
||||
|
||||
_FETCH_SEQ_RE = re.compile(rb"^(\d+)\s+\(")
|
||||
|
||||
|
||||
def _group_uid_fetch_records(msg_data) -> list:
|
||||
"""Group an imaplib UID FETCH response into per-message (meta, payload).
|
||||
|
||||
imaplib yields an interleaved list: ``(meta, literal)`` tuples for
|
||||
attributes that carry a literal (``RFC822.HEADER {n}`` etc.) plus bare
|
||||
``bytes`` elements for everything the server sends outside a literal.
|
||||
Where each attribute lands is server-specific: Dovecot sends FLAGS
|
||||
*before* the header literal (so it ends up inside the tuple meta), while
|
||||
Gmail sends FLAGS *after* it, arriving as a bare ``b' FLAGS (\\Seen))'``
|
||||
element. Dropping bare elements therefore silently loses FLAGS on Gmail
|
||||
and every message renders as unread/unflagged.
|
||||
|
||||
A tuple whose meta starts with a sequence number opens a new record;
|
||||
every other part — continuation tuple or bare bytes — is folded into the
|
||||
current record's meta so attribute regexes see the full meta text.
|
||||
Plain ``b')'`` terminators get folded in too, which is harmless.
|
||||
"""
|
||||
grouped: list = [] # list of (meta_bytes, payload_bytes_or_None)
|
||||
for part in (msg_data or []):
|
||||
if isinstance(part, tuple):
|
||||
meta_b = part[0] if isinstance(part[0], (bytes, bytearray)) else str(part[0]).encode()
|
||||
if _FETCH_SEQ_RE.match(meta_b):
|
||||
grouped.append((meta_b, part[1]))
|
||||
elif grouped:
|
||||
cur_meta, cur_payload = grouped[-1]
|
||||
grouped[-1] = (cur_meta + b" " + meta_b, cur_payload or part[1])
|
||||
elif isinstance(part, (bytes, bytearray)) and grouped:
|
||||
cur_meta, cur_payload = grouped[-1]
|
||||
grouped[-1] = (cur_meta + b" " + bytes(part), cur_payload)
|
||||
return grouped
|
||||
|
||||
|
||||
def _smtp_ready(cfg: dict) -> bool:
|
||||
return bool(cfg.get("smtp_host") and cfg.get("smtp_user") and cfg.get("smtp_password"))
|
||||
|
||||
@@ -799,20 +834,11 @@ def setup_email_routes():
|
||||
except Exception as e:
|
||||
logger.warning(f"Batch fetch failed, falling back to per-UID: {e}")
|
||||
status, msg_data = "NO", []
|
||||
# imaplib batch responses interleave (meta, payload) tuples and
|
||||
# `b')'` terminators. Group by message: each tuple where the
|
||||
# meta begins with a seq number starts a new message record.
|
||||
seq_re = re.compile(rb'^(\d+)\s+\(')
|
||||
grouped = [] # list of (meta_str, payload_bytes)
|
||||
for part in (msg_data or []):
|
||||
if isinstance(part, tuple):
|
||||
meta_b = part[0] if isinstance(part[0], (bytes, bytearray)) else str(part[0]).encode()
|
||||
if seq_re.match(meta_b):
|
||||
grouped.append((meta_b, part[1]))
|
||||
elif grouped:
|
||||
# continuation of previous message — concatenate meta info if any
|
||||
cur_meta, cur_payload = grouped[-1]
|
||||
grouped[-1] = (cur_meta + b" " + meta_b, cur_payload or part[1])
|
||||
# Group the batched response into per-message (meta, payload)
|
||||
# records. Bare bytes parts must be kept: Gmail returns FLAGS
|
||||
# after the header literal as a bare element, and dropping it
|
||||
# rendered every Gmail message as unread/unflagged.
|
||||
grouped = _group_uid_fetch_records(msg_data)
|
||||
|
||||
if status != "OK" and not grouped:
|
||||
conn.logout()
|
||||
@@ -1098,14 +1124,15 @@ def setup_email_routes():
|
||||
continue
|
||||
raw_header = None
|
||||
flags = ""
|
||||
for part in msg_data:
|
||||
if isinstance(part, tuple):
|
||||
meta = part[0].decode() if isinstance(part[0], bytes) else str(part[0])
|
||||
if b"RFC822.HEADER" in part[0] if isinstance(part[0], bytes) else "RFC822.HEADER" in meta:
|
||||
raw_header = part[1]
|
||||
flag_match = re.search(r'FLAGS \(([^)]*)\)', meta)
|
||||
if flag_match:
|
||||
flags = flag_match.group(1)
|
||||
# Same Gmail caveat as the list route: FLAGS may
|
||||
# arrive after the header literal, so group bare
|
||||
# parts back into the message meta before scanning.
|
||||
for meta_b, payload in _group_uid_fetch_records(msg_data):
|
||||
if payload and b"RFC822.HEADER" in meta_b:
|
||||
raw_header = payload
|
||||
flag_match = re.search(rb'FLAGS \(([^)]*)\)', meta_b)
|
||||
if flag_match:
|
||||
flags = flag_match.group(1).decode(errors="replace")
|
||||
if not raw_header:
|
||||
continue
|
||||
msg = email_mod.message_from_bytes(raw_header)
|
||||
@@ -1247,8 +1274,9 @@ def setup_email_routes():
|
||||
try:
|
||||
if sender_addr:
|
||||
_rs = _c.execute(
|
||||
"SELECT signature_text FROM sender_signatures WHERE from_address = ?",
|
||||
(sender_addr.lower().strip(),),
|
||||
f"SELECT signature_text FROM sender_signatures "
|
||||
f"WHERE from_address = ? AND {owner_clause}",
|
||||
(sender_addr.lower().strip(), *owner_params),
|
||||
).fetchone()
|
||||
if _rs and _rs[0]:
|
||||
cached_sender_sig = _rs[0]
|
||||
|
||||
@@ -11,6 +11,7 @@ from typing import Dict, Any, Optional
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.database import GalleryImage
|
||||
from src.auth_helpers import _auth_disabled
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -120,19 +121,18 @@ def _image_to_dict(img: GalleryImage, session_name: str = None) -> Dict[str, Any
|
||||
}
|
||||
|
||||
|
||||
def _owner_filter(q, user):
|
||||
def _owner_filter(q, user, model_cls=GalleryImage):
|
||||
"""Apply owner filtering to a gallery query.
|
||||
|
||||
When auth is disabled (single-user mode) get_current_user returns None
|
||||
and there is no per-user scoping. The main library list and stats already
|
||||
treat None as "show everything" (`if user is not None`), so this helper
|
||||
must too — otherwise the tag/model filter sidebars come back empty and the
|
||||
tag-cleanup endpoints (clear-user-tags, clear-ai-tags, dedupe-tags)
|
||||
silently affect zero rows in the most common self-hosted deployment.
|
||||
``get_current_user`` returns None both in auth-disabled single-user mode
|
||||
and when auth is enabled but no current user was resolved. Preserve the
|
||||
single-user behavior, but fail closed for auth-enabled null-user states.
|
||||
"""
|
||||
if user is None:
|
||||
if user is not None:
|
||||
return q.filter(model_cls.owner == user)
|
||||
if _auth_disabled():
|
||||
return q
|
||||
return q.filter(GalleryImage.owner == user)
|
||||
return q.filter(False)
|
||||
|
||||
|
||||
|
||||
|
||||
+10
-15
@@ -476,8 +476,7 @@ def setup_gallery_routes() -> APIRouter:
|
||||
.outerjoin(DbSession, GalleryImage.session_id == DbSession.id)
|
||||
.filter(GalleryImage.is_active == True)
|
||||
)
|
||||
if user is not None:
|
||||
q = q.filter(GalleryImage.owner == user)
|
||||
q = _owner_filter(q, user)
|
||||
|
||||
# Search filter (prompt + tags + ai_tags)
|
||||
if search:
|
||||
@@ -579,28 +578,26 @@ def setup_gallery_routes() -> APIRouter:
|
||||
db = SessionLocal()
|
||||
try:
|
||||
q = db.query(GalleryAlbum)
|
||||
if user:
|
||||
q = q.filter(GalleryAlbum.owner == user)
|
||||
q = _owner_filter(q, user, GalleryAlbum)
|
||||
albums = q.order_by(GalleryAlbum.created_at.desc()).all()
|
||||
result = []
|
||||
for a in albums:
|
||||
_count_q = db.query(GalleryImage).filter(
|
||||
GalleryImage.album_id == a.id, GalleryImage.is_active == True
|
||||
)
|
||||
if user:
|
||||
_count_q = _count_q.filter(GalleryImage.owner == user)
|
||||
_count_q = _owner_filter(_count_q, user)
|
||||
count = _count_q.count()
|
||||
cover_url = None
|
||||
if a.cover_id:
|
||||
cover = db.query(GalleryImage).filter(GalleryImage.id == a.cover_id).first()
|
||||
cover_q = db.query(GalleryImage).filter(GalleryImage.id == a.cover_id)
|
||||
cover = _owner_filter(cover_q, user).first()
|
||||
if cover:
|
||||
cover_url = f"/api/generated-image/{cover.filename}"
|
||||
elif count > 0:
|
||||
_cover_q = db.query(GalleryImage).filter(
|
||||
GalleryImage.album_id == a.id, GalleryImage.is_active == True
|
||||
)
|
||||
if user:
|
||||
_cover_q = _cover_q.filter(GalleryImage.owner == user)
|
||||
_cover_q = _owner_filter(_cover_q, user)
|
||||
first = _cover_q.order_by(GalleryImage.created_at.desc()).first()
|
||||
if first:
|
||||
cover_url = f"/api/generated-image/{first.filename}"
|
||||
@@ -643,10 +640,9 @@ def setup_gallery_routes() -> APIRouter:
|
||||
base = db.query(GalleryImage).filter(GalleryImage.is_active == True)
|
||||
size_q = db.query(func.sum(GalleryImage.file_size)).filter(GalleryImage.is_active == True)
|
||||
album_q = db.query(GalleryAlbum)
|
||||
if user:
|
||||
base = base.filter(GalleryImage.owner == user)
|
||||
size_q = size_q.filter(GalleryImage.owner == user)
|
||||
album_q = album_q.filter(GalleryAlbum.owner == user)
|
||||
base = _owner_filter(base, user)
|
||||
size_q = _owner_filter(size_q, user)
|
||||
album_q = _owner_filter(album_q, user, GalleryAlbum)
|
||||
total = base.count()
|
||||
total_size = size_q.scalar() or 0
|
||||
fav_count = base.filter(GalleryImage.favorite == True).count()
|
||||
@@ -674,8 +670,7 @@ def setup_gallery_routes() -> APIRouter:
|
||||
GalleryImage.is_active == True,
|
||||
(GalleryImage.ai_tags == None) | (GalleryImage.ai_tags == ""),
|
||||
)
|
||||
if user:
|
||||
q = q.filter(GalleryImage.owner == user)
|
||||
q = _owner_filter(q, user)
|
||||
if album_id:
|
||||
q = q.filter(GalleryImage.album_id == album_id)
|
||||
untagged = q.count()
|
||||
|
||||
+41
-4
@@ -1,7 +1,9 @@
|
||||
import re
|
||||
from copy import deepcopy
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import APIRouter, HTTPException
|
||||
|
||||
from routes._validators import validate_remote_host, validate_ssh_port
|
||||
|
||||
|
||||
# Backends the manual hardware simulator accepts. Must stay a subset of what
|
||||
@@ -11,6 +13,14 @@ from fastapi import APIRouter
|
||||
_MANUAL_BACKENDS = {"cuda", "rocm", "metal", "cpu_x86", "cpu_arm"}
|
||||
|
||||
|
||||
def _validate_detection_target(host: str = "", ssh_port: str = "") -> tuple[str, str]:
|
||||
host_value = validate_remote_host(host) or ""
|
||||
port_value = validate_ssh_port(ssh_port) or ""
|
||||
if port_value and not host_value:
|
||||
raise HTTPException(400, "ssh_port requires host")
|
||||
return host_value, port_value
|
||||
|
||||
|
||||
def _apply_manual_hardware(system, manual_mode="", manual_gpu_count="", manual_vram_gb="", manual_ram_gb="", manual_backend=""):
|
||||
"""Manual hardware is a "what if I had this setup" simulator —
|
||||
REPLACES the detected hardware entirely instead of adding to it.
|
||||
@@ -105,6 +115,7 @@ def setup_hwfit_routes():
|
||||
"""Detect and return current system hardware info. Pass host=user@server for remote.
|
||||
fresh=true bypasses the per-host cache (the Rescan button)."""
|
||||
from services.hwfit.hardware import detect_system
|
||||
host, ssh_port = _validate_detection_target(host, ssh_port)
|
||||
return detect_system(host=host, ssh_port=ssh_port, platform=platform, fresh=fresh)
|
||||
|
||||
@router.get("/models")
|
||||
@@ -118,6 +129,7 @@ def setup_hwfit_routes():
|
||||
from services.hwfit.hardware import detect_system
|
||||
from services.hwfit.fit import rank_models
|
||||
from services.hwfit.models import get_models, model_catalog_path
|
||||
host, ssh_port = _validate_detection_target(host, ssh_port)
|
||||
system = deepcopy(detect_system(host=host, ssh_port=ssh_port, platform=platform, fresh=fresh))
|
||||
if system.get("error"):
|
||||
return {"system": system, "models": [], "error": system["error"]}
|
||||
@@ -165,8 +177,14 @@ def setup_hwfit_routes():
|
||||
system["gpu_name"] = g["name"]
|
||||
system["active_group"] = {**g, "use_count": n}
|
||||
|
||||
if gpu_count != "":
|
||||
n = int(gpu_count)
|
||||
# Parse the optional count defensively (matches the gpu_group guard
|
||||
# above): a non-numeric query param previously raised ValueError ->
|
||||
# HTTP 500. A malformed value is ignored, same as omitting it.
|
||||
try:
|
||||
n = int(gpu_count) if gpu_count != "" else None
|
||||
except ValueError:
|
||||
n = None
|
||||
if n is not None:
|
||||
if n == 0:
|
||||
# RAM-only mode: rank against system memory, offload allowed.
|
||||
system["has_gpu"] = False
|
||||
@@ -196,7 +214,24 @@ def setup_hwfit_routes():
|
||||
if target_context is not None:
|
||||
target_context = max(1024, min(target_context, 1000000))
|
||||
|
||||
results = rank_models(system, use_case=use_case or None, limit=limit, search=search or None, sort=sort, quant=quant or None, target_context=target_context, fit_only=fit_only)
|
||||
rank_kwargs = {
|
||||
"use_case": use_case or None,
|
||||
"limit": limit,
|
||||
"search": search or None,
|
||||
"sort": sort,
|
||||
"quant": quant or None,
|
||||
"fit_only": fit_only,
|
||||
}
|
||||
if target_context is not None:
|
||||
rank_kwargs["target_context"] = target_context
|
||||
try:
|
||||
import inspect
|
||||
supported = set(inspect.signature(rank_models).parameters)
|
||||
rank_kwargs = {k: v for k, v in rank_kwargs.items() if k in supported}
|
||||
except Exception:
|
||||
rank_kwargs.pop("target_context", None)
|
||||
rank_kwargs.pop("fit_only", None)
|
||||
results = rank_models(system, **rank_kwargs)
|
||||
return {"system": system, "models": results}
|
||||
|
||||
@router.get("/profiles")
|
||||
@@ -212,6 +247,7 @@ def setup_hwfit_routes():
|
||||
from services.hwfit.hardware import detect_system
|
||||
from services.hwfit.models import get_models
|
||||
from services.hwfit.profiles import compute_serve_profiles
|
||||
host, ssh_port = _validate_detection_target(host, ssh_port)
|
||||
system = detect_system(host=host, ssh_port=ssh_port, platform=platform, fresh=fresh)
|
||||
if system.get("error"):
|
||||
return {"system": system, "profiles": [], "error": system["error"]}
|
||||
@@ -262,6 +298,7 @@ def setup_hwfit_routes():
|
||||
"""Rank image generation models against detected hardware."""
|
||||
from services.hwfit.hardware import detect_system
|
||||
from services.hwfit.image_models import rank_image_models
|
||||
host, ssh_port = _validate_detection_target(host, ssh_port)
|
||||
system = deepcopy(detect_system(host=host, ssh_port=ssh_port, platform=platform, fresh=fresh))
|
||||
if system.get("error"):
|
||||
return {"system": system, "models": [], "error": system["error"]}
|
||||
|
||||
+18
-2
@@ -105,6 +105,13 @@ def setup_memory_routes(memory_manager: MemoryManager, session_manager: SessionM
|
||||
if memory_manager.find_duplicates(text, user_mem):
|
||||
return {"ok": True, "count": len(user_mem), "message": "Memory already exists"}
|
||||
|
||||
if memory_data.session_id:
|
||||
try:
|
||||
session_obj = session_manager.get_session(memory_data.session_id)
|
||||
except KeyError:
|
||||
raise HTTPException(404, "Session not found")
|
||||
_assert_session_owner(session_obj, user)
|
||||
|
||||
new_entry = memory_manager.add_entry(text, memory_data.source, memory_data.category, owner=user)
|
||||
if memory_data.session_id:
|
||||
new_entry["session_id"] = memory_data.session_id
|
||||
@@ -163,8 +170,17 @@ def setup_memory_routes(memory_manager: MemoryManager, session_manager: SessionM
|
||||
|
||||
session_id = memory.get("session_id")
|
||||
if session_id and session_id in session_manager.sessions:
|
||||
session = session_manager.get_session(session_id)
|
||||
memory["session_name"] = session.name if session else f"Session {session_id[:6]}"
|
||||
try:
|
||||
session = session_manager.get_session(session_id)
|
||||
if session:
|
||||
_assert_session_owner(session, user)
|
||||
memory["session_name"] = session.name if session else f"Session {session_id[:6]}"
|
||||
except KeyError:
|
||||
memory["session_name"] = "Unknown"
|
||||
except HTTPException as exc:
|
||||
if exc.status_code != 404:
|
||||
raise
|
||||
memory["session_name"] = "Unknown"
|
||||
else:
|
||||
memory["session_name"] = "Unknown"
|
||||
|
||||
|
||||
+147
-134
@@ -4,8 +4,8 @@ import os
|
||||
import re
|
||||
import uuid
|
||||
import json
|
||||
import socket
|
||||
import hashlib
|
||||
import socket
|
||||
import time as _time
|
||||
import logging
|
||||
import httpx
|
||||
@@ -123,6 +123,21 @@ def _clear_user_pref_endpoint_refs(all_prefs: dict, ep_id: str) -> int:
|
||||
return cleared_users
|
||||
|
||||
|
||||
def _default_endpoint_needs_assignment(current_default_id: str, enabled_endpoint_ids) -> bool:
|
||||
"""Whether the global default chat endpoint should be (re)assigned.
|
||||
|
||||
True when nothing is configured yet, or the configured default no longer
|
||||
resolves to an enabled endpoint (e.g. the user disabled it). Without the
|
||||
second case, adding a new endpoint after disabling the previous default
|
||||
leaves `default_endpoint_id` pointing at the disabled endpoint, so features
|
||||
that read the raw setting (Memory → Tidy) fail with "No default model
|
||||
configured" even though an enabled endpoint exists. See #3586.
|
||||
"""
|
||||
if not current_default_id:
|
||||
return True
|
||||
return current_default_id not in enabled_endpoint_ids
|
||||
|
||||
|
||||
# Loopback hosts a user might type for a local model server (LM Studio,
|
||||
# llama.cpp, vLLM, …). Inside Docker these point at the *container*, not the
|
||||
# host the server actually runs on.
|
||||
@@ -283,11 +298,9 @@ _HOST_TO_CURATED = (
|
||||
("fireworks.ai", "fireworks"),
|
||||
("googleapis.com", "google"),
|
||||
("x.ai", "xai"),
|
||||
|
||||
("nvidia.com", "nvidia"),
|
||||
("openrouter.ai", "openrouter"),
|
||||
("ollama.com", "ollama"),
|
||||
("opencode.ai/zen/go", "opencode-go"),
|
||||
("opencode.ai/zen", "opencode-zen"),
|
||||
)
|
||||
|
||||
|
||||
@@ -480,10 +493,17 @@ _NON_CHAT_PREFIXES = (
|
||||
"dall-e", "tts-", "whisper", "text-embedding", "embedding",
|
||||
"davinci", "babbage", "moderation", "omni-moderation",
|
||||
"sora", "gpt-image", "chatgpt-image",
|
||||
# embedding / retrieval / non-chat models (common across providers)
|
||||
"snowflake/arctic-embed", "nvidia/nv-embed", "embed",
|
||||
)
|
||||
_NON_CHAT_CONTAINS = (
|
||||
"-realtime", "-transcribe", "-tts", "-codex",
|
||||
"codex-",
|
||||
"codex-", "content-safety", "-safety", "-reward", "nvclip",
|
||||
"kosmos", "fuyu", "deplot", "vila", "neva",
|
||||
"gliner", "riva", "-parse", "-embedqa", "-nemoretriever",
|
||||
"topic-control", "calibration",
|
||||
"ai-synthetic-video", "cosmos-reason2",
|
||||
"bge", "llama-guard",
|
||||
)
|
||||
_NON_CHAT_EXACT_PREFIXES = (
|
||||
"gpt-audio", # gpt-audio, gpt-audio-mini etc. (not gpt-4o-audio-preview which is chat)
|
||||
@@ -494,8 +514,6 @@ _NON_CHAT_EXACT_PREFIXES = (
|
||||
def _is_chat_model(model_id: str) -> bool:
|
||||
"""Return True if the model ID looks like a chat/completions-capable model."""
|
||||
mid = model_id.lower()
|
||||
if mid in {"gpt-5.1-codex"}:
|
||||
return True
|
||||
for prefix in _NON_CHAT_PREFIXES:
|
||||
if mid.startswith(prefix):
|
||||
return False
|
||||
@@ -509,15 +527,7 @@ def _is_chat_model(model_id: str) -> bool:
|
||||
|
||||
|
||||
def _delete_orphaned_provider_auth(db, auth_id: Optional[str], exclude_ep_id: Optional[str] = None) -> bool:
|
||||
"""Delete a ProviderAuthSession once no endpoint still references it.
|
||||
|
||||
Subscription providers (e.g. ChatGPT Subscription) keep their refresh token
|
||||
in ProviderAuthSession rather than ModelEndpoint.api_key. When the last
|
||||
endpoint backed by that auth row is removed, the stored credentials should
|
||||
be cleared instead of lingering. Returns True if a row was deleted.
|
||||
``exclude_ep_id`` drops the endpoint currently being deleted from the
|
||||
reference count so it does not keep its own auth alive.
|
||||
"""
|
||||
"""Delete a ProviderAuthSession once no endpoint still references it."""
|
||||
if not auth_id:
|
||||
return False
|
||||
from core.database import ProviderAuthSession
|
||||
@@ -534,40 +544,52 @@ def _delete_orphaned_provider_auth(db, auth_id: Optional[str], exclude_ep_id: Op
|
||||
return True
|
||||
|
||||
|
||||
def _is_discovery_only_provider(provider: str) -> bool:
|
||||
"""Provider that only supports model discovery, not live probing.
|
||||
def _safe_detect_provider(base_url: str) -> str:
|
||||
"""Best-effort provider detection that must not break endpoint probing."""
|
||||
try:
|
||||
return _detect_provider(base_url)
|
||||
except Exception as exc:
|
||||
logger.debug("Provider detection failed for %s: %s", base_url, exc)
|
||||
return ""
|
||||
|
||||
ChatGPT Subscription speaks the Responses/Codex API and has no
|
||||
chat-completions or general health endpoint, so completion probes and
|
||||
reachability pings are skipped — status is derived from cached models.
|
||||
"""
|
||||
|
||||
def _safe_build_models_url(base_url: str) -> str:
|
||||
"""Build a /models URL without letting optional provider imports break probes."""
|
||||
try:
|
||||
return build_models_url(base_url)
|
||||
except Exception as exc:
|
||||
logger.debug("Model URL detection failed for %s: %s", base_url, exc)
|
||||
return f"{(base_url or '').rstrip('/')}/models"
|
||||
|
||||
|
||||
def _safe_build_headers(api_key: Optional[str], base_url: str) -> dict:
|
||||
"""Build auth headers without letting optional provider imports break probes."""
|
||||
try:
|
||||
return build_headers(api_key, base_url)
|
||||
except Exception as exc:
|
||||
logger.debug("Header detection failed for %s: %s", base_url, exc)
|
||||
return {"Authorization": f"Bearer {api_key}"} if api_key else {}
|
||||
|
||||
|
||||
def _is_discovery_only_provider(provider: str) -> bool:
|
||||
return provider == "chatgpt-subscription"
|
||||
|
||||
|
||||
def _resolve_probe_key(ep) -> Optional[str]:
|
||||
"""API key/bearer to probe an endpoint with.
|
||||
|
||||
Delegates to ``resolve_endpoint_runtime``, which already returns the static
|
||||
``ModelEndpoint.api_key`` for keyed endpoints and resolves (and refreshes)
|
||||
the runtime bearer for session-backed providers (e.g. ChatGPT Subscription).
|
||||
Returns None if resolution fails (e.g. re-auth required) so probing skips
|
||||
rather than raising. Reads only already-loaded scalar attributes of ``ep``.
|
||||
"""
|
||||
"""API key/bearer to probe an endpoint with."""
|
||||
try:
|
||||
from src.endpoint_resolver import resolve_endpoint_runtime
|
||||
_base, key = resolve_endpoint_runtime(ep, owner=getattr(ep, "owner", None))
|
||||
return key
|
||||
except Exception as e:
|
||||
logger.warning("Probe key resolution failed for %s: %s", getattr(ep, "id", "?"), e)
|
||||
except Exception as exc:
|
||||
logger.warning("Probe key resolution failed for %s: %s", getattr(ep, "id", "?"), exc)
|
||||
return None
|
||||
|
||||
|
||||
def _probe_single_model(base: str, api_key: Optional[str], model_id: str, timeout: int = 10, with_tools: bool = False) -> dict:
|
||||
def _probe_single_model(base: str, api_key: str, model_id: str, timeout: int = 10, with_tools: bool = False) -> dict:
|
||||
"""Send a realistic completion request to a single model. Returns {status, latency_ms, error?}."""
|
||||
provider = _detect_provider(base)
|
||||
provider = _safe_detect_provider(base)
|
||||
if _is_discovery_only_provider(provider):
|
||||
# Responses/Codex API, not chat-completions: a completion probe would
|
||||
# 400 and the re-probe flow would then hide every model. Discovery-only.
|
||||
return {"status": "ok", "latency_ms": 0, "skipped": True}
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
@@ -587,12 +609,12 @@ def _probe_single_model(base: str, api_key: Optional[str], model_id: str, timeou
|
||||
elif provider == "ollama":
|
||||
from src.llm_core import _build_ollama_payload
|
||||
target_url = build_chat_url(base)
|
||||
h = build_headers(api_key, base)
|
||||
h = _safe_build_headers(api_key, base)
|
||||
h["Content-Type"] = "application/json"
|
||||
payload = _build_ollama_payload(model_id, messages, 0.0, 5, stream=False, tools=_test_tools)
|
||||
else:
|
||||
target_url = build_chat_url(base)
|
||||
h = build_headers(api_key, base)
|
||||
h = _safe_build_headers(api_key, base)
|
||||
h["Content-Type"] = "application/json"
|
||||
from src.llm_core import _uses_max_completion_tokens, _restricts_temperature
|
||||
_max_key = "max_completion_tokens" if _uses_max_completion_tokens(model_id) else "max_tokens"
|
||||
@@ -682,14 +704,15 @@ def _probe_endpoint(base_url: str, api_key: str = None, timeout: int = 5) -> Lis
|
||||
For Anthropic, queries their /v1/models API, falling back to hardcoded list."""
|
||||
from src.endpoint_resolver import resolve_url
|
||||
base = resolve_url(_normalize_base(base_url))
|
||||
if _detect_provider(base) == "chatgpt-subscription":
|
||||
provider = _safe_detect_provider(base)
|
||||
if provider == "chatgpt-subscription":
|
||||
from src.chatgpt_subscription import fetch_available_models
|
||||
if api_key:
|
||||
return fetch_available_models(api_key, timeout=timeout)
|
||||
return []
|
||||
if _detect_provider(base) == "anthropic":
|
||||
if provider == "anthropic":
|
||||
# Try Anthropic's /v1/models endpoint first
|
||||
url = build_models_url(base)
|
||||
url = _safe_build_models_url(base)
|
||||
headers = {"anthropic-version": "2023-06-01"}
|
||||
if api_key:
|
||||
headers["x-api-key"] = api_key
|
||||
@@ -712,12 +735,8 @@ def _probe_endpoint(base_url: str, api_key: str = None, timeout: int = 5) -> Lis
|
||||
return []
|
||||
logger.warning(f"Anthropic /v1/models failed, using hardcoded list: {e}")
|
||||
return list(ANTHROPIC_MODELS)
|
||||
url = build_models_url(base)
|
||||
if not url:
|
||||
curated_key = _match_provider_curated(base, None)
|
||||
fallback = _PROVIDER_CURATED.get(curated_key) if curated_key else None
|
||||
return list(fallback or [])
|
||||
headers = build_headers(api_key, base)
|
||||
url = _safe_build_models_url(base)
|
||||
headers = _safe_build_headers(api_key, base)
|
||||
try:
|
||||
r = httpx.get(url, headers=headers, timeout=timeout, verify=llm_verify())
|
||||
r.raise_for_status()
|
||||
@@ -735,7 +754,7 @@ def _probe_endpoint(base_url: str, api_key: str = None, timeout: int = 5) -> Lis
|
||||
for _e in _PROVIDER_CURATED.get(_ck, []):
|
||||
if _e not in set(models) and not any(m.startswith(_e) for m in models):
|
||||
models.append(_e)
|
||||
return models
|
||||
return [m for m in models if _is_chat_model(m)]
|
||||
except httpx.HTTPStatusError as e:
|
||||
if api_key:
|
||||
status = e.response.status_code if e.response is not None else "unknown"
|
||||
@@ -759,7 +778,7 @@ def _probe_endpoint(base_url: str, api_key: str = None, timeout: int = 5) -> Lis
|
||||
data = r.json()
|
||||
models = [m.get("name") or m.get("model") for m in (data.get("models") or []) if m.get("name") or m.get("model")]
|
||||
if models:
|
||||
return models
|
||||
return [m for m in models if _is_chat_model(m)]
|
||||
except Exception as e:
|
||||
logger.debug(f"Ollama /api/tags probe failed for {base}: {e}")
|
||||
# Fall back to curated list if the provider has a URL-based match (e.g. z.ai has no /models endpoint)
|
||||
@@ -770,11 +789,12 @@ def _probe_endpoint(base_url: str, api_key: str = None, timeout: int = 5) -> Lis
|
||||
return list(fallback)
|
||||
return []
|
||||
|
||||
|
||||
def _ping_endpoint(base_url: str, api_key: str = None, timeout: float = 1.5) -> Dict[str, Any]:
|
||||
"""Reachability probe that does not require installed/listed models."""
|
||||
from src.endpoint_resolver import resolve_url
|
||||
base = resolve_url(_normalize_base(base_url))
|
||||
headers = build_headers(api_key, base)
|
||||
headers = _safe_build_headers(api_key, base)
|
||||
|
||||
# Ollama exposes /v1/models (OpenAI-compatible) AND native /api/version,
|
||||
# /api/tags. Probe native paths for Ollama-style endpoints, but avoid using
|
||||
@@ -785,10 +805,6 @@ def _ping_endpoint(base_url: str, api_key: str = None, timeout: float = 1.5) ->
|
||||
or "ollama" in (parsed_base.hostname or "").lower()
|
||||
)
|
||||
|
||||
# APFEL-specific detection
|
||||
host = (parsed_base.hostname or "").lower()
|
||||
looks_like_apfel = "apfel" in host or parsed_base.port == 11435
|
||||
|
||||
def _result_from_response(r) -> Dict[str, Any]:
|
||||
if 300 <= r.status_code < 400:
|
||||
loc = r.headers.get("location", "")
|
||||
@@ -810,23 +826,7 @@ def _ping_endpoint(base_url: str, api_key: str = None, timeout: float = 1.5) ->
|
||||
last_error: Optional[str] = None
|
||||
|
||||
try:
|
||||
# APFEL does not behave like Ollama; use its health endpoint.
|
||||
if looks_like_apfel:
|
||||
root = base
|
||||
for suffix in ("/v1", "/api"):
|
||||
if root.endswith(suffix):
|
||||
root = root[: -len(suffix)].rstrip("/")
|
||||
break
|
||||
try:
|
||||
r = httpx.get(root + "/health", timeout=timeout, verify=llm_verify())
|
||||
result = _result_from_response(r)
|
||||
if result["reachable"]:
|
||||
return result
|
||||
last_error = result.get("error")
|
||||
except Exception as e:
|
||||
last_error = str(e)[:120]
|
||||
|
||||
elif looks_like_ollama:
|
||||
if looks_like_ollama:
|
||||
root = base
|
||||
for suffix in ("/v1", "/api"):
|
||||
if root.endswith(suffix):
|
||||
@@ -847,17 +847,11 @@ def _ping_endpoint(base_url: str, api_key: str = None, timeout: float = 1.5) ->
|
||||
try:
|
||||
r = httpx.get(base, headers=headers, timeout=timeout, verify=llm_verify())
|
||||
result = _result_from_response(r)
|
||||
# If the bare base URL returns a non-auth 4xx (e.g. 404), try /models
|
||||
# as a fallback. OpenAI-compatible servers like llama-swap return 404
|
||||
# on the base /v1 prefix but 200 on /v1/models. Auth failures (401/403)
|
||||
# are definitive — probing /models would just repeat the same rejection.
|
||||
if (
|
||||
not result["reachable"]
|
||||
and result.get("status_code") is not None
|
||||
and 400 <= result["status_code"] < 500
|
||||
and result["status_code"] not in (401, 403)
|
||||
):
|
||||
models_url = build_models_url(base)
|
||||
if result["reachable"]:
|
||||
return result
|
||||
sc = result.get("status_code") or 0
|
||||
if 400 <= sc < 500 and sc not in (401, 403):
|
||||
models_url = _safe_build_models_url(base)
|
||||
try:
|
||||
r2 = httpx.get(models_url, headers=headers, timeout=timeout, verify=llm_verify())
|
||||
result2 = _result_from_response(r2)
|
||||
@@ -865,12 +859,16 @@ def _ping_endpoint(base_url: str, api_key: str = None, timeout: float = 1.5) ->
|
||||
return result2
|
||||
except Exception:
|
||||
pass
|
||||
return result
|
||||
if sc:
|
||||
return result
|
||||
last_error = result.get("error") or last_error
|
||||
except Exception as e:
|
||||
last_error = str(e)[:120]
|
||||
|
||||
return {"reachable": False, "status_code": None, "error": last_error}
|
||||
|
||||
|
||||
|
||||
def _model_endpoint_error_message(base_url: str, ping: Dict[str, Any] = None) -> str:
|
||||
"""Return a provider-aware error message for failed endpoint probes."""
|
||||
ping = ping or {}
|
||||
@@ -1068,17 +1066,6 @@ def setup_model_routes(model_discovery):
|
||||
ok, info = _should_refresh_endpoint(ep, now, force=force)
|
||||
if not ok:
|
||||
continue
|
||||
if getattr(ep, "provider_auth_id", None):
|
||||
try:
|
||||
from src.endpoint_resolver import resolve_endpoint_runtime
|
||||
info["base"], info["api_key"] = resolve_endpoint_runtime(
|
||||
ep,
|
||||
owner=getattr(ep, "owner", None),
|
||||
)
|
||||
info["key"] = _refresh_key(info["base"], info["api_key"])
|
||||
except Exception as e:
|
||||
logger.warning("Skipping model refresh for %s: could not resolve provider auth: %s", getattr(ep, "name", ep.id), e)
|
||||
continue
|
||||
groups.setdefault(info["key"], {
|
||||
"base": info["base"],
|
||||
"api_key": info["api_key"],
|
||||
@@ -1156,7 +1143,7 @@ def setup_model_routes(model_discovery):
|
||||
|
||||
for ep in endpoints:
|
||||
base = _normalize_base(ep.base_url)
|
||||
provider = _detect_provider(base)
|
||||
provider = _safe_detect_provider(base)
|
||||
# Merge cached + pinned models, then filter out hidden ones
|
||||
ep_model_type = getattr(ep, "model_type", None) or "llm"
|
||||
model_ids = _visible_models(
|
||||
@@ -1233,8 +1220,8 @@ def setup_model_routes(model_discovery):
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error('Auth gate error in GET /api/models, failing closed: %s', e)
|
||||
raise HTTPException(status_code=500, detail='Internal error')
|
||||
logger.error("Auth gate error in GET /api/models, failing closed: %s", e)
|
||||
raise HTTPException(status_code=500, detail="Internal error")
|
||||
# Admins see every endpoint (they manage the global pool); regular
|
||||
# users get the owner-scoped view.
|
||||
_is_admin = False
|
||||
@@ -1298,7 +1285,14 @@ def setup_model_routes(model_discovery):
|
||||
t0 = _time.time()
|
||||
try:
|
||||
import asyncio as _asyncio
|
||||
ping = await _asyncio.to_thread(_ping_endpoint, data["base"], data.get("api_key"), 1.5)
|
||||
# Bumped 1.5s → 3.5s. The previous 1.5s budget was clipping
|
||||
# local vLLM endpoints on Tailscale links where the model
|
||||
# server is still loading (Qwen3.5-122B takes 2–3 min to
|
||||
# warm); /v1/models can take 500–2500 ms on a busy box,
|
||||
# which pushed _ping_endpoint's full path-discovery sweep
|
||||
# past the cap and marked the row offline despite the
|
||||
# user actively chatting with it.
|
||||
ping = await _asyncio.to_thread(_ping_endpoint, data["base"], data.get("api_key"), 3.5)
|
||||
lat = round((_time.time() - t0) * 1000)
|
||||
return {
|
||||
"alive": bool(ping.get("reachable")),
|
||||
@@ -1336,7 +1330,7 @@ def setup_model_routes(model_discovery):
|
||||
results = []
|
||||
for ep in endpoints:
|
||||
base = _normalize_base(ep.base_url)
|
||||
provider = _detect_provider(base)
|
||||
provider = _safe_detect_provider(base)
|
||||
kind = _effective_endpoint_kind(ep, base)
|
||||
cached_count = len(_cached_model_ids(ep))
|
||||
entry = {
|
||||
@@ -1348,20 +1342,12 @@ def setup_model_routes(model_discovery):
|
||||
"endpoint_kind": kind,
|
||||
}
|
||||
try:
|
||||
if _is_discovery_only_provider(provider):
|
||||
# No general health endpoint — an unauthenticated GET just
|
||||
# 401s. Report status from cached models instead of pinging.
|
||||
entry["latency_ms"] = None
|
||||
entry["status"] = "online" if cached_count else "offline"
|
||||
entry["error"] = None
|
||||
entry["model_count"] = cached_count
|
||||
else:
|
||||
t0 = _time.time()
|
||||
ping = _ping_endpoint(base, ep.api_key, timeout=1.5)
|
||||
entry["latency_ms"] = round((_time.time() - t0) * 1000)
|
||||
entry["status"] = "online" if ping.get("reachable") or cached_count else "offline"
|
||||
entry["error"] = ping.get("error")
|
||||
entry["model_count"] = cached_count or (len(ANTHROPIC_MODELS) if provider == "anthropic" else 0)
|
||||
t0 = _time.time()
|
||||
ping = _ping_endpoint(base, ep.api_key, timeout=1.5)
|
||||
entry["latency_ms"] = round((_time.time() - t0) * 1000)
|
||||
entry["status"] = "online" if ping.get("reachable") or cached_count else "offline"
|
||||
entry["error"] = ping.get("error")
|
||||
entry["model_count"] = cached_count or (len(ANTHROPIC_MODELS) if provider == "anthropic" else 0)
|
||||
except Exception as e:
|
||||
entry["latency_ms"] = None
|
||||
entry["status"] = "online" if cached_count else "offline"
|
||||
@@ -1394,7 +1380,7 @@ def setup_model_routes(model_discovery):
|
||||
if ep_id and ep_id not in endpoints_cache:
|
||||
ep = db.query(ModelEndpoint).filter(ModelEndpoint.id == ep_id).first()
|
||||
if ep:
|
||||
endpoints_cache[ep_id] = {"base_url": ep.base_url, "api_key": _resolve_probe_key(ep)}
|
||||
endpoints_cache[ep_id] = {"base_url": ep.base_url, "api_key": ep.api_key}
|
||||
ep_data = endpoints_cache.get(ep_id)
|
||||
if not ep_data:
|
||||
# Try to find by base_url from the model's endpoint field
|
||||
@@ -1433,7 +1419,7 @@ def setup_model_routes(model_discovery):
|
||||
"id": ep.id,
|
||||
"name": ep.name,
|
||||
"base_url": ep.base_url,
|
||||
"api_key": _resolve_probe_key(ep),
|
||||
"api_key": ep.api_key,
|
||||
})
|
||||
finally:
|
||||
db.close()
|
||||
@@ -1522,14 +1508,37 @@ def setup_model_routes(model_discovery):
|
||||
# Endpoint counts as reachable if it has any model — including
|
||||
# admin-pinned IDs that a probe would never surface.
|
||||
status = "online" if (all_models or pinned) else "offline"
|
||||
base = _normalize_base(r.base_url)
|
||||
ping = None
|
||||
# Discovery-only providers have no health endpoint — an
|
||||
# unauthenticated ping just 401s, so don't bother.
|
||||
if not all_models and not pinned and r.is_enabled and not _is_discovery_only_provider(_detect_provider(base)):
|
||||
ping = _ping_endpoint(r.base_url, r.api_key, timeout=1.0)
|
||||
# When cached_models is empty, do a quick reachability probe.
|
||||
# Bumped 1.0s → 3.5s because the user reported endpoints they
|
||||
# were ACTIVELY chatting with showed "offline" — the previous
|
||||
# 1s timeout was clipping live cloud endpoints (DeepSeek can
|
||||
# take 1.5–2.5s on /v1/models when their region is under load,
|
||||
# vLLM on a remote GPU box behind SSH can also push past 1s).
|
||||
# 3.5s still keeps the picker render snappy in the common
|
||||
# "everything's already cached" path because this branch only
|
||||
# runs for endpoints with an empty cached_models.
|
||||
if not all_models and not pinned and r.is_enabled:
|
||||
ping = _ping_endpoint(r.base_url, r.api_key, timeout=3.5)
|
||||
if ping.get("reachable"):
|
||||
status = "empty"
|
||||
# Best-effort: if the probe came back reachable, try
|
||||
# to populate cached_models in the background so the
|
||||
# NEXT picker load shows "online" instead of "empty".
|
||||
# Failure here is silent — we already returned the
|
||||
# "empty" status, and the existing background refresh
|
||||
# path will eventually fill it in too.
|
||||
try:
|
||||
probed = _probe_endpoint(r.base_url, r.api_key, timeout=5)
|
||||
if probed:
|
||||
r.cached_models = json.dumps(probed)
|
||||
db.commit()
|
||||
all_models = probed
|
||||
visible = _visible_models(all_models, r.hidden_models, pinned)
|
||||
status = "online"
|
||||
except Exception as _refill_err:
|
||||
logger.debug(f"opportunistic cached_models refill failed for {r.id}: {_refill_err!r}")
|
||||
base = _normalize_base(r.base_url)
|
||||
kind = _effective_endpoint_kind(r, base)
|
||||
results.append({
|
||||
"id": r.id,
|
||||
@@ -1603,11 +1612,10 @@ def setup_model_routes(model_discovery):
|
||||
)
|
||||
explicit_timeout = _explicit_model_list_timeout(base_url, requested_kind, refresh_timeout)
|
||||
|
||||
# Dedupe: if an endpoint with the same base_url and compatible
|
||||
# credentials already exists and is reachable by the caller (shared or
|
||||
# owned by them), return it instead of creating a duplicate row. Keep
|
||||
# same-url/different-key rows distinct so users can group the same
|
||||
# provider URL under multiple credentials.
|
||||
# Dedupe: if an endpoint with the same base_url already exists and
|
||||
# is reachable by the caller (shared or owned by them), return it
|
||||
# instead of creating a duplicate row. Fixes "Scan for Servers"
|
||||
# re-adding manually-added endpoints under their host:port name.
|
||||
from src.auth_helpers import get_current_user as _gcu_dedup
|
||||
_caller = _gcu_dedup(request) or None
|
||||
_incoming_api_key = api_key.strip()
|
||||
@@ -1734,12 +1742,19 @@ def setup_model_routes(model_discovery):
|
||||
)
|
||||
db.add(ep)
|
||||
db.commit()
|
||||
# Auto-set as default chat endpoint if none configured yet. Seed
|
||||
# the first CHAT model (not raw model_ids[0]) so we don't pin the
|
||||
# global default to an embedding/tts/etc. entry a provider happens
|
||||
# to list first.
|
||||
# Auto-set as default chat endpoint when none is usable yet — either
|
||||
# nothing is configured, or the configured default points at an
|
||||
# endpoint that is now missing/disabled (#3586). Seed the first CHAT
|
||||
# model (not raw model_ids[0]) so we don't pin the global default to
|
||||
# an embedding/tts/etc. entry a provider happens to list first.
|
||||
settings = _load_settings()
|
||||
if not settings.get("default_endpoint_id"):
|
||||
enabled_ids = {
|
||||
e.id
|
||||
for e in db.query(ModelEndpoint).filter(
|
||||
ModelEndpoint.is_enabled == True # noqa: E712
|
||||
).all()
|
||||
}
|
||||
if _default_endpoint_needs_assignment(settings.get("default_endpoint_id") or "", enabled_ids):
|
||||
from src.endpoint_resolver import _first_chat_model
|
||||
settings["default_endpoint_id"] = ep.id
|
||||
settings["default_model"] = _first_chat_model(model_ids) or ""
|
||||
@@ -1805,7 +1820,7 @@ def setup_model_routes(model_discovery):
|
||||
ep = db.query(ModelEndpoint).filter(ModelEndpoint.id == ep_id).first()
|
||||
if not ep:
|
||||
raise HTTPException(404, "Endpoint not found")
|
||||
ep_data = {"id": ep.id, "name": ep.name, "base_url": ep.base_url, "api_key": _resolve_probe_key(ep)}
|
||||
ep_data = {"id": ep.id, "name": ep.name, "base_url": ep.base_url, "api_key": ep.api_key}
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
@@ -1869,7 +1884,7 @@ def setup_model_routes(model_discovery):
|
||||
category = _classify_endpoint(base, kind)
|
||||
timeout = _manual_refresh_timeout(ep, category, refresh_timeout)
|
||||
try:
|
||||
probed = _probe_endpoint(base, _resolve_probe_key(ep), timeout=timeout)
|
||||
probed = _probe_endpoint(base, ep.api_key, timeout=timeout)
|
||||
except Exception as exc:
|
||||
logger.warning("Manual model refresh failed for endpoint %s at %s: %s", ep_id, base, exc)
|
||||
probed = []
|
||||
@@ -2105,8 +2120,6 @@ def setup_model_routes(model_discovery):
|
||||
"name": ep.name,
|
||||
"model_type": ep.model_type,
|
||||
"base_url": ep.base_url,
|
||||
"has_key": bool(ep.api_key),
|
||||
"api_key_fingerprint": _api_key_fingerprint(ep.api_key),
|
||||
"pinned_models": _normalize_model_ids(getattr(ep, "pinned_models", None)),
|
||||
"endpoint_kind": getattr(ep, "endpoint_kind", None) or "auto",
|
||||
"model_refresh_mode": getattr(ep, "model_refresh_mode", None) or "auto",
|
||||
|
||||
+19
-11
@@ -10,8 +10,9 @@ import logging
|
||||
from core.session_manager import SessionManager
|
||||
from core.models import ChatMessage
|
||||
from src.request_models import SessionResponse
|
||||
from core.database import Session as DbSession, SessionLocal, Document, GalleryImage
|
||||
from src.auth_helpers import get_current_user, effective_user, _auth_disabled
|
||||
from core.database import Session as DbSession, SessionLocal, Document, GalleryImage, utcnow_naive
|
||||
from src.auth_helpers import get_current_user, effective_user, _auth_disabled, owner_filter
|
||||
from src.session_actions import is_session_recently_active
|
||||
|
||||
|
||||
def _sanitize_export_filename(name: str) -> str:
|
||||
@@ -257,7 +258,9 @@ def setup_session_routes(session_manager: SessionManager, config: dict, webhook_
|
||||
last_msg_map = {}
|
||||
mode_map = {}
|
||||
msg_count_map = {}
|
||||
rows = db.query(DbSession.id, DbSession.folder, DbSession.total_input_tokens, DbSession.total_output_tokens, DbSession.is_important, DbSession.created_at, DbSession.updated_at, DbSession.last_message_at, DbSession.mode, DbSession.message_count).filter(DbSession.archived == False, DbSession.owner == user).all()
|
||||
q = db.query(DbSession.id, DbSession.folder, DbSession.total_input_tokens, DbSession.total_output_tokens, DbSession.is_important, DbSession.created_at, DbSession.updated_at, DbSession.last_message_at, DbSession.mode, DbSession.message_count).filter(DbSession.archived == False)
|
||||
q = owner_filter(q, DbSession, user)
|
||||
rows = q.all()
|
||||
for row in rows:
|
||||
folder_map[row.id] = row.folder
|
||||
token_map[row.id] = (row.total_input_tokens or 0) + (row.total_output_tokens or 0)
|
||||
@@ -276,17 +279,19 @@ def setup_session_routes(session_manager: SessionManager, config: dict, webhook_
|
||||
# Sessions with active documents that have content
|
||||
from sqlalchemy import func
|
||||
doc_session_ids = set(
|
||||
r[0] for r in db.query(Document.session_id)
|
||||
.filter(Document.is_active == True,
|
||||
Document.current_content != None,
|
||||
func.trim(Document.current_content) != "",
|
||||
Document.owner == user)
|
||||
r[0] for r in owner_filter(
|
||||
db.query(Document.session_id)
|
||||
.filter(Document.is_active == True,
|
||||
Document.current_content != None,
|
||||
func.trim(Document.current_content) != ""),
|
||||
Document, user)
|
||||
.distinct().all()
|
||||
)
|
||||
img_session_ids = set(
|
||||
r[0] for r in db.query(GalleryImage.session_id)
|
||||
.filter(GalleryImage.session_id != None,
|
||||
GalleryImage.owner == user)
|
||||
r[0] for r in owner_filter(
|
||||
db.query(GalleryImage.session_id)
|
||||
.filter(GalleryImage.session_id != None),
|
||||
GalleryImage, user)
|
||||
.distinct().all()
|
||||
)
|
||||
finally:
|
||||
@@ -1028,6 +1033,7 @@ def setup_session_routes(session_manager: SessionManager, config: dict, webhook_
|
||||
db.query(DbMsg.session_id, _sa_func.count(DbMsg.id))
|
||||
.filter(DbMsg.role == "assistant").group_by(DbMsg.session_id).all()
|
||||
)
|
||||
cleanup_now = utcnow_naive()
|
||||
for row in rows:
|
||||
# Never delete important sessions
|
||||
if getattr(row, 'is_important', False):
|
||||
@@ -1040,6 +1046,8 @@ def setup_session_routes(session_manager: SessionManager, config: dict, webhook_
|
||||
if hasattr(session_manager, 'delete_session'):
|
||||
session_manager.delete_session(row.id)
|
||||
continue
|
||||
if is_session_recently_active(row, now=cleanup_now):
|
||||
continue
|
||||
msg_count = _counts.get(row.id, 0)
|
||||
should_delete = False
|
||||
if msg_count == 0:
|
||||
|
||||
@@ -519,6 +519,15 @@ def setup_task_routes(task_scheduler) -> APIRouter:
|
||||
else bool(req.notifications_enabled) if req.notifications_enabled is not None
|
||||
else True
|
||||
)
|
||||
# Validate chained task belongs to same owner
|
||||
if req.then_task_id:
|
||||
chain_target = db.query(ScheduledTask).filter(
|
||||
ScheduledTask.id == req.then_task_id
|
||||
).first()
|
||||
if not chain_target:
|
||||
raise HTTPException(400, "Chained task not found")
|
||||
if chain_target.owner != user:
|
||||
raise HTTPException(403, "Cannot chain to another user's task")
|
||||
task = ScheduledTask(
|
||||
id=task_id,
|
||||
owner=user,
|
||||
|
||||
@@ -1,10 +1,15 @@
|
||||
"""Workspace API — browse server directories to pick a tool workspace folder."""
|
||||
"""Workspace API - browse server directories to pick a tool workspace folder."""
|
||||
import os
|
||||
from fastapi import APIRouter, Request, HTTPException, Query
|
||||
|
||||
from src.auth_helpers import get_current_user
|
||||
from src.tool_security import owner_is_admin_or_single_user
|
||||
|
||||
# Cap entries returned per directory (mirrors filesystem_tools._CODENAV_MAX_HITS).
|
||||
# A huge directory shouldn't dump thousands of rows into the picker; the user can
|
||||
# type/paste a path to jump straight in instead.
|
||||
_MAX_BROWSE_DIRS = 500
|
||||
|
||||
|
||||
def setup_workspace_routes():
|
||||
router = APIRouter(prefix="/api/workspace", tags=["workspace"])
|
||||
@@ -34,7 +39,7 @@ def setup_workspace_routes():
|
||||
with os.scandir(target) as it:
|
||||
for entry in it:
|
||||
try:
|
||||
# Don't follow symlinks when classifying — a symlinked
|
||||
# Don't follow symlinks when classifying - a symlinked
|
||||
# dir is skipped rather than letting the browser wander
|
||||
# off via a link. Hidden entries are omitted.
|
||||
if entry.is_dir(follow_symlinks=False) and not entry.name.startswith("."):
|
||||
@@ -46,11 +51,35 @@ def setup_workspace_routes():
|
||||
except (PermissionError, OSError):
|
||||
dirs = []
|
||||
|
||||
dirs_sorted = sorted(dirs, key=lambda d: d["name"].lower())
|
||||
truncated = len(dirs_sorted) > _MAX_BROWSE_DIRS
|
||||
parent = os.path.dirname(target)
|
||||
from src.tool_execution import vet_workspace
|
||||
return {
|
||||
"path": target,
|
||||
"parent": parent if parent and parent != target else None,
|
||||
"dirs": sorted(dirs, key=lambda d: d["name"].lower()),
|
||||
"dirs": dirs_sorted[:_MAX_BROWSE_DIRS],
|
||||
"truncated": truncated,
|
||||
# Whether this directory may be bound as a workspace (filesystem
|
||||
# roots and sensitive dirs may be browsed through but not chosen).
|
||||
"selectable": vet_workspace(target) is not None,
|
||||
}
|
||||
|
||||
@router.get("/vet")
|
||||
def vet(request: Request, path: str = Query(default="")):
|
||||
"""Validate a workspace path without binding it.
|
||||
|
||||
The UI calls this before persisting a manually typed path (/workspace
|
||||
set) so a typo, file path, deleted folder, sensitive dir, or filesystem
|
||||
root is rejected up front with the canonical path returned on success,
|
||||
instead of being stored client-side and silently dropped at chat time.
|
||||
Admin-gated like /browse: it confirms path existence on the host.
|
||||
"""
|
||||
owner = get_current_user(request)
|
||||
if not owner_is_admin_or_single_user(owner):
|
||||
raise HTTPException(status_code=403, detail="Workspace selection is admin-only")
|
||||
from src.tool_execution import vet_workspace
|
||||
resolved = vet_workspace(path)
|
||||
return {"ok": resolved is not None, "path": resolved}
|
||||
|
||||
return router
|
||||
|
||||
Reference in New Issue
Block a user