diff --git a/routes/preset_routes.py b/routes/preset_routes.py index 4f6814fb6..20c6c830a 100644 --- a/routes/preset_routes.py +++ b/routes/preset_routes.py @@ -9,6 +9,7 @@ from pydantic import BaseModel, Field from src.request_models import PresetUpdateRequest from core.middleware import require_admin +from src.auth_helpers import effective_user logger = logging.getLogger(__name__) @@ -100,7 +101,8 @@ def setup_preset_routes(preset_manager) -> APIRouter: try: model_spec = data.get("model") or "" - url, model, headers = _resolve_model(model_spec) + user = effective_user(request) + url, model, headers = _resolve_model(model_spec, owner=user) result = await llm_call_async(url, model, messages, temperature=0.8, max_tokens=500, headers=headers) return {"success": True, "prompt": result.strip()} except Exception as e: diff --git a/routes/skills_routes.py b/routes/skills_routes.py index 8a7c5c269..3d6ede921 100644 --- a/routes/skills_routes.py +++ b/routes/skills_routes.py @@ -1020,7 +1020,7 @@ def _resolve_audit_models(owner=None): spec = (get_setting("teacher_model", "") or "").strip() if spec: from src.ai_interaction import _resolve_model - t_url, t_model, t_headers = _resolve_model(spec) + t_url, t_model, t_headers = _resolve_model(spec, owner=owner) if t_url and t_model: teacher = (t_url, t_model, t_headers) except Exception as e: diff --git a/src/teacher_escalation.py b/src/teacher_escalation.py index e830ce17f..94d9ee81c 100644 --- a/src/teacher_escalation.py +++ b/src/teacher_escalation.py @@ -229,12 +229,13 @@ portable across users / hosts. """ -async def _call_teacher(teacher_model_spec: str, prompt: str) -> Optional[str]: +async def _call_teacher(teacher_model_spec: str, prompt: str, + owner: Optional[str] = None) -> Optional[str]: """Call the configured teacher endpoint with the escalation prompt.""" from src.llm_core import llm_call_async from src.ai_interaction import _resolve_model, _TEACHER_SYSTEM_PROMPT try: - url, model, headers = _resolve_model(teacher_model_spec) + url, model, headers = _resolve_model(teacher_model_spec, owner=owner) except Exception as e: logger.warning(f"teacher endpoint not resolvable ({teacher_model_spec!r}): {e}") return None @@ -388,7 +389,7 @@ async def escalate_and_learn( untrusted_trace_guard=_UNTRUSTED_TRACE_GUARD, trace=_format_trace(tool_results, agent_reply), ) - response = await _call_teacher(teacher_spec, prompt) + response = await _call_teacher(teacher_spec, prompt, owner=owner) if not response: return None @@ -523,7 +524,7 @@ async def run_teacher_inline( # Resolve teacher endpoint try: from src.ai_interaction import _resolve_model - teacher_url, teacher_model, teacher_headers = _resolve_model(teacher_spec) + teacher_url, teacher_model, teacher_headers = _resolve_model(teacher_spec, owner=owner) except Exception as e: logger.warning(f"teacher endpoint not resolvable ({teacher_spec!r}): {e}") yield ( @@ -617,7 +618,7 @@ async def run_teacher_inline( untrusted_trace_guard=_UNTRUSTED_TRACE_GUARD, trace=_format_trace(captured_tool_events, teacher_text), ) - skill_response = await _call_teacher(teacher_spec, prompt) + skill_response = await _call_teacher(teacher_spec, prompt, owner=owner) if skill_response and "NO_SKILL" in skill_response and not _extract_skill_json(skill_response): logger.info("teacher declined to write a skill (NO_SKILL)") yield ( diff --git a/tests/test_preset_expand_owner_scope.py b/tests/test_preset_expand_owner_scope.py new file mode 100644 index 000000000..4fc3e1123 --- /dev/null +++ b/tests/test_preset_expand_owner_scope.py @@ -0,0 +1,86 @@ +"""Route-level owner-scope test for POST /api/presets/expand. + +`expand_character_prompt` resolves a model endpoint to run its LLM call. It must +scope that lookup to the calling user, otherwise it can resolve another owner's +ModelEndpoint (and its decrypted api_key) in a multi-user deployment. See #2283. +""" + +import asyncio +from types import SimpleNamespace +from unittest.mock import MagicMock + +from routes.preset_routes import setup_preset_routes + + +class _FakeRequest: + """Minimal stand-in: an async ``json()`` body plus a ``state`` namespace.""" + + def __init__(self, body, **state): + self._body = body + self.state = SimpleNamespace(**state) + + async def json(self): + return self._body + + +def _expand_endpoint(): + router = setup_preset_routes(MagicMock()) + for route in router.routes: + if getattr(route, "path", "") == "/api/presets/expand" and "POST" in getattr(route, "methods", set()): + return route.endpoint + raise AssertionError("POST /api/presets/expand route not registered") + + +def _patch_model_pipeline(monkeypatch): + """Capture the owner passed to _resolve_model and stub the LLM call.""" + seen = {} + + def fake_resolve_model(spec, owner=None): + seen["spec"] = spec + seen["owner"] = owner + return ("http://endpoint.local/v1", "test-model", {}) + + async def fake_llm_call_async(url, model, messages, **kwargs): + return " expanded prompt " + + monkeypatch.setattr("src.ai_interaction._resolve_model", fake_resolve_model) + monkeypatch.setattr("src.llm_core.llm_call_async", fake_llm_call_async) + return seen + + +def test_expand_scopes_model_resolution_to_cookie_user(monkeypatch): + seen = _patch_model_pipeline(monkeypatch) + endpoint = _expand_endpoint() + + req = _FakeRequest({"name": "Pirate", "prompt": "talks like a pirate", "model": "test-model"}, + current_user="alice") + result = asyncio.run(endpoint(req)) + + assert seen["owner"] == "alice" + assert seen["spec"] == "test-model" + assert result == {"success": True, "prompt": "expanded prompt"} + + +def test_expand_attributes_bearer_token_to_its_owner(monkeypatch): + # effective_user (not get_current_user) resolves a bearer ody_ caller to the + # token's real owner instead of the sandbox "api" pseudo-user. + seen = _patch_model_pipeline(monkeypatch) + endpoint = _expand_endpoint() + + req = _FakeRequest({"name": "Pirate", "model": ""}, + current_user="api", api_token=True, api_token_owner="bob") + asyncio.run(endpoint(req)) + + assert seen["owner"] == "bob" + + +def test_expand_short_circuits_without_input(monkeypatch): + seen = _patch_model_pipeline(monkeypatch) + endpoint = _expand_endpoint() + + req = _FakeRequest({}, current_user="alice") + result = asyncio.run(endpoint(req)) + + # Nothing to expand: no model resolution attempted. + assert result["success"] is False + assert "owner" not in seen diff --git a/tests/test_teacher_audit_owner_scope.py b/tests/test_teacher_audit_owner_scope.py new file mode 100644 index 000000000..5bd6228d9 --- /dev/null +++ b/tests/test_teacher_audit_owner_scope.py @@ -0,0 +1,64 @@ +"""Owner-scope tests for the remaining _resolve_model call sites. + +Both the teacher-escalation path and the skill-audit teacher resolution map a +model spec to an endpoint (and its decrypted api_key). Like /presets/expand, +that lookup must be scoped to the calling user, otherwise it can resolve another +owner's ModelEndpoint in a multi-user deployment. See #2283. +""" + +import asyncio + +import src.teacher_escalation as teacher_escalation +import routes.skills_routes as skills_routes + + +def test_call_teacher_scopes_model_resolution_to_owner(monkeypatch): + seen = {} + + def fake_resolve_model(spec, owner=None): + seen["spec"] = spec + seen["owner"] = owner + return ("http://endpoint.local/v1", "teacher-model", {}) + + async def fake_llm_call_async(url, model, messages, **kwargs): + return "teacher reply" + + monkeypatch.setattr("src.ai_interaction._resolve_model", fake_resolve_model) + monkeypatch.setattr("src.ai_interaction._TEACHER_SYSTEM_PROMPT", "sys", raising=False) + monkeypatch.setattr("src.llm_core.llm_call_async", fake_llm_call_async) + + result = asyncio.run( + teacher_escalation._call_teacher("teacher-model", "prompt", owner="alice") + ) + + assert result == "teacher reply" + assert seen["owner"] == "alice" + assert seen["spec"] == "teacher-model" + + +def test_audit_teacher_resolution_scoped_to_owner(monkeypatch): + seen = {} + + def fake_resolve_endpoint(role, owner=None): + return ("http://worker.local/v1", "worker-model", {}) + + def fake_get_setting(key, default=None): + return {"teacher_enabled": True, "teacher_model": "teacher-model"}.get(key, default) + + def fake_resolve_model(spec, owner=None): + seen["spec"] = spec + seen["owner"] = owner + return ("http://endpoint.local/v1", "teacher-model", {}) + + monkeypatch.setattr("src.endpoint_resolver.resolve_endpoint", fake_resolve_endpoint) + monkeypatch.setattr("src.settings.get_setting", fake_get_setting) + monkeypatch.setattr("src.ai_interaction._resolve_model", fake_resolve_model) + # list_model_ids is best-effort; force it to no-op so the worker model passes through. + monkeypatch.setattr("src.llm_core.list_model_ids", lambda url, headers=None: []) + + url, model, headers, teacher = skills_routes._resolve_audit_models(owner="alice") + + assert (url, model) == ("http://worker.local/v1", "worker-model") + assert teacher == ("http://endpoint.local/v1", "teacher-model", {}) + assert seen["owner"] == "alice" + assert seen["spec"] == "teacher-model"