diff --git a/routes/chat_helpers.py b/routes/chat_helpers.py index b8d8b61f2..2e5db4478 100644 --- a/routes/chat_helpers.py +++ b/routes/chat_helpers.py @@ -551,7 +551,7 @@ async def build_chat_context( # Auto-compact messages, context_length, was_compacted = await maybe_compact( - sess, sess.endpoint_url, sess.model, messages, sess.headers, + sess, sess.endpoint_url, sess.model, messages, sess.headers, owner=user, ) messages = trim_for_context(messages, context_length) diff --git a/routes/research_routes.py b/routes/research_routes.py index 267ab50e9..569dad3e9 100644 --- a/routes/research_routes.py +++ b/routes/research_routes.py @@ -39,11 +39,13 @@ def _first_chat_model(models) -> str: def _resolve_research_endpoint(sess) -> tuple: """Return (endpoint_url, model, headers) for Deep Research, checking admin overrides.""" + owner = getattr(sess, "owner", None) or None url, model, headers = resolve_endpoint( "research", fallback_url=sess.endpoint_url, fallback_model=sess.model, fallback_headers=sess.headers, + owner=owner, ) return url, model, headers @@ -392,17 +394,17 @@ def setup_research_routes(research_handler, session_manager=None) -> APIRouter: finally: db.close() else: - ep_url, ep_model, ep_headers = resolve_endpoint("research") + ep_url, ep_model, ep_headers = resolve_endpoint("research", owner=user) if not ep_url: - ep_url, ep_model, ep_headers = resolve_endpoint("utility") + ep_url, ep_model, ep_headers = resolve_endpoint("utility", owner=user) # When neither research nor utility is configured, use the user's # configured DEFAULT model (default_endpoint_id/default_model) rather # than arbitrarily grabbing the first enabled endpoint's first model # (which surfaced gpt-3.5). "Default" should mean the default model. if not ep_url: - ep_url, ep_model, ep_headers = resolve_endpoint("default") + ep_url, ep_model, ep_headers = resolve_endpoint("default", owner=user) if not ep_url: - ep_url, ep_model, ep_headers = resolve_endpoint("chat") + ep_url, ep_model, ep_headers = resolve_endpoint("chat", owner=user) if not ep_url: from src.database import SessionLocal from src.endpoint_resolver import normalize_base, build_chat_url, build_headers @@ -572,19 +574,18 @@ def setup_research_routes(research_handler, session_manager=None) -> APIRouter: ep_headers = dict(r_headers) if not ep_url or not ep_model: - _merge(*resolve_endpoint("chat")) + _merge(*resolve_endpoint("chat", owner=user)) if not ep_url or not ep_model: - _merge(*resolve_endpoint("research")) + _merge(*resolve_endpoint("research", owner=user)) if not ep_url or not ep_model: - _merge(*resolve_endpoint("utility")) + _merge(*resolve_endpoint("utility", owner=user)) if not ep_url or not ep_model: - # Last resort: any enabled endpoint + # Last resort: this user's enabled endpoint, plus legacy shared rows. from src.database import SessionLocal - from src.database import ModelEndpoint from src.endpoint_resolver import normalize_base, build_chat_url, build_headers db = SessionLocal() try: - ep = db.query(ModelEndpoint).filter(ModelEndpoint.is_enabled == True).first() + ep = _owned_enabled_endpoint(db, user) if ep: base = normalize_base(ep.base_url) fallback_url = build_chat_url(base) @@ -594,7 +595,7 @@ def setup_research_routes(research_handler, session_manager=None) -> APIRouter: try: models = json.loads(ep.cached_models) if models: - fallback_model = models[0] + fallback_model = _first_chat_model(models) except Exception: pass _merge(fallback_url, fallback_model, fallback_headers) diff --git a/routes/session_routes.py b/routes/session_routes.py index 4dbacde0d..9aa94c11d 100644 --- a/routes/session_routes.py +++ b/routes/session_routes.py @@ -924,7 +924,8 @@ def setup_session_routes(session_manager: SessionManager, config: dict, webhook_ from src.endpoint_resolver import resolve_endpoint from src.llm_core import llm_call_async - url, model, headers = resolve_endpoint("utility", owner=get_current_user(request)) + owner = getattr(session, "owner", None) or effective_user(request) + url, model, headers = resolve_endpoint("utility", owner=owner) if not url or not model: url, model, headers = session.endpoint_url, session.model, session.headers if not url or not model: diff --git a/routes/task_routes.py b/routes/task_routes.py index a31d12995..49210f5bc 100644 --- a/routes/task_routes.py +++ b/routes/task_routes.py @@ -291,20 +291,24 @@ def setup_task_routes(task_scheduler) -> APIRouter: def _owner(request: Request): return get_current_user(request) - async def _generate_task_name(prompt: str) -> str: + async def _generate_task_name(prompt: str, owner: Optional[str] = None) -> str: """Use LLM to generate a short task name from the prompt.""" try: from src.llm_core import llm_call_async from core.database import Session as DbSession db = SessionLocal() try: - recent = db.query(DbSession).filter( + q = db.query(DbSession).filter( DbSession.endpoint_url.isnot(None), DbSession.model.isnot(None), - ).order_by(DbSession.created_at.desc()).first() + ) + if owner: + q = q.filter(DbSession.owner == owner) + recent = q.order_by(DbSession.created_at.desc()).first() if not recent: return prompt[:50].strip() url, model = recent.endpoint_url, recent.model + headers = recent.headers or {} finally: db.close() @@ -315,6 +319,7 @@ def setup_task_routes(task_scheduler) -> APIRouter: {"role": "user", "content": prompt[:500]}, ], max_tokens=20, + headers=headers, timeout=15, ) title = result.strip().strip('"\'').strip() @@ -479,7 +484,7 @@ def setup_task_routes(task_scheduler) -> APIRouter: from src.builtin_actions import BUILTIN_ACTION_INFO name = BUILTIN_ACTION_INFO.get(req.action, req.action or "Action Task") elif req.prompt: - name = await _generate_task_name(req.prompt) + name = await _generate_task_name(req.prompt, owner=user) else: name = "Untitled Task" diff --git a/src/context_compactor.py b/src/context_compactor.py index 7da52425a..c87ea4c43 100644 --- a/src/context_compactor.py +++ b/src/context_compactor.py @@ -307,6 +307,7 @@ async def maybe_compact( model: str, messages: List[Dict], headers: Optional[Dict] = None, + owner: Optional[str] = None, ) -> tuple: """Check context usage and compact if above threshold. @@ -353,7 +354,7 @@ async def maybe_compact( ) # Use utility model if configured, otherwise fall back to session model - util_url, util_model, util_headers = resolve_endpoint("utility") + util_url, util_model, util_headers = resolve_endpoint("utility", owner=owner) compact_url = util_url or endpoint_url compact_model = util_model or model compact_headers = util_headers if util_url else headers diff --git a/src/session_actions.py b/src/session_actions.py index 7f0944b2f..7376952d1 100644 --- a/src/session_actions.py +++ b/src/session_actions.py @@ -132,7 +132,7 @@ async def run_auto_sort(owner: str, skip_llm: bool = False, delete_throwaway: bo if skip_llm: return f"Cleaned {deleted_empty + deleted_throwaway} sessions (folder sort skipped)." - url, model, headers = resolve_task_endpoint() + url, model, headers = resolve_task_endpoint(owner=owner or None) if not url: return f"Cleaned {deleted_empty + deleted_throwaway} sessions. No model endpoint available for sorting." diff --git a/src/task_scheduler.py b/src/task_scheduler.py index 96b866720..5cc0e717a 100644 --- a/src/task_scheduler.py +++ b/src/task_scheduler.py @@ -1580,9 +1580,12 @@ class TaskScheduler: try: from core.database import SessionLocal, ModelEndpoint from src.endpoint_resolver import normalize_base, build_headers + from src.auth_helpers import owner_filter db2 = SessionLocal() try: - eps = db2.query(ModelEndpoint).filter(ModelEndpoint.is_enabled == True).all() + ep_q = db2.query(ModelEndpoint).filter(ModelEndpoint.is_enabled == True) + ep_q = owner_filter(ep_q, ModelEndpoint, task.owner or None) + eps = ep_q.all() for ep in eps: if normalize_base(ep.base_url) in endpoint_url or endpoint_url in normalize_base(ep.base_url): headers = build_headers(ep.api_key, normalize_base(ep.base_url)) @@ -1603,7 +1606,7 @@ class TaskScheduler: # chat uses but with the utility list (`utility_model_fallbacks`). try: from src.endpoint_resolver import resolve_utility_fallback_candidates - _task_fallbacks = resolve_utility_fallback_candidates() + _task_fallbacks = resolve_utility_fallback_candidates(owner=task.owner or None) except Exception: _task_fallbacks = [] async for event_str in stream_agent_loop( @@ -1646,7 +1649,7 @@ class TaskScheduler: else: grace_context += "No tool results were captured." grace_context += "\n\nSummarize what you accomplished and what's still pending. Be concise." - _grace_candidates = [(endpoint_url, model, headers)] + resolve_utility_fallback_candidates() + _grace_candidates = [(endpoint_url, model, headers)] + resolve_utility_fallback_candidates(owner=task.owner or None) full_text = await llm_call_async_with_fallback( _grace_candidates, messages=[ @@ -1674,6 +1677,8 @@ class TaskScheduler: # Resolve endpoint/model: research settings > task settings > session defaults endpoint_url = task.endpoint_url model = task.model + headers = {} + headers_from_resolver = False if not endpoint_url or not model: try: @@ -1683,9 +1688,13 @@ class TaskScheduler: endpoint_url or None, model or None, None, + owner=task.owner or None, ) endpoint_url = ep_url or endpoint_url model = ep_model or model + if ep_headers is not None: + headers = ep_headers + headers_from_resolver = True except Exception: pass @@ -1697,16 +1706,19 @@ class TaskScheduler: self._last_run_model = model # Resolve headers - headers = {} try: from core.database import ModelEndpoint from src.endpoint_resolver import normalize_base, build_headers + from src.auth_helpers import owner_filter db2 = db - eps = db2.query(ModelEndpoint).filter(ModelEndpoint.is_enabled == True).all() - for ep in eps: - if normalize_base(ep.base_url) in endpoint_url or endpoint_url in normalize_base(ep.base_url): - headers = build_headers(ep.api_key, normalize_base(ep.base_url)) - break + if not headers_from_resolver: + ep_q = db2.query(ModelEndpoint).filter(ModelEndpoint.is_enabled == True) + ep_q = owner_filter(ep_q, ModelEndpoint, task.owner or None) + eps = ep_q.all() + for ep in eps: + if normalize_base(ep.base_url) in endpoint_url or endpoint_url in normalize_base(ep.base_url): + headers = build_headers(ep.api_key, normalize_base(ep.base_url)) + break except Exception: pass diff --git a/tests/test_aux_llm_owner_scope.py b/tests/test_aux_llm_owner_scope.py new file mode 100644 index 000000000..233ae5695 --- /dev/null +++ b/tests/test_aux_llm_owner_scope.py @@ -0,0 +1,67 @@ +from pathlib import Path + + +ROOT = Path(__file__).resolve().parents[1] + + +def _src(path: str) -> str: + return (ROOT / path).read_text(encoding="utf-8") + + +def test_registered_manual_compaction_uses_session_owner_for_utility_endpoint(): + session_src = _src("routes/session_routes.py") + + assert 'owner = getattr(session, "owner", None) or effective_user(request)' in session_src + assert 'resolve_endpoint("utility", owner=owner)' in session_src + + +def test_task_name_generation_uses_owner_scoped_session_endpoint(): + src = _src("routes/task_routes.py") + + assert "async def _generate_task_name(prompt: str, owner: Optional[str] = None)" in src + assert "q = q.filter(DbSession.owner == owner)" in src + assert "headers = recent.headers or {}" in src + assert "headers=headers" in src + assert "await _generate_task_name(req.prompt, owner=user)" in src + + +def test_auto_compaction_utility_endpoint_keeps_chat_owner(): + helper_src = _src("routes/chat_helpers.py") + compact_src = _src("src/context_compactor.py") + + assert "owner=user" in helper_src + assert "owner: Optional[str] = None" in compact_src + assert 'resolve_endpoint("utility", owner=owner)' in compact_src + + +def test_background_session_sort_uses_owner_task_endpoint(): + src = _src("src/session_actions.py") + + assert "resolve_task_endpoint(owner=owner or None)" in src + + +def test_scheduler_fallbacks_and_research_headers_are_owner_scoped(): + src = _src("src/task_scheduler.py") + + assert "resolve_utility_fallback_candidates(owner=task.owner or None)" in src + assert 'resolve_endpoint(\n "research",' in src + assert "owner=task.owner or None" in src + assert "headers_from_resolver = False" in src + assert "headers_from_resolver = True" in src + assert "from src.auth_helpers import owner_filter" in src + assert "owner_filter(ep_q, ModelEndpoint, task.owner or None)" in src + + +def test_research_routes_fallbacks_are_owner_scoped(): + src = _src("routes/research_routes.py") + + assert 'resolve_endpoint("research", owner=user)' in src + assert 'resolve_endpoint("utility", owner=user)' in src + assert 'resolve_endpoint("default", owner=user)' in src + assert 'resolve_endpoint("chat", owner=user)' in src + assert '_merge(*resolve_endpoint("chat", owner=user))' in src + assert '_merge(*resolve_endpoint("research", owner=user))' in src + assert '_merge(*resolve_endpoint("utility", owner=user))' in src + assert "ep = _owned_enabled_endpoint(db, user)" in src + assert "db.query(ModelEndpoint).filter(ModelEndpoint.is_enabled == True).first()" not in src + assert "owner = getattr(sess, \"owner\", None) or None" in src diff --git a/tests/test_context_compactor.py b/tests/test_context_compactor.py index 393b4ac57..8b9da3972 100644 --- a/tests/test_context_compactor.py +++ b/tests/test_context_compactor.py @@ -133,7 +133,7 @@ class TestMaybeCompactFourthMessage: cc.get_context_length = lambda url, model: context_length cc.llm_call_async = _fake_summary - cc.resolve_endpoint = lambda which: (None, None, None) + cc.resolve_endpoint = lambda which, owner=None: (None, None, None) cc._update_session_history = lambda *a, **k: None try: return asyncio.run( diff --git a/tests/test_history_compact_tool_calls.py b/tests/test_history_compact_tool_calls.py index b2535d582..41dd3531d 100644 --- a/tests/test_history_compact_tool_calls.py +++ b/tests/test_history_compact_tool_calls.py @@ -79,6 +79,7 @@ class _FakeSession: endpoint_url = "http://example.test/v1" model = "test-model" headers = {} + owner = "session-owner" def __init__(self, history): self.history = history @@ -107,7 +108,11 @@ def _compact_prompt_for(monkeypatch, history): import src.model_context as model_context monkeypatch.setattr(agent_runs, "is_active", lambda session_id: False) - monkeypatch.setattr(endpoint_resolver, "resolve_endpoint", lambda kind, owner=None: (None, None, {})) + def fake_resolve_endpoint(kind, owner=None): + captured.setdefault("resolve_calls", []).append((kind, owner)) + return None, None, {} + + monkeypatch.setattr(endpoint_resolver, "resolve_endpoint", fake_resolve_endpoint) monkeypatch.setattr(llm_core, "llm_call_async", fake_llm_call_async) monkeypatch.setattr(model_context, "estimate_tokens", lambda messages: 100) monkeypatch.setattr(model_context, "get_context_length", lambda endpoint_url, model: 1000) @@ -146,7 +151,11 @@ def _registered_compact_response(monkeypatch, history, active_run=False): import src.llm_core as llm_core monkeypatch.setattr(agent_runs, "is_active", lambda session_id: active_run) - monkeypatch.setattr(endpoint_resolver, "resolve_endpoint", lambda kind, owner=None: (None, None, {})) + def fake_resolve_endpoint(kind, owner=None): + captured.setdefault("resolve_calls", []).append((kind, owner)) + return None, None, {} + + monkeypatch.setattr(endpoint_resolver, "resolve_endpoint", fake_resolve_endpoint) monkeypatch.setattr(llm_core, "llm_call_async", fake_llm_call_async) session = _FakeSession(history) @@ -212,6 +221,24 @@ def test_registered_manual_compact_route_tolerates_none_content(monkeypatch): assert manager.replaced_messages is not None +def test_registered_manual_compact_route_uses_session_owner(monkeypatch): + response, captured, manager = _registered_compact_response( + monkeypatch, + [ + ChatMessage(role="user", content="start"), + ChatMessage(role="assistant", content="tool call"), + ChatMessage(role="tool", content="tool result"), + ChatMessage(role="assistant", content="done"), + ChatMessage(role="user", content="next"), + ChatMessage(role="assistant", content="final"), + ], + ) + + assert response.status_code == 200 + assert manager.replaced_messages is not None + assert ("utility", "session-owner") in captured["resolve_calls"] + + def test_registered_manual_compact_route_rejects_active_agent_run(monkeypatch): response, captured, manager = _registered_compact_response( monkeypatch, diff --git a/tests/test_review_regressions.py b/tests/test_review_regressions.py index 747867e63..cd451111b 100644 --- a/tests/test_review_regressions.py +++ b/tests/test_review_regressions.py @@ -365,7 +365,7 @@ async def test_build_chat_context_incognito_does_not_duplicate_current_user_mess def fake_add_user_message(sess, chat_handler, preprocessed, incognito=False): sess.messages.append({"role": "user", "content": preprocessed.user_content}) - async def fake_maybe_compact(sess, endpoint_url, model, messages, headers): + async def fake_maybe_compact(sess, endpoint_url, model, messages, headers, owner=None): return messages, 123, False monkeypatch.setattr(chat_helpers, "preprocess", fake_preprocess)