diff --git a/routes/compare_routes.py b/routes/compare_routes.py index 35cd21289..ad42f1a89 100644 --- a/routes/compare_routes.py +++ b/routes/compare_routes.py @@ -12,6 +12,7 @@ import logging from core.database import Comparison, SessionLocal from core.session_manager import SessionManager from src.auth_helpers import get_current_user +from routes.session_routes import _reject_raw_endpoint_url_for_non_admin logger = logging.getLogger(__name__) @@ -38,6 +39,24 @@ def _owned_endpoint_by_url(db, base_url, owner): return owner_filter(q, ModelEndpoint, owner).first() +def _owned_endpoint_by_id(db, endpoint_id, owner): + """ModelEndpoint whose id == `endpoint_id` and is VISIBLE to `owner` (their + own rows + legacy null-owner "shared" rows); None otherwise. + + Preferred over _owned_endpoint_by_url for credential resolution: two visible + endpoints can share the same base_url but hold DIFFERENT api_keys (e.g. two + accounts on the same provider). A base_url-only match returns whichever row + sorts first, so it can copy the WRONG owner-scoped key into the [CMP] session. + An id pins the exact registered endpoint, so /api/compare/start prefers it and + only falls back to URL matching for legacy / admin raw-URL callers. Owner + scoping is identical to _owned_endpoint_by_url (a null/empty owner is a no-op). + """ + from core.database import ModelEndpoint + from src.auth_helpers import owner_filter + q = db.query(ModelEndpoint).filter(ModelEndpoint.id == endpoint_id) + return owner_filter(q, ModelEndpoint, owner).first() + + class RecordVoteRequest(BaseModel): prompt: str models: List[str] @@ -54,8 +73,10 @@ def setup_compare_routes(session_manager: SessionManager): prompt: str = Form(...), model_a: str = Form(...), model_b: str = Form(...), - endpoint_a: str = Form(...), - endpoint_b: str = Form(...), + endpoint_a: str = Form(""), + endpoint_b: str = Form(""), + endpoint_a_id: str = Form(""), + endpoint_b_id: str = Form(""), is_blind: str = Form("true"), ): """Create two ephemeral sessions and a comparison record. @@ -63,10 +84,10 @@ def setup_compare_routes(session_manager: SessionManager): Returns the comparison ID and the two session IDs so the client can fire two independent SSE streams to /api/chat_stream. """ + user = getattr(request.state, 'current_user', None) comp_id = str(uuid.uuid4()) sid_a = str(uuid.uuid4()) sid_b = str(uuid.uuid4()) - user = getattr(request.state, 'current_user', None) # Blind mapping: randomly assign left/right blind = str(is_blind).lower() == "true" @@ -87,31 +108,94 @@ def setup_compare_routes(session_manager: SessionManager): # de-anonymizing the comparison before the user votes (issue #1285). slot_name = {session_left: "Model A", session_right: "Model B"} - # Create ephemeral sessions (prefixed [CMP]) - for sid, model, endpoint in [(sid_a, model_a, endpoint_a), (sid_b, model_b, endpoint_b)]: + # SECURITY: resolve and validate BOTH endpoints before creating any + # session. Compare copies a registered endpoint's Authorization header + # into the [CMP] session, so validating one endpoint while creating its + # session, then rejecting the other, would leave a partial compare + # session behind with that header attached. Doing all the owner-scope + # resolution + raw-URL rejection up front means a 403 on either endpoint + # aborts the whole request with nothing created and no header copied. + from src.endpoint_resolver import build_chat_url, build_headers, normalize_base + resolved = [] + db = SessionLocal() + try: + for sid, model, endpoint, endpoint_id in [ + (sid_a, model_a, endpoint_a, endpoint_a_id), + (sid_b, model_b, endpoint_b, endpoint_b_id), + ]: + # Prefer an explicit endpoint id: it pins the EXACT registered + # endpoint (and its api_key), even when two endpoints visible to + # the caller share a base_url with different keys — a URL-only + # match would copy whichever row sorts first, i.e. possibly the + # wrong key. Fall back to URL resolution only for legacy / admin + # raw-URL callers that don't send an id. + eid = endpoint_id.strip() if isinstance(endpoint_id, str) else "" + if eid: + ep = _owned_endpoint_by_id(db, eid, user) + if ep is None: + # An id the caller can't see (wrong owner / deleted) must + # NOT silently fall back to a same-URL row with a different + # key — that's exactly the mix-up ids exist to prevent. + raise HTTPException(404, "Model endpoint not found") + # The id already resolved the endpoint; ignore any raw URL the + # caller also sent and dial the stored config instead. + endpoint = ep.base_url + elif not endpoint: + raise HTTPException( + 422, "endpoint_a/endpoint_b or endpoint_a_id/endpoint_b_id is required" + ) + else: + # Resolve the supplied URL to a ModelEndpoint the caller owns + # (their own rows + legacy null-owner shared rows), scoped so a + # comparison can't borrow another user's private endpoint key. + base = normalize_base(endpoint) + ep = _owned_endpoint_by_url(db, base, user) + # Reject *unregistered* raw URLs for signed-in non-admins; a + # matched registered endpoint supplies an id so the caller can + # still compare endpoints they own. Blanket-rejecting here (the + # earlier `endpoint_id=None` call) locked non-admins out of + # compare entirely, since compare resolves endpoints by URL with + # no endpoint_id. Mirrors the gallery inpaint/harmonize checks. + # Raised here (phase 1), before any session exists. + _reject_raw_endpoint_url_for_non_admin( + request, user, str(ep.id) if ep is not None else None, endpoint + ) + # Bind the [CMP] session to the RESOLVED endpoint, not the raw + # caller-supplied string. When the URL matches a registered + # endpoint visible to the caller, use that row's own normalized + # base URL (the same value owner scoping + endpoint validation + # already vetted) so the session dials exactly where the stored + # config points. The raw `endpoint` only survives for callers + # allowed to pass one — admins / single-user mode, where + # `_reject_raw_endpoint_url_for_non_admin` is a no-op and `ep` + # is None. Mirrors the registered-endpoint path in session_routes. + session_endpoint_url = ( + build_chat_url(normalize_base(ep.base_url)) if ep is not None else endpoint + ) + # Headers come only from a matched endpoint's key; None when + # `ep` is None (raw admin URL or no match), so a comparison can + # never inherit another user's key/headers. + headers = build_headers(ep.api_key, ep.base_url) if (ep and ep.api_key) else None + resolved.append((sid, model, session_endpoint_url, headers)) + finally: + db.close() + + # Both endpoints validated — only now create the ephemeral [CMP] + # sessions and copy any resolved headers. + for sid, model, session_endpoint_url, headers in resolved: name = f"[CMP] {slot_name[sid]}" if blind else f"[CMP] {model.split('/')[-1]}" session_manager.create_session( session_id=sid, name=name, - endpoint_url=endpoint, + endpoint_url=session_endpoint_url, model=model, rag=False, owner=user, ) - # Copy API key from endpoint config - db = SessionLocal() - try: - from src.endpoint_resolver import build_headers, normalize_base - # Find matching endpoint by URL, scoped to the caller so a - # comparison can't borrow another user's private endpoint key. - base = normalize_base(endpoint) - ep = _owned_endpoint_by_url(db, base, user) - if ep and ep.api_key: - s = session_manager.sessions.get(sid) - if s: - s.headers = build_headers(ep.api_key, ep.base_url) - finally: - db.close() + if headers: + s = session_manager.sessions.get(sid) + if s: + s.headers = headers # Store comparison record db = SessionLocal() @@ -121,8 +205,12 @@ def setup_compare_routes(session_manager: SessionManager): prompt=prompt, model_a=model_a, model_b=model_b, - endpoint_a=endpoint_a, - endpoint_b=endpoint_b, + # Record the URL the session actually dials. For URL callers this + # is their raw input; for id-only callers (empty endpoint_a/_b) + # fall back to the resolved endpoint URL so the column stays + # meaningful and non-null. resolved is in [a, b] order. + endpoint_a=endpoint_a or resolved[0][2], + endpoint_b=endpoint_b or resolved[1][2], is_blind=blind, blind_mapping=json.dumps(mapping), owner=user, diff --git a/routes/gallery_routes.py b/routes/gallery_routes.py index 6f3427eed..ed598f031 100644 --- a/routes/gallery_routes.py +++ b/routes/gallery_routes.py @@ -12,7 +12,7 @@ from fastapi import APIRouter, HTTPException, Query, Request from core.database import SessionLocal, GalleryImage, GalleryAlbum, ModelEndpoint from core.database import Session as DbSession -from src.auth_helpers import get_current_user, require_privilege +from src.auth_helpers import get_current_user, owner_filter, require_privilege from src.upload_limits import read_upload_limited from src.constants import GENERATED_IMAGES_DIR @@ -26,6 +26,19 @@ GALLERY_UPLOAD_MAX_BYTES = int(os.getenv("ODYSSEUS_GALLERY_UPLOAD_MAX_BYTES", st GALLERY_TRANSFORM_UPLOAD_MAX_BYTES = int(os.getenv("ODYSSEUS_GALLERY_TRANSFORM_UPLOAD_MAX_BYTES", str(25 * 1024 * 1024))) +def _current_user_is_admin(request: Request, user: str | None) -> bool: + if not user: + return False + auth_mgr = getattr(request.app.state, "auth_manager", None) + is_admin = getattr(auth_mgr, "is_admin", None) + if not callable(is_admin): + return False + try: + return bool(is_admin(user)) + except Exception: + return False + + def _sanitize_gallery_filename(filename: str) -> str: """Return a local filename safe to join under generated_images.""" safe_name = re.sub(r"[^A-Za-z0-9._-]", "_", Path(str(filename or "")).name)[:128] @@ -1043,7 +1056,10 @@ def setup_gallery_routes() -> APIRouter: try: ep = _visible_image_endpoint_for_base(db, _target, user) if ep: + base = (ep.base_url or base).rstrip("/") api_key = ep.api_key + elif user and not _current_user_is_admin(request, user): + raise HTTPException(403, "Choose a registered image endpoint") finally: db.close() @@ -1234,7 +1250,10 @@ def setup_gallery_routes() -> APIRouter: try: ep = _visible_image_endpoint_for_base(db, base, user) if ep: + base = (ep.base_url or base).rstrip("/") api_key = ep.api_key + elif user and not _current_user_is_admin(request, user): + raise HTTPException(403, "Choose a registered image endpoint") finally: db.close() diff --git a/routes/research_routes.py b/routes/research_routes.py index ea9d207a3..1ef36bd75 100644 --- a/routes/research_routes.py +++ b/routes/research_routes.py @@ -38,9 +38,9 @@ def _first_chat_model(models) -> str: return (models[0] if models else "") -def _resolve_research_endpoint(sess) -> tuple: +def _resolve_research_endpoint(sess, owner: Optional[str] = None) -> tuple: """Return (endpoint_url, model, headers) for Deep Research, checking admin overrides.""" - owner = getattr(sess, "owner", None) or None + owner = owner or getattr(sess, "owner", None) or None url, model, headers = resolve_endpoint( "research", fallback_url=sess.endpoint_url, diff --git a/tests/test_aux_llm_owner_scope.py b/tests/test_aux_llm_owner_scope.py index 233ae5695..534a2e429 100644 --- a/tests/test_aux_llm_owner_scope.py +++ b/tests/test_aux_llm_owner_scope.py @@ -64,4 +64,8 @@ def test_research_routes_fallbacks_are_owner_scoped(): assert '_merge(*resolve_endpoint("utility", owner=user))' in src assert "ep = _owned_enabled_endpoint(db, user)" in src assert "db.query(ModelEndpoint).filter(ModelEndpoint.is_enabled == True).first()" not in src - assert "owner = getattr(sess, \"owner\", None) or None" in src + # _resolve_research_endpoint derives the scope from the session owner. The + # rebased code generalized this to honor an explicit `owner` argument first + # (``owner = owner or getattr(sess, "owner", None) or None``), so assert on + # the stable session-derivation substring rather than the exact line. + assert 'getattr(sess, "owner", None) or None' in src diff --git a/tests/test_endpoint_owner_scope_followup.py b/tests/test_endpoint_owner_scope_followup.py new file mode 100644 index 000000000..2d630d506 --- /dev/null +++ b/tests/test_endpoint_owner_scope_followup.py @@ -0,0 +1,414 @@ +"""Regression tests for endpoint owner scoping in secondary model routes.""" + +from pathlib import Path +from types import SimpleNamespace + +import pytest +from fastapi import HTTPException + + +def _compare_request(user="alice", is_admin=False): + return SimpleNamespace( + state=SimpleNamespace(current_user=user), + app=SimpleNamespace( + state=SimpleNamespace( + auth_manager=SimpleNamespace(is_admin=lambda u: is_admin) + ) + ), + ) + + +def _compare_start_route(session_manager): + from routes.compare_routes import setup_compare_routes + + router = setup_compare_routes(session_manager) + # setup_compare_routes registers on a module-global router, so each call + # appends another /start route; take the most recently registered one so we + # get the handler bound to *this* session_manager. + return [ + r.endpoint for r in router.routes + if getattr(r, "path", "") == "/api/compare/start" + ][-1] + + +class _FakeDB: + """The endpoint lookup is patched, so only the trailing Comparison insert + touches this — swallow add/commit/close so the test never hits a real DB.""" + + def add(self, *a, **k): + pass + + def commit(self): + pass + + def close(self): + pass + + +class _SessionStore: + def __init__(self, store): + self._store = store + + def get(self, key, default=None): + return self._store.get(key, default) + + +def test_compare_start_rejects_unregistered_endpoint_for_non_admin(monkeypatch): + import routes.compare_routes as cr + + monkeypatch.setattr(cr, "SessionLocal", lambda: _FakeDB()) + # Nothing visible to the caller matches the supplied URL → raw, unregistered. + monkeypatch.setattr(cr, "_owned_endpoint_by_url", lambda *a, **k: None) + + start = _compare_start_route( + SimpleNamespace(create_session=lambda **_: None, sessions={}) + ) + with pytest.raises(HTTPException) as exc: + start( + _compare_request(), + prompt="p", + model_a="a", + model_b="b", + endpoint_a="http://127.0.0.1:8000/v1", + endpoint_b="http://127.0.0.1:8001/v1", + ) + + assert exc.value.status_code == 403 + + +def test_compare_start_allows_owned_registered_endpoint_for_non_admin(monkeypatch): + # Regression: the followup must not blanket-reject non-admins. Compare + # resolves endpoints by URL (no endpoint_id), so a caller comparing a + # registered endpoint they own has to be allowed — only truly raw, + # unregistered URLs are rejected. + import routes.compare_routes as cr + + monkeypatch.setattr(cr, "SessionLocal", lambda: _FakeDB()) + owned = SimpleNamespace(id=7, api_key="sk-secret", base_url="http://127.0.0.1:8000/v1") + monkeypatch.setattr(cr, "_owned_endpoint_by_url", lambda *a, **k: owned) + + created = {} + + def _create_session(session_id, **_): + created[session_id] = SimpleNamespace(headers={}) + + start = _compare_start_route( + SimpleNamespace(create_session=_create_session, sessions=_SessionStore(created)) + ) + # Must complete without raising 403. + start( + _compare_request(), + prompt="p", + model_a="a", + model_b="b", + endpoint_a="http://127.0.0.1:8000/v1", + endpoint_b="http://127.0.0.1:8000/v1", + ) + + # Both [CMP] sessions created, each with the owned endpoint's key copied in. + assert len(created) == 2 + for s in created.values(): + assert s.headers + + +def test_compare_start_rejects_another_users_private_endpoint(monkeypatch): + # bob owns the endpoint at this URL; alice supplying the same URL gets no + # match from the owner-scoped lookup (owner_filter drops bob's private row), + # so compare treats it exactly like a raw unregistered URL → 403. She can + # neither bind a session to his endpoint nor copy his key. + import routes.compare_routes as cr + + monkeypatch.setattr(cr, "SessionLocal", lambda: _FakeDB()) + + def _scoped(db, base, owner): + # Only the owner ("bob") can see this private row; everyone else → None. + if owner == "bob": + return SimpleNamespace(id=9, api_key="sk-bob", base_url=base) + return None + + monkeypatch.setattr(cr, "_owned_endpoint_by_url", _scoped) + + created = {} + + def _create_session(session_id, **_): + created[session_id] = SimpleNamespace(headers={}) + + start = _compare_start_route( + SimpleNamespace(create_session=_create_session, sessions=_SessionStore(created)) + ) + with pytest.raises(HTTPException) as exc: + start( + _compare_request(user="alice"), + prompt="p", + model_a="a", + model_b="b", + endpoint_a="http://10.0.0.5:9000/v1", + endpoint_b="http://10.0.0.5:9000/v1", + ) + + assert exc.value.status_code == 403 + # Nothing was created → no session bound to bob's endpoint, no key copied. + assert created == {} + + +def test_compare_start_rejects_before_creating_any_session_on_mixed_endpoints(monkeypatch): + # Mixed request: endpoint A is a registered endpoint the caller owns, + # endpoint B is a raw/unregistered URL. Both endpoints are resolved and + # validated up front, so the unregistered B makes the WHOLE request 403 with + # nothing created — no half-built [CMP] session for A, and therefore none of + # A's Authorization header left behind. Fails on the old interleaved loop + # that created A's session before reaching (and rejecting) B. + import routes.compare_routes as cr + from src.endpoint_resolver import normalize_base + + monkeypatch.setattr(cr, "SessionLocal", lambda: _FakeDB()) + owned = SimpleNamespace(id=7, api_key="sk-secret", base_url="http://127.0.0.1:8000/v1") + owned_base = normalize_base(owned.base_url) + + def _scoped(db, base, owner): + # Only endpoint A's URL maps to a visible registered endpoint; B → None. + return owned if base == owned_base else None + + monkeypatch.setattr(cr, "_owned_endpoint_by_url", _scoped) + + created = {} + + def _create_session(session_id, **kw): + created[session_id] = SimpleNamespace(headers={}) + + start = _compare_start_route( + SimpleNamespace(create_session=_create_session, sessions=_SessionStore(created)) + ) + with pytest.raises(HTTPException) as exc: + start( + _compare_request(), + prompt="p", + model_a="a", + model_b="b", + endpoint_a="http://127.0.0.1:8000/v1", # owned, registered + endpoint_b="http://203.0.113.9:9999/v1", # raw, unregistered + ) + + assert exc.value.status_code == 403 + # No partial session survives the reject, so no copied header does either. + assert created == {} + + +def test_compare_start_binds_session_to_registered_endpoint_url(monkeypatch): + # The session must dial the registered endpoint's OWN normalized base URL, + # never the raw caller-supplied string. Mint the owned row with a base URL + # that differs from the messy raw input so a regression to `endpoint_url= + # endpoint` would surface here. + import routes.compare_routes as cr + from src.endpoint_resolver import build_chat_url, normalize_base + + monkeypatch.setattr(cr, "SessionLocal", lambda: _FakeDB()) + owned = SimpleNamespace(id=7, api_key="sk-secret", base_url="http://127.0.0.1:8000/v1") + monkeypatch.setattr(cr, "_owned_endpoint_by_url", lambda *a, **k: owned) + + created = {} + captured = {} + + def _create_session(session_id, **kw): + created[session_id] = SimpleNamespace(headers={}) + captured[session_id] = kw + + start = _compare_start_route( + SimpleNamespace(create_session=_create_session, sessions=_SessionStore(created)) + ) + raw_url = "http://127.0.0.1:8000/v1/" # trailing slash → not byte-identical + start( + _compare_request(), + prompt="p", + model_a="a", + model_b="b", + endpoint_a=raw_url, + endpoint_b=raw_url, + ) + + expected = build_chat_url(normalize_base(owned.base_url)) + assert captured and all(kw["endpoint_url"] == expected for kw in captured.values()) + # The owned endpoint's key is copied into each session's headers. + for s in created.values(): + assert s.headers + + +def test_compare_start_admin_raw_endpoint_carries_no_borrowed_key(monkeypatch): + # Explicit admin/raw-endpoint behavior: an admin may pass a raw URL that + # matches no registered endpoint. It is allowed (the reject helper is a + # no-op for admins), the session keeps the raw URL, and — because nothing + # matched — no key/headers are inherited from any endpoint row. + import routes.compare_routes as cr + + monkeypatch.setattr(cr, "SessionLocal", lambda: _FakeDB()) + monkeypatch.setattr(cr, "_owned_endpoint_by_url", lambda *a, **k: None) + + created = {} + captured = {} + + def _create_session(session_id, **kw): + created[session_id] = SimpleNamespace(headers={}) + captured[session_id] = kw + + start = _compare_start_route( + SimpleNamespace(create_session=_create_session, sessions=_SessionStore(created)) + ) + raw_url = "http://198.51.100.7:1234/v1" + start( + _compare_request(user="root", is_admin=True), + prompt="p", + model_a="a", + model_b="b", + endpoint_a=raw_url, + endpoint_b=raw_url, + ) + + assert len(created) == 2 + for kw in captured.values(): + assert kw["endpoint_url"] == raw_url # raw URL preserved for admins + for s in created.values(): + assert s.headers == {} # no borrowed key/headers + + +def test_compare_start_prefers_endpoint_id_over_url(monkeypatch): + # Two endpoints visible to the caller share a base_url but hold DIFFERENT + # api_keys (e.g. two accounts on one provider). A base_url-only match returns + # whichever row sorts first, so it can copy the WRONG key. Passing the + # explicit id must pin the intended endpoint and copy ITS key. + import routes.compare_routes as cr + from src.endpoint_resolver import build_chat_url, build_headers, normalize_base + + monkeypatch.setattr(cr, "SessionLocal", lambda: _FakeDB()) + + url = "http://127.0.0.1:8000/v1" + by_url = SimpleNamespace(id=1, api_key="sk-first", base_url=url) # URL match + by_id = SimpleNamespace(id=2, api_key="sk-second", base_url=url) # id match + + # URL resolution would return the WRONG row; the id resolves the intended one. + monkeypatch.setattr(cr, "_owned_endpoint_by_url", lambda *a, **k: by_url) + monkeypatch.setattr( + cr, "_owned_endpoint_by_id", lambda db, eid, owner: by_id if eid == "2" else None + ) + + created = {} + captured = {} + + def _create_session(session_id, **kw): + created[session_id] = SimpleNamespace(headers={}) + captured[session_id] = kw + + start = _compare_start_route( + SimpleNamespace(create_session=_create_session, sessions=_SessionStore(created)) + ) + start( + _compare_request(), + prompt="p", + model_a="a", + model_b="b", + endpoint_a="", + endpoint_b="", + endpoint_a_id="2", + endpoint_b_id="2", + ) + + expected_url = build_chat_url(normalize_base(url)) + expected_headers = build_headers("sk-second", url) + assert captured and all(kw["endpoint_url"] == expected_url for kw in captured.values()) + # The id's key is copied in, NOT the same-URL row's key. + for s in created.values(): + assert s.headers == expected_headers + + +def test_compare_start_rejects_unowned_endpoint_id(monkeypatch): + # An id the caller can't see (wrong owner / deleted) must 404 and must NOT + # silently fall back to a same-URL row with a different key. + import routes.compare_routes as cr + + monkeypatch.setattr(cr, "SessionLocal", lambda: _FakeDB()) + # A same-URL row exists and would resolve, but the governing id is invisible. + monkeypatch.setattr( + cr, + "_owned_endpoint_by_url", + lambda *a, **k: SimpleNamespace(id=1, api_key="sk", base_url="http://127.0.0.1:8000/v1"), + ) + monkeypatch.setattr(cr, "_owned_endpoint_by_id", lambda *a, **k: None) + + created = {} + + def _create_session(session_id, **_): + created[session_id] = SimpleNamespace(headers={}) + + start = _compare_start_route( + SimpleNamespace(create_session=_create_session, sessions=_SessionStore(created)) + ) + with pytest.raises(HTTPException) as exc: + start( + _compare_request(), + prompt="p", + model_a="a", + model_b="b", + endpoint_a="", + endpoint_b="", + endpoint_a_id="999", + endpoint_b_id="999", + ) + + assert exc.value.status_code == 404 + assert created == {} + + +def test_compare_endpoint_key_lookup_is_owner_scoped(): + body = Path("routes/compare_routes.py").read_text(encoding="utf-8") + start_body = body.split("def start_comparison", 1)[1].split("# Store comparison record", 1)[0] + helper_body = body.split("def _owned_endpoint_by_url", 1)[1].split("class RecordVoteRequest", 1)[0] + id_helper_body = body.split("def _owned_endpoint_by_id", 1)[1].split("class RecordVoteRequest", 1)[0] + + assert "_reject_raw_endpoint_url_for_non_admin" in start_body + assert "_owned_endpoint_by_url(db, base, user)" in start_body + # Credentials prefer an explicit endpoint id (pins the exact key) and only + # fall back to URL matching for legacy / admin raw-URL callers. + assert "_owned_endpoint_by_id(db, eid, user)" in start_body + # The session binds to the resolved endpoint's stored base URL, not the raw + # caller-supplied string (the reviewer's remaining compare blocker). + assert "build_chat_url(normalize_base(ep.base_url))" in start_body + assert "owner_filter(q, ModelEndpoint, owner)" in helper_body + # The id lookup is owner-scoped the same way the URL lookup is. + assert "owner_filter(q, ModelEndpoint, owner)" in id_helper_body + + +def test_gallery_image_endpoint_lookups_are_owner_scoped(): + body = Path("routes/gallery_routes.py").read_text(encoding="utf-8") + helper_body = body.split("def _visible_image_endpoint_query", 1)[1].split( + "def _first_visible_image_endpoint", 1 + )[0] + + assert "owner_filter(q, ModelEndpoint, owner)" in helper_body + assert body.count("_first_visible_image_endpoint(db, user)") >= 4 + assert body.count("_visible_image_endpoint_for_base(db,") >= 2 + assert "def _current_user_is_admin" in body + assert body.count('raise HTTPException(403, "Choose a registered image endpoint")') == 2 + for marker in ( + "async def gallery_ai_upscale", + "async def gallery_style_transfer", + "async def inpaint_proxy", + "async def harmonize_image", + ): + section = body.split(marker, 1)[1].split("@router.", 1)[0] + assert "user = require_privilege(request, \"can_generate_images\")" in section + assert ( + "_first_visible_image_endpoint(db, user)" in section + or "_visible_image_endpoint_for_base(db," in section + ) + + +def test_research_endpoint_resolution_passes_owner(): + body = Path("routes/research_routes.py").read_text(encoding="utf-8") + + assert "def _resolve_research_endpoint(sess, owner:" in body + assert 'resolve_endpoint("research", owner=user)' in body + assert 'resolve_endpoint("utility", owner=user)' in body + assert 'resolve_endpoint("default", owner=user)' in body + assert 'resolve_endpoint("chat", owner=user)' in body + helper_body = body.split("def _owned_enabled_endpoint", 1)[1].split("def setup_research_routes", 1)[0] + assert "owner_filter(q, ModelEndpoint, owner)" in helper_body + assert body.count("_owned_enabled_endpoint(db, user") >= 2 diff --git a/tests/test_gallery_image_privileges.py b/tests/test_gallery_image_privileges.py index 2fe21c385..9be5383ab 100644 --- a/tests/test_gallery_image_privileges.py +++ b/tests/test_gallery_image_privileges.py @@ -37,4 +37,6 @@ def test_image_generation_endpoints_require_image_privilege(): def test_gallery_routes_imports_privilege_helper(): - assert "from src.auth_helpers import get_current_user, require_privilege" in _gallery_source() + source = _gallery_source() + assert "get_current_user" in source + assert "require_privilege" in source