mirror of
https://github.com/pewdiepie-archdaemon/odysseus.git
synced 2026-06-15 17:25:26 -04:00
refactor(tests): replace local function copies in test_endpoint_resolver with real imports (#3359)
* refactor(tests): replace local function copies in test_endpoint_resolver with real imports The test file carried 9 verbatim copies of src/endpoint_resolver.py functions to avoid import-pollution concerns, but these copies are a drift hazard — PR #3343 had to update both in parallel. Replace them with direct imports so future changes to endpoint_resolver are automatically exercised by the test suite. Also fixes _ollama_api_root in endpoint_resolver.py: the bare-URL Ollama case (e.g. http://nas:11434 with empty path) was already handled correctly in the test copy but was missing from the real function, which would return /chat instead of /api/chat for native Ollama endpoints without an explicit /api prefix. Closes #3351 * refactor: import _ollama_api_root from llm_core instead of duplicating it endpoint_resolver already imports _detect_provider and _host_match from llm_core. Add _ollama_api_root to that import and remove the local copy, collapsing two implementations to one source of truth. llm_core's version is a superset (also strips /api/chat|tags|generate paths), and since normalize_base already removes those suffixes upstream the result is identical for every input used here.
This commit is contained in:
@@ -12,7 +12,7 @@ from typing import Optional, Tuple, Dict
|
||||
from urllib.parse import urlparse, urlunparse
|
||||
|
||||
from core.database import SessionLocal, ModelEndpoint
|
||||
from src.llm_core import _detect_provider, _host_match
|
||||
from src.llm_core import _detect_provider, _host_match, _ollama_api_root
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -150,19 +150,6 @@ def _anthropic_api_root(base: str) -> str:
|
||||
return base
|
||||
|
||||
|
||||
def _ollama_api_root(base: str) -> str:
|
||||
"""Return the native Ollama API root, adding /api for ollama.com hosts."""
|
||||
base = (base or "").strip().rstrip("/")
|
||||
parsed = urlparse(base)
|
||||
path = (parsed.path or "").rstrip("/")
|
||||
if path.endswith("/api"):
|
||||
return base
|
||||
if _host_match(base, "ollama.com"):
|
||||
root = f"{parsed.scheme}://{parsed.netloc}" if parsed.scheme and parsed.netloc else "https://ollama.com"
|
||||
return root.rstrip("/") + "/api"
|
||||
return base
|
||||
|
||||
|
||||
def build_chat_url(base: str) -> str:
|
||||
"""Return the correct chat endpoint URL for a given base."""
|
||||
base = resolve_url(base)
|
||||
|
||||
@@ -1,119 +1,17 @@
|
||||
"""Tests for endpoint_resolver — pure functions tested directly to avoid import pollution."""
|
||||
"""Tests for endpoint_resolver — pure functions tested directly."""
|
||||
import json
|
||||
import re
|
||||
from urllib.parse import urlparse
|
||||
|
||||
|
||||
# Copy the pure functions to test them without importing the full module.
|
||||
# This avoids module cache conflicts with other test files that mock dependencies.
|
||||
|
||||
_NON_CHAT_MODEL = (
|
||||
"text-embedding", "embedding", "tts-", "whisper", "dall-e",
|
||||
"moderation", "rerank", "reranker", "clip", "stable-diffusion",
|
||||
from src.endpoint_resolver import (
|
||||
_first_chat_model,
|
||||
_endpoint_hidden_models,
|
||||
_endpoint_enabled_models,
|
||||
normalize_base,
|
||||
build_chat_url,
|
||||
build_models_url,
|
||||
build_headers,
|
||||
)
|
||||
|
||||
|
||||
def _first_chat_model(models):
|
||||
for m in (models or []):
|
||||
if not any(p in str(m).lower() for p in _NON_CHAT_MODEL):
|
||||
return m
|
||||
return (models[0] if models else None)
|
||||
|
||||
|
||||
def _endpoint_cached_models(ep) -> list:
|
||||
raw = getattr(ep, "cached_models", None) or getattr(ep, "models", None)
|
||||
if not raw:
|
||||
return []
|
||||
try:
|
||||
models = json.loads(raw) if isinstance(raw, str) else raw
|
||||
except Exception:
|
||||
return []
|
||||
return models if isinstance(models, list) else []
|
||||
|
||||
|
||||
def _endpoint_hidden_models(ep) -> set:
|
||||
raw = getattr(ep, "hidden_models", None)
|
||||
if not raw:
|
||||
return set()
|
||||
try:
|
||||
hidden = json.loads(raw) if isinstance(raw, str) else raw
|
||||
except Exception:
|
||||
return set()
|
||||
return set(hidden) if isinstance(hidden, list) else set()
|
||||
|
||||
|
||||
def _endpoint_enabled_models(ep) -> list:
|
||||
hidden = _endpoint_hidden_models(ep)
|
||||
return [m for m in _endpoint_cached_models(ep) if m not in hidden]
|
||||
|
||||
def normalize_base(url: str) -> str:
|
||||
url = (url or "").strip().rstrip("/")
|
||||
for suffix in ["/models", "/chat/completions", "/completions", "/v1/messages"]:
|
||||
if url.endswith(suffix):
|
||||
url = url[: -len(suffix)].rstrip("/")
|
||||
for suffix in ["/chat", "/tags", "/generate"]:
|
||||
if url.endswith("/api" + suffix):
|
||||
url = url[: -len(suffix)].rstrip("/")
|
||||
return url
|
||||
|
||||
|
||||
def _detect_provider(url: str) -> str:
|
||||
parsed = urlparse(url or "")
|
||||
host = parsed.hostname or ""
|
||||
path = (parsed.path or "").rstrip("/")
|
||||
if host.endswith("ollama.com"):
|
||||
return "ollama"
|
||||
if path.startswith("/v1"):
|
||||
pass # OpenAI compat
|
||||
elif (parsed.port == 11434 or host in {"localhost", "127.0.0.1", "0.0.0.0", "::1"}) and (path == "" or path == "/api" or path.startswith("/api/")):
|
||||
return "ollama"
|
||||
if "anthropic.com" in (url or ""):
|
||||
return "anthropic"
|
||||
return "openai"
|
||||
|
||||
|
||||
def _ollama_api_root(base: str) -> str:
|
||||
base = (base or "").strip().rstrip("/")
|
||||
parsed = urlparse(base)
|
||||
host = parsed.hostname or ""
|
||||
path = (parsed.path or "").rstrip("/")
|
||||
if path.endswith("/api"):
|
||||
return base
|
||||
if path == "":
|
||||
return base + "/api"
|
||||
if host.endswith("ollama.com"):
|
||||
return f"{parsed.scheme}://{parsed.netloc}/api"
|
||||
return base
|
||||
|
||||
|
||||
def build_chat_url(base: str) -> str:
|
||||
provider = _detect_provider(base)
|
||||
if provider == "anthropic":
|
||||
host = urlparse(base).hostname or ""
|
||||
if host.endswith("anthropic.com") and base.rstrip("/").endswith("/v1"):
|
||||
base = base.rstrip("/")[:-3].rstrip("/")
|
||||
return base + "/v1/messages"
|
||||
if provider == "ollama":
|
||||
return _ollama_api_root(base) + "/chat"
|
||||
return base + "/chat/completions"
|
||||
|
||||
|
||||
def build_models_url(base: str) -> str:
|
||||
provider = _detect_provider(base)
|
||||
if provider == "ollama":
|
||||
return _ollama_api_root(base) + "/tags"
|
||||
return base + "/models"
|
||||
|
||||
|
||||
def build_headers(api_key, base: str) -> dict:
|
||||
if not api_key:
|
||||
return {}
|
||||
provider = _detect_provider(base)
|
||||
if provider == "anthropic":
|
||||
return {"x-api-key": api_key, "anthropic-version": "2023-06-01"}
|
||||
return {"Authorization": f"Bearer {api_key}"}
|
||||
|
||||
|
||||
class TestNormalizeBase:
|
||||
def test_strips_models(self):
|
||||
assert normalize_base("https://api.openai.com/v1/models") == "https://api.openai.com/v1"
|
||||
|
||||
Reference in New Issue
Block a user