diff --git a/routes/gallery_routes.py b/routes/gallery_routes.py index 13d10179d..8bc5438c5 100644 --- a/routes/gallery_routes.py +++ b/routes/gallery_routes.py @@ -53,6 +53,46 @@ def _gallery_image_path(filename: str) -> Path: raise HTTPException(400, "Unsafe gallery filename") return path + +def _normalize_image_endpoint_base(url: str) -> str: + base = (url or "").strip().rstrip("/") + if base.endswith("/v1"): + base = base[:-3].rstrip("/") + return base + + +def _visible_image_endpoint_query(db, owner: str | None): + from src.auth_helpers import owner_filter + q = db.query(ModelEndpoint).filter( + ModelEndpoint.model_type == "image", + ModelEndpoint.is_enabled == True, # noqa: E712 + ) + return owner_filter(q, ModelEndpoint, owner) + + +def _first_visible_image_endpoint(db, owner: str | None): + endpoints = _visible_image_endpoint_query(db, owner).all() + if owner: + for ep in endpoints: + if getattr(ep, "owner", None) == owner: + return ep + return endpoints[0] if endpoints else None + + +def _visible_image_endpoint_for_base(db, base: str, owner: str | None): + target = _normalize_image_endpoint_base(base) + if not target: + return None + fallback = None + for ep in _visible_image_endpoint_query(db, owner).all(): + if _normalize_image_endpoint_base(getattr(ep, "base_url", "")) == target: + if owner and getattr(ep, "owner", None) == owner: + return ep + if fallback is None: + fallback = ep + return fallback + + def setup_gallery_routes() -> APIRouter: router = APIRouter(tags=["gallery"]) @@ -272,7 +312,7 @@ def setup_gallery_routes() -> APIRouter: """AI upscale using img2img with the diffusion server.""" import base64, httpx - require_privilege(request, "can_generate_images") + user = require_privilege(request, "can_generate_images") form = await request.form() file = form.get("image") if not file: raise HTTPException(400, "No image") @@ -284,7 +324,7 @@ def setup_gallery_routes() -> APIRouter: # Find image endpoint db = SessionLocal() try: - ep = db.query(ModelEndpoint).filter(ModelEndpoint.model_type == "image", ModelEndpoint.is_enabled == True).first() + ep = _first_visible_image_endpoint(db, user) finally: db.close() @@ -315,7 +355,7 @@ def setup_gallery_routes() -> APIRouter: """Style transfer using img2img with the diffusion server.""" import base64, httpx - require_privilege(request, "can_generate_images") + user = require_privilege(request, "can_generate_images") form = await request.form() file = form.get("image") prompt = form.get("prompt", "") @@ -327,7 +367,7 @@ def setup_gallery_routes() -> APIRouter: db = SessionLocal() try: - ep = db.query(ModelEndpoint).filter(ModelEndpoint.model_type == "image", ModelEndpoint.is_enabled == True).first() + ep = _first_visible_image_endpoint(db, user) finally: db.close() @@ -960,7 +1000,7 @@ def setup_gallery_routes() -> APIRouter: the request for /v1/images/edits (multipart, inverted mask). Otherwise proxy through to a self-hosted diffusion server's /v1/images/inpaint.""" import httpx - require_privilege(request, "can_generate_images") + user = require_privilege(request, "can_generate_images") body = await request.json() # Use endpoint from request body (editor dropdown) or fall back to DB lookup base = (body.pop("_endpoint", "") or "").rstrip("/") @@ -979,14 +1019,11 @@ def setup_gallery_routes() -> APIRouter: if not base: db = SessionLocal() try: - eps = db.query(ModelEndpoint).filter( - ModelEndpoint.is_enabled == True, - ModelEndpoint.model_type == "image", - ).all() - if not eps: + ep = _first_visible_image_endpoint(db, user) + if not ep: raise HTTPException(400, "No image generation endpoint configured. Serve a diffusion model via Cookbook first.") - base = eps[0].base_url.rstrip("/") - api_key = eps[0].api_key + base = ep.base_url.rstrip("/") + api_key = ep.api_key finally: db.close() else: @@ -1003,10 +1040,9 @@ def setup_gallery_routes() -> APIRouter: _target = _norm_url(base) db = SessionLocal() try: - for ep in db.query(ModelEndpoint).all(): - if _norm_url(ep.base_url) == _target: - api_key = ep.api_key - break + ep = _visible_image_endpoint_for_base(db, _target, user) + if ep: + api_key = ep.api_key finally: db.close() @@ -1158,7 +1194,7 @@ def setup_gallery_routes() -> APIRouter: you get edge blending + lighting unification while keeping the composition recognisable.""" import httpx, base64 as _b64 - require_privilege(request, "can_generate_images") + user = require_privilege(request, "can_generate_images") body = await request.json() image_b64 = body.get("image") @@ -1185,23 +1221,19 @@ def setup_gallery_routes() -> APIRouter: if not base: db = SessionLocal() try: - eps = db.query(ModelEndpoint).filter( - ModelEndpoint.is_enabled == True, - ModelEndpoint.model_type == "image", - ).all() - if not eps: + ep = _first_visible_image_endpoint(db, user) + if not ep: raise HTTPException(400, "No image generation endpoint configured.") - base = eps[0].base_url.rstrip("/") - api_key = eps[0].api_key + base = ep.base_url.rstrip("/") + api_key = ep.api_key finally: db.close() else: db = SessionLocal() try: - for ep in db.query(ModelEndpoint).all(): - if ep.base_url.rstrip("/").removesuffix("/v1").rstrip("/") == base.rstrip("/").removesuffix("/v1").rstrip("/"): - api_key = ep.api_key - break + ep = _visible_image_endpoint_for_base(db, base, user) + if ep: + api_key = ep.api_key finally: db.close() diff --git a/tests/test_gallery_endpoint_matching.py b/tests/test_gallery_endpoint_matching.py index 6bec8f582..8157bb3bf 100644 --- a/tests/test_gallery_endpoint_matching.py +++ b/tests/test_gallery_endpoint_matching.py @@ -1,34 +1,11 @@ -import ast -from pathlib import Path - def test_gallery_url_normalization_bug(): - # Read and parse the actual source file - source_path = Path("routes/gallery_routes.py") - assert source_path.exists(), "gallery_routes.py could not be found" - - source = source_path.read_text(encoding="utf-8") - tree = ast.parse(source) - - # Locate the comparison node within harmonize_image that references ep.base_url and base - compare_node = None - for node in ast.walk(tree): - if isinstance(node, ast.Compare): - segment = ast.get_source_segment(source, node) or "" - if "ep.base_url" in segment and "base" in segment and "_norm_url" not in segment: - compare_node = node - break - - assert compare_node is not None, "Could not find the ep.base_url vs base comparison inside gallery_routes.py" - - # Compile the compare node into an expression - expr = ast.Expression(body=compare_node) - compiled_code = compile(expr, "", "eval") - + from routes.gallery_routes import _normalize_image_endpoint_base + def check_match(ep_url: str, base_url: str) -> bool: - class MockEP: - def __init__(self, url): - self.base_url = url - return eval(compiled_code, {}, {"ep": MockEP(ep_url), "base": base_url}) + return ( + _normalize_image_endpoint_base(ep_url) + == _normalize_image_endpoint_base(base_url) + ) # Test cases that SHOULD NOT match under a correct implementation # (Buggy rstrip('/v1') logic incorrectly treats these as equal) diff --git a/tests/test_gallery_image_endpoint_owner_scope.py b/tests/test_gallery_image_endpoint_owner_scope.py new file mode 100644 index 000000000..acc193a78 --- /dev/null +++ b/tests/test_gallery_image_endpoint_owner_scope.py @@ -0,0 +1,126 @@ +"""Owner-scope regression for gallery image endpoint selection. + +The image editor/upscale proxies select ``ModelEndpoint`` rows and may copy the +row's stored ``api_key`` for OpenAI-compatible image endpoints. That lookup must +only consider endpoints visible to the caller, otherwise users sharing the same +base URL can borrow another account's private image API key. +""" + +from types import SimpleNamespace + +import routes.gallery_routes as gallery_routes + + +class _Predicate: + def __init__(self, check): + self._check = check + + def __call__(self, row): + return self._check(row) + + def __or__(self, other): + return _Predicate(lambda row: self(row) or other(row)) + + +class _Column: + def __init__(self, name): + self.name = name + + def __eq__(self, value): + return _Predicate(lambda row: getattr(row, self.name) == value) + + +class _ModelEndpoint: + base_url = _Column("base_url") + model_type = _Column("model_type") + is_enabled = _Column("is_enabled") + owner = _Column("owner") + + +class _Query: + def __init__(self, rows): + self._rows = list(rows) + + def filter(self, *predicates): + self._rows = [row for row in self._rows if all(pred(row) for pred in predicates)] + return self + + def all(self): + return list(self._rows) + + +class _DB: + def __init__(self, rows): + self._rows = rows + + def query(self, model): + assert model is _ModelEndpoint + return _Query(self._rows) + + +def _ep(base_url, owner, *, enabled=True, model_type="image", api_key="sk-secret"): + return SimpleNamespace( + base_url=base_url, + owner=owner, + is_enabled=enabled, + model_type=model_type, + api_key=api_key, + ) + + +def _patch_model(monkeypatch): + monkeypatch.setattr(gallery_routes, "ModelEndpoint", _ModelEndpoint) + + +URL = "https://api.example.com/v1" + + +def test_first_visible_image_endpoint_rejects_another_owner(monkeypatch): + _patch_model(monkeypatch) + rows = [_ep(URL, "bob")] + + assert gallery_routes._first_visible_image_endpoint(_DB(rows), "alice") is None + + +def test_first_visible_image_endpoint_prefers_callers_own_row(monkeypatch): + _patch_model(monkeypatch) + rows = [_ep(URL, None, api_key="shared"), _ep(URL, "alice", api_key="own")] + + ep = gallery_routes._first_visible_image_endpoint(_DB(rows), "alice") + + assert ep is not None + assert ep.owner == "alice" + assert ep.api_key == "own" + + +def test_visible_image_endpoint_for_base_rejects_same_url_other_owner(monkeypatch): + _patch_model(monkeypatch) + rows = [_ep(URL, "bob")] + + assert gallery_routes._visible_image_endpoint_for_base(_DB(rows), URL, "alice") is None + + +def test_visible_image_endpoint_for_base_allows_shared_or_own(monkeypatch): + _patch_model(monkeypatch) + rows = [ + _ep("https://other.example/v1", "alice"), + _ep(URL, None, api_key="shared"), + _ep(URL, "alice", api_key="own"), + ] + + ep = gallery_routes._visible_image_endpoint_for_base(_DB(rows), "https://api.example.com", "alice") + + assert ep is not None + assert ep.owner == "alice" + assert ep.api_key == "own" + assert ep.base_url == URL + + +def test_image_endpoint_owner_filter_is_noop_in_single_user_mode(monkeypatch): + _patch_model(monkeypatch) + rows = [_ep(URL, "bob")] + + ep = gallery_routes._visible_image_endpoint_for_base(_DB(rows), URL, None) + + assert ep is not None + assert ep.owner == "bob"