""" model_context.py Query and cache model context window sizes from OpenAI-compatible APIs. Provides token estimation for context usage tracking. """ import logging import sys from typing import Dict, List, Optional, Tuple from urllib.parse import urlparse import httpx logger = logging.getLogger(__name__) _LOCAL_HOSTS = {"localhost", "127.0.0.1", "0.0.0.0", "::1", "host.docker.internal"} _PRIVATE_PREFIXES = ("10.", "172.16.", "172.17.", "172.18.", "172.19.", "172.20.", "172.21.", "172.22.", "172.23.", "172.24.", "172.25.", "172.26.", "172.27.", "172.28.", "172.29.", "172.30.", "172.31.", "192.168.", "100.") def _normalize_base_for_compare(url: str) -> str: url = (url or "").strip().rstrip("/") for suffix in ("/chat/completions", "/models", "/completions", "/v1/messages"): if url.endswith(suffix): url = url[: -len(suffix)].rstrip("/") return url def _configured_endpoint_kind(url: str) -> Optional[str]: """Return configured endpoint kind for a chat/base URL when available.""" target = _normalize_base_for_compare(url) if not target: return None if "core.database" not in sys.modules: return None try: from core.database import SessionLocal, ModelEndpoint db = SessionLocal() try: rows = db.query(ModelEndpoint).filter(ModelEndpoint.is_enabled == True).all() for ep in rows: base = _normalize_base_for_compare(getattr(ep, "base_url", "") or "") if not base: continue if target != base and not target.startswith(base + "/"): continue kind = (getattr(ep, "endpoint_kind", None) or "auto").strip().lower() if kind in ("local", "api", "proxy"): return kind if getattr(ep, "api_key", None): parsed = urlparse(base) host = (parsed.hostname or "").lower() path = (parsed.path or "").rstrip("/") if parsed.port != 11434 and "ollama" not in host and (path.endswith("/v1") or "/openai" in path): return "proxy" return "auto" finally: db.close() except Exception: return None def _is_local_endpoint(url: str) -> bool: """Check if URL points to a local/private/tailscale address.""" kind = _configured_endpoint_kind(url) if kind in ("api", "proxy"): return False if kind == "local": return True try: host = urlparse(url).hostname or "" return host in _LOCAL_HOSTS or host.startswith(_PRIVATE_PREFIXES) except Exception: return False # --------------------------------------------------------------------------- # Constants # --------------------------------------------------------------------------- DEFAULT_CONTEXT = 128000 REQUEST_TIMEOUT = 5 # Known context windows for major API models (used as fallback when /models # endpoint doesn't report context_length). # Substring matching — use the shortest unique prefix so variants get caught. KNOWN_CONTEXT_WINDOWS = { # --- Anthropic --- 'claude-sonnet-4-5': 200000, 'claude-sonnet-4-6': 200000, 'claude-sonnet-4': 200000, 'claude-opus-4': 200000, 'claude-haiku-4': 200000, 'claude-haiku-3-5': 200000, 'claude-3-5-sonnet': 200000, 'claude-3-5-haiku': 200000, 'claude-3-opus': 200000, 'claude-3-sonnet': 200000, 'claude-3-haiku': 200000, # --- OpenAI --- 'gpt-5': 400000, 'gpt-4.1': 1047576, 'gpt-4.1-mini': 1047576, 'gpt-4.1-nano': 1047576, 'gpt-4o': 128000, 'gpt-4o-mini': 128000, 'gpt-4-turbo': 128000, 'gpt-4': 8192, 'gpt-3.5-turbo': 16385, 'o1': 200000, 'o1-mini': 128000, 'o1-pro': 200000, 'o3': 200000, 'o3-mini': 200000, 'o4-mini': 200000, # --- DeepSeek --- 'deepseek-chat': 64000, 'deepseek-coder': 64000, 'deepseek-reasoner': 64000, 'deepseek-r1': 64000, 'deepseek-v3': 64000, 'deepseek-v2': 64000, # --- Google --- 'gemini-2.5-pro': 1048576, 'gemini-2.5-flash': 1048576, 'gemini-2.0-flash': 1048576, 'gemini-1.5-pro': 1048576, 'gemini-1.5-flash': 1048576, 'gemma-4': 262144, 'gemma-3': 128000, 'gemma-2': 8192, # --- Mistral --- 'mistral-large': 128000, 'mistral-medium': 32000, 'mistral-small': 32000, 'mistral-nemo': 128000, 'mistral-7b': 32000, 'mixtral': 32000, 'codestral': 32000, 'pixtral': 128000, # --- xAI --- 'grok-4': 131072, 'grok-3': 131072, 'grok-2': 131072, # --- Meta / Llama --- 'llama-4': 1048576, 'llama-3.3': 131072, 'llama-3.2': 131072, 'llama-3.1': 131072, 'llama-3': 131072, # --- Qwen --- 'qwen3': 131072, 'qwen2.5': 131072, 'qwen2': 32768, 'qwq': 32768, # --- Cohere --- 'command-r-plus': 128000, 'command-r': 128000, 'command-a': 256000, # --- Perplexity --- 'sonar-pro': 200000, 'sonar': 128000, # --- MiniMax --- 'minimax': 1000000, # --- Moonshot / Kimi --- 'moonshot': 128000, 'kimi': 128000, # --- Microsoft --- 'phi-4': 16000, 'phi-3': 128000, # --- Nvidia --- 'nemotron': 131072, # --- Yi --- 'yi-large': 32768, 'yi-1.5': 16384, # --- 01.ai --- 'yi-lightning': 16384, # --- Nous --- 'hermes': 131072, 'nous-hermes': 131072, # --- Open community --- 'dolphin': 32768, 'mythomax': 4096, 'wizard': 32768, 'openchat': 8192, 'solar': 32768, } # --------------------------------------------------------------------------- # Cache # --------------------------------------------------------------------------- _context_cache: Dict[Tuple[str, str], int] = {} def get_context_length(endpoint_url: str, model: str) -> int: """Get the context window size for a model. Queries /v1/models on the endpoint and looks for context_length or context_window fields. Caches result per (endpoint, model). Falls back to DEFAULT_CONTEXT if unavailable. """ configured_kind = _configured_endpoint_kind(endpoint_url) is_local = _is_local_endpoint(endpoint_url) # Key on (endpoint_url, model): the same model id can be served by two # different remote endpoints with different real context windows (e.g. a # capped proxy vs. the full provider), so caching by model id alone would # serve one endpoint's window for the other (issue #2603). cache_key = (endpoint_url, model) if not is_local and cache_key in _context_cache: return _context_cache[cache_key] ctx = _query_context_length(endpoint_url, model) # Only cache non-default values to allow retry on next request. # Local endpoints can restart with a different --max-model-len while keeping # the same model id, so always re-query them instead of serving stale cache. if not is_local and (ctx != DEFAULT_CONTEXT or configured_kind in ("api", "proxy")): _context_cache[cache_key] = ctx logger.info(f"Context length for {model}: {ctx}") return ctx def _lookup_known(model: str) -> Optional[int]: """Check known context windows by substring match. Picks the LONGEST matching key so a short key never shadows a more specific one. Without this, 'o1' (200k) precedes 'o1-mini' (128k) in the table and a first-match return would report o1-mini's window as 200k. """ name = model.lower() basename = name.split("/")[-1] if "/" in name else name basename = basename.split(":")[0] # strip :free, :extended etc. best_key: Optional[str] = None best_ctx: Optional[int] = None for key, ctx in KNOWN_CONTEXT_WINDOWS.items(): if key in basename or key in name: if best_key is None or len(key) > len(best_key): best_key, best_ctx = key, ctx return best_ctx def _query_context_length(endpoint_url: str, model: str) -> int: """Query the model API for context length.""" known = _lookup_known(model) api_ctx = None configured_kind = _configured_endpoint_kind(endpoint_url) # Large OpenAI-compatible proxies can make /models expensive. If the # endpoint is explicitly configured as API/proxy, prefer known context # metadata (or the default) over downloading the full catalog. if configured_kind in ("api", "proxy"): if known: logger.info(f"Using known context window for {model}: {known}") return known return DEFAULT_CONTEXT # Try llama.cpp /slots endpoint first — reports actual serving context if _is_local_endpoint(endpoint_url): try: base = endpoint_url.split("/v1")[0] if "/v1" in endpoint_url else endpoint_url.rsplit("/", 1)[0] r = httpx.get(f"{base}/slots", timeout=REQUEST_TIMEOUT) if r.is_success: slots = r.json() if isinstance(slots, list) and slots: n_ctx = slots[0].get("n_ctx") if n_ctx and isinstance(n_ctx, int) and n_ctx > 0: logger.info(f"llama.cpp /slots reports n_ctx={n_ctx} for {model}") return n_ctx except Exception: pass # GitHub Copilot's /models requires auth + X-GitHub-Api-Version headers that # aren't available here; an unauthenticated probe just 400s. All Copilot # picker models are major API models covered by the known-context table, so # rely on that instead of a doomed network call. from src.copilot import is_copilot_base if is_copilot_base(endpoint_url): if known: logger.info(f"Using known context window for {model}: {known}") return known or DEFAULT_CONTEXT from src.endpoint_resolver import build_models_url models_url = build_models_url(endpoint_url) try: r = httpx.get(models_url, timeout=REQUEST_TIMEOUT) if r.is_success: data = r.json() models_list = data.get("data") or [] for m in models_list: mid = m.get("id", "") if mid == model or mid.split("/")[-1] == model.split("/")[-1]: for field in ( "context_length", "context_window", "max_model_len", "max_context_length", "max_seq_len", ): val = m.get(field) if val and isinstance(val, (int, float)) and val > 0: api_ctx = int(val) break if not api_ctx: meta = m.get("meta") or m.get("model_extra") or {} if isinstance(meta, dict): # n_ctx is the actual serving context (set via -c flag in llama.cpp) for field in ("n_ctx", "context_length", "context_window", "max_model_len"): val = meta.get(field) if val and isinstance(val, (int, float)) and val > 0: api_ctx = int(val) break break except Exception as e: logger.debug(f"Failed to query context length for {model}: {e}") # For local/self-hosted endpoints, trust the API value (user set --max-model-len) # For cloud APIs, use the larger value (API can report low defaults) if api_ctx and known: _is_local = _is_local_endpoint(endpoint_url) if _is_local and api_ctx < known: logger.info(f"Local endpoint reports {api_ctx} for {model} (known max: {known}) — using API value") return api_ctx result = max(api_ctx, known) if api_ctx < known: logger.info(f"API reported {api_ctx} for {model}, using known {known} instead") return result if api_ctx: return api_ctx if known: logger.info(f"Using known context window for {model}: {known}") return known return DEFAULT_CONTEXT def estimate_tokens(messages: List[Dict]) -> int: """Rough token estimate for a list of messages. Uses chars * 0.3 which is closer to real BPE tokenizer output than the commonly-cited chars/4 (which underestimates by ~20-30%). Also adds ~4 tokens per message for role/formatting overhead, and counts assistant tool_calls (name + arguments) — a tool-only turn carries content=None with the real payload in tool_calls, so ignoring them made the estimate (and the compaction/trim gates that rely on it) blind to large tool arguments. """ total = 0 for msg in messages: total += 4 # per-message overhead (role, separators) content = msg.get("content", "") if isinstance(content, str): total += int(len(content) * 0.3) elif isinstance(content, list): for item in content: if isinstance(item, dict) and item.get("type") == "text": total += int(len(item.get("text", "")) * 0.3) # Tool calls carry real payload too: a tool-only assistant turn is stored # with content=None and the actual args (e.g. a create_document body) in # tool_calls[].function.arguments. Ignoring them made large tool arguments # read as ~0 tokens, so the compaction/trim gates missed genuine overflow. tool_calls = msg.get("tool_calls") if isinstance(tool_calls, list): for tc in tool_calls: if not isinstance(tc, dict): continue fn = tc.get("function") if isinstance(tc.get("function"), dict) else tc name = fn.get("name", "") or "" args = fn.get("arguments", "") or "" if not isinstance(args, str): args = str(args) # some shapes store arguments as a dict total += 4 # per tool-call overhead (id, type, wrapper) total += int((len(str(name)) + len(args)) * 0.3) return total