Preserve large pasted messages in context

This commit is contained in:
pewdiepie-archdaemon
2026-06-01 12:38:10 +09:00
parent 1ce00b5dea
commit a66f241e21
2 changed files with 100 additions and 8 deletions
+69 -8
View File
@@ -6,7 +6,7 @@ Summarizes older messages via the same LLM, preserving key context.
""" """
import logging import logging
from typing import Dict, List, Optional from typing import Any, Dict, List, Optional
from src.model_context import get_context_length, estimate_tokens from src.model_context import get_context_length, estimate_tokens
from src.llm_core import llm_call_async from src.llm_core import llm_call_async
@@ -95,6 +95,55 @@ def _sanitize_tool_messages(msgs: List[Dict]) -> List[Dict]:
return out return out
def _message_text_token_estimate(text: str) -> int:
return int(len(text) * 0.3) + 4
def _truncate_text_to_token_budget(text: str, token_budget: int) -> str:
"""Trim a too-large current user message instead of dropping it entirely."""
if token_budget <= 32:
return "[Current user message omitted: it exceeded the model context window.]"
# Match src.model_context.estimate_tokens' rough chars * 0.3 estimate.
max_chars = max(200, int((token_budget - 16) / 0.3))
if len(text) <= max_chars:
return text
notice = (
"\n\n[Notice: the pasted message was too large for this model's context "
"window, so Odysseus kept the beginning and end.]"
)
keep_chars = max(200, max_chars - len(notice))
head_len = max(100, int(keep_chars * 0.7))
tail_len = max(80, keep_chars - head_len)
return text[:head_len].rstrip() + notice + "\n\n" + text[-tail_len:].lstrip()
def _truncate_message_to_token_budget(msg: Dict[str, Any], token_budget: int) -> Dict[str, Any]:
"""Return a copy of msg whose text content fits inside token_budget."""
out = dict(msg)
content = out.get("content", "")
if isinstance(content, str):
out["content"] = _truncate_text_to_token_budget(content, token_budget)
return out
if isinstance(content, list):
remaining = token_budget
new_content = []
for item in content:
if not isinstance(item, dict) or item.get("type") != "text":
new_content.append(item)
continue
text = item.get("text", "")
truncated = _truncate_text_to_token_budget(text, remaining)
cloned = dict(item)
cloned["text"] = truncated
new_content.append(cloned)
remaining -= _message_text_token_estimate(truncated)
out["content"] = new_content
return out
def trim_for_context(messages: List[Dict], context_length: int, reserve_tokens: int = 512) -> List[Dict]: def trim_for_context(messages: List[Dict], context_length: int, reserve_tokens: int = 512) -> List[Dict]:
"""Trim system messages to fit within context_length. """Trim system messages to fit within context_length.
@@ -153,19 +202,31 @@ def trim_for_context(messages: List[Dict], context_length: int, reserve_tokens:
if estimate_tokens(trimmed) <= budget: if estimate_tokens(trimmed) <= budget:
return _sanitize_tool_messages(essential_system + protected_msgs + convo_msgs) return _sanitize_tool_messages(essential_system + protected_msgs + convo_msgs)
# Still too big — drop older conversation turns BUT protect the last 10. # Still too big — drop older conversation turns BUT always keep the current
# user turn. If a pasted message alone exceeds the model context, truncate
# that message with a visible notice instead of dropping it; otherwise the
# model appears to "ignore" large pastes because it never receives them.
# Hermes-style: recent context matters more than old context. # Hermes-style: recent context matters more than old context.
PROTECT_RECENT = 10 PROTECT_RECENT = 10
if len(convo_msgs) > PROTECT_RECENT: current_msg = convo_msgs[-1:] if convo_msgs else []
old_msgs = convo_msgs[:-PROTECT_RECENT] prior_convo = convo_msgs[:-1] if convo_msgs else []
recent_msgs = convo_msgs[-PROTECT_RECENT:] if len(prior_convo) >= PROTECT_RECENT:
old_msgs = prior_convo[:-(PROTECT_RECENT - 1)]
recent_msgs = prior_convo[-(PROTECT_RECENT - 1):] + current_msg
while old_msgs and estimate_tokens(essential_system + old_msgs + recent_msgs) > budget: while old_msgs and estimate_tokens(essential_system + old_msgs + recent_msgs) > budget:
old_msgs.pop(0) old_msgs.pop(0)
convo_msgs = old_msgs + recent_msgs convo_msgs = old_msgs + recent_msgs
else: else:
# Not enough messages to split — just trim from front convo_msgs = prior_convo + current_msg
while convo_msgs and estimate_tokens(essential_system + convo_msgs) > budget: while prior_convo and estimate_tokens(essential_system + prior_convo + current_msg) > budget:
convo_msgs.pop(0) prior_convo.pop(0)
convo_msgs = prior_convo + current_msg
# If the current message itself is too large, shrink only that message.
if current_msg and estimate_tokens(essential_system + protected_msgs + convo_msgs) > budget:
prefix = essential_system + protected_msgs + convo_msgs[:-1]
available_for_current = max(64, budget - estimate_tokens(prefix))
convo_msgs[-1] = _truncate_message_to_token_budget(convo_msgs[-1], available_for_current)
result = _sanitize_tool_messages(essential_system + protected_msgs + convo_msgs) result = _sanitize_tool_messages(essential_system + protected_msgs + convo_msgs)
logger.info(f"Trimmed to {estimate_tokens(result)} tokens ({len(result)} messages)") logger.info(f"Trimmed to {estimate_tokens(result)} tokens ({len(result)} messages)")
+31
View File
@@ -18,6 +18,7 @@ from src.context_compactor import (
COMPACT_THRESHOLD, COMPACT_THRESHOLD,
SELF_SUMMARY_SYSTEM_PROMPT, SELF_SUMMARY_SYSTEM_PROMPT,
SUMMARY_MAX_TOKENS, SUMMARY_MAX_TOKENS,
trim_for_context,
) )
@@ -53,3 +54,33 @@ class TestSelfSummaryPrompt:
def test_mentions_compactions(self): def test_mentions_compactions(self):
assert "Compactions so far" in SELF_SUMMARY_SYSTEM_PROMPT assert "Compactions so far" in SELF_SUMMARY_SYSTEM_PROMPT
class TestTrimForContext:
def test_keeps_current_large_user_message_by_truncating(self):
huge = "A" * 20000
messages = [
{"role": "system", "content": "You are helpful."},
{"role": "user", "content": huge},
]
trimmed = trim_for_context(messages, context_length=2048, reserve_tokens=512)
user_msgs = [m for m in trimmed if m.get("role") == "user"]
assert len(user_msgs) == 1
content = user_msgs[0]["content"]
assert "pasted message was too large" in content
assert content.startswith("A")
assert len(content) < len(huge)
def test_drops_older_messages_before_latest_user_paste(self):
huge = "B" * 12000
messages = [{"role": "system", "content": "You are helpful."}]
messages.extend({"role": "user", "content": f"old-{i} " + ("x" * 1000)} for i in range(8))
messages.append({"role": "user", "content": huge})
trimmed = trim_for_context(messages, context_length=2048, reserve_tokens=512)
assert trimmed[-1]["role"] == "user"
assert "pasted message was too large" in trimmed[-1]["content"]
assert "old-0" not in "\n".join(str(m.get("content", "")) for m in trimmed)