mirror of
https://github.com/pewdiepie-archdaemon/odysseus.git
synced 2026-06-17 10:15:27 -04:00
Chat: route image sessions only to matching image endpoints
Co-authored-by: ghreprimand <203024559+ghreprimand@users.noreply.github.com>
This commit is contained in:
+55
-22
@@ -43,6 +43,7 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
# Track active streams for partial-save safety net
|
# Track active streams for partial-save safety net
|
||||||
_active_streams: Dict[str, dict] = {}
|
_active_streams: Dict[str, dict] = {}
|
||||||
|
_IMAGE_MODEL_PREFIXES = ("gpt-image", "dall-e", "chatgpt-image")
|
||||||
|
|
||||||
|
|
||||||
def _stream_set(session_id: str, **fields) -> None:
|
def _stream_set(session_id: str, **fields) -> None:
|
||||||
@@ -98,6 +99,59 @@ def _clear_orphaned_session_endpoint(sess) -> bool:
|
|||||||
db.close()
|
db.close()
|
||||||
|
|
||||||
|
|
||||||
|
def _endpoint_cache_contains_model(endpoint, model: str) -> bool:
|
||||||
|
"""Return True when a populated endpoint model cache includes ``model``.
|
||||||
|
|
||||||
|
Empty/malformed caches are treated as unknown rather than a negative match
|
||||||
|
so older image endpoints without cached models still work.
|
||||||
|
"""
|
||||||
|
raw = getattr(endpoint, "cached_models", None)
|
||||||
|
if not raw:
|
||||||
|
return True
|
||||||
|
try:
|
||||||
|
models = json.loads(raw) if isinstance(raw, str) else raw
|
||||||
|
except Exception:
|
||||||
|
return True
|
||||||
|
if not isinstance(models, list) or not models:
|
||||||
|
return True
|
||||||
|
wanted = (model or "").strip()
|
||||||
|
return wanted in {str(item).strip() for item in models}
|
||||||
|
|
||||||
|
|
||||||
|
def _is_image_generation_session(sess) -> bool:
|
||||||
|
"""Whether this chat session should bypass text chat and generate images.
|
||||||
|
|
||||||
|
Model-name prefixes are explicit image models. Endpoint type is only used
|
||||||
|
when the current session endpoint actually matches that image endpoint, and
|
||||||
|
when a populated endpoint model cache includes the selected model. This
|
||||||
|
prevents an image endpoint on the same host from misrouting ordinary text
|
||||||
|
models into the image-generation path.
|
||||||
|
"""
|
||||||
|
model = (getattr(sess, "model", "") or "").strip()
|
||||||
|
if any(model.lower().startswith(prefix) for prefix in _IMAGE_MODEL_PREFIXES):
|
||||||
|
return True
|
||||||
|
|
||||||
|
endpoint_url = (getattr(sess, "endpoint_url", "") or "").strip()
|
||||||
|
if not endpoint_url:
|
||||||
|
return False
|
||||||
|
|
||||||
|
db = SessionLocal()
|
||||||
|
try:
|
||||||
|
endpoints = db.query(ModelEndpoint).filter(ModelEndpoint.is_enabled == True).all()
|
||||||
|
for endpoint in endpoints:
|
||||||
|
if (getattr(endpoint, "model_type", None) or "llm") != "image":
|
||||||
|
continue
|
||||||
|
if not _session_url_matches_endpoint(endpoint_url, getattr(endpoint, "base_url", "") or ""):
|
||||||
|
continue
|
||||||
|
if _endpoint_cache_contains_model(endpoint, model):
|
||||||
|
return True
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
finally:
|
||||||
|
db.close()
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
def _recover_empty_session_model(sess, session_id: str) -> bool:
|
def _recover_empty_session_model(sess, session_id: str) -> bool:
|
||||||
"""Re-populate sess.model from the matching endpoint's cached models.
|
"""Re-populate sess.model from the matching endpoint's cached models.
|
||||||
|
|
||||||
@@ -726,28 +780,7 @@ def setup_chat_routes(
|
|||||||
_model_info["character_name"] = ctx.preset.character_name
|
_model_info["character_name"] = ctx.preset.character_name
|
||||||
yield f'data: {json.dumps(_model_info)}\n\n'
|
yield f'data: {json.dumps(_model_info)}\n\n'
|
||||||
|
|
||||||
# Detect image models and route directly to image generation
|
if _is_image_generation_session(sess):
|
||||||
_IMAGE_MODEL_PREFIXES = ("gpt-image", "dall-e", "chatgpt-image")
|
|
||||||
_is_image_model = any(sess.model.lower().startswith(p) for p in _IMAGE_MODEL_PREFIXES)
|
|
||||||
|
|
||||||
# Also check if the endpoint is registered as an image-type endpoint
|
|
||||||
if not _is_image_model:
|
|
||||||
try:
|
|
||||||
from src.endpoint_resolver import normalize_base as _nb
|
|
||||||
_ep_base = _nb(sess.endpoint_url)
|
|
||||||
_db = SessionLocal()
|
|
||||||
try:
|
|
||||||
_is_image_model = _db.query(ModelEndpoint).filter(
|
|
||||||
ModelEndpoint.model_type == "image",
|
|
||||||
ModelEndpoint.is_enabled == True,
|
|
||||||
ModelEndpoint.base_url.contains(_ep_base.split("://")[-1].split("/")[0]),
|
|
||||||
).first() is not None
|
|
||||||
finally:
|
|
||||||
_db.close()
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
if _is_image_model:
|
|
||||||
from src.settings import get_setting
|
from src.settings import get_setting
|
||||||
if not get_setting("image_gen_enabled", True):
|
if not get_setting("image_gen_enabled", True):
|
||||||
yield f'data: {json.dumps({"delta": "Image generation is disabled by the administrator."})}\n\n'
|
yield f'data: {json.dumps({"delta": "Image generation is disabled by the administrator."})}\n\n'
|
||||||
|
|||||||
@@ -0,0 +1,78 @@
|
|||||||
|
import json
|
||||||
|
from types import SimpleNamespace
|
||||||
|
|
||||||
|
from routes import chat_routes
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeQuery:
|
||||||
|
def __init__(self, rows):
|
||||||
|
self.rows = rows
|
||||||
|
|
||||||
|
def filter(self, *conditions):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def all(self):
|
||||||
|
return list(self.rows)
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeDb:
|
||||||
|
def __init__(self, rows):
|
||||||
|
self.rows = rows
|
||||||
|
self.closed = False
|
||||||
|
|
||||||
|
def query(self, model):
|
||||||
|
return _FakeQuery(self.rows)
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
self.closed = True
|
||||||
|
|
||||||
|
|
||||||
|
def _session(model="qwen3.5:latest", endpoint_url="http://localhost:11434/v1/chat/completions"):
|
||||||
|
return SimpleNamespace(model=model, endpoint_url=endpoint_url)
|
||||||
|
|
||||||
|
|
||||||
|
def _endpoint(base_url, model_type="image", models=None):
|
||||||
|
cached_models = None if models is None else json.dumps(models)
|
||||||
|
return SimpleNamespace(
|
||||||
|
base_url=base_url,
|
||||||
|
model_type=model_type,
|
||||||
|
is_enabled=True,
|
||||||
|
cached_models=cached_models,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_image_model_prefix_routes_to_image_generation_without_endpoint_lookup(monkeypatch):
|
||||||
|
def fail_if_called():
|
||||||
|
raise AssertionError("prefixed image models should not need a DB lookup")
|
||||||
|
|
||||||
|
monkeypatch.setattr(chat_routes, "SessionLocal", fail_if_called)
|
||||||
|
|
||||||
|
assert chat_routes._is_image_generation_session(_session(model="dall-e-3"))
|
||||||
|
|
||||||
|
|
||||||
|
def test_image_endpoint_does_not_catch_text_model_on_different_path(monkeypatch):
|
||||||
|
db = _FakeDb([
|
||||||
|
_endpoint("http://localhost:11434/v1/images", models=["sdxl-local"]),
|
||||||
|
])
|
||||||
|
monkeypatch.setattr(chat_routes, "SessionLocal", lambda: db)
|
||||||
|
|
||||||
|
assert not chat_routes._is_image_generation_session(_session())
|
||||||
|
assert db.closed
|
||||||
|
|
||||||
|
|
||||||
|
def test_image_endpoint_cache_must_contain_selected_model(monkeypatch):
|
||||||
|
db = _FakeDb([
|
||||||
|
_endpoint("http://localhost:11434/v1", models=["sdxl-local"]),
|
||||||
|
])
|
||||||
|
monkeypatch.setattr(chat_routes, "SessionLocal", lambda: db)
|
||||||
|
|
||||||
|
assert not chat_routes._is_image_generation_session(_session(model="qwen3.5:latest"))
|
||||||
|
|
||||||
|
|
||||||
|
def test_matching_image_endpoint_routes_selected_image_model(monkeypatch):
|
||||||
|
db = _FakeDb([
|
||||||
|
_endpoint("http://localhost:11434/v1", models=["sdxl-local"]),
|
||||||
|
])
|
||||||
|
monkeypatch.setattr(chat_routes, "SessionLocal", lambda: db)
|
||||||
|
|
||||||
|
assert chat_routes._is_image_generation_session(_session(model="sdxl-local"))
|
||||||
Reference in New Issue
Block a user