mirror of
https://github.com/pewdiepie-archdaemon/odysseus.git
synced 2026-06-16 09:45:24 -04:00
Compare commits
199 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 94d2754f41 | |||
| a240f28af9 | |||
| e7c1d75884 | |||
| f7ae85590b | |||
| 62ffcb6236 | |||
| 85c6056c87 | |||
| 049833e309 | |||
| 4e497f4878 | |||
| 5462030cde | |||
| 0a324f20d2 | |||
| 8e494cc1c4 | |||
| 932b7f2446 | |||
| a58f526992 | |||
| ed6cc88974 | |||
| 5198516979 | |||
| 095c74b985 | |||
| 34a3f8637a | |||
| 8449baea80 | |||
| d58202d10e | |||
| 1209f258d7 | |||
| d71284194b | |||
| d458cade98 | |||
| fe19d072e3 | |||
| 09565acc1e | |||
| d6882a895e | |||
| 4a9085d252 | |||
| aab203cf51 | |||
| ab2f7cffca | |||
| fe8d8cd020 | |||
| 233390546c | |||
| 1e0d9b92af | |||
| ac94885c84 | |||
| adc6ac9394 | |||
| fa7c4f8ea9 | |||
| 77b75ca97e | |||
| 505d8bae5a | |||
| 9c90f62657 | |||
| 73315e6ddc | |||
| 7b68413433 | |||
| 3557a3f495 | |||
| c46ea44f43 | |||
| a017108d41 | |||
| 9ad6a2809e | |||
| 92300b5d67 | |||
| dd4cdaf251 | |||
| 1a0e1c5d69 | |||
| 76c1f42ab0 | |||
| d85c5e335e | |||
| f939cb65ce | |||
| 865e61450e | |||
| 8746c9c0df | |||
| f7c0b3f23b | |||
| e3e37ce526 | |||
| a8859bb25c | |||
| 6c9a16a7a8 | |||
| accdc4fc53 | |||
| 3a91c11ff8 | |||
| 00e8084969 | |||
| c9198baa2e | |||
| 55343e89fb | |||
| 681a2a3f2a | |||
| d7ece5b4a9 | |||
| 5dff35ba03 | |||
| b22c2b280c | |||
| a6bc1addd2 | |||
| 2a422c00ec | |||
| 8cfc5bb28f | |||
| 8d9d4ec9c6 | |||
| 8f2c8d2dc8 | |||
| 613bbb0dba | |||
| 8f5b7210cc | |||
| 2a6921a455 | |||
| b8463e3ac2 | |||
| 92ef01d4fa | |||
| c5ac89f01f | |||
| b9a96bca1a | |||
| 706ea6a7b7 | |||
| 12cb39cbd9 | |||
| 43c16fc7e4 | |||
| c75d3e1975 | |||
| 3c924b8dee | |||
| adbcb3763f | |||
| bdf4ec8b24 | |||
| 5d3e3c7053 | |||
| 04d6a5ccaa | |||
| a3784da172 | |||
| cbbb41dfb1 | |||
| 83b0ab7cd3 | |||
| 12a7e741d0 | |||
| 573d431399 | |||
| 2149f0fb67 | |||
| 83fca6ac62 | |||
| 000932a6d9 | |||
| 299538ea4e | |||
| 67aeea4f8b | |||
| f2a79aaf5c | |||
| a6490ffb1b | |||
| 06d28e23ac | |||
| 7b4e6c4c1b | |||
| 3cff06781e | |||
| ff4508d396 | |||
| c11ce66e0e | |||
| 34bd8f0491 | |||
| f78539ba15 | |||
| 95c2dca4b5 | |||
| 3940297655 | |||
| a3cb15d0a1 | |||
| 108ee1e32b | |||
| b03d934ec6 | |||
| eb840459f5 | |||
| 6ccd4500d7 | |||
| 2e37d72155 | |||
| fb9c7cf3da | |||
| 33edc40eae | |||
| e87a1ad8d2 | |||
| 893cb8254f | |||
| 870ae2823f | |||
| 86abcb75d0 | |||
| 463713c2c6 | |||
| c2017fa089 | |||
| 53fd856ea8 | |||
| 66599b02a2 | |||
| fb3e89b011 | |||
| f72e1bd412 | |||
| 2bdf43b74d | |||
| c8b4cd24e0 | |||
| f4aa661502 | |||
| 5911b8c0dc | |||
| 08e543d1ff | |||
| 47a47bf71d | |||
| 71dda5b106 | |||
| ad82ee1c83 | |||
| 545e692565 | |||
| fa9f62b44c | |||
| b448119919 | |||
| 977daf0643 | |||
| 8ce945d338 | |||
| 2e207fc315 | |||
| 01f1278811 | |||
| 4bfe0c690a | |||
| c9d0c6db18 | |||
| 6973c5427c | |||
| 8354948a1c | |||
| 8159733c6c | |||
| 05f047b188 | |||
| e9ff6cde77 | |||
| 747d005645 | |||
| bec594904d | |||
| ec8fbf5d8f | |||
| b5c45326e4 | |||
| 452a94fb1b | |||
| 301d1109b5 | |||
| 370ae5d451 | |||
| 6d64055328 | |||
| 0b0d747f1c | |||
| 688194113b | |||
| 2a1febdeef | |||
| 0f8d12363a | |||
| 201e207b56 | |||
| 65231f2ba1 | |||
| 4f0133b8c3 | |||
| f9e1d38cc2 | |||
| 0a2adc9c96 | |||
| 621885ac06 | |||
| 30173f3909 | |||
| f5d834b0c5 | |||
| 367858a587 | |||
| b19e5693af | |||
| 11ba46505b | |||
| d4d168f972 | |||
| 194985b5e1 | |||
| 0dc051dea3 | |||
| 8b386a172e | |||
| 2cae5a681d | |||
| 46f128b9df | |||
| 4df4cfeaff | |||
| e0e250d023 | |||
| ec7691956b | |||
| 04df7255fb | |||
| 3ef73013eb | |||
| 17b62a3dba | |||
| e0097c9c48 | |||
| 9ffa87e394 | |||
| cfb2d17a2d | |||
| 5271d529d6 | |||
| a9c1c698b0 | |||
| 88c9f1fa74 | |||
| 2ba77e3aa3 | |||
| fbd34334a5 | |||
| e2f449f4ef | |||
| 43a101d305 | |||
| f8aaeab245 | |||
| f19ac6ed03 | |||
| a260e0abd4 | |||
| b98ee04e2f | |||
| 4ed48baf68 | |||
| a19b6d2d4d | |||
| 9112861d8e | |||
| 911fd61100 |
@@ -56,6 +56,13 @@ SEARXNG_INSTANCE=http://localhost:8080
|
|||||||
# SQLite database path (default: sqlite:///./data/app.db)
|
# SQLite database path (default: sqlite:///./data/app.db)
|
||||||
# DATABASE_URL=sqlite:///./data/app.db
|
# DATABASE_URL=sqlite:///./data/app.db
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# Data directory
|
||||||
|
# ============================================================
|
||||||
|
# Move everything that lives under data/ - settings, sessions, database, auth,
|
||||||
|
# cache, uploads, etc. - to another path:
|
||||||
|
# ODYSSEUS_DATA_DIR=C:\path\to\dir
|
||||||
|
|
||||||
# ============================================================
|
# ============================================================
|
||||||
# Auth & Security
|
# Auth & Security
|
||||||
# ============================================================
|
# ============================================================
|
||||||
@@ -112,6 +119,9 @@ SEARXNG_INSTANCE=http://localhost:8080
|
|||||||
# Default: http://{LLM_HOST}:11434/v1/embeddings (ollama)
|
# Default: http://{LLM_HOST}:11434/v1/embeddings (ollama)
|
||||||
# EMBEDDING_URL=http://localhost:11434/v1/embeddings
|
# EMBEDDING_URL=http://localhost:11434/v1/embeddings
|
||||||
|
|
||||||
|
# Embedding API key (if there's one)
|
||||||
|
# EMBEDDING_API_KEY=embedding_api_key_here
|
||||||
|
|
||||||
# Embedding model name (must be available at the endpoint above)
|
# Embedding model name (must be available at the endpoint above)
|
||||||
# EMBEDDING_MODEL=all-minilm:l6-v2
|
# EMBEDDING_MODEL=all-minilm:l6-v2
|
||||||
|
|
||||||
@@ -144,6 +154,21 @@ SEARXNG_INSTANCE=http://localhost:8080
|
|||||||
# if you intentionally want scheduled scripts to run remotely.
|
# if you intentionally want scheduled scripts to run remotely.
|
||||||
# ODYSSEUS_SCRIPT_HOST=localhost
|
# ODYSSEUS_SCRIPT_HOST=localhost
|
||||||
|
|
||||||
|
# Chat / agent attachment size cap in bytes (default: 10 MB).
|
||||||
|
# Raise this for local installs that need larger PDFs or text documents.
|
||||||
|
# Example: 52428800 = 50 MB.
|
||||||
|
# ODYSSEUS_CHAT_UPLOAD_MAX_BYTES=10485760
|
||||||
|
|
||||||
|
# Other per-feature upload size caps in bytes. All are validated and optional;
|
||||||
|
# defaults shown. An invalid value (non-integer or < 1) fails fast at startup.
|
||||||
|
# ODYSSEUS_GALLERY_UPLOAD_MAX_BYTES=104857600 # gallery image upload (100 MB)
|
||||||
|
# ODYSSEUS_GALLERY_TRANSFORM_UPLOAD_MAX_BYTES=26214400 # gallery transform input (25 MB)
|
||||||
|
# ODYSSEUS_MEMORY_IMPORT_MAX_BYTES=10485760 # memory import file (10 MB)
|
||||||
|
# ODYSSEUS_PERSONAL_UPLOAD_MAX_BYTES=26214400 # personal document upload (25 MB)
|
||||||
|
# ODYSSEUS_EMAIL_COMPOSE_UPLOAD_MAX_BYTES=26214400 # email compose attachment (25 MB)
|
||||||
|
# ODYSSEUS_STT_MAX_AUDIO_BYTES=26214400 # speech-to-text audio (25 MB)
|
||||||
|
# ODYSSEUS_ICS_MAX_BYTES=10485760 # calendar .ics import (10 MB)
|
||||||
|
|
||||||
# ============================================================
|
# ============================================================
|
||||||
# GPU support (Docker Compose)
|
# GPU support (Docker Compose)
|
||||||
# ============================================================
|
# ============================================================
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ body:
|
|||||||
required: true
|
required: true
|
||||||
- label: This is **not** a security vulnerability. (Vulnerabilities go to [GitHub Security Advisories](https://github.com/pewdiepie-archdaemon/odysseus/security/advisories/new) — see [SECURITY.md](https://github.com/pewdiepie-archdaemon/odysseus/blob/main/SECURITY.md).)
|
- label: This is **not** a security vulnerability. (Vulnerabilities go to [GitHub Security Advisories](https://github.com/pewdiepie-archdaemon/odysseus/security/advisories/new) — see [SECURITY.md](https://github.com/pewdiepie-archdaemon/odysseus/blob/main/SECURITY.md).)
|
||||||
required: true
|
required: true
|
||||||
- label: I am running the latest code from `main`.
|
- label: I am running the latest code from the `dev` branch (the default branch you get on clone, where fixes land first) and the bug still reproduces there. Please `git pull` the latest `dev` before filing.
|
||||||
required: true
|
required: true
|
||||||
|
|
||||||
- type: dropdown
|
- type: dropdown
|
||||||
|
|||||||
@@ -103,14 +103,21 @@ module.exports = async ({ github, context, core }) => {
|
|||||||
|
|
||||||
async function swapLabel(num, add, remove) {
|
async function swapLabel(num, add, remove) {
|
||||||
if (await labelExists(add)) {
|
if (await labelExists(add)) {
|
||||||
await github.rest.issues.addLabels({ owner, repo, issue_number: num, labels: [add] });
|
try {
|
||||||
|
await github.rest.issues.addLabels({ owner, repo, issue_number: num, labels: [add] });
|
||||||
|
} catch (e) {
|
||||||
|
// Fail soft on a token that can't write labels so a label permission
|
||||||
|
// problem never masks the actual description verdict.
|
||||||
|
if (e.status !== 403) throw e;
|
||||||
|
core.warning(`Could not add "${add}" — token lacks label write here; skipping.`);
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
core.warning(`Label "${add}" does not exist in the repo — skipping. Create it once to enable labelling.`);
|
core.warning(`Label "${add}" does not exist in the repo — skipping. Create it once to enable labelling.`);
|
||||||
}
|
}
|
||||||
try {
|
try {
|
||||||
await github.rest.issues.removeLabel({ owner, repo, issue_number: num, name: remove });
|
await github.rest.issues.removeLabel({ owner, repo, issue_number: num, name: remove });
|
||||||
} catch (e) {
|
} catch (e) {
|
||||||
if (e.status !== 404 && e.status !== 410) throw e;
|
if (e.status !== 404 && e.status !== 410 && e.status !== 403) throw e;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -20,6 +20,8 @@ jobs:
|
|||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4
|
- uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4
|
||||||
|
with:
|
||||||
|
persist-credentials: false
|
||||||
- uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5
|
- uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5
|
||||||
with:
|
with:
|
||||||
python-version: "3.11"
|
python-version: "3.11"
|
||||||
@@ -31,6 +33,8 @@ jobs:
|
|||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4
|
- uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4
|
||||||
|
with:
|
||||||
|
persist-credentials: false
|
||||||
- uses: actions/setup-node@49933ea5288caeca8642d1e84afbd3f7d6820020 # v4
|
- uses: actions/setup-node@49933ea5288caeca8642d1e84afbd3f7d6820020 # v4
|
||||||
with:
|
with:
|
||||||
node-version: "20"
|
node-version: "20"
|
||||||
@@ -51,10 +55,40 @@ jobs:
|
|||||||
continue-on-error: true
|
continue-on-error: true
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4
|
- uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4
|
||||||
|
with:
|
||||||
|
fetch-depth: 0
|
||||||
|
persist-credentials: false
|
||||||
|
|
||||||
|
# Detect whether this PR only touches documentation files.
|
||||||
|
# If so, skip the expensive pytest run while still reporting a passing check.
|
||||||
|
- name: Check for docs-only changes
|
||||||
|
id: docs-check
|
||||||
|
run: |
|
||||||
|
if [ "${{ github.event_name }}" = "pull_request" ]; then
|
||||||
|
BASE="${{ github.event.pull_request.base.sha }}"
|
||||||
|
HEAD="${{ github.event.pull_request.head.sha }}"
|
||||||
|
else
|
||||||
|
BASE="${{ github.event.before }}"
|
||||||
|
HEAD="${{ github.sha }}"
|
||||||
|
fi
|
||||||
|
# List all changed files; if every file matches docs/markdown patterns, skip pytest.
|
||||||
|
changed=$(git diff --name-only "$BASE" "$HEAD" 2>/dev/null || git diff --name-only HEAD~1 HEAD)
|
||||||
|
non_docs=$(echo "$changed" | grep -Ev '^(docs/|.*\.md$|\.github/[^/]+\.md$)' || true)
|
||||||
|
if [ -z "$non_docs" ]; then
|
||||||
|
echo "docs_only=true" >> "$GITHUB_OUTPUT"
|
||||||
|
echo "Docs-only change detected — skipping pytest."
|
||||||
|
else
|
||||||
|
echo "docs_only=false" >> "$GITHUB_OUTPUT"
|
||||||
|
fi
|
||||||
|
|
||||||
- uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5
|
- uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5
|
||||||
|
if: steps.docs-check.outputs.docs_only != 'true'
|
||||||
with:
|
with:
|
||||||
python-version: "3.11"
|
python-version: "3.11"
|
||||||
cache: pip
|
cache: pip
|
||||||
- run: pip install -r requirements.txt
|
- run: pip install -r requirements.txt
|
||||||
|
if: steps.docs-check.outputs.docs_only != 'true'
|
||||||
- run: mkdir -p data # sqlite DB lives at ./data/app.db
|
- run: mkdir -p data # sqlite DB lives at ./data/app.db
|
||||||
|
if: steps.docs-check.outputs.docs_only != 'true'
|
||||||
- run: python -m pytest -q
|
- run: python -m pytest -q
|
||||||
|
if: steps.docs-check.outputs.docs_only != 'true'
|
||||||
|
|||||||
@@ -0,0 +1,140 @@
|
|||||||
|
name: ci / docker publish
|
||||||
|
|
||||||
|
# Build the Odysseus image and publish to GHCR.
|
||||||
|
# push to main -> :latest, :X.Y.Z (curated release; main is fast-forwarded at releases)
|
||||||
|
# push to dev -> :dev, :X.Y.Z-dev.<sha> (rolling dev + an immutable, traceable pin)
|
||||||
|
# Multi-arch (linux/amd64 + linux/arm64): each arch builds on its own native
|
||||||
|
# runner and pushes by digest, then a merge job stitches the digests into one
|
||||||
|
# manifest list and applies the tags (faster + cleaner than QEMU emulation).
|
||||||
|
# Registry: ghcr.io/<owner>/<repo>.
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches: [dev, main]
|
||||||
|
paths-ignore:
|
||||||
|
- '**.md'
|
||||||
|
- 'docs/**'
|
||||||
|
- '.github/ISSUE_TEMPLATE/**'
|
||||||
|
|
||||||
|
concurrency:
|
||||||
|
group: docker-publish-${{ github.ref }}
|
||||||
|
cancel-in-progress: true
|
||||||
|
|
||||||
|
permissions:
|
||||||
|
contents: read
|
||||||
|
|
||||||
|
env:
|
||||||
|
REGISTRY: ghcr.io
|
||||||
|
IMAGE_NAME: ${{ github.repository }}
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
build:
|
||||||
|
name: build (${{ matrix.arch }})
|
||||||
|
runs-on: ${{ matrix.runner }}
|
||||||
|
permissions:
|
||||||
|
contents: read
|
||||||
|
packages: write
|
||||||
|
strategy:
|
||||||
|
fail-fast: false
|
||||||
|
matrix:
|
||||||
|
include:
|
||||||
|
- platform: linux/amd64
|
||||||
|
arch: amd64
|
||||||
|
runner: ubuntu-latest
|
||||||
|
- platform: linux/arm64
|
||||||
|
arch: arm64
|
||||||
|
runner: ubuntu-24.04-arm
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
|
||||||
|
with:
|
||||||
|
persist-credentials: false
|
||||||
|
- name: Set up Buildx
|
||||||
|
uses: docker/setup-buildx-action@d7f5e7f509e45cec5c76c4d5afdd7de93d0b3df5 # v4.1.0
|
||||||
|
- name: Log in to GHCR
|
||||||
|
uses: docker/login-action@650006c6eb7dba73a995cc03b0b2d7f5ca915bee # v4.2.0
|
||||||
|
with:
|
||||||
|
registry: ${{ env.REGISTRY }}
|
||||||
|
username: ${{ github.actor }}
|
||||||
|
password: ${{ secrets.GITHUB_TOKEN }}
|
||||||
|
- name: Build and push by digest
|
||||||
|
id: build
|
||||||
|
uses: docker/build-push-action@f9f3042f7e2789586610d6e8b85c8f03e5195baf # v7.2.0
|
||||||
|
with:
|
||||||
|
context: .
|
||||||
|
platforms: ${{ matrix.platform }}
|
||||||
|
outputs: type=image,name=${{ env.REGISTRY }}/${{ env.IMAGE_NAME }},push-by-digest=true,name-canonical=true,push=true
|
||||||
|
cache-from: type=gha,scope=${{ matrix.arch }}
|
||||||
|
cache-to: type=gha,mode=max,scope=${{ matrix.arch }}
|
||||||
|
- name: Export digest
|
||||||
|
run: |
|
||||||
|
mkdir -p /tmp/digests
|
||||||
|
digest="${{ steps.build.outputs.digest }}"
|
||||||
|
touch "/tmp/digests/${digest#sha256:}"
|
||||||
|
- name: Upload digest
|
||||||
|
uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7.0.1
|
||||||
|
with:
|
||||||
|
name: digest-${{ matrix.arch }}
|
||||||
|
path: /tmp/digests/*
|
||||||
|
if-no-files-found: error
|
||||||
|
retention-days: 1
|
||||||
|
|
||||||
|
merge:
|
||||||
|
name: merge manifest + tag
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
needs: build
|
||||||
|
permissions:
|
||||||
|
contents: read
|
||||||
|
packages: write
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
|
||||||
|
with:
|
||||||
|
persist-credentials: false
|
||||||
|
- name: Read APP_VERSION + short sha
|
||||||
|
id: ver
|
||||||
|
run: |
|
||||||
|
v=$(grep -E '^APP_VERSION' src/constants.py | head -1 | sed -E 's/.*"([^"]+)".*/\1/')
|
||||||
|
[ -n "$v" ] || { echo "APP_VERSION not found"; exit 1; }
|
||||||
|
echo "version=$v" >> "$GITHUB_OUTPUT"
|
||||||
|
echo "short=${GITHUB_SHA::7}" >> "$GITHUB_OUTPUT"
|
||||||
|
- name: Download digests
|
||||||
|
uses: actions/download-artifact@3e5f45b2cfb9172054b4087a40e8e0b5a5461e7c # v8.0.1
|
||||||
|
with:
|
||||||
|
path: /tmp/digests
|
||||||
|
pattern: digest-*
|
||||||
|
merge-multiple: true
|
||||||
|
- name: Set up Buildx
|
||||||
|
uses: docker/setup-buildx-action@d7f5e7f509e45cec5c76c4d5afdd7de93d0b3df5 # v4.1.0
|
||||||
|
- name: Log in to GHCR
|
||||||
|
uses: docker/login-action@650006c6eb7dba73a995cc03b0b2d7f5ca915bee # v4.2.0
|
||||||
|
with:
|
||||||
|
registry: ${{ env.REGISTRY }}
|
||||||
|
username: ${{ github.actor }}
|
||||||
|
password: ${{ secrets.GITHUB_TOKEN }}
|
||||||
|
- name: Compute tags
|
||||||
|
id: meta
|
||||||
|
uses: docker/metadata-action@80c7e94dd9b9319bd5eb7a0e0fe9291e23a2a2e9 # v6.1.0
|
||||||
|
with:
|
||||||
|
images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}
|
||||||
|
tags: |
|
||||||
|
type=raw,value=latest,enable=${{ github.ref == 'refs/heads/main' }}
|
||||||
|
type=raw,value=${{ steps.ver.outputs.version }},enable=${{ github.ref == 'refs/heads/main' }}
|
||||||
|
type=raw,value=dev,enable=${{ github.ref == 'refs/heads/dev' }}
|
||||||
|
type=raw,value=${{ steps.ver.outputs.version }}-dev.${{ steps.ver.outputs.short }},enable=${{ github.ref == 'refs/heads/dev' }}
|
||||||
|
- name: Create manifest list + push tags
|
||||||
|
working-directory: /tmp/digests
|
||||||
|
run: |
|
||||||
|
tags=$(jq -cr '.tags | map("-t " + .) | join(" ")' <<< "$DOCKER_METADATA_OUTPUT_JSON")
|
||||||
|
digests=$(printf "${REGISTRY}/${IMAGE_NAME}@sha256:%s " *)
|
||||||
|
# word-splitting is intended: $tags and $digests each expand to multiple args
|
||||||
|
# shellcheck disable=SC2086
|
||||||
|
docker buildx imagetools create $tags $digests
|
||||||
|
env:
|
||||||
|
REGISTRY: ${{ env.REGISTRY }}
|
||||||
|
IMAGE_NAME: ${{ env.IMAGE_NAME }}
|
||||||
|
- name: Inspect
|
||||||
|
run: |
|
||||||
|
if [ "$GITHUB_REF" = "refs/heads/main" ]; then ref=latest; else ref=dev; fi
|
||||||
|
docker buildx imagetools inspect "${REGISTRY}/${IMAGE_NAME}:${ref}"
|
||||||
|
env:
|
||||||
|
REGISTRY: ${{ env.REGISTRY }}
|
||||||
|
IMAGE_NAME: ${{ env.IMAGE_NAME }}
|
||||||
@@ -14,10 +14,11 @@ jobs:
|
|||||||
# Skip bots (Dependabot, release-drafter, etc.)
|
# Skip bots (Dependabot, release-drafter, etc.)
|
||||||
if: ${{ github.event.issue.user.type != 'Bot' }}
|
if: ${{ github.event.issue.user.type != 'Bot' }}
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
|
||||||
with:
|
with:
|
||||||
sparse-checkout: .github/scripts
|
sparse-checkout: .github/scripts
|
||||||
|
persist-credentials: false
|
||||||
|
|
||||||
- uses: actions/github-script@v7
|
- uses: actions/github-script@3a2844b7e9c422d3c10d287c895573f7108da1b3 # v9.0.0
|
||||||
with:
|
with:
|
||||||
script: return require('./.github/scripts/check-issue-description.js')({github, context, core})
|
script: return require('./.github/scripts/check-issue-description.js')({github, context, core})
|
||||||
|
|||||||
@@ -1,28 +1,109 @@
|
|||||||
name: ci / PR description check
|
name: ci / PR checks
|
||||||
|
|
||||||
on:
|
on:
|
||||||
pull_request_target:
|
# pull_request_target runs in the base-repo context (has secrets) so the check
|
||||||
types: [opened, edited, synchronize, reopened]
|
# works on fork PRs. Safe here: the checkout pins to the base branch (no fork
|
||||||
|
# code runs) and the scripts only read context.payload and call the GitHub API.
|
||||||
|
pull_request_target: # zizmor: ignore[dangerous-triggers]
|
||||||
|
types: [opened, edited, synchronize, reopened, ready_for_review]
|
||||||
|
|
||||||
# pull_request_target runs in the base-repo context (has secrets).
|
# Default-deny at the workflow level; each job opts into only the scopes it needs.
|
||||||
# The checkout below pins to the base branch so no fork code is executed.
|
# Note: modifying a PR's labels/comments needs pull-requests:write even though the
|
||||||
# The script only reads context.payload and calls the GitHub API.
|
# REST path is under /issues/{n}/...; issues:write alone returns 403 on PRs.
|
||||||
permissions:
|
permissions: {}
|
||||||
issues: write
|
|
||||||
pull-requests: write
|
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
check-description:
|
check-description:
|
||||||
name: Check PR description
|
name: Check PR description
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
# Skip bots — they open PRs programmatically and have their own process.
|
permissions:
|
||||||
|
contents: read
|
||||||
|
pull-requests: write
|
||||||
|
issues: write
|
||||||
|
# Skip bots: they open PRs programmatically and have their own process.
|
||||||
if: github.event.pull_request.user.type != 'Bot'
|
if: github.event.pull_request.user.type != 'Bot'
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
|
||||||
with:
|
with:
|
||||||
ref: ${{ github.base_ref }}
|
ref: ${{ github.base_ref }}
|
||||||
sparse-checkout: .github/scripts
|
sparse-checkout: .github/scripts
|
||||||
|
persist-credentials: false
|
||||||
|
|
||||||
- uses: actions/github-script@v7
|
- uses: actions/github-script@3a2844b7e9c422d3c10d287c895573f7108da1b3 # v9.0.0
|
||||||
with:
|
with:
|
||||||
script: return require('./.github/scripts/check-pr-description.js')({github, context, core})
|
script: return require('./.github/scripts/check-pr-description.js')({github, context, core})
|
||||||
|
|
||||||
|
check-title:
|
||||||
|
name: Check PR title (Conventional Commits)
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
permissions: {}
|
||||||
|
# Skip bots: they open PRs programmatically and have their own process.
|
||||||
|
if: github.event.pull_request.user.type != 'Bot'
|
||||||
|
steps:
|
||||||
|
- uses: actions/github-script@3a2844b7e9c422d3c10d287c895573f7108da1b3 # v9.0.0
|
||||||
|
with:
|
||||||
|
script: |
|
||||||
|
const title = context.payload.pull_request.title || "";
|
||||||
|
// Conventional Commits: type(optional-scope)(optional !): summary
|
||||||
|
const re = /^(feat|fix|docs|style|refactor|perf|test|build|ci|chore|revert)(\([\w .\/-]+\))?!?: .+/;
|
||||||
|
if (!re.test(title)) {
|
||||||
|
core.setFailed(
|
||||||
|
`PR title is not in Conventional Commits format:\n "${title}"\n\n` +
|
||||||
|
`Expected: type(scope): summary\n` +
|
||||||
|
`Example: fix(search): handle empty query\n` +
|
||||||
|
`Types: feat, fix, docs, style, refactor, perf, test, build, ci, chore, revert.`
|
||||||
|
);
|
||||||
|
} else {
|
||||||
|
core.info(`PR title OK: ${title}`);
|
||||||
|
}
|
||||||
|
|
||||||
|
check-mergeable:
|
||||||
|
name: Flag unmergeable PRs
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
permissions:
|
||||||
|
pull-requests: write
|
||||||
|
issues: write
|
||||||
|
# Skip bots: they open PRs programmatically and have their own process.
|
||||||
|
if: github.event.pull_request.user.type != 'Bot'
|
||||||
|
steps:
|
||||||
|
- uses: actions/github-script@3a2844b7e9c422d3c10d287c895573f7108da1b3 # v9.0.0
|
||||||
|
with:
|
||||||
|
script: |
|
||||||
|
const repo = { owner: context.repo.owner, repo: context.repo.repo };
|
||||||
|
const number = context.payload.pull_request.number;
|
||||||
|
const READY = "ready for review";
|
||||||
|
const CONFLICT = "merge conflict";
|
||||||
|
|
||||||
|
// Ensure the conflict label exists (red). Ignore if already present.
|
||||||
|
try {
|
||||||
|
await github.rest.issues.getLabel({ ...repo, name: CONFLICT });
|
||||||
|
} catch {
|
||||||
|
await github.rest.issues.createLabel({
|
||||||
|
...repo, name: CONFLICT, color: "B60205",
|
||||||
|
description: "Conflicts with the base branch; needs a rebase before review.",
|
||||||
|
}).catch(() => {});
|
||||||
|
}
|
||||||
|
|
||||||
|
// mergeable is computed asynchronously and is often null right after
|
||||||
|
// an event, so poll a few times until GitHub has resolved it.
|
||||||
|
let pr = null;
|
||||||
|
for (let i = 0; i < 5; i++) {
|
||||||
|
const { data } = await github.rest.pulls.get({ ...repo, pull_number: number });
|
||||||
|
if (data.mergeable !== null) { pr = data; break; }
|
||||||
|
await new Promise(r => setTimeout(r, 3000));
|
||||||
|
}
|
||||||
|
if (!pr || pr.draft) return;
|
||||||
|
const labels = pr.labels.map(l => l.name);
|
||||||
|
|
||||||
|
if (pr.mergeable === false) {
|
||||||
|
if (labels.includes(READY)) {
|
||||||
|
await github.rest.issues.removeLabel({ ...repo, issue_number: number, name: READY }).catch(() => {});
|
||||||
|
}
|
||||||
|
if (!labels.includes(CONFLICT)) {
|
||||||
|
await github.rest.issues.addLabels({ ...repo, issue_number: number, labels: [CONFLICT] });
|
||||||
|
}
|
||||||
|
} else if (pr.mergeable === true) {
|
||||||
|
if (labels.includes(CONFLICT)) {
|
||||||
|
await github.rest.issues.removeLabel({ ...repo, issue_number: number, name: CONFLICT }).catch(() => {});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -94,6 +94,18 @@ Before submitting any change that affects what the app looks like — buttons, i
|
|||||||
|
|
||||||
If you are unsure whether a change is "visual," it is. Default to attaching a screenshot.
|
If you are unsure whether a change is "visual," it is. Default to attaching a screenshot.
|
||||||
|
|
||||||
|
## Code conventions
|
||||||
|
|
||||||
|
Don't hardcode values that the project already exposes through a constant or a helper. Hardcoded literals drift out of sync, break on non-default deployments, and reintroduce bugs we've already fixed.
|
||||||
|
|
||||||
|
- **Filesystem paths:** never build writable paths from `Path(__file__)...` into the source tree, hardcode `/app/...`, or use a relative `"data/..."` string. Every persisted file and directory has a named constant in `src/constants.py` (for example `AUTH_FILE`, `USER_PREFS_FILE`, `SETTINGS_FILE`, `TTS_CACHE_DIR`, `CHROMA_DIR`). Import and use that named constant; do not re-derive the path locally with `os.path.join(DATA_DIR, "x.json")` or `DATA_DIR / "x.json"`. `DATA_DIR` is the single place that reads `ODYSSEUS_DATA_DIR`, so use it directly only for dynamic paths that have no fixed name (for example per-owner files). If a data file or directory has no constant yet, add one to `src/constants.py`. The source tree is read-only in Docker and `/app/...` does not exist on native runs; guard directory creation so an unwritable path degrades gracefully instead of crashing at import.
|
||||||
|
- **Internal API / loopback URLs:** don't hardcode `http://localhost:7000`. Use `internal_api_base()` from `src.constants` (it honors `ODYSSEUS_INTERNAL_BASE` / `APP_PORT`).
|
||||||
|
- **Ports, limits, model lists, and similar:** reuse the existing constant if one exists; if it doesn't and the value is used in more than one place, add a constant rather than copying the literal.
|
||||||
|
|
||||||
|
If you need a value that has no constant or helper yet, add it to `src/constants.py` (the single source of truth for paths and config; `core/constants.py` only re-exports it for backward compatibility) and import it, rather than repeating a literal across files.
|
||||||
|
|
||||||
|
**Commits:** use [Conventional Commits](https://www.conventionalcommits.org), `type(scope): summary` (e.g. `fix(search): ...`, `feat(notes): ...`, `docs(contributing): ...`). Common types: `fix`, `feat`, `refactor`, `docs`, `test`, `chore`, `ci`. Keep the subject short and imperative; put the "why" in the body when it isn't obvious.
|
||||||
|
|
||||||
## Issue Reports
|
## Issue Reports
|
||||||
|
|
||||||
For bugs, include:
|
For bugs, include:
|
||||||
|
|||||||
@@ -1,5 +1,7 @@
|
|||||||
# Odysseus
|
# Odysseus
|
||||||
|
|
||||||
|
> **Branch note:** `dev` is the default branch and contains the latest development changes, but it may be unstable. For the more stable curated branch, use [`main`](https://github.com/pewdiepie-archdaemon/odysseus/tree/main).
|
||||||
|
|
||||||
```
|
```
|
||||||
───────────────────────────────────────────────
|
───────────────────────────────────────────────
|
||||||
⊹ ࣪ ˖ ૮( ˶ᵔ ᵕ ᵔ˶ )っ Odysseus vers. 1.0
|
⊹ ࣪ ˖ ૮( ˶ᵔ ᵕ ᵔ˶ )っ Odysseus vers. 1.0
|
||||||
@@ -331,6 +333,12 @@ To expose Odysseus on a local network or Tailscale with HTTPS:
|
|||||||
| `PyMuPDF` | PDF page rendering in the side viewer panel and form-filling. (Note: AGPL-3.0) |
|
| `PyMuPDF` | PDF page rendering in the side viewer panel and form-filling. (Note: AGPL-3.0) |
|
||||||
| `markitdown` | Office/EPUB document text extraction (converts .docx/.xlsx/.pptx/.xls/.epub to Markdown). |
|
| `markitdown` | Office/EPUB document text extraction (converts .docx/.xlsx/.pptx/.xls/.epub to Markdown). |
|
||||||
|
|
||||||
|
### Outlook / Office 365 email
|
||||||
|
Odysseus email accounts currently use IMAP/SMTP username-password auth. Outlook
|
||||||
|
and Microsoft 365 generally require OAuth instead, so normal Microsoft mailbox
|
||||||
|
passwords will fail. See [docs/email-outlook.md](docs/email-outlook.md) for the
|
||||||
|
current limitation and the planned integration direction.
|
||||||
|
|
||||||
## Security Notes
|
## Security Notes
|
||||||
Odysseus is a self-hosted workspace with powerful local tools: shell access, file uploads, model downloads, web research, email/calendar integrations, and API tokens. Treat it like an admin console.
|
Odysseus is a self-hosted workspace with powerful local tools: shell access, file uploads, model downloads, web research, email/calendar integrations, and API tokens. Treat it like an admin console.
|
||||||
|
|
||||||
@@ -394,6 +402,16 @@ Key settings:
|
|||||||
| `CHROMADB_HOST` | `localhost` | ChromaDB host for vector memory. Docker overrides this to `chromadb`. |
|
| `CHROMADB_HOST` | `localhost` | ChromaDB host for vector memory. Docker overrides this to `chromadb`. |
|
||||||
| `CHROMADB_PORT` | `8100` | ChromaDB port for manual host runs. Docker overrides this to `8000`. |
|
| `CHROMADB_PORT` | `8100` | ChromaDB port for manual host runs. Docker overrides this to `8000`. |
|
||||||
| `EMBEDDING_URL` | -- | OpenAI-compatible embeddings endpoint |
|
| `EMBEDDING_URL` | -- | OpenAI-compatible embeddings endpoint |
|
||||||
|
| `ODYSSEUS_CHAT_UPLOAD_MAX_BYTES` | `10485760` | Chat/agent attachment cap in bytes. Raise for larger local PDFs or text documents. |
|
||||||
|
| `ODYSSEUS_GALLERY_UPLOAD_MAX_BYTES` | `104857600` | Gallery image upload cap in bytes (100 MB). |
|
||||||
|
| `ODYSSEUS_GALLERY_TRANSFORM_UPLOAD_MAX_BYTES` | `26214400` | Gallery transform input cap in bytes (25 MB). |
|
||||||
|
| `ODYSSEUS_MEMORY_IMPORT_MAX_BYTES` | `10485760` | Memory import file cap in bytes (10 MB). |
|
||||||
|
| `ODYSSEUS_PERSONAL_UPLOAD_MAX_BYTES` | `26214400` | Personal document upload cap in bytes (25 MB). |
|
||||||
|
| `ODYSSEUS_EMAIL_COMPOSE_UPLOAD_MAX_BYTES` | `26214400` | Email compose attachment cap in bytes (25 MB). |
|
||||||
|
| `ODYSSEUS_STT_MAX_AUDIO_BYTES` | `26214400` | Speech-to-text audio cap in bytes (25 MB). |
|
||||||
|
| `ODYSSEUS_ICS_MAX_BYTES` | `10485760` | Calendar `.ics` import cap in bytes (10 MB). |
|
||||||
|
|
||||||
|
All upload-limit vars are validated (must be a positive integer) and optional; an invalid value fails fast at startup.
|
||||||
|
|
||||||
### Built-in MCP servers (optional setup)
|
### Built-in MCP servers (optional setup)
|
||||||
|
|
||||||
|
|||||||
@@ -51,10 +51,10 @@ from starlette.middleware.base import BaseHTTPMiddleware
|
|||||||
# Core imports
|
# Core imports
|
||||||
from core.constants import (
|
from core.constants import (
|
||||||
BASE_DIR, STATIC_DIR, SESSIONS_FILE,
|
BASE_DIR, STATIC_DIR, SESSIONS_FILE,
|
||||||
REQUEST_TIMEOUT, OPENAI_API_KEY,
|
REQUEST_TIMEOUT, OPENAI_API_KEY, AUTH_FILE,
|
||||||
)
|
)
|
||||||
from core.database import SessionLocal, ApiToken
|
from core.database import SessionLocal, ApiToken
|
||||||
from core.middleware import SecurityHeadersMiddleware
|
from core.middleware import SecurityHeadersMiddleware, is_cors_preflight
|
||||||
from core.auth import AuthManager
|
from core.auth import AuthManager
|
||||||
from core.exceptions import (
|
from core.exceptions import (
|
||||||
SessionNotFoundError, InvalidFileUploadError,
|
SessionNotFoundError, InvalidFileUploadError,
|
||||||
@@ -64,6 +64,7 @@ from core.exceptions import (
|
|||||||
import bcrypt as _bcrypt
|
import bcrypt as _bcrypt
|
||||||
|
|
||||||
from src.app_helpers import abs_join
|
from src.app_helpers import abs_join
|
||||||
|
from src.generated_images import GENERATED_IMAGE_HEADERS, resolve_generated_image_path
|
||||||
from starlette.responses import RedirectResponse
|
from starlette.responses import RedirectResponse
|
||||||
|
|
||||||
# ========= LOGGING =========
|
# ========= LOGGING =========
|
||||||
@@ -252,6 +253,15 @@ if AUTH_ENABLED:
|
|||||||
class AuthMiddleware(BaseHTTPMiddleware):
|
class AuthMiddleware(BaseHTTPMiddleware):
|
||||||
async def dispatch(self, request: Request, call_next):
|
async def dispatch(self, request: Request, call_next):
|
||||||
path = request.url.path
|
path = request.url.path
|
||||||
|
# A genuine CORS preflight (OPTIONS + Access-Control-Request-Method)
|
||||||
|
# carries no credentials by design and must reach CORSMiddleware to be
|
||||||
|
# answered. AuthMiddleware is the outermost middleware, so gating the
|
||||||
|
# preflight on auth 401s it before CORS can respond -- which blocks
|
||||||
|
# every cross-origin browser/WebView client before the real request
|
||||||
|
# is sent. Let real preflights through (only OPTIONS w/ the ACRM
|
||||||
|
# header; never a credentialed request).
|
||||||
|
if is_cors_preflight(request.method, request.headers):
|
||||||
|
return await call_next(request)
|
||||||
if _is_auth_exempt(path):
|
if _is_auth_exempt(path):
|
||||||
return await call_next(request)
|
return await call_next(request)
|
||||||
# In-process internal-tool token bypass. Used by the agent
|
# In-process internal-tool token bypass. Used by the agent
|
||||||
@@ -387,13 +397,7 @@ app.mount("/static", _RevalidatingStatic(directory="static"), name="static")
|
|||||||
@app.get("/api/generated-image/{filename}")
|
@app.get("/api/generated-image/{filename}")
|
||||||
async def serve_generated_image(filename: str, request: Request):
|
async def serve_generated_image(filename: str, request: Request):
|
||||||
"""Serve generated images from the data directory."""
|
"""Serve generated images from the data directory."""
|
||||||
from pathlib import Path
|
img_path = resolve_generated_image_path(filename)
|
||||||
import re
|
|
||||||
if not re.match(r'^[a-f0-9]{8,64}\.(png|jpg|jpeg|webp|gif|mp4|mov|webm|mkv|m4v)$', filename):
|
|
||||||
raise HTTPException(status_code=400, detail="Invalid filename")
|
|
||||||
img_path = Path("data/generated_images") / filename
|
|
||||||
if not img_path.exists():
|
|
||||||
raise HTTPException(status_code=404, detail="Image not found")
|
|
||||||
# SECURITY: filename is the only key, so anyone who knows / guesses a
|
# SECURITY: filename is the only key, so anyone who knows / guesses a
|
||||||
# 12-hex content hash could pull another user's image bytes. Require
|
# 12-hex content hash could pull another user's image bytes. Require
|
||||||
# auth and verify ownership via the gallery row (when one exists).
|
# auth and verify ownership via the gallery row (when one exists).
|
||||||
@@ -429,7 +433,7 @@ async def serve_generated_image(filename: str, request: Request):
|
|||||||
return FileResponse(
|
return FileResponse(
|
||||||
str(img_path),
|
str(img_path),
|
||||||
media_type=mime,
|
media_type=mime,
|
||||||
headers={"Cache-Control": "public, max-age=31536000, immutable"},
|
headers=GENERATED_IMAGE_HEADERS,
|
||||||
)
|
)
|
||||||
|
|
||||||
# ========= YOUTUBE INIT =========
|
# ========= YOUTUBE INIT =========
|
||||||
@@ -594,6 +598,10 @@ app.include_router(setup_model_routes(model_discovery))
|
|||||||
from routes.copilot_routes import setup_copilot_routes
|
from routes.copilot_routes import setup_copilot_routes
|
||||||
app.include_router(setup_copilot_routes())
|
app.include_router(setup_copilot_routes())
|
||||||
|
|
||||||
|
# ChatGPT Subscription device-flow login
|
||||||
|
from routes.chatgpt_subscription_routes import setup_chatgpt_subscription_routes
|
||||||
|
app.include_router(setup_chatgpt_subscription_routes())
|
||||||
|
|
||||||
# TTS
|
# TTS
|
||||||
from routes.tts_routes import setup_tts_routes
|
from routes.tts_routes import setup_tts_routes
|
||||||
app.include_router(setup_tts_routes(tts_service))
|
app.include_router(setup_tts_routes(tts_service))
|
||||||
@@ -789,6 +797,8 @@ async def serve_backgrounds(request: Request):
|
|||||||
|
|
||||||
@app.get("/login")
|
@app.get("/login")
|
||||||
async def serve_login(request: Request):
|
async def serve_login(request: Request):
|
||||||
|
if not AUTH_ENABLED:
|
||||||
|
return RedirectResponse(url="/", status_code=302)
|
||||||
return _serve_html_with_nonce(request, abs_join(BASE_DIR, "static/login.html"))
|
return _serve_html_with_nonce(request, abs_join(BASE_DIR, "static/login.html"))
|
||||||
|
|
||||||
@app.get("/api/version")
|
@app.get("/api/version")
|
||||||
@@ -948,7 +958,7 @@ async def _startup_event():
|
|||||||
owners = set()
|
owners = set()
|
||||||
try:
|
try:
|
||||||
import json as _json
|
import json as _json
|
||||||
auth_path = "data/auth.json"
|
auth_path = AUTH_FILE
|
||||||
with open(auth_path, encoding="utf-8") as f:
|
with open(auth_path, encoding="utf-8") as f:
|
||||||
users = _json.load(f).get("users", {})
|
users = _json.load(f).get("users", {})
|
||||||
owners.update(users.keys())
|
owners.update(users.keys())
|
||||||
@@ -995,7 +1005,7 @@ async def _startup_event():
|
|||||||
# does not make an existing library look empty after auth/account changes.
|
# does not make an existing library look empty after auth/account changes.
|
||||||
try:
|
try:
|
||||||
import json as _json
|
import json as _json
|
||||||
auth_path = "data/auth.json"
|
auth_path = AUTH_FILE
|
||||||
with open(auth_path, encoding="utf-8") as f:
|
with open(auth_path, encoding="utf-8") as f:
|
||||||
users = _json.load(f).get("users", {})
|
users = _json.load(f).get("users", {})
|
||||||
primary_owner = None
|
primary_owner = None
|
||||||
@@ -1067,6 +1077,16 @@ async def _startup_event():
|
|||||||
logger.warning(f"Nightly skill audit failed: {e}")
|
logger.warning(f"Nightly skill audit failed: {e}")
|
||||||
|
|
||||||
_startup_tasks.append(asyncio.create_task(_skill_audit_nightly_loop()))
|
_startup_tasks.append(asyncio.create_task(_skill_audit_nightly_loop()))
|
||||||
|
|
||||||
|
# Cookbook serve lifecycle — kills scheduler-launched serves whose
|
||||||
|
# window-end has passed. Paired with the cookbook_serve builtin
|
||||||
|
# action; both are no-ops unless a scheduled task actually launches
|
||||||
|
# something with end_after_min set. Removing this line + the
|
||||||
|
# cookbook_serve entry in BUILTIN_ACTIONS + src/cookbook_serve_lifecycle.py
|
||||||
|
# removes the feature.
|
||||||
|
from src.cookbook_serve_lifecycle import cookbook_serve_lifecycle_loop
|
||||||
|
_startup_tasks.append(asyncio.create_task(cookbook_serve_lifecycle_loop()))
|
||||||
|
|
||||||
logger.info("Application startup complete")
|
logger.info("Application startup complete")
|
||||||
|
|
||||||
async def _shutdown_event():
|
async def _shutdown_event():
|
||||||
|
|||||||
@@ -14,6 +14,8 @@ import uuid
|
|||||||
|
|
||||||
import bcrypt
|
import bcrypt
|
||||||
|
|
||||||
|
from src.constants import AUTH_FILE
|
||||||
|
|
||||||
PAIRING_VERSION = 1
|
PAIRING_VERSION = 1
|
||||||
COMPANION_SCOPE = "chat"
|
COMPANION_SCOPE = "chat"
|
||||||
|
|
||||||
@@ -61,7 +63,7 @@ def lan_ip_candidates() -> list[str]:
|
|||||||
def find_admin_user() -> str | None:
|
def find_admin_user() -> str | None:
|
||||||
"""Resolve an admin username from data/auth.json (schema uses is_admin),
|
"""Resolve an admin username from data/auth.json (schema uses is_admin),
|
||||||
falling back to the first user."""
|
falling back to the first user."""
|
||||||
auth_path = os.path.join("data", "auth.json")
|
auth_path = AUTH_FILE
|
||||||
try:
|
try:
|
||||||
with open(auth_path, "r", encoding="utf-8") as f:
|
with open(auth_path, "r", encoding="utf-8") as f:
|
||||||
data = json.load(f)
|
data = json.load(f)
|
||||||
|
|||||||
+92
-62
@@ -30,14 +30,24 @@ DEFAULT_PRIVILEGES = {
|
|||||||
"can_manage_memory": True,
|
"can_manage_memory": True,
|
||||||
"max_messages_per_day": 0,
|
"max_messages_per_day": 0,
|
||||||
"allowed_models": [],
|
"allowed_models": [],
|
||||||
|
"allowed_models_restricted": False,
|
||||||
|
# Explicit "block every model" sentinel. An empty `allowed_models` list is
|
||||||
|
# ambiguous — it's also what gets sent when the admin clicks "[All]" — so
|
||||||
|
# we need a dedicated flag to express "this user may use no models at all"
|
||||||
|
# distinctly from "this user has no restriction".
|
||||||
|
"block_all_models": False,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Admins get everything
|
# Admins get everything
|
||||||
ADMIN_PRIVILEGES = {k: (True if isinstance(v, bool) else (0 if isinstance(v, int) else [])) for k, v in DEFAULT_PRIVILEGES.items()}
|
ADMIN_PRIVILEGES = {k: (True if isinstance(v, bool) else (0 if isinstance(v, int) else [])) for k, v in DEFAULT_PRIVILEGES.items()}
|
||||||
|
ADMIN_PRIVILEGES["allowed_models_restricted"] = False
|
||||||
|
# Admins must never be blocked from using models — the generic dict
|
||||||
|
# comprehension above flips every boolean default to True, which would be
|
||||||
|
# backwards for this sentinel.
|
||||||
|
ADMIN_PRIVILEGES["block_all_models"] = False
|
||||||
|
|
||||||
DEFAULT_AUTH_PATH = os.path.join(
|
from src.constants import AUTH_FILE
|
||||||
Path(__file__).parent.parent, "data", "auth.json"
|
DEFAULT_AUTH_PATH = AUTH_FILE
|
||||||
)
|
|
||||||
TOKEN_TTL = 60 * 60 * 24 * 7 # 7 days
|
TOKEN_TTL = 60 * 60 * 24 * 7 # 7 days
|
||||||
|
|
||||||
# Usernames the auth + middleware layer reserve as internal "synthetic owner"
|
# Usernames the auth + middleware layer reserve as internal "synthetic owner"
|
||||||
@@ -76,6 +86,10 @@ class AuthManager:
|
|||||||
# Guards mutations of self._sessions and the on-disk sessions.json.
|
# Guards mutations of self._sessions and the on-disk sessions.json.
|
||||||
# Validate/create/revoke run concurrently from the FastAPI threadpool.
|
# Validate/create/revoke run concurrently from the FastAPI threadpool.
|
||||||
self._sessions_lock = threading.RLock()
|
self._sessions_lock = threading.RLock()
|
||||||
|
# Guards all mutations of self._config and the on-disk auth.json so
|
||||||
|
# concurrent create/delete/rename/privilege operations don't interleave
|
||||||
|
# and corrupt the user database.
|
||||||
|
self._config_lock = threading.Lock()
|
||||||
# Guards the first-run setup check-and-write so concurrent requests
|
# Guards the first-run setup check-and-write so concurrent requests
|
||||||
# cannot both observe is_configured==False and both create admin accounts.
|
# cannot both observe is_configured==False and both create admin accounts.
|
||||||
self._setup_lock = threading.Lock()
|
self._setup_lock = threading.Lock()
|
||||||
@@ -172,8 +186,9 @@ class AuthManager:
|
|||||||
|
|
||||||
@signup_enabled.setter
|
@signup_enabled.setter
|
||||||
def signup_enabled(self, value: bool):
|
def signup_enabled(self, value: bool):
|
||||||
self._config["signup_enabled"] = value
|
with self._config_lock:
|
||||||
self._save()
|
self._config["signup_enabled"] = value
|
||||||
|
self._save()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_configured(self) -> bool:
|
def is_configured(self) -> bool:
|
||||||
@@ -198,17 +213,18 @@ class AuthManager:
|
|||||||
if username in RESERVED_USERNAMES:
|
if username in RESERVED_USERNAMES:
|
||||||
logger.warning("Refused to create reserved username '%s'", username)
|
logger.warning("Refused to create reserved username '%s'", username)
|
||||||
return False
|
return False
|
||||||
if username in self.users:
|
with self._config_lock:
|
||||||
return False
|
if username in self.users:
|
||||||
if "users" not in self._config:
|
return False
|
||||||
self._config["users"] = {}
|
if "users" not in self._config:
|
||||||
self._config["users"][username] = {
|
self._config["users"] = {}
|
||||||
"password_hash": _hash_password(password),
|
self._config["users"][username] = {
|
||||||
"created": time.time(),
|
"password_hash": _hash_password(password),
|
||||||
"is_admin": is_admin,
|
"created": time.time(),
|
||||||
"privileges": dict(ADMIN_PRIVILEGES if is_admin else DEFAULT_PRIVILEGES),
|
"is_admin": is_admin,
|
||||||
}
|
"privileges": dict(ADMIN_PRIVILEGES if is_admin else DEFAULT_PRIVILEGES),
|
||||||
self._save()
|
}
|
||||||
|
self._save()
|
||||||
logger.info(f"Created user '{username}' (admin={is_admin})")
|
logger.info(f"Created user '{username}' (admin={is_admin})")
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@@ -221,14 +237,15 @@ class AuthManager:
|
|||||||
their cookie expired naturally (default ~30 days).
|
their cookie expired naturally (default ~30 days).
|
||||||
"""
|
"""
|
||||||
username = username.strip().lower()
|
username = username.strip().lower()
|
||||||
if username not in self.users:
|
with self._config_lock:
|
||||||
return False
|
if username not in self.users:
|
||||||
if username == requesting_user:
|
return False
|
||||||
return False
|
if username == requesting_user:
|
||||||
if not self.users.get(requesting_user, {}).get("is_admin"):
|
return False
|
||||||
return False
|
if not self.users.get(requesting_user, {}).get("is_admin"):
|
||||||
del self._config["users"][username]
|
return False
|
||||||
self._save()
|
del self._config["users"][username]
|
||||||
|
self._save()
|
||||||
# Purge all sessions belonging to this user. validate_token doesn't
|
# Purge all sessions belonging to this user. validate_token doesn't
|
||||||
# cross-check `self.users`, so without this step a deleted user's
|
# cross-check `self.users`, so without this step a deleted user's
|
||||||
# cookie keeps authenticating.
|
# cookie keeps authenticating.
|
||||||
@@ -266,14 +283,15 @@ class AuthManager:
|
|||||||
if new_username in RESERVED_USERNAMES:
|
if new_username in RESERVED_USERNAMES:
|
||||||
logger.warning("Refused to rename '%s' into reserved username '%s'", old_username, new_username)
|
logger.warning("Refused to rename '%s' into reserved username '%s'", old_username, new_username)
|
||||||
return False
|
return False
|
||||||
if old_username not in self.users:
|
with self._config_lock:
|
||||||
return False
|
if old_username not in self.users:
|
||||||
if new_username in self.users:
|
return False
|
||||||
return False
|
if new_username in self.users:
|
||||||
if not self.users.get(requesting_user, {}).get("is_admin"):
|
return False
|
||||||
return False
|
if not self.users.get(requesting_user, {}).get("is_admin"):
|
||||||
self._config.setdefault("users", {})[new_username] = self._config["users"].pop(old_username)
|
return False
|
||||||
self._save()
|
self._config.setdefault("users", {})[new_username] = self._config["users"].pop(old_username)
|
||||||
|
self._save()
|
||||||
|
|
||||||
renamed_sessions = 0
|
renamed_sessions = 0
|
||||||
with self._sessions_lock:
|
with self._sessions_lock:
|
||||||
@@ -311,17 +329,18 @@ class AuthManager:
|
|||||||
def set_privileges(self, username: str, privileges: Dict[str, Any]) -> bool:
|
def set_privileges(self, username: str, privileges: Dict[str, Any]) -> bool:
|
||||||
"""Update privileges for a user. Can't modify admin privileges."""
|
"""Update privileges for a user. Can't modify admin privileges."""
|
||||||
username = username.strip().lower()
|
username = username.strip().lower()
|
||||||
if username not in self.users:
|
with self._config_lock:
|
||||||
return False
|
if username not in self.users:
|
||||||
if self.users[username].get("is_admin"):
|
return False
|
||||||
return False # admins always have full access
|
if self.users[username].get("is_admin"):
|
||||||
# Only allow known privilege keys
|
return False # admins always have full access
|
||||||
current = self.get_privileges(username)
|
# Only allow known privilege keys
|
||||||
for k, v in privileges.items():
|
current = self.get_privileges(username)
|
||||||
if k in DEFAULT_PRIVILEGES:
|
for k, v in privileges.items():
|
||||||
current[k] = v
|
if k in DEFAULT_PRIVILEGES:
|
||||||
self._config["users"][username]["privileges"] = current
|
current[k] = v
|
||||||
self._save()
|
self._config["users"][username]["privileges"] = current
|
||||||
|
self._save()
|
||||||
logger.info(f"Updated privileges for '{username}': {current}")
|
logger.info(f"Updated privileges for '{username}': {current}")
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@@ -331,8 +350,9 @@ class AuthManager:
|
|||||||
return False
|
return False
|
||||||
if not _verify_password(current_password, self.users[username]["password_hash"]):
|
if not _verify_password(current_password, self.users[username]["password_hash"]):
|
||||||
return False
|
return False
|
||||||
self._config["users"][username]["password_hash"] = _hash_password(new_password)
|
with self._config_lock:
|
||||||
self._save()
|
self._config["users"][username]["password_hash"] = _hash_password(new_password)
|
||||||
|
self._save()
|
||||||
return True
|
return True
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
@@ -350,8 +370,9 @@ class AuthManager:
|
|||||||
if username not in self.users:
|
if username not in self.users:
|
||||||
return None
|
return None
|
||||||
secret = pyotp.random_base32()
|
secret = pyotp.random_base32()
|
||||||
self._config["users"][username]["totp_secret_pending"] = secret
|
with self._config_lock:
|
||||||
self._save()
|
self._config["users"][username]["totp_secret_pending"] = secret
|
||||||
|
self._save()
|
||||||
return secret
|
return secret
|
||||||
|
|
||||||
def totp_get_provisioning_uri(self, username: str, secret: str) -> str:
|
def totp_get_provisioning_uri(self, username: str, secret: str) -> str:
|
||||||
@@ -370,13 +391,14 @@ class AuthManager:
|
|||||||
if not totp.verify(code, valid_window=1):
|
if not totp.verify(code, valid_window=1):
|
||||||
return False
|
return False
|
||||||
# Enable 2FA
|
# Enable 2FA
|
||||||
self._config["users"][username]["totp_secret"] = secret
|
with self._config_lock:
|
||||||
self._config["users"][username]["totp_enabled"] = True
|
self._config["users"][username]["totp_secret"] = secret
|
||||||
self._config["users"][username].pop("totp_secret_pending", None)
|
self._config["users"][username]["totp_enabled"] = True
|
||||||
# Generate backup codes
|
self._config["users"][username].pop("totp_secret_pending", None)
|
||||||
backup = [secrets.token_hex(4) for _ in range(8)]
|
# Generate backup codes
|
||||||
self._config["users"][username]["totp_backup_codes"] = backup
|
backup = [secrets.token_hex(4) for _ in range(8)]
|
||||||
self._save()
|
self._config["users"][username]["totp_backup_codes"] = backup
|
||||||
|
self._save()
|
||||||
logger.info(f"2FA enabled for '{username}'")
|
logger.info(f"2FA enabled for '{username}'")
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@@ -395,9 +417,10 @@ class AuthManager:
|
|||||||
# Check backup codes first
|
# Check backup codes first
|
||||||
backup = user.get("totp_backup_codes", [])
|
backup = user.get("totp_backup_codes", [])
|
||||||
if code in backup:
|
if code in backup:
|
||||||
backup.remove(code)
|
with self._config_lock:
|
||||||
self._config["users"][username]["totp_backup_codes"] = backup
|
backup.remove(code)
|
||||||
self._save()
|
self._config["users"][username]["totp_backup_codes"] = backup
|
||||||
|
self._save()
|
||||||
logger.info(f"Backup code used for '{username}' ({len(backup)} remaining)")
|
logger.info(f"Backup code used for '{username}' ({len(backup)} remaining)")
|
||||||
return True
|
return True
|
||||||
totp = pyotp.TOTP(secret)
|
totp = pyotp.TOTP(secret)
|
||||||
@@ -408,11 +431,12 @@ class AuthManager:
|
|||||||
username = username.strip().lower()
|
username = username.strip().lower()
|
||||||
if not self.verify_password(username, password):
|
if not self.verify_password(username, password):
|
||||||
return False
|
return False
|
||||||
self._config["users"][username].pop("totp_secret", None)
|
with self._config_lock:
|
||||||
self._config["users"][username].pop("totp_secret_pending", None)
|
self._config["users"][username].pop("totp_secret", None)
|
||||||
self._config["users"][username].pop("totp_backup_codes", None)
|
self._config["users"][username].pop("totp_secret_pending", None)
|
||||||
self._config["users"][username]["totp_enabled"] = False
|
self._config["users"][username].pop("totp_backup_codes", None)
|
||||||
self._save()
|
self._config["users"][username]["totp_enabled"] = False
|
||||||
|
self._save()
|
||||||
logger.info(f"2FA disabled for '{username}'")
|
logger.info(f"2FA disabled for '{username}'")
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@@ -431,6 +455,12 @@ class AuthManager:
|
|||||||
username = username.strip().lower()
|
username = username.strip().lower()
|
||||||
if not self.verify_password(username, password):
|
if not self.verify_password(username, password):
|
||||||
return None
|
return None
|
||||||
|
return self.create_session_trusted(username)
|
||||||
|
|
||||||
|
def create_session_trusted(self, username: str) -> str:
|
||||||
|
"""Issue a session token for an already-verified user.
|
||||||
|
Call only after verify_password (and TOTP if enabled) have passed."""
|
||||||
|
username = username.strip().lower()
|
||||||
token = secrets.token_hex(32)
|
token = secrets.token_hex(32)
|
||||||
with self._sessions_lock:
|
with self._sessions_lock:
|
||||||
self._sessions[token] = {
|
self._sessions[token] = {
|
||||||
|
|||||||
+11
-39
@@ -1,40 +1,12 @@
|
|||||||
# src/constants.py
|
# core/constants.py
|
||||||
"""Application-wide constants and configuration values."""
|
"""Backward-compatible shim — the single source of truth is src/constants.py.
|
||||||
import os
|
|
||||||
|
|
||||||
APP_VERSION = "0.9.1"
|
Historically there were two copies of this module (this one lagged behind at
|
||||||
|
APP_VERSION 0.9.1 and was missing the consolidated tool-output constants). To
|
||||||
# Base paths
|
kill the drift, this now simply re-exports everything from src.constants so
|
||||||
BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + "/"
|
there is exactly one place that defines paths and reads ODYSSEUS_DATA_DIR.
|
||||||
STATIC_DIR = os.path.join(BASE_DIR, "static")
|
internal_api_base() also lives in src.constants now and is re-exported here so
|
||||||
DATA_DIR = os.path.join(BASE_DIR, "data")
|
existing `from core.constants import internal_api_base` callers keep working.
|
||||||
|
"""
|
||||||
# Data file paths
|
from src.constants import * # noqa: F401,F403
|
||||||
SESSIONS_FILE = os.path.join(DATA_DIR, "sessions.json")
|
from src.constants import internal_api_base # noqa: F401 (explicit: functions aren't covered by some linters' * checks)
|
||||||
MEMORY_FILE = os.path.join(DATA_DIR, "memory.json")
|
|
||||||
MEMORY_DOC = os.path.join(DATA_DIR, "memory_doc.md")
|
|
||||||
PERSONAL_DIR = os.path.join(DATA_DIR, "personal_docs")
|
|
||||||
RUNBOOK_DIR = os.path.join(PERSONAL_DIR, "runbook")
|
|
||||||
UPLOAD_DIR = os.path.join(DATA_DIR, "uploads")
|
|
||||||
FEATURES_FILE = os.path.join(DATA_DIR, "features.json")
|
|
||||||
SETTINGS_FILE = os.path.join(DATA_DIR, "settings.json")
|
|
||||||
|
|
||||||
# API Configuration
|
|
||||||
MAX_CONTEXT_MESSAGES = 90
|
|
||||||
REQUEST_TIMEOUT = 20
|
|
||||||
OPENAI_COMPAT_PATH = "/v1/chat/completions"
|
|
||||||
|
|
||||||
# Environment variables with defaults
|
|
||||||
DEFAULT_HOST = os.getenv("LLM_HOST", "localhost")
|
|
||||||
LLM_HOSTS = [h.strip() for h in os.getenv("LLM_HOSTS", "").split(",") if h.strip()]
|
|
||||||
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
|
|
||||||
SEARXNG_INSTANCE = os.getenv('SEARXNG_INSTANCE', 'http://localhost:8080')
|
|
||||||
|
|
||||||
|
|
||||||
# Cleanup configuration
|
|
||||||
CLEANUP_ENABLED = os.getenv("CLEANUP_ENABLED", "True").lower() == "true"
|
|
||||||
CLEANUP_INTERVAL_HOURS = int(os.getenv("CLEANUP_INTERVAL_HOURS", "24"))
|
|
||||||
|
|
||||||
# Default parameters
|
|
||||||
DEFAULT_TEMPERATURE = 1.0
|
|
||||||
DEFAULT_MAX_TOKENS = 0
|
|
||||||
|
|||||||
+168
-7
@@ -29,8 +29,9 @@ class TimestampMixin:
|
|||||||
def updated_at(cls):
|
def updated_at(cls):
|
||||||
return Column(DateTime, default=utcnow_naive, onupdate=utcnow_naive, nullable=False)
|
return Column(DateTime, default=utcnow_naive, onupdate=utcnow_naive, nullable=False)
|
||||||
|
|
||||||
# Get database URL from environment, default to SQLite
|
# Get database URL from environment, default to SQLite in DATA_DIR
|
||||||
DATABASE_URL = os.getenv("DATABASE_URL", "sqlite:///./data/app.db")
|
from src.constants import DATA_DIR, AUTH_FILE, MEMORY_FILE, USER_PREFS_FILE, SETTINGS_FILE
|
||||||
|
DATABASE_URL = os.getenv("DATABASE_URL", f"sqlite:///{DATA_DIR}/app.db")
|
||||||
|
|
||||||
# Create engine
|
# Create engine
|
||||||
engine = create_engine(
|
engine = create_engine(
|
||||||
@@ -360,6 +361,24 @@ class ModelEndpoint(TimestampMixin, Base):
|
|||||||
# is the historical default. When non-null, the model picker only shows
|
# is the historical default. When non-null, the model picker only shows
|
||||||
# the endpoint to that user (admins always see everything).
|
# the endpoint to that user (admins always see everything).
|
||||||
owner = Column(String, nullable=True, index=True)
|
owner = Column(String, nullable=True, index=True)
|
||||||
|
# Optional OAuth/session-backed credential row. Used by subscription-backed
|
||||||
|
# providers that need refresh tokens instead of a static API key.
|
||||||
|
provider_auth_id = Column(String, nullable=True, index=True)
|
||||||
|
|
||||||
|
|
||||||
|
class ProviderAuthSession(TimestampMixin, Base):
|
||||||
|
"""Encrypted OAuth/session credentials for refresh-aware model providers."""
|
||||||
|
__tablename__ = "provider_auth_sessions"
|
||||||
|
|
||||||
|
id = Column(String, primary_key=True, index=True)
|
||||||
|
provider = Column(String, nullable=False, index=True)
|
||||||
|
owner = Column(String, nullable=True, index=True)
|
||||||
|
label = Column(String, nullable=True)
|
||||||
|
base_url = Column(String, nullable=False)
|
||||||
|
access_token = Column(EncryptedText, nullable=True)
|
||||||
|
refresh_token = Column(EncryptedText, nullable=True)
|
||||||
|
last_refresh = Column(DateTime, nullable=True)
|
||||||
|
auth_mode = Column(String, nullable=True)
|
||||||
|
|
||||||
class McpServer(TimestampMixin, Base):
|
class McpServer(TimestampMixin, Base):
|
||||||
"""Admin-configured MCP (Model Context Protocol) tool servers."""
|
"""Admin-configured MCP (Model Context Protocol) tool servers."""
|
||||||
@@ -800,6 +819,26 @@ def _migrate_add_model_endpoint_owner_column():
|
|||||||
logging.getLogger(__name__).warning(f"model_endpoints.owner migration failed: {e}")
|
logging.getLogger(__name__).warning(f"model_endpoints.owner migration failed: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
def _migrate_add_provider_auth_id_column():
|
||||||
|
"""Add provider_auth_id column to model_endpoints if it doesn't exist."""
|
||||||
|
import sqlite3
|
||||||
|
db_path = DATABASE_URL.replace("sqlite:///", "")
|
||||||
|
if not os.path.exists(db_path):
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
conn = sqlite3.connect(db_path)
|
||||||
|
cursor = conn.execute("PRAGMA table_info(model_endpoints)")
|
||||||
|
columns = [row[1] for row in cursor.fetchall()]
|
||||||
|
if columns and "provider_auth_id" not in columns:
|
||||||
|
conn.execute("ALTER TABLE model_endpoints ADD COLUMN provider_auth_id VARCHAR")
|
||||||
|
conn.execute("CREATE INDEX IF NOT EXISTS ix_model_endpoints_provider_auth_id ON model_endpoints(provider_auth_id)")
|
||||||
|
conn.commit()
|
||||||
|
logging.getLogger(__name__).info("Migrated: added 'provider_auth_id' column + index to model_endpoints")
|
||||||
|
conn.close()
|
||||||
|
except Exception as e:
|
||||||
|
logging.getLogger(__name__).warning(f"model_endpoints.provider_auth_id migration failed: {e}")
|
||||||
|
|
||||||
|
|
||||||
def _migrate_add_model_type_column():
|
def _migrate_add_model_type_column():
|
||||||
"""Add model_type column to model_endpoints if it doesn't exist."""
|
"""Add model_type column to model_endpoints if it doesn't exist."""
|
||||||
import sqlite3
|
import sqlite3
|
||||||
@@ -1065,7 +1104,7 @@ def _migrate_assign_legacy_owner():
|
|||||||
# fell through to "first user" every time.
|
# fell through to "first user" every time.
|
||||||
auth_path = os.path.join(os.path.dirname(DATABASE_URL.replace("sqlite:///", "")), "auth.json")
|
auth_path = os.path.join(os.path.dirname(DATABASE_URL.replace("sqlite:///", "")), "auth.json")
|
||||||
if not os.path.isabs(auth_path):
|
if not os.path.isabs(auth_path):
|
||||||
auth_path = os.path.join("data", "auth.json")
|
auth_path = AUTH_FILE
|
||||||
admin_user = None
|
admin_user = None
|
||||||
try:
|
try:
|
||||||
with open(auth_path, "r", encoding="utf-8") as f:
|
with open(auth_path, "r", encoding="utf-8") as f:
|
||||||
@@ -1118,7 +1157,7 @@ def _migrate_assign_legacy_owner():
|
|||||||
logger.warning(f"Legacy owner migration failed: {e}")
|
logger.warning(f"Legacy owner migration failed: {e}")
|
||||||
|
|
||||||
# Also migrate memory.json
|
# Also migrate memory.json
|
||||||
mem_path = os.path.join("data", "memory.json")
|
mem_path = MEMORY_FILE
|
||||||
try:
|
try:
|
||||||
if os.path.exists(mem_path):
|
if os.path.exists(mem_path):
|
||||||
with open(mem_path, "r", encoding="utf-8") as f:
|
with open(mem_path, "r", encoding="utf-8") as f:
|
||||||
@@ -1136,7 +1175,7 @@ def _migrate_assign_legacy_owner():
|
|||||||
logger.warning(f"memory.json legacy migration failed: {e}")
|
logger.warning(f"memory.json legacy migration failed: {e}")
|
||||||
|
|
||||||
# Also migrate user_prefs.json to per-user format
|
# Also migrate user_prefs.json to per-user format
|
||||||
prefs_path = os.path.join("data", "user_prefs.json")
|
prefs_path = USER_PREFS_FILE
|
||||||
try:
|
try:
|
||||||
if os.path.exists(prefs_path):
|
if os.path.exists(prefs_path):
|
||||||
with open(prefs_path, "r", encoding="utf-8") as f:
|
with open(prefs_path, "r", encoding="utf-8") as f:
|
||||||
@@ -1458,7 +1497,11 @@ class CalendarCal(TimestampMixin, Base):
|
|||||||
owner = Column(String, nullable=True, index=True)
|
owner = Column(String, nullable=True, index=True)
|
||||||
name = Column(String, nullable=False)
|
name = Column(String, nullable=False)
|
||||||
color = Column(String, default="#5b8abf")
|
color = Column(String, default="#5b8abf")
|
||||||
source = Column(String, default="local") # "local" or "timetree"
|
source = Column(String, default="local") # "local" or "caldav"
|
||||||
|
# UUID of the CalDAV account in user prefs that owns this calendar.
|
||||||
|
# NULL for local calendars and for CalDAV calendars created before
|
||||||
|
# multi-account support was added (treated as "use any configured account").
|
||||||
|
account_id = Column(String, nullable=True, index=True)
|
||||||
|
|
||||||
events = relationship("CalendarEvent", back_populates="calendar", cascade="all, delete-orphan")
|
events = relationship("CalendarEvent", back_populates="calendar", cascade="all, delete-orphan")
|
||||||
|
|
||||||
@@ -1526,7 +1569,7 @@ def _migrate_seed_email_account():
|
|||||||
import json as _json
|
import json as _json
|
||||||
import uuid as _uuid
|
import uuid as _uuid
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
settings_file = Path("data/settings.json")
|
settings_file = Path(SETTINGS_FILE)
|
||||||
if not settings_file.exists():
|
if not settings_file.exists():
|
||||||
return
|
return
|
||||||
try:
|
try:
|
||||||
@@ -1594,6 +1637,7 @@ def init_db():
|
|||||||
_migrate_add_model_type_column()
|
_migrate_add_model_type_column()
|
||||||
_migrate_add_model_endpoint_refresh_columns()
|
_migrate_add_model_endpoint_refresh_columns()
|
||||||
_migrate_add_model_endpoint_owner_column()
|
_migrate_add_model_endpoint_owner_column()
|
||||||
|
_migrate_add_provider_auth_id_column()
|
||||||
_migrate_add_supports_tools_column()
|
_migrate_add_supports_tools_column()
|
||||||
_migrate_add_task_run_model_column()
|
_migrate_add_task_run_model_column()
|
||||||
_migrate_add_owner_column()
|
_migrate_add_owner_column()
|
||||||
@@ -1622,9 +1666,105 @@ def init_db():
|
|||||||
_migrate_add_calendar_metadata()
|
_migrate_add_calendar_metadata()
|
||||||
_migrate_add_calendar_is_utc()
|
_migrate_add_calendar_is_utc()
|
||||||
_migrate_add_calendar_origin()
|
_migrate_add_calendar_origin()
|
||||||
|
_migrate_add_calendar_account_id()
|
||||||
|
_migrate_chat_messages_fts()
|
||||||
_migrate_encrypt_email_passwords()
|
_migrate_encrypt_email_passwords()
|
||||||
_migrate_encrypt_signatures()
|
_migrate_encrypt_signatures()
|
||||||
_migrate_encrypt_endpoint_keys()
|
_migrate_encrypt_endpoint_keys()
|
||||||
|
_migrate_backfill_task_folders()
|
||||||
|
|
||||||
|
|
||||||
|
def _migrate_backfill_task_folders():
|
||||||
|
"""Backfill folder='Tasks' on pre-existing task/research sessions.
|
||||||
|
|
||||||
|
Sessions created by the task scheduler (LLM tasks, action tasks, research
|
||||||
|
runs) now set folder='Tasks' at creation time. This migration tags any
|
||||||
|
older sessions that predate that assignment. Idempotent — only touches
|
||||||
|
rows where folder is NULL or empty and the title matches known prefixes.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
with engine.connect() as conn:
|
||||||
|
cols = [r[1] for r in conn.execute(text("PRAGMA table_info(sessions)"))]
|
||||||
|
if "folder" not in cols:
|
||||||
|
return
|
||||||
|
res = conn.execute(text(
|
||||||
|
"UPDATE sessions SET folder = 'Tasks' "
|
||||||
|
"WHERE (folder IS NULL OR folder = '') "
|
||||||
|
"AND (name LIKE '[Task] %' OR name LIKE '[Research] %')"
|
||||||
|
))
|
||||||
|
conn.commit()
|
||||||
|
if res.rowcount:
|
||||||
|
logging.getLogger(__name__).info(
|
||||||
|
f"Backfilled folder='Tasks' on {res.rowcount} task/research sessions")
|
||||||
|
except Exception as e:
|
||||||
|
logging.getLogger(__name__).warning(f"task folder backfill: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
def _migrate_chat_messages_fts():
|
||||||
|
"""Create and backfill the session transcript FTS index for SQLite."""
|
||||||
|
if not DATABASE_URL.startswith("sqlite"):
|
||||||
|
return
|
||||||
|
|
||||||
|
db_path = DATABASE_URL.replace("sqlite:///", "")
|
||||||
|
if db_path == ":memory:":
|
||||||
|
return
|
||||||
|
conn = None
|
||||||
|
try:
|
||||||
|
conn = sqlite3.connect(db_path)
|
||||||
|
try:
|
||||||
|
conn.execute("CREATE VIRTUAL TABLE IF NOT EXISTS temp._odysseus_fts5_probe USING fts5(content)")
|
||||||
|
conn.execute("DROP TABLE IF EXISTS temp._odysseus_fts5_probe")
|
||||||
|
except Exception as e:
|
||||||
|
logging.getLogger(__name__).warning(f"chat_messages FTS migration skipped; FTS5 unavailable: {e}")
|
||||||
|
return
|
||||||
|
|
||||||
|
conn.executescript(
|
||||||
|
"""
|
||||||
|
CREATE VIRTUAL TABLE IF NOT EXISTS chat_messages_fts USING fts5(
|
||||||
|
content,
|
||||||
|
message_id UNINDEXED,
|
||||||
|
session_id UNINDEXED,
|
||||||
|
role UNINDEXED
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TRIGGER IF NOT EXISTS chat_messages_fts_ai
|
||||||
|
AFTER INSERT ON chat_messages BEGIN
|
||||||
|
INSERT INTO chat_messages_fts(content, message_id, session_id, role)
|
||||||
|
VALUES (COALESCE(new.content, ''), new.id, new.session_id, new.role);
|
||||||
|
END;
|
||||||
|
|
||||||
|
CREATE TRIGGER IF NOT EXISTS chat_messages_fts_ad
|
||||||
|
AFTER DELETE ON chat_messages BEGIN
|
||||||
|
DELETE FROM chat_messages_fts WHERE message_id = old.id;
|
||||||
|
END;
|
||||||
|
|
||||||
|
CREATE TRIGGER IF NOT EXISTS chat_messages_fts_au
|
||||||
|
AFTER UPDATE ON chat_messages BEGIN
|
||||||
|
DELETE FROM chat_messages_fts WHERE message_id = old.id;
|
||||||
|
INSERT INTO chat_messages_fts(content, message_id, session_id, role)
|
||||||
|
VALUES (COALESCE(new.content, ''), new.id, new.session_id, new.role);
|
||||||
|
END;
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
conn.execute(
|
||||||
|
"""
|
||||||
|
INSERT INTO chat_messages_fts(content, message_id, session_id, role)
|
||||||
|
SELECT COALESCE(cm.content, ''), cm.id, cm.session_id, cm.role
|
||||||
|
FROM chat_messages cm
|
||||||
|
WHERE NOT EXISTS (
|
||||||
|
SELECT 1 FROM chat_messages_fts fts
|
||||||
|
WHERE fts.message_id = cm.id
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
conn.commit()
|
||||||
|
except Exception as e:
|
||||||
|
logging.getLogger(__name__).warning(f"chat_messages FTS migration failed: {e}")
|
||||||
|
finally:
|
||||||
|
try:
|
||||||
|
conn.close()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
def _migrate_add_email_smtp_security():
|
def _migrate_add_email_smtp_security():
|
||||||
@@ -1786,6 +1926,27 @@ def _migrate_add_calendar_origin():
|
|||||||
logging.getLogger(__name__).warning(f"calendar_events.origin migration failed: {e}")
|
logging.getLogger(__name__).warning(f"calendar_events.origin migration failed: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
def _migrate_add_calendar_account_id():
|
||||||
|
"""Add `account_id` to calendars so each CalDAV-backed calendar knows which
|
||||||
|
credential set (from caldav_accounts in user prefs) owns it. Idempotent."""
|
||||||
|
import sqlite3
|
||||||
|
db_path = DATABASE_URL.replace("sqlite:///", "")
|
||||||
|
if not os.path.exists(db_path):
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
conn = sqlite3.connect(db_path)
|
||||||
|
cursor = conn.execute("PRAGMA table_info(calendars)")
|
||||||
|
columns = [row[1] for row in cursor.fetchall()]
|
||||||
|
if columns and "account_id" not in columns:
|
||||||
|
conn.execute("ALTER TABLE calendars ADD COLUMN account_id TEXT")
|
||||||
|
conn.execute("CREATE INDEX IF NOT EXISTS ix_calendars_account_id ON calendars(account_id)")
|
||||||
|
conn.commit()
|
||||||
|
logging.getLogger(__name__).info("Migrated: added 'account_id' column to calendars")
|
||||||
|
conn.close()
|
||||||
|
except Exception as e:
|
||||||
|
logging.getLogger(__name__).warning(f"calendars.account_id migration failed: {e}")
|
||||||
|
|
||||||
|
|
||||||
def _migrate_add_calendar_metadata():
|
def _migrate_add_calendar_metadata():
|
||||||
"""Add importance/event_type/last_pinged columns to calendar_events table."""
|
"""Add importance/event_type/last_pinged columns to calendar_events table."""
|
||||||
import sqlite3
|
import sqlite3
|
||||||
|
|||||||
@@ -17,6 +17,15 @@ INTERNAL_TOOL_TOKEN = os.environ.get("ODYSSEUS_INTERNAL_TOKEN") or secrets.token
|
|||||||
INTERNAL_TOOL_HEADER = "X-Odysseus-Internal-Token"
|
INTERNAL_TOOL_HEADER = "X-Odysseus-Internal-Token"
|
||||||
|
|
||||||
|
|
||||||
|
def is_cors_preflight(method: str, headers) -> bool:
|
||||||
|
"""True for a genuine CORS preflight: an OPTIONS request carrying the
|
||||||
|
Access-Control-Request-Method header. Such requests are credential-less by
|
||||||
|
design and must reach CORSMiddleware to be answered -- gating them on auth
|
||||||
|
401s the preflight and breaks every cross-origin browser/WebView client.
|
||||||
|
Pure so it can be unit-tested without standing up the app."""
|
||||||
|
return method == "OPTIONS" and "access-control-request-method" in headers
|
||||||
|
|
||||||
|
|
||||||
def require_admin(request: Request):
|
def require_admin(request: Request):
|
||||||
"""Raise 403 if the current user isn't an admin.
|
"""Raise 403 if the current user isn't an admin.
|
||||||
Allows access when auth is explicitly disabled, or when the request carries
|
Allows access when auth is explicitly disabled, or when the request carries
|
||||||
@@ -58,11 +67,22 @@ class SecurityHeadersMiddleware(BaseHTTPMiddleware):
|
|||||||
|
|
||||||
# Tool render endpoints are served inside iframes — allow framing by self
|
# Tool render endpoints are served inside iframes — allow framing by self
|
||||||
is_tool_render = path.startswith("/api/tools/") and path.endswith("/render")
|
is_tool_render = path.startswith("/api/tools/") and path.endswith("/render")
|
||||||
|
# PDF previews are embedded by the in-app document library. Keep the
|
||||||
|
# exception route-scoped so normal app pages remain unframeable.
|
||||||
|
is_document_pdf_preview = path.startswith("/api/document/") and path.endswith("/render-pdf")
|
||||||
# Visual report pages are self-contained HTML — need inline scripts + external images
|
# Visual report pages are self-contained HTML — need inline scripts + external images
|
||||||
is_report = path.startswith("/api/research/report/")
|
is_report = path.startswith("/api/research/report/")
|
||||||
|
|
||||||
response.headers["X-Content-Type-Options"] = "nosniff"
|
response.headers["X-Content-Type-Options"] = "nosniff"
|
||||||
response.headers["Referrer-Policy"] = "no-referrer"
|
response.headers["Referrer-Policy"] = "no-referrer"
|
||||||
|
response.headers["Permissions-Policy"] = "camera=(), microphone=(self), geolocation=()"
|
||||||
|
|
||||||
|
is_https = (
|
||||||
|
request.url.scheme == "https"
|
||||||
|
or request.headers.get("X-Forwarded-Proto") == "https"
|
||||||
|
)
|
||||||
|
if is_https:
|
||||||
|
response.headers["Strict-Transport-Security"] = "max-age=31536000; includeSubDomains"
|
||||||
|
|
||||||
if is_report:
|
if is_report:
|
||||||
response.headers["Content-Security-Policy"] = (
|
response.headers["Content-Security-Policy"] = (
|
||||||
@@ -79,6 +99,12 @@ class SecurityHeadersMiddleware(BaseHTTPMiddleware):
|
|||||||
# sandbox="allow-scripts" attribute provides isolation.
|
# sandbox="allow-scripts" attribute provides isolation.
|
||||||
# Don't overwrite the route's own restrictive CSP either.
|
# Don't overwrite the route's own restrictive CSP either.
|
||||||
pass
|
pass
|
||||||
|
elif is_document_pdf_preview:
|
||||||
|
response.headers["X-Frame-Options"] = "SAMEORIGIN"
|
||||||
|
response.headers["Content-Security-Policy"] = (
|
||||||
|
"default-src 'none'; "
|
||||||
|
"frame-ancestors 'self'"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
response.headers["X-Frame-Options"] = "DENY"
|
response.headers["X-Frame-Options"] = "DENY"
|
||||||
# NOTE: `style-src 'unsafe-inline'` is intentionally retained.
|
# NOTE: `style-src 'unsafe-inline'` is intentionally retained.
|
||||||
|
|||||||
+205
-3
@@ -18,10 +18,22 @@ import ntpath
|
|||||||
import shutil
|
import shutil
|
||||||
import subprocess
|
import subprocess
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
import sys
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
import platform
|
||||||
|
|
||||||
IS_WINDOWS = os.name == "nt"
|
IS_WINDOWS = os.name == "nt"
|
||||||
IS_POSIX = not IS_WINDOWS
|
IS_POSIX = not IS_WINDOWS
|
||||||
|
# Allows APFEL support and ARM-native binary recommendations on Apple Silicon Macs.
|
||||||
|
IS_APPLE_SILICON = (
|
||||||
|
IS_POSIX
|
||||||
|
and platform.system() == "Darwin"
|
||||||
|
and platform.machine().lower()
|
||||||
|
in {
|
||||||
|
"arm64",
|
||||||
|
"aarch64",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# ── File permissions ────────────────────────────────────────────────────────
|
# ── File permissions ────────────────────────────────────────────────────────
|
||||||
@@ -53,9 +65,8 @@ def detached_popen_kwargs() -> dict:
|
|||||||
and is detached from any console.
|
and is detached from any console.
|
||||||
"""
|
"""
|
||||||
if IS_WINDOWS:
|
if IS_WINDOWS:
|
||||||
flags = (
|
flags = getattr(subprocess, "CREATE_NEW_PROCESS_GROUP", 0x00000200) | getattr(
|
||||||
getattr(subprocess, "CREATE_NEW_PROCESS_GROUP", 0x00000200)
|
subprocess, "DETACHED_PROCESS", 0x00000008
|
||||||
| getattr(subprocess, "DETACHED_PROCESS", 0x00000008)
|
|
||||||
)
|
)
|
||||||
return {"creationflags": flags}
|
return {"creationflags": flags}
|
||||||
return {"start_new_session": True}
|
return {"start_new_session": True}
|
||||||
@@ -150,6 +161,29 @@ _WINDOWS_BASH_RELATIVE_PATHS = (
|
|||||||
("usr", "bin", "bash.exe"),
|
("usr", "bin", "bash.exe"),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Paths to add to the remote SSH probe command to find tools like nvidia-smi that may not be on PATH.
|
||||||
|
_SSH_PATH_MEMBERS = (
|
||||||
|
"/usr/bin",
|
||||||
|
"/usr/local/bin",
|
||||||
|
"/usr/local/cuda/bin",
|
||||||
|
"/usr/lib/wsl/lib"
|
||||||
|
)
|
||||||
|
# Fallback locations for nvidia-smi on WSL and other Linux distros where it may not be on PATH.
|
||||||
|
NVIDIA_PATH_CANDIDATES = (
|
||||||
|
"/usr/bin/nvidia-smi",
|
||||||
|
"/usr/local/bin/nvidia-smi",
|
||||||
|
"/usr/local/cuda/bin/nvidia-smi",
|
||||||
|
"/usr/lib/wsl/lib/nvidia-smi",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _ssh_path_override() -> str:
|
||||||
|
"""Build the PATH export snippet used for remote SSH shell probes."""
|
||||||
|
return f"export PATH=\"$PATH:{':'.join(_SSH_PATH_MEMBERS)}\"; "
|
||||||
|
|
||||||
|
|
||||||
|
SSH_PATH_OVERRIDE = _ssh_path_override()
|
||||||
|
|
||||||
|
|
||||||
def _windows_bash_fallbacks() -> List[str]:
|
def _windows_bash_fallbacks() -> List[str]:
|
||||||
roots: List[str] = []
|
roots: List[str] = []
|
||||||
@@ -180,6 +214,21 @@ def _is_windows_bash_stub(path: str) -> bool:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def git_bash_path(path: str | Path) -> str:
|
||||||
|
"""Convert a path to POSIX style suitable for Git Bash on Windows.
|
||||||
|
|
||||||
|
Transforms drive letters (e.g., 'C:\\path') to POSIX '/c/path',
|
||||||
|
and uses forward slashes.
|
||||||
|
"""
|
||||||
|
p = Path(path)
|
||||||
|
p_str = p.as_posix()
|
||||||
|
if IS_WINDOWS and len(p_str) >= 2 and p_str[1] == ":":
|
||||||
|
drive = p_str[0].lower()
|
||||||
|
return f"/{drive}{p_str[2:]}"
|
||||||
|
return p_str
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def find_bash() -> Optional[str]:
|
def find_bash() -> Optional[str]:
|
||||||
"""Locate a real ``bash`` interpreter, or None.
|
"""Locate a real ``bash`` interpreter, or None.
|
||||||
|
|
||||||
@@ -242,3 +291,156 @@ def run_script_argv(script_path) -> List[str]:
|
|||||||
comspec = os.environ.get("ComSpec", "cmd.exe")
|
comspec = os.environ.get("ComSpec", "cmd.exe")
|
||||||
return [comspec, "/c", str(script_path)]
|
return [comspec, "/c", str(script_path)]
|
||||||
return ["sh", str(script_path)]
|
return ["sh", str(script_path)]
|
||||||
|
|
||||||
|
|
||||||
|
def is_wsl() -> bool:
|
||||||
|
"""True if running inside Windows Subsystem for Linux (WSL)."""
|
||||||
|
import sys
|
||||||
|
if sys.platform.startswith("linux") or os.name == "posix":
|
||||||
|
try:
|
||||||
|
with open("/proc/version", "r") as f:
|
||||||
|
if "microsoft" in f.read().lower():
|
||||||
|
return True
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def translate_path(path_str: str) -> str:
|
||||||
|
"""Translate a path (possibly a Windows path) to the current OS format.
|
||||||
|
|
||||||
|
Particularly handles Windows paths (e.g. C:\\foo or C:/foo) when running
|
||||||
|
under WSL, translating them to /mnt/c/foo.
|
||||||
|
Also handles standard path normalization to avoid string breakages.
|
||||||
|
"""
|
||||||
|
if not path_str:
|
||||||
|
return path_str
|
||||||
|
|
||||||
|
if is_wsl():
|
||||||
|
path_str = path_str.replace("\\", "/")
|
||||||
|
import re
|
||||||
|
m = re.match(r"^([a-zA-Z]):(.*)", path_str)
|
||||||
|
if m:
|
||||||
|
drive = m.group(1).lower()
|
||||||
|
rest = m.group(2)
|
||||||
|
if not rest.startswith("/"):
|
||||||
|
rest = "/" + rest
|
||||||
|
return f"/mnt/{drive}{rest}"
|
||||||
|
|
||||||
|
try:
|
||||||
|
return str(Path(path_str).resolve())
|
||||||
|
except Exception:
|
||||||
|
return path_str
|
||||||
|
|
||||||
|
|
||||||
|
def get_wsl_windows_user_profile() -> Optional[str]:
|
||||||
|
"""Retrieve the Windows host User Profile path from inside WSL."""
|
||||||
|
if not is_wsl():
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
r = run_wsl_windows_powershell("Write-Output $env:USERPROFILE", timeout=5)
|
||||||
|
if r.returncode == 0 and r.stdout.strip():
|
||||||
|
return translate_path(r.stdout.strip())
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
users_dir = "/mnt/c/Users"
|
||||||
|
if os.path.isdir(users_dir):
|
||||||
|
for entry in os.listdir(users_dir):
|
||||||
|
if entry not in ("All Users", "Default", "Default User", "desktop.ini", "Public"):
|
||||||
|
path = os.path.join(users_dir, entry)
|
||||||
|
if os.path.isdir(path):
|
||||||
|
return path
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _ssh_exec_argv(
|
||||||
|
remote: str,
|
||||||
|
ssh_port: str | None,
|
||||||
|
*,
|
||||||
|
remote_cmd: str | None = None,
|
||||||
|
connect_timeout: int | None = None,
|
||||||
|
strict_host_key_checking: bool | None = None,
|
||||||
|
) -> list[str]:
|
||||||
|
"""Build a consistent ssh argv for remote command execution."""
|
||||||
|
argv = ["ssh"]
|
||||||
|
if connect_timeout is not None:
|
||||||
|
argv.extend(["-o", f"ConnectTimeout={int(connect_timeout)}"])
|
||||||
|
if strict_host_key_checking is not None:
|
||||||
|
argv.extend(
|
||||||
|
[
|
||||||
|
"-o",
|
||||||
|
"StrictHostKeyChecking=yes"
|
||||||
|
if strict_host_key_checking
|
||||||
|
else "StrictHostKeyChecking=no",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
if ssh_port and ssh_port != "22":
|
||||||
|
argv.extend(["-p", str(ssh_port)])
|
||||||
|
argv.append(remote)
|
||||||
|
if remote_cmd is not None:
|
||||||
|
argv.append(remote_cmd)
|
||||||
|
return argv
|
||||||
|
|
||||||
|
|
||||||
|
def run_ssh_command(
|
||||||
|
remote: str,
|
||||||
|
ssh_port: str | None,
|
||||||
|
remote_cmd: str,
|
||||||
|
*,
|
||||||
|
timeout: float,
|
||||||
|
connect_timeout: int | None = None,
|
||||||
|
strict_host_key_checking: bool | None = None,
|
||||||
|
text: bool = True,
|
||||||
|
) -> subprocess.CompletedProcess:
|
||||||
|
"""Run an ssh command with centralized timeout and stderr/stdout capture."""
|
||||||
|
return subprocess.run(
|
||||||
|
_ssh_exec_argv(
|
||||||
|
remote,
|
||||||
|
ssh_port,
|
||||||
|
remote_cmd=remote_cmd,
|
||||||
|
connect_timeout=connect_timeout,
|
||||||
|
strict_host_key_checking=strict_host_key_checking,
|
||||||
|
),
|
||||||
|
timeout=timeout,
|
||||||
|
capture_output=True,
|
||||||
|
text=text,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _windows_powershell_argv(
|
||||||
|
command: str,
|
||||||
|
*,
|
||||||
|
no_profile: bool = True,
|
||||||
|
non_interactive: bool = True,
|
||||||
|
) -> List[str]:
|
||||||
|
argv: List[str] = ["powershell.exe"]
|
||||||
|
if no_profile:
|
||||||
|
argv.append("-NoProfile")
|
||||||
|
if non_interactive:
|
||||||
|
argv.append("-NonInteractive")
|
||||||
|
argv.extend(["-Command", command])
|
||||||
|
return argv
|
||||||
|
|
||||||
|
|
||||||
|
def run_wsl_windows_powershell(
|
||||||
|
command: str,
|
||||||
|
*,
|
||||||
|
timeout: float = 5,
|
||||||
|
) -> subprocess.CompletedProcess[str]:
|
||||||
|
"""Run a PowerShell command on the Windows host from WSL.
|
||||||
|
|
||||||
|
Raises ``RuntimeError`` when called outside WSL.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if not is_wsl():
|
||||||
|
raise RuntimeError("run_wsl_windows_powershell is only supported in WSL")
|
||||||
|
return subprocess.run(
|
||||||
|
_windows_powershell_argv(command),
|
||||||
|
capture_output=True,
|
||||||
|
text=True,
|
||||||
|
timeout=timeout,
|
||||||
|
)
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ import logging
|
|||||||
from datetime import datetime, timezone, timedelta
|
from datetime import datetime, timezone, timedelta
|
||||||
from typing import Dict, Optional
|
from typing import Dict, Optional
|
||||||
|
|
||||||
from .database import Session as DbSession, ChatMessage as DbChatMessage, Document as DbDocument, SessionLocal
|
from .database import Session as DbSession, ChatMessage as DbChatMessage, Document as DbDocument, SessionLocal, utcnow_naive
|
||||||
from .models import Session, ChatMessage
|
from .models import Session, ChatMessage
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -619,7 +619,7 @@ class SessionManager:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
all_sessions = db.query(DbSession).all()
|
all_sessions = db.query(DbSession).all()
|
||||||
cutoff_date = datetime.now(timezone.utc) - timedelta(days=auto_archive_days)
|
cutoff_date = utcnow_naive() - timedelta(days=auto_archive_days)
|
||||||
|
|
||||||
for db_session in all_sessions:
|
for db_session in all_sessions:
|
||||||
stats['total_checked'] += 1
|
stats['total_checked'] += 1
|
||||||
|
|||||||
@@ -52,12 +52,14 @@ services:
|
|||||||
- SECURE_COOKIES=${SECURE_COOKIES:-false}
|
- SECURE_COOKIES=${SECURE_COOKIES:-false}
|
||||||
- EMBEDDING_URL=${EMBEDDING_URL:-}
|
- EMBEDDING_URL=${EMBEDDING_URL:-}
|
||||||
- EMBEDDING_MODEL=${EMBEDDING_MODEL:-}
|
- EMBEDDING_MODEL=${EMBEDDING_MODEL:-}
|
||||||
|
- EMBEDDING_API_KEY=${EMBEDDING_API_KEY:-}
|
||||||
- FASTEMBED_MODEL=${FASTEMBED_MODEL:-sentence-transformers/all-MiniLM-L6-v2}
|
- FASTEMBED_MODEL=${FASTEMBED_MODEL:-sentence-transformers/all-MiniLM-L6-v2}
|
||||||
- FASTEMBED_CACHE_PATH=${FASTEMBED_CACHE_PATH:-}
|
- FASTEMBED_CACHE_PATH=${FASTEMBED_CACHE_PATH:-}
|
||||||
- CLEANUP_INTERVAL_HOURS=${CLEANUP_INTERVAL_HOURS:-24}
|
- CLEANUP_INTERVAL_HOURS=${CLEANUP_INTERVAL_HOURS:-24}
|
||||||
- ODYSSEUS_INPROCESS_POLLERS=${ODYSSEUS_INPROCESS_POLLERS:-1}
|
- ODYSSEUS_INPROCESS_POLLERS=${ODYSSEUS_INPROCESS_POLLERS:-1}
|
||||||
- ODYSSEUS_INPROCESS_TASKS=${ODYSSEUS_INPROCESS_TASKS:-1}
|
- ODYSSEUS_INPROCESS_TASKS=${ODYSSEUS_INPROCESS_TASKS:-1}
|
||||||
- ODYSSEUS_SCRIPT_HOST=${ODYSSEUS_SCRIPT_HOST:-localhost}
|
- ODYSSEUS_SCRIPT_HOST=${ODYSSEUS_SCRIPT_HOST:-localhost}
|
||||||
|
- ODYSSEUS_CHAT_UPLOAD_MAX_BYTES=${ODYSSEUS_CHAT_UPLOAD_MAX_BYTES:-10485760}
|
||||||
- DATA_BRAVE_API_KEY=${DATA_BRAVE_API_KEY:-}
|
- DATA_BRAVE_API_KEY=${DATA_BRAVE_API_KEY:-}
|
||||||
- GOOGLE_API_KEY=${GOOGLE_API_KEY:-}
|
- GOOGLE_API_KEY=${GOOGLE_API_KEY:-}
|
||||||
- GOOGLE_PSE_CX=${GOOGLE_PSE_CX:-}
|
- GOOGLE_PSE_CX=${GOOGLE_PSE_CX:-}
|
||||||
|
|||||||
@@ -51,12 +51,14 @@ services:
|
|||||||
- SECURE_COOKIES=${SECURE_COOKIES:-false}
|
- SECURE_COOKIES=${SECURE_COOKIES:-false}
|
||||||
- EMBEDDING_URL=${EMBEDDING_URL:-}
|
- EMBEDDING_URL=${EMBEDDING_URL:-}
|
||||||
- EMBEDDING_MODEL=${EMBEDDING_MODEL:-}
|
- EMBEDDING_MODEL=${EMBEDDING_MODEL:-}
|
||||||
|
- EMBEDDING_API_KEY=${EMBEDDING_API_KEY:-}
|
||||||
- FASTEMBED_MODEL=${FASTEMBED_MODEL:-sentence-transformers/all-MiniLM-L6-v2}
|
- FASTEMBED_MODEL=${FASTEMBED_MODEL:-sentence-transformers/all-MiniLM-L6-v2}
|
||||||
- FASTEMBED_CACHE_PATH=${FASTEMBED_CACHE_PATH:-}
|
- FASTEMBED_CACHE_PATH=${FASTEMBED_CACHE_PATH:-}
|
||||||
- CLEANUP_INTERVAL_HOURS=${CLEANUP_INTERVAL_HOURS:-24}
|
- CLEANUP_INTERVAL_HOURS=${CLEANUP_INTERVAL_HOURS:-24}
|
||||||
- ODYSSEUS_INPROCESS_POLLERS=${ODYSSEUS_INPROCESS_POLLERS:-1}
|
- ODYSSEUS_INPROCESS_POLLERS=${ODYSSEUS_INPROCESS_POLLERS:-1}
|
||||||
- ODYSSEUS_INPROCESS_TASKS=${ODYSSEUS_INPROCESS_TASKS:-1}
|
- ODYSSEUS_INPROCESS_TASKS=${ODYSSEUS_INPROCESS_TASKS:-1}
|
||||||
- ODYSSEUS_SCRIPT_HOST=${ODYSSEUS_SCRIPT_HOST:-localhost}
|
- ODYSSEUS_SCRIPT_HOST=${ODYSSEUS_SCRIPT_HOST:-localhost}
|
||||||
|
- ODYSSEUS_CHAT_UPLOAD_MAX_BYTES=${ODYSSEUS_CHAT_UPLOAD_MAX_BYTES:-10485760}
|
||||||
- DATA_BRAVE_API_KEY=${DATA_BRAVE_API_KEY:-}
|
- DATA_BRAVE_API_KEY=${DATA_BRAVE_API_KEY:-}
|
||||||
- GOOGLE_API_KEY=${GOOGLE_API_KEY:-}
|
- GOOGLE_API_KEY=${GOOGLE_API_KEY:-}
|
||||||
- GOOGLE_PSE_CX=${GOOGLE_PSE_CX:-}
|
- GOOGLE_PSE_CX=${GOOGLE_PSE_CX:-}
|
||||||
|
|||||||
@@ -40,12 +40,14 @@ services:
|
|||||||
- SECURE_COOKIES=${SECURE_COOKIES:-false}
|
- SECURE_COOKIES=${SECURE_COOKIES:-false}
|
||||||
- EMBEDDING_URL=${EMBEDDING_URL:-}
|
- EMBEDDING_URL=${EMBEDDING_URL:-}
|
||||||
- EMBEDDING_MODEL=${EMBEDDING_MODEL:-}
|
- EMBEDDING_MODEL=${EMBEDDING_MODEL:-}
|
||||||
|
- EMBEDDING_API_KEY=${EMBEDDING_API_KEY:-}
|
||||||
- FASTEMBED_MODEL=${FASTEMBED_MODEL:-sentence-transformers/all-MiniLM-L6-v2}
|
- FASTEMBED_MODEL=${FASTEMBED_MODEL:-sentence-transformers/all-MiniLM-L6-v2}
|
||||||
- FASTEMBED_CACHE_PATH=${FASTEMBED_CACHE_PATH:-}
|
- FASTEMBED_CACHE_PATH=${FASTEMBED_CACHE_PATH:-}
|
||||||
- CLEANUP_INTERVAL_HOURS=${CLEANUP_INTERVAL_HOURS:-24}
|
- CLEANUP_INTERVAL_HOURS=${CLEANUP_INTERVAL_HOURS:-24}
|
||||||
- ODYSSEUS_INPROCESS_POLLERS=${ODYSSEUS_INPROCESS_POLLERS:-1}
|
- ODYSSEUS_INPROCESS_POLLERS=${ODYSSEUS_INPROCESS_POLLERS:-1}
|
||||||
- ODYSSEUS_INPROCESS_TASKS=${ODYSSEUS_INPROCESS_TASKS:-1}
|
- ODYSSEUS_INPROCESS_TASKS=${ODYSSEUS_INPROCESS_TASKS:-1}
|
||||||
- ODYSSEUS_SCRIPT_HOST=${ODYSSEUS_SCRIPT_HOST:-localhost}
|
- ODYSSEUS_SCRIPT_HOST=${ODYSSEUS_SCRIPT_HOST:-localhost}
|
||||||
|
- ODYSSEUS_CHAT_UPLOAD_MAX_BYTES=${ODYSSEUS_CHAT_UPLOAD_MAX_BYTES:-10485760}
|
||||||
- DATA_BRAVE_API_KEY=${DATA_BRAVE_API_KEY:-}
|
- DATA_BRAVE_API_KEY=${DATA_BRAVE_API_KEY:-}
|
||||||
- GOOGLE_API_KEY=${GOOGLE_API_KEY:-}
|
- GOOGLE_API_KEY=${GOOGLE_API_KEY:-}
|
||||||
- GOOGLE_PSE_CX=${GOOGLE_PSE_CX:-}
|
- GOOGLE_PSE_CX=${GOOGLE_PSE_CX:-}
|
||||||
|
|||||||
@@ -0,0 +1,17 @@
|
|||||||
|
# Outlook / Office 365 email accounts
|
||||||
|
|
||||||
|
Odysseus email accounts currently use IMAP and SMTP with username/password
|
||||||
|
authentication. That works for providers that still allow app passwords or
|
||||||
|
mailbox passwords for IMAP/SMTP.
|
||||||
|
|
||||||
|
Microsoft disables basic authentication for Outlook and Microsoft 365 in most
|
||||||
|
modern accounts and tenants. If you try to add an Outlook account with a normal
|
||||||
|
password, Microsoft may return errors such as:
|
||||||
|
|
||||||
|
- `IMAP: AUTHENTICATE failed`
|
||||||
|
- `SMTP: 535 5.7.139 Authentication unsuccessful, basic authentication is disabled`
|
||||||
|
|
||||||
|
This is expected. Odysseus does not support Microsoft OAuth or Graph Mail yet,
|
||||||
|
so Outlook / Office 365 accounts cannot currently be added through the password
|
||||||
|
form. Use another email provider with app-password support, or track the future
|
||||||
|
Microsoft Graph OAuth integration.
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
---
|
---
|
||||||
name: odysseus
|
name: odysseus
|
||||||
description: Use when the user asks Claude Code to read or write Odysseus data (todos, email, calendar, memory, documents) through the scoped Claude Agent API. Requires ODYSSEUS_URL and ODYSSEUS_API_TOKEN.
|
description: Use when the user asks Claude Code to read or write Odysseus data (todos, email, calendar, memory, documents) or to launch/monitor/stop a Cookbook model-serve task through the scoped Claude Agent API. Requires ODYSSEUS_URL and ODYSSEUS_API_TOKEN.
|
||||||
---
|
---
|
||||||
|
|
||||||
# Odysseus
|
# Odysseus
|
||||||
@@ -105,6 +105,49 @@ python3 ~/.claude/skills/odysseus/scripts/odysseus_api.py POST /api/codex/memory
|
|||||||
- `POST /api/codex/emails/draft` — body matches `SendEmailRequest` (`to`, `cc`, `bcc`, `subject`, `body`, `body_html`, `attachments`, `account_id`, `in_reply_to`, `references`). Requires `email:draft` (or `email:send`).
|
- `POST /api/codex/emails/draft` — body matches `SendEmailRequest` (`to`, `cc`, `bcc`, `subject`, `body`, `body_html`, `attachments`, `account_id`, `in_reply_to`, `references`). Requires `email:draft` (or `email:send`).
|
||||||
- `POST /api/codex/emails/send` — same body. Requires `email:send`. Never send without explicit user instruction.
|
- `POST /api/codex/emails/send` — same body. Requires `email:send`. Never send without explicit user instruction.
|
||||||
|
|
||||||
|
## Cookbook serve (debug a failing model launch)
|
||||||
|
|
||||||
|
The Cookbook surface lets you reproduce what a human would do in Odysseus → Cookbook: read which serves are running, tail their tmux output to see why they crashed, edit the launch command, relaunch, kill a stuck one. Use this when the user is debugging a model server that won't come up (compute-capability errors, OOM, missing kernels, wrong attention backend, etc.).
|
||||||
|
|
||||||
|
- `GET /api/codex/cookbook/tasks` — list active serve/download/install tasks (sessionId, type, status, repo_id, remoteHost, payload._cmd). Requires `cookbook:read`.
|
||||||
|
- `GET /api/codex/cookbook/servers` — list configured servers (name, host, port, env type + path, model dirs). Requires `cookbook:read`.
|
||||||
|
- `GET /api/codex/cookbook/cached?host=<NAME>` — list models already cached on the named server (HF cache + Ollama + extra modelDirs). Call BEFORE `serve` to see what's already on disk. Requires `cookbook:read`.
|
||||||
|
- `GET /api/codex/cookbook/presets` — list saved serve presets (model + host + port + cmd). The user's saved preset usually has a working cmd — try `preset NAME` before composing your own. Requires `cookbook:read`.
|
||||||
|
- `GET /api/codex/cookbook/output/{session_id}?tail=400` — read the last N lines of the task's persistent log file (preferred) or tmux pane (fallback). The log file persists across vllm crashes, so this returns the actual Python traceback even after the bash prompt + neofetch banner overwrites the pane. Default tail=400. Requires `cookbook:read`.
|
||||||
|
- `POST /api/codex/cookbook/serve` — launch a serve task. Body matches `ServeRequest`: `{ repo_id, cmd, remote_host?, ssh_port?, env_prefix?, gpus?, platform? }`. The `cmd` is validated: leading binary must be `vllm`/`python3`/`sglang`/`llama-server`/`ollama`/`node`/`npx`. NEVER prefix with `cd …`, `source …`, or chain with `&&`/`||`/`;`/`$(...)` — the validator rejects shell metacharacters. The venv activation (`env_prefix`) is added automatically from the host's saved settings, so pass the bare binary + args. Requires `cookbook:launch`.
|
||||||
|
- `POST /api/codex/cookbook/preset/{name}` — launch a saved preset by name. Reuses the working cmd + host the user already saved. Requires `cookbook:launch`.
|
||||||
|
- `POST /api/codex/cookbook/adopt` — register an externally-launched tmux session into cookbook tracking. Body: `{ tmux_session, model, host?, port? }`. Use this when serve_model rejected a cmd and you fell back to direct ssh+tmux — without adoption, the session is invisible to the UI. Requires `cookbook:launch`.
|
||||||
|
- `POST /api/codex/cookbook/stop/{session_id}` — kill the tmux session for that task. Requires `cookbook:launch`.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Survey what's running
|
||||||
|
python3 ~/.claude/skills/odysseus/scripts/odysseus_api.py cookbook tasks
|
||||||
|
|
||||||
|
# Tail the failing one (sessionId from `cookbook tasks`)
|
||||||
|
python3 ~/.claude/skills/odysseus/scripts/odysseus_api.py cookbook output serve-abc12345 400
|
||||||
|
|
||||||
|
# Stop the previous attempt before you try a new flag set
|
||||||
|
python3 ~/.claude/skills/odysseus/scripts/odysseus_api.py cookbook stop serve-abc12345
|
||||||
|
|
||||||
|
# Relaunch with new flags. cmd MUST begin with one of the allowlisted binaries.
|
||||||
|
python3 ~/.claude/skills/odysseus/scripts/odysseus_api.py cookbook serve \
|
||||||
|
/mnt/HADES/models/Qwen3.5-397B-A17B-AWQ \
|
||||||
|
"vllm serve /mnt/HADES/models/Qwen3.5-397B-A17B-AWQ --host 0.0.0.0 --port 8001 --tensor-parallel-size 8 --max-model-len 262144 --gpu-memory-utilization 0.90 --dtype auto --max-num-seqs 8 --trust-remote-code --enable-expert-parallel --enable-auto-tool-choice --tool-call-parser qwen3_coder --reasoning-parser qwen3" \
|
||||||
|
pewds@192.168.1.12
|
||||||
|
```
|
||||||
|
|
||||||
|
**Debug loop pattern:** when a serve is failing, the productive sequence is
|
||||||
|
|
||||||
|
1. `cookbook tasks` → find the failing sessionId.
|
||||||
|
2. `cookbook output SID 600` → read the last 600 lines, find the actual root-cause line (often above the visible tail because tmux scrollback rolled — request a larger `tail` if the error references "above").
|
||||||
|
3. `cookbook stop SID` — kill the previous attempt before relaunching; two serves on the same `--port` collide.
|
||||||
|
4. `cookbook serve repo "new cmd"` — try the next variation. Wait ~20s, then `cookbook output` on the new sessionId.
|
||||||
|
|
||||||
|
**Hard limits this surface enforces:**
|
||||||
|
- `cookbook serve` cmd allowlist + shell-metacharacter rejection — you cannot run arbitrary shell, only model-server binaries.
|
||||||
|
- `cookbook stop` only targets task sessionIds matching `[a-zA-Z0-9_-]+`.
|
||||||
|
- The agent CAN spawn GPU-pinning long-lived processes — always `cookbook stop` your previous attempt before relaunching, and check `cookbook tasks` for collisions on the same `--port` before launching.
|
||||||
|
|
||||||
## Forbidden Bypass Pattern
|
## Forbidden Bypass Pattern
|
||||||
|
|
||||||
If you are about to reach the Odysseus host/container, import app internals, query the database, or call MCP helper modules directly, stop. Those paths bypass Odysseus Settings and token scopes. Ask the user to enable the relevant Claude Agent tool toggle instead.
|
If you are about to reach the Odysseus host/container, import app internals, query the database, or call MCP helper modules directly, stop. Those paths bypass Odysseus Settings and token scopes. Ask the user to enable the relevant Claude Agent tool toggle instead.
|
||||||
|
|||||||
@@ -17,6 +17,15 @@ def _usage() -> int:
|
|||||||
print(" odysseus_api.py todos add TITLE", file=sys.stderr)
|
print(" odysseus_api.py todos add TITLE", file=sys.stderr)
|
||||||
print(" odysseus_api.py emails list [limit]", file=sys.stderr)
|
print(" odysseus_api.py emails list [limit]", file=sys.stderr)
|
||||||
print(" odysseus_api.py emails read UID", file=sys.stderr)
|
print(" odysseus_api.py emails read UID", file=sys.stderr)
|
||||||
|
print(" odysseus_api.py cookbook tasks", file=sys.stderr)
|
||||||
|
print(" odysseus_api.py cookbook servers", file=sys.stderr)
|
||||||
|
print(" odysseus_api.py cookbook cached [HOST]", file=sys.stderr)
|
||||||
|
print(" odysseus_api.py cookbook presets", file=sys.stderr)
|
||||||
|
print(" odysseus_api.py cookbook output SESSION_ID [tail]", file=sys.stderr)
|
||||||
|
print(" odysseus_api.py cookbook serve REPO_ID 'CMD' [REMOTE_HOST]", file=sys.stderr)
|
||||||
|
print(" odysseus_api.py cookbook preset NAME", file=sys.stderr)
|
||||||
|
print(" odysseus_api.py cookbook adopt SESSION_ID MODEL [HOST] [PORT]", file=sys.stderr)
|
||||||
|
print(" odysseus_api.py cookbook stop SESSION_ID", file=sys.stderr)
|
||||||
print(" odysseus_api.py METHOD /api/codex/path [json-body]", file=sys.stderr)
|
print(" odysseus_api.py METHOD /api/codex/path [json-body]", file=sys.stderr)
|
||||||
return 2
|
return 2
|
||||||
|
|
||||||
@@ -72,6 +81,61 @@ def main() -> int:
|
|||||||
body = None
|
body = None
|
||||||
else:
|
else:
|
||||||
return _usage()
|
return _usage()
|
||||||
|
elif command == "cookbook":
|
||||||
|
if len(sys.argv) < 3:
|
||||||
|
return _usage()
|
||||||
|
action = sys.argv[2].lower()
|
||||||
|
if action == "tasks":
|
||||||
|
method = "GET"
|
||||||
|
path = "/api/codex/cookbook/tasks"
|
||||||
|
body = None
|
||||||
|
elif action == "servers":
|
||||||
|
method = "GET"
|
||||||
|
path = "/api/codex/cookbook/servers"
|
||||||
|
body = None
|
||||||
|
elif action == "output" and len(sys.argv) >= 4:
|
||||||
|
method = "GET"
|
||||||
|
sid = sys.argv[3]
|
||||||
|
tail = sys.argv[4] if len(sys.argv) >= 5 else "400"
|
||||||
|
path = f"/api/codex/cookbook/output/{sid}?tail={tail}"
|
||||||
|
body = None
|
||||||
|
elif action == "cached":
|
||||||
|
method = "GET"
|
||||||
|
if len(sys.argv) >= 4:
|
||||||
|
from urllib.parse import quote
|
||||||
|
path = f"/api/codex/cookbook/cached?host={quote(sys.argv[3])}"
|
||||||
|
else:
|
||||||
|
path = "/api/codex/cookbook/cached"
|
||||||
|
body = None
|
||||||
|
elif action == "presets":
|
||||||
|
method = "GET"
|
||||||
|
path = "/api/codex/cookbook/presets"
|
||||||
|
body = None
|
||||||
|
elif action == "preset" and len(sys.argv) >= 4:
|
||||||
|
from urllib.parse import quote
|
||||||
|
method = "POST"
|
||||||
|
path = f"/api/codex/cookbook/preset/{quote(sys.argv[3])}"
|
||||||
|
body = None
|
||||||
|
elif action == "adopt" and len(sys.argv) >= 5:
|
||||||
|
method = "POST"
|
||||||
|
path = "/api/codex/cookbook/adopt"
|
||||||
|
payload = {"tmux_session": sys.argv[3], "model": sys.argv[4]}
|
||||||
|
if len(sys.argv) >= 6: payload["host"] = sys.argv[5]
|
||||||
|
if len(sys.argv) >= 7: payload["port"] = int(sys.argv[6])
|
||||||
|
body = json.dumps(payload)
|
||||||
|
elif action == "serve" and len(sys.argv) >= 5:
|
||||||
|
method = "POST"
|
||||||
|
path = "/api/codex/cookbook/serve"
|
||||||
|
payload = {"repo_id": sys.argv[3], "cmd": sys.argv[4]}
|
||||||
|
if len(sys.argv) >= 6:
|
||||||
|
payload["remote_host"] = sys.argv[5]
|
||||||
|
body = json.dumps(payload)
|
||||||
|
elif action == "stop" and len(sys.argv) >= 4:
|
||||||
|
method = "POST"
|
||||||
|
path = f"/api/codex/cookbook/stop/{sys.argv[3]}"
|
||||||
|
body = None
|
||||||
|
else:
|
||||||
|
return _usage()
|
||||||
else:
|
else:
|
||||||
if len(sys.argv) < 3:
|
if len(sys.argv) < 3:
|
||||||
return _usage()
|
return _usage()
|
||||||
|
|||||||
@@ -17,6 +17,15 @@ def _usage() -> int:
|
|||||||
print(" odysseus_api.py todos add TITLE", file=sys.stderr)
|
print(" odysseus_api.py todos add TITLE", file=sys.stderr)
|
||||||
print(" odysseus_api.py emails list [limit]", file=sys.stderr)
|
print(" odysseus_api.py emails list [limit]", file=sys.stderr)
|
||||||
print(" odysseus_api.py emails read UID", file=sys.stderr)
|
print(" odysseus_api.py emails read UID", file=sys.stderr)
|
||||||
|
print(" odysseus_api.py cookbook tasks", file=sys.stderr)
|
||||||
|
print(" odysseus_api.py cookbook servers", file=sys.stderr)
|
||||||
|
print(" odysseus_api.py cookbook cached [HOST]", file=sys.stderr)
|
||||||
|
print(" odysseus_api.py cookbook presets", file=sys.stderr)
|
||||||
|
print(" odysseus_api.py cookbook output SESSION_ID [tail]", file=sys.stderr)
|
||||||
|
print(" odysseus_api.py cookbook serve REPO_ID 'CMD' [REMOTE_HOST]", file=sys.stderr)
|
||||||
|
print(" odysseus_api.py cookbook preset NAME", file=sys.stderr)
|
||||||
|
print(" odysseus_api.py cookbook adopt SESSION_ID MODEL [HOST] [PORT]", file=sys.stderr)
|
||||||
|
print(" odysseus_api.py cookbook stop SESSION_ID", file=sys.stderr)
|
||||||
print(" odysseus_api.py METHOD /api/codex/path [json-body]", file=sys.stderr)
|
print(" odysseus_api.py METHOD /api/codex/path [json-body]", file=sys.stderr)
|
||||||
return 2
|
return 2
|
||||||
|
|
||||||
@@ -72,6 +81,61 @@ def main() -> int:
|
|||||||
body = None
|
body = None
|
||||||
else:
|
else:
|
||||||
return _usage()
|
return _usage()
|
||||||
|
elif command == "cookbook":
|
||||||
|
if len(sys.argv) < 3:
|
||||||
|
return _usage()
|
||||||
|
action = sys.argv[2].lower()
|
||||||
|
if action == "tasks":
|
||||||
|
method = "GET"
|
||||||
|
path = "/api/codex/cookbook/tasks"
|
||||||
|
body = None
|
||||||
|
elif action == "servers":
|
||||||
|
method = "GET"
|
||||||
|
path = "/api/codex/cookbook/servers"
|
||||||
|
body = None
|
||||||
|
elif action == "output" and len(sys.argv) >= 4:
|
||||||
|
method = "GET"
|
||||||
|
sid = sys.argv[3]
|
||||||
|
tail = sys.argv[4] if len(sys.argv) >= 5 else "400"
|
||||||
|
path = f"/api/codex/cookbook/output/{sid}?tail={tail}"
|
||||||
|
body = None
|
||||||
|
elif action == "cached":
|
||||||
|
method = "GET"
|
||||||
|
if len(sys.argv) >= 4:
|
||||||
|
from urllib.parse import quote
|
||||||
|
path = f"/api/codex/cookbook/cached?host={quote(sys.argv[3])}"
|
||||||
|
else:
|
||||||
|
path = "/api/codex/cookbook/cached"
|
||||||
|
body = None
|
||||||
|
elif action == "presets":
|
||||||
|
method = "GET"
|
||||||
|
path = "/api/codex/cookbook/presets"
|
||||||
|
body = None
|
||||||
|
elif action == "preset" and len(sys.argv) >= 4:
|
||||||
|
from urllib.parse import quote
|
||||||
|
method = "POST"
|
||||||
|
path = f"/api/codex/cookbook/preset/{quote(sys.argv[3])}"
|
||||||
|
body = None
|
||||||
|
elif action == "adopt" and len(sys.argv) >= 5:
|
||||||
|
method = "POST"
|
||||||
|
path = "/api/codex/cookbook/adopt"
|
||||||
|
payload = {"tmux_session": sys.argv[3], "model": sys.argv[4]}
|
||||||
|
if len(sys.argv) >= 6: payload["host"] = sys.argv[5]
|
||||||
|
if len(sys.argv) >= 7: payload["port"] = int(sys.argv[6])
|
||||||
|
body = json.dumps(payload)
|
||||||
|
elif action == "serve" and len(sys.argv) >= 5:
|
||||||
|
method = "POST"
|
||||||
|
path = "/api/codex/cookbook/serve"
|
||||||
|
payload = {"repo_id": sys.argv[3], "cmd": sys.argv[4]}
|
||||||
|
if len(sys.argv) >= 6:
|
||||||
|
payload["remote_host"] = sys.argv[5]
|
||||||
|
body = json.dumps(payload)
|
||||||
|
elif action == "stop" and len(sys.argv) >= 4:
|
||||||
|
method = "POST"
|
||||||
|
path = f"/api/codex/cookbook/stop/{sys.argv[3]}"
|
||||||
|
body = None
|
||||||
|
else:
|
||||||
|
return _usage()
|
||||||
else:
|
else:
|
||||||
if len(sys.argv) < 3:
|
if len(sys.argv) < 3:
|
||||||
return _usage()
|
return _usage()
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
---
|
---
|
||||||
name: odysseus
|
name: odysseus
|
||||||
description: Use when the user asks Codex to read or write Odysseus data from a terminal Codex session through the scoped Codex Agent API. Requires ODYSSEUS_URL and ODYSSEUS_API_TOKEN.
|
description: Use when the user asks Codex to read or write Odysseus data (todos, email, calendar, memory, documents) or to launch/monitor/stop a Cookbook model-serve task through the scoped Codex Agent API. Requires ODYSSEUS_URL and ODYSSEUS_API_TOKEN.
|
||||||
---
|
---
|
||||||
|
|
||||||
# Odysseus
|
# Odysseus
|
||||||
@@ -105,6 +105,37 @@ python3 integrations/codex/scripts/odysseus_api.py POST /api/codex/memory '{"tex
|
|||||||
- `POST /api/codex/emails/draft` — body matches `SendEmailRequest` (`to`, `cc`, `bcc`, `subject`, `body`, `body_html`, `attachments`, `account_id`, `in_reply_to`, `references`). Requires `email:draft` (or `email:send`).
|
- `POST /api/codex/emails/draft` — body matches `SendEmailRequest` (`to`, `cc`, `bcc`, `subject`, `body`, `body_html`, `attachments`, `account_id`, `in_reply_to`, `references`). Requires `email:draft` (or `email:send`).
|
||||||
- `POST /api/codex/emails/send` — same body. Requires `email:send`. Never send without explicit user instruction.
|
- `POST /api/codex/emails/send` — same body. Requires `email:send`. Never send without explicit user instruction.
|
||||||
|
|
||||||
|
## Cookbook serve (debug a failing model launch)
|
||||||
|
|
||||||
|
The Cookbook surface lets you reproduce what a human would do in Odysseus → Cookbook: read which serves are running, tail their tmux output to see why they crashed, edit the launch command, relaunch, kill a stuck one. Use this when the user is debugging a model server that won't come up (compute-capability errors, OOM, missing kernels, wrong attention backend, etc.).
|
||||||
|
|
||||||
|
- `GET /api/codex/cookbook/tasks` — list active serve/download/install tasks (sessionId, type, status, repo_id, remoteHost, payload._cmd). Requires `cookbook:read`.
|
||||||
|
- `GET /api/codex/cookbook/servers` — list configured servers (name, host, port, env type + path, model dirs). Requires `cookbook:read`.
|
||||||
|
- `GET /api/codex/cookbook/cached?host=<NAME>` — list models already cached on the named server (HF cache + Ollama + extra modelDirs). Call BEFORE `serve` to see what's already on disk. Requires `cookbook:read`.
|
||||||
|
- `GET /api/codex/cookbook/presets` — list saved serve presets (model + host + port + cmd). The user's saved preset usually has a working cmd — try `preset NAME` before composing your own. Requires `cookbook:read`.
|
||||||
|
- `GET /api/codex/cookbook/output/{session_id}?tail=400` — read the last N lines of the task's persistent log file (preferred) or tmux pane (fallback). The log file persists across vllm crashes, so this returns the actual Python traceback even after the bash prompt + neofetch banner overwrites the pane. Default tail=400. Requires `cookbook:read`.
|
||||||
|
- `POST /api/codex/cookbook/serve` — launch a serve task. Body matches `ServeRequest`: `{ repo_id, cmd, remote_host?, ssh_port?, env_prefix?, gpus?, platform? }`. The `cmd` is validated: leading binary must be `vllm`/`python3`/`sglang`/`llama-server`/`ollama`/`node`/`npx`. NEVER prefix with `cd …`, `source …`, or chain with `&&`/`||`/`;`/`$(...)` — the validator rejects shell metacharacters. The venv activation (`env_prefix`) is added automatically from the host's saved settings, so pass the bare binary + args. Requires `cookbook:launch`.
|
||||||
|
- `POST /api/codex/cookbook/preset/{name}` — launch a saved preset by name. Reuses the working cmd + host the user already saved. Requires `cookbook:launch`.
|
||||||
|
- `POST /api/codex/cookbook/adopt` — register an externally-launched tmux session into cookbook tracking. Body: `{ tmux_session, model, host?, port? }`. Use this when serve_model rejected a cmd and you fell back to direct ssh+tmux — without adoption, the session is invisible to the UI. Requires `cookbook:launch`.
|
||||||
|
- `POST /api/codex/cookbook/stop/{session_id}` — kill the tmux session. Requires `cookbook:launch`.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python3 ~/plugins/odysseus/scripts/odysseus_api.py cookbook tasks
|
||||||
|
python3 ~/plugins/odysseus/scripts/odysseus_api.py cookbook output serve-abc12345 400
|
||||||
|
python3 ~/plugins/odysseus/scripts/odysseus_api.py cookbook stop serve-abc12345
|
||||||
|
python3 ~/plugins/odysseus/scripts/odysseus_api.py cookbook serve \
|
||||||
|
/mnt/HADES/models/Qwen3.5-397B-A17B-AWQ \
|
||||||
|
"vllm serve /mnt/HADES/models/Qwen3.5-397B-A17B-AWQ --host 0.0.0.0 --port 8001 --tensor-parallel-size 8 --max-model-len 262144 --gpu-memory-utilization 0.90 --dtype auto --max-num-seqs 8 --trust-remote-code --enable-expert-parallel --enable-auto-tool-choice --tool-call-parser qwen3_coder --reasoning-parser qwen3" \
|
||||||
|
pewds@192.168.1.12
|
||||||
|
```
|
||||||
|
|
||||||
|
**Debug loop pattern:** `tasks` → `output SID 600` (find root cause; request larger `tail` if it references "above") → `stop SID` → `serve repo "new cmd"` → wait ~20s → `output` on the new sessionId.
|
||||||
|
|
||||||
|
**Hard limits this surface enforces:**
|
||||||
|
- `cookbook serve` cmd allowlist + shell-metacharacter rejection.
|
||||||
|
- `cookbook stop` requires sessionIds matching `[a-zA-Z0-9_-]+`.
|
||||||
|
- Agent CAN spawn GPU-pinning long-lived processes — always `cookbook stop` your previous attempt before relaunching.
|
||||||
|
|
||||||
## Forbidden Bypass Pattern
|
## Forbidden Bypass Pattern
|
||||||
|
|
||||||
If you are about to reach the Odysseus host/container, import app internals, query the database, or call MCP helper modules directly, stop. Those paths bypass Odysseus Settings and token scopes. Ask the user to enable the relevant Codex Agent tool toggle instead.
|
If you are about to reach the Odysseus host/container, import app internals, query the database, or call MCP helper modules directly, stop. Those paths bypass Odysseus Settings and token scopes. Ask the user to enable the relevant Codex Agent tool toggle instead.
|
||||||
|
|||||||
@@ -1,22 +0,0 @@
|
|||||||
"""
|
|
||||||
_common.py
|
|
||||||
|
|
||||||
Shared constants and helpers for built-in MCP servers.
|
|
||||||
"""
|
|
||||||
|
|
||||||
MAX_OUTPUT_CHARS = 10_000
|
|
||||||
MAX_READ_CHARS = 20_000
|
|
||||||
SHELL_TIMEOUT = 60
|
|
||||||
PYTHON_TIMEOUT = 30
|
|
||||||
SEARCH_TIMEOUT = 30
|
|
||||||
|
|
||||||
|
|
||||||
def truncate(text: str, limit: int = MAX_OUTPUT_CHARS) -> str:
|
|
||||||
"""Truncate text to *limit* characters with a suffix note."""
|
|
||||||
if not isinstance(text, str):
|
|
||||||
# Tool output is occasionally None or a non-string; len(None) would
|
|
||||||
# raise. Coerce so this shared helper never crashes a tool response.
|
|
||||||
text = "" if text is None else str(text)
|
|
||||||
if len(text) > limit:
|
|
||||||
return text[:limit] + f"\n... (truncated, {len(text)} chars total)"
|
|
||||||
return text
|
|
||||||
+182
-125
@@ -31,13 +31,19 @@ sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
|
|||||||
|
|
||||||
server = Server("email")
|
server = Server("email")
|
||||||
EMAIL_SOCKET_TIMEOUT = float(os.environ.get("EMAIL_SOCKET_TIMEOUT", "20"))
|
EMAIL_SOCKET_TIMEOUT = float(os.environ.get("EMAIL_SOCKET_TIMEOUT", "20"))
|
||||||
DATA_DIR = Path(__file__).resolve().parent.parent / "data"
|
from src.constants import DATA_DIR as _DATA_DIR, APP_DB, EMAIL_CACHE_DB, SETTINGS_FILE as _SETTINGS_FILE, MAIL_ATTACHMENTS_DIR
|
||||||
|
DATA_DIR = Path(_DATA_DIR)
|
||||||
|
|
||||||
|
|
||||||
def _b(value) -> bytes:
|
def _b(value) -> bytes:
|
||||||
return str(value).encode()
|
return str(value).encode()
|
||||||
|
|
||||||
|
|
||||||
|
def _q(name: str) -> str:
|
||||||
|
"""Quote an IMAP mailbox name for commands that take mailbox args."""
|
||||||
|
return '"' + (name or "").replace("\\", "\\\\").replace('"', '\\"') + '"'
|
||||||
|
|
||||||
|
|
||||||
def _uid_fetch_rows(data) -> list:
|
def _uid_fetch_rows(data) -> list:
|
||||||
return [d for d in (data or []) if isinstance(d, bytes) and b"UID " in d]
|
return [d for d in (data or []) if isinstance(d, bytes) and b"UID " in d]
|
||||||
|
|
||||||
@@ -58,7 +64,7 @@ def _clean_header_value(value) -> str:
|
|||||||
|
|
||||||
|
|
||||||
def _db_path() -> Path:
|
def _db_path() -> Path:
|
||||||
return DATA_DIR / "app.db"
|
return Path(APP_DB)
|
||||||
|
|
||||||
|
|
||||||
def _list_accounts_raw() -> list:
|
def _list_accounts_raw() -> list:
|
||||||
@@ -157,7 +163,7 @@ def _load_config(account: str | None = None) -> dict:
|
|||||||
"trash_folder": os.environ.get("TRASH_FOLDER", "Trash"),
|
"trash_folder": os.environ.get("TRASH_FOLDER", "Trash"),
|
||||||
"cache_db": os.environ.get(
|
"cache_db": os.environ.get(
|
||||||
"EMAIL_CACHE_DB",
|
"EMAIL_CACHE_DB",
|
||||||
str(DATA_DIR / "email_cache.db"),
|
EMAIL_CACHE_DB,
|
||||||
),
|
),
|
||||||
"account_id": None,
|
"account_id": None,
|
||||||
"account_name": None,
|
"account_name": None,
|
||||||
@@ -199,7 +205,7 @@ def _load_config(account: str | None = None) -> dict:
|
|||||||
else:
|
else:
|
||||||
# Legacy fallback: settings.json flat keys
|
# Legacy fallback: settings.json flat keys
|
||||||
try:
|
try:
|
||||||
settings_path = Path(__file__).resolve().parent.parent / "data" / "settings.json"
|
settings_path = Path(_SETTINGS_FILE)
|
||||||
if settings_path.exists():
|
if settings_path.exists():
|
||||||
settings = json.loads(settings_path.read_text(encoding="utf-8"))
|
settings = json.loads(settings_path.read_text(encoding="utf-8"))
|
||||||
for key in (
|
for key in (
|
||||||
@@ -239,10 +245,27 @@ def _imap_connect(account: str | None = None):
|
|||||||
timeout=EMAIL_SOCKET_TIMEOUT,
|
timeout=EMAIL_SOCKET_TIMEOUT,
|
||||||
)
|
)
|
||||||
if cfg["imap_starttls"]:
|
if cfg["imap_starttls"]:
|
||||||
conn.starttls()
|
try:
|
||||||
|
conn.starttls()
|
||||||
|
except Exception:
|
||||||
|
# Don't leak the open plain socket on a rejected STARTTLS. (#3174)
|
||||||
|
try:
|
||||||
|
conn.shutdown()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
raise
|
||||||
if getattr(conn, "sock", None):
|
if getattr(conn, "sock", None):
|
||||||
conn.sock.settimeout(EMAIL_SOCKET_TIMEOUT)
|
conn.sock.settimeout(EMAIL_SOCKET_TIMEOUT)
|
||||||
conn.login(cfg["imap_user"], cfg["imap_password"])
|
try:
|
||||||
|
conn.login(cfg["imap_user"], cfg["imap_password"])
|
||||||
|
except Exception:
|
||||||
|
# A failed login otherwise orphans the connected socket; close it
|
||||||
|
# before propagating (shutdown() is the pre-auth low-level close). (#3174)
|
||||||
|
try:
|
||||||
|
conn.shutdown()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
raise
|
||||||
return conn
|
return conn
|
||||||
|
|
||||||
|
|
||||||
@@ -418,68 +441,71 @@ def _list_emails(folder="INBOX", max_results=20, unresponded_only=False,
|
|||||||
Pass unread_only=True and/or unresponded_only=True for attention scans.
|
Pass unread_only=True and/or unresponded_only=True for attention scans.
|
||||||
account selects mailbox (None = default).
|
account selects mailbox (None = default).
|
||||||
"""
|
"""
|
||||||
conn = _imap_connect(account)
|
conn = None
|
||||||
select_status, _ = conn.select(folder, readonly=True)
|
try:
|
||||||
if select_status != "OK":
|
conn = _imap_connect(account)
|
||||||
conn.logout()
|
select_status, _ = conn.select(_q(folder), readonly=True)
|
||||||
raise ValueError(f"IMAP folder not found: {folder}")
|
if select_status != "OK":
|
||||||
|
raise ValueError(f"IMAP folder not found: {folder}")
|
||||||
|
|
||||||
if unread_only and unresponded_only:
|
if unread_only and unresponded_only:
|
||||||
status, data = conn.uid("SEARCH", None, "(UNSEEN UNANSWERED)")
|
status, data = conn.uid("SEARCH", None, "(UNSEEN UNANSWERED)")
|
||||||
elif unread_only:
|
elif unread_only:
|
||||||
status, data = conn.uid("SEARCH", None, "(UNSEEN)")
|
status, data = conn.uid("SEARCH", None, "(UNSEEN)")
|
||||||
elif unresponded_only:
|
elif unresponded_only:
|
||||||
# Was missing — unresponded_only=True (without unread_only) fell through
|
# Was missing — unresponded_only=True (without unread_only) fell through
|
||||||
# to "ALL" and returned answered mail too, despite the documented
|
# to "ALL" and returned answered mail too, despite the documented
|
||||||
# "emails without replies" behaviour.
|
# "emails without replies" behaviour.
|
||||||
status, data = conn.uid("SEARCH", None, "(UNANSWERED)")
|
status, data = conn.uid("SEARCH", None, "(UNANSWERED)")
|
||||||
else:
|
else:
|
||||||
# Include read too — IMAP search "ALL" returns the entire folder
|
# Include read too — IMAP search "ALL" returns the entire folder
|
||||||
status, data = conn.uid("SEARCH", None, "ALL")
|
status, data = conn.uid("SEARCH", None, "ALL")
|
||||||
|
|
||||||
if status != "OK" or not data[0]:
|
if status != "OK" or not data[0]:
|
||||||
conn.logout()
|
return []
|
||||||
return []
|
|
||||||
|
|
||||||
uid_list = list(reversed(data[0].split()))[:max_results]
|
uid_list = list(reversed(data[0].split()))[:max_results]
|
||||||
cache = _get_cached_summaries()
|
cache = _get_cached_summaries()
|
||||||
results = []
|
results = []
|
||||||
|
|
||||||
for uid in uid_list:
|
for uid in uid_list:
|
||||||
try:
|
try:
|
||||||
status, msg_data = conn.uid("FETCH", uid, "(RFC822.HEADER)")
|
status, msg_data = conn.uid("FETCH", uid, "(RFC822.HEADER)")
|
||||||
if status != "OK":
|
if status != "OK":
|
||||||
|
continue
|
||||||
|
raw_header = msg_data[0][1]
|
||||||
|
msg = email.message_from_bytes(raw_header)
|
||||||
|
|
||||||
|
subject = _decode_header(msg.get("Subject", "(no subject)"))
|
||||||
|
sender = _decode_header(msg.get("From", "unknown"))
|
||||||
|
date_str = msg.get("Date", "")
|
||||||
|
message_id = msg.get("Message-ID", "")
|
||||||
|
|
||||||
|
# Parse sender name
|
||||||
|
sender_name, sender_addr = email.utils.parseaddr(sender)
|
||||||
|
sender_display = sender_name or sender_addr
|
||||||
|
|
||||||
|
# Check cache for summary
|
||||||
|
cached = cache.get(subject, {})
|
||||||
|
summary = cached.get("summary", "")
|
||||||
|
|
||||||
|
results.append({
|
||||||
|
"uid": uid.decode(),
|
||||||
|
"message_id": message_id,
|
||||||
|
"subject": subject,
|
||||||
|
"from": sender_display,
|
||||||
|
"from_address": sender_addr,
|
||||||
|
"date": date_str,
|
||||||
|
"summary": summary,
|
||||||
|
})
|
||||||
|
except Exception:
|
||||||
continue
|
continue
|
||||||
raw_header = msg_data[0][1]
|
|
||||||
msg = email.message_from_bytes(raw_header)
|
|
||||||
|
|
||||||
subject = _decode_header(msg.get("Subject", "(no subject)"))
|
return results
|
||||||
sender = _decode_header(msg.get("From", "unknown"))
|
finally:
|
||||||
date_str = msg.get("Date", "")
|
if conn:
|
||||||
message_id = msg.get("Message-ID", "")
|
try: conn.logout()
|
||||||
|
except Exception: pass
|
||||||
# Parse sender name
|
|
||||||
sender_name, sender_addr = email.utils.parseaddr(sender)
|
|
||||||
sender_display = sender_name or sender_addr
|
|
||||||
|
|
||||||
# Check cache for summary
|
|
||||||
cached = cache.get(subject, {})
|
|
||||||
summary = cached.get("summary", "")
|
|
||||||
|
|
||||||
results.append({
|
|
||||||
"uid": uid.decode(),
|
|
||||||
"message_id": message_id,
|
|
||||||
"subject": subject,
|
|
||||||
"from": sender_display,
|
|
||||||
"from_address": sender_addr,
|
|
||||||
"date": date_str,
|
|
||||||
"summary": summary,
|
|
||||||
})
|
|
||||||
except Exception:
|
|
||||||
continue
|
|
||||||
|
|
||||||
conn.logout()
|
|
||||||
return results
|
|
||||||
|
|
||||||
|
|
||||||
def _result_sort_time(result: dict) -> datetime:
|
def _result_sort_time(result: dict) -> datetime:
|
||||||
@@ -542,7 +568,7 @@ def _search_emails(query, folders=None, max_results=20, account=None):
|
|||||||
try:
|
try:
|
||||||
for folder in folders:
|
for folder in folders:
|
||||||
try:
|
try:
|
||||||
status, _ = conn.select(folder, readonly=True)
|
status, _ = conn.select(_q(folder), readonly=True)
|
||||||
if status != "OK":
|
if status != "OK":
|
||||||
continue
|
continue
|
||||||
status, data = conn.uid("SEARCH", None, search_cmd)
|
status, data = conn.uid("SEARCH", None, search_cmd)
|
||||||
@@ -652,54 +678,55 @@ def _extract_attachment_to_disk(msg, index, target_dir):
|
|||||||
def _read_email(uid=None, message_id=None, folder="INBOX", account=None):
|
def _read_email(uid=None, message_id=None, folder="INBOX", account=None):
|
||||||
"""Read full email content by UID or message-ID. account = mailbox selector."""
|
"""Read full email content by UID or message-ID. account = mailbox selector."""
|
||||||
cfg = _load_config(account)
|
cfg = _load_config(account)
|
||||||
conn = _imap_connect(account)
|
conn = None
|
||||||
conn.select(folder, readonly=True)
|
try:
|
||||||
|
conn = _imap_connect(account)
|
||||||
|
conn.select(_q(folder), readonly=True)
|
||||||
|
|
||||||
if message_id and not uid:
|
if message_id and not uid:
|
||||||
status, data = conn.uid("SEARCH", None, f'(HEADER Message-ID "{message_id}")')
|
status, data = conn.uid("SEARCH", None, f'(HEADER Message-ID "{message_id}")')
|
||||||
if status != "OK" or not data[0]:
|
if status != "OK" or not data[0]:
|
||||||
conn.logout()
|
return {"error": f"Email not found with Message-ID: {message_id}"}
|
||||||
return {"error": f"Email not found with Message-ID: {message_id}"}
|
uid = data[0].split()[-1]
|
||||||
uid = data[0].split()[-1]
|
|
||||||
|
|
||||||
if not uid:
|
if not uid:
|
||||||
conn.logout()
|
return {"error": "No UID or Message-ID provided"}
|
||||||
return {"error": "No UID or Message-ID provided"}
|
|
||||||
|
|
||||||
status, msg_data = conn.uid("FETCH", _b(uid), "(BODY.PEEK[])")
|
status, msg_data = conn.uid("FETCH", _b(uid), "(BODY.PEEK[])")
|
||||||
if status != "OK":
|
if status != "OK":
|
||||||
conn.logout()
|
return {"error": f"Failed to fetch email UID {uid}"}
|
||||||
return {"error": f"Failed to fetch email UID {uid}"}
|
if not msg_data or not msg_data[0] or not isinstance(msg_data[0], tuple) or len(msg_data[0]) < 2:
|
||||||
if not msg_data or not msg_data[0] or not isinstance(msg_data[0], tuple) or len(msg_data[0]) < 2:
|
return {"error": f"Email not found with UID {uid}"}
|
||||||
conn.logout()
|
|
||||||
return {"error": f"Email not found with UID {uid}"}
|
|
||||||
|
|
||||||
raw = msg_data[0][1]
|
raw = msg_data[0][1]
|
||||||
msg = email.message_from_bytes(raw)
|
msg = email.message_from_bytes(raw)
|
||||||
|
|
||||||
subject = _decode_header(msg.get("Subject", "(no subject)"))
|
subject = _decode_header(msg.get("Subject", "(no subject)"))
|
||||||
sender = _decode_header(msg.get("From", "unknown"))
|
sender = _decode_header(msg.get("From", "unknown"))
|
||||||
date_str = msg.get("Date", "")
|
date_str = msg.get("Date", "")
|
||||||
message_id_header = msg.get("Message-ID", "")
|
message_id_header = msg.get("Message-ID", "")
|
||||||
body = _extract_text(msg)
|
body = _extract_text(msg)
|
||||||
attachments = _list_attachments_from_msg(msg)
|
attachments = _list_attachments_from_msg(msg)
|
||||||
|
|
||||||
sender_name, sender_addr = email.utils.parseaddr(sender)
|
sender_name, sender_addr = email.utils.parseaddr(sender)
|
||||||
|
|
||||||
conn.logout()
|
return {
|
||||||
return {
|
"uid": uid.decode() if isinstance(uid, bytes) else str(uid),
|
||||||
"uid": uid.decode() if isinstance(uid, bytes) else str(uid),
|
"account": cfg.get("account_name") or cfg.get("imap_user") or "default",
|
||||||
"account": cfg.get("account_name") or cfg.get("imap_user") or "default",
|
"account_email": cfg.get("imap_user") or cfg.get("from_address") or "",
|
||||||
"account_email": cfg.get("imap_user") or cfg.get("from_address") or "",
|
"account_id": cfg.get("account_id"),
|
||||||
"account_id": cfg.get("account_id"),
|
"message_id": message_id_header,
|
||||||
"message_id": message_id_header,
|
"subject": subject,
|
||||||
"subject": subject,
|
"from": sender_name or sender_addr,
|
||||||
"from": sender_name or sender_addr,
|
"from_address": sender_addr,
|
||||||
"from_address": sender_addr,
|
"date": date_str,
|
||||||
"date": date_str,
|
"body": body[:8000],
|
||||||
"body": body[:8000],
|
"attachments": attachments,
|
||||||
"attachments": attachments,
|
}
|
||||||
}
|
finally:
|
||||||
|
if conn:
|
||||||
|
try: conn.logout()
|
||||||
|
except Exception: pass
|
||||||
|
|
||||||
|
|
||||||
def _read_email_across_accounts(uid=None, message_id=None, folder="INBOX"):
|
def _read_email_across_accounts(uid=None, message_id=None, folder="INBOX"):
|
||||||
@@ -768,7 +795,16 @@ def _smtp_connect(account=None, cfg=None):
|
|||||||
port,
|
port,
|
||||||
timeout=EMAIL_SOCKET_TIMEOUT,
|
timeout=EMAIL_SOCKET_TIMEOUT,
|
||||||
)
|
)
|
||||||
conn.starttls()
|
try:
|
||||||
|
conn.starttls()
|
||||||
|
except Exception:
|
||||||
|
# Don't leak the open plain socket on a rejected STARTTLS. SMTP has
|
||||||
|
# no shutdown(); close() is the low-level socket close (no QUIT). (#3174)
|
||||||
|
try:
|
||||||
|
conn.close()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
raise
|
||||||
elif security == "ssl":
|
elif security == "ssl":
|
||||||
conn = smtplib.SMTP_SSL(
|
conn = smtplib.SMTP_SSL(
|
||||||
cfg["smtp_host"],
|
cfg["smtp_host"],
|
||||||
@@ -782,7 +818,16 @@ def _smtp_connect(account=None, cfg=None):
|
|||||||
timeout=EMAIL_SOCKET_TIMEOUT,
|
timeout=EMAIL_SOCKET_TIMEOUT,
|
||||||
)
|
)
|
||||||
if cfg["smtp_user"] and cfg["smtp_password"]:
|
if cfg["smtp_user"] and cfg["smtp_password"]:
|
||||||
conn.login(cfg["smtp_user"], cfg["smtp_password"])
|
try:
|
||||||
|
conn.login(cfg["smtp_user"], cfg["smtp_password"])
|
||||||
|
except Exception:
|
||||||
|
# A failed login otherwise orphans the connected socket; close it
|
||||||
|
# before propagating (SMTP has no shutdown(); close() = socket close). (#3174)
|
||||||
|
try:
|
||||||
|
conn.close()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
raise
|
||||||
return conn
|
return conn
|
||||||
|
|
||||||
|
|
||||||
@@ -827,7 +872,7 @@ def _send_email(to, subject, body, in_reply_to=None, references=None, cc=None, b
|
|||||||
imap = _imap_connect(send_account)
|
imap = _imap_connect(send_account)
|
||||||
try:
|
try:
|
||||||
sent_folder = _detect_sent_folder(imap)
|
sent_folder = _detect_sent_folder(imap)
|
||||||
append_st, append_data = imap.append(sent_folder, "\\Seen", None, msg.as_bytes())
|
append_st, append_data = imap.append(_q(sent_folder), "\\Seen", None, msg.as_bytes())
|
||||||
if append_st == "OK" and append_data:
|
if append_st == "OK" and append_data:
|
||||||
m = re.search(rb"APPENDUID\s+\d+\s+(\d+)", append_data[0] or b"")
|
m = re.search(rb"APPENDUID\s+\d+\s+(\d+)", append_data[0] or b"")
|
||||||
if m:
|
if m:
|
||||||
@@ -853,10 +898,15 @@ def _send_email(to, subject, body, in_reply_to=None, references=None, cc=None, b
|
|||||||
|
|
||||||
def _reply_to_email(uid, body, folder="INBOX", reply_all=False, account=None):
|
def _reply_to_email(uid, body, folder="INBOX", reply_all=False, account=None):
|
||||||
"""Reply to an existing email by UID. Threads via In-Reply-To/References."""
|
"""Reply to an existing email by UID. Threads via In-Reply-To/References."""
|
||||||
conn = _imap_connect(account)
|
conn = None
|
||||||
conn.select(folder, readonly=True)
|
try:
|
||||||
status, msg_data = conn.uid("FETCH", _b(uid), "(BODY.PEEK[])")
|
conn = _imap_connect(account)
|
||||||
conn.logout()
|
conn.select(_q(folder), readonly=True)
|
||||||
|
status, msg_data = conn.uid("FETCH", _b(uid), "(BODY.PEEK[])")
|
||||||
|
finally:
|
||||||
|
if conn:
|
||||||
|
try: conn.logout()
|
||||||
|
except Exception: pass
|
||||||
if status != "OK" or not msg_data or not msg_data[0]:
|
if status != "OK" or not msg_data or not msg_data[0]:
|
||||||
return {"error": f"Failed to fetch email UID {uid}"}
|
return {"error": f"Failed to fetch email UID {uid}"}
|
||||||
raw = msg_data[0][1]
|
raw = msg_data[0][1]
|
||||||
@@ -896,7 +946,7 @@ def _reply_to_email(uid, body, folder="INBOX", reply_all=False, account=None):
|
|||||||
def _set_flag(uid, folder, flag, add=True, account=None):
|
def _set_flag(uid, folder, flag, add=True, account=None):
|
||||||
"""Add or remove an IMAP flag (e.g. \\Seen, \\Answered, \\Deleted)."""
|
"""Add or remove an IMAP flag (e.g. \\Seen, \\Answered, \\Deleted)."""
|
||||||
conn = _imap_connect(account)
|
conn = _imap_connect(account)
|
||||||
conn.select(folder)
|
conn.select(_q(folder))
|
||||||
op = "+FLAGS" if add else "-FLAGS"
|
op = "+FLAGS" if add else "-FLAGS"
|
||||||
try:
|
try:
|
||||||
status, data = conn.uid("STORE", _b(uid), op, flag)
|
status, data = conn.uid("STORE", _b(uid), op, flag)
|
||||||
@@ -918,7 +968,7 @@ def _bulk_set_flag(uids, folder, flag, add=True, account=None):
|
|||||||
conn = _imap_connect(account)
|
conn = _imap_connect(account)
|
||||||
touched = []
|
touched = []
|
||||||
try:
|
try:
|
||||||
conn.select(folder)
|
conn.select(_q(folder))
|
||||||
op = "+FLAGS" if add else "-FLAGS"
|
op = "+FLAGS" if add else "-FLAGS"
|
||||||
msg_set = ",".join(str(u) for u in uids)
|
msg_set = ",".join(str(u) for u in uids)
|
||||||
try:
|
try:
|
||||||
@@ -945,7 +995,7 @@ def _bulk_move(uids, source_folder, dest_folder, account=None, role: str = ""):
|
|||||||
conn = _imap_connect(account)
|
conn = _imap_connect(account)
|
||||||
moved = 0
|
moved = 0
|
||||||
try:
|
try:
|
||||||
conn.select(source_folder)
|
conn.select(_q(source_folder))
|
||||||
dest_folder = _resolve_folder(conn, dest_folder, role or _folder_role_from_name(dest_folder))
|
dest_folder = _resolve_folder(conn, dest_folder, role or _folder_role_from_name(dest_folder))
|
||||||
msg_set = ",".join(str(u) for u in uids)
|
msg_set = ",".join(str(u) for u in uids)
|
||||||
try:
|
try:
|
||||||
@@ -956,10 +1006,11 @@ def _bulk_move(uids, source_folder, dest_folder, account=None, role: str = ""):
|
|||||||
if not existing:
|
if not existing:
|
||||||
return 0
|
return 0
|
||||||
moved = len(existing)
|
moved = len(existing)
|
||||||
status, _ = conn.uid("MOVE", _b(msg_set), dest_folder)
|
dest_arg = _q(dest_folder)
|
||||||
|
status, _ = conn.uid("MOVE", _b(msg_set), dest_arg)
|
||||||
if status != "OK":
|
if status != "OK":
|
||||||
# Fallback: UID copy + flag-delete + expunge
|
# Fallback: UID copy + flag-delete + expunge
|
||||||
status, _ = conn.uid("COPY", _b(msg_set), dest_folder)
|
status, _ = conn.uid("COPY", _b(msg_set), dest_arg)
|
||||||
if status != "OK":
|
if status != "OK":
|
||||||
return 0
|
return 0
|
||||||
status, _ = conn.uid("STORE", _b(msg_set), "+FLAGS", "\\Deleted")
|
status, _ = conn.uid("STORE", _b(msg_set), "+FLAGS", "\\Deleted")
|
||||||
@@ -976,7 +1027,7 @@ def _search_uids(folder="INBOX", criteria="UNSEEN", account=None):
|
|||||||
ALL, ANSWERED). Used to resolve selectors like all_unread → uids."""
|
ALL, ANSWERED). Used to resolve selectors like all_unread → uids."""
|
||||||
conn = _imap_connect(account)
|
conn = _imap_connect(account)
|
||||||
try:
|
try:
|
||||||
conn.select(folder, readonly=True)
|
conn.select(_q(folder), readonly=True)
|
||||||
status, data = conn.uid("SEARCH", None, criteria)
|
status, data = conn.uid("SEARCH", None, criteria)
|
||||||
if status != "OK" or not data or not data[0]:
|
if status != "OK" or not data or not data[0]:
|
||||||
return []
|
return []
|
||||||
@@ -988,7 +1039,7 @@ def _search_uids(folder="INBOX", criteria="UNSEEN", account=None):
|
|||||||
def _move_message(uid, source_folder, dest_folder, account=None, role: str = ""):
|
def _move_message(uid, source_folder, dest_folder, account=None, role: str = ""):
|
||||||
"""Move a message between folders. Tries IMAP MOVE, falls back to copy+delete."""
|
"""Move a message between folders. Tries IMAP MOVE, falls back to copy+delete."""
|
||||||
conn = _imap_connect(account)
|
conn = _imap_connect(account)
|
||||||
conn.select(source_folder)
|
conn.select(_q(source_folder))
|
||||||
try:
|
try:
|
||||||
dest_folder = _resolve_folder(conn, dest_folder, role or _folder_role_from_name(dest_folder))
|
dest_folder = _resolve_folder(conn, dest_folder, role or _folder_role_from_name(dest_folder))
|
||||||
try:
|
try:
|
||||||
@@ -998,11 +1049,12 @@ def _move_message(uid, source_folder, dest_folder, account=None, role: str = "")
|
|||||||
existing = _uid_fetch_rows(data)
|
existing = _uid_fetch_rows(data)
|
||||||
if status != "OK" or not existing:
|
if status != "OK" or not existing:
|
||||||
return False
|
return False
|
||||||
status, _ = conn.uid("MOVE", _b(uid), dest_folder)
|
dest_arg = _q(dest_folder)
|
||||||
|
status, _ = conn.uid("MOVE", _b(uid), dest_arg)
|
||||||
if status == "OK":
|
if status == "OK":
|
||||||
return True
|
return True
|
||||||
# Fallback: UID copy + delete
|
# Fallback: UID copy + delete
|
||||||
status, _ = conn.uid("COPY", _b(uid), dest_folder)
|
status, _ = conn.uid("COPY", _b(uid), dest_arg)
|
||||||
if status != "OK":
|
if status != "OK":
|
||||||
return False
|
return False
|
||||||
status, _ = conn.uid("STORE", _b(uid), "+FLAGS", "\\Deleted")
|
status, _ = conn.uid("STORE", _b(uid), "+FLAGS", "\\Deleted")
|
||||||
@@ -1031,16 +1083,21 @@ def _archive_email(uid, folder="INBOX", account=None):
|
|||||||
|
|
||||||
def _download_attachment(uid, index, folder="INBOX", account=None):
|
def _download_attachment(uid, index, folder="INBOX", account=None):
|
||||||
"""Extract a specific attachment to disk and return its local path."""
|
"""Extract a specific attachment to disk and return its local path."""
|
||||||
conn = _imap_connect(account)
|
conn = None
|
||||||
conn.select(folder, readonly=True)
|
try:
|
||||||
status, msg_data = conn.uid("FETCH", _b(uid), "(BODY.PEEK[])")
|
conn = _imap_connect(account)
|
||||||
conn.logout()
|
conn.select(_q(folder), readonly=True)
|
||||||
|
status, msg_data = conn.uid("FETCH", _b(uid), "(BODY.PEEK[])")
|
||||||
|
finally:
|
||||||
|
if conn:
|
||||||
|
try: conn.logout()
|
||||||
|
except Exception: pass
|
||||||
if status != "OK":
|
if status != "OK":
|
||||||
return {"error": f"Failed to fetch email UID {uid}"}
|
return {"error": f"Failed to fetch email UID {uid}"}
|
||||||
raw = msg_data[0][1]
|
raw = msg_data[0][1]
|
||||||
msg = email.message_from_bytes(raw)
|
msg = email.message_from_bytes(raw)
|
||||||
|
|
||||||
target_dir = DATA_DIR / "mail-attachments" / f"{folder}_{uid}"
|
target_dir = Path(MAIL_ATTACHMENTS_DIR) / f"{folder}_{uid}"
|
||||||
filepath = _extract_attachment_to_disk(msg, index, target_dir)
|
filepath = _extract_attachment_to_disk(msg, index, target_dir)
|
||||||
if not filepath:
|
if not filepath:
|
||||||
return {"error": f"Attachment index {index} not found"}
|
return {"error": f"Attachment index {index} not found"}
|
||||||
|
|||||||
@@ -16,6 +16,8 @@ from mcp.types import Tool, TextContent
|
|||||||
|
|
||||||
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
|
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
|
||||||
|
|
||||||
|
from src.constants import GENERATED_IMAGES_DIR
|
||||||
|
|
||||||
server = Server("image_gen")
|
server = Server("image_gen")
|
||||||
|
|
||||||
|
|
||||||
@@ -115,14 +117,18 @@ async def call_tool(name: str, arguments: dict) -> list[TextContent]:
|
|||||||
|
|
||||||
img = images[0]
|
img = images[0]
|
||||||
image_url = None
|
image_url = None
|
||||||
|
# Prefix the instance's public base URL (existing app_public_url setting) so the
|
||||||
|
# link is fully-qualified and clickable when the model echoes it. Empty = relative
|
||||||
|
# same-origin path (unchanged default).
|
||||||
|
_pub_base = (get_setting("app_public_url", "") or "").rstrip("/")
|
||||||
|
|
||||||
if img.get("b64_json"):
|
if img.get("b64_json"):
|
||||||
img_dir = Path("data/generated_images")
|
img_dir = Path(GENERATED_IMAGES_DIR)
|
||||||
img_dir.mkdir(parents=True, exist_ok=True)
|
img_dir.mkdir(parents=True, exist_ok=True)
|
||||||
filename = f"{uuid.uuid4().hex[:12]}.png"
|
filename = f"{uuid.uuid4().hex[:12]}.png"
|
||||||
img_path = img_dir / filename
|
img_path = img_dir / filename
|
||||||
img_path.write_bytes(base64.b64decode(img["b64_json"]))
|
img_path.write_bytes(base64.b64decode(img["b64_json"]))
|
||||||
image_url = f"/api/generated-image/{filename}"
|
image_url = f"{_pub_base}/api/generated-image/{filename}"
|
||||||
|
|
||||||
# Save to gallery
|
# Save to gallery
|
||||||
try:
|
try:
|
||||||
@@ -146,7 +152,13 @@ async def call_tool(name: str, arguments: dict) -> list[TextContent]:
|
|||||||
else:
|
else:
|
||||||
return [TextContent(type="text", text="Error: Unexpected image API response format")]
|
return [TextContent(type="text", text="Error: Unexpected image API response format")]
|
||||||
|
|
||||||
result = f"Generated image for: {prompt[:100]}\nimage_url: {image_url}\nmodel: {model_id}\nsize: {size}"
|
# "Direct link:" rather than an "image_url:" label — small models copied the
|
||||||
|
# label token ("image_url") into the link href, producing a broken link.
|
||||||
|
result = (
|
||||||
|
f"Generated image for: {prompt[:100]}\n"
|
||||||
|
f"Direct link: {image_url}\n"
|
||||||
|
f"model: {model_id}\nsize: {size}"
|
||||||
|
)
|
||||||
return [TextContent(type="text", text=result)]
|
return [TextContent(type="text", text=result)]
|
||||||
|
|
||||||
except httpx.TimeoutException:
|
except httpx.TimeoutException:
|
||||||
|
|||||||
Generated
+1
-1
@@ -1,5 +1,5 @@
|
|||||||
{
|
{
|
||||||
"name": "odysseus-ui",
|
"name": "odysseus",
|
||||||
"lockfileVersion": 3,
|
"lockfileVersion": 3,
|
||||||
"requires": true,
|
"requires": true,
|
||||||
"packages": {
|
"packages": {
|
||||||
|
|||||||
@@ -1,3 +1,18 @@
|
|||||||
[tool.pytest.ini_options]
|
[tool.pytest.ini_options]
|
||||||
testpaths = ["tests"]
|
testpaths = ["tests"]
|
||||||
asyncio_mode = "auto"
|
asyncio_mode = "auto"
|
||||||
|
# Test-taxonomy markers added at collection time by tests/conftest.py. The
|
||||||
|
# stable area_* markers are declared here; the dynamic sub_<filename-token>
|
||||||
|
# markers are registered before collection by pytest_configure in
|
||||||
|
# tests/conftest.py, so unknown-mark warnings still flag genuine typos outside
|
||||||
|
# the taxonomy. See tests/_taxonomy.py and tests/README.md.
|
||||||
|
markers = [
|
||||||
|
"area_security: tests covering auth, owner-scope, SSRF, XSS, confinement, redaction",
|
||||||
|
"area_routes: tests covering HTTP route / API behavior",
|
||||||
|
"area_services: tests covering service-layer behavior (llm, cookbook, email, calendar, ...)",
|
||||||
|
"area_cli: tests covering CLI / script behavior",
|
||||||
|
"area_js: JavaScript / Node-backed tests",
|
||||||
|
"area_helpers: self-tests for the shared test helpers in tests/helpers/",
|
||||||
|
"area_unit: pure parser / utility tests that do not clearly belong elsewhere",
|
||||||
|
"area_uncategorized: tests not yet matched by the taxonomy (fallback)",
|
||||||
|
]
|
||||||
|
|||||||
@@ -31,7 +31,7 @@ from core.database import (
|
|||||||
CalendarEvent,
|
CalendarEvent,
|
||||||
CalendarCal,
|
CalendarCal,
|
||||||
)
|
)
|
||||||
from src.constants import DATA_DIR
|
from src.constants import DATA_DIR, SKILLS_DIR, SKILLS_FILE, GALLERY_DIR, GALLERY_UPLOADS_DIR
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -107,7 +107,7 @@ def setup_admin_wipe_routes(session_manager):
|
|||||||
# Skills live as SKILL.md files under data/skills/. Drop
|
# Skills live as SKILL.md files under data/skills/. Drop
|
||||||
# the entire directory; the SkillsManager re-creates the
|
# the entire directory; the SkillsManager re-creates the
|
||||||
# tree on next write.
|
# tree on next write.
|
||||||
skills_dir = os.path.join(DATA_DIR, "skills")
|
skills_dir = SKILLS_DIR
|
||||||
count = 0
|
count = 0
|
||||||
if os.path.isdir(skills_dir):
|
if os.path.isdir(skills_dir):
|
||||||
# Count SKILL.md files for the response — quick walk.
|
# Count SKILL.md files for the response — quick walk.
|
||||||
@@ -115,7 +115,7 @@ def setup_admin_wipe_routes(session_manager):
|
|||||||
count += sum(1 for f in files if f == "SKILL.md")
|
count += sum(1 for f in files if f == "SKILL.md")
|
||||||
_rmtree_quiet(skills_dir)
|
_rmtree_quiet(skills_dir)
|
||||||
# Legacy fallback file
|
# Legacy fallback file
|
||||||
legacy = os.path.join(DATA_DIR, "skills.json")
|
legacy = SKILLS_FILE
|
||||||
if os.path.exists(legacy):
|
if os.path.exists(legacy):
|
||||||
try:
|
try:
|
||||||
os.remove(legacy)
|
os.remove(legacy)
|
||||||
@@ -151,8 +151,8 @@ def setup_admin_wipe_routes(session_manager):
|
|||||||
db.query(GalleryAlbum).delete()
|
db.query(GalleryAlbum).delete()
|
||||||
db.commit()
|
db.commit()
|
||||||
# Also drop the upload dir so disk doesn't keep orphans.
|
# Also drop the upload dir so disk doesn't keep orphans.
|
||||||
_rmtree_quiet(os.path.join(DATA_DIR, "gallery"))
|
_rmtree_quiet(GALLERY_DIR)
|
||||||
_rmtree_quiet(os.path.join(DATA_DIR, "gallery_uploads"))
|
_rmtree_quiet(GALLERY_UPLOADS_DIR)
|
||||||
return {"status": "deleted", "kind": kind, "count": count}
|
return {"status": "deleted", "kind": kind, "count": count}
|
||||||
|
|
||||||
if kind == "calendar":
|
if kind == "calendar":
|
||||||
|
|||||||
@@ -155,22 +155,30 @@ def setup_api_token_routes() -> APIRouter:
|
|||||||
payload = await request.json()
|
payload = await request.json()
|
||||||
except Exception:
|
except Exception:
|
||||||
payload = {}
|
payload = {}
|
||||||
scope_list = _normalize_scopes(payload.get("scopes"))
|
|
||||||
scopes_value = ",".join(scope_list)
|
|
||||||
with get_db_session() as db:
|
with get_db_session() as db:
|
||||||
token = db.query(ApiToken).filter(ApiToken.id == token_id).first()
|
token = db.query(ApiToken).filter(ApiToken.id == token_id).first()
|
||||||
if not token:
|
if not token:
|
||||||
raise HTTPException(404, "Token not found")
|
raise HTTPException(404, "Token not found")
|
||||||
if isinstance(payload.get("name"), str) and payload["name"].strip():
|
if isinstance(payload.get("name"), str) and payload["name"].strip():
|
||||||
token.name = payload["name"].strip()[:MAX_NAME_LEN]
|
token.name = payload["name"].strip()[:MAX_NAME_LEN]
|
||||||
token.scopes = scopes_value
|
# Only touch scopes when the caller actually sent them. A partial
|
||||||
|
# update such as a rename ({"name": ...} with no "scopes" key) must
|
||||||
|
# not silently reset the token to the default scope — that dropped
|
||||||
|
# every previously granted scope.
|
||||||
|
if "scopes" in payload:
|
||||||
|
token.scopes = ",".join(_normalize_scopes(payload.get("scopes")))
|
||||||
db.add(token)
|
db.add(token)
|
||||||
|
current_scopes = [
|
||||||
|
s.strip()
|
||||||
|
for s in (getattr(token, "scopes", "") or DEFAULT_SCOPES).split(",")
|
||||||
|
if s.strip()
|
||||||
|
]
|
||||||
response = {
|
response = {
|
||||||
"id": token_id,
|
"id": token_id,
|
||||||
"name": getattr(token, "name", ""),
|
"name": getattr(token, "name", ""),
|
||||||
"owner": getattr(token, "owner", None),
|
"owner": getattr(token, "owner", None),
|
||||||
"token_prefix": getattr(token, "token_prefix", ""),
|
"token_prefix": getattr(token, "token_prefix", ""),
|
||||||
"scopes": scope_list,
|
"scopes": current_scopes,
|
||||||
}
|
}
|
||||||
_invalidate_cache(request)
|
_invalidate_cache(request)
|
||||||
return response
|
return response
|
||||||
|
|||||||
+23
-4
@@ -131,10 +131,8 @@ def setup_auth_routes(auth_manager: AuthManager) -> APIRouter:
|
|||||||
return {"ok": False, "requires_totp": True, "username": username}
|
return {"ok": False, "requires_totp": True, "username": username}
|
||||||
if not auth_manager.totp_verify(username, body.totp_code):
|
if not auth_manager.totp_verify(username, body.totp_code):
|
||||||
raise HTTPException(401, "Invalid 2FA code")
|
raise HTTPException(401, "Invalid 2FA code")
|
||||||
# All checks passed — create session
|
# All checks passed — create session (password already verified above)
|
||||||
token = await asyncio.to_thread(auth_manager.create_session, username, body.password)
|
token = await asyncio.to_thread(auth_manager.create_session_trusted, username)
|
||||||
if not token:
|
|
||||||
raise HTTPException(401, "Invalid credentials")
|
|
||||||
cookie_kwargs = dict(
|
cookie_kwargs = dict(
|
||||||
key=SESSION_COOKIE,
|
key=SESSION_COOKIE,
|
||||||
value=token,
|
value=token,
|
||||||
@@ -585,6 +583,27 @@ def setup_auth_routes(auth_manager: AuthManager) -> APIRouter:
|
|||||||
hint = " If this is Docker Compose ntfy, set NTFY_BIND to that host/Tailscale IP and NTFY_BASE_URL to the same server URL in .env, then recreate ntfy."
|
hint = " If this is Docker Compose ntfy, set NTFY_BIND to that host/Tailscale IP and NTFY_BASE_URL to the same server URL in .env, then recreate ntfy."
|
||||||
return {"ok": False, "message": f"ntfy publish to {full_url} failed: {e}.{hint}"[:500]}
|
return {"ok": False, "message": f"ntfy publish to {full_url} failed: {e}.{hint}"[:500]}
|
||||||
|
|
||||||
|
if preset == "discord_webhook":
|
||||||
|
import httpx
|
||||||
|
webhook_url = (integ.get("base_url") or "").strip()
|
||||||
|
if not webhook_url:
|
||||||
|
return {"ok": False, "message": "No webhook URL set — paste the full Discord webhook URL into the Base URL field."}
|
||||||
|
payload = {
|
||||||
|
"embeds": [{
|
||||||
|
"title": "Odysseus connectivity test",
|
||||||
|
"description": "If you see this, your Discord Webhook integration is wired up correctly.",
|
||||||
|
"color": 5793266,
|
||||||
|
}]
|
||||||
|
}
|
||||||
|
try:
|
||||||
|
async with httpx.AsyncClient(timeout=8.0) as client:
|
||||||
|
r = await client.post(webhook_url, json=payload)
|
||||||
|
if r.is_success:
|
||||||
|
return {"ok": True, "message": "Test embed sent — check your Discord channel to confirm it arrived."}
|
||||||
|
return {"ok": False, "message": f"Discord returned HTTP {r.status_code}: {r.text[:200]}"}
|
||||||
|
except Exception as e:
|
||||||
|
return {"ok": False, "message": f"Request failed: {e}"[:400]}
|
||||||
|
|
||||||
# All other presets: GET against a known health endpoint.
|
# All other presets: GET against a known health endpoint.
|
||||||
# Fall back to detecting from name if preset is missing.
|
# Fall back to detecting from name if preset is missing.
|
||||||
health_paths = {
|
health_paths = {
|
||||||
|
|||||||
+56
-12
@@ -101,24 +101,68 @@ def setup_backup_routes(memory_manager, preset_manager, skills_manager) -> APIRo
|
|||||||
# ── Skills ──
|
# ── Skills ──
|
||||||
if "skills" in body and isinstance(body["skills"], list):
|
if "skills" in body and isinstance(body["skills"], list):
|
||||||
existing = skills_manager.load_all()
|
existing = skills_manager.load_all()
|
||||||
existing_ids = {s.get("id") for s in existing}
|
existing_names = {s.get("name") for s in existing if s.get("name")}
|
||||||
existing_titles = {s.get("title", "").strip().lower() for s in existing}
|
existing_ids = {s.get("id") for s in existing if s.get("id")}
|
||||||
|
existing_titles = {
|
||||||
|
(s.get("title") or s.get("description") or "").strip().lower()
|
||||||
|
for s in existing
|
||||||
|
}
|
||||||
added = 0
|
added = 0
|
||||||
for skill in body["skills"]:
|
for skill in body["skills"]:
|
||||||
if not isinstance(skill, dict) or not skill.get("title"):
|
if not isinstance(skill, dict):
|
||||||
continue
|
continue
|
||||||
# Skip if same id or same title already exists
|
title = (
|
||||||
if skill.get("id") in existing_ids:
|
skill.get("title") or skill.get("description")
|
||||||
|
or skill.get("name") or ""
|
||||||
|
).strip()
|
||||||
|
if not title:
|
||||||
continue
|
continue
|
||||||
if skill["title"].strip().lower() in existing_titles:
|
sid = skill.get("id") or skill.get("name")
|
||||||
|
if sid and sid in existing_ids:
|
||||||
continue
|
continue
|
||||||
if user and not skill.get("owner"):
|
nm = skill.get("name")
|
||||||
skill["owner"] = user
|
if nm and nm in existing_names:
|
||||||
existing.append(skill)
|
continue
|
||||||
existing_ids.add(skill.get("id"))
|
if title.lower() in existing_titles:
|
||||||
existing_titles.add(skill["title"].strip().lower())
|
continue
|
||||||
|
owner = skill.get("owner")
|
||||||
|
if user and not owner:
|
||||||
|
owner = user
|
||||||
|
# Skills live on disk as SKILL.md files; the old JSON-era
|
||||||
|
# skills_manager.save() no longer exists. Write each new skill
|
||||||
|
# via add_skill (source="user" skips auto-dedup — this is an
|
||||||
|
# explicit backup restore).
|
||||||
|
result = skills_manager.add_skill(
|
||||||
|
title=title,
|
||||||
|
name=skill.get("name"),
|
||||||
|
description=skill.get("description"),
|
||||||
|
problem=skill.get("problem", ""),
|
||||||
|
solution=skill.get("solution", ""),
|
||||||
|
steps=skill.get("steps"),
|
||||||
|
tags=skill.get("tags"),
|
||||||
|
source="user",
|
||||||
|
teacher_model=skill.get("teacher_model"),
|
||||||
|
confidence=skill.get("confidence", 0.8),
|
||||||
|
owner=owner,
|
||||||
|
category=skill.get("category", "general"),
|
||||||
|
when_to_use=skill.get("when_to_use"),
|
||||||
|
procedure=skill.get("procedure"),
|
||||||
|
pitfalls=skill.get("pitfalls"),
|
||||||
|
verification=skill.get("verification"),
|
||||||
|
platforms=skill.get("platforms"),
|
||||||
|
requires_toolsets=skill.get("requires_toolsets"),
|
||||||
|
fallback_for_toolsets=skill.get("fallback_for_toolsets"),
|
||||||
|
status=skill.get("status", "draft"),
|
||||||
|
version=skill.get("version", "1.0.0"),
|
||||||
|
)
|
||||||
|
if result.get("_deduped"):
|
||||||
|
continue
|
||||||
|
if result.get("name"):
|
||||||
|
existing_names.add(result["name"])
|
||||||
|
if result.get("id"):
|
||||||
|
existing_ids.add(result["id"])
|
||||||
|
existing_titles.add(title.lower())
|
||||||
added += 1
|
added += 1
|
||||||
skills_manager.save(existing)
|
|
||||||
imported.append(f"{added} skills")
|
imported.append(f"{added} skills")
|
||||||
|
|
||||||
# ── Presets ──
|
# ── Presets ──
|
||||||
|
|||||||
+254
-69
@@ -1,6 +1,7 @@
|
|||||||
"""Calendar routes — local SQLite-backed calendar CRUD."""
|
"""Calendar routes — local SQLite-backed calendar CRUD."""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import re
|
||||||
import uuid
|
import uuid
|
||||||
from datetime import datetime, date, timedelta
|
from datetime import datetime, date, timedelta
|
||||||
from typing import Optional, List
|
from typing import Optional, List
|
||||||
@@ -12,7 +13,7 @@ from dateutil.rrule import rrulestr
|
|||||||
|
|
||||||
from core.database import SessionLocal, CalendarCal, CalendarEvent
|
from core.database import SessionLocal, CalendarCal, CalendarEvent
|
||||||
from src.auth_helpers import require_user
|
from src.auth_helpers import require_user
|
||||||
from src.upload_limits import read_upload_limited
|
from src.upload_limits import read_upload_limited, ICS_MAX_BYTES
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -100,6 +101,15 @@ def _ics_escape(text: str) -> str:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _safe_ics_filename(name: str) -> str:
|
||||||
|
"""Return a conservative .ics filename safe for Content-Disposition."""
|
||||||
|
stem = name if isinstance(name, str) else ""
|
||||||
|
stem = re.sub(r"[^A-Za-z0-9._-]", "_", stem).strip("._-")
|
||||||
|
if not stem:
|
||||||
|
stem = "calendar"
|
||||||
|
return f"{stem[:128]}.ics"
|
||||||
|
|
||||||
|
|
||||||
def _resolve_base_uid(uid: str) -> str:
|
def _resolve_base_uid(uid: str) -> str:
|
||||||
"""Extract the base series UID from a compound occurrence UID.
|
"""Extract the base series UID from a compound occurrence UID.
|
||||||
|
|
||||||
@@ -248,6 +258,17 @@ def parse_due_for_user(s: str) -> str:
|
|||||||
if t is not None:
|
if t is not None:
|
||||||
return base.replace(hour=t[0], minute=t[1]).isoformat()
|
return base.replace(hour=t[0], minute=t[1]).isoformat()
|
||||||
|
|
||||||
|
# Time-first: "3pm today", "11pm today", "9am tomorrow"
|
||||||
|
m = _re.match(r'^(.+?)\s+(today|tonight|tomorrow|tmrw|yesterday)$', lower)
|
||||||
|
if m:
|
||||||
|
time_part, word = m.group(1).strip(), m.group(2)
|
||||||
|
base = today
|
||||||
|
if word in ("tomorrow", "tmrw"): base = today + _td(days=1)
|
||||||
|
elif word == "yesterday": base = today - _td(days=1)
|
||||||
|
t = _parse_time(time_part)
|
||||||
|
if t is not None:
|
||||||
|
return base.replace(hour=t[0], minute=t[1]).isoformat()
|
||||||
|
|
||||||
m = _re.match(r'^in\s+(\d+)\s*(hour|hr|minute|min|day)s?\s*$', lower)
|
m = _re.match(r'^in\s+(\d+)\s*(hour|hr|minute|min|day)s?\s*$', lower)
|
||||||
if m:
|
if m:
|
||||||
n = int(m.group(1)); unit = m.group(2)
|
n = int(m.group(1)); unit = m.group(2)
|
||||||
@@ -399,7 +420,17 @@ def _parse_dt(s: str) -> datetime:
|
|||||||
# Last resort: dateutil's fuzzy parser
|
# Last resort: dateutil's fuzzy parser
|
||||||
try:
|
try:
|
||||||
from dateutil import parser as _du
|
from dateutil import parser as _du
|
||||||
return _du.parse(s)
|
parsed = _du.parse(s)
|
||||||
|
# Strip tz like every other return path above — this function's
|
||||||
|
# contract is naive datetimes (CalendarEvent.dtstart is naive). An
|
||||||
|
# offset-bearing non-ISO input (e.g. RFC-2822 "Mon, 05 Jan 2026
|
||||||
|
# 14:00:00 +0900") otherwise leaked tz-aware into the naive column and
|
||||||
|
# crashed read-back comparisons in _expand_rrule with "can't compare
|
||||||
|
# offset-naive and offset-aware datetimes".
|
||||||
|
if parsed.tzinfo is not None:
|
||||||
|
from datetime import timezone as _tz
|
||||||
|
return parsed.astimezone(_tz.utc).replace(tzinfo=None)
|
||||||
|
return parsed
|
||||||
except Exception:
|
except Exception:
|
||||||
raise ValueError(f"could not parse datetime: {s!r}")
|
raise ValueError(f"could not parse datetime: {s!r}")
|
||||||
|
|
||||||
@@ -440,6 +471,9 @@ def _event_to_dict(ev: CalendarEvent) -> dict:
|
|||||||
|
|
||||||
# ── Recurrence expansion ──
|
# ── Recurrence expansion ──
|
||||||
|
|
||||||
|
_RRULE_EXPANSION_LIMIT = 1000
|
||||||
|
|
||||||
|
|
||||||
def _expand_rrule(
|
def _expand_rrule(
|
||||||
ev: CalendarEvent, start: datetime, end: datetime
|
ev: CalendarEvent, start: datetime, end: datetime
|
||||||
) -> List[dict]:
|
) -> List[dict]:
|
||||||
@@ -462,6 +496,7 @@ def _expand_rrule(
|
|||||||
d = _event_to_dict(ev)
|
d = _event_to_dict(ev)
|
||||||
d["is_recurrence"] = False
|
d["is_recurrence"] = False
|
||||||
d["series_uid"] = ev.uid
|
d["series_uid"] = ev.uid
|
||||||
|
d["truncated"] = False
|
||||||
return [d]
|
return [d]
|
||||||
|
|
||||||
# Parse the rrule, applying it to the base dtstart.
|
# Parse the rrule, applying it to the base dtstart.
|
||||||
@@ -487,6 +522,7 @@ def _expand_rrule(
|
|||||||
d = _event_to_dict(ev)
|
d = _event_to_dict(ev)
|
||||||
d["is_recurrence"] = False
|
d["is_recurrence"] = False
|
||||||
d["series_uid"] = ev.uid
|
d["series_uid"] = ev.uid
|
||||||
|
d["truncated"] = False
|
||||||
# Malformed RRULE rows are fetched by the recurring SQL branch
|
# Malformed RRULE rows are fetched by the recurring SQL branch
|
||||||
# with only dtstart < end_dt — the base event may not actually
|
# with only dtstart < end_dt — the base event may not actually
|
||||||
# overlap the window. Only return if it does.
|
# overlap the window. Only return if it does.
|
||||||
@@ -499,22 +535,26 @@ def _expand_rrule(
|
|||||||
# (matching non-recurring overlap semantics: dtstart < end AND
|
# (matching non-recurring overlap semantics: dtstart < end AND
|
||||||
# dtend > start).
|
# dtend > start).
|
||||||
expand_start = start - duration
|
expand_start = start - duration
|
||||||
occurrences = rule.between(expand_start, end, inc=True)
|
|
||||||
if not occurrences:
|
|
||||||
return []
|
|
||||||
|
|
||||||
results = []
|
results = []
|
||||||
|
truncated = False
|
||||||
base = _event_to_dict(ev)
|
base = _event_to_dict(ev)
|
||||||
|
|
||||||
for occ_start in occurrences:
|
for occ_start in rule.xafter(expand_start, inc=True):
|
||||||
|
if occ_start >= end:
|
||||||
|
break
|
||||||
|
|
||||||
occ_end = occ_start + duration
|
occ_end = occ_start + duration
|
||||||
|
|
||||||
# Overlap filter: occurrence must intersect [start, end).
|
# Overlap filter: occurrence must intersect [start, end).
|
||||||
# This enforces exclusive-end semantics (occ_start >= end is
|
# This enforces exclusive-end semantics (occ_start >= end is
|
||||||
# excluded) and includes multi-day crossings (occ_end > start).
|
# excluded) and includes multi-day crossings (occ_end > start).
|
||||||
if occ_start >= end or occ_end <= start:
|
if occ_end <= start:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
if len(results) >= _RRULE_EXPANSION_LIMIT:
|
||||||
|
truncated = True
|
||||||
|
break
|
||||||
|
|
||||||
# Build the compound uid: {base_uid}::{date} or ::{datetime}
|
# Build the compound uid: {base_uid}::{date} or ::{datetime}
|
||||||
if ev.all_day:
|
if ev.all_day:
|
||||||
occ_uid = f"{ev.uid}::{occ_start.strftime('%Y-%m-%d')}"
|
occ_uid = f"{ev.uid}::{occ_start.strftime('%Y-%m-%d')}"
|
||||||
@@ -525,6 +565,7 @@ def _expand_rrule(
|
|||||||
d["uid"] = occ_uid
|
d["uid"] = occ_uid
|
||||||
d["series_uid"] = ev.uid
|
d["series_uid"] = ev.uid
|
||||||
d["is_recurrence"] = True
|
d["is_recurrence"] = True
|
||||||
|
d["truncated"] = False
|
||||||
|
|
||||||
if ev.all_day:
|
if ev.all_day:
|
||||||
d["dtstart"] = occ_start.strftime("%Y-%m-%d")
|
d["dtstart"] = occ_start.strftime("%Y-%m-%d")
|
||||||
@@ -537,6 +578,10 @@ def _expand_rrule(
|
|||||||
|
|
||||||
results.append(d)
|
results.append(d)
|
||||||
|
|
||||||
|
if truncated:
|
||||||
|
for d in results:
|
||||||
|
d["truncated"] = True
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
@@ -545,72 +590,178 @@ def _expand_rrule(
|
|||||||
def setup_calendar_routes() -> APIRouter:
|
def setup_calendar_routes() -> APIRouter:
|
||||||
router = APIRouter(prefix="/api/calendar", tags=["calendar"])
|
router = APIRouter(prefix="/api/calendar", tags=["calendar"])
|
||||||
|
|
||||||
# CalDAV connect form (Integrations → Calendar). Storage is local
|
# ── CalDAV multi-account helpers ─────────────────────────────────────────
|
||||||
# SQLite; sync (src/caldav_sync.py) pulls remote events into it on
|
|
||||||
# calendar open and periodically via the scheduler.
|
def _get_caldav_accounts(owner: str) -> list:
|
||||||
|
from src.caldav_sync import _load_caldav_accounts
|
||||||
|
return _load_caldav_accounts(owner)
|
||||||
|
|
||||||
|
def _save_caldav_accounts(owner: str, accounts: list) -> None:
|
||||||
|
from routes.prefs_routes import _load_for_user, _save_for_user
|
||||||
|
prefs = _load_for_user(owner) or {}
|
||||||
|
prefs["caldav_accounts"] = accounts
|
||||||
|
prefs.pop("caldav", None)
|
||||||
|
_save_for_user(owner, prefs)
|
||||||
|
|
||||||
|
# ── CalDAV config routes (backward-compat single-account API) ────────────
|
||||||
|
|
||||||
@router.get("/config")
|
@router.get("/config")
|
||||||
async def get_config(request: Request):
|
async def get_config(request: Request):
|
||||||
|
"""Legacy single-account endpoint — returns the first configured account."""
|
||||||
owner = _require_user(request)
|
owner = _require_user(request)
|
||||||
from routes.prefs_routes import _load_for_user
|
accounts = _get_caldav_accounts(owner)
|
||||||
cfg = (_load_for_user(owner) or {}).get("caldav", {}) or {}
|
if not accounts:
|
||||||
caldav_password = cfg.get("password") or ""
|
return {"url": "", "username": "", "password": "", "has_password": False, "local": True}
|
||||||
if caldav_password:
|
first = accounts[0]
|
||||||
|
pw = first.get("password") or ""
|
||||||
|
has_pw = False
|
||||||
|
if pw:
|
||||||
try:
|
try:
|
||||||
from src.secret_storage import decrypt
|
from src.secret_storage import decrypt
|
||||||
caldav_password = decrypt(caldav_password)
|
has_pw = bool(decrypt(pw))
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
has_pw = bool(pw)
|
||||||
# Surface url+username but never hand the password back to the
|
|
||||||
# client — saved-state UI shouldn't leak the credential.
|
|
||||||
return {
|
return {
|
||||||
"url": cfg.get("url", "") or "",
|
"url": first.get("url", "") or "",
|
||||||
"username": cfg.get("username", "") or "",
|
"username": first.get("username", "") or "",
|
||||||
"password": "",
|
"password": "",
|
||||||
"has_password": bool(caldav_password),
|
"has_password": has_pw,
|
||||||
"local": not bool(cfg.get("url")),
|
"local": not bool(first.get("url")),
|
||||||
}
|
}
|
||||||
|
|
||||||
@router.post("/config")
|
@router.post("/config")
|
||||||
async def save_config(request: Request):
|
async def save_config(request: Request):
|
||||||
|
"""Legacy single-account endpoint — upserts the first account."""
|
||||||
owner = _require_user(request)
|
owner = _require_user(request)
|
||||||
from routes.prefs_routes import _load_for_user, _save_for_user
|
|
||||||
try:
|
try:
|
||||||
body = await request.json()
|
body = await request.json()
|
||||||
except Exception:
|
except Exception:
|
||||||
body = {}
|
body = {}
|
||||||
prefs = _load_for_user(owner) or {}
|
accounts = _get_caldav_accounts(owner)
|
||||||
cfg = dict(prefs.get("caldav") or {})
|
|
||||||
# Empty url => clear the whole entry (treat as "remove integration").
|
|
||||||
if not (body.get("url") or "").strip():
|
if not (body.get("url") or "").strip():
|
||||||
prefs.pop("caldav", None)
|
_save_caldav_accounts(owner, [])
|
||||||
_save_for_user(owner, prefs)
|
|
||||||
return {"ok": True, "cleared": True}
|
return {"ok": True, "cleared": True}
|
||||||
from src.caldav_sync import validate_caldav_url
|
from src.caldav_sync import validate_caldav_url
|
||||||
try:
|
try:
|
||||||
cfg["url"] = validate_caldav_url(body.get("url", ""))
|
validated_url = validate_caldav_url(body.get("url", ""))
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
raise HTTPException(400, str(e))
|
raise HTTPException(400, str(e))
|
||||||
cfg["username"] = (body.get("username") or "").strip()
|
if accounts:
|
||||||
# Preserve the stored password when the client sends an empty
|
acc = dict(accounts[0])
|
||||||
# one (edit form re-submitted without re-typing the password).
|
else:
|
||||||
# cfg already holds the existing (already-encrypted) password from
|
import uuid as _uuid
|
||||||
# prefs, so we only touch it when a new password is supplied —
|
acc = {"id": str(_uuid.uuid4()), "label": "CalDAV"}
|
||||||
# re-encrypting the stored value would double-encrypt it.
|
acc["url"] = validated_url
|
||||||
|
acc["username"] = (body.get("username") or "").strip()
|
||||||
if body.get("password"):
|
if body.get("password"):
|
||||||
from src.secret_storage import encrypt
|
from src.secret_storage import encrypt
|
||||||
cfg["password"] = encrypt(body["password"])
|
acc["password"] = encrypt(body["password"])
|
||||||
prefs["caldav"] = cfg
|
new_accounts = [acc] + (accounts[1:] if len(accounts) > 1 else [])
|
||||||
_save_for_user(owner, prefs)
|
_save_caldav_accounts(owner, new_accounts)
|
||||||
|
return {"ok": True}
|
||||||
|
|
||||||
|
# ── CalDAV multi-account CRUD ─────────────────────────────────────────────
|
||||||
|
|
||||||
|
@router.get("/config/accounts")
|
||||||
|
async def list_caldav_accounts(request: Request):
|
||||||
|
"""Return all configured CalDAV accounts (passwords never returned)."""
|
||||||
|
owner = _require_user(request)
|
||||||
|
accounts = _get_caldav_accounts(owner)
|
||||||
|
safe = []
|
||||||
|
for acc in accounts:
|
||||||
|
pw = acc.get("password") or ""
|
||||||
|
has_pw = False
|
||||||
|
if pw:
|
||||||
|
try:
|
||||||
|
from src.secret_storage import decrypt
|
||||||
|
has_pw = bool(decrypt(pw))
|
||||||
|
except Exception:
|
||||||
|
has_pw = bool(pw)
|
||||||
|
safe.append({
|
||||||
|
"id": acc.get("id", ""),
|
||||||
|
"label": acc.get("label", "") or acc.get("url", ""),
|
||||||
|
"url": acc.get("url", "") or "",
|
||||||
|
"username": acc.get("username", "") or "",
|
||||||
|
"has_password": has_pw,
|
||||||
|
})
|
||||||
|
return {"accounts": safe}
|
||||||
|
|
||||||
|
@router.post("/config/accounts")
|
||||||
|
async def add_caldav_account(request: Request):
|
||||||
|
"""Add a new CalDAV account."""
|
||||||
|
import uuid as _uuid
|
||||||
|
owner = _require_user(request)
|
||||||
|
try:
|
||||||
|
body = await request.json()
|
||||||
|
except Exception:
|
||||||
|
body = {}
|
||||||
|
from src.caldav_sync import validate_caldav_url
|
||||||
|
try:
|
||||||
|
url = validate_caldav_url(body.get("url", ""))
|
||||||
|
except ValueError as e:
|
||||||
|
raise HTTPException(400, str(e))
|
||||||
|
if not body.get("password"):
|
||||||
|
raise HTTPException(400, "Password is required")
|
||||||
|
from src.secret_storage import encrypt
|
||||||
|
new_acc = {
|
||||||
|
"id": str(_uuid.uuid4()),
|
||||||
|
"label": (body.get("label") or "").strip() or "CalDAV",
|
||||||
|
"url": url,
|
||||||
|
"username": (body.get("username") or "").strip(),
|
||||||
|
"password": encrypt(body["password"]),
|
||||||
|
}
|
||||||
|
accounts = _get_caldav_accounts(owner)
|
||||||
|
accounts.append(new_acc)
|
||||||
|
_save_caldav_accounts(owner, accounts)
|
||||||
|
return {"ok": True, "id": new_acc["id"]}
|
||||||
|
|
||||||
|
@router.put("/config/accounts/{account_id}")
|
||||||
|
async def update_caldav_account(account_id: str, request: Request):
|
||||||
|
"""Update an existing CalDAV account by id."""
|
||||||
|
owner = _require_user(request)
|
||||||
|
try:
|
||||||
|
body = await request.json()
|
||||||
|
except Exception:
|
||||||
|
body = {}
|
||||||
|
accounts = _get_caldav_accounts(owner)
|
||||||
|
idx = next((i for i, a in enumerate(accounts) if a.get("id") == account_id), None)
|
||||||
|
if idx is None:
|
||||||
|
raise HTTPException(404, "Account not found")
|
||||||
|
acc = dict(accounts[idx])
|
||||||
|
if body.get("url"):
|
||||||
|
from src.caldav_sync import validate_caldav_url
|
||||||
|
try:
|
||||||
|
acc["url"] = validate_caldav_url(body["url"])
|
||||||
|
except ValueError as e:
|
||||||
|
raise HTTPException(400, str(e))
|
||||||
|
if body.get("label") is not None:
|
||||||
|
acc["label"] = (body.get("label") or "").strip() or "CalDAV"
|
||||||
|
if body.get("username") is not None:
|
||||||
|
acc["username"] = (body.get("username") or "").strip()
|
||||||
|
if body.get("password"):
|
||||||
|
from src.secret_storage import encrypt
|
||||||
|
acc["password"] = encrypt(body["password"])
|
||||||
|
accounts[idx] = acc
|
||||||
|
_save_caldav_accounts(owner, accounts)
|
||||||
|
return {"ok": True}
|
||||||
|
|
||||||
|
@router.delete("/config/accounts/{account_id}")
|
||||||
|
async def delete_caldav_account(account_id: str, request: Request):
|
||||||
|
"""Remove a CalDAV account by id."""
|
||||||
|
owner = _require_user(request)
|
||||||
|
accounts = _get_caldav_accounts(owner)
|
||||||
|
new_accounts = [a for a in accounts if a.get("id") != account_id]
|
||||||
|
if len(new_accounts) == len(accounts):
|
||||||
|
raise HTTPException(404, "Account not found")
|
||||||
|
_save_caldav_accounts(owner, new_accounts)
|
||||||
return {"ok": True}
|
return {"ok": True}
|
||||||
|
|
||||||
@router.post("/test")
|
@router.post("/test")
|
||||||
async def test_connection(request: Request):
|
async def test_connection(request: Request):
|
||||||
"""Actually probe the configured CalDAV server with a PROPFIND
|
"""Probe a CalDAV server with a PROPFIND. Accepts an optional body:
|
||||||
request (the same handshake every CalDAV client uses). Accepts
|
{url, username, password} to test before saving, or {account_id} to
|
||||||
an optional {url, username, password} body so the user can test
|
test an already-saved account. Falls back to the first saved account
|
||||||
a configuration BEFORE saving it; falls back to the stored
|
when nothing is provided."""
|
||||||
creds otherwise. Returns {ok, error?} with a useful message on
|
|
||||||
failure (status code, auth issue, network error)."""
|
|
||||||
owner = _require_user(request)
|
owner = _require_user(request)
|
||||||
try:
|
try:
|
||||||
body = await request.json()
|
body = await request.json()
|
||||||
@@ -620,19 +771,24 @@ def setup_calendar_routes() -> APIRouter:
|
|||||||
user = (body.get("username") or "").strip()
|
user = (body.get("username") or "").strip()
|
||||||
pw = body.get("password") or ""
|
pw = body.get("password") or ""
|
||||||
if not (url and user and pw):
|
if not (url and user and pw):
|
||||||
# Fall back to saved settings for this user.
|
# Look up a saved account: by id if supplied, else first account.
|
||||||
from routes.prefs_routes import _load_for_user
|
accounts = _get_caldav_accounts(owner)
|
||||||
cfg = (_load_for_user(owner) or {}).get("caldav", {}) or {}
|
acc = None
|
||||||
url = url or (cfg.get("url") or "")
|
if body.get("account_id"):
|
||||||
user = user or (cfg.get("username") or "")
|
acc = next((a for a in accounts if a.get("id") == body["account_id"]), None)
|
||||||
if not pw:
|
if acc is None and accounts:
|
||||||
pw = cfg.get("password") or ""
|
acc = accounts[0]
|
||||||
if pw:
|
if acc:
|
||||||
try:
|
url = url or (acc.get("url") or "")
|
||||||
from src.secret_storage import decrypt
|
user = user or (acc.get("username") or "")
|
||||||
pw = decrypt(pw)
|
if not pw:
|
||||||
except Exception:
|
pw = acc.get("password") or ""
|
||||||
pass
|
if pw:
|
||||||
|
try:
|
||||||
|
from src.secret_storage import decrypt
|
||||||
|
pw = decrypt(pw)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
if not (url and user and pw):
|
if not (url and user and pw):
|
||||||
return {"ok": False, "error": "Missing URL, username, or password"}
|
return {"ok": False, "error": "Missing URL, username, or password"}
|
||||||
from src.caldav_sync import validate_caldav_url
|
from src.caldav_sync import validate_caldav_url
|
||||||
@@ -695,6 +851,28 @@ def setup_calendar_routes() -> APIRouter:
|
|||||||
from src.caldav_sync import sync_caldav
|
from src.caldav_sync import sync_caldav
|
||||||
return await sync_caldav(owner)
|
return await sync_caldav(owner)
|
||||||
|
|
||||||
|
@router.delete("/calendars/{cal_id}")
|
||||||
|
async def delete_calendar(cal_id: str, request: Request):
|
||||||
|
owner = _require_user(request)
|
||||||
|
db = SessionLocal()
|
||||||
|
try:
|
||||||
|
cal = db.query(CalendarCal).filter(
|
||||||
|
CalendarCal.id == cal_id,
|
||||||
|
CalendarCal.owner == owner,
|
||||||
|
).first()
|
||||||
|
if not cal:
|
||||||
|
raise HTTPException(404, "Calendar not found")
|
||||||
|
db.delete(cal)
|
||||||
|
db.commit()
|
||||||
|
return {"ok": True}
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Failed to delete calendar %s: %s", cal_id, e)
|
||||||
|
raise HTTPException(500, "Failed to delete calendar")
|
||||||
|
finally:
|
||||||
|
db.close()
|
||||||
|
|
||||||
@router.get("/calendars")
|
@router.get("/calendars")
|
||||||
async def list_calendars(request: Request):
|
async def list_calendars(request: Request):
|
||||||
owner = _require_user(request)
|
owner = _require_user(request)
|
||||||
@@ -703,7 +881,7 @@ def setup_calendar_routes() -> APIRouter:
|
|||||||
_ensure_default_calendar(db, owner)
|
_ensure_default_calendar(db, owner)
|
||||||
cals = db.query(CalendarCal).filter(CalendarCal.owner == owner).all()
|
cals = db.query(CalendarCal).filter(CalendarCal.owner == owner).all()
|
||||||
return {"calendars": [
|
return {"calendars": [
|
||||||
{"name": c.name, "href": c.id, "color": c.color}
|
{"name": c.name, "href": c.id, "color": c.color, "source": c.source}
|
||||||
for c in cals
|
for c in cals
|
||||||
]}
|
]}
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
@@ -766,8 +944,12 @@ def setup_calendar_routes() -> APIRouter:
|
|||||||
expanded.extend(_expand_rrule(e, start_dt, end_dt))
|
expanded.extend(_expand_rrule(e, start_dt, end_dt))
|
||||||
|
|
||||||
# Sort by occurrence start time for consistent frontend ordering.
|
# Sort by occurrence start time for consistent frontend ordering.
|
||||||
|
truncated = any(e.get("truncated") for e in expanded)
|
||||||
expanded.sort(key=lambda d: d["dtstart"])
|
expanded.sort(key=lambda d: d["dtstart"])
|
||||||
return {"events": expanded}
|
response: dict = {"events": expanded}
|
||||||
|
if truncated:
|
||||||
|
response["truncated"] = True
|
||||||
|
return response
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -988,9 +1170,9 @@ def setup_calendar_routes() -> APIRouter:
|
|||||||
finally:
|
finally:
|
||||||
db.close()
|
db.close()
|
||||||
|
|
||||||
# 10 MB hard cap on ICS upload. Loading the whole file into memory is
|
# Hard cap on ICS upload (ICS_MAX_BYTES, default 10 MB). Loading the whole
|
||||||
# unavoidable with python-icalendar, so an unbounded upload would OOM.
|
# file into memory is unavoidable with python-icalendar, so an unbounded
|
||||||
_ICS_MAX_BYTES = 10 * 1024 * 1024
|
# upload would OOM.
|
||||||
|
|
||||||
@router.post("/import")
|
@router.post("/import")
|
||||||
async def import_ics(request: Request, file: UploadFile = File(...), calendar_name: str = ""):
|
async def import_ics(request: Request, file: UploadFile = File(...), calendar_name: str = ""):
|
||||||
@@ -1000,7 +1182,7 @@ def setup_calendar_routes() -> APIRouter:
|
|||||||
owner = _require_user(request)
|
owner = _require_user(request)
|
||||||
db = SessionLocal()
|
db = SessionLocal()
|
||||||
try:
|
try:
|
||||||
content = await read_upload_limited(file, _ICS_MAX_BYTES, "ICS file")
|
content = await read_upload_limited(file, ICS_MAX_BYTES, "ICS file")
|
||||||
try:
|
try:
|
||||||
cal_data = iCal.from_ical(content)
|
cal_data = iCal.from_ical(content)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -1168,11 +1350,14 @@ def setup_calendar_routes() -> APIRouter:
|
|||||||
lines.append("END:VCALENDAR")
|
lines.append("END:VCALENDAR")
|
||||||
|
|
||||||
ics_data = "\r\n".join(lines)
|
ics_data = "\r\n".join(lines)
|
||||||
safe_name = cal.name.replace(" ", "_").replace("/", "_")
|
download_name = _safe_ics_filename(cal.name)
|
||||||
return Response(
|
return Response(
|
||||||
content=ics_data,
|
content=ics_data,
|
||||||
media_type="text/calendar",
|
media_type="text/calendar",
|
||||||
headers={"Content-Disposition": f'attachment; filename="{safe_name}.ics"'},
|
headers={
|
||||||
|
"Content-Disposition": f'attachment; filename="{download_name}"',
|
||||||
|
"X-Content-Type-Options": "nosniff",
|
||||||
|
},
|
||||||
)
|
)
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
raise
|
raise
|
||||||
@@ -1194,7 +1379,7 @@ def setup_calendar_routes() -> APIRouter:
|
|||||||
"tomorrow", "next Tuesday", "in 30 minutes" resolve correctly.
|
"tomorrow", "next Tuesday", "in 30 minutes" resolve correctly.
|
||||||
Uses the "utility" endpoint (small / fast model) to keep latency low.
|
Uses the "utility" endpoint (small / fast model) to keep latency low.
|
||||||
"""
|
"""
|
||||||
_require_user(request)
|
owner = _require_user(request)
|
||||||
from src.endpoint_resolver import resolve_endpoint
|
from src.endpoint_resolver import resolve_endpoint
|
||||||
from src.llm_core import llm_call_async
|
from src.llm_core import llm_call_async
|
||||||
from src.text_helpers import strip_think
|
from src.text_helpers import strip_think
|
||||||
@@ -1220,9 +1405,9 @@ def setup_calendar_routes() -> APIRouter:
|
|||||||
if tz_hint:
|
if tz_hint:
|
||||||
set_user_tz_name(tz_hint)
|
set_user_tz_name(tz_hint)
|
||||||
|
|
||||||
url, model, headers = resolve_endpoint("utility")
|
url, model, headers = resolve_endpoint("utility", owner=owner or None)
|
||||||
if not url:
|
if not url:
|
||||||
url, model, headers = resolve_endpoint("default")
|
url, model, headers = resolve_endpoint("default", owner=owner or None)
|
||||||
if not url or not model:
|
if not url or not model:
|
||||||
return {"ok": False, "error": "No LLM endpoint configured"}
|
return {"ok": False, "error": "No LLM endpoint configured"}
|
||||||
|
|
||||||
|
|||||||
+130
-38
@@ -75,7 +75,7 @@ def _enforce_chat_privileges(request, sess) -> None:
|
|||||||
allowlist, or HTTPException(429) if the user has hit their daily message
|
allowlist, or HTTPException(429) if the user has hit their daily message
|
||||||
cap. No-op for unauthenticated callers or when auth_manager is absent
|
cap. No-op for unauthenticated callers or when auth_manager is absent
|
||||||
(single-user mode). Admins receive ADMIN_PRIVILEGES from get_privileges,
|
(single-user mode). Admins receive ADMIN_PRIVILEGES from get_privileges,
|
||||||
which means empty allowed_models / zero cap → no-op for them.
|
which means unrestricted allowed_models / zero cap -> no-op for them.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
user = get_current_user(request)
|
user = get_current_user(request)
|
||||||
@@ -88,8 +88,18 @@ def _enforce_chat_privileges(request, sess) -> None:
|
|||||||
return
|
return
|
||||||
|
|
||||||
privs = auth_manager.get_privileges(user) or {}
|
privs = auth_manager.get_privileges(user) or {}
|
||||||
allowed = privs.get("allowed_models") or []
|
|
||||||
if allowed and sess.model and sess.model not in allowed:
|
# Explicit "block everything" sentinel takes precedence over the
|
||||||
|
# allowlist — it's the only way to distinguish "user clicked [None]"
|
||||||
|
# (block all) from "user clicked [All]" (no restriction), since both
|
||||||
|
# otherwise produce an empty `allowed_models` list.
|
||||||
|
if privs.get("block_all_models"):
|
||||||
|
raise HTTPException(403, f"Your account is not allowed to use model '{sess.model}'.")
|
||||||
|
|
||||||
|
allowed_raw = privs.get("allowed_models")
|
||||||
|
allowed = allowed_raw if isinstance(allowed_raw, list) else []
|
||||||
|
restricted = bool(privs.get("allowed_models_restricted")) or bool(allowed)
|
||||||
|
if restricted and sess.model and sess.model not in allowed:
|
||||||
raise HTTPException(403, f"Your account is not allowed to use model '{sess.model}'.")
|
raise HTTPException(403, f"Your account is not allowed to use model '{sess.model}'.")
|
||||||
|
|
||||||
cap = int(privs.get("max_messages_per_day") or 0)
|
cap = int(privs.get("max_messages_per_day") or 0)
|
||||||
@@ -194,14 +204,26 @@ def try_fallback_endpoint(sess, session_id: str) -> dict | None:
|
|||||||
Returns {"model": ..., "endpoint_url": ..., "endpoint_name": ...} or None.
|
Returns {"model": ..., "endpoint_url": ..., "endpoint_name": ...} or None.
|
||||||
"""
|
"""
|
||||||
import requests as _req
|
import requests as _req
|
||||||
from src.endpoint_resolver import build_chat_url, build_headers, build_models_url, normalize_base
|
from src.endpoint_resolver import (
|
||||||
|
build_chat_url,
|
||||||
|
build_headers,
|
||||||
|
build_models_url,
|
||||||
|
normalize_base,
|
||||||
|
resolve_endpoint_runtime,
|
||||||
|
)
|
||||||
|
from src.chatgpt_subscription import is_chatgpt_subscription_base
|
||||||
|
|
||||||
current_url = sess.endpoint_url or ""
|
current_url = sess.endpoint_url or ""
|
||||||
|
owner = getattr(sess, "owner", None)
|
||||||
db = SessionLocal()
|
db = SessionLocal()
|
||||||
try:
|
try:
|
||||||
endpoints = db.query(ModelEndpoint).filter(
|
q = db.query(ModelEndpoint).filter(
|
||||||
ModelEndpoint.is_enabled == True
|
ModelEndpoint.is_enabled == True
|
||||||
).all()
|
)
|
||||||
|
if owner:
|
||||||
|
from src.auth_helpers import owner_filter
|
||||||
|
q = owner_filter(q, ModelEndpoint, owner)
|
||||||
|
endpoints = q.all()
|
||||||
finally:
|
finally:
|
||||||
db.close()
|
db.close()
|
||||||
|
|
||||||
@@ -210,26 +232,33 @@ def try_fallback_endpoint(sess, session_id: str) -> dict | None:
|
|||||||
# Skip current endpoint
|
# Skip current endpoint
|
||||||
if current_url and base in current_url:
|
if current_url and base in current_url:
|
||||||
continue
|
continue
|
||||||
# Quick ping
|
|
||||||
ping_url = build_models_url(base)
|
|
||||||
headers = build_headers(ep.api_key, base)
|
|
||||||
try:
|
try:
|
||||||
r = _req.get(ping_url, headers=headers, timeout=5)
|
base, api_key = resolve_endpoint_runtime(ep, owner=owner)
|
||||||
r.raise_for_status()
|
except Exception:
|
||||||
data = r.json()
|
continue
|
||||||
models = [m.get("id") for m in (data.get("data") or []) if m.get("id")]
|
ping_url = build_models_url(base)
|
||||||
if not models:
|
headers = build_headers(api_key, base)
|
||||||
models = [
|
try:
|
||||||
m.get("name") or m.get("model")
|
if ping_url:
|
||||||
for m in (data.get("models") or [])
|
r = _req.get(ping_url, headers=headers, timeout=5)
|
||||||
if m.get("name") or m.get("model")
|
r.raise_for_status()
|
||||||
]
|
data = r.json()
|
||||||
|
models = [m.get("id") for m in (data.get("data") or []) if m.get("id")]
|
||||||
|
if not models:
|
||||||
|
models = [
|
||||||
|
m.get("name") or m.get("model")
|
||||||
|
for m in (data.get("models") or [])
|
||||||
|
if m.get("name") or m.get("model")
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
models = json.loads(ep.cached_models or "[]")
|
||||||
if not models:
|
if not models:
|
||||||
continue
|
continue
|
||||||
# Found a working endpoint — update session
|
# Found a working endpoint — update session
|
||||||
new_model = models[0]
|
new_model = models[0]
|
||||||
chat_url = build_chat_url(base)
|
chat_url = build_chat_url(base)
|
||||||
new_headers = build_headers(ep.api_key, base)
|
new_headers = build_headers(api_key, base)
|
||||||
|
persisted_headers = {} if is_chatgpt_subscription_base(base) else new_headers
|
||||||
|
|
||||||
sess.model = new_model
|
sess.model = new_model
|
||||||
sess.endpoint_url = chat_url
|
sess.endpoint_url = chat_url
|
||||||
@@ -241,7 +270,7 @@ def try_fallback_endpoint(sess, session_id: str) -> dict | None:
|
|||||||
_db.query(DBSession).filter(DBSession.id == session_id).update({
|
_db.query(DBSession).filter(DBSession.id == session_id).update({
|
||||||
"model": new_model,
|
"model": new_model,
|
||||||
"endpoint_url": chat_url,
|
"endpoint_url": chat_url,
|
||||||
"headers": json.dumps(new_headers),
|
"headers": persisted_headers,
|
||||||
})
|
})
|
||||||
_db.commit()
|
_db.commit()
|
||||||
finally:
|
finally:
|
||||||
@@ -275,11 +304,16 @@ def extract_preset(chat_handler, preset_id) -> PresetInfo:
|
|||||||
async def preprocess(
|
async def preprocess(
|
||||||
chat_handler, message, att_ids, sess,
|
chat_handler, message, att_ids, sess,
|
||||||
auto_opened_docs: Optional[list] = None,
|
auto_opened_docs: Optional[list] = None,
|
||||||
|
allow_tool_preprocessing: bool = True,
|
||||||
) -> PreprocessedMessage:
|
) -> PreprocessedMessage:
|
||||||
"""Run chat_handler.preprocess_message and wrap the result."""
|
"""Run chat_handler.preprocess_message and wrap the result."""
|
||||||
enhanced, user_content, text_ctx, yt_transcripts, att_meta = (
|
enhanced, user_content, text_ctx, yt_transcripts, att_meta = (
|
||||||
await chat_handler.preprocess_message(
|
await chat_handler.preprocess_message(
|
||||||
message, att_ids, sess, auto_opened_docs=auto_opened_docs
|
message,
|
||||||
|
att_ids,
|
||||||
|
sess,
|
||||||
|
auto_opened_docs=auto_opened_docs,
|
||||||
|
allow_tool_preprocessing=allow_tool_preprocessing,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return PreprocessedMessage(
|
return PreprocessedMessage(
|
||||||
@@ -329,16 +363,26 @@ def _session_url_matches_endpoint(session_url: str, endpoint_base: str) -> bool:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _has_auth_keys(headers) -> bool:
|
||||||
|
"""True if a headers dict carries an Authorization/x-api-key entry."""
|
||||||
|
return isinstance(headers, dict) and any(
|
||||||
|
k.lower() in ('authorization', 'x-api-key') for k in headers
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def resolve_session_auth(sess, session_id: str, owner: Optional[str] = None):
|
def resolve_session_auth(sess, session_id: str, owner: Optional[str] = None):
|
||||||
"""Ensure session has auth headers — resolve from endpoint DB if missing."""
|
"""Ensure session has auth headers — resolve from endpoint DB if missing."""
|
||||||
has_auth = sess.headers and isinstance(sess.headers, dict) and any(
|
try:
|
||||||
k.lower() in ('authorization', 'x-api-key') for k in sess.headers
|
from src.chatgpt_subscription import is_chatgpt_subscription_base
|
||||||
)
|
is_chatgpt_subscription = is_chatgpt_subscription_base(getattr(sess, "endpoint_url", "") or "")
|
||||||
if has_auth:
|
except Exception:
|
||||||
|
is_chatgpt_subscription = False
|
||||||
|
has_auth = _has_auth_keys(sess.headers)
|
||||||
|
if has_auth and not is_chatgpt_subscription:
|
||||||
return
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from src.endpoint_resolver import build_headers, normalize_base
|
from src.endpoint_resolver import build_headers, resolve_endpoint_runtime
|
||||||
db = SessionLocal()
|
db = SessionLocal()
|
||||||
try:
|
try:
|
||||||
target_url = getattr(sess, "endpoint_url", "") or ""
|
target_url = getattr(sess, "endpoint_url", "") or ""
|
||||||
@@ -354,10 +398,30 @@ def resolve_session_auth(sess, session_id: str, owner: Optional[str] = None):
|
|||||||
for ep in q.all():
|
for ep in q.all():
|
||||||
if not _session_url_matches_endpoint(target_url, ep.base_url or ""):
|
if not _session_url_matches_endpoint(target_url, ep.base_url or ""):
|
||||||
continue
|
continue
|
||||||
if not ep.api_key:
|
try:
|
||||||
|
base, api_key = resolve_endpoint_runtime(ep, owner=owner)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("Failed to resolve provider auth for session %s: %s", session_id, e)
|
||||||
|
return
|
||||||
|
if not api_key:
|
||||||
|
# No usable key (e.g. ChatGPT Subscription needs re-auth).
|
||||||
|
return
|
||||||
|
sess.headers = build_headers(api_key, base)
|
||||||
|
if is_chatgpt_subscription:
|
||||||
|
# The bearer is short-lived and re-resolved per request, so it
|
||||||
|
# stays request-local and is never written to the plaintext
|
||||||
|
# sessions.headers column. Proactively strip any bearer an
|
||||||
|
# older code path may have persisted so it does not linger.
|
||||||
|
stale_q = db.query(DBSession).filter(DBSession.id == session_id)
|
||||||
|
if owner:
|
||||||
|
stale_q = stale_q.filter(DBSession.owner == owner)
|
||||||
|
stored = stale_q.first()
|
||||||
|
if stored is not None and _has_auth_keys(stored.headers):
|
||||||
|
stale_q.update({"headers": {}})
|
||||||
|
db.commit()
|
||||||
|
logger.info(f"Cleared persisted ChatGPT Subscription bearer from session {session_id}")
|
||||||
|
logger.debug(f"Resolved request-local ChatGPT Subscription auth for session {session_id}")
|
||||||
return
|
return
|
||||||
base = normalize_base(ep.base_url or "")
|
|
||||||
sess.headers = build_headers(ep.api_key, base)
|
|
||||||
update_q = db.query(DBSession).filter(DBSession.id == session_id)
|
update_q = db.query(DBSession).filter(DBSession.id == session_id)
|
||||||
if owner:
|
if owner:
|
||||||
update_q = update_q.filter(DBSession.owner == owner)
|
update_q = update_q.filter(DBSession.owner == owner)
|
||||||
@@ -401,7 +465,12 @@ def _normalize_model_id_from_cache(sess) -> Optional[str]:
|
|||||||
|
|
||||||
db = SessionLocal()
|
db = SessionLocal()
|
||||||
try:
|
try:
|
||||||
endpoints = db.query(ModelEndpoint).filter(ModelEndpoint.is_enabled == True).all()
|
q = db.query(ModelEndpoint).filter(ModelEndpoint.is_enabled == True)
|
||||||
|
owner = getattr(sess, "owner", None)
|
||||||
|
if owner:
|
||||||
|
from src.auth_helpers import owner_filter
|
||||||
|
q = owner_filter(q, ModelEndpoint, owner)
|
||||||
|
endpoints = q.all()
|
||||||
for ep in endpoints:
|
for ep in endpoints:
|
||||||
try:
|
try:
|
||||||
if normalize_base(getattr(ep, "base_url", "") or "") != session_base:
|
if normalize_base(getattr(ep, "base_url", "") or "") != session_base:
|
||||||
@@ -448,6 +517,7 @@ async def build_chat_context(
|
|||||||
webhook_manager=None,
|
webhook_manager=None,
|
||||||
use_enhanced_message: bool = False,
|
use_enhanced_message: bool = False,
|
||||||
agent_mode: bool = False,
|
agent_mode: bool = False,
|
||||||
|
allow_tool_preprocessing: bool = True,
|
||||||
) -> ChatContext:
|
) -> ChatContext:
|
||||||
"""Build the full context (preface + messages) for an LLM call.
|
"""Build the full context (preface + messages) for an LLM call.
|
||||||
|
|
||||||
@@ -465,6 +535,7 @@ async def build_chat_context(
|
|||||||
preprocessed = await preprocess(
|
preprocessed = await preprocess(
|
||||||
chat_handler, message, att_ids or [], sess,
|
chat_handler, message, att_ids or [], sess,
|
||||||
auto_opened_docs=auto_opened_docs,
|
auto_opened_docs=auto_opened_docs,
|
||||||
|
allow_tool_preprocessing=allow_tool_preprocessing,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add user message to history
|
# Add user message to history
|
||||||
@@ -483,6 +554,9 @@ async def build_chat_context(
|
|||||||
# Skills injection respects its own enable toggle (mirrors memory_enabled).
|
# Skills injection respects its own enable toggle (mirrors memory_enabled).
|
||||||
# When off, the "Available skills" index is not added to the prompt.
|
# When off, the "Available skills" index is not added to the prompt.
|
||||||
skills_enabled = not incognito and uprefs.get("skills_enabled", True)
|
skills_enabled = not incognito and uprefs.get("skills_enabled", True)
|
||||||
|
if not allow_tool_preprocessing:
|
||||||
|
mem_enabled = False
|
||||||
|
skills_enabled = False
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"Memory enabled=%s for user=%s (incognito=%s, no_memory=%s, pref=%s)",
|
"Memory enabled=%s for user=%s (incognito=%s, no_memory=%s, pref=%s)",
|
||||||
mem_enabled, user, incognito, no_memory, uprefs.get("memory_enabled", "NOT_SET"),
|
mem_enabled, user, incognito, no_memory, uprefs.get("memory_enabled", "NOT_SET"),
|
||||||
@@ -490,11 +564,11 @@ async def build_chat_context(
|
|||||||
|
|
||||||
# Use RAG?
|
# Use RAG?
|
||||||
use_rag_val = (str(use_rag).lower() != "false") if use_rag is not None else True
|
use_rag_val = (str(use_rag).lower() != "false") if use_rag is not None else True
|
||||||
if incognito:
|
if incognito or not allow_tool_preprocessing:
|
||||||
use_rag_val = False
|
use_rag_val = False
|
||||||
|
|
||||||
# If pre-fetched search context was provided (compare mode), skip live web search
|
# If pre-fetched search context was provided (compare mode), skip live web search
|
||||||
skip_web = bool(search_context)
|
skip_web = bool(search_context) or not allow_tool_preprocessing
|
||||||
|
|
||||||
# Build context preface
|
# Build context preface
|
||||||
# The stream path uses enhanced_message (with CoT/preprocessing applied),
|
# The stream path uses enhanced_message (with CoT/preprocessing applied),
|
||||||
@@ -521,7 +595,7 @@ async def build_chat_context(
|
|||||||
used_memories = getattr(chat_processor, '_last_used_memories', [])
|
used_memories = getattr(chat_processor, '_last_used_memories', [])
|
||||||
|
|
||||||
# Inject pre-fetched search context (compare mode)
|
# Inject pre-fetched search context (compare mode)
|
||||||
if search_context:
|
if search_context and allow_tool_preprocessing:
|
||||||
preface.append(untrusted_context_message("prefetched search context", search_context))
|
preface.append(untrusted_context_message("prefetched search context", search_context))
|
||||||
|
|
||||||
# YouTube transcripts
|
# YouTube transcripts
|
||||||
@@ -530,7 +604,11 @@ async def build_chat_context(
|
|||||||
|
|
||||||
# Normalize model ID. Prefer cached endpoint models so group chat does not
|
# Normalize model ID. Prefer cached endpoint models so group chat does not
|
||||||
# re-hit slow local /models endpoints on every participant turn.
|
# re-hit slow local /models endpoints on every participant turn.
|
||||||
norm = _normalize_model_id_from_cache(sess) or normalize_model_id(sess.endpoint_url, sess.model)
|
norm = _normalize_model_id_from_cache(sess) or normalize_model_id(
|
||||||
|
sess.endpoint_url,
|
||||||
|
sess.model,
|
||||||
|
owner=getattr(sess, "owner", None),
|
||||||
|
)
|
||||||
if norm:
|
if norm:
|
||||||
sess.model = norm
|
sess.model = norm
|
||||||
|
|
||||||
@@ -539,7 +617,7 @@ async def build_chat_context(
|
|||||||
|
|
||||||
# Auto-compact
|
# Auto-compact
|
||||||
messages, context_length, was_compacted = await maybe_compact(
|
messages, context_length, was_compacted = await maybe_compact(
|
||||||
sess, sess.endpoint_url, sess.model, messages, sess.headers,
|
sess, sess.endpoint_url, sess.model, messages, sess.headers, owner=user,
|
||||||
)
|
)
|
||||||
messages = trim_for_context(messages, context_length)
|
messages = trim_for_context(messages, context_length)
|
||||||
|
|
||||||
@@ -772,7 +850,19 @@ def save_assistant_response(
|
|||||||
):
|
):
|
||||||
"""Add assistant response to session history. In incognito mode, keeps in-memory context but skips DB persistence."""
|
"""Add assistant response to session history. In incognito mode, keeps in-memory context but skips DB persistence."""
|
||||||
md = dict(last_metrics) if last_metrics else {}
|
md = dict(last_metrics) if last_metrics else {}
|
||||||
md["model"] = sess.model
|
def _model_value(value) -> str:
|
||||||
|
if value is None:
|
||||||
|
return ""
|
||||||
|
if not isinstance(value, str):
|
||||||
|
value = str(value)
|
||||||
|
return value.strip()
|
||||||
|
|
||||||
|
requested_model = _model_value(md.get("requested_model") or md.get("selected_model") or getattr(sess, "model", ""))
|
||||||
|
actual_model = _model_value(md.get("model") or md.get("actual_model") or requested_model)
|
||||||
|
if requested_model:
|
||||||
|
md["requested_model"] = requested_model
|
||||||
|
if actual_model:
|
||||||
|
md["model"] = actual_model
|
||||||
if character_name:
|
if character_name:
|
||||||
md["character_name"] = character_name
|
md["character_name"] = character_name
|
||||||
if web_sources:
|
if web_sources:
|
||||||
@@ -841,12 +931,13 @@ def run_post_response_tasks(
|
|||||||
skills_manager=None,
|
skills_manager=None,
|
||||||
owner: str = None,
|
owner: str = None,
|
||||||
extract_skills: bool = True,
|
extract_skills: bool = True,
|
||||||
|
allow_background_extraction: bool = True,
|
||||||
):
|
):
|
||||||
"""Fire background tasks after a completed response: memory extraction, webhooks, auto-name, skill extraction."""
|
"""Fire background tasks after a completed response: memory extraction, webhooks, auto-name, skill extraction."""
|
||||||
# Memory extraction — only every 4th message pair to avoid excess LLM calls
|
# Memory extraction — only every 4th message pair to avoid excess LLM calls
|
||||||
_msg_count = len(sess.history) if hasattr(sess, 'history') else 0
|
_msg_count = len(sess.history) if hasattr(sess, 'history') else 0
|
||||||
_should_extract = (_msg_count >= 4) and (_msg_count % 4 == 0)
|
_should_extract = (_msg_count >= 4) and (_msg_count % 4 == 0)
|
||||||
if not incognito and not compare_mode and _should_extract and uprefs.get("auto_memory", True):
|
if allow_background_extraction and not incognito and not compare_mode and _should_extract and uprefs.get("auto_memory", True):
|
||||||
from services.memory.memory_extractor import extract_and_store
|
from services.memory.memory_extractor import extract_and_store
|
||||||
from src.task_endpoint import resolve_task_endpoint
|
from src.task_endpoint import resolve_task_endpoint
|
||||||
t_url, t_model, t_headers = resolve_task_endpoint(
|
t_url, t_model, t_headers = resolve_task_endpoint(
|
||||||
@@ -873,6 +964,7 @@ def run_post_response_tasks(
|
|||||||
)
|
)
|
||||||
if (
|
if (
|
||||||
extract_skills
|
extract_skills
|
||||||
|
and allow_background_extraction
|
||||||
and auto_skills_enabled
|
and auto_skills_enabled
|
||||||
and not incognito
|
and not incognito
|
||||||
and not compare_mode
|
and not compare_mode
|
||||||
|
|||||||
+206
-72
@@ -20,6 +20,7 @@ from src import agent_runs
|
|||||||
from src.model_context import estimate_tokens
|
from src.model_context import estimate_tokens
|
||||||
from src.chat_helpers import coerce_message_and_session
|
from src.chat_helpers import coerce_message_and_session
|
||||||
from src.endpoint_resolver import normalize_base as _normalize_base, build_chat_url
|
from src.endpoint_resolver import normalize_base as _normalize_base, build_chat_url
|
||||||
|
from src.session_search import search_session_messages
|
||||||
from src.prompt_security import untrusted_context_message
|
from src.prompt_security import untrusted_context_message
|
||||||
from core.exceptions import SessionNotFoundError
|
from core.exceptions import SessionNotFoundError
|
||||||
from src.auth_helpers import get_current_user
|
from src.auth_helpers import get_current_user
|
||||||
@@ -39,6 +40,7 @@ from routes.chat_helpers import (
|
|||||||
_enforce_chat_privileges,
|
_enforce_chat_privileges,
|
||||||
)
|
)
|
||||||
from src.action_intents import classify_tool_intent as _classify_tool_intent
|
from src.action_intents import classify_tool_intent as _classify_tool_intent
|
||||||
|
from src.tool_policy import build_effective_tool_policy
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -167,13 +169,20 @@ def _recover_empty_session_model(sess, session_id: str, owner: str | None = None
|
|||||||
Covers the window between endpoint setup and the first chat send: the
|
Covers the window between endpoint setup and the first chat send: the
|
||||||
picker showed a model in the dropdown but the session record never got
|
picker showed a model in the dropdown but the session record never got
|
||||||
written (Issue #587 — UI uses the cached endpoint list, not s.model).
|
written (Issue #587 — UI uses the cached endpoint list, not s.model).
|
||||||
Without this, we'd POST the upstream with model="" and get a generic
|
For ChatGPT Subscription, also repairs stale OpenAI API model names such as
|
||||||
401/503 instead of using the model the user already picked.
|
``gpt-5`` that are not accepted by the Codex-backed ChatGPT account route.
|
||||||
|
|
||||||
Returns True iff sess.model was repaired.
|
|
||||||
"""
|
"""
|
||||||
if getattr(sess, "model", None):
|
current_model = (getattr(sess, "model", "") or "").strip()
|
||||||
return False
|
endpoint_url = (getattr(sess, "endpoint_url", "") or "").strip()
|
||||||
|
is_chatgpt_subscription = False
|
||||||
|
if current_model:
|
||||||
|
try:
|
||||||
|
from src.chatgpt_subscription import is_chatgpt_subscription_base
|
||||||
|
is_chatgpt_subscription = is_chatgpt_subscription_base(endpoint_url)
|
||||||
|
if not is_chatgpt_subscription:
|
||||||
|
return False
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
db = SessionLocal()
|
db = SessionLocal()
|
||||||
try:
|
try:
|
||||||
# Prefer the endpoint whose base URL matches the session — we know the
|
# Prefer the endpoint whose base URL matches the session — we know the
|
||||||
@@ -192,16 +201,51 @@ def _recover_empty_session_model(sess, session_id: str, owner: str | None = None
|
|||||||
break
|
break
|
||||||
if not ep:
|
if not ep:
|
||||||
return False
|
return False
|
||||||
|
if not is_chatgpt_subscription:
|
||||||
|
try:
|
||||||
|
from src.chatgpt_subscription import is_chatgpt_subscription_base
|
||||||
|
is_chatgpt_subscription = is_chatgpt_subscription_base(getattr(ep, "base_url", "") or endpoint_url)
|
||||||
|
except Exception:
|
||||||
|
is_chatgpt_subscription = False
|
||||||
try:
|
try:
|
||||||
cached = json.loads(ep.cached_models) if isinstance(ep.cached_models, str) else (ep.cached_models or [])
|
cached = json.loads(ep.cached_models) if isinstance(ep.cached_models, str) else (ep.cached_models or [])
|
||||||
except Exception:
|
except Exception:
|
||||||
cached = []
|
cached = []
|
||||||
if not cached:
|
if not cached:
|
||||||
|
visible = []
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
visible = _visible_models(cached, getattr(ep, "hidden_models", None))
|
||||||
|
except Exception:
|
||||||
|
visible = cached
|
||||||
|
if current_model and current_model in {str(item).strip() for item in visible}:
|
||||||
return False
|
return False
|
||||||
try:
|
if is_chatgpt_subscription:
|
||||||
visible = _visible_models(cached, getattr(ep, "hidden_models", None))
|
live_models = []
|
||||||
except Exception:
|
if getattr(ep, "provider_auth_id", None):
|
||||||
visible = cached
|
try:
|
||||||
|
from src.chatgpt_subscription import fetch_available_models
|
||||||
|
from src.endpoint_resolver import resolve_endpoint_runtime
|
||||||
|
_base, api_key = resolve_endpoint_runtime(ep, owner=owner)
|
||||||
|
if api_key:
|
||||||
|
live_models = fetch_available_models(api_key)
|
||||||
|
if live_models:
|
||||||
|
ep.cached_models = json.dumps(live_models)
|
||||||
|
db.commit()
|
||||||
|
except Exception:
|
||||||
|
live_models = []
|
||||||
|
# ChatGPT Subscription recovery must use the live Codex catalog.
|
||||||
|
# Cached rows are only trusted above to avoid revalidating a model
|
||||||
|
# that is already present in the visible picker list.
|
||||||
|
cached = live_models
|
||||||
|
if not cached:
|
||||||
|
return False
|
||||||
|
try:
|
||||||
|
visible = _visible_models(cached, getattr(ep, "hidden_models", None))
|
||||||
|
except Exception:
|
||||||
|
visible = cached
|
||||||
|
if current_model and current_model in {str(item).strip() for item in visible}:
|
||||||
|
return False
|
||||||
if not visible:
|
if not visible:
|
||||||
return False
|
return False
|
||||||
model = visible[0]
|
model = visible[0]
|
||||||
@@ -211,14 +255,17 @@ def _recover_empty_session_model(sess, session_id: str, owner: str | None = None
|
|||||||
# Persist so the next request, websocket reconnect, or page reload
|
# Persist so the next request, websocket reconnect, or page reload
|
||||||
# picks up the same model (we'd otherwise re-pick on every send
|
# picks up the same model (we'd otherwise re-pick on every send
|
||||||
# and silently switch on the user if the cached order shifts).
|
# and silently switch on the user if the cached order shifts).
|
||||||
db_session = db.query(DBSession).filter(DBSession.id == session_id).first()
|
db_session_q = db.query(DBSession).filter(DBSession.id == session_id)
|
||||||
|
if owner:
|
||||||
|
db_session_q = db_session_q.filter(DBSession.owner == owner)
|
||||||
|
db_session = db_session_q.first()
|
||||||
if db_session:
|
if db_session:
|
||||||
db_session.model = model
|
db_session.model = model
|
||||||
db_session.updated_at = datetime.utcnow()
|
db_session.updated_at = datetime.utcnow()
|
||||||
db.commit()
|
db.commit()
|
||||||
sess.model = model
|
sess.model = model
|
||||||
logger.info(
|
logger.info(
|
||||||
"Recovered empty session model for %s — picked %r from endpoint %s",
|
"Recovered session model for %s — picked %r from endpoint %s",
|
||||||
session_id, model, ep.id,
|
session_id, model, ep.id,
|
||||||
)
|
)
|
||||||
return True
|
return True
|
||||||
@@ -304,8 +351,13 @@ def setup_chat_routes(
|
|||||||
# non-streaming path can't be used to bypass).
|
# non-streaming path can't be used to bypass).
|
||||||
_enforce_chat_privileges(request, sess)
|
_enforce_chat_privileges(request, sess)
|
||||||
|
|
||||||
|
tool_policy = build_effective_tool_policy(last_user_message=message)
|
||||||
|
allow_tool_preprocessing = not tool_policy.block_all_tool_calls
|
||||||
|
|
||||||
# Inline memory command
|
# Inline memory command
|
||||||
memory_response = await chat_handler.handle_memory_command(sess, message)
|
memory_response = None
|
||||||
|
if not tool_policy.blocks("manage_memory"):
|
||||||
|
memory_response = await chat_handler.handle_memory_command(sess, message)
|
||||||
if memory_response:
|
if memory_response:
|
||||||
return {"response": memory_response}
|
return {"response": memory_response}
|
||||||
|
|
||||||
@@ -319,10 +371,15 @@ def setup_chat_routes(
|
|||||||
use_web=use_web,
|
use_web=use_web,
|
||||||
time_filter=time_filter,
|
time_filter=time_filter,
|
||||||
webhook_manager=webhook_manager,
|
webhook_manager=webhook_manager,
|
||||||
|
allow_tool_preprocessing=allow_tool_preprocessing,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Research injection
|
# Research injection
|
||||||
if use_research:
|
research_blocked_by_policy = (
|
||||||
|
tool_policy.blocks("trigger_research")
|
||||||
|
or tool_policy.blocks("manage_research")
|
||||||
|
)
|
||||||
|
if use_research and not research_blocked_by_policy:
|
||||||
try:
|
try:
|
||||||
_r_ep, _r_model, _r_headers = _resolve_research_endpoint(sess)
|
_r_ep, _r_model, _r_headers = _resolve_research_endpoint(sess)
|
||||||
research_ctx = await research_handler.call_research_service(
|
research_ctx = await research_handler.call_research_service(
|
||||||
@@ -357,6 +414,7 @@ def setup_chat_routes(
|
|||||||
ctx.uprefs, memory_manager, memory_vector, webhook_manager,
|
ctx.uprefs, memory_manager, memory_vector, webhook_manager,
|
||||||
character_name=ctx.preset.character_name,
|
character_name=ctx.preset.character_name,
|
||||||
owner=ctx.user,
|
owner=ctx.user,
|
||||||
|
allow_background_extraction=not tool_policy.block_all_tool_calls,
|
||||||
)
|
)
|
||||||
|
|
||||||
return {"response": reply}
|
return {"response": reply}
|
||||||
@@ -394,6 +452,7 @@ def setup_chat_routes(
|
|||||||
search_context = form_data.get("search_context") # pre-fetched web search results (compare mode)
|
search_context = form_data.get("search_context") # pre-fetched web search results (compare mode)
|
||||||
compare_mode = str(form_data.get("compare_mode", "")).lower() == "true"
|
compare_mode = str(form_data.get("compare_mode", "")).lower() == "true"
|
||||||
incognito = str(form_data.get("incognito", "")).lower() == "true"
|
incognito = str(form_data.get("incognito", "")).lower() == "true"
|
||||||
|
plan_mode = str(form_data.get("plan_mode", "")).lower() == "true"
|
||||||
chat_mode = str(form_data.get("mode", "")).lower() # 'chat' or 'agent'
|
chat_mode = str(form_data.get("mode", "")).lower() # 'chat' or 'agent'
|
||||||
# Workspace: confine the agent's file/shell tools to this folder. Validate
|
# Workspace: confine the agent's file/shell tools to this folder. Validate
|
||||||
# it's a real directory; ignore (no confinement) otherwise.
|
# it's a real directory; ignore (no confinement) otherwise.
|
||||||
@@ -401,6 +460,17 @@ def setup_chat_routes(
|
|||||||
if workspace:
|
if workspace:
|
||||||
_ws_real = os.path.realpath(os.path.expanduser(workspace))
|
_ws_real = os.path.realpath(os.path.expanduser(workspace))
|
||||||
workspace = _ws_real if os.path.isdir(_ws_real) else ""
|
workspace = _ws_real if os.path.isdir(_ws_real) else ""
|
||||||
|
# Plan mode is a modifier on agent mode — it only makes sense with tools.
|
||||||
|
if plan_mode:
|
||||||
|
chat_mode = "agent"
|
||||||
|
# An approved plan being EXECUTED: the frontend sends the checklist back
|
||||||
|
# on each turn so we can pin it in context. This way a long plan on a
|
||||||
|
# weak model survives history truncation — the agent can always re-read
|
||||||
|
# the plan. Ignored while still proposing (plan_mode on). Capped so a
|
||||||
|
# huge plan can't blow the prompt.
|
||||||
|
approved_plan = ""
|
||||||
|
if not plan_mode:
|
||||||
|
approved_plan = (form_data.get("approved_plan") or "").strip()[:8192]
|
||||||
# Did the USER explicitly pick agent mode? (vs. us auto-escalating
|
# Did the USER explicitly pick agent mode? (vs. us auto-escalating
|
||||||
# below). Skill extraction should only learn from real agent sessions,
|
# below). Skill extraction should only learn from real agent sessions,
|
||||||
# not chats we quietly promoted for a notes/calendar intent.
|
# not chats we quietly promoted for a notes/calendar intent.
|
||||||
@@ -479,11 +549,6 @@ def setup_chat_routes(
|
|||||||
do_research = True
|
do_research = True
|
||||||
logger.info(f"Session {session} in research_pending — auto-triggering research")
|
logger.info(f"Session {session} in research_pending — auto-triggering research")
|
||||||
|
|
||||||
# Persist session mode (research > agent > chat)
|
|
||||||
_effective_mode = 'research' if do_research else (chat_mode or 'chat')
|
|
||||||
if _effective_mode in ('agent', 'research', 'chat'):
|
|
||||||
set_session_mode(session, _effective_mode)
|
|
||||||
|
|
||||||
att_ids = []
|
att_ids = []
|
||||||
if body and isinstance(body.get("attachments"), list):
|
if body and isinstance(body.get("attachments"), list):
|
||||||
att_ids = [str(x) for x in body["attachments"]]
|
att_ids = [str(x) for x in body["attachments"]]
|
||||||
@@ -494,6 +559,10 @@ def setup_chat_routes(
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
no_memory = str(form_data.get("no_memory", "")).lower() == "true"
|
no_memory = str(form_data.get("no_memory", "")).lower() == "true"
|
||||||
|
pre_context_tool_policy = build_effective_tool_policy(
|
||||||
|
last_user_message=message,
|
||||||
|
)
|
||||||
|
allow_tool_preprocessing = not pre_context_tool_policy.block_all_tool_calls
|
||||||
|
|
||||||
# Build shared context (stream path uses enhanced_message for context preface)
|
# Build shared context (stream path uses enhanced_message for context preface)
|
||||||
ctx = await build_chat_context(
|
ctx = await build_chat_context(
|
||||||
@@ -515,6 +584,7 @@ def setup_chat_routes(
|
|||||||
# manage_skills (agent mode). In plain chat or incognito the
|
# manage_skills (agent mode). In plain chat or incognito the
|
||||||
# index would be useless / unwanted noise.
|
# index would be useless / unwanted noise.
|
||||||
agent_mode=(chat_mode == "agent"),
|
agent_mode=(chat_mode == "agent"),
|
||||||
|
allow_tool_preprocessing=allow_tool_preprocessing,
|
||||||
)
|
)
|
||||||
|
|
||||||
_research_flags = {"do": do_research} # Mutable container for generator scope
|
_research_flags = {"do": do_research} # Mutable container for generator scope
|
||||||
@@ -659,6 +729,32 @@ def setup_chat_routes(
|
|||||||
if chat_mode == 'chat':
|
if chat_mode == 'chat':
|
||||||
disabled_tools.update({"bash", "python", "read_file", "write_file", "web_search", "web_fetch", "search_chats", "manage_tasks"})
|
disabled_tools.update({"bash", "python", "read_file", "write_file", "web_search", "web_fetch", "search_chats", "manage_tasks"})
|
||||||
|
|
||||||
|
# Plan mode: investigate read-only, propose a plan, don't mutate. Block
|
||||||
|
# every tool not on the read-only allowlist. (stream_agent_loop enforces
|
||||||
|
# this again + drops MCP, so this is belt-and-suspenders.)
|
||||||
|
if plan_mode:
|
||||||
|
from src.tool_security import plan_mode_disabled_tools
|
||||||
|
disabled_tools.update(plan_mode_disabled_tools())
|
||||||
|
|
||||||
|
tool_policy = build_effective_tool_policy(
|
||||||
|
disabled_tools=disabled_tools,
|
||||||
|
last_user_message=message,
|
||||||
|
)
|
||||||
|
disabled_tools = tool_policy.all_disabled_names()
|
||||||
|
research_blocked_by_policy = bool(
|
||||||
|
tool_policy.blocks("trigger_research")
|
||||||
|
or tool_policy.blocks("manage_research")
|
||||||
|
)
|
||||||
|
effective_do_research = bool(
|
||||||
|
do_research and _research_flags["do"] and not research_blocked_by_policy
|
||||||
|
)
|
||||||
|
|
||||||
|
# Persist session mode after policy/privilege gates so blocked research
|
||||||
|
# turns remain ordinary chat/agent streams and saved messages.
|
||||||
|
_effective_mode = 'research' if effective_do_research else (chat_mode or 'chat')
|
||||||
|
if _effective_mode in ('agent', 'research', 'chat'):
|
||||||
|
set_session_mode(session, _effective_mode)
|
||||||
|
|
||||||
async def stream_with_save() -> AsyncGenerator[str, None]:
|
async def stream_with_save() -> AsyncGenerator[str, None]:
|
||||||
# _effective_mode is read-only here; closure captures it from
|
# _effective_mode is read-only here; closure captures it from
|
||||||
# the outer scope. (Was `nonlocal` but never reassigned.)
|
# the outer scope. (Was `nonlocal` but never reassigned.)
|
||||||
@@ -666,7 +762,7 @@ def setup_chat_routes(
|
|||||||
web_sources = ctx.web_sources
|
web_sources = ctx.web_sources
|
||||||
|
|
||||||
# Register active stream for partial-save safety net
|
# Register active stream for partial-save safety net
|
||||||
_active_streams[session] = {"status": "streaming", "partial": "", "query": message, "is_research": do_research, "mode": _effective_mode}
|
_active_streams[session] = {"status": "streaming", "partial": "", "query": message, "is_research": effective_do_research, "mode": _effective_mode}
|
||||||
|
|
||||||
if ctx.preprocessed.attachment_meta:
|
if ctx.preprocessed.attachment_meta:
|
||||||
yield f"data: {json.dumps({'type': 'attachments', 'data': ctx.preprocessed.attachment_meta})}\n\n"
|
yield f"data: {json.dumps({'type': 'attachments', 'data': ctx.preprocessed.attachment_meta})}\n\n"
|
||||||
@@ -690,7 +786,7 @@ def setup_chat_routes(
|
|||||||
yield f"data: {json.dumps({'type': 'memories_used', 'data': ctx.used_memories})}\n\n"
|
yield f"data: {json.dumps({'type': 'memories_used', 'data': ctx.used_memories})}\n\n"
|
||||||
|
|
||||||
# Run research as a background task (survives page refresh)
|
# Run research as a background task (survives page refresh)
|
||||||
if do_research and _research_flags["do"]:
|
if effective_do_research:
|
||||||
_r_ep, _r_model, _r_headers = _resolve_research_endpoint(sess)
|
_r_ep, _r_model, _r_headers = _resolve_research_endpoint(sess)
|
||||||
_auth_keys = list(_r_headers.keys()) if _r_headers else []
|
_auth_keys = list(_r_headers.keys()) if _r_headers else []
|
||||||
logger.info(f"Research endpoint resolved: model={_r_model}, endpoint={_r_ep}, auth_keys={_auth_keys}, sess_headers_keys={list(sess.headers.keys()) if isinstance(sess.headers, dict) else type(sess.headers)}")
|
logger.info(f"Research endpoint resolved: model={_r_model}, endpoint={_r_ep}, auth_keys={_auth_keys}, sess_headers_keys={list(sess.headers.keys()) if isinstance(sess.headers, dict) else type(sess.headers)}")
|
||||||
@@ -829,7 +925,7 @@ def setup_chat_routes(
|
|||||||
_fallback_candidates = []
|
_fallback_candidates = []
|
||||||
|
|
||||||
# Send model name early so the frontend can show it during streaming
|
# Send model name early so the frontend can show it during streaming
|
||||||
_model_suffix = "Research" if do_research else None
|
_model_suffix = "Research" if effective_do_research else None
|
||||||
_model_info = {"type": "model_info", "model": sess.model}
|
_model_info = {"type": "model_info", "model": sess.model}
|
||||||
if _model_suffix:
|
if _model_suffix:
|
||||||
_model_info["suffix"] = _model_suffix
|
_model_info["suffix"] = _model_suffix
|
||||||
@@ -839,6 +935,12 @@ def setup_chat_routes(
|
|||||||
|
|
||||||
if _is_image_generation_session(sess, owner=_user):
|
if _is_image_generation_session(sess, owner=_user):
|
||||||
from src.settings import get_setting
|
from src.settings import get_setting
|
||||||
|
if tool_policy.blocks("generate_image"):
|
||||||
|
_blocked_msg = tool_policy.reason_for("generate_image")
|
||||||
|
yield f'data: {json.dumps({"delta": _blocked_msg})}\n\n'
|
||||||
|
yield "data: [DONE]\n\n"
|
||||||
|
_active_streams.pop(session, None)
|
||||||
|
return
|
||||||
if not get_setting("image_gen_enabled", True):
|
if not get_setting("image_gen_enabled", True):
|
||||||
yield f'data: {json.dumps({"delta": "Image generation is disabled by the administrator."})}\n\n'
|
yield f'data: {json.dumps({"delta": "Image generation is disabled by the administrator."})}\n\n'
|
||||||
yield "data: [DONE]\n\n"
|
yield "data: [DONE]\n\n"
|
||||||
@@ -873,6 +975,8 @@ def setup_chat_routes(
|
|||||||
elif chat_mode == "chat":
|
elif chat_mode == "chat":
|
||||||
_chat_start = time.time()
|
_chat_start = time.time()
|
||||||
_answered_by = None # set if the selected model failed and a fallback answered
|
_answered_by = None # set if the selected model failed and a fallback answered
|
||||||
|
_requested_model = sess.model
|
||||||
|
_actual_model = None
|
||||||
# ── Chat mode: call stream_llm directly, NO tools, NO document access ──
|
# ── Chat mode: call stream_llm directly, NO tools, NO document access ──
|
||||||
try:
|
try:
|
||||||
_chat_candidates = [(sess.endpoint_url, sess.model, sess.headers)] + _fallback_candidates
|
_chat_candidates = [(sess.endpoint_url, sess.model, sess.headers)] + _fallback_candidates
|
||||||
@@ -905,10 +1009,18 @@ def setup_chat_routes(
|
|||||||
# Selected model failed; a fallback answered.
|
# Selected model failed; a fallback answered.
|
||||||
# Forward the notice and remember the real model.
|
# Forward the notice and remember the real model.
|
||||||
_answered_by = data.get("answered_by") or _answered_by
|
_answered_by = data.get("answered_by") or _answered_by
|
||||||
|
_actual_model = _actual_model or _answered_by
|
||||||
|
data["selected_model"] = data.get("selected_model") or _requested_model
|
||||||
yield chunk
|
yield chunk
|
||||||
|
elif data.get("type") == "model_actual":
|
||||||
|
_actual_model = data.get("model") or _actual_model
|
||||||
|
data["requested_model"] = _requested_model
|
||||||
|
yield f'data: {json.dumps(data)}\n\n'
|
||||||
elif data.get("type") == "usage":
|
elif data.get("type") == "usage":
|
||||||
last_metrics = data.get("data", {})
|
last_metrics = data.get("data", {})
|
||||||
last_metrics["model"] = _answered_by or sess.model
|
_reported_model = last_metrics.get("model")
|
||||||
|
last_metrics["requested_model"] = _requested_model
|
||||||
|
last_metrics["model"] = _reported_model or _actual_model or _answered_by or _requested_model
|
||||||
if ctx.context_length and last_metrics.get("input_tokens"):
|
if ctx.context_length and last_metrics.get("input_tokens"):
|
||||||
pct = min(round((last_metrics["input_tokens"] / ctx.context_length) * 100, 1), 100.0)
|
pct = min(round((last_metrics["input_tokens"] / ctx.context_length) * 100, 1), 100.0)
|
||||||
last_metrics["context_percent"] = pct
|
last_metrics["context_percent"] = pct
|
||||||
@@ -945,7 +1057,8 @@ def setup_chat_routes(
|
|||||||
"tokens_per_second": _tps,
|
"tokens_per_second": _tps,
|
||||||
"context_percent": _ctx_pct,
|
"context_percent": _ctx_pct,
|
||||||
"context_length": ctx.context_length,
|
"context_length": ctx.context_length,
|
||||||
"model": sess.model,
|
"model": _actual_model or _answered_by or _requested_model,
|
||||||
|
"requested_model": _requested_model,
|
||||||
"usage_source": "estimated",
|
"usage_source": "estimated",
|
||||||
}
|
}
|
||||||
yield f'data: {json.dumps({"type": "metrics", "data": last_metrics})}\n\n'
|
yield f'data: {json.dumps({"type": "metrics", "data": last_metrics})}\n\n'
|
||||||
@@ -957,7 +1070,7 @@ def setup_chat_routes(
|
|||||||
rag_sources=ctx.rag_sources,
|
rag_sources=ctx.rag_sources,
|
||||||
research_sources=research_sources,
|
research_sources=research_sources,
|
||||||
used_memories=ctx.used_memories,
|
used_memories=ctx.used_memories,
|
||||||
do_research=do_research,
|
do_research=effective_do_research,
|
||||||
incognito=incognito,
|
incognito=incognito,
|
||||||
)
|
)
|
||||||
if _saved_id:
|
if _saved_id:
|
||||||
@@ -967,14 +1080,22 @@ def setup_chat_routes(
|
|||||||
last_metrics, ctx.uprefs, memory_manager, memory_vector, webhook_manager,
|
last_metrics, ctx.uprefs, memory_manager, memory_vector, webhook_manager,
|
||||||
incognito=incognito, compare_mode=compare_mode,
|
incognito=incognito, compare_mode=compare_mode,
|
||||||
character_name=ctx.preset.character_name,
|
character_name=ctx.preset.character_name,
|
||||||
owner=_user,
|
owner=_user,
|
||||||
|
allow_background_extraction=not tool_policy.block_all_tool_calls,
|
||||||
)
|
)
|
||||||
_stream_set(session, status="done")
|
_stream_set(session, status="done")
|
||||||
yield chunk
|
yield chunk
|
||||||
except (asyncio.CancelledError, GeneratorExit):
|
except (asyncio.CancelledError, GeneratorExit):
|
||||||
if full_response:
|
if full_response:
|
||||||
logger.info("Client disconnected mid-stream (chat mode) for session %s, saving partial (%d chars)", session, len(full_response))
|
logger.info("Client disconnected mid-stream (chat mode) for session %s, saving partial (%d chars)", session, len(full_response))
|
||||||
_stopped_content, _stopped_md = clean_thinking_for_save(full_response, {"stopped": True, "model": sess.model})
|
_stopped_content, _stopped_md = clean_thinking_for_save(
|
||||||
|
full_response,
|
||||||
|
{
|
||||||
|
"stopped": True,
|
||||||
|
"model": _actual_model or _answered_by or _requested_model,
|
||||||
|
"requested_model": _requested_model,
|
||||||
|
},
|
||||||
|
)
|
||||||
sess.add_message(ChatMessage("assistant", _stopped_content, metadata=_stopped_md))
|
sess.add_message(ChatMessage("assistant", _stopped_content, metadata=_stopped_md))
|
||||||
if not incognito:
|
if not incognito:
|
||||||
session_manager.save_sessions()
|
session_manager.save_sessions()
|
||||||
@@ -986,6 +1107,8 @@ def setup_chat_routes(
|
|||||||
_agent_rounds = 0
|
_agent_rounds = 0
|
||||||
_agent_tool_calls = 0
|
_agent_tool_calls = 0
|
||||||
_answered_by = None # set if the selected model failed and a fallback answered
|
_answered_by = None # set if the selected model failed and a fallback answered
|
||||||
|
_requested_model = sess.model
|
||||||
|
_actual_model = None
|
||||||
try:
|
try:
|
||||||
from src.settings import get_setting
|
from src.settings import get_setting
|
||||||
from src.agent_tools import MAX_AGENT_ROUNDS as _DEFAULT_ROUNDS
|
from src.agent_tools import MAX_AGENT_ROUNDS as _DEFAULT_ROUNDS
|
||||||
@@ -1012,9 +1135,12 @@ def setup_chat_routes(
|
|||||||
active_document=active_doc,
|
active_document=active_doc,
|
||||||
session_id=session,
|
session_id=session,
|
||||||
disabled_tools=disabled_tools if disabled_tools else None,
|
disabled_tools=disabled_tools if disabled_tools else None,
|
||||||
|
tool_policy=tool_policy,
|
||||||
owner=_user,
|
owner=_user,
|
||||||
fallbacks=_fallback_candidates,
|
fallbacks=_fallback_candidates,
|
||||||
workspace=workspace or None,
|
workspace=workspace or None,
|
||||||
|
plan_mode=plan_mode,
|
||||||
|
approved_plan=approved_plan or None,
|
||||||
):
|
):
|
||||||
if chunk.startswith("data: ") and not chunk.startswith("data: [DONE]"):
|
if chunk.startswith("data: ") and not chunk.startswith("data: [DONE]"):
|
||||||
try:
|
try:
|
||||||
@@ -1035,6 +1161,8 @@ def setup_chat_routes(
|
|||||||
"doc_stream_open", "doc_stream_delta",
|
"doc_stream_open", "doc_stream_delta",
|
||||||
"doc_update", "doc_suggestions", "ui_control",
|
"doc_update", "doc_suggestions", "ui_control",
|
||||||
"rounds_exhausted",
|
"rounds_exhausted",
|
||||||
|
"ask_user",
|
||||||
|
"plan_update",
|
||||||
):
|
):
|
||||||
if data.get("type") == "agent_step":
|
if data.get("type") == "agent_step":
|
||||||
_agent_rounds = max(_agent_rounds, data.get("round", 1))
|
_agent_rounds = max(_agent_rounds, data.get("round", 1))
|
||||||
@@ -1047,10 +1175,18 @@ def setup_chat_routes(
|
|||||||
# model so metrics reflect it, not the masked
|
# model so metrics reflect it, not the masked
|
||||||
# selected model.
|
# selected model.
|
||||||
_answered_by = data.get("answered_by") or _answered_by
|
_answered_by = data.get("answered_by") or _answered_by
|
||||||
|
_actual_model = _actual_model or _answered_by
|
||||||
|
data["selected_model"] = data.get("selected_model") or _requested_model
|
||||||
yield chunk
|
yield chunk
|
||||||
|
elif data.get("type") == "model_actual":
|
||||||
|
_actual_model = data.get("model") or _actual_model
|
||||||
|
data["requested_model"] = _requested_model
|
||||||
|
yield f'data: {json.dumps(data)}\n\n'
|
||||||
elif data.get("type") == "metrics":
|
elif data.get("type") == "metrics":
|
||||||
last_metrics = data.get("data", {})
|
last_metrics = data.get("data", {})
|
||||||
last_metrics["model"] = _answered_by or sess.model
|
_reported_model = last_metrics.get("model")
|
||||||
|
last_metrics["requested_model"] = last_metrics.get("requested_model") or _requested_model
|
||||||
|
last_metrics["model"] = _reported_model or _actual_model or _answered_by or _requested_model
|
||||||
yield f'data: {json.dumps({"type": "metrics", "data": last_metrics})}\n\n'
|
yield f'data: {json.dumps({"type": "metrics", "data": last_metrics})}\n\n'
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
yield chunk
|
yield chunk
|
||||||
@@ -1078,6 +1214,7 @@ def setup_chat_routes(
|
|||||||
skills_manager=skills_manager,
|
skills_manager=skills_manager,
|
||||||
owner=_user,
|
owner=_user,
|
||||||
extract_skills=user_requested_agent,
|
extract_skills=user_requested_agent,
|
||||||
|
allow_background_extraction=not tool_policy.block_all_tool_calls,
|
||||||
)
|
)
|
||||||
_stream_set(session, status="done")
|
_stream_set(session, status="done")
|
||||||
yield chunk
|
yield chunk
|
||||||
@@ -1091,7 +1228,14 @@ def setup_chat_routes(
|
|||||||
try:
|
try:
|
||||||
if full_response:
|
if full_response:
|
||||||
logger.info("Client disconnected mid-stream for session %s, saving partial response (%d chars)", session, len(full_response))
|
logger.info("Client disconnected mid-stream for session %s, saving partial response (%d chars)", session, len(full_response))
|
||||||
_stopped_content2, _stopped_md2 = clean_thinking_for_save(full_response, {"stopped": True, "model": sess.model})
|
_stopped_content2, _stopped_md2 = clean_thinking_for_save(
|
||||||
|
full_response,
|
||||||
|
{
|
||||||
|
"stopped": True,
|
||||||
|
"model": _actual_model or _answered_by or _requested_model,
|
||||||
|
"requested_model": _requested_model,
|
||||||
|
},
|
||||||
|
)
|
||||||
sess.add_message(ChatMessage("assistant", _stopped_content2, metadata=_stopped_md2))
|
sess.add_message(ChatMessage("assistant", _stopped_content2, metadata=_stopped_md2))
|
||||||
if not incognito:
|
if not incognito:
|
||||||
session_manager.save_sessions()
|
session_manager.save_sessions()
|
||||||
@@ -1110,11 +1254,30 @@ def setup_chat_routes(
|
|||||||
finally:
|
finally:
|
||||||
_active_streams.pop(session, None)
|
_active_streams.pop(session, None)
|
||||||
|
|
||||||
# Run the stream as a DETACHED background task so it survives the client
|
# Compare panes are short-lived, single-shot generations whose sessions
|
||||||
# closing the tab / navigating away (true terminal-agent behavior). The
|
# exist only to drive that one pane — there's nothing to "resume" and
|
||||||
# SSE response just subscribes (replay buffered output + live); dropping
|
# the user expects the pane's Stop button (which aborts the fetch,
|
||||||
# the SSE only removes a subscriber — the run keeps going and saves the
|
# closing this SSE) to promptly cancel the upstream LLM call. Detaching
|
||||||
# assistant message on completion regardless. Reconnect via /api/chat/resume.
|
# them would keep burning upstream tokens/compute after the pane is
|
||||||
|
# stopped or the comparison is abandoned, and would surface a stale
|
||||||
|
# "still streaming" /resume target for a session nobody will revisit.
|
||||||
|
#
|
||||||
|
# So: stream them directly (no agent_runs wrapping). Starlette cancels
|
||||||
|
# the underlying async generator (raising CancelledError/GeneratorExit
|
||||||
|
# inside it) as soon as it notices the client disconnected — which the
|
||||||
|
# mode-specific except blocks above already handle by saving the
|
||||||
|
# partial response exactly once. This stops the upstream call promptly
|
||||||
|
# without waiting on the next streamed chunk.
|
||||||
|
#
|
||||||
|
# Normal chat/agent streams keep the DETACHED behavior below: they
|
||||||
|
# survive the client closing the tab / navigating away (true
|
||||||
|
# terminal-agent semantics). The SSE response just subscribes (replay
|
||||||
|
# buffered output + live); dropping the SSE only removes a subscriber —
|
||||||
|
# the run keeps going and saves the assistant message on completion
|
||||||
|
# regardless. Reconnect via /api/chat/resume.
|
||||||
|
if compare_mode:
|
||||||
|
return StreamingResponse(_safe_stream(), media_type="text/event-stream")
|
||||||
|
|
||||||
agent_runs.start(session, _safe_stream())
|
agent_runs.start(session, _safe_stream())
|
||||||
return StreamingResponse(agent_runs.subscribe(session), media_type="text/event-stream")
|
return StreamingResponse(agent_runs.subscribe(session), media_type="text/event-stream")
|
||||||
|
|
||||||
@@ -1185,45 +1348,16 @@ def setup_chat_routes(
|
|||||||
return []
|
return []
|
||||||
|
|
||||||
_user = get_current_user(request)
|
_user = get_current_user(request)
|
||||||
query_term = q.strip()
|
return [
|
||||||
db = SessionLocal()
|
result.to_dict()
|
||||||
try:
|
for result in search_session_messages(
|
||||||
base_q = (
|
q,
|
||||||
db.query(DBChatMessage, DBSession.name)
|
limit=limit,
|
||||||
.join(DBSession, DBChatMessage.session_id == DBSession.id)
|
owner=_user,
|
||||||
.filter(
|
restrict_owner=_user is not None,
|
||||||
DBSession.archived == False,
|
include_legacy_owner=False,
|
||||||
DBChatMessage.content.ilike(f"%{query_term}%"),
|
|
||||||
DBChatMessage.role.in_(["user", "assistant"]),
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
if _user:
|
]
|
||||||
base_q = base_q.filter(DBSession.owner == _user)
|
|
||||||
rows = base_q.order_by(DBChatMessage.timestamp.desc()).limit(limit).all()
|
|
||||||
|
|
||||||
results = []
|
|
||||||
for msg, session_name in rows:
|
|
||||||
content = msg.content or ""
|
|
||||||
lower_content = content.lower()
|
|
||||||
idx = lower_content.find(query_term.lower())
|
|
||||||
if idx == -1:
|
|
||||||
snippet = content[:120]
|
|
||||||
else:
|
|
||||||
start = max(0, idx - 50)
|
|
||||||
end = min(len(content), idx + len(query_term) + 50)
|
|
||||||
snippet = ("..." if start > 0 else "") + content[start:end] + ("..." if end < len(content) else "")
|
|
||||||
|
|
||||||
results.append({
|
|
||||||
"session_id": msg.session_id,
|
|
||||||
"session_name": session_name or "Untitled",
|
|
||||||
"role": msg.role,
|
|
||||||
"content_snippet": snippet,
|
|
||||||
"timestamp": msg.timestamp.isoformat() if msg.timestamp else None,
|
|
||||||
})
|
|
||||||
|
|
||||||
return results
|
|
||||||
finally:
|
|
||||||
db.close()
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------ #
|
# ------------------------------------------------------------------ #
|
||||||
# POST /api/rewrite — lightweight rewrite of last AI message (no tools)
|
# POST /api/rewrite — lightweight rewrite of last AI message (no tools)
|
||||||
|
|||||||
@@ -0,0 +1,170 @@
|
|||||||
|
"""ChatGPT Subscription device-flow setup routes."""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import uuid
|
||||||
|
from typing import Dict, Optional
|
||||||
|
|
||||||
|
from fastapi import HTTPException, Request
|
||||||
|
|
||||||
|
from core.database import ModelEndpoint, ProviderAuthSession, SessionLocal, utcnow_naive
|
||||||
|
from routes.device_flow import (
|
||||||
|
DeviceFlowPoll,
|
||||||
|
DeviceFlowStart,
|
||||||
|
PendingDeviceFlowStore,
|
||||||
|
create_device_flow_router,
|
||||||
|
)
|
||||||
|
from src.auth_helpers import get_current_user
|
||||||
|
from src import chatgpt_subscription
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_DEVICE_FLOW_STORE = PendingDeviceFlowStore()
|
||||||
|
|
||||||
|
|
||||||
|
def _provision_endpoint(tokens: Dict, owner: Optional[str]) -> Dict:
|
||||||
|
access_token = tokens.get("access_token")
|
||||||
|
refresh_token = tokens.get("refresh_token")
|
||||||
|
if not access_token or not refresh_token:
|
||||||
|
raise ValueError("ChatGPT token response was missing access_token or refresh_token")
|
||||||
|
|
||||||
|
base = chatgpt_subscription.DEFAULT_CHATGPT_SUBSCRIPTION_BASE_URL
|
||||||
|
models = chatgpt_subscription.fetch_available_models(access_token)
|
||||||
|
if not models:
|
||||||
|
raise ValueError("ChatGPT Subscription connected, but no usable Codex models were discovered for this account.")
|
||||||
|
db = SessionLocal()
|
||||||
|
try:
|
||||||
|
auth = (
|
||||||
|
db.query(ProviderAuthSession)
|
||||||
|
.filter(
|
||||||
|
ProviderAuthSession.provider == chatgpt_subscription.CHATGPT_SUBSCRIPTION_PROVIDER,
|
||||||
|
ProviderAuthSession.owner == owner,
|
||||||
|
)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
if auth is None:
|
||||||
|
auth = ProviderAuthSession(
|
||||||
|
id=str(uuid.uuid4())[:8],
|
||||||
|
provider=chatgpt_subscription.CHATGPT_SUBSCRIPTION_PROVIDER,
|
||||||
|
owner=owner,
|
||||||
|
label="ChatGPT Subscription",
|
||||||
|
base_url=base,
|
||||||
|
auth_mode="chatgpt",
|
||||||
|
)
|
||||||
|
db.add(auth)
|
||||||
|
auth.base_url = base
|
||||||
|
auth.access_token = access_token
|
||||||
|
auth.refresh_token = refresh_token
|
||||||
|
auth.last_refresh = utcnow_naive()
|
||||||
|
auth.auth_mode = "chatgpt"
|
||||||
|
|
||||||
|
ep = (
|
||||||
|
db.query(ModelEndpoint)
|
||||||
|
.filter(
|
||||||
|
ModelEndpoint.base_url == base,
|
||||||
|
ModelEndpoint.provider_auth_id == auth.id,
|
||||||
|
ModelEndpoint.owner == owner,
|
||||||
|
)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
if ep is None:
|
||||||
|
ep = ModelEndpoint(
|
||||||
|
id=str(uuid.uuid4())[:8],
|
||||||
|
name="ChatGPT Subscription",
|
||||||
|
base_url=base,
|
||||||
|
model_type="llm",
|
||||||
|
endpoint_kind="api",
|
||||||
|
owner=owner,
|
||||||
|
)
|
||||||
|
db.add(ep)
|
||||||
|
ep.name = "ChatGPT Subscription"
|
||||||
|
ep.base_url = base
|
||||||
|
ep.api_key = None
|
||||||
|
ep.provider_auth_id = auth.id
|
||||||
|
ep.is_enabled = True
|
||||||
|
ep.supports_tools = False
|
||||||
|
ep.model_type = "llm"
|
||||||
|
ep.endpoint_kind = "api"
|
||||||
|
ep.model_refresh_mode = "manual"
|
||||||
|
ep.cached_models = json.dumps(models)
|
||||||
|
db.commit()
|
||||||
|
result = {
|
||||||
|
"id": ep.id,
|
||||||
|
"name": ep.name,
|
||||||
|
"base_url": ep.base_url,
|
||||||
|
"models": models,
|
||||||
|
}
|
||||||
|
finally:
|
||||||
|
db.close()
|
||||||
|
|
||||||
|
try:
|
||||||
|
from routes.model_routes import _invalidate_models_cache
|
||||||
|
|
||||||
|
_invalidate_models_cache()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def _start_device_flow(request: Request, _form) -> DeviceFlowStart:
|
||||||
|
try:
|
||||||
|
data = chatgpt_subscription.request_device_code()
|
||||||
|
except Exception as exc:
|
||||||
|
raise chatgpt_subscription.to_http_exception(exc)
|
||||||
|
|
||||||
|
device_auth_id = data.get("device_auth_id")
|
||||||
|
user_code = data.get("user_code")
|
||||||
|
if not device_auth_id or not user_code:
|
||||||
|
raise HTTPException(502, "ChatGPT did not return a complete device code")
|
||||||
|
verification_uri = data.get("verification_uri") or f"{chatgpt_subscription.CHATGPT_OAUTH_ISSUER}/codex/device"
|
||||||
|
return DeviceFlowStart(
|
||||||
|
pending={
|
||||||
|
"device_auth_id": device_auth_id,
|
||||||
|
"user_code": user_code,
|
||||||
|
"owner": get_current_user(request) or None,
|
||||||
|
},
|
||||||
|
response={
|
||||||
|
"user_code": user_code,
|
||||||
|
"verification_uri": verification_uri,
|
||||||
|
},
|
||||||
|
interval=int(data.get("interval") or 5),
|
||||||
|
expires_in=int(data.get("expires_in") or 900),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _poll_device_flow(_request: Request, pending: Dict) -> DeviceFlowPoll:
|
||||||
|
try:
|
||||||
|
data = chatgpt_subscription.poll_device_auth(pending["device_auth_id"], pending["user_code"])
|
||||||
|
except Exception as exc:
|
||||||
|
logger.debug("ChatGPT device poll failed: %s", exc)
|
||||||
|
return DeviceFlowPoll.pending(str(exc))
|
||||||
|
|
||||||
|
authorization_code = data.get("authorization_code")
|
||||||
|
code_verifier = data.get("code_verifier")
|
||||||
|
if authorization_code and code_verifier:
|
||||||
|
try:
|
||||||
|
tokens = chatgpt_subscription.exchange_authorization_code(authorization_code, code_verifier)
|
||||||
|
result = _provision_endpoint(tokens, pending["owner"])
|
||||||
|
except Exception as exc:
|
||||||
|
logger.exception("ChatGPT Subscription endpoint provisioning failed")
|
||||||
|
raise chatgpt_subscription.to_http_exception(exc)
|
||||||
|
return DeviceFlowPoll.authorized(result)
|
||||||
|
|
||||||
|
err = data.get("error") or data.get("status")
|
||||||
|
if err in ("authorization_pending", "pending", None):
|
||||||
|
return DeviceFlowPoll.pending()
|
||||||
|
if err == "slow_down":
|
||||||
|
return DeviceFlowPoll.slow_down(int(data.get("interval") or 0) or None)
|
||||||
|
if err in ("expired_token", "access_denied", "denied"):
|
||||||
|
return DeviceFlowPoll.failed(err)
|
||||||
|
return DeviceFlowPoll.pending(err or "unknown")
|
||||||
|
|
||||||
|
|
||||||
|
def setup_chatgpt_subscription_routes():
|
||||||
|
return create_device_flow_router(
|
||||||
|
prefix="/api/chatgpt-subscription",
|
||||||
|
tags=["chatgpt-subscription"],
|
||||||
|
store=_DEVICE_FLOW_STORE,
|
||||||
|
start_flow=_start_device_flow,
|
||||||
|
poll_flow=_poll_device_flow,
|
||||||
|
)
|
||||||
+388
-3
@@ -15,10 +15,13 @@ from typing import Any
|
|||||||
from fastapi import APIRouter, BackgroundTasks, Body, HTTPException, Request
|
from fastapi import APIRouter, BackgroundTasks, Body, HTTPException, Request
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
|
|
||||||
from src.auth_helpers import require_user
|
from src.auth_helpers import require_authenticated_request, require_user
|
||||||
from src.tool_implementations import do_manage_notes
|
from src.tool_implementations import do_manage_notes
|
||||||
|
from src.constants import COOKBOOK_STATE_FILE
|
||||||
|
|
||||||
|
|
||||||
|
COOKBOOK_READ_SCOPES = {"cookbook:read", "cookbook:launch"}
|
||||||
|
COOKBOOK_LAUNCH_SCOPES = {"cookbook:launch"}
|
||||||
TODO_READ_SCOPES = {"todos:read", "todos:write"}
|
TODO_READ_SCOPES = {"todos:read", "todos:write"}
|
||||||
TODO_WRITE_SCOPES = {"todos:write"}
|
TODO_WRITE_SCOPES = {"todos:write"}
|
||||||
EMAIL_READ_SCOPES = {"email:read", "email:draft", "email:send"}
|
EMAIL_READ_SCOPES = {"email:read", "email:draft", "email:send"}
|
||||||
@@ -39,7 +42,9 @@ async def _as_owner(request: Request, owner: str, fn, *args, **kwargs):
|
|||||||
the scope-gated owner (not the "api" pseudo-user the bearer middleware sets).
|
the scope-gated owner (not the "api" pseudo-user the bearer middleware sets).
|
||||||
Restores the original value when done. Works for sync and async handlers."""
|
Restores the original value when done. Works for sync and async handlers."""
|
||||||
orig = getattr(request.state, "current_user", None)
|
orig = getattr(request.state, "current_user", None)
|
||||||
|
orig_api_token = getattr(request.state, "api_token", None)
|
||||||
request.state.current_user = owner
|
request.state.current_user = owner
|
||||||
|
request.state.api_token = False
|
||||||
try:
|
try:
|
||||||
result = fn(*args, **kwargs)
|
result = fn(*args, **kwargs)
|
||||||
if asyncio.iscoroutine(result):
|
if asyncio.iscoroutine(result):
|
||||||
@@ -47,6 +52,13 @@ async def _as_owner(request: Request, owner: str, fn, *args, **kwargs):
|
|||||||
return result
|
return result
|
||||||
finally:
|
finally:
|
||||||
request.state.current_user = orig
|
request.state.current_user = orig
|
||||||
|
if orig_api_token is None:
|
||||||
|
try:
|
||||||
|
delattr(request.state, "api_token")
|
||||||
|
except AttributeError:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
request.state.api_token = orig_api_token
|
||||||
|
|
||||||
|
|
||||||
def _scope_owner(request: Request, allowed: set[str]) -> str:
|
def _scope_owner(request: Request, allowed: set[str]) -> str:
|
||||||
@@ -130,6 +142,11 @@ def setup_codex_routes(
|
|||||||
"actions": ["library", "read", "create", "delete"],
|
"actions": ["library", "read", "create", "delete"],
|
||||||
"available": documents_library_endpoint is not None,
|
"available": documents_library_endpoint is not None,
|
||||||
},
|
},
|
||||||
|
"cookbook": {
|
||||||
|
"read": scoped(COOKBOOK_READ_SCOPES),
|
||||||
|
"launch": scoped(COOKBOOK_LAUNCH_SCOPES),
|
||||||
|
"actions": ["tasks", "servers", "output", "serve", "stop"],
|
||||||
|
},
|
||||||
},
|
},
|
||||||
"safety": {
|
"safety": {
|
||||||
"email_send_requires_confirmation": True,
|
"email_send_requires_confirmation": True,
|
||||||
@@ -139,7 +156,7 @@ def setup_codex_routes(
|
|||||||
|
|
||||||
@router.get("/plugin.zip")
|
@router.get("/plugin.zip")
|
||||||
def plugin_zip(request: Request):
|
def plugin_zip(request: Request):
|
||||||
require_user(request)
|
require_authenticated_request(request)
|
||||||
root = Path(__file__).resolve().parent.parent / "integrations" / "codex"
|
root = Path(__file__).resolve().parent.parent / "integrations" / "codex"
|
||||||
if not root.exists():
|
if not root.exists():
|
||||||
raise HTTPException(404, "Codex plugin bundle not found")
|
raise HTTPException(404, "Codex plugin bundle not found")
|
||||||
@@ -373,6 +390,374 @@ def setup_codex_routes(
|
|||||||
raise HTTPException(400, f"Invalid document payload: {exc}")
|
raise HTTPException(400, f"Invalid document payload: {exc}")
|
||||||
return await _as_owner(request, owner, documents_create_endpoint, request, req)
|
return await _as_owner(request, owner, documents_create_endpoint, request, req)
|
||||||
|
|
||||||
|
# ── Cookbook surface ──
|
||||||
|
# Lets the agent run the same launch / monitor / kill loop the user
|
||||||
|
# would do by hand in the Cookbook UI: read the current task list +
|
||||||
|
# tmux output, launch a serve task, stop one. Two scopes:
|
||||||
|
# cookbook:read — list tasks + tail output + list servers
|
||||||
|
# cookbook:launch — also start/stop serves (host shell exec)
|
||||||
|
# `cookbook:launch` is genuinely powerful: /api/model/serve runs SSH'd
|
||||||
|
# commands on the user's hosts. The existing _validate_serve_cmd
|
||||||
|
# allowlist (vllm/python3/sglang/llama-server/etc., no shell metachars)
|
||||||
|
# keeps the agent inside the same sandbox the UI uses.
|
||||||
|
|
||||||
|
async def _run_shell(cmd: str, timeout: float = 15.0) -> dict:
|
||||||
|
"""Run a shell command, return {exit_code, stdout, stderr}."""
|
||||||
|
import asyncio as _asyncio
|
||||||
|
try:
|
||||||
|
proc = await _asyncio.create_subprocess_shell(
|
||||||
|
cmd,
|
||||||
|
stdout=_asyncio.subprocess.PIPE,
|
||||||
|
stderr=_asyncio.subprocess.PIPE,
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
stdout_b, stderr_b = await _asyncio.wait_for(proc.communicate(), timeout=timeout)
|
||||||
|
except _asyncio.TimeoutError:
|
||||||
|
proc.kill()
|
||||||
|
return {"exit_code": -1, "stdout": "", "stderr": "timed out"}
|
||||||
|
return {
|
||||||
|
"exit_code": proc.returncode,
|
||||||
|
"stdout": stdout_b.decode(errors="replace"),
|
||||||
|
"stderr": stderr_b.decode(errors="replace"),
|
||||||
|
}
|
||||||
|
except Exception as exc:
|
||||||
|
return {"exit_code": -1, "stdout": "", "stderr": str(exc)}
|
||||||
|
|
||||||
|
def _read_cookbook_state() -> dict:
|
||||||
|
from pathlib import Path as _Path
|
||||||
|
import json as _json
|
||||||
|
p = _Path(COOKBOOK_STATE_FILE)
|
||||||
|
if not p.exists():
|
||||||
|
return {}
|
||||||
|
try:
|
||||||
|
return _json.loads(p.read_text(encoding="utf-8"))
|
||||||
|
except Exception:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def _redact_task(t: dict) -> dict:
|
||||||
|
"""Strip secrets before returning to the agent."""
|
||||||
|
clean = {k: v for k, v in t.items() if k not in ("hf_token", "_secrets")}
|
||||||
|
if isinstance(clean.get("payload"), dict):
|
||||||
|
pl = clean["payload"]
|
||||||
|
clean["payload"] = {k: v for k, v in pl.items()
|
||||||
|
if k not in ("hf_token", "_secrets")}
|
||||||
|
return clean
|
||||||
|
|
||||||
|
@router.get("/cookbook/tasks")
|
||||||
|
async def codex_cookbook_tasks(request: Request):
|
||||||
|
_scope_owner(request, COOKBOOK_READ_SCOPES)
|
||||||
|
state = _read_cookbook_state()
|
||||||
|
tasks = state.get("tasks") or []
|
||||||
|
return {"tasks": [_redact_task(t) for t in tasks]}
|
||||||
|
|
||||||
|
@router.get("/cookbook/servers")
|
||||||
|
async def codex_cookbook_servers(request: Request):
|
||||||
|
_scope_owner(request, COOKBOOK_READ_SCOPES)
|
||||||
|
state = _read_cookbook_state()
|
||||||
|
servers = state.get("env", {}).get("servers") or []
|
||||||
|
# Strip ssh creds / passwords; keep only what's needed to pick a host.
|
||||||
|
cleaned = []
|
||||||
|
for s in servers:
|
||||||
|
cleaned.append({
|
||||||
|
"name": s.get("name"),
|
||||||
|
"host": s.get("host"),
|
||||||
|
"port": s.get("port"),
|
||||||
|
"env": s.get("env"),
|
||||||
|
"envPath": s.get("envPath"),
|
||||||
|
"platform": s.get("platform"),
|
||||||
|
"modelDirs": s.get("modelDirs"),
|
||||||
|
})
|
||||||
|
return {"servers": cleaned}
|
||||||
|
|
||||||
|
@router.get("/cookbook/output/{session_id}")
|
||||||
|
async def codex_cookbook_output(request: Request, session_id: str, tail: int = 400):
|
||||||
|
_scope_owner(request, COOKBOOK_READ_SCOPES)
|
||||||
|
# Defensive: session_id must be the tmux-style id we issue
|
||||||
|
# (`serve-XXXX` / `cookbook-XXXX` / `queue-XXXX`); anything else
|
||||||
|
# would let the agent run arbitrary `tmux capture-pane` targets.
|
||||||
|
import re as _re
|
||||||
|
if not _re.fullmatch(r"[a-zA-Z0-9_-]+", session_id):
|
||||||
|
raise HTTPException(400, "Invalid session id")
|
||||||
|
tail = max(20, min(int(tail or 400), 4000))
|
||||||
|
# Resolve the task's host (if any) from cookbook state so we can
|
||||||
|
# ssh to the right box, exactly as the UI does in _reconnectTask.
|
||||||
|
state = _read_cookbook_state()
|
||||||
|
tasks = state.get("tasks") or []
|
||||||
|
task = next((t for t in tasks if t.get("sessionId") == session_id), None)
|
||||||
|
if task is None:
|
||||||
|
raise HTTPException(404, "task not found")
|
||||||
|
host = (task.get("remoteHost") or "").strip()
|
||||||
|
ssh_port = (task.get("sshPort") or "").strip()
|
||||||
|
# Prefer the persisted log file over the tmux pane. The pane gets
|
||||||
|
# overwritten by the post-crash neofetch banner + bash prompt the
|
||||||
|
# moment vllm exits; the log file is the raw stdout/stderr and
|
||||||
|
# survives unchanged. Falls back to pane for older tasks predating
|
||||||
|
# the tee-to-log runner change.
|
||||||
|
log_path = f"/tmp/odysseus-tmux/{session_id}.log"
|
||||||
|
inner = (
|
||||||
|
f"if [ -s {log_path} ]; then tail -n {tail} {log_path}; "
|
||||||
|
f"else tmux capture-pane -t {session_id} -p -S -{tail}; fi"
|
||||||
|
)
|
||||||
|
if host:
|
||||||
|
port_flag = f"-p {ssh_port} " if ssh_port and ssh_port != "22" else ""
|
||||||
|
import shlex
|
||||||
|
cmd = f"ssh {port_flag}{host} {shlex.quote(inner)}"
|
||||||
|
else:
|
||||||
|
cmd = inner
|
||||||
|
result = await _run_shell(cmd, timeout=15)
|
||||||
|
return {
|
||||||
|
"session_id": session_id,
|
||||||
|
"host": host or "local",
|
||||||
|
"exit_code": result.get("exit_code"),
|
||||||
|
"output": result.get("stdout", ""),
|
||||||
|
"task": _redact_task(task),
|
||||||
|
}
|
||||||
|
|
||||||
|
@router.post("/cookbook/serve")
|
||||||
|
async def codex_cookbook_serve(request: Request, body: dict[str, Any] = Body(default_factory=dict)):
|
||||||
|
_scope_owner(request, COOKBOOK_LAUNCH_SCOPES)
|
||||||
|
# Wraps /api/model/serve with the SAME validation the UI uses.
|
||||||
|
# _validate_serve_cmd (called inside model_serve) rejects shell
|
||||||
|
# metachars and requires the leading binary to be in the
|
||||||
|
# cookbook allowlist (vllm / python3 / sglang / llama-server / ...).
|
||||||
|
from routes.cookbook_helpers import ServeRequest
|
||||||
|
# Accept friendly aliases agents naturally reach for. Without these,
|
||||||
|
# passing `host` silently maps to nothing and the serve runs LOCAL
|
||||||
|
# instead of on the intended remote — exactly the bug an agent
|
||||||
|
# would never debug on its own.
|
||||||
|
norm = dict(body or {})
|
||||||
|
if "host" in norm and "remote_host" not in norm:
|
||||||
|
norm["remote_host"] = norm.pop("host")
|
||||||
|
if "model" in norm and "repo_id" not in norm:
|
||||||
|
norm["repo_id"] = norm.pop("model")
|
||||||
|
if "ssh_port" not in norm and "port" in norm and (str(norm.get("port") or "").isdigit() and int(norm["port"]) >= 1000):
|
||||||
|
# Heuristic: if `port` looks like an SSH port (≥1000) and there's
|
||||||
|
# no explicit ssh_port, treat it as such. UI ports (8000, 8001,
|
||||||
|
# 30000) belong inside the cmd string, not here.
|
||||||
|
pass # leave as-is — user's `port` here is ambiguous; skip remap.
|
||||||
|
try:
|
||||||
|
req = ServeRequest(**norm)
|
||||||
|
except Exception as exc:
|
||||||
|
raise HTTPException(400, f"Invalid serve payload: {exc}")
|
||||||
|
serve_endpoint = _find_endpoint(None, "POST", "/api/model/serve")
|
||||||
|
# Fall back to importing from the cookbook router registered on app.
|
||||||
|
if serve_endpoint is None:
|
||||||
|
from fastapi import FastAPI
|
||||||
|
app: FastAPI = request.app
|
||||||
|
for route in app.routes:
|
||||||
|
if getattr(route, "path", None) == "/api/model/serve" and "POST" in getattr(route, "methods", set()):
|
||||||
|
serve_endpoint = route.endpoint
|
||||||
|
break
|
||||||
|
if serve_endpoint is None:
|
||||||
|
raise HTTPException(503, "model serve endpoint unavailable")
|
||||||
|
return await serve_endpoint(request, req)
|
||||||
|
|
||||||
|
@router.post("/cookbook/stop/{session_id}")
|
||||||
|
async def codex_cookbook_stop(request: Request, session_id: str):
|
||||||
|
_scope_owner(request, COOKBOOK_LAUNCH_SCOPES)
|
||||||
|
import re as _re
|
||||||
|
if not _re.fullmatch(r"[a-zA-Z0-9_-]+", session_id):
|
||||||
|
raise HTTPException(400, "Invalid session id")
|
||||||
|
state = _read_cookbook_state()
|
||||||
|
tasks = state.get("tasks") or []
|
||||||
|
task = next((t for t in tasks if t.get("sessionId") == session_id), None)
|
||||||
|
host = ((task or {}).get("remoteHost") or "").strip()
|
||||||
|
ssh_port = ((task or {}).get("sshPort") or "").strip()
|
||||||
|
if host:
|
||||||
|
port_flag = f"-p {ssh_port} " if ssh_port and ssh_port != "22" else ""
|
||||||
|
cmd = f"ssh {port_flag}{host} \"tmux kill-session -t {session_id}\""
|
||||||
|
else:
|
||||||
|
cmd = f"tmux kill-session -t {session_id}"
|
||||||
|
result = await _run_shell(cmd, timeout=10)
|
||||||
|
return {"session_id": session_id, "exit_code": result.get("exit_code"), "host": host or "local"}
|
||||||
|
|
||||||
|
@router.get("/cookbook/cached")
|
||||||
|
async def codex_cookbook_cached(request: Request, host: str | None = None):
|
||||||
|
"""List cached models on a configured server (or local if host is omitted).
|
||||||
|
Mirrors `list_cached_models` from the chat agent so external agents have
|
||||||
|
the same inventory view before deciding what to serve/download."""
|
||||||
|
_scope_owner(request, COOKBOOK_READ_SCOPES)
|
||||||
|
# Hit /api/model/cached internally, with the same modelDirs the chat
|
||||||
|
# agent's list_cached_models would resolve from cookbook state.
|
||||||
|
state = _read_cookbook_state()
|
||||||
|
env = state.get("env") if isinstance(state, dict) else {}
|
||||||
|
servers = (env.get("servers") if isinstance(env, dict) else None) or []
|
||||||
|
HF_DEFAULTS = {"~/.cache/huggingface/hub", "~/.cache/huggingface"}
|
||||||
|
def _dirs_for(srv: dict) -> str:
|
||||||
|
mds = srv.get("modelDirs") if isinstance(srv, dict) else None
|
||||||
|
if isinstance(mds, list):
|
||||||
|
extras = [d for d in mds if isinstance(d, str) and d.strip() and d.strip() not in HF_DEFAULTS]
|
||||||
|
return ",".join(extras)
|
||||||
|
if isinstance(mds, str) and mds.strip() not in HF_DEFAULTS:
|
||||||
|
return mds
|
||||||
|
return ""
|
||||||
|
# Resolve friendly host name → real host (matches list_cached_models flow).
|
||||||
|
resolved_host = host or ""
|
||||||
|
srv: dict[str, Any] = {}
|
||||||
|
if host:
|
||||||
|
srv = next(
|
||||||
|
(s for s in servers if isinstance(s, dict)
|
||||||
|
and (s.get("name") == host or s.get("host") == host)),
|
||||||
|
{},
|
||||||
|
)
|
||||||
|
if srv and srv.get("host"):
|
||||||
|
resolved_host = srv["host"]
|
||||||
|
else:
|
||||||
|
srv = next((s for s in servers if isinstance(s, dict) and not (s.get("host") or "").strip()), {})
|
||||||
|
params: dict[str, str] = {}
|
||||||
|
if resolved_host:
|
||||||
|
params["host"] = resolved_host
|
||||||
|
md = _dirs_for(srv)
|
||||||
|
if md:
|
||||||
|
params["model_dir"] = md
|
||||||
|
if srv.get("port"):
|
||||||
|
params["ssh_port"] = str(srv["port"])
|
||||||
|
if srv.get("platform"):
|
||||||
|
params["platform"] = srv["platform"]
|
||||||
|
cached_endpoint = _find_endpoint(None, "GET", "/api/model/cached")
|
||||||
|
if cached_endpoint is None:
|
||||||
|
from fastapi import FastAPI
|
||||||
|
app: FastAPI = request.app
|
||||||
|
for route in app.routes:
|
||||||
|
if getattr(route, "path", None) == "/api/model/cached" and "GET" in getattr(route, "methods", set()):
|
||||||
|
cached_endpoint = route.endpoint
|
||||||
|
break
|
||||||
|
if cached_endpoint is None:
|
||||||
|
raise HTTPException(503, "model cached endpoint unavailable")
|
||||||
|
# The endpoint reads host/model_dir/ssh_port/platform as kwargs.
|
||||||
|
return await cached_endpoint(
|
||||||
|
request,
|
||||||
|
host=params.get("host") or None,
|
||||||
|
model_dir=params.get("model_dir") or None,
|
||||||
|
ssh_port=params.get("ssh_port") or None,
|
||||||
|
platform=params.get("platform") or None,
|
||||||
|
)
|
||||||
|
|
||||||
|
@router.get("/cookbook/presets")
|
||||||
|
async def codex_cookbook_presets(request: Request):
|
||||||
|
"""List saved serve presets (model + host + port + launch cmd).
|
||||||
|
Counterpart to `list_serve_presets`. Use BEFORE composing a `serve`
|
||||||
|
body — the user's saved preset usually has the working cmd already."""
|
||||||
|
_scope_owner(request, COOKBOOK_READ_SCOPES)
|
||||||
|
state = _read_cookbook_state()
|
||||||
|
presets = state.get("presets") or []
|
||||||
|
out = []
|
||||||
|
for p in presets:
|
||||||
|
if not isinstance(p, dict):
|
||||||
|
continue
|
||||||
|
out.append({
|
||||||
|
"name": p.get("name"),
|
||||||
|
"model": p.get("model") or p.get("modelId"),
|
||||||
|
"host": p.get("host") or p.get("remoteHost"),
|
||||||
|
"port": p.get("port"),
|
||||||
|
"cmd": p.get("cmd"),
|
||||||
|
})
|
||||||
|
return {"presets": out, "default_host": (state.get("env") or {}).get("defaultServer", "")}
|
||||||
|
|
||||||
|
@router.post("/cookbook/preset/{name}")
|
||||||
|
async def codex_cookbook_serve_preset(request: Request, name: str):
|
||||||
|
"""Launch a saved preset by name. Reuses the working cmd + host the
|
||||||
|
user already saved, avoiding the cmd-allowlist trial-and-error loop."""
|
||||||
|
_scope_owner(request, COOKBOOK_LAUNCH_SCOPES)
|
||||||
|
import re as _re
|
||||||
|
if not _re.fullmatch(r"[A-Za-z0-9 _.:@\-]+", name):
|
||||||
|
raise HTTPException(400, "Invalid preset name")
|
||||||
|
state = _read_cookbook_state()
|
||||||
|
presets = state.get("presets") or []
|
||||||
|
lname = name.lower().strip()
|
||||||
|
chosen = next(
|
||||||
|
(p for p in presets if isinstance(p, dict) and (p.get("name") or "").lower() == lname),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
if chosen is None:
|
||||||
|
chosen = next(
|
||||||
|
(p for p in presets if isinstance(p, dict) and lname in (p.get("name") or "").lower()),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
if chosen is None:
|
||||||
|
raise HTTPException(404, f"No preset matching {name!r}")
|
||||||
|
repo_id = chosen.get("model") or chosen.get("modelId") or ""
|
||||||
|
cmd = (chosen.get("cmd") or "").strip()
|
||||||
|
host = chosen.get("host") or chosen.get("remoteHost") or ""
|
||||||
|
if not repo_id or not cmd or cmd.startswith("(adopted"):
|
||||||
|
raise HTTPException(400, f"Preset {chosen.get('name')!r} has no launchable cmd "
|
||||||
|
"(adopted from external launch). Use POST /cookbook/serve "
|
||||||
|
"with the actual cmd instead.")
|
||||||
|
# Reuse the serve handler we already validated.
|
||||||
|
from routes.cookbook_helpers import ServeRequest
|
||||||
|
body = {"repo_id": repo_id, "cmd": cmd}
|
||||||
|
if host:
|
||||||
|
body["remote_host"] = host
|
||||||
|
try:
|
||||||
|
req = ServeRequest(**body)
|
||||||
|
except Exception as exc:
|
||||||
|
raise HTTPException(400, f"Preset payload invalid: {exc}")
|
||||||
|
serve_endpoint = _find_endpoint(None, "POST", "/api/model/serve")
|
||||||
|
if serve_endpoint is None:
|
||||||
|
from fastapi import FastAPI
|
||||||
|
app: FastAPI = request.app
|
||||||
|
for route in app.routes:
|
||||||
|
if getattr(route, "path", None) == "/api/model/serve" and "POST" in getattr(route, "methods", set()):
|
||||||
|
serve_endpoint = route.endpoint
|
||||||
|
break
|
||||||
|
if serve_endpoint is None:
|
||||||
|
raise HTTPException(503, "model serve endpoint unavailable")
|
||||||
|
return await serve_endpoint(request, req)
|
||||||
|
|
||||||
|
@router.post("/cookbook/adopt")
|
||||||
|
async def codex_cookbook_adopt(request: Request, body: dict[str, Any] = Body(default_factory=dict)):
|
||||||
|
"""Adopt an existing tmux session (one started via raw ssh+tmux) into
|
||||||
|
cookbook tracking. Needed when serve_model rejects a cmd and the
|
||||||
|
agent falls back to direct ssh — without adoption the session is
|
||||||
|
invisible to the UI. Body: {tmux_session, model, host?, port?}."""
|
||||||
|
_scope_owner(request, COOKBOOK_LAUNCH_SCOPES)
|
||||||
|
norm = dict(body or {})
|
||||||
|
sess = (norm.get("tmux_session") or norm.get("session_id") or "").strip()
|
||||||
|
model = (norm.get("model") or norm.get("repo_id") or "").strip()
|
||||||
|
host = (norm.get("host") or norm.get("remote_host") or "").strip()
|
||||||
|
port = norm.get("port") or 8000
|
||||||
|
import re as _re
|
||||||
|
if not sess or not _re.fullmatch(r"[a-zA-Z0-9_-]+", sess):
|
||||||
|
raise HTTPException(400, "tmux_session required, [a-zA-Z0-9_-]+ only")
|
||||||
|
if not model:
|
||||||
|
raise HTTPException(400, "model required")
|
||||||
|
# Verify the tmux session exists on the target host before adopting.
|
||||||
|
import shlex
|
||||||
|
if host:
|
||||||
|
check = f"ssh {shlex.quote(host)} 'tmux has-session -t {shlex.quote(sess)}'"
|
||||||
|
else:
|
||||||
|
check = f"tmux has-session -t {shlex.quote(sess)}"
|
||||||
|
chk = await _run_shell(check, timeout=8)
|
||||||
|
if chk.get("exit_code") not in (0, None):
|
||||||
|
raise HTTPException(404, f"tmux session {sess!r} not found on {host or 'local'}")
|
||||||
|
# Write into cookbook_state.json.
|
||||||
|
import time as _t, json as _json
|
||||||
|
from core.atomic_io import atomic_write_json
|
||||||
|
from pathlib import Path as _Path
|
||||||
|
cookbook_state_path = _Path(COOKBOOK_STATE_FILE)
|
||||||
|
try:
|
||||||
|
state = _json.loads(cookbook_state_path.read_text(encoding="utf-8"))
|
||||||
|
except Exception:
|
||||||
|
state = {}
|
||||||
|
tasks = state.setdefault("tasks", [])
|
||||||
|
if any(isinstance(t, dict) and t.get("sessionId") == sess for t in tasks):
|
||||||
|
return {"ok": True, "already_tracked": True, "session_id": sess}
|
||||||
|
tasks.append({
|
||||||
|
"id": sess, "sessionId": sess,
|
||||||
|
"name": model.split("/")[-1] if "/" in model else model,
|
||||||
|
"type": "serve", "status": "running",
|
||||||
|
"output": f"Adopted externally-launched session {sess!r} on {host or 'local'}.",
|
||||||
|
"ts": int(_t.time() * 1000),
|
||||||
|
"payload": {"repo_id": model, "remote_host": host, "_cmd": "(adopted — launched outside cookbook)", "port": int(port)},
|
||||||
|
"remoteHost": host, "sshPort": "", "platform": "linux",
|
||||||
|
"_serveReady": False, "_endpointAdded": False, "_adoptedExternally": True,
|
||||||
|
})
|
||||||
|
try:
|
||||||
|
atomic_write_json(cookbook_state_path, state)
|
||||||
|
except Exception as exc:
|
||||||
|
raise HTTPException(500, f"state write failed: {exc}")
|
||||||
|
return {"ok": True, "session_id": sess, "host": host or "local"}
|
||||||
|
|
||||||
return router
|
return router
|
||||||
|
|
||||||
|
|
||||||
@@ -387,7 +772,7 @@ def setup_claude_routes() -> APIRouter:
|
|||||||
|
|
||||||
@router.get("/plugin.zip")
|
@router.get("/plugin.zip")
|
||||||
def plugin_zip(request: Request):
|
def plugin_zip(request: Request):
|
||||||
require_user(request)
|
require_authenticated_request(request)
|
||||||
# Only ship the skills/ subtree so extracting at ~/.claude/ doesn't dump
|
# Only ship the skills/ subtree so extracting at ~/.claude/ doesn't dump
|
||||||
# README.md or other bundle metadata into the user's claude config dir.
|
# README.md or other bundle metadata into the user's claude config dir.
|
||||||
skills_root = Path(__file__).resolve().parent.parent / "integrations" / "claude" / "skills"
|
skills_root = Path(__file__).resolve().parent.parent / "integrations" / "claude" / "skills"
|
||||||
|
|||||||
+110
-22
@@ -12,6 +12,7 @@ import logging
|
|||||||
from core.database import Comparison, SessionLocal
|
from core.database import Comparison, SessionLocal
|
||||||
from core.session_manager import SessionManager
|
from core.session_manager import SessionManager
|
||||||
from src.auth_helpers import get_current_user
|
from src.auth_helpers import get_current_user
|
||||||
|
from routes.session_routes import _reject_raw_endpoint_url_for_non_admin
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -38,6 +39,24 @@ def _owned_endpoint_by_url(db, base_url, owner):
|
|||||||
return owner_filter(q, ModelEndpoint, owner).first()
|
return owner_filter(q, ModelEndpoint, owner).first()
|
||||||
|
|
||||||
|
|
||||||
|
def _owned_endpoint_by_id(db, endpoint_id, owner):
|
||||||
|
"""ModelEndpoint whose id == `endpoint_id` and is VISIBLE to `owner` (their
|
||||||
|
own rows + legacy null-owner "shared" rows); None otherwise.
|
||||||
|
|
||||||
|
Preferred over _owned_endpoint_by_url for credential resolution: two visible
|
||||||
|
endpoints can share the same base_url but hold DIFFERENT api_keys (e.g. two
|
||||||
|
accounts on the same provider). A base_url-only match returns whichever row
|
||||||
|
sorts first, so it can copy the WRONG owner-scoped key into the [CMP] session.
|
||||||
|
An id pins the exact registered endpoint, so /api/compare/start prefers it and
|
||||||
|
only falls back to URL matching for legacy / admin raw-URL callers. Owner
|
||||||
|
scoping is identical to _owned_endpoint_by_url (a null/empty owner is a no-op).
|
||||||
|
"""
|
||||||
|
from core.database import ModelEndpoint
|
||||||
|
from src.auth_helpers import owner_filter
|
||||||
|
q = db.query(ModelEndpoint).filter(ModelEndpoint.id == endpoint_id)
|
||||||
|
return owner_filter(q, ModelEndpoint, owner).first()
|
||||||
|
|
||||||
|
|
||||||
class RecordVoteRequest(BaseModel):
|
class RecordVoteRequest(BaseModel):
|
||||||
prompt: str
|
prompt: str
|
||||||
models: List[str]
|
models: List[str]
|
||||||
@@ -54,8 +73,10 @@ def setup_compare_routes(session_manager: SessionManager):
|
|||||||
prompt: str = Form(...),
|
prompt: str = Form(...),
|
||||||
model_a: str = Form(...),
|
model_a: str = Form(...),
|
||||||
model_b: str = Form(...),
|
model_b: str = Form(...),
|
||||||
endpoint_a: str = Form(...),
|
endpoint_a: str = Form(""),
|
||||||
endpoint_b: str = Form(...),
|
endpoint_b: str = Form(""),
|
||||||
|
endpoint_a_id: str = Form(""),
|
||||||
|
endpoint_b_id: str = Form(""),
|
||||||
is_blind: str = Form("true"),
|
is_blind: str = Form("true"),
|
||||||
):
|
):
|
||||||
"""Create two ephemeral sessions and a comparison record.
|
"""Create two ephemeral sessions and a comparison record.
|
||||||
@@ -63,10 +84,10 @@ def setup_compare_routes(session_manager: SessionManager):
|
|||||||
Returns the comparison ID and the two session IDs so the client
|
Returns the comparison ID and the two session IDs so the client
|
||||||
can fire two independent SSE streams to /api/chat_stream.
|
can fire two independent SSE streams to /api/chat_stream.
|
||||||
"""
|
"""
|
||||||
|
user = getattr(request.state, 'current_user', None)
|
||||||
comp_id = str(uuid.uuid4())
|
comp_id = str(uuid.uuid4())
|
||||||
sid_a = str(uuid.uuid4())
|
sid_a = str(uuid.uuid4())
|
||||||
sid_b = str(uuid.uuid4())
|
sid_b = str(uuid.uuid4())
|
||||||
user = getattr(request.state, 'current_user', None)
|
|
||||||
|
|
||||||
# Blind mapping: randomly assign left/right
|
# Blind mapping: randomly assign left/right
|
||||||
blind = str(is_blind).lower() == "true"
|
blind = str(is_blind).lower() == "true"
|
||||||
@@ -87,31 +108,94 @@ def setup_compare_routes(session_manager: SessionManager):
|
|||||||
# de-anonymizing the comparison before the user votes (issue #1285).
|
# de-anonymizing the comparison before the user votes (issue #1285).
|
||||||
slot_name = {session_left: "Model A", session_right: "Model B"}
|
slot_name = {session_left: "Model A", session_right: "Model B"}
|
||||||
|
|
||||||
# Create ephemeral sessions (prefixed [CMP])
|
# SECURITY: resolve and validate BOTH endpoints before creating any
|
||||||
for sid, model, endpoint in [(sid_a, model_a, endpoint_a), (sid_b, model_b, endpoint_b)]:
|
# session. Compare copies a registered endpoint's Authorization header
|
||||||
|
# into the [CMP] session, so validating one endpoint while creating its
|
||||||
|
# session, then rejecting the other, would leave a partial compare
|
||||||
|
# session behind with that header attached. Doing all the owner-scope
|
||||||
|
# resolution + raw-URL rejection up front means a 403 on either endpoint
|
||||||
|
# aborts the whole request with nothing created and no header copied.
|
||||||
|
from src.endpoint_resolver import build_chat_url, build_headers, normalize_base
|
||||||
|
resolved = []
|
||||||
|
db = SessionLocal()
|
||||||
|
try:
|
||||||
|
for sid, model, endpoint, endpoint_id in [
|
||||||
|
(sid_a, model_a, endpoint_a, endpoint_a_id),
|
||||||
|
(sid_b, model_b, endpoint_b, endpoint_b_id),
|
||||||
|
]:
|
||||||
|
# Prefer an explicit endpoint id: it pins the EXACT registered
|
||||||
|
# endpoint (and its api_key), even when two endpoints visible to
|
||||||
|
# the caller share a base_url with different keys — a URL-only
|
||||||
|
# match would copy whichever row sorts first, i.e. possibly the
|
||||||
|
# wrong key. Fall back to URL resolution only for legacy / admin
|
||||||
|
# raw-URL callers that don't send an id.
|
||||||
|
eid = endpoint_id.strip() if isinstance(endpoint_id, str) else ""
|
||||||
|
if eid:
|
||||||
|
ep = _owned_endpoint_by_id(db, eid, user)
|
||||||
|
if ep is None:
|
||||||
|
# An id the caller can't see (wrong owner / deleted) must
|
||||||
|
# NOT silently fall back to a same-URL row with a different
|
||||||
|
# key — that's exactly the mix-up ids exist to prevent.
|
||||||
|
raise HTTPException(404, "Model endpoint not found")
|
||||||
|
# The id already resolved the endpoint; ignore any raw URL the
|
||||||
|
# caller also sent and dial the stored config instead.
|
||||||
|
endpoint = ep.base_url
|
||||||
|
elif not endpoint:
|
||||||
|
raise HTTPException(
|
||||||
|
422, "endpoint_a/endpoint_b or endpoint_a_id/endpoint_b_id is required"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Resolve the supplied URL to a ModelEndpoint the caller owns
|
||||||
|
# (their own rows + legacy null-owner shared rows), scoped so a
|
||||||
|
# comparison can't borrow another user's private endpoint key.
|
||||||
|
base = normalize_base(endpoint)
|
||||||
|
ep = _owned_endpoint_by_url(db, base, user)
|
||||||
|
# Reject *unregistered* raw URLs for signed-in non-admins; a
|
||||||
|
# matched registered endpoint supplies an id so the caller can
|
||||||
|
# still compare endpoints they own. Blanket-rejecting here (the
|
||||||
|
# earlier `endpoint_id=None` call) locked non-admins out of
|
||||||
|
# compare entirely, since compare resolves endpoints by URL with
|
||||||
|
# no endpoint_id. Mirrors the gallery inpaint/harmonize checks.
|
||||||
|
# Raised here (phase 1), before any session exists.
|
||||||
|
_reject_raw_endpoint_url_for_non_admin(
|
||||||
|
request, user, str(ep.id) if ep is not None else None, endpoint
|
||||||
|
)
|
||||||
|
# Bind the [CMP] session to the RESOLVED endpoint, not the raw
|
||||||
|
# caller-supplied string. When the URL matches a registered
|
||||||
|
# endpoint visible to the caller, use that row's own normalized
|
||||||
|
# base URL (the same value owner scoping + endpoint validation
|
||||||
|
# already vetted) so the session dials exactly where the stored
|
||||||
|
# config points. The raw `endpoint` only survives for callers
|
||||||
|
# allowed to pass one — admins / single-user mode, where
|
||||||
|
# `_reject_raw_endpoint_url_for_non_admin` is a no-op and `ep`
|
||||||
|
# is None. Mirrors the registered-endpoint path in session_routes.
|
||||||
|
session_endpoint_url = (
|
||||||
|
build_chat_url(normalize_base(ep.base_url)) if ep is not None else endpoint
|
||||||
|
)
|
||||||
|
# Headers come only from a matched endpoint's key; None when
|
||||||
|
# `ep` is None (raw admin URL or no match), so a comparison can
|
||||||
|
# never inherit another user's key/headers.
|
||||||
|
headers = build_headers(ep.api_key, ep.base_url) if (ep and ep.api_key) else None
|
||||||
|
resolved.append((sid, model, session_endpoint_url, headers))
|
||||||
|
finally:
|
||||||
|
db.close()
|
||||||
|
|
||||||
|
# Both endpoints validated — only now create the ephemeral [CMP]
|
||||||
|
# sessions and copy any resolved headers.
|
||||||
|
for sid, model, session_endpoint_url, headers in resolved:
|
||||||
name = f"[CMP] {slot_name[sid]}" if blind else f"[CMP] {model.split('/')[-1]}"
|
name = f"[CMP] {slot_name[sid]}" if blind else f"[CMP] {model.split('/')[-1]}"
|
||||||
session_manager.create_session(
|
session_manager.create_session(
|
||||||
session_id=sid,
|
session_id=sid,
|
||||||
name=name,
|
name=name,
|
||||||
endpoint_url=endpoint,
|
endpoint_url=session_endpoint_url,
|
||||||
model=model,
|
model=model,
|
||||||
rag=False,
|
rag=False,
|
||||||
owner=user,
|
owner=user,
|
||||||
)
|
)
|
||||||
# Copy API key from endpoint config
|
if headers:
|
||||||
db = SessionLocal()
|
s = session_manager.sessions.get(sid)
|
||||||
try:
|
if s:
|
||||||
from src.endpoint_resolver import build_headers, normalize_base
|
s.headers = headers
|
||||||
# Find matching endpoint by URL, scoped to the caller so a
|
|
||||||
# comparison can't borrow another user's private endpoint key.
|
|
||||||
base = normalize_base(endpoint)
|
|
||||||
ep = _owned_endpoint_by_url(db, base, user)
|
|
||||||
if ep and ep.api_key:
|
|
||||||
s = session_manager.sessions.get(sid)
|
|
||||||
if s:
|
|
||||||
s.headers = build_headers(ep.api_key, ep.base_url)
|
|
||||||
finally:
|
|
||||||
db.close()
|
|
||||||
|
|
||||||
# Store comparison record
|
# Store comparison record
|
||||||
db = SessionLocal()
|
db = SessionLocal()
|
||||||
@@ -121,8 +205,12 @@ def setup_compare_routes(session_manager: SessionManager):
|
|||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
model_a=model_a,
|
model_a=model_a,
|
||||||
model_b=model_b,
|
model_b=model_b,
|
||||||
endpoint_a=endpoint_a,
|
# Record the URL the session actually dials. For URL callers this
|
||||||
endpoint_b=endpoint_b,
|
# is their raw input; for id-only callers (empty endpoint_a/_b)
|
||||||
|
# fall back to the resolved endpoint URL so the column stays
|
||||||
|
# meaningful and non-null. resolved is in [a, b] order.
|
||||||
|
endpoint_a=endpoint_a or resolved[0][2],
|
||||||
|
endpoint_b=endpoint_b or resolved[1][2],
|
||||||
is_blind=blind,
|
is_blind=blind,
|
||||||
blind_mapping=json.dumps(mapping),
|
blind_mapping=json.dumps(mapping),
|
||||||
owner=user,
|
owner=user,
|
||||||
|
|||||||
+53
-18
@@ -11,20 +11,24 @@ import uuid
|
|||||||
import json
|
import json
|
||||||
import csv
|
import csv
|
||||||
import io
|
import io
|
||||||
|
import os
|
||||||
import httpx
|
import httpx
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from fastapi import APIRouter, Query, Depends, Response
|
from urllib.parse import urljoin, urlparse, urlunparse
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Query, Depends, Response, HTTPException
|
||||||
from typing import List, Dict, Optional
|
from typing import List, Dict, Optional
|
||||||
|
|
||||||
from src.auth_helpers import require_user
|
|
||||||
from core.middleware import require_admin
|
from core.middleware import require_admin
|
||||||
|
from src.url_safety import check_outbound_url
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
DATA_DIR = Path(__file__).resolve().parent.parent / "data"
|
from src.constants import DATA_DIR as _DATA_DIR, SETTINGS_FILE as _SETTINGS_FILE, CONTACTS_FILE as _CONTACTS_FILE
|
||||||
SETTINGS_FILE = DATA_DIR / "settings.json"
|
DATA_DIR = Path(_DATA_DIR)
|
||||||
LOCAL_CONTACTS_FILE = DATA_DIR / "contacts.json"
|
SETTINGS_FILE = Path(_SETTINGS_FILE)
|
||||||
|
LOCAL_CONTACTS_FILE = Path(_CONTACTS_FILE)
|
||||||
|
|
||||||
|
|
||||||
def _load_settings():
|
def _load_settings():
|
||||||
@@ -53,6 +57,21 @@ def _carddav_configured(cfg: Optional[Dict] = None) -> bool:
|
|||||||
return bool((cfg.get("url") or "").strip())
|
return bool((cfg.get("url") or "").strip())
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_carddav_url(url: str) -> str:
|
||||||
|
cleaned = (url if isinstance(url, str) else "").strip().rstrip("/")
|
||||||
|
ok, reason = check_outbound_url(
|
||||||
|
cleaned,
|
||||||
|
block_private=os.getenv("CARDDAV_BLOCK_PRIVATE_IPS", "false").lower() == "true",
|
||||||
|
)
|
||||||
|
if not ok:
|
||||||
|
raise ValueError(f"Rejected CardDAV URL: {reason}")
|
||||||
|
return cleaned
|
||||||
|
|
||||||
|
|
||||||
|
def _carddav_base_url(cfg: Dict) -> str:
|
||||||
|
return _validate_carddav_url(cfg.get("url") or "")
|
||||||
|
|
||||||
|
|
||||||
def _normalize_contact(contact: Dict) -> Dict:
|
def _normalize_contact(contact: Dict) -> Dict:
|
||||||
emails = []
|
emails = []
|
||||||
for e in contact.get("emails") or ([] if not contact.get("email") else [contact.get("email")]):
|
for e in contact.get("emails") or ([] if not contact.get("email") else [contact.get("email")]):
|
||||||
@@ -219,14 +238,18 @@ _contact_cache = {"contacts": [], "fetched_at": None}
|
|||||||
def _abs_url(href: str) -> str:
|
def _abs_url(href: str) -> str:
|
||||||
"""Combine a multistatus <href> (an absolute path like
|
"""Combine a multistatus <href> (an absolute path like
|
||||||
/user/contacts/x.vcf) with the configured CardDAV server origin so we
|
/user/contacts/x.vcf) with the configured CardDAV server origin so we
|
||||||
get a fully-qualified URL to PUT/DELETE. If href is already absolute
|
get a fully-qualified URL to PUT/DELETE. Absolute hrefs are accepted only
|
||||||
(http...), return it as-is."""
|
for the configured origin; a cross-origin href is treated as a path on the
|
||||||
from urllib.parse import urlparse, urlunparse
|
configured server so a malicious CardDAV response cannot redirect later
|
||||||
if href.startswith("http://") or href.startswith("https://"):
|
writes/deletes to cloud metadata or another host."""
|
||||||
return href
|
|
||||||
cfg = _get_carddav_config()
|
cfg = _get_carddav_config()
|
||||||
p = urlparse(cfg["url"])
|
base = _carddav_base_url(cfg)
|
||||||
return urlunparse((p.scheme, p.netloc, href, "", "", ""))
|
base_p = urlparse(base)
|
||||||
|
joined = urljoin(base.rstrip("/") + "/", href or "")
|
||||||
|
joined_p = urlparse(joined)
|
||||||
|
if (joined_p.scheme, joined_p.netloc) != (base_p.scheme, base_p.netloc):
|
||||||
|
joined = urlunparse((base_p.scheme, base_p.netloc, joined_p.path or "/", "", joined_p.query, ""))
|
||||||
|
return _validate_carddav_url(joined)
|
||||||
|
|
||||||
|
|
||||||
# CardDAV REPORT body — pull every card's etag + raw vCard in ONE request,
|
# CardDAV REPORT body — pull every card's etag + raw vCard in ONE request,
|
||||||
@@ -297,6 +320,7 @@ def _fetch_contacts(force=False):
|
|||||||
return contacts
|
return contacts
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
cfg["url"] = _carddav_base_url(cfg)
|
||||||
auth = None
|
auth = None
|
||||||
if cfg["username"]:
|
if cfg["username"]:
|
||||||
auth = (cfg["username"], cfg["password"])
|
auth = (cfg["username"], cfg["password"])
|
||||||
@@ -353,8 +377,8 @@ def _create_contact(name: str, email: str) -> bool:
|
|||||||
|
|
||||||
contact_uid = str(uuid.uuid4())
|
contact_uid = str(uuid.uuid4())
|
||||||
vcard = _build_vcard(name, email, contact_uid)
|
vcard = _build_vcard(name, email, contact_uid)
|
||||||
url = cfg["url"].rstrip("/") + "/" + contact_uid + ".vcf"
|
|
||||||
try:
|
try:
|
||||||
|
url = _carddav_base_url(cfg) + "/" + contact_uid + ".vcf"
|
||||||
auth = None
|
auth = None
|
||||||
if cfg["username"]:
|
if cfg["username"]:
|
||||||
auth = (cfg["username"], cfg["password"])
|
auth = (cfg["username"], cfg["password"])
|
||||||
@@ -382,7 +406,7 @@ def _vcard_url(uid: str) -> str:
|
|||||||
escape the collection and target an arbitrary CardDAV resource."""
|
escape the collection and target an arbitrary CardDAV resource."""
|
||||||
from urllib.parse import quote
|
from urllib.parse import quote
|
||||||
cfg = _get_carddav_config()
|
cfg = _get_carddav_config()
|
||||||
return cfg["url"].rstrip("/") + "/" + quote(uid, safe="") + ".vcf"
|
return _carddav_base_url(cfg) + "/" + quote(uid, safe="") + ".vcf"
|
||||||
|
|
||||||
|
|
||||||
def _import_vcards(text: str) -> Dict:
|
def _import_vcards(text: str) -> Dict:
|
||||||
@@ -413,6 +437,11 @@ def _import_vcards(text: str) -> Dict:
|
|||||||
if imported:
|
if imported:
|
||||||
_save_local_contacts(contacts)
|
_save_local_contacts(contacts)
|
||||||
return {"imported": imported, "failed": 0, "total": len(parsed)}
|
return {"imported": imported, "failed": 0, "total": len(parsed)}
|
||||||
|
try:
|
||||||
|
base_url = _carddav_base_url(cfg)
|
||||||
|
except ValueError as e:
|
||||||
|
logger.warning("CardDAV import URL rejected: %s", e)
|
||||||
|
return {"imported": 0, "failed": 0, "total": 0, "error": str(e)}
|
||||||
auth = (cfg["username"], cfg["password"]) if cfg["username"] else None
|
auth = (cfg["username"], cfg["password"]) if cfg["username"] else None
|
||||||
# Split into individual cards. re.split drops the BEGIN line, so we
|
# Split into individual cards. re.split drops the BEGIN line, so we
|
||||||
# re-add it. Normalize CRLF.
|
# re-add it. Normalize CRLF.
|
||||||
@@ -441,7 +470,7 @@ def _import_vcards(text: str) -> Dict:
|
|||||||
elif not re.search(r"^VERSION:", block, re.MULTILINE):
|
elif not re.search(r"^VERSION:", block, re.MULTILINE):
|
||||||
block = block.replace("BEGIN:VCARD", "BEGIN:VCARD\nVERSION:4.0", 1)
|
block = block.replace("BEGIN:VCARD", "BEGIN:VCARD\nVERSION:4.0", 1)
|
||||||
vcard = block.replace("\n", "\r\n") + "\r\n"
|
vcard = block.replace("\n", "\r\n") + "\r\n"
|
||||||
url = cfg["url"].rstrip("/") + "/" + quote(uid, safe="") + ".vcf"
|
url = base_url + "/" + quote(uid, safe="") + ".vcf"
|
||||||
try:
|
try:
|
||||||
r = httpx.put(
|
r = httpx.put(
|
||||||
url, data=vcard.encode("utf-8"),
|
url, data=vcard.encode("utf-8"),
|
||||||
@@ -601,8 +630,8 @@ def _update_contact(uid: str, name: str, emails: List[str], phones: List[str]) -
|
|||||||
vcard = _build_vcard(name, "", uid=uid, emails=emails, phones=phones)
|
vcard = _build_vcard(name, "", uid=uid, emails=emails, phones=phones)
|
||||||
# Use the real resource href (handles externally-created contacts whose
|
# Use the real resource href (handles externally-created contacts whose
|
||||||
# filename != UID); falls back to the <uid>.vcf guess.
|
# filename != UID); falls back to the <uid>.vcf guess.
|
||||||
url = _resolve_resource_url(uid)
|
|
||||||
try:
|
try:
|
||||||
|
url = _resolve_resource_url(uid)
|
||||||
auth = (cfg["username"], cfg["password"]) if cfg["username"] else None
|
auth = (cfg["username"], cfg["password"]) if cfg["username"] else None
|
||||||
r = httpx.put(
|
r = httpx.put(
|
||||||
url,
|
url,
|
||||||
@@ -630,8 +659,8 @@ def _delete_contact(uid: str) -> bool:
|
|||||||
_save_local_contacts(remaining)
|
_save_local_contacts(remaining)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
url = _resolve_resource_url(uid)
|
|
||||||
try:
|
try:
|
||||||
|
url = _resolve_resource_url(uid)
|
||||||
auth = (cfg["username"], cfg["password"]) if cfg["username"] else None
|
auth = (cfg["username"], cfg["password"]) if cfg["username"] else None
|
||||||
r = httpx.delete(url, auth=auth, timeout=10)
|
r = httpx.delete(url, auth=auth, timeout=10)
|
||||||
if r.status_code in (200, 204):
|
if r.status_code in (200, 204):
|
||||||
@@ -747,7 +776,13 @@ def setup_contacts_routes():
|
|||||||
settings = _load_settings()
|
settings = _load_settings()
|
||||||
for key in ("carddav_url", "carddav_username", "carddav_password"):
|
for key in ("carddav_url", "carddav_username", "carddav_password"):
|
||||||
if key in data:
|
if key in data:
|
||||||
settings[key] = data[key]
|
if key == "carddav_url" and str(data[key] or "").strip():
|
||||||
|
try:
|
||||||
|
settings[key] = _validate_carddav_url(data[key])
|
||||||
|
except ValueError as e:
|
||||||
|
raise HTTPException(400, str(e))
|
||||||
|
else:
|
||||||
|
settings[key] = data[key]
|
||||||
_save_settings(settings)
|
_save_settings(settings)
|
||||||
# Force re-fetch
|
# Force re-fetch
|
||||||
_contact_cache["fetched_at"] = None
|
_contact_cache["fetched_at"] = None
|
||||||
|
|||||||
+324
-15
@@ -11,6 +11,8 @@ import shlex
|
|||||||
from fastapi import HTTPException
|
from fastapi import HTTPException
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from core.platform_compat import _ssh_exec_argv
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@@ -195,6 +197,20 @@ def _pip_install_attempt(pip_cmd: str) -> str:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _pip_command(python_cmd: str) -> str:
|
||||||
|
"""Return a pip command for either a pip executable or a Python executable."""
|
||||||
|
cmd = python_cmd.strip()
|
||||||
|
if " -m pip" in cmd or cmd in {"pip", "pip3"}:
|
||||||
|
return python_cmd
|
||||||
|
if cmd in {"python", "python3", "python.exe"} or cmd.endswith(("/python", "/python3", "\\python.exe")):
|
||||||
|
return f"{python_cmd} -m pip"
|
||||||
|
return python_cmd
|
||||||
|
|
||||||
|
|
||||||
|
def _pip_break_system_packages_check(pip_cmd: str) -> str:
|
||||||
|
return f"{pip_cmd} install --help 2>/dev/null | grep -q -- --break-system-packages"
|
||||||
|
|
||||||
|
|
||||||
def _pip_install_fallback_chain(package: str, *, python_cmd: str = "python3 -m pip", upgrade: bool = False) -> str:
|
def _pip_install_fallback_chain(package: str, *, python_cmd: str = "python3 -m pip", upgrade: bool = False) -> str:
|
||||||
"""Build a bash pip install fallback chain that surfaces errors.
|
"""Build a bash pip install fallback chain that surfaces errors.
|
||||||
|
|
||||||
@@ -206,33 +222,44 @@ def _pip_install_fallback_chain(package: str, *, python_cmd: str = "python3 -m p
|
|||||||
exit code is preserved (no ``| tail`` masking) and the last 5 lines of
|
exit code is preserved (no ``| tail`` masking) and the last 5 lines of
|
||||||
pip output appear in the Cookbook log on failure.
|
pip output appear in the Cookbook log on failure.
|
||||||
"""
|
"""
|
||||||
|
from core.platform_compat import IS_WINDOWS
|
||||||
upgrade_flag = " -U" if upgrade else ""
|
upgrade_flag = " -U" if upgrade else ""
|
||||||
# Shell-quote the package spec: an extras spec like ``llama-cpp-python[server]``
|
# Shell-quote the package spec: an extras spec like ``llama-cpp-python[server]``
|
||||||
# contains brackets that bash would treat as a glob, so it must be quoted
|
# contains brackets that bash would treat as a glob, so it must be quoted
|
||||||
# before being embedded in the install command. Plain names (e.g.
|
# before being embedded in the install command. Plain names (e.g.
|
||||||
# ``huggingface_hub``) are returned unchanged by ``shlex.quote``.
|
# ``huggingface_hub``) are returned unchanged by ``shlex.quote``.
|
||||||
pkg = shlex.quote(package)
|
pkg = shlex.quote(package)
|
||||||
base = _pip_install_attempt(f"{python_cmd} install -q{upgrade_flag} {pkg}")
|
# llama-cpp-python source builds are brittle on older distro pip/packaging
|
||||||
user = _pip_install_attempt(f"{python_cmd} install --user --break-system-packages -q{upgrade_flag} {pkg}")
|
# stacks (common on WSL images). Prefer the prebuilt wheel index whenever
|
||||||
|
# this package is requested so dependency-install tasks are reliable.
|
||||||
|
if "llama-cpp-python" in package:
|
||||||
|
pkg += " --extra-index-url https://abetlen.github.io/llama-cpp-python/whl/cpu"
|
||||||
|
|
||||||
|
pip_cmd = _pip_command(python_cmd)
|
||||||
|
base = _pip_install_attempt(f"{pip_cmd} install -q{upgrade_flag} {pkg}")
|
||||||
|
user = _pip_install_attempt(f"{pip_cmd} install --user -q{upgrade_flag} {pkg}")
|
||||||
|
user_break_system = _pip_install_attempt(f"{pip_cmd} install --user --break-system-packages -q{upgrade_flag} {pkg}")
|
||||||
|
user_fallback = f"( {user} || {{ {_pip_break_system_packages_check(pip_cmd)} && {user_break_system}; }} )"
|
||||||
# Derive the python executable for the venv detection check.
|
# Derive the python executable for the venv detection check.
|
||||||
# Must use the same interpreter that pip belongs to; hardcoding
|
# Must use the same interpreter that pip belongs to; hardcoding
|
||||||
# python3 breaks when pip lives in a venv that only has "python".
|
# python3 breaks when pip lives in a venv that only has "python".
|
||||||
if " -m pip" in python_cmd:
|
if " -m pip" in pip_cmd:
|
||||||
python_exe = python_cmd.replace(" -m pip", "")
|
python_exe = pip_cmd.replace(" -m pip", "")
|
||||||
elif python_cmd.strip() == "pip":
|
elif pip_cmd.strip() == "pip":
|
||||||
python_exe = "python"
|
python_exe = "python"
|
||||||
elif python_cmd.strip() == "pip3":
|
elif pip_cmd.strip() == "pip3":
|
||||||
python_exe = "python3"
|
python_exe = "python3"
|
||||||
else:
|
else:
|
||||||
python_exe = "python3"
|
python_exe = "python3"
|
||||||
venv_check = f'{python_exe} -c "import sys; sys.exit(0 if sys.prefix != sys.base_prefix else 1)"'
|
venv_check = f'{python_exe} -c "import sys; sys.exit(0 if sys.prefix != sys.base_prefix else 1)"'
|
||||||
# Negated: `! venv_check` succeeds (exit 0) when NOT in a venv → `&&` tries
|
# Negated: `! venv_check` succeeds (exit 0) when NOT in a venv -> `&&` tries
|
||||||
# --user. When IN a venv `! venv_check` fails → `&&` skips --user and the
|
# --user. When IN a venv `! venv_check` fails -> `&&` skips --user and the
|
||||||
# group exits non-zero, propagating the base-install failure instead of
|
# group exits non-zero, propagating the base-install failure instead of
|
||||||
# masking it as success (the `|| { venv_check || … }` shape from #903
|
# masking it as success (the `|| { venv_check || … }` shape from #903
|
||||||
# swallowed the exit code because venv_check's exit-0 became the group's
|
# swallowed the exit code because venv_check's exit-0 became the group's
|
||||||
# result).
|
# result). `--break-system-packages` is only attempted when the active pip
|
||||||
return f"{base} || {{ ! {venv_check} && {user}; }}"
|
# supports it; older pip versions abort with "no such option" otherwise.
|
||||||
|
return f"{base} || {{ ! {venv_check} && {user_fallback}; }}"
|
||||||
|
|
||||||
|
|
||||||
def _venv_safe_local_pip_install_cmd(cmd: str, *, local: bool, in_venv: bool) -> str:
|
def _venv_safe_local_pip_install_cmd(cmd: str, *, local: bool, in_venv: bool) -> str:
|
||||||
@@ -263,6 +290,55 @@ def _venv_safe_local_pip_install_cmd(cmd: str, *, local: bool, in_venv: bool) ->
|
|||||||
return shlex.join(stripped)
|
return shlex.join(stripped)
|
||||||
|
|
||||||
|
|
||||||
|
def _pip_install_command_without_break_system_packages(cmd: str) -> str:
|
||||||
|
try:
|
||||||
|
parts = shlex.split(cmd)
|
||||||
|
except ValueError:
|
||||||
|
return cmd
|
||||||
|
stripped = [part for part in parts if part != "--break-system-packages"]
|
||||||
|
return shlex.join(stripped)
|
||||||
|
|
||||||
|
|
||||||
|
def _pip_install_help_check_from_cmd(cmd: str) -> str | None:
|
||||||
|
try:
|
||||||
|
parts = shlex.split(cmd)
|
||||||
|
except ValueError:
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
install_index = parts.index("install")
|
||||||
|
except ValueError:
|
||||||
|
return None
|
||||||
|
if install_index <= 0:
|
||||||
|
return None
|
||||||
|
pip_prefix = parts[:install_index]
|
||||||
|
return f"{shlex.join(pip_prefix + ['install', '--help'])} 2>/dev/null | grep -q -- --break-system-packages"
|
||||||
|
|
||||||
|
|
||||||
|
def _append_pip_install_runner_lines(runner_lines: list[str], cmd: str) -> None:
|
||||||
|
"""Append a pip install command, guarding --break-system-packages support.
|
||||||
|
|
||||||
|
The Dependencies UI may submit ``python3 -m pip install --user
|
||||||
|
--break-system-packages ...`` for non-venv installs. That flag is useful on
|
||||||
|
PEP-668-locked distros, but older pip (including Ubuntu 22.04's apt pip in
|
||||||
|
the NVIDIA CUDA base image) aborts with "no such option". Branch at runner
|
||||||
|
time so stale browser JS and remote targets are handled by the server too.
|
||||||
|
"""
|
||||||
|
if "--break-system-packages" not in (cmd or ""):
|
||||||
|
runner_lines.append(cmd)
|
||||||
|
return
|
||||||
|
help_check = _pip_install_help_check_from_cmd(cmd)
|
||||||
|
without_break = _pip_install_command_without_break_system_packages(cmd)
|
||||||
|
if not help_check or without_break == cmd:
|
||||||
|
runner_lines.append(cmd)
|
||||||
|
return
|
||||||
|
runner_lines.append(f"if {help_check}; then")
|
||||||
|
runner_lines.append(f" {cmd}")
|
||||||
|
runner_lines.append("else")
|
||||||
|
runner_lines.append(' echo "[odysseus] pip does not support --break-system-packages; installing without it."')
|
||||||
|
runner_lines.append(f" {without_break}")
|
||||||
|
runner_lines.append("fi")
|
||||||
|
|
||||||
|
|
||||||
def _user_shell_path_bootstrap() -> list[str]:
|
def _user_shell_path_bootstrap() -> list[str]:
|
||||||
return [
|
return [
|
||||||
'ODYSSEUS_USER_SHELL="${SHELL:-}"',
|
'ODYSSEUS_USER_SHELL="${SHELL:-}"',
|
||||||
@@ -271,11 +347,14 @@ def _user_shell_path_bootstrap() -> list[str]:
|
|||||||
' if [ -n "$ODYSSEUS_USER_PATH" ]; then export PATH="$ODYSSEUS_USER_PATH:$PATH"; fi',
|
' if [ -n "$ODYSSEUS_USER_PATH" ]; then export PATH="$ODYSSEUS_USER_PATH:$PATH"; fi',
|
||||||
'fi',
|
'fi',
|
||||||
'command -v python3 >/dev/null 2>&1 || python3() { python "$@"; }',
|
'command -v python3 >/dev/null 2>&1 || python3() { python "$@"; }',
|
||||||
|
'command -v python >/dev/null 2>&1 || python() { python3 "$@"; }',
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def _cached_model_scan_script(model_dirs: list[str] | None = None) -> str:
|
def _cached_model_scan_script(model_dirs: list[str] | None = None, add_hf_cache: str | None = None) -> str:
|
||||||
"""Build the standalone Python scanner used by /api/model/cached."""
|
"""Build the standalone Python scanner used by /api/model/cached.
|
||||||
|
Allows for an additional HuggingFace cache path to be scanned (i.e. Windows HF cache for local WSL envs.)
|
||||||
|
"""
|
||||||
lines = [
|
lines = [
|
||||||
"import json, os, re, shutil, subprocess, urllib.request",
|
"import json, os, re, shutil, subprocess, urllib.request",
|
||||||
"models = []",
|
"models = []",
|
||||||
@@ -338,6 +417,15 @@ def _cached_model_scan_script(model_dirs: list[str] | None = None) -> str:
|
|||||||
" if f.is_file(): nf += 1; sz += f.stat().st_size",
|
" if f.is_file(): nf += 1; sz += f.stat().st_size",
|
||||||
" if f.name.endswith('.incomplete'): ic = True",
|
" if f.name.endswith('.incomplete'): ic = True",
|
||||||
" snap = os.path.join(cache, d, 'snapshots')",
|
" snap = os.path.join(cache, d, 'snapshots')",
|
||||||
|
" # Windows HF cache stores files directly in snapshots/; blobs/ may be empty.",
|
||||||
|
" # Fallback: scan snapshots for real files when blobs yielded nothing.",
|
||||||
|
" if sz == 0 and os.path.isdir(snap):",
|
||||||
|
" for sd in os.listdir(snap):",
|
||||||
|
" sf = os.path.join(snap, sd)",
|
||||||
|
" if not os.path.isdir(sf): continue",
|
||||||
|
" for f in os.scandir(sf):",
|
||||||
|
" if f.is_file(): nf += 1; sz += f.stat().st_size",
|
||||||
|
" if f.name.endswith('.incomplete'): ic = True",
|
||||||
" is_diffusion = False; gguf_files = []",
|
" is_diffusion = False; gguf_files = []",
|
||||||
" if os.path.isdir(snap):",
|
" if os.path.isdir(snap):",
|
||||||
" for sd in os.listdir(snap):",
|
" for sd in os.listdir(snap):",
|
||||||
@@ -346,6 +434,21 @@ def _cached_model_scan_script(model_dirs: list[str] | None = None) -> str:
|
|||||||
" if os.path.exists(os.path.join(sf, 'model_index.json')): is_diffusion = True",
|
" if os.path.exists(os.path.join(sf, 'model_index.json')): is_diffusion = True",
|
||||||
" for f in collect_ggufs(sf): f['rel_path'] = sd + '/' + f['rel_path']; gguf_files.append(f)",
|
" for f in collect_ggufs(sf): f['rel_path'] = sd + '/' + f['rel_path']; gguf_files.append(f)",
|
||||||
" models.append({'repo_id':rid,'size_bytes':sz,'nb_files':nf,'has_incomplete':ic,'path':cache,'is_diffusion':is_diffusion,'is_gguf':bool(gguf_files),'gguf_files':gguf_files})",
|
" models.append({'repo_id':rid,'size_bytes':sz,'nb_files':nf,'has_incomplete':ic,'path':cache,'is_diffusion':is_diffusion,'is_gguf':bool(gguf_files),'gguf_files':gguf_files})",
|
||||||
|
"def hf_cache_paths():",
|
||||||
|
" candidates = []",
|
||||||
|
" def add(p):",
|
||||||
|
" if not p: return",
|
||||||
|
" p = os.path.expanduser(p)",
|
||||||
|
" if p not in candidates: candidates.append(p)",
|
||||||
|
" add(os.environ.get('HUGGINGFACE_HUB_CACHE'))",
|
||||||
|
" hf_home = os.environ.get('HF_HOME')",
|
||||||
|
" if hf_home: add(os.path.join(hf_home, 'hub'))",
|
||||||
|
" add('~/.cache/huggingface/hub')",
|
||||||
|
" # Docker images mount ./data/huggingface at /app/.cache/huggingface.",
|
||||||
|
" # When HOME is /root, expanduser() misses that persisted cache.",
|
||||||
|
" add('/app/.cache/huggingface/hub')",
|
||||||
|
f" add({add_hf_cache!r})" if add_hf_cache else "",
|
||||||
|
" return candidates",
|
||||||
"def scan_dir(p):",
|
"def scan_dir(p):",
|
||||||
" if not os.path.isdir(p) or not safe_path(p): return",
|
" if not os.path.isdir(p) or not safe_path(p): return",
|
||||||
" for d in sorted(os.listdir(p)):",
|
" for d in sorted(os.listdir(p)):",
|
||||||
@@ -409,7 +512,7 @@ def _cached_model_scan_script(model_dirs: list[str] | None = None) -> str:
|
|||||||
" seen.add(name)",
|
" seen.add(name)",
|
||||||
" models.append({'repo_id':name,'size_bytes':size_bytes,'nb_files':1,'has_incomplete':False,'path':'ollama','backend':'ollama','is_ollama':True})",
|
" models.append({'repo_id':name,'size_bytes':size_bytes,'nb_files':1,'has_incomplete':False,'path':'ollama','backend':'ollama','is_ollama':True})",
|
||||||
" return",
|
" return",
|
||||||
"scan_hf(os.path.expanduser('~/.cache/huggingface/hub'))",
|
"for _hf_cache in hf_cache_paths(): scan_hf(_hf_cache)",
|
||||||
"scan_ollama()",
|
"scan_ollama()",
|
||||||
"scan_ollama_api()",
|
"scan_ollama_api()",
|
||||||
]
|
]
|
||||||
@@ -525,6 +628,7 @@ def _validate_serve_cmd(v: str | None) -> str | None:
|
|||||||
# Backticks and raw newlines are never legitimate here.
|
# Backticks and raw newlines are never legitimate here.
|
||||||
if any(c in v for c in ("`", "\n", "\r")):
|
if any(c in v for c in ("`", "\n", "\r")):
|
||||||
raise HTTPException(400, "Invalid characters in cmd")
|
raise HTTPException(400, "Invalid characters in cmd")
|
||||||
|
|
||||||
# Known GGUF launcher prelude → validate the serve invocation(s) it guards.
|
# Known GGUF launcher prelude → validate the serve invocation(s) it guards.
|
||||||
m = _GGUF_PRELUDE_RE.match(v)
|
m = _GGUF_PRELUDE_RE.match(v)
|
||||||
if m:
|
if m:
|
||||||
@@ -533,9 +637,19 @@ def _validate_serve_cmd(v: str | None) -> str | None:
|
|||||||
for part in rest.split("||"):
|
for part in rest.split("||"):
|
||||||
_check_serve_binary(part.strip())
|
_check_serve_binary(part.strip())
|
||||||
return v
|
return v
|
||||||
|
|
||||||
# Otherwise: a single invocation — no shell metacharacters allowed.
|
# Otherwise: a single invocation — no shell metacharacters allowed.
|
||||||
|
# Temporarily replace safe $(printf %s ...) expressions with a placeholder
|
||||||
|
# to avoid triggering the metacharacter/command-injection checks.
|
||||||
|
cleaned_v = v
|
||||||
|
printf_matches = list(re.finditer(r"\$\(\s*printf\s+%s\s+([^\n()]*?)\)", v))
|
||||||
|
for match in printf_matches:
|
||||||
|
inner = match.group(1)
|
||||||
|
if not any(c in inner for c in (";", "&&", "||", "$(", "`")):
|
||||||
|
cleaned_v = cleaned_v.replace(match.group(0), "/placeholder/safe/path.gguf")
|
||||||
|
|
||||||
# (`$(` was the original intent; bare `$` is fine for shell-safe paths.)
|
# (`$(` was the original intent; bare `$` is fine for shell-safe paths.)
|
||||||
if any(c in v for c in (";", "&&", "||", "$(")):
|
if any(c in cleaned_v for c in (";", "&&", "||", "$(")):
|
||||||
raise HTTPException(400, "Invalid characters in cmd")
|
raise HTTPException(400, "Invalid characters in cmd")
|
||||||
_check_serve_binary(v)
|
_check_serve_binary(v)
|
||||||
return v
|
return v
|
||||||
@@ -546,12 +660,34 @@ def _append_serve_preflight_exit_lines(runner_lines: list[str], *, keep_shell_op
|
|||||||
runner_lines.append('if [ -n "$ODYSSEUS_PREFLIGHT_EXIT" ]; then')
|
runner_lines.append('if [ -n "$ODYSSEUS_PREFLIGHT_EXIT" ]; then')
|
||||||
runner_lines.append(' echo ""; echo "=== Process exited with code $ODYSSEUS_PREFLIGHT_EXIT ==="')
|
runner_lines.append(' echo ""; echo "=== Process exited with code $ODYSSEUS_PREFLIGHT_EXIT ==="')
|
||||||
if keep_shell_open:
|
if keep_shell_open:
|
||||||
|
# Decouple the post-crash interactive shell from the persistent log
|
||||||
|
# file. fds 3/4 were saved BEFORE the tee redirect at the top of
|
||||||
|
# the runner; restoring them here means the neofetch banner the
|
||||||
|
# user's .zshrc prints lands on the tmux pane only, not in the
|
||||||
|
# log file the agent's tail_serve_output reads.
|
||||||
|
runner_lines.append(' exec 1>&3 2>&4 3>&- 4>&- 2>/dev/null || true')
|
||||||
|
runner_lines.append(' sleep 0.2 # let tee child flush + exit')
|
||||||
runner_lines.append(' exec "${SHELL:-/bin/bash}"')
|
runner_lines.append(' exec "${SHELL:-/bin/bash}"')
|
||||||
else:
|
else:
|
||||||
runner_lines.append(' exit "$ODYSSEUS_PREFLIGHT_EXIT"')
|
runner_lines.append(' exit "$ODYSSEUS_PREFLIGHT_EXIT"')
|
||||||
runner_lines.append('fi')
|
runner_lines.append('fi')
|
||||||
|
|
||||||
|
|
||||||
|
def _append_vllm_linux_preflight_lines(runner_lines: list[str]) -> None:
|
||||||
|
"""Append Linux vLLM readiness lines that identify the runtime being used."""
|
||||||
|
# Keep the user install bin visible for Odysseus-managed `pip install --user`
|
||||||
|
# installs, but then report the actual CLI path so external runtimes are clear.
|
||||||
|
runner_lines.append('export PATH="$HOME/.local/bin:$PATH"')
|
||||||
|
runner_lines.append('ODYSSEUS_VLLM_BIN="$(command -v vllm 2>/dev/null || true)"')
|
||||||
|
runner_lines.append('if [ -z "$ODYSSEUS_VLLM_BIN" ]; then')
|
||||||
|
runner_lines.append(' echo "ERROR: vLLM is not installed."')
|
||||||
|
runner_lines.append(' ODYSSEUS_PREFLIGHT_EXIT=127')
|
||||||
|
runner_lines.append('else')
|
||||||
|
runner_lines.append(' echo "[odysseus] vLLM CLI: $ODYSSEUS_VLLM_BIN"')
|
||||||
|
runner_lines.append(' ODYSSEUS_VLLM_VERSION="$("$ODYSSEUS_VLLM_BIN" --version 2>&1 | head -n 1 || true)"')
|
||||||
|
runner_lines.append(' if [ -n "$ODYSSEUS_VLLM_VERSION" ]; then echo "[odysseus] vLLM version: $ODYSSEUS_VLLM_VERSION"; fi')
|
||||||
|
runner_lines.append('fi')
|
||||||
|
|
||||||
def _append_serve_exit_code_lines(
|
def _append_serve_exit_code_lines(
|
||||||
runner_lines: list[str],
|
runner_lines: list[str],
|
||||||
*,
|
*,
|
||||||
@@ -563,7 +699,11 @@ def _append_serve_exit_code_lines(
|
|||||||
if is_pip_install:
|
if is_pip_install:
|
||||||
runner_lines.append('if [ $ODYSSEUS_CMD_EXIT -eq 0 ]; then echo ""; echo "DOWNLOAD_OK"; fi')
|
runner_lines.append('if [ $ODYSSEUS_CMD_EXIT -eq 0 ]; then echo ""; echo "DOWNLOAD_OK"; fi')
|
||||||
if keep_shell_open:
|
if keep_shell_open:
|
||||||
runner_lines.append('echo ""; echo "=== Process exited with code $ODYSSEUS_CMD_EXIT ==="; exec "${SHELL:-/bin/bash}"')
|
runner_lines.append('echo ""; echo "=== Process exited with code $ODYSSEUS_CMD_EXIT ==="')
|
||||||
|
# See preflight branch above for the rationale on restoring fds 3/4.
|
||||||
|
runner_lines.append('exec 1>&3 2>&4 3>&- 4>&- 2>/dev/null || true')
|
||||||
|
runner_lines.append('sleep 0.2 # let tee child flush + exit')
|
||||||
|
runner_lines.append('exec "${SHELL:-/bin/bash}"')
|
||||||
else:
|
else:
|
||||||
runner_lines.append('echo ""; echo "=== Process exited with code $ODYSSEUS_CMD_EXIT ==="')
|
runner_lines.append('echo ""; echo "=== Process exited with code $ODYSSEUS_CMD_EXIT ==="')
|
||||||
runner_lines.append('exit "$ODYSSEUS_CMD_EXIT"')
|
runner_lines.append('exit "$ODYSSEUS_CMD_EXIT"')
|
||||||
@@ -793,3 +933,172 @@ def _ssh_ps(host, script_path, port=None):
|
|||||||
|
|
||||||
# Windows session dir — stored in user's temp on the remote
|
# Windows session dir — stored in user's temp on the remote
|
||||||
WIN_SESSION_DIR = "$env:TEMP\\\\odysseus-sessions"
|
WIN_SESSION_DIR = "$env:TEMP\\\\odysseus-sessions"
|
||||||
|
|
||||||
|
|
||||||
|
def _diagnose_serve_output(text: str) -> dict | None:
|
||||||
|
"""Server-side mirror of the Cookbook UI's common serve diagnoses.
|
||||||
|
|
||||||
|
The browser uses cookbook-diagnosis.js for clickable fixes. This gives
|
||||||
|
the agent/tool path the same structured signal so it can retry with an
|
||||||
|
adjusted command instead of guessing from raw tmux output.
|
||||||
|
"""
|
||||||
|
if not text:
|
||||||
|
return None
|
||||||
|
tail = text[-6000:]
|
||||||
|
patterns = [
|
||||||
|
(
|
||||||
|
r"No available memory for the cache blocks|Available KV cache memory:.*-",
|
||||||
|
"No GPU memory left for KV cache after loading model.",
|
||||||
|
[
|
||||||
|
{"label": "retry with GPU memory utilization 0.95", "op": "replace", "flag": "--gpu-memory-utilization", "value": "0.95"},
|
||||||
|
{"label": "retry with context 2048", "op": "replace", "flag": "--max-model-len", "value": "2048"},
|
||||||
|
],
|
||||||
|
),
|
||||||
|
(
|
||||||
|
r"CUDA out of memory|torch\.cuda\.OutOfMemoryError|CUDA error: out of memory|warming up sampler|max_num_seqs.*gpu_memory_utilization",
|
||||||
|
"GPU ran out of memory during startup or warmup.",
|
||||||
|
[
|
||||||
|
{"label": "retry with context 4096", "op": "replace", "flag": "--max-model-len", "value": "4096"},
|
||||||
|
{"label": "retry with GPU memory utilization 0.80", "op": "replace", "flag": "--gpu-memory-utilization", "value": "0.80"},
|
||||||
|
{"label": "retry with --enforce-eager", "op": "append", "arg": "--enforce-eager"},
|
||||||
|
],
|
||||||
|
),
|
||||||
|
(
|
||||||
|
r"not divisib|must be divisible|attention heads.*divisible",
|
||||||
|
"Tensor parallel size is incompatible with the model.",
|
||||||
|
[
|
||||||
|
{"label": "retry with tensor parallel size 1", "op": "replace", "flag": "--tensor-parallel-size", "value": "1"},
|
||||||
|
{"label": "retry with tensor parallel size 2", "op": "replace", "flag": "--tensor-parallel-size", "value": "2"},
|
||||||
|
],
|
||||||
|
),
|
||||||
|
(
|
||||||
|
r"KV cache.*too (small|large)|max_model_len.*exceeds|maximum.*context",
|
||||||
|
"Context length is too large for available GPU memory.",
|
||||||
|
[
|
||||||
|
{"label": "retry with context 8192", "op": "replace", "flag": "--max-model-len", "value": "8192"},
|
||||||
|
{"label": "retry with context 4096", "op": "replace", "flag": "--max-model-len", "value": "4096"},
|
||||||
|
],
|
||||||
|
),
|
||||||
|
(
|
||||||
|
r"enable-auto-tool-choice requires --tool-call-parser",
|
||||||
|
"Auto tool choice requires an explicit tool call parser.",
|
||||||
|
[{"label": "retry with Hermes tool parser", "op": "append", "arg": "--tool-call-parser hermes"}],
|
||||||
|
),
|
||||||
|
(
|
||||||
|
r"Please pass.*trust.remote.code=True|contains custom code which must be executed to correctly load|does not recognize this architecture|model type.*but Transformers does not",
|
||||||
|
"Model requires custom code or newer model support.",
|
||||||
|
[{"label": "retry with --trust-remote-code", "op": "append", "arg": "--trust-remote-code"}],
|
||||||
|
),
|
||||||
|
(
|
||||||
|
r"There is no module or parameter named ['\"]lm_head\.input_scale['\"]|lm_head\.input_scale|weight_scale_2",
|
||||||
|
"vLLM cannot load this ModelOpt LM-head quantized checkpoint with the current runtime.",
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"label": "upgrade vLLM through the environment that provides this CLI, or use a compatible checkpoint",
|
||||||
|
"op": "manual",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
),
|
||||||
|
(
|
||||||
|
r"Either a revision or a version must be specified|transformers\.integrations\.hub_kernels|kernels/layer",
|
||||||
|
"vLLM/Transformers kernel package mismatch.",
|
||||||
|
[{"label": "update vLLM, Transformers, and kernels on this server", "op": "dependency", "package": "vllm transformers kernels"}],
|
||||||
|
),
|
||||||
|
(
|
||||||
|
r"Address already in use|bind.*address.*in use",
|
||||||
|
"Port is already in use.",
|
||||||
|
[{"label": "retry on port 8001", "op": "replace", "flag": "--port", "value": "8001"}],
|
||||||
|
),
|
||||||
|
(
|
||||||
|
r"No CUDA GPUs are available|no GPU.*found|CUDA_VISIBLE_DEVICES.*invalid",
|
||||||
|
"No GPUs are visible to the serve process.",
|
||||||
|
[{"label": "clear Cookbook GPU selection or choose available GPUs", "op": "settings", "field": "gpus", "value": ""}],
|
||||||
|
),
|
||||||
|
(
|
||||||
|
r"Failed to infer device type|NVML Shared Library Not Found|No module named 'amdsmi'|platform is not available",
|
||||||
|
"vLLM could not find a supported GPU (CUDA or ROCm). "
|
||||||
|
"This machine may have integrated or unsupported graphics only.",
|
||||||
|
[
|
||||||
|
{"label": "switch to llama.cpp (CPU/Metal, works without a discrete GPU)", "op": "manual"},
|
||||||
|
{"label": "switch to Ollama (CPU/Metal, works without a discrete GPU)", "op": "manual"},
|
||||||
|
],
|
||||||
|
),
|
||||||
|
(
|
||||||
|
r"vllm.*command not found|No module named vllm|ERROR: vLLM is not installed",
|
||||||
|
"vLLM is not installed or not in PATH on this server.",
|
||||||
|
[{"label": "install vLLM in Cookbook Dependencies", "op": "dependency", "package": "vllm"}],
|
||||||
|
),
|
||||||
|
(
|
||||||
|
r"sglang.*command not found|No module named sglang|SGLang is not installed",
|
||||||
|
"SGLang is not installed or not in PATH on this server.",
|
||||||
|
[{"label": "install SGLang in Cookbook Dependencies", "op": "dependency", "package": "sglang[all]"}],
|
||||||
|
),
|
||||||
|
(
|
||||||
|
r"llama-server.*command not found|llama\.cpp.*not found|No module named.*llama_cpp|No module named 'starlette_context'|git: command not found|cmake: command not found",
|
||||||
|
"llama.cpp / llama-cpp-python dependencies are missing.",
|
||||||
|
[{"label": "install llama.cpp dependencies or llama-cpp-python[server]", "op": "dependency", "package": "llama-cpp-python[server]"}],
|
||||||
|
),
|
||||||
|
(
|
||||||
|
r"No GGUF found on this host|no \.gguf file|No GGUF file found",
|
||||||
|
"No GGUF file found for this model on this host. The llama.cpp backend needs a .gguf file.",
|
||||||
|
[{"label": "download a GGUF build of this model (repo name usually ends in -GGUF, file like Q4_K_M.gguf)", "op": "manual"}],
|
||||||
|
),
|
||||||
|
(
|
||||||
|
r"No module named 'torch'|No module named torch|No module named 'diffusers'|No module named diffusers",
|
||||||
|
"Diffusion serving requires PyTorch and diffusers.",
|
||||||
|
[{"label": "install diffusers[torch] in Cookbook Dependencies", "op": "dependency", "package": "diffusers[torch]"}],
|
||||||
|
),
|
||||||
|
(
|
||||||
|
r"403 Forbidden|401 Unauthorized|Access to model.*is restricted|gated repo|not in the authorized list|awaiting a review",
|
||||||
|
"Model access is gated or unauthorized.",
|
||||||
|
[{"label": "set HF token and request model access on HuggingFace", "op": "manual"}],
|
||||||
|
),
|
||||||
|
]
|
||||||
|
for pattern, message, suggestions in patterns:
|
||||||
|
if re.search(pattern, tail, re.I):
|
||||||
|
return {"message": message, "suggestions": suggestions}
|
||||||
|
if re.search(r"Traceback \(most recent call last\)", tail, re.I) and not re.search(
|
||||||
|
r"Application startup complete|GET /v1/|Uvicorn running on", tail, re.I
|
||||||
|
):
|
||||||
|
return {
|
||||||
|
"message": "Python traceback detected during serve startup.",
|
||||||
|
"suggestions": [{"label": "inspect traceback and retry with adjusted backend/settings", "op": "manual"}],
|
||||||
|
}
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
async def run_ssh_command_async(
|
||||||
|
remote: str,
|
||||||
|
ssh_port: str | None,
|
||||||
|
remote_cmd: str,
|
||||||
|
*,
|
||||||
|
timeout: float,
|
||||||
|
connect_timeout: int | None = None,
|
||||||
|
strict_host_key_checking: bool | None = None,
|
||||||
|
stdin_data: bytes | None = None,
|
||||||
|
) -> tuple[int, bytes, bytes]:
|
||||||
|
"""Run an ssh command with centralized timeout and stderr/stdout capture.
|
||||||
|
Async version of core.platform_compat.run_ssh_command_sync.
|
||||||
|
"""
|
||||||
|
import asyncio
|
||||||
|
proc = await asyncio.create_subprocess_exec(
|
||||||
|
*_ssh_exec_argv(
|
||||||
|
remote,
|
||||||
|
ssh_port,
|
||||||
|
remote_cmd=remote_cmd,
|
||||||
|
connect_timeout=connect_timeout,
|
||||||
|
strict_host_key_checking=strict_host_key_checking,
|
||||||
|
),
|
||||||
|
stdin=asyncio.subprocess.PIPE if stdin_data is not None else None,
|
||||||
|
stdout=asyncio.subprocess.PIPE,
|
||||||
|
stderr=asyncio.subprocess.PIPE,
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
stdout, stderr = await asyncio.wait_for(
|
||||||
|
proc.communicate(input=stdin_data), timeout=timeout
|
||||||
|
)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
proc.kill()
|
||||||
|
await proc.communicate()
|
||||||
|
raise
|
||||||
|
return proc.returncode or 0, stdout, stderr
|
||||||
|
|||||||
+464
-234
@@ -15,19 +15,26 @@ from pathlib import Path
|
|||||||
from fastapi import APIRouter, HTTPException, Request, Depends
|
from fastapi import APIRouter, HTTPException, Request, Depends
|
||||||
|
|
||||||
from src.auth_helpers import require_user
|
from src.auth_helpers import require_user
|
||||||
|
from src.constants import COOKBOOK_STATE_FILE
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from core.middleware import require_admin
|
from core.middleware import require_admin
|
||||||
from core.platform_compat import (
|
from core.platform_compat import (
|
||||||
IS_WINDOWS,
|
IS_WINDOWS,
|
||||||
|
SSH_PATH_OVERRIDE,
|
||||||
|
NVIDIA_PATH_CANDIDATES,
|
||||||
detached_popen_kwargs,
|
detached_popen_kwargs,
|
||||||
find_bash,
|
find_bash,
|
||||||
|
git_bash_path,
|
||||||
kill_process_tree,
|
kill_process_tree,
|
||||||
pid_alive,
|
pid_alive,
|
||||||
safe_chmod,
|
safe_chmod,
|
||||||
which_tool,
|
which_tool,
|
||||||
|
translate_path,
|
||||||
|
get_wsl_windows_user_profile,
|
||||||
)
|
)
|
||||||
from routes.shell_routes import TMUX_LOG_DIR
|
from routes.shell_routes import TMUX_LOG_DIR
|
||||||
|
from src.constants import COOKBOOK_STATE_FILE
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -38,8 +45,10 @@ from routes.cookbook_helpers import (
|
|||||||
_ps_squote, _bash_squote, _validate_serve_cmd, _parse_serve_phase,
|
_ps_squote, _bash_squote, _validate_serve_cmd, _parse_serve_phase,
|
||||||
_safe_env_prefix, _local_tooling_path_export, _append_serve_preflight_exit_lines,
|
_safe_env_prefix, _local_tooling_path_export, _append_serve_preflight_exit_lines,
|
||||||
_append_serve_exit_code_lines, _append_llama_cpp_linux_accel_build_lines, _cached_model_scan_script,
|
_append_serve_exit_code_lines, _append_llama_cpp_linux_accel_build_lines, _cached_model_scan_script,
|
||||||
_ollama_bind_from_cmd, _pip_install_fallback_chain, _pip_install_no_cache,
|
_append_vllm_linux_preflight_lines, _ollama_bind_from_cmd, _pip_install_fallback_chain,
|
||||||
_user_shell_path_bootstrap, _venv_safe_local_pip_install_cmd,
|
_pip_install_no_cache, _user_shell_path_bootstrap, _venv_safe_local_pip_install_cmd,
|
||||||
|
_append_pip_install_runner_lines,
|
||||||
|
_diagnose_serve_output, run_ssh_command_async,
|
||||||
ModelDownloadRequest, ServeRequest,
|
ModelDownloadRequest, ServeRequest,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -54,7 +63,7 @@ _HF_TOKEN_STATUS_SNIPPET = (
|
|||||||
|
|
||||||
def setup_cookbook_routes() -> APIRouter:
|
def setup_cookbook_routes() -> APIRouter:
|
||||||
router = APIRouter(tags=["cookbook"])
|
router = APIRouter(tags=["cookbook"])
|
||||||
_cookbook_state_path = Path(os.environ.get("DATA_DIR", "data")) / "cookbook_state.json"
|
_cookbook_state_path = Path(COOKBOOK_STATE_FILE)
|
||||||
|
|
||||||
def _mask_secret(value: str) -> str:
|
def _mask_secret(value: str) -> str:
|
||||||
if not value:
|
if not value:
|
||||||
@@ -81,127 +90,6 @@ def setup_cookbook_routes() -> APIRouter:
|
|||||||
task["payload"].pop("hf_token", None)
|
task["payload"].pop("hf_token", None)
|
||||||
return state
|
return state
|
||||||
|
|
||||||
def _diagnose_serve_output(text: str) -> dict | None:
|
|
||||||
"""Server-side mirror of the Cookbook UI's common serve diagnoses.
|
|
||||||
|
|
||||||
The browser uses cookbook-diagnosis.js for clickable fixes. This gives
|
|
||||||
the agent/tool path the same structured signal so it can retry with an
|
|
||||||
adjusted command instead of guessing from raw tmux output.
|
|
||||||
"""
|
|
||||||
if not text:
|
|
||||||
return None
|
|
||||||
tail = text[-6000:]
|
|
||||||
patterns = [
|
|
||||||
(
|
|
||||||
r"No available memory for the cache blocks|Available KV cache memory:.*-",
|
|
||||||
"No GPU memory left for KV cache after loading model.",
|
|
||||||
[
|
|
||||||
{"label": "retry with GPU memory utilization 0.95", "op": "replace", "flag": "--gpu-memory-utilization", "value": "0.95"},
|
|
||||||
{"label": "retry with context 2048", "op": "replace", "flag": "--max-model-len", "value": "2048"},
|
|
||||||
],
|
|
||||||
),
|
|
||||||
(
|
|
||||||
r"CUDA out of memory|torch\.cuda\.OutOfMemoryError|CUDA error: out of memory|warming up sampler|max_num_seqs.*gpu_memory_utilization",
|
|
||||||
"GPU ran out of memory during startup or warmup.",
|
|
||||||
[
|
|
||||||
{"label": "retry with context 4096", "op": "replace", "flag": "--max-model-len", "value": "4096"},
|
|
||||||
{"label": "retry with GPU memory utilization 0.80", "op": "replace", "flag": "--gpu-memory-utilization", "value": "0.80"},
|
|
||||||
{"label": "retry with --enforce-eager", "op": "append", "arg": "--enforce-eager"},
|
|
||||||
],
|
|
||||||
),
|
|
||||||
(
|
|
||||||
r"not divisib|must be divisible|attention heads.*divisible",
|
|
||||||
"Tensor parallel size is incompatible with the model.",
|
|
||||||
[
|
|
||||||
{"label": "retry with tensor parallel size 1", "op": "replace", "flag": "--tensor-parallel-size", "value": "1"},
|
|
||||||
{"label": "retry with tensor parallel size 2", "op": "replace", "flag": "--tensor-parallel-size", "value": "2"},
|
|
||||||
],
|
|
||||||
),
|
|
||||||
(
|
|
||||||
r"KV cache.*too (small|large)|max_model_len.*exceeds|maximum.*context",
|
|
||||||
"Context length is too large for available GPU memory.",
|
|
||||||
[
|
|
||||||
{"label": "retry with context 8192", "op": "replace", "flag": "--max-model-len", "value": "8192"},
|
|
||||||
{"label": "retry with context 4096", "op": "replace", "flag": "--max-model-len", "value": "4096"},
|
|
||||||
],
|
|
||||||
),
|
|
||||||
(
|
|
||||||
r"enable-auto-tool-choice requires --tool-call-parser",
|
|
||||||
"Auto tool choice requires an explicit tool call parser.",
|
|
||||||
[{"label": "retry with Hermes tool parser", "op": "append", "arg": "--tool-call-parser hermes"}],
|
|
||||||
),
|
|
||||||
(
|
|
||||||
r"Please pass.*trust.remote.code=True|contains custom code which must be executed to correctly load|does not recognize this architecture|model type.*but Transformers does not",
|
|
||||||
"Model requires custom code or newer model support.",
|
|
||||||
[{"label": "retry with --trust-remote-code", "op": "append", "arg": "--trust-remote-code"}],
|
|
||||||
),
|
|
||||||
(
|
|
||||||
r"Either a revision or a version must be specified|transformers\.integrations\.hub_kernels|kernels/layer",
|
|
||||||
"vLLM/Transformers kernel package mismatch.",
|
|
||||||
[{"label": "update vLLM, Transformers, and kernels on this server", "op": "dependency", "package": "vllm transformers kernels"}],
|
|
||||||
),
|
|
||||||
(
|
|
||||||
r"Address already in use|bind.*address.*in use",
|
|
||||||
"Port is already in use.",
|
|
||||||
[{"label": "retry on port 8001", "op": "replace", "flag": "--port", "value": "8001"}],
|
|
||||||
),
|
|
||||||
(
|
|
||||||
r"No CUDA GPUs are available|no GPU.*found|CUDA_VISIBLE_DEVICES.*invalid",
|
|
||||||
"No GPUs are visible to the serve process.",
|
|
||||||
[{"label": "clear Cookbook GPU selection or choose available GPUs", "op": "settings", "field": "gpus", "value": ""}],
|
|
||||||
),
|
|
||||||
(
|
|
||||||
r"Failed to infer device type|NVML Shared Library Not Found|No module named 'amdsmi'|platform is not available",
|
|
||||||
"vLLM could not find a supported GPU (CUDA or ROCm). "
|
|
||||||
"This machine may have integrated or unsupported graphics only.",
|
|
||||||
[
|
|
||||||
{"label": "switch to llama.cpp (CPU/Metal, works without a discrete GPU)", "op": "manual"},
|
|
||||||
{"label": "switch to Ollama (CPU/Metal, works without a discrete GPU)", "op": "manual"},
|
|
||||||
],
|
|
||||||
),
|
|
||||||
(
|
|
||||||
r"vllm.*command not found|No module named vllm|ERROR: vLLM is not installed",
|
|
||||||
"vLLM is not installed or not in PATH on this server.",
|
|
||||||
[{"label": "install vLLM in Cookbook Dependencies", "op": "dependency", "package": "vllm"}],
|
|
||||||
),
|
|
||||||
(
|
|
||||||
r"sglang.*command not found|No module named sglang|SGLang is not installed",
|
|
||||||
"SGLang is not installed or not in PATH on this server.",
|
|
||||||
[{"label": "install SGLang in Cookbook Dependencies", "op": "dependency", "package": "sglang[all]"}],
|
|
||||||
),
|
|
||||||
(
|
|
||||||
r"llama-server.*command not found|llama\.cpp.*not found|No module named.*llama_cpp|No module named 'starlette_context'|git: command not found|cmake: command not found",
|
|
||||||
"llama.cpp / llama-cpp-python dependencies are missing.",
|
|
||||||
[{"label": "install llama.cpp dependencies or llama-cpp-python[server]", "op": "dependency", "package": "llama-cpp-python[server]"}],
|
|
||||||
),
|
|
||||||
(
|
|
||||||
r"No GGUF found on this host|no \.gguf file|No GGUF file found",
|
|
||||||
"No GGUF file found for this model on this host. The llama.cpp backend needs a .gguf file.",
|
|
||||||
[{"label": "download a GGUF build of this model (repo name usually ends in -GGUF, file like Q4_K_M.gguf)", "op": "manual"}],
|
|
||||||
),
|
|
||||||
(
|
|
||||||
r"No module named 'torch'|No module named torch|No module named 'diffusers'|No module named diffusers",
|
|
||||||
"Diffusion serving requires PyTorch and diffusers.",
|
|
||||||
[{"label": "install diffusers[torch] in Cookbook Dependencies", "op": "dependency", "package": "diffusers[torch]"}],
|
|
||||||
),
|
|
||||||
(
|
|
||||||
r"403 Forbidden|401 Unauthorized|Access to model.*is restricted|gated repo|not in the authorized list|awaiting a review",
|
|
||||||
"Model access is gated or unauthorized.",
|
|
||||||
[{"label": "set HF token and request model access on HuggingFace", "op": "manual"}],
|
|
||||||
),
|
|
||||||
]
|
|
||||||
for pattern, message, suggestions in patterns:
|
|
||||||
if re.search(pattern, tail, re.I):
|
|
||||||
return {"message": message, "suggestions": suggestions}
|
|
||||||
if re.search(r"Traceback \(most recent call last\)", tail, re.I) and not re.search(
|
|
||||||
r"Application startup complete|GET /v1/|Uvicorn running on", tail, re.I
|
|
||||||
):
|
|
||||||
return {
|
|
||||||
"message": "Python traceback detected during serve startup.",
|
|
||||||
"suggestions": [{"label": "inspect traceback and retry with adjusted backend/settings", "op": "manual"}],
|
|
||||||
}
|
|
||||||
return None
|
|
||||||
|
|
||||||
def _state_for_client(state):
|
def _state_for_client(state):
|
||||||
"""Return cookbook state without raw secrets for browser clients."""
|
"""Return cookbook state without raw secrets for browser clients."""
|
||||||
_strip_task_secrets(state)
|
_strip_task_secrets(state)
|
||||||
@@ -295,6 +183,7 @@ def setup_cookbook_routes() -> APIRouter:
|
|||||||
safe_chmod(key_path.with_suffix(".pub"), 0o644)
|
safe_chmod(key_path.with_suffix(".pub"), 0o644)
|
||||||
return {"ok": True, "public_key": _read_cookbook_public_key()}
|
return {"ok": True, "public_key": _read_cookbook_public_key()}
|
||||||
|
|
||||||
|
|
||||||
def _needs_binary(cmd: str, binary: str) -> bool:
|
def _needs_binary(cmd: str, binary: str) -> bool:
|
||||||
return bool(re.search(rf"(^|[\s;&|()]){re.escape(binary)}($|[\s;&|()])", cmd or ""))
|
return bool(re.search(rf"(^|[\s;&|()]){re.escape(binary)}($|[\s;&|()])", cmd or ""))
|
||||||
|
|
||||||
@@ -355,8 +244,8 @@ def setup_cookbook_routes() -> APIRouter:
|
|||||||
# POSIX form + shell-quoting so drive paths / spaces survive.
|
# POSIX form + shell-quoting so drive paths / spaces survive.
|
||||||
inner = TMUX_LOG_DIR / f"{session_id}_run.sh"
|
inner = TMUX_LOG_DIR / f"{session_id}_run.sh"
|
||||||
inner.write_text("\n".join(bash_lines) + "\n", encoding="utf-8")
|
inner.write_text("\n".join(bash_lines) + "\n", encoding="utf-8")
|
||||||
lp = shlex.quote(log_path.as_posix())
|
lp = shlex.quote(git_bash_path(log_path))
|
||||||
ip = shlex.quote(inner.as_posix())
|
ip = shlex.quote(git_bash_path(inner))
|
||||||
script_path = TMUX_LOG_DIR / f"{session_id}.sh"
|
script_path = TMUX_LOG_DIR / f"{session_id}.sh"
|
||||||
script_path.write_text(
|
script_path.write_text(
|
||||||
f"bash {ip} > {lp} 2>&1\n",
|
f"bash {ip} > {lp} 2>&1\n",
|
||||||
@@ -472,6 +361,8 @@ def setup_cookbook_routes() -> APIRouter:
|
|||||||
ps_lines = []
|
ps_lines = []
|
||||||
ps_lines.append('$sessionDir = "$env:TEMP\\odysseus-sessions"')
|
ps_lines.append('$sessionDir = "$env:TEMP\\odysseus-sessions"')
|
||||||
ps_lines.append('New-Item -ItemType Directory -Force -Path $sessionDir | Out-Null')
|
ps_lines.append('New-Item -ItemType Directory -Force -Path $sessionDir | Out-Null')
|
||||||
|
ps_lines.append('$env:PYTHONIOENCODING = "utf-8"')
|
||||||
|
ps_lines.append('$env:PYTHONUTF8 = "1"')
|
||||||
if req.hf_token:
|
if req.hf_token:
|
||||||
ps_lines.append(f"$env:HF_TOKEN = '{_ps_squote(req.hf_token)}'")
|
ps_lines.append(f"$env:HF_TOKEN = '{_ps_squote(req.hf_token)}'")
|
||||||
if req.env_prefix:
|
if req.env_prefix:
|
||||||
@@ -545,7 +436,7 @@ def setup_cookbook_routes() -> APIRouter:
|
|||||||
# Install hf CLI + optional hf_transfer best-effort. Retries disable
|
# Install hf CLI + optional hf_transfer best-effort. Retries disable
|
||||||
# hf_transfer because the Rust parallel path is fast but has been
|
# hf_transfer because the Rust parallel path is fast but has been
|
||||||
# flaky near the end of very large multi-file downloads.
|
# flaky near the end of very large multi-file downloads.
|
||||||
# Use --break-system-packages on PEP-668 systems (Arch, newer Debian) so it doesn't bail.
|
# The helper tries active pip first, then guarded user-site fallbacks.
|
||||||
runner_lines.append(f"command -v hf >/dev/null 2>&1 || {_pip_install_fallback_chain('huggingface_hub', python_cmd='pip', upgrade=True)}")
|
runner_lines.append(f"command -v hf >/dev/null 2>&1 || {_pip_install_fallback_chain('huggingface_hub', python_cmd='pip', upgrade=True)}")
|
||||||
if req.disable_hf_transfer:
|
if req.disable_hf_transfer:
|
||||||
runner_lines.append("export HF_HUB_ENABLE_HF_TRANSFER=0")
|
runner_lines.append("export HF_HUB_ENABLE_HF_TRANSFER=0")
|
||||||
@@ -673,24 +564,35 @@ def setup_cookbook_routes() -> APIRouter:
|
|||||||
for d in model_dir.split(','):
|
for d in model_dir.split(','):
|
||||||
d = d.strip()
|
d = d.strip()
|
||||||
if d:
|
if d:
|
||||||
model_dirs.append(d)
|
translated_d = translate_path(d) if not host else d
|
||||||
paths_code = _cached_model_scan_script(model_dirs)
|
model_dirs.append(translated_d)
|
||||||
|
win_hf_hub = None
|
||||||
|
if not host:
|
||||||
|
win_profile = get_wsl_windows_user_profile()
|
||||||
|
win_hf_hub = os.path.join(win_profile, ".cache", "huggingface", "hub") if win_profile else None
|
||||||
|
|
||||||
|
paths_code = _cached_model_scan_script(model_dirs, win_hf_hub)
|
||||||
|
|
||||||
scan_py = TMUX_LOG_DIR / "scan_cache.py"
|
scan_py = TMUX_LOG_DIR / "scan_cache.py"
|
||||||
scan_py.write_text(paths_code, encoding="utf-8")
|
scan_py.write_text(paths_code, encoding="utf-8")
|
||||||
|
scan_payload = scan_py.read_bytes()
|
||||||
|
|
||||||
if host:
|
if host:
|
||||||
_pf = f"-p {ssh_port} " if ssh_port and ssh_port != "22" else ""
|
|
||||||
if platform == "windows":
|
if platform == "windows":
|
||||||
# Windows: use 'python' and pipe via stdin with double-quote wrapping
|
remote_cmd = "python -"
|
||||||
cmd = f'ssh {_pf}{host} "python -" < \'{scan_py}\''
|
|
||||||
else:
|
else:
|
||||||
cmd = f"ssh {_pf}{host} 'python3 -' < '{scan_py}'"
|
# POSIX: use 'python3' if available, fall back to 'python'; throw if neither is found.
|
||||||
proc = await asyncio.create_subprocess_shell(
|
remote_cmd = (
|
||||||
cmd,
|
"if command -v python3 >/dev/null 2>&1; then python3 -; "
|
||||||
stdout=asyncio.subprocess.PIPE,
|
"elif command -v python >/dev/null 2>&1; then python -; "
|
||||||
stderr=asyncio.subprocess.PIPE,
|
"else echo \"python3/python not found\" >&2; exit 127; fi"
|
||||||
cwd=str(Path.home()),
|
)
|
||||||
|
rc, stdout_b, stderr_b = await run_ssh_command_async(
|
||||||
|
host,
|
||||||
|
ssh_port,
|
||||||
|
remote_cmd,
|
||||||
|
timeout=60,
|
||||||
|
stdin_data=scan_payload,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# LOCAL scan: use sys.executable (the venv Python Odysseus is already
|
# LOCAL scan: use sys.executable (the venv Python Odysseus is already
|
||||||
@@ -710,7 +612,7 @@ def setup_cookbook_routes() -> APIRouter:
|
|||||||
stderr=asyncio.subprocess.PIPE,
|
stderr=asyncio.subprocess.PIPE,
|
||||||
cwd=str(Path.home()),
|
cwd=str(Path.home()),
|
||||||
)
|
)
|
||||||
stdout_b, stderr_b = await asyncio.wait_for(proc.communicate(), timeout=60)
|
stdout_b, stderr_b = await asyncio.wait_for(proc.communicate(), timeout=60)
|
||||||
|
|
||||||
models = []
|
models = []
|
||||||
try:
|
try:
|
||||||
@@ -801,6 +703,55 @@ def setup_cookbook_routes() -> APIRouter:
|
|||||||
finally:
|
finally:
|
||||||
db.close()
|
db.close()
|
||||||
|
|
||||||
|
def _pick_free_port_for_ollama(
|
||||||
|
remote: str | None, ssh_port: str | None, start_port: int, max_offset: int
|
||||||
|
) -> int | None:
|
||||||
|
"""Return the first free port in [start_port, start_port+max_offset] on
|
||||||
|
the target host. Used to pick a real bind for `ollama serve` so we
|
||||||
|
don't reattach to an external systemd ollama (or other listener) the
|
||||||
|
Cookbook Stop button can't kill."""
|
||||||
|
import socket
|
||||||
|
if remote:
|
||||||
|
# Probe over SSH. Bash's /dev/tcp gives a portable "is anything
|
||||||
|
# listening" check without requiring ss/netstat/nmap.
|
||||||
|
ssh_base = ["ssh", "-o", "ConnectTimeout=4", "-o", "StrictHostKeyChecking=no"]
|
||||||
|
if ssh_port and str(ssh_port) != "22":
|
||||||
|
if not _SSH_PORT_RE.match(str(ssh_port)):
|
||||||
|
return None
|
||||||
|
ssh_base.extend(["-p", str(ssh_port)])
|
||||||
|
host_arg = remote
|
||||||
|
if not _REMOTE_HOST_RE.match(host_arg):
|
||||||
|
return None
|
||||||
|
probe_ports = " ".join(str(start_port + i) for i in range(max_offset + 1))
|
||||||
|
script = (
|
||||||
|
f"for p in {probe_ports}; do "
|
||||||
|
"if ! (exec 3<>/dev/tcp/127.0.0.1/$p) 2>/dev/null; then "
|
||||||
|
"echo $p; exit 0; fi; exec 3<&-; exec 3>&-; done; exit 1"
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
import subprocess
|
||||||
|
r = subprocess.run(
|
||||||
|
ssh_base + [host_arg, script],
|
||||||
|
capture_output=True, text=True, timeout=8,
|
||||||
|
)
|
||||||
|
if r.returncode == 0:
|
||||||
|
out = (r.stdout or "").strip().splitlines()
|
||||||
|
if out and out[0].isdigit():
|
||||||
|
return int(out[0])
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
return None
|
||||||
|
# Local: just try to connect.
|
||||||
|
for off in range(max_offset + 1):
|
||||||
|
p = start_port + off
|
||||||
|
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||||
|
s.settimeout(0.25)
|
||||||
|
try:
|
||||||
|
s.connect(("127.0.0.1", p))
|
||||||
|
except (ConnectionRefusedError, socket.timeout, OSError):
|
||||||
|
return p
|
||||||
|
return None
|
||||||
|
|
||||||
def _auto_register_llm_endpoint(req: ServeRequest, remote: str | None) -> str | None:
|
def _auto_register_llm_endpoint(req: ServeRequest, remote: str | None) -> str | None:
|
||||||
"""Register a freshly-served LLM as a model endpoint so it appears in the
|
"""Register a freshly-served LLM as a model endpoint so it appears in the
|
||||||
model picker without a manual /setup step — the text-model sibling of
|
model picker without a manual /setup step — the text-model sibling of
|
||||||
@@ -815,21 +766,37 @@ def setup_cookbook_routes() -> APIRouter:
|
|||||||
import re
|
import re
|
||||||
from core.database import SessionLocal, ModelEndpoint
|
from core.database import SessionLocal, ModelEndpoint
|
||||||
|
|
||||||
# Port: an explicit --port wins. Otherwise fall back by backend — Ollama
|
# Port: ordered fallbacks so we match whatever the user actually
|
||||||
# is the only server in our generated commands that omits --port.
|
# asked for, not a hardcoded default:
|
||||||
|
# 1. explicit `--port N` (vllm / sglang / llama-server)
|
||||||
|
# 2. `OLLAMA_HOST=host:port` (the way Ollama specifies its bind)
|
||||||
|
# 3. fallback by backend (11434 ollama / 8080 llama.cpp)
|
||||||
|
# Previously the OLLAMA_HOST form was silently ignored and we
|
||||||
|
# registered every Ollama endpoint at 11434 — even if the user
|
||||||
|
# set OLLAMA_HOST=0.0.0.0:11435 to avoid colliding with an
|
||||||
|
# existing systemd Ollama, the registered endpoint pointed at
|
||||||
|
# the OLD port and showed as offline.
|
||||||
port_match = re.search(r'--port\s+(\d+)', req.cmd)
|
port_match = re.search(r'--port\s+(\d+)', req.cmd)
|
||||||
|
ollama_host_match = re.search(r'OLLAMA_HOST=[^\s]*?:(\d+)', req.cmd)
|
||||||
if port_match:
|
if port_match:
|
||||||
port = int(port_match.group(1))
|
port = int(port_match.group(1))
|
||||||
|
elif ollama_host_match:
|
||||||
|
port = int(ollama_host_match.group(1))
|
||||||
elif "ollama" in req.cmd:
|
elif "ollama" in req.cmd:
|
||||||
port = 11434
|
port = 11434
|
||||||
else:
|
else:
|
||||||
port = 8080 # llama.cpp's llama-server default — the Apple Silicon path
|
port = 8080 # llama.cpp's llama-server default — the Apple Silicon path
|
||||||
|
|
||||||
# Determine host (mirrors the image path: SSH alias for remote serves).
|
# Determine host (mirrors the image path: SSH alias for remote serves).
|
||||||
|
# For local serves while Odysseus runs inside Docker, "localhost"
|
||||||
|
# resolves to the container itself — useless. Use host.docker.internal
|
||||||
|
# which compose maps to the actual host, matching what /setup adds
|
||||||
|
# for Ollama by hand.
|
||||||
if remote:
|
if remote:
|
||||||
host = remote.split("@")[-1] if "@" in remote else remote
|
host = remote.split("@")[-1] if "@" in remote else remote
|
||||||
else:
|
else:
|
||||||
host = "localhost"
|
from routes.model_routes import _docker_host_gateway_reachable
|
||||||
|
host = "host.docker.internal" if _docker_host_gateway_reachable() else "localhost"
|
||||||
|
|
||||||
base_url = f"http://{host}:{port}/v1"
|
base_url = f"http://{host}:{port}/v1"
|
||||||
|
|
||||||
@@ -850,6 +817,10 @@ def setup_cookbook_routes() -> APIRouter:
|
|||||||
existing.name = display_name
|
existing.name = display_name
|
||||||
if supports_tools is not None:
|
if supports_tools is not None:
|
||||||
existing.supports_tools = supports_tools
|
existing.supports_tools = supports_tools
|
||||||
|
# Wipe stale model lists so the picker re-probes and discovers
|
||||||
|
# the newly-served model instead of showing the old one.
|
||||||
|
existing.cached_models = None
|
||||||
|
existing.hidden_models = None
|
||||||
db.commit()
|
db.commit()
|
||||||
logger.info(f"Updated existing local model endpoint: {base_url}")
|
logger.info(f"Updated existing local model endpoint: {base_url}")
|
||||||
return existing.id
|
return existing.id
|
||||||
@@ -906,11 +877,27 @@ def setup_cookbook_routes() -> APIRouter:
|
|||||||
in_venv=sys.prefix != sys.base_prefix,
|
in_venv=sys.prefix != sys.base_prefix,
|
||||||
)
|
)
|
||||||
is_pip_install = bool(req.cmd and "pip install" in req.cmd)
|
is_pip_install = bool(req.cmd and "pip install" in req.cmd)
|
||||||
|
remote = req.remote_host
|
||||||
|
is_windows = req.platform == "windows"
|
||||||
|
local_windows = IS_WINDOWS and not remote
|
||||||
|
if is_windows or local_windows:
|
||||||
|
if req.cmd.startswith("python3 "):
|
||||||
|
req.cmd = "python " + req.cmd[len("python3 "):]
|
||||||
|
if is_pip_install and ("llama-cpp-python" in req.cmd or "llama_cpp" in req.cmd) and (is_windows or local_windows):
|
||||||
|
if "--extra-index-url" not in req.cmd:
|
||||||
|
req.cmd += " --extra-index-url https://abetlen.github.io/llama-cpp-python/whl/cpu"
|
||||||
|
|
||||||
if is_pip_install:
|
if is_pip_install:
|
||||||
# Keep big dependency wheel builds (vLLM, …) off the home filesystem's
|
# Keep big dependency wheel builds (vLLM, …) off the home filesystem's
|
||||||
# pip cache so they don't fail mid-build with "No space left" (#1219)
|
# pip cache so they don't fail mid-build with "No space left" (#1219)
|
||||||
# and leave the dep installed-but-unusable (#1459).
|
# and leave the dep installed-but-unusable (#1459).
|
||||||
req.cmd = _pip_install_no_cache(req.cmd)
|
req.cmd = _pip_install_no_cache(req.cmd)
|
||||||
|
# Accept common aliases and enforce server extras for llama-cpp so
|
||||||
|
# `python -m llama_cpp.server` has all runtime dependencies.
|
||||||
|
req.cmd = re.sub(r"(?<![A-Za-z0-9_.-])llama_cpp(?![A-Za-z0-9_.-])", "llama-cpp-python[server]", req.cmd)
|
||||||
|
req.cmd = re.sub(r"(?<![A-Za-z0-9_.-])llama-cpp-python(?!\[)", "llama-cpp-python[server]", req.cmd)
|
||||||
|
if "llama-cpp-python" in req.cmd and "--extra-index-url" not in req.cmd:
|
||||||
|
req.cmd += " --extra-index-url https://abetlen.github.io/llama-cpp-python/whl/cpu"
|
||||||
# PEP-508-style package spec — letters, digits, `.-_` for the
|
# PEP-508-style package spec — letters, digits, `.-_` for the
|
||||||
# name; `[` `]` for extras; `<>=!~,` for version specifiers.
|
# name; `[` `]` for extras; `<>=!~,` for version specifiers.
|
||||||
# v2 review HIGH-14: tightened from the previous regex which
|
# v2 review HIGH-14: tightened from the previous regex which
|
||||||
@@ -927,6 +914,19 @@ def setup_cookbook_routes() -> APIRouter:
|
|||||||
session_id = f"serve-{uuid.uuid4().hex[:8]}"
|
session_id = f"serve-{uuid.uuid4().hex[:8]}"
|
||||||
remote = req.remote_host
|
remote = req.remote_host
|
||||||
is_windows = req.platform == "windows"
|
is_windows = req.platform == "windows"
|
||||||
|
|
||||||
|
# Ollama: if the user didn't pin a port, resolve the actual port we'll
|
||||||
|
# bind to here (before runner construction) by probing the target host.
|
||||||
|
# Otherwise the runner script picks one at runtime and `_auto_register`
|
||||||
|
# below still registers the stale 11434 default — which on a host with
|
||||||
|
# a systemd ollama lands on the wrong (unreachable-from-docker) service.
|
||||||
|
if "ollama" in req.cmd and "OLLAMA_HOST=" not in req.cmd:
|
||||||
|
_ollama_bind_host = "0.0.0.0" if remote else "127.0.0.1"
|
||||||
|
_ollama_chosen_port = _pick_free_port_for_ollama(
|
||||||
|
remote, req.ssh_port, start_port=11434, max_offset=10,
|
||||||
|
)
|
||||||
|
if _ollama_chosen_port:
|
||||||
|
req.cmd = f"OLLAMA_HOST={_ollama_bind_host}:{_ollama_chosen_port} {req.cmd}"
|
||||||
# LOCAL execution on a native-Windows host never uses tmux (detached
|
# LOCAL execution on a native-Windows host never uses tmux (detached
|
||||||
# process path below), regardless of the UI-supplied platform.
|
# process path below), regardless of the UI-supplied platform.
|
||||||
local_windows = IS_WINDOWS and not remote
|
local_windows = IS_WINDOWS and not remote
|
||||||
@@ -950,6 +950,8 @@ def setup_cookbook_routes() -> APIRouter:
|
|||||||
ps_lines = []
|
ps_lines = []
|
||||||
ps_lines.append('$sessionDir = "$env:TEMP\\odysseus-sessions"')
|
ps_lines.append('$sessionDir = "$env:TEMP\\odysseus-sessions"')
|
||||||
ps_lines.append('New-Item -ItemType Directory -Force -Path $sessionDir | Out-Null')
|
ps_lines.append('New-Item -ItemType Directory -Force -Path $sessionDir | Out-Null')
|
||||||
|
ps_lines.append('$env:PYTHONIOENCODING = "utf-8"')
|
||||||
|
ps_lines.append('$env:PYTHONUTF8 = "1"')
|
||||||
if req.hf_token:
|
if req.hf_token:
|
||||||
ps_lines.append(f"$env:HF_TOKEN = '{_ps_squote(req.hf_token)}'")
|
ps_lines.append(f"$env:HF_TOKEN = '{_ps_squote(req.hf_token)}'")
|
||||||
if req.gpus:
|
if req.gpus:
|
||||||
@@ -968,7 +970,7 @@ def setup_cookbook_routes() -> APIRouter:
|
|||||||
ps_lines.append('try { python -c "import llama_cpp" 2>$null } catch {}')
|
ps_lines.append('try { python -c "import llama_cpp" 2>$null } catch {}')
|
||||||
ps_lines.append('if ($LASTEXITCODE -ne 0) {')
|
ps_lines.append('if ($LASTEXITCODE -ne 0) {')
|
||||||
ps_lines.append(' Write-Host "Installing llama-cpp-python..."')
|
ps_lines.append(' Write-Host "Installing llama-cpp-python..."')
|
||||||
ps_lines.append(' python -m pip install llama-cpp-python[server]')
|
ps_lines.append(' python -m pip install llama-cpp-python[server] --extra-index-url https://abetlen.github.io/llama-cpp-python/whl/cpu')
|
||||||
ps_lines.append('}')
|
ps_lines.append('}')
|
||||||
elif "vllm" in req.cmd:
|
elif "vllm" in req.cmd:
|
||||||
ps_lines.append('Write-Host "ERROR: vLLM is not supported on Windows. Use Ollama or llama.cpp instead."')
|
ps_lines.append('Write-Host "ERROR: vLLM is not supported on Windows. Use Ollama or llama.cpp instead."')
|
||||||
@@ -998,6 +1000,21 @@ def setup_cookbook_routes() -> APIRouter:
|
|||||||
else:
|
else:
|
||||||
# ── Linux/Termux: bash + tmux (existing flow) ──
|
# ── Linux/Termux: bash + tmux (existing flow) ──
|
||||||
runner_lines = ["#!/bin/bash"]
|
runner_lines = ["#!/bin/bash"]
|
||||||
|
# Mirror every line of stdout+stderr into a persistent log file
|
||||||
|
# on the host running the serve. This is the file tail_serve_output
|
||||||
|
# reads when the tmux pane has been overwritten by the post-crash
|
||||||
|
# bash prompt — without it, the agent's diagnostic tool sees the
|
||||||
|
# neofetch banner instead of the actual Python traceback.
|
||||||
|
# We save the original fds to 3/4 so we can RESTORE them before
|
||||||
|
# `exec ${SHELL}` at the end of the script. Without that restore,
|
||||||
|
# the post-crash interactive shell's neofetch banner ALSO gets
|
||||||
|
# teed into the log file and `tail -N` returns ONLY the banner —
|
||||||
|
# the actual traceback ends up earlier than the tail window.
|
||||||
|
runner_lines.append("mkdir -p /tmp/odysseus-tmux 2>/dev/null || true")
|
||||||
|
runner_lines.append("exec 3>&1 4>&2")
|
||||||
|
runner_lines.append(
|
||||||
|
f"exec > >(tee -a /tmp/odysseus-tmux/{session_id}.log) 2>&1"
|
||||||
|
)
|
||||||
runner_lines.extend(_user_shell_path_bootstrap())
|
runner_lines.extend(_user_shell_path_bootstrap())
|
||||||
runner_lines.append('ODYSSEUS_PREFLIGHT_EXIT=""')
|
runner_lines.append('ODYSSEUS_PREFLIGHT_EXIT=""')
|
||||||
# Put Odysseus's own venv bin on PATH (local runs only) so the serve
|
# Put Odysseus's own venv bin on PATH (local runs only) so the serve
|
||||||
@@ -1028,45 +1045,57 @@ def setup_cookbook_routes() -> APIRouter:
|
|||||||
# ollama is found (otherwise macOS falls back to a slow source build).
|
# ollama is found (otherwise macOS falls back to a slow source build).
|
||||||
# /opt/homebrew = Apple Silicon, /usr/local = Intel; harmless on Linux.
|
# /opt/homebrew = Apple Silicon, /usr/local = Intel; harmless on Linux.
|
||||||
runner_lines.append('export PATH="$HOME/.local/bin:$HOME/bin:$HOME/llama.cpp/build/bin:/opt/homebrew/bin:/usr/local/bin:$PATH"')
|
runner_lines.append('export PATH="$HOME/.local/bin:$HOME/bin:$HOME/llama.cpp/build/bin:/opt/homebrew/bin:/usr/local/bin:$PATH"')
|
||||||
runner_lines.append('if [ -d /data/data/com.termux ]; then')
|
if local_windows:
|
||||||
runner_lines.append(' # Termux: no native build — use the Python bindings (CPU).')
|
# LOCAL Windows: no native source compilation (no cmake/compiler on Git Bash).
|
||||||
runner_lines.append(' if ! python3 -c "import llama_cpp" 2>/dev/null; then')
|
# Just check python bindings (using native `python` binary) and fall back to pip install.
|
||||||
runner_lines.append(' pkg install -y cmake 2>/dev/null')
|
runner_lines.append('if ! command -v llama-server &>/dev/null && ! python -c "import llama_cpp" 2>/dev/null; then')
|
||||||
runner_lines.append(' pip install numpy diskcache jinja2 2>/dev/null')
|
runner_lines.append(' echo "llama-server not found — installing Python bindings..."')
|
||||||
runner_lines.append(' CMAKE_ARGS="-DGGML_BLAS=OFF -DGGML_LLAMAFILE=OFF" pip install \'llama-cpp-python[server]\' --no-build-isolation --no-cache-dir 2>&1 || true')
|
runner_lines.append(f" {_pip_install_fallback_chain('llama-cpp-python[server]', python_cmd='python')} || true")
|
||||||
runner_lines.append(' fi')
|
runner_lines.append('fi')
|
||||||
runner_lines.append('elif ! command -v llama-server &>/dev/null; then')
|
runner_lines.append('if ! command -v llama-server &>/dev/null && ! python -c "import llama_cpp" 2>/dev/null; then')
|
||||||
runner_lines.append(' echo "Native llama-server not found — building from source (one-time, may take a few minutes)..."')
|
runner_lines.append(' echo "ERROR: llama.cpp serving is not available after install attempts."')
|
||||||
runner_lines.append(' mkdir -p ~/bin')
|
runner_lines.append(' ODYSSEUS_PREFLIGHT_EXIT=127')
|
||||||
runner_lines.append(' cd ~ && [ -d llama.cpp ] || git clone --depth 1 https://github.com/ggml-org/llama.cpp')
|
runner_lines.append('fi')
|
||||||
# Build with the right accelerator: Metal on macOS (llama.cpp
|
else:
|
||||||
# enables it automatically, no flag), CUDA on Linux when present,
|
runner_lines.append('if [ -d /data/data/com.termux ]; then')
|
||||||
# else a plain CPU build. nproc is Linux-only — fall back to
|
runner_lines.append(' # Termux: no native build — use the Python bindings (CPU).')
|
||||||
# `sysctl hw.ncpu` on macOS. (Tip: `brew install llama.cpp` ships
|
runner_lines.append(' if ! python3 -c "import llama_cpp" 2>/dev/null; then')
|
||||||
# a prebuilt llama-server and skips this whole source build.)
|
runner_lines.append(' pkg install -y cmake 2>/dev/null')
|
||||||
runner_lines.append(' NPROC="$(nproc 2>/dev/null || sysctl -n hw.ncpu 2>/dev/null || echo 4)"')
|
runner_lines.append(' pip install numpy diskcache jinja2 2>/dev/null')
|
||||||
runner_lines.append(' if [ "$(uname -s)" = "Darwin" ]; then')
|
runner_lines.append(' CMAKE_ARGS="-DGGML_BLAS=OFF -DGGML_LLAMAFILE=OFF" pip install \'llama-cpp-python[server]\' --no-build-isolation --no-cache-dir 2>&1 || true')
|
||||||
runner_lines.append(' command -v cmake >/dev/null 2>&1 || echo "WARNING: cmake not found — install it with: brew install cmake (or: brew install llama.cpp for a prebuilt llama-server)."')
|
runner_lines.append(' fi')
|
||||||
# Start from a clean cache: a prior failed configure (e.g. a CUDA
|
runner_lines.append('elif ! command -v llama-server &>/dev/null; then')
|
||||||
# attempt) poisons build/CMakeCache.txt, so a plain `cmake -B build`
|
runner_lines.append(' echo "Native llama-server not found — building from source (one-time, may take a few minutes)..."')
|
||||||
# would reuse the bad settings and fail again. CMAKE_BUILD_TYPE is
|
runner_lines.append(' mkdir -p ~/bin')
|
||||||
# explicit so the binary is optimized (Metal auto-enables on macOS).
|
runner_lines.append(' cd ~ && [ -d llama.cpp ] || git clone --depth 1 https://github.com/ggml-org/llama.cpp')
|
||||||
runner_lines.append(' cd ~/llama.cpp && rm -rf build && cmake -B build -DCMAKE_BUILD_TYPE=Release \\')
|
# Build with the right accelerator: Metal on macOS (llama.cpp
|
||||||
runner_lines.append(' && cmake --build build -j"$NPROC" --target llama-server \\')
|
# enables it automatically, no flag), CUDA on Linux when present,
|
||||||
runner_lines.append(' && ln -sf ~/llama.cpp/build/bin/llama-server ~/bin/llama-server')
|
# else a plain CPU build. nproc is Linux-only — fall back to
|
||||||
runner_lines.append(' else')
|
# `sysctl hw.ncpu` on macOS. (Tip: `brew install llama.cpp` ships
|
||||||
_append_llama_cpp_linux_accel_build_lines(runner_lines)
|
# a prebuilt llama-server and skips this whole source build.)
|
||||||
runner_lines.append(' fi')
|
runner_lines.append(' NPROC="$(nproc 2>/dev/null || sysctl -n hw.ncpu 2>/dev/null || echo 4)"')
|
||||||
runner_lines.append(' # If the native build failed, fall back to the Python bindings.')
|
runner_lines.append(' if [ "$(uname -s)" = "Darwin" ]; then')
|
||||||
runner_lines.append(' if ! command -v llama-server &>/dev/null && ! python3 -c "import llama_cpp" 2>/dev/null; then')
|
runner_lines.append(' command -v cmake >/dev/null 2>&1 || echo "WARNING: cmake not found — install it with: brew install cmake (or: brew install llama.cpp for a prebuilt llama-server)."')
|
||||||
runner_lines.append(' echo "llama-server build failed — installing Python bindings as fallback..."')
|
# Start from a clean cache: a prior failed configure (e.g. a CUDA
|
||||||
runner_lines.append(f" {_pip_install_fallback_chain('llama-cpp-python[server]', python_cmd='pip')} || true")
|
# attempt) poisons build/CMakeCache.txt, so a plain `cmake -B build`
|
||||||
runner_lines.append(' fi')
|
# would reuse the bad settings and fail again. CMAKE_BUILD_TYPE is
|
||||||
runner_lines.append(' if ! command -v llama-server &>/dev/null && ! python3 -c "import llama_cpp" 2>/dev/null; then')
|
# explicit so the binary is optimized (Metal auto-enables on macOS).
|
||||||
runner_lines.append(' echo "ERROR: llama.cpp serving is not available after install/build attempts."')
|
runner_lines.append(' cd ~/llama.cpp && rm -rf build && cmake -B build -DCMAKE_BUILD_TYPE=Release \\')
|
||||||
runner_lines.append(' ODYSSEUS_PREFLIGHT_EXIT=127')
|
runner_lines.append(' && cmake --build build -j"$NPROC" --target llama-server \\')
|
||||||
runner_lines.append(' fi')
|
runner_lines.append(' && ln -sf ~/llama.cpp/build/bin/llama-server ~/bin/llama-server')
|
||||||
runner_lines.append('fi')
|
runner_lines.append(' else')
|
||||||
|
_append_llama_cpp_linux_accel_build_lines(runner_lines)
|
||||||
|
runner_lines.append(' fi')
|
||||||
|
# If the native build failed, fall back to the Python bindings.
|
||||||
|
runner_lines.append(' if ! command -v llama-server &>/dev/null && ! python3 -c "import llama_cpp" 2>/dev/null; then')
|
||||||
|
runner_lines.append(' echo "llama-server build failed — installing Python bindings as fallback..."')
|
||||||
|
runner_lines.append(f" {_pip_install_fallback_chain('llama-cpp-python[server]', python_cmd='pip')} || true")
|
||||||
|
runner_lines.append(' fi')
|
||||||
|
runner_lines.append(' if ! command -v llama-server &>/dev/null && ! python3 -c "import llama_cpp" 2>/dev/null; then')
|
||||||
|
runner_lines.append(' echo "ERROR: llama.cpp serving is not available after install/build attempts."')
|
||||||
|
runner_lines.append(' ODYSSEUS_PREFLIGHT_EXIT=127')
|
||||||
|
runner_lines.append(' fi')
|
||||||
|
runner_lines.append('fi')
|
||||||
elif "ollama" in req.cmd:
|
elif "ollama" in req.cmd:
|
||||||
handled_ollama_serve = True
|
handled_ollama_serve = True
|
||||||
_ollama_default_host = "0.0.0.0" if remote else "127.0.0.1"
|
_ollama_default_host = "0.0.0.0" if remote else "127.0.0.1"
|
||||||
@@ -1074,41 +1103,37 @@ def setup_cookbook_routes() -> APIRouter:
|
|||||||
req.cmd,
|
req.cmd,
|
||||||
default_host=_ollama_default_host,
|
default_host=_ollama_default_host,
|
||||||
)
|
)
|
||||||
# Ollama can be a host binary, a system service, or a Docker
|
# Always launch a fresh ollama under tmux so Stop reliably
|
||||||
# container. If the HTTP API is already reachable, the model is
|
# kills it. If the requested port is busy (e.g. a systemd
|
||||||
# already served and we should not require a host `ollama` CLI.
|
# ollama on 11434), scan upward for a free one rather than
|
||||||
|
# silently reattaching to an external service that Stop
|
||||||
|
# can't reach.
|
||||||
runner_lines.append(f'ODYSSEUS_OLLAMA_HOST={_bash_squote(_ollama_host)}')
|
runner_lines.append(f'ODYSSEUS_OLLAMA_HOST={_bash_squote(_ollama_host)}')
|
||||||
runner_lines.append(f'ODYSSEUS_OLLAMA_PORT="{_ollama_port}"')
|
runner_lines.append(f'ODYSSEUS_OLLAMA_PORT="{_ollama_port}"')
|
||||||
runner_lines.append('ODYSSEUS_OLLAMA_URL=""')
|
runner_lines.append('for _ody_off in 0 1 2 3 4 5 6 7 8 9; do')
|
||||||
runner_lines.append('for _ody_ollama_try in $(seq 1 20); do')
|
runner_lines.append(' _ody_try_port=$((ODYSSEUS_OLLAMA_PORT + _ody_off))')
|
||||||
runner_lines.append(' for _ody_ollama_port in "$ODYSSEUS_OLLAMA_PORT" 11434; do')
|
runner_lines.append(' if ! (exec 3<>/dev/tcp/127.0.0.1/$_ody_try_port) 2>/dev/null; then')
|
||||||
runner_lines.append(' [ -z "$_ody_ollama_port" ] && continue')
|
runner_lines.append(' exec 3<&-; exec 3>&-')
|
||||||
runner_lines.append(' for _ody_ollama_host in 127.0.0.1 localhost host.docker.internal; do')
|
runner_lines.append(' ODYSSEUS_OLLAMA_PORT="$_ody_try_port"')
|
||||||
runner_lines.append(' _ody_ollama_url="http://${_ody_ollama_host}:${_ody_ollama_port}"')
|
runner_lines.append(' break')
|
||||||
runner_lines.append(' if curl -sf "$_ody_ollama_url/api/tags" >/dev/null 2>&1; then')
|
|
||||||
runner_lines.append(' ODYSSEUS_OLLAMA_URL="$_ody_ollama_url"')
|
|
||||||
runner_lines.append(' ODYSSEUS_OLLAMA_PORT="$_ody_ollama_port"')
|
|
||||||
runner_lines.append(' break 3')
|
|
||||||
runner_lines.append(' fi')
|
|
||||||
runner_lines.append(' done')
|
|
||||||
runner_lines.append(' done')
|
|
||||||
runner_lines.append(' [ "$_ody_ollama_try" -eq 1 ] && echo "[odysseus] Waiting for an existing Ollama API on ports ${ODYSSEUS_OLLAMA_PORT}/11434..."')
|
|
||||||
runner_lines.append(' sleep 1')
|
|
||||||
runner_lines.append('done')
|
|
||||||
runner_lines.append('if [ -n "$ODYSSEUS_OLLAMA_URL" ]; then')
|
|
||||||
runner_lines.append(' if [ "$ODYSSEUS_OLLAMA_PORT" != "' + _ollama_port + '" ]; then')
|
|
||||||
runner_lines.append(' echo "[odysseus] Selected Ollama port ' + _ollama_port + ' was not reachable; using running Ollama on port ${ODYSSEUS_OLLAMA_PORT}."')
|
|
||||||
runner_lines.append(' fi')
|
runner_lines.append(' fi')
|
||||||
runner_lines.append(' echo "[odysseus] Ollama API ready on port ${ODYSSEUS_OLLAMA_PORT}: ${ODYSSEUS_OLLAMA_URL}"')
|
runner_lines.append(' echo "[odysseus] Ollama API ready on port ${ODYSSEUS_OLLAMA_PORT}: ${ODYSSEUS_OLLAMA_URL}"')
|
||||||
runner_lines.append(' echo "[odysseus] This task is monitoring an existing Ollama server; stopping it here will not stop an external Docker/system service."')
|
runner_lines.append(' echo "[odysseus] This task is monitoring an existing Ollama server; stopping it here will not stop an external Docker/system service."')
|
||||||
runner_lines.append(' exec bash -i')
|
if local_windows:
|
||||||
|
# Windows detached process has no TTY; exec bash -i crashes.
|
||||||
|
# Keep the monitoring task alive with a sleep loop.
|
||||||
|
runner_lines.append(' while true; do sleep 60; done')
|
||||||
|
else:
|
||||||
|
runner_lines.append(' exec bash -i')
|
||||||
runner_lines.append('fi')
|
runner_lines.append('fi')
|
||||||
runner_lines.append('if ! command -v ollama &>/dev/null; then')
|
runner_lines.append('if ! command -v ollama &>/dev/null; then')
|
||||||
runner_lines.append(' echo "ERROR: Ollama not found and no Ollama API is reachable on 127.0.0.1, localhost, or host.docker.internal (ports ${ODYSSEUS_OLLAMA_PORT}/11434)."')
|
runner_lines.append(' echo "ERROR: Ollama not found on this server. Install it from https://ollama.com/download or `curl -fsSL https://ollama.com/install.sh | sh`."')
|
||||||
runner_lines.append(' echo "Install Ollama, start an Ollama service/container on this server, or pick the port where it is already listening."')
|
|
||||||
runner_lines.append(' echo')
|
runner_lines.append(' echo')
|
||||||
runner_lines.append(' echo "=== Process exited with code 127 ==="')
|
runner_lines.append(' echo "=== Process exited with code 127 ==="')
|
||||||
runner_lines.append(' exec bash -i')
|
if local_windows:
|
||||||
|
runner_lines.append(' exit 127')
|
||||||
|
else:
|
||||||
|
runner_lines.append(' exec bash -i')
|
||||||
runner_lines.append('fi')
|
runner_lines.append('fi')
|
||||||
runner_lines.append('ODYSSEUS_OLLAMA_URL="http://${ODYSSEUS_OLLAMA_HOST}:${ODYSSEUS_OLLAMA_PORT}"')
|
runner_lines.append('ODYSSEUS_OLLAMA_URL="http://${ODYSSEUS_OLLAMA_HOST}:${ODYSSEUS_OLLAMA_PORT}"')
|
||||||
if remote and _ollama_host in ("0.0.0.0", "::"):
|
if remote and _ollama_host in ("0.0.0.0", "::"):
|
||||||
@@ -1116,24 +1141,20 @@ def setup_cookbook_routes() -> APIRouter:
|
|||||||
runner_lines.append('echo "[odysseus] Ollama has no built-in authentication; expose this only on a trusted LAN/VPN or provide an explicit OLLAMA_HOST with your own access controls."')
|
runner_lines.append('echo "[odysseus] Ollama has no built-in authentication; expose this only on a trusted LAN/VPN or provide an explicit OLLAMA_HOST with your own access controls."')
|
||||||
runner_lines.append('echo "Starting ollama server on ${ODYSSEUS_OLLAMA_HOST}:${ODYSSEUS_OLLAMA_PORT}..."')
|
runner_lines.append('echo "Starting ollama server on ${ODYSSEUS_OLLAMA_HOST}:${ODYSSEUS_OLLAMA_PORT}..."')
|
||||||
runner_lines.append('OLLAMA_HOST="${ODYSSEUS_OLLAMA_HOST}:${ODYSSEUS_OLLAMA_PORT}" ollama serve')
|
runner_lines.append('OLLAMA_HOST="${ODYSSEUS_OLLAMA_HOST}:${ODYSSEUS_OLLAMA_PORT}" ollama serve')
|
||||||
runner_lines.append('_ody_exit=$?')
|
if local_windows:
|
||||||
runner_lines.append('echo')
|
_append_serve_exit_code_lines(runner_lines, keep_shell_open=False)
|
||||||
runner_lines.append('echo "=== Process exited with code ${_ody_exit} ==="')
|
else:
|
||||||
runner_lines.append('exec bash -i')
|
runner_lines.append('_ody_exit=$?')
|
||||||
|
runner_lines.append('echo')
|
||||||
|
runner_lines.append('echo "=== Process exited with code ${_ody_exit} ==="')
|
||||||
|
runner_lines.append('exec bash -i')
|
||||||
elif "vllm serve" in req.cmd:
|
elif "vllm serve" in req.cmd:
|
||||||
# vLLM is CUDA/ROCm-only and does not run on macOS at all.
|
# vLLM is CUDA/ROCm-only and does not run on macOS at all.
|
||||||
runner_lines.append('if [ "$(uname -s)" = "Darwin" ]; then')
|
runner_lines.append('if [ "$(uname -s)" = "Darwin" ]; then')
|
||||||
runner_lines.append(' echo "ERROR: vLLM does not run on macOS. Use Ollama or llama.cpp (Metal) instead."')
|
runner_lines.append(' echo "ERROR: vLLM does not run on macOS. Use Ollama or llama.cpp (Metal) instead."')
|
||||||
runner_lines.append(' ODYSSEUS_PREFLIGHT_EXIT=1')
|
runner_lines.append(' ODYSSEUS_PREFLIGHT_EXIT=1')
|
||||||
runner_lines.append('fi')
|
runner_lines.append('fi')
|
||||||
# Put ~/.local/bin on PATH first — without a venv, vllm installs
|
_append_vllm_linux_preflight_lines(runner_lines)
|
||||||
# there via --user and the non-login serve shell otherwise can't
|
|
||||||
# find the `vllm` CLI ("command not found"). Mirrors llama.cpp above.
|
|
||||||
runner_lines.append('export PATH="$HOME/.local/bin:$PATH"')
|
|
||||||
runner_lines.append('if ! command -v vllm &>/dev/null; then')
|
|
||||||
runner_lines.append(' echo "ERROR: vLLM is not installed."')
|
|
||||||
runner_lines.append(' ODYSSEUS_PREFLIGHT_EXIT=127')
|
|
||||||
runner_lines.append('fi')
|
|
||||||
elif "sglang.launch_server" in req.cmd:
|
elif "sglang.launch_server" in req.cmd:
|
||||||
runner_lines.append('export PATH="$HOME/.local/bin:$PATH"')
|
runner_lines.append('export PATH="$HOME/.local/bin:$PATH"')
|
||||||
runner_lines.append('if ! command -v sglang &>/dev/null; then')
|
runner_lines.append('if ! command -v sglang &>/dev/null; then')
|
||||||
@@ -1157,7 +1178,10 @@ def setup_cookbook_routes() -> APIRouter:
|
|||||||
runner_lines,
|
runner_lines,
|
||||||
keep_shell_open=not local_windows,
|
keep_shell_open=not local_windows,
|
||||||
)
|
)
|
||||||
runner_lines.append(req.cmd)
|
if is_pip_install:
|
||||||
|
_append_pip_install_runner_lines(runner_lines, req.cmd)
|
||||||
|
else:
|
||||||
|
runner_lines.append(req.cmd)
|
||||||
if local_windows:
|
if local_windows:
|
||||||
# Detached background process — no interactive shell to keep open.
|
# Detached background process — no interactive shell to keep open.
|
||||||
# Print the exit marker the status poller looks for, then stop.
|
# Print the exit marker the status poller looks for, then stop.
|
||||||
@@ -1318,8 +1342,8 @@ def setup_cookbook_routes() -> APIRouter:
|
|||||||
cmd = f"ssh {pf}{host} '{setup_script}'"
|
cmd = f"ssh {pf}{host} '{setup_script}'"
|
||||||
else:
|
else:
|
||||||
# Linux: auto-install tmux (via whichever package manager is available)
|
# Linux: auto-install tmux (via whichever package manager is available)
|
||||||
# and huggingface_hub + hf_transfer (falling back to --user/--break-system-packages
|
# and huggingface_hub + hf_transfer (falling back to --user, then
|
||||||
# on PEP-668 locked distros like Arch / newer Debian).
|
# guarded --break-system-packages on PEP-668 locked distros).
|
||||||
setup_script = (
|
setup_script = (
|
||||||
# Install tmux if missing — try common package managers; skip if no sudo
|
# Install tmux if missing — try common package managers; skip if no sudo
|
||||||
"if ! command -v tmux >/dev/null 2>&1; then "
|
"if ! command -v tmux >/dev/null 2>&1; then "
|
||||||
@@ -1331,10 +1355,15 @@ def setup_cookbook_routes() -> APIRouter:
|
|||||||
" fi; "
|
" fi; "
|
||||||
"fi; "
|
"fi; "
|
||||||
"command -v tmux >/dev/null 2>&1 || echo 'WARNING: tmux missing and auto-install failed (need passwordless sudo). Install manually.'; "
|
"command -v tmux >/dev/null 2>&1 || echo 'WARNING: tmux missing and auto-install failed (need passwordless sudo). Install manually.'; "
|
||||||
# Install Python bits. Try system install first; fall back to --user --break-system-packages on PEP 668 systems.
|
# Install Python bits. Try system install first; fall back to --user,
|
||||||
|
# then use --break-system-packages only when pip supports it.
|
||||||
"pip install -q huggingface_hub hf_transfer 2>/dev/null || "
|
"pip install -q huggingface_hub hf_transfer 2>/dev/null || "
|
||||||
"pip install --user --break-system-packages -q huggingface_hub hf_transfer 2>/dev/null || "
|
"pip install --user -q huggingface_hub hf_transfer 2>/dev/null || "
|
||||||
"pip3 install --user --break-system-packages -q huggingface_hub hf_transfer 2>/dev/null; "
|
"( pip install --help 2>/dev/null | grep -q -- --break-system-packages && "
|
||||||
|
"pip install --user --break-system-packages -q huggingface_hub hf_transfer 2>/dev/null ) || "
|
||||||
|
"pip3 install --user -q huggingface_hub hf_transfer 2>/dev/null || "
|
||||||
|
"( pip3 install --help 2>/dev/null | grep -q -- --break-system-packages && "
|
||||||
|
"pip3 install --user --break-system-packages -q huggingface_hub hf_transfer 2>/dev/null ); "
|
||||||
"python3 -c 'from huggingface_hub import snapshot_download; print(\"OK\")'"
|
"python3 -c 'from huggingface_hub import snapshot_download; print(\"OK\")'"
|
||||||
)
|
)
|
||||||
cmd = f"ssh {pf}{host} '{setup_script}'"
|
cmd = f"ssh {pf}{host} '{setup_script}'"
|
||||||
@@ -1357,11 +1386,38 @@ def setup_cookbook_routes() -> APIRouter:
|
|||||||
async def _run_nvidia_smi(query: str, host: str | None, ssh_port: str | None, timeout: int = 8):
|
async def _run_nvidia_smi(query: str, host: str | None, ssh_port: str | None, timeout: int = 8):
|
||||||
"""Run nvidia-smi locally or over SSH. Returns (stdout, error_or_None)."""
|
"""Run nvidia-smi locally or over SSH. Returns (stdout, error_or_None)."""
|
||||||
if host:
|
if host:
|
||||||
pf = f"-p {ssh_port} " if ssh_port and ssh_port != "22" else ""
|
candidates = [query]
|
||||||
cmd = f"ssh -o ConnectTimeout=5 -o StrictHostKeyChecking=no {pf}{host} '{query}'"
|
stripped = query.strip()
|
||||||
proc = await asyncio.create_subprocess_shell(
|
if stripped.startswith("nvidia-smi "):
|
||||||
cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE
|
args = stripped[len("nvidia-smi "):]
|
||||||
)
|
candidates.append(
|
||||||
|
"bash -lc "
|
||||||
|
+ shlex.quote(
|
||||||
|
f"{SSH_PATH_OVERRIDE}"
|
||||||
|
f"nvidia-smi {args}"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
for nvidia_path in NVIDIA_PATH_CANDIDATES:
|
||||||
|
candidates.append(f"{nvidia_path} {args}")
|
||||||
|
|
||||||
|
last_err = "nvidia-smi failed"
|
||||||
|
for candidate in candidates:
|
||||||
|
try:
|
||||||
|
rc, stdout, stderr = await run_ssh_command_async(
|
||||||
|
host,
|
||||||
|
ssh_port,
|
||||||
|
candidate,
|
||||||
|
connect_timeout=5,
|
||||||
|
timeout=timeout,
|
||||||
|
)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
return None, "nvidia-smi timed out"
|
||||||
|
if rc == 0:
|
||||||
|
return stdout.decode("utf-8", errors="replace"), None
|
||||||
|
err = (stderr.decode("utf-8", errors="replace") or "").strip()[:200]
|
||||||
|
if err:
|
||||||
|
last_err = err
|
||||||
|
return None, last_err
|
||||||
else:
|
else:
|
||||||
proc = await asyncio.create_subprocess_exec(
|
proc = await asyncio.create_subprocess_exec(
|
||||||
*shlex.split(query),
|
*shlex.split(query),
|
||||||
@@ -1940,6 +1996,153 @@ def setup_cookbook_routes() -> APIRouter:
|
|||||||
|
|
||||||
return {"models": out}
|
return {"models": out}
|
||||||
|
|
||||||
|
# Rate-limit for the orphan-tmux adoption sweep. The UI polls
|
||||||
|
# tasks/status every ~3s; we don't want to SSH every host on every
|
||||||
|
# poll. 20s is fast enough that a model the agent launched in the
|
||||||
|
# background shows up "almost immediately" in the UI without being
|
||||||
|
# wasteful.
|
||||||
|
_last_orphan_sweep_ts = [0.0]
|
||||||
|
_ORPHAN_SWEEP_MIN_INTERVAL_S = 20.0
|
||||||
|
|
||||||
|
def _maybe_sweep_orphans(tasks: list, state: dict) -> None:
|
||||||
|
"""Scan each configured cookbook server for `serve-*` tmux sessions
|
||||||
|
the cookbook doesn't know about and adopt them into state.tasks.
|
||||||
|
|
||||||
|
Writes are conditional: if no orphans are found, nothing is touched.
|
||||||
|
Rate-limited so polling UIs don't trigger SSH on every refresh.
|
||||||
|
"""
|
||||||
|
import time as _time
|
||||||
|
import subprocess
|
||||||
|
logger.info(f"_maybe_sweep_orphans: entered, last_ts={_last_orphan_sweep_ts[0]}")
|
||||||
|
now = _time.monotonic()
|
||||||
|
if now - _last_orphan_sweep_ts[0] < _ORPHAN_SWEEP_MIN_INTERVAL_S:
|
||||||
|
logger.info(f"_maybe_sweep_orphans: rate-limited, {now - _last_orphan_sweep_ts[0]:.1f}s since last")
|
||||||
|
return
|
||||||
|
_last_orphan_sweep_ts[0] = now
|
||||||
|
|
||||||
|
env = state.get("env") if isinstance(state, dict) else {}
|
||||||
|
servers = env.get("servers") if isinstance(env, dict) else []
|
||||||
|
logger.info(f"orphan sweep starting: {len(servers) if isinstance(servers, list) else 0} server(s), known_sids={len([t for t in tasks if isinstance(t, dict) and t.get('sessionId')])}")
|
||||||
|
if not isinstance(servers, list):
|
||||||
|
return
|
||||||
|
|
||||||
|
known_sids = {
|
||||||
|
t.get("sessionId") for t in tasks
|
||||||
|
if isinstance(t, dict) and t.get("sessionId")
|
||||||
|
}
|
||||||
|
|
||||||
|
adopted_any = False
|
||||||
|
for srv in servers:
|
||||||
|
if not isinstance(srv, dict):
|
||||||
|
continue
|
||||||
|
host = (srv.get("host") or "").strip()
|
||||||
|
if not host:
|
||||||
|
continue # local-only entry; the /proc scan handles it
|
||||||
|
if not _REMOTE_HOST_RE.match(host):
|
||||||
|
continue
|
||||||
|
sport = str(srv.get("port") or "").strip()
|
||||||
|
ssh_base = ["ssh", "-o", "ConnectTimeout=4", "-o", "StrictHostKeyChecking=no"]
|
||||||
|
if sport and sport != "22":
|
||||||
|
if not _SSH_PORT_RE.match(sport):
|
||||||
|
continue
|
||||||
|
ssh_base.extend(["-p", sport])
|
||||||
|
|
||||||
|
try:
|
||||||
|
ls = subprocess.run(
|
||||||
|
ssh_base + [host, "tmux ls 2>/dev/null"],
|
||||||
|
timeout=6, capture_output=True, text=True,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
for line in (ls.stdout or "").splitlines():
|
||||||
|
sid = line.split(":", 1)[0].strip()
|
||||||
|
if not sid or not _SESSION_ID_RE.match(sid):
|
||||||
|
continue
|
||||||
|
if sid in known_sids:
|
||||||
|
continue
|
||||||
|
# Adopt any session whose pane is currently running a
|
||||||
|
# known model-server process (checked below). The earlier
|
||||||
|
# prefix gate (serve-/cookbook-) dropped legitimate
|
||||||
|
# serves whenever tmux fell back to numeric IDs, leaving
|
||||||
|
# them invisible in the Cookbook UI — so the user could
|
||||||
|
# neither see nor stop them.
|
||||||
|
# Skip zombie / idle-shell sessions. A tmux session left
|
||||||
|
# over from a crashed vllm just shows a bash prompt —
|
||||||
|
# adopting it would pollute the UI with "running" tasks
|
||||||
|
# that aren't actually serving anything. pane_current_command
|
||||||
|
# is the foreground process in the pane right now; only
|
||||||
|
# real model serves leave a python/vllm/etc. process there.
|
||||||
|
try:
|
||||||
|
pc = subprocess.run(
|
||||||
|
ssh_base + [host, "tmux", "list-panes", "-t", sid,
|
||||||
|
"-F", "#{pane_current_command}"],
|
||||||
|
timeout=4, capture_output=True, text=True,
|
||||||
|
)
|
||||||
|
cur = (pc.stdout or "").strip().splitlines()
|
||||||
|
except Exception:
|
||||||
|
cur = []
|
||||||
|
LIVE_PROCS = {"python", "python3", "vllm", "llama-server",
|
||||||
|
"llama_cpp_main", "sglang", "lmdeploy",
|
||||||
|
"ollama", "node", "uvicorn"}
|
||||||
|
if not any(c in LIVE_PROCS for c in cur):
|
||||||
|
continue
|
||||||
|
# Try to recover a plausible repo_id + port from the
|
||||||
|
# pane buffer. Cheap heuristic — if we can't, register
|
||||||
|
# with placeholder fields; the UI still shows it.
|
||||||
|
try:
|
||||||
|
cap = subprocess.run(
|
||||||
|
ssh_base + [host, "tmux", "capture-pane", "-t", sid, "-p", "-S", "-300"],
|
||||||
|
timeout=6, capture_output=True, text=True,
|
||||||
|
)
|
||||||
|
pane = cap.stdout or ""
|
||||||
|
except Exception:
|
||||||
|
pane = ""
|
||||||
|
import re as _re_orphan
|
||||||
|
# vLLM banner: "model /path/...". Falls back to the
|
||||||
|
# raw vllm-serve command if the banner already scrolled.
|
||||||
|
m_model = _re_orphan.search(r"model\s+(\S+)", pane)
|
||||||
|
model = m_model.group(1) if m_model else ""
|
||||||
|
if not model:
|
||||||
|
m_serve = _re_orphan.search(r"vllm\s+serve\s+(\S+)", pane)
|
||||||
|
model = m_serve.group(1) if m_serve else f"adopted:{sid}"
|
||||||
|
m_port = _re_orphan.search(r"--port\s+(\d+)", pane)
|
||||||
|
port = int(m_port.group(1)) if m_port else 0
|
||||||
|
|
||||||
|
import time as _t2
|
||||||
|
tasks.append({
|
||||||
|
"id": sid,
|
||||||
|
"sessionId": sid,
|
||||||
|
"name": model.split("/")[-1] if "/" in model else model,
|
||||||
|
"type": "serve",
|
||||||
|
"status": "running",
|
||||||
|
"output": f"Auto-adopted from orphan tmux session on {host}. "
|
||||||
|
"Open the task to see live output.",
|
||||||
|
"ts": int(_t2.time() * 1000),
|
||||||
|
"payload": {
|
||||||
|
"repo_id": model,
|
||||||
|
"remote_host": host,
|
||||||
|
"_cmd": "(orphan tmux session — original launch cmd unknown)",
|
||||||
|
"port": port,
|
||||||
|
},
|
||||||
|
"remoteHost": host,
|
||||||
|
"sshPort": sport,
|
||||||
|
"platform": "linux",
|
||||||
|
"_serveReady": False,
|
||||||
|
"_endpointAdded": False,
|
||||||
|
"_adoptedExternally": True,
|
||||||
|
})
|
||||||
|
known_sids.add(sid)
|
||||||
|
adopted_any = True
|
||||||
|
logger.info(f"auto-adopted orphan tmux session {sid!r} on {host}")
|
||||||
|
|
||||||
|
if adopted_any:
|
||||||
|
try:
|
||||||
|
from core.atomic_io import atomic_write_json
|
||||||
|
state["tasks"] = tasks
|
||||||
|
atomic_write_json(_cookbook_state_path, state)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"orphan sweep: state write failed: {e}")
|
||||||
|
|
||||||
@router.get("/api/cookbook/tasks/status")
|
@router.get("/api/cookbook/tasks/status")
|
||||||
async def cookbook_tasks_status(request: Request):
|
async def cookbook_tasks_status(request: Request):
|
||||||
"""Check status of all active cookbook tmux sessions.
|
"""Check status of all active cookbook tmux sessions.
|
||||||
@@ -1977,7 +2180,13 @@ def setup_cookbook_routes() -> APIRouter:
|
|||||||
"inc=os.path.isdir(blobs) and any(x.endswith('.incomplete') for x in os.listdir(blobs));"
|
"inc=os.path.isdir(blobs) and any(x.endswith('.incomplete') for x in os.listdir(blobs));"
|
||||||
"sys.exit(0 if ok and not inc else 1)"
|
"sys.exit(0 if ok and not inc else 1)"
|
||||||
)
|
)
|
||||||
cmd = ["python3", "-c", py, repo_id]
|
if remote_host:
|
||||||
|
cmd = ["python3", "-c", py, repo_id]
|
||||||
|
else:
|
||||||
|
# Local Windows: python3 can hit the Microsoft Store stub. Use the
|
||||||
|
# real Python Odysseus is running under (guaranteed to exist).
|
||||||
|
import sys as _sys_local
|
||||||
|
cmd = [_sys_local.executable, "-c", py, repo_id]
|
||||||
try:
|
try:
|
||||||
if remote_host:
|
if remote_host:
|
||||||
ssh_base = ["ssh"]
|
ssh_base = ["ssh"]
|
||||||
@@ -1993,6 +2202,7 @@ def setup_cookbook_routes() -> APIRouter:
|
|||||||
|
|
||||||
# Load saved tasks from cookbook state
|
# Load saved tasks from cookbook state
|
||||||
tasks = []
|
tasks = []
|
||||||
|
state = {}
|
||||||
if _cookbook_state_path.exists():
|
if _cookbook_state_path.exists():
|
||||||
try:
|
try:
|
||||||
state = json.loads(_cookbook_state_path.read_text(encoding="utf-8"))
|
state = json.loads(_cookbook_state_path.read_text(encoding="utf-8"))
|
||||||
@@ -2004,6 +2214,21 @@ def setup_cookbook_routes() -> APIRouter:
|
|||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
# Orphan-tmux auto-adoption sweep. When the agent (or anyone)
|
||||||
|
# SSH-launches a `serve-*` tmux session — usually because
|
||||||
|
# serve_model rejected `source ... && vllm ...` or because of a
|
||||||
|
# manual relaunch via tmux send-keys — that session is invisible
|
||||||
|
# to the cookbook UI even though it's a live model server. The
|
||||||
|
# sweep finds those orphans on each configured remote host and
|
||||||
|
# writes them into state.tasks with _adoptedExternally=True, so
|
||||||
|
# they show up in the UI on the next poll without anyone having
|
||||||
|
# to remember to call adopt_served_model. Rate-limited via the
|
||||||
|
# module-level _last_orphan_sweep so we don't SSH every 3s.
|
||||||
|
try:
|
||||||
|
_maybe_sweep_orphans(tasks, state)
|
||||||
|
except Exception as _sweep_e:
|
||||||
|
logger.warning(f"orphan sweep failed (non-fatal): {_sweep_e!r}")
|
||||||
|
|
||||||
results = []
|
results = []
|
||||||
for task in tasks:
|
for task in tasks:
|
||||||
session_id = task.get("sessionId", "")
|
session_id = task.get("sessionId", "")
|
||||||
@@ -2063,7 +2288,12 @@ def setup_cookbook_routes() -> APIRouter:
|
|||||||
if _tport and _tport != "22":
|
if _tport and _tport != "22":
|
||||||
ssh_base.extend(["-p", str(_tport)])
|
ssh_base.extend(["-p", str(_tport)])
|
||||||
check_cmd = ssh_base + [remote, "tmux", "has-session", "-t", session_id]
|
check_cmd = ssh_base + [remote, "tmux", "has-session", "-t", session_id]
|
||||||
capture_cmd = ssh_base + [remote, "tmux", "capture-pane", "-t", session_id, "-p", "-S", "-50"]
|
# Capture 500 lines (was 50) so a Python traceback survives
|
||||||
|
# the post-crash neofetch banner + bash prompt that otherwise
|
||||||
|
# fills the visible tail. Without this, output_tail ends up
|
||||||
|
# as just "Locale: C / Ubuntu_Odysseus ❯" and the agent
|
||||||
|
# can't diagnose the actual error.
|
||||||
|
capture_cmd = ssh_base + [remote, "tmux", "capture-pane", "-t", session_id, "-p", "-S", "-500"]
|
||||||
elif IS_WINDOWS:
|
elif IS_WINDOWS:
|
||||||
# LOCAL Windows task: launched as a detached process (no tmux).
|
# LOCAL Windows task: launched as a detached process (no tmux).
|
||||||
# Liveness comes from the <session>.pid file, output from the
|
# Liveness comes from the <session>.pid file, output from the
|
||||||
@@ -2072,7 +2302,7 @@ def setup_cookbook_routes() -> APIRouter:
|
|||||||
capture_cmd = None
|
capture_cmd = None
|
||||||
else:
|
else:
|
||||||
check_cmd = ["tmux", "has-session", "-t", session_id]
|
check_cmd = ["tmux", "has-session", "-t", session_id]
|
||||||
capture_cmd = ["tmux", "capture-pane", "-t", session_id, "-p", "-S", "-50"]
|
capture_cmd = ["tmux", "capture-pane", "-t", session_id, "-p", "-S", "-500"]
|
||||||
|
|
||||||
local_win_task = (not remote) and IS_WINDOWS
|
local_win_task = (not remote) and IS_WINDOWS
|
||||||
|
|
||||||
|
|||||||
+67
-117
@@ -20,39 +20,26 @@ All routes are admin-gated (endpoint/provider management is an admin action).
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import time
|
|
||||||
import uuid
|
import uuid
|
||||||
import logging
|
import logging
|
||||||
import threading
|
|
||||||
from typing import Dict, Optional
|
from typing import Dict, Optional
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from fastapi import APIRouter, Request, Form, HTTPException
|
from fastapi import HTTPException, Request
|
||||||
|
|
||||||
from core.database import SessionLocal, ModelEndpoint
|
from core.database import SessionLocal, ModelEndpoint
|
||||||
from core.middleware import require_admin
|
from routes.device_flow import (
|
||||||
|
DeviceFlowPoll,
|
||||||
|
DeviceFlowStart,
|
||||||
|
PendingDeviceFlowStore,
|
||||||
|
create_device_flow_router,
|
||||||
|
)
|
||||||
from src.auth_helpers import get_current_user
|
from src.auth_helpers import get_current_user
|
||||||
from src import copilot
|
from src import copilot
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# Pending device-flow logins, keyed by an opaque poll_id. The device_code is a
|
_DEVICE_FLOW_STORE = PendingDeviceFlowStore()
|
||||||
# bearer-like secret, so it lives here (server memory) rather than in the
|
|
||||||
# browser. Entries expire with the GitHub device code.
|
|
||||||
#
|
|
||||||
# NOTE: this is per-process state. The device flow assumes a single worker
|
|
||||||
# (Odysseus' default): with multiple uvicorn workers, the poll request can land
|
|
||||||
# on a worker that never saw the start, returning "Unknown or expired login
|
|
||||||
# session". Move this to a shared store (DB/Redis) if running multi-worker.
|
|
||||||
_PENDING: Dict[str, Dict] = {}
|
|
||||||
_PENDING_LOCK = threading.Lock()
|
|
||||||
|
|
||||||
|
|
||||||
def _prune_expired() -> None:
|
|
||||||
now = time.time()
|
|
||||||
with _PENDING_LOCK:
|
|
||||||
for k in [k for k, v in _PENDING.items() if v.get("expires_at", 0) < now]:
|
|
||||||
_PENDING.pop(k, None)
|
|
||||||
|
|
||||||
|
|
||||||
def _provision_endpoint(token: str, base: str, owner: Optional[str]) -> Dict:
|
def _provision_endpoint(token: str, base: str, owner: Optional[str]) -> Dict:
|
||||||
@@ -112,112 +99,75 @@ def _provision_endpoint(token: str, base: str, owner: Optional[str]) -> Dict:
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
def setup_copilot_routes() -> APIRouter:
|
def _start_device_flow(request: Request, form) -> DeviceFlowStart:
|
||||||
router = APIRouter(prefix="/api/copilot", tags=["copilot"])
|
host = copilot.GITHUB_HOST
|
||||||
|
ent = str(form.get("enterprise_url") or "").strip()
|
||||||
|
if ent:
|
||||||
|
host = copilot.normalize_domain(ent)
|
||||||
|
try:
|
||||||
|
data = copilot.request_device_code(host)
|
||||||
|
except httpx.HTTPStatusError as e:
|
||||||
|
status = e.response.status_code if e.response is not None else "unknown"
|
||||||
|
raise HTTPException(502, f"GitHub device-code request failed (HTTP {status})")
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(502, f"GitHub device-code request failed: {e}")
|
||||||
|
|
||||||
@router.post("/device/start")
|
device_code = data.get("device_code")
|
||||||
def device_start(request: Request, enterprise_url: str = Form("")):
|
if not device_code:
|
||||||
require_admin(request)
|
raise HTTPException(502, "GitHub did not return a device code")
|
||||||
_prune_expired()
|
|
||||||
host = copilot.GITHUB_HOST
|
|
||||||
ent = (enterprise_url or "").strip()
|
|
||||||
if ent:
|
|
||||||
host = copilot.normalize_domain(ent)
|
|
||||||
try:
|
|
||||||
data = copilot.request_device_code(host)
|
|
||||||
except httpx.HTTPStatusError as e:
|
|
||||||
status = e.response.status_code if e.response is not None else "unknown"
|
|
||||||
raise HTTPException(502, f"GitHub device-code request failed (HTTP {status})")
|
|
||||||
except Exception as e:
|
|
||||||
raise HTTPException(502, f"GitHub device-code request failed: {e}")
|
|
||||||
|
|
||||||
device_code = data.get("device_code")
|
# verification_uri_complete embeds the user code, so the browser tab we
|
||||||
if not device_code:
|
# open lands the user straight on GitHub's "Authorize" screen with the
|
||||||
raise HTTPException(502, "GitHub did not return a device code")
|
# code pre-filled — one click, no manual code entry.
|
||||||
interval = int(data.get("interval") or 5)
|
return DeviceFlowStart(
|
||||||
expires_in = int(data.get("expires_in") or 900)
|
pending={
|
||||||
poll_id = uuid.uuid4().hex
|
"device_code": device_code,
|
||||||
with _PENDING_LOCK:
|
"host": host,
|
||||||
_PENDING[poll_id] = {
|
"enterprise_url": ent,
|
||||||
"device_code": device_code,
|
"owner": get_current_user(request) or None,
|
||||||
"host": host,
|
},
|
||||||
"enterprise_url": ent,
|
response={
|
||||||
"interval": interval,
|
|
||||||
"owner": get_current_user(request) or None,
|
|
||||||
"expires_at": time.time() + expires_in,
|
|
||||||
"next_poll_at": 0.0,
|
|
||||||
}
|
|
||||||
# verification_uri_complete embeds the user code, so the browser tab we
|
|
||||||
# open lands the user straight on GitHub's "Authorize" screen with the
|
|
||||||
# code pre-filled — one click, no manual code entry.
|
|
||||||
return {
|
|
||||||
"poll_id": poll_id,
|
|
||||||
"user_code": data.get("user_code"),
|
"user_code": data.get("user_code"),
|
||||||
"verification_uri": data.get("verification_uri"),
|
"verification_uri": data.get("verification_uri"),
|
||||||
"verification_uri_complete": data.get("verification_uri_complete"),
|
"verification_uri_complete": data.get("verification_uri_complete"),
|
||||||
"interval": interval,
|
},
|
||||||
"expires_in": expires_in,
|
interval=int(data.get("interval") or 5),
|
||||||
}
|
expires_in=int(data.get("expires_in") or 900),
|
||||||
|
)
|
||||||
|
|
||||||
@router.post("/device/poll")
|
|
||||||
def device_poll(request: Request, poll_id: str = Form(...)):
|
|
||||||
require_admin(request)
|
|
||||||
_prune_expired()
|
|
||||||
with _PENDING_LOCK:
|
|
||||||
pending = _PENDING.get(poll_id)
|
|
||||||
if not pending:
|
|
||||||
raise HTTPException(404, "Unknown or expired login session")
|
|
||||||
|
|
||||||
# Enforce GitHub's polling interval server-side so a chatty client
|
def _poll_device_flow(_request: Request, pending: Dict) -> DeviceFlowPoll:
|
||||||
# can't trip slow_down.
|
try:
|
||||||
now = time.time()
|
data = copilot.poll_access_token(pending["host"], pending["device_code"])
|
||||||
if now < pending.get("next_poll_at", 0):
|
except Exception as e:
|
||||||
return {"status": "pending"}
|
return DeviceFlowPoll.pending(f"poll error: {e}")
|
||||||
|
|
||||||
|
token = data.get("access_token")
|
||||||
|
if token:
|
||||||
|
base = copilot.enterprise_base(pending["enterprise_url"]) if pending["enterprise_url"] else copilot.COPILOT_BASE
|
||||||
try:
|
try:
|
||||||
data = copilot.poll_access_token(pending["host"], pending["device_code"])
|
result = _provision_endpoint(token, base, pending["owner"])
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return {"status": "pending", "detail": f"poll error: {e}"}
|
logger.exception("Copilot endpoint provisioning failed")
|
||||||
|
raise HTTPException(500, f"Login succeeded but provisioning failed: {e}")
|
||||||
|
return DeviceFlowPoll.authorized(result)
|
||||||
|
|
||||||
token = data.get("access_token")
|
err = data.get("error")
|
||||||
if token:
|
if err == "authorization_pending":
|
||||||
base = copilot.enterprise_base(pending["enterprise_url"]) if pending["enterprise_url"] else copilot.COPILOT_BASE
|
return DeviceFlowPoll.pending()
|
||||||
try:
|
if err == "slow_down":
|
||||||
result = _provision_endpoint(token, base, pending["owner"])
|
return DeviceFlowPoll.slow_down(int(data.get("interval") or 0) or None)
|
||||||
except Exception as e:
|
if err in ("expired_token", "access_denied"):
|
||||||
logger.exception("Copilot endpoint provisioning failed")
|
return DeviceFlowPoll.failed(err)
|
||||||
with _PENDING_LOCK:
|
# Unknown error — surface but keep the session for another try.
|
||||||
_PENDING.pop(poll_id, None)
|
return DeviceFlowPoll.pending(err or "unknown")
|
||||||
raise HTTPException(500, f"Login succeeded but provisioning failed: {e}")
|
|
||||||
with _PENDING_LOCK:
|
|
||||||
_PENDING.pop(poll_id, None)
|
|
||||||
return {"status": "authorized", "endpoint": result}
|
|
||||||
|
|
||||||
err = data.get("error")
|
|
||||||
if err == "authorization_pending":
|
|
||||||
with _PENDING_LOCK:
|
|
||||||
if poll_id in _PENDING:
|
|
||||||
_PENDING[poll_id]["next_poll_at"] = now + pending["interval"]
|
|
||||||
return {"status": "pending"}
|
|
||||||
if err == "slow_down":
|
|
||||||
new_interval = int(data.get("interval") or (pending["interval"] + 5))
|
|
||||||
with _PENDING_LOCK:
|
|
||||||
if poll_id in _PENDING:
|
|
||||||
_PENDING[poll_id]["interval"] = new_interval
|
|
||||||
_PENDING[poll_id]["next_poll_at"] = now + new_interval
|
|
||||||
return {"status": "pending"}
|
|
||||||
if err in ("expired_token", "access_denied"):
|
|
||||||
with _PENDING_LOCK:
|
|
||||||
_PENDING.pop(poll_id, None)
|
|
||||||
return {"status": "failed", "error": err}
|
|
||||||
# Unknown error — surface but keep the session for another try.
|
|
||||||
return {"status": "pending", "detail": err or "unknown"}
|
|
||||||
|
|
||||||
@router.post("/device/cancel")
|
def setup_copilot_routes():
|
||||||
def device_cancel(request: Request, poll_id: str = Form(...)):
|
return create_device_flow_router(
|
||||||
require_admin(request)
|
prefix="/api/copilot",
|
||||||
with _PENDING_LOCK:
|
tags=["copilot"],
|
||||||
_PENDING.pop(poll_id, None)
|
store=_DEVICE_FLOW_STORE,
|
||||||
return {"status": "cancelled"}
|
start_flow=_start_device_flow,
|
||||||
|
poll_flow=_poll_device_flow,
|
||||||
return router
|
)
|
||||||
|
|||||||
@@ -0,0 +1,193 @@
|
|||||||
|
"""Shared OAuth/device-flow route scaffolding for provider setup."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import inspect
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any, Callable, Iterable, Mapping, Optional
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Form, HTTPException, Request
|
||||||
|
|
||||||
|
from core.middleware import require_admin
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class DeviceFlowStart:
|
||||||
|
"""Provider-specific start result consumed by the shared route wrapper."""
|
||||||
|
|
||||||
|
pending: Mapping[str, Any]
|
||||||
|
response: Mapping[str, Any]
|
||||||
|
interval: int = 5
|
||||||
|
expires_in: int = 900
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class DeviceFlowPoll:
|
||||||
|
"""Normalized provider poll outcome."""
|
||||||
|
|
||||||
|
status: str
|
||||||
|
endpoint: Optional[Mapping[str, Any]] = None
|
||||||
|
error: Optional[str] = None
|
||||||
|
detail: Optional[str] = None
|
||||||
|
interval: Optional[int] = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def pending(cls, detail: Optional[str] = None) -> "DeviceFlowPoll":
|
||||||
|
return cls(status="pending", detail=detail)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def slow_down(cls, interval: Optional[int] = None, detail: Optional[str] = None) -> "DeviceFlowPoll":
|
||||||
|
return cls(status="slow_down", interval=interval, detail=detail)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def authorized(cls, endpoint: Mapping[str, Any]) -> "DeviceFlowPoll":
|
||||||
|
return cls(status="authorized", endpoint=endpoint)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def failed(cls, error: str) -> "DeviceFlowPoll":
|
||||||
|
return cls(status="failed", error=error)
|
||||||
|
|
||||||
|
|
||||||
|
class PendingDeviceFlowStore:
|
||||||
|
"""Thread-safe in-memory pending device-flow store.
|
||||||
|
|
||||||
|
Device codes and provider-side secrets stay inside this process. Each entry
|
||||||
|
stores provider payload separately from poll metadata so provider callbacks
|
||||||
|
only receive the fields they created.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, *, time_func: Callable[[], float] = time.time):
|
||||||
|
self._pending: dict[str, dict[str, Any]] = {}
|
||||||
|
self._lock = threading.Lock()
|
||||||
|
self._time = time_func
|
||||||
|
|
||||||
|
def _now(self) -> float:
|
||||||
|
return float(self._time())
|
||||||
|
|
||||||
|
def prune_expired(self) -> None:
|
||||||
|
now = self._now()
|
||||||
|
with self._lock:
|
||||||
|
for key in [k for k, v in self._pending.items() if v.get("expires_at", 0) < now]:
|
||||||
|
self._pending.pop(key, None)
|
||||||
|
|
||||||
|
def add(self, payload: Mapping[str, Any], *, interval: int, expires_in: int) -> str:
|
||||||
|
self.prune_expired()
|
||||||
|
poll_id = uuid.uuid4().hex
|
||||||
|
with self._lock:
|
||||||
|
self._pending[poll_id] = {
|
||||||
|
"payload": dict(payload),
|
||||||
|
"interval": max(int(interval or 5), 1),
|
||||||
|
"expires_at": self._now() + max(int(expires_in or 900), 1),
|
||||||
|
"next_poll_at": 0.0,
|
||||||
|
}
|
||||||
|
return poll_id
|
||||||
|
|
||||||
|
def get_payload(self, poll_id: str) -> Optional[dict[str, Any]]:
|
||||||
|
self.prune_expired()
|
||||||
|
with self._lock:
|
||||||
|
entry = self._pending.get(poll_id)
|
||||||
|
if entry is None:
|
||||||
|
return None
|
||||||
|
return dict(entry.get("payload") or {})
|
||||||
|
|
||||||
|
def is_throttled(self, poll_id: str) -> bool:
|
||||||
|
with self._lock:
|
||||||
|
entry = self._pending.get(poll_id)
|
||||||
|
return bool(entry and self._now() < float(entry.get("next_poll_at") or 0))
|
||||||
|
|
||||||
|
def schedule_next(self, poll_id: str) -> None:
|
||||||
|
now = self._now()
|
||||||
|
with self._lock:
|
||||||
|
entry = self._pending.get(poll_id)
|
||||||
|
if entry is not None:
|
||||||
|
entry["next_poll_at"] = now + int(entry.get("interval") or 5)
|
||||||
|
|
||||||
|
def slow_down(self, poll_id: str, interval: Optional[int] = None) -> None:
|
||||||
|
now = self._now()
|
||||||
|
with self._lock:
|
||||||
|
entry = self._pending.get(poll_id)
|
||||||
|
if entry is not None:
|
||||||
|
new_interval = int(interval or (int(entry.get("interval") or 5) + 5))
|
||||||
|
entry["interval"] = max(new_interval, 1)
|
||||||
|
entry["next_poll_at"] = now + entry["interval"]
|
||||||
|
|
||||||
|
def pop(self, poll_id: str) -> None:
|
||||||
|
with self._lock:
|
||||||
|
self._pending.pop(poll_id, None)
|
||||||
|
|
||||||
|
|
||||||
|
async def _maybe_await(value: Any) -> Any:
|
||||||
|
if inspect.isawaitable(value):
|
||||||
|
return await value
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
def _pending_response(detail: Optional[str] = None) -> dict[str, Any]:
|
||||||
|
response: dict[str, Any] = {"status": "pending"}
|
||||||
|
if detail:
|
||||||
|
response["detail"] = detail
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
def create_device_flow_router(
|
||||||
|
*,
|
||||||
|
prefix: str,
|
||||||
|
tags: Iterable[str],
|
||||||
|
store: PendingDeviceFlowStore,
|
||||||
|
start_flow: Callable[[Request, Mapping[str, Any]], DeviceFlowStart],
|
||||||
|
poll_flow: Callable[[Request, Mapping[str, Any]], DeviceFlowPoll],
|
||||||
|
) -> APIRouter:
|
||||||
|
"""Create standard `/device/start|poll|cancel` routes for a provider."""
|
||||||
|
|
||||||
|
router = APIRouter(prefix=prefix, tags=list(tags))
|
||||||
|
|
||||||
|
@router.post("/device/start")
|
||||||
|
async def device_start(request: Request):
|
||||||
|
require_admin(request)
|
||||||
|
form = await request.form()
|
||||||
|
start = await _maybe_await(start_flow(request, form))
|
||||||
|
interval = int(start.interval or 5)
|
||||||
|
expires_in = int(start.expires_in or 900)
|
||||||
|
poll_id = store.add(start.pending, interval=interval, expires_in=expires_in)
|
||||||
|
response = dict(start.response)
|
||||||
|
response.update({"poll_id": poll_id, "interval": interval, "expires_in": expires_in})
|
||||||
|
return response
|
||||||
|
|
||||||
|
@router.post("/device/poll")
|
||||||
|
async def device_poll(request: Request, poll_id: str = Form(...)):
|
||||||
|
require_admin(request)
|
||||||
|
payload = store.get_payload(poll_id)
|
||||||
|
if payload is None:
|
||||||
|
raise HTTPException(404, "Unknown or expired login session")
|
||||||
|
if store.is_throttled(poll_id):
|
||||||
|
return {"status": "pending"}
|
||||||
|
|
||||||
|
try:
|
||||||
|
outcome = await _maybe_await(poll_flow(request, payload))
|
||||||
|
except Exception:
|
||||||
|
store.pop(poll_id)
|
||||||
|
raise
|
||||||
|
|
||||||
|
if outcome.status == "authorized":
|
||||||
|
store.pop(poll_id)
|
||||||
|
return {"status": "authorized", "endpoint": dict(outcome.endpoint or {})}
|
||||||
|
if outcome.status == "failed":
|
||||||
|
store.pop(poll_id)
|
||||||
|
return {"status": "failed", "error": outcome.error or "denied"}
|
||||||
|
if outcome.status == "slow_down":
|
||||||
|
store.slow_down(poll_id, outcome.interval)
|
||||||
|
return _pending_response(outcome.detail)
|
||||||
|
|
||||||
|
store.schedule_next(poll_id)
|
||||||
|
return _pending_response(outcome.detail)
|
||||||
|
|
||||||
|
@router.post("/device/cancel")
|
||||||
|
def device_cancel(request: Request, poll_id: str = Form(...)):
|
||||||
|
require_admin(request)
|
||||||
|
store.pop(poll_id)
|
||||||
|
return {"status": "cancelled"}
|
||||||
|
|
||||||
|
return router
|
||||||
+72
-40
@@ -7,14 +7,24 @@ from typing import Dict, Any, List, Optional
|
|||||||
|
|
||||||
from fastapi import APIRouter, HTTPException, Query, Request, UploadFile, File, Form
|
from fastapi import APIRouter, HTTPException, Query, Request, UploadFile, File, Form
|
||||||
|
|
||||||
from sqlalchemy import func
|
from sqlalchemy import case, func, or_
|
||||||
from core.database import SessionLocal, Document, DocumentVersion
|
from core.database import SessionLocal, Document, DocumentVersion
|
||||||
from core.database import Session as DbSession
|
from core.database import Session as DbSession
|
||||||
from src.auth_helpers import get_current_user
|
from src.auth_helpers import get_current_user
|
||||||
|
from src.constants import MAIL_ATTACHMENTS_DIR
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_session_or_404(db, session_id: str, user: Optional[str]):
|
||||||
|
session = db.query(DbSession).filter(DbSession.id == session_id).first()
|
||||||
|
if not session:
|
||||||
|
raise HTTPException(404, "Session not found")
|
||||||
|
if user and session.owner != user:
|
||||||
|
raise HTTPException(404, "Session not found")
|
||||||
|
return session
|
||||||
|
|
||||||
|
|
||||||
def _aggregate_language_facets(lang_rows):
|
def _aggregate_language_facets(lang_rows):
|
||||||
"""Sum document counts per display language for the library facet.
|
"""Sum document counts per display language for the library facet.
|
||||||
|
|
||||||
@@ -30,6 +40,19 @@ def _aggregate_language_facets(lang_rows):
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def _library_language_for_document(doc: Document) -> str:
|
||||||
|
"""Return the display language used by the document library.
|
||||||
|
|
||||||
|
PDF documents are stored as markdown wrappers so the editor can preserve
|
||||||
|
extracted text, form fields, and annotations. The library should still
|
||||||
|
identify them as PDFs instead of exposing that internal wrapper format.
|
||||||
|
"""
|
||||||
|
from src.pdf_form_doc import find_source_upload_id
|
||||||
|
|
||||||
|
if find_source_upload_id(doc.current_content or ""):
|
||||||
|
return "pdf"
|
||||||
|
return doc.language or "text"
|
||||||
|
|
||||||
|
|
||||||
from routes.document_helpers import (
|
from routes.document_helpers import (
|
||||||
DocumentCreate, DocumentUpdate, DocumentPatch,
|
DocumentCreate, DocumentUpdate, DocumentPatch,
|
||||||
@@ -69,17 +92,12 @@ def setup_document_routes(session_manager, upload_handler=None) -> APIRouter:
|
|||||||
# the doc is owner-stamped, so it lives in the library on its own.
|
# the doc is owner-stamped, so it lives in the library on its own.
|
||||||
session = None
|
session = None
|
||||||
if req.session_id:
|
if req.session_id:
|
||||||
session = db.query(DbSession).filter(DbSession.id == req.session_id).first()
|
|
||||||
if not session:
|
|
||||||
raise HTTPException(404, "Session not found")
|
|
||||||
# Match the lenient ownership model the rest of the app uses
|
# Match the lenient ownership model the rest of the app uses
|
||||||
# (see _owner_filter): only block when an AUTHENTICATED user is
|
# (see _owner_filter): only block when an AUTHENTICATED user is
|
||||||
# writing into a DIFFERENT user's session. In single-user /
|
# writing into a DIFFERENT user's session. In single-user /
|
||||||
# unconfigured / localhost-bypass mode the middleware leaves
|
# unconfigured / localhost-bypass mode, falsey users preserve
|
||||||
# current_user unset (None), and those sessions are already
|
# the existing lenient path.
|
||||||
# served freely everywhere else.
|
session = _get_session_or_404(db, req.session_id, user)
|
||||||
if user and session.owner and session.owner != user:
|
|
||||||
raise HTTPException(403, "Cannot create document in another user's session")
|
|
||||||
|
|
||||||
doc_id = str(uuid.uuid4())
|
doc_id = str(uuid.uuid4())
|
||||||
ver_id = str(uuid.uuid4())
|
ver_id = str(uuid.uuid4())
|
||||||
@@ -171,11 +189,7 @@ def setup_document_routes(session_manager, upload_handler=None) -> APIRouter:
|
|||||||
if session_id:
|
if session_id:
|
||||||
db = SessionLocal()
|
db = SessionLocal()
|
||||||
try:
|
try:
|
||||||
sess = db.query(DbSession).filter(DbSession.id == session_id).first()
|
_get_session_or_404(db, session_id, user)
|
||||||
if not sess:
|
|
||||||
raise HTTPException(404, "Session not found")
|
|
||||||
if user and sess.owner and sess.owner != user:
|
|
||||||
raise HTTPException(403, "Cannot import into another user's session")
|
|
||||||
finally:
|
finally:
|
||||||
db.close()
|
db.close()
|
||||||
|
|
||||||
@@ -198,7 +212,7 @@ def setup_document_routes(session_manager, upload_handler=None) -> APIRouter:
|
|||||||
|
|
||||||
title = os.path.splitext(meta.get("original_name") or meta.get("name") or upload_id)[0]
|
title = os.path.splitext(meta.get("original_name") or meta.get("name") or upload_id)[0]
|
||||||
try:
|
try:
|
||||||
body_text = strip_pdf_content_marker(_process_pdf(pdf_path))
|
body_text = strip_pdf_content_marker(_process_pdf(pdf_path, owner=user))
|
||||||
except Exception:
|
except Exception:
|
||||||
body_text = None
|
body_text = None
|
||||||
|
|
||||||
@@ -260,18 +274,29 @@ def setup_document_routes(session_manager, upload_handler=None) -> APIRouter:
|
|||||||
db = SessionLocal()
|
db = SessionLocal()
|
||||||
try:
|
try:
|
||||||
from sqlalchemy import or_
|
from sqlalchemy import or_
|
||||||
|
pdf_marker_cond = or_(
|
||||||
|
Document.current_content.like('%<!-- pdf_source upload_id="%'),
|
||||||
|
Document.current_content.like('%<!-- pdf_form_source upload_id="%'),
|
||||||
|
)
|
||||||
|
library_language_expr = case(
|
||||||
|
(pdf_marker_cond, "pdf"),
|
||||||
|
(Document.language.is_(None), "text"),
|
||||||
|
else_=Document.language,
|
||||||
|
)
|
||||||
# Archived view shows ONLY archived docs; the default view excludes
|
# Archived view shows ONLY archived docs; the default view excludes
|
||||||
# them (NULL = legacy rows that predate the column = not archived).
|
# them (NULL = legacy rows that predate the column = not archived).
|
||||||
_arch_cond = (Document.archived == True) if archived else or_(
|
_arch_cond = (Document.archived == True) if archived else or_(
|
||||||
Document.archived == False, Document.archived.is_(None))
|
Document.archived == False, Document.archived.is_(None))
|
||||||
# Language facet counts (owner-filtered)
|
# Language facet counts (owner-filtered). PDF documents are stored
|
||||||
|
# as markdown wrappers, so group by the library display language
|
||||||
|
# instead of the raw stored language.
|
||||||
lang_q = (
|
lang_q = (
|
||||||
db.query(Document.language, func.count(Document.id))
|
db.query(library_language_expr, func.count(Document.id))
|
||||||
.outerjoin(DbSession, Document.session_id == DbSession.id)
|
.outerjoin(DbSession, Document.session_id == DbSession.id)
|
||||||
.filter(Document.is_active == True).filter(_arch_cond)
|
.filter(Document.is_active == True).filter(_arch_cond)
|
||||||
)
|
)
|
||||||
lang_q = _owner_session_filter(lang_q, user)
|
lang_q = _owner_session_filter(lang_q, user)
|
||||||
lang_rows = lang_q.group_by(Document.language).all()
|
lang_rows = lang_q.group_by(library_language_expr).all()
|
||||||
languages = _aggregate_language_facets(lang_rows)
|
languages = _aggregate_language_facets(lang_rows)
|
||||||
|
|
||||||
# Session count (owner-filtered)
|
# Session count (owner-filtered)
|
||||||
@@ -303,12 +328,17 @@ def setup_document_routes(session_manager, upload_handler=None) -> APIRouter:
|
|||||||
Document.title.ilike(term) | Document.current_content.ilike(term)
|
Document.title.ilike(term) | Document.current_content.ilike(term)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Language filter
|
# Language filter. "pdf" is a display language derived from the
|
||||||
|
# source marker; "markdown" excludes those wrappers.
|
||||||
if language:
|
if language:
|
||||||
if language == "text":
|
if language == "text":
|
||||||
q = q.filter((Document.language == None) | (Document.language == "text"))
|
q = q.filter((Document.language == None) | (Document.language == "text"))
|
||||||
|
elif language == "pdf":
|
||||||
|
q = q.filter(pdf_marker_cond)
|
||||||
else:
|
else:
|
||||||
q = q.filter(Document.language == language)
|
q = q.filter(Document.language == language)
|
||||||
|
if language == "markdown":
|
||||||
|
q = q.filter(~pdf_marker_cond)
|
||||||
|
|
||||||
# Total before pagination
|
# Total before pagination
|
||||||
total = q.count()
|
total = q.count()
|
||||||
@@ -332,7 +362,7 @@ def setup_document_routes(session_manager, upload_handler=None) -> APIRouter:
|
|||||||
"session_id": doc.session_id,
|
"session_id": doc.session_id,
|
||||||
"session_name": session_name,
|
"session_name": session_name,
|
||||||
"title": doc.title,
|
"title": doc.title,
|
||||||
"language": doc.language or "text",
|
"language": _library_language_for_document(doc),
|
||||||
"preview": (doc.current_content or "")[:500],
|
"preview": (doc.current_content or "")[:500],
|
||||||
"version_count": doc.version_count,
|
"version_count": doc.version_count,
|
||||||
"created_at": (doc.created_at.isoformat() + "Z") if doc.created_at else None,
|
"created_at": (doc.created_at.isoformat() + "Z") if doc.created_at else None,
|
||||||
@@ -359,18 +389,17 @@ def setup_document_routes(session_manager, upload_handler=None) -> APIRouter:
|
|||||||
try:
|
try:
|
||||||
if not user:
|
if not user:
|
||||||
raise HTTPException(403, "Authentication required")
|
raise HTTPException(403, "Authentication required")
|
||||||
session = db.query(DbSession).filter(DbSession.id == session_id).first()
|
|
||||||
# v2 review HIGH-9: raise 403 explicitly when the caller
|
# v2 review HIGH-9: raise 403 explicitly when the caller
|
||||||
# can't see this session, instead of returning [] which the
|
# can't see this session, instead of returning [] which the
|
||||||
# UI treats identically to "no docs" and silently masks
|
# UI treats identically to "no docs" and silently masks
|
||||||
# auth failures.
|
# auth failures.
|
||||||
if not session:
|
_get_session_or_404(db, session_id, user)
|
||||||
raise HTTPException(404, "Session not found")
|
q = db.query(Document).filter(
|
||||||
if user and session.owner and session.owner != user:
|
|
||||||
raise HTTPException(403, "Access denied")
|
|
||||||
docs = db.query(Document).filter(
|
|
||||||
Document.session_id == session_id
|
Document.session_id == session_id
|
||||||
).order_by(Document.created_at.desc()).all()
|
)
|
||||||
|
if user:
|
||||||
|
q = q.filter(or_(Document.owner == user, Document.owner.is_(None)))
|
||||||
|
docs = q.order_by(Document.created_at.desc()).all()
|
||||||
return [_doc_to_dict(d) for d in docs]
|
return [_doc_to_dict(d) for d in docs]
|
||||||
finally:
|
finally:
|
||||||
db.close()
|
db.close()
|
||||||
@@ -437,7 +466,7 @@ def setup_document_routes(session_manager, upload_handler=None) -> APIRouter:
|
|||||||
raise HTTPException(404, "Source PDF could not be located")
|
raise HTTPException(404, "Source PDF could not be located")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
body_text = strip_pdf_content_marker(_process_pdf(pdf_path))
|
body_text = strip_pdf_content_marker(_process_pdf(pdf_path, owner=user))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"extract_pdf_text failed for {pdf_path}: {e}")
|
logger.error(f"extract_pdf_text failed for {pdf_path}: {e}")
|
||||||
raise HTTPException(500, f"Extraction failed: {e}")
|
raise HTTPException(500, f"Extraction failed: {e}")
|
||||||
@@ -606,6 +635,8 @@ def setup_document_routes(session_manager, upload_handler=None) -> APIRouter:
|
|||||||
doc.language = req.language
|
doc.language = req.language
|
||||||
if req.session_id is not None:
|
if req.session_id is not None:
|
||||||
# Empty string = unlink from session
|
# Empty string = unlink from session
|
||||||
|
if req.session_id:
|
||||||
|
_get_session_or_404(db, req.session_id, user)
|
||||||
doc.session_id = req.session_id if req.session_id else None
|
doc.session_id = req.session_id if req.session_id else None
|
||||||
if not req.session_id:
|
if not req.session_id:
|
||||||
# Tab closed / doc detached from its session — drop the
|
# Tab closed / doc detached from its session — drop the
|
||||||
@@ -663,8 +694,9 @@ def setup_document_routes(session_manager, upload_handler=None) -> APIRouter:
|
|||||||
try:
|
try:
|
||||||
# Verify ownership before listing versions
|
# Verify ownership before listing versions
|
||||||
doc = db.query(Document).filter(Document.id == doc_id).first()
|
doc = db.query(Document).filter(Document.id == doc_id).first()
|
||||||
if doc:
|
if not doc:
|
||||||
_verify_doc_owner(db, doc, user)
|
raise HTTPException(404, "Document not found")
|
||||||
|
_verify_doc_owner(db, doc, user)
|
||||||
versions = db.query(DocumentVersion).filter(
|
versions = db.query(DocumentVersion).filter(
|
||||||
DocumentVersion.document_id == doc_id
|
DocumentVersion.document_id == doc_id
|
||||||
).order_by(DocumentVersion.version_number.desc()).all()
|
).order_by(DocumentVersion.version_number.desc()).all()
|
||||||
@@ -687,8 +719,9 @@ def setup_document_routes(session_manager, upload_handler=None) -> APIRouter:
|
|||||||
try:
|
try:
|
||||||
# Verify ownership
|
# Verify ownership
|
||||||
doc = db.query(Document).filter(Document.id == doc_id).first()
|
doc = db.query(Document).filter(Document.id == doc_id).first()
|
||||||
if doc:
|
if not doc:
|
||||||
_verify_doc_owner(db, doc, user)
|
raise HTTPException(404, "Document not found")
|
||||||
|
_verify_doc_owner(db, doc, user)
|
||||||
ver = db.query(DocumentVersion).filter(
|
ver = db.query(DocumentVersion).filter(
|
||||||
DocumentVersion.document_id == doc_id,
|
DocumentVersion.document_id == doc_id,
|
||||||
DocumentVersion.version_number == num,
|
DocumentVersion.version_number == num,
|
||||||
@@ -853,10 +886,10 @@ def setup_document_routes(session_manager, upload_handler=None) -> APIRouter:
|
|||||||
from src.llm_core import llm_call_async
|
from src.llm_core import llm_call_async
|
||||||
|
|
||||||
user = get_current_user(request)
|
user = get_current_user(request)
|
||||||
url, model, headers = resolve_task_endpoint()
|
url, model, headers = resolve_task_endpoint(owner=user or None)
|
||||||
if not url or not model:
|
if not url or not model:
|
||||||
# Fall back to default endpoint
|
# Fall back to default endpoint
|
||||||
url, model, headers = resolve_endpoint("default")
|
url, model, headers = resolve_endpoint("default", owner=user or None)
|
||||||
if not url or not model:
|
if not url or not model:
|
||||||
raise HTTPException(500, "No endpoint configured for AI tidy")
|
raise HTTPException(500, "No endpoint configured for AI tidy")
|
||||||
|
|
||||||
@@ -1156,7 +1189,7 @@ def setup_document_routes(session_manager, upload_handler=None) -> APIRouter:
|
|||||||
settings = _load_vl_settings()
|
settings = _load_vl_settings()
|
||||||
vl_model = settings.get("vision_model", "")
|
vl_model = settings.get("vision_model", "")
|
||||||
try:
|
try:
|
||||||
url, model_id, headers = _resolve_vl_model(vl_model)
|
url, model_id, headers = _resolve_vl_model(vl_model, owner=user)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(503, f"No vision model available: {e}")
|
raise HTTPException(503, f"No vision model available: {e}")
|
||||||
|
|
||||||
@@ -1510,10 +1543,7 @@ def setup_document_routes(session_manager, upload_handler=None) -> APIRouter:
|
|||||||
# don't import from a routes file (cycle-prone). Same env override
|
# don't import from a routes file (cycle-prone). Same env override
|
||||||
# as email_routes (ODYSSEUS_MAIL_ATTACHMENTS_DIR).
|
# as email_routes (ODYSSEUS_MAIL_ATTACHMENTS_DIR).
|
||||||
from pathlib import Path as _Path
|
from pathlib import Path as _Path
|
||||||
import os as _os
|
_COMPOSE_DIR = _Path(MAIL_ATTACHMENTS_DIR) / "_compose"
|
||||||
_DATA_DIR = _Path(__file__).resolve().parent.parent / "data"
|
|
||||||
_BASE = _os.environ.get("ODYSSEUS_MAIL_ATTACHMENTS_DIR", str(_DATA_DIR / "mail-attachments"))
|
|
||||||
_COMPOSE_DIR = _Path(_BASE) / "_compose"
|
|
||||||
_COMPOSE_DIR.mkdir(parents=True, exist_ok=True)
|
_COMPOSE_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
user = get_current_user(request)
|
user = get_current_user(request)
|
||||||
@@ -1629,9 +1659,11 @@ def setup_document_routes(session_manager, upload_handler=None) -> APIRouter:
|
|||||||
# context (To/Subject/In-Reply-To/References).
|
# context (To/Subject/In-Reply-To/References).
|
||||||
try:
|
try:
|
||||||
from routes.email_routes import _imap, _decode_header
|
from routes.email_routes import _imap, _decode_header
|
||||||
|
from routes.email_helpers import _q
|
||||||
except Exception:
|
except Exception:
|
||||||
_imap = None
|
_imap = None
|
||||||
_decode_header = lambda x: x or ""
|
_decode_header = lambda x: x or ""
|
||||||
|
_q = lambda x: x or ""
|
||||||
|
|
||||||
to_addr = ""
|
to_addr = ""
|
||||||
from_name = ""
|
from_name = ""
|
||||||
@@ -1641,7 +1673,7 @@ def setup_document_routes(session_manager, upload_handler=None) -> APIRouter:
|
|||||||
if _imap:
|
if _imap:
|
||||||
try:
|
try:
|
||||||
with _imap(doc.source_email_account_id or None) as conn:
|
with _imap(doc.source_email_account_id or None) as conn:
|
||||||
conn.select(doc.source_email_folder, readonly=True)
|
conn.select(_q(doc.source_email_folder), readonly=True)
|
||||||
status, data = conn.fetch(doc.source_email_uid.encode(), "(RFC822.HEADER)")
|
status, data = conn.fetch(doc.source_email_uid.encode(), "(RFC822.HEADER)")
|
||||||
if status == "OK" and data and data[0]:
|
if status == "OK" and data and data[0]:
|
||||||
raw_hdr = data[0][1]
|
raw_hdr = data[0][1]
|
||||||
|
|||||||
+98
-33
@@ -71,6 +71,38 @@ def _send_smtp_message(cfg: dict, from_addr: str, recipients: list[str], message
|
|||||||
smtp.sendmail(from_addr, recipients, message)
|
smtp.sendmail(from_addr, recipients, message)
|
||||||
|
|
||||||
|
|
||||||
|
def _friendly_email_auth_error(protocol: str, host: str, error: object) -> str:
|
||||||
|
"""Return a clearer setup error for known provider auth policies."""
|
||||||
|
raw = str(error or "")
|
||||||
|
lower = raw.lower()
|
||||||
|
host_lower = (host or "").lower()
|
||||||
|
microsoft_host = any(
|
||||||
|
marker in host_lower
|
||||||
|
for marker in (
|
||||||
|
"outlook.office365.com",
|
||||||
|
"smtp.office365.com",
|
||||||
|
"office365.com",
|
||||||
|
"outlook.com",
|
||||||
|
"hotmail.com",
|
||||||
|
"live.com",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
microsoft_basic_auth_failure = (
|
||||||
|
"5.7.139" in lower
|
||||||
|
or "basic authentication is disabled" in lower
|
||||||
|
or ("authenticate failed" in lower and microsoft_host)
|
||||||
|
or ("authentication unsuccessful" in lower and microsoft_host)
|
||||||
|
)
|
||||||
|
if microsoft_basic_auth_failure:
|
||||||
|
return (
|
||||||
|
"Microsoft no longer accepts normal mailbox passwords for "
|
||||||
|
"Outlook/Office 365 IMAP/SMTP in most accounts. Odysseus "
|
||||||
|
"does not support Microsoft OAuth/Graph mail yet, so Outlook "
|
||||||
|
"accounts cannot be added with this password form."
|
||||||
|
)
|
||||||
|
return raw[:200]
|
||||||
|
|
||||||
|
|
||||||
def _strip_think(text: str) -> str:
|
def _strip_think(text: str) -> str:
|
||||||
"""Email-flavored think strip — thin wrapper over the central helper.
|
"""Email-flavored think strip — thin wrapper over the central helper.
|
||||||
|
|
||||||
@@ -254,16 +286,17 @@ def _cleanup_compose_uploads(tokens) -> None:
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
DATA_DIR = Path(__file__).resolve().parent.parent / "data"
|
from src.constants import DATA_DIR as _DATA_DIR, MAIL_ATTACHMENTS_DIR, SETTINGS_FILE as _SETTINGS_FILE, SCHEDULED_EMAILS_DB
|
||||||
SETTINGS_FILE = DATA_DIR / "settings.json"
|
DATA_DIR = Path(_DATA_DIR)
|
||||||
|
SETTINGS_FILE = Path(_SETTINGS_FILE)
|
||||||
# Override at deploy time via ODYSSEUS_MAIL_ATTACHMENTS_DIR. Defaults to a
|
# Override at deploy time via ODYSSEUS_MAIL_ATTACHMENTS_DIR. Defaults to a
|
||||||
# subdir of the install's data/ tree so the app works out-of-the-box without
|
# subdir of the install's data/ tree so the app works out-of-the-box without
|
||||||
# a hardcoded /home/<user>/ path.
|
# a hardcoded /home/<user>/ path.
|
||||||
ATTACHMENTS_DIR = Path(os.environ.get("ODYSSEUS_MAIL_ATTACHMENTS_DIR", str(DATA_DIR / "mail-attachments")))
|
ATTACHMENTS_DIR = Path(MAIL_ATTACHMENTS_DIR)
|
||||||
ATTACHMENTS_DIR.mkdir(parents=True, exist_ok=True)
|
ATTACHMENTS_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
COMPOSE_UPLOADS_DIR = ATTACHMENTS_DIR / "_compose"
|
COMPOSE_UPLOADS_DIR = ATTACHMENTS_DIR / "_compose"
|
||||||
COMPOSE_UPLOADS_DIR.mkdir(parents=True, exist_ok=True)
|
COMPOSE_UPLOADS_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
SCHEDULED_DB = DATA_DIR / "scheduled_emails.db"
|
SCHEDULED_DB = Path(SCHEDULED_EMAILS_DB)
|
||||||
|
|
||||||
|
|
||||||
OWNER_SCOPED_EMAIL_CACHE_TABLES = {
|
OWNER_SCOPED_EMAIL_CACHE_TABLES = {
|
||||||
@@ -705,7 +738,16 @@ def _open_imap_connection(host: str, port: int, *, starttls: bool, timeout: int
|
|||||||
port = int(port or 993)
|
port = int(port or 993)
|
||||||
if starttls:
|
if starttls:
|
||||||
conn = imaplib.IMAP4(host, port, timeout=timeout)
|
conn = imaplib.IMAP4(host, port, timeout=timeout)
|
||||||
conn.starttls()
|
try:
|
||||||
|
conn.starttls()
|
||||||
|
except Exception:
|
||||||
|
# Don't leak the open plain socket if the STARTTLS upgrade is
|
||||||
|
# rejected; close it before propagating. (#3174)
|
||||||
|
try:
|
||||||
|
conn.shutdown()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
raise
|
||||||
elif port == 993:
|
elif port == 993:
|
||||||
conn = imaplib.IMAP4_SSL(host, port, timeout=timeout)
|
conn = imaplib.IMAP4_SSL(host, port, timeout=timeout)
|
||||||
else:
|
else:
|
||||||
@@ -714,6 +756,10 @@ def _open_imap_connection(host: str, port: int, *, starttls: bool, timeout: int
|
|||||||
conn.sock.settimeout(timeout)
|
conn.sock.settimeout(timeout)
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
# Raise the IMAP line-length limit from the default 1 MB to 50 MB so that
|
||||||
|
# large mailboxes (tens of thousands of messages) don't crash with
|
||||||
|
# "got more than 1000000 bytes" on UID SEARCH ALL. (#2883)
|
||||||
|
imaplib._MAXLINE = 50_000_000
|
||||||
return conn
|
return conn
|
||||||
|
|
||||||
def _imap_connect(account_id: str | None = None, owner: str = ""):
|
def _imap_connect(account_id: str | None = None, owner: str = ""):
|
||||||
@@ -734,7 +780,18 @@ def _imap_connect(account_id: str | None = None, owner: str = ""):
|
|||||||
starttls=bool(cfg.get("imap_starttls")),
|
starttls=bool(cfg.get("imap_starttls")),
|
||||||
timeout=_IMAP_TIMEOUT_SECONDS,
|
timeout=_IMAP_TIMEOUT_SECONDS,
|
||||||
)
|
)
|
||||||
conn.login(cfg["imap_user"], cfg["imap_password"])
|
try:
|
||||||
|
conn.login(cfg["imap_user"], cfg["imap_password"])
|
||||||
|
except Exception:
|
||||||
|
# A failed AUTHENTICATE (e.g. an Office 365 app password on an
|
||||||
|
# MFA-enabled tenant, #3174) otherwise orphans the already-connected
|
||||||
|
# socket; close it before propagating so a misconfigured account
|
||||||
|
# can't leak one descriptor per retry / background poller pass.
|
||||||
|
try:
|
||||||
|
conn.shutdown()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
raise
|
||||||
return conn
|
return conn
|
||||||
|
|
||||||
|
|
||||||
@@ -798,20 +855,28 @@ def _imap(account_id: str | None = None, owner: str = ""):
|
|||||||
def _decode_header(raw):
|
def _decode_header(raw):
|
||||||
if not raw:
|
if not raw:
|
||||||
return ""
|
return ""
|
||||||
parts = email.header.decode_header(raw)
|
try:
|
||||||
decoded = []
|
# make_header concatenates per RFC 2047: no spurious space between an
|
||||||
for data, charset in parts:
|
# encoded-word and adjacent plain text (plain runs keep their own
|
||||||
if isinstance(data, bytes):
|
# whitespace), and the whitespace between two adjacent encoded-words is
|
||||||
try:
|
# dropped. The old " ".join produced "Re: Jose"-style double spaces on
|
||||||
decoded.append(data.decode(charset or "utf-8", errors="replace"))
|
# every non-ASCII subject or sender.
|
||||||
except (LookupError, ValueError):
|
return str(email.header.make_header(email.header.decode_header(raw)))
|
||||||
# Unknown/invalid MIME charset (e.g. a malformed or spam header
|
except Exception:
|
||||||
# like =?x-unknown-charset?B?...?=). errors="replace" only covers
|
# Malformed header or unknown/invalid MIME charset (e.g. a spam header
|
||||||
# byte-decode errors, not codec lookup, so fall back to utf-8.
|
# like =?x-unknown-charset?B?...?=) makes make_header raise LookupError;
|
||||||
decoded.append(data.decode("utf-8", errors="replace"))
|
# fall back to a lossy per-part decode. errors="replace" only covers
|
||||||
else:
|
# byte-decode errors, not codec lookup, hence the explicit utf-8 retry.
|
||||||
decoded.append(data)
|
decoded = []
|
||||||
return " ".join(decoded)
|
for data, charset in email.header.decode_header(raw):
|
||||||
|
if isinstance(data, bytes):
|
||||||
|
try:
|
||||||
|
decoded.append(data.decode(charset or "utf-8", errors="replace"))
|
||||||
|
except (LookupError, ValueError):
|
||||||
|
decoded.append(data.decode("utf-8", errors="replace"))
|
||||||
|
else:
|
||||||
|
decoded.append(data)
|
||||||
|
return "".join(decoded)
|
||||||
|
|
||||||
|
|
||||||
def _detect_sent_folder(conn):
|
def _detect_sent_folder(conn):
|
||||||
@@ -1136,13 +1201,9 @@ def _fetch_sender_thread_context(sender_addr: str,
|
|||||||
if exclude_uid:
|
if exclude_uid:
|
||||||
seen_uids.add((exclude_folder or "INBOX", str(exclude_uid)))
|
seen_uids.add((exclude_folder or "INBOX", str(exclude_uid)))
|
||||||
|
|
||||||
|
conn = None
|
||||||
try:
|
try:
|
||||||
conn = _imap_connect(account_id, owner=owner)
|
conn = _imap_connect(account_id, owner=owner)
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"sender-thread-context: imap connect failed: {e}")
|
|
||||||
return ""
|
|
||||||
|
|
||||||
try:
|
|
||||||
for folder in ["INBOX", "Sent", "Archive", "Drafts"]:
|
for folder in ["INBOX", "Sent", "Archive", "Drafts"]:
|
||||||
if len(blocks) >= limit:
|
if len(blocks) >= limit:
|
||||||
break
|
break
|
||||||
@@ -1209,11 +1270,14 @@ def _fetch_sender_thread_context(sender_addr: str,
|
|||||||
if atts_text:
|
if atts_text:
|
||||||
lines.append(atts_text)
|
lines.append(atts_text)
|
||||||
blocks.append("\n".join(lines))
|
blocks.append("\n".join(lines))
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"sender-thread-context: imap failed: {e}")
|
||||||
finally:
|
finally:
|
||||||
try: conn.close()
|
if conn:
|
||||||
except Exception: pass
|
try: conn.close()
|
||||||
try: conn.logout()
|
except Exception: pass
|
||||||
except Exception: pass
|
try: conn.logout()
|
||||||
|
except Exception: pass
|
||||||
|
|
||||||
if not blocks:
|
if not blocks:
|
||||||
return ""
|
return ""
|
||||||
@@ -1316,6 +1380,7 @@ def _pre_retrieve_context(
|
|||||||
if not terms_list:
|
if not terms_list:
|
||||||
return context_snippets, terms_list
|
return context_snippets, terms_list
|
||||||
|
|
||||||
|
ctx_conn = None
|
||||||
try:
|
try:
|
||||||
ctx_conn = _imap_connect(account_id, owner=owner)
|
ctx_conn = _imap_connect(account_id, owner=owner)
|
||||||
for folder in ["INBOX", "Sent", "Archive", "Drafts"]:
|
for folder in ["INBOX", "Sent", "Archive", "Drafts"]:
|
||||||
@@ -1352,12 +1417,12 @@ def _pre_retrieve_context(
|
|||||||
except Exception as _e:
|
except Exception as _e:
|
||||||
logger.warning(f" search {folder} {term!r} failed: {_e}")
|
logger.warning(f" search {folder} {term!r} failed: {_e}")
|
||||||
continue
|
continue
|
||||||
try:
|
|
||||||
ctx_conn.logout()
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
except Exception as _e:
|
except Exception as _e:
|
||||||
logger.warning(f"IMAP context search failed: {_e}")
|
logger.warning(f"IMAP context search failed: {_e}")
|
||||||
|
finally:
|
||||||
|
if ctx_conn:
|
||||||
|
try: ctx_conn.logout()
|
||||||
|
except Exception: pass
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from routes.contacts_routes import _fetch_contacts
|
from routes.contacts_routes import _fetch_contacts
|
||||||
|
|||||||
@@ -210,7 +210,7 @@ async def _auto_summarize_pass_single(days_back: int = 1, account_id: str | None
|
|||||||
if auto_cal:
|
if auto_cal:
|
||||||
for sent_name in ("Sent", "INBOX/Sent", "Sent Items", "[Gmail]/Sent Mail"):
|
for sent_name in ("Sent", "INBOX/Sent", "Sent Items", "[Gmail]/Sent Mail"):
|
||||||
try:
|
try:
|
||||||
st, _ = conn.select(sent_name, readonly=True)
|
st, _ = conn.select(_q(sent_name), readonly=True)
|
||||||
if st == "OK":
|
if st == "OK":
|
||||||
folders_to_scan.append(sent_name)
|
folders_to_scan.append(sent_name)
|
||||||
break
|
break
|
||||||
@@ -1046,7 +1046,7 @@ def _scheduled_poll_once() -> dict:
|
|||||||
try:
|
try:
|
||||||
with _imap(row_account_id, owner=row_owner) as imap:
|
with _imap(row_account_id, owner=row_owner) as imap:
|
||||||
sent_folder = _detect_sent_folder(imap)
|
sent_folder = _detect_sent_folder(imap)
|
||||||
imap.append(sent_folder, "\\Seen", None, outer.as_bytes())
|
imap.append(_q(sent_folder), "\\Seen", None, outer.as_bytes())
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Failed to append scheduled {sid} to Sent: {e}")
|
logger.warning(f"Failed to append scheduled {sid} to Sent: {e}")
|
||||||
|
|
||||||
|
|||||||
@@ -32,9 +32,10 @@ from email.mime.multipart import MIMEMultipart
|
|||||||
|
|
||||||
from fastapi import APIRouter, Query, UploadFile, File, BackgroundTasks, HTTPException, Depends, Request
|
from fastapi import APIRouter, Query, UploadFile, File, BackgroundTasks, HTTPException, Depends, Request
|
||||||
from fastapi.responses import FileResponse
|
from fastapi.responses import FileResponse
|
||||||
|
from src.constants import DATA_DIR
|
||||||
|
|
||||||
from src.llm_core import llm_call_async
|
from src.llm_core import llm_call_async
|
||||||
from src.upload_limits import read_upload_limited
|
from src.upload_limits import read_upload_limited, EMAIL_COMPOSE_UPLOAD_MAX_BYTES
|
||||||
|
|
||||||
from routes.email_helpers import (
|
from routes.email_helpers import (
|
||||||
_strip_think, _extract_reply, _apply_email_style_mechanics, require_owner, require_user, _assert_owns_account,
|
_strip_think, _extract_reply, _apply_email_style_mechanics, require_owner, require_user, _assert_owns_account,
|
||||||
@@ -47,6 +48,7 @@ from routes.email_helpers import (
|
|||||||
_extract_attachment_to_disk, _extract_html, _extract_text,
|
_extract_attachment_to_disk, _extract_html, _extract_text,
|
||||||
_fetch_sender_thread_context, _pre_retrieve_context,
|
_fetch_sender_thread_context, _pre_retrieve_context,
|
||||||
_EMAIL_REPLY_SYS_PROMPT_BASE, _POOL_HOOKS,
|
_EMAIL_REPLY_SYS_PROMPT_BASE, _POOL_HOOKS,
|
||||||
|
_friendly_email_auth_error,
|
||||||
SendEmailRequest, ExtractStyleRequest,
|
SendEmailRequest, ExtractStyleRequest,
|
||||||
ATTACHMENTS_DIR, COMPOSE_UPLOADS_DIR, SCHEDULED_DB,
|
ATTACHMENTS_DIR, COMPOSE_UPLOADS_DIR, SCHEDULED_DB,
|
||||||
attachment_extract_dir, _email_cache_owner_clause,
|
attachment_extract_dir, _email_cache_owner_clause,
|
||||||
@@ -56,7 +58,6 @@ from routes.email_pollers import _start_poller
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
ODYSSEUS_MAIL_ORIGIN = "odysseus-ui"
|
ODYSSEUS_MAIL_ORIGIN = "odysseus-ui"
|
||||||
EMAIL_COMPOSE_UPLOAD_MAX_BYTES = 25 * 1024 * 1024
|
|
||||||
|
|
||||||
|
|
||||||
def _email_tag_owner_aliases(account_id: str | None, owner: str = "") -> list[str]:
|
def _email_tag_owner_aliases(account_id: str | None, owner: str = "") -> list[str]:
|
||||||
@@ -2904,7 +2905,7 @@ def setup_email_routes():
|
|||||||
from pathlib import Path as _P
|
from pathlib import Path as _P
|
||||||
import json as _json
|
import json as _json
|
||||||
_slug = "".join(c if (c.isalnum() or c in "-_.@") else "_" for c in (owner or "default"))
|
_slug = "".join(c if (c.isalnum() or c in "-_.@") else "_" for c in (owner or "default"))
|
||||||
path = _P(f"data/email_urgency_state_{_slug}.json")
|
path = _P(DATA_DIR) / f"email_urgency_state_{_slug}.json"
|
||||||
if not path.exists():
|
if not path.exists():
|
||||||
return {"total_unread": 0, "total_urgent": 0, "max_score": 0, "per_uid": {}}
|
return {"total_unread": 0, "total_urgent": 0, "max_score": 0, "per_uid": {}}
|
||||||
try:
|
try:
|
||||||
@@ -3162,7 +3163,7 @@ def setup_email_routes():
|
|||||||
try: conn.logout()
|
try: conn.logout()
|
||||||
except Exception: pass
|
except Exception: pass
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
imap_result = {"ok": False, "error": str(e)[:200]}
|
imap_result = {"ok": False, "error": _friendly_email_auth_error("IMAP", imap_host, e)}
|
||||||
|
|
||||||
smtp_host = (body.get("smtp_host") or "").strip()
|
smtp_host = (body.get("smtp_host") or "").strip()
|
||||||
if smtp_host:
|
if smtp_host:
|
||||||
@@ -3184,7 +3185,7 @@ def setup_email_routes():
|
|||||||
try: smtp.quit()
|
try: smtp.quit()
|
||||||
except Exception: pass
|
except Exception: pass
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
smtp_result = {"ok": False, "error": str(e)[:200]}
|
smtp_result = {"ok": False, "error": _friendly_email_auth_error("SMTP", smtp_host, e)}
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"ok": imap_result["ok"] and (smtp_result is None or smtp_result["ok"]),
|
"ok": imap_result["ok"] and (smtp_result is None or smtp_result["ok"]),
|
||||||
|
|||||||
+65
-22
@@ -7,12 +7,12 @@ import logging
|
|||||||
import asyncio
|
import asyncio
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from fastapi import APIRouter, HTTPException, Form, Depends
|
from fastapi import APIRouter, HTTPException, Form, Depends
|
||||||
from core.constants import BASE_DIR
|
from core.constants import EMBEDDING_ENDPOINT_FILE, FASTEMBED_CACHE_DIR
|
||||||
from core.middleware import require_admin
|
from core.middleware import require_admin
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
_ENDPOINT_FILE = os.path.join(BASE_DIR, "data", "embedding_endpoint.json")
|
_ENDPOINT_FILE = EMBEDDING_ENDPOINT_FILE
|
||||||
|
|
||||||
# Track in-progress downloads
|
# Track in-progress downloads
|
||||||
_downloading: dict = {}
|
_downloading: dict = {}
|
||||||
@@ -35,13 +35,7 @@ def _cache_dir() -> str:
|
|||||||
default lived in /tmp, which many systems wipe on reboot — forcing a
|
default lived in /tmp, which many systems wipe on reboot — forcing a
|
||||||
full re-download of the embedding model after every restart.
|
full re-download of the embedding model after every restart.
|
||||||
"""
|
"""
|
||||||
env = os.environ.get("FASTEMBED_CACHE_PATH")
|
return FASTEMBED_CACHE_DIR
|
||||||
if env:
|
|
||||||
return env
|
|
||||||
return os.path.join(
|
|
||||||
os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
|
|
||||||
"data", "fastembed_cache",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _model_cache_name(hf_source: str) -> str:
|
def _model_cache_name(hf_source: str) -> str:
|
||||||
@@ -49,19 +43,35 @@ def _model_cache_name(hf_source: str) -> str:
|
|||||||
return "models--" + hf_source.replace("/", "--")
|
return "models--" + hf_source.replace("/", "--")
|
||||||
|
|
||||||
|
|
||||||
|
def _model_cache_path(hf_source: str) -> Path:
|
||||||
|
"""Return a confined cache path for a fastembed HF source."""
|
||||||
|
root = Path(_cache_dir()).expanduser().resolve()
|
||||||
|
raw_path = root / _model_cache_name(hf_source)
|
||||||
|
if raw_path.is_symlink():
|
||||||
|
raise ValueError("Model cache path must not be a symlink")
|
||||||
|
path = raw_path.resolve(strict=False)
|
||||||
|
try:
|
||||||
|
path.relative_to(root)
|
||||||
|
except ValueError:
|
||||||
|
raise ValueError("Model cache path escapes cache root")
|
||||||
|
return path
|
||||||
|
|
||||||
|
|
||||||
def _is_downloaded(hf_source: str) -> bool:
|
def _is_downloaded(hf_source: str) -> bool:
|
||||||
"""Check if a model is already cached."""
|
"""Check if a model is already cached."""
|
||||||
cache = _cache_dir()
|
try:
|
||||||
model_dir = os.path.join(cache, _model_cache_name(hf_source))
|
model_dir = _model_cache_path(hf_source)
|
||||||
if not os.path.isdir(model_dir):
|
except ValueError:
|
||||||
|
return False
|
||||||
|
if not model_dir.is_dir():
|
||||||
return False
|
return False
|
||||||
# Check for actual model files (not just empty dir)
|
# Check for actual model files (not just empty dir)
|
||||||
snapshots = os.path.join(model_dir, "snapshots")
|
snapshots = model_dir / "snapshots"
|
||||||
if os.path.isdir(snapshots):
|
if snapshots.is_dir():
|
||||||
return any(os.listdir(snapshots))
|
return any(snapshots.iterdir())
|
||||||
# Also check for blobs (older cache format)
|
# Also check for blobs (older cache format)
|
||||||
blobs = os.path.join(model_dir, "blobs")
|
blobs = model_dir / "blobs"
|
||||||
return os.path.isdir(blobs) and any(os.listdir(blobs))
|
return blobs.is_dir() and any(blobs.iterdir())
|
||||||
|
|
||||||
|
|
||||||
def _active_model() -> str:
|
def _active_model() -> str:
|
||||||
@@ -119,8 +129,10 @@ def setup_embedding_routes():
|
|||||||
|
|
||||||
cached_size = None
|
cached_size = None
|
||||||
if downloaded and hf_src:
|
if downloaded and hf_src:
|
||||||
model_path = os.path.join(_cache_dir(), _model_cache_name(hf_src))
|
try:
|
||||||
cached_size = _dir_size_mb(model_path)
|
cached_size = _dir_size_mb(str(_model_cache_path(hf_src)))
|
||||||
|
except ValueError:
|
||||||
|
cached_size = None
|
||||||
|
|
||||||
result.append({
|
result.append({
|
||||||
"model": m["model"],
|
"model": m["model"],
|
||||||
@@ -217,8 +229,11 @@ def setup_embedding_routes():
|
|||||||
if not hf_src:
|
if not hf_src:
|
||||||
raise HTTPException(400, "No cache source for this model")
|
raise HTTPException(400, "No cache source for this model")
|
||||||
|
|
||||||
model_path = os.path.join(_cache_dir(), _model_cache_name(hf_src))
|
try:
|
||||||
if not os.path.isdir(model_path):
|
model_path = _model_cache_path(hf_src)
|
||||||
|
except ValueError as e:
|
||||||
|
raise HTTPException(400, str(e))
|
||||||
|
if not model_path.is_dir():
|
||||||
return {"deleted": False, "message": "Model not cached"}
|
return {"deleted": False, "message": "Model not cached"}
|
||||||
|
|
||||||
shutil.rmtree(model_path)
|
shutil.rmtree(model_path)
|
||||||
@@ -237,7 +252,7 @@ def setup_embedding_routes():
|
|||||||
}
|
}
|
||||||
|
|
||||||
@router.post("/endpoint")
|
@router.post("/endpoint")
|
||||||
def set_endpoint(url: str = Form(...), model: str = Form("")):
|
def set_endpoint(url: str = Form(...), model: str = Form(""), api_key: str = Form("")):
|
||||||
"""Save a custom embedding endpoint URL."""
|
"""Save a custom embedding endpoint URL."""
|
||||||
url = url.strip()
|
url = url.strip()
|
||||||
if not url:
|
if not url:
|
||||||
@@ -261,6 +276,7 @@ def setup_embedding_routes():
|
|||||||
resp = httpx.post(
|
resp = httpx.post(
|
||||||
url,
|
url,
|
||||||
json={"input": ["test"], "model": model or "test"},
|
json={"input": ["test"], "model": model or "test"},
|
||||||
|
headers={"Authorization": f"Bearer {api_key}"} if api_key else {},
|
||||||
timeout=10,
|
timeout=10,
|
||||||
)
|
)
|
||||||
resp.raise_for_status()
|
resp.raise_for_status()
|
||||||
@@ -271,10 +287,16 @@ def setup_embedding_routes():
|
|||||||
data = {"url": url}
|
data = {"url": url}
|
||||||
if model:
|
if model:
|
||||||
data["model"] = model
|
data["model"] = model
|
||||||
|
if api_key:
|
||||||
|
from src.secret_storage import encrypt
|
||||||
|
data["api_key"] = encrypt(api_key)
|
||||||
|
|
||||||
_save_custom_endpoint(data)
|
_save_custom_endpoint(data)
|
||||||
os.environ["EMBEDDING_URL"] = url
|
os.environ["EMBEDDING_URL"] = url
|
||||||
if model:
|
if model:
|
||||||
os.environ["EMBEDDING_MODEL"] = model
|
os.environ["EMBEDDING_MODEL"] = model
|
||||||
|
if api_key:
|
||||||
|
os.environ["EMBEDDING_API_KEY"] = api_key
|
||||||
|
|
||||||
# Reset the RAG singleton so it picks up the new endpoint
|
# Reset the RAG singleton so it picks up the new endpoint
|
||||||
import src.rag_singleton as _rs
|
import src.rag_singleton as _rs
|
||||||
@@ -288,6 +310,16 @@ def setup_embedding_routes():
|
|||||||
reset_http_embed_state()
|
reset_http_embed_state()
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
try:
|
||||||
|
from src.embedding_lanes import reset_embedding_lane_state
|
||||||
|
reset_embedding_lane_state()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
from src.tool_index import reset_tool_index
|
||||||
|
reset_tool_index()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
# Reset ChromaDB client (collections will be recreated with new embeddings)
|
# Reset ChromaDB client (collections will be recreated with new embeddings)
|
||||||
try:
|
try:
|
||||||
@@ -308,6 +340,7 @@ def setup_embedding_routes():
|
|||||||
# Remove from environment
|
# Remove from environment
|
||||||
os.environ.pop("EMBEDDING_URL", None)
|
os.environ.pop("EMBEDDING_URL", None)
|
||||||
os.environ.pop("EMBEDDING_MODEL", None)
|
os.environ.pop("EMBEDDING_MODEL", None)
|
||||||
|
os.environ.pop("EMBEDDING_API_KEY", None)
|
||||||
|
|
||||||
# Reset the RAG singleton so it falls back to fastembed
|
# Reset the RAG singleton so it falls back to fastembed
|
||||||
import src.rag_singleton as _rs
|
import src.rag_singleton as _rs
|
||||||
@@ -318,6 +351,16 @@ def setup_embedding_routes():
|
|||||||
reset_http_embed_state()
|
reset_http_embed_state()
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
try:
|
||||||
|
from src.embedding_lanes import reset_embedding_lane_state
|
||||||
|
reset_embedding_lane_state()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
from src.tool_index import reset_tool_index
|
||||||
|
reset_tool_index()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
# Reset ChromaDB client
|
# Reset ChromaDB client
|
||||||
try:
|
try:
|
||||||
|
|||||||
+45
-6
@@ -16,22 +16,54 @@ from pathlib import Path
|
|||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from fastapi import APIRouter
|
from fastapi import APIRouter
|
||||||
from fastapi.responses import FileResponse, Response
|
from fastapi.responses import Response
|
||||||
|
|
||||||
|
from src.constants import EMOJI_CACHE_DIR
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
_CACHE_DIR = Path(__file__).resolve().parent.parent / "data" / "emoji_cache"
|
_CACHE_DIR = Path(EMOJI_CACHE_DIR)
|
||||||
# OpenMoji "black" set = monochrome line-art SVGs. Filenames are the codepoints
|
# OpenMoji "black" set = monochrome line-art SVGs. Filenames are the codepoints
|
||||||
# in UPPERCASE (FE0F dropped, same as we compute), '-' joined.
|
# in UPPERCASE (FE0F dropped, same as we compute), '-' joined.
|
||||||
_OPENMOJI_BASE = "https://cdn.jsdelivr.net/npm/openmoji@15.0.0/black/svg"
|
_OPENMOJI_BASE = "https://cdn.jsdelivr.net/npm/openmoji@15.0.0/black/svg"
|
||||||
# codepoints like "1f600" or "1f468-200d-1f469-200d-1f467" (lowercase hex, '-' joined)
|
# codepoints like "1f600" or "1f468-200d-1f469-200d-1f467" (lowercase hex, '-' joined)
|
||||||
_CODE_RE = re.compile(r"^[0-9a-f]{2,6}(?:-[0-9a-f]{2,6})*$")
|
_CODE_RE = re.compile(r"^[0-9a-f]{2,6}(?:-[0-9a-f]{2,6})*$")
|
||||||
_SVG_HEADERS = {"Cache-Control": "public, max-age=31536000, immutable"}
|
_MAX_SVG_BYTES = 256 * 1024
|
||||||
|
_BLOCKED_SVG_RE = re.compile(
|
||||||
|
br"<\s*(?:script|foreignObject|iframe|object|embed|image)\b|"
|
||||||
|
br"\bon[a-z0-9_-]+\s*=",
|
||||||
|
re.IGNORECASE,
|
||||||
|
)
|
||||||
|
_EXTERNAL_REF_RE = re.compile(
|
||||||
|
br"\b(?:href|xlink:href)\s*=\s*['\"](?:https?:|//|data:|javascript:)",
|
||||||
|
re.IGNORECASE,
|
||||||
|
)
|
||||||
|
_SVG_SECURITY_HEADERS = {
|
||||||
|
"X-Content-Type-Options": "nosniff",
|
||||||
|
"Content-Security-Policy": "sandbox",
|
||||||
|
"Cross-Origin-Resource-Policy": "same-origin",
|
||||||
|
}
|
||||||
|
_SVG_HEADERS = {
|
||||||
|
"Cache-Control": "public, max-age=31536000, immutable",
|
||||||
|
**_SVG_SECURITY_HEADERS,
|
||||||
|
}
|
||||||
# Returned when a codepoint is unknown/unreachable: an empty (transparent) SVG,
|
# Returned when a codepoint is unknown/unreachable: an empty (transparent) SVG,
|
||||||
# so the CSS mask renders nothing instead of a solid box. Not cached, so a later
|
# so the CSS mask renders nothing instead of a solid box. Not cached, so a later
|
||||||
# request can still pick up the real glyph once the CDN is reachable.
|
# request can still pick up the real glyph once the CDN is reachable.
|
||||||
_BLANK_SVG = b'<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 1 1"></svg>'
|
_BLANK_SVG = b'<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 1 1"></svg>'
|
||||||
_BLANK_HEADERS = {"Cache-Control": "no-store"}
|
_BLANK_HEADERS = {"Cache-Control": "no-store", **_SVG_SECURITY_HEADERS}
|
||||||
|
|
||||||
|
|
||||||
|
def _is_safe_svg(content: bytes) -> bool:
|
||||||
|
if not isinstance(content, bytes) or not content:
|
||||||
|
return False
|
||||||
|
if len(content) > _MAX_SVG_BYTES:
|
||||||
|
return False
|
||||||
|
if b"<svg" not in content[:256].lower():
|
||||||
|
return False
|
||||||
|
if _BLOCKED_SVG_RE.search(content) or _EXTERNAL_REF_RE.search(content):
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
def setup_emoji_routes() -> APIRouter:
|
def setup_emoji_routes() -> APIRouter:
|
||||||
@@ -49,14 +81,21 @@ def setup_emoji_routes() -> APIRouter:
|
|||||||
_CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
_CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
fp = _CACHE_DIR / f"{code}.svg"
|
fp = _CACHE_DIR / f"{code}.svg"
|
||||||
if fp.exists():
|
if fp.exists():
|
||||||
return FileResponse(fp, media_type="image/svg+xml", headers=_SVG_HEADERS)
|
try:
|
||||||
|
content = fp.read_bytes()
|
||||||
|
if _is_safe_svg(content):
|
||||||
|
return Response(content, media_type="image/svg+xml", headers=_SVG_HEADERS)
|
||||||
|
fp.unlink(missing_ok=True)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("emoji cache read %s failed: %s", code, e)
|
||||||
|
return _blank()
|
||||||
|
|
||||||
# First time we've seen this emoji — fetch the OpenMoji black SVG + cache
|
# First time we've seen this emoji — fetch the OpenMoji black SVG + cache
|
||||||
# it. OpenMoji filenames are the codepoints uppercased.
|
# it. OpenMoji filenames are the codepoints uppercased.
|
||||||
try:
|
try:
|
||||||
async with httpx.AsyncClient(timeout=8.0) as client:
|
async with httpx.AsyncClient(timeout=8.0) as client:
|
||||||
r = await client.get(f"{_OPENMOJI_BASE}/{code.upper()}.svg")
|
r = await client.get(f"{_OPENMOJI_BASE}/{code.upper()}.svg")
|
||||||
if r.status_code == 200 and b"<svg" in r.content[:256]:
|
if r.status_code == 200 and _is_safe_svg(r.content):
|
||||||
try:
|
try:
|
||||||
fp.write_bytes(r.content)
|
fp.write_bytes(r.content)
|
||||||
except Exception:
|
except Exception:
|
||||||
|
|||||||
+145
-54
@@ -12,8 +12,13 @@ from fastapi import APIRouter, HTTPException, Query, Request
|
|||||||
|
|
||||||
from core.database import SessionLocal, GalleryImage, GalleryAlbum, ModelEndpoint
|
from core.database import SessionLocal, GalleryImage, GalleryAlbum, ModelEndpoint
|
||||||
from core.database import Session as DbSession
|
from core.database import Session as DbSession
|
||||||
from src.auth_helpers import get_current_user, require_privilege
|
from src.auth_helpers import get_current_user, owner_filter, require_privilege
|
||||||
from src.upload_limits import read_upload_limited
|
from src.upload_limits import (
|
||||||
|
read_upload_limited,
|
||||||
|
GALLERY_UPLOAD_MAX_BYTES,
|
||||||
|
GALLERY_TRANSFORM_UPLOAD_MAX_BYTES,
|
||||||
|
)
|
||||||
|
from src.constants import GENERATED_IMAGES_DIR
|
||||||
|
|
||||||
from routes.gallery_helpers import (
|
from routes.gallery_helpers import (
|
||||||
GalleryPatch, _extract_exif, _image_to_dict, _owner_filter, _human_size,
|
GalleryPatch, _extract_exif, _image_to_dict, _owner_filter, _human_size,
|
||||||
@@ -21,17 +26,88 @@ from routes.gallery_helpers import (
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
GALLERY_UPLOAD_MAX_BYTES = int(os.getenv("ODYSSEUS_GALLERY_UPLOAD_MAX_BYTES", str(100 * 1024 * 1024)))
|
|
||||||
GALLERY_TRANSFORM_UPLOAD_MAX_BYTES = int(os.getenv("ODYSSEUS_GALLERY_TRANSFORM_UPLOAD_MAX_BYTES", str(25 * 1024 * 1024)))
|
def _current_user_is_admin(request: Request, user: str | None) -> bool:
|
||||||
|
if not user:
|
||||||
|
return False
|
||||||
|
auth_mgr = getattr(request.app.state, "auth_manager", None)
|
||||||
|
is_admin = getattr(auth_mgr, "is_admin", None)
|
||||||
|
if not callable(is_admin):
|
||||||
|
return False
|
||||||
|
try:
|
||||||
|
return bool(is_admin(user))
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
def _sanitize_gallery_filename(filename: str) -> str:
|
def _sanitize_gallery_filename(filename: str) -> str:
|
||||||
"""Return a local filename safe to join under generated_images."""
|
"""Return a local filename safe to join under generated_images."""
|
||||||
safe_name = re.sub(r"[^A-Za-z0-9._-]", "_", Path(filename or "").name)[:128]
|
safe_name = re.sub(r"[^A-Za-z0-9._-]", "_", Path(str(filename or "")).name)[:128]
|
||||||
if not safe_name or safe_name in {".", ".."}:
|
if not safe_name or safe_name in {".", ".."}:
|
||||||
safe_name = uuid.uuid4().hex[:12]
|
safe_name = uuid.uuid4().hex[:12]
|
||||||
return safe_name
|
return safe_name
|
||||||
|
|
||||||
|
|
||||||
|
GALLERY_IMAGE_DIR = Path(GENERATED_IMAGES_DIR)
|
||||||
|
|
||||||
|
|
||||||
|
def _gallery_image_path(filename: str) -> Path:
|
||||||
|
"""Resolve a stored gallery filename without leaving generated_images."""
|
||||||
|
if not isinstance(filename, str):
|
||||||
|
raise HTTPException(400, "Unsafe gallery filename")
|
||||||
|
safe_name = _sanitize_gallery_filename(filename)
|
||||||
|
original = str(filename or "")
|
||||||
|
root = GALLERY_IMAGE_DIR.resolve()
|
||||||
|
path = (GALLERY_IMAGE_DIR / safe_name).resolve()
|
||||||
|
try:
|
||||||
|
if os.path.commonpath([str(root), str(path)]) != str(root):
|
||||||
|
raise ValueError
|
||||||
|
except Exception:
|
||||||
|
raise HTTPException(400, "Unsafe gallery filename")
|
||||||
|
if safe_name != original:
|
||||||
|
raise HTTPException(400, "Unsafe gallery filename")
|
||||||
|
return path
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_image_endpoint_base(url: str) -> str:
|
||||||
|
base = (url or "").strip().rstrip("/")
|
||||||
|
if base.endswith("/v1"):
|
||||||
|
base = base[:-3].rstrip("/")
|
||||||
|
return base
|
||||||
|
|
||||||
|
|
||||||
|
def _visible_image_endpoint_query(db, owner: str | None):
|
||||||
|
from src.auth_helpers import owner_filter
|
||||||
|
q = db.query(ModelEndpoint).filter(
|
||||||
|
ModelEndpoint.model_type == "image",
|
||||||
|
ModelEndpoint.is_enabled == True, # noqa: E712
|
||||||
|
)
|
||||||
|
return owner_filter(q, ModelEndpoint, owner)
|
||||||
|
|
||||||
|
|
||||||
|
def _first_visible_image_endpoint(db, owner: str | None):
|
||||||
|
endpoints = _visible_image_endpoint_query(db, owner).all()
|
||||||
|
if owner:
|
||||||
|
for ep in endpoints:
|
||||||
|
if getattr(ep, "owner", None) == owner:
|
||||||
|
return ep
|
||||||
|
return endpoints[0] if endpoints else None
|
||||||
|
|
||||||
|
|
||||||
|
def _visible_image_endpoint_for_base(db, base: str, owner: str | None):
|
||||||
|
target = _normalize_image_endpoint_base(base)
|
||||||
|
if not target:
|
||||||
|
return None
|
||||||
|
fallback = None
|
||||||
|
for ep in _visible_image_endpoint_query(db, owner).all():
|
||||||
|
if _normalize_image_endpoint_base(getattr(ep, "base_url", "")) == target:
|
||||||
|
if owner and getattr(ep, "owner", None) == owner:
|
||||||
|
return ep
|
||||||
|
if fallback is None:
|
||||||
|
fallback = ep
|
||||||
|
return fallback
|
||||||
|
|
||||||
|
|
||||||
def setup_gallery_routes() -> APIRouter:
|
def setup_gallery_routes() -> APIRouter:
|
||||||
router = APIRouter(tags=["gallery"])
|
router = APIRouter(tags=["gallery"])
|
||||||
|
|
||||||
@@ -55,6 +131,9 @@ def setup_gallery_routes() -> APIRouter:
|
|||||||
file_hash = hashlib.sha256(content).hexdigest()
|
file_hash = hashlib.sha256(content).hexdigest()
|
||||||
db = SessionLocal()
|
db = SessionLocal()
|
||||||
try:
|
try:
|
||||||
|
if album_id and user is not None:
|
||||||
|
_get_or_404_album(db, album_id, user)
|
||||||
|
|
||||||
# SECURITY: scope the dup-detect to THIS user — otherwise a
|
# SECURITY: scope the dup-detect to THIS user — otherwise a
|
||||||
# caller can probe whether someone else uploaded the same
|
# caller can probe whether someone else uploaded the same
|
||||||
# file (the response leaks the existing row's id+filename).
|
# file (the response leaks the existing row's id+filename).
|
||||||
@@ -69,7 +148,7 @@ def setup_gallery_routes() -> APIRouter:
|
|||||||
return {"ok": False, "duplicate": True, "filename": existing.filename,
|
return {"ok": False, "duplicate": True, "filename": existing.filename,
|
||||||
"id": existing.id, "message": "Duplicate photo skipped"}
|
"id": existing.id, "message": "Duplicate photo skipped"}
|
||||||
|
|
||||||
img_dir = Path("data/generated_images")
|
img_dir = Path(GENERATED_IMAGES_DIR)
|
||||||
img_dir.mkdir(parents=True, exist_ok=True)
|
img_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
ext = file.filename.rsplit(".", 1)[-1].lower() if "." in file.filename else "png"
|
ext = file.filename.rsplit(".", 1)[-1].lower() if "." in file.filename else "png"
|
||||||
@@ -135,7 +214,7 @@ def setup_gallery_routes() -> APIRouter:
|
|||||||
raise HTTPException(400, "No image provided")
|
raise HTTPException(400, "No image provided")
|
||||||
|
|
||||||
content = await read_upload_limited(file, GALLERY_UPLOAD_MAX_BYTES, "Gallery replacement")
|
content = await read_upload_limited(file, GALLERY_UPLOAD_MAX_BYTES, "Gallery replacement")
|
||||||
img_dir = Path("data/generated_images")
|
img_dir = Path(GENERATED_IMAGES_DIR)
|
||||||
img_dir.mkdir(parents=True, exist_ok=True)
|
img_dir.mkdir(parents=True, exist_ok=True)
|
||||||
img_path = img_dir / _sanitize_gallery_filename(img.filename)
|
img_path = img_dir / _sanitize_gallery_filename(img.filename)
|
||||||
img_path.write_bytes(content)
|
img_path.write_bytes(content)
|
||||||
@@ -211,7 +290,7 @@ def setup_gallery_routes() -> APIRouter:
|
|||||||
if not user or img.owner != user:
|
if not user or img.owner != user:
|
||||||
raise HTTPException(403, "Not your image")
|
raise HTTPException(403, "Not your image")
|
||||||
|
|
||||||
img_path = Path("data/generated_images") / img.filename
|
img_path = _gallery_image_path(img.filename)
|
||||||
if not img_path.exists():
|
if not img_path.exists():
|
||||||
raise HTTPException(404, "Image file not found")
|
raise HTTPException(404, "Image file not found")
|
||||||
|
|
||||||
@@ -248,7 +327,7 @@ def setup_gallery_routes() -> APIRouter:
|
|||||||
"""AI upscale using img2img with the diffusion server."""
|
"""AI upscale using img2img with the diffusion server."""
|
||||||
import base64, httpx
|
import base64, httpx
|
||||||
|
|
||||||
require_privilege(request, "can_generate_images")
|
user = require_privilege(request, "can_generate_images")
|
||||||
form = await request.form()
|
form = await request.form()
|
||||||
file = form.get("image")
|
file = form.get("image")
|
||||||
if not file: raise HTTPException(400, "No image")
|
if not file: raise HTTPException(400, "No image")
|
||||||
@@ -260,7 +339,7 @@ def setup_gallery_routes() -> APIRouter:
|
|||||||
# Find image endpoint
|
# Find image endpoint
|
||||||
db = SessionLocal()
|
db = SessionLocal()
|
||||||
try:
|
try:
|
||||||
ep = db.query(ModelEndpoint).filter(ModelEndpoint.model_type == "image", ModelEndpoint.is_enabled == True).first()
|
ep = _first_visible_image_endpoint(db, user)
|
||||||
finally:
|
finally:
|
||||||
db.close()
|
db.close()
|
||||||
|
|
||||||
@@ -291,7 +370,7 @@ def setup_gallery_routes() -> APIRouter:
|
|||||||
"""Style transfer using img2img with the diffusion server."""
|
"""Style transfer using img2img with the diffusion server."""
|
||||||
import base64, httpx
|
import base64, httpx
|
||||||
|
|
||||||
require_privilege(request, "can_generate_images")
|
user = require_privilege(request, "can_generate_images")
|
||||||
form = await request.form()
|
form = await request.form()
|
||||||
file = form.get("image")
|
file = form.get("image")
|
||||||
prompt = form.get("prompt", "")
|
prompt = form.get("prompt", "")
|
||||||
@@ -303,7 +382,7 @@ def setup_gallery_routes() -> APIRouter:
|
|||||||
|
|
||||||
db = SessionLocal()
|
db = SessionLocal()
|
||||||
try:
|
try:
|
||||||
ep = db.query(ModelEndpoint).filter(ModelEndpoint.model_type == "image", ModelEndpoint.is_enabled == True).first()
|
ep = _first_visible_image_endpoint(db, user)
|
||||||
finally:
|
finally:
|
||||||
db.close()
|
db.close()
|
||||||
|
|
||||||
@@ -505,18 +584,24 @@ def setup_gallery_routes() -> APIRouter:
|
|||||||
albums = q.order_by(GalleryAlbum.created_at.desc()).all()
|
albums = q.order_by(GalleryAlbum.created_at.desc()).all()
|
||||||
result = []
|
result = []
|
||||||
for a in albums:
|
for a in albums:
|
||||||
count = db.query(GalleryImage).filter(
|
_count_q = db.query(GalleryImage).filter(
|
||||||
GalleryImage.album_id == a.id, GalleryImage.is_active == True
|
GalleryImage.album_id == a.id, GalleryImage.is_active == True
|
||||||
).count()
|
)
|
||||||
|
if user:
|
||||||
|
_count_q = _count_q.filter(GalleryImage.owner == user)
|
||||||
|
count = _count_q.count()
|
||||||
cover_url = None
|
cover_url = None
|
||||||
if a.cover_id:
|
if a.cover_id:
|
||||||
cover = db.query(GalleryImage).filter(GalleryImage.id == a.cover_id).first()
|
cover = db.query(GalleryImage).filter(GalleryImage.id == a.cover_id).first()
|
||||||
if cover:
|
if cover:
|
||||||
cover_url = f"/api/generated-image/{cover.filename}"
|
cover_url = f"/api/generated-image/{cover.filename}"
|
||||||
elif count > 0:
|
elif count > 0:
|
||||||
first = db.query(GalleryImage).filter(
|
_cover_q = db.query(GalleryImage).filter(
|
||||||
GalleryImage.album_id == a.id, GalleryImage.is_active == True
|
GalleryImage.album_id == a.id, GalleryImage.is_active == True
|
||||||
).order_by(GalleryImage.created_at.desc()).first()
|
)
|
||||||
|
if user:
|
||||||
|
_cover_q = _cover_q.filter(GalleryImage.owner == user)
|
||||||
|
first = _cover_q.order_by(GalleryImage.created_at.desc()).first()
|
||||||
if first:
|
if first:
|
||||||
cover_url = f"/api/generated-image/{first.filename}"
|
cover_url = f"/api/generated-image/{first.filename}"
|
||||||
result.append({
|
result.append({
|
||||||
@@ -649,7 +734,14 @@ def setup_gallery_routes() -> APIRouter:
|
|||||||
if req.favorite is not None:
|
if req.favorite is not None:
|
||||||
img.favorite = req.favorite
|
img.favorite = req.favorite
|
||||||
if req.album_id is not None:
|
if req.album_id is not None:
|
||||||
img.album_id = req.album_id if req.album_id else None
|
if req.album_id:
|
||||||
|
# Validate the target album belongs to the caller before
|
||||||
|
# moving the image into it — mirrors add_to_album, so you
|
||||||
|
# cannot file your image into another user's album.
|
||||||
|
_get_or_404_album(db, req.album_id, user)
|
||||||
|
img.album_id = req.album_id
|
||||||
|
else:
|
||||||
|
img.album_id = None
|
||||||
db.commit()
|
db.commit()
|
||||||
db.refresh(img)
|
db.refresh(img)
|
||||||
return _image_to_dict(img)
|
return _image_to_dict(img)
|
||||||
@@ -692,11 +784,11 @@ def setup_gallery_routes() -> APIRouter:
|
|||||||
used = set()
|
used = set()
|
||||||
with zipfile.ZipFile(buf, "w", zipfile.ZIP_DEFLATED) as zf:
|
with zipfile.ZipFile(buf, "w", zipfile.ZIP_DEFLATED) as zf:
|
||||||
for img in imgs:
|
for img in imgs:
|
||||||
src = os.path.join("data", "generated_images", img.filename)
|
src = _gallery_image_path(img.filename)
|
||||||
if not os.path.exists(src):
|
if not src.exists():
|
||||||
continue
|
continue
|
||||||
ext = os.path.splitext(img.filename)[1] or ".png"
|
ext = src.suffix or ".png"
|
||||||
base = (img.prompt or "").strip() or os.path.splitext(img.filename)[0]
|
base = (img.prompt or "").strip() or src.stem
|
||||||
base = re.sub(r"[^\w\-. ]+", "", base)[:60].strip() or img.id
|
base = re.sub(r"[^\w\-. ]+", "", base)[:60].strip() or img.id
|
||||||
name = f"{base}{ext}"
|
name = f"{base}{ext}"
|
||||||
i = 1
|
i = 1
|
||||||
@@ -818,9 +910,9 @@ def setup_gallery_routes() -> APIRouter:
|
|||||||
|
|
||||||
img_filename = img.filename
|
img_filename = img.filename
|
||||||
# Remove the file from disk
|
# Remove the file from disk
|
||||||
img_path = os.path.join("data", "generated_images", img_filename)
|
img_path = _gallery_image_path(img_filename)
|
||||||
if os.path.exists(img_path):
|
if img_path.exists():
|
||||||
os.remove(img_path)
|
img_path.unlink()
|
||||||
|
|
||||||
# Soft-delete the record
|
# Soft-delete the record
|
||||||
img.is_active = False
|
img.is_active = False
|
||||||
@@ -923,7 +1015,7 @@ def setup_gallery_routes() -> APIRouter:
|
|||||||
the request for /v1/images/edits (multipart, inverted mask). Otherwise
|
the request for /v1/images/edits (multipart, inverted mask). Otherwise
|
||||||
proxy through to a self-hosted diffusion server's /v1/images/inpaint."""
|
proxy through to a self-hosted diffusion server's /v1/images/inpaint."""
|
||||||
import httpx
|
import httpx
|
||||||
require_privilege(request, "can_generate_images")
|
user = require_privilege(request, "can_generate_images")
|
||||||
body = await request.json()
|
body = await request.json()
|
||||||
# Use endpoint from request body (editor dropdown) or fall back to DB lookup
|
# Use endpoint from request body (editor dropdown) or fall back to DB lookup
|
||||||
base = (body.pop("_endpoint", "") or "").rstrip("/")
|
base = (body.pop("_endpoint", "") or "").rstrip("/")
|
||||||
@@ -942,14 +1034,11 @@ def setup_gallery_routes() -> APIRouter:
|
|||||||
if not base:
|
if not base:
|
||||||
db = SessionLocal()
|
db = SessionLocal()
|
||||||
try:
|
try:
|
||||||
eps = db.query(ModelEndpoint).filter(
|
ep = _first_visible_image_endpoint(db, user)
|
||||||
ModelEndpoint.is_enabled == True,
|
if not ep:
|
||||||
ModelEndpoint.model_type == "image",
|
|
||||||
).all()
|
|
||||||
if not eps:
|
|
||||||
raise HTTPException(400, "No image generation endpoint configured. Serve a diffusion model via Cookbook first.")
|
raise HTTPException(400, "No image generation endpoint configured. Serve a diffusion model via Cookbook first.")
|
||||||
base = eps[0].base_url.rstrip("/")
|
base = ep.base_url.rstrip("/")
|
||||||
api_key = eps[0].api_key
|
api_key = ep.api_key
|
||||||
finally:
|
finally:
|
||||||
db.close()
|
db.close()
|
||||||
else:
|
else:
|
||||||
@@ -966,10 +1055,12 @@ def setup_gallery_routes() -> APIRouter:
|
|||||||
_target = _norm_url(base)
|
_target = _norm_url(base)
|
||||||
db = SessionLocal()
|
db = SessionLocal()
|
||||||
try:
|
try:
|
||||||
for ep in db.query(ModelEndpoint).all():
|
ep = _visible_image_endpoint_for_base(db, _target, user)
|
||||||
if _norm_url(ep.base_url) == _target:
|
if ep:
|
||||||
api_key = ep.api_key
|
base = (ep.base_url or base).rstrip("/")
|
||||||
break
|
api_key = ep.api_key
|
||||||
|
elif user and not _current_user_is_admin(request, user):
|
||||||
|
raise HTTPException(403, "Choose a registered image endpoint")
|
||||||
finally:
|
finally:
|
||||||
db.close()
|
db.close()
|
||||||
|
|
||||||
@@ -1121,7 +1212,7 @@ def setup_gallery_routes() -> APIRouter:
|
|||||||
you get edge blending + lighting unification while keeping the
|
you get edge blending + lighting unification while keeping the
|
||||||
composition recognisable."""
|
composition recognisable."""
|
||||||
import httpx, base64 as _b64
|
import httpx, base64 as _b64
|
||||||
require_privilege(request, "can_generate_images")
|
user = require_privilege(request, "can_generate_images")
|
||||||
body = await request.json()
|
body = await request.json()
|
||||||
|
|
||||||
image_b64 = body.get("image")
|
image_b64 = body.get("image")
|
||||||
@@ -1148,23 +1239,22 @@ def setup_gallery_routes() -> APIRouter:
|
|||||||
if not base:
|
if not base:
|
||||||
db = SessionLocal()
|
db = SessionLocal()
|
||||||
try:
|
try:
|
||||||
eps = db.query(ModelEndpoint).filter(
|
ep = _first_visible_image_endpoint(db, user)
|
||||||
ModelEndpoint.is_enabled == True,
|
if not ep:
|
||||||
ModelEndpoint.model_type == "image",
|
|
||||||
).all()
|
|
||||||
if not eps:
|
|
||||||
raise HTTPException(400, "No image generation endpoint configured.")
|
raise HTTPException(400, "No image generation endpoint configured.")
|
||||||
base = eps[0].base_url.rstrip("/")
|
base = ep.base_url.rstrip("/")
|
||||||
api_key = eps[0].api_key
|
api_key = ep.api_key
|
||||||
finally:
|
finally:
|
||||||
db.close()
|
db.close()
|
||||||
else:
|
else:
|
||||||
db = SessionLocal()
|
db = SessionLocal()
|
||||||
try:
|
try:
|
||||||
for ep in db.query(ModelEndpoint).all():
|
ep = _visible_image_endpoint_for_base(db, base, user)
|
||||||
if ep.base_url.rstrip("/").removesuffix("/v1").rstrip("/") == base.rstrip("/").removesuffix("/v1").rstrip("/"):
|
if ep:
|
||||||
api_key = ep.api_key
|
base = (ep.base_url or base).rstrip("/")
|
||||||
break
|
api_key = ep.api_key
|
||||||
|
elif user and not _current_user_is_admin(request, user):
|
||||||
|
raise HTTPException(403, "Choose a registered image endpoint")
|
||||||
finally:
|
finally:
|
||||||
db.close()
|
db.close()
|
||||||
|
|
||||||
@@ -1316,6 +1406,7 @@ def setup_gallery_routes() -> APIRouter:
|
|||||||
@router.post("/api/image/sharpen")
|
@router.post("/api/image/sharpen")
|
||||||
async def sharpen_image(request: Request):
|
async def sharpen_image(request: Request):
|
||||||
"""Apply unsharp-mask sharpening to an image."""
|
"""Apply unsharp-mask sharpening to an image."""
|
||||||
|
require_privilege(request, "can_generate_images")
|
||||||
body = await request.json()
|
body = await request.json()
|
||||||
image_b64 = body.get("image")
|
image_b64 = body.get("image")
|
||||||
amount = body.get("amount", 50) / 100.0
|
amount = body.get("amount", 50) / 100.0
|
||||||
@@ -1635,9 +1726,10 @@ def setup_gallery_routes() -> APIRouter:
|
|||||||
db = SessionLocal()
|
db = SessionLocal()
|
||||||
try:
|
try:
|
||||||
album = _get_or_404_album(db, album_id, user)
|
album = _get_or_404_album(db, album_id, user)
|
||||||
db.query(GalleryImage).filter(GalleryImage.album_id == album_id).update(
|
q = db.query(GalleryImage).filter(GalleryImage.album_id == album_id)
|
||||||
{"album_id": None}, synchronize_session=False
|
if user is not None:
|
||||||
)
|
q = q.filter(GalleryImage.owner == user)
|
||||||
|
q.update({"album_id": None}, synchronize_session=False)
|
||||||
db.delete(album)
|
db.delete(album)
|
||||||
db.commit()
|
db.commit()
|
||||||
return {"ok": True}
|
return {"ok": True}
|
||||||
@@ -1708,7 +1800,7 @@ def setup_gallery_routes() -> APIRouter:
|
|||||||
try:
|
try:
|
||||||
img = _get_or_404_image(db, image_id, user)
|
img = _get_or_404_image(db, image_id, user)
|
||||||
|
|
||||||
img_path = Path("data/generated_images") / img.filename
|
img_path = _gallery_image_path(img.filename)
|
||||||
if not img_path.exists():
|
if not img_path.exists():
|
||||||
raise HTTPException(404, "Image file not found")
|
raise HTTPException(404, "Image file not found")
|
||||||
|
|
||||||
@@ -1726,7 +1818,7 @@ def setup_gallery_routes() -> APIRouter:
|
|||||||
return {"error": "Vision is disabled — enable it in Settings → Vision"}
|
return {"error": "Vision is disabled — enable it in Settings → Vision"}
|
||||||
configured = vl_settings.get("vision_model", "")
|
configured = vl_settings.get("vision_model", "")
|
||||||
try:
|
try:
|
||||||
chat_url, model_name, headers = _resolve_vl_model(configured)
|
chat_url, model_name, headers = _resolve_vl_model(configured, owner=user)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
return {"error": "No vision model configured — set one in Settings → Vision"}
|
return {"error": "No vision model configured — set one in Settings → Vision"}
|
||||||
if not chat_url:
|
if not chat_url:
|
||||||
@@ -1807,4 +1899,3 @@ def setup_gallery_routes() -> APIRouter:
|
|||||||
db.close()
|
db.close()
|
||||||
|
|
||||||
return router
|
return router
|
||||||
|
|
||||||
|
|||||||
@@ -490,7 +490,13 @@ def setup_history_routes(session_manager) -> APIRouter:
|
|||||||
# Copy messages up to keep_count
|
# Copy messages up to keep_count
|
||||||
msgs_to_copy = source.history[:keep_count]
|
msgs_to_copy = source.history[:keep_count]
|
||||||
for msg in msgs_to_copy:
|
for msg in msgs_to_copy:
|
||||||
new_session.add_message(ChatMessage(msg.role, msg.content, msg.metadata))
|
# Copy the metadata dict. Sharing it would let the fork's
|
||||||
|
# persistence (add_message -> _persist_message stamps
|
||||||
|
# _db_id/timestamp onto the dict) mutate the SOURCE session's
|
||||||
|
# in-memory messages, corrupting their _db_id and breaking
|
||||||
|
# edit/delete-by-id on the original conversation.
|
||||||
|
meta = dict(msg.metadata) if isinstance(msg.metadata, dict) else None
|
||||||
|
new_session.add_message(ChatMessage(msg.role, msg.content, meta))
|
||||||
try:
|
try:
|
||||||
from src.event_bus import fire_event
|
from src.event_bus import fire_event
|
||||||
fire_event("session_created", getattr(source, 'owner', None))
|
fire_event("session_created", getattr(source, 'owner', None))
|
||||||
@@ -522,6 +528,8 @@ def setup_history_routes(session_manager) -> APIRouter:
|
|||||||
async def compact_session(request: Request, session_id: str):
|
async def compact_session(request: Request, session_id: str):
|
||||||
"""Manually trigger context compaction for a session."""
|
"""Manually trigger context compaction for a session."""
|
||||||
_verify_session_owner(request, session_id)
|
_verify_session_owner(request, session_id)
|
||||||
|
from src.auth_helpers import effective_user
|
||||||
|
owner = effective_user(request)
|
||||||
try:
|
try:
|
||||||
session = session_manager.get_session(session_id)
|
session = session_manager.get_session(session_id)
|
||||||
except KeyError:
|
except KeyError:
|
||||||
@@ -555,7 +563,7 @@ def setup_history_routes(session_manager) -> APIRouter:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Use utility model if available
|
# Use utility model if available
|
||||||
util_url, util_model, util_headers = resolve_endpoint("utility")
|
util_url, util_model, util_headers = resolve_endpoint("utility", owner=owner or None)
|
||||||
compact_url = util_url or session.endpoint_url
|
compact_url = util_url or session.endpoint_url
|
||||||
compact_model = util_model or session.model
|
compact_model = util_model or session.model
|
||||||
compact_headers = util_headers if util_url else session.headers
|
compact_headers = util_headers if util_url else session.headers
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ import httpx
|
|||||||
|
|
||||||
from core.database import McpServer, SessionLocal
|
from core.database import McpServer, SessionLocal
|
||||||
from core.middleware import require_admin
|
from core.middleware import require_admin
|
||||||
from src.constants import DATA_DIR
|
from src.constants import DATA_DIR, MCP_OAUTH_DIR
|
||||||
from src.mcp_manager import McpManager
|
from src.mcp_manager import McpManager
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -23,7 +23,7 @@ router = APIRouter(prefix="/api/mcp", tags=["mcp"])
|
|||||||
|
|
||||||
def _mcp_oauth_base_dir() -> Path:
|
def _mcp_oauth_base_dir() -> Path:
|
||||||
"""Directory that may contain OAuth files managed by Odysseus."""
|
"""Directory that may contain OAuth files managed by Odysseus."""
|
||||||
return (Path(DATA_DIR) / "mcp_oauth").resolve(strict=False)
|
return Path(MCP_OAUTH_DIR).resolve(strict=False)
|
||||||
|
|
||||||
|
|
||||||
def _resolve_mcp_oauth_path(raw_path, field_name: str) -> str:
|
def _resolve_mcp_oauth_path(raw_path, field_name: str) -> str:
|
||||||
|
|||||||
@@ -29,11 +29,10 @@ from src.llm_core import llm_call_async
|
|||||||
from services.memory.memory_extractor import audit_memories
|
from services.memory.memory_extractor import audit_memories
|
||||||
from src.auth_helpers import get_current_user, require_user
|
from src.auth_helpers import get_current_user, require_user
|
||||||
from src.endpoint_resolver import resolve_endpoint
|
from src.endpoint_resolver import resolve_endpoint
|
||||||
from src.upload_limits import read_upload_limited
|
from src.upload_limits import read_upload_limited, MEMORY_IMPORT_MAX_BYTES
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
MEMORY_IMPORT_MAX_BYTES = int(os.getenv("ODYSSEUS_MEMORY_IMPORT_MAX_BYTES", str(10 * 1024 * 1024)))
|
|
||||||
|
|
||||||
def setup_memory_routes(memory_manager: MemoryManager, session_manager: SessionManager, memory_vector=None):
|
def setup_memory_routes(memory_manager: MemoryManager, session_manager: SessionManager, memory_vector=None):
|
||||||
"""Set up memory-related routes."""
|
"""Set up memory-related routes."""
|
||||||
@@ -371,7 +370,7 @@ def setup_memory_routes(memory_manager: MemoryManager, session_manager: SessionM
|
|||||||
tmp.write(content)
|
tmp.write(content)
|
||||||
tmp_path = tmp.name
|
tmp_path = tmp.name
|
||||||
try:
|
try:
|
||||||
text = _process_pdf(tmp_path)
|
text = _process_pdf(tmp_path, owner=_owner(request))
|
||||||
finally:
|
finally:
|
||||||
os.unlink(tmp_path)
|
os.unlink(tmp_path)
|
||||||
else:
|
else:
|
||||||
|
|||||||
+188
-26
@@ -5,6 +5,7 @@ import re
|
|||||||
import uuid
|
import uuid
|
||||||
import json
|
import json
|
||||||
import socket
|
import socket
|
||||||
|
import hashlib
|
||||||
import time as _time
|
import time as _time
|
||||||
import logging
|
import logging
|
||||||
import httpx
|
import httpx
|
||||||
@@ -282,8 +283,11 @@ _HOST_TO_CURATED = (
|
|||||||
("fireworks.ai", "fireworks"),
|
("fireworks.ai", "fireworks"),
|
||||||
("googleapis.com", "google"),
|
("googleapis.com", "google"),
|
||||||
("x.ai", "xai"),
|
("x.ai", "xai"),
|
||||||
|
|
||||||
("openrouter.ai", "openrouter"),
|
("openrouter.ai", "openrouter"),
|
||||||
("ollama.com", "ollama"),
|
("ollama.com", "ollama"),
|
||||||
|
("opencode.ai/zen/go", "opencode-go"),
|
||||||
|
("opencode.ai/zen", "opencode-zen"),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -490,6 +494,8 @@ _NON_CHAT_EXACT_PREFIXES = (
|
|||||||
def _is_chat_model(model_id: str) -> bool:
|
def _is_chat_model(model_id: str) -> bool:
|
||||||
"""Return True if the model ID looks like a chat/completions-capable model."""
|
"""Return True if the model ID looks like a chat/completions-capable model."""
|
||||||
mid = model_id.lower()
|
mid = model_id.lower()
|
||||||
|
if mid in {"gpt-5.1-codex"}:
|
||||||
|
return True
|
||||||
for prefix in _NON_CHAT_PREFIXES:
|
for prefix in _NON_CHAT_PREFIXES:
|
||||||
if mid.startswith(prefix):
|
if mid.startswith(prefix):
|
||||||
return False
|
return False
|
||||||
@@ -502,9 +508,67 @@ def _is_chat_model(model_id: str) -> bool:
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
def _probe_single_model(base: str, api_key: str, model_id: str, timeout: int = 10, with_tools: bool = False) -> dict:
|
def _delete_orphaned_provider_auth(db, auth_id: Optional[str], exclude_ep_id: Optional[str] = None) -> bool:
|
||||||
|
"""Delete a ProviderAuthSession once no endpoint still references it.
|
||||||
|
|
||||||
|
Subscription providers (e.g. ChatGPT Subscription) keep their refresh token
|
||||||
|
in ProviderAuthSession rather than ModelEndpoint.api_key. When the last
|
||||||
|
endpoint backed by that auth row is removed, the stored credentials should
|
||||||
|
be cleared instead of lingering. Returns True if a row was deleted.
|
||||||
|
``exclude_ep_id`` drops the endpoint currently being deleted from the
|
||||||
|
reference count so it does not keep its own auth alive.
|
||||||
|
"""
|
||||||
|
if not auth_id:
|
||||||
|
return False
|
||||||
|
from core.database import ProviderAuthSession
|
||||||
|
still_referenced = db.query(ModelEndpoint.id).filter(
|
||||||
|
ModelEndpoint.provider_auth_id == auth_id,
|
||||||
|
ModelEndpoint.id != exclude_ep_id,
|
||||||
|
).first()
|
||||||
|
if still_referenced is not None:
|
||||||
|
return False
|
||||||
|
auth_row = db.query(ProviderAuthSession).filter(ProviderAuthSession.id == auth_id).first()
|
||||||
|
if auth_row is None:
|
||||||
|
return False
|
||||||
|
db.delete(auth_row)
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def _is_discovery_only_provider(provider: str) -> bool:
|
||||||
|
"""Provider that only supports model discovery, not live probing.
|
||||||
|
|
||||||
|
ChatGPT Subscription speaks the Responses/Codex API and has no
|
||||||
|
chat-completions or general health endpoint, so completion probes and
|
||||||
|
reachability pings are skipped — status is derived from cached models.
|
||||||
|
"""
|
||||||
|
return provider == "chatgpt-subscription"
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_probe_key(ep) -> Optional[str]:
|
||||||
|
"""API key/bearer to probe an endpoint with.
|
||||||
|
|
||||||
|
Delegates to ``resolve_endpoint_runtime``, which already returns the static
|
||||||
|
``ModelEndpoint.api_key`` for keyed endpoints and resolves (and refreshes)
|
||||||
|
the runtime bearer for session-backed providers (e.g. ChatGPT Subscription).
|
||||||
|
Returns None if resolution fails (e.g. re-auth required) so probing skips
|
||||||
|
rather than raising. Reads only already-loaded scalar attributes of ``ep``.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from src.endpoint_resolver import resolve_endpoint_runtime
|
||||||
|
_base, key = resolve_endpoint_runtime(ep, owner=getattr(ep, "owner", None))
|
||||||
|
return key
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("Probe key resolution failed for %s: %s", getattr(ep, "id", "?"), e)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _probe_single_model(base: str, api_key: Optional[str], model_id: str, timeout: int = 10, with_tools: bool = False) -> dict:
|
||||||
"""Send a realistic completion request to a single model. Returns {status, latency_ms, error?}."""
|
"""Send a realistic completion request to a single model. Returns {status, latency_ms, error?}."""
|
||||||
provider = _detect_provider(base)
|
provider = _detect_provider(base)
|
||||||
|
if _is_discovery_only_provider(provider):
|
||||||
|
# Responses/Codex API, not chat-completions: a completion probe would
|
||||||
|
# 400 and the re-probe flow would then hide every model. Discovery-only.
|
||||||
|
return {"status": "ok", "latency_ms": 0, "skipped": True}
|
||||||
messages = [
|
messages = [
|
||||||
{"role": "system", "content": "You are a helpful assistant."},
|
{"role": "system", "content": "You are a helpful assistant."},
|
||||||
{"role": "user", "content": "Say OK"},
|
{"role": "user", "content": "Say OK"},
|
||||||
@@ -618,6 +682,11 @@ def _probe_endpoint(base_url: str, api_key: str = None, timeout: int = 5) -> Lis
|
|||||||
For Anthropic, queries their /v1/models API, falling back to hardcoded list."""
|
For Anthropic, queries their /v1/models API, falling back to hardcoded list."""
|
||||||
from src.endpoint_resolver import resolve_url
|
from src.endpoint_resolver import resolve_url
|
||||||
base = resolve_url(_normalize_base(base_url))
|
base = resolve_url(_normalize_base(base_url))
|
||||||
|
if _detect_provider(base) == "chatgpt-subscription":
|
||||||
|
from src.chatgpt_subscription import fetch_available_models
|
||||||
|
if api_key:
|
||||||
|
return fetch_available_models(api_key, timeout=timeout)
|
||||||
|
return []
|
||||||
if _detect_provider(base) == "anthropic":
|
if _detect_provider(base) == "anthropic":
|
||||||
# Try Anthropic's /v1/models endpoint first
|
# Try Anthropic's /v1/models endpoint first
|
||||||
url = build_models_url(base)
|
url = build_models_url(base)
|
||||||
@@ -644,6 +713,10 @@ def _probe_endpoint(base_url: str, api_key: str = None, timeout: int = 5) -> Lis
|
|||||||
logger.warning(f"Anthropic /v1/models failed, using hardcoded list: {e}")
|
logger.warning(f"Anthropic /v1/models failed, using hardcoded list: {e}")
|
||||||
return list(ANTHROPIC_MODELS)
|
return list(ANTHROPIC_MODELS)
|
||||||
url = build_models_url(base)
|
url = build_models_url(base)
|
||||||
|
if not url:
|
||||||
|
curated_key = _match_provider_curated(base, None)
|
||||||
|
fallback = _PROVIDER_CURATED.get(curated_key) if curated_key else None
|
||||||
|
return list(fallback or [])
|
||||||
headers = build_headers(api_key, base)
|
headers = build_headers(api_key, base)
|
||||||
try:
|
try:
|
||||||
r = httpx.get(url, headers=headers, timeout=timeout, verify=llm_verify())
|
r = httpx.get(url, headers=headers, timeout=timeout, verify=llm_verify())
|
||||||
@@ -697,7 +770,6 @@ def _probe_endpoint(base_url: str, api_key: str = None, timeout: int = 5) -> Lis
|
|||||||
return list(fallback)
|
return list(fallback)
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
||||||
def _ping_endpoint(base_url: str, api_key: str = None, timeout: float = 1.5) -> Dict[str, Any]:
|
def _ping_endpoint(base_url: str, api_key: str = None, timeout: float = 1.5) -> Dict[str, Any]:
|
||||||
"""Reachability probe that does not require installed/listed models."""
|
"""Reachability probe that does not require installed/listed models."""
|
||||||
from src.endpoint_resolver import resolve_url
|
from src.endpoint_resolver import resolve_url
|
||||||
@@ -713,6 +785,10 @@ def _ping_endpoint(base_url: str, api_key: str = None, timeout: float = 1.5) ->
|
|||||||
or "ollama" in (parsed_base.hostname or "").lower()
|
or "ollama" in (parsed_base.hostname or "").lower()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# APFEL-specific detection
|
||||||
|
host = (parsed_base.hostname or "").lower()
|
||||||
|
looks_like_apfel = "apfel" in host or parsed_base.port == 11435
|
||||||
|
|
||||||
def _result_from_response(r) -> Dict[str, Any]:
|
def _result_from_response(r) -> Dict[str, Any]:
|
||||||
if 300 <= r.status_code < 400:
|
if 300 <= r.status_code < 400:
|
||||||
loc = r.headers.get("location", "")
|
loc = r.headers.get("location", "")
|
||||||
@@ -734,7 +810,23 @@ def _ping_endpoint(base_url: str, api_key: str = None, timeout: float = 1.5) ->
|
|||||||
last_error: Optional[str] = None
|
last_error: Optional[str] = None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if looks_like_ollama:
|
# APFEL does not behave like Ollama; use its health endpoint.
|
||||||
|
if looks_like_apfel:
|
||||||
|
root = base
|
||||||
|
for suffix in ("/v1", "/api"):
|
||||||
|
if root.endswith(suffix):
|
||||||
|
root = root[: -len(suffix)].rstrip("/")
|
||||||
|
break
|
||||||
|
try:
|
||||||
|
r = httpx.get(root + "/health", timeout=timeout, verify=llm_verify())
|
||||||
|
result = _result_from_response(r)
|
||||||
|
if result["reachable"]:
|
||||||
|
return result
|
||||||
|
last_error = result.get("error")
|
||||||
|
except Exception as e:
|
||||||
|
last_error = str(e)[:120]
|
||||||
|
|
||||||
|
elif looks_like_ollama:
|
||||||
root = base
|
root = base
|
||||||
for suffix in ("/v1", "/api"):
|
for suffix in ("/v1", "/api"):
|
||||||
if root.endswith(suffix):
|
if root.endswith(suffix):
|
||||||
@@ -754,14 +846,31 @@ def _ping_endpoint(base_url: str, api_key: str = None, timeout: float = 1.5) ->
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
r = httpx.get(base, headers=headers, timeout=timeout, verify=llm_verify())
|
r = httpx.get(base, headers=headers, timeout=timeout, verify=llm_verify())
|
||||||
return _result_from_response(r)
|
result = _result_from_response(r)
|
||||||
|
# If the bare base URL returns a non-auth 4xx (e.g. 404), try /models
|
||||||
|
# as a fallback. OpenAI-compatible servers like llama-swap return 404
|
||||||
|
# on the base /v1 prefix but 200 on /v1/models. Auth failures (401/403)
|
||||||
|
# are definitive — probing /models would just repeat the same rejection.
|
||||||
|
if (
|
||||||
|
not result["reachable"]
|
||||||
|
and result.get("status_code") is not None
|
||||||
|
and 400 <= result["status_code"] < 500
|
||||||
|
and result["status_code"] not in (401, 403)
|
||||||
|
):
|
||||||
|
models_url = build_models_url(base)
|
||||||
|
try:
|
||||||
|
r2 = httpx.get(models_url, headers=headers, timeout=timeout, verify=llm_verify())
|
||||||
|
result2 = _result_from_response(r2)
|
||||||
|
if result2["reachable"]:
|
||||||
|
return result2
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return result
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
last_error = str(e)[:120]
|
last_error = str(e)[:120]
|
||||||
|
|
||||||
return {"reachable": False, "status_code": None, "error": last_error}
|
return {"reachable": False, "status_code": None, "error": last_error}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def _model_endpoint_error_message(base_url: str, ping: Dict[str, Any] = None) -> str:
|
def _model_endpoint_error_message(base_url: str, ping: Dict[str, Any] = None) -> str:
|
||||||
"""Return a provider-aware error message for failed endpoint probes."""
|
"""Return a provider-aware error message for failed endpoint probes."""
|
||||||
ping = ping or {}
|
ping = ping or {}
|
||||||
@@ -850,6 +959,14 @@ def _visible_models(cached_models, hidden_models, pinned_models=None):
|
|||||||
return [m for m in merged if m not in hidden]
|
return [m for m in merged if m not in hidden]
|
||||||
|
|
||||||
|
|
||||||
|
def _api_key_fingerprint(api_key: Optional[str]) -> str:
|
||||||
|
"""Stable, non-secret label for distinguishing same-URL credentials."""
|
||||||
|
key = (api_key or "").strip()
|
||||||
|
if not key:
|
||||||
|
return ""
|
||||||
|
return hashlib.sha256(key.encode("utf-8")).hexdigest()[:8]
|
||||||
|
|
||||||
|
|
||||||
def setup_model_routes(model_discovery):
|
def setup_model_routes(model_discovery):
|
||||||
router = APIRouter(prefix="/api")
|
router = APIRouter(prefix="/api")
|
||||||
|
|
||||||
@@ -951,6 +1068,17 @@ def setup_model_routes(model_discovery):
|
|||||||
ok, info = _should_refresh_endpoint(ep, now, force=force)
|
ok, info = _should_refresh_endpoint(ep, now, force=force)
|
||||||
if not ok:
|
if not ok:
|
||||||
continue
|
continue
|
||||||
|
if getattr(ep, "provider_auth_id", None):
|
||||||
|
try:
|
||||||
|
from src.endpoint_resolver import resolve_endpoint_runtime
|
||||||
|
info["base"], info["api_key"] = resolve_endpoint_runtime(
|
||||||
|
ep,
|
||||||
|
owner=getattr(ep, "owner", None),
|
||||||
|
)
|
||||||
|
info["key"] = _refresh_key(info["base"], info["api_key"])
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("Skipping model refresh for %s: could not resolve provider auth: %s", getattr(ep, "name", ep.id), e)
|
||||||
|
continue
|
||||||
groups.setdefault(info["key"], {
|
groups.setdefault(info["key"], {
|
||||||
"base": info["base"],
|
"base": info["base"],
|
||||||
"api_key": info["api_key"],
|
"api_key": info["api_key"],
|
||||||
@@ -1104,8 +1232,9 @@ def setup_model_routes(model_discovery):
|
|||||||
raise HTTPException(401, "Not authenticated")
|
raise HTTPException(401, "Not authenticated")
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
raise
|
raise
|
||||||
except Exception:
|
except Exception as e:
|
||||||
pass
|
logger.error('Auth gate error in GET /api/models, failing closed: %s', e)
|
||||||
|
raise HTTPException(status_code=500, detail='Internal error')
|
||||||
# Admins see every endpoint (they manage the global pool); regular
|
# Admins see every endpoint (they manage the global pool); regular
|
||||||
# users get the owner-scoped view.
|
# users get the owner-scoped view.
|
||||||
_is_admin = False
|
_is_admin = False
|
||||||
@@ -1219,12 +1348,20 @@ def setup_model_routes(model_discovery):
|
|||||||
"endpoint_kind": kind,
|
"endpoint_kind": kind,
|
||||||
}
|
}
|
||||||
try:
|
try:
|
||||||
t0 = _time.time()
|
if _is_discovery_only_provider(provider):
|
||||||
ping = _ping_endpoint(base, ep.api_key, timeout=1.5)
|
# No general health endpoint — an unauthenticated GET just
|
||||||
entry["latency_ms"] = round((_time.time() - t0) * 1000)
|
# 401s. Report status from cached models instead of pinging.
|
||||||
entry["status"] = "online" if ping.get("reachable") or cached_count else "offline"
|
entry["latency_ms"] = None
|
||||||
entry["error"] = ping.get("error")
|
entry["status"] = "online" if cached_count else "offline"
|
||||||
entry["model_count"] = cached_count or (len(ANTHROPIC_MODELS) if provider == "anthropic" else 0)
|
entry["error"] = None
|
||||||
|
entry["model_count"] = cached_count
|
||||||
|
else:
|
||||||
|
t0 = _time.time()
|
||||||
|
ping = _ping_endpoint(base, ep.api_key, timeout=1.5)
|
||||||
|
entry["latency_ms"] = round((_time.time() - t0) * 1000)
|
||||||
|
entry["status"] = "online" if ping.get("reachable") or cached_count else "offline"
|
||||||
|
entry["error"] = ping.get("error")
|
||||||
|
entry["model_count"] = cached_count or (len(ANTHROPIC_MODELS) if provider == "anthropic" else 0)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
entry["latency_ms"] = None
|
entry["latency_ms"] = None
|
||||||
entry["status"] = "online" if cached_count else "offline"
|
entry["status"] = "online" if cached_count else "offline"
|
||||||
@@ -1257,7 +1394,7 @@ def setup_model_routes(model_discovery):
|
|||||||
if ep_id and ep_id not in endpoints_cache:
|
if ep_id and ep_id not in endpoints_cache:
|
||||||
ep = db.query(ModelEndpoint).filter(ModelEndpoint.id == ep_id).first()
|
ep = db.query(ModelEndpoint).filter(ModelEndpoint.id == ep_id).first()
|
||||||
if ep:
|
if ep:
|
||||||
endpoints_cache[ep_id] = {"base_url": ep.base_url, "api_key": ep.api_key}
|
endpoints_cache[ep_id] = {"base_url": ep.base_url, "api_key": _resolve_probe_key(ep)}
|
||||||
ep_data = endpoints_cache.get(ep_id)
|
ep_data = endpoints_cache.get(ep_id)
|
||||||
if not ep_data:
|
if not ep_data:
|
||||||
# Try to find by base_url from the model's endpoint field
|
# Try to find by base_url from the model's endpoint field
|
||||||
@@ -1296,7 +1433,7 @@ def setup_model_routes(model_discovery):
|
|||||||
"id": ep.id,
|
"id": ep.id,
|
||||||
"name": ep.name,
|
"name": ep.name,
|
||||||
"base_url": ep.base_url,
|
"base_url": ep.base_url,
|
||||||
"api_key": ep.api_key,
|
"api_key": _resolve_probe_key(ep),
|
||||||
})
|
})
|
||||||
finally:
|
finally:
|
||||||
db.close()
|
db.close()
|
||||||
@@ -1385,18 +1522,21 @@ def setup_model_routes(model_discovery):
|
|||||||
# Endpoint counts as reachable if it has any model — including
|
# Endpoint counts as reachable if it has any model — including
|
||||||
# admin-pinned IDs that a probe would never surface.
|
# admin-pinned IDs that a probe would never surface.
|
||||||
status = "online" if (all_models or pinned) else "offline"
|
status = "online" if (all_models or pinned) else "offline"
|
||||||
|
base = _normalize_base(r.base_url)
|
||||||
ping = None
|
ping = None
|
||||||
if not all_models and not pinned and r.is_enabled:
|
# Discovery-only providers have no health endpoint — an
|
||||||
|
# unauthenticated ping just 401s, so don't bother.
|
||||||
|
if not all_models and not pinned and r.is_enabled and not _is_discovery_only_provider(_detect_provider(base)):
|
||||||
ping = _ping_endpoint(r.base_url, r.api_key, timeout=1.0)
|
ping = _ping_endpoint(r.base_url, r.api_key, timeout=1.0)
|
||||||
if ping.get("reachable"):
|
if ping.get("reachable"):
|
||||||
status = "empty"
|
status = "empty"
|
||||||
base = _normalize_base(r.base_url)
|
|
||||||
kind = _effective_endpoint_kind(r, base)
|
kind = _effective_endpoint_kind(r, base)
|
||||||
results.append({
|
results.append({
|
||||||
"id": r.id,
|
"id": r.id,
|
||||||
"name": r.name,
|
"name": r.name,
|
||||||
"base_url": r.base_url,
|
"base_url": r.base_url,
|
||||||
"has_key": bool(r.api_key),
|
"has_key": bool(r.api_key),
|
||||||
|
"api_key_fingerprint": _api_key_fingerprint(r.api_key),
|
||||||
"is_enabled": r.is_enabled,
|
"is_enabled": r.is_enabled,
|
||||||
"models": visible,
|
"models": visible,
|
||||||
"pinned_models": pinned,
|
"pinned_models": pinned,
|
||||||
@@ -1463,21 +1603,34 @@ def setup_model_routes(model_discovery):
|
|||||||
)
|
)
|
||||||
explicit_timeout = _explicit_model_list_timeout(base_url, requested_kind, refresh_timeout)
|
explicit_timeout = _explicit_model_list_timeout(base_url, requested_kind, refresh_timeout)
|
||||||
|
|
||||||
# Dedupe: if an endpoint with the same base_url already exists and
|
# Dedupe: if an endpoint with the same base_url and compatible
|
||||||
# is reachable by the caller (shared or owned by them), return it
|
# credentials already exists and is reachable by the caller (shared or
|
||||||
# instead of creating a duplicate row. Fixes "Scan for Servers"
|
# owned by them), return it instead of creating a duplicate row. Keep
|
||||||
# re-adding manually-added endpoints under their host:port name.
|
# same-url/different-key rows distinct so users can group the same
|
||||||
|
# provider URL under multiple credentials.
|
||||||
from src.auth_helpers import get_current_user as _gcu_dedup
|
from src.auth_helpers import get_current_user as _gcu_dedup
|
||||||
_caller = _gcu_dedup(request) or None
|
_caller = _gcu_dedup(request) or None
|
||||||
|
_incoming_api_key = api_key.strip()
|
||||||
_db_dedup = SessionLocal()
|
_db_dedup = SessionLocal()
|
||||||
try:
|
try:
|
||||||
existing = (
|
_same_url_rows = (
|
||||||
_db_dedup.query(ModelEndpoint)
|
_db_dedup.query(ModelEndpoint)
|
||||||
.filter(ModelEndpoint.base_url == base_url)
|
.filter(ModelEndpoint.base_url == base_url)
|
||||||
.filter((ModelEndpoint.owner.is_(None)) | (ModelEndpoint.owner == _caller))
|
.filter((ModelEndpoint.owner.is_(None)) | (ModelEndpoint.owner == _caller))
|
||||||
.order_by(ModelEndpoint.owner.desc()) # prefer owned over shared
|
.order_by(ModelEndpoint.owner.desc()) # prefer owned over shared
|
||||||
.first()
|
.all()
|
||||||
)
|
)
|
||||||
|
existing = None
|
||||||
|
_empty_key_existing = None
|
||||||
|
for _candidate in _same_url_rows:
|
||||||
|
_candidate_key = (getattr(_candidate, "api_key", None) or "").strip()
|
||||||
|
if _candidate_key == _incoming_api_key:
|
||||||
|
existing = _candidate
|
||||||
|
break
|
||||||
|
if _incoming_api_key and not _candidate_key and _empty_key_existing is None:
|
||||||
|
_empty_key_existing = _candidate
|
||||||
|
if existing is None and _incoming_api_key and _empty_key_existing is not None:
|
||||||
|
existing = _empty_key_existing
|
||||||
if existing:
|
if existing:
|
||||||
changed = False
|
changed = False
|
||||||
# Persist any incoming pinned IDs onto the existing row. An
|
# Persist any incoming pinned IDs onto the existing row. An
|
||||||
@@ -1526,6 +1679,8 @@ def setup_model_routes(model_discovery):
|
|||||||
"id": existing.id,
|
"id": existing.id,
|
||||||
"name": existing.name,
|
"name": existing.name,
|
||||||
"base_url": existing.base_url,
|
"base_url": existing.base_url,
|
||||||
|
"has_key": bool(existing.api_key),
|
||||||
|
"api_key_fingerprint": _api_key_fingerprint(existing.api_key),
|
||||||
"models": _visible_models(
|
"models": _visible_models(
|
||||||
existing_models,
|
existing_models,
|
||||||
getattr(existing, "hidden_models", None),
|
getattr(existing, "hidden_models", None),
|
||||||
@@ -1599,6 +1754,8 @@ def setup_model_routes(model_discovery):
|
|||||||
"id": ep_id,
|
"id": ep_id,
|
||||||
"name": name.strip(),
|
"name": name.strip(),
|
||||||
"base_url": base_url,
|
"base_url": base_url,
|
||||||
|
"has_key": bool(api_key.strip()),
|
||||||
|
"api_key_fingerprint": _api_key_fingerprint(api_key),
|
||||||
"models": _merge_model_ids(model_ids, _pinned),
|
"models": _merge_model_ids(model_ids, _pinned),
|
||||||
"pinned_models": _pinned,
|
"pinned_models": _pinned,
|
||||||
"online": bool(model_ids) or bool(_pinned) or bool(ping.get("reachable")),
|
"online": bool(model_ids) or bool(_pinned) or bool(ping.get("reachable")),
|
||||||
@@ -1648,7 +1805,7 @@ def setup_model_routes(model_discovery):
|
|||||||
ep = db.query(ModelEndpoint).filter(ModelEndpoint.id == ep_id).first()
|
ep = db.query(ModelEndpoint).filter(ModelEndpoint.id == ep_id).first()
|
||||||
if not ep:
|
if not ep:
|
||||||
raise HTTPException(404, "Endpoint not found")
|
raise HTTPException(404, "Endpoint not found")
|
||||||
ep_data = {"id": ep.id, "name": ep.name, "base_url": ep.base_url, "api_key": ep.api_key}
|
ep_data = {"id": ep.id, "name": ep.name, "base_url": ep.base_url, "api_key": _resolve_probe_key(ep)}
|
||||||
finally:
|
finally:
|
||||||
db.close()
|
db.close()
|
||||||
|
|
||||||
@@ -1712,7 +1869,7 @@ def setup_model_routes(model_discovery):
|
|||||||
category = _classify_endpoint(base, kind)
|
category = _classify_endpoint(base, kind)
|
||||||
timeout = _manual_refresh_timeout(ep, category, refresh_timeout)
|
timeout = _manual_refresh_timeout(ep, category, refresh_timeout)
|
||||||
try:
|
try:
|
||||||
probed = _probe_endpoint(base, ep.api_key, timeout=timeout)
|
probed = _probe_endpoint(base, _resolve_probe_key(ep), timeout=timeout)
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.warning("Manual model refresh failed for endpoint %s at %s: %s", ep_id, base, exc)
|
logger.warning("Manual model refresh failed for endpoint %s at %s: %s", ep_id, base, exc)
|
||||||
probed = []
|
probed = []
|
||||||
@@ -1948,6 +2105,8 @@ def setup_model_routes(model_discovery):
|
|||||||
"name": ep.name,
|
"name": ep.name,
|
||||||
"model_type": ep.model_type,
|
"model_type": ep.model_type,
|
||||||
"base_url": ep.base_url,
|
"base_url": ep.base_url,
|
||||||
|
"has_key": bool(ep.api_key),
|
||||||
|
"api_key_fingerprint": _api_key_fingerprint(ep.api_key),
|
||||||
"pinned_models": _normalize_model_ids(getattr(ep, "pinned_models", None)),
|
"pinned_models": _normalize_model_ids(getattr(ep, "pinned_models", None)),
|
||||||
"endpoint_kind": getattr(ep, "endpoint_kind", None) or "auto",
|
"endpoint_kind": getattr(ep, "endpoint_kind", None) or "auto",
|
||||||
"model_refresh_mode": getattr(ep, "model_refresh_mode", None) or "auto",
|
"model_refresh_mode": getattr(ep, "model_refresh_mode", None) or "auto",
|
||||||
@@ -2049,7 +2208,9 @@ def setup_model_routes(model_discovery):
|
|||||||
cleared_user_preferences = _clear_user_prefs_for_endpoint(ep_id)
|
cleared_user_preferences = _clear_user_prefs_for_endpoint(ep_id)
|
||||||
cleared_sessions = _clear_sessions_for_endpoint(db, ep.base_url)
|
cleared_sessions = _clear_sessions_for_endpoint(db, ep.base_url)
|
||||||
cleared_loaded_sessions = _clear_loaded_sessions_for_endpoint(ep.base_url)
|
cleared_loaded_sessions = _clear_loaded_sessions_for_endpoint(ep.base_url)
|
||||||
|
auth_id = getattr(ep, "provider_auth_id", None)
|
||||||
db.delete(ep)
|
db.delete(ep)
|
||||||
|
cleared_provider_auth = _delete_orphaned_provider_auth(db, auth_id, exclude_ep_id=ep_id)
|
||||||
db.commit()
|
db.commit()
|
||||||
_invalidate_models_cache()
|
_invalidate_models_cache()
|
||||||
_local_probe_cache["data"] = None
|
_local_probe_cache["data"] = None
|
||||||
@@ -2059,6 +2220,7 @@ def setup_model_routes(model_discovery):
|
|||||||
"cleared_user_preferences": cleared_user_preferences,
|
"cleared_user_preferences": cleared_user_preferences,
|
||||||
"cleared_sessions": cleared_sessions,
|
"cleared_sessions": cleared_sessions,
|
||||||
"cleared_loaded_sessions": cleared_loaded_sessions,
|
"cleared_loaded_sessions": cleared_loaded_sessions,
|
||||||
|
"cleared_provider_auth": cleared_provider_auth,
|
||||||
}
|
}
|
||||||
finally:
|
finally:
|
||||||
db.close()
|
db.close()
|
||||||
|
|||||||
+161
-16
@@ -11,6 +11,7 @@ from pydantic import BaseModel
|
|||||||
|
|
||||||
from core.database import SessionLocal, Note
|
from core.database import SessionLocal, Note
|
||||||
from src.auth_helpers import get_current_user
|
from src.auth_helpers import get_current_user
|
||||||
|
from src.constants import DATA_DIR
|
||||||
from sqlalchemy.orm.attributes import flag_modified
|
from sqlalchemy.orm.attributes import flag_modified
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -95,6 +96,32 @@ def _note_to_dict(note: Note) -> Dict[str, Any]:
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _reminder_text_from_note(note: Note) -> tuple[str, str]:
|
||||||
|
"""Return the reminder title/body from a stored note row."""
|
||||||
|
title = (note.title or "Note reminder").strip() or "Note reminder"
|
||||||
|
if note.items:
|
||||||
|
try:
|
||||||
|
items = json.loads(note.items)
|
||||||
|
except (json.JSONDecodeError, TypeError):
|
||||||
|
items = None
|
||||||
|
if isinstance(items, list):
|
||||||
|
pending: list[str] = []
|
||||||
|
for item in items:
|
||||||
|
if not isinstance(item, dict):
|
||||||
|
continue
|
||||||
|
if item.get("done") or item.get("checked"):
|
||||||
|
continue
|
||||||
|
text = str(item.get("text") or "").strip()
|
||||||
|
if text:
|
||||||
|
pending.append(text)
|
||||||
|
if pending:
|
||||||
|
shown = "\n".join(f"- {text}" for text in pending[:8])
|
||||||
|
extra = f"\n...and {len(pending) - 8} more" if len(pending) > 8 else ""
|
||||||
|
return title, f"Pending ({len(pending)}):\n{shown}{extra}"
|
||||||
|
return title, f"{len(items)} item{'s' if len(items) != 1 else ''}"
|
||||||
|
return title, (note.content or "").strip()[:400]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Reminder dispatch — module-level so background tasks (built-in actions)
|
# Reminder dispatch — module-level so background tasks (built-in actions)
|
||||||
@@ -114,8 +141,9 @@ async def dispatch_reminder(
|
|||||||
note_id: str,
|
note_id: str,
|
||||||
owner: str = "",
|
owner: str = "",
|
||||||
queue_browser: bool = True,
|
queue_browser: bool = True,
|
||||||
|
settings_override: dict | None = None,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""Fire a reminder via the configured channel (browser/email/ntfy).
|
"""Fire a reminder via the configured channel (browser/email/ntfy/webhook).
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
title: short headline shown to the user
|
title: short headline shown to the user
|
||||||
@@ -129,7 +157,7 @@ async def dispatch_reminder(
|
|||||||
nothing is "sent" synchronously for it — the channel just routes there.
|
nothing is "sent" synchronously for it — the channel just routes there.
|
||||||
"""
|
"""
|
||||||
from src.settings import load_settings
|
from src.settings import load_settings
|
||||||
settings = load_settings()
|
settings = {**load_settings(), **(settings_override or {})}
|
||||||
channel = settings.get("reminder_channel", "browser")
|
channel = settings.get("reminder_channel", "browser")
|
||||||
llm_on = bool(settings.get("reminder_llm_synthesis", False))
|
llm_on = bool(settings.get("reminder_llm_synthesis", False))
|
||||||
title = (title or "").strip()
|
title = (title or "").strip()
|
||||||
@@ -143,7 +171,7 @@ async def dispatch_reminder(
|
|||||||
from datetime import datetime as _dt, timezone as _tz, timedelta as _td
|
from datetime import datetime as _dt, timezone as _tz, timedelta as _td
|
||||||
from pathlib import Path as _P
|
from pathlib import Path as _P
|
||||||
_slug = "".join(c if (c.isalnum() or c in "-_.@") else "_" for c in (owner or "default"))
|
_slug = "".join(c if (c.isalnum() or c in "-_.@") else "_" for c in (owner or "default"))
|
||||||
cache_path = _P(f"data/note_pings_{_slug}.json")
|
cache_path = _P(DATA_DIR) / f"note_pings_{_slug}.json"
|
||||||
if cache_path.exists():
|
if cache_path.exists():
|
||||||
cache = _json.loads(cache_path.read_text(encoding="utf-8"))
|
cache = _json.loads(cache_path.read_text(encoding="utf-8"))
|
||||||
last = cache.get(cache_key)
|
last = cache.get(cache_key)
|
||||||
@@ -160,13 +188,14 @@ async def dispatch_reminder(
|
|||||||
# Treat those as browser-only dedupe so email reminders can be
|
# Treat those as browser-only dedupe so email reminders can be
|
||||||
# retried by the backend scanner after a failed frontend path.
|
# retried by the backend scanner after a failed frontend path.
|
||||||
should_skip = last_dt >= _dt.now(_tz.utc) - _td(minutes=25)
|
should_skip = last_dt >= _dt.now(_tz.utc) - _td(minutes=25)
|
||||||
if should_skip and channel in ("email", "ntfy"):
|
if should_skip and channel in ("email", "ntfy", "webhook"):
|
||||||
should_skip = last_channel == channel
|
should_skip = last_channel == channel
|
||||||
if should_skip:
|
if should_skip:
|
||||||
return {
|
return {
|
||||||
"synthesis": None,
|
"synthesis": None,
|
||||||
"email_sent": False,
|
"email_sent": False,
|
||||||
"ntfy_sent": False,
|
"ntfy_sent": False,
|
||||||
|
"webhook_sent": False,
|
||||||
"browser_sent": True,
|
"browser_sent": True,
|
||||||
"skipped": True,
|
"skipped": True,
|
||||||
}
|
}
|
||||||
@@ -179,9 +208,9 @@ async def dispatch_reminder(
|
|||||||
try:
|
try:
|
||||||
from src.endpoint_resolver import resolve_endpoint
|
from src.endpoint_resolver import resolve_endpoint
|
||||||
from src.llm_core import llm_call_async
|
from src.llm_core import llm_call_async
|
||||||
url, model, headers = resolve_endpoint("utility")
|
url, model, headers = resolve_endpoint("utility", owner=owner or None)
|
||||||
if not url:
|
if not url:
|
||||||
url, model, headers = resolve_endpoint("default")
|
url, model, headers = resolve_endpoint("default", owner=owner or None)
|
||||||
if url and model:
|
if url and model:
|
||||||
raw = await llm_call_async(
|
raw = await llm_call_async(
|
||||||
url=url, model=model,
|
url=url, model=model,
|
||||||
@@ -360,6 +389,76 @@ async def dispatch_reminder(
|
|||||||
email_error = str(e) or e.__class__.__name__
|
email_error = str(e) or e.__class__.__name__
|
||||||
logger.warning(f"Reminder email send failed: {e}")
|
logger.warning(f"Reminder email send failed: {e}")
|
||||||
|
|
||||||
|
webhook_sent = False
|
||||||
|
webhook_error = ""
|
||||||
|
if channel == "webhook":
|
||||||
|
try:
|
||||||
|
import httpx
|
||||||
|
import json as _wjson
|
||||||
|
from src.integrations import load_integrations
|
||||||
|
# Built-in payload defaults for known presets so users don't have
|
||||||
|
# to configure a template just to use a standard service.
|
||||||
|
_PRESET_TEMPLATE_DEFAULTS = {
|
||||||
|
"discord_webhook": '{"embeds": [{"title": "{{title}}", "description": "{{message}}", "color": 5793266}]}',
|
||||||
|
}
|
||||||
|
intg_id = settings.get("reminder_webhook_integration_id", "").strip()
|
||||||
|
template = settings.get("reminder_webhook_payload_template", "").strip()
|
||||||
|
if not intg_id:
|
||||||
|
webhook_error = "No webhook integration selected"
|
||||||
|
else:
|
||||||
|
intg = next(
|
||||||
|
(i for i in load_integrations()
|
||||||
|
if i.get("id") == intg_id and i.get("base_url")),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
if not intg:
|
||||||
|
webhook_error = f"Integration {intg_id!r} not found or missing base URL"
|
||||||
|
else:
|
||||||
|
# Fall back to a built-in default for known presets so
|
||||||
|
# users don't have to configure a template for standard
|
||||||
|
# services like Discord.
|
||||||
|
if not template:
|
||||||
|
template = _PRESET_TEMPLATE_DEFAULTS.get(intg.get("preset", ""), "")
|
||||||
|
if not template:
|
||||||
|
webhook_error = "No payload template configured"
|
||||||
|
else:
|
||||||
|
# Render template: JSON-escape the values so the result
|
||||||
|
# is always valid JSON regardless of special characters.
|
||||||
|
# dumps() returns `"value"` — strip outer quotes.
|
||||||
|
msg = (synthesis or note_body or title or "Reminder")[:4000]
|
||||||
|
_t = _wjson.dumps(title or "Reminder")[1:-1]
|
||||||
|
_m = _wjson.dumps(msg)[1:-1]
|
||||||
|
rendered = template.replace("{{title}}", _t).replace("{{message}}", _m)
|
||||||
|
hdrs = {"Content-Type": "application/json"}
|
||||||
|
api_key = intg.get("api_key", "")
|
||||||
|
auth_type = (intg.get("auth_type") or "none").lower()
|
||||||
|
if api_key:
|
||||||
|
if auth_type == "bearer":
|
||||||
|
hdrs["Authorization"] = f"Bearer {api_key}"
|
||||||
|
elif auth_type == "header":
|
||||||
|
hdrs[intg.get("auth_header") or "Authorization"] = api_key
|
||||||
|
url = intg["base_url"].rstrip("/")
|
||||||
|
# SSRF guard — matches the pattern used by webhook_routes,
|
||||||
|
# CalDAV, search, and embeddings. Blocks link-local / metadata
|
||||||
|
# addresses (169.254.x.x) by default; set
|
||||||
|
# REMINDER_WEBHOOK_BLOCK_PRIVATE_IPS=true to also block
|
||||||
|
# RFC-1918 ranges for locked-down deployments.
|
||||||
|
import os as _os
|
||||||
|
from src.url_safety import check_outbound_url as _chk
|
||||||
|
_block = _os.getenv("REMINDER_WEBHOOK_BLOCK_PRIVATE_IPS", "false").lower() == "true"
|
||||||
|
_ok, _reason = _chk(url, block_private=_block)
|
||||||
|
if not _ok:
|
||||||
|
webhook_error = f"Webhook URL rejected: {_reason}"
|
||||||
|
else:
|
||||||
|
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||||
|
resp = await client.post(url, content=rendered.encode(), headers=hdrs)
|
||||||
|
webhook_sent = resp.is_success
|
||||||
|
if not webhook_sent:
|
||||||
|
webhook_error = f"Webhook returned HTTP {resp.status_code}"
|
||||||
|
except Exception as e:
|
||||||
|
webhook_error = str(e) or e.__class__.__name__
|
||||||
|
logger.warning(f"Reminder webhook send failed: {e}")
|
||||||
|
|
||||||
ntfy_sent = False
|
ntfy_sent = False
|
||||||
ntfy_error = ""
|
ntfy_error = ""
|
||||||
if channel == "ntfy":
|
if channel == "ntfy":
|
||||||
@@ -415,7 +514,7 @@ async def dispatch_reminder(
|
|||||||
# second send for the same note within 25 min. Without this, a note
|
# second send for the same note within 25 min. Without this, a note
|
||||||
# whose due_date fires while the user has the app open got TWO emails
|
# whose due_date fires while the user has the app open got TWO emails
|
||||||
# (frontend-fired here + background-fired by ping_notes 0–5 min later).
|
# (frontend-fired here + background-fired by ping_notes 0–5 min later).
|
||||||
if (email_sent or ntfy_sent or browser_sent or local_browser_sent) and note_id:
|
if (email_sent or ntfy_sent or webhook_sent or browser_sent or local_browser_sent) and note_id:
|
||||||
try:
|
try:
|
||||||
import json as _json
|
import json as _json
|
||||||
from datetime import datetime as _dt, timezone as _tz
|
from datetime import datetime as _dt, timezone as _tz
|
||||||
@@ -425,13 +524,13 @@ async def dispatch_reminder(
|
|||||||
_STATE = cache_path
|
_STATE = cache_path
|
||||||
if _STATE is None:
|
if _STATE is None:
|
||||||
_slug = "".join(c if (c.isalnum() or c in "-_.@") else "_" for c in (owner or "default"))
|
_slug = "".join(c if (c.isalnum() or c in "-_.@") else "_" for c in (owner or "default"))
|
||||||
_STATE = _P(f"data/note_pings_{_slug}.json")
|
_STATE = _P(DATA_DIR) / f"note_pings_{_slug}.json"
|
||||||
_STATE.parent.mkdir(parents=True, exist_ok=True)
|
_STATE.parent.mkdir(parents=True, exist_ok=True)
|
||||||
try:
|
try:
|
||||||
_cache = cache or (_json.loads(_STATE.read_text(encoding="utf-8")) if _STATE.exists() else {})
|
_cache = cache or (_json.loads(_STATE.read_text(encoding="utf-8")) if _STATE.exists() else {})
|
||||||
except Exception:
|
except Exception:
|
||||||
_cache = {}
|
_cache = {}
|
||||||
sent_channel = "email" if email_sent else "ntfy" if ntfy_sent else "browser"
|
sent_channel = "email" if email_sent else "ntfy" if ntfy_sent else "webhook" if webhook_sent else "browser"
|
||||||
_cache[cache_key or str(note_id)] = {
|
_cache[cache_key or str(note_id)] = {
|
||||||
"at": _dt.now(_tz.utc).isoformat(),
|
"at": _dt.now(_tz.utc).isoformat(),
|
||||||
"channel": sent_channel,
|
"channel": sent_channel,
|
||||||
@@ -441,11 +540,14 @@ async def dispatch_reminder(
|
|||||||
logger.debug(f"dispatch_reminder: cache write failed: {_e}")
|
logger.debug(f"dispatch_reminder: cache write failed: {_e}")
|
||||||
|
|
||||||
return {
|
return {
|
||||||
|
"channel": channel,
|
||||||
"synthesis": synthesis,
|
"synthesis": synthesis,
|
||||||
"email_sent": email_sent,
|
"email_sent": email_sent,
|
||||||
"email_error": email_error,
|
"email_error": email_error,
|
||||||
"ntfy_sent": ntfy_sent,
|
"ntfy_sent": ntfy_sent,
|
||||||
"ntfy_error": ntfy_error,
|
"ntfy_error": ntfy_error,
|
||||||
|
"webhook_sent": webhook_sent,
|
||||||
|
"webhook_error": webhook_error,
|
||||||
"browser_sent": browser_sent or local_browser_sent,
|
"browser_sent": browser_sent or local_browser_sent,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -467,6 +569,23 @@ def setup_note_routes(task_scheduler=None):
|
|||||||
def _owner(request: Request) -> Optional[str]:
|
def _owner(request: Request) -> Optional[str]:
|
||||||
return get_current_user(request)
|
return get_current_user(request)
|
||||||
|
|
||||||
|
def _is_admin_or_single_user(request: Request, user: str | None) -> bool:
|
||||||
|
if user == "internal-tool":
|
||||||
|
return True
|
||||||
|
if not user:
|
||||||
|
# require_user() already admitted this request, which only happens
|
||||||
|
# for auth-disabled, loopback-bypass, or unconfigured single-user
|
||||||
|
# modes. There is no separate non-admin account boundary there.
|
||||||
|
return True
|
||||||
|
try:
|
||||||
|
from core.auth import AuthManager
|
||||||
|
auth_mgr = getattr(request.app.state, "auth_manager", None) or AuthManager()
|
||||||
|
if not getattr(auth_mgr, "is_configured", True):
|
||||||
|
return True
|
||||||
|
return bool(auth_mgr.is_admin(user))
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
# --- LIST ---
|
# --- LIST ---
|
||||||
@router.get("")
|
@router.get("")
|
||||||
def list_notes(
|
def list_notes(
|
||||||
@@ -684,20 +803,46 @@ def setup_note_routes(task_scheduler=None):
|
|||||||
"""
|
"""
|
||||||
# Gate against anonymous callers — LLM synthesis can burn tokens.
|
# Gate against anonymous callers — LLM synthesis can burn tokens.
|
||||||
from src.auth_helpers import require_user as _ru
|
from src.auth_helpers import require_user as _ru
|
||||||
_ru(request)
|
user = _ru(request)
|
||||||
body = await request.json()
|
body = await request.json()
|
||||||
note_id = body.get("note_id")
|
note_id = str(body.get("note_id") or "").strip()
|
||||||
title = (body.get("title") or "").strip()
|
|
||||||
note_body = (body.get("body") or "").strip()
|
|
||||||
if not note_id:
|
if not note_id:
|
||||||
raise HTTPException(400, "note_id required")
|
raise HTTPException(400, "note_id required")
|
||||||
|
|
||||||
# Delegate to the module-level helper so background tasks can reuse
|
caller = _owner(request)
|
||||||
# the same dispatch without an HTTP roundtrip + auth cookie.
|
is_test = note_id.startswith("test-")
|
||||||
|
is_admin = _is_admin_or_single_user(request, user or caller)
|
||||||
|
_override: dict = {}
|
||||||
|
if is_test:
|
||||||
|
if not is_admin:
|
||||||
|
raise HTTPException(403, "Admin only")
|
||||||
|
title = (body.get("title") or "Test Reminder").strip() or "Test Reminder"
|
||||||
|
note_body = (body.get("body") or "").strip()
|
||||||
|
# Optional overrides let the admin settings test button pass the
|
||||||
|
# current UI values directly so it never races a pending save.
|
||||||
|
if body.get("channel"):
|
||||||
|
_override["reminder_channel"] = body["channel"]
|
||||||
|
if body.get("webhook_integration_id"):
|
||||||
|
_override["reminder_webhook_integration_id"] = body["webhook_integration_id"]
|
||||||
|
if body.get("webhook_payload_template"):
|
||||||
|
_override["reminder_webhook_payload_template"] = body["webhook_payload_template"]
|
||||||
|
else:
|
||||||
|
db = SessionLocal()
|
||||||
|
try:
|
||||||
|
note = db.query(Note).filter(Note.id == note_id).first()
|
||||||
|
if not note:
|
||||||
|
raise HTTPException(404, "Note not found")
|
||||||
|
if caller is not None and note.owner != caller:
|
||||||
|
raise HTTPException(404, "Note not found")
|
||||||
|
title, note_body = _reminder_text_from_note(note)
|
||||||
|
finally:
|
||||||
|
db.close()
|
||||||
|
|
||||||
return await dispatch_reminder(
|
return await dispatch_reminder(
|
||||||
title=title, note_body=note_body, note_id=note_id,
|
title=title, note_body=note_body, note_id=note_id,
|
||||||
owner=_owner(request) or "",
|
owner=caller or "",
|
||||||
queue_browser=False,
|
queue_browser=False,
|
||||||
|
settings_override=_override or None,
|
||||||
)
|
)
|
||||||
|
|
||||||
# --- REORDER NOTES ---
|
# --- REORDER NOTES ---
|
||||||
|
|||||||
+13
-12
@@ -6,16 +6,14 @@ import uuid
|
|||||||
from typing import List, Tuple
|
from typing import List, Tuple
|
||||||
from fastapi import APIRouter, HTTPException, Query, Request, UploadFile, File, Depends
|
from fastapi import APIRouter, HTTPException, Query, Request, UploadFile, File, Depends
|
||||||
from src.request_models import DirectoryRequest
|
from src.request_models import DirectoryRequest
|
||||||
from core.constants import BASE_DIR, PERSONAL_DIR
|
from core.constants import BASE_DIR, PERSONAL_DIR, PERSONAL_UPLOADS_DIR
|
||||||
from src.rag_singleton import get_rag_manager
|
from src.rag_singleton import get_rag_manager
|
||||||
from src.auth_helpers import get_current_user, require_user
|
from src.auth_helpers import require_privilege, require_user
|
||||||
from core.middleware import require_admin
|
from core.middleware import require_admin
|
||||||
from src.upload_handler import secure_filename
|
from src.upload_handler import secure_filename
|
||||||
|
from src.upload_limits import PERSONAL_UPLOAD_MAX_BYTES
|
||||||
|
|
||||||
UPLOADS_DIR = os.path.join(BASE_DIR, "data", "personal_uploads")
|
UPLOADS_DIR = PERSONAL_UPLOADS_DIR
|
||||||
MAX_PERSONAL_UPLOAD_BYTES = int(
|
|
||||||
os.getenv("ODYSSEUS_PERSONAL_UPLOAD_MAX_BYTES", str(25 * 1024 * 1024))
|
|
||||||
)
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -194,7 +192,7 @@ def setup_personal_routes(personal_docs_manager, rag_manager, rag_available):
|
|||||||
@router.post("/upload")
|
@router.post("/upload")
|
||||||
async def upload_files_to_rag(request: Request, files: List[UploadFile] = File(...)):
|
async def upload_files_to_rag(request: Request, files: List[UploadFile] = File(...)):
|
||||||
"""Upload files directly into RAG. Supports text and PDF."""
|
"""Upload files directly into RAG. Supports text and PDF."""
|
||||||
user = get_current_user(request)
|
user = require_privilege(request, "can_use_documents")
|
||||||
rag = _rag()
|
rag = _rag()
|
||||||
if not rag:
|
if not rag:
|
||||||
raise HTTPException(503, "RAG system is not available — is the embedding service running?")
|
raise HTTPException(503, "RAG system is not available — is the embedding service running?")
|
||||||
@@ -208,8 +206,8 @@ def setup_personal_routes(personal_docs_manager, rag_manager, rag_available):
|
|||||||
for upload in files:
|
for upload in files:
|
||||||
try:
|
try:
|
||||||
file_path, stored_name, safe_name = _unique_personal_upload_path(upload_dir, upload.filename)
|
file_path, stored_name, safe_name = _unique_personal_upload_path(upload_dir, upload.filename)
|
||||||
content_bytes = await upload.read(MAX_PERSONAL_UPLOAD_BYTES + 1)
|
content_bytes = await upload.read(PERSONAL_UPLOAD_MAX_BYTES + 1)
|
||||||
if len(content_bytes) > MAX_PERSONAL_UPLOAD_BYTES:
|
if len(content_bytes) > PERSONAL_UPLOAD_MAX_BYTES:
|
||||||
logger.warning(f"Rejected oversized personal upload: {upload.filename!r}")
|
logger.warning(f"Rejected oversized personal upload: {upload.filename!r}")
|
||||||
total_failed += 1
|
total_failed += 1
|
||||||
continue
|
continue
|
||||||
@@ -286,9 +284,12 @@ def setup_personal_routes(personal_docs_manager, rag_manager, rag_available):
|
|||||||
except ValueError:
|
except ValueError:
|
||||||
# commonpath raises on mixed drives / non-comparable paths
|
# commonpath raises on mixed drives / non-comparable paths
|
||||||
in_uploads = False
|
in_uploads = False
|
||||||
if in_uploads and abs_target != base_abs and os.path.exists(abs_target):
|
if in_uploads and abs_target != base_abs:
|
||||||
os.remove(abs_target)
|
try:
|
||||||
deleted_from_disk = True
|
os.remove(abs_target)
|
||||||
|
deleted_from_disk = True
|
||||||
|
except FileNotFoundError:
|
||||||
|
pass # already gone — race with another request or cleanup
|
||||||
|
|
||||||
# Exclude the file from the listing (persists across restarts)
|
# Exclude the file from the listing (persists across restarts)
|
||||||
personal_docs_manager.exclude_file(filepath)
|
personal_docs_manager.exclude_file(filepath)
|
||||||
|
|||||||
@@ -4,8 +4,9 @@ import os
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
from fastapi import APIRouter, Request
|
from fastapi import APIRouter, Request
|
||||||
from src.auth_helpers import get_current_user
|
from src.auth_helpers import get_current_user
|
||||||
|
from src.constants import USER_PREFS_FILE
|
||||||
|
|
||||||
PREFS_FILE = os.path.join("data", "user_prefs.json")
|
PREFS_FILE = USER_PREFS_FILE
|
||||||
|
|
||||||
|
|
||||||
def _load():
|
def _load():
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ from pydantic import BaseModel, Field
|
|||||||
|
|
||||||
from src.request_models import PresetUpdateRequest
|
from src.request_models import PresetUpdateRequest
|
||||||
from core.middleware import require_admin
|
from core.middleware import require_admin
|
||||||
|
from src.auth_helpers import effective_user
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -100,7 +101,8 @@ def setup_preset_routes(preset_manager) -> APIRouter:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
model_spec = data.get("model") or ""
|
model_spec = data.get("model") or ""
|
||||||
url, model, headers = _resolve_model(model_spec)
|
user = effective_user(request)
|
||||||
|
url, model, headers = _resolve_model(model_spec, owner=user)
|
||||||
result = await llm_call_async(url, model, messages, temperature=0.8, max_tokens=500, headers=headers)
|
result = await llm_call_async(url, model, messages, temperature=0.8, max_tokens=500, headers=headers)
|
||||||
return {"success": True, "prompt": result.strip()}
|
return {"success": True, "prompt": result.strip()}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
+61
-46
@@ -14,6 +14,7 @@ from fastapi.responses import HTMLResponse, StreamingResponse
|
|||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from src.endpoint_resolver import resolve_endpoint
|
from src.endpoint_resolver import resolve_endpoint
|
||||||
from src.auth_helpers import _auth_disabled, get_current_user
|
from src.auth_helpers import _auth_disabled, get_current_user
|
||||||
|
from src.constants import DEEP_RESEARCH_DIR
|
||||||
|
|
||||||
_SESSION_ID_RE = re.compile(r"^[a-zA-Z0-9-]{1,128}$")
|
_SESSION_ID_RE = re.compile(r"^[a-zA-Z0-9-]{1,128}$")
|
||||||
|
|
||||||
@@ -37,13 +38,15 @@ def _first_chat_model(models) -> str:
|
|||||||
return (models[0] if models else "")
|
return (models[0] if models else "")
|
||||||
|
|
||||||
|
|
||||||
def _resolve_research_endpoint(sess) -> tuple:
|
def _resolve_research_endpoint(sess, owner: Optional[str] = None) -> tuple:
|
||||||
"""Return (endpoint_url, model, headers) for Deep Research, checking admin overrides."""
|
"""Return (endpoint_url, model, headers) for Deep Research, checking admin overrides."""
|
||||||
|
owner = owner or getattr(sess, "owner", None) or None
|
||||||
url, model, headers = resolve_endpoint(
|
url, model, headers = resolve_endpoint(
|
||||||
"research",
|
"research",
|
||||||
fallback_url=sess.endpoint_url,
|
fallback_url=sess.endpoint_url,
|
||||||
fallback_model=sess.model,
|
fallback_model=sess.model,
|
||||||
fallback_headers=sess.headers,
|
fallback_headers=sess.headers,
|
||||||
|
owner=owner,
|
||||||
)
|
)
|
||||||
return url, model, headers
|
return url, model, headers
|
||||||
|
|
||||||
@@ -72,6 +75,38 @@ def _owned_enabled_endpoint(db, owner, endpoint_id=None):
|
|||||||
return owner_filter(q, ModelEndpoint, owner).first()
|
return owner_filter(q, ModelEndpoint, owner).first()
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_endpoint_runtime(ep, owner=None, model: Optional[str] = None):
|
||||||
|
"""Resolve a ModelEndpoint row into (chat_url, model, headers).
|
||||||
|
|
||||||
|
Mirrors endpoint_resolver.resolve_endpoint's provider-auth handling for
|
||||||
|
panel-selected research endpoints. ChatGPT Subscription endpoints keep
|
||||||
|
OAuth tokens in ProviderAuthSession, so ep.api_key is intentionally empty.
|
||||||
|
"""
|
||||||
|
from src.endpoint_resolver import (
|
||||||
|
build_chat_url,
|
||||||
|
build_headers,
|
||||||
|
resolve_endpoint_runtime as resolve_model_endpoint_runtime,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
base, api_key = resolve_model_endpoint_runtime(ep, owner=owner)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("Could not resolve endpoint credentials for research: %s", e)
|
||||||
|
return None
|
||||||
|
|
||||||
|
ep_model = (model or "").strip()
|
||||||
|
if not ep_model:
|
||||||
|
try:
|
||||||
|
models = json.loads(ep.cached_models) if ep.cached_models else []
|
||||||
|
if models:
|
||||||
|
ep_model = _first_chat_model(models)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
if not ep_model:
|
||||||
|
return None
|
||||||
|
return build_chat_url(base), ep_model, build_headers(api_key, base)
|
||||||
|
|
||||||
|
|
||||||
def setup_research_routes(research_handler, session_manager=None) -> APIRouter:
|
def setup_research_routes(research_handler, session_manager=None) -> APIRouter:
|
||||||
router = APIRouter(tags=["research"])
|
router = APIRouter(tags=["research"])
|
||||||
|
|
||||||
@@ -98,7 +133,7 @@ def setup_research_routes(research_handler, session_manager=None) -> APIRouter:
|
|||||||
if entry is not None:
|
if entry is not None:
|
||||||
return entry.get("owner", "") == user
|
return entry.get("owner", "") == user
|
||||||
# Task no longer in memory — check the persisted JSON.
|
# Task no longer in memory — check the persisted JSON.
|
||||||
path = Path("data/deep_research") / f"{session_id}.json"
|
path = Path(DEEP_RESEARCH_DIR) / f"{session_id}.json"
|
||||||
if not path.exists():
|
if not path.exists():
|
||||||
return False
|
return False
|
||||||
try:
|
try:
|
||||||
@@ -162,7 +197,7 @@ def setup_research_routes(research_handler, session_manager=None) -> APIRouter:
|
|||||||
def _assert_owns_research(session_id: str, user: str) -> None:
|
def _assert_owns_research(session_id: str, user: str) -> None:
|
||||||
"""404-not-403 ownership gate for a research session's on-disk JSON.
|
"""404-not-403 ownership gate for a research session's on-disk JSON.
|
||||||
Use BEFORE returning any data or mutating the file."""
|
Use BEFORE returning any data or mutating the file."""
|
||||||
path = Path("data/deep_research") / f"{session_id}.json"
|
path = Path(DEEP_RESEARCH_DIR) / f"{session_id}.json"
|
||||||
if not path.exists():
|
if not path.exists():
|
||||||
raise HTTPException(404, "Research not found")
|
raise HTTPException(404, "Research not found")
|
||||||
try:
|
try:
|
||||||
@@ -225,7 +260,7 @@ def setup_research_routes(research_handler, session_manager=None) -> APIRouter:
|
|||||||
):
|
):
|
||||||
user = _require_user(request)
|
user = _require_user(request)
|
||||||
"""List all completed research for the Library panel."""
|
"""List all completed research for the Library panel."""
|
||||||
data_dir = Path("data/deep_research")
|
data_dir = Path(DEEP_RESEARCH_DIR)
|
||||||
items = []
|
items = []
|
||||||
for p in data_dir.glob("*.json"):
|
for p in data_dir.glob("*.json"):
|
||||||
try:
|
try:
|
||||||
@@ -275,7 +310,7 @@ def setup_research_routes(research_handler, session_manager=None) -> APIRouter:
|
|||||||
summary, stats — used by the Library preview panel."""
|
summary, stats — used by the Library preview panel."""
|
||||||
user = _require_user(request)
|
user = _require_user(request)
|
||||||
_validate_session_id(session_id)
|
_validate_session_id(session_id)
|
||||||
path = Path("data/deep_research") / f"{session_id}.json"
|
path = Path(DEEP_RESEARCH_DIR) / f"{session_id}.json"
|
||||||
if not path.exists():
|
if not path.exists():
|
||||||
raise HTTPException(404, "Research not found")
|
raise HTTPException(404, "Research not found")
|
||||||
try:
|
try:
|
||||||
@@ -292,7 +327,7 @@ def setup_research_routes(research_handler, session_manager=None) -> APIRouter:
|
|||||||
"""Soft-archive / restore a research report (sets `archived` in its JSON)."""
|
"""Soft-archive / restore a research report (sets `archived` in its JSON)."""
|
||||||
user = _require_user(request)
|
user = _require_user(request)
|
||||||
_validate_session_id(session_id)
|
_validate_session_id(session_id)
|
||||||
path = Path("data/deep_research") / f"{session_id}.json"
|
path = Path(DEEP_RESEARCH_DIR) / f"{session_id}.json"
|
||||||
if not path.exists():
|
if not path.exists():
|
||||||
raise HTTPException(404, "Research not found")
|
raise HTTPException(404, "Research not found")
|
||||||
try:
|
try:
|
||||||
@@ -312,7 +347,7 @@ def setup_research_routes(research_handler, session_manager=None) -> APIRouter:
|
|||||||
"""Delete a research result from disk."""
|
"""Delete a research result from disk."""
|
||||||
user = _require_user(request)
|
user = _require_user(request)
|
||||||
_validate_session_id(session_id)
|
_validate_session_id(session_id)
|
||||||
data_dir = Path("data/deep_research")
|
data_dir = Path(DEEP_RESEARCH_DIR)
|
||||||
json_path = data_dir / f"{session_id}.json"
|
json_path = data_dir / f"{session_id}.json"
|
||||||
deleted = False
|
deleted = False
|
||||||
if json_path.exists():
|
if json_path.exists():
|
||||||
@@ -368,7 +403,6 @@ def setup_research_routes(research_handler, session_manager=None) -> APIRouter:
|
|||||||
|
|
||||||
if body.endpoint_id:
|
if body.endpoint_id:
|
||||||
from src.database import SessionLocal
|
from src.database import SessionLocal
|
||||||
from src.endpoint_resolver import normalize_base, build_chat_url, build_headers
|
|
||||||
db = SessionLocal()
|
db = SessionLocal()
|
||||||
try:
|
try:
|
||||||
# Owner-scoped: never resolve another user's private endpoint
|
# Owner-scoped: never resolve another user's private endpoint
|
||||||
@@ -377,35 +411,26 @@ def setup_research_routes(research_handler, session_manager=None) -> APIRouter:
|
|||||||
ep = _owned_enabled_endpoint(db, user, body.endpoint_id)
|
ep = _owned_enabled_endpoint(db, user, body.endpoint_id)
|
||||||
if not ep:
|
if not ep:
|
||||||
raise HTTPException(404, "Endpoint not found or disabled")
|
raise HTTPException(404, "Endpoint not found or disabled")
|
||||||
base = normalize_base(ep.base_url)
|
resolved = _resolve_endpoint_runtime(ep, owner=user, model=body.model)
|
||||||
ep_url = build_chat_url(base)
|
if not resolved:
|
||||||
ep_headers = build_headers(ep.api_key, base)
|
raise HTTPException(400, "Endpoint is not configured with a usable model.")
|
||||||
ep_model = body.model or ""
|
ep_url, ep_model, ep_headers = resolved
|
||||||
if not ep_model:
|
|
||||||
try:
|
|
||||||
import json as _json
|
|
||||||
models = _json.loads(ep.cached_models) if ep.cached_models else []
|
|
||||||
if models:
|
|
||||||
ep_model = _first_chat_model(models)
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
finally:
|
finally:
|
||||||
db.close()
|
db.close()
|
||||||
else:
|
else:
|
||||||
ep_url, ep_model, ep_headers = resolve_endpoint("research")
|
ep_url, ep_model, ep_headers = resolve_endpoint("research", owner=user)
|
||||||
if not ep_url:
|
if not ep_url:
|
||||||
ep_url, ep_model, ep_headers = resolve_endpoint("utility")
|
ep_url, ep_model, ep_headers = resolve_endpoint("utility", owner=user)
|
||||||
# When neither research nor utility is configured, use the user's
|
# When neither research nor utility is configured, use the user's
|
||||||
# configured DEFAULT model (default_endpoint_id/default_model) rather
|
# configured DEFAULT model (default_endpoint_id/default_model) rather
|
||||||
# than arbitrarily grabbing the first enabled endpoint's first model
|
# than arbitrarily grabbing the first enabled endpoint's first model
|
||||||
# (which surfaced gpt-3.5). "Default" should mean the default model.
|
# (which surfaced gpt-3.5). "Default" should mean the default model.
|
||||||
if not ep_url:
|
if not ep_url:
|
||||||
ep_url, ep_model, ep_headers = resolve_endpoint("default")
|
ep_url, ep_model, ep_headers = resolve_endpoint("default", owner=user)
|
||||||
if not ep_url:
|
if not ep_url:
|
||||||
ep_url, ep_model, ep_headers = resolve_endpoint("chat")
|
ep_url, ep_model, ep_headers = resolve_endpoint("chat", owner=user)
|
||||||
if not ep_url:
|
if not ep_url:
|
||||||
from src.database import SessionLocal
|
from src.database import SessionLocal
|
||||||
from src.endpoint_resolver import normalize_base, build_chat_url, build_headers
|
|
||||||
db = SessionLocal()
|
db = SessionLocal()
|
||||||
try:
|
try:
|
||||||
# Owner-scoped first-enabled fallback: the caller's own rows
|
# Owner-scoped first-enabled fallback: the caller's own rows
|
||||||
@@ -414,18 +439,9 @@ def setup_research_routes(research_handler, session_manager=None) -> APIRouter:
|
|||||||
# /api/v1/chat fallback (webhook_routes._first_enabled_endpoint).
|
# /api/v1/chat fallback (webhook_routes._first_enabled_endpoint).
|
||||||
ep = _owned_enabled_endpoint(db, user)
|
ep = _owned_enabled_endpoint(db, user)
|
||||||
if ep:
|
if ep:
|
||||||
base = normalize_base(ep.base_url)
|
resolved = _resolve_endpoint_runtime(ep, owner=user)
|
||||||
ep_url = build_chat_url(base)
|
if resolved:
|
||||||
ep_headers = build_headers(ep.api_key, base)
|
ep_url, ep_model, ep_headers = resolved
|
||||||
ep_model = ""
|
|
||||||
if ep.cached_models:
|
|
||||||
try:
|
|
||||||
import json as _json
|
|
||||||
models = _json.loads(ep.cached_models)
|
|
||||||
if models:
|
|
||||||
ep_model = _first_chat_model(models)
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
finally:
|
finally:
|
||||||
db.close()
|
db.close()
|
||||||
if not ep_url:
|
if not ep_url:
|
||||||
@@ -494,7 +510,7 @@ def setup_research_routes(research_handler, session_manager=None) -> APIRouter:
|
|||||||
raise HTTPException(404, "No research found for this session")
|
raise HTTPException(404, "No research found for this session")
|
||||||
result = research_handler.get_result(session_id)
|
result = research_handler.get_result(session_id)
|
||||||
if result is None:
|
if result is None:
|
||||||
p = Path("data/deep_research") / f"{session_id}.json"
|
p = Path(DEEP_RESEARCH_DIR) / f"{session_id}.json"
|
||||||
if p.exists():
|
if p.exists():
|
||||||
d = json.loads(p.read_text(encoding="utf-8"))
|
d = json.loads(p.read_text(encoding="utf-8"))
|
||||||
return {
|
return {
|
||||||
@@ -534,7 +550,7 @@ def setup_research_routes(research_handler, session_manager=None) -> APIRouter:
|
|||||||
sources = research_handler.get_sources(session_id) or []
|
sources = research_handler.get_sources(session_id) or []
|
||||||
query = ""
|
query = ""
|
||||||
|
|
||||||
path = Path("data/deep_research") / f"{session_id}.json"
|
path = Path(DEEP_RESEARCH_DIR) / f"{session_id}.json"
|
||||||
if path.exists():
|
if path.exists():
|
||||||
try:
|
try:
|
||||||
disk = json.loads(path.read_text(encoding="utf-8"))
|
disk = json.loads(path.read_text(encoding="utf-8"))
|
||||||
@@ -572,19 +588,18 @@ def setup_research_routes(research_handler, session_manager=None) -> APIRouter:
|
|||||||
ep_headers = dict(r_headers)
|
ep_headers = dict(r_headers)
|
||||||
|
|
||||||
if not ep_url or not ep_model:
|
if not ep_url or not ep_model:
|
||||||
_merge(*resolve_endpoint("chat"))
|
_merge(*resolve_endpoint("chat", owner=user))
|
||||||
if not ep_url or not ep_model:
|
if not ep_url or not ep_model:
|
||||||
_merge(*resolve_endpoint("research"))
|
_merge(*resolve_endpoint("research", owner=user))
|
||||||
if not ep_url or not ep_model:
|
if not ep_url or not ep_model:
|
||||||
_merge(*resolve_endpoint("utility"))
|
_merge(*resolve_endpoint("utility", owner=user))
|
||||||
if not ep_url or not ep_model:
|
if not ep_url or not ep_model:
|
||||||
# Last resort: any enabled endpoint
|
# Last resort: this user's enabled endpoint, plus legacy shared rows.
|
||||||
from src.database import SessionLocal
|
from src.database import SessionLocal
|
||||||
from src.database import ModelEndpoint
|
|
||||||
from src.endpoint_resolver import normalize_base, build_chat_url, build_headers
|
from src.endpoint_resolver import normalize_base, build_chat_url, build_headers
|
||||||
db = SessionLocal()
|
db = SessionLocal()
|
||||||
try:
|
try:
|
||||||
ep = db.query(ModelEndpoint).filter(ModelEndpoint.is_enabled == True).first()
|
ep = _owned_enabled_endpoint(db, user)
|
||||||
if ep:
|
if ep:
|
||||||
base = normalize_base(ep.base_url)
|
base = normalize_base(ep.base_url)
|
||||||
fallback_url = build_chat_url(base)
|
fallback_url = build_chat_url(base)
|
||||||
@@ -594,7 +609,7 @@ def setup_research_routes(research_handler, session_manager=None) -> APIRouter:
|
|||||||
try:
|
try:
|
||||||
models = json.loads(ep.cached_models)
|
models = json.loads(ep.cached_models)
|
||||||
if models:
|
if models:
|
||||||
fallback_model = models[0]
|
fallback_model = _first_chat_model(models)
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
_merge(fallback_url, fallback_model, fallback_headers)
|
_merge(fallback_url, fallback_model, fallback_headers)
|
||||||
|
|||||||
+43
-32
@@ -11,7 +11,7 @@ from core.session_manager import SessionManager
|
|||||||
from core.models import ChatMessage
|
from core.models import ChatMessage
|
||||||
from src.request_models import SessionResponse
|
from src.request_models import SessionResponse
|
||||||
from core.database import Session as DbSession, SessionLocal, Document, GalleryImage
|
from core.database import Session as DbSession, SessionLocal, Document, GalleryImage
|
||||||
from src.auth_helpers import get_current_user, effective_user
|
from src.auth_helpers import get_current_user, effective_user, _auth_disabled
|
||||||
|
|
||||||
|
|
||||||
def _sanitize_export_filename(name: str) -> str:
|
def _sanitize_export_filename(name: str) -> str:
|
||||||
@@ -92,35 +92,30 @@ def _reject_compact_during_active_run(session_id: str) -> None:
|
|||||||
|
|
||||||
|
|
||||||
def _verify_session_owner(request: Request, session_id: str, session_manager=None):
|
def _verify_session_owner(request: Request, session_id: str, session_manager=None):
|
||||||
"""Verify the current user owns the session. Raises 404 if not.
|
"""Verify the current user owns the session, honoring single-user modes.
|
||||||
|
|
||||||
Ownership is checked against the DB row when one exists (unchanged). If
|
Authenticated requests must match the stored DB or in-memory owner. When
|
||||||
there is no DB row but the caller owns an in-memory "ghost" session — one
|
auth is disabled and no user is present, treat the app as single-user mode:
|
||||||
that lives only in ``session_manager`` because it was never persisted, or
|
verify that the session exists, but do not compare its stored owner. This
|
||||||
its DB row was removed out-of-band — fall back to the in-memory owner so the
|
keeps QA/dev instances with AUTH_ENABLED=false from rejecting owner-stamped
|
||||||
user can still manage and delete it. Without this fallback such sessions are
|
rows created while auth was previously enabled.
|
||||||
listed by ``/api/sessions`` (they come from the in-memory manager) yet every
|
|
||||||
per-session operation 404s, making them impossible to delete (issue #1044).
|
|
||||||
|
|
||||||
``session_manager`` is optional and defaults to ``None`` so existing callers
|
|
||||||
that only care about persisted sessions keep their exact prior behavior.
|
|
||||||
"""
|
"""
|
||||||
user = effective_user(request)
|
user = effective_user(request)
|
||||||
if not user:
|
if not user and not _auth_disabled():
|
||||||
raise HTTPException(403, "Authentication required")
|
raise HTTPException(401, "Authentication required")
|
||||||
db = SessionLocal()
|
db = SessionLocal()
|
||||||
try:
|
try:
|
||||||
row = db.query(DbSession.owner).filter(DbSession.id == session_id).first()
|
row = db.query(DbSession.owner).filter(DbSession.id == session_id).first()
|
||||||
finally:
|
finally:
|
||||||
db.close()
|
db.close()
|
||||||
if row is not None:
|
if row is not None:
|
||||||
if row.owner != user:
|
if user and row.owner != user:
|
||||||
raise HTTPException(404, f"Session {session_id} not found")
|
raise HTTPException(404, f"Session {session_id} not found")
|
||||||
return
|
return
|
||||||
# No DB row — allow the caller to act on an in-memory ghost they own.
|
# No DB row — allow the caller to act on an in-memory ghost they own.
|
||||||
if session_manager is not None:
|
if session_manager is not None:
|
||||||
ghost = getattr(session_manager, "sessions", {}).get(session_id)
|
ghost = getattr(session_manager, "sessions", {}).get(session_id)
|
||||||
if ghost is not None and getattr(ghost, "owner", None) == user:
|
if ghost is not None and (not user or getattr(ghost, "owner", None) == user):
|
||||||
return
|
return
|
||||||
raise HTTPException(404, f"Session {session_id} not found")
|
raise HTTPException(404, f"Session {session_id} not found")
|
||||||
|
|
||||||
@@ -262,7 +257,7 @@ def setup_session_routes(session_manager: SessionManager, config: dict, webhook_
|
|||||||
last_msg_map = {}
|
last_msg_map = {}
|
||||||
mode_map = {}
|
mode_map = {}
|
||||||
msg_count_map = {}
|
msg_count_map = {}
|
||||||
rows = db.query(DbSession.id, DbSession.folder, DbSession.total_input_tokens, DbSession.total_output_tokens, DbSession.is_important, DbSession.created_at, DbSession.updated_at, DbSession.last_message_at, DbSession.mode, DbSession.message_count).filter(DbSession.archived == False).all()
|
rows = db.query(DbSession.id, DbSession.folder, DbSession.total_input_tokens, DbSession.total_output_tokens, DbSession.is_important, DbSession.created_at, DbSession.updated_at, DbSession.last_message_at, DbSession.mode, DbSession.message_count).filter(DbSession.archived == False, DbSession.owner == user).all()
|
||||||
for row in rows:
|
for row in rows:
|
||||||
folder_map[row.id] = row.folder
|
folder_map[row.id] = row.folder
|
||||||
token_map[row.id] = (row.total_input_tokens or 0) + (row.total_output_tokens or 0)
|
token_map[row.id] = (row.total_input_tokens or 0) + (row.total_output_tokens or 0)
|
||||||
@@ -284,12 +279,14 @@ def setup_session_routes(session_manager: SessionManager, config: dict, webhook_
|
|||||||
r[0] for r in db.query(Document.session_id)
|
r[0] for r in db.query(Document.session_id)
|
||||||
.filter(Document.is_active == True,
|
.filter(Document.is_active == True,
|
||||||
Document.current_content != None,
|
Document.current_content != None,
|
||||||
func.trim(Document.current_content) != "")
|
func.trim(Document.current_content) != "",
|
||||||
|
Document.owner == user)
|
||||||
.distinct().all()
|
.distinct().all()
|
||||||
)
|
)
|
||||||
img_session_ids = set(
|
img_session_ids = set(
|
||||||
r[0] for r in db.query(GalleryImage.session_id)
|
r[0] for r in db.query(GalleryImage.session_id)
|
||||||
.filter(GalleryImage.session_id != None)
|
.filter(GalleryImage.session_id != None,
|
||||||
|
GalleryImage.owner == user)
|
||||||
.distinct().all()
|
.distinct().all()
|
||||||
)
|
)
|
||||||
finally:
|
finally:
|
||||||
@@ -370,8 +367,13 @@ def setup_session_routes(session_manager: SessionManager, config: dict, webhook_
|
|||||||
pass
|
pass
|
||||||
elif not model_to_use:
|
elif not model_to_use:
|
||||||
from src.llm_core import list_model_ids
|
from src.llm_core import list_model_ids
|
||||||
ids = list_model_ids(endpoint_url, timeout=REQUEST_TIMEOUT,
|
ids = list_model_ids(
|
||||||
headers=validation_headers)
|
endpoint_url,
|
||||||
|
timeout=REQUEST_TIMEOUT,
|
||||||
|
headers=validation_headers,
|
||||||
|
owner=user,
|
||||||
|
endpoint_id=endpoint_id.strip() if endpoint_id else None,
|
||||||
|
)
|
||||||
if not ids:
|
if not ids:
|
||||||
raise HTTPException(400, "Cannot reach /v1/models")
|
raise HTTPException(400, "Cannot reach /v1/models")
|
||||||
# Default to the first CHAT model — endpoints often list embedding/
|
# Default to the first CHAT model — endpoints often list embedding/
|
||||||
@@ -385,8 +387,13 @@ def setup_session_routes(session_manager: SessionManager, config: dict, webhook_
|
|||||||
from src.llm_core import list_model_ids
|
from src.llm_core import list_model_ids
|
||||||
import os as _os
|
import os as _os
|
||||||
req_base = _os.path.basename(model_to_use.rstrip("/"))
|
req_base = _os.path.basename(model_to_use.rstrip("/"))
|
||||||
avail = list_model_ids(endpoint_url, timeout=REQUEST_TIMEOUT,
|
avail = list_model_ids(
|
||||||
headers=validation_headers)
|
endpoint_url,
|
||||||
|
timeout=REQUEST_TIMEOUT,
|
||||||
|
headers=validation_headers,
|
||||||
|
owner=user,
|
||||||
|
endpoint_id=endpoint_id.strip() if endpoint_id else None,
|
||||||
|
)
|
||||||
if not avail:
|
if not avail:
|
||||||
raise HTTPException(400, "Cannot reach /v1/models")
|
raise HTTPException(400, "Cannot reach /v1/models")
|
||||||
if model_to_use not in avail:
|
if model_to_use not in avail:
|
||||||
@@ -543,22 +550,25 @@ def setup_session_routes(session_manager: SessionManager, config: dict, webhook_
|
|||||||
ids = body.get("ids", [])
|
ids = body.get("ids", [])
|
||||||
except Exception:
|
except Exception:
|
||||||
ids = []
|
ids = []
|
||||||
|
deleted_count = 0
|
||||||
for sid in ids:
|
for sid in ids:
|
||||||
try:
|
try:
|
||||||
_verify_session_owner(request, sid, session_manager)
|
_verify_session_owner(request, sid, session_manager)
|
||||||
session_manager.delete_session(sid)
|
|
||||||
|
# Enforce "starred" protection consistent with single-session delete
|
||||||
db = SessionLocal()
|
db = SessionLocal()
|
||||||
try:
|
try:
|
||||||
db.query(_CM).filter(_CM.session_id == sid).delete()
|
db_sess = db.query(DbSession).filter(DbSession.id == sid).first()
|
||||||
db.query(DbSession).filter(DbSession.id == sid).delete()
|
if db_sess and db_sess.is_important:
|
||||||
db.commit()
|
continue
|
||||||
except Exception:
|
|
||||||
db.rollback()
|
|
||||||
finally:
|
finally:
|
||||||
db.close()
|
db.close()
|
||||||
|
|
||||||
|
if session_manager.delete_session(sid):
|
||||||
|
deleted_count += 1
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
return {"deleted": len(ids)}
|
return {"deleted": deleted_count}
|
||||||
|
|
||||||
@router.delete("/session/{sid}")
|
@router.delete("/session/{sid}")
|
||||||
def delete_session(request: Request, sid: str):
|
def delete_session(request: Request, sid: str):
|
||||||
@@ -924,7 +934,8 @@ def setup_session_routes(session_manager: SessionManager, config: dict, webhook_
|
|||||||
from src.endpoint_resolver import resolve_endpoint
|
from src.endpoint_resolver import resolve_endpoint
|
||||||
from src.llm_core import llm_call_async
|
from src.llm_core import llm_call_async
|
||||||
|
|
||||||
url, model, headers = resolve_endpoint("utility", owner=get_current_user(request))
|
owner = getattr(session, "owner", None) or effective_user(request)
|
||||||
|
url, model, headers = resolve_endpoint("utility", owner=owner)
|
||||||
if not url or not model:
|
if not url or not model:
|
||||||
url, model, headers = session.endpoint_url, session.model, session.headers
|
url, model, headers = session.endpoint_url, session.model, session.headers
|
||||||
if not url or not model:
|
if not url or not model:
|
||||||
@@ -1006,7 +1017,7 @@ def setup_session_routes(session_manager: SessionManager, config: dict, webhook_
|
|||||||
}
|
}
|
||||||
_THROWAWAY_MAX_MESSAGES = 4 # only delete if <= this many messages
|
_THROWAWAY_MAX_MESSAGES = 4 # only delete if <= this many messages
|
||||||
try:
|
try:
|
||||||
rows = db.query(DbSession).filter(DbSession.archived == False, DbSession.owner == user).all()
|
rows = db.query(DbSession).filter(DbSession.archived == False, DbSession.owner == user).limit(2000).all()
|
||||||
folder_map = {r.id: r.folder for r in rows}
|
folder_map = {r.id: r.folder for r in rows}
|
||||||
# Precompute per-session message counts in TWO aggregate queries
|
# Precompute per-session message counts in TWO aggregate queries
|
||||||
# instead of 1–3 queries PER session — with many chats the per-row
|
# instead of 1–3 queries PER session — with many chats the per-row
|
||||||
|
|||||||
+279
-58
@@ -13,6 +13,7 @@ import tempfile
|
|||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, Any
|
from typing import Dict, Any
|
||||||
|
from core.platform_compat import IS_APPLE_SILICON, which_tool
|
||||||
|
|
||||||
# POSIX-only: `pty`/`fcntl` transitively import `termios`, which does NOT exist
|
# POSIX-only: `pty`/`fcntl` transitively import `termios`, which does NOT exist
|
||||||
# on Windows, so importing them unconditionally crashed app startup there
|
# on Windows, so importing them unconditionally crashed app startup there
|
||||||
@@ -37,6 +38,7 @@ from core.platform_compat import (
|
|||||||
IS_WINDOWS,
|
IS_WINDOWS,
|
||||||
detached_popen_kwargs,
|
detached_popen_kwargs,
|
||||||
find_bash,
|
find_bash,
|
||||||
|
git_bash_path,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -92,6 +94,7 @@ def _venv_activate_prefix(venv: str | None) -> str:
|
|||||||
act = venv if venv.endswith("/bin/activate") else venv.rstrip("/") + "/bin/activate"
|
act = venv if venv.endswith("/bin/activate") else venv.rstrip("/") + "/bin/activate"
|
||||||
return f". {act} && "
|
return f". {act} && "
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
PTY_SUPPORTED = pty is not None and fcntl is not None and hasattr(os, "setsid")
|
PTY_SUPPORTED = pty is not None and fcntl is not None and hasattr(os, "setsid")
|
||||||
@@ -169,7 +172,10 @@ def _package_installed_from_probe(name: str, probe: dict) -> bool:
|
|||||||
and (dists.get("torch") or modules.get("torch", {}).get("real_module"))
|
and (dists.get("torch") or modules.get("torch", {}).get("real_module"))
|
||||||
)
|
)
|
||||||
if name == "hf_transfer":
|
if name == "hf_transfer":
|
||||||
return bool(dists.get("hf-transfer") or modules.get("hf_transfer", {}).get("real_module"))
|
return bool(
|
||||||
|
dists.get("hf-transfer")
|
||||||
|
or modules.get("hf_transfer", {}).get("real_module")
|
||||||
|
)
|
||||||
return bool(dists.get(name) or modules.get(name, {}).get("real_module"))
|
return bool(dists.get(name) or modules.get(name, {}).get("real_module"))
|
||||||
|
|
||||||
|
|
||||||
@@ -194,8 +200,14 @@ def _package_status_note(name: str, probe: dict) -> str:
|
|||||||
if binaries.get("llama-server"):
|
if binaries.get("llama-server"):
|
||||||
parts.append(f"native llama-server: {binaries['llama-server']}")
|
parts.append(f"native llama-server: {binaries['llama-server']}")
|
||||||
if dists.get("llama-cpp-python"):
|
if dists.get("llama-cpp-python"):
|
||||||
parts.append(f"python package: llama-cpp-python {dists['llama-cpp-python']}")
|
parts.append(
|
||||||
return "; ".join(parts) if parts else "No native llama-server or llama-cpp-python server package found."
|
f"python package: llama-cpp-python {dists['llama-cpp-python']}"
|
||||||
|
)
|
||||||
|
return (
|
||||||
|
"; ".join(parts)
|
||||||
|
if parts
|
||||||
|
else "No native llama-server or llama-cpp-python server package found."
|
||||||
|
)
|
||||||
if name == "diffusers":
|
if name == "diffusers":
|
||||||
if _package_installed_from_probe(name, probe):
|
if _package_installed_from_probe(name, probe):
|
||||||
return f"diffusers {dists.get('diffusers', 'available')} with torch {dists.get('torch', 'available')}"
|
return f"diffusers {dists.get('diffusers', 'available')} with torch {dists.get('torch', 'available')}"
|
||||||
@@ -205,7 +217,9 @@ def _package_status_note(name: str, probe: dict) -> str:
|
|||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
|
||||||
def _package_pip_update_status(pkg: dict, probe: dict | None = None) -> PackageUpdateStatus:
|
def _package_pip_update_status(
|
||||||
|
pkg: dict, probe: dict | None = None
|
||||||
|
) -> PackageUpdateStatus:
|
||||||
"""Return whether the Dependencies UI should offer a generic pip update.
|
"""Return whether the Dependencies UI should offer a generic pip update.
|
||||||
|
|
||||||
"Installed" means Cookbook can use the dependency. It does not always mean
|
"Installed" means Cookbook can use the dependency. It does not always mean
|
||||||
@@ -213,12 +227,28 @@ def _package_pip_update_status(pkg: dict, probe: dict | None = None) -> PackageU
|
|||||||
native llama-server can come from a package manager/source build, and a CLI
|
native llama-server can come from a package manager/source build, and a CLI
|
||||||
may be on PATH without matching Python package metadata.
|
may be on PATH without matching Python package metadata.
|
||||||
"""
|
"""
|
||||||
|
if pkg.get("name") == "APFEL":
|
||||||
|
return PackageUpdateStatus(
|
||||||
|
False,
|
||||||
|
"", # Note is empty because IT DOES allow for updates outside of PIP.
|
||||||
|
)
|
||||||
|
|
||||||
if pkg.get("kind") == "system" or not pkg.get("pip"):
|
if pkg.get("kind") == "system" or not pkg.get("pip"):
|
||||||
return PackageUpdateStatus(False, "Update this system dependency outside Odysseus.")
|
return PackageUpdateStatus(
|
||||||
|
False, "Update this system dependency outside Odysseus."
|
||||||
|
)
|
||||||
|
|
||||||
name = pkg.get("name")
|
name = pkg.get("name")
|
||||||
binaries = probe.get("binaries") if isinstance(probe, dict) and isinstance(probe.get("binaries"), dict) else {}
|
binaries = (
|
||||||
dists = probe.get("dists") if isinstance(probe, dict) and isinstance(probe.get("dists"), dict) else {}
|
probe.get("binaries")
|
||||||
|
if isinstance(probe, dict) and isinstance(probe.get("binaries"), dict)
|
||||||
|
else {}
|
||||||
|
)
|
||||||
|
dists = (
|
||||||
|
probe.get("dists")
|
||||||
|
if isinstance(probe, dict) and isinstance(probe.get("dists"), dict)
|
||||||
|
else {}
|
||||||
|
)
|
||||||
|
|
||||||
if name == "llama_cpp" and binaries.get("llama-server"):
|
if name == "llama_cpp" and binaries.get("llama-server"):
|
||||||
return PackageUpdateStatus(
|
return PackageUpdateStatus(
|
||||||
@@ -231,7 +261,9 @@ def _package_pip_update_status(pkg: dict, probe: dict | None = None) -> PackageU
|
|||||||
"Using a vLLM CLI on PATH without Python package metadata; update it outside Odysseus.",
|
"Using a vLLM CLI on PATH without Python package metadata; update it outside Odysseus.",
|
||||||
)
|
)
|
||||||
|
|
||||||
return PackageUpdateStatus(True, "Update uses pip in the selected Python environment.")
|
return PackageUpdateStatus(
|
||||||
|
True, "Update uses pip in the selected Python environment."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _prepend_user_install_bins_to_path() -> None:
|
def _prepend_user_install_bins_to_path() -> None:
|
||||||
@@ -250,7 +282,9 @@ def _prepend_user_install_bins_to_path() -> None:
|
|||||||
candidates = []
|
candidates = []
|
||||||
candidates.append(os.path.expanduser("~/.local/bin"))
|
candidates.append(os.path.expanduser("~/.local/bin"))
|
||||||
|
|
||||||
parts = os.environ.get("PATH", "").split(os.pathsep) if os.environ.get("PATH") else []
|
parts = (
|
||||||
|
os.environ.get("PATH", "").split(os.pathsep) if os.environ.get("PATH") else []
|
||||||
|
)
|
||||||
changed = False
|
changed = False
|
||||||
for path in reversed([p for p in candidates if p]):
|
for path in reversed([p for p in candidates if p]):
|
||||||
if path not in parts:
|
if path not in parts:
|
||||||
@@ -357,9 +391,11 @@ PTY_UNSUPPORTED_ERROR = "pty_unsupported"
|
|||||||
|
|
||||||
class ShellExecRequest(BaseModel):
|
class ShellExecRequest(BaseModel):
|
||||||
command: str
|
command: str
|
||||||
timeout: int | None = None # optional override; 0 = no timeout (run until client disconnects)
|
timeout: int | None = (
|
||||||
use_pty: bool = False # use pseudo-TTY (for progress bars)
|
None # optional override; 0 = no timeout (run until client disconnects)
|
||||||
use_tmux: bool = False # run in tmux session (survives browser disconnect)
|
)
|
||||||
|
use_pty: bool = False # use pseudo-TTY (for progress bars)
|
||||||
|
use_tmux: bool = False # run in tmux session (survives browser disconnect)
|
||||||
|
|
||||||
|
|
||||||
async def _create_shell(command: str, **kwargs):
|
async def _create_shell(command: str, **kwargs):
|
||||||
@@ -368,8 +404,16 @@ async def _create_shell(command: str, **kwargs):
|
|||||||
POSIX: /bin/sh via create_subprocess_shell (unchanged behaviour).
|
POSIX: /bin/sh via create_subprocess_shell (unchanged behaviour).
|
||||||
Windows: prefer a real bash (Git Bash/WSL) so bash-syntax commands behave
|
Windows: prefer a real bash (Git Bash/WSL) so bash-syntax commands behave
|
||||||
the same as on Linux; fall back to cmd.exe when no bash is installed.
|
the same as on Linux; fall back to cmd.exe when no bash is installed.
|
||||||
|
Powershell commands are executed directly via cmd.exe /c to avoid quoting
|
||||||
|
and env variable expansion errors under Git Bash.
|
||||||
"""
|
"""
|
||||||
if IS_WINDOWS:
|
if IS_WINDOWS:
|
||||||
|
# PowerShell commands (used by the frontend for Windows log-file polling
|
||||||
|
# and session management) must run directly — passing them through
|
||||||
|
# bash -c mangles $env:VAR syntax and breaks the command.
|
||||||
|
cmd_trim = command.strip()
|
||||||
|
if cmd_trim.startswith("powershell") or cmd_trim.startswith("cmd "):
|
||||||
|
return await asyncio.create_subprocess_shell(command, **kwargs)
|
||||||
bash = find_bash()
|
bash = find_bash()
|
||||||
if bash:
|
if bash:
|
||||||
return await asyncio.create_subprocess_exec(bash, "-c", command, **kwargs)
|
return await asyncio.create_subprocess_exec(bash, "-c", command, **kwargs)
|
||||||
@@ -386,9 +430,7 @@ async def _exec_shell(command: str, timeout: int = EXEC_TIMEOUT) -> Dict[str, An
|
|||||||
stderr=asyncio.subprocess.PIPE,
|
stderr=asyncio.subprocess.PIPE,
|
||||||
cwd=str(Path.home()),
|
cwd=str(Path.home()),
|
||||||
)
|
)
|
||||||
stdout_b, stderr_b = await asyncio.wait_for(
|
stdout_b, stderr_b = await asyncio.wait_for(proc.communicate(), timeout=timeout)
|
||||||
proc.communicate(), timeout=timeout
|
|
||||||
)
|
|
||||||
stdout = stdout_b.decode(errors="replace")[:MAX_OUTPUT]
|
stdout = stdout_b.decode(errors="replace")[:MAX_OUTPUT]
|
||||||
stderr = stderr_b.decode(errors="replace")[:MAX_OUTPUT]
|
stderr = stderr_b.decode(errors="replace")[:MAX_OUTPUT]
|
||||||
return {"stdout": stdout, "stderr": stderr, "exit_code": proc.returncode}
|
return {"stdout": stdout, "stderr": stderr, "exit_code": proc.returncode}
|
||||||
@@ -399,7 +441,11 @@ async def _exec_shell(command: str, timeout: int = EXEC_TIMEOUT) -> Dict[str, An
|
|||||||
await proc.wait()
|
await proc.wait()
|
||||||
except ProcessLookupError:
|
except ProcessLookupError:
|
||||||
pass
|
pass
|
||||||
return {"stdout": "", "stderr": f"Command timed out after {timeout}s", "exit_code": -1}
|
return {
|
||||||
|
"stdout": "",
|
||||||
|
"stderr": f"Command timed out after {timeout}s",
|
||||||
|
"exit_code": -1,
|
||||||
|
}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return {"stdout": "", "stderr": str(e), "exit_code": -1}
|
return {"stdout": "", "stderr": str(e), "exit_code": -1}
|
||||||
|
|
||||||
@@ -481,7 +527,7 @@ async def _generate_pty(cmd: str, timeout: int, request: Request):
|
|||||||
if idx == -1:
|
if idx == -1:
|
||||||
break
|
break
|
||||||
line = buf[:idx].decode(errors="replace")
|
line = buf[:idx].decode(errors="replace")
|
||||||
buf = buf[idx + sep_len:]
|
buf = buf[idx + sep_len :]
|
||||||
if line:
|
if line:
|
||||||
yield f"data: {json.dumps({'stream': 'stdout', 'data': line})}\n\n"
|
yield f"data: {json.dumps({'stream': 'stdout', 'data': line})}\n\n"
|
||||||
|
|
||||||
@@ -503,7 +549,7 @@ async def _generate_pty(cmd: str, timeout: int, request: Request):
|
|||||||
if idx == -1:
|
if idx == -1:
|
||||||
break
|
break
|
||||||
line = buf[:idx].decode(errors="replace")
|
line = buf[:idx].decode(errors="replace")
|
||||||
buf = buf[idx + sep_len:]
|
buf = buf[idx + sep_len :]
|
||||||
if line:
|
if line:
|
||||||
yield f"data: {json.dumps({'stream': 'stdout', 'data': line})}\n\n"
|
yield f"data: {json.dumps({'stream': 'stdout', 'data': line})}\n\n"
|
||||||
if buf:
|
if buf:
|
||||||
@@ -534,6 +580,7 @@ def _pty_read(fd: int) -> bytes | None:
|
|||||||
"""Blocking read from PTY fd. Called via run_in_executor.
|
"""Blocking read from PTY fd. Called via run_in_executor.
|
||||||
Returns bytes on data, None on timeout (no data yet)."""
|
Returns bytes on data, None on timeout (no data yet)."""
|
||||||
import select
|
import select
|
||||||
|
|
||||||
r, _, _ = select.select([fd], [], [], 1.0)
|
r, _, _ = select.select([fd], [], [], 1.0)
|
||||||
if r:
|
if r:
|
||||||
try:
|
try:
|
||||||
@@ -557,10 +604,10 @@ async def _generate_tmux(cmd: str, request: Request):
|
|||||||
script_path = TMUX_LOG_DIR / f"{session_id}.sh"
|
script_path = TMUX_LOG_DIR / f"{session_id}.sh"
|
||||||
script_path.write_text(
|
script_path.write_text(
|
||||||
f"#!/bin/bash\n"
|
f"#!/bin/bash\n"
|
||||||
f"ODYSSEUS_USER_SHELL=\"${{SHELL:-}}\"\n"
|
f'ODYSSEUS_USER_SHELL="${{SHELL:-}}"\n'
|
||||||
f"if [ -n \"$ODYSSEUS_USER_SHELL\" ] && [ -x \"$ODYSSEUS_USER_SHELL\" ]; then\n"
|
f'if [ -n "$ODYSSEUS_USER_SHELL" ] && [ -x "$ODYSSEUS_USER_SHELL" ]; then\n'
|
||||||
f" ODYSSEUS_USER_PATH=\"$(\"$ODYSSEUS_USER_SHELL\" -ic 'printf \"__ODYSSEUS_PATH__%s\\n\" \"$PATH\"' 2>/dev/null | sed -n 's/^__ODYSSEUS_PATH__//p' | tail -n 1 || true)\"\n"
|
f' ODYSSEUS_USER_PATH="$("$ODYSSEUS_USER_SHELL" -ic \'printf "__ODYSSEUS_PATH__%s\\n" "$PATH"\' 2>/dev/null | sed -n \'s/^__ODYSSEUS_PATH__//p\' | tail -n 1 || true)"\n'
|
||||||
f" if [ -n \"$ODYSSEUS_USER_PATH\" ]; then export PATH=\"$ODYSSEUS_USER_PATH:$PATH\"; fi\n"
|
f' if [ -n "$ODYSSEUS_USER_PATH" ]; then export PATH="$ODYSSEUS_USER_PATH:$PATH"; fi\n'
|
||||||
f"fi\n"
|
f"fi\n"
|
||||||
f"{cmd} 2>&1 | tee '{log_path}'\n"
|
f"{cmd} 2>&1 | tee '{log_path}'\n"
|
||||||
f"EC=${{PIPESTATUS[0]}}\n"
|
f"EC=${{PIPESTATUS[0]}}\n"
|
||||||
@@ -570,7 +617,9 @@ async def _generate_tmux(cmd: str, request: Request):
|
|||||||
encoding="utf-8",
|
encoding="utf-8",
|
||||||
)
|
)
|
||||||
script_path.chmod(0o755)
|
script_path.chmod(0o755)
|
||||||
logger.info("tmux wrapper script created: session=%s path=%s", session_id, script_path)
|
logger.info(
|
||||||
|
"tmux wrapper script created: session=%s path=%s", session_id, script_path
|
||||||
|
)
|
||||||
|
|
||||||
tmux_cmd = f"tmux new-session -d -s {session_id} {shlex.quote(str(script_path))}"
|
tmux_cmd = f"tmux new-session -d -s {session_id} {shlex.quote(str(script_path))}"
|
||||||
|
|
||||||
@@ -602,7 +651,9 @@ async def _generate_tmux(cmd: str, request: Request):
|
|||||||
# Read new lines from log
|
# Read new lines from log
|
||||||
try:
|
try:
|
||||||
if log_path.exists():
|
if log_path.exists():
|
||||||
lines = log_path.read_text(encoding="utf-8", errors="replace").splitlines()
|
lines = log_path.read_text(
|
||||||
|
encoding="utf-8", errors="replace"
|
||||||
|
).splitlines()
|
||||||
new_lines = lines[lines_sent:]
|
new_lines = lines[lines_sent:]
|
||||||
for line in new_lines:
|
for line in new_lines:
|
||||||
if line.startswith(":::EXIT_CODE:::"):
|
if line.startswith(":::EXIT_CODE:::"):
|
||||||
@@ -630,7 +681,9 @@ async def _generate_tmux(cmd: str, request: Request):
|
|||||||
# Session ended — do one final read
|
# Session ended — do one final read
|
||||||
await asyncio.sleep(0.5)
|
await asyncio.sleep(0.5)
|
||||||
if log_path.exists():
|
if log_path.exists():
|
||||||
lines = log_path.read_text(encoding="utf-8", errors="replace").splitlines()
|
lines = log_path.read_text(
|
||||||
|
encoding="utf-8", errors="replace"
|
||||||
|
).splitlines()
|
||||||
for line in lines[lines_sent:]:
|
for line in lines[lines_sent:]:
|
||||||
if line.startswith(":::EXIT_CODE:::"):
|
if line.startswith(":::EXIT_CODE:::"):
|
||||||
try:
|
try:
|
||||||
@@ -672,8 +725,8 @@ async def _generate_win_detached(cmd: str, request: Request):
|
|||||||
if bash:
|
if bash:
|
||||||
script_path = TMUX_LOG_DIR / f"{session_id}.sh"
|
script_path = TMUX_LOG_DIR / f"{session_id}.sh"
|
||||||
script_path.write_text(
|
script_path.write_text(
|
||||||
f"{cmd} > {shlex.quote(str(log_path))} 2>&1\n"
|
f"{cmd} > {shlex.quote(git_bash_path(log_path))} 2>&1\n"
|
||||||
f"echo $? > {shlex.quote(str(exit_path))}\n",
|
f"echo $? > {shlex.quote(git_bash_path(exit_path))}\n",
|
||||||
encoding="utf-8",
|
encoding="utf-8",
|
||||||
)
|
)
|
||||||
argv = [bash, str(script_path)]
|
argv = [bash, str(script_path)]
|
||||||
@@ -711,7 +764,9 @@ async def _generate_win_detached(cmd: str, request: Request):
|
|||||||
return
|
return
|
||||||
try:
|
try:
|
||||||
if log_path.exists():
|
if log_path.exists():
|
||||||
lines = log_path.read_text(encoding="utf-8", errors="replace").splitlines()
|
lines = log_path.read_text(
|
||||||
|
encoding="utf-8", errors="replace"
|
||||||
|
).splitlines()
|
||||||
for line in lines[lines_sent:]:
|
for line in lines[lines_sent:]:
|
||||||
yield f"data: {json.dumps({'stream': 'stdout', 'data': line})}\n\n"
|
yield f"data: {json.dumps({'stream': 'stdout', 'data': line})}\n\n"
|
||||||
lines_sent = len(lines)
|
lines_sent = len(lines)
|
||||||
@@ -723,11 +778,18 @@ async def _generate_win_detached(cmd: str, request: Request):
|
|||||||
await asyncio.sleep(0.3)
|
await asyncio.sleep(0.3)
|
||||||
try:
|
try:
|
||||||
if log_path.exists():
|
if log_path.exists():
|
||||||
lines = log_path.read_text(encoding="utf-8", errors="replace").splitlines()
|
lines = log_path.read_text(
|
||||||
|
encoding="utf-8", errors="replace"
|
||||||
|
).splitlines()
|
||||||
for line in lines[lines_sent:]:
|
for line in lines[lines_sent:]:
|
||||||
yield f"data: {json.dumps({'stream': 'stdout', 'data': line})}\n\n"
|
yield f"data: {json.dumps({'stream': 'stdout', 'data': line})}\n\n"
|
||||||
lines_sent = len(lines)
|
lines_sent = len(lines)
|
||||||
exit_code = int((exit_path.read_text(encoding="utf-8", errors="replace").strip() or "0"))
|
exit_code = int(
|
||||||
|
(
|
||||||
|
exit_path.read_text(encoding="utf-8", errors="replace").strip()
|
||||||
|
or "0"
|
||||||
|
)
|
||||||
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
exit_code = 0
|
exit_code = 0
|
||||||
break
|
break
|
||||||
@@ -753,7 +815,9 @@ def setup_shell_routes() -> APIRouter:
|
|||||||
return {"stdout": "", "stderr": "No command provided", "exit_code": 1}
|
return {"stdout": "", "stderr": "No command provided", "exit_code": 1}
|
||||||
|
|
||||||
logger.info("User shell exec requested: length=%d", len(cmd))
|
logger.info("User shell exec requested: length=%d", len(cmd))
|
||||||
result = await _exec_shell(cmd, timeout=EXEC_TIMEOUT)
|
result = await _exec_shell(
|
||||||
|
cmd, timeout=req.timeout if req.timeout is not None else EXEC_TIMEOUT
|
||||||
|
)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@router.post("/api/shell/stream")
|
@router.post("/api/shell/stream")
|
||||||
@@ -762,9 +826,11 @@ def setup_shell_routes() -> APIRouter:
|
|||||||
_require_admin(request)
|
_require_admin(request)
|
||||||
cmd = req.command.strip()
|
cmd = req.command.strip()
|
||||||
if not cmd:
|
if not cmd:
|
||||||
|
|
||||||
async def empty():
|
async def empty():
|
||||||
yield f"data: {json.dumps({'stream': 'stderr', 'data': 'No command provided'})}\n\n"
|
yield f"data: {json.dumps({'stream': 'stderr', 'data': 'No command provided'})}\n\n"
|
||||||
yield f"data: {json.dumps({'exit_code': 1})}\n\n"
|
yield f"data: {json.dumps({'exit_code': 1})}\n\n"
|
||||||
|
|
||||||
return StreamingResponse(empty(), media_type="text/event-stream")
|
return StreamingResponse(empty(), media_type="text/event-stream")
|
||||||
|
|
||||||
timeout = req.timeout if req.timeout is not None else STREAM_TIMEOUT
|
timeout = req.timeout if req.timeout is not None else STREAM_TIMEOUT
|
||||||
@@ -781,7 +847,11 @@ def setup_shell_routes() -> APIRouter:
|
|||||||
if use_tmux:
|
if use_tmux:
|
||||||
# tmux is POSIX-only; Windows uses a detached-process + logfile tail
|
# tmux is POSIX-only; Windows uses a detached-process + logfile tail
|
||||||
# that preserves the "survives disconnect" behaviour.
|
# that preserves the "survives disconnect" behaviour.
|
||||||
gen = _generate_win_detached(cmd, request) if IS_WINDOWS else _generate_tmux(cmd, request)
|
gen = (
|
||||||
|
_generate_win_detached(cmd, request)
|
||||||
|
if IS_WINDOWS
|
||||||
|
else _generate_tmux(cmd, request)
|
||||||
|
)
|
||||||
return StreamingResponse(gen, media_type="text/event-stream")
|
return StreamingResponse(gen, media_type="text/event-stream")
|
||||||
|
|
||||||
if use_pty and not IS_WINDOWS:
|
if use_pty and not IS_WINDOWS:
|
||||||
@@ -813,7 +883,12 @@ def setup_shell_routes() -> APIRouter:
|
|||||||
chunk = await stream.read(4096)
|
chunk = await stream.read(4096)
|
||||||
if not chunk:
|
if not chunk:
|
||||||
if buf:
|
if buf:
|
||||||
await q.put((name, buf.decode(errors="replace").rstrip("\r\n")))
|
await q.put(
|
||||||
|
(
|
||||||
|
name,
|
||||||
|
buf.decode(errors="replace").rstrip("\r\n"),
|
||||||
|
)
|
||||||
|
)
|
||||||
break
|
break
|
||||||
buf += chunk
|
buf += chunk
|
||||||
while True:
|
while True:
|
||||||
@@ -821,7 +896,7 @@ def setup_shell_routes() -> APIRouter:
|
|||||||
if idx == -1:
|
if idx == -1:
|
||||||
break
|
break
|
||||||
line = buf[:idx].decode(errors="replace")
|
line = buf[:idx].decode(errors="replace")
|
||||||
buf = buf[idx + sep_len:]
|
buf = buf[idx + sep_len :]
|
||||||
if line:
|
if line:
|
||||||
await q.put((name, line))
|
await q.put((name, line))
|
||||||
finally:
|
finally:
|
||||||
@@ -880,7 +955,12 @@ def setup_shell_routes() -> APIRouter:
|
|||||||
return StreamingResponse(generate(), media_type="text/event-stream")
|
return StreamingResponse(generate(), media_type="text/event-stream")
|
||||||
|
|
||||||
@router.get("/api/cookbook/packages")
|
@router.get("/api/cookbook/packages")
|
||||||
async def list_packages(request: Request, host: str | None = None, ssh_port: str | None = None, venv: str | None = None):
|
async def list_packages(
|
||||||
|
request: Request,
|
||||||
|
host: str | None = None,
|
||||||
|
ssh_port: str | None = None,
|
||||||
|
venv: str | None = None,
|
||||||
|
):
|
||||||
"""Check which optional packages are installed.
|
"""Check which optional packages are installed.
|
||||||
|
|
||||||
Local-target packages are checked in-process. Remote-target packages
|
Local-target packages are checked in-process. Remote-target packages
|
||||||
@@ -890,7 +970,13 @@ def setup_shell_routes() -> APIRouter:
|
|||||||
"""
|
"""
|
||||||
_require_admin(request)
|
_require_admin(request)
|
||||||
_reject_cross_site(request)
|
_reject_cross_site(request)
|
||||||
import importlib, importlib.metadata as importlib_metadata, shlex, json as _json, site, sys
|
import importlib
|
||||||
|
import importlib.metadata as importlib_metadata
|
||||||
|
import shlex
|
||||||
|
import json as _json
|
||||||
|
import site
|
||||||
|
import sys
|
||||||
|
|
||||||
_prepend_user_install_bins_to_path()
|
_prepend_user_install_bins_to_path()
|
||||||
importlib.invalidate_caches()
|
importlib.invalidate_caches()
|
||||||
try:
|
try:
|
||||||
@@ -905,26 +991,115 @@ def setup_shell_routes() -> APIRouter:
|
|||||||
raise HTTPException(400, "Invalid ssh_port")
|
raise HTTPException(400, "Invalid ssh_port")
|
||||||
packages = [
|
packages = [
|
||||||
# ── System ── OS binaries, not pip packages
|
# ── System ── OS binaries, not pip packages
|
||||||
{"name": "tmux", "pip": "", "desc": "Required for Linux/Termux Cookbook background downloads and serves", "category": "System", "target": "remote", "kind": "system", "install_hint": "Run Cookbook server setup, or install tmux with apt/pacman/dnf/apk/zypper."},
|
{
|
||||||
{"name": "docker", "pip": "", "desc": "Required only for Docker-backed launch commands", "category": "System", "target": "remote", "kind": "system", "install_hint": "Install Docker on the selected server and allow this user to run docker."},
|
"name": "tmux",
|
||||||
|
"pip": "",
|
||||||
|
"desc": "Required for Linux/Termux Cookbook background downloads and serves",
|
||||||
|
"category": "System",
|
||||||
|
"target": "remote",
|
||||||
|
"kind": "system",
|
||||||
|
"install_hint": "Run Cookbook server setup, or install tmux with apt/pacman/dnf/apk/zypper.",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "docker",
|
||||||
|
"pip": "",
|
||||||
|
"desc": "Required only for Docker-backed launch commands",
|
||||||
|
"category": "System",
|
||||||
|
"target": "remote",
|
||||||
|
"kind": "system",
|
||||||
|
"install_hint": "Install Docker on the selected server and allow this user to run docker.",
|
||||||
|
},
|
||||||
# ── LLM ── installs on GPU servers for model serving/downloading
|
# ── LLM ── installs on GPU servers for model serving/downloading
|
||||||
{"name": "hf_transfer", "pip": "hf_transfer", "desc": "Fast model downloads from HuggingFace", "category": "LLM", "target": "remote"},
|
{
|
||||||
{"name": "llama_cpp", "pip": "llama-cpp-python[server]", "desc": "Serve GGUF models via llama.cpp", "category": "LLM", "target": "remote"},
|
"name": "hf_transfer",
|
||||||
{"name": "sglang", "pip": "sglang[all]", "desc": "Serve HF safetensors models via SGLang", "category": "LLM", "target": "remote"},
|
"pip": "hf_transfer",
|
||||||
{"name": "vllm", "pip": "vllm", "desc": "High-throughput LLM serving engine", "category": "LLM", "target": "remote"},
|
"desc": "Fast model downloads from HuggingFace",
|
||||||
|
"category": "LLM",
|
||||||
|
"target": "remote",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "llama_cpp",
|
||||||
|
"pip": "llama-cpp-python[server]",
|
||||||
|
"desc": "Serve GGUF models via llama.cpp",
|
||||||
|
"category": "LLM",
|
||||||
|
"target": "remote",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "sglang",
|
||||||
|
"pip": "sglang[all]",
|
||||||
|
"desc": "Serve HF safetensors models via SGLang",
|
||||||
|
"category": "LLM",
|
||||||
|
"target": "remote",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "vllm",
|
||||||
|
"pip": "vllm",
|
||||||
|
"desc": "High-throughput LLM serving engine",
|
||||||
|
"category": "LLM",
|
||||||
|
"target": "remote",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "APFEL",
|
||||||
|
"pip": "",
|
||||||
|
"desc": "OpenAI-compatible API for Apple Foundational Models on Apple Silicon",
|
||||||
|
"category": "LLM",
|
||||||
|
"target": "local",
|
||||||
|
"kind": "system",
|
||||||
|
"install_cmd": "brew install apfel",
|
||||||
|
"update_cmd": "brew upgrade apfel",
|
||||||
|
"install_hint": "Requires a native Apple Silicon Mac with Apple Foundational Models support. Installable via Homebrew on supported Macs.",
|
||||||
|
},
|
||||||
# ── Image ── editor + diffusion model serving
|
# ── Image ── editor + diffusion model serving
|
||||||
{"name": "diffusers", "pip": "diffusers[torch]", "desc": "Image generation pipelines (SD, Flux) with PyTorch", "category": "Image", "target": "remote"},
|
{
|
||||||
{"name": "rembg", "pip": "rembg[gpu]", "desc": "AI background removal for image editor", "category": "Image", "target": "local"},
|
"name": "diffusers",
|
||||||
{"name": "realesrgan", "pip": "realesrgan", "desc": "AI denoise + upscale (Real-ESRGAN). Used by editor's Denoise and Upscale tools.", "category": "Image", "target": "local"},
|
"pip": "diffusers[torch]",
|
||||||
|
"desc": "Image generation pipelines (SD, Flux) with PyTorch",
|
||||||
|
"category": "Image",
|
||||||
|
"target": "remote",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "rembg",
|
||||||
|
"pip": "rembg[gpu]",
|
||||||
|
"desc": "AI background removal for image editor",
|
||||||
|
"category": "Image",
|
||||||
|
"target": "local",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "realesrgan",
|
||||||
|
"pip": "realesrgan",
|
||||||
|
"desc": "AI denoise + upscale (Real-ESRGAN). Used by editor's Denoise and Upscale tools.",
|
||||||
|
"category": "Image",
|
||||||
|
"target": "local",
|
||||||
|
},
|
||||||
# ── Tools ──
|
# ── Tools ──
|
||||||
{"name": "playwright", "pip": "playwright", "desc": "Browser automation for web tools", "category": "Tools", "target": "local"},
|
{
|
||||||
|
"name": "playwright",
|
||||||
|
"pip": "playwright",
|
||||||
|
"desc": "Browser automation for web tools",
|
||||||
|
"category": "Tools",
|
||||||
|
"target": "local",
|
||||||
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
|
# Most packages should not be installed through external means. Hence, set the default of the
|
||||||
|
# install_cmd and update_cmd to None, which indicates that the recommended way to install/update is through the Cookbook # server setup or pip. Only system packages, should have explicit install/update commands provided.
|
||||||
|
for pkg in packages:
|
||||||
|
pkg.setdefault("install_cmd", None)
|
||||||
|
pkg.setdefault("update_cmd", None)
|
||||||
# Remote check: for remote-target packages, probe the selected server's
|
# Remote check: for remote-target packages, probe the selected server's
|
||||||
# venv over SSH so a remote `pip install` actually reflects here.
|
# venv over SSH so a remote `pip install` actually reflects here.
|
||||||
remote_status: dict = {}
|
remote_status: dict = {}
|
||||||
remote_details: dict = {}
|
remote_details: dict = {}
|
||||||
remote_names = [p["name"] for p in packages if p.get("target") == "remote" and p.get("kind") != "system"]
|
remote_names = [
|
||||||
remote_system_names = [p["name"] for p in packages if p.get("target") == "remote" and p.get("kind") == "system"]
|
p["name"]
|
||||||
|
for p in packages
|
||||||
|
if p.get("target") == "remote" and p.get("kind") != "system"
|
||||||
|
]
|
||||||
|
remote_system_names = [
|
||||||
|
p["name"]
|
||||||
|
for p in packages
|
||||||
|
if p.get("target") == "remote" and p.get("kind") == "system"
|
||||||
|
]
|
||||||
if host and remote_names:
|
if host and remote_names:
|
||||||
try:
|
try:
|
||||||
py = _package_probe_script(remote_names)
|
py = _package_probe_script(remote_names)
|
||||||
@@ -934,7 +1109,9 @@ def setup_shell_routes() -> APIRouter:
|
|||||||
inner = f"{src}python3 -c {shlex.quote(py)}"
|
inner = f"{src}python3 -c {shlex.quote(py)}"
|
||||||
argv = _ssh_base_argv(host, ssh_port) + [inner]
|
argv = _ssh_base_argv(host, ssh_port) + [inner]
|
||||||
proc = await asyncio.create_subprocess_exec(
|
proc = await asyncio.create_subprocess_exec(
|
||||||
*argv, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE
|
*argv,
|
||||||
|
stdout=asyncio.subprocess.PIPE,
|
||||||
|
stderr=asyncio.subprocess.PIPE,
|
||||||
)
|
)
|
||||||
out, _err = await asyncio.wait_for(proc.communicate(), timeout=12)
|
out, _err = await asyncio.wait_for(proc.communicate(), timeout=12)
|
||||||
txt = out.decode("utf-8", errors="replace").strip()
|
txt = out.decode("utf-8", errors="replace").strip()
|
||||||
@@ -958,11 +1135,15 @@ def setup_shell_routes() -> APIRouter:
|
|||||||
checks = []
|
checks = []
|
||||||
for name in remote_system_names:
|
for name in remote_system_names:
|
||||||
qn = shlex.quote(name)
|
qn = shlex.quote(name)
|
||||||
checks.append(f"if command -v {qn} >/dev/null 2>&1; then echo {qn}=1; else echo {qn}=0; fi")
|
checks.append(
|
||||||
|
f"if command -v {qn} >/dev/null 2>&1; then echo {qn}=1; else echo {qn}=0; fi"
|
||||||
|
)
|
||||||
inner = " ; ".join(checks)
|
inner = " ; ".join(checks)
|
||||||
argv = _ssh_base_argv(host, ssh_port) + [inner]
|
argv = _ssh_base_argv(host, ssh_port) + [inner]
|
||||||
proc = await asyncio.create_subprocess_exec(
|
proc = await asyncio.create_subprocess_exec(
|
||||||
*argv, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE
|
*argv,
|
||||||
|
stdout=asyncio.subprocess.PIPE,
|
||||||
|
stderr=asyncio.subprocess.PIPE,
|
||||||
)
|
)
|
||||||
out, _err = await asyncio.wait_for(proc.communicate(), timeout=12)
|
out, _err = await asyncio.wait_for(proc.communicate(), timeout=12)
|
||||||
txt = out.decode("utf-8", errors="replace").strip()
|
txt = out.decode("utf-8", errors="replace").strip()
|
||||||
@@ -987,11 +1168,25 @@ def setup_shell_routes() -> APIRouter:
|
|||||||
if note:
|
if note:
|
||||||
pkg["status_note"] = note
|
pkg["status_note"] = note
|
||||||
elif pkg.get("kind") == "system":
|
elif pkg.get("kind") == "system":
|
||||||
pkg["installed"] = shutil.which(pkg["name"]) is not None
|
if pkg["name"] == "APFEL":
|
||||||
|
pkg["applicable"] = IS_APPLE_SILICON
|
||||||
|
pkg["installed"] = which_tool("apfel") is not None
|
||||||
|
pkg["status_note"] = (
|
||||||
|
"Available on Apple Silicon (arm64) devices; exposed through a local OpenAI-compatible API."
|
||||||
|
if IS_APPLE_SILICON
|
||||||
|
else "Requires a native Apple Silicon Mac with Apple Foundational Models support."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
pkg["installed"] = shutil.which(pkg["name"]) is not None
|
||||||
elif pkg["name"] == "llama_cpp" and shutil.which("llama-server"):
|
elif pkg["name"] == "llama_cpp" and shutil.which("llama-server"):
|
||||||
pkg["installed"] = True
|
pkg["installed"] = True
|
||||||
pkg["status_note"] = f"native llama-server: {shutil.which('llama-server')}"
|
pkg["status_note"] = (
|
||||||
probe = {"binaries": {"llama-server": shutil.which("llama-server")}, "dists": {}}
|
f"native llama-server: {shutil.which('llama-server')}"
|
||||||
|
)
|
||||||
|
probe = {
|
||||||
|
"binaries": {"llama-server": shutil.which("llama-server")},
|
||||||
|
"dists": {},
|
||||||
|
}
|
||||||
elif pkg["name"] == "vllm":
|
elif pkg["name"] == "vllm":
|
||||||
_vllm_cli = shutil.which("vllm")
|
_vllm_cli = shutil.which("vllm")
|
||||||
pkg["installed"] = _vllm_cli is not None
|
pkg["installed"] = _vllm_cli is not None
|
||||||
@@ -1014,6 +1209,12 @@ def setup_shell_routes() -> APIRouter:
|
|||||||
pkg["installed"] = False
|
pkg["installed"] = False
|
||||||
except importlib_metadata.PackageNotFoundError:
|
except importlib_metadata.PackageNotFoundError:
|
||||||
pkg["installed"] = False
|
pkg["installed"] = False
|
||||||
|
except Exception:
|
||||||
|
# Installed but crashes on import — e.g. a CUDA build of
|
||||||
|
# llama-cpp-python raising FileNotFoundError when the CUDA
|
||||||
|
# toolkit dir is absent. One broken optional package must not
|
||||||
|
# 500 the entire packages panel; report it as not usable.
|
||||||
|
pkg["installed"] = False
|
||||||
|
|
||||||
if pkg.get("installed"):
|
if pkg.get("installed"):
|
||||||
update_status = _package_pip_update_status(pkg, probe)
|
update_status = _package_pip_update_status(pkg, probe)
|
||||||
@@ -1037,15 +1238,30 @@ def setup_shell_routes() -> APIRouter:
|
|||||||
"""Install a package via pip. Admin only — pip install is effectively code exec."""
|
"""Install a package via pip. Admin only — pip install is effectively code exec."""
|
||||||
_require_admin(request)
|
_require_admin(request)
|
||||||
import sys as _sys
|
import sys as _sys
|
||||||
|
|
||||||
body = await request.json()
|
body = await request.json()
|
||||||
pip_name = body.get("pip")
|
pip_name = body.get("pip")
|
||||||
if not pip_name:
|
if not pip_name:
|
||||||
return {"ok": False, "error": "No package specified"}
|
return {"ok": False, "error": "No package specified"}
|
||||||
# Validate against known packages to prevent arbitrary pip install
|
# Validate against known packages to prevent arbitrary pip install
|
||||||
known = {
|
known = {
|
||||||
"rembg[gpu]", "hf_transfer", "llama-cpp-python[server]", "sglang[all]", "diffusers", "diffusers[torch]",
|
"rembg[gpu]",
|
||||||
"TTS", "bark", "faster-whisper", "playwright", "realesrgan", "gfpgan",
|
"hf_transfer",
|
||||||
"insightface", "onnxruntime-gpu", "onnxruntime", "hdbscan", "vllm",
|
"llama-cpp-python[server]",
|
||||||
|
"sglang[all]",
|
||||||
|
"diffusers",
|
||||||
|
"diffusers[torch]",
|
||||||
|
"TTS",
|
||||||
|
"bark",
|
||||||
|
"faster-whisper",
|
||||||
|
"playwright",
|
||||||
|
"realesrgan",
|
||||||
|
"gfpgan",
|
||||||
|
"insightface",
|
||||||
|
"onnxruntime-gpu",
|
||||||
|
"onnxruntime",
|
||||||
|
"hdbscan",
|
||||||
|
"vllm",
|
||||||
}
|
}
|
||||||
if pip_name not in known:
|
if pip_name not in known:
|
||||||
return {"ok": False, "error": f"Unknown package: {pip_name}"}
|
return {"ok": False, "error": f"Unknown package: {pip_name}"}
|
||||||
@@ -1071,6 +1287,7 @@ def setup_shell_routes() -> APIRouter:
|
|||||||
"""
|
"""
|
||||||
_require_admin(request)
|
_require_admin(request)
|
||||||
from routes.cookbook_helpers import _llama_cpp_rebuild_cmd
|
from routes.cookbook_helpers import _llama_cpp_rebuild_cmd
|
||||||
|
|
||||||
body = await request.json()
|
body = await request.json()
|
||||||
engine = str(body.get("engine") or "llamacpp").strip()
|
engine = str(body.get("engine") or "llamacpp").strip()
|
||||||
if engine != "llamacpp":
|
if engine != "llamacpp":
|
||||||
@@ -1079,7 +1296,11 @@ def setup_shell_routes() -> APIRouter:
|
|||||||
ssh_port = body.get("ssh_port")
|
ssh_port = body.get("ssh_port")
|
||||||
cmd = _llama_cpp_rebuild_cmd()
|
cmd = _llama_cpp_rebuild_cmd()
|
||||||
try:
|
try:
|
||||||
argv = (_ssh_base_argv(host, ssh_port) + [cmd]) if host else ["bash", "-lc", cmd]
|
argv = (
|
||||||
|
(_ssh_base_argv(host, ssh_port) + [cmd])
|
||||||
|
if host
|
||||||
|
else ["bash", "-lc", cmd]
|
||||||
|
)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
raise HTTPException(400, str(e))
|
raise HTTPException(400, str(e))
|
||||||
try:
|
try:
|
||||||
|
|||||||
+44
-16
@@ -21,10 +21,44 @@ from src.auth_helpers import get_current_user
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
_DATA_URL_RE = re.compile(
|
_DATA_URL_RE = re.compile(r"^data:image/png;base64,(?P<data>.+)$", re.IGNORECASE | re.DOTALL)
|
||||||
r'^data:image/(?P<fmt>png|jpeg|jpg);base64,(?P<data>.+)$',
|
_ANY_IMAGE_DATA_URL_RE = re.compile(r"^data:image/[^;]+;base64,", re.IGNORECASE)
|
||||||
re.IGNORECASE | re.DOTALL,
|
_PNG_MAGIC = b"\x89PNG\r\n\x1a\n"
|
||||||
)
|
_MAX_SIGNATURE_BYTES = 2 * 1024 * 1024
|
||||||
|
_MAX_SIGNATURE_B64 = ((_MAX_SIGNATURE_BYTES + 2) // 3) * 4
|
||||||
|
_MAX_SIGNATURE_DIMENSION = 4096
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_signature_png(raw: str) -> str:
|
||||||
|
raw = (raw or "").strip()
|
||||||
|
m = _DATA_URL_RE.match(raw)
|
||||||
|
if m:
|
||||||
|
b64 = m.group("data")
|
||||||
|
elif _ANY_IMAGE_DATA_URL_RE.match(raw):
|
||||||
|
raise HTTPException(400, "Signature data must be a PNG image")
|
||||||
|
else:
|
||||||
|
b64 = raw
|
||||||
|
if len(b64) > _MAX_SIGNATURE_B64:
|
||||||
|
raise HTTPException(400, "Signature PNG is too large")
|
||||||
|
try:
|
||||||
|
payload = base64.b64decode(b64, validate=True)
|
||||||
|
except Exception:
|
||||||
|
raise HTTPException(400, "Signature data must be base64-encoded PNG bytes")
|
||||||
|
if not payload:
|
||||||
|
raise HTTPException(400, "Signature PNG is empty")
|
||||||
|
if len(payload) > _MAX_SIGNATURE_BYTES:
|
||||||
|
raise HTTPException(400, "Signature PNG is too large")
|
||||||
|
if not payload.startswith(_PNG_MAGIC):
|
||||||
|
raise HTTPException(400, "Signature data must be a PNG image")
|
||||||
|
return base64.b64encode(payload).decode("ascii")
|
||||||
|
|
||||||
|
|
||||||
|
def _signature_dimension(value: Optional[int]) -> Optional[int]:
|
||||||
|
if value is None:
|
||||||
|
return None
|
||||||
|
if not isinstance(value, int) or value < 1 or value > _MAX_SIGNATURE_DIMENSION:
|
||||||
|
raise HTTPException(400, "Signature dimensions are invalid")
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
class SignatureCreate(BaseModel):
|
class SignatureCreate(BaseModel):
|
||||||
@@ -67,24 +101,18 @@ def setup_signature_routes() -> APIRouter:
|
|||||||
@router.post("/api/signatures")
|
@router.post("/api/signatures")
|
||||||
async def create_signature(request: Request, req: SignatureCreate) -> Dict[str, Any]:
|
async def create_signature(request: Request, req: SignatureCreate) -> Dict[str, Any]:
|
||||||
user = get_current_user(request)
|
user = get_current_user(request)
|
||||||
raw = (req.data or "").strip()
|
b64 = _normalize_signature_png(req.data)
|
||||||
m = _DATA_URL_RE.match(raw)
|
width = _signature_dimension(req.width)
|
||||||
b64 = m.group("data") if m else raw
|
height = _signature_dimension(req.height)
|
||||||
try:
|
|
||||||
payload = base64.b64decode(b64, validate=True)
|
|
||||||
if not payload:
|
|
||||||
raise ValueError("empty payload")
|
|
||||||
except Exception:
|
|
||||||
raise HTTPException(400, "Signature data must be base64-encoded PNG bytes")
|
|
||||||
|
|
||||||
sig = Signature(
|
sig = Signature(
|
||||||
id=str(uuid.uuid4()),
|
id=str(uuid.uuid4()),
|
||||||
owner=user,
|
owner=user,
|
||||||
name=(req.name or "Signature").strip() or "Signature",
|
name=(req.name or "Signature").strip() or "Signature",
|
||||||
data_png=b64,
|
data_png=b64,
|
||||||
width=req.width,
|
width=width,
|
||||||
height=req.height,
|
height=height,
|
||||||
svg=req.svg,
|
svg=None,
|
||||||
)
|
)
|
||||||
db = SessionLocal()
|
db = SessionLocal()
|
||||||
try:
|
try:
|
||||||
|
|||||||
+107
-1
@@ -11,6 +11,8 @@ import logging
|
|||||||
import re
|
import re
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
from fastapi import APIRouter, HTTPException, Request
|
from fastapi import APIRouter, HTTPException, Request
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
@@ -51,6 +53,10 @@ class SkillAddRequest(BaseModel):
|
|||||||
steps: List[str] = Field(default_factory=list)
|
steps: List[str] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
class SkillImportUrlRequest(BaseModel):
|
||||||
|
url: str = Field(..., min_length=8, max_length=2000)
|
||||||
|
|
||||||
|
|
||||||
class SkillUpdateRequest(BaseModel):
|
class SkillUpdateRequest(BaseModel):
|
||||||
name: Optional[str] = None
|
name: Optional[str] = None
|
||||||
description: Optional[str] = None
|
description: Optional[str] = None
|
||||||
@@ -1014,7 +1020,7 @@ def _resolve_audit_models(owner=None):
|
|||||||
spec = (get_setting("teacher_model", "") or "").strip()
|
spec = (get_setting("teacher_model", "") or "").strip()
|
||||||
if spec:
|
if spec:
|
||||||
from src.ai_interaction import _resolve_model
|
from src.ai_interaction import _resolve_model
|
||||||
t_url, t_model, t_headers = _resolve_model(spec)
|
t_url, t_model, t_headers = _resolve_model(spec, owner=owner)
|
||||||
if t_url and t_model:
|
if t_url and t_model:
|
||||||
teacher = (t_url, t_model, t_headers)
|
teacher = (t_url, t_model, t_headers)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -1103,6 +1109,35 @@ def setup_skills_routes(skills_manager: SkillsManager) -> APIRouter:
|
|||||||
idx = skills_manager.index_for(owner=user)
|
idx = skills_manager.index_for(owner=user)
|
||||||
return {"index": idx, "count": len(idx)}
|
return {"index": idx, "count": len(idx)}
|
||||||
|
|
||||||
|
@router.get("/slash-catalog")
|
||||||
|
async def get_slash_catalog(request: Request):
|
||||||
|
"""Return skills that are available as slash commands.
|
||||||
|
|
||||||
|
Mirrors the agent prompt's published-skill index so the UI never offers
|
||||||
|
a slash command the model would not normally be allowed to discover.
|
||||||
|
"""
|
||||||
|
user = _owner(request)
|
||||||
|
all_skills = {s.get("name"): s for s in skills_manager.load(owner=user)}
|
||||||
|
entries = []
|
||||||
|
for s in skills_manager.index_for(owner=user):
|
||||||
|
name = (s.get("name") or "").strip()
|
||||||
|
if not name:
|
||||||
|
continue
|
||||||
|
full = all_skills.get(name) or {}
|
||||||
|
category = (s.get("category") or full.get("category") or "general").strip() or "general"
|
||||||
|
entries.append({
|
||||||
|
"type": "skill",
|
||||||
|
"token": f"/{name}",
|
||||||
|
"name": name,
|
||||||
|
"category": f"Skills / {category}",
|
||||||
|
"help": s.get("description") or full.get("description") or "",
|
||||||
|
"usage": f"/{name} <request>",
|
||||||
|
"uses": int(full.get("uses") or 0),
|
||||||
|
"last_used": full.get("last_used"),
|
||||||
|
})
|
||||||
|
entries.sort(key=lambda row: row["name"])
|
||||||
|
return {"skills": entries, "count": len(entries)}
|
||||||
|
|
||||||
@router.get("/builtin")
|
@router.get("/builtin")
|
||||||
async def list_builtin_skills(request: Request):
|
async def list_builtin_skills(request: Request):
|
||||||
"""Read-only list of the agent's built-in tool capabilities (research,
|
"""Read-only list of the agent's built-in tool capabilities (research,
|
||||||
@@ -1203,6 +1238,36 @@ def setup_skills_routes(skills_manager: SkillsManager) -> APIRouter:
|
|||||||
save_settings(settings)
|
save_settings(settings)
|
||||||
return {"ok": True, "name": name, "is_overridden": False}
|
return {"ok": True, "name": name, "is_overridden": False}
|
||||||
|
|
||||||
|
@router.post("/import-from-url")
|
||||||
|
async def import_skill_from_url(request: Request, body: SkillImportUrlRequest):
|
||||||
|
"""Install a SKILL.md bundle from a public GitHub URL (skills.sh links supported)."""
|
||||||
|
require_admin(request)
|
||||||
|
user = _owner(request)
|
||||||
|
from services.memory.skill_importer import (
|
||||||
|
SkillImportError,
|
||||||
|
fetch_skill_bundle,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
files, _src = fetch_skill_bundle(body.url.strip())
|
||||||
|
entry = skills_manager.import_bundle_from_files(
|
||||||
|
files,
|
||||||
|
owner=user,
|
||||||
|
source_url=body.url.strip(),
|
||||||
|
)
|
||||||
|
except SkillImportError as e:
|
||||||
|
raise HTTPException(400, str(e)) from e
|
||||||
|
except httpx.HTTPError as e:
|
||||||
|
logger.warning("skill import fetch failed: %s", e)
|
||||||
|
detail = str(e).strip() or "Could not download skill from URL"
|
||||||
|
raise HTTPException(502, detail) from e
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("skill import failed: %s", e)
|
||||||
|
raise HTTPException(500, "Skill import failed") from e
|
||||||
|
|
||||||
|
_fire_skill_added(user)
|
||||||
|
return {"ok": True, "skill": entry, "files": len(files)}
|
||||||
|
|
||||||
@router.post("/add")
|
@router.post("/add")
|
||||||
async def add_skill(request: Request, body: SkillAddRequest):
|
async def add_skill(request: Request, body: SkillAddRequest):
|
||||||
user = _owner(request)
|
user = _owner(request)
|
||||||
@@ -1236,6 +1301,47 @@ def setup_skills_routes(skills_manager: SkillsManager) -> APIRouter:
|
|||||||
_fire_skill_added(user)
|
_fire_skill_added(user)
|
||||||
return {"ok": True, "deduped": bool(entry.get("_deduped")), "skill": entry}
|
return {"ok": True, "deduped": bool(entry.get("_deduped")), "skill": entry}
|
||||||
|
|
||||||
|
@router.post("/{skill_id}/invoke")
|
||||||
|
async def invoke_skill(request: Request, skill_id: str):
|
||||||
|
"""Build a skill-pinned prompt for slash-command invocation.
|
||||||
|
|
||||||
|
This is intentionally server-side so availability, ownership, and usage
|
||||||
|
accounting use the same rules as the SkillsManager.
|
||||||
|
"""
|
||||||
|
user = _owner(request)
|
||||||
|
try:
|
||||||
|
body = await request.json()
|
||||||
|
except Exception:
|
||||||
|
body = {}
|
||||||
|
request_text = (body.get("request") or "").strip() if isinstance(body, dict) else ""
|
||||||
|
|
||||||
|
invokable = {
|
||||||
|
s.get("name"): s for s in skills_manager.index_for(owner=user)
|
||||||
|
if (s.get("name") or "").strip()
|
||||||
|
}
|
||||||
|
match = invokable.get(skill_id)
|
||||||
|
if not match:
|
||||||
|
raise HTTPException(404, "Skill is not available for slash invocation")
|
||||||
|
|
||||||
|
name = match.get("name")
|
||||||
|
md = skills_manager.read_skill_md(name, owner=user)
|
||||||
|
if md is None:
|
||||||
|
raise HTTPException(404, "Skill source unavailable")
|
||||||
|
|
||||||
|
skills_manager.record_use(name, owner=user)
|
||||||
|
message = (
|
||||||
|
"Apply the skill below to my request, following its Procedure / Pitfalls / Verification.\n\n"
|
||||||
|
f"--- BEGIN SKILL ---\n{md}\n--- END SKILL ---\n\n"
|
||||||
|
+ (f"Request: {request_text}" if request_text else "Request: (use the skill as appropriate)")
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"ok": True,
|
||||||
|
"type": "skill",
|
||||||
|
"name": name,
|
||||||
|
"command": f"/{name}",
|
||||||
|
"message": message,
|
||||||
|
}
|
||||||
|
|
||||||
@router.get("/{skill_id}")
|
@router.get("/{skill_id}")
|
||||||
async def get_skill(request: Request, skill_id: str):
|
async def get_skill(request: Request, skill_id: str):
|
||||||
user = _owner(request)
|
user = _owner(request)
|
||||||
|
|||||||
@@ -4,12 +4,10 @@
|
|||||||
from fastapi import APIRouter, HTTPException, UploadFile, File
|
from fastapi import APIRouter, HTTPException, UploadFile, File
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from src.upload_limits import read_upload_limited
|
from src.upload_limits import read_upload_limited, STT_MAX_AUDIO_BYTES
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
STT_MAX_AUDIO_BYTES = 25 * 1024 * 1024
|
|
||||||
|
|
||||||
|
|
||||||
def setup_stt_routes(stt_service):
|
def setup_stt_routes(stt_service):
|
||||||
"""Setup STT routes with the provided STT service"""
|
"""Setup STT routes with the provided STT service"""
|
||||||
|
|||||||
+153
-11
@@ -11,13 +11,128 @@ from fastapi import APIRouter, HTTPException, Request
|
|||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from core.database import SessionLocal, ScheduledTask, TaskRun
|
from core.database import SessionLocal, ScheduledTask, TaskRun
|
||||||
|
from core.constants import internal_api_base
|
||||||
from src.auth_helpers import get_current_user
|
from src.auth_helpers import get_current_user
|
||||||
|
from src.constants import DATA_DIR, EMAIL_URGENCY_CACHE_DIR
|
||||||
from src.task_scheduler import compute_next_run, HOUSEKEEPING_DEFAULTS
|
from src.task_scheduler import compute_next_run, HOUSEKEEPING_DEFAULTS
|
||||||
from routes.prefs_routes import _load_for_user, _save_for_user
|
from routes.prefs_routes import _load_for_user, _save_for_user
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _maybe_cascade_calendar_event(task) -> None:
|
||||||
|
"""Delete the linked calendar event when a cookbook_serve task is
|
||||||
|
removed. Two lookup strategies:
|
||||||
|
|
||||||
|
1. PRIMARY — `cookbook_event_uid` marker stashed in task.prompt
|
||||||
|
by cookbookSchedule.js right after creating the event. Direct
|
||||||
|
UID match, no ambiguity.
|
||||||
|
|
||||||
|
2. FALLBACK — for tasks created before the marker was wired up
|
||||||
|
(or when the PATCH to add the marker failed silently), scan
|
||||||
|
the Cookbook calendar for events whose summary equals the
|
||||||
|
task name and delete the matches.
|
||||||
|
|
||||||
|
Best-effort throughout: errors are logged but never block the task
|
||||||
|
deletion itself."""
|
||||||
|
if not task or task.task_type != "action" or task.action != "cookbook_serve":
|
||||||
|
return
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
from core.middleware import INTERNAL_TOOL_HEADER, INTERNAL_TOOL_TOKEN
|
||||||
|
headers = {INTERNAL_TOOL_HEADER: INTERNAL_TOOL_TOKEN}
|
||||||
|
if task.owner:
|
||||||
|
headers["X-Odysseus-Owner"] = task.owner
|
||||||
|
|
||||||
|
# Strategy 1: explicit UID marker in prompt.
|
||||||
|
event_uid = ""
|
||||||
|
if task.prompt:
|
||||||
|
try:
|
||||||
|
cfg = json.loads(task.prompt)
|
||||||
|
if isinstance(cfg, dict):
|
||||||
|
event_uid = (cfg.get("cookbook_event_uid") or "").strip()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _try_delete(uid: str) -> bool:
|
||||||
|
try:
|
||||||
|
with httpx.Client(timeout=10) as client:
|
||||||
|
r = client.delete(
|
||||||
|
f"{internal_api_base()}/api/calendar/events/{uid}",
|
||||||
|
headers=headers,
|
||||||
|
)
|
||||||
|
if r.status_code >= 400:
|
||||||
|
logger.info(
|
||||||
|
f"task delete: cascade calendar event {uid} returned "
|
||||||
|
f"HTTP {r.status_code}"
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"task delete: cascade calendar event {uid} failed: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
if event_uid:
|
||||||
|
_try_delete(event_uid)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Strategy 2: scan the Cookbook calendar for matching summaries.
|
||||||
|
# Only runs for tasks missing the marker (old tasks or PATCH failures).
|
||||||
|
if not task.name:
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
with httpx.Client(timeout=10) as client:
|
||||||
|
# Find the Cookbook calendar.
|
||||||
|
cal_r = client.get(f"{internal_api_base()}/api/calendar/calendars", headers=headers)
|
||||||
|
if cal_r.status_code >= 400:
|
||||||
|
return
|
||||||
|
cals = (cal_r.json() or {}).get("calendars", [])
|
||||||
|
cookbook_cal = next(
|
||||||
|
(c for c in cals if (c.get("name") or "").lower() == "cookbook"),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
if not cookbook_cal:
|
||||||
|
return
|
||||||
|
cal_href = cookbook_cal.get("href") or cookbook_cal.get("id") or ""
|
||||||
|
# List events in a wide window to catch recurring + upcoming.
|
||||||
|
from datetime import datetime as _dt, timedelta as _td, timezone as _tz
|
||||||
|
now = _dt.now(_tz.utc)
|
||||||
|
start = (now - _td(days=30)).isoformat()
|
||||||
|
end = (now + _td(days=365)).isoformat()
|
||||||
|
ev_r = client.get(
|
||||||
|
f"{internal_api_base()}/api/calendar/events",
|
||||||
|
params={"start": start, "end": end, "calendar": cal_href},
|
||||||
|
headers=headers,
|
||||||
|
)
|
||||||
|
if ev_r.status_code >= 400:
|
||||||
|
return
|
||||||
|
events = (ev_r.json() or {}).get("events", [])
|
||||||
|
# Match by exact summary. Tasks named "Serve: <model>" are
|
||||||
|
# created from the schedule modal; the event's summary mirrors
|
||||||
|
# the task name 1:1 by design.
|
||||||
|
target = (task.name or "").strip()
|
||||||
|
uids_to_delete = set()
|
||||||
|
for ev in events:
|
||||||
|
if (ev.get("summary") or "").strip() != target:
|
||||||
|
continue
|
||||||
|
uid = ev.get("uid") or ev.get("id") or ""
|
||||||
|
# Strip the "::occurrence" suffix on recurring expansions —
|
||||||
|
# we want to delete the MASTER once, not each instance.
|
||||||
|
if "::" in uid:
|
||||||
|
uid = uid.split("::", 1)[0]
|
||||||
|
if uid:
|
||||||
|
uids_to_delete.add(uid)
|
||||||
|
for uid in uids_to_delete:
|
||||||
|
_try_delete(uid)
|
||||||
|
if uids_to_delete:
|
||||||
|
logger.info(
|
||||||
|
f"task delete: cascade matched {len(uids_to_delete)} calendar event(s) "
|
||||||
|
f"by summary fallback for task {task.id} ({target!r})"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"task delete: cascade fallback scan failed: {e}")
|
||||||
|
|
||||||
|
|
||||||
class TaskCreate(BaseModel):
|
class TaskCreate(BaseModel):
|
||||||
name: Optional[str] = None
|
name: Optional[str] = None
|
||||||
prompt: Optional[str] = None
|
prompt: Optional[str] = None
|
||||||
@@ -178,20 +293,24 @@ def setup_task_routes(task_scheduler) -> APIRouter:
|
|||||||
def _owner(request: Request):
|
def _owner(request: Request):
|
||||||
return get_current_user(request)
|
return get_current_user(request)
|
||||||
|
|
||||||
async def _generate_task_name(prompt: str) -> str:
|
async def _generate_task_name(prompt: str, owner: Optional[str] = None) -> str:
|
||||||
"""Use LLM to generate a short task name from the prompt."""
|
"""Use LLM to generate a short task name from the prompt."""
|
||||||
try:
|
try:
|
||||||
from src.llm_core import llm_call_async
|
from src.llm_core import llm_call_async
|
||||||
from core.database import Session as DbSession
|
from core.database import Session as DbSession
|
||||||
db = SessionLocal()
|
db = SessionLocal()
|
||||||
try:
|
try:
|
||||||
recent = db.query(DbSession).filter(
|
q = db.query(DbSession).filter(
|
||||||
DbSession.endpoint_url.isnot(None),
|
DbSession.endpoint_url.isnot(None),
|
||||||
DbSession.model.isnot(None),
|
DbSession.model.isnot(None),
|
||||||
).order_by(DbSession.created_at.desc()).first()
|
)
|
||||||
|
if owner:
|
||||||
|
q = q.filter(DbSession.owner == owner)
|
||||||
|
recent = q.order_by(DbSession.created_at.desc()).first()
|
||||||
if not recent:
|
if not recent:
|
||||||
return prompt[:50].strip()
|
return prompt[:50].strip()
|
||||||
url, model = recent.endpoint_url, recent.model
|
url, model = recent.endpoint_url, recent.model
|
||||||
|
headers = recent.headers or {}
|
||||||
finally:
|
finally:
|
||||||
db.close()
|
db.close()
|
||||||
|
|
||||||
@@ -202,6 +321,7 @@ def setup_task_routes(task_scheduler) -> APIRouter:
|
|||||||
{"role": "user", "content": prompt[:500]},
|
{"role": "user", "content": prompt[:500]},
|
||||||
],
|
],
|
||||||
max_tokens=20,
|
max_tokens=20,
|
||||||
|
headers=headers,
|
||||||
timeout=15,
|
timeout=15,
|
||||||
)
|
)
|
||||||
title = result.strip().strip('"\'').strip()
|
title = result.strip().strip('"\'').strip()
|
||||||
@@ -316,6 +436,20 @@ def setup_task_routes(task_scheduler) -> APIRouter:
|
|||||||
except Exception:
|
except Exception:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
def _validate_then_task_id(db, then_task_id: Optional[str], user: Optional[str], current_task_id: Optional[str] = None) -> Optional[str]:
|
||||||
|
target_id = (then_task_id or "").strip()
|
||||||
|
if not target_id:
|
||||||
|
return None
|
||||||
|
if current_task_id and target_id == current_task_id:
|
||||||
|
raise HTTPException(400, "Task cannot chain to itself")
|
||||||
|
q = db.query(ScheduledTask).filter(ScheduledTask.id == target_id)
|
||||||
|
if user:
|
||||||
|
q = q.filter(ScheduledTask.owner == user)
|
||||||
|
target = q.first()
|
||||||
|
if not target:
|
||||||
|
raise HTTPException(404, "Chained task not found")
|
||||||
|
return target.id
|
||||||
|
|
||||||
@router.post("")
|
@router.post("")
|
||||||
async def create_task(request: Request, req: TaskCreate):
|
async def create_task(request: Request, req: TaskCreate):
|
||||||
user = _owner(request)
|
user = _owner(request)
|
||||||
@@ -352,7 +486,7 @@ def setup_task_routes(task_scheduler) -> APIRouter:
|
|||||||
from src.builtin_actions import BUILTIN_ACTION_INFO
|
from src.builtin_actions import BUILTIN_ACTION_INFO
|
||||||
name = BUILTIN_ACTION_INFO.get(req.action, req.action or "Action Task")
|
name = BUILTIN_ACTION_INFO.get(req.action, req.action or "Action Task")
|
||||||
elif req.prompt:
|
elif req.prompt:
|
||||||
name = await _generate_task_name(req.prompt)
|
name = await _generate_task_name(req.prompt, owner=user)
|
||||||
else:
|
else:
|
||||||
name = "Untitled Task"
|
name = "Untitled Task"
|
||||||
|
|
||||||
@@ -379,6 +513,7 @@ def setup_task_routes(task_scheduler) -> APIRouter:
|
|||||||
task_id = str(uuid.uuid4())
|
task_id = str(uuid.uuid4())
|
||||||
db = SessionLocal()
|
db = SessionLocal()
|
||||||
try:
|
try:
|
||||||
|
then_task_id = _validate_then_task_id(db, req.then_task_id, user)
|
||||||
notifications_enabled = (
|
notifications_enabled = (
|
||||||
False if req.task_type == "action" and req.notifications_enabled is None
|
False if req.task_type == "action" and req.notifications_enabled is None
|
||||||
else bool(req.notifications_enabled) if req.notifications_enabled is not None
|
else bool(req.notifications_enabled) if req.notifications_enabled is not None
|
||||||
@@ -405,7 +540,7 @@ def setup_task_routes(task_scheduler) -> APIRouter:
|
|||||||
output_target=req.output_target,
|
output_target=req.output_target,
|
||||||
model=req.model or None,
|
model=req.model or None,
|
||||||
endpoint_url=req.endpoint_url or None,
|
endpoint_url=req.endpoint_url or None,
|
||||||
then_task_id=req.then_task_id or None,
|
then_task_id=then_task_id,
|
||||||
webhook_token=webhook_token,
|
webhook_token=webhook_token,
|
||||||
notifications_enabled=notifications_enabled,
|
notifications_enabled=notifications_enabled,
|
||||||
)
|
)
|
||||||
@@ -487,7 +622,7 @@ def setup_task_routes(task_scheduler) -> APIRouter:
|
|||||||
|
|
||||||
removed_files = 0
|
removed_files = 0
|
||||||
if action == "check_email_urgency":
|
if action == "check_email_urgency":
|
||||||
cache_dir = Path("data/email_urgency_cache")
|
cache_dir = Path(EMAIL_URGENCY_CACHE_DIR)
|
||||||
if cache_dir.exists():
|
if cache_dir.exists():
|
||||||
for child in cache_dir.glob("*.json"):
|
for child in cache_dir.glob("*.json"):
|
||||||
try:
|
try:
|
||||||
@@ -496,7 +631,7 @@ def setup_task_routes(task_scheduler) -> APIRouter:
|
|||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
owner_slug = "".join(c if (c.isalnum() or c in "-_.@") else "_" for c in (user or "default"))
|
owner_slug = "".join(c if (c.isalnum() or c in "-_.@") else "_" for c in (user or "default"))
|
||||||
for state_path in [Path(f"data/email_urgency_state_{owner_slug}.json")]:
|
for state_path in [Path(DATA_DIR) / f"email_urgency_state_{owner_slug}.json"]:
|
||||||
try:
|
try:
|
||||||
if state_path.exists():
|
if state_path.exists():
|
||||||
state_path.unlink()
|
state_path.unlink()
|
||||||
@@ -558,7 +693,7 @@ def setup_task_routes(task_scheduler) -> APIRouter:
|
|||||||
if req.trigger_count is not None:
|
if req.trigger_count is not None:
|
||||||
task.trigger_count = req.trigger_count
|
task.trigger_count = req.trigger_count
|
||||||
if req.then_task_id is not None:
|
if req.then_task_id is not None:
|
||||||
task.then_task_id = req.then_task_id or None
|
task.then_task_id = _validate_then_task_id(db, req.then_task_id, user, current_task_id=task.id)
|
||||||
if req.notifications_enabled is not None:
|
if req.notifications_enabled is not None:
|
||||||
task.notifications_enabled = bool(req.notifications_enabled)
|
task.notifications_enabled = bool(req.notifications_enabled)
|
||||||
if req.cron_expression is not None:
|
if req.cron_expression is not None:
|
||||||
@@ -616,6 +751,12 @@ def setup_task_routes(task_scheduler) -> APIRouter:
|
|||||||
raise HTTPException(404, "Task not found")
|
raise HTTPException(404, "Task not found")
|
||||||
if user and task.owner != user:
|
if user and task.owner != user:
|
||||||
raise HTTPException(403, "Access denied")
|
raise HTTPException(403, "Access denied")
|
||||||
|
# Cascade: cookbook_serve tasks may have a linked calendar
|
||||||
|
# event (created via the "Create event in calendar" toggle
|
||||||
|
# in the schedule modal). If so, delete the calendar event
|
||||||
|
# too so the calendar doesn't end up holding a phantom event
|
||||||
|
# for a task that no longer exists.
|
||||||
|
_maybe_cascade_calendar_event(task)
|
||||||
db.delete(task)
|
db.delete(task)
|
||||||
db.commit()
|
db.commit()
|
||||||
return {"ok": True}
|
return {"ok": True}
|
||||||
@@ -833,7 +974,7 @@ def setup_task_routes(task_scheduler) -> APIRouter:
|
|||||||
"tag", "label", "move", "archive", "delete", "mark", "schedule",
|
"tag", "label", "move", "archive", "delete", "mark", "schedule",
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
from src.agent_tools import get_mcp_manager
|
from src.tool_utils import get_mcp_manager
|
||||||
mcp = get_mcp_manager()
|
mcp = get_mcp_manager()
|
||||||
if mcp:
|
if mcp:
|
||||||
for tool in mcp.get_all_tools():
|
for tool in mcp.get_all_tools():
|
||||||
@@ -928,6 +1069,7 @@ def setup_task_routes(task_scheduler) -> APIRouter:
|
|||||||
desc = (body.get("description") or "").strip()
|
desc = (body.get("description") or "").strip()
|
||||||
if not desc:
|
if not desc:
|
||||||
return {"success": False, "message": "Nothing to parse"}
|
return {"success": False, "message": "Nothing to parse"}
|
||||||
|
user = _owner(request)
|
||||||
|
|
||||||
now = _dt.now()
|
now = _dt.now()
|
||||||
# Give the model the current date/time + weekday so relative phrasing
|
# Give the model the current date/time + weekday so relative phrasing
|
||||||
@@ -954,9 +1096,9 @@ def setup_task_routes(task_scheduler) -> APIRouter:
|
|||||||
"use cron '0 H * * 1-5'. Keep the prompt actionable and self-contained."
|
"use cron '0 H * * 1-5'. Keep the prompt actionable and self-contained."
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
url, model, headers = resolve_endpoint("utility")
|
url, model, headers = resolve_endpoint("utility", owner=user or None)
|
||||||
if not url:
|
if not url:
|
||||||
url, model, headers = resolve_endpoint("default")
|
url, model, headers = resolve_endpoint("default", owner=user or None)
|
||||||
if not (url and model):
|
if not (url and model):
|
||||||
return {"success": False, "message": "No model endpoint configured"}
|
return {"success": False, "message": "No model endpoint configured"}
|
||||||
raw = await llm_call_async(
|
raw = await llm_call_async(
|
||||||
|
|||||||
+51
-34
@@ -13,9 +13,43 @@ from src.upload_handler import count_recent_uploads
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
router = APIRouter(prefix="/api/upload", tags=["upload"])
|
router = APIRouter(prefix="/api/upload", tags=["upload"])
|
||||||
|
UPLOAD_RESPONSE_HEADERS = {"X-Content-Type-Options": "nosniff"}
|
||||||
|
|
||||||
def setup_upload_routes(upload_handler):
|
def setup_upload_routes(upload_handler):
|
||||||
"""Setup upload routes with the provided handler"""
|
"""Setup upload routes with the provided handler"""
|
||||||
|
|
||||||
|
def _upload_root() -> str:
|
||||||
|
from src.constants import UPLOAD_DIR
|
||||||
|
return os.path.realpath(getattr(upload_handler, "upload_dir", UPLOAD_DIR))
|
||||||
|
|
||||||
|
def _path_inside_upload_dir(path: str) -> bool:
|
||||||
|
try:
|
||||||
|
return os.path.commonpath([_upload_root(), os.path.realpath(path)]) == _upload_root()
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _resolve_upload_path(file_id: str) -> str:
|
||||||
|
from src.constants import UPLOAD_DIR
|
||||||
|
upload_root = getattr(upload_handler, "upload_dir", UPLOAD_DIR)
|
||||||
|
direct = os.path.join(upload_root, file_id)
|
||||||
|
if os.path.lexists(direct):
|
||||||
|
if not _path_inside_upload_dir(direct):
|
||||||
|
raise HTTPException(403, "Access denied")
|
||||||
|
if os.path.isfile(direct):
|
||||||
|
return direct
|
||||||
|
raise HTTPException(404, "File not found")
|
||||||
|
|
||||||
|
for root, _dirs, files in os.walk(upload_root, followlinks=False):
|
||||||
|
if file_id not in files:
|
||||||
|
continue
|
||||||
|
path = os.path.join(root, file_id)
|
||||||
|
if not _path_inside_upload_dir(path):
|
||||||
|
raise HTTPException(403, "Access denied")
|
||||||
|
if os.path.isfile(path):
|
||||||
|
return path
|
||||||
|
raise HTTPException(404, "File not found")
|
||||||
|
|
||||||
|
raise HTTPException(404, "File not found")
|
||||||
|
|
||||||
@router.post("")
|
@router.post("")
|
||||||
async def api_upload(request: Request, files: List[UploadFile] = File(...)):
|
async def api_upload(request: Request, files: List[UploadFile] = File(...)):
|
||||||
@@ -91,23 +125,11 @@ def setup_upload_routes(upload_handler):
|
|||||||
client isn't downloading the full-resolution photo just to show it tiny."""
|
client isn't downloading the full-resolution photo just to show it tiny."""
|
||||||
if not upload_handler.validate_upload_id(file_id):
|
if not upload_handler.validate_upload_id(file_id):
|
||||||
raise HTTPException(400, "Invalid file ID")
|
raise HTTPException(400, "Invalid file ID")
|
||||||
# Search upload directories for the file
|
|
||||||
from src.constants import UPLOAD_DIR
|
|
||||||
import mimetypes as _mt
|
import mimetypes as _mt
|
||||||
path = os.path.join(UPLOAD_DIR, file_id)
|
|
||||||
if not os.path.exists(path):
|
|
||||||
for root, dirs, files in os.walk(UPLOAD_DIR):
|
|
||||||
if file_id in files:
|
|
||||||
path = os.path.join(root, file_id)
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
raise HTTPException(404, "File not found")
|
|
||||||
if not upload_handler.inside_base_dir(path):
|
|
||||||
raise HTTPException(403, "Access denied")
|
|
||||||
# Look up original filename and owner from uploads.json
|
# Look up original filename and owner from uploads.json
|
||||||
original_name = file_id
|
original_name = file_id
|
||||||
info = None
|
info = None
|
||||||
uploads_db = os.path.join(UPLOAD_DIR, "uploads.json")
|
uploads_db = os.path.join(_upload_root(), "uploads.json")
|
||||||
if os.path.exists(uploads_db):
|
if os.path.exists(uploads_db):
|
||||||
with open(uploads_db, encoding="utf-8") as f:
|
with open(uploads_db, encoding="utf-8") as f:
|
||||||
db = json.load(f)
|
db = json.load(f)
|
||||||
@@ -123,13 +145,14 @@ def setup_upload_routes(upload_handler):
|
|||||||
raise HTTPException(403, "Access denied")
|
raise HTTPException(403, "Access denied")
|
||||||
if file_owner != current_user and not auth_mgr.is_admin(current_user):
|
if file_owner != current_user and not auth_mgr.is_admin(current_user):
|
||||||
raise HTTPException(404, "File not found")
|
raise HTTPException(404, "File not found")
|
||||||
mime = _mt.guess_type(path)[0] or "application/octet-stream"
|
path = _resolve_upload_path(file_id)
|
||||||
|
mime = (info or {}).get("mime") or _mt.guess_type(path)[0] or "application/octet-stream"
|
||||||
from fastapi.responses import FileResponse
|
from fastapi.responses import FileResponse
|
||||||
# Downscaled thumbnail for image previews — generated once and cached.
|
# Downscaled thumbnail for image previews — generated once and cached.
|
||||||
if thumb and mime.startswith("image/"):
|
if thumb and mime.startswith("image/"):
|
||||||
try:
|
try:
|
||||||
from PIL import Image, ImageOps
|
from PIL import Image, ImageOps
|
||||||
thumb_dir = os.path.join(UPLOAD_DIR, ".thumbs")
|
thumb_dir = os.path.join(_upload_root(), ".thumbs")
|
||||||
os.makedirs(thumb_dir, exist_ok=True)
|
os.makedirs(thumb_dir, exist_ok=True)
|
||||||
thumb_path = os.path.join(thumb_dir, file_id + ".jpg")
|
thumb_path = os.path.join(thumb_dir, file_id + ".jpg")
|
||||||
if (not os.path.exists(thumb_path)
|
if (not os.path.exists(thumb_path)
|
||||||
@@ -145,17 +168,21 @@ def setup_upload_routes(upload_handler):
|
|||||||
if im.mode not in ("RGB", "L"):
|
if im.mode not in ("RGB", "L"):
|
||||||
im = im.convert("RGB")
|
im = im.convert("RGB")
|
||||||
im.save(thumb_path, "JPEG", quality=80)
|
im.save(thumb_path, "JPEG", quality=80)
|
||||||
return FileResponse(thumb_path, media_type="image/jpeg")
|
return FileResponse(thumb_path, media_type="image/jpeg", headers=UPLOAD_RESPONSE_HEADERS)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Thumbnail generation failed for {file_id}: {e}")
|
logger.warning(f"Thumbnail generation failed for {file_id}: {e}")
|
||||||
# Fall through to the full image.
|
# Fall through to the full image.
|
||||||
return FileResponse(path, media_type=mime, filename=original_name)
|
return FileResponse(
|
||||||
|
path,
|
||||||
|
media_type=mime,
|
||||||
|
filename=original_name,
|
||||||
|
headers=UPLOAD_RESPONSE_HEADERS,
|
||||||
|
)
|
||||||
|
|
||||||
def _load_upload_info(file_id: str):
|
def _load_upload_info(file_id: str):
|
||||||
"""Look up the uploads.json record for a file_id, with owner/auth checks."""
|
"""Look up the uploads.json record for a file_id, with owner/auth checks."""
|
||||||
from src.constants import UPLOAD_DIR
|
|
||||||
info = None
|
info = None
|
||||||
uploads_db = os.path.join(UPLOAD_DIR, "uploads.json")
|
uploads_db = os.path.join(_upload_root(), "uploads.json")
|
||||||
if os.path.exists(uploads_db):
|
if os.path.exists(uploads_db):
|
||||||
with open(uploads_db, encoding="utf-8") as f:
|
with open(uploads_db, encoding="utf-8") as f:
|
||||||
db = json.load(f)
|
db = json.load(f)
|
||||||
@@ -163,8 +190,7 @@ def setup_upload_routes(upload_handler):
|
|||||||
return info
|
return info
|
||||||
|
|
||||||
def _vision_cache_path(file_id: str) -> str:
|
def _vision_cache_path(file_id: str) -> str:
|
||||||
from src.constants import UPLOAD_DIR
|
cache_dir = os.path.join(_upload_root(), ".vision")
|
||||||
cache_dir = os.path.join(UPLOAD_DIR, ".vision")
|
|
||||||
os.makedirs(cache_dir, exist_ok=True)
|
os.makedirs(cache_dir, exist_ok=True)
|
||||||
return os.path.join(cache_dir, file_id + ".txt")
|
return os.path.join(cache_dir, file_id + ".txt")
|
||||||
|
|
||||||
@@ -175,17 +201,6 @@ def setup_upload_routes(upload_handler):
|
|||||||
subsequent loads are instant. Pass force=1 to recompute."""
|
subsequent loads are instant. Pass force=1 to recompute."""
|
||||||
if not upload_handler.validate_upload_id(file_id):
|
if not upload_handler.validate_upload_id(file_id):
|
||||||
raise HTTPException(400, "Invalid file ID")
|
raise HTTPException(400, "Invalid file ID")
|
||||||
from src.constants import UPLOAD_DIR
|
|
||||||
path = os.path.join(UPLOAD_DIR, file_id)
|
|
||||||
if not os.path.exists(path):
|
|
||||||
for root, dirs, files in os.walk(UPLOAD_DIR):
|
|
||||||
if file_id in files:
|
|
||||||
path = os.path.join(root, file_id)
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
raise HTTPException(404, "File not found")
|
|
||||||
if not upload_handler.inside_base_dir(path):
|
|
||||||
raise HTTPException(403, "Access denied")
|
|
||||||
info = _load_upload_info(file_id)
|
info = _load_upload_info(file_id)
|
||||||
auth_mgr = getattr(request.app.state, "auth_manager", None)
|
auth_mgr = getattr(request.app.state, "auth_manager", None)
|
||||||
auth_configured = bool(auth_mgr and auth_mgr.is_configured)
|
auth_configured = bool(auth_mgr and auth_mgr.is_configured)
|
||||||
@@ -196,8 +211,9 @@ def setup_upload_routes(upload_handler):
|
|||||||
raise HTTPException(403, "Access denied")
|
raise HTTPException(403, "Access denied")
|
||||||
if file_owner != current_user and not auth_mgr.is_admin(current_user):
|
if file_owner != current_user and not auth_mgr.is_admin(current_user):
|
||||||
raise HTTPException(404, "File not found")
|
raise HTTPException(404, "File not found")
|
||||||
|
path = _resolve_upload_path(file_id)
|
||||||
import mimetypes as _mt
|
import mimetypes as _mt
|
||||||
mime = _mt.guess_type(path)[0] or ""
|
mime = (info or {}).get("mime") or _mt.guess_type(path)[0] or ""
|
||||||
if not mime.startswith("image/"):
|
if not mime.startswith("image/"):
|
||||||
raise HTTPException(400, "Not an image")
|
raise HTTPException(400, "Not an image")
|
||||||
cache_path = _vision_cache_path(file_id)
|
cache_path = _vision_cache_path(file_id)
|
||||||
@@ -209,7 +225,7 @@ def setup_upload_routes(upload_handler):
|
|||||||
logger.warning(f"Vision cache read failed for {file_id}: {e}")
|
logger.warning(f"Vision cache read failed for {file_id}: {e}")
|
||||||
from src.document_processor import analyze_image_with_vl
|
from src.document_processor import analyze_image_with_vl
|
||||||
try:
|
try:
|
||||||
text = analyze_image_with_vl(path) or ""
|
text = analyze_image_with_vl(path, owner=current_user) or ""
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Vision analysis failed for {file_id}: {e}")
|
logger.error(f"Vision analysis failed for {file_id}: {e}")
|
||||||
raise HTTPException(500, f"Vision analysis failed: {e}")
|
raise HTTPException(500, f"Vision analysis failed: {e}")
|
||||||
@@ -238,6 +254,7 @@ def setup_upload_routes(upload_handler):
|
|||||||
raise HTTPException(403, "Access denied")
|
raise HTTPException(403, "Access denied")
|
||||||
if file_owner != current_user and not auth_mgr.is_admin(current_user):
|
if file_owner != current_user and not auth_mgr.is_admin(current_user):
|
||||||
raise HTTPException(404, "File not found")
|
raise HTTPException(404, "File not found")
|
||||||
|
_resolve_upload_path(file_id)
|
||||||
body = await request.json()
|
body = await request.json()
|
||||||
text = (body or {}).get("text", "")
|
text = (body or {}).get("text", "")
|
||||||
if not isinstance(text, str):
|
if not isinstance(text, str):
|
||||||
|
|||||||
@@ -17,10 +17,11 @@ from pydantic import BaseModel
|
|||||||
|
|
||||||
from core.middleware import require_admin
|
from core.middleware import require_admin
|
||||||
from core.platform_compat import IS_WINDOWS, safe_chmod, which_tool
|
from core.platform_compat import IS_WINDOWS, safe_chmod, which_tool
|
||||||
|
from src.constants import VAULT_FILE as _VAULT_FILE
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
VAULT_FILE = Path("data/vault.json")
|
VAULT_FILE = Path(_VAULT_FILE)
|
||||||
|
|
||||||
|
|
||||||
def _find_bw() -> str:
|
def _find_bw() -> str:
|
||||||
|
|||||||
+23
-10
@@ -194,6 +194,8 @@ def setup_webhook_routes(
|
|||||||
"together": "https://api.together.xyz/v1",
|
"together": "https://api.together.xyz/v1",
|
||||||
"openrouter": "https://openrouter.ai/api/v1",
|
"openrouter": "https://openrouter.ai/api/v1",
|
||||||
"ollama": "https://ollama.com/api",
|
"ollama": "https://ollama.com/api",
|
||||||
|
"opencode-zen": "https://opencode.ai/zen/v1",
|
||||||
|
"opencode-go": "https://opencode.ai/zen/go/v1",
|
||||||
"fireworks": "https://api.fireworks.ai/inference/v1",
|
"fireworks": "https://api.fireworks.ai/inference/v1",
|
||||||
"venice": "https://api.venice.ai/api/v1",
|
"venice": "https://api.venice.ai/api/v1",
|
||||||
}
|
}
|
||||||
@@ -323,22 +325,33 @@ def setup_webhook_routes(
|
|||||||
endpoint_url = build_chat_url(base_url)
|
endpoint_url = build_chat_url(base_url)
|
||||||
model = body.model or "auto"
|
model = body.model or "auto"
|
||||||
api_key = ep.api_key
|
api_key = ep.api_key
|
||||||
|
if getattr(ep, "provider_auth_id", None):
|
||||||
|
try:
|
||||||
|
from src.endpoint_resolver import resolve_endpoint_runtime
|
||||||
|
base_url, api_key = resolve_endpoint_runtime(ep, owner=token_owner)
|
||||||
|
endpoint_url = build_chat_url(base_url)
|
||||||
|
except Exception:
|
||||||
|
raise HTTPException(500, "Could not resolve endpoint credentials")
|
||||||
|
|
||||||
if model == "auto":
|
if model == "auto":
|
||||||
try:
|
try:
|
||||||
async with httpx.AsyncClient(timeout=5) as client:
|
async with httpx.AsyncClient(timeout=5) as client:
|
||||||
models_url = build_models_url(base_url)
|
models_url = build_models_url(base_url)
|
||||||
hdrs = build_headers(api_key, base_url)
|
hdrs = build_headers(api_key, base_url)
|
||||||
resp = await client.get(models_url, headers=hdrs)
|
if models_url:
|
||||||
resp.raise_for_status()
|
resp = await client.get(models_url, headers=hdrs)
|
||||||
data = resp.json()
|
resp.raise_for_status()
|
||||||
ids = [m.get("id") for m in (data.get("data") or []) if m.get("id")]
|
data = resp.json()
|
||||||
if not ids:
|
ids = [m.get("id") for m in (data.get("data") or []) if m.get("id")]
|
||||||
ids = [
|
if not ids:
|
||||||
m.get("name") or m.get("model")
|
ids = [
|
||||||
for m in (data.get("models") or [])
|
m.get("name") or m.get("model")
|
||||||
if m.get("name") or m.get("model")
|
for m in (data.get("models") or [])
|
||||||
]
|
if m.get("name") or m.get("model")
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
import json as _json
|
||||||
|
ids = _json.loads(ep.cached_models or "[]")
|
||||||
model = ids[0] if ids else "auto"
|
model = ids[0] if ids else "auto"
|
||||||
except Exception:
|
except Exception:
|
||||||
raise HTTPException(500, "Could not discover models from endpoint")
|
raise HTTPException(500, "Could not discover models from endpoint")
|
||||||
|
|||||||
@@ -13,6 +13,8 @@ import json
|
|||||||
|
|
||||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
|
|
||||||
|
from src.constants import MEMORY_FILE, SKILLS_FILE
|
||||||
|
|
||||||
|
|
||||||
def claim_json_entries(entries, owner):
|
def claim_json_entries(entries, owner):
|
||||||
count = 0
|
count = 0
|
||||||
@@ -35,8 +37,8 @@ def main():
|
|||||||
|
|
||||||
# 1. Memories (JSON files)
|
# 1. Memories (JSON files)
|
||||||
for label, path in [
|
for label, path in [
|
||||||
("memory.json", "data/memory.json"),
|
("memory.json", MEMORY_FILE),
|
||||||
("skills.json", "data/skills.json"),
|
("skills.json", SKILLS_FILE),
|
||||||
]:
|
]:
|
||||||
if not os.path.exists(path):
|
if not os.path.exists(path):
|
||||||
print(f" {label}: not found, skipping")
|
print(f" {label}: not found, skipping")
|
||||||
|
|||||||
@@ -34,6 +34,7 @@ import torch
|
|||||||
import uvicorn
|
import uvicorn
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
from starlette.middleware.trustedhost import TrustedHostMiddleware
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
@@ -52,7 +53,63 @@ async def lifespan(application):
|
|||||||
|
|
||||||
|
|
||||||
app = FastAPI(title="Diffusion Server", lifespan=lifespan)
|
app = FastAPI(title="Diffusion Server", lifespan=lifespan)
|
||||||
app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])
|
|
||||||
|
# Conservative defaults — server is designed for server-to-server use from
|
||||||
|
# the Odysseus backend. Wildcard CORS + the 127.0.0.1 default bind used to
|
||||||
|
# leave the server reachable via DNS-rebinding from any browser tab on the
|
||||||
|
# same host. The CLI flags below extend these allowlists for operators who
|
||||||
|
# need browser access; the safe defaults handle the common case.
|
||||||
|
_DEFAULT_ALLOWED_HOSTS = ["127.0.0.1", "localhost", "::1"]
|
||||||
|
_DEFAULT_CORS_ORIGINS: list = [] # default-deny
|
||||||
|
|
||||||
|
|
||||||
|
def _compute_allowed_hosts(bind_host: str, extras=None) -> list:
|
||||||
|
"""Allowed Host header values: the bind address + loopback variants +
|
||||||
|
any operator-supplied --allowed-host values. Duplicates and empty
|
||||||
|
strings are dropped; order is stable for predictable middleware setup."""
|
||||||
|
seen = []
|
||||||
|
for h in (bind_host, *_DEFAULT_ALLOWED_HOSTS, *(extras or [])):
|
||||||
|
h = (h or "").strip()
|
||||||
|
if h and h not in seen:
|
||||||
|
seen.append(h)
|
||||||
|
return seen
|
||||||
|
|
||||||
|
|
||||||
|
def _compute_cors_origins(extras=None) -> list:
|
||||||
|
"""CORS allowlist: default-deny (empty), extended only by explicit
|
||||||
|
--allowed-origin values. Server-to-server callers don't set an Origin
|
||||||
|
header so they're unaffected; this only narrows browser access."""
|
||||||
|
seen = []
|
||||||
|
for o in (*_DEFAULT_CORS_ORIGINS, *(extras or [])):
|
||||||
|
o = (o or "").strip()
|
||||||
|
if o and o not in seen:
|
||||||
|
seen.append(o)
|
||||||
|
return seen
|
||||||
|
|
||||||
|
|
||||||
|
def _configure_security_middleware(application, allowed_hosts, allowed_origins):
|
||||||
|
"""Replace `application`'s user middleware stack with the diffusion server
|
||||||
|
security middleware: the TrustedHost allowlist and, when origins are
|
||||||
|
supplied, CORS. Used at module load and by the __main__ CLI path before
|
||||||
|
serving starts. Raises before mutating if the middleware stack has already
|
||||||
|
been built. Order is preserved: TrustedHost first, then CORS (added last ->
|
||||||
|
outermost)."""
|
||||||
|
if application.middleware_stack is not None:
|
||||||
|
raise RuntimeError("security middleware must be configured before the app starts serving")
|
||||||
|
application.user_middleware.clear()
|
||||||
|
application.add_middleware(TrustedHostMiddleware, allowed_hosts=list(allowed_hosts))
|
||||||
|
if allowed_origins:
|
||||||
|
application.add_middleware(
|
||||||
|
CORSMiddleware,
|
||||||
|
allow_origins=list(allowed_origins),
|
||||||
|
allow_methods=["GET", "POST", "OPTIONS"],
|
||||||
|
allow_headers=["Authorization", "Content-Type"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Install defaults at module load so importing the app for tests / direct
|
||||||
|
# uvicorn invocation still benefits from the Host-header allowlist.
|
||||||
|
_configure_security_middleware(app, _DEFAULT_ALLOWED_HOSTS, _DEFAULT_CORS_ORIGINS)
|
||||||
|
|
||||||
|
|
||||||
class ImageRequest(BaseModel):
|
class ImageRequest(BaseModel):
|
||||||
@@ -1089,7 +1146,25 @@ if __name__ == "__main__":
|
|||||||
parser.add_argument("--attention-slicing", action="store_true", help="Enable attention slicing")
|
parser.add_argument("--attention-slicing", action="store_true", help="Enable attention slicing")
|
||||||
parser.add_argument("--vae-slicing", action="store_true", help="Enable VAE slicing")
|
parser.add_argument("--vae-slicing", action="store_true", help="Enable VAE slicing")
|
||||||
parser.add_argument("--harmonize-gpu", type=int, default=None, help="GPU index for harmonize/img2img (default: same as main)")
|
parser.add_argument("--harmonize-gpu", type=int, default=None, help="GPU index for harmonize/img2img (default: same as main)")
|
||||||
|
parser.add_argument("--allowed-host", action="append", default=[],
|
||||||
|
help="Additional Host header value to accept (DNS-rebinding allowlist). "
|
||||||
|
"Can be repeated. Loopback values are always included.")
|
||||||
|
parser.add_argument("--allowed-origin", action="append", default=[],
|
||||||
|
help="Additional CORS origin to allow. Can be repeated. Defaults to "
|
||||||
|
"no cross-origin access — only pass this if you need a browser "
|
||||||
|
"on a specific origin to call the server.")
|
||||||
_args = parser.parse_args()
|
_args = parser.parse_args()
|
||||||
|
|
||||||
|
# Replace the module-load middleware stack with the CLI-configured one so
|
||||||
|
# operator-supplied --allowed-host / --allowed-origin values take effect
|
||||||
|
# before the first request is served. user_middleware is consulted lazily
|
||||||
|
# when the middleware stack is built on the first request, so mutating it
|
||||||
|
# here is safe.
|
||||||
|
final_hosts = _compute_allowed_hosts(_args.host, _args.allowed_host)
|
||||||
|
final_origins = _compute_cors_origins(_args.allowed_origin)
|
||||||
|
_configure_security_middleware(app, final_hosts, final_origins)
|
||||||
|
logger.info("security middleware: allowed_hosts=%s allowed_origins=%s",
|
||||||
|
final_hosts, final_origins or "(none — default-deny)")
|
||||||
|
|
||||||
app.state.model_path = _args.model
|
app.state.model_path = _args.model
|
||||||
uvicorn.run(app, host=_args.host, port=_args.port)
|
uvicorn.run(app, host=_args.host, port=_args.port)
|
||||||
|
|||||||
@@ -19,6 +19,9 @@ import sys
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Tuple
|
from typing import List, Tuple
|
||||||
|
|
||||||
|
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
|
from src.constants import PERSONAL_DIR
|
||||||
|
|
||||||
# Configure logging for the script
|
# Configure logging for the script
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
level=logging.INFO,
|
level=logging.INFO,
|
||||||
@@ -45,7 +48,7 @@ def main():
|
|||||||
rag_manager = RAGManager()
|
rag_manager = RAGManager()
|
||||||
|
|
||||||
# Directory to scan
|
# Directory to scan
|
||||||
docs_directory = "data/personal_docs"
|
docs_directory = PERSONAL_DIR
|
||||||
directory_path = Path(docs_directory)
|
directory_path = Path(docs_directory)
|
||||||
|
|
||||||
# Check if directory exists
|
# Check if directory exists
|
||||||
|
|||||||
@@ -63,10 +63,10 @@ def migrate_memories():
|
|||||||
"""Migrate memory vectors from FAISS to ChromaDB."""
|
"""Migrate memory vectors from FAISS to ChromaDB."""
|
||||||
from src.chroma_client import get_chroma_client
|
from src.chroma_client import get_chroma_client
|
||||||
from src.embeddings import get_embedding_client
|
from src.embeddings import get_embedding_client
|
||||||
from src.constants import DATA_DIR
|
from src.constants import MEMORY_VECTORS_DIR, MEMORY_FILE
|
||||||
|
|
||||||
ids_path = os.path.join(DATA_DIR, "memory_vectors", "ids.json")
|
ids_path = os.path.join(MEMORY_VECTORS_DIR, "ids.json")
|
||||||
memory_path = os.path.join(DATA_DIR, "memory.json")
|
memory_path = MEMORY_FILE
|
||||||
|
|
||||||
if not os.path.exists(ids_path):
|
if not os.path.exists(ids_path):
|
||||||
logger.info("No memory FAISS index found, skipping memory migration")
|
logger.info("No memory FAISS index found, skipping memory migration")
|
||||||
|
|||||||
@@ -47,6 +47,9 @@ _STATE_PATH = _DATA_DIR / "cookbook_state.json"
|
|||||||
import tempfile
|
import tempfile
|
||||||
_TMUX_LOG_DIR = Path(tempfile.gettempdir()) / "odysseus-tmux"
|
_TMUX_LOG_DIR = Path(tempfile.gettempdir()) / "odysseus-tmux"
|
||||||
|
|
||||||
|
from core.platform_compat import NVIDIA_PATH_CANDIDATES, SSH_PATH_OVERRIDE
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def fail(msg: str, code: int = 1) -> None:
|
def fail(msg: str, code: int = 1) -> None:
|
||||||
sys.stderr.write(f"error: {msg}\n")
|
sys.stderr.write(f"error: {msg}\n")
|
||||||
@@ -160,7 +163,26 @@ def cmd_gpus(args) -> None:
|
|||||||
prefix = _ssh_prefix(args.host, args.ssh_port)
|
prefix = _ssh_prefix(args.host, args.ssh_port)
|
||||||
cmd = prefix + (query.split() if not prefix else [query])
|
cmd = prefix + (query.split() if not prefix else [query])
|
||||||
try:
|
try:
|
||||||
out = subprocess.run(cmd, capture_output=True, text=True, timeout=15)
|
if prefix:
|
||||||
|
candidates = [query]
|
||||||
|
args_part = query[len("nvidia-smi "):]
|
||||||
|
candidates.append(
|
||||||
|
"bash -lc "
|
||||||
|
+ repr(
|
||||||
|
f"{SSH_PATH_OVERRIDE}"
|
||||||
|
f"nvidia-smi {args_part}"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
for nvidia_path in NVIDIA_PATH_CANDIDATES:
|
||||||
|
candidates.append(f"{nvidia_path} {args_part}")
|
||||||
|
|
||||||
|
out = None
|
||||||
|
for candidate in candidates:
|
||||||
|
out = subprocess.run(prefix + [candidate], capture_output=True, text=True, timeout=15)
|
||||||
|
if out.returncode == 0:
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
out = subprocess.run(cmd, capture_output=True, text=True, timeout=15)
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
# No nvidia-smi locally → try the Metal fallback before giving up.
|
# No nvidia-smi locally → try the Metal fallback before giving up.
|
||||||
if not prefix:
|
if not prefix:
|
||||||
|
|||||||
@@ -25,6 +25,24 @@ from pathlib import Path
|
|||||||
|
|
||||||
_DATA_DIR = _REPO_ROOT / "data" / "deep_research"
|
_DATA_DIR = _REPO_ROOT / "data" / "deep_research"
|
||||||
|
|
||||||
|
# The CLI's --status takes the user-facing label "complete", but the writer
|
||||||
|
# in services/research/research_handler.py stores `status="done"` when a run
|
||||||
|
# finishes (and the legacy src/research_handler.py does the same). Without
|
||||||
|
# this alias, --status complete filters every finished record out and the
|
||||||
|
# user sees an empty list. Map at filter time so the on-disk corpus is the
|
||||||
|
# source of truth and the CLI surface stays the friendlier word. The other
|
||||||
|
# choices ("running", "cancelled", "error") are stored verbatim, so they
|
||||||
|
# fall through unchanged.
|
||||||
|
_STATUS_CLI_TO_STORED = {"complete": "done"}
|
||||||
|
|
||||||
|
|
||||||
|
def _status_matches(stored, requested: str) -> bool:
|
||||||
|
stored = (stored or "")
|
||||||
|
if not isinstance(stored, str):
|
||||||
|
stored = ""
|
||||||
|
target = _STATUS_CLI_TO_STORED.get(requested, requested)
|
||||||
|
return stored == target
|
||||||
|
|
||||||
|
|
||||||
def _load_path(path: Path) -> dict | None:
|
def _load_path(path: Path) -> dict | None:
|
||||||
try:
|
try:
|
||||||
@@ -72,7 +90,7 @@ def cmd_list(args):
|
|||||||
data = _load_path(path)
|
data = _load_path(path)
|
||||||
if data is None:
|
if data is None:
|
||||||
continue
|
continue
|
||||||
if args.status and (data.get("status") or "") != args.status:
|
if args.status and not _status_matches(data.get("status"), args.status):
|
||||||
continue
|
continue
|
||||||
out.append(_summarize(rp_id, data))
|
out.append(_summarize(rp_id, data))
|
||||||
out.sort(key=lambda r: r.get("started_at") or "", reverse=True)
|
out.sort(key=lambda r: r.get("started_at") or "", reverse=True)
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ from dataclasses import dataclass
|
|||||||
from typing import List, Dict, Any
|
from typing import List, Dict, Any
|
||||||
|
|
||||||
from src.rag_manager import RAGManager
|
from src.rag_manager import RAGManager
|
||||||
|
from src.constants import CHROMA_DIR
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -34,7 +35,7 @@ class DocsService:
|
|||||||
results = await service.query("what is async await?")
|
results = await service.query("what is async await?")
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, persist_dir: str = "data/chroma"):
|
def __init__(self, persist_dir: str = CHROMA_DIR):
|
||||||
self.rag = RAGManager(persist_directory=persist_dir)
|
self.rag = RAGManager(persist_directory=persist_dir)
|
||||||
|
|
||||||
async def query(self, query: str, top_k: int = 5) -> List[DocChunk]:
|
async def query(self, query: str, top_k: int = 5) -> List[DocChunk]:
|
||||||
|
|||||||
+93
-48
@@ -4,6 +4,13 @@ import re
|
|||||||
import shutil
|
import shutil
|
||||||
import subprocess
|
import subprocess
|
||||||
import time
|
import time
|
||||||
|
import shlex
|
||||||
|
|
||||||
|
from core.platform_compat import (
|
||||||
|
NVIDIA_PATH_CANDIDATES,
|
||||||
|
SSH_PATH_OVERRIDE,
|
||||||
|
run_ssh_command,
|
||||||
|
)
|
||||||
|
|
||||||
CACHE_TTL = 24 * 3600 # 24 h — hardware probes are user-initiated via the Rescan button; bumped
|
CACHE_TTL = 24 * 3600 # 24 h — hardware probes are user-initiated via the Rescan button; bumped
|
||||||
# from 30 min so changing filters doesn't keep re-probing the rig every
|
# from 30 min so changing filters doesn't keep re-probing the rig every
|
||||||
@@ -21,16 +28,17 @@ def _run(cmd):
|
|||||||
if _remote_host:
|
if _remote_host:
|
||||||
# Run command on remote host via SSH
|
# Run command on remote host via SSH
|
||||||
if isinstance(cmd, list):
|
if isinstance(cmd, list):
|
||||||
cmd_str = " ".join(cmd)
|
cmd_str = shlex.join(str(c) for c in cmd)
|
||||||
else:
|
else:
|
||||||
cmd_str = cmd
|
cmd_str = cmd
|
||||||
ssh_cmd = ["ssh", "-o", "ConnectTimeout=5", "-o", "StrictHostKeyChecking=no"]
|
r = run_ssh_command(
|
||||||
if _remote_port and _remote_port != "22":
|
_remote_host,
|
||||||
ssh_cmd += ["-p", _remote_port]
|
_remote_port,
|
||||||
ssh_cmd += [_remote_host, cmd_str]
|
cmd_str,
|
||||||
r = subprocess.run(
|
timeout=15,
|
||||||
ssh_cmd,
|
connect_timeout=5,
|
||||||
capture_output=True, text=True, timeout=15,
|
strict_host_key_checking=False,
|
||||||
|
text=True,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
r = subprocess.run(cmd, capture_output=True, text=True, timeout=10)
|
r = subprocess.run(cmd, capture_output=True, text=True, timeout=10)
|
||||||
@@ -76,21 +84,29 @@ def _detect_nvidia():
|
|||||||
global _last_gpu_error
|
global _last_gpu_error
|
||||||
_last_gpu_error = None
|
_last_gpu_error = None
|
||||||
out = _run(["nvidia-smi", "--query-gpu=memory.total,name", "--format=csv,noheader,nounits"])
|
out = _run(["nvidia-smi", "--query-gpu=memory.total,name", "--format=csv,noheader,nounits"])
|
||||||
# Remote fallback: a non-interactive SSH shell often has a minimal PATH
|
# Fallback: a non-interactive shell (or WSL) often has a minimal PATH
|
||||||
# that omits where nvidia-smi lives (/usr/bin, /usr/local/cuda/bin), so the
|
# that omits where nvidia-smi lives (/usr/bin, /usr/local/cuda/bin,
|
||||||
# first call silently returns nothing → "No GPU" on hosts that DO have GPUs.
|
# /usr/lib/wsl/lib), so the first call silently returns nothing →
|
||||||
|
# "No GPU" on machines that DO have GPUs.
|
||||||
# Retry through a login shell with the common CUDA bin dirs on PATH.
|
# Retry through a login shell with the common CUDA bin dirs on PATH.
|
||||||
if not out and _remote_host:
|
if not out and _remote_host:
|
||||||
out = _run(
|
out = _run(
|
||||||
"bash -lc 'export PATH=\"$PATH:/usr/bin:/usr/local/bin:/usr/local/cuda/bin\"; "
|
f"bash -lc '{SSH_PATH_OVERRIDE}"
|
||||||
"nvidia-smi --query-gpu=memory.total,name --format=csv,noheader,nounits'"
|
"nvidia-smi --query-gpu=memory.total,name --format=csv,noheader,nounits'"
|
||||||
)
|
)
|
||||||
# Last resort: call nvidia-smi by absolute path. Some hosts have a login
|
# Last resort: call nvidia-smi by absolute path. Some hosts have a login
|
||||||
# shell that isn't bash (or a profile that errors), so the bash -lc retry
|
# shell that isn't bash (or a profile that errors), so the bash -lc retry
|
||||||
# above still comes back empty even though the binary is right there.
|
# above still comes back empty even though the binary is right there.
|
||||||
if not out and _remote_host:
|
# Also handles WSL where nvidia-smi lives at /usr/lib/wsl/lib/ — a path
|
||||||
for _p in ("/usr/bin/nvidia-smi", "/usr/local/bin/nvidia-smi", "/usr/local/cuda/bin/nvidia-smi"):
|
# that may not be in the server process's PATH.
|
||||||
out = _run(f"{_p} --query-gpu=memory.total,name --format=csv,noheader,nounits")
|
if not out:
|
||||||
|
for _p in NVIDIA_PATH_CANDIDATES:
|
||||||
|
# Use list form so subprocess.run (local) resolves the absolute path
|
||||||
|
# correctly instead of treating the whole string as an executable name.
|
||||||
|
if _remote_host:
|
||||||
|
out = _run(f"{_p} --query-gpu=memory.total,name --format=csv,noheader,nounits")
|
||||||
|
else:
|
||||||
|
out = _run([_p, "--query-gpu=memory.total,name", "--format=csv,noheader,nounits"])
|
||||||
if out:
|
if out:
|
||||||
break
|
break
|
||||||
if not out:
|
if not out:
|
||||||
@@ -468,39 +484,55 @@ def _detect_windows():
|
|||||||
"""
|
"""
|
||||||
# Single PowerShell command that gathers all hardware info at once
|
# Single PowerShell command that gathers all hardware info at once
|
||||||
ps_cmd = (
|
ps_cmd = (
|
||||||
"$r = @{}; "
|
"""
|
||||||
"$os = Get-CimInstance Win32_OperatingSystem; "
|
$r = @{}
|
||||||
"$r.ram_gb = [math]::Round($os.TotalVisibleMemorySize / 1048576, 1); "
|
$os = Get-CimInstance Win32_OperatingSystem
|
||||||
"$r.avail_gb = [math]::Round($os.FreePhysicalMemory / 1048576, 1); "
|
$r.ram_gb = [math]::Round($os.TotalVisibleMemorySize / 1048576, 1)
|
||||||
"$cpu = Get-CimInstance Win32_Processor | Select-Object -First 1; "
|
$r.avail_gb = [math]::Round($os.FreePhysicalMemory / 1048576, 1)
|
||||||
"$r.cpu_name = $cpu.Name; "
|
$cpu = Get-CimInstance Win32_Processor | Select-Object -First 1
|
||||||
"$r.cpu_cores = (Get-CimInstance Win32_Processor | Measure-Object -Property NumberOfLogicalProcessors -Sum).Sum; "
|
$r.cpu_name = $cpu.Name
|
||||||
"$r.arch = $cpu.AddressWidth; "
|
$r.cpu_cores = (Get-CimInstance Win32_Processor | Measure-Object -Property NumberOfLogicalProcessors -Sum).Sum
|
||||||
|
$r.arch = $cpu.AddressWidth
|
||||||
# GPU detection via nvidia-smi (fastest) or WMI fallback
|
# GPU detection via nvidia-smi (fastest) or WMI fallback
|
||||||
"try { "
|
try {
|
||||||
" $nv = nvidia-smi --query-gpu=memory.total,name --format=csv,noheader,nounits 2>$null; "
|
$nv = nvidia-smi --query-gpu=memory.total,name --format=csv,noheader,nounits 2>$null
|
||||||
" if ($LASTEXITCODE -eq 0 -and $nv) { "
|
if ($LASTEXITCODE -eq 0 -and $nv) {
|
||||||
" $gpus = @(); "
|
$gpus = @()
|
||||||
" foreach ($line in $nv -split \"`n\") { "
|
foreach ($line in $nv -split "`n") {
|
||||||
" $p = $line -split ','; "
|
$p = $line -split ','
|
||||||
" if ($p.Count -ge 2) { $gpus += [pscustomobject]@{name=$p[1].Trim(); vram_mb=[double]$p[0].Trim()} } "
|
if ($p.Count -ge 2) { $gpus += [pscustomobject]@{name = $p[1].Trim(); vram_mb = [double]$p[0].Trim() } }
|
||||||
" }; "
|
}
|
||||||
" $r.gpu_name = $gpus[0].name; "
|
$r.gpu_name = $gpus[0].name
|
||||||
" $r.gpu_vram_gb = [math]::Round(($gpus | Measure-Object -Property vram_mb -Sum).Sum / 1024, 1); "
|
$r.gpu_vram_gb = [math]::Round(($gpus | Measure-Object -Property vram_mb -Sum).Sum / 1024, 1)
|
||||||
" $r.gpu_count = $gpus.Count; "
|
$r.gpu_count = $gpus.Count
|
||||||
" $r.gpu_backend = 'cuda'; "
|
$r.gpu_backend = 'cuda'
|
||||||
" } "
|
}
|
||||||
"} catch {}; "
|
}
|
||||||
"if (-not $r.gpu_name) { "
|
catch {}
|
||||||
" $wmiGpu = Get-CimInstance Win32_VideoController | Where-Object { $_.AdapterRAM -gt 0 } | Select-Object -First 1; "
|
if (-not $r.gpu_name) {
|
||||||
" if ($wmiGpu) { "
|
$wmiGpu = Get-CimInstance Win32_VideoController | Where-Object { $_.AdapterRAM -gt 0 } | Select-Object -First 1
|
||||||
" $r.gpu_name = $wmiGpu.Name; "
|
$GPUDriverKey = "HKLM:\\SYSTEM\\CurrentControlSet\\Control\\Class\\{4d36e968-e325-11ce-bfc1-08002be10318}\\0*"
|
||||||
" $r.gpu_vram_gb = [math]::Round($wmiGpu.AdapterRAM / 1073741824, 1); "
|
$GPUDeviceID = $wmiGpu.PNPDeviceID.Split('&')[0..1] -join '&'
|
||||||
" $r.gpu_count = 1; "
|
$VRAMfromRegistry = Get-ItemProperty -Path $GPUDriverKey |
|
||||||
" $r.gpu_backend = 'cpu_x86'; " # WMI doesn't tell us CUDA/ROCm
|
Where-Object { $_.MatchingDeviceId -like "${GPUDeviceID}*" } |
|
||||||
" } "
|
# Sometimes there happen to be multiple driver classes for the same gpu.
|
||||||
"}; "
|
Select-Object -ExpandProperty HardwareInformation.qwMemorySize -ErrorAction SilentlyContinue -First 1
|
||||||
"$r | ConvertTo-Json -Compress"
|
if ($wmiGpu) {
|
||||||
|
$r.gpu_name = $wmiGpu.Name
|
||||||
|
# Edge case: driver is broken, otherwise $wmiGpu.AdapterRAM is redundant
|
||||||
|
if ($VRAMfromRegistry -ge $wmiGpu.AdapterRAM) {
|
||||||
|
$r.gpu_vram_gb = [math]::Round($VRAMfromRegistry / 1073741824, 1)
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
$r.gpu_vram_gb = [math]::Round($wmiGpu.AdapterRAM / 1073741824, 1)
|
||||||
|
}
|
||||||
|
$r.gpu_count = 1
|
||||||
|
# WMI doesn't tell us CUDA/ROCm
|
||||||
|
$r.gpu_backend = 'cpu_x86';
|
||||||
|
}
|
||||||
|
}
|
||||||
|
$r | ConvertTo-Json -Compress
|
||||||
|
"""
|
||||||
)
|
)
|
||||||
if _remote_host:
|
if _remote_host:
|
||||||
# Remote: ship a single command string over SSH. The remote shell parses
|
# Remote: ship a single command string over SSH. The remote shell parses
|
||||||
@@ -566,6 +598,19 @@ def _detect_windows():
|
|||||||
_cache_by_host = {} # host -> (timestamp, result)
|
_cache_by_host = {} # host -> (timestamp, result)
|
||||||
|
|
||||||
|
|
||||||
|
def _cache_key(host: str, ssh_port: str, platform_name: str):
|
||||||
|
"""Build a stable cache key that isolates remote SSH context.
|
||||||
|
|
||||||
|
Same host aliases can have different hardware due to visibility, forwarding etc.
|
||||||
|
To avoid using the wrong cached hardware info, include the SSH port and platform in the cache key.
|
||||||
|
"""
|
||||||
|
return (
|
||||||
|
host or "_local",
|
||||||
|
str(ssh_port or ""),
|
||||||
|
str(platform_name or "").lower(),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def detect_system(host="", ssh_port="", platform="", fresh=False):
|
def detect_system(host="", ssh_port="", platform="", fresh=False):
|
||||||
"""Detect system hardware: RAM, CPU, GPU. Cached per host (hardware rarely
|
"""Detect system hardware: RAM, CPU, GPU. Cached per host (hardware rarely
|
||||||
changes, and probing a remote host over SSH is slow). Pass fresh=True to
|
changes, and probing a remote host over SSH is slow). Pass fresh=True to
|
||||||
@@ -575,7 +620,7 @@ def detect_system(host="", ssh_port="", platform="", fresh=False):
|
|||||||
"""
|
"""
|
||||||
global _remote_host, _remote_port, _remote_platform
|
global _remote_host, _remote_port, _remote_platform
|
||||||
|
|
||||||
cache_key = host or "_local"
|
cache_key = _cache_key(host, ssh_port, platform)
|
||||||
now = time.time()
|
now = time.time()
|
||||||
if not fresh and cache_key in _cache_by_host:
|
if not fresh and cache_key in _cache_by_host:
|
||||||
ts, cached = _cache_by_host[cache_key]
|
ts, cached = _cache_by_host[cache_key]
|
||||||
|
|||||||
@@ -192,11 +192,19 @@ def _fallback_memory_candidates(messages) -> list[dict]:
|
|||||||
if place:
|
if place:
|
||||||
add(f"User lives in {place}.", "identity")
|
add(f"User lives in {place}.", "identity")
|
||||||
|
|
||||||
m = re.search(r"\bi (?:prefer|like|love|hate|do not like|don't like)\s+([^.!?\n]{4,100})", text, re.I)
|
m = re.search(r"\bi (prefer|like|love|hate|do not like|don't like)\s+([^.!?\n]{4,100})", text, re.I)
|
||||||
if m:
|
if m:
|
||||||
preference = _clean_memory_value(m.group(1), 100)
|
preference = _clean_memory_value(m.group(2), 100)
|
||||||
if preference:
|
if preference:
|
||||||
add(f"User prefers {preference}.", "preference")
|
# The same pattern catches likes and dislikes; keep the stored
|
||||||
|
# sentiment faithful instead of recording every match as a
|
||||||
|
# preference ("I hate cilantro" must not become "User prefers
|
||||||
|
# cilantro").
|
||||||
|
verb = m.group(1).lower()
|
||||||
|
if verb in ("hate", "do not like", "don't like"):
|
||||||
|
add(f"User dislikes {preference}.", "preference")
|
||||||
|
else:
|
||||||
|
add(f"User prefers {preference}.", "preference")
|
||||||
|
|
||||||
m = re.search(
|
m = re.search(
|
||||||
r"\bi (?:(?:want|would like|plan|hope) to|wanna) "
|
r"\bi (?:(?:want|would like|plan|hope) to|wanna) "
|
||||||
@@ -228,6 +236,43 @@ def _is_text_duplicate(new_text: str, existing: list, threshold: float = 0.6) ->
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_extraction_json(raw: str) -> list:
|
||||||
|
"""Parse the extraction LLM's reply into a list of facts, tolerating
|
||||||
|
reasoning-model noise.
|
||||||
|
|
||||||
|
The model emits <think>…</think> (and sometimes a prose preamble or a
|
||||||
|
```json fence) AROUND the JSON array; without stripping it, json.loads
|
||||||
|
bombs and the run silently yields "0 candidates". Pure str -> list (no
|
||||||
|
LLM/network); returns [] on any parse failure instead of raising.
|
||||||
|
"""
|
||||||
|
text = (raw or "").strip()
|
||||||
|
try:
|
||||||
|
from src.text_helpers import strip_think as _strip_think
|
||||||
|
text = _strip_think(text, prose=True, prompt_echo=True).strip()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
if text.startswith("```"):
|
||||||
|
text = text.split("\n", 1)[-1].rsplit("```", 1)[0].strip()
|
||||||
|
# JSON may still be embedded in surrounding commentary (leading prose or
|
||||||
|
# trailing remarks like "[...] Done!") — slice from the first '[' to the
|
||||||
|
# last ']' whenever both exist. Slice unconditionally: a reply that starts
|
||||||
|
# with '[' can still carry trailing commentary that breaks json.loads.
|
||||||
|
_start = text.find("[")
|
||||||
|
_end = text.rfind("]")
|
||||||
|
if 0 <= _start < _end:
|
||||||
|
text = text[_start : _end + 1]
|
||||||
|
|
||||||
|
try:
|
||||||
|
facts = json.loads(text)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
logger.debug("Memory extraction returned non-JSON: %r", (raw or "")[:120])
|
||||||
|
return []
|
||||||
|
except Exception:
|
||||||
|
logger.debug("Memory extraction returned non-JSON: %r", (raw or "")[:120])
|
||||||
|
return []
|
||||||
|
return facts if isinstance(facts, list) else []
|
||||||
|
|
||||||
|
|
||||||
async def extract_and_store(
|
async def extract_and_store(
|
||||||
session,
|
session,
|
||||||
memory_manager,
|
memory_manager,
|
||||||
@@ -276,9 +321,34 @@ async def extract_and_store(
|
|||||||
|
|
||||||
fallback_facts = _fallback_memory_candidates(stripped_recent)
|
fallback_facts = _fallback_memory_candidates(stripped_recent)
|
||||||
|
|
||||||
|
# Flatten the window into a SINGLE user message instead of appending the
|
||||||
|
# raw alternating role messages. Passed as raw chat messages, the model
|
||||||
|
# treats the window as a conversation to CONTINUE rather than a transcript
|
||||||
|
# to ANALYZE, so it reliably extracts nothing — typically returning `[]`
|
||||||
|
# (and, depending on the input, sometimes an empty or <think>-only
|
||||||
|
# completion when the window ends on an assistant turn). This was the real
|
||||||
|
# cause of auto-memory logging "0 candidates" on every run. Reframing it as
|
||||||
|
# one "analyze this transcript, return the JSON array" user message makes
|
||||||
|
# the model actually extract. Controlled repro on this model: 0/6 trials
|
||||||
|
# with the old structure vs 6/6 with this one. The skill extractor flattens
|
||||||
|
# for the same reason.
|
||||||
|
def _flatten_msg(m):
|
||||||
|
c = m.get("content", "")
|
||||||
|
if isinstance(c, list):
|
||||||
|
c = " ".join(
|
||||||
|
b.get("text", "") for b in c
|
||||||
|
if isinstance(b, dict) and b.get("type") == "text"
|
||||||
|
)
|
||||||
|
return f"{m.get('role', '?')}: {c}"
|
||||||
|
|
||||||
|
transcript = "\n\n".join(_flatten_msg(m) for m in stripped_recent)
|
||||||
extraction_messages = [
|
extraction_messages = [
|
||||||
{"role": "system", "content": EXTRACT_SYSTEM_PROMPT},
|
{"role": "system", "content": EXTRACT_SYSTEM_PROMPT},
|
||||||
] + stripped_recent
|
{"role": "user", "content": (
|
||||||
|
"Conversation to analyze:\n\n" + transcript
|
||||||
|
+ "\n\nReturn the JSON array of durable facts now (or [] if none)."
|
||||||
|
)},
|
||||||
|
]
|
||||||
|
|
||||||
facts = []
|
facts = []
|
||||||
try:
|
try:
|
||||||
@@ -287,19 +357,20 @@ async def extract_and_store(
|
|||||||
model,
|
model,
|
||||||
extraction_messages,
|
extraction_messages,
|
||||||
temperature=0.1,
|
temperature=0.1,
|
||||||
max_tokens=500,
|
# A reasoning model spends most of its budget on <think> tokens
|
||||||
|
# BEFORE emitting the JSON, so the old 500 truncated the response
|
||||||
|
# before any JSON appeared → every run logged "0 candidates". The
|
||||||
|
# audit path hit the same wall and raised to 16384; extraction's
|
||||||
|
# output (a short facts list) is small, so an ample ceiling is
|
||||||
|
# enough once thinking has room.
|
||||||
|
max_tokens=4096,
|
||||||
headers=headers,
|
headers=headers,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Parse JSON from response (handle markdown fences if model wraps them)
|
# Parse JSON, tolerating reasoning-model noise (<think> blocks, a
|
||||||
text = raw.strip()
|
# ```json fence, and leading/trailing commentary). See
|
||||||
if text.startswith("```"):
|
# _parse_extraction_json — returns [] rather than raising.
|
||||||
text = text.split("\n", 1)[-1].rsplit("```", 1)[0].strip()
|
facts = _parse_extraction_json(raw)
|
||||||
|
|
||||||
try:
|
|
||||||
facts = json.loads(text)
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
logger.debug("Memory extraction returned non-JSON")
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"LLM memory extraction failed; using fallback candidates if available: {e}")
|
logger.warning(f"LLM memory extraction failed; using fallback candidates if available: {e}")
|
||||||
|
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import os
|
|||||||
from .memory import MemoryManager
|
from .memory import MemoryManager
|
||||||
from .memory_vector import MemoryVectorStore
|
from .memory_vector import MemoryVectorStore
|
||||||
from src.memory_provider import MemoryRecord, NativeMemoryProvider
|
from src.memory_provider import MemoryRecord, NativeMemoryProvider
|
||||||
|
from src.constants import DATA_DIR
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -38,7 +39,7 @@ class MemoryService:
|
|||||||
results = await service.recall("preferences")
|
results = await service.recall("preferences")
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, data_dir: str = "data"):
|
def __init__(self, data_dir: str = DATA_DIR):
|
||||||
self.manager = MemoryManager(data_dir)
|
self.manager = MemoryManager(data_dir)
|
||||||
self.vector_store = MemoryVectorStore(data_dir) if os.path.exists(
|
self.vector_store = MemoryVectorStore(data_dir) if os.path.exists(
|
||||||
os.path.join(data_dir, "memory_vectors")
|
os.path.join(data_dir, "memory_vectors")
|
||||||
|
|||||||
@@ -63,6 +63,46 @@ def _has_duplicate_title(skills, title: str) -> bool:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_json_object(text: str) -> Optional[dict]:
|
||||||
|
"""Best-effort extraction of a JSON object from an LLM response.
|
||||||
|
|
||||||
|
The response may be wrapped in code fences or surrounded by prose, and some
|
||||||
|
models emit a stray brace in the prose before the real object
|
||||||
|
(e.g. "uses {placeholder} then {...}"). Slicing first-'{' .. last-'}' then
|
||||||
|
grabs an unparseable span and the skill is silently lost. Try the whole
|
||||||
|
string first, then each '{' start position in turn, returning the first
|
||||||
|
candidate that parses to a JSON object (dict). Returns None if none do.
|
||||||
|
"""
|
||||||
|
if not text:
|
||||||
|
return None
|
||||||
|
s = text.strip()
|
||||||
|
if s.startswith("```"):
|
||||||
|
s = s.split("\n", 1)[-1].rsplit("```", 1)[0].strip()
|
||||||
|
end = s.rfind("}")
|
||||||
|
if end == -1:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _as_dict(candidate):
|
||||||
|
try:
|
||||||
|
obj = json.loads(candidate)
|
||||||
|
except (json.JSONDecodeError, ValueError):
|
||||||
|
return None
|
||||||
|
return obj if isinstance(obj, dict) else None
|
||||||
|
|
||||||
|
# The clean, common case: the whole (de-fenced) string is the object.
|
||||||
|
obj = _as_dict(s)
|
||||||
|
if obj is not None:
|
||||||
|
return obj
|
||||||
|
# Otherwise scan each '{' candidate up to the last '}'.
|
||||||
|
start = s.find("{")
|
||||||
|
while 0 <= start < end:
|
||||||
|
obj = _as_dict(s[start : end + 1])
|
||||||
|
if obj is not None:
|
||||||
|
return obj
|
||||||
|
start = s.find("{", start + 1)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
async def maybe_extract_skill(
|
async def maybe_extract_skill(
|
||||||
session,
|
session,
|
||||||
skills_manager,
|
skills_manager,
|
||||||
@@ -169,21 +209,14 @@ async def maybe_extract_skill(
|
|||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# Parse JSON
|
# Parse JSON. The object may be wrapped in code fences or surrounded by
|
||||||
text = response.strip()
|
# commentary (and may contain a stray/invalid brace fragment before
|
||||||
if text.startswith("```"):
|
# the real object — including one that makes the response itself look
|
||||||
text = text.split("\n", 1)[-1].rsplit("```", 1)[0].strip()
|
# like it starts with '{'), so use a tolerant extractor that tries the
|
||||||
# After strip_think, the JSON may still be embedded inside surrounding
|
# whole string first and then each '{' candidate left-to-right.
|
||||||
# commentary — slice from the first '{' to the matching last '}'.
|
data = _extract_json_object(response)
|
||||||
if text and text[0] != "{":
|
if not data:
|
||||||
_start = text.find("{")
|
logger.debug("[skill-extract] no JSON object found in response, dropping")
|
||||||
_end = text.rfind("}")
|
|
||||||
if 0 <= _start < _end:
|
|
||||||
text = text[_start : _end + 1]
|
|
||||||
|
|
||||||
data = json.loads(text)
|
|
||||||
if not data or not isinstance(data, dict):
|
|
||||||
logger.debug("[skill-extract] parsed JSON not a dict, dropping")
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
title = data.get("title", "").strip()
|
title = data.get("title", "").strip()
|
||||||
|
|||||||
@@ -0,0 +1,283 @@
|
|||||||
|
"""Import SKILL.md bundles from public GitHub (or skills.sh → GitHub) URLs."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
from urllib.parse import quote, urlparse
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
from src.url_safety import check_outbound_url
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
MAX_FILES = 64
|
||||||
|
MAX_TOTAL_BYTES = 2_000_000
|
||||||
|
MAX_FILE_BYTES = 400_000
|
||||||
|
ALLOWED_SUFFIXES = (
|
||||||
|
".md", ".txt", ".json", ".yaml", ".yml", ".py", ".sh", ".toml",
|
||||||
|
".js", ".ts", ".css", ".html", ".xml", ".csv",
|
||||||
|
)
|
||||||
|
TEXT_NAMES = {"skill.md", "license", "license.md", "readme.md"}
|
||||||
|
_GITHUB_HOSTS = frozenset({
|
||||||
|
"github.com", "www.github.com", "api.github.com", "raw.githubusercontent.com",
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
def _github_host(url: str) -> str:
|
||||||
|
return (urlparse(str(url)).hostname or "").lower()
|
||||||
|
|
||||||
|
|
||||||
|
def _assert_github_url(url: str, *, context: str = "URL") -> None:
|
||||||
|
host = _github_host(url)
|
||||||
|
if host not in _GITHUB_HOSTS:
|
||||||
|
raise SkillImportError(
|
||||||
|
f"{context} must stay on GitHub (got {host or 'unknown host'})"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ResolvedSource:
|
||||||
|
owner: str
|
||||||
|
repo: str
|
||||||
|
ref: str
|
||||||
|
path: str # directory or file path inside repo (no leading slash)
|
||||||
|
|
||||||
|
|
||||||
|
class SkillImportError(ValueError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def _safe_relpath(rel: str) -> str:
|
||||||
|
rel = (rel or "").replace("\\", "/").strip().lstrip("/")
|
||||||
|
if not rel or rel.startswith("..") or "/../" in f"/{rel}/":
|
||||||
|
raise SkillImportError(f"unsafe path: {rel!r}")
|
||||||
|
parts = [p for p in rel.split("/") if p and p != "."]
|
||||||
|
if any(p == ".." for p in parts):
|
||||||
|
raise SkillImportError(f"unsafe path: {rel!r}")
|
||||||
|
return "/".join(parts)
|
||||||
|
|
||||||
|
|
||||||
|
def _is_text_file(name: str) -> bool:
|
||||||
|
low = name.lower()
|
||||||
|
if low in TEXT_NAMES:
|
||||||
|
return True
|
||||||
|
return any(low.endswith(s) for s in ALLOWED_SUFFIXES)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_skill_source(url: str) -> ResolvedSource:
|
||||||
|
"""Normalize skills.sh / GitHub web URLs into owner/repo/ref/path."""
|
||||||
|
raw = (url or "").strip()
|
||||||
|
if not raw:
|
||||||
|
raise SkillImportError("URL is required")
|
||||||
|
|
||||||
|
# skills.sh often links to GitHub; try to unwrap ?url= or redirect target later.
|
||||||
|
if "skills.sh" in raw and "github.com" not in raw:
|
||||||
|
ok, reason = check_outbound_url(raw)
|
||||||
|
if not ok:
|
||||||
|
raise SkillImportError(reason)
|
||||||
|
with httpx.Client(follow_redirects=True, timeout=20.0) as client:
|
||||||
|
r = client.get(raw)
|
||||||
|
if r.status_code >= 400:
|
||||||
|
raise _github_response_error(r)
|
||||||
|
final = str(r.url)
|
||||||
|
_assert_github_url(final, context="redirect target")
|
||||||
|
# Page may embed a github link; prefer final URL if redirected.
|
||||||
|
if "github.com" in final:
|
||||||
|
raw = final
|
||||||
|
else:
|
||||||
|
m = re.search(r"https?://github\.com/[^\s\"')]+", r.text or "")
|
||||||
|
if m:
|
||||||
|
raw = m.group(0).rstrip(".,)")
|
||||||
|
|
||||||
|
parsed = urlparse(raw)
|
||||||
|
host = _github_host(raw)
|
||||||
|
if host not in _GITHUB_HOSTS:
|
||||||
|
raise SkillImportError(
|
||||||
|
"Only GitHub URLs are supported (https://github.com/... or raw.githubusercontent.com/...)"
|
||||||
|
)
|
||||||
|
|
||||||
|
if host == "raw.githubusercontent.com":
|
||||||
|
# /owner/repo/ref/path/to/file
|
||||||
|
bits = [p for p in parsed.path.split("/") if p]
|
||||||
|
if len(bits) < 4:
|
||||||
|
raise SkillImportError("Invalid raw GitHub URL")
|
||||||
|
owner, repo, ref = bits[0], bits[1], bits[2]
|
||||||
|
path = "/".join(bits[3:])
|
||||||
|
return ResolvedSource(owner=owner, repo=repo, ref=ref, path=path)
|
||||||
|
|
||||||
|
bits = [p for p in parsed.path.split("/") if p]
|
||||||
|
if len(bits) < 2:
|
||||||
|
raise SkillImportError("Invalid GitHub URL")
|
||||||
|
owner, repo = bits[0], bits[1]
|
||||||
|
ref = "main"
|
||||||
|
path = ""
|
||||||
|
|
||||||
|
if len(bits) >= 4 and bits[2] in ("tree", "blob"):
|
||||||
|
ref = bits[3]
|
||||||
|
path = "/".join(bits[4:])
|
||||||
|
elif len(bits) == 2:
|
||||||
|
path = ""
|
||||||
|
else:
|
||||||
|
raise SkillImportError("GitHub URL must include /tree/<branch>/... or /blob/<branch>/...")
|
||||||
|
|
||||||
|
return ResolvedSource(owner=owner, repo=repo, ref=ref, path=path)
|
||||||
|
|
||||||
|
|
||||||
|
def _raw_url(src: ResolvedSource, rel_path: str) -> str:
|
||||||
|
rel = _safe_relpath(rel_path)
|
||||||
|
return f"https://raw.githubusercontent.com/{src.owner}/{src.repo}/{quote(src.ref, safe='')}/{quote(rel, safe='/')}"
|
||||||
|
|
||||||
|
|
||||||
|
def _api_contents_url(src: ResolvedSource, rel_path: str = "") -> str:
|
||||||
|
rel = _safe_relpath(rel_path) if rel_path else ""
|
||||||
|
base = f"https://api.github.com/repos/{src.owner}/{src.repo}/contents"
|
||||||
|
if rel:
|
||||||
|
base += f"/{quote(rel, safe='/')}"
|
||||||
|
return f"{base}?ref={quote(src.ref, safe='')}"
|
||||||
|
|
||||||
|
|
||||||
|
def _github_response_error(response: httpx.Response) -> SkillImportError:
|
||||||
|
"""Turn a failed GitHub HTTP response into a user-visible import error."""
|
||||||
|
status = response.status_code
|
||||||
|
detail = ""
|
||||||
|
try:
|
||||||
|
body = response.json()
|
||||||
|
if isinstance(body, dict):
|
||||||
|
detail = str(body.get("message") or "").strip()
|
||||||
|
except Exception:
|
||||||
|
detail = (response.text or "").strip()[:200]
|
||||||
|
|
||||||
|
low = detail.lower()
|
||||||
|
if status == 403 and "rate limit" in low:
|
||||||
|
return SkillImportError(
|
||||||
|
"GitHub API rate limit exceeded — try again in a bit"
|
||||||
|
+ (f" ({detail})" if detail else "")
|
||||||
|
)
|
||||||
|
if status == 404:
|
||||||
|
return SkillImportError("path not found on GitHub")
|
||||||
|
if detail:
|
||||||
|
return SkillImportError(f"GitHub request failed ({status}): {detail}")
|
||||||
|
return SkillImportError(f"GitHub request failed ({status})")
|
||||||
|
|
||||||
|
|
||||||
|
def _fetch_bytes(url: str) -> bytes:
|
||||||
|
ok, reason = check_outbound_url(url)
|
||||||
|
if not ok:
|
||||||
|
raise SkillImportError(reason)
|
||||||
|
with httpx.Client(follow_redirects=True, timeout=30.0) as client:
|
||||||
|
r = client.get(url, headers={"Accept": "application/vnd.github+json"})
|
||||||
|
if r.status_code >= 400:
|
||||||
|
raise _github_response_error(r)
|
||||||
|
_assert_github_url(str(r.url), context="redirect target")
|
||||||
|
if len(r.content) > MAX_FILE_BYTES:
|
||||||
|
raise SkillImportError(f"file too large: {url}")
|
||||||
|
return r.content
|
||||||
|
|
||||||
|
|
||||||
|
def _fetch_text(url: str) -> str:
|
||||||
|
data = _fetch_bytes(url)
|
||||||
|
try:
|
||||||
|
return data.decode("utf-8")
|
||||||
|
except UnicodeDecodeError as e:
|
||||||
|
raise SkillImportError(f"non-text file: {url}") from e
|
||||||
|
|
||||||
|
|
||||||
|
def _list_github_dir(src: ResolvedSource, rel_dir: str, out: Dict[str, str], *, depth: int = 0) -> None:
|
||||||
|
if depth > 4 or len(out) >= MAX_FILES:
|
||||||
|
return
|
||||||
|
url = _api_contents_url(src, rel_dir)
|
||||||
|
ok, reason = check_outbound_url(url)
|
||||||
|
if not ok:
|
||||||
|
raise SkillImportError(reason)
|
||||||
|
with httpx.Client(follow_redirects=True, timeout=30.0) as client:
|
||||||
|
r = client.get(url, headers={"Accept": "application/vnd.github+json"})
|
||||||
|
if r.status_code >= 400:
|
||||||
|
raise _github_response_error(r)
|
||||||
|
_assert_github_url(str(r.url), context="redirect target")
|
||||||
|
entries = r.json()
|
||||||
|
if not isinstance(entries, list):
|
||||||
|
raise SkillImportError("expected a directory on GitHub")
|
||||||
|
total = sum(len(v.encode("utf-8")) for v in out.values())
|
||||||
|
for ent in entries:
|
||||||
|
if len(out) >= MAX_FILES or total >= MAX_TOTAL_BYTES:
|
||||||
|
break
|
||||||
|
if not isinstance(ent, dict):
|
||||||
|
continue
|
||||||
|
name = ent.get("name") or ""
|
||||||
|
ent_type = ent.get("type")
|
||||||
|
rel = _safe_relpath(f"{rel_dir}/{name}" if rel_dir else name)
|
||||||
|
if ent_type == "dir":
|
||||||
|
_list_github_dir(src, rel, out, depth=depth + 1)
|
||||||
|
total = sum(len(v.encode("utf-8")) for v in out.values())
|
||||||
|
continue
|
||||||
|
if ent_type != "file" or not _is_text_file(name):
|
||||||
|
continue
|
||||||
|
dl = ent.get("download_url")
|
||||||
|
if not dl:
|
||||||
|
continue
|
||||||
|
_assert_github_url(dl, context="download URL")
|
||||||
|
text = _fetch_text(dl)
|
||||||
|
total += len(text.encode("utf-8"))
|
||||||
|
if total > MAX_TOTAL_BYTES:
|
||||||
|
raise SkillImportError("skill bundle exceeds size limit")
|
||||||
|
out[rel] = text
|
||||||
|
|
||||||
|
|
||||||
|
def fetch_skill_bundle(url: str) -> Tuple[Dict[str, str], ResolvedSource]:
|
||||||
|
"""Download SKILL.md and sibling text assets. Returns relative_path → content."""
|
||||||
|
src = parse_skill_source(url)
|
||||||
|
files: Dict[str, str] = {}
|
||||||
|
|
||||||
|
path = _safe_relpath(src.path) if src.path else ""
|
||||||
|
if path.lower().endswith("skill.md"):
|
||||||
|
files[path] = _fetch_text(_raw_url(src, path))
|
||||||
|
parent = "/".join(path.split("/")[:-1])
|
||||||
|
if parent:
|
||||||
|
try:
|
||||||
|
_list_github_dir(src, parent, files)
|
||||||
|
except SkillImportError:
|
||||||
|
pass
|
||||||
|
return files, src
|
||||||
|
|
||||||
|
if path:
|
||||||
|
try:
|
||||||
|
_fetch_text(_raw_url(src, f"{path}/SKILL.md"))
|
||||||
|
_list_github_dir(src, path, files)
|
||||||
|
return files, src
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
text = _fetch_text(_raw_url(src, path))
|
||||||
|
if path.lower().endswith(".md"):
|
||||||
|
files[path] = text
|
||||||
|
return files, src
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
_list_github_dir(src, path, files)
|
||||||
|
else:
|
||||||
|
_list_github_dir(src, "", files)
|
||||||
|
|
||||||
|
if not any(p.lower().endswith("skill.md") for p in files):
|
||||||
|
# Flat repo root with SKILL.md only
|
||||||
|
try:
|
||||||
|
files["SKILL.md"] = _fetch_text(_raw_url(src, "SKILL.md"))
|
||||||
|
except Exception as e:
|
||||||
|
raise SkillImportError(
|
||||||
|
"No SKILL.md found — link to a skill folder or SKILL.md on GitHub"
|
||||||
|
) from e
|
||||||
|
return files, src
|
||||||
|
|
||||||
|
|
||||||
|
def pick_skill_md(files: Dict[str, str]) -> Tuple[str, str]:
|
||||||
|
for rel, content in files.items():
|
||||||
|
if rel.lower().endswith("skill.md"):
|
||||||
|
return rel, content
|
||||||
|
raise SkillImportError("bundle has no SKILL.md")
|
||||||
|
|
||||||
|
|
||||||
|
def default_category_from_source(src: ResolvedSource) -> str:
|
||||||
|
return "imported"
|
||||||
@@ -381,6 +381,54 @@ class SkillsManager:
|
|||||||
|
|
||||||
return sk.to_dict()
|
return sk.to_dict()
|
||||||
|
|
||||||
|
def import_bundle_from_files(
|
||||||
|
self,
|
||||||
|
files: Dict[str, str],
|
||||||
|
*,
|
||||||
|
owner: Optional[str] = None,
|
||||||
|
source_url: str = "",
|
||||||
|
category: str = "imported",
|
||||||
|
) -> Dict:
|
||||||
|
"""Install a fetched skill bundle (relative path → text) under skills/."""
|
||||||
|
from .skill_importer import SkillImportError, pick_skill_md, _safe_relpath
|
||||||
|
from core.atomic_io import atomic_write_text
|
||||||
|
|
||||||
|
if not files:
|
||||||
|
raise SkillImportError("empty bundle")
|
||||||
|
_rel, skill_md = pick_skill_md(files)
|
||||||
|
sk = Skill.from_markdown(skill_md)
|
||||||
|
nm = slugify(sk.name or _rel.split("/")[-2] or "skill")
|
||||||
|
cat = slugify(category or sk.category or "imported", fallback="imported")
|
||||||
|
|
||||||
|
existing = {s["name"] for s in self.load_all()}
|
||||||
|
base = nm
|
||||||
|
i = 2
|
||||||
|
while nm in existing:
|
||||||
|
nm = f"{base}-{i}"
|
||||||
|
i += 1
|
||||||
|
|
||||||
|
skill_dir = self._skill_dir(cat, nm)
|
||||||
|
os.makedirs(skill_dir, exist_ok=True)
|
||||||
|
|
||||||
|
# Preserve bundle layout (templates/, references/, etc.) under the skill dir.
|
||||||
|
for rel, content in files.items():
|
||||||
|
safe = _safe_relpath(rel)
|
||||||
|
dest = os.path.join(skill_dir, safe)
|
||||||
|
os.makedirs(os.path.dirname(dest), exist_ok=True)
|
||||||
|
atomic_write_text(dest, content)
|
||||||
|
|
||||||
|
sk.name = nm
|
||||||
|
sk.category = cat
|
||||||
|
sk.owner = owner
|
||||||
|
sk.source = "imported"
|
||||||
|
if source_url:
|
||||||
|
extra = (sk.body_extra or "").strip()
|
||||||
|
note = f"Imported from {source_url}"
|
||||||
|
sk.body_extra = f"{extra}\n\n{note}".strip() if extra else note
|
||||||
|
atomic_write_text(self._skill_file(cat, nm), sk.to_markdown())
|
||||||
|
sk.path = self._skill_file(cat, nm)
|
||||||
|
return sk.to_dict()
|
||||||
|
|
||||||
def update_skill(self, skill_id: str, updates: Dict, owner: Optional[str] = None) -> bool:
|
def update_skill(self, skill_id: str, updates: Dict, owner: Optional[str] = None) -> bool:
|
||||||
"""`skill_id` is the slug name. Allows updating any field plus
|
"""`skill_id` is the slug name. Allows updating any field plus
|
||||||
renames if `name` changes (file is moved on disk).
|
renames if `name` changes (file is moved on disk).
|
||||||
|
|||||||
@@ -15,10 +15,11 @@ from pathlib import Path
|
|||||||
from typing import Optional, Dict
|
from typing import Optional, Dict
|
||||||
|
|
||||||
from src.research_utils import is_low_quality
|
from src.research_utils import is_low_quality
|
||||||
|
from src.constants import DEEP_RESEARCH_DIR
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
RESEARCH_DATA_DIR = Path("data/deep_research")
|
RESEARCH_DATA_DIR = Path(DEEP_RESEARCH_DIR)
|
||||||
|
|
||||||
|
|
||||||
class ResearchHandler:
|
class ResearchHandler:
|
||||||
|
|||||||
@@ -6,21 +6,29 @@ from collections import Counter
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, Any
|
from typing import Dict, Any
|
||||||
|
|
||||||
|
from core.constants import DATA_DIR
|
||||||
|
|
||||||
from .cache import cache_metrics
|
from .cache import cache_metrics
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# Dedicated error logger with file handler
|
# Dedicated error logger — write to the data logs directory (writable on both
|
||||||
_error_log_path = Path(__file__).resolve().parent.parent / "search_engine_error.log"
|
# native runs and Docker, where DATA_DIR resolves to the bind-mounted volume).
|
||||||
_error_handler = logging.FileHandler(_error_log_path, encoding="utf-8")
|
_log_dir = Path(DATA_DIR) / "logs"
|
||||||
_error_handler.setLevel(logging.WARNING)
|
_error_log_path = _log_dir / "search_engine_error.log"
|
||||||
_error_handler.setFormatter(logging.Formatter("%(asctime)s %(levelname)s %(name)s %(message)s"))
|
|
||||||
error_logger = logging.getLogger("search_engine_error")
|
error_logger = logging.getLogger("search_engine_error")
|
||||||
error_logger.addHandler(_error_handler)
|
|
||||||
error_logger.propagate = False
|
error_logger.propagate = False
|
||||||
|
try:
|
||||||
|
_log_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
_error_handler = logging.FileHandler(_error_log_path, encoding="utf-8")
|
||||||
|
_error_handler.setLevel(logging.WARNING)
|
||||||
|
_error_handler.setFormatter(logging.Formatter("%(asctime)s %(levelname)s %(name)s %(message)s"))
|
||||||
|
error_logger.addHandler(_error_handler)
|
||||||
|
except Exception as _e:
|
||||||
|
logging.getLogger(__name__).warning("search_engine_error log handler unavailable: %s", _e)
|
||||||
|
|
||||||
# Analytics file
|
# Analytics file — also in the writable logs volume.
|
||||||
ANALYTICS_FILE = Path(__file__).resolve().parent.parent / "search_analytics.json"
|
ANALYTICS_FILE = _log_dir / "search_analytics.json"
|
||||||
|
|
||||||
|
|
||||||
# ----------------------------------------------------------------------
|
# ----------------------------------------------------------------------
|
||||||
|
|||||||
@@ -6,17 +6,23 @@ from datetime import datetime, timedelta
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
|
|
||||||
|
from core.constants import DATA_DIR
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# Cache directories
|
# Cache directories
|
||||||
CACHE_DIR = Path(__file__).resolve().parent.parent / "cache"
|
CACHE_DIR = Path(DATA_DIR) / "cache"
|
||||||
SEARCH_CACHE_DIR = CACHE_DIR / "search"
|
SEARCH_CACHE_DIR = CACHE_DIR / "search"
|
||||||
CONTENT_CACHE_DIR = CACHE_DIR / "content"
|
CONTENT_CACHE_DIR = CACHE_DIR / "content"
|
||||||
CACHE_MAX_ENTRIES = 1000
|
CACHE_MAX_ENTRIES = 1000
|
||||||
|
|
||||||
# Create cache directories
|
# Create cache directories. Guarded so an unwritable path (e.g. a read-only
|
||||||
SEARCH_CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
# mount) degrades to no-disk-cache instead of crashing module import.
|
||||||
CONTENT_CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
try:
|
||||||
|
SEARCH_CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
|
CONTENT_CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
|
except OSError as _e:
|
||||||
|
logger.warning("Search cache directory unavailable (%s); disk cache disabled", _e)
|
||||||
|
|
||||||
# Track cache size for LRU eviction
|
# Track cache size for LRU eviction
|
||||||
search_cache_index: Dict[str, datetime] = {}
|
search_cache_index: Dict[str, datetime] = {}
|
||||||
|
|||||||
@@ -259,6 +259,9 @@ def fetch_webpage_content(url: str, timeout: int = 5, retry_attempt: int = 0) ->
|
|||||||
raise RateLimitError(f"Rate limit hit for {url} (attempt {retry_attempt})")
|
raise RateLimitError(f"Rate limit hit for {url} (attempt {retry_attempt})")
|
||||||
|
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
|
except httpx.HTTPStatusError as e:
|
||||||
|
error_logger.warning(f"HTTP {e.response.status_code} fetching {url}: {e}")
|
||||||
|
return _empty_result(url, f"HTTP {e.response.status_code}: {e}")
|
||||||
except httpx.RequestError as e:
|
except httpx.RequestError as e:
|
||||||
error_logger.error(f"NetworkError fetching {url} (attempt {retry_attempt}): {e}")
|
error_logger.error(f"NetworkError fetching {url} (attempt {retry_attempt}): {e}")
|
||||||
return _empty_result(url, f"NetworkError: {e}")
|
return _empty_result(url, f"NetworkError: {e}")
|
||||||
|
|||||||
@@ -76,6 +76,19 @@ def _domain(url: str) -> str:
|
|||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
|
||||||
|
def _has_word(text: str, term: str) -> bool:
|
||||||
|
"""True if ``term`` appears in ``text`` as a whole word.
|
||||||
|
|
||||||
|
Query terms are matched on word boundaries so a short term doesn't match
|
||||||
|
inside an unrelated word: "us" must not match "business"/"music", "port"
|
||||||
|
must not match "transport"/"support". This mirrors the tokenization used to
|
||||||
|
build ``query_terms`` (``\\b\\w+\\b``). #1473 converted the title and sports
|
||||||
|
checks to word boundaries; the snippet and subject-term checks below use
|
||||||
|
the same helper so the whole file stays consistent.
|
||||||
|
"""
|
||||||
|
return re.search(rf"\b{re.escape(term)}\b", text) is not None
|
||||||
|
|
||||||
|
|
||||||
def rank_search_results(query: str, results: List[dict]) -> List[dict]:
|
def rank_search_results(query: str, results: List[dict]) -> List[dict]:
|
||||||
"""Rank search results by title relevance, snippet quality, domain authority, and recency."""
|
"""Rank search results by title relevance, snippet quality, domain authority, and recency."""
|
||||||
query_terms = [t.lower() for t in re.findall(r"\b\w+\b", query)]
|
query_terms = [t.lower() for t in re.findall(r"\b\w+\b", query)]
|
||||||
@@ -87,14 +100,14 @@ def rank_search_results(query: str, results: List[dict]) -> List[dict]:
|
|||||||
if not title:
|
if not title:
|
||||||
return 0.0
|
return 0.0
|
||||||
title_lc = title.lower()
|
title_lc = title.lower()
|
||||||
matches = sum(1 for term in query_terms if re.search(rf"\b{re.escape(term)}\b", title_lc))
|
matches = sum(1 for term in query_terms if _has_word(title_lc, term))
|
||||||
return matches / len(query_terms) if query_terms else 0.0
|
return matches / len(query_terms) if query_terms else 0.0
|
||||||
|
|
||||||
def snippet_score(snippet: str) -> float:
|
def snippet_score(snippet: str) -> float:
|
||||||
if not snippet:
|
if not snippet:
|
||||||
return 0.0
|
return 0.0
|
||||||
length_factor = min(len(snippet), 200) / 200
|
length_factor = min(len(snippet), 200) / 200
|
||||||
term_hits = sum(1 for term in query_terms if term in snippet.lower())
|
term_hits = sum(1 for term in query_terms if _has_word(snippet.lower(), term))
|
||||||
term_factor = term_hits / len(query_terms) if query_terms else 0.0
|
term_factor = term_hits / len(query_terms) if query_terms else 0.0
|
||||||
return (length_factor + term_factor) / 2
|
return (length_factor + term_factor) / 2
|
||||||
|
|
||||||
@@ -127,7 +140,7 @@ def rank_search_results(query: str, results: List[dict]) -> List[dict]:
|
|||||||
# A country/news query should not rank a page whose title/snippet barely
|
# A country/news query should not rank a page whose title/snippet barely
|
||||||
# mentions the country above actual news pages for that country.
|
# mentions the country above actual news pages for that country.
|
||||||
subject_terms = [t for t in query_terms if t not in _NEWS_HINTS]
|
subject_terms = [t for t in query_terms if t not in _NEWS_HINTS]
|
||||||
if subject_terms and not any(t in text or t in netloc for t in subject_terms):
|
if subject_terms and not any(_has_word(text, t) or _has_word(netloc, t) for t in subject_terms):
|
||||||
adjustment -= 1.0
|
adjustment -= 1.0
|
||||||
return adjustment
|
return adjustment
|
||||||
|
|
||||||
|
|||||||
@@ -9,6 +9,8 @@ import httpx
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional, Dict, Any
|
from typing import Optional, Dict, Any
|
||||||
|
|
||||||
|
from src.constants import TTS_CACHE_DIR
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@@ -35,7 +37,7 @@ class TTSService:
|
|||||||
"endpoint:<id>" — OpenAI-compatible /audio/speech via ModelEndpoint
|
"endpoint:<id>" — OpenAI-compatible /audio/speech via ModelEndpoint
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, cache_dir: str = "data/tts_cache"):
|
def __init__(self, cache_dir: str = TTS_CACHE_DIR):
|
||||||
self.cache_dir = Path(cache_dir)
|
self.cache_dir = Path(cache_dir)
|
||||||
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
||||||
self._kokoro = None # lazy-init
|
self._kokoro = None # lazy-init
|
||||||
|
|||||||
@@ -6,23 +6,30 @@ initial admin user. Safe to re-run (skips what already exists).
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
import platform
|
||||||
import shutil
|
import shutil
|
||||||
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||||
DATA_DIR = os.path.join(BASE_DIR, "data")
|
sys.path.insert(0, BASE_DIR)
|
||||||
|
from src.constants import (
|
||||||
|
DATA_DIR, AUTH_FILE, UPLOAD_DIR, PERSONAL_DIR, PERSONAL_UPLOADS_DIR,
|
||||||
|
TTS_CACHE_DIR, GENERATED_IMAGES_DIR, DEEP_RESEARCH_DIR, CHROMA_DIR,
|
||||||
|
RAG_DIR, MEMORY_VECTORS_DIR,
|
||||||
|
)
|
||||||
|
|
||||||
DIRS = [
|
DIRS = [
|
||||||
DATA_DIR,
|
DATA_DIR,
|
||||||
os.path.join(DATA_DIR, "uploads"),
|
UPLOAD_DIR,
|
||||||
os.path.join(DATA_DIR, "personal_docs"),
|
PERSONAL_DIR,
|
||||||
os.path.join(DATA_DIR, "personal_uploads"),
|
PERSONAL_UPLOADS_DIR,
|
||||||
os.path.join(DATA_DIR, "tts_cache"),
|
TTS_CACHE_DIR,
|
||||||
os.path.join(DATA_DIR, "generated_images"),
|
GENERATED_IMAGES_DIR,
|
||||||
os.path.join(DATA_DIR, "deep_research"),
|
DEEP_RESEARCH_DIR,
|
||||||
os.path.join(DATA_DIR, "chroma"),
|
CHROMA_DIR,
|
||||||
os.path.join(DATA_DIR, "rag"),
|
RAG_DIR,
|
||||||
os.path.join(DATA_DIR, "memory_vectors"),
|
MEMORY_VECTORS_DIR,
|
||||||
os.path.join(BASE_DIR, "logs"),
|
os.path.join(BASE_DIR, "logs"),
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -72,7 +79,7 @@ def _prompt_admin_credentials():
|
|||||||
|
|
||||||
def create_default_admin():
|
def create_default_admin():
|
||||||
"""Create an initial admin user if none exists."""
|
"""Create an initial admin user if none exists."""
|
||||||
auth_path = os.path.join(DATA_DIR, "auth.json")
|
auth_path = AUTH_FILE
|
||||||
if os.path.exists(auth_path):
|
if os.path.exists(auth_path):
|
||||||
print(" [skip] auth.json already exists")
|
print(" [skip] auth.json already exists")
|
||||||
return "exists"
|
return "exists"
|
||||||
@@ -117,7 +124,16 @@ def create_default_admin():
|
|||||||
print(f" Temporary password: {password}")
|
print(f" Temporary password: {password}")
|
||||||
print(f" ** Change it after first login. Set ODYSSEUS_ADMIN_PASSWORD to choose your own. **")
|
print(f" ** Change it after first login. Set ODYSSEUS_ADMIN_PASSWORD to choose your own. **")
|
||||||
return "created"
|
return "created"
|
||||||
except ImportError:
|
except ImportError as e:
|
||||||
|
if "incompatible architecture" in str(e).lower():
|
||||||
|
# bcrypt is present but built for the wrong CPU architecture — the
|
||||||
|
# same Apple Silicon mismatch check_arch() guards against, caught here
|
||||||
|
# for the rarer case of an x86 wheel inside an arm64 venv.
|
||||||
|
print(" [error] bcrypt loaded with the wrong CPU architecture.")
|
||||||
|
print(" Rebuild the venv with an arm64 Python:")
|
||||||
|
print(" rm -rf venv && /opt/homebrew/bin/python3.11 -m venv venv")
|
||||||
|
print(" ./venv/bin/pip install -r requirements.txt")
|
||||||
|
return "skipped"
|
||||||
print(" [warn] bcrypt not installed — skipping admin user creation")
|
print(" [warn] bcrypt not installed — skipping admin user creation")
|
||||||
print(" Run: pip install bcrypt")
|
print(" Run: pip install bcrypt")
|
||||||
return "skipped"
|
return "skipped"
|
||||||
@@ -167,9 +183,52 @@ def check_deps():
|
|||||||
print(" [ok] tmux installed")
|
print(" [ok] tmux installed")
|
||||||
|
|
||||||
|
|
||||||
|
def check_arch():
|
||||||
|
"""Stop early, with guidance, if we're on Apple Silicon but running an
|
||||||
|
Intel (x86_64) Python through Rosetta.
|
||||||
|
|
||||||
|
A venv built with such an interpreter installs and loads compiled packages
|
||||||
|
(bcrypt, pydantic-core, onnxruntime, …) for the wrong CPU architecture, then
|
||||||
|
dies deep inside an import with a cryptic
|
||||||
|
"(mach-o file, but is an incompatible architecture)" error. Catching it here
|
||||||
|
turns that into one clear, actionable message.
|
||||||
|
"""
|
||||||
|
if sys.platform != "darwin" or platform.machine() == "arm64":
|
||||||
|
return # Not macOS, or already an arm64-native interpreter — nothing to do.
|
||||||
|
|
||||||
|
# platform.machine() == "x86_64": either a genuine Intel Mac (fine) or an x86
|
||||||
|
# interpreter running under Rosetta on Apple Silicon (the case we must catch).
|
||||||
|
try:
|
||||||
|
translated = subprocess.run(
|
||||||
|
["sysctl", "-n", "sysctl.proc_translated"],
|
||||||
|
capture_output=True, text=True, timeout=5,
|
||||||
|
).stdout.strip()
|
||||||
|
except Exception:
|
||||||
|
translated = ""
|
||||||
|
if translated != "1":
|
||||||
|
return # Genuine Intel Mac — carry on.
|
||||||
|
|
||||||
|
print("\n [error] This is an Apple Silicon Mac, but setup is running under an")
|
||||||
|
print(" Intel (x86_64) Python through Rosetta. Compiled packages would")
|
||||||
|
print(' load as the wrong architecture and crash with "incompatible')
|
||||||
|
print(' architecture" later on.')
|
||||||
|
print("\n Rebuild the environment with Homebrew's arm64 Python:")
|
||||||
|
print(" brew install python@3.11 # if you don't have it yet")
|
||||||
|
print(" rm -rf venv")
|
||||||
|
print(" /opt/homebrew/bin/python3.11 -m venv venv")
|
||||||
|
print(" ./venv/bin/pip install -r requirements.txt")
|
||||||
|
print(" ./venv/bin/python setup.py")
|
||||||
|
print("\n Tip: ./start-macos.sh does all of this with the right Python.\n")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
print("\n=== Odysseus Setup ===\n")
|
print("\n=== Odysseus Setup ===\n")
|
||||||
|
|
||||||
|
# Fail fast with a clear message if the CPU architecture is wrong (Apple
|
||||||
|
# Silicon under an x86/Rosetta Python) before importing anything native.
|
||||||
|
check_arch()
|
||||||
|
|
||||||
print("1. Creating directories...")
|
print("1. Creating directories...")
|
||||||
create_dirs()
|
create_dirs()
|
||||||
|
|
||||||
|
|||||||
@@ -35,6 +35,7 @@ _CALENDAR_ACTION = (
|
|||||||
r"delete|deleting|remove|removing|cancel|cancelling|canceling)"
|
r"delete|deleting|remove|removing|cancel|cancelling|canceling)"
|
||||||
)
|
)
|
||||||
_CALENDAR_THING = r"(?:calendar|calendar\s+(?:entry|item)|event|meeting|appointment|entry|call)"
|
_CALENDAR_THING = r"(?:calendar|calendar\s+(?:entry|item)|event|meeting|appointment|entry|call)"
|
||||||
|
_CALENDAR_READ_THING = r"(?:calendar|schedule|events?|meetings?|appointments?|classes?)"
|
||||||
_EXPLANATORY_PREFIX = re.compile(
|
_EXPLANATORY_PREFIX = re.compile(
|
||||||
r"^\s*(?:how\s+(?:do|can)\s+i|can\s+you\s+explain|what\s+about|tell\s+me\s+how|show\s+me\s+how)\b",
|
r"^\s*(?:how\s+(?:do|can)\s+i|can\s+you\s+explain|what\s+about|tell\s+me\s+how|show\s+me\s+how)\b",
|
||||||
re.I,
|
re.I,
|
||||||
@@ -59,6 +60,14 @@ _ROUTING_PATTERNS: tuple[tuple[str, str, Pattern[str]], ...] = tuple(
|
|||||||
("calendar", "calendar target action request", rf"\b{_CALENDAR_ACTION}\b.{{0,120}}\b(?:to|on|in|into|for)\s+(?:my\s+|the\s+|this\s+)?calendar\b"),
|
("calendar", "calendar target action request", rf"\b{_CALENDAR_ACTION}\b.{{0,120}}\b(?:to|on|in|into|for)\s+(?:my\s+|the\s+|this\s+)?calendar\b"),
|
||||||
("calendar", "put item on calendar request", r"\bput\s+.+\bon\s+(?:my\s+)?calendar\b"),
|
("calendar", "put item on calendar request", r"\bput\s+.+\bon\s+(?:my\s+)?calendar\b"),
|
||||||
|
|
||||||
|
# Calendar/event lookup. A question such as "Do I have Taekwondo
|
||||||
|
# classes this week?" needs the calendar tool; plain chat cannot know.
|
||||||
|
("calendar", "calendar lookup request", rf"\b(?:list|show|check|find)\b.{{0,120}}\b(?:my\s+|the\s+)?(?:upcoming|next|today'?s?|tomorrow'?s?|this\s+week'?s?)\b.{{0,120}}\b{_CALENDAR_READ_THING}\b"),
|
||||||
|
("calendar", "calendar lookup question", rf"\b(?:what|which)\b.{{0,120}}\b(?:upcoming|next|today'?s?|tomorrow'?s?|this\s+week'?s?)\b.{{0,120}}\b{_CALENDAR_READ_THING}\b"),
|
||||||
|
("calendar", "calendar availability question", rf"\bdo\s+i\s+have\b.{{0,120}}\b(?:upcoming|next|today|tomorrow|this\s+week)\b.{{0,120}}\b{_CALENDAR_READ_THING}\b"),
|
||||||
|
("calendar", "calendar agenda question", r"\bwhat(?:'s| is)\s+on\s+(?:my\s+)?calendar\b"),
|
||||||
|
("calendar", "next calendar item question", r"\bwhen\s+(?:is|are)\s+(?:my\s+)?next\s+(?:event|meeting|appointment|class)\b"),
|
||||||
|
|
||||||
# Notes, todos, checklists, and reminders.
|
# Notes, todos, checklists, and reminders.
|
||||||
("notes", "reminder request", r"\bremind\s+me\b"),
|
("notes", "reminder request", r"\bremind\s+me\b"),
|
||||||
("notes", "assistant note/todo action request", rf"{_ACTION_QUESTION}(?:add|create|make|take|jot|write\s+down|set)\b.{{0,120}}\b(?:note|todo|task|checklist|reminder)\b"),
|
("notes", "assistant note/todo action request", rf"{_ACTION_QUESTION}(?:add|create|make|take|jot|write\s+down|set)\b.{{0,120}}\b(?:note|todo|task|checklist|reminder)\b"),
|
||||||
|
|||||||
+487
-191
File diff suppressed because it is too large
Load Diff
+6
-32
@@ -14,16 +14,17 @@ Sub-modules:
|
|||||||
import logging
|
import logging
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
|
|
||||||
|
from src.tool_utils import _truncate, get_mcp_manager, set_mcp_manager
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Constants (kept here — sub-modules import from here)
|
# Constants (re-exported for backward compatibility — single source of truth
|
||||||
|
# is src.constants; always prefer importing from there for new code)
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
MAX_AGENT_ROUNDS = 20
|
MAX_AGENT_ROUNDS = 50
|
||||||
SHELL_TIMEOUT = 60
|
SHELL_TIMEOUT = 60
|
||||||
PYTHON_TIMEOUT = 30
|
PYTHON_TIMEOUT = 30
|
||||||
MAX_OUTPUT_CHARS = 10_000
|
|
||||||
MAX_READ_CHARS = 20_000
|
|
||||||
|
|
||||||
# Tool types that trigger execution
|
# Tool types that trigger execution
|
||||||
TOOL_TAGS = {"bash", "python", "web_search", "web_fetch", "read_file", "write_file", "edit_file",
|
TOOL_TAGS = {"bash", "python", "web_search", "web_fetch", "read_file", "write_file", "edit_file",
|
||||||
@@ -34,7 +35,7 @@ TOOL_TAGS = {"bash", "python", "web_search", "web_fetch", "read_file", "write_fi
|
|||||||
"send_to_session",
|
"send_to_session",
|
||||||
"pipeline",
|
"pipeline",
|
||||||
"manage_session", "manage_memory", "list_models",
|
"manage_session", "manage_memory", "list_models",
|
||||||
"ui_control", "generate_image",
|
"ui_control", "generate_image", "ask_user", "update_plan",
|
||||||
"manage_tasks", "api_call", "ask_teacher", "manage_skills",
|
"manage_tasks", "api_call", "ask_teacher", "manage_skills",
|
||||||
"suggest_document",
|
"suggest_document",
|
||||||
"manage_endpoints", "manage_mcp", "manage_webhooks",
|
"manage_endpoints", "manage_mcp", "manage_webhooks",
|
||||||
@@ -63,33 +64,6 @@ TOOL_TAGS = {"bash", "python", "web_search", "web_fetch", "read_file", "write_fi
|
|||||||
|
|
||||||
ToolBlock = namedtuple("ToolBlock", ["tool_type", "content"])
|
ToolBlock = namedtuple("ToolBlock", ["tool_type", "content"])
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# MCP Manager (kept here — used by execution and agent_loop)
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
_mcp_manager = None
|
|
||||||
|
|
||||||
def set_mcp_manager(manager):
|
|
||||||
"""Set the global MCP manager instance."""
|
|
||||||
global _mcp_manager
|
|
||||||
_mcp_manager = manager
|
|
||||||
|
|
||||||
def get_mcp_manager():
|
|
||||||
"""Get the global MCP manager instance."""
|
|
||||||
return _mcp_manager
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Helpers (kept here — used by sub-modules)
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
def _truncate(text: str, limit: int = MAX_OUTPUT_CHARS) -> str:
|
|
||||||
# Callers treat the result as text, so always return a string: coerce a
|
|
||||||
# non-string (None -> "", otherwise str(...)) instead of returning it raw,
|
|
||||||
# which would just move the crash downstream.
|
|
||||||
if not isinstance(text, str):
|
|
||||||
text = "" if text is None else str(text)
|
|
||||||
if len(text) > limit:
|
|
||||||
return text[:limit] + f"\n... (truncated, {len(text)} chars total)"
|
|
||||||
return text
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Re-exports from sub-modules
|
# Re-exports from sub-modules
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|||||||
+43
-27
@@ -14,6 +14,8 @@ import uuid
|
|||||||
import time
|
import time
|
||||||
from typing import Dict, Optional, Tuple
|
from typing import Dict, Optional, Tuple
|
||||||
|
|
||||||
|
from src.constants import GENERATED_IMAGES_DIR
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
AI_CHAT_TIMEOUT = 120 # seconds for a single LLM call
|
AI_CHAT_TIMEOUT = 120 # seconds for a single LLM call
|
||||||
@@ -55,7 +57,7 @@ def set_rag_manager(rag_mgr, personal_docs_mgr=None):
|
|||||||
# Model resolution
|
# Model resolution
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
from src.endpoint_resolver import normalize_base as _normalize_base, build_chat_url, build_headers, build_models_url
|
from src.endpoint_resolver import build_chat_url, build_headers, build_models_url, resolve_endpoint_runtime
|
||||||
|
|
||||||
|
|
||||||
def _resolve_model(spec: str, owner: Optional[str] = None) -> Tuple[str, str, Dict]:
|
def _resolve_model(spec: str, owner: Optional[str] = None) -> Tuple[str, str, Dict]:
|
||||||
@@ -96,9 +98,12 @@ def _resolve_model(spec: str, owner: Optional[str] = None) -> Tuple[str, str, Di
|
|||||||
(f" matching '{target_endpoint_name}'" if target_endpoint_name else ""))
|
(f" matching '{target_endpoint_name}'" if target_endpoint_name else ""))
|
||||||
|
|
||||||
for ep in endpoints:
|
for ep in endpoints:
|
||||||
base = _normalize_base(ep.base_url)
|
try:
|
||||||
|
base, api_key = resolve_endpoint_runtime(ep, owner=owner)
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
provider = _detect_provider(base)
|
provider = _detect_provider(base)
|
||||||
headers = build_headers(ep.api_key, base)
|
headers = build_headers(api_key, base)
|
||||||
|
|
||||||
if provider == "anthropic":
|
if provider == "anthropic":
|
||||||
# Anthropic: match against hardcoded model list
|
# Anthropic: match against hardcoded model list
|
||||||
@@ -112,16 +117,20 @@ def _resolve_model(spec: str, owner: Optional[str] = None) -> Tuple[str, str, Di
|
|||||||
else:
|
else:
|
||||||
# OpenAI-compatible and native Ollama: probe the provider's model list.
|
# OpenAI-compatible and native Ollama: probe the provider's model list.
|
||||||
try:
|
try:
|
||||||
r = httpx.get(build_models_url(base), headers=headers, timeout=5)
|
models_url = build_models_url(base)
|
||||||
r.raise_for_status()
|
if models_url:
|
||||||
data = r.json()
|
r = httpx.get(models_url, headers=headers, timeout=5)
|
||||||
model_ids = [m.get("id") for m in (data.get("data") or []) if m.get("id")]
|
r.raise_for_status()
|
||||||
if not model_ids:
|
data = r.json()
|
||||||
model_ids = [
|
model_ids = [m.get("id") for m in (data.get("data") or []) if m.get("id")]
|
||||||
m.get("name") or m.get("model")
|
if not model_ids:
|
||||||
for m in (data.get("models") or [])
|
model_ids = [
|
||||||
if m.get("name") or m.get("model")
|
m.get("name") or m.get("model")
|
||||||
]
|
for m in (data.get("models") or [])
|
||||||
|
if m.get("name") or m.get("model")
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
model_ids = json.loads(ep.cached_models or "[]")
|
||||||
except Exception:
|
except Exception:
|
||||||
model_ids = []
|
model_ids = []
|
||||||
|
|
||||||
@@ -1119,25 +1128,32 @@ async def do_list_models(content: str, session_id: Optional[str] = None, owner:
|
|||||||
total_models = 0
|
total_models = 0
|
||||||
|
|
||||||
for ep in endpoints:
|
for ep in endpoints:
|
||||||
base = _normalize_base(ep.base_url)
|
try:
|
||||||
|
base, api_key = resolve_endpoint_runtime(ep, owner=owner)
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
provider = _detect_provider(base)
|
provider = _detect_provider(base)
|
||||||
headers = build_headers(ep.api_key, base)
|
headers = build_headers(api_key, base)
|
||||||
|
|
||||||
model_ids = []
|
model_ids = []
|
||||||
if provider == "anthropic":
|
if provider == "anthropic":
|
||||||
model_ids = list(ANTHROPIC_MODELS)
|
model_ids = list(ANTHROPIC_MODELS)
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
r = httpx.get(build_models_url(base), headers=headers, timeout=5)
|
models_url = build_models_url(base)
|
||||||
r.raise_for_status()
|
if models_url:
|
||||||
data = r.json()
|
r = httpx.get(models_url, headers=headers, timeout=5)
|
||||||
model_ids = [m.get("id") for m in (data.get("data") or []) if m.get("id")]
|
r.raise_for_status()
|
||||||
if not model_ids:
|
data = r.json()
|
||||||
model_ids = [
|
model_ids = [m.get("id") for m in (data.get("data") or []) if m.get("id")]
|
||||||
m.get("name") or m.get("model")
|
if not model_ids:
|
||||||
for m in (data.get("models") or [])
|
model_ids = [
|
||||||
if m.get("name") or m.get("model")
|
m.get("name") or m.get("model")
|
||||||
]
|
for m in (data.get("models") or [])
|
||||||
|
if m.get("name") or m.get("model")
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
model_ids = json.loads(ep.cached_models or "[]")
|
||||||
except Exception:
|
except Exception:
|
||||||
model_ids = ["(endpoint offline)"]
|
model_ids = ["(endpoint offline)"]
|
||||||
|
|
||||||
@@ -1715,7 +1731,7 @@ async def do_generate_image(content: str, session_id: Optional[str] = None, owne
|
|||||||
|
|
||||||
# GPT image models always return b64_json; DALL-E may return url
|
# GPT image models always return b64_json; DALL-E may return url
|
||||||
if img.get("b64_json"):
|
if img.get("b64_json"):
|
||||||
img_dir = Path("data/generated_images")
|
img_dir = Path(GENERATED_IMAGES_DIR)
|
||||||
img_dir.mkdir(parents=True, exist_ok=True)
|
img_dir.mkdir(parents=True, exist_ok=True)
|
||||||
filename = f"{uuid.uuid4().hex[:12]}.png"
|
filename = f"{uuid.uuid4().hex[:12]}.png"
|
||||||
img_path = img_dir / filename
|
img_path = img_dir / filename
|
||||||
@@ -1728,7 +1744,7 @@ async def do_generate_image(content: str, session_id: Optional[str] = None, owne
|
|||||||
try:
|
try:
|
||||||
dl_resp = httpx.get(img["url"], timeout=60)
|
dl_resp = httpx.get(img["url"], timeout=60)
|
||||||
if dl_resp.status_code == 200:
|
if dl_resp.status_code == 200:
|
||||||
img_dir = Path("data/generated_images")
|
img_dir = Path(GENERATED_IMAGES_DIR)
|
||||||
img_dir.mkdir(parents=True, exist_ok=True)
|
img_dir.mkdir(parents=True, exist_ok=True)
|
||||||
filename = f"{uuid.uuid4().hex[:12]}.png"
|
filename = f"{uuid.uuid4().hex[:12]}.png"
|
||||||
img_path = img_dir / filename
|
img_path = img_dir / filename
|
||||||
|
|||||||
+22
-1
@@ -10,7 +10,7 @@ def get_current_user(request: Request) -> Optional[str]:
|
|||||||
return getattr(request.state, 'current_user', None)
|
return getattr(request.state, 'current_user', None)
|
||||||
|
|
||||||
|
|
||||||
def effective_user(request: Request):
|
def effective_user(request: Request) -> Optional[str]:
|
||||||
"""The real human behind the request, for ownership/attribution.
|
"""The real human behind the request, for ownership/attribution.
|
||||||
|
|
||||||
Cookie sessions resolve to the logged-in username. Bearer ``ody_`` callers
|
Cookie sessions resolve to the logged-in username. Bearer ``ody_`` callers
|
||||||
@@ -34,6 +34,24 @@ def effective_user(request: Request):
|
|||||||
return get_current_user(request)
|
return get_current_user(request)
|
||||||
|
|
||||||
|
|
||||||
|
def _is_api_token_request(request: Request) -> bool:
|
||||||
|
"""Return True when middleware authenticated a bearer API token."""
|
||||||
|
return bool(getattr(request.state, "api_token", False))
|
||||||
|
|
||||||
|
|
||||||
|
def require_authenticated_request(request: Request) -> str:
|
||||||
|
"""Allow either a browser session or a valid bearer API token.
|
||||||
|
|
||||||
|
This is intentionally narrower than :func:`require_user`: use it only for
|
||||||
|
routes that need authentication but do not read or mutate owner-scoped
|
||||||
|
user data. Owner-scoped routes should use ``require_user`` for browser
|
||||||
|
sessions or their own API-token scope/owner gate.
|
||||||
|
"""
|
||||||
|
if _is_api_token_request(request):
|
||||||
|
return effective_user(request) or ""
|
||||||
|
return require_user(request)
|
||||||
|
|
||||||
|
|
||||||
def _auth_disabled() -> bool:
|
def _auth_disabled() -> bool:
|
||||||
"""True when the operator has explicitly turned off auth via .env.
|
"""True when the operator has explicitly turned off auth via .env.
|
||||||
Mirrors the AUTH_ENABLED parse in app.py / core/middleware.py so the
|
Mirrors the AUTH_ENABLED parse in app.py / core/middleware.py so the
|
||||||
@@ -60,6 +78,9 @@ def require_user(request: Request) -> str:
|
|||||||
Use this on routes that touch user data so middleware misconfig can't
|
Use this on routes that touch user data so middleware misconfig can't
|
||||||
open them up.
|
open them up.
|
||||||
"""
|
"""
|
||||||
|
if _is_api_token_request(request):
|
||||||
|
raise HTTPException(403, "API tokens must use a scope-aware API route")
|
||||||
|
|
||||||
u = get_current_user(request)
|
u = get_current_user(request)
|
||||||
if u:
|
if u:
|
||||||
return u
|
return u
|
||||||
|
|||||||
+6
-4
@@ -33,13 +33,15 @@ from core.atomic_io import atomic_write_json
|
|||||||
from core.platform_compat import (
|
from core.platform_compat import (
|
||||||
detached_popen_kwargs,
|
detached_popen_kwargs,
|
||||||
find_bash,
|
find_bash,
|
||||||
|
git_bash_path,
|
||||||
kill_process_tree,
|
kill_process_tree,
|
||||||
pid_alive,
|
pid_alive,
|
||||||
)
|
)
|
||||||
|
|
||||||
_DATA_DIR = Path(os.environ.get("DATA_DIR", "data"))
|
from src.constants import BG_JOBS_DIR, BG_JOBS_FILE
|
||||||
_JOBS_DIR = _DATA_DIR / "bg_jobs"
|
|
||||||
_STORE = _DATA_DIR / "bg_jobs.json"
|
_JOBS_DIR = Path(BG_JOBS_DIR)
|
||||||
|
_STORE = Path(BG_JOBS_FILE)
|
||||||
|
|
||||||
# A job that runs longer than this is presumed stuck and reaped (the agent
|
# A job that runs longer than this is presumed stuck and reaped (the agent
|
||||||
# still gets a "timed out" follow-up so nothing hangs forever).
|
# still gets a "timed out" follow-up so nothing hangs forever).
|
||||||
@@ -106,7 +108,7 @@ def launch(command: str, session_id: str, cwd: Optional[str] = None,
|
|||||||
# handles drive paths and spaces correctly.
|
# handles drive paths and spaces correctly.
|
||||||
cmd_path = _JOBS_DIR / f"{job_id}.cmd.sh"
|
cmd_path = _JOBS_DIR / f"{job_id}.cmd.sh"
|
||||||
cmd_path.write_text(command + "\n", encoding="utf-8")
|
cmd_path.write_text(command + "\n", encoding="utf-8")
|
||||||
lp, xp, cp = (shlex.quote(p.as_posix()) for p in (log_path, exit_path, cmd_path))
|
lp, xp, cp = (shlex.quote(git_bash_path(p)) for p in (log_path, exit_path, cmd_path))
|
||||||
script_path = _JOBS_DIR / f"{job_id}.sh"
|
script_path = _JOBS_DIR / f"{job_id}.sh"
|
||||||
script_path.write_text(
|
script_path.write_text(
|
||||||
f"bash {cp} > {lp} 2>&1\n"
|
f"bash {cp} > {lp} 2>&1\n"
|
||||||
|
|||||||
+224
-21
@@ -12,6 +12,8 @@ from typing import Tuple
|
|||||||
|
|
||||||
from src.auth_helpers import owner_filter
|
from src.auth_helpers import owner_filter
|
||||||
from core.platform_compat import IS_WINDOWS, find_bash
|
from core.platform_compat import IS_WINDOWS, find_bash
|
||||||
|
from core.constants import internal_api_base
|
||||||
|
from src.constants import DATA_DIR, DEEP_RESEARCH_DIR, TIDY_CALENDAR_STATE_FILE, EMAIL_URGENCY_CACHE_DIR, COOKBOOK_STATE_FILE
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -166,7 +168,6 @@ async def action_consolidate_memory(owner: str, **kwargs) -> Tuple[str, bool]:
|
|||||||
drop_items = decision.get("drop") if isinstance(decision, dict) else None
|
drop_items = decision.get("drop") if isinstance(decision, dict) else None
|
||||||
if isinstance(keep_items, list) and isinstance(drop_items, list):
|
if isinstance(keep_items, list) and isinstance(drop_items, list):
|
||||||
by_id = {m.get("id"): m for m in group_memories if m.get("id")}
|
by_id = {m.get("id"): m for m in group_memories if m.get("id")}
|
||||||
keep_ids = set()
|
|
||||||
cleaned_by_id = {}
|
cleaned_by_id = {}
|
||||||
for item in keep_items:
|
for item in keep_items:
|
||||||
if not isinstance(item, dict):
|
if not isinstance(item, dict):
|
||||||
@@ -177,7 +178,6 @@ async def action_consolidate_memory(owner: str, **kwargs) -> Tuple[str, bool]:
|
|||||||
text = (item.get("text") or "").strip()
|
text = (item.get("text") or "").strip()
|
||||||
if not text:
|
if not text:
|
||||||
continue
|
continue
|
||||||
keep_ids.add(mid)
|
|
||||||
cleaned = {
|
cleaned = {
|
||||||
"category": (item.get("category") or by_id[mid].get("category") or "fact").strip(),
|
"category": (item.get("category") or by_id[mid].get("category") or "fact").strip(),
|
||||||
}
|
}
|
||||||
@@ -186,11 +186,20 @@ async def action_consolidate_memory(owner: str, **kwargs) -> Tuple[str, bool]:
|
|||||||
cleaned["text"] = text
|
cleaned["text"] = text
|
||||||
cleaned_by_id[mid] = cleaned
|
cleaned_by_id[mid] = cleaned
|
||||||
|
|
||||||
# If the model only saw a truncated memory, do not let
|
# Delete only memories the model EXPLICITLY dropped, never
|
||||||
# that partial view delete or rewrite the full memory.
|
# ones it merely omitted from `keep`. Treating the
|
||||||
keep_ids.update(mid for mid in truncated_ids if mid in by_id)
|
# complement of `keep` as deletions meant a model that
|
||||||
|
# forgot to re-list an id (common) silently destroyed that
|
||||||
|
# memory. Honor the explicit `drop` set instead.
|
||||||
|
drop_ids = {
|
||||||
|
d.get("id")
|
||||||
|
for d in drop_items
|
||||||
|
if isinstance(d, dict) and d.get("id") in by_id
|
||||||
|
}
|
||||||
|
# Never delete a memory the model only saw truncated.
|
||||||
|
drop_ids -= truncated_ids
|
||||||
|
|
||||||
if keep_ids:
|
if drop_ids or cleaned_by_id:
|
||||||
changed_text = 0
|
changed_text = 0
|
||||||
group_ref_ids = {id(m) for m in group_memories}
|
group_ref_ids = {id(m) for m in group_memories}
|
||||||
kept_all = []
|
kept_all = []
|
||||||
@@ -199,7 +208,7 @@ async def action_consolidate_memory(owner: str, **kwargs) -> Tuple[str, bool]:
|
|||||||
kept_all.append(mem)
|
kept_all.append(mem)
|
||||||
continue
|
continue
|
||||||
mid = mem.get("id")
|
mid = mem.get("id")
|
||||||
if mid not in keep_ids:
|
if mid in drop_ids:
|
||||||
continue
|
continue
|
||||||
cleaned = cleaned_by_id.get(mid) or {}
|
cleaned = cleaned_by_id.get(mid) or {}
|
||||||
if mid in truncated_ids:
|
if mid in truncated_ids:
|
||||||
@@ -211,7 +220,7 @@ async def action_consolidate_memory(owner: str, **kwargs) -> Tuple[str, bool]:
|
|||||||
mem["category"] = cleaned["category"]
|
mem["category"] = cleaned["category"]
|
||||||
kept_all.append(mem)
|
kept_all.append(mem)
|
||||||
|
|
||||||
removed = len(group_memories) - len(keep_ids)
|
removed = sum(1 for m in group_memories if m.get("id") in drop_ids)
|
||||||
total_scanned += len(group_memories)
|
total_scanned += len(group_memories)
|
||||||
if removed or changed_text:
|
if removed or changed_text:
|
||||||
all_memories = kept_all
|
all_memories = kept_all
|
||||||
@@ -348,7 +357,7 @@ async def action_tidy_research(owner: str, **kwargs) -> Tuple[str, bool]:
|
|||||||
try:
|
try:
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import json as _json
|
import json as _json
|
||||||
research_dir = Path("data/deep_research")
|
research_dir = Path(DEEP_RESEARCH_DIR)
|
||||||
if not research_dir.exists():
|
if not research_dir.exists():
|
||||||
raise TaskNoop("no research directory")
|
raise TaskNoop("no research directory")
|
||||||
files = list(research_dir.glob("*.json"))
|
files = list(research_dir.glob("*.json"))
|
||||||
@@ -386,7 +395,7 @@ async def action_tidy_calendar(owner: str, **kwargs) -> Tuple[str, bool]:
|
|||||||
from core.database import SessionLocal, CalendarEvent
|
from core.database import SessionLocal, CalendarEvent
|
||||||
from sqlalchemy import func
|
from sqlalchemy import func
|
||||||
|
|
||||||
STATE_FILE = Path("data/tidy_calendar_state.json")
|
STATE_FILE = Path(TIDY_CALENDAR_STATE_FILE)
|
||||||
last_watermark = None
|
last_watermark = None
|
||||||
try:
|
try:
|
||||||
if STATE_FILE.exists():
|
if STATE_FILE.exists():
|
||||||
@@ -593,9 +602,9 @@ async def action_classify_events(owner: str, **kwargs) -> Tuple[str, bool]:
|
|||||||
if not events:
|
if not events:
|
||||||
return "No upcoming events to classify", True
|
return "No upcoming events to classify", True
|
||||||
|
|
||||||
llm_url, llm_model, llm_headers = resolve_endpoint("utility")
|
llm_url, llm_model, llm_headers = resolve_endpoint("utility", owner=owner)
|
||||||
if not llm_url:
|
if not llm_url:
|
||||||
llm_url, llm_model, llm_headers = resolve_endpoint("default")
|
llm_url, llm_model, llm_headers = resolve_endpoint("default", owner=owner)
|
||||||
llm_available = bool(llm_url and llm_model)
|
llm_available = bool(llm_url and llm_model)
|
||||||
|
|
||||||
# Pull user memories so the LLM has personal context (relationships,
|
# Pull user memories so the LLM has personal context (relationships,
|
||||||
@@ -867,9 +876,9 @@ async def action_learn_sender_signatures(owner: str, **kwargs) -> Tuple[str, boo
|
|||||||
if not eligible:
|
if not eligible:
|
||||||
return "All sender sigs already cached (or no eligible senders)", True
|
return "All sender sigs already cached (or no eligible senders)", True
|
||||||
|
|
||||||
url, model, headers = resolve_endpoint("utility")
|
url, model, headers = resolve_endpoint("utility", owner=owner)
|
||||||
if not url or not model:
|
if not url or not model:
|
||||||
url, model, headers = resolve_endpoint("default")
|
url, model, headers = resolve_endpoint("default", owner=owner)
|
||||||
if not url or not model:
|
if not url or not model:
|
||||||
return "No LLM endpoint available", False
|
return "No LLM endpoint available", False
|
||||||
|
|
||||||
@@ -1303,12 +1312,12 @@ async def action_ping_notes(owner: str, **kwargs) -> Tuple[str, bool]:
|
|||||||
# users' entries (review C4). Legacy path kept as fallback so a
|
# users' entries (review C4). Legacy path kept as fallback so a
|
||||||
# single-user install (empty owner) doesn't lose its history.
|
# single-user install (empty owner) doesn't lose its history.
|
||||||
_owner_slug = "".join(c if (c.isalnum() or c in "-_.@") else "_" for c in (owner or "default"))
|
_owner_slug = "".join(c if (c.isalnum() or c in "-_.@") else "_" for c in (owner or "default"))
|
||||||
STATE = _P(f"data/note_pings_{_owner_slug}.json")
|
STATE = _P(DATA_DIR) / f"note_pings_{_owner_slug}.json"
|
||||||
STATE.parent.mkdir(parents=True, exist_ok=True)
|
STATE.parent.mkdir(parents=True, exist_ok=True)
|
||||||
# One-time migration: if legacy global file exists and per-owner file
|
# One-time migration: if legacy global file exists and per-owner file
|
||||||
# doesn't, seed from global (entries for OTHER owners still get pruned
|
# doesn't, seed from global (entries for OTHER owners still get pruned
|
||||||
# on their first run — acceptable, prevents silent loss).
|
# on their first run — acceptable, prevents silent loss).
|
||||||
_legacy = _P("data/note_pings.json")
|
_legacy = _P(DATA_DIR) / "note_pings.json"
|
||||||
if _legacy.exists() and not STATE.exists():
|
if _legacy.exists() and not STATE.exists():
|
||||||
try:
|
try:
|
||||||
STATE.write_text(_legacy.read_text(encoding="utf-8"), encoding="utf-8")
|
STATE.write_text(_legacy.read_text(encoding="utf-8"), encoding="utf-8")
|
||||||
@@ -1465,8 +1474,8 @@ async def action_check_email_urgency(owner: str, **kwargs) -> Tuple[str, bool]:
|
|||||||
# notified_uids / urgency counts. Empty owner falls back to a generic
|
# notified_uids / urgency counts. Empty owner falls back to a generic
|
||||||
# filename for single-user installs (matches prior behaviour).
|
# filename for single-user installs (matches prior behaviour).
|
||||||
_owner_slug = "".join(c if (c.isalnum() or c in "-_.@") else "_" for c in (owner or "default"))
|
_owner_slug = "".join(c if (c.isalnum() or c in "-_.@") else "_" for c in (owner or "default"))
|
||||||
STATE_PATH = _P(f"data/email_urgency_state_{_owner_slug}.json")
|
STATE_PATH = _P(DATA_DIR) / f"email_urgency_state_{_owner_slug}.json"
|
||||||
CACHE_DIR = _P("data/email_urgency_cache")
|
CACHE_DIR = _P(EMAIL_URGENCY_CACHE_DIR)
|
||||||
CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
STATE_PATH.parent.mkdir(parents=True, exist_ok=True)
|
STATE_PATH.parent.mkdir(parents=True, exist_ok=True)
|
||||||
AGE_CUTOFF = _dt.utcnow() - _td(days=7)
|
AGE_CUTOFF = _dt.utcnow() - _td(days=7)
|
||||||
@@ -1480,12 +1489,12 @@ async def action_check_email_urgency(owner: str, **kwargs) -> Tuple[str, bool]:
|
|||||||
|
|
||||||
# ── 1. Resolve LLM candidates (utility primary + utility fallbacks; fall
|
# ── 1. Resolve LLM candidates (utility primary + utility fallbacks; fall
|
||||||
# through to default chat as a last resort).
|
# through to default chat as a last resort).
|
||||||
url, model, headers = resolve_endpoint("utility")
|
url, model, headers = resolve_endpoint("utility", owner=owner)
|
||||||
if not url or not model:
|
if not url or not model:
|
||||||
url, model, headers = resolve_endpoint("default")
|
url, model, headers = resolve_endpoint("default", owner=owner)
|
||||||
if not url or not model:
|
if not url or not model:
|
||||||
return "No LLM endpoint available", False
|
return "No LLM endpoint available", False
|
||||||
candidates = [(url, model, headers)] + resolve_utility_fallback_candidates()
|
candidates = [(url, model, headers)] + resolve_utility_fallback_candidates(owner=owner)
|
||||||
|
|
||||||
# ── 2. Enumerate enabled accounts. Match this task's owner AND fall
|
# ── 2. Enumerate enabled accounts. Match this task's owner AND fall
|
||||||
# back to the legacy "unowned account whose imap_user / from_address
|
# back to the legacy "unowned account whose imap_user / from_address
|
||||||
@@ -1902,6 +1911,8 @@ async def action_check_email_urgency(owner: str, **kwargs) -> Tuple[str, bool]:
|
|||||||
delivered = bool(dispatch_result.get("email_sent"))
|
delivered = bool(dispatch_result.get("email_sent"))
|
||||||
elif channel == "ntfy":
|
elif channel == "ntfy":
|
||||||
delivered = bool(dispatch_result.get("ntfy_sent"))
|
delivered = bool(dispatch_result.get("ntfy_sent"))
|
||||||
|
elif channel == "webhook":
|
||||||
|
delivered = bool(dispatch_result.get("webhook_sent"))
|
||||||
if delivered:
|
if delivered:
|
||||||
newly_notified.update(new_urgent)
|
newly_notified.update(new_urgent)
|
||||||
else:
|
else:
|
||||||
@@ -2001,6 +2012,197 @@ async def action_check_email_urgency(owner: str, **kwargs) -> Tuple[str, bool]:
|
|||||||
return str(e), False
|
return str(e), False
|
||||||
|
|
||||||
|
|
||||||
|
async def action_cookbook_serve(
|
||||||
|
owner: str,
|
||||||
|
task_name: str = "",
|
||||||
|
progress_cb=None,
|
||||||
|
command: str = "",
|
||||||
|
**kwargs,
|
||||||
|
) -> Tuple[str, bool]:
|
||||||
|
"""Launch a Cookbook model serve as a scheduled task.
|
||||||
|
|
||||||
|
`command` is the JSON config string the task carries in `prompt`,
|
||||||
|
of shape: {"preset": "name"} OR {"repo_id": "...", "cmd": "...", "host": "..."}.
|
||||||
|
Optional `end_after_min: N` schedules a hard-stop N minutes after launch
|
||||||
|
(handled by cookbook_serve_lifecycle_loop in src/cookbook_serve_lifecycle.py).
|
||||||
|
"""
|
||||||
|
import json
|
||||||
|
import time as _time
|
||||||
|
import httpx
|
||||||
|
from pathlib import Path
|
||||||
|
from core.middleware import INTERNAL_TOOL_HEADER, INTERNAL_TOOL_TOKEN
|
||||||
|
from core.atomic_io import atomic_write_json
|
||||||
|
|
||||||
|
headers = {INTERNAL_TOOL_HEADER: INTERNAL_TOOL_TOKEN}
|
||||||
|
try:
|
||||||
|
cfg = json.loads(command or "{}")
|
||||||
|
except Exception:
|
||||||
|
return f"Invalid JSON config: {command!r}", False
|
||||||
|
if not isinstance(cfg, dict):
|
||||||
|
return "Config must be a JSON object", False
|
||||||
|
|
||||||
|
# Resolve the preset (if named) OR fall through with explicit fields.
|
||||||
|
preset_name = (cfg.get("preset") or "").strip()
|
||||||
|
repo_id = (cfg.get("repo_id") or "").strip()
|
||||||
|
cmd = (cfg.get("cmd") or "").strip()
|
||||||
|
host = (cfg.get("host") or cfg.get("remote_host") or "").strip()
|
||||||
|
try:
|
||||||
|
end_after_min = int(cfg.get("end_after_min") or 0)
|
||||||
|
except Exception:
|
||||||
|
end_after_min = 0
|
||||||
|
|
||||||
|
state_path = Path(COOKBOOK_STATE_FILE)
|
||||||
|
try:
|
||||||
|
state = json.loads(state_path.read_text(encoding="utf-8")) if state_path.exists() else {}
|
||||||
|
except Exception:
|
||||||
|
state = {}
|
||||||
|
|
||||||
|
# Preset lookup. Try three matching strategies in order so the
|
||||||
|
# schedule still works even when the user's preset is named
|
||||||
|
# differently from the model's short name:
|
||||||
|
#
|
||||||
|
# 1. Exact preset.name == preset_name (case-insensitive)
|
||||||
|
# 2. preset.model / preset.modelId == repo_id (caller knows the repo)
|
||||||
|
# 3. preset.model's short name (after final /) == preset_name
|
||||||
|
#
|
||||||
|
# Without #2 and #3, scheduling "Qwen3.5-397B-A17B-AWQ" failed when
|
||||||
|
# the saved preset was named "vllm-qwen-397b" or had the model field
|
||||||
|
# populated with the full HF repo path. Either should resolve.
|
||||||
|
def _short(name: str) -> str:
|
||||||
|
return (name or "").rsplit("/", 1)[-1].lower()
|
||||||
|
|
||||||
|
if not cmd or not repo_id:
|
||||||
|
presets = state.get("presets") or []
|
||||||
|
chosen = None
|
||||||
|
# Strategy 1: exact name match.
|
||||||
|
if preset_name:
|
||||||
|
chosen = next(
|
||||||
|
(p for p in presets if isinstance(p, dict)
|
||||||
|
and (p.get("name") or "").lower() == preset_name.lower()),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
# Strategy 2: repo_id matches the preset's model field.
|
||||||
|
if chosen is None and repo_id:
|
||||||
|
chosen = next(
|
||||||
|
(p for p in presets if isinstance(p, dict)
|
||||||
|
and (p.get("model") or p.get("modelId") or "").lower() == repo_id.lower()),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
# Strategy 3: model's short name matches the preset_name.
|
||||||
|
if chosen is None and preset_name:
|
||||||
|
chosen = next(
|
||||||
|
(p for p in presets if isinstance(p, dict)
|
||||||
|
and _short(p.get("model") or p.get("modelId") or "") == preset_name.lower()),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
if chosen is not None:
|
||||||
|
repo_id = repo_id or chosen.get("model") or chosen.get("modelId") or ""
|
||||||
|
cmd = cmd or (chosen.get("cmd") or "").strip()
|
||||||
|
host = host or chosen.get("host") or chosen.get("remoteHost") or ""
|
||||||
|
if not repo_id or not cmd or cmd.startswith("(adopted"):
|
||||||
|
# Surface what we tried so the user can name their preset to match.
|
||||||
|
preset_names = [(p.get("name") or "") for p in (state.get("presets") or []) if isinstance(p, dict)]
|
||||||
|
hint = f" Saved presets: {preset_names!r}" if preset_names else ""
|
||||||
|
return (f"No launchable config for {preset_name!r} (repo_id={repo_id!r}). "
|
||||||
|
f"Check Cookbook → Presets has a real cmd, not 'adopted'.{hint}", False)
|
||||||
|
|
||||||
|
# Resolve env_prefix etc. from the host's saved cookbook server entry,
|
||||||
|
# matching the chat agent's serve_model path.
|
||||||
|
body = {"repo_id": repo_id, "cmd": cmd}
|
||||||
|
if host:
|
||||||
|
body["remote_host"] = host
|
||||||
|
env = (state.get("env") or {})
|
||||||
|
srv = next(
|
||||||
|
(s for s in (env.get("servers") or [])
|
||||||
|
if isinstance(s, dict) and (s.get("host") == host or s.get("name") == host)),
|
||||||
|
{},
|
||||||
|
)
|
||||||
|
if srv.get("env") == "venv" and srv.get("envPath"):
|
||||||
|
body["env_prefix"] = f"source {srv['envPath']}/bin/activate"
|
||||||
|
elif srv.get("env") == "conda" and srv.get("envPath"):
|
||||||
|
body["env_prefix"] = f"conda activate {srv['envPath']}"
|
||||||
|
if srv.get("hfToken"): body["hf_token"] = srv["hfToken"]
|
||||||
|
if srv.get("port"): body["ssh_port"] = str(srv["port"])
|
||||||
|
if srv.get("platform"): body["platform"] = srv["platform"]
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with httpx.AsyncClient(timeout=30) as client:
|
||||||
|
r = await client.post(f"{internal_api_base()}/api/model/serve",
|
||||||
|
json=body, headers=headers)
|
||||||
|
data = r.json() if r.content else {}
|
||||||
|
except Exception as e:
|
||||||
|
return f"Launch HTTP failed: {e}", False
|
||||||
|
if not data.get("ok"):
|
||||||
|
return f"Launch rejected: {data.get('error') or data.get('detail') or 'unknown'}", False
|
||||||
|
|
||||||
|
sid = data.get("session_id") or ""
|
||||||
|
# Register the new task in cookbook_state.json + stamp it with our
|
||||||
|
# scheduler-owner markers. /api/model/serve spawns the tmux session
|
||||||
|
# but leaves the state-write to the UI — when a scheduled action
|
||||||
|
# launches a serve from server-side, NOBODY writes the task into
|
||||||
|
# state, so the Cookbook tab never shows it. We do the write here.
|
||||||
|
if sid:
|
||||||
|
try:
|
||||||
|
# Re-read fresh (the route may have updated state already).
|
||||||
|
try:
|
||||||
|
fresh = json.loads(state_path.read_text(encoding="utf-8"))
|
||||||
|
except Exception:
|
||||||
|
fresh = {}
|
||||||
|
if not isinstance(fresh, dict):
|
||||||
|
fresh = {}
|
||||||
|
tasks = fresh.get("tasks") if isinstance(fresh.get("tasks"), list) else []
|
||||||
|
existing = next(
|
||||||
|
(t for t in tasks if isinstance(t, dict) and t.get("sessionId") == sid),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
if existing is None:
|
||||||
|
display_name = repo_id.split("/")[-1] if "/" in repo_id else repo_id
|
||||||
|
placeholder = (
|
||||||
|
f"Launched by scheduled task {task_name!r} — waiting for tmux output…\n"
|
||||||
|
f" session: {sid}\n"
|
||||||
|
f" target: {host or 'local'}\n"
|
||||||
|
f" cmd: {cmd[:200]}{'…' if len(cmd) > 200 else ''}"
|
||||||
|
)
|
||||||
|
existing = {
|
||||||
|
"id": sid,
|
||||||
|
"sessionId": sid,
|
||||||
|
"name": display_name,
|
||||||
|
"modelId": repo_id,
|
||||||
|
"type": "serve",
|
||||||
|
"status": "running",
|
||||||
|
"output": placeholder,
|
||||||
|
"ts": int(_time.time() * 1000),
|
||||||
|
"payload": {"repo_id": repo_id, "remote_host": host or "", "_cmd": cmd},
|
||||||
|
"remoteHost": host or "",
|
||||||
|
"sshPort": "",
|
||||||
|
"platform": "linux",
|
||||||
|
"_serveReady": False,
|
||||||
|
"_endpointAdded": False,
|
||||||
|
}
|
||||||
|
tasks.append(existing)
|
||||||
|
# Stamp ownership + end-at on the task entry.
|
||||||
|
existing["_scheduledByTask"] = task_name or ""
|
||||||
|
existing["_scheduledByOwner"] = owner or ""
|
||||||
|
if end_after_min > 0:
|
||||||
|
existing["_scheduledStopAtMs"] = int(_time.time() * 1000) + end_after_min * 60 * 1000
|
||||||
|
fresh["tasks"] = tasks
|
||||||
|
atomic_write_json(state_path, fresh)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"cookbook_serve: state register/stamp failed: {e}")
|
||||||
|
# Don't try to render absolute clock time in the message — the
|
||||||
|
# server runs in UTC (Docker default), the user reads it as local,
|
||||||
|
# and the offset depends on the user's TZ which the action doesn't
|
||||||
|
# have a reliable handle on. The Tasks UI already shows the RUN
|
||||||
|
# timestamp in the user's local time right above this message, so
|
||||||
|
# "stops 8 min after that" gives the user everything they need.
|
||||||
|
if end_after_min:
|
||||||
|
return (
|
||||||
|
f"Launched {repo_id} (session {sid}); stops {end_after_min} min after this ran",
|
||||||
|
True,
|
||||||
|
)
|
||||||
|
return f"Launched {repo_id} (session {sid})", True
|
||||||
|
|
||||||
|
|
||||||
BUILTIN_ACTIONS = {
|
BUILTIN_ACTIONS = {
|
||||||
"tidy_sessions": action_tidy_sessions,
|
"tidy_sessions": action_tidy_sessions,
|
||||||
"tidy_documents": action_tidy_documents,
|
"tidy_documents": action_tidy_documents,
|
||||||
@@ -2020,6 +2222,7 @@ BUILTIN_ACTIONS = {
|
|||||||
"test_skills": action_test_skills,
|
"test_skills": action_test_skills,
|
||||||
"audit_skills": action_audit_skills,
|
"audit_skills": action_audit_skills,
|
||||||
"check_email_urgency": action_check_email_urgency,
|
"check_email_urgency": action_check_email_urgency,
|
||||||
|
"cookbook_serve": action_cookbook_serve,
|
||||||
# ping_notes removed from the registry — runs only inside `_note_pings_loop`.
|
# ping_notes removed from the registry — runs only inside `_note_pings_loop`.
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
+241
-49
@@ -27,6 +27,7 @@ import hashlib
|
|||||||
import ipaddress
|
import ipaddress
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import socket
|
||||||
import uuid
|
import uuid
|
||||||
from datetime import date, datetime, timedelta, timezone
|
from datetime import date, datetime, timedelta, timezone
|
||||||
from urllib.parse import urlparse, urlunparse
|
from urllib.parse import urlparse, urlunparse
|
||||||
@@ -50,15 +51,55 @@ def _private_caldav_allowed() -> bool:
|
|||||||
return os.environ.get("ODYSSEUS_ALLOW_PRIVATE_CALDAV", "0").lower() in {"1", "true", "yes"}
|
return os.environ.get("ODYSSEUS_ALLOW_PRIVATE_CALDAV", "0").lower() in {"1", "true", "yes"}
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_caldav_address(addr: ipaddress._BaseAddress) -> None:
|
||||||
|
if isinstance(addr, ipaddress.IPv6Address) and addr.ipv4_mapped is not None:
|
||||||
|
addr = addr.ipv4_mapped
|
||||||
|
if (
|
||||||
|
addr.is_loopback
|
||||||
|
or addr.is_link_local
|
||||||
|
or addr.is_multicast
|
||||||
|
or addr.is_unspecified
|
||||||
|
or addr.is_reserved
|
||||||
|
):
|
||||||
|
raise ValueError("CalDAV URL host is not allowed")
|
||||||
|
if addr.is_private and not _private_caldav_allowed():
|
||||||
|
raise ValueError("Private CalDAV IPs require ODYSSEUS_ALLOW_PRIVATE_CALDAV=1")
|
||||||
|
|
||||||
|
|
||||||
def _validate_caldav_ip(host: str) -> None:
|
def _validate_caldav_ip(host: str) -> None:
|
||||||
try:
|
try:
|
||||||
ip = ipaddress.ip_address(host.strip("[]"))
|
ip = ipaddress.ip_address(host.strip("[]"))
|
||||||
except ValueError:
|
except ValueError:
|
||||||
return
|
return
|
||||||
if ip.is_loopback or ip.is_link_local or ip.is_multicast or ip.is_unspecified:
|
_validate_caldav_address(ip)
|
||||||
raise ValueError("CalDAV URL host is not allowed")
|
|
||||||
if ip.is_private and not _private_caldav_allowed():
|
|
||||||
raise ValueError("Private CalDAV IPs require ODYSSEUS_ALLOW_PRIVATE_CALDAV=1")
|
def _resolve_caldav_host_ips(host: str) -> list[ipaddress._BaseAddress]:
|
||||||
|
addrs: list[ipaddress._BaseAddress] = []
|
||||||
|
for family, _, _, _, sockaddr in socket.getaddrinfo(host, None):
|
||||||
|
if family not in (socket.AF_INET, socket.AF_INET6):
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
addrs.append(ipaddress.ip_address(sockaddr[0].split("%", 1)[0]))
|
||||||
|
except ValueError:
|
||||||
|
continue
|
||||||
|
return addrs
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_caldav_hostname(host: str) -> None:
|
||||||
|
try:
|
||||||
|
ipaddress.ip_address(host.strip("[]"))
|
||||||
|
return
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
addrs = _resolve_caldav_host_ips(host)
|
||||||
|
except OSError:
|
||||||
|
raise ValueError("CalDAV URL host does not resolve")
|
||||||
|
if not addrs:
|
||||||
|
raise ValueError("CalDAV URL host does not resolve")
|
||||||
|
for addr in addrs:
|
||||||
|
_validate_caldav_address(addr)
|
||||||
|
|
||||||
|
|
||||||
def validate_caldav_url(raw_url: str) -> str:
|
def validate_caldav_url(raw_url: str) -> str:
|
||||||
@@ -83,13 +124,18 @@ def validate_caldav_url(raw_url: str) -> str:
|
|||||||
if host in _BLOCKED_HOSTS or host.endswith(".localhost"):
|
if host in _BLOCKED_HOSTS or host.endswith(".localhost"):
|
||||||
raise ValueError("CalDAV URL host is not allowed")
|
raise ValueError("CalDAV URL host is not allowed")
|
||||||
_validate_caldav_ip(host)
|
_validate_caldav_ip(host)
|
||||||
|
_validate_caldav_hostname(host)
|
||||||
return urlunparse(parsed._replace(fragment="")).rstrip("/")
|
return urlunparse(parsed._replace(fragment="")).rstrip("/")
|
||||||
|
|
||||||
|
|
||||||
def _stable_cal_id(remote_url: str) -> str:
|
def _stable_cal_id(remote_url: str, owner: str = "", account_id: str = "") -> str:
|
||||||
"""Deterministic local id for a remote CalDAV calendar — same URL
|
"""Deterministic local id for a remote CalDAV calendar, scoped to owner
|
||||||
always maps to the same local row across restarts and re-syncs."""
|
and account so two users — or one user with two accounts — pointing at
|
||||||
h = hashlib.sha256(remote_url.encode("utf-8")).hexdigest()[:24]
|
the same server URL get distinct local rows (avoids PK collision, #2765).
|
||||||
|
The owner and account_id default to "" for the legacy/URL-only path so
|
||||||
|
existing callers without those arguments keep working."""
|
||||||
|
key = f"{owner}\n{account_id}\n{remote_url}"
|
||||||
|
h = hashlib.sha256(key.encode("utf-8")).hexdigest()[:24]
|
||||||
return f"caldav-{h}"
|
return f"caldav-{h}"
|
||||||
|
|
||||||
|
|
||||||
@@ -124,18 +170,103 @@ def _find_existing_event(db, pending, uid_val, calendar_id):
|
|||||||
).first()
|
).first()
|
||||||
|
|
||||||
|
|
||||||
def _sync_blocking(owner: str, url: str, username: str, password: str) -> dict:
|
def _google_caldav_events_url(url: str) -> str | None:
|
||||||
|
"""Map a Google CalDAV *principal* URL to its event-collection URL.
|
||||||
|
|
||||||
|
Google serves the principal at ``…/user`` but events live under ``…/events``
|
||||||
|
— the ``/user`` resource holds no VEVENTs. The `caldav` library's
|
||||||
|
principal→home-set discovery does not reliably enumerate calendars from
|
||||||
|
Google's ``/user`` endpoint, so the sync falls into the "treat the URL as a
|
||||||
|
single calendar" fallback below. Pointed at ``/user`` that fallback issues
|
||||||
|
every calendar-query REPORT against the principal, which returns a clean but
|
||||||
|
empty 200 for all date ranges — the calendar shows no events even though
|
||||||
|
auth succeeded (issue #2507).
|
||||||
|
|
||||||
|
Both Google CalDAV endpoint forms are handled, since some accounts only
|
||||||
|
authenticate against one of them:
|
||||||
|
- newer: ``https://apidata.googleusercontent.com/caldav/v2/<id>/user``
|
||||||
|
- legacy: ``https://www.google.com/calendar/dav/<id>/user``
|
||||||
|
|
||||||
|
Returns the events URL for a recognised Google principal URL, else None so
|
||||||
|
the caller keeps the original URL unchanged.
|
||||||
|
"""
|
||||||
|
parts = urlparse(url)
|
||||||
|
host = (parts.hostname or "").lower()
|
||||||
|
path = parts.path.rstrip("/")
|
||||||
|
if not path.endswith("/user"):
|
||||||
|
return None
|
||||||
|
is_google = (
|
||||||
|
host.endswith("googleusercontent.com") # newer /caldav/v2 form
|
||||||
|
or (host in ("www.google.com", "google.com") and "/calendar/dav/" in path) # legacy form
|
||||||
|
)
|
||||||
|
if not is_google:
|
||||||
|
return None
|
||||||
|
new_path = path[: -len("/user")] + "/events"
|
||||||
|
return urlunparse(parts._replace(path=new_path))
|
||||||
|
|
||||||
|
|
||||||
|
def _open_url_as_calendar(client, url: str):
|
||||||
|
"""Open ``url`` as a single calendar collection.
|
||||||
|
|
||||||
|
Used when principal discovery yields no calendars. Google's principal URL
|
||||||
|
is not an event collection, so map it to the events URL first
|
||||||
|
(see ``_google_caldav_events_url``); other servers' URLs are used as-is.
|
||||||
|
"""
|
||||||
|
target = _google_caldav_events_url(url) or url
|
||||||
|
return client.calendar(url=target)
|
||||||
|
|
||||||
|
|
||||||
|
def _build_dav_client(url: str, username: str, password: str):
|
||||||
|
"""Construct a CalDAV client with automatic redirects disabled.
|
||||||
|
|
||||||
|
``validate_caldav_url`` resolves and vets the *initial* host, but caldav's
|
||||||
|
underlying HTTP session follows 3xx redirects by default. So a URL that
|
||||||
|
passes validation can still be redirected — at request time — to
|
||||||
|
loopback / link-local / private space, re-opening the SSRF the host check
|
||||||
|
closes. Pin the session to zero redirects: any 3xx then raises instead of
|
||||||
|
silently following an attacker-chosen ``Location``. This mirrors the
|
||||||
|
test-connection path in ``routes/calendar_routes.py``, which already sets
|
||||||
|
``follow_redirects=False``.
|
||||||
|
|
||||||
|
DAVClient exposes no per-request redirect flag, so we set it on the session
|
||||||
|
after construction (the session is created in ``__init__``).
|
||||||
|
"""
|
||||||
|
import caldav
|
||||||
|
|
||||||
|
client = caldav.DAVClient(url=url, username=username, password=password)
|
||||||
|
# Unconditional: a redirect-disable that only sometimes applies is not a
|
||||||
|
# control. The session exists right after __init__ on every real client;
|
||||||
|
# test_build_dav_client_disables_redirects asserts it against installed
|
||||||
|
# caldav in CI.
|
||||||
|
client.session.max_redirects = 0
|
||||||
|
return client
|
||||||
|
|
||||||
|
|
||||||
|
def _should_prune_window(seen_uids: set, parse_failed: bool) -> bool:
|
||||||
|
"""Whether the post-sync prune of vanished CalDAV events is safe to run.
|
||||||
|
|
||||||
|
The prune deletes local ``origin=="caldav"`` rows in the window whose UID the
|
||||||
|
server did not just return. Any parse failure (total or partial) makes
|
||||||
|
``seen_uids`` an incomplete view of the server, so pruning against it can
|
||||||
|
delete events that still exist upstream but could not be read: a total
|
||||||
|
failure wipes the whole window, a partial failure deletes just the
|
||||||
|
unreadable ones. Only prune on a clean read. An empty ``seen_uids`` after a
|
||||||
|
clean read is a genuinely empty window, which is safe to prune.
|
||||||
|
"""
|
||||||
|
return not parse_failed
|
||||||
|
|
||||||
|
|
||||||
|
def _sync_blocking(owner: str, url: str, username: str, password: str, account_id: str = "") -> dict:
|
||||||
"""The actual sync — synchronous, intended to run in a threadpool.
|
"""The actual sync — synchronous, intended to run in a threadpool.
|
||||||
Returns counts: {calendars, events, deleted, errors}."""
|
Returns counts: {calendars, events, deleted, errors}."""
|
||||||
# Lazy imports so a missing `caldav` dep doesn't break app startup —
|
# Lazy imports so a missing `caldav` dep doesn't break app startup —
|
||||||
# the integrations form still works, sync just no-ops with an error.
|
# the integrations form still works, sync just no-ops with an error.
|
||||||
import caldav
|
|
||||||
from caldav.lib.error import AuthorizationError, NotFoundError
|
from caldav.lib.error import AuthorizationError, NotFoundError
|
||||||
from core.database import CalendarCal, CalendarEvent, SessionLocal
|
from core.database import CalendarCal, CalendarEvent, SessionLocal
|
||||||
|
|
||||||
result = {"calendars": 0, "events": 0, "deleted": 0, "errors": []}
|
result = {"calendars": 0, "events": 0, "deleted": 0, "errors": []}
|
||||||
|
|
||||||
client = caldav.DAVClient(url=url, username=username, password=password)
|
client = _build_dav_client(url, username, password)
|
||||||
|
|
||||||
# Discovery: try principal → calendars first; if the server doesn't
|
# Discovery: try principal → calendars first; if the server doesn't
|
||||||
# support discovery (or the URL points directly at a calendar), fall
|
# support discovery (or the URL points directly at a calendar), fall
|
||||||
@@ -150,14 +281,14 @@ def _sync_blocking(owner: str, url: str, username: str, password: str) -> dict:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.info(f"CalDAV principal discovery failed, trying URL as calendar: {e}")
|
logger.info(f"CalDAV principal discovery failed, trying URL as calendar: {e}")
|
||||||
try:
|
try:
|
||||||
calendars = [client.calendar(url=url)]
|
calendars = [_open_url_as_calendar(client, url)]
|
||||||
except Exception as e2:
|
except Exception as e2:
|
||||||
result["errors"].append(f"Could not open URL as calendar: {e2}")
|
result["errors"].append(f"Could not open URL as calendar: {e2}")
|
||||||
return result
|
return result
|
||||||
|
|
||||||
if not calendars:
|
if not calendars:
|
||||||
try:
|
try:
|
||||||
calendars = [client.calendar(url=url)]
|
calendars = [_open_url_as_calendar(client, url)]
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
result["errors"].append(f"No calendars and URL fallback failed: {e}")
|
result["errors"].append(f"No calendars and URL fallback failed: {e}")
|
||||||
return result
|
return result
|
||||||
@@ -170,7 +301,7 @@ def _sync_blocking(owner: str, url: str, username: str, password: str) -> dict:
|
|||||||
for remote_cal in calendars:
|
for remote_cal in calendars:
|
||||||
try:
|
try:
|
||||||
remote_url = str(remote_cal.url)
|
remote_url = str(remote_cal.url)
|
||||||
cal_id = _stable_cal_id(remote_url)
|
cal_id = _stable_cal_id(remote_url, owner=owner, account_id=account_id)
|
||||||
display_name = (remote_cal.name or "").strip() or "CalDAV"
|
display_name = (remote_cal.name or "").strip() or "CalDAV"
|
||||||
|
|
||||||
local_cal = db.query(CalendarCal).filter(
|
local_cal = db.query(CalendarCal).filter(
|
||||||
@@ -184,14 +315,20 @@ def _sync_blocking(owner: str, url: str, username: str, password: str) -> dict:
|
|||||||
name=display_name,
|
name=display_name,
|
||||||
color="#5b8abf",
|
color="#5b8abf",
|
||||||
source="caldav",
|
source="caldav",
|
||||||
|
account_id=account_id or None,
|
||||||
)
|
)
|
||||||
db.add(local_cal)
|
db.add(local_cal)
|
||||||
db.commit()
|
db.commit()
|
||||||
else:
|
else:
|
||||||
# Refresh the display name if the user renamed it
|
# Refresh display name and stamp account_id if missing.
|
||||||
# remotely; preserve any local color override.
|
changed = False
|
||||||
if local_cal.name != display_name:
|
if local_cal.name != display_name:
|
||||||
local_cal.name = display_name
|
local_cal.name = display_name
|
||||||
|
changed = True
|
||||||
|
if account_id and not local_cal.account_id:
|
||||||
|
local_cal.account_id = account_id
|
||||||
|
changed = True
|
||||||
|
if changed:
|
||||||
db.commit()
|
db.commit()
|
||||||
result["calendars"] += 1
|
result["calendars"] += 1
|
||||||
|
|
||||||
@@ -205,6 +342,7 @@ def _sync_blocking(owner: str, url: str, username: str, password: str) -> dict:
|
|||||||
# duplicate UIDs within the same batch are updated, not re-inserted
|
# duplicate UIDs within the same batch are updated, not re-inserted
|
||||||
# (which would violate the UNIQUE constraint on commit).
|
# (which would violate the UNIQUE constraint on commit).
|
||||||
pending: dict = {}
|
pending: dict = {}
|
||||||
|
parse_failed = False
|
||||||
try:
|
try:
|
||||||
objs = remote_cal.date_search(start=start, end=end, expand=False)
|
objs = remote_cal.date_search(start=start, end=end, expand=False)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -216,6 +354,7 @@ def _sync_blocking(owner: str, url: str, username: str, password: str) -> dict:
|
|||||||
ical = iCal.from_ical(obj.data)
|
ical = iCal.from_ical(obj.data)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
result["errors"].append(f"{display_name}: parse failed ({e})")
|
result["errors"].append(f"{display_name}: parse failed ({e})")
|
||||||
|
parse_failed = True
|
||||||
continue
|
continue
|
||||||
|
|
||||||
for comp in ical.walk():
|
for comp in ical.walk():
|
||||||
@@ -292,17 +431,23 @@ def _sync_blocking(owner: str, url: str, username: str, password: str) -> dict:
|
|||||||
# are prunable; locally-created events (agent / email triage / a
|
# are prunable; locally-created events (agent / email triage / a
|
||||||
# UI event whose write-back failed) carry origin NULL and must
|
# UI event whose write-back failed) carry origin NULL and must
|
||||||
# never be deleted just because the server didn't return them.
|
# never be deleted just because the server didn't return them.
|
||||||
stale = db.query(CalendarEvent).filter(
|
# Skip the prune on any parse failure: seen_uids is then an
|
||||||
CalendarEvent.calendar_id == local_cal.id,
|
# incomplete view of the server, so pruning against it would
|
||||||
CalendarEvent.origin == "caldav",
|
# delete events that still exist upstream but could not be read
|
||||||
CalendarEvent.dtstart >= start,
|
# (the empty-seen_uids case wipes the whole window; a partial
|
||||||
CalendarEvent.dtstart <= end,
|
# failure deletes just the unreadable rows).
|
||||||
~CalendarEvent.uid.in_(seen_uids) if seen_uids else CalendarEvent.uid.isnot(None),
|
if _should_prune_window(seen_uids, parse_failed):
|
||||||
).all()
|
stale = db.query(CalendarEvent).filter(
|
||||||
for ev in stale:
|
CalendarEvent.calendar_id == local_cal.id,
|
||||||
db.delete(ev)
|
CalendarEvent.origin == "caldav",
|
||||||
result["deleted"] += len(stale)
|
CalendarEvent.dtstart >= start,
|
||||||
db.commit()
|
CalendarEvent.dtstart <= end,
|
||||||
|
~CalendarEvent.uid.in_(seen_uids) if seen_uids else CalendarEvent.uid.isnot(None),
|
||||||
|
).all()
|
||||||
|
for ev in stale:
|
||||||
|
db.delete(ev)
|
||||||
|
result["deleted"] += len(stale)
|
||||||
|
db.commit()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception("CalDAV sync failed for one calendar")
|
logger.exception("CalDAV sync failed for one calendar")
|
||||||
result["errors"].append(str(e)[:200])
|
result["errors"].append(str(e)[:200])
|
||||||
@@ -313,31 +458,78 @@ def _sync_blocking(owner: str, url: str, username: str, password: str) -> dict:
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
async def sync_caldav(owner: str) -> dict:
|
def _load_caldav_accounts(owner: str) -> list:
|
||||||
"""Pull CalDAV state into local DB for `owner`. Returns counts +
|
"""Return the list of CalDAV accounts for *owner*, auto-migrating the legacy
|
||||||
errors. Loads credentials from the user's prefs; no-ops with a
|
single-account ``caldav`` key to the new ``caldav_accounts`` list on first call.
|
||||||
clear error if CalDAV isn't configured."""
|
|
||||||
|
The save step is best-effort: if ``_save_for_user`` is unavailable (e.g. in a
|
||||||
|
test with a minimal prefs mock) the migrated accounts are still returned; the
|
||||||
|
next real call will just re-run the cheap migration again.
|
||||||
|
"""
|
||||||
|
import uuid as _uuid
|
||||||
from routes.prefs_routes import _load_for_user
|
from routes.prefs_routes import _load_for_user
|
||||||
|
|
||||||
cfg = (_load_for_user(owner) or {}).get("caldav", {}) or {}
|
prefs = _load_for_user(owner) or {}
|
||||||
url = (cfg.get("url") or "").strip()
|
if "caldav_accounts" in prefs:
|
||||||
user = (cfg.get("username") or "").strip()
|
return list(prefs["caldav_accounts"] or [])
|
||||||
pw = cfg.get("password") or ""
|
# Migrate legacy single-account config to the list format.
|
||||||
try:
|
legacy = prefs.get("caldav", {}) or {}
|
||||||
from src.secret_storage import decrypt
|
if legacy.get("url"):
|
||||||
pw = decrypt(pw)
|
accounts = [{
|
||||||
except Exception:
|
"id": str(_uuid.uuid4()),
|
||||||
pass
|
"label": "CalDAV",
|
||||||
if not (url and user and pw):
|
"url": legacy["url"],
|
||||||
|
"username": legacy.get("username", ""),
|
||||||
|
"password": legacy.get("password", ""),
|
||||||
|
}]
|
||||||
|
prefs["caldav_accounts"] = accounts
|
||||||
|
prefs.pop("caldav", None)
|
||||||
|
try:
|
||||||
|
from routes.prefs_routes import _save_for_user
|
||||||
|
_save_for_user(owner, prefs)
|
||||||
|
except (ImportError, AttributeError):
|
||||||
|
pass # best-effort; next call re-migrates from the still-present legacy key
|
||||||
|
return accounts
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
async def sync_caldav(owner: str) -> dict:
|
||||||
|
"""Pull CalDAV state into local DB for `owner` across all configured accounts.
|
||||||
|
Returns aggregated counts + per-account errors."""
|
||||||
|
from src.secret_storage import decrypt
|
||||||
|
|
||||||
|
accounts = _load_caldav_accounts(owner)
|
||||||
|
if not accounts:
|
||||||
return {
|
return {
|
||||||
"calendars": 0, "events": 0, "deleted": 0,
|
"calendars": 0, "events": 0, "deleted": 0,
|
||||||
"errors": ["CalDAV is not configured"],
|
"errors": ["CalDAV is not configured"],
|
||||||
}
|
}
|
||||||
try:
|
|
||||||
url = validate_caldav_url(url)
|
totals: dict = {"calendars": 0, "events": 0, "deleted": 0, "errors": []}
|
||||||
return await asyncio.to_thread(_sync_blocking, owner, url, user, pw)
|
for acc in accounts:
|
||||||
except ValueError as e:
|
url = (acc.get("url") or "").strip()
|
||||||
return {"calendars": 0, "events": 0, "deleted": 0, "errors": [str(e)]}
|
user = (acc.get("username") or "").strip()
|
||||||
except Exception as e:
|
pw = acc.get("password") or ""
|
||||||
logger.exception("CalDAV sync raised")
|
account_id = acc.get("id") or ""
|
||||||
return {"calendars": 0, "events": 0, "deleted": 0, "errors": [str(e)[:200]]}
|
label = acc.get("label") or url or account_id
|
||||||
|
try:
|
||||||
|
pw = decrypt(pw)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
if not (url and user and pw):
|
||||||
|
totals["errors"].append(f"{label}: missing URL, username, or password")
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
url = validate_caldav_url(url)
|
||||||
|
result = await asyncio.to_thread(_sync_blocking, owner, url, user, pw, account_id)
|
||||||
|
except ValueError as e:
|
||||||
|
result = {"calendars": 0, "events": 0, "deleted": 0, "errors": [str(e)]}
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception("CalDAV sync raised for account %s", label)
|
||||||
|
result = {"calendars": 0, "events": 0, "deleted": 0, "errors": [str(e)[:200]]}
|
||||||
|
totals["calendars"] += result.get("calendars", 0)
|
||||||
|
totals["events"] += result.get("events", 0)
|
||||||
|
totals["deleted"] += result.get("deleted", 0)
|
||||||
|
for err in result.get("errors", []):
|
||||||
|
totals["errors"].append(f"{label}: {err}")
|
||||||
|
return totals
|
||||||
|
|||||||
+59
-23
@@ -23,11 +23,10 @@ from datetime import timezone
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def _stable_cal_id(remote_url: str) -> str:
|
def _stable_cal_id(remote_url: str, owner: str = "", account_id: str = "") -> str:
|
||||||
# Reuse the sync module's hashing so a local CalDAV calendar id maps back to
|
# Reuse the sync module's hashing so owner+account_id scoping stays consistent.
|
||||||
# the same remote URL it was pulled from.
|
|
||||||
from src.caldav_sync import _stable_cal_id as _sync_id
|
from src.caldav_sync import _stable_cal_id as _sync_id
|
||||||
return _sync_id(remote_url)
|
return _sync_id(remote_url, owner=owner, account_id=account_id)
|
||||||
|
|
||||||
|
|
||||||
def build_event_ical(ev: dict) -> str:
|
def build_event_ical(ev: dict) -> str:
|
||||||
@@ -76,28 +75,34 @@ def build_event_ical(ev: dict) -> str:
|
|||||||
return cal.to_ical().decode("utf-8")
|
return cal.to_ical().decode("utf-8")
|
||||||
|
|
||||||
|
|
||||||
def find_remote_calendar(calendars, local_cal_id: str):
|
def find_remote_calendar(calendars, local_cal_id: str, owner: str = "", account_id: str = ""):
|
||||||
"""Find the remote calendar whose URL hashes to ``local_cal_id``, or None."""
|
"""Find the remote calendar whose URL hashes to ``local_cal_id``, or None.
|
||||||
|
|
||||||
|
``owner`` and ``account_id`` must match what was used when the local calendar
|
||||||
|
id was originally computed in ``_sync_blocking`` so the hash round-trips."""
|
||||||
for cal in calendars:
|
for cal in calendars:
|
||||||
try:
|
try:
|
||||||
if _stable_cal_id(str(cal.url)) == local_cal_id:
|
if _stable_cal_id(str(cal.url), owner=owner, account_id=account_id) == local_cal_id:
|
||||||
return cal
|
return cal
|
||||||
except Exception:
|
except Exception:
|
||||||
continue
|
continue
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def push_event(calendars, local_cal_id: str, ev: dict, *, delete: bool = False) -> dict:
|
def push_event(calendars, local_cal_id: str, ev: dict, *, delete: bool = False,
|
||||||
|
owner: str = "", account_id: str = "") -> dict:
|
||||||
"""Create/update (or delete) ``ev`` on the matching remote calendar.
|
"""Create/update (or delete) ``ev`` on the matching remote calendar.
|
||||||
|
|
||||||
Returns ``{"ok": bool, ...}``. ``calendars`` is the discovered caldav
|
Returns ``{"ok": bool, ...}``. ``calendars`` is the discovered caldav
|
||||||
calendar list (injected so this is unit-testable with fakes).
|
calendar list (injected so this is unit-testable with fakes).
|
||||||
|
``owner`` and ``account_id`` are forwarded to ``find_remote_calendar``
|
||||||
|
so the URL hash round-trips correctly (#2765).
|
||||||
"""
|
"""
|
||||||
uid = (ev or {}).get("uid") if isinstance(ev, dict) else None
|
uid = (ev or {}).get("uid") if isinstance(ev, dict) else None
|
||||||
if not uid:
|
if not uid:
|
||||||
return {"ok": False, "error": "event uid is required"}
|
return {"ok": False, "error": "event uid is required"}
|
||||||
|
|
||||||
remote = find_remote_calendar(calendars, local_cal_id)
|
remote = find_remote_calendar(calendars, local_cal_id, owner=owner, account_id=account_id)
|
||||||
if remote is None:
|
if remote is None:
|
||||||
return {"ok": False, "error": "remote calendar not found"}
|
return {"ok": False, "error": "remote calendar not found"}
|
||||||
|
|
||||||
@@ -136,13 +141,17 @@ def _discover_calendars(client):
|
|||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
||||||
def _writeback_blocking(local_cal_id, ev, delete, url, username, password) -> dict:
|
def _writeback_blocking(local_cal_id, ev, delete, url, username, password,
|
||||||
import caldav
|
owner="", account_id="") -> dict:
|
||||||
client = caldav.DAVClient(url=url, username=username, password=password)
|
from src.caldav_sync import _build_dav_client
|
||||||
|
# Redirects disabled here too: the write-back path opens its own DAVClient,
|
||||||
|
# so it needs the same SSRF-via-redirect protection as the pull path.
|
||||||
|
client = _build_dav_client(url, username, password)
|
||||||
calendars = _discover_calendars(client)
|
calendars = _discover_calendars(client)
|
||||||
if not calendars:
|
if not calendars:
|
||||||
return {"ok": False, "error": "no remote calendars discovered"}
|
return {"ok": False, "error": "no remote calendars discovered"}
|
||||||
return push_event(calendars, local_cal_id, ev, delete=delete)
|
return push_event(calendars, local_cal_id, ev, delete=delete,
|
||||||
|
owner=owner, account_id=account_id)
|
||||||
|
|
||||||
|
|
||||||
async def writeback_event(owner: str, calendar_source: str, calendar_id: str,
|
async def writeback_event(owner: str, calendar_source: str, calendar_id: str,
|
||||||
@@ -156,18 +165,45 @@ async def writeback_event(owner: str, calendar_source: str, calendar_id: str,
|
|||||||
if calendar_source != "caldav":
|
if calendar_source != "caldav":
|
||||||
return {"skipped": "not a caldav calendar"}
|
return {"skipped": "not a caldav calendar"}
|
||||||
try:
|
try:
|
||||||
from routes.prefs_routes import _load_for_user
|
from src.caldav_sync import _load_caldav_accounts
|
||||||
from src.secret_storage import decrypt
|
from src.secret_storage import decrypt
|
||||||
cfg = (_load_for_user(owner) or {}).get("caldav", {}) or {}
|
from core.database import CalendarCal, SessionLocal
|
||||||
url = (cfg.get("url") or "").strip()
|
|
||||||
user = (cfg.get("username") or "").strip()
|
accounts = _load_caldav_accounts(owner)
|
||||||
# Stored encrypted by routes/calendar_routes; decrypt before use so
|
if not accounts:
|
||||||
# the remote sees the real password (decrypt is a no-op on legacy
|
|
||||||
# plaintext). The pull path src/caldav_sync.py already does this.
|
|
||||||
pw = decrypt(cfg.get("password") or "")
|
|
||||||
if not (url and user and pw):
|
|
||||||
return {"skipped": "caldav not configured"}
|
return {"skipped": "caldav not configured"}
|
||||||
result = await asyncio.to_thread(_writeback_blocking, calendar_id, ev, delete, url, user, pw)
|
|
||||||
|
# Find which account owns this calendar.
|
||||||
|
acc = None
|
||||||
|
if len(accounts) > 1:
|
||||||
|
db = SessionLocal()
|
||||||
|
try:
|
||||||
|
cal_row = db.query(CalendarCal).filter(CalendarCal.id == calendar_id).first()
|
||||||
|
cal_account_id = cal_row.account_id if cal_row else None
|
||||||
|
finally:
|
||||||
|
db.close()
|
||||||
|
if cal_account_id:
|
||||||
|
acc = next((a for a in accounts if a.get("id") == cal_account_id), None)
|
||||||
|
# Fall back to first account (covers single-account and legacy rows with
|
||||||
|
# no account_id stamped).
|
||||||
|
if acc is None:
|
||||||
|
acc = accounts[0]
|
||||||
|
|
||||||
|
url = (acc.get("url") or "").strip()
|
||||||
|
user = (acc.get("username") or "").strip()
|
||||||
|
pw = decrypt(acc.get("password") or "")
|
||||||
|
if not (url and user and pw):
|
||||||
|
return {"skipped": "caldav account credentials incomplete"}
|
||||||
|
from src.caldav_sync import validate_caldav_url
|
||||||
|
try:
|
||||||
|
url = validate_caldav_url(url)
|
||||||
|
except ValueError as e:
|
||||||
|
logger.warning("CalDAV write-back URL rejected: %s", e)
|
||||||
|
return {"ok": False, "error": str(e)[:200]}
|
||||||
|
acc_id = acc.get("id") or ""
|
||||||
|
result = await asyncio.to_thread(
|
||||||
|
_writeback_blocking, calendar_id, ev, delete, url, user, pw, owner, acc_id
|
||||||
|
)
|
||||||
if not result.get("ok"):
|
if not result.get("ok"):
|
||||||
logger.warning("CalDAV write-back did not apply: %s", result.get("error") or result)
|
logger.warning("CalDAV write-back did not apply: %s", result.get("error") or result)
|
||||||
return result
|
return result
|
||||||
|
|||||||
+25
-15
@@ -98,6 +98,7 @@ class ChatHandler:
|
|||||||
att_ids: List[str],
|
att_ids: List[str],
|
||||||
sess,
|
sess,
|
||||||
auto_opened_docs: Optional[List[Dict[str, Any]]] = None,
|
auto_opened_docs: Optional[List[Dict[str, Any]]] = None,
|
||||||
|
allow_tool_preprocessing: bool = True,
|
||||||
) -> tuple:
|
) -> tuple:
|
||||||
"""
|
"""
|
||||||
Common preprocessing for both chat endpoints.
|
Common preprocessing for both chat endpoints.
|
||||||
@@ -112,7 +113,7 @@ class ChatHandler:
|
|||||||
attachment_meta: List[Dict[str, Any]] = []
|
attachment_meta: List[Dict[str, Any]] = []
|
||||||
|
|
||||||
# Extract URLs and process YouTube transcripts
|
# Extract URLs and process YouTube transcripts
|
||||||
urls = extract_urls(enhanced_message)
|
urls = extract_urls(enhanced_message) if allow_tool_preprocessing else []
|
||||||
youtube_transcripts: List[str] = []
|
youtube_transcripts: List[str] = []
|
||||||
|
|
||||||
has_youtube = False
|
has_youtube = False
|
||||||
@@ -143,24 +144,18 @@ class ChatHandler:
|
|||||||
if has_youtube:
|
if has_youtube:
|
||||||
youtube_transcripts.insert(0, YOUTUBE_INSTRUCTION_PROMPT)
|
youtube_transcripts.insert(0, YOUTUBE_INSTRUCTION_PROMPT)
|
||||||
|
|
||||||
# Analyze images — skip if vision disabled, or if main model is vision-capable
|
|
||||||
from src.settings import get_setting
|
|
||||||
vision_enabled = get_setting("vision_enabled", True)
|
|
||||||
main_is_vision = await asyncio.to_thread(
|
|
||||||
model_supports_vision, sess.model or "", getattr(sess, "endpoint_url", "") or ""
|
|
||||||
)
|
|
||||||
|
|
||||||
# Resolve uploads once with the session owner. Attachment IDs are
|
# Resolve uploads once with the session owner. Attachment IDs are
|
||||||
# bearer-like references; never trust them without an owner check.
|
# bearer-like references; never trust them without an owner check.
|
||||||
files_by_id: Dict[str, Dict] = {}
|
files_by_id: Dict[str, Dict] = {}
|
||||||
owner = getattr(sess, "owner", None)
|
owner = getattr(sess, "owner", None)
|
||||||
if att_ids:
|
effective_att_ids = att_ids if allow_tool_preprocessing else []
|
||||||
for att_id in att_ids:
|
if effective_att_ids:
|
||||||
|
for att_id in effective_att_ids:
|
||||||
fi = self.upload_handler.resolve_upload(att_id, owner=owner)
|
fi = self.upload_handler.resolve_upload(att_id, owner=owner)
|
||||||
if fi:
|
if fi:
|
||||||
files_by_id[att_id] = fi
|
files_by_id[att_id] = fi
|
||||||
|
|
||||||
for att_id in att_ids:
|
for att_id in effective_att_ids:
|
||||||
fi = files_by_id.get(att_id)
|
fi = files_by_id.get(att_id)
|
||||||
if fi:
|
if fi:
|
||||||
attachment_meta.append({
|
attachment_meta.append({
|
||||||
@@ -172,9 +167,24 @@ class ChatHandler:
|
|||||||
"height": fi.get("height"),
|
"height": fi.get("height"),
|
||||||
})
|
})
|
||||||
|
|
||||||
if att_ids and vision_enabled:
|
# Analyze images only when attachment preprocessing is actually
|
||||||
|
# allowed. The vision capability check can probe local model endpoints,
|
||||||
|
# so guide-only/no-tools turns must not reach it.
|
||||||
|
vision_enabled = False
|
||||||
|
main_is_vision = False
|
||||||
|
if effective_att_ids:
|
||||||
|
from src.settings import get_setting
|
||||||
|
vision_enabled = get_setting("vision_enabled", True)
|
||||||
|
if vision_enabled:
|
||||||
|
main_is_vision = await asyncio.to_thread(
|
||||||
|
model_supports_vision,
|
||||||
|
sess.model or "",
|
||||||
|
getattr(sess, "endpoint_url", "") or "",
|
||||||
|
)
|
||||||
|
|
||||||
|
if effective_att_ids and vision_enabled:
|
||||||
meta_by_id = {m["id"]: m for m in attachment_meta}
|
meta_by_id = {m["id"]: m for m in attachment_meta}
|
||||||
for att_id in att_ids:
|
for att_id in effective_att_ids:
|
||||||
file_info = files_by_id.get(att_id)
|
file_info = files_by_id.get(att_id)
|
||||||
if file_info and self.upload_handler.is_image_file(
|
if file_info and self.upload_handler.is_image_file(
|
||||||
file_info["name"], file_info.get("mime", "")
|
file_info["name"], file_info.get("mime", "")
|
||||||
@@ -219,7 +229,7 @@ class ChatHandler:
|
|||||||
except Exception:
|
except Exception:
|
||||||
vl_desc = None
|
vl_desc = None
|
||||||
if not vl_desc:
|
if not vl_desc:
|
||||||
vl_result = analyze_image_with_vl_result(file_info["path"])
|
vl_result = analyze_image_with_vl_result(file_info["path"], owner=owner)
|
||||||
vl_desc = vl_result.get("text", "")
|
vl_desc = vl_result.get("text", "")
|
||||||
vl_model = vl_result.get("model", "")
|
vl_model = vl_result.get("model", "")
|
||||||
if vl_desc and not vl_desc.startswith("["):
|
if vl_desc and not vl_desc.startswith("["):
|
||||||
@@ -239,7 +249,7 @@ class ChatHandler:
|
|||||||
_m["vision_model"] = vl_model
|
_m["vision_model"] = vl_model
|
||||||
|
|
||||||
user_content = build_user_content(
|
user_content = build_user_content(
|
||||||
enhanced_message, att_ids, UPLOAD_DIR, self.upload_handler,
|
enhanced_message, effective_att_ids, UPLOAD_DIR, self.upload_handler,
|
||||||
session_id=getattr(sess, "id", None),
|
session_id=getattr(sess, "id", None),
|
||||||
auto_opened_docs=auto_opened_docs,
|
auto_opened_docs=auto_opened_docs,
|
||||||
owner=owner,
|
owner=owner,
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user