diff --git a/.gitignore b/.gitignore index c48f6cd61..846e6cf74 100644 --- a/.gitignore +++ b/.gitignore @@ -89,3 +89,4 @@ docs/windows-port/ compound.config.json *.error.log _scratch/ +/odysseus/ diff --git a/src/agent_tools.py b/src/agent_tools/__init__.py similarity index 87% rename from src/agent_tools.py rename to src/agent_tools/__init__.py index c7eea4541..a90a061e5 100644 --- a/src/agent_tools.py +++ b/src/agent_tools/__init__.py @@ -18,6 +18,23 @@ from src.tool_utils import _truncate, get_mcp_manager, set_mcp_manager logger = logging.getLogger(__name__) +from .subprocess_tools import BashTool, PythonTool +from .web_tools import WebSearchTool, WebFetchTool +from .filesystem_tools import ReadFileTool, WriteFileTool, EditFileTool, LsTool, GlobTool, GrepTool + +TOOL_HANDLERS = { + "bash": BashTool().execute, + "python": PythonTool().execute, + "web_search": WebSearchTool().execute, + "web_fetch": WebFetchTool().execute, + "read_file": ReadFileTool().execute, + "write_file": WriteFileTool().execute, + "edit_file": EditFileTool().execute, + "ls": LsTool().execute, + "glob": GlobTool().execute, + "grep": GrepTool().execute, +} + # --------------------------------------------------------------------------- # Constants (re-exported for backward compatibility — single source of truth # is src.constants; always prefer importing from there for new code) diff --git a/src/agent_tools/filesystem_tools.py b/src/agent_tools/filesystem_tools.py new file mode 100644 index 000000000..3b5425242 --- /dev/null +++ b/src/agent_tools/filesystem_tools.py @@ -0,0 +1,419 @@ +import asyncio +import json +import os +import difflib +import fnmatch +import shutil +from typing import Optional, Dict, Any, Tuple + +from src.constants import MAX_READ_CHARS, MAX_DIFF_LINES, MAX_OUTPUT_CHARS + +_CODENAV_SKIP_DIRS = frozenset({ + ".git", ".hg", ".svn", "node_modules", "venv", ".venv", "__pycache__", + ".mypy_cache", ".pytest_cache", ".ruff_cache", "dist", "build", + ".next", ".cache", "site-packages", ".idea", ".tox", +}) +_CODENAV_MAX_HITS = 200 +_CODENAV_MAX_LINE = 400 + +def _unified_diff(old: str, new: str, path: str) -> Optional[Dict[str, Any]]: + if old == new: + return None + old_lines = old.splitlines() + new_lines = new.splitlines() + label = path or "file" + diff_lines = list(difflib.unified_diff( + old_lines, new_lines, + fromfile=f"a/{label}", tofile=f"b/{label}", + lineterm="", + )) + added = sum(1 for line in diff_lines if line.startswith("+") and not line.startswith("+++")) + removed = sum(1 for line in diff_lines if line.startswith("-") and not line.startswith("---")) + truncated = False + if len(diff_lines) > MAX_DIFF_LINES: + diff_lines = diff_lines[:MAX_DIFF_LINES] + truncated = True + text = "\n".join(diff_lines) + if truncated: + text += f"\n… diff truncated at {MAX_DIFF_LINES} lines" + return { + "text": text, + "added": added, + "removed": removed, + "new_file": old == "", + "file": os.path.basename(path) or (path or "file"), + } + +class EditFileTool: + async def execute(self, content: str, ctx: dict) -> dict: + from src.tool_execution import ( + _resolve_tool_path, + _resolve_tool_path_in_workspace, + _resolve_search_root, + _truncate + ) + workspace = ctx.get("workspace") + try: + args = json.loads(content) if content.strip().startswith("{") else {} + except (json.JSONDecodeError, TypeError): + args = {} + raw_path = (args.get("path") or "").strip() + old = args.get("old_string", "") + new = args.get("new_string", "") + replace_all = bool(args.get("replace_all", False)) + if not raw_path: + return {"error": "edit_file: path required", "exit_code": 1} + try: + path = (_resolve_tool_path_in_workspace(workspace, raw_path) + if workspace else _resolve_tool_path(raw_path)) + except ValueError as e: + return {"error": f"edit_file: {e}", "exit_code": 1} + if old == "": + return {"error": "edit_file: old_string required (use write_file to create a file)", "exit_code": 1} + if old == new: + return {"error": "edit_file: old_string and new_string are identical", "exit_code": 1} + + def _apply(): + """Helper function that performs the actual string replacement and file writing logic.""" + with open(path, "r", encoding="utf-8") as f: + original = f.read() + count = original.count(old) + if count == 0: + return original, None, "not_found" + if count > 1 and not replace_all: + return original, None, f"not_unique:{count}" + updated = original.replace(old, new) if replace_all else original.replace(old, new, 1) + with open(path, "w", encoding="utf-8") as f: + f.write(updated) + return original, updated, "ok" + + try: + original, updated, status = await asyncio.to_thread(_apply) + except FileNotFoundError: + return {"error": f"edit_file: {path}: not found (use write_file to create it)", "exit_code": 1} + except (IsADirectoryError, UnicodeDecodeError): + return {"error": f"edit_file: {path}: not an editable text file", "exit_code": 1} + except PermissionError: + return {"error": f"edit_file: {path}: permission denied", "exit_code": 1} + except OSError as e: + return {"error": f"edit_file: {path}: {e}", "exit_code": 1} + + if status == "not_found": + return {"error": f"edit_file: old_string not found in {path}. Read the file and match it exactly.", "exit_code": 1} + if status.startswith("not_unique"): + n = status.split(":", 1)[1] + return {"error": f"edit_file: old_string is not unique in {path} ({n} matches). Add surrounding context or set replace_all=true.", "exit_code": 1} + + n = original.count(old) + result = {"output": f"Edited {path} ({n} replacement{'s' if n != 1 else ''})", "exit_code": 0} + diff = _unified_diff(original, updated, path) + if diff: + result["diff"] = diff + return result + +class ReadFileTool: + async def execute(self, content: str, ctx: dict) -> dict: + from src.tool_execution import ( + _resolve_tool_path, + _resolve_tool_path_in_workspace, + _resolve_search_root, + _truncate + ) + workspace = ctx.get("workspace") + raw_path, offset, limit = content.split("\n", 1)[0].strip(), 0, 0 + _stripped = content.strip() + if _stripped.startswith("{"): + try: + _a = json.loads(_stripped) + raw_path = str(_a.get("path", "")).strip() + offset = int(_a.get("offset") or 0) + limit = int(_a.get("limit") or 0) + except (json.JSONDecodeError, TypeError, ValueError): + pass + try: + path = (_resolve_tool_path_in_workspace(workspace, raw_path) + if workspace else _resolve_tool_path(raw_path)) + except ValueError as e: + return {"error": f"read_file: {e}", "exit_code": 1} + try: + def _read(): + if offset > 0 or limit > 0: + start = max(offset, 1) + out, n, budget = [], 0, MAX_READ_CHARS + with open(path, "r", encoding="utf-8", errors="replace") as f: + for i, line in enumerate(f, 1): + if i < start: + continue + if limit > 0 and n >= limit: + break + out.append(line) + n += 1 + budget -= len(line) + if budget <= 0: + out.append(f"\n... [truncated at {MAX_READ_CHARS} chars]") + break + return "".join(out) + with open(path, "r", encoding="utf-8", errors="replace") as f: + return f.read(MAX_READ_CHARS + 1) + data = await asyncio.to_thread(_read) + except FileNotFoundError: + return {"error": f"read_file: {path}: not found", "exit_code": 1} + except PermissionError: + return {"error": f"read_file: {path}: permission denied", "exit_code": 1} + except IsADirectoryError: + return {"error": f"read_file: {path}: is a directory (use ls)", "exit_code": 1} + except OSError as e: + return {"error": f"read_file: {path}: {e}", "exit_code": 1} + if not (offset > 0 or limit > 0) and len(data) > MAX_READ_CHARS: + data = data[:MAX_READ_CHARS] + f"\n... [truncated at {MAX_READ_CHARS} chars]" + return {"output": data, "exit_code": 0} + +class WriteFileTool: + async def execute(self, content: str, ctx: dict) -> dict: + from src.tool_execution import ( + _resolve_tool_path, + _resolve_tool_path_in_workspace, + _resolve_search_root, + _truncate + ) + workspace = ctx.get("workspace") + lines = content.split("\n", 1) + raw_path = lines[0].strip() + body = lines[1] if len(lines) > 1 else "" + try: + path = (_resolve_tool_path_in_workspace(workspace, raw_path) + if workspace else _resolve_tool_path(raw_path)) + except ValueError as e: + return {"error": f"write_file: {e}", "exit_code": 1} + try: + def _write(): + old = "" + try: + with open(path, "r", encoding="utf-8") as f: + old = f.read() + except (FileNotFoundError, IsADirectoryError, UnicodeDecodeError, OSError): + old = "" + d = os.path.dirname(path) + if d: + os.makedirs(d, exist_ok=True) + with open(path, "w", encoding="utf-8") as f: + f.write(body) + return old, len(body) + old_content, size = await asyncio.to_thread(_write) + except PermissionError: + return {"error": f"write_file: {path}: permission denied", "exit_code": 1} + except OSError as e: + return {"error": f"write_file: {path}: {e}", "exit_code": 1} + diff = _unified_diff(old_content, body, path) + result = {"output": f"Wrote {size} bytes to {path}", "exit_code": 0} + if diff: + result["diff"] = diff + return result + +class LsTool: + async def execute(self, content: str, ctx: dict) -> dict: + from src.tool_execution import ( + _resolve_tool_path, + _resolve_tool_path_in_workspace, + _resolve_search_root, + _truncate + ) + workspace = ctx.get("workspace") + raw_path = "" + _s = (content or "").strip() + if _s.startswith("{"): + try: + raw_path = str(json.loads(_s).get("path", "")).strip() + except json.JSONDecodeError: + raw_path = "" + else: + raw_path = _s.split("\n", 1)[0].strip() + try: + root = _resolve_search_root(raw_path) + except ValueError as e: + return {"error": f"ls: {e}", "exit_code": 1} + + def _ls(): + if not os.path.isdir(root): + return None, f"ls: {root}: not a directory" + rows = [] + try: + with os.scandir(root) as it: + for entry in it: + if entry.name.startswith("."): + continue + try: + is_dir = entry.is_dir(follow_symlinks=False) + size = entry.stat(follow_symlinks=False).st_size if not is_dir else 0 + except OSError: + continue + rows.append((is_dir, entry.name, size)) + except (PermissionError, OSError) as _e: + return None, f"ls: {_e}" + rows.sort(key=lambda r: (not r[0], r[1].lower())) + lines = [f"{root}:"] + for is_dir, name, size in rows[:_CODENAV_MAX_HITS]: + lines.append(f" {name}/" if is_dir else f" {name} ({size} B)") + if len(rows) > _CODENAV_MAX_HITS: + lines.append(f" ... [{len(rows) - _CODENAV_MAX_HITS} more]") + if not rows: + lines.append(" (empty)") + return "\n".join(lines), None + + out, err = await asyncio.to_thread(_ls) + if err: + return {"error": err, "exit_code": 1} + return {"output": _truncate(out), "exit_code": 0} + +class GlobTool: + async def execute(self, content: str, ctx: dict) -> dict: + from src.tool_execution import ( + _resolve_tool_path, + _resolve_tool_path_in_workspace, + _resolve_search_root, + _truncate + ) + workspace = ctx.get("workspace") + args = {} + _s = (content or "").strip() + if _s.startswith("{"): + try: + args = json.loads(_s) + except json.JSONDecodeError: + args = {} + else: + args = {"pattern": _s} + pattern = str(args.get("pattern", "")).strip() + if not pattern: + return {"error": "glob: pattern is required", "exit_code": 1} + try: + root = _resolve_search_root(str(args.get("path", ""))) + except ValueError as e: + return {"error": f"glob: {e}", "exit_code": 1} + + def _glob(): + from pathlib import Path + base = Path(root) + if not base.is_dir(): + return None, f"glob: {root}: not a directory" + matched = [] + try: + for p in base.rglob(pattern): + if set(p.relative_to(base).parts) & _CODENAV_SKIP_DIRS: + continue + try: + mtime = p.stat().st_mtime + except OSError: + mtime = 0 + matched.append((mtime, str(p))) + if len(matched) > _CODENAV_MAX_HITS * 5: + break + except (OSError, ValueError) as _e: + return None, f"glob: {_e}" + matched.sort(key=lambda t: t[0], reverse=True) + return [pth for _, pth in matched[:_CODENAV_MAX_HITS]], None + + paths, err = await asyncio.to_thread(_glob) + if err: + return {"error": err, "exit_code": 1} + if not paths: + return {"output": f"No files matching {pattern!r} under {root}", "exit_code": 0} + out = "\n".join(paths) + if len(paths) >= _CODENAV_MAX_HITS: + out += f"\n... [capped at {_CODENAV_MAX_HITS} files]" + return {"output": _truncate(out), "exit_code": 0} + +class GrepTool: + async def execute(self, content: str, ctx: dict) -> dict: + from src.tool_execution import ( + _resolve_tool_path, + _resolve_tool_path_in_workspace, + _resolve_search_root, + _truncate + ) + workspace = ctx.get("workspace") + args: Dict[str, Any] = {} + _s = (content or "").strip() + if _s.startswith("{"): + try: + args = json.loads(_s) + except json.JSONDecodeError: + args = {} + else: + args = {"pattern": _s} + pattern = str(args.get("pattern", "")).strip() + if not pattern: + return {"error": "grep: pattern is required", "exit_code": 1} + ignore_case = bool(args.get("ignore_case")) + glob_pat = str(args.get("glob", "") or "").strip() + try: + max_hits = int(args.get("max_results") or _CODENAV_MAX_HITS) + except (TypeError, ValueError): + max_hits = _CODENAV_MAX_HITS + max_hits = max(1, min(max_hits, _CODENAV_MAX_HITS)) + try: + root = _resolve_search_root(str(args.get("path", ""))) + except ValueError as e: + return {"error": f"grep: {e}", "exit_code": 1} + + def _grep(): + import re as _re + import shutil + rg = shutil.which("rg") + if rg: + cmd = [rg, "--line-number", "--no-heading", "--color=never", + "--max-count", str(max_hits)] + if ignore_case: + cmd.append("--ignore-case") + if glob_pat: + cmd += ["--glob", glob_pat] + for _d in _CODENAV_SKIP_DIRS: + cmd += ["--glob", f"!**/{_d}/**"] + cmd += ["--regexp", pattern, root] + try: + import subprocess + p = subprocess.run(cmd, capture_output=True, text=True, timeout=20) + lines = [ln for ln in (p.stdout or "").splitlines() if ln][:max_hits] + return lines, None + except subprocess.TimeoutExpired: + return None, "grep: timed out" + except Exception as _e: + return None, f"grep: {_e}" + try: + rx = _re.compile(pattern, _re.IGNORECASE if ignore_case else 0) + except _re.error as _e: + return None, f"grep: bad pattern: {_e}" + hits = [] + if os.path.isfile(root): + file_iter = [root] + else: + file_iter = [] + for dp, dns, fns in os.walk(root): + dns[:] = [d for d in dns if d not in _CODENAV_SKIP_DIRS] + for fn in fns: + if glob_pat and not fnmatch.fnmatch(fn, glob_pat): + continue + file_iter.append(os.path.join(dp, fn)) + for fp in file_iter: + if len(hits) >= max_hits: + break + try: + with open(fp, "r", encoding="utf-8", errors="strict") as f: + for i, line in enumerate(f, 1): + if rx.search(line): + hits.append(f"{fp}:{i}:{line.rstrip()[:_CODENAV_MAX_LINE]}") + if len(hits) >= max_hits: + break + except (UnicodeDecodeError, OSError): + continue + return hits, None + + lines, err = await asyncio.to_thread(_grep) + if err: + return {"error": err, "exit_code": 1} + if not lines: + return {"output": f"No matches for {pattern!r} under {root}", "exit_code": 0} + out = "\n".join(ln[:_CODENAV_MAX_LINE] for ln in lines) + if len(lines) >= max_hits: + out += f"\n... [capped at {max_hits} matches]" + return {"output": _truncate(out), "exit_code": 0} diff --git a/src/agent_tools/subprocess_tools.py b/src/agent_tools/subprocess_tools.py new file mode 100644 index 000000000..6b5972030 --- /dev/null +++ b/src/agent_tools/subprocess_tools.py @@ -0,0 +1,155 @@ +import asyncio +import sys +import time +import collections +from typing import Optional, Callable, Awaitable, Tuple, Dict +from src.constants import MAX_OUTPUT_CHARS + +DEFAULT_BASH_TIMEOUT = 60 * 60 # 1 hour +DEFAULT_PYTHON_TIMEOUT = 60 * 60 + +PROGRESS_INTERVAL_S = 2.0 +PROGRESS_TAIL_LINES = 12 + +async def _run_subprocess_streaming( + proc: asyncio.subprocess.Process, + *, + timeout: float, + progress_cb: Optional[Callable[[Dict], Awaitable[None]]] = None, +) -> Tuple[str, str, Optional[int], bool]: + started = time.time() + stdout_full: list[str] = [] + stderr_full: list[str] = [] + tail = collections.deque(maxlen=PROGRESS_TAIL_LINES) + + async def _reader(stream, full_buf, label: str): + if stream is None: + return + while True: + line = await stream.readline() + if not line: + break + decoded = line.decode("utf-8", errors="replace").rstrip("\n") + full_buf.append(decoded) + if label == "err": + tail.append(f"! {decoded}") + else: + tail.append(decoded) + + async def _progress_emitter(): + await asyncio.sleep(PROGRESS_INTERVAL_S) + while True: + if progress_cb: + try: + await progress_cb({ + "elapsed_s": round(time.time() - started, 1), + "tail": "\n".join(list(tail)), + }) + except Exception: + pass + await asyncio.sleep(PROGRESS_INTERVAL_S) + + rd_out = asyncio.create_task(_reader(proc.stdout, stdout_full, "out")) + rd_err = asyncio.create_task(_reader(proc.stderr, stderr_full, "err")) + prog_task = asyncio.create_task(_progress_emitter()) if progress_cb else None + + timed_out = False + try: + await asyncio.wait_for(proc.wait(), timeout=timeout) + except asyncio.TimeoutError: + timed_out = True + try: + proc.kill() + except Exception: + pass + try: + await asyncio.wait_for(proc.wait(), timeout=2) + except Exception: + pass + except asyncio.CancelledError: + try: + proc.kill() + except Exception: + pass + try: + await asyncio.wait_for(proc.wait(), timeout=2) + except Exception: + pass + for t in (rd_out, rd_err): + t.cancel() + if prog_task is not None: + prog_task.cancel() + raise + finally: + if prog_task is not None and not prog_task.done(): + prog_task.cancel() + try: + await prog_task + except (asyncio.CancelledError, Exception): + pass + for t in (rd_out, rd_err): + try: + await asyncio.wait_for(t, timeout=1) + except Exception: + pass + + return ( + "\n".join(stdout_full), + "\n".join(stderr_full), + proc.returncode, + timed_out, + ) + +class BashTool: + async def execute(self, content: str, ctx: dict) -> dict: + from src.tool_execution import _AGENT_WORKDIR, _truncate + progress_cb = ctx.get("progress_cb") + workspace = ctx.get("workspace") + _subproc_env = ctx.get("subproc_env") + proc = await asyncio.create_subprocess_shell( + content, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + env=_subproc_env, + cwd=workspace or _AGENT_WORKDIR, + ) + stdout, stderr, rc, timed_out = await _run_subprocess_streaming( + proc, + timeout=DEFAULT_BASH_TIMEOUT, + progress_cb=progress_cb, + ) + if timed_out: + return {"error": f"bash: timed out after {DEFAULT_BASH_TIMEOUT}s — process killed", "exit_code": 124, "stdout": _truncate(stdout, MAX_OUTPUT_CHARS), "stderr": _truncate(stderr, MAX_OUTPUT_CHARS)} + output = stdout.rstrip() + err = stderr.rstrip() + if err: + output = (output + "\nSTDERR: " + err).strip() if output else "STDERR: " + err + output = _truncate(output, MAX_OUTPUT_CHARS) + return {"output": output or "(no output)", "exit_code": rc or 0} + +class PythonTool: + async def execute(self, content: str, ctx: dict) -> dict: + from src.tool_execution import _AGENT_WORKDIR, _truncate + progress_cb = ctx.get("progress_cb") + workspace = ctx.get("workspace") + _subproc_env = ctx.get("subproc_env") + proc = await asyncio.create_subprocess_exec( + (sys.executable or "python"), "-I", "-c", content, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + env=_subproc_env, + cwd=workspace or _AGENT_WORKDIR, + ) + stdout, stderr, rc, timed_out = await _run_subprocess_streaming( + proc, + timeout=DEFAULT_PYTHON_TIMEOUT, + progress_cb=progress_cb, + ) + if timed_out: + return {"error": f"python: timed out after {DEFAULT_PYTHON_TIMEOUT}s — process killed", "exit_code": 124, "stdout": _truncate(stdout, MAX_OUTPUT_CHARS), "stderr": _truncate(stderr, MAX_OUTPUT_CHARS)} + output = stdout.rstrip() + err = stderr.rstrip() + if err: + output = (output + "\nSTDERR: " + err).strip() if output else "STDERR: " + err + output = _truncate(output, MAX_OUTPUT_CHARS) + return {"output": output or "(no output)", "exit_code": rc or 0} diff --git a/src/agent_tools/web_tools.py b/src/agent_tools/web_tools.py new file mode 100644 index 000000000..87a4b697f --- /dev/null +++ b/src/agent_tools/web_tools.py @@ -0,0 +1,101 @@ +import asyncio +import json +from typing import Dict, Any + +from src.constants import MAX_OUTPUT_CHARS + +class WebSearchTool: + async def execute(self, content: str, ctx: dict) -> dict: + from src.search import comprehensive_web_search + raw = content.strip() + query = raw + time_filter = None + max_pages = 5 + if raw.startswith("{"): + try: + parsed = json.loads(raw) + if isinstance(parsed, dict) and "query" in parsed: + query = str(parsed.get("query", "")).strip() + tf = parsed.get("time_filter") or parsed.get("freshness") + if isinstance(tf, str) and tf.lower() in ("day", "week", "month", "year"): + time_filter = tf.lower() + mp = parsed.get("max_pages") + if isinstance(mp, int) and 1 <= mp <= 10: + max_pages = mp + except json.JSONDecodeError: + pass + if not query: + query = raw.split("\n")[0].strip() + if time_filter is None: + q_lc = query.lower() + if any(kw in q_lc for kw in ("today", "latest", "breaking", "this morning", "right now", "currently")): + time_filter = "day" + elif any(kw in q_lc for kw in ("this week", "past week", "recent news", "last few days")): + time_filter = "week" + elif any(kw in q_lc for kw in ("this month", "past month")): + time_filter = "month" + elif " news" in q_lc or q_lc.startswith("news ") or q_lc.endswith(" news"): + time_filter = "week" + loop = asyncio.get_running_loop() + text, sources = await asyncio.wait_for( + loop.run_in_executor( + None, + lambda: comprehensive_web_search( + query, + max_pages=max_pages, + time_filter=time_filter, + return_sources=True, + ), + ), + timeout=30, + ) + output = text[:MAX_OUTPUT_CHARS] if len(text) > MAX_OUTPUT_CHARS else text + if sources: + output += "\n\n" + return {"output": output, "exit_code": 0} + +class WebFetchTool: + async def execute(self, content: str, ctx: dict) -> dict: + from src.search.content import fetch_webpage_content + raw = content.strip() + url = "" + if raw.startswith("{"): + try: + parsed = json.loads(raw) + if isinstance(parsed, dict): + url = str(parsed.get("url") or "").strip() + except json.JSONDecodeError: + url = "" + if not url: + url = raw.split("\n")[0].strip() + if not url or url.startswith("{") or any(c in url for c in (" ", "\t", "\n")): + return {"error": "web_fetch: provide a single URL or domain, e.g. example.com", "exit_code": 1} + low = url.lower() + if "://" in low and not low.startswith(("http://", "https://")): + return {"error": f"web_fetch: unsupported URL scheme (only http/https): {url[:80]}", "exit_code": 1} + if not low.startswith(("http://", "https://")): + url = "https://" + url + loop = asyncio.get_running_loop() + try: + result = await asyncio.wait_for( + loop.run_in_executor(None, lambda: fetch_webpage_content(url, timeout=10)), + timeout=30, + ) + except asyncio.TimeoutError: + return {"error": f"web_fetch: timed out fetching {url}", "exit_code": 1} + except Exception as e: + return {"error": f"web_fetch: {url}: {e}", "exit_code": 1} + err = result.get("error") + text = (result.get("content") or "").strip() + title = result.get("title") or "" + + if not text: + if err: + return {"error": f"web_fetch: {url}: {err}", "exit_code": 1} + return {"error": f"web_fetch: {url}: no readable text content (not HTML, or the page needs JS/login)", "exit_code": 1} + + header = (f"# {title}\n" if title else "") + f"Source: {url}\n\n" + output = header + text + if len(output) > MAX_OUTPUT_CHARS: + output = output[:MAX_OUTPUT_CHARS] + "\n\n[...truncated]" + return {"output": output, "exit_code": 0} diff --git a/src/tool_execution.py b/src/tool_execution.py index 704f3f48e..662cc7268 100644 --- a/src/tool_execution.py +++ b/src/tool_execution.py @@ -18,6 +18,8 @@ import sys import time from typing import Any, Awaitable, Callable, Dict, Optional, Tuple + + from src.tool_security import is_public_blocked_tool, owner_is_admin_or_single_user from src.tool_policy import ToolPolicy from src.constants import MAX_OUTPUT_CHARS, MAX_READ_CHARS, MAX_DIFF_LINES, DATA_DIR @@ -31,105 +33,6 @@ from src.tool_utils import _truncate, get_mcp_manager _AGENT_WORKDIR = DATA_DIR -def _unified_diff(old: str, new: str, path: str) -> Optional[Dict[str, Any]]: - """Build a unified diff of a file write for display in the chat. - - Returns {"text": , "added": N, "removed": M, "new_file": bool} - or None when there's no textual change. Truncates very large diffs. - """ - if old == new: - return None - import difflib - - old_lines = old.splitlines() - new_lines = new.splitlines() - label = path or "file" - diff_lines = list(difflib.unified_diff( - old_lines, new_lines, - fromfile=f"a/{label}", tofile=f"b/{label}", - lineterm="", - )) - added = sum(1 for line in diff_lines if line.startswith("+") and not line.startswith("+++")) - removed = sum(1 for line in diff_lines if line.startswith("-") and not line.startswith("---")) - truncated = False - if len(diff_lines) > MAX_DIFF_LINES: - diff_lines = diff_lines[:MAX_DIFF_LINES] - truncated = True - text = "\n".join(diff_lines) - if truncated: - text += f"\n… diff truncated at {MAX_DIFF_LINES} lines" - return { - "text": text, - "added": added, - "removed": removed, - "new_file": old == "", - "file": os.path.basename(path) or (path or "file"), - } - - -async def _do_edit_file(content: str) -> Dict[str, Any]: - """Exact string-replacement edit of an on-disk file. - - content is JSON: {"path", "old_string", "new_string", "replace_all"?}. - Fails if old_string is missing or non-unique (unless replace_all) so the - model can't silently edit the wrong place. Returns a unified diff for the UI. - """ - try: - args = json.loads(content) if content.strip().startswith("{") else {} - except (json.JSONDecodeError, TypeError): - args = {} - raw_path = (args.get("path") or "").strip() - old = args.get("old_string", "") - new = args.get("new_string", "") - replace_all = bool(args.get("replace_all", False)) - if not raw_path: - return {"error": "edit_file: path required", "exit_code": 1} - # Allowlist + sensitive-file policy as read/write_file. - try: - path = _resolve_tool_path(raw_path) - except ValueError as e: - return {"error": f"edit_file: {e}", "exit_code": 1} - if old == "": - return {"error": "edit_file: old_string required (use write_file to create a file)", "exit_code": 1} - if old == new: - return {"error": "edit_file: old_string and new_string are identical", "exit_code": 1} - - def _apply(): - with open(path, "r", encoding="utf-8") as f: - original = f.read() - count = original.count(old) - if count == 0: - return original, None, "not_found" - if count > 1 and not replace_all: - return original, None, f"not_unique:{count}" - updated = original.replace(old, new) if replace_all else original.replace(old, new, 1) - with open(path, "w", encoding="utf-8") as f: - f.write(updated) - return original, updated, "ok" - - try: - original, updated, status = await asyncio.to_thread(_apply) - except FileNotFoundError: - return {"error": f"edit_file: {path}: not found (use write_file to create it)", "exit_code": 1} - except (IsADirectoryError, UnicodeDecodeError): - return {"error": f"edit_file: {path}: not an editable text file", "exit_code": 1} - except PermissionError: - return {"error": f"edit_file: {path}: permission denied", "exit_code": 1} - except OSError as e: - return {"error": f"edit_file: {path}: {e}", "exit_code": 1} - - if status == "not_found": - return {"error": f"edit_file: old_string not found in {path}. Read the file and match it exactly.", "exit_code": 1} - if status.startswith("not_unique"): - n = status.split(":", 1)[1] - return {"error": f"edit_file: old_string is not unique in {path} ({n} matches). Add surrounding context or set replace_all=true.", "exit_code": 1} - - n = original.count(old) - result = {"output": f"Edited {path} ({n} replacement{'s' if n != 1 else ''})", "exit_code": 0} - diff = _unified_diff(original, updated, path) - if diff: - result["diff"] = diff - return result # --------------------------------------------------------------------------- # Path confinement for read_file / write_file @@ -269,40 +172,46 @@ def _resolve_tool_path(raw_path: str) -> str: ) -# Bash + python tools used to share a single 60s timeout. That's -# enough for one-shot commands but starves real workloads (pip -# install, ffmpeg conversions, etc.) — and worse, the agent saw the -# 60s timeout and went silent because it had nothing to report. -# The new default is intentionally generous: long enough that real -# work isn't killed mid-flight, but bounded so a runaway process -# (infinite loop, hung connect, etc.) eventually frees the worker. -# The user can cancel sooner via the chat stop button — when the -# SSE stream is torn down, the asyncio task running the subprocess -# gets cancelled and the subprocess is killed by the finally block. -DEFAULT_BASH_TIMEOUT = 60 * 60 # 1 hour -DEFAULT_PYTHON_TIMEOUT = 60 * 60 +def _resolve_tool_path_in_workspace(workspace: str, raw_path: str) -> str: + """Confine a model-supplied path to the active workspace. + + Layered on top of upstream's path policy: the workspace is the allowed + root (relative paths resolve under it; paths that escape it are rejected), + and the sensitive-file deny list (.ssh, .gnupg, id_rsa, …) still applies + inside it. When no workspace is set, callers use _resolve_tool_path (the + default data/tmp allowlist) instead. + """ + if raw_path is None or not str(raw_path).strip(): + raise ValueError("path is required") + base = os.path.realpath(workspace) + expanded = os.path.expanduser(str(raw_path).strip()) + candidate = expanded if os.path.isabs(expanded) else os.path.join(base, expanded) + resolved = os.path.realpath(candidate) + if _is_sensitive_path(resolved): + raise ValueError( + f"path '{raw_path}' is inside a sensitive directory " + f"(e.g. .ssh, .gnupg) or matches a sensitive filename" + ) + if resolved != base: + # normcase so containment holds on case-insensitive filesystems + # (Windows, default macOS): it lowercases on Windows and is a no-op on + # POSIX. commonpath raises ValueError across Windows drives (C: vs D:) + # or mixed abs/rel — both mean "outside", so the except rejects them. + nbase = os.path.normcase(base) + try: + if os.path.commonpath([os.path.normcase(resolved), nbase]) != nbase: + raise ValueError + except ValueError: + raise ValueError(f"path '{raw_path}' is outside the workspace ({workspace})") + return resolved + + + +def get_mcp_manager(): + from src import agent_tools + return agent_tools.get_mcp_manager() -# How often to push a progress event while a long-running subprocess -# is still in flight. The frontend cares about "alive" more than -# "every-byte" — 2s is the sweet spot. -PROGRESS_INTERVAL_S = 2.0 -# Tail buffer size — we keep the most recent N lines of stdout + -# stderr so the progress event includes a "what's it doing right now" -# snippet without dragging the whole output along. -PROGRESS_TAIL_LINES = 12 -# Directories ignored by the code-nav tools' Python fallbacks so results aren't -# polluted by VCS internals / dependency trees / build caches. ripgrep already -# honours .gitignore; this is the parity floor for the no-rg path (and the -# explicit excludes passed to rg so it skips them even without a .gitignore). -_CODENAV_SKIP_DIRS = frozenset({ - ".git", ".hg", ".svn", "node_modules", "venv", ".venv", "__pycache__", - ".mypy_cache", ".pytest_cache", ".ruff_cache", "dist", "build", - ".next", ".cache", "site-packages", ".idea", ".tox", -}) -# Per-tool result caps (keep tool output cheap + model-friendly). -_CODENAV_MAX_HITS = 200 -_CODENAV_MAX_LINE = 400 def _resolve_search_root(raw_path: str) -> str: @@ -320,116 +229,6 @@ def _resolve_search_root(raw_path: str) -> str: logger = logging.getLogger(__name__) -async def _run_subprocess_streaming( - proc: asyncio.subprocess.Process, - *, - timeout: float, - progress_cb: Optional[Callable[[Dict], Awaitable[None]]] = None, -) -> Tuple[str, str, Optional[int], bool]: - """Run a subprocess to completion, streaming progress. - - Reads stdout + stderr line-by-line into ring buffers so a - periodic progress callback can emit a "tail" of recent output - without waiting for the full result. Returns - (full_stdout, full_stderr, return_code, timed_out). - - `timed_out=True` means the process was killed because it ran - past `timeout` seconds. Whatever output we'd buffered up to - that point is still returned. - """ - started = time.time() - stdout_full: list[str] = [] - stderr_full: list[str] = [] - tail = collections.deque(maxlen=PROGRESS_TAIL_LINES) - - async def _reader(stream, full_buf, label: str): - if stream is None: - return - while True: - line = await stream.readline() - if not line: - break - decoded = line.decode("utf-8", errors="replace").rstrip("\n") - full_buf.append(decoded) - if label == "err": - tail.append(f"! {decoded}") - else: - tail.append(decoded) - - async def _progress_emitter(): - # Skip the first push — many commands finish well under - # PROGRESS_INTERVAL_S and a 0-second "progress" event would - # just add UI churn. - await asyncio.sleep(PROGRESS_INTERVAL_S) - while True: - if progress_cb: - try: - await progress_cb({ - "elapsed_s": round(time.time() - started, 1), - "tail": "\n".join(list(tail)), - }) - except Exception: - # Progress is best-effort — never let a UI hiccup - # break the underlying subprocess. - pass - await asyncio.sleep(PROGRESS_INTERVAL_S) - - rd_out = asyncio.create_task(_reader(proc.stdout, stdout_full, "out")) - rd_err = asyncio.create_task(_reader(proc.stderr, stderr_full, "err")) - prog_task = asyncio.create_task(_progress_emitter()) if progress_cb else None - - timed_out = False - try: - await asyncio.wait_for(proc.wait(), timeout=timeout) - except asyncio.TimeoutError: - timed_out = True - try: - proc.kill() - except Exception: - pass - try: - await asyncio.wait_for(proc.wait(), timeout=2) - except Exception: - pass - except asyncio.CancelledError: - # User hit stop / SSE stream torn down. Kill the child so it - # doesn't keep running orphaned. Re-raise so the agent loop's - # cancellation propagates as the user expects. - try: - proc.kill() - except Exception: - pass - try: - await asyncio.wait_for(proc.wait(), timeout=2) - except Exception: - pass - # Best-effort: stop the readers + emitter before re-raising. - for t in (rd_out, rd_err): - t.cancel() - if prog_task is not None: - prog_task.cancel() - raise - finally: - if prog_task is not None and not prog_task.done(): - prog_task.cancel() - try: - await prog_task - except (asyncio.CancelledError, Exception): - pass - # Wait for readers to finish draining the pipes. - for t in (rd_out, rd_err): - try: - await asyncio.wait_for(t, timeout=1) - except Exception: - pass - - return ( - "\n".join(stdout_full), - "\n".join(stderr_full), - proc.returncode, - timed_out, - ) - _ADMIN_TOOLS = { "app_api", "manage_endpoints", @@ -593,24 +392,8 @@ async def _direct_fallback( tool: str, content: str, progress_cb: Optional[Callable[[Dict], Awaitable[None]]] = None, + workspace: Optional[str] = None, ) -> Optional[Dict]: - """In-process execution path for the eight tools that used to live as - stdio MCP servers under mcp_servers/. Those servers were deleted in - favor of native execution; this function is now the canonical path, - not a fallback. The name is kept for backwards compat with callers. - - `progress_cb` is called periodically while bash/python subprocesses - are still running, with `{elapsed_s, tail}` payloads. Other tools - ignore it. - """ - # Inherit env + force a sane terminal so subprocesses that touch - # terminfo (anything calling `clear`, `tput`, `os.system("clear")`, - # or scripts that probe $TERM) don't spam "TERM environment variable - # not set" errors. The agent's bash/python tool calls run with PIPE - # stdin/stdout (no real TTY), so curses/termios still won't work — - # but at least non-interactive code with incidental TERM lookups - # stops failing. COLUMNS/LINES give terminal-width-aware tools (less, - # rich, etc.) reasonable defaults instead of 0×0. _subproc_env = { **os.environ, "TERM": "xterm-256color", @@ -620,444 +403,16 @@ async def _direct_fallback( } try: - if tool == "bash": - proc = await asyncio.create_subprocess_shell( - content, - stdout=asyncio.subprocess.PIPE, - stderr=asyncio.subprocess.PIPE, - env=_subproc_env, - cwd=_AGENT_WORKDIR, - ) - stdout, stderr, rc, timed_out = await _run_subprocess_streaming( - proc, - timeout=DEFAULT_BASH_TIMEOUT, - progress_cb=progress_cb, - ) - if timed_out: - return {"error": f"bash: timed out after {DEFAULT_BASH_TIMEOUT}s — process killed", "exit_code": 124, "stdout": _truncate(stdout, MAX_OUTPUT_CHARS), "stderr": _truncate(stderr, MAX_OUTPUT_CHARS)} - output = stdout.rstrip() - err = stderr.rstrip() - if err: - output = (output + "\nSTDERR: " + err).strip() if output else "STDERR: " + err - output = _truncate(output, MAX_OUTPUT_CHARS) - return {"output": output or "(no output)", "exit_code": rc or 0} + ctx = { + "progress_cb": progress_cb, + "workspace": workspace, + "subproc_env": _subproc_env, + } - if tool == "python": - # Run user code in a subprocess so an infinite loop or crash - # can't take the whole server down. -I = isolated mode (skip - # user site, no PYTHONPATH inheritance) for hygiene. - proc = await asyncio.create_subprocess_exec( - # Use the running interpreter — there is no `python3.exe` on - # Windows, which made the agent's `python` tool fail there. - (sys.executable or "python"), "-I", "-c", content, - stdout=asyncio.subprocess.PIPE, - stderr=asyncio.subprocess.PIPE, - env=_subproc_env, - cwd=_AGENT_WORKDIR, - ) - stdout, stderr, rc, timed_out = await _run_subprocess_streaming( - proc, - timeout=DEFAULT_PYTHON_TIMEOUT, - progress_cb=progress_cb, - ) - if timed_out: - return {"error": f"python: timed out after {DEFAULT_PYTHON_TIMEOUT}s — process killed", "exit_code": 124, "stdout": _truncate(stdout, MAX_OUTPUT_CHARS), "stderr": _truncate(stderr, MAX_OUTPUT_CHARS)} - output = stdout.rstrip() - err = stderr.rstrip() - if err: - output = (output + "\nSTDERR: " + err).strip() if output else "STDERR: " + err - output = _truncate(output, MAX_OUTPUT_CHARS) - return {"output": output or "(no output)", "exit_code": rc or 0} + from src.agent_tools import TOOL_HANDLERS + if tool in TOOL_HANDLERS: + return await TOOL_HANDLERS[tool](content, ctx) - if tool == "read_file": - # Args: plain path on line 1 (back-compat) OR JSON - # {path, offset?, limit?} where offset/limit are a 1-based line range. - raw_path, offset, limit = content.split("\n", 1)[0].strip(), 0, 0 - _stripped = content.strip() - if _stripped.startswith("{"): - try: - _a = json.loads(_stripped) - raw_path = str(_a.get("path", "")).strip() - offset = int(_a.get("offset") or 0) - limit = int(_a.get("limit") or 0) - except (json.JSONDecodeError, TypeError, ValueError): - pass - try: - path = _resolve_tool_path(raw_path) - except ValueError as e: - return {"error": f"read_file: {e}", "exit_code": 1} - try: - # Run blocking read in a thread to keep the loop responsive. - def _read(): - if offset > 0 or limit > 0: - # Line-range read: slice [offset, offset+limit). - start = max(offset, 1) - out, n, budget = [], 0, MAX_READ_CHARS - with open(path, "r", encoding="utf-8", errors="replace") as f: - for i, line in enumerate(f, 1): - if i < start: - continue - if limit > 0 and n >= limit: - break - out.append(line) - n += 1 - budget -= len(line) - if budget <= 0: - out.append(f"\n... [truncated at {MAX_READ_CHARS} chars]") - break - return "".join(out) - with open(path, "r", encoding="utf-8", errors="replace") as f: - return f.read(MAX_READ_CHARS + 1) - data = await asyncio.to_thread(_read) - except FileNotFoundError: - return {"error": f"read_file: {path}: not found", "exit_code": 1} - except PermissionError: - return {"error": f"read_file: {path}: permission denied", "exit_code": 1} - except IsADirectoryError: - return {"error": f"read_file: {path}: is a directory (use ls)", "exit_code": 1} - except OSError as e: - return {"error": f"read_file: {path}: {e}", "exit_code": 1} - if not (offset > 0 or limit > 0) and len(data) > MAX_READ_CHARS: - data = data[:MAX_READ_CHARS] + f"\n... [truncated at {MAX_READ_CHARS} chars]" - return {"output": data, "exit_code": 0} - - if tool == "write_file": - lines = content.split("\n", 1) - raw_path = lines[0].strip() - body = lines[1] if len(lines) > 1 else "" - try: - path = _resolve_tool_path(raw_path) - except ValueError as e: - return {"error": f"write_file: {e}", "exit_code": 1} - try: - def _write(): - # Capture prior content (best-effort, text) so we can show a - # before/after diff. Missing/binary file → treat as empty. - old = "" - try: - with open(path, "r", encoding="utf-8") as f: - old = f.read() - except (FileNotFoundError, IsADirectoryError, UnicodeDecodeError, OSError): - old = "" - d = os.path.dirname(path) - if d: - os.makedirs(d, exist_ok=True) - with open(path, "w", encoding="utf-8") as f: - f.write(body) - return old, len(body) - old_content, size = await asyncio.to_thread(_write) - except PermissionError: - return {"error": f"write_file: {path}: permission denied", "exit_code": 1} - except OSError as e: - return {"error": f"write_file: {path}: {e}", "exit_code": 1} - diff = _unified_diff(old_content, body, path) - result = {"output": f"Wrote {size} bytes to {path}", "exit_code": 0} - if diff: - result["diff"] = diff - return result - - if tool == "grep": - # Args (JSON): {pattern, path?, glob?, ignore_case?, max_results?}. - # Bare string → treated as the pattern. - args: Dict[str, Any] = {} - _s = (content or "").strip() - if _s.startswith("{"): - try: - args = json.loads(_s) - except json.JSONDecodeError: - args = {} - else: - args = {"pattern": _s} - pattern = str(args.get("pattern", "")).strip() - if not pattern: - return {"error": "grep: pattern is required", "exit_code": 1} - ignore_case = bool(args.get("ignore_case")) - glob_pat = str(args.get("glob", "") or "").strip() - try: - max_hits = int(args.get("max_results") or _CODENAV_MAX_HITS) - except (TypeError, ValueError): - max_hits = _CODENAV_MAX_HITS - max_hits = max(1, min(max_hits, _CODENAV_MAX_HITS)) - try: - root = _resolve_search_root(str(args.get("path", ""))) - except ValueError as e: - return {"error": f"grep: {e}", "exit_code": 1} - - def _grep(): - import re as _re - import shutil - rg = shutil.which("rg") - if rg: - cmd = [rg, "--line-number", "--no-heading", "--color=never", - "--max-count", str(max_hits)] - if ignore_case: - cmd.append("--ignore-case") - if glob_pat: - cmd += ["--glob", glob_pat] - # Exclude junk dirs even when the tree has no .gitignore, so - # results match the Python fallback's skip set. - for _d in _CODENAV_SKIP_DIRS: - cmd += ["--glob", f"!**/{_d}/**"] - cmd += ["--regexp", pattern, root] - try: - import subprocess - p = subprocess.run(cmd, capture_output=True, text=True, timeout=20) - lines = [ln for ln in (p.stdout or "").splitlines() if ln][:max_hits] - return lines, None - except subprocess.TimeoutExpired: - return None, "grep: timed out" - except Exception as _e: - return None, f"grep: {_e}" - # Python fallback (no ripgrep): walk + regex. - try: - rx = _re.compile(pattern, _re.IGNORECASE if ignore_case else 0) - except _re.error as _e: - return None, f"grep: bad pattern: {_e}" - import fnmatch - hits = [] - if os.path.isfile(root): - file_iter = [root] - else: - file_iter = [] - for dp, dns, fns in os.walk(root): - dns[:] = [d for d in dns if d not in _CODENAV_SKIP_DIRS] - for fn in fns: - if glob_pat and not fnmatch.fnmatch(fn, glob_pat): - continue - file_iter.append(os.path.join(dp, fn)) - for fp in file_iter: - if len(hits) >= max_hits: - break - try: - with open(fp, "r", encoding="utf-8", errors="strict") as f: - for i, line in enumerate(f, 1): - if rx.search(line): - hits.append(f"{fp}:{i}:{line.rstrip()[:_CODENAV_MAX_LINE]}") - if len(hits) >= max_hits: - break - except (UnicodeDecodeError, OSError): - continue # skip binary / unreadable - return hits, None - - lines, err = await asyncio.to_thread(_grep) - if err: - return {"error": err, "exit_code": 1} - if not lines: - return {"output": f"No matches for {pattern!r} under {root}", "exit_code": 0} - out = "\n".join(ln[:_CODENAV_MAX_LINE] for ln in lines) - if len(lines) >= max_hits: - out += f"\n... [capped at {max_hits} matches]" - return {"output": _truncate(out), "exit_code": 0} - - if tool == "glob": - args = {} - _s = (content or "").strip() - if _s.startswith("{"): - try: - args = json.loads(_s) - except json.JSONDecodeError: - args = {} - else: - args = {"pattern": _s} - pattern = str(args.get("pattern", "")).strip() - if not pattern: - return {"error": "glob: pattern is required", "exit_code": 1} - try: - root = _resolve_search_root(str(args.get("path", ""))) - except ValueError as e: - return {"error": f"glob: {e}", "exit_code": 1} - - def _glob(): - from pathlib import Path - base = Path(root) - if not base.is_dir(): - return None, f"glob: {root}: not a directory" - matched = [] - try: - for p in base.rglob(pattern): - if set(p.relative_to(base).parts) & _CODENAV_SKIP_DIRS: - continue - try: - mtime = p.stat().st_mtime - except OSError: - mtime = 0 - matched.append((mtime, str(p))) - if len(matched) > _CODENAV_MAX_HITS * 5: - break - except (OSError, ValueError) as _e: - return None, f"glob: {_e}" - matched.sort(key=lambda t: t[0], reverse=True) # newest first - return [pth for _, pth in matched[:_CODENAV_MAX_HITS]], None - - paths, err = await asyncio.to_thread(_glob) - if err: - return {"error": err, "exit_code": 1} - if not paths: - return {"output": f"No files matching {pattern!r} under {root}", "exit_code": 0} - out = "\n".join(paths) - if len(paths) >= _CODENAV_MAX_HITS: - out += f"\n... [capped at {_CODENAV_MAX_HITS} files]" - return {"output": _truncate(out), "exit_code": 0} - - if tool == "ls": - raw_path = "" - _s = (content or "").strip() - if _s.startswith("{"): - try: - raw_path = str(json.loads(_s).get("path", "")).strip() - except json.JSONDecodeError: - raw_path = "" - else: - raw_path = _s.split("\n", 1)[0].strip() - try: - root = _resolve_search_root(raw_path) - except ValueError as e: - return {"error": f"ls: {e}", "exit_code": 1} - - def _ls(): - if not os.path.isdir(root): - return None, f"ls: {root}: not a directory" - rows = [] - try: - with os.scandir(root) as it: - for entry in it: - if entry.name.startswith("."): - continue - try: - is_dir = entry.is_dir(follow_symlinks=False) - size = entry.stat(follow_symlinks=False).st_size if not is_dir else 0 - except OSError: - continue - rows.append((is_dir, entry.name, size)) - except (PermissionError, OSError) as _e: - return None, f"ls: {_e}" - rows.sort(key=lambda r: (not r[0], r[1].lower())) # dirs first, then name - lines = [f"{root}:"] - for is_dir, name, size in rows[:_CODENAV_MAX_HITS]: - lines.append(f" {name}/" if is_dir else f" {name} ({size} B)") - if len(rows) > _CODENAV_MAX_HITS: - lines.append(f" ... [{len(rows) - _CODENAV_MAX_HITS} more]") - if not rows: - lines.append(" (empty)") - return "\n".join(lines), None - - out, err = await asyncio.to_thread(_ls) - if err: - return {"error": err, "exit_code": 1} - return {"output": _truncate(out), "exit_code": 0} - - if tool == "web_search": - from src.search import comprehensive_web_search - raw = content.strip() - query = raw - time_filter = None - max_pages = 5 - # Allow JSON-shaped args: {"query": "...", "time_filter": "day", "max_pages": 7} - if raw.startswith("{"): - try: - parsed = json.loads(raw) - if isinstance(parsed, dict) and "query" in parsed: - query = str(parsed.get("query", "")).strip() - tf = parsed.get("time_filter") or parsed.get("freshness") - if isinstance(tf, str) and tf.lower() in ("day", "week", "month", "year"): - time_filter = tf.lower() - mp = parsed.get("max_pages") - if isinstance(mp, int) and 1 <= mp <= 10: - max_pages = mp - except json.JSONDecodeError: - pass - if not query: - query = raw.split("\n")[0].strip() - # Auto-detect freshness from query phrasing when not explicit - if time_filter is None: - q_lc = query.lower() - if any(kw in q_lc for kw in ("today", "latest", "breaking", "this morning", "right now", "currently")): - time_filter = "day" - elif any(kw in q_lc for kw in ("this week", "past week", "recent news", "last few days")): - time_filter = "week" - elif any(kw in q_lc for kw in ("this month", "past month")): - time_filter = "month" - elif " news" in q_lc or q_lc.startswith("news ") or q_lc.endswith(" news"): - time_filter = "week" - loop = asyncio.get_running_loop() - text, sources = await asyncio.wait_for( - loop.run_in_executor( - None, - lambda: comprehensive_web_search( - query, - max_pages=max_pages, - time_filter=time_filter, - return_sources=True, - ), - ), - timeout=30, - ) - output = text[:MAX_OUTPUT_CHARS] if len(text) > MAX_OUTPUT_CHARS else text - if sources: - output += "\n\n" - return {"output": output, "exit_code": 0} - - if tool == "web_fetch": - # Lightweight single-URL fetch. Wraps the SSRF-safe fetcher used - # by deep research, so private/loopback/metadata addresses are - # already blocked there. - from src.search.content import fetch_webpage_content - raw = content.strip() - url = "" - # Accept either a JSON arg ({"url": "..."}) or a plain URL/domain. - if raw.startswith("{"): - try: - parsed = json.loads(raw) - if isinstance(parsed, dict): - url = str(parsed.get("url") or "").strip() - except json.JSONDecodeError: - url = "" - if not url: - # Non-JSON (or JSON without a usable url): take the first line - # only, so a URL followed by commentary still parses. - url = raw.split("\n")[0].strip() - # Reject anything that isn't a single bare URL/domain token. - if not url or url.startswith("{") or any(c in url for c in (" ", "\t", "\n")): - return {"error": "web_fetch: provide a single URL or domain, e.g. example.com", "exit_code": 1} - low = url.lower() - if "://" in low and not low.startswith(("http://", "https://")): - return {"error": f"web_fetch: unsupported URL scheme (only http/https): {url[:80]}", "exit_code": 1} - # Accept bare domains like "example.com" by defaulting to https. - if not low.startswith(("http://", "https://")): - url = "https://" + url - loop = asyncio.get_running_loop() - try: - result = await asyncio.wait_for( - loop.run_in_executor(None, lambda: fetch_webpage_content(url, timeout=10)), - timeout=30, - ) - except asyncio.TimeoutError: - return {"error": f"web_fetch: timed out fetching {url}", "exit_code": 1} - except Exception as e: - # Direct URL fetches can hit bot protection / auth walls - # (e.g. eBay 403). Treat that as a tool failure the model can - # reason around, not an uncaught chat-stream 500. - return {"error": f"web_fetch: {url}: {e}", "exit_code": 1} - err = result.get("error") - text = (result.get("content") or "").strip() - title = result.get("title") or "" - - if not text: - if err: - return {"error": f"web_fetch: {url}: {err}", "exit_code": 1} - # No extractable text: non-HTML body, or a pure client-rendered - # shell. The agent can fall back to the builtin_browser tool. - return {"error": f"web_fetch: {url}: no readable text content (not HTML, or the page needs JS/login)", "exit_code": 1} - - header = (f"# {title}\n" if title else "") + f"Source: {url}\n\n" - output = header + text - if len(output) > MAX_OUTPUT_CHARS: - output = output[:MAX_OUTPUT_CHARS] + "\n\n[...truncated]" - return {"output": output, "exit_code": 0} - - # manage_memory / generate_image still live as MCP servers - # (mcp_servers/{memory,image_gen}_server.py); the MCP path above - # handles them. except Exception as e: return {"error": f"{tool}: {e}", "exit_code": 1} @@ -1072,9 +427,10 @@ async def execute_tool_block( block: Any, session_id: Optional[str] = None, disabled_tools: Optional[set] = None, - tool_policy: Optional[ToolPolicy] = None, owner: Optional[str] = None, progress_cb: Optional[Callable[[Dict], Awaitable[None]]] = None, + workspace: Optional[str] = None, + tool_policy: Optional[Any] = None, ) -> Tuple[str, Dict]: """Execute a single tool block. Returns (description, result_dict). @@ -1130,18 +486,21 @@ async def execute_tool_block( pass # Reject tools that the user has disabled for this request - if tool_policy and tool_policy.blocks(tool): - desc = f"{tool}: BLOCKED" - result = {"error": tool_policy.reason_for(tool), "exit_code": 1} - logger.info("Tool blocked by policy: %s", tool) - return desc, result - if disabled_tools and tool in disabled_tools: desc = f"{tool}: BLOCKED" result = {"error": f"Tool '{tool}' is disabled by user.", "exit_code": 1} logger.info(f"Tool blocked by user: {tool}") return desc, result + if tool_policy and tool_policy.blocks(tool): + desc = f"{tool}: BLOCKED" + result = { + "error": f"Execution of tool '{tool}' is forbade by the active guide-only policy.", + "exit_code": 1, + } + logger.warning("Tool policy blocked tool=%s", tool) + return desc, result + if tool in _ADMIN_TOOLS and not _owner_is_admin(owner): desc = f"{tool}: BLOCKED" result = {"error": f"Tool '{tool}' requires an admin user.", "exit_code": 1} @@ -1381,7 +740,7 @@ async def execute_tool_block( desc = "edit_image" result = await do_edit_image(content, owner=owner) elif tool == "edit_file": - result = await _do_edit_file(content) + result = await _direct_fallback(tool, content, workspace=workspace) or {"error": "edit failed", "exit_code": 1} desc = result.get("output") or result.get("error") or "edit_file" elif tool == "trigger_research": desc = "trigger_research" diff --git a/tests/test_edit_file.py b/tests/test_edit_file.py index e35530ac2..6af22fb5d 100644 --- a/tests/test_edit_file.py +++ b/tests/test_edit_file.py @@ -11,7 +11,7 @@ from src.tool_security import ( is_public_blocked_tool, blocked_tools_for_owner, ) -from src.tool_execution import _do_edit_file +from src.agent_tools.filesystem_tools import EditFileTool from src.agent_tools import ToolBlock @@ -60,7 +60,7 @@ async def test_edit_file_blocked_at_execution_for_non_admin(monkeypatch): async def test_edit_file_success(): p = os.path.join("/tmp", "ef_ok.py") open(p, "w").write("def f():\n return 1\n") - res = await _do_edit_file(json.dumps({"path": p, "old_string": "return 1", "new_string": "return 2"})) + res = await EditFileTool().execute(json.dumps({"path": p, "old_string": "return 1", "new_string": "return 2"}), {}) assert res["exit_code"] == 0 assert open(p).read() == "def f():\n return 2\n" assert res["diff"]["added"] == 1 and res["diff"]["removed"] == 1 and res["diff"]["file"] == "ef_ok.py" @@ -71,7 +71,7 @@ async def test_edit_file_success(): async def test_edit_file_not_found(): p = os.path.join("/tmp", "ef_nf.txt") open(p, "w").write("hello\n") - res = await _do_edit_file(json.dumps({"path": p, "old_string": "nope", "new_string": "x"})) + res = await EditFileTool().execute(json.dumps({"path": p, "old_string": "nope", "new_string": "x"}), {}) assert res["exit_code"] == 1 and "not found" in res["error"] os.unlink(p) @@ -80,15 +80,15 @@ async def test_edit_file_not_found(): async def test_edit_file_non_unique(): p = os.path.join("/tmp", "ef_dup.txt") open(p, "w").write("x\nx\n") - res = await _do_edit_file(json.dumps({"path": p, "old_string": "x", "new_string": "y"})) + res = await EditFileTool().execute(json.dumps({"path": p, "old_string": "x", "new_string": "y"}), {}) assert res["exit_code"] == 1 and "not unique" in res["error"] # replace_all resolves it - res = await _do_edit_file(json.dumps({"path": p, "old_string": "x", "new_string": "y", "replace_all": True})) + res = await EditFileTool().execute(json.dumps({"path": p, "old_string": "x", "new_string": "y", "replace_all": True}), {}) assert res["exit_code"] == 0 and open(p).read() == "y\ny\n" os.unlink(p) @pytest.mark.asyncio async def test_edit_file_outside_allowed_roots(): - res = await _do_edit_file(json.dumps({"path": "/etc/hosts", "old_string": "x", "new_string": "y"})) + res = await EditFileTool().execute(json.dumps({"path": "/etc/hosts", "old_string": "x", "new_string": "y"}), {}) assert res["exit_code"] == 1 and ("outside the allowed roots" in res["error"] or "sensitive" in res["error"])