fix(llm): route harmony thinking streams (#2449)

This commit is contained in:
nubs
2026-06-05 13:22:08 +00:00
committed by GitHub
parent 8159733c6c
commit 8354948a1c
2 changed files with 224 additions and 53 deletions
+190 -53
View File
@@ -6,8 +6,9 @@ import json
import logging import logging
import hashlib import hashlib
import threading import threading
import re
from fastapi import HTTPException from fastapi import HTTPException
from typing import Optional, Dict, List from typing import Optional, Dict, List, Tuple
from src.model_context import get_context_length, DEFAULT_CONTEXT from src.model_context import get_context_length, DEFAULT_CONTEXT
from urllib.parse import urlparse from urllib.parse import urlparse
@@ -66,6 +67,103 @@ _host_fails: Dict[str, int] = {}
_host_health_lock = threading.Lock() _host_health_lock = threading.Lock()
_model_activity: Dict[str, float] = {} _model_activity: Dict[str, float] = {}
_HARMONY_MARKER_RE = re.compile(
r"<\|channel\|>(analysis|final)"
r"|<\|start\|>(?:assistant|system|user|tool)?"
r"|<\|message\|>"
r"|<\|end\|>"
r"|<\|return\|>"
r"|<\|call\|>"
)
_HARMONY_MARKERS = (
"<|channel|>analysis",
"<|channel|>final",
"<|start|>assistant",
"<|start|>system",
"<|start|>user",
"<|start|>tool",
"<|start|>",
"<|message|>",
"<|end|>",
"<|return|>",
"<|call|>",
)
_HARMONY_MAX_MARKER_LEN = max(len(marker) for marker in _HARMONY_MARKERS)
def _harmony_suffix_hold_len(text: str) -> int:
"""Return how many trailing chars could be the start of a harmony marker."""
limit = min(len(text), _HARMONY_MAX_MARKER_LEN - 1)
for n in range(limit, 0, -1):
suffix = text[-n:]
if any(marker.startswith(suffix) for marker in _HARMONY_MARKERS):
return n
return 0
class _HarmonyStreamRouter:
"""Route OpenAI harmony analysis/final channels without leaking markers."""
def __init__(self) -> None:
self._buf = ""
self._seen_harmony = False
self._channel: Optional[str] = None
self._in_message = False
def feed(self, text: str) -> List[Tuple[str, bool]]:
if not text:
return []
self._buf += text
return self._drain(final=False)
def flush(self) -> List[Tuple[str, bool]]:
return self._drain(final=True)
def _append_text(self, out: List[Tuple[str, bool]], text: str) -> None:
if not text:
return
if not self._seen_harmony:
out.append((text, False))
return
if self._in_message:
out.append((text, self._channel == "analysis"))
def _handle_marker(self, match: re.Match[str]) -> None:
marker = match.group(0)
self._seen_harmony = True
if marker.startswith("<|channel|>"):
self._channel = match.group(1)
self._in_message = False
elif marker == "<|message|>":
self._in_message = True
else:
self._in_message = False
if marker in {"<|end|>", "<|return|>", "<|call|>"}:
self._channel = None
def _drain(self, *, final: bool) -> List[Tuple[str, bool]]:
out: List[Tuple[str, bool]] = []
while True:
match = _HARMONY_MARKER_RE.search(self._buf)
if not match:
break
self._append_text(out, self._buf[:match.start()])
self._handle_marker(match)
self._buf = self._buf[match.end():]
hold = 0 if final else _harmony_suffix_hold_len(self._buf)
emit = self._buf if hold == 0 else self._buf[:-hold]
self._buf = "" if hold == 0 else self._buf[-hold:]
self._append_text(out, emit)
return out
def _stream_delta_event(text: str, *, thinking: bool = False) -> str:
payload = {"delta": text}
if thinking:
payload["thinking"] = True
return f"data: {json.dumps(payload)}\n\n"
def _model_activity_key(url: str, model: str) -> str: def _model_activity_key(url: str, model: str) -> str:
return f"{(url or '').strip()}|{(model or '').strip()}" return f"{(url or '').strip()}|{(model or '').strip()}"
@@ -1217,6 +1315,7 @@ async def stream_llm(url: str, model: str, messages: List[Dict], temperature: fl
# ── Native Ollama streaming ── # ── Native Ollama streaming ──
if provider == "ollama": if provider == "ollama":
_ollama_tool_calls: List[Dict] = [] _ollama_tool_calls: List[Dict] = []
_harmony_router = _HarmonyStreamRouter()
try: try:
client = _get_http_client() client = _get_http_client()
async with client.stream('POST', target_url, json=payload, headers=h, timeout=stream_timeout) as r: async with client.stream('POST', target_url, json=payload, headers=h, timeout=stream_timeout) as r:
@@ -1236,10 +1335,11 @@ async def stream_llm(url: str, model: str, messages: List[Dict], temperature: fl
message = j.get("message") or {} message = j.get("message") or {}
thinking = message.get("thinking") or "" thinking = message.get("thinking") or ""
if thinking: if thinking:
yield f'data: {json.dumps({"delta": thinking, "thinking": True})}\n\n' yield _stream_delta_event(thinking, thinking=True)
content = message.get("content") or "" content = message.get("content") or ""
if content: if content:
yield f'data: {json.dumps({"delta": content})}\n\n' for part, is_thinking in _harmony_router.feed(content):
yield _stream_delta_event(part, thinking=is_thinking)
for tc in message.get("tool_calls") or []: for tc in message.get("tool_calls") or []:
fn = tc.get("function") or {} fn = tc.get("function") or {}
if fn.get("name"): if fn.get("name"):
@@ -1249,12 +1349,16 @@ async def stream_llm(url: str, model: str, messages: List[Dict], temperature: fl
"arguments": json.dumps(fn.get("arguments") or {}), "arguments": json.dumps(fn.get("arguments") or {}),
}) })
if j.get("done"): if j.get("done"):
for part, is_thinking in _harmony_router.flush():
yield _stream_delta_event(part, thinking=is_thinking)
if _ollama_tool_calls: if _ollama_tool_calls:
yield f'data: {json.dumps({"type": "tool_calls", "calls": _ollama_tool_calls})}\n\n' yield f'data: {json.dumps({"type": "tool_calls", "calls": _ollama_tool_calls})}\n\n'
if j.get("prompt_eval_count") is not None or j.get("eval_count") is not None: if j.get("prompt_eval_count") is not None or j.get("eval_count") is not None:
yield f'data: {json.dumps({"type": "usage", "data": {"input_tokens": j.get("prompt_eval_count", 0), "output_tokens": j.get("eval_count", 0)}})}\n\n' yield f'data: {json.dumps({"type": "usage", "data": {"input_tokens": j.get("prompt_eval_count", 0), "output_tokens": j.get("eval_count", 0)}})}\n\n'
yield "data: [DONE]\n\n" yield "data: [DONE]\n\n"
return return
for part, is_thinking in _harmony_router.flush():
yield _stream_delta_event(part, thinking=is_thinking)
yield "data: [DONE]\n\n" yield "data: [DONE]\n\n"
except (httpx.ConnectError, httpx.ConnectTimeout) as e: except (httpx.ConnectError, httpx.ConnectTimeout) as e:
_cooled = _mark_host_dead(target_url) _cooled = _mark_host_dead(target_url)
@@ -1387,6 +1491,8 @@ async def stream_llm(url: str, model: str, messages: List[Dict], temperature: fl
_first_content_sent = False _first_content_sent = False
_in_think_tag = False # True while consuming <think>…</think> content _in_think_tag = False # True while consuming <think>…</think> content
_think_open_stripped = False # opening <think> tag already removed _think_open_stripped = False # opening <think> tag already removed
_harmony_router = _HarmonyStreamRouter()
_harmony_active = False # sticky: gpt-oss harmony <|channel|> stream detected
def _emit_tool_calls(): def _emit_tool_calls():
"""Build the tool_calls event string if any were accumulated.""" """Build the tool_calls event string if any were accumulated."""
@@ -1395,6 +1501,22 @@ async def stream_llm(url: str, model: str, messages: List[Dict], temperature: fl
calls = [_tc_acc[i] for i in sorted(_tc_acc)] calls = [_tc_acc[i] for i in sorted(_tc_acc)]
return f'data: {json.dumps({"type": "tool_calls", "calls": calls})}\n\n' return f'data: {json.dumps({"type": "tool_calls", "calls": calls})}\n\n'
def _format_routed_content(parts: List[Tuple[str, bool]]) -> List[str]:
nonlocal _first_content_sent
events = []
for part, is_thinking in parts:
if is_thinking:
events.append(_stream_delta_event(part, thinking=True))
continue
# Some thinking backends start normal content with a stray closing
# tag. Repair only that shape; do not wrap every first token for
# model families like MiniMax, which often stream ordinary answers.
if _thinking_model and not _first_content_sent and part.lstrip().lower().startswith("</think"):
part = "<think>" + part
_first_content_sent = True
events.append(_stream_delta_event(part))
return events
try: try:
client = _get_http_client() client = _get_http_client()
async with client.stream('POST', target_url, json=payload, headers=h, timeout=stream_timeout) as r: async with client.stream('POST', target_url, json=payload, headers=h, timeout=stream_timeout) as r:
@@ -1415,6 +1537,8 @@ async def stream_llm(url: str, model: str, messages: List[Dict], temperature: fl
if line.startswith("data:"): if line.startswith("data:"):
data = line[5:].strip() data = line[5:].strip()
if data == "[DONE]": if data == "[DONE]":
for event in _format_routed_content(_harmony_router.flush()):
yield event
tc_event = _emit_tool_calls() tc_event = _emit_tool_calls()
if tc_event: if tc_event:
yield tc_event yield tc_event
@@ -1438,6 +1562,7 @@ async def stream_llm(url: str, model: str, messages: List[Dict], temperature: fl
_delta0.get("content") _delta0.get("content")
or _delta0.get("reasoning_content") or _delta0.get("reasoning_content")
or _delta0.get("reasoning") or _delta0.get("reasoning")
or _delta0.get("thinking")
or _delta0.get("tool_calls") or _delta0.get("tool_calls")
) )
if "usage" in j and not _delta_has_output: if "usage" in j and not _delta_has_output:
@@ -1462,59 +1587,67 @@ async def stream_llm(url: str, model: str, messages: List[Dict], temperature: fl
delta = _c0.get("delta") or {} delta = _c0.get("delta") or {}
if isinstance(delta, dict): if isinstance(delta, dict):
# Text content # Text content
# Reasoning tokens (VLLM --reasoning-parser, e.g. Qwen3/DeepSeek-R1, Nemotron). vLLM 0.20.2 / NIM emit the field as `reasoning`; older builds use `reasoning_content`. Accept either. # Reasoning tokens (VLLM --reasoning-parser, e.g. Qwen3/DeepSeek-R1, Nemotron). vLLM 0.20.2 / NIM emit the field as `reasoning`; older builds use `reasoning_content`. Some OpenAI-compatible Ollama builds use `thinking`.
reasoning = delta.get("reasoning_content") or delta.get("reasoning") or "" reasoning = delta.get("reasoning_content") or delta.get("reasoning") or delta.get("thinking") or ""
if reasoning: if reasoning:
yield f'data: {json.dumps({"delta": reasoning, "thinking": True})}\n\n' yield _stream_delta_event(reasoning, thinking=True)
content = delta.get("content") or "" content = delta.get("content") or ""
if content: if content:
stripped = content.lstrip() stripped = content.lstrip()
# Auto-detect <think>…</think> in content stream. # gpt-oss harmony format (<|channel|>analysis/final): route via the harmony
# Covers Qwen3-derived models (Qwopus, QwQ forks) whose # stream router. Sticky once the first marker appears — distinct from the
# names don't match _THINKING_MODEL_PATTERNS but still # <think> path below (handled in the else, preserving #2588 behaviour).
# emit literal <think> markup via llama.cpp --jinja. if _harmony_active or "<|" in content:
if not _first_content_sent and not _thinking_model and not _in_think_tag and stripped.lower().startswith("<think"): _harmony_active = True
_thinking_model = True for event in _format_routed_content(_harmony_router.feed(content)):
_in_think_tag = True yield event
if _in_think_tag:
close_idx = content.lower().find("</think>")
if close_idx != -1:
# Split: up-to-</think> → thinking, remainder → content
think_part = content[:close_idx]
if not _think_open_stripped:
# Strip the opening <think[...] > from the first chunk.
# Use a dedicated flag — _first_content_sent stays False
# throughout the think block, so it must not be reused.
tag_end = think_part.lower().find(">")
if tag_end != -1:
think_part = think_part[tag_end + 1:]
_think_open_stripped = True
regular_part = content[close_idx + len("</think>"):]
_in_think_tag = False
if think_part:
yield f'data: {json.dumps({"delta": think_part, "thinking": True})}\n\n'
if regular_part:
_first_content_sent = True
yield f'data: {json.dumps({"delta": regular_part})}\n\n'
else:
# Still inside <think>: route to thinking channel
if not _think_open_stripped:
# Strip the opening <think[...] > tag (first chunk only)
tag_end = stripped.lower().find(">")
if tag_end != -1:
content = stripped[tag_end + 1:]
_think_open_stripped = True
if content:
yield f'data: {json.dumps({"delta": content, "thinking": True})}\n\n'
else: else:
# Some thinking backends start normal content with a # Auto-detect <think>…</think> in content stream.
# stray closing tag. Repair only that shape; do not # Covers Qwen3-derived models (Qwopus, QwQ forks) whose
# wrap every first token for model families like # names don't match _THINKING_MODEL_PATTERNS but still
# MiniMax, which often stream ordinary answers. # emit literal <think> markup via llama.cpp --jinja.
if _thinking_model and not _first_content_sent and stripped.lower().startswith("</think"): if not _first_content_sent and not _thinking_model and not _in_think_tag and stripped.lower().startswith("<think"):
content = "<think>" + content _thinking_model = True
_first_content_sent = True _in_think_tag = True
yield f'data: {json.dumps({"delta": content})}\n\n' if _in_think_tag:
close_idx = content.lower().find("</think>")
if close_idx != -1:
# Split: up-to-</think> → thinking, remainder → content
think_part = content[:close_idx]
if not _think_open_stripped:
# Strip the opening <think[...] > from the first chunk.
# Use a dedicated flag — _first_content_sent stays False
# throughout the think block, so it must not be reused.
tag_end = think_part.lower().find(">")
if tag_end != -1:
think_part = think_part[tag_end + 1:]
_think_open_stripped = True
regular_part = content[close_idx + len("</think>"):]
_in_think_tag = False
if think_part:
yield f'data: {json.dumps({"delta": think_part, "thinking": True})}\n\n'
if regular_part:
_first_content_sent = True
yield f'data: {json.dumps({"delta": regular_part})}\n\n'
else:
# Still inside <think>: route to thinking channel
if not _think_open_stripped:
# Strip the opening <think[...] > tag (first chunk only)
tag_end = stripped.lower().find(">")
if tag_end != -1:
content = stripped[tag_end + 1:]
_think_open_stripped = True
if content:
yield f'data: {json.dumps({"delta": content, "thinking": True})}\n\n'
else:
# Some thinking backends start normal content with a
# stray closing tag. Repair only that shape; do not
# wrap every first token for model families like
# MiniMax, which often stream ordinary answers.
if _thinking_model and not _first_content_sent and stripped.lower().startswith("</think"):
content = "<think>" + content
_first_content_sent = True
yield f'data: {json.dumps({"delta": content})}\n\n'
# Native tool calls — accumulate across chunks # Native tool calls — accumulate across chunks
for tc in delta.get("tool_calls") or []: for tc in delta.get("tool_calls") or []:
if tc is None: if tc is None:
@@ -1563,15 +1696,19 @@ async def stream_llm(url: str, model: str, messages: List[Dict], temperature: fl
yield f'data: {json.dumps({"type": "tool_call_delta", "index": idx, "name": _tc_acc[idx]["name"], "arg_delta": func["arguments"]})}\n\n' yield f'data: {json.dumps({"type": "tool_call_delta", "index": idx, "name": _tc_acc[idx]["name"], "arg_delta": func["arguments"]})}\n\n'
elif "text" in j: elif "text" in j:
if j["text"]: if j["text"]:
yield f'data: {json.dumps({"delta": j["text"]})}\n\n' for event in _format_routed_content(_harmony_router.feed(j["text"])):
yield event
else: else:
if data.strip(): if data.strip():
yield f'data: {json.dumps({"delta": data})}\n\n' for event in _format_routed_content(_harmony_router.feed(data)):
yield event
except Exception as e: except Exception as e:
logger.error(f"Error parsing stream data: {e}") logger.error(f"Error parsing stream data: {e}")
continue continue
# End of stream (no explicit [DONE] received) # End of stream (no explicit [DONE] received)
for event in _format_routed_content(_harmony_router.flush()):
yield event
tc_event = _emit_tool_calls() tc_event = _emit_tool_calls()
if tc_event: if tc_event:
yield tc_event yield tc_event
+34
View File
@@ -172,3 +172,37 @@ def test_registered_thinking_model_stray_close_tag_repair_unchanged(monkeypatch)
assert deltas, deltas assert deltas, deltas
first = deltas[0]["delta"] first = deltas[0]["delta"]
assert first.startswith("<think>"), f"expected repair prefix, got: {first!r}" assert first.startswith("<think>"), f"expected repair prefix, got: {first!r}"
def test_thinking_field_emits_thinking_chunk(monkeypatch):
deltas = _run_stream(
"gpt-oss:20b",
[
'data: {"choices":[{"delta":{"thinking":"checking files"}}]}',
'data: {"choices":[{"delta":{"content":"visible answer"}}]}',
"data: [DONE]",
],
monkeypatch,
)
assert any(d.get("thinking") and d["delta"] == "checking files" for d in deltas), deltas
assert any((not d.get("thinking")) and d["delta"] == "visible answer" for d in deltas), deltas
def test_harmony_analysis_channel_routes_to_thinking(monkeypatch):
deltas = _run_stream(
"gpt-oss:20b",
[
'data: {"choices":[{"delta":{"content":"<|channel|>ana"}}]}',
'data: {"choices":[{"delta":{"content":"lysis<|message|>We need to inspect."}}]}',
'data: {"choices":[{"delta":{"content":"<|end|><|channel|>final<|message|>Here "}}]}',
'data: {"choices":[{"delta":{"content":"are the files.<|end|>"}}]}',
"data: [DONE]",
],
monkeypatch,
)
thinking = "".join(d["delta"] for d in deltas if d.get("thinking"))
answer = "".join(d["delta"] for d in deltas if not d.get("thinking"))
assert thinking == "We need to inspect."
assert answer == "Here are the files."
assert "<|channel|>" not in thinking + answer
assert "<|message|>" not in thinking + answer