Merge remote-tracking branch 'origin/dev' into test-main-dev-merge-20260615

# Conflicts:
#	src/tool_implementations.py
#	static/js/research/panel.js
This commit is contained in:
pewdiepie-archdaemon
2026-06-15 21:20:15 +09:00
312 changed files with 20047 additions and 2952 deletions
+3
View File
@@ -91,6 +91,9 @@ _ROUTING_PATTERNS: tuple[tuple[str, str, Pattern[str]], ...] = tuple(
("ui", "tool or feature toggle request", r"\b(?:disable|enable|turn\s+(?:on|off))\s+(?:the\s+)?(?:shell|search|web|browser|documents?|memory|skills|images?|calendar|email|mail|research|incognito)\b"),
# Deep research jobs, not quick conceptual mentions of research.
("web", "explicit web search request", rf"{_PLEASE}(?:do|run|use|perform|make)\s+(?:a\s+)?(?:web\s+search|search\s+the\s+web)\b.+"),
("web", "web lookup imperative request", rf"{_PLEASE}(?:web\s+search|search\s+the\s+web|search\s+online|look\s+up|google)\b.+"),
("web", "assistant web lookup request", rf"{_ACTION_QUESTION}(?:web\s+search|search\s+the\s+web|search\s+online|look\s+up|google)\b.+"),
("research", "deep research imperative request", rf"{_PLEASE}(?:research|deep\s+dive|look\s+into|investigate)\s+.+"),
("research", "assistant deep research request", rf"{_ACTION_QUESTION}(?:research|do\s+research|deep\s+dive|look\s+into|investigate)\s+.+"),
+179 -53
View File
@@ -21,7 +21,7 @@ from src.settings import get_setting
from src.prompt_security import untrusted_context_message
from src.tool_security import blocked_tools_for_owner, plan_mode_disabled_tools
from src.tool_policy import GUIDE_ONLY_DIRECTIVE, ToolPolicy
from src.tool_utils import get_mcp_manager
from src.tool_utils import _truncate, get_mcp_manager
from src.agent_tools import (
parse_tool_blocks,
strip_tool_blocks,
@@ -262,6 +262,11 @@ _DOMAIN_RULES = {
- Use `manage_settings` for preferences and tool enable/disable.
- Use named tools over `app_api` when a named wrapper exists.
- `app_api` is only for safe UI/API actions without a named tool; do not use it for shell, package installs, engine rebuilds, or sensitive auth/admin paths.""",
"contacts": """\
## Contacts rules
- Use `resolve_contact` to look up a contact's email or phone number by name. Searches the CardDAV address book and sent email history.
- Use `manage_contact` to list, add, update, or delete contacts in the address book.
- Do NOT use `manage_memory` for contact lookups — contact details live in the address book, not memory.""",
}
_DOMAIN_TOOL_MAP = {
@@ -272,8 +277,9 @@ _DOMAIN_TOOL_MAP = {
"notes_calendar_tasks": {"manage_notes", "manage_calendar", "manage_tasks"},
"ui": {"ui_control"},
"sessions": {"create_session", "list_sessions", "manage_session", "send_to_session", "search_chats"},
"files": {"bash", "python", "read_file", "write_file", "edit_file", "grep", "glob", "ls"},
"files": {"bash", "python", "read_file", "write_file", "edit_file", "grep", "glob", "ls", "get_workspace"},
"settings": {"manage_settings", "manage_endpoints", "manage_mcp", "manage_webhooks", "manage_tokens", "app_api"},
"contacts": {"resolve_contact", "manage_contact"},
}
def _domain_rules_for_tools(tool_names: set) -> list[str]:
@@ -309,6 +315,7 @@ NEVER pipe multi-line Python through `python -c "..."` — shell quoting eats re
<python code>
```
Execute Python code. Use for computation, data processing, scripting. NOT for writing code for the user (use create_document for that). Same sandbox limits as bash — no TTY, no GUI, no `input()`; for anything the user should interact with, generate a single HTML file with inline JS instead.
Prefer a dedicated tool whenever one fits the job (reading, searching, or writing files); use python only for computation/processing no dedicated tool covers - not for reading or writing files.
Do NOT use Python/requests for web lookup/search/latest/current requests when `web_search` or `web_fetch` is available.""",
"web_search": """\
@@ -347,6 +354,11 @@ Write content to a file. First line is the path, rest is the content.""",
```
Edit an EXISTING file by exact string replacement. PREFER this over bash (sed/echo/redirects) for changing files — it shows a before/after diff. `old_string` must match the file exactly and be unique unless `replace_all` is true. Use write_file to create a new file.""",
"get_workspace": """\
```get_workspace
```
Return the absolute path of the active workspace folder. File tools are CONFINED to it (paths can be RELATIVE to it); the shell starts there (cwd) but is NOT sandboxed. Call this first when the user says "the project"/"the code"/"this folder" without a path, instead of asking them. No arguments.""",
"create_document": """\
```create_document
<title>
@@ -598,7 +610,7 @@ _API_HOSTS = frozenset([
"api.deepseek.com", "deepseek.com",
"api.together.xyz", "api.fireworks.ai",
"api.perplexity.ai", "api.x.ai",
"ollama.com", "api.venice.ai",
"ollama.com", "api.venice.ai", "api.kimi.com",
"api.githubcopilot.com",
# Local OpenAI-compatible endpoints (llama.cpp, vLLM, LM Studio, etc.).
# Without these, `_is_api_model` falls back to keyword sniffing on the
@@ -785,6 +797,12 @@ def _classify_agent_request(messages: List[Dict], last_user: str) -> Dict[str, o
domains.add("documents")
if has(r"\b(search|web|google|look up|latest|news|current|weather|forecast|stock price|price of|website|url|https?://|www\.)\b"):
domains.add("web")
if has(
r"\b(wyszukaj|wyszukać|wyszukac)\b.*\b(internet|internecie|online|web)\b",
r"\b(sprawd[zź]|znajd[zź])\b.*\b(internet|internecie|online|web)\b",
r"\b(aktualn\w*|bieżąc\w*|biezac\w*|dzisiaj|teraz)\b.*\b(pogod\w*|temperatur\w*)\b",
):
domains.add("web")
if has(r"\b(research|deep dive|investigate|look into)\b"):
domains.add("web")
if has(r"\b(open|show|toggle|turn on|turn off|disable|enable|switch model|change model|settings|theme|panel)\b"):
@@ -795,6 +813,8 @@ def _classify_agent_request(messages: List[Dict], last_user: str) -> Dict[str, o
domains.add("files")
if has(r"\b(endpoint|api token|mcp|webhook|preference|configure|config|setting)\b"):
domains.add("settings")
if has(r"\b(contact|contacts|phone|phone number|address book|vcard)\b"):
domains.add("contacts")
low_signal = not continuation and not domains
return {
@@ -860,7 +880,7 @@ def _build_system_prompt(
_ov_sig = _hl.sha256(_json.dumps(get_builtin_overrides() or {}, sort_keys=True).encode()).hexdigest()
except Exception:
_ov_sig = ""
cache_key = (frozenset(disabled_tools or []), bool(mcp_mgr), needs_admin, _rt_key, compact, _ov_sig, suppress_local_context)
cache_key = (frozenset(disabled_tools or []), bool(mcp_mgr), needs_admin, _rt_key, compact, _ov_sig, owner, suppress_local_context)
if _cached_base_prompt and _cached_base_prompt_key == cache_key and not active_document:
agent_prompt = _cached_base_prompt
# Skill index is user-editable (name + description), so it must never
@@ -868,7 +888,7 @@ def _build_system_prompt(
# when the cache hits.
_, _skill_index_block = _build_base_prompt(
disabled_tools, mcp_mgr, needs_admin, relevant_tools,
mcp_disabled_map=mcp_disabled_map, compact=compact,
mcp_disabled_map=mcp_disabled_map, compact=compact, owner=owner,
suppress_local_context=suppress_local_context,
)
else:
@@ -879,6 +899,7 @@ def _build_system_prompt(
relevant_tools,
mcp_disabled_map=mcp_disabled_map,
compact=compact,
owner=owner,
suppress_local_context=suppress_local_context,
)
if not active_document:
@@ -894,9 +915,20 @@ def _build_system_prompt(
# Current date/time for every agent request. This is user-local when the
# browser provided timezone headers, with a server-local fallback.
#
# IMPORTANT: this is intentionally NOT prepended into agent_prompt (the
# system message) anymore. Its text changes every minute, and local
# OpenAI-compatible backends (llama.cpp / LM Studio) key their KV-cache
# prefix off the system message byte-for-byte — mixing ever-changing
# timestamp text into the (already large, tool-laden) agent system prompt
# would invalidate the cached prefix on every single request, forcing a
# full prompt re-evaluation each turn (issue #2927). It's built here as a
# standalone *user*-role message and inserted near the end of the array,
# right alongside _doc_message / _skills_message, below.
_datetime_message = None
try:
from src.user_time import current_datetime_prompt
agent_prompt = current_datetime_prompt() + agent_prompt
from src.user_time import current_datetime_context_message
_datetime_message = current_datetime_context_message()
except Exception:
pass
@@ -1296,6 +1328,9 @@ def _build_system_prompt(
last_user_idx += 1
if _skills_message:
merged.insert(last_user_idx, _skills_message)
last_user_idx += 1
if _datetime_message:
merged.insert(last_user_idx, _datetime_message)
return merged, mcp_schemas
@@ -1314,6 +1349,7 @@ def _build_base_prompt(
relevant_tools=None,
mcp_disabled_map=None,
compact: bool = False,
owner: Optional[str] = None,
suppress_local_context: bool = False,
):
"""Build the agent prompt with only relevant tools included.
@@ -1373,7 +1409,7 @@ def _build_base_prompt(
from src.constants import DATA_DIR
_sm = SkillsManager(DATA_DIR)
active_tools = list(set(TOOL_SECTIONS.keys()) - set(disabled or []))
skill_idx = _sm.index_for(owner=None, active_toolsets=active_tools)
skill_idx = _sm.index_for(owner=owner, active_toolsets=active_tools)
if skill_idx:
lines = ["## Available skills",
"Procedures the assistant should consult before doing domain work. "
@@ -1782,10 +1818,10 @@ async def stream_agent_loop(
owner: Optional[str] = None,
relevant_tools: Optional[Set[str]] = None,
fallbacks: Optional[List[tuple]] = None,
workspace: Optional[str] = None,
plan_mode: bool = False,
approved_plan: Optional[str] = None,
tool_policy: Optional[ToolPolicy] = None,
workspace: Optional[str] = None,
_is_teacher_run: bool = False,
) -> AsyncGenerator[str, None]:
"""Streaming agent loop generator.
@@ -1854,8 +1890,21 @@ async def stream_agent_loop(
logger.info(f"[tool-rag] Using caller-provided relevant_tools ({len(_relevant_tools)} tools)")
if not guide_only and not _relevant_tools and bool(_intent.get("low_signal")):
from src.tool_index import ALWAYS_AVAILABLE
_relevant_tools = set(ALWAYS_AVAILABLE)
logger.info("[tool-rag] Low-signal agent message; skipping retrieval and using always-available tools only")
if workspace:
# An active workspace IS the file-work signal: a vague "look at the
# project" means explore this folder. Surface only the READ-ONLY file
# tools (intersection with the plan-mode read-only allowlist) so the
# agent can investigate; write/shell tools stay out until the request
# actually calls for them (RAG retrieval adds those on a real ask).
_relevant_tools = set(ALWAYS_AVAILABLE)
from src.tool_security import PLAN_MODE_READONLY_TOOLS
_relevant_tools |= (_DOMAIN_TOOL_MAP["files"] & PLAN_MODE_READONLY_TOOLS)
logger.info("[tool-rag] Low-signal but workspace active; including read-only file tools")
else:
# Don't short-circuit: fall through to RAG retrieval below.
# Non-English queries are flagged low_signal by the English-only
# intent classifier, but fastembed retrieval works across languages.
logger.info("[tool-rag] Low-signal query; will run RAG retrieval")
if not guide_only and not _relevant_tools:
try:
from src.tool_index import get_tool_index, ALWAYS_AVAILABLE
@@ -1930,6 +1979,44 @@ async def stream_agent_loop(
if _relevant_tools is not None and active_document is not None:
_relevant_tools.update({"edit_document", "update_document", "suggest_document"})
# The skill index injected by _build_system_prompt tells the model to
# call `manage_skills action=view`, and Jaccard-matched skills are pasted
# into the prompt as procedures to follow — but neither path goes through
# tool selection, so the model can be handed a procedure naming tools
# (grep, read_file, ...) that aren't in its schema list. Keep the schemas
# in lockstep: manage_skills is callable whenever any skill is indexed,
# and a matched skill's declared requires_toolsets ride along with it.
if not guide_only and _relevant_tools is not None:
try:
from services.memory.skills import SkillsManager
from src.constants import DATA_DIR
_skills_on = True
try:
from routes.prefs_routes import _load_for_user as _load_prefs
_skills_on = (_load_prefs(owner) or {}).get("skills_enabled", True)
except Exception:
pass
_sm = SkillsManager(DATA_DIR)
_owner_skills = _sm.load(owner=owner) if _skills_on else []
if _owner_skills:
_relevant_tools.add("manage_skills")
if _retrieval_query:
# Validate against every known executable tool, not just
# TOOL_SECTIONS — code-nav tools (grep/glob/ls) ship as
# schemas without a prompt-prose section.
from src.tool_policy import known_tool_names
_known = known_tool_names()
for _sk in _sm.get_relevant_skills(
_retrieval_query, skills=_owner_skills,
threshold=0.25, max_items=3,
):
_relevant_tools.update(
t for t in (_sk.get("requires_toolsets") or [])
if t in _known
)
except Exception as _e:
logger.debug(f"[tool-rag] skill-aware tool include skipped: {_e}")
if _relevant_tools is not None:
logger.info("[agent-intent] selected_tools=%s", sorted(_relevant_tools)[:50])
@@ -1980,6 +2067,10 @@ async def stream_agent_loop(
# and can override this list for users who know their setup.
_model_no_tools = any(kw in _model_lc for kw in (
"deepseek-r1",
# Open-weight GPT-OSS models are commonly served through llama.cpp /
# llama-cpp-python. Their names contain "gpt-o", but they do not use
# OpenAI's native tool-call channel unless the endpoint opts in.
"gpt-oss",
))
# Native Ollama endpoints (/api/chat) handle tool schemas differently from
# the OpenAI-compat path. Models like gemma4, qwen3.5, ministral respond to
@@ -2011,27 +2102,6 @@ async def stream_agent_loop(
suppress_local_context=guide_only,
active_email=active_email,
)
if workspace and not guide_only:
# PREPEND (not append) so it dominates the large base prompt — appended
# at the end, small models ignored it and asked the user for code. The
# folder IS the project; the agent must explore it, not ask.
_ws_note = (
f"## ACTIVE WORKSPACE — READ FIRST\n"
f"The user is working in this folder: {workspace}\n"
f"It IS the project. bash/python run with cwd set here and "
f"read_file/write_file are confined to it (paths outside are rejected).\n"
f"When the user says \"the code\" / \"this project\" / \"the workspace\" "
f"or asks to review/find/edit something WITHOUT a path, they mean THIS "
f"folder. Do NOT ask the user for code or a path, and do NOT read a file "
f"literally named \"workspace\". ALWAYS start by exploring it yourself: "
f"run `bash` → `git ls-files` (or `ls -R`) to see the files, then "
f"read_file the relevant ones by path RELATIVE to the workspace."
)
if messages and messages[0].get("role") == "system":
messages[0]["content"] = _ws_note + "\n\n" + (messages[0].get("content") or "")
else:
messages.insert(0, {"role": "system", "content": _ws_note})
logger.info("[workspace] active for this turn: %s", workspace)
if plan_mode and not guide_only:
# Steer the model to investigate-then-propose. Hard tool gating handles
# every write path except shell; this directive is what keeps the
@@ -2063,30 +2133,34 @@ async def stream_agent_loop(
_t3 = time.time()
try:
from src.context_compactor import trim_for_context
from src.context_budget import compute_input_token_budget, DEFAULT_HARD_MAX
from src.settings import is_setting_overridden
from src.context_budget import compute_input_token_budget, DEFAULT_HARD_MAX, DEFAULT_BUDGET, budget_is_explicit as _budget_is_explicit
from src.model_context import budget_context_for_model
soft_budget = int(get_setting("agent_input_token_budget", 6000) or 0)
soft_budget = int(get_setting("agent_input_token_budget", DEFAULT_BUDGET) or 0)
if soft_budget > 0:
before_trim_tokens = estimate_tokens(messages)
reserve_tokens = min(max(max_tokens or 1024, 512), 2048)
# Honour the configurable ceiling for the auto-derived budget path.
# No-op when the user has an explicit `agent_input_token_budget`
# (that branch ignores hard_max). Falls back to DEFAULT_HARD_MAX
# on missing/malformed values so misconfig can't zero the budget.
# Ceiling for the auto-derived budget (no effect on an explicit budget;
# see #1230). Falls back to DEFAULT_HARD_MAX on missing/malformed values
# so misconfig can't zero the budget.
try:
hard_max = int(get_setting("agent_input_token_hard_max", DEFAULT_HARD_MAX) or DEFAULT_HARD_MAX)
except (TypeError, ValueError):
hard_max = DEFAULT_HARD_MAX
if hard_max <= 0:
hard_max = DEFAULT_HARD_MAX
# Scale the default budget to the model's context window so long-context
# models aren't silently capped at 6000; an explicit user setting is
# still honoured (clamped to the window). (#1170)
# Default value = auto sentinel (scale to the window); any other value =
# explicit cap. Value-based, not presence-based, because the save path
# materializes defaults so a persisted default must still read as auto (#4121).
budget_is_explicit = _budget_is_explicit(soft_budget)
# Scale only off a window we actually discovered, bound to the value it
# proves (else 0) — not the passed-in context_length, which can be stale
# or unset for some callers (#4122 review).
ctx_for_budget = budget_context_for_model(endpoint_url, model, fallback=context_length)
effective_budget = compute_input_token_budget(
soft_budget,
context_length,
is_setting_overridden("agent_input_token_budget"),
ctx_for_budget,
budget_is_explicit,
hard_max=hard_max,
)
trimmed_messages = trim_for_context(
@@ -2161,11 +2235,12 @@ async def stream_agent_loop(
# tool, so we don't nudge on harmless transitional text like "let me
# know what you think".
_INTENT_RE = re.compile(
r"(?:^|\n)\s*(?:let me|i'?ll|i will|going to|let's)\s+"
r"(?:^|\n)\s*(?:let me|i'?ll|i will|i need to|we need to|need to|"
r"i should|we should|i must|we must|going to|let's)\s+"
r"(?:tail|check|investigate|look at|see|tail|read|fetch|inspect|"
r"verify|diagnose|examine|debug|capture|grab|pull|view|run|call|"
r"trigger|launch|start|kick off|stop|kill|restart|adopt|serve|"
r"register|adopt|list|search|find|query|hit|ping|test)"
r"register|adopt|list|search|find|query|hit|ping|test|use|perform|do)"
r"\b[^.\n]{0,140}",
re.IGNORECASE,
)
@@ -2206,9 +2281,17 @@ async def stream_agent_loop(
elif _is_api_model:
# Filter schemas by RAG-selected tools (if available)
if _relevant_tools:
# _build_base_prompt unions _ADMIN_TOOLS into the prompt
# sections when admin intent fires — the schema list must
# offer the same names, or the model reads prose describing
# tools it cannot call and substitutes the nearest schema
# it does have (e.g. manage_memory for manage_skills).
_schema_names = set(_relevant_tools)
if _needs_admin:
_schema_names |= _ADMIN_TOOLS
base_schemas = [
s for s in FUNCTION_TOOL_SCHEMAS
if s.get("function", {}).get("name") in _relevant_tools
if s.get("function", {}).get("name") in _schema_names
]
_mcp_filtered = [
s for s in mcp_schemas
@@ -2254,6 +2337,7 @@ async def stream_agent_loop(
prompt_type=prompt_type if round_num == 1 else None,
tools=all_tool_schemas if all_tool_schemas else None,
timeout=agent_stream_timeout,
session_id=session_id,
):
if time.time() > _round_deadline:
logger.warning(f"[agent] round {round_num} stream exceeded wall-clock deadline; cutting off")
@@ -2743,6 +2827,46 @@ async def stream_agent_loop(
)
desc, result = await _tool_task
# A skill the model just loaded can prescribe tools that weren't
# RAG-selected this turn (declared via requires_toolsets in its
# frontmatter). Union them into the selection so the NEXT round's
# schema list includes them — otherwise the model reads "use
# grep" from the skill it fetched but has no grep schema to call.
if (
block.tool_type == "manage_skills"
and _relevant_tools is not None
and not result.get("error")
):
_ms_args = {}
_ms_raw = (block.content or "").strip()
if _ms_raw.startswith("{"):
try:
_ms_args = json.loads(_ms_raw)
except json.JSONDecodeError:
_ms_args = {}
_ms_name = str(_ms_args.get("name", "") or "").strip()
if _ms_name and _ms_args.get("action") in ("view", "view_ref"):
try:
from services.memory.skills import SkillsManager as _SkM
from src.constants import DATA_DIR as _DD
from src.tool_policy import known_tool_names as _ktn
_known = _ktn()
for _sk in _SkM(_DD).load(owner=owner):
if _sk.get("name") == _ms_name:
_new = {
t for t in (_sk.get("requires_toolsets") or [])
if t in _known and t not in _relevant_tools
}
if _new:
_relevant_tools.update(_new)
logger.info(
"[tool-rag] skill '%s' unlocked tools for next round: %s",
_ms_name, sorted(_new),
)
break
except Exception as _e:
logger.debug(f"skill requires_toolsets unlock skipped: {_e}")
# Extract structured web sources from web_search tool output.
# web_search returns {"output": ..., "exit_code": 0}; check "output"
# first so the <!-- SOURCES:…--> marker is found and stripped even
@@ -2833,18 +2957,20 @@ async def stream_agent_loop(
# On a bash/python timeout the result carries error + (often
# empty) stdout/stderr; fall back to the error so the "timed
# out" reason reaches the UI instead of a blank result.
output_text = (result["stdout"] or result["stderr"] or result.get("error", ""))[:2000]
raw = result["stdout"] or result["stderr"] or result.get("error", "")
output_text = _truncate(raw)
elif "output" in result:
# bash / python canonical result: {"output": ..., "exit_code": ...}
output_text = (result["output"] or "")[:2000]
raw = result["output"] or ""
output_text = _truncate(raw)
elif "response" in result:
# AI interaction tools (chat_with_model, send_to_session)
label = result.get("model", result.get("session_name", "AI"))
output_text = f"{label}: {result['response']}"[:4000]
output_text = _truncate(f"{label}: {result['response']}")
elif "content" in result:
output_text = result["content"][:2000]
output_text = _truncate(result["content"])
elif "results" in result:
output_text = result["results"][:4000]
output_text = _truncate(result["results"])
elif "session_id" in result and "name" in result:
output_text = f"Session created: {result['name']} (id: {result['session_id']})"
elif "success" in result:
@@ -2854,7 +2980,7 @@ async def stream_agent_loop(
else f"Error: {result.get('error', '')}"
)
elif "error" in result:
output_text = result["error"][:2000]
output_text = _truncate(result["error"])
# Emit tool_output (include ui_event data if present)
tool_output_data = {"type": "tool_output", "tool": block.tool_type, "command": cmd_display, "output": output_text, "exit_code": result.get("exit_code")}
@@ -18,6 +18,30 @@ from src.tool_utils import _truncate, get_mcp_manager, set_mcp_manager
logger = logging.getLogger(__name__)
from .subprocess_tools import BashTool, PythonTool
from .web_tools import WebSearchTool, WebFetchTool
from .filesystem_tools import ReadFileTool, WriteFileTool, EditFileTool, LsTool, GlobTool, GrepTool, GetWorkspaceTool
from .document_tools import CreateDocumentTool, UpdateDocumentTool, EditDocumentTool, SuggestDocumentTool, ManageDocumentTool
TOOL_HANDLERS = {
"bash": BashTool().execute,
"python": PythonTool().execute,
"web_search": WebSearchTool().execute,
"web_fetch": WebFetchTool().execute,
"read_file": ReadFileTool().execute,
"write_file": WriteFileTool().execute,
"edit_file": EditFileTool().execute,
"ls": LsTool().execute,
"glob": GlobTool().execute,
"grep": GrepTool().execute,
"create_document": CreateDocumentTool().execute,
"update_document": UpdateDocumentTool().execute,
"edit_document": EditDocumentTool().execute,
"suggest_document": SuggestDocumentTool().execute,
"manage_documents": ManageDocumentTool().execute,
"get_workspace": GetWorkspaceTool().execute,
}
# ---------------------------------------------------------------------------
# Constants (re-exported for backward compatibility — single source of truth
# is src.constants; always prefer importing from there for new code)
@@ -28,7 +52,7 @@ PYTHON_TIMEOUT = 30
# Tool types that trigger execution
TOOL_TAGS = {"bash", "python", "web_search", "web_fetch", "read_file", "write_file", "edit_file",
"grep", "glob", "ls",
"grep", "glob", "ls", "get_workspace",
"create_document", "update_document", "edit_document",
"search_chats",
"chat_with_model", "create_session", "list_sessions",
@@ -92,15 +116,14 @@ from src.tool_execution import ( # noqa: E402, F401
format_tool_result,
)
# Document functions
from .document_tools import (
set_active_document,
set_active_model
)
# Implementations
from src.tool_implementations import ( # noqa: E402, F401
set_active_document,
set_active_model,
get_active_document,
do_create_document,
do_update_document,
do_edit_document,
do_suggest_document,
do_search_chats,
do_manage_skills,
do_manage_tasks,
@@ -108,7 +131,6 @@ from src.tool_implementations import ( # noqa: E402, F401
do_manage_mcp,
do_manage_webhooks,
do_manage_tokens,
do_manage_documents,
do_manage_settings,
do_api_call,
)
+644
View File
@@ -0,0 +1,644 @@
from typing import Any, Dict, List, Optional
import logging
import re
import json
from src.constants import MAX_READ_CHARS
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Active document state
# ---------------------------------------------------------------------------
_active_document_id: Optional[str] = None
_active_model: Optional[str] = None
def set_active_document(doc_id: Optional[str]):
"""Set the active document ID for document tool execution."""
global _active_document_id
_active_document_id = doc_id
def set_active_model(model: Optional[str]):
"""Set the current model name for version summaries."""
global _active_model
_active_model = model
def get_active_document():
return _active_document_id
def clear_active_document(doc_id: Optional[str] = None) -> bool:
"""Clear the in-memory active-document pointer.
With ``doc_id`` given, only clears when it matches the current pointer, so a
different active document is left untouched. Returns True if it was cleared.
Called when a document is detached from its session or deleted (its tab is
closed): without this, the stale pointer makes the last-resort doc-injection
path re-surface a closed document in a later, unrelated chat — even one whose
session no longer matches — because an unlinked doc has session_id NULL (#1160).
"""
global _active_document_id
if doc_id is None or _active_document_id == doc_id:
_active_document_id = None
return True
return False
def _owned_document_query(query, Document, owner: Optional[str]):
if owner is None:
# A bare Python `False` is not a valid SQL expression — SQLAlchemy 1.4
# deprecates it and 2.0 raises ArgumentError. Use the SQL `false()`
# literal to return zero rows for an unscoped (owner-less) query.
from sqlalchemy import false
return query.filter(false())
return query.filter(Document.owner == owner)
def _get_owned_document(db, Document, doc_id: str, owner: Optional[str], active_only: bool = False):
q = db.query(Document).filter(Document.id == doc_id)
if active_only:
q = q.filter(Document.is_active == True)
q = _owned_document_query(q, Document, owner)
return q.first()
def _most_recent_owned_document(db, Document, owner: Optional[str], active_only: bool = False):
q = db.query(Document)
if active_only:
q = q.filter(Document.is_active == True)
q = _owned_document_query(q, Document, owner)
return q.order_by(Document.updated_at.desc()).first()
# ---------------------------------------------------------------------------
# Document tools — create/update/edit/suggest living documents
# ---------------------------------------------------------------------------
def _sniff_doc_language(text: str) -> str:
"""Best-effort detect a document's language from its content when the model
didn't specify one. Defaults to 'markdown' (prose). Recognizes the common
markup/code types the editor supports so e.g. an SVG isn't saved as markdown."""
import json as _json, re as _re2
s = (text or "").strip()
if not s:
return "markdown"
head = s[:600]
hl = head.lower()
if _looks_like_email_document(s):
return "email"
# Markup (unambiguous)
if "<svg" in hl:
return "svg"
if hl.startswith("<?xml"):
return "xml"
if (hl.startswith("<!doctype html") or hl.startswith("<html")
or _re2.search(r"<(div|body|head|p|span|table|button|h[1-6]|ul|ol|li|img)\b", hl)):
return "html"
# JSON
if s[0] in "{[":
try:
_json.loads(s)
return "json"
except Exception:
pass
# Shebang
first = s.split("\n", 1)[0].strip().lower()
if first.startswith("#!"):
return "python" if "python" in first else "bash"
# Code by strong leading signals (line-anchored so prose with stray words won't match)
if _re2.search(r"(?m)^\s*(def \w|class \w|import \w|from \w[\w.]* import )", s):
return "python"
if _re2.search(r"(?m)^\s*(function \w|const \w|let \w|export |import .* from )", s):
return "javascript"
if _re2.search(r"(?mi)^\s*(select .* from |create table |insert into |update \w)", s):
return "sql"
if _re2.search(r"(?m)^[.#]?[\w-]+\s*\{[^{}]*:[^{}]*;", s):
return "css"
return "markdown"
def _looks_like_email_document(text: str = "", title: str = "") -> bool:
import re as _re
title_l = (title or "").strip().lower()
if title_l in {"new email", "new mail", "new message"}:
return True
s = (text or "").lstrip()
if "\n---\n" in s and _re.search(r"(?im)^To:\s*", s) and _re.search(r"(?im)^Subject:\s*", s):
return True
return bool(_re.search(r"(?im)^To:\s*", s) and _re.search(r"(?im)^Subject:\s*", s))
def _coerce_email_document_content(existing: str, incoming: str) -> str:
"""Keep email docs in the To/Subject/---/body shape even if a model writes
only the body or dumps header labels without the separator."""
import re as _re
old = existing or ""
new = (incoming or "").strip()
if "\n---\n" in new:
return new
header = old.split("\n---\n", 1)[0] if "\n---\n" in old else "To: \nSubject: "
if _looks_like_email_document(new):
lines = new.splitlines()
last_header_idx = -1
header_re = _re.compile(r"^(To|Cc|Bcc|Subject|In-Reply-To|References|X-Source-UID|X-Source-Folder|X-Attachments):", _re.I)
for i, line in enumerate(lines):
if header_re.match(line.strip()):
last_header_idx = i
body_lines = lines[last_header_idx + 1:] if last_header_idx >= 0 else lines
while body_lines and not body_lines[0].strip():
body_lines.pop(0)
body = "\n".join(body_lines).strip()
else:
body = new
return header.rstrip() + "\n---\n" + body
def _parse_tool_args(content):
"""Parse a tool-call argument blob.
Accepts either a JSON string or an already-decoded dict. Unwraps the
common `{"body": {...}}` envelope that smaller models emit when they
read tool descriptions like "Body is JSON: {...}" literally — they
pass `body` as a field name rather than treating it as a noun.
Returns a dict on success, raises ValueError on bad JSON.
"""
if isinstance(content, str):
try:
args = json.loads(content) if content.strip() else {}
except (json.JSONDecodeError, TypeError) as e:
raise ValueError(str(e))
elif isinstance(content, dict):
args = content
else:
args = {}
# Unwrap {"body": {...}} envelope — but only if `body` is the sole key
# and points at a dict. We don't want to clobber a legitimate `body`
# field on tools where it's a real arg (e.g. send_email body text).
if (
isinstance(args, dict)
and len(args) == 1
and "body" in args
and isinstance(args["body"], dict)
and "action" in args["body"] # extra safety: only unwrap if the inner dict looks like a tool call
):
args = args["body"]
return args
def parse_edit_blocks(content: str) -> list:
"""Parse <<<FIND>>>...<<<REPLACE>>>...<<<END>>> blocks."""
edits = []
pattern = r'<<<FIND>>>\n(.*?)\n<<<REPLACE>>>\n(.*?)\n<<<END>>>'
for m in re.finditer(pattern, content, re.DOTALL):
edits.append({"find": m.group(1), "replace": m.group(2)})
return edits
def parse_suggest_blocks(content: str) -> list:
"""Parse <<<FIND>>>...<<<SUGGEST>>>...<<<REASON>>>...<<<END>>> blocks."""
suggestions = []
_skip_phrases = ["no change", "clear", "fine as", "looks good", "no improvement", "keep as"]
pattern = r'<<<FIND>>>\n(.*?)\n<<<SUGGEST>>>\n(.*?)\n<<<REASON>>>\n(.*?)\n<<<END>>>'
for m in re.finditer(pattern, content, re.DOTALL):
find_text = m.group(1)
replace_text = m.group(2)
reason = m.group(3).strip()
# Skip no-op suggestions where find == replace or reason says no change
if find_text.strip() == replace_text.strip():
continue
if any(phrase in reason.lower() for phrase in _skip_phrases):
continue
suggestions.append({
"id": f"sugg-{len(suggestions)+1}",
"find": find_text,
"replace": replace_text,
"reason": reason,
})
return suggestions
class CreateDocumentTool:
async def execute(self, content: str, ctx: dict) -> dict:
"""Create a new document. Supports two formats:
1) Line-based: line 1 = title, line 2 (optional) = language, rest = content
2) XML-like tags: <title>...</title><language>...</language><content>...</content>
Some models mix them — strip any XML-style tags and fall back to line parsing."""
import uuid, re as _re
from src.database import SessionLocal, Document, DocumentVersion, Session as DbSession
raw = content or ""
session_id = ctx.get("session_id")
owner = ctx.get("owner")
# Known languages the editor understands (match the <select> in HTML)
_KNOWN_LANGS = {
"python", "javascript", "typescript", "html", "css", "markdown", "json",
"yaml", "bash", "sql", "rust", "go", "java", "c", "cpp", "xml", "toml",
"ini", "ruby", "php", "csv", "email", "text", "plain", "svg",
}
# Try XML tag extraction first
title = None
language = None
content = None
mt = _re.search(r"<title>\s*(.*?)\s*</title>", raw, _re.DOTALL | _re.IGNORECASE)
ml = _re.search(r"<language>\s*(.*?)\s*</language>", raw, _re.DOTALL | _re.IGNORECASE)
mc = _re.search(r"<content>\s*(.*?)\s*</content>", raw, _re.DOTALL | _re.IGNORECASE)
if mt or mc:
title = mt.group(1).strip() if mt else None
language = ml.group(1).strip().lower() if ml else None
content = mc.group(1) if mc else None
# Fall back to line-based parsing. First strip any stray XML-ish tags.
if title is None or content is None:
cleaned = _re.sub(r"</?(?:title|language|content)>", "", raw)
lines = cleaned.strip().split("\n")
if title is None:
title = lines[0].strip() if lines else "Untitled"
lines = lines[1:]
# Only consume second line as language if it looks like a valid short lang token
if language is None and lines:
candidate = lines[0].strip().lower()
if candidate and len(candidate) < 20 and " " not in candidate and candidate in _KNOWN_LANGS:
language = candidate
lines = lines[1:]
if content is None:
content = "\n".join(lines)
# Validate language: must be in known set, else default based on content
if language and language not in _KNOWN_LANGS:
language = None
if not language:
# No explicit language — sniff it from the content so an SVG / HTML / JSON
# / code document isn't silently saved as markdown. Prose → markdown.
language = _sniff_doc_language(content)
if _looks_like_email_document(content, title):
language = "email"
if not title:
title = "Untitled"
if not session_id:
return {"error": "No session context for document creation"}
db = SessionLocal()
try:
doc_id = str(uuid.uuid4())
ver_id = str(uuid.uuid4())
# Inherit ownership from the chat session so the doc survives that
# session later being deleted (session_id → NULL).
_sess = db.query(DbSession).filter(DbSession.id == session_id).first()
if owner is not None and (not _sess or _sess.owner != owner):
return {"error": "Cannot create document in another user's session"}
_owner = _sess.owner if _sess else None
doc = Document(
id=doc_id,
session_id=session_id,
title=title,
language=language,
current_content=content,
version_count=1,
is_active=True,
owner=_owner,
)
ver = DocumentVersion(
id=ver_id,
document_id=doc_id,
version_number=1,
content=content,
summary=f"Created by {_active_model or 'AI'}",
source="ai",
)
db.add(doc)
db.add(ver)
db.commit()
set_active_document(doc_id)
try:
from src.event_bus import fire_event
fire_event("document_created", _owner)
except Exception:
logger.debug("document_created event dispatch failed", exc_info=True)
return {
"action": "create",
"doc_id": doc_id,
"title": title,
"language": language,
"content": content,
"version": 1,
}
except Exception as e:
db.rollback()
return {"error": f"Failed to create document: {e}"}
finally:
db.close()
class UpdateDocumentTool:
async def execute(self, content: str, ctx: dict) -> Dict:
"""Update an existing document. Content = full new document text."""
import uuid
from src.database import SessionLocal, Document, DocumentVersion
target_id = ctx.get("doc_id", None) or _active_document_id
owner = ctx.get("owner")
db = SessionLocal()
try:
doc = None
if target_id:
doc = _get_owned_document(db, Document, target_id, owner)
if not doc:
doc = _most_recent_owned_document(db, Document, owner)
if doc:
target_id = doc.id
set_active_document(target_id)
logger.info(f"update_document: fell back to most recent doc id={target_id}")
if not doc:
return {"error": "No documents exist to update"}
is_email_doc = doc.language == "email" or _looks_like_email_document(doc.current_content or "", doc.title or "")
new_content = _coerce_email_document_content(doc.current_content or "", content) if is_email_doc else content.strip()
if is_email_doc:
doc.language = "email"
new_ver = doc.version_count + 1
ver = DocumentVersion(
id=str(uuid.uuid4()),
document_id=target_id,
version_number=new_ver,
content=new_content,
summary=f"Updated by {_active_model or 'AI'}",
source="ai",
)
doc.current_content = new_content
doc.version_count = new_ver
db.add(ver)
db.commit()
return {
"action": "update",
"doc_id": target_id,
"title": doc.title,
"language": doc.language,
"content": new_content,
"version": new_ver,
}
except Exception as e:
db.rollback()
return {"error": f"Failed to update document: {e}"}
finally:
db.close()
class EditDocumentTool:
async def execute(self, content: str, ctx: dict) -> Dict:
"""Apply targeted FIND/REPLACE edits to an existing document."""
import uuid
from src.database import SessionLocal, Document, DocumentVersion
target_id = ctx.get("doc_id", None) or _active_document_id
owner = ctx.get("owner")
edits = parse_edit_blocks(content)
if not edits:
return {"error": "No valid <<<FIND>>>...<<<REPLACE>>>...<<<END>>> blocks found"}
db = SessionLocal()
try:
doc = None
if target_id:
doc = _get_owned_document(db, Document, target_id, owner)
if not doc:
# Fallback: most recently updated document. Avoids "no active doc" errors
# after server restart or when the agent loses track of which doc to edit.
doc = _most_recent_owned_document(db, Document, owner)
if doc:
target_id = doc.id
set_active_document(target_id)
logger.info(f"edit_document: fell back to most recent doc id={target_id} title={doc.title!r}")
if not doc:
return {"error": "No documents exist to edit"}
updated_content = doc.current_content
applied = 0
skipped = 0
for edit in edits:
_find = edit["find"]
if _find in updated_content:
updated_content = updated_content.replace(_find, edit["replace"], 1)
applied += 1
else:
# Defensive: the active-doc context shows a "N\t" line-number
# gutter for reference. Weaker models sometimes copy that prefix
# into FIND. If the exact match failed, retry with a leading
# "<digits><tab>" stripped from each FIND line — but only use it
# when that stripped form actually matches, so we never corrupt a
# legitimately tab-prefixed document.
_stripped = "\n".join(re.sub(r"^\d+\t", "", _l) for _l in _find.split("\n"))
if _stripped != _find and _stripped in updated_content:
updated_content = updated_content.replace(_stripped, edit["replace"], 1)
applied += 1
logger.info("edit_document: matched after stripping line-number gutter from FIND")
else:
logger.warning(f"edit_document: FIND text not found, skipping: {_find[:80]!r}")
skipped += 1
if applied == 0:
return {"error": f"No edits applied — none of the FIND blocks matched the document content (skipped {skipped})"}
new_ver = doc.version_count + 1
ver = DocumentVersion(
id=str(uuid.uuid4()),
document_id=target_id,
version_number=new_ver,
content=updated_content,
summary=f"Edited by {_active_model or 'AI'} ({applied} edit(s))",
source="ai",
)
doc.current_content = updated_content
doc.version_count = new_ver
db.add(ver)
db.commit()
return {
"action": "edit",
"doc_id": target_id,
"title": doc.title,
"language": doc.language,
"content": updated_content,
"version": new_ver,
"applied": applied,
"skipped": skipped,
}
except Exception as e:
db.rollback()
return {"error": f"Failed to edit document: {e}"}
finally:
db.close()
class SuggestDocumentTool:
async def execute(self, content: str, ctx: dict) -> Dict:
"""Create inline suggestions for the active document WITHOUT modifying it."""
from src.database import SessionLocal, Document
target_id = ctx.get("doc_id", None) or _active_document_id
owner = ctx.get("owner")
if not target_id:
return {"error": "No active document to suggest on"}
suggestions = parse_suggest_blocks(content)
if not suggestions:
return {"error": "No valid <<<FIND>>>...<<<SUGGEST>>>...<<<REASON>>>...<<<END>>> blocks found"}
db = SessionLocal()
try:
doc = _get_owned_document(db, Document, target_id, owner)
if not doc:
return {"error": f"Document {target_id} not found"}
# Validate that FIND text exists in document
valid = []
for s in suggestions:
if s["find"] in doc.current_content:
valid.append(s)
else:
logger.warning(f"suggest_document: FIND text not found, skipping: {s['find'][:80]!r}")
if not valid:
return {"error": "No suggestions matched the document content"}
return {
"action": "suggest",
"doc_id": target_id,
"suggestions": valid,
"count": len(valid),
}
finally:
db.close()
# ---------------------------------------------------------------------------
# Document management tool (delete, list, organize)
# ---------------------------------------------------------------------------
class ManageDocumentTool:
async def execute(self, content: str, ctx: dict) -> Dict:
"""Manage documents: list, read/view/open, delete, tidy.
Output format mirrors `manage_session`: list rows include a
clickable `[Title](#document-<id>)` anchor + relative timestamps
so the user can click straight from chat to open the editor.
"""
from core.database import SessionLocal, Document
from datetime import datetime, timezone
owner = ctx.get("owner")
try:
args = _parse_tool_args(content)
except ValueError:
return {"error": "Invalid JSON arguments", "exit_code": 1}
action = args.get("action", "list")
db = SessionLocal()
def _rel(ts):
if not ts:
return 'never'
try:
now = datetime.now(timezone.utc) if ts.tzinfo is not None else datetime.utcnow()
diff = (now - ts).total_seconds()
except Exception:
return 'unknown'
if diff < 60: return 'just now'
if diff < 3600: return f'{int(diff / 60)}m ago'
if diff < 86400: return f'{int(diff / 3600)}h ago'
if diff < 86400 * 7: return f'{int(diff / 86400)}d ago'
return ts.strftime('%Y-%m-%d')
try:
if action == "list":
q = db.query(Document).filter(Document.is_active == True)
q = _owned_document_query(q, Document, owner)
if args.get("search"):
q = q.filter(Document.title.ilike(f"%{args['search']}%"))
if args.get("language"):
q = q.filter(Document.language == args["language"])
docs = q.order_by(Document.updated_at.desc()).limit(args.get("limit", 50)).all()
if not docs:
msg = "No documents found" + (f" matching '{args['search']}'" if args.get("search") else "") + "."
return {"response": msg, "documents": [], "exit_code": 0}
lines = []
items = []
for i, d in enumerate(docs):
size = len(d.current_content or "")
lang = d.language or "text"
ts = getattr(d, 'updated_at', None) or getattr(d, 'created_at', None)
marker = " ← most recent" if i == 0 else ""
lines.append(
f"- [{d.title}](#document-{d.id}) — {lang}, {size} chars, updated {_rel(ts)}{marker}"
)
items.append({"id": d.id, "title": d.title, "language": lang, "size": size})
header = f"Found {len(docs)} document(s), sorted most-recent first. Click a title to open:"
return {
"response": header + "\n" + "\n".join(lines),
"documents": items,
"exit_code": 0,
}
elif action in ("read", "view", "open", "get"):
doc_id = args.get("document_id") or args.get("id") or args.get("uid")
if not doc_id:
return {"error": "Need document_id (use action=list to find one)", "exit_code": 1}
doc = _get_owned_document(db, Document, doc_id, owner, active_only=True)
if not doc:
return {"error": f"Document '{doc_id}' not found", "exit_code": 1}
body = doc.current_content or ""
preview_limit = int(args.get("limit", MAX_READ_CHARS))
truncated = len(body) > preview_limit
preview = body[:preview_limit] + (f"\n... (truncated, {len(body)} chars total)" if truncated else "")
anchor = f"[{doc.title}](#document-{doc.id})"
return {
"response": f"{anchor} — click to open in editor.\n\n```{doc.language or ''}\n{preview}\n```",
"document": {
"id": doc.id,
"title": doc.title,
"language": doc.language,
"size": len(body),
"content": preview,
"truncated": truncated,
},
"exit_code": 0,
}
elif action == "delete":
doc_id = args.get("document_id") or args.get("id") or args.get("uid") or _active_document_id
doc = None
if doc_id:
doc = _get_owned_document(db, Document, doc_id, owner)
if not doc:
# Fallback: most recently updated doc (likely what the user means)
doc = _most_recent_owned_document(db, Document, owner, active_only=True)
if not doc:
return {"error": "No document to delete", "exit_code": 1}
title = doc.title
doc.is_active = False
db.commit()
if _active_document_id == doc.id:
set_active_document(None)
return {"response": f"Deleted document '{title}'", "exit_code": 0}
elif action == "tidy":
from src.document_actions import run_document_tidy
result = await run_document_tidy(owner or "")
return {"response": result, "exit_code": 0}
else:
return {"error": f"Unknown action: {action}", "exit_code": 1}
except Exception as e:
logger.error(f"manage_documents error: {e}")
return {"error": str(e), "exit_code": 1}
finally:
db.close()
+398
View File
@@ -0,0 +1,398 @@
import asyncio
import json
import os
import difflib
import fnmatch
import shutil
from typing import Optional, Dict, Any, Tuple
from src.constants import MAX_READ_CHARS, MAX_DIFF_LINES, MAX_OUTPUT_CHARS
_CODENAV_SKIP_DIRS = frozenset({
".git", ".hg", ".svn", "node_modules", "venv", ".venv", "__pycache__",
".mypy_cache", ".pytest_cache", ".ruff_cache", "dist", "build",
".next", ".cache", "site-packages", ".idea", ".tox",
})
_CODENAV_MAX_HITS = 200
_CODENAV_MAX_LINE = 400
def _unified_diff(old: str, new: str, path: str) -> Optional[Dict[str, Any]]:
if old == new:
return None
old_lines = old.splitlines()
new_lines = new.splitlines()
label = path or "file"
diff_lines = list(difflib.unified_diff(
old_lines, new_lines,
fromfile=f"a/{label}", tofile=f"b/{label}",
lineterm="",
))
added = sum(1 for line in diff_lines if line.startswith("+") and not line.startswith("+++"))
removed = sum(1 for line in diff_lines if line.startswith("-") and not line.startswith("---"))
truncated = False
if len(diff_lines) > MAX_DIFF_LINES:
diff_lines = diff_lines[:MAX_DIFF_LINES]
truncated = True
text = "\n".join(diff_lines)
if truncated:
text += f"\n… diff truncated at {MAX_DIFF_LINES} lines"
return {
"text": text,
"added": added,
"removed": removed,
"new_file": old == "",
"file": os.path.basename(path) or (path or "file"),
}
class EditFileTool:
async def execute(self, content: str, ctx: dict) -> dict:
from src.tool_execution import _resolve_tool_path, _resolve_search_root, _truncate
try:
args = json.loads(content) if content.strip().startswith("{") else {}
except (json.JSONDecodeError, TypeError):
args = {}
raw_path = (args.get("path") or "").strip()
old = args.get("old_string", "")
new = args.get("new_string", "")
replace_all = bool(args.get("replace_all", False))
if not raw_path:
return {"error": "edit_file: path required", "exit_code": 1}
try:
path = _resolve_tool_path(raw_path)
except ValueError as e:
return {"error": f"edit_file: {e}", "exit_code": 1}
if old == "":
return {"error": "edit_file: old_string required (use write_file to create a file)", "exit_code": 1}
if old == new:
return {"error": "edit_file: old_string and new_string are identical", "exit_code": 1}
def _apply():
"""Helper function that performs the actual string replacement and file writing logic."""
with open(path, "r", encoding="utf-8") as f:
original = f.read()
count = original.count(old)
if count == 0:
return original, None, "not_found"
if count > 1 and not replace_all:
return original, None, f"not_unique:{count}"
updated = original.replace(old, new) if replace_all else original.replace(old, new, 1)
with open(path, "w", encoding="utf-8") as f:
f.write(updated)
return original, updated, "ok"
try:
original, updated, status = await asyncio.to_thread(_apply)
except FileNotFoundError:
return {"error": f"edit_file: {path}: not found (use write_file to create it)", "exit_code": 1}
except (IsADirectoryError, UnicodeDecodeError):
return {"error": f"edit_file: {path}: not an editable text file", "exit_code": 1}
except PermissionError:
return {"error": f"edit_file: {path}: permission denied", "exit_code": 1}
except OSError as e:
return {"error": f"edit_file: {path}: {e}", "exit_code": 1}
if status == "not_found":
return {"error": f"edit_file: old_string not found in {path}. Read the file and match it exactly.", "exit_code": 1}
if status.startswith("not_unique"):
n = status.split(":", 1)[1]
return {"error": f"edit_file: old_string is not unique in {path} ({n} matches). Add surrounding context or set replace_all=true.", "exit_code": 1}
n = original.count(old)
result = {"output": f"Edited {path} ({n} replacement{'s' if n != 1 else ''})", "exit_code": 0}
diff = _unified_diff(original, updated, path)
if diff:
result["diff"] = diff
return result
class ReadFileTool:
async def execute(self, content: str, ctx: dict) -> dict:
from src.tool_execution import _resolve_tool_path, _resolve_search_root, _truncate
raw_path, offset, limit = content.split("\n", 1)[0].strip(), 0, 0
_stripped = content.strip()
if _stripped.startswith("{"):
try:
_a = json.loads(_stripped)
raw_path = str(_a.get("path", "")).strip()
offset = int(_a.get("offset") or 0)
limit = int(_a.get("limit") or 0)
except (json.JSONDecodeError, TypeError, ValueError):
pass
try:
path = _resolve_tool_path(raw_path)
except ValueError as e:
return {"error": f"read_file: {e}", "exit_code": 1}
try:
def _read():
if offset > 0 or limit > 0:
start = max(offset, 1)
out, n, budget = [], 0, MAX_READ_CHARS
with open(path, "r", encoding="utf-8", errors="replace") as f:
for i, line in enumerate(f, 1):
if i < start:
continue
if limit > 0 and n >= limit:
break
out.append(line)
n += 1
budget -= len(line)
if budget <= 0:
out.append(f"\n... [truncated at {MAX_READ_CHARS} chars]")
break
return "".join(out)
with open(path, "r", encoding="utf-8", errors="replace") as f:
return f.read(MAX_READ_CHARS + 1)
data = await asyncio.to_thread(_read)
except FileNotFoundError:
return {"error": f"read_file: {path}: not found", "exit_code": 1}
except PermissionError:
return {"error": f"read_file: {path}: permission denied", "exit_code": 1}
except IsADirectoryError:
return {"error": f"read_file: {path}: is a directory (use ls)", "exit_code": 1}
except OSError as e:
return {"error": f"read_file: {path}: {e}", "exit_code": 1}
if not (offset > 0 or limit > 0) and len(data) > MAX_READ_CHARS:
data = data[:MAX_READ_CHARS] + f"\n... [truncated at {MAX_READ_CHARS} chars]"
return {"output": data, "exit_code": 0}
class WriteFileTool:
async def execute(self, content: str, ctx: dict) -> dict:
from src.tool_execution import _resolve_tool_path, _resolve_search_root, _truncate
lines = content.split("\n", 1)
raw_path = lines[0].strip()
body = lines[1] if len(lines) > 1 else ""
try:
path = _resolve_tool_path(raw_path)
except ValueError as e:
return {"error": f"write_file: {e}", "exit_code": 1}
try:
def _write():
old = ""
try:
with open(path, "r", encoding="utf-8") as f:
old = f.read()
except (FileNotFoundError, IsADirectoryError, UnicodeDecodeError, OSError):
old = ""
d = os.path.dirname(path)
if d:
os.makedirs(d, exist_ok=True)
with open(path, "w", encoding="utf-8") as f:
f.write(body)
return old, len(body)
old_content, size = await asyncio.to_thread(_write)
except PermissionError:
return {"error": f"write_file: {path}: permission denied", "exit_code": 1}
except OSError as e:
return {"error": f"write_file: {path}: {e}", "exit_code": 1}
diff = _unified_diff(old_content, body, path)
result = {"output": f"Wrote {size} bytes to {path}", "exit_code": 0}
if diff:
result["diff"] = diff
return result
class LsTool:
async def execute(self, content: str, ctx: dict) -> dict:
from src.tool_execution import _resolve_tool_path, _resolve_search_root, _truncate
raw_path = ""
_s = (content or "").strip()
if _s.startswith("{"):
try:
raw_path = str(json.loads(_s).get("path", "")).strip()
except json.JSONDecodeError:
raw_path = ""
else:
raw_path = _s.split("\n", 1)[0].strip()
try:
root = _resolve_search_root(raw_path)
except ValueError as e:
return {"error": f"ls: {e}", "exit_code": 1}
def _ls():
if not os.path.isdir(root):
return None, f"ls: {root}: not a directory"
rows = []
try:
with os.scandir(root) as it:
for entry in it:
if entry.name.startswith("."):
continue
try:
is_dir = entry.is_dir(follow_symlinks=False)
size = entry.stat(follow_symlinks=False).st_size if not is_dir else 0
except OSError:
continue
rows.append((is_dir, entry.name, size))
except (PermissionError, OSError) as _e:
return None, f"ls: {_e}"
rows.sort(key=lambda r: (not r[0], r[1].lower()))
lines = [f"{root}:"]
for is_dir, name, size in rows[:_CODENAV_MAX_HITS]:
lines.append(f" {name}/" if is_dir else f" {name} ({size} B)")
if len(rows) > _CODENAV_MAX_HITS:
lines.append(f" ... [{len(rows) - _CODENAV_MAX_HITS} more]")
if not rows:
lines.append(" (empty)")
return "\n".join(lines), None
out, err = await asyncio.to_thread(_ls)
if err:
return {"error": err, "exit_code": 1}
return {"output": _truncate(out), "exit_code": 0}
class GlobTool:
async def execute(self, content: str, ctx: dict) -> dict:
from src.tool_execution import _resolve_tool_path, _resolve_search_root, _truncate
args = {}
_s = (content or "").strip()
if _s.startswith("{"):
try:
args = json.loads(_s)
except json.JSONDecodeError:
args = {}
else:
args = {"pattern": _s}
pattern = str(args.get("pattern", "")).strip()
if not pattern:
return {"error": "glob: pattern is required", "exit_code": 1}
try:
root = _resolve_search_root(str(args.get("path", "")))
except ValueError as e:
return {"error": f"glob: {e}", "exit_code": 1}
def _glob():
from pathlib import Path
base = Path(root)
if not base.is_dir():
return None, f"glob: {root}: not a directory"
matched = []
try:
for p in base.rglob(pattern):
if set(p.relative_to(base).parts) & _CODENAV_SKIP_DIRS:
continue
try:
mtime = p.stat().st_mtime
except OSError:
mtime = 0
matched.append((mtime, str(p)))
if len(matched) > _CODENAV_MAX_HITS * 5:
break
except (OSError, ValueError) as _e:
return None, f"glob: {_e}"
matched.sort(key=lambda t: t[0], reverse=True)
return [pth for _, pth in matched[:_CODENAV_MAX_HITS]], None
paths, err = await asyncio.to_thread(_glob)
if err:
return {"error": err, "exit_code": 1}
if not paths:
return {"output": f"No files matching {pattern!r} under {root}", "exit_code": 0}
out = "\n".join(paths)
if len(paths) >= _CODENAV_MAX_HITS:
out += f"\n... [capped at {_CODENAV_MAX_HITS} files]"
return {"output": _truncate(out), "exit_code": 0}
class GrepTool:
async def execute(self, content: str, ctx: dict) -> dict:
from src.tool_execution import _resolve_tool_path, _resolve_search_root, _truncate
args: Dict[str, Any] = {}
_s = (content or "").strip()
if _s.startswith("{"):
try:
args = json.loads(_s)
except json.JSONDecodeError:
args = {}
else:
args = {"pattern": _s}
pattern = str(args.get("pattern", "")).strip()
if not pattern:
return {"error": "grep: pattern is required", "exit_code": 1}
ignore_case = bool(args.get("ignore_case"))
glob_pat = str(args.get("glob", "") or "").strip()
try:
max_hits = int(args.get("max_results") or _CODENAV_MAX_HITS)
except (TypeError, ValueError):
max_hits = _CODENAV_MAX_HITS
max_hits = max(1, min(max_hits, _CODENAV_MAX_HITS))
try:
root = _resolve_search_root(str(args.get("path", "")))
except ValueError as e:
return {"error": f"grep: {e}", "exit_code": 1}
def _grep():
import re as _re
import shutil
rg = shutil.which("rg")
if rg:
cmd = [rg, "--line-number", "--no-heading", "--color=never",
"--max-count", str(max_hits)]
if ignore_case:
cmd.append("--ignore-case")
if glob_pat:
cmd += ["--glob", glob_pat]
for _d in _CODENAV_SKIP_DIRS:
cmd += ["--glob", f"!**/{_d}/**"]
cmd += ["--regexp", pattern, root]
try:
import subprocess
p = subprocess.run(cmd, capture_output=True, text=True, timeout=20)
lines = [ln for ln in (p.stdout or "").splitlines() if ln][:max_hits]
return lines, None
except subprocess.TimeoutExpired:
return None, "grep: timed out"
except Exception as _e:
return None, f"grep: {_e}"
try:
rx = _re.compile(pattern, _re.IGNORECASE if ignore_case else 0)
except _re.error as _e:
return None, f"grep: bad pattern: {_e}"
hits = []
if os.path.isfile(root):
file_iter = [root]
else:
file_iter = []
for dp, dns, fns in os.walk(root):
dns[:] = [d for d in dns if d not in _CODENAV_SKIP_DIRS]
for fn in fns:
if glob_pat and not fnmatch.fnmatch(fn, glob_pat):
continue
file_iter.append(os.path.join(dp, fn))
for fp in file_iter:
if len(hits) >= max_hits:
break
try:
with open(fp, "r", encoding="utf-8", errors="strict") as f:
for i, line in enumerate(f, 1):
if rx.search(line):
hits.append(f"{fp}:{i}:{line.rstrip()[:_CODENAV_MAX_LINE]}")
if len(hits) >= max_hits:
break
except (UnicodeDecodeError, OSError):
continue
return hits, None
lines, err = await asyncio.to_thread(_grep)
if err:
return {"error": err, "exit_code": 1}
if not lines:
return {"output": f"No matches for {pattern!r} under {root}", "exit_code": 0}
out = "\n".join(ln[:_CODENAV_MAX_LINE] for ln in lines)
if len(lines) >= max_hits:
out += f"\n... [capped at {max_hits} matches]"
return {"output": _truncate(out), "exit_code": 0}
class GetWorkspaceTool:
"""Report the active workspace folder (no args). File tools are confined to
it; the shell starts there (cwd) but is NOT sandboxed."""
async def execute(self, content: str, ctx: dict) -> dict:
from src.tool_execution import get_active_workspace
ws = get_active_workspace()
if ws:
return {
"output": f"{ws}\n(File tools are confined to this folder; the shell starts "
f"here but is not sandboxed and can reach outside it.)",
"exit_code": 0,
}
return {
"output": "No workspace is set. File tools use the default allowed roots; "
"resolve paths from the user or use absolute paths.",
"exit_code": 0,
}
+153
View File
@@ -0,0 +1,153 @@
import asyncio
import sys
import time
import collections
from typing import Optional, Callable, Awaitable, Tuple, Dict
from src.constants import MAX_OUTPUT_CHARS
DEFAULT_BASH_TIMEOUT = 60 * 60 # 1 hour
DEFAULT_PYTHON_TIMEOUT = 60 * 60
PROGRESS_INTERVAL_S = 2.0
PROGRESS_TAIL_LINES = 12
async def _run_subprocess_streaming(
proc: asyncio.subprocess.Process,
*,
timeout: float,
progress_cb: Optional[Callable[[Dict], Awaitable[None]]] = None,
) -> Tuple[str, str, Optional[int], bool]:
started = time.time()
stdout_full: list[str] = []
stderr_full: list[str] = []
tail = collections.deque(maxlen=PROGRESS_TAIL_LINES)
async def _reader(stream, full_buf, label: str):
if stream is None:
return
while True:
line = await stream.readline()
if not line:
break
decoded = line.decode("utf-8", errors="replace").rstrip("\n")
full_buf.append(decoded)
if label == "err":
tail.append(f"! {decoded}")
else:
tail.append(decoded)
async def _progress_emitter():
await asyncio.sleep(PROGRESS_INTERVAL_S)
while True:
if progress_cb:
try:
await progress_cb({
"elapsed_s": round(time.time() - started, 1),
"tail": "\n".join(list(tail)),
})
except Exception:
pass
await asyncio.sleep(PROGRESS_INTERVAL_S)
rd_out = asyncio.create_task(_reader(proc.stdout, stdout_full, "out"))
rd_err = asyncio.create_task(_reader(proc.stderr, stderr_full, "err"))
prog_task = asyncio.create_task(_progress_emitter()) if progress_cb else None
timed_out = False
try:
await asyncio.wait_for(proc.wait(), timeout=timeout)
except asyncio.TimeoutError:
timed_out = True
try:
proc.kill()
except Exception:
pass
try:
await asyncio.wait_for(proc.wait(), timeout=2)
except Exception:
pass
except asyncio.CancelledError:
try:
proc.kill()
except Exception:
pass
try:
await asyncio.wait_for(proc.wait(), timeout=2)
except Exception:
pass
for t in (rd_out, rd_err):
t.cancel()
if prog_task is not None:
prog_task.cancel()
raise
finally:
if prog_task is not None and not prog_task.done():
prog_task.cancel()
try:
await prog_task
except (asyncio.CancelledError, Exception):
pass
for t in (rd_out, rd_err):
try:
await asyncio.wait_for(t, timeout=1)
except Exception:
pass
return (
"\n".join(stdout_full),
"\n".join(stderr_full),
proc.returncode,
timed_out,
)
class BashTool:
async def execute(self, content: str, ctx: dict) -> dict:
from src.tool_execution import agent_cwd, _truncate
progress_cb = ctx.get("progress_cb")
_subproc_env = ctx.get("subproc_env")
proc = await asyncio.create_subprocess_shell(
content,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
env=_subproc_env,
cwd=agent_cwd(),
)
stdout, stderr, rc, timed_out = await _run_subprocess_streaming(
proc,
timeout=DEFAULT_BASH_TIMEOUT,
progress_cb=progress_cb,
)
if timed_out:
return {"error": f"bash: timed out after {DEFAULT_BASH_TIMEOUT}s — process killed", "exit_code": 124, "stdout": _truncate(stdout, MAX_OUTPUT_CHARS), "stderr": _truncate(stderr, MAX_OUTPUT_CHARS)}
output = stdout.rstrip()
err = stderr.rstrip()
if err:
output = (output + "\nSTDERR: " + err).strip() if output else "STDERR: " + err
output = _truncate(output, MAX_OUTPUT_CHARS)
return {"output": output or "(no output)", "exit_code": rc or 0}
class PythonTool:
async def execute(self, content: str, ctx: dict) -> dict:
from src.tool_execution import agent_cwd, _truncate
progress_cb = ctx.get("progress_cb")
_subproc_env = ctx.get("subproc_env")
proc = await asyncio.create_subprocess_exec(
(sys.executable or "python"), "-I", "-c", content,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
env=_subproc_env,
cwd=agent_cwd(),
)
stdout, stderr, rc, timed_out = await _run_subprocess_streaming(
proc,
timeout=DEFAULT_PYTHON_TIMEOUT,
progress_cb=progress_cb,
)
if timed_out:
return {"error": f"python: timed out after {DEFAULT_PYTHON_TIMEOUT}s — process killed", "exit_code": 124, "stdout": _truncate(stdout, MAX_OUTPUT_CHARS), "stderr": _truncate(stderr, MAX_OUTPUT_CHARS)}
output = stdout.rstrip()
err = stderr.rstrip()
if err:
output = (output + "\nSTDERR: " + err).strip() if output else "STDERR: " + err
output = _truncate(output, MAX_OUTPUT_CHARS)
return {"output": output or "(no output)", "exit_code": rc or 0}
+101
View File
@@ -0,0 +1,101 @@
import asyncio
import json
from typing import Dict, Any
from src.constants import MAX_OUTPUT_CHARS
class WebSearchTool:
async def execute(self, content: str, ctx: dict) -> dict:
from src.search import comprehensive_web_search
raw = content.strip()
query = raw
time_filter = None
max_pages = 5
if raw.startswith("{"):
try:
parsed = json.loads(raw)
if isinstance(parsed, dict) and "query" in parsed:
query = str(parsed.get("query", "")).strip()
tf = parsed.get("time_filter") or parsed.get("freshness")
if isinstance(tf, str) and tf.lower() in ("day", "week", "month", "year"):
time_filter = tf.lower()
mp = parsed.get("max_pages")
if isinstance(mp, int) and 1 <= mp <= 10:
max_pages = mp
except json.JSONDecodeError:
pass
if not query:
query = raw.split("\n")[0].strip()
if time_filter is None:
q_lc = query.lower()
if any(kw in q_lc for kw in ("today", "latest", "breaking", "this morning", "right now", "currently")):
time_filter = "day"
elif any(kw in q_lc for kw in ("this week", "past week", "recent news", "last few days")):
time_filter = "week"
elif any(kw in q_lc for kw in ("this month", "past month")):
time_filter = "month"
elif " news" in q_lc or q_lc.startswith("news ") or q_lc.endswith(" news"):
time_filter = "week"
loop = asyncio.get_running_loop()
text, sources = await asyncio.wait_for(
loop.run_in_executor(
None,
lambda: comprehensive_web_search(
query,
max_pages=max_pages,
time_filter=time_filter,
return_sources=True,
),
),
timeout=30,
)
output = text[:MAX_OUTPUT_CHARS] if len(text) > MAX_OUTPUT_CHARS else text
if sources:
output += "\n\n<!-- SOURCES:" + json.dumps(sources) + " -->"
return {"output": output, "exit_code": 0}
class WebFetchTool:
async def execute(self, content: str, ctx: dict) -> dict:
from src.search.content import fetch_webpage_content
raw = content.strip()
url = ""
if raw.startswith("{"):
try:
parsed = json.loads(raw)
if isinstance(parsed, dict):
url = str(parsed.get("url") or "").strip()
except json.JSONDecodeError:
url = ""
if not url:
url = raw.split("\n")[0].strip()
if not url or url.startswith("{") or any(c in url for c in (" ", "\t", "\n")):
return {"error": "web_fetch: provide a single URL or domain, e.g. example.com", "exit_code": 1}
low = url.lower()
if "://" in low and not low.startswith(("http://", "https://")):
return {"error": f"web_fetch: unsupported URL scheme (only http/https): {url[:80]}", "exit_code": 1}
if not low.startswith(("http://", "https://")):
url = "https://" + url
loop = asyncio.get_running_loop()
try:
result = await asyncio.wait_for(
loop.run_in_executor(None, lambda: fetch_webpage_content(url, timeout=10)),
timeout=30,
)
except asyncio.TimeoutError:
return {"error": f"web_fetch: timed out fetching {url}", "exit_code": 1}
except Exception as e:
return {"error": f"web_fetch: {url}: {e}", "exit_code": 1}
err = result.get("error")
text = (result.get("content") or "").strip()
title = result.get("title") or ""
if not text:
if err:
return {"error": f"web_fetch: {url}: {err}", "exit_code": 1}
return {"error": f"web_fetch: {url}: no readable text content (not HTML, or the page needs JS/login)", "exit_code": 1}
header = (f"# {title}\n" if title else "") + f"Source: {url}\n\n"
output = header + text
if len(output) > MAX_OUTPUT_CHARS:
output = output[:MAX_OUTPUT_CHARS] + "\n\n[...truncated]"
return {"output": output, "exit_code": 0}
+9 -4
View File
@@ -24,7 +24,9 @@ MAX_PIPELINE_STEPS = 10
# ---------------------------------------------------------------------------
# Global managers (set from app.py, same pattern as _mcp_manager)
# ---------------------------------------------------------------------------
# _session_manager is kept as a local cache for performance (avoiding
# repeated get_session_manager_instance() calls). It's synced with
# the authoritative singleton in core.models.
_session_manager = None
_memory_manager = None
_memory_vector = None
@@ -33,11 +35,15 @@ _personal_docs_manager = None
def set_session_manager(mgr):
"""Set the global session manager. Syncs local cache + core singleton."""
global _session_manager
_session_manager = mgr
from core.models import set_session_manager_instance
set_session_manager_instance(mgr)
def get_session_manager():
"""Get the global session manager."""
return _session_manager
@@ -966,16 +972,15 @@ async def do_manage_memory(content: str, session_id: Optional[str] = None, owner
memories = [m for m in memories if m.get("category", "").lower() == category_filter]
if not memories:
return {"results": "No memories found" + (f" in category '{category_filter}'" if category_filter else "") + "."}
result_lines = [f"Found {len(memories)} memory entries:\n"]
for m in memories[:100]:
for m in memories:
cat = m.get("category", "fact")
mid = m.get("id", "?")[:8]
text = m.get("text", "")
if len(text) > 150:
text = text[:150] + "..."
result_lines.append(f"- [{cat}] `{mid}` — {text}")
if len(memories) > 100:
result_lines.append(f"... and {len(memories) - 100} more")
return {"results": "\n".join(result_lines)}
elif action == "add":
+16 -2
View File
@@ -4,6 +4,8 @@ import logging
from typing import Dict
from cryptography.fernet import Fernet, InvalidToken
from core.platform_compat import safe_chmod
logger = logging.getLogger(__name__)
class APIKeyManager:
@@ -15,12 +17,20 @@ class APIKeyManager:
def get_or_create_key(self) -> bytes:
"""Get or create encryption key for API keys"""
if os.path.exists(self.key_file):
# Older versions wrote .key with the process umask (often 0o644,
# i.e. group/world-readable). Re-restrict on read so existing
# installs heal without needing the key to be regenerated.
safe_chmod(self.key_file, 0o600)
with open(self.key_file, 'rb') as f:
return f.read()
else:
key = Fernet.generate_key()
with open(self.key_file, 'wb') as f:
f.write(key)
# This key decrypts every stored provider credential, so restrict it
# to the owner (0o600) — it must not be group/world-readable. No-op
# on Windows (files there are ACL-restricted to the user already).
safe_chmod(self.key_file, 0o600)
return key
def encrypt_api_key(self, api_key: str) -> str:
@@ -57,7 +67,12 @@ class APIKeyManager:
# Legacy/wrong shape (e.g. a list) — .items() would raise. Ignore it.
logger.warning("API keys file has unexpected shape (%s); ignoring", type(encrypted_keys).__name__)
return {}
return encrypted_keys
return {
str(provider): key
for provider, key in encrypted_keys.items()
if isinstance(key, str)
}
def save(self, provider: str, api_key: str):
"""Save encrypted API key to file.
@@ -82,4 +97,3 @@ class APIKeyManager:
except (InvalidToken, ValueError) as e:
logger.warning("Failed to decrypt API key for %s: %s", provider, e)
return decrypted
+2
View File
@@ -55,6 +55,8 @@ async def _drain_agent(sess, messages):
if "delta" in d:
delta = d.get("delta")
if isinstance(delta, str):
if d.get("thinking"):
continue
full += delta
elif d.get("type") == "agent_step":
round_num = d.get("round", round_num)
+32 -16
View File
@@ -579,6 +579,24 @@ def _classify_event_heuristic(summary: str) -> tuple:
return etype, None
def _memory_context_lines(mems, limit: int = 40) -> list:
"""Render Memory rows into short personal-context bullets for event classify.
Reads the Memory ORM `text` column. The previous inline code read a
non-existent `content` attribute, so it raised AttributeError on the first
row, the surrounding except swallowed it, and the classifier ran with no
personal context at all. getattr keeps it robust to future schema drift.
"""
lines: list = []
for m in mems:
c = (getattr(m, "text", "") or "").strip()
if c:
lines.append(f"- {c[:200]}")
if len(lines) >= limit:
break
return lines
async def action_classify_events(owner: str, **kwargs) -> Tuple[str, bool]:
"""Hybrid classification of upcoming calendar events: fast heuristic for
obvious cases, LLM fallback for ambiguous ones. Assigns event_type +
@@ -614,16 +632,11 @@ async def action_classify_events(owner: str, **kwargs) -> Tuple[str, bool]:
try:
from core.database import Memory as _Mem
_mems = db.query(_Mem).filter(_Mem.owner == owner).limit(60).all() if owner else []
if _mems:
_lines = []
for m in _mems:
c = (m.content or "").strip()
if c:
_lines.append(f"- {c[:200]}")
if _lines:
_memory_context = "USER CONTEXT (relationships, work, life):\n" + "\n".join(_lines[:40]) + "\n\n"
_lines = _memory_context_lines(_mems)
if _lines:
_memory_context = "USER CONTEXT (relationships, work, life):\n" + "\n".join(_lines) + "\n\n"
except Exception as _me:
logger.debug(f"Could not load memory for classify: {_me}")
logger.warning(f"Could not load memory for classify: {_me}")
classified_h = 0
classified_llm = 0
@@ -796,14 +809,14 @@ async def action_learn_sender_signatures(owner: str, **kwargs) -> Tuple[str, boo
import email as _email_mod
import asyncio as _aio
from datetime import datetime as _dt, timedelta as _td
from routes.email_helpers import _imap_connect, SCHEDULED_DB
from routes.email_helpers import _email_cache_owner_clause, _imap_connect, SCHEDULED_DB
from src.endpoint_resolver import resolve_endpoint
from src.llm_core import llm_call_async
# 1. Pull recent UIDs + From headers cheaply (header-only fetch).
def _pull_headers():
results = []
conn = _imap_connect(None)
conn = _imap_connect(None, owner=owner)
try:
conn.select("INBOX", readonly=True)
status, data = conn.search(None, "ALL")
@@ -855,9 +868,11 @@ async def action_learn_sender_signatures(owner: str, **kwargs) -> Tuple[str, boo
# 3. Eligibility: ≥3 emails AND (no cache OR cache > 30 days old).
try:
conn = _sql3.connect(SCHEDULED_DB)
owner_clause, owner_params = _email_cache_owner_clause(owner)
cached = {
r[0]: r[1] for r in conn.execute(
"SELECT from_address, last_built_at FROM sender_signatures"
f"SELECT from_address, last_built_at FROM sender_signatures WHERE {owner_clause}",
owner_params,
).fetchall()
}
conn.close()
@@ -888,7 +903,7 @@ async def action_learn_sender_signatures(owner: str, **kwargs) -> Tuple[str, boo
def _fetch_bodies(_msgs):
bodies = []
conn2 = _imap_connect(None)
conn2 = _imap_connect(None, owner=owner)
try:
conn2.select("INBOX", readonly=True)
for mm in _msgs:
@@ -965,11 +980,12 @@ async def action_learn_sender_signatures(owner: str, **kwargs) -> Tuple[str, boo
try:
conn = _sql3.connect(SCHEDULED_DB)
owner_value = (owner or "").strip()
conn.execute(
"INSERT OR REPLACE INTO sender_signatures "
"(from_address, signature_text, sample_count, last_built_at, model_used, source) "
"VALUES (?, ?, ?, ?, ?, ?)",
(addr, cached_sig, len(bodies), _dt.utcnow().isoformat(), model, "llm"),
"(from_address, owner, signature_text, sample_count, last_built_at, model_used, source) "
"VALUES (?, ?, ?, ?, ?, ?, ?)",
(addr, owner_value, cached_sig, len(bodies), _dt.utcnow().isoformat(), model, "llm"),
)
conn.commit()
conn.close()
+84 -6
View File
@@ -5,11 +5,13 @@ Auto-registration of built-in MCP servers on startup.
Each server runs as a stdio subprocess managed by McpManager.
"""
import asyncio
import json
import logging
import os
import shutil
import subprocess
import sys
import asyncio
from core.platform_compat import IS_WINDOWS, which_tool
@@ -196,18 +198,29 @@ def _npx_package_from_args(args):
async def _is_npx_package_cached(npx_path, package_spec, timeout_s=5):
"""Probe whether an npx package is already in the local cache.
Runs `npx --no-install <pkg> --version`. --no-install tells npx to
fail instead of downloading, so a cache miss returns fast. We treat
"exited 0 with non-empty stdout" as proof of a working cached copy.
Anything else (non-zero exit, empty stdout, timeout, missing npx,
network error) means we should skip the server.
First checks the local `_npx` cache for an installed package. If the
package is not found there, falls back to `npx --no-install <pkg>
--version` so older npm layouts still work without downloading.
"""
if _is_package_in_npx_cache(package_spec):
return True
try:
proc = await asyncio.create_subprocess_exec(
npx_path, "--no-install", package_spec, "--version",
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
except NotImplementedError:
try:
result = subprocess.run(
[npx_path, "--no-install", package_spec, "--version"],
capture_output=True,
timeout=timeout_s,
)
except (subprocess.TimeoutExpired, OSError, ValueError):
return False
return result.returncode == 0 and bool(result.stdout.strip())
except (OSError, ValueError):
return False
try:
@@ -220,3 +233,68 @@ async def _is_npx_package_cached(npx_path, package_spec, timeout_s=5):
pass
return False
return proc.returncode == 0 and bool(stdout.strip())
def _is_package_in_npx_cache(package_spec):
"""Return True when npm's `_npx` cache already contains package_spec."""
package_name = _npx_package_name(package_spec)
if not package_name:
return False
for cache_root in _npm_cache_roots():
npx_root = os.path.join(cache_root, "_npx")
if _npx_cache_contains_package(npx_root, package_name):
return True
return False
def _npx_package_name(package_spec):
"""Strip a version/range suffix from an npm package spec."""
if not package_spec:
return ""
if package_spec.startswith("@"):
parts = package_spec.split("@", 2)
if len(parts) >= 3:
return f"@{parts[1]}"
return package_spec
return package_spec.split("@", 1)[0]
def _npm_cache_roots():
roots = []
configured = os.environ.get("npm_config_cache")
if configured:
roots.append(os.path.expanduser(configured))
roots.append(os.path.join(os.path.expanduser("~"), ".npm"))
local_app_data = os.environ.get("LOCALAPPDATA")
if local_app_data:
roots.append(os.path.join(local_app_data, "npm-cache"))
return list(dict.fromkeys(roots))
def _npx_cache_contains_package(npx_root, package_name):
if not os.path.isdir(npx_root):
return False
package_path = os.path.join("node_modules", *package_name.split("/"), "package.json")
try:
entries = list(os.scandir(npx_root))
except OSError:
return False
for entry in entries:
try:
is_dir = entry.is_dir()
except OSError:
continue
cached_name = _cached_package_name(os.path.join(entry.path, package_path))
if is_dir and cached_name == package_name:
return True
return False
def _cached_package_name(package_json_path):
try:
with open(package_json_path, encoding="utf-8") as fh:
data = json.load(fh)
except (OSError, ValueError):
return ""
return str(data.get("name", "")).strip()
+178 -1
View File
@@ -128,6 +128,17 @@ def validate_caldav_url(raw_url: str) -> str:
return urlunparse(parsed._replace(fragment="")).rstrip("/")
def _event_etag(obj) -> str:
"""Best-effort ETag extraction from python-caldav resources."""
try:
etag = getattr(obj, "etag", None)
if callable(etag):
etag = etag()
return str(etag or "")
except Exception:
return ""
def _stable_cal_id(remote_url: str, owner: str = "", account_id: str = "") -> str:
"""Deterministic local id for a remote CalDAV calendar, scoped to owner
and account so two users — or one user with two accounts — pointing at
@@ -316,11 +327,12 @@ def _sync_blocking(owner: str, url: str, username: str, password: str, account_i
color="#5b8abf",
source="caldav",
account_id=account_id or None,
caldav_base_url=remote_url,
)
db.add(local_cal)
db.commit()
else:
# Refresh display name and stamp account_id if missing.
# Refresh display name and stamp CalDAV metadata if missing.
changed = False
if local_cal.name != display_name:
local_cal.name = display_name
@@ -328,6 +340,9 @@ def _sync_blocking(owner: str, url: str, username: str, password: str, account_i
if account_id and not local_cal.account_id:
local_cal.account_id = account_id
changed = True
if local_cal.caldav_base_url != remote_url:
local_cal.caldav_base_url = remote_url
changed = True
if changed:
db.commit()
result["calendars"] += 1
@@ -395,6 +410,9 @@ def _sync_blocking(owner: str, url: str, username: str, password: str, account_i
existing = _find_existing_event(db, pending, uid_val, local_cal.id)
if existing:
if existing.caldav_sync_pending in {"create", "update"}:
result["events"] += 1
continue
existing.calendar_id = local_cal.id
existing.summary = summary
existing.description = description
@@ -405,6 +423,9 @@ def _sync_blocking(owner: str, url: str, username: str, password: str, account_i
existing.is_utc = row_is_utc
existing.rrule = rrule
existing.origin = "caldav"
existing.remote_href = str(getattr(obj, "url", "") or "") or None
existing.remote_etag = _event_etag(obj) or None
existing.caldav_sync_pending = None
else:
new_ev = CalendarEvent(
uid=uid_val,
@@ -418,6 +439,8 @@ def _sync_blocking(owner: str, url: str, username: str, password: str, account_i
is_utc=row_is_utc,
rrule=rrule,
origin="caldav",
remote_href=str(getattr(obj, "url", "") or "") or None,
remote_etag=_event_etag(obj) or None,
)
db.add(new_ev)
pending[uid_val] = new_ev
@@ -442,6 +465,8 @@ def _sync_blocking(owner: str, url: str, username: str, password: str, account_i
CalendarEvent.origin == "caldav",
CalendarEvent.dtstart >= start,
CalendarEvent.dtstart <= end,
CalendarEvent.remote_href.isnot(None),
CalendarEvent.caldav_sync_pending.is_(None),
~CalendarEvent.uid.in_(seen_uids) if seen_uids else CalendarEvent.uid.isnot(None),
).all()
for ev in stale:
@@ -458,6 +483,92 @@ def _sync_blocking(owner: str, url: str, username: str, password: str, account_i
return result
def _event_payload(ev) -> dict:
return {
"uid": ev.uid,
"summary": ev.summary,
"description": ev.description,
"location": ev.location,
"dtstart": ev.dtstart,
"dtend": ev.dtend,
"all_day": ev.all_day,
"is_utc": ev.is_utc,
"rrule": ev.rrule or "",
}
def _load_event_for_writeback(owner: str, uid: str) -> tuple[str, str, dict] | None:
from core.database import CalendarCal, CalendarEvent, SessionLocal
db = SessionLocal()
try:
ev = (
db.query(CalendarEvent)
.join(CalendarCal)
.filter(CalendarEvent.uid == uid, CalendarCal.owner == owner)
.first()
)
if not ev or not ev.calendar or ev.calendar.source != "caldav":
return None
return ev.calendar.source, ev.calendar.id, _event_payload(ev)
finally:
db.close()
def _load_delete_for_writeback(owner: str, uid: str) -> tuple[str, str, dict] | None:
from core.database import CalendarCal, CalendarDeletedEvent, CalendarEvent, SessionLocal
db = SessionLocal()
try:
tombstone = db.query(CalendarDeletedEvent).filter(
CalendarDeletedEvent.uid == uid,
CalendarDeletedEvent.owner == owner,
).first()
if tombstone:
return "caldav", tombstone.calendar_id, {"uid": uid}
ev = (
db.query(CalendarEvent)
.join(CalendarCal)
.filter(CalendarEvent.uid == uid, CalendarCal.owner == owner)
.first()
)
if not ev or not ev.calendar or ev.calendar.source != "caldav":
return None
return ev.calendar.source, ev.calendar.id, {"uid": uid}
finally:
db.close()
def _pending_writeback_uids(owner: str) -> tuple[list[str], list[str]]:
from core.database import CalendarCal, CalendarDeletedEvent, CalendarEvent, SessionLocal
db = SessionLocal()
try:
rows = (
db.query(CalendarEvent.uid)
.join(CalendarCal)
.filter(
CalendarCal.owner == owner,
CalendarCal.source == "caldav",
CalendarEvent.status != "cancelled",
(
(CalendarEvent.caldav_sync_pending.isnot(None))
| (CalendarEvent.remote_href.is_(None))
),
)
.all()
)
delete_rows = (
db.query(CalendarDeletedEvent.uid)
.filter(CalendarDeletedEvent.owner == owner)
.all()
)
return [row[0] for row in rows], [row[0] for row in delete_rows]
finally:
db.close()
def _load_caldav_accounts(owner: str) -> list:
"""Return the list of CalDAV accounts for *owner*, auto-migrating the legacy
single-account ``caldav`` key to the new ``caldav_accounts`` list on first call.
@@ -533,3 +644,69 @@ async def sync_caldav(owner: str) -> dict:
for err in result.get("errors", []):
totals["errors"].append(f"{label}: {err}")
return totals
async def push_event_create(owner: str, uid: str) -> dict:
loaded = _load_event_for_writeback(owner, uid)
if not loaded:
return {"ok": True, "skipped": True}
source, calendar_id, payload = loaded
from src.caldav_writeback import writeback_event
return await writeback_event(owner, source, calendar_id, payload)
async def push_event_update(owner: str, uid: str) -> dict:
return await push_event_create(owner, uid)
async def push_event_delete(owner: str, uid: str) -> dict:
loaded = _load_delete_for_writeback(owner, uid)
if not loaded:
return {"ok": True, "skipped": True}
source, calendar_id, payload = loaded
from src.caldav_writeback import writeback_event
return await writeback_event(owner, source, calendar_id, payload, delete=True)
async def push_pending_events(owner: str) -> dict:
result = {"events": 0, "errors": []}
uids, delete_uids = _pending_writeback_uids(owner)
for event_uid in uids:
try:
out = await push_event_update(owner, event_uid)
if out.get("ok"):
result["events"] += 1
elif not out.get("skipped"):
result["errors"].append(f"{event_uid}: {str(out.get('error') or out)[:160]}")
except Exception as e:
logger.warning("CalDAV pending push failed for uid=%s: %s", event_uid, e)
result["errors"].append(f"{event_uid}: {str(e)[:160]}")
for event_uid in delete_uids:
try:
out = await push_event_delete(owner, event_uid)
if out.get("ok"):
result["events"] += 1
elif not out.get("skipped"):
result["errors"].append(f"{event_uid}: {str(out.get('error') or out)[:160]}")
except Exception as e:
logger.warning("CalDAV pending delete failed for uid=%s: %s", event_uid, e)
result["errors"].append(f"{event_uid}: {str(e)[:160]}")
return result
async def sync_caldav_direction(owner: str, direction: str = "pull") -> dict:
direction = (direction or "pull").strip().lower()
if direction == "pull":
return await sync_caldav(owner)
if direction == "push":
return await push_pending_events(owner)
if direction == "both":
pushed = await push_pending_events(owner)
pulled = await sync_caldav(owner)
return {"push": pushed, "pull": pulled}
return {
"calendars": 0,
"events": 0,
"deleted": 0,
"errors": [f"Unsupported CalDAV sync direction: {direction}"],
}
+92 -6
View File
@@ -89,6 +89,23 @@ def find_remote_calendar(calendars, local_cal_id: str, owner: str = "", account_
return None
def _resource_href(obj) -> str:
try:
return str(getattr(obj, "url", "") or "")
except Exception:
return ""
def _resource_etag(obj) -> str:
try:
etag = getattr(obj, "etag", None)
if callable(etag):
etag = etag()
return str(etag or "")
except Exception:
return ""
def push_event(calendars, local_cal_id: str, ev: dict, *, delete: bool = False,
owner: str = "", account_id: str = "") -> dict:
"""Create/update (or delete) ``ev`` on the matching remote calendar.
@@ -105,6 +122,7 @@ def push_event(calendars, local_cal_id: str, ev: dict, *, delete: bool = False,
remote = find_remote_calendar(calendars, local_cal_id, owner=owner, account_id=account_id)
if remote is None:
return {"ok": False, "error": "remote calendar not found"}
remote_url = str(getattr(remote, "url", "") or "")
try:
existing = remote.event_by_uid(uid)
@@ -113,17 +131,34 @@ def push_event(calendars, local_cal_id: str, ev: dict, *, delete: bool = False,
if delete:
if existing is None:
return {"ok": True, "note": "already absent on remote"}
return {"ok": True, "note": "already absent on remote", "calendar_url": remote_url}
existing.delete()
return {"ok": True}
return {
"ok": True,
"calendar_url": remote_url,
"remote_href": _resource_href(existing),
"remote_etag": _resource_etag(existing),
}
ical = build_event_ical(ev)
if existing is not None:
existing.data = ical
existing.save()
return {"ok": True, "updated": True}
remote.save_event(ical)
return {"ok": True, "created": True}
return {
"ok": True,
"updated": True,
"calendar_url": remote_url,
"remote_href": _resource_href(existing),
"remote_etag": _resource_etag(existing),
}
created = remote.save_event(ical)
return {
"ok": True,
"created": True,
"calendar_url": remote_url,
"remote_href": _resource_href(created),
"remote_etag": _resource_etag(created),
}
def _discover_calendars(client):
@@ -154,6 +189,54 @@ def _writeback_blocking(local_cal_id, ev, delete, url, username, password,
owner=owner, account_id=account_id)
def _persist_writeback_result(owner: str, calendar_id: str, uid: str, result: dict, *, delete: bool) -> None:
from core.database import CalendarCal, CalendarDeletedEvent, CalendarEvent, SessionLocal
if not uid or not isinstance(result, dict):
return
db = SessionLocal()
try:
calendar = db.query(CalendarCal).filter(
CalendarCal.id == calendar_id,
CalendarCal.owner == owner,
).first()
if calendar and result.get("calendar_url"):
calendar.caldav_base_url = result.get("calendar_url")
if delete:
tombstone = db.query(CalendarDeletedEvent).filter(
CalendarDeletedEvent.uid == uid,
CalendarDeletedEvent.owner == owner,
).first()
if result.get("ok"):
if tombstone:
db.delete(tombstone)
elif tombstone:
tombstone.last_error = str(result.get("error") or result)[:500]
db.commit()
return
event = (
db.query(CalendarEvent)
.join(CalendarCal)
.filter(CalendarEvent.uid == uid, CalendarCal.owner == owner)
.first()
)
if event and result.get("ok"):
if result.get("remote_href"):
event.remote_href = result.get("remote_href")
if result.get("remote_etag"):
event.remote_etag = result.get("remote_etag")
event.caldav_sync_pending = None
db.commit()
except Exception:
db.rollback()
logger.exception("CalDAV write-back metadata persistence failed")
finally:
db.close()
async def writeback_event(owner: str, calendar_source: str, calendar_id: str,
ev: dict, *, delete: bool = False) -> dict:
"""Best-effort push of a local change to the remote CalDAV server.
@@ -204,9 +287,12 @@ async def writeback_event(owner: str, calendar_source: str, calendar_id: str,
result = await asyncio.to_thread(
_writeback_blocking, calendar_id, ev, delete, url, user, pw, owner, acc_id
)
_persist_writeback_result(owner, calendar_id, (ev or {}).get("uid", ""), result, delete=delete)
if not result.get("ok"):
logger.warning("CalDAV write-back did not apply: %s", result.get("error") or result)
return result
except Exception as e:
logger.exception("CalDAV write-back raised")
return {"ok": False, "error": str(e)[:200]}
result = {"ok": False, "error": str(e)[:200]}
_persist_writeback_result(owner, calendar_id, (ev or {}).get("uid", ""), result, delete=delete)
return result
+13 -9
View File
@@ -175,6 +175,19 @@ class ChatProcessor:
Returns:
Tuple of (preface messages, rag_sources list)
Note on KV-cache friendliness: the ``system``-role messages assembled
here are later concatenated into a single system message and sent as
the very first thing in the payload (see ``llm_core``'s "consolidate
system messages" step). Local OpenAI-compatible backends (llama.cpp /
LM Studio) key their KV cache off the byte-identical token prefix, so
*anything* that changes turn-to-turn — timestamps, retrieved snippets,
per-turn counts — must NOT be folded into a system message here. Such
content belongs in a separate ``user``/context message appended near
the end of the array (see ``current_datetime_context_message`` and
``untrusted_context_message`` callers in ``build_chat_context``),
which keeps the static system prefix byte-identical across turns of
the same session and lets the backend reuse its cached prefix.
"""
preface = []
rag_sources = []
@@ -185,15 +198,6 @@ class ChatProcessor:
"role": "system",
"content": preset_system_prompt
})
if not agent_mode:
try:
from src.user_time import current_datetime_prompt
preface.append({
"role": "system",
"content": current_datetime_prompt(),
})
except Exception:
logger.debug("Failed to add current date/time context", exc_info=True)
preface.append({
"role": "system",
"content": UNTRUSTED_CONTEXT_POLICY,
+27 -7
View File
@@ -31,16 +31,22 @@ def compute_input_token_budget(
Args:
configured: the value read from settings (may be the default).
context_length: the model's discovered context window (0/unknown if none).
explicit: True if the user explicitly set ``agent_input_token_budget``.
context_length: the model's discovered context window. Pass 0 when the
window is unknown / only a bare fallback — auto-scaling then stays
conservative instead of trusting an unproven window (review on #4122).
explicit: True if the user set a NON-default budget. The default value is
the "auto" sentinel (scale to the window); any other value is an
explicit cap. (A deliberately-chosen default can't be distinguished
from a materialized default by value, so the default reads as auto.)
Rules:
- Explicit user budget is honoured exactly, only clamped to the model's
window when that window is known (never send more than the model holds).
- Otherwise (default), scale to ``headroom`` of the context window, capped
at ``hard_max`` — so long-context models use their capacity.
- When the window is unknown, fall back to the configured/default value
(preserving the previous behaviour).
window when that window is known (the user's deliberate choice wins;
``hard_max`` is an auto-budget ceiling only — see #1230).
- Otherwise (auto), scale to ``headroom`` of the context window, capped at
``hard_max`` — so long-context models use their capacity.
- When the window is unknown (context_length <= 0), use the conservative
``default`` budget and do NOT scale off the fallback.
"""
configured = int(configured or 0)
context_length = int(context_length or 0)
@@ -53,3 +59,17 @@ def compute_input_token_budget(
return max(1, min(scaled, hard_max))
return configured if configured > 0 else default
def budget_is_explicit(configured: int, *, default: int = DEFAULT_BUDGET) -> bool:
"""Whether a configured agent_input_token_budget is a deliberate explicit cap.
The default value is the "auto" sentinel (scale to the model's window), so only
a NON-default positive value counts as explicit. This keys off the VALUE, not
settings *presence* — the settings-save path materializes every default into
settings.json, so a persisted default must still read as auto (the regression
#4121 / #1230 are about). Centralised here so the materialized-default contract
is unit-testable and can't silently regress to a presence check.
"""
configured = int(configured or 0)
return configured > 0 and configured != default
+13 -5
View File
@@ -244,9 +244,17 @@ def trim_for_context(messages: List[Dict], context_length: int, reserve_tokens:
protected_tokens = estimate_tokens(protected_msgs)
budget -= protected_tokens
# Priority: keep first system msg (preset prompt), drop others (memory, RAG, memo)
essential_system = system_msgs[:1] if system_msgs else []
extra_system = system_msgs[1:]
# Priority: keep first system msg (preset prompt), drop others (memory, RAG, memo).
# Exception: a research-spinoff primer (the seeded report that grounds a
# "Discuss" chat) must never be dropped — it is the conversation's whole
# knowledge base. Treat any system message carrying research_spinoff_from
# metadata as essential alongside the leading system prompt.
def _is_research_primer(m):
return bool((m.get("metadata") or {}).get("research_spinoff_from"))
_primers = [m for m in system_msgs if _is_research_primer(m)]
_non_primer = [m for m in system_msgs if not _is_research_primer(m)]
essential_system = (_non_primer[:1] if _non_primer else []) + _primers
extra_system = _non_primer[1:]
# Try dropping extra system messages one by one (from the end)
trimmed = essential_system + convo_msgs
@@ -438,8 +446,8 @@ def _update_session_history(session, split_point: int, summary: str,
)
new_history = system_prefix + [summary_msg] + recent_history
try:
from core import models as _core_models
manager = getattr(_core_models, "_session_manager", None)
from core.models import get_session_manager_instance
manager = get_session_manager_instance()
except Exception:
manager = None
if manager and getattr(session, "id", None):
+22 -3
View File
@@ -136,7 +136,8 @@ async def _tick() -> None:
return
try:
state = json.loads(state_path.read_text(encoding="utf-8"))
except Exception:
except Exception as e:
logger.warning("cookbook_serve_lifecycle: state file unreadable (%s), skipping tick", e)
return
tasks = state.get("tasks") or []
now_ms = int(time.time() * 1000)
@@ -178,8 +179,26 @@ async def _tick() -> None:
if stopped_any:
try:
from core.atomic_io import atomic_write_json
state["tasks"] = tasks
atomic_write_json(state_path, state)
# Re-read the state file so concurrent UI writes (task adds,
# status flips, config edits) are not silently overwritten.
# Apply only our stop mutations to the fresh snapshot.
try:
fresh = json.loads(state_path.read_text(encoding="utf-8"))
fresh_tasks = fresh.get("tasks") or []
except Exception:
fresh = state
fresh_tasks = tasks
stopped_sids = {sid for sid, _, _ in to_stop}
for ft in fresh_tasks:
if not isinstance(ft, dict):
continue
ft_sid = ft.get("sessionId") or ft.get("id")
if ft_sid in stopped_sids:
ft["status"] = "stopped"
ft["_scheduledStopAtMs"] = None
ft["_lastStatusFlipAt"] = now_ms
fresh["tasks"] = fresh_tasks
atomic_write_json(state_path, fresh)
except Exception as e:
logger.warning(f"cookbook_serve_lifecycle: state write failed: {e}")
+5
View File
@@ -232,6 +232,7 @@ class DeepResearcher:
self._start_time: float = 0
self.queries_used: Set[str] = set()
self.urls_fetched: Set[str] = set()
self.analyzed_urls: List[Dict[str, str]] = []
self.round_count: int = 0
# Track which search providers actually returned results during the
# run, in arrival order — surfaced in the visual report so users can
@@ -525,6 +526,10 @@ class DeepResearcher:
if url and url not in self.urls_fetched:
urls_to_fetch.append(r)
self.urls_fetched.add(url)
self.analyzed_urls.append({
"url": url,
"title": r.get("title", "") or url,
})
if len(urls_to_fetch) >= self.max_urls_per_round * len(queries):
break
+11 -2
View File
@@ -196,13 +196,22 @@ def _get_or_reset_collection(chroma_client, name: str, metadata: Dict[str, Any],
try:
chroma_client.delete_collection(name)
restored = chroma_client.get_or_create_collection(name=name, metadata=current)
old_embeddings = preserved.get("embeddings") or []
if ids and docs and old_embeddings:
# chromadb returns embeddings as a numpy ndarray, whose truth value
# is ambiguous — `preserved.get("embeddings") or []` and a bare
# `if ... and old_embeddings:` both raise ValueError, which aborts
# the restore and loses the rows the reset was supposed to keep.
# Use explicit None/len checks instead.
old_embeddings = preserved.get("embeddings")
if old_embeddings is None:
old_embeddings = []
if ids and docs and len(old_embeddings):
for start in range(0, len(ids), 100):
batch_ids = ids[start:start + 100]
batch_docs = docs[start:start + 100]
batch_metas = metas[start:start + 100]
batch_embeddings = old_embeddings[start:start + 100]
if hasattr(batch_embeddings, "tolist"):
batch_embeddings = batch_embeddings.tolist()
if len(batch_metas) < len(batch_ids):
batch_metas += [{}] * (len(batch_ids) - len(batch_metas))
restored.add(
+27 -14
View File
@@ -12,7 +12,7 @@ from typing import Optional, Tuple, Dict
from urllib.parse import urlparse, urlunparse
from core.database import SessionLocal, ModelEndpoint
from src.llm_core import _detect_provider, _host_match, _ollama_api_root
from src.llm_core import _detect_provider, _host_match, _is_kimi_code_url, KIMI_CODE_USER_AGENT, _ollama_api_root
logger = logging.getLogger(__name__)
@@ -183,7 +183,16 @@ def build_chat_url(base: str) -> str:
def build_models_url(base: str) -> Optional[str]:
"""Return the provider-specific model-list endpoint URL for a base."""
"""Return the provider-specific model-list endpoint URL for a base.
For OpenAI-compatible servers (LM Studio, llama.cpp, vLLM,
text-generation-webui, etc.) the model list is exposed at ``/v1/models``.
When the user-supplied base has no path — e.g. ``http://localhost:1234`` —
we still need to land on ``/v1/models`` (issue #25); insert the ``/v1``
segment only when the path is empty, leaving any explicit non-empty path
untouched (so custom prefixes like ``/openai`` or ``/api/openai/v1`` keep
their semantics).
"""
base = normalize_base(resolve_url(base))
provider = _detect_provider(base)
if provider == "anthropic":
@@ -192,6 +201,12 @@ def build_models_url(base: str) -> Optional[str]:
return _ollama_api_root(base) + "/tags"
if provider == "chatgpt-subscription":
return None
# Generic OpenAI-compatible fallback: ensure the path lands on /v1/models
# when the user omitted a path entirely. If a non-empty path is already
# present (e.g. /openai, /api/openai/v1, /v1), trust the caller — the
# /models suffix is appended as-is and the caller's prefix is preserved.
if not urlparse(base).path:
base = base + "/v1"
return base + "/models"
@@ -215,6 +230,8 @@ def build_headers(api_key: Optional[str], base: str) -> Dict[str, str]:
if provider == "openrouter":
headers.setdefault("HTTP-Referer", "https://github.com/pewdiepie-archdaemon/odysseus")
headers.setdefault("X-OpenRouter-Title", "Odysseus")
if _is_kimi_code_url(base):
headers.setdefault("User-Agent", KIMI_CODE_USER_AGENT)
return headers
@@ -250,27 +267,23 @@ def resolve_endpoint(
ep_id = _stg(f"{setting_prefix}_endpoint_id")
model = _stg(f"{setting_prefix}_model")
# If the specific endpoint is not configured, but the caller provided a
# Fall back to utility model for task/research/auto-naming if not specifically configured.
if not ep_id and setting_prefix not in ("utility", "default"):
ep_id = _stg("utility_endpoint_id")
model = _stg("utility_model")
# If the endpoint is STILL not configured, but the caller provided a
# valid fallback (e.g. the active session model), use that immediately.
# This prevents background tasks from jumping to the global default_model
# when the user is mid-conversation with a different model.
if not ep_id and fallback_url and fallback_model:
return fallback_url, fallback_model, fallback_headers
# Unset Utility means "same as Default Chat Model".
if setting_prefix == "utility" and not ep_id:
# Unset Utility (or anything else that didn't have a fallback) means "same as Default Chat Model".
if not ep_id:
ep_id = _stg("default_endpoint_id")
model = _stg("default_model")
# Fall back to utility model for task/research/auto-naming if not specifically configured.
# If Utility itself is unset, the block above makes that resolve to Default Chat.
if not ep_id and setting_prefix != "utility":
ep_id = _stg("utility_endpoint_id")
model = _stg("utility_model")
if not ep_id:
ep_id = _stg("default_endpoint_id")
model = _stg("default_model")
if not ep_id:
return fallback_url, fallback_model, fallback_headers
+79 -5
View File
@@ -6,6 +6,7 @@ import re
from typing import Dict, List, Optional, Any
import httpx
from fastapi import HTTPException
from core.atomic_io import atomic_write_json
from core.platform_compat import safe_chmod
@@ -258,6 +259,11 @@ def add_integration(data: Dict[str, Any]) -> Dict[str, Any]:
integration.setdefault("name", "")
integration.setdefault("base_url", "")
if not isinstance(integration.get("name"), str) or not integration["name"].strip():
raise HTTPException(400, "Integration name is required")
if not isinstance(integration.get("base_url"), str) or not integration["base_url"].strip():
raise HTTPException(400, "Integration base URL is required")
integrations = load_integrations()
integrations.append(integration)
save_integrations(integrations)
@@ -266,6 +272,11 @@ def add_integration(data: Dict[str, Any]) -> Dict[str, Any]:
def update_integration(integration_id: str, data: Dict[str, Any]) -> Optional[Dict[str, Any]]:
"""Update fields on an existing integration. Returns updated integration or None."""
if "name" in data and (not isinstance(data["name"], str) or not data["name"].strip()):
raise HTTPException(400, "Integration name is required")
if "base_url" in data and (not isinstance(data["base_url"], str) or not data["base_url"].strip()):
raise HTTPException(400, "Integration base URL is required")
integrations = load_integrations()
for item in integrations:
if item.get("id") == integration_id:
@@ -411,17 +422,80 @@ async def execute_api_call(
if "application/json" in content_type:
try:
data = response.json()
formatted = json.dumps(data, indent=2, ensure_ascii=False)
full = json.dumps(data, indent=2, ensure_ascii=False)
if len(full) > 12000:
if isinstance(data, list):
# Binary-search for the largest prefix such that the
# final array (prefix + sentinel) fits within the limit.
# Pre-compute the sentinel so we know its serialized size.
sentinel_placeholder = {
"_truncated": True,
"total_items": len(data),
"shown_items": 0,
}
# Overhead: the sentinel appears as an extra array element.
# Add a conservative padding for the separating comma,
# newline, and indentation characters (~6 chars).
sentinel_overhead = len(
json.dumps(sentinel_placeholder, indent=2, ensure_ascii=False)
) + 6
budget = 12000 - sentinel_overhead
lo, hi = 0, len(data)
while lo < hi:
mid = (lo + hi + 1) // 2
candidate = json.dumps(
data[:mid], indent=2, ensure_ascii=False
)
if len(candidate) < budget:
lo = mid
else:
hi = mid - 1
sentinel = {
"_truncated": True,
"total_items": len(data),
"shown_items": lo,
}
formatted = json.dumps(
data[:lo] + [sentinel], indent=2, ensure_ascii=False
)
elif isinstance(data, dict):
# Truncate dict entries until the result fits, then add
# the _truncated marker. Walk keys in insertion order.
DICT_LIMIT = 12000
kept: dict = {}
for k, v in data.items():
candidate = json.dumps(
{**kept, k: v, "_truncated": True},
indent=2,
ensure_ascii=False,
)
if len(candidate) <= DICT_LIMIT:
kept[k] = v
else:
break
formatted = json.dumps(
{**kept, "_truncated": True}, indent=2, ensure_ascii=False
)
else:
total = len(full)
formatted = full[:12000] + f"\n... (truncated, {total} chars total)"
else:
formatted = full
except (json.JSONDecodeError, ValueError):
formatted = response.text
if len(formatted) > 12000:
total = len(formatted)
formatted = formatted[:12000] + f"\n... (truncated, {total} chars total)"
elif "text/html" in content_type:
formatted = _strip_html_tags(response.text)
if len(formatted) > 12000:
total = len(formatted)
formatted = formatted[:12000] + f"\n... (truncated, {total} chars total)"
else:
formatted = response.text
# Truncate
if len(formatted) > 12000:
formatted = formatted[:12000] + "\n... (truncated)"
if len(formatted) > 12000:
total = len(formatted)
formatted = formatted[:12000] + f"\n... (truncated, {total} chars total)"
output = f"HTTP {status}\n{formatted}"
+314 -17
View File
@@ -7,6 +7,7 @@ import logging
import hashlib
import threading
import re
import os
from fastapi import HTTPException
from typing import Optional, Dict, List, Tuple
from src.model_context import get_context_length, DEFAULT_CONTEXT
@@ -22,6 +23,24 @@ class LLMConfig:
MAX_RETRIES = 3
RETRY_DELAY = 0.5
STREAM_TIMEOUT = 300
# TCP+TLS connect budget for a SINGLE attempt. The old hard-coded 3.0s
# assumed LAN/Tailscale peers ('SYN in <100ms'); it is too tight for public
# cloud endpoints (offshore APIs take ~0.5-1.5s cold, with jitter), so a
# brief blip on the first connect of an idle chat surfaced as a 503 on the
# streaming path (which, unlike llm_call, does not retry the connect). A
# genuinely dead upstream stays bounded by the dead-host cooldown. Override
# with env LLM_CONNECT_TIMEOUT (seconds).
CONNECT_TIMEOUT = float(os.getenv('LLM_CONNECT_TIMEOUT', '10') or '10')
def _call_timeout(read_timeout) -> httpx.Timeout:
"""Per-request timeout for non-streaming LLM calls (connect from config)."""
return httpx.Timeout(connect=LLMConfig.CONNECT_TIMEOUT, read=float(read_timeout), write=10.0, pool=5.0)
def _stream_timeout(read_timeout) -> httpx.Timeout:
"""Per-request timeout for streaming LLM calls (connect from config)."""
return httpx.Timeout(connect=LLMConfig.CONNECT_TIMEOUT, read=float(read_timeout), write=30.0, pool=5.0)
# Cache for LLM responses
@@ -276,6 +295,24 @@ def _is_ollama_native_url(url: str) -> bool:
return local_ollama_host and (path == "" or path == "/api" or path.startswith("/api/"))
def _is_ollama_openai_compat_url(url: str) -> bool:
"""Return True for local Ollama's OpenAI-compatible /v1 surface.
Mirrors the host detection used by ``_is_ollama_native_url`` so that the
two helpers stay in lockstep: a localhost Ollama on a non-default port
(custom ``OLLAMA_HOST``, reverse proxy, container port remap) is treated
the same way here as it is on the native ``/api`` path.
"""
try:
parsed = urlparse(url or "")
except Exception:
return False
host = parsed.hostname or ""
path = (parsed.path or "").rstrip("/")
local_ollama_host = host in {"localhost", "127.0.0.1", "0.0.0.0", "::1"} or parsed.port == 11434
return local_ollama_host and (path == "/v1" or path.startswith("/v1/"))
def _ollama_api_root(url: str) -> str:
"""Return a native Ollama API root such as https://ollama.com/api."""
url = (url or "").strip().rstrip("/")
@@ -405,6 +442,146 @@ def _host_match(url: str, *domains: str) -> bool:
return any(host == d or host.endswith("." + d) for d in domains)
# Kimi Code subscription keys (api.kimi.com/coding/v1) require a whitelisted
# coding-agent User-Agent; otherwise the API returns 403 access_terminated_error.
# Tried in order; first success is cached per base URL for later requests.
KIMI_CODE_USER_AGENTS: tuple[str, ...] = (
"claude-code/0.1.0",
"claude-code/1.0.0",
"KimiCLI/1.0",
"Kilo-Code/1.0",
"Roo-Code/1.0",
"Cursor/1.0",
)
KIMI_CODE_USER_AGENT = KIMI_CODE_USER_AGENTS[0]
_kimi_code_ua_cache: dict[str, str] = {}
def _is_kimi_code_url(url: str) -> bool:
if not url or not _host_match(url, "kimi.com"):
return False
try:
return "/coding" in (urlparse(url).path or "")
except Exception:
return False
def _kimi_code_base_key(url: str) -> str:
"""Normalize a Kimi Code chat/models URL to its OpenAI base (.../coding/v1)."""
parsed = urlparse(url)
path = (parsed.path or "").rstrip("/")
for suffix in ("/chat/completions", "/models", "/completions"):
if path.endswith(suffix):
path = path[: -len(suffix)]
path = path.rstrip("/") or "/coding/v1"
return f"{parsed.scheme}://{parsed.netloc}{path}"
def _is_kimi_code_access_denied(status: int, body: bytes | str) -> bool:
if status != 403:
return False
text = body.decode("utf-8", errors="replace") if isinstance(body, bytes) else (body or "")
lower = text.lower()
return (
"access_terminated_error" in lower
or "coding agents" in lower
or "only available for coding" in lower
)
def _kimi_code_ua_candidates(url: str) -> list[str]:
if not _is_kimi_code_url(url):
return []
base_key = _kimi_code_base_key(url)
cached = _kimi_code_ua_cache.get(base_key)
if cached:
return [cached] + [ua for ua in KIMI_CODE_USER_AGENTS if ua != cached]
return list(KIMI_CODE_USER_AGENTS)
def _remember_kimi_code_user_agent(url: str, user_agent: str) -> None:
_kimi_code_ua_cache[_kimi_code_base_key(url)] = user_agent
def apply_kimi_code_headers(headers: Optional[Dict], url: str) -> Dict[str, str]:
"""Pick a Kimi Code User-Agent (cached probe when possible)."""
h = dict(headers or {})
if not _is_kimi_code_url(url):
return h
base_key = _kimi_code_base_key(url)
cached = _kimi_code_ua_cache.get(base_key)
if cached:
h["User-Agent"] = cached
return h
models_url = base_key.rstrip("/") + "/models"
from src.tls_overrides import llm_verify
for ua in KIMI_CODE_USER_AGENTS:
trial = dict(h)
trial["User-Agent"] = ua
try:
r = httpx.get(models_url, headers=trial, timeout=8, verify=llm_verify())
except Exception:
continue
if _is_kimi_code_access_denied(r.status_code, r.content):
logger.debug("Kimi Code rejected User-Agent %s (403), trying next", ua)
continue
if r.status_code < 400:
_remember_kimi_code_user_agent(url, ua)
h["User-Agent"] = ua
return h
break
h.setdefault("User-Agent", KIMI_CODE_USER_AGENT)
return h
def httpx_get_kimi_aware(url: str, headers: Optional[Dict], **kwargs):
h = apply_kimi_code_headers(headers, url)
if not _is_kimi_code_url(url):
return httpx.get(url, headers=h, **kwargs)
last = None
for ua in _kimi_code_ua_candidates(url):
trial = dict(h)
trial["User-Agent"] = ua
last = httpx.get(url, headers=trial, **kwargs)
if not _is_kimi_code_access_denied(last.status_code, last.content):
if last.status_code < 400:
_remember_kimi_code_user_agent(url, ua)
return last
return last
def httpx_post_kimi_aware(url: str, headers: Optional[Dict], **kwargs):
h = apply_kimi_code_headers(headers, url)
if not _is_kimi_code_url(url):
return httpx.post(url, headers=h, **kwargs)
last = None
for ua in _kimi_code_ua_candidates(url):
trial = dict(h)
trial["User-Agent"] = ua
last = httpx.post(url, headers=trial, **kwargs)
if not _is_kimi_code_access_denied(last.status_code, last.content):
if last.status_code < 400:
_remember_kimi_code_user_agent(url, ua)
return last
return last
async def httpx_post_kimi_aware_async(client, url: str, headers: Optional[Dict], **kwargs):
h = apply_kimi_code_headers(headers, url)
if not _is_kimi_code_url(url):
return await client.post(url, headers=h, **kwargs)
last = None
for ua in _kimi_code_ua_candidates(url):
trial = dict(h)
trial["User-Agent"] = ua
last = await client.post(url, headers=trial, **kwargs)
if not _is_kimi_code_access_denied(last.status_code, last.content):
if last.status_code < 400:
_remember_kimi_code_user_agent(url, ua)
return last
return last
def _detect_provider(url: str) -> str:
"""Detect the API provider from a configured endpoint URL.
@@ -426,6 +603,10 @@ def _detect_provider(url: str) -> str:
return "openrouter"
if _host_match(url, "groq.com"):
return "groq"
if _host_match(url, "nvidia.com"):
return "nvidia"
if _host_match(url, "moonshot.ai") or _host_match(url, "moonshot.cn"):
return "moonshot"
from src.chatgpt_subscription import is_chatgpt_subscription_base
if is_chatgpt_subscription_base(url):
return "chatgpt-subscription"
@@ -435,6 +616,53 @@ def _detect_provider(url: str) -> str:
return "openai"
def _is_self_hosted_openai_compatible(url: str) -> bool:
"""True for custom/local OpenAI-compatible servers (llama.cpp, LM Studio,
vLLM, text-generation-webui, etc.) as opposed to cloud APIs.
Used to gate llama.cpp-server-specific payload extras (``session_id``,
``cache_prompt``) used for KV-cache slot affinity (issue #2927). Strict
cloud providers reject unrecognized top-level fields (api.openai.com
returns 400, Mistral returns 422 "extra_forbidden", issue #3793), and any
unknown OpenAI-compatible host used to be treated as self-hosted, so those
fields leaked to every strict provider added as a custom endpoint.
A server only counts as self-hosted when it also resolves as local:
loopback/private/tailscale host, or the endpoint explicitly configured
with kind "local". A self-hosted server exposed via a public hostname
loses the affinity hint unless its endpoint kind is set to "local" -
a lost perf hint, versus a hard 4xx on every request the other way.
"""
if _detect_provider(url) != "openai" or _host_match(url, "openai.com"):
return False
from src.model_context import is_local_endpoint
return is_local_endpoint(url)
def _apply_local_cache_affinity(payload: Dict, url: str, session_id: Optional[str]) -> None:
"""Add llama.cpp-server slot-affinity hints to an outgoing payload, in place.
As diagnosed in issue #2927, llama.cpp assigns requests to processing
slots via LRU when no stable identifier is present ("session_id=<empty>
server-selected (LCP/LRU)"), which means consecutive turns of the same
chat can land on different slots and lose their cached prefix entirely.
Sending a stable ``session_id`` (derived from the Odysseus session) lets
the server keep routing the same conversation to the same slot, and
``cache_prompt: true`` asks it to retain/reuse the prefix it already has.
Both fields are llama.cpp / LM Studio extensions to the OpenAI schema; we
only set them for self-hosted OpenAI-compatible endpoints (never
api.openai.com or other cloud providers, which reject unrecognized
top-level request fields).
"""
if not session_id:
return
if not _is_self_hosted_openai_compatible(url):
return
payload.setdefault("session_id", str(session_id))
payload.setdefault("cache_prompt", True)
def _provider_headers(provider: str, headers: Optional[Dict] = None) -> Dict[str, str]:
h = {"Content-Type": "application/json"}
if isinstance(headers, dict):
@@ -471,9 +699,16 @@ def _provider_label(url: str) -> str:
if is_copilot_base(url): return "GitHub Copilot"
if _host_match(url, "mistral.ai"): return "Mistral"
if _host_match(url, "deepseek.com"): return "DeepSeek"
if _host_match(url, "nvidia.com"): return "NVIDIA"
if _host_match(url, "googleapis.com"): return "Google"
if _host_match(url, "together.xyz", "together.ai"): return "Together"
if _host_match(url, "fireworks.ai"): return "Fireworks"
if _host_match(url, "kimi.com"):
try:
if "/coding" in (urlparse(url).path or ""):
return "Kimi Code"
except Exception:
pass
if _is_ollama_native_url(url): return "Ollama"
try:
host = (urlparse(url).hostname or "").lower()
@@ -542,8 +777,9 @@ def _build_chatgpt_responses_payload(
}
if not _restricts_temperature(model):
payload["temperature"] = temperature
if max_tokens and max_tokens > 0:
payload["max_output_tokens"] = max_tokens
# ChatGPT Subscription Codex API does not support max_output_tokens —
# passing it returns HTTP 400 "Unsupported parameter: max_output_tokens".
# Do not include it in the payload.
return payload
@@ -613,7 +849,7 @@ def _uses_max_completion_tokens(model: str) -> bool:
# perfectly good model as failing. For these models we omit the field and let
# the API use its required default. (gpt-4.5 is intentionally excluded — it is
# not a reasoning model and accepts temperature normally.)
_FIXED_TEMPERATURE_MODELS = ("o1", "o3", "o4", "gpt-5")
_FIXED_TEMPERATURE_MODELS = ("o1", "o3", "o4", "gpt-5", "kimi-for-coding")
def _restricts_temperature(model: str) -> bool:
"""Check if a model rejects any non-default temperature."""
@@ -622,6 +858,49 @@ def _restricts_temperature(model: str) -> bool:
m = model.lower()
return any(m.startswith(p) or f"/{p}" in m for p in _FIXED_TEMPERATURE_MODELS)
# The official Moonshot API fixes temperature at 1.0 in thinking mode and 0.6
# when thinking is explicitly disabled for Kimi K2.5/K2.6. Any other explicit
# value returns HTTP 400. Odysseus does not currently send the `thinking` mode
# control, so omit temperature and let Moonshot use its default thinking mode.
# Keep the gate provider-specific: self-hosted Kimi deployments may accept
# custom sampling values, and older Moonshot models have different defaults.
def _moonshot_rejects_custom_temperature(provider: str, model: str) -> bool:
"""Check if the official Moonshot API fixes temperature for this model."""
if provider != "moonshot" or not isinstance(model, str):
return False
model_id = model.lower().rsplit("/", 1)[-1]
return bool(re.match(r"^kimi-k2\.(?:5|6)(?:$|[-_:])", model_id))
def _omit_temperature(provider: str, model: str) -> bool:
"""Check if a request should use the provider's default temperature."""
return _restricts_temperature(model) or _moonshot_rejects_custom_temperature(
provider, model
)
# Anthropic removed the sampling parameters (temperature, top_p, top_k) starting
# with Claude Opus 4.7. On Opus 4.7 and later, sending `temperature` at all —
# even 0.0 — returns HTTP 400. Earlier Claude models (Opus 4.6 and below, every
# Sonnet/Haiku) still accept temperature in [0.0, 1.0], so the omission must be
# version-gated rather than applied to all `claude-*` models.
def _anthropic_rejects_temperature(model: str) -> bool:
"""Check if a native-Anthropic model rejects the temperature field (Opus 4.7+)."""
if not isinstance(model, str) or not model:
return False
# `(?<![a-z])` anchors "opus" to a word boundary so a substring match like
# `oct-opus`/`octopus-4-8` can't be read as Opus (it would otherwise strip
# temperature). Cap the minor at 1-2 digits and forbid a trailing digit so a
# dated id like `claude-opus-4-20250514` (Opus 4.0) parses as major-only (no
# minor match, kept) instead of reading the date `20250514` as a giant minor
# that would falsely test >= 4.7. Dated 4.7+ snapshots (`claude-opus-4-7-
# 20260201`) keep their explicit minor and are still matched.
match = re.search(r"(?<![a-z])opus[-_]?(\d+)[-_.](\d{1,2})(?!\d)", model.lower())
if not match:
return False
return (int(match.group(1)), int(match.group(2))) >= (4, 7)
# Models that support structured thinking — may output </think> without opening tag
_THINKING_MODEL_PATTERNS = ("qwen3", "qwq", "deepseek-r1", "deepseek-reasoner", "minimax", "m2-reap", "gemma")
@@ -725,8 +1004,11 @@ def _build_anthropic_payload(model, messages, temperature, max_tokens, stream=Fa
"model": model,
"messages": chat_messages,
"max_tokens": max_tokens if max_tokens and max_tokens > 0 else 4096,
"temperature": temperature,
}
# Opus 4.7+ removed the sampling parameters — sending `temperature` (even 0.0)
# returns HTTP 400. Omit it for those models; older Claude models still take it.
if not _anthropic_rejects_temperature(model):
payload["temperature"] = temperature
if system_parts:
system_text = "\n\n".join(system_parts)
# Send `system` as a structured text block so we can attach a prompt-cache
@@ -810,7 +1092,7 @@ def _sanitize_llm_messages(messages: List[Dict]) -> List[Dict]:
(content=None, since Gemini/Ollama reject tool_calls alongside ""). Dropping
it leaves the tool result dangling and breaks the next round.
"""
allowed = {"role", "content", "name", "tool_call_id", "tool_calls", "function_call"}
allowed = {"role", "content", "name", "tool_call_id", "tool_calls", "function_call", "reasoning_content"}
cleaned = []
for msg in messages or []:
if not isinstance(msg, dict):
@@ -1045,7 +1327,7 @@ def list_model_ids(
from src.endpoint_resolver import build_models_url
models_url = build_models_url(base_chat_url)
r = httpx.get(models_url, headers=h, timeout=timeout)
r = httpx_get_kimi_aware(models_url, h, timeout=timeout)
r.raise_for_status()
data = r.json()
model_ids = [m.get("id") for m in (data.get("data") or []) if m.get("id")]
@@ -1146,14 +1428,14 @@ def llm_call(url: str, model: str, messages: List[Dict], temperature: float = LL
"messages": messages_copy,
"temperature": temperature,
}
if _restricts_temperature(model):
if _omit_temperature(provider, model):
payload.pop("temperature", None)
if max_tokens and max_tokens > 0:
tok_key = "max_completion_tokens" if _uses_max_completion_tokens(model) else "max_tokens"
payload[tok_key] = max_tokens
try:
note_model_activity(target_url, model)
r = httpx.post(target_url, headers=h, json=payload, timeout=timeout)
r = httpx_post_kimi_aware(target_url, h, json=payload, timeout=timeout)
except Exception as e:
raise HTTPException(502, f"POST {target_url} failed: {e}")
if not r.is_success:
@@ -1247,7 +1529,8 @@ async def llm_call_async(
headers: Optional[Dict] = None,
timeout: int = LLMConfig.STREAM_TIMEOUT,
max_retries: int = LLMConfig.MAX_RETRIES,
prompt_type: Optional[str] = None
prompt_type: Optional[str] = None,
session_id: Optional[str] = None,
) -> str:
"""Asynchronous LLM call using httpx with connection pooling, timeout, retry logic, and performance logging."""
provider = _detect_provider(url)
@@ -1339,16 +1622,20 @@ async def llm_call_async(
"messages": messages_copy,
"temperature": temperature,
}
if _restricts_temperature(model):
if _omit_temperature(provider, model):
payload.pop("temperature", None)
if max_tokens and max_tokens > 0:
tok_key = "max_completion_tokens" if _uses_max_completion_tokens(model) else "max_tokens"
payload[tok_key] = max_tokens
# Suppress thinking for qwen3/gemma4 on Ollama /v1 — same as stream_llm.
if _is_ollama_openai_compat_url(url) and _supports_thinking(model):
payload["think"] = False
_apply_local_cache_affinity(payload, url, session_id)
if _is_host_dead(target_url):
raise HTTPException(503, f"Upstream {_host_key(target_url)} marked unreachable (cooldown active)")
call_timeout = httpx.Timeout(connect=3.0, read=float(timeout), write=10.0, pool=5.0)
call_timeout = _call_timeout(timeout)
attempt = 0
while attempt < max_retries:
attempt += 1
@@ -1356,7 +1643,7 @@ async def llm_call_async(
try:
note_model_activity(target_url, model)
client = _get_http_client()
r = await client.post(target_url, headers=h, json=payload, timeout=call_timeout)
r = await httpx_post_kimi_aware_async(client, target_url, h, json=payload, timeout=call_timeout)
duration = time.time() - start
if not r.is_success:
friendly = _format_upstream_error(r.status_code, r.text, target_url)
@@ -1401,7 +1688,7 @@ async def llm_call_async(
async def stream_llm(url: str, model: str, messages: List[Dict], temperature: float = LLMConfig.DEFAULT_TEMPERATURE,
max_tokens: int = LLMConfig.DEFAULT_MAX_TOKENS, headers: Optional[Dict] = None,
timeout: int = LLMConfig.STREAM_TIMEOUT, prompt_type: Optional[str] = None,
tools: Optional[List[Dict]] = None):
tools: Optional[List[Dict]] = None, session_id: Optional[str] = None):
"""Stream LLM responses with improved error handling.
Yields SSE chunks:
@@ -1452,7 +1739,7 @@ async def stream_llm(url: str, model: str, messages: List[Dict], temperature: fl
"temperature": temperature,
"stream": True,
}
if _restricts_temperature(model):
if _omit_temperature(provider, model):
payload.pop("temperature", None)
if provider not in {"openrouter", "groq"}:
payload["stream_options"] = {"include_usage": True}
@@ -1461,14 +1748,23 @@ async def stream_llm(url: str, model: str, messages: List[Dict], temperature: fl
payload[tok_key] = max_tokens
if tools:
payload["tools"] = tools
# For Ollama's OpenAI-compat /v1 endpoint with thinking models (qwen3,
# gemma4, etc.), suppress thinking so tool calls aren't swallowed inside
# <think> blocks. Ollama /v1 accepts "think": false as a top-level param.
if _is_ollama_openai_compat_url(url) and _supports_thinking(model):
payload["think"] = False
_apply_local_cache_affinity(payload, url, session_id)
h = _provider_headers(provider, headers)
if provider == "copilot":
from src.copilot import apply_request_headers
apply_request_headers(h, messages_copy)
# Short connect timeout: a reachable peer answers SYN in <100ms even on
# Tailscale. 3s is plenty; 30s let one dead upstream wedge the UI.
stream_timeout = httpx.Timeout(connect=3.0, read=float(timeout), write=30.0, pool=5.0)
# Connect budget from LLMConfig.CONNECT_TIMEOUT (env LLM_CONNECT_TIMEOUT).
# The dead-host cooldown still bounds a genuinely unreachable upstream, so a
# wider connect budget only affects first contact and stops a brief cold
# connect blip (offshore/public endpoints) surfacing as a 503 on this stream
# path, which -- unlike llm_call -- does not retry the connect.
stream_timeout = _stream_timeout(timeout)
if _is_host_dead(target_url):
yield f'event: error\ndata: {json.dumps({"error": f"Upstream {_host_key(target_url)} unreachable (cooldown active)", "status": 503})}\n\n'
@@ -1744,6 +2040,7 @@ async def stream_llm(url: str, model: str, messages: List[Dict], temperature: fl
events.append(_stream_delta_event(part))
return events
h = apply_kimi_code_headers(h, target_url)
try:
client = _get_http_client()
async with client.stream('POST', target_url, json=payload, headers=h, timeout=stream_timeout) as r:
+82 -34
View File
@@ -5,6 +5,7 @@ Query and cache model context window sizes from OpenAI-compatible APIs.
Provides token estimation for context usage tracking.
"""
import ipaddress
import logging
import sys
from typing import Dict, List, Optional, Tuple
@@ -19,7 +20,20 @@ _LOCAL_HOSTS = {"localhost", "127.0.0.1", "0.0.0.0", "::1", "host.docker.interna
_PRIVATE_PREFIXES = ("10.", "172.16.", "172.17.", "172.18.", "172.19.",
"172.20.", "172.21.", "172.22.", "172.23.", "172.24.",
"172.25.", "172.26.", "172.27.", "172.28.", "172.29.",
"172.30.", "172.31.", "192.168.", "100.")
"172.30.", "172.31.", "192.168.")
# Tailscale uses the CGNAT range 100.64.0.0/10, NOT all of 100.0.0.0/8.
# A bare "100." prefix would classify public addresses (e.g. AWS ranges
# under 100.x outside the CGNAT block) as local; routes/model_routes.py
# already narrows this the same way for endpoint classification.
_TAILSCALE_CGNAT = ipaddress.ip_network("100.64.0.0/10")
def _in_tailscale_range(host: str) -> bool:
try:
return ipaddress.ip_address(host) in _TAILSCALE_CGNAT
except ValueError:
return False
def _normalize_base_for_compare(url: str) -> str:
@@ -64,7 +78,7 @@ def _configured_endpoint_kind(url: str) -> Optional[str]:
return None
def _is_local_endpoint(url: str) -> bool:
def is_local_endpoint(url: str) -> bool:
"""Check if URL points to a local/private/tailscale address."""
kind = _configured_endpoint_kind(url)
if kind in ("api", "proxy"):
@@ -73,7 +87,7 @@ def _is_local_endpoint(url: str) -> bool:
return True
try:
host = urlparse(url).hostname or ""
return host in _LOCAL_HOSTS or host.startswith(_PRIVATE_PREFIXES)
return host in _LOCAL_HOSTS or host.startswith(_PRIVATE_PREFIXES) or _in_tailscale_range(host)
except Exception:
return False
@@ -208,7 +222,30 @@ KNOWN_CONTEXT_WINDOWS = {
# ---------------------------------------------------------------------------
# Cache
# ---------------------------------------------------------------------------
_context_cache: Dict[Tuple[str, str], int] = {}
_context_cache: Dict[Tuple[str, str], Tuple[int, bool]] = {}
def _get_context_length_cached(endpoint_url: str, model: str) -> Tuple[int, bool]:
"""Return (context_length, known). ``known`` is False only when the value is a
bare DEFAULT_CONTEXT fallback (no endpoint report and not in the known table)."""
configured_kind = _configured_endpoint_kind(endpoint_url)
is_local = is_local_endpoint(endpoint_url)
# Key on (endpoint_url, model): the same model id can be served by two
# different remote endpoints with different real context windows (e.g. a
# capped proxy vs. the full provider), so caching by model id alone would
# serve one endpoint's window for the other (issue #2603).
cache_key = (endpoint_url, model)
if not is_local and cache_key in _context_cache:
return _context_cache[cache_key]
ctx, known = _query_context_length(endpoint_url, model)
# Only cache non-default values to allow retry on next request.
# Local endpoints can restart with a different --max-model-len while keeping
# the same model id, so always re-query them instead of serving stale cache.
if not is_local and (ctx != DEFAULT_CONTEXT or configured_kind in ("api", "proxy")):
_context_cache[cache_key] = (ctx, known)
logger.info(f"Context length for {model}: {ctx}")
return ctx, known
def get_context_length(endpoint_url: str, model: str) -> int:
@@ -218,24 +255,33 @@ def get_context_length(endpoint_url: str, model: str) -> int:
or context_window fields. Caches result per (endpoint, model).
Falls back to DEFAULT_CONTEXT if unavailable.
"""
configured_kind = _configured_endpoint_kind(endpoint_url)
is_local = _is_local_endpoint(endpoint_url)
# Key on (endpoint_url, model): the same model id can be served by two
# different remote endpoints with different real context windows (e.g. a
# capped proxy vs. the full provider), so caching by model id alone would
# serve one endpoint's window for the other (issue #2603).
cache_key = (endpoint_url, model)
if not is_local and cache_key in _context_cache:
return _context_cache[cache_key]
return _get_context_length_cached(endpoint_url, model)[0]
ctx = _query_context_length(endpoint_url, model)
# Only cache non-default values to allow retry on next request.
# Local endpoints can restart with a different --max-model-len while keeping
# the same model id, so always re-query them instead of serving stale cache.
if not is_local and (ctx != DEFAULT_CONTEXT or configured_kind in ("api", "proxy")):
_context_cache[cache_key] = ctx
logger.info(f"Context length for {model}: {ctx}")
return ctx
def get_context_length_known(endpoint_url: str, model: str) -> Tuple[int, bool]:
"""Like ``get_context_length`` but also returns whether the window was actually
discovered (endpoint-reported or in the known-models table) rather than the bare
DEFAULT_CONTEXT fallback. Callers that *scale* a budget off the window must not
trust an unknown value a fallback 128K isn't proof the model holds 128K
(review on #4122)."""
return _get_context_length_cached(endpoint_url, model)
def budget_context_for_model(endpoint_url: str, model: str, *, fallback: int = 0) -> int:
"""Context window to scale the agent input budget against.
Returns the *freshly discovered* window when it was actually proven
(endpoint-reported / known table), else 0 so auto-scaling stays conservative.
Crucially this binds the ``known`` flag to the value it proves callers must
not pair this flag with a context length from a *different* lookup (a stale
local re-query, or a caller that didn't pass one), which would budget off an
unproven number (review on #4122). On probe error, returns ``fallback`` (the
caller's best-known value) to preserve prior behaviour."""
try:
ctx, known = get_context_length_known(endpoint_url, model)
return ctx if known else 0
except Exception:
return fallback
def _lookup_known(model: str) -> Optional[int]:
@@ -257,8 +303,9 @@ def _lookup_known(model: str) -> Optional[int]:
return best_ctx
def _query_context_length(endpoint_url: str, model: str) -> int:
"""Query the model API for context length."""
def _query_context_length(endpoint_url: str, model: str) -> Tuple[int, bool]:
"""Query the model API for context length. Returns (context_length, known) where
``known`` is False only for the bare DEFAULT_CONTEXT fallback."""
known = _lookup_known(model)
api_ctx = None
configured_kind = _configured_endpoint_kind(endpoint_url)
@@ -269,11 +316,11 @@ def _query_context_length(endpoint_url: str, model: str) -> int:
if configured_kind in ("api", "proxy"):
if known:
logger.info(f"Using known context window for {model}: {known}")
return known
return DEFAULT_CONTEXT
return known, True
return DEFAULT_CONTEXT, False
# Try llama.cpp /slots endpoint first — reports actual serving context
if _is_local_endpoint(endpoint_url):
if is_local_endpoint(endpoint_url):
try:
base = endpoint_url.split("/v1")[0] if "/v1" in endpoint_url else endpoint_url.rsplit("/", 1)[0]
r = httpx.get(f"{base}/slots", timeout=REQUEST_TIMEOUT)
@@ -283,7 +330,7 @@ def _query_context_length(endpoint_url: str, model: str) -> int:
n_ctx = slots[0].get("n_ctx")
if n_ctx and isinstance(n_ctx, int) and n_ctx > 0:
logger.info(f"llama.cpp /slots reports n_ctx={n_ctx} for {model}")
return n_ctx
return n_ctx, True
except Exception:
pass
@@ -295,7 +342,8 @@ def _query_context_length(endpoint_url: str, model: str) -> int:
if is_copilot_base(endpoint_url):
if known:
logger.info(f"Using known context window for {model}: {known}")
return known or DEFAULT_CONTEXT
return known, True
return DEFAULT_CONTEXT, False
from src.endpoint_resolver import build_models_url
@@ -337,21 +385,21 @@ def _query_context_length(endpoint_url: str, model: str) -> int:
# For local/self-hosted endpoints, trust the API value (user set --max-model-len)
# For cloud APIs, use the larger value (API can report low defaults)
if api_ctx and known:
_is_local = _is_local_endpoint(endpoint_url)
_is_local = is_local_endpoint(endpoint_url)
if _is_local and api_ctx < known:
logger.info(f"Local endpoint reports {api_ctx} for {model} (known max: {known}) — using API value")
return api_ctx
return api_ctx, True
result = max(api_ctx, known)
if api_ctx < known:
logger.info(f"API reported {api_ctx} for {model}, using known {known} instead")
return result
return result, True
if api_ctx:
return api_ctx
return api_ctx, True
if known:
logger.info(f"Using known context window for {model}: {known}")
return known
return known, True
return DEFAULT_CONTEXT
return DEFAULT_CONTEXT, False
def estimate_tokens(messages: List[Dict]) -> int:
+19
View File
@@ -223,6 +223,25 @@ class ModelDiscovery:
)
return {"hosts": hosts, "items": items}
def warmup_ping_urls(self, limit: int = 5) -> List[str]:
"""The ``/models`` URLs of up to ``limit`` discovered endpoints.
Used by the startup warmup / keepalive loop to prime connections. Each
discovered item already carries a ``/v1/chat/completions`` url; swap the
suffix for the cheap ``/models`` probe. Failures degrade to an empty list
so warmup never crashes the caller.
"""
try:
items = (self.discover_models() or {}).get("items", [])
except Exception:
return []
urls: List[str] = []
for ep in items[:limit]:
url = (ep.get("url") or "").replace("/chat/completions", "/models")
if url:
urls.append(url)
return urls
def get_providers(self) -> Dict[str, Any]:
"""Get all available providers"""
discovery = self.discover_models()
+1 -1
View File
@@ -32,7 +32,7 @@ def create_office_document(
DocumentVersion,
Session as DbSession,
)
from src.tool_implementations import set_active_document
from src.agent_tools.document_tools import set_active_document
if not body_text or not body_text.strip():
return None
+32
View File
@@ -0,0 +1,32 @@
"""Compatibility helpers for optional third-party dependencies."""
from __future__ import annotations
import sys
import types
def patch_realesrgan_torchvision_compat() -> None:
"""Restore the torchvision import path expected by BasicSR/Real-ESRGAN."""
module_name = "torchvision.transforms.functional_tensor"
if module_name in sys.modules:
return
try:
from torchvision.transforms import functional
except Exception:
return
rgb_to_grayscale = getattr(functional, "rgb_to_grayscale", None)
if rgb_to_grayscale is None:
return
shim = types.ModuleType(module_name)
shim.rgb_to_grayscale = rgb_to_grayscale
shim.__getattr__ = lambda name: getattr(functional, name)
sys.modules[module_name] = shim
def prepare_optional_dependency_import(name: str) -> None:
"""Apply known import-time compatibility shims before probing a package."""
if name == "realesrgan":
patch_realesrgan_torchvision_compat()
+2 -2
View File
@@ -219,7 +219,7 @@ def create_plain_pdf_document(
pages without form-field overlays.
"""
from src.database import SessionLocal, Document, DocumentVersion, Session as DbSession
from src.tool_implementations import set_active_document
from src.agent_tools.document_tools import set_active_document
content = render_plain_pdf_markdown(upload_id, title, body_text)
db = SessionLocal()
@@ -402,7 +402,7 @@ def create_form_markdown_document(
inside the content, which the export route looks for.
"""
from src.database import SessionLocal, Document, DocumentVersion, Session as DbSession
from src.tool_implementations import set_active_document
from src.agent_tools.document_tools import set_active_document
content = render_form_as_markdown(fields, upload_id, title, intro_text=intro_text)
db = SessionLocal()
+24 -1
View File
@@ -221,6 +221,22 @@ class ResearchHandler:
# Task registry — background research with persistence
# ------------------------------------------------------------------
def rename_owner(self, old_owner: str, new_owner: str) -> int:
"""Move in-flight research tasks from one owner key to another."""
old_key = str(old_owner or "").strip().lower()
new_key = str(new_owner or "").strip().lower()
if not old_key or not new_key:
return 0
changed = 0
for entry in list(self._active_tasks.values()):
if not isinstance(entry, dict):
continue
if str(entry.get("owner", "")).strip().lower() == old_key:
entry["owner"] = new_key
changed += 1
return changed
def start_research(
self,
session_id: str,
@@ -390,7 +406,6 @@ class ResearchHandler:
def get_status(self, session_id: str) -> Optional[dict]:
"""Get current research status for a session."""
avg = self.get_avg_duration()
if session_id in self._active_tasks:
entry = self._active_tasks[session_id]
result = {
@@ -399,6 +414,14 @@ class ResearchHandler:
"query": entry["query"],
"started_at": entry["started_at"],
}
# avg_duration is a historical figure over completed reports on
# disk; get_avg_duration() globs and JSON-parses the whole research
# dir, so compute it at most once per active stream (memoized on the
# entry) instead of on every ~1s SSE poll. The disk branch below
# never used it, so it no longer pays that cost at all.
if "_avg_duration" not in entry:
entry["_avg_duration"] = self.get_avg_duration()
avg = entry["_avg_duration"]
if avg is not None:
result["avg_duration"] = round(avg, 1)
return result
+506
View File
@@ -0,0 +1,506 @@
"""Consolidated service health / degraded-state reporting.
ROADMAP: "Better degraded-state reporting for ChromaDB, SearXNG, email, ntfy,
and provider probes." There was no single readout of which subsystems are
actually working `/api/health` is only a liveness ping and each subsystem's
signal lives in a different module. This collects them into one uniform,
*non-intrusive* report (no test push is sent, no real search is run), so the
admin endpoint built on top of it is safe to poll.
Each probe returns:
{"name": str, "status": "ok"|"degraded"|"down"|"disabled",
"detail": str, "meta": dict}
- ok reachable / working
- degraded partially working (one of several components down)
- down configured & enabled but unreachable / erroring
- disabled not configured or turned off (not counted as a failure)
Design notes (driven by review feedback):
- **Bounded wall-clock.** Per-item probes (providers, email accounts) fan out
across a bounded thread pool with a hard total budget (`_FANOUT_BUDGET`);
stragglers are reported as a controlled `timeout` rather than blocking. The
aggregate adds a per-subsystem deadline (`_SUBSYSTEM_DEADLINE`) and an overall
ceiling (`_AGGREGATE_DEADLINE`), so the endpoint cannot hang regardless of how
many endpoints/accounts are configured or how slowly they respond.
- **No secret leakage.** Even though the endpoint is admin-only, the response
never returns credential-bearing URLs or raw exception text: URLs are passed
through `_safe_url` (userinfo / query / fragment stripped) and failures are
mapped to controlled categories via `_classify_error`.
The probe functions take their inputs as parameters (settings dict, account
list, endpoint list, manager objects) and isolate the network call to
``_http_get`` / injected callables, so they unit-test without touching the
network.
"""
import asyncio
import concurrent.futures
import logging
import socket
import ssl
import time
from typing import Any, Callable, Dict, List, Optional
from urllib.parse import urlparse
logger = logging.getLogger(__name__)
# Status ordering for rolling up an overall verdict. "disabled" is excluded —
# a turned-off feature must never drag the overall status down.
_SEVERITY = {"ok": 0, "degraded": 1, "down": 2}
OK = "ok"
DEGRADED = "degraded"
DOWN = "down"
DISABLED = "disabled"
# Timing budgets (seconds). _PROBE_TIMEOUT bounds a single network op;
# _FANOUT_BUDGET bounds a whole fan-out (providers/email) regardless of count;
# the aggregate layer adds a per-subsystem deadline and an overall ceiling.
_PROBE_TIMEOUT = 4
_PROBE_CONCURRENCY = 8
_FANOUT_BUDGET = 8
_SUBSYSTEM_DEADLINE = 10
_AGGREGATE_DEADLINE = 14
# Controlled, secret-free phrasing for each failure category.
_ERROR_DETAIL = {
"timeout": "probe timed out",
"connection_refused": "connection refused",
"dns_error": "host could not be resolved",
"tls_error": "TLS handshake failed",
"network_error": "network error",
"http_error": "server returned an error response",
"auth_or_protocol_error": "authentication or protocol error",
"no_models": "endpoint returned no models",
"no_host": "no host configured",
"error": "probe failed",
}
def _svc(name: str, status: str, detail: str, **meta: Any) -> Dict[str, Any]:
return {"name": name, "status": status, "detail": detail, "meta": dict(meta)}
def _safe_url(url: Optional[str]) -> str:
"""Strip credentials (userinfo), query, and fragment from a URL.
Keeps scheme / host / port / path so the report is still useful, but never
echoes `user:pass@`, `?api_key=`, or `#…` back to the caller. Returns
"<redacted>" if the URL can't be parsed into at least a host.
"""
if not url:
return ""
raw = url.strip()
try:
p = urlparse(raw if "://" in raw else "//" + raw)
host = p.hostname or ""
if not host:
return "<redacted>"
netloc = f"{host}:{p.port}" if p.port else host
path = (p.path or "").rstrip("/")
scheme = f"{p.scheme}://" if p.scheme else ""
return f"{scheme}{netloc}{path}"
except Exception:
return "<redacted>"
def _classify_error(exc: BaseException) -> str:
"""Map an exception to a controlled, secret-free category token.
Never returns `str(exc)` httpx/imaplib exception text can embed the target
URL (which may carry credentials) or server-supplied detail.
"""
if isinstance(exc, (asyncio.TimeoutError, concurrent.futures.TimeoutError,
TimeoutError, socket.timeout)):
return "timeout"
name = type(exc).__name__
mod = (type(exc).__module__ or "")
if isinstance(exc, ssl.SSLError) or "SSL" in name or "Certificate" in name:
return "tls_error"
if isinstance(exc, socket.gaierror) or name in ("gaierror", "herror"):
return "dns_error"
if isinstance(exc, ConnectionRefusedError) or "ConnectionRefused" in name \
or name in ("ConnectError",):
return "connection_refused"
if "Timeout" in name:
return "timeout"
if mod.startswith("imaplib") or name in ("error", "abort", "readonly"):
return "auth_or_protocol_error"
if name == "HTTPStatusError":
return "http_error"
if name in ("ConnectTimeout", "ReadTimeout", "ReadError", "WriteError",
"PoolTimeout", "RemoteProtocolError", "NetworkError",
"ProxyError", "ProtocolError"):
return "network_error"
if isinstance(exc, OSError):
return "network_error"
return "error"
def _detail_for(category: str) -> str:
return _ERROR_DETAIL.get(category, _ERROR_DETAIL["error"])
def _http_get(url: str, timeout: float = _PROBE_TIMEOUT):
"""Single network entry point for the HTTP probes (monkeypatched in tests)."""
import httpx
return httpx.get(url, timeout=timeout)
def _bounded_map(items: List[Any], worker: Callable[[int, Any], Dict[str, Any]],
*, budget: float = _FANOUT_BUDGET,
concurrency: int = _PROBE_CONCURRENCY) -> List[Optional[Dict[str, Any]]]:
"""Run ``worker(index, item)`` across a bounded thread pool, in order.
`worker` must catch its own exceptions and return a per-item dict. Any item
not finished within `budget` seconds *in total* is left as ``None`` (the
caller substitutes a controlled `timeout` entry). The pool is shut down with
``wait=False`` so stragglers never block the response their own per-op
timeout reaps them shortly after.
"""
n = len(items)
out: List[Optional[Dict[str, Any]]] = [None] * n
if n == 0:
return out
ex = concurrent.futures.ThreadPoolExecutor(max_workers=max(1, min(concurrency, n)))
futures = {ex.submit(worker, i, items[i]): i for i in range(n)}
try:
for fut in concurrent.futures.as_completed(futures, timeout=budget):
i = futures[fut]
try:
out[i] = fut.result()
except Exception as e: # worker is expected to handle its own errors
out[i] = {"ok": False, "error": _classify_error(e)}
except concurrent.futures.TimeoutError:
pass # unfinished items stay None → marked timeout by the caller
finally:
ex.shutdown(wait=False, cancel_futures=True)
return out
# ── ChromaDB (vector RAG + vector memory) ──
def chromadb_health(rag_manager: Any, memory_vector: Any) -> Dict[str, Any]:
"""Report on the two ChromaDB-backed stores via their `.healthy` flags.
Both absent disabled (Chroma/embeddings not installed or off).
Both healthy ok. One down degraded. Both present but unhealthy down.
"""
rag_present = rag_manager is not None
mem_present = memory_vector is not None
if not rag_present and not mem_present:
return _svc("chromadb", DISABLED,
"Vector RAG and vector memory are not initialized.",
rag=None, memory=None)
rag_ok = bool(rag_present and getattr(rag_manager, "healthy", False))
mem_ok = bool(mem_present and getattr(memory_vector, "healthy", False))
meta = {"rag": rag_ok if rag_present else None,
"memory": mem_ok if mem_present else None}
healthy = [ok for ok in (rag_ok if rag_present else None,
mem_ok if mem_present else None) if ok is not None]
if healthy and all(healthy):
return _svc("chromadb", OK, "Vector stores healthy.", **meta)
if any(healthy):
return _svc("chromadb", DEGRADED,
"One vector store is unavailable.", **meta)
return _svc("chromadb", DOWN, "Vector stores are unavailable.", **meta)
# ── SearXNG ──
def _searxng_instance(settings: Dict[str, Any]) -> str:
"""Mirror src/search/providers.py:_get_search_instance precedence."""
url = (settings.get("search_url") or "").strip()
if url:
return url.rstrip("/")
from src.constants import SEARXNG_INSTANCE
return SEARXNG_INSTANCE.rstrip("/")
def searxng_health(settings: Dict[str, Any],
*, http_get: Callable = _http_get) -> Dict[str, Any]:
"""Non-intrusive reachability probe for the configured SearXNG instance.
Tries `/healthz` (2xx), falling back to the instance root (any non-5xx means
the host answered). No search query is run. The configured instance is
probed in full, but only its sanitized form is returned in `meta`.
"""
provider = (settings.get("search_provider") or "searxng")
if provider != "searxng":
return _svc("searxng", DISABLED,
f"Search provider is '{provider}', not SearXNG.",
provider=provider)
instance = _searxng_instance(settings)
if not instance:
return _svc("searxng", DISABLED, "No SearXNG instance configured.")
safe_instance = _safe_url(instance)
last_category = "error"
for path, accept in (("/healthz", lambda c: 200 <= c < 300),
("/", lambda c: 0 < c < 500)):
try:
r = http_get(instance + path, timeout=_PROBE_TIMEOUT)
code = getattr(r, "status_code", 0)
if accept(code):
return _svc("searxng", OK, f"Reachable (HTTP {code}).",
instance=safe_instance, probed=path, http_status=code)
last_category = "http_error"
except Exception as e: # connection refused, DNS, timeout, …
last_category = _classify_error(e)
return _svc("searxng", DOWN, f"Unreachable ({_detail_for(last_category)}).",
instance=safe_instance, error=last_category)
# ── ntfy ──
def _ntfy_integration(integrations: List[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
"""First enabled ntfy integration with a base_url (matches note_routes)."""
for i in integrations or []:
if (i.get("preset") == "ntfy" and i.get("enabled", True)
and i.get("base_url")):
return i
return None
def ntfy_health(integrations: List[Dict[str, Any]], settings: Dict[str, Any],
*, http_get: Callable = _http_get) -> Dict[str, Any]:
"""Non-intrusive ntfy probe via the server's built-in `/v1/health` route.
No test notification is POSTed `/v1/health` returns `{"healthy":true}`
without publishing to a topic. The request keeps whatever credentials the
configured base_url carries, but `meta.base` is sanitized.
"""
channel = settings.get("reminder_channel") or "browser"
intg = _ntfy_integration(integrations)
if not intg:
return _svc("ntfy", DISABLED, "No ntfy integration configured.",
reminder_channel=channel)
raw = (intg.get("base_url") or "").strip()
parsed = urlparse(raw)
probe_base = (f"{parsed.scheme}://{parsed.netloc}"
if parsed.scheme and parsed.netloc else raw.rstrip("/"))
safe_base = _safe_url(raw)
try:
r = http_get(probe_base + "/v1/health", timeout=_PROBE_TIMEOUT)
code = getattr(r, "status_code", 0)
if code and code < 500:
return _svc("ntfy", OK, f"Reachable (HTTP {code}).",
base=safe_base, reminder_channel=channel, http_status=code)
return _svc("ntfy", DOWN, "Server returned an error response.",
base=safe_base, reminder_channel=channel, error="http_error")
except Exception as e:
category = _classify_error(e)
return _svc("ntfy", DOWN, f"Unreachable ({_detail_for(category)}).",
base=safe_base, reminder_channel=channel, error=category)
# ── Email (IMAP) ──
def email_health(accounts: List[Dict[str, Any]],
*, connect: Optional[Callable] = None) -> Dict[str, Any]:
"""Try a short IMAP connect+logout per configured account, concurrently.
All connect ok. Some fail degraded. All fail down. No account
configured disabled. Bounded by `_FANOUT_BUDGET` regardless of count.
`meta` carries only the account label and a controlled error category
never credentials or raw exception text.
"""
if not accounts:
return _svc("email", DISABLED, "No email accounts configured.")
if connect is None:
from routes.email_helpers import _imap_connect
# Impose the service-health budget on the IMAP connect itself.
connect = lambda aid: _imap_connect(aid, timeout=_PROBE_TIMEOUT) # noqa: E731
def _label(acc: Dict[str, Any]) -> str:
return acc.get("account_name") or acc.get("account_id") or "account"
def _check(_i: int, acc: Dict[str, Any]) -> Dict[str, Any]:
name = _label(acc)
if not (acc.get("imap_host") or ""):
return {"name": name, "ok": False, "error": "no_host"}
try:
conn = connect(acc.get("account_id"))
try:
conn.logout()
except Exception:
pass
return {"name": name, "ok": True, "error": None}
except Exception as e:
return {"name": name, "ok": False, "error": _classify_error(e)}
raw = _bounded_map(accounts, _check, budget=_FANOUT_BUDGET,
concurrency=_PROBE_CONCURRENCY)
per_account = [r if r is not None
else {"name": _label(accounts[i]), "ok": False, "error": "timeout"}
for i, r in enumerate(raw)]
return _rollup_items("email", "mailbox(es)", per_account)
# ── Provider endpoints ──
def providers_health(endpoints: List[Dict[str, Any]],
*, probe: Optional[Callable] = None) -> Dict[str, Any]:
"""Probe each enabled model endpoint's model list, concurrently.
`endpoints` is a list of plain dicts ({name, base_url, api_key}) so this
stays decoupled from the ORM and trivially testable. Non-empty model list
reachable. Bounded by `_FANOUT_BUDGET` regardless of count. `meta` never
contains api_key or raw URLs only a display name (or a sanitized URL when
no name is set) and a controlled error category.
"""
if not endpoints:
return _svc("providers", DISABLED, "No model endpoints configured.")
if probe is None:
from routes.model_routes import _probe_endpoint as probe
def _label(ep: Dict[str, Any]) -> str:
return ep.get("name") or _safe_url(ep.get("base_url")) or "endpoint"
def _check(_i: int, ep: Dict[str, Any]) -> Dict[str, Any]:
name = _label(ep)
try:
models = probe(ep.get("base_url"), ep.get("api_key"),
timeout=_PROBE_TIMEOUT) or []
except Exception as e:
return {"name": name, "ok": False, "model_count": 0,
"error": _classify_error(e)}
count = len(models)
return {"name": name, "ok": bool(count), "model_count": count,
"error": None if count else "no_models"}
raw = _bounded_map(endpoints, _check, budget=_FANOUT_BUDGET,
concurrency=_PROBE_CONCURRENCY)
per_endpoint = [r if r is not None
else {"name": _label(endpoints[i]), "ok": False,
"model_count": 0, "error": "timeout"}
for i, r in enumerate(raw)]
return _rollup_items("providers", "endpoint(s)", per_endpoint, key="endpoints")
def _rollup_items(name: str, noun: str, items: List[Dict[str, Any]],
key: str = "accounts") -> Dict[str, Any]:
"""Shared ok/degraded/down rollup for a list of per-item probe results."""
total = len(items)
ok_count = sum(1 for it in items if it.get("ok"))
if ok_count == total:
status, detail = OK, f"{ok_count}/{total} {noun} reachable."
elif ok_count == 0:
status, detail = DOWN, f"No {noun} reachable."
else:
status, detail = DEGRADED, f"{ok_count}/{total} {noun} reachable."
return _svc(name, status, detail, **{key: items})
# ── Aggregate ──
def _rollup(services: List[Dict[str, Any]]) -> str:
worst = OK
for s in services:
sev = _SEVERITY.get(s.get("status"))
if sev is not None and sev > _SEVERITY[worst]:
worst = s["status"]
return worst
def _gather_inputs() -> Dict[str, Any]:
"""Pull live config/account/endpoint lists from the app's data sources.
Each lookup fails soft: a broken source yields an empty/neutral value so a
single failure can't take down the whole health report.
"""
settings: Dict[str, Any] = {}
integrations: List[Dict[str, Any]] = []
accounts: List[Dict[str, Any]] = []
endpoints: List[Dict[str, Any]] = []
try:
from src.settings import load_settings
settings = load_settings() or {}
except Exception as e:
logger.debug(f"service_health: settings load failed: {e}")
try:
from src.integrations import load_integrations
integrations = load_integrations() or []
except Exception as e:
logger.debug(f"service_health: integrations load failed: {e}")
try:
from routes.email_helpers import _list_email_accounts
accounts = _list_email_accounts() or []
except Exception as e:
logger.debug(f"service_health: email accounts load failed: {e}")
try:
from core.database import SessionLocal, ModelEndpoint
db = SessionLocal()
try:
rows = db.query(ModelEndpoint).filter(
ModelEndpoint.is_enabled == True).all() # noqa: E712
endpoints = [{"name": r.name, "base_url": r.base_url,
"api_key": r.api_key} for r in rows]
finally:
db.close()
except Exception as e:
logger.debug(f"service_health: endpoint load failed: {e}")
return {"settings": settings, "integrations": integrations,
"accounts": accounts, "endpoints": endpoints}
async def _run_subsystem(name: str, fn: Callable, *args: Any) -> Dict[str, Any]:
"""Run one (sync) subsystem probe in a thread under a hard deadline.
A subsystem that overruns `_SUBSYSTEM_DEADLINE` (or raises) becomes a
controlled `down`/`timeout` entry instead of hanging or leaking the error.
"""
try:
return await asyncio.wait_for(asyncio.to_thread(fn, *args),
timeout=_SUBSYSTEM_DEADLINE)
except asyncio.TimeoutError:
return _svc(name, DOWN, _detail_for("timeout"), error="timeout")
except Exception as e:
category = _classify_error(e)
return _svc(name, DOWN, _detail_for(category), error=category)
async def collect_service_health(rag_manager: Any = None,
memory_vector: Any = None) -> Dict[str, Any]:
"""Run every probe and return {overall, services, timestamp}.
Bounded end-to-end: in-process ChromaDB flags are read synchronously; the
four network subsystems run concurrently, each under `_SUBSYSTEM_DEADLINE`,
with an overall `_AGGREGATE_DEADLINE` backstop. Per-item probes inside
providers/email are themselves bounded by `_FANOUT_BUDGET`.
"""
from datetime import datetime, timezone
inputs = _gather_inputs()
settings = inputs["settings"]
# ChromaDB is in-process and synchronous (just reads flags).
chroma = chromadb_health(rag_manager, memory_vector)
names = ["searxng", "ntfy", "email", "providers"]
coros = [
_run_subsystem("searxng", searxng_health, settings),
_run_subsystem("ntfy", ntfy_health, inputs["integrations"], settings),
_run_subsystem("email", email_health, inputs["accounts"]),
_run_subsystem("providers", providers_health, inputs["endpoints"]),
]
try:
results = await asyncio.wait_for(asyncio.gather(*coros),
timeout=_AGGREGATE_DEADLINE)
except asyncio.TimeoutError:
# Hard backstop — should not normally fire given per-subsystem deadlines.
results = [_svc(n, DOWN, _detail_for("timeout"), error="timeout")
for n in names]
services = [chroma, *results]
return {
"overall": _rollup(services),
"services": services,
# Timezone-aware UTC (…+00:00). Avoids the deprecated naive
# datetime.utcnow() flagged in review (overlaps with #1116).
"timestamp": datetime.now(timezone.utc).isoformat(),
}
+23 -11
View File
@@ -214,6 +214,24 @@ def _search_like(
return _rows_to_results(db, shaped, query, context_messages)
def _fetch_messages_by_id(db, message_ids):
"""Fetch (message, session_name) for many message ids in a single query.
The FTS search returns a list of hit ids; fetching each row on its own was an
N+1 query (one SELECT per hit). Batch them with one IN(...) query and return
a lookup so the caller can reassemble results in hit (relevance) order.
"""
if not message_ids:
return {}
rows = (
db.query(DBChatMessage, DBSession.name)
.join(DBSession, DBChatMessage.session_id == DBSession.id)
.filter(DBChatMessage.id.in_(message_ids))
.all()
)
return {msg.id: (msg, session_name) for msg, session_name in rows}
def _search_fts(
db,
query: str,
@@ -267,19 +285,13 @@ def _search_fts(
if not hits:
return None
by_id = _fetch_messages_by_id(db, [hit[0] for hit in hits])
rows = []
for hit in hits:
message_id = hit[0]
snippet = hit[1] or ""
row = (
db.query(DBChatMessage, DBSession.name)
.join(DBSession, DBChatMessage.session_id == DBSession.id)
.filter(DBChatMessage.id == message_id)
.first()
)
if row:
msg, session_name = row
rows.append((msg, session_name, snippet))
found = by_id.get(hit[0])
if found:
msg, session_name = found
rows.append((msg, session_name, hit[1] or ""))
return _rows_to_results(db, rows, query, context_messages)
+19 -9
View File
@@ -109,14 +109,22 @@ DEFAULT_SETTINGS = {
"research_run_timeout_seconds": 1800,
"agent_max_tool_calls": 0,
"agent_max_rounds": 20, # per-message agent step cap (clamped 1..200)
# Soft input-token budget for the agent loop. The DEFAULT value (6000) is the
# "auto" sentinel: it means "scale the budget to the model's context window"
# (#1230) — so long-context models aren't capped at 6000. Set ANY OTHER value
# to enforce an explicit cap (clamped to the window only — hard_max does not
# apply to explicit budgets, #1230); set 0 to disable soft-trimming. The
# default is treated as auto because the settings-save path materializes
# defaults, so a persisted 6000 can't be told apart from a deliberate 6000 —
# to pin a budget near the default, use a nearby value (e.g. 5999).
"agent_input_token_budget": 6000,
# Ceiling on the *auto-derived* input budget that #1230 introduced. Has
# no effect when `agent_input_token_budget` is explicitly set (the user's
# value is honoured regardless). Default matches
# `src.context_budget.DEFAULT_HARD_MAX`; lower this for cost-paranoid
# setups, raise it on premium APIs with very large windows that you
# Ceiling on the *auto-derived* input budget; a configurable setting since #1273
# (the merged #1230 left it a module constant). No effect on an explicit budget
# — a deliberate value is honoured (#1230). Default matches
# `src.context_budget.DEFAULT_HARD_MAX`; lower this for
# cost-paranoid setups, raise it on premium APIs with very large windows you
# want to actually use (e.g. 900_000 to fill a 1M-context model). See
# `compute_input_token_budget` in src/context_budget.py.
# `compute_input_token_budget`.
"agent_input_token_hard_max": 200_000,
"agent_stream_timeout_seconds": 300,
# Extra directory roots that read_file / write_file may access, in
@@ -232,8 +240,10 @@ def is_setting_overridden(key: str) -> bool:
``load_settings`` merges DEFAULT_SETTINGS with the saved file, so a value
equal to its default is indistinguishable from "never set" via get_setting.
Callers that need to treat an explicit user choice differently from the
default (e.g. adaptive budgets) use this to read the raw saved file.
Callers that must distinguish an explicit user choice from a default read
the raw saved file via this. (Note: a materialized default is also "present",
so value-sensitive callers should compare against the default see
``context_budget.budget_is_explicit``.)
"""
try:
with open(SETTINGS_FILE, "r", encoding="utf-8") as f:
@@ -292,7 +302,7 @@ def load_features() -> dict:
if not isinstance(saved, dict):
raise ValueError("features must be an object")
merged = {**DEFAULT_FEATURES, **saved}
except (FileNotFoundError, json.JSONDecodeError, ValueError):
except (FileNotFoundError, PermissionError, json.JSONDecodeError, ValueError):
merged = dict(DEFAULT_FEATURES)
_features_cache = (now, merged)
return merged
+11 -1
View File
@@ -12,6 +12,8 @@ tunnel / reverse proxy. Scrubbing is deep (recurses nested dicts/lists) and keye
on secret-shaped names.
"""
import re
_SECRET_KEY_PATTERNS = (
"_api_key", "_apikey", "_password", "_passwd", "_pass", "_pwd",
"_secret", "_client_secret", "_token", "_access_token", "_refresh_token",
@@ -26,8 +28,16 @@ _SENSITIVE_KEY_EXACT = (
)
def _canonical_key_name(name: str) -> str:
"""Normalize common JS-style key names so secret matching is style-agnostic."""
n = (name or "").replace("-", "_")
n = re.sub(r"(.)([A-Z][a-z]+)", r"\1_\2", n)
n = re.sub(r"([a-z0-9])([A-Z])", r"\1_\2", n)
return n.lower()
def is_secret_key(name: str) -> bool:
n = (name or "").lower()
n = _canonical_key_name(name)
if n in _SECRET_KEY_ALLOW:
return False
if n in _SENSITIVE_KEY_EXACT:
+52 -29
View File
@@ -1324,7 +1324,10 @@ class TaskScheduler:
db.commit()
if self._session_manager:
try:
self._session_manager.sessions[session_id] = self._session_manager._db_to_session(sess)
self._session_manager.ensure_task_session(
session_id, f"[Task] {task.name}", endpoint_url, model,
owner=task.owner, task=task
)
except Exception:
pass
@@ -1430,6 +1433,7 @@ class TaskScheduler:
task's visible output target.
"""
from core.database import Session as DbSession, ChatMessage, CrewMember
from core.models import ChatMessage as MemChatMessage
output = task.output_target or "session"
if (
@@ -1486,7 +1490,10 @@ class TaskScheduler:
db.commit()
if self._session_manager:
try:
self._session_manager.sessions[session_id] = self._session_manager._db_to_session(sess)
self._session_manager.ensure_task_session(
session_id, f"[Task] {task.name}", endpoint_url, model_name,
owner=task.owner, task=task
)
except Exception:
pass
@@ -1495,36 +1502,50 @@ class TaskScheduler:
meta["model"] = model_name
if crew and crew.is_default_assistant:
meta.update({"source": "cron", "task_id": task.id, "task_name": task.name})
msg_meta = json.dumps(meta)
user_content = task.prompt or f"[Task] {task.name}"
user_msg = ChatMessage(
id=str(uuid.uuid4()),
session_id=session_id,
role="user",
content=user_content,
timestamp=_utcnow(),
meta_data=msg_meta,
)
assistant_msg = ChatMessage(
id=str(uuid.uuid4()),
session_id=session_id,
role="assistant",
content=result or "",
timestamp=_utcnow(),
meta_data=msg_meta,
)
db.add(user_msg)
db.add(assistant_msg)
db.commit()
if self._session_manager:
# Use SessionManager for persistence so in-memory cache stays in sync
if self._session_manager and session_id:
try:
from core.models import ChatMessage as MemMsg
sess_obj = self._session_manager.get_session(session_id)
sess_obj.history.append(MemMsg(role="user", content=user_msg.content, metadata=meta))
sess_obj.history.append(MemMsg(role="assistant", content=assistant_msg.content, metadata=meta))
self._session_manager.add_message(
session_id,
MemChatMessage(
"user",
task.prompt or f"[Task] {task.name}",
metadata=dict(meta),
),
)
self._session_manager.add_message(
session_id,
MemChatMessage(
"assistant",
result or "",
metadata=dict(meta),
),
)
except Exception:
pass
logger.exception("Failed to deliver task %s through SessionManager", task.id)
else:
# Fallback: raw DB write (no session manager available)
msg_meta = json.dumps(meta)
user_msg = ChatMessage(
id=str(uuid.uuid4()),
session_id=session_id,
role="user",
content=task.prompt or f"[Task] {task.name}",
timestamp=_utcnow(),
meta_data=msg_meta,
)
assistant_msg = ChatMessage(
id=str(uuid.uuid4()),
session_id=session_id,
role="assistant",
content=result or "",
timestamp=_utcnow(),
meta_data=msg_meta,
)
db.add(user_msg)
db.add(assistant_msg)
db.commit()
@staticmethod
def _is_email_output_target(output: str) -> bool:
@@ -1641,6 +1662,8 @@ class TaskScheduler:
data = json.loads(event_str[6:])
# Capture text from all event types, not just delta
if "delta" in data:
if data.get("thinking"):
continue
full_text += data["delta"]
elif data.get("type") == "tool_output":
# Tool results — capture summary so we have SOMETHING even
+3 -1
View File
@@ -42,7 +42,7 @@ _SOTA_HOSTS = frozenset({
"api.together.xyz", "api.fireworks.ai",
"api.perplexity.ai", "api.x.ai",
"generativelanguage.googleapis.com", "api.groq.com",
"openrouter.ai", "ollama.com", "api.venice.ai",
"openrouter.ai", "ollama.com", "api.venice.ai", "api.kimi.com",
})
@@ -594,6 +594,8 @@ async def run_teacher_inline(
"exit_code": payload.get("exit_code"),
})
if "delta" in payload and isinstance(payload["delta"], str):
if payload.get("thinking"):
continue
captured_text_parts.append(payload["delta"])
yield 'data: ' + json.dumps(payload) + '\n\n'
continue
+148 -746
View File
File diff suppressed because it is too large Load Diff
+111 -677
View File
@@ -18,6 +18,40 @@ from core.constants import internal_api_base
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Active email state
# ---------------------------------------------------------------------------
# When the user has an email reader window open, the frontend tells the
# backend about it on each chat submit. Email tools can resolve "this email"
# without guessing a UID. Cleared between requests by chat_routes.
_active_email_ref: Optional[Dict[str, str]] = None
def set_active_email(uid: Optional[str], folder: Optional[str] = None, account: Optional[str] = None,
subject: Optional[str] = None, sender: Optional[str] = None) -> None:
"""Stash the email currently open in the UI. None clears it."""
global _active_email_ref
if not uid:
_active_email_ref = None
return
_active_email_ref = {
"uid": str(uid),
"folder": str(folder or "INBOX"),
"account": str(account or ""),
"subject": str(subject or ""),
"from": str(sender or ""),
}
def get_active_email() -> Optional[Dict[str, str]]:
return _active_email_ref
def clear_active_email() -> None:
global _active_email_ref
_active_email_ref = None
# ---------------------------------------------------------------------------
# Argument parsing
# ---------------------------------------------------------------------------
@@ -54,517 +88,6 @@ def _parse_tool_args(content):
args = args["body"]
return args
# ---------------------------------------------------------------------------
# Active document state
# ---------------------------------------------------------------------------
_active_document_id: Optional[str] = None
_active_model: Optional[str] = None
# When the user has an email reader window open, the frontend tells the
# backend about it on each chat submit. We stash it here so email tools
# (reply_to_email, read_email, mark_email) can resolve "this email" / "the
# open one" without the agent guessing a UID. Cleared between requests by
# chat_routes after the agent loop returns.
_active_email_ref: Optional[Dict[str, str]] = None
def set_active_email(uid: Optional[str], folder: Optional[str] = None, account: Optional[str] = None,
subject: Optional[str] = None, sender: Optional[str] = None) -> None:
"""Stash the email currently open in the UI. None clears it."""
global _active_email_ref
if not uid:
_active_email_ref = None
return
_active_email_ref = {
"uid": str(uid),
"folder": str(folder or "INBOX"),
"account": str(account or ""),
"subject": str(subject or ""),
"from": str(sender or ""),
}
def get_active_email() -> Optional[Dict[str, str]]:
return _active_email_ref
def clear_active_email() -> None:
global _active_email_ref
_active_email_ref = None
def set_active_document(doc_id: Optional[str]):
"""Set the active document ID for document tool execution."""
global _active_document_id
_active_document_id = doc_id
def set_active_model(model: Optional[str]):
"""Set the current model name for version summaries."""
global _active_model
_active_model = model
def get_active_document():
return _active_document_id
def clear_active_document(doc_id: Optional[str] = None) -> bool:
"""Clear the in-memory active-document pointer.
With ``doc_id`` given, only clears when it matches the current pointer, so a
different active document is left untouched. Returns True if it was cleared.
Called when a document is detached from its session or deleted (its tab is
closed): without this, the stale pointer makes the last-resort doc-injection
path re-surface a closed document in a later, unrelated chat even one whose
session no longer matches because an unlinked doc has session_id NULL (#1160).
"""
global _active_document_id
if doc_id is None or _active_document_id == doc_id:
_active_document_id = None
return True
return False
def _owned_document_query(query, Document, owner: Optional[str]):
if owner is None:
# A bare Python `False` is not a valid SQL expression — SQLAlchemy 1.4
# deprecates it and 2.0 raises ArgumentError. Use the SQL `false()`
# literal to return zero rows for an unscoped (owner-less) query.
from sqlalchemy import false
return query.filter(false())
return query.filter(Document.owner == owner)
def _get_owned_document(db, Document, doc_id: str, owner: Optional[str], active_only: bool = False):
q = db.query(Document).filter(Document.id == doc_id)
if active_only:
q = q.filter(Document.is_active == True)
q = _owned_document_query(q, Document, owner)
return q.first()
def _most_recent_owned_document(db, Document, owner: Optional[str], active_only: bool = False):
q = db.query(Document)
if active_only:
q = q.filter(Document.is_active == True)
q = _owned_document_query(q, Document, owner)
return q.order_by(Document.updated_at.desc()).first()
# ---------------------------------------------------------------------------
# Document tools — create/update/edit/suggest living documents
# ---------------------------------------------------------------------------
def _sniff_doc_language(text: str) -> str:
"""Best-effort detect a document's language from its content when the model
didn't specify one. Defaults to 'markdown' (prose). Recognizes the common
markup/code types the editor supports so e.g. an SVG isn't saved as markdown."""
import json as _json, re as _re2
s = (text or "").strip()
if not s:
return "markdown"
head = s[:600]
hl = head.lower()
if _looks_like_email_document(s):
return "email"
# Markup (unambiguous)
if "<svg" in hl:
return "svg"
if hl.startswith("<?xml"):
return "xml"
if (hl.startswith("<!doctype html") or hl.startswith("<html")
or _re2.search(r"<(div|body|head|p|span|table|button|h[1-6]|ul|ol|li|img)\b", hl)):
return "html"
# JSON
if s[0] in "{[":
try:
_json.loads(s)
return "json"
except Exception:
pass
# Shebang
first = s.split("\n", 1)[0].strip().lower()
if first.startswith("#!"):
return "python" if "python" in first else "bash"
# Code by strong leading signals (line-anchored so prose with stray words won't match)
if _re2.search(r"(?m)^\s*(def \w|class \w|import \w|from \w[\w.]* import )", s):
return "python"
if _re2.search(r"(?m)^\s*(function \w|const \w|let \w|export |import .* from )", s):
return "javascript"
if _re2.search(r"(?mi)^\s*(select .* from |create table |insert into |update \w)", s):
return "sql"
if _re2.search(r"(?m)^[.#]?[\w-]+\s*\{[^{}]*:[^{}]*;", s):
return "css"
return "markdown"
def _looks_like_email_document(text: str = "", title: str = "") -> bool:
import re as _re
title_l = (title or "").strip().lower()
if title_l in {"new email", "new mail", "new message"}:
return True
s = (text or "").lstrip()
if "\n---\n" in s and _re.search(r"(?im)^To:\s*", s) and _re.search(r"(?im)^Subject:\s*", s):
return True
return bool(_re.search(r"(?im)^To:\s*", s) and _re.search(r"(?im)^Subject:\s*", s))
def _coerce_email_document_content(existing: str, incoming: str) -> str:
"""Keep email docs in the To/Subject/---/body shape even if a model writes
only the body or dumps header labels without the separator."""
import re as _re
old = existing or ""
new = (incoming or "").strip()
if "\n---\n" in new:
return new
header = old.split("\n---\n", 1)[0] if "\n---\n" in old else "To: \nSubject: "
if _looks_like_email_document(new):
lines = new.splitlines()
last_header_idx = -1
header_re = _re.compile(r"^(To|Cc|Bcc|Subject|In-Reply-To|References|X-Source-UID|X-Source-Folder|X-Attachments):", _re.I)
for i, line in enumerate(lines):
if header_re.match(line.strip()):
last_header_idx = i
body_lines = lines[last_header_idx + 1:] if last_header_idx >= 0 else lines
while body_lines and not body_lines[0].strip():
body_lines.pop(0)
body = "\n".join(body_lines).strip()
else:
body = new
return header.rstrip() + "\n---\n" + body
async def do_create_document(content_block: str, session_id: Optional[str] = None, owner: Optional[str] = None) -> Dict:
"""Create a new document. Supports two formats:
1) Line-based: line 1 = title, line 2 (optional) = language, rest = content
2) XML-like tags: <title>...</title><language>...</language><content>...</content>
Some models mix them strip any XML-style tags and fall back to line parsing."""
import uuid, re as _re
from src.database import SessionLocal, Document, DocumentVersion, Session as DbSession
raw = content_block or ""
# Known languages the editor understands (match the <select> in HTML)
_KNOWN_LANGS = {
"python", "javascript", "typescript", "html", "css", "markdown", "json",
"yaml", "bash", "sql", "rust", "go", "java", "c", "cpp", "xml", "toml",
"ini", "ruby", "php", "csv", "email", "text", "plain", "svg",
}
# Try XML tag extraction first
title = None
language = None
content = None
mt = _re.search(r"<title>\s*(.*?)\s*</title>", raw, _re.DOTALL | _re.IGNORECASE)
ml = _re.search(r"<language>\s*(.*?)\s*</language>", raw, _re.DOTALL | _re.IGNORECASE)
mc = _re.search(r"<content>\s*(.*?)\s*</content>", raw, _re.DOTALL | _re.IGNORECASE)
if mt or mc:
title = mt.group(1).strip() if mt else None
language = ml.group(1).strip().lower() if ml else None
content = mc.group(1) if mc else None
# Fall back to line-based parsing. First strip any stray XML-ish tags.
if title is None or content is None:
cleaned = _re.sub(r"</?(?:title|language|content)>", "", raw)
lines = cleaned.strip().split("\n")
if title is None:
title = lines[0].strip() if lines else "Untitled"
lines = lines[1:]
# Only consume second line as language if it looks like a valid short lang token
if language is None and lines:
candidate = lines[0].strip().lower()
if candidate and len(candidate) < 20 and " " not in candidate and candidate in _KNOWN_LANGS:
language = candidate
lines = lines[1:]
if content is None:
content = "\n".join(lines)
# Validate language: must be in known set, else default based on content
if language and language not in _KNOWN_LANGS:
language = None
if not language:
# No explicit language — sniff it from the content so an SVG / HTML / JSON
# / code document isn't silently saved as markdown. Prose → markdown.
language = _sniff_doc_language(content)
if _looks_like_email_document(content, title):
language = "email"
if not title:
title = "Untitled"
if not session_id:
return {"error": "No session context for document creation"}
db = SessionLocal()
try:
doc_id = str(uuid.uuid4())
ver_id = str(uuid.uuid4())
# Inherit ownership from the chat session so the doc survives that
# session later being deleted (session_id → NULL).
_sess = db.query(DbSession).filter(DbSession.id == session_id).first()
if owner is not None and (not _sess or _sess.owner != owner):
return {"error": "Cannot create document in another user's session"}
_owner = _sess.owner if _sess else None
doc = Document(
id=doc_id,
session_id=session_id,
title=title,
language=language,
current_content=content,
version_count=1,
is_active=True,
owner=_owner,
)
ver = DocumentVersion(
id=ver_id,
document_id=doc_id,
version_number=1,
content=content,
summary=f"Created by {_active_model or 'AI'}",
source="ai",
)
db.add(doc)
db.add(ver)
db.commit()
set_active_document(doc_id)
try:
from src.event_bus import fire_event
fire_event("document_created", _owner)
except Exception:
logger.debug("document_created event dispatch failed", exc_info=True)
return {
"action": "create",
"doc_id": doc_id,
"title": title,
"language": language,
"content": content,
"version": 1,
}
except Exception as e:
db.rollback()
return {"error": f"Failed to create document: {e}"}
finally:
db.close()
async def do_update_document(content: str, doc_id: Optional[str] = None, owner: Optional[str] = None) -> Dict:
"""Update an existing document. Content = full new document text."""
import uuid
from src.database import SessionLocal, Document, DocumentVersion
target_id = doc_id or _active_document_id
db = SessionLocal()
try:
doc = None
if target_id:
doc = _get_owned_document(db, Document, target_id, owner)
if not doc:
doc = _most_recent_owned_document(db, Document, owner)
if doc:
target_id = doc.id
set_active_document(target_id)
logger.info(f"update_document: fell back to most recent doc id={target_id}")
if not doc:
return {"error": "No documents exist to update"}
is_email_doc = doc.language == "email" or _looks_like_email_document(doc.current_content or "", doc.title or "")
new_content = _coerce_email_document_content(doc.current_content or "", content) if is_email_doc else content.strip()
if is_email_doc:
doc.language = "email"
new_ver = doc.version_count + 1
ver = DocumentVersion(
id=str(uuid.uuid4()),
document_id=target_id,
version_number=new_ver,
content=new_content,
summary=f"Updated by {_active_model or 'AI'}",
source="ai",
)
doc.current_content = new_content
doc.version_count = new_ver
db.add(ver)
db.commit()
return {
"action": "update",
"doc_id": target_id,
"title": doc.title,
"language": doc.language,
"content": new_content,
"version": new_ver,
}
except Exception as e:
db.rollback()
return {"error": f"Failed to update document: {e}"}
finally:
db.close()
def parse_edit_blocks(content: str) -> list:
"""Parse <<<FIND>>>...<<<REPLACE>>>...<<<END>>> blocks."""
edits = []
pattern = r'<<<FIND>>>\n(.*?)\n<<<REPLACE>>>\n(.*?)\n<<<END>>>'
for m in re.finditer(pattern, content, re.DOTALL):
edits.append({"find": m.group(1), "replace": m.group(2)})
return edits
async def do_edit_document(content: str, doc_id: Optional[str] = None, owner: Optional[str] = None) -> Dict:
"""Apply targeted FIND/REPLACE edits to an existing document."""
import uuid
from src.database import SessionLocal, Document, DocumentVersion
target_id = doc_id or _active_document_id
edits = parse_edit_blocks(content)
if not edits:
return {"error": "No valid <<<FIND>>>...<<<REPLACE>>>...<<<END>>> blocks found"}
db = SessionLocal()
try:
doc = None
if target_id:
doc = _get_owned_document(db, Document, target_id, owner)
if not doc:
# Fallback: most recently updated document. Avoids "no active doc" errors
# after server restart or when the agent loses track of which doc to edit.
doc = _most_recent_owned_document(db, Document, owner)
if doc:
target_id = doc.id
set_active_document(target_id)
logger.info(f"edit_document: fell back to most recent doc id={target_id} title={doc.title!r}")
if not doc:
return {"error": "No documents exist to edit"}
updated_content = doc.current_content
applied = 0
skipped = 0
for edit in edits:
_find = edit["find"]
if _find in updated_content:
updated_content = updated_content.replace(_find, edit["replace"], 1)
applied += 1
else:
# Defensive: the active-doc context shows a "N\t" line-number
# gutter for reference. Weaker models sometimes copy that prefix
# into FIND. If the exact match failed, retry with a leading
# "<digits><tab>" stripped from each FIND line — but only use it
# when that stripped form actually matches, so we never corrupt a
# legitimately tab-prefixed document.
_stripped = "\n".join(re.sub(r"^\d+\t", "", _l) for _l in _find.split("\n"))
if _stripped != _find and _stripped in updated_content:
updated_content = updated_content.replace(_stripped, edit["replace"], 1)
applied += 1
logger.info("edit_document: matched after stripping line-number gutter from FIND")
else:
logger.warning(f"edit_document: FIND text not found, skipping: {_find[:80]!r}")
skipped += 1
if applied == 0:
return {"error": f"No edits applied — none of the FIND blocks matched the document content (skipped {skipped})"}
new_ver = doc.version_count + 1
ver = DocumentVersion(
id=str(uuid.uuid4()),
document_id=target_id,
version_number=new_ver,
content=updated_content,
summary=f"Edited by {_active_model or 'AI'} ({applied} edit(s))",
source="ai",
)
doc.current_content = updated_content
doc.version_count = new_ver
db.add(ver)
db.commit()
return {
"action": "edit",
"doc_id": target_id,
"title": doc.title,
"language": doc.language,
"content": updated_content,
"version": new_ver,
"applied": applied,
"skipped": skipped,
}
except Exception as e:
db.rollback()
return {"error": f"Failed to edit document: {e}"}
finally:
db.close()
def parse_suggest_blocks(content: str) -> list:
"""Parse <<<FIND>>>...<<<SUGGEST>>>...<<<REASON>>>...<<<END>>> blocks."""
suggestions = []
_skip_phrases = ["no change", "clear", "fine as", "looks good", "no improvement", "keep as"]
pattern = r'<<<FIND>>>\n(.*?)\n<<<SUGGEST>>>\n(.*?)\n<<<REASON>>>\n(.*?)\n<<<END>>>'
for m in re.finditer(pattern, content, re.DOTALL):
find_text = m.group(1)
replace_text = m.group(2)
reason = m.group(3).strip()
# Skip no-op suggestions where find == replace or reason says no change
if find_text.strip() == replace_text.strip():
continue
if any(phrase in reason.lower() for phrase in _skip_phrases):
continue
suggestions.append({
"id": f"sugg-{len(suggestions)+1}",
"find": find_text,
"replace": replace_text,
"reason": reason,
})
return suggestions
async def do_suggest_document(content: str, doc_id: str = None, owner: Optional[str] = None) -> Dict:
"""Create inline suggestions for the active document WITHOUT modifying it."""
from src.database import SessionLocal, Document
target_id = doc_id or _active_document_id
if not target_id:
return {"error": "No active document to suggest on"}
suggestions = parse_suggest_blocks(content)
if not suggestions:
return {"error": "No valid <<<FIND>>>...<<<SUGGEST>>>...<<<REASON>>>...<<<END>>> blocks found"}
db = SessionLocal()
try:
doc = _get_owned_document(db, Document, target_id, owner)
if not doc:
return {"error": f"Document {target_id} not found"}
# Validate that FIND text exists in document
valid = []
for s in suggestions:
if s["find"] in doc.current_content:
valid.append(s)
else:
logger.warning(f"suggest_document: FIND text not found, skipping: {s['find'][:80]!r}")
if not valid:
return {"error": "No suggestions matched the document content"}
return {
"action": "suggest",
"doc_id": target_id,
"suggestions": valid,
"count": len(valid),
}
finally:
db.close()
# ---------------------------------------------------------------------------
# Search chats
# ---------------------------------------------------------------------------
@@ -1392,147 +915,6 @@ async def do_manage_tokens(content: str, owner: Optional[str] = None) -> Dict:
finally:
db.close()
# ---------------------------------------------------------------------------
# Document management tool (delete, list, organize)
# ---------------------------------------------------------------------------
async def do_manage_documents(content: str, owner: Optional[str] = None) -> Dict:
"""Manage documents: list, read/view/open, delete, tidy.
Output format mirrors `manage_session`: list rows include a
clickable `[Title](#document-<id>)` anchor + relative timestamps
so the user can click straight from chat to open the editor.
"""
from core.database import SessionLocal, Document
from datetime import datetime, timezone
try:
args = _parse_tool_args(content)
except ValueError:
return {"error": "Invalid JSON arguments", "exit_code": 1}
action = args.get("action", "list")
db = SessionLocal()
def _rel(ts):
if not ts:
return 'never'
try:
now = datetime.now(timezone.utc) if ts.tzinfo is not None else datetime.utcnow()
diff = (now - ts).total_seconds()
except Exception:
return 'unknown'
if diff < 60: return 'just now'
if diff < 3600: return f'{int(diff / 60)}m ago'
if diff < 86400: return f'{int(diff / 3600)}h ago'
if diff < 86400 * 7: return f'{int(diff / 86400)}d ago'
return ts.strftime('%Y-%m-%d')
try:
if action == "list":
q = db.query(Document).filter(Document.is_active == True)
q = _owned_document_query(q, Document, owner)
if args.get("search"):
q = q.filter(Document.title.ilike(f"%{args['search']}%"))
if args.get("language"):
q = q.filter(Document.language == args["language"])
docs = q.order_by(Document.updated_at.desc()).limit(args.get("limit", 50)).all()
if not docs:
msg = "No documents found" + (f" matching '{args['search']}'" if args.get("search") else "") + "."
return {"response": msg, "documents": [], "exit_code": 0}
lines = []
items = []
for i, d in enumerate(docs):
size = len(d.current_content or "")
lang = d.language or "text"
ts = getattr(d, 'updated_at', None) or getattr(d, 'created_at', None)
marker = " ← most recent" if i == 0 else ""
lines.append(
f"- [{d.title}](#document-{d.id}) — {lang}, {size} chars, updated {_rel(ts)}{marker}"
)
items.append({"id": d.id, "title": d.title, "language": lang, "size": size})
header = f"Found {len(docs)} document(s), sorted most-recent first. Click a title to open:"
return {
"response": header + "\n" + "\n".join(lines),
"documents": items,
"exit_code": 0,
}
elif action in ("read", "view", "open", "get"):
doc_id = args.get("document_id") or args.get("id") or args.get("uid")
if not doc_id:
return {"error": "Need document_id (use action=list to find one)", "exit_code": 1}
doc = _get_owned_document(db, Document, doc_id, owner, active_only=True)
if not doc:
return {"error": f"Document '{doc_id}' not found", "exit_code": 1}
body = doc.current_content or ""
total = len(body)
# Clamp offset to [0, total] so a far-out offset returns an empty
# window with a useful "end of document" hint rather than erroring.
try: offset = int(args.get("offset", 0))
except (TypeError, ValueError): offset = 0
offset = max(0, min(offset, total))
preview_limit = int(args.get("limit", MAX_READ_CHARS))
chunk = body[offset:offset + preview_limit]
next_offset = offset + len(chunk)
has_more = next_offset < total
# Trailing marker — tells the agent (and a curious human) exactly
# what to pass next to continue paginating.
if has_more:
marker = f"\n... ({total - next_offset:,} more chars; pass offset={next_offset} to continue)"
elif offset > 0:
marker = f"\n... (end of document, {total:,} chars total)"
else:
marker = ""
preview = chunk + marker
anchor = f"[{doc.title}](#document-{doc.id})"
return {
"response": f"{anchor} — click to open in editor.\n\n```{doc.language or ''}\n{preview}\n```",
"document": {
"id": doc.id,
"title": doc.title,
"language": doc.language,
"size": total,
"content": chunk,
"offset": offset,
"next_offset": next_offset if has_more else None,
"truncated": has_more,
},
"exit_code": 0,
}
elif action == "delete":
doc_id = args.get("document_id") or args.get("id") or args.get("uid") or _active_document_id
doc = None
if doc_id:
doc = _get_owned_document(db, Document, doc_id, owner)
if not doc:
# Fallback: most recently updated doc (likely what the user means)
doc = _most_recent_owned_document(db, Document, owner, active_only=True)
if not doc:
return {"error": "No document to delete", "exit_code": 1}
title = doc.title
doc.is_active = False
db.commit()
if _active_document_id == doc.id:
set_active_document(None)
return {"response": f"Deleted document '{title}'", "exit_code": 0}
elif action == "tidy":
from src.document_actions import run_document_tidy
result = await run_document_tidy(owner or "")
return {"response": result, "exit_code": 0}
else:
return {"error": f"Unknown action: {action}", "exit_code": 1}
except Exception as e:
logger.error(f"manage_documents error: {e}")
return {"error": str(e), "exit_code": 1}
finally:
db.close()
# ---------------------------------------------------------------------------
# Settings/preferences management tool
# ---------------------------------------------------------------------------
@@ -2097,7 +1479,15 @@ async def do_manage_calendar(content: str, owner: Optional[str] = None) -> Dict:
"""Handle manage_calendar tool calls: list/create/update/delete calendar events (local SQLite)."""
from datetime import datetime, timedelta
from core.database import SessionLocal, CalendarCal, CalendarEvent, Note
from routes.calendar_routes import _ensure_default_calendar, _parse_dt, _parse_dt_pair, parse_due_for_user, _resolve_base_uid
from routes.calendar_routes import (
_ensure_default_calendar,
_parse_dt,
_parse_dt_pair,
parse_due_for_user,
_resolve_base_uid,
_push_caldav_event_after_commit,
_record_caldav_delete_tombstone,
)
import uuid as _uuid
try:
@@ -2105,6 +1495,42 @@ async def do_manage_calendar(content: str, owner: Optional[str] = None) -> Dict:
except ValueError:
return {"error": "Invalid JSON arguments", "exit_code": 1}
# ── Batch normalization ──
# Some models (e.g. deepseek-v4-flash) emit {"events": [{...}, ...]}
# instead of individual create_event calls. Iterate and create each.
if isinstance(args.get("events"), list) and not args.get("action"):
results = []
for ev in args["events"]:
if not isinstance(ev, dict):
continue
# Normalize start/end from {dateTime: "..."} object to flat string
for field, target in [("start", "dtstart"), ("end", "dtend")]:
val = ev.pop(field, None)
if val and target not in ev:
ev[target] = val.get("dateTime", val) if isinstance(val, dict) else val
ev.setdefault("action", "create_event")
r = await do_manage_calendar(json.dumps(ev), owner=owner)
results.append(r)
created = [r for r in results if r.get("exit_code") == 0 and not r.get("error")]
failed = [r for r in results if r.get("error")]
if not results:
return {"error": "No events to create", "exit_code": 1}
# Surface both successes and failures
parts = []
if created:
summaries = [r.get("response", "") for r in created]
parts.append(f"Created {len(created)} event(s):\n" + "\n".join(summaries))
if failed:
first_error = failed[0].get("error", "Unknown error")
parts.append(f"Failed to create {len(failed)} event(s). First error: {first_error}")
response = "\n\n".join(parts)
# Non-zero exit code for partial or total failure
exit_code = 0 if not failed else 1
return {"response": response, "exit_code": exit_code, "created_count": len(created), "failed_count": len(failed)}
# Normalize action — some models emit hyphens ("list-calendars") instead
# of underscores. Treat them as equivalent so we don't bounce a
# cosmetic typo back to the model and waste a round-trip. Also accept
@@ -2259,6 +1685,9 @@ async def do_manage_calendar(content: str, owner: Optional[str] = None) -> Dict:
except ValueError as e:
return {"error": f"Invalid date format: {e}", "exit_code": 1}
if end_dt <= start_dt:
end_dt = start_dt + timedelta(days=1)
q = _event_query().filter(
CalendarEvent.dtstart < end_dt,
CalendarEvent.dtend > start_dt,
@@ -2438,6 +1867,7 @@ async def do_manage_calendar(content: str, owner: Optional[str] = None) -> Dict:
rrule=args.get("rrule", "") or "",
event_type=event_type,
importance=importance,
caldav_sync_pending="create" if cal.source == "caldav" else None,
)
db.add(ev)
reminder_note_id = None
@@ -2452,6 +1882,8 @@ async def do_manage_calendar(content: str, owner: Optional[str] = None) -> Dict:
dtstart_is_utc and not all_day,
)
db.commit()
if cal.source == "caldav":
await _push_caldav_event_after_commit(owner, uid, "create")
tag_blurb = f" [{event_type}]" if event_type else ""
if minutes_before is None:
reminder_blurb = ""
@@ -2509,7 +1941,12 @@ async def do_manage_calendar(content: str, owner: Optional[str] = None) -> Dict:
ev.event_type = _tag or None
if args.get("importance") is not None:
ev.importance = args["importance"]
is_caldav = ev.calendar and ev.calendar.source == "caldav"
if is_caldav:
ev.caldav_sync_pending = "update"
db.commit()
if is_caldav:
await _push_caldav_event_after_commit(owner, base_uid, "update")
return {"response": f"Updated event {uid}", "exit_code": 0}
elif action == "delete_event":
@@ -2523,8 +1960,13 @@ async def do_manage_calendar(content: str, owner: Optional[str] = None) -> Dict:
ev = _event_query().filter(CalendarEvent.uid == base_uid).first()
if not ev:
return {"error": f"Event {uid} not found", "exit_code": 1}
is_caldav = ev.calendar and ev.calendar.source == "caldav" and ev.remote_href
if is_caldav:
_record_caldav_delete_tombstone(db, ev, owner)
db.delete(ev)
db.commit()
if is_caldav:
await _push_caldav_event_after_commit(owner, base_uid, "delete")
return {"response": f"Deleted event {uid}", "exit_code": 0}
else:
@@ -2670,13 +2112,14 @@ async def _cookbook_env_for_host(host: str) -> Dict[str, Any]:
else:
env_prefix = f'eval "$(conda shell.bash hook)" && conda activate {env_path}'
from routes.cookbook_helpers import load_stored_hf_token
return {
"env_prefix": env_prefix,
"env_type": env_kind,
"env_path": env_path,
"gpus": env_root.get("gpus") or "",
"platform": platform,
"hf_token": env_root.get("hfToken") or "",
"hf_token": load_stored_hf_token(),
"ssh_port": ssh_port,
}
@@ -2733,7 +2176,7 @@ async def _ensure_served_endpoint(
try:
async with httpx.AsyncClient(timeout=30) as client:
resp = await client.post(
f"{_COOKBOOK_BASE}/api/model-endpoints",
f"{_INTERNAL_BASE}/api/model-endpoints",
data=payload,
headers=_internal_headers(),
)
@@ -4428,24 +3871,16 @@ async def do_manage_contact(content: str, owner: Optional[str] = None) -> Dict:
if action == "add":
email = (args.get("email") or "").strip()
name = (args.get("name") or "").strip() or (email.split("@")[0] if email else "")
address = (args.get("address") or "").strip()
# Need at least one identifying field. Address-only (e.g. a
# business location with no email) is fine as long as there's
# a name.
if not email and not name:
return {"error": "Provide at least name+address or email for add", "exit_code": 1}
# Dedupe by email when one is given.
if email:
existing = await asyncio.to_thread(cc._fetch_contacts)
for c in existing:
if email.lower() in [e.lower() for e in c.get("emails", [])]:
return {"output": f"{email} is already a contact ({c.get('name','')}).", "exit_code": 0}
ok = await asyncio.to_thread(cc._create_contact, name, email, address)
tail = f" <{email}>" if email else ""
if address:
tail += f"{address}"
return {"output": f"{'Added' if ok else 'Failed to add'} {name}{tail}.", "exit_code": 0 if ok else 1}
if not email:
return {"error": "email is required for add", "exit_code": 1}
name = (args.get("name") or "").strip() or email.split("@")[0]
# Dedupe by email (same as the /add route).
existing = await asyncio.to_thread(cc._fetch_contacts)
for c in existing:
if email.lower() in [e.lower() for e in c.get("emails", [])]:
return {"output": f"{email} is already a contact ({c.get('name','')}).", "exit_code": 0}
ok = await asyncio.to_thread(cc._create_contact, name, email)
return {"output": f"{'Added' if ok else 'Failed to add'} {name} <{email}>.", "exit_code": 0 if ok else 1}
if action in ("update", "edit"):
uid = (args.get("uid") or "").strip()
@@ -4457,12 +3892,11 @@ async def do_manage_contact(content: str, owner: Optional[str] = None) -> Dict:
emails = [args["email"]]
emails = [e.strip() for e in (emails or []) if e and e.strip()]
phones = [p.strip() for p in (args.get("phones") or []) if p and p.strip()]
address = (args.get("address") or "").strip()
if not name and not emails and not address:
return {"error": "Provide a name, emails, or address to update", "exit_code": 1}
if not name and not emails:
return {"error": "Provide a name or emails to update", "exit_code": 1}
if not name and emails:
name = emails[0].split("@")[0]
ok = await asyncio.to_thread(cc._update_contact, uid, name, emails, phones, address)
ok = await asyncio.to_thread(cc._update_contact, uid, name, emails, phones)
return {"output": "Contact updated." if ok else "Update failed.", "exit_code": 0 if ok else 1}
if action == "delete":
+7 -2
View File
@@ -67,14 +67,15 @@ COLLECTION_NAME = "odysseus_tool_index"
# Each tool gets a searchable description that helps retrieval.
# These are richer than the system prompt one-liners — they're for embedding.
BUILTIN_TOOL_DESCRIPTIONS: Dict[str, str] = {
"bash": "Run shell commands on the server. Install packages, check files, git operations, system info, and process management. Do not use for web lookup/search; use web_search or web_fetch when web tools are available.",
"python": "Execute Python code for computation, data processing, math, scripting, and parsing. Not for writing code for the user. Do not use for web lookup/search; use web_search or web_fetch when web tools are available.",
"bash": "Run shell commands on the server. Install packages, git operations, builds, system info, process management. Prefer a dedicated tool whenever one fits the job (file read/write/edit, search, listing); use bash only for what no dedicated tool covers. Do not use for web lookup/search; use web_search or web_fetch when web tools are available.",
"python": "Execute Python code for computation, data processing, math, scripting, and parsing. Not for writing code for the user. Prefer a dedicated tool for reading, writing, or searching files; use python only for what no dedicated tool covers. Do not use for web lookup/search; use web_search or web_fetch when web tools are available.",
"web_search": "Quick single web lookup for a fact, current event, latest/current information, or doc mid-task. Use this instead of bash/curl/python/requests for web searches. NOT for 'research X' / 'do research on X' requests — those are deep-research jobs (use trigger_research). web_search = one query; trigger_research = a full researched report in the sidebar.",
"web_fetch": "Fetch and read the text content of a specific URL/website the user names (e.g. 'check example.com', 'open this link'). Use when you have a concrete URL; for open-ended lookups use web_search instead.",
"read_file": "Read a file from disk and return its contents. View source code, config files, logs. Supports an optional line range (offset/limit) for large files.",
"grep": "Search file CONTENTS for a regex across a directory tree (ripgrep-backed, honours .gitignore). Returns file:line:match. Use to find where code/symbols/strings live — prefer over bash grep.",
"glob": "Find FILES by glob pattern (e.g. '**/*.py'), newest first. Use to locate files by name/extension — prefer over bash find/ls.",
"ls": "List a directory's entries (folders then files with sizes). Use to see what's in a folder — prefer over bash ls.",
"get_workspace": "Return the absolute path of the active workspace folder the user is working in. File tools are confined to it; the shell starts there but is not sandboxed. Call this first when the user refers to 'the project'/'the code'/'this folder' without giving a path, instead of asking them.",
"write_file": "Write/create or fully rewrite a file ON DISK (source code, configs, project files). Use for new files or full rewrites — NOT create_document (editor panel) and NOT a bash heredoc.",
"edit_file": "Edit an existing file ON DISK by exact string replacement (fix a bug, change a function). Shows a diff. The tool for changing files on disk — NOT edit_document (editor panel) and NOT bash sed/heredoc.",
"create_document": "Create a new document in the editor panel. For code, articles, text content longer than 15 lines, unless an already-open document/email draft is the obvious target. If an email compose draft is open, edit that draft instead of creating another document.",
@@ -395,6 +396,10 @@ class ToolIndex:
"delegate to", "have model"}):
{"chat_with_model", "ask_teacher", "list_models"},
# Deep research intent (incl. common typo "reserach")
frozenset({"web search", "search the web", "search online", "look up",
"google", "latest", "current", "news", "weather",
"forecast", "stock price", "price of"}):
{"web_search", "web_fetch"},
frozenset({"research", "reserach", "reasearch", "look into", "investigate",
"deep dive", "deep research", "find out about", "study up on",
"report on", "do research", "look up everything"}):
+86
View File
@@ -188,6 +188,12 @@ _MISFENCED_WEB_TOOL_NAMES = {
"fetch_url": "web_fetch",
}
_RAW_WEB_JSON_TOOL_RE = re.compile(
r"\b(?:web_search|websearch|google_search|google_search_retrieval|google_search_grounding)\b",
re.IGNORECASE,
)
_RAW_WEB_JSON_ALLOWED_KEYS = {"query", "queries", "time_filter", "freshness", "max_pages"}
# ---------------------------------------------------------------------------
# Parsing functions
@@ -279,6 +285,73 @@ def _parse_misfenced_web_lookup(content: str) -> Optional[ToolBlock]:
return None
return ToolBlock("web_fetch", url)
def _coerce_raw_web_query(value) -> Optional[str]:
if isinstance(value, str) and value.strip():
return value.strip()
if isinstance(value, list):
for item in value:
if isinstance(item, str) and item.strip():
return item.strip()
return None
def _raw_web_json_to_tool_block(payload) -> Optional[ToolBlock]:
if not isinstance(payload, dict):
return None
if set(payload) - _RAW_WEB_JSON_ALLOWED_KEYS:
return None
query = _coerce_raw_web_query(payload.get("query"))
if not query:
query = _coerce_raw_web_query(payload.get("queries"))
if not query:
return None
content = {"query": query}
for key in ("time_filter", "freshness"):
value = payload.get(key)
if isinstance(value, str) and value.strip().lower() in ("day", "week", "month", "year"):
content[key] = value.strip().lower()
max_pages = payload.get("max_pages")
if isinstance(max_pages, int) and 1 <= max_pages <= 10:
content["max_pages"] = max_pages
if len(content) == 1:
return ToolBlock("web_search", query)
return ToolBlock("web_search", json.dumps(content))
def _parse_raw_web_json_lookup(text: str) -> Optional[tuple[ToolBlock, tuple[int, int]]]:
"""Recover local text-model web_search calls emitted as prose + bare JSON.
Some non-native tool models leak the intended call as:
Need to do web_search for ...
{"query": "...", "time_filter": "week"}
Keep this narrower than fenced/tool markup: it only runs when a known web
tool name appears shortly before a JSON object shaped like web_search args.
"""
if not isinstance(text, str):
return None
decoder = json.JSONDecoder()
for mention in _RAW_WEB_JSON_TOOL_RE.finditer(text):
search_start = mention.end()
search_end = min(len(text), search_start + 1200)
for brace in re.finditer(r"\{", text[search_start:search_end]):
start = search_start + brace.start()
try:
parsed, end = decoder.raw_decode(text[start:])
except json.JSONDecodeError:
continue
block = _raw_web_json_to_tool_block(parsed)
if block:
return block, (start, start + end)
return None
def _parse_tool_call_block(raw: str) -> Optional[ToolBlock]:
"""Parse a [TOOL_CALL] block into a ToolBlock.
@@ -436,6 +509,8 @@ def parse_tool_blocks(text: str, skip_fenced: bool = False) -> List[ToolBlock]:
3. XML-style <tool_call>/<invoke> blocks
4. <tool_code> blocks (MiniMax-M2.5 style)
5. DeepSeek DSML markup (normalized to <invoke> first)
6. Non-native local model fallback: prose mentioning web_search followed by
bare JSON args, e.g. {"query":"...", "time_filter":"week"}
`skip_fenced`: when True, Pattern 1 (fenced ```bash/```python/```json code
blocks) is not matched at all. Native function-calling models (GPT/Claude/
@@ -509,6 +584,12 @@ def parse_tool_blocks(text: str, skip_fenced: bool = False) -> List[ToolBlock]:
if block:
blocks.append(block)
# Pattern 6: local text-model web_search call leaked as prose + bare JSON.
if not blocks and not skip_fenced:
raw_web_json = _parse_raw_web_json_lookup(text)
if raw_web_json:
blocks.append(raw_web_json[0])
return blocks
@@ -532,6 +613,11 @@ def strip_tool_blocks(text: str, skip_fenced: bool = False) -> str:
cleaned = _TOOL_CALL_RE.sub('', cleaned)
cleaned = _XML_TOOL_CALL_RE.sub('', cleaned)
cleaned = _TOOL_CODE_RE.sub('', cleaned)
if not skip_fenced:
raw_web_json = _parse_raw_web_json_lookup(cleaned)
if raw_web_json:
_, (start, end) = raw_web_json
cleaned = cleaned[:start] + cleaned[end:]
# Strip bare <invoke> blocks not wrapped in <tool_call>
cleaned = re.sub(r'<invoke\s+name=["\'].*?</invoke>', '', cleaned, flags=re.DOTALL | re.IGNORECASE)
cleaned = re.sub(r'\n{3,}', '\n\n', cleaned)
+12 -2
View File
@@ -25,7 +25,7 @@ FUNCTION_TOOL_SCHEMAS = [
"type": "function",
"function": {
"name": "bash",
"description": "Run a shell command (full access)",
"description": "Run a shell command (full access). Prefer a dedicated tool whenever one fits the job (reading, writing, editing, searching, or listing files); use bash only for what no dedicated tool covers (installs, git, builds, running programs, system info). Do NOT create or edit files via bash redirects/heredocs/sed -- use the dedicated file tools.",
"parameters": {
"type": "object",
"properties": {
@@ -39,7 +39,7 @@ FUNCTION_TOOL_SCHEMAS = [
"type": "function",
"function": {
"name": "python",
"description": "Execute Python code to compute a result or test something",
"description": "Execute Python code to compute a result or test something. Prefer a dedicated tool whenever one fits the job (reading, writing, or searching files); use python only for computation, data processing, or scripting no dedicated tool covers.",
"parameters": {
"type": "object",
"properties": {
@@ -141,6 +141,14 @@ FUNCTION_TOOL_SCHEMAS = [
}
}
},
{
"type": "function",
"function": {
"name": "get_workspace",
"description": "Return the absolute path of the active workspace folder the user is working in. File tools are confined to it; the shell starts there but is not sandboxed. Call this first when the user refers to 'the project'/'the code'/'this folder' without a path, instead of asking them. Takes no arguments.",
"parameters": {"type": "object", "properties": {}, "required": []}
}
},
{
"type": "function",
"function": {
@@ -1247,6 +1255,8 @@ def function_call_to_tool_block(name: str, arguments: str) -> Optional[ToolBlock
content = args.get("path", "")
elif tool_type in ("grep", "glob", "ls"):
content = json.dumps(args) if args else "{}"
elif tool_type == "get_workspace":
content = ""
elif tool_type == "write_file":
content = args.get("path", "") + "\n" + args.get("content", "")
elif tool_type == "edit_file":
+20 -2
View File
@@ -20,6 +20,7 @@ NON_ADMIN_BLOCKED_TOOLS = {
"grep",
"glob",
"ls",
"get_workspace",
"search_chats",
"manage_memory",
"manage_skills",
@@ -66,6 +67,7 @@ PLAN_MODE_READONLY_TOOLS = {
"grep",
"glob",
"ls",
"get_workspace",
"web_search",
"web_fetch",
"search_chats",
@@ -162,13 +164,29 @@ def is_public_blocked_tool(tool_name: Optional[str]) -> bool:
def owner_is_admin_or_single_user(owner: Optional[str]) -> bool:
"""Return True for admins, or when auth is not configured yet."""
"""Return True for admins, or in intentional single-user mode.
Single-user mode means the operator explicitly disabled auth
(``AUTH_ENABLED=false``) the local/self-host default where the owner has
full access to their own box.
The pre-setup window (auth ENABLED but no admin created yet) is treated as
NON-admin: returning True there would hand server-execution tools
(``bash``/``python``) to any caller before setup completes. The auth
middleware already 401s ``/api/`` requests pre-setup, so this is
defense-in-depth for callers that bypass it (e.g. trusted loopback).
"""
try:
from src.auth_helpers import _auth_disabled
if _auth_disabled():
return True
from core.auth import AuthManager
auth = AuthManager()
if not auth.is_configured:
return True
return False
return bool(owner and auth.is_admin(owner))
except Exception as exc:
logger.warning("Unable to evaluate owner admin status: %s", exc)
+80
View File
@@ -352,6 +352,86 @@ class UploadHandler:
return dict(info)
return None
def _renamed_upload_index_key(self, key: str, info: Dict[str, Any], old_owner: str, new_owner: str) -> str:
"""Return the storage key to use after renaming an owned upload row."""
if isinstance(key, str) and ":" in key:
owner_part, rest = key.split(":", 1)
if owner_part.strip().lower() == old_owner:
return f"{new_owner}:{rest}"
file_hash = info.get("hash")
if file_hash:
return f"{new_owner}:{file_hash}"
return key
def _unique_upload_index_key(self, base_key: str, used_keys: set, reserved_keys: set, info: Dict[str, Any]) -> str:
"""Choose a deterministic collision key without overwriting an existing row."""
if base_key not in used_keys and base_key not in reserved_keys:
return base_key
upload_id = str(info.get("id") or "renamed").strip() or "renamed"
candidate = f"{base_key}:{upload_id}"
if candidate not in used_keys and candidate not in reserved_keys:
return candidate
index = 2
while True:
candidate = f"{base_key}:{upload_id}:{index}"
if candidate not in used_keys and candidate not in reserved_keys:
return candidate
index += 1
def rename_owner(self, old_owner: str, new_owner: str) -> int:
"""Rename upload metadata ownership from old_owner to new_owner.
Upload rows are keyed by owner-qualified hashes for dedupe and also
carry an `owner` field for access checks. Both must move together when
usernames change.
"""
old_owner_normalized = str(old_owner or "").strip().lower()
new_owner = str(new_owner or "").strip()
if not old_owner_normalized or not new_owner:
return 0
if old_owner_normalized == new_owner.lower():
return 0
uploads_db_path = os.path.join(self.upload_dir, "uploads.json")
with self._index_lock:
current = self._load_upload_index()
if not current:
return 0
updated = {}
renamed = 0
original_keys = set(current.keys())
for key, info in current.items():
new_key = key
new_info = info
if isinstance(info, dict) and str(info.get("owner", "")).strip().lower() == old_owner_normalized:
new_info = dict(info)
new_info["owner"] = new_owner
base_key = self._renamed_upload_index_key(key, new_info, old_owner_normalized, new_owner)
new_key = self._unique_upload_index_key(
base_key,
set(updated.keys()),
original_keys - {key},
new_info,
)
if new_key != base_key:
logger.warning(
"Upload owner rename key collision for %s -> %s at %s; preserving row as %s",
old_owner_normalized,
new_owner,
base_key,
new_key,
)
renamed += 1
updated[new_key] = new_info
if renamed:
self._atomic_write_json(uploads_db_path, updated)
return renamed
def _find_upload_path(self, upload_id: str) -> Optional[str]:
"""Find an upload file by ID while staying inside upload_dir."""
if not self.validate_upload_id(upload_id):
+24 -1
View File
@@ -9,7 +9,7 @@ from __future__ import annotations
import re
from contextvars import ContextVar
from datetime import datetime, timedelta, timezone
from typing import Optional
from typing import Dict, Optional
_USER_TZ_OFFSET_MIN: ContextVar[Optional[int]] = ContextVar("user_tz_offset_min", default=None)
@@ -136,3 +136,26 @@ def current_datetime_prompt(now_utc: Optional[datetime] = None) -> str:
"When scheduling a task with manage_tasks, scheduled_time is in UTC: "
"convert the user's stated local time using the UTC offset above.\n\n"
)
def current_datetime_context_message(now_utc: Optional[datetime] = None) -> Dict[str, str]:
"""Build the current-date/time context as a standalone chat message.
This intentionally returns a ``user``-role message rather than a
``system``-role one. The text changes every turn (it embeds the current
clock time down to the minute), and local OpenAI-compatible backends
(llama.cpp / LM Studio) key their KV-cache prefix off the system message
byte-for-byte folding ever-changing timestamp text into the system
message would invalidate the cached prefix on every single request (see
issue #2927). Keeping it as a separate message placed near the end of the
array (right before the latest user turn) lets the static system prompt
stay byte-identical across turns while the model still gets fresh
date/time grounding for relative-date reasoning.
"""
return {
"role": "user",
"content": (
"[Context — current date/time, refreshed each turn; not part of "
"your instructions]\n" + current_datetime_prompt(now_utc)
),
}
+15 -3
View File
@@ -202,6 +202,18 @@ class WebhookManager:
self._client = httpx.AsyncClient(timeout=10, follow_redirects=False)
self._loop: Optional[asyncio.AbstractEventLoop] = None
self._api_key_manager = api_key_manager
# Strong references to in-flight fire-and-forget tasks. asyncio only
# keeps weak references to tasks, so without this the GC can collect a
# delivery task mid-flight and the webhook is silently never sent.
self._bg_tasks: set = set()
def _spawn_tracked(self, coro):
"""Schedule a background task and hold a strong reference until it
finishes, so it can't be garbage-collected before delivery completes."""
task = asyncio.ensure_future(coro)
self._bg_tasks.add(task)
task.add_done_callback(self._bg_tasks.discard)
return task
def set_loop(self, loop: asyncio.AbstractEventLoop):
self._loop = loop
@@ -223,8 +235,8 @@ class WebhookManager:
if event not in ALLOWED_EVENTS:
return
try:
loop = asyncio.get_running_loop()
loop.create_task(self.fire(event, payload))
asyncio.get_running_loop()
self._spawn_tracked(self.fire(event, payload))
except RuntimeError:
# Called from a sync thread (e.g. sync FastAPI route in threadpool)
if self._loop and self._loop.is_running():
@@ -243,7 +255,7 @@ class WebhookManager:
for wh in matching:
decrypted_secret = self._decrypt_secret(wh.secret)
asyncio.create_task(self._deliver(wh.id, wh.url, decrypted_secret, event, payload))
self._spawn_tracked(self._deliver(wh.id, wh.url, decrypted_secret, event, payload))
async def deliver_test(self, webhook_id: str, url: str, encrypted_secret: Optional[str]):
"""Public method for the test-webhook route."""
+18 -273
View File
@@ -1,278 +1,23 @@
"""
YouTube handling transcript extraction, comment fetching (yt-dlp),
and context formatting for LLM injection. Used by chat_handler.py.
"""Compatibility wrapper for the canonical services.youtube.youtube_handler module.
Odysseus historically carried two independent copies of the YouTube handler
one here under ``src`` and one under ``services.youtube``. They drifted: the
comment-fetch timeout fix landed only in the ``src`` copy, while ``app.py``
calls ``services.youtube.init_youtube()`` at startup. Because the chat flow
imported ``extract_transcript_async`` from ``src.youtube_handler`` (a different
module object), the ``YOUTUBE_AVAILABLE`` / ``YouTubeTranscriptApi`` globals set
by ``init_youtube`` never reached it and transcript extraction always reported
"YouTube transcript API not available".
Keep the old ``src.youtube_handler`` import path working, but make it resolve to
the single source of truth so module state and behavior can't diverge again.
"""
import asyncio
import json
import logging
import shutil
import importlib
import sys
import urllib.parse
from pathlib import Path
from typing import Dict, Any, Optional
logger = logging.getLogger(__name__)
# Import the canonical module directly (services.youtube.youtube_handler)
# without triggering the heavy services/__init__.py top-level imports.
_youtube_handler = importlib.import_module("services.youtube.youtube_handler")
# ---------------------------------------------------------------------------
# Constants
# ---------------------------------------------------------------------------
YOUTUBE_INSTRUCTION_PROMPT = """When the user shares a YouTube video, respond with a structured breakdown:
1. **Summary** Concise overview of the video's content and main thesis (2-4 sentences)
2. **Key Points** Bullet list of the most important topics, arguments, or moments
3. **Notable Timestamps** If timestamps are available from the transcript, highlight 3-5 interesting moments with their approximate timestamps (e.g. "03:45 — discusses X")
4. **Audience Reception** If comments are available, summarize what viewers think: general sentiment, top reactions, any debate or controversy
Keep it conversational and concise. Do NOT web search for this video use only the transcript and comments provided."""
# ---------------------------------------------------------------------------
# Init / helpers
# ---------------------------------------------------------------------------
# Will be set at startup by init_youtube()
YouTubeTranscriptApi = None
YOUTUBE_AVAILABLE = False
def _find_ytdlp() -> str:
"""Find the yt-dlp binary: venv bin first, then system PATH."""
venv_bin = Path(sys.executable).parent / "yt-dlp"
if venv_bin.exists():
return str(venv_bin)
found = shutil.which("yt-dlp")
return found or "yt-dlp"
def init_youtube():
"""Import and cache the YouTube transcript API."""
global YouTubeTranscriptApi, YOUTUBE_AVAILABLE
try:
from youtube_transcript_api import YouTubeTranscriptApi as _Api
YouTubeTranscriptApi = _Api
YOUTUBE_AVAILABLE = True
logger.info("YouTube transcript API available")
except ImportError as e:
logger.warning(f"youtube-transcript-api not installed: {e}")
YOUTUBE_AVAILABLE = False
def is_youtube_url(url: str) -> bool:
if not isinstance(url, str):
return False
return "youtube.com" in url or "youtu.be" in url
def extract_youtube_id(url: str) -> Optional[str]:
"""Extract YouTube video ID from various URL formats."""
parsed = urllib.parse.urlparse(url)
if parsed.hostname in ("www.youtube.com", "youtube.com", "m.youtube.com"):
if parsed.path == "/watch":
params = urllib.parse.parse_qs(parsed.query)
if "v" in params:
return params["v"][0]
elif parsed.path.startswith("/embed/"):
return parsed.path.split("/")[-1]
elif parsed.hostname == "youtu.be":
return parsed.path[1:]
return None
async def extract_transcript_async(
url: str, video_id: str, max_retries: int = 3
) -> Dict[str, Any]:
"""
Async YouTube transcript extraction with retries.
Args:
url: Full YouTube URL
video_id: Extracted video ID
max_retries: Number of attempts
Returns:
Dict with success/error/transcript keys
"""
if not YOUTUBE_AVAILABLE or YouTubeTranscriptApi is None:
return {"success": False, "error": "YouTube transcript API not available", "transcript": None}
for attempt in range(max_retries):
try:
api = YouTubeTranscriptApi()
transcript = api.fetch(video_id)
transcript_list = list(transcript)
formatted = []
for snippet in transcript_list:
text = snippet.text.strip()
if not text:
continue
start = snippet.start
formatted.append({
"text": text,
"start": start,
"duration": snippet.duration,
"timestamp": f"{int(start // 60):02d}:{int(start % 60):02d}",
})
full_text = " ".join(e["text"] for e in formatted)
max_len = 8000
if len(full_text) > max_len:
full_text = full_text[:max_len] + "... [transcript truncated]"
return {
"success": True,
"transcript": full_text,
"video_id": video_id,
"language": "en",
"is_generated": False,
"segments": formatted,
}
except Exception as e:
logger.warning(f"Transcript attempt {attempt + 1} failed: {e}")
if attempt < max_retries - 1:
await asyncio.sleep(1 * (attempt + 1))
return {"success": False, "error": f"Failed after {max_retries} attempts", "transcript": None}
def format_transcript_for_context(
transcript_data: Dict[str, Any], url: str,
title: str = "", channel: str = ""
) -> str:
"""Format transcript data for inclusion in LLM context."""
if not transcript_data.get("success"):
header = ""
if title:
header = f" \"{title}\""
if channel:
header += f" by {channel}"
return f"\n[YouTube Video{header}: Transcript unavailable ({transcript_data.get('error', 'Unknown error')}). Use the comments below if available, do NOT web search for this video.]"
transcript = transcript_data.get("transcript", "")
video_id = transcript_data.get("video_id", "")
language = transcript_data.get("language", "unknown")
is_generated = transcript_data.get("is_generated", False)
segments = transcript_data.get("segments", [])
ctx = "\n[YOUTUBE VIDEO TRANSCRIPT]\n"
if title:
ctx += f"Title: {title}\n"
if channel:
ctx += f"Channel: {channel}\n"
ctx += f"Video ID: {video_id}\n"
ctx += f"Language: {language}\n"
ctx += f"Source: {'Auto-generated' if is_generated else 'Manual'}\n"
ctx += f"URL: {url}\n\n"
# Include timestamped segments for the LLM to reference
if segments:
ctx += "Timestamped Transcript:\n"
for seg in segments:
if not isinstance(seg, dict):
continue
ctx += f"[{seg['timestamp']}] {seg['text']}\n"
# Check length — fall back to plain text if too long
if len(ctx) > 12000:
ctx = ctx[:ctx.index("Timestamped Transcript:\n")]
ctx += "Transcript:\n"
ctx += transcript
else:
ctx += "Transcript:\n"
ctx += transcript
ctx += "\n[END TRANSCRIPT]\n"
return ctx
async def fetch_youtube_comments(
video_id: str, max_comments: int = 25, timeout: int = 30
) -> Dict[str, Any]:
"""Fetch top comments for a YouTube video using yt-dlp.
Returns dict with 'success', 'comments' list, 'error'.
"""
try:
cmd = [
_find_ytdlp(),
"--skip-download",
"--write-comments",
"--extractor-args", f"youtube:max_comments={max_comments},all,100,0",
"--dump-json",
"--js-runtimes", "node",
"--remote-components", "ejs:github",
f"https://www.youtube.com/watch?v={video_id}",
]
proc = await asyncio.create_subprocess_exec(
*cmd,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
# Bound the wait on the process actually finishing, not on spawning it.
# create_subprocess_exec returns as soon as the child starts, so wrapping
# it in wait_for never enforces the timeout — proc.communicate() is the
# blocking step. Kill and reap the child if it overruns so it does not
# linger after we return.
try:
stdout, stderr = await asyncio.wait_for(
proc.communicate(), timeout=timeout
)
except asyncio.TimeoutError:
proc.kill()
await proc.wait()
raise
if proc.returncode != 0:
return {"success": False, "error": f"yt-dlp failed: {stderr.decode()[:200]}", "comments": []}
data = json.loads(stdout.decode())
title = data.get("title", "")
channel = data.get("channel", "") or data.get("uploader", "")
raw_comments = data.get("comments", [])
comments = []
for c in raw_comments[:max_comments]:
text = (c.get("text") or "").strip()
if not text:
continue
comments.append({
"author": c.get("author", "Unknown"),
"text": text,
"likes": c.get("like_count", 0),
})
# Sort by likes descending — most popular comments first
comments.sort(key=lambda x: x.get("likes", 0), reverse=True)
return {"success": True, "comments": comments, "count": len(comments),
"title": title, "channel": channel}
except asyncio.TimeoutError:
logger.warning(f"Comment fetch timed out for {video_id}")
return {"success": False, "error": "Comment fetch timed out", "comments": []}
except FileNotFoundError:
logger.warning("yt-dlp not installed — cannot fetch comments")
return {"success": False, "error": "yt-dlp not installed", "comments": []}
except Exception as e:
logger.warning(f"Failed to fetch comments for {video_id}: {e}")
return {"success": False, "error": str(e), "comments": []}
def format_comments_for_context(comments_data: Dict[str, Any], url: str) -> str:
"""Format YouTube comments for inclusion in LLM context."""
if not comments_data.get("success") or not comments_data.get("comments"):
return ""
comments = comments_data["comments"]
ctx = f"\n[YOUTUBE VIDEO COMMENTS — Top {len(comments)} by popularity]\n"
ctx += f"URL: {url}\n\n"
for i, c in enumerate(comments, 1):
likes = c.get("likes", 0)
likes_str = f" [{likes} likes]" if likes else ""
ctx += f"{i}. @{c['author']}{likes_str}: {c['text']}\n\n"
if len(ctx) > 4000:
ctx = ctx[:4000] + "\n[Comments truncated]\n"
ctx += "[END COMMENTS]\n"
return ctx
sys.modules[__name__] = _youtube_handler