['\"]?)(?P[A-Za-z0-9_]+)(?P=quote)" +) def _ollama_bind_from_cmd(cmd: str | None, *, default_host: str = "127.0.0.1") -> tuple[str, str]: @@ -588,6 +639,22 @@ def _ollama_bind_from_cmd(cmd: str | None, *, default_host: str = "127.0.0.1") - return f"[{host}]" if bracketed_host else host, port +def _normalize_llama_cpp_python_cache_types(cmd: str | None) -> str | None: + """Map llama.cpp KV cache type names to llama-cpp-python's integer enum.""" + if not cmd or "llama_cpp.server" not in cmd: + return cmd + + def repl(match: re.Match[str]) -> str: + value = match.group("value") + mapped = _LLAMA_CPP_PYTHON_GGML_TYPES.get(value.lower()) + if not mapped: + return match.group(0) + quote = match.group("quote") + return f"{match.group('flag')}{match.group('sep')}{quote}{mapped}{quote}" + + return _LLAMA_CPP_PYTHON_TYPE_FLAG_RE.sub(repl, cmd) + + def _check_serve_binary(seg: str) -> None: """Validate that a single command segment starts with an allowlisted binary (after skipping leading env-var assignments like `CUDA_VISIBLE_DEVICES=0`).""" @@ -726,6 +793,7 @@ def _append_llama_cpp_linux_accel_build_lines(runner_lines: list[str]) -> None: runner_lines.append(' done') # rm -rf build so a prior poisoned CMakeCache.txt (e.g. from a failed CUDA # or HIP attempt) doesn't cause the next configure to reuse stale settings. + runner_lines.append(' mkdir -p ~/bin') runner_lines.append(' cd ~/llama.cpp && rm -rf build') runner_lines.append(' if command -v hipconfig &>/dev/null || [ -d /opt/rocm ] || [ -n "$ROCM_PATH" ] || [ -n "$HIP_PATH" ]; then') runner_lines.append(' if command -v hipconfig &>/dev/null; then') @@ -1030,6 +1098,16 @@ def _diagnose_serve_output(text: str) -> dict | None: "vLLM is not installed or not in PATH on this server.", [{"label": "install vLLM in Cookbook Dependencies", "op": "dependency", "package": "vllm"}], ), + ( + r"sgl_kernel[\s\S]*(Python\.h|libnuma\.so\.1|common_ops)|" + r"(Python\.h|libnuma\.so\.1|common_ops)[\s\S]*sgl_kernel|" + r"Please ensure sgl_kernel is properly installed", + "SGLang native dependencies are missing on this server.", + [ + {"label": "install OS packages: libnuma-dev python3.12-dev build-essential", "op": "manual"}, + {"label": "upgrade sglang-kernel after OS packages are installed", "op": "manual"}, + ], + ), ( r"sglang.*command not found|No module named sglang|SGLang is not installed", "SGLang is not installed or not in PATH on this server.", diff --git a/routes/cookbook_output.py b/routes/cookbook_output.py new file mode 100644 index 000000000..b30b18536 --- /dev/null +++ b/routes/cookbook_output.py @@ -0,0 +1,75 @@ +"""Pure helpers for shaping cookbook task output for the status response. + +Kept dependency-free (no FastAPI / SQLAlchemy imports) so the behavior can be +unit-tested without standing up the whole app. +""" + +import re + +_FETCHING_ZERO_FILES_RE = re.compile(r"Fetching\s+0\s+files", re.IGNORECASE) + +# Probe scripts for the dead-session download check, run as +# `python3 -c ` (locally or over SSH). +# cache_root is the task's custom download dir, '' for the default HF cache. +# It has to be passed explicitly: the download runner exports +# HF_HOME= , so that task's cache lives under /hub, and +# the probe process's own environment knows nothing about it. +HF_CACHE_COMPLETE_PROBE = ( + "import os,sys;" + "repo=sys.argv[1];" + "root=os.path.expanduser(sys.argv[2]) if len(sys.argv)>2 and sys.argv[2] else '';" + "base=os.path.join(root,'hub') if root else (os.environ.get('HUGGINGFACE_HUB_CACHE') or os.path.join(os.environ.get('HF_HOME', os.path.expanduser('~/.cache/huggingface')), 'hub'));" + "d=os.path.join(base,'models--'+repo.replace('/','--'));" + "snap=os.path.join(d,'snapshots');" + "ok=os.path.isdir(snap) and any(os.path.isdir(os.path.join(snap,x)) and os.listdir(os.path.join(snap,x)) for x in os.listdir(snap));" + "inc=False;" + "blobs=os.path.join(d,'blobs');" + "inc=os.path.isdir(blobs) and any(x.endswith('.incomplete') for x in os.listdir(blobs));" + "sys.exit(0 if ok and not inc else 1)" +) + +HF_CACHE_INCOMPLETE_PROBE = ( + "import os,sys;" + "repo=sys.argv[1];" + "root=os.path.expanduser(sys.argv[2]) if len(sys.argv)>2 and sys.argv[2] else '';" + "base=os.path.join(root,'hub') if root else (os.environ.get('HUGGINGFACE_HUB_CACHE') or os.path.join(os.environ.get('HF_HOME', os.path.expanduser('~/.cache/huggingface')), 'hub'));" + "d=os.path.join(base,'models--'+repo.replace('/','--'));" + "blobs=os.path.join(d,'blobs');" + "inc=os.path.isdir(blobs) and any(x.endswith('.incomplete') for x in os.listdir(blobs));" + "sys.exit(0 if inc else 1)" +) + + +def classify_dead_download(full_snapshot: str): + """Resolve a dead download session's status from its runner markers. + + The runner prints DOWNLOAD_OK only after exiting 0 (and DOWNLOAD_FAILED + otherwise), so the markers stay trustworthy after the tmux pane is gone. + Returns (status, zero_files), or None when the snapshot carries no marker + and the caller has to fall back to the cache probe. Same precedence as + the live-session branch: DOWNLOAD_OK wins, except a "Fetching 0 files" + run is an error (nothing matched the include/quant pattern). + """ + if not full_snapshot: + return None + if "DOWNLOAD_OK" in full_snapshot: + if _FETCHING_ZERO_FILES_RE.search(full_snapshot): + return ("error", True) + return ("completed", False) + if "DOWNLOAD_FAILED" in full_snapshot: + return ("error", False) + return None + + +def error_aware_output_tail(full_snapshot: str, status: str) -> str: + """Return the trailing slice of a task log for the status response. + + Failed tasks return the last 50 lines so the "Copy last 50 lines" action + surfaces the actual error context (stack traces, build output). Running and + other non-error tasks keep the cheaper 12-line tail to limit the payload on + the 10s polling interval. + """ + if not full_snapshot: + return "" + tail_lines = 50 if status == "error" else 12 + return "\n".join(full_snapshot.splitlines()[-tail_lines:]) diff --git a/routes/cookbook_routes.py b/routes/cookbook_routes.py index 872075178..af25dd8e8 100644 --- a/routes/cookbook_routes.py +++ b/routes/cookbook_routes.py @@ -15,9 +15,11 @@ from pathlib import Path from fastapi import APIRouter, HTTPException, Request, Depends from src.auth_helpers import require_user +from src.constants import COOKBOOK_STATE_FILE from pydantic import BaseModel from core.middleware import require_admin +from routes._validators import validate_remote_host, validate_ssh_port from core.platform_compat import ( IS_WINDOWS, detached_popen_kwargs, @@ -28,18 +30,26 @@ from core.platform_compat import ( which_tool, ) from routes.shell_routes import TMUX_LOG_DIR +from routes.cookbook_output import ( + error_aware_output_tail, classify_dead_download, + HF_CACHE_COMPLETE_PROBE, HF_CACHE_INCOMPLETE_PROBE, +) logger = logging.getLogger(__name__) from routes.cookbook_helpers import ( - _SSH_PORT_RE, _REMOTE_HOST_RE, _SESSION_ID_RE, - _validate_repo_id, _validate_serve_model_id, _validate_include, _validate_remote_host, _validate_token, - _validate_local_dir, _validate_ssh_port, _validate_gpus, _shell_path, + _SESSION_ID_RE, _validate_repo_id, _validate_serve_model_id, _validate_include, _validate_token, + _validate_local_dir, _validate_gpus, _shell_path, _ps_squote, _bash_squote, _validate_serve_cmd, _parse_serve_phase, _safe_env_prefix, _local_tooling_path_export, _append_serve_preflight_exit_lines, _append_serve_exit_code_lines, _append_llama_cpp_linux_accel_build_lines, _cached_model_scan_script, + load_stored_hf_token, + _append_vllm_linux_preflight_lines, _ollama_bind_from_cmd, _pip_install_fallback_chain, + _pip_install_no_cache, _user_shell_path_bootstrap, _venv_safe_local_pip_install_cmd, + _diagnose_serve_output, run_ssh_command_async, _ollama_bind_from_cmd, _pip_install_fallback_chain, _pip_install_no_cache, _user_shell_path_bootstrap, _venv_safe_local_pip_install_cmd, + _normalize_llama_cpp_python_cache_types, ModelDownloadRequest, ServeRequest, ) @@ -48,13 +58,13 @@ _HF_TOKEN_STATUS_SNIPPET = ( 'echo "[odysseus] HF token: applied"; ' 'else ' 'echo "[odysseus] HF token: NOT SET — gated/private models will be denied. ' - 'Add one in Odysseus Settings -> Cookbook -> HuggingFace Token."; ' + 'Add one in Odysseus Cookbook -> Settings -> HuggingFace Token."; ' 'fi' ) def setup_cookbook_routes() -> APIRouter: router = APIRouter(tags=["cookbook"]) - _cookbook_state_path = Path(os.environ.get("DATA_DIR", "data")) / "cookbook_state.json" + _cookbook_state_path = Path(COOKBOOK_STATE_FILE) def _mask_secret(value: str) -> str: if not value: @@ -164,6 +174,16 @@ def setup_cookbook_routes() -> APIRouter: "vLLM is not installed or not in PATH on this server.", [{"label": "install vLLM in Cookbook Dependencies", "op": "dependency", "package": "vllm"}], ), + ( + r"sgl_kernel[\s\S]*(Python\.h|libnuma\.so\.1|common_ops)|" + r"(Python\.h|libnuma\.so\.1|common_ops)[\s\S]*sgl_kernel|" + r"Please ensure sgl_kernel is properly installed", + "SGLang native dependencies are missing on this server.", + [ + {"label": "install OS packages: libnuma-dev python3.12-dev build-essential", "op": "manual"}, + {"label": "upgrade sglang-kernel after OS packages are installed", "op": "manual"}, + ], + ), ( r"sglang.*command not found|No module named sglang|SGLang is not installed", "SGLang is not installed or not in PATH on this server.", @@ -232,14 +252,7 @@ def setup_cookbook_routes() -> APIRouter: return state def _load_stored_hf_token() -> str: - if not _cookbook_state_path.exists(): - return "" - try: - state = json.loads(_cookbook_state_path.read_text(encoding="utf-8")) - env = state.get("env") if isinstance(state, dict) else {} - return _decrypt_secret(env.get("hfToken") if isinstance(env, dict) else "") - except Exception: - return "" + return load_stored_hf_token(state_path=_cookbook_state_path) def _cookbook_ssh_dir() -> Path: # The Docker image keeps cookbook keys under /app/.ssh; that path only @@ -354,7 +367,11 @@ def setup_cookbook_routes() -> APIRouter: # all output to the log the poller reads. Paths handed to bash use # POSIX form + shell-quoting so drive paths / spaces survive. inner = TMUX_LOG_DIR / f"{session_id}_run.sh" - inner.write_text("\n".join(bash_lines) + "\n", encoding="utf-8") + pp = shlex.quote(pid_path.as_posix()) + inner.write_text( + f"printf '%s\\n' \"$$\" > {pp}\n" + "\n".join(bash_lines) + "\n", + encoding="utf-8", + ) lp = shlex.quote(log_path.as_posix()) ip = shlex.quote(inner.as_posix()) script_path = TMUX_LOG_DIR / f"{session_id}.sh" @@ -406,8 +423,8 @@ def setup_cookbook_routes() -> APIRouter: else: _validate_repo_id(req.repo_id) _validate_include(req.include) - _validate_remote_host(req.remote_host) - req.ssh_port = _validate_ssh_port(req.ssh_port) + validate_remote_host(req.remote_host) + req.ssh_port = validate_ssh_port(req.ssh_port) req.local_dir = _validate_local_dir(req.local_dir) req.hf_token = "" if is_ollama_download else (req.hf_token or _load_stored_hf_token()) _validate_token(req.hf_token) @@ -738,9 +755,8 @@ def setup_cookbook_routes() -> APIRouter: # Validate shell-bound inputs, matching the sibling list_gpus endpoint — # `host`/`ssh_port` are interpolated into an ssh command below, so an # unvalidated value (e.g. "x'; rm -rf ~ #") would be command injection. - host = _validate_remote_host(host) - if ssh_port is not None and ssh_port != "" and not _SSH_PORT_RE.fullmatch(ssh_port): - raise HTTPException(400, "Invalid ssh_port") + host = validate_remote_host(host) + ssh_port = validate_ssh_port(ssh_port) TMUX_LOG_DIR.mkdir(parents=True, exist_ok=True) model_dirs = [] @@ -889,11 +905,16 @@ def setup_cookbook_routes() -> APIRouter: # listening" check without requiring ss/netstat/nmap. ssh_base = ["ssh", "-o", "ConnectTimeout=4", "-o", "StrictHostKeyChecking=no"] if ssh_port and str(ssh_port) != "22": - if not _SSH_PORT_RE.match(str(ssh_port)): + try: + ssh_port = validate_ssh_port(ssh_port) + except HTTPException: return None ssh_base.extend(["-p", str(ssh_port)]) - host_arg = remote - if not _REMOTE_HOST_RE.match(host_arg): + try: + host_arg = validate_remote_host(remote) + except HTTPException: + return None + if not host_arg: return None probe_ports = " ".join(str(start_port + i) for i in range(max_offset + 1)) script = ( @@ -1196,8 +1217,8 @@ def setup_cookbook_routes() -> APIRouter: """ require_admin(request) # Defence-in-depth: reject values that could break out of shell contexts. - _validate_remote_host(req.remote_host) - req.ssh_port = _validate_ssh_port(req.ssh_port) + validate_remote_host(req.remote_host) + req.ssh_port = validate_ssh_port(req.ssh_port) req.gpus = _validate_gpus(req.gpus) req.hf_token = req.hf_token or _load_stored_hf_token() _validate_token(req.hf_token) @@ -1208,6 +1229,7 @@ def setup_cookbook_routes() -> APIRouter: # many downstream `"engine" in req.cmd` membership checks can't hit # `TypeError: argument of type 'NoneType'` (a 500 instead of a clean 400). req.cmd = _validate_serve_cmd(req.cmd) or "" + req.cmd = _normalize_llama_cpp_python_cache_types(req.cmd) or "" req.cmd = _venv_safe_local_pip_install_cmd( req.cmd, local=not bool(req.remote_host), @@ -1637,12 +1659,11 @@ def setup_cookbook_routes() -> APIRouter: async def server_setup(request: Request, req: SetupRequest): """Install required dependencies on a remote server via SSH.""" require_admin(request) - host = _validate_remote_host(req.host) + host = validate_remote_host(req.host) if not host: raise HTTPException(400, "host is required") port = req.ssh_port - if port is not None and port != "" and not re.fullmatch(r"\d{1,5}", port): - raise HTTPException(400, "Invalid ssh_port") + port = validate_ssh_port(port) pf = f"-p {port} " if port and port != "22" else "" # Detect platform: Windows first (echo %OS% → Windows_NT), then Termux, then Linux @@ -1886,9 +1907,8 @@ def setup_cookbook_routes() -> APIRouter: `busy` is True when free_mb/total_mb < 0.5. """ require_admin(request) - host = _validate_remote_host(host) - if ssh_port is not None and ssh_port != "" and not _SSH_PORT_RE.fullmatch(ssh_port): - raise HTTPException(400, "Invalid ssh_port") + host = validate_remote_host(host) + ssh_port = validate_ssh_port(ssh_port) gpu_query = "nvidia-smi --query-gpu=index,name,memory.free,memory.total,memory.used,utilization.gpu,uuid --format=csv,noheader,nounits" nvidia_error = None try: @@ -2045,9 +2065,8 @@ def setup_cookbook_routes() -> APIRouter: sig = (req.signal or "TERM").upper() if sig not in ("TERM", "KILL", "INT"): raise HTTPException(400, "signal must be TERM, KILL, or INT") - host = _validate_remote_host(req.host) - if req.ssh_port and not _SSH_PORT_RE.fullmatch(req.ssh_port): - raise HTTPException(400, "Invalid ssh_port") + host = validate_remote_host(req.host) + req.ssh_port = validate_ssh_port(req.ssh_port) kill_cmd = f"kill -{sig} {req.pid}" try: if host: @@ -2381,14 +2400,19 @@ def setup_cookbook_routes() -> APIRouter: host = (srv.get("host") or "").strip() if not host: continue # local-only entry; the /proc scan handles it - if not _REMOTE_HOST_RE.match(host): + try: + host = validate_remote_host(host) + except HTTPException: continue sport = str(srv.get("port") or "").strip() ssh_base = ["ssh", "-o", "ConnectTimeout=4", "-o", "StrictHostKeyChecking=no"] if sport and sport != "22": - if not _SSH_PORT_RE.match(sport): + try: + sport = validate_ssh_port(sport) + except HTTPException: continue - ssh_base.extend(["-p", sport]) + if sport != "22": + ssh_base.extend(["-p", sport]) try: ls = subprocess.run( @@ -2802,30 +2826,20 @@ def setup_cookbook_routes() -> APIRouter: def _cookbook_tasks_status_sync(): import subprocess - def _download_cache_complete(repo_id: str, remote_host: str = "", ssh_port: str = "") -> bool: + def _download_cache_complete(repo_id: str, remote_host: str = "", ssh_port: str = "", cache_root: str = "") -> bool: """Best-effort check for a completed HF cache entry. tmux output can stop at a stale progress line if the pane/session disappears before Cookbook captures the final DOWNLOAD_OK marker. In that case, trust the cache shape: a snapshot directory with files and no *.incomplete blobs means HuggingFace finished materializing the - model. + model. cache_root is the task's custom download dir — the runner + pointed HF_HOME there, so the cache lives under /hub, + not wherever this probe's environment says. """ if not repo_id or "/" not in repo_id: return False - py = ( - "import os,sys;" - "repo=sys.argv[1];" - "base=os.environ.get('HUGGINGFACE_HUB_CACHE') or os.path.join(os.environ.get('HF_HOME', os.path.expanduser('~/.cache/huggingface')), 'hub');" - "d=os.path.join(base,'models--'+repo.replace('/','--'));" - "snap=os.path.join(d,'snapshots');" - "ok=os.path.isdir(snap) and any(os.path.isdir(os.path.join(snap,x)) and os.listdir(os.path.join(snap,x)) for x in os.listdir(snap));" - "inc=False;" - "blobs=os.path.join(d,'blobs');" - "inc=os.path.isdir(blobs) and any(x.endswith('.incomplete') for x in os.listdir(blobs));" - "sys.exit(0 if ok and not inc else 1)" - ) - cmd = ["python3", "-c", py, repo_id] + cmd = ["python3", "-c", HF_CACHE_COMPLETE_PROBE, repo_id, cache_root or ""] try: if remote_host: ssh_base = ["ssh"] @@ -2839,7 +2853,7 @@ def setup_cookbook_routes() -> APIRouter: except Exception: return False - def _download_cache_incomplete(repo_id: str, remote_host: str = "", ssh_port: str = "") -> bool: + def _download_cache_incomplete(repo_id: str, remote_host: str = "", ssh_port: str = "", cache_root: str = "") -> bool: """Best-effort check for resumable HF partial blobs. A lost SSH/tmux session can leave a real download still incomplete. @@ -2848,16 +2862,7 @@ def setup_cookbook_routes() -> APIRouter: """ if not repo_id or "/" not in repo_id: return False - py = ( - "import os,sys;" - "repo=sys.argv[1];" - "base=os.environ.get('HUGGINGFACE_HUB_CACHE') or os.path.join(os.environ.get('HF_HOME', os.path.expanduser('~/.cache/huggingface')), 'hub');" - "d=os.path.join(base,'models--'+repo.replace('/','--'));" - "blobs=os.path.join(d,'blobs');" - "inc=os.path.isdir(blobs) and any(x.endswith('.incomplete') for x in os.listdir(blobs));" - "sys.exit(0 if inc else 1)" - ) - cmd = ["python3", "-c", py, repo_id] + cmd = ["python3", "-c", HF_CACHE_INCOMPLETE_PROBE, repo_id, cache_root or ""] try: if remote_host: ssh_base = ["ssh"] @@ -2929,12 +2934,18 @@ def setup_cookbook_routes() -> APIRouter: if not _SESSION_ID_RE.match(session_id): logger.warning(f"Skipping task with unsafe session_id: {session_id!r}") continue - if remote and not _REMOTE_HOST_RE.match(remote): - logger.warning(f"Skipping task with unsafe remoteHost: {remote!r}") - continue - if _tport and not _SSH_PORT_RE.match(str(_tport)): - logger.warning(f"Skipping task with unsafe sshPort: {_tport!r}") - continue + if remote: + try: + remote = validate_remote_host(remote) + except HTTPException: + logger.warning(f"Skipping task with unsafe remoteHost: {remote!r}") + continue + if _tport: + try: + _tport = validate_ssh_port(str(_tport)) + except HTTPException: + logger.warning(f"Skipping task with unsafe sshPort: {_tport!r}") + continue if task_platform == "windows" and remote: # Windows: check PID file + Get-Process, read log tail sd = "$env:TEMP\\odysseus-sessions" @@ -3047,6 +3058,7 @@ def setup_cookbook_routes() -> APIRouter: # snapshot to classify (DOWNLOAD_OK / exit marker) — evaluate it even # when the PID is gone instead of blindly reporting "stopped". download_zero_files = False + exit_code = None status = "unknown" download_has_ok = task_type == "download" and "DOWNLOAD_OK" in full_snapshot download_has_failed = task_type == "download" and "DOWNLOAD_FAILED" in full_snapshot @@ -3055,7 +3067,7 @@ def setup_cookbook_routes() -> APIRouter: and ( ".incomplete" in full_snapshot or bool(re.search(r'model-\d+-of-\d+\.[A-Za-z0-9_.-]+:\s+(?:[0-9]|[1-8][0-9])%', full_snapshot)) - or _download_cache_incomplete(_payload.get("repo_id") or model, remote, str(_tport or "")) + or _download_cache_incomplete(_payload.get("repo_id") or model, remote, str(_tport or ""), _payload.get("local_dir") or "") ) ) if is_alive or (local_win_task and full_snapshot): @@ -3096,11 +3108,19 @@ def setup_cookbook_routes() -> APIRouter: else: status = "running" else: - # Session is dead — check if it completed or crashed - if ( + # Session is dead — check if it completed or crashed. The + # runner markers in the retained output are conclusive + # (DOWNLOAD_OK only prints after exit 0), so check them before + # the cache probe, which can't see ollama pulls at all. + marker = classify_dead_download(full_snapshot) if task_type == "download" else None + if marker is not None: + status, download_zero_files = marker + if status == "completed" and not progress_text: + progress_text = "Download complete" + elif ( task_type == "download" and not download_has_incomplete_evidence - and _download_cache_complete(_payload.get("repo_id") or model, remote, str(_tport or "")) + and _download_cache_complete(_payload.get("repo_id") or model, remote, str(_tport or ""), _payload.get("local_dir") or "") ): status = "completed" if not progress_text: @@ -3120,7 +3140,7 @@ def setup_cookbook_routes() -> APIRouter: status = "error" if download_zero_files: diagnosis = {"message": "No matching files were downloaded. The model repo or filename/quant pattern may be wrong (for example a ':Q4_K_M' tag that does not exist in the repo). Check the repo and the include/quant pattern."} - output_tail = "\n".join(full_snapshot.splitlines()[-12:]) if full_snapshot else "" + output_tail = error_aware_output_tail(full_snapshot, status) results.append({ "session_id": session_id, @@ -3131,6 +3151,7 @@ def setup_cookbook_routes() -> APIRouter: "phase": serve_phase, "diagnosis": diagnosis, "output_tail": output_tail, + "exit_code": exit_code, "cmd": _payload.get("_cmd") or "", "tps": phase_info.get("tps"), "reqs": phase_info.get("reqs"), diff --git a/routes/diagnostics_routes.py b/routes/diagnostics_routes.py index daebef8d2..e6167a80f 100644 --- a/routes/diagnostics_routes.py +++ b/routes/diagnostics_routes.py @@ -1,12 +1,13 @@ """Diagnostics routes — /api/db/stats, /api/rag/stats, /api/test/youtube, /api/test-research.""" import logging +import os from typing import Dict, Any from fastapi import APIRouter, HTTPException, Form, Request from services.youtube.youtube_handler import extract_youtube_id, extract_transcript_async -from core.constants import DEFAULT_HOST +from core.constants import DEFAULT_HOST, DATA_DIR from core.middleware import require_admin logger = logging.getLogger(__name__) @@ -16,9 +17,42 @@ def setup_diagnostics_routes( rag_manager, rag_available: bool, research_handler, + memory_vector=None, ) -> APIRouter: router = APIRouter(tags=["diagnostics"]) + @router.get("/api/diagnostics/services") + async def get_service_health(request: Request) -> Dict[str, Any]: + """Consolidated degraded-state report for ChromaDB, SearXNG, email, + ntfy, and provider endpoints. Non-intrusive probes — safe to poll.""" + require_admin(request) + from src.service_health import collect_service_health + return await collect_service_health(rag_manager, memory_vector) + + @router.get("/api/diagnostics/logs") + async def get_diagnostics_logs(request: Request, limit: int = 200) -> Dict[str, Any]: + require_admin(request) + limit = max(1, min(limit, 1000)) + try: + log_file = os.path.join(DATA_DIR, "logs", "app.log") + if not os.path.exists(log_file): + return {"status": "success", "logs": []} + + # Safe tail read of the log file (max 5MB via rotation) + with open(log_file, "r", encoding="utf-8", errors="ignore") as f: + lines = f.readlines() + + tail_lines = lines[-limit:] if len(lines) > limit else lines + tail_lines = [line.rstrip('\r\n') for line in tail_lines] + + return { + "status": "success", + "logs": tail_lines + } + except Exception as e: + logger.error(f"Diagnostics logs retrieval error: {e}") + raise HTTPException(500, f"Failed to retrieve logs: {str(e)}") + @router.get("/api/db/stats") async def get_database_stats(request: Request) -> Dict[str, Any]: require_admin(request) diff --git a/routes/document_routes.py b/routes/document_routes.py index cb41108e0..e4598d925 100644 --- a/routes/document_routes.py +++ b/routes/document_routes.py @@ -108,10 +108,10 @@ def setup_document_routes(session_manager, upload_handler=None) -> APIRouter: # to markdown for prose. language = req.language if not language: - from src.tool_implementations import _looks_like_email_document, _sniff_doc_language + from src.agent_tools.document_tools import _looks_like_email_document, _sniff_doc_language language = _sniff_doc_language(req.content) else: - from src.tool_implementations import _looks_like_email_document + from src.agent_tools.document_tools import _looks_like_email_document if _looks_like_email_document(req.content, req.title): language = "email" @@ -643,7 +643,7 @@ def setup_document_routes(session_manager, upload_handler=None) -> APIRouter: # in-memory active-doc pointer so the last-resort injection # path doesn't re-surface this doc in a later chat (#1160). try: - from src.tool_implementations import clear_active_document + from src.agent_tools.document_tools import clear_active_document clear_active_document(doc_id) except Exception: pass @@ -672,7 +672,7 @@ def setup_document_routes(session_manager, upload_handler=None) -> APIRouter: # Closed/deleted — drop the in-memory active-doc pointer so it isn't # re-injected into a later, unrelated chat (#1160). try: - from src.tool_implementations import clear_active_document + from src.agent_tools.document_tools import clear_active_document clear_active_document(doc_id) except Exception: pass diff --git a/routes/email_helpers.py b/routes/email_helpers.py index 890680a87..b3df6a560 100644 --- a/routes/email_helpers.py +++ b/routes/email_helpers.py @@ -304,6 +304,7 @@ OWNER_SCOPED_EMAIL_CACHE_TABLES = { "email_ai_replies", "email_calendar_extractions", "email_urgency_alerts", + "sender_signatures", } @@ -341,6 +342,55 @@ def _ensure_owner_scoped_email_cache_table(conn, table: str, create_sql: str, co _lg.getLogger(__name__).warning(f"{table} owner-migration skipped: {_mig_e}") +def _ensure_sender_signatures_table(conn): + """Create/migrate learned sender signatures to an owner-scoped cache.""" + create_sql = """ + CREATE TABLE IF NOT EXISTS sender_signatures ( + from_address TEXT, + owner TEXT DEFAULT '', + signature_text TEXT, + sample_count INTEGER, + last_built_at TEXT NOT NULL, + model_used TEXT, + source TEXT, + PRIMARY KEY (from_address, owner) + ) + """ + conn.execute(create_sql) + try: + info = conn.execute("PRAGMA table_info(sender_signatures)").fetchall() + cols = [r[1] for r in info] + pk_cols = [r[1] for r in sorted((r for r in info if r[5]), key=lambda r: r[5])] + if "owner" in cols and pk_cols == ["from_address", "owner"]: + return + + conn.execute("ALTER TABLE sender_signatures RENAME TO sender_signatures__old") + conn.execute(create_sql) + old_cols = [r[1] for r in conn.execute("PRAGMA table_info(sender_signatures__old)").fetchall()] + copy_cols = [ + c for c in ( + "from_address", + "signature_text", + "sample_count", + "last_built_at", + "model_used", + "source", + ) + if c in old_cols + ] + source_owner = "COALESCE(owner, '')" if "owner" in old_cols else "''" + conn.execute( + f"INSERT OR IGNORE INTO sender_signatures " + f"({', '.join([*copy_cols, 'owner'])}) " + f"SELECT {', '.join([*copy_cols, source_owner])} " + f"FROM sender_signatures__old" + ) + conn.execute("DROP TABLE sender_signatures__old") + except Exception as _mig_e: + import logging as _lg + _lg.getLogger(__name__).warning(f"sender_signatures owner-migration skipped: {_mig_e}") + + def attachment_extract_dir(folder: str, uid: str) -> Path: """Containment-safe extraction directory for an attachment. @@ -559,20 +609,10 @@ def _init_scheduled_db(): conn.execute("ALTER TABLE email_boundaries ADD COLUMN turns_json TEXT") except Exception: pass - # Per-sender signature cache. Populated by `learn_sender_signatures` - # action: the LLM extracts the common trailing block across N emails - # from each sender; the renderer folds it consistently for every - # future email from that address. - conn.execute(""" - CREATE TABLE IF NOT EXISTS sender_signatures ( - from_address TEXT PRIMARY KEY, - signature_text TEXT, - sample_count INTEGER, - last_built_at TEXT NOT NULL, - model_used TEXT, - source TEXT - ) - """) + # Per-sender signature cache. Populated by `learn_sender_signatures`. + # Message sender addresses are global, so signatures must be scoped to the + # mailbox owner before `/read` returns them to the renderer. + _ensure_sender_signatures_table(conn) conn.commit() conn.close() @@ -762,10 +802,14 @@ def _open_imap_connection(host: str, port: int, *, starttls: bool, timeout: int imaplib._MAXLINE = 50_000_000 return conn -def _imap_connect(account_id: str | None = None, owner: str = ""): +def _imap_connect(account_id: str | None = None, owner: str = "", + timeout: int = _IMAP_TIMEOUT_SECONDS): # SECURITY: passing `owner` scopes the fallback config lookup so a brand # new user doesn't get connected against another user's default mailbox # when they have no account configured. + # + # `timeout` is overridable so short-lived callers (e.g. the service-health + # probe) can impose a tighter budget than the default IMAP timeout. cfg = _get_email_config(account_id, owner=owner) # Connection mode: # STARTTLS on → plain + upgrade @@ -778,7 +822,7 @@ def _imap_connect(account_id: str | None = None, owner: str = ""): cfg["imap_host"], cfg["imap_port"], starttls=bool(cfg.get("imap_starttls")), - timeout=_IMAP_TIMEOUT_SECONDS, + timeout=timeout, ) try: conn.login(cfg["imap_user"], cfg["imap_password"]) diff --git a/routes/email_routes.py b/routes/email_routes.py index 1c5e1e6a4..0871b5656 100644 --- a/routes/email_routes.py +++ b/routes/email_routes.py @@ -249,6 +249,41 @@ def _uid_from_fetch_meta(meta_b: bytes) -> str: return m.group(1).decode() if m else "" +_FETCH_SEQ_RE = re.compile(rb"^(\d+)\s+\(") + + +def _group_uid_fetch_records(msg_data) -> list: + """Group an imaplib UID FETCH response into per-message (meta, payload). + + imaplib yields an interleaved list: ``(meta, literal)`` tuples for + attributes that carry a literal (``RFC822.HEADER {n}`` etc.) plus bare + ``bytes`` elements for everything the server sends outside a literal. + Where each attribute lands is server-specific: Dovecot sends FLAGS + *before* the header literal (so it ends up inside the tuple meta), while + Gmail sends FLAGS *after* it, arriving as a bare ``b' FLAGS (\\Seen))'`` + element. Dropping bare elements therefore silently loses FLAGS on Gmail + and every message renders as unread/unflagged. + + A tuple whose meta starts with a sequence number opens a new record; + every other part — continuation tuple or bare bytes — is folded into the + current record's meta so attribute regexes see the full meta text. + Plain ``b')'`` terminators get folded in too, which is harmless. + """ + grouped: list = [] # list of (meta_bytes, payload_bytes_or_None) + for part in (msg_data or []): + if isinstance(part, tuple): + meta_b = part[0] if isinstance(part[0], (bytes, bytearray)) else str(part[0]).encode() + if _FETCH_SEQ_RE.match(meta_b): + grouped.append((meta_b, part[1])) + elif grouped: + cur_meta, cur_payload = grouped[-1] + grouped[-1] = (cur_meta + b" " + meta_b, cur_payload or part[1]) + elif isinstance(part, (bytes, bytearray)) and grouped: + cur_meta, cur_payload = grouped[-1] + grouped[-1] = (cur_meta + b" " + bytes(part), cur_payload) + return grouped + + def _smtp_ready(cfg: dict) -> bool: return bool(cfg.get("smtp_host") and cfg.get("smtp_user") and cfg.get("smtp_password")) @@ -799,20 +834,11 @@ def setup_email_routes(): except Exception as e: logger.warning(f"Batch fetch failed, falling back to per-UID: {e}") status, msg_data = "NO", [] - # imaplib batch responses interleave (meta, payload) tuples and - # `b')'` terminators. Group by message: each tuple where the - # meta begins with a seq number starts a new message record. - seq_re = re.compile(rb'^(\d+)\s+\(') - grouped = [] # list of (meta_str, payload_bytes) - for part in (msg_data or []): - if isinstance(part, tuple): - meta_b = part[0] if isinstance(part[0], (bytes, bytearray)) else str(part[0]).encode() - if seq_re.match(meta_b): - grouped.append((meta_b, part[1])) - elif grouped: - # continuation of previous message — concatenate meta info if any - cur_meta, cur_payload = grouped[-1] - grouped[-1] = (cur_meta + b" " + meta_b, cur_payload or part[1]) + # Group the batched response into per-message (meta, payload) + # records. Bare bytes parts must be kept: Gmail returns FLAGS + # after the header literal as a bare element, and dropping it + # rendered every Gmail message as unread/unflagged. + grouped = _group_uid_fetch_records(msg_data) if status != "OK" and not grouped: conn.logout() @@ -1061,7 +1087,10 @@ def setup_email_routes(): return {"contacts": [], "error": "Mail operation failed"} @router.get("/search") - async def search_emails( + # Sync def: the body is blocking IMAP I/O with no awaits. As `async def` it ran + # directly on the event loop and stalled the whole app during a search; as a sync + # def FastAPI runs it in a threadpool, keeping the loop responsive. + def search_emails( q: str = Query(""), folder: str = Query("INBOX"), limit: int = Query(50), @@ -1123,14 +1152,15 @@ def setup_email_routes(): continue raw_header = None flags = "" - for part in msg_data: - if isinstance(part, tuple): - meta = part[0].decode() if isinstance(part[0], bytes) else str(part[0]) - if b"RFC822.HEADER" in part[0] if isinstance(part[0], bytes) else "RFC822.HEADER" in meta: - raw_header = part[1] - flag_match = re.search(r'FLAGS \(([^)]*)\)', meta) - if flag_match: - flags = flag_match.group(1) + # Same Gmail caveat as the list route: FLAGS may + # arrive after the header literal, so group bare + # parts back into the message meta before scanning. + for meta_b, payload in _group_uid_fetch_records(msg_data): + if payload and b"RFC822.HEADER" in meta_b: + raw_header = payload + flag_match = re.search(rb'FLAGS \(([^)]*)\)', meta_b) + if flag_match: + flags = flag_match.group(1).decode(errors="replace") if not raw_header: continue msg = email_mod.message_from_bytes(raw_header) @@ -1279,8 +1309,9 @@ def setup_email_routes(): try: if sender_addr: _rs = _c.execute( - "SELECT signature_text FROM sender_signatures WHERE from_address = ?", - (sender_addr.lower().strip(),), + f"SELECT signature_text FROM sender_signatures " + f"WHERE from_address = ? AND {owner_clause}", + (sender_addr.lower().strip(), *owner_params), ).fetchone() if _rs and _rs[0]: cached_sender_sig = _rs[0] @@ -1756,7 +1787,9 @@ def setup_email_routes(): return {"success": False, "error": "Mail operation failed"} @router.post("/archive/{uid}") - async def archive_email(uid: str, folder: str = Query("INBOX"), account_id: str | None = Query(None), owner: str = Depends(require_owner)): + # Sync def: blocking IMAP I/O with no awaits — see search_emails above. Runs in a + # threadpool instead of blocking the event loop. + def archive_email(uid: str, folder: str = Query("INBOX"), account_id: str | None = Query(None), owner: str = Depends(require_owner)): """Move email to Archive folder.""" try: with _imap(account_id, owner=owner) as conn: diff --git a/routes/gallery_helpers.py b/routes/gallery_helpers.py index 5cab62791..e4005b8a7 100644 --- a/routes/gallery_helpers.py +++ b/routes/gallery_helpers.py @@ -11,6 +11,7 @@ from typing import Dict, Any, Optional from pydantic import BaseModel from core.database import GalleryImage +from src.auth_helpers import _auth_disabled logger = logging.getLogger(__name__) @@ -120,19 +121,18 @@ def _image_to_dict(img: GalleryImage, session_name: str = None) -> Dict[str, Any } -def _owner_filter(q, user): +def _owner_filter(q, user, model_cls=GalleryImage): """Apply owner filtering to a gallery query. - When auth is disabled (single-user mode) get_current_user returns None - and there is no per-user scoping. The main library list and stats already - treat None as "show everything" (`if user is not None`), so this helper - must too — otherwise the tag/model filter sidebars come back empty and the - tag-cleanup endpoints (clear-user-tags, clear-ai-tags, dedupe-tags) - silently affect zero rows in the most common self-hosted deployment. + ``get_current_user`` returns None both in auth-disabled single-user mode + and when auth is enabled but no current user was resolved. Preserve the + single-user behavior, but fail closed for auth-enabled null-user states. """ - if user is None: + if user is not None: + return q.filter(model_cls.owner == user) + if _auth_disabled(): return q - return q.filter(GalleryImage.owner == user) + return q.filter(False) diff --git a/routes/gallery_routes.py b/routes/gallery_routes.py index 43999344e..c641912dc 100644 --- a/routes/gallery_routes.py +++ b/routes/gallery_routes.py @@ -19,6 +19,7 @@ from src.upload_limits import ( GALLERY_TRANSFORM_UPLOAD_MAX_BYTES, ) from src.constants import GENERATED_IMAGES_DIR +from src.optional_deps import patch_realesrgan_torchvision_compat from routes.gallery_helpers import ( GalleryPatch, _extract_exif, _image_to_dict, _owner_filter, _human_size, @@ -108,6 +109,32 @@ def _visible_image_endpoint_for_base(db, base: str, owner: str | None): return fallback +async def _fetch_result_image_b64(url: str) -> Optional[str]: + """Fetch an image URL returned in an upstream response body, base64-encoded + (or None on a non-200). + + The URL comes from the diffusion/OpenAI server's response, not from our own + config, so a malicious or compromised endpoint could otherwise steer this + fetch at an internal or cloud-metadata address. Validate it the same way the + client-supplied endpoint is validated before the first request. + """ + import base64 + import httpx + from src.url_safety import check_outbound_url + + ok, reason = check_outbound_url( + url, + block_private=os.getenv("IMAGE_BLOCK_PRIVATE_IPS", "false").lower() == "true", + ) + if not ok: + raise HTTPException(502, f"Upstream returned an unsafe image URL: {reason}") + async with httpx.AsyncClient(timeout=60) as c2: + ir = await c2.get(url) + if ir.status_code == 200: + return base64.b64encode(ir.content).decode() + return None + + def setup_gallery_routes() -> APIRouter: router = APIRouter(tags=["gallery"]) @@ -476,8 +503,7 @@ def setup_gallery_routes() -> APIRouter: .outerjoin(DbSession, GalleryImage.session_id == DbSession.id) .filter(GalleryImage.is_active == True) ) - if user is not None: - q = q.filter(GalleryImage.owner == user) + q = _owner_filter(q, user) # Search filter (prompt + tags + ai_tags) if search: @@ -579,28 +605,26 @@ def setup_gallery_routes() -> APIRouter: db = SessionLocal() try: q = db.query(GalleryAlbum) - if user: - q = q.filter(GalleryAlbum.owner == user) + q = _owner_filter(q, user, GalleryAlbum) albums = q.order_by(GalleryAlbum.created_at.desc()).all() result = [] for a in albums: _count_q = db.query(GalleryImage).filter( GalleryImage.album_id == a.id, GalleryImage.is_active == True ) - if user: - _count_q = _count_q.filter(GalleryImage.owner == user) + _count_q = _owner_filter(_count_q, user) count = _count_q.count() cover_url = None if a.cover_id: - cover = db.query(GalleryImage).filter(GalleryImage.id == a.cover_id).first() + cover_q = db.query(GalleryImage).filter(GalleryImage.id == a.cover_id) + cover = _owner_filter(cover_q, user).first() if cover: cover_url = f"/api/generated-image/{cover.filename}" elif count > 0: _cover_q = db.query(GalleryImage).filter( GalleryImage.album_id == a.id, GalleryImage.is_active == True ) - if user: - _cover_q = _cover_q.filter(GalleryImage.owner == user) + _cover_q = _owner_filter(_cover_q, user) first = _cover_q.order_by(GalleryImage.created_at.desc()).first() if first: cover_url = f"/api/generated-image/{first.filename}" @@ -643,10 +667,9 @@ def setup_gallery_routes() -> APIRouter: base = db.query(GalleryImage).filter(GalleryImage.is_active == True) size_q = db.query(func.sum(GalleryImage.file_size)).filter(GalleryImage.is_active == True) album_q = db.query(GalleryAlbum) - if user: - base = base.filter(GalleryImage.owner == user) - size_q = size_q.filter(GalleryImage.owner == user) - album_q = album_q.filter(GalleryAlbum.owner == user) + base = _owner_filter(base, user) + size_q = _owner_filter(size_q, user) + album_q = _owner_filter(album_q, user, GalleryAlbum) total = base.count() total_size = size_q.scalar() or 0 fav_count = base.filter(GalleryImage.favorite == True).count() @@ -674,8 +697,7 @@ def setup_gallery_routes() -> APIRouter: GalleryImage.is_active == True, (GalleryImage.ai_tags == None) | (GalleryImage.ai_tags == ""), ) - if user: - q = q.filter(GalleryImage.owner == user) + q = _owner_filter(q, user) if album_id: q = q.filter(GalleryImage.album_id == album_id) untagged = q.count() @@ -909,15 +931,23 @@ def setup_gallery_routes() -> APIRouter: raise HTTPException(404, "Image not found") img_filename = img.filename - # Remove the file from disk - img_path = _gallery_image_path(img_filename) - if img_path.exists(): - img_path.unlink() - - # Soft-delete the record + # Soft-delete the record first; the DB is the source of truth. img.is_active = False db.commit() + # Only after the soft-delete commit succeeds do we remove the file. + # If the file were deleted first and the commit then failed/rolled + # back, the still-active record would point at a missing file. + # Best-effort so a missing or locked file can't 500 a delete that + # already succeeded logically. Uses the path-confined resolver so a + # malformed stored filename can't escape generated_images. + try: + img_path = _gallery_image_path(img_filename) + if img_path.exists(): + img_path.unlink() + except Exception as e: + logger.warning(f"Could not remove gallery image file for {img_filename}: {e}") + # Strip stale chat-history references so the image bubble # (and its prompt caption) doesn't come back after a server # reboot replays the session. We remove the matching tool @@ -1147,10 +1177,7 @@ def setup_gallery_routes() -> APIRouter: if item.get("b64_json"): raw_b64 = item["b64_json"] elif item.get("url"): - async with httpx.AsyncClient(timeout=60) as c2: - img_r = await c2.get(item["url"]) - if img_r.status_code == 200: - raw_b64 = base64.b64encode(img_r.content).decode() + raw_b64 = await _fetch_result_image_b64(item["url"]) if not raw_b64: raise HTTPException(502, "OpenAI returned no image") @@ -1211,7 +1238,7 @@ def setup_gallery_routes() -> APIRouter: original and regenerates `strength` fraction. With strength ~0.4 you get edge blending + lighting unification while keeping the composition recognisable.""" - import httpx, base64 as _b64 + import httpx user = require_privilege(request, "can_generate_images") body = await request.json() @@ -1387,10 +1414,9 @@ def setup_gallery_routes() -> APIRouter: if item.get("b64_json"): return {"image": item["b64_json"]} if item.get("url"): - async with httpx.AsyncClient(timeout=60) as c2: - ir = await c2.get(item["url"]) - if ir.status_code == 200: - return {"image": _b64.b64encode(ir.content).decode()} + img_b64 = await _fetch_result_image_b64(item["url"]) + if img_b64: + return {"image": img_b64} last_err = f"{path}: server returned no image" except httpx.ConnectError as e: raise HTTPException(502, f"Can't reach diffusion server at {base}: {e}") @@ -1450,6 +1476,7 @@ def setup_gallery_routes() -> APIRouter: img_bytes = base64.b64decode(image_b64) src = Image.open(io.BytesIO(img_bytes)).convert("RGB") try: + patch_realesrgan_torchvision_compat() from realesrgan import RealESRGANer except ImportError: return {"error": "realesrgan not installed. Install it from Cookbook → Dependencies (search 'realesrgan')."} @@ -1499,6 +1526,7 @@ def setup_gallery_routes() -> APIRouter: img_bytes = base64.b64decode(image_b64) src = Image.open(io.BytesIO(img_bytes)).convert("RGB") try: + patch_realesrgan_torchvision_compat() from basicsr.archs.rrdbnet_arch import RRDBNet from realesrgan import RealESRGANer except ImportError: diff --git a/routes/hwfit_routes.py b/routes/hwfit_routes.py index 4879d3610..5e38b9ca3 100644 --- a/routes/hwfit_routes.py +++ b/routes/hwfit_routes.py @@ -1,7 +1,9 @@ import re from copy import deepcopy -from fastapi import APIRouter +from fastapi import APIRouter, HTTPException + +from routes._validators import validate_remote_host, validate_ssh_port # Backends the manual hardware simulator accepts. Must stay a subset of what @@ -11,6 +13,14 @@ from fastapi import APIRouter _MANUAL_BACKENDS = {"cuda", "rocm", "metal", "cpu_x86", "cpu_arm"} +def _validate_detection_target(host: str = "", ssh_port: str = "") -> tuple[str, str]: + host_value = validate_remote_host(host) or "" + port_value = validate_ssh_port(ssh_port) or "" + if port_value and not host_value: + raise HTTPException(400, "ssh_port requires host") + return host_value, port_value + + def _apply_manual_hardware(system, manual_mode="", manual_gpu_count="", manual_vram_gb="", manual_ram_gb="", manual_backend=""): """Manual hardware is a "what if I had this setup" simulator — REPLACES the detected hardware entirely instead of adding to it. @@ -105,6 +115,7 @@ def setup_hwfit_routes(): """Detect and return current system hardware info. Pass host=user@server for remote. fresh=true bypasses the per-host cache (the Rescan button).""" from services.hwfit.hardware import detect_system + host, ssh_port = _validate_detection_target(host, ssh_port) return detect_system(host=host, ssh_port=ssh_port, platform=platform, fresh=fresh) @router.get("/models") @@ -118,6 +129,7 @@ def setup_hwfit_routes(): from services.hwfit.hardware import detect_system from services.hwfit.fit import rank_models from services.hwfit.models import get_models, model_catalog_path + host, ssh_port = _validate_detection_target(host, ssh_port) system = deepcopy(detect_system(host=host, ssh_port=ssh_port, platform=platform, fresh=fresh)) if system.get("error"): return {"system": system, "models": [], "error": system["error"]} @@ -165,8 +177,14 @@ def setup_hwfit_routes(): system["gpu_name"] = g["name"] system["active_group"] = {**g, "use_count": n} - if gpu_count != "": - n = int(gpu_count) + # Parse the optional count defensively (matches the gpu_group guard + # above): a non-numeric query param previously raised ValueError -> + # HTTP 500. A malformed value is ignored, same as omitting it. + try: + n = int(gpu_count) if gpu_count != "" else None + except ValueError: + n = None + if n is not None: if n == 0: # RAM-only mode: rank against system memory, offload allowed. system["has_gpu"] = False @@ -229,6 +247,7 @@ def setup_hwfit_routes(): from services.hwfit.hardware import detect_system from services.hwfit.models import get_models from services.hwfit.profiles import compute_serve_profiles + host, ssh_port = _validate_detection_target(host, ssh_port) system = detect_system(host=host, ssh_port=ssh_port, platform=platform, fresh=fresh) if system.get("error"): return {"system": system, "profiles": [], "error": system["error"]} @@ -279,6 +298,7 @@ def setup_hwfit_routes(): """Rank image generation models against detected hardware.""" from services.hwfit.hardware import detect_system from services.hwfit.image_models import rank_image_models + host, ssh_port = _validate_detection_target(host, ssh_port) system = deepcopy(detect_system(host=host, ssh_port=ssh_port, platform=platform, fresh=fresh)) if system.get("error"): return {"system": system, "models": [], "error": system["error"]} diff --git a/routes/mcp_routes.py b/routes/mcp_routes.py index ca2722b5b..a0ade88b6 100644 --- a/routes/mcp_routes.py +++ b/routes/mcp_routes.py @@ -108,6 +108,12 @@ def _load_disabled_map(): db.close() +def _mcp_oauth_redirect_uri() -> str: + """Shared callback URL for legacy Google and generic MCP OAuth flows.""" + from src.mcp_oauth import REDIRECT_URI + return REDIRECT_URI + + def setup_mcp_routes(mcp_manager: McpManager): """Setup MCP routes with the provided manager.""" @@ -445,9 +451,9 @@ def setup_mcp_routes(mcp_manager: McpManager): client_id = keys["client_id"] scopes = oauth_cfg.get("scopes", []) - # For Desktop App creds, redirect to localhost — the user will + # For Desktop App creds, default to localhost — the user will # paste the resulting URL back if they're on a different device. - redirect_uri = "http://localhost:7000/api/mcp/oauth/callback" + redirect_uri = _mcp_oauth_redirect_uri() params = { "client_id": client_id, @@ -469,7 +475,7 @@ def setup_mcp_routes(mcp_manager: McpManager): return RedirectResponse(auth_url) else: # Remote device — show paste-back page - return HTMLResponse(_oauth_authorize_page(auth_url, server_id, host)) + return HTMLResponse(_oauth_authorize_page(auth_url, server_id, host, redirect_uri)) finally: db.close() @@ -536,7 +542,7 @@ def setup_mcp_routes(mcp_manager: McpManager): client_id = keys["client_id"] client_secret = keys["client_secret"] - redirect_uri = "http://localhost:7000/api/mcp/oauth/callback" + redirect_uri = _mcp_oauth_redirect_uri() async with httpx.AsyncClient() as client: resp = await client.post( @@ -603,13 +609,19 @@ def setup_mcp_routes(mcp_manager: McpManager): return router -def _oauth_authorize_page(auth_url: str, server_id: str, host: str) -> str: +def _oauth_authorize_page( + auth_url: str, + server_id: str, + host: str, + redirect_uri: str = "http://localhost:7000/api/mcp/oauth/callback", +) -> str: """Page with Google sign-in link and URL paste-back form for remote access.""" # Escape values interpolated into the page: `host` comes from the request # Host header and `server_id` from the OAuth state — neither is trusted. auth_url = html.escape(auth_url, quote=True) server_id = html.escape(server_id, quote=True) host = html.escape(host, quote=True) + redirect_uri = html.escape(redirect_uri, quote=True) return f""" Authorize — Odysseus @@ -654,7 +666,7 @@ def _oauth_authorize_page(auth_url: str, server_id: str, host: str) -> str: """ diff --git a/routes/memory_routes.py b/routes/memory_routes.py index 7be3c6d32..e788f82d2 100644 --- a/routes/memory_routes.py +++ b/routes/memory_routes.py @@ -29,6 +29,7 @@ from src.llm_core import llm_call_async from services.memory.memory_extractor import audit_memories from src.auth_helpers import get_current_user, require_user from src.endpoint_resolver import resolve_endpoint +from src.task_endpoint import resolve_task_endpoint from src.upload_limits import read_upload_limited, MEMORY_IMPORT_MAX_BYTES logger = logging.getLogger(__name__) @@ -105,6 +106,13 @@ def setup_memory_routes(memory_manager: MemoryManager, session_manager: SessionM if memory_manager.find_duplicates(text, user_mem): return {"ok": True, "count": len(user_mem), "message": "Memory already exists"} + if memory_data.session_id: + try: + session_obj = session_manager.get_session(memory_data.session_id) + except KeyError: + raise HTTPException(404, "Session not found") + _assert_session_owner(session_obj, user) + new_entry = memory_manager.add_entry(text, memory_data.source, memory_data.category, owner=user) if memory_data.session_id: new_entry["session_id"] = memory_data.session_id @@ -163,8 +171,17 @@ def setup_memory_routes(memory_manager: MemoryManager, session_manager: SessionM session_id = memory.get("session_id") if session_id and session_id in session_manager.sessions: - session = session_manager.get_session(session_id) - memory["session_name"] = session.name if session else f"Session {session_id[:6]}" + try: + session = session_manager.get_session(session_id) + if session: + _assert_session_owner(session, user) + memory["session_name"] = session.name if session else f"Session {session_id[:6]}" + except KeyError: + memory["session_name"] = "Unknown" + except HTTPException as exc: + if exc.status_code != 404: + raise + memory["session_name"] = "Unknown" else: memory["session_name"] = "Unknown" @@ -224,14 +241,18 @@ def setup_memory_routes(memory_manager: MemoryManager, session_manager: SessionM } messages = [system_msg] + sess.get_context_messages() + t_url, t_model, t_headers = resolve_task_endpoint( + sess.endpoint_url, sess.model, sess.headers, owner=_owner(request) + ) + try: suggestion_text = await llm_call_async( - sess.endpoint_url, - sess.model, + t_url, + t_model, messages, temperature=0.2, max_tokens=500, - headers=sess.headers, + headers=t_headers, ) try: suggestions = json.loads(suggestion_text) @@ -262,42 +283,50 @@ def setup_memory_routes(memory_manager: MemoryManager, session_manager: SessionM endpoint_url = model = None headers = {} - # Try default model from settings first - settings = _load_settings() - ep_id = settings.get("default_endpoint_id", "") - default_model = settings.get("default_model", "") - if ep_id: - db = SessionLocal() - try: - ep = db.query(ModelEndpoint).filter( - ModelEndpoint.id == ep_id, ModelEndpoint.is_enabled == True - ).first() - if ep: - base = _normalize_base(ep.base_url) - endpoint_url = build_chat_url(base) - model = default_model - if not model and ep.models: - try: - models = _json.loads(ep.models) if isinstance(ep.models, str) else ep.models - if models: - model = models[0] - except Exception: - pass - if ep.api_key: - headers = {"Authorization": f"Bearer {ep.api_key}"} - finally: - db.close() + # Try utility model from settings first — memory audit is a background + # task and should prefer the lighter utility model over the main chat model. + from src.task_endpoint import resolve_task_endpoint + user = _owner(request) + t_url, t_model, t_headers = resolve_task_endpoint(owner=user) + if t_url and t_model: + endpoint_url, model, headers = t_url, t_model, t_headers + else: + # Fall back to default model if no task/utility model configured + settings = _load_settings() + ep_id = settings.get("default_endpoint_id", "") + default_model = settings.get("default_model", "") + if ep_id: + db = SessionLocal() + try: + ep = db.query(ModelEndpoint).filter( + ModelEndpoint.id == ep_id, ModelEndpoint.is_enabled == True + ).first() + if ep: + base = _normalize_base(ep.base_url) + endpoint_url = build_chat_url(base) + model = default_model + if not model and ep.models: + try: + models = _json.loads(ep.models) if isinstance(ep.models, str) else ep.models + if models: + model = models[0] + except Exception: + pass + if ep.api_key: + headers = {"Authorization": f"Bearer {ep.api_key}"} + finally: + db.close() - # Fall back to session model if no default configured - if not endpoint_url and session: - try: - sess = session_manager.get_session(session) - _assert_session_owner(sess, _owner(request)) - endpoint_url = sess.endpoint_url - model = sess.model - headers = sess.headers - except KeyError: - pass + # Fall back to session model if no default configured + if not endpoint_url and session: + try: + sess = session_manager.get_session(session) + _assert_session_owner(sess, _owner(request)) + endpoint_url = sess.endpoint_url + model = sess.model + headers = sess.headers + except KeyError: + pass if not endpoint_url or not model: raise HTTPException(400, "No default model configured — set one in Settings") @@ -344,13 +373,14 @@ def setup_memory_routes(memory_manager: MemoryManager, session_manager: SessionM try: sess = session_manager.get_session(session) _assert_session_owner(sess, _owner(request)) - endpoint_url = sess.endpoint_url - model = sess.model - headers = sess.headers + endpoint_url, model, headers = resolve_task_endpoint( + sess.endpoint_url, sess.model, sess.headers, owner=_owner(request) + ) except KeyError: - raise HTTPException(404, "Session not found — needed for LLM config") + logger.warning("Session %s not found, falling back to utility endpoint", session) + endpoint_url, model, headers = resolve_endpoint("utility", owner=_owner(request)) else: - endpoint_url, model, headers = resolve_endpoint("utility", owner=_owner(request)) + endpoint_url, model, headers = resolve_task_endpoint(owner=_owner(request)) if not endpoint_url or not model: raise HTTPException(400, "No LLM model configured. Set a default model in Settings.") diff --git a/routes/model_routes.py b/routes/model_routes.py index 864035884..b5bd6ead8 100644 --- a/routes/model_routes.py +++ b/routes/model_routes.py @@ -123,6 +123,21 @@ def _clear_user_pref_endpoint_refs(all_prefs: dict, ep_id: str) -> int: return cleared_users +def _default_endpoint_needs_assignment(current_default_id: str, enabled_endpoint_ids) -> bool: + """Whether the global default chat endpoint should be (re)assigned. + + True when nothing is configured yet, or the configured default no longer + resolves to an enabled endpoint (e.g. the user disabled it). Without the + second case, adding a new endpoint after disabling the previous default + leaves `default_endpoint_id` pointing at the disabled endpoint, so features + that read the raw setting (Memory → Tidy) fail with "No default model + configured" even though an enabled endpoint exists. See #3586. + """ + if not current_default_id: + return True + return current_default_id not in enabled_endpoint_ids + + # Loopback hosts a user might type for a local model server (LM Studio, # llama.cpp, vLLM, …). Inside Docker these point at the *container*, not the # host the server actually runs on. @@ -233,6 +248,9 @@ _PROVIDER_CURATED = { "zai-coding": [ "glm-5.1", "glm-5v-turbo", "glm-5-turbo", "glm-4.7", "glm-4.5-air", ], + "kimi-code": [ + "kimi-for-coding", + ], "deepseek": [ "deepseek-chat", "deepseek-reasoner", ], @@ -283,6 +301,7 @@ _HOST_TO_CURATED = ( ("fireworks.ai", "fireworks"), ("googleapis.com", "google"), ("x.ai", "xai"), + ("nvidia.com", "nvidia"), ("openrouter.ai", "openrouter"), ("ollama.com", "ollama"), ) @@ -299,6 +318,8 @@ def _match_provider_curated(base_url: str, provider: str) -> str: parsed = urlparse(base_url) if _host_match(base_url, "z.ai") and "/api/coding" in (parsed.path or ""): return "zai-coding" + if _host_match(base_url, "kimi.com") and "/coding" in (parsed.path or ""): + return "kimi-code" for domain, key in _HOST_TO_CURATED: if _host_match(base_url, domain): return key @@ -477,10 +498,17 @@ _NON_CHAT_PREFIXES = ( "dall-e", "tts-", "whisper", "text-embedding", "embedding", "davinci", "babbage", "moderation", "omni-moderation", "sora", "gpt-image", "chatgpt-image", + # embedding / retrieval / non-chat models (common across providers) + "snowflake/arctic-embed", "nvidia/nv-embed", "embed", ) _NON_CHAT_CONTAINS = ( "-realtime", "-transcribe", "-tts", "-codex", - "codex-", + "codex-", "content-safety", "-safety", "-reward", "nvclip", + "kosmos", "fuyu", "deplot", "vila", "neva", + "gliner", "riva", "-parse", "-embedqa", "-nemoretriever", + "topic-control", "calibration", + "ai-synthetic-video", "cosmos-reason2", + "bge", "llama-guard", ) _NON_CHAT_EXACT_PREFIXES = ( "gpt-audio", # gpt-audio, gpt-audio-mini etc. (not gpt-4o-audio-preview which is chat) @@ -680,6 +708,7 @@ def _probe_endpoint(base_url: str, api_key: str = None, timeout: int = 5) -> Lis """Probe a base URL's /models endpoint and return list of model IDs. For Anthropic, queries their /v1/models API, falling back to hardcoded list.""" from src.endpoint_resolver import resolve_url + from src.llm_core import httpx_get_kimi_aware base = resolve_url(_normalize_base(base_url)) provider = _safe_detect_provider(base) if provider == "chatgpt-subscription": @@ -715,7 +744,7 @@ def _probe_endpoint(base_url: str, api_key: str = None, timeout: int = 5) -> Lis url = _safe_build_models_url(base) headers = _safe_build_headers(api_key, base) try: - r = httpx.get(url, headers=headers, timeout=timeout, verify=llm_verify()) + r = httpx_get_kimi_aware(url, headers, timeout=timeout, verify=llm_verify()) r.raise_for_status() data = r.json() # OpenAI format: {"data": [{"id": "model-name"}]} @@ -731,7 +760,12 @@ def _probe_endpoint(base_url: str, api_key: str = None, timeout: int = 5) -> Lis for _e in _PROVIDER_CURATED.get(_ck, []): if _e not in set(models) and not any(m.startswith(_e) for m in models): models.append(_e) - return models + if _host_match(base, "kimi.com") and "/coding" in (urlparse(base).path or ""): + _ck = _match_provider_curated(base, None) + for _e in _PROVIDER_CURATED.get(_ck, []): + if _e not in set(models) and not any(m.startswith(_e) for m in models): + models.append(_e) + return [m for m in models if _is_chat_model(m)] except httpx.HTTPStatusError as e: if api_key: status = e.response.status_code if e.response is not None else "unknown" @@ -755,7 +789,7 @@ def _probe_endpoint(base_url: str, api_key: str = None, timeout: int = 5) -> Lis data = r.json() models = [m.get("name") or m.get("model") for m in (data.get("models") or []) if m.get("name") or m.get("model")] if models: - return models + return [m for m in models if _is_chat_model(m)] except Exception as e: logger.debug(f"Ollama /api/tags probe failed for {base}: {e}") # Fall back to curated list if the provider has a URL-based match (e.g. z.ai has no /models endpoint) @@ -847,15 +881,52 @@ def _ping_endpoint(base_url: str, api_key: str = None, timeout: float = 1.5) -> def _model_endpoint_error_message(base_url: str, ping: Dict[str, Any] = None) -> str: - """Return a provider-aware error message for failed endpoint probes.""" + """Return a provider-aware error message for failed endpoint probes. + + Surfaces the URL we actually probed and, when the endpoint looks like + LM Studio (port 1234 or hostname match), adds a hint about loading a + model and confirming the Developer Server is running. The user previously + saw a generic "No models found for that provider/key" with no way to + tell whether the URL was wrong, the server was down, or the server was + reachable but had no model loaded (issue #25). + """ ping = ping or {} error = ping.get("error") + from src.endpoint_resolver import build_models_url + try: + probed = build_models_url(base_url) or base_url + except Exception: + probed = base_url parsed = urlparse(base_url) host = (parsed.hostname or "").lower() is_ollama = parsed.port == 11434 or "ollama" in host or "ollama" in base_url.lower() + is_lmstudio = ( + parsed.port == 1234 + or "lmstudio" in host + or "lm-studio" in host + or "lm_studio" in host + ) + + if is_lmstudio: + parts = [ + "LM Studio is reachable, but no models were reported.", + f"Probed {probed}.", + ] + if error: + parts.append(f"Last probe error: {error}.") + parts.append( + "Open LM Studio, load at least one model, and confirm the " + "Developer Server is running on port 1234." + ) + parts.append( + "Base URL should be http://localhost:1234/v1 (native) or " + "http://host.docker.internal:1234/v1 (Docker)." + ) + return " ".join(parts) if is_ollama: parts = ["No Ollama models found for that endpoint."] + parts.append(f"Probed {probed}.") if error: parts.append(f"Last probe error: {error}.") parts.append("Check that Ollama is running and that the base URL is correct.") @@ -865,9 +936,9 @@ def _model_endpoint_error_message(base_url: str, ping: Dict[str, Any] = None) -> return " ".join(parts) if error: - return f"No models found for that provider/key. Last probe error: {error}." + return f"No models found for that provider/key. Probed {probed}. Last probe error: {error}." - return "No models found for that provider/key." + return f"No models found for that provider/key. Probed {probed}." def _normalize_model_ids(value): @@ -1719,12 +1790,19 @@ def setup_model_routes(model_discovery): ) db.add(ep) db.commit() - # Auto-set as default chat endpoint if none configured yet. Seed - # the first CHAT model (not raw model_ids[0]) so we don't pin the - # global default to an embedding/tts/etc. entry a provider happens - # to list first. + # Auto-set as default chat endpoint when none is usable yet — either + # nothing is configured, or the configured default points at an + # endpoint that is now missing/disabled (#3586). Seed the first CHAT + # model (not raw model_ids[0]) so we don't pin the global default to + # an embedding/tts/etc. entry a provider happens to list first. settings = _load_settings() - if not settings.get("default_endpoint_id"): + enabled_ids = { + e.id + for e in db.query(ModelEndpoint).filter( + ModelEndpoint.is_enabled == True # noqa: E712 + ).all() + } + if _default_endpoint_needs_assignment(settings.get("default_endpoint_id") or "", enabled_ids): from src.endpoint_resolver import _first_chat_model settings["default_endpoint_id"] = ep.id settings["default_model"] = _first_chat_model(model_ids) or "" diff --git a/routes/personal_routes.py b/routes/personal_routes.py index c32f5ffe1..a078e580c 100644 --- a/routes/personal_routes.py +++ b/routes/personal_routes.py @@ -160,8 +160,11 @@ def setup_personal_routes(personal_docs_manager, rag_manager, rag_available): JSON response confirming removal """ try: - if not directory: - raise HTTPException(400, "Directory path is required") + # Confine to PERSONAL_DIR — parity with add_directory_to_rag (which + # resolves the path the same way). Without this, an arbitrary or + # `..`-escaping path is passed straight to + # personal_docs_manager.remove_directory / rag.remove_directory. + directory = _resolve_allowed_personal_dir(directory) logger.info(f"Removing directory from RAG: {directory}") diff --git a/routes/session_routes.py b/routes/session_routes.py index 811a40bbe..1fb2a487a 100644 --- a/routes/session_routes.py +++ b/routes/session_routes.py @@ -11,7 +11,7 @@ from core.session_manager import SessionManager from core.models import ChatMessage from src.request_models import SessionResponse from core.database import Session as DbSession, SessionLocal, Document, GalleryImage, utcnow_naive -from src.auth_helpers import get_current_user, effective_user, _auth_disabled +from src.auth_helpers import get_current_user, effective_user, _auth_disabled, owner_filter from src.session_actions import is_session_recently_active @@ -258,7 +258,9 @@ def setup_session_routes(session_manager: SessionManager, config: dict, webhook_ last_msg_map = {} mode_map = {} msg_count_map = {} - rows = db.query(DbSession.id, DbSession.folder, DbSession.total_input_tokens, DbSession.total_output_tokens, DbSession.is_important, DbSession.created_at, DbSession.updated_at, DbSession.last_message_at, DbSession.mode, DbSession.message_count).filter(DbSession.archived == False, DbSession.owner == user).all() + q = db.query(DbSession.id, DbSession.folder, DbSession.total_input_tokens, DbSession.total_output_tokens, DbSession.is_important, DbSession.created_at, DbSession.updated_at, DbSession.last_message_at, DbSession.mode, DbSession.message_count).filter(DbSession.archived == False) + q = owner_filter(q, DbSession, user) + rows = q.all() for row in rows: folder_map[row.id] = row.folder token_map[row.id] = (row.total_input_tokens or 0) + (row.total_output_tokens or 0) @@ -277,17 +279,19 @@ def setup_session_routes(session_manager: SessionManager, config: dict, webhook_ # Sessions with active documents that have content from sqlalchemy import func doc_session_ids = set( - r[0] for r in db.query(Document.session_id) - .filter(Document.is_active == True, - Document.current_content != None, - func.trim(Document.current_content) != "", - Document.owner == user) + r[0] for r in owner_filter( + db.query(Document.session_id) + .filter(Document.is_active == True, + Document.current_content != None, + func.trim(Document.current_content) != ""), + Document, user) .distinct().all() ) img_session_ids = set( - r[0] for r in db.query(GalleryImage.session_id) - .filter(GalleryImage.session_id != None, - GalleryImage.owner == user) + r[0] for r in owner_filter( + db.query(GalleryImage.session_id) + .filter(GalleryImage.session_id != None), + GalleryImage, user) .distinct().all() ) finally: diff --git a/routes/shell_routes.py b/routes/shell_routes.py index a3126abbb..b4e52325d 100644 --- a/routes/shell_routes.py +++ b/routes/shell_routes.py @@ -1,6 +1,7 @@ """Shell routes — user-facing command execution endpoint.""" import asyncio +import importlib import json import logging import os @@ -14,6 +15,7 @@ from collections import namedtuple from pathlib import Path from typing import Dict, Any from core.platform_compat import IS_APPLE_SILICON, which_tool +from src.optional_deps import prepare_optional_dependency_import # POSIX-only: `pty`/`fcntl` transitively import `termios`, which does NOT exist # on Windows, so importing them unconditionally crashed app startup there @@ -149,6 +151,11 @@ def _pip_dist_name(pkg: dict) -> str: return (pkg.get("name") or "").replace("_", "-") +def _import_optional_dependency_for_status(name: str): + prepare_optional_dependency_import(name) + return importlib.import_module(name) + + def _package_installed_from_probe(name: str, probe: dict) -> bool: """Return whether an optional dependency is usable by Cookbook. @@ -970,7 +977,6 @@ def setup_shell_routes() -> APIRouter: """ _require_admin(request) _reject_cross_site(request) - import importlib import importlib.metadata as importlib_metadata import shlex import json as _json @@ -1057,6 +1063,13 @@ def setup_shell_routes() -> APIRouter: "category": "Image", "target": "remote", }, + { + "name": "transformers", + "pip": "transformers", + "desc": "Hugging Face model components used by SD/Flux pipelines and image tools", + "category": "Image", + "target": "remote", + }, { "name": "rembg", "pip": "rembg[gpu]", @@ -1202,7 +1215,7 @@ def setup_shell_routes() -> APIRouter: pkg["status_note"] = _package_status_note("vllm", probe) else: try: - importlib.import_module(pkg["name"]) + _import_optional_dependency_for_status(pkg["name"]) importlib_metadata.version(_pip_dist_name(pkg)) pkg["installed"] = True except ImportError: @@ -1251,6 +1264,7 @@ def setup_shell_routes() -> APIRouter: "sglang[all]", "diffusers", "diffusers[torch]", + "transformers", "TTS", "bark", "faster-whisper", diff --git a/routes/webhook_routes.py b/routes/webhook_routes.py index da6288e7a..77902c24b 100644 --- a/routes/webhook_routes.py +++ b/routes/webhook_routes.py @@ -198,6 +198,8 @@ def setup_webhook_routes( "opencode-go": "https://opencode.ai/zen/go/v1", "fireworks": "https://api.fireworks.ai/inference/v1", "venice": "https://api.venice.ai/api/v1", + "kimi-code": "https://api.kimi.com/coding/v1", + "kimicode": "https://api.kimi.com/coding/v1", } # Model prefix → provider mapping for auto-detection @@ -210,6 +212,8 @@ def setup_webhook_routes( "mistral": "mistral", "llama": "groq", "mixtral": "groq", + "kimi-for-coding": "kimi-code", + "kimi": "kimi-code", } def _resolve_base_url(model: Optional[str], provider: Optional[str]) -> Optional[str]: diff --git a/routes/workspace_routes.py b/routes/workspace_routes.py new file mode 100644 index 000000000..ef70e78c2 --- /dev/null +++ b/routes/workspace_routes.py @@ -0,0 +1,85 @@ +"""Workspace API - browse server directories to pick a tool workspace folder.""" +import os +from fastapi import APIRouter, Request, HTTPException, Query + +from src.auth_helpers import get_current_user +from src.tool_security import owner_is_admin_or_single_user + +# Cap entries returned per directory (mirrors filesystem_tools._CODENAV_MAX_HITS). +# A huge directory shouldn't dump thousands of rows into the picker; the user can +# type/paste a path to jump straight in instead. +_MAX_BROWSE_DIRS = 500 + + +def setup_workspace_routes(): + router = APIRouter(prefix="/api/workspace", tags=["workspace"]) + + @router.get("/browse") + def browse(request: Request, path: str = Query(default="")): + """List subdirectories of `path` (default: home) so the UI can navigate + the server filesystem and pick a workspace folder. Directories only. + + ADMIN-ONLY: this enumerates the server filesystem, so it is gated the + same way the file/shell tools are (read_file/write_file/bash are in + NON_ADMIN_BLOCKED_TOOLS). A non-admin who can't use those tools must not + be able to map the host's directory tree either. + """ + owner = get_current_user(request) + if not owner_is_admin_or_single_user(owner): + raise HTTPException(status_code=403, detail="Workspace browsing is admin-only") + + # Resolve symlinks so the reported path is canonical and the UI navigates + # real directories (defends against symlink games in displayed paths). + target = os.path.realpath(os.path.expanduser(path.strip() or "~")) + if not os.path.isdir(target): + target = os.path.realpath(os.path.expanduser("~")) + + dirs = [] + try: + with os.scandir(target) as it: + for entry in it: + try: + # Don't follow symlinks when classifying - a symlinked + # dir is skipped rather than letting the browser wander + # off via a link. Hidden entries are omitted. + if entry.is_dir(follow_symlinks=False) and not entry.name.startswith("."): + # Build the child path server-side with os.path.join + # so it's correct on Windows (backslashes) and Linux. + dirs.append({"name": entry.name, "path": os.path.join(target, entry.name)}) + except OSError: + continue + except (PermissionError, OSError): + dirs = [] + + dirs_sorted = sorted(dirs, key=lambda d: d["name"].lower()) + truncated = len(dirs_sorted) > _MAX_BROWSE_DIRS + parent = os.path.dirname(target) + from src.tool_execution import vet_workspace + return { + "path": target, + "parent": parent if parent and parent != target else None, + "dirs": dirs_sorted[:_MAX_BROWSE_DIRS], + "truncated": truncated, + # Whether this directory may be bound as a workspace (filesystem + # roots and sensitive dirs may be browsed through but not chosen). + "selectable": vet_workspace(target) is not None, + } + + @router.get("/vet") + def vet(request: Request, path: str = Query(default="")): + """Validate a workspace path without binding it. + + The UI calls this before persisting a manually typed path (/workspace + set) so a typo, file path, deleted folder, sensitive dir, or filesystem + root is rejected up front with the canonical path returned on success, + instead of being stored client-side and silently dropped at chat time. + Admin-gated like /browse: it confirms path existence on the host. + """ + owner = get_current_user(request) + if not owner_is_admin_or_single_user(owner): + raise HTTPException(status_code=403, detail="Workspace selection is admin-only") + from src.tool_execution import vet_workspace + resolved = vet_workspace(path) + return {"ok": resolved is not None, "path": resolved} + + return router diff --git a/scripts/agent_migration_manifest.py b/scripts/agent_migration_manifest.py new file mode 100755 index 000000000..82b5d24a7 --- /dev/null +++ b/scripts/agent_migration_manifest.py @@ -0,0 +1,635 @@ +#!/usr/bin/env python3 +"""Build a neutral agent migration manifest. + +This helper is intentionally read-only. It does not import the Odysseus +application package, write to data/, call an LLM, or apply anything. It turns +common agent export shapes into a portable JSON manifest that Odysseus can +preview or import later. +""" +from __future__ import annotations + +import argparse +import hashlib +import json +import mimetypes +import sys +from dataclasses import dataclass +from datetime import datetime, timezone +from pathlib import Path +from typing import Any, Iterable + + +SCHEMA_VERSION = "agent-migration.v1" +TEXT_EXTENSIONS = { + ".cfg", + ".conf", + ".csv", + ".json", + ".log", + ".md", + ".markdown", + ".py", + ".rst", + ".toml", + ".txt", + ".yaml", + ".yml", +} + + +@dataclass(frozen=True) +class InputWarning: + path: str + message: str + + +def utc_now_iso() -> str: + return datetime.now(timezone.utc).replace(microsecond=0).isoformat().replace("+00:00", "Z") + + +def sha256_text(text: str) -> str: + return hashlib.sha256(text.encode("utf-8")).hexdigest() + + +def sha256_bytes(data: bytes) -> str: + return hashlib.sha256(data).hexdigest() + + +def sha256_path(path: Path) -> str: + h = hashlib.sha256() + with path.open("rb") as f: + for chunk in iter(lambda: f.read(65536), b""): + h.update(chunk) + return h.hexdigest() + + +def stable_id(kind: str, source_name: str, *parts: Any) -> str: + raw = "\x1f".join([kind, source_name, *[str(part) for part in parts]]) + return f"{kind}:{hashlib.sha256(raw.encode('utf-8')).hexdigest()[:16]}" + + +def read_json(path: Path) -> Any: + with path.open("r", encoding="utf-8") as handle: + return json.load(handle) + + +def normalize_category(value: Any) -> str: + category = str(value or "fact").strip().lower() + return category or "fact" + + +def normalize_memory_text(item: Any) -> str: + if isinstance(item, str): + return item.strip() + if isinstance(item, dict): + for key in ("text", "content", "memory", "value"): + value = item.get(key) + if isinstance(value, str) and value.strip(): + return value.strip() + return "" + + +def memory_metadata(item: Any, source_path: Path, index: int) -> dict[str, Any]: + metadata: dict[str, Any] = { + "source_path": str(source_path), + "source_index": index, + } + if isinstance(item, dict): + for key in ("id", "timestamp", "created_at", "updated_at", "source", "tags", "pinned"): + if key in item: + metadata[f"source_{key}"] = item.get(key) + return metadata + + +def payload_items(payload: Any, keys: tuple[str, ...]) -> Any: + if isinstance(payload, dict): + for key in keys: + if isinstance(payload.get(key), list): + return payload[key] + return payload + + +def collect_memory_json(path: Path, source_name: str) -> tuple[list[dict[str, Any]], list[InputWarning]]: + warnings: list[InputWarning] = [] + try: + payload = read_json(path) + except Exception as exc: + return [], [InputWarning(str(path), f"could not read JSON: {exc}")] + + payload = payload_items(payload, ("memories", "memory", "items", "data")) + + if not isinstance(payload, list): + return [], [InputWarning(str(path), "expected a JSON list or an object containing a memory list")] + + items: list[dict[str, Any]] = [] + seen: set[str] = set() + for index, item in enumerate(payload): + text = normalize_memory_text(item) + if not text: + warnings.append(InputWarning(str(path), f"skipped memory at index {index}: missing text")) + continue + digest = sha256_text(text.strip().lower()) + if digest in seen: + warnings.append(InputWarning(str(path), f"skipped duplicate memory at index {index}")) + continue + seen.add(digest) + category = normalize_category(item.get("category") if isinstance(item, dict) else "fact") + source = str(item.get("source") or source_name) if isinstance(item, dict) else source_name + items.append( + { + "id": stable_id("memory", source_name, path, index, digest), + "kind": "memory", + "text": text, + "category": category, + "source": source, + "metadata": memory_metadata(item, path, index), + } + ) + return items, warnings + + +def normalize_timestamp(value: Any) -> str | None: + if value is None or value == "": + return None + if isinstance(value, (int, float)): + try: + return ( + datetime.fromtimestamp(float(value), timezone.utc) + .replace(microsecond=0) + .isoformat() + .replace("+00:00", "Z") + ) + except (OverflowError, OSError, ValueError): + return str(value) + return str(value) + + +def normalize_role(value: Any) -> str: + role = str(value or "unknown").strip().lower() + if role in {"human", "user"}: + return "user" + if role in {"assistant", "ai", "bot", "model"}: + return "assistant" + if role in {"system", "tool"}: + return role + return role or "unknown" + + +def content_part_text(part: Any) -> str: + if isinstance(part, str): + return part + if isinstance(part, dict): + for key in ("text", "content", "value"): + value = part.get(key) + if isinstance(value, str): + return value + if part.get("type") == "text" and isinstance(part.get("text"), str): + return part["text"] + return "" + + +def normalize_message_text(message: dict[str, Any]) -> str: + content = message.get("content") + if isinstance(content, str): + return content + if isinstance(content, list): + return "\n".join(text for text in (content_part_text(part).strip() for part in content) if text) + if isinstance(content, dict): + parts = content.get("parts") + if isinstance(parts, list): + return "\n".join(text for text in (content_part_text(part).strip() for part in parts) if text) + for key in ("text", "content", "value"): + value = content.get(key) + if isinstance(value, str): + return value + for key in ("text", "body", "message"): + value = message.get(key) + if isinstance(value, str): + return value + return "" + + +def normalize_message(message: dict[str, Any]) -> dict[str, Any] | None: + author = message.get("author") if isinstance(message.get("author"), dict) else {} + role = ( + message.get("role") + or message.get("sender") + or message.get("speaker") + or author.get("role") + or author.get("name") + ) + text = normalize_message_text(message).strip() + if not text: + return None + normalized: dict[str, Any] = { + "role": normalize_role(role), + "text": text, + } + timestamp = normalize_timestamp(message.get("created_at") or message.get("create_time") or message.get("timestamp")) + if timestamp: + normalized["created_at"] = timestamp + message_id = message.get("id") + if message_id is not None: + normalized["source_id"] = str(message_id) + return normalized + + +def chatgpt_mapping_messages(conversation: dict[str, Any]) -> list[dict[str, Any]]: + mapping = conversation.get("mapping") + if not isinstance(mapping, dict): + return [] + rows: list[tuple[float, int, dict[str, Any]]] = [] + for index, node in enumerate(mapping.values()): + if not isinstance(node, dict) or not isinstance(node.get("message"), dict): + continue + message = node["message"] + sort_value = message.get("create_time") + try: + sort_key = float(sort_value) + except (TypeError, ValueError): + sort_key = float(index) + normalized = normalize_message(message) + if normalized: + rows.append((sort_key, index, normalized)) + return [row[2] for row in sorted(rows, key=lambda row: (row[0], row[1]))] + + +def conversation_messages(conversation: dict[str, Any]) -> tuple[list[dict[str, Any]], str]: + mapped = chatgpt_mapping_messages(conversation) + if mapped: + return mapped, "chatgpt_mapping" + for key in ("messages", "chat_messages", "turns"): + raw_messages = conversation.get(key) + if isinstance(raw_messages, list): + messages = [ + normalized + for raw in raw_messages + if isinstance(raw, dict) + for normalized in [normalize_message(raw)] + if normalized + ] + return messages, key + return [], "unknown" + + +def conversation_title(conversation: dict[str, Any], index: int) -> str: + for key in ("title", "name", "summary"): + value = conversation.get(key) + if isinstance(value, str) and value.strip(): + return value.strip() + return f"Conversation {index + 1}" + + +def collect_conversation_json( + path: Path, + source_name: str, + *, + include_content: bool = False, + max_messages: int = 2000, +) -> tuple[list[dict[str, Any]], list[InputWarning]]: + warnings: list[InputWarning] = [] + try: + payload = read_json(path) + except Exception as exc: + return [], [InputWarning(str(path), f"could not read JSON: {exc}")] + + payload = payload_items(payload, ("conversations", "conversation", "items", "data")) + if isinstance(payload, dict): + payload = [payload] + if not isinstance(payload, list): + return [], [InputWarning(str(path), "expected a JSON list or an object containing a conversation list")] + + items: list[dict[str, Any]] = [] + for index, conversation in enumerate(payload): + if not isinstance(conversation, dict): + warnings.append(InputWarning(str(path), f"skipped conversation at index {index}: expected object")) + continue + messages, format_hint = conversation_messages(conversation) + if not messages: + warnings.append(InputWarning(str(path), f"skipped conversation at index {index}: no text messages found")) + continue + title = conversation_title(conversation, index) + source_id = conversation.get("id") or conversation.get("uuid") or conversation.get("conversation_id") + text_digest = sha256_text("\n".join(f"{msg['role']}:{msg['text']}" for msg in messages)) + metadata: dict[str, Any] = { + "source_path": str(path), + "source_index": index, + "source_format": format_hint, + "message_count": len(messages), + "text_sha256": text_digest, + "content_included": False, + } + if source_id is not None: + metadata["source_id"] = str(source_id) + for key in ("create_time", "created_at", "update_time", "updated_at"): + timestamp = normalize_timestamp(conversation.get(key)) + if timestamp: + metadata[f"source_{key}"] = timestamp + item: dict[str, Any] = { + "id": stable_id("conversation", source_name, path, source_id or index, text_digest), + "kind": "conversation_thread", + "title": title, + "source": source_name, + "metadata": metadata, + } + if include_content: + if len(messages) > max_messages: + warnings.append( + InputWarning( + str(path), + f"skipped conversation content at index {index}: over {max_messages} messages", + ) + ) + else: + item["messages"] = messages + item["metadata"]["content_included"] = True + items.append(item) + return items, warnings + + +def parse_skill_frontmatter(text: str) -> dict[str, Any]: + if not text.startswith("---"): + return {} + end = text.find("\n---", 3) + if end < 0: + return {} + frontmatter: dict[str, Any] = {} + for line in text[3:end].strip().splitlines(): + if not line.strip() or line.lstrip().startswith("#") or ":" not in line: + continue + key, value = line.split(":", 1) + key = key.strip() + value = value.strip().strip('"').strip("'") + if key: + frontmatter[key] = value + return frontmatter + + +def collect_skill_dir(path: Path, source_name: str) -> tuple[list[dict[str, Any]], list[InputWarning]]: + warnings: list[InputWarning] = [] + if path.is_symlink(): + return [], [InputWarning(str(path), "skills path is a symlink; skipped")] + if not path.exists(): + return [], [InputWarning(str(path), "skills directory does not exist")] + if not path.is_dir(): + return [], [InputWarning(str(path), "skills path is not a directory")] + + items: list[dict[str, Any]] = [] + for skill_path in sorted(path.rglob("SKILL.md")): + if skill_path.is_symlink(): + warnings.append(InputWarning(str(skill_path), "skipped symlinked skill file")) + continue + try: + text = skill_path.read_text(encoding="utf-8") + except Exception as exc: + warnings.append(InputWarning(str(skill_path), f"could not read skill: {exc}")) + continue + frontmatter = parse_skill_frontmatter(text) + name = str(frontmatter.get("name") or skill_path.parent.name).strip() or skill_path.parent.name + items.append( + { + "id": stable_id("skill", source_name, skill_path, sha256_text(text)), + "kind": "skill", + "name": name, + "category": str(frontmatter.get("category") or "general"), + "source": source_name, + "format": "SKILL.md", + "content": text, + "metadata": { + "source_path": str(skill_path), + "sha256": sha256_text(text), + "frontmatter": frontmatter, + }, + } + ) + return items, warnings + + +def looks_textual(path: Path) -> bool: + if path.suffix.lower() in TEXT_EXTENSIONS: + return True + guessed, _ = mimetypes.guess_type(str(path)) + return bool(guessed and (guessed.startswith("text/") or guessed in {"application/json"})) + + +def iter_archive_dir(path: Path) -> Iterable[Path | InputWarning]: + try: + children = sorted(path.iterdir()) + except Exception as exc: + yield InputWarning(str(path), f"could not scan archive directory: {exc}") + return + for child in children: + if child.is_symlink(): + yield InputWarning(str(child), "skipped symlinked archive path") + continue + if child.is_file(): + yield child + elif child.is_dir(): + yield from iter_archive_dir(child) + + +def iter_archive_files(paths: Iterable[Path]) -> Iterable[Path | InputWarning]: + for path in paths: + if path.is_symlink(): + yield InputWarning(str(path), "skipped symlinked archive path") + continue + if path.is_file(): + yield path + elif path.is_dir(): + yield from iter_archive_dir(path) + + +def collect_archive_paths( + paths: list[Path], + source_name: str, + *, + include_content: bool = False, + max_bytes: int = 256_000, +) -> tuple[list[dict[str, Any]], list[InputWarning]]: + warnings: list[InputWarning] = [] + items: list[dict[str, Any]] = [] + existing_paths: list[Path] = [] + for path in paths: + if path.is_symlink(): + warnings.append(InputWarning(str(path), "archive path is a symlink; skipped")) + continue + if not path.exists(): + warnings.append(InputWarning(str(path), "archive path does not exist")) + continue + if not path.is_file() and not path.is_dir(): + warnings.append(InputWarning(str(path), "archive path is not a file or directory")) + continue + existing_paths.append(path) + + for entry in iter_archive_files(existing_paths): + if isinstance(entry, InputWarning): + warnings.append(entry) + continue + path = entry + if not looks_textual(path): + warnings.append(InputWarning(str(path), "skipped non-text archive file")) + continue + try: + st = path.stat() + except Exception as exc: + warnings.append(InputWarning(str(path), f"could not stat archive file: {exc}")) + continue + size = st.st_size + try: + file_hash = sha256_path(path) + except Exception as exc: + warnings.append(InputWarning(str(path), f"could not hash archive file: {exc}")) + continue + if include_content and size > max_bytes: + warnings.append(InputWarning(str(path), f"skipped archive content over {max_bytes} bytes")) + archive_item: dict[str, Any] = { + "id": stable_id("archive", source_name, path, file_hash), + "kind": "archive_document", + "title": path.name, + "source": source_name, + "metadata": { + "source_path": str(path), + "size_bytes": size, + "sha256": file_hash, + }, + } + if include_content and size <= max_bytes: + try: + archive_item["content"] = path.read_text(encoding="utf-8") + except UnicodeDecodeError: + archive_item["content"] = path.read_text(encoding="utf-8", errors="replace") + archive_item["metadata"]["decoded_with_replacement"] = True + items.append(archive_item) + return items, warnings + + +def build_manifest(args) -> dict[str, Any]: + warnings: list[InputWarning] = [] + items: list[dict[str, Any]] = [] + + for path in args.memory_json: + collected, got_warnings = collect_memory_json(path, args.source_name) + items.extend(collected) + warnings.extend(got_warnings) + + for path in args.skills_dir: + collected, got_warnings = collect_skill_dir(path, args.source_name) + items.extend(collected) + warnings.extend(got_warnings) + + for path in args.conversation_json: + collected, got_warnings = collect_conversation_json( + path, + args.source_name, + include_content=args.include_conversation_content, + max_messages=args.max_conversation_messages, + ) + items.extend(collected) + warnings.extend(got_warnings) + + if args.archive: + collected, got_warnings = collect_archive_paths( + args.archive, + args.source_name, + include_content=args.include_archive_content, + max_bytes=args.max_archive_bytes, + ) + items.extend(collected) + warnings.extend(got_warnings) + + counts: dict[str, int] = {} + for item in items: + counts[item["kind"]] = counts.get(item["kind"], 0) + 1 + + return { + "schema_version": SCHEMA_VERSION, + "generated_at": utc_now_iso(), + "source": { + "name": args.source_name, + "kind": args.source_kind, + }, + "summary": { + "item_count": len(items), + "counts_by_kind": counts, + "warning_count": len(warnings), + }, + "items": items, + "warnings": [{"path": warning.path, "message": warning.message} for warning in warnings], + } + + +def parse_args(argv: list[str] | None = None): + parser = argparse.ArgumentParser(description="Build a neutral Odysseus agent migration manifest.") + parser.add_argument("--source-name", default="agent-export", help="Human-readable source name.") + parser.add_argument("--source-kind", default="generic", help="Source adapter kind, e.g. generic, openclaw, hermes.") + parser.add_argument( + "--memory-json", + action="append", + type=Path, + default=[], + help="JSON memory export. May be a list, or an object containing memories/items/data.", + ) + parser.add_argument( + "--skills-dir", + action="append", + type=Path, + default=[], + help="Directory containing SKILL.md files. Scanned recursively.", + ) + parser.add_argument( + "--archive", + action="append", + type=Path, + default=[], + help="Text/Markdown/JSON file or directory to preserve as archive documents.", + ) + parser.add_argument( + "--conversation-json", + action="append", + type=Path, + default=[], + help="Conversation export JSON. Supports generic message lists and ChatGPT-style conversations.json.", + ) + parser.add_argument( + "--include-archive-content", + action="store_true", + help="Embed archive document content in the manifest. By default only metadata is included.", + ) + parser.add_argument( + "--max-archive-bytes", + type=int, + default=256_000, + help="Maximum bytes to embed per archive file when --include-archive-content is used.", + ) + parser.add_argument( + "--include-conversation-content", + action="store_true", + help="Embed normalized conversation messages. By default only thread metadata is included.", + ) + parser.add_argument( + "--max-conversation-messages", + type=int, + default=2000, + help="Maximum messages to embed per conversation when --include-conversation-content is used.", + ) + parser.add_argument("--output", type=Path, help="Write manifest JSON to this path instead of stdout.") + parser.add_argument("--compact", action="store_true", help="Write compact JSON without indentation.") + return parser.parse_args(argv) + + +def main(argv: list[str] | None = None) -> int: + args = parse_args(argv) + manifest = build_manifest(args) + text = json.dumps(manifest, ensure_ascii=False, sort_keys=True, separators=(",", ":")) if args.compact else ( + json.dumps(manifest, ensure_ascii=False, indent=2, sort_keys=True) + "\n" + ) + if args.output: + args.output.parent.mkdir(parents=True, exist_ok=True) + args.output.write_text(text, encoding="utf-8") + else: + sys.stdout.write(text) + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/services/hwfit/hardware.py b/services/hwfit/hardware.py index 47ec94d44..9d868f257 100644 --- a/services/hwfit/hardware.py +++ b/services/hwfit/hardware.py @@ -611,6 +611,93 @@ def _cache_key(host: str, ssh_port: str, platform_name: str): ) +def _is_containerized(): + """Best-effort check for whether the local Odysseus process is running in a container.""" + if _remote_host: + return False + + if os.path.exists("/.dockerenv"): + return True + + try: + with open("/proc/1/cgroup", encoding="utf-8", errors="replace") as f: + text = f.read().lower() + return any(marker in text for marker in ("docker", "containerd", "kubepods")) + except Exception: + return False + + +def _hardware_visibility_warning(result): + """Return a non-blocking UX warning when detected hardware may only be container-visible.""" + if not isinstance(result, dict): + return None + + if result.get("manual_hardware"): + return None + + if not result.get("containerized"): + return None + + if result.get("gpu_error"): + return None + + if not result.get("has_gpu"): + return { + "code": "container_no_gpu_visible", + "severity": "warning", + "title": "No GPU visible inside Docker", + "message": ( + "Cookbook is scanning hardware from inside the Odysseus container. " + "If your host has a GPU, Docker may not be exposing it to the container, " + "so model recommendations may be CPU-only or too conservative." + ), + "actions": [ + "manual_hardware", + "rescan", + "copy_diagnostics", + ], + } + + total_ram = result.get("total_ram_gb") or 0 + if total_ram and total_ram <= 8: + return { + "code": "container_low_ram_visible", + "severity": "info", + "title": "Container-visible RAM may be lower than host RAM", + "message": ( + "Cookbook is seeing the RAM available inside the container. " + "If your host has more memory, validate host RAM separately or use Manual Hardware." + ), + "actions": [ + "manual_hardware", + "rescan", + "copy_diagnostics", + ], + } + + return None + + +def _attach_probe_context(result, host=""): + """Attach probe-scope metadata and optional hardware visibility warning.""" + if not isinstance(result, dict) or result.get("error"): + return result + + is_remote = bool(host) + containerized = False if is_remote else _is_containerized() + + result["probe_scope"] = "remote" if is_remote else ("container" if containerized else "native") + result["containerized"] = containerized + + warning = _hardware_visibility_warning(result) + if warning: + result["hardware_visibility_warning"] = warning + else: + result.pop("hardware_visibility_warning", None) + + return result + + def detect_system(host="", ssh_port="", platform="", fresh=False): """Detect system hardware: RAM, CPU, GPU. Cached per host (hardware rarely changes, and probing a remote host over SSH is slow). Pass fresh=True to @@ -635,6 +722,7 @@ def detect_system(host="", ssh_port="", platform="", fresh=False): if _remote_platform == "windows" and _remote_host: result = _detect_windows() if result: + result = _attach_probe_context(result, host=host) _remote_host = None _remote_platform = None _cache_by_host[cache_key] = (now, result) @@ -653,6 +741,7 @@ def detect_system(host="", ssh_port="", platform="", fresh=False): if not _remote_host and os.name == "nt": result = _detect_windows() if result: + result = _attach_probe_context(result, host=host) _cache_by_host[cache_key] = (now, result) return result # PowerShell probe failed entirely — fall through to the generic path @@ -714,6 +803,7 @@ def detect_system(host="", ssh_port="", platform="", fresh=False): "gpu_error": _last_gpu_error, } + result = _attach_probe_context(result, host=host) _remote_host = None _remote_platform = None _cache_by_host[cache_key] = (now, result) diff --git a/services/hwfit/profiles.py b/services/hwfit/profiles.py index 87aa147fe..337af7648 100644 --- a/services/hwfit/profiles.py +++ b/services/hwfit/profiles.py @@ -188,12 +188,18 @@ def compute_serve_profiles(system, model, serve_weights_gb=None, serve_quant=Non # Shrink context if even the chosen KV won't fit alongside weights. # Start from the smaller of the profile's target and the model's limit. cur_ctx = min(ctx, model_ctx_max) - while cur_ctx >= 8192: + # Floor the context-shrink loop at 8192, but never above the model's own + # trained limit. A model with a sub-8192 context (e.g. a 2048-token + # SmolLM) starts below 8192, so a hard-coded 8192 guard skipped the loop + # entirely and produced NO profile — the serve UI then fell back to + # manual flags even though the model fits the GPU trivially. + ctx_floor = min(8192, model_ctx_max) + while cur_ctx >= ctx_floor: kv = _kv_gb(model, cur_ctx, kv_type) n_cpu_moe, fits = _cpu_moe_for_budget(model, quant, kv, budget, fixed_gb=serve_weights_gb) est = _weights_gb(model, quant, serve_weights_gb) + kv + 0.6 # If a non-MoE model can't fit even fully offloaded, try less context. - if model.get("is_moe") or fits or cur_ctx <= 8192: + if model.get("is_moe") or fits or cur_ctx <= ctx_floor: profiles.append({ "key": key, "label": label, diff --git a/services/memory/skill_extractor.py b/services/memory/skill_extractor.py index 79e4c67c2..3c6b7c59c 100644 --- a/services/memory/skill_extractor.py +++ b/services/memory/skill_extractor.py @@ -66,41 +66,57 @@ def _has_duplicate_title(skills, title: str) -> bool: def _extract_json_object(text: str) -> Optional[dict]: """Best-effort extraction of a JSON object from an LLM response. - The response may be wrapped in code fences or surrounded by prose, and some - models emit a stray brace in the prose before the real object - (e.g. "uses {placeholder} then {...}"). Slicing first-'{' .. last-'}' then - grabs an unparseable span and the skill is silently lost. Try the whole - string first, then each '{' start position in turn, returning the first - candidate that parses to a JSON object (dict). Returns None if none do. + The response may be wrapped in code fences or surrounded by prose. Uses + json.JSONDecoder().raw_decode() to locate the boundaries of complete JSON + objects starting at each '{' position. Nested objects are filtered out to + keep only top-level candidates. If multiple non-overlapping valid JSON + objects are found, it is treated as ambiguous and returns None. Otherwise, + returns the single valid candidate dictionary. """ if not text: return None s = text.strip() if s.startswith("```"): s = s.split("\n", 1)[-1].rsplit("```", 1)[0].strip() - end = s.rfind("}") - if end == -1: + + decoder = json.JSONDecoder() + candidates = [] + + start = s.find("{") + while start != -1: + try: + obj, idx = decoder.raw_decode(s[start:]) + end_pos = start + idx + if isinstance(obj, dict): + candidates.append((start, end_pos, obj)) + except (json.JSONDecodeError, ValueError): + pass + start = s.find("{", start + 1) + + # Filter out nested candidates to identify top-level dictionaries + top_level = [] + for c in candidates: + is_nested = False + for other in candidates: + if other == c: + continue + if other[0] <= c[0] and c[1] <= other[1]: + is_nested = True + break + if not is_nested: + top_level.append(c) + + if not top_level: return None - def _as_dict(candidate): - try: - obj = json.loads(candidate) - except (json.JSONDecodeError, ValueError): - return None - return obj if isinstance(obj, dict) else None + if len(top_level) > 1: + logger.debug( + "[skill-extract] Found multiple non-overlapping JSON objects: %s", + [item[2].get("title") for item in top_level] + ) + return None - # The clean, common case: the whole (de-fenced) string is the object. - obj = _as_dict(s) - if obj is not None: - return obj - # Otherwise scan each '{' candidate up to the last '}'. - start = s.find("{") - while 0 <= start < end: - obj = _as_dict(s[start : end + 1]) - if obj is not None: - return obj - start = s.find("{", start + 1) - return None + return top_level[0][2] async def maybe_extract_skill( diff --git a/services/memory/skills.py b/services/memory/skills.py index 9cfe801e1..5baaa88c5 100644 --- a/services/memory/skills.py +++ b/services/memory/skills.py @@ -603,7 +603,6 @@ class SkillsManager: escalation) — those are work-in-progress and pollute the prompt with half-finished procedures. """ - active_toolsets = active_toolsets or [] out = [] for s in self.load(owner=owner): status = s.get("status") @@ -617,13 +616,16 @@ class SkillsManager: # Platform gating if platform and s.get("platforms") and platform not in s["platforms"]: continue - # requires_toolsets: hide unless every required toolset is active + # requires_toolsets: hide unless every required toolset is active. + # active_toolsets=None means the caller doesn't know the active + # set (API listings, chat preface) — don't gate in that case; + # only an explicit list filters. req = s.get("requires_toolsets") or [] - if req and not all(t in active_toolsets for t in req): + if req and active_toolsets is not None and not all(t in active_toolsets for t in req): continue # fallback_for_toolsets: hide when any of those toolsets is active fb = s.get("fallback_for_toolsets") or [] - if fb and any(t in active_toolsets for t in fb): + if fb and active_toolsets and any(t in active_toolsets for t in fb): continue out.append({ "name": s["name"], diff --git a/services/research/research_handler.py b/services/research/research_handler.py index bd4c6bb15..2521f61e1 100644 --- a/services/research/research_handler.py +++ b/services/research/research_handler.py @@ -285,6 +285,7 @@ class ResearchHandler: query, report, stats, elapsed, findings=researcher.findings, evolving_report=researcher.evolving_report, + analyzed_urls=getattr(researcher, "analyzed_urls", None), ) except Exception as e: @@ -331,7 +332,8 @@ class ResearchHandler: def _format_research_report( self, query: str, full_report: str, stats: dict, elapsed: float, - findings: list = None, evolving_report: str = None, + findings: Optional[list] = None, evolving_report: Optional[str] = None, + analyzed_urls: Optional[list] = None, ) -> str: """Format research report with sources list and expandable raw findings.""" summary_lines = [ @@ -342,20 +344,34 @@ class ResearchHandler: ] summary_text = " | ".join(summary_lines) - # Build sources list with clickable links + # Build sources list with clickable links. Keep the curated Sources + # section filtered for citation quality, but also list every unique URL + # the research run inspected so the "URLs Analyzed" count is auditable. sources_section = "" - if findings: + analyzed_urls_section = "" + url_items = analyzed_urls if analyzed_urls is not None else findings + if findings or url_items: seen_urls = set() source_lines = [] - for f in findings: + analyzed_seen = set() + analyzed_lines = [] + for f in findings or []: url = f.get("url", "") title = f.get("title", "") or url summary = f.get("summary", "") or f.get("evidence", "") if url and url not in seen_urls and not is_low_quality(summary): seen_urls.add(url) source_lines.append(f"- [{title}]({url})") + for item in url_items or []: + url = item.get("url", "") + title = item.get("title", "") or url + if url and url not in analyzed_seen: + analyzed_seen.add(url) + analyzed_lines.append(f"{len(analyzed_lines) + 1}. [{title}]({url})") if source_lines: sources_section = "\n### Sources\n\n" + "\n".join(source_lines) + "\n" + if analyzed_lines: + analyzed_urls_section = "\n### Analyzed URLs\n\n" + "\n".join(analyzed_lines) + "\n" # Build raw findings section (individual extractions per source) raw_findings_section = "" @@ -391,6 +407,7 @@ class ResearchHandler: {full_report} {sources_section} +{analyzed_urls_section} {collected_section} --- diff --git a/services/search/content.py b/services/search/content.py index 2c1f5f64c..ac9b4a99c 100644 --- a/services/search/content.py +++ b/services/search/content.py @@ -299,6 +299,40 @@ def fetch_webpage_content(url: str, timeout: int = 5, retry_attempt: int = 0) -> _cache_result(cache_file, cache_key, result, url) return result + # Plain-text / Markdown / JSON handling. Sources like + # raw.githubusercontent.com serve Markdown as `text/plain`, JSON APIs and + # raw config files serve `application/json`, and a lot of code and tool + # docs live in `.md` / `.txt`. These have no HTML structure, so the HTML + # branch below would extract nothing and report "no readable text content". + # Return the body verbatim instead. The `is_html` guard keeps real HTML + # (including `application/xhtml+xml`) on the parsing path; the `json` check + # covers `application/json` and `+json` suffixes; the URL-suffix fallback + # catches servers that mislabel text files as `application/octet-stream`. + is_html = "html" in content_type + is_json = "json" in content_type + url_path = url.lower().split("?", 1)[0].split("#", 1)[0] + looks_like_text_file = url_path.endswith( + (".md", ".markdown", ".txt", ".text", ".json", ".jsonl") + ) + if not is_html and (content_type.startswith("text/") or is_json or looks_like_text_file): + text_body = (response.text or "").strip() + result = { + "url": url, + "title": os.path.basename(url_path) or url, + "content": text_body, + "lists": [], + "tables": [], + "code_blocks": [], + "meta_description": "", + "meta_keywords": "", + "js_rendered": False, + "js_message": "", + "success": bool(text_body), + "error": "" if text_body else "Empty response body", + } + _cache_result(cache_file, cache_key, result, url) + return result + # HTML handling try: soup = BeautifulSoup(response.text, "html.parser") diff --git a/services/search/providers.py b/services/search/providers.py index f2d4a583b..b913e1c6f 100644 --- a/services/search/providers.py +++ b/services/search/providers.py @@ -134,9 +134,10 @@ _NEWS_HINTS = ("news", "nyheter", "headlines", "breaking", "latest", "today", "i _GENERAL_ENGINES = os.environ.get("SEARXNG_GENERAL_ENGINES", "bing,mojeek,presearch") -def searxng_search_api(query: str, count: int = 10, categories: str = "general", +def searxng_search_api(query: str, count: Optional[int] = None, categories: str = "general", time_filter: Optional[str] = None) -> List[dict]: """Search using SearXNG JSON API. Returns list of {title, url, snippet}.""" + count = count if count is not None else _get_result_count() instance = _get_search_instance() api_key = "" headers = {"User-Agent": "Mozilla/5.0"} @@ -282,8 +283,9 @@ def searxng_search(query, max_results=10): # ── Brave ── -def brave_search(query: str, count: int = 10, time_filter: Optional[str] = None) -> List[dict]: +def brave_search(query: str, count: Optional[int] = None, time_filter: Optional[str] = None) -> List[dict]: """Search using Brave API with key from admin settings or env var.""" + count = count if count is not None else _get_result_count() api_key = _get_provider_key("brave") or os.environ.get("DATA_BRAVE_API_KEY") or "" return _brave_search_impl(query, count, time_filter, search_config={"brave_api_key": api_key}) @@ -381,9 +383,9 @@ def _resolve_ddg_redirect(raw: str) -> str: return resolved -def duckduckgo_search(query: str, count: int = 10, time_filter: Optional[str] = None) -> List[dict]: +def duckduckgo_search(query: str, count: Optional[int] = None, time_filter: Optional[str] = None) -> List[dict]: """Search using DuckDuckGo via the duckduckgo-search library. No API key needed.""" - + count = count if count is not None else _get_result_count() def _html_fallback() -> List[dict]: try: response = httpx.get( @@ -415,7 +417,7 @@ def duckduckgo_search(query: str, count: int = 10, time_filter: Optional[str] = return [] try: - from duckduckgo_search import DDGS + from ddgs import DDGS except ImportError: logger.warning("duckduckgo-search package not installed; using HTML fallback") return _html_fallback() @@ -452,7 +454,7 @@ def duckduckgo_search(query: str, count: int = 10, time_filter: Optional[str] = # ── Google Programmable Search Engine ── -def google_pse_search(query: str, count: int = 10, time_filter: Optional[str] = None) -> List[dict]: +def google_pse_search(query: str, count: Optional[int] = None, time_filter: Optional[str] = None) -> List[dict]: """Search using Google PSE (Custom Search JSON API). Requires two keys in settings: @@ -460,6 +462,7 @@ def google_pse_search(query: str, count: int = 10, time_filter: Optional[str] = - google_pse_cx: Programmable Search Engine ID (cx) Or env vars GOOGLE_API_KEY and GOOGLE_PSE_CX. """ + count = count if count is not None else _get_result_count() settings = _get_search_settings() api_key = _get_provider_key("google_pse") or os.environ.get("GOOGLE_API_KEY", "") cx = (settings.get("google_pse_cx") or "").strip() or os.environ.get("GOOGLE_PSE_CX", "") @@ -522,8 +525,9 @@ def google_pse_search(query: str, count: int = 10, time_filter: Optional[str] = # ── Tavily ── -def tavily_search(query: str, count: int = 10, time_filter: Optional[str] = None) -> List[dict]: +def tavily_search(query: str, count: Optional[int] = None, time_filter: Optional[str] = None) -> List[dict]: """Search using Tavily API. Requires search_api_key or TAVILY_API_KEY env var.""" + count = count if count is not None else _get_result_count() api_key = _get_provider_key("tavily") or os.environ.get("TAVILY_API_KEY", "") if not api_key: logger.warning("Tavily: no API key configured") @@ -580,8 +584,9 @@ def tavily_search(query: str, count: int = 10, time_filter: Optional[str] = None # ── Serper.dev ── -def serper_search(query: str, count: int = 10, time_filter: Optional[str] = None) -> List[dict]: +def serper_search(query: str, count: Optional[int] = None, time_filter: Optional[str] = None) -> List[dict]: """Search using Serper.dev API. Requires search_api_key or SERPER_API_KEY env var.""" + count = count if count is not None else _get_result_count() api_key = _get_provider_key("serper") or os.environ.get("SERPER_API_KEY", "") if not api_key: logger.warning("Serper: no API key configured") diff --git a/services/youtube/youtube_handler.py b/services/youtube/youtube_handler.py index b36989e8d..d1b1e9b91 100644 --- a/services/youtube/youtube_handler.py +++ b/services/youtube/youtube_handler.py @@ -64,20 +64,40 @@ def is_youtube_url(url: str) -> bool: return "youtube.com" in url or "youtu.be" in url +# youtube.com-shaped hosts. music.youtube.com serves the same /watch and +# /shorts paths, so links shared from YouTube Music must resolve too. +_YT_HOSTS = ("www.youtube.com", "youtube.com", "m.youtube.com", "music.youtube.com") +# Path prefixes whose first following segment is the video id. Covers the +# /embed/ player, Shorts (/shorts/), live streams (/live/), and the legacy +# /v/ embed — all of which `is_youtube_url` already treats as YouTube, so +# they must be extractable or the link is silently dropped (neither web-fetched +# nor transcript-fetched) by the chat pipeline. +_YT_PATH_PREFIXES = ("/embed/", "/shorts/", "/live/", "/v/") + + def extract_youtube_id(url: str) -> Optional[str]: - """Extract YouTube video ID from various URL formats.""" + """Extract a YouTube video ID from the common URL shapes: + watch?v=, youtu.be/, /embed/ , /shorts/ , /live/ , /v/ , + across youtube.com / m.youtube.com / music.youtube.com / youtu.be.""" if not isinstance(url, str): return None parsed = urllib.parse.urlparse(url) - if parsed.hostname in ("www.youtube.com", "youtube.com", "m.youtube.com"): + host = (parsed.hostname or "").lower() + if host in _YT_HOSTS: if parsed.path == "/watch": params = urllib.parse.parse_qs(parsed.query) - if "v" in params: + if params.get("v"): return params["v"][0] - elif parsed.path.startswith("/embed/"): - return parsed.path.split("/")[-1] - elif parsed.hostname == "youtu.be": - return parsed.path[1:] + else: + for prefix in _YT_PATH_PREFIXES: + if parsed.path.startswith(prefix): + vid = parsed.path[len(prefix):].split("/")[0] + if vid: + return vid + elif host == "youtu.be": + vid = parsed.path.lstrip("/").split("/")[0] + if vid: + return vid return None @@ -170,6 +190,8 @@ def format_transcript_for_context( if segments: ctx += "Timestamped Transcript:\n" for seg in segments: + if not isinstance(seg, dict): + continue ctx += f"[{seg['timestamp']}] {seg['text']}\n" # Check length — fall back to plain text if too long if len(ctx) > 12000: @@ -202,15 +224,24 @@ async def fetch_youtube_comments( f"https://www.youtube.com/watch?v={video_id}", ] - proc = await asyncio.wait_for( - asyncio.create_subprocess_exec( - *cmd, - stdout=asyncio.subprocess.PIPE, - stderr=asyncio.subprocess.PIPE, - ), - timeout=timeout, + proc = await asyncio.create_subprocess_exec( + *cmd, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, ) - stdout, stderr = await proc.communicate() + # Bound the wait on the process actually finishing, not on spawning it. + # create_subprocess_exec returns as soon as the child starts, so wrapping + # it in wait_for never enforces the timeout — proc.communicate() is the + # blocking step. Kill and reap the child if it overruns so it does not + # linger after we return. + try: + stdout, stderr = await asyncio.wait_for( + proc.communicate(), timeout=timeout + ) + except asyncio.TimeoutError: + proc.kill() + await proc.wait() + raise if proc.returncode != 0: return {"success": False, "error": f"yt-dlp failed: {stderr.decode()[:200]}", "comments": []} diff --git a/src/action_intents.py b/src/action_intents.py index ea0cbc86d..3b9c3cc73 100644 --- a/src/action_intents.py +++ b/src/action_intents.py @@ -91,6 +91,9 @@ _ROUTING_PATTERNS: tuple[tuple[str, str, Pattern[str]], ...] = tuple( ("ui", "tool or feature toggle request", r"\b(?:disable|enable|turn\s+(?:on|off))\s+(?:the\s+)?(?:shell|search|web|browser|documents?|memory|skills|images?|calendar|email|mail|research|incognito)\b"), # Deep research jobs, not quick conceptual mentions of research. + ("web", "explicit web search request", rf"{_PLEASE}(?:do|run|use|perform|make)\s+(?:a\s+)?(?:web\s+search|search\s+the\s+web)\b.+"), + ("web", "web lookup imperative request", rf"{_PLEASE}(?:web\s+search|search\s+the\s+web|search\s+online|look\s+up|google)\b.+"), + ("web", "assistant web lookup request", rf"{_ACTION_QUESTION}(?:web\s+search|search\s+the\s+web|search\s+online|look\s+up|google)\b.+"), ("research", "deep research imperative request", rf"{_PLEASE}(?:research|deep\s+dive|look\s+into|investigate)\s+.+"), ("research", "assistant deep research request", rf"{_ACTION_QUESTION}(?:research|do\s+research|deep\s+dive|look\s+into|investigate)\s+.+"), diff --git a/src/agent_loop.py b/src/agent_loop.py index acb35e7b1..f600ac598 100644 --- a/src/agent_loop.py +++ b/src/agent_loop.py @@ -21,7 +21,7 @@ from src.settings import get_setting from src.prompt_security import untrusted_context_message from src.tool_security import blocked_tools_for_owner, plan_mode_disabled_tools from src.tool_policy import GUIDE_ONLY_DIRECTIVE, ToolPolicy -from src.tool_utils import get_mcp_manager +from src.tool_utils import _truncate, get_mcp_manager from src.agent_tools import ( parse_tool_blocks, strip_tool_blocks, @@ -262,6 +262,11 @@ _DOMAIN_RULES = { - Use `manage_settings` for preferences and tool enable/disable. - Use named tools over `app_api` when a named wrapper exists. - `app_api` is only for safe UI/API actions without a named tool; do not use it for shell, package installs, engine rebuilds, or sensitive auth/admin paths.""", + "contacts": """\ +## Contacts rules +- Use `resolve_contact` to look up a contact's email or phone number by name. Searches the CardDAV address book and sent email history. +- Use `manage_contact` to list, add, update, or delete contacts in the address book. +- Do NOT use `manage_memory` for contact lookups — contact details live in the address book, not memory.""", } _DOMAIN_TOOL_MAP = { @@ -272,8 +277,9 @@ _DOMAIN_TOOL_MAP = { "notes_calendar_tasks": {"manage_notes", "manage_calendar", "manage_tasks"}, "ui": {"ui_control"}, "sessions": {"create_session", "list_sessions", "manage_session", "send_to_session", "search_chats"}, - "files": {"bash", "python", "read_file", "write_file", "edit_file", "grep", "glob", "ls"}, + "files": {"bash", "python", "read_file", "write_file", "edit_file", "grep", "glob", "ls", "get_workspace"}, "settings": {"manage_settings", "manage_endpoints", "manage_mcp", "manage_webhooks", "manage_tokens", "app_api"}, + "contacts": {"resolve_contact", "manage_contact"}, } def _domain_rules_for_tools(tool_names: set) -> list[str]: @@ -309,6 +315,7 @@ NEVER pipe multi-line Python through `python -c "..."` — shell quoting eats re ``` Execute Python code. Use for computation, data processing, scripting. NOT for writing code for the user (use create_document for that). Same sandbox limits as bash — no TTY, no GUI, no `input()`; for anything the user should interact with, generate a single HTML file with inline JS instead. +Prefer a dedicated tool whenever one fits the job (reading, searching, or writing files); use python only for computation/processing no dedicated tool covers - not for reading or writing files. Do NOT use Python/requests for web lookup/search/latest/current requests when `web_search` or `web_fetch` is available.""", "web_search": """\ @@ -347,6 +354,11 @@ Write content to a file. First line is the path, rest is the content.""", ``` Edit an EXISTING file by exact string replacement. PREFER this over bash (sed/echo/redirects) for changing files — it shows a before/after diff. `old_string` must match the file exactly and be unique unless `replace_all` is true. Use write_file to create a new file.""", + "get_workspace": """\ +```get_workspace +``` +Return the absolute path of the active workspace folder. File tools are CONFINED to it (paths can be RELATIVE to it); the shell starts there (cwd) but is NOT sandboxed. Call this first when the user says "the project"/"the code"/"this folder" without a path, instead of asking them. No arguments.""", + "create_document": """\ ```create_document @@ -598,7 +610,7 @@ _API_HOSTS = frozenset([ "api.deepseek.com", "deepseek.com", "api.together.xyz", "api.fireworks.ai", "api.perplexity.ai", "api.x.ai", - "ollama.com", "api.venice.ai", + "ollama.com", "api.venice.ai", "api.kimi.com", "api.githubcopilot.com", # Local OpenAI-compatible endpoints (llama.cpp, vLLM, LM Studio, etc.). # Without these, `_is_api_model` falls back to keyword sniffing on the @@ -785,6 +797,12 @@ def _classify_agent_request(messages: List[Dict], last_user: str) -> Dict[str, o domains.add("documents") if has(r"\b(search|web|google|look up|latest|news|current|weather|forecast|stock price|price of|website|url|https?://|www\.)\b"): domains.add("web") + if has( + r"\b(wyszukaj|wyszukać|wyszukac)\b.*\b(internet|internecie|online|web)\b", + r"\b(sprawd[zź]|znajd[zź])\b.*\b(internet|internecie|online|web)\b", + r"\b(aktualn\w*|bieżąc\w*|biezac\w*|dzisiaj|teraz)\b.*\b(pogod\w*|temperatur\w*)\b", + ): + domains.add("web") if has(r"\b(research|deep dive|investigate|look into)\b"): domains.add("web") if has(r"\b(open|show|toggle|turn on|turn off|disable|enable|switch model|change model|settings|theme|panel)\b"): @@ -795,6 +813,8 @@ def _classify_agent_request(messages: List[Dict], last_user: str) -> Dict[str, o domains.add("files") if has(r"\b(endpoint|api token|mcp|webhook|preference|configure|config|setting)\b"): domains.add("settings") + if has(r"\b(contact|contacts|phone|phone number|address book|vcard)\b"): + domains.add("contacts") low_signal = not continuation and not domains return { @@ -860,7 +880,7 @@ def _build_system_prompt( _ov_sig = _hl.sha256(_json.dumps(get_builtin_overrides() or {}, sort_keys=True).encode()).hexdigest() except Exception: _ov_sig = "" - cache_key = (frozenset(disabled_tools or []), bool(mcp_mgr), needs_admin, _rt_key, compact, _ov_sig, suppress_local_context) + cache_key = (frozenset(disabled_tools or []), bool(mcp_mgr), needs_admin, _rt_key, compact, _ov_sig, owner, suppress_local_context) if _cached_base_prompt and _cached_base_prompt_key == cache_key and not active_document: agent_prompt = _cached_base_prompt # Skill index is user-editable (name + description), so it must never @@ -868,7 +888,7 @@ def _build_system_prompt( # when the cache hits. _, _skill_index_block = _build_base_prompt( disabled_tools, mcp_mgr, needs_admin, relevant_tools, - mcp_disabled_map=mcp_disabled_map, compact=compact, + mcp_disabled_map=mcp_disabled_map, compact=compact, owner=owner, suppress_local_context=suppress_local_context, ) else: @@ -879,6 +899,7 @@ def _build_system_prompt( relevant_tools, mcp_disabled_map=mcp_disabled_map, compact=compact, + owner=owner, suppress_local_context=suppress_local_context, ) if not active_document: @@ -894,9 +915,20 @@ def _build_system_prompt( # Current date/time for every agent request. This is user-local when the # browser provided timezone headers, with a server-local fallback. + # + # IMPORTANT: this is intentionally NOT prepended into agent_prompt (the + # system message) anymore. Its text changes every minute, and local + # OpenAI-compatible backends (llama.cpp / LM Studio) key their KV-cache + # prefix off the system message byte-for-byte — mixing ever-changing + # timestamp text into the (already large, tool-laden) agent system prompt + # would invalidate the cached prefix on every single request, forcing a + # full prompt re-evaluation each turn (issue #2927). It's built here as a + # standalone *user*-role message and inserted near the end of the array, + # right alongside _doc_message / _skills_message, below. + _datetime_message = None try: - from src.user_time import current_datetime_prompt - agent_prompt = current_datetime_prompt() + agent_prompt + from src.user_time import current_datetime_context_message + _datetime_message = current_datetime_context_message() except Exception: pass @@ -1296,6 +1328,9 @@ def _build_system_prompt( last_user_idx += 1 if _skills_message: merged.insert(last_user_idx, _skills_message) + last_user_idx += 1 + if _datetime_message: + merged.insert(last_user_idx, _datetime_message) return merged, mcp_schemas @@ -1314,6 +1349,7 @@ def _build_base_prompt( relevant_tools=None, mcp_disabled_map=None, compact: bool = False, + owner: Optional[str] = None, suppress_local_context: bool = False, ): """Build the agent prompt with only relevant tools included. @@ -1373,7 +1409,7 @@ def _build_base_prompt( from src.constants import DATA_DIR _sm = SkillsManager(DATA_DIR) active_tools = list(set(TOOL_SECTIONS.keys()) - set(disabled or [])) - skill_idx = _sm.index_for(owner=None, active_toolsets=active_tools) + skill_idx = _sm.index_for(owner=owner, active_toolsets=active_tools) if skill_idx: lines = ["## Available skills", "Procedures the assistant should consult before doing domain work. " @@ -1782,10 +1818,10 @@ async def stream_agent_loop( owner: Optional[str] = None, relevant_tools: Optional[Set[str]] = None, fallbacks: Optional[List[tuple]] = None, - workspace: Optional[str] = None, plan_mode: bool = False, approved_plan: Optional[str] = None, tool_policy: Optional[ToolPolicy] = None, + workspace: Optional[str] = None, _is_teacher_run: bool = False, ) -> AsyncGenerator[str, None]: """Streaming agent loop generator. @@ -1854,8 +1890,21 @@ async def stream_agent_loop( logger.info(f"[tool-rag] Using caller-provided relevant_tools ({len(_relevant_tools)} tools)") if not guide_only and not _relevant_tools and bool(_intent.get("low_signal")): from src.tool_index import ALWAYS_AVAILABLE - _relevant_tools = set(ALWAYS_AVAILABLE) - logger.info("[tool-rag] Low-signal agent message; skipping retrieval and using always-available tools only") + if workspace: + # An active workspace IS the file-work signal: a vague "look at the + # project" means explore this folder. Surface only the READ-ONLY file + # tools (intersection with the plan-mode read-only allowlist) so the + # agent can investigate; write/shell tools stay out until the request + # actually calls for them (RAG retrieval adds those on a real ask). + _relevant_tools = set(ALWAYS_AVAILABLE) + from src.tool_security import PLAN_MODE_READONLY_TOOLS + _relevant_tools |= (_DOMAIN_TOOL_MAP["files"] & PLAN_MODE_READONLY_TOOLS) + logger.info("[tool-rag] Low-signal but workspace active; including read-only file tools") + else: + # Don't short-circuit: fall through to RAG retrieval below. + # Non-English queries are flagged low_signal by the English-only + # intent classifier, but fastembed retrieval works across languages. + logger.info("[tool-rag] Low-signal query; will run RAG retrieval") if not guide_only and not _relevant_tools: try: from src.tool_index import get_tool_index, ALWAYS_AVAILABLE @@ -1930,6 +1979,44 @@ async def stream_agent_loop( if _relevant_tools is not None and active_document is not None: _relevant_tools.update({"edit_document", "update_document", "suggest_document"}) + # The skill index injected by _build_system_prompt tells the model to + # call `manage_skills action=view`, and Jaccard-matched skills are pasted + # into the prompt as procedures to follow — but neither path goes through + # tool selection, so the model can be handed a procedure naming tools + # (grep, read_file, ...) that aren't in its schema list. Keep the schemas + # in lockstep: manage_skills is callable whenever any skill is indexed, + # and a matched skill's declared requires_toolsets ride along with it. + if not guide_only and _relevant_tools is not None: + try: + from services.memory.skills import SkillsManager + from src.constants import DATA_DIR + _skills_on = True + try: + from routes.prefs_routes import _load_for_user as _load_prefs + _skills_on = (_load_prefs(owner) or {}).get("skills_enabled", True) + except Exception: + pass + _sm = SkillsManager(DATA_DIR) + _owner_skills = _sm.load(owner=owner) if _skills_on else [] + if _owner_skills: + _relevant_tools.add("manage_skills") + if _retrieval_query: + # Validate against every known executable tool, not just + # TOOL_SECTIONS — code-nav tools (grep/glob/ls) ship as + # schemas without a prompt-prose section. + from src.tool_policy import known_tool_names + _known = known_tool_names() + for _sk in _sm.get_relevant_skills( + _retrieval_query, skills=_owner_skills, + threshold=0.25, max_items=3, + ): + _relevant_tools.update( + t for t in (_sk.get("requires_toolsets") or []) + if t in _known + ) + except Exception as _e: + logger.debug(f"[tool-rag] skill-aware tool include skipped: {_e}") + if _relevant_tools is not None: logger.info("[agent-intent] selected_tools=%s", sorted(_relevant_tools)[:50]) @@ -1980,6 +2067,10 @@ async def stream_agent_loop( # and can override this list for users who know their setup. _model_no_tools = any(kw in _model_lc for kw in ( "deepseek-r1", + # Open-weight GPT-OSS models are commonly served through llama.cpp / + # llama-cpp-python. Their names contain "gpt-o", but they do not use + # OpenAI's native tool-call channel unless the endpoint opts in. + "gpt-oss", )) # Native Ollama endpoints (/api/chat) handle tool schemas differently from # the OpenAI-compat path. Models like gemma4, qwen3.5, ministral respond to @@ -2011,27 +2102,6 @@ async def stream_agent_loop( suppress_local_context=guide_only, active_email=active_email, ) - if workspace and not guide_only: - # PREPEND (not append) so it dominates the large base prompt — appended - # at the end, small models ignored it and asked the user for code. The - # folder IS the project; the agent must explore it, not ask. - _ws_note = ( - f"## ACTIVE WORKSPACE — READ FIRST\n" - f"The user is working in this folder: {workspace}\n" - f"It IS the project. bash/python run with cwd set here and " - f"read_file/write_file are confined to it (paths outside are rejected).\n" - f"When the user says \"the code\" / \"this project\" / \"the workspace\" " - f"or asks to review/find/edit something WITHOUT a path, they mean THIS " - f"folder. Do NOT ask the user for code or a path, and do NOT read a file " - f"literally named \"workspace\". ALWAYS start by exploring it yourself: " - f"run `bash` → `git ls-files` (or `ls -R`) to see the files, then " - f"read_file the relevant ones by path RELATIVE to the workspace." - ) - if messages and messages[0].get("role") == "system": - messages[0]["content"] = _ws_note + "\n\n" + (messages[0].get("content") or "") - else: - messages.insert(0, {"role": "system", "content": _ws_note}) - logger.info("[workspace] active for this turn: %s", workspace) if plan_mode and not guide_only: # Steer the model to investigate-then-propose. Hard tool gating handles # every write path except shell; this directive is what keeps the @@ -2063,30 +2133,34 @@ async def stream_agent_loop( _t3 = time.time() try: from src.context_compactor import trim_for_context - from src.context_budget import compute_input_token_budget, DEFAULT_HARD_MAX - from src.settings import is_setting_overridden + from src.context_budget import compute_input_token_budget, DEFAULT_HARD_MAX, DEFAULT_BUDGET, budget_is_explicit as _budget_is_explicit + from src.model_context import budget_context_for_model - soft_budget = int(get_setting("agent_input_token_budget", 6000) or 0) + soft_budget = int(get_setting("agent_input_token_budget", DEFAULT_BUDGET) or 0) if soft_budget > 0: before_trim_tokens = estimate_tokens(messages) reserve_tokens = min(max(max_tokens or 1024, 512), 2048) - # Honour the configurable ceiling for the auto-derived budget path. - # No-op when the user has an explicit `agent_input_token_budget` - # (that branch ignores hard_max). Falls back to DEFAULT_HARD_MAX - # on missing/malformed values so misconfig can't zero the budget. + # Ceiling for the auto-derived budget (no effect on an explicit budget; + # see #1230). Falls back to DEFAULT_HARD_MAX on missing/malformed values + # so misconfig can't zero the budget. try: hard_max = int(get_setting("agent_input_token_hard_max", DEFAULT_HARD_MAX) or DEFAULT_HARD_MAX) except (TypeError, ValueError): hard_max = DEFAULT_HARD_MAX if hard_max <= 0: hard_max = DEFAULT_HARD_MAX - # Scale the default budget to the model's context window so long-context - # models aren't silently capped at 6000; an explicit user setting is - # still honoured (clamped to the window). (#1170) + # Default value = auto sentinel (scale to the window); any other value = + # explicit cap. Value-based, not presence-based, because the save path + # materializes defaults so a persisted default must still read as auto (#4121). + budget_is_explicit = _budget_is_explicit(soft_budget) + # Scale only off a window we actually discovered, bound to the value it + # proves (else 0) — not the passed-in context_length, which can be stale + # or unset for some callers (#4122 review). + ctx_for_budget = budget_context_for_model(endpoint_url, model, fallback=context_length) effective_budget = compute_input_token_budget( soft_budget, - context_length, - is_setting_overridden("agent_input_token_budget"), + ctx_for_budget, + budget_is_explicit, hard_max=hard_max, ) trimmed_messages = trim_for_context( @@ -2161,11 +2235,12 @@ async def stream_agent_loop( # tool, so we don't nudge on harmless transitional text like "let me # know what you think". _INTENT_RE = re.compile( - r"(?:^|\n)\s*(?:let me|i'?ll|i will|going to|let's)\s+" + r"(?:^|\n)\s*(?:let me|i'?ll|i will|i need to|we need to|need to|" + r"i should|we should|i must|we must|going to|let's)\s+" r"(?:tail|check|investigate|look at|see|tail|read|fetch|inspect|" r"verify|diagnose|examine|debug|capture|grab|pull|view|run|call|" r"trigger|launch|start|kick off|stop|kill|restart|adopt|serve|" - r"register|adopt|list|search|find|query|hit|ping|test)" + r"register|adopt|list|search|find|query|hit|ping|test|use|perform|do)" r"\b[^.\n]{0,140}", re.IGNORECASE, ) @@ -2206,9 +2281,17 @@ async def stream_agent_loop( elif _is_api_model: # Filter schemas by RAG-selected tools (if available) if _relevant_tools: + # _build_base_prompt unions _ADMIN_TOOLS into the prompt + # sections when admin intent fires — the schema list must + # offer the same names, or the model reads prose describing + # tools it cannot call and substitutes the nearest schema + # it does have (e.g. manage_memory for manage_skills). + _schema_names = set(_relevant_tools) + if _needs_admin: + _schema_names |= _ADMIN_TOOLS base_schemas = [ s for s in FUNCTION_TOOL_SCHEMAS - if s.get("function", {}).get("name") in _relevant_tools + if s.get("function", {}).get("name") in _schema_names ] _mcp_filtered = [ s for s in mcp_schemas @@ -2254,6 +2337,7 @@ async def stream_agent_loop( prompt_type=prompt_type if round_num == 1 else None, tools=all_tool_schemas if all_tool_schemas else None, timeout=agent_stream_timeout, + session_id=session_id, ): if time.time() > _round_deadline: logger.warning(f"[agent] round {round_num} stream exceeded wall-clock deadline; cutting off") @@ -2743,6 +2827,46 @@ async def stream_agent_loop( ) desc, result = await _tool_task + # A skill the model just loaded can prescribe tools that weren't + # RAG-selected this turn (declared via requires_toolsets in its + # frontmatter). Union them into the selection so the NEXT round's + # schema list includes them — otherwise the model reads "use + # grep" from the skill it fetched but has no grep schema to call. + if ( + block.tool_type == "manage_skills" + and _relevant_tools is not None + and not result.get("error") + ): + _ms_args = {} + _ms_raw = (block.content or "").strip() + if _ms_raw.startswith("{"): + try: + _ms_args = json.loads(_ms_raw) + except json.JSONDecodeError: + _ms_args = {} + _ms_name = str(_ms_args.get("name", "") or "").strip() + if _ms_name and _ms_args.get("action") in ("view", "view_ref"): + try: + from services.memory.skills import SkillsManager as _SkM + from src.constants import DATA_DIR as _DD + from src.tool_policy import known_tool_names as _ktn + _known = _ktn() + for _sk in _SkM(_DD).load(owner=owner): + if _sk.get("name") == _ms_name: + _new = { + t for t in (_sk.get("requires_toolsets") or []) + if t in _known and t not in _relevant_tools + } + if _new: + _relevant_tools.update(_new) + logger.info( + "[tool-rag] skill '%s' unlocked tools for next round: %s", + _ms_name, sorted(_new), + ) + break + except Exception as _e: + logger.debug(f"skill requires_toolsets unlock skipped: {_e}") + # Extract structured web sources from web_search tool output. # web_search returns {"output": ..., "exit_code": 0}; check "output" # first so the marker is found and stripped even @@ -2833,18 +2957,20 @@ async def stream_agent_loop( # On a bash/python timeout the result carries error + (often # empty) stdout/stderr; fall back to the error so the "timed # out" reason reaches the UI instead of a blank result. - output_text = (result["stdout"] or result["stderr"] or result.get("error", ""))[:2000] + raw = result["stdout"] or result["stderr"] or result.get("error", "") + output_text = _truncate(raw) elif "output" in result: # bash / python canonical result: {"output": ..., "exit_code": ...} - output_text = (result["output"] or "")[:2000] + raw = result["output"] or "" + output_text = _truncate(raw) elif "response" in result: # AI interaction tools (chat_with_model, send_to_session) label = result.get("model", result.get("session_name", "AI")) - output_text = f"{label}: {result['response']}"[:4000] + output_text = _truncate(f"{label}: {result['response']}") elif "content" in result: - output_text = result["content"][:2000] + output_text = _truncate(result["content"]) elif "results" in result: - output_text = result["results"][:4000] + output_text = _truncate(result["results"]) elif "session_id" in result and "name" in result: output_text = f"Session created: {result['name']} (id: {result['session_id']})" elif "success" in result: @@ -2854,7 +2980,7 @@ async def stream_agent_loop( else f"Error: {result.get('error', '')}" ) elif "error" in result: - output_text = result["error"][:2000] + output_text = _truncate(result["error"]) # Emit tool_output (include ui_event data if present) tool_output_data = {"type": "tool_output", "tool": block.tool_type, "command": cmd_display, "output": output_text, "exit_code": result.get("exit_code")} diff --git a/src/agent_tools.py b/src/agent_tools/__init__.py similarity index 76% rename from src/agent_tools.py rename to src/agent_tools/__init__.py index c7eea4541..52fe4a99c 100644 --- a/src/agent_tools.py +++ b/src/agent_tools/__init__.py @@ -18,6 +18,30 @@ from src.tool_utils import _truncate, get_mcp_manager, set_mcp_manager logger = logging.getLogger(__name__) +from .subprocess_tools import BashTool, PythonTool +from .web_tools import WebSearchTool, WebFetchTool +from .filesystem_tools import ReadFileTool, WriteFileTool, EditFileTool, LsTool, GlobTool, GrepTool, GetWorkspaceTool +from .document_tools import CreateDocumentTool, UpdateDocumentTool, EditDocumentTool, SuggestDocumentTool, ManageDocumentTool + +TOOL_HANDLERS = { + "bash": BashTool().execute, + "python": PythonTool().execute, + "web_search": WebSearchTool().execute, + "web_fetch": WebFetchTool().execute, + "read_file": ReadFileTool().execute, + "write_file": WriteFileTool().execute, + "edit_file": EditFileTool().execute, + "ls": LsTool().execute, + "glob": GlobTool().execute, + "grep": GrepTool().execute, + "create_document": CreateDocumentTool().execute, + "update_document": UpdateDocumentTool().execute, + "edit_document": EditDocumentTool().execute, + "suggest_document": SuggestDocumentTool().execute, + "manage_documents": ManageDocumentTool().execute, + "get_workspace": GetWorkspaceTool().execute, +} + # --------------------------------------------------------------------------- # Constants (re-exported for backward compatibility — single source of truth # is src.constants; always prefer importing from there for new code) @@ -28,7 +52,7 @@ PYTHON_TIMEOUT = 30 # Tool types that trigger execution TOOL_TAGS = {"bash", "python", "web_search", "web_fetch", "read_file", "write_file", "edit_file", - "grep", "glob", "ls", + "grep", "glob", "ls", "get_workspace", "create_document", "update_document", "edit_document", "search_chats", "chat_with_model", "create_session", "list_sessions", @@ -92,15 +116,14 @@ from src.tool_execution import ( # noqa: E402, F401 format_tool_result, ) +# Document functions +from .document_tools import ( + set_active_document, + set_active_model +) + # Implementations from src.tool_implementations import ( # noqa: E402, F401 - set_active_document, - set_active_model, - get_active_document, - do_create_document, - do_update_document, - do_edit_document, - do_suggest_document, do_search_chats, do_manage_skills, do_manage_tasks, @@ -108,7 +131,6 @@ from src.tool_implementations import ( # noqa: E402, F401 do_manage_mcp, do_manage_webhooks, do_manage_tokens, - do_manage_documents, do_manage_settings, do_api_call, ) diff --git a/src/agent_tools/document_tools.py b/src/agent_tools/document_tools.py new file mode 100644 index 000000000..33b10c8d3 --- /dev/null +++ b/src/agent_tools/document_tools.py @@ -0,0 +1,644 @@ +from typing import Any, Dict, List, Optional +import logging +import re +import json +from src.constants import MAX_READ_CHARS + +logger = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# Active document state +# --------------------------------------------------------------------------- + +_active_document_id: Optional[str] = None +_active_model: Optional[str] = None + + +def set_active_document(doc_id: Optional[str]): + """Set the active document ID for document tool execution.""" + global _active_document_id + _active_document_id = doc_id + + +def set_active_model(model: Optional[str]): + """Set the current model name for version summaries.""" + global _active_model + _active_model = model + + +def get_active_document(): + return _active_document_id + + +def clear_active_document(doc_id: Optional[str] = None) -> bool: + """Clear the in-memory active-document pointer. + + With ``doc_id`` given, only clears when it matches the current pointer, so a + different active document is left untouched. Returns True if it was cleared. + + Called when a document is detached from its session or deleted (its tab is + closed): without this, the stale pointer makes the last-resort doc-injection + path re-surface a closed document in a later, unrelated chat — even one whose + session no longer matches — because an unlinked doc has session_id NULL (#1160). + """ + global _active_document_id + if doc_id is None or _active_document_id == doc_id: + _active_document_id = None + return True + return False + + +def _owned_document_query(query, Document, owner: Optional[str]): + if owner is None: + # A bare Python `False` is not a valid SQL expression — SQLAlchemy 1.4 + # deprecates it and 2.0 raises ArgumentError. Use the SQL `false()` + # literal to return zero rows for an unscoped (owner-less) query. + from sqlalchemy import false + return query.filter(false()) + return query.filter(Document.owner == owner) + + +def _get_owned_document(db, Document, doc_id: str, owner: Optional[str], active_only: bool = False): + q = db.query(Document).filter(Document.id == doc_id) + if active_only: + q = q.filter(Document.is_active == True) + q = _owned_document_query(q, Document, owner) + return q.first() + + +def _most_recent_owned_document(db, Document, owner: Optional[str], active_only: bool = False): + q = db.query(Document) + if active_only: + q = q.filter(Document.is_active == True) + q = _owned_document_query(q, Document, owner) + return q.order_by(Document.updated_at.desc()).first() + + +# --------------------------------------------------------------------------- +# Document tools — create/update/edit/suggest living documents +# --------------------------------------------------------------------------- + +def _sniff_doc_language(text: str) -> str: + """Best-effort detect a document's language from its content when the model + didn't specify one. Defaults to 'markdown' (prose). Recognizes the common + markup/code types the editor supports so e.g. an SVG isn't saved as markdown.""" + import json as _json, re as _re2 + s = (text or "").strip() + if not s: + return "markdown" + head = s[:600] + hl = head.lower() + if _looks_like_email_document(s): + return "email" + # Markup (unambiguous) + if "