Chat: route image sessions only to matching image endpoints

Co-authored-by: ghreprimand <203024559+ghreprimand@users.noreply.github.com>
This commit is contained in:
ghreprimand
2026-06-02 06:52:03 -05:00
committed by GitHub
parent 064c1ace91
commit 4cec31d988
2 changed files with 133 additions and 22 deletions
+55 -22
View File
@@ -43,6 +43,7 @@ logger = logging.getLogger(__name__)
# Track active streams for partial-save safety net
_active_streams: Dict[str, dict] = {}
_IMAGE_MODEL_PREFIXES = ("gpt-image", "dall-e", "chatgpt-image")
def _stream_set(session_id: str, **fields) -> None:
@@ -98,6 +99,59 @@ def _clear_orphaned_session_endpoint(sess) -> bool:
db.close()
def _endpoint_cache_contains_model(endpoint, model: str) -> bool:
"""Return True when a populated endpoint model cache includes ``model``.
Empty/malformed caches are treated as unknown rather than a negative match
so older image endpoints without cached models still work.
"""
raw = getattr(endpoint, "cached_models", None)
if not raw:
return True
try:
models = json.loads(raw) if isinstance(raw, str) else raw
except Exception:
return True
if not isinstance(models, list) or not models:
return True
wanted = (model or "").strip()
return wanted in {str(item).strip() for item in models}
def _is_image_generation_session(sess) -> bool:
"""Whether this chat session should bypass text chat and generate images.
Model-name prefixes are explicit image models. Endpoint type is only used
when the current session endpoint actually matches that image endpoint, and
when a populated endpoint model cache includes the selected model. This
prevents an image endpoint on the same host from misrouting ordinary text
models into the image-generation path.
"""
model = (getattr(sess, "model", "") or "").strip()
if any(model.lower().startswith(prefix) for prefix in _IMAGE_MODEL_PREFIXES):
return True
endpoint_url = (getattr(sess, "endpoint_url", "") or "").strip()
if not endpoint_url:
return False
db = SessionLocal()
try:
endpoints = db.query(ModelEndpoint).filter(ModelEndpoint.is_enabled == True).all()
for endpoint in endpoints:
if (getattr(endpoint, "model_type", None) or "llm") != "image":
continue
if not _session_url_matches_endpoint(endpoint_url, getattr(endpoint, "base_url", "") or ""):
continue
if _endpoint_cache_contains_model(endpoint, model):
return True
except Exception:
return False
finally:
db.close()
return False
def _recover_empty_session_model(sess, session_id: str) -> bool:
"""Re-populate sess.model from the matching endpoint's cached models.
@@ -726,28 +780,7 @@ def setup_chat_routes(
_model_info["character_name"] = ctx.preset.character_name
yield f'data: {json.dumps(_model_info)}\n\n'
# Detect image models and route directly to image generation
_IMAGE_MODEL_PREFIXES = ("gpt-image", "dall-e", "chatgpt-image")
_is_image_model = any(sess.model.lower().startswith(p) for p in _IMAGE_MODEL_PREFIXES)
# Also check if the endpoint is registered as an image-type endpoint
if not _is_image_model:
try:
from src.endpoint_resolver import normalize_base as _nb
_ep_base = _nb(sess.endpoint_url)
_db = SessionLocal()
try:
_is_image_model = _db.query(ModelEndpoint).filter(
ModelEndpoint.model_type == "image",
ModelEndpoint.is_enabled == True,
ModelEndpoint.base_url.contains(_ep_base.split("://")[-1].split("/")[0]),
).first() is not None
finally:
_db.close()
except Exception:
pass
if _is_image_model:
if _is_image_generation_session(sess):
from src.settings import get_setting
if not get_setting("image_gen_enabled", True):
yield f'data: {json.dumps({"delta": "Image generation is disabled by the administrator."})}\n\n'