mirror of
https://github.com/pewdiepie-archdaemon/odysseus.git
synced 2026-06-27 07:05:23 -04:00
Merge remote-tracking branch 'origin/dev' into test-main-dev-merge-20260615
# Conflicts: # src/tool_implementations.py # static/js/research/panel.js
This commit is contained in:
@@ -91,6 +91,9 @@ _ROUTING_PATTERNS: tuple[tuple[str, str, Pattern[str]], ...] = tuple(
|
||||
("ui", "tool or feature toggle request", r"\b(?:disable|enable|turn\s+(?:on|off))\s+(?:the\s+)?(?:shell|search|web|browser|documents?|memory|skills|images?|calendar|email|mail|research|incognito)\b"),
|
||||
|
||||
# Deep research jobs, not quick conceptual mentions of research.
|
||||
("web", "explicit web search request", rf"{_PLEASE}(?:do|run|use|perform|make)\s+(?:a\s+)?(?:web\s+search|search\s+the\s+web)\b.+"),
|
||||
("web", "web lookup imperative request", rf"{_PLEASE}(?:web\s+search|search\s+the\s+web|search\s+online|look\s+up|google)\b.+"),
|
||||
("web", "assistant web lookup request", rf"{_ACTION_QUESTION}(?:web\s+search|search\s+the\s+web|search\s+online|look\s+up|google)\b.+"),
|
||||
("research", "deep research imperative request", rf"{_PLEASE}(?:research|deep\s+dive|look\s+into|investigate)\s+.+"),
|
||||
("research", "assistant deep research request", rf"{_ACTION_QUESTION}(?:research|do\s+research|deep\s+dive|look\s+into|investigate)\s+.+"),
|
||||
|
||||
|
||||
+179
-53
@@ -21,7 +21,7 @@ from src.settings import get_setting
|
||||
from src.prompt_security import untrusted_context_message
|
||||
from src.tool_security import blocked_tools_for_owner, plan_mode_disabled_tools
|
||||
from src.tool_policy import GUIDE_ONLY_DIRECTIVE, ToolPolicy
|
||||
from src.tool_utils import get_mcp_manager
|
||||
from src.tool_utils import _truncate, get_mcp_manager
|
||||
from src.agent_tools import (
|
||||
parse_tool_blocks,
|
||||
strip_tool_blocks,
|
||||
@@ -262,6 +262,11 @@ _DOMAIN_RULES = {
|
||||
- Use `manage_settings` for preferences and tool enable/disable.
|
||||
- Use named tools over `app_api` when a named wrapper exists.
|
||||
- `app_api` is only for safe UI/API actions without a named tool; do not use it for shell, package installs, engine rebuilds, or sensitive auth/admin paths.""",
|
||||
"contacts": """\
|
||||
## Contacts rules
|
||||
- Use `resolve_contact` to look up a contact's email or phone number by name. Searches the CardDAV address book and sent email history.
|
||||
- Use `manage_contact` to list, add, update, or delete contacts in the address book.
|
||||
- Do NOT use `manage_memory` for contact lookups — contact details live in the address book, not memory.""",
|
||||
}
|
||||
|
||||
_DOMAIN_TOOL_MAP = {
|
||||
@@ -272,8 +277,9 @@ _DOMAIN_TOOL_MAP = {
|
||||
"notes_calendar_tasks": {"manage_notes", "manage_calendar", "manage_tasks"},
|
||||
"ui": {"ui_control"},
|
||||
"sessions": {"create_session", "list_sessions", "manage_session", "send_to_session", "search_chats"},
|
||||
"files": {"bash", "python", "read_file", "write_file", "edit_file", "grep", "glob", "ls"},
|
||||
"files": {"bash", "python", "read_file", "write_file", "edit_file", "grep", "glob", "ls", "get_workspace"},
|
||||
"settings": {"manage_settings", "manage_endpoints", "manage_mcp", "manage_webhooks", "manage_tokens", "app_api"},
|
||||
"contacts": {"resolve_contact", "manage_contact"},
|
||||
}
|
||||
|
||||
def _domain_rules_for_tools(tool_names: set) -> list[str]:
|
||||
@@ -309,6 +315,7 @@ NEVER pipe multi-line Python through `python -c "..."` — shell quoting eats re
|
||||
<python code>
|
||||
```
|
||||
Execute Python code. Use for computation, data processing, scripting. NOT for writing code for the user (use create_document for that). Same sandbox limits as bash — no TTY, no GUI, no `input()`; for anything the user should interact with, generate a single HTML file with inline JS instead.
|
||||
Prefer a dedicated tool whenever one fits the job (reading, searching, or writing files); use python only for computation/processing no dedicated tool covers - not for reading or writing files.
|
||||
Do NOT use Python/requests for web lookup/search/latest/current requests when `web_search` or `web_fetch` is available.""",
|
||||
|
||||
"web_search": """\
|
||||
@@ -347,6 +354,11 @@ Write content to a file. First line is the path, rest is the content.""",
|
||||
```
|
||||
Edit an EXISTING file by exact string replacement. PREFER this over bash (sed/echo/redirects) for changing files — it shows a before/after diff. `old_string` must match the file exactly and be unique unless `replace_all` is true. Use write_file to create a new file.""",
|
||||
|
||||
"get_workspace": """\
|
||||
```get_workspace
|
||||
```
|
||||
Return the absolute path of the active workspace folder. File tools are CONFINED to it (paths can be RELATIVE to it); the shell starts there (cwd) but is NOT sandboxed. Call this first when the user says "the project"/"the code"/"this folder" without a path, instead of asking them. No arguments.""",
|
||||
|
||||
"create_document": """\
|
||||
```create_document
|
||||
<title>
|
||||
@@ -598,7 +610,7 @@ _API_HOSTS = frozenset([
|
||||
"api.deepseek.com", "deepseek.com",
|
||||
"api.together.xyz", "api.fireworks.ai",
|
||||
"api.perplexity.ai", "api.x.ai",
|
||||
"ollama.com", "api.venice.ai",
|
||||
"ollama.com", "api.venice.ai", "api.kimi.com",
|
||||
"api.githubcopilot.com",
|
||||
# Local OpenAI-compatible endpoints (llama.cpp, vLLM, LM Studio, etc.).
|
||||
# Without these, `_is_api_model` falls back to keyword sniffing on the
|
||||
@@ -785,6 +797,12 @@ def _classify_agent_request(messages: List[Dict], last_user: str) -> Dict[str, o
|
||||
domains.add("documents")
|
||||
if has(r"\b(search|web|google|look up|latest|news|current|weather|forecast|stock price|price of|website|url|https?://|www\.)\b"):
|
||||
domains.add("web")
|
||||
if has(
|
||||
r"\b(wyszukaj|wyszukać|wyszukac)\b.*\b(internet|internecie|online|web)\b",
|
||||
r"\b(sprawd[zź]|znajd[zź])\b.*\b(internet|internecie|online|web)\b",
|
||||
r"\b(aktualn\w*|bieżąc\w*|biezac\w*|dzisiaj|teraz)\b.*\b(pogod\w*|temperatur\w*)\b",
|
||||
):
|
||||
domains.add("web")
|
||||
if has(r"\b(research|deep dive|investigate|look into)\b"):
|
||||
domains.add("web")
|
||||
if has(r"\b(open|show|toggle|turn on|turn off|disable|enable|switch model|change model|settings|theme|panel)\b"):
|
||||
@@ -795,6 +813,8 @@ def _classify_agent_request(messages: List[Dict], last_user: str) -> Dict[str, o
|
||||
domains.add("files")
|
||||
if has(r"\b(endpoint|api token|mcp|webhook|preference|configure|config|setting)\b"):
|
||||
domains.add("settings")
|
||||
if has(r"\b(contact|contacts|phone|phone number|address book|vcard)\b"):
|
||||
domains.add("contacts")
|
||||
|
||||
low_signal = not continuation and not domains
|
||||
return {
|
||||
@@ -860,7 +880,7 @@ def _build_system_prompt(
|
||||
_ov_sig = _hl.sha256(_json.dumps(get_builtin_overrides() or {}, sort_keys=True).encode()).hexdigest()
|
||||
except Exception:
|
||||
_ov_sig = ""
|
||||
cache_key = (frozenset(disabled_tools or []), bool(mcp_mgr), needs_admin, _rt_key, compact, _ov_sig, suppress_local_context)
|
||||
cache_key = (frozenset(disabled_tools or []), bool(mcp_mgr), needs_admin, _rt_key, compact, _ov_sig, owner, suppress_local_context)
|
||||
if _cached_base_prompt and _cached_base_prompt_key == cache_key and not active_document:
|
||||
agent_prompt = _cached_base_prompt
|
||||
# Skill index is user-editable (name + description), so it must never
|
||||
@@ -868,7 +888,7 @@ def _build_system_prompt(
|
||||
# when the cache hits.
|
||||
_, _skill_index_block = _build_base_prompt(
|
||||
disabled_tools, mcp_mgr, needs_admin, relevant_tools,
|
||||
mcp_disabled_map=mcp_disabled_map, compact=compact,
|
||||
mcp_disabled_map=mcp_disabled_map, compact=compact, owner=owner,
|
||||
suppress_local_context=suppress_local_context,
|
||||
)
|
||||
else:
|
||||
@@ -879,6 +899,7 @@ def _build_system_prompt(
|
||||
relevant_tools,
|
||||
mcp_disabled_map=mcp_disabled_map,
|
||||
compact=compact,
|
||||
owner=owner,
|
||||
suppress_local_context=suppress_local_context,
|
||||
)
|
||||
if not active_document:
|
||||
@@ -894,9 +915,20 @@ def _build_system_prompt(
|
||||
|
||||
# Current date/time for every agent request. This is user-local when the
|
||||
# browser provided timezone headers, with a server-local fallback.
|
||||
#
|
||||
# IMPORTANT: this is intentionally NOT prepended into agent_prompt (the
|
||||
# system message) anymore. Its text changes every minute, and local
|
||||
# OpenAI-compatible backends (llama.cpp / LM Studio) key their KV-cache
|
||||
# prefix off the system message byte-for-byte — mixing ever-changing
|
||||
# timestamp text into the (already large, tool-laden) agent system prompt
|
||||
# would invalidate the cached prefix on every single request, forcing a
|
||||
# full prompt re-evaluation each turn (issue #2927). It's built here as a
|
||||
# standalone *user*-role message and inserted near the end of the array,
|
||||
# right alongside _doc_message / _skills_message, below.
|
||||
_datetime_message = None
|
||||
try:
|
||||
from src.user_time import current_datetime_prompt
|
||||
agent_prompt = current_datetime_prompt() + agent_prompt
|
||||
from src.user_time import current_datetime_context_message
|
||||
_datetime_message = current_datetime_context_message()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@@ -1296,6 +1328,9 @@ def _build_system_prompt(
|
||||
last_user_idx += 1
|
||||
if _skills_message:
|
||||
merged.insert(last_user_idx, _skills_message)
|
||||
last_user_idx += 1
|
||||
if _datetime_message:
|
||||
merged.insert(last_user_idx, _datetime_message)
|
||||
|
||||
return merged, mcp_schemas
|
||||
|
||||
@@ -1314,6 +1349,7 @@ def _build_base_prompt(
|
||||
relevant_tools=None,
|
||||
mcp_disabled_map=None,
|
||||
compact: bool = False,
|
||||
owner: Optional[str] = None,
|
||||
suppress_local_context: bool = False,
|
||||
):
|
||||
"""Build the agent prompt with only relevant tools included.
|
||||
@@ -1373,7 +1409,7 @@ def _build_base_prompt(
|
||||
from src.constants import DATA_DIR
|
||||
_sm = SkillsManager(DATA_DIR)
|
||||
active_tools = list(set(TOOL_SECTIONS.keys()) - set(disabled or []))
|
||||
skill_idx = _sm.index_for(owner=None, active_toolsets=active_tools)
|
||||
skill_idx = _sm.index_for(owner=owner, active_toolsets=active_tools)
|
||||
if skill_idx:
|
||||
lines = ["## Available skills",
|
||||
"Procedures the assistant should consult before doing domain work. "
|
||||
@@ -1782,10 +1818,10 @@ async def stream_agent_loop(
|
||||
owner: Optional[str] = None,
|
||||
relevant_tools: Optional[Set[str]] = None,
|
||||
fallbacks: Optional[List[tuple]] = None,
|
||||
workspace: Optional[str] = None,
|
||||
plan_mode: bool = False,
|
||||
approved_plan: Optional[str] = None,
|
||||
tool_policy: Optional[ToolPolicy] = None,
|
||||
workspace: Optional[str] = None,
|
||||
_is_teacher_run: bool = False,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Streaming agent loop generator.
|
||||
@@ -1854,8 +1890,21 @@ async def stream_agent_loop(
|
||||
logger.info(f"[tool-rag] Using caller-provided relevant_tools ({len(_relevant_tools)} tools)")
|
||||
if not guide_only and not _relevant_tools and bool(_intent.get("low_signal")):
|
||||
from src.tool_index import ALWAYS_AVAILABLE
|
||||
_relevant_tools = set(ALWAYS_AVAILABLE)
|
||||
logger.info("[tool-rag] Low-signal agent message; skipping retrieval and using always-available tools only")
|
||||
if workspace:
|
||||
# An active workspace IS the file-work signal: a vague "look at the
|
||||
# project" means explore this folder. Surface only the READ-ONLY file
|
||||
# tools (intersection with the plan-mode read-only allowlist) so the
|
||||
# agent can investigate; write/shell tools stay out until the request
|
||||
# actually calls for them (RAG retrieval adds those on a real ask).
|
||||
_relevant_tools = set(ALWAYS_AVAILABLE)
|
||||
from src.tool_security import PLAN_MODE_READONLY_TOOLS
|
||||
_relevant_tools |= (_DOMAIN_TOOL_MAP["files"] & PLAN_MODE_READONLY_TOOLS)
|
||||
logger.info("[tool-rag] Low-signal but workspace active; including read-only file tools")
|
||||
else:
|
||||
# Don't short-circuit: fall through to RAG retrieval below.
|
||||
# Non-English queries are flagged low_signal by the English-only
|
||||
# intent classifier, but fastembed retrieval works across languages.
|
||||
logger.info("[tool-rag] Low-signal query; will run RAG retrieval")
|
||||
if not guide_only and not _relevant_tools:
|
||||
try:
|
||||
from src.tool_index import get_tool_index, ALWAYS_AVAILABLE
|
||||
@@ -1930,6 +1979,44 @@ async def stream_agent_loop(
|
||||
if _relevant_tools is not None and active_document is not None:
|
||||
_relevant_tools.update({"edit_document", "update_document", "suggest_document"})
|
||||
|
||||
# The skill index injected by _build_system_prompt tells the model to
|
||||
# call `manage_skills action=view`, and Jaccard-matched skills are pasted
|
||||
# into the prompt as procedures to follow — but neither path goes through
|
||||
# tool selection, so the model can be handed a procedure naming tools
|
||||
# (grep, read_file, ...) that aren't in its schema list. Keep the schemas
|
||||
# in lockstep: manage_skills is callable whenever any skill is indexed,
|
||||
# and a matched skill's declared requires_toolsets ride along with it.
|
||||
if not guide_only and _relevant_tools is not None:
|
||||
try:
|
||||
from services.memory.skills import SkillsManager
|
||||
from src.constants import DATA_DIR
|
||||
_skills_on = True
|
||||
try:
|
||||
from routes.prefs_routes import _load_for_user as _load_prefs
|
||||
_skills_on = (_load_prefs(owner) or {}).get("skills_enabled", True)
|
||||
except Exception:
|
||||
pass
|
||||
_sm = SkillsManager(DATA_DIR)
|
||||
_owner_skills = _sm.load(owner=owner) if _skills_on else []
|
||||
if _owner_skills:
|
||||
_relevant_tools.add("manage_skills")
|
||||
if _retrieval_query:
|
||||
# Validate against every known executable tool, not just
|
||||
# TOOL_SECTIONS — code-nav tools (grep/glob/ls) ship as
|
||||
# schemas without a prompt-prose section.
|
||||
from src.tool_policy import known_tool_names
|
||||
_known = known_tool_names()
|
||||
for _sk in _sm.get_relevant_skills(
|
||||
_retrieval_query, skills=_owner_skills,
|
||||
threshold=0.25, max_items=3,
|
||||
):
|
||||
_relevant_tools.update(
|
||||
t for t in (_sk.get("requires_toolsets") or [])
|
||||
if t in _known
|
||||
)
|
||||
except Exception as _e:
|
||||
logger.debug(f"[tool-rag] skill-aware tool include skipped: {_e}")
|
||||
|
||||
if _relevant_tools is not None:
|
||||
logger.info("[agent-intent] selected_tools=%s", sorted(_relevant_tools)[:50])
|
||||
|
||||
@@ -1980,6 +2067,10 @@ async def stream_agent_loop(
|
||||
# and can override this list for users who know their setup.
|
||||
_model_no_tools = any(kw in _model_lc for kw in (
|
||||
"deepseek-r1",
|
||||
# Open-weight GPT-OSS models are commonly served through llama.cpp /
|
||||
# llama-cpp-python. Their names contain "gpt-o", but they do not use
|
||||
# OpenAI's native tool-call channel unless the endpoint opts in.
|
||||
"gpt-oss",
|
||||
))
|
||||
# Native Ollama endpoints (/api/chat) handle tool schemas differently from
|
||||
# the OpenAI-compat path. Models like gemma4, qwen3.5, ministral respond to
|
||||
@@ -2011,27 +2102,6 @@ async def stream_agent_loop(
|
||||
suppress_local_context=guide_only,
|
||||
active_email=active_email,
|
||||
)
|
||||
if workspace and not guide_only:
|
||||
# PREPEND (not append) so it dominates the large base prompt — appended
|
||||
# at the end, small models ignored it and asked the user for code. The
|
||||
# folder IS the project; the agent must explore it, not ask.
|
||||
_ws_note = (
|
||||
f"## ACTIVE WORKSPACE — READ FIRST\n"
|
||||
f"The user is working in this folder: {workspace}\n"
|
||||
f"It IS the project. bash/python run with cwd set here and "
|
||||
f"read_file/write_file are confined to it (paths outside are rejected).\n"
|
||||
f"When the user says \"the code\" / \"this project\" / \"the workspace\" "
|
||||
f"or asks to review/find/edit something WITHOUT a path, they mean THIS "
|
||||
f"folder. Do NOT ask the user for code or a path, and do NOT read a file "
|
||||
f"literally named \"workspace\". ALWAYS start by exploring it yourself: "
|
||||
f"run `bash` → `git ls-files` (or `ls -R`) to see the files, then "
|
||||
f"read_file the relevant ones by path RELATIVE to the workspace."
|
||||
)
|
||||
if messages and messages[0].get("role") == "system":
|
||||
messages[0]["content"] = _ws_note + "\n\n" + (messages[0].get("content") or "")
|
||||
else:
|
||||
messages.insert(0, {"role": "system", "content": _ws_note})
|
||||
logger.info("[workspace] active for this turn: %s", workspace)
|
||||
if plan_mode and not guide_only:
|
||||
# Steer the model to investigate-then-propose. Hard tool gating handles
|
||||
# every write path except shell; this directive is what keeps the
|
||||
@@ -2063,30 +2133,34 @@ async def stream_agent_loop(
|
||||
_t3 = time.time()
|
||||
try:
|
||||
from src.context_compactor import trim_for_context
|
||||
from src.context_budget import compute_input_token_budget, DEFAULT_HARD_MAX
|
||||
from src.settings import is_setting_overridden
|
||||
from src.context_budget import compute_input_token_budget, DEFAULT_HARD_MAX, DEFAULT_BUDGET, budget_is_explicit as _budget_is_explicit
|
||||
from src.model_context import budget_context_for_model
|
||||
|
||||
soft_budget = int(get_setting("agent_input_token_budget", 6000) or 0)
|
||||
soft_budget = int(get_setting("agent_input_token_budget", DEFAULT_BUDGET) or 0)
|
||||
if soft_budget > 0:
|
||||
before_trim_tokens = estimate_tokens(messages)
|
||||
reserve_tokens = min(max(max_tokens or 1024, 512), 2048)
|
||||
# Honour the configurable ceiling for the auto-derived budget path.
|
||||
# No-op when the user has an explicit `agent_input_token_budget`
|
||||
# (that branch ignores hard_max). Falls back to DEFAULT_HARD_MAX
|
||||
# on missing/malformed values so misconfig can't zero the budget.
|
||||
# Ceiling for the auto-derived budget (no effect on an explicit budget;
|
||||
# see #1230). Falls back to DEFAULT_HARD_MAX on missing/malformed values
|
||||
# so misconfig can't zero the budget.
|
||||
try:
|
||||
hard_max = int(get_setting("agent_input_token_hard_max", DEFAULT_HARD_MAX) or DEFAULT_HARD_MAX)
|
||||
except (TypeError, ValueError):
|
||||
hard_max = DEFAULT_HARD_MAX
|
||||
if hard_max <= 0:
|
||||
hard_max = DEFAULT_HARD_MAX
|
||||
# Scale the default budget to the model's context window so long-context
|
||||
# models aren't silently capped at 6000; an explicit user setting is
|
||||
# still honoured (clamped to the window). (#1170)
|
||||
# Default value = auto sentinel (scale to the window); any other value =
|
||||
# explicit cap. Value-based, not presence-based, because the save path
|
||||
# materializes defaults so a persisted default must still read as auto (#4121).
|
||||
budget_is_explicit = _budget_is_explicit(soft_budget)
|
||||
# Scale only off a window we actually discovered, bound to the value it
|
||||
# proves (else 0) — not the passed-in context_length, which can be stale
|
||||
# or unset for some callers (#4122 review).
|
||||
ctx_for_budget = budget_context_for_model(endpoint_url, model, fallback=context_length)
|
||||
effective_budget = compute_input_token_budget(
|
||||
soft_budget,
|
||||
context_length,
|
||||
is_setting_overridden("agent_input_token_budget"),
|
||||
ctx_for_budget,
|
||||
budget_is_explicit,
|
||||
hard_max=hard_max,
|
||||
)
|
||||
trimmed_messages = trim_for_context(
|
||||
@@ -2161,11 +2235,12 @@ async def stream_agent_loop(
|
||||
# tool, so we don't nudge on harmless transitional text like "let me
|
||||
# know what you think".
|
||||
_INTENT_RE = re.compile(
|
||||
r"(?:^|\n)\s*(?:let me|i'?ll|i will|going to|let's)\s+"
|
||||
r"(?:^|\n)\s*(?:let me|i'?ll|i will|i need to|we need to|need to|"
|
||||
r"i should|we should|i must|we must|going to|let's)\s+"
|
||||
r"(?:tail|check|investigate|look at|see|tail|read|fetch|inspect|"
|
||||
r"verify|diagnose|examine|debug|capture|grab|pull|view|run|call|"
|
||||
r"trigger|launch|start|kick off|stop|kill|restart|adopt|serve|"
|
||||
r"register|adopt|list|search|find|query|hit|ping|test)"
|
||||
r"register|adopt|list|search|find|query|hit|ping|test|use|perform|do)"
|
||||
r"\b[^.\n]{0,140}",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
@@ -2206,9 +2281,17 @@ async def stream_agent_loop(
|
||||
elif _is_api_model:
|
||||
# Filter schemas by RAG-selected tools (if available)
|
||||
if _relevant_tools:
|
||||
# _build_base_prompt unions _ADMIN_TOOLS into the prompt
|
||||
# sections when admin intent fires — the schema list must
|
||||
# offer the same names, or the model reads prose describing
|
||||
# tools it cannot call and substitutes the nearest schema
|
||||
# it does have (e.g. manage_memory for manage_skills).
|
||||
_schema_names = set(_relevant_tools)
|
||||
if _needs_admin:
|
||||
_schema_names |= _ADMIN_TOOLS
|
||||
base_schemas = [
|
||||
s for s in FUNCTION_TOOL_SCHEMAS
|
||||
if s.get("function", {}).get("name") in _relevant_tools
|
||||
if s.get("function", {}).get("name") in _schema_names
|
||||
]
|
||||
_mcp_filtered = [
|
||||
s for s in mcp_schemas
|
||||
@@ -2254,6 +2337,7 @@ async def stream_agent_loop(
|
||||
prompt_type=prompt_type if round_num == 1 else None,
|
||||
tools=all_tool_schemas if all_tool_schemas else None,
|
||||
timeout=agent_stream_timeout,
|
||||
session_id=session_id,
|
||||
):
|
||||
if time.time() > _round_deadline:
|
||||
logger.warning(f"[agent] round {round_num} stream exceeded wall-clock deadline; cutting off")
|
||||
@@ -2743,6 +2827,46 @@ async def stream_agent_loop(
|
||||
)
|
||||
desc, result = await _tool_task
|
||||
|
||||
# A skill the model just loaded can prescribe tools that weren't
|
||||
# RAG-selected this turn (declared via requires_toolsets in its
|
||||
# frontmatter). Union them into the selection so the NEXT round's
|
||||
# schema list includes them — otherwise the model reads "use
|
||||
# grep" from the skill it fetched but has no grep schema to call.
|
||||
if (
|
||||
block.tool_type == "manage_skills"
|
||||
and _relevant_tools is not None
|
||||
and not result.get("error")
|
||||
):
|
||||
_ms_args = {}
|
||||
_ms_raw = (block.content or "").strip()
|
||||
if _ms_raw.startswith("{"):
|
||||
try:
|
||||
_ms_args = json.loads(_ms_raw)
|
||||
except json.JSONDecodeError:
|
||||
_ms_args = {}
|
||||
_ms_name = str(_ms_args.get("name", "") or "").strip()
|
||||
if _ms_name and _ms_args.get("action") in ("view", "view_ref"):
|
||||
try:
|
||||
from services.memory.skills import SkillsManager as _SkM
|
||||
from src.constants import DATA_DIR as _DD
|
||||
from src.tool_policy import known_tool_names as _ktn
|
||||
_known = _ktn()
|
||||
for _sk in _SkM(_DD).load(owner=owner):
|
||||
if _sk.get("name") == _ms_name:
|
||||
_new = {
|
||||
t for t in (_sk.get("requires_toolsets") or [])
|
||||
if t in _known and t not in _relevant_tools
|
||||
}
|
||||
if _new:
|
||||
_relevant_tools.update(_new)
|
||||
logger.info(
|
||||
"[tool-rag] skill '%s' unlocked tools for next round: %s",
|
||||
_ms_name, sorted(_new),
|
||||
)
|
||||
break
|
||||
except Exception as _e:
|
||||
logger.debug(f"skill requires_toolsets unlock skipped: {_e}")
|
||||
|
||||
# Extract structured web sources from web_search tool output.
|
||||
# web_search returns {"output": ..., "exit_code": 0}; check "output"
|
||||
# first so the <!-- SOURCES:…--> marker is found and stripped even
|
||||
@@ -2833,18 +2957,20 @@ async def stream_agent_loop(
|
||||
# On a bash/python timeout the result carries error + (often
|
||||
# empty) stdout/stderr; fall back to the error so the "timed
|
||||
# out" reason reaches the UI instead of a blank result.
|
||||
output_text = (result["stdout"] or result["stderr"] or result.get("error", ""))[:2000]
|
||||
raw = result["stdout"] or result["stderr"] or result.get("error", "")
|
||||
output_text = _truncate(raw)
|
||||
elif "output" in result:
|
||||
# bash / python canonical result: {"output": ..., "exit_code": ...}
|
||||
output_text = (result["output"] or "")[:2000]
|
||||
raw = result["output"] or ""
|
||||
output_text = _truncate(raw)
|
||||
elif "response" in result:
|
||||
# AI interaction tools (chat_with_model, send_to_session)
|
||||
label = result.get("model", result.get("session_name", "AI"))
|
||||
output_text = f"{label}: {result['response']}"[:4000]
|
||||
output_text = _truncate(f"{label}: {result['response']}")
|
||||
elif "content" in result:
|
||||
output_text = result["content"][:2000]
|
||||
output_text = _truncate(result["content"])
|
||||
elif "results" in result:
|
||||
output_text = result["results"][:4000]
|
||||
output_text = _truncate(result["results"])
|
||||
elif "session_id" in result and "name" in result:
|
||||
output_text = f"Session created: {result['name']} (id: {result['session_id']})"
|
||||
elif "success" in result:
|
||||
@@ -2854,7 +2980,7 @@ async def stream_agent_loop(
|
||||
else f"Error: {result.get('error', '')}"
|
||||
)
|
||||
elif "error" in result:
|
||||
output_text = result["error"][:2000]
|
||||
output_text = _truncate(result["error"])
|
||||
|
||||
# Emit tool_output (include ui_event data if present)
|
||||
tool_output_data = {"type": "tool_output", "tool": block.tool_type, "command": cmd_display, "output": output_text, "exit_code": result.get("exit_code")}
|
||||
|
||||
@@ -18,6 +18,30 @@ from src.tool_utils import _truncate, get_mcp_manager, set_mcp_manager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from .subprocess_tools import BashTool, PythonTool
|
||||
from .web_tools import WebSearchTool, WebFetchTool
|
||||
from .filesystem_tools import ReadFileTool, WriteFileTool, EditFileTool, LsTool, GlobTool, GrepTool, GetWorkspaceTool
|
||||
from .document_tools import CreateDocumentTool, UpdateDocumentTool, EditDocumentTool, SuggestDocumentTool, ManageDocumentTool
|
||||
|
||||
TOOL_HANDLERS = {
|
||||
"bash": BashTool().execute,
|
||||
"python": PythonTool().execute,
|
||||
"web_search": WebSearchTool().execute,
|
||||
"web_fetch": WebFetchTool().execute,
|
||||
"read_file": ReadFileTool().execute,
|
||||
"write_file": WriteFileTool().execute,
|
||||
"edit_file": EditFileTool().execute,
|
||||
"ls": LsTool().execute,
|
||||
"glob": GlobTool().execute,
|
||||
"grep": GrepTool().execute,
|
||||
"create_document": CreateDocumentTool().execute,
|
||||
"update_document": UpdateDocumentTool().execute,
|
||||
"edit_document": EditDocumentTool().execute,
|
||||
"suggest_document": SuggestDocumentTool().execute,
|
||||
"manage_documents": ManageDocumentTool().execute,
|
||||
"get_workspace": GetWorkspaceTool().execute,
|
||||
}
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Constants (re-exported for backward compatibility — single source of truth
|
||||
# is src.constants; always prefer importing from there for new code)
|
||||
@@ -28,7 +52,7 @@ PYTHON_TIMEOUT = 30
|
||||
|
||||
# Tool types that trigger execution
|
||||
TOOL_TAGS = {"bash", "python", "web_search", "web_fetch", "read_file", "write_file", "edit_file",
|
||||
"grep", "glob", "ls",
|
||||
"grep", "glob", "ls", "get_workspace",
|
||||
"create_document", "update_document", "edit_document",
|
||||
"search_chats",
|
||||
"chat_with_model", "create_session", "list_sessions",
|
||||
@@ -92,15 +116,14 @@ from src.tool_execution import ( # noqa: E402, F401
|
||||
format_tool_result,
|
||||
)
|
||||
|
||||
# Document functions
|
||||
from .document_tools import (
|
||||
set_active_document,
|
||||
set_active_model
|
||||
)
|
||||
|
||||
# Implementations
|
||||
from src.tool_implementations import ( # noqa: E402, F401
|
||||
set_active_document,
|
||||
set_active_model,
|
||||
get_active_document,
|
||||
do_create_document,
|
||||
do_update_document,
|
||||
do_edit_document,
|
||||
do_suggest_document,
|
||||
do_search_chats,
|
||||
do_manage_skills,
|
||||
do_manage_tasks,
|
||||
@@ -108,7 +131,6 @@ from src.tool_implementations import ( # noqa: E402, F401
|
||||
do_manage_mcp,
|
||||
do_manage_webhooks,
|
||||
do_manage_tokens,
|
||||
do_manage_documents,
|
||||
do_manage_settings,
|
||||
do_api_call,
|
||||
)
|
||||
@@ -0,0 +1,644 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
import logging
|
||||
import re
|
||||
import json
|
||||
from src.constants import MAX_READ_CHARS
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Active document state
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_active_document_id: Optional[str] = None
|
||||
_active_model: Optional[str] = None
|
||||
|
||||
|
||||
def set_active_document(doc_id: Optional[str]):
|
||||
"""Set the active document ID for document tool execution."""
|
||||
global _active_document_id
|
||||
_active_document_id = doc_id
|
||||
|
||||
|
||||
def set_active_model(model: Optional[str]):
|
||||
"""Set the current model name for version summaries."""
|
||||
global _active_model
|
||||
_active_model = model
|
||||
|
||||
|
||||
def get_active_document():
|
||||
return _active_document_id
|
||||
|
||||
|
||||
def clear_active_document(doc_id: Optional[str] = None) -> bool:
|
||||
"""Clear the in-memory active-document pointer.
|
||||
|
||||
With ``doc_id`` given, only clears when it matches the current pointer, so a
|
||||
different active document is left untouched. Returns True if it was cleared.
|
||||
|
||||
Called when a document is detached from its session or deleted (its tab is
|
||||
closed): without this, the stale pointer makes the last-resort doc-injection
|
||||
path re-surface a closed document in a later, unrelated chat — even one whose
|
||||
session no longer matches — because an unlinked doc has session_id NULL (#1160).
|
||||
"""
|
||||
global _active_document_id
|
||||
if doc_id is None or _active_document_id == doc_id:
|
||||
_active_document_id = None
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _owned_document_query(query, Document, owner: Optional[str]):
|
||||
if owner is None:
|
||||
# A bare Python `False` is not a valid SQL expression — SQLAlchemy 1.4
|
||||
# deprecates it and 2.0 raises ArgumentError. Use the SQL `false()`
|
||||
# literal to return zero rows for an unscoped (owner-less) query.
|
||||
from sqlalchemy import false
|
||||
return query.filter(false())
|
||||
return query.filter(Document.owner == owner)
|
||||
|
||||
|
||||
def _get_owned_document(db, Document, doc_id: str, owner: Optional[str], active_only: bool = False):
|
||||
q = db.query(Document).filter(Document.id == doc_id)
|
||||
if active_only:
|
||||
q = q.filter(Document.is_active == True)
|
||||
q = _owned_document_query(q, Document, owner)
|
||||
return q.first()
|
||||
|
||||
|
||||
def _most_recent_owned_document(db, Document, owner: Optional[str], active_only: bool = False):
|
||||
q = db.query(Document)
|
||||
if active_only:
|
||||
q = q.filter(Document.is_active == True)
|
||||
q = _owned_document_query(q, Document, owner)
|
||||
return q.order_by(Document.updated_at.desc()).first()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Document tools — create/update/edit/suggest living documents
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _sniff_doc_language(text: str) -> str:
|
||||
"""Best-effort detect a document's language from its content when the model
|
||||
didn't specify one. Defaults to 'markdown' (prose). Recognizes the common
|
||||
markup/code types the editor supports so e.g. an SVG isn't saved as markdown."""
|
||||
import json as _json, re as _re2
|
||||
s = (text or "").strip()
|
||||
if not s:
|
||||
return "markdown"
|
||||
head = s[:600]
|
||||
hl = head.lower()
|
||||
if _looks_like_email_document(s):
|
||||
return "email"
|
||||
# Markup (unambiguous)
|
||||
if "<svg" in hl:
|
||||
return "svg"
|
||||
if hl.startswith("<?xml"):
|
||||
return "xml"
|
||||
if (hl.startswith("<!doctype html") or hl.startswith("<html")
|
||||
or _re2.search(r"<(div|body|head|p|span|table|button|h[1-6]|ul|ol|li|img)\b", hl)):
|
||||
return "html"
|
||||
# JSON
|
||||
if s[0] in "{[":
|
||||
try:
|
||||
_json.loads(s)
|
||||
return "json"
|
||||
except Exception:
|
||||
pass
|
||||
# Shebang
|
||||
first = s.split("\n", 1)[0].strip().lower()
|
||||
if first.startswith("#!"):
|
||||
return "python" if "python" in first else "bash"
|
||||
# Code by strong leading signals (line-anchored so prose with stray words won't match)
|
||||
if _re2.search(r"(?m)^\s*(def \w|class \w|import \w|from \w[\w.]* import )", s):
|
||||
return "python"
|
||||
if _re2.search(r"(?m)^\s*(function \w|const \w|let \w|export |import .* from )", s):
|
||||
return "javascript"
|
||||
if _re2.search(r"(?mi)^\s*(select .* from |create table |insert into |update \w)", s):
|
||||
return "sql"
|
||||
if _re2.search(r"(?m)^[.#]?[\w-]+\s*\{[^{}]*:[^{}]*;", s):
|
||||
return "css"
|
||||
return "markdown"
|
||||
|
||||
def _looks_like_email_document(text: str = "", title: str = "") -> bool:
|
||||
import re as _re
|
||||
title_l = (title or "").strip().lower()
|
||||
if title_l in {"new email", "new mail", "new message"}:
|
||||
return True
|
||||
s = (text or "").lstrip()
|
||||
if "\n---\n" in s and _re.search(r"(?im)^To:\s*", s) and _re.search(r"(?im)^Subject:\s*", s):
|
||||
return True
|
||||
return bool(_re.search(r"(?im)^To:\s*", s) and _re.search(r"(?im)^Subject:\s*", s))
|
||||
|
||||
def _coerce_email_document_content(existing: str, incoming: str) -> str:
|
||||
"""Keep email docs in the To/Subject/---/body shape even if a model writes
|
||||
only the body or dumps header labels without the separator."""
|
||||
import re as _re
|
||||
old = existing or ""
|
||||
new = (incoming or "").strip()
|
||||
if "\n---\n" in new:
|
||||
return new
|
||||
header = old.split("\n---\n", 1)[0] if "\n---\n" in old else "To: \nSubject: "
|
||||
if _looks_like_email_document(new):
|
||||
lines = new.splitlines()
|
||||
last_header_idx = -1
|
||||
header_re = _re.compile(r"^(To|Cc|Bcc|Subject|In-Reply-To|References|X-Source-UID|X-Source-Folder|X-Attachments):", _re.I)
|
||||
for i, line in enumerate(lines):
|
||||
if header_re.match(line.strip()):
|
||||
last_header_idx = i
|
||||
body_lines = lines[last_header_idx + 1:] if last_header_idx >= 0 else lines
|
||||
while body_lines and not body_lines[0].strip():
|
||||
body_lines.pop(0)
|
||||
body = "\n".join(body_lines).strip()
|
||||
else:
|
||||
body = new
|
||||
return header.rstrip() + "\n---\n" + body
|
||||
|
||||
def _parse_tool_args(content):
|
||||
"""Parse a tool-call argument blob.
|
||||
|
||||
Accepts either a JSON string or an already-decoded dict. Unwraps the
|
||||
common `{"body": {...}}` envelope that smaller models emit when they
|
||||
read tool descriptions like "Body is JSON: {...}" literally — they
|
||||
pass `body` as a field name rather than treating it as a noun.
|
||||
|
||||
Returns a dict on success, raises ValueError on bad JSON.
|
||||
"""
|
||||
if isinstance(content, str):
|
||||
try:
|
||||
args = json.loads(content) if content.strip() else {}
|
||||
except (json.JSONDecodeError, TypeError) as e:
|
||||
raise ValueError(str(e))
|
||||
elif isinstance(content, dict):
|
||||
args = content
|
||||
else:
|
||||
args = {}
|
||||
# Unwrap {"body": {...}} envelope — but only if `body` is the sole key
|
||||
# and points at a dict. We don't want to clobber a legitimate `body`
|
||||
# field on tools where it's a real arg (e.g. send_email body text).
|
||||
if (
|
||||
isinstance(args, dict)
|
||||
and len(args) == 1
|
||||
and "body" in args
|
||||
and isinstance(args["body"], dict)
|
||||
and "action" in args["body"] # extra safety: only unwrap if the inner dict looks like a tool call
|
||||
):
|
||||
args = args["body"]
|
||||
return args
|
||||
|
||||
def parse_edit_blocks(content: str) -> list:
|
||||
"""Parse <<<FIND>>>...<<<REPLACE>>>...<<<END>>> blocks."""
|
||||
edits = []
|
||||
pattern = r'<<<FIND>>>\n(.*?)\n<<<REPLACE>>>\n(.*?)\n<<<END>>>'
|
||||
for m in re.finditer(pattern, content, re.DOTALL):
|
||||
edits.append({"find": m.group(1), "replace": m.group(2)})
|
||||
return edits
|
||||
|
||||
def parse_suggest_blocks(content: str) -> list:
|
||||
"""Parse <<<FIND>>>...<<<SUGGEST>>>...<<<REASON>>>...<<<END>>> blocks."""
|
||||
suggestions = []
|
||||
_skip_phrases = ["no change", "clear", "fine as", "looks good", "no improvement", "keep as"]
|
||||
pattern = r'<<<FIND>>>\n(.*?)\n<<<SUGGEST>>>\n(.*?)\n<<<REASON>>>\n(.*?)\n<<<END>>>'
|
||||
for m in re.finditer(pattern, content, re.DOTALL):
|
||||
find_text = m.group(1)
|
||||
replace_text = m.group(2)
|
||||
reason = m.group(3).strip()
|
||||
# Skip no-op suggestions where find == replace or reason says no change
|
||||
if find_text.strip() == replace_text.strip():
|
||||
continue
|
||||
if any(phrase in reason.lower() for phrase in _skip_phrases):
|
||||
continue
|
||||
suggestions.append({
|
||||
"id": f"sugg-{len(suggestions)+1}",
|
||||
"find": find_text,
|
||||
"replace": replace_text,
|
||||
"reason": reason,
|
||||
})
|
||||
return suggestions
|
||||
|
||||
|
||||
class CreateDocumentTool:
|
||||
async def execute(self, content: str, ctx: dict) -> dict:
|
||||
"""Create a new document. Supports two formats:
|
||||
1) Line-based: line 1 = title, line 2 (optional) = language, rest = content
|
||||
2) XML-like tags: <title>...</title><language>...</language><content>...</content>
|
||||
Some models mix them — strip any XML-style tags and fall back to line parsing."""
|
||||
import uuid, re as _re
|
||||
from src.database import SessionLocal, Document, DocumentVersion, Session as DbSession
|
||||
|
||||
raw = content or ""
|
||||
session_id = ctx.get("session_id")
|
||||
owner = ctx.get("owner")
|
||||
|
||||
# Known languages the editor understands (match the <select> in HTML)
|
||||
_KNOWN_LANGS = {
|
||||
"python", "javascript", "typescript", "html", "css", "markdown", "json",
|
||||
"yaml", "bash", "sql", "rust", "go", "java", "c", "cpp", "xml", "toml",
|
||||
"ini", "ruby", "php", "csv", "email", "text", "plain", "svg",
|
||||
}
|
||||
|
||||
# Try XML tag extraction first
|
||||
title = None
|
||||
language = None
|
||||
content = None
|
||||
mt = _re.search(r"<title>\s*(.*?)\s*</title>", raw, _re.DOTALL | _re.IGNORECASE)
|
||||
ml = _re.search(r"<language>\s*(.*?)\s*</language>", raw, _re.DOTALL | _re.IGNORECASE)
|
||||
mc = _re.search(r"<content>\s*(.*?)\s*</content>", raw, _re.DOTALL | _re.IGNORECASE)
|
||||
if mt or mc:
|
||||
title = mt.group(1).strip() if mt else None
|
||||
language = ml.group(1).strip().lower() if ml else None
|
||||
content = mc.group(1) if mc else None
|
||||
|
||||
# Fall back to line-based parsing. First strip any stray XML-ish tags.
|
||||
if title is None or content is None:
|
||||
cleaned = _re.sub(r"</?(?:title|language|content)>", "", raw)
|
||||
lines = cleaned.strip().split("\n")
|
||||
if title is None:
|
||||
title = lines[0].strip() if lines else "Untitled"
|
||||
lines = lines[1:]
|
||||
# Only consume second line as language if it looks like a valid short lang token
|
||||
if language is None and lines:
|
||||
candidate = lines[0].strip().lower()
|
||||
if candidate and len(candidate) < 20 and " " not in candidate and candidate in _KNOWN_LANGS:
|
||||
language = candidate
|
||||
lines = lines[1:]
|
||||
if content is None:
|
||||
content = "\n".join(lines)
|
||||
|
||||
# Validate language: must be in known set, else default based on content
|
||||
if language and language not in _KNOWN_LANGS:
|
||||
language = None
|
||||
if not language:
|
||||
# No explicit language — sniff it from the content so an SVG / HTML / JSON
|
||||
# / code document isn't silently saved as markdown. Prose → markdown.
|
||||
language = _sniff_doc_language(content)
|
||||
if _looks_like_email_document(content, title):
|
||||
language = "email"
|
||||
|
||||
if not title:
|
||||
title = "Untitled"
|
||||
|
||||
if not session_id:
|
||||
return {"error": "No session context for document creation"}
|
||||
|
||||
db = SessionLocal()
|
||||
try:
|
||||
doc_id = str(uuid.uuid4())
|
||||
ver_id = str(uuid.uuid4())
|
||||
|
||||
# Inherit ownership from the chat session so the doc survives that
|
||||
# session later being deleted (session_id → NULL).
|
||||
_sess = db.query(DbSession).filter(DbSession.id == session_id).first()
|
||||
if owner is not None and (not _sess or _sess.owner != owner):
|
||||
return {"error": "Cannot create document in another user's session"}
|
||||
_owner = _sess.owner if _sess else None
|
||||
|
||||
doc = Document(
|
||||
id=doc_id,
|
||||
session_id=session_id,
|
||||
title=title,
|
||||
language=language,
|
||||
current_content=content,
|
||||
version_count=1,
|
||||
is_active=True,
|
||||
owner=_owner,
|
||||
)
|
||||
ver = DocumentVersion(
|
||||
id=ver_id,
|
||||
document_id=doc_id,
|
||||
version_number=1,
|
||||
content=content,
|
||||
summary=f"Created by {_active_model or 'AI'}",
|
||||
source="ai",
|
||||
)
|
||||
db.add(doc)
|
||||
db.add(ver)
|
||||
db.commit()
|
||||
|
||||
set_active_document(doc_id)
|
||||
try:
|
||||
from src.event_bus import fire_event
|
||||
fire_event("document_created", _owner)
|
||||
except Exception:
|
||||
logger.debug("document_created event dispatch failed", exc_info=True)
|
||||
|
||||
return {
|
||||
"action": "create",
|
||||
"doc_id": doc_id,
|
||||
"title": title,
|
||||
"language": language,
|
||||
"content": content,
|
||||
"version": 1,
|
||||
}
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
return {"error": f"Failed to create document: {e}"}
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
class UpdateDocumentTool:
|
||||
async def execute(self, content: str, ctx: dict) -> Dict:
|
||||
"""Update an existing document. Content = full new document text."""
|
||||
import uuid
|
||||
from src.database import SessionLocal, Document, DocumentVersion
|
||||
|
||||
target_id = ctx.get("doc_id", None) or _active_document_id
|
||||
owner = ctx.get("owner")
|
||||
|
||||
db = SessionLocal()
|
||||
try:
|
||||
doc = None
|
||||
if target_id:
|
||||
doc = _get_owned_document(db, Document, target_id, owner)
|
||||
if not doc:
|
||||
doc = _most_recent_owned_document(db, Document, owner)
|
||||
if doc:
|
||||
target_id = doc.id
|
||||
set_active_document(target_id)
|
||||
logger.info(f"update_document: fell back to most recent doc id={target_id}")
|
||||
if not doc:
|
||||
return {"error": "No documents exist to update"}
|
||||
|
||||
is_email_doc = doc.language == "email" or _looks_like_email_document(doc.current_content or "", doc.title or "")
|
||||
new_content = _coerce_email_document_content(doc.current_content or "", content) if is_email_doc else content.strip()
|
||||
if is_email_doc:
|
||||
doc.language = "email"
|
||||
|
||||
new_ver = doc.version_count + 1
|
||||
ver = DocumentVersion(
|
||||
id=str(uuid.uuid4()),
|
||||
document_id=target_id,
|
||||
version_number=new_ver,
|
||||
content=new_content,
|
||||
summary=f"Updated by {_active_model or 'AI'}",
|
||||
source="ai",
|
||||
)
|
||||
doc.current_content = new_content
|
||||
doc.version_count = new_ver
|
||||
db.add(ver)
|
||||
db.commit()
|
||||
|
||||
return {
|
||||
"action": "update",
|
||||
"doc_id": target_id,
|
||||
"title": doc.title,
|
||||
"language": doc.language,
|
||||
"content": new_content,
|
||||
"version": new_ver,
|
||||
}
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
return {"error": f"Failed to update document: {e}"}
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
class EditDocumentTool:
|
||||
async def execute(self, content: str, ctx: dict) -> Dict:
|
||||
"""Apply targeted FIND/REPLACE edits to an existing document."""
|
||||
import uuid
|
||||
from src.database import SessionLocal, Document, DocumentVersion
|
||||
|
||||
target_id = ctx.get("doc_id", None) or _active_document_id
|
||||
owner = ctx.get("owner")
|
||||
|
||||
edits = parse_edit_blocks(content)
|
||||
if not edits:
|
||||
return {"error": "No valid <<<FIND>>>...<<<REPLACE>>>...<<<END>>> blocks found"}
|
||||
|
||||
db = SessionLocal()
|
||||
try:
|
||||
doc = None
|
||||
if target_id:
|
||||
doc = _get_owned_document(db, Document, target_id, owner)
|
||||
if not doc:
|
||||
# Fallback: most recently updated document. Avoids "no active doc" errors
|
||||
# after server restart or when the agent loses track of which doc to edit.
|
||||
doc = _most_recent_owned_document(db, Document, owner)
|
||||
if doc:
|
||||
target_id = doc.id
|
||||
set_active_document(target_id)
|
||||
logger.info(f"edit_document: fell back to most recent doc id={target_id} title={doc.title!r}")
|
||||
if not doc:
|
||||
return {"error": "No documents exist to edit"}
|
||||
|
||||
updated_content = doc.current_content
|
||||
applied = 0
|
||||
skipped = 0
|
||||
for edit in edits:
|
||||
_find = edit["find"]
|
||||
if _find in updated_content:
|
||||
updated_content = updated_content.replace(_find, edit["replace"], 1)
|
||||
applied += 1
|
||||
else:
|
||||
# Defensive: the active-doc context shows a "N\t" line-number
|
||||
# gutter for reference. Weaker models sometimes copy that prefix
|
||||
# into FIND. If the exact match failed, retry with a leading
|
||||
# "<digits><tab>" stripped from each FIND line — but only use it
|
||||
# when that stripped form actually matches, so we never corrupt a
|
||||
# legitimately tab-prefixed document.
|
||||
_stripped = "\n".join(re.sub(r"^\d+\t", "", _l) for _l in _find.split("\n"))
|
||||
if _stripped != _find and _stripped in updated_content:
|
||||
updated_content = updated_content.replace(_stripped, edit["replace"], 1)
|
||||
applied += 1
|
||||
logger.info("edit_document: matched after stripping line-number gutter from FIND")
|
||||
else:
|
||||
logger.warning(f"edit_document: FIND text not found, skipping: {_find[:80]!r}")
|
||||
skipped += 1
|
||||
|
||||
if applied == 0:
|
||||
return {"error": f"No edits applied — none of the FIND blocks matched the document content (skipped {skipped})"}
|
||||
|
||||
new_ver = doc.version_count + 1
|
||||
ver = DocumentVersion(
|
||||
id=str(uuid.uuid4()),
|
||||
document_id=target_id,
|
||||
version_number=new_ver,
|
||||
content=updated_content,
|
||||
summary=f"Edited by {_active_model or 'AI'} ({applied} edit(s))",
|
||||
source="ai",
|
||||
)
|
||||
doc.current_content = updated_content
|
||||
doc.version_count = new_ver
|
||||
db.add(ver)
|
||||
db.commit()
|
||||
|
||||
return {
|
||||
"action": "edit",
|
||||
"doc_id": target_id,
|
||||
"title": doc.title,
|
||||
"language": doc.language,
|
||||
"content": updated_content,
|
||||
"version": new_ver,
|
||||
"applied": applied,
|
||||
"skipped": skipped,
|
||||
}
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
return {"error": f"Failed to edit document: {e}"}
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
class SuggestDocumentTool:
|
||||
async def execute(self, content: str, ctx: dict) -> Dict:
|
||||
"""Create inline suggestions for the active document WITHOUT modifying it."""
|
||||
from src.database import SessionLocal, Document
|
||||
|
||||
target_id = ctx.get("doc_id", None) or _active_document_id
|
||||
owner = ctx.get("owner")
|
||||
|
||||
if not target_id:
|
||||
return {"error": "No active document to suggest on"}
|
||||
|
||||
suggestions = parse_suggest_blocks(content)
|
||||
if not suggestions:
|
||||
return {"error": "No valid <<<FIND>>>...<<<SUGGEST>>>...<<<REASON>>>...<<<END>>> blocks found"}
|
||||
|
||||
db = SessionLocal()
|
||||
try:
|
||||
doc = _get_owned_document(db, Document, target_id, owner)
|
||||
if not doc:
|
||||
return {"error": f"Document {target_id} not found"}
|
||||
|
||||
# Validate that FIND text exists in document
|
||||
valid = []
|
||||
for s in suggestions:
|
||||
if s["find"] in doc.current_content:
|
||||
valid.append(s)
|
||||
else:
|
||||
logger.warning(f"suggest_document: FIND text not found, skipping: {s['find'][:80]!r}")
|
||||
|
||||
if not valid:
|
||||
return {"error": "No suggestions matched the document content"}
|
||||
|
||||
return {
|
||||
"action": "suggest",
|
||||
"doc_id": target_id,
|
||||
"suggestions": valid,
|
||||
"count": len(valid),
|
||||
}
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Document management tool (delete, list, organize)
|
||||
# ---------------------------------------------------------------------------
|
||||
class ManageDocumentTool:
|
||||
async def execute(self, content: str, ctx: dict) -> Dict:
|
||||
"""Manage documents: list, read/view/open, delete, tidy.
|
||||
|
||||
Output format mirrors `manage_session`: list rows include a
|
||||
clickable `[Title](#document-<id>)` anchor + relative timestamps
|
||||
so the user can click straight from chat to open the editor.
|
||||
"""
|
||||
from core.database import SessionLocal, Document
|
||||
from datetime import datetime, timezone
|
||||
|
||||
owner = ctx.get("owner")
|
||||
|
||||
try:
|
||||
args = _parse_tool_args(content)
|
||||
except ValueError:
|
||||
return {"error": "Invalid JSON arguments", "exit_code": 1}
|
||||
|
||||
action = args.get("action", "list")
|
||||
db = SessionLocal()
|
||||
|
||||
def _rel(ts):
|
||||
if not ts:
|
||||
return 'never'
|
||||
try:
|
||||
now = datetime.now(timezone.utc) if ts.tzinfo is not None else datetime.utcnow()
|
||||
diff = (now - ts).total_seconds()
|
||||
except Exception:
|
||||
return 'unknown'
|
||||
if diff < 60: return 'just now'
|
||||
if diff < 3600: return f'{int(diff / 60)}m ago'
|
||||
if diff < 86400: return f'{int(diff / 3600)}h ago'
|
||||
if diff < 86400 * 7: return f'{int(diff / 86400)}d ago'
|
||||
return ts.strftime('%Y-%m-%d')
|
||||
|
||||
try:
|
||||
if action == "list":
|
||||
q = db.query(Document).filter(Document.is_active == True)
|
||||
q = _owned_document_query(q, Document, owner)
|
||||
if args.get("search"):
|
||||
q = q.filter(Document.title.ilike(f"%{args['search']}%"))
|
||||
if args.get("language"):
|
||||
q = q.filter(Document.language == args["language"])
|
||||
docs = q.order_by(Document.updated_at.desc()).limit(args.get("limit", 50)).all()
|
||||
if not docs:
|
||||
msg = "No documents found" + (f" matching '{args['search']}'" if args.get("search") else "") + "."
|
||||
return {"response": msg, "documents": [], "exit_code": 0}
|
||||
lines = []
|
||||
items = []
|
||||
for i, d in enumerate(docs):
|
||||
size = len(d.current_content or "")
|
||||
lang = d.language or "text"
|
||||
ts = getattr(d, 'updated_at', None) or getattr(d, 'created_at', None)
|
||||
marker = " ← most recent" if i == 0 else ""
|
||||
lines.append(
|
||||
f"- [{d.title}](#document-{d.id}) — {lang}, {size} chars, updated {_rel(ts)}{marker}"
|
||||
)
|
||||
items.append({"id": d.id, "title": d.title, "language": lang, "size": size})
|
||||
header = f"Found {len(docs)} document(s), sorted most-recent first. Click a title to open:"
|
||||
return {
|
||||
"response": header + "\n" + "\n".join(lines),
|
||||
"documents": items,
|
||||
"exit_code": 0,
|
||||
}
|
||||
|
||||
elif action in ("read", "view", "open", "get"):
|
||||
doc_id = args.get("document_id") or args.get("id") or args.get("uid")
|
||||
if not doc_id:
|
||||
return {"error": "Need document_id (use action=list to find one)", "exit_code": 1}
|
||||
doc = _get_owned_document(db, Document, doc_id, owner, active_only=True)
|
||||
if not doc:
|
||||
return {"error": f"Document '{doc_id}' not found", "exit_code": 1}
|
||||
body = doc.current_content or ""
|
||||
preview_limit = int(args.get("limit", MAX_READ_CHARS))
|
||||
truncated = len(body) > preview_limit
|
||||
preview = body[:preview_limit] + (f"\n... (truncated, {len(body)} chars total)" if truncated else "")
|
||||
anchor = f"[{doc.title}](#document-{doc.id})"
|
||||
return {
|
||||
"response": f"{anchor} — click to open in editor.\n\n```{doc.language or ''}\n{preview}\n```",
|
||||
"document": {
|
||||
"id": doc.id,
|
||||
"title": doc.title,
|
||||
"language": doc.language,
|
||||
"size": len(body),
|
||||
"content": preview,
|
||||
"truncated": truncated,
|
||||
},
|
||||
"exit_code": 0,
|
||||
}
|
||||
|
||||
elif action == "delete":
|
||||
doc_id = args.get("document_id") or args.get("id") or args.get("uid") or _active_document_id
|
||||
doc = None
|
||||
if doc_id:
|
||||
doc = _get_owned_document(db, Document, doc_id, owner)
|
||||
if not doc:
|
||||
# Fallback: most recently updated doc (likely what the user means)
|
||||
doc = _most_recent_owned_document(db, Document, owner, active_only=True)
|
||||
if not doc:
|
||||
return {"error": "No document to delete", "exit_code": 1}
|
||||
title = doc.title
|
||||
doc.is_active = False
|
||||
db.commit()
|
||||
if _active_document_id == doc.id:
|
||||
set_active_document(None)
|
||||
return {"response": f"Deleted document '{title}'", "exit_code": 0}
|
||||
|
||||
elif action == "tidy":
|
||||
from src.document_actions import run_document_tidy
|
||||
result = await run_document_tidy(owner or "")
|
||||
return {"response": result, "exit_code": 0}
|
||||
|
||||
else:
|
||||
return {"error": f"Unknown action: {action}", "exit_code": 1}
|
||||
except Exception as e:
|
||||
logger.error(f"manage_documents error: {e}")
|
||||
return {"error": str(e), "exit_code": 1}
|
||||
finally:
|
||||
db.close()
|
||||
@@ -0,0 +1,398 @@
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import difflib
|
||||
import fnmatch
|
||||
import shutil
|
||||
from typing import Optional, Dict, Any, Tuple
|
||||
|
||||
from src.constants import MAX_READ_CHARS, MAX_DIFF_LINES, MAX_OUTPUT_CHARS
|
||||
|
||||
_CODENAV_SKIP_DIRS = frozenset({
|
||||
".git", ".hg", ".svn", "node_modules", "venv", ".venv", "__pycache__",
|
||||
".mypy_cache", ".pytest_cache", ".ruff_cache", "dist", "build",
|
||||
".next", ".cache", "site-packages", ".idea", ".tox",
|
||||
})
|
||||
_CODENAV_MAX_HITS = 200
|
||||
_CODENAV_MAX_LINE = 400
|
||||
|
||||
def _unified_diff(old: str, new: str, path: str) -> Optional[Dict[str, Any]]:
|
||||
if old == new:
|
||||
return None
|
||||
old_lines = old.splitlines()
|
||||
new_lines = new.splitlines()
|
||||
label = path or "file"
|
||||
diff_lines = list(difflib.unified_diff(
|
||||
old_lines, new_lines,
|
||||
fromfile=f"a/{label}", tofile=f"b/{label}",
|
||||
lineterm="",
|
||||
))
|
||||
added = sum(1 for line in diff_lines if line.startswith("+") and not line.startswith("+++"))
|
||||
removed = sum(1 for line in diff_lines if line.startswith("-") and not line.startswith("---"))
|
||||
truncated = False
|
||||
if len(diff_lines) > MAX_DIFF_LINES:
|
||||
diff_lines = diff_lines[:MAX_DIFF_LINES]
|
||||
truncated = True
|
||||
text = "\n".join(diff_lines)
|
||||
if truncated:
|
||||
text += f"\n… diff truncated at {MAX_DIFF_LINES} lines"
|
||||
return {
|
||||
"text": text,
|
||||
"added": added,
|
||||
"removed": removed,
|
||||
"new_file": old == "",
|
||||
"file": os.path.basename(path) or (path or "file"),
|
||||
}
|
||||
|
||||
class EditFileTool:
|
||||
async def execute(self, content: str, ctx: dict) -> dict:
|
||||
from src.tool_execution import _resolve_tool_path, _resolve_search_root, _truncate
|
||||
try:
|
||||
args = json.loads(content) if content.strip().startswith("{") else {}
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
args = {}
|
||||
raw_path = (args.get("path") or "").strip()
|
||||
old = args.get("old_string", "")
|
||||
new = args.get("new_string", "")
|
||||
replace_all = bool(args.get("replace_all", False))
|
||||
if not raw_path:
|
||||
return {"error": "edit_file: path required", "exit_code": 1}
|
||||
try:
|
||||
path = _resolve_tool_path(raw_path)
|
||||
except ValueError as e:
|
||||
return {"error": f"edit_file: {e}", "exit_code": 1}
|
||||
if old == "":
|
||||
return {"error": "edit_file: old_string required (use write_file to create a file)", "exit_code": 1}
|
||||
if old == new:
|
||||
return {"error": "edit_file: old_string and new_string are identical", "exit_code": 1}
|
||||
|
||||
def _apply():
|
||||
"""Helper function that performs the actual string replacement and file writing logic."""
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
original = f.read()
|
||||
count = original.count(old)
|
||||
if count == 0:
|
||||
return original, None, "not_found"
|
||||
if count > 1 and not replace_all:
|
||||
return original, None, f"not_unique:{count}"
|
||||
updated = original.replace(old, new) if replace_all else original.replace(old, new, 1)
|
||||
with open(path, "w", encoding="utf-8") as f:
|
||||
f.write(updated)
|
||||
return original, updated, "ok"
|
||||
|
||||
try:
|
||||
original, updated, status = await asyncio.to_thread(_apply)
|
||||
except FileNotFoundError:
|
||||
return {"error": f"edit_file: {path}: not found (use write_file to create it)", "exit_code": 1}
|
||||
except (IsADirectoryError, UnicodeDecodeError):
|
||||
return {"error": f"edit_file: {path}: not an editable text file", "exit_code": 1}
|
||||
except PermissionError:
|
||||
return {"error": f"edit_file: {path}: permission denied", "exit_code": 1}
|
||||
except OSError as e:
|
||||
return {"error": f"edit_file: {path}: {e}", "exit_code": 1}
|
||||
|
||||
if status == "not_found":
|
||||
return {"error": f"edit_file: old_string not found in {path}. Read the file and match it exactly.", "exit_code": 1}
|
||||
if status.startswith("not_unique"):
|
||||
n = status.split(":", 1)[1]
|
||||
return {"error": f"edit_file: old_string is not unique in {path} ({n} matches). Add surrounding context or set replace_all=true.", "exit_code": 1}
|
||||
|
||||
n = original.count(old)
|
||||
result = {"output": f"Edited {path} ({n} replacement{'s' if n != 1 else ''})", "exit_code": 0}
|
||||
diff = _unified_diff(original, updated, path)
|
||||
if diff:
|
||||
result["diff"] = diff
|
||||
return result
|
||||
|
||||
class ReadFileTool:
|
||||
async def execute(self, content: str, ctx: dict) -> dict:
|
||||
from src.tool_execution import _resolve_tool_path, _resolve_search_root, _truncate
|
||||
raw_path, offset, limit = content.split("\n", 1)[0].strip(), 0, 0
|
||||
_stripped = content.strip()
|
||||
if _stripped.startswith("{"):
|
||||
try:
|
||||
_a = json.loads(_stripped)
|
||||
raw_path = str(_a.get("path", "")).strip()
|
||||
offset = int(_a.get("offset") or 0)
|
||||
limit = int(_a.get("limit") or 0)
|
||||
except (json.JSONDecodeError, TypeError, ValueError):
|
||||
pass
|
||||
try:
|
||||
path = _resolve_tool_path(raw_path)
|
||||
except ValueError as e:
|
||||
return {"error": f"read_file: {e}", "exit_code": 1}
|
||||
try:
|
||||
def _read():
|
||||
if offset > 0 or limit > 0:
|
||||
start = max(offset, 1)
|
||||
out, n, budget = [], 0, MAX_READ_CHARS
|
||||
with open(path, "r", encoding="utf-8", errors="replace") as f:
|
||||
for i, line in enumerate(f, 1):
|
||||
if i < start:
|
||||
continue
|
||||
if limit > 0 and n >= limit:
|
||||
break
|
||||
out.append(line)
|
||||
n += 1
|
||||
budget -= len(line)
|
||||
if budget <= 0:
|
||||
out.append(f"\n... [truncated at {MAX_READ_CHARS} chars]")
|
||||
break
|
||||
return "".join(out)
|
||||
with open(path, "r", encoding="utf-8", errors="replace") as f:
|
||||
return f.read(MAX_READ_CHARS + 1)
|
||||
data = await asyncio.to_thread(_read)
|
||||
except FileNotFoundError:
|
||||
return {"error": f"read_file: {path}: not found", "exit_code": 1}
|
||||
except PermissionError:
|
||||
return {"error": f"read_file: {path}: permission denied", "exit_code": 1}
|
||||
except IsADirectoryError:
|
||||
return {"error": f"read_file: {path}: is a directory (use ls)", "exit_code": 1}
|
||||
except OSError as e:
|
||||
return {"error": f"read_file: {path}: {e}", "exit_code": 1}
|
||||
if not (offset > 0 or limit > 0) and len(data) > MAX_READ_CHARS:
|
||||
data = data[:MAX_READ_CHARS] + f"\n... [truncated at {MAX_READ_CHARS} chars]"
|
||||
return {"output": data, "exit_code": 0}
|
||||
|
||||
class WriteFileTool:
|
||||
async def execute(self, content: str, ctx: dict) -> dict:
|
||||
from src.tool_execution import _resolve_tool_path, _resolve_search_root, _truncate
|
||||
lines = content.split("\n", 1)
|
||||
raw_path = lines[0].strip()
|
||||
body = lines[1] if len(lines) > 1 else ""
|
||||
try:
|
||||
path = _resolve_tool_path(raw_path)
|
||||
except ValueError as e:
|
||||
return {"error": f"write_file: {e}", "exit_code": 1}
|
||||
try:
|
||||
def _write():
|
||||
old = ""
|
||||
try:
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
old = f.read()
|
||||
except (FileNotFoundError, IsADirectoryError, UnicodeDecodeError, OSError):
|
||||
old = ""
|
||||
d = os.path.dirname(path)
|
||||
if d:
|
||||
os.makedirs(d, exist_ok=True)
|
||||
with open(path, "w", encoding="utf-8") as f:
|
||||
f.write(body)
|
||||
return old, len(body)
|
||||
old_content, size = await asyncio.to_thread(_write)
|
||||
except PermissionError:
|
||||
return {"error": f"write_file: {path}: permission denied", "exit_code": 1}
|
||||
except OSError as e:
|
||||
return {"error": f"write_file: {path}: {e}", "exit_code": 1}
|
||||
diff = _unified_diff(old_content, body, path)
|
||||
result = {"output": f"Wrote {size} bytes to {path}", "exit_code": 0}
|
||||
if diff:
|
||||
result["diff"] = diff
|
||||
return result
|
||||
|
||||
class LsTool:
|
||||
async def execute(self, content: str, ctx: dict) -> dict:
|
||||
from src.tool_execution import _resolve_tool_path, _resolve_search_root, _truncate
|
||||
raw_path = ""
|
||||
_s = (content or "").strip()
|
||||
if _s.startswith("{"):
|
||||
try:
|
||||
raw_path = str(json.loads(_s).get("path", "")).strip()
|
||||
except json.JSONDecodeError:
|
||||
raw_path = ""
|
||||
else:
|
||||
raw_path = _s.split("\n", 1)[0].strip()
|
||||
try:
|
||||
root = _resolve_search_root(raw_path)
|
||||
except ValueError as e:
|
||||
return {"error": f"ls: {e}", "exit_code": 1}
|
||||
|
||||
def _ls():
|
||||
if not os.path.isdir(root):
|
||||
return None, f"ls: {root}: not a directory"
|
||||
rows = []
|
||||
try:
|
||||
with os.scandir(root) as it:
|
||||
for entry in it:
|
||||
if entry.name.startswith("."):
|
||||
continue
|
||||
try:
|
||||
is_dir = entry.is_dir(follow_symlinks=False)
|
||||
size = entry.stat(follow_symlinks=False).st_size if not is_dir else 0
|
||||
except OSError:
|
||||
continue
|
||||
rows.append((is_dir, entry.name, size))
|
||||
except (PermissionError, OSError) as _e:
|
||||
return None, f"ls: {_e}"
|
||||
rows.sort(key=lambda r: (not r[0], r[1].lower()))
|
||||
lines = [f"{root}:"]
|
||||
for is_dir, name, size in rows[:_CODENAV_MAX_HITS]:
|
||||
lines.append(f" {name}/" if is_dir else f" {name} ({size} B)")
|
||||
if len(rows) > _CODENAV_MAX_HITS:
|
||||
lines.append(f" ... [{len(rows) - _CODENAV_MAX_HITS} more]")
|
||||
if not rows:
|
||||
lines.append(" (empty)")
|
||||
return "\n".join(lines), None
|
||||
|
||||
out, err = await asyncio.to_thread(_ls)
|
||||
if err:
|
||||
return {"error": err, "exit_code": 1}
|
||||
return {"output": _truncate(out), "exit_code": 0}
|
||||
|
||||
class GlobTool:
|
||||
async def execute(self, content: str, ctx: dict) -> dict:
|
||||
from src.tool_execution import _resolve_tool_path, _resolve_search_root, _truncate
|
||||
args = {}
|
||||
_s = (content or "").strip()
|
||||
if _s.startswith("{"):
|
||||
try:
|
||||
args = json.loads(_s)
|
||||
except json.JSONDecodeError:
|
||||
args = {}
|
||||
else:
|
||||
args = {"pattern": _s}
|
||||
pattern = str(args.get("pattern", "")).strip()
|
||||
if not pattern:
|
||||
return {"error": "glob: pattern is required", "exit_code": 1}
|
||||
try:
|
||||
root = _resolve_search_root(str(args.get("path", "")))
|
||||
except ValueError as e:
|
||||
return {"error": f"glob: {e}", "exit_code": 1}
|
||||
|
||||
def _glob():
|
||||
from pathlib import Path
|
||||
base = Path(root)
|
||||
if not base.is_dir():
|
||||
return None, f"glob: {root}: not a directory"
|
||||
matched = []
|
||||
try:
|
||||
for p in base.rglob(pattern):
|
||||
if set(p.relative_to(base).parts) & _CODENAV_SKIP_DIRS:
|
||||
continue
|
||||
try:
|
||||
mtime = p.stat().st_mtime
|
||||
except OSError:
|
||||
mtime = 0
|
||||
matched.append((mtime, str(p)))
|
||||
if len(matched) > _CODENAV_MAX_HITS * 5:
|
||||
break
|
||||
except (OSError, ValueError) as _e:
|
||||
return None, f"glob: {_e}"
|
||||
matched.sort(key=lambda t: t[0], reverse=True)
|
||||
return [pth for _, pth in matched[:_CODENAV_MAX_HITS]], None
|
||||
|
||||
paths, err = await asyncio.to_thread(_glob)
|
||||
if err:
|
||||
return {"error": err, "exit_code": 1}
|
||||
if not paths:
|
||||
return {"output": f"No files matching {pattern!r} under {root}", "exit_code": 0}
|
||||
out = "\n".join(paths)
|
||||
if len(paths) >= _CODENAV_MAX_HITS:
|
||||
out += f"\n... [capped at {_CODENAV_MAX_HITS} files]"
|
||||
return {"output": _truncate(out), "exit_code": 0}
|
||||
|
||||
class GrepTool:
|
||||
async def execute(self, content: str, ctx: dict) -> dict:
|
||||
from src.tool_execution import _resolve_tool_path, _resolve_search_root, _truncate
|
||||
args: Dict[str, Any] = {}
|
||||
_s = (content or "").strip()
|
||||
if _s.startswith("{"):
|
||||
try:
|
||||
args = json.loads(_s)
|
||||
except json.JSONDecodeError:
|
||||
args = {}
|
||||
else:
|
||||
args = {"pattern": _s}
|
||||
pattern = str(args.get("pattern", "")).strip()
|
||||
if not pattern:
|
||||
return {"error": "grep: pattern is required", "exit_code": 1}
|
||||
ignore_case = bool(args.get("ignore_case"))
|
||||
glob_pat = str(args.get("glob", "") or "").strip()
|
||||
try:
|
||||
max_hits = int(args.get("max_results") or _CODENAV_MAX_HITS)
|
||||
except (TypeError, ValueError):
|
||||
max_hits = _CODENAV_MAX_HITS
|
||||
max_hits = max(1, min(max_hits, _CODENAV_MAX_HITS))
|
||||
try:
|
||||
root = _resolve_search_root(str(args.get("path", "")))
|
||||
except ValueError as e:
|
||||
return {"error": f"grep: {e}", "exit_code": 1}
|
||||
|
||||
def _grep():
|
||||
import re as _re
|
||||
import shutil
|
||||
rg = shutil.which("rg")
|
||||
if rg:
|
||||
cmd = [rg, "--line-number", "--no-heading", "--color=never",
|
||||
"--max-count", str(max_hits)]
|
||||
if ignore_case:
|
||||
cmd.append("--ignore-case")
|
||||
if glob_pat:
|
||||
cmd += ["--glob", glob_pat]
|
||||
for _d in _CODENAV_SKIP_DIRS:
|
||||
cmd += ["--glob", f"!**/{_d}/**"]
|
||||
cmd += ["--regexp", pattern, root]
|
||||
try:
|
||||
import subprocess
|
||||
p = subprocess.run(cmd, capture_output=True, text=True, timeout=20)
|
||||
lines = [ln for ln in (p.stdout or "").splitlines() if ln][:max_hits]
|
||||
return lines, None
|
||||
except subprocess.TimeoutExpired:
|
||||
return None, "grep: timed out"
|
||||
except Exception as _e:
|
||||
return None, f"grep: {_e}"
|
||||
try:
|
||||
rx = _re.compile(pattern, _re.IGNORECASE if ignore_case else 0)
|
||||
except _re.error as _e:
|
||||
return None, f"grep: bad pattern: {_e}"
|
||||
hits = []
|
||||
if os.path.isfile(root):
|
||||
file_iter = [root]
|
||||
else:
|
||||
file_iter = []
|
||||
for dp, dns, fns in os.walk(root):
|
||||
dns[:] = [d for d in dns if d not in _CODENAV_SKIP_DIRS]
|
||||
for fn in fns:
|
||||
if glob_pat and not fnmatch.fnmatch(fn, glob_pat):
|
||||
continue
|
||||
file_iter.append(os.path.join(dp, fn))
|
||||
for fp in file_iter:
|
||||
if len(hits) >= max_hits:
|
||||
break
|
||||
try:
|
||||
with open(fp, "r", encoding="utf-8", errors="strict") as f:
|
||||
for i, line in enumerate(f, 1):
|
||||
if rx.search(line):
|
||||
hits.append(f"{fp}:{i}:{line.rstrip()[:_CODENAV_MAX_LINE]}")
|
||||
if len(hits) >= max_hits:
|
||||
break
|
||||
except (UnicodeDecodeError, OSError):
|
||||
continue
|
||||
return hits, None
|
||||
|
||||
lines, err = await asyncio.to_thread(_grep)
|
||||
if err:
|
||||
return {"error": err, "exit_code": 1}
|
||||
if not lines:
|
||||
return {"output": f"No matches for {pattern!r} under {root}", "exit_code": 0}
|
||||
out = "\n".join(ln[:_CODENAV_MAX_LINE] for ln in lines)
|
||||
if len(lines) >= max_hits:
|
||||
out += f"\n... [capped at {max_hits} matches]"
|
||||
return {"output": _truncate(out), "exit_code": 0}
|
||||
|
||||
class GetWorkspaceTool:
|
||||
"""Report the active workspace folder (no args). File tools are confined to
|
||||
it; the shell starts there (cwd) but is NOT sandboxed."""
|
||||
async def execute(self, content: str, ctx: dict) -> dict:
|
||||
from src.tool_execution import get_active_workspace
|
||||
ws = get_active_workspace()
|
||||
if ws:
|
||||
return {
|
||||
"output": f"{ws}\n(File tools are confined to this folder; the shell starts "
|
||||
f"here but is not sandboxed and can reach outside it.)",
|
||||
"exit_code": 0,
|
||||
}
|
||||
return {
|
||||
"output": "No workspace is set. File tools use the default allowed roots; "
|
||||
"resolve paths from the user or use absolute paths.",
|
||||
"exit_code": 0,
|
||||
}
|
||||
@@ -0,0 +1,153 @@
|
||||
import asyncio
|
||||
import sys
|
||||
import time
|
||||
import collections
|
||||
from typing import Optional, Callable, Awaitable, Tuple, Dict
|
||||
from src.constants import MAX_OUTPUT_CHARS
|
||||
|
||||
DEFAULT_BASH_TIMEOUT = 60 * 60 # 1 hour
|
||||
DEFAULT_PYTHON_TIMEOUT = 60 * 60
|
||||
|
||||
PROGRESS_INTERVAL_S = 2.0
|
||||
PROGRESS_TAIL_LINES = 12
|
||||
|
||||
async def _run_subprocess_streaming(
|
||||
proc: asyncio.subprocess.Process,
|
||||
*,
|
||||
timeout: float,
|
||||
progress_cb: Optional[Callable[[Dict], Awaitable[None]]] = None,
|
||||
) -> Tuple[str, str, Optional[int], bool]:
|
||||
started = time.time()
|
||||
stdout_full: list[str] = []
|
||||
stderr_full: list[str] = []
|
||||
tail = collections.deque(maxlen=PROGRESS_TAIL_LINES)
|
||||
|
||||
async def _reader(stream, full_buf, label: str):
|
||||
if stream is None:
|
||||
return
|
||||
while True:
|
||||
line = await stream.readline()
|
||||
if not line:
|
||||
break
|
||||
decoded = line.decode("utf-8", errors="replace").rstrip("\n")
|
||||
full_buf.append(decoded)
|
||||
if label == "err":
|
||||
tail.append(f"! {decoded}")
|
||||
else:
|
||||
tail.append(decoded)
|
||||
|
||||
async def _progress_emitter():
|
||||
await asyncio.sleep(PROGRESS_INTERVAL_S)
|
||||
while True:
|
||||
if progress_cb:
|
||||
try:
|
||||
await progress_cb({
|
||||
"elapsed_s": round(time.time() - started, 1),
|
||||
"tail": "\n".join(list(tail)),
|
||||
})
|
||||
except Exception:
|
||||
pass
|
||||
await asyncio.sleep(PROGRESS_INTERVAL_S)
|
||||
|
||||
rd_out = asyncio.create_task(_reader(proc.stdout, stdout_full, "out"))
|
||||
rd_err = asyncio.create_task(_reader(proc.stderr, stderr_full, "err"))
|
||||
prog_task = asyncio.create_task(_progress_emitter()) if progress_cb else None
|
||||
|
||||
timed_out = False
|
||||
try:
|
||||
await asyncio.wait_for(proc.wait(), timeout=timeout)
|
||||
except asyncio.TimeoutError:
|
||||
timed_out = True
|
||||
try:
|
||||
proc.kill()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
await asyncio.wait_for(proc.wait(), timeout=2)
|
||||
except Exception:
|
||||
pass
|
||||
except asyncio.CancelledError:
|
||||
try:
|
||||
proc.kill()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
await asyncio.wait_for(proc.wait(), timeout=2)
|
||||
except Exception:
|
||||
pass
|
||||
for t in (rd_out, rd_err):
|
||||
t.cancel()
|
||||
if prog_task is not None:
|
||||
prog_task.cancel()
|
||||
raise
|
||||
finally:
|
||||
if prog_task is not None and not prog_task.done():
|
||||
prog_task.cancel()
|
||||
try:
|
||||
await prog_task
|
||||
except (asyncio.CancelledError, Exception):
|
||||
pass
|
||||
for t in (rd_out, rd_err):
|
||||
try:
|
||||
await asyncio.wait_for(t, timeout=1)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return (
|
||||
"\n".join(stdout_full),
|
||||
"\n".join(stderr_full),
|
||||
proc.returncode,
|
||||
timed_out,
|
||||
)
|
||||
|
||||
class BashTool:
|
||||
async def execute(self, content: str, ctx: dict) -> dict:
|
||||
from src.tool_execution import agent_cwd, _truncate
|
||||
progress_cb = ctx.get("progress_cb")
|
||||
_subproc_env = ctx.get("subproc_env")
|
||||
proc = await asyncio.create_subprocess_shell(
|
||||
content,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
env=_subproc_env,
|
||||
cwd=agent_cwd(),
|
||||
)
|
||||
stdout, stderr, rc, timed_out = await _run_subprocess_streaming(
|
||||
proc,
|
||||
timeout=DEFAULT_BASH_TIMEOUT,
|
||||
progress_cb=progress_cb,
|
||||
)
|
||||
if timed_out:
|
||||
return {"error": f"bash: timed out after {DEFAULT_BASH_TIMEOUT}s — process killed", "exit_code": 124, "stdout": _truncate(stdout, MAX_OUTPUT_CHARS), "stderr": _truncate(stderr, MAX_OUTPUT_CHARS)}
|
||||
output = stdout.rstrip()
|
||||
err = stderr.rstrip()
|
||||
if err:
|
||||
output = (output + "\nSTDERR: " + err).strip() if output else "STDERR: " + err
|
||||
output = _truncate(output, MAX_OUTPUT_CHARS)
|
||||
return {"output": output or "(no output)", "exit_code": rc or 0}
|
||||
|
||||
class PythonTool:
|
||||
async def execute(self, content: str, ctx: dict) -> dict:
|
||||
from src.tool_execution import agent_cwd, _truncate
|
||||
progress_cb = ctx.get("progress_cb")
|
||||
_subproc_env = ctx.get("subproc_env")
|
||||
proc = await asyncio.create_subprocess_exec(
|
||||
(sys.executable or "python"), "-I", "-c", content,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
env=_subproc_env,
|
||||
cwd=agent_cwd(),
|
||||
)
|
||||
stdout, stderr, rc, timed_out = await _run_subprocess_streaming(
|
||||
proc,
|
||||
timeout=DEFAULT_PYTHON_TIMEOUT,
|
||||
progress_cb=progress_cb,
|
||||
)
|
||||
if timed_out:
|
||||
return {"error": f"python: timed out after {DEFAULT_PYTHON_TIMEOUT}s — process killed", "exit_code": 124, "stdout": _truncate(stdout, MAX_OUTPUT_CHARS), "stderr": _truncate(stderr, MAX_OUTPUT_CHARS)}
|
||||
output = stdout.rstrip()
|
||||
err = stderr.rstrip()
|
||||
if err:
|
||||
output = (output + "\nSTDERR: " + err).strip() if output else "STDERR: " + err
|
||||
output = _truncate(output, MAX_OUTPUT_CHARS)
|
||||
return {"output": output or "(no output)", "exit_code": rc or 0}
|
||||
@@ -0,0 +1,101 @@
|
||||
import asyncio
|
||||
import json
|
||||
from typing import Dict, Any
|
||||
|
||||
from src.constants import MAX_OUTPUT_CHARS
|
||||
|
||||
class WebSearchTool:
|
||||
async def execute(self, content: str, ctx: dict) -> dict:
|
||||
from src.search import comprehensive_web_search
|
||||
raw = content.strip()
|
||||
query = raw
|
||||
time_filter = None
|
||||
max_pages = 5
|
||||
if raw.startswith("{"):
|
||||
try:
|
||||
parsed = json.loads(raw)
|
||||
if isinstance(parsed, dict) and "query" in parsed:
|
||||
query = str(parsed.get("query", "")).strip()
|
||||
tf = parsed.get("time_filter") or parsed.get("freshness")
|
||||
if isinstance(tf, str) and tf.lower() in ("day", "week", "month", "year"):
|
||||
time_filter = tf.lower()
|
||||
mp = parsed.get("max_pages")
|
||||
if isinstance(mp, int) and 1 <= mp <= 10:
|
||||
max_pages = mp
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
if not query:
|
||||
query = raw.split("\n")[0].strip()
|
||||
if time_filter is None:
|
||||
q_lc = query.lower()
|
||||
if any(kw in q_lc for kw in ("today", "latest", "breaking", "this morning", "right now", "currently")):
|
||||
time_filter = "day"
|
||||
elif any(kw in q_lc for kw in ("this week", "past week", "recent news", "last few days")):
|
||||
time_filter = "week"
|
||||
elif any(kw in q_lc for kw in ("this month", "past month")):
|
||||
time_filter = "month"
|
||||
elif " news" in q_lc or q_lc.startswith("news ") or q_lc.endswith(" news"):
|
||||
time_filter = "week"
|
||||
loop = asyncio.get_running_loop()
|
||||
text, sources = await asyncio.wait_for(
|
||||
loop.run_in_executor(
|
||||
None,
|
||||
lambda: comprehensive_web_search(
|
||||
query,
|
||||
max_pages=max_pages,
|
||||
time_filter=time_filter,
|
||||
return_sources=True,
|
||||
),
|
||||
),
|
||||
timeout=30,
|
||||
)
|
||||
output = text[:MAX_OUTPUT_CHARS] if len(text) > MAX_OUTPUT_CHARS else text
|
||||
if sources:
|
||||
output += "\n\n<!-- SOURCES:" + json.dumps(sources) + " -->"
|
||||
return {"output": output, "exit_code": 0}
|
||||
|
||||
class WebFetchTool:
|
||||
async def execute(self, content: str, ctx: dict) -> dict:
|
||||
from src.search.content import fetch_webpage_content
|
||||
raw = content.strip()
|
||||
url = ""
|
||||
if raw.startswith("{"):
|
||||
try:
|
||||
parsed = json.loads(raw)
|
||||
if isinstance(parsed, dict):
|
||||
url = str(parsed.get("url") or "").strip()
|
||||
except json.JSONDecodeError:
|
||||
url = ""
|
||||
if not url:
|
||||
url = raw.split("\n")[0].strip()
|
||||
if not url or url.startswith("{") or any(c in url for c in (" ", "\t", "\n")):
|
||||
return {"error": "web_fetch: provide a single URL or domain, e.g. example.com", "exit_code": 1}
|
||||
low = url.lower()
|
||||
if "://" in low and not low.startswith(("http://", "https://")):
|
||||
return {"error": f"web_fetch: unsupported URL scheme (only http/https): {url[:80]}", "exit_code": 1}
|
||||
if not low.startswith(("http://", "https://")):
|
||||
url = "https://" + url
|
||||
loop = asyncio.get_running_loop()
|
||||
try:
|
||||
result = await asyncio.wait_for(
|
||||
loop.run_in_executor(None, lambda: fetch_webpage_content(url, timeout=10)),
|
||||
timeout=30,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
return {"error": f"web_fetch: timed out fetching {url}", "exit_code": 1}
|
||||
except Exception as e:
|
||||
return {"error": f"web_fetch: {url}: {e}", "exit_code": 1}
|
||||
err = result.get("error")
|
||||
text = (result.get("content") or "").strip()
|
||||
title = result.get("title") or ""
|
||||
|
||||
if not text:
|
||||
if err:
|
||||
return {"error": f"web_fetch: {url}: {err}", "exit_code": 1}
|
||||
return {"error": f"web_fetch: {url}: no readable text content (not HTML, or the page needs JS/login)", "exit_code": 1}
|
||||
|
||||
header = (f"# {title}\n" if title else "") + f"Source: {url}\n\n"
|
||||
output = header + text
|
||||
if len(output) > MAX_OUTPUT_CHARS:
|
||||
output = output[:MAX_OUTPUT_CHARS] + "\n\n[...truncated]"
|
||||
return {"output": output, "exit_code": 0}
|
||||
@@ -24,7 +24,9 @@ MAX_PIPELINE_STEPS = 10
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Global managers (set from app.py, same pattern as _mcp_manager)
|
||||
# ---------------------------------------------------------------------------
|
||||
# _session_manager is kept as a local cache for performance (avoiding
|
||||
# repeated get_session_manager_instance() calls). It's synced with
|
||||
# the authoritative singleton in core.models.
|
||||
_session_manager = None
|
||||
_memory_manager = None
|
||||
_memory_vector = None
|
||||
@@ -33,11 +35,15 @@ _personal_docs_manager = None
|
||||
|
||||
|
||||
def set_session_manager(mgr):
|
||||
"""Set the global session manager. Syncs local cache + core singleton."""
|
||||
global _session_manager
|
||||
_session_manager = mgr
|
||||
from core.models import set_session_manager_instance
|
||||
set_session_manager_instance(mgr)
|
||||
|
||||
|
||||
def get_session_manager():
|
||||
"""Get the global session manager."""
|
||||
return _session_manager
|
||||
|
||||
|
||||
@@ -966,16 +972,15 @@ async def do_manage_memory(content: str, session_id: Optional[str] = None, owner
|
||||
memories = [m for m in memories if m.get("category", "").lower() == category_filter]
|
||||
if not memories:
|
||||
return {"results": "No memories found" + (f" in category '{category_filter}'" if category_filter else "") + "."}
|
||||
|
||||
result_lines = [f"Found {len(memories)} memory entries:\n"]
|
||||
for m in memories[:100]:
|
||||
for m in memories:
|
||||
cat = m.get("category", "fact")
|
||||
mid = m.get("id", "?")[:8]
|
||||
text = m.get("text", "")
|
||||
if len(text) > 150:
|
||||
text = text[:150] + "..."
|
||||
result_lines.append(f"- [{cat}] `{mid}` — {text}")
|
||||
if len(memories) > 100:
|
||||
result_lines.append(f"... and {len(memories) - 100} more")
|
||||
return {"results": "\n".join(result_lines)}
|
||||
|
||||
elif action == "add":
|
||||
|
||||
+16
-2
@@ -4,6 +4,8 @@ import logging
|
||||
from typing import Dict
|
||||
from cryptography.fernet import Fernet, InvalidToken
|
||||
|
||||
from core.platform_compat import safe_chmod
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class APIKeyManager:
|
||||
@@ -15,12 +17,20 @@ class APIKeyManager:
|
||||
def get_or_create_key(self) -> bytes:
|
||||
"""Get or create encryption key for API keys"""
|
||||
if os.path.exists(self.key_file):
|
||||
# Older versions wrote .key with the process umask (often 0o644,
|
||||
# i.e. group/world-readable). Re-restrict on read so existing
|
||||
# installs heal without needing the key to be regenerated.
|
||||
safe_chmod(self.key_file, 0o600)
|
||||
with open(self.key_file, 'rb') as f:
|
||||
return f.read()
|
||||
else:
|
||||
key = Fernet.generate_key()
|
||||
with open(self.key_file, 'wb') as f:
|
||||
f.write(key)
|
||||
# This key decrypts every stored provider credential, so restrict it
|
||||
# to the owner (0o600) — it must not be group/world-readable. No-op
|
||||
# on Windows (files there are ACL-restricted to the user already).
|
||||
safe_chmod(self.key_file, 0o600)
|
||||
return key
|
||||
|
||||
def encrypt_api_key(self, api_key: str) -> str:
|
||||
@@ -57,7 +67,12 @@ class APIKeyManager:
|
||||
# Legacy/wrong shape (e.g. a list) — .items() would raise. Ignore it.
|
||||
logger.warning("API keys file has unexpected shape (%s); ignoring", type(encrypted_keys).__name__)
|
||||
return {}
|
||||
return encrypted_keys
|
||||
|
||||
return {
|
||||
str(provider): key
|
||||
for provider, key in encrypted_keys.items()
|
||||
if isinstance(key, str)
|
||||
}
|
||||
|
||||
def save(self, provider: str, api_key: str):
|
||||
"""Save encrypted API key to file.
|
||||
@@ -82,4 +97,3 @@ class APIKeyManager:
|
||||
except (InvalidToken, ValueError) as e:
|
||||
logger.warning("Failed to decrypt API key for %s: %s", provider, e)
|
||||
return decrypted
|
||||
|
||||
|
||||
@@ -55,6 +55,8 @@ async def _drain_agent(sess, messages):
|
||||
if "delta" in d:
|
||||
delta = d.get("delta")
|
||||
if isinstance(delta, str):
|
||||
if d.get("thinking"):
|
||||
continue
|
||||
full += delta
|
||||
elif d.get("type") == "agent_step":
|
||||
round_num = d.get("round", round_num)
|
||||
|
||||
+32
-16
@@ -579,6 +579,24 @@ def _classify_event_heuristic(summary: str) -> tuple:
|
||||
return etype, None
|
||||
|
||||
|
||||
def _memory_context_lines(mems, limit: int = 40) -> list:
|
||||
"""Render Memory rows into short personal-context bullets for event classify.
|
||||
|
||||
Reads the Memory ORM `text` column. The previous inline code read a
|
||||
non-existent `content` attribute, so it raised AttributeError on the first
|
||||
row, the surrounding except swallowed it, and the classifier ran with no
|
||||
personal context at all. getattr keeps it robust to future schema drift.
|
||||
"""
|
||||
lines: list = []
|
||||
for m in mems:
|
||||
c = (getattr(m, "text", "") or "").strip()
|
||||
if c:
|
||||
lines.append(f"- {c[:200]}")
|
||||
if len(lines) >= limit:
|
||||
break
|
||||
return lines
|
||||
|
||||
|
||||
async def action_classify_events(owner: str, **kwargs) -> Tuple[str, bool]:
|
||||
"""Hybrid classification of upcoming calendar events: fast heuristic for
|
||||
obvious cases, LLM fallback for ambiguous ones. Assigns event_type +
|
||||
@@ -614,16 +632,11 @@ async def action_classify_events(owner: str, **kwargs) -> Tuple[str, bool]:
|
||||
try:
|
||||
from core.database import Memory as _Mem
|
||||
_mems = db.query(_Mem).filter(_Mem.owner == owner).limit(60).all() if owner else []
|
||||
if _mems:
|
||||
_lines = []
|
||||
for m in _mems:
|
||||
c = (m.content or "").strip()
|
||||
if c:
|
||||
_lines.append(f"- {c[:200]}")
|
||||
if _lines:
|
||||
_memory_context = "USER CONTEXT (relationships, work, life):\n" + "\n".join(_lines[:40]) + "\n\n"
|
||||
_lines = _memory_context_lines(_mems)
|
||||
if _lines:
|
||||
_memory_context = "USER CONTEXT (relationships, work, life):\n" + "\n".join(_lines) + "\n\n"
|
||||
except Exception as _me:
|
||||
logger.debug(f"Could not load memory for classify: {_me}")
|
||||
logger.warning(f"Could not load memory for classify: {_me}")
|
||||
|
||||
classified_h = 0
|
||||
classified_llm = 0
|
||||
@@ -796,14 +809,14 @@ async def action_learn_sender_signatures(owner: str, **kwargs) -> Tuple[str, boo
|
||||
import email as _email_mod
|
||||
import asyncio as _aio
|
||||
from datetime import datetime as _dt, timedelta as _td
|
||||
from routes.email_helpers import _imap_connect, SCHEDULED_DB
|
||||
from routes.email_helpers import _email_cache_owner_clause, _imap_connect, SCHEDULED_DB
|
||||
from src.endpoint_resolver import resolve_endpoint
|
||||
from src.llm_core import llm_call_async
|
||||
|
||||
# 1. Pull recent UIDs + From headers cheaply (header-only fetch).
|
||||
def _pull_headers():
|
||||
results = []
|
||||
conn = _imap_connect(None)
|
||||
conn = _imap_connect(None, owner=owner)
|
||||
try:
|
||||
conn.select("INBOX", readonly=True)
|
||||
status, data = conn.search(None, "ALL")
|
||||
@@ -855,9 +868,11 @@ async def action_learn_sender_signatures(owner: str, **kwargs) -> Tuple[str, boo
|
||||
# 3. Eligibility: ≥3 emails AND (no cache OR cache > 30 days old).
|
||||
try:
|
||||
conn = _sql3.connect(SCHEDULED_DB)
|
||||
owner_clause, owner_params = _email_cache_owner_clause(owner)
|
||||
cached = {
|
||||
r[0]: r[1] for r in conn.execute(
|
||||
"SELECT from_address, last_built_at FROM sender_signatures"
|
||||
f"SELECT from_address, last_built_at FROM sender_signatures WHERE {owner_clause}",
|
||||
owner_params,
|
||||
).fetchall()
|
||||
}
|
||||
conn.close()
|
||||
@@ -888,7 +903,7 @@ async def action_learn_sender_signatures(owner: str, **kwargs) -> Tuple[str, boo
|
||||
|
||||
def _fetch_bodies(_msgs):
|
||||
bodies = []
|
||||
conn2 = _imap_connect(None)
|
||||
conn2 = _imap_connect(None, owner=owner)
|
||||
try:
|
||||
conn2.select("INBOX", readonly=True)
|
||||
for mm in _msgs:
|
||||
@@ -965,11 +980,12 @@ async def action_learn_sender_signatures(owner: str, **kwargs) -> Tuple[str, boo
|
||||
|
||||
try:
|
||||
conn = _sql3.connect(SCHEDULED_DB)
|
||||
owner_value = (owner or "").strip()
|
||||
conn.execute(
|
||||
"INSERT OR REPLACE INTO sender_signatures "
|
||||
"(from_address, signature_text, sample_count, last_built_at, model_used, source) "
|
||||
"VALUES (?, ?, ?, ?, ?, ?)",
|
||||
(addr, cached_sig, len(bodies), _dt.utcnow().isoformat(), model, "llm"),
|
||||
"(from_address, owner, signature_text, sample_count, last_built_at, model_used, source) "
|
||||
"VALUES (?, ?, ?, ?, ?, ?, ?)",
|
||||
(addr, owner_value, cached_sig, len(bodies), _dt.utcnow().isoformat(), model, "llm"),
|
||||
)
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
+84
-6
@@ -5,11 +5,13 @@ Auto-registration of built-in MCP servers on startup.
|
||||
Each server runs as a stdio subprocess managed by McpManager.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
import asyncio
|
||||
|
||||
from core.platform_compat import IS_WINDOWS, which_tool
|
||||
|
||||
@@ -196,18 +198,29 @@ def _npx_package_from_args(args):
|
||||
async def _is_npx_package_cached(npx_path, package_spec, timeout_s=5):
|
||||
"""Probe whether an npx package is already in the local cache.
|
||||
|
||||
Runs `npx --no-install <pkg> --version`. --no-install tells npx to
|
||||
fail instead of downloading, so a cache miss returns fast. We treat
|
||||
"exited 0 with non-empty stdout" as proof of a working cached copy.
|
||||
Anything else (non-zero exit, empty stdout, timeout, missing npx,
|
||||
network error) means we should skip the server.
|
||||
First checks the local `_npx` cache for an installed package. If the
|
||||
package is not found there, falls back to `npx --no-install <pkg>
|
||||
--version` so older npm layouts still work without downloading.
|
||||
"""
|
||||
if _is_package_in_npx_cache(package_spec):
|
||||
return True
|
||||
|
||||
try:
|
||||
proc = await asyncio.create_subprocess_exec(
|
||||
npx_path, "--no-install", package_spec, "--version",
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
)
|
||||
except NotImplementedError:
|
||||
try:
|
||||
result = subprocess.run(
|
||||
[npx_path, "--no-install", package_spec, "--version"],
|
||||
capture_output=True,
|
||||
timeout=timeout_s,
|
||||
)
|
||||
except (subprocess.TimeoutExpired, OSError, ValueError):
|
||||
return False
|
||||
return result.returncode == 0 and bool(result.stdout.strip())
|
||||
except (OSError, ValueError):
|
||||
return False
|
||||
try:
|
||||
@@ -220,3 +233,68 @@ async def _is_npx_package_cached(npx_path, package_spec, timeout_s=5):
|
||||
pass
|
||||
return False
|
||||
return proc.returncode == 0 and bool(stdout.strip())
|
||||
|
||||
|
||||
def _is_package_in_npx_cache(package_spec):
|
||||
"""Return True when npm's `_npx` cache already contains package_spec."""
|
||||
package_name = _npx_package_name(package_spec)
|
||||
if not package_name:
|
||||
return False
|
||||
|
||||
for cache_root in _npm_cache_roots():
|
||||
npx_root = os.path.join(cache_root, "_npx")
|
||||
if _npx_cache_contains_package(npx_root, package_name):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _npx_package_name(package_spec):
|
||||
"""Strip a version/range suffix from an npm package spec."""
|
||||
if not package_spec:
|
||||
return ""
|
||||
if package_spec.startswith("@"):
|
||||
parts = package_spec.split("@", 2)
|
||||
if len(parts) >= 3:
|
||||
return f"@{parts[1]}"
|
||||
return package_spec
|
||||
return package_spec.split("@", 1)[0]
|
||||
|
||||
|
||||
def _npm_cache_roots():
|
||||
roots = []
|
||||
configured = os.environ.get("npm_config_cache")
|
||||
if configured:
|
||||
roots.append(os.path.expanduser(configured))
|
||||
roots.append(os.path.join(os.path.expanduser("~"), ".npm"))
|
||||
local_app_data = os.environ.get("LOCALAPPDATA")
|
||||
if local_app_data:
|
||||
roots.append(os.path.join(local_app_data, "npm-cache"))
|
||||
return list(dict.fromkeys(roots))
|
||||
|
||||
|
||||
def _npx_cache_contains_package(npx_root, package_name):
|
||||
if not os.path.isdir(npx_root):
|
||||
return False
|
||||
package_path = os.path.join("node_modules", *package_name.split("/"), "package.json")
|
||||
try:
|
||||
entries = list(os.scandir(npx_root))
|
||||
except OSError:
|
||||
return False
|
||||
for entry in entries:
|
||||
try:
|
||||
is_dir = entry.is_dir()
|
||||
except OSError:
|
||||
continue
|
||||
cached_name = _cached_package_name(os.path.join(entry.path, package_path))
|
||||
if is_dir and cached_name == package_name:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _cached_package_name(package_json_path):
|
||||
try:
|
||||
with open(package_json_path, encoding="utf-8") as fh:
|
||||
data = json.load(fh)
|
||||
except (OSError, ValueError):
|
||||
return ""
|
||||
return str(data.get("name", "")).strip()
|
||||
|
||||
+178
-1
@@ -128,6 +128,17 @@ def validate_caldav_url(raw_url: str) -> str:
|
||||
return urlunparse(parsed._replace(fragment="")).rstrip("/")
|
||||
|
||||
|
||||
def _event_etag(obj) -> str:
|
||||
"""Best-effort ETag extraction from python-caldav resources."""
|
||||
try:
|
||||
etag = getattr(obj, "etag", None)
|
||||
if callable(etag):
|
||||
etag = etag()
|
||||
return str(etag or "")
|
||||
except Exception:
|
||||
return ""
|
||||
|
||||
|
||||
def _stable_cal_id(remote_url: str, owner: str = "", account_id: str = "") -> str:
|
||||
"""Deterministic local id for a remote CalDAV calendar, scoped to owner
|
||||
and account so two users — or one user with two accounts — pointing at
|
||||
@@ -316,11 +327,12 @@ def _sync_blocking(owner: str, url: str, username: str, password: str, account_i
|
||||
color="#5b8abf",
|
||||
source="caldav",
|
||||
account_id=account_id or None,
|
||||
caldav_base_url=remote_url,
|
||||
)
|
||||
db.add(local_cal)
|
||||
db.commit()
|
||||
else:
|
||||
# Refresh display name and stamp account_id if missing.
|
||||
# Refresh display name and stamp CalDAV metadata if missing.
|
||||
changed = False
|
||||
if local_cal.name != display_name:
|
||||
local_cal.name = display_name
|
||||
@@ -328,6 +340,9 @@ def _sync_blocking(owner: str, url: str, username: str, password: str, account_i
|
||||
if account_id and not local_cal.account_id:
|
||||
local_cal.account_id = account_id
|
||||
changed = True
|
||||
if local_cal.caldav_base_url != remote_url:
|
||||
local_cal.caldav_base_url = remote_url
|
||||
changed = True
|
||||
if changed:
|
||||
db.commit()
|
||||
result["calendars"] += 1
|
||||
@@ -395,6 +410,9 @@ def _sync_blocking(owner: str, url: str, username: str, password: str, account_i
|
||||
|
||||
existing = _find_existing_event(db, pending, uid_val, local_cal.id)
|
||||
if existing:
|
||||
if existing.caldav_sync_pending in {"create", "update"}:
|
||||
result["events"] += 1
|
||||
continue
|
||||
existing.calendar_id = local_cal.id
|
||||
existing.summary = summary
|
||||
existing.description = description
|
||||
@@ -405,6 +423,9 @@ def _sync_blocking(owner: str, url: str, username: str, password: str, account_i
|
||||
existing.is_utc = row_is_utc
|
||||
existing.rrule = rrule
|
||||
existing.origin = "caldav"
|
||||
existing.remote_href = str(getattr(obj, "url", "") or "") or None
|
||||
existing.remote_etag = _event_etag(obj) or None
|
||||
existing.caldav_sync_pending = None
|
||||
else:
|
||||
new_ev = CalendarEvent(
|
||||
uid=uid_val,
|
||||
@@ -418,6 +439,8 @@ def _sync_blocking(owner: str, url: str, username: str, password: str, account_i
|
||||
is_utc=row_is_utc,
|
||||
rrule=rrule,
|
||||
origin="caldav",
|
||||
remote_href=str(getattr(obj, "url", "") or "") or None,
|
||||
remote_etag=_event_etag(obj) or None,
|
||||
)
|
||||
db.add(new_ev)
|
||||
pending[uid_val] = new_ev
|
||||
@@ -442,6 +465,8 @@ def _sync_blocking(owner: str, url: str, username: str, password: str, account_i
|
||||
CalendarEvent.origin == "caldav",
|
||||
CalendarEvent.dtstart >= start,
|
||||
CalendarEvent.dtstart <= end,
|
||||
CalendarEvent.remote_href.isnot(None),
|
||||
CalendarEvent.caldav_sync_pending.is_(None),
|
||||
~CalendarEvent.uid.in_(seen_uids) if seen_uids else CalendarEvent.uid.isnot(None),
|
||||
).all()
|
||||
for ev in stale:
|
||||
@@ -458,6 +483,92 @@ def _sync_blocking(owner: str, url: str, username: str, password: str, account_i
|
||||
return result
|
||||
|
||||
|
||||
def _event_payload(ev) -> dict:
|
||||
return {
|
||||
"uid": ev.uid,
|
||||
"summary": ev.summary,
|
||||
"description": ev.description,
|
||||
"location": ev.location,
|
||||
"dtstart": ev.dtstart,
|
||||
"dtend": ev.dtend,
|
||||
"all_day": ev.all_day,
|
||||
"is_utc": ev.is_utc,
|
||||
"rrule": ev.rrule or "",
|
||||
}
|
||||
|
||||
|
||||
def _load_event_for_writeback(owner: str, uid: str) -> tuple[str, str, dict] | None:
|
||||
from core.database import CalendarCal, CalendarEvent, SessionLocal
|
||||
|
||||
db = SessionLocal()
|
||||
try:
|
||||
ev = (
|
||||
db.query(CalendarEvent)
|
||||
.join(CalendarCal)
|
||||
.filter(CalendarEvent.uid == uid, CalendarCal.owner == owner)
|
||||
.first()
|
||||
)
|
||||
if not ev or not ev.calendar or ev.calendar.source != "caldav":
|
||||
return None
|
||||
return ev.calendar.source, ev.calendar.id, _event_payload(ev)
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
def _load_delete_for_writeback(owner: str, uid: str) -> tuple[str, str, dict] | None:
|
||||
from core.database import CalendarCal, CalendarDeletedEvent, CalendarEvent, SessionLocal
|
||||
|
||||
db = SessionLocal()
|
||||
try:
|
||||
tombstone = db.query(CalendarDeletedEvent).filter(
|
||||
CalendarDeletedEvent.uid == uid,
|
||||
CalendarDeletedEvent.owner == owner,
|
||||
).first()
|
||||
if tombstone:
|
||||
return "caldav", tombstone.calendar_id, {"uid": uid}
|
||||
|
||||
ev = (
|
||||
db.query(CalendarEvent)
|
||||
.join(CalendarCal)
|
||||
.filter(CalendarEvent.uid == uid, CalendarCal.owner == owner)
|
||||
.first()
|
||||
)
|
||||
if not ev or not ev.calendar or ev.calendar.source != "caldav":
|
||||
return None
|
||||
return ev.calendar.source, ev.calendar.id, {"uid": uid}
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
def _pending_writeback_uids(owner: str) -> tuple[list[str], list[str]]:
|
||||
from core.database import CalendarCal, CalendarDeletedEvent, CalendarEvent, SessionLocal
|
||||
|
||||
db = SessionLocal()
|
||||
try:
|
||||
rows = (
|
||||
db.query(CalendarEvent.uid)
|
||||
.join(CalendarCal)
|
||||
.filter(
|
||||
CalendarCal.owner == owner,
|
||||
CalendarCal.source == "caldav",
|
||||
CalendarEvent.status != "cancelled",
|
||||
(
|
||||
(CalendarEvent.caldav_sync_pending.isnot(None))
|
||||
| (CalendarEvent.remote_href.is_(None))
|
||||
),
|
||||
)
|
||||
.all()
|
||||
)
|
||||
delete_rows = (
|
||||
db.query(CalendarDeletedEvent.uid)
|
||||
.filter(CalendarDeletedEvent.owner == owner)
|
||||
.all()
|
||||
)
|
||||
return [row[0] for row in rows], [row[0] for row in delete_rows]
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
def _load_caldav_accounts(owner: str) -> list:
|
||||
"""Return the list of CalDAV accounts for *owner*, auto-migrating the legacy
|
||||
single-account ``caldav`` key to the new ``caldav_accounts`` list on first call.
|
||||
@@ -533,3 +644,69 @@ async def sync_caldav(owner: str) -> dict:
|
||||
for err in result.get("errors", []):
|
||||
totals["errors"].append(f"{label}: {err}")
|
||||
return totals
|
||||
|
||||
|
||||
async def push_event_create(owner: str, uid: str) -> dict:
|
||||
loaded = _load_event_for_writeback(owner, uid)
|
||||
if not loaded:
|
||||
return {"ok": True, "skipped": True}
|
||||
source, calendar_id, payload = loaded
|
||||
from src.caldav_writeback import writeback_event
|
||||
return await writeback_event(owner, source, calendar_id, payload)
|
||||
|
||||
|
||||
async def push_event_update(owner: str, uid: str) -> dict:
|
||||
return await push_event_create(owner, uid)
|
||||
|
||||
|
||||
async def push_event_delete(owner: str, uid: str) -> dict:
|
||||
loaded = _load_delete_for_writeback(owner, uid)
|
||||
if not loaded:
|
||||
return {"ok": True, "skipped": True}
|
||||
source, calendar_id, payload = loaded
|
||||
from src.caldav_writeback import writeback_event
|
||||
return await writeback_event(owner, source, calendar_id, payload, delete=True)
|
||||
|
||||
|
||||
async def push_pending_events(owner: str) -> dict:
|
||||
result = {"events": 0, "errors": []}
|
||||
uids, delete_uids = _pending_writeback_uids(owner)
|
||||
for event_uid in uids:
|
||||
try:
|
||||
out = await push_event_update(owner, event_uid)
|
||||
if out.get("ok"):
|
||||
result["events"] += 1
|
||||
elif not out.get("skipped"):
|
||||
result["errors"].append(f"{event_uid}: {str(out.get('error') or out)[:160]}")
|
||||
except Exception as e:
|
||||
logger.warning("CalDAV pending push failed for uid=%s: %s", event_uid, e)
|
||||
result["errors"].append(f"{event_uid}: {str(e)[:160]}")
|
||||
for event_uid in delete_uids:
|
||||
try:
|
||||
out = await push_event_delete(owner, event_uid)
|
||||
if out.get("ok"):
|
||||
result["events"] += 1
|
||||
elif not out.get("skipped"):
|
||||
result["errors"].append(f"{event_uid}: {str(out.get('error') or out)[:160]}")
|
||||
except Exception as e:
|
||||
logger.warning("CalDAV pending delete failed for uid=%s: %s", event_uid, e)
|
||||
result["errors"].append(f"{event_uid}: {str(e)[:160]}")
|
||||
return result
|
||||
|
||||
|
||||
async def sync_caldav_direction(owner: str, direction: str = "pull") -> dict:
|
||||
direction = (direction or "pull").strip().lower()
|
||||
if direction == "pull":
|
||||
return await sync_caldav(owner)
|
||||
if direction == "push":
|
||||
return await push_pending_events(owner)
|
||||
if direction == "both":
|
||||
pushed = await push_pending_events(owner)
|
||||
pulled = await sync_caldav(owner)
|
||||
return {"push": pushed, "pull": pulled}
|
||||
return {
|
||||
"calendars": 0,
|
||||
"events": 0,
|
||||
"deleted": 0,
|
||||
"errors": [f"Unsupported CalDAV sync direction: {direction}"],
|
||||
}
|
||||
|
||||
+92
-6
@@ -89,6 +89,23 @@ def find_remote_calendar(calendars, local_cal_id: str, owner: str = "", account_
|
||||
return None
|
||||
|
||||
|
||||
def _resource_href(obj) -> str:
|
||||
try:
|
||||
return str(getattr(obj, "url", "") or "")
|
||||
except Exception:
|
||||
return ""
|
||||
|
||||
|
||||
def _resource_etag(obj) -> str:
|
||||
try:
|
||||
etag = getattr(obj, "etag", None)
|
||||
if callable(etag):
|
||||
etag = etag()
|
||||
return str(etag or "")
|
||||
except Exception:
|
||||
return ""
|
||||
|
||||
|
||||
def push_event(calendars, local_cal_id: str, ev: dict, *, delete: bool = False,
|
||||
owner: str = "", account_id: str = "") -> dict:
|
||||
"""Create/update (or delete) ``ev`` on the matching remote calendar.
|
||||
@@ -105,6 +122,7 @@ def push_event(calendars, local_cal_id: str, ev: dict, *, delete: bool = False,
|
||||
remote = find_remote_calendar(calendars, local_cal_id, owner=owner, account_id=account_id)
|
||||
if remote is None:
|
||||
return {"ok": False, "error": "remote calendar not found"}
|
||||
remote_url = str(getattr(remote, "url", "") or "")
|
||||
|
||||
try:
|
||||
existing = remote.event_by_uid(uid)
|
||||
@@ -113,17 +131,34 @@ def push_event(calendars, local_cal_id: str, ev: dict, *, delete: bool = False,
|
||||
|
||||
if delete:
|
||||
if existing is None:
|
||||
return {"ok": True, "note": "already absent on remote"}
|
||||
return {"ok": True, "note": "already absent on remote", "calendar_url": remote_url}
|
||||
existing.delete()
|
||||
return {"ok": True}
|
||||
return {
|
||||
"ok": True,
|
||||
"calendar_url": remote_url,
|
||||
"remote_href": _resource_href(existing),
|
||||
"remote_etag": _resource_etag(existing),
|
||||
}
|
||||
|
||||
ical = build_event_ical(ev)
|
||||
if existing is not None:
|
||||
existing.data = ical
|
||||
existing.save()
|
||||
return {"ok": True, "updated": True}
|
||||
remote.save_event(ical)
|
||||
return {"ok": True, "created": True}
|
||||
return {
|
||||
"ok": True,
|
||||
"updated": True,
|
||||
"calendar_url": remote_url,
|
||||
"remote_href": _resource_href(existing),
|
||||
"remote_etag": _resource_etag(existing),
|
||||
}
|
||||
created = remote.save_event(ical)
|
||||
return {
|
||||
"ok": True,
|
||||
"created": True,
|
||||
"calendar_url": remote_url,
|
||||
"remote_href": _resource_href(created),
|
||||
"remote_etag": _resource_etag(created),
|
||||
}
|
||||
|
||||
|
||||
def _discover_calendars(client):
|
||||
@@ -154,6 +189,54 @@ def _writeback_blocking(local_cal_id, ev, delete, url, username, password,
|
||||
owner=owner, account_id=account_id)
|
||||
|
||||
|
||||
def _persist_writeback_result(owner: str, calendar_id: str, uid: str, result: dict, *, delete: bool) -> None:
|
||||
from core.database import CalendarCal, CalendarDeletedEvent, CalendarEvent, SessionLocal
|
||||
|
||||
if not uid or not isinstance(result, dict):
|
||||
return
|
||||
|
||||
db = SessionLocal()
|
||||
try:
|
||||
calendar = db.query(CalendarCal).filter(
|
||||
CalendarCal.id == calendar_id,
|
||||
CalendarCal.owner == owner,
|
||||
).first()
|
||||
if calendar and result.get("calendar_url"):
|
||||
calendar.caldav_base_url = result.get("calendar_url")
|
||||
|
||||
if delete:
|
||||
tombstone = db.query(CalendarDeletedEvent).filter(
|
||||
CalendarDeletedEvent.uid == uid,
|
||||
CalendarDeletedEvent.owner == owner,
|
||||
).first()
|
||||
if result.get("ok"):
|
||||
if tombstone:
|
||||
db.delete(tombstone)
|
||||
elif tombstone:
|
||||
tombstone.last_error = str(result.get("error") or result)[:500]
|
||||
db.commit()
|
||||
return
|
||||
|
||||
event = (
|
||||
db.query(CalendarEvent)
|
||||
.join(CalendarCal)
|
||||
.filter(CalendarEvent.uid == uid, CalendarCal.owner == owner)
|
||||
.first()
|
||||
)
|
||||
if event and result.get("ok"):
|
||||
if result.get("remote_href"):
|
||||
event.remote_href = result.get("remote_href")
|
||||
if result.get("remote_etag"):
|
||||
event.remote_etag = result.get("remote_etag")
|
||||
event.caldav_sync_pending = None
|
||||
db.commit()
|
||||
except Exception:
|
||||
db.rollback()
|
||||
logger.exception("CalDAV write-back metadata persistence failed")
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
async def writeback_event(owner: str, calendar_source: str, calendar_id: str,
|
||||
ev: dict, *, delete: bool = False) -> dict:
|
||||
"""Best-effort push of a local change to the remote CalDAV server.
|
||||
@@ -204,9 +287,12 @@ async def writeback_event(owner: str, calendar_source: str, calendar_id: str,
|
||||
result = await asyncio.to_thread(
|
||||
_writeback_blocking, calendar_id, ev, delete, url, user, pw, owner, acc_id
|
||||
)
|
||||
_persist_writeback_result(owner, calendar_id, (ev or {}).get("uid", ""), result, delete=delete)
|
||||
if not result.get("ok"):
|
||||
logger.warning("CalDAV write-back did not apply: %s", result.get("error") or result)
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.exception("CalDAV write-back raised")
|
||||
return {"ok": False, "error": str(e)[:200]}
|
||||
result = {"ok": False, "error": str(e)[:200]}
|
||||
_persist_writeback_result(owner, calendar_id, (ev or {}).get("uid", ""), result, delete=delete)
|
||||
return result
|
||||
|
||||
+13
-9
@@ -175,6 +175,19 @@ class ChatProcessor:
|
||||
|
||||
Returns:
|
||||
Tuple of (preface messages, rag_sources list)
|
||||
|
||||
Note on KV-cache friendliness: the ``system``-role messages assembled
|
||||
here are later concatenated into a single system message and sent as
|
||||
the very first thing in the payload (see ``llm_core``'s "consolidate
|
||||
system messages" step). Local OpenAI-compatible backends (llama.cpp /
|
||||
LM Studio) key their KV cache off the byte-identical token prefix, so
|
||||
*anything* that changes turn-to-turn — timestamps, retrieved snippets,
|
||||
per-turn counts — must NOT be folded into a system message here. Such
|
||||
content belongs in a separate ``user``/context message appended near
|
||||
the end of the array (see ``current_datetime_context_message`` and
|
||||
``untrusted_context_message`` callers in ``build_chat_context``),
|
||||
which keeps the static system prefix byte-identical across turns of
|
||||
the same session and lets the backend reuse its cached prefix.
|
||||
"""
|
||||
preface = []
|
||||
rag_sources = []
|
||||
@@ -185,15 +198,6 @@ class ChatProcessor:
|
||||
"role": "system",
|
||||
"content": preset_system_prompt
|
||||
})
|
||||
if not agent_mode:
|
||||
try:
|
||||
from src.user_time import current_datetime_prompt
|
||||
preface.append({
|
||||
"role": "system",
|
||||
"content": current_datetime_prompt(),
|
||||
})
|
||||
except Exception:
|
||||
logger.debug("Failed to add current date/time context", exc_info=True)
|
||||
preface.append({
|
||||
"role": "system",
|
||||
"content": UNTRUSTED_CONTEXT_POLICY,
|
||||
|
||||
+27
-7
@@ -31,16 +31,22 @@ def compute_input_token_budget(
|
||||
|
||||
Args:
|
||||
configured: the value read from settings (may be the default).
|
||||
context_length: the model's discovered context window (0/unknown if none).
|
||||
explicit: True if the user explicitly set ``agent_input_token_budget``.
|
||||
context_length: the model's discovered context window. Pass 0 when the
|
||||
window is unknown / only a bare fallback — auto-scaling then stays
|
||||
conservative instead of trusting an unproven window (review on #4122).
|
||||
explicit: True if the user set a NON-default budget. The default value is
|
||||
the "auto" sentinel (scale to the window); any other value is an
|
||||
explicit cap. (A deliberately-chosen default can't be distinguished
|
||||
from a materialized default by value, so the default reads as auto.)
|
||||
|
||||
Rules:
|
||||
- Explicit user budget is honoured exactly, only clamped to the model's
|
||||
window when that window is known (never send more than the model holds).
|
||||
- Otherwise (default), scale to ``headroom`` of the context window, capped
|
||||
at ``hard_max`` — so long-context models use their capacity.
|
||||
- When the window is unknown, fall back to the configured/default value
|
||||
(preserving the previous behaviour).
|
||||
window when that window is known (the user's deliberate choice wins;
|
||||
``hard_max`` is an auto-budget ceiling only — see #1230).
|
||||
- Otherwise (auto), scale to ``headroom`` of the context window, capped at
|
||||
``hard_max`` — so long-context models use their capacity.
|
||||
- When the window is unknown (context_length <= 0), use the conservative
|
||||
``default`` budget and do NOT scale off the fallback.
|
||||
"""
|
||||
configured = int(configured or 0)
|
||||
context_length = int(context_length or 0)
|
||||
@@ -53,3 +59,17 @@ def compute_input_token_budget(
|
||||
return max(1, min(scaled, hard_max))
|
||||
|
||||
return configured if configured > 0 else default
|
||||
|
||||
|
||||
def budget_is_explicit(configured: int, *, default: int = DEFAULT_BUDGET) -> bool:
|
||||
"""Whether a configured agent_input_token_budget is a deliberate explicit cap.
|
||||
|
||||
The default value is the "auto" sentinel (scale to the model's window), so only
|
||||
a NON-default positive value counts as explicit. This keys off the VALUE, not
|
||||
settings *presence* — the settings-save path materializes every default into
|
||||
settings.json, so a persisted default must still read as auto (the regression
|
||||
#4121 / #1230 are about). Centralised here so the materialized-default contract
|
||||
is unit-testable and can't silently regress to a presence check.
|
||||
"""
|
||||
configured = int(configured or 0)
|
||||
return configured > 0 and configured != default
|
||||
|
||||
@@ -244,9 +244,17 @@ def trim_for_context(messages: List[Dict], context_length: int, reserve_tokens:
|
||||
protected_tokens = estimate_tokens(protected_msgs)
|
||||
budget -= protected_tokens
|
||||
|
||||
# Priority: keep first system msg (preset prompt), drop others (memory, RAG, memo)
|
||||
essential_system = system_msgs[:1] if system_msgs else []
|
||||
extra_system = system_msgs[1:]
|
||||
# Priority: keep first system msg (preset prompt), drop others (memory, RAG, memo).
|
||||
# Exception: a research-spinoff primer (the seeded report that grounds a
|
||||
# "Discuss" chat) must never be dropped — it is the conversation's whole
|
||||
# knowledge base. Treat any system message carrying research_spinoff_from
|
||||
# metadata as essential alongside the leading system prompt.
|
||||
def _is_research_primer(m):
|
||||
return bool((m.get("metadata") or {}).get("research_spinoff_from"))
|
||||
_primers = [m for m in system_msgs if _is_research_primer(m)]
|
||||
_non_primer = [m for m in system_msgs if not _is_research_primer(m)]
|
||||
essential_system = (_non_primer[:1] if _non_primer else []) + _primers
|
||||
extra_system = _non_primer[1:]
|
||||
|
||||
# Try dropping extra system messages one by one (from the end)
|
||||
trimmed = essential_system + convo_msgs
|
||||
@@ -438,8 +446,8 @@ def _update_session_history(session, split_point: int, summary: str,
|
||||
)
|
||||
new_history = system_prefix + [summary_msg] + recent_history
|
||||
try:
|
||||
from core import models as _core_models
|
||||
manager = getattr(_core_models, "_session_manager", None)
|
||||
from core.models import get_session_manager_instance
|
||||
manager = get_session_manager_instance()
|
||||
except Exception:
|
||||
manager = None
|
||||
if manager and getattr(session, "id", None):
|
||||
|
||||
@@ -136,7 +136,8 @@ async def _tick() -> None:
|
||||
return
|
||||
try:
|
||||
state = json.loads(state_path.read_text(encoding="utf-8"))
|
||||
except Exception:
|
||||
except Exception as e:
|
||||
logger.warning("cookbook_serve_lifecycle: state file unreadable (%s), skipping tick", e)
|
||||
return
|
||||
tasks = state.get("tasks") or []
|
||||
now_ms = int(time.time() * 1000)
|
||||
@@ -178,8 +179,26 @@ async def _tick() -> None:
|
||||
if stopped_any:
|
||||
try:
|
||||
from core.atomic_io import atomic_write_json
|
||||
state["tasks"] = tasks
|
||||
atomic_write_json(state_path, state)
|
||||
# Re-read the state file so concurrent UI writes (task adds,
|
||||
# status flips, config edits) are not silently overwritten.
|
||||
# Apply only our stop mutations to the fresh snapshot.
|
||||
try:
|
||||
fresh = json.loads(state_path.read_text(encoding="utf-8"))
|
||||
fresh_tasks = fresh.get("tasks") or []
|
||||
except Exception:
|
||||
fresh = state
|
||||
fresh_tasks = tasks
|
||||
stopped_sids = {sid for sid, _, _ in to_stop}
|
||||
for ft in fresh_tasks:
|
||||
if not isinstance(ft, dict):
|
||||
continue
|
||||
ft_sid = ft.get("sessionId") or ft.get("id")
|
||||
if ft_sid in stopped_sids:
|
||||
ft["status"] = "stopped"
|
||||
ft["_scheduledStopAtMs"] = None
|
||||
ft["_lastStatusFlipAt"] = now_ms
|
||||
fresh["tasks"] = fresh_tasks
|
||||
atomic_write_json(state_path, fresh)
|
||||
except Exception as e:
|
||||
logger.warning(f"cookbook_serve_lifecycle: state write failed: {e}")
|
||||
|
||||
|
||||
@@ -232,6 +232,7 @@ class DeepResearcher:
|
||||
self._start_time: float = 0
|
||||
self.queries_used: Set[str] = set()
|
||||
self.urls_fetched: Set[str] = set()
|
||||
self.analyzed_urls: List[Dict[str, str]] = []
|
||||
self.round_count: int = 0
|
||||
# Track which search providers actually returned results during the
|
||||
# run, in arrival order — surfaced in the visual report so users can
|
||||
@@ -525,6 +526,10 @@ class DeepResearcher:
|
||||
if url and url not in self.urls_fetched:
|
||||
urls_to_fetch.append(r)
|
||||
self.urls_fetched.add(url)
|
||||
self.analyzed_urls.append({
|
||||
"url": url,
|
||||
"title": r.get("title", "") or url,
|
||||
})
|
||||
if len(urls_to_fetch) >= self.max_urls_per_round * len(queries):
|
||||
break
|
||||
|
||||
|
||||
+11
-2
@@ -196,13 +196,22 @@ def _get_or_reset_collection(chroma_client, name: str, metadata: Dict[str, Any],
|
||||
try:
|
||||
chroma_client.delete_collection(name)
|
||||
restored = chroma_client.get_or_create_collection(name=name, metadata=current)
|
||||
old_embeddings = preserved.get("embeddings") or []
|
||||
if ids and docs and old_embeddings:
|
||||
# chromadb returns embeddings as a numpy ndarray, whose truth value
|
||||
# is ambiguous — `preserved.get("embeddings") or []` and a bare
|
||||
# `if ... and old_embeddings:` both raise ValueError, which aborts
|
||||
# the restore and loses the rows the reset was supposed to keep.
|
||||
# Use explicit None/len checks instead.
|
||||
old_embeddings = preserved.get("embeddings")
|
||||
if old_embeddings is None:
|
||||
old_embeddings = []
|
||||
if ids and docs and len(old_embeddings):
|
||||
for start in range(0, len(ids), 100):
|
||||
batch_ids = ids[start:start + 100]
|
||||
batch_docs = docs[start:start + 100]
|
||||
batch_metas = metas[start:start + 100]
|
||||
batch_embeddings = old_embeddings[start:start + 100]
|
||||
if hasattr(batch_embeddings, "tolist"):
|
||||
batch_embeddings = batch_embeddings.tolist()
|
||||
if len(batch_metas) < len(batch_ids):
|
||||
batch_metas += [{}] * (len(batch_ids) - len(batch_metas))
|
||||
restored.add(
|
||||
|
||||
+27
-14
@@ -12,7 +12,7 @@ from typing import Optional, Tuple, Dict
|
||||
from urllib.parse import urlparse, urlunparse
|
||||
|
||||
from core.database import SessionLocal, ModelEndpoint
|
||||
from src.llm_core import _detect_provider, _host_match, _ollama_api_root
|
||||
from src.llm_core import _detect_provider, _host_match, _is_kimi_code_url, KIMI_CODE_USER_AGENT, _ollama_api_root
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -183,7 +183,16 @@ def build_chat_url(base: str) -> str:
|
||||
|
||||
|
||||
def build_models_url(base: str) -> Optional[str]:
|
||||
"""Return the provider-specific model-list endpoint URL for a base."""
|
||||
"""Return the provider-specific model-list endpoint URL for a base.
|
||||
|
||||
For OpenAI-compatible servers (LM Studio, llama.cpp, vLLM,
|
||||
text-generation-webui, etc.) the model list is exposed at ``/v1/models``.
|
||||
When the user-supplied base has no path — e.g. ``http://localhost:1234`` —
|
||||
we still need to land on ``/v1/models`` (issue #25); insert the ``/v1``
|
||||
segment only when the path is empty, leaving any explicit non-empty path
|
||||
untouched (so custom prefixes like ``/openai`` or ``/api/openai/v1`` keep
|
||||
their semantics).
|
||||
"""
|
||||
base = normalize_base(resolve_url(base))
|
||||
provider = _detect_provider(base)
|
||||
if provider == "anthropic":
|
||||
@@ -192,6 +201,12 @@ def build_models_url(base: str) -> Optional[str]:
|
||||
return _ollama_api_root(base) + "/tags"
|
||||
if provider == "chatgpt-subscription":
|
||||
return None
|
||||
# Generic OpenAI-compatible fallback: ensure the path lands on /v1/models
|
||||
# when the user omitted a path entirely. If a non-empty path is already
|
||||
# present (e.g. /openai, /api/openai/v1, /v1), trust the caller — the
|
||||
# /models suffix is appended as-is and the caller's prefix is preserved.
|
||||
if not urlparse(base).path:
|
||||
base = base + "/v1"
|
||||
return base + "/models"
|
||||
|
||||
|
||||
@@ -215,6 +230,8 @@ def build_headers(api_key: Optional[str], base: str) -> Dict[str, str]:
|
||||
if provider == "openrouter":
|
||||
headers.setdefault("HTTP-Referer", "https://github.com/pewdiepie-archdaemon/odysseus")
|
||||
headers.setdefault("X-OpenRouter-Title", "Odysseus")
|
||||
if _is_kimi_code_url(base):
|
||||
headers.setdefault("User-Agent", KIMI_CODE_USER_AGENT)
|
||||
return headers
|
||||
|
||||
|
||||
@@ -250,27 +267,23 @@ def resolve_endpoint(
|
||||
ep_id = _stg(f"{setting_prefix}_endpoint_id")
|
||||
model = _stg(f"{setting_prefix}_model")
|
||||
|
||||
# If the specific endpoint is not configured, but the caller provided a
|
||||
# Fall back to utility model for task/research/auto-naming if not specifically configured.
|
||||
if not ep_id and setting_prefix not in ("utility", "default"):
|
||||
ep_id = _stg("utility_endpoint_id")
|
||||
model = _stg("utility_model")
|
||||
|
||||
# If the endpoint is STILL not configured, but the caller provided a
|
||||
# valid fallback (e.g. the active session model), use that immediately.
|
||||
# This prevents background tasks from jumping to the global default_model
|
||||
# when the user is mid-conversation with a different model.
|
||||
if not ep_id and fallback_url and fallback_model:
|
||||
return fallback_url, fallback_model, fallback_headers
|
||||
|
||||
# Unset Utility means "same as Default Chat Model".
|
||||
if setting_prefix == "utility" and not ep_id:
|
||||
# Unset Utility (or anything else that didn't have a fallback) means "same as Default Chat Model".
|
||||
if not ep_id:
|
||||
ep_id = _stg("default_endpoint_id")
|
||||
model = _stg("default_model")
|
||||
|
||||
# Fall back to utility model for task/research/auto-naming if not specifically configured.
|
||||
# If Utility itself is unset, the block above makes that resolve to Default Chat.
|
||||
if not ep_id and setting_prefix != "utility":
|
||||
ep_id = _stg("utility_endpoint_id")
|
||||
model = _stg("utility_model")
|
||||
if not ep_id:
|
||||
ep_id = _stg("default_endpoint_id")
|
||||
model = _stg("default_model")
|
||||
|
||||
if not ep_id:
|
||||
return fallback_url, fallback_model, fallback_headers
|
||||
|
||||
|
||||
+79
-5
@@ -6,6 +6,7 @@ import re
|
||||
from typing import Dict, List, Optional, Any
|
||||
|
||||
import httpx
|
||||
from fastapi import HTTPException
|
||||
|
||||
from core.atomic_io import atomic_write_json
|
||||
from core.platform_compat import safe_chmod
|
||||
@@ -258,6 +259,11 @@ def add_integration(data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
integration.setdefault("name", "")
|
||||
integration.setdefault("base_url", "")
|
||||
|
||||
if not isinstance(integration.get("name"), str) or not integration["name"].strip():
|
||||
raise HTTPException(400, "Integration name is required")
|
||||
if not isinstance(integration.get("base_url"), str) or not integration["base_url"].strip():
|
||||
raise HTTPException(400, "Integration base URL is required")
|
||||
|
||||
integrations = load_integrations()
|
||||
integrations.append(integration)
|
||||
save_integrations(integrations)
|
||||
@@ -266,6 +272,11 @@ def add_integration(data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
|
||||
def update_integration(integration_id: str, data: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
||||
"""Update fields on an existing integration. Returns updated integration or None."""
|
||||
if "name" in data and (not isinstance(data["name"], str) or not data["name"].strip()):
|
||||
raise HTTPException(400, "Integration name is required")
|
||||
if "base_url" in data and (not isinstance(data["base_url"], str) or not data["base_url"].strip()):
|
||||
raise HTTPException(400, "Integration base URL is required")
|
||||
|
||||
integrations = load_integrations()
|
||||
for item in integrations:
|
||||
if item.get("id") == integration_id:
|
||||
@@ -411,17 +422,80 @@ async def execute_api_call(
|
||||
if "application/json" in content_type:
|
||||
try:
|
||||
data = response.json()
|
||||
formatted = json.dumps(data, indent=2, ensure_ascii=False)
|
||||
full = json.dumps(data, indent=2, ensure_ascii=False)
|
||||
if len(full) > 12000:
|
||||
if isinstance(data, list):
|
||||
# Binary-search for the largest prefix such that the
|
||||
# final array (prefix + sentinel) fits within the limit.
|
||||
# Pre-compute the sentinel so we know its serialized size.
|
||||
sentinel_placeholder = {
|
||||
"_truncated": True,
|
||||
"total_items": len(data),
|
||||
"shown_items": 0,
|
||||
}
|
||||
# Overhead: the sentinel appears as an extra array element.
|
||||
# Add a conservative padding for the separating comma,
|
||||
# newline, and indentation characters (~6 chars).
|
||||
sentinel_overhead = len(
|
||||
json.dumps(sentinel_placeholder, indent=2, ensure_ascii=False)
|
||||
) + 6
|
||||
budget = 12000 - sentinel_overhead
|
||||
lo, hi = 0, len(data)
|
||||
while lo < hi:
|
||||
mid = (lo + hi + 1) // 2
|
||||
candidate = json.dumps(
|
||||
data[:mid], indent=2, ensure_ascii=False
|
||||
)
|
||||
if len(candidate) < budget:
|
||||
lo = mid
|
||||
else:
|
||||
hi = mid - 1
|
||||
sentinel = {
|
||||
"_truncated": True,
|
||||
"total_items": len(data),
|
||||
"shown_items": lo,
|
||||
}
|
||||
formatted = json.dumps(
|
||||
data[:lo] + [sentinel], indent=2, ensure_ascii=False
|
||||
)
|
||||
elif isinstance(data, dict):
|
||||
# Truncate dict entries until the result fits, then add
|
||||
# the _truncated marker. Walk keys in insertion order.
|
||||
DICT_LIMIT = 12000
|
||||
kept: dict = {}
|
||||
for k, v in data.items():
|
||||
candidate = json.dumps(
|
||||
{**kept, k: v, "_truncated": True},
|
||||
indent=2,
|
||||
ensure_ascii=False,
|
||||
)
|
||||
if len(candidate) <= DICT_LIMIT:
|
||||
kept[k] = v
|
||||
else:
|
||||
break
|
||||
formatted = json.dumps(
|
||||
{**kept, "_truncated": True}, indent=2, ensure_ascii=False
|
||||
)
|
||||
else:
|
||||
total = len(full)
|
||||
formatted = full[:12000] + f"\n... (truncated, {total} chars total)"
|
||||
else:
|
||||
formatted = full
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
formatted = response.text
|
||||
if len(formatted) > 12000:
|
||||
total = len(formatted)
|
||||
formatted = formatted[:12000] + f"\n... (truncated, {total} chars total)"
|
||||
elif "text/html" in content_type:
|
||||
formatted = _strip_html_tags(response.text)
|
||||
if len(formatted) > 12000:
|
||||
total = len(formatted)
|
||||
formatted = formatted[:12000] + f"\n... (truncated, {total} chars total)"
|
||||
else:
|
||||
formatted = response.text
|
||||
|
||||
# Truncate
|
||||
if len(formatted) > 12000:
|
||||
formatted = formatted[:12000] + "\n... (truncated)"
|
||||
if len(formatted) > 12000:
|
||||
total = len(formatted)
|
||||
formatted = formatted[:12000] + f"\n... (truncated, {total} chars total)"
|
||||
|
||||
output = f"HTTP {status}\n{formatted}"
|
||||
|
||||
|
||||
+314
-17
@@ -7,6 +7,7 @@ import logging
|
||||
import hashlib
|
||||
import threading
|
||||
import re
|
||||
import os
|
||||
from fastapi import HTTPException
|
||||
from typing import Optional, Dict, List, Tuple
|
||||
from src.model_context import get_context_length, DEFAULT_CONTEXT
|
||||
@@ -22,6 +23,24 @@ class LLMConfig:
|
||||
MAX_RETRIES = 3
|
||||
RETRY_DELAY = 0.5
|
||||
STREAM_TIMEOUT = 300
|
||||
# TCP+TLS connect budget for a SINGLE attempt. The old hard-coded 3.0s
|
||||
# assumed LAN/Tailscale peers ('SYN in <100ms'); it is too tight for public
|
||||
# cloud endpoints (offshore APIs take ~0.5-1.5s cold, with jitter), so a
|
||||
# brief blip on the first connect of an idle chat surfaced as a 503 on the
|
||||
# streaming path (which, unlike llm_call, does not retry the connect). A
|
||||
# genuinely dead upstream stays bounded by the dead-host cooldown. Override
|
||||
# with env LLM_CONNECT_TIMEOUT (seconds).
|
||||
CONNECT_TIMEOUT = float(os.getenv('LLM_CONNECT_TIMEOUT', '10') or '10')
|
||||
|
||||
|
||||
def _call_timeout(read_timeout) -> httpx.Timeout:
|
||||
"""Per-request timeout for non-streaming LLM calls (connect from config)."""
|
||||
return httpx.Timeout(connect=LLMConfig.CONNECT_TIMEOUT, read=float(read_timeout), write=10.0, pool=5.0)
|
||||
|
||||
|
||||
def _stream_timeout(read_timeout) -> httpx.Timeout:
|
||||
"""Per-request timeout for streaming LLM calls (connect from config)."""
|
||||
return httpx.Timeout(connect=LLMConfig.CONNECT_TIMEOUT, read=float(read_timeout), write=30.0, pool=5.0)
|
||||
|
||||
|
||||
# Cache for LLM responses
|
||||
@@ -276,6 +295,24 @@ def _is_ollama_native_url(url: str) -> bool:
|
||||
return local_ollama_host and (path == "" or path == "/api" or path.startswith("/api/"))
|
||||
|
||||
|
||||
def _is_ollama_openai_compat_url(url: str) -> bool:
|
||||
"""Return True for local Ollama's OpenAI-compatible /v1 surface.
|
||||
|
||||
Mirrors the host detection used by ``_is_ollama_native_url`` so that the
|
||||
two helpers stay in lockstep: a localhost Ollama on a non-default port
|
||||
(custom ``OLLAMA_HOST``, reverse proxy, container port remap) is treated
|
||||
the same way here as it is on the native ``/api`` path.
|
||||
"""
|
||||
try:
|
||||
parsed = urlparse(url or "")
|
||||
except Exception:
|
||||
return False
|
||||
host = parsed.hostname or ""
|
||||
path = (parsed.path or "").rstrip("/")
|
||||
local_ollama_host = host in {"localhost", "127.0.0.1", "0.0.0.0", "::1"} or parsed.port == 11434
|
||||
return local_ollama_host and (path == "/v1" or path.startswith("/v1/"))
|
||||
|
||||
|
||||
def _ollama_api_root(url: str) -> str:
|
||||
"""Return a native Ollama API root such as https://ollama.com/api."""
|
||||
url = (url or "").strip().rstrip("/")
|
||||
@@ -405,6 +442,146 @@ def _host_match(url: str, *domains: str) -> bool:
|
||||
return any(host == d or host.endswith("." + d) for d in domains)
|
||||
|
||||
|
||||
# Kimi Code subscription keys (api.kimi.com/coding/v1) require a whitelisted
|
||||
# coding-agent User-Agent; otherwise the API returns 403 access_terminated_error.
|
||||
# Tried in order; first success is cached per base URL for later requests.
|
||||
KIMI_CODE_USER_AGENTS: tuple[str, ...] = (
|
||||
"claude-code/0.1.0",
|
||||
"claude-code/1.0.0",
|
||||
"KimiCLI/1.0",
|
||||
"Kilo-Code/1.0",
|
||||
"Roo-Code/1.0",
|
||||
"Cursor/1.0",
|
||||
)
|
||||
KIMI_CODE_USER_AGENT = KIMI_CODE_USER_AGENTS[0]
|
||||
_kimi_code_ua_cache: dict[str, str] = {}
|
||||
|
||||
|
||||
def _is_kimi_code_url(url: str) -> bool:
|
||||
if not url or not _host_match(url, "kimi.com"):
|
||||
return False
|
||||
try:
|
||||
return "/coding" in (urlparse(url).path or "")
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def _kimi_code_base_key(url: str) -> str:
|
||||
"""Normalize a Kimi Code chat/models URL to its OpenAI base (.../coding/v1)."""
|
||||
parsed = urlparse(url)
|
||||
path = (parsed.path or "").rstrip("/")
|
||||
for suffix in ("/chat/completions", "/models", "/completions"):
|
||||
if path.endswith(suffix):
|
||||
path = path[: -len(suffix)]
|
||||
path = path.rstrip("/") or "/coding/v1"
|
||||
return f"{parsed.scheme}://{parsed.netloc}{path}"
|
||||
|
||||
|
||||
def _is_kimi_code_access_denied(status: int, body: bytes | str) -> bool:
|
||||
if status != 403:
|
||||
return False
|
||||
text = body.decode("utf-8", errors="replace") if isinstance(body, bytes) else (body or "")
|
||||
lower = text.lower()
|
||||
return (
|
||||
"access_terminated_error" in lower
|
||||
or "coding agents" in lower
|
||||
or "only available for coding" in lower
|
||||
)
|
||||
|
||||
|
||||
def _kimi_code_ua_candidates(url: str) -> list[str]:
|
||||
if not _is_kimi_code_url(url):
|
||||
return []
|
||||
base_key = _kimi_code_base_key(url)
|
||||
cached = _kimi_code_ua_cache.get(base_key)
|
||||
if cached:
|
||||
return [cached] + [ua for ua in KIMI_CODE_USER_AGENTS if ua != cached]
|
||||
return list(KIMI_CODE_USER_AGENTS)
|
||||
|
||||
|
||||
def _remember_kimi_code_user_agent(url: str, user_agent: str) -> None:
|
||||
_kimi_code_ua_cache[_kimi_code_base_key(url)] = user_agent
|
||||
|
||||
|
||||
def apply_kimi_code_headers(headers: Optional[Dict], url: str) -> Dict[str, str]:
|
||||
"""Pick a Kimi Code User-Agent (cached probe when possible)."""
|
||||
h = dict(headers or {})
|
||||
if not _is_kimi_code_url(url):
|
||||
return h
|
||||
base_key = _kimi_code_base_key(url)
|
||||
cached = _kimi_code_ua_cache.get(base_key)
|
||||
if cached:
|
||||
h["User-Agent"] = cached
|
||||
return h
|
||||
models_url = base_key.rstrip("/") + "/models"
|
||||
from src.tls_overrides import llm_verify
|
||||
for ua in KIMI_CODE_USER_AGENTS:
|
||||
trial = dict(h)
|
||||
trial["User-Agent"] = ua
|
||||
try:
|
||||
r = httpx.get(models_url, headers=trial, timeout=8, verify=llm_verify())
|
||||
except Exception:
|
||||
continue
|
||||
if _is_kimi_code_access_denied(r.status_code, r.content):
|
||||
logger.debug("Kimi Code rejected User-Agent %s (403), trying next", ua)
|
||||
continue
|
||||
if r.status_code < 400:
|
||||
_remember_kimi_code_user_agent(url, ua)
|
||||
h["User-Agent"] = ua
|
||||
return h
|
||||
break
|
||||
h.setdefault("User-Agent", KIMI_CODE_USER_AGENT)
|
||||
return h
|
||||
|
||||
|
||||
def httpx_get_kimi_aware(url: str, headers: Optional[Dict], **kwargs):
|
||||
h = apply_kimi_code_headers(headers, url)
|
||||
if not _is_kimi_code_url(url):
|
||||
return httpx.get(url, headers=h, **kwargs)
|
||||
last = None
|
||||
for ua in _kimi_code_ua_candidates(url):
|
||||
trial = dict(h)
|
||||
trial["User-Agent"] = ua
|
||||
last = httpx.get(url, headers=trial, **kwargs)
|
||||
if not _is_kimi_code_access_denied(last.status_code, last.content):
|
||||
if last.status_code < 400:
|
||||
_remember_kimi_code_user_agent(url, ua)
|
||||
return last
|
||||
return last
|
||||
|
||||
|
||||
def httpx_post_kimi_aware(url: str, headers: Optional[Dict], **kwargs):
|
||||
h = apply_kimi_code_headers(headers, url)
|
||||
if not _is_kimi_code_url(url):
|
||||
return httpx.post(url, headers=h, **kwargs)
|
||||
last = None
|
||||
for ua in _kimi_code_ua_candidates(url):
|
||||
trial = dict(h)
|
||||
trial["User-Agent"] = ua
|
||||
last = httpx.post(url, headers=trial, **kwargs)
|
||||
if not _is_kimi_code_access_denied(last.status_code, last.content):
|
||||
if last.status_code < 400:
|
||||
_remember_kimi_code_user_agent(url, ua)
|
||||
return last
|
||||
return last
|
||||
|
||||
|
||||
async def httpx_post_kimi_aware_async(client, url: str, headers: Optional[Dict], **kwargs):
|
||||
h = apply_kimi_code_headers(headers, url)
|
||||
if not _is_kimi_code_url(url):
|
||||
return await client.post(url, headers=h, **kwargs)
|
||||
last = None
|
||||
for ua in _kimi_code_ua_candidates(url):
|
||||
trial = dict(h)
|
||||
trial["User-Agent"] = ua
|
||||
last = await client.post(url, headers=trial, **kwargs)
|
||||
if not _is_kimi_code_access_denied(last.status_code, last.content):
|
||||
if last.status_code < 400:
|
||||
_remember_kimi_code_user_agent(url, ua)
|
||||
return last
|
||||
return last
|
||||
|
||||
|
||||
def _detect_provider(url: str) -> str:
|
||||
"""Detect the API provider from a configured endpoint URL.
|
||||
|
||||
@@ -426,6 +603,10 @@ def _detect_provider(url: str) -> str:
|
||||
return "openrouter"
|
||||
if _host_match(url, "groq.com"):
|
||||
return "groq"
|
||||
if _host_match(url, "nvidia.com"):
|
||||
return "nvidia"
|
||||
if _host_match(url, "moonshot.ai") or _host_match(url, "moonshot.cn"):
|
||||
return "moonshot"
|
||||
from src.chatgpt_subscription import is_chatgpt_subscription_base
|
||||
if is_chatgpt_subscription_base(url):
|
||||
return "chatgpt-subscription"
|
||||
@@ -435,6 +616,53 @@ def _detect_provider(url: str) -> str:
|
||||
return "openai"
|
||||
|
||||
|
||||
def _is_self_hosted_openai_compatible(url: str) -> bool:
|
||||
"""True for custom/local OpenAI-compatible servers (llama.cpp, LM Studio,
|
||||
vLLM, text-generation-webui, etc.) as opposed to cloud APIs.
|
||||
|
||||
Used to gate llama.cpp-server-specific payload extras (``session_id``,
|
||||
``cache_prompt``) used for KV-cache slot affinity (issue #2927). Strict
|
||||
cloud providers reject unrecognized top-level fields (api.openai.com
|
||||
returns 400, Mistral returns 422 "extra_forbidden", issue #3793), and any
|
||||
unknown OpenAI-compatible host used to be treated as self-hosted, so those
|
||||
fields leaked to every strict provider added as a custom endpoint.
|
||||
|
||||
A server only counts as self-hosted when it also resolves as local:
|
||||
loopback/private/tailscale host, or the endpoint explicitly configured
|
||||
with kind "local". A self-hosted server exposed via a public hostname
|
||||
loses the affinity hint unless its endpoint kind is set to "local" -
|
||||
a lost perf hint, versus a hard 4xx on every request the other way.
|
||||
"""
|
||||
if _detect_provider(url) != "openai" or _host_match(url, "openai.com"):
|
||||
return False
|
||||
from src.model_context import is_local_endpoint
|
||||
return is_local_endpoint(url)
|
||||
|
||||
|
||||
def _apply_local_cache_affinity(payload: Dict, url: str, session_id: Optional[str]) -> None:
|
||||
"""Add llama.cpp-server slot-affinity hints to an outgoing payload, in place.
|
||||
|
||||
As diagnosed in issue #2927, llama.cpp assigns requests to processing
|
||||
slots via LRU when no stable identifier is present ("session_id=<empty>
|
||||
server-selected (LCP/LRU)"), which means consecutive turns of the same
|
||||
chat can land on different slots and lose their cached prefix entirely.
|
||||
Sending a stable ``session_id`` (derived from the Odysseus session) lets
|
||||
the server keep routing the same conversation to the same slot, and
|
||||
``cache_prompt: true`` asks it to retain/reuse the prefix it already has.
|
||||
|
||||
Both fields are llama.cpp / LM Studio extensions to the OpenAI schema; we
|
||||
only set them for self-hosted OpenAI-compatible endpoints (never
|
||||
api.openai.com or other cloud providers, which reject unrecognized
|
||||
top-level request fields).
|
||||
"""
|
||||
if not session_id:
|
||||
return
|
||||
if not _is_self_hosted_openai_compatible(url):
|
||||
return
|
||||
payload.setdefault("session_id", str(session_id))
|
||||
payload.setdefault("cache_prompt", True)
|
||||
|
||||
|
||||
def _provider_headers(provider: str, headers: Optional[Dict] = None) -> Dict[str, str]:
|
||||
h = {"Content-Type": "application/json"}
|
||||
if isinstance(headers, dict):
|
||||
@@ -471,9 +699,16 @@ def _provider_label(url: str) -> str:
|
||||
if is_copilot_base(url): return "GitHub Copilot"
|
||||
if _host_match(url, "mistral.ai"): return "Mistral"
|
||||
if _host_match(url, "deepseek.com"): return "DeepSeek"
|
||||
if _host_match(url, "nvidia.com"): return "NVIDIA"
|
||||
if _host_match(url, "googleapis.com"): return "Google"
|
||||
if _host_match(url, "together.xyz", "together.ai"): return "Together"
|
||||
if _host_match(url, "fireworks.ai"): return "Fireworks"
|
||||
if _host_match(url, "kimi.com"):
|
||||
try:
|
||||
if "/coding" in (urlparse(url).path or ""):
|
||||
return "Kimi Code"
|
||||
except Exception:
|
||||
pass
|
||||
if _is_ollama_native_url(url): return "Ollama"
|
||||
try:
|
||||
host = (urlparse(url).hostname or "").lower()
|
||||
@@ -542,8 +777,9 @@ def _build_chatgpt_responses_payload(
|
||||
}
|
||||
if not _restricts_temperature(model):
|
||||
payload["temperature"] = temperature
|
||||
if max_tokens and max_tokens > 0:
|
||||
payload["max_output_tokens"] = max_tokens
|
||||
# ChatGPT Subscription Codex API does not support max_output_tokens —
|
||||
# passing it returns HTTP 400 "Unsupported parameter: max_output_tokens".
|
||||
# Do not include it in the payload.
|
||||
return payload
|
||||
|
||||
|
||||
@@ -613,7 +849,7 @@ def _uses_max_completion_tokens(model: str) -> bool:
|
||||
# perfectly good model as failing. For these models we omit the field and let
|
||||
# the API use its required default. (gpt-4.5 is intentionally excluded — it is
|
||||
# not a reasoning model and accepts temperature normally.)
|
||||
_FIXED_TEMPERATURE_MODELS = ("o1", "o3", "o4", "gpt-5")
|
||||
_FIXED_TEMPERATURE_MODELS = ("o1", "o3", "o4", "gpt-5", "kimi-for-coding")
|
||||
|
||||
def _restricts_temperature(model: str) -> bool:
|
||||
"""Check if a model rejects any non-default temperature."""
|
||||
@@ -622,6 +858,49 @@ def _restricts_temperature(model: str) -> bool:
|
||||
m = model.lower()
|
||||
return any(m.startswith(p) or f"/{p}" in m for p in _FIXED_TEMPERATURE_MODELS)
|
||||
|
||||
|
||||
# The official Moonshot API fixes temperature at 1.0 in thinking mode and 0.6
|
||||
# when thinking is explicitly disabled for Kimi K2.5/K2.6. Any other explicit
|
||||
# value returns HTTP 400. Odysseus does not currently send the `thinking` mode
|
||||
# control, so omit temperature and let Moonshot use its default thinking mode.
|
||||
# Keep the gate provider-specific: self-hosted Kimi deployments may accept
|
||||
# custom sampling values, and older Moonshot models have different defaults.
|
||||
def _moonshot_rejects_custom_temperature(provider: str, model: str) -> bool:
|
||||
"""Check if the official Moonshot API fixes temperature for this model."""
|
||||
if provider != "moonshot" or not isinstance(model, str):
|
||||
return False
|
||||
model_id = model.lower().rsplit("/", 1)[-1]
|
||||
return bool(re.match(r"^kimi-k2\.(?:5|6)(?:$|[-_:])", model_id))
|
||||
|
||||
|
||||
def _omit_temperature(provider: str, model: str) -> bool:
|
||||
"""Check if a request should use the provider's default temperature."""
|
||||
return _restricts_temperature(model) or _moonshot_rejects_custom_temperature(
|
||||
provider, model
|
||||
)
|
||||
|
||||
|
||||
# Anthropic removed the sampling parameters (temperature, top_p, top_k) starting
|
||||
# with Claude Opus 4.7. On Opus 4.7 and later, sending `temperature` at all —
|
||||
# even 0.0 — returns HTTP 400. Earlier Claude models (Opus 4.6 and below, every
|
||||
# Sonnet/Haiku) still accept temperature in [0.0, 1.0], so the omission must be
|
||||
# version-gated rather than applied to all `claude-*` models.
|
||||
def _anthropic_rejects_temperature(model: str) -> bool:
|
||||
"""Check if a native-Anthropic model rejects the temperature field (Opus 4.7+)."""
|
||||
if not isinstance(model, str) or not model:
|
||||
return False
|
||||
# `(?<![a-z])` anchors "opus" to a word boundary so a substring match like
|
||||
# `oct-opus`/`octopus-4-8` can't be read as Opus (it would otherwise strip
|
||||
# temperature). Cap the minor at 1-2 digits and forbid a trailing digit so a
|
||||
# dated id like `claude-opus-4-20250514` (Opus 4.0) parses as major-only (no
|
||||
# minor match, kept) instead of reading the date `20250514` as a giant minor
|
||||
# that would falsely test >= 4.7. Dated 4.7+ snapshots (`claude-opus-4-7-
|
||||
# 20260201`) keep their explicit minor and are still matched.
|
||||
match = re.search(r"(?<![a-z])opus[-_]?(\d+)[-_.](\d{1,2})(?!\d)", model.lower())
|
||||
if not match:
|
||||
return False
|
||||
return (int(match.group(1)), int(match.group(2))) >= (4, 7)
|
||||
|
||||
# Models that support structured thinking — may output </think> without opening tag
|
||||
_THINKING_MODEL_PATTERNS = ("qwen3", "qwq", "deepseek-r1", "deepseek-reasoner", "minimax", "m2-reap", "gemma")
|
||||
|
||||
@@ -725,8 +1004,11 @@ def _build_anthropic_payload(model, messages, temperature, max_tokens, stream=Fa
|
||||
"model": model,
|
||||
"messages": chat_messages,
|
||||
"max_tokens": max_tokens if max_tokens and max_tokens > 0 else 4096,
|
||||
"temperature": temperature,
|
||||
}
|
||||
# Opus 4.7+ removed the sampling parameters — sending `temperature` (even 0.0)
|
||||
# returns HTTP 400. Omit it for those models; older Claude models still take it.
|
||||
if not _anthropic_rejects_temperature(model):
|
||||
payload["temperature"] = temperature
|
||||
if system_parts:
|
||||
system_text = "\n\n".join(system_parts)
|
||||
# Send `system` as a structured text block so we can attach a prompt-cache
|
||||
@@ -810,7 +1092,7 @@ def _sanitize_llm_messages(messages: List[Dict]) -> List[Dict]:
|
||||
(content=None, since Gemini/Ollama reject tool_calls alongside ""). Dropping
|
||||
it leaves the tool result dangling and breaks the next round.
|
||||
"""
|
||||
allowed = {"role", "content", "name", "tool_call_id", "tool_calls", "function_call"}
|
||||
allowed = {"role", "content", "name", "tool_call_id", "tool_calls", "function_call", "reasoning_content"}
|
||||
cleaned = []
|
||||
for msg in messages or []:
|
||||
if not isinstance(msg, dict):
|
||||
@@ -1045,7 +1327,7 @@ def list_model_ids(
|
||||
from src.endpoint_resolver import build_models_url
|
||||
|
||||
models_url = build_models_url(base_chat_url)
|
||||
r = httpx.get(models_url, headers=h, timeout=timeout)
|
||||
r = httpx_get_kimi_aware(models_url, h, timeout=timeout)
|
||||
r.raise_for_status()
|
||||
data = r.json()
|
||||
model_ids = [m.get("id") for m in (data.get("data") or []) if m.get("id")]
|
||||
@@ -1146,14 +1428,14 @@ def llm_call(url: str, model: str, messages: List[Dict], temperature: float = LL
|
||||
"messages": messages_copy,
|
||||
"temperature": temperature,
|
||||
}
|
||||
if _restricts_temperature(model):
|
||||
if _omit_temperature(provider, model):
|
||||
payload.pop("temperature", None)
|
||||
if max_tokens and max_tokens > 0:
|
||||
tok_key = "max_completion_tokens" if _uses_max_completion_tokens(model) else "max_tokens"
|
||||
payload[tok_key] = max_tokens
|
||||
try:
|
||||
note_model_activity(target_url, model)
|
||||
r = httpx.post(target_url, headers=h, json=payload, timeout=timeout)
|
||||
r = httpx_post_kimi_aware(target_url, h, json=payload, timeout=timeout)
|
||||
except Exception as e:
|
||||
raise HTTPException(502, f"POST {target_url} failed: {e}")
|
||||
if not r.is_success:
|
||||
@@ -1247,7 +1529,8 @@ async def llm_call_async(
|
||||
headers: Optional[Dict] = None,
|
||||
timeout: int = LLMConfig.STREAM_TIMEOUT,
|
||||
max_retries: int = LLMConfig.MAX_RETRIES,
|
||||
prompt_type: Optional[str] = None
|
||||
prompt_type: Optional[str] = None,
|
||||
session_id: Optional[str] = None,
|
||||
) -> str:
|
||||
"""Asynchronous LLM call using httpx with connection pooling, timeout, retry logic, and performance logging."""
|
||||
provider = _detect_provider(url)
|
||||
@@ -1339,16 +1622,20 @@ async def llm_call_async(
|
||||
"messages": messages_copy,
|
||||
"temperature": temperature,
|
||||
}
|
||||
if _restricts_temperature(model):
|
||||
if _omit_temperature(provider, model):
|
||||
payload.pop("temperature", None)
|
||||
if max_tokens and max_tokens > 0:
|
||||
tok_key = "max_completion_tokens" if _uses_max_completion_tokens(model) else "max_tokens"
|
||||
payload[tok_key] = max_tokens
|
||||
# Suppress thinking for qwen3/gemma4 on Ollama /v1 — same as stream_llm.
|
||||
if _is_ollama_openai_compat_url(url) and _supports_thinking(model):
|
||||
payload["think"] = False
|
||||
_apply_local_cache_affinity(payload, url, session_id)
|
||||
|
||||
if _is_host_dead(target_url):
|
||||
raise HTTPException(503, f"Upstream {_host_key(target_url)} marked unreachable (cooldown active)")
|
||||
|
||||
call_timeout = httpx.Timeout(connect=3.0, read=float(timeout), write=10.0, pool=5.0)
|
||||
call_timeout = _call_timeout(timeout)
|
||||
attempt = 0
|
||||
while attempt < max_retries:
|
||||
attempt += 1
|
||||
@@ -1356,7 +1643,7 @@ async def llm_call_async(
|
||||
try:
|
||||
note_model_activity(target_url, model)
|
||||
client = _get_http_client()
|
||||
r = await client.post(target_url, headers=h, json=payload, timeout=call_timeout)
|
||||
r = await httpx_post_kimi_aware_async(client, target_url, h, json=payload, timeout=call_timeout)
|
||||
duration = time.time() - start
|
||||
if not r.is_success:
|
||||
friendly = _format_upstream_error(r.status_code, r.text, target_url)
|
||||
@@ -1401,7 +1688,7 @@ async def llm_call_async(
|
||||
async def stream_llm(url: str, model: str, messages: List[Dict], temperature: float = LLMConfig.DEFAULT_TEMPERATURE,
|
||||
max_tokens: int = LLMConfig.DEFAULT_MAX_TOKENS, headers: Optional[Dict] = None,
|
||||
timeout: int = LLMConfig.STREAM_TIMEOUT, prompt_type: Optional[str] = None,
|
||||
tools: Optional[List[Dict]] = None):
|
||||
tools: Optional[List[Dict]] = None, session_id: Optional[str] = None):
|
||||
"""Stream LLM responses with improved error handling.
|
||||
|
||||
Yields SSE chunks:
|
||||
@@ -1452,7 +1739,7 @@ async def stream_llm(url: str, model: str, messages: List[Dict], temperature: fl
|
||||
"temperature": temperature,
|
||||
"stream": True,
|
||||
}
|
||||
if _restricts_temperature(model):
|
||||
if _omit_temperature(provider, model):
|
||||
payload.pop("temperature", None)
|
||||
if provider not in {"openrouter", "groq"}:
|
||||
payload["stream_options"] = {"include_usage": True}
|
||||
@@ -1461,14 +1748,23 @@ async def stream_llm(url: str, model: str, messages: List[Dict], temperature: fl
|
||||
payload[tok_key] = max_tokens
|
||||
if tools:
|
||||
payload["tools"] = tools
|
||||
# For Ollama's OpenAI-compat /v1 endpoint with thinking models (qwen3,
|
||||
# gemma4, etc.), suppress thinking so tool calls aren't swallowed inside
|
||||
# <think> blocks. Ollama /v1 accepts "think": false as a top-level param.
|
||||
if _is_ollama_openai_compat_url(url) and _supports_thinking(model):
|
||||
payload["think"] = False
|
||||
_apply_local_cache_affinity(payload, url, session_id)
|
||||
h = _provider_headers(provider, headers)
|
||||
if provider == "copilot":
|
||||
from src.copilot import apply_request_headers
|
||||
apply_request_headers(h, messages_copy)
|
||||
|
||||
# Short connect timeout: a reachable peer answers SYN in <100ms even on
|
||||
# Tailscale. 3s is plenty; 30s let one dead upstream wedge the UI.
|
||||
stream_timeout = httpx.Timeout(connect=3.0, read=float(timeout), write=30.0, pool=5.0)
|
||||
# Connect budget from LLMConfig.CONNECT_TIMEOUT (env LLM_CONNECT_TIMEOUT).
|
||||
# The dead-host cooldown still bounds a genuinely unreachable upstream, so a
|
||||
# wider connect budget only affects first contact and stops a brief cold
|
||||
# connect blip (offshore/public endpoints) surfacing as a 503 on this stream
|
||||
# path, which -- unlike llm_call -- does not retry the connect.
|
||||
stream_timeout = _stream_timeout(timeout)
|
||||
|
||||
if _is_host_dead(target_url):
|
||||
yield f'event: error\ndata: {json.dumps({"error": f"Upstream {_host_key(target_url)} unreachable (cooldown active)", "status": 503})}\n\n'
|
||||
@@ -1744,6 +2040,7 @@ async def stream_llm(url: str, model: str, messages: List[Dict], temperature: fl
|
||||
events.append(_stream_delta_event(part))
|
||||
return events
|
||||
|
||||
h = apply_kimi_code_headers(h, target_url)
|
||||
try:
|
||||
client = _get_http_client()
|
||||
async with client.stream('POST', target_url, json=payload, headers=h, timeout=stream_timeout) as r:
|
||||
|
||||
+82
-34
@@ -5,6 +5,7 @@ Query and cache model context window sizes from OpenAI-compatible APIs.
|
||||
Provides token estimation for context usage tracking.
|
||||
"""
|
||||
|
||||
import ipaddress
|
||||
import logging
|
||||
import sys
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
@@ -19,7 +20,20 @@ _LOCAL_HOSTS = {"localhost", "127.0.0.1", "0.0.0.0", "::1", "host.docker.interna
|
||||
_PRIVATE_PREFIXES = ("10.", "172.16.", "172.17.", "172.18.", "172.19.",
|
||||
"172.20.", "172.21.", "172.22.", "172.23.", "172.24.",
|
||||
"172.25.", "172.26.", "172.27.", "172.28.", "172.29.",
|
||||
"172.30.", "172.31.", "192.168.", "100.")
|
||||
"172.30.", "172.31.", "192.168.")
|
||||
|
||||
# Tailscale uses the CGNAT range 100.64.0.0/10, NOT all of 100.0.0.0/8.
|
||||
# A bare "100." prefix would classify public addresses (e.g. AWS ranges
|
||||
# under 100.x outside the CGNAT block) as local; routes/model_routes.py
|
||||
# already narrows this the same way for endpoint classification.
|
||||
_TAILSCALE_CGNAT = ipaddress.ip_network("100.64.0.0/10")
|
||||
|
||||
|
||||
def _in_tailscale_range(host: str) -> bool:
|
||||
try:
|
||||
return ipaddress.ip_address(host) in _TAILSCALE_CGNAT
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
|
||||
def _normalize_base_for_compare(url: str) -> str:
|
||||
@@ -64,7 +78,7 @@ def _configured_endpoint_kind(url: str) -> Optional[str]:
|
||||
return None
|
||||
|
||||
|
||||
def _is_local_endpoint(url: str) -> bool:
|
||||
def is_local_endpoint(url: str) -> bool:
|
||||
"""Check if URL points to a local/private/tailscale address."""
|
||||
kind = _configured_endpoint_kind(url)
|
||||
if kind in ("api", "proxy"):
|
||||
@@ -73,7 +87,7 @@ def _is_local_endpoint(url: str) -> bool:
|
||||
return True
|
||||
try:
|
||||
host = urlparse(url).hostname or ""
|
||||
return host in _LOCAL_HOSTS or host.startswith(_PRIVATE_PREFIXES)
|
||||
return host in _LOCAL_HOSTS or host.startswith(_PRIVATE_PREFIXES) or _in_tailscale_range(host)
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
@@ -208,7 +222,30 @@ KNOWN_CONTEXT_WINDOWS = {
|
||||
# ---------------------------------------------------------------------------
|
||||
# Cache
|
||||
# ---------------------------------------------------------------------------
|
||||
_context_cache: Dict[Tuple[str, str], int] = {}
|
||||
_context_cache: Dict[Tuple[str, str], Tuple[int, bool]] = {}
|
||||
|
||||
|
||||
def _get_context_length_cached(endpoint_url: str, model: str) -> Tuple[int, bool]:
|
||||
"""Return (context_length, known). ``known`` is False only when the value is a
|
||||
bare DEFAULT_CONTEXT fallback (no endpoint report and not in the known table)."""
|
||||
configured_kind = _configured_endpoint_kind(endpoint_url)
|
||||
is_local = is_local_endpoint(endpoint_url)
|
||||
# Key on (endpoint_url, model): the same model id can be served by two
|
||||
# different remote endpoints with different real context windows (e.g. a
|
||||
# capped proxy vs. the full provider), so caching by model id alone would
|
||||
# serve one endpoint's window for the other (issue #2603).
|
||||
cache_key = (endpoint_url, model)
|
||||
if not is_local and cache_key in _context_cache:
|
||||
return _context_cache[cache_key]
|
||||
|
||||
ctx, known = _query_context_length(endpoint_url, model)
|
||||
# Only cache non-default values to allow retry on next request.
|
||||
# Local endpoints can restart with a different --max-model-len while keeping
|
||||
# the same model id, so always re-query them instead of serving stale cache.
|
||||
if not is_local and (ctx != DEFAULT_CONTEXT or configured_kind in ("api", "proxy")):
|
||||
_context_cache[cache_key] = (ctx, known)
|
||||
logger.info(f"Context length for {model}: {ctx}")
|
||||
return ctx, known
|
||||
|
||||
|
||||
def get_context_length(endpoint_url: str, model: str) -> int:
|
||||
@@ -218,24 +255,33 @@ def get_context_length(endpoint_url: str, model: str) -> int:
|
||||
or context_window fields. Caches result per (endpoint, model).
|
||||
Falls back to DEFAULT_CONTEXT if unavailable.
|
||||
"""
|
||||
configured_kind = _configured_endpoint_kind(endpoint_url)
|
||||
is_local = _is_local_endpoint(endpoint_url)
|
||||
# Key on (endpoint_url, model): the same model id can be served by two
|
||||
# different remote endpoints with different real context windows (e.g. a
|
||||
# capped proxy vs. the full provider), so caching by model id alone would
|
||||
# serve one endpoint's window for the other (issue #2603).
|
||||
cache_key = (endpoint_url, model)
|
||||
if not is_local and cache_key in _context_cache:
|
||||
return _context_cache[cache_key]
|
||||
return _get_context_length_cached(endpoint_url, model)[0]
|
||||
|
||||
ctx = _query_context_length(endpoint_url, model)
|
||||
# Only cache non-default values to allow retry on next request.
|
||||
# Local endpoints can restart with a different --max-model-len while keeping
|
||||
# the same model id, so always re-query them instead of serving stale cache.
|
||||
if not is_local and (ctx != DEFAULT_CONTEXT or configured_kind in ("api", "proxy")):
|
||||
_context_cache[cache_key] = ctx
|
||||
logger.info(f"Context length for {model}: {ctx}")
|
||||
return ctx
|
||||
|
||||
def get_context_length_known(endpoint_url: str, model: str) -> Tuple[int, bool]:
|
||||
"""Like ``get_context_length`` but also returns whether the window was actually
|
||||
discovered (endpoint-reported or in the known-models table) rather than the bare
|
||||
DEFAULT_CONTEXT fallback. Callers that *scale* a budget off the window must not
|
||||
trust an unknown value — a fallback 128K isn't proof the model holds 128K
|
||||
(review on #4122)."""
|
||||
return _get_context_length_cached(endpoint_url, model)
|
||||
|
||||
|
||||
def budget_context_for_model(endpoint_url: str, model: str, *, fallback: int = 0) -> int:
|
||||
"""Context window to scale the agent input budget against.
|
||||
|
||||
Returns the *freshly discovered* window when it was actually proven
|
||||
(endpoint-reported / known table), else 0 so auto-scaling stays conservative.
|
||||
Crucially this binds the ``known`` flag to the value it proves — callers must
|
||||
not pair this flag with a context length from a *different* lookup (a stale
|
||||
local re-query, or a caller that didn't pass one), which would budget off an
|
||||
unproven number (review on #4122). On probe error, returns ``fallback`` (the
|
||||
caller's best-known value) to preserve prior behaviour."""
|
||||
try:
|
||||
ctx, known = get_context_length_known(endpoint_url, model)
|
||||
return ctx if known else 0
|
||||
except Exception:
|
||||
return fallback
|
||||
|
||||
|
||||
def _lookup_known(model: str) -> Optional[int]:
|
||||
@@ -257,8 +303,9 @@ def _lookup_known(model: str) -> Optional[int]:
|
||||
return best_ctx
|
||||
|
||||
|
||||
def _query_context_length(endpoint_url: str, model: str) -> int:
|
||||
"""Query the model API for context length."""
|
||||
def _query_context_length(endpoint_url: str, model: str) -> Tuple[int, bool]:
|
||||
"""Query the model API for context length. Returns (context_length, known) where
|
||||
``known`` is False only for the bare DEFAULT_CONTEXT fallback."""
|
||||
known = _lookup_known(model)
|
||||
api_ctx = None
|
||||
configured_kind = _configured_endpoint_kind(endpoint_url)
|
||||
@@ -269,11 +316,11 @@ def _query_context_length(endpoint_url: str, model: str) -> int:
|
||||
if configured_kind in ("api", "proxy"):
|
||||
if known:
|
||||
logger.info(f"Using known context window for {model}: {known}")
|
||||
return known
|
||||
return DEFAULT_CONTEXT
|
||||
return known, True
|
||||
return DEFAULT_CONTEXT, False
|
||||
|
||||
# Try llama.cpp /slots endpoint first — reports actual serving context
|
||||
if _is_local_endpoint(endpoint_url):
|
||||
if is_local_endpoint(endpoint_url):
|
||||
try:
|
||||
base = endpoint_url.split("/v1")[0] if "/v1" in endpoint_url else endpoint_url.rsplit("/", 1)[0]
|
||||
r = httpx.get(f"{base}/slots", timeout=REQUEST_TIMEOUT)
|
||||
@@ -283,7 +330,7 @@ def _query_context_length(endpoint_url: str, model: str) -> int:
|
||||
n_ctx = slots[0].get("n_ctx")
|
||||
if n_ctx and isinstance(n_ctx, int) and n_ctx > 0:
|
||||
logger.info(f"llama.cpp /slots reports n_ctx={n_ctx} for {model}")
|
||||
return n_ctx
|
||||
return n_ctx, True
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@@ -295,7 +342,8 @@ def _query_context_length(endpoint_url: str, model: str) -> int:
|
||||
if is_copilot_base(endpoint_url):
|
||||
if known:
|
||||
logger.info(f"Using known context window for {model}: {known}")
|
||||
return known or DEFAULT_CONTEXT
|
||||
return known, True
|
||||
return DEFAULT_CONTEXT, False
|
||||
|
||||
from src.endpoint_resolver import build_models_url
|
||||
|
||||
@@ -337,21 +385,21 @@ def _query_context_length(endpoint_url: str, model: str) -> int:
|
||||
# For local/self-hosted endpoints, trust the API value (user set --max-model-len)
|
||||
# For cloud APIs, use the larger value (API can report low defaults)
|
||||
if api_ctx and known:
|
||||
_is_local = _is_local_endpoint(endpoint_url)
|
||||
_is_local = is_local_endpoint(endpoint_url)
|
||||
if _is_local and api_ctx < known:
|
||||
logger.info(f"Local endpoint reports {api_ctx} for {model} (known max: {known}) — using API value")
|
||||
return api_ctx
|
||||
return api_ctx, True
|
||||
result = max(api_ctx, known)
|
||||
if api_ctx < known:
|
||||
logger.info(f"API reported {api_ctx} for {model}, using known {known} instead")
|
||||
return result
|
||||
return result, True
|
||||
if api_ctx:
|
||||
return api_ctx
|
||||
return api_ctx, True
|
||||
if known:
|
||||
logger.info(f"Using known context window for {model}: {known}")
|
||||
return known
|
||||
return known, True
|
||||
|
||||
return DEFAULT_CONTEXT
|
||||
return DEFAULT_CONTEXT, False
|
||||
|
||||
|
||||
def estimate_tokens(messages: List[Dict]) -> int:
|
||||
|
||||
@@ -223,6 +223,25 @@ class ModelDiscovery:
|
||||
)
|
||||
return {"hosts": hosts, "items": items}
|
||||
|
||||
def warmup_ping_urls(self, limit: int = 5) -> List[str]:
|
||||
"""The ``/models`` URLs of up to ``limit`` discovered endpoints.
|
||||
|
||||
Used by the startup warmup / keepalive loop to prime connections. Each
|
||||
discovered item already carries a ``/v1/chat/completions`` url; swap the
|
||||
suffix for the cheap ``/models`` probe. Failures degrade to an empty list
|
||||
so warmup never crashes the caller.
|
||||
"""
|
||||
try:
|
||||
items = (self.discover_models() or {}).get("items", [])
|
||||
except Exception:
|
||||
return []
|
||||
urls: List[str] = []
|
||||
for ep in items[:limit]:
|
||||
url = (ep.get("url") or "").replace("/chat/completions", "/models")
|
||||
if url:
|
||||
urls.append(url)
|
||||
return urls
|
||||
|
||||
def get_providers(self) -> Dict[str, Any]:
|
||||
"""Get all available providers"""
|
||||
discovery = self.discover_models()
|
||||
|
||||
+1
-1
@@ -32,7 +32,7 @@ def create_office_document(
|
||||
DocumentVersion,
|
||||
Session as DbSession,
|
||||
)
|
||||
from src.tool_implementations import set_active_document
|
||||
from src.agent_tools.document_tools import set_active_document
|
||||
|
||||
if not body_text or not body_text.strip():
|
||||
return None
|
||||
|
||||
@@ -0,0 +1,32 @@
|
||||
"""Compatibility helpers for optional third-party dependencies."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
import types
|
||||
|
||||
|
||||
def patch_realesrgan_torchvision_compat() -> None:
|
||||
"""Restore the torchvision import path expected by BasicSR/Real-ESRGAN."""
|
||||
module_name = "torchvision.transforms.functional_tensor"
|
||||
if module_name in sys.modules:
|
||||
return
|
||||
try:
|
||||
from torchvision.transforms import functional
|
||||
except Exception:
|
||||
return
|
||||
|
||||
rgb_to_grayscale = getattr(functional, "rgb_to_grayscale", None)
|
||||
if rgb_to_grayscale is None:
|
||||
return
|
||||
|
||||
shim = types.ModuleType(module_name)
|
||||
shim.rgb_to_grayscale = rgb_to_grayscale
|
||||
shim.__getattr__ = lambda name: getattr(functional, name)
|
||||
sys.modules[module_name] = shim
|
||||
|
||||
|
||||
def prepare_optional_dependency_import(name: str) -> None:
|
||||
"""Apply known import-time compatibility shims before probing a package."""
|
||||
if name == "realesrgan":
|
||||
patch_realesrgan_torchvision_compat()
|
||||
+2
-2
@@ -219,7 +219,7 @@ def create_plain_pdf_document(
|
||||
pages without form-field overlays.
|
||||
"""
|
||||
from src.database import SessionLocal, Document, DocumentVersion, Session as DbSession
|
||||
from src.tool_implementations import set_active_document
|
||||
from src.agent_tools.document_tools import set_active_document
|
||||
|
||||
content = render_plain_pdf_markdown(upload_id, title, body_text)
|
||||
db = SessionLocal()
|
||||
@@ -402,7 +402,7 @@ def create_form_markdown_document(
|
||||
inside the content, which the export route looks for.
|
||||
"""
|
||||
from src.database import SessionLocal, Document, DocumentVersion, Session as DbSession
|
||||
from src.tool_implementations import set_active_document
|
||||
from src.agent_tools.document_tools import set_active_document
|
||||
|
||||
content = render_form_as_markdown(fields, upload_id, title, intro_text=intro_text)
|
||||
db = SessionLocal()
|
||||
|
||||
+24
-1
@@ -221,6 +221,22 @@ class ResearchHandler:
|
||||
# Task registry — background research with persistence
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def rename_owner(self, old_owner: str, new_owner: str) -> int:
|
||||
"""Move in-flight research tasks from one owner key to another."""
|
||||
old_key = str(old_owner or "").strip().lower()
|
||||
new_key = str(new_owner or "").strip().lower()
|
||||
if not old_key or not new_key:
|
||||
return 0
|
||||
|
||||
changed = 0
|
||||
for entry in list(self._active_tasks.values()):
|
||||
if not isinstance(entry, dict):
|
||||
continue
|
||||
if str(entry.get("owner", "")).strip().lower() == old_key:
|
||||
entry["owner"] = new_key
|
||||
changed += 1
|
||||
return changed
|
||||
|
||||
def start_research(
|
||||
self,
|
||||
session_id: str,
|
||||
@@ -390,7 +406,6 @@ class ResearchHandler:
|
||||
|
||||
def get_status(self, session_id: str) -> Optional[dict]:
|
||||
"""Get current research status for a session."""
|
||||
avg = self.get_avg_duration()
|
||||
if session_id in self._active_tasks:
|
||||
entry = self._active_tasks[session_id]
|
||||
result = {
|
||||
@@ -399,6 +414,14 @@ class ResearchHandler:
|
||||
"query": entry["query"],
|
||||
"started_at": entry["started_at"],
|
||||
}
|
||||
# avg_duration is a historical figure over completed reports on
|
||||
# disk; get_avg_duration() globs and JSON-parses the whole research
|
||||
# dir, so compute it at most once per active stream (memoized on the
|
||||
# entry) instead of on every ~1s SSE poll. The disk branch below
|
||||
# never used it, so it no longer pays that cost at all.
|
||||
if "_avg_duration" not in entry:
|
||||
entry["_avg_duration"] = self.get_avg_duration()
|
||||
avg = entry["_avg_duration"]
|
||||
if avg is not None:
|
||||
result["avg_duration"] = round(avg, 1)
|
||||
return result
|
||||
|
||||
@@ -0,0 +1,506 @@
|
||||
"""Consolidated service health / degraded-state reporting.
|
||||
|
||||
ROADMAP: "Better degraded-state reporting for ChromaDB, SearXNG, email, ntfy,
|
||||
and provider probes." There was no single readout of which subsystems are
|
||||
actually working — `/api/health` is only a liveness ping and each subsystem's
|
||||
signal lives in a different module. This collects them into one uniform,
|
||||
*non-intrusive* report (no test push is sent, no real search is run), so the
|
||||
admin endpoint built on top of it is safe to poll.
|
||||
|
||||
Each probe returns:
|
||||
|
||||
{"name": str, "status": "ok"|"degraded"|"down"|"disabled",
|
||||
"detail": str, "meta": dict}
|
||||
|
||||
- ok — reachable / working
|
||||
- degraded — partially working (one of several components down)
|
||||
- down — configured & enabled but unreachable / erroring
|
||||
- disabled — not configured or turned off (not counted as a failure)
|
||||
|
||||
Design notes (driven by review feedback):
|
||||
|
||||
- **Bounded wall-clock.** Per-item probes (providers, email accounts) fan out
|
||||
across a bounded thread pool with a hard total budget (`_FANOUT_BUDGET`);
|
||||
stragglers are reported as a controlled `timeout` rather than blocking. The
|
||||
aggregate adds a per-subsystem deadline (`_SUBSYSTEM_DEADLINE`) and an overall
|
||||
ceiling (`_AGGREGATE_DEADLINE`), so the endpoint cannot hang regardless of how
|
||||
many endpoints/accounts are configured or how slowly they respond.
|
||||
- **No secret leakage.** Even though the endpoint is admin-only, the response
|
||||
never returns credential-bearing URLs or raw exception text: URLs are passed
|
||||
through `_safe_url` (userinfo / query / fragment stripped) and failures are
|
||||
mapped to controlled categories via `_classify_error`.
|
||||
|
||||
The probe functions take their inputs as parameters (settings dict, account
|
||||
list, endpoint list, manager objects) and isolate the network call to
|
||||
``_http_get`` / injected callables, so they unit-test without touching the
|
||||
network.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import concurrent.futures
|
||||
import logging
|
||||
import socket
|
||||
import ssl
|
||||
import time
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
from urllib.parse import urlparse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Status ordering for rolling up an overall verdict. "disabled" is excluded —
|
||||
# a turned-off feature must never drag the overall status down.
|
||||
_SEVERITY = {"ok": 0, "degraded": 1, "down": 2}
|
||||
|
||||
OK = "ok"
|
||||
DEGRADED = "degraded"
|
||||
DOWN = "down"
|
||||
DISABLED = "disabled"
|
||||
|
||||
# Timing budgets (seconds). _PROBE_TIMEOUT bounds a single network op;
|
||||
# _FANOUT_BUDGET bounds a whole fan-out (providers/email) regardless of count;
|
||||
# the aggregate layer adds a per-subsystem deadline and an overall ceiling.
|
||||
_PROBE_TIMEOUT = 4
|
||||
_PROBE_CONCURRENCY = 8
|
||||
_FANOUT_BUDGET = 8
|
||||
_SUBSYSTEM_DEADLINE = 10
|
||||
_AGGREGATE_DEADLINE = 14
|
||||
|
||||
# Controlled, secret-free phrasing for each failure category.
|
||||
_ERROR_DETAIL = {
|
||||
"timeout": "probe timed out",
|
||||
"connection_refused": "connection refused",
|
||||
"dns_error": "host could not be resolved",
|
||||
"tls_error": "TLS handshake failed",
|
||||
"network_error": "network error",
|
||||
"http_error": "server returned an error response",
|
||||
"auth_or_protocol_error": "authentication or protocol error",
|
||||
"no_models": "endpoint returned no models",
|
||||
"no_host": "no host configured",
|
||||
"error": "probe failed",
|
||||
}
|
||||
|
||||
|
||||
def _svc(name: str, status: str, detail: str, **meta: Any) -> Dict[str, Any]:
|
||||
return {"name": name, "status": status, "detail": detail, "meta": dict(meta)}
|
||||
|
||||
|
||||
def _safe_url(url: Optional[str]) -> str:
|
||||
"""Strip credentials (userinfo), query, and fragment from a URL.
|
||||
|
||||
Keeps scheme / host / port / path so the report is still useful, but never
|
||||
echoes `user:pass@`, `?api_key=…`, or `#…` back to the caller. Returns
|
||||
"<redacted>" if the URL can't be parsed into at least a host.
|
||||
"""
|
||||
if not url:
|
||||
return ""
|
||||
raw = url.strip()
|
||||
try:
|
||||
p = urlparse(raw if "://" in raw else "//" + raw)
|
||||
host = p.hostname or ""
|
||||
if not host:
|
||||
return "<redacted>"
|
||||
netloc = f"{host}:{p.port}" if p.port else host
|
||||
path = (p.path or "").rstrip("/")
|
||||
scheme = f"{p.scheme}://" if p.scheme else ""
|
||||
return f"{scheme}{netloc}{path}"
|
||||
except Exception:
|
||||
return "<redacted>"
|
||||
|
||||
|
||||
def _classify_error(exc: BaseException) -> str:
|
||||
"""Map an exception to a controlled, secret-free category token.
|
||||
|
||||
Never returns `str(exc)` — httpx/imaplib exception text can embed the target
|
||||
URL (which may carry credentials) or server-supplied detail.
|
||||
"""
|
||||
if isinstance(exc, (asyncio.TimeoutError, concurrent.futures.TimeoutError,
|
||||
TimeoutError, socket.timeout)):
|
||||
return "timeout"
|
||||
name = type(exc).__name__
|
||||
mod = (type(exc).__module__ or "")
|
||||
if isinstance(exc, ssl.SSLError) or "SSL" in name or "Certificate" in name:
|
||||
return "tls_error"
|
||||
if isinstance(exc, socket.gaierror) or name in ("gaierror", "herror"):
|
||||
return "dns_error"
|
||||
if isinstance(exc, ConnectionRefusedError) or "ConnectionRefused" in name \
|
||||
or name in ("ConnectError",):
|
||||
return "connection_refused"
|
||||
if "Timeout" in name:
|
||||
return "timeout"
|
||||
if mod.startswith("imaplib") or name in ("error", "abort", "readonly"):
|
||||
return "auth_or_protocol_error"
|
||||
if name == "HTTPStatusError":
|
||||
return "http_error"
|
||||
if name in ("ConnectTimeout", "ReadTimeout", "ReadError", "WriteError",
|
||||
"PoolTimeout", "RemoteProtocolError", "NetworkError",
|
||||
"ProxyError", "ProtocolError"):
|
||||
return "network_error"
|
||||
if isinstance(exc, OSError):
|
||||
return "network_error"
|
||||
return "error"
|
||||
|
||||
|
||||
def _detail_for(category: str) -> str:
|
||||
return _ERROR_DETAIL.get(category, _ERROR_DETAIL["error"])
|
||||
|
||||
|
||||
def _http_get(url: str, timeout: float = _PROBE_TIMEOUT):
|
||||
"""Single network entry point for the HTTP probes (monkeypatched in tests)."""
|
||||
import httpx
|
||||
return httpx.get(url, timeout=timeout)
|
||||
|
||||
|
||||
def _bounded_map(items: List[Any], worker: Callable[[int, Any], Dict[str, Any]],
|
||||
*, budget: float = _FANOUT_BUDGET,
|
||||
concurrency: int = _PROBE_CONCURRENCY) -> List[Optional[Dict[str, Any]]]:
|
||||
"""Run ``worker(index, item)`` across a bounded thread pool, in order.
|
||||
|
||||
`worker` must catch its own exceptions and return a per-item dict. Any item
|
||||
not finished within `budget` seconds *in total* is left as ``None`` (the
|
||||
caller substitutes a controlled `timeout` entry). The pool is shut down with
|
||||
``wait=False`` so stragglers never block the response — their own per-op
|
||||
timeout reaps them shortly after.
|
||||
"""
|
||||
n = len(items)
|
||||
out: List[Optional[Dict[str, Any]]] = [None] * n
|
||||
if n == 0:
|
||||
return out
|
||||
ex = concurrent.futures.ThreadPoolExecutor(max_workers=max(1, min(concurrency, n)))
|
||||
futures = {ex.submit(worker, i, items[i]): i for i in range(n)}
|
||||
try:
|
||||
for fut in concurrent.futures.as_completed(futures, timeout=budget):
|
||||
i = futures[fut]
|
||||
try:
|
||||
out[i] = fut.result()
|
||||
except Exception as e: # worker is expected to handle its own errors
|
||||
out[i] = {"ok": False, "error": _classify_error(e)}
|
||||
except concurrent.futures.TimeoutError:
|
||||
pass # unfinished items stay None → marked timeout by the caller
|
||||
finally:
|
||||
ex.shutdown(wait=False, cancel_futures=True)
|
||||
return out
|
||||
|
||||
|
||||
# ── ChromaDB (vector RAG + vector memory) ──
|
||||
|
||||
def chromadb_health(rag_manager: Any, memory_vector: Any) -> Dict[str, Any]:
|
||||
"""Report on the two ChromaDB-backed stores via their `.healthy` flags.
|
||||
|
||||
Both absent → disabled (Chroma/embeddings not installed or off).
|
||||
Both healthy → ok. One down → degraded. Both present but unhealthy → down.
|
||||
"""
|
||||
rag_present = rag_manager is not None
|
||||
mem_present = memory_vector is not None
|
||||
if not rag_present and not mem_present:
|
||||
return _svc("chromadb", DISABLED,
|
||||
"Vector RAG and vector memory are not initialized.",
|
||||
rag=None, memory=None)
|
||||
|
||||
rag_ok = bool(rag_present and getattr(rag_manager, "healthy", False))
|
||||
mem_ok = bool(mem_present and getattr(memory_vector, "healthy", False))
|
||||
meta = {"rag": rag_ok if rag_present else None,
|
||||
"memory": mem_ok if mem_present else None}
|
||||
|
||||
healthy = [ok for ok in (rag_ok if rag_present else None,
|
||||
mem_ok if mem_present else None) if ok is not None]
|
||||
if healthy and all(healthy):
|
||||
return _svc("chromadb", OK, "Vector stores healthy.", **meta)
|
||||
if any(healthy):
|
||||
return _svc("chromadb", DEGRADED,
|
||||
"One vector store is unavailable.", **meta)
|
||||
return _svc("chromadb", DOWN, "Vector stores are unavailable.", **meta)
|
||||
|
||||
|
||||
# ── SearXNG ──
|
||||
|
||||
def _searxng_instance(settings: Dict[str, Any]) -> str:
|
||||
"""Mirror src/search/providers.py:_get_search_instance precedence."""
|
||||
url = (settings.get("search_url") or "").strip()
|
||||
if url:
|
||||
return url.rstrip("/")
|
||||
from src.constants import SEARXNG_INSTANCE
|
||||
return SEARXNG_INSTANCE.rstrip("/")
|
||||
|
||||
|
||||
def searxng_health(settings: Dict[str, Any],
|
||||
*, http_get: Callable = _http_get) -> Dict[str, Any]:
|
||||
"""Non-intrusive reachability probe for the configured SearXNG instance.
|
||||
|
||||
Tries `/healthz` (2xx), falling back to the instance root (any non-5xx means
|
||||
the host answered). No search query is run. The configured instance is
|
||||
probed in full, but only its sanitized form is returned in `meta`.
|
||||
"""
|
||||
provider = (settings.get("search_provider") or "searxng")
|
||||
if provider != "searxng":
|
||||
return _svc("searxng", DISABLED,
|
||||
f"Search provider is '{provider}', not SearXNG.",
|
||||
provider=provider)
|
||||
instance = _searxng_instance(settings)
|
||||
if not instance:
|
||||
return _svc("searxng", DISABLED, "No SearXNG instance configured.")
|
||||
safe_instance = _safe_url(instance)
|
||||
last_category = "error"
|
||||
for path, accept in (("/healthz", lambda c: 200 <= c < 300),
|
||||
("/", lambda c: 0 < c < 500)):
|
||||
try:
|
||||
r = http_get(instance + path, timeout=_PROBE_TIMEOUT)
|
||||
code = getattr(r, "status_code", 0)
|
||||
if accept(code):
|
||||
return _svc("searxng", OK, f"Reachable (HTTP {code}).",
|
||||
instance=safe_instance, probed=path, http_status=code)
|
||||
last_category = "http_error"
|
||||
except Exception as e: # connection refused, DNS, timeout, …
|
||||
last_category = _classify_error(e)
|
||||
return _svc("searxng", DOWN, f"Unreachable ({_detail_for(last_category)}).",
|
||||
instance=safe_instance, error=last_category)
|
||||
|
||||
|
||||
# ── ntfy ──
|
||||
|
||||
def _ntfy_integration(integrations: List[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
|
||||
"""First enabled ntfy integration with a base_url (matches note_routes)."""
|
||||
for i in integrations or []:
|
||||
if (i.get("preset") == "ntfy" and i.get("enabled", True)
|
||||
and i.get("base_url")):
|
||||
return i
|
||||
return None
|
||||
|
||||
|
||||
def ntfy_health(integrations: List[Dict[str, Any]], settings: Dict[str, Any],
|
||||
*, http_get: Callable = _http_get) -> Dict[str, Any]:
|
||||
"""Non-intrusive ntfy probe via the server's built-in `/v1/health` route.
|
||||
|
||||
No test notification is POSTed — `/v1/health` returns `{"healthy":true}`
|
||||
without publishing to a topic. The request keeps whatever credentials the
|
||||
configured base_url carries, but `meta.base` is sanitized.
|
||||
"""
|
||||
channel = settings.get("reminder_channel") or "browser"
|
||||
intg = _ntfy_integration(integrations)
|
||||
if not intg:
|
||||
return _svc("ntfy", DISABLED, "No ntfy integration configured.",
|
||||
reminder_channel=channel)
|
||||
raw = (intg.get("base_url") or "").strip()
|
||||
parsed = urlparse(raw)
|
||||
probe_base = (f"{parsed.scheme}://{parsed.netloc}"
|
||||
if parsed.scheme and parsed.netloc else raw.rstrip("/"))
|
||||
safe_base = _safe_url(raw)
|
||||
try:
|
||||
r = http_get(probe_base + "/v1/health", timeout=_PROBE_TIMEOUT)
|
||||
code = getattr(r, "status_code", 0)
|
||||
if code and code < 500:
|
||||
return _svc("ntfy", OK, f"Reachable (HTTP {code}).",
|
||||
base=safe_base, reminder_channel=channel, http_status=code)
|
||||
return _svc("ntfy", DOWN, "Server returned an error response.",
|
||||
base=safe_base, reminder_channel=channel, error="http_error")
|
||||
except Exception as e:
|
||||
category = _classify_error(e)
|
||||
return _svc("ntfy", DOWN, f"Unreachable ({_detail_for(category)}).",
|
||||
base=safe_base, reminder_channel=channel, error=category)
|
||||
|
||||
|
||||
# ── Email (IMAP) ──
|
||||
|
||||
def email_health(accounts: List[Dict[str, Any]],
|
||||
*, connect: Optional[Callable] = None) -> Dict[str, Any]:
|
||||
"""Try a short IMAP connect+logout per configured account, concurrently.
|
||||
|
||||
All connect → ok. Some fail → degraded. All fail → down. No account
|
||||
configured → disabled. Bounded by `_FANOUT_BUDGET` regardless of count.
|
||||
`meta` carries only the account label and a controlled error category —
|
||||
never credentials or raw exception text.
|
||||
"""
|
||||
if not accounts:
|
||||
return _svc("email", DISABLED, "No email accounts configured.")
|
||||
if connect is None:
|
||||
from routes.email_helpers import _imap_connect
|
||||
# Impose the service-health budget on the IMAP connect itself.
|
||||
connect = lambda aid: _imap_connect(aid, timeout=_PROBE_TIMEOUT) # noqa: E731
|
||||
|
||||
def _label(acc: Dict[str, Any]) -> str:
|
||||
return acc.get("account_name") or acc.get("account_id") or "account"
|
||||
|
||||
def _check(_i: int, acc: Dict[str, Any]) -> Dict[str, Any]:
|
||||
name = _label(acc)
|
||||
if not (acc.get("imap_host") or ""):
|
||||
return {"name": name, "ok": False, "error": "no_host"}
|
||||
try:
|
||||
conn = connect(acc.get("account_id"))
|
||||
try:
|
||||
conn.logout()
|
||||
except Exception:
|
||||
pass
|
||||
return {"name": name, "ok": True, "error": None}
|
||||
except Exception as e:
|
||||
return {"name": name, "ok": False, "error": _classify_error(e)}
|
||||
|
||||
raw = _bounded_map(accounts, _check, budget=_FANOUT_BUDGET,
|
||||
concurrency=_PROBE_CONCURRENCY)
|
||||
per_account = [r if r is not None
|
||||
else {"name": _label(accounts[i]), "ok": False, "error": "timeout"}
|
||||
for i, r in enumerate(raw)]
|
||||
return _rollup_items("email", "mailbox(es)", per_account)
|
||||
|
||||
|
||||
# ── Provider endpoints ──
|
||||
|
||||
def providers_health(endpoints: List[Dict[str, Any]],
|
||||
*, probe: Optional[Callable] = None) -> Dict[str, Any]:
|
||||
"""Probe each enabled model endpoint's model list, concurrently.
|
||||
|
||||
`endpoints` is a list of plain dicts ({name, base_url, api_key}) so this
|
||||
stays decoupled from the ORM and trivially testable. Non-empty model list
|
||||
→ reachable. Bounded by `_FANOUT_BUDGET` regardless of count. `meta` never
|
||||
contains api_key or raw URLs — only a display name (or a sanitized URL when
|
||||
no name is set) and a controlled error category.
|
||||
"""
|
||||
if not endpoints:
|
||||
return _svc("providers", DISABLED, "No model endpoints configured.")
|
||||
if probe is None:
|
||||
from routes.model_routes import _probe_endpoint as probe
|
||||
|
||||
def _label(ep: Dict[str, Any]) -> str:
|
||||
return ep.get("name") or _safe_url(ep.get("base_url")) or "endpoint"
|
||||
|
||||
def _check(_i: int, ep: Dict[str, Any]) -> Dict[str, Any]:
|
||||
name = _label(ep)
|
||||
try:
|
||||
models = probe(ep.get("base_url"), ep.get("api_key"),
|
||||
timeout=_PROBE_TIMEOUT) or []
|
||||
except Exception as e:
|
||||
return {"name": name, "ok": False, "model_count": 0,
|
||||
"error": _classify_error(e)}
|
||||
count = len(models)
|
||||
return {"name": name, "ok": bool(count), "model_count": count,
|
||||
"error": None if count else "no_models"}
|
||||
|
||||
raw = _bounded_map(endpoints, _check, budget=_FANOUT_BUDGET,
|
||||
concurrency=_PROBE_CONCURRENCY)
|
||||
per_endpoint = [r if r is not None
|
||||
else {"name": _label(endpoints[i]), "ok": False,
|
||||
"model_count": 0, "error": "timeout"}
|
||||
for i, r in enumerate(raw)]
|
||||
return _rollup_items("providers", "endpoint(s)", per_endpoint, key="endpoints")
|
||||
|
||||
|
||||
def _rollup_items(name: str, noun: str, items: List[Dict[str, Any]],
|
||||
key: str = "accounts") -> Dict[str, Any]:
|
||||
"""Shared ok/degraded/down rollup for a list of per-item probe results."""
|
||||
total = len(items)
|
||||
ok_count = sum(1 for it in items if it.get("ok"))
|
||||
if ok_count == total:
|
||||
status, detail = OK, f"{ok_count}/{total} {noun} reachable."
|
||||
elif ok_count == 0:
|
||||
status, detail = DOWN, f"No {noun} reachable."
|
||||
else:
|
||||
status, detail = DEGRADED, f"{ok_count}/{total} {noun} reachable."
|
||||
return _svc(name, status, detail, **{key: items})
|
||||
|
||||
|
||||
# ── Aggregate ──
|
||||
|
||||
def _rollup(services: List[Dict[str, Any]]) -> str:
|
||||
worst = OK
|
||||
for s in services:
|
||||
sev = _SEVERITY.get(s.get("status"))
|
||||
if sev is not None and sev > _SEVERITY[worst]:
|
||||
worst = s["status"]
|
||||
return worst
|
||||
|
||||
|
||||
def _gather_inputs() -> Dict[str, Any]:
|
||||
"""Pull live config/account/endpoint lists from the app's data sources.
|
||||
|
||||
Each lookup fails soft: a broken source yields an empty/neutral value so a
|
||||
single failure can't take down the whole health report.
|
||||
"""
|
||||
settings: Dict[str, Any] = {}
|
||||
integrations: List[Dict[str, Any]] = []
|
||||
accounts: List[Dict[str, Any]] = []
|
||||
endpoints: List[Dict[str, Any]] = []
|
||||
try:
|
||||
from src.settings import load_settings
|
||||
settings = load_settings() or {}
|
||||
except Exception as e:
|
||||
logger.debug(f"service_health: settings load failed: {e}")
|
||||
try:
|
||||
from src.integrations import load_integrations
|
||||
integrations = load_integrations() or []
|
||||
except Exception as e:
|
||||
logger.debug(f"service_health: integrations load failed: {e}")
|
||||
try:
|
||||
from routes.email_helpers import _list_email_accounts
|
||||
accounts = _list_email_accounts() or []
|
||||
except Exception as e:
|
||||
logger.debug(f"service_health: email accounts load failed: {e}")
|
||||
try:
|
||||
from core.database import SessionLocal, ModelEndpoint
|
||||
db = SessionLocal()
|
||||
try:
|
||||
rows = db.query(ModelEndpoint).filter(
|
||||
ModelEndpoint.is_enabled == True).all() # noqa: E712
|
||||
endpoints = [{"name": r.name, "base_url": r.base_url,
|
||||
"api_key": r.api_key} for r in rows]
|
||||
finally:
|
||||
db.close()
|
||||
except Exception as e:
|
||||
logger.debug(f"service_health: endpoint load failed: {e}")
|
||||
return {"settings": settings, "integrations": integrations,
|
||||
"accounts": accounts, "endpoints": endpoints}
|
||||
|
||||
|
||||
async def _run_subsystem(name: str, fn: Callable, *args: Any) -> Dict[str, Any]:
|
||||
"""Run one (sync) subsystem probe in a thread under a hard deadline.
|
||||
|
||||
A subsystem that overruns `_SUBSYSTEM_DEADLINE` (or raises) becomes a
|
||||
controlled `down`/`timeout` entry instead of hanging or leaking the error.
|
||||
"""
|
||||
try:
|
||||
return await asyncio.wait_for(asyncio.to_thread(fn, *args),
|
||||
timeout=_SUBSYSTEM_DEADLINE)
|
||||
except asyncio.TimeoutError:
|
||||
return _svc(name, DOWN, _detail_for("timeout"), error="timeout")
|
||||
except Exception as e:
|
||||
category = _classify_error(e)
|
||||
return _svc(name, DOWN, _detail_for(category), error=category)
|
||||
|
||||
|
||||
async def collect_service_health(rag_manager: Any = None,
|
||||
memory_vector: Any = None) -> Dict[str, Any]:
|
||||
"""Run every probe and return {overall, services, timestamp}.
|
||||
|
||||
Bounded end-to-end: in-process ChromaDB flags are read synchronously; the
|
||||
four network subsystems run concurrently, each under `_SUBSYSTEM_DEADLINE`,
|
||||
with an overall `_AGGREGATE_DEADLINE` backstop. Per-item probes inside
|
||||
providers/email are themselves bounded by `_FANOUT_BUDGET`.
|
||||
"""
|
||||
from datetime import datetime, timezone
|
||||
|
||||
inputs = _gather_inputs()
|
||||
settings = inputs["settings"]
|
||||
|
||||
# ChromaDB is in-process and synchronous (just reads flags).
|
||||
chroma = chromadb_health(rag_manager, memory_vector)
|
||||
|
||||
names = ["searxng", "ntfy", "email", "providers"]
|
||||
coros = [
|
||||
_run_subsystem("searxng", searxng_health, settings),
|
||||
_run_subsystem("ntfy", ntfy_health, inputs["integrations"], settings),
|
||||
_run_subsystem("email", email_health, inputs["accounts"]),
|
||||
_run_subsystem("providers", providers_health, inputs["endpoints"]),
|
||||
]
|
||||
try:
|
||||
results = await asyncio.wait_for(asyncio.gather(*coros),
|
||||
timeout=_AGGREGATE_DEADLINE)
|
||||
except asyncio.TimeoutError:
|
||||
# Hard backstop — should not normally fire given per-subsystem deadlines.
|
||||
results = [_svc(n, DOWN, _detail_for("timeout"), error="timeout")
|
||||
for n in names]
|
||||
|
||||
services = [chroma, *results]
|
||||
return {
|
||||
"overall": _rollup(services),
|
||||
"services": services,
|
||||
# Timezone-aware UTC (…+00:00). Avoids the deprecated naive
|
||||
# datetime.utcnow() flagged in review (overlaps with #1116).
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
}
|
||||
+23
-11
@@ -214,6 +214,24 @@ def _search_like(
|
||||
return _rows_to_results(db, shaped, query, context_messages)
|
||||
|
||||
|
||||
def _fetch_messages_by_id(db, message_ids):
|
||||
"""Fetch (message, session_name) for many message ids in a single query.
|
||||
|
||||
The FTS search returns a list of hit ids; fetching each row on its own was an
|
||||
N+1 query (one SELECT per hit). Batch them with one IN(...) query and return
|
||||
a lookup so the caller can reassemble results in hit (relevance) order.
|
||||
"""
|
||||
if not message_ids:
|
||||
return {}
|
||||
rows = (
|
||||
db.query(DBChatMessage, DBSession.name)
|
||||
.join(DBSession, DBChatMessage.session_id == DBSession.id)
|
||||
.filter(DBChatMessage.id.in_(message_ids))
|
||||
.all()
|
||||
)
|
||||
return {msg.id: (msg, session_name) for msg, session_name in rows}
|
||||
|
||||
|
||||
def _search_fts(
|
||||
db,
|
||||
query: str,
|
||||
@@ -267,19 +285,13 @@ def _search_fts(
|
||||
if not hits:
|
||||
return None
|
||||
|
||||
by_id = _fetch_messages_by_id(db, [hit[0] for hit in hits])
|
||||
rows = []
|
||||
for hit in hits:
|
||||
message_id = hit[0]
|
||||
snippet = hit[1] or ""
|
||||
row = (
|
||||
db.query(DBChatMessage, DBSession.name)
|
||||
.join(DBSession, DBChatMessage.session_id == DBSession.id)
|
||||
.filter(DBChatMessage.id == message_id)
|
||||
.first()
|
||||
)
|
||||
if row:
|
||||
msg, session_name = row
|
||||
rows.append((msg, session_name, snippet))
|
||||
found = by_id.get(hit[0])
|
||||
if found:
|
||||
msg, session_name = found
|
||||
rows.append((msg, session_name, hit[1] or ""))
|
||||
return _rows_to_results(db, rows, query, context_messages)
|
||||
|
||||
|
||||
|
||||
+19
-9
@@ -109,14 +109,22 @@ DEFAULT_SETTINGS = {
|
||||
"research_run_timeout_seconds": 1800,
|
||||
"agent_max_tool_calls": 0,
|
||||
"agent_max_rounds": 20, # per-message agent step cap (clamped 1..200)
|
||||
# Soft input-token budget for the agent loop. The DEFAULT value (6000) is the
|
||||
# "auto" sentinel: it means "scale the budget to the model's context window"
|
||||
# (#1230) — so long-context models aren't capped at 6000. Set ANY OTHER value
|
||||
# to enforce an explicit cap (clamped to the window only — hard_max does not
|
||||
# apply to explicit budgets, #1230); set 0 to disable soft-trimming. The
|
||||
# default is treated as auto because the settings-save path materializes
|
||||
# defaults, so a persisted 6000 can't be told apart from a deliberate 6000 —
|
||||
# to pin a budget near the default, use a nearby value (e.g. 5999).
|
||||
"agent_input_token_budget": 6000,
|
||||
# Ceiling on the *auto-derived* input budget that #1230 introduced. Has
|
||||
# no effect when `agent_input_token_budget` is explicitly set (the user's
|
||||
# value is honoured regardless). Default matches
|
||||
# `src.context_budget.DEFAULT_HARD_MAX`; lower this for cost-paranoid
|
||||
# setups, raise it on premium APIs with very large windows that you
|
||||
# Ceiling on the *auto-derived* input budget; a configurable setting since #1273
|
||||
# (the merged #1230 left it a module constant). No effect on an explicit budget
|
||||
# — a deliberate value is honoured (#1230). Default matches
|
||||
# `src.context_budget.DEFAULT_HARD_MAX`; lower this for
|
||||
# cost-paranoid setups, raise it on premium APIs with very large windows you
|
||||
# want to actually use (e.g. 900_000 to fill a 1M-context model). See
|
||||
# `compute_input_token_budget` in src/context_budget.py.
|
||||
# `compute_input_token_budget`.
|
||||
"agent_input_token_hard_max": 200_000,
|
||||
"agent_stream_timeout_seconds": 300,
|
||||
# Extra directory roots that read_file / write_file may access, in
|
||||
@@ -232,8 +240,10 @@ def is_setting_overridden(key: str) -> bool:
|
||||
|
||||
``load_settings`` merges DEFAULT_SETTINGS with the saved file, so a value
|
||||
equal to its default is indistinguishable from "never set" via get_setting.
|
||||
Callers that need to treat an explicit user choice differently from the
|
||||
default (e.g. adaptive budgets) use this to read the raw saved file.
|
||||
Callers that must distinguish an explicit user choice from a default read
|
||||
the raw saved file via this. (Note: a materialized default is also "present",
|
||||
so value-sensitive callers should compare against the default — see
|
||||
``context_budget.budget_is_explicit``.)
|
||||
"""
|
||||
try:
|
||||
with open(SETTINGS_FILE, "r", encoding="utf-8") as f:
|
||||
@@ -292,7 +302,7 @@ def load_features() -> dict:
|
||||
if not isinstance(saved, dict):
|
||||
raise ValueError("features must be an object")
|
||||
merged = {**DEFAULT_FEATURES, **saved}
|
||||
except (FileNotFoundError, json.JSONDecodeError, ValueError):
|
||||
except (FileNotFoundError, PermissionError, json.JSONDecodeError, ValueError):
|
||||
merged = dict(DEFAULT_FEATURES)
|
||||
_features_cache = (now, merged)
|
||||
return merged
|
||||
|
||||
+11
-1
@@ -12,6 +12,8 @@ tunnel / reverse proxy. Scrubbing is deep (recurses nested dicts/lists) and keye
|
||||
on secret-shaped names.
|
||||
"""
|
||||
|
||||
import re
|
||||
|
||||
_SECRET_KEY_PATTERNS = (
|
||||
"_api_key", "_apikey", "_password", "_passwd", "_pass", "_pwd",
|
||||
"_secret", "_client_secret", "_token", "_access_token", "_refresh_token",
|
||||
@@ -26,8 +28,16 @@ _SENSITIVE_KEY_EXACT = (
|
||||
)
|
||||
|
||||
|
||||
def _canonical_key_name(name: str) -> str:
|
||||
"""Normalize common JS-style key names so secret matching is style-agnostic."""
|
||||
n = (name or "").replace("-", "_")
|
||||
n = re.sub(r"(.)([A-Z][a-z]+)", r"\1_\2", n)
|
||||
n = re.sub(r"([a-z0-9])([A-Z])", r"\1_\2", n)
|
||||
return n.lower()
|
||||
|
||||
|
||||
def is_secret_key(name: str) -> bool:
|
||||
n = (name or "").lower()
|
||||
n = _canonical_key_name(name)
|
||||
if n in _SECRET_KEY_ALLOW:
|
||||
return False
|
||||
if n in _SENSITIVE_KEY_EXACT:
|
||||
|
||||
+52
-29
@@ -1324,7 +1324,10 @@ class TaskScheduler:
|
||||
db.commit()
|
||||
if self._session_manager:
|
||||
try:
|
||||
self._session_manager.sessions[session_id] = self._session_manager._db_to_session(sess)
|
||||
self._session_manager.ensure_task_session(
|
||||
session_id, f"[Task] {task.name}", endpoint_url, model,
|
||||
owner=task.owner, task=task
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@@ -1430,6 +1433,7 @@ class TaskScheduler:
|
||||
task's visible output target.
|
||||
"""
|
||||
from core.database import Session as DbSession, ChatMessage, CrewMember
|
||||
from core.models import ChatMessage as MemChatMessage
|
||||
|
||||
output = task.output_target or "session"
|
||||
if (
|
||||
@@ -1486,7 +1490,10 @@ class TaskScheduler:
|
||||
db.commit()
|
||||
if self._session_manager:
|
||||
try:
|
||||
self._session_manager.sessions[session_id] = self._session_manager._db_to_session(sess)
|
||||
self._session_manager.ensure_task_session(
|
||||
session_id, f"[Task] {task.name}", endpoint_url, model_name,
|
||||
owner=task.owner, task=task
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@@ -1495,36 +1502,50 @@ class TaskScheduler:
|
||||
meta["model"] = model_name
|
||||
if crew and crew.is_default_assistant:
|
||||
meta.update({"source": "cron", "task_id": task.id, "task_name": task.name})
|
||||
msg_meta = json.dumps(meta)
|
||||
user_content = task.prompt or f"[Task] {task.name}"
|
||||
user_msg = ChatMessage(
|
||||
id=str(uuid.uuid4()),
|
||||
session_id=session_id,
|
||||
role="user",
|
||||
content=user_content,
|
||||
timestamp=_utcnow(),
|
||||
meta_data=msg_meta,
|
||||
)
|
||||
assistant_msg = ChatMessage(
|
||||
id=str(uuid.uuid4()),
|
||||
session_id=session_id,
|
||||
role="assistant",
|
||||
content=result or "",
|
||||
timestamp=_utcnow(),
|
||||
meta_data=msg_meta,
|
||||
)
|
||||
db.add(user_msg)
|
||||
db.add(assistant_msg)
|
||||
db.commit()
|
||||
|
||||
if self._session_manager:
|
||||
# Use SessionManager for persistence so in-memory cache stays in sync
|
||||
if self._session_manager and session_id:
|
||||
try:
|
||||
from core.models import ChatMessage as MemMsg
|
||||
sess_obj = self._session_manager.get_session(session_id)
|
||||
sess_obj.history.append(MemMsg(role="user", content=user_msg.content, metadata=meta))
|
||||
sess_obj.history.append(MemMsg(role="assistant", content=assistant_msg.content, metadata=meta))
|
||||
self._session_manager.add_message(
|
||||
session_id,
|
||||
MemChatMessage(
|
||||
"user",
|
||||
task.prompt or f"[Task] {task.name}",
|
||||
metadata=dict(meta),
|
||||
),
|
||||
)
|
||||
self._session_manager.add_message(
|
||||
session_id,
|
||||
MemChatMessage(
|
||||
"assistant",
|
||||
result or "",
|
||||
metadata=dict(meta),
|
||||
),
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
logger.exception("Failed to deliver task %s through SessionManager", task.id)
|
||||
else:
|
||||
# Fallback: raw DB write (no session manager available)
|
||||
msg_meta = json.dumps(meta)
|
||||
user_msg = ChatMessage(
|
||||
id=str(uuid.uuid4()),
|
||||
session_id=session_id,
|
||||
role="user",
|
||||
content=task.prompt or f"[Task] {task.name}",
|
||||
timestamp=_utcnow(),
|
||||
meta_data=msg_meta,
|
||||
)
|
||||
assistant_msg = ChatMessage(
|
||||
id=str(uuid.uuid4()),
|
||||
session_id=session_id,
|
||||
role="assistant",
|
||||
content=result or "",
|
||||
timestamp=_utcnow(),
|
||||
meta_data=msg_meta,
|
||||
)
|
||||
db.add(user_msg)
|
||||
db.add(assistant_msg)
|
||||
db.commit()
|
||||
|
||||
@staticmethod
|
||||
def _is_email_output_target(output: str) -> bool:
|
||||
@@ -1641,6 +1662,8 @@ class TaskScheduler:
|
||||
data = json.loads(event_str[6:])
|
||||
# Capture text from all event types, not just delta
|
||||
if "delta" in data:
|
||||
if data.get("thinking"):
|
||||
continue
|
||||
full_text += data["delta"]
|
||||
elif data.get("type") == "tool_output":
|
||||
# Tool results — capture summary so we have SOMETHING even
|
||||
|
||||
@@ -42,7 +42,7 @@ _SOTA_HOSTS = frozenset({
|
||||
"api.together.xyz", "api.fireworks.ai",
|
||||
"api.perplexity.ai", "api.x.ai",
|
||||
"generativelanguage.googleapis.com", "api.groq.com",
|
||||
"openrouter.ai", "ollama.com", "api.venice.ai",
|
||||
"openrouter.ai", "ollama.com", "api.venice.ai", "api.kimi.com",
|
||||
})
|
||||
|
||||
|
||||
@@ -594,6 +594,8 @@ async def run_teacher_inline(
|
||||
"exit_code": payload.get("exit_code"),
|
||||
})
|
||||
if "delta" in payload and isinstance(payload["delta"], str):
|
||||
if payload.get("thinking"):
|
||||
continue
|
||||
captured_text_parts.append(payload["delta"])
|
||||
yield 'data: ' + json.dumps(payload) + '\n\n'
|
||||
continue
|
||||
|
||||
+148
-746
File diff suppressed because it is too large
Load Diff
+111
-677
@@ -18,6 +18,40 @@ from core.constants import internal_api_base
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Active email state
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# When the user has an email reader window open, the frontend tells the
|
||||
# backend about it on each chat submit. Email tools can resolve "this email"
|
||||
# without guessing a UID. Cleared between requests by chat_routes.
|
||||
_active_email_ref: Optional[Dict[str, str]] = None
|
||||
|
||||
|
||||
def set_active_email(uid: Optional[str], folder: Optional[str] = None, account: Optional[str] = None,
|
||||
subject: Optional[str] = None, sender: Optional[str] = None) -> None:
|
||||
"""Stash the email currently open in the UI. None clears it."""
|
||||
global _active_email_ref
|
||||
if not uid:
|
||||
_active_email_ref = None
|
||||
return
|
||||
_active_email_ref = {
|
||||
"uid": str(uid),
|
||||
"folder": str(folder or "INBOX"),
|
||||
"account": str(account or ""),
|
||||
"subject": str(subject or ""),
|
||||
"from": str(sender or ""),
|
||||
}
|
||||
|
||||
|
||||
def get_active_email() -> Optional[Dict[str, str]]:
|
||||
return _active_email_ref
|
||||
|
||||
|
||||
def clear_active_email() -> None:
|
||||
global _active_email_ref
|
||||
_active_email_ref = None
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Argument parsing
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -54,517 +88,6 @@ def _parse_tool_args(content):
|
||||
args = args["body"]
|
||||
return args
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Active document state
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_active_document_id: Optional[str] = None
|
||||
_active_model: Optional[str] = None
|
||||
# When the user has an email reader window open, the frontend tells the
|
||||
# backend about it on each chat submit. We stash it here so email tools
|
||||
# (reply_to_email, read_email, mark_email) can resolve "this email" / "the
|
||||
# open one" without the agent guessing a UID. Cleared between requests by
|
||||
# chat_routes after the agent loop returns.
|
||||
_active_email_ref: Optional[Dict[str, str]] = None
|
||||
|
||||
|
||||
def set_active_email(uid: Optional[str], folder: Optional[str] = None, account: Optional[str] = None,
|
||||
subject: Optional[str] = None, sender: Optional[str] = None) -> None:
|
||||
"""Stash the email currently open in the UI. None clears it."""
|
||||
global _active_email_ref
|
||||
if not uid:
|
||||
_active_email_ref = None
|
||||
return
|
||||
_active_email_ref = {
|
||||
"uid": str(uid),
|
||||
"folder": str(folder or "INBOX"),
|
||||
"account": str(account or ""),
|
||||
"subject": str(subject or ""),
|
||||
"from": str(sender or ""),
|
||||
}
|
||||
|
||||
|
||||
def get_active_email() -> Optional[Dict[str, str]]:
|
||||
return _active_email_ref
|
||||
|
||||
|
||||
def clear_active_email() -> None:
|
||||
global _active_email_ref
|
||||
_active_email_ref = None
|
||||
|
||||
|
||||
def set_active_document(doc_id: Optional[str]):
|
||||
"""Set the active document ID for document tool execution."""
|
||||
global _active_document_id
|
||||
_active_document_id = doc_id
|
||||
|
||||
|
||||
def set_active_model(model: Optional[str]):
|
||||
"""Set the current model name for version summaries."""
|
||||
global _active_model
|
||||
_active_model = model
|
||||
|
||||
|
||||
def get_active_document():
|
||||
return _active_document_id
|
||||
|
||||
|
||||
def clear_active_document(doc_id: Optional[str] = None) -> bool:
|
||||
"""Clear the in-memory active-document pointer.
|
||||
|
||||
With ``doc_id`` given, only clears when it matches the current pointer, so a
|
||||
different active document is left untouched. Returns True if it was cleared.
|
||||
|
||||
Called when a document is detached from its session or deleted (its tab is
|
||||
closed): without this, the stale pointer makes the last-resort doc-injection
|
||||
path re-surface a closed document in a later, unrelated chat — even one whose
|
||||
session no longer matches — because an unlinked doc has session_id NULL (#1160).
|
||||
"""
|
||||
global _active_document_id
|
||||
if doc_id is None or _active_document_id == doc_id:
|
||||
_active_document_id = None
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _owned_document_query(query, Document, owner: Optional[str]):
|
||||
if owner is None:
|
||||
# A bare Python `False` is not a valid SQL expression — SQLAlchemy 1.4
|
||||
# deprecates it and 2.0 raises ArgumentError. Use the SQL `false()`
|
||||
# literal to return zero rows for an unscoped (owner-less) query.
|
||||
from sqlalchemy import false
|
||||
return query.filter(false())
|
||||
return query.filter(Document.owner == owner)
|
||||
|
||||
|
||||
def _get_owned_document(db, Document, doc_id: str, owner: Optional[str], active_only: bool = False):
|
||||
q = db.query(Document).filter(Document.id == doc_id)
|
||||
if active_only:
|
||||
q = q.filter(Document.is_active == True)
|
||||
q = _owned_document_query(q, Document, owner)
|
||||
return q.first()
|
||||
|
||||
|
||||
def _most_recent_owned_document(db, Document, owner: Optional[str], active_only: bool = False):
|
||||
q = db.query(Document)
|
||||
if active_only:
|
||||
q = q.filter(Document.is_active == True)
|
||||
q = _owned_document_query(q, Document, owner)
|
||||
return q.order_by(Document.updated_at.desc()).first()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Document tools — create/update/edit/suggest living documents
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _sniff_doc_language(text: str) -> str:
|
||||
"""Best-effort detect a document's language from its content when the model
|
||||
didn't specify one. Defaults to 'markdown' (prose). Recognizes the common
|
||||
markup/code types the editor supports so e.g. an SVG isn't saved as markdown."""
|
||||
import json as _json, re as _re2
|
||||
s = (text or "").strip()
|
||||
if not s:
|
||||
return "markdown"
|
||||
head = s[:600]
|
||||
hl = head.lower()
|
||||
if _looks_like_email_document(s):
|
||||
return "email"
|
||||
# Markup (unambiguous)
|
||||
if "<svg" in hl:
|
||||
return "svg"
|
||||
if hl.startswith("<?xml"):
|
||||
return "xml"
|
||||
if (hl.startswith("<!doctype html") or hl.startswith("<html")
|
||||
or _re2.search(r"<(div|body|head|p|span|table|button|h[1-6]|ul|ol|li|img)\b", hl)):
|
||||
return "html"
|
||||
# JSON
|
||||
if s[0] in "{[":
|
||||
try:
|
||||
_json.loads(s)
|
||||
return "json"
|
||||
except Exception:
|
||||
pass
|
||||
# Shebang
|
||||
first = s.split("\n", 1)[0].strip().lower()
|
||||
if first.startswith("#!"):
|
||||
return "python" if "python" in first else "bash"
|
||||
# Code by strong leading signals (line-anchored so prose with stray words won't match)
|
||||
if _re2.search(r"(?m)^\s*(def \w|class \w|import \w|from \w[\w.]* import )", s):
|
||||
return "python"
|
||||
if _re2.search(r"(?m)^\s*(function \w|const \w|let \w|export |import .* from )", s):
|
||||
return "javascript"
|
||||
if _re2.search(r"(?mi)^\s*(select .* from |create table |insert into |update \w)", s):
|
||||
return "sql"
|
||||
if _re2.search(r"(?m)^[.#]?[\w-]+\s*\{[^{}]*:[^{}]*;", s):
|
||||
return "css"
|
||||
return "markdown"
|
||||
|
||||
|
||||
def _looks_like_email_document(text: str = "", title: str = "") -> bool:
|
||||
import re as _re
|
||||
title_l = (title or "").strip().lower()
|
||||
if title_l in {"new email", "new mail", "new message"}:
|
||||
return True
|
||||
s = (text or "").lstrip()
|
||||
if "\n---\n" in s and _re.search(r"(?im)^To:\s*", s) and _re.search(r"(?im)^Subject:\s*", s):
|
||||
return True
|
||||
return bool(_re.search(r"(?im)^To:\s*", s) and _re.search(r"(?im)^Subject:\s*", s))
|
||||
|
||||
|
||||
def _coerce_email_document_content(existing: str, incoming: str) -> str:
|
||||
"""Keep email docs in the To/Subject/---/body shape even if a model writes
|
||||
only the body or dumps header labels without the separator."""
|
||||
import re as _re
|
||||
old = existing or ""
|
||||
new = (incoming or "").strip()
|
||||
if "\n---\n" in new:
|
||||
return new
|
||||
header = old.split("\n---\n", 1)[0] if "\n---\n" in old else "To: \nSubject: "
|
||||
if _looks_like_email_document(new):
|
||||
lines = new.splitlines()
|
||||
last_header_idx = -1
|
||||
header_re = _re.compile(r"^(To|Cc|Bcc|Subject|In-Reply-To|References|X-Source-UID|X-Source-Folder|X-Attachments):", _re.I)
|
||||
for i, line in enumerate(lines):
|
||||
if header_re.match(line.strip()):
|
||||
last_header_idx = i
|
||||
body_lines = lines[last_header_idx + 1:] if last_header_idx >= 0 else lines
|
||||
while body_lines and not body_lines[0].strip():
|
||||
body_lines.pop(0)
|
||||
body = "\n".join(body_lines).strip()
|
||||
else:
|
||||
body = new
|
||||
return header.rstrip() + "\n---\n" + body
|
||||
|
||||
|
||||
async def do_create_document(content_block: str, session_id: Optional[str] = None, owner: Optional[str] = None) -> Dict:
|
||||
"""Create a new document. Supports two formats:
|
||||
1) Line-based: line 1 = title, line 2 (optional) = language, rest = content
|
||||
2) XML-like tags: <title>...</title><language>...</language><content>...</content>
|
||||
Some models mix them — strip any XML-style tags and fall back to line parsing."""
|
||||
import uuid, re as _re
|
||||
from src.database import SessionLocal, Document, DocumentVersion, Session as DbSession
|
||||
|
||||
raw = content_block or ""
|
||||
|
||||
# Known languages the editor understands (match the <select> in HTML)
|
||||
_KNOWN_LANGS = {
|
||||
"python", "javascript", "typescript", "html", "css", "markdown", "json",
|
||||
"yaml", "bash", "sql", "rust", "go", "java", "c", "cpp", "xml", "toml",
|
||||
"ini", "ruby", "php", "csv", "email", "text", "plain", "svg",
|
||||
}
|
||||
|
||||
# Try XML tag extraction first
|
||||
title = None
|
||||
language = None
|
||||
content = None
|
||||
mt = _re.search(r"<title>\s*(.*?)\s*</title>", raw, _re.DOTALL | _re.IGNORECASE)
|
||||
ml = _re.search(r"<language>\s*(.*?)\s*</language>", raw, _re.DOTALL | _re.IGNORECASE)
|
||||
mc = _re.search(r"<content>\s*(.*?)\s*</content>", raw, _re.DOTALL | _re.IGNORECASE)
|
||||
if mt or mc:
|
||||
title = mt.group(1).strip() if mt else None
|
||||
language = ml.group(1).strip().lower() if ml else None
|
||||
content = mc.group(1) if mc else None
|
||||
|
||||
# Fall back to line-based parsing. First strip any stray XML-ish tags.
|
||||
if title is None or content is None:
|
||||
cleaned = _re.sub(r"</?(?:title|language|content)>", "", raw)
|
||||
lines = cleaned.strip().split("\n")
|
||||
if title is None:
|
||||
title = lines[0].strip() if lines else "Untitled"
|
||||
lines = lines[1:]
|
||||
# Only consume second line as language if it looks like a valid short lang token
|
||||
if language is None and lines:
|
||||
candidate = lines[0].strip().lower()
|
||||
if candidate and len(candidate) < 20 and " " not in candidate and candidate in _KNOWN_LANGS:
|
||||
language = candidate
|
||||
lines = lines[1:]
|
||||
if content is None:
|
||||
content = "\n".join(lines)
|
||||
|
||||
# Validate language: must be in known set, else default based on content
|
||||
if language and language not in _KNOWN_LANGS:
|
||||
language = None
|
||||
if not language:
|
||||
# No explicit language — sniff it from the content so an SVG / HTML / JSON
|
||||
# / code document isn't silently saved as markdown. Prose → markdown.
|
||||
language = _sniff_doc_language(content)
|
||||
if _looks_like_email_document(content, title):
|
||||
language = "email"
|
||||
|
||||
if not title:
|
||||
title = "Untitled"
|
||||
|
||||
if not session_id:
|
||||
return {"error": "No session context for document creation"}
|
||||
|
||||
db = SessionLocal()
|
||||
try:
|
||||
doc_id = str(uuid.uuid4())
|
||||
ver_id = str(uuid.uuid4())
|
||||
|
||||
# Inherit ownership from the chat session so the doc survives that
|
||||
# session later being deleted (session_id → NULL).
|
||||
_sess = db.query(DbSession).filter(DbSession.id == session_id).first()
|
||||
if owner is not None and (not _sess or _sess.owner != owner):
|
||||
return {"error": "Cannot create document in another user's session"}
|
||||
_owner = _sess.owner if _sess else None
|
||||
|
||||
doc = Document(
|
||||
id=doc_id,
|
||||
session_id=session_id,
|
||||
title=title,
|
||||
language=language,
|
||||
current_content=content,
|
||||
version_count=1,
|
||||
is_active=True,
|
||||
owner=_owner,
|
||||
)
|
||||
ver = DocumentVersion(
|
||||
id=ver_id,
|
||||
document_id=doc_id,
|
||||
version_number=1,
|
||||
content=content,
|
||||
summary=f"Created by {_active_model or 'AI'}",
|
||||
source="ai",
|
||||
)
|
||||
db.add(doc)
|
||||
db.add(ver)
|
||||
db.commit()
|
||||
|
||||
set_active_document(doc_id)
|
||||
try:
|
||||
from src.event_bus import fire_event
|
||||
fire_event("document_created", _owner)
|
||||
except Exception:
|
||||
logger.debug("document_created event dispatch failed", exc_info=True)
|
||||
|
||||
return {
|
||||
"action": "create",
|
||||
"doc_id": doc_id,
|
||||
"title": title,
|
||||
"language": language,
|
||||
"content": content,
|
||||
"version": 1,
|
||||
}
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
return {"error": f"Failed to create document: {e}"}
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
async def do_update_document(content: str, doc_id: Optional[str] = None, owner: Optional[str] = None) -> Dict:
|
||||
"""Update an existing document. Content = full new document text."""
|
||||
import uuid
|
||||
from src.database import SessionLocal, Document, DocumentVersion
|
||||
|
||||
target_id = doc_id or _active_document_id
|
||||
|
||||
db = SessionLocal()
|
||||
try:
|
||||
doc = None
|
||||
if target_id:
|
||||
doc = _get_owned_document(db, Document, target_id, owner)
|
||||
if not doc:
|
||||
doc = _most_recent_owned_document(db, Document, owner)
|
||||
if doc:
|
||||
target_id = doc.id
|
||||
set_active_document(target_id)
|
||||
logger.info(f"update_document: fell back to most recent doc id={target_id}")
|
||||
if not doc:
|
||||
return {"error": "No documents exist to update"}
|
||||
|
||||
is_email_doc = doc.language == "email" or _looks_like_email_document(doc.current_content or "", doc.title or "")
|
||||
new_content = _coerce_email_document_content(doc.current_content or "", content) if is_email_doc else content.strip()
|
||||
if is_email_doc:
|
||||
doc.language = "email"
|
||||
|
||||
new_ver = doc.version_count + 1
|
||||
ver = DocumentVersion(
|
||||
id=str(uuid.uuid4()),
|
||||
document_id=target_id,
|
||||
version_number=new_ver,
|
||||
content=new_content,
|
||||
summary=f"Updated by {_active_model or 'AI'}",
|
||||
source="ai",
|
||||
)
|
||||
doc.current_content = new_content
|
||||
doc.version_count = new_ver
|
||||
db.add(ver)
|
||||
db.commit()
|
||||
|
||||
return {
|
||||
"action": "update",
|
||||
"doc_id": target_id,
|
||||
"title": doc.title,
|
||||
"language": doc.language,
|
||||
"content": new_content,
|
||||
"version": new_ver,
|
||||
}
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
return {"error": f"Failed to update document: {e}"}
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
def parse_edit_blocks(content: str) -> list:
|
||||
"""Parse <<<FIND>>>...<<<REPLACE>>>...<<<END>>> blocks."""
|
||||
edits = []
|
||||
pattern = r'<<<FIND>>>\n(.*?)\n<<<REPLACE>>>\n(.*?)\n<<<END>>>'
|
||||
for m in re.finditer(pattern, content, re.DOTALL):
|
||||
edits.append({"find": m.group(1), "replace": m.group(2)})
|
||||
return edits
|
||||
|
||||
|
||||
async def do_edit_document(content: str, doc_id: Optional[str] = None, owner: Optional[str] = None) -> Dict:
|
||||
"""Apply targeted FIND/REPLACE edits to an existing document."""
|
||||
import uuid
|
||||
from src.database import SessionLocal, Document, DocumentVersion
|
||||
|
||||
target_id = doc_id or _active_document_id
|
||||
|
||||
edits = parse_edit_blocks(content)
|
||||
if not edits:
|
||||
return {"error": "No valid <<<FIND>>>...<<<REPLACE>>>...<<<END>>> blocks found"}
|
||||
|
||||
db = SessionLocal()
|
||||
try:
|
||||
doc = None
|
||||
if target_id:
|
||||
doc = _get_owned_document(db, Document, target_id, owner)
|
||||
if not doc:
|
||||
# Fallback: most recently updated document. Avoids "no active doc" errors
|
||||
# after server restart or when the agent loses track of which doc to edit.
|
||||
doc = _most_recent_owned_document(db, Document, owner)
|
||||
if doc:
|
||||
target_id = doc.id
|
||||
set_active_document(target_id)
|
||||
logger.info(f"edit_document: fell back to most recent doc id={target_id} title={doc.title!r}")
|
||||
if not doc:
|
||||
return {"error": "No documents exist to edit"}
|
||||
|
||||
updated_content = doc.current_content
|
||||
applied = 0
|
||||
skipped = 0
|
||||
for edit in edits:
|
||||
_find = edit["find"]
|
||||
if _find in updated_content:
|
||||
updated_content = updated_content.replace(_find, edit["replace"], 1)
|
||||
applied += 1
|
||||
else:
|
||||
# Defensive: the active-doc context shows a "N\t" line-number
|
||||
# gutter for reference. Weaker models sometimes copy that prefix
|
||||
# into FIND. If the exact match failed, retry with a leading
|
||||
# "<digits><tab>" stripped from each FIND line — but only use it
|
||||
# when that stripped form actually matches, so we never corrupt a
|
||||
# legitimately tab-prefixed document.
|
||||
_stripped = "\n".join(re.sub(r"^\d+\t", "", _l) for _l in _find.split("\n"))
|
||||
if _stripped != _find and _stripped in updated_content:
|
||||
updated_content = updated_content.replace(_stripped, edit["replace"], 1)
|
||||
applied += 1
|
||||
logger.info("edit_document: matched after stripping line-number gutter from FIND")
|
||||
else:
|
||||
logger.warning(f"edit_document: FIND text not found, skipping: {_find[:80]!r}")
|
||||
skipped += 1
|
||||
|
||||
if applied == 0:
|
||||
return {"error": f"No edits applied — none of the FIND blocks matched the document content (skipped {skipped})"}
|
||||
|
||||
new_ver = doc.version_count + 1
|
||||
ver = DocumentVersion(
|
||||
id=str(uuid.uuid4()),
|
||||
document_id=target_id,
|
||||
version_number=new_ver,
|
||||
content=updated_content,
|
||||
summary=f"Edited by {_active_model or 'AI'} ({applied} edit(s))",
|
||||
source="ai",
|
||||
)
|
||||
doc.current_content = updated_content
|
||||
doc.version_count = new_ver
|
||||
db.add(ver)
|
||||
db.commit()
|
||||
|
||||
return {
|
||||
"action": "edit",
|
||||
"doc_id": target_id,
|
||||
"title": doc.title,
|
||||
"language": doc.language,
|
||||
"content": updated_content,
|
||||
"version": new_ver,
|
||||
"applied": applied,
|
||||
"skipped": skipped,
|
||||
}
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
return {"error": f"Failed to edit document: {e}"}
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
def parse_suggest_blocks(content: str) -> list:
|
||||
"""Parse <<<FIND>>>...<<<SUGGEST>>>...<<<REASON>>>...<<<END>>> blocks."""
|
||||
suggestions = []
|
||||
_skip_phrases = ["no change", "clear", "fine as", "looks good", "no improvement", "keep as"]
|
||||
pattern = r'<<<FIND>>>\n(.*?)\n<<<SUGGEST>>>\n(.*?)\n<<<REASON>>>\n(.*?)\n<<<END>>>'
|
||||
for m in re.finditer(pattern, content, re.DOTALL):
|
||||
find_text = m.group(1)
|
||||
replace_text = m.group(2)
|
||||
reason = m.group(3).strip()
|
||||
# Skip no-op suggestions where find == replace or reason says no change
|
||||
if find_text.strip() == replace_text.strip():
|
||||
continue
|
||||
if any(phrase in reason.lower() for phrase in _skip_phrases):
|
||||
continue
|
||||
suggestions.append({
|
||||
"id": f"sugg-{len(suggestions)+1}",
|
||||
"find": find_text,
|
||||
"replace": replace_text,
|
||||
"reason": reason,
|
||||
})
|
||||
return suggestions
|
||||
|
||||
|
||||
async def do_suggest_document(content: str, doc_id: str = None, owner: Optional[str] = None) -> Dict:
|
||||
"""Create inline suggestions for the active document WITHOUT modifying it."""
|
||||
from src.database import SessionLocal, Document
|
||||
|
||||
target_id = doc_id or _active_document_id
|
||||
if not target_id:
|
||||
return {"error": "No active document to suggest on"}
|
||||
|
||||
suggestions = parse_suggest_blocks(content)
|
||||
if not suggestions:
|
||||
return {"error": "No valid <<<FIND>>>...<<<SUGGEST>>>...<<<REASON>>>...<<<END>>> blocks found"}
|
||||
|
||||
db = SessionLocal()
|
||||
try:
|
||||
doc = _get_owned_document(db, Document, target_id, owner)
|
||||
if not doc:
|
||||
return {"error": f"Document {target_id} not found"}
|
||||
|
||||
# Validate that FIND text exists in document
|
||||
valid = []
|
||||
for s in suggestions:
|
||||
if s["find"] in doc.current_content:
|
||||
valid.append(s)
|
||||
else:
|
||||
logger.warning(f"suggest_document: FIND text not found, skipping: {s['find'][:80]!r}")
|
||||
|
||||
if not valid:
|
||||
return {"error": "No suggestions matched the document content"}
|
||||
|
||||
return {
|
||||
"action": "suggest",
|
||||
"doc_id": target_id,
|
||||
"suggestions": valid,
|
||||
"count": len(valid),
|
||||
}
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Search chats
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -1392,147 +915,6 @@ async def do_manage_tokens(content: str, owner: Optional[str] = None) -> Dict:
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Document management tool (delete, list, organize)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def do_manage_documents(content: str, owner: Optional[str] = None) -> Dict:
|
||||
"""Manage documents: list, read/view/open, delete, tidy.
|
||||
|
||||
Output format mirrors `manage_session`: list rows include a
|
||||
clickable `[Title](#document-<id>)` anchor + relative timestamps
|
||||
so the user can click straight from chat to open the editor.
|
||||
"""
|
||||
from core.database import SessionLocal, Document
|
||||
from datetime import datetime, timezone
|
||||
|
||||
try:
|
||||
args = _parse_tool_args(content)
|
||||
except ValueError:
|
||||
return {"error": "Invalid JSON arguments", "exit_code": 1}
|
||||
|
||||
action = args.get("action", "list")
|
||||
db = SessionLocal()
|
||||
|
||||
def _rel(ts):
|
||||
if not ts:
|
||||
return 'never'
|
||||
try:
|
||||
now = datetime.now(timezone.utc) if ts.tzinfo is not None else datetime.utcnow()
|
||||
diff = (now - ts).total_seconds()
|
||||
except Exception:
|
||||
return 'unknown'
|
||||
if diff < 60: return 'just now'
|
||||
if diff < 3600: return f'{int(diff / 60)}m ago'
|
||||
if diff < 86400: return f'{int(diff / 3600)}h ago'
|
||||
if diff < 86400 * 7: return f'{int(diff / 86400)}d ago'
|
||||
return ts.strftime('%Y-%m-%d')
|
||||
|
||||
try:
|
||||
if action == "list":
|
||||
q = db.query(Document).filter(Document.is_active == True)
|
||||
q = _owned_document_query(q, Document, owner)
|
||||
if args.get("search"):
|
||||
q = q.filter(Document.title.ilike(f"%{args['search']}%"))
|
||||
if args.get("language"):
|
||||
q = q.filter(Document.language == args["language"])
|
||||
docs = q.order_by(Document.updated_at.desc()).limit(args.get("limit", 50)).all()
|
||||
if not docs:
|
||||
msg = "No documents found" + (f" matching '{args['search']}'" if args.get("search") else "") + "."
|
||||
return {"response": msg, "documents": [], "exit_code": 0}
|
||||
lines = []
|
||||
items = []
|
||||
for i, d in enumerate(docs):
|
||||
size = len(d.current_content or "")
|
||||
lang = d.language or "text"
|
||||
ts = getattr(d, 'updated_at', None) or getattr(d, 'created_at', None)
|
||||
marker = " ← most recent" if i == 0 else ""
|
||||
lines.append(
|
||||
f"- [{d.title}](#document-{d.id}) — {lang}, {size} chars, updated {_rel(ts)}{marker}"
|
||||
)
|
||||
items.append({"id": d.id, "title": d.title, "language": lang, "size": size})
|
||||
header = f"Found {len(docs)} document(s), sorted most-recent first. Click a title to open:"
|
||||
return {
|
||||
"response": header + "\n" + "\n".join(lines),
|
||||
"documents": items,
|
||||
"exit_code": 0,
|
||||
}
|
||||
|
||||
elif action in ("read", "view", "open", "get"):
|
||||
doc_id = args.get("document_id") or args.get("id") or args.get("uid")
|
||||
if not doc_id:
|
||||
return {"error": "Need document_id (use action=list to find one)", "exit_code": 1}
|
||||
doc = _get_owned_document(db, Document, doc_id, owner, active_only=True)
|
||||
if not doc:
|
||||
return {"error": f"Document '{doc_id}' not found", "exit_code": 1}
|
||||
body = doc.current_content or ""
|
||||
total = len(body)
|
||||
# Clamp offset to [0, total] so a far-out offset returns an empty
|
||||
# window with a useful "end of document" hint rather than erroring.
|
||||
try: offset = int(args.get("offset", 0))
|
||||
except (TypeError, ValueError): offset = 0
|
||||
offset = max(0, min(offset, total))
|
||||
preview_limit = int(args.get("limit", MAX_READ_CHARS))
|
||||
chunk = body[offset:offset + preview_limit]
|
||||
next_offset = offset + len(chunk)
|
||||
has_more = next_offset < total
|
||||
# Trailing marker — tells the agent (and a curious human) exactly
|
||||
# what to pass next to continue paginating.
|
||||
if has_more:
|
||||
marker = f"\n... ({total - next_offset:,} more chars; pass offset={next_offset} to continue)"
|
||||
elif offset > 0:
|
||||
marker = f"\n... (end of document, {total:,} chars total)"
|
||||
else:
|
||||
marker = ""
|
||||
preview = chunk + marker
|
||||
anchor = f"[{doc.title}](#document-{doc.id})"
|
||||
return {
|
||||
"response": f"{anchor} — click to open in editor.\n\n```{doc.language or ''}\n{preview}\n```",
|
||||
"document": {
|
||||
"id": doc.id,
|
||||
"title": doc.title,
|
||||
"language": doc.language,
|
||||
"size": total,
|
||||
"content": chunk,
|
||||
"offset": offset,
|
||||
"next_offset": next_offset if has_more else None,
|
||||
"truncated": has_more,
|
||||
},
|
||||
"exit_code": 0,
|
||||
}
|
||||
|
||||
elif action == "delete":
|
||||
doc_id = args.get("document_id") or args.get("id") or args.get("uid") or _active_document_id
|
||||
doc = None
|
||||
if doc_id:
|
||||
doc = _get_owned_document(db, Document, doc_id, owner)
|
||||
if not doc:
|
||||
# Fallback: most recently updated doc (likely what the user means)
|
||||
doc = _most_recent_owned_document(db, Document, owner, active_only=True)
|
||||
if not doc:
|
||||
return {"error": "No document to delete", "exit_code": 1}
|
||||
title = doc.title
|
||||
doc.is_active = False
|
||||
db.commit()
|
||||
if _active_document_id == doc.id:
|
||||
set_active_document(None)
|
||||
return {"response": f"Deleted document '{title}'", "exit_code": 0}
|
||||
|
||||
elif action == "tidy":
|
||||
from src.document_actions import run_document_tidy
|
||||
result = await run_document_tidy(owner or "")
|
||||
return {"response": result, "exit_code": 0}
|
||||
|
||||
else:
|
||||
return {"error": f"Unknown action: {action}", "exit_code": 1}
|
||||
except Exception as e:
|
||||
logger.error(f"manage_documents error: {e}")
|
||||
return {"error": str(e), "exit_code": 1}
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Settings/preferences management tool
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -2097,7 +1479,15 @@ async def do_manage_calendar(content: str, owner: Optional[str] = None) -> Dict:
|
||||
"""Handle manage_calendar tool calls: list/create/update/delete calendar events (local SQLite)."""
|
||||
from datetime import datetime, timedelta
|
||||
from core.database import SessionLocal, CalendarCal, CalendarEvent, Note
|
||||
from routes.calendar_routes import _ensure_default_calendar, _parse_dt, _parse_dt_pair, parse_due_for_user, _resolve_base_uid
|
||||
from routes.calendar_routes import (
|
||||
_ensure_default_calendar,
|
||||
_parse_dt,
|
||||
_parse_dt_pair,
|
||||
parse_due_for_user,
|
||||
_resolve_base_uid,
|
||||
_push_caldav_event_after_commit,
|
||||
_record_caldav_delete_tombstone,
|
||||
)
|
||||
import uuid as _uuid
|
||||
|
||||
try:
|
||||
@@ -2105,6 +1495,42 @@ async def do_manage_calendar(content: str, owner: Optional[str] = None) -> Dict:
|
||||
except ValueError:
|
||||
return {"error": "Invalid JSON arguments", "exit_code": 1}
|
||||
|
||||
# ── Batch normalization ──
|
||||
# Some models (e.g. deepseek-v4-flash) emit {"events": [{...}, ...]}
|
||||
# instead of individual create_event calls. Iterate and create each.
|
||||
if isinstance(args.get("events"), list) and not args.get("action"):
|
||||
results = []
|
||||
for ev in args["events"]:
|
||||
if not isinstance(ev, dict):
|
||||
continue
|
||||
# Normalize start/end from {dateTime: "..."} object to flat string
|
||||
for field, target in [("start", "dtstart"), ("end", "dtend")]:
|
||||
val = ev.pop(field, None)
|
||||
if val and target not in ev:
|
||||
ev[target] = val.get("dateTime", val) if isinstance(val, dict) else val
|
||||
ev.setdefault("action", "create_event")
|
||||
r = await do_manage_calendar(json.dumps(ev), owner=owner)
|
||||
results.append(r)
|
||||
created = [r for r in results if r.get("exit_code") == 0 and not r.get("error")]
|
||||
failed = [r for r in results if r.get("error")]
|
||||
|
||||
if not results:
|
||||
return {"error": "No events to create", "exit_code": 1}
|
||||
|
||||
# Surface both successes and failures
|
||||
parts = []
|
||||
if created:
|
||||
summaries = [r.get("response", "") for r in created]
|
||||
parts.append(f"Created {len(created)} event(s):\n" + "\n".join(summaries))
|
||||
if failed:
|
||||
first_error = failed[0].get("error", "Unknown error")
|
||||
parts.append(f"Failed to create {len(failed)} event(s). First error: {first_error}")
|
||||
|
||||
response = "\n\n".join(parts)
|
||||
# Non-zero exit code for partial or total failure
|
||||
exit_code = 0 if not failed else 1
|
||||
return {"response": response, "exit_code": exit_code, "created_count": len(created), "failed_count": len(failed)}
|
||||
|
||||
# Normalize action — some models emit hyphens ("list-calendars") instead
|
||||
# of underscores. Treat them as equivalent so we don't bounce a
|
||||
# cosmetic typo back to the model and waste a round-trip. Also accept
|
||||
@@ -2259,6 +1685,9 @@ async def do_manage_calendar(content: str, owner: Optional[str] = None) -> Dict:
|
||||
except ValueError as e:
|
||||
return {"error": f"Invalid date format: {e}", "exit_code": 1}
|
||||
|
||||
if end_dt <= start_dt:
|
||||
end_dt = start_dt + timedelta(days=1)
|
||||
|
||||
q = _event_query().filter(
|
||||
CalendarEvent.dtstart < end_dt,
|
||||
CalendarEvent.dtend > start_dt,
|
||||
@@ -2438,6 +1867,7 @@ async def do_manage_calendar(content: str, owner: Optional[str] = None) -> Dict:
|
||||
rrule=args.get("rrule", "") or "",
|
||||
event_type=event_type,
|
||||
importance=importance,
|
||||
caldav_sync_pending="create" if cal.source == "caldav" else None,
|
||||
)
|
||||
db.add(ev)
|
||||
reminder_note_id = None
|
||||
@@ -2452,6 +1882,8 @@ async def do_manage_calendar(content: str, owner: Optional[str] = None) -> Dict:
|
||||
dtstart_is_utc and not all_day,
|
||||
)
|
||||
db.commit()
|
||||
if cal.source == "caldav":
|
||||
await _push_caldav_event_after_commit(owner, uid, "create")
|
||||
tag_blurb = f" [{event_type}]" if event_type else ""
|
||||
if minutes_before is None:
|
||||
reminder_blurb = ""
|
||||
@@ -2509,7 +1941,12 @@ async def do_manage_calendar(content: str, owner: Optional[str] = None) -> Dict:
|
||||
ev.event_type = _tag or None
|
||||
if args.get("importance") is not None:
|
||||
ev.importance = args["importance"]
|
||||
is_caldav = ev.calendar and ev.calendar.source == "caldav"
|
||||
if is_caldav:
|
||||
ev.caldav_sync_pending = "update"
|
||||
db.commit()
|
||||
if is_caldav:
|
||||
await _push_caldav_event_after_commit(owner, base_uid, "update")
|
||||
return {"response": f"Updated event {uid}", "exit_code": 0}
|
||||
|
||||
elif action == "delete_event":
|
||||
@@ -2523,8 +1960,13 @@ async def do_manage_calendar(content: str, owner: Optional[str] = None) -> Dict:
|
||||
ev = _event_query().filter(CalendarEvent.uid == base_uid).first()
|
||||
if not ev:
|
||||
return {"error": f"Event {uid} not found", "exit_code": 1}
|
||||
is_caldav = ev.calendar and ev.calendar.source == "caldav" and ev.remote_href
|
||||
if is_caldav:
|
||||
_record_caldav_delete_tombstone(db, ev, owner)
|
||||
db.delete(ev)
|
||||
db.commit()
|
||||
if is_caldav:
|
||||
await _push_caldav_event_after_commit(owner, base_uid, "delete")
|
||||
return {"response": f"Deleted event {uid}", "exit_code": 0}
|
||||
|
||||
else:
|
||||
@@ -2670,13 +2112,14 @@ async def _cookbook_env_for_host(host: str) -> Dict[str, Any]:
|
||||
else:
|
||||
env_prefix = f'eval "$(conda shell.bash hook)" && conda activate {env_path}'
|
||||
|
||||
from routes.cookbook_helpers import load_stored_hf_token
|
||||
return {
|
||||
"env_prefix": env_prefix,
|
||||
"env_type": env_kind,
|
||||
"env_path": env_path,
|
||||
"gpus": env_root.get("gpus") or "",
|
||||
"platform": platform,
|
||||
"hf_token": env_root.get("hfToken") or "",
|
||||
"hf_token": load_stored_hf_token(),
|
||||
"ssh_port": ssh_port,
|
||||
}
|
||||
|
||||
@@ -2733,7 +2176,7 @@ async def _ensure_served_endpoint(
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=30) as client:
|
||||
resp = await client.post(
|
||||
f"{_COOKBOOK_BASE}/api/model-endpoints",
|
||||
f"{_INTERNAL_BASE}/api/model-endpoints",
|
||||
data=payload,
|
||||
headers=_internal_headers(),
|
||||
)
|
||||
@@ -4428,24 +3871,16 @@ async def do_manage_contact(content: str, owner: Optional[str] = None) -> Dict:
|
||||
|
||||
if action == "add":
|
||||
email = (args.get("email") or "").strip()
|
||||
name = (args.get("name") or "").strip() or (email.split("@")[0] if email else "")
|
||||
address = (args.get("address") or "").strip()
|
||||
# Need at least one identifying field. Address-only (e.g. a
|
||||
# business location with no email) is fine as long as there's
|
||||
# a name.
|
||||
if not email and not name:
|
||||
return {"error": "Provide at least name+address or email for add", "exit_code": 1}
|
||||
# Dedupe by email when one is given.
|
||||
if email:
|
||||
existing = await asyncio.to_thread(cc._fetch_contacts)
|
||||
for c in existing:
|
||||
if email.lower() in [e.lower() for e in c.get("emails", [])]:
|
||||
return {"output": f"{email} is already a contact ({c.get('name','')}).", "exit_code": 0}
|
||||
ok = await asyncio.to_thread(cc._create_contact, name, email, address)
|
||||
tail = f" <{email}>" if email else ""
|
||||
if address:
|
||||
tail += f" — {address}"
|
||||
return {"output": f"{'Added' if ok else 'Failed to add'} {name}{tail}.", "exit_code": 0 if ok else 1}
|
||||
if not email:
|
||||
return {"error": "email is required for add", "exit_code": 1}
|
||||
name = (args.get("name") or "").strip() or email.split("@")[0]
|
||||
# Dedupe by email (same as the /add route).
|
||||
existing = await asyncio.to_thread(cc._fetch_contacts)
|
||||
for c in existing:
|
||||
if email.lower() in [e.lower() for e in c.get("emails", [])]:
|
||||
return {"output": f"{email} is already a contact ({c.get('name','')}).", "exit_code": 0}
|
||||
ok = await asyncio.to_thread(cc._create_contact, name, email)
|
||||
return {"output": f"{'Added' if ok else 'Failed to add'} {name} <{email}>.", "exit_code": 0 if ok else 1}
|
||||
|
||||
if action in ("update", "edit"):
|
||||
uid = (args.get("uid") or "").strip()
|
||||
@@ -4457,12 +3892,11 @@ async def do_manage_contact(content: str, owner: Optional[str] = None) -> Dict:
|
||||
emails = [args["email"]]
|
||||
emails = [e.strip() for e in (emails or []) if e and e.strip()]
|
||||
phones = [p.strip() for p in (args.get("phones") or []) if p and p.strip()]
|
||||
address = (args.get("address") or "").strip()
|
||||
if not name and not emails and not address:
|
||||
return {"error": "Provide a name, emails, or address to update", "exit_code": 1}
|
||||
if not name and not emails:
|
||||
return {"error": "Provide a name or emails to update", "exit_code": 1}
|
||||
if not name and emails:
|
||||
name = emails[0].split("@")[0]
|
||||
ok = await asyncio.to_thread(cc._update_contact, uid, name, emails, phones, address)
|
||||
ok = await asyncio.to_thread(cc._update_contact, uid, name, emails, phones)
|
||||
return {"output": "Contact updated." if ok else "Update failed.", "exit_code": 0 if ok else 1}
|
||||
|
||||
if action == "delete":
|
||||
|
||||
+7
-2
@@ -67,14 +67,15 @@ COLLECTION_NAME = "odysseus_tool_index"
|
||||
# Each tool gets a searchable description that helps retrieval.
|
||||
# These are richer than the system prompt one-liners — they're for embedding.
|
||||
BUILTIN_TOOL_DESCRIPTIONS: Dict[str, str] = {
|
||||
"bash": "Run shell commands on the server. Install packages, check files, git operations, system info, and process management. Do not use for web lookup/search; use web_search or web_fetch when web tools are available.",
|
||||
"python": "Execute Python code for computation, data processing, math, scripting, and parsing. Not for writing code for the user. Do not use for web lookup/search; use web_search or web_fetch when web tools are available.",
|
||||
"bash": "Run shell commands on the server. Install packages, git operations, builds, system info, process management. Prefer a dedicated tool whenever one fits the job (file read/write/edit, search, listing); use bash only for what no dedicated tool covers. Do not use for web lookup/search; use web_search or web_fetch when web tools are available.",
|
||||
"python": "Execute Python code for computation, data processing, math, scripting, and parsing. Not for writing code for the user. Prefer a dedicated tool for reading, writing, or searching files; use python only for what no dedicated tool covers. Do not use for web lookup/search; use web_search or web_fetch when web tools are available.",
|
||||
"web_search": "Quick single web lookup for a fact, current event, latest/current information, or doc mid-task. Use this instead of bash/curl/python/requests for web searches. NOT for 'research X' / 'do research on X' requests — those are deep-research jobs (use trigger_research). web_search = one query; trigger_research = a full researched report in the sidebar.",
|
||||
"web_fetch": "Fetch and read the text content of a specific URL/website the user names (e.g. 'check example.com', 'open this link'). Use when you have a concrete URL; for open-ended lookups use web_search instead.",
|
||||
"read_file": "Read a file from disk and return its contents. View source code, config files, logs. Supports an optional line range (offset/limit) for large files.",
|
||||
"grep": "Search file CONTENTS for a regex across a directory tree (ripgrep-backed, honours .gitignore). Returns file:line:match. Use to find where code/symbols/strings live — prefer over bash grep.",
|
||||
"glob": "Find FILES by glob pattern (e.g. '**/*.py'), newest first. Use to locate files by name/extension — prefer over bash find/ls.",
|
||||
"ls": "List a directory's entries (folders then files with sizes). Use to see what's in a folder — prefer over bash ls.",
|
||||
"get_workspace": "Return the absolute path of the active workspace folder the user is working in. File tools are confined to it; the shell starts there but is not sandboxed. Call this first when the user refers to 'the project'/'the code'/'this folder' without giving a path, instead of asking them.",
|
||||
"write_file": "Write/create or fully rewrite a file ON DISK (source code, configs, project files). Use for new files or full rewrites — NOT create_document (editor panel) and NOT a bash heredoc.",
|
||||
"edit_file": "Edit an existing file ON DISK by exact string replacement (fix a bug, change a function). Shows a diff. The tool for changing files on disk — NOT edit_document (editor panel) and NOT bash sed/heredoc.",
|
||||
"create_document": "Create a new document in the editor panel. For code, articles, text content longer than 15 lines, unless an already-open document/email draft is the obvious target. If an email compose draft is open, edit that draft instead of creating another document.",
|
||||
@@ -395,6 +396,10 @@ class ToolIndex:
|
||||
"delegate to", "have model"}):
|
||||
{"chat_with_model", "ask_teacher", "list_models"},
|
||||
# Deep research intent (incl. common typo "reserach")
|
||||
frozenset({"web search", "search the web", "search online", "look up",
|
||||
"google", "latest", "current", "news", "weather",
|
||||
"forecast", "stock price", "price of"}):
|
||||
{"web_search", "web_fetch"},
|
||||
frozenset({"research", "reserach", "reasearch", "look into", "investigate",
|
||||
"deep dive", "deep research", "find out about", "study up on",
|
||||
"report on", "do research", "look up everything"}):
|
||||
|
||||
@@ -188,6 +188,12 @@ _MISFENCED_WEB_TOOL_NAMES = {
|
||||
"fetch_url": "web_fetch",
|
||||
}
|
||||
|
||||
_RAW_WEB_JSON_TOOL_RE = re.compile(
|
||||
r"\b(?:web_search|websearch|google_search|google_search_retrieval|google_search_grounding)\b",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
_RAW_WEB_JSON_ALLOWED_KEYS = {"query", "queries", "time_filter", "freshness", "max_pages"}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Parsing functions
|
||||
@@ -279,6 +285,73 @@ def _parse_misfenced_web_lookup(content: str) -> Optional[ToolBlock]:
|
||||
return None
|
||||
return ToolBlock("web_fetch", url)
|
||||
|
||||
|
||||
def _coerce_raw_web_query(value) -> Optional[str]:
|
||||
if isinstance(value, str) and value.strip():
|
||||
return value.strip()
|
||||
if isinstance(value, list):
|
||||
for item in value:
|
||||
if isinstance(item, str) and item.strip():
|
||||
return item.strip()
|
||||
return None
|
||||
|
||||
|
||||
def _raw_web_json_to_tool_block(payload) -> Optional[ToolBlock]:
|
||||
if not isinstance(payload, dict):
|
||||
return None
|
||||
if set(payload) - _RAW_WEB_JSON_ALLOWED_KEYS:
|
||||
return None
|
||||
|
||||
query = _coerce_raw_web_query(payload.get("query"))
|
||||
if not query:
|
||||
query = _coerce_raw_web_query(payload.get("queries"))
|
||||
if not query:
|
||||
return None
|
||||
|
||||
content = {"query": query}
|
||||
for key in ("time_filter", "freshness"):
|
||||
value = payload.get(key)
|
||||
if isinstance(value, str) and value.strip().lower() in ("day", "week", "month", "year"):
|
||||
content[key] = value.strip().lower()
|
||||
|
||||
max_pages = payload.get("max_pages")
|
||||
if isinstance(max_pages, int) and 1 <= max_pages <= 10:
|
||||
content["max_pages"] = max_pages
|
||||
|
||||
if len(content) == 1:
|
||||
return ToolBlock("web_search", query)
|
||||
return ToolBlock("web_search", json.dumps(content))
|
||||
|
||||
|
||||
def _parse_raw_web_json_lookup(text: str) -> Optional[tuple[ToolBlock, tuple[int, int]]]:
|
||||
"""Recover local text-model web_search calls emitted as prose + bare JSON.
|
||||
|
||||
Some non-native tool models leak the intended call as:
|
||||
|
||||
Need to do web_search for ...
|
||||
{"query": "...", "time_filter": "week"}
|
||||
|
||||
Keep this narrower than fenced/tool markup: it only runs when a known web
|
||||
tool name appears shortly before a JSON object shaped like web_search args.
|
||||
"""
|
||||
if not isinstance(text, str):
|
||||
return None
|
||||
|
||||
decoder = json.JSONDecoder()
|
||||
for mention in _RAW_WEB_JSON_TOOL_RE.finditer(text):
|
||||
search_start = mention.end()
|
||||
search_end = min(len(text), search_start + 1200)
|
||||
for brace in re.finditer(r"\{", text[search_start:search_end]):
|
||||
start = search_start + brace.start()
|
||||
try:
|
||||
parsed, end = decoder.raw_decode(text[start:])
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
block = _raw_web_json_to_tool_block(parsed)
|
||||
if block:
|
||||
return block, (start, start + end)
|
||||
return None
|
||||
|
||||
def _parse_tool_call_block(raw: str) -> Optional[ToolBlock]:
|
||||
"""Parse a [TOOL_CALL] block into a ToolBlock.
|
||||
|
||||
@@ -436,6 +509,8 @@ def parse_tool_blocks(text: str, skip_fenced: bool = False) -> List[ToolBlock]:
|
||||
3. XML-style <tool_call>/<invoke> blocks
|
||||
4. <tool_code> blocks (MiniMax-M2.5 style)
|
||||
5. DeepSeek DSML markup (normalized to <invoke> first)
|
||||
6. Non-native local model fallback: prose mentioning web_search followed by
|
||||
bare JSON args, e.g. {"query":"...", "time_filter":"week"}
|
||||
|
||||
`skip_fenced`: when True, Pattern 1 (fenced ```bash/```python/```json code
|
||||
blocks) is not matched at all. Native function-calling models (GPT/Claude/
|
||||
@@ -509,6 +584,12 @@ def parse_tool_blocks(text: str, skip_fenced: bool = False) -> List[ToolBlock]:
|
||||
if block:
|
||||
blocks.append(block)
|
||||
|
||||
# Pattern 6: local text-model web_search call leaked as prose + bare JSON.
|
||||
if not blocks and not skip_fenced:
|
||||
raw_web_json = _parse_raw_web_json_lookup(text)
|
||||
if raw_web_json:
|
||||
blocks.append(raw_web_json[0])
|
||||
|
||||
return blocks
|
||||
|
||||
|
||||
@@ -532,6 +613,11 @@ def strip_tool_blocks(text: str, skip_fenced: bool = False) -> str:
|
||||
cleaned = _TOOL_CALL_RE.sub('', cleaned)
|
||||
cleaned = _XML_TOOL_CALL_RE.sub('', cleaned)
|
||||
cleaned = _TOOL_CODE_RE.sub('', cleaned)
|
||||
if not skip_fenced:
|
||||
raw_web_json = _parse_raw_web_json_lookup(cleaned)
|
||||
if raw_web_json:
|
||||
_, (start, end) = raw_web_json
|
||||
cleaned = cleaned[:start] + cleaned[end:]
|
||||
# Strip bare <invoke> blocks not wrapped in <tool_call>
|
||||
cleaned = re.sub(r'<invoke\s+name=["\'].*?</invoke>', '', cleaned, flags=re.DOTALL | re.IGNORECASE)
|
||||
cleaned = re.sub(r'\n{3,}', '\n\n', cleaned)
|
||||
|
||||
+12
-2
@@ -25,7 +25,7 @@ FUNCTION_TOOL_SCHEMAS = [
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "bash",
|
||||
"description": "Run a shell command (full access)",
|
||||
"description": "Run a shell command (full access). Prefer a dedicated tool whenever one fits the job (reading, writing, editing, searching, or listing files); use bash only for what no dedicated tool covers (installs, git, builds, running programs, system info). Do NOT create or edit files via bash redirects/heredocs/sed -- use the dedicated file tools.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
@@ -39,7 +39,7 @@ FUNCTION_TOOL_SCHEMAS = [
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "python",
|
||||
"description": "Execute Python code to compute a result or test something",
|
||||
"description": "Execute Python code to compute a result or test something. Prefer a dedicated tool whenever one fits the job (reading, writing, or searching files); use python only for computation, data processing, or scripting no dedicated tool covers.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
@@ -141,6 +141,14 @@ FUNCTION_TOOL_SCHEMAS = [
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_workspace",
|
||||
"description": "Return the absolute path of the active workspace folder the user is working in. File tools are confined to it; the shell starts there but is not sandboxed. Call this first when the user refers to 'the project'/'the code'/'this folder' without a path, instead of asking them. Takes no arguments.",
|
||||
"parameters": {"type": "object", "properties": {}, "required": []}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
@@ -1247,6 +1255,8 @@ def function_call_to_tool_block(name: str, arguments: str) -> Optional[ToolBlock
|
||||
content = args.get("path", "")
|
||||
elif tool_type in ("grep", "glob", "ls"):
|
||||
content = json.dumps(args) if args else "{}"
|
||||
elif tool_type == "get_workspace":
|
||||
content = ""
|
||||
elif tool_type == "write_file":
|
||||
content = args.get("path", "") + "\n" + args.get("content", "")
|
||||
elif tool_type == "edit_file":
|
||||
|
||||
+20
-2
@@ -20,6 +20,7 @@ NON_ADMIN_BLOCKED_TOOLS = {
|
||||
"grep",
|
||||
"glob",
|
||||
"ls",
|
||||
"get_workspace",
|
||||
"search_chats",
|
||||
"manage_memory",
|
||||
"manage_skills",
|
||||
@@ -66,6 +67,7 @@ PLAN_MODE_READONLY_TOOLS = {
|
||||
"grep",
|
||||
"glob",
|
||||
"ls",
|
||||
"get_workspace",
|
||||
"web_search",
|
||||
"web_fetch",
|
||||
"search_chats",
|
||||
@@ -162,13 +164,29 @@ def is_public_blocked_tool(tool_name: Optional[str]) -> bool:
|
||||
|
||||
|
||||
def owner_is_admin_or_single_user(owner: Optional[str]) -> bool:
|
||||
"""Return True for admins, or when auth is not configured yet."""
|
||||
"""Return True for admins, or in intentional single-user mode.
|
||||
|
||||
Single-user mode means the operator explicitly disabled auth
|
||||
(``AUTH_ENABLED=false``) — the local/self-host default where the owner has
|
||||
full access to their own box.
|
||||
|
||||
The pre-setup window (auth ENABLED but no admin created yet) is treated as
|
||||
NON-admin: returning True there would hand server-execution tools
|
||||
(``bash``/``python``) to any caller before setup completes. The auth
|
||||
middleware already 401s ``/api/`` requests pre-setup, so this is
|
||||
defense-in-depth for callers that bypass it (e.g. trusted loopback).
|
||||
"""
|
||||
try:
|
||||
from src.auth_helpers import _auth_disabled
|
||||
|
||||
if _auth_disabled():
|
||||
return True
|
||||
|
||||
from core.auth import AuthManager
|
||||
|
||||
auth = AuthManager()
|
||||
if not auth.is_configured:
|
||||
return True
|
||||
return False
|
||||
return bool(owner and auth.is_admin(owner))
|
||||
except Exception as exc:
|
||||
logger.warning("Unable to evaluate owner admin status: %s", exc)
|
||||
|
||||
@@ -352,6 +352,86 @@ class UploadHandler:
|
||||
return dict(info)
|
||||
return None
|
||||
|
||||
def _renamed_upload_index_key(self, key: str, info: Dict[str, Any], old_owner: str, new_owner: str) -> str:
|
||||
"""Return the storage key to use after renaming an owned upload row."""
|
||||
if isinstance(key, str) and ":" in key:
|
||||
owner_part, rest = key.split(":", 1)
|
||||
if owner_part.strip().lower() == old_owner:
|
||||
return f"{new_owner}:{rest}"
|
||||
file_hash = info.get("hash")
|
||||
if file_hash:
|
||||
return f"{new_owner}:{file_hash}"
|
||||
return key
|
||||
|
||||
def _unique_upload_index_key(self, base_key: str, used_keys: set, reserved_keys: set, info: Dict[str, Any]) -> str:
|
||||
"""Choose a deterministic collision key without overwriting an existing row."""
|
||||
if base_key not in used_keys and base_key not in reserved_keys:
|
||||
return base_key
|
||||
|
||||
upload_id = str(info.get("id") or "renamed").strip() or "renamed"
|
||||
candidate = f"{base_key}:{upload_id}"
|
||||
if candidate not in used_keys and candidate not in reserved_keys:
|
||||
return candidate
|
||||
|
||||
index = 2
|
||||
while True:
|
||||
candidate = f"{base_key}:{upload_id}:{index}"
|
||||
if candidate not in used_keys and candidate not in reserved_keys:
|
||||
return candidate
|
||||
index += 1
|
||||
|
||||
def rename_owner(self, old_owner: str, new_owner: str) -> int:
|
||||
"""Rename upload metadata ownership from old_owner to new_owner.
|
||||
|
||||
Upload rows are keyed by owner-qualified hashes for dedupe and also
|
||||
carry an `owner` field for access checks. Both must move together when
|
||||
usernames change.
|
||||
"""
|
||||
old_owner_normalized = str(old_owner or "").strip().lower()
|
||||
new_owner = str(new_owner or "").strip()
|
||||
if not old_owner_normalized or not new_owner:
|
||||
return 0
|
||||
if old_owner_normalized == new_owner.lower():
|
||||
return 0
|
||||
|
||||
uploads_db_path = os.path.join(self.upload_dir, "uploads.json")
|
||||
with self._index_lock:
|
||||
current = self._load_upload_index()
|
||||
if not current:
|
||||
return 0
|
||||
|
||||
updated = {}
|
||||
renamed = 0
|
||||
original_keys = set(current.keys())
|
||||
|
||||
for key, info in current.items():
|
||||
new_key = key
|
||||
new_info = info
|
||||
if isinstance(info, dict) and str(info.get("owner", "")).strip().lower() == old_owner_normalized:
|
||||
new_info = dict(info)
|
||||
new_info["owner"] = new_owner
|
||||
base_key = self._renamed_upload_index_key(key, new_info, old_owner_normalized, new_owner)
|
||||
new_key = self._unique_upload_index_key(
|
||||
base_key,
|
||||
set(updated.keys()),
|
||||
original_keys - {key},
|
||||
new_info,
|
||||
)
|
||||
if new_key != base_key:
|
||||
logger.warning(
|
||||
"Upload owner rename key collision for %s -> %s at %s; preserving row as %s",
|
||||
old_owner_normalized,
|
||||
new_owner,
|
||||
base_key,
|
||||
new_key,
|
||||
)
|
||||
renamed += 1
|
||||
updated[new_key] = new_info
|
||||
|
||||
if renamed:
|
||||
self._atomic_write_json(uploads_db_path, updated)
|
||||
return renamed
|
||||
|
||||
def _find_upload_path(self, upload_id: str) -> Optional[str]:
|
||||
"""Find an upload file by ID while staying inside upload_dir."""
|
||||
if not self.validate_upload_id(upload_id):
|
||||
|
||||
+24
-1
@@ -9,7 +9,7 @@ from __future__ import annotations
|
||||
import re
|
||||
from contextvars import ContextVar
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Optional
|
||||
from typing import Dict, Optional
|
||||
|
||||
|
||||
_USER_TZ_OFFSET_MIN: ContextVar[Optional[int]] = ContextVar("user_tz_offset_min", default=None)
|
||||
@@ -136,3 +136,26 @@ def current_datetime_prompt(now_utc: Optional[datetime] = None) -> str:
|
||||
"When scheduling a task with manage_tasks, scheduled_time is in UTC: "
|
||||
"convert the user's stated local time using the UTC offset above.\n\n"
|
||||
)
|
||||
|
||||
|
||||
def current_datetime_context_message(now_utc: Optional[datetime] = None) -> Dict[str, str]:
|
||||
"""Build the current-date/time context as a standalone chat message.
|
||||
|
||||
This intentionally returns a ``user``-role message rather than a
|
||||
``system``-role one. The text changes every turn (it embeds the current
|
||||
clock time down to the minute), and local OpenAI-compatible backends
|
||||
(llama.cpp / LM Studio) key their KV-cache prefix off the system message
|
||||
byte-for-byte — folding ever-changing timestamp text into the system
|
||||
message would invalidate the cached prefix on every single request (see
|
||||
issue #2927). Keeping it as a separate message placed near the end of the
|
||||
array (right before the latest user turn) lets the static system prompt
|
||||
stay byte-identical across turns while the model still gets fresh
|
||||
date/time grounding for relative-date reasoning.
|
||||
"""
|
||||
return {
|
||||
"role": "user",
|
||||
"content": (
|
||||
"[Context — current date/time, refreshed each turn; not part of "
|
||||
"your instructions]\n" + current_datetime_prompt(now_utc)
|
||||
),
|
||||
}
|
||||
|
||||
+15
-3
@@ -202,6 +202,18 @@ class WebhookManager:
|
||||
self._client = httpx.AsyncClient(timeout=10, follow_redirects=False)
|
||||
self._loop: Optional[asyncio.AbstractEventLoop] = None
|
||||
self._api_key_manager = api_key_manager
|
||||
# Strong references to in-flight fire-and-forget tasks. asyncio only
|
||||
# keeps weak references to tasks, so without this the GC can collect a
|
||||
# delivery task mid-flight and the webhook is silently never sent.
|
||||
self._bg_tasks: set = set()
|
||||
|
||||
def _spawn_tracked(self, coro):
|
||||
"""Schedule a background task and hold a strong reference until it
|
||||
finishes, so it can't be garbage-collected before delivery completes."""
|
||||
task = asyncio.ensure_future(coro)
|
||||
self._bg_tasks.add(task)
|
||||
task.add_done_callback(self._bg_tasks.discard)
|
||||
return task
|
||||
|
||||
def set_loop(self, loop: asyncio.AbstractEventLoop):
|
||||
self._loop = loop
|
||||
@@ -223,8 +235,8 @@ class WebhookManager:
|
||||
if event not in ALLOWED_EVENTS:
|
||||
return
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
loop.create_task(self.fire(event, payload))
|
||||
asyncio.get_running_loop()
|
||||
self._spawn_tracked(self.fire(event, payload))
|
||||
except RuntimeError:
|
||||
# Called from a sync thread (e.g. sync FastAPI route in threadpool)
|
||||
if self._loop and self._loop.is_running():
|
||||
@@ -243,7 +255,7 @@ class WebhookManager:
|
||||
|
||||
for wh in matching:
|
||||
decrypted_secret = self._decrypt_secret(wh.secret)
|
||||
asyncio.create_task(self._deliver(wh.id, wh.url, decrypted_secret, event, payload))
|
||||
self._spawn_tracked(self._deliver(wh.id, wh.url, decrypted_secret, event, payload))
|
||||
|
||||
async def deliver_test(self, webhook_id: str, url: str, encrypted_secret: Optional[str]):
|
||||
"""Public method for the test-webhook route."""
|
||||
|
||||
+18
-273
@@ -1,278 +1,23 @@
|
||||
"""
|
||||
YouTube handling — transcript extraction, comment fetching (yt-dlp),
|
||||
and context formatting for LLM injection. Used by chat_handler.py.
|
||||
"""Compatibility wrapper for the canonical services.youtube.youtube_handler module.
|
||||
|
||||
Odysseus historically carried two independent copies of the YouTube handler —
|
||||
one here under ``src`` and one under ``services.youtube``. They drifted: the
|
||||
comment-fetch timeout fix landed only in the ``src`` copy, while ``app.py``
|
||||
calls ``services.youtube.init_youtube()`` at startup. Because the chat flow
|
||||
imported ``extract_transcript_async`` from ``src.youtube_handler`` (a different
|
||||
module object), the ``YOUTUBE_AVAILABLE`` / ``YouTubeTranscriptApi`` globals set
|
||||
by ``init_youtube`` never reached it and transcript extraction always reported
|
||||
"YouTube transcript API not available".
|
||||
|
||||
Keep the old ``src.youtube_handler`` import path working, but make it resolve to
|
||||
the single source of truth so module state and behavior can't diverge again.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import shutil
|
||||
import importlib
|
||||
import sys
|
||||
import urllib.parse
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
# Import the canonical module directly (services.youtube.youtube_handler)
|
||||
# without triggering the heavy services/__init__.py top-level imports.
|
||||
_youtube_handler = importlib.import_module("services.youtube.youtube_handler")
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Constants
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
YOUTUBE_INSTRUCTION_PROMPT = """When the user shares a YouTube video, respond with a structured breakdown:
|
||||
|
||||
1. **Summary** — Concise overview of the video's content and main thesis (2-4 sentences)
|
||||
2. **Key Points** — Bullet list of the most important topics, arguments, or moments
|
||||
3. **Notable Timestamps** — If timestamps are available from the transcript, highlight 3-5 interesting moments with their approximate timestamps (e.g. "03:45 — discusses X")
|
||||
4. **Audience Reception** — If comments are available, summarize what viewers think: general sentiment, top reactions, any debate or controversy
|
||||
|
||||
Keep it conversational and concise. Do NOT web search for this video — use only the transcript and comments provided."""
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Init / helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Will be set at startup by init_youtube()
|
||||
YouTubeTranscriptApi = None
|
||||
YOUTUBE_AVAILABLE = False
|
||||
|
||||
|
||||
def _find_ytdlp() -> str:
|
||||
"""Find the yt-dlp binary: venv bin first, then system PATH."""
|
||||
venv_bin = Path(sys.executable).parent / "yt-dlp"
|
||||
if venv_bin.exists():
|
||||
return str(venv_bin)
|
||||
found = shutil.which("yt-dlp")
|
||||
return found or "yt-dlp"
|
||||
|
||||
|
||||
def init_youtube():
|
||||
"""Import and cache the YouTube transcript API."""
|
||||
global YouTubeTranscriptApi, YOUTUBE_AVAILABLE
|
||||
try:
|
||||
from youtube_transcript_api import YouTubeTranscriptApi as _Api
|
||||
YouTubeTranscriptApi = _Api
|
||||
YOUTUBE_AVAILABLE = True
|
||||
logger.info("YouTube transcript API available")
|
||||
except ImportError as e:
|
||||
logger.warning(f"youtube-transcript-api not installed: {e}")
|
||||
YOUTUBE_AVAILABLE = False
|
||||
|
||||
|
||||
def is_youtube_url(url: str) -> bool:
|
||||
if not isinstance(url, str):
|
||||
return False
|
||||
return "youtube.com" in url or "youtu.be" in url
|
||||
|
||||
|
||||
def extract_youtube_id(url: str) -> Optional[str]:
|
||||
"""Extract YouTube video ID from various URL formats."""
|
||||
parsed = urllib.parse.urlparse(url)
|
||||
if parsed.hostname in ("www.youtube.com", "youtube.com", "m.youtube.com"):
|
||||
if parsed.path == "/watch":
|
||||
params = urllib.parse.parse_qs(parsed.query)
|
||||
if "v" in params:
|
||||
return params["v"][0]
|
||||
elif parsed.path.startswith("/embed/"):
|
||||
return parsed.path.split("/")[-1]
|
||||
elif parsed.hostname == "youtu.be":
|
||||
return parsed.path[1:]
|
||||
return None
|
||||
|
||||
|
||||
async def extract_transcript_async(
|
||||
url: str, video_id: str, max_retries: int = 3
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Async YouTube transcript extraction with retries.
|
||||
|
||||
Args:
|
||||
url: Full YouTube URL
|
||||
video_id: Extracted video ID
|
||||
max_retries: Number of attempts
|
||||
|
||||
Returns:
|
||||
Dict with success/error/transcript keys
|
||||
"""
|
||||
if not YOUTUBE_AVAILABLE or YouTubeTranscriptApi is None:
|
||||
return {"success": False, "error": "YouTube transcript API not available", "transcript": None}
|
||||
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
api = YouTubeTranscriptApi()
|
||||
transcript = api.fetch(video_id)
|
||||
transcript_list = list(transcript)
|
||||
|
||||
formatted = []
|
||||
for snippet in transcript_list:
|
||||
text = snippet.text.strip()
|
||||
if not text:
|
||||
continue
|
||||
start = snippet.start
|
||||
formatted.append({
|
||||
"text": text,
|
||||
"start": start,
|
||||
"duration": snippet.duration,
|
||||
"timestamp": f"{int(start // 60):02d}:{int(start % 60):02d}",
|
||||
})
|
||||
|
||||
full_text = " ".join(e["text"] for e in formatted)
|
||||
max_len = 8000
|
||||
if len(full_text) > max_len:
|
||||
full_text = full_text[:max_len] + "... [transcript truncated]"
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"transcript": full_text,
|
||||
"video_id": video_id,
|
||||
"language": "en",
|
||||
"is_generated": False,
|
||||
"segments": formatted,
|
||||
}
|
||||
except Exception as e:
|
||||
logger.warning(f"Transcript attempt {attempt + 1} failed: {e}")
|
||||
if attempt < max_retries - 1:
|
||||
await asyncio.sleep(1 * (attempt + 1))
|
||||
|
||||
return {"success": False, "error": f"Failed after {max_retries} attempts", "transcript": None}
|
||||
|
||||
|
||||
def format_transcript_for_context(
|
||||
transcript_data: Dict[str, Any], url: str,
|
||||
title: str = "", channel: str = ""
|
||||
) -> str:
|
||||
"""Format transcript data for inclusion in LLM context."""
|
||||
if not transcript_data.get("success"):
|
||||
header = ""
|
||||
if title:
|
||||
header = f" \"{title}\""
|
||||
if channel:
|
||||
header += f" by {channel}"
|
||||
return f"\n[YouTube Video{header}: Transcript unavailable ({transcript_data.get('error', 'Unknown error')}). Use the comments below if available, do NOT web search for this video.]"
|
||||
|
||||
transcript = transcript_data.get("transcript", "")
|
||||
video_id = transcript_data.get("video_id", "")
|
||||
language = transcript_data.get("language", "unknown")
|
||||
is_generated = transcript_data.get("is_generated", False)
|
||||
segments = transcript_data.get("segments", [])
|
||||
|
||||
ctx = "\n[YOUTUBE VIDEO TRANSCRIPT]\n"
|
||||
if title:
|
||||
ctx += f"Title: {title}\n"
|
||||
if channel:
|
||||
ctx += f"Channel: {channel}\n"
|
||||
ctx += f"Video ID: {video_id}\n"
|
||||
ctx += f"Language: {language}\n"
|
||||
ctx += f"Source: {'Auto-generated' if is_generated else 'Manual'}\n"
|
||||
ctx += f"URL: {url}\n\n"
|
||||
# Include timestamped segments for the LLM to reference
|
||||
if segments:
|
||||
ctx += "Timestamped Transcript:\n"
|
||||
for seg in segments:
|
||||
if not isinstance(seg, dict):
|
||||
continue
|
||||
ctx += f"[{seg['timestamp']}] {seg['text']}\n"
|
||||
# Check length — fall back to plain text if too long
|
||||
if len(ctx) > 12000:
|
||||
ctx = ctx[:ctx.index("Timestamped Transcript:\n")]
|
||||
ctx += "Transcript:\n"
|
||||
ctx += transcript
|
||||
else:
|
||||
ctx += "Transcript:\n"
|
||||
ctx += transcript
|
||||
ctx += "\n[END TRANSCRIPT]\n"
|
||||
return ctx
|
||||
|
||||
|
||||
async def fetch_youtube_comments(
|
||||
video_id: str, max_comments: int = 25, timeout: int = 30
|
||||
) -> Dict[str, Any]:
|
||||
"""Fetch top comments for a YouTube video using yt-dlp.
|
||||
|
||||
Returns dict with 'success', 'comments' list, 'error'.
|
||||
"""
|
||||
try:
|
||||
cmd = [
|
||||
_find_ytdlp(),
|
||||
"--skip-download",
|
||||
"--write-comments",
|
||||
"--extractor-args", f"youtube:max_comments={max_comments},all,100,0",
|
||||
"--dump-json",
|
||||
"--js-runtimes", "node",
|
||||
"--remote-components", "ejs:github",
|
||||
f"https://www.youtube.com/watch?v={video_id}",
|
||||
]
|
||||
|
||||
proc = await asyncio.create_subprocess_exec(
|
||||
*cmd,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
)
|
||||
# Bound the wait on the process actually finishing, not on spawning it.
|
||||
# create_subprocess_exec returns as soon as the child starts, so wrapping
|
||||
# it in wait_for never enforces the timeout — proc.communicate() is the
|
||||
# blocking step. Kill and reap the child if it overruns so it does not
|
||||
# linger after we return.
|
||||
try:
|
||||
stdout, stderr = await asyncio.wait_for(
|
||||
proc.communicate(), timeout=timeout
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
proc.kill()
|
||||
await proc.wait()
|
||||
raise
|
||||
|
||||
if proc.returncode != 0:
|
||||
return {"success": False, "error": f"yt-dlp failed: {stderr.decode()[:200]}", "comments": []}
|
||||
|
||||
data = json.loads(stdout.decode())
|
||||
title = data.get("title", "")
|
||||
channel = data.get("channel", "") or data.get("uploader", "")
|
||||
raw_comments = data.get("comments", [])
|
||||
|
||||
comments = []
|
||||
for c in raw_comments[:max_comments]:
|
||||
text = (c.get("text") or "").strip()
|
||||
if not text:
|
||||
continue
|
||||
comments.append({
|
||||
"author": c.get("author", "Unknown"),
|
||||
"text": text,
|
||||
"likes": c.get("like_count", 0),
|
||||
})
|
||||
|
||||
# Sort by likes descending — most popular comments first
|
||||
comments.sort(key=lambda x: x.get("likes", 0), reverse=True)
|
||||
|
||||
return {"success": True, "comments": comments, "count": len(comments),
|
||||
"title": title, "channel": channel}
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(f"Comment fetch timed out for {video_id}")
|
||||
return {"success": False, "error": "Comment fetch timed out", "comments": []}
|
||||
except FileNotFoundError:
|
||||
logger.warning("yt-dlp not installed — cannot fetch comments")
|
||||
return {"success": False, "error": "yt-dlp not installed", "comments": []}
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch comments for {video_id}: {e}")
|
||||
return {"success": False, "error": str(e), "comments": []}
|
||||
|
||||
|
||||
def format_comments_for_context(comments_data: Dict[str, Any], url: str) -> str:
|
||||
"""Format YouTube comments for inclusion in LLM context."""
|
||||
if not comments_data.get("success") or not comments_data.get("comments"):
|
||||
return ""
|
||||
|
||||
comments = comments_data["comments"]
|
||||
ctx = f"\n[YOUTUBE VIDEO COMMENTS — Top {len(comments)} by popularity]\n"
|
||||
ctx += f"URL: {url}\n\n"
|
||||
|
||||
for i, c in enumerate(comments, 1):
|
||||
likes = c.get("likes", 0)
|
||||
likes_str = f" [{likes} likes]" if likes else ""
|
||||
ctx += f"{i}. @{c['author']}{likes_str}: {c['text']}\n\n"
|
||||
|
||||
if len(ctx) > 4000:
|
||||
ctx = ctx[:4000] + "\n[Comments truncated]\n"
|
||||
|
||||
ctx += "[END COMMENTS]\n"
|
||||
return ctx
|
||||
sys.modules[__name__] = _youtube_handler
|
||||
|
||||
Reference in New Issue
Block a user