fix(endpoints): normalize URL handling (#4338)

This commit is contained in:
RaresKeY
2026-06-16 05:59:18 +03:00
committed by GitHub
parent a031a94a2e
commit 33fe7276be
13 changed files with 300 additions and 40 deletions
+17 -9
View File
@@ -5,6 +5,7 @@ import re
import uuid import uuid
import json import json
import hashlib import hashlib
import ipaddress
import socket import socket
import time as _time import time as _time
import logging import logging
@@ -562,6 +563,8 @@ def _safe_build_models_url(base_url: str) -> str:
"""Build a /models URL without letting optional provider imports break probes.""" """Build a /models URL without letting optional provider imports break probes."""
try: try:
return build_models_url(base_url) return build_models_url(base_url)
except ValueError:
raise
except Exception as exc: except Exception as exc:
logger.debug("Model URL detection failed for %s: %s", base_url, exc) logger.debug("Model URL detection failed for %s: %s", base_url, exc)
return f"{(base_url or '').rstrip('/')}/models" return f"{(base_url or '').rstrip('/')}/models"
@@ -633,7 +636,7 @@ def _probe_single_model(base: str, api_key: str, model_id: str, timeout: int = 1
try: try:
t0 = _time.time() t0 = _time.time()
r = httpx.post(target_url, headers=h, json=payload, timeout=timeout) r = httpx.post(target_url, headers=h, json=payload, timeout=timeout, verify=llm_verify())
latency = round((_time.time() - t0) * 1000) latency = round((_time.time() - t0) * 1000)
if r.is_success: if r.is_success:
return {"status": "ok", "latency_ms": latency} return {"status": "ok", "latency_ms": latency}
@@ -659,13 +662,20 @@ def _probe_single_model(base: str, api_key: str, model_id: str, timeout: int = 1
# Hostnames / IP prefixes that indicate a local endpoint # Hostnames / IP prefixes that indicate a local endpoint
_LOCAL_HOSTS = {"localhost", "127.0.0.1", "0.0.0.0", "::1"} _LOCAL_HOSTS = {"localhost", "127.0.0.1", "0.0.0.0", "::1"}
_PRIVATE_PREFIXES = ("10.", "172.16.", "172.17.", "172.18.", "172.19.", _PRIVATE_NETWORKS = (
"172.20.", "172.21.", "172.22.", "172.23.", "172.24.", ipaddress.ip_network("10.0.0.0/8"),
"172.25.", "172.26.", "172.27.", "172.28.", "172.29.", ipaddress.ip_network("172.16.0.0/12"),
"172.30.", "172.31.", "192.168.") ipaddress.ip_network("192.168.0.0/16"),
)
_TAILSCALE_CGNAT = ipaddress.ip_network("100.64.0.0/10")
_TAILSCALE_RE = re.compile(r"^100\.(6[4-9]|[7-9]\d|1[01]\d|12[0-7])\.") def _local_ip_literal(host: str) -> bool:
try:
ip = ipaddress.ip_address(host)
except ValueError:
return False
return any(ip in network for network in _PRIVATE_NETWORKS) or ip in _TAILSCALE_CGNAT
def _classify_endpoint(base_url: str, endpoint_kind: str = "auto") -> str: def _classify_endpoint(base_url: str, endpoint_kind: str = "auto") -> str:
@@ -679,9 +689,7 @@ def _classify_endpoint(base_url: str, endpoint_kind: str = "auto") -> str:
return "api" return "api"
try: try:
host = urlparse(base_url).hostname or "" host = urlparse(base_url).hostname or ""
if host in _LOCAL_HOSTS or host.startswith(_PRIVATE_PREFIXES): if host in _LOCAL_HOSTS or _local_ip_literal(host):
return "local"
if _TAILSCALE_RE.match(host):
return "local" return "local"
except Exception: except Exception:
pass pass
+39 -11
View File
@@ -161,6 +161,32 @@ def normalize_base(url: str) -> str:
return url return url
def _validated_endpoint_base(url: str) -> str:
"""Return a base URL that is safe for endpoint path appends."""
base = (url or "").strip().rstrip("/")
if "?" in base or "#" in base:
raise ValueError("Endpoint base URL must not include query or fragment")
return urlunparse(urlparse(base)._replace(query="", fragment="")).rstrip("/")
def _prepare_endpoint_base(base: str) -> str:
base = _validated_endpoint_base(normalize_base(base))
return _validated_endpoint_base(normalize_base(resolve_url(base)))
def _append_endpoint_path(base: str, suffix: str) -> str:
parsed = urlparse(base)
current = (parsed.path or "").rstrip("/")
extra = "/" + suffix.lstrip("/")
path = f"{current}{extra}" if current else extra
return urlunparse(parsed._replace(path=path, query="", fragment=""))
def _pathless_host(base: str, host: str) -> bool:
parsed = urlparse(base)
return (parsed.hostname or "").lower() == host and not (parsed.path or "").strip("/")
def _anthropic_api_root(base: str) -> str: def _anthropic_api_root(base: str) -> str:
"""Return Anthropic's API root, preserving /v1 for OpenAI-compatible APIs elsewhere.""" """Return Anthropic's API root, preserving /v1 for OpenAI-compatible APIs elsewhere."""
base = (base or "").strip().rstrip("/") base = (base or "").strip().rstrip("/")
@@ -171,15 +197,17 @@ def _anthropic_api_root(base: str) -> str:
def build_chat_url(base: str) -> str: def build_chat_url(base: str) -> str:
"""Return the correct chat endpoint URL for a given base.""" """Return the correct chat endpoint URL for a given base."""
base = resolve_url(base) base = _prepare_endpoint_base(base)
provider = _detect_provider(base) provider = _detect_provider(base)
if provider == "anthropic": if provider == "anthropic":
return _anthropic_api_root(base) + "/v1/messages" return _append_endpoint_path(_anthropic_api_root(base), "/v1/messages")
if provider == "ollama": if provider == "ollama":
return _ollama_api_root(base) + "/chat" return _append_endpoint_path(_ollama_api_root(base), "/chat")
if provider == "chatgpt-subscription": if provider == "chatgpt-subscription":
return base.rstrip("/") + "/responses" return _append_endpoint_path(base, "/responses")
return base + "/chat/completions" if _pathless_host(base, "api.openai.com"):
base = _append_endpoint_path(base, "/v1")
return _append_endpoint_path(base, "/chat/completions")
def build_models_url(base: str) -> Optional[str]: def build_models_url(base: str) -> Optional[str]:
@@ -193,12 +221,12 @@ def build_models_url(base: str) -> Optional[str]:
untouched (so custom prefixes like ``/openai`` or ``/api/openai/v1`` keep untouched (so custom prefixes like ``/openai`` or ``/api/openai/v1`` keep
their semantics). their semantics).
""" """
base = normalize_base(resolve_url(base)) base = _prepare_endpoint_base(base)
provider = _detect_provider(base) provider = _detect_provider(base)
if provider == "anthropic": if provider == "anthropic":
return _anthropic_api_root(base) + "/v1/models" return _append_endpoint_path(_anthropic_api_root(base), "/v1/models")
if provider == "ollama": if provider == "ollama":
return _ollama_api_root(base) + "/tags" return _append_endpoint_path(_ollama_api_root(base), "/tags")
if provider == "chatgpt-subscription": if provider == "chatgpt-subscription":
return None return None
# Generic OpenAI-compatible fallback: local model servers with no explicit # Generic OpenAI-compatible fallback: local model servers with no explicit
@@ -208,10 +236,10 @@ def build_models_url(base: str) -> Optional[str]:
parsed = urlparse(base) parsed = urlparse(base)
host = (parsed.hostname or "").lower() host = (parsed.hostname or "").lower()
is_local = host in {"localhost", "127.0.0.1", "::1", "host.docker.internal"} is_local = host in {"localhost", "127.0.0.1", "::1", "host.docker.internal"}
uses_v1_models_by_default = is_local or host in {"api.deepseek.com"} uses_v1_models_by_default = is_local or host in {"api.deepseek.com", "api.openai.com"}
if not parsed.path and uses_v1_models_by_default: if not parsed.path and uses_v1_models_by_default:
base = base + "/v1" base = _append_endpoint_path(base, "/v1")
return base + "/models" return _append_endpoint_path(base, "/models")
def build_headers(api_key: Optional[str], base: str) -> Dict[str, str]: def build_headers(api_key: Optional[str], base: str) -> Dict[str, str]:
+35 -8
View File
@@ -4,6 +4,7 @@ import uuid
import logging import logging
import re import re
from typing import Dict, List, Optional, Any from typing import Dict, List, Optional, Any
from urllib.parse import urljoin, urlparse, urlunparse
import httpx import httpx
from fastapi import HTTPException from fastapi import HTTPException
@@ -202,6 +203,22 @@ def mask_integration_secret(integration: Dict[str, Any]) -> Dict[str, Any]:
return safe return safe
def _normalize_integration_base_url(base_url: Any) -> str:
if not isinstance(base_url, str) or not base_url.strip():
raise ValueError("Integration base URL is required")
cleaned = base_url.strip().rstrip("/")
if "?" in cleaned or "#" in cleaned:
raise ValueError("Integration base URL must not include query or fragment")
parsed = urlparse(cleaned)
if parsed.scheme.lower() not in ("http", "https") or not parsed.hostname:
raise ValueError("Integration base URL must be an HTTP(S) URL")
return urlunparse(parsed._replace(scheme=parsed.scheme.lower(), query="", fragment="")).rstrip("/")
def _join_integration_url(base_url: str, path: str) -> str:
return urljoin(base_url.rstrip("/") + "/", path.lstrip("/"))
def load_integrations() -> List[Dict[str, Any]]: def load_integrations() -> List[Dict[str, Any]]:
"""Load all integrations from disk with secrets decrypted for runtime use.""" """Load all integrations from disk with secrets decrypted for runtime use."""
if not os.path.exists(DATA_FILE): if not os.path.exists(DATA_FILE):
@@ -261,8 +278,10 @@ def add_integration(data: Dict[str, Any]) -> Dict[str, Any]:
if not isinstance(integration.get("name"), str) or not integration["name"].strip(): if not isinstance(integration.get("name"), str) or not integration["name"].strip():
raise HTTPException(400, "Integration name is required") raise HTTPException(400, "Integration name is required")
if not isinstance(integration.get("base_url"), str) or not integration["base_url"].strip(): try:
raise HTTPException(400, "Integration base URL is required") integration["base_url"] = _normalize_integration_base_url(integration.get("base_url"))
except ValueError as exc:
raise HTTPException(400, str(exc)) from exc
integrations = load_integrations() integrations = load_integrations()
integrations.append(integration) integrations.append(integration)
@@ -272,10 +291,14 @@ def add_integration(data: Dict[str, Any]) -> Dict[str, Any]:
def update_integration(integration_id: str, data: Dict[str, Any]) -> Optional[Dict[str, Any]]: def update_integration(integration_id: str, data: Dict[str, Any]) -> Optional[Dict[str, Any]]:
"""Update fields on an existing integration. Returns updated integration or None.""" """Update fields on an existing integration. Returns updated integration or None."""
data = dict(data)
if "name" in data and (not isinstance(data["name"], str) or not data["name"].strip()): if "name" in data and (not isinstance(data["name"], str) or not data["name"].strip()):
raise HTTPException(400, "Integration name is required") raise HTTPException(400, "Integration name is required")
if "base_url" in data and (not isinstance(data["base_url"], str) or not data["base_url"].strip()): if "base_url" in data:
raise HTTPException(400, "Integration base URL is required") try:
data["base_url"] = _normalize_integration_base_url(data["base_url"])
except ValueError as exc:
raise HTTPException(400, str(exc)) from exc
integrations = load_integrations() integrations = load_integrations()
for item in integrations: for item in integrations:
@@ -341,9 +364,10 @@ async def execute_api_call(
if not integration.get("enabled", True): if not integration.get("enabled", True):
return {"error": f"Integration '{integration.get('name')}' is disabled", "exit_code": 1} return {"error": f"Integration '{integration.get('name')}' is disabled", "exit_code": 1}
base_url = integration.get("base_url", "").rstrip("/") try:
if not base_url: base_url = _normalize_integration_base_url(integration.get("base_url", ""))
return {"error": "Integration has no base_url configured", "exit_code": 1} except ValueError as exc:
return {"error": str(exc), "exit_code": 1}
# Strip common API path suffixes users might accidentally include # Strip common API path suffixes users might accidentally include
# (e.g. "http://host/v1/" → "http://host"). The integration's preset # (e.g. "http://host/v1/" → "http://host"). The integration's preset
@@ -366,7 +390,10 @@ async def execute_api_call(
if re.search(r"^https?://", path) or "://" in path: if re.search(r"^https?://", path) or "://" in path:
return {"error": "Path must not contain a protocol scheme", "exit_code": 1} return {"error": "Path must not contain a protocol scheme", "exit_code": 1}
url = base_url + path if "#" in path:
return {"error": "Path must not contain a fragment", "exit_code": 1}
url = _join_integration_url(base_url, path)
method = method.upper() method = method.upper()
# Build headers # Build headers
+14 -5
View File
@@ -17,10 +17,11 @@ import httpx
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
_LOCAL_HOSTS = {"localhost", "127.0.0.1", "0.0.0.0", "::1", "host.docker.internal"} _LOCAL_HOSTS = {"localhost", "127.0.0.1", "0.0.0.0", "::1", "host.docker.internal"}
_PRIVATE_PREFIXES = ("10.", "172.16.", "172.17.", "172.18.", "172.19.", _PRIVATE_NETWORKS = (
"172.20.", "172.21.", "172.22.", "172.23.", "172.24.", ipaddress.ip_network("10.0.0.0/8"),
"172.25.", "172.26.", "172.27.", "172.28.", "172.29.", ipaddress.ip_network("172.16.0.0/12"),
"172.30.", "172.31.", "192.168.") ipaddress.ip_network("192.168.0.0/16"),
)
# Tailscale uses the CGNAT range 100.64.0.0/10, NOT all of 100.0.0.0/8. # Tailscale uses the CGNAT range 100.64.0.0/10, NOT all of 100.0.0.0/8.
# A bare "100." prefix would classify public addresses (e.g. AWS ranges # A bare "100." prefix would classify public addresses (e.g. AWS ranges
@@ -36,6 +37,14 @@ def _in_tailscale_range(host: str) -> bool:
return False return False
def _is_private_ip_literal(host: str) -> bool:
try:
ip = ipaddress.ip_address(host)
except ValueError:
return False
return any(ip in network for network in _PRIVATE_NETWORKS)
def _normalize_base_for_compare(url: str) -> str: def _normalize_base_for_compare(url: str) -> str:
url = (url or "").strip().rstrip("/") url = (url or "").strip().rstrip("/")
for suffix in ("/chat/completions", "/models", "/completions", "/v1/messages"): for suffix in ("/chat/completions", "/models", "/completions", "/v1/messages"):
@@ -87,7 +96,7 @@ def is_local_endpoint(url: str) -> bool:
return True return True
try: try:
host = urlparse(url).hostname or "" host = urlparse(url).hostname or ""
return host in _LOCAL_HOSTS or host.startswith(_PRIVATE_PREFIXES) or _in_tailscale_range(host) return host in _LOCAL_HOSTS or _is_private_ip_literal(host) or _in_tailscale_range(host)
except Exception: except Exception:
return False return False
+10
View File
@@ -38,6 +38,16 @@ def test_unknown_public_host_gets_no_affinity_fields(monkeypatch):
assert payload == {} assert payload == {}
@pytest.mark.parametrize("url", [
"https://10.example-cloud.com/v1",
"https://172.16.example-cloud.com/v1",
"https://192.168.example-cloud.com/v1",
])
def test_private_prefix_dns_host_gets_no_affinity_fields(monkeypatch, url):
payload = _affinity_fields(url, monkeypatch)
assert payload == {}
def test_localhost_server_gets_affinity_fields(monkeypatch): def test_localhost_server_gets_affinity_fields(monkeypatch):
payload = _affinity_fields("http://localhost:8080/v1", monkeypatch) payload = _affinity_fields("http://localhost:8080/v1", monkeypatch)
assert payload == {"session_id": "sess-123", "cache_prompt": True} assert payload == {"session_id": "sess-123", "cache_prompt": True}
+27 -7
View File
@@ -264,7 +264,7 @@ class TestProbeSingleModel:
_patch_resolve(monkeypatch) _patch_resolve(monkeypatch)
captured = {} captured = {}
def fake_post(url, headers=None, json=None, timeout=None): def fake_post(url, headers=None, json=None, timeout=None, verify=None):
captured["url"] = url captured["url"] = url
return _resp(200, json={"choices": [{"message": {"content": "OK"}}]}) return _resp(200, json={"choices": [{"message": {"content": "OK"}}]})
@@ -274,11 +274,31 @@ class TestProbeSingleModel:
assert "latency_ms" in result assert "latency_ms" in result
assert captured["url"] == "https://api.example.com/v1/chat/completions" assert captured["url"] == "https://api.example.com/v1/chat/completions"
@pytest.mark.parametrize("base,api_key,model_id", [
("https://api.example.com/v1", "key", "gpt-4o"),
("http://localhost:11434/v1", None, "llama3.2"),
("https://api.anthropic.com/v1", "sk-ant", "claude-sonnet-4-5"),
])
def test_completion_probe_uses_llm_verify(self, monkeypatch, base, api_key, model_id):
_patch_resolve(monkeypatch)
marker = object()
captured = {}
monkeypatch.setattr(model_routes, "llm_verify", lambda: marker)
def fake_post(url, headers=None, json=None, timeout=None, verify=None):
captured["verify"] = verify
return _resp(200, json={"choices": [{"message": {"content": "OK"}}]})
monkeypatch.setattr(model_routes.httpx, "post", fake_post)
result = _probe_single_model(base, api_key, model_id)
assert result["status"] == "ok"
assert captured["verify"] is marker
def test_extracts_dict_error_message(self, monkeypatch): def test_extracts_dict_error_message(self, monkeypatch):
_patch_resolve(monkeypatch) _patch_resolve(monkeypatch)
monkeypatch.setattr( monkeypatch.setattr(
model_routes.httpx, "post", model_routes.httpx, "post",
lambda url, headers=None, json=None, timeout=None: _resp( lambda url, headers=None, json=None, timeout=None, verify=None: _resp(
400, json={"error": {"message": "model not found"}}), 400, json={"error": {"message": "model not found"}}),
) )
result = _probe_single_model("https://api.example.com/v1", "key", "ghost") result = _probe_single_model("https://api.example.com/v1", "key", "ghost")
@@ -289,7 +309,7 @@ class TestProbeSingleModel:
_patch_resolve(monkeypatch) _patch_resolve(monkeypatch)
monkeypatch.setattr( monkeypatch.setattr(
model_routes.httpx, "post", model_routes.httpx, "post",
lambda url, headers=None, json=None, timeout=None: _resp( lambda url, headers=None, json=None, timeout=None, verify=None: _resp(
403, json={"error": "forbidden"}), 403, json={"error": "forbidden"}),
) )
result = _probe_single_model("https://api.example.com/v1", "key", "m") result = _probe_single_model("https://api.example.com/v1", "key", "m")
@@ -299,7 +319,7 @@ class TestProbeSingleModel:
def test_timeout(self, monkeypatch): def test_timeout(self, monkeypatch):
_patch_resolve(monkeypatch) _patch_resolve(monkeypatch)
def fake_post(url, headers=None, json=None, timeout=None): def fake_post(url, headers=None, json=None, timeout=None, verify=None):
raise httpx.TimeoutException("timed out") raise httpx.TimeoutException("timed out")
monkeypatch.setattr(model_routes.httpx, "post", fake_post) monkeypatch.setattr(model_routes.httpx, "post", fake_post)
@@ -310,7 +330,7 @@ class TestProbeSingleModel:
def test_transport_error_is_fail(self, monkeypatch): def test_transport_error_is_fail(self, monkeypatch):
_patch_resolve(monkeypatch) _patch_resolve(monkeypatch)
def fake_post(url, headers=None, json=None, timeout=None): def fake_post(url, headers=None, json=None, timeout=None, verify=None):
raise httpx.ConnectError("refused") raise httpx.ConnectError("refused")
monkeypatch.setattr(model_routes.httpx, "post", fake_post) monkeypatch.setattr(model_routes.httpx, "post", fake_post)
@@ -322,7 +342,7 @@ class TestProbeSingleModel:
_patch_resolve(monkeypatch) _patch_resolve(monkeypatch)
captured = {} captured = {}
def fake_post(url, headers=None, json=None, timeout=None): def fake_post(url, headers=None, json=None, timeout=None, verify=None):
captured.update(url=url, headers=headers, payload=json) captured.update(url=url, headers=headers, payload=json)
return _resp(200, json={"content": [{"type": "text", "text": "OK"}]}) return _resp(200, json={"content": [{"type": "text", "text": "OK"}]})
@@ -337,7 +357,7 @@ class TestProbeSingleModel:
_patch_resolve(monkeypatch) _patch_resolve(monkeypatch)
captured = {} captured = {}
def fake_post(url, headers=None, json=None, timeout=None): def fake_post(url, headers=None, json=None, timeout=None, verify=None):
captured["payload"] = json captured["payload"] = json
return _resp(200, json={"content": []}) return _resp(200, json={"content": []})
+26
View File
@@ -1,6 +1,8 @@
"""Tests for endpoint_resolver — pure functions tested directly.""" """Tests for endpoint_resolver — pure functions tested directly."""
import json import json
import pytest
from src.endpoint_resolver import ( from src.endpoint_resolver import (
_first_chat_model, _first_chat_model,
_endpoint_hidden_models, _endpoint_hidden_models,
@@ -45,6 +47,9 @@ class TestBuildChatUrl:
def test_openai_style(self): def test_openai_style(self):
assert build_chat_url("https://api.openai.com/v1") == "https://api.openai.com/v1/chat/completions" assert build_chat_url("https://api.openai.com/v1") == "https://api.openai.com/v1/chat/completions"
def test_pathless_openai_style_adds_v1(self):
assert build_chat_url("https://api.openai.com") == "https://api.openai.com/v1/chat/completions"
def test_anthropic_style(self): def test_anthropic_style(self):
assert build_chat_url("https://api.anthropic.com") == "https://api.anthropic.com/v1/messages" assert build_chat_url("https://api.anthropic.com") == "https://api.anthropic.com/v1/messages"
@@ -66,14 +71,35 @@ class TestBuildChatUrl:
def test_ollama_v1_preserves_openai_compat(self): def test_ollama_v1_preserves_openai_compat(self):
assert build_chat_url("http://nas:11434/v1") == "http://nas:11434/v1/chat/completions" assert build_chat_url("http://nas:11434/v1") == "http://nas:11434/v1/chat/completions"
@pytest.mark.parametrize("bad_base", [
"https://api.example.com/v1?token=abc",
"https://api.example.com/v1#fragment",
"http://localhost:1234?",
])
def test_rejects_query_or_fragment_base(self, bad_base):
with pytest.raises(ValueError, match="query or fragment"):
build_chat_url(bad_base)
class TestBuildModelsUrl: class TestBuildModelsUrl:
def test_openai_models(self): def test_openai_models(self):
assert build_models_url("https://api.openai.com/v1") == "https://api.openai.com/v1/models" assert build_models_url("https://api.openai.com/v1") == "https://api.openai.com/v1/models"
def test_pathless_openai_models_adds_v1(self):
assert build_models_url("https://api.openai.com") == "https://api.openai.com/v1/models"
def test_ollama_tags(self): def test_ollama_tags(self):
assert build_models_url("https://ollama.com/api") == "https://ollama.com/api/tags" assert build_models_url("https://ollama.com/api") == "https://ollama.com/api/tags"
@pytest.mark.parametrize("bad_base", [
"https://api.example.com/v1?token=abc",
"https://api.example.com/v1#fragment",
"http://localhost:1234?",
])
def test_rejects_query_or_fragment_base(self, bad_base):
with pytest.raises(ValueError, match="query or fragment"):
build_models_url(bad_base)
class TestBuildHeaders: class TestBuildHeaders:
def test_no_key(self): def test_no_key(self):
@@ -87,11 +87,60 @@ async def _call(json_data, status=200):
return await integrations.execute_api_call("test_integ", "GET", "/items") return await integrations.execute_api_call("test_integ", "GET", "/items")
async def _call_with_integration(integration, path="/items"):
mock_resp = _make_response({"ok": True})
mock_client = AsyncMock()
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
mock_client.__aexit__ = AsyncMock(return_value=None)
mock_client.request = AsyncMock(return_value=mock_resp)
with (
patch.object(integrations, "_find_integration", return_value=integration),
patch("httpx.AsyncClient", return_value=mock_client),
):
result = await integrations.execute_api_call("test_integ", "GET", path)
return result, mock_client
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Tests # Tests
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_api_call_rejects_stored_base_url_with_query_without_requesting():
integration = {**DUMMY_INTEGRATION, "base_url": "http://api.example.com/api?token=abc"}
result, mock_client = await _call_with_integration(integration)
assert result == {
"error": "Integration base URL must not include query or fragment",
"exit_code": 1,
}
mock_client.request.assert_not_called()
@pytest.mark.asyncio
async def test_api_call_joins_path_under_configured_base_path():
integration = {**DUMMY_INTEGRATION, "base_url": "http://api.example.com/root"}
result, mock_client = await _call_with_integration(integration, "/v1/items?limit=1")
assert result.get("exit_code") == 0
mock_client.request.assert_called_once()
assert mock_client.request.call_args.args[:2] == (
"GET",
"http://api.example.com/root/v1/items?limit=1",
)
@pytest.mark.asyncio
async def test_api_call_rejects_path_fragment_without_requesting():
result, mock_client = await _call_with_integration(DUMMY_INTEGRATION, "/items#fragment")
assert result == {"error": "Path must not contain a fragment", "exit_code": 1}
mock_client.request.assert_not_called()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_large_json_list_returns_valid_json_with_sentinel(): async def test_large_json_list_returns_valid_json_with_sentinel():
"""A JSON list whose serialized form exceeds 12000 chars must be truncated """A JSON list whose serialized form exceeds 12000 chars must be truncated
+50
View File
@@ -83,6 +83,27 @@ def test_create_integration_rejects_blank_base_url_without_persisting(integratio
assert integrations.load_integrations() == [] assert integrations.load_integrations() == []
@pytest.mark.parametrize(("base_url", "message"), [
("ftp://example.test", "Integration base URL must be an HTTP(S) URL"),
("https://example.test/api?token=abc", "Integration base URL must not include query or fragment"),
("https://example.test/api#fragment", "Integration base URL must not include query or fragment"),
])
def test_create_integration_rejects_invalid_base_url_without_persisting(
integrations_routes, base_url, message
):
endpoint, session_cookie, http_exception = integrations_routes
create_integration = endpoint("/api/auth/integrations", "POST")
with pytest.raises(http_exception) as exc:
asyncio.run(create_integration(
_JsonRequest({"name": "Example", "base_url": base_url}, session_cookie)
))
assert exc.value.status_code == 400
assert exc.value.detail == message
assert integrations.load_integrations() == []
@pytest.mark.parametrize("blank_name", ["", " "]) @pytest.mark.parametrize("blank_name", ["", " "])
def test_update_integration_rejects_blank_name_without_changing_existing(integrations_routes, blank_name): def test_update_integration_rejects_blank_name_without_changing_existing(integrations_routes, blank_name):
endpoint, session_cookie, http_exception = integrations_routes endpoint, session_cookie, http_exception = integrations_routes
@@ -127,3 +148,32 @@ def test_update_integration_rejects_blank_base_url_without_changing_existing(int
assert exc.value.status_code == 400 assert exc.value.status_code == 400
assert exc.value.detail == "Integration base URL is required" assert exc.value.detail == "Integration base URL is required"
assert integrations.load_integrations()[0]["base_url"] == "https://example.test" assert integrations.load_integrations()[0]["base_url"] == "https://example.test"
@pytest.mark.parametrize(("base_url", "message"), [
("ftp://example.test", "Integration base URL must be an HTTP(S) URL"),
("https://example.test/api?token=abc", "Integration base URL must not include query or fragment"),
("https://example.test/api#fragment", "Integration base URL must not include query or fragment"),
])
def test_update_integration_rejects_invalid_base_url_without_changing_existing(
integrations_routes, base_url, message
):
endpoint, session_cookie, http_exception = integrations_routes
update_integration = endpoint("/api/auth/integrations/{integration_id}", "PUT")
integrations.save_integrations([
{
"id": "existing",
"name": "Original",
"base_url": "https://example.test",
}
])
with pytest.raises(http_exception) as exc:
asyncio.run(update_integration(
integration_id="existing",
request=_JsonRequest({"base_url": base_url}, session_cookie),
))
assert exc.value.status_code == 400
assert exc.value.detail == message
assert integrations.load_integrations()[0]["base_url"] == "https://example.test"
+14
View File
@@ -17,6 +17,7 @@ This module pins both behaviors so future refactors don't regress them.
""" """
import httpx import httpx
import pytest
from src import endpoint_resolver, llm_core from src import endpoint_resolver, llm_core
@@ -90,6 +91,19 @@ def test_build_models_url_preserves_explicit_non_v1_path(monkeypatch):
) )
@pytest.mark.parametrize("base_url", [
"http://localhost:1234?",
"http://localhost:1234#fragment",
"http://localhost:1234/v1?token=abc",
])
def test_build_models_url_rejects_query_or_fragment_base(monkeypatch, base_url):
monkeypatch.setattr(endpoint_resolver, "resolve_url", lambda url: url)
_neutralize_provider_detection(monkeypatch)
with pytest.raises(ValueError, match="query or fragment"):
endpoint_resolver.build_models_url(base_url)
# ── list_model_ids: parse LM Studio's response ───────────────────────── # ── list_model_ids: parse LM Studio's response ─────────────────────────
+8
View File
@@ -67,6 +67,14 @@ class TestIsLocalEndpoint:
def test_private_10(self): def test_private_10(self):
assert is_local_endpoint("http://10.0.0.5:8000/v1/chat/completions") is True assert is_local_endpoint("http://10.0.0.5:8000/v1/chat/completions") is True
@pytest.mark.parametrize("host", [
"10.example-cloud.com",
"172.16.example-cloud.com",
"192.168.example-cloud.com",
])
def test_private_prefix_dns_names_are_remote(self, host):
assert is_local_endpoint(f"https://{host}/v1/chat/completions") is False
def test_tailscale_100(self): def test_tailscale_100(self):
# 100.64.0.0/10 is the CGNAT range Tailscale uses. # 100.64.0.0/10 is the CGNAT range Tailscale uses.
assert is_local_endpoint("http://100.64.0.1:5000/v1/chat/completions") is True assert is_local_endpoint("http://100.64.0.1:5000/v1/chat/completions") is True
+8
View File
@@ -419,6 +419,14 @@ class TestClassifyEndpoint:
def test_private_10(self): def test_private_10(self):
assert _classify_endpoint("http://10.0.0.5:8000") == "local" assert _classify_endpoint("http://10.0.0.5:8000") == "local"
@pytest.mark.parametrize("host", [
"10.example-cloud.com",
"172.16.example-cloud.com",
"192.168.example-cloud.com",
])
def test_private_prefix_dns_names_are_api(self, host):
assert _classify_endpoint(f"https://{host}/v1") == "api"
def test_public_api(self): def test_public_api(self):
assert _classify_endpoint("https://api.openai.com/v1") == "api" assert _classify_endpoint("https://api.openai.com/v1") == "api"
+3
View File
@@ -37,6 +37,9 @@ PROVIDER_CASES = [
("openai", "https://api.openai.com/v1", ("openai", "https://api.openai.com/v1",
"https://api.openai.com/v1/chat/completions", "https://api.openai.com/v1/chat/completions",
"https://api.openai.com/v1/models"), "https://api.openai.com/v1/models"),
("openai_pathless", "https://api.openai.com",
"https://api.openai.com/v1/chat/completions",
"https://api.openai.com/v1/models"),
("anthropic", "https://api.anthropic.com", ("anthropic", "https://api.anthropic.com",
"https://api.anthropic.com/v1/messages", "https://api.anthropic.com/v1/messages",
"https://api.anthropic.com/v1/models"), "https://api.anthropic.com/v1/models"),