diff --git a/src/tool_execution.py b/src/tool_execution.py index 40bca4231..3aef285d8 100644 --- a/src/tool_execution.py +++ b/src/tool_execution.py @@ -50,8 +50,8 @@ def _unified_diff(old: str, new: str, path: str) -> Optional[Dict[str, Any]]: fromfile=f"a/{label}", tofile=f"b/{label}", lineterm="", )) - added = sum(1 for l in diff_lines if l.startswith("+") and not l.startswith("+++")) - removed = sum(1 for l in diff_lines if l.startswith("-") and not l.startswith("---")) + added = sum(1 for line in diff_lines if line.startswith("+") and not line.startswith("+++")) + removed = sum(1 for line in diff_lines if line.startswith("-") and not line.startswith("---")) truncated = False if len(diff_lines) > MAX_DIFF_LINES: diff_lines = diff_lines[:MAX_DIFF_LINES] @@ -555,7 +555,7 @@ def _parse_write_file(content: str) -> Dict: return {"path": lines[0].strip(), "content": lines[1] if len(lines) > 1 else ""} -_MCP_ARG_PARSERS: Dict[str, callable] = { +_MCP_ARG_PARSERS: Dict[str, Callable[[str], Dict[str, str]]] = { "bash": lambda c: {"command": c}, "python": lambda c: {"code": c}, "web_search": lambda c: {"query": c.split("\n")[0].strip()}, @@ -660,8 +660,6 @@ async def _direct_fallback( are still running, with `{elapsed_s, tail}` payloads. Other tools ignore it. """ - import json as _json - # Inherit env + force a sane terminal so subprocesses that touch # terminfo (anything calling `clear`, `tput`, `os.system("clear")`, # or scripts that probe $TERM) don't spam "TERM environment variable @@ -735,11 +733,11 @@ async def _direct_fallback( _stripped = content.strip() if _stripped.startswith("{"): try: - _a = _json.loads(_stripped) + _a = json.loads(_stripped) raw_path = str(_a.get("path", "")).strip() offset = int(_a.get("offset") or 0) limit = int(_a.get("limit") or 0) - except (_json.JSONDecodeError, TypeError, ValueError): + except (json.JSONDecodeError, TypeError, ValueError): pass try: path = (_resolve_tool_path_in_workspace(workspace, raw_path) @@ -824,8 +822,8 @@ async def _direct_fallback( _s = (content or "").strip() if _s.startswith("{"): try: - args = _json.loads(_s) - except _json.JSONDecodeError: + args = json.loads(_s) + except json.JSONDecodeError: args = {} else: args = {"pattern": _s} @@ -915,8 +913,8 @@ async def _direct_fallback( _s = (content or "").strip() if _s.startswith("{"): try: - args = _json.loads(_s) - except _json.JSONDecodeError: + args = json.loads(_s) + except json.JSONDecodeError: args = {} else: args = {"pattern": _s} @@ -965,8 +963,8 @@ async def _direct_fallback( _s = (content or "").strip() if _s.startswith("{"): try: - raw_path = str(_json.loads(_s).get("path", "")).strip() - except _json.JSONDecodeError: + raw_path = str(json.loads(_s).get("path", "")).strip() + except json.JSONDecodeError: raw_path = "" else: raw_path = _s.split("\n", 1)[0].strip() @@ -1016,7 +1014,7 @@ async def _direct_fallback( # Allow JSON-shaped args: {"query": "...", "time_filter": "day", "max_pages": 7} if raw.startswith("{"): try: - parsed = _json.loads(raw) + parsed = json.loads(raw) if isinstance(parsed, dict) and "query" in parsed: query = str(parsed.get("query", "")).strip() tf = parsed.get("time_filter") or parsed.get("freshness") @@ -1025,7 +1023,7 @@ async def _direct_fallback( mp = parsed.get("max_pages") if isinstance(mp, int) and 1 <= mp <= 10: max_pages = mp - except _json.JSONDecodeError: + except json.JSONDecodeError: pass if not query: query = raw.split("\n")[0].strip() @@ -1055,7 +1053,7 @@ async def _direct_fallback( ) output = text[:MAX_OUTPUT_CHARS] if len(text) > MAX_OUTPUT_CHARS else text if sources: - output += "\n\n" + output += "\n\n" return {"output": output, "exit_code": 0} if tool == "web_fetch": @@ -1068,10 +1066,10 @@ async def _direct_fallback( # Accept either a JSON arg ({"url": "..."}) or a plain URL/domain. if raw.startswith("{"): try: - parsed = _json.loads(raw) + parsed = json.loads(raw) if isinstance(parsed, dict): url = str(parsed.get("url") or "").strip() - except _json.JSONDecodeError: + except json.JSONDecodeError: url = "" if not url: # Non-JSON (or JSON without a usable url): take the first line @@ -1169,8 +1167,7 @@ async def execute_tool_block( # Return a helpful error so the model retries with the correct format. if tool in ("python", "json", "xml") and content.strip().startswith("{") and content.strip().endswith("}"): try: - import json as _json - parsed = _json.loads(content.strip()) + parsed = json.loads(content.strip()) if isinstance(parsed, dict): desc = f"{tool}: misformatted tool call" result = { @@ -1222,11 +1219,10 @@ async def execute_tool_block( # into an `ask_user` SSE event and then ENDS the turn, so the chat waits for # the user's selection (their choice arrives as the next message). if tool == "ask_user": - import json as _json question, options, multi = "", [], False raw = (content or "").strip() try: - parsed = _json.loads(raw) if raw else {} + parsed = json.loads(raw) if raw else {} except (ValueError, TypeError): parsed = {} if isinstance(parsed, dict):