fix(llm): route harmony thinking streams (#2449)

This commit is contained in:
nubs
2026-06-05 13:22:08 +00:00
committed by GitHub
parent 8159733c6c
commit 8354948a1c
2 changed files with 224 additions and 53 deletions
+190 -53
View File
@@ -6,8 +6,9 @@ import json
import logging
import hashlib
import threading
import re
from fastapi import HTTPException
from typing import Optional, Dict, List
from typing import Optional, Dict, List, Tuple
from src.model_context import get_context_length, DEFAULT_CONTEXT
from urllib.parse import urlparse
@@ -66,6 +67,103 @@ _host_fails: Dict[str, int] = {}
_host_health_lock = threading.Lock()
_model_activity: Dict[str, float] = {}
_HARMONY_MARKER_RE = re.compile(
r"<\|channel\|>(analysis|final)"
r"|<\|start\|>(?:assistant|system|user|tool)?"
r"|<\|message\|>"
r"|<\|end\|>"
r"|<\|return\|>"
r"|<\|call\|>"
)
_HARMONY_MARKERS = (
"<|channel|>analysis",
"<|channel|>final",
"<|start|>assistant",
"<|start|>system",
"<|start|>user",
"<|start|>tool",
"<|start|>",
"<|message|>",
"<|end|>",
"<|return|>",
"<|call|>",
)
_HARMONY_MAX_MARKER_LEN = max(len(marker) for marker in _HARMONY_MARKERS)
def _harmony_suffix_hold_len(text: str) -> int:
"""Return how many trailing chars could be the start of a harmony marker."""
limit = min(len(text), _HARMONY_MAX_MARKER_LEN - 1)
for n in range(limit, 0, -1):
suffix = text[-n:]
if any(marker.startswith(suffix) for marker in _HARMONY_MARKERS):
return n
return 0
class _HarmonyStreamRouter:
"""Route OpenAI harmony analysis/final channels without leaking markers."""
def __init__(self) -> None:
self._buf = ""
self._seen_harmony = False
self._channel: Optional[str] = None
self._in_message = False
def feed(self, text: str) -> List[Tuple[str, bool]]:
if not text:
return []
self._buf += text
return self._drain(final=False)
def flush(self) -> List[Tuple[str, bool]]:
return self._drain(final=True)
def _append_text(self, out: List[Tuple[str, bool]], text: str) -> None:
if not text:
return
if not self._seen_harmony:
out.append((text, False))
return
if self._in_message:
out.append((text, self._channel == "analysis"))
def _handle_marker(self, match: re.Match[str]) -> None:
marker = match.group(0)
self._seen_harmony = True
if marker.startswith("<|channel|>"):
self._channel = match.group(1)
self._in_message = False
elif marker == "<|message|>":
self._in_message = True
else:
self._in_message = False
if marker in {"<|end|>", "<|return|>", "<|call|>"}:
self._channel = None
def _drain(self, *, final: bool) -> List[Tuple[str, bool]]:
out: List[Tuple[str, bool]] = []
while True:
match = _HARMONY_MARKER_RE.search(self._buf)
if not match:
break
self._append_text(out, self._buf[:match.start()])
self._handle_marker(match)
self._buf = self._buf[match.end():]
hold = 0 if final else _harmony_suffix_hold_len(self._buf)
emit = self._buf if hold == 0 else self._buf[:-hold]
self._buf = "" if hold == 0 else self._buf[-hold:]
self._append_text(out, emit)
return out
def _stream_delta_event(text: str, *, thinking: bool = False) -> str:
payload = {"delta": text}
if thinking:
payload["thinking"] = True
return f"data: {json.dumps(payload)}\n\n"
def _model_activity_key(url: str, model: str) -> str:
return f"{(url or '').strip()}|{(model or '').strip()}"
@@ -1217,6 +1315,7 @@ async def stream_llm(url: str, model: str, messages: List[Dict], temperature: fl
# ── Native Ollama streaming ──
if provider == "ollama":
_ollama_tool_calls: List[Dict] = []
_harmony_router = _HarmonyStreamRouter()
try:
client = _get_http_client()
async with client.stream('POST', target_url, json=payload, headers=h, timeout=stream_timeout) as r:
@@ -1236,10 +1335,11 @@ async def stream_llm(url: str, model: str, messages: List[Dict], temperature: fl
message = j.get("message") or {}
thinking = message.get("thinking") or ""
if thinking:
yield f'data: {json.dumps({"delta": thinking, "thinking": True})}\n\n'
yield _stream_delta_event(thinking, thinking=True)
content = message.get("content") or ""
if content:
yield f'data: {json.dumps({"delta": content})}\n\n'
for part, is_thinking in _harmony_router.feed(content):
yield _stream_delta_event(part, thinking=is_thinking)
for tc in message.get("tool_calls") or []:
fn = tc.get("function") or {}
if fn.get("name"):
@@ -1249,12 +1349,16 @@ async def stream_llm(url: str, model: str, messages: List[Dict], temperature: fl
"arguments": json.dumps(fn.get("arguments") or {}),
})
if j.get("done"):
for part, is_thinking in _harmony_router.flush():
yield _stream_delta_event(part, thinking=is_thinking)
if _ollama_tool_calls:
yield f'data: {json.dumps({"type": "tool_calls", "calls": _ollama_tool_calls})}\n\n'
if j.get("prompt_eval_count") is not None or j.get("eval_count") is not None:
yield f'data: {json.dumps({"type": "usage", "data": {"input_tokens": j.get("prompt_eval_count", 0), "output_tokens": j.get("eval_count", 0)}})}\n\n'
yield "data: [DONE]\n\n"
return
for part, is_thinking in _harmony_router.flush():
yield _stream_delta_event(part, thinking=is_thinking)
yield "data: [DONE]\n\n"
except (httpx.ConnectError, httpx.ConnectTimeout) as e:
_cooled = _mark_host_dead(target_url)
@@ -1387,6 +1491,8 @@ async def stream_llm(url: str, model: str, messages: List[Dict], temperature: fl
_first_content_sent = False
_in_think_tag = False # True while consuming <think>…</think> content
_think_open_stripped = False # opening <think> tag already removed
_harmony_router = _HarmonyStreamRouter()
_harmony_active = False # sticky: gpt-oss harmony <|channel|> stream detected
def _emit_tool_calls():
"""Build the tool_calls event string if any were accumulated."""
@@ -1395,6 +1501,22 @@ async def stream_llm(url: str, model: str, messages: List[Dict], temperature: fl
calls = [_tc_acc[i] for i in sorted(_tc_acc)]
return f'data: {json.dumps({"type": "tool_calls", "calls": calls})}\n\n'
def _format_routed_content(parts: List[Tuple[str, bool]]) -> List[str]:
nonlocal _first_content_sent
events = []
for part, is_thinking in parts:
if is_thinking:
events.append(_stream_delta_event(part, thinking=True))
continue
# Some thinking backends start normal content with a stray closing
# tag. Repair only that shape; do not wrap every first token for
# model families like MiniMax, which often stream ordinary answers.
if _thinking_model and not _first_content_sent and part.lstrip().lower().startswith("</think"):
part = "<think>" + part
_first_content_sent = True
events.append(_stream_delta_event(part))
return events
try:
client = _get_http_client()
async with client.stream('POST', target_url, json=payload, headers=h, timeout=stream_timeout) as r:
@@ -1415,6 +1537,8 @@ async def stream_llm(url: str, model: str, messages: List[Dict], temperature: fl
if line.startswith("data:"):
data = line[5:].strip()
if data == "[DONE]":
for event in _format_routed_content(_harmony_router.flush()):
yield event
tc_event = _emit_tool_calls()
if tc_event:
yield tc_event
@@ -1438,6 +1562,7 @@ async def stream_llm(url: str, model: str, messages: List[Dict], temperature: fl
_delta0.get("content")
or _delta0.get("reasoning_content")
or _delta0.get("reasoning")
or _delta0.get("thinking")
or _delta0.get("tool_calls")
)
if "usage" in j and not _delta_has_output:
@@ -1462,59 +1587,67 @@ async def stream_llm(url: str, model: str, messages: List[Dict], temperature: fl
delta = _c0.get("delta") or {}
if isinstance(delta, dict):
# Text content
# Reasoning tokens (VLLM --reasoning-parser, e.g. Qwen3/DeepSeek-R1, Nemotron). vLLM 0.20.2 / NIM emit the field as `reasoning`; older builds use `reasoning_content`. Accept either.
reasoning = delta.get("reasoning_content") or delta.get("reasoning") or ""
# Reasoning tokens (VLLM --reasoning-parser, e.g. Qwen3/DeepSeek-R1, Nemotron). vLLM 0.20.2 / NIM emit the field as `reasoning`; older builds use `reasoning_content`. Some OpenAI-compatible Ollama builds use `thinking`.
reasoning = delta.get("reasoning_content") or delta.get("reasoning") or delta.get("thinking") or ""
if reasoning:
yield f'data: {json.dumps({"delta": reasoning, "thinking": True})}\n\n'
yield _stream_delta_event(reasoning, thinking=True)
content = delta.get("content") or ""
if content:
stripped = content.lstrip()
# Auto-detect <think>…</think> in content stream.
# Covers Qwen3-derived models (Qwopus, QwQ forks) whose
# names don't match _THINKING_MODEL_PATTERNS but still
# emit literal <think> markup via llama.cpp --jinja.
if not _first_content_sent and not _thinking_model and not _in_think_tag and stripped.lower().startswith("<think"):
_thinking_model = True
_in_think_tag = True
if _in_think_tag:
close_idx = content.lower().find("</think>")
if close_idx != -1:
# Split: up-to-</think> → thinking, remainder → content
think_part = content[:close_idx]
if not _think_open_stripped:
# Strip the opening <think[...] > from the first chunk.
# Use a dedicated flag — _first_content_sent stays False
# throughout the think block, so it must not be reused.
tag_end = think_part.lower().find(">")
if tag_end != -1:
think_part = think_part[tag_end + 1:]
_think_open_stripped = True
regular_part = content[close_idx + len("</think>"):]
_in_think_tag = False
if think_part:
yield f'data: {json.dumps({"delta": think_part, "thinking": True})}\n\n'
if regular_part:
_first_content_sent = True
yield f'data: {json.dumps({"delta": regular_part})}\n\n'
else:
# Still inside <think>: route to thinking channel
if not _think_open_stripped:
# Strip the opening <think[...] > tag (first chunk only)
tag_end = stripped.lower().find(">")
if tag_end != -1:
content = stripped[tag_end + 1:]
_think_open_stripped = True
if content:
yield f'data: {json.dumps({"delta": content, "thinking": True})}\n\n'
# gpt-oss harmony format (<|channel|>analysis/final): route via the harmony
# stream router. Sticky once the first marker appears — distinct from the
# <think> path below (handled in the else, preserving #2588 behaviour).
if _harmony_active or "<|" in content:
_harmony_active = True
for event in _format_routed_content(_harmony_router.feed(content)):
yield event
else:
# Some thinking backends start normal content with a
# stray closing tag. Repair only that shape; do not
# wrap every first token for model families like
# MiniMax, which often stream ordinary answers.
if _thinking_model and not _first_content_sent and stripped.lower().startswith("</think"):
content = "<think>" + content
_first_content_sent = True
yield f'data: {json.dumps({"delta": content})}\n\n'
# Auto-detect <think>…</think> in content stream.
# Covers Qwen3-derived models (Qwopus, QwQ forks) whose
# names don't match _THINKING_MODEL_PATTERNS but still
# emit literal <think> markup via llama.cpp --jinja.
if not _first_content_sent and not _thinking_model and not _in_think_tag and stripped.lower().startswith("<think"):
_thinking_model = True
_in_think_tag = True
if _in_think_tag:
close_idx = content.lower().find("</think>")
if close_idx != -1:
# Split: up-to-</think> → thinking, remainder → content
think_part = content[:close_idx]
if not _think_open_stripped:
# Strip the opening <think[...] > from the first chunk.
# Use a dedicated flag — _first_content_sent stays False
# throughout the think block, so it must not be reused.
tag_end = think_part.lower().find(">")
if tag_end != -1:
think_part = think_part[tag_end + 1:]
_think_open_stripped = True
regular_part = content[close_idx + len("</think>"):]
_in_think_tag = False
if think_part:
yield f'data: {json.dumps({"delta": think_part, "thinking": True})}\n\n'
if regular_part:
_first_content_sent = True
yield f'data: {json.dumps({"delta": regular_part})}\n\n'
else:
# Still inside <think>: route to thinking channel
if not _think_open_stripped:
# Strip the opening <think[...] > tag (first chunk only)
tag_end = stripped.lower().find(">")
if tag_end != -1:
content = stripped[tag_end + 1:]
_think_open_stripped = True
if content:
yield f'data: {json.dumps({"delta": content, "thinking": True})}\n\n'
else:
# Some thinking backends start normal content with a
# stray closing tag. Repair only that shape; do not
# wrap every first token for model families like
# MiniMax, which often stream ordinary answers.
if _thinking_model and not _first_content_sent and stripped.lower().startswith("</think"):
content = "<think>" + content
_first_content_sent = True
yield f'data: {json.dumps({"delta": content})}\n\n'
# Native tool calls — accumulate across chunks
for tc in delta.get("tool_calls") or []:
if tc is None:
@@ -1563,15 +1696,19 @@ async def stream_llm(url: str, model: str, messages: List[Dict], temperature: fl
yield f'data: {json.dumps({"type": "tool_call_delta", "index": idx, "name": _tc_acc[idx]["name"], "arg_delta": func["arguments"]})}\n\n'
elif "text" in j:
if j["text"]:
yield f'data: {json.dumps({"delta": j["text"]})}\n\n'
for event in _format_routed_content(_harmony_router.feed(j["text"])):
yield event
else:
if data.strip():
yield f'data: {json.dumps({"delta": data})}\n\n'
for event in _format_routed_content(_harmony_router.feed(data)):
yield event
except Exception as e:
logger.error(f"Error parsing stream data: {e}")
continue
# End of stream (no explicit [DONE] received)
for event in _format_routed_content(_harmony_router.flush()):
yield event
tc_event = _emit_tool_calls()
if tc_event:
yield tc_event
+34
View File
@@ -172,3 +172,37 @@ def test_registered_thinking_model_stray_close_tag_repair_unchanged(monkeypatch)
assert deltas, deltas
first = deltas[0]["delta"]
assert first.startswith("<think>"), f"expected repair prefix, got: {first!r}"
def test_thinking_field_emits_thinking_chunk(monkeypatch):
deltas = _run_stream(
"gpt-oss:20b",
[
'data: {"choices":[{"delta":{"thinking":"checking files"}}]}',
'data: {"choices":[{"delta":{"content":"visible answer"}}]}',
"data: [DONE]",
],
monkeypatch,
)
assert any(d.get("thinking") and d["delta"] == "checking files" for d in deltas), deltas
assert any((not d.get("thinking")) and d["delta"] == "visible answer" for d in deltas), deltas
def test_harmony_analysis_channel_routes_to_thinking(monkeypatch):
deltas = _run_stream(
"gpt-oss:20b",
[
'data: {"choices":[{"delta":{"content":"<|channel|>ana"}}]}',
'data: {"choices":[{"delta":{"content":"lysis<|message|>We need to inspect."}}]}',
'data: {"choices":[{"delta":{"content":"<|end|><|channel|>final<|message|>Here "}}]}',
'data: {"choices":[{"delta":{"content":"are the files.<|end|>"}}]}',
"data: [DONE]",
],
monkeypatch,
)
thinking = "".join(d["delta"] for d in deltas if d.get("thinking"))
answer = "".join(d["delta"] for d in deltas if not d.get("thinking"))
assert thinking == "We need to inspect."
assert answer == "Here are the files."
assert "<|channel|>" not in thinking + answer
assert "<|message|>" not in thinking + answer