mirror of
https://github.com/pewdiepie-archdaemon/odysseus.git
synced 2026-06-27 07:05:23 -04:00
Compare commits
3 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 537f7180e6 | |||
| bd0c67b6d3 | |||
| ff5bcd9864 |
@@ -176,7 +176,6 @@ class AuthManager:
|
||||
)
|
||||
old_user = "admin"
|
||||
old_hash = self._config["password_hash"]
|
||||
with self._config_lock:
|
||||
self._config = {
|
||||
"users": {
|
||||
old_user: {
|
||||
@@ -205,7 +204,6 @@ class AuthManager:
|
||||
continue
|
||||
normalized[key] = data
|
||||
if removed or normalized != users:
|
||||
with self._config_lock:
|
||||
self._config["users"] = normalized
|
||||
self._save()
|
||||
if removed:
|
||||
|
||||
@@ -299,16 +299,6 @@ To expose Odysseus on a local network or Tailscale with HTTPS:
|
||||
```
|
||||
4. Install the `mkcert` CA on any other device you want to access Odysseus from (e.g., for iOS, email the `rootCA.pem` to yourself, install the profile, and trust it in Certificate Trust Settings).
|
||||
|
||||
### Common self-host traps (30-second fixes)
|
||||
A grab-bag of small gotchas that otherwise turn into long debugging sessions.
|
||||
|
||||
- **`AUTH_ENABLED=false` is ignored / you're still forced to log in (Windows).** If you edited `.env` in Notepad it may have saved a UTF-8 **BOM**, turning the first key into `AUTH_ENABLED` so it is never matched. Odysseus loads `.env` with `encoding="utf-8-sig"` to tolerate a leading BOM, but the safe fix is to re-save `.env` as **UTF-8 without BOM** (VS Code: *Save with Encoding → UTF-8*).
|
||||
- **macOS: the app isn't at `http://localhost:7000`.** macOS AirPlay Receiver usually holds port `7000`, so the macOS start script serves on **`7860`** instead — open `http://localhost:7860`. To use `7000`, free it (System Settings → General → AirDrop & Handoff → turn off *AirPlay Receiver*) and set `APP_PORT=7000`.
|
||||
- **Copy buttons do nothing over a plain-HTTP Tailscale/LAN URL.** Browsers only expose the clipboard API (`navigator.clipboard`) on **secure origins** — HTTPS, or `localhost`. Over `http://100.x.y.z:7860` it is blocked. Serve over HTTPS (see *HTTPS + LAN/Tailscale exposure* above); `localhost` is exempt, so copy still works on the host itself.
|
||||
- **Self-hosted ntfy reminders don't reach your phone.** Two things: (1) the bundled ntfy binds to loopback by default — to reach it from your phone set `NTFY_BIND` to your host/Tailscale IP and `NTFY_BASE_URL` to the same server URL in `.env`, then recreate the ntfy container (see the `NTFY_*` block in `.env.example`); (2) in the ntfy **Android** app, subscribe to the topic with **Instant delivery** enabled — non-`ntfy.sh` servers don't get instant push otherwise.
|
||||
- **Local mail (Dovecot) login fails: "Plaintext authentication disallowed on non-encrypted connections."** Your IMAP/SMTP server is refusing cleartext auth over an unencrypted link. Prefer enabling TLS on the mail server; on a trusted LAN only, you can allow cleartext (Dovecot: `disable_plaintext_auth = no`).
|
||||
- **Calendar/contacts (Radicale) won't sync.** Point Odysseus at the **full collection URL** with its trailing slash — e.g. `http://host:5232/<user>/<collection-id>/` — not just the server root. Radicale shows this address for each calendar/address book in its web UI.
|
||||
|
||||
### Optional Dependencies
|
||||
`requirements-optional.txt` contains packages that unlock extra features. It is not installed by default.
|
||||
|
||||
|
||||
@@ -1310,6 +1310,8 @@ def setup_chat_routes(
|
||||
"doc_stream_open", "doc_stream_delta",
|
||||
"doc_update", "doc_suggestions", "ui_control",
|
||||
"rounds_exhausted",
|
||||
"loop_breaker_triggered",
|
||||
"intent_nudge_exhausted",
|
||||
"ask_user",
|
||||
"plan_update",
|
||||
):
|
||||
|
||||
@@ -561,7 +561,7 @@ def _bash_squote(v: str) -> str:
|
||||
# Allow-list of binaries permitted as the leading token of `req.cmd` for /api/model/serve.
|
||||
# Anything else is rejected before the cmd is interpolated into a tmux/PowerShell wrapper.
|
||||
_SERVE_CMD_ALLOWLIST = {
|
||||
"vllm", "llama-server", "llama-server.exe", "llama_server", "llama.cpp", "ollama",
|
||||
"vllm", "llama-server", "llama_server", "llama.cpp", "ollama",
|
||||
"python", "python3",
|
||||
"sglang", "lmdeploy",
|
||||
"node", "npx",
|
||||
|
||||
@@ -73,9 +73,6 @@ def setup_cookbook_routes() -> APIRouter:
|
||||
return "stored"
|
||||
return f"{value[:4]}...{value[-4:]}"
|
||||
|
||||
def _client_host_platform() -> str:
|
||||
return "windows" if IS_WINDOWS else ""
|
||||
|
||||
def _decrypt_secret(value: str | None) -> str:
|
||||
if not value:
|
||||
return ""
|
||||
@@ -248,15 +245,11 @@ def setup_cookbook_routes() -> APIRouter:
|
||||
"""Return cookbook state without raw secrets for browser clients."""
|
||||
_strip_task_secrets(state)
|
||||
env = state.get("env") if isinstance(state, dict) else None
|
||||
if isinstance(state, dict) and not isinstance(env, dict):
|
||||
env = {}
|
||||
state["env"] = env
|
||||
if isinstance(env, dict):
|
||||
token = _decrypt_secret(env.get("hfToken"))
|
||||
env.pop("hfToken", None)
|
||||
env["hfTokenConfigured"] = bool(token)
|
||||
env["hfTokenMasked"] = _mask_secret(token)
|
||||
env["hostPlatform"] = _client_host_platform()
|
||||
return state
|
||||
|
||||
def _state_for_storage(state, on_disk=None):
|
||||
@@ -275,7 +268,6 @@ def setup_cookbook_routes() -> APIRouter:
|
||||
env.pop("hfToken", None)
|
||||
env.pop("hfTokenMasked", None)
|
||||
env.pop("hfTokenConfigured", None)
|
||||
env.pop("hostPlatform", None)
|
||||
return state
|
||||
|
||||
def _load_stored_hf_token() -> str:
|
||||
@@ -1487,10 +1479,6 @@ def setup_cookbook_routes() -> APIRouter:
|
||||
# shell resolves the bundled python3/hf, mirroring the download flow.
|
||||
if not remote:
|
||||
runner_lines.append(_local_tooling_path_export(sys.executable))
|
||||
if local_windows:
|
||||
# Detached Git Bash runs do not always inherit recently edited
|
||||
# user PATH entries from the already-running Odysseus process.
|
||||
runner_lines.append('export PATH="$HOME/bin:$HOME/llama.cpp/build-cuda/bin/Release:$HOME/llama.cpp/build/bin/Release:$HOME/llama.cpp/build/bin/Debug:$HOME/llama.cpp/build/bin:$PATH"')
|
||||
runner_lines.append("export FLASHINFER_DISABLE_VERSION_CHECK=1")
|
||||
if req.hf_token:
|
||||
runner_lines.append(f"export HF_TOKEN='{_bash_squote(req.hf_token)}'")
|
||||
@@ -1505,8 +1493,7 @@ def setup_cookbook_routes() -> APIRouter:
|
||||
runner_lines.append(_HF_TOKEN_STATUS_SNIPPET)
|
||||
handled_ollama_serve = False
|
||||
# Auto-install inference engine if missing
|
||||
local_windows_llama_cmd = local_windows and ("llama_cpp" in req.cmd or "llama-server" in req.cmd)
|
||||
if ("llama_cpp" in req.cmd or "llama-server" in req.cmd) and not local_windows_llama_cmd:
|
||||
if "llama_cpp" in req.cmd or "llama-server" in req.cmd:
|
||||
# Prefer the NATIVE llama-server binary — its minja templating
|
||||
# renders modern GGUF chat templates that the Python bindings'
|
||||
# Jinja2 rejects (do_tojson ensure_ascii). Build it once from
|
||||
@@ -2409,8 +2396,8 @@ def setup_cookbook_routes() -> APIRouter:
|
||||
try:
|
||||
return _state_for_client(json.loads(_cookbook_state_path.read_text(encoding="utf-8")))
|
||||
except Exception:
|
||||
return _state_for_client({})
|
||||
return _state_for_client({})
|
||||
return {}
|
||||
return {}
|
||||
|
||||
@router.post("/api/cookbook/state")
|
||||
async def save_cookbook_state(request: Request):
|
||||
|
||||
+8
-36
@@ -64,21 +64,6 @@ ODYSSEUS_MAIL_ORIGIN = "odysseus-ui"
|
||||
EMAIL_READ_ATTACHMENT_VERSION = 2
|
||||
|
||||
|
||||
def _coerce_port(value, default):
|
||||
"""Coerce a user-supplied port to int.
|
||||
|
||||
Returns ``(port, error)``. A missing or blank value yields ``default``; a
|
||||
non-numeric value yields ``(None, message)`` so callers can return a clean
|
||||
error instead of letting ``int()`` raise and surface as an HTTP 500.
|
||||
"""
|
||||
if value in (None, ""):
|
||||
return default, None
|
||||
try:
|
||||
return int(value), None
|
||||
except (TypeError, ValueError):
|
||||
return None, f"Invalid port {value!r}; must be a whole number"
|
||||
|
||||
|
||||
def _email_tag_owner_aliases(account_id: str | None, owner: str = "") -> list[str]:
|
||||
aliases = [owner or ""]
|
||||
try:
|
||||
@@ -3344,12 +3329,6 @@ def setup_email_routes():
|
||||
name = (data.get("name") or "").strip()
|
||||
if not name:
|
||||
return {"ok": False, "error": "name required"}
|
||||
imap_port, port_err = _coerce_port(data.get("imap_port"), 993)
|
||||
if port_err:
|
||||
return {"ok": False, "error": port_err}
|
||||
smtp_port, port_err = _coerce_port(data.get("smtp_port"), 465)
|
||||
if port_err:
|
||||
return {"ok": False, "error": port_err}
|
||||
db = SessionLocal()
|
||||
try:
|
||||
row = EmailAccount(
|
||||
@@ -3358,13 +3337,13 @@ def setup_email_routes():
|
||||
is_default=bool(data.get("is_default", False)),
|
||||
enabled=bool(data.get("enabled", True)),
|
||||
imap_host=(data.get("imap_host") or "").strip(),
|
||||
imap_port=imap_port,
|
||||
imap_port=int(data.get("imap_port") or 993),
|
||||
imap_user=(data.get("imap_user") or "").strip(),
|
||||
imap_password=_enc(data.get("imap_password") or ""),
|
||||
imap_starttls=bool(data.get("imap_starttls", True)),
|
||||
smtp_host=(data.get("smtp_host") or "").strip(),
|
||||
smtp_port=smtp_port,
|
||||
smtp_security=_smtp_security_mode({"smtp_security": data.get("smtp_security"), "smtp_port": smtp_port}),
|
||||
smtp_port=int(data.get("smtp_port") or 465),
|
||||
smtp_security=_smtp_security_mode({"smtp_security": data.get("smtp_security"), "smtp_port": data.get("smtp_port") or 465}),
|
||||
smtp_user=(data.get("smtp_user") or "").strip(),
|
||||
smtp_password=_enc(data.get("smtp_password") or ""),
|
||||
from_address=(data.get("from_address") or "").strip(),
|
||||
@@ -3408,10 +3387,7 @@ def setup_email_routes():
|
||||
setattr(row, key, (data[key] or "").strip())
|
||||
for key in ("imap_port", "smtp_port"):
|
||||
if data.get(key) not in (None, ""):
|
||||
port, port_err = _coerce_port(data.get(key), None)
|
||||
if port_err:
|
||||
return {"ok": False, "error": port_err}
|
||||
setattr(row, key, port)
|
||||
setattr(row, key, int(data[key]))
|
||||
if "smtp_security" in data:
|
||||
row.smtp_security = _smtp_security_mode({"smtp_security": data.get("smtp_security"), "smtp_port": data.get("smtp_port") or row.smtp_port})
|
||||
for key in ("imap_starttls", "enabled"):
|
||||
@@ -3515,14 +3491,12 @@ def setup_email_routes():
|
||||
smtp_result = None
|
||||
|
||||
imap_host = (body.get("imap_host") or "").strip()
|
||||
imap_port, imap_port_err = _coerce_port(body.get("imap_port"), 993)
|
||||
imap_port = int(body.get("imap_port") or 993)
|
||||
imap_user = (body.get("imap_user") or "").strip()
|
||||
imap_pass = body.get("imap_password") or ""
|
||||
imap_starttls = bool(body.get("imap_starttls"))
|
||||
|
||||
if imap_port_err:
|
||||
imap_result = {"ok": False, "error": imap_port_err}
|
||||
elif not (imap_host and imap_user and imap_pass):
|
||||
if not (imap_host and imap_user and imap_pass):
|
||||
imap_result = {"ok": False, "error": "Need IMAP host, username, and password"}
|
||||
else:
|
||||
# Connection mode resolution:
|
||||
@@ -3549,10 +3523,8 @@ def setup_email_routes():
|
||||
imap_result = {"ok": False, "error": _friendly_email_auth_error("IMAP", imap_host, e)}
|
||||
|
||||
smtp_host = (body.get("smtp_host") or "").strip()
|
||||
smtp_port, smtp_port_err = _coerce_port(body.get("smtp_port"), 465)
|
||||
if smtp_host and smtp_port_err:
|
||||
smtp_result = {"ok": False, "error": smtp_port_err}
|
||||
elif smtp_host:
|
||||
if smtp_host:
|
||||
smtp_port = int(body.get("smtp_port") or 465)
|
||||
smtp_security = _smtp_security_mode({"smtp_security": body.get("smtp_security"), "smtp_port": smtp_port})
|
||||
smtp_user = (body.get("smtp_user") or imap_user).strip()
|
||||
smtp_pass = body.get("smtp_password") or imap_pass
|
||||
|
||||
+230
-52
@@ -38,6 +38,167 @@ from src.agent_tools import (
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Redaction patterns for common secret-bearing shapes. Explicit and tested
|
||||
# (see tests/test_loop_guard_signals.py) rather than one clever broad regex —
|
||||
# safety first, but we try not to mangle harmless prose. Applied in order.
|
||||
_REDACTED = "[redacted]"
|
||||
|
||||
# Cookie: ... / Set-Cookie: ... — redact the rest of the line (cookies hold spaces).
|
||||
_SENSITIVE_COOKIE_RE = re.compile(
|
||||
r"(?i)\b((?:set-)?cookie\s*[:=]\s*)[^\r\n]+"
|
||||
)
|
||||
# URL credentials, e.g. postgres://user:pass@host/db. The password half allows
|
||||
# inner colons (postgres://user:pa:ss@host/db) but still stops at / and @.
|
||||
_SENSITIVE_URL_CRED_RE = re.compile(
|
||||
r"(?i)\b([a-z][a-z0-9+.\-]*://)[^\s:/@]+:[^\s/@]+@"
|
||||
)
|
||||
# Prefix-only discovery regexes. Each matches the key and its separator (the part
|
||||
# we KEEP); the value that follows is found by a linear scanner rather than by a
|
||||
# regex, so there is no backtracking-prone quantifier over uncontrolled input.
|
||||
#
|
||||
# Authorization: Bearer <tok> / Authorization: Basic "two word secret"
|
||||
_AUTH_PREFIX_RE = re.compile(
|
||||
r"(?i)authorization\s*[:=]\s*(?:bearer|basic)\s+"
|
||||
)
|
||||
# Provider-prefixed env names, e.g. OPENAI_API_KEY=..., AWS_SECRET_ACCESS_KEY=...,
|
||||
# GITHUB_TOKEN=... — require a sensitive suffix preceded by `_` so benign names
|
||||
# that merely end in KEY (MONKEY, TURKEY) are left alone.
|
||||
_ENV_PREFIX_RE = re.compile(
|
||||
r"(?:export\s+)?\b[A-Z][A-Z0-9_]*"
|
||||
r"_(?:KEY|TOKEN|SECRET|PASSWORD|PASSWD|PWD|CREDENTIALS?)\s*=\s*"
|
||||
)
|
||||
# Generic sensitive key, e.g. password=..., api_key: ..., client_secret=...
|
||||
_KEY_PREFIX_RE = re.compile(
|
||||
r"(?i)\b(?:password|passwd|pwd|token|api[_-]?key|client_secret|secret)\b\s*[:=]\s*"
|
||||
)
|
||||
# Obvious provider-shaped bare tokens (no surrounding key needed).
|
||||
_SENSITIVE_BARE_TOKEN_RE = re.compile(
|
||||
r"\b("
|
||||
r"sk-[A-Za-z0-9_\-]{16,}" # OpenAI / Anthropic style
|
||||
r"|gh[pousr]_[A-Za-z0-9]{20,}" # GitHub PAT
|
||||
r"|xox[baprs]-[A-Za-z0-9\-]{10,}" # Slack
|
||||
r"|AKIA[0-9A-Z]{16}" # AWS access key id
|
||||
r"|hf_[A-Za-z0-9]{16,}" # Hugging Face token
|
||||
r"|AIza[0-9A-Za-z_\-]{20,}" # Google API key
|
||||
r")\b"
|
||||
)
|
||||
|
||||
|
||||
def _consume_secret_value_end(text: str, start: int) -> int:
|
||||
"""Return the exclusive end index of the secret value beginning at ``start``.
|
||||
|
||||
If the value is quoted, scan to the matching unescaped quote (backslash
|
||||
escapes are skipped two chars at a time). Otherwise scan to the first
|
||||
whitespace, comma, or semicolon. The scan is linear in the length of the
|
||||
input, so it cannot exhibit catastrophic backtracking.
|
||||
"""
|
||||
n = len(text)
|
||||
if start >= n:
|
||||
return start
|
||||
quote = text[start]
|
||||
if quote in ("'", '"'):
|
||||
i = start + 1
|
||||
while i < n:
|
||||
ch = text[i]
|
||||
if ch == "\\":
|
||||
i += 2
|
||||
continue
|
||||
if ch == quote:
|
||||
return i + 1
|
||||
i += 1
|
||||
return n # unterminated quote: redact to the end
|
||||
i = start
|
||||
while i < n and not text[i].isspace() and text[i] not in (",", ";"):
|
||||
i += 1
|
||||
return i
|
||||
|
||||
|
||||
def _redact_after_prefix(text: str, prefix_re: "re.Pattern") -> str:
|
||||
"""Redact the value following each ``prefix_re`` match using a linear scan."""
|
||||
result = []
|
||||
pos = 0
|
||||
n = len(text)
|
||||
while pos < n:
|
||||
match = prefix_re.search(text, pos)
|
||||
if match is None:
|
||||
result.append(text[pos:])
|
||||
break
|
||||
result.append(text[pos:match.end()])
|
||||
value_end = _consume_secret_value_end(text, match.end())
|
||||
if value_end > match.end():
|
||||
result.append(_REDACTED)
|
||||
pos = value_end
|
||||
else:
|
||||
# Empty value: nothing to redact; step past the prefix and continue.
|
||||
pos = match.end()
|
||||
if pos < n:
|
||||
result.append(text[pos])
|
||||
pos += 1
|
||||
return "".join(result)
|
||||
|
||||
|
||||
def _redact_private_keys(text: str) -> str:
|
||||
"""Replace PEM private-key blocks with a placeholder via linear scanning.
|
||||
|
||||
Finds ``-----BEGIN `` markers, verifies the header names a PRIVATE KEY,
|
||||
locates the matching ``-----END `` marker, and collapses the whole block.
|
||||
No regex is used, so the (multi-line, uncontrolled) body cannot trigger
|
||||
polynomial matching.
|
||||
"""
|
||||
begin_marker = "-----BEGIN "
|
||||
end_marker = "-----END "
|
||||
dash = "-----"
|
||||
max_header = 64 # generous bound on "[TYPE ]PRIVATE KEY"
|
||||
result = []
|
||||
pos = 0
|
||||
while True:
|
||||
begin = text.find(begin_marker, pos)
|
||||
if begin == -1:
|
||||
result.append(text[pos:])
|
||||
return "".join(result)
|
||||
header_start = begin + len(begin_marker)
|
||||
header_close = text.find(dash, header_start)
|
||||
if (
|
||||
header_close == -1
|
||||
or header_close - header_start > max_header
|
||||
or not text[header_start:header_close].endswith("PRIVATE KEY")
|
||||
):
|
||||
result.append(text[pos:header_start])
|
||||
pos = header_start
|
||||
continue
|
||||
end = text.find(end_marker, header_close)
|
||||
if end == -1:
|
||||
result.append(text[pos:])
|
||||
return "".join(result)
|
||||
end_header_start = end + len(end_marker)
|
||||
end_close = text.find(dash, end_header_start)
|
||||
if (
|
||||
end_close == -1
|
||||
or end_close - end_header_start > max_header
|
||||
or not text[end_header_start:end_close].endswith("PRIVATE KEY")
|
||||
):
|
||||
result.append(text[pos:header_start])
|
||||
pos = header_start
|
||||
continue
|
||||
result.append(text[pos:begin])
|
||||
result.append("[redacted private key]")
|
||||
pos = end_close + len(dash)
|
||||
|
||||
|
||||
def _redact_sensitive_text(value: object) -> str:
|
||||
"""Redact obvious credential values before surfacing tool output."""
|
||||
if value is None:
|
||||
return ""
|
||||
|
||||
text = str(value)
|
||||
text = _redact_private_keys(text)
|
||||
text = _redact_after_prefix(text, _AUTH_PREFIX_RE)
|
||||
text = _SENSITIVE_COOKIE_RE.sub(r"\1" + _REDACTED, text)
|
||||
text = _SENSITIVE_URL_CRED_RE.sub(r"\1" + _REDACTED + "@", text)
|
||||
text = _redact_after_prefix(text, _ENV_PREFIX_RE)
|
||||
text = _redact_after_prefix(text, _KEY_PREFIX_RE)
|
||||
return _SENSITIVE_BARE_TOKEN_RE.sub(_REDACTED, text)
|
||||
|
||||
|
||||
def _load_mcp_disabled_map() -> Dict[str, set]:
|
||||
"""Load per-server disabled tool sets from the database."""
|
||||
@@ -755,38 +916,6 @@ def _extract_last_user_message(messages: List[Dict]) -> str:
|
||||
return ""
|
||||
|
||||
|
||||
def _strip_think_blocks(text: str) -> str:
|
||||
"""Linear-time equivalent of
|
||||
``re.sub(r'<think>.*?</think>', '', text, flags=DOTALL|IGNORECASE)``.
|
||||
|
||||
The lazy regex rescans to end-of-string from every ``<think>`` opener when
|
||||
a closer is missing -> O(n^2) on untrusted model output (prompt injection
|
||||
can echo thousands of openers). This forward-only scan pairs each opener
|
||||
with the next closer in a single pass. Output is byte-for-byte identical to
|
||||
the original narrow regex: only literal ``<think>``/``</think>`` (any case)
|
||||
are matched, a dangling opener with no closer is left intact, and an orphan
|
||||
``</think>`` is never stripped.
|
||||
"""
|
||||
if not text:
|
||||
return text
|
||||
lowered = text.lower()
|
||||
parts = []
|
||||
pos = 0
|
||||
while True:
|
||||
start = lowered.find("<think>", pos)
|
||||
if start == -1:
|
||||
parts.append(text[pos:])
|
||||
break
|
||||
end = lowered.find("</think>", start + 7)
|
||||
if end == -1:
|
||||
# No closer for this opener: lazy regex matches nothing here.
|
||||
parts.append(text[pos:])
|
||||
break
|
||||
parts.append(text[pos:start])
|
||||
pos = end + 8 # len("</think>")
|
||||
return "".join(parts)
|
||||
|
||||
|
||||
_LOW_SIGNAL_RE = re.compile(r"^[\W_]*$", re.UNICODE)
|
||||
_CASUAL_OPENING_RE = re.compile(
|
||||
r"^\s*(?:h+i+|hey+|hello+|yo+|sup+|what'?s up|wass?up|hiya|howdy|"
|
||||
@@ -1869,7 +1998,7 @@ async def _run_verifier_subagent(
|
||||
except Exception as e:
|
||||
logger.warning(f"[agent] verifier subagent failed: {e}")
|
||||
return []
|
||||
raw = _strip_think_blocks(raw or "")
|
||||
raw = re.sub(r"<think>.*?</think>", "", raw or "", flags=re.DOTALL | re.IGNORECASE)
|
||||
last_v = None
|
||||
for line in raw.splitlines():
|
||||
if "VERIFICATION:" in line:
|
||||
@@ -2487,10 +2616,12 @@ async def stream_agent_loop(
|
||||
# signatures + consecutive no-text tool rounds to bail early.
|
||||
_recent_call_sigs = collections.deque(maxlen=6)
|
||||
_stuck_rounds = 0
|
||||
_MAX_STUCK_ROUNDS = 4 # consecutive no-progress rounds before loop-breaker bails
|
||||
# Frequency of each exact call signature (tool + args), for the runaway
|
||||
# backstop. Counting identical repeats — not distinct same-tool calls —
|
||||
# lets a legit batch (e.g. 18 calendar events at once) through.
|
||||
_call_freq: collections.Counter = collections.Counter()
|
||||
_THINK_RE = re.compile(r'<think>.*?</think>', re.DOTALL | re.IGNORECASE)
|
||||
_force_answer = False # set by loop-breaker → next round runs with NO tools
|
||||
# Supervisor: how many times we've nudged the model after it announced
|
||||
# an action without emitting the tool call. Capped to prevent a model
|
||||
@@ -2828,7 +2959,7 @@ async def stream_agent_loop(
|
||||
if tool_blocks:
|
||||
logger.info(f"[agent] force-answer round {round_num}: discarding {len(tool_blocks)} ignored tool call(s)")
|
||||
tool_blocks = []
|
||||
if not _strip_think_blocks(strip_tool_blocks(round_response)).strip():
|
||||
if not _THINK_RE.sub("", strip_tool_blocks(round_response)).strip():
|
||||
# The model burned its budget gathering data but never wrote a
|
||||
# final answer (common with weaker models on multi-source
|
||||
# briefings). Salvage it: one blunt non-streaming synthesis call
|
||||
@@ -2851,7 +2982,7 @@ async def stream_agent_loop(
|
||||
url=endpoint_url, model=model, messages=_synth_messages,
|
||||
headers=headers, temperature=0.3, max_tokens=max_tokens, timeout=60,
|
||||
)
|
||||
_synth = _strip_think_blocks(strip_tool_blocks(_raw or "")).strip()
|
||||
_synth = _THINK_RE.sub("", strip_tool_blocks(_raw or "")).strip()
|
||||
except Exception as _e:
|
||||
logger.warning(f"[agent] grace synthesis failed: {_e}")
|
||||
if _synth:
|
||||
@@ -2913,7 +3044,7 @@ async def stream_agent_loop(
|
||||
# the model fix them (capped, and it must do new effectful work
|
||||
# to re-trigger). Skipped on force-answer rounds (no tools to
|
||||
# fix with), pure Q&A, and when the toggle is off.
|
||||
_claimed_done = bool(_strip_think_blocks(cleaned_round).strip())
|
||||
_claimed_done = bool(_THINK_RE.sub("", cleaned_round).strip())
|
||||
if (_effectful_used and not _force_answer
|
||||
and _claimed_done
|
||||
and _verifier_rounds < _VERIFIER_MAX_ROUNDS
|
||||
@@ -2957,23 +3088,28 @@ async def stream_agent_loop(
|
||||
# actual tool now") and loop again. Capped at
|
||||
# _MAX_INTENT_NUDGES so a model that genuinely cannot use the
|
||||
# tool doesn't pin us in a forever loop.
|
||||
_intent_text = _strip_think_blocks(cleaned_round).strip()
|
||||
_intent_text = _THINK_RE.sub("", cleaned_round).strip()
|
||||
_intent_match = _INTENT_RE.search(_intent_text) if _intent_text else None
|
||||
# Only nudge when the round REALLY looks like an unfinished
|
||||
# promise: short response (<400 chars), no fenced code/answer,
|
||||
# and an action-intent phrase was matched. Long answers that
|
||||
# happen to contain "let me know" are not stalls.
|
||||
_looks_like_promise = (
|
||||
_promise_shape = (
|
||||
not guide_only
|
||||
and _intent_match is not None
|
||||
and len(_intent_text) < 400
|
||||
and "```" not in _intent_text
|
||||
and _intent_nudge_count < _MAX_INTENT_NUDGES
|
||||
)
|
||||
_looks_like_promise = _promise_shape and _intent_nudge_count < _MAX_INTENT_NUDGES
|
||||
if _looks_like_promise:
|
||||
_intent_nudge_count += 1
|
||||
_matched_phrase = _intent_match.group(0).strip()
|
||||
logger.info(f"[agent] intent-without-action nudge #{_intent_nudge_count} on round {round_num}: {_matched_phrase!r}")
|
||||
# Don't log the matched phrase — it's raw model text that may
|
||||
# carry credentials. Structural metadata only.
|
||||
logger.info(
|
||||
"[agent] intent-without-action nudge #%d on round %d",
|
||||
_intent_nudge_count, round_num,
|
||||
)
|
||||
_lower_phrase = _matched_phrase.lower()
|
||||
_cookbook_log_hint = ""
|
||||
if any(_word in _lower_phrase for _word in ("log", "logs", "output", "tail", "status")):
|
||||
@@ -2999,6 +3135,24 @@ async def stream_agent_loop(
|
||||
# Visible signal in the stream so the user knows we caught it.
|
||||
yield f'data: {json.dumps({"type": "agent_step", "round": round_num + 1})}\n\n'
|
||||
continue
|
||||
# The model keeps announcing actions it never takes and we've spent
|
||||
# every nudge — surface why the turn is ending instead of letting it
|
||||
# look like a clean completion.
|
||||
if _promise_shape and _intent_nudge_count >= _MAX_INTENT_NUDGES:
|
||||
_matched_phrase = _intent_match.group(0).strip()
|
||||
_matched_phrase_safe = _redact_sensitive_text(_matched_phrase)
|
||||
_in_message = (
|
||||
f"Intent-nudge cap reached on round {round_num}: the model "
|
||||
f"announced an action ({_matched_phrase_safe!r}) without a tool call "
|
||||
f"after {_intent_nudge_count} nudge(s); ending the turn."
|
||||
)
|
||||
# Do not log the matched phrase, even redacted. It is raw model
|
||||
# text and may contain credentials; keep logs structural only.
|
||||
logger.warning(
|
||||
"[agent] intent-nudge cap exhausted on round %d (%d/%d)",
|
||||
round_num, _intent_nudge_count, _MAX_INTENT_NUDGES,
|
||||
)
|
||||
yield f'data: {json.dumps({"type": "intent_nudge_exhausted", "round": round_num, "nudges": _intent_nudge_count, "max_nudges": _MAX_INTENT_NUDGES, "message": _in_message})}\n\n'
|
||||
break # no tools — done
|
||||
|
||||
# ── Loop-breaker (Terminus-style stall detector) ──────────────
|
||||
@@ -3020,7 +3174,7 @@ async def stream_agent_loop(
|
||||
# "Real" answer text = round text minus <think> blocks. Empty-think
|
||||
# rounds (just "<think>\n\n</think>" + a tool call) must not read as
|
||||
# progress, so strip think before checking.
|
||||
_real_text = _strip_think_blocks(cleaned_round).strip()
|
||||
_real_text = _THINK_RE.sub("", cleaned_round).strip()
|
||||
# Circling = repeating a recent call with nothing written. Any
|
||||
# progress (a NEW distinct call, or actual answer text) resets it.
|
||||
if _is_repeat and not _real_text:
|
||||
@@ -3031,10 +3185,23 @@ async def stream_agent_loop(
|
||||
# Distinct calls to one tool (a real batch) are legitimate work, so we
|
||||
# count identical call signatures, not raw per-tool-type totals.
|
||||
_runaway = _detect_runaway_call(_call_freq)
|
||||
if _stuck_rounds >= 4 or _runaway:
|
||||
if _stuck_rounds >= _MAX_STUCK_ROUNDS or _runaway:
|
||||
reason = (f"calling {_runaway} with identical arguments over and over" if _runaway
|
||||
else "repeating the same tool calls without new progress")
|
||||
logger.warning(f"[agent] loop-breaker tripped on round {round_num} ({reason}); sig={_sig[:80]!r}")
|
||||
_lb_message = (
|
||||
f"Loop-breaker stopped the agent on round {round_num}: {reason}. "
|
||||
"Forced one tool-free round to converge on an answer or state what's blocked."
|
||||
)
|
||||
# Log structural metadata only — `_sig` is raw tool-call content
|
||||
# that may carry credentials.
|
||||
logger.warning(
|
||||
"[agent] loop-breaker tripped on round %d (%s); "
|
||||
"stuck_rounds=%d/%d runaway=%r",
|
||||
round_num, reason, _stuck_rounds, _MAX_STUCK_ROUNDS, _runaway,
|
||||
)
|
||||
# Surface the stop cause to the stream so the user (and journalctl)
|
||||
# can tell a guard fired, not a clean completion.
|
||||
yield f'data: {json.dumps({"type": "loop_breaker_triggered", "round": round_num, "reason": reason, "stuck_rounds": _stuck_rounds, "max_stuck_rounds": _MAX_STUCK_ROUNDS, "runaway": _runaway, "message": _lb_message})}\n\n'
|
||||
# The model has been executing tools, so its results are already
|
||||
# in context. Force ONE tool-free round to converge: write the
|
||||
# answer from what it has, or state plainly what's blocking it.
|
||||
@@ -3113,6 +3280,10 @@ async def stream_agent_loop(
|
||||
cmd_display = block.content.split("\n")[0].strip()[:80]
|
||||
else:
|
||||
cmd_display = block.content.strip()
|
||||
# The display string is streamed (tool_start/tool_output) and persisted;
|
||||
# redact any secrets in it. block.content itself is left untouched so
|
||||
# tool execution still sees the real command.
|
||||
cmd_display = _redact_sensitive_text(cmd_display)
|
||||
|
||||
if tool_policy and tool_policy.blocks(block.tool_type):
|
||||
desc = f"{block.tool_type}: BLOCKED"
|
||||
@@ -3158,8 +3329,15 @@ async def stream_agent_loop(
|
||||
evt = await _progress_q.get()
|
||||
if evt is None:
|
||||
break
|
||||
# Redact secrets in the live tail before streaming — the
|
||||
# final tool_output is redacted, so the progress tail must
|
||||
# be too, or a secret could flash by mid-run. Copy so we
|
||||
# don't mutate the tool's own event payload.
|
||||
_evt = dict(evt)
|
||||
if isinstance(_evt.get("tail"), str):
|
||||
_evt["tail"] = _redact_sensitive_text(_evt["tail"])
|
||||
yield (
|
||||
f'data: {json.dumps({"type": "tool_progress", "tool": block.tool_type, "round": round_num, **evt})}\n\n'
|
||||
f'data: {json.dumps({"type": "tool_progress", "tool": block.tool_type, "round": round_num, **_evt})}\n\n'
|
||||
)
|
||||
desc, result = await _tool_task
|
||||
|
||||
@@ -3225,7 +3403,7 @@ async def stream_agent_loop(
|
||||
result["results"] = _clean
|
||||
elif "stdout" in result:
|
||||
result["stdout"] = _clean
|
||||
except (json.JSONDecodeError, Exception):
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Emit doc-specific event for document tools — the frontend
|
||||
@@ -3295,29 +3473,29 @@ async def stream_agent_loop(
|
||||
# empty) stdout/stderr; fall back to the error so the "timed
|
||||
# out" reason reaches the UI instead of a blank result.
|
||||
raw = result["stdout"] or result["stderr"] or result.get("error", "")
|
||||
output_text = _truncate(raw)
|
||||
output_text = _truncate(_redact_sensitive_text(raw))
|
||||
elif "output" in result:
|
||||
# bash / python canonical result: {"output": ..., "exit_code": ...}
|
||||
raw = result["output"] or ""
|
||||
output_text = _truncate(raw)
|
||||
output_text = _truncate(_redact_sensitive_text(raw))
|
||||
elif "response" in result:
|
||||
# AI interaction tools (chat_with_model, send_to_session)
|
||||
label = result.get("model", result.get("session_name", "AI"))
|
||||
output_text = _truncate(f"{label}: {result['response']}")
|
||||
output_text = _truncate(_redact_sensitive_text(f"{label}: {result['response']}"))
|
||||
elif "content" in result:
|
||||
output_text = _truncate(result["content"])
|
||||
output_text = _truncate(_redact_sensitive_text(result["content"]))
|
||||
elif "results" in result:
|
||||
output_text = _truncate(result["results"])
|
||||
output_text = _truncate(_redact_sensitive_text(result["results"]))
|
||||
elif "session_id" in result and "name" in result:
|
||||
output_text = f"Session created: {result['name']} (id: {result['session_id']})"
|
||||
elif "success" in result:
|
||||
output_text = (
|
||||
f"Written: {result.get('path', '')}"
|
||||
if result["success"]
|
||||
else f"Error: {result.get('error', '')}"
|
||||
else f"Error: {_redact_sensitive_text(result.get('error', ''))}"
|
||||
)
|
||||
elif "error" in result:
|
||||
output_text = _truncate(result["error"])
|
||||
output_text = _truncate(_redact_sensitive_text(result["error"]))
|
||||
|
||||
# Emit tool_output (include ui_event data if present)
|
||||
tool_output_data = {"type": "tool_output", "tool": block.tool_type, "command": cmd_display, "output": output_text, "exit_code": result.get("exit_code")}
|
||||
|
||||
+16
-41
@@ -55,8 +55,6 @@ class EmbeddingClient:
|
||||
# of stalling startup ~30s per probe. Read stays generous for a real
|
||||
# endpoint (embedding a short string returns in well under a second).
|
||||
self._client = httpx.Client(timeout=httpx.Timeout(connect=3.0, read=10.0, write=5.0, pool=3.0))
|
||||
self._batch_size = max(1, int(os.getenv("EMBEDDING_BATCH_SIZE", "8")))
|
||||
self._max_chars = max(200, int(os.getenv("EMBEDDING_MAX_CHARS", "900")))
|
||||
|
||||
def get_sentence_embedding_dimension(self) -> int:
|
||||
"""Probe the endpoint for embedding dimension if not yet known."""
|
||||
@@ -75,10 +73,23 @@ class EmbeddingClient:
|
||||
if not texts:
|
||||
return np.array([], dtype="float32")
|
||||
|
||||
# Batch in chunks of 64 to avoid oversized requests
|
||||
all_vecs = []
|
||||
for i in range(0, len(texts), self._batch_size):
|
||||
batch = texts[i : i + self._batch_size]
|
||||
all_vecs.extend(self._embed_batch(batch))
|
||||
for i in range(0, len(texts), 64):
|
||||
batch = texts[i : i + 64]
|
||||
resp = self._client.post(
|
||||
self.url,
|
||||
headers={"Authorization": f"Bearer {self.api_key}"} if self.api_key else {},
|
||||
json={"input": batch, "model": self.model},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
|
||||
# OpenAI format: {"data": [{"embedding": [...], "index": 0}, ...]}
|
||||
embeddings = data.get("data", [])
|
||||
embeddings.sort(key=lambda e: e.get("index", 0))
|
||||
for emb in embeddings:
|
||||
all_vecs.append(emb["embedding"])
|
||||
|
||||
vecs = np.array(all_vecs, dtype="float32")
|
||||
|
||||
@@ -92,42 +103,6 @@ class EmbeddingClient:
|
||||
|
||||
return vecs
|
||||
|
||||
def _embed_batch(self, batch: List[str]) -> List[List[float]]:
|
||||
try:
|
||||
return self._post_embeddings(batch)
|
||||
except httpx.HTTPStatusError as e:
|
||||
status = e.response.status_code if e.response is not None else None
|
||||
if status != 400:
|
||||
raise
|
||||
if len(batch) > 1:
|
||||
vecs = []
|
||||
for text in batch:
|
||||
vecs.extend(self._embed_batch([text]))
|
||||
return vecs
|
||||
text = batch[0]
|
||||
trimmed = text[: self._max_chars]
|
||||
if trimmed != text:
|
||||
logger.warning(
|
||||
"Embedding input exceeded endpoint context; retrying with %d chars",
|
||||
len(trimmed),
|
||||
)
|
||||
return self._post_embeddings([trimmed])
|
||||
raise
|
||||
|
||||
def _post_embeddings(self, batch: List[str]) -> List[List[float]]:
|
||||
resp = self._client.post(
|
||||
self.url,
|
||||
headers={"Authorization": f"Bearer {self.api_key}"} if self.api_key else {},
|
||||
json={"input": batch, "model": self.model},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
|
||||
# OpenAI format: {"data": [{"embedding": [...], "index": 0}, ...]}
|
||||
embeddings = data.get("data", [])
|
||||
embeddings.sort(key=lambda e: e.get("index", 0))
|
||||
return [emb["embedding"] for emb in embeddings]
|
||||
|
||||
|
||||
class FastEmbedClient:
|
||||
"""Local embedding client using fastembed (ONNX). No external service needed."""
|
||||
|
||||
@@ -152,7 +152,6 @@ DEFAULT_SETTINGS = {
|
||||
"utility_model_fallbacks": [],
|
||||
"teacher_model": "",
|
||||
"teacher_enabled": False,
|
||||
"teacher_tier2_enabled": False,
|
||||
# Skills: minimum self-reported confidence for an auto-written (LLM-authored)
|
||||
# DRAFT skill to be injected into the agent prompt. Published skills always
|
||||
# qualify. Keeps low-confidence auto-skills out of context until they're
|
||||
|
||||
+7
-102
@@ -366,71 +366,6 @@ def _format_trace(tool_results: List[Dict[str, Any]], agent_reply: str) -> str:
|
||||
return f"<<<UNTRUSTED_TRACE>>>\n{trace}\n<<<END_UNTRUSTED_TRACE>>>"
|
||||
|
||||
|
||||
_EVALUATE_TURN_LLM_PROMPT = """\
|
||||
You are an independent auditor evaluating a student AI agent's turn.
|
||||
Given the original request, the trace of tool calls and results, and the agent's final reply, determine whether the agent failed, gave up because it lacks the tools/capability/information, or encountered an error.
|
||||
|
||||
Respond with exactly one of these two words:
|
||||
- "failure" if the agent failed, gave up, encountered an error, or asked the user for clarification/missing tools.
|
||||
- "ok" if the agent successfully completed the task or is making correct progress.
|
||||
|
||||
ORIGINAL USER REQUEST:
|
||||
{user_request}
|
||||
|
||||
AGENT TRACE:
|
||||
{trace}
|
||||
|
||||
AGENT REPLY:
|
||||
{agent_reply}
|
||||
|
||||
EVALUATION:"""
|
||||
|
||||
|
||||
async def evaluate_turn_llm(
|
||||
user_request: str,
|
||||
tool_results: List[Dict[str, Any]],
|
||||
agent_reply: str,
|
||||
student_endpoint_url: str,
|
||||
owner: Optional[str] = None,
|
||||
) -> Tuple[str, Optional[str]]:
|
||||
"""Use a fast LLM (resolved via utility endpoint) to evaluate a turn."""
|
||||
from src.endpoint_resolver import resolve_endpoint
|
||||
from src.llm_core import llm_call_async
|
||||
|
||||
# Resolve utility model (falls back to default model, then student_endpoint_url)
|
||||
url, model, headers = resolve_endpoint(
|
||||
"utility",
|
||||
fallback_url=student_endpoint_url,
|
||||
owner=owner
|
||||
)
|
||||
if not url or not model:
|
||||
return ("ok", None)
|
||||
|
||||
trace_str = _format_trace(tool_results, agent_reply)
|
||||
prompt = _EVALUATE_TURN_LLM_PROMPT.format(
|
||||
user_request=user_request or "(no user request)",
|
||||
trace=trace_str,
|
||||
agent_reply=agent_reply or "(no agent reply)",
|
||||
)
|
||||
|
||||
try:
|
||||
response = await llm_call_async(
|
||||
url, model,
|
||||
[{"role": "user", "content": prompt}],
|
||||
headers=headers,
|
||||
timeout=20,
|
||||
)
|
||||
if response:
|
||||
cleaned_response = response.strip().strip("'\"").lower()
|
||||
if cleaned_response == "failure":
|
||||
return ("failure", f"LLM evaluation flagged failure: {response.strip()}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Tier 2 LLM self-eval failed: {e}")
|
||||
|
||||
return ("ok", None)
|
||||
|
||||
|
||||
|
||||
async def escalate_and_learn(
|
||||
user_request: str,
|
||||
tool_results: List[Dict[str, Any]],
|
||||
@@ -524,34 +459,15 @@ def maybe_escalate(
|
||||
|
||||
# Gate 3: regex eval — only escalate on detected failure.
|
||||
status, reason = evaluate_turn_regex(tool_results, agent_reply)
|
||||
if status == "failure":
|
||||
if status != "failure":
|
||||
return None
|
||||
|
||||
# Fire async — don't block the user's chat.
|
||||
return asyncio.create_task(
|
||||
escalate_and_learn(user_request, tool_results, agent_reply, reason or "", owner),
|
||||
name="teacher_escalation",
|
||||
)
|
||||
|
||||
# Gate 4: Tier 2 LLM self-evaluation requires teacher_tier2_enabled
|
||||
if not get_setting("teacher_tier2_enabled", False):
|
||||
return None
|
||||
|
||||
# Tier 2: LLM self-evaluation background task
|
||||
async def evaluate_and_maybe_escalate():
|
||||
llm_status, llm_reason = await evaluate_turn_llm(
|
||||
user_request=user_request,
|
||||
tool_results=tool_results,
|
||||
agent_reply=agent_reply,
|
||||
student_endpoint_url=student_endpoint_url,
|
||||
owner=owner,
|
||||
)
|
||||
if llm_status == "failure":
|
||||
await escalate_and_learn(user_request, tool_results, agent_reply, llm_reason or "", owner)
|
||||
|
||||
return asyncio.create_task(
|
||||
evaluate_and_maybe_escalate(),
|
||||
name="teacher_escalation_tier2",
|
||||
)
|
||||
|
||||
|
||||
# ── Inline teacher takeover (visible in chat stream) ───────────────
|
||||
|
||||
@@ -585,6 +501,10 @@ async def run_teacher_inline(
|
||||
except Exception:
|
||||
return
|
||||
|
||||
status, reason = evaluate_turn_regex(student_tool_events, student_reply)
|
||||
if status != "failure":
|
||||
return
|
||||
|
||||
# Extract original user request — last user-role message
|
||||
user_request = ""
|
||||
for m in reversed(student_messages):
|
||||
@@ -601,21 +521,6 @@ async def run_teacher_inline(
|
||||
)
|
||||
break
|
||||
|
||||
status, reason = evaluate_turn_regex(student_tool_events, student_reply)
|
||||
if status != "failure":
|
||||
# Tier 2: LLM self-evaluation check requires teacher_tier2_enabled
|
||||
if not get_setting("teacher_tier2_enabled", False):
|
||||
return
|
||||
status, reason = await evaluate_turn_llm(
|
||||
user_request=user_request,
|
||||
tool_results=student_tool_events,
|
||||
agent_reply=student_reply,
|
||||
student_endpoint_url=student_endpoint_url,
|
||||
owner=owner,
|
||||
)
|
||||
if status != "failure":
|
||||
return
|
||||
|
||||
# Resolve teacher endpoint
|
||||
try:
|
||||
from src.ai_interaction import _resolve_model
|
||||
|
||||
+3459
-68
File diff suppressed because it is too large
Load Diff
@@ -1,32 +0,0 @@
|
||||
"""Tool implementation package, split by domain (slice 1, #4082/#4071).
|
||||
|
||||
Public tool functions live in domain modules. ``src.tool_implementations``
|
||||
re-exports from here for backward compatibility.
|
||||
"""
|
||||
from src.tools._common import _parse_tool_args # noqa: F401
|
||||
from src.tools.system import ( # noqa: F401
|
||||
do_manage_skills, _skill_dump, do_manage_tasks,
|
||||
do_api_call, do_app_api,
|
||||
)
|
||||
from src.tools.cookbook import ( # noqa: F401
|
||||
do_download_model, do_serve_model, do_list_served_models,
|
||||
do_stop_served_model, do_tail_serve_output, do_list_downloads,
|
||||
do_cancel_download, do_search_hf_models, do_adopt_served_model,
|
||||
do_list_cookbook_servers, do_list_serve_presets, do_serve_preset,
|
||||
do_list_cached_models,
|
||||
_cookbook_servers, _resolve_cookbook_host, _cookbook_env_for_host,
|
||||
_infer_serve_port, _infer_serve_host, _ensure_served_endpoint,
|
||||
_cookbook_register_task, _cookbook_apply_retry_suggestion,
|
||||
_scan_running_model_processes, _cookbook_kill_session,
|
||||
_MODEL_PROCESS_PATTERNS,
|
||||
)
|
||||
from src.tools.search import do_search_chats # noqa: F401
|
||||
from src.tools.notes import do_manage_notes # noqa: F401
|
||||
from src.tools.calendar import do_manage_calendar # noqa: F401
|
||||
from src.tools.image import do_edit_image # noqa: F401
|
||||
from src.tools.research import do_manage_research, do_trigger_research # noqa: F401
|
||||
from src.tools.contacts import do_resolve_contact, do_manage_contact # noqa: F401
|
||||
from src.tools.vault import ( # noqa: F401
|
||||
_load_vault_config, _run_bw,
|
||||
do_vault_search, do_vault_get, do_vault_unlock,
|
||||
)
|
||||
@@ -1,25 +0,0 @@
|
||||
"""Shared helpers used across tool implementation domains.
|
||||
|
||||
Extracted from tool_implementations.py as part of slice 1 (#4082/#4071).
|
||||
Domain modules under src/tools/ import from here.
|
||||
"""
|
||||
from typing import Dict, Optional
|
||||
|
||||
from core.constants import internal_api_base
|
||||
from src.tool_utils import _parse_tool_args # noqa: F401 — single source of the tool-arg parser; tool_utils is a leaf module (imports nothing from src)
|
||||
|
||||
|
||||
# In-process loopback base for agent tools that call Odysseus's own API
|
||||
# (cookbook state, model serve, gallery, email, calendar). We ride the
|
||||
# per-process internal token so require_admin lets us through. See
|
||||
# core/middleware.py. Resolution (override / APP_PORT / 7000) lives in
|
||||
# core.constants.internal_api_base().
|
||||
_INTERNAL_BASE = internal_api_base()
|
||||
|
||||
|
||||
def _internal_headers(owner: Optional[str] = None) -> Dict[str, str]:
|
||||
from core.middleware import INTERNAL_TOOL_HEADER, INTERNAL_TOOL_TOKEN
|
||||
headers = {INTERNAL_TOOL_HEADER: INTERNAL_TOOL_TOKEN}
|
||||
if owner:
|
||||
headers["X-Odysseus-Owner"] = owner
|
||||
return headers
|
||||
@@ -1,522 +0,0 @@
|
||||
"""Calendar-domain tool implementations.
|
||||
|
||||
Extracted from tool_implementations.py as part of slice 1 (#4082/#4071).
|
||||
Holds the manage_calendar tool (CalDAV-backed event CRUD).
|
||||
``src.tool_implementations`` re-exports these for backward compatibility.
|
||||
"""
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from typing import Dict, Optional
|
||||
|
||||
from src.tools._common import _parse_tool_args
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def do_manage_calendar(content: str, owner: Optional[str] = None) -> Dict:
|
||||
"""Handle manage_calendar tool calls: list/create/update/delete calendar events (local SQLite)."""
|
||||
from datetime import datetime, timedelta
|
||||
from core.database import SessionLocal, CalendarCal, CalendarEvent, Note
|
||||
from routes.calendar_routes import (
|
||||
_ensure_default_calendar,
|
||||
_parse_dt,
|
||||
_parse_dt_pair,
|
||||
parse_due_for_user,
|
||||
_resolve_base_uid,
|
||||
_push_caldav_event_after_commit,
|
||||
_record_caldav_delete_tombstone,
|
||||
)
|
||||
import uuid as _uuid
|
||||
|
||||
try:
|
||||
args = _parse_tool_args(content)
|
||||
except ValueError:
|
||||
return {"error": "Invalid JSON arguments", "exit_code": 1}
|
||||
|
||||
# ── Batch normalization ──
|
||||
# Some models (e.g. deepseek-v4-flash) emit {"events": [{...}, ...]}
|
||||
# instead of individual create_event calls. Iterate and create each.
|
||||
if isinstance(args.get("events"), list) and not args.get("action"):
|
||||
results = []
|
||||
for ev in args["events"]:
|
||||
if not isinstance(ev, dict):
|
||||
continue
|
||||
# Normalize start/end from {dateTime: "..."} object to flat string
|
||||
for field, target in [("start", "dtstart"), ("end", "dtend")]:
|
||||
val = ev.pop(field, None)
|
||||
if val and target not in ev:
|
||||
ev[target] = val.get("dateTime", val) if isinstance(val, dict) else val
|
||||
ev.setdefault("action", "create_event")
|
||||
r = await do_manage_calendar(json.dumps(ev), owner=owner)
|
||||
results.append(r)
|
||||
created = [r for r in results if r.get("exit_code") == 0 and not r.get("error")]
|
||||
failed = [r for r in results if r.get("error")]
|
||||
|
||||
if not results:
|
||||
return {"error": "No events to create", "exit_code": 1}
|
||||
|
||||
# Surface both successes and failures
|
||||
parts = []
|
||||
if created:
|
||||
summaries = [r.get("response", "") for r in created]
|
||||
parts.append(f"Created {len(created)} event(s):\n" + "\n".join(summaries))
|
||||
if failed:
|
||||
first_error = failed[0].get("error", "Unknown error")
|
||||
parts.append(f"Failed to create {len(failed)} event(s). First error: {first_error}")
|
||||
|
||||
response = "\n\n".join(parts)
|
||||
# Non-zero exit code for partial or total failure
|
||||
exit_code = 0 if not failed else 1
|
||||
return {"response": response, "exit_code": exit_code, "created_count": len(created), "failed_count": len(failed)}
|
||||
|
||||
# Normalize action — some models emit hyphens ("list-calendars") instead
|
||||
# of underscores. Treat them as equivalent so we don't bounce a
|
||||
# cosmetic typo back to the model and waste a round-trip. Also accept
|
||||
# short forms (`create`, `update`, `delete`) as aliases for the
|
||||
# full `<verb>_event` names — models keep emitting the short forms.
|
||||
action = (args.get("action") or "list_events").replace("-", "_").strip().lower()
|
||||
_ACTION_ALIASES = {
|
||||
"create": "create_event",
|
||||
"update": "update_event",
|
||||
"delete": "delete_event",
|
||||
"list": "list_events",
|
||||
}
|
||||
action = _ACTION_ALIASES.get(action, action)
|
||||
db = SessionLocal()
|
||||
|
||||
def _calendar_query():
|
||||
q = db.query(CalendarCal)
|
||||
if owner is not None:
|
||||
q = q.filter(CalendarCal.owner == owner)
|
||||
return q
|
||||
|
||||
def _event_query():
|
||||
q = db.query(CalendarEvent).join(CalendarCal)
|
||||
if owner is not None:
|
||||
q = q.filter(CalendarCal.owner == owner)
|
||||
return q
|
||||
|
||||
def _reminder_minutes(raw_args) -> Optional[int]:
|
||||
raw = (
|
||||
raw_args.get("reminder_minutes")
|
||||
or raw_args.get("remind_before_minutes")
|
||||
or raw_args.get("alarm_minutes")
|
||||
or raw_args.get("reminder")
|
||||
or raw_args.get("alarm")
|
||||
)
|
||||
if raw in (None, ""):
|
||||
desc = str(raw_args.get("description") or "")
|
||||
if re.search(r"\b(remind|reminder|alarm)\b", desc, re.I):
|
||||
raw = desc
|
||||
if raw in (None, "", False):
|
||||
return None
|
||||
if raw is True:
|
||||
return 10
|
||||
if isinstance(raw, (int, float)):
|
||||
return max(0, int(raw))
|
||||
text = str(raw).strip().lower()
|
||||
if text in {"none", "no", "off", "false"}:
|
||||
return None
|
||||
m = re.search(r"(\d+)\s*(?:minutes?|mins?|m)\b", text)
|
||||
if m:
|
||||
return max(0, int(m.group(1)))
|
||||
m = re.search(r"(\d+)\s*(?:hours?|hrs?|h)\b", text)
|
||||
if m:
|
||||
return max(0, int(m.group(1)) * 60)
|
||||
if text.isdigit():
|
||||
return max(0, int(text))
|
||||
return None
|
||||
|
||||
def _event_description(raw_args, minutes_before: Optional[int]) -> str:
|
||||
desc = str(raw_args.get("description", "") or "")
|
||||
if minutes_before is None:
|
||||
return desc
|
||||
reminder_only = re.compile(
|
||||
r"^\s*(?:remind(?:er)?|alarm)\s*:?\s*\d+\s*"
|
||||
r"(?:minutes?|mins?|m|hours?|hrs?|h)\b.*$",
|
||||
re.I,
|
||||
)
|
||||
return "" if reminder_only.match(desc) else desc
|
||||
|
||||
def _parse_event_dt(raw: str) -> tuple[datetime, bool]:
|
||||
"""Parse agent event datetimes in the user's timezone when available."""
|
||||
return _parse_dt_pair(parse_due_for_user(raw))
|
||||
|
||||
def _first_nonempty_arg(*names: str):
|
||||
for name in names:
|
||||
value = args.get(name)
|
||||
if value not in (None, ""):
|
||||
return value
|
||||
return None
|
||||
|
||||
def _create_calendar_reminder(summary: str, location: str, dtstart: datetime,
|
||||
all_day: bool, minutes_before: int,
|
||||
is_utc: bool = False) -> tuple[Optional[str], Optional[str]]:
|
||||
remind_at = dtstart - timedelta(minutes=minutes_before)
|
||||
now = datetime.utcnow() if is_utc else datetime.now()
|
||||
if dtstart <= now:
|
||||
return None, "event already passed"
|
||||
if remind_at <= now:
|
||||
# If the requested "before" time already passed but the event is
|
||||
# still upcoming, create an immediate Note reminder instead of
|
||||
# silently dropping it.
|
||||
remind_at = now
|
||||
start_fmt = dtstart.strftime("%a %b %d") if all_day else dtstart.strftime("%a %b %d %H:%M")
|
||||
loc = f" @ {location}" if location else ""
|
||||
text = f"{summary}{loc} — {start_fmt}"
|
||||
due_date = remind_at.isoformat() + ("Z" if is_utc else "")
|
||||
expected_title = f"Reminder: {summary}"
|
||||
existing_q = db.query(Note).filter(
|
||||
Note.archived == False, # noqa: E712
|
||||
Note.due_date == due_date,
|
||||
)
|
||||
if owner is not None:
|
||||
existing_q = existing_q.filter(Note.owner == owner)
|
||||
target_title = re.sub(r"^\s*reminder\s*:\s*", "", expected_title.strip().lower())
|
||||
for existing in existing_q.limit(25).all():
|
||||
existing_title = re.sub(r"^\s*reminder\s*:\s*", "", (existing.title or "").strip().lower())
|
||||
if existing_title == target_title:
|
||||
return existing.id, "duplicate reminder already exists"
|
||||
note = Note(
|
||||
id=str(_uuid.uuid4()),
|
||||
owner=owner,
|
||||
title=expected_title,
|
||||
items=json.dumps([{"text": text, "done": False, "checked": False}]),
|
||||
note_type="todo",
|
||||
label="calendar",
|
||||
due_date=due_date,
|
||||
source="calendar",
|
||||
)
|
||||
db.add(note)
|
||||
return note.id, None
|
||||
|
||||
try:
|
||||
if action == "list_calendars":
|
||||
_ensure_default_calendar(db, owner)
|
||||
cals = _calendar_query().all()
|
||||
result = [{"name": c.name, "href": c.id} for c in cals]
|
||||
if result:
|
||||
lines = [f"Found {len(result)} calendar(s):"]
|
||||
for c in result:
|
||||
lines.append(f"- {c['name']} ({c['href'][:8]})")
|
||||
response_text = "\n".join(lines)
|
||||
else:
|
||||
response_text = "No calendars found."
|
||||
return {"response": response_text, "calendars": result, "exit_code": 0}
|
||||
|
||||
elif action == "list_events":
|
||||
try:
|
||||
start_raw = _first_nonempty_arg(
|
||||
"start", "start_date", "range_start", "from", "dtstart", "since"
|
||||
)
|
||||
end_raw = _first_nonempty_arg(
|
||||
"end", "end_date", "range_end", "to", "dtend", "until"
|
||||
)
|
||||
if start_raw:
|
||||
start_dt = _parse_dt(start_raw)
|
||||
else:
|
||||
start_dt = datetime.utcnow().replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
if end_raw:
|
||||
end_dt = _parse_dt(end_raw)
|
||||
else:
|
||||
end_dt = start_dt + timedelta(days=14)
|
||||
except ValueError as e:
|
||||
return {"error": f"Invalid date format: {e}", "exit_code": 1}
|
||||
|
||||
if end_dt <= start_dt:
|
||||
end_dt = start_dt + timedelta(days=1)
|
||||
|
||||
q = _event_query().filter(
|
||||
CalendarEvent.dtstart < end_dt,
|
||||
CalendarEvent.dtend > start_dt,
|
||||
CalendarEvent.status != "cancelled",
|
||||
)
|
||||
calendar_filter = args.get("calendar")
|
||||
if calendar_filter:
|
||||
q = q.filter(
|
||||
(CalendarEvent.calendar_id == calendar_filter) |
|
||||
(CalendarCal.name == calendar_filter)
|
||||
)
|
||||
rows = q.order_by(CalendarEvent.dtstart).all()
|
||||
events = []
|
||||
for ev in rows:
|
||||
if ev.all_day:
|
||||
s, e = ev.dtstart.strftime("%Y-%m-%d"), ev.dtend.strftime("%Y-%m-%d")
|
||||
else:
|
||||
suffix = "Z" if getattr(ev, "is_utc", False) else ""
|
||||
s, e = ev.dtstart.isoformat() + suffix, ev.dtend.isoformat() + suffix
|
||||
events.append({
|
||||
"uid": ev.uid, "summary": ev.summary or "", "dtstart": s, "dtend": e,
|
||||
"all_day": ev.all_day, "description": ev.description or "",
|
||||
"location": ev.location or "",
|
||||
"calendar": ev.calendar.name if ev.calendar else "",
|
||||
"calendar_href": ev.calendar_id,
|
||||
"event_type": ev.event_type or "",
|
||||
"importance": ev.importance or "normal",
|
||||
})
|
||||
if not events:
|
||||
response_text = f"No events between {start_dt.date().isoformat()} and {end_dt.date().isoformat()}."
|
||||
else:
|
||||
lines = [f"Found {len(events)} event(s) between {start_dt.date().isoformat()} and {end_dt.date().isoformat()}:"]
|
||||
for ev in events:
|
||||
when = ev["dtstart"]
|
||||
when_str = f"{when} (all day)" if ev.get("all_day") else f"{when} -> {ev.get('dtend', '')}"
|
||||
# Clickable anchor — opens the calendar on the event's day.
|
||||
line = f"- {when_str}: [{ev['summary']}](#event-{ev['uid']})"
|
||||
if ev.get("event_type"):
|
||||
line += f" #{ev['event_type']}"
|
||||
if ev.get("importance") and ev["importance"] != "normal":
|
||||
line += f" !{ev['importance']}"
|
||||
if ev.get("location"):
|
||||
line += f" @ {ev['location']}"
|
||||
if ev.get("calendar"):
|
||||
line += f" ({ev['calendar']})"
|
||||
if ev.get("description"):
|
||||
desc = ev["description"].strip().replace("\n", " ")
|
||||
if len(desc) > 120:
|
||||
desc = desc[:117] + "..."
|
||||
line += f"\n {desc}"
|
||||
lines.append(line)
|
||||
response_text = "\n".join(lines)
|
||||
return {"response": response_text, "events": events, "exit_code": 0}
|
||||
|
||||
elif action == "create_event":
|
||||
summary = args.get("summary")
|
||||
# Accept the various names models like to use for the start
|
||||
# field: dtstart (canonical), start, start_time, when.
|
||||
dtstart_str = (args.get("dtstart") or args.get("start")
|
||||
or args.get("start_time") or args.get("when"))
|
||||
if not summary or not dtstart_str:
|
||||
return {"error": "summary and dtstart are required", "exit_code": 1}
|
||||
|
||||
# Accept either an href OR a calendar name/short-id like "Main"
|
||||
# or "62e545d8" — saves the model from having to memorize hrefs
|
||||
# after a `list_calendars` call returned short prefixes.
|
||||
cal_href = args.get("calendar_href") or args.get("calendar")
|
||||
cal = None
|
||||
if cal_href:
|
||||
cal = (_calendar_query()
|
||||
.filter(CalendarCal.id == cal_href)
|
||||
.first())
|
||||
if not cal:
|
||||
# Try by name (case-insensitive) or by short-id prefix
|
||||
cal = (_calendar_query()
|
||||
.filter(CalendarCal.name.ilike(cal_href))
|
||||
.first())
|
||||
if not cal:
|
||||
cal = (_calendar_query()
|
||||
.filter(CalendarCal.id.like(f"{cal_href}%"))
|
||||
.first())
|
||||
if not cal:
|
||||
cal = _ensure_default_calendar(db, owner)
|
||||
|
||||
all_day = bool(args.get("all_day", False))
|
||||
try:
|
||||
dtstart, dtstart_is_utc = _parse_event_dt(dtstart_str)
|
||||
except ValueError as e:
|
||||
return {"error": f"Could not parse dtstart {dtstart_str!r}: {e}", "exit_code": 1}
|
||||
dtend_raw = args.get("dtend") or args.get("end") or args.get("end_time")
|
||||
if dtend_raw:
|
||||
try:
|
||||
dtend, dtend_is_utc = _parse_event_dt(dtend_raw)
|
||||
dtstart_is_utc = dtstart_is_utc or dtend_is_utc
|
||||
except ValueError as e:
|
||||
return {"error": f"Could not parse dtend {dtend_raw!r}: {e}", "exit_code": 1}
|
||||
else:
|
||||
# Support duration: "1h", "30m", "90min", "1hr30m"
|
||||
dur = (args.get("duration") or "").strip().lower()
|
||||
delta = None
|
||||
if dur:
|
||||
import re as _re_d
|
||||
h = _re_d.search(r'(\d+)\s*(?:h|hr|hours?)', dur)
|
||||
m = _re_d.search(r'(\d+)\s*(?:m|min|minutes?)', dur)
|
||||
secs = (int(h.group(1)) * 3600 if h else 0) + (int(m.group(1)) * 60 if m else 0)
|
||||
if secs > 0:
|
||||
delta = timedelta(seconds=secs)
|
||||
if delta is not None:
|
||||
dtend = dtstart + delta
|
||||
elif all_day:
|
||||
dtend = dtstart + timedelta(days=1)
|
||||
else:
|
||||
dtend = dtstart + timedelta(hours=1)
|
||||
|
||||
# Dedup: if a non-cancelled event with the same title + start time already
|
||||
# exists, return its UID instead of creating a fresh copy. Prevents the
|
||||
# email triage from multiplying events when several emails reference the
|
||||
# same meeting. Compare case-insensitively since LLM-extracted titles
|
||||
# can vary in capitalisation.
|
||||
from sqlalchemy import func as _func
|
||||
existing = (
|
||||
_event_query()
|
||||
.filter(
|
||||
CalendarEvent.dtstart == dtstart,
|
||||
CalendarEvent.status != "cancelled",
|
||||
_func.lower(CalendarEvent.summary) == summary.lower(),
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if existing is not None:
|
||||
reminder_note_id = None
|
||||
reminder_skipped_reason = None
|
||||
minutes_before = _reminder_minutes(args)
|
||||
if minutes_before is not None:
|
||||
reminder_note_id, reminder_skipped_reason = _create_calendar_reminder(
|
||||
existing.summary or summary,
|
||||
existing.location or "",
|
||||
existing.dtstart,
|
||||
existing.all_day,
|
||||
minutes_before,
|
||||
bool(existing.is_utc),
|
||||
)
|
||||
if reminder_note_id:
|
||||
db.commit()
|
||||
reminder_text = ""
|
||||
if minutes_before is not None:
|
||||
reminder_text = (
|
||||
f"; reminder set {minutes_before} min before"
|
||||
if reminder_note_id
|
||||
else f"; reminder not set ({reminder_skipped_reason or 'reminder time already passed'})"
|
||||
)
|
||||
return {
|
||||
"response": (
|
||||
f"Event already exists: '{summary}' on {dtstart_str}"
|
||||
+ reminder_text
|
||||
),
|
||||
"uid": existing.uid,
|
||||
"reminder_note_id": reminder_note_id,
|
||||
"reminder_skipped_reason": reminder_skipped_reason,
|
||||
"duplicate": True,
|
||||
"exit_code": 0,
|
||||
}
|
||||
|
||||
# Optional tag/category and importance — friendly aliases.
|
||||
event_type = (args.get("event_type") or args.get("tag")
|
||||
or args.get("category") or args.get("type") or "") or None
|
||||
importance = args.get("importance") or "normal"
|
||||
minutes_before = _reminder_minutes(args)
|
||||
|
||||
uid = str(_uuid.uuid4())
|
||||
ev = CalendarEvent(
|
||||
uid=uid, calendar_id=cal.id, summary=summary,
|
||||
description=_event_description(args, minutes_before),
|
||||
location=args.get("location", "") or "",
|
||||
dtstart=dtstart, dtend=dtend, all_day=all_day,
|
||||
is_utc=dtstart_is_utc and not all_day,
|
||||
rrule=args.get("rrule", "") or "",
|
||||
event_type=event_type,
|
||||
importance=importance,
|
||||
caldav_sync_pending="create" if cal.source == "caldav" else None,
|
||||
)
|
||||
db.add(ev)
|
||||
reminder_note_id = None
|
||||
reminder_skipped_reason = None
|
||||
if minutes_before is not None:
|
||||
reminder_note_id, reminder_skipped_reason = _create_calendar_reminder(
|
||||
summary,
|
||||
args.get("location", "") or "",
|
||||
dtstart,
|
||||
all_day,
|
||||
minutes_before,
|
||||
dtstart_is_utc and not all_day,
|
||||
)
|
||||
db.commit()
|
||||
if cal.source == "caldav":
|
||||
await _push_caldav_event_after_commit(owner, uid, "create")
|
||||
tag_blurb = f" [{event_type}]" if event_type else ""
|
||||
if minutes_before is None:
|
||||
reminder_blurb = ""
|
||||
elif reminder_note_id:
|
||||
reminder_blurb = f" with reminder {minutes_before} min before"
|
||||
else:
|
||||
reminder_blurb = f" without reminder ({reminder_skipped_reason or 'reminder time already passed'})"
|
||||
# Return a clickable anchor so the agent can surface a link
|
||||
# that opens the calendar on that day. See the markdown
|
||||
# anchor convention ([Name](#event-<uid>)).
|
||||
return {
|
||||
"response": f"Created event [{summary}](#event-{uid}){tag_blurb} on {dtstart_str}{reminder_blurb}",
|
||||
"uid": uid,
|
||||
"anchor": f"[{summary}](#event-{uid})",
|
||||
"reminder_note_id": reminder_note_id,
|
||||
"reminder_skipped_reason": reminder_skipped_reason,
|
||||
"exit_code": 0,
|
||||
}
|
||||
|
||||
elif action == "update_event":
|
||||
uid = args.get("uid")
|
||||
if not uid:
|
||||
return {"error": "uid is required", "exit_code": 1}
|
||||
try:
|
||||
base_uid = _resolve_base_uid(uid)
|
||||
except ValueError as e:
|
||||
return {"error": str(e), "exit_code": 1}
|
||||
ev = _event_query().filter(CalendarEvent.uid == base_uid).first()
|
||||
if not ev:
|
||||
return {"error": f"Event {uid} not found", "exit_code": 1}
|
||||
if args.get("summary") is not None:
|
||||
ev.summary = args["summary"]
|
||||
if args.get("description") is not None:
|
||||
ev.description = args["description"]
|
||||
if args.get("location") is not None:
|
||||
ev.location = args["location"]
|
||||
if args.get("dtstart") is not None:
|
||||
# Anchor naive/natural-language input to the USER's timezone and
|
||||
# refresh is_utc, exactly like create_event. Parsing with the
|
||||
# raw server-local _parse_dt here (and never touching is_utc)
|
||||
# silently shifted an updated event by the user's UTC offset.
|
||||
_eff_all_day = (
|
||||
args["all_day"] if args.get("all_day") is not None else ev.all_day
|
||||
)
|
||||
ev.dtstart, _su = _parse_event_dt(args["dtstart"])
|
||||
ev.is_utc = bool(_su and not _eff_all_day)
|
||||
if args.get("dtend") is not None:
|
||||
ev.dtend, _eu = _parse_event_dt(args["dtend"])
|
||||
if args.get("all_day") is not None:
|
||||
ev.all_day = args["all_day"]
|
||||
# Tag/category + importance updates (any of these aliases).
|
||||
_tag = (args.get("event_type") or args.get("tag")
|
||||
or args.get("category") or args.get("type"))
|
||||
if _tag is not None:
|
||||
ev.event_type = _tag or None
|
||||
if args.get("importance") is not None:
|
||||
ev.importance = args["importance"]
|
||||
is_caldav = ev.calendar and ev.calendar.source == "caldav"
|
||||
if is_caldav:
|
||||
ev.caldav_sync_pending = "update"
|
||||
db.commit()
|
||||
if is_caldav:
|
||||
await _push_caldav_event_after_commit(owner, base_uid, "update")
|
||||
return {"response": f"Updated event {uid}", "exit_code": 0}
|
||||
|
||||
elif action == "delete_event":
|
||||
uid = args.get("uid")
|
||||
if not uid:
|
||||
return {"error": "uid is required", "exit_code": 1}
|
||||
try:
|
||||
base_uid = _resolve_base_uid(uid)
|
||||
except ValueError as e:
|
||||
return {"error": str(e), "exit_code": 1}
|
||||
ev = _event_query().filter(CalendarEvent.uid == base_uid).first()
|
||||
if not ev:
|
||||
return {"error": f"Event {uid} not found", "exit_code": 1}
|
||||
is_caldav = ev.calendar and ev.calendar.source == "caldav" and ev.remote_href
|
||||
if is_caldav:
|
||||
_record_caldav_delete_tombstone(db, ev, owner)
|
||||
db.delete(ev)
|
||||
db.commit()
|
||||
if is_caldav:
|
||||
await _push_caldav_event_after_commit(owner, base_uid, "delete")
|
||||
return {"response": f"Deleted event {uid}", "exit_code": 0}
|
||||
|
||||
else:
|
||||
return {
|
||||
"error": f"Unknown action: {action}. Use list_events, create_event, update_event, delete_event, list_calendars",
|
||||
"exit_code": 1,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
logger.error(f"manage_calendar error: {e}")
|
||||
return {"error": str(e), "exit_code": 1}
|
||||
finally:
|
||||
db.close()
|
||||
@@ -1,148 +0,0 @@
|
||||
"""Contacts-domain tool implementations.
|
||||
|
||||
Extracted from tool_implementations.py as part of slice 1 (#4082/#4071).
|
||||
Holds the resolve_contact and manage_contact (CardDAV CRUD) tools.
|
||||
``src.tool_implementations`` re-exports these for backward compatibility.
|
||||
``_INTERNAL_BASE`` still lives in tool_implementations.py and is pulled
|
||||
back function-locally where needed.
|
||||
"""
|
||||
from typing import Dict, Optional
|
||||
|
||||
from src.tools._common import _parse_tool_args
|
||||
|
||||
|
||||
async def do_resolve_contact(content: str, owner: Optional[str] = None) -> Dict:
|
||||
"""Look up a contact by name. Searches: CardDAV -> email history -> memory."""
|
||||
import httpx
|
||||
from src.tool_implementations import _INTERNAL_BASE # shared constant, still lives in the facade
|
||||
try:
|
||||
args = _parse_tool_args(content)
|
||||
except ValueError:
|
||||
return {"error": "Invalid JSON arguments", "exit_code": 1}
|
||||
name = args.get("name", "")
|
||||
if not name:
|
||||
return {"error": "name is required", "exit_code": 1}
|
||||
|
||||
contacts = {} # email_or_phone -> {name, source, phone?}
|
||||
|
||||
# 1. CardDAV (Radicale) — structured contacts. Call in-process: a
|
||||
# server-side httpx GET to /api/contacts/search carries no session
|
||||
# cookie and would 401 under require_user.
|
||||
try:
|
||||
import asyncio
|
||||
from routes import contacts_routes as cc
|
||||
all_contacts = await asyncio.to_thread(cc._fetch_contacts)
|
||||
q = name.lower()
|
||||
for c in (all_contacts or []):
|
||||
hay_name = (c.get("name") or "").lower()
|
||||
match = q in hay_name or any(q in (e or "").lower() for e in c.get("emails", []))
|
||||
if not match:
|
||||
continue
|
||||
has_email = False
|
||||
for email in (c.get("emails") or []):
|
||||
email = (email or "").strip().lower()
|
||||
if email and "@" in email:
|
||||
contacts[email] = {"name": c.get("name") or email, "source": "contacts"}
|
||||
has_email = True
|
||||
# Fall back to phone numbers when the contact has no email address
|
||||
if not has_email:
|
||||
for phone in (c.get("phones") or []):
|
||||
phone = (phone or "").strip()
|
||||
if phone:
|
||||
contacts[phone] = {"name": c.get("name") or phone, "source": "contacts", "phone": phone}
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
async with httpx.AsyncClient(timeout=30) as client:
|
||||
# 2. Email history (sent/received)
|
||||
try:
|
||||
resp = await client.get(f"{_INTERNAL_BASE}/api/email/resolve-contact", params={"name": name})
|
||||
if resp.status_code == 200:
|
||||
for c in (resp.json().get("contacts") or []):
|
||||
email = (c.get("email") or "").strip().lower()
|
||||
if email and email not in contacts:
|
||||
contacts[email] = {"name": c.get("name") or email, "source": "email history"}
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if not contacts:
|
||||
return {"output": f"No contacts found matching '{name}'.", "exit_code": 0}
|
||||
|
||||
lines = [f"Contacts matching '{name}':"]
|
||||
for key, info in contacts.items():
|
||||
if info.get("phone"):
|
||||
lines.append(f"- {info['name']} — phone: {info['phone']} ({info['source']})")
|
||||
else:
|
||||
lines.append(f"- {info['name']} <{key}> ({info['source']})")
|
||||
return {"output": "\n".join(lines), "exit_code": 0}
|
||||
|
||||
|
||||
async def do_manage_contact(content: str, owner: Optional[str] = None) -> Dict:
|
||||
"""Add / update / delete / list CardDAV contacts. Calls the contacts
|
||||
helpers IN-PROCESS rather than over HTTP — a server-side httpx call to
|
||||
/api/contacts/* carries no session cookie and would be rejected by
|
||||
require_user (401), so the tool would see zero contacts even though
|
||||
the browser-side UI works fine."""
|
||||
try:
|
||||
args = _parse_tool_args(content)
|
||||
except ValueError:
|
||||
return {"error": "Invalid JSON arguments", "exit_code": 1}
|
||||
action = (args.get("action") or "").strip().lower()
|
||||
try:
|
||||
from routes import contacts_routes as cc
|
||||
except Exception as e:
|
||||
return {"error": f"Contacts module unavailable: {e}", "exit_code": 1}
|
||||
# The contacts helpers are sync (httpx blocking calls to CardDAV) — run
|
||||
# them in a thread so we don't block the event loop.
|
||||
import asyncio
|
||||
try:
|
||||
if action == "list":
|
||||
rows = await asyncio.to_thread(cc._fetch_contacts, True)
|
||||
if not rows:
|
||||
return {"output": "No contacts.", "exit_code": 0}
|
||||
lines = [f"{len(rows)} contacts:"]
|
||||
for c in rows:
|
||||
em = ", ".join(c.get("emails") or [])
|
||||
lines.append(f"- {c.get('name') or '(no name)'} <{em}> [uid={c.get('uid','')}]")
|
||||
return {"output": "\n".join(lines), "exit_code": 0}
|
||||
|
||||
if action == "add":
|
||||
email = (args.get("email") or "").strip()
|
||||
if not email:
|
||||
return {"error": "email is required for add", "exit_code": 1}
|
||||
name = (args.get("name") or "").strip() or email.split("@")[0]
|
||||
# Dedupe by email (same as the /add route).
|
||||
existing = await asyncio.to_thread(cc._fetch_contacts)
|
||||
for c in existing:
|
||||
if email.lower() in [e.lower() for e in c.get("emails", [])]:
|
||||
return {"output": f"{email} is already a contact ({c.get('name','')}).", "exit_code": 0}
|
||||
ok = await asyncio.to_thread(cc._create_contact, name, email)
|
||||
return {"output": f"{'Added' if ok else 'Failed to add'} {name} <{email}>.", "exit_code": 0 if ok else 1}
|
||||
|
||||
if action in ("update", "edit"):
|
||||
uid = (args.get("uid") or "").strip()
|
||||
if not uid:
|
||||
return {"error": "uid is required for update (use action=list to find it)", "exit_code": 1}
|
||||
name = (args.get("name") or "").strip()
|
||||
emails = args.get("emails")
|
||||
if emails is None and args.get("email"):
|
||||
emails = [args["email"]]
|
||||
emails = [e.strip() for e in (emails or []) if e and e.strip()]
|
||||
phones = [p.strip() for p in (args.get("phones") or []) if p and p.strip()]
|
||||
if not name and not emails:
|
||||
return {"error": "Provide a name or emails to update", "exit_code": 1}
|
||||
if not name and emails:
|
||||
name = emails[0].split("@")[0]
|
||||
ok = await asyncio.to_thread(cc._update_contact, uid, name, emails, phones)
|
||||
return {"output": "Contact updated." if ok else "Update failed.", "exit_code": 0 if ok else 1}
|
||||
|
||||
if action == "delete":
|
||||
uid = (args.get("uid") or "").strip()
|
||||
if not uid:
|
||||
return {"error": "uid is required for delete (use action=list to find it)", "exit_code": 1}
|
||||
ok = await asyncio.to_thread(cc._delete_contact, uid)
|
||||
return {"output": "Contact deleted." if ok else "Delete failed.", "exit_code": 0 if ok else 1}
|
||||
|
||||
return {"error": f"Unknown action '{action}'. Use list, add, update, or delete.", "exit_code": 1}
|
||||
except Exception as e:
|
||||
return {"error": f"Contact operation failed: {e}", "exit_code": 1}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,39 +0,0 @@
|
||||
"""Image-domain tool implementations.
|
||||
|
||||
Extracted from tool_implementations.py as part of slice 1 (#4082/#4071).
|
||||
Holds the edit_image (gallery) tool.
|
||||
``src.tool_implementations`` re-exports these for backward compatibility.
|
||||
``_INTERNAL_BASE`` still lives in tool_implementations.py and is pulled back
|
||||
function-locally here.
|
||||
"""
|
||||
from typing import Dict, Optional
|
||||
|
||||
from src.tools._common import _parse_tool_args
|
||||
|
||||
|
||||
async def do_edit_image(content: str, owner: Optional[str] = None) -> Dict:
|
||||
"""Edit a gallery image (upscale, rembg, inpaint, harmonize)."""
|
||||
import httpx
|
||||
from src.tool_implementations import _INTERNAL_BASE # shared constant, still lives in the facade
|
||||
try:
|
||||
args = _parse_tool_args(content)
|
||||
except ValueError:
|
||||
return {"error": "Invalid JSON arguments", "exit_code": 1}
|
||||
image_id = args.get("image_id", "")
|
||||
action = args.get("action", "")
|
||||
if not image_id or not action:
|
||||
return {"error": "image_id and action are required", "exit_code": 1}
|
||||
payload = {"image_id": image_id}
|
||||
if args.get("prompt"):
|
||||
payload["prompt"] = args["prompt"]
|
||||
if args.get("scale"):
|
||||
payload["scale"] = args["scale"]
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=120) as client:
|
||||
resp = await client.post(f"{_INTERNAL_BASE}/api/gallery/{action}", json=payload)
|
||||
data = resp.json()
|
||||
if data.get("success") or data.get("id"):
|
||||
return {"output": f"Image edited ({action}). New image ID: {data.get('id', '?')}", "exit_code": 0}
|
||||
return {"error": data.get("error", f"{action} failed"), "exit_code": 1}
|
||||
except Exception as e:
|
||||
return {"error": str(e), "exit_code": 1}
|
||||
@@ -1,254 +0,0 @@
|
||||
"""Notes-domain tool implementations.
|
||||
|
||||
Extracted from tool_implementations.py as part of slice 1 (#4082/#4071).
|
||||
Holds the manage_notes tool (notes + checklists CRUD).
|
||||
``src.tool_implementations`` re-exports these for backward compatibility.
|
||||
"""
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from typing import Dict, Optional
|
||||
|
||||
from src.tools._common import _parse_tool_args
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def do_manage_notes(content: str, owner: Optional[str] = None) -> Dict:
|
||||
"""Handle manage_notes tool calls: CRUD on notes and checklists."""
|
||||
import uuid as _uuid
|
||||
from core.database import SessionLocal, Note
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
try:
|
||||
args = _parse_tool_args(content)
|
||||
except ValueError:
|
||||
return {"error": "Invalid JSON arguments", "exit_code": 1}
|
||||
|
||||
# Action aliases — match what models actually emit. `create` is the most
|
||||
# common alternative to `add`. Hyphenated forms also accepted.
|
||||
action = (args.get("action") or "").replace("-", "_").strip().lower()
|
||||
_NOTE_ACTION_ALIASES = {
|
||||
"create": "add",
|
||||
"new": "add",
|
||||
"save": "add",
|
||||
"remind": "add",
|
||||
"remove": "delete",
|
||||
"remove_item": "toggle_item",
|
||||
}
|
||||
action = _NOTE_ACTION_ALIASES.get(action, action)
|
||||
db = SessionLocal()
|
||||
|
||||
def _norm_note_title(value: str) -> str:
|
||||
text = (value or "").strip().lower()
|
||||
text = re.sub(r"^\s*reminder\s*:\s*", "", text)
|
||||
return re.sub(r"\s+", " ", text)
|
||||
|
||||
def _note_visible_to_owner(note, owner_value: Optional[str]) -> bool:
|
||||
# Empty owner_value is single-user / auth-disabled mode. A real
|
||||
# authenticated owner must match exactly; null/empty legacy rows are not
|
||||
# shared between accounts.
|
||||
if not owner_value:
|
||||
return True
|
||||
return getattr(note, "owner", None) == owner_value
|
||||
|
||||
def _note_by_prefix(note_id: str):
|
||||
if not note_id:
|
||||
return None
|
||||
q = db.query(Note).filter(Note.id.startswith(note_id))
|
||||
if owner:
|
||||
q = q.filter(Note.owner == owner)
|
||||
return q.first()
|
||||
|
||||
try:
|
||||
if action == "list":
|
||||
q = db.query(Note)
|
||||
if owner is not None:
|
||||
q = q.filter(Note.owner == owner)
|
||||
if args.get("label"):
|
||||
q = q.filter(Note.label == args["label"])
|
||||
show_archived = args.get("archived", False)
|
||||
q = q.filter(Note.archived == show_archived)
|
||||
notes = q.order_by(Note.pinned.desc(), Note.updated_at.desc()).all()
|
||||
if not notes:
|
||||
return {"response": "No notes found.", "exit_code": 0}
|
||||
lines = []
|
||||
for n in notes:
|
||||
pin = " [PINNED]" if n.pinned else ""
|
||||
typ = " [checklist]" if n.note_type == "checklist" else ""
|
||||
lbl = f" #{n.label}" if n.label else ""
|
||||
title = n.title or "(untitled)"
|
||||
lines.append(f"- [{n.id[:8]}] **{title}**{pin}{typ}{lbl}")
|
||||
if n.note_type == "checklist" and n.items:
|
||||
try:
|
||||
items = json.loads(n.items)
|
||||
for i, item in enumerate(items):
|
||||
mark = "x" if item.get("done") else " "
|
||||
lines.append(f" [{mark}] {i}: {item.get('text', '')}")
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
pass
|
||||
elif n.content:
|
||||
snippet = n.content[:80].replace("\n", " ")
|
||||
lines.append(f" {snippet}")
|
||||
return {"results": "\n".join(lines)}
|
||||
|
||||
elif action == "add":
|
||||
# Accept the various field names models emit: `text` is the most
|
||||
# common stand-in for "title or body content" when the model
|
||||
# treats the note as a single string. If text was supplied and
|
||||
# neither title nor content, use it as the title.
|
||||
title = (args.get("title") or "").strip()
|
||||
content_raw = args.get("content")
|
||||
text_raw = args.get("text") or args.get("body")
|
||||
if not title and not content_raw and text_raw:
|
||||
title = text_raw.strip()
|
||||
elif not content_raw and text_raw:
|
||||
content_raw = text_raw
|
||||
# Accept both `items` (legacy/internal field) and `checklist_items`
|
||||
# (the schema-exposed name used by native function calls). Models
|
||||
# following the schema emit `checklist_items`; older code paths
|
||||
# and direct API callers still use `items`.
|
||||
items_raw = args.get("checklist_items")
|
||||
if items_raw is None:
|
||||
items_raw = args.get("items")
|
||||
items_json = json.dumps(items_raw) if items_raw is not None else None
|
||||
note_type = args.get("note_type", "checklist" if items_raw else "note")
|
||||
# Accept natural-language due_date ("tomorrow at 1pm") in
|
||||
# addition to ISO. Use the user-tz-aware parser so the LLM's
|
||||
# naive times ("today at 9pm") are anchored to the USER's clock,
|
||||
# not the server's. Returns ISO with explicit offset so frontend
|
||||
# `new Date()` resolves the right absolute moment regardless of
|
||||
# where the user is.
|
||||
due_raw = args.get("due_date")
|
||||
due_iso = None
|
||||
if due_raw:
|
||||
try:
|
||||
from routes.calendar_routes import parse_due_for_user as _pdt_user
|
||||
due_iso = _pdt_user(due_raw)
|
||||
except Exception:
|
||||
due_iso = due_raw # fall through; trust the model
|
||||
if due_iso and title:
|
||||
# Calendar event reminders are represented as Notes. If the
|
||||
# model creates a calendar event with reminder_minutes and then
|
||||
# also creates a separate note reminder for the same title/time,
|
||||
# keep the existing note so the user gets only one dispatch.
|
||||
existing_q = db.query(Note).filter(
|
||||
Note.archived == False, # noqa: E712
|
||||
Note.due_date == due_iso,
|
||||
)
|
||||
if owner is not None:
|
||||
existing_q = existing_q.filter(Note.owner == owner)
|
||||
target_title = _norm_note_title(title)
|
||||
for existing in existing_q.limit(25).all():
|
||||
if _norm_note_title(existing.title or "") == target_title:
|
||||
return {
|
||||
"response": f"Reminder already exists: \"{existing.title or title}\" (id: {existing.id[:8]})",
|
||||
"note_id": existing.id,
|
||||
"duplicate": True,
|
||||
"exit_code": 0,
|
||||
}
|
||||
note = Note(
|
||||
id=str(_uuid.uuid4()),
|
||||
owner=owner,
|
||||
title=title,
|
||||
content=content_raw,
|
||||
items=items_json,
|
||||
note_type=note_type,
|
||||
color=args.get("color"),
|
||||
label=args.get("label"),
|
||||
pinned=args.get("pinned", False),
|
||||
due_date=due_iso,
|
||||
source="agent",
|
||||
session_id=args.get("session_id"),
|
||||
)
|
||||
db.add(note)
|
||||
db.commit()
|
||||
# Return note_id so the chat-side renderer can build a real
|
||||
# "View note" button that opens the notes modal at this id.
|
||||
# Previously the create response only included a prose
|
||||
# confirmation; the model would type "View note" as a markdown
|
||||
# link with no target, leaving the user with a click that
|
||||
# did nothing and uncertainty about whether the note was made.
|
||||
return {
|
||||
"response": f"Note created: \"{title or '(untitled)'}\" (id: {note.id[:8]})",
|
||||
"note_id": note.id,
|
||||
"note_title": title or "",
|
||||
"open_url": f"/#open=notes¬e={note.id}",
|
||||
"exit_code": 0,
|
||||
}
|
||||
|
||||
elif action == "update":
|
||||
note_id = args.get("id", "")
|
||||
note = _note_by_prefix(note_id)
|
||||
if not note:
|
||||
return {"error": f"Note '{note_id}' not found", "exit_code": 1}
|
||||
if not _note_visible_to_owner(note, owner):
|
||||
return {"error": "Note not found", "exit_code": 1}
|
||||
for field in ("title", "content", "note_type", "color", "label"):
|
||||
if field in args and args[field] is not None:
|
||||
setattr(note, field, args[field])
|
||||
# Parse due_date the same way the `add` action does. The schema
|
||||
# advertises natural language ("tomorrow at 9am"), and naive ISO
|
||||
# strings need the user's tz offset attached so the frontend's
|
||||
# `new Date()` resolves the right absolute moment. Storing the raw
|
||||
# value here left updated reminders as unparseable literals that
|
||||
# never fired.
|
||||
if args.get("due_date") is not None:
|
||||
due_raw = args["due_date"]
|
||||
try:
|
||||
from routes.calendar_routes import parse_due_for_user as _pdt_user
|
||||
note.due_date = _pdt_user(due_raw)
|
||||
except Exception:
|
||||
note.due_date = due_raw # fall through; trust the model
|
||||
new_items = args.get("checklist_items")
|
||||
if new_items is None:
|
||||
new_items = args.get("items")
|
||||
if new_items is not None:
|
||||
note.items = json.dumps(new_items)
|
||||
flag_modified(note, "items")
|
||||
if "pinned" in args:
|
||||
note.pinned = args["pinned"]
|
||||
if "archived" in args:
|
||||
note.archived = args["archived"]
|
||||
db.commit()
|
||||
return {"response": f"Note updated: \"{note.title or '(untitled)'}\"", "exit_code": 0}
|
||||
|
||||
elif action == "delete":
|
||||
note_id = args.get("id", "")
|
||||
note = _note_by_prefix(note_id)
|
||||
if not note:
|
||||
return {"error": f"Note '{note_id}' not found", "exit_code": 1}
|
||||
if not _note_visible_to_owner(note, owner):
|
||||
return {"error": "Note not found", "exit_code": 1}
|
||||
title = note.title
|
||||
db.delete(note)
|
||||
db.commit()
|
||||
return {"response": f"Deleted note: \"{title or '(untitled)'}\"", "exit_code": 0}
|
||||
|
||||
elif action == "toggle_item":
|
||||
note_id = args.get("id", "")
|
||||
index = args.get("index", 0)
|
||||
note = _note_by_prefix(note_id)
|
||||
if not note:
|
||||
return {"error": f"Note '{note_id}' not found", "exit_code": 1}
|
||||
if not _note_visible_to_owner(note, owner):
|
||||
return {"error": "Note not found", "exit_code": 1}
|
||||
if not note.items:
|
||||
return {"error": "Note has no checklist items", "exit_code": 1}
|
||||
items = json.loads(note.items)
|
||||
if index < 0 or index >= len(items):
|
||||
return {"error": f"Item index {index} out of range (0-{len(items)-1})", "exit_code": 1}
|
||||
items[index]["done"] = not items[index].get("done", False)
|
||||
note.items = json.dumps(items)
|
||||
flag_modified(note, "items")
|
||||
db.commit()
|
||||
mark = "done" if items[index]["done"] else "undone"
|
||||
return {"response": f"Item '{items[index].get('text', '')}' marked {mark}", "exit_code": 0}
|
||||
|
||||
else:
|
||||
return {"error": f"Unknown action: {action}. Use list/add/update/delete/toggle_item", "exit_code": 1}
|
||||
except Exception as e:
|
||||
logger.error(f"manage_notes error: {e}")
|
||||
return {"error": str(e), "exit_code": 1}
|
||||
finally:
|
||||
db.close()
|
||||
@@ -1,142 +0,0 @@
|
||||
"""Research-domain tool implementations.
|
||||
|
||||
Extracted from tool_implementations.py as part of slice 1 (#4082/#4071).
|
||||
Holds the manage_research (library CRUD) and trigger_research (live job)
|
||||
tools.
|
||||
``src.tool_implementations`` re-exports these for backward compatibility.
|
||||
``_internal_headers`` and ``_INTERNAL_BASE`` still live in
|
||||
tool_implementations.py and are pulled back function-locally where needed.
|
||||
"""
|
||||
import re
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from src.constants import DEEP_RESEARCH_DIR
|
||||
from src.tools._common import _parse_tool_args
|
||||
|
||||
|
||||
async def do_manage_research(content: str, owner: Optional[str] = None) -> Dict:
|
||||
"""List, read/open, or delete saved deep-research results from the Library.
|
||||
Args (JSON): {"action": "list|read|delete", "id": "<id>", "search": "..."}.
|
||||
Research is stored as data/deep_research/<id>.json (query, summary, sources)."""
|
||||
import json as _json
|
||||
from pathlib import Path as _Path
|
||||
try:
|
||||
args = _parse_tool_args(content) if content.strip().startswith("{") else {}
|
||||
except ValueError:
|
||||
args = {}
|
||||
if not isinstance(args, dict):
|
||||
args = {}
|
||||
action = (args.get("action") or "list").lower()
|
||||
rid = (args.get("id") or args.get("session_id") or args.get("research_id") or "").strip()
|
||||
data_dir = _Path(DEEP_RESEARCH_DIR)
|
||||
|
||||
# SECURITY: the research id is interpolated straight into a filesystem
|
||||
# path (data/deep_research/<rid>.json) for read AND delete. Without this
|
||||
# gate an agent-supplied id like "../settings" or "../../etc/passwd"
|
||||
# escapes the research dir — reading exfiltrates arbitrary *.json into
|
||||
# chat, deleting unlinks arbitrary *.json on disk. Allow only a bare
|
||||
# token (research session ids are hex/uuid/slug — no separators).
|
||||
if rid and not re.fullmatch(r"[A-Za-z0-9_-]+", rid):
|
||||
return {"error": "Invalid research id."}
|
||||
|
||||
def _load(p):
|
||||
try:
|
||||
return _json.loads(p.read_text(encoding="utf-8"))
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
if action in ("read", "open", "view", "get"):
|
||||
if not rid:
|
||||
return {"error": "Provide the research id (from action='list')."}
|
||||
p = data_dir / f"{rid}.json"
|
||||
if not p.exists():
|
||||
return {"error": f"Research '{rid}' not found."}
|
||||
d = _load(p) or {}
|
||||
summary = d.get("result") or d.get("raw_report") or d.get("summary") or d.get("report") or "(no report body)"
|
||||
srcs = d.get("sources", []) or []
|
||||
out = f"# {d.get('query', '(untitled)')}\n\n{summary}"
|
||||
if srcs:
|
||||
out += "\n\nSources:\n" + "\n".join(
|
||||
f"- {s.get('title') or s.get('url', '')}: {s.get('url', '')}" for s in srcs[:30]
|
||||
)
|
||||
return {"output": out[:16000], "exit_code": 0}
|
||||
|
||||
if action == "delete":
|
||||
if not rid:
|
||||
return {"error": "Provide the research id to delete (from action='list')."}
|
||||
p = data_dir / f"{rid}.json"
|
||||
if p.exists():
|
||||
try:
|
||||
p.unlink()
|
||||
except Exception as e:
|
||||
return {"error": f"Failed to delete: {e}"}
|
||||
return {"output": f"Deleted research '{rid}'.", "exit_code": 0}
|
||||
return {"error": f"Research '{rid}' not found."}
|
||||
|
||||
# default: list — clickable [query](#research-<id>) rows, most-recent first
|
||||
search = (args.get("search") or "").lower()
|
||||
items = []
|
||||
if data_dir.exists():
|
||||
for p in data_dir.glob("*.json"):
|
||||
d = _load(p)
|
||||
if not d:
|
||||
continue
|
||||
q = d.get("query", "")
|
||||
if search and search not in q.lower():
|
||||
continue
|
||||
items.append((d.get("completed_at", 0) or 0, p.stem, q, len(d.get("sources", []) or [])))
|
||||
items.sort(reverse=True)
|
||||
if not items:
|
||||
return {"output": "No research found in the library." + (f" (search: {search})" if search else ""), "exit_code": 0}
|
||||
rows = "\n".join(f"- [{q or '(untitled)'}](#research-{sid}) — {n} sources" for _, sid, q, n in items[:50])
|
||||
return {"output": f"Research library ({len(items)} item{'s' if len(items) != 1 else ''}):\n{rows}", "exit_code": 0}
|
||||
|
||||
|
||||
async def do_trigger_research(content: str, owner: Optional[str] = None) -> Dict:
|
||||
"""Start a live deep-research job that appears in the Deep Research
|
||||
sidebar. Hits /api/research/start (the same path the sidebar's
|
||||
'Research' button uses) so the session is discoverable + streamable
|
||||
there, rather than creating a scheduled task that never surfaces."""
|
||||
import httpx
|
||||
from src.tool_implementations import _internal_headers, _INTERNAL_BASE # shared constants, still live in the facade
|
||||
try:
|
||||
args = _parse_tool_args(content)
|
||||
except ValueError:
|
||||
return {"error": "Invalid JSON arguments", "exit_code": 1}
|
||||
topic = args.get("topic", "") or args.get("query", "")
|
||||
if not topic:
|
||||
return {"error": "topic (or query) is required", "exit_code": 1}
|
||||
payload: Dict[str, Any] = {"query": topic}
|
||||
# Optional knobs the research panel supports.
|
||||
if args.get("max_rounds") is not None:
|
||||
try: payload["max_rounds"] = int(args["max_rounds"])
|
||||
except (ValueError, TypeError): pass
|
||||
if args.get("max_time") is not None:
|
||||
try: payload["max_time"] = int(args["max_time"])
|
||||
except (ValueError, TypeError): pass
|
||||
if args.get("category"):
|
||||
payload["category"] = args["category"]
|
||||
if args.get("search_provider"):
|
||||
payload["search_provider"] = args["search_provider"]
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=30) as client:
|
||||
resp = await client.post(f"{_INTERNAL_BASE}/api/research/start",
|
||||
json=payload, headers=_internal_headers(owner))
|
||||
if resp.status_code >= 400:
|
||||
return {"error": f"research/start returned HTTP {resp.status_code}: {resp.text[:200]}", "exit_code": 1}
|
||||
data = resp.json()
|
||||
sid = data.get("session_id", "?")
|
||||
return {
|
||||
"output": (
|
||||
f"Deep research started: [{topic}](#research-{sid}). "
|
||||
"Click to open the Deep Research sidebar and watch progress / read the report."
|
||||
),
|
||||
"session_id": sid,
|
||||
"anchor": f"[{topic}](#research-{sid})",
|
||||
# UI hint so the frontend can open/refresh the research panel.
|
||||
"ui_event": "research_started",
|
||||
"research_session_id": sid,
|
||||
"exit_code": 0,
|
||||
}
|
||||
except Exception as e:
|
||||
return {"error": str(e), "exit_code": 1}
|
||||
@@ -1,51 +0,0 @@
|
||||
"""Search-domain tool implementations.
|
||||
|
||||
Extracted from tool_implementations.py as part of slice 1 (#4082/#4071).
|
||||
Holds the search_chats tool.
|
||||
``src.tool_implementations`` re-exports these for backward compatibility.
|
||||
"""
|
||||
import logging
|
||||
from typing import Dict
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def do_search_chats(query: str, limit: int = 20, owner: str | None = None) -> Dict:
|
||||
"""Search past session transcripts for the calling user's sessions only.
|
||||
|
||||
Without an owner filter this used to leak EVERY user's chat history
|
||||
into the agent's `search_chats` results (v2 review HIGH-11). The
|
||||
caller in `tool_execution.execute_tool_block` now plumbs the owner
|
||||
through; legacy callers without owner pass through as before but
|
||||
will only see legacy/null-owner rows.
|
||||
"""
|
||||
try:
|
||||
from src.session_search import search_session_messages
|
||||
|
||||
results = search_session_messages(query, limit=limit, owner=owner)
|
||||
if not results:
|
||||
return {"results": f"No chats found matching \"{query}\"."}
|
||||
|
||||
# Group by session to avoid duplicate links
|
||||
seen_sessions = {}
|
||||
for result in results:
|
||||
if result.session_id not in seen_sessions:
|
||||
seen_sessions[result.session_id] = result
|
||||
|
||||
lines = [f"Found {len(seen_sessions)} session(s) matching \"{query}\":\n"]
|
||||
for sid, result in seen_sessions.items():
|
||||
lines.append(f"- **{result.session_name}** (#{sid})")
|
||||
lines.append(f" Link: [Open chat](#{sid})")
|
||||
lines.append(f" Match ({result.role}): {result.content_snippet}")
|
||||
if result.context_before:
|
||||
before = result.context_before[-1]
|
||||
lines.append(f" Before ({before['role']}): {before['content'][:180]}")
|
||||
if result.context_after:
|
||||
after = result.context_after[0]
|
||||
lines.append(f" After ({after['role']}): {after['content'][:180]}")
|
||||
lines.append("")
|
||||
|
||||
return {"results": "\n".join(lines)}
|
||||
except Exception as e:
|
||||
logger.error(f"search_chats failed: {e}")
|
||||
return {"error": str(e), "exit_code": 1}
|
||||
@@ -1,700 +0,0 @@
|
||||
"""System-domain tool implementations.
|
||||
|
||||
Extracted from tool_implementations.py as part of slice 1 (#4082/#4071).
|
||||
Holds the skills/tasks tools plus the generic API bridges (api_call, app_api).
|
||||
The admin manage_* tools (endpoints, mcp, webhooks, tokens, settings) live in
|
||||
``src.agent_tools.admin_tools`` after the upstream registry migration (#3629);
|
||||
``src.tool_implementations`` re-exports both sets for backward compatibility.
|
||||
"""
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from src.tools._common import _parse_tool_args
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Skills management tool
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def do_manage_skills(content: str, owner: Optional[str] = None) -> Dict:
|
||||
"""Handle manage_skills tool calls.
|
||||
|
||||
SKILL.md-backed CRUD with progressive disclosure (Hermes-style). Actions:
|
||||
|
||||
list / index — Level 0: name + description summary.
|
||||
view {name} — Level 1: full SKILL.md.
|
||||
view_ref {name, path} — Level 2: a sub-file under the skill dir.
|
||||
add {name, description, when_to_use, procedure[], pitfalls[],
|
||||
verification[], tags[], category, status}
|
||||
— Create a new skill (draft by default).
|
||||
patch {name, old_string, new_string}
|
||||
— Token-efficient surgical edit on the
|
||||
raw SKILL.md text. Fails on ambiguous
|
||||
`old_string` (multiple matches).
|
||||
edit {name, content} — Replace the entire SKILL.md.
|
||||
publish {name} — Flip status: draft -> published.
|
||||
delete {name} — Remove the skill directory.
|
||||
search {query} — Relevance match on published skills.
|
||||
"""
|
||||
try:
|
||||
args = _parse_tool_args(content)
|
||||
except ValueError:
|
||||
return {"error": "Invalid JSON arguments", "exit_code": 1}
|
||||
|
||||
action = (args.get("action") or "").lower()
|
||||
from services.memory.skills import SkillsManager
|
||||
from services.memory.skill_format import Skill, slugify
|
||||
from src.constants import DATA_DIR
|
||||
sm = SkillsManager(DATA_DIR)
|
||||
|
||||
# Accept legacy `skill_id` as an alias for `name`.
|
||||
name = (args.get("name") or args.get("skill_id") or "").strip()
|
||||
|
||||
if action in ("list", "index", ""):
|
||||
all_skills = sm.load(owner=owner)
|
||||
if not all_skills:
|
||||
return {"results": "No skills yet. Create one with action='add'."}
|
||||
published = [s for s in all_skills if s.get("status") == "published"]
|
||||
drafts = [s for s in all_skills if s.get("status") == "draft"]
|
||||
lines = []
|
||||
if published:
|
||||
lines.append("## Published")
|
||||
for s in sorted(published, key=lambda x: x["name"]):
|
||||
lines.append(f"- **{s['name']}** ({s.get('category','general')}): {s.get('description','')}")
|
||||
if drafts:
|
||||
lines.append("\n## Drafts")
|
||||
for s in sorted(drafts, key=lambda x: x["name"]):
|
||||
lines.append(f"- **{s['name']}** [draft]: {s.get('description','')}")
|
||||
return {"results": "\n".join(lines) if lines else "No skills yet."}
|
||||
|
||||
if action == "view":
|
||||
if not name:
|
||||
return {"error": "name is required for view", "exit_code": 1}
|
||||
md = sm.read_skill_md(name, owner=owner)
|
||||
if md is None:
|
||||
return {"error": f"Skill {name!r} not found", "exit_code": 1}
|
||||
return {"results": md}
|
||||
|
||||
if action == "view_ref":
|
||||
if not name:
|
||||
return {"error": "name is required for view_ref", "exit_code": 1}
|
||||
ref = (args.get("path") or "").strip()
|
||||
if not ref:
|
||||
return {"error": "path is required for view_ref", "exit_code": 1}
|
||||
text = sm.read_skill_reference(name, ref, owner=owner)
|
||||
if text is None:
|
||||
return {"error": f"Reference {ref!r} not found under {name!r}", "exit_code": 1}
|
||||
return {"results": text}
|
||||
|
||||
if action == "add":
|
||||
if not name:
|
||||
return {
|
||||
"error": "name is required for add. Provide the exact slug the user should see, then report the returned name.",
|
||||
"exit_code": 1,
|
||||
}
|
||||
proc = args.get("procedure")
|
||||
if proc is None:
|
||||
proc = args.get("steps") or []
|
||||
if not proc and not args.get("body_extra") and not args.get("solution"):
|
||||
return {"error": "procedure (or solution body) is required", "exit_code": 1}
|
||||
# Same auto-publish gate as the extractor path — when the user
|
||||
# has auto_approve_skills on and the caller didn't pin an explicit
|
||||
# status, publish immediately. Audit later demotes/removes on fail.
|
||||
_status_arg = args.get("status")
|
||||
if not _status_arg:
|
||||
try:
|
||||
from routes.prefs_routes import _load_for_user as _load_prefs
|
||||
_prefs = _load_prefs(owner) or {}
|
||||
_status_arg = "published" if _prefs.get("auto_approve_skills", True) else "draft"
|
||||
except Exception:
|
||||
_status_arg = "draft"
|
||||
entry = sm.add_skill(
|
||||
name=args.get("name"),
|
||||
description=(args.get("description") or args.get("title") or "").strip(),
|
||||
category=args.get("category") or "general",
|
||||
tags=args.get("tags") or [],
|
||||
platforms=args.get("platforms") or [],
|
||||
requires_toolsets=args.get("requires_toolsets") or [],
|
||||
fallback_for_toolsets=args.get("fallback_for_toolsets") or [],
|
||||
when_to_use=(args.get("when_to_use") if args.get("when_to_use") is not None
|
||||
else args.get("problem", "")),
|
||||
procedure=proc,
|
||||
pitfalls=args.get("pitfalls") or [],
|
||||
verification=args.get("verification") or [],
|
||||
status=_status_arg,
|
||||
version=args.get("version") or "1.0.0",
|
||||
confidence=args.get("confidence", 0.8),
|
||||
source=args.get("source", "learned"),
|
||||
teacher_model=args.get("teacher_model"),
|
||||
owner=owner,
|
||||
title=args.get("title", ""),
|
||||
problem=args.get("problem", ""),
|
||||
solution=args.get("solution", ""),
|
||||
steps=args.get("steps") or [],
|
||||
)
|
||||
if entry.get("_deduped"):
|
||||
return {"results": (
|
||||
f"A near-identical skill already exists: `{entry['name']}` — not creating "
|
||||
f"a duplicate. View or edit it with action='view', name='{entry['name']}'."
|
||||
)}
|
||||
try:
|
||||
from src.event_bus import fire_event
|
||||
fire_event("skill_added", owner)
|
||||
except Exception:
|
||||
logger.debug("skill_added event dispatch failed", exc_info=True)
|
||||
verify_hint = ""
|
||||
if entry.get("status") == "draft":
|
||||
verify_hint = (
|
||||
"\n\nThis skill is a DRAFT. Run through the procedure once to verify, "
|
||||
f"then publish with action='publish', name='{entry['name']}'."
|
||||
)
|
||||
return {"results": f"Created skill `{entry['name']}` — {entry.get('description','')}{verify_hint}"}
|
||||
|
||||
if action == "edit":
|
||||
if not name:
|
||||
return {"error": "name is required for edit", "exit_code": 1}
|
||||
new_content = args.get("content")
|
||||
if not isinstance(new_content, str) or not new_content.strip():
|
||||
return {"error": "content (full SKILL.md) is required for edit", "exit_code": 1}
|
||||
try:
|
||||
sk_new = Skill.from_markdown(new_content)
|
||||
except Exception as e:
|
||||
return {"error": f"Could not parse content as SKILL.md: {e}", "exit_code": 1}
|
||||
sk_new.name = slugify(sk_new.name or name)
|
||||
existing = sm.load(owner=owner)
|
||||
match = next((s for s in existing if s.get("name") == name), None)
|
||||
if not match:
|
||||
return {"error": f"Skill {name!r} not found", "exit_code": 1}
|
||||
if not sk_new.owner:
|
||||
sk_new.owner = match.get("owner") or owner
|
||||
ok = sm.update_skill(name, _skill_dump(sk_new), owner=owner)
|
||||
return {"results": f"Edited skill `{sk_new.name}`."} if ok else {"error": "Update failed", "exit_code": 1}
|
||||
|
||||
if action == "patch":
|
||||
if not name:
|
||||
return {"error": "name is required for patch", "exit_code": 1}
|
||||
old = args.get("old_string")
|
||||
new_str = args.get("new_string", "")
|
||||
if not isinstance(old, str) or not old:
|
||||
return {"error": "old_string is required and must be non-empty", "exit_code": 1}
|
||||
md = sm.read_skill_md(name, owner=owner)
|
||||
if md is None:
|
||||
return {"error": f"Skill {name!r} not found", "exit_code": 1}
|
||||
count = md.count(old)
|
||||
if count == 0:
|
||||
return {"error": "old_string not found in SKILL.md", "exit_code": 1}
|
||||
if count > 1:
|
||||
return {"error": f"old_string is ambiguous (appears {count} times). Make it more specific.", "exit_code": 1}
|
||||
new_md = md.replace(old, new_str, 1)
|
||||
try:
|
||||
sk_new = Skill.from_markdown(new_md)
|
||||
except Exception as e:
|
||||
return {"error": f"Patched content is not valid SKILL.md: {e}", "exit_code": 1}
|
||||
sk_new.name = slugify(sk_new.name or name)
|
||||
ok = sm.update_skill(name, _skill_dump(sk_new), owner=owner)
|
||||
return {"results": f"Patched skill `{sk_new.name}`."} if ok else {"error": "Patch update failed", "exit_code": 1}
|
||||
|
||||
if action == "publish":
|
||||
if not name:
|
||||
return {"error": "name is required for publish", "exit_code": 1}
|
||||
all_skills = sm.load(owner=owner)
|
||||
match = next((s for s in all_skills if s.get("name") == name), None)
|
||||
if not match:
|
||||
return {"error": f"Skill {name!r} not found", "exit_code": 1}
|
||||
updates = {"status": "published"}
|
||||
if args.get("confidence") is not None:
|
||||
updates["confidence"] = max(0.0, min(1.0, float(args["confidence"])))
|
||||
sm.update_skill(name, updates, owner=owner)
|
||||
return {"results": f"✅ Published `{name}`. It now appears in the skills index for future turns."}
|
||||
|
||||
if action == "delete":
|
||||
if not name:
|
||||
return {"error": "name is required for delete", "exit_code": 1}
|
||||
ok = sm.delete_skill(name, owner=owner)
|
||||
return {"results": f"Deleted skill `{name}`."} if ok else {"error": f"Skill {name!r} not found", "exit_code": 1}
|
||||
|
||||
if action == "search":
|
||||
query = (args.get("query") or "").strip()
|
||||
if not query:
|
||||
return {"error": "query is required for search", "exit_code": 1}
|
||||
results = sm.get_relevant_skills(query, sm.load(owner=owner), max_items=5)
|
||||
if not results:
|
||||
return {"results": "No matching skills found."}
|
||||
lines = []
|
||||
for sk in results:
|
||||
proc = sk.get("procedure") or sk.get("steps") or []
|
||||
steps_str = " → ".join(proc[:5])
|
||||
lines.append(f"**{sk['name']}**: {sk.get('description','')}\n When: {sk.get('when_to_use','')}\n Steps: {steps_str}")
|
||||
return {"results": "\n\n".join(lines)}
|
||||
|
||||
return {
|
||||
"error": (
|
||||
f"Unknown action: {action!r}. "
|
||||
"Use one of: list, view, view_ref, add, edit, patch, publish, delete, search."
|
||||
),
|
||||
"exit_code": 1,
|
||||
}
|
||||
|
||||
|
||||
def _skill_dump(sk) -> Dict:
|
||||
"""Translate a parsed Skill back into the kwargs `update_skill` expects."""
|
||||
return {
|
||||
"name": sk.name,
|
||||
"description": sk.description,
|
||||
"version": sk.version,
|
||||
"category": sk.category,
|
||||
"tags": sk.tags,
|
||||
"platforms": sk.platforms,
|
||||
"requires_toolsets": sk.requires_toolsets,
|
||||
"fallback_for_toolsets": sk.fallback_for_toolsets,
|
||||
"status": sk.status,
|
||||
"confidence": sk.confidence,
|
||||
"source": sk.source,
|
||||
"teacher_model": sk.teacher_model,
|
||||
"owner": sk.owner,
|
||||
"when_to_use": sk.when_to_use,
|
||||
"procedure": sk.procedure,
|
||||
"pitfalls": sk.pitfalls,
|
||||
"verification": sk.verification,
|
||||
"body_extra": sk.body_extra,
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Task management tool
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def do_manage_tasks(content: str, owner: Optional[str] = None) -> Dict:
|
||||
"""Handle manage_tasks tool calls: CRUD on scheduled tasks."""
|
||||
import uuid as _uuid
|
||||
from core.database import SessionLocal, ScheduledTask
|
||||
from src.task_scheduler import compute_next_run
|
||||
|
||||
try:
|
||||
args = _parse_tool_args(content)
|
||||
except ValueError:
|
||||
return {"error": "Invalid JSON arguments", "exit_code": 1}
|
||||
|
||||
action = args.get("action", "list")
|
||||
db = SessionLocal()
|
||||
try:
|
||||
if action == "list":
|
||||
q = db.query(ScheduledTask)
|
||||
if owner:
|
||||
q = q.filter(ScheduledTask.owner == owner)
|
||||
tasks = q.order_by(ScheduledTask.created_at.desc()).all()
|
||||
task_list = []
|
||||
for t in tasks:
|
||||
task_list.append({
|
||||
"id": t.id, "name": t.name, "status": t.status,
|
||||
"task_type": t.task_type or "llm",
|
||||
"action": t.action,
|
||||
"trigger_type": t.trigger_type or "schedule",
|
||||
"schedule": t.schedule,
|
||||
"trigger_event": t.trigger_event,
|
||||
"trigger_count": t.trigger_count,
|
||||
"next_run": t.next_run.isoformat() + "Z" if t.next_run else None,
|
||||
"last_run": t.last_run.isoformat() + "Z" if t.last_run else None,
|
||||
"run_count": t.run_count or 0,
|
||||
})
|
||||
return {"response": f"Found {len(task_list)} tasks", "tasks": task_list, "exit_code": 0}
|
||||
|
||||
elif action == "create":
|
||||
task_type = args.get("task_type", "llm")
|
||||
trigger_type = args.get("trigger_type", "schedule")
|
||||
|
||||
if task_type in ("llm", "research") and not args.get("prompt"):
|
||||
return {"error": "Prompt is required for llm/research tasks", "exit_code": 1}
|
||||
if task_type == "action" and not args.get("action_name"):
|
||||
return {"error": "action_name is required for action tasks", "exit_code": 1}
|
||||
|
||||
# Compute next_run for schedule triggers
|
||||
next_run = None
|
||||
if trigger_type == "schedule":
|
||||
schedule = args.get("schedule", "daily")
|
||||
next_run = compute_next_run(
|
||||
schedule, args.get("scheduled_time", "09:00"),
|
||||
args.get("scheduled_day"),
|
||||
)
|
||||
|
||||
task_id = str(_uuid.uuid4())
|
||||
# Guard each fallback with `or`: args.get("prompt", default) returns
|
||||
# None when the key is present but null, and None[:50] raises.
|
||||
name = args.get("name") or (args.get("prompt") or args.get("action_name") or "Task")[:50]
|
||||
|
||||
task = ScheduledTask(
|
||||
id=task_id,
|
||||
owner=owner,
|
||||
name=name,
|
||||
prompt=args.get("prompt"),
|
||||
task_type=task_type,
|
||||
action=args.get("action_name"),
|
||||
schedule=args.get("schedule") if trigger_type == "schedule" else None,
|
||||
scheduled_time=args.get("scheduled_time", "09:00") if trigger_type == "schedule" else None,
|
||||
scheduled_day=args.get("scheduled_day"),
|
||||
trigger_type=trigger_type,
|
||||
trigger_event=args.get("trigger_event"),
|
||||
trigger_count=args.get("trigger_count"),
|
||||
trigger_counter=0,
|
||||
next_run=next_run,
|
||||
status="active",
|
||||
output_target=args.get("output_target", "session"),
|
||||
)
|
||||
db.add(task)
|
||||
db.commit()
|
||||
return {"response": f"Created task '{name}' (id: {task_id})", "task_id": task_id, "exit_code": 0}
|
||||
|
||||
elif action == "edit":
|
||||
task_id = args.get("task_id")
|
||||
if not task_id:
|
||||
return {"error": "task_id is required for edit", "exit_code": 1}
|
||||
task = db.query(ScheduledTask).filter(ScheduledTask.id == task_id).first()
|
||||
if not task:
|
||||
return {"error": f"Task {task_id} not found", "exit_code": 1}
|
||||
if owner and task.owner and task.owner != owner:
|
||||
return {"error": "Access denied", "exit_code": 1}
|
||||
|
||||
changed = []
|
||||
for field in ("name", "prompt", "output_target"):
|
||||
if args.get(field) is not None:
|
||||
setattr(task, field, args[field])
|
||||
changed.append(field)
|
||||
if args.get("task_type") is not None:
|
||||
task.task_type = args["task_type"]
|
||||
changed.append("task_type")
|
||||
if args.get("action_name") is not None:
|
||||
task.action = args["action_name"]
|
||||
changed.append("action")
|
||||
if args.get("trigger_type") is not None:
|
||||
task.trigger_type = args["trigger_type"]
|
||||
changed.append("trigger_type")
|
||||
if args.get("trigger_event") is not None:
|
||||
task.trigger_event = args["trigger_event"]
|
||||
changed.append("trigger_event")
|
||||
if args.get("trigger_count") is not None:
|
||||
task.trigger_count = args["trigger_count"]
|
||||
changed.append("trigger_count")
|
||||
|
||||
schedule_changed = False
|
||||
for field in ("schedule", "scheduled_time", "scheduled_day"):
|
||||
if args.get(field) is not None:
|
||||
setattr(task, field, args[field])
|
||||
changed.append(field)
|
||||
schedule_changed = True
|
||||
|
||||
if schedule_changed and (task.trigger_type or "schedule") == "schedule":
|
||||
task.next_run = compute_next_run(
|
||||
task.schedule, task.scheduled_time, task.scheduled_day,
|
||||
)
|
||||
|
||||
db.commit()
|
||||
return {"response": f"Updated task '{task.name}': {', '.join(changed)}", "exit_code": 0}
|
||||
|
||||
elif action == "delete":
|
||||
task_id = args.get("task_id")
|
||||
if not task_id:
|
||||
return {"error": "task_id is required for delete", "exit_code": 1}
|
||||
task = db.query(ScheduledTask).filter(ScheduledTask.id == task_id).first()
|
||||
if not task:
|
||||
return {"error": f"Task {task_id} not found", "exit_code": 1}
|
||||
if owner and task.owner and task.owner != owner:
|
||||
return {"error": "Access denied", "exit_code": 1}
|
||||
name = task.name
|
||||
db.delete(task)
|
||||
db.commit()
|
||||
return {"response": f"Deleted task '{name}'", "exit_code": 0}
|
||||
|
||||
elif action in ("pause", "resume"):
|
||||
task_id = args.get("task_id")
|
||||
if not task_id:
|
||||
return {"error": f"task_id is required for {action}", "exit_code": 1}
|
||||
task = db.query(ScheduledTask).filter(ScheduledTask.id == task_id).first()
|
||||
if not task:
|
||||
return {"error": f"Task {task_id} not found", "exit_code": 1}
|
||||
if owner and task.owner and task.owner != owner:
|
||||
return {"error": "Access denied", "exit_code": 1}
|
||||
|
||||
if action == "pause":
|
||||
task.status = "paused"
|
||||
else:
|
||||
task.status = "active"
|
||||
if (task.trigger_type or "schedule") == "schedule":
|
||||
task.next_run = compute_next_run(
|
||||
task.schedule, task.scheduled_time, task.scheduled_day,
|
||||
)
|
||||
db.commit()
|
||||
return {"response": f"Task '{task.name}' {action}d", "exit_code": 0}
|
||||
|
||||
elif action == "run":
|
||||
task_id = args.get("task_id")
|
||||
if not task_id:
|
||||
return {"error": "task_id is required for run", "exit_code": 1}
|
||||
task = db.query(ScheduledTask).filter(ScheduledTask.id == task_id).first()
|
||||
if not task:
|
||||
return {"error": f"Task {task_id} not found", "exit_code": 1}
|
||||
if owner and task.owner and task.owner != owner:
|
||||
return {"error": "Access denied", "exit_code": 1}
|
||||
|
||||
from src.event_bus import get_task_scheduler
|
||||
scheduler = get_task_scheduler()
|
||||
if scheduler:
|
||||
started = await scheduler.run_task_now(task_id)
|
||||
if started:
|
||||
return {"response": f"Task '{task.name}' triggered", "exit_code": 0}
|
||||
else:
|
||||
return {"error": "Task is already running", "exit_code": 1}
|
||||
return {"error": "Task scheduler not available", "exit_code": 1}
|
||||
|
||||
else:
|
||||
return {"error": f"Unknown action: {action}", "exit_code": 1}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"manage_tasks error: {e}")
|
||||
return {"error": str(e), "exit_code": 1}
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# API call tool
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def do_api_call(content: str) -> Dict:
|
||||
"""Execute an API call to a registered integration."""
|
||||
from src.integrations import execute_api_call, load_integrations
|
||||
try:
|
||||
args = json.loads(content)
|
||||
except json.JSONDecodeError:
|
||||
# Try line-based format: integration\nmethod path\nbody
|
||||
lines = content.strip().split("\n")
|
||||
args = {"integration": lines[0].strip() if lines else ""}
|
||||
if len(lines) > 1:
|
||||
parts = lines[1].strip().split(" ", 1)
|
||||
args["method"] = parts[0] if parts else "GET"
|
||||
args["path"] = parts[1] if len(parts) > 1 else "/"
|
||||
if len(lines) > 2:
|
||||
try:
|
||||
args["body"] = json.loads("\n".join(lines[2:]))
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
integration_name = args.get("integration", "")
|
||||
integrations = load_integrations()
|
||||
intg = next((i for i in integrations if i["id"] == integration_name
|
||||
or i["name"].lower() == integration_name.lower()), None)
|
||||
if not intg:
|
||||
available = ", ".join(i["name"] for i in integrations if i.get("enabled", True))
|
||||
return {"error": f"No integration matching '{integration_name}'. Available: {available or 'none configured'}", "exit_code": 1}
|
||||
|
||||
return await execute_api_call(
|
||||
intg["id"],
|
||||
args.get("method", "GET"),
|
||||
args.get("path", "/"),
|
||||
params=args.get("params"),
|
||||
body=args.get("body"),
|
||||
extra_headers=args.get("headers"),
|
||||
)
|
||||
|
||||
|
||||
# Paths the generic `app_api` tool will refuse to call. Auth/token/user
|
||||
# administration and host shell execution are too risky to route through an
|
||||
# agent surface even when the agent is admin-context; accidental account or
|
||||
# command mistakes have permanent blast radius.
|
||||
_APP_API_BLOCKLIST_PREFIXES = (
|
||||
"/api/auth", # login/logout/password
|
||||
"/api/users", # user CRUD (bare /api/users list+create+delete must also block)
|
||||
"/api/tokens", # api token mgmt (bare /api/tokens list+create must also block)
|
||||
"/api/admin", # admin one-shots (wipe etc.)
|
||||
"/api/shell", # host shell execution must stay behind named command tooling
|
||||
"/api/backup/restore", # destructive restore
|
||||
)
|
||||
|
||||
# (method, prefix) pairs to refuse specifically. Used for endpoints
|
||||
# where GET is fine but writes are destructive or host-control shaped.
|
||||
# Saw the agent wipe cookbook_state.json (presets + tasks) by POSTing
|
||||
# {"tasks": []} to /api/cookbook/state, which overwrote the whole file.
|
||||
# Use dedicated tools or UI flows instead.
|
||||
_APP_API_BLOCKLIST_METHOD_PATH = (
|
||||
("GET", "/api/email/accounts"), # owner-filtered in tool context; use list_email_accounts MCP tool
|
||||
("POST", "/api/cookbook/state"), # whole-file overwrite — agent must use serve_preset/serve_model instead
|
||||
("DELETE", "/api/cookbook/state"),
|
||||
# Host-control routes: package install, engine rebuild, and process
|
||||
# signalling should not be reachable through the generic API bridge.
|
||||
("POST", "/api/cookbook/packages/install"),
|
||||
("POST", "/api/cookbook/rebuild-engine"),
|
||||
("POST", "/api/cookbook/kill-pid"),
|
||||
# Use the named tools (download_model / serve_model) — they handle
|
||||
# host-name resolution, per-host env_prefix, AND register the task
|
||||
# in cookbook state so it shows in the UI + list_downloads. Hitting
|
||||
# the raw endpoint via app_api skips all of that → orphan task.
|
||||
("POST", "/api/model/download"),
|
||||
("POST", "/api/model/serve"),
|
||||
# Use trigger_research — it returns a UI hint so the Deep Research
|
||||
# sidebar surfaces the session. Raw start works but the agent
|
||||
# fumbles the payload + the session doesn't reliably show up.
|
||||
("POST", "/api/research/start"),
|
||||
# Use the named tools — they handle owner attribution, natural-
|
||||
# language due_date parsing, timezone, dedup, and tag/category
|
||||
# normalization. Hitting the raw endpoint via app_api saves a
|
||||
# note/event with the wrong fields, no reminder, or the wrong tz.
|
||||
("POST", "/api/notes"),
|
||||
("PUT", "/api/notes"),
|
||||
("DELETE", "/api/notes"),
|
||||
("POST", "/api/calendar/events"),
|
||||
("PUT", "/api/calendar/events"),
|
||||
("DELETE", "/api/calendar/events"),
|
||||
)
|
||||
|
||||
|
||||
async def do_app_api(content: str, owner: Optional[str] = None) -> Dict:
|
||||
"""Generic loopback to allowed internal Odysseus API endpoints. Lets the
|
||||
agent reach the full UI-button surface (cookbook, email, notes,
|
||||
calendar, skills, sessions, gallery, research, etc.) without us
|
||||
landing a named tool wrapper for every one.
|
||||
|
||||
Args (JSON):
|
||||
action: "call" (default) | "endpoints"
|
||||
path: "/api/cookbook/gpus" # required for call
|
||||
method: "GET" | "POST" | "PUT" | "PATCH" | "DELETE" (default GET)
|
||||
body: <object> # JSON body for POST/PUT/PATCH
|
||||
query: <object> # querystring params
|
||||
|
||||
The `endpoints` action returns the OpenAPI surface (method + path +
|
||||
summary) so the agent can discover what's reachable. A blocklist
|
||||
refuses sensitive auth/user/admin/shell paths and method-specific
|
||||
host-control routes to keep blast radius bounded.
|
||||
"""
|
||||
# `_internal_headers` and `_INTERNAL_BASE` still live in
|
||||
# tool_implementations.py (shared by many domain tools). Function-local
|
||||
# import avoids a top-level circular dependency until a later task
|
||||
# relocates them.
|
||||
from src.tool_implementations import _internal_headers, _INTERNAL_BASE
|
||||
|
||||
import httpx
|
||||
try:
|
||||
args = _parse_tool_args(content) if content.strip() else {}
|
||||
except ValueError:
|
||||
return {"error": "Invalid JSON arguments", "exit_code": 1}
|
||||
|
||||
action = (args.get("action") or "call").lower()
|
||||
base = _INTERNAL_BASE
|
||||
|
||||
if action == "endpoints":
|
||||
# Fetch FastAPI's OpenAPI schema so the agent can discover any
|
||||
# endpoint without us pre-listing them. Filter by an optional
|
||||
# `filter` keyword (substring match on path or summary).
|
||||
kw = (args.get("filter") or "").lower()
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=15) as client:
|
||||
resp = await client.get(f"{base}/openapi.json",
|
||||
headers=_internal_headers())
|
||||
data = resp.json()
|
||||
except Exception as e:
|
||||
return {"error": f"OpenAPI fetch failed: {e}", "exit_code": 1}
|
||||
rows: List[Dict[str, Any]] = []
|
||||
for path, methods in (data.get("paths") or {}).items():
|
||||
if not isinstance(methods, dict):
|
||||
continue
|
||||
if any(path.startswith(p) for p in _APP_API_BLOCKLIST_PREFIXES):
|
||||
continue
|
||||
for method, op in methods.items():
|
||||
if method.lower() not in ("get", "post", "put", "patch", "delete"):
|
||||
continue
|
||||
if any(method.upper() == m and path.startswith(p) for m, p in _APP_API_BLOCKLIST_METHOD_PATH):
|
||||
continue
|
||||
summary = (op or {}).get("summary") or (op or {}).get("description") or ""
|
||||
if isinstance(summary, str):
|
||||
summary = summary.strip().split("\n")[0][:140]
|
||||
if kw and kw not in path.lower() and kw not in (summary or "").lower():
|
||||
continue
|
||||
rows.append({"method": method.upper(), "path": path, "summary": summary})
|
||||
rows.sort(key=lambda r: (r["path"], r["method"]))
|
||||
if not rows:
|
||||
return {"output": f"No endpoints match filter {kw!r}." if kw else "No endpoints found.", "exit_code": 0}
|
||||
lines = [f"{len(rows)} endpoint(s)" + (f" matching {kw!r}" if kw else "") + ":"]
|
||||
for r in rows[:200]:
|
||||
line = f" {r['method']:6s} {r['path']}"
|
||||
if r["summary"]:
|
||||
line += f" — {r['summary']}"
|
||||
lines.append(line)
|
||||
if len(rows) > 200:
|
||||
lines.append(f" ...({len(rows) - 200} more — filter to narrow)")
|
||||
return {"output": "\n".join(lines), "endpoints": rows, "exit_code": 0}
|
||||
|
||||
# action == "call"
|
||||
path = args.get("path") or ""
|
||||
if not path:
|
||||
return {"error": "path is required (e.g. '/api/cookbook/gpus')", "exit_code": 1}
|
||||
if not path.startswith("/"):
|
||||
path = "/" + path
|
||||
if any(path.startswith(p) for p in _APP_API_BLOCKLIST_PREFIXES):
|
||||
return {"error": f"Path blocked for safety: {path}. Sensitive endpoints are off-limits via app_api.", "exit_code": 1}
|
||||
|
||||
method = (args.get("method") or "GET").upper()
|
||||
if method not in ("GET", "POST", "PUT", "PATCH", "DELETE"):
|
||||
return {"error": f"Unsupported method: {method}", "exit_code": 1}
|
||||
if any(method == m and path.startswith(p) for m, p in _APP_API_BLOCKLIST_METHOD_PATH):
|
||||
if "/api/email/accounts" in path:
|
||||
return {"error": "Don't use /api/email/accounts via app_api — it is owner-filtered in tool context and may return empty. Use the `list_email_accounts` email tool, then pass `account` to list_emails/read_email.", "exit_code": 1}
|
||||
if "/api/cookbook/packages/install" in path:
|
||||
return {"error": "Don't POST /api/cookbook/packages/install via app_api — package installation is host code execution. Use the dedicated Cookbook dependency UI/flow instead.", "exit_code": 1}
|
||||
if "/api/cookbook/rebuild-engine" in path:
|
||||
return {"error": "Don't POST /api/cookbook/rebuild-engine via app_api — engine rebuild mutates local or remote host state. Use the dedicated Cookbook UI/flow instead.", "exit_code": 1}
|
||||
if "/api/cookbook/kill-pid" in path:
|
||||
return {"error": "Don't POST /api/cookbook/kill-pid via app_api — process signalling is host control. Use the dedicated Cookbook stop/diagnostic flow instead.", "exit_code": 1}
|
||||
if "/api/model/download" in path:
|
||||
return {"error": "Don't POST /api/model/download directly — use the `download_model` tool (it resolves the server name, sets the venv env_prefix, and registers the task so it shows in the UI).", "exit_code": 1}
|
||||
if "/api/model/serve" in path:
|
||||
return {"error": "Don't POST /api/model/serve directly — use the `serve_model` or `serve_preset` tool (handles host resolution, env_prefix, and cookbook tracking).", "exit_code": 1}
|
||||
if "/api/research/start" in path:
|
||||
return {"error": "Don't POST /api/research/start directly — use the `trigger_research` tool (it surfaces the session in the Deep Research sidebar).", "exit_code": 1}
|
||||
if "/api/notes" in path:
|
||||
return {"error": "Don't hit /api/notes via app_api — use the `manage_notes` tool. It accepts natural-language due_date ('11pm today', 'tomorrow at 9am'), fires reminders from the due_date itself (no separate calendar event), and uses the caller's timezone. The raw endpoint requires ISO-UTC + a separate calendar event, both of which the agent tends to get wrong.", "exit_code": 1}
|
||||
if "/api/calendar/events" in path:
|
||||
return {"error": "Don't hit /api/calendar/events via app_api — use the `manage_calendar` tool. It handles tz-aware natural-language datetimes and reminder_minutes correctly. If the user wants a note + reminder, prefer `manage_notes` with due_date — it bundles both.", "exit_code": 1}
|
||||
return {"error": f"{method} {path} is blocked — it overwrites the whole cookbook state file. Use list_serve_presets / serve_preset / serve_model instead.", "exit_code": 1}
|
||||
|
||||
body = args.get("body")
|
||||
query = args.get("query") or None
|
||||
# Pass owner so the backend impersonates the user — without this,
|
||||
# POSTs (notes, calendar, todos, ...) get owner="internal-tool"
|
||||
# and the user that asked for them can't see the result.
|
||||
headers = {**_internal_headers(owner=owner), "Content-Type": "application/json"}
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=60) as client:
|
||||
resp = await client.request(
|
||||
method, f"{base}{path}",
|
||||
json=body if body is not None and method in ("POST", "PUT", "PATCH") else None,
|
||||
params=query,
|
||||
headers=headers,
|
||||
)
|
||||
# Try to parse JSON; fall back to raw text.
|
||||
try:
|
||||
payload = resp.json()
|
||||
preview = json.dumps(payload, indent=2, default=str)
|
||||
if len(preview) > 4000:
|
||||
preview = preview[:4000] + "\n... (truncated)"
|
||||
except Exception:
|
||||
payload = None
|
||||
preview = (resp.text or "")[:4000]
|
||||
if resp.status_code >= 400:
|
||||
return {
|
||||
"error": f"{method} {path} -> HTTP {resp.status_code}",
|
||||
"status_code": resp.status_code,
|
||||
"body": preview,
|
||||
"exit_code": 1,
|
||||
}
|
||||
return {
|
||||
"output": f"{method} {path} -> {resp.status_code}\n{preview}",
|
||||
"status_code": resp.status_code,
|
||||
"json": payload,
|
||||
"exit_code": 0,
|
||||
}
|
||||
except Exception as e:
|
||||
return {"error": f"{method} {path} failed: {e}", "exit_code": 1}
|
||||
@@ -1,189 +0,0 @@
|
||||
"""Vault-domain tool implementations.
|
||||
|
||||
Extracted from tool_implementations.py as part of slice 1 (#4082/#4071).
|
||||
Holds the Bitwarden CLI wrappers (vault_search / vault_get / vault_unlock)
|
||||
and their helpers (_load_vault_config, _run_bw).
|
||||
``src.tool_implementations`` re-exports these for backward compatibility.
|
||||
"""
|
||||
import json
|
||||
from typing import Dict, Optional
|
||||
|
||||
from src.constants import VAULT_FILE
|
||||
from src.tools._common import _parse_tool_args
|
||||
|
||||
|
||||
def _load_vault_config() -> Dict:
|
||||
"""Load Vaultwarden config from data/vault.json."""
|
||||
from pathlib import Path
|
||||
p = Path(VAULT_FILE)
|
||||
if p.exists():
|
||||
try:
|
||||
return json.loads(p.read_text(encoding="utf-8"))
|
||||
except Exception:
|
||||
pass
|
||||
return {}
|
||||
|
||||
|
||||
async def _run_bw(args: list, session: Optional[str] = None, input_text: Optional[str] = None) -> tuple:
|
||||
"""Run a bw CLI command with optional session + stdin. Returns (stdout, stderr, returncode)."""
|
||||
import asyncio
|
||||
env = {}
|
||||
import os as _os
|
||||
env.update(_os.environ)
|
||||
if session:
|
||||
env["BW_SESSION"] = session
|
||||
|
||||
proc = await asyncio.create_subprocess_exec(
|
||||
"bw", *args,
|
||||
stdin=asyncio.subprocess.PIPE if input_text else None,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
env=env,
|
||||
)
|
||||
stdout, stderr = await proc.communicate(input=input_text.encode() if input_text else None)
|
||||
return stdout.decode(errors="replace").strip(), stderr.decode(errors="replace").strip(), proc.returncode
|
||||
|
||||
|
||||
async def do_vault_search(content: str, owner: Optional[str] = None) -> Dict:
|
||||
"""Search the vault by keyword. Returns matching item names + URLs, NO passwords."""
|
||||
try:
|
||||
args = _parse_tool_args(content)
|
||||
except ValueError:
|
||||
return {"error": "Invalid JSON arguments", "exit_code": 1}
|
||||
query = args.get("query", "").strip()
|
||||
if not query:
|
||||
return {"error": "query is required", "exit_code": 1}
|
||||
|
||||
cfg = _load_vault_config()
|
||||
session = cfg.get("session")
|
||||
if not session:
|
||||
return {"error": "Vault is locked. Run vault_unlock or provide session key in settings.", "exit_code": 1}
|
||||
|
||||
stdout, stderr, rc = await _run_bw(["list", "items", "--search", query], session=session)
|
||||
if rc != 0:
|
||||
return {"error": f"bw failed: {stderr[:300]}", "exit_code": 1}
|
||||
|
||||
try:
|
||||
items = json.loads(stdout)
|
||||
except json.JSONDecodeError:
|
||||
return {"error": "Failed to parse bw output", "exit_code": 1}
|
||||
|
||||
if not items:
|
||||
return {"output": f"No vault items match '{query}'.", "exit_code": 0}
|
||||
|
||||
lines = [f"Found {len(items)} item(s) matching '{query}':"]
|
||||
for it in items[:20]:
|
||||
item_id = it.get("id", "?")
|
||||
name = it.get("name", "?")
|
||||
login = it.get("login") or {}
|
||||
username = login.get("username", "")
|
||||
uris = login.get("uris") or []
|
||||
url = uris[0].get("uri", "") if uris else ""
|
||||
parts = [f"[{item_id[:8]}] {name}"]
|
||||
if username:
|
||||
parts.append(f"user: {username}")
|
||||
if url:
|
||||
parts.append(f"url: {url}")
|
||||
lines.append("- " + " · ".join(parts))
|
||||
lines.append("\nUse vault_get(item_id, reason) to retrieve the password.")
|
||||
return {"output": "\n".join(lines), "exit_code": 0}
|
||||
|
||||
|
||||
async def do_vault_get(content: str, owner: Optional[str] = None) -> Dict:
|
||||
"""Retrieve a full vault entry (including password) by item ID. Logs access to assistant chat."""
|
||||
try:
|
||||
args = _parse_tool_args(content)
|
||||
except ValueError:
|
||||
return {"error": "Invalid JSON arguments", "exit_code": 1}
|
||||
item_id = args.get("item_id", "").strip()
|
||||
reason = args.get("reason", "").strip()
|
||||
if not item_id:
|
||||
return {"error": "item_id is required", "exit_code": 1}
|
||||
if not reason:
|
||||
return {"error": "reason is required — explain WHY you need this password", "exit_code": 1}
|
||||
|
||||
cfg = _load_vault_config()
|
||||
session = cfg.get("session")
|
||||
if not session:
|
||||
return {"error": "Vault is locked. Unlock first.", "exit_code": 1}
|
||||
|
||||
stdout, stderr, rc = await _run_bw(["get", "item", item_id], session=session)
|
||||
if rc != 0:
|
||||
return {"error": f"bw failed: {stderr[:300]}", "exit_code": 1}
|
||||
|
||||
try:
|
||||
item = json.loads(stdout)
|
||||
except json.JSONDecodeError:
|
||||
return {"error": "Failed to parse bw output", "exit_code": 1}
|
||||
|
||||
login = item.get("login") or {}
|
||||
name = item.get("name", "?")
|
||||
|
||||
# Audit log to assistant chat
|
||||
try:
|
||||
from src.assistant_log import log_to_assistant
|
||||
if owner:
|
||||
log_to_assistant(
|
||||
owner,
|
||||
f"Retrieved password for **{name}** — reason: {reason}",
|
||||
category="Vault",
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
output = [
|
||||
f"Vault item: {name}",
|
||||
f"Username: {login.get('username', '(none)')}",
|
||||
f"Password: {login.get('password', '(none)')}",
|
||||
]
|
||||
if login.get("totp"):
|
||||
output.append(f"TOTP secret: {login['totp']}")
|
||||
uris = login.get("uris") or []
|
||||
if uris:
|
||||
output.append("URLs: " + ", ".join(u.get("uri", "") for u in uris))
|
||||
if item.get("notes"):
|
||||
output.append(f"Notes: {item['notes']}")
|
||||
|
||||
return {"output": "\n".join(output), "exit_code": 0}
|
||||
|
||||
|
||||
async def do_vault_unlock(content: str, owner: Optional[str] = None) -> Dict:
|
||||
"""Unlock the vault using a master password. Stores the resulting session key."""
|
||||
try:
|
||||
args = _parse_tool_args(content)
|
||||
except ValueError:
|
||||
return {"error": "Invalid JSON arguments", "exit_code": 1}
|
||||
master_password = args.get("master_password", "")
|
||||
if not master_password:
|
||||
return {"error": "master_password is required", "exit_code": 1}
|
||||
|
||||
# Do not pass the master password as an argv element. Local process lists
|
||||
# can expose argv to other users; stdin keeps the secret out of `ps`.
|
||||
stdout, stderr, rc = await _run_bw(["unlock", "--raw"], input_text=master_password + "\n")
|
||||
if rc != 0:
|
||||
return {"error": f"Unlock failed: {stderr[:300]}", "exit_code": 1}
|
||||
|
||||
session = stdout.strip()
|
||||
if not session:
|
||||
return {"error": "bw returned empty session", "exit_code": 1}
|
||||
|
||||
# Save session to vault.json
|
||||
from pathlib import Path
|
||||
p = Path(VAULT_FILE)
|
||||
cfg = {}
|
||||
if p.exists():
|
||||
try:
|
||||
cfg = json.loads(p.read_text(encoding="utf-8"))
|
||||
except Exception:
|
||||
pass
|
||||
cfg["session"] = session
|
||||
from datetime import datetime as _dt
|
||||
cfg["unlocked_at"] = _dt.utcnow().isoformat()
|
||||
p.write_text(json.dumps(cfg, indent=2), encoding="utf-8")
|
||||
try:
|
||||
import os as _os
|
||||
_os.chmod(str(p), 0o600)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return {"output": "Vault unlocked. Session saved.", "exit_code": 0}
|
||||
+13
-48
@@ -113,10 +113,6 @@ class UploadHandler:
|
||||
self.file_detector = None
|
||||
logger.warning("python-magic not available, falling back to basic detection")
|
||||
|
||||
# In-memory index cache to avoid O(N) disk I/O on every request
|
||||
self._index_cache: Optional[Dict[str, Any]] = None
|
||||
self._index_mtime: float = 0.0
|
||||
|
||||
def inside_base_dir(self, path: str) -> bool:
|
||||
"""Check if path is inside base directory"""
|
||||
base = os.path.realpath(self.base_dir)
|
||||
@@ -321,13 +317,6 @@ class UploadHandler:
|
||||
except OSError:
|
||||
pass
|
||||
os.replace(tmp, path)
|
||||
# Update cache if this is the main index
|
||||
if path.endswith("uploads.json"):
|
||||
self._index_cache = data
|
||||
try:
|
||||
self._index_mtime = os.path.getmtime(path)
|
||||
except OSError:
|
||||
self._index_mtime = time.time()
|
||||
except Exception:
|
||||
try:
|
||||
os.unlink(tmp)
|
||||
@@ -336,40 +325,22 @@ class UploadHandler:
|
||||
raise
|
||||
|
||||
def _load_upload_index(self) -> Dict[str, Any]:
|
||||
"""Load the upload index from disk/cache. Uses mtime-based validation
|
||||
to avoid redundant parsing on hot paths.
|
||||
"""
|
||||
uploads_db_path = os.path.join(self.upload_dir, "uploads.json")
|
||||
if not os.path.exists(uploads_db_path):
|
||||
self._index_cache = {}
|
||||
self._index_mtime = 0.0
|
||||
return {}
|
||||
|
||||
# Check cache validity
|
||||
try:
|
||||
mtime = os.path.getmtime(uploads_db_path)
|
||||
if self._index_cache is not None and mtime <= self._index_mtime:
|
||||
return self._index_cache
|
||||
except OSError:
|
||||
mtime = 0.0
|
||||
|
||||
# Try the live file first, fall back to the .bak sibling if the
|
||||
# live file is truncated/corrupted.
|
||||
# live file is truncated/corrupted (e.g. a previous writer was
|
||||
# SIGKILL'd mid-rename before the new code path was deployed).
|
||||
for candidate in (uploads_db_path, uploads_db_path + ".bak"):
|
||||
if not os.path.exists(candidate):
|
||||
continue
|
||||
try:
|
||||
with open(candidate, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
if isinstance(data, dict):
|
||||
self._index_cache = data
|
||||
self._index_mtime = mtime
|
||||
return data
|
||||
return data if isinstance(data, dict) else {}
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to read uploads database ({candidate}): {e}")
|
||||
continue
|
||||
|
||||
self._index_cache = {}
|
||||
return {}
|
||||
|
||||
def get_upload_info(self, upload_id: str) -> Optional[Dict[str, Any]]:
|
||||
@@ -382,23 +353,14 @@ class UploadHandler:
|
||||
return None
|
||||
|
||||
def _renamed_upload_index_key(self, key: str, info: Dict[str, Any], old_owner: str, new_owner: str) -> str:
|
||||
"""Return the storage key to use after renaming an owned upload row.
|
||||
|
||||
Harden against usernames with colons by using the explicit metadata
|
||||
fields instead of trying to parse the key string.
|
||||
"""
|
||||
"""Return the storage key to use after renaming an owned upload row."""
|
||||
if isinstance(key, str) and ":" in key:
|
||||
owner_part, rest = key.split(":", 1)
|
||||
if owner_part.strip().lower() == old_owner:
|
||||
return f"{new_owner}:{rest}"
|
||||
file_hash = info.get("hash")
|
||||
if file_hash:
|
||||
return f"{new_owner}:{file_hash}"
|
||||
|
||||
# Fallback for rows without an explicit hash (should not happen in modern Odysseus)
|
||||
if isinstance(key, str) and ":" in key:
|
||||
# Join all but the last part if there are multiple colons
|
||||
parts = key.rsplit(":", 1)
|
||||
if len(parts) == 2:
|
||||
owner_part, rest = parts[0], parts[1]
|
||||
if owner_part.strip().lower() == old_owner.strip().lower():
|
||||
return f"{new_owner}:{rest}"
|
||||
return key
|
||||
|
||||
def _unique_upload_index_key(self, base_key: str, used_keys: set, reserved_keys: set, info: Dict[str, Any]) -> str:
|
||||
@@ -581,8 +543,11 @@ class UploadHandler:
|
||||
total_size = 0
|
||||
file_types = {}
|
||||
|
||||
files = self._load_upload_index()
|
||||
if files:
|
||||
uploads_db_path = os.path.join(self.upload_dir, "uploads.json")
|
||||
if os.path.exists(uploads_db_path):
|
||||
with open(uploads_db_path, "r", encoding="utf-8") as f:
|
||||
files = json.load(f)
|
||||
|
||||
total_files = len(files)
|
||||
for file_info in files.values():
|
||||
total_size += file_info.get("size", 0)
|
||||
|
||||
@@ -1920,6 +1920,23 @@ import { wireArrowUpRecall, getLastUserMessageFromChatHistory } from './composer
|
||||
_chatBox.appendChild(note);
|
||||
try { note.scrollIntoView({ block: 'end', behavior: 'smooth' }); } catch (_) { uiModule.scrollHistory && uiModule.scrollHistory(); }
|
||||
}
|
||||
} else if (json.type === 'loop_breaker_triggered' || json.type === 'intent_nudge_exhausted') {
|
||||
// A loop guard ended the turn — surface why so it isn't mistaken
|
||||
// for a clean completion or a silent stall.
|
||||
const _chatBox = document.getElementById('chat-history');
|
||||
if (!_isBg && _chatBox) {
|
||||
const note = document.createElement('div');
|
||||
note.className = 'stopped-indicator loop-guard-stop';
|
||||
const label = document.createElement('span');
|
||||
label.className = 'rounds-exhausted-label';
|
||||
label.textContent = json.message ||
|
||||
(json.type === 'loop_breaker_triggered'
|
||||
? 'Stopped by the loop-breaker (no new progress).'
|
||||
: 'Stopped: announced an action but never called the tool.');
|
||||
note.appendChild(label);
|
||||
_chatBox.appendChild(note);
|
||||
try { note.scrollIntoView({ block: 'end', behavior: 'smooth' }); } catch (_) { uiModule.scrollHistory && uiModule.scrollHistory(); }
|
||||
}
|
||||
} else if (json.type === 'model_actual') {
|
||||
if (!_isBg && holder) {
|
||||
holder._requestedModel = json.requested_model || holder._requestedModel || modelName;
|
||||
|
||||
+13
-29
@@ -76,7 +76,7 @@ function _platformIcon(platform) {
|
||||
return '';
|
||||
}
|
||||
|
||||
export let _envState = { env: 'none', envPath: '', hfToken: '', hfTokenConfigured: false, hfTokenMasked: '', gpus: '', remoteHost: '', servers: [], modelPaths: [], platform: '', hostPlatform: '', defaultServer: '' };
|
||||
export let _envState = { env: 'none', envPath: '', hfToken: '', hfTokenConfigured: false, hfTokenMasked: '', gpus: '', remoteHost: '', servers: [], modelPaths: [], platform: '', defaultServer: '' };
|
||||
let _lastCacheHostVal = null;
|
||||
let _cookbookOpeningSpinners = [];
|
||||
export function _lastCacheHost() { return _lastCacheHostVal; }
|
||||
@@ -213,13 +213,8 @@ function _getPort(hostOrTask) {
|
||||
|
||||
/** Get platform for a given host (or task object). Returns 'windows', 'termux', 'linux', or '' */
|
||||
export function _getPlatform(hostOrTask) {
|
||||
if (hostOrTask === 'local') return _envState.hostPlatform || '';
|
||||
if (!hostOrTask) return _envState.remoteHost ? (_envState.platform || '') : (_envState.hostPlatform || '');
|
||||
if (typeof hostOrTask === 'object') {
|
||||
const taskHost = hostOrTask.remoteServerKey || hostOrTask.remoteHost || '';
|
||||
if (!taskHost || taskHost === 'local') return _envState.hostPlatform || '';
|
||||
return hostOrTask.platform || _getPlatform(taskHost);
|
||||
}
|
||||
if (!hostOrTask) return _envState.platform || '';
|
||||
if (typeof hostOrTask === 'object') return hostOrTask.platform || _getPlatform(hostOrTask.remoteServerKey || hostOrTask.remoteHost);
|
||||
const selected = hostOrTask === _envState.remoteHost ? _selectedServer() : null;
|
||||
const srv = selected || _serverByVal(hostOrTask);
|
||||
return srv?.platform || '';
|
||||
@@ -643,12 +638,7 @@ export function _buildServeCmd(f, modelName, backend) {
|
||||
// GPU list — read from gpus (button strip); fall back to gpu_id for
|
||||
// backward-compat with older saved presets that pre-date the removal.
|
||||
const gpuId = (f.gpus || f.gpu_id || '').toString().trim();
|
||||
const _targetHost = Object.prototype.hasOwnProperty.call(f, 'host')
|
||||
? String(f.host || '').trim()
|
||||
: String(_envState.remoteHost || '').trim();
|
||||
const _isWin = _targetHost ? _isWindows(_targetHost) : _isWindows('local');
|
||||
const _localWindows = _isWin && !_targetHost;
|
||||
const py = _isWin ? 'python' : 'python3';
|
||||
const py = _isWindows() ? 'python' : 'python3';
|
||||
// CPU-only serve (-ngl 0): drop the GPU-only flags, otherwise the command
|
||||
// mixes "zero GPU layers" with CUDA unified-memory + flash-attn and fails to
|
||||
// start (issue #1291). Only affects the ngl=0 path; GPU serving is unchanged.
|
||||
@@ -670,19 +660,19 @@ export function _buildServeCmd(f, modelName, backend) {
|
||||
// with misleading prefixes.
|
||||
const _sb = String(_hwfitCache?.system?.backend || '').toLowerCase();
|
||||
const _hwfitHost = String(_hwfitCache?._scannedHost || '');
|
||||
const _curHost = _targetHost;
|
||||
const _curHost = String(_envState.remoteHost || '');
|
||||
const _isCudaTarget = (_sb === 'cuda') && (_hwfitHost === _curHost);
|
||||
const lcPrefix = (() => {
|
||||
let p = '';
|
||||
if (f.unified_mem && !_cpuOnly && (!_isWin || _localWindows) && _isCudaTarget) p += `GGML_CUDA_ENABLE_UNIFIED_MEMORY=1 `;
|
||||
// No GPU env var in CPU mode - `-ngl 0` already disables offload
|
||||
if (f.unified_mem && !_cpuOnly && !_isWindows() && _isCudaTarget) p += `GGML_CUDA_ENABLE_UNIFIED_MEMORY=1 `;
|
||||
// No GPU env var in CPU mode — `-ngl 0` already disables offload
|
||||
// so CUDA_VISIBLE_DEVICES / HIP_VISIBLE_DEVICES would be misleading
|
||||
// clutter ("why is CUDA pinned for a CPU run?").
|
||||
if ((!_isWin || _localWindows) && !_cpuOnly) p += _gpuEnvPrefix(gpuId);
|
||||
if (!_isWindows() && !_cpuOnly) p += _gpuEnvPrefix(gpuId);
|
||||
return p;
|
||||
})();
|
||||
if (f.unified_mem && !_cpuOnly && _isWin && !_localWindows && _isCudaTarget) cmd += `$env:GGML_CUDA_ENABLE_UNIFIED_MEMORY="1"; `;
|
||||
if (_isWin && !_localWindows && !_cpuOnly) cmd += _gpuEnvPrefix(gpuId, true);
|
||||
if (f.unified_mem && !_cpuOnly && _isWindows() && _isCudaTarget) cmd += `$env:GGML_CUDA_ENABLE_UNIFIED_MEMORY="1"; `;
|
||||
if (_isWindows() && !_cpuOnly) cmd += _gpuEnvPrefix(gpuId, true);
|
||||
const needsGgufPrelude = /^\$\(\{\s*find\s/.test(String(ggufPath || ''));
|
||||
const modelArg = needsGgufPrelude ? '"$MODEL_FILE"' : `"${ggufPath}"`;
|
||||
// Prefer native llama-server. The backend bootstrap resolves/builds the
|
||||
@@ -754,16 +744,11 @@ export function _buildServeCmd(f, modelName, backend) {
|
||||
// llama-cpp-python takes the projector via --clip_model_path.
|
||||
_lcpExtra += ` --clip_model_path "${f._mmproj_path}"`;
|
||||
}
|
||||
const _lcServer = `${lcPrefix}llama-server --model ${modelArg} --host 0.0.0.0 --port ${f.port || '8080'} -ngl ${f.ngl || '99'} -c ${f.ctx || '8192'}${_lcExtra}`;
|
||||
if (_isWindows()) {
|
||||
const _lcpServer = `${lcPrefix}${py} -m llama_cpp.server --model ${modelArg} --host 0.0.0.0 --port ${f.port || '8080'} --n_gpu_layers ${f.ngl || '99'} --n_ctx ${f.ctx || '8192'}${_lcpExtra}`;
|
||||
if (_localWindows) {
|
||||
// Local Windows serve is launched through Git Bash, so use the native
|
||||
// llama-server shape and let PATH resolve the CUDA Release wrapper.
|
||||
cmd += _lcServer;
|
||||
} else if (_isWin) {
|
||||
cmd += _lcpServer;
|
||||
} else {
|
||||
cmd += _lcServer;
|
||||
cmd += `${lcPrefix}llama-server --model ${modelArg} --host 0.0.0.0 --port ${f.port || '8080'} -ngl ${f.ngl || '99'} -c ${f.ctx || '8192'}${_lcExtra}`;
|
||||
}
|
||||
if (needsGgufPrelude) {
|
||||
cmd = `MODEL_FILE=${ggufPath} && { [ -n "$MODEL_FILE" ] && [ -f "$MODEL_FILE" ]; } || { echo "ERROR: No GGUF found on this host"; exit 1; } && ${cmd}`;
|
||||
@@ -2627,14 +2612,13 @@ function _renderRecipes() {
|
||||
const isLocal = !s.host || s.host.toLowerCase() === 'local';
|
||||
if (isLocal) {
|
||||
s.host = '';
|
||||
s.platform = _envState.hostPlatform || '';
|
||||
if (_localSeen) return false;
|
||||
_localSeen = true;
|
||||
}
|
||||
return true;
|
||||
});
|
||||
if (!_localSeen) {
|
||||
_es.servers.unshift({ host: '', env: _es.env || 'none', envPath: _es.envPath || '', modelDir: '~/.cache/huggingface/hub', platform: _envState.hostPlatform || '' });
|
||||
_es.servers.unshift({ host: '', env: _es.env || 'none', envPath: _es.envPath || '', modelDir: '~/.cache/huggingface/hub' });
|
||||
}
|
||||
if (_es.remoteHost && !_es.servers.some(s => s.host === _es.remoteHost)) {
|
||||
_es.servers.push({ host: _es.remoteHost, env: _es.env || 'none', envPath: _es.envPath || '', modelDir: '~/.cache/huggingface/hub' });
|
||||
|
||||
@@ -781,7 +781,6 @@ function _stripStateSecrets(state) {
|
||||
const safe = { ...state };
|
||||
if (safe.env && typeof safe.env === 'object') {
|
||||
const { hfToken, ...env } = safe.env;
|
||||
delete env.hostPlatform;
|
||||
safe.env = env;
|
||||
}
|
||||
if (Array.isArray(safe.tasks)) safe.tasks = safe.tasks.map(_redactTaskForStorage);
|
||||
@@ -1674,7 +1673,7 @@ export async function _launchServeTask(shortName, repo, cmd, fields, hostOverrid
|
||||
|| _envState.servers.find(s => s.host === _host) || {};
|
||||
const _serverMetaKey = _targetKey || (_hsrv && _serverKey ? _serverKey(_hsrv) : '') || (_host || 'local');
|
||||
const _serverMetaName = targetMeta?.serverName || _hsrv.name || (_host ? _host : 'Local');
|
||||
const _hplatform = _host ? (_hsrv.platform || '') : (_envState.hostPlatform || '');
|
||||
const _hplatform = _host ? (_hsrv.platform || '') : (_envState.platform || '');
|
||||
const _replaceTaskId = fields?._replaceTaskId || '';
|
||||
if (_replaceTaskId) {
|
||||
try {
|
||||
@@ -1689,6 +1688,7 @@ export async function _launchServeTask(shortName, repo, cmd, fields, hostOverrid
|
||||
}
|
||||
} catch {}
|
||||
}
|
||||
|
||||
// Replace any serve already targeting this same host:port — you can't run two
|
||||
// servers on one port, so re-serving (or retrying) should stop & remove the
|
||||
// old one instead of leaving a dead duplicate behind. (The retry buttons
|
||||
|
||||
@@ -527,7 +527,7 @@ function _selectedServeTarget(panel) {
|
||||
env: server?.env || '',
|
||||
port: host ? (server?.port || _getPort(host) || '') : '',
|
||||
venv,
|
||||
platform: host ? (server?.platform || '') : (_envState.hostPlatform || ''),
|
||||
platform: server?.platform || _envState.platform || '',
|
||||
label,
|
||||
};
|
||||
}
|
||||
@@ -658,12 +658,6 @@ function _selectedGgufSizeGb(model, relPath) {
|
||||
return bytes / (1024 ** 3);
|
||||
}
|
||||
|
||||
function _projectorGgufFiles(model) {
|
||||
return _ggufFilesForModel(model)
|
||||
.filter(f => (f.role || '') === 'projector' || /(^|\/)mmproj[^/]*\.gguf$/i.test(f.rel_path || f.name || ''))
|
||||
.sort((a, b) => String(a.rel_path || a.name || '').localeCompare(String(b.rel_path || b.name || '')));
|
||||
}
|
||||
|
||||
function _ggufFileLabel(file) {
|
||||
const base = (file.name || file.rel_path || '').split('/').pop();
|
||||
const size = _formatGgufSize(file.size_bytes);
|
||||
@@ -1204,7 +1198,6 @@ function _rerenderCachedModels() {
|
||||
panelHtml += `<div class="hwfit-serve-warn" style="margin:0 0 8px;padding:6px 10px;border-radius:5px;font-size:11px;background:color-mix(in srgb, var(--color-warning, #f0ad4e) 14%, transparent);border:1px solid color-mix(in srgb, var(--color-warning, #f0ad4e) 40%, transparent);color:var(--color-warning, #f0ad4e);display:flex;gap:6px;align-items:flex-start;line-height:1.4;"><span aria-hidden="true">⚠</span><span>${_warnText}</span></div>`;
|
||||
}
|
||||
panelHtml += `<div class="hwfit-serve-preset-row">${_slotsHtml}</div>`;
|
||||
panelHtml += `<div class="hwfit-serve-vision-warn" style="display:none;margin:0 0 8px;padding:6px 10px;border-radius:5px;font-size:11px;background:color-mix(in srgb, var(--color-warning, #f0ad4e) 14%, transparent);border:1px solid color-mix(in srgb, var(--color-warning, #f0ad4e) 40%, transparent);color:var(--color-warning, #f0ad4e);gap:6px;align-items:flex-start;line-height:1.4;"><span aria-hidden="true">⚠</span><span>Vision is enabled, but no mmproj GGUF projector was found in the cached model scan. Download an mmproj-*.gguf for this model, then refresh the cached model list before launching.</span></div>`;
|
||||
// Row 1: Engine + Server + Env
|
||||
panelHtml += `<div class="hwfit-serve-row">`;
|
||||
const backendOpts = _backendChoices.map(([v,l]) => `<option value="${v}"${defaultBackend===v?' selected':''}>${l}</option>`).join('');
|
||||
@@ -1531,11 +1524,6 @@ function _rerenderCachedModels() {
|
||||
if (el.type === 'checkbox') f[el.dataset.field] = el.checked;
|
||||
else f[el.dataset.field] = el.value;
|
||||
});
|
||||
const buildTarget = _selectedServeTarget(panel);
|
||||
f.host = buildTarget.host || '';
|
||||
f.platform = buildTarget.platform || '';
|
||||
const hostField = panel.querySelector('[data-field="host"]');
|
||||
if (hostField) hostField.value = f.host;
|
||||
const backend = f.backend || 'vllm';
|
||||
const serveModel = (f.model_path || '').trim() || (m.is_local_dir && m.path ? `${m.path}/${repo}` : repo);
|
||||
if (backend === 'llamacpp') {
|
||||
@@ -1555,11 +1543,11 @@ function _rerenderCachedModels() {
|
||||
: m.is_local_dir && m.path
|
||||
? `$({ find ${_ldir} -name '*-00001-of-*.gguf' 2>/dev/null | sort; find ${_ldir} -name '*.gguf' 2>/dev/null | sort; } | head -1)`
|
||||
: `$({ find ${dir} -name '*-00001-of-*.gguf' 2>/dev/null | sort; find ${dir} -name '*.gguf' 2>/dev/null | sort; } | head -1)`;
|
||||
// Vision: use the scanned projector (CLIP/mmproj) file when present.
|
||||
// Keeping this as a printf path avoids generating a command substitution
|
||||
// that the backend serve-command validator must reject as unsafe.
|
||||
const selectedProjector = _projectorGgufFiles(m)[0];
|
||||
f._mmproj_path = selectedProjector ? _selectedGgufExpr(m, repo, selectedProjector.rel_path) : '';
|
||||
// Vision: auto-find the mmproj (CLIP/projector) file in the same dir.
|
||||
// Resolved at runtime so the toggle just works if an mmproj-*.gguf is
|
||||
// present (downloaded alongside the model). Empty if none → cmd omits it.
|
||||
const _vsearchdir = (m.is_local_dir && m.path) ? _ldir : dir;
|
||||
f._mmproj_path = `$(find ${_vsearchdir} -iname 'mmproj*.gguf' 2>/dev/null | sort | head -1)`;
|
||||
}
|
||||
if (f.reasoning_parser) {
|
||||
const _rpEl2 = panel.querySelector('[data-field="reasoning_parser"]');
|
||||
@@ -1575,10 +1563,6 @@ function _rerenderCachedModels() {
|
||||
}
|
||||
let cmd = _buildServeCmd(f, serveModel, backend);
|
||||
if (f.extra && f.extra.trim()) cmd += ' ' + f.extra.trim();
|
||||
const missingVisionProjector = backend === 'llamacpp' && !!f.vision && !f._mmproj_path;
|
||||
panel._visionMissingProjector = missingVisionProjector;
|
||||
const _visionWarn = panel.querySelector('.hwfit-serve-vision-warn');
|
||||
if (_visionWarn) _visionWarn.style.display = missingVisionProjector ? 'flex' : 'none';
|
||||
const _ce2 = panel.querySelector('.hwfit-serve-cmd'); _ce2.value = _formatServeCmdPreview(cmd); _ce2.style.height = 'auto'; _ce2.style.height = _ce2.scrollHeight + 'px';
|
||||
panel._cmd = cmd;
|
||||
panel._host = f.host || '';
|
||||
@@ -2954,16 +2938,12 @@ function _rerenderCachedModels() {
|
||||
});
|
||||
serveState.backend = serveState.backend || (_detectBackend(m).backend) || 'vllm';
|
||||
const launchTarget = _selectedServeTarget(panel);
|
||||
if (serveState.backend === 'llamacpp' && serveState.vision && !/(?:^|\s)(?:--mmproj|--clip_model_path)\b/.test(launchCmd)) {
|
||||
_restoreLaunchBtn();
|
||||
uiModule.showToast('Vision is checked, but no mmproj projector is in the launch command. Refresh cached models after downloading mmproj, or add --mmproj manually.', 8000);
|
||||
return;
|
||||
}
|
||||
if (serveState.backend === 'diffusers' && _remoteWindowsDiffusersUnsupported(launchTarget)) {
|
||||
_restoreLaunchBtn();
|
||||
uiModule.showToast('Diffusers serving is not supported on remote Windows servers yet. Use local Windows or a Linux server.', 9000);
|
||||
return;
|
||||
}
|
||||
|
||||
// Pre-launch: check our own task list for a serve already running
|
||||
// on this host. Offer to stop+launch as the default action — the
|
||||
// SSH-based port probe below is more thorough but it can miss
|
||||
|
||||
+13
-72
@@ -6,7 +6,7 @@ import markdownModule from './markdown.js';
|
||||
import chatRenderer from './chatRenderer.js';
|
||||
import spinnerModule from './spinner.js';
|
||||
import { providerLogo } from './providers.js';
|
||||
import { PROMPT_TEMPLATES, getUserTemplates } from './presets.js';
|
||||
import { PROMPT_TEMPLATES, getAllPresets } from './presets.js';
|
||||
import { sortModelObjects } from './modelSort.js';
|
||||
import Storage from './storage.js';
|
||||
|
||||
@@ -89,16 +89,12 @@ function _initGroupTab() {
|
||||
|
||||
const charSel = document.createElement('select');
|
||||
charSel.className = 'preset-input';
|
||||
// add an identifier that this is a character selection
|
||||
charSel.dataset.selectionType = "character"
|
||||
charSel.style.cssText = 'font-size:11px;flex:1;height:26px;';
|
||||
charSel.innerHTML = '<option value="">Empty...</option>' +
|
||||
characters.map(c => '<option value="' + c.id + '">' + uiModule.esc(c.name) + '</option>').join('');
|
||||
|
||||
const modelSel = document.createElement('select');
|
||||
modelSel.className = 'preset-input';
|
||||
// add an identifier that this is a model selection
|
||||
modelSel.dataset.selectionType = "model"
|
||||
modelSel.style.cssText = 'font-size:11px;flex:1;height:26px;';
|
||||
modelSel.innerHTML = '<option value="">Model…</option>' +
|
||||
models.map(m => '<option value="' + m.mid + '">' + uiModule.esc(m.display) + '</option>').join('');
|
||||
@@ -200,67 +196,15 @@ function _initGroupTab() {
|
||||
});
|
||||
|
||||
const groupTab = document.querySelector('.preset-tab[data-chartab="group"]');
|
||||
// whenever a user navigates to the Group tab
|
||||
if (groupTab) groupTab.addEventListener('click', () => {
|
||||
_modelsCache = null;
|
||||
if (startBtn) startBtn.textContent = 'Start Group';
|
||||
_loadGroupPresets();
|
||||
|
||||
const isGroupTabUnInitialized =
|
||||
_groupParticipants.length === 0 && participantsEl.children.length === 0;
|
||||
|
||||
if (isGroupTabUnInitialized) {
|
||||
if (_groupParticipants.length === 0) {
|
||||
setTimeout(() => addBtn.click(), 100);
|
||||
} else {
|
||||
// queue this asynchronously since repopulating the selection drop-downs
|
||||
// do not need to be visible right away; it can be safely delayed before
|
||||
// the next event loop
|
||||
queueMicrotask(() => {
|
||||
repopulateExistingSelections();
|
||||
})
|
||||
}
|
||||
});
|
||||
|
||||
async function repopulateExistingSelections() {
|
||||
const EMPTY = "";
|
||||
|
||||
const characterSelections = participantsEl.querySelectorAll("select.preset-input[data-selection-type=character]");
|
||||
const modelSelections = participantsEl.querySelectorAll("select.preset-input[data-selection-type=model]");
|
||||
|
||||
if (characterSelections.length !== 0) {
|
||||
const characters = await _getCharacterList();
|
||||
|
||||
characterSelections.forEach((characterSelection) => {
|
||||
|
||||
const chosenCharacter = characterSelection.value;
|
||||
const isChosenCharacterExisting = chosenCharacter !== EMPTY
|
||||
&& characters.findIndex((char) => char.id === chosenCharacter) !== -1;
|
||||
|
||||
characterSelection.innerHTML = '<option value="">Empty...</option>' +
|
||||
characters.map(c => '<option value="' + c.id + '">' + uiModule.esc(c.name) + '</option>').join('');
|
||||
if (isChosenCharacterExisting) {
|
||||
characterSelection.value = chosenCharacter;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
if (modelSelections.length !== 0) {
|
||||
const models = await _getModels();
|
||||
|
||||
modelSelections.forEach((modelSelection) => {
|
||||
const chosenModel = modelSelection.value;
|
||||
const isChosenModelExisting = chosenModel !== EMPTY
|
||||
&& models.findIndex((model) => model.mid === chosenModel) !== -1;
|
||||
|
||||
modelSelection.innerHTML = '<option value="">Model…</option>' +
|
||||
models.map(m => '<option value="' + m.mid + '">' + uiModule.esc(m.display) + '</option>').join('');
|
||||
if (isChosenModelExisting) {
|
||||
modelSelection.value = chosenModel;
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Load and render saved group presets
|
||||
async function _loadGroupPresets() {
|
||||
try {
|
||||
@@ -344,6 +288,17 @@ async function _getCharacterList() {
|
||||
const chars = PROMPT_TEMPLATES.filter(t => t.isCharacter).map(t => ({
|
||||
id: t.id, name: t.name, prompt: t.prompt,
|
||||
}));
|
||||
// User-created characters from presets
|
||||
try {
|
||||
const allPresets = getAllPresets();
|
||||
if (allPresets && allPresets.custom && allPresets.custom.character_name) {
|
||||
chars.push({
|
||||
id: 'custom',
|
||||
name: allPresets.custom.character_name,
|
||||
prompt: allPresets.custom.system_prompt || allPresets.custom.prompt || '',
|
||||
});
|
||||
}
|
||||
} catch (e) {}
|
||||
// Load user templates and wait for them before returning.
|
||||
// The endpoint returns a JSON array directly (not {templates:[...]}).
|
||||
// All user templates are personas by definition — no isCharacter filter needed.
|
||||
@@ -351,26 +306,12 @@ async function _getCharacterList() {
|
||||
const r = await fetch(API_BASE + '/api/presets/templates', { credentials: 'same-origin' });
|
||||
const data = await r.json();
|
||||
const templates = Array.isArray(data) ? data : (data.templates || []);
|
||||
|
||||
templates.forEach(t => {
|
||||
if (t.id && t.name && !chars.find(c => c.id === t.id)) {
|
||||
chars.push({ id: t.id, name: t.name, prompt: t.system_prompt || t.prompt || '' });
|
||||
}
|
||||
});
|
||||
} catch (e) {}
|
||||
|
||||
// Also merge in-memory templates from presets.js — these may include
|
||||
// newly created characters whose async save-to-API hasn't completed yet.
|
||||
const memTemplates = getUserTemplates();
|
||||
|
||||
if (Array.isArray(memTemplates)) {
|
||||
memTemplates.forEach(t => {
|
||||
if (t.id && t.name && !chars.find(c => c.id === t.id)) {
|
||||
chars.push({ id: t.id, name: t.name, prompt: t.system_prompt || t.prompt || '' });
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
return chars;
|
||||
}
|
||||
|
||||
|
||||
+7
-48
@@ -830,48 +830,15 @@ export async function saveCustomPreset(showToast, showError) {
|
||||
const _selVal = document.getElementById('char-template-select')?.value || '';
|
||||
const isBuiltinPreset = PROMPT_TEMPLATES.some(t => t.isPreset && (t.name === name || t.name === _selVal));
|
||||
const saveName = isBuiltinPreset ? null : (name || null);
|
||||
|
||||
if (saveName) {
|
||||
const _existing = userTemplates.find(t => t.name === saveName);
|
||||
let clone;
|
||||
const _entry = {
|
||||
id: _existing && _existing.id
|
||||
|| 'user-' + Math.random().toString(16).slice(2, 10),
|
||||
name: saveName,
|
||||
// use ?? since it's more semantic for null-coalescing
|
||||
system_prompt: system_prompt ?? '',
|
||||
temperature: config.temperature,
|
||||
max_tokens: config.max_tokens,
|
||||
}
|
||||
const ENDPOINT = `${API_BASE}/api/presets/templates`;
|
||||
|
||||
// Optimistically update the in-memory templates list by @michaelxer
|
||||
if (_existing) {
|
||||
// slow but works for now
|
||||
clone = JSON.parse(JSON.stringify(_existing));
|
||||
|
||||
Object.assign(_existing, _entry);
|
||||
} else {
|
||||
userTemplates.push(_entry);
|
||||
}
|
||||
|
||||
fetch(ENDPOINT, {
|
||||
method: "POST",
|
||||
fetch(`${API_BASE}/api/presets/templates`, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify(_entry)
|
||||
}).then((r) => {
|
||||
if (r.ok) {
|
||||
loadUserTemplates();
|
||||
}
|
||||
}).catch(() => {
|
||||
if (clone) {
|
||||
Object.assign(_existing, clone);
|
||||
}
|
||||
|
||||
if (showError) {
|
||||
showError(_isInjectStart ? "Something went wrong. Saved prompt has been undone." : "Something went wrong. Saved persona has been undone.");
|
||||
}
|
||||
});
|
||||
body: JSON.stringify({
|
||||
id: (userTemplates.find(t => t.name === saveName) || {}).id || '',
|
||||
name: saveName, system_prompt, temperature: config.temperature, max_tokens: config.max_tokens,
|
||||
}),
|
||||
}).then(r => { if (r.ok) loadUserTemplates(); }).catch(() => {});
|
||||
}
|
||||
|
||||
if (showToast) {
|
||||
@@ -916,13 +883,6 @@ export function getAllPresets() {
|
||||
return presets;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the in-memory user templates list (may be stale; call loadUserTemplates first if freshness matters).
|
||||
*/
|
||||
export function getUserTemplates() {
|
||||
return [...userTemplates];
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the character name (if set)
|
||||
*/
|
||||
@@ -1139,7 +1099,6 @@ const presetsModule = {
|
||||
getSelectedPreset,
|
||||
getPreset,
|
||||
getAllPresets,
|
||||
getUserTemplates,
|
||||
getCharacterName,
|
||||
onSessionSwitch,
|
||||
isPersistentChat,
|
||||
|
||||
+22
-30
@@ -340,12 +340,19 @@ export function showToast(msg, durationOrOpts) {
|
||||
stack.style.cssText = 'display:inline-flex;flex-direction:column;align-items:center;gap:1px;margin-left:10px;line-height:1;';
|
||||
|
||||
const btn = document.createElement('button');
|
||||
// If the caller supplied an SVG icon, prepend it. We trust the icon string
|
||||
// (only set internally) — never accept caller-controlled HTML otherwise.
|
||||
if (actionIcon) {
|
||||
btn.innerHTML = `<span style="display:inline-flex;align-items:center;gap:5px;">${actionIcon}<span></span></span>`;
|
||||
btn.querySelector('span span').textContent = actionLabel;
|
||||
} else {
|
||||
btn.textContent = actionLabel;
|
||||
}
|
||||
// The toast itself is `pointer-events: none` so it doesn't block clicks
|
||||
// beneath it. With an action button we need to flip both the toast AND
|
||||
// the button so the user can actually click Undo. The flag is reset on
|
||||
// the next plain showToast / showError call (those overwrite textContent
|
||||
// which strips the button + we clear inline style at the top below).
|
||||
btn.style.cssText = 'padding:2px 10px;border:1px solid var(--fg);border-radius:4px;background:none;color:var(--fg);cursor:pointer;font-size:12px;pointer-events:auto;display:inline-flex;align-items:center;';
|
||||
btn.addEventListener('click', (e) => {
|
||||
e.stopPropagation();
|
||||
@@ -355,6 +362,8 @@ export function showToast(msg, durationOrOpts) {
|
||||
});
|
||||
stack.appendChild(btn);
|
||||
|
||||
// Keyboard-shortcut hints (Ctrl+Z / ⌘Z) are meaningless on touch devices —
|
||||
// skip them on mobile so the toast just shows the Undo button.
|
||||
if (actionHint && window.innerWidth > 768) {
|
||||
const hint = document.createElement('span');
|
||||
hint.textContent = actionHint;
|
||||
@@ -363,28 +372,32 @@ export function showToast(msg, durationOrOpts) {
|
||||
}
|
||||
|
||||
toastEl.appendChild(stack);
|
||||
toastEl.style.pointerEvents = 'auto';
|
||||
} else {
|
||||
toastEl.style.pointerEvents = '';
|
||||
}
|
||||
|
||||
// Close button for all toasts — dismiss without waiting for timeout.
|
||||
// Small × to dismiss the toast without taking the action. Useful when
|
||||
// the user already acted (or just doesn't want the banner sitting there).
|
||||
const closeBtn = document.createElement('button');
|
||||
closeBtn.type = 'button';
|
||||
closeBtn.className = 'toast-close-btn';
|
||||
closeBtn.setAttribute('aria-label', 'Dismiss');
|
||||
closeBtn.title = 'Dismiss';
|
||||
closeBtn.textContent = '×';
|
||||
closeBtn.style.cssText = 'margin-left:8px;padding:0;width:20px;height:20px;line-height:1;border:none;background:none;color:var(--fg);opacity:0.55;cursor:pointer;font-size:18px;border-radius:50%;display:inline-flex;align-items:center;justify-content:center;pointer-events:auto;';
|
||||
closeBtn.addEventListener('mouseenter', () => { closeBtn.style.opacity = '1'; });
|
||||
closeBtn.addEventListener('mouseleave', () => { closeBtn.style.opacity = '0.55'; });
|
||||
closeBtn.addEventListener('click', (e) => {
|
||||
e.stopPropagation();
|
||||
e.preventDefault();
|
||||
clearTimeout(toastEl._hideTimer);
|
||||
toastEl.classList.add('exiting');
|
||||
toastEl.classList.remove('show');
|
||||
toastEl.style.pointerEvents = '';
|
||||
});
|
||||
toastEl.appendChild(closeBtn);
|
||||
|
||||
toastEl.style.pointerEvents = 'auto';
|
||||
} else {
|
||||
// No action — restore the default non-blocking behavior.
|
||||
toastEl.style.pointerEvents = '';
|
||||
}
|
||||
|
||||
// Pin to top-right via CSS — clear any legacy inline overrides so the
|
||||
// slide-in-from-right / slide-out-to-left transition can run cleanly.
|
||||
toastEl.style.left = '';
|
||||
@@ -415,38 +428,17 @@ export function showError(msg) {
|
||||
toastEl = document.getElementById('toast');
|
||||
}
|
||||
_wireToastSwipe(toastEl);
|
||||
toastEl.textContent = '';
|
||||
toastEl.textContent = msg;
|
||||
toastEl.classList.add('error');
|
||||
toastEl.style.left = '';
|
||||
toastEl.style.transform = '';
|
||||
toastEl.classList.remove('exiting');
|
||||
toastEl.classList.add('show');
|
||||
clearTimeout(toastEl._hideTimer);
|
||||
|
||||
const textSpan = document.createElement('span');
|
||||
textSpan.textContent = msg;
|
||||
toastEl.appendChild(textSpan);
|
||||
|
||||
const closeBtn = document.createElement('button');
|
||||
closeBtn.type = 'button';
|
||||
closeBtn.className = 'toast-close-btn';
|
||||
closeBtn.setAttribute('aria-label', 'Dismiss');
|
||||
closeBtn.title = 'Dismiss';
|
||||
closeBtn.textContent = '×';
|
||||
closeBtn.addEventListener('click', (e) => {
|
||||
e.stopPropagation();
|
||||
e.preventDefault();
|
||||
clearTimeout(toastEl._hideTimer);
|
||||
toastEl.classList.add('exiting');
|
||||
toastEl.classList.remove('show');
|
||||
toastEl.style.pointerEvents = '';
|
||||
});
|
||||
toastEl.appendChild(closeBtn);
|
||||
|
||||
toastEl._hideTimer = setTimeout(() => {
|
||||
toastEl.classList.add('exiting');
|
||||
toastEl.classList.remove('show');
|
||||
}, 6000);
|
||||
}, 3000);
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@@ -4062,31 +4062,6 @@ body.bg-pattern-sparkles {
|
||||
@keyframes toastCheckDraw {
|
||||
to { stroke-dashoffset: 0; }
|
||||
}
|
||||
.toast-close-btn {
|
||||
margin-left: 8px;
|
||||
padding: 0;
|
||||
width: 22px;
|
||||
height: 22px;
|
||||
line-height: 1;
|
||||
border: none;
|
||||
background: none;
|
||||
color: var(--fg);
|
||||
opacity: 0.5;
|
||||
cursor: pointer;
|
||||
font-size: 16px;
|
||||
border-radius: 50%;
|
||||
display: inline-flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
pointer-events: auto;
|
||||
flex-shrink: 0;
|
||||
transition: transform 0.22s ease, opacity 0.15s ease, background 0.15s ease;
|
||||
}
|
||||
.toast-close-btn:hover {
|
||||
opacity: 1;
|
||||
transform: rotate(90deg);
|
||||
background: color-mix(in srgb, var(--fg) 8%, transparent);
|
||||
}
|
||||
.toast.exiting {
|
||||
opacity: 0;
|
||||
transform: translateX(-120%);
|
||||
|
||||
@@ -1,124 +0,0 @@
|
||||
"""Shared fakes for embedding-lane tests."""
|
||||
|
||||
|
||||
class FakeEmbedder:
|
||||
def __init__(self, dim, model, url):
|
||||
self.dim = dim
|
||||
self.model = model
|
||||
self.url = url
|
||||
|
||||
def get_sentence_embedding_dimension(self):
|
||||
return self.dim
|
||||
|
||||
def encode(self, texts, normalize_embeddings=True):
|
||||
return [[float(i + 1)] * self.dim for i, _ in enumerate(texts)]
|
||||
|
||||
|
||||
class FailingEmbedder(FakeEmbedder):
|
||||
def encode(self, texts, normalize_embeddings=True):
|
||||
raise RuntimeError("embedding endpoint rate limited")
|
||||
|
||||
|
||||
class FakeCollection:
|
||||
def __init__(self, name, metadata=None):
|
||||
self.name = name
|
||||
self.metadata = metadata or {}
|
||||
self.rows = {}
|
||||
self.dim = None
|
||||
|
||||
def count(self):
|
||||
return len(self.rows)
|
||||
|
||||
def add(self, ids, embeddings, documents=None, metadatas=None):
|
||||
self._check_dim(embeddings)
|
||||
documents = documents or [None] * len(ids)
|
||||
metadatas = metadatas or [{}] * len(ids)
|
||||
for row_id, emb, doc, meta in zip(ids, embeddings, documents, metadatas):
|
||||
self.rows[row_id] = {"embedding": emb, "document": doc, "metadata": meta}
|
||||
|
||||
def upsert(self, ids, embeddings, documents=None, metadatas=None):
|
||||
self.add(ids, embeddings, documents=documents, metadatas=metadatas)
|
||||
|
||||
def get(self, ids=None, include=None, where=None, limit=None):
|
||||
selected = list(self.rows.items())
|
||||
if ids is not None:
|
||||
id_set = set(ids)
|
||||
selected = [(row_id, row) for row_id, row in selected if row_id in id_set]
|
||||
if where:
|
||||
selected = [
|
||||
(row_id, row)
|
||||
for row_id, row in selected
|
||||
if all(row["metadata"].get(k) == v for k, v in where.items())
|
||||
]
|
||||
if limit is not None:
|
||||
selected = selected[:limit]
|
||||
return {
|
||||
"ids": [row_id for row_id, _ in selected],
|
||||
"documents": [row["document"] for _, row in selected],
|
||||
"metadatas": [row["metadata"] for _, row in selected],
|
||||
"embeddings": [row["embedding"] for _, row in selected],
|
||||
}
|
||||
|
||||
def query(self, query_embeddings, n_results, where=None, include=None):
|
||||
self._check_dim(query_embeddings)
|
||||
rows = self.get(where=where)
|
||||
ids = rows["ids"][:n_results]
|
||||
docs = rows["documents"][:n_results]
|
||||
metas = rows["metadatas"][:n_results]
|
||||
return {
|
||||
"ids": [ids],
|
||||
"documents": [docs],
|
||||
"metadatas": [metas],
|
||||
"distances": [[0.1 + i * 0.01 for i in range(len(ids))]],
|
||||
}
|
||||
|
||||
def delete(self, ids):
|
||||
for row_id in ids:
|
||||
self.rows.pop(row_id, None)
|
||||
|
||||
def _check_dim(self, embeddings):
|
||||
if not embeddings:
|
||||
return
|
||||
dim = len(embeddings[0])
|
||||
if self.dim is None:
|
||||
self.dim = dim
|
||||
elif self.dim != dim:
|
||||
raise RuntimeError(f"Collection expecting embedding with dimension of {self.dim}, got {dim}")
|
||||
|
||||
|
||||
class FakeChroma:
|
||||
def __init__(self):
|
||||
self.collections = {}
|
||||
self.deleted = []
|
||||
self.fail_next_add_for = {}
|
||||
|
||||
def get_or_create_collection(self, name, metadata=None):
|
||||
if name not in self.collections:
|
||||
self.collections[name] = FakeCollection(name, metadata=metadata)
|
||||
if self.fail_next_add_for.get(name, 0) > 0:
|
||||
original_add = self.collections[name].add
|
||||
|
||||
def fail_once(*args, **kwargs):
|
||||
self.fail_next_add_for[name] -= 1
|
||||
self.collections[name].add = original_add
|
||||
raise RuntimeError("chroma write failed")
|
||||
|
||||
self.collections[name].add = fail_once
|
||||
elif metadata is not None:
|
||||
self.collections[name].metadata = metadata
|
||||
return self.collections[name]
|
||||
|
||||
def get_collection(self, name):
|
||||
if name not in self.collections:
|
||||
raise KeyError(name)
|
||||
return self.collections[name]
|
||||
|
||||
def delete_collection(self, name):
|
||||
self.deleted.append(name)
|
||||
self.collections.pop(name, None)
|
||||
|
||||
|
||||
def patch_chroma(monkeypatch, fake):
|
||||
import src.chroma_client as chroma_client
|
||||
|
||||
monkeypatch.setattr(chroma_client, "get_chroma_client", lambda: fake)
|
||||
+1
-17
@@ -47,12 +47,6 @@ AREAS: tuple[str, ...] = (
|
||||
"uncategorized",
|
||||
)
|
||||
|
||||
# Backward-compatible aggregate selectors for focused runs whose original
|
||||
# monolithic files were split into more specific taxonomy sub-areas.
|
||||
SUB_AREA_ALIASES: dict[str, tuple[str, ...]] = {
|
||||
"embedding": ("embedding", "embedding_memory"),
|
||||
}
|
||||
|
||||
|
||||
def normalize_sub_area(value: str) -> str:
|
||||
"""Normalize a CLI sub-area value and remove an optional ``sub_`` prefix."""
|
||||
@@ -108,13 +102,6 @@ def sub_area_type(valid_sub_areas: frozenset[str]) -> Callable[[str], str]:
|
||||
return validate
|
||||
|
||||
|
||||
def _sub_area_marker_expression(sub_area: str) -> str:
|
||||
"""Build the marker expression for a sub-area, including narrow aliases."""
|
||||
aliases = SUB_AREA_ALIASES.get(sub_area, (sub_area,))
|
||||
markers = [f"sub_{alias}" for alias in aliases]
|
||||
return " or ".join(markers)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class FocusSelection:
|
||||
"""A single focused-selection request, decoupled from argparse and pytest."""
|
||||
@@ -156,10 +143,7 @@ def build_marker_expression(
|
||||
if area:
|
||||
parts.append(f"area_{area}")
|
||||
if sub_area:
|
||||
sub_expression = _sub_area_marker_expression(sub_area)
|
||||
if " or " in sub_expression:
|
||||
sub_expression = f"({sub_expression})"
|
||||
parts.append(sub_expression)
|
||||
parts.append(f"sub_{sub_area}")
|
||||
if fast:
|
||||
parts.append("not slow")
|
||||
if not parts:
|
||||
|
||||
@@ -58,17 +58,12 @@ def test_owner_adapter_defaults_owner_to_none():
|
||||
|
||||
|
||||
def test_parse_tool_args_lives_in_tool_utils_single_source():
|
||||
# The helper was de-duplicated into tool_utils; every consumer imports it
|
||||
# from there rather than carrying its own copy. After the tool_implementations
|
||||
# split, _common and the facade must also re-export the same object.
|
||||
# The helper was de-duplicated into tool_utils; admin_tools imports it
|
||||
# from there rather than carrying its own copy.
|
||||
from src.tool_utils import _parse_tool_args
|
||||
from src.agent_tools import admin_tools, document_tools
|
||||
from src.tools import _common
|
||||
import src.tool_implementations as ti
|
||||
assert admin_tools._parse_tool_args is _parse_tool_args
|
||||
assert document_tools._parse_tool_args is _parse_tool_args
|
||||
assert _common._parse_tool_args is _parse_tool_args
|
||||
assert ti._parse_tool_args is _parse_tool_args
|
||||
assert _parse_tool_args('{"action":"add"}') == {"action": "add"}
|
||||
# body-envelope unwrap still works
|
||||
assert _parse_tool_args('{"body":{"action":"x"}}') == {"action": "x"}
|
||||
|
||||
@@ -53,7 +53,7 @@ def test_http_calendar_writes_mark_pending_and_push_after_commit():
|
||||
|
||||
|
||||
def test_agent_calendar_writes_share_caldav_push_path():
|
||||
source = Path("src/tools/calendar.py").read_text()
|
||||
source = Path("src/tool_implementations.py").read_text()
|
||||
|
||||
assert "_push_caldav_event_after_commit" in source
|
||||
assert 'caldav_sync_pending="create" if cal.source == "caldav" else None' in source
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
"""Regression guard for issue #1291 - CPU-only serve still emitted GPU-only flags.
|
||||
"""Regression guard for issue #1291 — CPU-only serve still emitted GPU-only flags.
|
||||
|
||||
The llama.cpp serve command builder (static/js/cookbook.js) added
|
||||
`--flash-attn on` and exported `GGML_CUDA_ENABLE_UNIFIED_MEMORY=1` from
|
||||
@@ -16,8 +16,8 @@ from pathlib import Path
|
||||
|
||||
SRC = Path(__file__).resolve().parent.parent / "static/js/cookbook.js"
|
||||
SERVE_SRC = Path(__file__).resolve().parent.parent / "static/js/cookbookServe.js"
|
||||
ROOT = SRC.parent.parent.parent
|
||||
ROUTES_SRC = ROOT / "routes/cookbook_routes.py"
|
||||
ROUTES_SRC = Path(__file__).resolve().parent.parent / "routes/cookbook_routes.py"
|
||||
|
||||
|
||||
def test_cpu_only_drops_gpu_only_flags():
|
||||
text = SRC.read_text(encoding="utf-8")
|
||||
@@ -84,101 +84,3 @@ def test_vllm_route_strips_swap_space_when_runtime_rejects_it():
|
||||
assert "print(shlex.join(parts[:serve_i + 1] + [\"--help\"]))" in text
|
||||
assert "eval \"$ODYSSEUS_VLLM_HELP_CMD\" 2>&1 | grep -q -- \"--swap-space\"" in text
|
||||
assert "eval \"$ODYSSEUS_SERVE_CMD\"" in text
|
||||
|
||||
|
||||
def test_local_windows_platform_comes_from_backend_host_state():
|
||||
text = SRC.read_text(encoding="utf-8")
|
||||
routes = ROUTES_SRC.read_text(encoding="utf-8")
|
||||
running = (SRC.parent / "cookbookRunning.js").read_text(encoding="utf-8")
|
||||
|
||||
assert "hostPlatform" in text
|
||||
assert "navigator.platform" not in text
|
||||
assert "hostOrTask === 'local'" in text
|
||||
assert "if (hostOrTask === 'local') return _envState.hostPlatform || '';" in text
|
||||
assert "return _envState.hostPlatform || _envState.platform || ''" not in text
|
||||
assert "s.platform = _envState.hostPlatform || '';" in text
|
||||
assert "platform: _envState.hostPlatform || ''" in text
|
||||
assert "s.platform = _envState.hostPlatform || _envState.platform || '';" not in text
|
||||
assert "platform: _envState.hostPlatform || _envState.platform || ''" not in text
|
||||
assert 'return "windows" if IS_WINDOWS else ""' in routes
|
||||
assert 'env["hostPlatform"] = _client_host_platform()' in routes
|
||||
assert "return _state_for_client({})" in routes
|
||||
assert 'env.pop("hostPlatform", None)' in routes
|
||||
assert "delete env.hostPlatform;" in running
|
||||
|
||||
|
||||
def test_local_serve_payload_ignores_stale_env_platform():
|
||||
serve = SERVE_SRC.read_text(encoding="utf-8")
|
||||
running = (SRC.parent / "cookbookRunning.js").read_text(encoding="utf-8")
|
||||
|
||||
assert "platform: host ? (server?.platform || '') : (_envState.hostPlatform || '')," in serve
|
||||
assert "platform: server?.platform || _envState.platform || ''" not in serve
|
||||
assert "const _hplatform = _host ? (_hsrv.platform || '') : (_envState.hostPlatform || '');" in running
|
||||
assert "const _hplatform = _host ? (_hsrv.platform || '') : (_envState.platform || '');" not in running
|
||||
|
||||
|
||||
def test_local_windows_llamacpp_prefers_native_llama_server():
|
||||
text = SRC.read_text(encoding="utf-8")
|
||||
helpers = (ROOT / "routes/cookbook_helpers.py").read_text(encoding="utf-8")
|
||||
|
||||
assert "Object.prototype.hasOwnProperty.call(f, 'host')" in text
|
||||
assert "const _isWin = _targetHost ? _isWindows(_targetHost) : _isWindows('local');" in text
|
||||
assert "const _localWindows = _isWin && !_targetHost;" in text
|
||||
assert "const _curHost = _targetHost;" in text
|
||||
assert "const _localWindows = _isWin && !_envState.remoteHost;" not in text
|
||||
assert "const gpuId = (f.gpus || f.gpu_id || '').toString().trim();" in text
|
||||
assert "const _lcServer = `${lcPrefix}llama-server --model" in text
|
||||
assert "if (_localWindows) {" in text
|
||||
assert "cmd += _lcServer;" in text
|
||||
assert '"llama-server.exe"' in helpers
|
||||
|
||||
|
||||
|
||||
def test_serve_command_preview_uses_selected_target_host():
|
||||
text = SERVE_SRC.read_text(encoding="utf-8")
|
||||
|
||||
assert "const buildTarget = _selectedServeTarget(panel);" in text
|
||||
assert "f.host = buildTarget.host || '';" in text
|
||||
assert "f.platform = buildTarget.platform || '';" in text
|
||||
assert "const hostField = panel.querySelector('[data-field=\"host\"]');" in text
|
||||
assert "if (hostField) hostField.value = f.host;" in text
|
||||
|
||||
|
||||
def test_local_windows_llama_server_skips_source_bootstrap():
|
||||
routes = ROUTES_SRC.read_text(encoding="utf-8")
|
||||
|
||||
assert 'local_windows_llama_cmd = local_windows and ("llama_cpp" in req.cmd or "llama-server" in req.cmd)' in routes
|
||||
assert 'if ("llama_cpp" in req.cmd or "llama-server" in req.cmd) and not local_windows_llama_cmd:' in routes
|
||||
|
||||
|
||||
def test_local_windows_llama_server_path_includes_user_wrapper_and_cuda_builds():
|
||||
routes = (ROOT / "routes/cookbook_routes.py").read_text(encoding="utf-8")
|
||||
|
||||
assert 'if local_windows:' in routes
|
||||
assert (
|
||||
'export PATH="$HOME/bin:$HOME/llama.cpp/build-cuda/bin/Release:'
|
||||
'$HOME/llama.cpp/build/bin/Release:$HOME/llama.cpp/build/bin/Debug:'
|
||||
'$HOME/llama.cpp/build/bin:$PATH"'
|
||||
) in routes
|
||||
|
||||
|
||||
def test_serve_panel_keeps_row_markup_and_launch_cmd_assignment_executable():
|
||||
text = SERVE_SRC.read_text(encoding="utf-8").replace("\r\n", "\n")
|
||||
|
||||
assert '// Row 1: Engine + Server + Env panelHtml +=' not in text
|
||||
assert "px'; panel._cmd = cmd;" not in text
|
||||
assert '// Row 1: Engine + Server + Env\n panelHtml += `<div class="hwfit-serve-row">`;' in text
|
||||
assert "px';\n panel._cmd = cmd;" in text
|
||||
|
||||
|
||||
def test_llamacpp_vision_uses_scanned_projector_instead_of_runtime_find():
|
||||
text = SERVE_SRC.read_text(encoding="utf-8")
|
||||
|
||||
assert "function _projectorGgufFiles(model)" in text
|
||||
assert "const selectedProjector = _projectorGgufFiles(m)[0];" in text
|
||||
assert "f._mmproj_path = selectedProjector ? _selectedGgufExpr(m, repo, selectedProjector.rel_path) : '';" in text
|
||||
assert "const missingVisionProjector = backend === 'llamacpp' && !!f.vision && !f._mmproj_path;" in text
|
||||
assert "hwfit-serve-vision-warn" in text
|
||||
assert "!/(?:^|\\s)(?:--mmproj|--clip_model_path)\\b/.test(launchCmd)" in text
|
||||
assert "no mmproj projector is in the launch command" in text
|
||||
assert "find ${_vsearchdir} -iname 'mmproj*.gguf'" not in text
|
||||
|
||||
@@ -419,6 +419,8 @@ def test_pip_install_attempt_failure_propagates_real_exit_code():
|
||||
"""Run the generated snippet against a deliberately broken pip install
|
||||
to confirm the subshell exits with pip's non-zero status."""
|
||||
snippet = _pip_install_attempt("python3 -m pip install __nonexistent_package_12345__")
|
||||
if sys.platform == "win32":
|
||||
snippet = snippet.replace("$", "\\$")
|
||||
result = subprocess.run(
|
||||
["bash", "-c", snippet],
|
||||
capture_output=True,
|
||||
@@ -431,6 +433,8 @@ def test_pip_install_attempt_failure_propagates_real_exit_code():
|
||||
def test_pip_install_attempt_success_exits_zero():
|
||||
"""When pip succeeds, the subshell should exit 0."""
|
||||
snippet = _pip_install_attempt("python3 -c 'pass'")
|
||||
if sys.platform == "win32":
|
||||
snippet = snippet.replace("$", "\\$")
|
||||
result = subprocess.run(
|
||||
["bash", "-c", snippet],
|
||||
capture_output=True,
|
||||
@@ -443,6 +447,8 @@ def test_pip_install_attempt_success_exits_zero():
|
||||
def test_pip_install_attempt_surfaces_stderr_on_failure():
|
||||
"""On failure, the last 5 lines of pip output should appear in stdout."""
|
||||
snippet = _pip_install_attempt("python3 -m pip install __nonexistent_package_12345__")
|
||||
if sys.platform == "win32":
|
||||
snippet = snippet.replace("$", "\\$")
|
||||
result = subprocess.run(
|
||||
["bash", "-c", snippet],
|
||||
capture_output=True,
|
||||
@@ -551,19 +557,6 @@ def test_validate_serve_cmd_accepts_windows_printf_format():
|
||||
assert _validate_serve_cmd(cmd) == cmd
|
||||
|
||||
|
||||
def test_validate_serve_cmd_accepts_llama_mmproj_printf_format():
|
||||
cmd = (
|
||||
"CUDA_VISIBLE_DEVICES=0 llama-server --model "
|
||||
"\"$(printf %s ${HOME}'/.cache/huggingface/hub/models--unsloth--Qwen3.6-35B-A3B-GGUF/snapshots/abc/Qwen3.6-35B-A3B-UD-Q4_K_M.gguf')\" "
|
||||
"--host 0.0.0.0 --port 8000 -ngl 99 -c 20000 "
|
||||
"--cache-type-k q4_0 --cache-type-v q4_0 --mmproj "
|
||||
"\"$(printf %s ${HOME}'/.cache/huggingface/hub/models--unsloth--Qwen3.6-35B-A3B-GGUF/snapshots/abc/mmproj-BF16.gguf')\" "
|
||||
"--image-max-tokens 1024"
|
||||
)
|
||||
|
||||
assert _validate_serve_cmd(cmd) == cmd
|
||||
|
||||
|
||||
def test_normalize_llama_cpp_python_cache_types_for_stale_client_cmd():
|
||||
cmd = (
|
||||
"python -m llama_cpp.server --model model.gguf --host 0.0.0.0 --port 8000 "
|
||||
|
||||
@@ -54,13 +54,3 @@ def test_styled_dialogs_manage_focus():
|
||||
assert _UI.count("_prevFocus && _prevFocus.focus && _prevFocus.focus()") == 2
|
||||
assert _UI.count("e.key === 'Tab'") == 2
|
||||
|
||||
|
||||
def test_toast_has_dismiss_button():
|
||||
"""Both showToast and showError must include a close button with aria-label."""
|
||||
# Read fresh every time so edits to ui.js are picked up
|
||||
ui = (_REPO / "static" / "js" / "ui.js").read_text(encoding="utf-8")
|
||||
assert "toast-close-btn" in ui
|
||||
assert "aria-label" in ui
|
||||
assert "Dismiss" in ui
|
||||
assert ui.count("toast-close-btn") >= 2
|
||||
|
||||
|
||||
@@ -1,56 +0,0 @@
|
||||
"""User-supplied IMAP/SMTP ports must not crash the email-account endpoints.
|
||||
|
||||
A non-numeric port (for example ``"imap"`` or ``"993x"``) previously reached an
|
||||
unguarded ``int(...)`` in create / update / test-config and raised ``ValueError``,
|
||||
which surfaces as an HTTP 500. The endpoints should reject it with their standard
|
||||
``{"ok": False, "error": ...}`` response instead.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
def _route_endpoint(router, path: str, method: str):
|
||||
method = method.upper()
|
||||
for route in router.routes:
|
||||
if route.path == path and method in getattr(route, "methods", set()):
|
||||
return route.endpoint
|
||||
raise AssertionError(f"route not found: {method} {path}")
|
||||
|
||||
|
||||
def test_coerce_port_accepts_int_and_numeric_string():
|
||||
import routes.email_routes as email_routes
|
||||
assert email_routes._coerce_port(2525, 993) == (2525, None)
|
||||
assert email_routes._coerce_port("465", 993) == (465, None)
|
||||
|
||||
|
||||
def test_coerce_port_blank_uses_default():
|
||||
import routes.email_routes as email_routes
|
||||
assert email_routes._coerce_port(None, 993) == (993, None)
|
||||
assert email_routes._coerce_port("", 465) == (465, None)
|
||||
|
||||
|
||||
def test_coerce_port_rejects_non_numeric():
|
||||
import routes.email_routes as email_routes
|
||||
port, err = email_routes._coerce_port("imap", 993)
|
||||
assert port is None
|
||||
assert err and "port" in err.lower()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_account_rejects_non_numeric_port():
|
||||
"""A bad port is rejected before any DB work, with the endpoint's error shape."""
|
||||
import routes.email_routes as email_routes
|
||||
router = email_routes.setup_email_routes()
|
||||
create = _route_endpoint(router, "/api/email/accounts", "POST")
|
||||
result = await create(
|
||||
{
|
||||
"name": "Test",
|
||||
"imap_host": "mail.example.com",
|
||||
"imap_user": "u",
|
||||
"imap_password": "p",
|
||||
"imap_port": "not-a-number",
|
||||
},
|
||||
owner="alice",
|
||||
)
|
||||
assert result["ok"] is False
|
||||
assert "port" in result["error"].lower()
|
||||
@@ -13,7 +13,7 @@ in test_embedding_lanes.py, but the preserved embeddings come back as ndarray.
|
||||
import numpy as np
|
||||
|
||||
from src.embedding_lanes import build_embedding_lanes
|
||||
from tests.helpers.embedding_lanes import FakeChroma, FakeEmbedder, patch_chroma
|
||||
from tests.test_embedding_lanes import FakeChroma, FakeEmbedder, _patch_chroma
|
||||
|
||||
|
||||
def test_lane_reset_restores_when_chroma_returns_numpy_embeddings(monkeypatch):
|
||||
@@ -46,7 +46,7 @@ def test_lane_reset_restores_when_chroma_returns_numpy_embeddings(monkeypatch):
|
||||
|
||||
# Force the post-reset rewrite to fail so the restore branch runs.
|
||||
fake.fail_next_add_for["odysseus_memories_custom"] = 1
|
||||
patch_chroma(monkeypatch, fake)
|
||||
_patch_chroma(monkeypatch, fake)
|
||||
|
||||
import src.embedding_lanes as lanes
|
||||
|
||||
|
||||
+822
-13
@@ -1,21 +1,139 @@
|
||||
import pytest
|
||||
|
||||
from src.embedding_lanes import (
|
||||
EmbeddingLane,
|
||||
LANE_CUSTOM,
|
||||
LANE_FASTEMBED,
|
||||
build_embedding_lanes,
|
||||
)
|
||||
from tests.helpers.embedding_lanes import (
|
||||
FakeChroma,
|
||||
FakeEmbedder,
|
||||
FailingEmbedder,
|
||||
patch_chroma,
|
||||
)
|
||||
|
||||
|
||||
class FakeEmbedder:
|
||||
def __init__(self, dim, model, url):
|
||||
self.dim = dim
|
||||
self.model = model
|
||||
self.url = url
|
||||
|
||||
def get_sentence_embedding_dimension(self):
|
||||
return self.dim
|
||||
|
||||
def encode(self, texts, normalize_embeddings=True):
|
||||
return [[float(i + 1)] * self.dim for i, _ in enumerate(texts)]
|
||||
|
||||
|
||||
class FailingEmbedder(FakeEmbedder):
|
||||
def encode(self, texts, normalize_embeddings=True):
|
||||
raise RuntimeError("embedding endpoint rate limited")
|
||||
|
||||
|
||||
class FakeCollection:
|
||||
def __init__(self, name, metadata=None):
|
||||
self.name = name
|
||||
self.metadata = metadata or {}
|
||||
self.rows = {}
|
||||
self.dim = None
|
||||
|
||||
def count(self):
|
||||
return len(self.rows)
|
||||
|
||||
def add(self, ids, embeddings, documents=None, metadatas=None):
|
||||
self._check_dim(embeddings)
|
||||
documents = documents or [None] * len(ids)
|
||||
metadatas = metadatas or [{}] * len(ids)
|
||||
for row_id, emb, doc, meta in zip(ids, embeddings, documents, metadatas):
|
||||
self.rows[row_id] = {"embedding": emb, "document": doc, "metadata": meta}
|
||||
|
||||
def upsert(self, ids, embeddings, documents=None, metadatas=None):
|
||||
self.add(ids, embeddings, documents=documents, metadatas=metadatas)
|
||||
|
||||
def get(self, ids=None, include=None, where=None, limit=None):
|
||||
selected = list(self.rows.items())
|
||||
if ids is not None:
|
||||
id_set = set(ids)
|
||||
selected = [(row_id, row) for row_id, row in selected if row_id in id_set]
|
||||
if where:
|
||||
selected = [
|
||||
(row_id, row)
|
||||
for row_id, row in selected
|
||||
if all(row["metadata"].get(k) == v for k, v in where.items())
|
||||
]
|
||||
if limit is not None:
|
||||
selected = selected[:limit]
|
||||
return {
|
||||
"ids": [row_id for row_id, _ in selected],
|
||||
"documents": [row["document"] for _, row in selected],
|
||||
"metadatas": [row["metadata"] for _, row in selected],
|
||||
"embeddings": [row["embedding"] for _, row in selected],
|
||||
}
|
||||
|
||||
def query(self, query_embeddings, n_results, where=None, include=None):
|
||||
self._check_dim(query_embeddings)
|
||||
rows = self.get(where=where)
|
||||
ids = rows["ids"][:n_results]
|
||||
docs = rows["documents"][:n_results]
|
||||
metas = rows["metadatas"][:n_results]
|
||||
return {
|
||||
"ids": [ids],
|
||||
"documents": [docs],
|
||||
"metadatas": [metas],
|
||||
"distances": [[0.1 + i * 0.01 for i in range(len(ids))]],
|
||||
}
|
||||
|
||||
def delete(self, ids):
|
||||
for row_id in ids:
|
||||
self.rows.pop(row_id, None)
|
||||
|
||||
def _check_dim(self, embeddings):
|
||||
if not embeddings:
|
||||
return
|
||||
dim = len(embeddings[0])
|
||||
if self.dim is None:
|
||||
self.dim = dim
|
||||
elif self.dim != dim:
|
||||
raise RuntimeError(f"Collection expecting embedding with dimension of {self.dim}, got {dim}")
|
||||
|
||||
|
||||
class FakeChroma:
|
||||
def __init__(self):
|
||||
self.collections = {}
|
||||
self.deleted = []
|
||||
self.fail_next_add_for = {}
|
||||
|
||||
def get_or_create_collection(self, name, metadata=None):
|
||||
if name not in self.collections:
|
||||
self.collections[name] = FakeCollection(name, metadata=metadata)
|
||||
if self.fail_next_add_for.get(name, 0) > 0:
|
||||
original_add = self.collections[name].add
|
||||
|
||||
def fail_once(*args, **kwargs):
|
||||
self.fail_next_add_for[name] -= 1
|
||||
self.collections[name].add = original_add
|
||||
raise RuntimeError("chroma write failed")
|
||||
|
||||
self.collections[name].add = fail_once
|
||||
elif metadata is not None:
|
||||
self.collections[name].metadata = metadata
|
||||
return self.collections[name]
|
||||
|
||||
def get_collection(self, name):
|
||||
if name not in self.collections:
|
||||
raise KeyError(name)
|
||||
return self.collections[name]
|
||||
|
||||
def delete_collection(self, name):
|
||||
self.deleted.append(name)
|
||||
self.collections.pop(name, None)
|
||||
|
||||
|
||||
def _patch_chroma(monkeypatch, fake):
|
||||
import src.chroma_client as chroma_client
|
||||
|
||||
monkeypatch.setattr(chroma_client, "get_chroma_client", lambda: fake)
|
||||
|
||||
|
||||
def test_build_embedding_lanes_keeps_custom_and_fastembed_dimensions_separate(monkeypatch):
|
||||
fake = FakeChroma()
|
||||
patch_chroma(monkeypatch, fake)
|
||||
_patch_chroma(monkeypatch, fake)
|
||||
|
||||
import src.embedding_lanes as lanes
|
||||
|
||||
@@ -64,7 +182,7 @@ def test_build_embedding_lanes_recreates_only_custom_when_fingerprint_changes(mo
|
||||
},
|
||||
)
|
||||
fast.add(ids=["fast"], embeddings=[[0.0] * 384], documents=["fast"])
|
||||
patch_chroma(monkeypatch, fake)
|
||||
_patch_chroma(monkeypatch, fake)
|
||||
|
||||
import src.embedding_lanes as lanes
|
||||
|
||||
@@ -96,7 +214,7 @@ def test_lane_reset_reembeds_existing_documents_on_fingerprint_change(monkeypatc
|
||||
documents=["existing custom memory"],
|
||||
metadatas=[{"source": "memory"}],
|
||||
)
|
||||
patch_chroma(monkeypatch, fake)
|
||||
_patch_chroma(monkeypatch, fake)
|
||||
|
||||
import src.embedding_lanes as lanes
|
||||
|
||||
@@ -133,7 +251,7 @@ def test_lane_reset_keeps_existing_collection_when_reembed_fails(monkeypatch):
|
||||
documents=["existing custom memory"],
|
||||
metadatas=[{"source": "memory"}],
|
||||
)
|
||||
patch_chroma(monkeypatch, fake)
|
||||
_patch_chroma(monkeypatch, fake)
|
||||
|
||||
import src.embedding_lanes as lanes
|
||||
|
||||
@@ -169,7 +287,7 @@ def test_lane_reset_keeps_existing_collection_when_preserve_read_fails(monkeypat
|
||||
raise RuntimeError("chroma read failed")
|
||||
|
||||
old_custom.get = fail_get
|
||||
patch_chroma(monkeypatch, fake)
|
||||
_patch_chroma(monkeypatch, fake)
|
||||
|
||||
import src.embedding_lanes as lanes
|
||||
|
||||
@@ -204,7 +322,7 @@ def test_lane_reset_restores_existing_collection_when_rewrite_fails(monkeypatch)
|
||||
metadatas=[{"source": "memory"}],
|
||||
)
|
||||
fake.fail_next_add_for["odysseus_memories_custom"] = 1
|
||||
patch_chroma(monkeypatch, fake)
|
||||
_patch_chroma(monkeypatch, fake)
|
||||
|
||||
import src.embedding_lanes as lanes
|
||||
|
||||
@@ -226,7 +344,7 @@ def test_lane_reset_restores_existing_collection_when_rewrite_fails(monkeypatch)
|
||||
|
||||
def test_build_embedding_lanes_uses_fastembed_when_custom_unavailable(monkeypatch):
|
||||
fake = FakeChroma()
|
||||
patch_chroma(monkeypatch, fake)
|
||||
_patch_chroma(monkeypatch, fake)
|
||||
|
||||
import src.embedding_lanes as lanes
|
||||
|
||||
@@ -293,3 +411,694 @@ def test_custom_lane_uses_http_down_latch(monkeypatch):
|
||||
|
||||
assert calls == [{"url": None, "model": None, "api_key": None}]
|
||||
embeddings.reset_http_embed_state()
|
||||
|
||||
|
||||
def test_memory_vector_store_writes_both_lanes_and_prefers_custom(monkeypatch):
|
||||
fake = FakeChroma()
|
||||
_patch_chroma(monkeypatch, fake)
|
||||
|
||||
import src.embedding_lanes as lanes
|
||||
|
||||
monkeypatch.setattr(lanes, "_build_custom_client", lambda: FakeEmbedder(768, "nomic", "http://embeddings/v1"))
|
||||
monkeypatch.setattr(lanes, "_build_fastembed_client", lambda: FakeEmbedder(384, "mini", "local://fastembed"))
|
||||
|
||||
from src.memory_vector import MemoryVectorStore
|
||||
|
||||
store = MemoryVectorStore("data")
|
||||
store.add("mem-1", "Nicholai likes direct memory systems")
|
||||
|
||||
assert fake.collections["odysseus_memories_custom"].count() == 1
|
||||
assert fake.collections["odysseus_memories_fastembed"].count() == 1
|
||||
|
||||
results = store.search("direct memory", k=5)
|
||||
assert results[0]["memory_id"] == "mem-1"
|
||||
assert results[0]["embedding_lane"] == LANE_CUSTOM
|
||||
|
||||
|
||||
def test_memory_search_merges_fallback_only_results_before_limit():
|
||||
custom_collection = FakeCollection("odysseus_memories_custom", metadata={"embedding_lane": "custom"})
|
||||
fast_collection = FakeCollection("odysseus_memories_fastembed", metadata={"embedding_lane": "fastembed"})
|
||||
custom_collection.add(
|
||||
ids=["old-1", "old-2"],
|
||||
embeddings=[[0.0] * 768, [0.0] * 768],
|
||||
documents=["older custom memory", "another custom memory"],
|
||||
metadatas=[{"source": "memory"}, {"source": "memory"}],
|
||||
)
|
||||
fast_collection.add(
|
||||
ids=["fallback-only"],
|
||||
embeddings=[[0.0] * 384],
|
||||
documents=["fallback only relevant memory"],
|
||||
metadatas=[{"source": "memory"}],
|
||||
)
|
||||
|
||||
custom_collection.query = lambda **_kwargs: {
|
||||
"ids": [["old-1", "old-2"]],
|
||||
"distances": [[0.20, 0.21]],
|
||||
}
|
||||
fast_collection.query = lambda **_kwargs: {
|
||||
"ids": [["fallback-only"]],
|
||||
"distances": [[0.05]],
|
||||
}
|
||||
|
||||
custom_lane = EmbeddingLane(
|
||||
name=LANE_CUSTOM,
|
||||
client=FakeEmbedder(768, "nomic", "http://embeddings/v1"),
|
||||
collection=custom_collection,
|
||||
collection_name="odysseus_memories_custom",
|
||||
model="nomic",
|
||||
url="http://embeddings/v1",
|
||||
dimension=768,
|
||||
fingerprint="custom",
|
||||
)
|
||||
fast_lane = EmbeddingLane(
|
||||
name=LANE_FASTEMBED,
|
||||
client=FakeEmbedder(384, "mini", "local://fastembed"),
|
||||
collection=fast_collection,
|
||||
collection_name="odysseus_memories_fastembed",
|
||||
model="mini",
|
||||
url="local://fastembed",
|
||||
dimension=384,
|
||||
fingerprint="fast",
|
||||
)
|
||||
|
||||
from src.memory_vector import MemoryVectorStore
|
||||
|
||||
store = MemoryVectorStore.__new__(MemoryVectorStore)
|
||||
store._lanes = [custom_lane, fast_lane]
|
||||
store._healthy = True
|
||||
|
||||
results = store.search("fallback relevant", k=2)
|
||||
|
||||
assert [row["memory_id"] for row in results] == ["fallback-only", "old-1"]
|
||||
|
||||
|
||||
def test_vector_rag_writes_both_lanes_and_falls_back_to_fastembed(monkeypatch):
|
||||
fake = FakeChroma()
|
||||
_patch_chroma(monkeypatch, fake)
|
||||
|
||||
import src.embedding_lanes as lanes
|
||||
|
||||
monkeypatch.setattr(lanes, "_build_custom_client", lambda: None)
|
||||
monkeypatch.setattr(lanes, "_build_fastembed_client", lambda: FakeEmbedder(384, "mini", "local://fastembed"))
|
||||
|
||||
from src.rag_vector import VectorRAG
|
||||
|
||||
rag = VectorRAG()
|
||||
assert rag.add_document("session search belongs in tools", {"source": "/tmp/a.md", "owner": "alice"})
|
||||
assert "odysseus_rag_custom" not in fake.collections
|
||||
assert fake.collections["odysseus_rag_fastembed"].count() == 1
|
||||
|
||||
results = rag.search("session search", k=3, owner="alice")
|
||||
assert results[0]["document"] == "session search belongs in tools"
|
||||
assert results[0]["embedding_lane"] == LANE_FASTEMBED
|
||||
|
||||
|
||||
def test_vector_rag_batch_index_continues_when_custom_lane_fails(monkeypatch, tmp_path):
|
||||
fake = FakeChroma()
|
||||
_patch_chroma(monkeypatch, fake)
|
||||
|
||||
import src.embedding_lanes as lanes
|
||||
|
||||
monkeypatch.setattr(lanes, "_build_custom_client", lambda: FailingEmbedder(768, "nomic", "http://embeddings/v1"))
|
||||
monkeypatch.setattr(lanes, "_build_fastembed_client", lambda: FakeEmbedder(384, "mini", "local://fastembed"))
|
||||
|
||||
from src.rag_vector import VectorRAG
|
||||
|
||||
rag = VectorRAG(persist_directory=str(tmp_path))
|
||||
result = rag.add_documents_batch([
|
||||
("batch fallback document", {"source": "/tmp/a.md", "owner": "alice"}),
|
||||
])
|
||||
|
||||
assert result["success"]
|
||||
assert result["added_count"] == 1
|
||||
assert fake.collections["odysseus_rag_custom"].count() == 0
|
||||
assert fake.collections["odysseus_rag_fastembed"].count() == 1
|
||||
|
||||
|
||||
def test_vector_rag_batch_index_reports_failure_when_all_lanes_fail(monkeypatch, tmp_path):
|
||||
fake = FakeChroma()
|
||||
_patch_chroma(monkeypatch, fake)
|
||||
|
||||
import src.embedding_lanes as lanes
|
||||
|
||||
monkeypatch.setattr(lanes, "_build_custom_client", lambda: FailingEmbedder(768, "nomic", "http://embeddings/v1"))
|
||||
monkeypatch.setattr(lanes, "_build_fastembed_client", lambda: FailingEmbedder(384, "mini", "local://fastembed"))
|
||||
|
||||
from src.rag_vector import VectorRAG
|
||||
|
||||
rag = VectorRAG(persist_directory=str(tmp_path))
|
||||
result = rag.add_documents_batch([
|
||||
("batch outage document", {"source": "/tmp/a.md", "owner": "alice"}),
|
||||
])
|
||||
|
||||
assert not result["success"]
|
||||
assert fake.collections["odysseus_rag_custom"].count() == 0
|
||||
assert fake.collections["odysseus_rag_fastembed"].count() == 0
|
||||
|
||||
|
||||
def test_tool_index_indexes_and_retrieves_from_available_lanes(monkeypatch):
|
||||
fake = FakeChroma()
|
||||
_patch_chroma(monkeypatch, fake)
|
||||
|
||||
import src.embedding_lanes as lanes
|
||||
|
||||
monkeypatch.setattr(lanes, "_build_custom_client", lambda: FakeEmbedder(768, "nomic", "http://embeddings/v1"))
|
||||
monkeypatch.setattr(lanes, "_build_fastembed_client", lambda: FakeEmbedder(384, "mini", "local://fastembed"))
|
||||
|
||||
from src.tool_index import ToolIndex
|
||||
|
||||
index = ToolIndex()
|
||||
index.index_builtin_tools()
|
||||
|
||||
assert fake.collections["odysseus_tool_index_custom"].count() > 0
|
||||
assert fake.collections["odysseus_tool_index_fastembed"].count() > 0
|
||||
assert "bash" in index.retrieve("run a shell command", k=10)
|
||||
|
||||
|
||||
def test_tool_index_builtin_indexing_fails_when_all_lanes_fail():
|
||||
custom_lane = EmbeddingLane(
|
||||
name=LANE_CUSTOM,
|
||||
client=FailingEmbedder(768, "nomic", "http://embeddings/v1"),
|
||||
collection=FakeCollection("odysseus_tool_index_custom", metadata={"embedding_lane": "custom"}),
|
||||
collection_name="odysseus_tool_index_custom",
|
||||
model="nomic",
|
||||
url="http://embeddings/v1",
|
||||
dimension=768,
|
||||
fingerprint="custom",
|
||||
)
|
||||
fast_lane = EmbeddingLane(
|
||||
name=LANE_FASTEMBED,
|
||||
client=FailingEmbedder(384, "mini", "local://fastembed"),
|
||||
collection=FakeCollection("odysseus_tool_index_fastembed", metadata={"embedding_lane": "fastembed"}),
|
||||
collection_name="odysseus_tool_index_fastembed",
|
||||
model="mini",
|
||||
url="local://fastembed",
|
||||
dimension=384,
|
||||
fingerprint="fast",
|
||||
)
|
||||
|
||||
from src.tool_index import ToolIndex
|
||||
|
||||
index = ToolIndex.__new__(ToolIndex)
|
||||
index._lanes = [custom_lane, fast_lane]
|
||||
index._healthy = True
|
||||
|
||||
with pytest.raises(RuntimeError, match="all embedding lanes"):
|
||||
index.index_builtin_tools()
|
||||
assert not index.healthy
|
||||
|
||||
|
||||
def test_tool_index_retrieval_continues_when_custom_lane_query_fails():
|
||||
custom_collection = FakeCollection("odysseus_tool_index_custom", metadata={"embedding_lane": "custom"})
|
||||
fast_collection = FakeCollection("odysseus_tool_index_fastembed", metadata={"embedding_lane": "fastembed"})
|
||||
fast_collection.add(
|
||||
ids=["builtin_bash"],
|
||||
embeddings=[[0.0] * 384],
|
||||
documents=["Tool: bash\nRun shell commands"],
|
||||
metadatas=[{"tool_name": "bash", "tool_type": "builtin"}],
|
||||
)
|
||||
|
||||
def fail_query(*_args, **_kwargs):
|
||||
raise RuntimeError("custom endpoint down")
|
||||
|
||||
custom_collection.add(
|
||||
ids=["builtin_python"],
|
||||
embeddings=[[0.0] * 768],
|
||||
documents=["Tool: python\nRun Python"],
|
||||
metadatas=[{"tool_name": "python", "tool_type": "builtin"}],
|
||||
)
|
||||
custom_collection.query = fail_query
|
||||
|
||||
custom_lane = EmbeddingLane(
|
||||
name=LANE_CUSTOM,
|
||||
client=FakeEmbedder(768, "nomic", "http://embeddings/v1"),
|
||||
collection=custom_collection,
|
||||
collection_name="odysseus_tool_index_custom",
|
||||
model="nomic",
|
||||
url="http://embeddings/v1",
|
||||
dimension=768,
|
||||
fingerprint="custom",
|
||||
)
|
||||
fast_lane = EmbeddingLane(
|
||||
name=LANE_FASTEMBED,
|
||||
client=FakeEmbedder(384, "mini", "local://fastembed"),
|
||||
collection=fast_collection,
|
||||
collection_name="odysseus_tool_index_fastembed",
|
||||
model="mini",
|
||||
url="local://fastembed",
|
||||
dimension=384,
|
||||
fingerprint="fast",
|
||||
)
|
||||
|
||||
from src.tool_index import ToolIndex
|
||||
|
||||
index = ToolIndex.__new__(ToolIndex)
|
||||
index._lanes = [custom_lane, fast_lane]
|
||||
|
||||
assert index.retrieve("run shell", k=5) == ["bash"]
|
||||
|
||||
|
||||
def test_tool_index_merges_fallback_tool_results_before_limit():
|
||||
custom_collection = FakeCollection("odysseus_tool_index_custom", metadata={"embedding_lane": "custom"})
|
||||
fast_collection = FakeCollection("odysseus_tool_index_fastembed", metadata={"embedding_lane": "fastembed"})
|
||||
custom_collection.add(
|
||||
ids=["builtin_one", "builtin_two"],
|
||||
embeddings=[[0.0] * 768, [0.0] * 768],
|
||||
documents=["Tool: one", "Tool: two"],
|
||||
metadatas=[
|
||||
{"tool_name": "one", "tool_type": "builtin"},
|
||||
{"tool_name": "two", "tool_type": "builtin"},
|
||||
],
|
||||
)
|
||||
fast_collection.add(
|
||||
ids=["mcp_current"],
|
||||
embeddings=[[0.0] * 384],
|
||||
documents=["Tool: current MCP"],
|
||||
metadatas=[{"tool_name": "current_mcp", "tool_type": "mcp"}],
|
||||
)
|
||||
|
||||
custom_collection.query = lambda **_kwargs: {
|
||||
"ids": [["builtin_one", "builtin_two"]],
|
||||
"metadatas": [[
|
||||
{"tool_name": "one", "tool_type": "builtin"},
|
||||
{"tool_name": "two", "tool_type": "builtin"},
|
||||
]],
|
||||
"distances": [[0.20, 0.21]],
|
||||
}
|
||||
fast_collection.query = lambda **_kwargs: {
|
||||
"ids": [["mcp_current"]],
|
||||
"metadatas": [[{"tool_name": "current_mcp", "tool_type": "mcp"}]],
|
||||
"distances": [[0.05]],
|
||||
}
|
||||
|
||||
custom_lane = EmbeddingLane(
|
||||
name=LANE_CUSTOM,
|
||||
client=FakeEmbedder(768, "nomic", "http://embeddings/v1"),
|
||||
collection=custom_collection,
|
||||
collection_name="odysseus_tool_index_custom",
|
||||
model="nomic",
|
||||
url="http://embeddings/v1",
|
||||
dimension=768,
|
||||
fingerprint="custom",
|
||||
)
|
||||
fast_lane = EmbeddingLane(
|
||||
name=LANE_FASTEMBED,
|
||||
client=FakeEmbedder(384, "mini", "local://fastembed"),
|
||||
collection=fast_collection,
|
||||
collection_name="odysseus_tool_index_fastembed",
|
||||
model="mini",
|
||||
url="local://fastembed",
|
||||
dimension=384,
|
||||
fingerprint="fast",
|
||||
)
|
||||
|
||||
from src.tool_index import ToolIndex
|
||||
|
||||
index = ToolIndex.__new__(ToolIndex)
|
||||
index._lanes = [custom_lane, fast_lane]
|
||||
|
||||
assert index.retrieve("current mcp", k=2) == ["current_mcp", "one"]
|
||||
|
||||
|
||||
def test_legacy_collection_backfills_fastembed_lane(monkeypatch):
|
||||
fake = FakeChroma()
|
||||
legacy = fake.get_or_create_collection("odysseus_memories", metadata={"hnsw:space": "cosine"})
|
||||
legacy.add(
|
||||
ids=["legacy-memory"],
|
||||
embeddings=[[0.0] * 384],
|
||||
documents=["legacy memory row"],
|
||||
metadatas=[{"source": "memory"}],
|
||||
)
|
||||
_patch_chroma(monkeypatch, fake)
|
||||
|
||||
import src.embedding_lanes as lanes
|
||||
|
||||
monkeypatch.setattr(lanes, "_build_custom_client", lambda: None)
|
||||
monkeypatch.setattr(lanes, "_build_fastembed_client", lambda: FakeEmbedder(384, "mini", "local://fastembed"))
|
||||
|
||||
from src.memory_vector import MemoryVectorStore
|
||||
|
||||
store = MemoryVectorStore("data")
|
||||
|
||||
assert store.count() == 1
|
||||
assert fake.collections["odysseus_memories"].count() == 1
|
||||
assert fake.collections["odysseus_memories_fastembed"].count() == 1
|
||||
|
||||
|
||||
def test_legacy_collection_backfills_custom_only_lane(monkeypatch):
|
||||
fake = FakeChroma()
|
||||
legacy = fake.get_or_create_collection("odysseus_memories", metadata={"hnsw:space": "cosine"})
|
||||
legacy.add(
|
||||
ids=["legacy-memory"],
|
||||
embeddings=[[0.0] * 384],
|
||||
documents=["legacy memory row"],
|
||||
metadatas=[{"source": "memory"}],
|
||||
)
|
||||
_patch_chroma(monkeypatch, fake)
|
||||
|
||||
import src.embedding_lanes as lanes
|
||||
|
||||
monkeypatch.setattr(lanes, "_build_custom_client", lambda: FakeEmbedder(768, "nomic", "http://embeddings/v1"))
|
||||
|
||||
def fail_fastembed():
|
||||
raise RuntimeError("fastembed missing")
|
||||
|
||||
monkeypatch.setattr(lanes, "_build_fastembed_client", fail_fastembed)
|
||||
|
||||
from src.memory_vector import MemoryVectorStore
|
||||
|
||||
store = MemoryVectorStore("data")
|
||||
|
||||
assert store.count() == 1
|
||||
assert "odysseus_memories_fastembed" not in fake.collections
|
||||
assert fake.collections["odysseus_memories_custom"].count() == 1
|
||||
assert len(fake.collections["odysseus_memories_custom"].rows["legacy-memory"]["embedding"]) == 768
|
||||
|
||||
|
||||
def test_legacy_migration_continues_when_custom_backfill_fails(monkeypatch):
|
||||
fake = FakeChroma()
|
||||
legacy = fake.get_or_create_collection("odysseus_memories", metadata={"hnsw:space": "cosine"})
|
||||
legacy.add(
|
||||
ids=["legacy-memory"],
|
||||
embeddings=[[0.0] * 384],
|
||||
documents=["legacy memory row"],
|
||||
metadatas=[{"source": "memory"}],
|
||||
)
|
||||
_patch_chroma(monkeypatch, fake)
|
||||
|
||||
import src.embedding_lanes as lanes
|
||||
|
||||
monkeypatch.setattr(lanes, "_build_custom_client", lambda: FailingEmbedder(768, "nomic", "http://embeddings/v1"))
|
||||
monkeypatch.setattr(lanes, "_build_fastembed_client", lambda: FakeEmbedder(384, "mini", "local://fastembed"))
|
||||
|
||||
from src.memory_vector import MemoryVectorStore
|
||||
|
||||
store = MemoryVectorStore("data")
|
||||
|
||||
assert store.healthy
|
||||
assert fake.collections["odysseus_memories_custom"].count() == 0
|
||||
assert fake.collections["odysseus_memories_fastembed"].count() == 1
|
||||
|
||||
|
||||
def test_legacy_migration_resumes_partial_lane_backfill(monkeypatch):
|
||||
fake = FakeChroma()
|
||||
legacy = fake.get_or_create_collection("odysseus_memories", metadata={"hnsw:space": "cosine"})
|
||||
legacy.add(
|
||||
ids=["legacy-1", "legacy-2"],
|
||||
embeddings=[[0.0] * 384, [0.0] * 384],
|
||||
documents=["legacy memory one", "legacy memory two"],
|
||||
metadatas=[{"source": "memory"}, {"source": "memory"}],
|
||||
)
|
||||
partial = fake.get_or_create_collection("odysseus_memories_fastembed", metadata={"embedding_lane": "fastembed"})
|
||||
partial.add(
|
||||
ids=["legacy-1"],
|
||||
embeddings=[[0.0] * 384],
|
||||
documents=["legacy memory one"],
|
||||
metadatas=[{"source": "memory"}],
|
||||
)
|
||||
_patch_chroma(monkeypatch, fake)
|
||||
|
||||
import src.embedding_lanes as lanes
|
||||
|
||||
monkeypatch.setattr(lanes, "_build_custom_client", lambda: None)
|
||||
monkeypatch.setattr(lanes, "_build_fastembed_client", lambda: FakeEmbedder(384, "mini", "local://fastembed"))
|
||||
|
||||
from src.memory_vector import MemoryVectorStore
|
||||
|
||||
store = MemoryVectorStore("data")
|
||||
|
||||
assert store.count() == 2
|
||||
assert set(fake.collections["odysseus_memories_fastembed"].get()["ids"]) == {"legacy-1", "legacy-2"}
|
||||
|
||||
|
||||
def test_memory_rebuild_does_not_reimport_legacy_collection(monkeypatch):
|
||||
fake = FakeChroma()
|
||||
legacy = fake.get_or_create_collection("odysseus_memories", metadata={"hnsw:space": "cosine"})
|
||||
legacy.add(
|
||||
ids=["stale-memory"],
|
||||
embeddings=[[0.0] * 384],
|
||||
documents=["stale legacy memory"],
|
||||
metadatas=[{"source": "memory"}],
|
||||
)
|
||||
inactive_custom = fake.get_or_create_collection("odysseus_memories_custom", metadata={"embedding_lane": "custom"})
|
||||
inactive_custom.add(
|
||||
ids=["stale-custom"],
|
||||
embeddings=[[0.0] * 768],
|
||||
documents=["stale inactive custom memory"],
|
||||
metadatas=[{"source": "memory"}],
|
||||
)
|
||||
_patch_chroma(monkeypatch, fake)
|
||||
|
||||
import src.embedding_lanes as lanes
|
||||
|
||||
monkeypatch.setattr(lanes, "_build_custom_client", lambda: None)
|
||||
monkeypatch.setattr(lanes, "_build_fastembed_client", lambda: FakeEmbedder(384, "mini", "local://fastembed"))
|
||||
|
||||
from src.memory_vector import MemoryVectorStore
|
||||
|
||||
store = MemoryVectorStore("data")
|
||||
assert fake.collections["odysseus_memories_fastembed"].count() == 1
|
||||
|
||||
store.rebuild([{"id": "current-memory", "text": "current rebuilt memory"}])
|
||||
|
||||
assert "odysseus_memories" not in fake.collections
|
||||
assert "odysseus_memories_custom" not in fake.collections
|
||||
assert fake.collections["odysseus_memories_fastembed"].count() == 1
|
||||
assert fake.collections["odysseus_memories_fastembed"].get()["ids"] == ["current-memory"]
|
||||
|
||||
|
||||
def test_memory_remove_deletes_inactive_lane_collection(monkeypatch):
|
||||
fake = FakeChroma()
|
||||
custom_collection = fake.get_or_create_collection("odysseus_memories_custom", metadata={"embedding_lane": "custom"})
|
||||
fast_collection = fake.get_or_create_collection("odysseus_memories_fastembed", metadata={"embedding_lane": "fastembed"})
|
||||
custom_collection.add(
|
||||
ids=["mem-1"],
|
||||
embeddings=[[0.0] * 768],
|
||||
documents=["custom stale memory"],
|
||||
metadatas=[{"source": "memory"}],
|
||||
)
|
||||
fast_collection.add(
|
||||
ids=["mem-1"],
|
||||
embeddings=[[0.0] * 384],
|
||||
documents=["fast memory"],
|
||||
metadatas=[{"source": "memory"}],
|
||||
)
|
||||
_patch_chroma(monkeypatch, fake)
|
||||
|
||||
fast_lane = EmbeddingLane(
|
||||
name=LANE_FASTEMBED,
|
||||
client=FakeEmbedder(384, "mini", "local://fastembed"),
|
||||
collection=fast_collection,
|
||||
collection_name="odysseus_memories_fastembed",
|
||||
model="mini",
|
||||
url="local://fastembed",
|
||||
dimension=384,
|
||||
fingerprint="fast",
|
||||
)
|
||||
|
||||
from src.memory_vector import MemoryVectorStore
|
||||
|
||||
store = MemoryVectorStore.__new__(MemoryVectorStore)
|
||||
store._lanes = [fast_lane]
|
||||
store._healthy = True
|
||||
|
||||
store.remove("mem-1")
|
||||
|
||||
assert custom_collection.count() == 0
|
||||
assert fast_collection.count() == 0
|
||||
|
||||
|
||||
def test_memory_rebuild_continues_when_custom_lane_fails(monkeypatch):
|
||||
fake = FakeChroma()
|
||||
_patch_chroma(monkeypatch, fake)
|
||||
|
||||
import src.embedding_lanes as lanes
|
||||
|
||||
monkeypatch.setattr(lanes, "_build_custom_client", lambda: FailingEmbedder(768, "nomic", "http://embeddings/v1"))
|
||||
monkeypatch.setattr(lanes, "_build_fastembed_client", lambda: FakeEmbedder(384, "mini", "local://fastembed"))
|
||||
|
||||
from src.memory_vector import MemoryVectorStore
|
||||
|
||||
store = MemoryVectorStore("data")
|
||||
store.rebuild([{"id": "current-memory", "text": "current rebuilt memory"}])
|
||||
|
||||
assert fake.collections["odysseus_memories_custom"].count() == 0
|
||||
assert fake.collections["odysseus_memories_fastembed"].count() == 1
|
||||
assert fake.collections["odysseus_memories_fastembed"].get()["ids"] == ["current-memory"]
|
||||
|
||||
|
||||
def test_rag_rebuild_does_not_reimport_legacy_collection(monkeypatch, tmp_path):
|
||||
fake = FakeChroma()
|
||||
legacy = fake.get_or_create_collection("odysseus_rag", metadata={"hnsw:space": "cosine"})
|
||||
legacy.add(
|
||||
ids=["stale-doc"],
|
||||
embeddings=[[0.0] * 384],
|
||||
documents=["stale legacy document"],
|
||||
metadatas=[{"source": "/tmp/stale.md"}],
|
||||
)
|
||||
inactive_custom = fake.get_or_create_collection("odysseus_rag_custom", metadata={"embedding_lane": "custom"})
|
||||
inactive_custom.add(
|
||||
ids=["stale-custom-doc"],
|
||||
embeddings=[[0.0] * 768],
|
||||
documents=["stale inactive custom document"],
|
||||
metadatas=[{"source": "/tmp/stale.md"}],
|
||||
)
|
||||
_patch_chroma(monkeypatch, fake)
|
||||
|
||||
import src.embedding_lanes as lanes
|
||||
|
||||
monkeypatch.setattr(lanes, "_build_custom_client", lambda: None)
|
||||
monkeypatch.setattr(lanes, "_build_fastembed_client", lambda: FakeEmbedder(384, "mini", "local://fastembed"))
|
||||
|
||||
from src.rag_vector import VectorRAG
|
||||
|
||||
rag = VectorRAG(persist_directory=str(tmp_path))
|
||||
assert fake.collections["odysseus_rag_fastembed"].count() == 1
|
||||
|
||||
assert rag.rebuild_index()
|
||||
|
||||
assert "odysseus_rag" not in fake.collections
|
||||
assert "odysseus_rag_custom" not in fake.collections
|
||||
assert fake.collections["odysseus_rag_fastembed"].count() == 0
|
||||
assert rag.search("stale legacy", k=3) == []
|
||||
|
||||
|
||||
def test_rag_remove_directory_deletes_inactive_lane_collection(monkeypatch, tmp_path):
|
||||
fake = FakeChroma()
|
||||
legacy_collection = fake.get_or_create_collection("odysseus_rag", metadata={"hnsw:space": "cosine"})
|
||||
custom_collection = fake.get_or_create_collection("odysseus_rag_custom", metadata={"embedding_lane": "custom"})
|
||||
fast_collection = fake.get_or_create_collection("odysseus_rag_fastembed", metadata={"embedding_lane": "fastembed"})
|
||||
source = str(tmp_path / "docs" / "note.md")
|
||||
directory = str(tmp_path / "docs")
|
||||
legacy_collection.add(
|
||||
ids=["legacy-doc"],
|
||||
embeddings=[[0.0] * 384],
|
||||
documents=["legacy stale doc"],
|
||||
metadatas=[{"source": source}],
|
||||
)
|
||||
custom_collection.add(
|
||||
ids=["custom-doc"],
|
||||
embeddings=[[0.0] * 768],
|
||||
documents=["custom stale doc"],
|
||||
metadatas=[{"source": source}],
|
||||
)
|
||||
fast_collection.add(
|
||||
ids=["fast-doc"],
|
||||
embeddings=[[0.0] * 384],
|
||||
documents=["fast current doc"],
|
||||
metadatas=[{"source": source}],
|
||||
)
|
||||
_patch_chroma(monkeypatch, fake)
|
||||
|
||||
fast_lane = EmbeddingLane(
|
||||
name=LANE_FASTEMBED,
|
||||
client=FakeEmbedder(384, "mini", "local://fastembed"),
|
||||
collection=fast_collection,
|
||||
collection_name="odysseus_rag_fastembed",
|
||||
model="mini",
|
||||
url="local://fastembed",
|
||||
dimension=384,
|
||||
fingerprint="fast",
|
||||
)
|
||||
|
||||
from src.rag_vector import VectorRAG
|
||||
|
||||
rag = VectorRAG.__new__(VectorRAG)
|
||||
rag._lanes = [fast_lane]
|
||||
rag._collection = fast_collection
|
||||
rag._healthy = True
|
||||
|
||||
result = rag.remove_directory(directory)
|
||||
|
||||
assert result["success"]
|
||||
assert result["removed_count"] == 3
|
||||
assert legacy_collection.count() == 0
|
||||
assert custom_collection.count() == 0
|
||||
assert fast_collection.count() == 0
|
||||
|
||||
|
||||
def test_rag_delete_by_source_deletes_inactive_lane_collection(monkeypatch, tmp_path):
|
||||
fake = FakeChroma()
|
||||
legacy_collection = fake.get_or_create_collection("odysseus_rag", metadata={"hnsw:space": "cosine"})
|
||||
custom_collection = fake.get_or_create_collection("odysseus_rag_custom", metadata={"embedding_lane": "custom"})
|
||||
fast_collection = fake.get_or_create_collection("odysseus_rag_fastembed", metadata={"embedding_lane": "fastembed"})
|
||||
source = str(tmp_path / "docs" / "note.md")
|
||||
legacy_collection.add(
|
||||
ids=["legacy-doc"],
|
||||
embeddings=[[0.0] * 384],
|
||||
documents=["legacy stale doc"],
|
||||
metadatas=[{"source": source}],
|
||||
)
|
||||
custom_collection.add(
|
||||
ids=["shared-doc"],
|
||||
embeddings=[[0.0] * 768],
|
||||
documents=["custom stale doc"],
|
||||
metadatas=[{"source": source}],
|
||||
)
|
||||
fast_collection.add(
|
||||
ids=["shared-doc"],
|
||||
embeddings=[[0.0] * 384],
|
||||
documents=["fast current doc"],
|
||||
metadatas=[{"source": source}],
|
||||
)
|
||||
_patch_chroma(monkeypatch, fake)
|
||||
|
||||
fast_lane = EmbeddingLane(
|
||||
name=LANE_FASTEMBED,
|
||||
client=FakeEmbedder(384, "mini", "local://fastembed"),
|
||||
collection=fast_collection,
|
||||
collection_name="odysseus_rag_fastembed",
|
||||
model="mini",
|
||||
url="local://fastembed",
|
||||
dimension=384,
|
||||
fingerprint="fast",
|
||||
)
|
||||
|
||||
from src.rag_vector import VectorRAG
|
||||
|
||||
rag = VectorRAG.__new__(VectorRAG)
|
||||
rag._lanes = [fast_lane]
|
||||
rag._collection = fast_collection
|
||||
rag._healthy = True
|
||||
|
||||
assert rag.delete_by_source(source) == 2
|
||||
assert legacy_collection.count() == 0
|
||||
assert custom_collection.count() == 0
|
||||
assert fast_collection.count() == 0
|
||||
|
||||
|
||||
def test_vector_rag_uses_keyword_fallback_when_all_lanes_query_fail():
|
||||
collection = FakeCollection("odysseus_rag_fastembed", metadata={"embedding_lane": "fastembed"})
|
||||
collection.add(
|
||||
ids=["doc-1"],
|
||||
embeddings=[[0.0] * 384],
|
||||
documents=["fallback keyword document"],
|
||||
metadatas=[{"source": "/tmp/doc.md"}],
|
||||
)
|
||||
|
||||
def fail_query(*_args, **_kwargs):
|
||||
raise RuntimeError("embedding query down")
|
||||
|
||||
collection.query = fail_query
|
||||
lane = EmbeddingLane(
|
||||
name=LANE_FASTEMBED,
|
||||
client=FakeEmbedder(384, "mini", "local://fastembed"),
|
||||
collection=collection,
|
||||
collection_name="odysseus_rag_fastembed",
|
||||
model="mini",
|
||||
url="local://fastembed",
|
||||
dimension=384,
|
||||
fingerprint="fp",
|
||||
)
|
||||
|
||||
from src.rag_vector import VectorRAG
|
||||
|
||||
rag = VectorRAG.__new__(VectorRAG)
|
||||
rag._lanes = [lane]
|
||||
rag._collection = collection
|
||||
rag._healthy = True
|
||||
|
||||
results = rag.search("fallback keyword", k=3)
|
||||
|
||||
assert results[0]["id"] == "doc-1"
|
||||
assert results[0]["search_type"] == "keyword_fallback"
|
||||
|
||||
@@ -1,117 +0,0 @@
|
||||
from tests.helpers.embedding_lanes import (
|
||||
FakeChroma,
|
||||
FakeEmbedder,
|
||||
FailingEmbedder,
|
||||
patch_chroma,
|
||||
)
|
||||
|
||||
|
||||
def test_legacy_collection_backfills_fastembed_lane(monkeypatch):
|
||||
fake = FakeChroma()
|
||||
legacy = fake.get_or_create_collection("odysseus_memories", metadata={"hnsw:space": "cosine"})
|
||||
legacy.add(
|
||||
ids=["legacy-memory"],
|
||||
embeddings=[[0.0] * 384],
|
||||
documents=["legacy memory row"],
|
||||
metadatas=[{"source": "memory"}],
|
||||
)
|
||||
patch_chroma(monkeypatch, fake)
|
||||
|
||||
import src.embedding_lanes as lanes
|
||||
|
||||
monkeypatch.setattr(lanes, "_build_custom_client", lambda: None)
|
||||
monkeypatch.setattr(lanes, "_build_fastembed_client", lambda: FakeEmbedder(384, "mini", "local://fastembed"))
|
||||
|
||||
from src.memory_vector import MemoryVectorStore
|
||||
|
||||
store = MemoryVectorStore("data")
|
||||
|
||||
assert store.count() == 1
|
||||
assert fake.collections["odysseus_memories"].count() == 1
|
||||
assert fake.collections["odysseus_memories_fastembed"].count() == 1
|
||||
|
||||
|
||||
def test_legacy_collection_backfills_custom_only_lane(monkeypatch):
|
||||
fake = FakeChroma()
|
||||
legacy = fake.get_or_create_collection("odysseus_memories", metadata={"hnsw:space": "cosine"})
|
||||
legacy.add(
|
||||
ids=["legacy-memory"],
|
||||
embeddings=[[0.0] * 384],
|
||||
documents=["legacy memory row"],
|
||||
metadatas=[{"source": "memory"}],
|
||||
)
|
||||
patch_chroma(monkeypatch, fake)
|
||||
|
||||
import src.embedding_lanes as lanes
|
||||
|
||||
monkeypatch.setattr(lanes, "_build_custom_client", lambda: FakeEmbedder(768, "nomic", "http://embeddings/v1"))
|
||||
|
||||
def fail_fastembed():
|
||||
raise RuntimeError("fastembed missing")
|
||||
|
||||
monkeypatch.setattr(lanes, "_build_fastembed_client", fail_fastembed)
|
||||
|
||||
from src.memory_vector import MemoryVectorStore
|
||||
|
||||
store = MemoryVectorStore("data")
|
||||
|
||||
assert store.count() == 1
|
||||
assert "odysseus_memories_fastembed" not in fake.collections
|
||||
assert fake.collections["odysseus_memories_custom"].count() == 1
|
||||
assert len(fake.collections["odysseus_memories_custom"].rows["legacy-memory"]["embedding"]) == 768
|
||||
|
||||
|
||||
def test_legacy_migration_continues_when_custom_backfill_fails(monkeypatch):
|
||||
fake = FakeChroma()
|
||||
legacy = fake.get_or_create_collection("odysseus_memories", metadata={"hnsw:space": "cosine"})
|
||||
legacy.add(
|
||||
ids=["legacy-memory"],
|
||||
embeddings=[[0.0] * 384],
|
||||
documents=["legacy memory row"],
|
||||
metadatas=[{"source": "memory"}],
|
||||
)
|
||||
patch_chroma(monkeypatch, fake)
|
||||
|
||||
import src.embedding_lanes as lanes
|
||||
|
||||
monkeypatch.setattr(lanes, "_build_custom_client", lambda: FailingEmbedder(768, "nomic", "http://embeddings/v1"))
|
||||
monkeypatch.setattr(lanes, "_build_fastembed_client", lambda: FakeEmbedder(384, "mini", "local://fastembed"))
|
||||
|
||||
from src.memory_vector import MemoryVectorStore
|
||||
|
||||
store = MemoryVectorStore("data")
|
||||
|
||||
assert store.healthy
|
||||
assert fake.collections["odysseus_memories_custom"].count() == 0
|
||||
assert fake.collections["odysseus_memories_fastembed"].count() == 1
|
||||
|
||||
|
||||
def test_legacy_migration_resumes_partial_lane_backfill(monkeypatch):
|
||||
fake = FakeChroma()
|
||||
legacy = fake.get_or_create_collection("odysseus_memories", metadata={"hnsw:space": "cosine"})
|
||||
legacy.add(
|
||||
ids=["legacy-1", "legacy-2"],
|
||||
embeddings=[[0.0] * 384, [0.0] * 384],
|
||||
documents=["legacy memory one", "legacy memory two"],
|
||||
metadatas=[{"source": "memory"}, {"source": "memory"}],
|
||||
)
|
||||
partial = fake.get_or_create_collection("odysseus_memories_fastembed", metadata={"embedding_lane": "fastembed"})
|
||||
partial.add(
|
||||
ids=["legacy-1"],
|
||||
embeddings=[[0.0] * 384],
|
||||
documents=["legacy memory one"],
|
||||
metadatas=[{"source": "memory"}],
|
||||
)
|
||||
patch_chroma(monkeypatch, fake)
|
||||
|
||||
import src.embedding_lanes as lanes
|
||||
|
||||
monkeypatch.setattr(lanes, "_build_custom_client", lambda: None)
|
||||
monkeypatch.setattr(lanes, "_build_fastembed_client", lambda: FakeEmbedder(384, "mini", "local://fastembed"))
|
||||
|
||||
from src.memory_vector import MemoryVectorStore
|
||||
|
||||
store = MemoryVectorStore("data")
|
||||
|
||||
assert store.count() == 2
|
||||
assert set(fake.collections["odysseus_memories_fastembed"].get()["ids"]) == {"legacy-1", "legacy-2"}
|
||||
@@ -1,187 +0,0 @@
|
||||
from src.embedding_lanes import (
|
||||
EmbeddingLane,
|
||||
LANE_CUSTOM,
|
||||
LANE_FASTEMBED,
|
||||
)
|
||||
from tests.helpers.embedding_lanes import (
|
||||
FakeChroma,
|
||||
FakeCollection,
|
||||
FakeEmbedder,
|
||||
FailingEmbedder,
|
||||
patch_chroma,
|
||||
)
|
||||
|
||||
|
||||
def test_memory_vector_store_writes_both_lanes_and_prefers_custom(monkeypatch):
|
||||
fake = FakeChroma()
|
||||
patch_chroma(monkeypatch, fake)
|
||||
|
||||
import src.embedding_lanes as lanes
|
||||
|
||||
monkeypatch.setattr(lanes, "_build_custom_client", lambda: FakeEmbedder(768, "nomic", "http://embeddings/v1"))
|
||||
monkeypatch.setattr(lanes, "_build_fastembed_client", lambda: FakeEmbedder(384, "mini", "local://fastembed"))
|
||||
|
||||
from src.memory_vector import MemoryVectorStore
|
||||
|
||||
store = MemoryVectorStore("data")
|
||||
store.add("mem-1", "Nicholai likes direct memory systems")
|
||||
|
||||
assert fake.collections["odysseus_memories_custom"].count() == 1
|
||||
assert fake.collections["odysseus_memories_fastembed"].count() == 1
|
||||
|
||||
results = store.search("direct memory", k=5)
|
||||
assert results[0]["memory_id"] == "mem-1"
|
||||
assert results[0]["embedding_lane"] == LANE_CUSTOM
|
||||
|
||||
|
||||
def test_memory_search_merges_fallback_only_results_before_limit():
|
||||
custom_collection = FakeCollection("odysseus_memories_custom", metadata={"embedding_lane": "custom"})
|
||||
fast_collection = FakeCollection("odysseus_memories_fastembed", metadata={"embedding_lane": "fastembed"})
|
||||
custom_collection.add(
|
||||
ids=["old-1", "old-2"],
|
||||
embeddings=[[0.0] * 768, [0.0] * 768],
|
||||
documents=["older custom memory", "another custom memory"],
|
||||
metadatas=[{"source": "memory"}, {"source": "memory"}],
|
||||
)
|
||||
fast_collection.add(
|
||||
ids=["fallback-only"],
|
||||
embeddings=[[0.0] * 384],
|
||||
documents=["fallback only relevant memory"],
|
||||
metadatas=[{"source": "memory"}],
|
||||
)
|
||||
|
||||
custom_collection.query = lambda **_kwargs: {
|
||||
"ids": [["old-1", "old-2"]],
|
||||
"distances": [[0.20, 0.21]],
|
||||
}
|
||||
fast_collection.query = lambda **_kwargs: {
|
||||
"ids": [["fallback-only"]],
|
||||
"distances": [[0.05]],
|
||||
}
|
||||
|
||||
custom_lane = EmbeddingLane(
|
||||
name=LANE_CUSTOM,
|
||||
client=FakeEmbedder(768, "nomic", "http://embeddings/v1"),
|
||||
collection=custom_collection,
|
||||
collection_name="odysseus_memories_custom",
|
||||
model="nomic",
|
||||
url="http://embeddings/v1",
|
||||
dimension=768,
|
||||
fingerprint="custom",
|
||||
)
|
||||
fast_lane = EmbeddingLane(
|
||||
name=LANE_FASTEMBED,
|
||||
client=FakeEmbedder(384, "mini", "local://fastembed"),
|
||||
collection=fast_collection,
|
||||
collection_name="odysseus_memories_fastembed",
|
||||
model="mini",
|
||||
url="local://fastembed",
|
||||
dimension=384,
|
||||
fingerprint="fast",
|
||||
)
|
||||
|
||||
from src.memory_vector import MemoryVectorStore
|
||||
|
||||
store = MemoryVectorStore.__new__(MemoryVectorStore)
|
||||
store._lanes = [custom_lane, fast_lane]
|
||||
store._healthy = True
|
||||
|
||||
results = store.search("fallback relevant", k=2)
|
||||
|
||||
assert [row["memory_id"] for row in results] == ["fallback-only", "old-1"]
|
||||
|
||||
|
||||
def test_memory_rebuild_does_not_reimport_legacy_collection(monkeypatch):
|
||||
fake = FakeChroma()
|
||||
legacy = fake.get_or_create_collection("odysseus_memories", metadata={"hnsw:space": "cosine"})
|
||||
legacy.add(
|
||||
ids=["stale-memory"],
|
||||
embeddings=[[0.0] * 384],
|
||||
documents=["stale legacy memory"],
|
||||
metadatas=[{"source": "memory"}],
|
||||
)
|
||||
inactive_custom = fake.get_or_create_collection("odysseus_memories_custom", metadata={"embedding_lane": "custom"})
|
||||
inactive_custom.add(
|
||||
ids=["stale-custom"],
|
||||
embeddings=[[0.0] * 768],
|
||||
documents=["stale inactive custom memory"],
|
||||
metadatas=[{"source": "memory"}],
|
||||
)
|
||||
patch_chroma(monkeypatch, fake)
|
||||
|
||||
import src.embedding_lanes as lanes
|
||||
|
||||
monkeypatch.setattr(lanes, "_build_custom_client", lambda: None)
|
||||
monkeypatch.setattr(lanes, "_build_fastembed_client", lambda: FakeEmbedder(384, "mini", "local://fastembed"))
|
||||
|
||||
from src.memory_vector import MemoryVectorStore
|
||||
|
||||
store = MemoryVectorStore("data")
|
||||
assert fake.collections["odysseus_memories_fastembed"].count() == 1
|
||||
|
||||
store.rebuild([{"id": "current-memory", "text": "current rebuilt memory"}])
|
||||
|
||||
assert "odysseus_memories" not in fake.collections
|
||||
assert "odysseus_memories_custom" not in fake.collections
|
||||
assert fake.collections["odysseus_memories_fastembed"].count() == 1
|
||||
assert fake.collections["odysseus_memories_fastembed"].get()["ids"] == ["current-memory"]
|
||||
|
||||
|
||||
def test_memory_remove_deletes_inactive_lane_collection(monkeypatch):
|
||||
fake = FakeChroma()
|
||||
custom_collection = fake.get_or_create_collection("odysseus_memories_custom", metadata={"embedding_lane": "custom"})
|
||||
fast_collection = fake.get_or_create_collection("odysseus_memories_fastembed", metadata={"embedding_lane": "fastembed"})
|
||||
custom_collection.add(
|
||||
ids=["mem-1"],
|
||||
embeddings=[[0.0] * 768],
|
||||
documents=["custom stale memory"],
|
||||
metadatas=[{"source": "memory"}],
|
||||
)
|
||||
fast_collection.add(
|
||||
ids=["mem-1"],
|
||||
embeddings=[[0.0] * 384],
|
||||
documents=["fast memory"],
|
||||
metadatas=[{"source": "memory"}],
|
||||
)
|
||||
patch_chroma(monkeypatch, fake)
|
||||
|
||||
fast_lane = EmbeddingLane(
|
||||
name=LANE_FASTEMBED,
|
||||
client=FakeEmbedder(384, "mini", "local://fastembed"),
|
||||
collection=fast_collection,
|
||||
collection_name="odysseus_memories_fastembed",
|
||||
model="mini",
|
||||
url="local://fastembed",
|
||||
dimension=384,
|
||||
fingerprint="fast",
|
||||
)
|
||||
|
||||
from src.memory_vector import MemoryVectorStore
|
||||
|
||||
store = MemoryVectorStore.__new__(MemoryVectorStore)
|
||||
store._lanes = [fast_lane]
|
||||
store._healthy = True
|
||||
|
||||
store.remove("mem-1")
|
||||
|
||||
assert custom_collection.count() == 0
|
||||
assert fast_collection.count() == 0
|
||||
|
||||
|
||||
def test_memory_rebuild_continues_when_custom_lane_fails(monkeypatch):
|
||||
fake = FakeChroma()
|
||||
patch_chroma(monkeypatch, fake)
|
||||
|
||||
import src.embedding_lanes as lanes
|
||||
|
||||
monkeypatch.setattr(lanes, "_build_custom_client", lambda: FailingEmbedder(768, "nomic", "http://embeddings/v1"))
|
||||
monkeypatch.setattr(lanes, "_build_fastembed_client", lambda: FakeEmbedder(384, "mini", "local://fastembed"))
|
||||
|
||||
from src.memory_vector import MemoryVectorStore
|
||||
|
||||
store = MemoryVectorStore("data")
|
||||
store.rebuild([{"id": "current-memory", "text": "current rebuilt memory"}])
|
||||
|
||||
assert fake.collections["odysseus_memories_custom"].count() == 0
|
||||
assert fake.collections["odysseus_memories_fastembed"].count() == 1
|
||||
assert fake.collections["odysseus_memories_fastembed"].get()["ids"] == ["current-memory"]
|
||||
@@ -1,252 +0,0 @@
|
||||
from src.embedding_lanes import (
|
||||
EmbeddingLane,
|
||||
LANE_FASTEMBED,
|
||||
)
|
||||
from tests.helpers.embedding_lanes import (
|
||||
FakeChroma,
|
||||
FakeCollection,
|
||||
FakeEmbedder,
|
||||
FailingEmbedder,
|
||||
patch_chroma,
|
||||
)
|
||||
|
||||
|
||||
def test_vector_rag_writes_both_lanes_and_falls_back_to_fastembed(monkeypatch):
|
||||
fake = FakeChroma()
|
||||
patch_chroma(monkeypatch, fake)
|
||||
|
||||
import src.embedding_lanes as lanes
|
||||
|
||||
monkeypatch.setattr(lanes, "_build_custom_client", lambda: None)
|
||||
monkeypatch.setattr(lanes, "_build_fastembed_client", lambda: FakeEmbedder(384, "mini", "local://fastembed"))
|
||||
|
||||
from src.rag_vector import VectorRAG
|
||||
|
||||
rag = VectorRAG()
|
||||
assert rag.add_document("session search belongs in tools", {"source": "/tmp/a.md", "owner": "alice"})
|
||||
assert "odysseus_rag_custom" not in fake.collections
|
||||
assert fake.collections["odysseus_rag_fastembed"].count() == 1
|
||||
|
||||
results = rag.search("session search", k=3, owner="alice")
|
||||
assert results[0]["document"] == "session search belongs in tools"
|
||||
assert results[0]["embedding_lane"] == LANE_FASTEMBED
|
||||
|
||||
|
||||
def test_vector_rag_batch_index_continues_when_custom_lane_fails(monkeypatch, tmp_path):
|
||||
fake = FakeChroma()
|
||||
patch_chroma(monkeypatch, fake)
|
||||
|
||||
import src.embedding_lanes as lanes
|
||||
|
||||
monkeypatch.setattr(lanes, "_build_custom_client", lambda: FailingEmbedder(768, "nomic", "http://embeddings/v1"))
|
||||
monkeypatch.setattr(lanes, "_build_fastembed_client", lambda: FakeEmbedder(384, "mini", "local://fastembed"))
|
||||
|
||||
from src.rag_vector import VectorRAG
|
||||
|
||||
rag = VectorRAG(persist_directory=str(tmp_path))
|
||||
result = rag.add_documents_batch([
|
||||
("batch fallback document", {"source": "/tmp/a.md", "owner": "alice"}),
|
||||
])
|
||||
|
||||
assert result["success"]
|
||||
assert result["added_count"] == 1
|
||||
assert fake.collections["odysseus_rag_custom"].count() == 0
|
||||
assert fake.collections["odysseus_rag_fastembed"].count() == 1
|
||||
|
||||
|
||||
def test_vector_rag_batch_index_reports_failure_when_all_lanes_fail(monkeypatch, tmp_path):
|
||||
fake = FakeChroma()
|
||||
patch_chroma(monkeypatch, fake)
|
||||
|
||||
import src.embedding_lanes as lanes
|
||||
|
||||
monkeypatch.setattr(lanes, "_build_custom_client", lambda: FailingEmbedder(768, "nomic", "http://embeddings/v1"))
|
||||
monkeypatch.setattr(lanes, "_build_fastembed_client", lambda: FailingEmbedder(384, "mini", "local://fastembed"))
|
||||
|
||||
from src.rag_vector import VectorRAG
|
||||
|
||||
rag = VectorRAG(persist_directory=str(tmp_path))
|
||||
result = rag.add_documents_batch([
|
||||
("batch outage document", {"source": "/tmp/a.md", "owner": "alice"}),
|
||||
])
|
||||
|
||||
assert not result["success"]
|
||||
assert fake.collections["odysseus_rag_custom"].count() == 0
|
||||
assert fake.collections["odysseus_rag_fastembed"].count() == 0
|
||||
|
||||
|
||||
def test_rag_rebuild_does_not_reimport_legacy_collection(monkeypatch, tmp_path):
|
||||
fake = FakeChroma()
|
||||
legacy = fake.get_or_create_collection("odysseus_rag", metadata={"hnsw:space": "cosine"})
|
||||
legacy.add(
|
||||
ids=["stale-doc"],
|
||||
embeddings=[[0.0] * 384],
|
||||
documents=["stale legacy document"],
|
||||
metadatas=[{"source": "/tmp/stale.md"}],
|
||||
)
|
||||
inactive_custom = fake.get_or_create_collection("odysseus_rag_custom", metadata={"embedding_lane": "custom"})
|
||||
inactive_custom.add(
|
||||
ids=["stale-custom-doc"],
|
||||
embeddings=[[0.0] * 768],
|
||||
documents=["stale inactive custom document"],
|
||||
metadatas=[{"source": "/tmp/stale.md"}],
|
||||
)
|
||||
patch_chroma(monkeypatch, fake)
|
||||
|
||||
import src.embedding_lanes as lanes
|
||||
|
||||
monkeypatch.setattr(lanes, "_build_custom_client", lambda: None)
|
||||
monkeypatch.setattr(lanes, "_build_fastembed_client", lambda: FakeEmbedder(384, "mini", "local://fastembed"))
|
||||
|
||||
from src.rag_vector import VectorRAG
|
||||
|
||||
rag = VectorRAG(persist_directory=str(tmp_path))
|
||||
assert fake.collections["odysseus_rag_fastembed"].count() == 1
|
||||
|
||||
assert rag.rebuild_index()
|
||||
|
||||
assert "odysseus_rag" not in fake.collections
|
||||
assert "odysseus_rag_custom" not in fake.collections
|
||||
assert fake.collections["odysseus_rag_fastembed"].count() == 0
|
||||
assert rag.search("stale legacy", k=3) == []
|
||||
|
||||
|
||||
def test_rag_remove_directory_deletes_inactive_lane_collection(monkeypatch, tmp_path):
|
||||
fake = FakeChroma()
|
||||
legacy_collection = fake.get_or_create_collection("odysseus_rag", metadata={"hnsw:space": "cosine"})
|
||||
custom_collection = fake.get_or_create_collection("odysseus_rag_custom", metadata={"embedding_lane": "custom"})
|
||||
fast_collection = fake.get_or_create_collection("odysseus_rag_fastembed", metadata={"embedding_lane": "fastembed"})
|
||||
source = str(tmp_path / "docs" / "note.md")
|
||||
directory = str(tmp_path / "docs")
|
||||
legacy_collection.add(
|
||||
ids=["legacy-doc"],
|
||||
embeddings=[[0.0] * 384],
|
||||
documents=["legacy stale doc"],
|
||||
metadatas=[{"source": source}],
|
||||
)
|
||||
custom_collection.add(
|
||||
ids=["custom-doc"],
|
||||
embeddings=[[0.0] * 768],
|
||||
documents=["custom stale doc"],
|
||||
metadatas=[{"source": source}],
|
||||
)
|
||||
fast_collection.add(
|
||||
ids=["fast-doc"],
|
||||
embeddings=[[0.0] * 384],
|
||||
documents=["fast current doc"],
|
||||
metadatas=[{"source": source}],
|
||||
)
|
||||
patch_chroma(monkeypatch, fake)
|
||||
|
||||
fast_lane = EmbeddingLane(
|
||||
name=LANE_FASTEMBED,
|
||||
client=FakeEmbedder(384, "mini", "local://fastembed"),
|
||||
collection=fast_collection,
|
||||
collection_name="odysseus_rag_fastembed",
|
||||
model="mini",
|
||||
url="local://fastembed",
|
||||
dimension=384,
|
||||
fingerprint="fast",
|
||||
)
|
||||
|
||||
from src.rag_vector import VectorRAG
|
||||
|
||||
rag = VectorRAG.__new__(VectorRAG)
|
||||
rag._lanes = [fast_lane]
|
||||
rag._collection = fast_collection
|
||||
rag._healthy = True
|
||||
|
||||
result = rag.remove_directory(directory)
|
||||
|
||||
assert result["success"]
|
||||
assert result["removed_count"] == 3
|
||||
assert legacy_collection.count() == 0
|
||||
assert custom_collection.count() == 0
|
||||
assert fast_collection.count() == 0
|
||||
|
||||
|
||||
def test_rag_delete_by_source_deletes_inactive_lane_collection(monkeypatch, tmp_path):
|
||||
fake = FakeChroma()
|
||||
legacy_collection = fake.get_or_create_collection("odysseus_rag", metadata={"hnsw:space": "cosine"})
|
||||
custom_collection = fake.get_or_create_collection("odysseus_rag_custom", metadata={"embedding_lane": "custom"})
|
||||
fast_collection = fake.get_or_create_collection("odysseus_rag_fastembed", metadata={"embedding_lane": "fastembed"})
|
||||
source = str(tmp_path / "docs" / "note.md")
|
||||
legacy_collection.add(
|
||||
ids=["legacy-doc"],
|
||||
embeddings=[[0.0] * 384],
|
||||
documents=["legacy stale doc"],
|
||||
metadatas=[{"source": source}],
|
||||
)
|
||||
custom_collection.add(
|
||||
ids=["shared-doc"],
|
||||
embeddings=[[0.0] * 768],
|
||||
documents=["custom stale doc"],
|
||||
metadatas=[{"source": source}],
|
||||
)
|
||||
fast_collection.add(
|
||||
ids=["shared-doc"],
|
||||
embeddings=[[0.0] * 384],
|
||||
documents=["fast current doc"],
|
||||
metadatas=[{"source": source}],
|
||||
)
|
||||
patch_chroma(monkeypatch, fake)
|
||||
|
||||
fast_lane = EmbeddingLane(
|
||||
name=LANE_FASTEMBED,
|
||||
client=FakeEmbedder(384, "mini", "local://fastembed"),
|
||||
collection=fast_collection,
|
||||
collection_name="odysseus_rag_fastembed",
|
||||
model="mini",
|
||||
url="local://fastembed",
|
||||
dimension=384,
|
||||
fingerprint="fast",
|
||||
)
|
||||
|
||||
from src.rag_vector import VectorRAG
|
||||
|
||||
rag = VectorRAG.__new__(VectorRAG)
|
||||
rag._lanes = [fast_lane]
|
||||
rag._collection = fast_collection
|
||||
rag._healthy = True
|
||||
|
||||
assert rag.delete_by_source(source) == 2
|
||||
assert legacy_collection.count() == 0
|
||||
assert custom_collection.count() == 0
|
||||
assert fast_collection.count() == 0
|
||||
|
||||
|
||||
def test_vector_rag_uses_keyword_fallback_when_all_lanes_query_fail():
|
||||
collection = FakeCollection("odysseus_rag_fastembed", metadata={"embedding_lane": "fastembed"})
|
||||
collection.add(
|
||||
ids=["doc-1"],
|
||||
embeddings=[[0.0] * 384],
|
||||
documents=["fallback keyword document"],
|
||||
metadatas=[{"source": "/tmp/doc.md"}],
|
||||
)
|
||||
|
||||
def fail_query(*_args, **_kwargs):
|
||||
raise RuntimeError("embedding query down")
|
||||
|
||||
collection.query = fail_query
|
||||
lane = EmbeddingLane(
|
||||
name=LANE_FASTEMBED,
|
||||
client=FakeEmbedder(384, "mini", "local://fastembed"),
|
||||
collection=collection,
|
||||
collection_name="odysseus_rag_fastembed",
|
||||
model="mini",
|
||||
url="local://fastembed",
|
||||
dimension=384,
|
||||
fingerprint="fp",
|
||||
)
|
||||
|
||||
from src.rag_vector import VectorRAG
|
||||
|
||||
rag = VectorRAG.__new__(VectorRAG)
|
||||
rag._lanes = [lane]
|
||||
rag._collection = collection
|
||||
rag._healthy = True
|
||||
|
||||
results = rag.search("fallback keyword", k=3)
|
||||
|
||||
assert results[0]["id"] == "doc-1"
|
||||
assert results[0]["search_type"] == "keyword_fallback"
|
||||
@@ -1,178 +0,0 @@
|
||||
import pytest
|
||||
|
||||
from src.embedding_lanes import (
|
||||
EmbeddingLane,
|
||||
LANE_CUSTOM,
|
||||
LANE_FASTEMBED,
|
||||
)
|
||||
from tests.helpers.embedding_lanes import (
|
||||
FakeChroma,
|
||||
FakeCollection,
|
||||
FakeEmbedder,
|
||||
FailingEmbedder,
|
||||
patch_chroma,
|
||||
)
|
||||
|
||||
|
||||
def test_tool_index_indexes_and_retrieves_from_available_lanes(monkeypatch):
|
||||
fake = FakeChroma()
|
||||
patch_chroma(monkeypatch, fake)
|
||||
|
||||
import src.embedding_lanes as lanes
|
||||
|
||||
monkeypatch.setattr(lanes, "_build_custom_client", lambda: FakeEmbedder(768, "nomic", "http://embeddings/v1"))
|
||||
monkeypatch.setattr(lanes, "_build_fastembed_client", lambda: FakeEmbedder(384, "mini", "local://fastembed"))
|
||||
|
||||
from src.tool_index import ToolIndex
|
||||
|
||||
index = ToolIndex()
|
||||
index.index_builtin_tools()
|
||||
|
||||
assert fake.collections["odysseus_tool_index_custom"].count() > 0
|
||||
assert fake.collections["odysseus_tool_index_fastembed"].count() > 0
|
||||
assert "bash" in index.retrieve("run a shell command", k=10)
|
||||
|
||||
|
||||
def test_tool_index_builtin_indexing_fails_when_all_lanes_fail():
|
||||
custom_lane = EmbeddingLane(
|
||||
name=LANE_CUSTOM,
|
||||
client=FailingEmbedder(768, "nomic", "http://embeddings/v1"),
|
||||
collection=FakeCollection("odysseus_tool_index_custom", metadata={"embedding_lane": "custom"}),
|
||||
collection_name="odysseus_tool_index_custom",
|
||||
model="nomic",
|
||||
url="http://embeddings/v1",
|
||||
dimension=768,
|
||||
fingerprint="custom",
|
||||
)
|
||||
fast_lane = EmbeddingLane(
|
||||
name=LANE_FASTEMBED,
|
||||
client=FailingEmbedder(384, "mini", "local://fastembed"),
|
||||
collection=FakeCollection("odysseus_tool_index_fastembed", metadata={"embedding_lane": "fastembed"}),
|
||||
collection_name="odysseus_tool_index_fastembed",
|
||||
model="mini",
|
||||
url="local://fastembed",
|
||||
dimension=384,
|
||||
fingerprint="fast",
|
||||
)
|
||||
|
||||
from src.tool_index import ToolIndex
|
||||
|
||||
index = ToolIndex.__new__(ToolIndex)
|
||||
index._lanes = [custom_lane, fast_lane]
|
||||
index._healthy = True
|
||||
|
||||
with pytest.raises(RuntimeError, match="all embedding lanes"):
|
||||
index.index_builtin_tools()
|
||||
assert not index.healthy
|
||||
|
||||
|
||||
def test_tool_index_retrieval_continues_when_custom_lane_query_fails():
|
||||
custom_collection = FakeCollection("odysseus_tool_index_custom", metadata={"embedding_lane": "custom"})
|
||||
fast_collection = FakeCollection("odysseus_tool_index_fastembed", metadata={"embedding_lane": "fastembed"})
|
||||
fast_collection.add(
|
||||
ids=["builtin_bash"],
|
||||
embeddings=[[0.0] * 384],
|
||||
documents=["Tool: bash\nRun shell commands"],
|
||||
metadatas=[{"tool_name": "bash", "tool_type": "builtin"}],
|
||||
)
|
||||
|
||||
def fail_query(*_args, **_kwargs):
|
||||
raise RuntimeError("custom endpoint down")
|
||||
|
||||
custom_collection.add(
|
||||
ids=["builtin_python"],
|
||||
embeddings=[[0.0] * 768],
|
||||
documents=["Tool: python\nRun Python"],
|
||||
metadatas=[{"tool_name": "python", "tool_type": "builtin"}],
|
||||
)
|
||||
custom_collection.query = fail_query
|
||||
|
||||
custom_lane = EmbeddingLane(
|
||||
name=LANE_CUSTOM,
|
||||
client=FakeEmbedder(768, "nomic", "http://embeddings/v1"),
|
||||
collection=custom_collection,
|
||||
collection_name="odysseus_tool_index_custom",
|
||||
model="nomic",
|
||||
url="http://embeddings/v1",
|
||||
dimension=768,
|
||||
fingerprint="custom",
|
||||
)
|
||||
fast_lane = EmbeddingLane(
|
||||
name=LANE_FASTEMBED,
|
||||
client=FakeEmbedder(384, "mini", "local://fastembed"),
|
||||
collection=fast_collection,
|
||||
collection_name="odysseus_tool_index_fastembed",
|
||||
model="mini",
|
||||
url="local://fastembed",
|
||||
dimension=384,
|
||||
fingerprint="fast",
|
||||
)
|
||||
|
||||
from src.tool_index import ToolIndex
|
||||
|
||||
index = ToolIndex.__new__(ToolIndex)
|
||||
index._lanes = [custom_lane, fast_lane]
|
||||
|
||||
assert index.retrieve("run shell", k=5) == ["bash"]
|
||||
|
||||
|
||||
def test_tool_index_merges_fallback_tool_results_before_limit():
|
||||
custom_collection = FakeCollection("odysseus_tool_index_custom", metadata={"embedding_lane": "custom"})
|
||||
fast_collection = FakeCollection("odysseus_tool_index_fastembed", metadata={"embedding_lane": "fastembed"})
|
||||
custom_collection.add(
|
||||
ids=["builtin_one", "builtin_two"],
|
||||
embeddings=[[0.0] * 768, [0.0] * 768],
|
||||
documents=["Tool: one", "Tool: two"],
|
||||
metadatas=[
|
||||
{"tool_name": "one", "tool_type": "builtin"},
|
||||
{"tool_name": "two", "tool_type": "builtin"},
|
||||
],
|
||||
)
|
||||
fast_collection.add(
|
||||
ids=["mcp_current"],
|
||||
embeddings=[[0.0] * 384],
|
||||
documents=["Tool: current MCP"],
|
||||
metadatas=[{"tool_name": "current_mcp", "tool_type": "mcp"}],
|
||||
)
|
||||
|
||||
custom_collection.query = lambda **_kwargs: {
|
||||
"ids": [["builtin_one", "builtin_two"]],
|
||||
"metadatas": [[
|
||||
{"tool_name": "one", "tool_type": "builtin"},
|
||||
{"tool_name": "two", "tool_type": "builtin"},
|
||||
]],
|
||||
"distances": [[0.20, 0.21]],
|
||||
}
|
||||
fast_collection.query = lambda **_kwargs: {
|
||||
"ids": [["mcp_current"]],
|
||||
"metadatas": [[{"tool_name": "current_mcp", "tool_type": "mcp"}]],
|
||||
"distances": [[0.05]],
|
||||
}
|
||||
|
||||
custom_lane = EmbeddingLane(
|
||||
name=LANE_CUSTOM,
|
||||
client=FakeEmbedder(768, "nomic", "http://embeddings/v1"),
|
||||
collection=custom_collection,
|
||||
collection_name="odysseus_tool_index_custom",
|
||||
model="nomic",
|
||||
url="http://embeddings/v1",
|
||||
dimension=768,
|
||||
fingerprint="custom",
|
||||
)
|
||||
fast_lane = EmbeddingLane(
|
||||
name=LANE_FASTEMBED,
|
||||
client=FakeEmbedder(384, "mini", "local://fastembed"),
|
||||
collection=fast_collection,
|
||||
collection_name="odysseus_tool_index_fastembed",
|
||||
model="mini",
|
||||
url="local://fastembed",
|
||||
dimension=384,
|
||||
fingerprint="fast",
|
||||
)
|
||||
|
||||
from src.tool_index import ToolIndex
|
||||
|
||||
index = ToolIndex.__new__(ToolIndex)
|
||||
index._lanes = [custom_lane, fast_lane]
|
||||
|
||||
assert index.retrieve("current mcp", k=2) == ["current_mcp", "one"]
|
||||
@@ -1,95 +0,0 @@
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from src.embeddings import EmbeddingClient
|
||||
|
||||
|
||||
class _FakeEmbeddingHttpClient:
|
||||
def __init__(self, handler):
|
||||
self.handler = handler
|
||||
self.headers = []
|
||||
|
||||
def post(self, url, headers=None, json=None):
|
||||
self.headers.append(headers or {})
|
||||
request = httpx.Request("POST", url)
|
||||
status, body = self.handler(json)
|
||||
return httpx.Response(status, request=request, json=body)
|
||||
|
||||
|
||||
def test_embedding_400_batch_retry_falls_back_to_single_inputs(monkeypatch):
|
||||
monkeypatch.setenv("EMBEDDING_BATCH_SIZE", "8")
|
||||
calls = []
|
||||
|
||||
def handler(payload):
|
||||
texts = payload["input"]
|
||||
calls.append(list(texts))
|
||||
if len(texts) > 1:
|
||||
return 400, {"error": "batch too large"}
|
||||
text = texts[0]
|
||||
return 200, {"data": [{"index": 0, "embedding": [float(len(text)), 1.0]}]}
|
||||
|
||||
client = EmbeddingClient(url="http://embeddings.test/v1/embeddings", model="embed-test")
|
||||
client._client = _FakeEmbeddingHttpClient(handler)
|
||||
|
||||
vecs = client.encode(["a", "bbbb"], normalize_embeddings=False)
|
||||
|
||||
assert calls == [["a", "bbbb"], ["a"], ["bbbb"]]
|
||||
assert vecs.tolist() == [[1.0, 1.0], [4.0, 1.0]]
|
||||
|
||||
|
||||
def test_embedding_400_single_input_retries_with_truncated_text(monkeypatch):
|
||||
monkeypatch.setenv("EMBEDDING_MAX_CHARS", "200")
|
||||
lengths = []
|
||||
|
||||
def handler(payload):
|
||||
text = payload["input"][0]
|
||||
lengths.append(len(text))
|
||||
if len(text) > 200:
|
||||
return 400, {"error": "context length exceeded"}
|
||||
return 200, {"data": [{"index": 0, "embedding": [2.0, 0.0]}]}
|
||||
|
||||
client = EmbeddingClient(url="http://embeddings.test/v1/embeddings", model="embed-test")
|
||||
client._client = _FakeEmbeddingHttpClient(handler)
|
||||
|
||||
vecs = client.encode(["x" * 250], normalize_embeddings=False)
|
||||
|
||||
assert lengths == [250, 200]
|
||||
assert vecs.tolist() == [[2.0, 0.0]]
|
||||
|
||||
|
||||
def test_embedding_non_400_errors_are_not_retried_or_swallowed():
|
||||
calls = 0
|
||||
|
||||
def handler(payload):
|
||||
nonlocal calls
|
||||
calls += 1
|
||||
return 500, {"error": "server error"}
|
||||
|
||||
client = EmbeddingClient(url="http://embeddings.test/v1/embeddings", model="embed-test")
|
||||
client._client = _FakeEmbeddingHttpClient(handler)
|
||||
|
||||
with pytest.raises(httpx.HTTPStatusError):
|
||||
client.encode(["a"], normalize_embeddings=False)
|
||||
|
||||
assert calls == 1
|
||||
|
||||
|
||||
def test_embedding_retry_path_preserves_api_key_header():
|
||||
seen_headers = []
|
||||
|
||||
def handler(payload):
|
||||
return 200, {"data": [{"index": 0, "embedding": [1.0, 0.0]}]}
|
||||
|
||||
client = EmbeddingClient(
|
||||
url="http://embeddings.test/v1/embeddings",
|
||||
model="embed-test",
|
||||
api_key="secret-key",
|
||||
)
|
||||
fake = _FakeEmbeddingHttpClient(handler)
|
||||
client._client = fake
|
||||
|
||||
vecs = client.encode(["a"], normalize_embeddings=False)
|
||||
seen_headers.extend(fake.headers)
|
||||
|
||||
assert vecs.tolist() == [[1.0, 0.0]]
|
||||
assert seen_headers == [{"Authorization": "Bearer secret-key"}]
|
||||
@@ -1,84 +0,0 @@
|
||||
"""Issue #3207 — newly created characters missing from Group participant dropdown.
|
||||
|
||||
The fix has two parts:
|
||||
1. group.js _getCharacterList() merges in-memory userTemplates from presets.js
|
||||
as a fallback (covers the gap while the async templates API save is in-flight).
|
||||
2. presets.js saveCustomPreset() does an optimistic in-memory update of
|
||||
userTemplates immediately on success (bridges the timing race where
|
||||
loadUserTemplates hasn't been triggered yet).
|
||||
|
||||
These tests assert the source patterns exist so they can't be silently removed.
|
||||
"""
|
||||
from pathlib import Path
|
||||
|
||||
GROUP_JS = Path("static/js/group.js").read_text(encoding="utf-8")
|
||||
PRESETS_JS = Path("static/js/presets.js").read_text(encoding="utf-8")
|
||||
|
||||
|
||||
# --- group.js: in-memory template merge in _getCharacterList ---
|
||||
|
||||
def test_group_imports_getUserTemplates():
|
||||
"""group.js must import getUserTemplates from presets.js."""
|
||||
assert "getUserTemplates" in GROUP_JS
|
||||
assert "from './presets.js'" in GROUP_JS or 'from "./presets.js"' in GROUP_JS
|
||||
|
||||
|
||||
def test_group_merges_in_memory_templates():
|
||||
"""_getCharacterList must call getUserTemplates() and merge results."""
|
||||
assert "getUserTemplates()" in GROUP_JS
|
||||
# The merge loop should check for duplicates by id
|
||||
assert "!chars.find(c => c.id === t.id)" in GROUP_JS
|
||||
|
||||
|
||||
# --- presets.js: optimistic in-memory update on save ---
|
||||
|
||||
def test_presets_exports_getUserTemplates():
|
||||
"""getUserTemplates must be exported from presets.js."""
|
||||
assert "export function getUserTemplates()" in PRESETS_JS
|
||||
|
||||
|
||||
def test_presets_optimistic_update_on_save():
|
||||
"""saveCustomPreset must update userTemplates in-memory before the async POST."""
|
||||
# Find the optimistic update block
|
||||
assert "Optimistically update the in-memory templates list" in PRESETS_JS
|
||||
# Must push to userTemplates for new entries
|
||||
assert "userTemplates.push(_entry)" in PRESETS_JS
|
||||
# Must Object.assign for existing entries
|
||||
assert "Object.assign(_existing, _entry)" in PRESETS_JS
|
||||
|
||||
|
||||
def test_presets_getUserTemplates_returns_array():
|
||||
"""getUserTemplates should return a shallow copy of userTemplates."""
|
||||
assert "return [...userTemplates]" in PRESETS_JS
|
||||
|
||||
|
||||
def test_presets_optimistic_id_not_empty():
|
||||
"""Optimistic update must generate a client-side id for new characters (not empty string)."""
|
||||
# The id generation uses 'user-' prefix matching server's uuid convention
|
||||
assert "user-' + Math.random" in PRESETS_JS
|
||||
# Must NOT use empty string as fallback (that was the bug)
|
||||
assert "(_existing && _existing.id) || ''" not in PRESETS_JS
|
||||
|
||||
def test_presets_clone_happens_before_mutation():
|
||||
"""Rollback snapshot must be taken before Object.assign mutates _existing."""
|
||||
clone_idx = PRESETS_JS.find("clone = JSON.parse(JSON.stringify(_existing))")
|
||||
assign_idx = PRESETS_JS.find("Object.assign(_existing, _entry)")
|
||||
|
||||
assert clone_idx != -1
|
||||
assert assign_idx != -1
|
||||
assert clone_idx < assign_idx
|
||||
|
||||
def test_presets_rollbak_restores_from_clone():
|
||||
"""Failed save must restore the original object from the pre-mutation clone."""
|
||||
assert "if (clone)" in PRESETS_JS
|
||||
assert "Object.assign(_existing, clone)" in PRESETS_JS
|
||||
|
||||
def test_presets_clone_is_deep_copy():
|
||||
"""Rollback snapshot must be a deep clone, not an alias."""
|
||||
assert "clone = JSON.parse(JSON.stringify(_existing))" in PRESETS_JS
|
||||
|
||||
def test_presets_no_alias_clone():
|
||||
"""Prevent accidental rollback breakage via reference assignment."""
|
||||
assert "clone = _existing" not in PRESETS_JS
|
||||
assert "const clone = _existing" not in PRESETS_JS
|
||||
assert "let clone = _existing" not in PRESETS_JS
|
||||
@@ -38,7 +38,7 @@ def test_no_hardcoded_loopback_left_in_call_sites():
|
||||
# Regression guard: the converted files must not reintroduce the literal.
|
||||
root = pathlib.Path(__file__).resolve().parent.parent
|
||||
for rel in (
|
||||
"src/tools/_common.py",
|
||||
"src/tool_implementations.py",
|
||||
"src/cookbook_serve_lifecycle.py",
|
||||
"src/builtin_actions.py",
|
||||
"routes/task_routes.py",
|
||||
|
||||
@@ -0,0 +1,350 @@
|
||||
"""Regression: stream_agent_loop surfaces *why* a guard ended the turn.
|
||||
|
||||
Two internal guards used to stop the agent in ways that looked like a clean
|
||||
completion or a vague blocked message:
|
||||
|
||||
* the loop-breaker stall detector -> now emits `loop_breaker_triggered`
|
||||
* the intent-without-action nudge cap -> now emits `intent_nudge_exhausted`
|
||||
|
||||
These tests run the real loop body against a fake LLM stream (no model calls,
|
||||
no sleeps) and assert the structured stop event is emitted.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
|
||||
import pytest
|
||||
|
||||
import src.agent_loop as al
|
||||
|
||||
|
||||
def _collect(gen):
|
||||
async def _run():
|
||||
return [c async for c in gen]
|
||||
return asyncio.run(_run())
|
||||
|
||||
|
||||
def _types(chunks):
|
||||
out = []
|
||||
for c in chunks:
|
||||
if c.startswith("data: ") and not c.startswith("data: [DONE]"):
|
||||
try:
|
||||
out.append(json.loads(c[6:]))
|
||||
except Exception:
|
||||
pass
|
||||
return out
|
||||
|
||||
|
||||
def _patch_common(monkeypatch):
|
||||
monkeypatch.setattr(al, "get_setting", lambda key, default=None: default, raising=False)
|
||||
monkeypatch.setattr(al, "get_mcp_manager", lambda: None, raising=False)
|
||||
monkeypatch.setattr(al, "estimate_tokens", lambda *a, **k: 10, raising=False)
|
||||
|
||||
async def _fake_exec(block, *a, **k):
|
||||
return ("bash", {"output": "ok", "exit_code": 0})
|
||||
monkeypatch.setattr(al, "execute_tool_block", _fake_exec, raising=False)
|
||||
|
||||
|
||||
def _run_loop(monkeypatch, round_text, max_rounds, relevant_tools={"bash"}):
|
||||
async def _fake_stream(_candidates, messages, **kwargs):
|
||||
yield f'data: {json.dumps({"delta": round_text})}\n\n'
|
||||
yield "data: [DONE]\n\n"
|
||||
monkeypatch.setattr(al, "stream_llm_with_fallback", _fake_stream, raising=False)
|
||||
|
||||
gen = al.stream_agent_loop(
|
||||
"http://x/v1", "m",
|
||||
[{"role": "user", "content": "do a long multi-step task"}],
|
||||
max_rounds=max_rounds,
|
||||
relevant_tools=relevant_tools,
|
||||
)
|
||||
return _types(_collect(gen))
|
||||
|
||||
|
||||
def test_emits_loop_breaker_triggered_on_repeated_no_progress(monkeypatch):
|
||||
_patch_common(monkeypatch)
|
||||
# Same exact tool call every round, no answer text -> stuck-round streak
|
||||
# trips the loop-breaker once the cap is reached.
|
||||
events = _run_loop(monkeypatch, "```bash\necho hi\n```", max_rounds=8)
|
||||
lb = [e for e in events if e.get("type") == "loop_breaker_triggered"]
|
||||
assert lb, events
|
||||
e = lb[0]
|
||||
assert e["reason"]
|
||||
assert e["max_stuck_rounds"] == 4
|
||||
assert e["stuck_rounds"] >= 4
|
||||
assert "message" in e
|
||||
|
||||
|
||||
def test_no_loop_breaker_on_normal_finish(monkeypatch):
|
||||
_patch_common(monkeypatch)
|
||||
events = _run_loop(monkeypatch, "All done, here is your answer.", max_rounds=8)
|
||||
assert not any(e.get("type") == "loop_breaker_triggered" for e in events), events
|
||||
|
||||
|
||||
def test_emits_intent_nudge_exhausted_when_cap_reached(monkeypatch):
|
||||
_patch_common(monkeypatch)
|
||||
# The model keeps announcing an action with no tool call. After the nudge
|
||||
# cap is spent, the turn ends with an explicit intent_nudge_exhausted event.
|
||||
events = _run_loop(monkeypatch, "Let me check the logs now", max_rounds=5)
|
||||
inx = [e for e in events if e.get("type") == "intent_nudge_exhausted"]
|
||||
assert inx, events
|
||||
e = inx[0]
|
||||
assert e["max_nudges"] == 2
|
||||
assert e["nudges"] >= 2
|
||||
assert "message" in e
|
||||
|
||||
|
||||
def test_no_intent_nudge_exhausted_on_normal_finish(monkeypatch):
|
||||
_patch_common(monkeypatch)
|
||||
events = _run_loop(monkeypatch, "Here is the complete answer to your question.", max_rounds=5)
|
||||
assert not any(e.get("type") == "intent_nudge_exhausted" for e in events), events
|
||||
|
||||
|
||||
def _assert_guard_log_safe(caplog, *, structural, secret="secret123"):
|
||||
"""The guard's own structural log line fired, and that record carries no raw
|
||||
secret. Scoped to the guard's records on purpose: an unrelated, pre-existing
|
||||
round-summary log echoes raw model text and is out of scope for this PR."""
|
||||
records = [r for r in caplog.records if structural in r.getMessage()]
|
||||
assert records, caplog.text
|
||||
for r in records:
|
||||
assert secret not in r.getMessage(), r.getMessage()
|
||||
|
||||
|
||||
def test_intent_nudge_logging_does_not_leak_secret(monkeypatch, caplog):
|
||||
# The model announces an action (no tool call) with a secret in the text.
|
||||
# The nudge logger must record only structural metadata, never the matched
|
||||
# phrase — so the credential never lands in journalctl.
|
||||
_patch_common(monkeypatch)
|
||||
with caplog.at_level(logging.INFO, logger="src.agent_loop"):
|
||||
events = _run_loop(monkeypatch, "Let me check api_key=secret123 now", max_rounds=5)
|
||||
assert any(e.get("type") == "intent_nudge_exhausted" for e in events), events
|
||||
_assert_guard_log_safe(caplog, structural="intent-without-action nudge")
|
||||
|
||||
|
||||
def test_loop_breaker_logging_does_not_leak_secret(monkeypatch, caplog):
|
||||
# A repeated tool command carrying a secret trips the loop-breaker. The
|
||||
# structural log must not contain `_sig` / raw tool-call content.
|
||||
_patch_common(monkeypatch)
|
||||
with caplog.at_level(logging.INFO, logger="src.agent_loop"):
|
||||
events = _run_loop(monkeypatch, "```bash\necho api_key=secret123\n```", max_rounds=8)
|
||||
assert any(e.get("type") == "loop_breaker_triggered" for e in events), events
|
||||
_assert_guard_log_safe(caplog, structural="loop-breaker tripped")
|
||||
|
||||
|
||||
def test_redacts_sensitive_tool_output_before_surfacing():
|
||||
text = al._redact_sensitive_text(
|
||||
"password: private-value\n"
|
||||
"api_key=private-key\n"
|
||||
"Authorization: Bearer private-token\n"
|
||||
"normal output"
|
||||
)
|
||||
|
||||
assert "private-value" not in text
|
||||
assert "private-key" not in text
|
||||
assert "private-token" not in text
|
||||
assert "password: [redacted]" in text
|
||||
assert "api_key=[redacted]" in text
|
||||
assert "Authorization: Bearer [redacted]" in text
|
||||
assert "normal output" in text
|
||||
|
||||
|
||||
_GCP_API_KEY_SAMPLE = "AI" + "za" + ("A" * 35)
|
||||
|
||||
# (input, secret substring that must be gone, expected substring that must remain)
|
||||
_REDACTION_CASES = [
|
||||
("Authorization: Bearer abc123tok", "abc123tok", "Authorization: Bearer [redacted]"),
|
||||
("Authorization: Basic dXNlcjpwYXNz", "dXNlcjpwYXNz", "Authorization: Basic [redacted]"),
|
||||
# Quoted Authorization value (spaces) must be redacted whole.
|
||||
('Authorization: Bearer "two word secret"', "two word secret", "Authorization: Bearer [redacted]"),
|
||||
# Escaped quote inside a quoted secret must not leak the tail.
|
||||
(r'password="abc\"def secret"', "def secret", "password=[redacted]"),
|
||||
# URL password containing a colon must still be redacted whole.
|
||||
("postgres://user:pa:ss@host/db", "pa:ss", "postgres://[redacted]@host/db"),
|
||||
# Provider-shaped bare tokens.
|
||||
("token is hf_abcdefghij1234567890XYZ", "hf_abcdefghij1234567890XYZ", "[redacted]"),
|
||||
("key " + _GCP_API_KEY_SAMPLE, _GCP_API_KEY_SAMPLE, "[redacted]"),
|
||||
("Cookie: session=abc123secret", "abc123secret", "Cookie: [redacted]"),
|
||||
("Set-Cookie: sid=xyz789; HttpOnly", "xyz789", "Set-Cookie: [redacted]"),
|
||||
("postgres://user:pa55word@host/db", "pa55word", "postgres://[redacted]@host/db"),
|
||||
("client_secret=supersecretvalue", "supersecretvalue", "client_secret=[redacted]"),
|
||||
("OPENAI_API_KEY=abcd1234deadbeef", "abcd1234deadbeef", "OPENAI_API_KEY=[redacted]"),
|
||||
# Quoted multi-word env value must be fully redacted, not clipped at the space.
|
||||
('OPENAI_API_KEY="two word secret"', "two word secret", "OPENAI_API_KEY=[redacted]"),
|
||||
('password: "my secret value"', "my secret value", "password: [redacted]"),
|
||||
("here is sk-abcdefghij1234567890", "sk-abcdefghij1234567890", "[redacted]"),
|
||||
(
|
||||
"-----BEGIN PRIVATE KEY-----\nMIIfakeKEYbody\n-----END PRIVATE KEY-----",
|
||||
"MIIfakeKEYbody",
|
||||
"[redacted private key]",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("raw, secret, expected", _REDACTION_CASES)
|
||||
def test_redaction_covers_requested_secret_shapes(raw, secret, expected):
|
||||
out = al._redact_sensitive_text(raw)
|
||||
assert secret not in out, out
|
||||
assert expected in out, out
|
||||
|
||||
|
||||
@pytest.mark.parametrize("raw", [
|
||||
"the build completed in 3.2s with 0 errors",
|
||||
"password reset email sent to the user",
|
||||
"Listing 5 files: a.py b.py c.py d.py e.py",
|
||||
"https://example.com/path?page=2",
|
||||
# Benign uppercase names that merely end in KEY must not be redacted.
|
||||
"MONKEY=banana",
|
||||
"TURKEY=dinner",
|
||||
])
|
||||
def test_redaction_keeps_normal_output_readable(raw):
|
||||
assert al._redact_sensitive_text(raw) == raw
|
||||
|
||||
|
||||
def test_redacts_before_truncating():
|
||||
# A secret near the start must be gone even if truncation would otherwise
|
||||
# only clip the tail — redaction runs first.
|
||||
raw = "api_key=topsecretvalue " + ("x" * 50_000)
|
||||
out = al._truncate(al._redact_sensitive_text(raw))
|
||||
assert "topsecretvalue" not in out
|
||||
assert "api_key=[redacted]" in out
|
||||
|
||||
|
||||
def _run_tool_result(monkeypatch, tool, exec_result, max_rounds=2):
|
||||
"""Drive one tool round whose execution returns `exec_result`, and collect
|
||||
the streamed events. Used to assert restored per-tool-result emissions."""
|
||||
_patch_common(monkeypatch)
|
||||
|
||||
async def _fake_exec(block, *a, **k):
|
||||
return (tool, exec_result)
|
||||
monkeypatch.setattr(al, "execute_tool_block", _fake_exec, raising=False)
|
||||
|
||||
round_text = f"```{tool}\n{{}}\n```"
|
||||
|
||||
async def _fake_stream(_candidates, messages, **kwargs):
|
||||
yield f'data: {json.dumps({"delta": round_text})}\n\n'
|
||||
yield "data: [DONE]\n\n"
|
||||
monkeypatch.setattr(al, "stream_llm_with_fallback", _fake_stream, raising=False)
|
||||
|
||||
gen = al.stream_agent_loop(
|
||||
"http://x/v1", "m",
|
||||
[{"role": "user", "content": "do something"}],
|
||||
max_rounds=max_rounds,
|
||||
relevant_tools={tool},
|
||||
)
|
||||
return _types(_collect(gen))
|
||||
|
||||
|
||||
def test_restores_doc_suggestions_event(monkeypatch):
|
||||
events = _run_tool_result(
|
||||
monkeypatch, "suggest_document",
|
||||
{"action": "suggest", "doc_id": "d1", "suggestions": [{"text": "x"}], "exit_code": 0},
|
||||
)
|
||||
assert any(e.get("type") == "doc_suggestions" for e in events), events
|
||||
|
||||
|
||||
def test_restores_doc_update_event(monkeypatch):
|
||||
events = _run_tool_result(
|
||||
monkeypatch, "edit_document",
|
||||
{"action": "edit", "doc_id": "d1", "content": "body", "version": 2,
|
||||
"title": "T", "language": "md", "exit_code": 0},
|
||||
)
|
||||
# A native document block also emits doc_update AFTER tool_output, so a plain
|
||||
# "any doc_update" check would pass even if the restored generic block were
|
||||
# gone. Prove the restored block fires BEFORE the first tool_output.
|
||||
types = [e.get("type") for e in events]
|
||||
assert "doc_update" in types, events
|
||||
assert "tool_output" in types, events
|
||||
assert types.index("doc_update") < types.index("tool_output"), types
|
||||
|
||||
|
||||
def test_restores_ui_control_event(monkeypatch):
|
||||
events = _run_tool_result(
|
||||
monkeypatch, "ui_control",
|
||||
{"ui_event": "toggle", "toggle_name": "bash", "state": "off", "exit_code": 0},
|
||||
)
|
||||
assert any(e.get("type") == "ui_control" for e in events), events
|
||||
|
||||
|
||||
def test_restores_plan_update_event(monkeypatch):
|
||||
events = _run_tool_result(
|
||||
monkeypatch, "update_plan",
|
||||
{"plan_update": {"steps": [{"text": "step", "done": True}]}, "exit_code": 0},
|
||||
)
|
||||
assert any(e.get("type") == "plan_update" for e in events), events
|
||||
|
||||
|
||||
def test_restores_ask_user_event_and_persists_question(monkeypatch):
|
||||
events = _run_tool_result(
|
||||
monkeypatch, "ask_user",
|
||||
{"ask_user": {"question": "Which option?", "options": [{"label": "A"}, {"label": "B"}]},
|
||||
"exit_code": 0},
|
||||
)
|
||||
# Exactly one ask_user event — not re-emitted on a follow-up round.
|
||||
_ask_events = [e for e in events if e.get("type") == "ask_user"]
|
||||
assert len(_ask_events) == 1, events
|
||||
# The question is streamed as assistant text so it persists for replay.
|
||||
# Upstream prepends "\n\n" when full_response already holds streamed text,
|
||||
# so match on containment — and it must be streamed exactly once.
|
||||
_q_deltas = [e for e in events if "Which option?" in (e.get("delta") or "")]
|
||||
assert len(_q_deltas) == 1, events
|
||||
# Setting `_awaiting_user` breaks the loop, so the turn does NOT advance into
|
||||
# another agent round (which would emit an agent_step event) after the ask.
|
||||
assert not any(e.get("type") == "agent_step" for e in events), events
|
||||
|
||||
|
||||
def test_redacts_command_display_in_streamed_events(monkeypatch):
|
||||
# A tool command line can carry a secret. The streamed command display
|
||||
# (tool_start / tool_output) must be redacted, even though the real command
|
||||
# passed to execution is left untouched.
|
||||
_patch_common(monkeypatch)
|
||||
|
||||
round_text = "```bash\necho api_key=secret123\n```"
|
||||
|
||||
async def _fake_stream(_candidates, messages, **kwargs):
|
||||
yield f'data: {json.dumps({"delta": round_text})}\n\n'
|
||||
yield "data: [DONE]\n\n"
|
||||
monkeypatch.setattr(al, "stream_llm_with_fallback", _fake_stream, raising=False)
|
||||
|
||||
gen = al.stream_agent_loop(
|
||||
"http://x/v1", "m",
|
||||
[{"role": "user", "content": "run it"}],
|
||||
max_rounds=2,
|
||||
relevant_tools={"bash"},
|
||||
)
|
||||
events = _types(_collect(gen))
|
||||
cmds = [e for e in events if e.get("type") in ("tool_start", "tool_output")]
|
||||
assert cmds, events
|
||||
assert all("secret123" not in (e.get("command") or "") for e in cmds), cmds
|
||||
assert any("api_key=[redacted]" in (e.get("command") or "") for e in cmds), cmds
|
||||
|
||||
|
||||
def test_redacts_live_tool_progress_tail(monkeypatch):
|
||||
# A secret in the live progress tail must be redacted before streaming —
|
||||
# otherwise it flashes by before the (already redacted) final tool_output.
|
||||
_patch_common(monkeypatch)
|
||||
|
||||
async def _fake_exec(block, *a, **k):
|
||||
await k["progress_cb"]({"tail": "api_key=secret123", "elapsed_s": 1})
|
||||
return ("bash", {"output": "done", "exit_code": 0})
|
||||
monkeypatch.setattr(al, "execute_tool_block", _fake_exec, raising=False)
|
||||
|
||||
round_text = "```bash\necho hi\n```"
|
||||
|
||||
async def _fake_stream(_candidates, messages, **kwargs):
|
||||
yield f'data: {json.dumps({"delta": round_text})}\n\n'
|
||||
yield "data: [DONE]\n\n"
|
||||
monkeypatch.setattr(al, "stream_llm_with_fallback", _fake_stream, raising=False)
|
||||
|
||||
gen = al.stream_agent_loop(
|
||||
"http://x/v1", "m",
|
||||
[{"role": "user", "content": "run it"}],
|
||||
max_rounds=2,
|
||||
relevant_tools={"bash"},
|
||||
)
|
||||
events = _types(_collect(gen))
|
||||
prog = [e for e in events if e.get("type") == "tool_progress"]
|
||||
assert prog, events
|
||||
assert all("secret123" not in (e.get("tail") or "") for e in prog), prog
|
||||
assert any("api_key=[redacted]" in (e.get("tail") or "") for e in prog), prog
|
||||
# Other fields are preserved.
|
||||
assert any(e.get("elapsed_s") == 1 for e in prog), prog
|
||||
@@ -1,89 +0,0 @@
|
||||
"""Regression tests for ReDoS in agent_loop's `<think>...</think>` stripping.
|
||||
|
||||
CodeQL flagged `py/polynomial-redos` on the lazy `<think>.*?</think>` pattern
|
||||
used in `src/agent_loop.py` (one compiled `_THINK_RE`, one inline copy). It is
|
||||
applied with `re.sub` over a whole model response. When the closing delimiter
|
||||
is missing, the engine rescans to end-of-string from every `<think>` opener ->
|
||||
O(n^2) on attacker-influenced input (prompt injection via tool output /
|
||||
retrieved content echoed back by the model).
|
||||
|
||||
The fix replaces the regex with `_strip_think_blocks`, a forward-only linear
|
||||
scan that is byte-for-byte equivalent to the original
|
||||
`re.sub(r'<think>.*?</think>', '', text, flags=DOTALL|IGNORECASE)`.
|
||||
|
||||
These tests pin BOTH halves:
|
||||
* output is identical to the reference regex for legitimate inputs, and
|
||||
* pathological "many openers, no closer" input completes promptly.
|
||||
"""
|
||||
|
||||
import re
|
||||
import time
|
||||
|
||||
from src.agent_loop import _strip_think_blocks
|
||||
|
||||
# The exact pattern this fix replaces. Used only as an equivalence oracle on
|
||||
# well-formed inputs (never on the adversarial one, where it is the slow path).
|
||||
_REFERENCE_RE = re.compile(r"<think>.*?</think>", re.DOTALL | re.IGNORECASE)
|
||||
|
||||
|
||||
def _reference(text: str) -> str:
|
||||
return _REFERENCE_RE.sub("", text or "")
|
||||
|
||||
|
||||
# Loose ceiling: the linear helper finishes in well under 100ms; the vulnerable
|
||||
# regex took seconds-to-tens-of-seconds on the same input.
|
||||
_BUDGET_S = 4.0
|
||||
|
||||
|
||||
# -- equivalence with the original regex -------------------------------------
|
||||
|
||||
EQUIV_CASES = [
|
||||
"",
|
||||
"no tags here at all",
|
||||
"<think>hidden</think>visible",
|
||||
"before<think>cot</think>after",
|
||||
"a<think>one</think>b<think>two</think>c",
|
||||
"<think>only</think>",
|
||||
"<think></think>tail",
|
||||
"<think>a<think>nested</think>rest", # lazy stops at first closer
|
||||
"leading</think>orphan<think>x</think>", # orphan closer is NOT stripped
|
||||
"trailing<think>no closer for this one", # dangling opener kept verbatim
|
||||
"CASE <THINK>UP</THINK> mix <Think>x</Think>", # case-insensitive
|
||||
"multi\nline\n<think>a\nb\nc</think>\nkeep", # DOTALL across newlines
|
||||
"<thinking>not matched by narrow regex</thinking>", # only literal <think>
|
||||
"<think >space-in-tag not matched</think >", # literal tag only
|
||||
]
|
||||
|
||||
|
||||
def test_strip_think_blocks_matches_reference_regex():
|
||||
for case in EQUIV_CASES:
|
||||
assert _strip_think_blocks(case) == _reference(case), repr(case)
|
||||
|
||||
|
||||
def test_empty_and_none_safe():
|
||||
assert _strip_think_blocks("") == ""
|
||||
assert _strip_think_blocks(None) in (None, "")
|
||||
|
||||
|
||||
# -- ReDoS bound -------------------------------------------------------------
|
||||
|
||||
def test_many_openers_no_closer_is_linear():
|
||||
# Attacker echoes thousands of "<think>" with no closer. The lazy regex
|
||||
# rescans to EOS from each opener (O(n^2)); the helper scans once.
|
||||
hostile = "<think>" * 60_000 + "x"
|
||||
start = time.perf_counter()
|
||||
out = _strip_think_blocks(hostile)
|
||||
elapsed = time.perf_counter() - start
|
||||
# No closer anywhere -> nothing is stripped, input returned intact.
|
||||
assert out == hostile
|
||||
assert elapsed < _BUDGET_S, f"took {elapsed:.2f}s (expected linear)"
|
||||
|
||||
|
||||
def test_openers_then_one_far_closer_is_linear():
|
||||
hostile = "<think>" * 60_000 + "</think>" + "tail"
|
||||
start = time.perf_counter()
|
||||
out = _strip_think_blocks(hostile)
|
||||
elapsed = time.perf_counter() - start
|
||||
# First opener pairs with the single closer; lazy match spans to it.
|
||||
assert out == "tail"
|
||||
assert elapsed < _BUDGET_S, f"took {elapsed:.2f}s (expected linear)"
|
||||
@@ -41,24 +41,10 @@ def test_sub_area_only_marker_expression():
|
||||
assert build_marker_expression(None, "cookbook") == "sub_cookbook"
|
||||
|
||||
|
||||
def test_embedding_sub_area_marker_expression_includes_memory_split():
|
||||
assert (
|
||||
build_marker_expression(None, "embedding")
|
||||
== "(sub_embedding or sub_embedding_memory)"
|
||||
)
|
||||
|
||||
|
||||
def test_area_and_sub_area_marker_expression():
|
||||
assert build_marker_expression("services", "cookbook") == "area_services and sub_cookbook"
|
||||
|
||||
|
||||
def test_area_and_embedding_sub_area_marker_expression_includes_memory_split():
|
||||
assert (
|
||||
build_marker_expression("services", "embedding")
|
||||
== "area_services and (sub_embedding or sub_embedding_memory)"
|
||||
)
|
||||
|
||||
|
||||
def test_no_selection_marker_expression_is_none():
|
||||
assert build_marker_expression(None, None) is None
|
||||
|
||||
@@ -89,12 +75,6 @@ def test_sub_area_only_command():
|
||||
assert _cmd(sub_area="cookbook") == [PY, "-m", "pytest", "-m", "sub_cookbook"]
|
||||
|
||||
|
||||
def test_embedding_sub_area_command_includes_memory_split():
|
||||
assert _cmd(sub_area="embedding") == [
|
||||
PY, "-m", "pytest", "-m", "(sub_embedding or sub_embedding_memory)",
|
||||
]
|
||||
|
||||
|
||||
def test_area_and_sub_area_command():
|
||||
assert _cmd(area="services", sub_area="cookbook") == [
|
||||
PY, "-m", "pytest", "-m", "area_services and sub_cookbook",
|
||||
@@ -150,13 +130,6 @@ def test_fast_with_area_and_sub_area_command():
|
||||
]
|
||||
|
||||
|
||||
def test_fast_with_embedding_sub_area_command_includes_memory_split():
|
||||
assert _cmd(sub_area="embedding", fast=True) == [
|
||||
PY, "-m", "pytest", "-m",
|
||||
"(sub_embedding or sub_embedding_memory) and not slow",
|
||||
]
|
||||
|
||||
|
||||
def test_durations_appends_flag():
|
||||
assert _cmd(fast=True, durations=25) == [
|
||||
PY, "-m", "pytest", "-m", "not slow", "--durations=25",
|
||||
@@ -279,30 +252,6 @@ def test_run_accepts_both_sub_area_forms(value):
|
||||
]]
|
||||
|
||||
|
||||
def test_run_keeps_embedding_memory_selector_specific():
|
||||
executor = _FakeExecutor()
|
||||
run(["--sub-area", "embedding_memory"], executor=executor)
|
||||
assert executor.calls == [[
|
||||
sys.executable,
|
||||
"-m",
|
||||
"pytest",
|
||||
"-m",
|
||||
"sub_embedding_memory",
|
||||
]]
|
||||
|
||||
|
||||
def test_run_expands_embedding_selector_to_memory_split():
|
||||
executor = _FakeExecutor()
|
||||
run(["--sub-area", "embedding"], executor=executor)
|
||||
assert executor.calls == [[
|
||||
sys.executable,
|
||||
"-m",
|
||||
"pytest",
|
||||
"-m",
|
||||
"(sub_embedding or sub_embedding_memory)",
|
||||
]]
|
||||
|
||||
|
||||
def test_invalid_area_exits_with_error():
|
||||
with pytest.raises(SystemExit) as excinfo:
|
||||
run(["--area", "bogus"], executor=_FakeExecutor())
|
||||
|
||||
@@ -50,12 +50,6 @@ def test_classify_examples(filename, expected_area, expected_sub):
|
||||
assert result.sub_area == expected_sub
|
||||
|
||||
|
||||
def test_embedding_lanes_memory_file_keeps_specific_sub_area():
|
||||
result = classify_test_path("tests/test_embedding_lanes_memory.py")
|
||||
assert result.area == "services"
|
||||
assert result.sub_area == "embedding_memory"
|
||||
|
||||
|
||||
# --- classify_test_path: fallback --------------------------------------------
|
||||
|
||||
def test_unknown_filename_is_uncategorized():
|
||||
|
||||
@@ -1,239 +0,0 @@
|
||||
import asyncio
|
||||
from types import SimpleNamespace
|
||||
import pytest
|
||||
|
||||
import src.teacher_escalation as teacher_escalation
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_evaluate_turn_llm_ok(monkeypatch):
|
||||
seen = {}
|
||||
|
||||
def fake_resolve_endpoint(prefix, fallback_url=None, owner=None):
|
||||
seen["prefix"] = prefix
|
||||
seen["owner"] = owner
|
||||
return "http://endpoint.local/v1", "utility-model", {}
|
||||
|
||||
async def fake_llm_call_async(url, model, messages, **kwargs):
|
||||
seen["called"] = True
|
||||
return "ok"
|
||||
|
||||
monkeypatch.setattr("src.endpoint_resolver.resolve_endpoint", fake_resolve_endpoint)
|
||||
monkeypatch.setattr("src.llm_core.llm_call_async", fake_llm_call_async)
|
||||
|
||||
status, reason = await teacher_escalation.evaluate_turn_llm(
|
||||
user_request="test request",
|
||||
tool_results=[],
|
||||
agent_reply="test reply",
|
||||
student_endpoint_url="http://student.local/v1",
|
||||
owner="alice",
|
||||
)
|
||||
|
||||
assert status == "ok"
|
||||
assert reason is None
|
||||
assert seen["prefix"] == "utility"
|
||||
assert seen["owner"] == "alice"
|
||||
assert seen["called"] is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_evaluate_turn_llm_failure(monkeypatch):
|
||||
def fake_resolve_endpoint(prefix, fallback_url=None, owner=None):
|
||||
return "http://endpoint.local/v1", "utility-model", {}
|
||||
|
||||
async def fake_llm_call_async(url, model, messages, **kwargs):
|
||||
return " \"Failure\" "
|
||||
|
||||
monkeypatch.setattr("src.endpoint_resolver.resolve_endpoint", fake_resolve_endpoint)
|
||||
monkeypatch.setattr("src.llm_core.llm_call_async", fake_llm_call_async)
|
||||
|
||||
status, reason = await teacher_escalation.evaluate_turn_llm(
|
||||
user_request="test request",
|
||||
tool_results=[],
|
||||
agent_reply="test reply",
|
||||
student_endpoint_url="http://student.local/v1",
|
||||
owner="alice",
|
||||
)
|
||||
|
||||
assert status == "failure"
|
||||
assert "LLM evaluation flagged failure" in reason
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_evaluate_turn_llm_contains_failure_but_not_exact_match(monkeypatch):
|
||||
def fake_resolve_endpoint(prefix, fallback_url=None, owner=None):
|
||||
return "http://endpoint.local/v1", "utility-model", {}
|
||||
|
||||
async def fake_llm_call_async(url, model, messages, **kwargs):
|
||||
return "this agent execution is not a failure"
|
||||
|
||||
monkeypatch.setattr("src.endpoint_resolver.resolve_endpoint", fake_resolve_endpoint)
|
||||
monkeypatch.setattr("src.llm_core.llm_call_async", fake_llm_call_async)
|
||||
|
||||
status, reason = await teacher_escalation.evaluate_turn_llm(
|
||||
user_request="test request",
|
||||
tool_results=[],
|
||||
agent_reply="test reply",
|
||||
student_endpoint_url="http://student.local/v1",
|
||||
owner="alice",
|
||||
)
|
||||
|
||||
assert status == "ok"
|
||||
assert reason is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_evaluate_turn_llm_exception_handling(monkeypatch):
|
||||
def fake_resolve_endpoint(prefix, fallback_url=None, owner=None):
|
||||
return "http://endpoint.local/v1", "utility-model", {}
|
||||
|
||||
async def fake_llm_call_async(url, model, messages, **kwargs):
|
||||
raise RuntimeError("model timeout")
|
||||
|
||||
monkeypatch.setattr("src.endpoint_resolver.resolve_endpoint", fake_resolve_endpoint)
|
||||
monkeypatch.setattr("src.llm_core.llm_call_async", fake_llm_call_async)
|
||||
|
||||
# Should degrade gracefully to "ok"
|
||||
status, reason = await teacher_escalation.evaluate_turn_llm(
|
||||
user_request="test request",
|
||||
tool_results=[],
|
||||
agent_reply="test reply",
|
||||
student_endpoint_url="http://student.local/v1",
|
||||
owner="alice",
|
||||
)
|
||||
|
||||
assert status == "ok"
|
||||
assert reason is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_maybe_escalate_triggers_tier2_background_task(monkeypatch):
|
||||
# Enable teacher settings
|
||||
monkeypatch.setattr("src.settings.get_setting", lambda key, default=None: {"teacher_enabled": True, "teacher_model": "teacher-model", "teacher_tier2_enabled": True}.get(key, default))
|
||||
|
||||
# Regex check says OK
|
||||
monkeypatch.setattr("src.teacher_escalation.evaluate_turn_regex", lambda *args: ("ok", None))
|
||||
|
||||
llm_eval_called = []
|
||||
async def fake_evaluate_turn_llm(*args, **kwargs):
|
||||
llm_eval_called.append(True)
|
||||
return "failure", "LLM flagged failure"
|
||||
|
||||
monkeypatch.setattr("src.teacher_escalation.evaluate_turn_llm", fake_evaluate_turn_llm)
|
||||
|
||||
escalate_called = []
|
||||
async def fake_escalate_and_learn(user_request, tool_results, agent_reply, failure_reason, owner):
|
||||
escalate_called.append(failure_reason)
|
||||
return "skill-slug"
|
||||
|
||||
monkeypatch.setattr("src.teacher_escalation.escalate_and_learn", fake_escalate_and_learn)
|
||||
|
||||
# Call maybe_escalate
|
||||
task = teacher_escalation.maybe_escalate(
|
||||
student_endpoint_url="http://student.local/v1",
|
||||
mode="agent",
|
||||
user_request="test request",
|
||||
tool_results=[],
|
||||
agent_reply="test reply",
|
||||
owner="alice",
|
||||
)
|
||||
|
||||
assert task is not None
|
||||
assert task.get_name() == "teacher_escalation_tier2"
|
||||
|
||||
# Await the background task execution
|
||||
await task
|
||||
|
||||
assert llm_eval_called == [True]
|
||||
assert escalate_called == ["LLM flagged failure"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_maybe_escalate_tier2_disabled_by_default(monkeypatch):
|
||||
# Enable teacher settings, but keep tier2 disabled
|
||||
monkeypatch.setattr("src.settings.get_setting", lambda key, default=None: {"teacher_enabled": True, "teacher_model": "teacher-model", "teacher_tier2_enabled": False}.get(key, default))
|
||||
|
||||
# Regex check says OK
|
||||
monkeypatch.setattr("src.teacher_escalation.evaluate_turn_regex", lambda *args: ("ok", None))
|
||||
|
||||
# Call maybe_escalate
|
||||
task = teacher_escalation.maybe_escalate(
|
||||
student_endpoint_url="http://student.local/v1",
|
||||
mode="agent",
|
||||
user_request="test request",
|
||||
tool_results=[],
|
||||
agent_reply="test reply",
|
||||
owner="alice",
|
||||
)
|
||||
|
||||
# Should not start any background task since Tier 2 is disabled
|
||||
assert task is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_teacher_inline_triggers_tier2_escalation(monkeypatch):
|
||||
# Settings and gates
|
||||
monkeypatch.setattr("src.settings.get_setting", lambda key, default=None: {"teacher_enabled": True, "teacher_model": "teacher-model", "teacher_tier2_enabled": True}.get(key, default))
|
||||
monkeypatch.setattr("src.ai_interaction._resolve_model", lambda spec, owner=None: ("http://teacher.local/v1", "teacher-model", {}))
|
||||
|
||||
# Regex evaluation says "ok"
|
||||
monkeypatch.setattr("src.teacher_escalation.evaluate_turn_regex", lambda *args: ("ok", None))
|
||||
|
||||
# LLM evaluation flags "failure"
|
||||
async def fake_evaluate_turn_llm(*args, **kwargs):
|
||||
return "failure", "LLM flagged failure"
|
||||
monkeypatch.setattr("src.teacher_escalation.evaluate_turn_llm", fake_evaluate_turn_llm)
|
||||
|
||||
# Mock stream_agent_loop recursively called by run_teacher_inline
|
||||
async def fake_stream_agent_loop(*args, **kwargs):
|
||||
yield "data: {\"type\": \"tool_output\", \"tool\": \"bash\"}\n\n"
|
||||
yield "data: {\"type\": \"text\", \"delta\": \"Teacher reply\"}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
monkeypatch.setattr("src.agent_loop.stream_agent_loop", fake_stream_agent_loop)
|
||||
|
||||
# Mock _call_teacher returning a skill definition
|
||||
async def fake_call_teacher(spec, prompt, owner=None):
|
||||
return '```json\n{"action": "add", "name": "test-skill"}\n```'
|
||||
monkeypatch.setattr("src.teacher_escalation._call_teacher", fake_call_teacher)
|
||||
|
||||
# Mock do_manage_skills
|
||||
async def fake_do_manage_skills(skill_json, owner=None):
|
||||
return {"success": True}
|
||||
monkeypatch.setattr("src.tool_implementations.do_manage_skills", fake_do_manage_skills)
|
||||
|
||||
events = []
|
||||
async for evt in teacher_escalation.run_teacher_inline(
|
||||
student_endpoint_url="http://student.local/v1",
|
||||
student_messages=[{"role": "user", "content": "test request"}],
|
||||
student_tool_events=[],
|
||||
student_reply="student reply",
|
||||
owner="alice",
|
||||
):
|
||||
events.append(evt)
|
||||
|
||||
# Make sure teacher takeover was announced and executed
|
||||
assert any("teacher_takeover" in evt for evt in events)
|
||||
assert any("tool_output" in evt for evt in events)
|
||||
assert any("skill_saved" in evt for evt in events)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_teacher_inline_tier2_disabled_by_default(monkeypatch):
|
||||
# Settings and gates (Tier 2 disabled)
|
||||
monkeypatch.setattr("src.settings.get_setting", lambda key, default=None: {"teacher_enabled": True, "teacher_model": "teacher-model", "teacher_tier2_enabled": False}.get(key, default))
|
||||
|
||||
# Regex evaluation says "ok"
|
||||
monkeypatch.setattr("src.teacher_escalation.evaluate_turn_regex", lambda *args: ("ok", None))
|
||||
|
||||
events = []
|
||||
async for evt in teacher_escalation.run_teacher_inline(
|
||||
student_endpoint_url="http://student.local/v1",
|
||||
student_messages=[{"role": "user", "content": "test request"}],
|
||||
student_tool_events=[],
|
||||
student_reply="student reply",
|
||||
owner="alice",
|
||||
):
|
||||
events.append(evt)
|
||||
|
||||
# Should exit early without any events (no takeover)
|
||||
assert len(events) == 0
|
||||
@@ -1,155 +0,0 @@
|
||||
"""Guard that toast dismissal (via the × close button) correctly resets
|
||||
pointer-events so the invisible fixed overlay does not block clicks.
|
||||
|
||||
The reviewer flagged that action-toasts set ``pointer-events: auto`` on
|
||||
``#toast`` for their clickable button, but the close-button dismiss path
|
||||
was cancelling the auto-hide timer without resetting ``pointer-events``.
|
||||
This left an invisible element intercepting mouse/touch events.
|
||||
|
||||
These are source-level assertions (no browser, no DOM) that verify the
|
||||
close-button handler includes the reset. They cover:
|
||||
• ordinary (plain text) toast – showToast
|
||||
• error toast – showError
|
||||
• action toast – showToast with action opts
|
||||
"""
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
_REPO = Path(__file__).resolve().parent.parent
|
||||
_UI_PATH = _REPO / "static" / "js" / "ui.js"
|
||||
|
||||
|
||||
def _read_ui():
|
||||
return _UI_PATH.read_text(encoding="utf-8")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers – extract the close-button event-handler bodies from each function.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _extract_function(src: str, func_name: str) -> str:
|
||||
"""Return the full body of *func_name* (exported or not)."""
|
||||
# Match export function showToast(… or function showToast(…
|
||||
pat = re.compile(
|
||||
rf"(?:export\s+)?function\s+{re.escape(func_name)}\s*\(", re.DOTALL
|
||||
)
|
||||
m = pat.search(src)
|
||||
assert m, f"could not find function {func_name!r} in ui.js"
|
||||
start = m.start()
|
||||
# Walk forward counting braces to find the matching closing brace.
|
||||
depth = 0
|
||||
for i in range(start, len(src)):
|
||||
if src[i] == "{":
|
||||
depth += 1
|
||||
elif src[i] == "}":
|
||||
depth -= 1
|
||||
if depth == 0:
|
||||
return src[start : i + 1]
|
||||
raise AssertionError(f"unbalanced braces for {func_name}")
|
||||
|
||||
|
||||
def _extract_close_handler(func_body: str) -> str:
|
||||
"""Return the close-button click-handler body inside *func_body*.
|
||||
|
||||
Looks for the ``toast-close-btn`` class assignment, then finds the
|
||||
``addEventListener('click'`` call that follows, and extracts the arrow
|
||||
function body.
|
||||
"""
|
||||
idx = func_body.find("toast-close-btn")
|
||||
assert idx != -1, "toast-close-btn not found in function body"
|
||||
# Find the addEventListener('click', … that follows
|
||||
listen_idx = func_body.find("addEventListener('click'", idx)
|
||||
if listen_idx == -1:
|
||||
listen_idx = func_body.find('addEventListener("click"', idx)
|
||||
assert listen_idx != -1, "addEventListener('click') not found after toast-close-btn"
|
||||
|
||||
# Find the opening brace of the handler
|
||||
brace = func_body.find("{", listen_idx)
|
||||
assert brace != -1
|
||||
depth = 0
|
||||
for i in range(brace, len(func_body)):
|
||||
if func_body[i] == "{":
|
||||
depth += 1
|
||||
elif func_body[i] == "}":
|
||||
depth -= 1
|
||||
if depth == 0:
|
||||
return func_body[brace : i + 1]
|
||||
raise AssertionError("unbalanced braces in close handler")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_showToast_close_handler_resets_pointer_events():
|
||||
"""showToast's × handler must clear pointer-events so an action-toast
|
||||
that set them to 'auto' doesn't leave the overlay blocking clicks."""
|
||||
src = _read_ui()
|
||||
body = _extract_function(src, "showToast")
|
||||
handler = _extract_close_handler(body)
|
||||
assert "pointerEvents" in handler, (
|
||||
"showToast close-button handler does not reset pointerEvents – "
|
||||
"action toasts will leave an invisible click-blocking overlay"
|
||||
)
|
||||
|
||||
|
||||
def test_showError_close_handler_resets_pointer_events():
|
||||
"""showError's × handler must also clear pointer-events defensively,
|
||||
in case a prior action-toast left them as 'auto'."""
|
||||
src = _read_ui()
|
||||
body = _extract_function(src, "showError")
|
||||
handler = _extract_close_handler(body)
|
||||
assert "pointerEvents" in handler, (
|
||||
"showError close-button handler does not reset pointerEvents – "
|
||||
"a prior action toast could leave the overlay blocking clicks"
|
||||
)
|
||||
|
||||
|
||||
def test_showToast_timer_resets_pointer_events():
|
||||
"""The auto-hide timer in showToast must also reset pointer-events.
|
||||
This was already in place before the × button was added; make sure
|
||||
it stays."""
|
||||
src = _read_ui()
|
||||
body = _extract_function(src, "showToast")
|
||||
# The _hideTimer setTimeout body should contain the reset
|
||||
timer_idx = body.find("_hideTimer")
|
||||
assert timer_idx != -1, "no _hideTimer found in showToast"
|
||||
# Find the setTimeout callback after the last _hideTimer assignment
|
||||
last_timer = body.rfind("_hideTimer = setTimeout")
|
||||
assert last_timer != -1
|
||||
# Extract the setTimeout callback body
|
||||
brace = body.find("{", last_timer)
|
||||
depth = 0
|
||||
timer_body = ""
|
||||
for i in range(brace, len(body)):
|
||||
if body[i] == "{":
|
||||
depth += 1
|
||||
elif body[i] == "}":
|
||||
depth -= 1
|
||||
if depth == 0:
|
||||
timer_body = body[brace : i + 1]
|
||||
break
|
||||
assert "pointerEvents" in timer_body, (
|
||||
"showToast auto-hide timer no longer resets pointerEvents"
|
||||
)
|
||||
|
||||
|
||||
def test_action_toast_sets_pointer_events_auto():
|
||||
"""When an action button is present the toast must set pointer-events
|
||||
to 'auto' so the button is clickable."""
|
||||
src = _read_ui()
|
||||
body = _extract_function(src, "showToast")
|
||||
assert "pointerEvents = 'auto'" in body or 'pointerEvents = "auto"' in body, (
|
||||
"showToast no longer sets pointer-events:auto for action toasts"
|
||||
)
|
||||
|
||||
|
||||
def test_plain_toast_clears_pointer_events():
|
||||
"""When there is NO action button, showToast must clear any leftover
|
||||
pointer-events from a previous action toast."""
|
||||
src = _read_ui()
|
||||
body = _extract_function(src, "showToast")
|
||||
# The else-branch of the action check should reset pointerEvents
|
||||
assert "pointerEvents = ''" in body or 'pointerEvents = ""' in body, (
|
||||
"showToast does not clear pointer-events for non-action toasts"
|
||||
)
|
||||
@@ -1,165 +0,0 @@
|
||||
"""Protection test: the tool_implementations compatibility shim must keep
|
||||
re-exporting every symbol importers depend on.
|
||||
|
||||
Guards the slice-1 split (tool_implementations.py -> src/tools/*) from
|
||||
accidentally dropping a symbol. The contract is enforced by two
|
||||
self-verifying tests, not by the hand-maintained list below:
|
||||
|
||||
* ``test_shim_reexports_every_domain_do_function`` discovers every ``do_*``
|
||||
from the domain modules and asserts reachability through the shim.
|
||||
* ``test_every_facade_import_in_repo_resolves`` discovers every
|
||||
``from src.tool_implementations import X`` site across first-party Python
|
||||
dirs (src/, tests/, routes/, ...) and asserts ``X`` resolves through the
|
||||
shim.
|
||||
|
||||
Both fail automatically if a re-export is forgotten (the do_* discovery
|
||||
covers the tool surface; the import-site scan covers underscore helpers a
|
||||
reviewer's P3 finding showed could otherwise slip through the list). The
|
||||
``_EXPECTED`` list below is the curated historical surface (the original
|
||||
module's top-level names), kept as a belt-and-suspenders check and as the
|
||||
async-shape contract for ``do_*``; it is not the ground truth.
|
||||
"""
|
||||
|
||||
import inspect
|
||||
|
||||
import src.tool_implementations as ti
|
||||
|
||||
# 33 do_* tool functions
|
||||
_EXPECTED = [
|
||||
"do_adopt_served_model", "do_api_call", "do_app_api", "do_cancel_download",
|
||||
"do_download_model", "do_edit_image", "do_list_cached_models",
|
||||
"do_list_cookbook_servers", "do_list_downloads", "do_list_served_models",
|
||||
"do_list_serve_presets", "do_manage_calendar", "do_manage_contact",
|
||||
"do_manage_endpoints", "do_manage_mcp", "do_manage_notes",
|
||||
"do_manage_research", "do_manage_settings", "do_manage_skills",
|
||||
"do_manage_tasks", "do_manage_tokens", "do_manage_webhooks",
|
||||
"do_resolve_contact", "do_search_chats", "do_search_hf_models",
|
||||
"do_serve_model", "do_serve_preset", "do_stop_served_model",
|
||||
"do_tail_serve_output", "do_trigger_research", "do_vault_get",
|
||||
"do_vault_search", "do_vault_unlock",
|
||||
# module-private helpers (importable by name too)
|
||||
"_cookbook_apply_retry_suggestion", "_cookbook_env_for_host",
|
||||
"_cookbook_kill_session", "_cookbook_register_task", "_cookbook_servers",
|
||||
"_ensure_served_endpoint", "_infer_serve_host", "_infer_serve_port",
|
||||
"_internal_headers", "_load_vault_config", "_mcp_allowed_commands",
|
||||
"_parse_tool_args", "_resolve_cookbook_host", "_run_bw",
|
||||
"_scan_running_model_processes", "_skill_dump", "_string_arg",
|
||||
"_validate_cookbook_ssh_target",
|
||||
# active-email facade helpers (no do_* prefix); consumed by
|
||||
# routes/chat_routes.py — listed here because get_active_email has no
|
||||
# in-repo importer, so the import-site scan below can't see it alone.
|
||||
"set_active_email", "get_active_email", "clear_active_email",
|
||||
]
|
||||
|
||||
|
||||
def test_shim_reexports_all_top_level_symbols():
|
||||
"""Every original top-level function must remain importable via the module."""
|
||||
missing = [name for name in _EXPECTED if not hasattr(ti, name)]
|
||||
assert not missing, f"shim dropped symbols: {missing}"
|
||||
|
||||
|
||||
def test_do_functions_remain_async_through_shim():
|
||||
"""Every do_* must remain a coroutine function through the shim."""
|
||||
for name in _EXPECTED:
|
||||
if name.startswith("do_"):
|
||||
obj = getattr(ti, name)
|
||||
assert inspect.iscoroutinefunction(obj), (
|
||||
f"{name} is not async via shim (got {type(obj).__name__})"
|
||||
)
|
||||
|
||||
|
||||
# Domain modules that own tool implementations after the slice-1 split.
|
||||
# The shim must re-export every public do_* from each so existing
|
||||
# `from src.tool_implementations import do_X` imports keep resolving.
|
||||
_DOMAIN_MODULES = (
|
||||
"src.tools.system",
|
||||
"src.tools.cookbook",
|
||||
"src.tools.search",
|
||||
"src.tools.notes",
|
||||
"src.tools.calendar",
|
||||
"src.tools.image",
|
||||
"src.tools.research",
|
||||
"src.tools.contacts",
|
||||
"src.tools.vault",
|
||||
"src.agent_tools.admin_tools", # admin manage_* tools migrated here (#3629)
|
||||
)
|
||||
|
||||
|
||||
def test_shim_reexports_every_domain_do_function():
|
||||
"""Auto-discovered guard: every do_* defined in a domain module must be
|
||||
reachable through the shim.
|
||||
|
||||
The hand-maintained ``_EXPECTED`` list above can drift silently when a
|
||||
new tool is added to a domain module but not re-exported by the facade
|
||||
(exactly the omission a reviewer found post-split). This test discovers
|
||||
the ground truth from the domain modules themselves, so a forgotten
|
||||
re-export fails the build automatically. ``hasattr`` is used (not
|
||||
``dir(ti)``) because the admin symbols are re-exported lazily via
|
||||
module ``__getattr__`` and therefore do not appear in ``dir(ti)``.
|
||||
"""
|
||||
import importlib
|
||||
|
||||
dropped = []
|
||||
for mod_name in _DOMAIN_MODULES:
|
||||
mod = importlib.import_module(mod_name)
|
||||
for name in dir(mod):
|
||||
if not name.startswith("do_"):
|
||||
continue
|
||||
if not inspect.iscoroutinefunction(getattr(mod, name, None)):
|
||||
continue
|
||||
if not hasattr(ti, name):
|
||||
dropped.append(f"{mod_name}.{name}")
|
||||
assert not dropped, f"shim dropped domain do_* (re-export forgotten): {dropped}"
|
||||
|
||||
|
||||
def test_every_facade_import_in_repo_resolves():
|
||||
"""Every ``from src.tool_implementations import X`` in any first-party
|
||||
Python dir (src/, tests/, routes/, ...) must resolve through the shim.
|
||||
|
||||
This makes the module-docstring contract ("existing ``from
|
||||
src.tool_implementations import X`` imports keep working") self-verifying
|
||||
instead of reliant on the hand-maintained ``_EXPECTED`` list, which
|
||||
omitted three underscore helpers in a reviewer's P3 finding and can drift
|
||||
again. The import sites are enumerated with ``ast`` rather than checked
|
||||
at runtime because the invariant is *which names the rest of the
|
||||
codebase asks the facade for* — no runtime hook enumerates that set,
|
||||
only the import statements do (the narrow source-scanning exception to
|
||||
the behavioral-first rule). The per-name assertion is runtime
|
||||
(``hasattr``), so any forgotten re-export — helper or ``do_*`` — fails
|
||||
here automatically.
|
||||
"""
|
||||
import ast
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
repo = Path(__file__).resolve().parents[1]
|
||||
# Walk every first-party Python dir so route-level (and any future)
|
||||
# facade consumers are covered, not just src/ and tests/. Prune
|
||||
# non-source trees (venvs, caches, data, build artifacts) in-place.
|
||||
_SKIP_DIRS = {
|
||||
"__pycache__", "venv", "node_modules", "data", "logs",
|
||||
"odysseus.egg-info", "static", "specs", "licenses", "docker",
|
||||
}
|
||||
names = set()
|
||||
for root, _dirs, files in os.walk(repo):
|
||||
_dirs[:] = [d for d in _dirs if not (d.startswith(".") or d in _SKIP_DIRS)]
|
||||
for fn in files:
|
||||
if not fn.endswith(".py"):
|
||||
continue
|
||||
path = Path(root) / fn
|
||||
text = path.read_text(encoding="utf-8")
|
||||
if "src.tool_implementations" not in text:
|
||||
continue
|
||||
try:
|
||||
tree = ast.parse(text, filename=str(path))
|
||||
except SyntaxError:
|
||||
continue # unrelated to the facade contract
|
||||
for node in ast.walk(tree):
|
||||
if isinstance(node, ast.ImportFrom) and node.module == "src.tool_implementations":
|
||||
for alias in node.names:
|
||||
if alias.name != "*":
|
||||
names.add(alias.name)
|
||||
unresolved = sorted(n for n in names if not hasattr(ti, n))
|
||||
assert not unresolved, (
|
||||
f"facade consumers import names the shim does not re-export: {unresolved}"
|
||||
)
|
||||
@@ -102,7 +102,7 @@ def test_unlock_handler_feeds_password_on_stdin_not_argv():
|
||||
|
||||
|
||||
def test_tool_vault_unlock_feeds_password_on_stdin_not_argv():
|
||||
text = open("src/tools/vault.py", encoding="utf-8").read()
|
||||
text = open("src/tool_implementations.py", encoding="utf-8").read()
|
||||
|
||||
assert '["unlock", master_password, "--raw"]' not in text
|
||||
assert '_run_bw(["unlock", master_password' not in text
|
||||
|
||||
Reference in New Issue
Block a user