Harden API-token chat endpoint selection

Validate only token-supplied direct base_url values for API-token chat requests, while keeping admin-configured endpoints available for local/LAN providers.

Scope configured endpoint fallback selection to the API token owner, fail closed for unknown token owners, and preserve strict session ownership checks when resuming sessions from chat-scoped API tokens.

Add focused regression coverage for direct base_url SSRF rejection, configured endpoint fallback behavior, token-owner scoping, URL validation, and null-owner session/endpoint handling.
This commit is contained in:
Alexandre Teixeira
2026-06-03 13:05:13 +01:00
committed by GitHub
parent 145f4fd2b4
commit b1a4ed13b0
4 changed files with 543 additions and 31 deletions
+26 -24
View File
@@ -9,7 +9,9 @@ import httpx
from fastapi import APIRouter, HTTPException, Request, Form from fastapi import APIRouter, HTTPException, Request, Form
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from core.database import SessionLocal, Webhook from core.database import SessionLocal, Webhook, ModelEndpoint
from src.auth_helpers import owner_filter
from src.url_security import validate_public_http_url
from src.webhook_manager import WebhookManager, validate_webhook_url, validate_events from src.webhook_manager import WebhookManager, validate_webhook_url, validate_events
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -26,23 +28,19 @@ MAX_MESSAGE_LEN = 32_000
from core.middleware import require_admin as _require_admin from core.middleware import require_admin as _require_admin
def _first_enabled_endpoint(db, owner): def _select_api_chat_fallback_endpoint(db, token_owner: Optional[str]):
"""First enabled ModelEndpoint VISIBLE to `owner` — their own rows plus """First enabled ModelEndpoint visible to token_owner — their own rows plus
legacy null-owner ("shared") rows. Owner-scoped on purpose: ModelEndpoint legacy null-owner ("shared") rows. Owner-scoped: an unscoped .first() would
is per-user (core/database.py — "when non-null, the model picker only shows let a chat-scoped token fall back onto another user's private endpoint and
the endpoint to that user"), and the sync-chat fallback uses the row's silently spend that owner's API key/quota. Prefer owner rows before shared
decrypted `api_key`. An unscoped ``.first()`` would let a chat-scoped token rows. Fails closed to null-owner rows only when token_owner is absent.
(e.g. a paired mobile device) fall back onto ANOTHER user's private Does not validate base_url — admin-configured local/LAN endpoints remain allowed.
endpoint and silently spend that owner's API key / quota — and reach
whatever internal base_url they configured. Mirrors the owner_filter scoping
in routes/model_routes.py and companion/routes.py. A null/empty owner is a
no-op (single-user / legacy mode), preserving the original behaviour.
""" """
from core.database import ModelEndpoint query = db.query(ModelEndpoint).filter(ModelEndpoint.is_enabled == True) # noqa: E712
from src.auth_helpers import owner_filter if token_owner:
q = db.query(ModelEndpoint).filter(ModelEndpoint.is_enabled == True) # noqa: E712 query = owner_filter(query, ModelEndpoint, token_owner)
q = owner_filter(q, ModelEndpoint, owner) return query.order_by(ModelEndpoint.owner.desc(), ModelEndpoint.created_at).first()
return q.first() return query.filter(ModelEndpoint.owner == None).order_by(ModelEndpoint.created_at).first() # noqa: E711
def _caller_owns_session(sess_owner, caller) -> bool: def _caller_owns_session(sess_owner, caller) -> bool:
@@ -278,15 +276,21 @@ def setup_webhook_routes(
api_key = body.api_key.strip() api_key = body.api_key.strip()
model = body.model or "deepseek-chat" model = body.model or "deepseek-chat"
# Resolve base_url: explicit > provider name > model prefix auto-detect # Validate only token-supplied direct base_url; auto-resolved known-provider
base_url = body.base_url.strip().rstrip("/") if body.base_url else None # URLs are not subject to extra local/LAN blocking beyond existing provider logic.
if not base_url: direct_base_url = body.base_url.strip().rstrip("/") if body.base_url else None
if direct_base_url:
try:
base_url = validate_public_http_url(direct_base_url)
except ValueError as e:
detail = str(e).replace("URL", "base_url", 1)
raise HTTPException(400, detail)
else:
base_url = _resolve_base_url(model, body.provider) base_url = _resolve_base_url(model, body.provider)
if not base_url: if not base_url:
raise HTTPException(400, raise HTTPException(400,
"Could not auto-detect provider. Pass base_url (e.g. 'https://api.deepseek.com/v1') " "Could not auto-detect provider. Pass base_url (e.g. 'https://api.deepseek.com/v1') "
"or provider ('deepseek', 'openai', 'groq', etc.)") "or provider ('deepseek', 'openai', 'groq', etc.)")
base_url = normalize_base(base_url) base_url = normalize_base(base_url)
endpoint_url = build_chat_url(base_url) endpoint_url = build_chat_url(base_url)
@@ -306,9 +310,7 @@ def setup_webhook_routes(
if not sess: if not sess:
db = SessionLocal() db = SessionLocal()
try: try:
# Owner-scoped: only THIS token owner's endpoints + legacy ep = _select_api_chat_fallback_endpoint(db, token_owner)
# shared rows, never another user's private endpoint/api_key.
ep = _first_enabled_endpoint(db, token_owner)
finally: finally:
db.close() db.close()
+94
View File
@@ -0,0 +1,94 @@
"""URL validation helpers for server-side outbound requests."""
from __future__ import annotations
import ipaddress
import socket
from urllib.parse import urlparse
_INTERNAL_HOSTNAMES = {
"localhost",
"metadata",
"metadata.google.internal",
}
_INTERNAL_SUFFIXES = (
".localhost",
".local",
".internal",
".lan",
".intranet",
)
_BLOCKED_NETWORKS = (
ipaddress.ip_network("0.0.0.0/8"),
ipaddress.ip_network("10.0.0.0/8"),
ipaddress.ip_network("100.64.0.0/10"),
ipaddress.ip_network("127.0.0.0/8"),
ipaddress.ip_network("169.254.0.0/16"),
ipaddress.ip_network("172.16.0.0/12"),
ipaddress.ip_network("192.168.0.0/16"),
ipaddress.ip_network("::/128"),
ipaddress.ip_network("::1/128"),
ipaddress.ip_network("fc00::/7"),
ipaddress.ip_network("fe80::/10"),
)
def _resolve_hostname_ips(hostname: str) -> list[ipaddress._BaseAddress]:
ips: list[ipaddress._BaseAddress] = []
for family, _, _, _, sockaddr in socket.getaddrinfo(hostname, None):
if family in (socket.AF_INET, socket.AF_INET6):
ips.append(ipaddress.ip_address(sockaddr[0]))
return ips
def _blocked_ip(addr: ipaddress._BaseAddress) -> bool:
return (
any(addr in net for net in _BLOCKED_NETWORKS)
or addr.is_private
or addr.is_loopback
or addr.is_link_local
or addr.is_multicast
or addr.is_unspecified
or addr.is_reserved
)
def _host_resolves_publicly(hostname: str) -> bool:
host = hostname.strip().lower()
if host in _INTERNAL_HOSTNAMES or host.endswith(_INTERNAL_SUFFIXES):
return False
try:
return not _blocked_ip(ipaddress.ip_address(host))
except ValueError:
pass
try:
addrs = _resolve_hostname_ips(host)
except OSError:
return False
return bool(addrs) and all(not _blocked_ip(addr) for addr in addrs)
def is_public_http_url(url: str) -> bool:
parsed = urlparse((url or "").strip())
if parsed.scheme not in ("http", "https") or not parsed.hostname:
return False
return _host_resolves_publicly(parsed.hostname)
def validate_public_http_url(url: str, *, max_length: int = 2048) -> str:
"""Validate a user/API-token supplied server-side HTTP(S) endpoint.
This is for untrusted outbound URLs, not admin-created model endpoints
that are intentionally allowed to point at private model providers. DNS
failures fail closed, and DNS checks reduce obvious private-network
targets but do not eliminate every DNS rebinding race by themselves.
"""
cleaned = (url or "").strip()
if len(cleaned) > max_length:
raise ValueError("URL is too long")
if not is_public_http_url(cleaned):
raise ValueError("URL must point to a public HTTP(S) endpoint")
return cleaned
+401
View File
@@ -0,0 +1,401 @@
import ipaddress
import importlib.util
import sys
import types
from pathlib import Path
import pytest
@pytest.mark.parametrize("url", [
"http://127.0.0.1:8000/v1",
"http://localhost:8000/v1",
"http://10.0.0.5/v1",
"http://172.16.0.1/v1",
"http://192.168.1.2/v1",
"http://169.254.169.254/latest/meta-data/",
"http://metadata.google.internal/",
"http://[::1]:8000/v1",
"http://[fc00::1]/v1",
"http://224.0.0.1/v1",
"http://0.0.0.0/v1",
"file:///etc/passwd",
])
def test_public_url_validator_blocks_internal_targets(url):
from src.url_security import is_public_http_url
assert is_public_http_url(url) is False
def test_public_url_validator_allows_public_endpoint(monkeypatch):
from src import url_security
monkeypatch.setattr(
url_security,
"_resolve_hostname_ips",
lambda host: [ipaddress.ip_address("93.184.216.34")],
)
assert url_security.validate_public_http_url("https://api.example.com/v1") == "https://api.example.com/v1"
def test_public_url_validator_blocks_dns_to_private(monkeypatch):
from src import url_security
monkeypatch.setattr(
url_security,
"_resolve_hostname_ips",
lambda host: [ipaddress.ip_address("10.0.0.5")],
)
with pytest.raises(ValueError):
url_security.validate_public_http_url("https://api.example.com/v1")
def _load_webhook_routes_for_test(monkeypatch):
# Load under a unique module name so each test gets a fresh module object
# rather than a cached one from a previous monkeypatch run.
core_pkg = types.ModuleType("core")
core_pkg.__path__ = []
core_db = types.ModuleType("core.database")
core_db.SessionLocal = object
core_db.Webhook = object
core_db.ModelEndpoint = object
core_middleware = types.ModuleType("core.middleware")
core_middleware.require_admin = lambda request: None
webhook_manager = types.ModuleType("src.webhook_manager")
webhook_manager.WebhookManager = object
webhook_manager.validate_webhook_url = lambda url: url
webhook_manager.validate_events = lambda events: events
monkeypatch.setitem(sys.modules, "core", core_pkg)
monkeypatch.setitem(sys.modules, "core.database", core_db)
monkeypatch.setitem(sys.modules, "core.middleware", core_middleware)
monkeypatch.setitem(sys.modules, "src.webhook_manager", webhook_manager)
module_name = "routes.webhook_routes_under_test"
spec = importlib.util.spec_from_file_location(
module_name,
Path(__file__).resolve().parent.parent / "routes" / "webhook_routes.py",
)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
return module
class _Expr:
def __init__(self, fn):
self.fn = fn
def __call__(self, row):
return self.fn(row)
def __or__(self, other):
return _Expr(lambda row: self(row) or other(row))
class _Column:
def __init__(self, name):
self.name = name
def __eq__(self, other):
return _Expr(lambda row: getattr(row, self.name) == other)
def desc(self):
return ("desc", self.name)
class _ModelEndpoint:
is_enabled = _Column("is_enabled")
owner = _Column("owner")
created_at = _Column("created_at")
class _Endpoint:
def __init__(
self,
*,
owner,
is_enabled=True,
created_at=1,
base_url="https://api.example.com/v1",
api_key=None,
):
self.owner = owner
self.is_enabled = is_enabled
self.created_at = created_at
self.base_url = base_url
self.api_key = api_key
class _EndpointQuery:
def __init__(self, rows):
self.rows = rows
self.filters = []
self.orders = []
def filter(self, *exprs):
self.filters.extend(exprs)
return self
def order_by(self, *exprs):
self.orders.extend(exprs)
return self
def first(self):
rows = self.rows
for expr in self.filters:
rows = [row for row in rows if expr(row)]
# Apply sort keys right-to-left so the leftmost key ends up as the
# primary sort (stable-sort reversal idiom mirrors SQLAlchemy's
# multi-column ORDER BY behaviour).
for order in reversed(self.orders):
reverse = False
name = getattr(order, "name", None)
if isinstance(order, tuple) and order[0] == "desc":
reverse = True
name = order[1]
rows = sorted(rows, key=lambda row: getattr(row, name) is not None, reverse=reverse)
if name != "owner":
rows = sorted(rows, key=lambda row: getattr(row, name), reverse=reverse)
return rows[0] if rows else None
class _DB:
def __init__(self, rows):
self.query_obj = _EndpointQuery(rows)
self.closed = False
def query(self, model):
assert model is _ModelEndpoint
return self.query_obj
def close(self):
self.closed = True
class _ChatSession:
def __init__(self, endpoint_url, model):
self.endpoint_url = endpoint_url
self.model = model
self.headers = {}
self.history = []
def add_message(self, message):
self.history.append(message)
class _SessionManager:
def __init__(self):
self.created = []
self.save_calls = 0
def create_session(self, *, session_id, name, endpoint_url, model, owner):
session = _ChatSession(endpoint_url, model)
self.created.append({
"session_id": session_id,
"name": name,
"endpoint_url": endpoint_url,
"model": model,
"owner": owner,
"session": session,
})
return session
def save_sessions(self):
self.save_calls += 1
class _Request:
def __init__(self, *, owner="alice"):
self.state = types.SimpleNamespace(
api_token=True,
api_token_scopes=["chat"],
api_token_owner=owner,
)
class _WebhookManager:
async def fire(self, event, payload):
return None
def _install_sync_chat_stubs(monkeypatch):
# FastAPI checks for python_multipart at import time when Form is used;
# stub it so the optional dependency is not required in the test environment.
python_multipart = types.ModuleType("python_multipart")
python_multipart.__version__ = "0.0.13"
core_models = types.ModuleType("core.models")
class _ChatMessage:
def __init__(self, role, content):
self.role = role
self.content = content
async def _llm_call_async(endpoint_url, model, messages, headers=None, timeout=None):
return "mocked response"
endpoint_resolver = types.ModuleType("src.endpoint_resolver")
endpoint_resolver.normalize_base = lambda url: (url or "").strip().rstrip("/")
endpoint_resolver.build_chat_url = lambda base_url: f"{base_url}/chat/completions"
endpoint_resolver.build_models_url = lambda base_url: f"{base_url}/models"
endpoint_resolver.build_headers = lambda api_key, base_url: {"Authorization": f"Bearer {api_key}"}
llm_core = types.ModuleType("src.llm_core")
llm_core.llm_call_async = _llm_call_async
core_models.ChatMessage = _ChatMessage
monkeypatch.setitem(sys.modules, "python_multipart", python_multipart)
monkeypatch.setitem(sys.modules, "core.models", core_models)
monkeypatch.setitem(sys.modules, "src.llm_core", llm_core)
monkeypatch.setitem(sys.modules, "src.endpoint_resolver", endpoint_resolver)
def _sync_chat_endpoint(webhook_routes, session_manager):
router = webhook_routes.setup_webhook_routes(
_WebhookManager(),
auth_manager=None,
session_manager=session_manager,
)
for route in router.routes:
if route.path == "/api/v1/chat":
return route.endpoint
raise AssertionError("sync chat route not found")
@pytest.mark.parametrize("base_url", [
"http://127.0.0.1:11434/v1",
"http://localhost:11434/v1",
"http://10.0.0.5/v1",
"http://169.254.169.254/latest/meta-data/",
])
@pytest.mark.asyncio
async def test_api_chat_direct_base_url_rejects_local_private_targets(monkeypatch, base_url):
webhook_routes = _load_webhook_routes_for_test(monkeypatch)
_install_sync_chat_stubs(monkeypatch)
session_manager = _SessionManager()
sync_chat = _sync_chat_endpoint(webhook_routes, session_manager)
body = types.SimpleNamespace(
message="hello",
api_key="test-key",
base_url=base_url,
model="test-model",
provider=None,
session=None,
)
with pytest.raises(webhook_routes.HTTPException) as exc:
await sync_chat(_Request(), body)
assert exc.value.status_code == 400
assert exc.value.detail == "base_url must point to a public HTTP(S) endpoint"
assert session_manager.created == []
@pytest.mark.asyncio
async def test_api_chat_direct_base_url_allows_mocked_public_endpoint(monkeypatch):
webhook_routes = _load_webhook_routes_for_test(monkeypatch)
_install_sync_chat_stubs(monkeypatch)
from src import url_security
monkeypatch.setattr(
url_security,
"_resolve_hostname_ips",
lambda host: [ipaddress.ip_address("93.184.216.34")],
)
session_manager = _SessionManager()
sync_chat = _sync_chat_endpoint(webhook_routes, session_manager)
body = types.SimpleNamespace(
message="hello",
api_key="test-key",
base_url="https://api.example.com/v1",
model="test-model",
provider=None,
session=None,
)
response = await sync_chat(_Request(), body)
assert response["response"] == "mocked response"
assert response["model"] == "test-model"
assert session_manager.created[0]["endpoint_url"] == "https://api.example.com/v1/chat/completions"
def test_api_chat_fallback_endpoint_selection_for_owned_token(monkeypatch):
webhook_routes = _load_webhook_routes_for_test(monkeypatch)
rows = [
_Endpoint(owner="alice", is_enabled=False, created_at=0),
_Endpoint(owner="bob", created_at=0),
_Endpoint(owner=None, created_at=1),
_Endpoint(owner="alice", created_at=2),
]
monkeypatch.setattr(webhook_routes, "ModelEndpoint", _ModelEndpoint)
selected = webhook_routes._select_api_chat_fallback_endpoint(_DB(rows), "alice")
assert selected.owner == "alice"
assert selected.is_enabled is True
assert selected.created_at == 2
def test_api_chat_fallback_without_owner_uses_shared_only(monkeypatch):
webhook_routes = _load_webhook_routes_for_test(monkeypatch)
rows = [
_Endpoint(owner="alice", created_at=0),
_Endpoint(owner=None, is_enabled=False, created_at=1),
_Endpoint(owner=None, created_at=2),
]
monkeypatch.setattr(webhook_routes, "ModelEndpoint", _ModelEndpoint)
selected = webhook_routes._select_api_chat_fallback_endpoint(_DB(rows), None)
assert selected.owner is None
assert selected.is_enabled is True
assert selected.created_at == 2
@pytest.mark.asyncio
async def test_api_chat_fallback_trusts_configured_local_endpoint(monkeypatch):
webhook_routes = _load_webhook_routes_for_test(monkeypatch)
_install_sync_chat_stubs(monkeypatch)
local_endpoint = _Endpoint(
owner=None,
base_url="http://localhost:11434/v1",
api_key="configured-key",
)
db = _DB([local_endpoint])
calls = []
def _session_local():
return db
def _validate_public_http_url(url, *, max_length=2048):
calls.append(url)
raise AssertionError("configured fallback endpoint should not be publicly validated")
monkeypatch.setattr(webhook_routes, "ModelEndpoint", _ModelEndpoint)
monkeypatch.setattr(webhook_routes, "SessionLocal", _session_local)
monkeypatch.setattr(webhook_routes, "validate_public_http_url", _validate_public_http_url)
session_manager = _SessionManager()
sync_chat = _sync_chat_endpoint(webhook_routes, session_manager)
body = types.SimpleNamespace(
message="hello",
model="local-model",
api_key=None,
base_url=None,
provider=None,
session=None,
)
response = await sync_chat(_Request(owner=None), body)
assert response["response"] == "mocked response"
assert response["model"] == "local-model"
assert session_manager.created[0]["endpoint_url"] == "http://localhost:11434/v1/chat/completions"
assert calls == []
+22 -7
View File
@@ -247,10 +247,14 @@ class _Column:
def __eq__(self, value): def __eq__(self, value):
return _Predicate(lambda row: getattr(row, self.name) == value) return _Predicate(lambda row: getattr(row, self.name) == value)
def desc(self):
return self
class _ModelEndpoint: class _ModelEndpoint:
is_enabled = _Column("is_enabled") is_enabled = _Column("is_enabled")
owner = _Column("owner") owner = _Column("owner")
created_at = _Column("created_at")
class _Query: class _Query:
@@ -261,6 +265,9 @@ class _Query:
self._rows = [r for r in self._rows if all(p(r) for p in predicates)] self._rows = [r for r in self._rows if all(p(r) for p in predicates)]
return self return self
def order_by(self, *exprs):
return self
def first(self): def first(self):
return self._rows[0] if self._rows else None return self._rows[0] if self._rows else None
@@ -280,8 +287,10 @@ def _ep(name, owner, *, is_enabled=True):
def _select(rows, owner): def _select(rows, owner):
wh_mod = _import_webhook_helper() wh_mod = _import_webhook_helper()
sys.modules["core.database"].ModelEndpoint = _ModelEndpoint # _select_api_chat_fallback_endpoint uses the module-level ModelEndpoint
return wh_mod._first_enabled_endpoint(_DB(rows), owner) # (not a local import), so we patch the module attribute directly.
wh_mod.ModelEndpoint = _ModelEndpoint
return wh_mod._select_api_chat_fallback_endpoint(_DB(rows), owner)
def test_sync_chat_fallback_never_picks_another_owners_endpoint(): def test_sync_chat_fallback_never_picks_another_owners_endpoint():
@@ -310,9 +319,15 @@ def test_sync_chat_fallback_skips_disabled_owned_endpoint():
assert ep is not None and ep.name == "shared" assert ep is not None and ep.name == "shared"
def test_sync_chat_fallback_null_owner_is_legacy_single_user_noop(): def test_sync_chat_fallback_null_owner_uses_shared_rows_only():
# An unresolvable/empty token owner keeps the original single-user behaviour # When no token owner is known, only null-owner (shared) endpoints are
# (owner_filter no-op): first enabled row, whatever it is. # visible — private endpoints of any user must not be returned.
rows = [_ep("first", "bob"), _ep("second", "alice")] rows = [_ep("bob-private", "bob"), _ep("shared", None)]
ep = _select(rows, None) ep = _select(rows, None)
assert ep is not None and ep.name == "first" assert ep is not None and ep.name == "shared"
def test_sync_chat_fallback_null_owner_returns_none_with_no_shared():
# No shared rows → fail closed rather than returning another user's endpoint.
rows = [_ep("bob-private", "bob"), _ep("alice-private", "alice")]
assert _select(rows, None) is None