Files
odysseus/src/model_context.py
T
Ocean Bennett e7c1d75884 fix(models): query v1 models for llama-server endpoints (#3380)
* fix(models): query v1 models for llama-server endpoints

* test(models): accept owner kwargs in llama-server regression
2026-06-09 01:09:02 +02:00

395 lines
14 KiB
Python

"""
model_context.py
Query and cache model context window sizes from OpenAI-compatible APIs.
Provides token estimation for context usage tracking.
"""
import logging
import sys
from typing import Dict, List, Optional, Tuple
from urllib.parse import urlparse
import httpx
logger = logging.getLogger(__name__)
_LOCAL_HOSTS = {"localhost", "127.0.0.1", "0.0.0.0", "::1", "host.docker.internal"}
_PRIVATE_PREFIXES = ("10.", "172.16.", "172.17.", "172.18.", "172.19.",
"172.20.", "172.21.", "172.22.", "172.23.", "172.24.",
"172.25.", "172.26.", "172.27.", "172.28.", "172.29.",
"172.30.", "172.31.", "192.168.", "100.")
def _normalize_base_for_compare(url: str) -> str:
url = (url or "").strip().rstrip("/")
for suffix in ("/chat/completions", "/models", "/completions", "/v1/messages"):
if url.endswith(suffix):
url = url[: -len(suffix)].rstrip("/")
return url
def _configured_endpoint_kind(url: str) -> Optional[str]:
"""Return configured endpoint kind for a chat/base URL when available."""
target = _normalize_base_for_compare(url)
if not target:
return None
if "core.database" not in sys.modules:
return None
try:
from core.database import SessionLocal, ModelEndpoint
db = SessionLocal()
try:
rows = db.query(ModelEndpoint).filter(ModelEndpoint.is_enabled == True).all()
for ep in rows:
base = _normalize_base_for_compare(getattr(ep, "base_url", "") or "")
if not base:
continue
if target != base and not target.startswith(base + "/"):
continue
kind = (getattr(ep, "endpoint_kind", None) or "auto").strip().lower()
if kind in ("local", "api", "proxy"):
return kind
if getattr(ep, "api_key", None):
parsed = urlparse(base)
host = (parsed.hostname or "").lower()
path = (parsed.path or "").rstrip("/")
if parsed.port != 11434 and "ollama" not in host and (path.endswith("/v1") or "/openai" in path):
return "proxy"
return "auto"
finally:
db.close()
except Exception:
return None
def _is_local_endpoint(url: str) -> bool:
"""Check if URL points to a local/private/tailscale address."""
kind = _configured_endpoint_kind(url)
if kind in ("api", "proxy"):
return False
if kind == "local":
return True
try:
host = urlparse(url).hostname or ""
return host in _LOCAL_HOSTS or host.startswith(_PRIVATE_PREFIXES)
except Exception:
return False
# ---------------------------------------------------------------------------
# Constants
# ---------------------------------------------------------------------------
DEFAULT_CONTEXT = 128000
REQUEST_TIMEOUT = 5
# Known context windows for major API models (used as fallback when /models
# endpoint doesn't report context_length).
# Substring matching — use the shortest unique prefix so variants get caught.
KNOWN_CONTEXT_WINDOWS = {
# --- Anthropic ---
'claude-sonnet-4-5': 200000,
'claude-sonnet-4-6': 200000,
'claude-sonnet-4': 200000,
'claude-opus-4': 200000,
'claude-haiku-4': 200000,
'claude-haiku-3-5': 200000,
'claude-3-5-sonnet': 200000,
'claude-3-5-haiku': 200000,
'claude-3-opus': 200000,
'claude-3-sonnet': 200000,
'claude-3-haiku': 200000,
# --- OpenAI ---
'gpt-5': 400000,
'gpt-4.1': 1047576,
'gpt-4.1-mini': 1047576,
'gpt-4.1-nano': 1047576,
'gpt-4o': 128000,
'gpt-4o-mini': 128000,
'gpt-4-turbo': 128000,
'gpt-4': 8192,
'gpt-3.5-turbo': 16385,
'o1': 200000,
'o1-mini': 128000,
'o1-pro': 200000,
'o3': 200000,
'o3-mini': 200000,
'o4-mini': 200000,
# --- DeepSeek ---
'deepseek-chat': 64000,
'deepseek-coder': 64000,
'deepseek-reasoner': 64000,
'deepseek-r1': 64000,
'deepseek-v3': 64000,
'deepseek-v2': 64000,
# --- Google ---
'gemini-2.5-pro': 1048576,
'gemini-2.5-flash': 1048576,
'gemini-2.0-flash': 1048576,
'gemini-1.5-pro': 1048576,
'gemini-1.5-flash': 1048576,
'gemma-4': 262144,
'gemma-3': 128000,
'gemma-2': 8192,
# --- Mistral ---
'mistral-large': 128000,
'mistral-medium': 32000,
'mistral-small': 32000,
'mistral-nemo': 128000,
'mistral-7b': 32000,
'mixtral': 32000,
'codestral': 32000,
'pixtral': 128000,
# --- xAI ---
'grok-4': 131072,
'grok-3': 131072,
'grok-2': 131072,
# --- Meta / Llama ---
'llama-4': 1048576,
'llama-3.3': 131072,
'llama-3.2': 131072,
'llama-3.1': 131072,
'llama-3': 131072,
# --- Qwen ---
'qwen3': 131072,
'qwen2.5': 131072,
'qwen2': 32768,
'qwq': 32768,
# --- Cohere ---
'command-r-plus': 128000,
'command-r': 128000,
'command-a': 256000,
# --- Perplexity ---
'sonar-pro': 200000,
'sonar': 128000,
# --- MiniMax ---
'minimax': 1000000,
# --- Moonshot / Kimi ---
'moonshot': 128000,
'kimi': 128000,
# --- Microsoft ---
'phi-4': 16000,
'phi-3': 128000,
# --- Nvidia ---
'nemotron': 131072,
# --- Yi ---
'yi-large': 32768,
'yi-1.5': 16384,
# --- 01.ai ---
'yi-lightning': 16384,
# --- Nous ---
'hermes': 131072,
'nous-hermes': 131072,
# --- Open community ---
'dolphin': 32768,
'mythomax': 4096,
'wizard': 32768,
'openchat': 8192,
'solar': 32768,
}
# ---------------------------------------------------------------------------
# Cache
# ---------------------------------------------------------------------------
_context_cache: Dict[Tuple[str, str], int] = {}
def get_context_length(endpoint_url: str, model: str) -> int:
"""Get the context window size for a model.
Queries /v1/models on the endpoint and looks for context_length
or context_window fields. Caches result per (endpoint, model).
Falls back to DEFAULT_CONTEXT if unavailable.
"""
configured_kind = _configured_endpoint_kind(endpoint_url)
is_local = _is_local_endpoint(endpoint_url)
# Key on (endpoint_url, model): the same model id can be served by two
# different remote endpoints with different real context windows (e.g. a
# capped proxy vs. the full provider), so caching by model id alone would
# serve one endpoint's window for the other (issue #2603).
cache_key = (endpoint_url, model)
if not is_local and cache_key in _context_cache:
return _context_cache[cache_key]
ctx = _query_context_length(endpoint_url, model)
# Only cache non-default values to allow retry on next request.
# Local endpoints can restart with a different --max-model-len while keeping
# the same model id, so always re-query them instead of serving stale cache.
if not is_local and (ctx != DEFAULT_CONTEXT or configured_kind in ("api", "proxy")):
_context_cache[cache_key] = ctx
logger.info(f"Context length for {model}: {ctx}")
return ctx
def _lookup_known(model: str) -> Optional[int]:
"""Check known context windows by substring match.
Picks the LONGEST matching key so a short key never shadows a more specific
one. Without this, 'o1' (200k) precedes 'o1-mini' (128k) in the table and a
first-match return would report o1-mini's window as 200k.
"""
name = model.lower()
basename = name.split("/")[-1] if "/" in name else name
basename = basename.split(":")[0] # strip :free, :extended etc.
best_key: Optional[str] = None
best_ctx: Optional[int] = None
for key, ctx in KNOWN_CONTEXT_WINDOWS.items():
if key in basename or key in name:
if best_key is None or len(key) > len(best_key):
best_key, best_ctx = key, ctx
return best_ctx
def _query_context_length(endpoint_url: str, model: str) -> int:
"""Query the model API for context length."""
known = _lookup_known(model)
api_ctx = None
configured_kind = _configured_endpoint_kind(endpoint_url)
# Large OpenAI-compatible proxies can make /models expensive. If the
# endpoint is explicitly configured as API/proxy, prefer known context
# metadata (or the default) over downloading the full catalog.
if configured_kind in ("api", "proxy"):
if known:
logger.info(f"Using known context window for {model}: {known}")
return known
return DEFAULT_CONTEXT
# Try llama.cpp /slots endpoint first — reports actual serving context
if _is_local_endpoint(endpoint_url):
try:
base = endpoint_url.split("/v1")[0] if "/v1" in endpoint_url else endpoint_url.rsplit("/", 1)[0]
r = httpx.get(f"{base}/slots", timeout=REQUEST_TIMEOUT)
if r.is_success:
slots = r.json()
if isinstance(slots, list) and slots:
n_ctx = slots[0].get("n_ctx")
if n_ctx and isinstance(n_ctx, int) and n_ctx > 0:
logger.info(f"llama.cpp /slots reports n_ctx={n_ctx} for {model}")
return n_ctx
except Exception:
pass
# GitHub Copilot's /models requires auth + X-GitHub-Api-Version headers that
# aren't available here; an unauthenticated probe just 400s. All Copilot
# picker models are major API models covered by the known-context table, so
# rely on that instead of a doomed network call.
from src.copilot import is_copilot_base
if is_copilot_base(endpoint_url):
if known:
logger.info(f"Using known context window for {model}: {known}")
return known or DEFAULT_CONTEXT
from src.endpoint_resolver import build_models_url
models_url = build_models_url(endpoint_url)
try:
r = httpx.get(models_url, timeout=REQUEST_TIMEOUT)
if r.is_success:
data = r.json()
models_list = data.get("data") or []
for m in models_list:
mid = m.get("id", "")
if mid == model or mid.split("/")[-1] == model.split("/")[-1]:
for field in (
"context_length",
"context_window",
"max_model_len",
"max_context_length",
"max_seq_len",
):
val = m.get(field)
if val and isinstance(val, (int, float)) and val > 0:
api_ctx = int(val)
break
if not api_ctx:
meta = m.get("meta") or m.get("model_extra") or {}
if isinstance(meta, dict):
# n_ctx is the actual serving context (set via -c flag in llama.cpp)
for field in ("n_ctx", "context_length", "context_window", "max_model_len"):
val = meta.get(field)
if val and isinstance(val, (int, float)) and val > 0:
api_ctx = int(val)
break
break
except Exception as e:
logger.debug(f"Failed to query context length for {model}: {e}")
# For local/self-hosted endpoints, trust the API value (user set --max-model-len)
# For cloud APIs, use the larger value (API can report low defaults)
if api_ctx and known:
_is_local = _is_local_endpoint(endpoint_url)
if _is_local and api_ctx < known:
logger.info(f"Local endpoint reports {api_ctx} for {model} (known max: {known}) — using API value")
return api_ctx
result = max(api_ctx, known)
if api_ctx < known:
logger.info(f"API reported {api_ctx} for {model}, using known {known} instead")
return result
if api_ctx:
return api_ctx
if known:
logger.info(f"Using known context window for {model}: {known}")
return known
return DEFAULT_CONTEXT
def estimate_tokens(messages: List[Dict]) -> int:
"""Rough token estimate for a list of messages.
Uses chars * 0.3 which is closer to real BPE tokenizer output
than the commonly-cited chars/4 (which underestimates by ~20-30%).
Also adds ~4 tokens per message for role/formatting overhead, and counts
assistant tool_calls (name + arguments) — a tool-only turn carries
content=None with the real payload in tool_calls, so ignoring them made the
estimate (and the compaction/trim gates that rely on it) blind to large
tool arguments.
"""
total = 0
for msg in messages:
total += 4 # per-message overhead (role, separators)
content = msg.get("content", "")
if isinstance(content, str):
total += int(len(content) * 0.3)
elif isinstance(content, list):
for item in content:
if isinstance(item, dict) and item.get("type") == "text":
total += int(len(item.get("text", "")) * 0.3)
# Tool calls carry real payload too: a tool-only assistant turn is stored
# with content=None and the actual args (e.g. a create_document body) in
# tool_calls[].function.arguments. Ignoring them made large tool arguments
# read as ~0 tokens, so the compaction/trim gates missed genuine overflow.
tool_calls = msg.get("tool_calls")
if isinstance(tool_calls, list):
for tc in tool_calls:
if not isinstance(tc, dict):
continue
fn = tc.get("function") if isinstance(tc.get("function"), dict) else tc
name = fn.get("name", "") or ""
args = fn.get("arguments", "") or ""
if not isinstance(args, str):
args = str(args) # some shapes store arguments as a dict
total += 4 # per tool-call overhead (id, type, wrapper)
total += int((len(str(name)) + len(args)) * 0.3)
return total