mirror of
https://github.com/pewdiepie-archdaemon/odysseus.git
synced 2026-06-16 09:45:24 -04:00
Scope auxiliary LLM endpoints by owner (#2996)
* fix(auth): scope auxiliary llm endpoints by owner * fix(auth): scope auxiliary llm fallbacks by owner
This commit is contained in:
@@ -551,7 +551,7 @@ async def build_chat_context(
|
||||
|
||||
# Auto-compact
|
||||
messages, context_length, was_compacted = await maybe_compact(
|
||||
sess, sess.endpoint_url, sess.model, messages, sess.headers,
|
||||
sess, sess.endpoint_url, sess.model, messages, sess.headers, owner=user,
|
||||
)
|
||||
messages = trim_for_context(messages, context_length)
|
||||
|
||||
|
||||
+12
-11
@@ -39,11 +39,13 @@ def _first_chat_model(models) -> str:
|
||||
|
||||
def _resolve_research_endpoint(sess) -> tuple:
|
||||
"""Return (endpoint_url, model, headers) for Deep Research, checking admin overrides."""
|
||||
owner = getattr(sess, "owner", None) or None
|
||||
url, model, headers = resolve_endpoint(
|
||||
"research",
|
||||
fallback_url=sess.endpoint_url,
|
||||
fallback_model=sess.model,
|
||||
fallback_headers=sess.headers,
|
||||
owner=owner,
|
||||
)
|
||||
return url, model, headers
|
||||
|
||||
@@ -392,17 +394,17 @@ def setup_research_routes(research_handler, session_manager=None) -> APIRouter:
|
||||
finally:
|
||||
db.close()
|
||||
else:
|
||||
ep_url, ep_model, ep_headers = resolve_endpoint("research")
|
||||
ep_url, ep_model, ep_headers = resolve_endpoint("research", owner=user)
|
||||
if not ep_url:
|
||||
ep_url, ep_model, ep_headers = resolve_endpoint("utility")
|
||||
ep_url, ep_model, ep_headers = resolve_endpoint("utility", owner=user)
|
||||
# When neither research nor utility is configured, use the user's
|
||||
# configured DEFAULT model (default_endpoint_id/default_model) rather
|
||||
# than arbitrarily grabbing the first enabled endpoint's first model
|
||||
# (which surfaced gpt-3.5). "Default" should mean the default model.
|
||||
if not ep_url:
|
||||
ep_url, ep_model, ep_headers = resolve_endpoint("default")
|
||||
ep_url, ep_model, ep_headers = resolve_endpoint("default", owner=user)
|
||||
if not ep_url:
|
||||
ep_url, ep_model, ep_headers = resolve_endpoint("chat")
|
||||
ep_url, ep_model, ep_headers = resolve_endpoint("chat", owner=user)
|
||||
if not ep_url:
|
||||
from src.database import SessionLocal
|
||||
from src.endpoint_resolver import normalize_base, build_chat_url, build_headers
|
||||
@@ -572,19 +574,18 @@ def setup_research_routes(research_handler, session_manager=None) -> APIRouter:
|
||||
ep_headers = dict(r_headers)
|
||||
|
||||
if not ep_url or not ep_model:
|
||||
_merge(*resolve_endpoint("chat"))
|
||||
_merge(*resolve_endpoint("chat", owner=user))
|
||||
if not ep_url or not ep_model:
|
||||
_merge(*resolve_endpoint("research"))
|
||||
_merge(*resolve_endpoint("research", owner=user))
|
||||
if not ep_url or not ep_model:
|
||||
_merge(*resolve_endpoint("utility"))
|
||||
_merge(*resolve_endpoint("utility", owner=user))
|
||||
if not ep_url or not ep_model:
|
||||
# Last resort: any enabled endpoint
|
||||
# Last resort: this user's enabled endpoint, plus legacy shared rows.
|
||||
from src.database import SessionLocal
|
||||
from src.database import ModelEndpoint
|
||||
from src.endpoint_resolver import normalize_base, build_chat_url, build_headers
|
||||
db = SessionLocal()
|
||||
try:
|
||||
ep = db.query(ModelEndpoint).filter(ModelEndpoint.is_enabled == True).first()
|
||||
ep = _owned_enabled_endpoint(db, user)
|
||||
if ep:
|
||||
base = normalize_base(ep.base_url)
|
||||
fallback_url = build_chat_url(base)
|
||||
@@ -594,7 +595,7 @@ def setup_research_routes(research_handler, session_manager=None) -> APIRouter:
|
||||
try:
|
||||
models = json.loads(ep.cached_models)
|
||||
if models:
|
||||
fallback_model = models[0]
|
||||
fallback_model = _first_chat_model(models)
|
||||
except Exception:
|
||||
pass
|
||||
_merge(fallback_url, fallback_model, fallback_headers)
|
||||
|
||||
@@ -924,7 +924,8 @@ def setup_session_routes(session_manager: SessionManager, config: dict, webhook_
|
||||
from src.endpoint_resolver import resolve_endpoint
|
||||
from src.llm_core import llm_call_async
|
||||
|
||||
url, model, headers = resolve_endpoint("utility", owner=get_current_user(request))
|
||||
owner = getattr(session, "owner", None) or effective_user(request)
|
||||
url, model, headers = resolve_endpoint("utility", owner=owner)
|
||||
if not url or not model:
|
||||
url, model, headers = session.endpoint_url, session.model, session.headers
|
||||
if not url or not model:
|
||||
|
||||
@@ -291,20 +291,24 @@ def setup_task_routes(task_scheduler) -> APIRouter:
|
||||
def _owner(request: Request):
|
||||
return get_current_user(request)
|
||||
|
||||
async def _generate_task_name(prompt: str) -> str:
|
||||
async def _generate_task_name(prompt: str, owner: Optional[str] = None) -> str:
|
||||
"""Use LLM to generate a short task name from the prompt."""
|
||||
try:
|
||||
from src.llm_core import llm_call_async
|
||||
from core.database import Session as DbSession
|
||||
db = SessionLocal()
|
||||
try:
|
||||
recent = db.query(DbSession).filter(
|
||||
q = db.query(DbSession).filter(
|
||||
DbSession.endpoint_url.isnot(None),
|
||||
DbSession.model.isnot(None),
|
||||
).order_by(DbSession.created_at.desc()).first()
|
||||
)
|
||||
if owner:
|
||||
q = q.filter(DbSession.owner == owner)
|
||||
recent = q.order_by(DbSession.created_at.desc()).first()
|
||||
if not recent:
|
||||
return prompt[:50].strip()
|
||||
url, model = recent.endpoint_url, recent.model
|
||||
headers = recent.headers or {}
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
@@ -315,6 +319,7 @@ def setup_task_routes(task_scheduler) -> APIRouter:
|
||||
{"role": "user", "content": prompt[:500]},
|
||||
],
|
||||
max_tokens=20,
|
||||
headers=headers,
|
||||
timeout=15,
|
||||
)
|
||||
title = result.strip().strip('"\'').strip()
|
||||
@@ -479,7 +484,7 @@ def setup_task_routes(task_scheduler) -> APIRouter:
|
||||
from src.builtin_actions import BUILTIN_ACTION_INFO
|
||||
name = BUILTIN_ACTION_INFO.get(req.action, req.action or "Action Task")
|
||||
elif req.prompt:
|
||||
name = await _generate_task_name(req.prompt)
|
||||
name = await _generate_task_name(req.prompt, owner=user)
|
||||
else:
|
||||
name = "Untitled Task"
|
||||
|
||||
|
||||
@@ -307,6 +307,7 @@ async def maybe_compact(
|
||||
model: str,
|
||||
messages: List[Dict],
|
||||
headers: Optional[Dict] = None,
|
||||
owner: Optional[str] = None,
|
||||
) -> tuple:
|
||||
"""Check context usage and compact if above threshold.
|
||||
|
||||
@@ -353,7 +354,7 @@ async def maybe_compact(
|
||||
)
|
||||
|
||||
# Use utility model if configured, otherwise fall back to session model
|
||||
util_url, util_model, util_headers = resolve_endpoint("utility")
|
||||
util_url, util_model, util_headers = resolve_endpoint("utility", owner=owner)
|
||||
compact_url = util_url or endpoint_url
|
||||
compact_model = util_model or model
|
||||
compact_headers = util_headers if util_url else headers
|
||||
|
||||
@@ -132,7 +132,7 @@ async def run_auto_sort(owner: str, skip_llm: bool = False, delete_throwaway: bo
|
||||
if skip_llm:
|
||||
return f"Cleaned {deleted_empty + deleted_throwaway} sessions (folder sort skipped)."
|
||||
|
||||
url, model, headers = resolve_task_endpoint()
|
||||
url, model, headers = resolve_task_endpoint(owner=owner or None)
|
||||
if not url:
|
||||
return f"Cleaned {deleted_empty + deleted_throwaway} sessions. No model endpoint available for sorting."
|
||||
|
||||
|
||||
+21
-9
@@ -1580,9 +1580,12 @@ class TaskScheduler:
|
||||
try:
|
||||
from core.database import SessionLocal, ModelEndpoint
|
||||
from src.endpoint_resolver import normalize_base, build_headers
|
||||
from src.auth_helpers import owner_filter
|
||||
db2 = SessionLocal()
|
||||
try:
|
||||
eps = db2.query(ModelEndpoint).filter(ModelEndpoint.is_enabled == True).all()
|
||||
ep_q = db2.query(ModelEndpoint).filter(ModelEndpoint.is_enabled == True)
|
||||
ep_q = owner_filter(ep_q, ModelEndpoint, task.owner or None)
|
||||
eps = ep_q.all()
|
||||
for ep in eps:
|
||||
if normalize_base(ep.base_url) in endpoint_url or endpoint_url in normalize_base(ep.base_url):
|
||||
headers = build_headers(ep.api_key, normalize_base(ep.base_url))
|
||||
@@ -1603,7 +1606,7 @@ class TaskScheduler:
|
||||
# chat uses but with the utility list (`utility_model_fallbacks`).
|
||||
try:
|
||||
from src.endpoint_resolver import resolve_utility_fallback_candidates
|
||||
_task_fallbacks = resolve_utility_fallback_candidates()
|
||||
_task_fallbacks = resolve_utility_fallback_candidates(owner=task.owner or None)
|
||||
except Exception:
|
||||
_task_fallbacks = []
|
||||
async for event_str in stream_agent_loop(
|
||||
@@ -1646,7 +1649,7 @@ class TaskScheduler:
|
||||
else:
|
||||
grace_context += "No tool results were captured."
|
||||
grace_context += "\n\nSummarize what you accomplished and what's still pending. Be concise."
|
||||
_grace_candidates = [(endpoint_url, model, headers)] + resolve_utility_fallback_candidates()
|
||||
_grace_candidates = [(endpoint_url, model, headers)] + resolve_utility_fallback_candidates(owner=task.owner or None)
|
||||
full_text = await llm_call_async_with_fallback(
|
||||
_grace_candidates,
|
||||
messages=[
|
||||
@@ -1674,6 +1677,8 @@ class TaskScheduler:
|
||||
# Resolve endpoint/model: research settings > task settings > session defaults
|
||||
endpoint_url = task.endpoint_url
|
||||
model = task.model
|
||||
headers = {}
|
||||
headers_from_resolver = False
|
||||
|
||||
if not endpoint_url or not model:
|
||||
try:
|
||||
@@ -1683,9 +1688,13 @@ class TaskScheduler:
|
||||
endpoint_url or None,
|
||||
model or None,
|
||||
None,
|
||||
owner=task.owner or None,
|
||||
)
|
||||
endpoint_url = ep_url or endpoint_url
|
||||
model = ep_model or model
|
||||
if ep_headers is not None:
|
||||
headers = ep_headers
|
||||
headers_from_resolver = True
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@@ -1697,16 +1706,19 @@ class TaskScheduler:
|
||||
self._last_run_model = model
|
||||
|
||||
# Resolve headers
|
||||
headers = {}
|
||||
try:
|
||||
from core.database import ModelEndpoint
|
||||
from src.endpoint_resolver import normalize_base, build_headers
|
||||
from src.auth_helpers import owner_filter
|
||||
db2 = db
|
||||
eps = db2.query(ModelEndpoint).filter(ModelEndpoint.is_enabled == True).all()
|
||||
for ep in eps:
|
||||
if normalize_base(ep.base_url) in endpoint_url or endpoint_url in normalize_base(ep.base_url):
|
||||
headers = build_headers(ep.api_key, normalize_base(ep.base_url))
|
||||
break
|
||||
if not headers_from_resolver:
|
||||
ep_q = db2.query(ModelEndpoint).filter(ModelEndpoint.is_enabled == True)
|
||||
ep_q = owner_filter(ep_q, ModelEndpoint, task.owner or None)
|
||||
eps = ep_q.all()
|
||||
for ep in eps:
|
||||
if normalize_base(ep.base_url) in endpoint_url or endpoint_url in normalize_base(ep.base_url):
|
||||
headers = build_headers(ep.api_key, normalize_base(ep.base_url))
|
||||
break
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
@@ -0,0 +1,67 @@
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
ROOT = Path(__file__).resolve().parents[1]
|
||||
|
||||
|
||||
def _src(path: str) -> str:
|
||||
return (ROOT / path).read_text(encoding="utf-8")
|
||||
|
||||
|
||||
def test_registered_manual_compaction_uses_session_owner_for_utility_endpoint():
|
||||
session_src = _src("routes/session_routes.py")
|
||||
|
||||
assert 'owner = getattr(session, "owner", None) or effective_user(request)' in session_src
|
||||
assert 'resolve_endpoint("utility", owner=owner)' in session_src
|
||||
|
||||
|
||||
def test_task_name_generation_uses_owner_scoped_session_endpoint():
|
||||
src = _src("routes/task_routes.py")
|
||||
|
||||
assert "async def _generate_task_name(prompt: str, owner: Optional[str] = None)" in src
|
||||
assert "q = q.filter(DbSession.owner == owner)" in src
|
||||
assert "headers = recent.headers or {}" in src
|
||||
assert "headers=headers" in src
|
||||
assert "await _generate_task_name(req.prompt, owner=user)" in src
|
||||
|
||||
|
||||
def test_auto_compaction_utility_endpoint_keeps_chat_owner():
|
||||
helper_src = _src("routes/chat_helpers.py")
|
||||
compact_src = _src("src/context_compactor.py")
|
||||
|
||||
assert "owner=user" in helper_src
|
||||
assert "owner: Optional[str] = None" in compact_src
|
||||
assert 'resolve_endpoint("utility", owner=owner)' in compact_src
|
||||
|
||||
|
||||
def test_background_session_sort_uses_owner_task_endpoint():
|
||||
src = _src("src/session_actions.py")
|
||||
|
||||
assert "resolve_task_endpoint(owner=owner or None)" in src
|
||||
|
||||
|
||||
def test_scheduler_fallbacks_and_research_headers_are_owner_scoped():
|
||||
src = _src("src/task_scheduler.py")
|
||||
|
||||
assert "resolve_utility_fallback_candidates(owner=task.owner or None)" in src
|
||||
assert 'resolve_endpoint(\n "research",' in src
|
||||
assert "owner=task.owner or None" in src
|
||||
assert "headers_from_resolver = False" in src
|
||||
assert "headers_from_resolver = True" in src
|
||||
assert "from src.auth_helpers import owner_filter" in src
|
||||
assert "owner_filter(ep_q, ModelEndpoint, task.owner or None)" in src
|
||||
|
||||
|
||||
def test_research_routes_fallbacks_are_owner_scoped():
|
||||
src = _src("routes/research_routes.py")
|
||||
|
||||
assert 'resolve_endpoint("research", owner=user)' in src
|
||||
assert 'resolve_endpoint("utility", owner=user)' in src
|
||||
assert 'resolve_endpoint("default", owner=user)' in src
|
||||
assert 'resolve_endpoint("chat", owner=user)' in src
|
||||
assert '_merge(*resolve_endpoint("chat", owner=user))' in src
|
||||
assert '_merge(*resolve_endpoint("research", owner=user))' in src
|
||||
assert '_merge(*resolve_endpoint("utility", owner=user))' in src
|
||||
assert "ep = _owned_enabled_endpoint(db, user)" in src
|
||||
assert "db.query(ModelEndpoint).filter(ModelEndpoint.is_enabled == True).first()" not in src
|
||||
assert "owner = getattr(sess, \"owner\", None) or None" in src
|
||||
@@ -133,7 +133,7 @@ class TestMaybeCompactFourthMessage:
|
||||
|
||||
cc.get_context_length = lambda url, model: context_length
|
||||
cc.llm_call_async = _fake_summary
|
||||
cc.resolve_endpoint = lambda which: (None, None, None)
|
||||
cc.resolve_endpoint = lambda which, owner=None: (None, None, None)
|
||||
cc._update_session_history = lambda *a, **k: None
|
||||
try:
|
||||
return asyncio.run(
|
||||
|
||||
@@ -79,6 +79,7 @@ class _FakeSession:
|
||||
endpoint_url = "http://example.test/v1"
|
||||
model = "test-model"
|
||||
headers = {}
|
||||
owner = "session-owner"
|
||||
|
||||
def __init__(self, history):
|
||||
self.history = history
|
||||
@@ -107,7 +108,11 @@ def _compact_prompt_for(monkeypatch, history):
|
||||
import src.model_context as model_context
|
||||
|
||||
monkeypatch.setattr(agent_runs, "is_active", lambda session_id: False)
|
||||
monkeypatch.setattr(endpoint_resolver, "resolve_endpoint", lambda kind, owner=None: (None, None, {}))
|
||||
def fake_resolve_endpoint(kind, owner=None):
|
||||
captured.setdefault("resolve_calls", []).append((kind, owner))
|
||||
return None, None, {}
|
||||
|
||||
monkeypatch.setattr(endpoint_resolver, "resolve_endpoint", fake_resolve_endpoint)
|
||||
monkeypatch.setattr(llm_core, "llm_call_async", fake_llm_call_async)
|
||||
monkeypatch.setattr(model_context, "estimate_tokens", lambda messages: 100)
|
||||
monkeypatch.setattr(model_context, "get_context_length", lambda endpoint_url, model: 1000)
|
||||
@@ -146,7 +151,11 @@ def _registered_compact_response(monkeypatch, history, active_run=False):
|
||||
import src.llm_core as llm_core
|
||||
|
||||
monkeypatch.setattr(agent_runs, "is_active", lambda session_id: active_run)
|
||||
monkeypatch.setattr(endpoint_resolver, "resolve_endpoint", lambda kind, owner=None: (None, None, {}))
|
||||
def fake_resolve_endpoint(kind, owner=None):
|
||||
captured.setdefault("resolve_calls", []).append((kind, owner))
|
||||
return None, None, {}
|
||||
|
||||
monkeypatch.setattr(endpoint_resolver, "resolve_endpoint", fake_resolve_endpoint)
|
||||
monkeypatch.setattr(llm_core, "llm_call_async", fake_llm_call_async)
|
||||
|
||||
session = _FakeSession(history)
|
||||
@@ -212,6 +221,24 @@ def test_registered_manual_compact_route_tolerates_none_content(monkeypatch):
|
||||
assert manager.replaced_messages is not None
|
||||
|
||||
|
||||
def test_registered_manual_compact_route_uses_session_owner(monkeypatch):
|
||||
response, captured, manager = _registered_compact_response(
|
||||
monkeypatch,
|
||||
[
|
||||
ChatMessage(role="user", content="start"),
|
||||
ChatMessage(role="assistant", content="tool call"),
|
||||
ChatMessage(role="tool", content="tool result"),
|
||||
ChatMessage(role="assistant", content="done"),
|
||||
ChatMessage(role="user", content="next"),
|
||||
ChatMessage(role="assistant", content="final"),
|
||||
],
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert manager.replaced_messages is not None
|
||||
assert ("utility", "session-owner") in captured["resolve_calls"]
|
||||
|
||||
|
||||
def test_registered_manual_compact_route_rejects_active_agent_run(monkeypatch):
|
||||
response, captured, manager = _registered_compact_response(
|
||||
monkeypatch,
|
||||
|
||||
@@ -365,7 +365,7 @@ async def test_build_chat_context_incognito_does_not_duplicate_current_user_mess
|
||||
def fake_add_user_message(sess, chat_handler, preprocessed, incognito=False):
|
||||
sess.messages.append({"role": "user", "content": preprocessed.user_content})
|
||||
|
||||
async def fake_maybe_compact(sess, endpoint_url, model, messages, headers):
|
||||
async def fake_maybe_compact(sess, endpoint_url, model, messages, headers, owner=None):
|
||||
return messages, 123, False
|
||||
|
||||
monkeypatch.setattr(chat_helpers, "preprocess", fake_preprocess)
|
||||
|
||||
Reference in New Issue
Block a user