mirror of
https://github.com/pewdiepie-archdaemon/odysseus.git
synced 2026-06-17 10:15:27 -04:00
Merge branch 'dev'
# Conflicts: # routes/task_routes.py # src/caldav_sync.py
This commit is contained in:
@@ -56,6 +56,13 @@ SEARXNG_INSTANCE=http://localhost:8080
|
|||||||
# SQLite database path (default: sqlite:///./data/app.db)
|
# SQLite database path (default: sqlite:///./data/app.db)
|
||||||
# DATABASE_URL=sqlite:///./data/app.db
|
# DATABASE_URL=sqlite:///./data/app.db
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# Data directory
|
||||||
|
# ============================================================
|
||||||
|
# Move everything that lives under data/ - settings, sessions, database, auth,
|
||||||
|
# cache, uploads, etc. - to another path:
|
||||||
|
# ODYSSEUS_DATA_DIR=C:\path\to\dir
|
||||||
|
|
||||||
# ============================================================
|
# ============================================================
|
||||||
# Auth & Security
|
# Auth & Security
|
||||||
# ============================================================
|
# ============================================================
|
||||||
@@ -112,6 +119,9 @@ SEARXNG_INSTANCE=http://localhost:8080
|
|||||||
# Default: http://{LLM_HOST}:11434/v1/embeddings (ollama)
|
# Default: http://{LLM_HOST}:11434/v1/embeddings (ollama)
|
||||||
# EMBEDDING_URL=http://localhost:11434/v1/embeddings
|
# EMBEDDING_URL=http://localhost:11434/v1/embeddings
|
||||||
|
|
||||||
|
# Embedding API key (if there's one)
|
||||||
|
# EMBEDDING_API_KEY=embedding_api_key_here
|
||||||
|
|
||||||
# Embedding model name (must be available at the endpoint above)
|
# Embedding model name (must be available at the endpoint above)
|
||||||
# EMBEDDING_MODEL=all-minilm:l6-v2
|
# EMBEDDING_MODEL=all-minilm:l6-v2
|
||||||
|
|
||||||
@@ -144,6 +154,21 @@ SEARXNG_INSTANCE=http://localhost:8080
|
|||||||
# if you intentionally want scheduled scripts to run remotely.
|
# if you intentionally want scheduled scripts to run remotely.
|
||||||
# ODYSSEUS_SCRIPT_HOST=localhost
|
# ODYSSEUS_SCRIPT_HOST=localhost
|
||||||
|
|
||||||
|
# Chat / agent attachment size cap in bytes (default: 10 MB).
|
||||||
|
# Raise this for local installs that need larger PDFs or text documents.
|
||||||
|
# Example: 52428800 = 50 MB.
|
||||||
|
# ODYSSEUS_CHAT_UPLOAD_MAX_BYTES=10485760
|
||||||
|
|
||||||
|
# Other per-feature upload size caps in bytes. All are validated and optional;
|
||||||
|
# defaults shown. An invalid value (non-integer or < 1) fails fast at startup.
|
||||||
|
# ODYSSEUS_GALLERY_UPLOAD_MAX_BYTES=104857600 # gallery image upload (100 MB)
|
||||||
|
# ODYSSEUS_GALLERY_TRANSFORM_UPLOAD_MAX_BYTES=26214400 # gallery transform input (25 MB)
|
||||||
|
# ODYSSEUS_MEMORY_IMPORT_MAX_BYTES=10485760 # memory import file (10 MB)
|
||||||
|
# ODYSSEUS_PERSONAL_UPLOAD_MAX_BYTES=26214400 # personal document upload (25 MB)
|
||||||
|
# ODYSSEUS_EMAIL_COMPOSE_UPLOAD_MAX_BYTES=26214400 # email compose attachment (25 MB)
|
||||||
|
# ODYSSEUS_STT_MAX_AUDIO_BYTES=26214400 # speech-to-text audio (25 MB)
|
||||||
|
# ODYSSEUS_ICS_MAX_BYTES=10485760 # calendar .ics import (10 MB)
|
||||||
|
|
||||||
# ============================================================
|
# ============================================================
|
||||||
# GPU support (Docker Compose)
|
# GPU support (Docker Compose)
|
||||||
# ============================================================
|
# ============================================================
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ body:
|
|||||||
required: true
|
required: true
|
||||||
- label: This is **not** a security vulnerability. (Vulnerabilities go to [GitHub Security Advisories](https://github.com/pewdiepie-archdaemon/odysseus/security/advisories/new) — see [SECURITY.md](https://github.com/pewdiepie-archdaemon/odysseus/blob/main/SECURITY.md).)
|
- label: This is **not** a security vulnerability. (Vulnerabilities go to [GitHub Security Advisories](https://github.com/pewdiepie-archdaemon/odysseus/security/advisories/new) — see [SECURITY.md](https://github.com/pewdiepie-archdaemon/odysseus/blob/main/SECURITY.md).)
|
||||||
required: true
|
required: true
|
||||||
- label: I am running the latest code from `main`.
|
- label: I am running the latest code from the `dev` branch (the default branch you get on clone, where fixes land first) and the bug still reproduces there. Please `git pull` the latest `dev` before filing.
|
||||||
required: true
|
required: true
|
||||||
|
|
||||||
- type: dropdown
|
- type: dropdown
|
||||||
|
|||||||
@@ -103,14 +103,21 @@ module.exports = async ({ github, context, core }) => {
|
|||||||
|
|
||||||
async function swapLabel(num, add, remove) {
|
async function swapLabel(num, add, remove) {
|
||||||
if (await labelExists(add)) {
|
if (await labelExists(add)) {
|
||||||
await github.rest.issues.addLabels({ owner, repo, issue_number: num, labels: [add] });
|
try {
|
||||||
|
await github.rest.issues.addLabels({ owner, repo, issue_number: num, labels: [add] });
|
||||||
|
} catch (e) {
|
||||||
|
// Fail soft on a token that can't write labels so a label permission
|
||||||
|
// problem never masks the actual description verdict.
|
||||||
|
if (e.status !== 403) throw e;
|
||||||
|
core.warning(`Could not add "${add}" — token lacks label write here; skipping.`);
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
core.warning(`Label "${add}" does not exist in the repo — skipping. Create it once to enable labelling.`);
|
core.warning(`Label "${add}" does not exist in the repo — skipping. Create it once to enable labelling.`);
|
||||||
}
|
}
|
||||||
try {
|
try {
|
||||||
await github.rest.issues.removeLabel({ owner, repo, issue_number: num, name: remove });
|
await github.rest.issues.removeLabel({ owner, repo, issue_number: num, name: remove });
|
||||||
} catch (e) {
|
} catch (e) {
|
||||||
if (e.status !== 404 && e.status !== 410) throw e;
|
if (e.status !== 404 && e.status !== 410 && e.status !== 403) throw e;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -20,6 +20,8 @@ jobs:
|
|||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4
|
- uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4
|
||||||
|
with:
|
||||||
|
persist-credentials: false
|
||||||
- uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5
|
- uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5
|
||||||
with:
|
with:
|
||||||
python-version: "3.11"
|
python-version: "3.11"
|
||||||
@@ -31,6 +33,8 @@ jobs:
|
|||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4
|
- uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4
|
||||||
|
with:
|
||||||
|
persist-credentials: false
|
||||||
- uses: actions/setup-node@49933ea5288caeca8642d1e84afbd3f7d6820020 # v4
|
- uses: actions/setup-node@49933ea5288caeca8642d1e84afbd3f7d6820020 # v4
|
||||||
with:
|
with:
|
||||||
node-version: "20"
|
node-version: "20"
|
||||||
@@ -51,10 +55,40 @@ jobs:
|
|||||||
continue-on-error: true
|
continue-on-error: true
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4
|
- uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4
|
||||||
|
with:
|
||||||
|
fetch-depth: 0
|
||||||
|
persist-credentials: false
|
||||||
|
|
||||||
|
# Detect whether this PR only touches documentation files.
|
||||||
|
# If so, skip the expensive pytest run while still reporting a passing check.
|
||||||
|
- name: Check for docs-only changes
|
||||||
|
id: docs-check
|
||||||
|
run: |
|
||||||
|
if [ "${{ github.event_name }}" = "pull_request" ]; then
|
||||||
|
BASE="${{ github.event.pull_request.base.sha }}"
|
||||||
|
HEAD="${{ github.event.pull_request.head.sha }}"
|
||||||
|
else
|
||||||
|
BASE="${{ github.event.before }}"
|
||||||
|
HEAD="${{ github.sha }}"
|
||||||
|
fi
|
||||||
|
# List all changed files; if every file matches docs/markdown patterns, skip pytest.
|
||||||
|
changed=$(git diff --name-only "$BASE" "$HEAD" 2>/dev/null || git diff --name-only HEAD~1 HEAD)
|
||||||
|
non_docs=$(echo "$changed" | grep -Ev '^(docs/|.*\.md$|\.github/[^/]+\.md$)' || true)
|
||||||
|
if [ -z "$non_docs" ]; then
|
||||||
|
echo "docs_only=true" >> "$GITHUB_OUTPUT"
|
||||||
|
echo "Docs-only change detected — skipping pytest."
|
||||||
|
else
|
||||||
|
echo "docs_only=false" >> "$GITHUB_OUTPUT"
|
||||||
|
fi
|
||||||
|
|
||||||
- uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5
|
- uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5
|
||||||
|
if: steps.docs-check.outputs.docs_only != 'true'
|
||||||
with:
|
with:
|
||||||
python-version: "3.11"
|
python-version: "3.11"
|
||||||
cache: pip
|
cache: pip
|
||||||
- run: pip install -r requirements.txt
|
- run: pip install -r requirements.txt
|
||||||
|
if: steps.docs-check.outputs.docs_only != 'true'
|
||||||
- run: mkdir -p data # sqlite DB lives at ./data/app.db
|
- run: mkdir -p data # sqlite DB lives at ./data/app.db
|
||||||
|
if: steps.docs-check.outputs.docs_only != 'true'
|
||||||
- run: python -m pytest -q
|
- run: python -m pytest -q
|
||||||
|
if: steps.docs-check.outputs.docs_only != 'true'
|
||||||
|
|||||||
@@ -0,0 +1,140 @@
|
|||||||
|
name: ci / docker publish
|
||||||
|
|
||||||
|
# Build the Odysseus image and publish to GHCR.
|
||||||
|
# push to main -> :latest, :X.Y.Z (curated release; main is fast-forwarded at releases)
|
||||||
|
# push to dev -> :dev, :X.Y.Z-dev.<sha> (rolling dev + an immutable, traceable pin)
|
||||||
|
# Multi-arch (linux/amd64 + linux/arm64): each arch builds on its own native
|
||||||
|
# runner and pushes by digest, then a merge job stitches the digests into one
|
||||||
|
# manifest list and applies the tags (faster + cleaner than QEMU emulation).
|
||||||
|
# Registry: ghcr.io/<owner>/<repo>.
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches: [dev, main]
|
||||||
|
paths-ignore:
|
||||||
|
- '**.md'
|
||||||
|
- 'docs/**'
|
||||||
|
- '.github/ISSUE_TEMPLATE/**'
|
||||||
|
|
||||||
|
concurrency:
|
||||||
|
group: docker-publish-${{ github.ref }}
|
||||||
|
cancel-in-progress: true
|
||||||
|
|
||||||
|
permissions:
|
||||||
|
contents: read
|
||||||
|
|
||||||
|
env:
|
||||||
|
REGISTRY: ghcr.io
|
||||||
|
IMAGE_NAME: ${{ github.repository }}
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
build:
|
||||||
|
name: build (${{ matrix.arch }})
|
||||||
|
runs-on: ${{ matrix.runner }}
|
||||||
|
permissions:
|
||||||
|
contents: read
|
||||||
|
packages: write
|
||||||
|
strategy:
|
||||||
|
fail-fast: false
|
||||||
|
matrix:
|
||||||
|
include:
|
||||||
|
- platform: linux/amd64
|
||||||
|
arch: amd64
|
||||||
|
runner: ubuntu-latest
|
||||||
|
- platform: linux/arm64
|
||||||
|
arch: arm64
|
||||||
|
runner: ubuntu-24.04-arm
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
|
||||||
|
with:
|
||||||
|
persist-credentials: false
|
||||||
|
- name: Set up Buildx
|
||||||
|
uses: docker/setup-buildx-action@d7f5e7f509e45cec5c76c4d5afdd7de93d0b3df5 # v4.1.0
|
||||||
|
- name: Log in to GHCR
|
||||||
|
uses: docker/login-action@650006c6eb7dba73a995cc03b0b2d7f5ca915bee # v4.2.0
|
||||||
|
with:
|
||||||
|
registry: ${{ env.REGISTRY }}
|
||||||
|
username: ${{ github.actor }}
|
||||||
|
password: ${{ secrets.GITHUB_TOKEN }}
|
||||||
|
- name: Build and push by digest
|
||||||
|
id: build
|
||||||
|
uses: docker/build-push-action@f9f3042f7e2789586610d6e8b85c8f03e5195baf # v7.2.0
|
||||||
|
with:
|
||||||
|
context: .
|
||||||
|
platforms: ${{ matrix.platform }}
|
||||||
|
outputs: type=image,name=${{ env.REGISTRY }}/${{ env.IMAGE_NAME }},push-by-digest=true,name-canonical=true,push=true
|
||||||
|
cache-from: type=gha,scope=${{ matrix.arch }}
|
||||||
|
cache-to: type=gha,mode=max,scope=${{ matrix.arch }}
|
||||||
|
- name: Export digest
|
||||||
|
run: |
|
||||||
|
mkdir -p /tmp/digests
|
||||||
|
digest="${{ steps.build.outputs.digest }}"
|
||||||
|
touch "/tmp/digests/${digest#sha256:}"
|
||||||
|
- name: Upload digest
|
||||||
|
uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7.0.1
|
||||||
|
with:
|
||||||
|
name: digest-${{ matrix.arch }}
|
||||||
|
path: /tmp/digests/*
|
||||||
|
if-no-files-found: error
|
||||||
|
retention-days: 1
|
||||||
|
|
||||||
|
merge:
|
||||||
|
name: merge manifest + tag
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
needs: build
|
||||||
|
permissions:
|
||||||
|
contents: read
|
||||||
|
packages: write
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
|
||||||
|
with:
|
||||||
|
persist-credentials: false
|
||||||
|
- name: Read APP_VERSION + short sha
|
||||||
|
id: ver
|
||||||
|
run: |
|
||||||
|
v=$(grep -E '^APP_VERSION' src/constants.py | head -1 | sed -E 's/.*"([^"]+)".*/\1/')
|
||||||
|
[ -n "$v" ] || { echo "APP_VERSION not found"; exit 1; }
|
||||||
|
echo "version=$v" >> "$GITHUB_OUTPUT"
|
||||||
|
echo "short=${GITHUB_SHA::7}" >> "$GITHUB_OUTPUT"
|
||||||
|
- name: Download digests
|
||||||
|
uses: actions/download-artifact@3e5f45b2cfb9172054b4087a40e8e0b5a5461e7c # v8.0.1
|
||||||
|
with:
|
||||||
|
path: /tmp/digests
|
||||||
|
pattern: digest-*
|
||||||
|
merge-multiple: true
|
||||||
|
- name: Set up Buildx
|
||||||
|
uses: docker/setup-buildx-action@d7f5e7f509e45cec5c76c4d5afdd7de93d0b3df5 # v4.1.0
|
||||||
|
- name: Log in to GHCR
|
||||||
|
uses: docker/login-action@650006c6eb7dba73a995cc03b0b2d7f5ca915bee # v4.2.0
|
||||||
|
with:
|
||||||
|
registry: ${{ env.REGISTRY }}
|
||||||
|
username: ${{ github.actor }}
|
||||||
|
password: ${{ secrets.GITHUB_TOKEN }}
|
||||||
|
- name: Compute tags
|
||||||
|
id: meta
|
||||||
|
uses: docker/metadata-action@80c7e94dd9b9319bd5eb7a0e0fe9291e23a2a2e9 # v6.1.0
|
||||||
|
with:
|
||||||
|
images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}
|
||||||
|
tags: |
|
||||||
|
type=raw,value=latest,enable=${{ github.ref == 'refs/heads/main' }}
|
||||||
|
type=raw,value=${{ steps.ver.outputs.version }},enable=${{ github.ref == 'refs/heads/main' }}
|
||||||
|
type=raw,value=dev,enable=${{ github.ref == 'refs/heads/dev' }}
|
||||||
|
type=raw,value=${{ steps.ver.outputs.version }}-dev.${{ steps.ver.outputs.short }},enable=${{ github.ref == 'refs/heads/dev' }}
|
||||||
|
- name: Create manifest list + push tags
|
||||||
|
working-directory: /tmp/digests
|
||||||
|
run: |
|
||||||
|
tags=$(jq -cr '.tags | map("-t " + .) | join(" ")' <<< "$DOCKER_METADATA_OUTPUT_JSON")
|
||||||
|
digests=$(printf "${REGISTRY}/${IMAGE_NAME}@sha256:%s " *)
|
||||||
|
# word-splitting is intended: $tags and $digests each expand to multiple args
|
||||||
|
# shellcheck disable=SC2086
|
||||||
|
docker buildx imagetools create $tags $digests
|
||||||
|
env:
|
||||||
|
REGISTRY: ${{ env.REGISTRY }}
|
||||||
|
IMAGE_NAME: ${{ env.IMAGE_NAME }}
|
||||||
|
- name: Inspect
|
||||||
|
run: |
|
||||||
|
if [ "$GITHUB_REF" = "refs/heads/main" ]; then ref=latest; else ref=dev; fi
|
||||||
|
docker buildx imagetools inspect "${REGISTRY}/${IMAGE_NAME}:${ref}"
|
||||||
|
env:
|
||||||
|
REGISTRY: ${{ env.REGISTRY }}
|
||||||
|
IMAGE_NAME: ${{ env.IMAGE_NAME }}
|
||||||
@@ -14,10 +14,11 @@ jobs:
|
|||||||
# Skip bots (Dependabot, release-drafter, etc.)
|
# Skip bots (Dependabot, release-drafter, etc.)
|
||||||
if: ${{ github.event.issue.user.type != 'Bot' }}
|
if: ${{ github.event.issue.user.type != 'Bot' }}
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
|
||||||
with:
|
with:
|
||||||
sparse-checkout: .github/scripts
|
sparse-checkout: .github/scripts
|
||||||
|
persist-credentials: false
|
||||||
|
|
||||||
- uses: actions/github-script@v7
|
- uses: actions/github-script@3a2844b7e9c422d3c10d287c895573f7108da1b3 # v9.0.0
|
||||||
with:
|
with:
|
||||||
script: return require('./.github/scripts/check-issue-description.js')({github, context, core})
|
script: return require('./.github/scripts/check-issue-description.js')({github, context, core})
|
||||||
|
|||||||
@@ -1,28 +1,109 @@
|
|||||||
name: ci / PR description check
|
name: ci / PR checks
|
||||||
|
|
||||||
on:
|
on:
|
||||||
pull_request_target:
|
# pull_request_target runs in the base-repo context (has secrets) so the check
|
||||||
types: [opened, edited, synchronize, reopened]
|
# works on fork PRs. Safe here: the checkout pins to the base branch (no fork
|
||||||
|
# code runs) and the scripts only read context.payload and call the GitHub API.
|
||||||
|
pull_request_target: # zizmor: ignore[dangerous-triggers]
|
||||||
|
types: [opened, edited, synchronize, reopened, ready_for_review]
|
||||||
|
|
||||||
# pull_request_target runs in the base-repo context (has secrets).
|
# Default-deny at the workflow level; each job opts into only the scopes it needs.
|
||||||
# The checkout below pins to the base branch so no fork code is executed.
|
# Note: modifying a PR's labels/comments needs pull-requests:write even though the
|
||||||
# The script only reads context.payload and calls the GitHub API.
|
# REST path is under /issues/{n}/...; issues:write alone returns 403 on PRs.
|
||||||
permissions:
|
permissions: {}
|
||||||
issues: write
|
|
||||||
pull-requests: write
|
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
check-description:
|
check-description:
|
||||||
name: Check PR description
|
name: Check PR description
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
# Skip bots — they open PRs programmatically and have their own process.
|
permissions:
|
||||||
|
contents: read
|
||||||
|
pull-requests: write
|
||||||
|
issues: write
|
||||||
|
# Skip bots: they open PRs programmatically and have their own process.
|
||||||
if: github.event.pull_request.user.type != 'Bot'
|
if: github.event.pull_request.user.type != 'Bot'
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
|
||||||
with:
|
with:
|
||||||
ref: ${{ github.base_ref }}
|
ref: ${{ github.base_ref }}
|
||||||
sparse-checkout: .github/scripts
|
sparse-checkout: .github/scripts
|
||||||
|
persist-credentials: false
|
||||||
|
|
||||||
- uses: actions/github-script@v7
|
- uses: actions/github-script@3a2844b7e9c422d3c10d287c895573f7108da1b3 # v9.0.0
|
||||||
with:
|
with:
|
||||||
script: return require('./.github/scripts/check-pr-description.js')({github, context, core})
|
script: return require('./.github/scripts/check-pr-description.js')({github, context, core})
|
||||||
|
|
||||||
|
check-title:
|
||||||
|
name: Check PR title (Conventional Commits)
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
permissions: {}
|
||||||
|
# Skip bots: they open PRs programmatically and have their own process.
|
||||||
|
if: github.event.pull_request.user.type != 'Bot'
|
||||||
|
steps:
|
||||||
|
- uses: actions/github-script@3a2844b7e9c422d3c10d287c895573f7108da1b3 # v9.0.0
|
||||||
|
with:
|
||||||
|
script: |
|
||||||
|
const title = context.payload.pull_request.title || "";
|
||||||
|
// Conventional Commits: type(optional-scope)(optional !): summary
|
||||||
|
const re = /^(feat|fix|docs|style|refactor|perf|test|build|ci|chore|revert)(\([\w .\/-]+\))?!?: .+/;
|
||||||
|
if (!re.test(title)) {
|
||||||
|
core.setFailed(
|
||||||
|
`PR title is not in Conventional Commits format:\n "${title}"\n\n` +
|
||||||
|
`Expected: type(scope): summary\n` +
|
||||||
|
`Example: fix(search): handle empty query\n` +
|
||||||
|
`Types: feat, fix, docs, style, refactor, perf, test, build, ci, chore, revert.`
|
||||||
|
);
|
||||||
|
} else {
|
||||||
|
core.info(`PR title OK: ${title}`);
|
||||||
|
}
|
||||||
|
|
||||||
|
check-mergeable:
|
||||||
|
name: Flag unmergeable PRs
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
permissions:
|
||||||
|
pull-requests: write
|
||||||
|
issues: write
|
||||||
|
# Skip bots: they open PRs programmatically and have their own process.
|
||||||
|
if: github.event.pull_request.user.type != 'Bot'
|
||||||
|
steps:
|
||||||
|
- uses: actions/github-script@3a2844b7e9c422d3c10d287c895573f7108da1b3 # v9.0.0
|
||||||
|
with:
|
||||||
|
script: |
|
||||||
|
const repo = { owner: context.repo.owner, repo: context.repo.repo };
|
||||||
|
const number = context.payload.pull_request.number;
|
||||||
|
const READY = "ready for review";
|
||||||
|
const CONFLICT = "merge conflict";
|
||||||
|
|
||||||
|
// Ensure the conflict label exists (red). Ignore if already present.
|
||||||
|
try {
|
||||||
|
await github.rest.issues.getLabel({ ...repo, name: CONFLICT });
|
||||||
|
} catch {
|
||||||
|
await github.rest.issues.createLabel({
|
||||||
|
...repo, name: CONFLICT, color: "B60205",
|
||||||
|
description: "Conflicts with the base branch; needs a rebase before review.",
|
||||||
|
}).catch(() => {});
|
||||||
|
}
|
||||||
|
|
||||||
|
// mergeable is computed asynchronously and is often null right after
|
||||||
|
// an event, so poll a few times until GitHub has resolved it.
|
||||||
|
let pr = null;
|
||||||
|
for (let i = 0; i < 5; i++) {
|
||||||
|
const { data } = await github.rest.pulls.get({ ...repo, pull_number: number });
|
||||||
|
if (data.mergeable !== null) { pr = data; break; }
|
||||||
|
await new Promise(r => setTimeout(r, 3000));
|
||||||
|
}
|
||||||
|
if (!pr || pr.draft) return;
|
||||||
|
const labels = pr.labels.map(l => l.name);
|
||||||
|
|
||||||
|
if (pr.mergeable === false) {
|
||||||
|
if (labels.includes(READY)) {
|
||||||
|
await github.rest.issues.removeLabel({ ...repo, issue_number: number, name: READY }).catch(() => {});
|
||||||
|
}
|
||||||
|
if (!labels.includes(CONFLICT)) {
|
||||||
|
await github.rest.issues.addLabels({ ...repo, issue_number: number, labels: [CONFLICT] });
|
||||||
|
}
|
||||||
|
} else if (pr.mergeable === true) {
|
||||||
|
if (labels.includes(CONFLICT)) {
|
||||||
|
await github.rest.issues.removeLabel({ ...repo, issue_number: number, name: CONFLICT }).catch(() => {});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -94,6 +94,18 @@ Before submitting any change that affects what the app looks like — buttons, i
|
|||||||
|
|
||||||
If you are unsure whether a change is "visual," it is. Default to attaching a screenshot.
|
If you are unsure whether a change is "visual," it is. Default to attaching a screenshot.
|
||||||
|
|
||||||
|
## Code conventions
|
||||||
|
|
||||||
|
Don't hardcode values that the project already exposes through a constant or a helper. Hardcoded literals drift out of sync, break on non-default deployments, and reintroduce bugs we've already fixed.
|
||||||
|
|
||||||
|
- **Filesystem paths:** never build writable paths from `Path(__file__)...` into the source tree, hardcode `/app/...`, or use a relative `"data/..."` string. Every persisted file and directory has a named constant in `src/constants.py` (for example `AUTH_FILE`, `USER_PREFS_FILE`, `SETTINGS_FILE`, `TTS_CACHE_DIR`, `CHROMA_DIR`). Import and use that named constant; do not re-derive the path locally with `os.path.join(DATA_DIR, "x.json")` or `DATA_DIR / "x.json"`. `DATA_DIR` is the single place that reads `ODYSSEUS_DATA_DIR`, so use it directly only for dynamic paths that have no fixed name (for example per-owner files). If a data file or directory has no constant yet, add one to `src/constants.py`. The source tree is read-only in Docker and `/app/...` does not exist on native runs; guard directory creation so an unwritable path degrades gracefully instead of crashing at import.
|
||||||
|
- **Internal API / loopback URLs:** don't hardcode `http://localhost:7000`. Use `internal_api_base()` from `src.constants` (it honors `ODYSSEUS_INTERNAL_BASE` / `APP_PORT`).
|
||||||
|
- **Ports, limits, model lists, and similar:** reuse the existing constant if one exists; if it doesn't and the value is used in more than one place, add a constant rather than copying the literal.
|
||||||
|
|
||||||
|
If you need a value that has no constant or helper yet, add it to `src/constants.py` (the single source of truth for paths and config; `core/constants.py` only re-exports it for backward compatibility) and import it, rather than repeating a literal across files.
|
||||||
|
|
||||||
|
**Commits:** use [Conventional Commits](https://www.conventionalcommits.org), `type(scope): summary` (e.g. `fix(search): ...`, `feat(notes): ...`, `docs(contributing): ...`). Common types: `fix`, `feat`, `refactor`, `docs`, `test`, `chore`, `ci`. Keep the subject short and imperative; put the "why" in the body when it isn't obvious.
|
||||||
|
|
||||||
## Issue Reports
|
## Issue Reports
|
||||||
|
|
||||||
For bugs, include:
|
For bugs, include:
|
||||||
|
|||||||
@@ -1,5 +1,7 @@
|
|||||||
# Odysseus
|
# Odysseus
|
||||||
|
|
||||||
|
> **Branch note:** `dev` is the default branch and contains the latest development changes, but it may be unstable. For the more stable curated branch, use [`main`](https://github.com/pewdiepie-archdaemon/odysseus/tree/main).
|
||||||
|
|
||||||
```
|
```
|
||||||
───────────────────────────────────────────────
|
───────────────────────────────────────────────
|
||||||
⊹ ࣪ ˖ ૮( ˶ᵔ ᵕ ᵔ˶ )っ Odysseus vers. 1.0
|
⊹ ࣪ ˖ ૮( ˶ᵔ ᵕ ᵔ˶ )っ Odysseus vers. 1.0
|
||||||
@@ -331,6 +333,12 @@ To expose Odysseus on a local network or Tailscale with HTTPS:
|
|||||||
| `PyMuPDF` | PDF page rendering in the side viewer panel and form-filling. (Note: AGPL-3.0) |
|
| `PyMuPDF` | PDF page rendering in the side viewer panel and form-filling. (Note: AGPL-3.0) |
|
||||||
| `markitdown` | Office/EPUB document text extraction (converts .docx/.xlsx/.pptx/.xls/.epub to Markdown). |
|
| `markitdown` | Office/EPUB document text extraction (converts .docx/.xlsx/.pptx/.xls/.epub to Markdown). |
|
||||||
|
|
||||||
|
### Outlook / Office 365 email
|
||||||
|
Odysseus email accounts currently use IMAP/SMTP username-password auth. Outlook
|
||||||
|
and Microsoft 365 generally require OAuth instead, so normal Microsoft mailbox
|
||||||
|
passwords will fail. See [docs/email-outlook.md](docs/email-outlook.md) for the
|
||||||
|
current limitation and the planned integration direction.
|
||||||
|
|
||||||
## Security Notes
|
## Security Notes
|
||||||
Odysseus is a self-hosted workspace with powerful local tools: shell access, file uploads, model downloads, web research, email/calendar integrations, and API tokens. Treat it like an admin console.
|
Odysseus is a self-hosted workspace with powerful local tools: shell access, file uploads, model downloads, web research, email/calendar integrations, and API tokens. Treat it like an admin console.
|
||||||
|
|
||||||
@@ -394,6 +402,16 @@ Key settings:
|
|||||||
| `CHROMADB_HOST` | `localhost` | ChromaDB host for vector memory. Docker overrides this to `chromadb`. |
|
| `CHROMADB_HOST` | `localhost` | ChromaDB host for vector memory. Docker overrides this to `chromadb`. |
|
||||||
| `CHROMADB_PORT` | `8100` | ChromaDB port for manual host runs. Docker overrides this to `8000`. |
|
| `CHROMADB_PORT` | `8100` | ChromaDB port for manual host runs. Docker overrides this to `8000`. |
|
||||||
| `EMBEDDING_URL` | -- | OpenAI-compatible embeddings endpoint |
|
| `EMBEDDING_URL` | -- | OpenAI-compatible embeddings endpoint |
|
||||||
|
| `ODYSSEUS_CHAT_UPLOAD_MAX_BYTES` | `10485760` | Chat/agent attachment cap in bytes. Raise for larger local PDFs or text documents. |
|
||||||
|
| `ODYSSEUS_GALLERY_UPLOAD_MAX_BYTES` | `104857600` | Gallery image upload cap in bytes (100 MB). |
|
||||||
|
| `ODYSSEUS_GALLERY_TRANSFORM_UPLOAD_MAX_BYTES` | `26214400` | Gallery transform input cap in bytes (25 MB). |
|
||||||
|
| `ODYSSEUS_MEMORY_IMPORT_MAX_BYTES` | `10485760` | Memory import file cap in bytes (10 MB). |
|
||||||
|
| `ODYSSEUS_PERSONAL_UPLOAD_MAX_BYTES` | `26214400` | Personal document upload cap in bytes (25 MB). |
|
||||||
|
| `ODYSSEUS_EMAIL_COMPOSE_UPLOAD_MAX_BYTES` | `26214400` | Email compose attachment cap in bytes (25 MB). |
|
||||||
|
| `ODYSSEUS_STT_MAX_AUDIO_BYTES` | `26214400` | Speech-to-text audio cap in bytes (25 MB). |
|
||||||
|
| `ODYSSEUS_ICS_MAX_BYTES` | `10485760` | Calendar `.ics` import cap in bytes (10 MB). |
|
||||||
|
|
||||||
|
All upload-limit vars are validated (must be a positive integer) and optional; an invalid value fails fast at startup.
|
||||||
|
|
||||||
### Built-in MCP servers (optional setup)
|
### Built-in MCP servers (optional setup)
|
||||||
|
|
||||||
|
|||||||
@@ -51,10 +51,10 @@ from starlette.middleware.base import BaseHTTPMiddleware
|
|||||||
# Core imports
|
# Core imports
|
||||||
from core.constants import (
|
from core.constants import (
|
||||||
BASE_DIR, STATIC_DIR, SESSIONS_FILE,
|
BASE_DIR, STATIC_DIR, SESSIONS_FILE,
|
||||||
REQUEST_TIMEOUT, OPENAI_API_KEY,
|
REQUEST_TIMEOUT, OPENAI_API_KEY, AUTH_FILE,
|
||||||
)
|
)
|
||||||
from core.database import SessionLocal, ApiToken
|
from core.database import SessionLocal, ApiToken
|
||||||
from core.middleware import SecurityHeadersMiddleware
|
from core.middleware import SecurityHeadersMiddleware, is_cors_preflight
|
||||||
from core.auth import AuthManager
|
from core.auth import AuthManager
|
||||||
from core.exceptions import (
|
from core.exceptions import (
|
||||||
SessionNotFoundError, InvalidFileUploadError,
|
SessionNotFoundError, InvalidFileUploadError,
|
||||||
@@ -64,6 +64,7 @@ from core.exceptions import (
|
|||||||
import bcrypt as _bcrypt
|
import bcrypt as _bcrypt
|
||||||
|
|
||||||
from src.app_helpers import abs_join
|
from src.app_helpers import abs_join
|
||||||
|
from src.generated_images import GENERATED_IMAGE_HEADERS, resolve_generated_image_path
|
||||||
from starlette.responses import RedirectResponse
|
from starlette.responses import RedirectResponse
|
||||||
|
|
||||||
# ========= LOGGING =========
|
# ========= LOGGING =========
|
||||||
@@ -252,6 +253,15 @@ if AUTH_ENABLED:
|
|||||||
class AuthMiddleware(BaseHTTPMiddleware):
|
class AuthMiddleware(BaseHTTPMiddleware):
|
||||||
async def dispatch(self, request: Request, call_next):
|
async def dispatch(self, request: Request, call_next):
|
||||||
path = request.url.path
|
path = request.url.path
|
||||||
|
# A genuine CORS preflight (OPTIONS + Access-Control-Request-Method)
|
||||||
|
# carries no credentials by design and must reach CORSMiddleware to be
|
||||||
|
# answered. AuthMiddleware is the outermost middleware, so gating the
|
||||||
|
# preflight on auth 401s it before CORS can respond -- which blocks
|
||||||
|
# every cross-origin browser/WebView client before the real request
|
||||||
|
# is sent. Let real preflights through (only OPTIONS w/ the ACRM
|
||||||
|
# header; never a credentialed request).
|
||||||
|
if is_cors_preflight(request.method, request.headers):
|
||||||
|
return await call_next(request)
|
||||||
if _is_auth_exempt(path):
|
if _is_auth_exempt(path):
|
||||||
return await call_next(request)
|
return await call_next(request)
|
||||||
# In-process internal-tool token bypass. Used by the agent
|
# In-process internal-tool token bypass. Used by the agent
|
||||||
@@ -387,13 +397,7 @@ app.mount("/static", _RevalidatingStatic(directory="static"), name="static")
|
|||||||
@app.get("/api/generated-image/{filename}")
|
@app.get("/api/generated-image/{filename}")
|
||||||
async def serve_generated_image(filename: str, request: Request):
|
async def serve_generated_image(filename: str, request: Request):
|
||||||
"""Serve generated images from the data directory."""
|
"""Serve generated images from the data directory."""
|
||||||
from pathlib import Path
|
img_path = resolve_generated_image_path(filename)
|
||||||
import re
|
|
||||||
if not re.match(r'^[a-f0-9]{8,64}\.(png|jpg|jpeg|webp|gif|mp4|mov|webm|mkv|m4v)$', filename):
|
|
||||||
raise HTTPException(status_code=400, detail="Invalid filename")
|
|
||||||
img_path = Path("data/generated_images") / filename
|
|
||||||
if not img_path.exists():
|
|
||||||
raise HTTPException(status_code=404, detail="Image not found")
|
|
||||||
# SECURITY: filename is the only key, so anyone who knows / guesses a
|
# SECURITY: filename is the only key, so anyone who knows / guesses a
|
||||||
# 12-hex content hash could pull another user's image bytes. Require
|
# 12-hex content hash could pull another user's image bytes. Require
|
||||||
# auth and verify ownership via the gallery row (when one exists).
|
# auth and verify ownership via the gallery row (when one exists).
|
||||||
@@ -429,7 +433,7 @@ async def serve_generated_image(filename: str, request: Request):
|
|||||||
return FileResponse(
|
return FileResponse(
|
||||||
str(img_path),
|
str(img_path),
|
||||||
media_type=mime,
|
media_type=mime,
|
||||||
headers={"Cache-Control": "public, max-age=31536000, immutable"},
|
headers=GENERATED_IMAGE_HEADERS,
|
||||||
)
|
)
|
||||||
|
|
||||||
# ========= YOUTUBE INIT =========
|
# ========= YOUTUBE INIT =========
|
||||||
@@ -594,6 +598,10 @@ app.include_router(setup_model_routes(model_discovery))
|
|||||||
from routes.copilot_routes import setup_copilot_routes
|
from routes.copilot_routes import setup_copilot_routes
|
||||||
app.include_router(setup_copilot_routes())
|
app.include_router(setup_copilot_routes())
|
||||||
|
|
||||||
|
# ChatGPT Subscription device-flow login
|
||||||
|
from routes.chatgpt_subscription_routes import setup_chatgpt_subscription_routes
|
||||||
|
app.include_router(setup_chatgpt_subscription_routes())
|
||||||
|
|
||||||
# TTS
|
# TTS
|
||||||
from routes.tts_routes import setup_tts_routes
|
from routes.tts_routes import setup_tts_routes
|
||||||
app.include_router(setup_tts_routes(tts_service))
|
app.include_router(setup_tts_routes(tts_service))
|
||||||
@@ -789,6 +797,8 @@ async def serve_backgrounds(request: Request):
|
|||||||
|
|
||||||
@app.get("/login")
|
@app.get("/login")
|
||||||
async def serve_login(request: Request):
|
async def serve_login(request: Request):
|
||||||
|
if not AUTH_ENABLED:
|
||||||
|
return RedirectResponse(url="/", status_code=302)
|
||||||
return _serve_html_with_nonce(request, abs_join(BASE_DIR, "static/login.html"))
|
return _serve_html_with_nonce(request, abs_join(BASE_DIR, "static/login.html"))
|
||||||
|
|
||||||
@app.get("/api/version")
|
@app.get("/api/version")
|
||||||
@@ -948,7 +958,7 @@ async def _startup_event():
|
|||||||
owners = set()
|
owners = set()
|
||||||
try:
|
try:
|
||||||
import json as _json
|
import json as _json
|
||||||
auth_path = "data/auth.json"
|
auth_path = AUTH_FILE
|
||||||
with open(auth_path, encoding="utf-8") as f:
|
with open(auth_path, encoding="utf-8") as f:
|
||||||
users = _json.load(f).get("users", {})
|
users = _json.load(f).get("users", {})
|
||||||
owners.update(users.keys())
|
owners.update(users.keys())
|
||||||
@@ -995,7 +1005,7 @@ async def _startup_event():
|
|||||||
# does not make an existing library look empty after auth/account changes.
|
# does not make an existing library look empty after auth/account changes.
|
||||||
try:
|
try:
|
||||||
import json as _json
|
import json as _json
|
||||||
auth_path = "data/auth.json"
|
auth_path = AUTH_FILE
|
||||||
with open(auth_path, encoding="utf-8") as f:
|
with open(auth_path, encoding="utf-8") as f:
|
||||||
users = _json.load(f).get("users", {})
|
users = _json.load(f).get("users", {})
|
||||||
primary_owner = None
|
primary_owner = None
|
||||||
|
|||||||
@@ -14,6 +14,8 @@ import uuid
|
|||||||
|
|
||||||
import bcrypt
|
import bcrypt
|
||||||
|
|
||||||
|
from src.constants import AUTH_FILE
|
||||||
|
|
||||||
PAIRING_VERSION = 1
|
PAIRING_VERSION = 1
|
||||||
COMPANION_SCOPE = "chat"
|
COMPANION_SCOPE = "chat"
|
||||||
|
|
||||||
@@ -61,7 +63,7 @@ def lan_ip_candidates() -> list[str]:
|
|||||||
def find_admin_user() -> str | None:
|
def find_admin_user() -> str | None:
|
||||||
"""Resolve an admin username from data/auth.json (schema uses is_admin),
|
"""Resolve an admin username from data/auth.json (schema uses is_admin),
|
||||||
falling back to the first user."""
|
falling back to the first user."""
|
||||||
auth_path = os.path.join("data", "auth.json")
|
auth_path = AUTH_FILE
|
||||||
try:
|
try:
|
||||||
with open(auth_path, "r", encoding="utf-8") as f:
|
with open(auth_path, "r", encoding="utf-8") as f:
|
||||||
data = json.load(f)
|
data = json.load(f)
|
||||||
|
|||||||
+92
-62
@@ -30,14 +30,24 @@ DEFAULT_PRIVILEGES = {
|
|||||||
"can_manage_memory": True,
|
"can_manage_memory": True,
|
||||||
"max_messages_per_day": 0,
|
"max_messages_per_day": 0,
|
||||||
"allowed_models": [],
|
"allowed_models": [],
|
||||||
|
"allowed_models_restricted": False,
|
||||||
|
# Explicit "block every model" sentinel. An empty `allowed_models` list is
|
||||||
|
# ambiguous — it's also what gets sent when the admin clicks "[All]" — so
|
||||||
|
# we need a dedicated flag to express "this user may use no models at all"
|
||||||
|
# distinctly from "this user has no restriction".
|
||||||
|
"block_all_models": False,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Admins get everything
|
# Admins get everything
|
||||||
ADMIN_PRIVILEGES = {k: (True if isinstance(v, bool) else (0 if isinstance(v, int) else [])) for k, v in DEFAULT_PRIVILEGES.items()}
|
ADMIN_PRIVILEGES = {k: (True if isinstance(v, bool) else (0 if isinstance(v, int) else [])) for k, v in DEFAULT_PRIVILEGES.items()}
|
||||||
|
ADMIN_PRIVILEGES["allowed_models_restricted"] = False
|
||||||
|
# Admins must never be blocked from using models — the generic dict
|
||||||
|
# comprehension above flips every boolean default to True, which would be
|
||||||
|
# backwards for this sentinel.
|
||||||
|
ADMIN_PRIVILEGES["block_all_models"] = False
|
||||||
|
|
||||||
DEFAULT_AUTH_PATH = os.path.join(
|
from src.constants import AUTH_FILE
|
||||||
Path(__file__).parent.parent, "data", "auth.json"
|
DEFAULT_AUTH_PATH = AUTH_FILE
|
||||||
)
|
|
||||||
TOKEN_TTL = 60 * 60 * 24 * 7 # 7 days
|
TOKEN_TTL = 60 * 60 * 24 * 7 # 7 days
|
||||||
|
|
||||||
# Usernames the auth + middleware layer reserve as internal "synthetic owner"
|
# Usernames the auth + middleware layer reserve as internal "synthetic owner"
|
||||||
@@ -76,6 +86,10 @@ class AuthManager:
|
|||||||
# Guards mutations of self._sessions and the on-disk sessions.json.
|
# Guards mutations of self._sessions and the on-disk sessions.json.
|
||||||
# Validate/create/revoke run concurrently from the FastAPI threadpool.
|
# Validate/create/revoke run concurrently from the FastAPI threadpool.
|
||||||
self._sessions_lock = threading.RLock()
|
self._sessions_lock = threading.RLock()
|
||||||
|
# Guards all mutations of self._config and the on-disk auth.json so
|
||||||
|
# concurrent create/delete/rename/privilege operations don't interleave
|
||||||
|
# and corrupt the user database.
|
||||||
|
self._config_lock = threading.Lock()
|
||||||
# Guards the first-run setup check-and-write so concurrent requests
|
# Guards the first-run setup check-and-write so concurrent requests
|
||||||
# cannot both observe is_configured==False and both create admin accounts.
|
# cannot both observe is_configured==False and both create admin accounts.
|
||||||
self._setup_lock = threading.Lock()
|
self._setup_lock = threading.Lock()
|
||||||
@@ -172,8 +186,9 @@ class AuthManager:
|
|||||||
|
|
||||||
@signup_enabled.setter
|
@signup_enabled.setter
|
||||||
def signup_enabled(self, value: bool):
|
def signup_enabled(self, value: bool):
|
||||||
self._config["signup_enabled"] = value
|
with self._config_lock:
|
||||||
self._save()
|
self._config["signup_enabled"] = value
|
||||||
|
self._save()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_configured(self) -> bool:
|
def is_configured(self) -> bool:
|
||||||
@@ -198,17 +213,18 @@ class AuthManager:
|
|||||||
if username in RESERVED_USERNAMES:
|
if username in RESERVED_USERNAMES:
|
||||||
logger.warning("Refused to create reserved username '%s'", username)
|
logger.warning("Refused to create reserved username '%s'", username)
|
||||||
return False
|
return False
|
||||||
if username in self.users:
|
with self._config_lock:
|
||||||
return False
|
if username in self.users:
|
||||||
if "users" not in self._config:
|
return False
|
||||||
self._config["users"] = {}
|
if "users" not in self._config:
|
||||||
self._config["users"][username] = {
|
self._config["users"] = {}
|
||||||
"password_hash": _hash_password(password),
|
self._config["users"][username] = {
|
||||||
"created": time.time(),
|
"password_hash": _hash_password(password),
|
||||||
"is_admin": is_admin,
|
"created": time.time(),
|
||||||
"privileges": dict(ADMIN_PRIVILEGES if is_admin else DEFAULT_PRIVILEGES),
|
"is_admin": is_admin,
|
||||||
}
|
"privileges": dict(ADMIN_PRIVILEGES if is_admin else DEFAULT_PRIVILEGES),
|
||||||
self._save()
|
}
|
||||||
|
self._save()
|
||||||
logger.info(f"Created user '{username}' (admin={is_admin})")
|
logger.info(f"Created user '{username}' (admin={is_admin})")
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@@ -221,14 +237,15 @@ class AuthManager:
|
|||||||
their cookie expired naturally (default ~30 days).
|
their cookie expired naturally (default ~30 days).
|
||||||
"""
|
"""
|
||||||
username = username.strip().lower()
|
username = username.strip().lower()
|
||||||
if username not in self.users:
|
with self._config_lock:
|
||||||
return False
|
if username not in self.users:
|
||||||
if username == requesting_user:
|
return False
|
||||||
return False
|
if username == requesting_user:
|
||||||
if not self.users.get(requesting_user, {}).get("is_admin"):
|
return False
|
||||||
return False
|
if not self.users.get(requesting_user, {}).get("is_admin"):
|
||||||
del self._config["users"][username]
|
return False
|
||||||
self._save()
|
del self._config["users"][username]
|
||||||
|
self._save()
|
||||||
# Purge all sessions belonging to this user. validate_token doesn't
|
# Purge all sessions belonging to this user. validate_token doesn't
|
||||||
# cross-check `self.users`, so without this step a deleted user's
|
# cross-check `self.users`, so without this step a deleted user's
|
||||||
# cookie keeps authenticating.
|
# cookie keeps authenticating.
|
||||||
@@ -266,14 +283,15 @@ class AuthManager:
|
|||||||
if new_username in RESERVED_USERNAMES:
|
if new_username in RESERVED_USERNAMES:
|
||||||
logger.warning("Refused to rename '%s' into reserved username '%s'", old_username, new_username)
|
logger.warning("Refused to rename '%s' into reserved username '%s'", old_username, new_username)
|
||||||
return False
|
return False
|
||||||
if old_username not in self.users:
|
with self._config_lock:
|
||||||
return False
|
if old_username not in self.users:
|
||||||
if new_username in self.users:
|
return False
|
||||||
return False
|
if new_username in self.users:
|
||||||
if not self.users.get(requesting_user, {}).get("is_admin"):
|
return False
|
||||||
return False
|
if not self.users.get(requesting_user, {}).get("is_admin"):
|
||||||
self._config.setdefault("users", {})[new_username] = self._config["users"].pop(old_username)
|
return False
|
||||||
self._save()
|
self._config.setdefault("users", {})[new_username] = self._config["users"].pop(old_username)
|
||||||
|
self._save()
|
||||||
|
|
||||||
renamed_sessions = 0
|
renamed_sessions = 0
|
||||||
with self._sessions_lock:
|
with self._sessions_lock:
|
||||||
@@ -311,17 +329,18 @@ class AuthManager:
|
|||||||
def set_privileges(self, username: str, privileges: Dict[str, Any]) -> bool:
|
def set_privileges(self, username: str, privileges: Dict[str, Any]) -> bool:
|
||||||
"""Update privileges for a user. Can't modify admin privileges."""
|
"""Update privileges for a user. Can't modify admin privileges."""
|
||||||
username = username.strip().lower()
|
username = username.strip().lower()
|
||||||
if username not in self.users:
|
with self._config_lock:
|
||||||
return False
|
if username not in self.users:
|
||||||
if self.users[username].get("is_admin"):
|
return False
|
||||||
return False # admins always have full access
|
if self.users[username].get("is_admin"):
|
||||||
# Only allow known privilege keys
|
return False # admins always have full access
|
||||||
current = self.get_privileges(username)
|
# Only allow known privilege keys
|
||||||
for k, v in privileges.items():
|
current = self.get_privileges(username)
|
||||||
if k in DEFAULT_PRIVILEGES:
|
for k, v in privileges.items():
|
||||||
current[k] = v
|
if k in DEFAULT_PRIVILEGES:
|
||||||
self._config["users"][username]["privileges"] = current
|
current[k] = v
|
||||||
self._save()
|
self._config["users"][username]["privileges"] = current
|
||||||
|
self._save()
|
||||||
logger.info(f"Updated privileges for '{username}': {current}")
|
logger.info(f"Updated privileges for '{username}': {current}")
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@@ -331,8 +350,9 @@ class AuthManager:
|
|||||||
return False
|
return False
|
||||||
if not _verify_password(current_password, self.users[username]["password_hash"]):
|
if not _verify_password(current_password, self.users[username]["password_hash"]):
|
||||||
return False
|
return False
|
||||||
self._config["users"][username]["password_hash"] = _hash_password(new_password)
|
with self._config_lock:
|
||||||
self._save()
|
self._config["users"][username]["password_hash"] = _hash_password(new_password)
|
||||||
|
self._save()
|
||||||
return True
|
return True
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
@@ -350,8 +370,9 @@ class AuthManager:
|
|||||||
if username not in self.users:
|
if username not in self.users:
|
||||||
return None
|
return None
|
||||||
secret = pyotp.random_base32()
|
secret = pyotp.random_base32()
|
||||||
self._config["users"][username]["totp_secret_pending"] = secret
|
with self._config_lock:
|
||||||
self._save()
|
self._config["users"][username]["totp_secret_pending"] = secret
|
||||||
|
self._save()
|
||||||
return secret
|
return secret
|
||||||
|
|
||||||
def totp_get_provisioning_uri(self, username: str, secret: str) -> str:
|
def totp_get_provisioning_uri(self, username: str, secret: str) -> str:
|
||||||
@@ -370,13 +391,14 @@ class AuthManager:
|
|||||||
if not totp.verify(code, valid_window=1):
|
if not totp.verify(code, valid_window=1):
|
||||||
return False
|
return False
|
||||||
# Enable 2FA
|
# Enable 2FA
|
||||||
self._config["users"][username]["totp_secret"] = secret
|
with self._config_lock:
|
||||||
self._config["users"][username]["totp_enabled"] = True
|
self._config["users"][username]["totp_secret"] = secret
|
||||||
self._config["users"][username].pop("totp_secret_pending", None)
|
self._config["users"][username]["totp_enabled"] = True
|
||||||
# Generate backup codes
|
self._config["users"][username].pop("totp_secret_pending", None)
|
||||||
backup = [secrets.token_hex(4) for _ in range(8)]
|
# Generate backup codes
|
||||||
self._config["users"][username]["totp_backup_codes"] = backup
|
backup = [secrets.token_hex(4) for _ in range(8)]
|
||||||
self._save()
|
self._config["users"][username]["totp_backup_codes"] = backup
|
||||||
|
self._save()
|
||||||
logger.info(f"2FA enabled for '{username}'")
|
logger.info(f"2FA enabled for '{username}'")
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@@ -395,9 +417,10 @@ class AuthManager:
|
|||||||
# Check backup codes first
|
# Check backup codes first
|
||||||
backup = user.get("totp_backup_codes", [])
|
backup = user.get("totp_backup_codes", [])
|
||||||
if code in backup:
|
if code in backup:
|
||||||
backup.remove(code)
|
with self._config_lock:
|
||||||
self._config["users"][username]["totp_backup_codes"] = backup
|
backup.remove(code)
|
||||||
self._save()
|
self._config["users"][username]["totp_backup_codes"] = backup
|
||||||
|
self._save()
|
||||||
logger.info(f"Backup code used for '{username}' ({len(backup)} remaining)")
|
logger.info(f"Backup code used for '{username}' ({len(backup)} remaining)")
|
||||||
return True
|
return True
|
||||||
totp = pyotp.TOTP(secret)
|
totp = pyotp.TOTP(secret)
|
||||||
@@ -408,11 +431,12 @@ class AuthManager:
|
|||||||
username = username.strip().lower()
|
username = username.strip().lower()
|
||||||
if not self.verify_password(username, password):
|
if not self.verify_password(username, password):
|
||||||
return False
|
return False
|
||||||
self._config["users"][username].pop("totp_secret", None)
|
with self._config_lock:
|
||||||
self._config["users"][username].pop("totp_secret_pending", None)
|
self._config["users"][username].pop("totp_secret", None)
|
||||||
self._config["users"][username].pop("totp_backup_codes", None)
|
self._config["users"][username].pop("totp_secret_pending", None)
|
||||||
self._config["users"][username]["totp_enabled"] = False
|
self._config["users"][username].pop("totp_backup_codes", None)
|
||||||
self._save()
|
self._config["users"][username]["totp_enabled"] = False
|
||||||
|
self._save()
|
||||||
logger.info(f"2FA disabled for '{username}'")
|
logger.info(f"2FA disabled for '{username}'")
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@@ -431,6 +455,12 @@ class AuthManager:
|
|||||||
username = username.strip().lower()
|
username = username.strip().lower()
|
||||||
if not self.verify_password(username, password):
|
if not self.verify_password(username, password):
|
||||||
return None
|
return None
|
||||||
|
return self.create_session_trusted(username)
|
||||||
|
|
||||||
|
def create_session_trusted(self, username: str) -> str:
|
||||||
|
"""Issue a session token for an already-verified user.
|
||||||
|
Call only after verify_password (and TOTP if enabled) have passed."""
|
||||||
|
username = username.strip().lower()
|
||||||
token = secrets.token_hex(32)
|
token = secrets.token_hex(32)
|
||||||
with self._sessions_lock:
|
with self._sessions_lock:
|
||||||
self._sessions[token] = {
|
self._sessions[token] = {
|
||||||
|
|||||||
+11
-39
@@ -1,40 +1,12 @@
|
|||||||
# src/constants.py
|
# core/constants.py
|
||||||
"""Application-wide constants and configuration values."""
|
"""Backward-compatible shim — the single source of truth is src/constants.py.
|
||||||
import os
|
|
||||||
|
|
||||||
APP_VERSION = "0.9.1"
|
Historically there were two copies of this module (this one lagged behind at
|
||||||
|
APP_VERSION 0.9.1 and was missing the consolidated tool-output constants). To
|
||||||
# Base paths
|
kill the drift, this now simply re-exports everything from src.constants so
|
||||||
BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + "/"
|
there is exactly one place that defines paths and reads ODYSSEUS_DATA_DIR.
|
||||||
STATIC_DIR = os.path.join(BASE_DIR, "static")
|
internal_api_base() also lives in src.constants now and is re-exported here so
|
||||||
DATA_DIR = os.path.join(BASE_DIR, "data")
|
existing `from core.constants import internal_api_base` callers keep working.
|
||||||
|
"""
|
||||||
# Data file paths
|
from src.constants import * # noqa: F401,F403
|
||||||
SESSIONS_FILE = os.path.join(DATA_DIR, "sessions.json")
|
from src.constants import internal_api_base # noqa: F401 (explicit: functions aren't covered by some linters' * checks)
|
||||||
MEMORY_FILE = os.path.join(DATA_DIR, "memory.json")
|
|
||||||
MEMORY_DOC = os.path.join(DATA_DIR, "memory_doc.md")
|
|
||||||
PERSONAL_DIR = os.path.join(DATA_DIR, "personal_docs")
|
|
||||||
RUNBOOK_DIR = os.path.join(PERSONAL_DIR, "runbook")
|
|
||||||
UPLOAD_DIR = os.path.join(DATA_DIR, "uploads")
|
|
||||||
FEATURES_FILE = os.path.join(DATA_DIR, "features.json")
|
|
||||||
SETTINGS_FILE = os.path.join(DATA_DIR, "settings.json")
|
|
||||||
|
|
||||||
# API Configuration
|
|
||||||
MAX_CONTEXT_MESSAGES = 90
|
|
||||||
REQUEST_TIMEOUT = 20
|
|
||||||
OPENAI_COMPAT_PATH = "/v1/chat/completions"
|
|
||||||
|
|
||||||
# Environment variables with defaults
|
|
||||||
DEFAULT_HOST = os.getenv("LLM_HOST", "localhost")
|
|
||||||
LLM_HOSTS = [h.strip() for h in os.getenv("LLM_HOSTS", "").split(",") if h.strip()]
|
|
||||||
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
|
|
||||||
SEARXNG_INSTANCE = os.getenv('SEARXNG_INSTANCE', 'http://localhost:8080')
|
|
||||||
|
|
||||||
|
|
||||||
# Cleanup configuration
|
|
||||||
CLEANUP_ENABLED = os.getenv("CLEANUP_ENABLED", "True").lower() == "true"
|
|
||||||
CLEANUP_INTERVAL_HOURS = int(os.getenv("CLEANUP_INTERVAL_HOURS", "24"))
|
|
||||||
|
|
||||||
# Default parameters
|
|
||||||
DEFAULT_TEMPERATURE = 1.0
|
|
||||||
DEFAULT_MAX_TOKENS = 0
|
|
||||||
|
|||||||
+168
-7
@@ -29,8 +29,9 @@ class TimestampMixin:
|
|||||||
def updated_at(cls):
|
def updated_at(cls):
|
||||||
return Column(DateTime, default=utcnow_naive, onupdate=utcnow_naive, nullable=False)
|
return Column(DateTime, default=utcnow_naive, onupdate=utcnow_naive, nullable=False)
|
||||||
|
|
||||||
# Get database URL from environment, default to SQLite
|
# Get database URL from environment, default to SQLite in DATA_DIR
|
||||||
DATABASE_URL = os.getenv("DATABASE_URL", "sqlite:///./data/app.db")
|
from src.constants import DATA_DIR, AUTH_FILE, MEMORY_FILE, USER_PREFS_FILE, SETTINGS_FILE
|
||||||
|
DATABASE_URL = os.getenv("DATABASE_URL", f"sqlite:///{DATA_DIR}/app.db")
|
||||||
|
|
||||||
# Create engine
|
# Create engine
|
||||||
engine = create_engine(
|
engine = create_engine(
|
||||||
@@ -360,6 +361,24 @@ class ModelEndpoint(TimestampMixin, Base):
|
|||||||
# is the historical default. When non-null, the model picker only shows
|
# is the historical default. When non-null, the model picker only shows
|
||||||
# the endpoint to that user (admins always see everything).
|
# the endpoint to that user (admins always see everything).
|
||||||
owner = Column(String, nullable=True, index=True)
|
owner = Column(String, nullable=True, index=True)
|
||||||
|
# Optional OAuth/session-backed credential row. Used by subscription-backed
|
||||||
|
# providers that need refresh tokens instead of a static API key.
|
||||||
|
provider_auth_id = Column(String, nullable=True, index=True)
|
||||||
|
|
||||||
|
|
||||||
|
class ProviderAuthSession(TimestampMixin, Base):
|
||||||
|
"""Encrypted OAuth/session credentials for refresh-aware model providers."""
|
||||||
|
__tablename__ = "provider_auth_sessions"
|
||||||
|
|
||||||
|
id = Column(String, primary_key=True, index=True)
|
||||||
|
provider = Column(String, nullable=False, index=True)
|
||||||
|
owner = Column(String, nullable=True, index=True)
|
||||||
|
label = Column(String, nullable=True)
|
||||||
|
base_url = Column(String, nullable=False)
|
||||||
|
access_token = Column(EncryptedText, nullable=True)
|
||||||
|
refresh_token = Column(EncryptedText, nullable=True)
|
||||||
|
last_refresh = Column(DateTime, nullable=True)
|
||||||
|
auth_mode = Column(String, nullable=True)
|
||||||
|
|
||||||
class McpServer(TimestampMixin, Base):
|
class McpServer(TimestampMixin, Base):
|
||||||
"""Admin-configured MCP (Model Context Protocol) tool servers."""
|
"""Admin-configured MCP (Model Context Protocol) tool servers."""
|
||||||
@@ -800,6 +819,26 @@ def _migrate_add_model_endpoint_owner_column():
|
|||||||
logging.getLogger(__name__).warning(f"model_endpoints.owner migration failed: {e}")
|
logging.getLogger(__name__).warning(f"model_endpoints.owner migration failed: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
def _migrate_add_provider_auth_id_column():
|
||||||
|
"""Add provider_auth_id column to model_endpoints if it doesn't exist."""
|
||||||
|
import sqlite3
|
||||||
|
db_path = DATABASE_URL.replace("sqlite:///", "")
|
||||||
|
if not os.path.exists(db_path):
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
conn = sqlite3.connect(db_path)
|
||||||
|
cursor = conn.execute("PRAGMA table_info(model_endpoints)")
|
||||||
|
columns = [row[1] for row in cursor.fetchall()]
|
||||||
|
if columns and "provider_auth_id" not in columns:
|
||||||
|
conn.execute("ALTER TABLE model_endpoints ADD COLUMN provider_auth_id VARCHAR")
|
||||||
|
conn.execute("CREATE INDEX IF NOT EXISTS ix_model_endpoints_provider_auth_id ON model_endpoints(provider_auth_id)")
|
||||||
|
conn.commit()
|
||||||
|
logging.getLogger(__name__).info("Migrated: added 'provider_auth_id' column + index to model_endpoints")
|
||||||
|
conn.close()
|
||||||
|
except Exception as e:
|
||||||
|
logging.getLogger(__name__).warning(f"model_endpoints.provider_auth_id migration failed: {e}")
|
||||||
|
|
||||||
|
|
||||||
def _migrate_add_model_type_column():
|
def _migrate_add_model_type_column():
|
||||||
"""Add model_type column to model_endpoints if it doesn't exist."""
|
"""Add model_type column to model_endpoints if it doesn't exist."""
|
||||||
import sqlite3
|
import sqlite3
|
||||||
@@ -1065,7 +1104,7 @@ def _migrate_assign_legacy_owner():
|
|||||||
# fell through to "first user" every time.
|
# fell through to "first user" every time.
|
||||||
auth_path = os.path.join(os.path.dirname(DATABASE_URL.replace("sqlite:///", "")), "auth.json")
|
auth_path = os.path.join(os.path.dirname(DATABASE_URL.replace("sqlite:///", "")), "auth.json")
|
||||||
if not os.path.isabs(auth_path):
|
if not os.path.isabs(auth_path):
|
||||||
auth_path = os.path.join("data", "auth.json")
|
auth_path = AUTH_FILE
|
||||||
admin_user = None
|
admin_user = None
|
||||||
try:
|
try:
|
||||||
with open(auth_path, "r", encoding="utf-8") as f:
|
with open(auth_path, "r", encoding="utf-8") as f:
|
||||||
@@ -1118,7 +1157,7 @@ def _migrate_assign_legacy_owner():
|
|||||||
logger.warning(f"Legacy owner migration failed: {e}")
|
logger.warning(f"Legacy owner migration failed: {e}")
|
||||||
|
|
||||||
# Also migrate memory.json
|
# Also migrate memory.json
|
||||||
mem_path = os.path.join("data", "memory.json")
|
mem_path = MEMORY_FILE
|
||||||
try:
|
try:
|
||||||
if os.path.exists(mem_path):
|
if os.path.exists(mem_path):
|
||||||
with open(mem_path, "r", encoding="utf-8") as f:
|
with open(mem_path, "r", encoding="utf-8") as f:
|
||||||
@@ -1136,7 +1175,7 @@ def _migrate_assign_legacy_owner():
|
|||||||
logger.warning(f"memory.json legacy migration failed: {e}")
|
logger.warning(f"memory.json legacy migration failed: {e}")
|
||||||
|
|
||||||
# Also migrate user_prefs.json to per-user format
|
# Also migrate user_prefs.json to per-user format
|
||||||
prefs_path = os.path.join("data", "user_prefs.json")
|
prefs_path = USER_PREFS_FILE
|
||||||
try:
|
try:
|
||||||
if os.path.exists(prefs_path):
|
if os.path.exists(prefs_path):
|
||||||
with open(prefs_path, "r", encoding="utf-8") as f:
|
with open(prefs_path, "r", encoding="utf-8") as f:
|
||||||
@@ -1458,7 +1497,11 @@ class CalendarCal(TimestampMixin, Base):
|
|||||||
owner = Column(String, nullable=True, index=True)
|
owner = Column(String, nullable=True, index=True)
|
||||||
name = Column(String, nullable=False)
|
name = Column(String, nullable=False)
|
||||||
color = Column(String, default="#5b8abf")
|
color = Column(String, default="#5b8abf")
|
||||||
source = Column(String, default="local") # "local" or "timetree"
|
source = Column(String, default="local") # "local" or "caldav"
|
||||||
|
# UUID of the CalDAV account in user prefs that owns this calendar.
|
||||||
|
# NULL for local calendars and for CalDAV calendars created before
|
||||||
|
# multi-account support was added (treated as "use any configured account").
|
||||||
|
account_id = Column(String, nullable=True, index=True)
|
||||||
|
|
||||||
events = relationship("CalendarEvent", back_populates="calendar", cascade="all, delete-orphan")
|
events = relationship("CalendarEvent", back_populates="calendar", cascade="all, delete-orphan")
|
||||||
|
|
||||||
@@ -1526,7 +1569,7 @@ def _migrate_seed_email_account():
|
|||||||
import json as _json
|
import json as _json
|
||||||
import uuid as _uuid
|
import uuid as _uuid
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
settings_file = Path("data/settings.json")
|
settings_file = Path(SETTINGS_FILE)
|
||||||
if not settings_file.exists():
|
if not settings_file.exists():
|
||||||
return
|
return
|
||||||
try:
|
try:
|
||||||
@@ -1594,6 +1637,7 @@ def init_db():
|
|||||||
_migrate_add_model_type_column()
|
_migrate_add_model_type_column()
|
||||||
_migrate_add_model_endpoint_refresh_columns()
|
_migrate_add_model_endpoint_refresh_columns()
|
||||||
_migrate_add_model_endpoint_owner_column()
|
_migrate_add_model_endpoint_owner_column()
|
||||||
|
_migrate_add_provider_auth_id_column()
|
||||||
_migrate_add_supports_tools_column()
|
_migrate_add_supports_tools_column()
|
||||||
_migrate_add_task_run_model_column()
|
_migrate_add_task_run_model_column()
|
||||||
_migrate_add_owner_column()
|
_migrate_add_owner_column()
|
||||||
@@ -1622,9 +1666,105 @@ def init_db():
|
|||||||
_migrate_add_calendar_metadata()
|
_migrate_add_calendar_metadata()
|
||||||
_migrate_add_calendar_is_utc()
|
_migrate_add_calendar_is_utc()
|
||||||
_migrate_add_calendar_origin()
|
_migrate_add_calendar_origin()
|
||||||
|
_migrate_add_calendar_account_id()
|
||||||
|
_migrate_chat_messages_fts()
|
||||||
_migrate_encrypt_email_passwords()
|
_migrate_encrypt_email_passwords()
|
||||||
_migrate_encrypt_signatures()
|
_migrate_encrypt_signatures()
|
||||||
_migrate_encrypt_endpoint_keys()
|
_migrate_encrypt_endpoint_keys()
|
||||||
|
_migrate_backfill_task_folders()
|
||||||
|
|
||||||
|
|
||||||
|
def _migrate_backfill_task_folders():
|
||||||
|
"""Backfill folder='Tasks' on pre-existing task/research sessions.
|
||||||
|
|
||||||
|
Sessions created by the task scheduler (LLM tasks, action tasks, research
|
||||||
|
runs) now set folder='Tasks' at creation time. This migration tags any
|
||||||
|
older sessions that predate that assignment. Idempotent — only touches
|
||||||
|
rows where folder is NULL or empty and the title matches known prefixes.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
with engine.connect() as conn:
|
||||||
|
cols = [r[1] for r in conn.execute(text("PRAGMA table_info(sessions)"))]
|
||||||
|
if "folder" not in cols:
|
||||||
|
return
|
||||||
|
res = conn.execute(text(
|
||||||
|
"UPDATE sessions SET folder = 'Tasks' "
|
||||||
|
"WHERE (folder IS NULL OR folder = '') "
|
||||||
|
"AND (name LIKE '[Task] %' OR name LIKE '[Research] %')"
|
||||||
|
))
|
||||||
|
conn.commit()
|
||||||
|
if res.rowcount:
|
||||||
|
logging.getLogger(__name__).info(
|
||||||
|
f"Backfilled folder='Tasks' on {res.rowcount} task/research sessions")
|
||||||
|
except Exception as e:
|
||||||
|
logging.getLogger(__name__).warning(f"task folder backfill: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
def _migrate_chat_messages_fts():
|
||||||
|
"""Create and backfill the session transcript FTS index for SQLite."""
|
||||||
|
if not DATABASE_URL.startswith("sqlite"):
|
||||||
|
return
|
||||||
|
|
||||||
|
db_path = DATABASE_URL.replace("sqlite:///", "")
|
||||||
|
if db_path == ":memory:":
|
||||||
|
return
|
||||||
|
conn = None
|
||||||
|
try:
|
||||||
|
conn = sqlite3.connect(db_path)
|
||||||
|
try:
|
||||||
|
conn.execute("CREATE VIRTUAL TABLE IF NOT EXISTS temp._odysseus_fts5_probe USING fts5(content)")
|
||||||
|
conn.execute("DROP TABLE IF EXISTS temp._odysseus_fts5_probe")
|
||||||
|
except Exception as e:
|
||||||
|
logging.getLogger(__name__).warning(f"chat_messages FTS migration skipped; FTS5 unavailable: {e}")
|
||||||
|
return
|
||||||
|
|
||||||
|
conn.executescript(
|
||||||
|
"""
|
||||||
|
CREATE VIRTUAL TABLE IF NOT EXISTS chat_messages_fts USING fts5(
|
||||||
|
content,
|
||||||
|
message_id UNINDEXED,
|
||||||
|
session_id UNINDEXED,
|
||||||
|
role UNINDEXED
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TRIGGER IF NOT EXISTS chat_messages_fts_ai
|
||||||
|
AFTER INSERT ON chat_messages BEGIN
|
||||||
|
INSERT INTO chat_messages_fts(content, message_id, session_id, role)
|
||||||
|
VALUES (COALESCE(new.content, ''), new.id, new.session_id, new.role);
|
||||||
|
END;
|
||||||
|
|
||||||
|
CREATE TRIGGER IF NOT EXISTS chat_messages_fts_ad
|
||||||
|
AFTER DELETE ON chat_messages BEGIN
|
||||||
|
DELETE FROM chat_messages_fts WHERE message_id = old.id;
|
||||||
|
END;
|
||||||
|
|
||||||
|
CREATE TRIGGER IF NOT EXISTS chat_messages_fts_au
|
||||||
|
AFTER UPDATE ON chat_messages BEGIN
|
||||||
|
DELETE FROM chat_messages_fts WHERE message_id = old.id;
|
||||||
|
INSERT INTO chat_messages_fts(content, message_id, session_id, role)
|
||||||
|
VALUES (COALESCE(new.content, ''), new.id, new.session_id, new.role);
|
||||||
|
END;
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
conn.execute(
|
||||||
|
"""
|
||||||
|
INSERT INTO chat_messages_fts(content, message_id, session_id, role)
|
||||||
|
SELECT COALESCE(cm.content, ''), cm.id, cm.session_id, cm.role
|
||||||
|
FROM chat_messages cm
|
||||||
|
WHERE NOT EXISTS (
|
||||||
|
SELECT 1 FROM chat_messages_fts fts
|
||||||
|
WHERE fts.message_id = cm.id
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
conn.commit()
|
||||||
|
except Exception as e:
|
||||||
|
logging.getLogger(__name__).warning(f"chat_messages FTS migration failed: {e}")
|
||||||
|
finally:
|
||||||
|
try:
|
||||||
|
conn.close()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
def _migrate_add_email_smtp_security():
|
def _migrate_add_email_smtp_security():
|
||||||
@@ -1786,6 +1926,27 @@ def _migrate_add_calendar_origin():
|
|||||||
logging.getLogger(__name__).warning(f"calendar_events.origin migration failed: {e}")
|
logging.getLogger(__name__).warning(f"calendar_events.origin migration failed: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
def _migrate_add_calendar_account_id():
|
||||||
|
"""Add `account_id` to calendars so each CalDAV-backed calendar knows which
|
||||||
|
credential set (from caldav_accounts in user prefs) owns it. Idempotent."""
|
||||||
|
import sqlite3
|
||||||
|
db_path = DATABASE_URL.replace("sqlite:///", "")
|
||||||
|
if not os.path.exists(db_path):
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
conn = sqlite3.connect(db_path)
|
||||||
|
cursor = conn.execute("PRAGMA table_info(calendars)")
|
||||||
|
columns = [row[1] for row in cursor.fetchall()]
|
||||||
|
if columns and "account_id" not in columns:
|
||||||
|
conn.execute("ALTER TABLE calendars ADD COLUMN account_id TEXT")
|
||||||
|
conn.execute("CREATE INDEX IF NOT EXISTS ix_calendars_account_id ON calendars(account_id)")
|
||||||
|
conn.commit()
|
||||||
|
logging.getLogger(__name__).info("Migrated: added 'account_id' column to calendars")
|
||||||
|
conn.close()
|
||||||
|
except Exception as e:
|
||||||
|
logging.getLogger(__name__).warning(f"calendars.account_id migration failed: {e}")
|
||||||
|
|
||||||
|
|
||||||
def _migrate_add_calendar_metadata():
|
def _migrate_add_calendar_metadata():
|
||||||
"""Add importance/event_type/last_pinged columns to calendar_events table."""
|
"""Add importance/event_type/last_pinged columns to calendar_events table."""
|
||||||
import sqlite3
|
import sqlite3
|
||||||
|
|||||||
@@ -17,6 +17,15 @@ INTERNAL_TOOL_TOKEN = os.environ.get("ODYSSEUS_INTERNAL_TOKEN") or secrets.token
|
|||||||
INTERNAL_TOOL_HEADER = "X-Odysseus-Internal-Token"
|
INTERNAL_TOOL_HEADER = "X-Odysseus-Internal-Token"
|
||||||
|
|
||||||
|
|
||||||
|
def is_cors_preflight(method: str, headers) -> bool:
|
||||||
|
"""True for a genuine CORS preflight: an OPTIONS request carrying the
|
||||||
|
Access-Control-Request-Method header. Such requests are credential-less by
|
||||||
|
design and must reach CORSMiddleware to be answered -- gating them on auth
|
||||||
|
401s the preflight and breaks every cross-origin browser/WebView client.
|
||||||
|
Pure so it can be unit-tested without standing up the app."""
|
||||||
|
return method == "OPTIONS" and "access-control-request-method" in headers
|
||||||
|
|
||||||
|
|
||||||
def require_admin(request: Request):
|
def require_admin(request: Request):
|
||||||
"""Raise 403 if the current user isn't an admin.
|
"""Raise 403 if the current user isn't an admin.
|
||||||
Allows access when auth is explicitly disabled, or when the request carries
|
Allows access when auth is explicitly disabled, or when the request carries
|
||||||
@@ -58,11 +67,22 @@ class SecurityHeadersMiddleware(BaseHTTPMiddleware):
|
|||||||
|
|
||||||
# Tool render endpoints are served inside iframes — allow framing by self
|
# Tool render endpoints are served inside iframes — allow framing by self
|
||||||
is_tool_render = path.startswith("/api/tools/") and path.endswith("/render")
|
is_tool_render = path.startswith("/api/tools/") and path.endswith("/render")
|
||||||
|
# PDF previews are embedded by the in-app document library. Keep the
|
||||||
|
# exception route-scoped so normal app pages remain unframeable.
|
||||||
|
is_document_pdf_preview = path.startswith("/api/document/") and path.endswith("/render-pdf")
|
||||||
# Visual report pages are self-contained HTML — need inline scripts + external images
|
# Visual report pages are self-contained HTML — need inline scripts + external images
|
||||||
is_report = path.startswith("/api/research/report/")
|
is_report = path.startswith("/api/research/report/")
|
||||||
|
|
||||||
response.headers["X-Content-Type-Options"] = "nosniff"
|
response.headers["X-Content-Type-Options"] = "nosniff"
|
||||||
response.headers["Referrer-Policy"] = "no-referrer"
|
response.headers["Referrer-Policy"] = "no-referrer"
|
||||||
|
response.headers["Permissions-Policy"] = "camera=(), microphone=(self), geolocation=()"
|
||||||
|
|
||||||
|
is_https = (
|
||||||
|
request.url.scheme == "https"
|
||||||
|
or request.headers.get("X-Forwarded-Proto") == "https"
|
||||||
|
)
|
||||||
|
if is_https:
|
||||||
|
response.headers["Strict-Transport-Security"] = "max-age=31536000; includeSubDomains"
|
||||||
|
|
||||||
if is_report:
|
if is_report:
|
||||||
response.headers["Content-Security-Policy"] = (
|
response.headers["Content-Security-Policy"] = (
|
||||||
@@ -79,6 +99,12 @@ class SecurityHeadersMiddleware(BaseHTTPMiddleware):
|
|||||||
# sandbox="allow-scripts" attribute provides isolation.
|
# sandbox="allow-scripts" attribute provides isolation.
|
||||||
# Don't overwrite the route's own restrictive CSP either.
|
# Don't overwrite the route's own restrictive CSP either.
|
||||||
pass
|
pass
|
||||||
|
elif is_document_pdf_preview:
|
||||||
|
response.headers["X-Frame-Options"] = "SAMEORIGIN"
|
||||||
|
response.headers["Content-Security-Policy"] = (
|
||||||
|
"default-src 'none'; "
|
||||||
|
"frame-ancestors 'self'"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
response.headers["X-Frame-Options"] = "DENY"
|
response.headers["X-Frame-Options"] = "DENY"
|
||||||
# NOTE: `style-src 'unsafe-inline'` is intentionally retained.
|
# NOTE: `style-src 'unsafe-inline'` is intentionally retained.
|
||||||
|
|||||||
+205
-3
@@ -18,10 +18,22 @@ import ntpath
|
|||||||
import shutil
|
import shutil
|
||||||
import subprocess
|
import subprocess
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
import sys
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
import platform
|
||||||
|
|
||||||
IS_WINDOWS = os.name == "nt"
|
IS_WINDOWS = os.name == "nt"
|
||||||
IS_POSIX = not IS_WINDOWS
|
IS_POSIX = not IS_WINDOWS
|
||||||
|
# Allows APFEL support and ARM-native binary recommendations on Apple Silicon Macs.
|
||||||
|
IS_APPLE_SILICON = (
|
||||||
|
IS_POSIX
|
||||||
|
and platform.system() == "Darwin"
|
||||||
|
and platform.machine().lower()
|
||||||
|
in {
|
||||||
|
"arm64",
|
||||||
|
"aarch64",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# ── File permissions ────────────────────────────────────────────────────────
|
# ── File permissions ────────────────────────────────────────────────────────
|
||||||
@@ -53,9 +65,8 @@ def detached_popen_kwargs() -> dict:
|
|||||||
and is detached from any console.
|
and is detached from any console.
|
||||||
"""
|
"""
|
||||||
if IS_WINDOWS:
|
if IS_WINDOWS:
|
||||||
flags = (
|
flags = getattr(subprocess, "CREATE_NEW_PROCESS_GROUP", 0x00000200) | getattr(
|
||||||
getattr(subprocess, "CREATE_NEW_PROCESS_GROUP", 0x00000200)
|
subprocess, "DETACHED_PROCESS", 0x00000008
|
||||||
| getattr(subprocess, "DETACHED_PROCESS", 0x00000008)
|
|
||||||
)
|
)
|
||||||
return {"creationflags": flags}
|
return {"creationflags": flags}
|
||||||
return {"start_new_session": True}
|
return {"start_new_session": True}
|
||||||
@@ -150,6 +161,29 @@ _WINDOWS_BASH_RELATIVE_PATHS = (
|
|||||||
("usr", "bin", "bash.exe"),
|
("usr", "bin", "bash.exe"),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Paths to add to the remote SSH probe command to find tools like nvidia-smi that may not be on PATH.
|
||||||
|
_SSH_PATH_MEMBERS = (
|
||||||
|
"/usr/bin",
|
||||||
|
"/usr/local/bin",
|
||||||
|
"/usr/local/cuda/bin",
|
||||||
|
"/usr/lib/wsl/lib"
|
||||||
|
)
|
||||||
|
# Fallback locations for nvidia-smi on WSL and other Linux distros where it may not be on PATH.
|
||||||
|
NVIDIA_PATH_CANDIDATES = (
|
||||||
|
"/usr/bin/nvidia-smi",
|
||||||
|
"/usr/local/bin/nvidia-smi",
|
||||||
|
"/usr/local/cuda/bin/nvidia-smi",
|
||||||
|
"/usr/lib/wsl/lib/nvidia-smi",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _ssh_path_override() -> str:
|
||||||
|
"""Build the PATH export snippet used for remote SSH shell probes."""
|
||||||
|
return f"export PATH=\"$PATH:{':'.join(_SSH_PATH_MEMBERS)}\"; "
|
||||||
|
|
||||||
|
|
||||||
|
SSH_PATH_OVERRIDE = _ssh_path_override()
|
||||||
|
|
||||||
|
|
||||||
def _windows_bash_fallbacks() -> List[str]:
|
def _windows_bash_fallbacks() -> List[str]:
|
||||||
roots: List[str] = []
|
roots: List[str] = []
|
||||||
@@ -180,6 +214,21 @@ def _is_windows_bash_stub(path: str) -> bool:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def git_bash_path(path: str | Path) -> str:
|
||||||
|
"""Convert a path to POSIX style suitable for Git Bash on Windows.
|
||||||
|
|
||||||
|
Transforms drive letters (e.g., 'C:\\path') to POSIX '/c/path',
|
||||||
|
and uses forward slashes.
|
||||||
|
"""
|
||||||
|
p = Path(path)
|
||||||
|
p_str = p.as_posix()
|
||||||
|
if IS_WINDOWS and len(p_str) >= 2 and p_str[1] == ":":
|
||||||
|
drive = p_str[0].lower()
|
||||||
|
return f"/{drive}{p_str[2:]}"
|
||||||
|
return p_str
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def find_bash() -> Optional[str]:
|
def find_bash() -> Optional[str]:
|
||||||
"""Locate a real ``bash`` interpreter, or None.
|
"""Locate a real ``bash`` interpreter, or None.
|
||||||
|
|
||||||
@@ -242,3 +291,156 @@ def run_script_argv(script_path) -> List[str]:
|
|||||||
comspec = os.environ.get("ComSpec", "cmd.exe")
|
comspec = os.environ.get("ComSpec", "cmd.exe")
|
||||||
return [comspec, "/c", str(script_path)]
|
return [comspec, "/c", str(script_path)]
|
||||||
return ["sh", str(script_path)]
|
return ["sh", str(script_path)]
|
||||||
|
|
||||||
|
|
||||||
|
def is_wsl() -> bool:
|
||||||
|
"""True if running inside Windows Subsystem for Linux (WSL)."""
|
||||||
|
import sys
|
||||||
|
if sys.platform.startswith("linux") or os.name == "posix":
|
||||||
|
try:
|
||||||
|
with open("/proc/version", "r") as f:
|
||||||
|
if "microsoft" in f.read().lower():
|
||||||
|
return True
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def translate_path(path_str: str) -> str:
|
||||||
|
"""Translate a path (possibly a Windows path) to the current OS format.
|
||||||
|
|
||||||
|
Particularly handles Windows paths (e.g. C:\\foo or C:/foo) when running
|
||||||
|
under WSL, translating them to /mnt/c/foo.
|
||||||
|
Also handles standard path normalization to avoid string breakages.
|
||||||
|
"""
|
||||||
|
if not path_str:
|
||||||
|
return path_str
|
||||||
|
|
||||||
|
if is_wsl():
|
||||||
|
path_str = path_str.replace("\\", "/")
|
||||||
|
import re
|
||||||
|
m = re.match(r"^([a-zA-Z]):(.*)", path_str)
|
||||||
|
if m:
|
||||||
|
drive = m.group(1).lower()
|
||||||
|
rest = m.group(2)
|
||||||
|
if not rest.startswith("/"):
|
||||||
|
rest = "/" + rest
|
||||||
|
return f"/mnt/{drive}{rest}"
|
||||||
|
|
||||||
|
try:
|
||||||
|
return str(Path(path_str).resolve())
|
||||||
|
except Exception:
|
||||||
|
return path_str
|
||||||
|
|
||||||
|
|
||||||
|
def get_wsl_windows_user_profile() -> Optional[str]:
|
||||||
|
"""Retrieve the Windows host User Profile path from inside WSL."""
|
||||||
|
if not is_wsl():
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
r = run_wsl_windows_powershell("Write-Output $env:USERPROFILE", timeout=5)
|
||||||
|
if r.returncode == 0 and r.stdout.strip():
|
||||||
|
return translate_path(r.stdout.strip())
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
users_dir = "/mnt/c/Users"
|
||||||
|
if os.path.isdir(users_dir):
|
||||||
|
for entry in os.listdir(users_dir):
|
||||||
|
if entry not in ("All Users", "Default", "Default User", "desktop.ini", "Public"):
|
||||||
|
path = os.path.join(users_dir, entry)
|
||||||
|
if os.path.isdir(path):
|
||||||
|
return path
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _ssh_exec_argv(
|
||||||
|
remote: str,
|
||||||
|
ssh_port: str | None,
|
||||||
|
*,
|
||||||
|
remote_cmd: str | None = None,
|
||||||
|
connect_timeout: int | None = None,
|
||||||
|
strict_host_key_checking: bool | None = None,
|
||||||
|
) -> list[str]:
|
||||||
|
"""Build a consistent ssh argv for remote command execution."""
|
||||||
|
argv = ["ssh"]
|
||||||
|
if connect_timeout is not None:
|
||||||
|
argv.extend(["-o", f"ConnectTimeout={int(connect_timeout)}"])
|
||||||
|
if strict_host_key_checking is not None:
|
||||||
|
argv.extend(
|
||||||
|
[
|
||||||
|
"-o",
|
||||||
|
"StrictHostKeyChecking=yes"
|
||||||
|
if strict_host_key_checking
|
||||||
|
else "StrictHostKeyChecking=no",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
if ssh_port and ssh_port != "22":
|
||||||
|
argv.extend(["-p", str(ssh_port)])
|
||||||
|
argv.append(remote)
|
||||||
|
if remote_cmd is not None:
|
||||||
|
argv.append(remote_cmd)
|
||||||
|
return argv
|
||||||
|
|
||||||
|
|
||||||
|
def run_ssh_command(
|
||||||
|
remote: str,
|
||||||
|
ssh_port: str | None,
|
||||||
|
remote_cmd: str,
|
||||||
|
*,
|
||||||
|
timeout: float,
|
||||||
|
connect_timeout: int | None = None,
|
||||||
|
strict_host_key_checking: bool | None = None,
|
||||||
|
text: bool = True,
|
||||||
|
) -> subprocess.CompletedProcess:
|
||||||
|
"""Run an ssh command with centralized timeout and stderr/stdout capture."""
|
||||||
|
return subprocess.run(
|
||||||
|
_ssh_exec_argv(
|
||||||
|
remote,
|
||||||
|
ssh_port,
|
||||||
|
remote_cmd=remote_cmd,
|
||||||
|
connect_timeout=connect_timeout,
|
||||||
|
strict_host_key_checking=strict_host_key_checking,
|
||||||
|
),
|
||||||
|
timeout=timeout,
|
||||||
|
capture_output=True,
|
||||||
|
text=text,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _windows_powershell_argv(
|
||||||
|
command: str,
|
||||||
|
*,
|
||||||
|
no_profile: bool = True,
|
||||||
|
non_interactive: bool = True,
|
||||||
|
) -> List[str]:
|
||||||
|
argv: List[str] = ["powershell.exe"]
|
||||||
|
if no_profile:
|
||||||
|
argv.append("-NoProfile")
|
||||||
|
if non_interactive:
|
||||||
|
argv.append("-NonInteractive")
|
||||||
|
argv.extend(["-Command", command])
|
||||||
|
return argv
|
||||||
|
|
||||||
|
|
||||||
|
def run_wsl_windows_powershell(
|
||||||
|
command: str,
|
||||||
|
*,
|
||||||
|
timeout: float = 5,
|
||||||
|
) -> subprocess.CompletedProcess[str]:
|
||||||
|
"""Run a PowerShell command on the Windows host from WSL.
|
||||||
|
|
||||||
|
Raises ``RuntimeError`` when called outside WSL.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if not is_wsl():
|
||||||
|
raise RuntimeError("run_wsl_windows_powershell is only supported in WSL")
|
||||||
|
return subprocess.run(
|
||||||
|
_windows_powershell_argv(command),
|
||||||
|
capture_output=True,
|
||||||
|
text=True,
|
||||||
|
timeout=timeout,
|
||||||
|
)
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ import logging
|
|||||||
from datetime import datetime, timezone, timedelta
|
from datetime import datetime, timezone, timedelta
|
||||||
from typing import Dict, Optional
|
from typing import Dict, Optional
|
||||||
|
|
||||||
from .database import Session as DbSession, ChatMessage as DbChatMessage, Document as DbDocument, SessionLocal
|
from .database import Session as DbSession, ChatMessage as DbChatMessage, Document as DbDocument, SessionLocal, utcnow_naive
|
||||||
from .models import Session, ChatMessage
|
from .models import Session, ChatMessage
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -619,7 +619,7 @@ class SessionManager:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
all_sessions = db.query(DbSession).all()
|
all_sessions = db.query(DbSession).all()
|
||||||
cutoff_date = datetime.now(timezone.utc) - timedelta(days=auto_archive_days)
|
cutoff_date = utcnow_naive() - timedelta(days=auto_archive_days)
|
||||||
|
|
||||||
for db_session in all_sessions:
|
for db_session in all_sessions:
|
||||||
stats['total_checked'] += 1
|
stats['total_checked'] += 1
|
||||||
|
|||||||
@@ -52,12 +52,14 @@ services:
|
|||||||
- SECURE_COOKIES=${SECURE_COOKIES:-false}
|
- SECURE_COOKIES=${SECURE_COOKIES:-false}
|
||||||
- EMBEDDING_URL=${EMBEDDING_URL:-}
|
- EMBEDDING_URL=${EMBEDDING_URL:-}
|
||||||
- EMBEDDING_MODEL=${EMBEDDING_MODEL:-}
|
- EMBEDDING_MODEL=${EMBEDDING_MODEL:-}
|
||||||
|
- EMBEDDING_API_KEY=${EMBEDDING_API_KEY:-}
|
||||||
- FASTEMBED_MODEL=${FASTEMBED_MODEL:-sentence-transformers/all-MiniLM-L6-v2}
|
- FASTEMBED_MODEL=${FASTEMBED_MODEL:-sentence-transformers/all-MiniLM-L6-v2}
|
||||||
- FASTEMBED_CACHE_PATH=${FASTEMBED_CACHE_PATH:-}
|
- FASTEMBED_CACHE_PATH=${FASTEMBED_CACHE_PATH:-}
|
||||||
- CLEANUP_INTERVAL_HOURS=${CLEANUP_INTERVAL_HOURS:-24}
|
- CLEANUP_INTERVAL_HOURS=${CLEANUP_INTERVAL_HOURS:-24}
|
||||||
- ODYSSEUS_INPROCESS_POLLERS=${ODYSSEUS_INPROCESS_POLLERS:-1}
|
- ODYSSEUS_INPROCESS_POLLERS=${ODYSSEUS_INPROCESS_POLLERS:-1}
|
||||||
- ODYSSEUS_INPROCESS_TASKS=${ODYSSEUS_INPROCESS_TASKS:-1}
|
- ODYSSEUS_INPROCESS_TASKS=${ODYSSEUS_INPROCESS_TASKS:-1}
|
||||||
- ODYSSEUS_SCRIPT_HOST=${ODYSSEUS_SCRIPT_HOST:-localhost}
|
- ODYSSEUS_SCRIPT_HOST=${ODYSSEUS_SCRIPT_HOST:-localhost}
|
||||||
|
- ODYSSEUS_CHAT_UPLOAD_MAX_BYTES=${ODYSSEUS_CHAT_UPLOAD_MAX_BYTES:-10485760}
|
||||||
- DATA_BRAVE_API_KEY=${DATA_BRAVE_API_KEY:-}
|
- DATA_BRAVE_API_KEY=${DATA_BRAVE_API_KEY:-}
|
||||||
- GOOGLE_API_KEY=${GOOGLE_API_KEY:-}
|
- GOOGLE_API_KEY=${GOOGLE_API_KEY:-}
|
||||||
- GOOGLE_PSE_CX=${GOOGLE_PSE_CX:-}
|
- GOOGLE_PSE_CX=${GOOGLE_PSE_CX:-}
|
||||||
|
|||||||
@@ -51,12 +51,14 @@ services:
|
|||||||
- SECURE_COOKIES=${SECURE_COOKIES:-false}
|
- SECURE_COOKIES=${SECURE_COOKIES:-false}
|
||||||
- EMBEDDING_URL=${EMBEDDING_URL:-}
|
- EMBEDDING_URL=${EMBEDDING_URL:-}
|
||||||
- EMBEDDING_MODEL=${EMBEDDING_MODEL:-}
|
- EMBEDDING_MODEL=${EMBEDDING_MODEL:-}
|
||||||
|
- EMBEDDING_API_KEY=${EMBEDDING_API_KEY:-}
|
||||||
- FASTEMBED_MODEL=${FASTEMBED_MODEL:-sentence-transformers/all-MiniLM-L6-v2}
|
- FASTEMBED_MODEL=${FASTEMBED_MODEL:-sentence-transformers/all-MiniLM-L6-v2}
|
||||||
- FASTEMBED_CACHE_PATH=${FASTEMBED_CACHE_PATH:-}
|
- FASTEMBED_CACHE_PATH=${FASTEMBED_CACHE_PATH:-}
|
||||||
- CLEANUP_INTERVAL_HOURS=${CLEANUP_INTERVAL_HOURS:-24}
|
- CLEANUP_INTERVAL_HOURS=${CLEANUP_INTERVAL_HOURS:-24}
|
||||||
- ODYSSEUS_INPROCESS_POLLERS=${ODYSSEUS_INPROCESS_POLLERS:-1}
|
- ODYSSEUS_INPROCESS_POLLERS=${ODYSSEUS_INPROCESS_POLLERS:-1}
|
||||||
- ODYSSEUS_INPROCESS_TASKS=${ODYSSEUS_INPROCESS_TASKS:-1}
|
- ODYSSEUS_INPROCESS_TASKS=${ODYSSEUS_INPROCESS_TASKS:-1}
|
||||||
- ODYSSEUS_SCRIPT_HOST=${ODYSSEUS_SCRIPT_HOST:-localhost}
|
- ODYSSEUS_SCRIPT_HOST=${ODYSSEUS_SCRIPT_HOST:-localhost}
|
||||||
|
- ODYSSEUS_CHAT_UPLOAD_MAX_BYTES=${ODYSSEUS_CHAT_UPLOAD_MAX_BYTES:-10485760}
|
||||||
- DATA_BRAVE_API_KEY=${DATA_BRAVE_API_KEY:-}
|
- DATA_BRAVE_API_KEY=${DATA_BRAVE_API_KEY:-}
|
||||||
- GOOGLE_API_KEY=${GOOGLE_API_KEY:-}
|
- GOOGLE_API_KEY=${GOOGLE_API_KEY:-}
|
||||||
- GOOGLE_PSE_CX=${GOOGLE_PSE_CX:-}
|
- GOOGLE_PSE_CX=${GOOGLE_PSE_CX:-}
|
||||||
|
|||||||
@@ -40,12 +40,14 @@ services:
|
|||||||
- SECURE_COOKIES=${SECURE_COOKIES:-false}
|
- SECURE_COOKIES=${SECURE_COOKIES:-false}
|
||||||
- EMBEDDING_URL=${EMBEDDING_URL:-}
|
- EMBEDDING_URL=${EMBEDDING_URL:-}
|
||||||
- EMBEDDING_MODEL=${EMBEDDING_MODEL:-}
|
- EMBEDDING_MODEL=${EMBEDDING_MODEL:-}
|
||||||
|
- EMBEDDING_API_KEY=${EMBEDDING_API_KEY:-}
|
||||||
- FASTEMBED_MODEL=${FASTEMBED_MODEL:-sentence-transformers/all-MiniLM-L6-v2}
|
- FASTEMBED_MODEL=${FASTEMBED_MODEL:-sentence-transformers/all-MiniLM-L6-v2}
|
||||||
- FASTEMBED_CACHE_PATH=${FASTEMBED_CACHE_PATH:-}
|
- FASTEMBED_CACHE_PATH=${FASTEMBED_CACHE_PATH:-}
|
||||||
- CLEANUP_INTERVAL_HOURS=${CLEANUP_INTERVAL_HOURS:-24}
|
- CLEANUP_INTERVAL_HOURS=${CLEANUP_INTERVAL_HOURS:-24}
|
||||||
- ODYSSEUS_INPROCESS_POLLERS=${ODYSSEUS_INPROCESS_POLLERS:-1}
|
- ODYSSEUS_INPROCESS_POLLERS=${ODYSSEUS_INPROCESS_POLLERS:-1}
|
||||||
- ODYSSEUS_INPROCESS_TASKS=${ODYSSEUS_INPROCESS_TASKS:-1}
|
- ODYSSEUS_INPROCESS_TASKS=${ODYSSEUS_INPROCESS_TASKS:-1}
|
||||||
- ODYSSEUS_SCRIPT_HOST=${ODYSSEUS_SCRIPT_HOST:-localhost}
|
- ODYSSEUS_SCRIPT_HOST=${ODYSSEUS_SCRIPT_HOST:-localhost}
|
||||||
|
- ODYSSEUS_CHAT_UPLOAD_MAX_BYTES=${ODYSSEUS_CHAT_UPLOAD_MAX_BYTES:-10485760}
|
||||||
- DATA_BRAVE_API_KEY=${DATA_BRAVE_API_KEY:-}
|
- DATA_BRAVE_API_KEY=${DATA_BRAVE_API_KEY:-}
|
||||||
- GOOGLE_API_KEY=${GOOGLE_API_KEY:-}
|
- GOOGLE_API_KEY=${GOOGLE_API_KEY:-}
|
||||||
- GOOGLE_PSE_CX=${GOOGLE_PSE_CX:-}
|
- GOOGLE_PSE_CX=${GOOGLE_PSE_CX:-}
|
||||||
|
|||||||
@@ -0,0 +1,17 @@
|
|||||||
|
# Outlook / Office 365 email accounts
|
||||||
|
|
||||||
|
Odysseus email accounts currently use IMAP and SMTP with username/password
|
||||||
|
authentication. That works for providers that still allow app passwords or
|
||||||
|
mailbox passwords for IMAP/SMTP.
|
||||||
|
|
||||||
|
Microsoft disables basic authentication for Outlook and Microsoft 365 in most
|
||||||
|
modern accounts and tenants. If you try to add an Outlook account with a normal
|
||||||
|
password, Microsoft may return errors such as:
|
||||||
|
|
||||||
|
- `IMAP: AUTHENTICATE failed`
|
||||||
|
- `SMTP: 535 5.7.139 Authentication unsuccessful, basic authentication is disabled`
|
||||||
|
|
||||||
|
This is expected. Odysseus does not support Microsoft OAuth or Graph Mail yet,
|
||||||
|
so Outlook / Office 365 accounts cannot currently be added through the password
|
||||||
|
form. Use another email provider with app-password support, or track the future
|
||||||
|
Microsoft Graph OAuth integration.
|
||||||
@@ -1,22 +0,0 @@
|
|||||||
"""
|
|
||||||
_common.py
|
|
||||||
|
|
||||||
Shared constants and helpers for built-in MCP servers.
|
|
||||||
"""
|
|
||||||
|
|
||||||
MAX_OUTPUT_CHARS = 10_000
|
|
||||||
MAX_READ_CHARS = 20_000
|
|
||||||
SHELL_TIMEOUT = 60
|
|
||||||
PYTHON_TIMEOUT = 30
|
|
||||||
SEARCH_TIMEOUT = 30
|
|
||||||
|
|
||||||
|
|
||||||
def truncate(text: str, limit: int = MAX_OUTPUT_CHARS) -> str:
|
|
||||||
"""Truncate text to *limit* characters with a suffix note."""
|
|
||||||
if not isinstance(text, str):
|
|
||||||
# Tool output is occasionally None or a non-string; len(None) would
|
|
||||||
# raise. Coerce so this shared helper never crashes a tool response.
|
|
||||||
text = "" if text is None else str(text)
|
|
||||||
if len(text) > limit:
|
|
||||||
return text[:limit] + f"\n... (truncated, {len(text)} chars total)"
|
|
||||||
return text
|
|
||||||
+182
-125
@@ -31,13 +31,19 @@ sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
|
|||||||
|
|
||||||
server = Server("email")
|
server = Server("email")
|
||||||
EMAIL_SOCKET_TIMEOUT = float(os.environ.get("EMAIL_SOCKET_TIMEOUT", "20"))
|
EMAIL_SOCKET_TIMEOUT = float(os.environ.get("EMAIL_SOCKET_TIMEOUT", "20"))
|
||||||
DATA_DIR = Path(__file__).resolve().parent.parent / "data"
|
from src.constants import DATA_DIR as _DATA_DIR, APP_DB, EMAIL_CACHE_DB, SETTINGS_FILE as _SETTINGS_FILE, MAIL_ATTACHMENTS_DIR
|
||||||
|
DATA_DIR = Path(_DATA_DIR)
|
||||||
|
|
||||||
|
|
||||||
def _b(value) -> bytes:
|
def _b(value) -> bytes:
|
||||||
return str(value).encode()
|
return str(value).encode()
|
||||||
|
|
||||||
|
|
||||||
|
def _q(name: str) -> str:
|
||||||
|
"""Quote an IMAP mailbox name for commands that take mailbox args."""
|
||||||
|
return '"' + (name or "").replace("\\", "\\\\").replace('"', '\\"') + '"'
|
||||||
|
|
||||||
|
|
||||||
def _uid_fetch_rows(data) -> list:
|
def _uid_fetch_rows(data) -> list:
|
||||||
return [d for d in (data or []) if isinstance(d, bytes) and b"UID " in d]
|
return [d for d in (data or []) if isinstance(d, bytes) and b"UID " in d]
|
||||||
|
|
||||||
@@ -58,7 +64,7 @@ def _clean_header_value(value) -> str:
|
|||||||
|
|
||||||
|
|
||||||
def _db_path() -> Path:
|
def _db_path() -> Path:
|
||||||
return DATA_DIR / "app.db"
|
return Path(APP_DB)
|
||||||
|
|
||||||
|
|
||||||
def _list_accounts_raw() -> list:
|
def _list_accounts_raw() -> list:
|
||||||
@@ -157,7 +163,7 @@ def _load_config(account: str | None = None) -> dict:
|
|||||||
"trash_folder": os.environ.get("TRASH_FOLDER", "Trash"),
|
"trash_folder": os.environ.get("TRASH_FOLDER", "Trash"),
|
||||||
"cache_db": os.environ.get(
|
"cache_db": os.environ.get(
|
||||||
"EMAIL_CACHE_DB",
|
"EMAIL_CACHE_DB",
|
||||||
str(DATA_DIR / "email_cache.db"),
|
EMAIL_CACHE_DB,
|
||||||
),
|
),
|
||||||
"account_id": None,
|
"account_id": None,
|
||||||
"account_name": None,
|
"account_name": None,
|
||||||
@@ -199,7 +205,7 @@ def _load_config(account: str | None = None) -> dict:
|
|||||||
else:
|
else:
|
||||||
# Legacy fallback: settings.json flat keys
|
# Legacy fallback: settings.json flat keys
|
||||||
try:
|
try:
|
||||||
settings_path = Path(__file__).resolve().parent.parent / "data" / "settings.json"
|
settings_path = Path(_SETTINGS_FILE)
|
||||||
if settings_path.exists():
|
if settings_path.exists():
|
||||||
settings = json.loads(settings_path.read_text(encoding="utf-8"))
|
settings = json.loads(settings_path.read_text(encoding="utf-8"))
|
||||||
for key in (
|
for key in (
|
||||||
@@ -239,10 +245,27 @@ def _imap_connect(account: str | None = None):
|
|||||||
timeout=EMAIL_SOCKET_TIMEOUT,
|
timeout=EMAIL_SOCKET_TIMEOUT,
|
||||||
)
|
)
|
||||||
if cfg["imap_starttls"]:
|
if cfg["imap_starttls"]:
|
||||||
conn.starttls()
|
try:
|
||||||
|
conn.starttls()
|
||||||
|
except Exception:
|
||||||
|
# Don't leak the open plain socket on a rejected STARTTLS. (#3174)
|
||||||
|
try:
|
||||||
|
conn.shutdown()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
raise
|
||||||
if getattr(conn, "sock", None):
|
if getattr(conn, "sock", None):
|
||||||
conn.sock.settimeout(EMAIL_SOCKET_TIMEOUT)
|
conn.sock.settimeout(EMAIL_SOCKET_TIMEOUT)
|
||||||
conn.login(cfg["imap_user"], cfg["imap_password"])
|
try:
|
||||||
|
conn.login(cfg["imap_user"], cfg["imap_password"])
|
||||||
|
except Exception:
|
||||||
|
# A failed login otherwise orphans the connected socket; close it
|
||||||
|
# before propagating (shutdown() is the pre-auth low-level close). (#3174)
|
||||||
|
try:
|
||||||
|
conn.shutdown()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
raise
|
||||||
return conn
|
return conn
|
||||||
|
|
||||||
|
|
||||||
@@ -418,68 +441,71 @@ def _list_emails(folder="INBOX", max_results=20, unresponded_only=False,
|
|||||||
Pass unread_only=True and/or unresponded_only=True for attention scans.
|
Pass unread_only=True and/or unresponded_only=True for attention scans.
|
||||||
account selects mailbox (None = default).
|
account selects mailbox (None = default).
|
||||||
"""
|
"""
|
||||||
conn = _imap_connect(account)
|
conn = None
|
||||||
select_status, _ = conn.select(folder, readonly=True)
|
try:
|
||||||
if select_status != "OK":
|
conn = _imap_connect(account)
|
||||||
conn.logout()
|
select_status, _ = conn.select(_q(folder), readonly=True)
|
||||||
raise ValueError(f"IMAP folder not found: {folder}")
|
if select_status != "OK":
|
||||||
|
raise ValueError(f"IMAP folder not found: {folder}")
|
||||||
|
|
||||||
if unread_only and unresponded_only:
|
if unread_only and unresponded_only:
|
||||||
status, data = conn.uid("SEARCH", None, "(UNSEEN UNANSWERED)")
|
status, data = conn.uid("SEARCH", None, "(UNSEEN UNANSWERED)")
|
||||||
elif unread_only:
|
elif unread_only:
|
||||||
status, data = conn.uid("SEARCH", None, "(UNSEEN)")
|
status, data = conn.uid("SEARCH", None, "(UNSEEN)")
|
||||||
elif unresponded_only:
|
elif unresponded_only:
|
||||||
# Was missing — unresponded_only=True (without unread_only) fell through
|
# Was missing — unresponded_only=True (without unread_only) fell through
|
||||||
# to "ALL" and returned answered mail too, despite the documented
|
# to "ALL" and returned answered mail too, despite the documented
|
||||||
# "emails without replies" behaviour.
|
# "emails without replies" behaviour.
|
||||||
status, data = conn.uid("SEARCH", None, "(UNANSWERED)")
|
status, data = conn.uid("SEARCH", None, "(UNANSWERED)")
|
||||||
else:
|
else:
|
||||||
# Include read too — IMAP search "ALL" returns the entire folder
|
# Include read too — IMAP search "ALL" returns the entire folder
|
||||||
status, data = conn.uid("SEARCH", None, "ALL")
|
status, data = conn.uid("SEARCH", None, "ALL")
|
||||||
|
|
||||||
if status != "OK" or not data[0]:
|
if status != "OK" or not data[0]:
|
||||||
conn.logout()
|
return []
|
||||||
return []
|
|
||||||
|
|
||||||
uid_list = list(reversed(data[0].split()))[:max_results]
|
uid_list = list(reversed(data[0].split()))[:max_results]
|
||||||
cache = _get_cached_summaries()
|
cache = _get_cached_summaries()
|
||||||
results = []
|
results = []
|
||||||
|
|
||||||
for uid in uid_list:
|
for uid in uid_list:
|
||||||
try:
|
try:
|
||||||
status, msg_data = conn.uid("FETCH", uid, "(RFC822.HEADER)")
|
status, msg_data = conn.uid("FETCH", uid, "(RFC822.HEADER)")
|
||||||
if status != "OK":
|
if status != "OK":
|
||||||
|
continue
|
||||||
|
raw_header = msg_data[0][1]
|
||||||
|
msg = email.message_from_bytes(raw_header)
|
||||||
|
|
||||||
|
subject = _decode_header(msg.get("Subject", "(no subject)"))
|
||||||
|
sender = _decode_header(msg.get("From", "unknown"))
|
||||||
|
date_str = msg.get("Date", "")
|
||||||
|
message_id = msg.get("Message-ID", "")
|
||||||
|
|
||||||
|
# Parse sender name
|
||||||
|
sender_name, sender_addr = email.utils.parseaddr(sender)
|
||||||
|
sender_display = sender_name or sender_addr
|
||||||
|
|
||||||
|
# Check cache for summary
|
||||||
|
cached = cache.get(subject, {})
|
||||||
|
summary = cached.get("summary", "")
|
||||||
|
|
||||||
|
results.append({
|
||||||
|
"uid": uid.decode(),
|
||||||
|
"message_id": message_id,
|
||||||
|
"subject": subject,
|
||||||
|
"from": sender_display,
|
||||||
|
"from_address": sender_addr,
|
||||||
|
"date": date_str,
|
||||||
|
"summary": summary,
|
||||||
|
})
|
||||||
|
except Exception:
|
||||||
continue
|
continue
|
||||||
raw_header = msg_data[0][1]
|
|
||||||
msg = email.message_from_bytes(raw_header)
|
|
||||||
|
|
||||||
subject = _decode_header(msg.get("Subject", "(no subject)"))
|
return results
|
||||||
sender = _decode_header(msg.get("From", "unknown"))
|
finally:
|
||||||
date_str = msg.get("Date", "")
|
if conn:
|
||||||
message_id = msg.get("Message-ID", "")
|
try: conn.logout()
|
||||||
|
except Exception: pass
|
||||||
# Parse sender name
|
|
||||||
sender_name, sender_addr = email.utils.parseaddr(sender)
|
|
||||||
sender_display = sender_name or sender_addr
|
|
||||||
|
|
||||||
# Check cache for summary
|
|
||||||
cached = cache.get(subject, {})
|
|
||||||
summary = cached.get("summary", "")
|
|
||||||
|
|
||||||
results.append({
|
|
||||||
"uid": uid.decode(),
|
|
||||||
"message_id": message_id,
|
|
||||||
"subject": subject,
|
|
||||||
"from": sender_display,
|
|
||||||
"from_address": sender_addr,
|
|
||||||
"date": date_str,
|
|
||||||
"summary": summary,
|
|
||||||
})
|
|
||||||
except Exception:
|
|
||||||
continue
|
|
||||||
|
|
||||||
conn.logout()
|
|
||||||
return results
|
|
||||||
|
|
||||||
|
|
||||||
def _result_sort_time(result: dict) -> datetime:
|
def _result_sort_time(result: dict) -> datetime:
|
||||||
@@ -542,7 +568,7 @@ def _search_emails(query, folders=None, max_results=20, account=None):
|
|||||||
try:
|
try:
|
||||||
for folder in folders:
|
for folder in folders:
|
||||||
try:
|
try:
|
||||||
status, _ = conn.select(folder, readonly=True)
|
status, _ = conn.select(_q(folder), readonly=True)
|
||||||
if status != "OK":
|
if status != "OK":
|
||||||
continue
|
continue
|
||||||
status, data = conn.uid("SEARCH", None, search_cmd)
|
status, data = conn.uid("SEARCH", None, search_cmd)
|
||||||
@@ -652,54 +678,55 @@ def _extract_attachment_to_disk(msg, index, target_dir):
|
|||||||
def _read_email(uid=None, message_id=None, folder="INBOX", account=None):
|
def _read_email(uid=None, message_id=None, folder="INBOX", account=None):
|
||||||
"""Read full email content by UID or message-ID. account = mailbox selector."""
|
"""Read full email content by UID or message-ID. account = mailbox selector."""
|
||||||
cfg = _load_config(account)
|
cfg = _load_config(account)
|
||||||
conn = _imap_connect(account)
|
conn = None
|
||||||
conn.select(folder, readonly=True)
|
try:
|
||||||
|
conn = _imap_connect(account)
|
||||||
|
conn.select(_q(folder), readonly=True)
|
||||||
|
|
||||||
if message_id and not uid:
|
if message_id and not uid:
|
||||||
status, data = conn.uid("SEARCH", None, f'(HEADER Message-ID "{message_id}")')
|
status, data = conn.uid("SEARCH", None, f'(HEADER Message-ID "{message_id}")')
|
||||||
if status != "OK" or not data[0]:
|
if status != "OK" or not data[0]:
|
||||||
conn.logout()
|
return {"error": f"Email not found with Message-ID: {message_id}"}
|
||||||
return {"error": f"Email not found with Message-ID: {message_id}"}
|
uid = data[0].split()[-1]
|
||||||
uid = data[0].split()[-1]
|
|
||||||
|
|
||||||
if not uid:
|
if not uid:
|
||||||
conn.logout()
|
return {"error": "No UID or Message-ID provided"}
|
||||||
return {"error": "No UID or Message-ID provided"}
|
|
||||||
|
|
||||||
status, msg_data = conn.uid("FETCH", _b(uid), "(BODY.PEEK[])")
|
status, msg_data = conn.uid("FETCH", _b(uid), "(BODY.PEEK[])")
|
||||||
if status != "OK":
|
if status != "OK":
|
||||||
conn.logout()
|
return {"error": f"Failed to fetch email UID {uid}"}
|
||||||
return {"error": f"Failed to fetch email UID {uid}"}
|
if not msg_data or not msg_data[0] or not isinstance(msg_data[0], tuple) or len(msg_data[0]) < 2:
|
||||||
if not msg_data or not msg_data[0] or not isinstance(msg_data[0], tuple) or len(msg_data[0]) < 2:
|
return {"error": f"Email not found with UID {uid}"}
|
||||||
conn.logout()
|
|
||||||
return {"error": f"Email not found with UID {uid}"}
|
|
||||||
|
|
||||||
raw = msg_data[0][1]
|
raw = msg_data[0][1]
|
||||||
msg = email.message_from_bytes(raw)
|
msg = email.message_from_bytes(raw)
|
||||||
|
|
||||||
subject = _decode_header(msg.get("Subject", "(no subject)"))
|
subject = _decode_header(msg.get("Subject", "(no subject)"))
|
||||||
sender = _decode_header(msg.get("From", "unknown"))
|
sender = _decode_header(msg.get("From", "unknown"))
|
||||||
date_str = msg.get("Date", "")
|
date_str = msg.get("Date", "")
|
||||||
message_id_header = msg.get("Message-ID", "")
|
message_id_header = msg.get("Message-ID", "")
|
||||||
body = _extract_text(msg)
|
body = _extract_text(msg)
|
||||||
attachments = _list_attachments_from_msg(msg)
|
attachments = _list_attachments_from_msg(msg)
|
||||||
|
|
||||||
sender_name, sender_addr = email.utils.parseaddr(sender)
|
sender_name, sender_addr = email.utils.parseaddr(sender)
|
||||||
|
|
||||||
conn.logout()
|
return {
|
||||||
return {
|
"uid": uid.decode() if isinstance(uid, bytes) else str(uid),
|
||||||
"uid": uid.decode() if isinstance(uid, bytes) else str(uid),
|
"account": cfg.get("account_name") or cfg.get("imap_user") or "default",
|
||||||
"account": cfg.get("account_name") or cfg.get("imap_user") or "default",
|
"account_email": cfg.get("imap_user") or cfg.get("from_address") or "",
|
||||||
"account_email": cfg.get("imap_user") or cfg.get("from_address") or "",
|
"account_id": cfg.get("account_id"),
|
||||||
"account_id": cfg.get("account_id"),
|
"message_id": message_id_header,
|
||||||
"message_id": message_id_header,
|
"subject": subject,
|
||||||
"subject": subject,
|
"from": sender_name or sender_addr,
|
||||||
"from": sender_name or sender_addr,
|
"from_address": sender_addr,
|
||||||
"from_address": sender_addr,
|
"date": date_str,
|
||||||
"date": date_str,
|
"body": body[:8000],
|
||||||
"body": body[:8000],
|
"attachments": attachments,
|
||||||
"attachments": attachments,
|
}
|
||||||
}
|
finally:
|
||||||
|
if conn:
|
||||||
|
try: conn.logout()
|
||||||
|
except Exception: pass
|
||||||
|
|
||||||
|
|
||||||
def _read_email_across_accounts(uid=None, message_id=None, folder="INBOX"):
|
def _read_email_across_accounts(uid=None, message_id=None, folder="INBOX"):
|
||||||
@@ -768,7 +795,16 @@ def _smtp_connect(account=None, cfg=None):
|
|||||||
port,
|
port,
|
||||||
timeout=EMAIL_SOCKET_TIMEOUT,
|
timeout=EMAIL_SOCKET_TIMEOUT,
|
||||||
)
|
)
|
||||||
conn.starttls()
|
try:
|
||||||
|
conn.starttls()
|
||||||
|
except Exception:
|
||||||
|
# Don't leak the open plain socket on a rejected STARTTLS. SMTP has
|
||||||
|
# no shutdown(); close() is the low-level socket close (no QUIT). (#3174)
|
||||||
|
try:
|
||||||
|
conn.close()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
raise
|
||||||
elif security == "ssl":
|
elif security == "ssl":
|
||||||
conn = smtplib.SMTP_SSL(
|
conn = smtplib.SMTP_SSL(
|
||||||
cfg["smtp_host"],
|
cfg["smtp_host"],
|
||||||
@@ -782,7 +818,16 @@ def _smtp_connect(account=None, cfg=None):
|
|||||||
timeout=EMAIL_SOCKET_TIMEOUT,
|
timeout=EMAIL_SOCKET_TIMEOUT,
|
||||||
)
|
)
|
||||||
if cfg["smtp_user"] and cfg["smtp_password"]:
|
if cfg["smtp_user"] and cfg["smtp_password"]:
|
||||||
conn.login(cfg["smtp_user"], cfg["smtp_password"])
|
try:
|
||||||
|
conn.login(cfg["smtp_user"], cfg["smtp_password"])
|
||||||
|
except Exception:
|
||||||
|
# A failed login otherwise orphans the connected socket; close it
|
||||||
|
# before propagating (SMTP has no shutdown(); close() = socket close). (#3174)
|
||||||
|
try:
|
||||||
|
conn.close()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
raise
|
||||||
return conn
|
return conn
|
||||||
|
|
||||||
|
|
||||||
@@ -827,7 +872,7 @@ def _send_email(to, subject, body, in_reply_to=None, references=None, cc=None, b
|
|||||||
imap = _imap_connect(send_account)
|
imap = _imap_connect(send_account)
|
||||||
try:
|
try:
|
||||||
sent_folder = _detect_sent_folder(imap)
|
sent_folder = _detect_sent_folder(imap)
|
||||||
append_st, append_data = imap.append(sent_folder, "\\Seen", None, msg.as_bytes())
|
append_st, append_data = imap.append(_q(sent_folder), "\\Seen", None, msg.as_bytes())
|
||||||
if append_st == "OK" and append_data:
|
if append_st == "OK" and append_data:
|
||||||
m = re.search(rb"APPENDUID\s+\d+\s+(\d+)", append_data[0] or b"")
|
m = re.search(rb"APPENDUID\s+\d+\s+(\d+)", append_data[0] or b"")
|
||||||
if m:
|
if m:
|
||||||
@@ -853,10 +898,15 @@ def _send_email(to, subject, body, in_reply_to=None, references=None, cc=None, b
|
|||||||
|
|
||||||
def _reply_to_email(uid, body, folder="INBOX", reply_all=False, account=None):
|
def _reply_to_email(uid, body, folder="INBOX", reply_all=False, account=None):
|
||||||
"""Reply to an existing email by UID. Threads via In-Reply-To/References."""
|
"""Reply to an existing email by UID. Threads via In-Reply-To/References."""
|
||||||
conn = _imap_connect(account)
|
conn = None
|
||||||
conn.select(folder, readonly=True)
|
try:
|
||||||
status, msg_data = conn.uid("FETCH", _b(uid), "(BODY.PEEK[])")
|
conn = _imap_connect(account)
|
||||||
conn.logout()
|
conn.select(_q(folder), readonly=True)
|
||||||
|
status, msg_data = conn.uid("FETCH", _b(uid), "(BODY.PEEK[])")
|
||||||
|
finally:
|
||||||
|
if conn:
|
||||||
|
try: conn.logout()
|
||||||
|
except Exception: pass
|
||||||
if status != "OK" or not msg_data or not msg_data[0]:
|
if status != "OK" or not msg_data or not msg_data[0]:
|
||||||
return {"error": f"Failed to fetch email UID {uid}"}
|
return {"error": f"Failed to fetch email UID {uid}"}
|
||||||
raw = msg_data[0][1]
|
raw = msg_data[0][1]
|
||||||
@@ -896,7 +946,7 @@ def _reply_to_email(uid, body, folder="INBOX", reply_all=False, account=None):
|
|||||||
def _set_flag(uid, folder, flag, add=True, account=None):
|
def _set_flag(uid, folder, flag, add=True, account=None):
|
||||||
"""Add or remove an IMAP flag (e.g. \\Seen, \\Answered, \\Deleted)."""
|
"""Add or remove an IMAP flag (e.g. \\Seen, \\Answered, \\Deleted)."""
|
||||||
conn = _imap_connect(account)
|
conn = _imap_connect(account)
|
||||||
conn.select(folder)
|
conn.select(_q(folder))
|
||||||
op = "+FLAGS" if add else "-FLAGS"
|
op = "+FLAGS" if add else "-FLAGS"
|
||||||
try:
|
try:
|
||||||
status, data = conn.uid("STORE", _b(uid), op, flag)
|
status, data = conn.uid("STORE", _b(uid), op, flag)
|
||||||
@@ -918,7 +968,7 @@ def _bulk_set_flag(uids, folder, flag, add=True, account=None):
|
|||||||
conn = _imap_connect(account)
|
conn = _imap_connect(account)
|
||||||
touched = []
|
touched = []
|
||||||
try:
|
try:
|
||||||
conn.select(folder)
|
conn.select(_q(folder))
|
||||||
op = "+FLAGS" if add else "-FLAGS"
|
op = "+FLAGS" if add else "-FLAGS"
|
||||||
msg_set = ",".join(str(u) for u in uids)
|
msg_set = ",".join(str(u) for u in uids)
|
||||||
try:
|
try:
|
||||||
@@ -945,7 +995,7 @@ def _bulk_move(uids, source_folder, dest_folder, account=None, role: str = ""):
|
|||||||
conn = _imap_connect(account)
|
conn = _imap_connect(account)
|
||||||
moved = 0
|
moved = 0
|
||||||
try:
|
try:
|
||||||
conn.select(source_folder)
|
conn.select(_q(source_folder))
|
||||||
dest_folder = _resolve_folder(conn, dest_folder, role or _folder_role_from_name(dest_folder))
|
dest_folder = _resolve_folder(conn, dest_folder, role or _folder_role_from_name(dest_folder))
|
||||||
msg_set = ",".join(str(u) for u in uids)
|
msg_set = ",".join(str(u) for u in uids)
|
||||||
try:
|
try:
|
||||||
@@ -956,10 +1006,11 @@ def _bulk_move(uids, source_folder, dest_folder, account=None, role: str = ""):
|
|||||||
if not existing:
|
if not existing:
|
||||||
return 0
|
return 0
|
||||||
moved = len(existing)
|
moved = len(existing)
|
||||||
status, _ = conn.uid("MOVE", _b(msg_set), dest_folder)
|
dest_arg = _q(dest_folder)
|
||||||
|
status, _ = conn.uid("MOVE", _b(msg_set), dest_arg)
|
||||||
if status != "OK":
|
if status != "OK":
|
||||||
# Fallback: UID copy + flag-delete + expunge
|
# Fallback: UID copy + flag-delete + expunge
|
||||||
status, _ = conn.uid("COPY", _b(msg_set), dest_folder)
|
status, _ = conn.uid("COPY", _b(msg_set), dest_arg)
|
||||||
if status != "OK":
|
if status != "OK":
|
||||||
return 0
|
return 0
|
||||||
status, _ = conn.uid("STORE", _b(msg_set), "+FLAGS", "\\Deleted")
|
status, _ = conn.uid("STORE", _b(msg_set), "+FLAGS", "\\Deleted")
|
||||||
@@ -976,7 +1027,7 @@ def _search_uids(folder="INBOX", criteria="UNSEEN", account=None):
|
|||||||
ALL, ANSWERED). Used to resolve selectors like all_unread → uids."""
|
ALL, ANSWERED). Used to resolve selectors like all_unread → uids."""
|
||||||
conn = _imap_connect(account)
|
conn = _imap_connect(account)
|
||||||
try:
|
try:
|
||||||
conn.select(folder, readonly=True)
|
conn.select(_q(folder), readonly=True)
|
||||||
status, data = conn.uid("SEARCH", None, criteria)
|
status, data = conn.uid("SEARCH", None, criteria)
|
||||||
if status != "OK" or not data or not data[0]:
|
if status != "OK" or not data or not data[0]:
|
||||||
return []
|
return []
|
||||||
@@ -988,7 +1039,7 @@ def _search_uids(folder="INBOX", criteria="UNSEEN", account=None):
|
|||||||
def _move_message(uid, source_folder, dest_folder, account=None, role: str = ""):
|
def _move_message(uid, source_folder, dest_folder, account=None, role: str = ""):
|
||||||
"""Move a message between folders. Tries IMAP MOVE, falls back to copy+delete."""
|
"""Move a message between folders. Tries IMAP MOVE, falls back to copy+delete."""
|
||||||
conn = _imap_connect(account)
|
conn = _imap_connect(account)
|
||||||
conn.select(source_folder)
|
conn.select(_q(source_folder))
|
||||||
try:
|
try:
|
||||||
dest_folder = _resolve_folder(conn, dest_folder, role or _folder_role_from_name(dest_folder))
|
dest_folder = _resolve_folder(conn, dest_folder, role or _folder_role_from_name(dest_folder))
|
||||||
try:
|
try:
|
||||||
@@ -998,11 +1049,12 @@ def _move_message(uid, source_folder, dest_folder, account=None, role: str = "")
|
|||||||
existing = _uid_fetch_rows(data)
|
existing = _uid_fetch_rows(data)
|
||||||
if status != "OK" or not existing:
|
if status != "OK" or not existing:
|
||||||
return False
|
return False
|
||||||
status, _ = conn.uid("MOVE", _b(uid), dest_folder)
|
dest_arg = _q(dest_folder)
|
||||||
|
status, _ = conn.uid("MOVE", _b(uid), dest_arg)
|
||||||
if status == "OK":
|
if status == "OK":
|
||||||
return True
|
return True
|
||||||
# Fallback: UID copy + delete
|
# Fallback: UID copy + delete
|
||||||
status, _ = conn.uid("COPY", _b(uid), dest_folder)
|
status, _ = conn.uid("COPY", _b(uid), dest_arg)
|
||||||
if status != "OK":
|
if status != "OK":
|
||||||
return False
|
return False
|
||||||
status, _ = conn.uid("STORE", _b(uid), "+FLAGS", "\\Deleted")
|
status, _ = conn.uid("STORE", _b(uid), "+FLAGS", "\\Deleted")
|
||||||
@@ -1031,16 +1083,21 @@ def _archive_email(uid, folder="INBOX", account=None):
|
|||||||
|
|
||||||
def _download_attachment(uid, index, folder="INBOX", account=None):
|
def _download_attachment(uid, index, folder="INBOX", account=None):
|
||||||
"""Extract a specific attachment to disk and return its local path."""
|
"""Extract a specific attachment to disk and return its local path."""
|
||||||
conn = _imap_connect(account)
|
conn = None
|
||||||
conn.select(folder, readonly=True)
|
try:
|
||||||
status, msg_data = conn.uid("FETCH", _b(uid), "(BODY.PEEK[])")
|
conn = _imap_connect(account)
|
||||||
conn.logout()
|
conn.select(_q(folder), readonly=True)
|
||||||
|
status, msg_data = conn.uid("FETCH", _b(uid), "(BODY.PEEK[])")
|
||||||
|
finally:
|
||||||
|
if conn:
|
||||||
|
try: conn.logout()
|
||||||
|
except Exception: pass
|
||||||
if status != "OK":
|
if status != "OK":
|
||||||
return {"error": f"Failed to fetch email UID {uid}"}
|
return {"error": f"Failed to fetch email UID {uid}"}
|
||||||
raw = msg_data[0][1]
|
raw = msg_data[0][1]
|
||||||
msg = email.message_from_bytes(raw)
|
msg = email.message_from_bytes(raw)
|
||||||
|
|
||||||
target_dir = DATA_DIR / "mail-attachments" / f"{folder}_{uid}"
|
target_dir = Path(MAIL_ATTACHMENTS_DIR) / f"{folder}_{uid}"
|
||||||
filepath = _extract_attachment_to_disk(msg, index, target_dir)
|
filepath = _extract_attachment_to_disk(msg, index, target_dir)
|
||||||
if not filepath:
|
if not filepath:
|
||||||
return {"error": f"Attachment index {index} not found"}
|
return {"error": f"Attachment index {index} not found"}
|
||||||
|
|||||||
@@ -16,6 +16,8 @@ from mcp.types import Tool, TextContent
|
|||||||
|
|
||||||
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
|
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
|
||||||
|
|
||||||
|
from src.constants import GENERATED_IMAGES_DIR
|
||||||
|
|
||||||
server = Server("image_gen")
|
server = Server("image_gen")
|
||||||
|
|
||||||
|
|
||||||
@@ -115,14 +117,18 @@ async def call_tool(name: str, arguments: dict) -> list[TextContent]:
|
|||||||
|
|
||||||
img = images[0]
|
img = images[0]
|
||||||
image_url = None
|
image_url = None
|
||||||
|
# Prefix the instance's public base URL (existing app_public_url setting) so the
|
||||||
|
# link is fully-qualified and clickable when the model echoes it. Empty = relative
|
||||||
|
# same-origin path (unchanged default).
|
||||||
|
_pub_base = (get_setting("app_public_url", "") or "").rstrip("/")
|
||||||
|
|
||||||
if img.get("b64_json"):
|
if img.get("b64_json"):
|
||||||
img_dir = Path("data/generated_images")
|
img_dir = Path(GENERATED_IMAGES_DIR)
|
||||||
img_dir.mkdir(parents=True, exist_ok=True)
|
img_dir.mkdir(parents=True, exist_ok=True)
|
||||||
filename = f"{uuid.uuid4().hex[:12]}.png"
|
filename = f"{uuid.uuid4().hex[:12]}.png"
|
||||||
img_path = img_dir / filename
|
img_path = img_dir / filename
|
||||||
img_path.write_bytes(base64.b64decode(img["b64_json"]))
|
img_path.write_bytes(base64.b64decode(img["b64_json"]))
|
||||||
image_url = f"/api/generated-image/{filename}"
|
image_url = f"{_pub_base}/api/generated-image/{filename}"
|
||||||
|
|
||||||
# Save to gallery
|
# Save to gallery
|
||||||
try:
|
try:
|
||||||
@@ -146,7 +152,13 @@ async def call_tool(name: str, arguments: dict) -> list[TextContent]:
|
|||||||
else:
|
else:
|
||||||
return [TextContent(type="text", text="Error: Unexpected image API response format")]
|
return [TextContent(type="text", text="Error: Unexpected image API response format")]
|
||||||
|
|
||||||
result = f"Generated image for: {prompt[:100]}\nimage_url: {image_url}\nmodel: {model_id}\nsize: {size}"
|
# "Direct link:" rather than an "image_url:" label — small models copied the
|
||||||
|
# label token ("image_url") into the link href, producing a broken link.
|
||||||
|
result = (
|
||||||
|
f"Generated image for: {prompt[:100]}\n"
|
||||||
|
f"Direct link: {image_url}\n"
|
||||||
|
f"model: {model_id}\nsize: {size}"
|
||||||
|
)
|
||||||
return [TextContent(type="text", text=result)]
|
return [TextContent(type="text", text=result)]
|
||||||
|
|
||||||
except httpx.TimeoutException:
|
except httpx.TimeoutException:
|
||||||
|
|||||||
Generated
+1
-1
@@ -1,5 +1,5 @@
|
|||||||
{
|
{
|
||||||
"name": "odysseus-ui",
|
"name": "odysseus",
|
||||||
"lockfileVersion": 3,
|
"lockfileVersion": 3,
|
||||||
"requires": true,
|
"requires": true,
|
||||||
"packages": {
|
"packages": {
|
||||||
|
|||||||
@@ -1,3 +1,18 @@
|
|||||||
[tool.pytest.ini_options]
|
[tool.pytest.ini_options]
|
||||||
testpaths = ["tests"]
|
testpaths = ["tests"]
|
||||||
asyncio_mode = "auto"
|
asyncio_mode = "auto"
|
||||||
|
# Test-taxonomy markers added at collection time by tests/conftest.py. The
|
||||||
|
# stable area_* markers are declared here; the dynamic sub_<filename-token>
|
||||||
|
# markers are registered before collection by pytest_configure in
|
||||||
|
# tests/conftest.py, so unknown-mark warnings still flag genuine typos outside
|
||||||
|
# the taxonomy. See tests/_taxonomy.py and tests/README.md.
|
||||||
|
markers = [
|
||||||
|
"area_security: tests covering auth, owner-scope, SSRF, XSS, confinement, redaction",
|
||||||
|
"area_routes: tests covering HTTP route / API behavior",
|
||||||
|
"area_services: tests covering service-layer behavior (llm, cookbook, email, calendar, ...)",
|
||||||
|
"area_cli: tests covering CLI / script behavior",
|
||||||
|
"area_js: JavaScript / Node-backed tests",
|
||||||
|
"area_helpers: self-tests for the shared test helpers in tests/helpers/",
|
||||||
|
"area_unit: pure parser / utility tests that do not clearly belong elsewhere",
|
||||||
|
"area_uncategorized: tests not yet matched by the taxonomy (fallback)",
|
||||||
|
]
|
||||||
|
|||||||
@@ -31,7 +31,7 @@ from core.database import (
|
|||||||
CalendarEvent,
|
CalendarEvent,
|
||||||
CalendarCal,
|
CalendarCal,
|
||||||
)
|
)
|
||||||
from src.constants import DATA_DIR
|
from src.constants import DATA_DIR, SKILLS_DIR, SKILLS_FILE, GALLERY_DIR, GALLERY_UPLOADS_DIR
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -107,7 +107,7 @@ def setup_admin_wipe_routes(session_manager):
|
|||||||
# Skills live as SKILL.md files under data/skills/. Drop
|
# Skills live as SKILL.md files under data/skills/. Drop
|
||||||
# the entire directory; the SkillsManager re-creates the
|
# the entire directory; the SkillsManager re-creates the
|
||||||
# tree on next write.
|
# tree on next write.
|
||||||
skills_dir = os.path.join(DATA_DIR, "skills")
|
skills_dir = SKILLS_DIR
|
||||||
count = 0
|
count = 0
|
||||||
if os.path.isdir(skills_dir):
|
if os.path.isdir(skills_dir):
|
||||||
# Count SKILL.md files for the response — quick walk.
|
# Count SKILL.md files for the response — quick walk.
|
||||||
@@ -115,7 +115,7 @@ def setup_admin_wipe_routes(session_manager):
|
|||||||
count += sum(1 for f in files if f == "SKILL.md")
|
count += sum(1 for f in files if f == "SKILL.md")
|
||||||
_rmtree_quiet(skills_dir)
|
_rmtree_quiet(skills_dir)
|
||||||
# Legacy fallback file
|
# Legacy fallback file
|
||||||
legacy = os.path.join(DATA_DIR, "skills.json")
|
legacy = SKILLS_FILE
|
||||||
if os.path.exists(legacy):
|
if os.path.exists(legacy):
|
||||||
try:
|
try:
|
||||||
os.remove(legacy)
|
os.remove(legacy)
|
||||||
@@ -151,8 +151,8 @@ def setup_admin_wipe_routes(session_manager):
|
|||||||
db.query(GalleryAlbum).delete()
|
db.query(GalleryAlbum).delete()
|
||||||
db.commit()
|
db.commit()
|
||||||
# Also drop the upload dir so disk doesn't keep orphans.
|
# Also drop the upload dir so disk doesn't keep orphans.
|
||||||
_rmtree_quiet(os.path.join(DATA_DIR, "gallery"))
|
_rmtree_quiet(GALLERY_DIR)
|
||||||
_rmtree_quiet(os.path.join(DATA_DIR, "gallery_uploads"))
|
_rmtree_quiet(GALLERY_UPLOADS_DIR)
|
||||||
return {"status": "deleted", "kind": kind, "count": count}
|
return {"status": "deleted", "kind": kind, "count": count}
|
||||||
|
|
||||||
if kind == "calendar":
|
if kind == "calendar":
|
||||||
|
|||||||
@@ -155,22 +155,30 @@ def setup_api_token_routes() -> APIRouter:
|
|||||||
payload = await request.json()
|
payload = await request.json()
|
||||||
except Exception:
|
except Exception:
|
||||||
payload = {}
|
payload = {}
|
||||||
scope_list = _normalize_scopes(payload.get("scopes"))
|
|
||||||
scopes_value = ",".join(scope_list)
|
|
||||||
with get_db_session() as db:
|
with get_db_session() as db:
|
||||||
token = db.query(ApiToken).filter(ApiToken.id == token_id).first()
|
token = db.query(ApiToken).filter(ApiToken.id == token_id).first()
|
||||||
if not token:
|
if not token:
|
||||||
raise HTTPException(404, "Token not found")
|
raise HTTPException(404, "Token not found")
|
||||||
if isinstance(payload.get("name"), str) and payload["name"].strip():
|
if isinstance(payload.get("name"), str) and payload["name"].strip():
|
||||||
token.name = payload["name"].strip()[:MAX_NAME_LEN]
|
token.name = payload["name"].strip()[:MAX_NAME_LEN]
|
||||||
token.scopes = scopes_value
|
# Only touch scopes when the caller actually sent them. A partial
|
||||||
|
# update such as a rename ({"name": ...} with no "scopes" key) must
|
||||||
|
# not silently reset the token to the default scope — that dropped
|
||||||
|
# every previously granted scope.
|
||||||
|
if "scopes" in payload:
|
||||||
|
token.scopes = ",".join(_normalize_scopes(payload.get("scopes")))
|
||||||
db.add(token)
|
db.add(token)
|
||||||
|
current_scopes = [
|
||||||
|
s.strip()
|
||||||
|
for s in (getattr(token, "scopes", "") or DEFAULT_SCOPES).split(",")
|
||||||
|
if s.strip()
|
||||||
|
]
|
||||||
response = {
|
response = {
|
||||||
"id": token_id,
|
"id": token_id,
|
||||||
"name": getattr(token, "name", ""),
|
"name": getattr(token, "name", ""),
|
||||||
"owner": getattr(token, "owner", None),
|
"owner": getattr(token, "owner", None),
|
||||||
"token_prefix": getattr(token, "token_prefix", ""),
|
"token_prefix": getattr(token, "token_prefix", ""),
|
||||||
"scopes": scope_list,
|
"scopes": current_scopes,
|
||||||
}
|
}
|
||||||
_invalidate_cache(request)
|
_invalidate_cache(request)
|
||||||
return response
|
return response
|
||||||
|
|||||||
+23
-4
@@ -131,10 +131,8 @@ def setup_auth_routes(auth_manager: AuthManager) -> APIRouter:
|
|||||||
return {"ok": False, "requires_totp": True, "username": username}
|
return {"ok": False, "requires_totp": True, "username": username}
|
||||||
if not auth_manager.totp_verify(username, body.totp_code):
|
if not auth_manager.totp_verify(username, body.totp_code):
|
||||||
raise HTTPException(401, "Invalid 2FA code")
|
raise HTTPException(401, "Invalid 2FA code")
|
||||||
# All checks passed — create session
|
# All checks passed — create session (password already verified above)
|
||||||
token = await asyncio.to_thread(auth_manager.create_session, username, body.password)
|
token = await asyncio.to_thread(auth_manager.create_session_trusted, username)
|
||||||
if not token:
|
|
||||||
raise HTTPException(401, "Invalid credentials")
|
|
||||||
cookie_kwargs = dict(
|
cookie_kwargs = dict(
|
||||||
key=SESSION_COOKIE,
|
key=SESSION_COOKIE,
|
||||||
value=token,
|
value=token,
|
||||||
@@ -585,6 +583,27 @@ def setup_auth_routes(auth_manager: AuthManager) -> APIRouter:
|
|||||||
hint = " If this is Docker Compose ntfy, set NTFY_BIND to that host/Tailscale IP and NTFY_BASE_URL to the same server URL in .env, then recreate ntfy."
|
hint = " If this is Docker Compose ntfy, set NTFY_BIND to that host/Tailscale IP and NTFY_BASE_URL to the same server URL in .env, then recreate ntfy."
|
||||||
return {"ok": False, "message": f"ntfy publish to {full_url} failed: {e}.{hint}"[:500]}
|
return {"ok": False, "message": f"ntfy publish to {full_url} failed: {e}.{hint}"[:500]}
|
||||||
|
|
||||||
|
if preset == "discord_webhook":
|
||||||
|
import httpx
|
||||||
|
webhook_url = (integ.get("base_url") or "").strip()
|
||||||
|
if not webhook_url:
|
||||||
|
return {"ok": False, "message": "No webhook URL set — paste the full Discord webhook URL into the Base URL field."}
|
||||||
|
payload = {
|
||||||
|
"embeds": [{
|
||||||
|
"title": "Odysseus connectivity test",
|
||||||
|
"description": "If you see this, your Discord Webhook integration is wired up correctly.",
|
||||||
|
"color": 5793266,
|
||||||
|
}]
|
||||||
|
}
|
||||||
|
try:
|
||||||
|
async with httpx.AsyncClient(timeout=8.0) as client:
|
||||||
|
r = await client.post(webhook_url, json=payload)
|
||||||
|
if r.is_success:
|
||||||
|
return {"ok": True, "message": "Test embed sent — check your Discord channel to confirm it arrived."}
|
||||||
|
return {"ok": False, "message": f"Discord returned HTTP {r.status_code}: {r.text[:200]}"}
|
||||||
|
except Exception as e:
|
||||||
|
return {"ok": False, "message": f"Request failed: {e}"[:400]}
|
||||||
|
|
||||||
# All other presets: GET against a known health endpoint.
|
# All other presets: GET against a known health endpoint.
|
||||||
# Fall back to detecting from name if preset is missing.
|
# Fall back to detecting from name if preset is missing.
|
||||||
health_paths = {
|
health_paths = {
|
||||||
|
|||||||
+56
-12
@@ -101,24 +101,68 @@ def setup_backup_routes(memory_manager, preset_manager, skills_manager) -> APIRo
|
|||||||
# ── Skills ──
|
# ── Skills ──
|
||||||
if "skills" in body and isinstance(body["skills"], list):
|
if "skills" in body and isinstance(body["skills"], list):
|
||||||
existing = skills_manager.load_all()
|
existing = skills_manager.load_all()
|
||||||
existing_ids = {s.get("id") for s in existing}
|
existing_names = {s.get("name") for s in existing if s.get("name")}
|
||||||
existing_titles = {s.get("title", "").strip().lower() for s in existing}
|
existing_ids = {s.get("id") for s in existing if s.get("id")}
|
||||||
|
existing_titles = {
|
||||||
|
(s.get("title") or s.get("description") or "").strip().lower()
|
||||||
|
for s in existing
|
||||||
|
}
|
||||||
added = 0
|
added = 0
|
||||||
for skill in body["skills"]:
|
for skill in body["skills"]:
|
||||||
if not isinstance(skill, dict) or not skill.get("title"):
|
if not isinstance(skill, dict):
|
||||||
continue
|
continue
|
||||||
# Skip if same id or same title already exists
|
title = (
|
||||||
if skill.get("id") in existing_ids:
|
skill.get("title") or skill.get("description")
|
||||||
|
or skill.get("name") or ""
|
||||||
|
).strip()
|
||||||
|
if not title:
|
||||||
continue
|
continue
|
||||||
if skill["title"].strip().lower() in existing_titles:
|
sid = skill.get("id") or skill.get("name")
|
||||||
|
if sid and sid in existing_ids:
|
||||||
continue
|
continue
|
||||||
if user and not skill.get("owner"):
|
nm = skill.get("name")
|
||||||
skill["owner"] = user
|
if nm and nm in existing_names:
|
||||||
existing.append(skill)
|
continue
|
||||||
existing_ids.add(skill.get("id"))
|
if title.lower() in existing_titles:
|
||||||
existing_titles.add(skill["title"].strip().lower())
|
continue
|
||||||
|
owner = skill.get("owner")
|
||||||
|
if user and not owner:
|
||||||
|
owner = user
|
||||||
|
# Skills live on disk as SKILL.md files; the old JSON-era
|
||||||
|
# skills_manager.save() no longer exists. Write each new skill
|
||||||
|
# via add_skill (source="user" skips auto-dedup — this is an
|
||||||
|
# explicit backup restore).
|
||||||
|
result = skills_manager.add_skill(
|
||||||
|
title=title,
|
||||||
|
name=skill.get("name"),
|
||||||
|
description=skill.get("description"),
|
||||||
|
problem=skill.get("problem", ""),
|
||||||
|
solution=skill.get("solution", ""),
|
||||||
|
steps=skill.get("steps"),
|
||||||
|
tags=skill.get("tags"),
|
||||||
|
source="user",
|
||||||
|
teacher_model=skill.get("teacher_model"),
|
||||||
|
confidence=skill.get("confidence", 0.8),
|
||||||
|
owner=owner,
|
||||||
|
category=skill.get("category", "general"),
|
||||||
|
when_to_use=skill.get("when_to_use"),
|
||||||
|
procedure=skill.get("procedure"),
|
||||||
|
pitfalls=skill.get("pitfalls"),
|
||||||
|
verification=skill.get("verification"),
|
||||||
|
platforms=skill.get("platforms"),
|
||||||
|
requires_toolsets=skill.get("requires_toolsets"),
|
||||||
|
fallback_for_toolsets=skill.get("fallback_for_toolsets"),
|
||||||
|
status=skill.get("status", "draft"),
|
||||||
|
version=skill.get("version", "1.0.0"),
|
||||||
|
)
|
||||||
|
if result.get("_deduped"):
|
||||||
|
continue
|
||||||
|
if result.get("name"):
|
||||||
|
existing_names.add(result["name"])
|
||||||
|
if result.get("id"):
|
||||||
|
existing_ids.add(result["id"])
|
||||||
|
existing_titles.add(title.lower())
|
||||||
added += 1
|
added += 1
|
||||||
skills_manager.save(existing)
|
|
||||||
imported.append(f"{added} skills")
|
imported.append(f"{added} skills")
|
||||||
|
|
||||||
# ── Presets ──
|
# ── Presets ──
|
||||||
|
|||||||
+254
-69
@@ -1,6 +1,7 @@
|
|||||||
"""Calendar routes — local SQLite-backed calendar CRUD."""
|
"""Calendar routes — local SQLite-backed calendar CRUD."""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import re
|
||||||
import uuid
|
import uuid
|
||||||
from datetime import datetime, date, timedelta
|
from datetime import datetime, date, timedelta
|
||||||
from typing import Optional, List
|
from typing import Optional, List
|
||||||
@@ -12,7 +13,7 @@ from dateutil.rrule import rrulestr
|
|||||||
|
|
||||||
from core.database import SessionLocal, CalendarCal, CalendarEvent
|
from core.database import SessionLocal, CalendarCal, CalendarEvent
|
||||||
from src.auth_helpers import require_user
|
from src.auth_helpers import require_user
|
||||||
from src.upload_limits import read_upload_limited
|
from src.upload_limits import read_upload_limited, ICS_MAX_BYTES
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -100,6 +101,15 @@ def _ics_escape(text: str) -> str:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _safe_ics_filename(name: str) -> str:
|
||||||
|
"""Return a conservative .ics filename safe for Content-Disposition."""
|
||||||
|
stem = name if isinstance(name, str) else ""
|
||||||
|
stem = re.sub(r"[^A-Za-z0-9._-]", "_", stem).strip("._-")
|
||||||
|
if not stem:
|
||||||
|
stem = "calendar"
|
||||||
|
return f"{stem[:128]}.ics"
|
||||||
|
|
||||||
|
|
||||||
def _resolve_base_uid(uid: str) -> str:
|
def _resolve_base_uid(uid: str) -> str:
|
||||||
"""Extract the base series UID from a compound occurrence UID.
|
"""Extract the base series UID from a compound occurrence UID.
|
||||||
|
|
||||||
@@ -248,6 +258,17 @@ def parse_due_for_user(s: str) -> str:
|
|||||||
if t is not None:
|
if t is not None:
|
||||||
return base.replace(hour=t[0], minute=t[1]).isoformat()
|
return base.replace(hour=t[0], minute=t[1]).isoformat()
|
||||||
|
|
||||||
|
# Time-first: "3pm today", "11pm today", "9am tomorrow"
|
||||||
|
m = _re.match(r'^(.+?)\s+(today|tonight|tomorrow|tmrw|yesterday)$', lower)
|
||||||
|
if m:
|
||||||
|
time_part, word = m.group(1).strip(), m.group(2)
|
||||||
|
base = today
|
||||||
|
if word in ("tomorrow", "tmrw"): base = today + _td(days=1)
|
||||||
|
elif word == "yesterday": base = today - _td(days=1)
|
||||||
|
t = _parse_time(time_part)
|
||||||
|
if t is not None:
|
||||||
|
return base.replace(hour=t[0], minute=t[1]).isoformat()
|
||||||
|
|
||||||
m = _re.match(r'^in\s+(\d+)\s*(hour|hr|minute|min|day)s?\s*$', lower)
|
m = _re.match(r'^in\s+(\d+)\s*(hour|hr|minute|min|day)s?\s*$', lower)
|
||||||
if m:
|
if m:
|
||||||
n = int(m.group(1)); unit = m.group(2)
|
n = int(m.group(1)); unit = m.group(2)
|
||||||
@@ -399,7 +420,17 @@ def _parse_dt(s: str) -> datetime:
|
|||||||
# Last resort: dateutil's fuzzy parser
|
# Last resort: dateutil's fuzzy parser
|
||||||
try:
|
try:
|
||||||
from dateutil import parser as _du
|
from dateutil import parser as _du
|
||||||
return _du.parse(s)
|
parsed = _du.parse(s)
|
||||||
|
# Strip tz like every other return path above — this function's
|
||||||
|
# contract is naive datetimes (CalendarEvent.dtstart is naive). An
|
||||||
|
# offset-bearing non-ISO input (e.g. RFC-2822 "Mon, 05 Jan 2026
|
||||||
|
# 14:00:00 +0900") otherwise leaked tz-aware into the naive column and
|
||||||
|
# crashed read-back comparisons in _expand_rrule with "can't compare
|
||||||
|
# offset-naive and offset-aware datetimes".
|
||||||
|
if parsed.tzinfo is not None:
|
||||||
|
from datetime import timezone as _tz
|
||||||
|
return parsed.astimezone(_tz.utc).replace(tzinfo=None)
|
||||||
|
return parsed
|
||||||
except Exception:
|
except Exception:
|
||||||
raise ValueError(f"could not parse datetime: {s!r}")
|
raise ValueError(f"could not parse datetime: {s!r}")
|
||||||
|
|
||||||
@@ -440,6 +471,9 @@ def _event_to_dict(ev: CalendarEvent) -> dict:
|
|||||||
|
|
||||||
# ── Recurrence expansion ──
|
# ── Recurrence expansion ──
|
||||||
|
|
||||||
|
_RRULE_EXPANSION_LIMIT = 1000
|
||||||
|
|
||||||
|
|
||||||
def _expand_rrule(
|
def _expand_rrule(
|
||||||
ev: CalendarEvent, start: datetime, end: datetime
|
ev: CalendarEvent, start: datetime, end: datetime
|
||||||
) -> List[dict]:
|
) -> List[dict]:
|
||||||
@@ -462,6 +496,7 @@ def _expand_rrule(
|
|||||||
d = _event_to_dict(ev)
|
d = _event_to_dict(ev)
|
||||||
d["is_recurrence"] = False
|
d["is_recurrence"] = False
|
||||||
d["series_uid"] = ev.uid
|
d["series_uid"] = ev.uid
|
||||||
|
d["truncated"] = False
|
||||||
return [d]
|
return [d]
|
||||||
|
|
||||||
# Parse the rrule, applying it to the base dtstart.
|
# Parse the rrule, applying it to the base dtstart.
|
||||||
@@ -487,6 +522,7 @@ def _expand_rrule(
|
|||||||
d = _event_to_dict(ev)
|
d = _event_to_dict(ev)
|
||||||
d["is_recurrence"] = False
|
d["is_recurrence"] = False
|
||||||
d["series_uid"] = ev.uid
|
d["series_uid"] = ev.uid
|
||||||
|
d["truncated"] = False
|
||||||
# Malformed RRULE rows are fetched by the recurring SQL branch
|
# Malformed RRULE rows are fetched by the recurring SQL branch
|
||||||
# with only dtstart < end_dt — the base event may not actually
|
# with only dtstart < end_dt — the base event may not actually
|
||||||
# overlap the window. Only return if it does.
|
# overlap the window. Only return if it does.
|
||||||
@@ -499,22 +535,26 @@ def _expand_rrule(
|
|||||||
# (matching non-recurring overlap semantics: dtstart < end AND
|
# (matching non-recurring overlap semantics: dtstart < end AND
|
||||||
# dtend > start).
|
# dtend > start).
|
||||||
expand_start = start - duration
|
expand_start = start - duration
|
||||||
occurrences = rule.between(expand_start, end, inc=True)
|
|
||||||
if not occurrences:
|
|
||||||
return []
|
|
||||||
|
|
||||||
results = []
|
results = []
|
||||||
|
truncated = False
|
||||||
base = _event_to_dict(ev)
|
base = _event_to_dict(ev)
|
||||||
|
|
||||||
for occ_start in occurrences:
|
for occ_start in rule.xafter(expand_start, inc=True):
|
||||||
|
if occ_start >= end:
|
||||||
|
break
|
||||||
|
|
||||||
occ_end = occ_start + duration
|
occ_end = occ_start + duration
|
||||||
|
|
||||||
# Overlap filter: occurrence must intersect [start, end).
|
# Overlap filter: occurrence must intersect [start, end).
|
||||||
# This enforces exclusive-end semantics (occ_start >= end is
|
# This enforces exclusive-end semantics (occ_start >= end is
|
||||||
# excluded) and includes multi-day crossings (occ_end > start).
|
# excluded) and includes multi-day crossings (occ_end > start).
|
||||||
if occ_start >= end or occ_end <= start:
|
if occ_end <= start:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
if len(results) >= _RRULE_EXPANSION_LIMIT:
|
||||||
|
truncated = True
|
||||||
|
break
|
||||||
|
|
||||||
# Build the compound uid: {base_uid}::{date} or ::{datetime}
|
# Build the compound uid: {base_uid}::{date} or ::{datetime}
|
||||||
if ev.all_day:
|
if ev.all_day:
|
||||||
occ_uid = f"{ev.uid}::{occ_start.strftime('%Y-%m-%d')}"
|
occ_uid = f"{ev.uid}::{occ_start.strftime('%Y-%m-%d')}"
|
||||||
@@ -525,6 +565,7 @@ def _expand_rrule(
|
|||||||
d["uid"] = occ_uid
|
d["uid"] = occ_uid
|
||||||
d["series_uid"] = ev.uid
|
d["series_uid"] = ev.uid
|
||||||
d["is_recurrence"] = True
|
d["is_recurrence"] = True
|
||||||
|
d["truncated"] = False
|
||||||
|
|
||||||
if ev.all_day:
|
if ev.all_day:
|
||||||
d["dtstart"] = occ_start.strftime("%Y-%m-%d")
|
d["dtstart"] = occ_start.strftime("%Y-%m-%d")
|
||||||
@@ -537,6 +578,10 @@ def _expand_rrule(
|
|||||||
|
|
||||||
results.append(d)
|
results.append(d)
|
||||||
|
|
||||||
|
if truncated:
|
||||||
|
for d in results:
|
||||||
|
d["truncated"] = True
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
@@ -545,72 +590,178 @@ def _expand_rrule(
|
|||||||
def setup_calendar_routes() -> APIRouter:
|
def setup_calendar_routes() -> APIRouter:
|
||||||
router = APIRouter(prefix="/api/calendar", tags=["calendar"])
|
router = APIRouter(prefix="/api/calendar", tags=["calendar"])
|
||||||
|
|
||||||
# CalDAV connect form (Integrations → Calendar). Storage is local
|
# ── CalDAV multi-account helpers ─────────────────────────────────────────
|
||||||
# SQLite; sync (src/caldav_sync.py) pulls remote events into it on
|
|
||||||
# calendar open and periodically via the scheduler.
|
def _get_caldav_accounts(owner: str) -> list:
|
||||||
|
from src.caldav_sync import _load_caldav_accounts
|
||||||
|
return _load_caldav_accounts(owner)
|
||||||
|
|
||||||
|
def _save_caldav_accounts(owner: str, accounts: list) -> None:
|
||||||
|
from routes.prefs_routes import _load_for_user, _save_for_user
|
||||||
|
prefs = _load_for_user(owner) or {}
|
||||||
|
prefs["caldav_accounts"] = accounts
|
||||||
|
prefs.pop("caldav", None)
|
||||||
|
_save_for_user(owner, prefs)
|
||||||
|
|
||||||
|
# ── CalDAV config routes (backward-compat single-account API) ────────────
|
||||||
|
|
||||||
@router.get("/config")
|
@router.get("/config")
|
||||||
async def get_config(request: Request):
|
async def get_config(request: Request):
|
||||||
|
"""Legacy single-account endpoint — returns the first configured account."""
|
||||||
owner = _require_user(request)
|
owner = _require_user(request)
|
||||||
from routes.prefs_routes import _load_for_user
|
accounts = _get_caldav_accounts(owner)
|
||||||
cfg = (_load_for_user(owner) or {}).get("caldav", {}) or {}
|
if not accounts:
|
||||||
caldav_password = cfg.get("password") or ""
|
return {"url": "", "username": "", "password": "", "has_password": False, "local": True}
|
||||||
if caldav_password:
|
first = accounts[0]
|
||||||
|
pw = first.get("password") or ""
|
||||||
|
has_pw = False
|
||||||
|
if pw:
|
||||||
try:
|
try:
|
||||||
from src.secret_storage import decrypt
|
from src.secret_storage import decrypt
|
||||||
caldav_password = decrypt(caldav_password)
|
has_pw = bool(decrypt(pw))
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
has_pw = bool(pw)
|
||||||
# Surface url+username but never hand the password back to the
|
|
||||||
# client — saved-state UI shouldn't leak the credential.
|
|
||||||
return {
|
return {
|
||||||
"url": cfg.get("url", "") or "",
|
"url": first.get("url", "") or "",
|
||||||
"username": cfg.get("username", "") or "",
|
"username": first.get("username", "") or "",
|
||||||
"password": "",
|
"password": "",
|
||||||
"has_password": bool(caldav_password),
|
"has_password": has_pw,
|
||||||
"local": not bool(cfg.get("url")),
|
"local": not bool(first.get("url")),
|
||||||
}
|
}
|
||||||
|
|
||||||
@router.post("/config")
|
@router.post("/config")
|
||||||
async def save_config(request: Request):
|
async def save_config(request: Request):
|
||||||
|
"""Legacy single-account endpoint — upserts the first account."""
|
||||||
owner = _require_user(request)
|
owner = _require_user(request)
|
||||||
from routes.prefs_routes import _load_for_user, _save_for_user
|
|
||||||
try:
|
try:
|
||||||
body = await request.json()
|
body = await request.json()
|
||||||
except Exception:
|
except Exception:
|
||||||
body = {}
|
body = {}
|
||||||
prefs = _load_for_user(owner) or {}
|
accounts = _get_caldav_accounts(owner)
|
||||||
cfg = dict(prefs.get("caldav") or {})
|
|
||||||
# Empty url => clear the whole entry (treat as "remove integration").
|
|
||||||
if not (body.get("url") or "").strip():
|
if not (body.get("url") or "").strip():
|
||||||
prefs.pop("caldav", None)
|
_save_caldav_accounts(owner, [])
|
||||||
_save_for_user(owner, prefs)
|
|
||||||
return {"ok": True, "cleared": True}
|
return {"ok": True, "cleared": True}
|
||||||
from src.caldav_sync import validate_caldav_url
|
from src.caldav_sync import validate_caldav_url
|
||||||
try:
|
try:
|
||||||
cfg["url"] = validate_caldav_url(body.get("url", ""))
|
validated_url = validate_caldav_url(body.get("url", ""))
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
raise HTTPException(400, str(e))
|
raise HTTPException(400, str(e))
|
||||||
cfg["username"] = (body.get("username") or "").strip()
|
if accounts:
|
||||||
# Preserve the stored password when the client sends an empty
|
acc = dict(accounts[0])
|
||||||
# one (edit form re-submitted without re-typing the password).
|
else:
|
||||||
# cfg already holds the existing (already-encrypted) password from
|
import uuid as _uuid
|
||||||
# prefs, so we only touch it when a new password is supplied —
|
acc = {"id": str(_uuid.uuid4()), "label": "CalDAV"}
|
||||||
# re-encrypting the stored value would double-encrypt it.
|
acc["url"] = validated_url
|
||||||
|
acc["username"] = (body.get("username") or "").strip()
|
||||||
if body.get("password"):
|
if body.get("password"):
|
||||||
from src.secret_storage import encrypt
|
from src.secret_storage import encrypt
|
||||||
cfg["password"] = encrypt(body["password"])
|
acc["password"] = encrypt(body["password"])
|
||||||
prefs["caldav"] = cfg
|
new_accounts = [acc] + (accounts[1:] if len(accounts) > 1 else [])
|
||||||
_save_for_user(owner, prefs)
|
_save_caldav_accounts(owner, new_accounts)
|
||||||
|
return {"ok": True}
|
||||||
|
|
||||||
|
# ── CalDAV multi-account CRUD ─────────────────────────────────────────────
|
||||||
|
|
||||||
|
@router.get("/config/accounts")
|
||||||
|
async def list_caldav_accounts(request: Request):
|
||||||
|
"""Return all configured CalDAV accounts (passwords never returned)."""
|
||||||
|
owner = _require_user(request)
|
||||||
|
accounts = _get_caldav_accounts(owner)
|
||||||
|
safe = []
|
||||||
|
for acc in accounts:
|
||||||
|
pw = acc.get("password") or ""
|
||||||
|
has_pw = False
|
||||||
|
if pw:
|
||||||
|
try:
|
||||||
|
from src.secret_storage import decrypt
|
||||||
|
has_pw = bool(decrypt(pw))
|
||||||
|
except Exception:
|
||||||
|
has_pw = bool(pw)
|
||||||
|
safe.append({
|
||||||
|
"id": acc.get("id", ""),
|
||||||
|
"label": acc.get("label", "") or acc.get("url", ""),
|
||||||
|
"url": acc.get("url", "") or "",
|
||||||
|
"username": acc.get("username", "") or "",
|
||||||
|
"has_password": has_pw,
|
||||||
|
})
|
||||||
|
return {"accounts": safe}
|
||||||
|
|
||||||
|
@router.post("/config/accounts")
|
||||||
|
async def add_caldav_account(request: Request):
|
||||||
|
"""Add a new CalDAV account."""
|
||||||
|
import uuid as _uuid
|
||||||
|
owner = _require_user(request)
|
||||||
|
try:
|
||||||
|
body = await request.json()
|
||||||
|
except Exception:
|
||||||
|
body = {}
|
||||||
|
from src.caldav_sync import validate_caldav_url
|
||||||
|
try:
|
||||||
|
url = validate_caldav_url(body.get("url", ""))
|
||||||
|
except ValueError as e:
|
||||||
|
raise HTTPException(400, str(e))
|
||||||
|
if not body.get("password"):
|
||||||
|
raise HTTPException(400, "Password is required")
|
||||||
|
from src.secret_storage import encrypt
|
||||||
|
new_acc = {
|
||||||
|
"id": str(_uuid.uuid4()),
|
||||||
|
"label": (body.get("label") or "").strip() or "CalDAV",
|
||||||
|
"url": url,
|
||||||
|
"username": (body.get("username") or "").strip(),
|
||||||
|
"password": encrypt(body["password"]),
|
||||||
|
}
|
||||||
|
accounts = _get_caldav_accounts(owner)
|
||||||
|
accounts.append(new_acc)
|
||||||
|
_save_caldav_accounts(owner, accounts)
|
||||||
|
return {"ok": True, "id": new_acc["id"]}
|
||||||
|
|
||||||
|
@router.put("/config/accounts/{account_id}")
|
||||||
|
async def update_caldav_account(account_id: str, request: Request):
|
||||||
|
"""Update an existing CalDAV account by id."""
|
||||||
|
owner = _require_user(request)
|
||||||
|
try:
|
||||||
|
body = await request.json()
|
||||||
|
except Exception:
|
||||||
|
body = {}
|
||||||
|
accounts = _get_caldav_accounts(owner)
|
||||||
|
idx = next((i for i, a in enumerate(accounts) if a.get("id") == account_id), None)
|
||||||
|
if idx is None:
|
||||||
|
raise HTTPException(404, "Account not found")
|
||||||
|
acc = dict(accounts[idx])
|
||||||
|
if body.get("url"):
|
||||||
|
from src.caldav_sync import validate_caldav_url
|
||||||
|
try:
|
||||||
|
acc["url"] = validate_caldav_url(body["url"])
|
||||||
|
except ValueError as e:
|
||||||
|
raise HTTPException(400, str(e))
|
||||||
|
if body.get("label") is not None:
|
||||||
|
acc["label"] = (body.get("label") or "").strip() or "CalDAV"
|
||||||
|
if body.get("username") is not None:
|
||||||
|
acc["username"] = (body.get("username") or "").strip()
|
||||||
|
if body.get("password"):
|
||||||
|
from src.secret_storage import encrypt
|
||||||
|
acc["password"] = encrypt(body["password"])
|
||||||
|
accounts[idx] = acc
|
||||||
|
_save_caldav_accounts(owner, accounts)
|
||||||
|
return {"ok": True}
|
||||||
|
|
||||||
|
@router.delete("/config/accounts/{account_id}")
|
||||||
|
async def delete_caldav_account(account_id: str, request: Request):
|
||||||
|
"""Remove a CalDAV account by id."""
|
||||||
|
owner = _require_user(request)
|
||||||
|
accounts = _get_caldav_accounts(owner)
|
||||||
|
new_accounts = [a for a in accounts if a.get("id") != account_id]
|
||||||
|
if len(new_accounts) == len(accounts):
|
||||||
|
raise HTTPException(404, "Account not found")
|
||||||
|
_save_caldav_accounts(owner, new_accounts)
|
||||||
return {"ok": True}
|
return {"ok": True}
|
||||||
|
|
||||||
@router.post("/test")
|
@router.post("/test")
|
||||||
async def test_connection(request: Request):
|
async def test_connection(request: Request):
|
||||||
"""Actually probe the configured CalDAV server with a PROPFIND
|
"""Probe a CalDAV server with a PROPFIND. Accepts an optional body:
|
||||||
request (the same handshake every CalDAV client uses). Accepts
|
{url, username, password} to test before saving, or {account_id} to
|
||||||
an optional {url, username, password} body so the user can test
|
test an already-saved account. Falls back to the first saved account
|
||||||
a configuration BEFORE saving it; falls back to the stored
|
when nothing is provided."""
|
||||||
creds otherwise. Returns {ok, error?} with a useful message on
|
|
||||||
failure (status code, auth issue, network error)."""
|
|
||||||
owner = _require_user(request)
|
owner = _require_user(request)
|
||||||
try:
|
try:
|
||||||
body = await request.json()
|
body = await request.json()
|
||||||
@@ -620,19 +771,24 @@ def setup_calendar_routes() -> APIRouter:
|
|||||||
user = (body.get("username") or "").strip()
|
user = (body.get("username") or "").strip()
|
||||||
pw = body.get("password") or ""
|
pw = body.get("password") or ""
|
||||||
if not (url and user and pw):
|
if not (url and user and pw):
|
||||||
# Fall back to saved settings for this user.
|
# Look up a saved account: by id if supplied, else first account.
|
||||||
from routes.prefs_routes import _load_for_user
|
accounts = _get_caldav_accounts(owner)
|
||||||
cfg = (_load_for_user(owner) or {}).get("caldav", {}) or {}
|
acc = None
|
||||||
url = url or (cfg.get("url") or "")
|
if body.get("account_id"):
|
||||||
user = user or (cfg.get("username") or "")
|
acc = next((a for a in accounts if a.get("id") == body["account_id"]), None)
|
||||||
if not pw:
|
if acc is None and accounts:
|
||||||
pw = cfg.get("password") or ""
|
acc = accounts[0]
|
||||||
if pw:
|
if acc:
|
||||||
try:
|
url = url or (acc.get("url") or "")
|
||||||
from src.secret_storage import decrypt
|
user = user or (acc.get("username") or "")
|
||||||
pw = decrypt(pw)
|
if not pw:
|
||||||
except Exception:
|
pw = acc.get("password") or ""
|
||||||
pass
|
if pw:
|
||||||
|
try:
|
||||||
|
from src.secret_storage import decrypt
|
||||||
|
pw = decrypt(pw)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
if not (url and user and pw):
|
if not (url and user and pw):
|
||||||
return {"ok": False, "error": "Missing URL, username, or password"}
|
return {"ok": False, "error": "Missing URL, username, or password"}
|
||||||
from src.caldav_sync import validate_caldav_url
|
from src.caldav_sync import validate_caldav_url
|
||||||
@@ -695,6 +851,28 @@ def setup_calendar_routes() -> APIRouter:
|
|||||||
from src.caldav_sync import sync_caldav
|
from src.caldav_sync import sync_caldav
|
||||||
return await sync_caldav(owner)
|
return await sync_caldav(owner)
|
||||||
|
|
||||||
|
@router.delete("/calendars/{cal_id}")
|
||||||
|
async def delete_calendar(cal_id: str, request: Request):
|
||||||
|
owner = _require_user(request)
|
||||||
|
db = SessionLocal()
|
||||||
|
try:
|
||||||
|
cal = db.query(CalendarCal).filter(
|
||||||
|
CalendarCal.id == cal_id,
|
||||||
|
CalendarCal.owner == owner,
|
||||||
|
).first()
|
||||||
|
if not cal:
|
||||||
|
raise HTTPException(404, "Calendar not found")
|
||||||
|
db.delete(cal)
|
||||||
|
db.commit()
|
||||||
|
return {"ok": True}
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Failed to delete calendar %s: %s", cal_id, e)
|
||||||
|
raise HTTPException(500, "Failed to delete calendar")
|
||||||
|
finally:
|
||||||
|
db.close()
|
||||||
|
|
||||||
@router.get("/calendars")
|
@router.get("/calendars")
|
||||||
async def list_calendars(request: Request):
|
async def list_calendars(request: Request):
|
||||||
owner = _require_user(request)
|
owner = _require_user(request)
|
||||||
@@ -703,7 +881,7 @@ def setup_calendar_routes() -> APIRouter:
|
|||||||
_ensure_default_calendar(db, owner)
|
_ensure_default_calendar(db, owner)
|
||||||
cals = db.query(CalendarCal).filter(CalendarCal.owner == owner).all()
|
cals = db.query(CalendarCal).filter(CalendarCal.owner == owner).all()
|
||||||
return {"calendars": [
|
return {"calendars": [
|
||||||
{"name": c.name, "href": c.id, "color": c.color}
|
{"name": c.name, "href": c.id, "color": c.color, "source": c.source}
|
||||||
for c in cals
|
for c in cals
|
||||||
]}
|
]}
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
@@ -766,8 +944,12 @@ def setup_calendar_routes() -> APIRouter:
|
|||||||
expanded.extend(_expand_rrule(e, start_dt, end_dt))
|
expanded.extend(_expand_rrule(e, start_dt, end_dt))
|
||||||
|
|
||||||
# Sort by occurrence start time for consistent frontend ordering.
|
# Sort by occurrence start time for consistent frontend ordering.
|
||||||
|
truncated = any(e.get("truncated") for e in expanded)
|
||||||
expanded.sort(key=lambda d: d["dtstart"])
|
expanded.sort(key=lambda d: d["dtstart"])
|
||||||
return {"events": expanded}
|
response: dict = {"events": expanded}
|
||||||
|
if truncated:
|
||||||
|
response["truncated"] = True
|
||||||
|
return response
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -988,9 +1170,9 @@ def setup_calendar_routes() -> APIRouter:
|
|||||||
finally:
|
finally:
|
||||||
db.close()
|
db.close()
|
||||||
|
|
||||||
# 10 MB hard cap on ICS upload. Loading the whole file into memory is
|
# Hard cap on ICS upload (ICS_MAX_BYTES, default 10 MB). Loading the whole
|
||||||
# unavoidable with python-icalendar, so an unbounded upload would OOM.
|
# file into memory is unavoidable with python-icalendar, so an unbounded
|
||||||
_ICS_MAX_BYTES = 10 * 1024 * 1024
|
# upload would OOM.
|
||||||
|
|
||||||
@router.post("/import")
|
@router.post("/import")
|
||||||
async def import_ics(request: Request, file: UploadFile = File(...), calendar_name: str = ""):
|
async def import_ics(request: Request, file: UploadFile = File(...), calendar_name: str = ""):
|
||||||
@@ -1000,7 +1182,7 @@ def setup_calendar_routes() -> APIRouter:
|
|||||||
owner = _require_user(request)
|
owner = _require_user(request)
|
||||||
db = SessionLocal()
|
db = SessionLocal()
|
||||||
try:
|
try:
|
||||||
content = await read_upload_limited(file, _ICS_MAX_BYTES, "ICS file")
|
content = await read_upload_limited(file, ICS_MAX_BYTES, "ICS file")
|
||||||
try:
|
try:
|
||||||
cal_data = iCal.from_ical(content)
|
cal_data = iCal.from_ical(content)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -1168,11 +1350,14 @@ def setup_calendar_routes() -> APIRouter:
|
|||||||
lines.append("END:VCALENDAR")
|
lines.append("END:VCALENDAR")
|
||||||
|
|
||||||
ics_data = "\r\n".join(lines)
|
ics_data = "\r\n".join(lines)
|
||||||
safe_name = cal.name.replace(" ", "_").replace("/", "_")
|
download_name = _safe_ics_filename(cal.name)
|
||||||
return Response(
|
return Response(
|
||||||
content=ics_data,
|
content=ics_data,
|
||||||
media_type="text/calendar",
|
media_type="text/calendar",
|
||||||
headers={"Content-Disposition": f'attachment; filename="{safe_name}.ics"'},
|
headers={
|
||||||
|
"Content-Disposition": f'attachment; filename="{download_name}"',
|
||||||
|
"X-Content-Type-Options": "nosniff",
|
||||||
|
},
|
||||||
)
|
)
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
raise
|
raise
|
||||||
@@ -1194,7 +1379,7 @@ def setup_calendar_routes() -> APIRouter:
|
|||||||
"tomorrow", "next Tuesday", "in 30 minutes" resolve correctly.
|
"tomorrow", "next Tuesday", "in 30 minutes" resolve correctly.
|
||||||
Uses the "utility" endpoint (small / fast model) to keep latency low.
|
Uses the "utility" endpoint (small / fast model) to keep latency low.
|
||||||
"""
|
"""
|
||||||
_require_user(request)
|
owner = _require_user(request)
|
||||||
from src.endpoint_resolver import resolve_endpoint
|
from src.endpoint_resolver import resolve_endpoint
|
||||||
from src.llm_core import llm_call_async
|
from src.llm_core import llm_call_async
|
||||||
from src.text_helpers import strip_think
|
from src.text_helpers import strip_think
|
||||||
@@ -1220,9 +1405,9 @@ def setup_calendar_routes() -> APIRouter:
|
|||||||
if tz_hint:
|
if tz_hint:
|
||||||
set_user_tz_name(tz_hint)
|
set_user_tz_name(tz_hint)
|
||||||
|
|
||||||
url, model, headers = resolve_endpoint("utility")
|
url, model, headers = resolve_endpoint("utility", owner=owner or None)
|
||||||
if not url:
|
if not url:
|
||||||
url, model, headers = resolve_endpoint("default")
|
url, model, headers = resolve_endpoint("default", owner=owner or None)
|
||||||
if not url or not model:
|
if not url or not model:
|
||||||
return {"ok": False, "error": "No LLM endpoint configured"}
|
return {"ok": False, "error": "No LLM endpoint configured"}
|
||||||
|
|
||||||
|
|||||||
+130
-38
@@ -75,7 +75,7 @@ def _enforce_chat_privileges(request, sess) -> None:
|
|||||||
allowlist, or HTTPException(429) if the user has hit their daily message
|
allowlist, or HTTPException(429) if the user has hit their daily message
|
||||||
cap. No-op for unauthenticated callers or when auth_manager is absent
|
cap. No-op for unauthenticated callers or when auth_manager is absent
|
||||||
(single-user mode). Admins receive ADMIN_PRIVILEGES from get_privileges,
|
(single-user mode). Admins receive ADMIN_PRIVILEGES from get_privileges,
|
||||||
which means empty allowed_models / zero cap → no-op for them.
|
which means unrestricted allowed_models / zero cap -> no-op for them.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
user = get_current_user(request)
|
user = get_current_user(request)
|
||||||
@@ -88,8 +88,18 @@ def _enforce_chat_privileges(request, sess) -> None:
|
|||||||
return
|
return
|
||||||
|
|
||||||
privs = auth_manager.get_privileges(user) or {}
|
privs = auth_manager.get_privileges(user) or {}
|
||||||
allowed = privs.get("allowed_models") or []
|
|
||||||
if allowed and sess.model and sess.model not in allowed:
|
# Explicit "block everything" sentinel takes precedence over the
|
||||||
|
# allowlist — it's the only way to distinguish "user clicked [None]"
|
||||||
|
# (block all) from "user clicked [All]" (no restriction), since both
|
||||||
|
# otherwise produce an empty `allowed_models` list.
|
||||||
|
if privs.get("block_all_models"):
|
||||||
|
raise HTTPException(403, f"Your account is not allowed to use model '{sess.model}'.")
|
||||||
|
|
||||||
|
allowed_raw = privs.get("allowed_models")
|
||||||
|
allowed = allowed_raw if isinstance(allowed_raw, list) else []
|
||||||
|
restricted = bool(privs.get("allowed_models_restricted")) or bool(allowed)
|
||||||
|
if restricted and sess.model and sess.model not in allowed:
|
||||||
raise HTTPException(403, f"Your account is not allowed to use model '{sess.model}'.")
|
raise HTTPException(403, f"Your account is not allowed to use model '{sess.model}'.")
|
||||||
|
|
||||||
cap = int(privs.get("max_messages_per_day") or 0)
|
cap = int(privs.get("max_messages_per_day") or 0)
|
||||||
@@ -194,14 +204,26 @@ def try_fallback_endpoint(sess, session_id: str) -> dict | None:
|
|||||||
Returns {"model": ..., "endpoint_url": ..., "endpoint_name": ...} or None.
|
Returns {"model": ..., "endpoint_url": ..., "endpoint_name": ...} or None.
|
||||||
"""
|
"""
|
||||||
import requests as _req
|
import requests as _req
|
||||||
from src.endpoint_resolver import build_chat_url, build_headers, build_models_url, normalize_base
|
from src.endpoint_resolver import (
|
||||||
|
build_chat_url,
|
||||||
|
build_headers,
|
||||||
|
build_models_url,
|
||||||
|
normalize_base,
|
||||||
|
resolve_endpoint_runtime,
|
||||||
|
)
|
||||||
|
from src.chatgpt_subscription import is_chatgpt_subscription_base
|
||||||
|
|
||||||
current_url = sess.endpoint_url or ""
|
current_url = sess.endpoint_url or ""
|
||||||
|
owner = getattr(sess, "owner", None)
|
||||||
db = SessionLocal()
|
db = SessionLocal()
|
||||||
try:
|
try:
|
||||||
endpoints = db.query(ModelEndpoint).filter(
|
q = db.query(ModelEndpoint).filter(
|
||||||
ModelEndpoint.is_enabled == True
|
ModelEndpoint.is_enabled == True
|
||||||
).all()
|
)
|
||||||
|
if owner:
|
||||||
|
from src.auth_helpers import owner_filter
|
||||||
|
q = owner_filter(q, ModelEndpoint, owner)
|
||||||
|
endpoints = q.all()
|
||||||
finally:
|
finally:
|
||||||
db.close()
|
db.close()
|
||||||
|
|
||||||
@@ -210,26 +232,33 @@ def try_fallback_endpoint(sess, session_id: str) -> dict | None:
|
|||||||
# Skip current endpoint
|
# Skip current endpoint
|
||||||
if current_url and base in current_url:
|
if current_url and base in current_url:
|
||||||
continue
|
continue
|
||||||
# Quick ping
|
|
||||||
ping_url = build_models_url(base)
|
|
||||||
headers = build_headers(ep.api_key, base)
|
|
||||||
try:
|
try:
|
||||||
r = _req.get(ping_url, headers=headers, timeout=5)
|
base, api_key = resolve_endpoint_runtime(ep, owner=owner)
|
||||||
r.raise_for_status()
|
except Exception:
|
||||||
data = r.json()
|
continue
|
||||||
models = [m.get("id") for m in (data.get("data") or []) if m.get("id")]
|
ping_url = build_models_url(base)
|
||||||
if not models:
|
headers = build_headers(api_key, base)
|
||||||
models = [
|
try:
|
||||||
m.get("name") or m.get("model")
|
if ping_url:
|
||||||
for m in (data.get("models") or [])
|
r = _req.get(ping_url, headers=headers, timeout=5)
|
||||||
if m.get("name") or m.get("model")
|
r.raise_for_status()
|
||||||
]
|
data = r.json()
|
||||||
|
models = [m.get("id") for m in (data.get("data") or []) if m.get("id")]
|
||||||
|
if not models:
|
||||||
|
models = [
|
||||||
|
m.get("name") or m.get("model")
|
||||||
|
for m in (data.get("models") or [])
|
||||||
|
if m.get("name") or m.get("model")
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
models = json.loads(ep.cached_models or "[]")
|
||||||
if not models:
|
if not models:
|
||||||
continue
|
continue
|
||||||
# Found a working endpoint — update session
|
# Found a working endpoint — update session
|
||||||
new_model = models[0]
|
new_model = models[0]
|
||||||
chat_url = build_chat_url(base)
|
chat_url = build_chat_url(base)
|
||||||
new_headers = build_headers(ep.api_key, base)
|
new_headers = build_headers(api_key, base)
|
||||||
|
persisted_headers = {} if is_chatgpt_subscription_base(base) else new_headers
|
||||||
|
|
||||||
sess.model = new_model
|
sess.model = new_model
|
||||||
sess.endpoint_url = chat_url
|
sess.endpoint_url = chat_url
|
||||||
@@ -241,7 +270,7 @@ def try_fallback_endpoint(sess, session_id: str) -> dict | None:
|
|||||||
_db.query(DBSession).filter(DBSession.id == session_id).update({
|
_db.query(DBSession).filter(DBSession.id == session_id).update({
|
||||||
"model": new_model,
|
"model": new_model,
|
||||||
"endpoint_url": chat_url,
|
"endpoint_url": chat_url,
|
||||||
"headers": json.dumps(new_headers),
|
"headers": persisted_headers,
|
||||||
})
|
})
|
||||||
_db.commit()
|
_db.commit()
|
||||||
finally:
|
finally:
|
||||||
@@ -275,11 +304,16 @@ def extract_preset(chat_handler, preset_id) -> PresetInfo:
|
|||||||
async def preprocess(
|
async def preprocess(
|
||||||
chat_handler, message, att_ids, sess,
|
chat_handler, message, att_ids, sess,
|
||||||
auto_opened_docs: Optional[list] = None,
|
auto_opened_docs: Optional[list] = None,
|
||||||
|
allow_tool_preprocessing: bool = True,
|
||||||
) -> PreprocessedMessage:
|
) -> PreprocessedMessage:
|
||||||
"""Run chat_handler.preprocess_message and wrap the result."""
|
"""Run chat_handler.preprocess_message and wrap the result."""
|
||||||
enhanced, user_content, text_ctx, yt_transcripts, att_meta = (
|
enhanced, user_content, text_ctx, yt_transcripts, att_meta = (
|
||||||
await chat_handler.preprocess_message(
|
await chat_handler.preprocess_message(
|
||||||
message, att_ids, sess, auto_opened_docs=auto_opened_docs
|
message,
|
||||||
|
att_ids,
|
||||||
|
sess,
|
||||||
|
auto_opened_docs=auto_opened_docs,
|
||||||
|
allow_tool_preprocessing=allow_tool_preprocessing,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return PreprocessedMessage(
|
return PreprocessedMessage(
|
||||||
@@ -329,16 +363,26 @@ def _session_url_matches_endpoint(session_url: str, endpoint_base: str) -> bool:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _has_auth_keys(headers) -> bool:
|
||||||
|
"""True if a headers dict carries an Authorization/x-api-key entry."""
|
||||||
|
return isinstance(headers, dict) and any(
|
||||||
|
k.lower() in ('authorization', 'x-api-key') for k in headers
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def resolve_session_auth(sess, session_id: str, owner: Optional[str] = None):
|
def resolve_session_auth(sess, session_id: str, owner: Optional[str] = None):
|
||||||
"""Ensure session has auth headers — resolve from endpoint DB if missing."""
|
"""Ensure session has auth headers — resolve from endpoint DB if missing."""
|
||||||
has_auth = sess.headers and isinstance(sess.headers, dict) and any(
|
try:
|
||||||
k.lower() in ('authorization', 'x-api-key') for k in sess.headers
|
from src.chatgpt_subscription import is_chatgpt_subscription_base
|
||||||
)
|
is_chatgpt_subscription = is_chatgpt_subscription_base(getattr(sess, "endpoint_url", "") or "")
|
||||||
if has_auth:
|
except Exception:
|
||||||
|
is_chatgpt_subscription = False
|
||||||
|
has_auth = _has_auth_keys(sess.headers)
|
||||||
|
if has_auth and not is_chatgpt_subscription:
|
||||||
return
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from src.endpoint_resolver import build_headers, normalize_base
|
from src.endpoint_resolver import build_headers, resolve_endpoint_runtime
|
||||||
db = SessionLocal()
|
db = SessionLocal()
|
||||||
try:
|
try:
|
||||||
target_url = getattr(sess, "endpoint_url", "") or ""
|
target_url = getattr(sess, "endpoint_url", "") or ""
|
||||||
@@ -354,10 +398,30 @@ def resolve_session_auth(sess, session_id: str, owner: Optional[str] = None):
|
|||||||
for ep in q.all():
|
for ep in q.all():
|
||||||
if not _session_url_matches_endpoint(target_url, ep.base_url or ""):
|
if not _session_url_matches_endpoint(target_url, ep.base_url or ""):
|
||||||
continue
|
continue
|
||||||
if not ep.api_key:
|
try:
|
||||||
|
base, api_key = resolve_endpoint_runtime(ep, owner=owner)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("Failed to resolve provider auth for session %s: %s", session_id, e)
|
||||||
|
return
|
||||||
|
if not api_key:
|
||||||
|
# No usable key (e.g. ChatGPT Subscription needs re-auth).
|
||||||
|
return
|
||||||
|
sess.headers = build_headers(api_key, base)
|
||||||
|
if is_chatgpt_subscription:
|
||||||
|
# The bearer is short-lived and re-resolved per request, so it
|
||||||
|
# stays request-local and is never written to the plaintext
|
||||||
|
# sessions.headers column. Proactively strip any bearer an
|
||||||
|
# older code path may have persisted so it does not linger.
|
||||||
|
stale_q = db.query(DBSession).filter(DBSession.id == session_id)
|
||||||
|
if owner:
|
||||||
|
stale_q = stale_q.filter(DBSession.owner == owner)
|
||||||
|
stored = stale_q.first()
|
||||||
|
if stored is not None and _has_auth_keys(stored.headers):
|
||||||
|
stale_q.update({"headers": {}})
|
||||||
|
db.commit()
|
||||||
|
logger.info(f"Cleared persisted ChatGPT Subscription bearer from session {session_id}")
|
||||||
|
logger.debug(f"Resolved request-local ChatGPT Subscription auth for session {session_id}")
|
||||||
return
|
return
|
||||||
base = normalize_base(ep.base_url or "")
|
|
||||||
sess.headers = build_headers(ep.api_key, base)
|
|
||||||
update_q = db.query(DBSession).filter(DBSession.id == session_id)
|
update_q = db.query(DBSession).filter(DBSession.id == session_id)
|
||||||
if owner:
|
if owner:
|
||||||
update_q = update_q.filter(DBSession.owner == owner)
|
update_q = update_q.filter(DBSession.owner == owner)
|
||||||
@@ -401,7 +465,12 @@ def _normalize_model_id_from_cache(sess) -> Optional[str]:
|
|||||||
|
|
||||||
db = SessionLocal()
|
db = SessionLocal()
|
||||||
try:
|
try:
|
||||||
endpoints = db.query(ModelEndpoint).filter(ModelEndpoint.is_enabled == True).all()
|
q = db.query(ModelEndpoint).filter(ModelEndpoint.is_enabled == True)
|
||||||
|
owner = getattr(sess, "owner", None)
|
||||||
|
if owner:
|
||||||
|
from src.auth_helpers import owner_filter
|
||||||
|
q = owner_filter(q, ModelEndpoint, owner)
|
||||||
|
endpoints = q.all()
|
||||||
for ep in endpoints:
|
for ep in endpoints:
|
||||||
try:
|
try:
|
||||||
if normalize_base(getattr(ep, "base_url", "") or "") != session_base:
|
if normalize_base(getattr(ep, "base_url", "") or "") != session_base:
|
||||||
@@ -448,6 +517,7 @@ async def build_chat_context(
|
|||||||
webhook_manager=None,
|
webhook_manager=None,
|
||||||
use_enhanced_message: bool = False,
|
use_enhanced_message: bool = False,
|
||||||
agent_mode: bool = False,
|
agent_mode: bool = False,
|
||||||
|
allow_tool_preprocessing: bool = True,
|
||||||
) -> ChatContext:
|
) -> ChatContext:
|
||||||
"""Build the full context (preface + messages) for an LLM call.
|
"""Build the full context (preface + messages) for an LLM call.
|
||||||
|
|
||||||
@@ -465,6 +535,7 @@ async def build_chat_context(
|
|||||||
preprocessed = await preprocess(
|
preprocessed = await preprocess(
|
||||||
chat_handler, message, att_ids or [], sess,
|
chat_handler, message, att_ids or [], sess,
|
||||||
auto_opened_docs=auto_opened_docs,
|
auto_opened_docs=auto_opened_docs,
|
||||||
|
allow_tool_preprocessing=allow_tool_preprocessing,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add user message to history
|
# Add user message to history
|
||||||
@@ -483,6 +554,9 @@ async def build_chat_context(
|
|||||||
# Skills injection respects its own enable toggle (mirrors memory_enabled).
|
# Skills injection respects its own enable toggle (mirrors memory_enabled).
|
||||||
# When off, the "Available skills" index is not added to the prompt.
|
# When off, the "Available skills" index is not added to the prompt.
|
||||||
skills_enabled = not incognito and uprefs.get("skills_enabled", True)
|
skills_enabled = not incognito and uprefs.get("skills_enabled", True)
|
||||||
|
if not allow_tool_preprocessing:
|
||||||
|
mem_enabled = False
|
||||||
|
skills_enabled = False
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"Memory enabled=%s for user=%s (incognito=%s, no_memory=%s, pref=%s)",
|
"Memory enabled=%s for user=%s (incognito=%s, no_memory=%s, pref=%s)",
|
||||||
mem_enabled, user, incognito, no_memory, uprefs.get("memory_enabled", "NOT_SET"),
|
mem_enabled, user, incognito, no_memory, uprefs.get("memory_enabled", "NOT_SET"),
|
||||||
@@ -490,11 +564,11 @@ async def build_chat_context(
|
|||||||
|
|
||||||
# Use RAG?
|
# Use RAG?
|
||||||
use_rag_val = (str(use_rag).lower() != "false") if use_rag is not None else True
|
use_rag_val = (str(use_rag).lower() != "false") if use_rag is not None else True
|
||||||
if incognito:
|
if incognito or not allow_tool_preprocessing:
|
||||||
use_rag_val = False
|
use_rag_val = False
|
||||||
|
|
||||||
# If pre-fetched search context was provided (compare mode), skip live web search
|
# If pre-fetched search context was provided (compare mode), skip live web search
|
||||||
skip_web = bool(search_context)
|
skip_web = bool(search_context) or not allow_tool_preprocessing
|
||||||
|
|
||||||
# Build context preface
|
# Build context preface
|
||||||
# The stream path uses enhanced_message (with CoT/preprocessing applied),
|
# The stream path uses enhanced_message (with CoT/preprocessing applied),
|
||||||
@@ -521,7 +595,7 @@ async def build_chat_context(
|
|||||||
used_memories = getattr(chat_processor, '_last_used_memories', [])
|
used_memories = getattr(chat_processor, '_last_used_memories', [])
|
||||||
|
|
||||||
# Inject pre-fetched search context (compare mode)
|
# Inject pre-fetched search context (compare mode)
|
||||||
if search_context:
|
if search_context and allow_tool_preprocessing:
|
||||||
preface.append(untrusted_context_message("prefetched search context", search_context))
|
preface.append(untrusted_context_message("prefetched search context", search_context))
|
||||||
|
|
||||||
# YouTube transcripts
|
# YouTube transcripts
|
||||||
@@ -530,7 +604,11 @@ async def build_chat_context(
|
|||||||
|
|
||||||
# Normalize model ID. Prefer cached endpoint models so group chat does not
|
# Normalize model ID. Prefer cached endpoint models so group chat does not
|
||||||
# re-hit slow local /models endpoints on every participant turn.
|
# re-hit slow local /models endpoints on every participant turn.
|
||||||
norm = _normalize_model_id_from_cache(sess) or normalize_model_id(sess.endpoint_url, sess.model)
|
norm = _normalize_model_id_from_cache(sess) or normalize_model_id(
|
||||||
|
sess.endpoint_url,
|
||||||
|
sess.model,
|
||||||
|
owner=getattr(sess, "owner", None),
|
||||||
|
)
|
||||||
if norm:
|
if norm:
|
||||||
sess.model = norm
|
sess.model = norm
|
||||||
|
|
||||||
@@ -539,7 +617,7 @@ async def build_chat_context(
|
|||||||
|
|
||||||
# Auto-compact
|
# Auto-compact
|
||||||
messages, context_length, was_compacted = await maybe_compact(
|
messages, context_length, was_compacted = await maybe_compact(
|
||||||
sess, sess.endpoint_url, sess.model, messages, sess.headers,
|
sess, sess.endpoint_url, sess.model, messages, sess.headers, owner=user,
|
||||||
)
|
)
|
||||||
messages = trim_for_context(messages, context_length)
|
messages = trim_for_context(messages, context_length)
|
||||||
|
|
||||||
@@ -772,7 +850,19 @@ def save_assistant_response(
|
|||||||
):
|
):
|
||||||
"""Add assistant response to session history. In incognito mode, keeps in-memory context but skips DB persistence."""
|
"""Add assistant response to session history. In incognito mode, keeps in-memory context but skips DB persistence."""
|
||||||
md = dict(last_metrics) if last_metrics else {}
|
md = dict(last_metrics) if last_metrics else {}
|
||||||
md["model"] = sess.model
|
def _model_value(value) -> str:
|
||||||
|
if value is None:
|
||||||
|
return ""
|
||||||
|
if not isinstance(value, str):
|
||||||
|
value = str(value)
|
||||||
|
return value.strip()
|
||||||
|
|
||||||
|
requested_model = _model_value(md.get("requested_model") or md.get("selected_model") or getattr(sess, "model", ""))
|
||||||
|
actual_model = _model_value(md.get("model") or md.get("actual_model") or requested_model)
|
||||||
|
if requested_model:
|
||||||
|
md["requested_model"] = requested_model
|
||||||
|
if actual_model:
|
||||||
|
md["model"] = actual_model
|
||||||
if character_name:
|
if character_name:
|
||||||
md["character_name"] = character_name
|
md["character_name"] = character_name
|
||||||
if web_sources:
|
if web_sources:
|
||||||
@@ -841,12 +931,13 @@ def run_post_response_tasks(
|
|||||||
skills_manager=None,
|
skills_manager=None,
|
||||||
owner: str = None,
|
owner: str = None,
|
||||||
extract_skills: bool = True,
|
extract_skills: bool = True,
|
||||||
|
allow_background_extraction: bool = True,
|
||||||
):
|
):
|
||||||
"""Fire background tasks after a completed response: memory extraction, webhooks, auto-name, skill extraction."""
|
"""Fire background tasks after a completed response: memory extraction, webhooks, auto-name, skill extraction."""
|
||||||
# Memory extraction — only every 4th message pair to avoid excess LLM calls
|
# Memory extraction — only every 4th message pair to avoid excess LLM calls
|
||||||
_msg_count = len(sess.history) if hasattr(sess, 'history') else 0
|
_msg_count = len(sess.history) if hasattr(sess, 'history') else 0
|
||||||
_should_extract = (_msg_count >= 4) and (_msg_count % 4 == 0)
|
_should_extract = (_msg_count >= 4) and (_msg_count % 4 == 0)
|
||||||
if not incognito and not compare_mode and _should_extract and uprefs.get("auto_memory", True):
|
if allow_background_extraction and not incognito and not compare_mode and _should_extract and uprefs.get("auto_memory", True):
|
||||||
from services.memory.memory_extractor import extract_and_store
|
from services.memory.memory_extractor import extract_and_store
|
||||||
from src.task_endpoint import resolve_task_endpoint
|
from src.task_endpoint import resolve_task_endpoint
|
||||||
t_url, t_model, t_headers = resolve_task_endpoint(
|
t_url, t_model, t_headers = resolve_task_endpoint(
|
||||||
@@ -873,6 +964,7 @@ def run_post_response_tasks(
|
|||||||
)
|
)
|
||||||
if (
|
if (
|
||||||
extract_skills
|
extract_skills
|
||||||
|
and allow_background_extraction
|
||||||
and auto_skills_enabled
|
and auto_skills_enabled
|
||||||
and not incognito
|
and not incognito
|
||||||
and not compare_mode
|
and not compare_mode
|
||||||
|
|||||||
+206
-72
@@ -20,6 +20,7 @@ from src import agent_runs
|
|||||||
from src.model_context import estimate_tokens
|
from src.model_context import estimate_tokens
|
||||||
from src.chat_helpers import coerce_message_and_session
|
from src.chat_helpers import coerce_message_and_session
|
||||||
from src.endpoint_resolver import normalize_base as _normalize_base, build_chat_url
|
from src.endpoint_resolver import normalize_base as _normalize_base, build_chat_url
|
||||||
|
from src.session_search import search_session_messages
|
||||||
from src.prompt_security import untrusted_context_message
|
from src.prompt_security import untrusted_context_message
|
||||||
from core.exceptions import SessionNotFoundError
|
from core.exceptions import SessionNotFoundError
|
||||||
from src.auth_helpers import get_current_user
|
from src.auth_helpers import get_current_user
|
||||||
@@ -39,6 +40,7 @@ from routes.chat_helpers import (
|
|||||||
_enforce_chat_privileges,
|
_enforce_chat_privileges,
|
||||||
)
|
)
|
||||||
from src.action_intents import classify_tool_intent as _classify_tool_intent
|
from src.action_intents import classify_tool_intent as _classify_tool_intent
|
||||||
|
from src.tool_policy import build_effective_tool_policy
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -167,13 +169,20 @@ def _recover_empty_session_model(sess, session_id: str, owner: str | None = None
|
|||||||
Covers the window between endpoint setup and the first chat send: the
|
Covers the window between endpoint setup and the first chat send: the
|
||||||
picker showed a model in the dropdown but the session record never got
|
picker showed a model in the dropdown but the session record never got
|
||||||
written (Issue #587 — UI uses the cached endpoint list, not s.model).
|
written (Issue #587 — UI uses the cached endpoint list, not s.model).
|
||||||
Without this, we'd POST the upstream with model="" and get a generic
|
For ChatGPT Subscription, also repairs stale OpenAI API model names such as
|
||||||
401/503 instead of using the model the user already picked.
|
``gpt-5`` that are not accepted by the Codex-backed ChatGPT account route.
|
||||||
|
|
||||||
Returns True iff sess.model was repaired.
|
|
||||||
"""
|
"""
|
||||||
if getattr(sess, "model", None):
|
current_model = (getattr(sess, "model", "") or "").strip()
|
||||||
return False
|
endpoint_url = (getattr(sess, "endpoint_url", "") or "").strip()
|
||||||
|
is_chatgpt_subscription = False
|
||||||
|
if current_model:
|
||||||
|
try:
|
||||||
|
from src.chatgpt_subscription import is_chatgpt_subscription_base
|
||||||
|
is_chatgpt_subscription = is_chatgpt_subscription_base(endpoint_url)
|
||||||
|
if not is_chatgpt_subscription:
|
||||||
|
return False
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
db = SessionLocal()
|
db = SessionLocal()
|
||||||
try:
|
try:
|
||||||
# Prefer the endpoint whose base URL matches the session — we know the
|
# Prefer the endpoint whose base URL matches the session — we know the
|
||||||
@@ -192,16 +201,51 @@ def _recover_empty_session_model(sess, session_id: str, owner: str | None = None
|
|||||||
break
|
break
|
||||||
if not ep:
|
if not ep:
|
||||||
return False
|
return False
|
||||||
|
if not is_chatgpt_subscription:
|
||||||
|
try:
|
||||||
|
from src.chatgpt_subscription import is_chatgpt_subscription_base
|
||||||
|
is_chatgpt_subscription = is_chatgpt_subscription_base(getattr(ep, "base_url", "") or endpoint_url)
|
||||||
|
except Exception:
|
||||||
|
is_chatgpt_subscription = False
|
||||||
try:
|
try:
|
||||||
cached = json.loads(ep.cached_models) if isinstance(ep.cached_models, str) else (ep.cached_models or [])
|
cached = json.loads(ep.cached_models) if isinstance(ep.cached_models, str) else (ep.cached_models or [])
|
||||||
except Exception:
|
except Exception:
|
||||||
cached = []
|
cached = []
|
||||||
if not cached:
|
if not cached:
|
||||||
|
visible = []
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
visible = _visible_models(cached, getattr(ep, "hidden_models", None))
|
||||||
|
except Exception:
|
||||||
|
visible = cached
|
||||||
|
if current_model and current_model in {str(item).strip() for item in visible}:
|
||||||
return False
|
return False
|
||||||
try:
|
if is_chatgpt_subscription:
|
||||||
visible = _visible_models(cached, getattr(ep, "hidden_models", None))
|
live_models = []
|
||||||
except Exception:
|
if getattr(ep, "provider_auth_id", None):
|
||||||
visible = cached
|
try:
|
||||||
|
from src.chatgpt_subscription import fetch_available_models
|
||||||
|
from src.endpoint_resolver import resolve_endpoint_runtime
|
||||||
|
_base, api_key = resolve_endpoint_runtime(ep, owner=owner)
|
||||||
|
if api_key:
|
||||||
|
live_models = fetch_available_models(api_key)
|
||||||
|
if live_models:
|
||||||
|
ep.cached_models = json.dumps(live_models)
|
||||||
|
db.commit()
|
||||||
|
except Exception:
|
||||||
|
live_models = []
|
||||||
|
# ChatGPT Subscription recovery must use the live Codex catalog.
|
||||||
|
# Cached rows are only trusted above to avoid revalidating a model
|
||||||
|
# that is already present in the visible picker list.
|
||||||
|
cached = live_models
|
||||||
|
if not cached:
|
||||||
|
return False
|
||||||
|
try:
|
||||||
|
visible = _visible_models(cached, getattr(ep, "hidden_models", None))
|
||||||
|
except Exception:
|
||||||
|
visible = cached
|
||||||
|
if current_model and current_model in {str(item).strip() for item in visible}:
|
||||||
|
return False
|
||||||
if not visible:
|
if not visible:
|
||||||
return False
|
return False
|
||||||
model = visible[0]
|
model = visible[0]
|
||||||
@@ -211,14 +255,17 @@ def _recover_empty_session_model(sess, session_id: str, owner: str | None = None
|
|||||||
# Persist so the next request, websocket reconnect, or page reload
|
# Persist so the next request, websocket reconnect, or page reload
|
||||||
# picks up the same model (we'd otherwise re-pick on every send
|
# picks up the same model (we'd otherwise re-pick on every send
|
||||||
# and silently switch on the user if the cached order shifts).
|
# and silently switch on the user if the cached order shifts).
|
||||||
db_session = db.query(DBSession).filter(DBSession.id == session_id).first()
|
db_session_q = db.query(DBSession).filter(DBSession.id == session_id)
|
||||||
|
if owner:
|
||||||
|
db_session_q = db_session_q.filter(DBSession.owner == owner)
|
||||||
|
db_session = db_session_q.first()
|
||||||
if db_session:
|
if db_session:
|
||||||
db_session.model = model
|
db_session.model = model
|
||||||
db_session.updated_at = datetime.utcnow()
|
db_session.updated_at = datetime.utcnow()
|
||||||
db.commit()
|
db.commit()
|
||||||
sess.model = model
|
sess.model = model
|
||||||
logger.info(
|
logger.info(
|
||||||
"Recovered empty session model for %s — picked %r from endpoint %s",
|
"Recovered session model for %s — picked %r from endpoint %s",
|
||||||
session_id, model, ep.id,
|
session_id, model, ep.id,
|
||||||
)
|
)
|
||||||
return True
|
return True
|
||||||
@@ -304,8 +351,13 @@ def setup_chat_routes(
|
|||||||
# non-streaming path can't be used to bypass).
|
# non-streaming path can't be used to bypass).
|
||||||
_enforce_chat_privileges(request, sess)
|
_enforce_chat_privileges(request, sess)
|
||||||
|
|
||||||
|
tool_policy = build_effective_tool_policy(last_user_message=message)
|
||||||
|
allow_tool_preprocessing = not tool_policy.block_all_tool_calls
|
||||||
|
|
||||||
# Inline memory command
|
# Inline memory command
|
||||||
memory_response = await chat_handler.handle_memory_command(sess, message)
|
memory_response = None
|
||||||
|
if not tool_policy.blocks("manage_memory"):
|
||||||
|
memory_response = await chat_handler.handle_memory_command(sess, message)
|
||||||
if memory_response:
|
if memory_response:
|
||||||
return {"response": memory_response}
|
return {"response": memory_response}
|
||||||
|
|
||||||
@@ -319,10 +371,15 @@ def setup_chat_routes(
|
|||||||
use_web=use_web,
|
use_web=use_web,
|
||||||
time_filter=time_filter,
|
time_filter=time_filter,
|
||||||
webhook_manager=webhook_manager,
|
webhook_manager=webhook_manager,
|
||||||
|
allow_tool_preprocessing=allow_tool_preprocessing,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Research injection
|
# Research injection
|
||||||
if use_research:
|
research_blocked_by_policy = (
|
||||||
|
tool_policy.blocks("trigger_research")
|
||||||
|
or tool_policy.blocks("manage_research")
|
||||||
|
)
|
||||||
|
if use_research and not research_blocked_by_policy:
|
||||||
try:
|
try:
|
||||||
_r_ep, _r_model, _r_headers = _resolve_research_endpoint(sess)
|
_r_ep, _r_model, _r_headers = _resolve_research_endpoint(sess)
|
||||||
research_ctx = await research_handler.call_research_service(
|
research_ctx = await research_handler.call_research_service(
|
||||||
@@ -357,6 +414,7 @@ def setup_chat_routes(
|
|||||||
ctx.uprefs, memory_manager, memory_vector, webhook_manager,
|
ctx.uprefs, memory_manager, memory_vector, webhook_manager,
|
||||||
character_name=ctx.preset.character_name,
|
character_name=ctx.preset.character_name,
|
||||||
owner=ctx.user,
|
owner=ctx.user,
|
||||||
|
allow_background_extraction=not tool_policy.block_all_tool_calls,
|
||||||
)
|
)
|
||||||
|
|
||||||
return {"response": reply}
|
return {"response": reply}
|
||||||
@@ -394,6 +452,7 @@ def setup_chat_routes(
|
|||||||
search_context = form_data.get("search_context") # pre-fetched web search results (compare mode)
|
search_context = form_data.get("search_context") # pre-fetched web search results (compare mode)
|
||||||
compare_mode = str(form_data.get("compare_mode", "")).lower() == "true"
|
compare_mode = str(form_data.get("compare_mode", "")).lower() == "true"
|
||||||
incognito = str(form_data.get("incognito", "")).lower() == "true"
|
incognito = str(form_data.get("incognito", "")).lower() == "true"
|
||||||
|
plan_mode = str(form_data.get("plan_mode", "")).lower() == "true"
|
||||||
chat_mode = str(form_data.get("mode", "")).lower() # 'chat' or 'agent'
|
chat_mode = str(form_data.get("mode", "")).lower() # 'chat' or 'agent'
|
||||||
# Workspace: confine the agent's file/shell tools to this folder. Validate
|
# Workspace: confine the agent's file/shell tools to this folder. Validate
|
||||||
# it's a real directory; ignore (no confinement) otherwise.
|
# it's a real directory; ignore (no confinement) otherwise.
|
||||||
@@ -401,6 +460,17 @@ def setup_chat_routes(
|
|||||||
if workspace:
|
if workspace:
|
||||||
_ws_real = os.path.realpath(os.path.expanduser(workspace))
|
_ws_real = os.path.realpath(os.path.expanduser(workspace))
|
||||||
workspace = _ws_real if os.path.isdir(_ws_real) else ""
|
workspace = _ws_real if os.path.isdir(_ws_real) else ""
|
||||||
|
# Plan mode is a modifier on agent mode — it only makes sense with tools.
|
||||||
|
if plan_mode:
|
||||||
|
chat_mode = "agent"
|
||||||
|
# An approved plan being EXECUTED: the frontend sends the checklist back
|
||||||
|
# on each turn so we can pin it in context. This way a long plan on a
|
||||||
|
# weak model survives history truncation — the agent can always re-read
|
||||||
|
# the plan. Ignored while still proposing (plan_mode on). Capped so a
|
||||||
|
# huge plan can't blow the prompt.
|
||||||
|
approved_plan = ""
|
||||||
|
if not plan_mode:
|
||||||
|
approved_plan = (form_data.get("approved_plan") or "").strip()[:8192]
|
||||||
# Did the USER explicitly pick agent mode? (vs. us auto-escalating
|
# Did the USER explicitly pick agent mode? (vs. us auto-escalating
|
||||||
# below). Skill extraction should only learn from real agent sessions,
|
# below). Skill extraction should only learn from real agent sessions,
|
||||||
# not chats we quietly promoted for a notes/calendar intent.
|
# not chats we quietly promoted for a notes/calendar intent.
|
||||||
@@ -479,11 +549,6 @@ def setup_chat_routes(
|
|||||||
do_research = True
|
do_research = True
|
||||||
logger.info(f"Session {session} in research_pending — auto-triggering research")
|
logger.info(f"Session {session} in research_pending — auto-triggering research")
|
||||||
|
|
||||||
# Persist session mode (research > agent > chat)
|
|
||||||
_effective_mode = 'research' if do_research else (chat_mode or 'chat')
|
|
||||||
if _effective_mode in ('agent', 'research', 'chat'):
|
|
||||||
set_session_mode(session, _effective_mode)
|
|
||||||
|
|
||||||
att_ids = []
|
att_ids = []
|
||||||
if body and isinstance(body.get("attachments"), list):
|
if body and isinstance(body.get("attachments"), list):
|
||||||
att_ids = [str(x) for x in body["attachments"]]
|
att_ids = [str(x) for x in body["attachments"]]
|
||||||
@@ -494,6 +559,10 @@ def setup_chat_routes(
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
no_memory = str(form_data.get("no_memory", "")).lower() == "true"
|
no_memory = str(form_data.get("no_memory", "")).lower() == "true"
|
||||||
|
pre_context_tool_policy = build_effective_tool_policy(
|
||||||
|
last_user_message=message,
|
||||||
|
)
|
||||||
|
allow_tool_preprocessing = not pre_context_tool_policy.block_all_tool_calls
|
||||||
|
|
||||||
# Build shared context (stream path uses enhanced_message for context preface)
|
# Build shared context (stream path uses enhanced_message for context preface)
|
||||||
ctx = await build_chat_context(
|
ctx = await build_chat_context(
|
||||||
@@ -515,6 +584,7 @@ def setup_chat_routes(
|
|||||||
# manage_skills (agent mode). In plain chat or incognito the
|
# manage_skills (agent mode). In plain chat or incognito the
|
||||||
# index would be useless / unwanted noise.
|
# index would be useless / unwanted noise.
|
||||||
agent_mode=(chat_mode == "agent"),
|
agent_mode=(chat_mode == "agent"),
|
||||||
|
allow_tool_preprocessing=allow_tool_preprocessing,
|
||||||
)
|
)
|
||||||
|
|
||||||
_research_flags = {"do": do_research} # Mutable container for generator scope
|
_research_flags = {"do": do_research} # Mutable container for generator scope
|
||||||
@@ -659,6 +729,32 @@ def setup_chat_routes(
|
|||||||
if chat_mode == 'chat':
|
if chat_mode == 'chat':
|
||||||
disabled_tools.update({"bash", "python", "read_file", "write_file", "web_search", "web_fetch", "search_chats", "manage_tasks"})
|
disabled_tools.update({"bash", "python", "read_file", "write_file", "web_search", "web_fetch", "search_chats", "manage_tasks"})
|
||||||
|
|
||||||
|
# Plan mode: investigate read-only, propose a plan, don't mutate. Block
|
||||||
|
# every tool not on the read-only allowlist. (stream_agent_loop enforces
|
||||||
|
# this again + drops MCP, so this is belt-and-suspenders.)
|
||||||
|
if plan_mode:
|
||||||
|
from src.tool_security import plan_mode_disabled_tools
|
||||||
|
disabled_tools.update(plan_mode_disabled_tools())
|
||||||
|
|
||||||
|
tool_policy = build_effective_tool_policy(
|
||||||
|
disabled_tools=disabled_tools,
|
||||||
|
last_user_message=message,
|
||||||
|
)
|
||||||
|
disabled_tools = tool_policy.all_disabled_names()
|
||||||
|
research_blocked_by_policy = bool(
|
||||||
|
tool_policy.blocks("trigger_research")
|
||||||
|
or tool_policy.blocks("manage_research")
|
||||||
|
)
|
||||||
|
effective_do_research = bool(
|
||||||
|
do_research and _research_flags["do"] and not research_blocked_by_policy
|
||||||
|
)
|
||||||
|
|
||||||
|
# Persist session mode after policy/privilege gates so blocked research
|
||||||
|
# turns remain ordinary chat/agent streams and saved messages.
|
||||||
|
_effective_mode = 'research' if effective_do_research else (chat_mode or 'chat')
|
||||||
|
if _effective_mode in ('agent', 'research', 'chat'):
|
||||||
|
set_session_mode(session, _effective_mode)
|
||||||
|
|
||||||
async def stream_with_save() -> AsyncGenerator[str, None]:
|
async def stream_with_save() -> AsyncGenerator[str, None]:
|
||||||
# _effective_mode is read-only here; closure captures it from
|
# _effective_mode is read-only here; closure captures it from
|
||||||
# the outer scope. (Was `nonlocal` but never reassigned.)
|
# the outer scope. (Was `nonlocal` but never reassigned.)
|
||||||
@@ -666,7 +762,7 @@ def setup_chat_routes(
|
|||||||
web_sources = ctx.web_sources
|
web_sources = ctx.web_sources
|
||||||
|
|
||||||
# Register active stream for partial-save safety net
|
# Register active stream for partial-save safety net
|
||||||
_active_streams[session] = {"status": "streaming", "partial": "", "query": message, "is_research": do_research, "mode": _effective_mode}
|
_active_streams[session] = {"status": "streaming", "partial": "", "query": message, "is_research": effective_do_research, "mode": _effective_mode}
|
||||||
|
|
||||||
if ctx.preprocessed.attachment_meta:
|
if ctx.preprocessed.attachment_meta:
|
||||||
yield f"data: {json.dumps({'type': 'attachments', 'data': ctx.preprocessed.attachment_meta})}\n\n"
|
yield f"data: {json.dumps({'type': 'attachments', 'data': ctx.preprocessed.attachment_meta})}\n\n"
|
||||||
@@ -690,7 +786,7 @@ def setup_chat_routes(
|
|||||||
yield f"data: {json.dumps({'type': 'memories_used', 'data': ctx.used_memories})}\n\n"
|
yield f"data: {json.dumps({'type': 'memories_used', 'data': ctx.used_memories})}\n\n"
|
||||||
|
|
||||||
# Run research as a background task (survives page refresh)
|
# Run research as a background task (survives page refresh)
|
||||||
if do_research and _research_flags["do"]:
|
if effective_do_research:
|
||||||
_r_ep, _r_model, _r_headers = _resolve_research_endpoint(sess)
|
_r_ep, _r_model, _r_headers = _resolve_research_endpoint(sess)
|
||||||
_auth_keys = list(_r_headers.keys()) if _r_headers else []
|
_auth_keys = list(_r_headers.keys()) if _r_headers else []
|
||||||
logger.info(f"Research endpoint resolved: model={_r_model}, endpoint={_r_ep}, auth_keys={_auth_keys}, sess_headers_keys={list(sess.headers.keys()) if isinstance(sess.headers, dict) else type(sess.headers)}")
|
logger.info(f"Research endpoint resolved: model={_r_model}, endpoint={_r_ep}, auth_keys={_auth_keys}, sess_headers_keys={list(sess.headers.keys()) if isinstance(sess.headers, dict) else type(sess.headers)}")
|
||||||
@@ -829,7 +925,7 @@ def setup_chat_routes(
|
|||||||
_fallback_candidates = []
|
_fallback_candidates = []
|
||||||
|
|
||||||
# Send model name early so the frontend can show it during streaming
|
# Send model name early so the frontend can show it during streaming
|
||||||
_model_suffix = "Research" if do_research else None
|
_model_suffix = "Research" if effective_do_research else None
|
||||||
_model_info = {"type": "model_info", "model": sess.model}
|
_model_info = {"type": "model_info", "model": sess.model}
|
||||||
if _model_suffix:
|
if _model_suffix:
|
||||||
_model_info["suffix"] = _model_suffix
|
_model_info["suffix"] = _model_suffix
|
||||||
@@ -839,6 +935,12 @@ def setup_chat_routes(
|
|||||||
|
|
||||||
if _is_image_generation_session(sess, owner=_user):
|
if _is_image_generation_session(sess, owner=_user):
|
||||||
from src.settings import get_setting
|
from src.settings import get_setting
|
||||||
|
if tool_policy.blocks("generate_image"):
|
||||||
|
_blocked_msg = tool_policy.reason_for("generate_image")
|
||||||
|
yield f'data: {json.dumps({"delta": _blocked_msg})}\n\n'
|
||||||
|
yield "data: [DONE]\n\n"
|
||||||
|
_active_streams.pop(session, None)
|
||||||
|
return
|
||||||
if not get_setting("image_gen_enabled", True):
|
if not get_setting("image_gen_enabled", True):
|
||||||
yield f'data: {json.dumps({"delta": "Image generation is disabled by the administrator."})}\n\n'
|
yield f'data: {json.dumps({"delta": "Image generation is disabled by the administrator."})}\n\n'
|
||||||
yield "data: [DONE]\n\n"
|
yield "data: [DONE]\n\n"
|
||||||
@@ -873,6 +975,8 @@ def setup_chat_routes(
|
|||||||
elif chat_mode == "chat":
|
elif chat_mode == "chat":
|
||||||
_chat_start = time.time()
|
_chat_start = time.time()
|
||||||
_answered_by = None # set if the selected model failed and a fallback answered
|
_answered_by = None # set if the selected model failed and a fallback answered
|
||||||
|
_requested_model = sess.model
|
||||||
|
_actual_model = None
|
||||||
# ── Chat mode: call stream_llm directly, NO tools, NO document access ──
|
# ── Chat mode: call stream_llm directly, NO tools, NO document access ──
|
||||||
try:
|
try:
|
||||||
_chat_candidates = [(sess.endpoint_url, sess.model, sess.headers)] + _fallback_candidates
|
_chat_candidates = [(sess.endpoint_url, sess.model, sess.headers)] + _fallback_candidates
|
||||||
@@ -905,10 +1009,18 @@ def setup_chat_routes(
|
|||||||
# Selected model failed; a fallback answered.
|
# Selected model failed; a fallback answered.
|
||||||
# Forward the notice and remember the real model.
|
# Forward the notice and remember the real model.
|
||||||
_answered_by = data.get("answered_by") or _answered_by
|
_answered_by = data.get("answered_by") or _answered_by
|
||||||
|
_actual_model = _actual_model or _answered_by
|
||||||
|
data["selected_model"] = data.get("selected_model") or _requested_model
|
||||||
yield chunk
|
yield chunk
|
||||||
|
elif data.get("type") == "model_actual":
|
||||||
|
_actual_model = data.get("model") or _actual_model
|
||||||
|
data["requested_model"] = _requested_model
|
||||||
|
yield f'data: {json.dumps(data)}\n\n'
|
||||||
elif data.get("type") == "usage":
|
elif data.get("type") == "usage":
|
||||||
last_metrics = data.get("data", {})
|
last_metrics = data.get("data", {})
|
||||||
last_metrics["model"] = _answered_by or sess.model
|
_reported_model = last_metrics.get("model")
|
||||||
|
last_metrics["requested_model"] = _requested_model
|
||||||
|
last_metrics["model"] = _reported_model or _actual_model or _answered_by or _requested_model
|
||||||
if ctx.context_length and last_metrics.get("input_tokens"):
|
if ctx.context_length and last_metrics.get("input_tokens"):
|
||||||
pct = min(round((last_metrics["input_tokens"] / ctx.context_length) * 100, 1), 100.0)
|
pct = min(round((last_metrics["input_tokens"] / ctx.context_length) * 100, 1), 100.0)
|
||||||
last_metrics["context_percent"] = pct
|
last_metrics["context_percent"] = pct
|
||||||
@@ -945,7 +1057,8 @@ def setup_chat_routes(
|
|||||||
"tokens_per_second": _tps,
|
"tokens_per_second": _tps,
|
||||||
"context_percent": _ctx_pct,
|
"context_percent": _ctx_pct,
|
||||||
"context_length": ctx.context_length,
|
"context_length": ctx.context_length,
|
||||||
"model": sess.model,
|
"model": _actual_model or _answered_by or _requested_model,
|
||||||
|
"requested_model": _requested_model,
|
||||||
"usage_source": "estimated",
|
"usage_source": "estimated",
|
||||||
}
|
}
|
||||||
yield f'data: {json.dumps({"type": "metrics", "data": last_metrics})}\n\n'
|
yield f'data: {json.dumps({"type": "metrics", "data": last_metrics})}\n\n'
|
||||||
@@ -957,7 +1070,7 @@ def setup_chat_routes(
|
|||||||
rag_sources=ctx.rag_sources,
|
rag_sources=ctx.rag_sources,
|
||||||
research_sources=research_sources,
|
research_sources=research_sources,
|
||||||
used_memories=ctx.used_memories,
|
used_memories=ctx.used_memories,
|
||||||
do_research=do_research,
|
do_research=effective_do_research,
|
||||||
incognito=incognito,
|
incognito=incognito,
|
||||||
)
|
)
|
||||||
if _saved_id:
|
if _saved_id:
|
||||||
@@ -967,14 +1080,22 @@ def setup_chat_routes(
|
|||||||
last_metrics, ctx.uprefs, memory_manager, memory_vector, webhook_manager,
|
last_metrics, ctx.uprefs, memory_manager, memory_vector, webhook_manager,
|
||||||
incognito=incognito, compare_mode=compare_mode,
|
incognito=incognito, compare_mode=compare_mode,
|
||||||
character_name=ctx.preset.character_name,
|
character_name=ctx.preset.character_name,
|
||||||
owner=_user,
|
owner=_user,
|
||||||
|
allow_background_extraction=not tool_policy.block_all_tool_calls,
|
||||||
)
|
)
|
||||||
_stream_set(session, status="done")
|
_stream_set(session, status="done")
|
||||||
yield chunk
|
yield chunk
|
||||||
except (asyncio.CancelledError, GeneratorExit):
|
except (asyncio.CancelledError, GeneratorExit):
|
||||||
if full_response:
|
if full_response:
|
||||||
logger.info("Client disconnected mid-stream (chat mode) for session %s, saving partial (%d chars)", session, len(full_response))
|
logger.info("Client disconnected mid-stream (chat mode) for session %s, saving partial (%d chars)", session, len(full_response))
|
||||||
_stopped_content, _stopped_md = clean_thinking_for_save(full_response, {"stopped": True, "model": sess.model})
|
_stopped_content, _stopped_md = clean_thinking_for_save(
|
||||||
|
full_response,
|
||||||
|
{
|
||||||
|
"stopped": True,
|
||||||
|
"model": _actual_model or _answered_by or _requested_model,
|
||||||
|
"requested_model": _requested_model,
|
||||||
|
},
|
||||||
|
)
|
||||||
sess.add_message(ChatMessage("assistant", _stopped_content, metadata=_stopped_md))
|
sess.add_message(ChatMessage("assistant", _stopped_content, metadata=_stopped_md))
|
||||||
if not incognito:
|
if not incognito:
|
||||||
session_manager.save_sessions()
|
session_manager.save_sessions()
|
||||||
@@ -986,6 +1107,8 @@ def setup_chat_routes(
|
|||||||
_agent_rounds = 0
|
_agent_rounds = 0
|
||||||
_agent_tool_calls = 0
|
_agent_tool_calls = 0
|
||||||
_answered_by = None # set if the selected model failed and a fallback answered
|
_answered_by = None # set if the selected model failed and a fallback answered
|
||||||
|
_requested_model = sess.model
|
||||||
|
_actual_model = None
|
||||||
try:
|
try:
|
||||||
from src.settings import get_setting
|
from src.settings import get_setting
|
||||||
from src.agent_tools import MAX_AGENT_ROUNDS as _DEFAULT_ROUNDS
|
from src.agent_tools import MAX_AGENT_ROUNDS as _DEFAULT_ROUNDS
|
||||||
@@ -1012,9 +1135,12 @@ def setup_chat_routes(
|
|||||||
active_document=active_doc,
|
active_document=active_doc,
|
||||||
session_id=session,
|
session_id=session,
|
||||||
disabled_tools=disabled_tools if disabled_tools else None,
|
disabled_tools=disabled_tools if disabled_tools else None,
|
||||||
|
tool_policy=tool_policy,
|
||||||
owner=_user,
|
owner=_user,
|
||||||
fallbacks=_fallback_candidates,
|
fallbacks=_fallback_candidates,
|
||||||
workspace=workspace or None,
|
workspace=workspace or None,
|
||||||
|
plan_mode=plan_mode,
|
||||||
|
approved_plan=approved_plan or None,
|
||||||
):
|
):
|
||||||
if chunk.startswith("data: ") and not chunk.startswith("data: [DONE]"):
|
if chunk.startswith("data: ") and not chunk.startswith("data: [DONE]"):
|
||||||
try:
|
try:
|
||||||
@@ -1035,6 +1161,8 @@ def setup_chat_routes(
|
|||||||
"doc_stream_open", "doc_stream_delta",
|
"doc_stream_open", "doc_stream_delta",
|
||||||
"doc_update", "doc_suggestions", "ui_control",
|
"doc_update", "doc_suggestions", "ui_control",
|
||||||
"rounds_exhausted",
|
"rounds_exhausted",
|
||||||
|
"ask_user",
|
||||||
|
"plan_update",
|
||||||
):
|
):
|
||||||
if data.get("type") == "agent_step":
|
if data.get("type") == "agent_step":
|
||||||
_agent_rounds = max(_agent_rounds, data.get("round", 1))
|
_agent_rounds = max(_agent_rounds, data.get("round", 1))
|
||||||
@@ -1047,10 +1175,18 @@ def setup_chat_routes(
|
|||||||
# model so metrics reflect it, not the masked
|
# model so metrics reflect it, not the masked
|
||||||
# selected model.
|
# selected model.
|
||||||
_answered_by = data.get("answered_by") or _answered_by
|
_answered_by = data.get("answered_by") or _answered_by
|
||||||
|
_actual_model = _actual_model or _answered_by
|
||||||
|
data["selected_model"] = data.get("selected_model") or _requested_model
|
||||||
yield chunk
|
yield chunk
|
||||||
|
elif data.get("type") == "model_actual":
|
||||||
|
_actual_model = data.get("model") or _actual_model
|
||||||
|
data["requested_model"] = _requested_model
|
||||||
|
yield f'data: {json.dumps(data)}\n\n'
|
||||||
elif data.get("type") == "metrics":
|
elif data.get("type") == "metrics":
|
||||||
last_metrics = data.get("data", {})
|
last_metrics = data.get("data", {})
|
||||||
last_metrics["model"] = _answered_by or sess.model
|
_reported_model = last_metrics.get("model")
|
||||||
|
last_metrics["requested_model"] = last_metrics.get("requested_model") or _requested_model
|
||||||
|
last_metrics["model"] = _reported_model or _actual_model or _answered_by or _requested_model
|
||||||
yield f'data: {json.dumps({"type": "metrics", "data": last_metrics})}\n\n'
|
yield f'data: {json.dumps({"type": "metrics", "data": last_metrics})}\n\n'
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
yield chunk
|
yield chunk
|
||||||
@@ -1078,6 +1214,7 @@ def setup_chat_routes(
|
|||||||
skills_manager=skills_manager,
|
skills_manager=skills_manager,
|
||||||
owner=_user,
|
owner=_user,
|
||||||
extract_skills=user_requested_agent,
|
extract_skills=user_requested_agent,
|
||||||
|
allow_background_extraction=not tool_policy.block_all_tool_calls,
|
||||||
)
|
)
|
||||||
_stream_set(session, status="done")
|
_stream_set(session, status="done")
|
||||||
yield chunk
|
yield chunk
|
||||||
@@ -1091,7 +1228,14 @@ def setup_chat_routes(
|
|||||||
try:
|
try:
|
||||||
if full_response:
|
if full_response:
|
||||||
logger.info("Client disconnected mid-stream for session %s, saving partial response (%d chars)", session, len(full_response))
|
logger.info("Client disconnected mid-stream for session %s, saving partial response (%d chars)", session, len(full_response))
|
||||||
_stopped_content2, _stopped_md2 = clean_thinking_for_save(full_response, {"stopped": True, "model": sess.model})
|
_stopped_content2, _stopped_md2 = clean_thinking_for_save(
|
||||||
|
full_response,
|
||||||
|
{
|
||||||
|
"stopped": True,
|
||||||
|
"model": _actual_model or _answered_by or _requested_model,
|
||||||
|
"requested_model": _requested_model,
|
||||||
|
},
|
||||||
|
)
|
||||||
sess.add_message(ChatMessage("assistant", _stopped_content2, metadata=_stopped_md2))
|
sess.add_message(ChatMessage("assistant", _stopped_content2, metadata=_stopped_md2))
|
||||||
if not incognito:
|
if not incognito:
|
||||||
session_manager.save_sessions()
|
session_manager.save_sessions()
|
||||||
@@ -1110,11 +1254,30 @@ def setup_chat_routes(
|
|||||||
finally:
|
finally:
|
||||||
_active_streams.pop(session, None)
|
_active_streams.pop(session, None)
|
||||||
|
|
||||||
# Run the stream as a DETACHED background task so it survives the client
|
# Compare panes are short-lived, single-shot generations whose sessions
|
||||||
# closing the tab / navigating away (true terminal-agent behavior). The
|
# exist only to drive that one pane — there's nothing to "resume" and
|
||||||
# SSE response just subscribes (replay buffered output + live); dropping
|
# the user expects the pane's Stop button (which aborts the fetch,
|
||||||
# the SSE only removes a subscriber — the run keeps going and saves the
|
# closing this SSE) to promptly cancel the upstream LLM call. Detaching
|
||||||
# assistant message on completion regardless. Reconnect via /api/chat/resume.
|
# them would keep burning upstream tokens/compute after the pane is
|
||||||
|
# stopped or the comparison is abandoned, and would surface a stale
|
||||||
|
# "still streaming" /resume target for a session nobody will revisit.
|
||||||
|
#
|
||||||
|
# So: stream them directly (no agent_runs wrapping). Starlette cancels
|
||||||
|
# the underlying async generator (raising CancelledError/GeneratorExit
|
||||||
|
# inside it) as soon as it notices the client disconnected — which the
|
||||||
|
# mode-specific except blocks above already handle by saving the
|
||||||
|
# partial response exactly once. This stops the upstream call promptly
|
||||||
|
# without waiting on the next streamed chunk.
|
||||||
|
#
|
||||||
|
# Normal chat/agent streams keep the DETACHED behavior below: they
|
||||||
|
# survive the client closing the tab / navigating away (true
|
||||||
|
# terminal-agent semantics). The SSE response just subscribes (replay
|
||||||
|
# buffered output + live); dropping the SSE only removes a subscriber —
|
||||||
|
# the run keeps going and saves the assistant message on completion
|
||||||
|
# regardless. Reconnect via /api/chat/resume.
|
||||||
|
if compare_mode:
|
||||||
|
return StreamingResponse(_safe_stream(), media_type="text/event-stream")
|
||||||
|
|
||||||
agent_runs.start(session, _safe_stream())
|
agent_runs.start(session, _safe_stream())
|
||||||
return StreamingResponse(agent_runs.subscribe(session), media_type="text/event-stream")
|
return StreamingResponse(agent_runs.subscribe(session), media_type="text/event-stream")
|
||||||
|
|
||||||
@@ -1185,45 +1348,16 @@ def setup_chat_routes(
|
|||||||
return []
|
return []
|
||||||
|
|
||||||
_user = get_current_user(request)
|
_user = get_current_user(request)
|
||||||
query_term = q.strip()
|
return [
|
||||||
db = SessionLocal()
|
result.to_dict()
|
||||||
try:
|
for result in search_session_messages(
|
||||||
base_q = (
|
q,
|
||||||
db.query(DBChatMessage, DBSession.name)
|
limit=limit,
|
||||||
.join(DBSession, DBChatMessage.session_id == DBSession.id)
|
owner=_user,
|
||||||
.filter(
|
restrict_owner=_user is not None,
|
||||||
DBSession.archived == False,
|
include_legacy_owner=False,
|
||||||
DBChatMessage.content.ilike(f"%{query_term}%"),
|
|
||||||
DBChatMessage.role.in_(["user", "assistant"]),
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
if _user:
|
]
|
||||||
base_q = base_q.filter(DBSession.owner == _user)
|
|
||||||
rows = base_q.order_by(DBChatMessage.timestamp.desc()).limit(limit).all()
|
|
||||||
|
|
||||||
results = []
|
|
||||||
for msg, session_name in rows:
|
|
||||||
content = msg.content or ""
|
|
||||||
lower_content = content.lower()
|
|
||||||
idx = lower_content.find(query_term.lower())
|
|
||||||
if idx == -1:
|
|
||||||
snippet = content[:120]
|
|
||||||
else:
|
|
||||||
start = max(0, idx - 50)
|
|
||||||
end = min(len(content), idx + len(query_term) + 50)
|
|
||||||
snippet = ("..." if start > 0 else "") + content[start:end] + ("..." if end < len(content) else "")
|
|
||||||
|
|
||||||
results.append({
|
|
||||||
"session_id": msg.session_id,
|
|
||||||
"session_name": session_name or "Untitled",
|
|
||||||
"role": msg.role,
|
|
||||||
"content_snippet": snippet,
|
|
||||||
"timestamp": msg.timestamp.isoformat() if msg.timestamp else None,
|
|
||||||
})
|
|
||||||
|
|
||||||
return results
|
|
||||||
finally:
|
|
||||||
db.close()
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------ #
|
# ------------------------------------------------------------------ #
|
||||||
# POST /api/rewrite — lightweight rewrite of last AI message (no tools)
|
# POST /api/rewrite — lightweight rewrite of last AI message (no tools)
|
||||||
|
|||||||
@@ -0,0 +1,170 @@
|
|||||||
|
"""ChatGPT Subscription device-flow setup routes."""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import uuid
|
||||||
|
from typing import Dict, Optional
|
||||||
|
|
||||||
|
from fastapi import HTTPException, Request
|
||||||
|
|
||||||
|
from core.database import ModelEndpoint, ProviderAuthSession, SessionLocal, utcnow_naive
|
||||||
|
from routes.device_flow import (
|
||||||
|
DeviceFlowPoll,
|
||||||
|
DeviceFlowStart,
|
||||||
|
PendingDeviceFlowStore,
|
||||||
|
create_device_flow_router,
|
||||||
|
)
|
||||||
|
from src.auth_helpers import get_current_user
|
||||||
|
from src import chatgpt_subscription
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_DEVICE_FLOW_STORE = PendingDeviceFlowStore()
|
||||||
|
|
||||||
|
|
||||||
|
def _provision_endpoint(tokens: Dict, owner: Optional[str]) -> Dict:
|
||||||
|
access_token = tokens.get("access_token")
|
||||||
|
refresh_token = tokens.get("refresh_token")
|
||||||
|
if not access_token or not refresh_token:
|
||||||
|
raise ValueError("ChatGPT token response was missing access_token or refresh_token")
|
||||||
|
|
||||||
|
base = chatgpt_subscription.DEFAULT_CHATGPT_SUBSCRIPTION_BASE_URL
|
||||||
|
models = chatgpt_subscription.fetch_available_models(access_token)
|
||||||
|
if not models:
|
||||||
|
raise ValueError("ChatGPT Subscription connected, but no usable Codex models were discovered for this account.")
|
||||||
|
db = SessionLocal()
|
||||||
|
try:
|
||||||
|
auth = (
|
||||||
|
db.query(ProviderAuthSession)
|
||||||
|
.filter(
|
||||||
|
ProviderAuthSession.provider == chatgpt_subscription.CHATGPT_SUBSCRIPTION_PROVIDER,
|
||||||
|
ProviderAuthSession.owner == owner,
|
||||||
|
)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
if auth is None:
|
||||||
|
auth = ProviderAuthSession(
|
||||||
|
id=str(uuid.uuid4())[:8],
|
||||||
|
provider=chatgpt_subscription.CHATGPT_SUBSCRIPTION_PROVIDER,
|
||||||
|
owner=owner,
|
||||||
|
label="ChatGPT Subscription",
|
||||||
|
base_url=base,
|
||||||
|
auth_mode="chatgpt",
|
||||||
|
)
|
||||||
|
db.add(auth)
|
||||||
|
auth.base_url = base
|
||||||
|
auth.access_token = access_token
|
||||||
|
auth.refresh_token = refresh_token
|
||||||
|
auth.last_refresh = utcnow_naive()
|
||||||
|
auth.auth_mode = "chatgpt"
|
||||||
|
|
||||||
|
ep = (
|
||||||
|
db.query(ModelEndpoint)
|
||||||
|
.filter(
|
||||||
|
ModelEndpoint.base_url == base,
|
||||||
|
ModelEndpoint.provider_auth_id == auth.id,
|
||||||
|
ModelEndpoint.owner == owner,
|
||||||
|
)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
if ep is None:
|
||||||
|
ep = ModelEndpoint(
|
||||||
|
id=str(uuid.uuid4())[:8],
|
||||||
|
name="ChatGPT Subscription",
|
||||||
|
base_url=base,
|
||||||
|
model_type="llm",
|
||||||
|
endpoint_kind="api",
|
||||||
|
owner=owner,
|
||||||
|
)
|
||||||
|
db.add(ep)
|
||||||
|
ep.name = "ChatGPT Subscription"
|
||||||
|
ep.base_url = base
|
||||||
|
ep.api_key = None
|
||||||
|
ep.provider_auth_id = auth.id
|
||||||
|
ep.is_enabled = True
|
||||||
|
ep.supports_tools = False
|
||||||
|
ep.model_type = "llm"
|
||||||
|
ep.endpoint_kind = "api"
|
||||||
|
ep.model_refresh_mode = "manual"
|
||||||
|
ep.cached_models = json.dumps(models)
|
||||||
|
db.commit()
|
||||||
|
result = {
|
||||||
|
"id": ep.id,
|
||||||
|
"name": ep.name,
|
||||||
|
"base_url": ep.base_url,
|
||||||
|
"models": models,
|
||||||
|
}
|
||||||
|
finally:
|
||||||
|
db.close()
|
||||||
|
|
||||||
|
try:
|
||||||
|
from routes.model_routes import _invalidate_models_cache
|
||||||
|
|
||||||
|
_invalidate_models_cache()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def _start_device_flow(request: Request, _form) -> DeviceFlowStart:
|
||||||
|
try:
|
||||||
|
data = chatgpt_subscription.request_device_code()
|
||||||
|
except Exception as exc:
|
||||||
|
raise chatgpt_subscription.to_http_exception(exc)
|
||||||
|
|
||||||
|
device_auth_id = data.get("device_auth_id")
|
||||||
|
user_code = data.get("user_code")
|
||||||
|
if not device_auth_id or not user_code:
|
||||||
|
raise HTTPException(502, "ChatGPT did not return a complete device code")
|
||||||
|
verification_uri = data.get("verification_uri") or f"{chatgpt_subscription.CHATGPT_OAUTH_ISSUER}/codex/device"
|
||||||
|
return DeviceFlowStart(
|
||||||
|
pending={
|
||||||
|
"device_auth_id": device_auth_id,
|
||||||
|
"user_code": user_code,
|
||||||
|
"owner": get_current_user(request) or None,
|
||||||
|
},
|
||||||
|
response={
|
||||||
|
"user_code": user_code,
|
||||||
|
"verification_uri": verification_uri,
|
||||||
|
},
|
||||||
|
interval=int(data.get("interval") or 5),
|
||||||
|
expires_in=int(data.get("expires_in") or 900),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _poll_device_flow(_request: Request, pending: Dict) -> DeviceFlowPoll:
|
||||||
|
try:
|
||||||
|
data = chatgpt_subscription.poll_device_auth(pending["device_auth_id"], pending["user_code"])
|
||||||
|
except Exception as exc:
|
||||||
|
logger.debug("ChatGPT device poll failed: %s", exc)
|
||||||
|
return DeviceFlowPoll.pending(str(exc))
|
||||||
|
|
||||||
|
authorization_code = data.get("authorization_code")
|
||||||
|
code_verifier = data.get("code_verifier")
|
||||||
|
if authorization_code and code_verifier:
|
||||||
|
try:
|
||||||
|
tokens = chatgpt_subscription.exchange_authorization_code(authorization_code, code_verifier)
|
||||||
|
result = _provision_endpoint(tokens, pending["owner"])
|
||||||
|
except Exception as exc:
|
||||||
|
logger.exception("ChatGPT Subscription endpoint provisioning failed")
|
||||||
|
raise chatgpt_subscription.to_http_exception(exc)
|
||||||
|
return DeviceFlowPoll.authorized(result)
|
||||||
|
|
||||||
|
err = data.get("error") or data.get("status")
|
||||||
|
if err in ("authorization_pending", "pending", None):
|
||||||
|
return DeviceFlowPoll.pending()
|
||||||
|
if err == "slow_down":
|
||||||
|
return DeviceFlowPoll.slow_down(int(data.get("interval") or 0) or None)
|
||||||
|
if err in ("expired_token", "access_denied", "denied"):
|
||||||
|
return DeviceFlowPoll.failed(err)
|
||||||
|
return DeviceFlowPoll.pending(err or "unknown")
|
||||||
|
|
||||||
|
|
||||||
|
def setup_chatgpt_subscription_routes():
|
||||||
|
return create_device_flow_router(
|
||||||
|
prefix="/api/chatgpt-subscription",
|
||||||
|
tags=["chatgpt-subscription"],
|
||||||
|
store=_DEVICE_FLOW_STORE,
|
||||||
|
start_flow=_start_device_flow,
|
||||||
|
poll_flow=_poll_device_flow,
|
||||||
|
)
|
||||||
+16
-6
@@ -15,8 +15,9 @@ from typing import Any
|
|||||||
from fastapi import APIRouter, BackgroundTasks, Body, HTTPException, Request
|
from fastapi import APIRouter, BackgroundTasks, Body, HTTPException, Request
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
|
|
||||||
from src.auth_helpers import require_user
|
from src.auth_helpers import require_authenticated_request, require_user
|
||||||
from src.tool_implementations import do_manage_notes
|
from src.tool_implementations import do_manage_notes
|
||||||
|
from src.constants import COOKBOOK_STATE_FILE
|
||||||
|
|
||||||
|
|
||||||
COOKBOOK_READ_SCOPES = {"cookbook:read", "cookbook:launch"}
|
COOKBOOK_READ_SCOPES = {"cookbook:read", "cookbook:launch"}
|
||||||
@@ -41,7 +42,9 @@ async def _as_owner(request: Request, owner: str, fn, *args, **kwargs):
|
|||||||
the scope-gated owner (not the "api" pseudo-user the bearer middleware sets).
|
the scope-gated owner (not the "api" pseudo-user the bearer middleware sets).
|
||||||
Restores the original value when done. Works for sync and async handlers."""
|
Restores the original value when done. Works for sync and async handlers."""
|
||||||
orig = getattr(request.state, "current_user", None)
|
orig = getattr(request.state, "current_user", None)
|
||||||
|
orig_api_token = getattr(request.state, "api_token", None)
|
||||||
request.state.current_user = owner
|
request.state.current_user = owner
|
||||||
|
request.state.api_token = False
|
||||||
try:
|
try:
|
||||||
result = fn(*args, **kwargs)
|
result = fn(*args, **kwargs)
|
||||||
if asyncio.iscoroutine(result):
|
if asyncio.iscoroutine(result):
|
||||||
@@ -49,6 +52,13 @@ async def _as_owner(request: Request, owner: str, fn, *args, **kwargs):
|
|||||||
return result
|
return result
|
||||||
finally:
|
finally:
|
||||||
request.state.current_user = orig
|
request.state.current_user = orig
|
||||||
|
if orig_api_token is None:
|
||||||
|
try:
|
||||||
|
delattr(request.state, "api_token")
|
||||||
|
except AttributeError:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
request.state.api_token = orig_api_token
|
||||||
|
|
||||||
|
|
||||||
def _scope_owner(request: Request, allowed: set[str]) -> str:
|
def _scope_owner(request: Request, allowed: set[str]) -> str:
|
||||||
@@ -146,7 +156,7 @@ def setup_codex_routes(
|
|||||||
|
|
||||||
@router.get("/plugin.zip")
|
@router.get("/plugin.zip")
|
||||||
def plugin_zip(request: Request):
|
def plugin_zip(request: Request):
|
||||||
require_user(request)
|
require_authenticated_request(request)
|
||||||
root = Path(__file__).resolve().parent.parent / "integrations" / "codex"
|
root = Path(__file__).resolve().parent.parent / "integrations" / "codex"
|
||||||
if not root.exists():
|
if not root.exists():
|
||||||
raise HTTPException(404, "Codex plugin bundle not found")
|
raise HTTPException(404, "Codex plugin bundle not found")
|
||||||
@@ -415,8 +425,8 @@ def setup_codex_routes(
|
|||||||
|
|
||||||
def _read_cookbook_state() -> dict:
|
def _read_cookbook_state() -> dict:
|
||||||
from pathlib import Path as _Path
|
from pathlib import Path as _Path
|
||||||
import os as _os, json as _json
|
import json as _json
|
||||||
p = _Path(_os.environ.get("DATA_DIR", "data")) / "cookbook_state.json"
|
p = _Path(COOKBOOK_STATE_FILE)
|
||||||
if not p.exists():
|
if not p.exists():
|
||||||
return {}
|
return {}
|
||||||
try:
|
try:
|
||||||
@@ -724,7 +734,7 @@ def setup_codex_routes(
|
|||||||
import time as _t, json as _json
|
import time as _t, json as _json
|
||||||
from core.atomic_io import atomic_write_json
|
from core.atomic_io import atomic_write_json
|
||||||
from pathlib import Path as _Path
|
from pathlib import Path as _Path
|
||||||
cookbook_state_path = _Path("/app/data/cookbook_state.json")
|
cookbook_state_path = _Path(COOKBOOK_STATE_FILE)
|
||||||
try:
|
try:
|
||||||
state = _json.loads(cookbook_state_path.read_text(encoding="utf-8"))
|
state = _json.loads(cookbook_state_path.read_text(encoding="utf-8"))
|
||||||
except Exception:
|
except Exception:
|
||||||
@@ -762,7 +772,7 @@ def setup_claude_routes() -> APIRouter:
|
|||||||
|
|
||||||
@router.get("/plugin.zip")
|
@router.get("/plugin.zip")
|
||||||
def plugin_zip(request: Request):
|
def plugin_zip(request: Request):
|
||||||
require_user(request)
|
require_authenticated_request(request)
|
||||||
# Only ship the skills/ subtree so extracting at ~/.claude/ doesn't dump
|
# Only ship the skills/ subtree so extracting at ~/.claude/ doesn't dump
|
||||||
# README.md or other bundle metadata into the user's claude config dir.
|
# README.md or other bundle metadata into the user's claude config dir.
|
||||||
skills_root = Path(__file__).resolve().parent.parent / "integrations" / "claude" / "skills"
|
skills_root = Path(__file__).resolve().parent.parent / "integrations" / "claude" / "skills"
|
||||||
|
|||||||
+110
-22
@@ -12,6 +12,7 @@ import logging
|
|||||||
from core.database import Comparison, SessionLocal
|
from core.database import Comparison, SessionLocal
|
||||||
from core.session_manager import SessionManager
|
from core.session_manager import SessionManager
|
||||||
from src.auth_helpers import get_current_user
|
from src.auth_helpers import get_current_user
|
||||||
|
from routes.session_routes import _reject_raw_endpoint_url_for_non_admin
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -38,6 +39,24 @@ def _owned_endpoint_by_url(db, base_url, owner):
|
|||||||
return owner_filter(q, ModelEndpoint, owner).first()
|
return owner_filter(q, ModelEndpoint, owner).first()
|
||||||
|
|
||||||
|
|
||||||
|
def _owned_endpoint_by_id(db, endpoint_id, owner):
|
||||||
|
"""ModelEndpoint whose id == `endpoint_id` and is VISIBLE to `owner` (their
|
||||||
|
own rows + legacy null-owner "shared" rows); None otherwise.
|
||||||
|
|
||||||
|
Preferred over _owned_endpoint_by_url for credential resolution: two visible
|
||||||
|
endpoints can share the same base_url but hold DIFFERENT api_keys (e.g. two
|
||||||
|
accounts on the same provider). A base_url-only match returns whichever row
|
||||||
|
sorts first, so it can copy the WRONG owner-scoped key into the [CMP] session.
|
||||||
|
An id pins the exact registered endpoint, so /api/compare/start prefers it and
|
||||||
|
only falls back to URL matching for legacy / admin raw-URL callers. Owner
|
||||||
|
scoping is identical to _owned_endpoint_by_url (a null/empty owner is a no-op).
|
||||||
|
"""
|
||||||
|
from core.database import ModelEndpoint
|
||||||
|
from src.auth_helpers import owner_filter
|
||||||
|
q = db.query(ModelEndpoint).filter(ModelEndpoint.id == endpoint_id)
|
||||||
|
return owner_filter(q, ModelEndpoint, owner).first()
|
||||||
|
|
||||||
|
|
||||||
class RecordVoteRequest(BaseModel):
|
class RecordVoteRequest(BaseModel):
|
||||||
prompt: str
|
prompt: str
|
||||||
models: List[str]
|
models: List[str]
|
||||||
@@ -54,8 +73,10 @@ def setup_compare_routes(session_manager: SessionManager):
|
|||||||
prompt: str = Form(...),
|
prompt: str = Form(...),
|
||||||
model_a: str = Form(...),
|
model_a: str = Form(...),
|
||||||
model_b: str = Form(...),
|
model_b: str = Form(...),
|
||||||
endpoint_a: str = Form(...),
|
endpoint_a: str = Form(""),
|
||||||
endpoint_b: str = Form(...),
|
endpoint_b: str = Form(""),
|
||||||
|
endpoint_a_id: str = Form(""),
|
||||||
|
endpoint_b_id: str = Form(""),
|
||||||
is_blind: str = Form("true"),
|
is_blind: str = Form("true"),
|
||||||
):
|
):
|
||||||
"""Create two ephemeral sessions and a comparison record.
|
"""Create two ephemeral sessions and a comparison record.
|
||||||
@@ -63,10 +84,10 @@ def setup_compare_routes(session_manager: SessionManager):
|
|||||||
Returns the comparison ID and the two session IDs so the client
|
Returns the comparison ID and the two session IDs so the client
|
||||||
can fire two independent SSE streams to /api/chat_stream.
|
can fire two independent SSE streams to /api/chat_stream.
|
||||||
"""
|
"""
|
||||||
|
user = getattr(request.state, 'current_user', None)
|
||||||
comp_id = str(uuid.uuid4())
|
comp_id = str(uuid.uuid4())
|
||||||
sid_a = str(uuid.uuid4())
|
sid_a = str(uuid.uuid4())
|
||||||
sid_b = str(uuid.uuid4())
|
sid_b = str(uuid.uuid4())
|
||||||
user = getattr(request.state, 'current_user', None)
|
|
||||||
|
|
||||||
# Blind mapping: randomly assign left/right
|
# Blind mapping: randomly assign left/right
|
||||||
blind = str(is_blind).lower() == "true"
|
blind = str(is_blind).lower() == "true"
|
||||||
@@ -87,31 +108,94 @@ def setup_compare_routes(session_manager: SessionManager):
|
|||||||
# de-anonymizing the comparison before the user votes (issue #1285).
|
# de-anonymizing the comparison before the user votes (issue #1285).
|
||||||
slot_name = {session_left: "Model A", session_right: "Model B"}
|
slot_name = {session_left: "Model A", session_right: "Model B"}
|
||||||
|
|
||||||
# Create ephemeral sessions (prefixed [CMP])
|
# SECURITY: resolve and validate BOTH endpoints before creating any
|
||||||
for sid, model, endpoint in [(sid_a, model_a, endpoint_a), (sid_b, model_b, endpoint_b)]:
|
# session. Compare copies a registered endpoint's Authorization header
|
||||||
|
# into the [CMP] session, so validating one endpoint while creating its
|
||||||
|
# session, then rejecting the other, would leave a partial compare
|
||||||
|
# session behind with that header attached. Doing all the owner-scope
|
||||||
|
# resolution + raw-URL rejection up front means a 403 on either endpoint
|
||||||
|
# aborts the whole request with nothing created and no header copied.
|
||||||
|
from src.endpoint_resolver import build_chat_url, build_headers, normalize_base
|
||||||
|
resolved = []
|
||||||
|
db = SessionLocal()
|
||||||
|
try:
|
||||||
|
for sid, model, endpoint, endpoint_id in [
|
||||||
|
(sid_a, model_a, endpoint_a, endpoint_a_id),
|
||||||
|
(sid_b, model_b, endpoint_b, endpoint_b_id),
|
||||||
|
]:
|
||||||
|
# Prefer an explicit endpoint id: it pins the EXACT registered
|
||||||
|
# endpoint (and its api_key), even when two endpoints visible to
|
||||||
|
# the caller share a base_url with different keys — a URL-only
|
||||||
|
# match would copy whichever row sorts first, i.e. possibly the
|
||||||
|
# wrong key. Fall back to URL resolution only for legacy / admin
|
||||||
|
# raw-URL callers that don't send an id.
|
||||||
|
eid = endpoint_id.strip() if isinstance(endpoint_id, str) else ""
|
||||||
|
if eid:
|
||||||
|
ep = _owned_endpoint_by_id(db, eid, user)
|
||||||
|
if ep is None:
|
||||||
|
# An id the caller can't see (wrong owner / deleted) must
|
||||||
|
# NOT silently fall back to a same-URL row with a different
|
||||||
|
# key — that's exactly the mix-up ids exist to prevent.
|
||||||
|
raise HTTPException(404, "Model endpoint not found")
|
||||||
|
# The id already resolved the endpoint; ignore any raw URL the
|
||||||
|
# caller also sent and dial the stored config instead.
|
||||||
|
endpoint = ep.base_url
|
||||||
|
elif not endpoint:
|
||||||
|
raise HTTPException(
|
||||||
|
422, "endpoint_a/endpoint_b or endpoint_a_id/endpoint_b_id is required"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Resolve the supplied URL to a ModelEndpoint the caller owns
|
||||||
|
# (their own rows + legacy null-owner shared rows), scoped so a
|
||||||
|
# comparison can't borrow another user's private endpoint key.
|
||||||
|
base = normalize_base(endpoint)
|
||||||
|
ep = _owned_endpoint_by_url(db, base, user)
|
||||||
|
# Reject *unregistered* raw URLs for signed-in non-admins; a
|
||||||
|
# matched registered endpoint supplies an id so the caller can
|
||||||
|
# still compare endpoints they own. Blanket-rejecting here (the
|
||||||
|
# earlier `endpoint_id=None` call) locked non-admins out of
|
||||||
|
# compare entirely, since compare resolves endpoints by URL with
|
||||||
|
# no endpoint_id. Mirrors the gallery inpaint/harmonize checks.
|
||||||
|
# Raised here (phase 1), before any session exists.
|
||||||
|
_reject_raw_endpoint_url_for_non_admin(
|
||||||
|
request, user, str(ep.id) if ep is not None else None, endpoint
|
||||||
|
)
|
||||||
|
# Bind the [CMP] session to the RESOLVED endpoint, not the raw
|
||||||
|
# caller-supplied string. When the URL matches a registered
|
||||||
|
# endpoint visible to the caller, use that row's own normalized
|
||||||
|
# base URL (the same value owner scoping + endpoint validation
|
||||||
|
# already vetted) so the session dials exactly where the stored
|
||||||
|
# config points. The raw `endpoint` only survives for callers
|
||||||
|
# allowed to pass one — admins / single-user mode, where
|
||||||
|
# `_reject_raw_endpoint_url_for_non_admin` is a no-op and `ep`
|
||||||
|
# is None. Mirrors the registered-endpoint path in session_routes.
|
||||||
|
session_endpoint_url = (
|
||||||
|
build_chat_url(normalize_base(ep.base_url)) if ep is not None else endpoint
|
||||||
|
)
|
||||||
|
# Headers come only from a matched endpoint's key; None when
|
||||||
|
# `ep` is None (raw admin URL or no match), so a comparison can
|
||||||
|
# never inherit another user's key/headers.
|
||||||
|
headers = build_headers(ep.api_key, ep.base_url) if (ep and ep.api_key) else None
|
||||||
|
resolved.append((sid, model, session_endpoint_url, headers))
|
||||||
|
finally:
|
||||||
|
db.close()
|
||||||
|
|
||||||
|
# Both endpoints validated — only now create the ephemeral [CMP]
|
||||||
|
# sessions and copy any resolved headers.
|
||||||
|
for sid, model, session_endpoint_url, headers in resolved:
|
||||||
name = f"[CMP] {slot_name[sid]}" if blind else f"[CMP] {model.split('/')[-1]}"
|
name = f"[CMP] {slot_name[sid]}" if blind else f"[CMP] {model.split('/')[-1]}"
|
||||||
session_manager.create_session(
|
session_manager.create_session(
|
||||||
session_id=sid,
|
session_id=sid,
|
||||||
name=name,
|
name=name,
|
||||||
endpoint_url=endpoint,
|
endpoint_url=session_endpoint_url,
|
||||||
model=model,
|
model=model,
|
||||||
rag=False,
|
rag=False,
|
||||||
owner=user,
|
owner=user,
|
||||||
)
|
)
|
||||||
# Copy API key from endpoint config
|
if headers:
|
||||||
db = SessionLocal()
|
s = session_manager.sessions.get(sid)
|
||||||
try:
|
if s:
|
||||||
from src.endpoint_resolver import build_headers, normalize_base
|
s.headers = headers
|
||||||
# Find matching endpoint by URL, scoped to the caller so a
|
|
||||||
# comparison can't borrow another user's private endpoint key.
|
|
||||||
base = normalize_base(endpoint)
|
|
||||||
ep = _owned_endpoint_by_url(db, base, user)
|
|
||||||
if ep and ep.api_key:
|
|
||||||
s = session_manager.sessions.get(sid)
|
|
||||||
if s:
|
|
||||||
s.headers = build_headers(ep.api_key, ep.base_url)
|
|
||||||
finally:
|
|
||||||
db.close()
|
|
||||||
|
|
||||||
# Store comparison record
|
# Store comparison record
|
||||||
db = SessionLocal()
|
db = SessionLocal()
|
||||||
@@ -121,8 +205,12 @@ def setup_compare_routes(session_manager: SessionManager):
|
|||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
model_a=model_a,
|
model_a=model_a,
|
||||||
model_b=model_b,
|
model_b=model_b,
|
||||||
endpoint_a=endpoint_a,
|
# Record the URL the session actually dials. For URL callers this
|
||||||
endpoint_b=endpoint_b,
|
# is their raw input; for id-only callers (empty endpoint_a/_b)
|
||||||
|
# fall back to the resolved endpoint URL so the column stays
|
||||||
|
# meaningful and non-null. resolved is in [a, b] order.
|
||||||
|
endpoint_a=endpoint_a or resolved[0][2],
|
||||||
|
endpoint_b=endpoint_b or resolved[1][2],
|
||||||
is_blind=blind,
|
is_blind=blind,
|
||||||
blind_mapping=json.dumps(mapping),
|
blind_mapping=json.dumps(mapping),
|
||||||
owner=user,
|
owner=user,
|
||||||
|
|||||||
+53
-18
@@ -11,20 +11,24 @@ import uuid
|
|||||||
import json
|
import json
|
||||||
import csv
|
import csv
|
||||||
import io
|
import io
|
||||||
|
import os
|
||||||
import httpx
|
import httpx
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from fastapi import APIRouter, Query, Depends, Response
|
from urllib.parse import urljoin, urlparse, urlunparse
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Query, Depends, Response, HTTPException
|
||||||
from typing import List, Dict, Optional
|
from typing import List, Dict, Optional
|
||||||
|
|
||||||
from src.auth_helpers import require_user
|
|
||||||
from core.middleware import require_admin
|
from core.middleware import require_admin
|
||||||
|
from src.url_safety import check_outbound_url
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
DATA_DIR = Path(__file__).resolve().parent.parent / "data"
|
from src.constants import DATA_DIR as _DATA_DIR, SETTINGS_FILE as _SETTINGS_FILE, CONTACTS_FILE as _CONTACTS_FILE
|
||||||
SETTINGS_FILE = DATA_DIR / "settings.json"
|
DATA_DIR = Path(_DATA_DIR)
|
||||||
LOCAL_CONTACTS_FILE = DATA_DIR / "contacts.json"
|
SETTINGS_FILE = Path(_SETTINGS_FILE)
|
||||||
|
LOCAL_CONTACTS_FILE = Path(_CONTACTS_FILE)
|
||||||
|
|
||||||
|
|
||||||
def _load_settings():
|
def _load_settings():
|
||||||
@@ -53,6 +57,21 @@ def _carddav_configured(cfg: Optional[Dict] = None) -> bool:
|
|||||||
return bool((cfg.get("url") or "").strip())
|
return bool((cfg.get("url") or "").strip())
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_carddav_url(url: str) -> str:
|
||||||
|
cleaned = (url if isinstance(url, str) else "").strip().rstrip("/")
|
||||||
|
ok, reason = check_outbound_url(
|
||||||
|
cleaned,
|
||||||
|
block_private=os.getenv("CARDDAV_BLOCK_PRIVATE_IPS", "false").lower() == "true",
|
||||||
|
)
|
||||||
|
if not ok:
|
||||||
|
raise ValueError(f"Rejected CardDAV URL: {reason}")
|
||||||
|
return cleaned
|
||||||
|
|
||||||
|
|
||||||
|
def _carddav_base_url(cfg: Dict) -> str:
|
||||||
|
return _validate_carddav_url(cfg.get("url") or "")
|
||||||
|
|
||||||
|
|
||||||
def _normalize_contact(contact: Dict) -> Dict:
|
def _normalize_contact(contact: Dict) -> Dict:
|
||||||
emails = []
|
emails = []
|
||||||
for e in contact.get("emails") or ([] if not contact.get("email") else [contact.get("email")]):
|
for e in contact.get("emails") or ([] if not contact.get("email") else [contact.get("email")]):
|
||||||
@@ -219,14 +238,18 @@ _contact_cache = {"contacts": [], "fetched_at": None}
|
|||||||
def _abs_url(href: str) -> str:
|
def _abs_url(href: str) -> str:
|
||||||
"""Combine a multistatus <href> (an absolute path like
|
"""Combine a multistatus <href> (an absolute path like
|
||||||
/user/contacts/x.vcf) with the configured CardDAV server origin so we
|
/user/contacts/x.vcf) with the configured CardDAV server origin so we
|
||||||
get a fully-qualified URL to PUT/DELETE. If href is already absolute
|
get a fully-qualified URL to PUT/DELETE. Absolute hrefs are accepted only
|
||||||
(http...), return it as-is."""
|
for the configured origin; a cross-origin href is treated as a path on the
|
||||||
from urllib.parse import urlparse, urlunparse
|
configured server so a malicious CardDAV response cannot redirect later
|
||||||
if href.startswith("http://") or href.startswith("https://"):
|
writes/deletes to cloud metadata or another host."""
|
||||||
return href
|
|
||||||
cfg = _get_carddav_config()
|
cfg = _get_carddav_config()
|
||||||
p = urlparse(cfg["url"])
|
base = _carddav_base_url(cfg)
|
||||||
return urlunparse((p.scheme, p.netloc, href, "", "", ""))
|
base_p = urlparse(base)
|
||||||
|
joined = urljoin(base.rstrip("/") + "/", href or "")
|
||||||
|
joined_p = urlparse(joined)
|
||||||
|
if (joined_p.scheme, joined_p.netloc) != (base_p.scheme, base_p.netloc):
|
||||||
|
joined = urlunparse((base_p.scheme, base_p.netloc, joined_p.path or "/", "", joined_p.query, ""))
|
||||||
|
return _validate_carddav_url(joined)
|
||||||
|
|
||||||
|
|
||||||
# CardDAV REPORT body — pull every card's etag + raw vCard in ONE request,
|
# CardDAV REPORT body — pull every card's etag + raw vCard in ONE request,
|
||||||
@@ -297,6 +320,7 @@ def _fetch_contacts(force=False):
|
|||||||
return contacts
|
return contacts
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
cfg["url"] = _carddav_base_url(cfg)
|
||||||
auth = None
|
auth = None
|
||||||
if cfg["username"]:
|
if cfg["username"]:
|
||||||
auth = (cfg["username"], cfg["password"])
|
auth = (cfg["username"], cfg["password"])
|
||||||
@@ -353,8 +377,8 @@ def _create_contact(name: str, email: str) -> bool:
|
|||||||
|
|
||||||
contact_uid = str(uuid.uuid4())
|
contact_uid = str(uuid.uuid4())
|
||||||
vcard = _build_vcard(name, email, contact_uid)
|
vcard = _build_vcard(name, email, contact_uid)
|
||||||
url = cfg["url"].rstrip("/") + "/" + contact_uid + ".vcf"
|
|
||||||
try:
|
try:
|
||||||
|
url = _carddav_base_url(cfg) + "/" + contact_uid + ".vcf"
|
||||||
auth = None
|
auth = None
|
||||||
if cfg["username"]:
|
if cfg["username"]:
|
||||||
auth = (cfg["username"], cfg["password"])
|
auth = (cfg["username"], cfg["password"])
|
||||||
@@ -382,7 +406,7 @@ def _vcard_url(uid: str) -> str:
|
|||||||
escape the collection and target an arbitrary CardDAV resource."""
|
escape the collection and target an arbitrary CardDAV resource."""
|
||||||
from urllib.parse import quote
|
from urllib.parse import quote
|
||||||
cfg = _get_carddav_config()
|
cfg = _get_carddav_config()
|
||||||
return cfg["url"].rstrip("/") + "/" + quote(uid, safe="") + ".vcf"
|
return _carddav_base_url(cfg) + "/" + quote(uid, safe="") + ".vcf"
|
||||||
|
|
||||||
|
|
||||||
def _import_vcards(text: str) -> Dict:
|
def _import_vcards(text: str) -> Dict:
|
||||||
@@ -413,6 +437,11 @@ def _import_vcards(text: str) -> Dict:
|
|||||||
if imported:
|
if imported:
|
||||||
_save_local_contacts(contacts)
|
_save_local_contacts(contacts)
|
||||||
return {"imported": imported, "failed": 0, "total": len(parsed)}
|
return {"imported": imported, "failed": 0, "total": len(parsed)}
|
||||||
|
try:
|
||||||
|
base_url = _carddav_base_url(cfg)
|
||||||
|
except ValueError as e:
|
||||||
|
logger.warning("CardDAV import URL rejected: %s", e)
|
||||||
|
return {"imported": 0, "failed": 0, "total": 0, "error": str(e)}
|
||||||
auth = (cfg["username"], cfg["password"]) if cfg["username"] else None
|
auth = (cfg["username"], cfg["password"]) if cfg["username"] else None
|
||||||
# Split into individual cards. re.split drops the BEGIN line, so we
|
# Split into individual cards. re.split drops the BEGIN line, so we
|
||||||
# re-add it. Normalize CRLF.
|
# re-add it. Normalize CRLF.
|
||||||
@@ -441,7 +470,7 @@ def _import_vcards(text: str) -> Dict:
|
|||||||
elif not re.search(r"^VERSION:", block, re.MULTILINE):
|
elif not re.search(r"^VERSION:", block, re.MULTILINE):
|
||||||
block = block.replace("BEGIN:VCARD", "BEGIN:VCARD\nVERSION:4.0", 1)
|
block = block.replace("BEGIN:VCARD", "BEGIN:VCARD\nVERSION:4.0", 1)
|
||||||
vcard = block.replace("\n", "\r\n") + "\r\n"
|
vcard = block.replace("\n", "\r\n") + "\r\n"
|
||||||
url = cfg["url"].rstrip("/") + "/" + quote(uid, safe="") + ".vcf"
|
url = base_url + "/" + quote(uid, safe="") + ".vcf"
|
||||||
try:
|
try:
|
||||||
r = httpx.put(
|
r = httpx.put(
|
||||||
url, data=vcard.encode("utf-8"),
|
url, data=vcard.encode("utf-8"),
|
||||||
@@ -601,8 +630,8 @@ def _update_contact(uid: str, name: str, emails: List[str], phones: List[str]) -
|
|||||||
vcard = _build_vcard(name, "", uid=uid, emails=emails, phones=phones)
|
vcard = _build_vcard(name, "", uid=uid, emails=emails, phones=phones)
|
||||||
# Use the real resource href (handles externally-created contacts whose
|
# Use the real resource href (handles externally-created contacts whose
|
||||||
# filename != UID); falls back to the <uid>.vcf guess.
|
# filename != UID); falls back to the <uid>.vcf guess.
|
||||||
url = _resolve_resource_url(uid)
|
|
||||||
try:
|
try:
|
||||||
|
url = _resolve_resource_url(uid)
|
||||||
auth = (cfg["username"], cfg["password"]) if cfg["username"] else None
|
auth = (cfg["username"], cfg["password"]) if cfg["username"] else None
|
||||||
r = httpx.put(
|
r = httpx.put(
|
||||||
url,
|
url,
|
||||||
@@ -630,8 +659,8 @@ def _delete_contact(uid: str) -> bool:
|
|||||||
_save_local_contacts(remaining)
|
_save_local_contacts(remaining)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
url = _resolve_resource_url(uid)
|
|
||||||
try:
|
try:
|
||||||
|
url = _resolve_resource_url(uid)
|
||||||
auth = (cfg["username"], cfg["password"]) if cfg["username"] else None
|
auth = (cfg["username"], cfg["password"]) if cfg["username"] else None
|
||||||
r = httpx.delete(url, auth=auth, timeout=10)
|
r = httpx.delete(url, auth=auth, timeout=10)
|
||||||
if r.status_code in (200, 204):
|
if r.status_code in (200, 204):
|
||||||
@@ -747,7 +776,13 @@ def setup_contacts_routes():
|
|||||||
settings = _load_settings()
|
settings = _load_settings()
|
||||||
for key in ("carddav_url", "carddav_username", "carddav_password"):
|
for key in ("carddav_url", "carddav_username", "carddav_password"):
|
||||||
if key in data:
|
if key in data:
|
||||||
settings[key] = data[key]
|
if key == "carddav_url" and str(data[key] or "").strip():
|
||||||
|
try:
|
||||||
|
settings[key] = _validate_carddav_url(data[key])
|
||||||
|
except ValueError as e:
|
||||||
|
raise HTTPException(400, str(e))
|
||||||
|
else:
|
||||||
|
settings[key] = data[key]
|
||||||
_save_settings(settings)
|
_save_settings(settings)
|
||||||
# Force re-fetch
|
# Force re-fetch
|
||||||
_contact_cache["fetched_at"] = None
|
_contact_cache["fetched_at"] = None
|
||||||
|
|||||||
+312
-14
@@ -11,6 +11,8 @@ import shlex
|
|||||||
from fastapi import HTTPException
|
from fastapi import HTTPException
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from core.platform_compat import _ssh_exec_argv
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@@ -195,6 +197,20 @@ def _pip_install_attempt(pip_cmd: str) -> str:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _pip_command(python_cmd: str) -> str:
|
||||||
|
"""Return a pip command for either a pip executable or a Python executable."""
|
||||||
|
cmd = python_cmd.strip()
|
||||||
|
if " -m pip" in cmd or cmd in {"pip", "pip3"}:
|
||||||
|
return python_cmd
|
||||||
|
if cmd in {"python", "python3", "python.exe"} or cmd.endswith(("/python", "/python3", "\\python.exe")):
|
||||||
|
return f"{python_cmd} -m pip"
|
||||||
|
return python_cmd
|
||||||
|
|
||||||
|
|
||||||
|
def _pip_break_system_packages_check(pip_cmd: str) -> str:
|
||||||
|
return f"{pip_cmd} install --help 2>/dev/null | grep -q -- --break-system-packages"
|
||||||
|
|
||||||
|
|
||||||
def _pip_install_fallback_chain(package: str, *, python_cmd: str = "python3 -m pip", upgrade: bool = False) -> str:
|
def _pip_install_fallback_chain(package: str, *, python_cmd: str = "python3 -m pip", upgrade: bool = False) -> str:
|
||||||
"""Build a bash pip install fallback chain that surfaces errors.
|
"""Build a bash pip install fallback chain that surfaces errors.
|
||||||
|
|
||||||
@@ -206,33 +222,44 @@ def _pip_install_fallback_chain(package: str, *, python_cmd: str = "python3 -m p
|
|||||||
exit code is preserved (no ``| tail`` masking) and the last 5 lines of
|
exit code is preserved (no ``| tail`` masking) and the last 5 lines of
|
||||||
pip output appear in the Cookbook log on failure.
|
pip output appear in the Cookbook log on failure.
|
||||||
"""
|
"""
|
||||||
|
from core.platform_compat import IS_WINDOWS
|
||||||
upgrade_flag = " -U" if upgrade else ""
|
upgrade_flag = " -U" if upgrade else ""
|
||||||
# Shell-quote the package spec: an extras spec like ``llama-cpp-python[server]``
|
# Shell-quote the package spec: an extras spec like ``llama-cpp-python[server]``
|
||||||
# contains brackets that bash would treat as a glob, so it must be quoted
|
# contains brackets that bash would treat as a glob, so it must be quoted
|
||||||
# before being embedded in the install command. Plain names (e.g.
|
# before being embedded in the install command. Plain names (e.g.
|
||||||
# ``huggingface_hub``) are returned unchanged by ``shlex.quote``.
|
# ``huggingface_hub``) are returned unchanged by ``shlex.quote``.
|
||||||
pkg = shlex.quote(package)
|
pkg = shlex.quote(package)
|
||||||
base = _pip_install_attempt(f"{python_cmd} install -q{upgrade_flag} {pkg}")
|
# llama-cpp-python source builds are brittle on older distro pip/packaging
|
||||||
user = _pip_install_attempt(f"{python_cmd} install --user --break-system-packages -q{upgrade_flag} {pkg}")
|
# stacks (common on WSL images). Prefer the prebuilt wheel index whenever
|
||||||
|
# this package is requested so dependency-install tasks are reliable.
|
||||||
|
if "llama-cpp-python" in package:
|
||||||
|
pkg += " --extra-index-url https://abetlen.github.io/llama-cpp-python/whl/cpu"
|
||||||
|
|
||||||
|
pip_cmd = _pip_command(python_cmd)
|
||||||
|
base = _pip_install_attempt(f"{pip_cmd} install -q{upgrade_flag} {pkg}")
|
||||||
|
user = _pip_install_attempt(f"{pip_cmd} install --user -q{upgrade_flag} {pkg}")
|
||||||
|
user_break_system = _pip_install_attempt(f"{pip_cmd} install --user --break-system-packages -q{upgrade_flag} {pkg}")
|
||||||
|
user_fallback = f"( {user} || {{ {_pip_break_system_packages_check(pip_cmd)} && {user_break_system}; }} )"
|
||||||
# Derive the python executable for the venv detection check.
|
# Derive the python executable for the venv detection check.
|
||||||
# Must use the same interpreter that pip belongs to; hardcoding
|
# Must use the same interpreter that pip belongs to; hardcoding
|
||||||
# python3 breaks when pip lives in a venv that only has "python".
|
# python3 breaks when pip lives in a venv that only has "python".
|
||||||
if " -m pip" in python_cmd:
|
if " -m pip" in pip_cmd:
|
||||||
python_exe = python_cmd.replace(" -m pip", "")
|
python_exe = pip_cmd.replace(" -m pip", "")
|
||||||
elif python_cmd.strip() == "pip":
|
elif pip_cmd.strip() == "pip":
|
||||||
python_exe = "python"
|
python_exe = "python"
|
||||||
elif python_cmd.strip() == "pip3":
|
elif pip_cmd.strip() == "pip3":
|
||||||
python_exe = "python3"
|
python_exe = "python3"
|
||||||
else:
|
else:
|
||||||
python_exe = "python3"
|
python_exe = "python3"
|
||||||
venv_check = f'{python_exe} -c "import sys; sys.exit(0 if sys.prefix != sys.base_prefix else 1)"'
|
venv_check = f'{python_exe} -c "import sys; sys.exit(0 if sys.prefix != sys.base_prefix else 1)"'
|
||||||
# Negated: `! venv_check` succeeds (exit 0) when NOT in a venv → `&&` tries
|
# Negated: `! venv_check` succeeds (exit 0) when NOT in a venv -> `&&` tries
|
||||||
# --user. When IN a venv `! venv_check` fails → `&&` skips --user and the
|
# --user. When IN a venv `! venv_check` fails -> `&&` skips --user and the
|
||||||
# group exits non-zero, propagating the base-install failure instead of
|
# group exits non-zero, propagating the base-install failure instead of
|
||||||
# masking it as success (the `|| { venv_check || … }` shape from #903
|
# masking it as success (the `|| { venv_check || … }` shape from #903
|
||||||
# swallowed the exit code because venv_check's exit-0 became the group's
|
# swallowed the exit code because venv_check's exit-0 became the group's
|
||||||
# result).
|
# result). `--break-system-packages` is only attempted when the active pip
|
||||||
return f"{base} || {{ ! {venv_check} && {user}; }}"
|
# supports it; older pip versions abort with "no such option" otherwise.
|
||||||
|
return f"{base} || {{ ! {venv_check} && {user_fallback}; }}"
|
||||||
|
|
||||||
|
|
||||||
def _venv_safe_local_pip_install_cmd(cmd: str, *, local: bool, in_venv: bool) -> str:
|
def _venv_safe_local_pip_install_cmd(cmd: str, *, local: bool, in_venv: bool) -> str:
|
||||||
@@ -263,6 +290,55 @@ def _venv_safe_local_pip_install_cmd(cmd: str, *, local: bool, in_venv: bool) ->
|
|||||||
return shlex.join(stripped)
|
return shlex.join(stripped)
|
||||||
|
|
||||||
|
|
||||||
|
def _pip_install_command_without_break_system_packages(cmd: str) -> str:
|
||||||
|
try:
|
||||||
|
parts = shlex.split(cmd)
|
||||||
|
except ValueError:
|
||||||
|
return cmd
|
||||||
|
stripped = [part for part in parts if part != "--break-system-packages"]
|
||||||
|
return shlex.join(stripped)
|
||||||
|
|
||||||
|
|
||||||
|
def _pip_install_help_check_from_cmd(cmd: str) -> str | None:
|
||||||
|
try:
|
||||||
|
parts = shlex.split(cmd)
|
||||||
|
except ValueError:
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
install_index = parts.index("install")
|
||||||
|
except ValueError:
|
||||||
|
return None
|
||||||
|
if install_index <= 0:
|
||||||
|
return None
|
||||||
|
pip_prefix = parts[:install_index]
|
||||||
|
return f"{shlex.join(pip_prefix + ['install', '--help'])} 2>/dev/null | grep -q -- --break-system-packages"
|
||||||
|
|
||||||
|
|
||||||
|
def _append_pip_install_runner_lines(runner_lines: list[str], cmd: str) -> None:
|
||||||
|
"""Append a pip install command, guarding --break-system-packages support.
|
||||||
|
|
||||||
|
The Dependencies UI may submit ``python3 -m pip install --user
|
||||||
|
--break-system-packages ...`` for non-venv installs. That flag is useful on
|
||||||
|
PEP-668-locked distros, but older pip (including Ubuntu 22.04's apt pip in
|
||||||
|
the NVIDIA CUDA base image) aborts with "no such option". Branch at runner
|
||||||
|
time so stale browser JS and remote targets are handled by the server too.
|
||||||
|
"""
|
||||||
|
if "--break-system-packages" not in (cmd or ""):
|
||||||
|
runner_lines.append(cmd)
|
||||||
|
return
|
||||||
|
help_check = _pip_install_help_check_from_cmd(cmd)
|
||||||
|
without_break = _pip_install_command_without_break_system_packages(cmd)
|
||||||
|
if not help_check or without_break == cmd:
|
||||||
|
runner_lines.append(cmd)
|
||||||
|
return
|
||||||
|
runner_lines.append(f"if {help_check}; then")
|
||||||
|
runner_lines.append(f" {cmd}")
|
||||||
|
runner_lines.append("else")
|
||||||
|
runner_lines.append(' echo "[odysseus] pip does not support --break-system-packages; installing without it."')
|
||||||
|
runner_lines.append(f" {without_break}")
|
||||||
|
runner_lines.append("fi")
|
||||||
|
|
||||||
|
|
||||||
def _user_shell_path_bootstrap() -> list[str]:
|
def _user_shell_path_bootstrap() -> list[str]:
|
||||||
return [
|
return [
|
||||||
'ODYSSEUS_USER_SHELL="${SHELL:-}"',
|
'ODYSSEUS_USER_SHELL="${SHELL:-}"',
|
||||||
@@ -271,11 +347,14 @@ def _user_shell_path_bootstrap() -> list[str]:
|
|||||||
' if [ -n "$ODYSSEUS_USER_PATH" ]; then export PATH="$ODYSSEUS_USER_PATH:$PATH"; fi',
|
' if [ -n "$ODYSSEUS_USER_PATH" ]; then export PATH="$ODYSSEUS_USER_PATH:$PATH"; fi',
|
||||||
'fi',
|
'fi',
|
||||||
'command -v python3 >/dev/null 2>&1 || python3() { python "$@"; }',
|
'command -v python3 >/dev/null 2>&1 || python3() { python "$@"; }',
|
||||||
|
'command -v python >/dev/null 2>&1 || python() { python3 "$@"; }',
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def _cached_model_scan_script(model_dirs: list[str] | None = None) -> str:
|
def _cached_model_scan_script(model_dirs: list[str] | None = None, add_hf_cache: str | None = None) -> str:
|
||||||
"""Build the standalone Python scanner used by /api/model/cached."""
|
"""Build the standalone Python scanner used by /api/model/cached.
|
||||||
|
Allows for an additional HuggingFace cache path to be scanned (i.e. Windows HF cache for local WSL envs.)
|
||||||
|
"""
|
||||||
lines = [
|
lines = [
|
||||||
"import json, os, re, shutil, subprocess, urllib.request",
|
"import json, os, re, shutil, subprocess, urllib.request",
|
||||||
"models = []",
|
"models = []",
|
||||||
@@ -338,6 +417,15 @@ def _cached_model_scan_script(model_dirs: list[str] | None = None) -> str:
|
|||||||
" if f.is_file(): nf += 1; sz += f.stat().st_size",
|
" if f.is_file(): nf += 1; sz += f.stat().st_size",
|
||||||
" if f.name.endswith('.incomplete'): ic = True",
|
" if f.name.endswith('.incomplete'): ic = True",
|
||||||
" snap = os.path.join(cache, d, 'snapshots')",
|
" snap = os.path.join(cache, d, 'snapshots')",
|
||||||
|
" # Windows HF cache stores files directly in snapshots/; blobs/ may be empty.",
|
||||||
|
" # Fallback: scan snapshots for real files when blobs yielded nothing.",
|
||||||
|
" if sz == 0 and os.path.isdir(snap):",
|
||||||
|
" for sd in os.listdir(snap):",
|
||||||
|
" sf = os.path.join(snap, sd)",
|
||||||
|
" if not os.path.isdir(sf): continue",
|
||||||
|
" for f in os.scandir(sf):",
|
||||||
|
" if f.is_file(): nf += 1; sz += f.stat().st_size",
|
||||||
|
" if f.name.endswith('.incomplete'): ic = True",
|
||||||
" is_diffusion = False; gguf_files = []",
|
" is_diffusion = False; gguf_files = []",
|
||||||
" if os.path.isdir(snap):",
|
" if os.path.isdir(snap):",
|
||||||
" for sd in os.listdir(snap):",
|
" for sd in os.listdir(snap):",
|
||||||
@@ -346,6 +434,21 @@ def _cached_model_scan_script(model_dirs: list[str] | None = None) -> str:
|
|||||||
" if os.path.exists(os.path.join(sf, 'model_index.json')): is_diffusion = True",
|
" if os.path.exists(os.path.join(sf, 'model_index.json')): is_diffusion = True",
|
||||||
" for f in collect_ggufs(sf): f['rel_path'] = sd + '/' + f['rel_path']; gguf_files.append(f)",
|
" for f in collect_ggufs(sf): f['rel_path'] = sd + '/' + f['rel_path']; gguf_files.append(f)",
|
||||||
" models.append({'repo_id':rid,'size_bytes':sz,'nb_files':nf,'has_incomplete':ic,'path':cache,'is_diffusion':is_diffusion,'is_gguf':bool(gguf_files),'gguf_files':gguf_files})",
|
" models.append({'repo_id':rid,'size_bytes':sz,'nb_files':nf,'has_incomplete':ic,'path':cache,'is_diffusion':is_diffusion,'is_gguf':bool(gguf_files),'gguf_files':gguf_files})",
|
||||||
|
"def hf_cache_paths():",
|
||||||
|
" candidates = []",
|
||||||
|
" def add(p):",
|
||||||
|
" if not p: return",
|
||||||
|
" p = os.path.expanduser(p)",
|
||||||
|
" if p not in candidates: candidates.append(p)",
|
||||||
|
" add(os.environ.get('HUGGINGFACE_HUB_CACHE'))",
|
||||||
|
" hf_home = os.environ.get('HF_HOME')",
|
||||||
|
" if hf_home: add(os.path.join(hf_home, 'hub'))",
|
||||||
|
" add('~/.cache/huggingface/hub')",
|
||||||
|
" # Docker images mount ./data/huggingface at /app/.cache/huggingface.",
|
||||||
|
" # When HOME is /root, expanduser() misses that persisted cache.",
|
||||||
|
" add('/app/.cache/huggingface/hub')",
|
||||||
|
f" add({add_hf_cache!r})" if add_hf_cache else "",
|
||||||
|
" return candidates",
|
||||||
"def scan_dir(p):",
|
"def scan_dir(p):",
|
||||||
" if not os.path.isdir(p) or not safe_path(p): return",
|
" if not os.path.isdir(p) or not safe_path(p): return",
|
||||||
" for d in sorted(os.listdir(p)):",
|
" for d in sorted(os.listdir(p)):",
|
||||||
@@ -409,7 +512,7 @@ def _cached_model_scan_script(model_dirs: list[str] | None = None) -> str:
|
|||||||
" seen.add(name)",
|
" seen.add(name)",
|
||||||
" models.append({'repo_id':name,'size_bytes':size_bytes,'nb_files':1,'has_incomplete':False,'path':'ollama','backend':'ollama','is_ollama':True})",
|
" models.append({'repo_id':name,'size_bytes':size_bytes,'nb_files':1,'has_incomplete':False,'path':'ollama','backend':'ollama','is_ollama':True})",
|
||||||
" return",
|
" return",
|
||||||
"scan_hf(os.path.expanduser('~/.cache/huggingface/hub'))",
|
"for _hf_cache in hf_cache_paths(): scan_hf(_hf_cache)",
|
||||||
"scan_ollama()",
|
"scan_ollama()",
|
||||||
"scan_ollama_api()",
|
"scan_ollama_api()",
|
||||||
]
|
]
|
||||||
@@ -525,6 +628,7 @@ def _validate_serve_cmd(v: str | None) -> str | None:
|
|||||||
# Backticks and raw newlines are never legitimate here.
|
# Backticks and raw newlines are never legitimate here.
|
||||||
if any(c in v for c in ("`", "\n", "\r")):
|
if any(c in v for c in ("`", "\n", "\r")):
|
||||||
raise HTTPException(400, "Invalid characters in cmd")
|
raise HTTPException(400, "Invalid characters in cmd")
|
||||||
|
|
||||||
# Known GGUF launcher prelude → validate the serve invocation(s) it guards.
|
# Known GGUF launcher prelude → validate the serve invocation(s) it guards.
|
||||||
m = _GGUF_PRELUDE_RE.match(v)
|
m = _GGUF_PRELUDE_RE.match(v)
|
||||||
if m:
|
if m:
|
||||||
@@ -533,9 +637,19 @@ def _validate_serve_cmd(v: str | None) -> str | None:
|
|||||||
for part in rest.split("||"):
|
for part in rest.split("||"):
|
||||||
_check_serve_binary(part.strip())
|
_check_serve_binary(part.strip())
|
||||||
return v
|
return v
|
||||||
|
|
||||||
# Otherwise: a single invocation — no shell metacharacters allowed.
|
# Otherwise: a single invocation — no shell metacharacters allowed.
|
||||||
|
# Temporarily replace safe $(printf %s ...) expressions with a placeholder
|
||||||
|
# to avoid triggering the metacharacter/command-injection checks.
|
||||||
|
cleaned_v = v
|
||||||
|
printf_matches = list(re.finditer(r"\$\(\s*printf\s+%s\s+([^\n()]*?)\)", v))
|
||||||
|
for match in printf_matches:
|
||||||
|
inner = match.group(1)
|
||||||
|
if not any(c in inner for c in (";", "&&", "||", "$(", "`")):
|
||||||
|
cleaned_v = cleaned_v.replace(match.group(0), "/placeholder/safe/path.gguf")
|
||||||
|
|
||||||
# (`$(` was the original intent; bare `$` is fine for shell-safe paths.)
|
# (`$(` was the original intent; bare `$` is fine for shell-safe paths.)
|
||||||
if any(c in v for c in (";", "&&", "||", "$(")):
|
if any(c in cleaned_v for c in (";", "&&", "||", "$(")):
|
||||||
raise HTTPException(400, "Invalid characters in cmd")
|
raise HTTPException(400, "Invalid characters in cmd")
|
||||||
_check_serve_binary(v)
|
_check_serve_binary(v)
|
||||||
return v
|
return v
|
||||||
@@ -559,6 +673,21 @@ def _append_serve_preflight_exit_lines(runner_lines: list[str], *, keep_shell_op
|
|||||||
runner_lines.append('fi')
|
runner_lines.append('fi')
|
||||||
|
|
||||||
|
|
||||||
|
def _append_vllm_linux_preflight_lines(runner_lines: list[str]) -> None:
|
||||||
|
"""Append Linux vLLM readiness lines that identify the runtime being used."""
|
||||||
|
# Keep the user install bin visible for Odysseus-managed `pip install --user`
|
||||||
|
# installs, but then report the actual CLI path so external runtimes are clear.
|
||||||
|
runner_lines.append('export PATH="$HOME/.local/bin:$PATH"')
|
||||||
|
runner_lines.append('ODYSSEUS_VLLM_BIN="$(command -v vllm 2>/dev/null || true)"')
|
||||||
|
runner_lines.append('if [ -z "$ODYSSEUS_VLLM_BIN" ]; then')
|
||||||
|
runner_lines.append(' echo "ERROR: vLLM is not installed."')
|
||||||
|
runner_lines.append(' ODYSSEUS_PREFLIGHT_EXIT=127')
|
||||||
|
runner_lines.append('else')
|
||||||
|
runner_lines.append(' echo "[odysseus] vLLM CLI: $ODYSSEUS_VLLM_BIN"')
|
||||||
|
runner_lines.append(' ODYSSEUS_VLLM_VERSION="$("$ODYSSEUS_VLLM_BIN" --version 2>&1 | head -n 1 || true)"')
|
||||||
|
runner_lines.append(' if [ -n "$ODYSSEUS_VLLM_VERSION" ]; then echo "[odysseus] vLLM version: $ODYSSEUS_VLLM_VERSION"; fi')
|
||||||
|
runner_lines.append('fi')
|
||||||
|
|
||||||
def _append_serve_exit_code_lines(
|
def _append_serve_exit_code_lines(
|
||||||
runner_lines: list[str],
|
runner_lines: list[str],
|
||||||
*,
|
*,
|
||||||
@@ -804,3 +933,172 @@ def _ssh_ps(host, script_path, port=None):
|
|||||||
|
|
||||||
# Windows session dir — stored in user's temp on the remote
|
# Windows session dir — stored in user's temp on the remote
|
||||||
WIN_SESSION_DIR = "$env:TEMP\\\\odysseus-sessions"
|
WIN_SESSION_DIR = "$env:TEMP\\\\odysseus-sessions"
|
||||||
|
|
||||||
|
|
||||||
|
def _diagnose_serve_output(text: str) -> dict | None:
|
||||||
|
"""Server-side mirror of the Cookbook UI's common serve diagnoses.
|
||||||
|
|
||||||
|
The browser uses cookbook-diagnosis.js for clickable fixes. This gives
|
||||||
|
the agent/tool path the same structured signal so it can retry with an
|
||||||
|
adjusted command instead of guessing from raw tmux output.
|
||||||
|
"""
|
||||||
|
if not text:
|
||||||
|
return None
|
||||||
|
tail = text[-6000:]
|
||||||
|
patterns = [
|
||||||
|
(
|
||||||
|
r"No available memory for the cache blocks|Available KV cache memory:.*-",
|
||||||
|
"No GPU memory left for KV cache after loading model.",
|
||||||
|
[
|
||||||
|
{"label": "retry with GPU memory utilization 0.95", "op": "replace", "flag": "--gpu-memory-utilization", "value": "0.95"},
|
||||||
|
{"label": "retry with context 2048", "op": "replace", "flag": "--max-model-len", "value": "2048"},
|
||||||
|
],
|
||||||
|
),
|
||||||
|
(
|
||||||
|
r"CUDA out of memory|torch\.cuda\.OutOfMemoryError|CUDA error: out of memory|warming up sampler|max_num_seqs.*gpu_memory_utilization",
|
||||||
|
"GPU ran out of memory during startup or warmup.",
|
||||||
|
[
|
||||||
|
{"label": "retry with context 4096", "op": "replace", "flag": "--max-model-len", "value": "4096"},
|
||||||
|
{"label": "retry with GPU memory utilization 0.80", "op": "replace", "flag": "--gpu-memory-utilization", "value": "0.80"},
|
||||||
|
{"label": "retry with --enforce-eager", "op": "append", "arg": "--enforce-eager"},
|
||||||
|
],
|
||||||
|
),
|
||||||
|
(
|
||||||
|
r"not divisib|must be divisible|attention heads.*divisible",
|
||||||
|
"Tensor parallel size is incompatible with the model.",
|
||||||
|
[
|
||||||
|
{"label": "retry with tensor parallel size 1", "op": "replace", "flag": "--tensor-parallel-size", "value": "1"},
|
||||||
|
{"label": "retry with tensor parallel size 2", "op": "replace", "flag": "--tensor-parallel-size", "value": "2"},
|
||||||
|
],
|
||||||
|
),
|
||||||
|
(
|
||||||
|
r"KV cache.*too (small|large)|max_model_len.*exceeds|maximum.*context",
|
||||||
|
"Context length is too large for available GPU memory.",
|
||||||
|
[
|
||||||
|
{"label": "retry with context 8192", "op": "replace", "flag": "--max-model-len", "value": "8192"},
|
||||||
|
{"label": "retry with context 4096", "op": "replace", "flag": "--max-model-len", "value": "4096"},
|
||||||
|
],
|
||||||
|
),
|
||||||
|
(
|
||||||
|
r"enable-auto-tool-choice requires --tool-call-parser",
|
||||||
|
"Auto tool choice requires an explicit tool call parser.",
|
||||||
|
[{"label": "retry with Hermes tool parser", "op": "append", "arg": "--tool-call-parser hermes"}],
|
||||||
|
),
|
||||||
|
(
|
||||||
|
r"Please pass.*trust.remote.code=True|contains custom code which must be executed to correctly load|does not recognize this architecture|model type.*but Transformers does not",
|
||||||
|
"Model requires custom code or newer model support.",
|
||||||
|
[{"label": "retry with --trust-remote-code", "op": "append", "arg": "--trust-remote-code"}],
|
||||||
|
),
|
||||||
|
(
|
||||||
|
r"There is no module or parameter named ['\"]lm_head\.input_scale['\"]|lm_head\.input_scale|weight_scale_2",
|
||||||
|
"vLLM cannot load this ModelOpt LM-head quantized checkpoint with the current runtime.",
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"label": "upgrade vLLM through the environment that provides this CLI, or use a compatible checkpoint",
|
||||||
|
"op": "manual",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
),
|
||||||
|
(
|
||||||
|
r"Either a revision or a version must be specified|transformers\.integrations\.hub_kernels|kernels/layer",
|
||||||
|
"vLLM/Transformers kernel package mismatch.",
|
||||||
|
[{"label": "update vLLM, Transformers, and kernels on this server", "op": "dependency", "package": "vllm transformers kernels"}],
|
||||||
|
),
|
||||||
|
(
|
||||||
|
r"Address already in use|bind.*address.*in use",
|
||||||
|
"Port is already in use.",
|
||||||
|
[{"label": "retry on port 8001", "op": "replace", "flag": "--port", "value": "8001"}],
|
||||||
|
),
|
||||||
|
(
|
||||||
|
r"No CUDA GPUs are available|no GPU.*found|CUDA_VISIBLE_DEVICES.*invalid",
|
||||||
|
"No GPUs are visible to the serve process.",
|
||||||
|
[{"label": "clear Cookbook GPU selection or choose available GPUs", "op": "settings", "field": "gpus", "value": ""}],
|
||||||
|
),
|
||||||
|
(
|
||||||
|
r"Failed to infer device type|NVML Shared Library Not Found|No module named 'amdsmi'|platform is not available",
|
||||||
|
"vLLM could not find a supported GPU (CUDA or ROCm). "
|
||||||
|
"This machine may have integrated or unsupported graphics only.",
|
||||||
|
[
|
||||||
|
{"label": "switch to llama.cpp (CPU/Metal, works without a discrete GPU)", "op": "manual"},
|
||||||
|
{"label": "switch to Ollama (CPU/Metal, works without a discrete GPU)", "op": "manual"},
|
||||||
|
],
|
||||||
|
),
|
||||||
|
(
|
||||||
|
r"vllm.*command not found|No module named vllm|ERROR: vLLM is not installed",
|
||||||
|
"vLLM is not installed or not in PATH on this server.",
|
||||||
|
[{"label": "install vLLM in Cookbook Dependencies", "op": "dependency", "package": "vllm"}],
|
||||||
|
),
|
||||||
|
(
|
||||||
|
r"sglang.*command not found|No module named sglang|SGLang is not installed",
|
||||||
|
"SGLang is not installed or not in PATH on this server.",
|
||||||
|
[{"label": "install SGLang in Cookbook Dependencies", "op": "dependency", "package": "sglang[all]"}],
|
||||||
|
),
|
||||||
|
(
|
||||||
|
r"llama-server.*command not found|llama\.cpp.*not found|No module named.*llama_cpp|No module named 'starlette_context'|git: command not found|cmake: command not found",
|
||||||
|
"llama.cpp / llama-cpp-python dependencies are missing.",
|
||||||
|
[{"label": "install llama.cpp dependencies or llama-cpp-python[server]", "op": "dependency", "package": "llama-cpp-python[server]"}],
|
||||||
|
),
|
||||||
|
(
|
||||||
|
r"No GGUF found on this host|no \.gguf file|No GGUF file found",
|
||||||
|
"No GGUF file found for this model on this host. The llama.cpp backend needs a .gguf file.",
|
||||||
|
[{"label": "download a GGUF build of this model (repo name usually ends in -GGUF, file like Q4_K_M.gguf)", "op": "manual"}],
|
||||||
|
),
|
||||||
|
(
|
||||||
|
r"No module named 'torch'|No module named torch|No module named 'diffusers'|No module named diffusers",
|
||||||
|
"Diffusion serving requires PyTorch and diffusers.",
|
||||||
|
[{"label": "install diffusers[torch] in Cookbook Dependencies", "op": "dependency", "package": "diffusers[torch]"}],
|
||||||
|
),
|
||||||
|
(
|
||||||
|
r"403 Forbidden|401 Unauthorized|Access to model.*is restricted|gated repo|not in the authorized list|awaiting a review",
|
||||||
|
"Model access is gated or unauthorized.",
|
||||||
|
[{"label": "set HF token and request model access on HuggingFace", "op": "manual"}],
|
||||||
|
),
|
||||||
|
]
|
||||||
|
for pattern, message, suggestions in patterns:
|
||||||
|
if re.search(pattern, tail, re.I):
|
||||||
|
return {"message": message, "suggestions": suggestions}
|
||||||
|
if re.search(r"Traceback \(most recent call last\)", tail, re.I) and not re.search(
|
||||||
|
r"Application startup complete|GET /v1/|Uvicorn running on", tail, re.I
|
||||||
|
):
|
||||||
|
return {
|
||||||
|
"message": "Python traceback detected during serve startup.",
|
||||||
|
"suggestions": [{"label": "inspect traceback and retry with adjusted backend/settings", "op": "manual"}],
|
||||||
|
}
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
async def run_ssh_command_async(
|
||||||
|
remote: str,
|
||||||
|
ssh_port: str | None,
|
||||||
|
remote_cmd: str,
|
||||||
|
*,
|
||||||
|
timeout: float,
|
||||||
|
connect_timeout: int | None = None,
|
||||||
|
strict_host_key_checking: bool | None = None,
|
||||||
|
stdin_data: bytes | None = None,
|
||||||
|
) -> tuple[int, bytes, bytes]:
|
||||||
|
"""Run an ssh command with centralized timeout and stderr/stdout capture.
|
||||||
|
Async version of core.platform_compat.run_ssh_command_sync.
|
||||||
|
"""
|
||||||
|
import asyncio
|
||||||
|
proc = await asyncio.create_subprocess_exec(
|
||||||
|
*_ssh_exec_argv(
|
||||||
|
remote,
|
||||||
|
ssh_port,
|
||||||
|
remote_cmd=remote_cmd,
|
||||||
|
connect_timeout=connect_timeout,
|
||||||
|
strict_host_key_checking=strict_host_key_checking,
|
||||||
|
),
|
||||||
|
stdin=asyncio.subprocess.PIPE if stdin_data is not None else None,
|
||||||
|
stdout=asyncio.subprocess.PIPE,
|
||||||
|
stderr=asyncio.subprocess.PIPE,
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
stdout, stderr = await asyncio.wait_for(
|
||||||
|
proc.communicate(input=stdin_data), timeout=timeout
|
||||||
|
)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
proc.kill()
|
||||||
|
await proc.communicate()
|
||||||
|
raise
|
||||||
|
return proc.returncode or 0, stdout, stderr
|
||||||
|
|||||||
+189
-206
@@ -15,19 +15,26 @@ from pathlib import Path
|
|||||||
from fastapi import APIRouter, HTTPException, Request, Depends
|
from fastapi import APIRouter, HTTPException, Request, Depends
|
||||||
|
|
||||||
from src.auth_helpers import require_user
|
from src.auth_helpers import require_user
|
||||||
|
from src.constants import COOKBOOK_STATE_FILE
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from core.middleware import require_admin
|
from core.middleware import require_admin
|
||||||
from core.platform_compat import (
|
from core.platform_compat import (
|
||||||
IS_WINDOWS,
|
IS_WINDOWS,
|
||||||
|
SSH_PATH_OVERRIDE,
|
||||||
|
NVIDIA_PATH_CANDIDATES,
|
||||||
detached_popen_kwargs,
|
detached_popen_kwargs,
|
||||||
find_bash,
|
find_bash,
|
||||||
|
git_bash_path,
|
||||||
kill_process_tree,
|
kill_process_tree,
|
||||||
pid_alive,
|
pid_alive,
|
||||||
safe_chmod,
|
safe_chmod,
|
||||||
which_tool,
|
which_tool,
|
||||||
|
translate_path,
|
||||||
|
get_wsl_windows_user_profile,
|
||||||
)
|
)
|
||||||
from routes.shell_routes import TMUX_LOG_DIR
|
from routes.shell_routes import TMUX_LOG_DIR
|
||||||
|
from src.constants import COOKBOOK_STATE_FILE
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -38,8 +45,10 @@ from routes.cookbook_helpers import (
|
|||||||
_ps_squote, _bash_squote, _validate_serve_cmd, _parse_serve_phase,
|
_ps_squote, _bash_squote, _validate_serve_cmd, _parse_serve_phase,
|
||||||
_safe_env_prefix, _local_tooling_path_export, _append_serve_preflight_exit_lines,
|
_safe_env_prefix, _local_tooling_path_export, _append_serve_preflight_exit_lines,
|
||||||
_append_serve_exit_code_lines, _append_llama_cpp_linux_accel_build_lines, _cached_model_scan_script,
|
_append_serve_exit_code_lines, _append_llama_cpp_linux_accel_build_lines, _cached_model_scan_script,
|
||||||
_ollama_bind_from_cmd, _pip_install_fallback_chain, _pip_install_no_cache,
|
_append_vllm_linux_preflight_lines, _ollama_bind_from_cmd, _pip_install_fallback_chain,
|
||||||
_user_shell_path_bootstrap, _venv_safe_local_pip_install_cmd,
|
_pip_install_no_cache, _user_shell_path_bootstrap, _venv_safe_local_pip_install_cmd,
|
||||||
|
_append_pip_install_runner_lines,
|
||||||
|
_diagnose_serve_output, run_ssh_command_async,
|
||||||
ModelDownloadRequest, ServeRequest,
|
ModelDownloadRequest, ServeRequest,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -54,7 +63,7 @@ _HF_TOKEN_STATUS_SNIPPET = (
|
|||||||
|
|
||||||
def setup_cookbook_routes() -> APIRouter:
|
def setup_cookbook_routes() -> APIRouter:
|
||||||
router = APIRouter(tags=["cookbook"])
|
router = APIRouter(tags=["cookbook"])
|
||||||
_cookbook_state_path = Path(os.environ.get("DATA_DIR", "data")) / "cookbook_state.json"
|
_cookbook_state_path = Path(COOKBOOK_STATE_FILE)
|
||||||
|
|
||||||
def _mask_secret(value: str) -> str:
|
def _mask_secret(value: str) -> str:
|
||||||
if not value:
|
if not value:
|
||||||
@@ -81,127 +90,6 @@ def setup_cookbook_routes() -> APIRouter:
|
|||||||
task["payload"].pop("hf_token", None)
|
task["payload"].pop("hf_token", None)
|
||||||
return state
|
return state
|
||||||
|
|
||||||
def _diagnose_serve_output(text: str) -> dict | None:
|
|
||||||
"""Server-side mirror of the Cookbook UI's common serve diagnoses.
|
|
||||||
|
|
||||||
The browser uses cookbook-diagnosis.js for clickable fixes. This gives
|
|
||||||
the agent/tool path the same structured signal so it can retry with an
|
|
||||||
adjusted command instead of guessing from raw tmux output.
|
|
||||||
"""
|
|
||||||
if not text:
|
|
||||||
return None
|
|
||||||
tail = text[-6000:]
|
|
||||||
patterns = [
|
|
||||||
(
|
|
||||||
r"No available memory for the cache blocks|Available KV cache memory:.*-",
|
|
||||||
"No GPU memory left for KV cache after loading model.",
|
|
||||||
[
|
|
||||||
{"label": "retry with GPU memory utilization 0.95", "op": "replace", "flag": "--gpu-memory-utilization", "value": "0.95"},
|
|
||||||
{"label": "retry with context 2048", "op": "replace", "flag": "--max-model-len", "value": "2048"},
|
|
||||||
],
|
|
||||||
),
|
|
||||||
(
|
|
||||||
r"CUDA out of memory|torch\.cuda\.OutOfMemoryError|CUDA error: out of memory|warming up sampler|max_num_seqs.*gpu_memory_utilization",
|
|
||||||
"GPU ran out of memory during startup or warmup.",
|
|
||||||
[
|
|
||||||
{"label": "retry with context 4096", "op": "replace", "flag": "--max-model-len", "value": "4096"},
|
|
||||||
{"label": "retry with GPU memory utilization 0.80", "op": "replace", "flag": "--gpu-memory-utilization", "value": "0.80"},
|
|
||||||
{"label": "retry with --enforce-eager", "op": "append", "arg": "--enforce-eager"},
|
|
||||||
],
|
|
||||||
),
|
|
||||||
(
|
|
||||||
r"not divisib|must be divisible|attention heads.*divisible",
|
|
||||||
"Tensor parallel size is incompatible with the model.",
|
|
||||||
[
|
|
||||||
{"label": "retry with tensor parallel size 1", "op": "replace", "flag": "--tensor-parallel-size", "value": "1"},
|
|
||||||
{"label": "retry with tensor parallel size 2", "op": "replace", "flag": "--tensor-parallel-size", "value": "2"},
|
|
||||||
],
|
|
||||||
),
|
|
||||||
(
|
|
||||||
r"KV cache.*too (small|large)|max_model_len.*exceeds|maximum.*context",
|
|
||||||
"Context length is too large for available GPU memory.",
|
|
||||||
[
|
|
||||||
{"label": "retry with context 8192", "op": "replace", "flag": "--max-model-len", "value": "8192"},
|
|
||||||
{"label": "retry with context 4096", "op": "replace", "flag": "--max-model-len", "value": "4096"},
|
|
||||||
],
|
|
||||||
),
|
|
||||||
(
|
|
||||||
r"enable-auto-tool-choice requires --tool-call-parser",
|
|
||||||
"Auto tool choice requires an explicit tool call parser.",
|
|
||||||
[{"label": "retry with Hermes tool parser", "op": "append", "arg": "--tool-call-parser hermes"}],
|
|
||||||
),
|
|
||||||
(
|
|
||||||
r"Please pass.*trust.remote.code=True|contains custom code which must be executed to correctly load|does not recognize this architecture|model type.*but Transformers does not",
|
|
||||||
"Model requires custom code or newer model support.",
|
|
||||||
[{"label": "retry with --trust-remote-code", "op": "append", "arg": "--trust-remote-code"}],
|
|
||||||
),
|
|
||||||
(
|
|
||||||
r"Either a revision or a version must be specified|transformers\.integrations\.hub_kernels|kernels/layer",
|
|
||||||
"vLLM/Transformers kernel package mismatch.",
|
|
||||||
[{"label": "update vLLM, Transformers, and kernels on this server", "op": "dependency", "package": "vllm transformers kernels"}],
|
|
||||||
),
|
|
||||||
(
|
|
||||||
r"Address already in use|bind.*address.*in use",
|
|
||||||
"Port is already in use.",
|
|
||||||
[{"label": "retry on port 8001", "op": "replace", "flag": "--port", "value": "8001"}],
|
|
||||||
),
|
|
||||||
(
|
|
||||||
r"No CUDA GPUs are available|no GPU.*found|CUDA_VISIBLE_DEVICES.*invalid",
|
|
||||||
"No GPUs are visible to the serve process.",
|
|
||||||
[{"label": "clear Cookbook GPU selection or choose available GPUs", "op": "settings", "field": "gpus", "value": ""}],
|
|
||||||
),
|
|
||||||
(
|
|
||||||
r"Failed to infer device type|NVML Shared Library Not Found|No module named 'amdsmi'|platform is not available",
|
|
||||||
"vLLM could not find a supported GPU (CUDA or ROCm). "
|
|
||||||
"This machine may have integrated or unsupported graphics only.",
|
|
||||||
[
|
|
||||||
{"label": "switch to llama.cpp (CPU/Metal, works without a discrete GPU)", "op": "manual"},
|
|
||||||
{"label": "switch to Ollama (CPU/Metal, works without a discrete GPU)", "op": "manual"},
|
|
||||||
],
|
|
||||||
),
|
|
||||||
(
|
|
||||||
r"vllm.*command not found|No module named vllm|ERROR: vLLM is not installed",
|
|
||||||
"vLLM is not installed or not in PATH on this server.",
|
|
||||||
[{"label": "install vLLM in Cookbook Dependencies", "op": "dependency", "package": "vllm"}],
|
|
||||||
),
|
|
||||||
(
|
|
||||||
r"sglang.*command not found|No module named sglang|SGLang is not installed",
|
|
||||||
"SGLang is not installed or not in PATH on this server.",
|
|
||||||
[{"label": "install SGLang in Cookbook Dependencies", "op": "dependency", "package": "sglang[all]"}],
|
|
||||||
),
|
|
||||||
(
|
|
||||||
r"llama-server.*command not found|llama\.cpp.*not found|No module named.*llama_cpp|No module named 'starlette_context'|git: command not found|cmake: command not found",
|
|
||||||
"llama.cpp / llama-cpp-python dependencies are missing.",
|
|
||||||
[{"label": "install llama.cpp dependencies or llama-cpp-python[server]", "op": "dependency", "package": "llama-cpp-python[server]"}],
|
|
||||||
),
|
|
||||||
(
|
|
||||||
r"No GGUF found on this host|no \.gguf file|No GGUF file found",
|
|
||||||
"No GGUF file found for this model on this host. The llama.cpp backend needs a .gguf file.",
|
|
||||||
[{"label": "download a GGUF build of this model (repo name usually ends in -GGUF, file like Q4_K_M.gguf)", "op": "manual"}],
|
|
||||||
),
|
|
||||||
(
|
|
||||||
r"No module named 'torch'|No module named torch|No module named 'diffusers'|No module named diffusers",
|
|
||||||
"Diffusion serving requires PyTorch and diffusers.",
|
|
||||||
[{"label": "install diffusers[torch] in Cookbook Dependencies", "op": "dependency", "package": "diffusers[torch]"}],
|
|
||||||
),
|
|
||||||
(
|
|
||||||
r"403 Forbidden|401 Unauthorized|Access to model.*is restricted|gated repo|not in the authorized list|awaiting a review",
|
|
||||||
"Model access is gated or unauthorized.",
|
|
||||||
[{"label": "set HF token and request model access on HuggingFace", "op": "manual"}],
|
|
||||||
),
|
|
||||||
]
|
|
||||||
for pattern, message, suggestions in patterns:
|
|
||||||
if re.search(pattern, tail, re.I):
|
|
||||||
return {"message": message, "suggestions": suggestions}
|
|
||||||
if re.search(r"Traceback \(most recent call last\)", tail, re.I) and not re.search(
|
|
||||||
r"Application startup complete|GET /v1/|Uvicorn running on", tail, re.I
|
|
||||||
):
|
|
||||||
return {
|
|
||||||
"message": "Python traceback detected during serve startup.",
|
|
||||||
"suggestions": [{"label": "inspect traceback and retry with adjusted backend/settings", "op": "manual"}],
|
|
||||||
}
|
|
||||||
return None
|
|
||||||
|
|
||||||
def _state_for_client(state):
|
def _state_for_client(state):
|
||||||
"""Return cookbook state without raw secrets for browser clients."""
|
"""Return cookbook state without raw secrets for browser clients."""
|
||||||
_strip_task_secrets(state)
|
_strip_task_secrets(state)
|
||||||
@@ -295,6 +183,7 @@ def setup_cookbook_routes() -> APIRouter:
|
|||||||
safe_chmod(key_path.with_suffix(".pub"), 0o644)
|
safe_chmod(key_path.with_suffix(".pub"), 0o644)
|
||||||
return {"ok": True, "public_key": _read_cookbook_public_key()}
|
return {"ok": True, "public_key": _read_cookbook_public_key()}
|
||||||
|
|
||||||
|
|
||||||
def _needs_binary(cmd: str, binary: str) -> bool:
|
def _needs_binary(cmd: str, binary: str) -> bool:
|
||||||
return bool(re.search(rf"(^|[\s;&|()]){re.escape(binary)}($|[\s;&|()])", cmd or ""))
|
return bool(re.search(rf"(^|[\s;&|()]){re.escape(binary)}($|[\s;&|()])", cmd or ""))
|
||||||
|
|
||||||
@@ -355,8 +244,8 @@ def setup_cookbook_routes() -> APIRouter:
|
|||||||
# POSIX form + shell-quoting so drive paths / spaces survive.
|
# POSIX form + shell-quoting so drive paths / spaces survive.
|
||||||
inner = TMUX_LOG_DIR / f"{session_id}_run.sh"
|
inner = TMUX_LOG_DIR / f"{session_id}_run.sh"
|
||||||
inner.write_text("\n".join(bash_lines) + "\n", encoding="utf-8")
|
inner.write_text("\n".join(bash_lines) + "\n", encoding="utf-8")
|
||||||
lp = shlex.quote(log_path.as_posix())
|
lp = shlex.quote(git_bash_path(log_path))
|
||||||
ip = shlex.quote(inner.as_posix())
|
ip = shlex.quote(git_bash_path(inner))
|
||||||
script_path = TMUX_LOG_DIR / f"{session_id}.sh"
|
script_path = TMUX_LOG_DIR / f"{session_id}.sh"
|
||||||
script_path.write_text(
|
script_path.write_text(
|
||||||
f"bash {ip} > {lp} 2>&1\n",
|
f"bash {ip} > {lp} 2>&1\n",
|
||||||
@@ -472,6 +361,8 @@ def setup_cookbook_routes() -> APIRouter:
|
|||||||
ps_lines = []
|
ps_lines = []
|
||||||
ps_lines.append('$sessionDir = "$env:TEMP\\odysseus-sessions"')
|
ps_lines.append('$sessionDir = "$env:TEMP\\odysseus-sessions"')
|
||||||
ps_lines.append('New-Item -ItemType Directory -Force -Path $sessionDir | Out-Null')
|
ps_lines.append('New-Item -ItemType Directory -Force -Path $sessionDir | Out-Null')
|
||||||
|
ps_lines.append('$env:PYTHONIOENCODING = "utf-8"')
|
||||||
|
ps_lines.append('$env:PYTHONUTF8 = "1"')
|
||||||
if req.hf_token:
|
if req.hf_token:
|
||||||
ps_lines.append(f"$env:HF_TOKEN = '{_ps_squote(req.hf_token)}'")
|
ps_lines.append(f"$env:HF_TOKEN = '{_ps_squote(req.hf_token)}'")
|
||||||
if req.env_prefix:
|
if req.env_prefix:
|
||||||
@@ -545,7 +436,7 @@ def setup_cookbook_routes() -> APIRouter:
|
|||||||
# Install hf CLI + optional hf_transfer best-effort. Retries disable
|
# Install hf CLI + optional hf_transfer best-effort. Retries disable
|
||||||
# hf_transfer because the Rust parallel path is fast but has been
|
# hf_transfer because the Rust parallel path is fast but has been
|
||||||
# flaky near the end of very large multi-file downloads.
|
# flaky near the end of very large multi-file downloads.
|
||||||
# Use --break-system-packages on PEP-668 systems (Arch, newer Debian) so it doesn't bail.
|
# The helper tries active pip first, then guarded user-site fallbacks.
|
||||||
runner_lines.append(f"command -v hf >/dev/null 2>&1 || {_pip_install_fallback_chain('huggingface_hub', python_cmd='pip', upgrade=True)}")
|
runner_lines.append(f"command -v hf >/dev/null 2>&1 || {_pip_install_fallback_chain('huggingface_hub', python_cmd='pip', upgrade=True)}")
|
||||||
if req.disable_hf_transfer:
|
if req.disable_hf_transfer:
|
||||||
runner_lines.append("export HF_HUB_ENABLE_HF_TRANSFER=0")
|
runner_lines.append("export HF_HUB_ENABLE_HF_TRANSFER=0")
|
||||||
@@ -673,24 +564,35 @@ def setup_cookbook_routes() -> APIRouter:
|
|||||||
for d in model_dir.split(','):
|
for d in model_dir.split(','):
|
||||||
d = d.strip()
|
d = d.strip()
|
||||||
if d:
|
if d:
|
||||||
model_dirs.append(d)
|
translated_d = translate_path(d) if not host else d
|
||||||
paths_code = _cached_model_scan_script(model_dirs)
|
model_dirs.append(translated_d)
|
||||||
|
win_hf_hub = None
|
||||||
|
if not host:
|
||||||
|
win_profile = get_wsl_windows_user_profile()
|
||||||
|
win_hf_hub = os.path.join(win_profile, ".cache", "huggingface", "hub") if win_profile else None
|
||||||
|
|
||||||
|
paths_code = _cached_model_scan_script(model_dirs, win_hf_hub)
|
||||||
|
|
||||||
scan_py = TMUX_LOG_DIR / "scan_cache.py"
|
scan_py = TMUX_LOG_DIR / "scan_cache.py"
|
||||||
scan_py.write_text(paths_code, encoding="utf-8")
|
scan_py.write_text(paths_code, encoding="utf-8")
|
||||||
|
scan_payload = scan_py.read_bytes()
|
||||||
|
|
||||||
if host:
|
if host:
|
||||||
_pf = f"-p {ssh_port} " if ssh_port and ssh_port != "22" else ""
|
|
||||||
if platform == "windows":
|
if platform == "windows":
|
||||||
# Windows: use 'python' and pipe via stdin with double-quote wrapping
|
remote_cmd = "python -"
|
||||||
cmd = f'ssh {_pf}{host} "python -" < \'{scan_py}\''
|
|
||||||
else:
|
else:
|
||||||
cmd = f"ssh {_pf}{host} 'python3 -' < '{scan_py}'"
|
# POSIX: use 'python3' if available, fall back to 'python'; throw if neither is found.
|
||||||
proc = await asyncio.create_subprocess_shell(
|
remote_cmd = (
|
||||||
cmd,
|
"if command -v python3 >/dev/null 2>&1; then python3 -; "
|
||||||
stdout=asyncio.subprocess.PIPE,
|
"elif command -v python >/dev/null 2>&1; then python -; "
|
||||||
stderr=asyncio.subprocess.PIPE,
|
"else echo \"python3/python not found\" >&2; exit 127; fi"
|
||||||
cwd=str(Path.home()),
|
)
|
||||||
|
rc, stdout_b, stderr_b = await run_ssh_command_async(
|
||||||
|
host,
|
||||||
|
ssh_port,
|
||||||
|
remote_cmd,
|
||||||
|
timeout=60,
|
||||||
|
stdin_data=scan_payload,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# LOCAL scan: use sys.executable (the venv Python Odysseus is already
|
# LOCAL scan: use sys.executable (the venv Python Odysseus is already
|
||||||
@@ -710,7 +612,7 @@ def setup_cookbook_routes() -> APIRouter:
|
|||||||
stderr=asyncio.subprocess.PIPE,
|
stderr=asyncio.subprocess.PIPE,
|
||||||
cwd=str(Path.home()),
|
cwd=str(Path.home()),
|
||||||
)
|
)
|
||||||
stdout_b, stderr_b = await asyncio.wait_for(proc.communicate(), timeout=60)
|
stdout_b, stderr_b = await asyncio.wait_for(proc.communicate(), timeout=60)
|
||||||
|
|
||||||
models = []
|
models = []
|
||||||
try:
|
try:
|
||||||
@@ -915,6 +817,10 @@ def setup_cookbook_routes() -> APIRouter:
|
|||||||
existing.name = display_name
|
existing.name = display_name
|
||||||
if supports_tools is not None:
|
if supports_tools is not None:
|
||||||
existing.supports_tools = supports_tools
|
existing.supports_tools = supports_tools
|
||||||
|
# Wipe stale model lists so the picker re-probes and discovers
|
||||||
|
# the newly-served model instead of showing the old one.
|
||||||
|
existing.cached_models = None
|
||||||
|
existing.hidden_models = None
|
||||||
db.commit()
|
db.commit()
|
||||||
logger.info(f"Updated existing local model endpoint: {base_url}")
|
logger.info(f"Updated existing local model endpoint: {base_url}")
|
||||||
return existing.id
|
return existing.id
|
||||||
@@ -971,11 +877,27 @@ def setup_cookbook_routes() -> APIRouter:
|
|||||||
in_venv=sys.prefix != sys.base_prefix,
|
in_venv=sys.prefix != sys.base_prefix,
|
||||||
)
|
)
|
||||||
is_pip_install = bool(req.cmd and "pip install" in req.cmd)
|
is_pip_install = bool(req.cmd and "pip install" in req.cmd)
|
||||||
|
remote = req.remote_host
|
||||||
|
is_windows = req.platform == "windows"
|
||||||
|
local_windows = IS_WINDOWS and not remote
|
||||||
|
if is_windows or local_windows:
|
||||||
|
if req.cmd.startswith("python3 "):
|
||||||
|
req.cmd = "python " + req.cmd[len("python3 "):]
|
||||||
|
if is_pip_install and ("llama-cpp-python" in req.cmd or "llama_cpp" in req.cmd) and (is_windows or local_windows):
|
||||||
|
if "--extra-index-url" not in req.cmd:
|
||||||
|
req.cmd += " --extra-index-url https://abetlen.github.io/llama-cpp-python/whl/cpu"
|
||||||
|
|
||||||
if is_pip_install:
|
if is_pip_install:
|
||||||
# Keep big dependency wheel builds (vLLM, …) off the home filesystem's
|
# Keep big dependency wheel builds (vLLM, …) off the home filesystem's
|
||||||
# pip cache so they don't fail mid-build with "No space left" (#1219)
|
# pip cache so they don't fail mid-build with "No space left" (#1219)
|
||||||
# and leave the dep installed-but-unusable (#1459).
|
# and leave the dep installed-but-unusable (#1459).
|
||||||
req.cmd = _pip_install_no_cache(req.cmd)
|
req.cmd = _pip_install_no_cache(req.cmd)
|
||||||
|
# Accept common aliases and enforce server extras for llama-cpp so
|
||||||
|
# `python -m llama_cpp.server` has all runtime dependencies.
|
||||||
|
req.cmd = re.sub(r"(?<![A-Za-z0-9_.-])llama_cpp(?![A-Za-z0-9_.-])", "llama-cpp-python[server]", req.cmd)
|
||||||
|
req.cmd = re.sub(r"(?<![A-Za-z0-9_.-])llama-cpp-python(?!\[)", "llama-cpp-python[server]", req.cmd)
|
||||||
|
if "llama-cpp-python" in req.cmd and "--extra-index-url" not in req.cmd:
|
||||||
|
req.cmd += " --extra-index-url https://abetlen.github.io/llama-cpp-python/whl/cpu"
|
||||||
# PEP-508-style package spec — letters, digits, `.-_` for the
|
# PEP-508-style package spec — letters, digits, `.-_` for the
|
||||||
# name; `[` `]` for extras; `<>=!~,` for version specifiers.
|
# name; `[` `]` for extras; `<>=!~,` for version specifiers.
|
||||||
# v2 review HIGH-14: tightened from the previous regex which
|
# v2 review HIGH-14: tightened from the previous regex which
|
||||||
@@ -1028,6 +950,8 @@ def setup_cookbook_routes() -> APIRouter:
|
|||||||
ps_lines = []
|
ps_lines = []
|
||||||
ps_lines.append('$sessionDir = "$env:TEMP\\odysseus-sessions"')
|
ps_lines.append('$sessionDir = "$env:TEMP\\odysseus-sessions"')
|
||||||
ps_lines.append('New-Item -ItemType Directory -Force -Path $sessionDir | Out-Null')
|
ps_lines.append('New-Item -ItemType Directory -Force -Path $sessionDir | Out-Null')
|
||||||
|
ps_lines.append('$env:PYTHONIOENCODING = "utf-8"')
|
||||||
|
ps_lines.append('$env:PYTHONUTF8 = "1"')
|
||||||
if req.hf_token:
|
if req.hf_token:
|
||||||
ps_lines.append(f"$env:HF_TOKEN = '{_ps_squote(req.hf_token)}'")
|
ps_lines.append(f"$env:HF_TOKEN = '{_ps_squote(req.hf_token)}'")
|
||||||
if req.gpus:
|
if req.gpus:
|
||||||
@@ -1046,7 +970,7 @@ def setup_cookbook_routes() -> APIRouter:
|
|||||||
ps_lines.append('try { python -c "import llama_cpp" 2>$null } catch {}')
|
ps_lines.append('try { python -c "import llama_cpp" 2>$null } catch {}')
|
||||||
ps_lines.append('if ($LASTEXITCODE -ne 0) {')
|
ps_lines.append('if ($LASTEXITCODE -ne 0) {')
|
||||||
ps_lines.append(' Write-Host "Installing llama-cpp-python..."')
|
ps_lines.append(' Write-Host "Installing llama-cpp-python..."')
|
||||||
ps_lines.append(' python -m pip install llama-cpp-python[server]')
|
ps_lines.append(' python -m pip install llama-cpp-python[server] --extra-index-url https://abetlen.github.io/llama-cpp-python/whl/cpu')
|
||||||
ps_lines.append('}')
|
ps_lines.append('}')
|
||||||
elif "vllm" in req.cmd:
|
elif "vllm" in req.cmd:
|
||||||
ps_lines.append('Write-Host "ERROR: vLLM is not supported on Windows. Use Ollama or llama.cpp instead."')
|
ps_lines.append('Write-Host "ERROR: vLLM is not supported on Windows. Use Ollama or llama.cpp instead."')
|
||||||
@@ -1121,45 +1045,57 @@ def setup_cookbook_routes() -> APIRouter:
|
|||||||
# ollama is found (otherwise macOS falls back to a slow source build).
|
# ollama is found (otherwise macOS falls back to a slow source build).
|
||||||
# /opt/homebrew = Apple Silicon, /usr/local = Intel; harmless on Linux.
|
# /opt/homebrew = Apple Silicon, /usr/local = Intel; harmless on Linux.
|
||||||
runner_lines.append('export PATH="$HOME/.local/bin:$HOME/bin:$HOME/llama.cpp/build/bin:/opt/homebrew/bin:/usr/local/bin:$PATH"')
|
runner_lines.append('export PATH="$HOME/.local/bin:$HOME/bin:$HOME/llama.cpp/build/bin:/opt/homebrew/bin:/usr/local/bin:$PATH"')
|
||||||
runner_lines.append('if [ -d /data/data/com.termux ]; then')
|
if local_windows:
|
||||||
runner_lines.append(' # Termux: no native build — use the Python bindings (CPU).')
|
# LOCAL Windows: no native source compilation (no cmake/compiler on Git Bash).
|
||||||
runner_lines.append(' if ! python3 -c "import llama_cpp" 2>/dev/null; then')
|
# Just check python bindings (using native `python` binary) and fall back to pip install.
|
||||||
runner_lines.append(' pkg install -y cmake 2>/dev/null')
|
runner_lines.append('if ! command -v llama-server &>/dev/null && ! python -c "import llama_cpp" 2>/dev/null; then')
|
||||||
runner_lines.append(' pip install numpy diskcache jinja2 2>/dev/null')
|
runner_lines.append(' echo "llama-server not found — installing Python bindings..."')
|
||||||
runner_lines.append(' CMAKE_ARGS="-DGGML_BLAS=OFF -DGGML_LLAMAFILE=OFF" pip install \'llama-cpp-python[server]\' --no-build-isolation --no-cache-dir 2>&1 || true')
|
runner_lines.append(f" {_pip_install_fallback_chain('llama-cpp-python[server]', python_cmd='python')} || true")
|
||||||
runner_lines.append(' fi')
|
runner_lines.append('fi')
|
||||||
runner_lines.append('elif ! command -v llama-server &>/dev/null; then')
|
runner_lines.append('if ! command -v llama-server &>/dev/null && ! python -c "import llama_cpp" 2>/dev/null; then')
|
||||||
runner_lines.append(' echo "Native llama-server not found — building from source (one-time, may take a few minutes)..."')
|
runner_lines.append(' echo "ERROR: llama.cpp serving is not available after install attempts."')
|
||||||
runner_lines.append(' mkdir -p ~/bin')
|
runner_lines.append(' ODYSSEUS_PREFLIGHT_EXIT=127')
|
||||||
runner_lines.append(' cd ~ && [ -d llama.cpp ] || git clone --depth 1 https://github.com/ggml-org/llama.cpp')
|
runner_lines.append('fi')
|
||||||
# Build with the right accelerator: Metal on macOS (llama.cpp
|
else:
|
||||||
# enables it automatically, no flag), CUDA on Linux when present,
|
runner_lines.append('if [ -d /data/data/com.termux ]; then')
|
||||||
# else a plain CPU build. nproc is Linux-only — fall back to
|
runner_lines.append(' # Termux: no native build — use the Python bindings (CPU).')
|
||||||
# `sysctl hw.ncpu` on macOS. (Tip: `brew install llama.cpp` ships
|
runner_lines.append(' if ! python3 -c "import llama_cpp" 2>/dev/null; then')
|
||||||
# a prebuilt llama-server and skips this whole source build.)
|
runner_lines.append(' pkg install -y cmake 2>/dev/null')
|
||||||
runner_lines.append(' NPROC="$(nproc 2>/dev/null || sysctl -n hw.ncpu 2>/dev/null || echo 4)"')
|
runner_lines.append(' pip install numpy diskcache jinja2 2>/dev/null')
|
||||||
runner_lines.append(' if [ "$(uname -s)" = "Darwin" ]; then')
|
runner_lines.append(' CMAKE_ARGS="-DGGML_BLAS=OFF -DGGML_LLAMAFILE=OFF" pip install \'llama-cpp-python[server]\' --no-build-isolation --no-cache-dir 2>&1 || true')
|
||||||
runner_lines.append(' command -v cmake >/dev/null 2>&1 || echo "WARNING: cmake not found — install it with: brew install cmake (or: brew install llama.cpp for a prebuilt llama-server)."')
|
runner_lines.append(' fi')
|
||||||
# Start from a clean cache: a prior failed configure (e.g. a CUDA
|
runner_lines.append('elif ! command -v llama-server &>/dev/null; then')
|
||||||
# attempt) poisons build/CMakeCache.txt, so a plain `cmake -B build`
|
runner_lines.append(' echo "Native llama-server not found — building from source (one-time, may take a few minutes)..."')
|
||||||
# would reuse the bad settings and fail again. CMAKE_BUILD_TYPE is
|
runner_lines.append(' mkdir -p ~/bin')
|
||||||
# explicit so the binary is optimized (Metal auto-enables on macOS).
|
runner_lines.append(' cd ~ && [ -d llama.cpp ] || git clone --depth 1 https://github.com/ggml-org/llama.cpp')
|
||||||
runner_lines.append(' cd ~/llama.cpp && rm -rf build && cmake -B build -DCMAKE_BUILD_TYPE=Release \\')
|
# Build with the right accelerator: Metal on macOS (llama.cpp
|
||||||
runner_lines.append(' && cmake --build build -j"$NPROC" --target llama-server \\')
|
# enables it automatically, no flag), CUDA on Linux when present,
|
||||||
runner_lines.append(' && ln -sf ~/llama.cpp/build/bin/llama-server ~/bin/llama-server')
|
# else a plain CPU build. nproc is Linux-only — fall back to
|
||||||
runner_lines.append(' else')
|
# `sysctl hw.ncpu` on macOS. (Tip: `brew install llama.cpp` ships
|
||||||
_append_llama_cpp_linux_accel_build_lines(runner_lines)
|
# a prebuilt llama-server and skips this whole source build.)
|
||||||
runner_lines.append(' fi')
|
runner_lines.append(' NPROC="$(nproc 2>/dev/null || sysctl -n hw.ncpu 2>/dev/null || echo 4)"')
|
||||||
runner_lines.append(' # If the native build failed, fall back to the Python bindings.')
|
runner_lines.append(' if [ "$(uname -s)" = "Darwin" ]; then')
|
||||||
runner_lines.append(' if ! command -v llama-server &>/dev/null && ! python3 -c "import llama_cpp" 2>/dev/null; then')
|
runner_lines.append(' command -v cmake >/dev/null 2>&1 || echo "WARNING: cmake not found — install it with: brew install cmake (or: brew install llama.cpp for a prebuilt llama-server)."')
|
||||||
runner_lines.append(' echo "llama-server build failed — installing Python bindings as fallback..."')
|
# Start from a clean cache: a prior failed configure (e.g. a CUDA
|
||||||
runner_lines.append(f" {_pip_install_fallback_chain('llama-cpp-python[server]', python_cmd='pip')} || true")
|
# attempt) poisons build/CMakeCache.txt, so a plain `cmake -B build`
|
||||||
runner_lines.append(' fi')
|
# would reuse the bad settings and fail again. CMAKE_BUILD_TYPE is
|
||||||
runner_lines.append(' if ! command -v llama-server &>/dev/null && ! python3 -c "import llama_cpp" 2>/dev/null; then')
|
# explicit so the binary is optimized (Metal auto-enables on macOS).
|
||||||
runner_lines.append(' echo "ERROR: llama.cpp serving is not available after install/build attempts."')
|
runner_lines.append(' cd ~/llama.cpp && rm -rf build && cmake -B build -DCMAKE_BUILD_TYPE=Release \\')
|
||||||
runner_lines.append(' ODYSSEUS_PREFLIGHT_EXIT=127')
|
runner_lines.append(' && cmake --build build -j"$NPROC" --target llama-server \\')
|
||||||
runner_lines.append(' fi')
|
runner_lines.append(' && ln -sf ~/llama.cpp/build/bin/llama-server ~/bin/llama-server')
|
||||||
runner_lines.append('fi')
|
runner_lines.append(' else')
|
||||||
|
_append_llama_cpp_linux_accel_build_lines(runner_lines)
|
||||||
|
runner_lines.append(' fi')
|
||||||
|
# If the native build failed, fall back to the Python bindings.
|
||||||
|
runner_lines.append(' if ! command -v llama-server &>/dev/null && ! python3 -c "import llama_cpp" 2>/dev/null; then')
|
||||||
|
runner_lines.append(' echo "llama-server build failed — installing Python bindings as fallback..."')
|
||||||
|
runner_lines.append(f" {_pip_install_fallback_chain('llama-cpp-python[server]', python_cmd='pip')} || true")
|
||||||
|
runner_lines.append(' fi')
|
||||||
|
runner_lines.append(' if ! command -v llama-server &>/dev/null && ! python3 -c "import llama_cpp" 2>/dev/null; then')
|
||||||
|
runner_lines.append(' echo "ERROR: llama.cpp serving is not available after install/build attempts."')
|
||||||
|
runner_lines.append(' ODYSSEUS_PREFLIGHT_EXIT=127')
|
||||||
|
runner_lines.append(' fi')
|
||||||
|
runner_lines.append('fi')
|
||||||
elif "ollama" in req.cmd:
|
elif "ollama" in req.cmd:
|
||||||
handled_ollama_serve = True
|
handled_ollama_serve = True
|
||||||
_ollama_default_host = "0.0.0.0" if remote else "127.0.0.1"
|
_ollama_default_host = "0.0.0.0" if remote else "127.0.0.1"
|
||||||
@@ -1181,13 +1117,23 @@ def setup_cookbook_routes() -> APIRouter:
|
|||||||
runner_lines.append(' ODYSSEUS_OLLAMA_PORT="$_ody_try_port"')
|
runner_lines.append(' ODYSSEUS_OLLAMA_PORT="$_ody_try_port"')
|
||||||
runner_lines.append(' break')
|
runner_lines.append(' break')
|
||||||
runner_lines.append(' fi')
|
runner_lines.append(' fi')
|
||||||
runner_lines.append(' exec 3<&-; exec 3>&-')
|
runner_lines.append(' echo "[odysseus] Ollama API ready on port ${ODYSSEUS_OLLAMA_PORT}: ${ODYSSEUS_OLLAMA_URL}"')
|
||||||
runner_lines.append('done')
|
runner_lines.append(' echo "[odysseus] This task is monitoring an existing Ollama server; stopping it here will not stop an external Docker/system service."')
|
||||||
|
if local_windows:
|
||||||
|
# Windows detached process has no TTY; exec bash -i crashes.
|
||||||
|
# Keep the monitoring task alive with a sleep loop.
|
||||||
|
runner_lines.append(' while true; do sleep 60; done')
|
||||||
|
else:
|
||||||
|
runner_lines.append(' exec bash -i')
|
||||||
|
runner_lines.append('fi')
|
||||||
runner_lines.append('if ! command -v ollama &>/dev/null; then')
|
runner_lines.append('if ! command -v ollama &>/dev/null; then')
|
||||||
runner_lines.append(' echo "ERROR: Ollama not found on this server. Install it from https://ollama.com/download or `curl -fsSL https://ollama.com/install.sh | sh`."')
|
runner_lines.append(' echo "ERROR: Ollama not found on this server. Install it from https://ollama.com/download or `curl -fsSL https://ollama.com/install.sh | sh`."')
|
||||||
runner_lines.append(' echo')
|
runner_lines.append(' echo')
|
||||||
runner_lines.append(' echo "=== Process exited with code 127 ==="')
|
runner_lines.append(' echo "=== Process exited with code 127 ==="')
|
||||||
runner_lines.append(' exec bash -i')
|
if local_windows:
|
||||||
|
runner_lines.append(' exit 127')
|
||||||
|
else:
|
||||||
|
runner_lines.append(' exec bash -i')
|
||||||
runner_lines.append('fi')
|
runner_lines.append('fi')
|
||||||
runner_lines.append('ODYSSEUS_OLLAMA_URL="http://${ODYSSEUS_OLLAMA_HOST}:${ODYSSEUS_OLLAMA_PORT}"')
|
runner_lines.append('ODYSSEUS_OLLAMA_URL="http://${ODYSSEUS_OLLAMA_HOST}:${ODYSSEUS_OLLAMA_PORT}"')
|
||||||
if remote and _ollama_host in ("0.0.0.0", "::"):
|
if remote and _ollama_host in ("0.0.0.0", "::"):
|
||||||
@@ -1195,24 +1141,20 @@ def setup_cookbook_routes() -> APIRouter:
|
|||||||
runner_lines.append('echo "[odysseus] Ollama has no built-in authentication; expose this only on a trusted LAN/VPN or provide an explicit OLLAMA_HOST with your own access controls."')
|
runner_lines.append('echo "[odysseus] Ollama has no built-in authentication; expose this only on a trusted LAN/VPN or provide an explicit OLLAMA_HOST with your own access controls."')
|
||||||
runner_lines.append('echo "Starting ollama server on ${ODYSSEUS_OLLAMA_HOST}:${ODYSSEUS_OLLAMA_PORT}..."')
|
runner_lines.append('echo "Starting ollama server on ${ODYSSEUS_OLLAMA_HOST}:${ODYSSEUS_OLLAMA_PORT}..."')
|
||||||
runner_lines.append('OLLAMA_HOST="${ODYSSEUS_OLLAMA_HOST}:${ODYSSEUS_OLLAMA_PORT}" ollama serve')
|
runner_lines.append('OLLAMA_HOST="${ODYSSEUS_OLLAMA_HOST}:${ODYSSEUS_OLLAMA_PORT}" ollama serve')
|
||||||
runner_lines.append('_ody_exit=$?')
|
if local_windows:
|
||||||
runner_lines.append('echo')
|
_append_serve_exit_code_lines(runner_lines, keep_shell_open=False)
|
||||||
runner_lines.append('echo "=== Process exited with code ${_ody_exit} ==="')
|
else:
|
||||||
runner_lines.append('exec bash -i')
|
runner_lines.append('_ody_exit=$?')
|
||||||
|
runner_lines.append('echo')
|
||||||
|
runner_lines.append('echo "=== Process exited with code ${_ody_exit} ==="')
|
||||||
|
runner_lines.append('exec bash -i')
|
||||||
elif "vllm serve" in req.cmd:
|
elif "vllm serve" in req.cmd:
|
||||||
# vLLM is CUDA/ROCm-only and does not run on macOS at all.
|
# vLLM is CUDA/ROCm-only and does not run on macOS at all.
|
||||||
runner_lines.append('if [ "$(uname -s)" = "Darwin" ]; then')
|
runner_lines.append('if [ "$(uname -s)" = "Darwin" ]; then')
|
||||||
runner_lines.append(' echo "ERROR: vLLM does not run on macOS. Use Ollama or llama.cpp (Metal) instead."')
|
runner_lines.append(' echo "ERROR: vLLM does not run on macOS. Use Ollama or llama.cpp (Metal) instead."')
|
||||||
runner_lines.append(' ODYSSEUS_PREFLIGHT_EXIT=1')
|
runner_lines.append(' ODYSSEUS_PREFLIGHT_EXIT=1')
|
||||||
runner_lines.append('fi')
|
runner_lines.append('fi')
|
||||||
# Put ~/.local/bin on PATH first — without a venv, vllm installs
|
_append_vllm_linux_preflight_lines(runner_lines)
|
||||||
# there via --user and the non-login serve shell otherwise can't
|
|
||||||
# find the `vllm` CLI ("command not found"). Mirrors llama.cpp above.
|
|
||||||
runner_lines.append('export PATH="$HOME/.local/bin:$PATH"')
|
|
||||||
runner_lines.append('if ! command -v vllm &>/dev/null; then')
|
|
||||||
runner_lines.append(' echo "ERROR: vLLM is not installed."')
|
|
||||||
runner_lines.append(' ODYSSEUS_PREFLIGHT_EXIT=127')
|
|
||||||
runner_lines.append('fi')
|
|
||||||
elif "sglang.launch_server" in req.cmd:
|
elif "sglang.launch_server" in req.cmd:
|
||||||
runner_lines.append('export PATH="$HOME/.local/bin:$PATH"')
|
runner_lines.append('export PATH="$HOME/.local/bin:$PATH"')
|
||||||
runner_lines.append('if ! command -v sglang &>/dev/null; then')
|
runner_lines.append('if ! command -v sglang &>/dev/null; then')
|
||||||
@@ -1236,7 +1178,10 @@ def setup_cookbook_routes() -> APIRouter:
|
|||||||
runner_lines,
|
runner_lines,
|
||||||
keep_shell_open=not local_windows,
|
keep_shell_open=not local_windows,
|
||||||
)
|
)
|
||||||
runner_lines.append(req.cmd)
|
if is_pip_install:
|
||||||
|
_append_pip_install_runner_lines(runner_lines, req.cmd)
|
||||||
|
else:
|
||||||
|
runner_lines.append(req.cmd)
|
||||||
if local_windows:
|
if local_windows:
|
||||||
# Detached background process — no interactive shell to keep open.
|
# Detached background process — no interactive shell to keep open.
|
||||||
# Print the exit marker the status poller looks for, then stop.
|
# Print the exit marker the status poller looks for, then stop.
|
||||||
@@ -1397,8 +1342,8 @@ def setup_cookbook_routes() -> APIRouter:
|
|||||||
cmd = f"ssh {pf}{host} '{setup_script}'"
|
cmd = f"ssh {pf}{host} '{setup_script}'"
|
||||||
else:
|
else:
|
||||||
# Linux: auto-install tmux (via whichever package manager is available)
|
# Linux: auto-install tmux (via whichever package manager is available)
|
||||||
# and huggingface_hub + hf_transfer (falling back to --user/--break-system-packages
|
# and huggingface_hub + hf_transfer (falling back to --user, then
|
||||||
# on PEP-668 locked distros like Arch / newer Debian).
|
# guarded --break-system-packages on PEP-668 locked distros).
|
||||||
setup_script = (
|
setup_script = (
|
||||||
# Install tmux if missing — try common package managers; skip if no sudo
|
# Install tmux if missing — try common package managers; skip if no sudo
|
||||||
"if ! command -v tmux >/dev/null 2>&1; then "
|
"if ! command -v tmux >/dev/null 2>&1; then "
|
||||||
@@ -1410,10 +1355,15 @@ def setup_cookbook_routes() -> APIRouter:
|
|||||||
" fi; "
|
" fi; "
|
||||||
"fi; "
|
"fi; "
|
||||||
"command -v tmux >/dev/null 2>&1 || echo 'WARNING: tmux missing and auto-install failed (need passwordless sudo). Install manually.'; "
|
"command -v tmux >/dev/null 2>&1 || echo 'WARNING: tmux missing and auto-install failed (need passwordless sudo). Install manually.'; "
|
||||||
# Install Python bits. Try system install first; fall back to --user --break-system-packages on PEP 668 systems.
|
# Install Python bits. Try system install first; fall back to --user,
|
||||||
|
# then use --break-system-packages only when pip supports it.
|
||||||
"pip install -q huggingface_hub hf_transfer 2>/dev/null || "
|
"pip install -q huggingface_hub hf_transfer 2>/dev/null || "
|
||||||
"pip install --user --break-system-packages -q huggingface_hub hf_transfer 2>/dev/null || "
|
"pip install --user -q huggingface_hub hf_transfer 2>/dev/null || "
|
||||||
"pip3 install --user --break-system-packages -q huggingface_hub hf_transfer 2>/dev/null; "
|
"( pip install --help 2>/dev/null | grep -q -- --break-system-packages && "
|
||||||
|
"pip install --user --break-system-packages -q huggingface_hub hf_transfer 2>/dev/null ) || "
|
||||||
|
"pip3 install --user -q huggingface_hub hf_transfer 2>/dev/null || "
|
||||||
|
"( pip3 install --help 2>/dev/null | grep -q -- --break-system-packages && "
|
||||||
|
"pip3 install --user --break-system-packages -q huggingface_hub hf_transfer 2>/dev/null ); "
|
||||||
"python3 -c 'from huggingface_hub import snapshot_download; print(\"OK\")'"
|
"python3 -c 'from huggingface_hub import snapshot_download; print(\"OK\")'"
|
||||||
)
|
)
|
||||||
cmd = f"ssh {pf}{host} '{setup_script}'"
|
cmd = f"ssh {pf}{host} '{setup_script}'"
|
||||||
@@ -1436,11 +1386,38 @@ def setup_cookbook_routes() -> APIRouter:
|
|||||||
async def _run_nvidia_smi(query: str, host: str | None, ssh_port: str | None, timeout: int = 8):
|
async def _run_nvidia_smi(query: str, host: str | None, ssh_port: str | None, timeout: int = 8):
|
||||||
"""Run nvidia-smi locally or over SSH. Returns (stdout, error_or_None)."""
|
"""Run nvidia-smi locally or over SSH. Returns (stdout, error_or_None)."""
|
||||||
if host:
|
if host:
|
||||||
pf = f"-p {ssh_port} " if ssh_port and ssh_port != "22" else ""
|
candidates = [query]
|
||||||
cmd = f"ssh -o ConnectTimeout=5 -o StrictHostKeyChecking=no {pf}{host} '{query}'"
|
stripped = query.strip()
|
||||||
proc = await asyncio.create_subprocess_shell(
|
if stripped.startswith("nvidia-smi "):
|
||||||
cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE
|
args = stripped[len("nvidia-smi "):]
|
||||||
)
|
candidates.append(
|
||||||
|
"bash -lc "
|
||||||
|
+ shlex.quote(
|
||||||
|
f"{SSH_PATH_OVERRIDE}"
|
||||||
|
f"nvidia-smi {args}"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
for nvidia_path in NVIDIA_PATH_CANDIDATES:
|
||||||
|
candidates.append(f"{nvidia_path} {args}")
|
||||||
|
|
||||||
|
last_err = "nvidia-smi failed"
|
||||||
|
for candidate in candidates:
|
||||||
|
try:
|
||||||
|
rc, stdout, stderr = await run_ssh_command_async(
|
||||||
|
host,
|
||||||
|
ssh_port,
|
||||||
|
candidate,
|
||||||
|
connect_timeout=5,
|
||||||
|
timeout=timeout,
|
||||||
|
)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
return None, "nvidia-smi timed out"
|
||||||
|
if rc == 0:
|
||||||
|
return stdout.decode("utf-8", errors="replace"), None
|
||||||
|
err = (stderr.decode("utf-8", errors="replace") or "").strip()[:200]
|
||||||
|
if err:
|
||||||
|
last_err = err
|
||||||
|
return None, last_err
|
||||||
else:
|
else:
|
||||||
proc = await asyncio.create_subprocess_exec(
|
proc = await asyncio.create_subprocess_exec(
|
||||||
*shlex.split(query),
|
*shlex.split(query),
|
||||||
@@ -2203,7 +2180,13 @@ def setup_cookbook_routes() -> APIRouter:
|
|||||||
"inc=os.path.isdir(blobs) and any(x.endswith('.incomplete') for x in os.listdir(blobs));"
|
"inc=os.path.isdir(blobs) and any(x.endswith('.incomplete') for x in os.listdir(blobs));"
|
||||||
"sys.exit(0 if ok and not inc else 1)"
|
"sys.exit(0 if ok and not inc else 1)"
|
||||||
)
|
)
|
||||||
cmd = ["python3", "-c", py, repo_id]
|
if remote_host:
|
||||||
|
cmd = ["python3", "-c", py, repo_id]
|
||||||
|
else:
|
||||||
|
# Local Windows: python3 can hit the Microsoft Store stub. Use the
|
||||||
|
# real Python Odysseus is running under (guaranteed to exist).
|
||||||
|
import sys as _sys_local
|
||||||
|
cmd = [_sys_local.executable, "-c", py, repo_id]
|
||||||
try:
|
try:
|
||||||
if remote_host:
|
if remote_host:
|
||||||
ssh_base = ["ssh"]
|
ssh_base = ["ssh"]
|
||||||
|
|||||||
+67
-117
@@ -20,39 +20,26 @@ All routes are admin-gated (endpoint/provider management is an admin action).
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import time
|
|
||||||
import uuid
|
import uuid
|
||||||
import logging
|
import logging
|
||||||
import threading
|
|
||||||
from typing import Dict, Optional
|
from typing import Dict, Optional
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from fastapi import APIRouter, Request, Form, HTTPException
|
from fastapi import HTTPException, Request
|
||||||
|
|
||||||
from core.database import SessionLocal, ModelEndpoint
|
from core.database import SessionLocal, ModelEndpoint
|
||||||
from core.middleware import require_admin
|
from routes.device_flow import (
|
||||||
|
DeviceFlowPoll,
|
||||||
|
DeviceFlowStart,
|
||||||
|
PendingDeviceFlowStore,
|
||||||
|
create_device_flow_router,
|
||||||
|
)
|
||||||
from src.auth_helpers import get_current_user
|
from src.auth_helpers import get_current_user
|
||||||
from src import copilot
|
from src import copilot
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# Pending device-flow logins, keyed by an opaque poll_id. The device_code is a
|
_DEVICE_FLOW_STORE = PendingDeviceFlowStore()
|
||||||
# bearer-like secret, so it lives here (server memory) rather than in the
|
|
||||||
# browser. Entries expire with the GitHub device code.
|
|
||||||
#
|
|
||||||
# NOTE: this is per-process state. The device flow assumes a single worker
|
|
||||||
# (Odysseus' default): with multiple uvicorn workers, the poll request can land
|
|
||||||
# on a worker that never saw the start, returning "Unknown or expired login
|
|
||||||
# session". Move this to a shared store (DB/Redis) if running multi-worker.
|
|
||||||
_PENDING: Dict[str, Dict] = {}
|
|
||||||
_PENDING_LOCK = threading.Lock()
|
|
||||||
|
|
||||||
|
|
||||||
def _prune_expired() -> None:
|
|
||||||
now = time.time()
|
|
||||||
with _PENDING_LOCK:
|
|
||||||
for k in [k for k, v in _PENDING.items() if v.get("expires_at", 0) < now]:
|
|
||||||
_PENDING.pop(k, None)
|
|
||||||
|
|
||||||
|
|
||||||
def _provision_endpoint(token: str, base: str, owner: Optional[str]) -> Dict:
|
def _provision_endpoint(token: str, base: str, owner: Optional[str]) -> Dict:
|
||||||
@@ -112,112 +99,75 @@ def _provision_endpoint(token: str, base: str, owner: Optional[str]) -> Dict:
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
def setup_copilot_routes() -> APIRouter:
|
def _start_device_flow(request: Request, form) -> DeviceFlowStart:
|
||||||
router = APIRouter(prefix="/api/copilot", tags=["copilot"])
|
host = copilot.GITHUB_HOST
|
||||||
|
ent = str(form.get("enterprise_url") or "").strip()
|
||||||
|
if ent:
|
||||||
|
host = copilot.normalize_domain(ent)
|
||||||
|
try:
|
||||||
|
data = copilot.request_device_code(host)
|
||||||
|
except httpx.HTTPStatusError as e:
|
||||||
|
status = e.response.status_code if e.response is not None else "unknown"
|
||||||
|
raise HTTPException(502, f"GitHub device-code request failed (HTTP {status})")
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(502, f"GitHub device-code request failed: {e}")
|
||||||
|
|
||||||
@router.post("/device/start")
|
device_code = data.get("device_code")
|
||||||
def device_start(request: Request, enterprise_url: str = Form("")):
|
if not device_code:
|
||||||
require_admin(request)
|
raise HTTPException(502, "GitHub did not return a device code")
|
||||||
_prune_expired()
|
|
||||||
host = copilot.GITHUB_HOST
|
|
||||||
ent = (enterprise_url or "").strip()
|
|
||||||
if ent:
|
|
||||||
host = copilot.normalize_domain(ent)
|
|
||||||
try:
|
|
||||||
data = copilot.request_device_code(host)
|
|
||||||
except httpx.HTTPStatusError as e:
|
|
||||||
status = e.response.status_code if e.response is not None else "unknown"
|
|
||||||
raise HTTPException(502, f"GitHub device-code request failed (HTTP {status})")
|
|
||||||
except Exception as e:
|
|
||||||
raise HTTPException(502, f"GitHub device-code request failed: {e}")
|
|
||||||
|
|
||||||
device_code = data.get("device_code")
|
# verification_uri_complete embeds the user code, so the browser tab we
|
||||||
if not device_code:
|
# open lands the user straight on GitHub's "Authorize" screen with the
|
||||||
raise HTTPException(502, "GitHub did not return a device code")
|
# code pre-filled — one click, no manual code entry.
|
||||||
interval = int(data.get("interval") or 5)
|
return DeviceFlowStart(
|
||||||
expires_in = int(data.get("expires_in") or 900)
|
pending={
|
||||||
poll_id = uuid.uuid4().hex
|
"device_code": device_code,
|
||||||
with _PENDING_LOCK:
|
"host": host,
|
||||||
_PENDING[poll_id] = {
|
"enterprise_url": ent,
|
||||||
"device_code": device_code,
|
"owner": get_current_user(request) or None,
|
||||||
"host": host,
|
},
|
||||||
"enterprise_url": ent,
|
response={
|
||||||
"interval": interval,
|
|
||||||
"owner": get_current_user(request) or None,
|
|
||||||
"expires_at": time.time() + expires_in,
|
|
||||||
"next_poll_at": 0.0,
|
|
||||||
}
|
|
||||||
# verification_uri_complete embeds the user code, so the browser tab we
|
|
||||||
# open lands the user straight on GitHub's "Authorize" screen with the
|
|
||||||
# code pre-filled — one click, no manual code entry.
|
|
||||||
return {
|
|
||||||
"poll_id": poll_id,
|
|
||||||
"user_code": data.get("user_code"),
|
"user_code": data.get("user_code"),
|
||||||
"verification_uri": data.get("verification_uri"),
|
"verification_uri": data.get("verification_uri"),
|
||||||
"verification_uri_complete": data.get("verification_uri_complete"),
|
"verification_uri_complete": data.get("verification_uri_complete"),
|
||||||
"interval": interval,
|
},
|
||||||
"expires_in": expires_in,
|
interval=int(data.get("interval") or 5),
|
||||||
}
|
expires_in=int(data.get("expires_in") or 900),
|
||||||
|
)
|
||||||
|
|
||||||
@router.post("/device/poll")
|
|
||||||
def device_poll(request: Request, poll_id: str = Form(...)):
|
|
||||||
require_admin(request)
|
|
||||||
_prune_expired()
|
|
||||||
with _PENDING_LOCK:
|
|
||||||
pending = _PENDING.get(poll_id)
|
|
||||||
if not pending:
|
|
||||||
raise HTTPException(404, "Unknown or expired login session")
|
|
||||||
|
|
||||||
# Enforce GitHub's polling interval server-side so a chatty client
|
def _poll_device_flow(_request: Request, pending: Dict) -> DeviceFlowPoll:
|
||||||
# can't trip slow_down.
|
try:
|
||||||
now = time.time()
|
data = copilot.poll_access_token(pending["host"], pending["device_code"])
|
||||||
if now < pending.get("next_poll_at", 0):
|
except Exception as e:
|
||||||
return {"status": "pending"}
|
return DeviceFlowPoll.pending(f"poll error: {e}")
|
||||||
|
|
||||||
|
token = data.get("access_token")
|
||||||
|
if token:
|
||||||
|
base = copilot.enterprise_base(pending["enterprise_url"]) if pending["enterprise_url"] else copilot.COPILOT_BASE
|
||||||
try:
|
try:
|
||||||
data = copilot.poll_access_token(pending["host"], pending["device_code"])
|
result = _provision_endpoint(token, base, pending["owner"])
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return {"status": "pending", "detail": f"poll error: {e}"}
|
logger.exception("Copilot endpoint provisioning failed")
|
||||||
|
raise HTTPException(500, f"Login succeeded but provisioning failed: {e}")
|
||||||
|
return DeviceFlowPoll.authorized(result)
|
||||||
|
|
||||||
token = data.get("access_token")
|
err = data.get("error")
|
||||||
if token:
|
if err == "authorization_pending":
|
||||||
base = copilot.enterprise_base(pending["enterprise_url"]) if pending["enterprise_url"] else copilot.COPILOT_BASE
|
return DeviceFlowPoll.pending()
|
||||||
try:
|
if err == "slow_down":
|
||||||
result = _provision_endpoint(token, base, pending["owner"])
|
return DeviceFlowPoll.slow_down(int(data.get("interval") or 0) or None)
|
||||||
except Exception as e:
|
if err in ("expired_token", "access_denied"):
|
||||||
logger.exception("Copilot endpoint provisioning failed")
|
return DeviceFlowPoll.failed(err)
|
||||||
with _PENDING_LOCK:
|
# Unknown error — surface but keep the session for another try.
|
||||||
_PENDING.pop(poll_id, None)
|
return DeviceFlowPoll.pending(err or "unknown")
|
||||||
raise HTTPException(500, f"Login succeeded but provisioning failed: {e}")
|
|
||||||
with _PENDING_LOCK:
|
|
||||||
_PENDING.pop(poll_id, None)
|
|
||||||
return {"status": "authorized", "endpoint": result}
|
|
||||||
|
|
||||||
err = data.get("error")
|
|
||||||
if err == "authorization_pending":
|
|
||||||
with _PENDING_LOCK:
|
|
||||||
if poll_id in _PENDING:
|
|
||||||
_PENDING[poll_id]["next_poll_at"] = now + pending["interval"]
|
|
||||||
return {"status": "pending"}
|
|
||||||
if err == "slow_down":
|
|
||||||
new_interval = int(data.get("interval") or (pending["interval"] + 5))
|
|
||||||
with _PENDING_LOCK:
|
|
||||||
if poll_id in _PENDING:
|
|
||||||
_PENDING[poll_id]["interval"] = new_interval
|
|
||||||
_PENDING[poll_id]["next_poll_at"] = now + new_interval
|
|
||||||
return {"status": "pending"}
|
|
||||||
if err in ("expired_token", "access_denied"):
|
|
||||||
with _PENDING_LOCK:
|
|
||||||
_PENDING.pop(poll_id, None)
|
|
||||||
return {"status": "failed", "error": err}
|
|
||||||
# Unknown error — surface but keep the session for another try.
|
|
||||||
return {"status": "pending", "detail": err or "unknown"}
|
|
||||||
|
|
||||||
@router.post("/device/cancel")
|
def setup_copilot_routes():
|
||||||
def device_cancel(request: Request, poll_id: str = Form(...)):
|
return create_device_flow_router(
|
||||||
require_admin(request)
|
prefix="/api/copilot",
|
||||||
with _PENDING_LOCK:
|
tags=["copilot"],
|
||||||
_PENDING.pop(poll_id, None)
|
store=_DEVICE_FLOW_STORE,
|
||||||
return {"status": "cancelled"}
|
start_flow=_start_device_flow,
|
||||||
|
poll_flow=_poll_device_flow,
|
||||||
return router
|
)
|
||||||
|
|||||||
@@ -0,0 +1,193 @@
|
|||||||
|
"""Shared OAuth/device-flow route scaffolding for provider setup."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import inspect
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any, Callable, Iterable, Mapping, Optional
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Form, HTTPException, Request
|
||||||
|
|
||||||
|
from core.middleware import require_admin
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class DeviceFlowStart:
|
||||||
|
"""Provider-specific start result consumed by the shared route wrapper."""
|
||||||
|
|
||||||
|
pending: Mapping[str, Any]
|
||||||
|
response: Mapping[str, Any]
|
||||||
|
interval: int = 5
|
||||||
|
expires_in: int = 900
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class DeviceFlowPoll:
|
||||||
|
"""Normalized provider poll outcome."""
|
||||||
|
|
||||||
|
status: str
|
||||||
|
endpoint: Optional[Mapping[str, Any]] = None
|
||||||
|
error: Optional[str] = None
|
||||||
|
detail: Optional[str] = None
|
||||||
|
interval: Optional[int] = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def pending(cls, detail: Optional[str] = None) -> "DeviceFlowPoll":
|
||||||
|
return cls(status="pending", detail=detail)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def slow_down(cls, interval: Optional[int] = None, detail: Optional[str] = None) -> "DeviceFlowPoll":
|
||||||
|
return cls(status="slow_down", interval=interval, detail=detail)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def authorized(cls, endpoint: Mapping[str, Any]) -> "DeviceFlowPoll":
|
||||||
|
return cls(status="authorized", endpoint=endpoint)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def failed(cls, error: str) -> "DeviceFlowPoll":
|
||||||
|
return cls(status="failed", error=error)
|
||||||
|
|
||||||
|
|
||||||
|
class PendingDeviceFlowStore:
|
||||||
|
"""Thread-safe in-memory pending device-flow store.
|
||||||
|
|
||||||
|
Device codes and provider-side secrets stay inside this process. Each entry
|
||||||
|
stores provider payload separately from poll metadata so provider callbacks
|
||||||
|
only receive the fields they created.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, *, time_func: Callable[[], float] = time.time):
|
||||||
|
self._pending: dict[str, dict[str, Any]] = {}
|
||||||
|
self._lock = threading.Lock()
|
||||||
|
self._time = time_func
|
||||||
|
|
||||||
|
def _now(self) -> float:
|
||||||
|
return float(self._time())
|
||||||
|
|
||||||
|
def prune_expired(self) -> None:
|
||||||
|
now = self._now()
|
||||||
|
with self._lock:
|
||||||
|
for key in [k for k, v in self._pending.items() if v.get("expires_at", 0) < now]:
|
||||||
|
self._pending.pop(key, None)
|
||||||
|
|
||||||
|
def add(self, payload: Mapping[str, Any], *, interval: int, expires_in: int) -> str:
|
||||||
|
self.prune_expired()
|
||||||
|
poll_id = uuid.uuid4().hex
|
||||||
|
with self._lock:
|
||||||
|
self._pending[poll_id] = {
|
||||||
|
"payload": dict(payload),
|
||||||
|
"interval": max(int(interval or 5), 1),
|
||||||
|
"expires_at": self._now() + max(int(expires_in or 900), 1),
|
||||||
|
"next_poll_at": 0.0,
|
||||||
|
}
|
||||||
|
return poll_id
|
||||||
|
|
||||||
|
def get_payload(self, poll_id: str) -> Optional[dict[str, Any]]:
|
||||||
|
self.prune_expired()
|
||||||
|
with self._lock:
|
||||||
|
entry = self._pending.get(poll_id)
|
||||||
|
if entry is None:
|
||||||
|
return None
|
||||||
|
return dict(entry.get("payload") or {})
|
||||||
|
|
||||||
|
def is_throttled(self, poll_id: str) -> bool:
|
||||||
|
with self._lock:
|
||||||
|
entry = self._pending.get(poll_id)
|
||||||
|
return bool(entry and self._now() < float(entry.get("next_poll_at") or 0))
|
||||||
|
|
||||||
|
def schedule_next(self, poll_id: str) -> None:
|
||||||
|
now = self._now()
|
||||||
|
with self._lock:
|
||||||
|
entry = self._pending.get(poll_id)
|
||||||
|
if entry is not None:
|
||||||
|
entry["next_poll_at"] = now + int(entry.get("interval") or 5)
|
||||||
|
|
||||||
|
def slow_down(self, poll_id: str, interval: Optional[int] = None) -> None:
|
||||||
|
now = self._now()
|
||||||
|
with self._lock:
|
||||||
|
entry = self._pending.get(poll_id)
|
||||||
|
if entry is not None:
|
||||||
|
new_interval = int(interval or (int(entry.get("interval") or 5) + 5))
|
||||||
|
entry["interval"] = max(new_interval, 1)
|
||||||
|
entry["next_poll_at"] = now + entry["interval"]
|
||||||
|
|
||||||
|
def pop(self, poll_id: str) -> None:
|
||||||
|
with self._lock:
|
||||||
|
self._pending.pop(poll_id, None)
|
||||||
|
|
||||||
|
|
||||||
|
async def _maybe_await(value: Any) -> Any:
|
||||||
|
if inspect.isawaitable(value):
|
||||||
|
return await value
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
def _pending_response(detail: Optional[str] = None) -> dict[str, Any]:
|
||||||
|
response: dict[str, Any] = {"status": "pending"}
|
||||||
|
if detail:
|
||||||
|
response["detail"] = detail
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
def create_device_flow_router(
|
||||||
|
*,
|
||||||
|
prefix: str,
|
||||||
|
tags: Iterable[str],
|
||||||
|
store: PendingDeviceFlowStore,
|
||||||
|
start_flow: Callable[[Request, Mapping[str, Any]], DeviceFlowStart],
|
||||||
|
poll_flow: Callable[[Request, Mapping[str, Any]], DeviceFlowPoll],
|
||||||
|
) -> APIRouter:
|
||||||
|
"""Create standard `/device/start|poll|cancel` routes for a provider."""
|
||||||
|
|
||||||
|
router = APIRouter(prefix=prefix, tags=list(tags))
|
||||||
|
|
||||||
|
@router.post("/device/start")
|
||||||
|
async def device_start(request: Request):
|
||||||
|
require_admin(request)
|
||||||
|
form = await request.form()
|
||||||
|
start = await _maybe_await(start_flow(request, form))
|
||||||
|
interval = int(start.interval or 5)
|
||||||
|
expires_in = int(start.expires_in or 900)
|
||||||
|
poll_id = store.add(start.pending, interval=interval, expires_in=expires_in)
|
||||||
|
response = dict(start.response)
|
||||||
|
response.update({"poll_id": poll_id, "interval": interval, "expires_in": expires_in})
|
||||||
|
return response
|
||||||
|
|
||||||
|
@router.post("/device/poll")
|
||||||
|
async def device_poll(request: Request, poll_id: str = Form(...)):
|
||||||
|
require_admin(request)
|
||||||
|
payload = store.get_payload(poll_id)
|
||||||
|
if payload is None:
|
||||||
|
raise HTTPException(404, "Unknown or expired login session")
|
||||||
|
if store.is_throttled(poll_id):
|
||||||
|
return {"status": "pending"}
|
||||||
|
|
||||||
|
try:
|
||||||
|
outcome = await _maybe_await(poll_flow(request, payload))
|
||||||
|
except Exception:
|
||||||
|
store.pop(poll_id)
|
||||||
|
raise
|
||||||
|
|
||||||
|
if outcome.status == "authorized":
|
||||||
|
store.pop(poll_id)
|
||||||
|
return {"status": "authorized", "endpoint": dict(outcome.endpoint or {})}
|
||||||
|
if outcome.status == "failed":
|
||||||
|
store.pop(poll_id)
|
||||||
|
return {"status": "failed", "error": outcome.error or "denied"}
|
||||||
|
if outcome.status == "slow_down":
|
||||||
|
store.slow_down(poll_id, outcome.interval)
|
||||||
|
return _pending_response(outcome.detail)
|
||||||
|
|
||||||
|
store.schedule_next(poll_id)
|
||||||
|
return _pending_response(outcome.detail)
|
||||||
|
|
||||||
|
@router.post("/device/cancel")
|
||||||
|
def device_cancel(request: Request, poll_id: str = Form(...)):
|
||||||
|
require_admin(request)
|
||||||
|
store.pop(poll_id)
|
||||||
|
return {"status": "cancelled"}
|
||||||
|
|
||||||
|
return router
|
||||||
+66
-36
@@ -7,14 +7,24 @@ from typing import Dict, Any, List, Optional
|
|||||||
|
|
||||||
from fastapi import APIRouter, HTTPException, Query, Request, UploadFile, File, Form
|
from fastapi import APIRouter, HTTPException, Query, Request, UploadFile, File, Form
|
||||||
|
|
||||||
from sqlalchemy import func
|
from sqlalchemy import case, func, or_
|
||||||
from core.database import SessionLocal, Document, DocumentVersion
|
from core.database import SessionLocal, Document, DocumentVersion
|
||||||
from core.database import Session as DbSession
|
from core.database import Session as DbSession
|
||||||
from src.auth_helpers import get_current_user
|
from src.auth_helpers import get_current_user
|
||||||
|
from src.constants import MAIL_ATTACHMENTS_DIR
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_session_or_404(db, session_id: str, user: Optional[str]):
|
||||||
|
session = db.query(DbSession).filter(DbSession.id == session_id).first()
|
||||||
|
if not session:
|
||||||
|
raise HTTPException(404, "Session not found")
|
||||||
|
if user and session.owner != user:
|
||||||
|
raise HTTPException(404, "Session not found")
|
||||||
|
return session
|
||||||
|
|
||||||
|
|
||||||
def _aggregate_language_facets(lang_rows):
|
def _aggregate_language_facets(lang_rows):
|
||||||
"""Sum document counts per display language for the library facet.
|
"""Sum document counts per display language for the library facet.
|
||||||
|
|
||||||
@@ -30,6 +40,19 @@ def _aggregate_language_facets(lang_rows):
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def _library_language_for_document(doc: Document) -> str:
|
||||||
|
"""Return the display language used by the document library.
|
||||||
|
|
||||||
|
PDF documents are stored as markdown wrappers so the editor can preserve
|
||||||
|
extracted text, form fields, and annotations. The library should still
|
||||||
|
identify them as PDFs instead of exposing that internal wrapper format.
|
||||||
|
"""
|
||||||
|
from src.pdf_form_doc import find_source_upload_id
|
||||||
|
|
||||||
|
if find_source_upload_id(doc.current_content or ""):
|
||||||
|
return "pdf"
|
||||||
|
return doc.language or "text"
|
||||||
|
|
||||||
|
|
||||||
from routes.document_helpers import (
|
from routes.document_helpers import (
|
||||||
DocumentCreate, DocumentUpdate, DocumentPatch,
|
DocumentCreate, DocumentUpdate, DocumentPatch,
|
||||||
@@ -69,17 +92,12 @@ def setup_document_routes(session_manager, upload_handler=None) -> APIRouter:
|
|||||||
# the doc is owner-stamped, so it lives in the library on its own.
|
# the doc is owner-stamped, so it lives in the library on its own.
|
||||||
session = None
|
session = None
|
||||||
if req.session_id:
|
if req.session_id:
|
||||||
session = db.query(DbSession).filter(DbSession.id == req.session_id).first()
|
|
||||||
if not session:
|
|
||||||
raise HTTPException(404, "Session not found")
|
|
||||||
# Match the lenient ownership model the rest of the app uses
|
# Match the lenient ownership model the rest of the app uses
|
||||||
# (see _owner_filter): only block when an AUTHENTICATED user is
|
# (see _owner_filter): only block when an AUTHENTICATED user is
|
||||||
# writing into a DIFFERENT user's session. In single-user /
|
# writing into a DIFFERENT user's session. In single-user /
|
||||||
# unconfigured / localhost-bypass mode the middleware leaves
|
# unconfigured / localhost-bypass mode, falsey users preserve
|
||||||
# current_user unset (None), and those sessions are already
|
# the existing lenient path.
|
||||||
# served freely everywhere else.
|
session = _get_session_or_404(db, req.session_id, user)
|
||||||
if user and session.owner and session.owner != user:
|
|
||||||
raise HTTPException(403, "Cannot create document in another user's session")
|
|
||||||
|
|
||||||
doc_id = str(uuid.uuid4())
|
doc_id = str(uuid.uuid4())
|
||||||
ver_id = str(uuid.uuid4())
|
ver_id = str(uuid.uuid4())
|
||||||
@@ -171,11 +189,7 @@ def setup_document_routes(session_manager, upload_handler=None) -> APIRouter:
|
|||||||
if session_id:
|
if session_id:
|
||||||
db = SessionLocal()
|
db = SessionLocal()
|
||||||
try:
|
try:
|
||||||
sess = db.query(DbSession).filter(DbSession.id == session_id).first()
|
_get_session_or_404(db, session_id, user)
|
||||||
if not sess:
|
|
||||||
raise HTTPException(404, "Session not found")
|
|
||||||
if user and sess.owner and sess.owner != user:
|
|
||||||
raise HTTPException(403, "Cannot import into another user's session")
|
|
||||||
finally:
|
finally:
|
||||||
db.close()
|
db.close()
|
||||||
|
|
||||||
@@ -198,7 +212,7 @@ def setup_document_routes(session_manager, upload_handler=None) -> APIRouter:
|
|||||||
|
|
||||||
title = os.path.splitext(meta.get("original_name") or meta.get("name") or upload_id)[0]
|
title = os.path.splitext(meta.get("original_name") or meta.get("name") or upload_id)[0]
|
||||||
try:
|
try:
|
||||||
body_text = strip_pdf_content_marker(_process_pdf(pdf_path))
|
body_text = strip_pdf_content_marker(_process_pdf(pdf_path, owner=user))
|
||||||
except Exception:
|
except Exception:
|
||||||
body_text = None
|
body_text = None
|
||||||
|
|
||||||
@@ -260,18 +274,29 @@ def setup_document_routes(session_manager, upload_handler=None) -> APIRouter:
|
|||||||
db = SessionLocal()
|
db = SessionLocal()
|
||||||
try:
|
try:
|
||||||
from sqlalchemy import or_
|
from sqlalchemy import or_
|
||||||
|
pdf_marker_cond = or_(
|
||||||
|
Document.current_content.like('%<!-- pdf_source upload_id="%'),
|
||||||
|
Document.current_content.like('%<!-- pdf_form_source upload_id="%'),
|
||||||
|
)
|
||||||
|
library_language_expr = case(
|
||||||
|
(pdf_marker_cond, "pdf"),
|
||||||
|
(Document.language.is_(None), "text"),
|
||||||
|
else_=Document.language,
|
||||||
|
)
|
||||||
# Archived view shows ONLY archived docs; the default view excludes
|
# Archived view shows ONLY archived docs; the default view excludes
|
||||||
# them (NULL = legacy rows that predate the column = not archived).
|
# them (NULL = legacy rows that predate the column = not archived).
|
||||||
_arch_cond = (Document.archived == True) if archived else or_(
|
_arch_cond = (Document.archived == True) if archived else or_(
|
||||||
Document.archived == False, Document.archived.is_(None))
|
Document.archived == False, Document.archived.is_(None))
|
||||||
# Language facet counts (owner-filtered)
|
# Language facet counts (owner-filtered). PDF documents are stored
|
||||||
|
# as markdown wrappers, so group by the library display language
|
||||||
|
# instead of the raw stored language.
|
||||||
lang_q = (
|
lang_q = (
|
||||||
db.query(Document.language, func.count(Document.id))
|
db.query(library_language_expr, func.count(Document.id))
|
||||||
.outerjoin(DbSession, Document.session_id == DbSession.id)
|
.outerjoin(DbSession, Document.session_id == DbSession.id)
|
||||||
.filter(Document.is_active == True).filter(_arch_cond)
|
.filter(Document.is_active == True).filter(_arch_cond)
|
||||||
)
|
)
|
||||||
lang_q = _owner_session_filter(lang_q, user)
|
lang_q = _owner_session_filter(lang_q, user)
|
||||||
lang_rows = lang_q.group_by(Document.language).all()
|
lang_rows = lang_q.group_by(library_language_expr).all()
|
||||||
languages = _aggregate_language_facets(lang_rows)
|
languages = _aggregate_language_facets(lang_rows)
|
||||||
|
|
||||||
# Session count (owner-filtered)
|
# Session count (owner-filtered)
|
||||||
@@ -303,12 +328,17 @@ def setup_document_routes(session_manager, upload_handler=None) -> APIRouter:
|
|||||||
Document.title.ilike(term) | Document.current_content.ilike(term)
|
Document.title.ilike(term) | Document.current_content.ilike(term)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Language filter
|
# Language filter. "pdf" is a display language derived from the
|
||||||
|
# source marker; "markdown" excludes those wrappers.
|
||||||
if language:
|
if language:
|
||||||
if language == "text":
|
if language == "text":
|
||||||
q = q.filter((Document.language == None) | (Document.language == "text"))
|
q = q.filter((Document.language == None) | (Document.language == "text"))
|
||||||
|
elif language == "pdf":
|
||||||
|
q = q.filter(pdf_marker_cond)
|
||||||
else:
|
else:
|
||||||
q = q.filter(Document.language == language)
|
q = q.filter(Document.language == language)
|
||||||
|
if language == "markdown":
|
||||||
|
q = q.filter(~pdf_marker_cond)
|
||||||
|
|
||||||
# Total before pagination
|
# Total before pagination
|
||||||
total = q.count()
|
total = q.count()
|
||||||
@@ -332,7 +362,7 @@ def setup_document_routes(session_manager, upload_handler=None) -> APIRouter:
|
|||||||
"session_id": doc.session_id,
|
"session_id": doc.session_id,
|
||||||
"session_name": session_name,
|
"session_name": session_name,
|
||||||
"title": doc.title,
|
"title": doc.title,
|
||||||
"language": doc.language or "text",
|
"language": _library_language_for_document(doc),
|
||||||
"preview": (doc.current_content or "")[:500],
|
"preview": (doc.current_content or "")[:500],
|
||||||
"version_count": doc.version_count,
|
"version_count": doc.version_count,
|
||||||
"created_at": (doc.created_at.isoformat() + "Z") if doc.created_at else None,
|
"created_at": (doc.created_at.isoformat() + "Z") if doc.created_at else None,
|
||||||
@@ -359,18 +389,17 @@ def setup_document_routes(session_manager, upload_handler=None) -> APIRouter:
|
|||||||
try:
|
try:
|
||||||
if not user:
|
if not user:
|
||||||
raise HTTPException(403, "Authentication required")
|
raise HTTPException(403, "Authentication required")
|
||||||
session = db.query(DbSession).filter(DbSession.id == session_id).first()
|
|
||||||
# v2 review HIGH-9: raise 403 explicitly when the caller
|
# v2 review HIGH-9: raise 403 explicitly when the caller
|
||||||
# can't see this session, instead of returning [] which the
|
# can't see this session, instead of returning [] which the
|
||||||
# UI treats identically to "no docs" and silently masks
|
# UI treats identically to "no docs" and silently masks
|
||||||
# auth failures.
|
# auth failures.
|
||||||
if not session:
|
_get_session_or_404(db, session_id, user)
|
||||||
raise HTTPException(404, "Session not found")
|
q = db.query(Document).filter(
|
||||||
if user and session.owner and session.owner != user:
|
|
||||||
raise HTTPException(403, "Access denied")
|
|
||||||
docs = db.query(Document).filter(
|
|
||||||
Document.session_id == session_id
|
Document.session_id == session_id
|
||||||
).order_by(Document.created_at.desc()).all()
|
)
|
||||||
|
if user:
|
||||||
|
q = q.filter(or_(Document.owner == user, Document.owner.is_(None)))
|
||||||
|
docs = q.order_by(Document.created_at.desc()).all()
|
||||||
return [_doc_to_dict(d) for d in docs]
|
return [_doc_to_dict(d) for d in docs]
|
||||||
finally:
|
finally:
|
||||||
db.close()
|
db.close()
|
||||||
@@ -437,7 +466,7 @@ def setup_document_routes(session_manager, upload_handler=None) -> APIRouter:
|
|||||||
raise HTTPException(404, "Source PDF could not be located")
|
raise HTTPException(404, "Source PDF could not be located")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
body_text = strip_pdf_content_marker(_process_pdf(pdf_path))
|
body_text = strip_pdf_content_marker(_process_pdf(pdf_path, owner=user))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"extract_pdf_text failed for {pdf_path}: {e}")
|
logger.error(f"extract_pdf_text failed for {pdf_path}: {e}")
|
||||||
raise HTTPException(500, f"Extraction failed: {e}")
|
raise HTTPException(500, f"Extraction failed: {e}")
|
||||||
@@ -606,6 +635,8 @@ def setup_document_routes(session_manager, upload_handler=None) -> APIRouter:
|
|||||||
doc.language = req.language
|
doc.language = req.language
|
||||||
if req.session_id is not None:
|
if req.session_id is not None:
|
||||||
# Empty string = unlink from session
|
# Empty string = unlink from session
|
||||||
|
if req.session_id:
|
||||||
|
_get_session_or_404(db, req.session_id, user)
|
||||||
doc.session_id = req.session_id if req.session_id else None
|
doc.session_id = req.session_id if req.session_id else None
|
||||||
if not req.session_id:
|
if not req.session_id:
|
||||||
# Tab closed / doc detached from its session — drop the
|
# Tab closed / doc detached from its session — drop the
|
||||||
@@ -855,10 +886,10 @@ def setup_document_routes(session_manager, upload_handler=None) -> APIRouter:
|
|||||||
from src.llm_core import llm_call_async
|
from src.llm_core import llm_call_async
|
||||||
|
|
||||||
user = get_current_user(request)
|
user = get_current_user(request)
|
||||||
url, model, headers = resolve_task_endpoint()
|
url, model, headers = resolve_task_endpoint(owner=user or None)
|
||||||
if not url or not model:
|
if not url or not model:
|
||||||
# Fall back to default endpoint
|
# Fall back to default endpoint
|
||||||
url, model, headers = resolve_endpoint("default")
|
url, model, headers = resolve_endpoint("default", owner=user or None)
|
||||||
if not url or not model:
|
if not url or not model:
|
||||||
raise HTTPException(500, "No endpoint configured for AI tidy")
|
raise HTTPException(500, "No endpoint configured for AI tidy")
|
||||||
|
|
||||||
@@ -1158,7 +1189,7 @@ def setup_document_routes(session_manager, upload_handler=None) -> APIRouter:
|
|||||||
settings = _load_vl_settings()
|
settings = _load_vl_settings()
|
||||||
vl_model = settings.get("vision_model", "")
|
vl_model = settings.get("vision_model", "")
|
||||||
try:
|
try:
|
||||||
url, model_id, headers = _resolve_vl_model(vl_model)
|
url, model_id, headers = _resolve_vl_model(vl_model, owner=user)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(503, f"No vision model available: {e}")
|
raise HTTPException(503, f"No vision model available: {e}")
|
||||||
|
|
||||||
@@ -1512,10 +1543,7 @@ def setup_document_routes(session_manager, upload_handler=None) -> APIRouter:
|
|||||||
# don't import from a routes file (cycle-prone). Same env override
|
# don't import from a routes file (cycle-prone). Same env override
|
||||||
# as email_routes (ODYSSEUS_MAIL_ATTACHMENTS_DIR).
|
# as email_routes (ODYSSEUS_MAIL_ATTACHMENTS_DIR).
|
||||||
from pathlib import Path as _Path
|
from pathlib import Path as _Path
|
||||||
import os as _os
|
_COMPOSE_DIR = _Path(MAIL_ATTACHMENTS_DIR) / "_compose"
|
||||||
_DATA_DIR = _Path(__file__).resolve().parent.parent / "data"
|
|
||||||
_BASE = _os.environ.get("ODYSSEUS_MAIL_ATTACHMENTS_DIR", str(_DATA_DIR / "mail-attachments"))
|
|
||||||
_COMPOSE_DIR = _Path(_BASE) / "_compose"
|
|
||||||
_COMPOSE_DIR.mkdir(parents=True, exist_ok=True)
|
_COMPOSE_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
user = get_current_user(request)
|
user = get_current_user(request)
|
||||||
@@ -1631,9 +1659,11 @@ def setup_document_routes(session_manager, upload_handler=None) -> APIRouter:
|
|||||||
# context (To/Subject/In-Reply-To/References).
|
# context (To/Subject/In-Reply-To/References).
|
||||||
try:
|
try:
|
||||||
from routes.email_routes import _imap, _decode_header
|
from routes.email_routes import _imap, _decode_header
|
||||||
|
from routes.email_helpers import _q
|
||||||
except Exception:
|
except Exception:
|
||||||
_imap = None
|
_imap = None
|
||||||
_decode_header = lambda x: x or ""
|
_decode_header = lambda x: x or ""
|
||||||
|
_q = lambda x: x or ""
|
||||||
|
|
||||||
to_addr = ""
|
to_addr = ""
|
||||||
from_name = ""
|
from_name = ""
|
||||||
@@ -1643,7 +1673,7 @@ def setup_document_routes(session_manager, upload_handler=None) -> APIRouter:
|
|||||||
if _imap:
|
if _imap:
|
||||||
try:
|
try:
|
||||||
with _imap(doc.source_email_account_id or None) as conn:
|
with _imap(doc.source_email_account_id or None) as conn:
|
||||||
conn.select(doc.source_email_folder, readonly=True)
|
conn.select(_q(doc.source_email_folder), readonly=True)
|
||||||
status, data = conn.fetch(doc.source_email_uid.encode(), "(RFC822.HEADER)")
|
status, data = conn.fetch(doc.source_email_uid.encode(), "(RFC822.HEADER)")
|
||||||
if status == "OK" and data and data[0]:
|
if status == "OK" and data and data[0]:
|
||||||
raw_hdr = data[0][1]
|
raw_hdr = data[0][1]
|
||||||
|
|||||||
+98
-33
@@ -71,6 +71,38 @@ def _send_smtp_message(cfg: dict, from_addr: str, recipients: list[str], message
|
|||||||
smtp.sendmail(from_addr, recipients, message)
|
smtp.sendmail(from_addr, recipients, message)
|
||||||
|
|
||||||
|
|
||||||
|
def _friendly_email_auth_error(protocol: str, host: str, error: object) -> str:
|
||||||
|
"""Return a clearer setup error for known provider auth policies."""
|
||||||
|
raw = str(error or "")
|
||||||
|
lower = raw.lower()
|
||||||
|
host_lower = (host or "").lower()
|
||||||
|
microsoft_host = any(
|
||||||
|
marker in host_lower
|
||||||
|
for marker in (
|
||||||
|
"outlook.office365.com",
|
||||||
|
"smtp.office365.com",
|
||||||
|
"office365.com",
|
||||||
|
"outlook.com",
|
||||||
|
"hotmail.com",
|
||||||
|
"live.com",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
microsoft_basic_auth_failure = (
|
||||||
|
"5.7.139" in lower
|
||||||
|
or "basic authentication is disabled" in lower
|
||||||
|
or ("authenticate failed" in lower and microsoft_host)
|
||||||
|
or ("authentication unsuccessful" in lower and microsoft_host)
|
||||||
|
)
|
||||||
|
if microsoft_basic_auth_failure:
|
||||||
|
return (
|
||||||
|
"Microsoft no longer accepts normal mailbox passwords for "
|
||||||
|
"Outlook/Office 365 IMAP/SMTP in most accounts. Odysseus "
|
||||||
|
"does not support Microsoft OAuth/Graph mail yet, so Outlook "
|
||||||
|
"accounts cannot be added with this password form."
|
||||||
|
)
|
||||||
|
return raw[:200]
|
||||||
|
|
||||||
|
|
||||||
def _strip_think(text: str) -> str:
|
def _strip_think(text: str) -> str:
|
||||||
"""Email-flavored think strip — thin wrapper over the central helper.
|
"""Email-flavored think strip — thin wrapper over the central helper.
|
||||||
|
|
||||||
@@ -254,16 +286,17 @@ def _cleanup_compose_uploads(tokens) -> None:
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
DATA_DIR = Path(__file__).resolve().parent.parent / "data"
|
from src.constants import DATA_DIR as _DATA_DIR, MAIL_ATTACHMENTS_DIR, SETTINGS_FILE as _SETTINGS_FILE, SCHEDULED_EMAILS_DB
|
||||||
SETTINGS_FILE = DATA_DIR / "settings.json"
|
DATA_DIR = Path(_DATA_DIR)
|
||||||
|
SETTINGS_FILE = Path(_SETTINGS_FILE)
|
||||||
# Override at deploy time via ODYSSEUS_MAIL_ATTACHMENTS_DIR. Defaults to a
|
# Override at deploy time via ODYSSEUS_MAIL_ATTACHMENTS_DIR. Defaults to a
|
||||||
# subdir of the install's data/ tree so the app works out-of-the-box without
|
# subdir of the install's data/ tree so the app works out-of-the-box without
|
||||||
# a hardcoded /home/<user>/ path.
|
# a hardcoded /home/<user>/ path.
|
||||||
ATTACHMENTS_DIR = Path(os.environ.get("ODYSSEUS_MAIL_ATTACHMENTS_DIR", str(DATA_DIR / "mail-attachments")))
|
ATTACHMENTS_DIR = Path(MAIL_ATTACHMENTS_DIR)
|
||||||
ATTACHMENTS_DIR.mkdir(parents=True, exist_ok=True)
|
ATTACHMENTS_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
COMPOSE_UPLOADS_DIR = ATTACHMENTS_DIR / "_compose"
|
COMPOSE_UPLOADS_DIR = ATTACHMENTS_DIR / "_compose"
|
||||||
COMPOSE_UPLOADS_DIR.mkdir(parents=True, exist_ok=True)
|
COMPOSE_UPLOADS_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
SCHEDULED_DB = DATA_DIR / "scheduled_emails.db"
|
SCHEDULED_DB = Path(SCHEDULED_EMAILS_DB)
|
||||||
|
|
||||||
|
|
||||||
OWNER_SCOPED_EMAIL_CACHE_TABLES = {
|
OWNER_SCOPED_EMAIL_CACHE_TABLES = {
|
||||||
@@ -705,7 +738,16 @@ def _open_imap_connection(host: str, port: int, *, starttls: bool, timeout: int
|
|||||||
port = int(port or 993)
|
port = int(port or 993)
|
||||||
if starttls:
|
if starttls:
|
||||||
conn = imaplib.IMAP4(host, port, timeout=timeout)
|
conn = imaplib.IMAP4(host, port, timeout=timeout)
|
||||||
conn.starttls()
|
try:
|
||||||
|
conn.starttls()
|
||||||
|
except Exception:
|
||||||
|
# Don't leak the open plain socket if the STARTTLS upgrade is
|
||||||
|
# rejected; close it before propagating. (#3174)
|
||||||
|
try:
|
||||||
|
conn.shutdown()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
raise
|
||||||
elif port == 993:
|
elif port == 993:
|
||||||
conn = imaplib.IMAP4_SSL(host, port, timeout=timeout)
|
conn = imaplib.IMAP4_SSL(host, port, timeout=timeout)
|
||||||
else:
|
else:
|
||||||
@@ -714,6 +756,10 @@ def _open_imap_connection(host: str, port: int, *, starttls: bool, timeout: int
|
|||||||
conn.sock.settimeout(timeout)
|
conn.sock.settimeout(timeout)
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
# Raise the IMAP line-length limit from the default 1 MB to 50 MB so that
|
||||||
|
# large mailboxes (tens of thousands of messages) don't crash with
|
||||||
|
# "got more than 1000000 bytes" on UID SEARCH ALL. (#2883)
|
||||||
|
imaplib._MAXLINE = 50_000_000
|
||||||
return conn
|
return conn
|
||||||
|
|
||||||
def _imap_connect(account_id: str | None = None, owner: str = ""):
|
def _imap_connect(account_id: str | None = None, owner: str = ""):
|
||||||
@@ -734,7 +780,18 @@ def _imap_connect(account_id: str | None = None, owner: str = ""):
|
|||||||
starttls=bool(cfg.get("imap_starttls")),
|
starttls=bool(cfg.get("imap_starttls")),
|
||||||
timeout=_IMAP_TIMEOUT_SECONDS,
|
timeout=_IMAP_TIMEOUT_SECONDS,
|
||||||
)
|
)
|
||||||
conn.login(cfg["imap_user"], cfg["imap_password"])
|
try:
|
||||||
|
conn.login(cfg["imap_user"], cfg["imap_password"])
|
||||||
|
except Exception:
|
||||||
|
# A failed AUTHENTICATE (e.g. an Office 365 app password on an
|
||||||
|
# MFA-enabled tenant, #3174) otherwise orphans the already-connected
|
||||||
|
# socket; close it before propagating so a misconfigured account
|
||||||
|
# can't leak one descriptor per retry / background poller pass.
|
||||||
|
try:
|
||||||
|
conn.shutdown()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
raise
|
||||||
return conn
|
return conn
|
||||||
|
|
||||||
|
|
||||||
@@ -798,20 +855,28 @@ def _imap(account_id: str | None = None, owner: str = ""):
|
|||||||
def _decode_header(raw):
|
def _decode_header(raw):
|
||||||
if not raw:
|
if not raw:
|
||||||
return ""
|
return ""
|
||||||
parts = email.header.decode_header(raw)
|
try:
|
||||||
decoded = []
|
# make_header concatenates per RFC 2047: no spurious space between an
|
||||||
for data, charset in parts:
|
# encoded-word and adjacent plain text (plain runs keep their own
|
||||||
if isinstance(data, bytes):
|
# whitespace), and the whitespace between two adjacent encoded-words is
|
||||||
try:
|
# dropped. The old " ".join produced "Re: Jose"-style double spaces on
|
||||||
decoded.append(data.decode(charset or "utf-8", errors="replace"))
|
# every non-ASCII subject or sender.
|
||||||
except (LookupError, ValueError):
|
return str(email.header.make_header(email.header.decode_header(raw)))
|
||||||
# Unknown/invalid MIME charset (e.g. a malformed or spam header
|
except Exception:
|
||||||
# like =?x-unknown-charset?B?...?=). errors="replace" only covers
|
# Malformed header or unknown/invalid MIME charset (e.g. a spam header
|
||||||
# byte-decode errors, not codec lookup, so fall back to utf-8.
|
# like =?x-unknown-charset?B?...?=) makes make_header raise LookupError;
|
||||||
decoded.append(data.decode("utf-8", errors="replace"))
|
# fall back to a lossy per-part decode. errors="replace" only covers
|
||||||
else:
|
# byte-decode errors, not codec lookup, hence the explicit utf-8 retry.
|
||||||
decoded.append(data)
|
decoded = []
|
||||||
return " ".join(decoded)
|
for data, charset in email.header.decode_header(raw):
|
||||||
|
if isinstance(data, bytes):
|
||||||
|
try:
|
||||||
|
decoded.append(data.decode(charset or "utf-8", errors="replace"))
|
||||||
|
except (LookupError, ValueError):
|
||||||
|
decoded.append(data.decode("utf-8", errors="replace"))
|
||||||
|
else:
|
||||||
|
decoded.append(data)
|
||||||
|
return "".join(decoded)
|
||||||
|
|
||||||
|
|
||||||
def _detect_sent_folder(conn):
|
def _detect_sent_folder(conn):
|
||||||
@@ -1136,13 +1201,9 @@ def _fetch_sender_thread_context(sender_addr: str,
|
|||||||
if exclude_uid:
|
if exclude_uid:
|
||||||
seen_uids.add((exclude_folder or "INBOX", str(exclude_uid)))
|
seen_uids.add((exclude_folder or "INBOX", str(exclude_uid)))
|
||||||
|
|
||||||
|
conn = None
|
||||||
try:
|
try:
|
||||||
conn = _imap_connect(account_id, owner=owner)
|
conn = _imap_connect(account_id, owner=owner)
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"sender-thread-context: imap connect failed: {e}")
|
|
||||||
return ""
|
|
||||||
|
|
||||||
try:
|
|
||||||
for folder in ["INBOX", "Sent", "Archive", "Drafts"]:
|
for folder in ["INBOX", "Sent", "Archive", "Drafts"]:
|
||||||
if len(blocks) >= limit:
|
if len(blocks) >= limit:
|
||||||
break
|
break
|
||||||
@@ -1209,11 +1270,14 @@ def _fetch_sender_thread_context(sender_addr: str,
|
|||||||
if atts_text:
|
if atts_text:
|
||||||
lines.append(atts_text)
|
lines.append(atts_text)
|
||||||
blocks.append("\n".join(lines))
|
blocks.append("\n".join(lines))
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"sender-thread-context: imap failed: {e}")
|
||||||
finally:
|
finally:
|
||||||
try: conn.close()
|
if conn:
|
||||||
except Exception: pass
|
try: conn.close()
|
||||||
try: conn.logout()
|
except Exception: pass
|
||||||
except Exception: pass
|
try: conn.logout()
|
||||||
|
except Exception: pass
|
||||||
|
|
||||||
if not blocks:
|
if not blocks:
|
||||||
return ""
|
return ""
|
||||||
@@ -1316,6 +1380,7 @@ def _pre_retrieve_context(
|
|||||||
if not terms_list:
|
if not terms_list:
|
||||||
return context_snippets, terms_list
|
return context_snippets, terms_list
|
||||||
|
|
||||||
|
ctx_conn = None
|
||||||
try:
|
try:
|
||||||
ctx_conn = _imap_connect(account_id, owner=owner)
|
ctx_conn = _imap_connect(account_id, owner=owner)
|
||||||
for folder in ["INBOX", "Sent", "Archive", "Drafts"]:
|
for folder in ["INBOX", "Sent", "Archive", "Drafts"]:
|
||||||
@@ -1352,12 +1417,12 @@ def _pre_retrieve_context(
|
|||||||
except Exception as _e:
|
except Exception as _e:
|
||||||
logger.warning(f" search {folder} {term!r} failed: {_e}")
|
logger.warning(f" search {folder} {term!r} failed: {_e}")
|
||||||
continue
|
continue
|
||||||
try:
|
|
||||||
ctx_conn.logout()
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
except Exception as _e:
|
except Exception as _e:
|
||||||
logger.warning(f"IMAP context search failed: {_e}")
|
logger.warning(f"IMAP context search failed: {_e}")
|
||||||
|
finally:
|
||||||
|
if ctx_conn:
|
||||||
|
try: ctx_conn.logout()
|
||||||
|
except Exception: pass
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from routes.contacts_routes import _fetch_contacts
|
from routes.contacts_routes import _fetch_contacts
|
||||||
|
|||||||
@@ -210,7 +210,7 @@ async def _auto_summarize_pass_single(days_back: int = 1, account_id: str | None
|
|||||||
if auto_cal:
|
if auto_cal:
|
||||||
for sent_name in ("Sent", "INBOX/Sent", "Sent Items", "[Gmail]/Sent Mail"):
|
for sent_name in ("Sent", "INBOX/Sent", "Sent Items", "[Gmail]/Sent Mail"):
|
||||||
try:
|
try:
|
||||||
st, _ = conn.select(sent_name, readonly=True)
|
st, _ = conn.select(_q(sent_name), readonly=True)
|
||||||
if st == "OK":
|
if st == "OK":
|
||||||
folders_to_scan.append(sent_name)
|
folders_to_scan.append(sent_name)
|
||||||
break
|
break
|
||||||
@@ -1046,7 +1046,7 @@ def _scheduled_poll_once() -> dict:
|
|||||||
try:
|
try:
|
||||||
with _imap(row_account_id, owner=row_owner) as imap:
|
with _imap(row_account_id, owner=row_owner) as imap:
|
||||||
sent_folder = _detect_sent_folder(imap)
|
sent_folder = _detect_sent_folder(imap)
|
||||||
imap.append(sent_folder, "\\Seen", None, outer.as_bytes())
|
imap.append(_q(sent_folder), "\\Seen", None, outer.as_bytes())
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Failed to append scheduled {sid} to Sent: {e}")
|
logger.warning(f"Failed to append scheduled {sid} to Sent: {e}")
|
||||||
|
|
||||||
|
|||||||
@@ -32,9 +32,10 @@ from email.mime.multipart import MIMEMultipart
|
|||||||
|
|
||||||
from fastapi import APIRouter, Query, UploadFile, File, BackgroundTasks, HTTPException, Depends, Request
|
from fastapi import APIRouter, Query, UploadFile, File, BackgroundTasks, HTTPException, Depends, Request
|
||||||
from fastapi.responses import FileResponse
|
from fastapi.responses import FileResponse
|
||||||
|
from src.constants import DATA_DIR
|
||||||
|
|
||||||
from src.llm_core import llm_call_async
|
from src.llm_core import llm_call_async
|
||||||
from src.upload_limits import read_upload_limited
|
from src.upload_limits import read_upload_limited, EMAIL_COMPOSE_UPLOAD_MAX_BYTES
|
||||||
|
|
||||||
from routes.email_helpers import (
|
from routes.email_helpers import (
|
||||||
_strip_think, _extract_reply, _apply_email_style_mechanics, require_owner, require_user, _assert_owns_account,
|
_strip_think, _extract_reply, _apply_email_style_mechanics, require_owner, require_user, _assert_owns_account,
|
||||||
@@ -47,6 +48,7 @@ from routes.email_helpers import (
|
|||||||
_extract_attachment_to_disk, _extract_html, _extract_text,
|
_extract_attachment_to_disk, _extract_html, _extract_text,
|
||||||
_fetch_sender_thread_context, _pre_retrieve_context,
|
_fetch_sender_thread_context, _pre_retrieve_context,
|
||||||
_EMAIL_REPLY_SYS_PROMPT_BASE, _POOL_HOOKS,
|
_EMAIL_REPLY_SYS_PROMPT_BASE, _POOL_HOOKS,
|
||||||
|
_friendly_email_auth_error,
|
||||||
SendEmailRequest, ExtractStyleRequest,
|
SendEmailRequest, ExtractStyleRequest,
|
||||||
ATTACHMENTS_DIR, COMPOSE_UPLOADS_DIR, SCHEDULED_DB,
|
ATTACHMENTS_DIR, COMPOSE_UPLOADS_DIR, SCHEDULED_DB,
|
||||||
attachment_extract_dir, _email_cache_owner_clause,
|
attachment_extract_dir, _email_cache_owner_clause,
|
||||||
@@ -56,7 +58,6 @@ from routes.email_pollers import _start_poller
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
ODYSSEUS_MAIL_ORIGIN = "odysseus-ui"
|
ODYSSEUS_MAIL_ORIGIN = "odysseus-ui"
|
||||||
EMAIL_COMPOSE_UPLOAD_MAX_BYTES = 25 * 1024 * 1024
|
|
||||||
|
|
||||||
|
|
||||||
def _email_tag_owner_aliases(account_id: str | None, owner: str = "") -> list[str]:
|
def _email_tag_owner_aliases(account_id: str | None, owner: str = "") -> list[str]:
|
||||||
@@ -2904,7 +2905,7 @@ def setup_email_routes():
|
|||||||
from pathlib import Path as _P
|
from pathlib import Path as _P
|
||||||
import json as _json
|
import json as _json
|
||||||
_slug = "".join(c if (c.isalnum() or c in "-_.@") else "_" for c in (owner or "default"))
|
_slug = "".join(c if (c.isalnum() or c in "-_.@") else "_" for c in (owner or "default"))
|
||||||
path = _P(f"data/email_urgency_state_{_slug}.json")
|
path = _P(DATA_DIR) / f"email_urgency_state_{_slug}.json"
|
||||||
if not path.exists():
|
if not path.exists():
|
||||||
return {"total_unread": 0, "total_urgent": 0, "max_score": 0, "per_uid": {}}
|
return {"total_unread": 0, "total_urgent": 0, "max_score": 0, "per_uid": {}}
|
||||||
try:
|
try:
|
||||||
@@ -3162,7 +3163,7 @@ def setup_email_routes():
|
|||||||
try: conn.logout()
|
try: conn.logout()
|
||||||
except Exception: pass
|
except Exception: pass
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
imap_result = {"ok": False, "error": str(e)[:200]}
|
imap_result = {"ok": False, "error": _friendly_email_auth_error("IMAP", imap_host, e)}
|
||||||
|
|
||||||
smtp_host = (body.get("smtp_host") or "").strip()
|
smtp_host = (body.get("smtp_host") or "").strip()
|
||||||
if smtp_host:
|
if smtp_host:
|
||||||
@@ -3184,7 +3185,7 @@ def setup_email_routes():
|
|||||||
try: smtp.quit()
|
try: smtp.quit()
|
||||||
except Exception: pass
|
except Exception: pass
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
smtp_result = {"ok": False, "error": str(e)[:200]}
|
smtp_result = {"ok": False, "error": _friendly_email_auth_error("SMTP", smtp_host, e)}
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"ok": imap_result["ok"] and (smtp_result is None or smtp_result["ok"]),
|
"ok": imap_result["ok"] and (smtp_result is None or smtp_result["ok"]),
|
||||||
|
|||||||
+65
-22
@@ -7,12 +7,12 @@ import logging
|
|||||||
import asyncio
|
import asyncio
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from fastapi import APIRouter, HTTPException, Form, Depends
|
from fastapi import APIRouter, HTTPException, Form, Depends
|
||||||
from core.constants import BASE_DIR
|
from core.constants import EMBEDDING_ENDPOINT_FILE, FASTEMBED_CACHE_DIR
|
||||||
from core.middleware import require_admin
|
from core.middleware import require_admin
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
_ENDPOINT_FILE = os.path.join(BASE_DIR, "data", "embedding_endpoint.json")
|
_ENDPOINT_FILE = EMBEDDING_ENDPOINT_FILE
|
||||||
|
|
||||||
# Track in-progress downloads
|
# Track in-progress downloads
|
||||||
_downloading: dict = {}
|
_downloading: dict = {}
|
||||||
@@ -35,13 +35,7 @@ def _cache_dir() -> str:
|
|||||||
default lived in /tmp, which many systems wipe on reboot — forcing a
|
default lived in /tmp, which many systems wipe on reboot — forcing a
|
||||||
full re-download of the embedding model after every restart.
|
full re-download of the embedding model after every restart.
|
||||||
"""
|
"""
|
||||||
env = os.environ.get("FASTEMBED_CACHE_PATH")
|
return FASTEMBED_CACHE_DIR
|
||||||
if env:
|
|
||||||
return env
|
|
||||||
return os.path.join(
|
|
||||||
os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
|
|
||||||
"data", "fastembed_cache",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _model_cache_name(hf_source: str) -> str:
|
def _model_cache_name(hf_source: str) -> str:
|
||||||
@@ -49,19 +43,35 @@ def _model_cache_name(hf_source: str) -> str:
|
|||||||
return "models--" + hf_source.replace("/", "--")
|
return "models--" + hf_source.replace("/", "--")
|
||||||
|
|
||||||
|
|
||||||
|
def _model_cache_path(hf_source: str) -> Path:
|
||||||
|
"""Return a confined cache path for a fastembed HF source."""
|
||||||
|
root = Path(_cache_dir()).expanduser().resolve()
|
||||||
|
raw_path = root / _model_cache_name(hf_source)
|
||||||
|
if raw_path.is_symlink():
|
||||||
|
raise ValueError("Model cache path must not be a symlink")
|
||||||
|
path = raw_path.resolve(strict=False)
|
||||||
|
try:
|
||||||
|
path.relative_to(root)
|
||||||
|
except ValueError:
|
||||||
|
raise ValueError("Model cache path escapes cache root")
|
||||||
|
return path
|
||||||
|
|
||||||
|
|
||||||
def _is_downloaded(hf_source: str) -> bool:
|
def _is_downloaded(hf_source: str) -> bool:
|
||||||
"""Check if a model is already cached."""
|
"""Check if a model is already cached."""
|
||||||
cache = _cache_dir()
|
try:
|
||||||
model_dir = os.path.join(cache, _model_cache_name(hf_source))
|
model_dir = _model_cache_path(hf_source)
|
||||||
if not os.path.isdir(model_dir):
|
except ValueError:
|
||||||
|
return False
|
||||||
|
if not model_dir.is_dir():
|
||||||
return False
|
return False
|
||||||
# Check for actual model files (not just empty dir)
|
# Check for actual model files (not just empty dir)
|
||||||
snapshots = os.path.join(model_dir, "snapshots")
|
snapshots = model_dir / "snapshots"
|
||||||
if os.path.isdir(snapshots):
|
if snapshots.is_dir():
|
||||||
return any(os.listdir(snapshots))
|
return any(snapshots.iterdir())
|
||||||
# Also check for blobs (older cache format)
|
# Also check for blobs (older cache format)
|
||||||
blobs = os.path.join(model_dir, "blobs")
|
blobs = model_dir / "blobs"
|
||||||
return os.path.isdir(blobs) and any(os.listdir(blobs))
|
return blobs.is_dir() and any(blobs.iterdir())
|
||||||
|
|
||||||
|
|
||||||
def _active_model() -> str:
|
def _active_model() -> str:
|
||||||
@@ -119,8 +129,10 @@ def setup_embedding_routes():
|
|||||||
|
|
||||||
cached_size = None
|
cached_size = None
|
||||||
if downloaded and hf_src:
|
if downloaded and hf_src:
|
||||||
model_path = os.path.join(_cache_dir(), _model_cache_name(hf_src))
|
try:
|
||||||
cached_size = _dir_size_mb(model_path)
|
cached_size = _dir_size_mb(str(_model_cache_path(hf_src)))
|
||||||
|
except ValueError:
|
||||||
|
cached_size = None
|
||||||
|
|
||||||
result.append({
|
result.append({
|
||||||
"model": m["model"],
|
"model": m["model"],
|
||||||
@@ -217,8 +229,11 @@ def setup_embedding_routes():
|
|||||||
if not hf_src:
|
if not hf_src:
|
||||||
raise HTTPException(400, "No cache source for this model")
|
raise HTTPException(400, "No cache source for this model")
|
||||||
|
|
||||||
model_path = os.path.join(_cache_dir(), _model_cache_name(hf_src))
|
try:
|
||||||
if not os.path.isdir(model_path):
|
model_path = _model_cache_path(hf_src)
|
||||||
|
except ValueError as e:
|
||||||
|
raise HTTPException(400, str(e))
|
||||||
|
if not model_path.is_dir():
|
||||||
return {"deleted": False, "message": "Model not cached"}
|
return {"deleted": False, "message": "Model not cached"}
|
||||||
|
|
||||||
shutil.rmtree(model_path)
|
shutil.rmtree(model_path)
|
||||||
@@ -237,7 +252,7 @@ def setup_embedding_routes():
|
|||||||
}
|
}
|
||||||
|
|
||||||
@router.post("/endpoint")
|
@router.post("/endpoint")
|
||||||
def set_endpoint(url: str = Form(...), model: str = Form("")):
|
def set_endpoint(url: str = Form(...), model: str = Form(""), api_key: str = Form("")):
|
||||||
"""Save a custom embedding endpoint URL."""
|
"""Save a custom embedding endpoint URL."""
|
||||||
url = url.strip()
|
url = url.strip()
|
||||||
if not url:
|
if not url:
|
||||||
@@ -261,6 +276,7 @@ def setup_embedding_routes():
|
|||||||
resp = httpx.post(
|
resp = httpx.post(
|
||||||
url,
|
url,
|
||||||
json={"input": ["test"], "model": model or "test"},
|
json={"input": ["test"], "model": model or "test"},
|
||||||
|
headers={"Authorization": f"Bearer {api_key}"} if api_key else {},
|
||||||
timeout=10,
|
timeout=10,
|
||||||
)
|
)
|
||||||
resp.raise_for_status()
|
resp.raise_for_status()
|
||||||
@@ -271,10 +287,16 @@ def setup_embedding_routes():
|
|||||||
data = {"url": url}
|
data = {"url": url}
|
||||||
if model:
|
if model:
|
||||||
data["model"] = model
|
data["model"] = model
|
||||||
|
if api_key:
|
||||||
|
from src.secret_storage import encrypt
|
||||||
|
data["api_key"] = encrypt(api_key)
|
||||||
|
|
||||||
_save_custom_endpoint(data)
|
_save_custom_endpoint(data)
|
||||||
os.environ["EMBEDDING_URL"] = url
|
os.environ["EMBEDDING_URL"] = url
|
||||||
if model:
|
if model:
|
||||||
os.environ["EMBEDDING_MODEL"] = model
|
os.environ["EMBEDDING_MODEL"] = model
|
||||||
|
if api_key:
|
||||||
|
os.environ["EMBEDDING_API_KEY"] = api_key
|
||||||
|
|
||||||
# Reset the RAG singleton so it picks up the new endpoint
|
# Reset the RAG singleton so it picks up the new endpoint
|
||||||
import src.rag_singleton as _rs
|
import src.rag_singleton as _rs
|
||||||
@@ -288,6 +310,16 @@ def setup_embedding_routes():
|
|||||||
reset_http_embed_state()
|
reset_http_embed_state()
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
try:
|
||||||
|
from src.embedding_lanes import reset_embedding_lane_state
|
||||||
|
reset_embedding_lane_state()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
from src.tool_index import reset_tool_index
|
||||||
|
reset_tool_index()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
# Reset ChromaDB client (collections will be recreated with new embeddings)
|
# Reset ChromaDB client (collections will be recreated with new embeddings)
|
||||||
try:
|
try:
|
||||||
@@ -308,6 +340,7 @@ def setup_embedding_routes():
|
|||||||
# Remove from environment
|
# Remove from environment
|
||||||
os.environ.pop("EMBEDDING_URL", None)
|
os.environ.pop("EMBEDDING_URL", None)
|
||||||
os.environ.pop("EMBEDDING_MODEL", None)
|
os.environ.pop("EMBEDDING_MODEL", None)
|
||||||
|
os.environ.pop("EMBEDDING_API_KEY", None)
|
||||||
|
|
||||||
# Reset the RAG singleton so it falls back to fastembed
|
# Reset the RAG singleton so it falls back to fastembed
|
||||||
import src.rag_singleton as _rs
|
import src.rag_singleton as _rs
|
||||||
@@ -318,6 +351,16 @@ def setup_embedding_routes():
|
|||||||
reset_http_embed_state()
|
reset_http_embed_state()
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
try:
|
||||||
|
from src.embedding_lanes import reset_embedding_lane_state
|
||||||
|
reset_embedding_lane_state()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
from src.tool_index import reset_tool_index
|
||||||
|
reset_tool_index()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
# Reset ChromaDB client
|
# Reset ChromaDB client
|
||||||
try:
|
try:
|
||||||
|
|||||||
+45
-6
@@ -16,22 +16,54 @@ from pathlib import Path
|
|||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from fastapi import APIRouter
|
from fastapi import APIRouter
|
||||||
from fastapi.responses import FileResponse, Response
|
from fastapi.responses import Response
|
||||||
|
|
||||||
|
from src.constants import EMOJI_CACHE_DIR
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
_CACHE_DIR = Path(__file__).resolve().parent.parent / "data" / "emoji_cache"
|
_CACHE_DIR = Path(EMOJI_CACHE_DIR)
|
||||||
# OpenMoji "black" set = monochrome line-art SVGs. Filenames are the codepoints
|
# OpenMoji "black" set = monochrome line-art SVGs. Filenames are the codepoints
|
||||||
# in UPPERCASE (FE0F dropped, same as we compute), '-' joined.
|
# in UPPERCASE (FE0F dropped, same as we compute), '-' joined.
|
||||||
_OPENMOJI_BASE = "https://cdn.jsdelivr.net/npm/openmoji@15.0.0/black/svg"
|
_OPENMOJI_BASE = "https://cdn.jsdelivr.net/npm/openmoji@15.0.0/black/svg"
|
||||||
# codepoints like "1f600" or "1f468-200d-1f469-200d-1f467" (lowercase hex, '-' joined)
|
# codepoints like "1f600" or "1f468-200d-1f469-200d-1f467" (lowercase hex, '-' joined)
|
||||||
_CODE_RE = re.compile(r"^[0-9a-f]{2,6}(?:-[0-9a-f]{2,6})*$")
|
_CODE_RE = re.compile(r"^[0-9a-f]{2,6}(?:-[0-9a-f]{2,6})*$")
|
||||||
_SVG_HEADERS = {"Cache-Control": "public, max-age=31536000, immutable"}
|
_MAX_SVG_BYTES = 256 * 1024
|
||||||
|
_BLOCKED_SVG_RE = re.compile(
|
||||||
|
br"<\s*(?:script|foreignObject|iframe|object|embed|image)\b|"
|
||||||
|
br"\bon[a-z0-9_-]+\s*=",
|
||||||
|
re.IGNORECASE,
|
||||||
|
)
|
||||||
|
_EXTERNAL_REF_RE = re.compile(
|
||||||
|
br"\b(?:href|xlink:href)\s*=\s*['\"](?:https?:|//|data:|javascript:)",
|
||||||
|
re.IGNORECASE,
|
||||||
|
)
|
||||||
|
_SVG_SECURITY_HEADERS = {
|
||||||
|
"X-Content-Type-Options": "nosniff",
|
||||||
|
"Content-Security-Policy": "sandbox",
|
||||||
|
"Cross-Origin-Resource-Policy": "same-origin",
|
||||||
|
}
|
||||||
|
_SVG_HEADERS = {
|
||||||
|
"Cache-Control": "public, max-age=31536000, immutable",
|
||||||
|
**_SVG_SECURITY_HEADERS,
|
||||||
|
}
|
||||||
# Returned when a codepoint is unknown/unreachable: an empty (transparent) SVG,
|
# Returned when a codepoint is unknown/unreachable: an empty (transparent) SVG,
|
||||||
# so the CSS mask renders nothing instead of a solid box. Not cached, so a later
|
# so the CSS mask renders nothing instead of a solid box. Not cached, so a later
|
||||||
# request can still pick up the real glyph once the CDN is reachable.
|
# request can still pick up the real glyph once the CDN is reachable.
|
||||||
_BLANK_SVG = b'<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 1 1"></svg>'
|
_BLANK_SVG = b'<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 1 1"></svg>'
|
||||||
_BLANK_HEADERS = {"Cache-Control": "no-store"}
|
_BLANK_HEADERS = {"Cache-Control": "no-store", **_SVG_SECURITY_HEADERS}
|
||||||
|
|
||||||
|
|
||||||
|
def _is_safe_svg(content: bytes) -> bool:
|
||||||
|
if not isinstance(content, bytes) or not content:
|
||||||
|
return False
|
||||||
|
if len(content) > _MAX_SVG_BYTES:
|
||||||
|
return False
|
||||||
|
if b"<svg" not in content[:256].lower():
|
||||||
|
return False
|
||||||
|
if _BLOCKED_SVG_RE.search(content) or _EXTERNAL_REF_RE.search(content):
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
def setup_emoji_routes() -> APIRouter:
|
def setup_emoji_routes() -> APIRouter:
|
||||||
@@ -49,14 +81,21 @@ def setup_emoji_routes() -> APIRouter:
|
|||||||
_CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
_CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
fp = _CACHE_DIR / f"{code}.svg"
|
fp = _CACHE_DIR / f"{code}.svg"
|
||||||
if fp.exists():
|
if fp.exists():
|
||||||
return FileResponse(fp, media_type="image/svg+xml", headers=_SVG_HEADERS)
|
try:
|
||||||
|
content = fp.read_bytes()
|
||||||
|
if _is_safe_svg(content):
|
||||||
|
return Response(content, media_type="image/svg+xml", headers=_SVG_HEADERS)
|
||||||
|
fp.unlink(missing_ok=True)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("emoji cache read %s failed: %s", code, e)
|
||||||
|
return _blank()
|
||||||
|
|
||||||
# First time we've seen this emoji — fetch the OpenMoji black SVG + cache
|
# First time we've seen this emoji — fetch the OpenMoji black SVG + cache
|
||||||
# it. OpenMoji filenames are the codepoints uppercased.
|
# it. OpenMoji filenames are the codepoints uppercased.
|
||||||
try:
|
try:
|
||||||
async with httpx.AsyncClient(timeout=8.0) as client:
|
async with httpx.AsyncClient(timeout=8.0) as client:
|
||||||
r = await client.get(f"{_OPENMOJI_BASE}/{code.upper()}.svg")
|
r = await client.get(f"{_OPENMOJI_BASE}/{code.upper()}.svg")
|
||||||
if r.status_code == 200 and b"<svg" in r.content[:256]:
|
if r.status_code == 200 and _is_safe_svg(r.content):
|
||||||
try:
|
try:
|
||||||
fp.write_bytes(r.content)
|
fp.write_bytes(r.content)
|
||||||
except Exception:
|
except Exception:
|
||||||
|
|||||||
+144
-54
@@ -12,8 +12,13 @@ from fastapi import APIRouter, HTTPException, Query, Request
|
|||||||
|
|
||||||
from core.database import SessionLocal, GalleryImage, GalleryAlbum, ModelEndpoint
|
from core.database import SessionLocal, GalleryImage, GalleryAlbum, ModelEndpoint
|
||||||
from core.database import Session as DbSession
|
from core.database import Session as DbSession
|
||||||
from src.auth_helpers import get_current_user, require_privilege
|
from src.auth_helpers import get_current_user, owner_filter, require_privilege
|
||||||
from src.upload_limits import read_upload_limited
|
from src.upload_limits import (
|
||||||
|
read_upload_limited,
|
||||||
|
GALLERY_UPLOAD_MAX_BYTES,
|
||||||
|
GALLERY_TRANSFORM_UPLOAD_MAX_BYTES,
|
||||||
|
)
|
||||||
|
from src.constants import GENERATED_IMAGES_DIR
|
||||||
|
|
||||||
from routes.gallery_helpers import (
|
from routes.gallery_helpers import (
|
||||||
GalleryPatch, _extract_exif, _image_to_dict, _owner_filter, _human_size,
|
GalleryPatch, _extract_exif, _image_to_dict, _owner_filter, _human_size,
|
||||||
@@ -21,17 +26,88 @@ from routes.gallery_helpers import (
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
GALLERY_UPLOAD_MAX_BYTES = int(os.getenv("ODYSSEUS_GALLERY_UPLOAD_MAX_BYTES", str(100 * 1024 * 1024)))
|
|
||||||
GALLERY_TRANSFORM_UPLOAD_MAX_BYTES = int(os.getenv("ODYSSEUS_GALLERY_TRANSFORM_UPLOAD_MAX_BYTES", str(25 * 1024 * 1024)))
|
def _current_user_is_admin(request: Request, user: str | None) -> bool:
|
||||||
|
if not user:
|
||||||
|
return False
|
||||||
|
auth_mgr = getattr(request.app.state, "auth_manager", None)
|
||||||
|
is_admin = getattr(auth_mgr, "is_admin", None)
|
||||||
|
if not callable(is_admin):
|
||||||
|
return False
|
||||||
|
try:
|
||||||
|
return bool(is_admin(user))
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
def _sanitize_gallery_filename(filename: str) -> str:
|
def _sanitize_gallery_filename(filename: str) -> str:
|
||||||
"""Return a local filename safe to join under generated_images."""
|
"""Return a local filename safe to join under generated_images."""
|
||||||
safe_name = re.sub(r"[^A-Za-z0-9._-]", "_", Path(filename or "").name)[:128]
|
safe_name = re.sub(r"[^A-Za-z0-9._-]", "_", Path(str(filename or "")).name)[:128]
|
||||||
if not safe_name or safe_name in {".", ".."}:
|
if not safe_name or safe_name in {".", ".."}:
|
||||||
safe_name = uuid.uuid4().hex[:12]
|
safe_name = uuid.uuid4().hex[:12]
|
||||||
return safe_name
|
return safe_name
|
||||||
|
|
||||||
|
|
||||||
|
GALLERY_IMAGE_DIR = Path(GENERATED_IMAGES_DIR)
|
||||||
|
|
||||||
|
|
||||||
|
def _gallery_image_path(filename: str) -> Path:
|
||||||
|
"""Resolve a stored gallery filename without leaving generated_images."""
|
||||||
|
if not isinstance(filename, str):
|
||||||
|
raise HTTPException(400, "Unsafe gallery filename")
|
||||||
|
safe_name = _sanitize_gallery_filename(filename)
|
||||||
|
original = str(filename or "")
|
||||||
|
root = GALLERY_IMAGE_DIR.resolve()
|
||||||
|
path = (GALLERY_IMAGE_DIR / safe_name).resolve()
|
||||||
|
try:
|
||||||
|
if os.path.commonpath([str(root), str(path)]) != str(root):
|
||||||
|
raise ValueError
|
||||||
|
except Exception:
|
||||||
|
raise HTTPException(400, "Unsafe gallery filename")
|
||||||
|
if safe_name != original:
|
||||||
|
raise HTTPException(400, "Unsafe gallery filename")
|
||||||
|
return path
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_image_endpoint_base(url: str) -> str:
|
||||||
|
base = (url or "").strip().rstrip("/")
|
||||||
|
if base.endswith("/v1"):
|
||||||
|
base = base[:-3].rstrip("/")
|
||||||
|
return base
|
||||||
|
|
||||||
|
|
||||||
|
def _visible_image_endpoint_query(db, owner: str | None):
|
||||||
|
from src.auth_helpers import owner_filter
|
||||||
|
q = db.query(ModelEndpoint).filter(
|
||||||
|
ModelEndpoint.model_type == "image",
|
||||||
|
ModelEndpoint.is_enabled == True, # noqa: E712
|
||||||
|
)
|
||||||
|
return owner_filter(q, ModelEndpoint, owner)
|
||||||
|
|
||||||
|
|
||||||
|
def _first_visible_image_endpoint(db, owner: str | None):
|
||||||
|
endpoints = _visible_image_endpoint_query(db, owner).all()
|
||||||
|
if owner:
|
||||||
|
for ep in endpoints:
|
||||||
|
if getattr(ep, "owner", None) == owner:
|
||||||
|
return ep
|
||||||
|
return endpoints[0] if endpoints else None
|
||||||
|
|
||||||
|
|
||||||
|
def _visible_image_endpoint_for_base(db, base: str, owner: str | None):
|
||||||
|
target = _normalize_image_endpoint_base(base)
|
||||||
|
if not target:
|
||||||
|
return None
|
||||||
|
fallback = None
|
||||||
|
for ep in _visible_image_endpoint_query(db, owner).all():
|
||||||
|
if _normalize_image_endpoint_base(getattr(ep, "base_url", "")) == target:
|
||||||
|
if owner and getattr(ep, "owner", None) == owner:
|
||||||
|
return ep
|
||||||
|
if fallback is None:
|
||||||
|
fallback = ep
|
||||||
|
return fallback
|
||||||
|
|
||||||
|
|
||||||
def setup_gallery_routes() -> APIRouter:
|
def setup_gallery_routes() -> APIRouter:
|
||||||
router = APIRouter(tags=["gallery"])
|
router = APIRouter(tags=["gallery"])
|
||||||
|
|
||||||
@@ -55,6 +131,9 @@ def setup_gallery_routes() -> APIRouter:
|
|||||||
file_hash = hashlib.sha256(content).hexdigest()
|
file_hash = hashlib.sha256(content).hexdigest()
|
||||||
db = SessionLocal()
|
db = SessionLocal()
|
||||||
try:
|
try:
|
||||||
|
if album_id and user is not None:
|
||||||
|
_get_or_404_album(db, album_id, user)
|
||||||
|
|
||||||
# SECURITY: scope the dup-detect to THIS user — otherwise a
|
# SECURITY: scope the dup-detect to THIS user — otherwise a
|
||||||
# caller can probe whether someone else uploaded the same
|
# caller can probe whether someone else uploaded the same
|
||||||
# file (the response leaks the existing row's id+filename).
|
# file (the response leaks the existing row's id+filename).
|
||||||
@@ -69,7 +148,7 @@ def setup_gallery_routes() -> APIRouter:
|
|||||||
return {"ok": False, "duplicate": True, "filename": existing.filename,
|
return {"ok": False, "duplicate": True, "filename": existing.filename,
|
||||||
"id": existing.id, "message": "Duplicate photo skipped"}
|
"id": existing.id, "message": "Duplicate photo skipped"}
|
||||||
|
|
||||||
img_dir = Path("data/generated_images")
|
img_dir = Path(GENERATED_IMAGES_DIR)
|
||||||
img_dir.mkdir(parents=True, exist_ok=True)
|
img_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
ext = file.filename.rsplit(".", 1)[-1].lower() if "." in file.filename else "png"
|
ext = file.filename.rsplit(".", 1)[-1].lower() if "." in file.filename else "png"
|
||||||
@@ -135,7 +214,7 @@ def setup_gallery_routes() -> APIRouter:
|
|||||||
raise HTTPException(400, "No image provided")
|
raise HTTPException(400, "No image provided")
|
||||||
|
|
||||||
content = await read_upload_limited(file, GALLERY_UPLOAD_MAX_BYTES, "Gallery replacement")
|
content = await read_upload_limited(file, GALLERY_UPLOAD_MAX_BYTES, "Gallery replacement")
|
||||||
img_dir = Path("data/generated_images")
|
img_dir = Path(GENERATED_IMAGES_DIR)
|
||||||
img_dir.mkdir(parents=True, exist_ok=True)
|
img_dir.mkdir(parents=True, exist_ok=True)
|
||||||
img_path = img_dir / _sanitize_gallery_filename(img.filename)
|
img_path = img_dir / _sanitize_gallery_filename(img.filename)
|
||||||
img_path.write_bytes(content)
|
img_path.write_bytes(content)
|
||||||
@@ -211,7 +290,7 @@ def setup_gallery_routes() -> APIRouter:
|
|||||||
if not user or img.owner != user:
|
if not user or img.owner != user:
|
||||||
raise HTTPException(403, "Not your image")
|
raise HTTPException(403, "Not your image")
|
||||||
|
|
||||||
img_path = Path("data/generated_images") / img.filename
|
img_path = _gallery_image_path(img.filename)
|
||||||
if not img_path.exists():
|
if not img_path.exists():
|
||||||
raise HTTPException(404, "Image file not found")
|
raise HTTPException(404, "Image file not found")
|
||||||
|
|
||||||
@@ -248,7 +327,7 @@ def setup_gallery_routes() -> APIRouter:
|
|||||||
"""AI upscale using img2img with the diffusion server."""
|
"""AI upscale using img2img with the diffusion server."""
|
||||||
import base64, httpx
|
import base64, httpx
|
||||||
|
|
||||||
require_privilege(request, "can_generate_images")
|
user = require_privilege(request, "can_generate_images")
|
||||||
form = await request.form()
|
form = await request.form()
|
||||||
file = form.get("image")
|
file = form.get("image")
|
||||||
if not file: raise HTTPException(400, "No image")
|
if not file: raise HTTPException(400, "No image")
|
||||||
@@ -260,7 +339,7 @@ def setup_gallery_routes() -> APIRouter:
|
|||||||
# Find image endpoint
|
# Find image endpoint
|
||||||
db = SessionLocal()
|
db = SessionLocal()
|
||||||
try:
|
try:
|
||||||
ep = db.query(ModelEndpoint).filter(ModelEndpoint.model_type == "image", ModelEndpoint.is_enabled == True).first()
|
ep = _first_visible_image_endpoint(db, user)
|
||||||
finally:
|
finally:
|
||||||
db.close()
|
db.close()
|
||||||
|
|
||||||
@@ -291,7 +370,7 @@ def setup_gallery_routes() -> APIRouter:
|
|||||||
"""Style transfer using img2img with the diffusion server."""
|
"""Style transfer using img2img with the diffusion server."""
|
||||||
import base64, httpx
|
import base64, httpx
|
||||||
|
|
||||||
require_privilege(request, "can_generate_images")
|
user = require_privilege(request, "can_generate_images")
|
||||||
form = await request.form()
|
form = await request.form()
|
||||||
file = form.get("image")
|
file = form.get("image")
|
||||||
prompt = form.get("prompt", "")
|
prompt = form.get("prompt", "")
|
||||||
@@ -303,7 +382,7 @@ def setup_gallery_routes() -> APIRouter:
|
|||||||
|
|
||||||
db = SessionLocal()
|
db = SessionLocal()
|
||||||
try:
|
try:
|
||||||
ep = db.query(ModelEndpoint).filter(ModelEndpoint.model_type == "image", ModelEndpoint.is_enabled == True).first()
|
ep = _first_visible_image_endpoint(db, user)
|
||||||
finally:
|
finally:
|
||||||
db.close()
|
db.close()
|
||||||
|
|
||||||
@@ -505,18 +584,24 @@ def setup_gallery_routes() -> APIRouter:
|
|||||||
albums = q.order_by(GalleryAlbum.created_at.desc()).all()
|
albums = q.order_by(GalleryAlbum.created_at.desc()).all()
|
||||||
result = []
|
result = []
|
||||||
for a in albums:
|
for a in albums:
|
||||||
count = db.query(GalleryImage).filter(
|
_count_q = db.query(GalleryImage).filter(
|
||||||
GalleryImage.album_id == a.id, GalleryImage.is_active == True
|
GalleryImage.album_id == a.id, GalleryImage.is_active == True
|
||||||
).count()
|
)
|
||||||
|
if user:
|
||||||
|
_count_q = _count_q.filter(GalleryImage.owner == user)
|
||||||
|
count = _count_q.count()
|
||||||
cover_url = None
|
cover_url = None
|
||||||
if a.cover_id:
|
if a.cover_id:
|
||||||
cover = db.query(GalleryImage).filter(GalleryImage.id == a.cover_id).first()
|
cover = db.query(GalleryImage).filter(GalleryImage.id == a.cover_id).first()
|
||||||
if cover:
|
if cover:
|
||||||
cover_url = f"/api/generated-image/{cover.filename}"
|
cover_url = f"/api/generated-image/{cover.filename}"
|
||||||
elif count > 0:
|
elif count > 0:
|
||||||
first = db.query(GalleryImage).filter(
|
_cover_q = db.query(GalleryImage).filter(
|
||||||
GalleryImage.album_id == a.id, GalleryImage.is_active == True
|
GalleryImage.album_id == a.id, GalleryImage.is_active == True
|
||||||
).order_by(GalleryImage.created_at.desc()).first()
|
)
|
||||||
|
if user:
|
||||||
|
_cover_q = _cover_q.filter(GalleryImage.owner == user)
|
||||||
|
first = _cover_q.order_by(GalleryImage.created_at.desc()).first()
|
||||||
if first:
|
if first:
|
||||||
cover_url = f"/api/generated-image/{first.filename}"
|
cover_url = f"/api/generated-image/{first.filename}"
|
||||||
result.append({
|
result.append({
|
||||||
@@ -649,7 +734,14 @@ def setup_gallery_routes() -> APIRouter:
|
|||||||
if req.favorite is not None:
|
if req.favorite is not None:
|
||||||
img.favorite = req.favorite
|
img.favorite = req.favorite
|
||||||
if req.album_id is not None:
|
if req.album_id is not None:
|
||||||
img.album_id = req.album_id if req.album_id else None
|
if req.album_id:
|
||||||
|
# Validate the target album belongs to the caller before
|
||||||
|
# moving the image into it — mirrors add_to_album, so you
|
||||||
|
# cannot file your image into another user's album.
|
||||||
|
_get_or_404_album(db, req.album_id, user)
|
||||||
|
img.album_id = req.album_id
|
||||||
|
else:
|
||||||
|
img.album_id = None
|
||||||
db.commit()
|
db.commit()
|
||||||
db.refresh(img)
|
db.refresh(img)
|
||||||
return _image_to_dict(img)
|
return _image_to_dict(img)
|
||||||
@@ -692,11 +784,11 @@ def setup_gallery_routes() -> APIRouter:
|
|||||||
used = set()
|
used = set()
|
||||||
with zipfile.ZipFile(buf, "w", zipfile.ZIP_DEFLATED) as zf:
|
with zipfile.ZipFile(buf, "w", zipfile.ZIP_DEFLATED) as zf:
|
||||||
for img in imgs:
|
for img in imgs:
|
||||||
src = os.path.join("data", "generated_images", img.filename)
|
src = _gallery_image_path(img.filename)
|
||||||
if not os.path.exists(src):
|
if not src.exists():
|
||||||
continue
|
continue
|
||||||
ext = os.path.splitext(img.filename)[1] or ".png"
|
ext = src.suffix or ".png"
|
||||||
base = (img.prompt or "").strip() or os.path.splitext(img.filename)[0]
|
base = (img.prompt or "").strip() or src.stem
|
||||||
base = re.sub(r"[^\w\-. ]+", "", base)[:60].strip() or img.id
|
base = re.sub(r"[^\w\-. ]+", "", base)[:60].strip() or img.id
|
||||||
name = f"{base}{ext}"
|
name = f"{base}{ext}"
|
||||||
i = 1
|
i = 1
|
||||||
@@ -818,9 +910,9 @@ def setup_gallery_routes() -> APIRouter:
|
|||||||
|
|
||||||
img_filename = img.filename
|
img_filename = img.filename
|
||||||
# Remove the file from disk
|
# Remove the file from disk
|
||||||
img_path = os.path.join("data", "generated_images", img_filename)
|
img_path = _gallery_image_path(img_filename)
|
||||||
if os.path.exists(img_path):
|
if img_path.exists():
|
||||||
os.remove(img_path)
|
img_path.unlink()
|
||||||
|
|
||||||
# Soft-delete the record
|
# Soft-delete the record
|
||||||
img.is_active = False
|
img.is_active = False
|
||||||
@@ -923,7 +1015,7 @@ def setup_gallery_routes() -> APIRouter:
|
|||||||
the request for /v1/images/edits (multipart, inverted mask). Otherwise
|
the request for /v1/images/edits (multipart, inverted mask). Otherwise
|
||||||
proxy through to a self-hosted diffusion server's /v1/images/inpaint."""
|
proxy through to a self-hosted diffusion server's /v1/images/inpaint."""
|
||||||
import httpx
|
import httpx
|
||||||
require_privilege(request, "can_generate_images")
|
user = require_privilege(request, "can_generate_images")
|
||||||
body = await request.json()
|
body = await request.json()
|
||||||
# Use endpoint from request body (editor dropdown) or fall back to DB lookup
|
# Use endpoint from request body (editor dropdown) or fall back to DB lookup
|
||||||
base = (body.pop("_endpoint", "") or "").rstrip("/")
|
base = (body.pop("_endpoint", "") or "").rstrip("/")
|
||||||
@@ -942,14 +1034,11 @@ def setup_gallery_routes() -> APIRouter:
|
|||||||
if not base:
|
if not base:
|
||||||
db = SessionLocal()
|
db = SessionLocal()
|
||||||
try:
|
try:
|
||||||
eps = db.query(ModelEndpoint).filter(
|
ep = _first_visible_image_endpoint(db, user)
|
||||||
ModelEndpoint.is_enabled == True,
|
if not ep:
|
||||||
ModelEndpoint.model_type == "image",
|
|
||||||
).all()
|
|
||||||
if not eps:
|
|
||||||
raise HTTPException(400, "No image generation endpoint configured. Serve a diffusion model via Cookbook first.")
|
raise HTTPException(400, "No image generation endpoint configured. Serve a diffusion model via Cookbook first.")
|
||||||
base = eps[0].base_url.rstrip("/")
|
base = ep.base_url.rstrip("/")
|
||||||
api_key = eps[0].api_key
|
api_key = ep.api_key
|
||||||
finally:
|
finally:
|
||||||
db.close()
|
db.close()
|
||||||
else:
|
else:
|
||||||
@@ -966,10 +1055,12 @@ def setup_gallery_routes() -> APIRouter:
|
|||||||
_target = _norm_url(base)
|
_target = _norm_url(base)
|
||||||
db = SessionLocal()
|
db = SessionLocal()
|
||||||
try:
|
try:
|
||||||
for ep in db.query(ModelEndpoint).all():
|
ep = _visible_image_endpoint_for_base(db, _target, user)
|
||||||
if _norm_url(ep.base_url) == _target:
|
if ep:
|
||||||
api_key = ep.api_key
|
base = (ep.base_url or base).rstrip("/")
|
||||||
break
|
api_key = ep.api_key
|
||||||
|
elif user and not _current_user_is_admin(request, user):
|
||||||
|
raise HTTPException(403, "Choose a registered image endpoint")
|
||||||
finally:
|
finally:
|
||||||
db.close()
|
db.close()
|
||||||
|
|
||||||
@@ -1121,7 +1212,7 @@ def setup_gallery_routes() -> APIRouter:
|
|||||||
you get edge blending + lighting unification while keeping the
|
you get edge blending + lighting unification while keeping the
|
||||||
composition recognisable."""
|
composition recognisable."""
|
||||||
import httpx, base64 as _b64
|
import httpx, base64 as _b64
|
||||||
require_privilege(request, "can_generate_images")
|
user = require_privilege(request, "can_generate_images")
|
||||||
body = await request.json()
|
body = await request.json()
|
||||||
|
|
||||||
image_b64 = body.get("image")
|
image_b64 = body.get("image")
|
||||||
@@ -1148,23 +1239,22 @@ def setup_gallery_routes() -> APIRouter:
|
|||||||
if not base:
|
if not base:
|
||||||
db = SessionLocal()
|
db = SessionLocal()
|
||||||
try:
|
try:
|
||||||
eps = db.query(ModelEndpoint).filter(
|
ep = _first_visible_image_endpoint(db, user)
|
||||||
ModelEndpoint.is_enabled == True,
|
if not ep:
|
||||||
ModelEndpoint.model_type == "image",
|
|
||||||
).all()
|
|
||||||
if not eps:
|
|
||||||
raise HTTPException(400, "No image generation endpoint configured.")
|
raise HTTPException(400, "No image generation endpoint configured.")
|
||||||
base = eps[0].base_url.rstrip("/")
|
base = ep.base_url.rstrip("/")
|
||||||
api_key = eps[0].api_key
|
api_key = ep.api_key
|
||||||
finally:
|
finally:
|
||||||
db.close()
|
db.close()
|
||||||
else:
|
else:
|
||||||
db = SessionLocal()
|
db = SessionLocal()
|
||||||
try:
|
try:
|
||||||
for ep in db.query(ModelEndpoint).all():
|
ep = _visible_image_endpoint_for_base(db, base, user)
|
||||||
if ep.base_url.rstrip("/").removesuffix("/v1").rstrip("/") == base.rstrip("/").removesuffix("/v1").rstrip("/"):
|
if ep:
|
||||||
api_key = ep.api_key
|
base = (ep.base_url or base).rstrip("/")
|
||||||
break
|
api_key = ep.api_key
|
||||||
|
elif user and not _current_user_is_admin(request, user):
|
||||||
|
raise HTTPException(403, "Choose a registered image endpoint")
|
||||||
finally:
|
finally:
|
||||||
db.close()
|
db.close()
|
||||||
|
|
||||||
@@ -1636,9 +1726,10 @@ def setup_gallery_routes() -> APIRouter:
|
|||||||
db = SessionLocal()
|
db = SessionLocal()
|
||||||
try:
|
try:
|
||||||
album = _get_or_404_album(db, album_id, user)
|
album = _get_or_404_album(db, album_id, user)
|
||||||
db.query(GalleryImage).filter(GalleryImage.album_id == album_id).update(
|
q = db.query(GalleryImage).filter(GalleryImage.album_id == album_id)
|
||||||
{"album_id": None}, synchronize_session=False
|
if user is not None:
|
||||||
)
|
q = q.filter(GalleryImage.owner == user)
|
||||||
|
q.update({"album_id": None}, synchronize_session=False)
|
||||||
db.delete(album)
|
db.delete(album)
|
||||||
db.commit()
|
db.commit()
|
||||||
return {"ok": True}
|
return {"ok": True}
|
||||||
@@ -1709,7 +1800,7 @@ def setup_gallery_routes() -> APIRouter:
|
|||||||
try:
|
try:
|
||||||
img = _get_or_404_image(db, image_id, user)
|
img = _get_or_404_image(db, image_id, user)
|
||||||
|
|
||||||
img_path = Path("data/generated_images") / img.filename
|
img_path = _gallery_image_path(img.filename)
|
||||||
if not img_path.exists():
|
if not img_path.exists():
|
||||||
raise HTTPException(404, "Image file not found")
|
raise HTTPException(404, "Image file not found")
|
||||||
|
|
||||||
@@ -1727,7 +1818,7 @@ def setup_gallery_routes() -> APIRouter:
|
|||||||
return {"error": "Vision is disabled — enable it in Settings → Vision"}
|
return {"error": "Vision is disabled — enable it in Settings → Vision"}
|
||||||
configured = vl_settings.get("vision_model", "")
|
configured = vl_settings.get("vision_model", "")
|
||||||
try:
|
try:
|
||||||
chat_url, model_name, headers = _resolve_vl_model(configured)
|
chat_url, model_name, headers = _resolve_vl_model(configured, owner=user)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
return {"error": "No vision model configured — set one in Settings → Vision"}
|
return {"error": "No vision model configured — set one in Settings → Vision"}
|
||||||
if not chat_url:
|
if not chat_url:
|
||||||
@@ -1808,4 +1899,3 @@ def setup_gallery_routes() -> APIRouter:
|
|||||||
db.close()
|
db.close()
|
||||||
|
|
||||||
return router
|
return router
|
||||||
|
|
||||||
|
|||||||
@@ -490,7 +490,13 @@ def setup_history_routes(session_manager) -> APIRouter:
|
|||||||
# Copy messages up to keep_count
|
# Copy messages up to keep_count
|
||||||
msgs_to_copy = source.history[:keep_count]
|
msgs_to_copy = source.history[:keep_count]
|
||||||
for msg in msgs_to_copy:
|
for msg in msgs_to_copy:
|
||||||
new_session.add_message(ChatMessage(msg.role, msg.content, msg.metadata))
|
# Copy the metadata dict. Sharing it would let the fork's
|
||||||
|
# persistence (add_message -> _persist_message stamps
|
||||||
|
# _db_id/timestamp onto the dict) mutate the SOURCE session's
|
||||||
|
# in-memory messages, corrupting their _db_id and breaking
|
||||||
|
# edit/delete-by-id on the original conversation.
|
||||||
|
meta = dict(msg.metadata) if isinstance(msg.metadata, dict) else None
|
||||||
|
new_session.add_message(ChatMessage(msg.role, msg.content, meta))
|
||||||
try:
|
try:
|
||||||
from src.event_bus import fire_event
|
from src.event_bus import fire_event
|
||||||
fire_event("session_created", getattr(source, 'owner', None))
|
fire_event("session_created", getattr(source, 'owner', None))
|
||||||
@@ -522,6 +528,8 @@ def setup_history_routes(session_manager) -> APIRouter:
|
|||||||
async def compact_session(request: Request, session_id: str):
|
async def compact_session(request: Request, session_id: str):
|
||||||
"""Manually trigger context compaction for a session."""
|
"""Manually trigger context compaction for a session."""
|
||||||
_verify_session_owner(request, session_id)
|
_verify_session_owner(request, session_id)
|
||||||
|
from src.auth_helpers import effective_user
|
||||||
|
owner = effective_user(request)
|
||||||
try:
|
try:
|
||||||
session = session_manager.get_session(session_id)
|
session = session_manager.get_session(session_id)
|
||||||
except KeyError:
|
except KeyError:
|
||||||
@@ -555,7 +563,7 @@ def setup_history_routes(session_manager) -> APIRouter:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Use utility model if available
|
# Use utility model if available
|
||||||
util_url, util_model, util_headers = resolve_endpoint("utility")
|
util_url, util_model, util_headers = resolve_endpoint("utility", owner=owner or None)
|
||||||
compact_url = util_url or session.endpoint_url
|
compact_url = util_url or session.endpoint_url
|
||||||
compact_model = util_model or session.model
|
compact_model = util_model or session.model
|
||||||
compact_headers = util_headers if util_url else session.headers
|
compact_headers = util_headers if util_url else session.headers
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ import httpx
|
|||||||
|
|
||||||
from core.database import McpServer, SessionLocal
|
from core.database import McpServer, SessionLocal
|
||||||
from core.middleware import require_admin
|
from core.middleware import require_admin
|
||||||
from src.constants import DATA_DIR
|
from src.constants import DATA_DIR, MCP_OAUTH_DIR
|
||||||
from src.mcp_manager import McpManager
|
from src.mcp_manager import McpManager
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -23,7 +23,7 @@ router = APIRouter(prefix="/api/mcp", tags=["mcp"])
|
|||||||
|
|
||||||
def _mcp_oauth_base_dir() -> Path:
|
def _mcp_oauth_base_dir() -> Path:
|
||||||
"""Directory that may contain OAuth files managed by Odysseus."""
|
"""Directory that may contain OAuth files managed by Odysseus."""
|
||||||
return (Path(DATA_DIR) / "mcp_oauth").resolve(strict=False)
|
return Path(MCP_OAUTH_DIR).resolve(strict=False)
|
||||||
|
|
||||||
|
|
||||||
def _resolve_mcp_oauth_path(raw_path, field_name: str) -> str:
|
def _resolve_mcp_oauth_path(raw_path, field_name: str) -> str:
|
||||||
|
|||||||
@@ -29,11 +29,10 @@ from src.llm_core import llm_call_async
|
|||||||
from services.memory.memory_extractor import audit_memories
|
from services.memory.memory_extractor import audit_memories
|
||||||
from src.auth_helpers import get_current_user, require_user
|
from src.auth_helpers import get_current_user, require_user
|
||||||
from src.endpoint_resolver import resolve_endpoint
|
from src.endpoint_resolver import resolve_endpoint
|
||||||
from src.upload_limits import read_upload_limited
|
from src.upload_limits import read_upload_limited, MEMORY_IMPORT_MAX_BYTES
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
MEMORY_IMPORT_MAX_BYTES = int(os.getenv("ODYSSEUS_MEMORY_IMPORT_MAX_BYTES", str(10 * 1024 * 1024)))
|
|
||||||
|
|
||||||
def setup_memory_routes(memory_manager: MemoryManager, session_manager: SessionManager, memory_vector=None):
|
def setup_memory_routes(memory_manager: MemoryManager, session_manager: SessionManager, memory_vector=None):
|
||||||
"""Set up memory-related routes."""
|
"""Set up memory-related routes."""
|
||||||
@@ -371,7 +370,7 @@ def setup_memory_routes(memory_manager: MemoryManager, session_manager: SessionM
|
|||||||
tmp.write(content)
|
tmp.write(content)
|
||||||
tmp_path = tmp.name
|
tmp_path = tmp.name
|
||||||
try:
|
try:
|
||||||
text = _process_pdf(tmp_path)
|
text = _process_pdf(tmp_path, owner=_owner(request))
|
||||||
finally:
|
finally:
|
||||||
os.unlink(tmp_path)
|
os.unlink(tmp_path)
|
||||||
else:
|
else:
|
||||||
|
|||||||
+188
-26
@@ -5,6 +5,7 @@ import re
|
|||||||
import uuid
|
import uuid
|
||||||
import json
|
import json
|
||||||
import socket
|
import socket
|
||||||
|
import hashlib
|
||||||
import time as _time
|
import time as _time
|
||||||
import logging
|
import logging
|
||||||
import httpx
|
import httpx
|
||||||
@@ -282,8 +283,11 @@ _HOST_TO_CURATED = (
|
|||||||
("fireworks.ai", "fireworks"),
|
("fireworks.ai", "fireworks"),
|
||||||
("googleapis.com", "google"),
|
("googleapis.com", "google"),
|
||||||
("x.ai", "xai"),
|
("x.ai", "xai"),
|
||||||
|
|
||||||
("openrouter.ai", "openrouter"),
|
("openrouter.ai", "openrouter"),
|
||||||
("ollama.com", "ollama"),
|
("ollama.com", "ollama"),
|
||||||
|
("opencode.ai/zen/go", "opencode-go"),
|
||||||
|
("opencode.ai/zen", "opencode-zen"),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -490,6 +494,8 @@ _NON_CHAT_EXACT_PREFIXES = (
|
|||||||
def _is_chat_model(model_id: str) -> bool:
|
def _is_chat_model(model_id: str) -> bool:
|
||||||
"""Return True if the model ID looks like a chat/completions-capable model."""
|
"""Return True if the model ID looks like a chat/completions-capable model."""
|
||||||
mid = model_id.lower()
|
mid = model_id.lower()
|
||||||
|
if mid in {"gpt-5.1-codex"}:
|
||||||
|
return True
|
||||||
for prefix in _NON_CHAT_PREFIXES:
|
for prefix in _NON_CHAT_PREFIXES:
|
||||||
if mid.startswith(prefix):
|
if mid.startswith(prefix):
|
||||||
return False
|
return False
|
||||||
@@ -502,9 +508,67 @@ def _is_chat_model(model_id: str) -> bool:
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
def _probe_single_model(base: str, api_key: str, model_id: str, timeout: int = 10, with_tools: bool = False) -> dict:
|
def _delete_orphaned_provider_auth(db, auth_id: Optional[str], exclude_ep_id: Optional[str] = None) -> bool:
|
||||||
|
"""Delete a ProviderAuthSession once no endpoint still references it.
|
||||||
|
|
||||||
|
Subscription providers (e.g. ChatGPT Subscription) keep their refresh token
|
||||||
|
in ProviderAuthSession rather than ModelEndpoint.api_key. When the last
|
||||||
|
endpoint backed by that auth row is removed, the stored credentials should
|
||||||
|
be cleared instead of lingering. Returns True if a row was deleted.
|
||||||
|
``exclude_ep_id`` drops the endpoint currently being deleted from the
|
||||||
|
reference count so it does not keep its own auth alive.
|
||||||
|
"""
|
||||||
|
if not auth_id:
|
||||||
|
return False
|
||||||
|
from core.database import ProviderAuthSession
|
||||||
|
still_referenced = db.query(ModelEndpoint.id).filter(
|
||||||
|
ModelEndpoint.provider_auth_id == auth_id,
|
||||||
|
ModelEndpoint.id != exclude_ep_id,
|
||||||
|
).first()
|
||||||
|
if still_referenced is not None:
|
||||||
|
return False
|
||||||
|
auth_row = db.query(ProviderAuthSession).filter(ProviderAuthSession.id == auth_id).first()
|
||||||
|
if auth_row is None:
|
||||||
|
return False
|
||||||
|
db.delete(auth_row)
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def _is_discovery_only_provider(provider: str) -> bool:
|
||||||
|
"""Provider that only supports model discovery, not live probing.
|
||||||
|
|
||||||
|
ChatGPT Subscription speaks the Responses/Codex API and has no
|
||||||
|
chat-completions or general health endpoint, so completion probes and
|
||||||
|
reachability pings are skipped — status is derived from cached models.
|
||||||
|
"""
|
||||||
|
return provider == "chatgpt-subscription"
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_probe_key(ep) -> Optional[str]:
|
||||||
|
"""API key/bearer to probe an endpoint with.
|
||||||
|
|
||||||
|
Delegates to ``resolve_endpoint_runtime``, which already returns the static
|
||||||
|
``ModelEndpoint.api_key`` for keyed endpoints and resolves (and refreshes)
|
||||||
|
the runtime bearer for session-backed providers (e.g. ChatGPT Subscription).
|
||||||
|
Returns None if resolution fails (e.g. re-auth required) so probing skips
|
||||||
|
rather than raising. Reads only already-loaded scalar attributes of ``ep``.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from src.endpoint_resolver import resolve_endpoint_runtime
|
||||||
|
_base, key = resolve_endpoint_runtime(ep, owner=getattr(ep, "owner", None))
|
||||||
|
return key
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("Probe key resolution failed for %s: %s", getattr(ep, "id", "?"), e)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _probe_single_model(base: str, api_key: Optional[str], model_id: str, timeout: int = 10, with_tools: bool = False) -> dict:
|
||||||
"""Send a realistic completion request to a single model. Returns {status, latency_ms, error?}."""
|
"""Send a realistic completion request to a single model. Returns {status, latency_ms, error?}."""
|
||||||
provider = _detect_provider(base)
|
provider = _detect_provider(base)
|
||||||
|
if _is_discovery_only_provider(provider):
|
||||||
|
# Responses/Codex API, not chat-completions: a completion probe would
|
||||||
|
# 400 and the re-probe flow would then hide every model. Discovery-only.
|
||||||
|
return {"status": "ok", "latency_ms": 0, "skipped": True}
|
||||||
messages = [
|
messages = [
|
||||||
{"role": "system", "content": "You are a helpful assistant."},
|
{"role": "system", "content": "You are a helpful assistant."},
|
||||||
{"role": "user", "content": "Say OK"},
|
{"role": "user", "content": "Say OK"},
|
||||||
@@ -618,6 +682,11 @@ def _probe_endpoint(base_url: str, api_key: str = None, timeout: int = 5) -> Lis
|
|||||||
For Anthropic, queries their /v1/models API, falling back to hardcoded list."""
|
For Anthropic, queries their /v1/models API, falling back to hardcoded list."""
|
||||||
from src.endpoint_resolver import resolve_url
|
from src.endpoint_resolver import resolve_url
|
||||||
base = resolve_url(_normalize_base(base_url))
|
base = resolve_url(_normalize_base(base_url))
|
||||||
|
if _detect_provider(base) == "chatgpt-subscription":
|
||||||
|
from src.chatgpt_subscription import fetch_available_models
|
||||||
|
if api_key:
|
||||||
|
return fetch_available_models(api_key, timeout=timeout)
|
||||||
|
return []
|
||||||
if _detect_provider(base) == "anthropic":
|
if _detect_provider(base) == "anthropic":
|
||||||
# Try Anthropic's /v1/models endpoint first
|
# Try Anthropic's /v1/models endpoint first
|
||||||
url = build_models_url(base)
|
url = build_models_url(base)
|
||||||
@@ -644,6 +713,10 @@ def _probe_endpoint(base_url: str, api_key: str = None, timeout: int = 5) -> Lis
|
|||||||
logger.warning(f"Anthropic /v1/models failed, using hardcoded list: {e}")
|
logger.warning(f"Anthropic /v1/models failed, using hardcoded list: {e}")
|
||||||
return list(ANTHROPIC_MODELS)
|
return list(ANTHROPIC_MODELS)
|
||||||
url = build_models_url(base)
|
url = build_models_url(base)
|
||||||
|
if not url:
|
||||||
|
curated_key = _match_provider_curated(base, None)
|
||||||
|
fallback = _PROVIDER_CURATED.get(curated_key) if curated_key else None
|
||||||
|
return list(fallback or [])
|
||||||
headers = build_headers(api_key, base)
|
headers = build_headers(api_key, base)
|
||||||
try:
|
try:
|
||||||
r = httpx.get(url, headers=headers, timeout=timeout, verify=llm_verify())
|
r = httpx.get(url, headers=headers, timeout=timeout, verify=llm_verify())
|
||||||
@@ -697,7 +770,6 @@ def _probe_endpoint(base_url: str, api_key: str = None, timeout: int = 5) -> Lis
|
|||||||
return list(fallback)
|
return list(fallback)
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
||||||
def _ping_endpoint(base_url: str, api_key: str = None, timeout: float = 1.5) -> Dict[str, Any]:
|
def _ping_endpoint(base_url: str, api_key: str = None, timeout: float = 1.5) -> Dict[str, Any]:
|
||||||
"""Reachability probe that does not require installed/listed models."""
|
"""Reachability probe that does not require installed/listed models."""
|
||||||
from src.endpoint_resolver import resolve_url
|
from src.endpoint_resolver import resolve_url
|
||||||
@@ -713,6 +785,10 @@ def _ping_endpoint(base_url: str, api_key: str = None, timeout: float = 1.5) ->
|
|||||||
or "ollama" in (parsed_base.hostname or "").lower()
|
or "ollama" in (parsed_base.hostname or "").lower()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# APFEL-specific detection
|
||||||
|
host = (parsed_base.hostname or "").lower()
|
||||||
|
looks_like_apfel = "apfel" in host or parsed_base.port == 11435
|
||||||
|
|
||||||
def _result_from_response(r) -> Dict[str, Any]:
|
def _result_from_response(r) -> Dict[str, Any]:
|
||||||
if 300 <= r.status_code < 400:
|
if 300 <= r.status_code < 400:
|
||||||
loc = r.headers.get("location", "")
|
loc = r.headers.get("location", "")
|
||||||
@@ -734,7 +810,23 @@ def _ping_endpoint(base_url: str, api_key: str = None, timeout: float = 1.5) ->
|
|||||||
last_error: Optional[str] = None
|
last_error: Optional[str] = None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if looks_like_ollama:
|
# APFEL does not behave like Ollama; use its health endpoint.
|
||||||
|
if looks_like_apfel:
|
||||||
|
root = base
|
||||||
|
for suffix in ("/v1", "/api"):
|
||||||
|
if root.endswith(suffix):
|
||||||
|
root = root[: -len(suffix)].rstrip("/")
|
||||||
|
break
|
||||||
|
try:
|
||||||
|
r = httpx.get(root + "/health", timeout=timeout, verify=llm_verify())
|
||||||
|
result = _result_from_response(r)
|
||||||
|
if result["reachable"]:
|
||||||
|
return result
|
||||||
|
last_error = result.get("error")
|
||||||
|
except Exception as e:
|
||||||
|
last_error = str(e)[:120]
|
||||||
|
|
||||||
|
elif looks_like_ollama:
|
||||||
root = base
|
root = base
|
||||||
for suffix in ("/v1", "/api"):
|
for suffix in ("/v1", "/api"):
|
||||||
if root.endswith(suffix):
|
if root.endswith(suffix):
|
||||||
@@ -754,14 +846,31 @@ def _ping_endpoint(base_url: str, api_key: str = None, timeout: float = 1.5) ->
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
r = httpx.get(base, headers=headers, timeout=timeout, verify=llm_verify())
|
r = httpx.get(base, headers=headers, timeout=timeout, verify=llm_verify())
|
||||||
return _result_from_response(r)
|
result = _result_from_response(r)
|
||||||
|
# If the bare base URL returns a non-auth 4xx (e.g. 404), try /models
|
||||||
|
# as a fallback. OpenAI-compatible servers like llama-swap return 404
|
||||||
|
# on the base /v1 prefix but 200 on /v1/models. Auth failures (401/403)
|
||||||
|
# are definitive — probing /models would just repeat the same rejection.
|
||||||
|
if (
|
||||||
|
not result["reachable"]
|
||||||
|
and result.get("status_code") is not None
|
||||||
|
and 400 <= result["status_code"] < 500
|
||||||
|
and result["status_code"] not in (401, 403)
|
||||||
|
):
|
||||||
|
models_url = build_models_url(base)
|
||||||
|
try:
|
||||||
|
r2 = httpx.get(models_url, headers=headers, timeout=timeout, verify=llm_verify())
|
||||||
|
result2 = _result_from_response(r2)
|
||||||
|
if result2["reachable"]:
|
||||||
|
return result2
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return result
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
last_error = str(e)[:120]
|
last_error = str(e)[:120]
|
||||||
|
|
||||||
return {"reachable": False, "status_code": None, "error": last_error}
|
return {"reachable": False, "status_code": None, "error": last_error}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def _model_endpoint_error_message(base_url: str, ping: Dict[str, Any] = None) -> str:
|
def _model_endpoint_error_message(base_url: str, ping: Dict[str, Any] = None) -> str:
|
||||||
"""Return a provider-aware error message for failed endpoint probes."""
|
"""Return a provider-aware error message for failed endpoint probes."""
|
||||||
ping = ping or {}
|
ping = ping or {}
|
||||||
@@ -850,6 +959,14 @@ def _visible_models(cached_models, hidden_models, pinned_models=None):
|
|||||||
return [m for m in merged if m not in hidden]
|
return [m for m in merged if m not in hidden]
|
||||||
|
|
||||||
|
|
||||||
|
def _api_key_fingerprint(api_key: Optional[str]) -> str:
|
||||||
|
"""Stable, non-secret label for distinguishing same-URL credentials."""
|
||||||
|
key = (api_key or "").strip()
|
||||||
|
if not key:
|
||||||
|
return ""
|
||||||
|
return hashlib.sha256(key.encode("utf-8")).hexdigest()[:8]
|
||||||
|
|
||||||
|
|
||||||
def setup_model_routes(model_discovery):
|
def setup_model_routes(model_discovery):
|
||||||
router = APIRouter(prefix="/api")
|
router = APIRouter(prefix="/api")
|
||||||
|
|
||||||
@@ -951,6 +1068,17 @@ def setup_model_routes(model_discovery):
|
|||||||
ok, info = _should_refresh_endpoint(ep, now, force=force)
|
ok, info = _should_refresh_endpoint(ep, now, force=force)
|
||||||
if not ok:
|
if not ok:
|
||||||
continue
|
continue
|
||||||
|
if getattr(ep, "provider_auth_id", None):
|
||||||
|
try:
|
||||||
|
from src.endpoint_resolver import resolve_endpoint_runtime
|
||||||
|
info["base"], info["api_key"] = resolve_endpoint_runtime(
|
||||||
|
ep,
|
||||||
|
owner=getattr(ep, "owner", None),
|
||||||
|
)
|
||||||
|
info["key"] = _refresh_key(info["base"], info["api_key"])
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("Skipping model refresh for %s: could not resolve provider auth: %s", getattr(ep, "name", ep.id), e)
|
||||||
|
continue
|
||||||
groups.setdefault(info["key"], {
|
groups.setdefault(info["key"], {
|
||||||
"base": info["base"],
|
"base": info["base"],
|
||||||
"api_key": info["api_key"],
|
"api_key": info["api_key"],
|
||||||
@@ -1104,8 +1232,9 @@ def setup_model_routes(model_discovery):
|
|||||||
raise HTTPException(401, "Not authenticated")
|
raise HTTPException(401, "Not authenticated")
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
raise
|
raise
|
||||||
except Exception:
|
except Exception as e:
|
||||||
pass
|
logger.error('Auth gate error in GET /api/models, failing closed: %s', e)
|
||||||
|
raise HTTPException(status_code=500, detail='Internal error')
|
||||||
# Admins see every endpoint (they manage the global pool); regular
|
# Admins see every endpoint (they manage the global pool); regular
|
||||||
# users get the owner-scoped view.
|
# users get the owner-scoped view.
|
||||||
_is_admin = False
|
_is_admin = False
|
||||||
@@ -1219,12 +1348,20 @@ def setup_model_routes(model_discovery):
|
|||||||
"endpoint_kind": kind,
|
"endpoint_kind": kind,
|
||||||
}
|
}
|
||||||
try:
|
try:
|
||||||
t0 = _time.time()
|
if _is_discovery_only_provider(provider):
|
||||||
ping = _ping_endpoint(base, ep.api_key, timeout=1.5)
|
# No general health endpoint — an unauthenticated GET just
|
||||||
entry["latency_ms"] = round((_time.time() - t0) * 1000)
|
# 401s. Report status from cached models instead of pinging.
|
||||||
entry["status"] = "online" if ping.get("reachable") or cached_count else "offline"
|
entry["latency_ms"] = None
|
||||||
entry["error"] = ping.get("error")
|
entry["status"] = "online" if cached_count else "offline"
|
||||||
entry["model_count"] = cached_count or (len(ANTHROPIC_MODELS) if provider == "anthropic" else 0)
|
entry["error"] = None
|
||||||
|
entry["model_count"] = cached_count
|
||||||
|
else:
|
||||||
|
t0 = _time.time()
|
||||||
|
ping = _ping_endpoint(base, ep.api_key, timeout=1.5)
|
||||||
|
entry["latency_ms"] = round((_time.time() - t0) * 1000)
|
||||||
|
entry["status"] = "online" if ping.get("reachable") or cached_count else "offline"
|
||||||
|
entry["error"] = ping.get("error")
|
||||||
|
entry["model_count"] = cached_count or (len(ANTHROPIC_MODELS) if provider == "anthropic" else 0)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
entry["latency_ms"] = None
|
entry["latency_ms"] = None
|
||||||
entry["status"] = "online" if cached_count else "offline"
|
entry["status"] = "online" if cached_count else "offline"
|
||||||
@@ -1257,7 +1394,7 @@ def setup_model_routes(model_discovery):
|
|||||||
if ep_id and ep_id not in endpoints_cache:
|
if ep_id and ep_id not in endpoints_cache:
|
||||||
ep = db.query(ModelEndpoint).filter(ModelEndpoint.id == ep_id).first()
|
ep = db.query(ModelEndpoint).filter(ModelEndpoint.id == ep_id).first()
|
||||||
if ep:
|
if ep:
|
||||||
endpoints_cache[ep_id] = {"base_url": ep.base_url, "api_key": ep.api_key}
|
endpoints_cache[ep_id] = {"base_url": ep.base_url, "api_key": _resolve_probe_key(ep)}
|
||||||
ep_data = endpoints_cache.get(ep_id)
|
ep_data = endpoints_cache.get(ep_id)
|
||||||
if not ep_data:
|
if not ep_data:
|
||||||
# Try to find by base_url from the model's endpoint field
|
# Try to find by base_url from the model's endpoint field
|
||||||
@@ -1296,7 +1433,7 @@ def setup_model_routes(model_discovery):
|
|||||||
"id": ep.id,
|
"id": ep.id,
|
||||||
"name": ep.name,
|
"name": ep.name,
|
||||||
"base_url": ep.base_url,
|
"base_url": ep.base_url,
|
||||||
"api_key": ep.api_key,
|
"api_key": _resolve_probe_key(ep),
|
||||||
})
|
})
|
||||||
finally:
|
finally:
|
||||||
db.close()
|
db.close()
|
||||||
@@ -1385,18 +1522,21 @@ def setup_model_routes(model_discovery):
|
|||||||
# Endpoint counts as reachable if it has any model — including
|
# Endpoint counts as reachable if it has any model — including
|
||||||
# admin-pinned IDs that a probe would never surface.
|
# admin-pinned IDs that a probe would never surface.
|
||||||
status = "online" if (all_models or pinned) else "offline"
|
status = "online" if (all_models or pinned) else "offline"
|
||||||
|
base = _normalize_base(r.base_url)
|
||||||
ping = None
|
ping = None
|
||||||
if not all_models and not pinned and r.is_enabled:
|
# Discovery-only providers have no health endpoint — an
|
||||||
|
# unauthenticated ping just 401s, so don't bother.
|
||||||
|
if not all_models and not pinned and r.is_enabled and not _is_discovery_only_provider(_detect_provider(base)):
|
||||||
ping = _ping_endpoint(r.base_url, r.api_key, timeout=1.0)
|
ping = _ping_endpoint(r.base_url, r.api_key, timeout=1.0)
|
||||||
if ping.get("reachable"):
|
if ping.get("reachable"):
|
||||||
status = "empty"
|
status = "empty"
|
||||||
base = _normalize_base(r.base_url)
|
|
||||||
kind = _effective_endpoint_kind(r, base)
|
kind = _effective_endpoint_kind(r, base)
|
||||||
results.append({
|
results.append({
|
||||||
"id": r.id,
|
"id": r.id,
|
||||||
"name": r.name,
|
"name": r.name,
|
||||||
"base_url": r.base_url,
|
"base_url": r.base_url,
|
||||||
"has_key": bool(r.api_key),
|
"has_key": bool(r.api_key),
|
||||||
|
"api_key_fingerprint": _api_key_fingerprint(r.api_key),
|
||||||
"is_enabled": r.is_enabled,
|
"is_enabled": r.is_enabled,
|
||||||
"models": visible,
|
"models": visible,
|
||||||
"pinned_models": pinned,
|
"pinned_models": pinned,
|
||||||
@@ -1463,21 +1603,34 @@ def setup_model_routes(model_discovery):
|
|||||||
)
|
)
|
||||||
explicit_timeout = _explicit_model_list_timeout(base_url, requested_kind, refresh_timeout)
|
explicit_timeout = _explicit_model_list_timeout(base_url, requested_kind, refresh_timeout)
|
||||||
|
|
||||||
# Dedupe: if an endpoint with the same base_url already exists and
|
# Dedupe: if an endpoint with the same base_url and compatible
|
||||||
# is reachable by the caller (shared or owned by them), return it
|
# credentials already exists and is reachable by the caller (shared or
|
||||||
# instead of creating a duplicate row. Fixes "Scan for Servers"
|
# owned by them), return it instead of creating a duplicate row. Keep
|
||||||
# re-adding manually-added endpoints under their host:port name.
|
# same-url/different-key rows distinct so users can group the same
|
||||||
|
# provider URL under multiple credentials.
|
||||||
from src.auth_helpers import get_current_user as _gcu_dedup
|
from src.auth_helpers import get_current_user as _gcu_dedup
|
||||||
_caller = _gcu_dedup(request) or None
|
_caller = _gcu_dedup(request) or None
|
||||||
|
_incoming_api_key = api_key.strip()
|
||||||
_db_dedup = SessionLocal()
|
_db_dedup = SessionLocal()
|
||||||
try:
|
try:
|
||||||
existing = (
|
_same_url_rows = (
|
||||||
_db_dedup.query(ModelEndpoint)
|
_db_dedup.query(ModelEndpoint)
|
||||||
.filter(ModelEndpoint.base_url == base_url)
|
.filter(ModelEndpoint.base_url == base_url)
|
||||||
.filter((ModelEndpoint.owner.is_(None)) | (ModelEndpoint.owner == _caller))
|
.filter((ModelEndpoint.owner.is_(None)) | (ModelEndpoint.owner == _caller))
|
||||||
.order_by(ModelEndpoint.owner.desc()) # prefer owned over shared
|
.order_by(ModelEndpoint.owner.desc()) # prefer owned over shared
|
||||||
.first()
|
.all()
|
||||||
)
|
)
|
||||||
|
existing = None
|
||||||
|
_empty_key_existing = None
|
||||||
|
for _candidate in _same_url_rows:
|
||||||
|
_candidate_key = (getattr(_candidate, "api_key", None) or "").strip()
|
||||||
|
if _candidate_key == _incoming_api_key:
|
||||||
|
existing = _candidate
|
||||||
|
break
|
||||||
|
if _incoming_api_key and not _candidate_key and _empty_key_existing is None:
|
||||||
|
_empty_key_existing = _candidate
|
||||||
|
if existing is None and _incoming_api_key and _empty_key_existing is not None:
|
||||||
|
existing = _empty_key_existing
|
||||||
if existing:
|
if existing:
|
||||||
changed = False
|
changed = False
|
||||||
# Persist any incoming pinned IDs onto the existing row. An
|
# Persist any incoming pinned IDs onto the existing row. An
|
||||||
@@ -1526,6 +1679,8 @@ def setup_model_routes(model_discovery):
|
|||||||
"id": existing.id,
|
"id": existing.id,
|
||||||
"name": existing.name,
|
"name": existing.name,
|
||||||
"base_url": existing.base_url,
|
"base_url": existing.base_url,
|
||||||
|
"has_key": bool(existing.api_key),
|
||||||
|
"api_key_fingerprint": _api_key_fingerprint(existing.api_key),
|
||||||
"models": _visible_models(
|
"models": _visible_models(
|
||||||
existing_models,
|
existing_models,
|
||||||
getattr(existing, "hidden_models", None),
|
getattr(existing, "hidden_models", None),
|
||||||
@@ -1599,6 +1754,8 @@ def setup_model_routes(model_discovery):
|
|||||||
"id": ep_id,
|
"id": ep_id,
|
||||||
"name": name.strip(),
|
"name": name.strip(),
|
||||||
"base_url": base_url,
|
"base_url": base_url,
|
||||||
|
"has_key": bool(api_key.strip()),
|
||||||
|
"api_key_fingerprint": _api_key_fingerprint(api_key),
|
||||||
"models": _merge_model_ids(model_ids, _pinned),
|
"models": _merge_model_ids(model_ids, _pinned),
|
||||||
"pinned_models": _pinned,
|
"pinned_models": _pinned,
|
||||||
"online": bool(model_ids) or bool(_pinned) or bool(ping.get("reachable")),
|
"online": bool(model_ids) or bool(_pinned) or bool(ping.get("reachable")),
|
||||||
@@ -1648,7 +1805,7 @@ def setup_model_routes(model_discovery):
|
|||||||
ep = db.query(ModelEndpoint).filter(ModelEndpoint.id == ep_id).first()
|
ep = db.query(ModelEndpoint).filter(ModelEndpoint.id == ep_id).first()
|
||||||
if not ep:
|
if not ep:
|
||||||
raise HTTPException(404, "Endpoint not found")
|
raise HTTPException(404, "Endpoint not found")
|
||||||
ep_data = {"id": ep.id, "name": ep.name, "base_url": ep.base_url, "api_key": ep.api_key}
|
ep_data = {"id": ep.id, "name": ep.name, "base_url": ep.base_url, "api_key": _resolve_probe_key(ep)}
|
||||||
finally:
|
finally:
|
||||||
db.close()
|
db.close()
|
||||||
|
|
||||||
@@ -1712,7 +1869,7 @@ def setup_model_routes(model_discovery):
|
|||||||
category = _classify_endpoint(base, kind)
|
category = _classify_endpoint(base, kind)
|
||||||
timeout = _manual_refresh_timeout(ep, category, refresh_timeout)
|
timeout = _manual_refresh_timeout(ep, category, refresh_timeout)
|
||||||
try:
|
try:
|
||||||
probed = _probe_endpoint(base, ep.api_key, timeout=timeout)
|
probed = _probe_endpoint(base, _resolve_probe_key(ep), timeout=timeout)
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.warning("Manual model refresh failed for endpoint %s at %s: %s", ep_id, base, exc)
|
logger.warning("Manual model refresh failed for endpoint %s at %s: %s", ep_id, base, exc)
|
||||||
probed = []
|
probed = []
|
||||||
@@ -1948,6 +2105,8 @@ def setup_model_routes(model_discovery):
|
|||||||
"name": ep.name,
|
"name": ep.name,
|
||||||
"model_type": ep.model_type,
|
"model_type": ep.model_type,
|
||||||
"base_url": ep.base_url,
|
"base_url": ep.base_url,
|
||||||
|
"has_key": bool(ep.api_key),
|
||||||
|
"api_key_fingerprint": _api_key_fingerprint(ep.api_key),
|
||||||
"pinned_models": _normalize_model_ids(getattr(ep, "pinned_models", None)),
|
"pinned_models": _normalize_model_ids(getattr(ep, "pinned_models", None)),
|
||||||
"endpoint_kind": getattr(ep, "endpoint_kind", None) or "auto",
|
"endpoint_kind": getattr(ep, "endpoint_kind", None) or "auto",
|
||||||
"model_refresh_mode": getattr(ep, "model_refresh_mode", None) or "auto",
|
"model_refresh_mode": getattr(ep, "model_refresh_mode", None) or "auto",
|
||||||
@@ -2049,7 +2208,9 @@ def setup_model_routes(model_discovery):
|
|||||||
cleared_user_preferences = _clear_user_prefs_for_endpoint(ep_id)
|
cleared_user_preferences = _clear_user_prefs_for_endpoint(ep_id)
|
||||||
cleared_sessions = _clear_sessions_for_endpoint(db, ep.base_url)
|
cleared_sessions = _clear_sessions_for_endpoint(db, ep.base_url)
|
||||||
cleared_loaded_sessions = _clear_loaded_sessions_for_endpoint(ep.base_url)
|
cleared_loaded_sessions = _clear_loaded_sessions_for_endpoint(ep.base_url)
|
||||||
|
auth_id = getattr(ep, "provider_auth_id", None)
|
||||||
db.delete(ep)
|
db.delete(ep)
|
||||||
|
cleared_provider_auth = _delete_orphaned_provider_auth(db, auth_id, exclude_ep_id=ep_id)
|
||||||
db.commit()
|
db.commit()
|
||||||
_invalidate_models_cache()
|
_invalidate_models_cache()
|
||||||
_local_probe_cache["data"] = None
|
_local_probe_cache["data"] = None
|
||||||
@@ -2059,6 +2220,7 @@ def setup_model_routes(model_discovery):
|
|||||||
"cleared_user_preferences": cleared_user_preferences,
|
"cleared_user_preferences": cleared_user_preferences,
|
||||||
"cleared_sessions": cleared_sessions,
|
"cleared_sessions": cleared_sessions,
|
||||||
"cleared_loaded_sessions": cleared_loaded_sessions,
|
"cleared_loaded_sessions": cleared_loaded_sessions,
|
||||||
|
"cleared_provider_auth": cleared_provider_auth,
|
||||||
}
|
}
|
||||||
finally:
|
finally:
|
||||||
db.close()
|
db.close()
|
||||||
|
|||||||
+161
-16
@@ -11,6 +11,7 @@ from pydantic import BaseModel
|
|||||||
|
|
||||||
from core.database import SessionLocal, Note
|
from core.database import SessionLocal, Note
|
||||||
from src.auth_helpers import get_current_user
|
from src.auth_helpers import get_current_user
|
||||||
|
from src.constants import DATA_DIR
|
||||||
from sqlalchemy.orm.attributes import flag_modified
|
from sqlalchemy.orm.attributes import flag_modified
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -95,6 +96,32 @@ def _note_to_dict(note: Note) -> Dict[str, Any]:
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _reminder_text_from_note(note: Note) -> tuple[str, str]:
|
||||||
|
"""Return the reminder title/body from a stored note row."""
|
||||||
|
title = (note.title or "Note reminder").strip() or "Note reminder"
|
||||||
|
if note.items:
|
||||||
|
try:
|
||||||
|
items = json.loads(note.items)
|
||||||
|
except (json.JSONDecodeError, TypeError):
|
||||||
|
items = None
|
||||||
|
if isinstance(items, list):
|
||||||
|
pending: list[str] = []
|
||||||
|
for item in items:
|
||||||
|
if not isinstance(item, dict):
|
||||||
|
continue
|
||||||
|
if item.get("done") or item.get("checked"):
|
||||||
|
continue
|
||||||
|
text = str(item.get("text") or "").strip()
|
||||||
|
if text:
|
||||||
|
pending.append(text)
|
||||||
|
if pending:
|
||||||
|
shown = "\n".join(f"- {text}" for text in pending[:8])
|
||||||
|
extra = f"\n...and {len(pending) - 8} more" if len(pending) > 8 else ""
|
||||||
|
return title, f"Pending ({len(pending)}):\n{shown}{extra}"
|
||||||
|
return title, f"{len(items)} item{'s' if len(items) != 1 else ''}"
|
||||||
|
return title, (note.content or "").strip()[:400]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Reminder dispatch — module-level so background tasks (built-in actions)
|
# Reminder dispatch — module-level so background tasks (built-in actions)
|
||||||
@@ -114,8 +141,9 @@ async def dispatch_reminder(
|
|||||||
note_id: str,
|
note_id: str,
|
||||||
owner: str = "",
|
owner: str = "",
|
||||||
queue_browser: bool = True,
|
queue_browser: bool = True,
|
||||||
|
settings_override: dict | None = None,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""Fire a reminder via the configured channel (browser/email/ntfy).
|
"""Fire a reminder via the configured channel (browser/email/ntfy/webhook).
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
title: short headline shown to the user
|
title: short headline shown to the user
|
||||||
@@ -129,7 +157,7 @@ async def dispatch_reminder(
|
|||||||
nothing is "sent" synchronously for it — the channel just routes there.
|
nothing is "sent" synchronously for it — the channel just routes there.
|
||||||
"""
|
"""
|
||||||
from src.settings import load_settings
|
from src.settings import load_settings
|
||||||
settings = load_settings()
|
settings = {**load_settings(), **(settings_override or {})}
|
||||||
channel = settings.get("reminder_channel", "browser")
|
channel = settings.get("reminder_channel", "browser")
|
||||||
llm_on = bool(settings.get("reminder_llm_synthesis", False))
|
llm_on = bool(settings.get("reminder_llm_synthesis", False))
|
||||||
title = (title or "").strip()
|
title = (title or "").strip()
|
||||||
@@ -143,7 +171,7 @@ async def dispatch_reminder(
|
|||||||
from datetime import datetime as _dt, timezone as _tz, timedelta as _td
|
from datetime import datetime as _dt, timezone as _tz, timedelta as _td
|
||||||
from pathlib import Path as _P
|
from pathlib import Path as _P
|
||||||
_slug = "".join(c if (c.isalnum() or c in "-_.@") else "_" for c in (owner or "default"))
|
_slug = "".join(c if (c.isalnum() or c in "-_.@") else "_" for c in (owner or "default"))
|
||||||
cache_path = _P(f"data/note_pings_{_slug}.json")
|
cache_path = _P(DATA_DIR) / f"note_pings_{_slug}.json"
|
||||||
if cache_path.exists():
|
if cache_path.exists():
|
||||||
cache = _json.loads(cache_path.read_text(encoding="utf-8"))
|
cache = _json.loads(cache_path.read_text(encoding="utf-8"))
|
||||||
last = cache.get(cache_key)
|
last = cache.get(cache_key)
|
||||||
@@ -160,13 +188,14 @@ async def dispatch_reminder(
|
|||||||
# Treat those as browser-only dedupe so email reminders can be
|
# Treat those as browser-only dedupe so email reminders can be
|
||||||
# retried by the backend scanner after a failed frontend path.
|
# retried by the backend scanner after a failed frontend path.
|
||||||
should_skip = last_dt >= _dt.now(_tz.utc) - _td(minutes=25)
|
should_skip = last_dt >= _dt.now(_tz.utc) - _td(minutes=25)
|
||||||
if should_skip and channel in ("email", "ntfy"):
|
if should_skip and channel in ("email", "ntfy", "webhook"):
|
||||||
should_skip = last_channel == channel
|
should_skip = last_channel == channel
|
||||||
if should_skip:
|
if should_skip:
|
||||||
return {
|
return {
|
||||||
"synthesis": None,
|
"synthesis": None,
|
||||||
"email_sent": False,
|
"email_sent": False,
|
||||||
"ntfy_sent": False,
|
"ntfy_sent": False,
|
||||||
|
"webhook_sent": False,
|
||||||
"browser_sent": True,
|
"browser_sent": True,
|
||||||
"skipped": True,
|
"skipped": True,
|
||||||
}
|
}
|
||||||
@@ -179,9 +208,9 @@ async def dispatch_reminder(
|
|||||||
try:
|
try:
|
||||||
from src.endpoint_resolver import resolve_endpoint
|
from src.endpoint_resolver import resolve_endpoint
|
||||||
from src.llm_core import llm_call_async
|
from src.llm_core import llm_call_async
|
||||||
url, model, headers = resolve_endpoint("utility")
|
url, model, headers = resolve_endpoint("utility", owner=owner or None)
|
||||||
if not url:
|
if not url:
|
||||||
url, model, headers = resolve_endpoint("default")
|
url, model, headers = resolve_endpoint("default", owner=owner or None)
|
||||||
if url and model:
|
if url and model:
|
||||||
raw = await llm_call_async(
|
raw = await llm_call_async(
|
||||||
url=url, model=model,
|
url=url, model=model,
|
||||||
@@ -360,6 +389,76 @@ async def dispatch_reminder(
|
|||||||
email_error = str(e) or e.__class__.__name__
|
email_error = str(e) or e.__class__.__name__
|
||||||
logger.warning(f"Reminder email send failed: {e}")
|
logger.warning(f"Reminder email send failed: {e}")
|
||||||
|
|
||||||
|
webhook_sent = False
|
||||||
|
webhook_error = ""
|
||||||
|
if channel == "webhook":
|
||||||
|
try:
|
||||||
|
import httpx
|
||||||
|
import json as _wjson
|
||||||
|
from src.integrations import load_integrations
|
||||||
|
# Built-in payload defaults for known presets so users don't have
|
||||||
|
# to configure a template just to use a standard service.
|
||||||
|
_PRESET_TEMPLATE_DEFAULTS = {
|
||||||
|
"discord_webhook": '{"embeds": [{"title": "{{title}}", "description": "{{message}}", "color": 5793266}]}',
|
||||||
|
}
|
||||||
|
intg_id = settings.get("reminder_webhook_integration_id", "").strip()
|
||||||
|
template = settings.get("reminder_webhook_payload_template", "").strip()
|
||||||
|
if not intg_id:
|
||||||
|
webhook_error = "No webhook integration selected"
|
||||||
|
else:
|
||||||
|
intg = next(
|
||||||
|
(i for i in load_integrations()
|
||||||
|
if i.get("id") == intg_id and i.get("base_url")),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
if not intg:
|
||||||
|
webhook_error = f"Integration {intg_id!r} not found or missing base URL"
|
||||||
|
else:
|
||||||
|
# Fall back to a built-in default for known presets so
|
||||||
|
# users don't have to configure a template for standard
|
||||||
|
# services like Discord.
|
||||||
|
if not template:
|
||||||
|
template = _PRESET_TEMPLATE_DEFAULTS.get(intg.get("preset", ""), "")
|
||||||
|
if not template:
|
||||||
|
webhook_error = "No payload template configured"
|
||||||
|
else:
|
||||||
|
# Render template: JSON-escape the values so the result
|
||||||
|
# is always valid JSON regardless of special characters.
|
||||||
|
# dumps() returns `"value"` — strip outer quotes.
|
||||||
|
msg = (synthesis or note_body or title or "Reminder")[:4000]
|
||||||
|
_t = _wjson.dumps(title or "Reminder")[1:-1]
|
||||||
|
_m = _wjson.dumps(msg)[1:-1]
|
||||||
|
rendered = template.replace("{{title}}", _t).replace("{{message}}", _m)
|
||||||
|
hdrs = {"Content-Type": "application/json"}
|
||||||
|
api_key = intg.get("api_key", "")
|
||||||
|
auth_type = (intg.get("auth_type") or "none").lower()
|
||||||
|
if api_key:
|
||||||
|
if auth_type == "bearer":
|
||||||
|
hdrs["Authorization"] = f"Bearer {api_key}"
|
||||||
|
elif auth_type == "header":
|
||||||
|
hdrs[intg.get("auth_header") or "Authorization"] = api_key
|
||||||
|
url = intg["base_url"].rstrip("/")
|
||||||
|
# SSRF guard — matches the pattern used by webhook_routes,
|
||||||
|
# CalDAV, search, and embeddings. Blocks link-local / metadata
|
||||||
|
# addresses (169.254.x.x) by default; set
|
||||||
|
# REMINDER_WEBHOOK_BLOCK_PRIVATE_IPS=true to also block
|
||||||
|
# RFC-1918 ranges for locked-down deployments.
|
||||||
|
import os as _os
|
||||||
|
from src.url_safety import check_outbound_url as _chk
|
||||||
|
_block = _os.getenv("REMINDER_WEBHOOK_BLOCK_PRIVATE_IPS", "false").lower() == "true"
|
||||||
|
_ok, _reason = _chk(url, block_private=_block)
|
||||||
|
if not _ok:
|
||||||
|
webhook_error = f"Webhook URL rejected: {_reason}"
|
||||||
|
else:
|
||||||
|
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||||
|
resp = await client.post(url, content=rendered.encode(), headers=hdrs)
|
||||||
|
webhook_sent = resp.is_success
|
||||||
|
if not webhook_sent:
|
||||||
|
webhook_error = f"Webhook returned HTTP {resp.status_code}"
|
||||||
|
except Exception as e:
|
||||||
|
webhook_error = str(e) or e.__class__.__name__
|
||||||
|
logger.warning(f"Reminder webhook send failed: {e}")
|
||||||
|
|
||||||
ntfy_sent = False
|
ntfy_sent = False
|
||||||
ntfy_error = ""
|
ntfy_error = ""
|
||||||
if channel == "ntfy":
|
if channel == "ntfy":
|
||||||
@@ -415,7 +514,7 @@ async def dispatch_reminder(
|
|||||||
# second send for the same note within 25 min. Without this, a note
|
# second send for the same note within 25 min. Without this, a note
|
||||||
# whose due_date fires while the user has the app open got TWO emails
|
# whose due_date fires while the user has the app open got TWO emails
|
||||||
# (frontend-fired here + background-fired by ping_notes 0–5 min later).
|
# (frontend-fired here + background-fired by ping_notes 0–5 min later).
|
||||||
if (email_sent or ntfy_sent or browser_sent or local_browser_sent) and note_id:
|
if (email_sent or ntfy_sent or webhook_sent or browser_sent or local_browser_sent) and note_id:
|
||||||
try:
|
try:
|
||||||
import json as _json
|
import json as _json
|
||||||
from datetime import datetime as _dt, timezone as _tz
|
from datetime import datetime as _dt, timezone as _tz
|
||||||
@@ -425,13 +524,13 @@ async def dispatch_reminder(
|
|||||||
_STATE = cache_path
|
_STATE = cache_path
|
||||||
if _STATE is None:
|
if _STATE is None:
|
||||||
_slug = "".join(c if (c.isalnum() or c in "-_.@") else "_" for c in (owner or "default"))
|
_slug = "".join(c if (c.isalnum() or c in "-_.@") else "_" for c in (owner or "default"))
|
||||||
_STATE = _P(f"data/note_pings_{_slug}.json")
|
_STATE = _P(DATA_DIR) / f"note_pings_{_slug}.json"
|
||||||
_STATE.parent.mkdir(parents=True, exist_ok=True)
|
_STATE.parent.mkdir(parents=True, exist_ok=True)
|
||||||
try:
|
try:
|
||||||
_cache = cache or (_json.loads(_STATE.read_text(encoding="utf-8")) if _STATE.exists() else {})
|
_cache = cache or (_json.loads(_STATE.read_text(encoding="utf-8")) if _STATE.exists() else {})
|
||||||
except Exception:
|
except Exception:
|
||||||
_cache = {}
|
_cache = {}
|
||||||
sent_channel = "email" if email_sent else "ntfy" if ntfy_sent else "browser"
|
sent_channel = "email" if email_sent else "ntfy" if ntfy_sent else "webhook" if webhook_sent else "browser"
|
||||||
_cache[cache_key or str(note_id)] = {
|
_cache[cache_key or str(note_id)] = {
|
||||||
"at": _dt.now(_tz.utc).isoformat(),
|
"at": _dt.now(_tz.utc).isoformat(),
|
||||||
"channel": sent_channel,
|
"channel": sent_channel,
|
||||||
@@ -441,11 +540,14 @@ async def dispatch_reminder(
|
|||||||
logger.debug(f"dispatch_reminder: cache write failed: {_e}")
|
logger.debug(f"dispatch_reminder: cache write failed: {_e}")
|
||||||
|
|
||||||
return {
|
return {
|
||||||
|
"channel": channel,
|
||||||
"synthesis": synthesis,
|
"synthesis": synthesis,
|
||||||
"email_sent": email_sent,
|
"email_sent": email_sent,
|
||||||
"email_error": email_error,
|
"email_error": email_error,
|
||||||
"ntfy_sent": ntfy_sent,
|
"ntfy_sent": ntfy_sent,
|
||||||
"ntfy_error": ntfy_error,
|
"ntfy_error": ntfy_error,
|
||||||
|
"webhook_sent": webhook_sent,
|
||||||
|
"webhook_error": webhook_error,
|
||||||
"browser_sent": browser_sent or local_browser_sent,
|
"browser_sent": browser_sent or local_browser_sent,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -467,6 +569,23 @@ def setup_note_routes(task_scheduler=None):
|
|||||||
def _owner(request: Request) -> Optional[str]:
|
def _owner(request: Request) -> Optional[str]:
|
||||||
return get_current_user(request)
|
return get_current_user(request)
|
||||||
|
|
||||||
|
def _is_admin_or_single_user(request: Request, user: str | None) -> bool:
|
||||||
|
if user == "internal-tool":
|
||||||
|
return True
|
||||||
|
if not user:
|
||||||
|
# require_user() already admitted this request, which only happens
|
||||||
|
# for auth-disabled, loopback-bypass, or unconfigured single-user
|
||||||
|
# modes. There is no separate non-admin account boundary there.
|
||||||
|
return True
|
||||||
|
try:
|
||||||
|
from core.auth import AuthManager
|
||||||
|
auth_mgr = getattr(request.app.state, "auth_manager", None) or AuthManager()
|
||||||
|
if not getattr(auth_mgr, "is_configured", True):
|
||||||
|
return True
|
||||||
|
return bool(auth_mgr.is_admin(user))
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
# --- LIST ---
|
# --- LIST ---
|
||||||
@router.get("")
|
@router.get("")
|
||||||
def list_notes(
|
def list_notes(
|
||||||
@@ -684,20 +803,46 @@ def setup_note_routes(task_scheduler=None):
|
|||||||
"""
|
"""
|
||||||
# Gate against anonymous callers — LLM synthesis can burn tokens.
|
# Gate against anonymous callers — LLM synthesis can burn tokens.
|
||||||
from src.auth_helpers import require_user as _ru
|
from src.auth_helpers import require_user as _ru
|
||||||
_ru(request)
|
user = _ru(request)
|
||||||
body = await request.json()
|
body = await request.json()
|
||||||
note_id = body.get("note_id")
|
note_id = str(body.get("note_id") or "").strip()
|
||||||
title = (body.get("title") or "").strip()
|
|
||||||
note_body = (body.get("body") or "").strip()
|
|
||||||
if not note_id:
|
if not note_id:
|
||||||
raise HTTPException(400, "note_id required")
|
raise HTTPException(400, "note_id required")
|
||||||
|
|
||||||
# Delegate to the module-level helper so background tasks can reuse
|
caller = _owner(request)
|
||||||
# the same dispatch without an HTTP roundtrip + auth cookie.
|
is_test = note_id.startswith("test-")
|
||||||
|
is_admin = _is_admin_or_single_user(request, user or caller)
|
||||||
|
_override: dict = {}
|
||||||
|
if is_test:
|
||||||
|
if not is_admin:
|
||||||
|
raise HTTPException(403, "Admin only")
|
||||||
|
title = (body.get("title") or "Test Reminder").strip() or "Test Reminder"
|
||||||
|
note_body = (body.get("body") or "").strip()
|
||||||
|
# Optional overrides let the admin settings test button pass the
|
||||||
|
# current UI values directly so it never races a pending save.
|
||||||
|
if body.get("channel"):
|
||||||
|
_override["reminder_channel"] = body["channel"]
|
||||||
|
if body.get("webhook_integration_id"):
|
||||||
|
_override["reminder_webhook_integration_id"] = body["webhook_integration_id"]
|
||||||
|
if body.get("webhook_payload_template"):
|
||||||
|
_override["reminder_webhook_payload_template"] = body["webhook_payload_template"]
|
||||||
|
else:
|
||||||
|
db = SessionLocal()
|
||||||
|
try:
|
||||||
|
note = db.query(Note).filter(Note.id == note_id).first()
|
||||||
|
if not note:
|
||||||
|
raise HTTPException(404, "Note not found")
|
||||||
|
if caller is not None and note.owner != caller:
|
||||||
|
raise HTTPException(404, "Note not found")
|
||||||
|
title, note_body = _reminder_text_from_note(note)
|
||||||
|
finally:
|
||||||
|
db.close()
|
||||||
|
|
||||||
return await dispatch_reminder(
|
return await dispatch_reminder(
|
||||||
title=title, note_body=note_body, note_id=note_id,
|
title=title, note_body=note_body, note_id=note_id,
|
||||||
owner=_owner(request) or "",
|
owner=caller or "",
|
||||||
queue_browser=False,
|
queue_browser=False,
|
||||||
|
settings_override=_override or None,
|
||||||
)
|
)
|
||||||
|
|
||||||
# --- REORDER NOTES ---
|
# --- REORDER NOTES ---
|
||||||
|
|||||||
+13
-12
@@ -6,16 +6,14 @@ import uuid
|
|||||||
from typing import List, Tuple
|
from typing import List, Tuple
|
||||||
from fastapi import APIRouter, HTTPException, Query, Request, UploadFile, File, Depends
|
from fastapi import APIRouter, HTTPException, Query, Request, UploadFile, File, Depends
|
||||||
from src.request_models import DirectoryRequest
|
from src.request_models import DirectoryRequest
|
||||||
from core.constants import BASE_DIR, PERSONAL_DIR
|
from core.constants import BASE_DIR, PERSONAL_DIR, PERSONAL_UPLOADS_DIR
|
||||||
from src.rag_singleton import get_rag_manager
|
from src.rag_singleton import get_rag_manager
|
||||||
from src.auth_helpers import get_current_user, require_user
|
from src.auth_helpers import require_privilege, require_user
|
||||||
from core.middleware import require_admin
|
from core.middleware import require_admin
|
||||||
from src.upload_handler import secure_filename
|
from src.upload_handler import secure_filename
|
||||||
|
from src.upload_limits import PERSONAL_UPLOAD_MAX_BYTES
|
||||||
|
|
||||||
UPLOADS_DIR = os.path.join(BASE_DIR, "data", "personal_uploads")
|
UPLOADS_DIR = PERSONAL_UPLOADS_DIR
|
||||||
MAX_PERSONAL_UPLOAD_BYTES = int(
|
|
||||||
os.getenv("ODYSSEUS_PERSONAL_UPLOAD_MAX_BYTES", str(25 * 1024 * 1024))
|
|
||||||
)
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -194,7 +192,7 @@ def setup_personal_routes(personal_docs_manager, rag_manager, rag_available):
|
|||||||
@router.post("/upload")
|
@router.post("/upload")
|
||||||
async def upload_files_to_rag(request: Request, files: List[UploadFile] = File(...)):
|
async def upload_files_to_rag(request: Request, files: List[UploadFile] = File(...)):
|
||||||
"""Upload files directly into RAG. Supports text and PDF."""
|
"""Upload files directly into RAG. Supports text and PDF."""
|
||||||
user = get_current_user(request)
|
user = require_privilege(request, "can_use_documents")
|
||||||
rag = _rag()
|
rag = _rag()
|
||||||
if not rag:
|
if not rag:
|
||||||
raise HTTPException(503, "RAG system is not available — is the embedding service running?")
|
raise HTTPException(503, "RAG system is not available — is the embedding service running?")
|
||||||
@@ -208,8 +206,8 @@ def setup_personal_routes(personal_docs_manager, rag_manager, rag_available):
|
|||||||
for upload in files:
|
for upload in files:
|
||||||
try:
|
try:
|
||||||
file_path, stored_name, safe_name = _unique_personal_upload_path(upload_dir, upload.filename)
|
file_path, stored_name, safe_name = _unique_personal_upload_path(upload_dir, upload.filename)
|
||||||
content_bytes = await upload.read(MAX_PERSONAL_UPLOAD_BYTES + 1)
|
content_bytes = await upload.read(PERSONAL_UPLOAD_MAX_BYTES + 1)
|
||||||
if len(content_bytes) > MAX_PERSONAL_UPLOAD_BYTES:
|
if len(content_bytes) > PERSONAL_UPLOAD_MAX_BYTES:
|
||||||
logger.warning(f"Rejected oversized personal upload: {upload.filename!r}")
|
logger.warning(f"Rejected oversized personal upload: {upload.filename!r}")
|
||||||
total_failed += 1
|
total_failed += 1
|
||||||
continue
|
continue
|
||||||
@@ -286,9 +284,12 @@ def setup_personal_routes(personal_docs_manager, rag_manager, rag_available):
|
|||||||
except ValueError:
|
except ValueError:
|
||||||
# commonpath raises on mixed drives / non-comparable paths
|
# commonpath raises on mixed drives / non-comparable paths
|
||||||
in_uploads = False
|
in_uploads = False
|
||||||
if in_uploads and abs_target != base_abs and os.path.exists(abs_target):
|
if in_uploads and abs_target != base_abs:
|
||||||
os.remove(abs_target)
|
try:
|
||||||
deleted_from_disk = True
|
os.remove(abs_target)
|
||||||
|
deleted_from_disk = True
|
||||||
|
except FileNotFoundError:
|
||||||
|
pass # already gone — race with another request or cleanup
|
||||||
|
|
||||||
# Exclude the file from the listing (persists across restarts)
|
# Exclude the file from the listing (persists across restarts)
|
||||||
personal_docs_manager.exclude_file(filepath)
|
personal_docs_manager.exclude_file(filepath)
|
||||||
|
|||||||
@@ -4,8 +4,9 @@ import os
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
from fastapi import APIRouter, Request
|
from fastapi import APIRouter, Request
|
||||||
from src.auth_helpers import get_current_user
|
from src.auth_helpers import get_current_user
|
||||||
|
from src.constants import USER_PREFS_FILE
|
||||||
|
|
||||||
PREFS_FILE = os.path.join("data", "user_prefs.json")
|
PREFS_FILE = USER_PREFS_FILE
|
||||||
|
|
||||||
|
|
||||||
def _load():
|
def _load():
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ from pydantic import BaseModel, Field
|
|||||||
|
|
||||||
from src.request_models import PresetUpdateRequest
|
from src.request_models import PresetUpdateRequest
|
||||||
from core.middleware import require_admin
|
from core.middleware import require_admin
|
||||||
|
from src.auth_helpers import effective_user
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -100,7 +101,8 @@ def setup_preset_routes(preset_manager) -> APIRouter:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
model_spec = data.get("model") or ""
|
model_spec = data.get("model") or ""
|
||||||
url, model, headers = _resolve_model(model_spec)
|
user = effective_user(request)
|
||||||
|
url, model, headers = _resolve_model(model_spec, owner=user)
|
||||||
result = await llm_call_async(url, model, messages, temperature=0.8, max_tokens=500, headers=headers)
|
result = await llm_call_async(url, model, messages, temperature=0.8, max_tokens=500, headers=headers)
|
||||||
return {"success": True, "prompt": result.strip()}
|
return {"success": True, "prompt": result.strip()}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
+61
-46
@@ -14,6 +14,7 @@ from fastapi.responses import HTMLResponse, StreamingResponse
|
|||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from src.endpoint_resolver import resolve_endpoint
|
from src.endpoint_resolver import resolve_endpoint
|
||||||
from src.auth_helpers import _auth_disabled, get_current_user
|
from src.auth_helpers import _auth_disabled, get_current_user
|
||||||
|
from src.constants import DEEP_RESEARCH_DIR
|
||||||
|
|
||||||
_SESSION_ID_RE = re.compile(r"^[a-zA-Z0-9-]{1,128}$")
|
_SESSION_ID_RE = re.compile(r"^[a-zA-Z0-9-]{1,128}$")
|
||||||
|
|
||||||
@@ -37,13 +38,15 @@ def _first_chat_model(models) -> str:
|
|||||||
return (models[0] if models else "")
|
return (models[0] if models else "")
|
||||||
|
|
||||||
|
|
||||||
def _resolve_research_endpoint(sess) -> tuple:
|
def _resolve_research_endpoint(sess, owner: Optional[str] = None) -> tuple:
|
||||||
"""Return (endpoint_url, model, headers) for Deep Research, checking admin overrides."""
|
"""Return (endpoint_url, model, headers) for Deep Research, checking admin overrides."""
|
||||||
|
owner = owner or getattr(sess, "owner", None) or None
|
||||||
url, model, headers = resolve_endpoint(
|
url, model, headers = resolve_endpoint(
|
||||||
"research",
|
"research",
|
||||||
fallback_url=sess.endpoint_url,
|
fallback_url=sess.endpoint_url,
|
||||||
fallback_model=sess.model,
|
fallback_model=sess.model,
|
||||||
fallback_headers=sess.headers,
|
fallback_headers=sess.headers,
|
||||||
|
owner=owner,
|
||||||
)
|
)
|
||||||
return url, model, headers
|
return url, model, headers
|
||||||
|
|
||||||
@@ -72,6 +75,38 @@ def _owned_enabled_endpoint(db, owner, endpoint_id=None):
|
|||||||
return owner_filter(q, ModelEndpoint, owner).first()
|
return owner_filter(q, ModelEndpoint, owner).first()
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_endpoint_runtime(ep, owner=None, model: Optional[str] = None):
|
||||||
|
"""Resolve a ModelEndpoint row into (chat_url, model, headers).
|
||||||
|
|
||||||
|
Mirrors endpoint_resolver.resolve_endpoint's provider-auth handling for
|
||||||
|
panel-selected research endpoints. ChatGPT Subscription endpoints keep
|
||||||
|
OAuth tokens in ProviderAuthSession, so ep.api_key is intentionally empty.
|
||||||
|
"""
|
||||||
|
from src.endpoint_resolver import (
|
||||||
|
build_chat_url,
|
||||||
|
build_headers,
|
||||||
|
resolve_endpoint_runtime as resolve_model_endpoint_runtime,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
base, api_key = resolve_model_endpoint_runtime(ep, owner=owner)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("Could not resolve endpoint credentials for research: %s", e)
|
||||||
|
return None
|
||||||
|
|
||||||
|
ep_model = (model or "").strip()
|
||||||
|
if not ep_model:
|
||||||
|
try:
|
||||||
|
models = json.loads(ep.cached_models) if ep.cached_models else []
|
||||||
|
if models:
|
||||||
|
ep_model = _first_chat_model(models)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
if not ep_model:
|
||||||
|
return None
|
||||||
|
return build_chat_url(base), ep_model, build_headers(api_key, base)
|
||||||
|
|
||||||
|
|
||||||
def setup_research_routes(research_handler, session_manager=None) -> APIRouter:
|
def setup_research_routes(research_handler, session_manager=None) -> APIRouter:
|
||||||
router = APIRouter(tags=["research"])
|
router = APIRouter(tags=["research"])
|
||||||
|
|
||||||
@@ -98,7 +133,7 @@ def setup_research_routes(research_handler, session_manager=None) -> APIRouter:
|
|||||||
if entry is not None:
|
if entry is not None:
|
||||||
return entry.get("owner", "") == user
|
return entry.get("owner", "") == user
|
||||||
# Task no longer in memory — check the persisted JSON.
|
# Task no longer in memory — check the persisted JSON.
|
||||||
path = Path("data/deep_research") / f"{session_id}.json"
|
path = Path(DEEP_RESEARCH_DIR) / f"{session_id}.json"
|
||||||
if not path.exists():
|
if not path.exists():
|
||||||
return False
|
return False
|
||||||
try:
|
try:
|
||||||
@@ -162,7 +197,7 @@ def setup_research_routes(research_handler, session_manager=None) -> APIRouter:
|
|||||||
def _assert_owns_research(session_id: str, user: str) -> None:
|
def _assert_owns_research(session_id: str, user: str) -> None:
|
||||||
"""404-not-403 ownership gate for a research session's on-disk JSON.
|
"""404-not-403 ownership gate for a research session's on-disk JSON.
|
||||||
Use BEFORE returning any data or mutating the file."""
|
Use BEFORE returning any data or mutating the file."""
|
||||||
path = Path("data/deep_research") / f"{session_id}.json"
|
path = Path(DEEP_RESEARCH_DIR) / f"{session_id}.json"
|
||||||
if not path.exists():
|
if not path.exists():
|
||||||
raise HTTPException(404, "Research not found")
|
raise HTTPException(404, "Research not found")
|
||||||
try:
|
try:
|
||||||
@@ -225,7 +260,7 @@ def setup_research_routes(research_handler, session_manager=None) -> APIRouter:
|
|||||||
):
|
):
|
||||||
user = _require_user(request)
|
user = _require_user(request)
|
||||||
"""List all completed research for the Library panel."""
|
"""List all completed research for the Library panel."""
|
||||||
data_dir = Path("data/deep_research")
|
data_dir = Path(DEEP_RESEARCH_DIR)
|
||||||
items = []
|
items = []
|
||||||
for p in data_dir.glob("*.json"):
|
for p in data_dir.glob("*.json"):
|
||||||
try:
|
try:
|
||||||
@@ -275,7 +310,7 @@ def setup_research_routes(research_handler, session_manager=None) -> APIRouter:
|
|||||||
summary, stats — used by the Library preview panel."""
|
summary, stats — used by the Library preview panel."""
|
||||||
user = _require_user(request)
|
user = _require_user(request)
|
||||||
_validate_session_id(session_id)
|
_validate_session_id(session_id)
|
||||||
path = Path("data/deep_research") / f"{session_id}.json"
|
path = Path(DEEP_RESEARCH_DIR) / f"{session_id}.json"
|
||||||
if not path.exists():
|
if not path.exists():
|
||||||
raise HTTPException(404, "Research not found")
|
raise HTTPException(404, "Research not found")
|
||||||
try:
|
try:
|
||||||
@@ -292,7 +327,7 @@ def setup_research_routes(research_handler, session_manager=None) -> APIRouter:
|
|||||||
"""Soft-archive / restore a research report (sets `archived` in its JSON)."""
|
"""Soft-archive / restore a research report (sets `archived` in its JSON)."""
|
||||||
user = _require_user(request)
|
user = _require_user(request)
|
||||||
_validate_session_id(session_id)
|
_validate_session_id(session_id)
|
||||||
path = Path("data/deep_research") / f"{session_id}.json"
|
path = Path(DEEP_RESEARCH_DIR) / f"{session_id}.json"
|
||||||
if not path.exists():
|
if not path.exists():
|
||||||
raise HTTPException(404, "Research not found")
|
raise HTTPException(404, "Research not found")
|
||||||
try:
|
try:
|
||||||
@@ -312,7 +347,7 @@ def setup_research_routes(research_handler, session_manager=None) -> APIRouter:
|
|||||||
"""Delete a research result from disk."""
|
"""Delete a research result from disk."""
|
||||||
user = _require_user(request)
|
user = _require_user(request)
|
||||||
_validate_session_id(session_id)
|
_validate_session_id(session_id)
|
||||||
data_dir = Path("data/deep_research")
|
data_dir = Path(DEEP_RESEARCH_DIR)
|
||||||
json_path = data_dir / f"{session_id}.json"
|
json_path = data_dir / f"{session_id}.json"
|
||||||
deleted = False
|
deleted = False
|
||||||
if json_path.exists():
|
if json_path.exists():
|
||||||
@@ -368,7 +403,6 @@ def setup_research_routes(research_handler, session_manager=None) -> APIRouter:
|
|||||||
|
|
||||||
if body.endpoint_id:
|
if body.endpoint_id:
|
||||||
from src.database import SessionLocal
|
from src.database import SessionLocal
|
||||||
from src.endpoint_resolver import normalize_base, build_chat_url, build_headers
|
|
||||||
db = SessionLocal()
|
db = SessionLocal()
|
||||||
try:
|
try:
|
||||||
# Owner-scoped: never resolve another user's private endpoint
|
# Owner-scoped: never resolve another user's private endpoint
|
||||||
@@ -377,35 +411,26 @@ def setup_research_routes(research_handler, session_manager=None) -> APIRouter:
|
|||||||
ep = _owned_enabled_endpoint(db, user, body.endpoint_id)
|
ep = _owned_enabled_endpoint(db, user, body.endpoint_id)
|
||||||
if not ep:
|
if not ep:
|
||||||
raise HTTPException(404, "Endpoint not found or disabled")
|
raise HTTPException(404, "Endpoint not found or disabled")
|
||||||
base = normalize_base(ep.base_url)
|
resolved = _resolve_endpoint_runtime(ep, owner=user, model=body.model)
|
||||||
ep_url = build_chat_url(base)
|
if not resolved:
|
||||||
ep_headers = build_headers(ep.api_key, base)
|
raise HTTPException(400, "Endpoint is not configured with a usable model.")
|
||||||
ep_model = body.model or ""
|
ep_url, ep_model, ep_headers = resolved
|
||||||
if not ep_model:
|
|
||||||
try:
|
|
||||||
import json as _json
|
|
||||||
models = _json.loads(ep.cached_models) if ep.cached_models else []
|
|
||||||
if models:
|
|
||||||
ep_model = _first_chat_model(models)
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
finally:
|
finally:
|
||||||
db.close()
|
db.close()
|
||||||
else:
|
else:
|
||||||
ep_url, ep_model, ep_headers = resolve_endpoint("research")
|
ep_url, ep_model, ep_headers = resolve_endpoint("research", owner=user)
|
||||||
if not ep_url:
|
if not ep_url:
|
||||||
ep_url, ep_model, ep_headers = resolve_endpoint("utility")
|
ep_url, ep_model, ep_headers = resolve_endpoint("utility", owner=user)
|
||||||
# When neither research nor utility is configured, use the user's
|
# When neither research nor utility is configured, use the user's
|
||||||
# configured DEFAULT model (default_endpoint_id/default_model) rather
|
# configured DEFAULT model (default_endpoint_id/default_model) rather
|
||||||
# than arbitrarily grabbing the first enabled endpoint's first model
|
# than arbitrarily grabbing the first enabled endpoint's first model
|
||||||
# (which surfaced gpt-3.5). "Default" should mean the default model.
|
# (which surfaced gpt-3.5). "Default" should mean the default model.
|
||||||
if not ep_url:
|
if not ep_url:
|
||||||
ep_url, ep_model, ep_headers = resolve_endpoint("default")
|
ep_url, ep_model, ep_headers = resolve_endpoint("default", owner=user)
|
||||||
if not ep_url:
|
if not ep_url:
|
||||||
ep_url, ep_model, ep_headers = resolve_endpoint("chat")
|
ep_url, ep_model, ep_headers = resolve_endpoint("chat", owner=user)
|
||||||
if not ep_url:
|
if not ep_url:
|
||||||
from src.database import SessionLocal
|
from src.database import SessionLocal
|
||||||
from src.endpoint_resolver import normalize_base, build_chat_url, build_headers
|
|
||||||
db = SessionLocal()
|
db = SessionLocal()
|
||||||
try:
|
try:
|
||||||
# Owner-scoped first-enabled fallback: the caller's own rows
|
# Owner-scoped first-enabled fallback: the caller's own rows
|
||||||
@@ -414,18 +439,9 @@ def setup_research_routes(research_handler, session_manager=None) -> APIRouter:
|
|||||||
# /api/v1/chat fallback (webhook_routes._first_enabled_endpoint).
|
# /api/v1/chat fallback (webhook_routes._first_enabled_endpoint).
|
||||||
ep = _owned_enabled_endpoint(db, user)
|
ep = _owned_enabled_endpoint(db, user)
|
||||||
if ep:
|
if ep:
|
||||||
base = normalize_base(ep.base_url)
|
resolved = _resolve_endpoint_runtime(ep, owner=user)
|
||||||
ep_url = build_chat_url(base)
|
if resolved:
|
||||||
ep_headers = build_headers(ep.api_key, base)
|
ep_url, ep_model, ep_headers = resolved
|
||||||
ep_model = ""
|
|
||||||
if ep.cached_models:
|
|
||||||
try:
|
|
||||||
import json as _json
|
|
||||||
models = _json.loads(ep.cached_models)
|
|
||||||
if models:
|
|
||||||
ep_model = _first_chat_model(models)
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
finally:
|
finally:
|
||||||
db.close()
|
db.close()
|
||||||
if not ep_url:
|
if not ep_url:
|
||||||
@@ -494,7 +510,7 @@ def setup_research_routes(research_handler, session_manager=None) -> APIRouter:
|
|||||||
raise HTTPException(404, "No research found for this session")
|
raise HTTPException(404, "No research found for this session")
|
||||||
result = research_handler.get_result(session_id)
|
result = research_handler.get_result(session_id)
|
||||||
if result is None:
|
if result is None:
|
||||||
p = Path("data/deep_research") / f"{session_id}.json"
|
p = Path(DEEP_RESEARCH_DIR) / f"{session_id}.json"
|
||||||
if p.exists():
|
if p.exists():
|
||||||
d = json.loads(p.read_text(encoding="utf-8"))
|
d = json.loads(p.read_text(encoding="utf-8"))
|
||||||
return {
|
return {
|
||||||
@@ -534,7 +550,7 @@ def setup_research_routes(research_handler, session_manager=None) -> APIRouter:
|
|||||||
sources = research_handler.get_sources(session_id) or []
|
sources = research_handler.get_sources(session_id) or []
|
||||||
query = ""
|
query = ""
|
||||||
|
|
||||||
path = Path("data/deep_research") / f"{session_id}.json"
|
path = Path(DEEP_RESEARCH_DIR) / f"{session_id}.json"
|
||||||
if path.exists():
|
if path.exists():
|
||||||
try:
|
try:
|
||||||
disk = json.loads(path.read_text(encoding="utf-8"))
|
disk = json.loads(path.read_text(encoding="utf-8"))
|
||||||
@@ -572,19 +588,18 @@ def setup_research_routes(research_handler, session_manager=None) -> APIRouter:
|
|||||||
ep_headers = dict(r_headers)
|
ep_headers = dict(r_headers)
|
||||||
|
|
||||||
if not ep_url or not ep_model:
|
if not ep_url or not ep_model:
|
||||||
_merge(*resolve_endpoint("chat"))
|
_merge(*resolve_endpoint("chat", owner=user))
|
||||||
if not ep_url or not ep_model:
|
if not ep_url or not ep_model:
|
||||||
_merge(*resolve_endpoint("research"))
|
_merge(*resolve_endpoint("research", owner=user))
|
||||||
if not ep_url or not ep_model:
|
if not ep_url or not ep_model:
|
||||||
_merge(*resolve_endpoint("utility"))
|
_merge(*resolve_endpoint("utility", owner=user))
|
||||||
if not ep_url or not ep_model:
|
if not ep_url or not ep_model:
|
||||||
# Last resort: any enabled endpoint
|
# Last resort: this user's enabled endpoint, plus legacy shared rows.
|
||||||
from src.database import SessionLocal
|
from src.database import SessionLocal
|
||||||
from src.database import ModelEndpoint
|
|
||||||
from src.endpoint_resolver import normalize_base, build_chat_url, build_headers
|
from src.endpoint_resolver import normalize_base, build_chat_url, build_headers
|
||||||
db = SessionLocal()
|
db = SessionLocal()
|
||||||
try:
|
try:
|
||||||
ep = db.query(ModelEndpoint).filter(ModelEndpoint.is_enabled == True).first()
|
ep = _owned_enabled_endpoint(db, user)
|
||||||
if ep:
|
if ep:
|
||||||
base = normalize_base(ep.base_url)
|
base = normalize_base(ep.base_url)
|
||||||
fallback_url = build_chat_url(base)
|
fallback_url = build_chat_url(base)
|
||||||
@@ -594,7 +609,7 @@ def setup_research_routes(research_handler, session_manager=None) -> APIRouter:
|
|||||||
try:
|
try:
|
||||||
models = json.loads(ep.cached_models)
|
models = json.loads(ep.cached_models)
|
||||||
if models:
|
if models:
|
||||||
fallback_model = models[0]
|
fallback_model = _first_chat_model(models)
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
_merge(fallback_url, fallback_model, fallback_headers)
|
_merge(fallback_url, fallback_model, fallback_headers)
|
||||||
|
|||||||
+48
-33
@@ -10,8 +10,9 @@ import logging
|
|||||||
from core.session_manager import SessionManager
|
from core.session_manager import SessionManager
|
||||||
from core.models import ChatMessage
|
from core.models import ChatMessage
|
||||||
from src.request_models import SessionResponse
|
from src.request_models import SessionResponse
|
||||||
from core.database import Session as DbSession, SessionLocal, Document, GalleryImage
|
from core.database import Session as DbSession, SessionLocal, Document, GalleryImage, utcnow_naive
|
||||||
from src.auth_helpers import get_current_user, effective_user
|
from src.auth_helpers import get_current_user, effective_user, _auth_disabled
|
||||||
|
from src.session_actions import is_session_recently_active
|
||||||
|
|
||||||
|
|
||||||
def _sanitize_export_filename(name: str) -> str:
|
def _sanitize_export_filename(name: str) -> str:
|
||||||
@@ -92,35 +93,30 @@ def _reject_compact_during_active_run(session_id: str) -> None:
|
|||||||
|
|
||||||
|
|
||||||
def _verify_session_owner(request: Request, session_id: str, session_manager=None):
|
def _verify_session_owner(request: Request, session_id: str, session_manager=None):
|
||||||
"""Verify the current user owns the session. Raises 404 if not.
|
"""Verify the current user owns the session, honoring single-user modes.
|
||||||
|
|
||||||
Ownership is checked against the DB row when one exists (unchanged). If
|
Authenticated requests must match the stored DB or in-memory owner. When
|
||||||
there is no DB row but the caller owns an in-memory "ghost" session — one
|
auth is disabled and no user is present, treat the app as single-user mode:
|
||||||
that lives only in ``session_manager`` because it was never persisted, or
|
verify that the session exists, but do not compare its stored owner. This
|
||||||
its DB row was removed out-of-band — fall back to the in-memory owner so the
|
keeps QA/dev instances with AUTH_ENABLED=false from rejecting owner-stamped
|
||||||
user can still manage and delete it. Without this fallback such sessions are
|
rows created while auth was previously enabled.
|
||||||
listed by ``/api/sessions`` (they come from the in-memory manager) yet every
|
|
||||||
per-session operation 404s, making them impossible to delete (issue #1044).
|
|
||||||
|
|
||||||
``session_manager`` is optional and defaults to ``None`` so existing callers
|
|
||||||
that only care about persisted sessions keep their exact prior behavior.
|
|
||||||
"""
|
"""
|
||||||
user = effective_user(request)
|
user = effective_user(request)
|
||||||
if not user:
|
if not user and not _auth_disabled():
|
||||||
raise HTTPException(403, "Authentication required")
|
raise HTTPException(401, "Authentication required")
|
||||||
db = SessionLocal()
|
db = SessionLocal()
|
||||||
try:
|
try:
|
||||||
row = db.query(DbSession.owner).filter(DbSession.id == session_id).first()
|
row = db.query(DbSession.owner).filter(DbSession.id == session_id).first()
|
||||||
finally:
|
finally:
|
||||||
db.close()
|
db.close()
|
||||||
if row is not None:
|
if row is not None:
|
||||||
if row.owner != user:
|
if user and row.owner != user:
|
||||||
raise HTTPException(404, f"Session {session_id} not found")
|
raise HTTPException(404, f"Session {session_id} not found")
|
||||||
return
|
return
|
||||||
# No DB row — allow the caller to act on an in-memory ghost they own.
|
# No DB row — allow the caller to act on an in-memory ghost they own.
|
||||||
if session_manager is not None:
|
if session_manager is not None:
|
||||||
ghost = getattr(session_manager, "sessions", {}).get(session_id)
|
ghost = getattr(session_manager, "sessions", {}).get(session_id)
|
||||||
if ghost is not None and getattr(ghost, "owner", None) == user:
|
if ghost is not None and (not user or getattr(ghost, "owner", None) == user):
|
||||||
return
|
return
|
||||||
raise HTTPException(404, f"Session {session_id} not found")
|
raise HTTPException(404, f"Session {session_id} not found")
|
||||||
|
|
||||||
@@ -262,7 +258,7 @@ def setup_session_routes(session_manager: SessionManager, config: dict, webhook_
|
|||||||
last_msg_map = {}
|
last_msg_map = {}
|
||||||
mode_map = {}
|
mode_map = {}
|
||||||
msg_count_map = {}
|
msg_count_map = {}
|
||||||
rows = db.query(DbSession.id, DbSession.folder, DbSession.total_input_tokens, DbSession.total_output_tokens, DbSession.is_important, DbSession.created_at, DbSession.updated_at, DbSession.last_message_at, DbSession.mode, DbSession.message_count).filter(DbSession.archived == False).all()
|
rows = db.query(DbSession.id, DbSession.folder, DbSession.total_input_tokens, DbSession.total_output_tokens, DbSession.is_important, DbSession.created_at, DbSession.updated_at, DbSession.last_message_at, DbSession.mode, DbSession.message_count).filter(DbSession.archived == False, DbSession.owner == user).all()
|
||||||
for row in rows:
|
for row in rows:
|
||||||
folder_map[row.id] = row.folder
|
folder_map[row.id] = row.folder
|
||||||
token_map[row.id] = (row.total_input_tokens or 0) + (row.total_output_tokens or 0)
|
token_map[row.id] = (row.total_input_tokens or 0) + (row.total_output_tokens or 0)
|
||||||
@@ -284,12 +280,14 @@ def setup_session_routes(session_manager: SessionManager, config: dict, webhook_
|
|||||||
r[0] for r in db.query(Document.session_id)
|
r[0] for r in db.query(Document.session_id)
|
||||||
.filter(Document.is_active == True,
|
.filter(Document.is_active == True,
|
||||||
Document.current_content != None,
|
Document.current_content != None,
|
||||||
func.trim(Document.current_content) != "")
|
func.trim(Document.current_content) != "",
|
||||||
|
Document.owner == user)
|
||||||
.distinct().all()
|
.distinct().all()
|
||||||
)
|
)
|
||||||
img_session_ids = set(
|
img_session_ids = set(
|
||||||
r[0] for r in db.query(GalleryImage.session_id)
|
r[0] for r in db.query(GalleryImage.session_id)
|
||||||
.filter(GalleryImage.session_id != None)
|
.filter(GalleryImage.session_id != None,
|
||||||
|
GalleryImage.owner == user)
|
||||||
.distinct().all()
|
.distinct().all()
|
||||||
)
|
)
|
||||||
finally:
|
finally:
|
||||||
@@ -370,8 +368,13 @@ def setup_session_routes(session_manager: SessionManager, config: dict, webhook_
|
|||||||
pass
|
pass
|
||||||
elif not model_to_use:
|
elif not model_to_use:
|
||||||
from src.llm_core import list_model_ids
|
from src.llm_core import list_model_ids
|
||||||
ids = list_model_ids(endpoint_url, timeout=REQUEST_TIMEOUT,
|
ids = list_model_ids(
|
||||||
headers=validation_headers)
|
endpoint_url,
|
||||||
|
timeout=REQUEST_TIMEOUT,
|
||||||
|
headers=validation_headers,
|
||||||
|
owner=user,
|
||||||
|
endpoint_id=endpoint_id.strip() if endpoint_id else None,
|
||||||
|
)
|
||||||
if not ids:
|
if not ids:
|
||||||
raise HTTPException(400, "Cannot reach /v1/models")
|
raise HTTPException(400, "Cannot reach /v1/models")
|
||||||
# Default to the first CHAT model — endpoints often list embedding/
|
# Default to the first CHAT model — endpoints often list embedding/
|
||||||
@@ -385,8 +388,13 @@ def setup_session_routes(session_manager: SessionManager, config: dict, webhook_
|
|||||||
from src.llm_core import list_model_ids
|
from src.llm_core import list_model_ids
|
||||||
import os as _os
|
import os as _os
|
||||||
req_base = _os.path.basename(model_to_use.rstrip("/"))
|
req_base = _os.path.basename(model_to_use.rstrip("/"))
|
||||||
avail = list_model_ids(endpoint_url, timeout=REQUEST_TIMEOUT,
|
avail = list_model_ids(
|
||||||
headers=validation_headers)
|
endpoint_url,
|
||||||
|
timeout=REQUEST_TIMEOUT,
|
||||||
|
headers=validation_headers,
|
||||||
|
owner=user,
|
||||||
|
endpoint_id=endpoint_id.strip() if endpoint_id else None,
|
||||||
|
)
|
||||||
if not avail:
|
if not avail:
|
||||||
raise HTTPException(400, "Cannot reach /v1/models")
|
raise HTTPException(400, "Cannot reach /v1/models")
|
||||||
if model_to_use not in avail:
|
if model_to_use not in avail:
|
||||||
@@ -543,22 +551,25 @@ def setup_session_routes(session_manager: SessionManager, config: dict, webhook_
|
|||||||
ids = body.get("ids", [])
|
ids = body.get("ids", [])
|
||||||
except Exception:
|
except Exception:
|
||||||
ids = []
|
ids = []
|
||||||
|
deleted_count = 0
|
||||||
for sid in ids:
|
for sid in ids:
|
||||||
try:
|
try:
|
||||||
_verify_session_owner(request, sid, session_manager)
|
_verify_session_owner(request, sid, session_manager)
|
||||||
session_manager.delete_session(sid)
|
|
||||||
|
# Enforce "starred" protection consistent with single-session delete
|
||||||
db = SessionLocal()
|
db = SessionLocal()
|
||||||
try:
|
try:
|
||||||
db.query(_CM).filter(_CM.session_id == sid).delete()
|
db_sess = db.query(DbSession).filter(DbSession.id == sid).first()
|
||||||
db.query(DbSession).filter(DbSession.id == sid).delete()
|
if db_sess and db_sess.is_important:
|
||||||
db.commit()
|
continue
|
||||||
except Exception:
|
|
||||||
db.rollback()
|
|
||||||
finally:
|
finally:
|
||||||
db.close()
|
db.close()
|
||||||
|
|
||||||
|
if session_manager.delete_session(sid):
|
||||||
|
deleted_count += 1
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
return {"deleted": len(ids)}
|
return {"deleted": deleted_count}
|
||||||
|
|
||||||
@router.delete("/session/{sid}")
|
@router.delete("/session/{sid}")
|
||||||
def delete_session(request: Request, sid: str):
|
def delete_session(request: Request, sid: str):
|
||||||
@@ -924,7 +935,8 @@ def setup_session_routes(session_manager: SessionManager, config: dict, webhook_
|
|||||||
from src.endpoint_resolver import resolve_endpoint
|
from src.endpoint_resolver import resolve_endpoint
|
||||||
from src.llm_core import llm_call_async
|
from src.llm_core import llm_call_async
|
||||||
|
|
||||||
url, model, headers = resolve_endpoint("utility", owner=get_current_user(request))
|
owner = getattr(session, "owner", None) or effective_user(request)
|
||||||
|
url, model, headers = resolve_endpoint("utility", owner=owner)
|
||||||
if not url or not model:
|
if not url or not model:
|
||||||
url, model, headers = session.endpoint_url, session.model, session.headers
|
url, model, headers = session.endpoint_url, session.model, session.headers
|
||||||
if not url or not model:
|
if not url or not model:
|
||||||
@@ -1006,7 +1018,7 @@ def setup_session_routes(session_manager: SessionManager, config: dict, webhook_
|
|||||||
}
|
}
|
||||||
_THROWAWAY_MAX_MESSAGES = 4 # only delete if <= this many messages
|
_THROWAWAY_MAX_MESSAGES = 4 # only delete if <= this many messages
|
||||||
try:
|
try:
|
||||||
rows = db.query(DbSession).filter(DbSession.archived == False, DbSession.owner == user).all()
|
rows = db.query(DbSession).filter(DbSession.archived == False, DbSession.owner == user).limit(2000).all()
|
||||||
folder_map = {r.id: r.folder for r in rows}
|
folder_map = {r.id: r.folder for r in rows}
|
||||||
# Precompute per-session message counts in TWO aggregate queries
|
# Precompute per-session message counts in TWO aggregate queries
|
||||||
# instead of 1–3 queries PER session — with many chats the per-row
|
# instead of 1–3 queries PER session — with many chats the per-row
|
||||||
@@ -1017,6 +1029,7 @@ def setup_session_routes(session_manager: SessionManager, config: dict, webhook_
|
|||||||
db.query(DbMsg.session_id, _sa_func.count(DbMsg.id))
|
db.query(DbMsg.session_id, _sa_func.count(DbMsg.id))
|
||||||
.filter(DbMsg.role == "assistant").group_by(DbMsg.session_id).all()
|
.filter(DbMsg.role == "assistant").group_by(DbMsg.session_id).all()
|
||||||
)
|
)
|
||||||
|
cleanup_now = utcnow_naive()
|
||||||
for row in rows:
|
for row in rows:
|
||||||
# Never delete important sessions
|
# Never delete important sessions
|
||||||
if getattr(row, 'is_important', False):
|
if getattr(row, 'is_important', False):
|
||||||
@@ -1029,6 +1042,8 @@ def setup_session_routes(session_manager: SessionManager, config: dict, webhook_
|
|||||||
if hasattr(session_manager, 'delete_session'):
|
if hasattr(session_manager, 'delete_session'):
|
||||||
session_manager.delete_session(row.id)
|
session_manager.delete_session(row.id)
|
||||||
continue
|
continue
|
||||||
|
if is_session_recently_active(row, now=cleanup_now):
|
||||||
|
continue
|
||||||
msg_count = _counts.get(row.id, 0)
|
msg_count = _counts.get(row.id, 0)
|
||||||
should_delete = False
|
should_delete = False
|
||||||
if msg_count == 0:
|
if msg_count == 0:
|
||||||
|
|||||||
+279
-58
@@ -13,6 +13,7 @@ import tempfile
|
|||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, Any
|
from typing import Dict, Any
|
||||||
|
from core.platform_compat import IS_APPLE_SILICON, which_tool
|
||||||
|
|
||||||
# POSIX-only: `pty`/`fcntl` transitively import `termios`, which does NOT exist
|
# POSIX-only: `pty`/`fcntl` transitively import `termios`, which does NOT exist
|
||||||
# on Windows, so importing them unconditionally crashed app startup there
|
# on Windows, so importing them unconditionally crashed app startup there
|
||||||
@@ -37,6 +38,7 @@ from core.platform_compat import (
|
|||||||
IS_WINDOWS,
|
IS_WINDOWS,
|
||||||
detached_popen_kwargs,
|
detached_popen_kwargs,
|
||||||
find_bash,
|
find_bash,
|
||||||
|
git_bash_path,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -92,6 +94,7 @@ def _venv_activate_prefix(venv: str | None) -> str:
|
|||||||
act = venv if venv.endswith("/bin/activate") else venv.rstrip("/") + "/bin/activate"
|
act = venv if venv.endswith("/bin/activate") else venv.rstrip("/") + "/bin/activate"
|
||||||
return f". {act} && "
|
return f". {act} && "
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
PTY_SUPPORTED = pty is not None and fcntl is not None and hasattr(os, "setsid")
|
PTY_SUPPORTED = pty is not None and fcntl is not None and hasattr(os, "setsid")
|
||||||
@@ -169,7 +172,10 @@ def _package_installed_from_probe(name: str, probe: dict) -> bool:
|
|||||||
and (dists.get("torch") or modules.get("torch", {}).get("real_module"))
|
and (dists.get("torch") or modules.get("torch", {}).get("real_module"))
|
||||||
)
|
)
|
||||||
if name == "hf_transfer":
|
if name == "hf_transfer":
|
||||||
return bool(dists.get("hf-transfer") or modules.get("hf_transfer", {}).get("real_module"))
|
return bool(
|
||||||
|
dists.get("hf-transfer")
|
||||||
|
or modules.get("hf_transfer", {}).get("real_module")
|
||||||
|
)
|
||||||
return bool(dists.get(name) or modules.get(name, {}).get("real_module"))
|
return bool(dists.get(name) or modules.get(name, {}).get("real_module"))
|
||||||
|
|
||||||
|
|
||||||
@@ -194,8 +200,14 @@ def _package_status_note(name: str, probe: dict) -> str:
|
|||||||
if binaries.get("llama-server"):
|
if binaries.get("llama-server"):
|
||||||
parts.append(f"native llama-server: {binaries['llama-server']}")
|
parts.append(f"native llama-server: {binaries['llama-server']}")
|
||||||
if dists.get("llama-cpp-python"):
|
if dists.get("llama-cpp-python"):
|
||||||
parts.append(f"python package: llama-cpp-python {dists['llama-cpp-python']}")
|
parts.append(
|
||||||
return "; ".join(parts) if parts else "No native llama-server or llama-cpp-python server package found."
|
f"python package: llama-cpp-python {dists['llama-cpp-python']}"
|
||||||
|
)
|
||||||
|
return (
|
||||||
|
"; ".join(parts)
|
||||||
|
if parts
|
||||||
|
else "No native llama-server or llama-cpp-python server package found."
|
||||||
|
)
|
||||||
if name == "diffusers":
|
if name == "diffusers":
|
||||||
if _package_installed_from_probe(name, probe):
|
if _package_installed_from_probe(name, probe):
|
||||||
return f"diffusers {dists.get('diffusers', 'available')} with torch {dists.get('torch', 'available')}"
|
return f"diffusers {dists.get('diffusers', 'available')} with torch {dists.get('torch', 'available')}"
|
||||||
@@ -205,7 +217,9 @@ def _package_status_note(name: str, probe: dict) -> str:
|
|||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
|
||||||
def _package_pip_update_status(pkg: dict, probe: dict | None = None) -> PackageUpdateStatus:
|
def _package_pip_update_status(
|
||||||
|
pkg: dict, probe: dict | None = None
|
||||||
|
) -> PackageUpdateStatus:
|
||||||
"""Return whether the Dependencies UI should offer a generic pip update.
|
"""Return whether the Dependencies UI should offer a generic pip update.
|
||||||
|
|
||||||
"Installed" means Cookbook can use the dependency. It does not always mean
|
"Installed" means Cookbook can use the dependency. It does not always mean
|
||||||
@@ -213,12 +227,28 @@ def _package_pip_update_status(pkg: dict, probe: dict | None = None) -> PackageU
|
|||||||
native llama-server can come from a package manager/source build, and a CLI
|
native llama-server can come from a package manager/source build, and a CLI
|
||||||
may be on PATH without matching Python package metadata.
|
may be on PATH without matching Python package metadata.
|
||||||
"""
|
"""
|
||||||
|
if pkg.get("name") == "APFEL":
|
||||||
|
return PackageUpdateStatus(
|
||||||
|
False,
|
||||||
|
"", # Note is empty because IT DOES allow for updates outside of PIP.
|
||||||
|
)
|
||||||
|
|
||||||
if pkg.get("kind") == "system" or not pkg.get("pip"):
|
if pkg.get("kind") == "system" or not pkg.get("pip"):
|
||||||
return PackageUpdateStatus(False, "Update this system dependency outside Odysseus.")
|
return PackageUpdateStatus(
|
||||||
|
False, "Update this system dependency outside Odysseus."
|
||||||
|
)
|
||||||
|
|
||||||
name = pkg.get("name")
|
name = pkg.get("name")
|
||||||
binaries = probe.get("binaries") if isinstance(probe, dict) and isinstance(probe.get("binaries"), dict) else {}
|
binaries = (
|
||||||
dists = probe.get("dists") if isinstance(probe, dict) and isinstance(probe.get("dists"), dict) else {}
|
probe.get("binaries")
|
||||||
|
if isinstance(probe, dict) and isinstance(probe.get("binaries"), dict)
|
||||||
|
else {}
|
||||||
|
)
|
||||||
|
dists = (
|
||||||
|
probe.get("dists")
|
||||||
|
if isinstance(probe, dict) and isinstance(probe.get("dists"), dict)
|
||||||
|
else {}
|
||||||
|
)
|
||||||
|
|
||||||
if name == "llama_cpp" and binaries.get("llama-server"):
|
if name == "llama_cpp" and binaries.get("llama-server"):
|
||||||
return PackageUpdateStatus(
|
return PackageUpdateStatus(
|
||||||
@@ -231,7 +261,9 @@ def _package_pip_update_status(pkg: dict, probe: dict | None = None) -> PackageU
|
|||||||
"Using a vLLM CLI on PATH without Python package metadata; update it outside Odysseus.",
|
"Using a vLLM CLI on PATH without Python package metadata; update it outside Odysseus.",
|
||||||
)
|
)
|
||||||
|
|
||||||
return PackageUpdateStatus(True, "Update uses pip in the selected Python environment.")
|
return PackageUpdateStatus(
|
||||||
|
True, "Update uses pip in the selected Python environment."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _prepend_user_install_bins_to_path() -> None:
|
def _prepend_user_install_bins_to_path() -> None:
|
||||||
@@ -250,7 +282,9 @@ def _prepend_user_install_bins_to_path() -> None:
|
|||||||
candidates = []
|
candidates = []
|
||||||
candidates.append(os.path.expanduser("~/.local/bin"))
|
candidates.append(os.path.expanduser("~/.local/bin"))
|
||||||
|
|
||||||
parts = os.environ.get("PATH", "").split(os.pathsep) if os.environ.get("PATH") else []
|
parts = (
|
||||||
|
os.environ.get("PATH", "").split(os.pathsep) if os.environ.get("PATH") else []
|
||||||
|
)
|
||||||
changed = False
|
changed = False
|
||||||
for path in reversed([p for p in candidates if p]):
|
for path in reversed([p for p in candidates if p]):
|
||||||
if path not in parts:
|
if path not in parts:
|
||||||
@@ -357,9 +391,11 @@ PTY_UNSUPPORTED_ERROR = "pty_unsupported"
|
|||||||
|
|
||||||
class ShellExecRequest(BaseModel):
|
class ShellExecRequest(BaseModel):
|
||||||
command: str
|
command: str
|
||||||
timeout: int | None = None # optional override; 0 = no timeout (run until client disconnects)
|
timeout: int | None = (
|
||||||
use_pty: bool = False # use pseudo-TTY (for progress bars)
|
None # optional override; 0 = no timeout (run until client disconnects)
|
||||||
use_tmux: bool = False # run in tmux session (survives browser disconnect)
|
)
|
||||||
|
use_pty: bool = False # use pseudo-TTY (for progress bars)
|
||||||
|
use_tmux: bool = False # run in tmux session (survives browser disconnect)
|
||||||
|
|
||||||
|
|
||||||
async def _create_shell(command: str, **kwargs):
|
async def _create_shell(command: str, **kwargs):
|
||||||
@@ -368,8 +404,16 @@ async def _create_shell(command: str, **kwargs):
|
|||||||
POSIX: /bin/sh via create_subprocess_shell (unchanged behaviour).
|
POSIX: /bin/sh via create_subprocess_shell (unchanged behaviour).
|
||||||
Windows: prefer a real bash (Git Bash/WSL) so bash-syntax commands behave
|
Windows: prefer a real bash (Git Bash/WSL) so bash-syntax commands behave
|
||||||
the same as on Linux; fall back to cmd.exe when no bash is installed.
|
the same as on Linux; fall back to cmd.exe when no bash is installed.
|
||||||
|
Powershell commands are executed directly via cmd.exe /c to avoid quoting
|
||||||
|
and env variable expansion errors under Git Bash.
|
||||||
"""
|
"""
|
||||||
if IS_WINDOWS:
|
if IS_WINDOWS:
|
||||||
|
# PowerShell commands (used by the frontend for Windows log-file polling
|
||||||
|
# and session management) must run directly — passing them through
|
||||||
|
# bash -c mangles $env:VAR syntax and breaks the command.
|
||||||
|
cmd_trim = command.strip()
|
||||||
|
if cmd_trim.startswith("powershell") or cmd_trim.startswith("cmd "):
|
||||||
|
return await asyncio.create_subprocess_shell(command, **kwargs)
|
||||||
bash = find_bash()
|
bash = find_bash()
|
||||||
if bash:
|
if bash:
|
||||||
return await asyncio.create_subprocess_exec(bash, "-c", command, **kwargs)
|
return await asyncio.create_subprocess_exec(bash, "-c", command, **kwargs)
|
||||||
@@ -386,9 +430,7 @@ async def _exec_shell(command: str, timeout: int = EXEC_TIMEOUT) -> Dict[str, An
|
|||||||
stderr=asyncio.subprocess.PIPE,
|
stderr=asyncio.subprocess.PIPE,
|
||||||
cwd=str(Path.home()),
|
cwd=str(Path.home()),
|
||||||
)
|
)
|
||||||
stdout_b, stderr_b = await asyncio.wait_for(
|
stdout_b, stderr_b = await asyncio.wait_for(proc.communicate(), timeout=timeout)
|
||||||
proc.communicate(), timeout=timeout
|
|
||||||
)
|
|
||||||
stdout = stdout_b.decode(errors="replace")[:MAX_OUTPUT]
|
stdout = stdout_b.decode(errors="replace")[:MAX_OUTPUT]
|
||||||
stderr = stderr_b.decode(errors="replace")[:MAX_OUTPUT]
|
stderr = stderr_b.decode(errors="replace")[:MAX_OUTPUT]
|
||||||
return {"stdout": stdout, "stderr": stderr, "exit_code": proc.returncode}
|
return {"stdout": stdout, "stderr": stderr, "exit_code": proc.returncode}
|
||||||
@@ -399,7 +441,11 @@ async def _exec_shell(command: str, timeout: int = EXEC_TIMEOUT) -> Dict[str, An
|
|||||||
await proc.wait()
|
await proc.wait()
|
||||||
except ProcessLookupError:
|
except ProcessLookupError:
|
||||||
pass
|
pass
|
||||||
return {"stdout": "", "stderr": f"Command timed out after {timeout}s", "exit_code": -1}
|
return {
|
||||||
|
"stdout": "",
|
||||||
|
"stderr": f"Command timed out after {timeout}s",
|
||||||
|
"exit_code": -1,
|
||||||
|
}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return {"stdout": "", "stderr": str(e), "exit_code": -1}
|
return {"stdout": "", "stderr": str(e), "exit_code": -1}
|
||||||
|
|
||||||
@@ -481,7 +527,7 @@ async def _generate_pty(cmd: str, timeout: int, request: Request):
|
|||||||
if idx == -1:
|
if idx == -1:
|
||||||
break
|
break
|
||||||
line = buf[:idx].decode(errors="replace")
|
line = buf[:idx].decode(errors="replace")
|
||||||
buf = buf[idx + sep_len:]
|
buf = buf[idx + sep_len :]
|
||||||
if line:
|
if line:
|
||||||
yield f"data: {json.dumps({'stream': 'stdout', 'data': line})}\n\n"
|
yield f"data: {json.dumps({'stream': 'stdout', 'data': line})}\n\n"
|
||||||
|
|
||||||
@@ -503,7 +549,7 @@ async def _generate_pty(cmd: str, timeout: int, request: Request):
|
|||||||
if idx == -1:
|
if idx == -1:
|
||||||
break
|
break
|
||||||
line = buf[:idx].decode(errors="replace")
|
line = buf[:idx].decode(errors="replace")
|
||||||
buf = buf[idx + sep_len:]
|
buf = buf[idx + sep_len :]
|
||||||
if line:
|
if line:
|
||||||
yield f"data: {json.dumps({'stream': 'stdout', 'data': line})}\n\n"
|
yield f"data: {json.dumps({'stream': 'stdout', 'data': line})}\n\n"
|
||||||
if buf:
|
if buf:
|
||||||
@@ -534,6 +580,7 @@ def _pty_read(fd: int) -> bytes | None:
|
|||||||
"""Blocking read from PTY fd. Called via run_in_executor.
|
"""Blocking read from PTY fd. Called via run_in_executor.
|
||||||
Returns bytes on data, None on timeout (no data yet)."""
|
Returns bytes on data, None on timeout (no data yet)."""
|
||||||
import select
|
import select
|
||||||
|
|
||||||
r, _, _ = select.select([fd], [], [], 1.0)
|
r, _, _ = select.select([fd], [], [], 1.0)
|
||||||
if r:
|
if r:
|
||||||
try:
|
try:
|
||||||
@@ -557,10 +604,10 @@ async def _generate_tmux(cmd: str, request: Request):
|
|||||||
script_path = TMUX_LOG_DIR / f"{session_id}.sh"
|
script_path = TMUX_LOG_DIR / f"{session_id}.sh"
|
||||||
script_path.write_text(
|
script_path.write_text(
|
||||||
f"#!/bin/bash\n"
|
f"#!/bin/bash\n"
|
||||||
f"ODYSSEUS_USER_SHELL=\"${{SHELL:-}}\"\n"
|
f'ODYSSEUS_USER_SHELL="${{SHELL:-}}"\n'
|
||||||
f"if [ -n \"$ODYSSEUS_USER_SHELL\" ] && [ -x \"$ODYSSEUS_USER_SHELL\" ]; then\n"
|
f'if [ -n "$ODYSSEUS_USER_SHELL" ] && [ -x "$ODYSSEUS_USER_SHELL" ]; then\n'
|
||||||
f" ODYSSEUS_USER_PATH=\"$(\"$ODYSSEUS_USER_SHELL\" -ic 'printf \"__ODYSSEUS_PATH__%s\\n\" \"$PATH\"' 2>/dev/null | sed -n 's/^__ODYSSEUS_PATH__//p' | tail -n 1 || true)\"\n"
|
f' ODYSSEUS_USER_PATH="$("$ODYSSEUS_USER_SHELL" -ic \'printf "__ODYSSEUS_PATH__%s\\n" "$PATH"\' 2>/dev/null | sed -n \'s/^__ODYSSEUS_PATH__//p\' | tail -n 1 || true)"\n'
|
||||||
f" if [ -n \"$ODYSSEUS_USER_PATH\" ]; then export PATH=\"$ODYSSEUS_USER_PATH:$PATH\"; fi\n"
|
f' if [ -n "$ODYSSEUS_USER_PATH" ]; then export PATH="$ODYSSEUS_USER_PATH:$PATH"; fi\n'
|
||||||
f"fi\n"
|
f"fi\n"
|
||||||
f"{cmd} 2>&1 | tee '{log_path}'\n"
|
f"{cmd} 2>&1 | tee '{log_path}'\n"
|
||||||
f"EC=${{PIPESTATUS[0]}}\n"
|
f"EC=${{PIPESTATUS[0]}}\n"
|
||||||
@@ -570,7 +617,9 @@ async def _generate_tmux(cmd: str, request: Request):
|
|||||||
encoding="utf-8",
|
encoding="utf-8",
|
||||||
)
|
)
|
||||||
script_path.chmod(0o755)
|
script_path.chmod(0o755)
|
||||||
logger.info("tmux wrapper script created: session=%s path=%s", session_id, script_path)
|
logger.info(
|
||||||
|
"tmux wrapper script created: session=%s path=%s", session_id, script_path
|
||||||
|
)
|
||||||
|
|
||||||
tmux_cmd = f"tmux new-session -d -s {session_id} {shlex.quote(str(script_path))}"
|
tmux_cmd = f"tmux new-session -d -s {session_id} {shlex.quote(str(script_path))}"
|
||||||
|
|
||||||
@@ -602,7 +651,9 @@ async def _generate_tmux(cmd: str, request: Request):
|
|||||||
# Read new lines from log
|
# Read new lines from log
|
||||||
try:
|
try:
|
||||||
if log_path.exists():
|
if log_path.exists():
|
||||||
lines = log_path.read_text(encoding="utf-8", errors="replace").splitlines()
|
lines = log_path.read_text(
|
||||||
|
encoding="utf-8", errors="replace"
|
||||||
|
).splitlines()
|
||||||
new_lines = lines[lines_sent:]
|
new_lines = lines[lines_sent:]
|
||||||
for line in new_lines:
|
for line in new_lines:
|
||||||
if line.startswith(":::EXIT_CODE:::"):
|
if line.startswith(":::EXIT_CODE:::"):
|
||||||
@@ -630,7 +681,9 @@ async def _generate_tmux(cmd: str, request: Request):
|
|||||||
# Session ended — do one final read
|
# Session ended — do one final read
|
||||||
await asyncio.sleep(0.5)
|
await asyncio.sleep(0.5)
|
||||||
if log_path.exists():
|
if log_path.exists():
|
||||||
lines = log_path.read_text(encoding="utf-8", errors="replace").splitlines()
|
lines = log_path.read_text(
|
||||||
|
encoding="utf-8", errors="replace"
|
||||||
|
).splitlines()
|
||||||
for line in lines[lines_sent:]:
|
for line in lines[lines_sent:]:
|
||||||
if line.startswith(":::EXIT_CODE:::"):
|
if line.startswith(":::EXIT_CODE:::"):
|
||||||
try:
|
try:
|
||||||
@@ -672,8 +725,8 @@ async def _generate_win_detached(cmd: str, request: Request):
|
|||||||
if bash:
|
if bash:
|
||||||
script_path = TMUX_LOG_DIR / f"{session_id}.sh"
|
script_path = TMUX_LOG_DIR / f"{session_id}.sh"
|
||||||
script_path.write_text(
|
script_path.write_text(
|
||||||
f"{cmd} > {shlex.quote(str(log_path))} 2>&1\n"
|
f"{cmd} > {shlex.quote(git_bash_path(log_path))} 2>&1\n"
|
||||||
f"echo $? > {shlex.quote(str(exit_path))}\n",
|
f"echo $? > {shlex.quote(git_bash_path(exit_path))}\n",
|
||||||
encoding="utf-8",
|
encoding="utf-8",
|
||||||
)
|
)
|
||||||
argv = [bash, str(script_path)]
|
argv = [bash, str(script_path)]
|
||||||
@@ -711,7 +764,9 @@ async def _generate_win_detached(cmd: str, request: Request):
|
|||||||
return
|
return
|
||||||
try:
|
try:
|
||||||
if log_path.exists():
|
if log_path.exists():
|
||||||
lines = log_path.read_text(encoding="utf-8", errors="replace").splitlines()
|
lines = log_path.read_text(
|
||||||
|
encoding="utf-8", errors="replace"
|
||||||
|
).splitlines()
|
||||||
for line in lines[lines_sent:]:
|
for line in lines[lines_sent:]:
|
||||||
yield f"data: {json.dumps({'stream': 'stdout', 'data': line})}\n\n"
|
yield f"data: {json.dumps({'stream': 'stdout', 'data': line})}\n\n"
|
||||||
lines_sent = len(lines)
|
lines_sent = len(lines)
|
||||||
@@ -723,11 +778,18 @@ async def _generate_win_detached(cmd: str, request: Request):
|
|||||||
await asyncio.sleep(0.3)
|
await asyncio.sleep(0.3)
|
||||||
try:
|
try:
|
||||||
if log_path.exists():
|
if log_path.exists():
|
||||||
lines = log_path.read_text(encoding="utf-8", errors="replace").splitlines()
|
lines = log_path.read_text(
|
||||||
|
encoding="utf-8", errors="replace"
|
||||||
|
).splitlines()
|
||||||
for line in lines[lines_sent:]:
|
for line in lines[lines_sent:]:
|
||||||
yield f"data: {json.dumps({'stream': 'stdout', 'data': line})}\n\n"
|
yield f"data: {json.dumps({'stream': 'stdout', 'data': line})}\n\n"
|
||||||
lines_sent = len(lines)
|
lines_sent = len(lines)
|
||||||
exit_code = int((exit_path.read_text(encoding="utf-8", errors="replace").strip() or "0"))
|
exit_code = int(
|
||||||
|
(
|
||||||
|
exit_path.read_text(encoding="utf-8", errors="replace").strip()
|
||||||
|
or "0"
|
||||||
|
)
|
||||||
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
exit_code = 0
|
exit_code = 0
|
||||||
break
|
break
|
||||||
@@ -753,7 +815,9 @@ def setup_shell_routes() -> APIRouter:
|
|||||||
return {"stdout": "", "stderr": "No command provided", "exit_code": 1}
|
return {"stdout": "", "stderr": "No command provided", "exit_code": 1}
|
||||||
|
|
||||||
logger.info("User shell exec requested: length=%d", len(cmd))
|
logger.info("User shell exec requested: length=%d", len(cmd))
|
||||||
result = await _exec_shell(cmd, timeout=EXEC_TIMEOUT)
|
result = await _exec_shell(
|
||||||
|
cmd, timeout=req.timeout if req.timeout is not None else EXEC_TIMEOUT
|
||||||
|
)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@router.post("/api/shell/stream")
|
@router.post("/api/shell/stream")
|
||||||
@@ -762,9 +826,11 @@ def setup_shell_routes() -> APIRouter:
|
|||||||
_require_admin(request)
|
_require_admin(request)
|
||||||
cmd = req.command.strip()
|
cmd = req.command.strip()
|
||||||
if not cmd:
|
if not cmd:
|
||||||
|
|
||||||
async def empty():
|
async def empty():
|
||||||
yield f"data: {json.dumps({'stream': 'stderr', 'data': 'No command provided'})}\n\n"
|
yield f"data: {json.dumps({'stream': 'stderr', 'data': 'No command provided'})}\n\n"
|
||||||
yield f"data: {json.dumps({'exit_code': 1})}\n\n"
|
yield f"data: {json.dumps({'exit_code': 1})}\n\n"
|
||||||
|
|
||||||
return StreamingResponse(empty(), media_type="text/event-stream")
|
return StreamingResponse(empty(), media_type="text/event-stream")
|
||||||
|
|
||||||
timeout = req.timeout if req.timeout is not None else STREAM_TIMEOUT
|
timeout = req.timeout if req.timeout is not None else STREAM_TIMEOUT
|
||||||
@@ -781,7 +847,11 @@ def setup_shell_routes() -> APIRouter:
|
|||||||
if use_tmux:
|
if use_tmux:
|
||||||
# tmux is POSIX-only; Windows uses a detached-process + logfile tail
|
# tmux is POSIX-only; Windows uses a detached-process + logfile tail
|
||||||
# that preserves the "survives disconnect" behaviour.
|
# that preserves the "survives disconnect" behaviour.
|
||||||
gen = _generate_win_detached(cmd, request) if IS_WINDOWS else _generate_tmux(cmd, request)
|
gen = (
|
||||||
|
_generate_win_detached(cmd, request)
|
||||||
|
if IS_WINDOWS
|
||||||
|
else _generate_tmux(cmd, request)
|
||||||
|
)
|
||||||
return StreamingResponse(gen, media_type="text/event-stream")
|
return StreamingResponse(gen, media_type="text/event-stream")
|
||||||
|
|
||||||
if use_pty and not IS_WINDOWS:
|
if use_pty and not IS_WINDOWS:
|
||||||
@@ -813,7 +883,12 @@ def setup_shell_routes() -> APIRouter:
|
|||||||
chunk = await stream.read(4096)
|
chunk = await stream.read(4096)
|
||||||
if not chunk:
|
if not chunk:
|
||||||
if buf:
|
if buf:
|
||||||
await q.put((name, buf.decode(errors="replace").rstrip("\r\n")))
|
await q.put(
|
||||||
|
(
|
||||||
|
name,
|
||||||
|
buf.decode(errors="replace").rstrip("\r\n"),
|
||||||
|
)
|
||||||
|
)
|
||||||
break
|
break
|
||||||
buf += chunk
|
buf += chunk
|
||||||
while True:
|
while True:
|
||||||
@@ -821,7 +896,7 @@ def setup_shell_routes() -> APIRouter:
|
|||||||
if idx == -1:
|
if idx == -1:
|
||||||
break
|
break
|
||||||
line = buf[:idx].decode(errors="replace")
|
line = buf[:idx].decode(errors="replace")
|
||||||
buf = buf[idx + sep_len:]
|
buf = buf[idx + sep_len :]
|
||||||
if line:
|
if line:
|
||||||
await q.put((name, line))
|
await q.put((name, line))
|
||||||
finally:
|
finally:
|
||||||
@@ -880,7 +955,12 @@ def setup_shell_routes() -> APIRouter:
|
|||||||
return StreamingResponse(generate(), media_type="text/event-stream")
|
return StreamingResponse(generate(), media_type="text/event-stream")
|
||||||
|
|
||||||
@router.get("/api/cookbook/packages")
|
@router.get("/api/cookbook/packages")
|
||||||
async def list_packages(request: Request, host: str | None = None, ssh_port: str | None = None, venv: str | None = None):
|
async def list_packages(
|
||||||
|
request: Request,
|
||||||
|
host: str | None = None,
|
||||||
|
ssh_port: str | None = None,
|
||||||
|
venv: str | None = None,
|
||||||
|
):
|
||||||
"""Check which optional packages are installed.
|
"""Check which optional packages are installed.
|
||||||
|
|
||||||
Local-target packages are checked in-process. Remote-target packages
|
Local-target packages are checked in-process. Remote-target packages
|
||||||
@@ -890,7 +970,13 @@ def setup_shell_routes() -> APIRouter:
|
|||||||
"""
|
"""
|
||||||
_require_admin(request)
|
_require_admin(request)
|
||||||
_reject_cross_site(request)
|
_reject_cross_site(request)
|
||||||
import importlib, importlib.metadata as importlib_metadata, shlex, json as _json, site, sys
|
import importlib
|
||||||
|
import importlib.metadata as importlib_metadata
|
||||||
|
import shlex
|
||||||
|
import json as _json
|
||||||
|
import site
|
||||||
|
import sys
|
||||||
|
|
||||||
_prepend_user_install_bins_to_path()
|
_prepend_user_install_bins_to_path()
|
||||||
importlib.invalidate_caches()
|
importlib.invalidate_caches()
|
||||||
try:
|
try:
|
||||||
@@ -905,26 +991,115 @@ def setup_shell_routes() -> APIRouter:
|
|||||||
raise HTTPException(400, "Invalid ssh_port")
|
raise HTTPException(400, "Invalid ssh_port")
|
||||||
packages = [
|
packages = [
|
||||||
# ── System ── OS binaries, not pip packages
|
# ── System ── OS binaries, not pip packages
|
||||||
{"name": "tmux", "pip": "", "desc": "Required for Linux/Termux Cookbook background downloads and serves", "category": "System", "target": "remote", "kind": "system", "install_hint": "Run Cookbook server setup, or install tmux with apt/pacman/dnf/apk/zypper."},
|
{
|
||||||
{"name": "docker", "pip": "", "desc": "Required only for Docker-backed launch commands", "category": "System", "target": "remote", "kind": "system", "install_hint": "Install Docker on the selected server and allow this user to run docker."},
|
"name": "tmux",
|
||||||
|
"pip": "",
|
||||||
|
"desc": "Required for Linux/Termux Cookbook background downloads and serves",
|
||||||
|
"category": "System",
|
||||||
|
"target": "remote",
|
||||||
|
"kind": "system",
|
||||||
|
"install_hint": "Run Cookbook server setup, or install tmux with apt/pacman/dnf/apk/zypper.",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "docker",
|
||||||
|
"pip": "",
|
||||||
|
"desc": "Required only for Docker-backed launch commands",
|
||||||
|
"category": "System",
|
||||||
|
"target": "remote",
|
||||||
|
"kind": "system",
|
||||||
|
"install_hint": "Install Docker on the selected server and allow this user to run docker.",
|
||||||
|
},
|
||||||
# ── LLM ── installs on GPU servers for model serving/downloading
|
# ── LLM ── installs on GPU servers for model serving/downloading
|
||||||
{"name": "hf_transfer", "pip": "hf_transfer", "desc": "Fast model downloads from HuggingFace", "category": "LLM", "target": "remote"},
|
{
|
||||||
{"name": "llama_cpp", "pip": "llama-cpp-python[server]", "desc": "Serve GGUF models via llama.cpp", "category": "LLM", "target": "remote"},
|
"name": "hf_transfer",
|
||||||
{"name": "sglang", "pip": "sglang[all]", "desc": "Serve HF safetensors models via SGLang", "category": "LLM", "target": "remote"},
|
"pip": "hf_transfer",
|
||||||
{"name": "vllm", "pip": "vllm", "desc": "High-throughput LLM serving engine", "category": "LLM", "target": "remote"},
|
"desc": "Fast model downloads from HuggingFace",
|
||||||
|
"category": "LLM",
|
||||||
|
"target": "remote",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "llama_cpp",
|
||||||
|
"pip": "llama-cpp-python[server]",
|
||||||
|
"desc": "Serve GGUF models via llama.cpp",
|
||||||
|
"category": "LLM",
|
||||||
|
"target": "remote",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "sglang",
|
||||||
|
"pip": "sglang[all]",
|
||||||
|
"desc": "Serve HF safetensors models via SGLang",
|
||||||
|
"category": "LLM",
|
||||||
|
"target": "remote",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "vllm",
|
||||||
|
"pip": "vllm",
|
||||||
|
"desc": "High-throughput LLM serving engine",
|
||||||
|
"category": "LLM",
|
||||||
|
"target": "remote",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "APFEL",
|
||||||
|
"pip": "",
|
||||||
|
"desc": "OpenAI-compatible API for Apple Foundational Models on Apple Silicon",
|
||||||
|
"category": "LLM",
|
||||||
|
"target": "local",
|
||||||
|
"kind": "system",
|
||||||
|
"install_cmd": "brew install apfel",
|
||||||
|
"update_cmd": "brew upgrade apfel",
|
||||||
|
"install_hint": "Requires a native Apple Silicon Mac with Apple Foundational Models support. Installable via Homebrew on supported Macs.",
|
||||||
|
},
|
||||||
# ── Image ── editor + diffusion model serving
|
# ── Image ── editor + diffusion model serving
|
||||||
{"name": "diffusers", "pip": "diffusers[torch]", "desc": "Image generation pipelines (SD, Flux) with PyTorch", "category": "Image", "target": "remote"},
|
{
|
||||||
{"name": "rembg", "pip": "rembg[gpu]", "desc": "AI background removal for image editor", "category": "Image", "target": "local"},
|
"name": "diffusers",
|
||||||
{"name": "realesrgan", "pip": "realesrgan", "desc": "AI denoise + upscale (Real-ESRGAN). Used by editor's Denoise and Upscale tools.", "category": "Image", "target": "local"},
|
"pip": "diffusers[torch]",
|
||||||
|
"desc": "Image generation pipelines (SD, Flux) with PyTorch",
|
||||||
|
"category": "Image",
|
||||||
|
"target": "remote",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "rembg",
|
||||||
|
"pip": "rembg[gpu]",
|
||||||
|
"desc": "AI background removal for image editor",
|
||||||
|
"category": "Image",
|
||||||
|
"target": "local",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "realesrgan",
|
||||||
|
"pip": "realesrgan",
|
||||||
|
"desc": "AI denoise + upscale (Real-ESRGAN). Used by editor's Denoise and Upscale tools.",
|
||||||
|
"category": "Image",
|
||||||
|
"target": "local",
|
||||||
|
},
|
||||||
# ── Tools ──
|
# ── Tools ──
|
||||||
{"name": "playwright", "pip": "playwright", "desc": "Browser automation for web tools", "category": "Tools", "target": "local"},
|
{
|
||||||
|
"name": "playwright",
|
||||||
|
"pip": "playwright",
|
||||||
|
"desc": "Browser automation for web tools",
|
||||||
|
"category": "Tools",
|
||||||
|
"target": "local",
|
||||||
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
|
# Most packages should not be installed through external means. Hence, set the default of the
|
||||||
|
# install_cmd and update_cmd to None, which indicates that the recommended way to install/update is through the Cookbook # server setup or pip. Only system packages, should have explicit install/update commands provided.
|
||||||
|
for pkg in packages:
|
||||||
|
pkg.setdefault("install_cmd", None)
|
||||||
|
pkg.setdefault("update_cmd", None)
|
||||||
# Remote check: for remote-target packages, probe the selected server's
|
# Remote check: for remote-target packages, probe the selected server's
|
||||||
# venv over SSH so a remote `pip install` actually reflects here.
|
# venv over SSH so a remote `pip install` actually reflects here.
|
||||||
remote_status: dict = {}
|
remote_status: dict = {}
|
||||||
remote_details: dict = {}
|
remote_details: dict = {}
|
||||||
remote_names = [p["name"] for p in packages if p.get("target") == "remote" and p.get("kind") != "system"]
|
remote_names = [
|
||||||
remote_system_names = [p["name"] for p in packages if p.get("target") == "remote" and p.get("kind") == "system"]
|
p["name"]
|
||||||
|
for p in packages
|
||||||
|
if p.get("target") == "remote" and p.get("kind") != "system"
|
||||||
|
]
|
||||||
|
remote_system_names = [
|
||||||
|
p["name"]
|
||||||
|
for p in packages
|
||||||
|
if p.get("target") == "remote" and p.get("kind") == "system"
|
||||||
|
]
|
||||||
if host and remote_names:
|
if host and remote_names:
|
||||||
try:
|
try:
|
||||||
py = _package_probe_script(remote_names)
|
py = _package_probe_script(remote_names)
|
||||||
@@ -934,7 +1109,9 @@ def setup_shell_routes() -> APIRouter:
|
|||||||
inner = f"{src}python3 -c {shlex.quote(py)}"
|
inner = f"{src}python3 -c {shlex.quote(py)}"
|
||||||
argv = _ssh_base_argv(host, ssh_port) + [inner]
|
argv = _ssh_base_argv(host, ssh_port) + [inner]
|
||||||
proc = await asyncio.create_subprocess_exec(
|
proc = await asyncio.create_subprocess_exec(
|
||||||
*argv, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE
|
*argv,
|
||||||
|
stdout=asyncio.subprocess.PIPE,
|
||||||
|
stderr=asyncio.subprocess.PIPE,
|
||||||
)
|
)
|
||||||
out, _err = await asyncio.wait_for(proc.communicate(), timeout=12)
|
out, _err = await asyncio.wait_for(proc.communicate(), timeout=12)
|
||||||
txt = out.decode("utf-8", errors="replace").strip()
|
txt = out.decode("utf-8", errors="replace").strip()
|
||||||
@@ -958,11 +1135,15 @@ def setup_shell_routes() -> APIRouter:
|
|||||||
checks = []
|
checks = []
|
||||||
for name in remote_system_names:
|
for name in remote_system_names:
|
||||||
qn = shlex.quote(name)
|
qn = shlex.quote(name)
|
||||||
checks.append(f"if command -v {qn} >/dev/null 2>&1; then echo {qn}=1; else echo {qn}=0; fi")
|
checks.append(
|
||||||
|
f"if command -v {qn} >/dev/null 2>&1; then echo {qn}=1; else echo {qn}=0; fi"
|
||||||
|
)
|
||||||
inner = " ; ".join(checks)
|
inner = " ; ".join(checks)
|
||||||
argv = _ssh_base_argv(host, ssh_port) + [inner]
|
argv = _ssh_base_argv(host, ssh_port) + [inner]
|
||||||
proc = await asyncio.create_subprocess_exec(
|
proc = await asyncio.create_subprocess_exec(
|
||||||
*argv, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE
|
*argv,
|
||||||
|
stdout=asyncio.subprocess.PIPE,
|
||||||
|
stderr=asyncio.subprocess.PIPE,
|
||||||
)
|
)
|
||||||
out, _err = await asyncio.wait_for(proc.communicate(), timeout=12)
|
out, _err = await asyncio.wait_for(proc.communicate(), timeout=12)
|
||||||
txt = out.decode("utf-8", errors="replace").strip()
|
txt = out.decode("utf-8", errors="replace").strip()
|
||||||
@@ -987,11 +1168,25 @@ def setup_shell_routes() -> APIRouter:
|
|||||||
if note:
|
if note:
|
||||||
pkg["status_note"] = note
|
pkg["status_note"] = note
|
||||||
elif pkg.get("kind") == "system":
|
elif pkg.get("kind") == "system":
|
||||||
pkg["installed"] = shutil.which(pkg["name"]) is not None
|
if pkg["name"] == "APFEL":
|
||||||
|
pkg["applicable"] = IS_APPLE_SILICON
|
||||||
|
pkg["installed"] = which_tool("apfel") is not None
|
||||||
|
pkg["status_note"] = (
|
||||||
|
"Available on Apple Silicon (arm64) devices; exposed through a local OpenAI-compatible API."
|
||||||
|
if IS_APPLE_SILICON
|
||||||
|
else "Requires a native Apple Silicon Mac with Apple Foundational Models support."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
pkg["installed"] = shutil.which(pkg["name"]) is not None
|
||||||
elif pkg["name"] == "llama_cpp" and shutil.which("llama-server"):
|
elif pkg["name"] == "llama_cpp" and shutil.which("llama-server"):
|
||||||
pkg["installed"] = True
|
pkg["installed"] = True
|
||||||
pkg["status_note"] = f"native llama-server: {shutil.which('llama-server')}"
|
pkg["status_note"] = (
|
||||||
probe = {"binaries": {"llama-server": shutil.which("llama-server")}, "dists": {}}
|
f"native llama-server: {shutil.which('llama-server')}"
|
||||||
|
)
|
||||||
|
probe = {
|
||||||
|
"binaries": {"llama-server": shutil.which("llama-server")},
|
||||||
|
"dists": {},
|
||||||
|
}
|
||||||
elif pkg["name"] == "vllm":
|
elif pkg["name"] == "vllm":
|
||||||
_vllm_cli = shutil.which("vllm")
|
_vllm_cli = shutil.which("vllm")
|
||||||
pkg["installed"] = _vllm_cli is not None
|
pkg["installed"] = _vllm_cli is not None
|
||||||
@@ -1014,6 +1209,12 @@ def setup_shell_routes() -> APIRouter:
|
|||||||
pkg["installed"] = False
|
pkg["installed"] = False
|
||||||
except importlib_metadata.PackageNotFoundError:
|
except importlib_metadata.PackageNotFoundError:
|
||||||
pkg["installed"] = False
|
pkg["installed"] = False
|
||||||
|
except Exception:
|
||||||
|
# Installed but crashes on import — e.g. a CUDA build of
|
||||||
|
# llama-cpp-python raising FileNotFoundError when the CUDA
|
||||||
|
# toolkit dir is absent. One broken optional package must not
|
||||||
|
# 500 the entire packages panel; report it as not usable.
|
||||||
|
pkg["installed"] = False
|
||||||
|
|
||||||
if pkg.get("installed"):
|
if pkg.get("installed"):
|
||||||
update_status = _package_pip_update_status(pkg, probe)
|
update_status = _package_pip_update_status(pkg, probe)
|
||||||
@@ -1037,15 +1238,30 @@ def setup_shell_routes() -> APIRouter:
|
|||||||
"""Install a package via pip. Admin only — pip install is effectively code exec."""
|
"""Install a package via pip. Admin only — pip install is effectively code exec."""
|
||||||
_require_admin(request)
|
_require_admin(request)
|
||||||
import sys as _sys
|
import sys as _sys
|
||||||
|
|
||||||
body = await request.json()
|
body = await request.json()
|
||||||
pip_name = body.get("pip")
|
pip_name = body.get("pip")
|
||||||
if not pip_name:
|
if not pip_name:
|
||||||
return {"ok": False, "error": "No package specified"}
|
return {"ok": False, "error": "No package specified"}
|
||||||
# Validate against known packages to prevent arbitrary pip install
|
# Validate against known packages to prevent arbitrary pip install
|
||||||
known = {
|
known = {
|
||||||
"rembg[gpu]", "hf_transfer", "llama-cpp-python[server]", "sglang[all]", "diffusers", "diffusers[torch]",
|
"rembg[gpu]",
|
||||||
"TTS", "bark", "faster-whisper", "playwright", "realesrgan", "gfpgan",
|
"hf_transfer",
|
||||||
"insightface", "onnxruntime-gpu", "onnxruntime", "hdbscan", "vllm",
|
"llama-cpp-python[server]",
|
||||||
|
"sglang[all]",
|
||||||
|
"diffusers",
|
||||||
|
"diffusers[torch]",
|
||||||
|
"TTS",
|
||||||
|
"bark",
|
||||||
|
"faster-whisper",
|
||||||
|
"playwright",
|
||||||
|
"realesrgan",
|
||||||
|
"gfpgan",
|
||||||
|
"insightface",
|
||||||
|
"onnxruntime-gpu",
|
||||||
|
"onnxruntime",
|
||||||
|
"hdbscan",
|
||||||
|
"vllm",
|
||||||
}
|
}
|
||||||
if pip_name not in known:
|
if pip_name not in known:
|
||||||
return {"ok": False, "error": f"Unknown package: {pip_name}"}
|
return {"ok": False, "error": f"Unknown package: {pip_name}"}
|
||||||
@@ -1071,6 +1287,7 @@ def setup_shell_routes() -> APIRouter:
|
|||||||
"""
|
"""
|
||||||
_require_admin(request)
|
_require_admin(request)
|
||||||
from routes.cookbook_helpers import _llama_cpp_rebuild_cmd
|
from routes.cookbook_helpers import _llama_cpp_rebuild_cmd
|
||||||
|
|
||||||
body = await request.json()
|
body = await request.json()
|
||||||
engine = str(body.get("engine") or "llamacpp").strip()
|
engine = str(body.get("engine") or "llamacpp").strip()
|
||||||
if engine != "llamacpp":
|
if engine != "llamacpp":
|
||||||
@@ -1079,7 +1296,11 @@ def setup_shell_routes() -> APIRouter:
|
|||||||
ssh_port = body.get("ssh_port")
|
ssh_port = body.get("ssh_port")
|
||||||
cmd = _llama_cpp_rebuild_cmd()
|
cmd = _llama_cpp_rebuild_cmd()
|
||||||
try:
|
try:
|
||||||
argv = (_ssh_base_argv(host, ssh_port) + [cmd]) if host else ["bash", "-lc", cmd]
|
argv = (
|
||||||
|
(_ssh_base_argv(host, ssh_port) + [cmd])
|
||||||
|
if host
|
||||||
|
else ["bash", "-lc", cmd]
|
||||||
|
)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
raise HTTPException(400, str(e))
|
raise HTTPException(400, str(e))
|
||||||
try:
|
try:
|
||||||
|
|||||||
+44
-16
@@ -21,10 +21,44 @@ from src.auth_helpers import get_current_user
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
_DATA_URL_RE = re.compile(
|
_DATA_URL_RE = re.compile(r"^data:image/png;base64,(?P<data>.+)$", re.IGNORECASE | re.DOTALL)
|
||||||
r'^data:image/(?P<fmt>png|jpeg|jpg);base64,(?P<data>.+)$',
|
_ANY_IMAGE_DATA_URL_RE = re.compile(r"^data:image/[^;]+;base64,", re.IGNORECASE)
|
||||||
re.IGNORECASE | re.DOTALL,
|
_PNG_MAGIC = b"\x89PNG\r\n\x1a\n"
|
||||||
)
|
_MAX_SIGNATURE_BYTES = 2 * 1024 * 1024
|
||||||
|
_MAX_SIGNATURE_B64 = ((_MAX_SIGNATURE_BYTES + 2) // 3) * 4
|
||||||
|
_MAX_SIGNATURE_DIMENSION = 4096
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_signature_png(raw: str) -> str:
|
||||||
|
raw = (raw or "").strip()
|
||||||
|
m = _DATA_URL_RE.match(raw)
|
||||||
|
if m:
|
||||||
|
b64 = m.group("data")
|
||||||
|
elif _ANY_IMAGE_DATA_URL_RE.match(raw):
|
||||||
|
raise HTTPException(400, "Signature data must be a PNG image")
|
||||||
|
else:
|
||||||
|
b64 = raw
|
||||||
|
if len(b64) > _MAX_SIGNATURE_B64:
|
||||||
|
raise HTTPException(400, "Signature PNG is too large")
|
||||||
|
try:
|
||||||
|
payload = base64.b64decode(b64, validate=True)
|
||||||
|
except Exception:
|
||||||
|
raise HTTPException(400, "Signature data must be base64-encoded PNG bytes")
|
||||||
|
if not payload:
|
||||||
|
raise HTTPException(400, "Signature PNG is empty")
|
||||||
|
if len(payload) > _MAX_SIGNATURE_BYTES:
|
||||||
|
raise HTTPException(400, "Signature PNG is too large")
|
||||||
|
if not payload.startswith(_PNG_MAGIC):
|
||||||
|
raise HTTPException(400, "Signature data must be a PNG image")
|
||||||
|
return base64.b64encode(payload).decode("ascii")
|
||||||
|
|
||||||
|
|
||||||
|
def _signature_dimension(value: Optional[int]) -> Optional[int]:
|
||||||
|
if value is None:
|
||||||
|
return None
|
||||||
|
if not isinstance(value, int) or value < 1 or value > _MAX_SIGNATURE_DIMENSION:
|
||||||
|
raise HTTPException(400, "Signature dimensions are invalid")
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
class SignatureCreate(BaseModel):
|
class SignatureCreate(BaseModel):
|
||||||
@@ -67,24 +101,18 @@ def setup_signature_routes() -> APIRouter:
|
|||||||
@router.post("/api/signatures")
|
@router.post("/api/signatures")
|
||||||
async def create_signature(request: Request, req: SignatureCreate) -> Dict[str, Any]:
|
async def create_signature(request: Request, req: SignatureCreate) -> Dict[str, Any]:
|
||||||
user = get_current_user(request)
|
user = get_current_user(request)
|
||||||
raw = (req.data or "").strip()
|
b64 = _normalize_signature_png(req.data)
|
||||||
m = _DATA_URL_RE.match(raw)
|
width = _signature_dimension(req.width)
|
||||||
b64 = m.group("data") if m else raw
|
height = _signature_dimension(req.height)
|
||||||
try:
|
|
||||||
payload = base64.b64decode(b64, validate=True)
|
|
||||||
if not payload:
|
|
||||||
raise ValueError("empty payload")
|
|
||||||
except Exception:
|
|
||||||
raise HTTPException(400, "Signature data must be base64-encoded PNG bytes")
|
|
||||||
|
|
||||||
sig = Signature(
|
sig = Signature(
|
||||||
id=str(uuid.uuid4()),
|
id=str(uuid.uuid4()),
|
||||||
owner=user,
|
owner=user,
|
||||||
name=(req.name or "Signature").strip() or "Signature",
|
name=(req.name or "Signature").strip() or "Signature",
|
||||||
data_png=b64,
|
data_png=b64,
|
||||||
width=req.width,
|
width=width,
|
||||||
height=req.height,
|
height=height,
|
||||||
svg=req.svg,
|
svg=None,
|
||||||
)
|
)
|
||||||
db = SessionLocal()
|
db = SessionLocal()
|
||||||
try:
|
try:
|
||||||
|
|||||||
+107
-1
@@ -11,6 +11,8 @@ import logging
|
|||||||
import re
|
import re
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
from fastapi import APIRouter, HTTPException, Request
|
from fastapi import APIRouter, HTTPException, Request
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
@@ -51,6 +53,10 @@ class SkillAddRequest(BaseModel):
|
|||||||
steps: List[str] = Field(default_factory=list)
|
steps: List[str] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
class SkillImportUrlRequest(BaseModel):
|
||||||
|
url: str = Field(..., min_length=8, max_length=2000)
|
||||||
|
|
||||||
|
|
||||||
class SkillUpdateRequest(BaseModel):
|
class SkillUpdateRequest(BaseModel):
|
||||||
name: Optional[str] = None
|
name: Optional[str] = None
|
||||||
description: Optional[str] = None
|
description: Optional[str] = None
|
||||||
@@ -1014,7 +1020,7 @@ def _resolve_audit_models(owner=None):
|
|||||||
spec = (get_setting("teacher_model", "") or "").strip()
|
spec = (get_setting("teacher_model", "") or "").strip()
|
||||||
if spec:
|
if spec:
|
||||||
from src.ai_interaction import _resolve_model
|
from src.ai_interaction import _resolve_model
|
||||||
t_url, t_model, t_headers = _resolve_model(spec)
|
t_url, t_model, t_headers = _resolve_model(spec, owner=owner)
|
||||||
if t_url and t_model:
|
if t_url and t_model:
|
||||||
teacher = (t_url, t_model, t_headers)
|
teacher = (t_url, t_model, t_headers)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -1103,6 +1109,35 @@ def setup_skills_routes(skills_manager: SkillsManager) -> APIRouter:
|
|||||||
idx = skills_manager.index_for(owner=user)
|
idx = skills_manager.index_for(owner=user)
|
||||||
return {"index": idx, "count": len(idx)}
|
return {"index": idx, "count": len(idx)}
|
||||||
|
|
||||||
|
@router.get("/slash-catalog")
|
||||||
|
async def get_slash_catalog(request: Request):
|
||||||
|
"""Return skills that are available as slash commands.
|
||||||
|
|
||||||
|
Mirrors the agent prompt's published-skill index so the UI never offers
|
||||||
|
a slash command the model would not normally be allowed to discover.
|
||||||
|
"""
|
||||||
|
user = _owner(request)
|
||||||
|
all_skills = {s.get("name"): s for s in skills_manager.load(owner=user)}
|
||||||
|
entries = []
|
||||||
|
for s in skills_manager.index_for(owner=user):
|
||||||
|
name = (s.get("name") or "").strip()
|
||||||
|
if not name:
|
||||||
|
continue
|
||||||
|
full = all_skills.get(name) or {}
|
||||||
|
category = (s.get("category") or full.get("category") or "general").strip() or "general"
|
||||||
|
entries.append({
|
||||||
|
"type": "skill",
|
||||||
|
"token": f"/{name}",
|
||||||
|
"name": name,
|
||||||
|
"category": f"Skills / {category}",
|
||||||
|
"help": s.get("description") or full.get("description") or "",
|
||||||
|
"usage": f"/{name} <request>",
|
||||||
|
"uses": int(full.get("uses") or 0),
|
||||||
|
"last_used": full.get("last_used"),
|
||||||
|
})
|
||||||
|
entries.sort(key=lambda row: row["name"])
|
||||||
|
return {"skills": entries, "count": len(entries)}
|
||||||
|
|
||||||
@router.get("/builtin")
|
@router.get("/builtin")
|
||||||
async def list_builtin_skills(request: Request):
|
async def list_builtin_skills(request: Request):
|
||||||
"""Read-only list of the agent's built-in tool capabilities (research,
|
"""Read-only list of the agent's built-in tool capabilities (research,
|
||||||
@@ -1203,6 +1238,36 @@ def setup_skills_routes(skills_manager: SkillsManager) -> APIRouter:
|
|||||||
save_settings(settings)
|
save_settings(settings)
|
||||||
return {"ok": True, "name": name, "is_overridden": False}
|
return {"ok": True, "name": name, "is_overridden": False}
|
||||||
|
|
||||||
|
@router.post("/import-from-url")
|
||||||
|
async def import_skill_from_url(request: Request, body: SkillImportUrlRequest):
|
||||||
|
"""Install a SKILL.md bundle from a public GitHub URL (skills.sh links supported)."""
|
||||||
|
require_admin(request)
|
||||||
|
user = _owner(request)
|
||||||
|
from services.memory.skill_importer import (
|
||||||
|
SkillImportError,
|
||||||
|
fetch_skill_bundle,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
files, _src = fetch_skill_bundle(body.url.strip())
|
||||||
|
entry = skills_manager.import_bundle_from_files(
|
||||||
|
files,
|
||||||
|
owner=user,
|
||||||
|
source_url=body.url.strip(),
|
||||||
|
)
|
||||||
|
except SkillImportError as e:
|
||||||
|
raise HTTPException(400, str(e)) from e
|
||||||
|
except httpx.HTTPError as e:
|
||||||
|
logger.warning("skill import fetch failed: %s", e)
|
||||||
|
detail = str(e).strip() or "Could not download skill from URL"
|
||||||
|
raise HTTPException(502, detail) from e
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("skill import failed: %s", e)
|
||||||
|
raise HTTPException(500, "Skill import failed") from e
|
||||||
|
|
||||||
|
_fire_skill_added(user)
|
||||||
|
return {"ok": True, "skill": entry, "files": len(files)}
|
||||||
|
|
||||||
@router.post("/add")
|
@router.post("/add")
|
||||||
async def add_skill(request: Request, body: SkillAddRequest):
|
async def add_skill(request: Request, body: SkillAddRequest):
|
||||||
user = _owner(request)
|
user = _owner(request)
|
||||||
@@ -1236,6 +1301,47 @@ def setup_skills_routes(skills_manager: SkillsManager) -> APIRouter:
|
|||||||
_fire_skill_added(user)
|
_fire_skill_added(user)
|
||||||
return {"ok": True, "deduped": bool(entry.get("_deduped")), "skill": entry}
|
return {"ok": True, "deduped": bool(entry.get("_deduped")), "skill": entry}
|
||||||
|
|
||||||
|
@router.post("/{skill_id}/invoke")
|
||||||
|
async def invoke_skill(request: Request, skill_id: str):
|
||||||
|
"""Build a skill-pinned prompt for slash-command invocation.
|
||||||
|
|
||||||
|
This is intentionally server-side so availability, ownership, and usage
|
||||||
|
accounting use the same rules as the SkillsManager.
|
||||||
|
"""
|
||||||
|
user = _owner(request)
|
||||||
|
try:
|
||||||
|
body = await request.json()
|
||||||
|
except Exception:
|
||||||
|
body = {}
|
||||||
|
request_text = (body.get("request") or "").strip() if isinstance(body, dict) else ""
|
||||||
|
|
||||||
|
invokable = {
|
||||||
|
s.get("name"): s for s in skills_manager.index_for(owner=user)
|
||||||
|
if (s.get("name") or "").strip()
|
||||||
|
}
|
||||||
|
match = invokable.get(skill_id)
|
||||||
|
if not match:
|
||||||
|
raise HTTPException(404, "Skill is not available for slash invocation")
|
||||||
|
|
||||||
|
name = match.get("name")
|
||||||
|
md = skills_manager.read_skill_md(name, owner=user)
|
||||||
|
if md is None:
|
||||||
|
raise HTTPException(404, "Skill source unavailable")
|
||||||
|
|
||||||
|
skills_manager.record_use(name, owner=user)
|
||||||
|
message = (
|
||||||
|
"Apply the skill below to my request, following its Procedure / Pitfalls / Verification.\n\n"
|
||||||
|
f"--- BEGIN SKILL ---\n{md}\n--- END SKILL ---\n\n"
|
||||||
|
+ (f"Request: {request_text}" if request_text else "Request: (use the skill as appropriate)")
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"ok": True,
|
||||||
|
"type": "skill",
|
||||||
|
"name": name,
|
||||||
|
"command": f"/{name}",
|
||||||
|
"message": message,
|
||||||
|
}
|
||||||
|
|
||||||
@router.get("/{skill_id}")
|
@router.get("/{skill_id}")
|
||||||
async def get_skill(request: Request, skill_id: str):
|
async def get_skill(request: Request, skill_id: str):
|
||||||
user = _owner(request)
|
user = _owner(request)
|
||||||
|
|||||||
@@ -4,12 +4,10 @@
|
|||||||
from fastapi import APIRouter, HTTPException, UploadFile, File
|
from fastapi import APIRouter, HTTPException, UploadFile, File
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from src.upload_limits import read_upload_limited
|
from src.upload_limits import read_upload_limited, STT_MAX_AUDIO_BYTES
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
STT_MAX_AUDIO_BYTES = 25 * 1024 * 1024
|
|
||||||
|
|
||||||
|
|
||||||
def setup_stt_routes(stt_service):
|
def setup_stt_routes(stt_service):
|
||||||
"""Setup STT routes with the provided STT service"""
|
"""Setup STT routes with the provided STT service"""
|
||||||
|
|||||||
+37
-22
@@ -11,7 +11,9 @@ from fastapi import APIRouter, HTTPException, Request
|
|||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from core.database import SessionLocal, ScheduledTask, TaskRun
|
from core.database import SessionLocal, ScheduledTask, TaskRun
|
||||||
|
from core.constants import internal_api_base
|
||||||
from src.auth_helpers import get_current_user
|
from src.auth_helpers import get_current_user
|
||||||
|
from src.constants import DATA_DIR, EMAIL_URGENCY_CACHE_DIR
|
||||||
from src.task_scheduler import compute_next_run, HOUSEKEEPING_DEFAULTS
|
from src.task_scheduler import compute_next_run, HOUSEKEEPING_DEFAULTS
|
||||||
from routes.prefs_routes import _load_for_user, _save_for_user
|
from routes.prefs_routes import _load_for_user, _save_for_user
|
||||||
|
|
||||||
@@ -56,7 +58,7 @@ def _maybe_cascade_calendar_event(task) -> None:
|
|||||||
try:
|
try:
|
||||||
with httpx.Client(timeout=10) as client:
|
with httpx.Client(timeout=10) as client:
|
||||||
r = client.delete(
|
r = client.delete(
|
||||||
f"http://localhost:7000/api/calendar/events/{uid}",
|
f"{internal_api_base()}/api/calendar/events/{uid}",
|
||||||
headers=headers,
|
headers=headers,
|
||||||
)
|
)
|
||||||
if r.status_code >= 400:
|
if r.status_code >= 400:
|
||||||
@@ -81,7 +83,7 @@ def _maybe_cascade_calendar_event(task) -> None:
|
|||||||
try:
|
try:
|
||||||
with httpx.Client(timeout=10) as client:
|
with httpx.Client(timeout=10) as client:
|
||||||
# Find the Cookbook calendar.
|
# Find the Cookbook calendar.
|
||||||
cal_r = client.get("http://localhost:7000/api/calendar/calendars", headers=headers)
|
cal_r = client.get(f"{internal_api_base()}/api/calendar/calendars", headers=headers)
|
||||||
if cal_r.status_code >= 400:
|
if cal_r.status_code >= 400:
|
||||||
return
|
return
|
||||||
cals = (cal_r.json() or {}).get("calendars", [])
|
cals = (cal_r.json() or {}).get("calendars", [])
|
||||||
@@ -98,7 +100,7 @@ def _maybe_cascade_calendar_event(task) -> None:
|
|||||||
start = (now - _td(days=30)).isoformat()
|
start = (now - _td(days=30)).isoformat()
|
||||||
end = (now + _td(days=365)).isoformat()
|
end = (now + _td(days=365)).isoformat()
|
||||||
ev_r = client.get(
|
ev_r = client.get(
|
||||||
"http://localhost:7000/api/calendar/events",
|
f"{internal_api_base()}/api/calendar/events",
|
||||||
params={"start": start, "end": end, "calendar": cal_href},
|
params={"start": start, "end": end, "calendar": cal_href},
|
||||||
headers=headers,
|
headers=headers,
|
||||||
)
|
)
|
||||||
@@ -291,20 +293,24 @@ def setup_task_routes(task_scheduler) -> APIRouter:
|
|||||||
def _owner(request: Request):
|
def _owner(request: Request):
|
||||||
return get_current_user(request)
|
return get_current_user(request)
|
||||||
|
|
||||||
async def _generate_task_name(prompt: str) -> str:
|
async def _generate_task_name(prompt: str, owner: Optional[str] = None) -> str:
|
||||||
"""Use LLM to generate a short task name from the prompt."""
|
"""Use LLM to generate a short task name from the prompt."""
|
||||||
try:
|
try:
|
||||||
from src.llm_core import llm_call_async
|
from src.llm_core import llm_call_async
|
||||||
from core.database import Session as DbSession
|
from core.database import Session as DbSession
|
||||||
db = SessionLocal()
|
db = SessionLocal()
|
||||||
try:
|
try:
|
||||||
recent = db.query(DbSession).filter(
|
q = db.query(DbSession).filter(
|
||||||
DbSession.endpoint_url.isnot(None),
|
DbSession.endpoint_url.isnot(None),
|
||||||
DbSession.model.isnot(None),
|
DbSession.model.isnot(None),
|
||||||
).order_by(DbSession.created_at.desc()).first()
|
)
|
||||||
|
if owner:
|
||||||
|
q = q.filter(DbSession.owner == owner)
|
||||||
|
recent = q.order_by(DbSession.created_at.desc()).first()
|
||||||
if not recent:
|
if not recent:
|
||||||
return prompt[:50].strip()
|
return prompt[:50].strip()
|
||||||
url, model = recent.endpoint_url, recent.model
|
url, model = recent.endpoint_url, recent.model
|
||||||
|
headers = recent.headers or {}
|
||||||
finally:
|
finally:
|
||||||
db.close()
|
db.close()
|
||||||
|
|
||||||
@@ -315,6 +321,7 @@ def setup_task_routes(task_scheduler) -> APIRouter:
|
|||||||
{"role": "user", "content": prompt[:500]},
|
{"role": "user", "content": prompt[:500]},
|
||||||
],
|
],
|
||||||
max_tokens=20,
|
max_tokens=20,
|
||||||
|
headers=headers,
|
||||||
timeout=15,
|
timeout=15,
|
||||||
)
|
)
|
||||||
title = result.strip().strip('"\'').strip()
|
title = result.strip().strip('"\'').strip()
|
||||||
@@ -429,6 +436,20 @@ def setup_task_routes(task_scheduler) -> APIRouter:
|
|||||||
except Exception:
|
except Exception:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
def _validate_then_task_id(db, then_task_id: Optional[str], user: Optional[str], current_task_id: Optional[str] = None) -> Optional[str]:
|
||||||
|
target_id = (then_task_id or "").strip()
|
||||||
|
if not target_id:
|
||||||
|
return None
|
||||||
|
if current_task_id and target_id == current_task_id:
|
||||||
|
raise HTTPException(400, "Task cannot chain to itself")
|
||||||
|
q = db.query(ScheduledTask).filter(ScheduledTask.id == target_id)
|
||||||
|
if user:
|
||||||
|
q = q.filter(ScheduledTask.owner == user)
|
||||||
|
target = q.first()
|
||||||
|
if not target:
|
||||||
|
raise HTTPException(404, "Chained task not found")
|
||||||
|
return target.id
|
||||||
|
|
||||||
@router.post("")
|
@router.post("")
|
||||||
async def create_task(request: Request, req: TaskCreate):
|
async def create_task(request: Request, req: TaskCreate):
|
||||||
user = _owner(request)
|
user = _owner(request)
|
||||||
@@ -465,7 +486,7 @@ def setup_task_routes(task_scheduler) -> APIRouter:
|
|||||||
from src.builtin_actions import BUILTIN_ACTION_INFO
|
from src.builtin_actions import BUILTIN_ACTION_INFO
|
||||||
name = BUILTIN_ACTION_INFO.get(req.action, req.action or "Action Task")
|
name = BUILTIN_ACTION_INFO.get(req.action, req.action or "Action Task")
|
||||||
elif req.prompt:
|
elif req.prompt:
|
||||||
name = await _generate_task_name(req.prompt)
|
name = await _generate_task_name(req.prompt, owner=user)
|
||||||
else:
|
else:
|
||||||
name = "Untitled Task"
|
name = "Untitled Task"
|
||||||
|
|
||||||
@@ -492,6 +513,7 @@ def setup_task_routes(task_scheduler) -> APIRouter:
|
|||||||
task_id = str(uuid.uuid4())
|
task_id = str(uuid.uuid4())
|
||||||
db = SessionLocal()
|
db = SessionLocal()
|
||||||
try:
|
try:
|
||||||
|
then_task_id = _validate_then_task_id(db, req.then_task_id, user)
|
||||||
notifications_enabled = (
|
notifications_enabled = (
|
||||||
False if req.task_type == "action" and req.notifications_enabled is None
|
False if req.task_type == "action" and req.notifications_enabled is None
|
||||||
else bool(req.notifications_enabled) if req.notifications_enabled is not None
|
else bool(req.notifications_enabled) if req.notifications_enabled is not None
|
||||||
@@ -527,7 +549,7 @@ def setup_task_routes(task_scheduler) -> APIRouter:
|
|||||||
output_target=req.output_target,
|
output_target=req.output_target,
|
||||||
model=req.model or None,
|
model=req.model or None,
|
||||||
endpoint_url=req.endpoint_url or None,
|
endpoint_url=req.endpoint_url or None,
|
||||||
then_task_id=req.then_task_id or None,
|
then_task_id=then_task_id,
|
||||||
webhook_token=webhook_token,
|
webhook_token=webhook_token,
|
||||||
notifications_enabled=notifications_enabled,
|
notifications_enabled=notifications_enabled,
|
||||||
)
|
)
|
||||||
@@ -609,7 +631,7 @@ def setup_task_routes(task_scheduler) -> APIRouter:
|
|||||||
|
|
||||||
removed_files = 0
|
removed_files = 0
|
||||||
if action == "check_email_urgency":
|
if action == "check_email_urgency":
|
||||||
cache_dir = Path("data/email_urgency_cache")
|
cache_dir = Path(EMAIL_URGENCY_CACHE_DIR)
|
||||||
if cache_dir.exists():
|
if cache_dir.exists():
|
||||||
for child in cache_dir.glob("*.json"):
|
for child in cache_dir.glob("*.json"):
|
||||||
try:
|
try:
|
||||||
@@ -618,7 +640,7 @@ def setup_task_routes(task_scheduler) -> APIRouter:
|
|||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
owner_slug = "".join(c if (c.isalnum() or c in "-_.@") else "_" for c in (user or "default"))
|
owner_slug = "".join(c if (c.isalnum() or c in "-_.@") else "_" for c in (user or "default"))
|
||||||
for state_path in [Path(f"data/email_urgency_state_{owner_slug}.json")]:
|
for state_path in [Path(DATA_DIR) / f"email_urgency_state_{owner_slug}.json"]:
|
||||||
try:
|
try:
|
||||||
if state_path.exists():
|
if state_path.exists():
|
||||||
state_path.unlink()
|
state_path.unlink()
|
||||||
@@ -680,15 +702,7 @@ def setup_task_routes(task_scheduler) -> APIRouter:
|
|||||||
if req.trigger_count is not None:
|
if req.trigger_count is not None:
|
||||||
task.trigger_count = req.trigger_count
|
task.trigger_count = req.trigger_count
|
||||||
if req.then_task_id is not None:
|
if req.then_task_id is not None:
|
||||||
if req.then_task_id:
|
task.then_task_id = _validate_then_task_id(db, req.then_task_id, user, current_task_id=task.id)
|
||||||
chain_target = db.query(ScheduledTask).filter(
|
|
||||||
ScheduledTask.id == req.then_task_id
|
|
||||||
).first()
|
|
||||||
if not chain_target:
|
|
||||||
raise HTTPException(400, "Chained task not found")
|
|
||||||
if chain_target.owner != user:
|
|
||||||
raise HTTPException(403, "Cannot chain to another user's task")
|
|
||||||
task.then_task_id = req.then_task_id or None
|
|
||||||
if req.notifications_enabled is not None:
|
if req.notifications_enabled is not None:
|
||||||
task.notifications_enabled = bool(req.notifications_enabled)
|
task.notifications_enabled = bool(req.notifications_enabled)
|
||||||
if req.cron_expression is not None:
|
if req.cron_expression is not None:
|
||||||
@@ -969,7 +983,7 @@ def setup_task_routes(task_scheduler) -> APIRouter:
|
|||||||
"tag", "label", "move", "archive", "delete", "mark", "schedule",
|
"tag", "label", "move", "archive", "delete", "mark", "schedule",
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
from src.agent_tools import get_mcp_manager
|
from src.tool_utils import get_mcp_manager
|
||||||
mcp = get_mcp_manager()
|
mcp = get_mcp_manager()
|
||||||
if mcp:
|
if mcp:
|
||||||
for tool in mcp.get_all_tools():
|
for tool in mcp.get_all_tools():
|
||||||
@@ -1064,6 +1078,7 @@ def setup_task_routes(task_scheduler) -> APIRouter:
|
|||||||
desc = (body.get("description") or "").strip()
|
desc = (body.get("description") or "").strip()
|
||||||
if not desc:
|
if not desc:
|
||||||
return {"success": False, "message": "Nothing to parse"}
|
return {"success": False, "message": "Nothing to parse"}
|
||||||
|
user = _owner(request)
|
||||||
|
|
||||||
now = _dt.now()
|
now = _dt.now()
|
||||||
# Give the model the current date/time + weekday so relative phrasing
|
# Give the model the current date/time + weekday so relative phrasing
|
||||||
@@ -1090,9 +1105,9 @@ def setup_task_routes(task_scheduler) -> APIRouter:
|
|||||||
"use cron '0 H * * 1-5'. Keep the prompt actionable and self-contained."
|
"use cron '0 H * * 1-5'. Keep the prompt actionable and self-contained."
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
url, model, headers = resolve_endpoint("utility")
|
url, model, headers = resolve_endpoint("utility", owner=user or None)
|
||||||
if not url:
|
if not url:
|
||||||
url, model, headers = resolve_endpoint("default")
|
url, model, headers = resolve_endpoint("default", owner=user or None)
|
||||||
if not (url and model):
|
if not (url and model):
|
||||||
return {"success": False, "message": "No model endpoint configured"}
|
return {"success": False, "message": "No model endpoint configured"}
|
||||||
raw = await llm_call_async(
|
raw = await llm_call_async(
|
||||||
|
|||||||
+51
-34
@@ -13,10 +13,44 @@ from src.upload_handler import count_recent_uploads
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
router = APIRouter(prefix="/api/upload", tags=["upload"])
|
router = APIRouter(prefix="/api/upload", tags=["upload"])
|
||||||
|
UPLOAD_RESPONSE_HEADERS = {"X-Content-Type-Options": "nosniff"}
|
||||||
|
|
||||||
def setup_upload_routes(upload_handler):
|
def setup_upload_routes(upload_handler):
|
||||||
"""Setup upload routes with the provided handler"""
|
"""Setup upload routes with the provided handler"""
|
||||||
|
|
||||||
|
def _upload_root() -> str:
|
||||||
|
from src.constants import UPLOAD_DIR
|
||||||
|
return os.path.realpath(getattr(upload_handler, "upload_dir", UPLOAD_DIR))
|
||||||
|
|
||||||
|
def _path_inside_upload_dir(path: str) -> bool:
|
||||||
|
try:
|
||||||
|
return os.path.commonpath([_upload_root(), os.path.realpath(path)]) == _upload_root()
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _resolve_upload_path(file_id: str) -> str:
|
||||||
|
from src.constants import UPLOAD_DIR
|
||||||
|
upload_root = getattr(upload_handler, "upload_dir", UPLOAD_DIR)
|
||||||
|
direct = os.path.join(upload_root, file_id)
|
||||||
|
if os.path.lexists(direct):
|
||||||
|
if not _path_inside_upload_dir(direct):
|
||||||
|
raise HTTPException(403, "Access denied")
|
||||||
|
if os.path.isfile(direct):
|
||||||
|
return direct
|
||||||
|
raise HTTPException(404, "File not found")
|
||||||
|
|
||||||
|
for root, _dirs, files in os.walk(upload_root, followlinks=False):
|
||||||
|
if file_id not in files:
|
||||||
|
continue
|
||||||
|
path = os.path.join(root, file_id)
|
||||||
|
if not _path_inside_upload_dir(path):
|
||||||
|
raise HTTPException(403, "Access denied")
|
||||||
|
if os.path.isfile(path):
|
||||||
|
return path
|
||||||
|
raise HTTPException(404, "File not found")
|
||||||
|
|
||||||
|
raise HTTPException(404, "File not found")
|
||||||
|
|
||||||
@router.post("")
|
@router.post("")
|
||||||
async def api_upload(request: Request, files: List[UploadFile] = File(...)):
|
async def api_upload(request: Request, files: List[UploadFile] = File(...)):
|
||||||
"""Upload files with enhanced security and organization."""
|
"""Upload files with enhanced security and organization."""
|
||||||
@@ -91,23 +125,11 @@ def setup_upload_routes(upload_handler):
|
|||||||
client isn't downloading the full-resolution photo just to show it tiny."""
|
client isn't downloading the full-resolution photo just to show it tiny."""
|
||||||
if not upload_handler.validate_upload_id(file_id):
|
if not upload_handler.validate_upload_id(file_id):
|
||||||
raise HTTPException(400, "Invalid file ID")
|
raise HTTPException(400, "Invalid file ID")
|
||||||
# Search upload directories for the file
|
|
||||||
from src.constants import UPLOAD_DIR
|
|
||||||
import mimetypes as _mt
|
import mimetypes as _mt
|
||||||
path = os.path.join(UPLOAD_DIR, file_id)
|
|
||||||
if not os.path.exists(path):
|
|
||||||
for root, dirs, files in os.walk(UPLOAD_DIR):
|
|
||||||
if file_id in files:
|
|
||||||
path = os.path.join(root, file_id)
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
raise HTTPException(404, "File not found")
|
|
||||||
if not upload_handler.inside_base_dir(path):
|
|
||||||
raise HTTPException(403, "Access denied")
|
|
||||||
# Look up original filename and owner from uploads.json
|
# Look up original filename and owner from uploads.json
|
||||||
original_name = file_id
|
original_name = file_id
|
||||||
info = None
|
info = None
|
||||||
uploads_db = os.path.join(UPLOAD_DIR, "uploads.json")
|
uploads_db = os.path.join(_upload_root(), "uploads.json")
|
||||||
if os.path.exists(uploads_db):
|
if os.path.exists(uploads_db):
|
||||||
with open(uploads_db, encoding="utf-8") as f:
|
with open(uploads_db, encoding="utf-8") as f:
|
||||||
db = json.load(f)
|
db = json.load(f)
|
||||||
@@ -123,13 +145,14 @@ def setup_upload_routes(upload_handler):
|
|||||||
raise HTTPException(403, "Access denied")
|
raise HTTPException(403, "Access denied")
|
||||||
if file_owner != current_user and not auth_mgr.is_admin(current_user):
|
if file_owner != current_user and not auth_mgr.is_admin(current_user):
|
||||||
raise HTTPException(404, "File not found")
|
raise HTTPException(404, "File not found")
|
||||||
mime = _mt.guess_type(path)[0] or "application/octet-stream"
|
path = _resolve_upload_path(file_id)
|
||||||
|
mime = (info or {}).get("mime") or _mt.guess_type(path)[0] or "application/octet-stream"
|
||||||
from fastapi.responses import FileResponse
|
from fastapi.responses import FileResponse
|
||||||
# Downscaled thumbnail for image previews — generated once and cached.
|
# Downscaled thumbnail for image previews — generated once and cached.
|
||||||
if thumb and mime.startswith("image/"):
|
if thumb and mime.startswith("image/"):
|
||||||
try:
|
try:
|
||||||
from PIL import Image, ImageOps
|
from PIL import Image, ImageOps
|
||||||
thumb_dir = os.path.join(UPLOAD_DIR, ".thumbs")
|
thumb_dir = os.path.join(_upload_root(), ".thumbs")
|
||||||
os.makedirs(thumb_dir, exist_ok=True)
|
os.makedirs(thumb_dir, exist_ok=True)
|
||||||
thumb_path = os.path.join(thumb_dir, file_id + ".jpg")
|
thumb_path = os.path.join(thumb_dir, file_id + ".jpg")
|
||||||
if (not os.path.exists(thumb_path)
|
if (not os.path.exists(thumb_path)
|
||||||
@@ -145,17 +168,21 @@ def setup_upload_routes(upload_handler):
|
|||||||
if im.mode not in ("RGB", "L"):
|
if im.mode not in ("RGB", "L"):
|
||||||
im = im.convert("RGB")
|
im = im.convert("RGB")
|
||||||
im.save(thumb_path, "JPEG", quality=80)
|
im.save(thumb_path, "JPEG", quality=80)
|
||||||
return FileResponse(thumb_path, media_type="image/jpeg")
|
return FileResponse(thumb_path, media_type="image/jpeg", headers=UPLOAD_RESPONSE_HEADERS)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Thumbnail generation failed for {file_id}: {e}")
|
logger.warning(f"Thumbnail generation failed for {file_id}: {e}")
|
||||||
# Fall through to the full image.
|
# Fall through to the full image.
|
||||||
return FileResponse(path, media_type=mime, filename=original_name)
|
return FileResponse(
|
||||||
|
path,
|
||||||
|
media_type=mime,
|
||||||
|
filename=original_name,
|
||||||
|
headers=UPLOAD_RESPONSE_HEADERS,
|
||||||
|
)
|
||||||
|
|
||||||
def _load_upload_info(file_id: str):
|
def _load_upload_info(file_id: str):
|
||||||
"""Look up the uploads.json record for a file_id, with owner/auth checks."""
|
"""Look up the uploads.json record for a file_id, with owner/auth checks."""
|
||||||
from src.constants import UPLOAD_DIR
|
|
||||||
info = None
|
info = None
|
||||||
uploads_db = os.path.join(UPLOAD_DIR, "uploads.json")
|
uploads_db = os.path.join(_upload_root(), "uploads.json")
|
||||||
if os.path.exists(uploads_db):
|
if os.path.exists(uploads_db):
|
||||||
with open(uploads_db, encoding="utf-8") as f:
|
with open(uploads_db, encoding="utf-8") as f:
|
||||||
db = json.load(f)
|
db = json.load(f)
|
||||||
@@ -163,8 +190,7 @@ def setup_upload_routes(upload_handler):
|
|||||||
return info
|
return info
|
||||||
|
|
||||||
def _vision_cache_path(file_id: str) -> str:
|
def _vision_cache_path(file_id: str) -> str:
|
||||||
from src.constants import UPLOAD_DIR
|
cache_dir = os.path.join(_upload_root(), ".vision")
|
||||||
cache_dir = os.path.join(UPLOAD_DIR, ".vision")
|
|
||||||
os.makedirs(cache_dir, exist_ok=True)
|
os.makedirs(cache_dir, exist_ok=True)
|
||||||
return os.path.join(cache_dir, file_id + ".txt")
|
return os.path.join(cache_dir, file_id + ".txt")
|
||||||
|
|
||||||
@@ -175,17 +201,6 @@ def setup_upload_routes(upload_handler):
|
|||||||
subsequent loads are instant. Pass force=1 to recompute."""
|
subsequent loads are instant. Pass force=1 to recompute."""
|
||||||
if not upload_handler.validate_upload_id(file_id):
|
if not upload_handler.validate_upload_id(file_id):
|
||||||
raise HTTPException(400, "Invalid file ID")
|
raise HTTPException(400, "Invalid file ID")
|
||||||
from src.constants import UPLOAD_DIR
|
|
||||||
path = os.path.join(UPLOAD_DIR, file_id)
|
|
||||||
if not os.path.exists(path):
|
|
||||||
for root, dirs, files in os.walk(UPLOAD_DIR):
|
|
||||||
if file_id in files:
|
|
||||||
path = os.path.join(root, file_id)
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
raise HTTPException(404, "File not found")
|
|
||||||
if not upload_handler.inside_base_dir(path):
|
|
||||||
raise HTTPException(403, "Access denied")
|
|
||||||
info = _load_upload_info(file_id)
|
info = _load_upload_info(file_id)
|
||||||
auth_mgr = getattr(request.app.state, "auth_manager", None)
|
auth_mgr = getattr(request.app.state, "auth_manager", None)
|
||||||
auth_configured = bool(auth_mgr and auth_mgr.is_configured)
|
auth_configured = bool(auth_mgr and auth_mgr.is_configured)
|
||||||
@@ -196,8 +211,9 @@ def setup_upload_routes(upload_handler):
|
|||||||
raise HTTPException(403, "Access denied")
|
raise HTTPException(403, "Access denied")
|
||||||
if file_owner != current_user and not auth_mgr.is_admin(current_user):
|
if file_owner != current_user and not auth_mgr.is_admin(current_user):
|
||||||
raise HTTPException(404, "File not found")
|
raise HTTPException(404, "File not found")
|
||||||
|
path = _resolve_upload_path(file_id)
|
||||||
import mimetypes as _mt
|
import mimetypes as _mt
|
||||||
mime = _mt.guess_type(path)[0] or ""
|
mime = (info or {}).get("mime") or _mt.guess_type(path)[0] or ""
|
||||||
if not mime.startswith("image/"):
|
if not mime.startswith("image/"):
|
||||||
raise HTTPException(400, "Not an image")
|
raise HTTPException(400, "Not an image")
|
||||||
cache_path = _vision_cache_path(file_id)
|
cache_path = _vision_cache_path(file_id)
|
||||||
@@ -209,7 +225,7 @@ def setup_upload_routes(upload_handler):
|
|||||||
logger.warning(f"Vision cache read failed for {file_id}: {e}")
|
logger.warning(f"Vision cache read failed for {file_id}: {e}")
|
||||||
from src.document_processor import analyze_image_with_vl
|
from src.document_processor import analyze_image_with_vl
|
||||||
try:
|
try:
|
||||||
text = analyze_image_with_vl(path) or ""
|
text = analyze_image_with_vl(path, owner=current_user) or ""
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Vision analysis failed for {file_id}: {e}")
|
logger.error(f"Vision analysis failed for {file_id}: {e}")
|
||||||
raise HTTPException(500, f"Vision analysis failed: {e}")
|
raise HTTPException(500, f"Vision analysis failed: {e}")
|
||||||
@@ -238,6 +254,7 @@ def setup_upload_routes(upload_handler):
|
|||||||
raise HTTPException(403, "Access denied")
|
raise HTTPException(403, "Access denied")
|
||||||
if file_owner != current_user and not auth_mgr.is_admin(current_user):
|
if file_owner != current_user and not auth_mgr.is_admin(current_user):
|
||||||
raise HTTPException(404, "File not found")
|
raise HTTPException(404, "File not found")
|
||||||
|
_resolve_upload_path(file_id)
|
||||||
body = await request.json()
|
body = await request.json()
|
||||||
text = (body or {}).get("text", "")
|
text = (body or {}).get("text", "")
|
||||||
if not isinstance(text, str):
|
if not isinstance(text, str):
|
||||||
|
|||||||
@@ -17,10 +17,11 @@ from pydantic import BaseModel
|
|||||||
|
|
||||||
from core.middleware import require_admin
|
from core.middleware import require_admin
|
||||||
from core.platform_compat import IS_WINDOWS, safe_chmod, which_tool
|
from core.platform_compat import IS_WINDOWS, safe_chmod, which_tool
|
||||||
|
from src.constants import VAULT_FILE as _VAULT_FILE
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
VAULT_FILE = Path("data/vault.json")
|
VAULT_FILE = Path(_VAULT_FILE)
|
||||||
|
|
||||||
|
|
||||||
def _find_bw() -> str:
|
def _find_bw() -> str:
|
||||||
|
|||||||
+23
-10
@@ -194,6 +194,8 @@ def setup_webhook_routes(
|
|||||||
"together": "https://api.together.xyz/v1",
|
"together": "https://api.together.xyz/v1",
|
||||||
"openrouter": "https://openrouter.ai/api/v1",
|
"openrouter": "https://openrouter.ai/api/v1",
|
||||||
"ollama": "https://ollama.com/api",
|
"ollama": "https://ollama.com/api",
|
||||||
|
"opencode-zen": "https://opencode.ai/zen/v1",
|
||||||
|
"opencode-go": "https://opencode.ai/zen/go/v1",
|
||||||
"fireworks": "https://api.fireworks.ai/inference/v1",
|
"fireworks": "https://api.fireworks.ai/inference/v1",
|
||||||
"venice": "https://api.venice.ai/api/v1",
|
"venice": "https://api.venice.ai/api/v1",
|
||||||
}
|
}
|
||||||
@@ -323,22 +325,33 @@ def setup_webhook_routes(
|
|||||||
endpoint_url = build_chat_url(base_url)
|
endpoint_url = build_chat_url(base_url)
|
||||||
model = body.model or "auto"
|
model = body.model or "auto"
|
||||||
api_key = ep.api_key
|
api_key = ep.api_key
|
||||||
|
if getattr(ep, "provider_auth_id", None):
|
||||||
|
try:
|
||||||
|
from src.endpoint_resolver import resolve_endpoint_runtime
|
||||||
|
base_url, api_key = resolve_endpoint_runtime(ep, owner=token_owner)
|
||||||
|
endpoint_url = build_chat_url(base_url)
|
||||||
|
except Exception:
|
||||||
|
raise HTTPException(500, "Could not resolve endpoint credentials")
|
||||||
|
|
||||||
if model == "auto":
|
if model == "auto":
|
||||||
try:
|
try:
|
||||||
async with httpx.AsyncClient(timeout=5) as client:
|
async with httpx.AsyncClient(timeout=5) as client:
|
||||||
models_url = build_models_url(base_url)
|
models_url = build_models_url(base_url)
|
||||||
hdrs = build_headers(api_key, base_url)
|
hdrs = build_headers(api_key, base_url)
|
||||||
resp = await client.get(models_url, headers=hdrs)
|
if models_url:
|
||||||
resp.raise_for_status()
|
resp = await client.get(models_url, headers=hdrs)
|
||||||
data = resp.json()
|
resp.raise_for_status()
|
||||||
ids = [m.get("id") for m in (data.get("data") or []) if m.get("id")]
|
data = resp.json()
|
||||||
if not ids:
|
ids = [m.get("id") for m in (data.get("data") or []) if m.get("id")]
|
||||||
ids = [
|
if not ids:
|
||||||
m.get("name") or m.get("model")
|
ids = [
|
||||||
for m in (data.get("models") or [])
|
m.get("name") or m.get("model")
|
||||||
if m.get("name") or m.get("model")
|
for m in (data.get("models") or [])
|
||||||
]
|
if m.get("name") or m.get("model")
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
import json as _json
|
||||||
|
ids = _json.loads(ep.cached_models or "[]")
|
||||||
model = ids[0] if ids else "auto"
|
model = ids[0] if ids else "auto"
|
||||||
except Exception:
|
except Exception:
|
||||||
raise HTTPException(500, "Could not discover models from endpoint")
|
raise HTTPException(500, "Could not discover models from endpoint")
|
||||||
|
|||||||
@@ -13,6 +13,8 @@ import json
|
|||||||
|
|
||||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
|
|
||||||
|
from src.constants import MEMORY_FILE, SKILLS_FILE
|
||||||
|
|
||||||
|
|
||||||
def claim_json_entries(entries, owner):
|
def claim_json_entries(entries, owner):
|
||||||
count = 0
|
count = 0
|
||||||
@@ -35,8 +37,8 @@ def main():
|
|||||||
|
|
||||||
# 1. Memories (JSON files)
|
# 1. Memories (JSON files)
|
||||||
for label, path in [
|
for label, path in [
|
||||||
("memory.json", "data/memory.json"),
|
("memory.json", MEMORY_FILE),
|
||||||
("skills.json", "data/skills.json"),
|
("skills.json", SKILLS_FILE),
|
||||||
]:
|
]:
|
||||||
if not os.path.exists(path):
|
if not os.path.exists(path):
|
||||||
print(f" {label}: not found, skipping")
|
print(f" {label}: not found, skipping")
|
||||||
|
|||||||
@@ -34,6 +34,7 @@ import torch
|
|||||||
import uvicorn
|
import uvicorn
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
from starlette.middleware.trustedhost import TrustedHostMiddleware
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
@@ -52,7 +53,63 @@ async def lifespan(application):
|
|||||||
|
|
||||||
|
|
||||||
app = FastAPI(title="Diffusion Server", lifespan=lifespan)
|
app = FastAPI(title="Diffusion Server", lifespan=lifespan)
|
||||||
app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])
|
|
||||||
|
# Conservative defaults — server is designed for server-to-server use from
|
||||||
|
# the Odysseus backend. Wildcard CORS + the 127.0.0.1 default bind used to
|
||||||
|
# leave the server reachable via DNS-rebinding from any browser tab on the
|
||||||
|
# same host. The CLI flags below extend these allowlists for operators who
|
||||||
|
# need browser access; the safe defaults handle the common case.
|
||||||
|
_DEFAULT_ALLOWED_HOSTS = ["127.0.0.1", "localhost", "::1"]
|
||||||
|
_DEFAULT_CORS_ORIGINS: list = [] # default-deny
|
||||||
|
|
||||||
|
|
||||||
|
def _compute_allowed_hosts(bind_host: str, extras=None) -> list:
|
||||||
|
"""Allowed Host header values: the bind address + loopback variants +
|
||||||
|
any operator-supplied --allowed-host values. Duplicates and empty
|
||||||
|
strings are dropped; order is stable for predictable middleware setup."""
|
||||||
|
seen = []
|
||||||
|
for h in (bind_host, *_DEFAULT_ALLOWED_HOSTS, *(extras or [])):
|
||||||
|
h = (h or "").strip()
|
||||||
|
if h and h not in seen:
|
||||||
|
seen.append(h)
|
||||||
|
return seen
|
||||||
|
|
||||||
|
|
||||||
|
def _compute_cors_origins(extras=None) -> list:
|
||||||
|
"""CORS allowlist: default-deny (empty), extended only by explicit
|
||||||
|
--allowed-origin values. Server-to-server callers don't set an Origin
|
||||||
|
header so they're unaffected; this only narrows browser access."""
|
||||||
|
seen = []
|
||||||
|
for o in (*_DEFAULT_CORS_ORIGINS, *(extras or [])):
|
||||||
|
o = (o or "").strip()
|
||||||
|
if o and o not in seen:
|
||||||
|
seen.append(o)
|
||||||
|
return seen
|
||||||
|
|
||||||
|
|
||||||
|
def _configure_security_middleware(application, allowed_hosts, allowed_origins):
|
||||||
|
"""Replace `application`'s user middleware stack with the diffusion server
|
||||||
|
security middleware: the TrustedHost allowlist and, when origins are
|
||||||
|
supplied, CORS. Used at module load and by the __main__ CLI path before
|
||||||
|
serving starts. Raises before mutating if the middleware stack has already
|
||||||
|
been built. Order is preserved: TrustedHost first, then CORS (added last ->
|
||||||
|
outermost)."""
|
||||||
|
if application.middleware_stack is not None:
|
||||||
|
raise RuntimeError("security middleware must be configured before the app starts serving")
|
||||||
|
application.user_middleware.clear()
|
||||||
|
application.add_middleware(TrustedHostMiddleware, allowed_hosts=list(allowed_hosts))
|
||||||
|
if allowed_origins:
|
||||||
|
application.add_middleware(
|
||||||
|
CORSMiddleware,
|
||||||
|
allow_origins=list(allowed_origins),
|
||||||
|
allow_methods=["GET", "POST", "OPTIONS"],
|
||||||
|
allow_headers=["Authorization", "Content-Type"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Install defaults at module load so importing the app for tests / direct
|
||||||
|
# uvicorn invocation still benefits from the Host-header allowlist.
|
||||||
|
_configure_security_middleware(app, _DEFAULT_ALLOWED_HOSTS, _DEFAULT_CORS_ORIGINS)
|
||||||
|
|
||||||
|
|
||||||
class ImageRequest(BaseModel):
|
class ImageRequest(BaseModel):
|
||||||
@@ -1089,7 +1146,25 @@ if __name__ == "__main__":
|
|||||||
parser.add_argument("--attention-slicing", action="store_true", help="Enable attention slicing")
|
parser.add_argument("--attention-slicing", action="store_true", help="Enable attention slicing")
|
||||||
parser.add_argument("--vae-slicing", action="store_true", help="Enable VAE slicing")
|
parser.add_argument("--vae-slicing", action="store_true", help="Enable VAE slicing")
|
||||||
parser.add_argument("--harmonize-gpu", type=int, default=None, help="GPU index for harmonize/img2img (default: same as main)")
|
parser.add_argument("--harmonize-gpu", type=int, default=None, help="GPU index for harmonize/img2img (default: same as main)")
|
||||||
|
parser.add_argument("--allowed-host", action="append", default=[],
|
||||||
|
help="Additional Host header value to accept (DNS-rebinding allowlist). "
|
||||||
|
"Can be repeated. Loopback values are always included.")
|
||||||
|
parser.add_argument("--allowed-origin", action="append", default=[],
|
||||||
|
help="Additional CORS origin to allow. Can be repeated. Defaults to "
|
||||||
|
"no cross-origin access — only pass this if you need a browser "
|
||||||
|
"on a specific origin to call the server.")
|
||||||
_args = parser.parse_args()
|
_args = parser.parse_args()
|
||||||
|
|
||||||
|
# Replace the module-load middleware stack with the CLI-configured one so
|
||||||
|
# operator-supplied --allowed-host / --allowed-origin values take effect
|
||||||
|
# before the first request is served. user_middleware is consulted lazily
|
||||||
|
# when the middleware stack is built on the first request, so mutating it
|
||||||
|
# here is safe.
|
||||||
|
final_hosts = _compute_allowed_hosts(_args.host, _args.allowed_host)
|
||||||
|
final_origins = _compute_cors_origins(_args.allowed_origin)
|
||||||
|
_configure_security_middleware(app, final_hosts, final_origins)
|
||||||
|
logger.info("security middleware: allowed_hosts=%s allowed_origins=%s",
|
||||||
|
final_hosts, final_origins or "(none — default-deny)")
|
||||||
|
|
||||||
app.state.model_path = _args.model
|
app.state.model_path = _args.model
|
||||||
uvicorn.run(app, host=_args.host, port=_args.port)
|
uvicorn.run(app, host=_args.host, port=_args.port)
|
||||||
|
|||||||
@@ -19,6 +19,9 @@ import sys
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Tuple
|
from typing import List, Tuple
|
||||||
|
|
||||||
|
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
|
from src.constants import PERSONAL_DIR
|
||||||
|
|
||||||
# Configure logging for the script
|
# Configure logging for the script
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
level=logging.INFO,
|
level=logging.INFO,
|
||||||
@@ -45,7 +48,7 @@ def main():
|
|||||||
rag_manager = RAGManager()
|
rag_manager = RAGManager()
|
||||||
|
|
||||||
# Directory to scan
|
# Directory to scan
|
||||||
docs_directory = "data/personal_docs"
|
docs_directory = PERSONAL_DIR
|
||||||
directory_path = Path(docs_directory)
|
directory_path = Path(docs_directory)
|
||||||
|
|
||||||
# Check if directory exists
|
# Check if directory exists
|
||||||
|
|||||||
@@ -63,10 +63,10 @@ def migrate_memories():
|
|||||||
"""Migrate memory vectors from FAISS to ChromaDB."""
|
"""Migrate memory vectors from FAISS to ChromaDB."""
|
||||||
from src.chroma_client import get_chroma_client
|
from src.chroma_client import get_chroma_client
|
||||||
from src.embeddings import get_embedding_client
|
from src.embeddings import get_embedding_client
|
||||||
from src.constants import DATA_DIR
|
from src.constants import MEMORY_VECTORS_DIR, MEMORY_FILE
|
||||||
|
|
||||||
ids_path = os.path.join(DATA_DIR, "memory_vectors", "ids.json")
|
ids_path = os.path.join(MEMORY_VECTORS_DIR, "ids.json")
|
||||||
memory_path = os.path.join(DATA_DIR, "memory.json")
|
memory_path = MEMORY_FILE
|
||||||
|
|
||||||
if not os.path.exists(ids_path):
|
if not os.path.exists(ids_path):
|
||||||
logger.info("No memory FAISS index found, skipping memory migration")
|
logger.info("No memory FAISS index found, skipping memory migration")
|
||||||
|
|||||||
@@ -47,6 +47,9 @@ _STATE_PATH = _DATA_DIR / "cookbook_state.json"
|
|||||||
import tempfile
|
import tempfile
|
||||||
_TMUX_LOG_DIR = Path(tempfile.gettempdir()) / "odysseus-tmux"
|
_TMUX_LOG_DIR = Path(tempfile.gettempdir()) / "odysseus-tmux"
|
||||||
|
|
||||||
|
from core.platform_compat import NVIDIA_PATH_CANDIDATES, SSH_PATH_OVERRIDE
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def fail(msg: str, code: int = 1) -> None:
|
def fail(msg: str, code: int = 1) -> None:
|
||||||
sys.stderr.write(f"error: {msg}\n")
|
sys.stderr.write(f"error: {msg}\n")
|
||||||
@@ -160,7 +163,26 @@ def cmd_gpus(args) -> None:
|
|||||||
prefix = _ssh_prefix(args.host, args.ssh_port)
|
prefix = _ssh_prefix(args.host, args.ssh_port)
|
||||||
cmd = prefix + (query.split() if not prefix else [query])
|
cmd = prefix + (query.split() if not prefix else [query])
|
||||||
try:
|
try:
|
||||||
out = subprocess.run(cmd, capture_output=True, text=True, timeout=15)
|
if prefix:
|
||||||
|
candidates = [query]
|
||||||
|
args_part = query[len("nvidia-smi "):]
|
||||||
|
candidates.append(
|
||||||
|
"bash -lc "
|
||||||
|
+ repr(
|
||||||
|
f"{SSH_PATH_OVERRIDE}"
|
||||||
|
f"nvidia-smi {args_part}"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
for nvidia_path in NVIDIA_PATH_CANDIDATES:
|
||||||
|
candidates.append(f"{nvidia_path} {args_part}")
|
||||||
|
|
||||||
|
out = None
|
||||||
|
for candidate in candidates:
|
||||||
|
out = subprocess.run(prefix + [candidate], capture_output=True, text=True, timeout=15)
|
||||||
|
if out.returncode == 0:
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
out = subprocess.run(cmd, capture_output=True, text=True, timeout=15)
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
# No nvidia-smi locally → try the Metal fallback before giving up.
|
# No nvidia-smi locally → try the Metal fallback before giving up.
|
||||||
if not prefix:
|
if not prefix:
|
||||||
|
|||||||
@@ -25,6 +25,24 @@ from pathlib import Path
|
|||||||
|
|
||||||
_DATA_DIR = _REPO_ROOT / "data" / "deep_research"
|
_DATA_DIR = _REPO_ROOT / "data" / "deep_research"
|
||||||
|
|
||||||
|
# The CLI's --status takes the user-facing label "complete", but the writer
|
||||||
|
# in services/research/research_handler.py stores `status="done"` when a run
|
||||||
|
# finishes (and the legacy src/research_handler.py does the same). Without
|
||||||
|
# this alias, --status complete filters every finished record out and the
|
||||||
|
# user sees an empty list. Map at filter time so the on-disk corpus is the
|
||||||
|
# source of truth and the CLI surface stays the friendlier word. The other
|
||||||
|
# choices ("running", "cancelled", "error") are stored verbatim, so they
|
||||||
|
# fall through unchanged.
|
||||||
|
_STATUS_CLI_TO_STORED = {"complete": "done"}
|
||||||
|
|
||||||
|
|
||||||
|
def _status_matches(stored, requested: str) -> bool:
|
||||||
|
stored = (stored or "")
|
||||||
|
if not isinstance(stored, str):
|
||||||
|
stored = ""
|
||||||
|
target = _STATUS_CLI_TO_STORED.get(requested, requested)
|
||||||
|
return stored == target
|
||||||
|
|
||||||
|
|
||||||
def _load_path(path: Path) -> dict | None:
|
def _load_path(path: Path) -> dict | None:
|
||||||
try:
|
try:
|
||||||
@@ -72,7 +90,7 @@ def cmd_list(args):
|
|||||||
data = _load_path(path)
|
data = _load_path(path)
|
||||||
if data is None:
|
if data is None:
|
||||||
continue
|
continue
|
||||||
if args.status and (data.get("status") or "") != args.status:
|
if args.status and not _status_matches(data.get("status"), args.status):
|
||||||
continue
|
continue
|
||||||
out.append(_summarize(rp_id, data))
|
out.append(_summarize(rp_id, data))
|
||||||
out.sort(key=lambda r: r.get("started_at") or "", reverse=True)
|
out.sort(key=lambda r: r.get("started_at") or "", reverse=True)
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ from dataclasses import dataclass
|
|||||||
from typing import List, Dict, Any
|
from typing import List, Dict, Any
|
||||||
|
|
||||||
from src.rag_manager import RAGManager
|
from src.rag_manager import RAGManager
|
||||||
|
from src.constants import CHROMA_DIR
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -34,7 +35,7 @@ class DocsService:
|
|||||||
results = await service.query("what is async await?")
|
results = await service.query("what is async await?")
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, persist_dir: str = "data/chroma"):
|
def __init__(self, persist_dir: str = CHROMA_DIR):
|
||||||
self.rag = RAGManager(persist_directory=persist_dir)
|
self.rag = RAGManager(persist_directory=persist_dir)
|
||||||
|
|
||||||
async def query(self, query: str, top_k: int = 5) -> List[DocChunk]:
|
async def query(self, query: str, top_k: int = 5) -> List[DocChunk]:
|
||||||
|
|||||||
+93
-48
@@ -4,6 +4,13 @@ import re
|
|||||||
import shutil
|
import shutil
|
||||||
import subprocess
|
import subprocess
|
||||||
import time
|
import time
|
||||||
|
import shlex
|
||||||
|
|
||||||
|
from core.platform_compat import (
|
||||||
|
NVIDIA_PATH_CANDIDATES,
|
||||||
|
SSH_PATH_OVERRIDE,
|
||||||
|
run_ssh_command,
|
||||||
|
)
|
||||||
|
|
||||||
CACHE_TTL = 24 * 3600 # 24 h — hardware probes are user-initiated via the Rescan button; bumped
|
CACHE_TTL = 24 * 3600 # 24 h — hardware probes are user-initiated via the Rescan button; bumped
|
||||||
# from 30 min so changing filters doesn't keep re-probing the rig every
|
# from 30 min so changing filters doesn't keep re-probing the rig every
|
||||||
@@ -21,16 +28,17 @@ def _run(cmd):
|
|||||||
if _remote_host:
|
if _remote_host:
|
||||||
# Run command on remote host via SSH
|
# Run command on remote host via SSH
|
||||||
if isinstance(cmd, list):
|
if isinstance(cmd, list):
|
||||||
cmd_str = " ".join(cmd)
|
cmd_str = shlex.join(str(c) for c in cmd)
|
||||||
else:
|
else:
|
||||||
cmd_str = cmd
|
cmd_str = cmd
|
||||||
ssh_cmd = ["ssh", "-o", "ConnectTimeout=5", "-o", "StrictHostKeyChecking=no"]
|
r = run_ssh_command(
|
||||||
if _remote_port and _remote_port != "22":
|
_remote_host,
|
||||||
ssh_cmd += ["-p", _remote_port]
|
_remote_port,
|
||||||
ssh_cmd += [_remote_host, cmd_str]
|
cmd_str,
|
||||||
r = subprocess.run(
|
timeout=15,
|
||||||
ssh_cmd,
|
connect_timeout=5,
|
||||||
capture_output=True, text=True, timeout=15,
|
strict_host_key_checking=False,
|
||||||
|
text=True,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
r = subprocess.run(cmd, capture_output=True, text=True, timeout=10)
|
r = subprocess.run(cmd, capture_output=True, text=True, timeout=10)
|
||||||
@@ -76,21 +84,29 @@ def _detect_nvidia():
|
|||||||
global _last_gpu_error
|
global _last_gpu_error
|
||||||
_last_gpu_error = None
|
_last_gpu_error = None
|
||||||
out = _run(["nvidia-smi", "--query-gpu=memory.total,name", "--format=csv,noheader,nounits"])
|
out = _run(["nvidia-smi", "--query-gpu=memory.total,name", "--format=csv,noheader,nounits"])
|
||||||
# Remote fallback: a non-interactive SSH shell often has a minimal PATH
|
# Fallback: a non-interactive shell (or WSL) often has a minimal PATH
|
||||||
# that omits where nvidia-smi lives (/usr/bin, /usr/local/cuda/bin), so the
|
# that omits where nvidia-smi lives (/usr/bin, /usr/local/cuda/bin,
|
||||||
# first call silently returns nothing → "No GPU" on hosts that DO have GPUs.
|
# /usr/lib/wsl/lib), so the first call silently returns nothing →
|
||||||
|
# "No GPU" on machines that DO have GPUs.
|
||||||
# Retry through a login shell with the common CUDA bin dirs on PATH.
|
# Retry through a login shell with the common CUDA bin dirs on PATH.
|
||||||
if not out and _remote_host:
|
if not out and _remote_host:
|
||||||
out = _run(
|
out = _run(
|
||||||
"bash -lc 'export PATH=\"$PATH:/usr/bin:/usr/local/bin:/usr/local/cuda/bin\"; "
|
f"bash -lc '{SSH_PATH_OVERRIDE}"
|
||||||
"nvidia-smi --query-gpu=memory.total,name --format=csv,noheader,nounits'"
|
"nvidia-smi --query-gpu=memory.total,name --format=csv,noheader,nounits'"
|
||||||
)
|
)
|
||||||
# Last resort: call nvidia-smi by absolute path. Some hosts have a login
|
# Last resort: call nvidia-smi by absolute path. Some hosts have a login
|
||||||
# shell that isn't bash (or a profile that errors), so the bash -lc retry
|
# shell that isn't bash (or a profile that errors), so the bash -lc retry
|
||||||
# above still comes back empty even though the binary is right there.
|
# above still comes back empty even though the binary is right there.
|
||||||
if not out and _remote_host:
|
# Also handles WSL where nvidia-smi lives at /usr/lib/wsl/lib/ — a path
|
||||||
for _p in ("/usr/bin/nvidia-smi", "/usr/local/bin/nvidia-smi", "/usr/local/cuda/bin/nvidia-smi"):
|
# that may not be in the server process's PATH.
|
||||||
out = _run(f"{_p} --query-gpu=memory.total,name --format=csv,noheader,nounits")
|
if not out:
|
||||||
|
for _p in NVIDIA_PATH_CANDIDATES:
|
||||||
|
# Use list form so subprocess.run (local) resolves the absolute path
|
||||||
|
# correctly instead of treating the whole string as an executable name.
|
||||||
|
if _remote_host:
|
||||||
|
out = _run(f"{_p} --query-gpu=memory.total,name --format=csv,noheader,nounits")
|
||||||
|
else:
|
||||||
|
out = _run([_p, "--query-gpu=memory.total,name", "--format=csv,noheader,nounits"])
|
||||||
if out:
|
if out:
|
||||||
break
|
break
|
||||||
if not out:
|
if not out:
|
||||||
@@ -468,39 +484,55 @@ def _detect_windows():
|
|||||||
"""
|
"""
|
||||||
# Single PowerShell command that gathers all hardware info at once
|
# Single PowerShell command that gathers all hardware info at once
|
||||||
ps_cmd = (
|
ps_cmd = (
|
||||||
"$r = @{}; "
|
"""
|
||||||
"$os = Get-CimInstance Win32_OperatingSystem; "
|
$r = @{}
|
||||||
"$r.ram_gb = [math]::Round($os.TotalVisibleMemorySize / 1048576, 1); "
|
$os = Get-CimInstance Win32_OperatingSystem
|
||||||
"$r.avail_gb = [math]::Round($os.FreePhysicalMemory / 1048576, 1); "
|
$r.ram_gb = [math]::Round($os.TotalVisibleMemorySize / 1048576, 1)
|
||||||
"$cpu = Get-CimInstance Win32_Processor | Select-Object -First 1; "
|
$r.avail_gb = [math]::Round($os.FreePhysicalMemory / 1048576, 1)
|
||||||
"$r.cpu_name = $cpu.Name; "
|
$cpu = Get-CimInstance Win32_Processor | Select-Object -First 1
|
||||||
"$r.cpu_cores = (Get-CimInstance Win32_Processor | Measure-Object -Property NumberOfLogicalProcessors -Sum).Sum; "
|
$r.cpu_name = $cpu.Name
|
||||||
"$r.arch = $cpu.AddressWidth; "
|
$r.cpu_cores = (Get-CimInstance Win32_Processor | Measure-Object -Property NumberOfLogicalProcessors -Sum).Sum
|
||||||
|
$r.arch = $cpu.AddressWidth
|
||||||
# GPU detection via nvidia-smi (fastest) or WMI fallback
|
# GPU detection via nvidia-smi (fastest) or WMI fallback
|
||||||
"try { "
|
try {
|
||||||
" $nv = nvidia-smi --query-gpu=memory.total,name --format=csv,noheader,nounits 2>$null; "
|
$nv = nvidia-smi --query-gpu=memory.total,name --format=csv,noheader,nounits 2>$null
|
||||||
" if ($LASTEXITCODE -eq 0 -and $nv) { "
|
if ($LASTEXITCODE -eq 0 -and $nv) {
|
||||||
" $gpus = @(); "
|
$gpus = @()
|
||||||
" foreach ($line in $nv -split \"`n\") { "
|
foreach ($line in $nv -split "`n") {
|
||||||
" $p = $line -split ','; "
|
$p = $line -split ','
|
||||||
" if ($p.Count -ge 2) { $gpus += [pscustomobject]@{name=$p[1].Trim(); vram_mb=[double]$p[0].Trim()} } "
|
if ($p.Count -ge 2) { $gpus += [pscustomobject]@{name = $p[1].Trim(); vram_mb = [double]$p[0].Trim() } }
|
||||||
" }; "
|
}
|
||||||
" $r.gpu_name = $gpus[0].name; "
|
$r.gpu_name = $gpus[0].name
|
||||||
" $r.gpu_vram_gb = [math]::Round(($gpus | Measure-Object -Property vram_mb -Sum).Sum / 1024, 1); "
|
$r.gpu_vram_gb = [math]::Round(($gpus | Measure-Object -Property vram_mb -Sum).Sum / 1024, 1)
|
||||||
" $r.gpu_count = $gpus.Count; "
|
$r.gpu_count = $gpus.Count
|
||||||
" $r.gpu_backend = 'cuda'; "
|
$r.gpu_backend = 'cuda'
|
||||||
" } "
|
}
|
||||||
"} catch {}; "
|
}
|
||||||
"if (-not $r.gpu_name) { "
|
catch {}
|
||||||
" $wmiGpu = Get-CimInstance Win32_VideoController | Where-Object { $_.AdapterRAM -gt 0 } | Select-Object -First 1; "
|
if (-not $r.gpu_name) {
|
||||||
" if ($wmiGpu) { "
|
$wmiGpu = Get-CimInstance Win32_VideoController | Where-Object { $_.AdapterRAM -gt 0 } | Select-Object -First 1
|
||||||
" $r.gpu_name = $wmiGpu.Name; "
|
$GPUDriverKey = "HKLM:\\SYSTEM\\CurrentControlSet\\Control\\Class\\{4d36e968-e325-11ce-bfc1-08002be10318}\\0*"
|
||||||
" $r.gpu_vram_gb = [math]::Round($wmiGpu.AdapterRAM / 1073741824, 1); "
|
$GPUDeviceID = $wmiGpu.PNPDeviceID.Split('&')[0..1] -join '&'
|
||||||
" $r.gpu_count = 1; "
|
$VRAMfromRegistry = Get-ItemProperty -Path $GPUDriverKey |
|
||||||
" $r.gpu_backend = 'cpu_x86'; " # WMI doesn't tell us CUDA/ROCm
|
Where-Object { $_.MatchingDeviceId -like "${GPUDeviceID}*" } |
|
||||||
" } "
|
# Sometimes there happen to be multiple driver classes for the same gpu.
|
||||||
"}; "
|
Select-Object -ExpandProperty HardwareInformation.qwMemorySize -ErrorAction SilentlyContinue -First 1
|
||||||
"$r | ConvertTo-Json -Compress"
|
if ($wmiGpu) {
|
||||||
|
$r.gpu_name = $wmiGpu.Name
|
||||||
|
# Edge case: driver is broken, otherwise $wmiGpu.AdapterRAM is redundant
|
||||||
|
if ($VRAMfromRegistry -ge $wmiGpu.AdapterRAM) {
|
||||||
|
$r.gpu_vram_gb = [math]::Round($VRAMfromRegistry / 1073741824, 1)
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
$r.gpu_vram_gb = [math]::Round($wmiGpu.AdapterRAM / 1073741824, 1)
|
||||||
|
}
|
||||||
|
$r.gpu_count = 1
|
||||||
|
# WMI doesn't tell us CUDA/ROCm
|
||||||
|
$r.gpu_backend = 'cpu_x86';
|
||||||
|
}
|
||||||
|
}
|
||||||
|
$r | ConvertTo-Json -Compress
|
||||||
|
"""
|
||||||
)
|
)
|
||||||
if _remote_host:
|
if _remote_host:
|
||||||
# Remote: ship a single command string over SSH. The remote shell parses
|
# Remote: ship a single command string over SSH. The remote shell parses
|
||||||
@@ -566,6 +598,19 @@ def _detect_windows():
|
|||||||
_cache_by_host = {} # host -> (timestamp, result)
|
_cache_by_host = {} # host -> (timestamp, result)
|
||||||
|
|
||||||
|
|
||||||
|
def _cache_key(host: str, ssh_port: str, platform_name: str):
|
||||||
|
"""Build a stable cache key that isolates remote SSH context.
|
||||||
|
|
||||||
|
Same host aliases can have different hardware due to visibility, forwarding etc.
|
||||||
|
To avoid using the wrong cached hardware info, include the SSH port and platform in the cache key.
|
||||||
|
"""
|
||||||
|
return (
|
||||||
|
host or "_local",
|
||||||
|
str(ssh_port or ""),
|
||||||
|
str(platform_name or "").lower(),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def detect_system(host="", ssh_port="", platform="", fresh=False):
|
def detect_system(host="", ssh_port="", platform="", fresh=False):
|
||||||
"""Detect system hardware: RAM, CPU, GPU. Cached per host (hardware rarely
|
"""Detect system hardware: RAM, CPU, GPU. Cached per host (hardware rarely
|
||||||
changes, and probing a remote host over SSH is slow). Pass fresh=True to
|
changes, and probing a remote host over SSH is slow). Pass fresh=True to
|
||||||
@@ -575,7 +620,7 @@ def detect_system(host="", ssh_port="", platform="", fresh=False):
|
|||||||
"""
|
"""
|
||||||
global _remote_host, _remote_port, _remote_platform
|
global _remote_host, _remote_port, _remote_platform
|
||||||
|
|
||||||
cache_key = host or "_local"
|
cache_key = _cache_key(host, ssh_port, platform)
|
||||||
now = time.time()
|
now = time.time()
|
||||||
if not fresh and cache_key in _cache_by_host:
|
if not fresh and cache_key in _cache_by_host:
|
||||||
ts, cached = _cache_by_host[cache_key]
|
ts, cached = _cache_by_host[cache_key]
|
||||||
|
|||||||
@@ -192,11 +192,19 @@ def _fallback_memory_candidates(messages) -> list[dict]:
|
|||||||
if place:
|
if place:
|
||||||
add(f"User lives in {place}.", "identity")
|
add(f"User lives in {place}.", "identity")
|
||||||
|
|
||||||
m = re.search(r"\bi (?:prefer|like|love|hate|do not like|don't like)\s+([^.!?\n]{4,100})", text, re.I)
|
m = re.search(r"\bi (prefer|like|love|hate|do not like|don't like)\s+([^.!?\n]{4,100})", text, re.I)
|
||||||
if m:
|
if m:
|
||||||
preference = _clean_memory_value(m.group(1), 100)
|
preference = _clean_memory_value(m.group(2), 100)
|
||||||
if preference:
|
if preference:
|
||||||
add(f"User prefers {preference}.", "preference")
|
# The same pattern catches likes and dislikes; keep the stored
|
||||||
|
# sentiment faithful instead of recording every match as a
|
||||||
|
# preference ("I hate cilantro" must not become "User prefers
|
||||||
|
# cilantro").
|
||||||
|
verb = m.group(1).lower()
|
||||||
|
if verb in ("hate", "do not like", "don't like"):
|
||||||
|
add(f"User dislikes {preference}.", "preference")
|
||||||
|
else:
|
||||||
|
add(f"User prefers {preference}.", "preference")
|
||||||
|
|
||||||
m = re.search(
|
m = re.search(
|
||||||
r"\bi (?:(?:want|would like|plan|hope) to|wanna) "
|
r"\bi (?:(?:want|would like|plan|hope) to|wanna) "
|
||||||
@@ -228,6 +236,43 @@ def _is_text_duplicate(new_text: str, existing: list, threshold: float = 0.6) ->
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_extraction_json(raw: str) -> list:
|
||||||
|
"""Parse the extraction LLM's reply into a list of facts, tolerating
|
||||||
|
reasoning-model noise.
|
||||||
|
|
||||||
|
The model emits <think>…</think> (and sometimes a prose preamble or a
|
||||||
|
```json fence) AROUND the JSON array; without stripping it, json.loads
|
||||||
|
bombs and the run silently yields "0 candidates". Pure str -> list (no
|
||||||
|
LLM/network); returns [] on any parse failure instead of raising.
|
||||||
|
"""
|
||||||
|
text = (raw or "").strip()
|
||||||
|
try:
|
||||||
|
from src.text_helpers import strip_think as _strip_think
|
||||||
|
text = _strip_think(text, prose=True, prompt_echo=True).strip()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
if text.startswith("```"):
|
||||||
|
text = text.split("\n", 1)[-1].rsplit("```", 1)[0].strip()
|
||||||
|
# JSON may still be embedded in surrounding commentary (leading prose or
|
||||||
|
# trailing remarks like "[...] Done!") — slice from the first '[' to the
|
||||||
|
# last ']' whenever both exist. Slice unconditionally: a reply that starts
|
||||||
|
# with '[' can still carry trailing commentary that breaks json.loads.
|
||||||
|
_start = text.find("[")
|
||||||
|
_end = text.rfind("]")
|
||||||
|
if 0 <= _start < _end:
|
||||||
|
text = text[_start : _end + 1]
|
||||||
|
|
||||||
|
try:
|
||||||
|
facts = json.loads(text)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
logger.debug("Memory extraction returned non-JSON: %r", (raw or "")[:120])
|
||||||
|
return []
|
||||||
|
except Exception:
|
||||||
|
logger.debug("Memory extraction returned non-JSON: %r", (raw or "")[:120])
|
||||||
|
return []
|
||||||
|
return facts if isinstance(facts, list) else []
|
||||||
|
|
||||||
|
|
||||||
async def extract_and_store(
|
async def extract_and_store(
|
||||||
session,
|
session,
|
||||||
memory_manager,
|
memory_manager,
|
||||||
@@ -276,9 +321,34 @@ async def extract_and_store(
|
|||||||
|
|
||||||
fallback_facts = _fallback_memory_candidates(stripped_recent)
|
fallback_facts = _fallback_memory_candidates(stripped_recent)
|
||||||
|
|
||||||
|
# Flatten the window into a SINGLE user message instead of appending the
|
||||||
|
# raw alternating role messages. Passed as raw chat messages, the model
|
||||||
|
# treats the window as a conversation to CONTINUE rather than a transcript
|
||||||
|
# to ANALYZE, so it reliably extracts nothing — typically returning `[]`
|
||||||
|
# (and, depending on the input, sometimes an empty or <think>-only
|
||||||
|
# completion when the window ends on an assistant turn). This was the real
|
||||||
|
# cause of auto-memory logging "0 candidates" on every run. Reframing it as
|
||||||
|
# one "analyze this transcript, return the JSON array" user message makes
|
||||||
|
# the model actually extract. Controlled repro on this model: 0/6 trials
|
||||||
|
# with the old structure vs 6/6 with this one. The skill extractor flattens
|
||||||
|
# for the same reason.
|
||||||
|
def _flatten_msg(m):
|
||||||
|
c = m.get("content", "")
|
||||||
|
if isinstance(c, list):
|
||||||
|
c = " ".join(
|
||||||
|
b.get("text", "") for b in c
|
||||||
|
if isinstance(b, dict) and b.get("type") == "text"
|
||||||
|
)
|
||||||
|
return f"{m.get('role', '?')}: {c}"
|
||||||
|
|
||||||
|
transcript = "\n\n".join(_flatten_msg(m) for m in stripped_recent)
|
||||||
extraction_messages = [
|
extraction_messages = [
|
||||||
{"role": "system", "content": EXTRACT_SYSTEM_PROMPT},
|
{"role": "system", "content": EXTRACT_SYSTEM_PROMPT},
|
||||||
] + stripped_recent
|
{"role": "user", "content": (
|
||||||
|
"Conversation to analyze:\n\n" + transcript
|
||||||
|
+ "\n\nReturn the JSON array of durable facts now (or [] if none)."
|
||||||
|
)},
|
||||||
|
]
|
||||||
|
|
||||||
facts = []
|
facts = []
|
||||||
try:
|
try:
|
||||||
@@ -287,19 +357,20 @@ async def extract_and_store(
|
|||||||
model,
|
model,
|
||||||
extraction_messages,
|
extraction_messages,
|
||||||
temperature=0.1,
|
temperature=0.1,
|
||||||
max_tokens=500,
|
# A reasoning model spends most of its budget on <think> tokens
|
||||||
|
# BEFORE emitting the JSON, so the old 500 truncated the response
|
||||||
|
# before any JSON appeared → every run logged "0 candidates". The
|
||||||
|
# audit path hit the same wall and raised to 16384; extraction's
|
||||||
|
# output (a short facts list) is small, so an ample ceiling is
|
||||||
|
# enough once thinking has room.
|
||||||
|
max_tokens=4096,
|
||||||
headers=headers,
|
headers=headers,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Parse JSON from response (handle markdown fences if model wraps them)
|
# Parse JSON, tolerating reasoning-model noise (<think> blocks, a
|
||||||
text = raw.strip()
|
# ```json fence, and leading/trailing commentary). See
|
||||||
if text.startswith("```"):
|
# _parse_extraction_json — returns [] rather than raising.
|
||||||
text = text.split("\n", 1)[-1].rsplit("```", 1)[0].strip()
|
facts = _parse_extraction_json(raw)
|
||||||
|
|
||||||
try:
|
|
||||||
facts = json.loads(text)
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
logger.debug("Memory extraction returned non-JSON")
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"LLM memory extraction failed; using fallback candidates if available: {e}")
|
logger.warning(f"LLM memory extraction failed; using fallback candidates if available: {e}")
|
||||||
|
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import os
|
|||||||
from .memory import MemoryManager
|
from .memory import MemoryManager
|
||||||
from .memory_vector import MemoryVectorStore
|
from .memory_vector import MemoryVectorStore
|
||||||
from src.memory_provider import MemoryRecord, NativeMemoryProvider
|
from src.memory_provider import MemoryRecord, NativeMemoryProvider
|
||||||
|
from src.constants import DATA_DIR
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -38,7 +39,7 @@ class MemoryService:
|
|||||||
results = await service.recall("preferences")
|
results = await service.recall("preferences")
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, data_dir: str = "data"):
|
def __init__(self, data_dir: str = DATA_DIR):
|
||||||
self.manager = MemoryManager(data_dir)
|
self.manager = MemoryManager(data_dir)
|
||||||
self.vector_store = MemoryVectorStore(data_dir) if os.path.exists(
|
self.vector_store = MemoryVectorStore(data_dir) if os.path.exists(
|
||||||
os.path.join(data_dir, "memory_vectors")
|
os.path.join(data_dir, "memory_vectors")
|
||||||
|
|||||||
@@ -63,6 +63,46 @@ def _has_duplicate_title(skills, title: str) -> bool:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_json_object(text: str) -> Optional[dict]:
|
||||||
|
"""Best-effort extraction of a JSON object from an LLM response.
|
||||||
|
|
||||||
|
The response may be wrapped in code fences or surrounded by prose, and some
|
||||||
|
models emit a stray brace in the prose before the real object
|
||||||
|
(e.g. "uses {placeholder} then {...}"). Slicing first-'{' .. last-'}' then
|
||||||
|
grabs an unparseable span and the skill is silently lost. Try the whole
|
||||||
|
string first, then each '{' start position in turn, returning the first
|
||||||
|
candidate that parses to a JSON object (dict). Returns None if none do.
|
||||||
|
"""
|
||||||
|
if not text:
|
||||||
|
return None
|
||||||
|
s = text.strip()
|
||||||
|
if s.startswith("```"):
|
||||||
|
s = s.split("\n", 1)[-1].rsplit("```", 1)[0].strip()
|
||||||
|
end = s.rfind("}")
|
||||||
|
if end == -1:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _as_dict(candidate):
|
||||||
|
try:
|
||||||
|
obj = json.loads(candidate)
|
||||||
|
except (json.JSONDecodeError, ValueError):
|
||||||
|
return None
|
||||||
|
return obj if isinstance(obj, dict) else None
|
||||||
|
|
||||||
|
# The clean, common case: the whole (de-fenced) string is the object.
|
||||||
|
obj = _as_dict(s)
|
||||||
|
if obj is not None:
|
||||||
|
return obj
|
||||||
|
# Otherwise scan each '{' candidate up to the last '}'.
|
||||||
|
start = s.find("{")
|
||||||
|
while 0 <= start < end:
|
||||||
|
obj = _as_dict(s[start : end + 1])
|
||||||
|
if obj is not None:
|
||||||
|
return obj
|
||||||
|
start = s.find("{", start + 1)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
async def maybe_extract_skill(
|
async def maybe_extract_skill(
|
||||||
session,
|
session,
|
||||||
skills_manager,
|
skills_manager,
|
||||||
@@ -169,21 +209,14 @@ async def maybe_extract_skill(
|
|||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# Parse JSON
|
# Parse JSON. The object may be wrapped in code fences or surrounded by
|
||||||
text = response.strip()
|
# commentary (and may contain a stray/invalid brace fragment before
|
||||||
if text.startswith("```"):
|
# the real object — including one that makes the response itself look
|
||||||
text = text.split("\n", 1)[-1].rsplit("```", 1)[0].strip()
|
# like it starts with '{'), so use a tolerant extractor that tries the
|
||||||
# After strip_think, the JSON may still be embedded inside surrounding
|
# whole string first and then each '{' candidate left-to-right.
|
||||||
# commentary — slice from the first '{' to the matching last '}'.
|
data = _extract_json_object(response)
|
||||||
if text and text[0] != "{":
|
if not data:
|
||||||
_start = text.find("{")
|
logger.debug("[skill-extract] no JSON object found in response, dropping")
|
||||||
_end = text.rfind("}")
|
|
||||||
if 0 <= _start < _end:
|
|
||||||
text = text[_start : _end + 1]
|
|
||||||
|
|
||||||
data = json.loads(text)
|
|
||||||
if not data or not isinstance(data, dict):
|
|
||||||
logger.debug("[skill-extract] parsed JSON not a dict, dropping")
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
title = data.get("title", "").strip()
|
title = data.get("title", "").strip()
|
||||||
|
|||||||
@@ -0,0 +1,283 @@
|
|||||||
|
"""Import SKILL.md bundles from public GitHub (or skills.sh → GitHub) URLs."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
from urllib.parse import quote, urlparse
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
from src.url_safety import check_outbound_url
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
MAX_FILES = 64
|
||||||
|
MAX_TOTAL_BYTES = 2_000_000
|
||||||
|
MAX_FILE_BYTES = 400_000
|
||||||
|
ALLOWED_SUFFIXES = (
|
||||||
|
".md", ".txt", ".json", ".yaml", ".yml", ".py", ".sh", ".toml",
|
||||||
|
".js", ".ts", ".css", ".html", ".xml", ".csv",
|
||||||
|
)
|
||||||
|
TEXT_NAMES = {"skill.md", "license", "license.md", "readme.md"}
|
||||||
|
_GITHUB_HOSTS = frozenset({
|
||||||
|
"github.com", "www.github.com", "api.github.com", "raw.githubusercontent.com",
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
def _github_host(url: str) -> str:
|
||||||
|
return (urlparse(str(url)).hostname or "").lower()
|
||||||
|
|
||||||
|
|
||||||
|
def _assert_github_url(url: str, *, context: str = "URL") -> None:
|
||||||
|
host = _github_host(url)
|
||||||
|
if host not in _GITHUB_HOSTS:
|
||||||
|
raise SkillImportError(
|
||||||
|
f"{context} must stay on GitHub (got {host or 'unknown host'})"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ResolvedSource:
|
||||||
|
owner: str
|
||||||
|
repo: str
|
||||||
|
ref: str
|
||||||
|
path: str # directory or file path inside repo (no leading slash)
|
||||||
|
|
||||||
|
|
||||||
|
class SkillImportError(ValueError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def _safe_relpath(rel: str) -> str:
|
||||||
|
rel = (rel or "").replace("\\", "/").strip().lstrip("/")
|
||||||
|
if not rel or rel.startswith("..") or "/../" in f"/{rel}/":
|
||||||
|
raise SkillImportError(f"unsafe path: {rel!r}")
|
||||||
|
parts = [p for p in rel.split("/") if p and p != "."]
|
||||||
|
if any(p == ".." for p in parts):
|
||||||
|
raise SkillImportError(f"unsafe path: {rel!r}")
|
||||||
|
return "/".join(parts)
|
||||||
|
|
||||||
|
|
||||||
|
def _is_text_file(name: str) -> bool:
|
||||||
|
low = name.lower()
|
||||||
|
if low in TEXT_NAMES:
|
||||||
|
return True
|
||||||
|
return any(low.endswith(s) for s in ALLOWED_SUFFIXES)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_skill_source(url: str) -> ResolvedSource:
|
||||||
|
"""Normalize skills.sh / GitHub web URLs into owner/repo/ref/path."""
|
||||||
|
raw = (url or "").strip()
|
||||||
|
if not raw:
|
||||||
|
raise SkillImportError("URL is required")
|
||||||
|
|
||||||
|
# skills.sh often links to GitHub; try to unwrap ?url= or redirect target later.
|
||||||
|
if "skills.sh" in raw and "github.com" not in raw:
|
||||||
|
ok, reason = check_outbound_url(raw)
|
||||||
|
if not ok:
|
||||||
|
raise SkillImportError(reason)
|
||||||
|
with httpx.Client(follow_redirects=True, timeout=20.0) as client:
|
||||||
|
r = client.get(raw)
|
||||||
|
if r.status_code >= 400:
|
||||||
|
raise _github_response_error(r)
|
||||||
|
final = str(r.url)
|
||||||
|
_assert_github_url(final, context="redirect target")
|
||||||
|
# Page may embed a github link; prefer final URL if redirected.
|
||||||
|
if "github.com" in final:
|
||||||
|
raw = final
|
||||||
|
else:
|
||||||
|
m = re.search(r"https?://github\.com/[^\s\"')]+", r.text or "")
|
||||||
|
if m:
|
||||||
|
raw = m.group(0).rstrip(".,)")
|
||||||
|
|
||||||
|
parsed = urlparse(raw)
|
||||||
|
host = _github_host(raw)
|
||||||
|
if host not in _GITHUB_HOSTS:
|
||||||
|
raise SkillImportError(
|
||||||
|
"Only GitHub URLs are supported (https://github.com/... or raw.githubusercontent.com/...)"
|
||||||
|
)
|
||||||
|
|
||||||
|
if host == "raw.githubusercontent.com":
|
||||||
|
# /owner/repo/ref/path/to/file
|
||||||
|
bits = [p for p in parsed.path.split("/") if p]
|
||||||
|
if len(bits) < 4:
|
||||||
|
raise SkillImportError("Invalid raw GitHub URL")
|
||||||
|
owner, repo, ref = bits[0], bits[1], bits[2]
|
||||||
|
path = "/".join(bits[3:])
|
||||||
|
return ResolvedSource(owner=owner, repo=repo, ref=ref, path=path)
|
||||||
|
|
||||||
|
bits = [p for p in parsed.path.split("/") if p]
|
||||||
|
if len(bits) < 2:
|
||||||
|
raise SkillImportError("Invalid GitHub URL")
|
||||||
|
owner, repo = bits[0], bits[1]
|
||||||
|
ref = "main"
|
||||||
|
path = ""
|
||||||
|
|
||||||
|
if len(bits) >= 4 and bits[2] in ("tree", "blob"):
|
||||||
|
ref = bits[3]
|
||||||
|
path = "/".join(bits[4:])
|
||||||
|
elif len(bits) == 2:
|
||||||
|
path = ""
|
||||||
|
else:
|
||||||
|
raise SkillImportError("GitHub URL must include /tree/<branch>/... or /blob/<branch>/...")
|
||||||
|
|
||||||
|
return ResolvedSource(owner=owner, repo=repo, ref=ref, path=path)
|
||||||
|
|
||||||
|
|
||||||
|
def _raw_url(src: ResolvedSource, rel_path: str) -> str:
|
||||||
|
rel = _safe_relpath(rel_path)
|
||||||
|
return f"https://raw.githubusercontent.com/{src.owner}/{src.repo}/{quote(src.ref, safe='')}/{quote(rel, safe='/')}"
|
||||||
|
|
||||||
|
|
||||||
|
def _api_contents_url(src: ResolvedSource, rel_path: str = "") -> str:
|
||||||
|
rel = _safe_relpath(rel_path) if rel_path else ""
|
||||||
|
base = f"https://api.github.com/repos/{src.owner}/{src.repo}/contents"
|
||||||
|
if rel:
|
||||||
|
base += f"/{quote(rel, safe='/')}"
|
||||||
|
return f"{base}?ref={quote(src.ref, safe='')}"
|
||||||
|
|
||||||
|
|
||||||
|
def _github_response_error(response: httpx.Response) -> SkillImportError:
|
||||||
|
"""Turn a failed GitHub HTTP response into a user-visible import error."""
|
||||||
|
status = response.status_code
|
||||||
|
detail = ""
|
||||||
|
try:
|
||||||
|
body = response.json()
|
||||||
|
if isinstance(body, dict):
|
||||||
|
detail = str(body.get("message") or "").strip()
|
||||||
|
except Exception:
|
||||||
|
detail = (response.text or "").strip()[:200]
|
||||||
|
|
||||||
|
low = detail.lower()
|
||||||
|
if status == 403 and "rate limit" in low:
|
||||||
|
return SkillImportError(
|
||||||
|
"GitHub API rate limit exceeded — try again in a bit"
|
||||||
|
+ (f" ({detail})" if detail else "")
|
||||||
|
)
|
||||||
|
if status == 404:
|
||||||
|
return SkillImportError("path not found on GitHub")
|
||||||
|
if detail:
|
||||||
|
return SkillImportError(f"GitHub request failed ({status}): {detail}")
|
||||||
|
return SkillImportError(f"GitHub request failed ({status})")
|
||||||
|
|
||||||
|
|
||||||
|
def _fetch_bytes(url: str) -> bytes:
|
||||||
|
ok, reason = check_outbound_url(url)
|
||||||
|
if not ok:
|
||||||
|
raise SkillImportError(reason)
|
||||||
|
with httpx.Client(follow_redirects=True, timeout=30.0) as client:
|
||||||
|
r = client.get(url, headers={"Accept": "application/vnd.github+json"})
|
||||||
|
if r.status_code >= 400:
|
||||||
|
raise _github_response_error(r)
|
||||||
|
_assert_github_url(str(r.url), context="redirect target")
|
||||||
|
if len(r.content) > MAX_FILE_BYTES:
|
||||||
|
raise SkillImportError(f"file too large: {url}")
|
||||||
|
return r.content
|
||||||
|
|
||||||
|
|
||||||
|
def _fetch_text(url: str) -> str:
|
||||||
|
data = _fetch_bytes(url)
|
||||||
|
try:
|
||||||
|
return data.decode("utf-8")
|
||||||
|
except UnicodeDecodeError as e:
|
||||||
|
raise SkillImportError(f"non-text file: {url}") from e
|
||||||
|
|
||||||
|
|
||||||
|
def _list_github_dir(src: ResolvedSource, rel_dir: str, out: Dict[str, str], *, depth: int = 0) -> None:
|
||||||
|
if depth > 4 or len(out) >= MAX_FILES:
|
||||||
|
return
|
||||||
|
url = _api_contents_url(src, rel_dir)
|
||||||
|
ok, reason = check_outbound_url(url)
|
||||||
|
if not ok:
|
||||||
|
raise SkillImportError(reason)
|
||||||
|
with httpx.Client(follow_redirects=True, timeout=30.0) as client:
|
||||||
|
r = client.get(url, headers={"Accept": "application/vnd.github+json"})
|
||||||
|
if r.status_code >= 400:
|
||||||
|
raise _github_response_error(r)
|
||||||
|
_assert_github_url(str(r.url), context="redirect target")
|
||||||
|
entries = r.json()
|
||||||
|
if not isinstance(entries, list):
|
||||||
|
raise SkillImportError("expected a directory on GitHub")
|
||||||
|
total = sum(len(v.encode("utf-8")) for v in out.values())
|
||||||
|
for ent in entries:
|
||||||
|
if len(out) >= MAX_FILES or total >= MAX_TOTAL_BYTES:
|
||||||
|
break
|
||||||
|
if not isinstance(ent, dict):
|
||||||
|
continue
|
||||||
|
name = ent.get("name") or ""
|
||||||
|
ent_type = ent.get("type")
|
||||||
|
rel = _safe_relpath(f"{rel_dir}/{name}" if rel_dir else name)
|
||||||
|
if ent_type == "dir":
|
||||||
|
_list_github_dir(src, rel, out, depth=depth + 1)
|
||||||
|
total = sum(len(v.encode("utf-8")) for v in out.values())
|
||||||
|
continue
|
||||||
|
if ent_type != "file" or not _is_text_file(name):
|
||||||
|
continue
|
||||||
|
dl = ent.get("download_url")
|
||||||
|
if not dl:
|
||||||
|
continue
|
||||||
|
_assert_github_url(dl, context="download URL")
|
||||||
|
text = _fetch_text(dl)
|
||||||
|
total += len(text.encode("utf-8"))
|
||||||
|
if total > MAX_TOTAL_BYTES:
|
||||||
|
raise SkillImportError("skill bundle exceeds size limit")
|
||||||
|
out[rel] = text
|
||||||
|
|
||||||
|
|
||||||
|
def fetch_skill_bundle(url: str) -> Tuple[Dict[str, str], ResolvedSource]:
|
||||||
|
"""Download SKILL.md and sibling text assets. Returns relative_path → content."""
|
||||||
|
src = parse_skill_source(url)
|
||||||
|
files: Dict[str, str] = {}
|
||||||
|
|
||||||
|
path = _safe_relpath(src.path) if src.path else ""
|
||||||
|
if path.lower().endswith("skill.md"):
|
||||||
|
files[path] = _fetch_text(_raw_url(src, path))
|
||||||
|
parent = "/".join(path.split("/")[:-1])
|
||||||
|
if parent:
|
||||||
|
try:
|
||||||
|
_list_github_dir(src, parent, files)
|
||||||
|
except SkillImportError:
|
||||||
|
pass
|
||||||
|
return files, src
|
||||||
|
|
||||||
|
if path:
|
||||||
|
try:
|
||||||
|
_fetch_text(_raw_url(src, f"{path}/SKILL.md"))
|
||||||
|
_list_github_dir(src, path, files)
|
||||||
|
return files, src
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
text = _fetch_text(_raw_url(src, path))
|
||||||
|
if path.lower().endswith(".md"):
|
||||||
|
files[path] = text
|
||||||
|
return files, src
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
_list_github_dir(src, path, files)
|
||||||
|
else:
|
||||||
|
_list_github_dir(src, "", files)
|
||||||
|
|
||||||
|
if not any(p.lower().endswith("skill.md") for p in files):
|
||||||
|
# Flat repo root with SKILL.md only
|
||||||
|
try:
|
||||||
|
files["SKILL.md"] = _fetch_text(_raw_url(src, "SKILL.md"))
|
||||||
|
except Exception as e:
|
||||||
|
raise SkillImportError(
|
||||||
|
"No SKILL.md found — link to a skill folder or SKILL.md on GitHub"
|
||||||
|
) from e
|
||||||
|
return files, src
|
||||||
|
|
||||||
|
|
||||||
|
def pick_skill_md(files: Dict[str, str]) -> Tuple[str, str]:
|
||||||
|
for rel, content in files.items():
|
||||||
|
if rel.lower().endswith("skill.md"):
|
||||||
|
return rel, content
|
||||||
|
raise SkillImportError("bundle has no SKILL.md")
|
||||||
|
|
||||||
|
|
||||||
|
def default_category_from_source(src: ResolvedSource) -> str:
|
||||||
|
return "imported"
|
||||||
@@ -381,6 +381,54 @@ class SkillsManager:
|
|||||||
|
|
||||||
return sk.to_dict()
|
return sk.to_dict()
|
||||||
|
|
||||||
|
def import_bundle_from_files(
|
||||||
|
self,
|
||||||
|
files: Dict[str, str],
|
||||||
|
*,
|
||||||
|
owner: Optional[str] = None,
|
||||||
|
source_url: str = "",
|
||||||
|
category: str = "imported",
|
||||||
|
) -> Dict:
|
||||||
|
"""Install a fetched skill bundle (relative path → text) under skills/."""
|
||||||
|
from .skill_importer import SkillImportError, pick_skill_md, _safe_relpath
|
||||||
|
from core.atomic_io import atomic_write_text
|
||||||
|
|
||||||
|
if not files:
|
||||||
|
raise SkillImportError("empty bundle")
|
||||||
|
_rel, skill_md = pick_skill_md(files)
|
||||||
|
sk = Skill.from_markdown(skill_md)
|
||||||
|
nm = slugify(sk.name or _rel.split("/")[-2] or "skill")
|
||||||
|
cat = slugify(category or sk.category or "imported", fallback="imported")
|
||||||
|
|
||||||
|
existing = {s["name"] for s in self.load_all()}
|
||||||
|
base = nm
|
||||||
|
i = 2
|
||||||
|
while nm in existing:
|
||||||
|
nm = f"{base}-{i}"
|
||||||
|
i += 1
|
||||||
|
|
||||||
|
skill_dir = self._skill_dir(cat, nm)
|
||||||
|
os.makedirs(skill_dir, exist_ok=True)
|
||||||
|
|
||||||
|
# Preserve bundle layout (templates/, references/, etc.) under the skill dir.
|
||||||
|
for rel, content in files.items():
|
||||||
|
safe = _safe_relpath(rel)
|
||||||
|
dest = os.path.join(skill_dir, safe)
|
||||||
|
os.makedirs(os.path.dirname(dest), exist_ok=True)
|
||||||
|
atomic_write_text(dest, content)
|
||||||
|
|
||||||
|
sk.name = nm
|
||||||
|
sk.category = cat
|
||||||
|
sk.owner = owner
|
||||||
|
sk.source = "imported"
|
||||||
|
if source_url:
|
||||||
|
extra = (sk.body_extra or "").strip()
|
||||||
|
note = f"Imported from {source_url}"
|
||||||
|
sk.body_extra = f"{extra}\n\n{note}".strip() if extra else note
|
||||||
|
atomic_write_text(self._skill_file(cat, nm), sk.to_markdown())
|
||||||
|
sk.path = self._skill_file(cat, nm)
|
||||||
|
return sk.to_dict()
|
||||||
|
|
||||||
def update_skill(self, skill_id: str, updates: Dict, owner: Optional[str] = None) -> bool:
|
def update_skill(self, skill_id: str, updates: Dict, owner: Optional[str] = None) -> bool:
|
||||||
"""`skill_id` is the slug name. Allows updating any field plus
|
"""`skill_id` is the slug name. Allows updating any field plus
|
||||||
renames if `name` changes (file is moved on disk).
|
renames if `name` changes (file is moved on disk).
|
||||||
|
|||||||
@@ -15,10 +15,11 @@ from pathlib import Path
|
|||||||
from typing import Optional, Dict
|
from typing import Optional, Dict
|
||||||
|
|
||||||
from src.research_utils import is_low_quality
|
from src.research_utils import is_low_quality
|
||||||
|
from src.constants import DEEP_RESEARCH_DIR
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
RESEARCH_DATA_DIR = Path("data/deep_research")
|
RESEARCH_DATA_DIR = Path(DEEP_RESEARCH_DIR)
|
||||||
|
|
||||||
|
|
||||||
class ResearchHandler:
|
class ResearchHandler:
|
||||||
|
|||||||
@@ -6,21 +6,29 @@ from collections import Counter
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, Any
|
from typing import Dict, Any
|
||||||
|
|
||||||
|
from core.constants import DATA_DIR
|
||||||
|
|
||||||
from .cache import cache_metrics
|
from .cache import cache_metrics
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# Dedicated error logger with file handler
|
# Dedicated error logger — write to the data logs directory (writable on both
|
||||||
_error_log_path = Path(__file__).resolve().parent.parent / "search_engine_error.log"
|
# native runs and Docker, where DATA_DIR resolves to the bind-mounted volume).
|
||||||
_error_handler = logging.FileHandler(_error_log_path, encoding="utf-8")
|
_log_dir = Path(DATA_DIR) / "logs"
|
||||||
_error_handler.setLevel(logging.WARNING)
|
_error_log_path = _log_dir / "search_engine_error.log"
|
||||||
_error_handler.setFormatter(logging.Formatter("%(asctime)s %(levelname)s %(name)s %(message)s"))
|
|
||||||
error_logger = logging.getLogger("search_engine_error")
|
error_logger = logging.getLogger("search_engine_error")
|
||||||
error_logger.addHandler(_error_handler)
|
|
||||||
error_logger.propagate = False
|
error_logger.propagate = False
|
||||||
|
try:
|
||||||
|
_log_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
_error_handler = logging.FileHandler(_error_log_path, encoding="utf-8")
|
||||||
|
_error_handler.setLevel(logging.WARNING)
|
||||||
|
_error_handler.setFormatter(logging.Formatter("%(asctime)s %(levelname)s %(name)s %(message)s"))
|
||||||
|
error_logger.addHandler(_error_handler)
|
||||||
|
except Exception as _e:
|
||||||
|
logging.getLogger(__name__).warning("search_engine_error log handler unavailable: %s", _e)
|
||||||
|
|
||||||
# Analytics file
|
# Analytics file — also in the writable logs volume.
|
||||||
ANALYTICS_FILE = Path(__file__).resolve().parent.parent / "search_analytics.json"
|
ANALYTICS_FILE = _log_dir / "search_analytics.json"
|
||||||
|
|
||||||
|
|
||||||
# ----------------------------------------------------------------------
|
# ----------------------------------------------------------------------
|
||||||
|
|||||||
@@ -6,17 +6,23 @@ from datetime import datetime, timedelta
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
|
|
||||||
|
from core.constants import DATA_DIR
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# Cache directories
|
# Cache directories
|
||||||
CACHE_DIR = Path(__file__).resolve().parent.parent / "cache"
|
CACHE_DIR = Path(DATA_DIR) / "cache"
|
||||||
SEARCH_CACHE_DIR = CACHE_DIR / "search"
|
SEARCH_CACHE_DIR = CACHE_DIR / "search"
|
||||||
CONTENT_CACHE_DIR = CACHE_DIR / "content"
|
CONTENT_CACHE_DIR = CACHE_DIR / "content"
|
||||||
CACHE_MAX_ENTRIES = 1000
|
CACHE_MAX_ENTRIES = 1000
|
||||||
|
|
||||||
# Create cache directories
|
# Create cache directories. Guarded so an unwritable path (e.g. a read-only
|
||||||
SEARCH_CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
# mount) degrades to no-disk-cache instead of crashing module import.
|
||||||
CONTENT_CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
try:
|
||||||
|
SEARCH_CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
|
CONTENT_CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
|
except OSError as _e:
|
||||||
|
logger.warning("Search cache directory unavailable (%s); disk cache disabled", _e)
|
||||||
|
|
||||||
# Track cache size for LRU eviction
|
# Track cache size for LRU eviction
|
||||||
search_cache_index: Dict[str, datetime] = {}
|
search_cache_index: Dict[str, datetime] = {}
|
||||||
|
|||||||
@@ -259,6 +259,9 @@ def fetch_webpage_content(url: str, timeout: int = 5, retry_attempt: int = 0) ->
|
|||||||
raise RateLimitError(f"Rate limit hit for {url} (attempt {retry_attempt})")
|
raise RateLimitError(f"Rate limit hit for {url} (attempt {retry_attempt})")
|
||||||
|
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
|
except httpx.HTTPStatusError as e:
|
||||||
|
error_logger.warning(f"HTTP {e.response.status_code} fetching {url}: {e}")
|
||||||
|
return _empty_result(url, f"HTTP {e.response.status_code}: {e}")
|
||||||
except httpx.RequestError as e:
|
except httpx.RequestError as e:
|
||||||
error_logger.error(f"NetworkError fetching {url} (attempt {retry_attempt}): {e}")
|
error_logger.error(f"NetworkError fetching {url} (attempt {retry_attempt}): {e}")
|
||||||
return _empty_result(url, f"NetworkError: {e}")
|
return _empty_result(url, f"NetworkError: {e}")
|
||||||
|
|||||||
@@ -76,6 +76,19 @@ def _domain(url: str) -> str:
|
|||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
|
||||||
|
def _has_word(text: str, term: str) -> bool:
|
||||||
|
"""True if ``term`` appears in ``text`` as a whole word.
|
||||||
|
|
||||||
|
Query terms are matched on word boundaries so a short term doesn't match
|
||||||
|
inside an unrelated word: "us" must not match "business"/"music", "port"
|
||||||
|
must not match "transport"/"support". This mirrors the tokenization used to
|
||||||
|
build ``query_terms`` (``\\b\\w+\\b``). #1473 converted the title and sports
|
||||||
|
checks to word boundaries; the snippet and subject-term checks below use
|
||||||
|
the same helper so the whole file stays consistent.
|
||||||
|
"""
|
||||||
|
return re.search(rf"\b{re.escape(term)}\b", text) is not None
|
||||||
|
|
||||||
|
|
||||||
def rank_search_results(query: str, results: List[dict]) -> List[dict]:
|
def rank_search_results(query: str, results: List[dict]) -> List[dict]:
|
||||||
"""Rank search results by title relevance, snippet quality, domain authority, and recency."""
|
"""Rank search results by title relevance, snippet quality, domain authority, and recency."""
|
||||||
query_terms = [t.lower() for t in re.findall(r"\b\w+\b", query)]
|
query_terms = [t.lower() for t in re.findall(r"\b\w+\b", query)]
|
||||||
@@ -87,14 +100,14 @@ def rank_search_results(query: str, results: List[dict]) -> List[dict]:
|
|||||||
if not title:
|
if not title:
|
||||||
return 0.0
|
return 0.0
|
||||||
title_lc = title.lower()
|
title_lc = title.lower()
|
||||||
matches = sum(1 for term in query_terms if re.search(rf"\b{re.escape(term)}\b", title_lc))
|
matches = sum(1 for term in query_terms if _has_word(title_lc, term))
|
||||||
return matches / len(query_terms) if query_terms else 0.0
|
return matches / len(query_terms) if query_terms else 0.0
|
||||||
|
|
||||||
def snippet_score(snippet: str) -> float:
|
def snippet_score(snippet: str) -> float:
|
||||||
if not snippet:
|
if not snippet:
|
||||||
return 0.0
|
return 0.0
|
||||||
length_factor = min(len(snippet), 200) / 200
|
length_factor = min(len(snippet), 200) / 200
|
||||||
term_hits = sum(1 for term in query_terms if term in snippet.lower())
|
term_hits = sum(1 for term in query_terms if _has_word(snippet.lower(), term))
|
||||||
term_factor = term_hits / len(query_terms) if query_terms else 0.0
|
term_factor = term_hits / len(query_terms) if query_terms else 0.0
|
||||||
return (length_factor + term_factor) / 2
|
return (length_factor + term_factor) / 2
|
||||||
|
|
||||||
@@ -127,7 +140,7 @@ def rank_search_results(query: str, results: List[dict]) -> List[dict]:
|
|||||||
# A country/news query should not rank a page whose title/snippet barely
|
# A country/news query should not rank a page whose title/snippet barely
|
||||||
# mentions the country above actual news pages for that country.
|
# mentions the country above actual news pages for that country.
|
||||||
subject_terms = [t for t in query_terms if t not in _NEWS_HINTS]
|
subject_terms = [t for t in query_terms if t not in _NEWS_HINTS]
|
||||||
if subject_terms and not any(t in text or t in netloc for t in subject_terms):
|
if subject_terms and not any(_has_word(text, t) or _has_word(netloc, t) for t in subject_terms):
|
||||||
adjustment -= 1.0
|
adjustment -= 1.0
|
||||||
return adjustment
|
return adjustment
|
||||||
|
|
||||||
|
|||||||
@@ -9,6 +9,8 @@ import httpx
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional, Dict, Any
|
from typing import Optional, Dict, Any
|
||||||
|
|
||||||
|
from src.constants import TTS_CACHE_DIR
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@@ -35,7 +37,7 @@ class TTSService:
|
|||||||
"endpoint:<id>" — OpenAI-compatible /audio/speech via ModelEndpoint
|
"endpoint:<id>" — OpenAI-compatible /audio/speech via ModelEndpoint
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, cache_dir: str = "data/tts_cache"):
|
def __init__(self, cache_dir: str = TTS_CACHE_DIR):
|
||||||
self.cache_dir = Path(cache_dir)
|
self.cache_dir = Path(cache_dir)
|
||||||
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
||||||
self._kokoro = None # lazy-init
|
self._kokoro = None # lazy-init
|
||||||
|
|||||||
@@ -6,23 +6,30 @@ initial admin user. Safe to re-run (skips what already exists).
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
import platform
|
||||||
import shutil
|
import shutil
|
||||||
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||||
DATA_DIR = os.path.join(BASE_DIR, "data")
|
sys.path.insert(0, BASE_DIR)
|
||||||
|
from src.constants import (
|
||||||
|
DATA_DIR, AUTH_FILE, UPLOAD_DIR, PERSONAL_DIR, PERSONAL_UPLOADS_DIR,
|
||||||
|
TTS_CACHE_DIR, GENERATED_IMAGES_DIR, DEEP_RESEARCH_DIR, CHROMA_DIR,
|
||||||
|
RAG_DIR, MEMORY_VECTORS_DIR,
|
||||||
|
)
|
||||||
|
|
||||||
DIRS = [
|
DIRS = [
|
||||||
DATA_DIR,
|
DATA_DIR,
|
||||||
os.path.join(DATA_DIR, "uploads"),
|
UPLOAD_DIR,
|
||||||
os.path.join(DATA_DIR, "personal_docs"),
|
PERSONAL_DIR,
|
||||||
os.path.join(DATA_DIR, "personal_uploads"),
|
PERSONAL_UPLOADS_DIR,
|
||||||
os.path.join(DATA_DIR, "tts_cache"),
|
TTS_CACHE_DIR,
|
||||||
os.path.join(DATA_DIR, "generated_images"),
|
GENERATED_IMAGES_DIR,
|
||||||
os.path.join(DATA_DIR, "deep_research"),
|
DEEP_RESEARCH_DIR,
|
||||||
os.path.join(DATA_DIR, "chroma"),
|
CHROMA_DIR,
|
||||||
os.path.join(DATA_DIR, "rag"),
|
RAG_DIR,
|
||||||
os.path.join(DATA_DIR, "memory_vectors"),
|
MEMORY_VECTORS_DIR,
|
||||||
os.path.join(BASE_DIR, "logs"),
|
os.path.join(BASE_DIR, "logs"),
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -72,7 +79,7 @@ def _prompt_admin_credentials():
|
|||||||
|
|
||||||
def create_default_admin():
|
def create_default_admin():
|
||||||
"""Create an initial admin user if none exists."""
|
"""Create an initial admin user if none exists."""
|
||||||
auth_path = os.path.join(DATA_DIR, "auth.json")
|
auth_path = AUTH_FILE
|
||||||
if os.path.exists(auth_path):
|
if os.path.exists(auth_path):
|
||||||
print(" [skip] auth.json already exists")
|
print(" [skip] auth.json already exists")
|
||||||
return "exists"
|
return "exists"
|
||||||
@@ -117,7 +124,16 @@ def create_default_admin():
|
|||||||
print(f" Temporary password: {password}")
|
print(f" Temporary password: {password}")
|
||||||
print(f" ** Change it after first login. Set ODYSSEUS_ADMIN_PASSWORD to choose your own. **")
|
print(f" ** Change it after first login. Set ODYSSEUS_ADMIN_PASSWORD to choose your own. **")
|
||||||
return "created"
|
return "created"
|
||||||
except ImportError:
|
except ImportError as e:
|
||||||
|
if "incompatible architecture" in str(e).lower():
|
||||||
|
# bcrypt is present but built for the wrong CPU architecture — the
|
||||||
|
# same Apple Silicon mismatch check_arch() guards against, caught here
|
||||||
|
# for the rarer case of an x86 wheel inside an arm64 venv.
|
||||||
|
print(" [error] bcrypt loaded with the wrong CPU architecture.")
|
||||||
|
print(" Rebuild the venv with an arm64 Python:")
|
||||||
|
print(" rm -rf venv && /opt/homebrew/bin/python3.11 -m venv venv")
|
||||||
|
print(" ./venv/bin/pip install -r requirements.txt")
|
||||||
|
return "skipped"
|
||||||
print(" [warn] bcrypt not installed — skipping admin user creation")
|
print(" [warn] bcrypt not installed — skipping admin user creation")
|
||||||
print(" Run: pip install bcrypt")
|
print(" Run: pip install bcrypt")
|
||||||
return "skipped"
|
return "skipped"
|
||||||
@@ -167,9 +183,52 @@ def check_deps():
|
|||||||
print(" [ok] tmux installed")
|
print(" [ok] tmux installed")
|
||||||
|
|
||||||
|
|
||||||
|
def check_arch():
|
||||||
|
"""Stop early, with guidance, if we're on Apple Silicon but running an
|
||||||
|
Intel (x86_64) Python through Rosetta.
|
||||||
|
|
||||||
|
A venv built with such an interpreter installs and loads compiled packages
|
||||||
|
(bcrypt, pydantic-core, onnxruntime, …) for the wrong CPU architecture, then
|
||||||
|
dies deep inside an import with a cryptic
|
||||||
|
"(mach-o file, but is an incompatible architecture)" error. Catching it here
|
||||||
|
turns that into one clear, actionable message.
|
||||||
|
"""
|
||||||
|
if sys.platform != "darwin" or platform.machine() == "arm64":
|
||||||
|
return # Not macOS, or already an arm64-native interpreter — nothing to do.
|
||||||
|
|
||||||
|
# platform.machine() == "x86_64": either a genuine Intel Mac (fine) or an x86
|
||||||
|
# interpreter running under Rosetta on Apple Silicon (the case we must catch).
|
||||||
|
try:
|
||||||
|
translated = subprocess.run(
|
||||||
|
["sysctl", "-n", "sysctl.proc_translated"],
|
||||||
|
capture_output=True, text=True, timeout=5,
|
||||||
|
).stdout.strip()
|
||||||
|
except Exception:
|
||||||
|
translated = ""
|
||||||
|
if translated != "1":
|
||||||
|
return # Genuine Intel Mac — carry on.
|
||||||
|
|
||||||
|
print("\n [error] This is an Apple Silicon Mac, but setup is running under an")
|
||||||
|
print(" Intel (x86_64) Python through Rosetta. Compiled packages would")
|
||||||
|
print(' load as the wrong architecture and crash with "incompatible')
|
||||||
|
print(' architecture" later on.')
|
||||||
|
print("\n Rebuild the environment with Homebrew's arm64 Python:")
|
||||||
|
print(" brew install python@3.11 # if you don't have it yet")
|
||||||
|
print(" rm -rf venv")
|
||||||
|
print(" /opt/homebrew/bin/python3.11 -m venv venv")
|
||||||
|
print(" ./venv/bin/pip install -r requirements.txt")
|
||||||
|
print(" ./venv/bin/python setup.py")
|
||||||
|
print("\n Tip: ./start-macos.sh does all of this with the right Python.\n")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
print("\n=== Odysseus Setup ===\n")
|
print("\n=== Odysseus Setup ===\n")
|
||||||
|
|
||||||
|
# Fail fast with a clear message if the CPU architecture is wrong (Apple
|
||||||
|
# Silicon under an x86/Rosetta Python) before importing anything native.
|
||||||
|
check_arch()
|
||||||
|
|
||||||
print("1. Creating directories...")
|
print("1. Creating directories...")
|
||||||
create_dirs()
|
create_dirs()
|
||||||
|
|
||||||
|
|||||||
@@ -35,6 +35,7 @@ _CALENDAR_ACTION = (
|
|||||||
r"delete|deleting|remove|removing|cancel|cancelling|canceling)"
|
r"delete|deleting|remove|removing|cancel|cancelling|canceling)"
|
||||||
)
|
)
|
||||||
_CALENDAR_THING = r"(?:calendar|calendar\s+(?:entry|item)|event|meeting|appointment|entry|call)"
|
_CALENDAR_THING = r"(?:calendar|calendar\s+(?:entry|item)|event|meeting|appointment|entry|call)"
|
||||||
|
_CALENDAR_READ_THING = r"(?:calendar|schedule|events?|meetings?|appointments?|classes?)"
|
||||||
_EXPLANATORY_PREFIX = re.compile(
|
_EXPLANATORY_PREFIX = re.compile(
|
||||||
r"^\s*(?:how\s+(?:do|can)\s+i|can\s+you\s+explain|what\s+about|tell\s+me\s+how|show\s+me\s+how)\b",
|
r"^\s*(?:how\s+(?:do|can)\s+i|can\s+you\s+explain|what\s+about|tell\s+me\s+how|show\s+me\s+how)\b",
|
||||||
re.I,
|
re.I,
|
||||||
@@ -59,6 +60,14 @@ _ROUTING_PATTERNS: tuple[tuple[str, str, Pattern[str]], ...] = tuple(
|
|||||||
("calendar", "calendar target action request", rf"\b{_CALENDAR_ACTION}\b.{{0,120}}\b(?:to|on|in|into|for)\s+(?:my\s+|the\s+|this\s+)?calendar\b"),
|
("calendar", "calendar target action request", rf"\b{_CALENDAR_ACTION}\b.{{0,120}}\b(?:to|on|in|into|for)\s+(?:my\s+|the\s+|this\s+)?calendar\b"),
|
||||||
("calendar", "put item on calendar request", r"\bput\s+.+\bon\s+(?:my\s+)?calendar\b"),
|
("calendar", "put item on calendar request", r"\bput\s+.+\bon\s+(?:my\s+)?calendar\b"),
|
||||||
|
|
||||||
|
# Calendar/event lookup. A question such as "Do I have Taekwondo
|
||||||
|
# classes this week?" needs the calendar tool; plain chat cannot know.
|
||||||
|
("calendar", "calendar lookup request", rf"\b(?:list|show|check|find)\b.{{0,120}}\b(?:my\s+|the\s+)?(?:upcoming|next|today'?s?|tomorrow'?s?|this\s+week'?s?)\b.{{0,120}}\b{_CALENDAR_READ_THING}\b"),
|
||||||
|
("calendar", "calendar lookup question", rf"\b(?:what|which)\b.{{0,120}}\b(?:upcoming|next|today'?s?|tomorrow'?s?|this\s+week'?s?)\b.{{0,120}}\b{_CALENDAR_READ_THING}\b"),
|
||||||
|
("calendar", "calendar availability question", rf"\bdo\s+i\s+have\b.{{0,120}}\b(?:upcoming|next|today|tomorrow|this\s+week)\b.{{0,120}}\b{_CALENDAR_READ_THING}\b"),
|
||||||
|
("calendar", "calendar agenda question", r"\bwhat(?:'s| is)\s+on\s+(?:my\s+)?calendar\b"),
|
||||||
|
("calendar", "next calendar item question", r"\bwhen\s+(?:is|are)\s+(?:my\s+)?next\s+(?:event|meeting|appointment|class)\b"),
|
||||||
|
|
||||||
# Notes, todos, checklists, and reminders.
|
# Notes, todos, checklists, and reminders.
|
||||||
("notes", "reminder request", r"\bremind\s+me\b"),
|
("notes", "reminder request", r"\bremind\s+me\b"),
|
||||||
("notes", "assistant note/todo action request", rf"{_ACTION_QUESTION}(?:add|create|make|take|jot|write\s+down|set)\b.{{0,120}}\b(?:note|todo|task|checklist|reminder)\b"),
|
("notes", "assistant note/todo action request", rf"{_ACTION_QUESTION}(?:add|create|make|take|jot|write\s+down|set)\b.{{0,120}}\b(?:note|todo|task|checklist|reminder)\b"),
|
||||||
|
|||||||
+683
-198
File diff suppressed because it is too large
Load Diff
+5
-31
@@ -14,16 +14,17 @@ Sub-modules:
|
|||||||
import logging
|
import logging
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
|
|
||||||
|
from src.tool_utils import _truncate, get_mcp_manager, set_mcp_manager
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Constants (kept here — sub-modules import from here)
|
# Constants (re-exported for backward compatibility — single source of truth
|
||||||
|
# is src.constants; always prefer importing from there for new code)
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
MAX_AGENT_ROUNDS = 50
|
MAX_AGENT_ROUNDS = 50
|
||||||
SHELL_TIMEOUT = 60
|
SHELL_TIMEOUT = 60
|
||||||
PYTHON_TIMEOUT = 30
|
PYTHON_TIMEOUT = 30
|
||||||
MAX_OUTPUT_CHARS = 10_000
|
|
||||||
MAX_READ_CHARS = 20_000
|
|
||||||
|
|
||||||
# Tool types that trigger execution
|
# Tool types that trigger execution
|
||||||
TOOL_TAGS = {"bash", "python", "web_search", "web_fetch", "read_file", "write_file", "edit_file",
|
TOOL_TAGS = {"bash", "python", "web_search", "web_fetch", "read_file", "write_file", "edit_file",
|
||||||
@@ -34,7 +35,7 @@ TOOL_TAGS = {"bash", "python", "web_search", "web_fetch", "read_file", "write_fi
|
|||||||
"send_to_session",
|
"send_to_session",
|
||||||
"pipeline",
|
"pipeline",
|
||||||
"manage_session", "manage_memory", "list_models",
|
"manage_session", "manage_memory", "list_models",
|
||||||
"ui_control", "generate_image",
|
"ui_control", "generate_image", "ask_user", "update_plan",
|
||||||
"manage_tasks", "api_call", "ask_teacher", "manage_skills",
|
"manage_tasks", "api_call", "ask_teacher", "manage_skills",
|
||||||
"suggest_document",
|
"suggest_document",
|
||||||
"manage_endpoints", "manage_mcp", "manage_webhooks",
|
"manage_endpoints", "manage_mcp", "manage_webhooks",
|
||||||
@@ -63,33 +64,6 @@ TOOL_TAGS = {"bash", "python", "web_search", "web_fetch", "read_file", "write_fi
|
|||||||
|
|
||||||
ToolBlock = namedtuple("ToolBlock", ["tool_type", "content"])
|
ToolBlock = namedtuple("ToolBlock", ["tool_type", "content"])
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# MCP Manager (kept here — used by execution and agent_loop)
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
_mcp_manager = None
|
|
||||||
|
|
||||||
def set_mcp_manager(manager):
|
|
||||||
"""Set the global MCP manager instance."""
|
|
||||||
global _mcp_manager
|
|
||||||
_mcp_manager = manager
|
|
||||||
|
|
||||||
def get_mcp_manager():
|
|
||||||
"""Get the global MCP manager instance."""
|
|
||||||
return _mcp_manager
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Helpers (kept here — used by sub-modules)
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
def _truncate(text: str, limit: int = MAX_OUTPUT_CHARS) -> str:
|
|
||||||
# Callers treat the result as text, so always return a string: coerce a
|
|
||||||
# non-string (None -> "", otherwise str(...)) instead of returning it raw,
|
|
||||||
# which would just move the crash downstream.
|
|
||||||
if not isinstance(text, str):
|
|
||||||
text = "" if text is None else str(text)
|
|
||||||
if len(text) > limit:
|
|
||||||
return text[:limit] + f"\n... (truncated, {len(text)} chars total)"
|
|
||||||
return text
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Re-exports from sub-modules
|
# Re-exports from sub-modules
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|||||||
+44
-28
@@ -14,6 +14,8 @@ import uuid
|
|||||||
import time
|
import time
|
||||||
from typing import Dict, Optional, Tuple
|
from typing import Dict, Optional, Tuple
|
||||||
|
|
||||||
|
from src.constants import GENERATED_IMAGES_DIR
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
AI_CHAT_TIMEOUT = 120 # seconds for a single LLM call
|
AI_CHAT_TIMEOUT = 120 # seconds for a single LLM call
|
||||||
@@ -55,7 +57,7 @@ def set_rag_manager(rag_mgr, personal_docs_mgr=None):
|
|||||||
# Model resolution
|
# Model resolution
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
from src.endpoint_resolver import normalize_base as _normalize_base, build_chat_url, build_headers, build_models_url
|
from src.endpoint_resolver import build_chat_url, build_headers, build_models_url, resolve_endpoint_runtime
|
||||||
|
|
||||||
|
|
||||||
def _resolve_model(spec: str, owner: Optional[str] = None) -> Tuple[str, str, Dict]:
|
def _resolve_model(spec: str, owner: Optional[str] = None) -> Tuple[str, str, Dict]:
|
||||||
@@ -96,9 +98,12 @@ def _resolve_model(spec: str, owner: Optional[str] = None) -> Tuple[str, str, Di
|
|||||||
(f" matching '{target_endpoint_name}'" if target_endpoint_name else ""))
|
(f" matching '{target_endpoint_name}'" if target_endpoint_name else ""))
|
||||||
|
|
||||||
for ep in endpoints:
|
for ep in endpoints:
|
||||||
base = _normalize_base(ep.base_url)
|
try:
|
||||||
|
base, api_key = resolve_endpoint_runtime(ep, owner=owner)
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
provider = _detect_provider(base)
|
provider = _detect_provider(base)
|
||||||
headers = build_headers(ep.api_key, base)
|
headers = build_headers(api_key, base)
|
||||||
|
|
||||||
if provider == "anthropic":
|
if provider == "anthropic":
|
||||||
# Anthropic: match against hardcoded model list
|
# Anthropic: match against hardcoded model list
|
||||||
@@ -112,16 +117,20 @@ def _resolve_model(spec: str, owner: Optional[str] = None) -> Tuple[str, str, Di
|
|||||||
else:
|
else:
|
||||||
# OpenAI-compatible and native Ollama: probe the provider's model list.
|
# OpenAI-compatible and native Ollama: probe the provider's model list.
|
||||||
try:
|
try:
|
||||||
r = httpx.get(build_models_url(base), headers=headers, timeout=5)
|
models_url = build_models_url(base)
|
||||||
r.raise_for_status()
|
if models_url:
|
||||||
data = r.json()
|
r = httpx.get(models_url, headers=headers, timeout=5)
|
||||||
model_ids = [m.get("id") for m in (data.get("data") or []) if m.get("id")]
|
r.raise_for_status()
|
||||||
if not model_ids:
|
data = r.json()
|
||||||
model_ids = [
|
model_ids = [m.get("id") for m in (data.get("data") or []) if m.get("id")]
|
||||||
m.get("name") or m.get("model")
|
if not model_ids:
|
||||||
for m in (data.get("models") or [])
|
model_ids = [
|
||||||
if m.get("name") or m.get("model")
|
m.get("name") or m.get("model")
|
||||||
]
|
for m in (data.get("models") or [])
|
||||||
|
if m.get("name") or m.get("model")
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
model_ids = json.loads(ep.cached_models or "[]")
|
||||||
except Exception:
|
except Exception:
|
||||||
model_ids = []
|
model_ids = []
|
||||||
|
|
||||||
@@ -1119,25 +1128,32 @@ async def do_list_models(content: str, session_id: Optional[str] = None, owner:
|
|||||||
total_models = 0
|
total_models = 0
|
||||||
|
|
||||||
for ep in endpoints:
|
for ep in endpoints:
|
||||||
base = _normalize_base(ep.base_url)
|
try:
|
||||||
|
base, api_key = resolve_endpoint_runtime(ep, owner=owner)
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
provider = _detect_provider(base)
|
provider = _detect_provider(base)
|
||||||
headers = build_headers(ep.api_key, base)
|
headers = build_headers(api_key, base)
|
||||||
|
|
||||||
model_ids = []
|
model_ids = []
|
||||||
if provider == "anthropic":
|
if provider == "anthropic":
|
||||||
model_ids = list(ANTHROPIC_MODELS)
|
model_ids = list(ANTHROPIC_MODELS)
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
r = httpx.get(build_models_url(base), headers=headers, timeout=5)
|
models_url = build_models_url(base)
|
||||||
r.raise_for_status()
|
if models_url:
|
||||||
data = r.json()
|
r = httpx.get(models_url, headers=headers, timeout=5)
|
||||||
model_ids = [m.get("id") for m in (data.get("data") or []) if m.get("id")]
|
r.raise_for_status()
|
||||||
if not model_ids:
|
data = r.json()
|
||||||
model_ids = [
|
model_ids = [m.get("id") for m in (data.get("data") or []) if m.get("id")]
|
||||||
m.get("name") or m.get("model")
|
if not model_ids:
|
||||||
for m in (data.get("models") or [])
|
model_ids = [
|
||||||
if m.get("name") or m.get("model")
|
m.get("name") or m.get("model")
|
||||||
]
|
for m in (data.get("models") or [])
|
||||||
|
if m.get("name") or m.get("model")
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
model_ids = json.loads(ep.cached_models or "[]")
|
||||||
except Exception:
|
except Exception:
|
||||||
model_ids = ["(endpoint offline)"]
|
model_ids = ["(endpoint offline)"]
|
||||||
|
|
||||||
@@ -1268,7 +1284,7 @@ async def do_ui_control(content: str, session_id: Optional[str] = None, owner: O
|
|||||||
toggle <name> <on|off> — Toggle a setting (web, bash, rag, research, incognito, document_editor)
|
toggle <name> <on|off> — Toggle a setting (web, bash, rag, research, incognito, document_editor)
|
||||||
set_mode <agent|chat> — Switch between agent and chat mode
|
set_mode <agent|chat> — Switch between agent and chat mode
|
||||||
switch_model <model> — Change the model for the current session
|
switch_model <model> — Change the model for the current session
|
||||||
set_theme <preset> — Apply a theme preset (dark, light, paper, nord, dracula, gruvbox, gpt, claude, lavender, etc.)
|
set_theme <preset> — Apply a built-in theme preset (dark, light, midnight, paper, cyberpunk, retrowave, forest, ocean, ume, copper, terminal, organs, lavender, gpt, claude, cute)
|
||||||
create_theme <name> <bg> <fg> <panel> <border> <accent> [key=val ...] — Create custom theme. Optional key=val: advanced color overrides AND background effects: bgPattern=<none|dots|synapse|rain|constellations|perlin-flow|petals|sparkles|embers>, bgEffectColor=#RRGGBB, bgEffectIntensity=<num>, bgEffectSize=<num>, frosted=true|false
|
create_theme <name> <bg> <fg> <panel> <border> <accent> [key=val ...] — Create custom theme. Optional key=val: advanced color overrides AND background effects: bgPattern=<none|dots|synapse|rain|constellations|perlin-flow|petals|sparkles|embers>, bgEffectColor=#RRGGBB, bgEffectIntensity=<num>, bgEffectSize=<num>, frosted=true|false
|
||||||
open_panel <name> — Open a panel (documents, gallery, email, sessions, notes, memories, skills, settings, cookbook)
|
open_panel <name> — Open a panel (documents, gallery, email, sessions, notes, memories, skills, settings, cookbook)
|
||||||
open_email_reply <uid> [folder] [reply|reply-all|ai-reply] — Open a reply draft document for an email; does not send
|
open_email_reply <uid> [folder] [reply|reply-all|ai-reply] — Open a reply draft document for an email; does not send
|
||||||
@@ -1715,7 +1731,7 @@ async def do_generate_image(content: str, session_id: Optional[str] = None, owne
|
|||||||
|
|
||||||
# GPT image models always return b64_json; DALL-E may return url
|
# GPT image models always return b64_json; DALL-E may return url
|
||||||
if img.get("b64_json"):
|
if img.get("b64_json"):
|
||||||
img_dir = Path("data/generated_images")
|
img_dir = Path(GENERATED_IMAGES_DIR)
|
||||||
img_dir.mkdir(parents=True, exist_ok=True)
|
img_dir.mkdir(parents=True, exist_ok=True)
|
||||||
filename = f"{uuid.uuid4().hex[:12]}.png"
|
filename = f"{uuid.uuid4().hex[:12]}.png"
|
||||||
img_path = img_dir / filename
|
img_path = img_dir / filename
|
||||||
@@ -1728,7 +1744,7 @@ async def do_generate_image(content: str, session_id: Optional[str] = None, owne
|
|||||||
try:
|
try:
|
||||||
dl_resp = httpx.get(img["url"], timeout=60)
|
dl_resp = httpx.get(img["url"], timeout=60)
|
||||||
if dl_resp.status_code == 200:
|
if dl_resp.status_code == 200:
|
||||||
img_dir = Path("data/generated_images")
|
img_dir = Path(GENERATED_IMAGES_DIR)
|
||||||
img_dir.mkdir(parents=True, exist_ok=True)
|
img_dir.mkdir(parents=True, exist_ok=True)
|
||||||
filename = f"{uuid.uuid4().hex[:12]}.png"
|
filename = f"{uuid.uuid4().hex[:12]}.png"
|
||||||
img_path = img_dir / filename
|
img_path = img_dir / filename
|
||||||
|
|||||||
+22
-1
@@ -10,7 +10,7 @@ def get_current_user(request: Request) -> Optional[str]:
|
|||||||
return getattr(request.state, 'current_user', None)
|
return getattr(request.state, 'current_user', None)
|
||||||
|
|
||||||
|
|
||||||
def effective_user(request: Request):
|
def effective_user(request: Request) -> Optional[str]:
|
||||||
"""The real human behind the request, for ownership/attribution.
|
"""The real human behind the request, for ownership/attribution.
|
||||||
|
|
||||||
Cookie sessions resolve to the logged-in username. Bearer ``ody_`` callers
|
Cookie sessions resolve to the logged-in username. Bearer ``ody_`` callers
|
||||||
@@ -34,6 +34,24 @@ def effective_user(request: Request):
|
|||||||
return get_current_user(request)
|
return get_current_user(request)
|
||||||
|
|
||||||
|
|
||||||
|
def _is_api_token_request(request: Request) -> bool:
|
||||||
|
"""Return True when middleware authenticated a bearer API token."""
|
||||||
|
return bool(getattr(request.state, "api_token", False))
|
||||||
|
|
||||||
|
|
||||||
|
def require_authenticated_request(request: Request) -> str:
|
||||||
|
"""Allow either a browser session or a valid bearer API token.
|
||||||
|
|
||||||
|
This is intentionally narrower than :func:`require_user`: use it only for
|
||||||
|
routes that need authentication but do not read or mutate owner-scoped
|
||||||
|
user data. Owner-scoped routes should use ``require_user`` for browser
|
||||||
|
sessions or their own API-token scope/owner gate.
|
||||||
|
"""
|
||||||
|
if _is_api_token_request(request):
|
||||||
|
return effective_user(request) or ""
|
||||||
|
return require_user(request)
|
||||||
|
|
||||||
|
|
||||||
def _auth_disabled() -> bool:
|
def _auth_disabled() -> bool:
|
||||||
"""True when the operator has explicitly turned off auth via .env.
|
"""True when the operator has explicitly turned off auth via .env.
|
||||||
Mirrors the AUTH_ENABLED parse in app.py / core/middleware.py so the
|
Mirrors the AUTH_ENABLED parse in app.py / core/middleware.py so the
|
||||||
@@ -60,6 +78,9 @@ def require_user(request: Request) -> str:
|
|||||||
Use this on routes that touch user data so middleware misconfig can't
|
Use this on routes that touch user data so middleware misconfig can't
|
||||||
open them up.
|
open them up.
|
||||||
"""
|
"""
|
||||||
|
if _is_api_token_request(request):
|
||||||
|
raise HTTPException(403, "API tokens must use a scope-aware API route")
|
||||||
|
|
||||||
u = get_current_user(request)
|
u = get_current_user(request)
|
||||||
if u:
|
if u:
|
||||||
return u
|
return u
|
||||||
|
|||||||
+6
-4
@@ -33,13 +33,15 @@ from core.atomic_io import atomic_write_json
|
|||||||
from core.platform_compat import (
|
from core.platform_compat import (
|
||||||
detached_popen_kwargs,
|
detached_popen_kwargs,
|
||||||
find_bash,
|
find_bash,
|
||||||
|
git_bash_path,
|
||||||
kill_process_tree,
|
kill_process_tree,
|
||||||
pid_alive,
|
pid_alive,
|
||||||
)
|
)
|
||||||
|
|
||||||
_DATA_DIR = Path(os.environ.get("DATA_DIR", "data"))
|
from src.constants import BG_JOBS_DIR, BG_JOBS_FILE
|
||||||
_JOBS_DIR = _DATA_DIR / "bg_jobs"
|
|
||||||
_STORE = _DATA_DIR / "bg_jobs.json"
|
_JOBS_DIR = Path(BG_JOBS_DIR)
|
||||||
|
_STORE = Path(BG_JOBS_FILE)
|
||||||
|
|
||||||
# A job that runs longer than this is presumed stuck and reaped (the agent
|
# A job that runs longer than this is presumed stuck and reaped (the agent
|
||||||
# still gets a "timed out" follow-up so nothing hangs forever).
|
# still gets a "timed out" follow-up so nothing hangs forever).
|
||||||
@@ -106,7 +108,7 @@ def launch(command: str, session_id: str, cwd: Optional[str] = None,
|
|||||||
# handles drive paths and spaces correctly.
|
# handles drive paths and spaces correctly.
|
||||||
cmd_path = _JOBS_DIR / f"{job_id}.cmd.sh"
|
cmd_path = _JOBS_DIR / f"{job_id}.cmd.sh"
|
||||||
cmd_path.write_text(command + "\n", encoding="utf-8")
|
cmd_path.write_text(command + "\n", encoding="utf-8")
|
||||||
lp, xp, cp = (shlex.quote(p.as_posix()) for p in (log_path, exit_path, cmd_path))
|
lp, xp, cp = (shlex.quote(git_bash_path(p)) for p in (log_path, exit_path, cmd_path))
|
||||||
script_path = _JOBS_DIR / f"{job_id}.sh"
|
script_path = _JOBS_DIR / f"{job_id}.sh"
|
||||||
script_path.write_text(
|
script_path.write_text(
|
||||||
f"bash {cp} > {lp} 2>&1\n"
|
f"bash {cp} > {lp} 2>&1\n"
|
||||||
|
|||||||
+34
-23
@@ -12,6 +12,8 @@ from typing import Tuple
|
|||||||
|
|
||||||
from src.auth_helpers import owner_filter
|
from src.auth_helpers import owner_filter
|
||||||
from core.platform_compat import IS_WINDOWS, find_bash
|
from core.platform_compat import IS_WINDOWS, find_bash
|
||||||
|
from core.constants import internal_api_base
|
||||||
|
from src.constants import DATA_DIR, DEEP_RESEARCH_DIR, TIDY_CALENDAR_STATE_FILE, EMAIL_URGENCY_CACHE_DIR, COOKBOOK_STATE_FILE
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -166,7 +168,6 @@ async def action_consolidate_memory(owner: str, **kwargs) -> Tuple[str, bool]:
|
|||||||
drop_items = decision.get("drop") if isinstance(decision, dict) else None
|
drop_items = decision.get("drop") if isinstance(decision, dict) else None
|
||||||
if isinstance(keep_items, list) and isinstance(drop_items, list):
|
if isinstance(keep_items, list) and isinstance(drop_items, list):
|
||||||
by_id = {m.get("id"): m for m in group_memories if m.get("id")}
|
by_id = {m.get("id"): m for m in group_memories if m.get("id")}
|
||||||
keep_ids = set()
|
|
||||||
cleaned_by_id = {}
|
cleaned_by_id = {}
|
||||||
for item in keep_items:
|
for item in keep_items:
|
||||||
if not isinstance(item, dict):
|
if not isinstance(item, dict):
|
||||||
@@ -177,7 +178,6 @@ async def action_consolidate_memory(owner: str, **kwargs) -> Tuple[str, bool]:
|
|||||||
text = (item.get("text") or "").strip()
|
text = (item.get("text") or "").strip()
|
||||||
if not text:
|
if not text:
|
||||||
continue
|
continue
|
||||||
keep_ids.add(mid)
|
|
||||||
cleaned = {
|
cleaned = {
|
||||||
"category": (item.get("category") or by_id[mid].get("category") or "fact").strip(),
|
"category": (item.get("category") or by_id[mid].get("category") or "fact").strip(),
|
||||||
}
|
}
|
||||||
@@ -186,11 +186,20 @@ async def action_consolidate_memory(owner: str, **kwargs) -> Tuple[str, bool]:
|
|||||||
cleaned["text"] = text
|
cleaned["text"] = text
|
||||||
cleaned_by_id[mid] = cleaned
|
cleaned_by_id[mid] = cleaned
|
||||||
|
|
||||||
# If the model only saw a truncated memory, do not let
|
# Delete only memories the model EXPLICITLY dropped, never
|
||||||
# that partial view delete or rewrite the full memory.
|
# ones it merely omitted from `keep`. Treating the
|
||||||
keep_ids.update(mid for mid in truncated_ids if mid in by_id)
|
# complement of `keep` as deletions meant a model that
|
||||||
|
# forgot to re-list an id (common) silently destroyed that
|
||||||
|
# memory. Honor the explicit `drop` set instead.
|
||||||
|
drop_ids = {
|
||||||
|
d.get("id")
|
||||||
|
for d in drop_items
|
||||||
|
if isinstance(d, dict) and d.get("id") in by_id
|
||||||
|
}
|
||||||
|
# Never delete a memory the model only saw truncated.
|
||||||
|
drop_ids -= truncated_ids
|
||||||
|
|
||||||
if keep_ids:
|
if drop_ids or cleaned_by_id:
|
||||||
changed_text = 0
|
changed_text = 0
|
||||||
group_ref_ids = {id(m) for m in group_memories}
|
group_ref_ids = {id(m) for m in group_memories}
|
||||||
kept_all = []
|
kept_all = []
|
||||||
@@ -199,7 +208,7 @@ async def action_consolidate_memory(owner: str, **kwargs) -> Tuple[str, bool]:
|
|||||||
kept_all.append(mem)
|
kept_all.append(mem)
|
||||||
continue
|
continue
|
||||||
mid = mem.get("id")
|
mid = mem.get("id")
|
||||||
if mid not in keep_ids:
|
if mid in drop_ids:
|
||||||
continue
|
continue
|
||||||
cleaned = cleaned_by_id.get(mid) or {}
|
cleaned = cleaned_by_id.get(mid) or {}
|
||||||
if mid in truncated_ids:
|
if mid in truncated_ids:
|
||||||
@@ -211,7 +220,7 @@ async def action_consolidate_memory(owner: str, **kwargs) -> Tuple[str, bool]:
|
|||||||
mem["category"] = cleaned["category"]
|
mem["category"] = cleaned["category"]
|
||||||
kept_all.append(mem)
|
kept_all.append(mem)
|
||||||
|
|
||||||
removed = len(group_memories) - len(keep_ids)
|
removed = sum(1 for m in group_memories if m.get("id") in drop_ids)
|
||||||
total_scanned += len(group_memories)
|
total_scanned += len(group_memories)
|
||||||
if removed or changed_text:
|
if removed or changed_text:
|
||||||
all_memories = kept_all
|
all_memories = kept_all
|
||||||
@@ -348,7 +357,7 @@ async def action_tidy_research(owner: str, **kwargs) -> Tuple[str, bool]:
|
|||||||
try:
|
try:
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import json as _json
|
import json as _json
|
||||||
research_dir = Path("data/deep_research")
|
research_dir = Path(DEEP_RESEARCH_DIR)
|
||||||
if not research_dir.exists():
|
if not research_dir.exists():
|
||||||
raise TaskNoop("no research directory")
|
raise TaskNoop("no research directory")
|
||||||
files = list(research_dir.glob("*.json"))
|
files = list(research_dir.glob("*.json"))
|
||||||
@@ -386,7 +395,7 @@ async def action_tidy_calendar(owner: str, **kwargs) -> Tuple[str, bool]:
|
|||||||
from core.database import SessionLocal, CalendarEvent
|
from core.database import SessionLocal, CalendarEvent
|
||||||
from sqlalchemy import func
|
from sqlalchemy import func
|
||||||
|
|
||||||
STATE_FILE = Path("data/tidy_calendar_state.json")
|
STATE_FILE = Path(TIDY_CALENDAR_STATE_FILE)
|
||||||
last_watermark = None
|
last_watermark = None
|
||||||
try:
|
try:
|
||||||
if STATE_FILE.exists():
|
if STATE_FILE.exists():
|
||||||
@@ -593,9 +602,9 @@ async def action_classify_events(owner: str, **kwargs) -> Tuple[str, bool]:
|
|||||||
if not events:
|
if not events:
|
||||||
return "No upcoming events to classify", True
|
return "No upcoming events to classify", True
|
||||||
|
|
||||||
llm_url, llm_model, llm_headers = resolve_endpoint("utility")
|
llm_url, llm_model, llm_headers = resolve_endpoint("utility", owner=owner)
|
||||||
if not llm_url:
|
if not llm_url:
|
||||||
llm_url, llm_model, llm_headers = resolve_endpoint("default")
|
llm_url, llm_model, llm_headers = resolve_endpoint("default", owner=owner)
|
||||||
llm_available = bool(llm_url and llm_model)
|
llm_available = bool(llm_url and llm_model)
|
||||||
|
|
||||||
# Pull user memories so the LLM has personal context (relationships,
|
# Pull user memories so the LLM has personal context (relationships,
|
||||||
@@ -867,9 +876,9 @@ async def action_learn_sender_signatures(owner: str, **kwargs) -> Tuple[str, boo
|
|||||||
if not eligible:
|
if not eligible:
|
||||||
return "All sender sigs already cached (or no eligible senders)", True
|
return "All sender sigs already cached (or no eligible senders)", True
|
||||||
|
|
||||||
url, model, headers = resolve_endpoint("utility")
|
url, model, headers = resolve_endpoint("utility", owner=owner)
|
||||||
if not url or not model:
|
if not url or not model:
|
||||||
url, model, headers = resolve_endpoint("default")
|
url, model, headers = resolve_endpoint("default", owner=owner)
|
||||||
if not url or not model:
|
if not url or not model:
|
||||||
return "No LLM endpoint available", False
|
return "No LLM endpoint available", False
|
||||||
|
|
||||||
@@ -1303,12 +1312,12 @@ async def action_ping_notes(owner: str, **kwargs) -> Tuple[str, bool]:
|
|||||||
# users' entries (review C4). Legacy path kept as fallback so a
|
# users' entries (review C4). Legacy path kept as fallback so a
|
||||||
# single-user install (empty owner) doesn't lose its history.
|
# single-user install (empty owner) doesn't lose its history.
|
||||||
_owner_slug = "".join(c if (c.isalnum() or c in "-_.@") else "_" for c in (owner or "default"))
|
_owner_slug = "".join(c if (c.isalnum() or c in "-_.@") else "_" for c in (owner or "default"))
|
||||||
STATE = _P(f"data/note_pings_{_owner_slug}.json")
|
STATE = _P(DATA_DIR) / f"note_pings_{_owner_slug}.json"
|
||||||
STATE.parent.mkdir(parents=True, exist_ok=True)
|
STATE.parent.mkdir(parents=True, exist_ok=True)
|
||||||
# One-time migration: if legacy global file exists and per-owner file
|
# One-time migration: if legacy global file exists and per-owner file
|
||||||
# doesn't, seed from global (entries for OTHER owners still get pruned
|
# doesn't, seed from global (entries for OTHER owners still get pruned
|
||||||
# on their first run — acceptable, prevents silent loss).
|
# on their first run — acceptable, prevents silent loss).
|
||||||
_legacy = _P("data/note_pings.json")
|
_legacy = _P(DATA_DIR) / "note_pings.json"
|
||||||
if _legacy.exists() and not STATE.exists():
|
if _legacy.exists() and not STATE.exists():
|
||||||
try:
|
try:
|
||||||
STATE.write_text(_legacy.read_text(encoding="utf-8"), encoding="utf-8")
|
STATE.write_text(_legacy.read_text(encoding="utf-8"), encoding="utf-8")
|
||||||
@@ -1465,8 +1474,8 @@ async def action_check_email_urgency(owner: str, **kwargs) -> Tuple[str, bool]:
|
|||||||
# notified_uids / urgency counts. Empty owner falls back to a generic
|
# notified_uids / urgency counts. Empty owner falls back to a generic
|
||||||
# filename for single-user installs (matches prior behaviour).
|
# filename for single-user installs (matches prior behaviour).
|
||||||
_owner_slug = "".join(c if (c.isalnum() or c in "-_.@") else "_" for c in (owner or "default"))
|
_owner_slug = "".join(c if (c.isalnum() or c in "-_.@") else "_" for c in (owner or "default"))
|
||||||
STATE_PATH = _P(f"data/email_urgency_state_{_owner_slug}.json")
|
STATE_PATH = _P(DATA_DIR) / f"email_urgency_state_{_owner_slug}.json"
|
||||||
CACHE_DIR = _P("data/email_urgency_cache")
|
CACHE_DIR = _P(EMAIL_URGENCY_CACHE_DIR)
|
||||||
CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
STATE_PATH.parent.mkdir(parents=True, exist_ok=True)
|
STATE_PATH.parent.mkdir(parents=True, exist_ok=True)
|
||||||
AGE_CUTOFF = _dt.utcnow() - _td(days=7)
|
AGE_CUTOFF = _dt.utcnow() - _td(days=7)
|
||||||
@@ -1480,12 +1489,12 @@ async def action_check_email_urgency(owner: str, **kwargs) -> Tuple[str, bool]:
|
|||||||
|
|
||||||
# ── 1. Resolve LLM candidates (utility primary + utility fallbacks; fall
|
# ── 1. Resolve LLM candidates (utility primary + utility fallbacks; fall
|
||||||
# through to default chat as a last resort).
|
# through to default chat as a last resort).
|
||||||
url, model, headers = resolve_endpoint("utility")
|
url, model, headers = resolve_endpoint("utility", owner=owner)
|
||||||
if not url or not model:
|
if not url or not model:
|
||||||
url, model, headers = resolve_endpoint("default")
|
url, model, headers = resolve_endpoint("default", owner=owner)
|
||||||
if not url or not model:
|
if not url or not model:
|
||||||
return "No LLM endpoint available", False
|
return "No LLM endpoint available", False
|
||||||
candidates = [(url, model, headers)] + resolve_utility_fallback_candidates()
|
candidates = [(url, model, headers)] + resolve_utility_fallback_candidates(owner=owner)
|
||||||
|
|
||||||
# ── 2. Enumerate enabled accounts. Match this task's owner AND fall
|
# ── 2. Enumerate enabled accounts. Match this task's owner AND fall
|
||||||
# back to the legacy "unowned account whose imap_user / from_address
|
# back to the legacy "unowned account whose imap_user / from_address
|
||||||
@@ -1902,6 +1911,8 @@ async def action_check_email_urgency(owner: str, **kwargs) -> Tuple[str, bool]:
|
|||||||
delivered = bool(dispatch_result.get("email_sent"))
|
delivered = bool(dispatch_result.get("email_sent"))
|
||||||
elif channel == "ntfy":
|
elif channel == "ntfy":
|
||||||
delivered = bool(dispatch_result.get("ntfy_sent"))
|
delivered = bool(dispatch_result.get("ntfy_sent"))
|
||||||
|
elif channel == "webhook":
|
||||||
|
delivered = bool(dispatch_result.get("webhook_sent"))
|
||||||
if delivered:
|
if delivered:
|
||||||
newly_notified.update(new_urgent)
|
newly_notified.update(new_urgent)
|
||||||
else:
|
else:
|
||||||
@@ -2040,7 +2051,7 @@ async def action_cookbook_serve(
|
|||||||
except Exception:
|
except Exception:
|
||||||
end_after_min = 0
|
end_after_min = 0
|
||||||
|
|
||||||
state_path = Path("/app/data/cookbook_state.json")
|
state_path = Path(COOKBOOK_STATE_FILE)
|
||||||
try:
|
try:
|
||||||
state = json.loads(state_path.read_text(encoding="utf-8")) if state_path.exists() else {}
|
state = json.loads(state_path.read_text(encoding="utf-8")) if state_path.exists() else {}
|
||||||
except Exception:
|
except Exception:
|
||||||
@@ -2116,7 +2127,7 @@ async def action_cookbook_serve(
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
async with httpx.AsyncClient(timeout=30) as client:
|
async with httpx.AsyncClient(timeout=30) as client:
|
||||||
r = await client.post("http://localhost:7000/api/model/serve",
|
r = await client.post(f"{internal_api_base()}/api/model/serve",
|
||||||
json=body, headers=headers)
|
json=body, headers=headers)
|
||||||
data = r.json() if r.content else {}
|
data = r.json() if r.content else {}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
+241
-51
@@ -27,6 +27,7 @@ import hashlib
|
|||||||
import ipaddress
|
import ipaddress
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import socket
|
||||||
import uuid
|
import uuid
|
||||||
from datetime import date, datetime, timedelta, timezone
|
from datetime import date, datetime, timedelta, timezone
|
||||||
from urllib.parse import urlparse, urlunparse
|
from urllib.parse import urlparse, urlunparse
|
||||||
@@ -50,15 +51,55 @@ def _private_caldav_allowed() -> bool:
|
|||||||
return os.environ.get("ODYSSEUS_ALLOW_PRIVATE_CALDAV", "0").lower() in {"1", "true", "yes"}
|
return os.environ.get("ODYSSEUS_ALLOW_PRIVATE_CALDAV", "0").lower() in {"1", "true", "yes"}
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_caldav_address(addr: ipaddress._BaseAddress) -> None:
|
||||||
|
if isinstance(addr, ipaddress.IPv6Address) and addr.ipv4_mapped is not None:
|
||||||
|
addr = addr.ipv4_mapped
|
||||||
|
if (
|
||||||
|
addr.is_loopback
|
||||||
|
or addr.is_link_local
|
||||||
|
or addr.is_multicast
|
||||||
|
or addr.is_unspecified
|
||||||
|
or addr.is_reserved
|
||||||
|
):
|
||||||
|
raise ValueError("CalDAV URL host is not allowed")
|
||||||
|
if addr.is_private and not _private_caldav_allowed():
|
||||||
|
raise ValueError("Private CalDAV IPs require ODYSSEUS_ALLOW_PRIVATE_CALDAV=1")
|
||||||
|
|
||||||
|
|
||||||
def _validate_caldav_ip(host: str) -> None:
|
def _validate_caldav_ip(host: str) -> None:
|
||||||
try:
|
try:
|
||||||
ip = ipaddress.ip_address(host.strip("[]"))
|
ip = ipaddress.ip_address(host.strip("[]"))
|
||||||
except ValueError:
|
except ValueError:
|
||||||
return
|
return
|
||||||
if ip.is_loopback or ip.is_link_local or ip.is_multicast or ip.is_unspecified:
|
_validate_caldav_address(ip)
|
||||||
raise ValueError("CalDAV URL host is not allowed")
|
|
||||||
if ip.is_private and not _private_caldav_allowed():
|
|
||||||
raise ValueError("Private CalDAV IPs require ODYSSEUS_ALLOW_PRIVATE_CALDAV=1")
|
def _resolve_caldav_host_ips(host: str) -> list[ipaddress._BaseAddress]:
|
||||||
|
addrs: list[ipaddress._BaseAddress] = []
|
||||||
|
for family, _, _, _, sockaddr in socket.getaddrinfo(host, None):
|
||||||
|
if family not in (socket.AF_INET, socket.AF_INET6):
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
addrs.append(ipaddress.ip_address(sockaddr[0].split("%", 1)[0]))
|
||||||
|
except ValueError:
|
||||||
|
continue
|
||||||
|
return addrs
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_caldav_hostname(host: str) -> None:
|
||||||
|
try:
|
||||||
|
ipaddress.ip_address(host.strip("[]"))
|
||||||
|
return
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
addrs = _resolve_caldav_host_ips(host)
|
||||||
|
except OSError:
|
||||||
|
raise ValueError("CalDAV URL host does not resolve")
|
||||||
|
if not addrs:
|
||||||
|
raise ValueError("CalDAV URL host does not resolve")
|
||||||
|
for addr in addrs:
|
||||||
|
_validate_caldav_address(addr)
|
||||||
|
|
||||||
|
|
||||||
def validate_caldav_url(raw_url: str) -> str:
|
def validate_caldav_url(raw_url: str) -> str:
|
||||||
@@ -83,15 +124,18 @@ def validate_caldav_url(raw_url: str) -> str:
|
|||||||
if host in _BLOCKED_HOSTS or host.endswith(".localhost"):
|
if host in _BLOCKED_HOSTS or host.endswith(".localhost"):
|
||||||
raise ValueError("CalDAV URL host is not allowed")
|
raise ValueError("CalDAV URL host is not allowed")
|
||||||
_validate_caldav_ip(host)
|
_validate_caldav_ip(host)
|
||||||
|
_validate_caldav_hostname(host)
|
||||||
return urlunparse(parsed._replace(fragment="")).rstrip("/")
|
return urlunparse(parsed._replace(fragment="")).rstrip("/")
|
||||||
|
|
||||||
|
|
||||||
def _stable_cal_id(remote_url: str, owner: str = "") -> str:
|
def _stable_cal_id(remote_url: str, owner: str = "", account_id: str = "") -> str:
|
||||||
"""Deterministic local id for a remote CalDAV calendar — same URL
|
"""Deterministic local id for a remote CalDAV calendar, scoped to owner
|
||||||
always maps to the same local row across restarts and re-syncs.
|
and account so two users — or one user with two accounts — pointing at
|
||||||
Owner is included in the hash to prevent PK collisions when multiple
|
the same server URL get distinct local rows (avoids PK collision, #2765).
|
||||||
users sync the same CalDAV endpoint."""
|
The owner and account_id default to "" for the legacy/URL-only path so
|
||||||
h = hashlib.sha256(f"{owner}:{remote_url}".encode("utf-8")).hexdigest()[:24]
|
existing callers without those arguments keep working."""
|
||||||
|
key = f"{owner}\n{account_id}\n{remote_url}"
|
||||||
|
h = hashlib.sha256(key.encode("utf-8")).hexdigest()[:24]
|
||||||
return f"caldav-{h}"
|
return f"caldav-{h}"
|
||||||
|
|
||||||
|
|
||||||
@@ -126,18 +170,103 @@ def _find_existing_event(db, pending, uid_val, calendar_id):
|
|||||||
).first()
|
).first()
|
||||||
|
|
||||||
|
|
||||||
def _sync_blocking(owner: str, url: str, username: str, password: str) -> dict:
|
def _google_caldav_events_url(url: str) -> str | None:
|
||||||
|
"""Map a Google CalDAV *principal* URL to its event-collection URL.
|
||||||
|
|
||||||
|
Google serves the principal at ``…/user`` but events live under ``…/events``
|
||||||
|
— the ``/user`` resource holds no VEVENTs. The `caldav` library's
|
||||||
|
principal→home-set discovery does not reliably enumerate calendars from
|
||||||
|
Google's ``/user`` endpoint, so the sync falls into the "treat the URL as a
|
||||||
|
single calendar" fallback below. Pointed at ``/user`` that fallback issues
|
||||||
|
every calendar-query REPORT against the principal, which returns a clean but
|
||||||
|
empty 200 for all date ranges — the calendar shows no events even though
|
||||||
|
auth succeeded (issue #2507).
|
||||||
|
|
||||||
|
Both Google CalDAV endpoint forms are handled, since some accounts only
|
||||||
|
authenticate against one of them:
|
||||||
|
- newer: ``https://apidata.googleusercontent.com/caldav/v2/<id>/user``
|
||||||
|
- legacy: ``https://www.google.com/calendar/dav/<id>/user``
|
||||||
|
|
||||||
|
Returns the events URL for a recognised Google principal URL, else None so
|
||||||
|
the caller keeps the original URL unchanged.
|
||||||
|
"""
|
||||||
|
parts = urlparse(url)
|
||||||
|
host = (parts.hostname or "").lower()
|
||||||
|
path = parts.path.rstrip("/")
|
||||||
|
if not path.endswith("/user"):
|
||||||
|
return None
|
||||||
|
is_google = (
|
||||||
|
host.endswith("googleusercontent.com") # newer /caldav/v2 form
|
||||||
|
or (host in ("www.google.com", "google.com") and "/calendar/dav/" in path) # legacy form
|
||||||
|
)
|
||||||
|
if not is_google:
|
||||||
|
return None
|
||||||
|
new_path = path[: -len("/user")] + "/events"
|
||||||
|
return urlunparse(parts._replace(path=new_path))
|
||||||
|
|
||||||
|
|
||||||
|
def _open_url_as_calendar(client, url: str):
|
||||||
|
"""Open ``url`` as a single calendar collection.
|
||||||
|
|
||||||
|
Used when principal discovery yields no calendars. Google's principal URL
|
||||||
|
is not an event collection, so map it to the events URL first
|
||||||
|
(see ``_google_caldav_events_url``); other servers' URLs are used as-is.
|
||||||
|
"""
|
||||||
|
target = _google_caldav_events_url(url) or url
|
||||||
|
return client.calendar(url=target)
|
||||||
|
|
||||||
|
|
||||||
|
def _build_dav_client(url: str, username: str, password: str):
|
||||||
|
"""Construct a CalDAV client with automatic redirects disabled.
|
||||||
|
|
||||||
|
``validate_caldav_url`` resolves and vets the *initial* host, but caldav's
|
||||||
|
underlying HTTP session follows 3xx redirects by default. So a URL that
|
||||||
|
passes validation can still be redirected — at request time — to
|
||||||
|
loopback / link-local / private space, re-opening the SSRF the host check
|
||||||
|
closes. Pin the session to zero redirects: any 3xx then raises instead of
|
||||||
|
silently following an attacker-chosen ``Location``. This mirrors the
|
||||||
|
test-connection path in ``routes/calendar_routes.py``, which already sets
|
||||||
|
``follow_redirects=False``.
|
||||||
|
|
||||||
|
DAVClient exposes no per-request redirect flag, so we set it on the session
|
||||||
|
after construction (the session is created in ``__init__``).
|
||||||
|
"""
|
||||||
|
import caldav
|
||||||
|
|
||||||
|
client = caldav.DAVClient(url=url, username=username, password=password)
|
||||||
|
# Unconditional: a redirect-disable that only sometimes applies is not a
|
||||||
|
# control. The session exists right after __init__ on every real client;
|
||||||
|
# test_build_dav_client_disables_redirects asserts it against installed
|
||||||
|
# caldav in CI.
|
||||||
|
client.session.max_redirects = 0
|
||||||
|
return client
|
||||||
|
|
||||||
|
|
||||||
|
def _should_prune_window(seen_uids: set, parse_failed: bool) -> bool:
|
||||||
|
"""Whether the post-sync prune of vanished CalDAV events is safe to run.
|
||||||
|
|
||||||
|
The prune deletes local ``origin=="caldav"`` rows in the window whose UID the
|
||||||
|
server did not just return. Any parse failure (total or partial) makes
|
||||||
|
``seen_uids`` an incomplete view of the server, so pruning against it can
|
||||||
|
delete events that still exist upstream but could not be read: a total
|
||||||
|
failure wipes the whole window, a partial failure deletes just the
|
||||||
|
unreadable ones. Only prune on a clean read. An empty ``seen_uids`` after a
|
||||||
|
clean read is a genuinely empty window, which is safe to prune.
|
||||||
|
"""
|
||||||
|
return not parse_failed
|
||||||
|
|
||||||
|
|
||||||
|
def _sync_blocking(owner: str, url: str, username: str, password: str, account_id: str = "") -> dict:
|
||||||
"""The actual sync — synchronous, intended to run in a threadpool.
|
"""The actual sync — synchronous, intended to run in a threadpool.
|
||||||
Returns counts: {calendars, events, deleted, errors}."""
|
Returns counts: {calendars, events, deleted, errors}."""
|
||||||
# Lazy imports so a missing `caldav` dep doesn't break app startup —
|
# Lazy imports so a missing `caldav` dep doesn't break app startup —
|
||||||
# the integrations form still works, sync just no-ops with an error.
|
# the integrations form still works, sync just no-ops with an error.
|
||||||
import caldav
|
|
||||||
from caldav.lib.error import AuthorizationError, NotFoundError
|
from caldav.lib.error import AuthorizationError, NotFoundError
|
||||||
from core.database import CalendarCal, CalendarEvent, SessionLocal
|
from core.database import CalendarCal, CalendarEvent, SessionLocal
|
||||||
|
|
||||||
result = {"calendars": 0, "events": 0, "deleted": 0, "errors": []}
|
result = {"calendars": 0, "events": 0, "deleted": 0, "errors": []}
|
||||||
|
|
||||||
client = caldav.DAVClient(url=url, username=username, password=password)
|
client = _build_dav_client(url, username, password)
|
||||||
|
|
||||||
# Discovery: try principal → calendars first; if the server doesn't
|
# Discovery: try principal → calendars first; if the server doesn't
|
||||||
# support discovery (or the URL points directly at a calendar), fall
|
# support discovery (or the URL points directly at a calendar), fall
|
||||||
@@ -152,14 +281,14 @@ def _sync_blocking(owner: str, url: str, username: str, password: str) -> dict:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.info(f"CalDAV principal discovery failed, trying URL as calendar: {e}")
|
logger.info(f"CalDAV principal discovery failed, trying URL as calendar: {e}")
|
||||||
try:
|
try:
|
||||||
calendars = [client.calendar(url=url)]
|
calendars = [_open_url_as_calendar(client, url)]
|
||||||
except Exception as e2:
|
except Exception as e2:
|
||||||
result["errors"].append(f"Could not open URL as calendar: {e2}")
|
result["errors"].append(f"Could not open URL as calendar: {e2}")
|
||||||
return result
|
return result
|
||||||
|
|
||||||
if not calendars:
|
if not calendars:
|
||||||
try:
|
try:
|
||||||
calendars = [client.calendar(url=url)]
|
calendars = [_open_url_as_calendar(client, url)]
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
result["errors"].append(f"No calendars and URL fallback failed: {e}")
|
result["errors"].append(f"No calendars and URL fallback failed: {e}")
|
||||||
return result
|
return result
|
||||||
@@ -172,7 +301,7 @@ def _sync_blocking(owner: str, url: str, username: str, password: str) -> dict:
|
|||||||
for remote_cal in calendars:
|
for remote_cal in calendars:
|
||||||
try:
|
try:
|
||||||
remote_url = str(remote_cal.url)
|
remote_url = str(remote_cal.url)
|
||||||
cal_id = _stable_cal_id(remote_url, owner)
|
cal_id = _stable_cal_id(remote_url, owner=owner, account_id=account_id)
|
||||||
display_name = (remote_cal.name or "").strip() or "CalDAV"
|
display_name = (remote_cal.name or "").strip() or "CalDAV"
|
||||||
|
|
||||||
local_cal = db.query(CalendarCal).filter(
|
local_cal = db.query(CalendarCal).filter(
|
||||||
@@ -186,14 +315,20 @@ def _sync_blocking(owner: str, url: str, username: str, password: str) -> dict:
|
|||||||
name=display_name,
|
name=display_name,
|
||||||
color="#5b8abf",
|
color="#5b8abf",
|
||||||
source="caldav",
|
source="caldav",
|
||||||
|
account_id=account_id or None,
|
||||||
)
|
)
|
||||||
db.add(local_cal)
|
db.add(local_cal)
|
||||||
db.commit()
|
db.commit()
|
||||||
else:
|
else:
|
||||||
# Refresh the display name if the user renamed it
|
# Refresh display name and stamp account_id if missing.
|
||||||
# remotely; preserve any local color override.
|
changed = False
|
||||||
if local_cal.name != display_name:
|
if local_cal.name != display_name:
|
||||||
local_cal.name = display_name
|
local_cal.name = display_name
|
||||||
|
changed = True
|
||||||
|
if account_id and not local_cal.account_id:
|
||||||
|
local_cal.account_id = account_id
|
||||||
|
changed = True
|
||||||
|
if changed:
|
||||||
db.commit()
|
db.commit()
|
||||||
result["calendars"] += 1
|
result["calendars"] += 1
|
||||||
|
|
||||||
@@ -207,6 +342,7 @@ def _sync_blocking(owner: str, url: str, username: str, password: str) -> dict:
|
|||||||
# duplicate UIDs within the same batch are updated, not re-inserted
|
# duplicate UIDs within the same batch are updated, not re-inserted
|
||||||
# (which would violate the UNIQUE constraint on commit).
|
# (which would violate the UNIQUE constraint on commit).
|
||||||
pending: dict = {}
|
pending: dict = {}
|
||||||
|
parse_failed = False
|
||||||
try:
|
try:
|
||||||
objs = remote_cal.date_search(start=start, end=end, expand=False)
|
objs = remote_cal.date_search(start=start, end=end, expand=False)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -218,6 +354,7 @@ def _sync_blocking(owner: str, url: str, username: str, password: str) -> dict:
|
|||||||
ical = iCal.from_ical(obj.data)
|
ical = iCal.from_ical(obj.data)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
result["errors"].append(f"{display_name}: parse failed ({e})")
|
result["errors"].append(f"{display_name}: parse failed ({e})")
|
||||||
|
parse_failed = True
|
||||||
continue
|
continue
|
||||||
|
|
||||||
for comp in ical.walk():
|
for comp in ical.walk():
|
||||||
@@ -294,17 +431,23 @@ def _sync_blocking(owner: str, url: str, username: str, password: str) -> dict:
|
|||||||
# are prunable; locally-created events (agent / email triage / a
|
# are prunable; locally-created events (agent / email triage / a
|
||||||
# UI event whose write-back failed) carry origin NULL and must
|
# UI event whose write-back failed) carry origin NULL and must
|
||||||
# never be deleted just because the server didn't return them.
|
# never be deleted just because the server didn't return them.
|
||||||
stale = db.query(CalendarEvent).filter(
|
# Skip the prune on any parse failure: seen_uids is then an
|
||||||
CalendarEvent.calendar_id == local_cal.id,
|
# incomplete view of the server, so pruning against it would
|
||||||
CalendarEvent.origin == "caldav",
|
# delete events that still exist upstream but could not be read
|
||||||
CalendarEvent.dtstart >= start,
|
# (the empty-seen_uids case wipes the whole window; a partial
|
||||||
CalendarEvent.dtstart <= end,
|
# failure deletes just the unreadable rows).
|
||||||
~CalendarEvent.uid.in_(seen_uids) if seen_uids else CalendarEvent.uid.isnot(None),
|
if _should_prune_window(seen_uids, parse_failed):
|
||||||
).all()
|
stale = db.query(CalendarEvent).filter(
|
||||||
for ev in stale:
|
CalendarEvent.calendar_id == local_cal.id,
|
||||||
db.delete(ev)
|
CalendarEvent.origin == "caldav",
|
||||||
result["deleted"] += len(stale)
|
CalendarEvent.dtstart >= start,
|
||||||
db.commit()
|
CalendarEvent.dtstart <= end,
|
||||||
|
~CalendarEvent.uid.in_(seen_uids) if seen_uids else CalendarEvent.uid.isnot(None),
|
||||||
|
).all()
|
||||||
|
for ev in stale:
|
||||||
|
db.delete(ev)
|
||||||
|
result["deleted"] += len(stale)
|
||||||
|
db.commit()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception("CalDAV sync failed for one calendar")
|
logger.exception("CalDAV sync failed for one calendar")
|
||||||
result["errors"].append(str(e)[:200])
|
result["errors"].append(str(e)[:200])
|
||||||
@@ -315,31 +458,78 @@ def _sync_blocking(owner: str, url: str, username: str, password: str) -> dict:
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
async def sync_caldav(owner: str) -> dict:
|
def _load_caldav_accounts(owner: str) -> list:
|
||||||
"""Pull CalDAV state into local DB for `owner`. Returns counts +
|
"""Return the list of CalDAV accounts for *owner*, auto-migrating the legacy
|
||||||
errors. Loads credentials from the user's prefs; no-ops with a
|
single-account ``caldav`` key to the new ``caldav_accounts`` list on first call.
|
||||||
clear error if CalDAV isn't configured."""
|
|
||||||
|
The save step is best-effort: if ``_save_for_user`` is unavailable (e.g. in a
|
||||||
|
test with a minimal prefs mock) the migrated accounts are still returned; the
|
||||||
|
next real call will just re-run the cheap migration again.
|
||||||
|
"""
|
||||||
|
import uuid as _uuid
|
||||||
from routes.prefs_routes import _load_for_user
|
from routes.prefs_routes import _load_for_user
|
||||||
|
|
||||||
cfg = (_load_for_user(owner) or {}).get("caldav", {}) or {}
|
prefs = _load_for_user(owner) or {}
|
||||||
url = (cfg.get("url") or "").strip()
|
if "caldav_accounts" in prefs:
|
||||||
user = (cfg.get("username") or "").strip()
|
return list(prefs["caldav_accounts"] or [])
|
||||||
pw = cfg.get("password") or ""
|
# Migrate legacy single-account config to the list format.
|
||||||
try:
|
legacy = prefs.get("caldav", {}) or {}
|
||||||
from src.secret_storage import decrypt
|
if legacy.get("url"):
|
||||||
pw = decrypt(pw)
|
accounts = [{
|
||||||
except Exception:
|
"id": str(_uuid.uuid4()),
|
||||||
pass
|
"label": "CalDAV",
|
||||||
if not (url and user and pw):
|
"url": legacy["url"],
|
||||||
|
"username": legacy.get("username", ""),
|
||||||
|
"password": legacy.get("password", ""),
|
||||||
|
}]
|
||||||
|
prefs["caldav_accounts"] = accounts
|
||||||
|
prefs.pop("caldav", None)
|
||||||
|
try:
|
||||||
|
from routes.prefs_routes import _save_for_user
|
||||||
|
_save_for_user(owner, prefs)
|
||||||
|
except (ImportError, AttributeError):
|
||||||
|
pass # best-effort; next call re-migrates from the still-present legacy key
|
||||||
|
return accounts
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
async def sync_caldav(owner: str) -> dict:
|
||||||
|
"""Pull CalDAV state into local DB for `owner` across all configured accounts.
|
||||||
|
Returns aggregated counts + per-account errors."""
|
||||||
|
from src.secret_storage import decrypt
|
||||||
|
|
||||||
|
accounts = _load_caldav_accounts(owner)
|
||||||
|
if not accounts:
|
||||||
return {
|
return {
|
||||||
"calendars": 0, "events": 0, "deleted": 0,
|
"calendars": 0, "events": 0, "deleted": 0,
|
||||||
"errors": ["CalDAV is not configured"],
|
"errors": ["CalDAV is not configured"],
|
||||||
}
|
}
|
||||||
try:
|
|
||||||
url = validate_caldav_url(url)
|
totals: dict = {"calendars": 0, "events": 0, "deleted": 0, "errors": []}
|
||||||
return await asyncio.to_thread(_sync_blocking, owner, url, user, pw)
|
for acc in accounts:
|
||||||
except ValueError as e:
|
url = (acc.get("url") or "").strip()
|
||||||
return {"calendars": 0, "events": 0, "deleted": 0, "errors": [str(e)]}
|
user = (acc.get("username") or "").strip()
|
||||||
except Exception as e:
|
pw = acc.get("password") or ""
|
||||||
logger.exception("CalDAV sync raised")
|
account_id = acc.get("id") or ""
|
||||||
return {"calendars": 0, "events": 0, "deleted": 0, "errors": [str(e)[:200]]}
|
label = acc.get("label") or url or account_id
|
||||||
|
try:
|
||||||
|
pw = decrypt(pw)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
if not (url and user and pw):
|
||||||
|
totals["errors"].append(f"{label}: missing URL, username, or password")
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
url = validate_caldav_url(url)
|
||||||
|
result = await asyncio.to_thread(_sync_blocking, owner, url, user, pw, account_id)
|
||||||
|
except ValueError as e:
|
||||||
|
result = {"calendars": 0, "events": 0, "deleted": 0, "errors": [str(e)]}
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception("CalDAV sync raised for account %s", label)
|
||||||
|
result = {"calendars": 0, "events": 0, "deleted": 0, "errors": [str(e)[:200]]}
|
||||||
|
totals["calendars"] += result.get("calendars", 0)
|
||||||
|
totals["events"] += result.get("events", 0)
|
||||||
|
totals["deleted"] += result.get("deleted", 0)
|
||||||
|
for err in result.get("errors", []):
|
||||||
|
totals["errors"].append(f"{label}: {err}")
|
||||||
|
return totals
|
||||||
|
|||||||
+59
-23
@@ -23,11 +23,10 @@ from datetime import timezone
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def _stable_cal_id(remote_url: str) -> str:
|
def _stable_cal_id(remote_url: str, owner: str = "", account_id: str = "") -> str:
|
||||||
# Reuse the sync module's hashing so a local CalDAV calendar id maps back to
|
# Reuse the sync module's hashing so owner+account_id scoping stays consistent.
|
||||||
# the same remote URL it was pulled from.
|
|
||||||
from src.caldav_sync import _stable_cal_id as _sync_id
|
from src.caldav_sync import _stable_cal_id as _sync_id
|
||||||
return _sync_id(remote_url)
|
return _sync_id(remote_url, owner=owner, account_id=account_id)
|
||||||
|
|
||||||
|
|
||||||
def build_event_ical(ev: dict) -> str:
|
def build_event_ical(ev: dict) -> str:
|
||||||
@@ -76,28 +75,34 @@ def build_event_ical(ev: dict) -> str:
|
|||||||
return cal.to_ical().decode("utf-8")
|
return cal.to_ical().decode("utf-8")
|
||||||
|
|
||||||
|
|
||||||
def find_remote_calendar(calendars, local_cal_id: str):
|
def find_remote_calendar(calendars, local_cal_id: str, owner: str = "", account_id: str = ""):
|
||||||
"""Find the remote calendar whose URL hashes to ``local_cal_id``, or None."""
|
"""Find the remote calendar whose URL hashes to ``local_cal_id``, or None.
|
||||||
|
|
||||||
|
``owner`` and ``account_id`` must match what was used when the local calendar
|
||||||
|
id was originally computed in ``_sync_blocking`` so the hash round-trips."""
|
||||||
for cal in calendars:
|
for cal in calendars:
|
||||||
try:
|
try:
|
||||||
if _stable_cal_id(str(cal.url)) == local_cal_id:
|
if _stable_cal_id(str(cal.url), owner=owner, account_id=account_id) == local_cal_id:
|
||||||
return cal
|
return cal
|
||||||
except Exception:
|
except Exception:
|
||||||
continue
|
continue
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def push_event(calendars, local_cal_id: str, ev: dict, *, delete: bool = False) -> dict:
|
def push_event(calendars, local_cal_id: str, ev: dict, *, delete: bool = False,
|
||||||
|
owner: str = "", account_id: str = "") -> dict:
|
||||||
"""Create/update (or delete) ``ev`` on the matching remote calendar.
|
"""Create/update (or delete) ``ev`` on the matching remote calendar.
|
||||||
|
|
||||||
Returns ``{"ok": bool, ...}``. ``calendars`` is the discovered caldav
|
Returns ``{"ok": bool, ...}``. ``calendars`` is the discovered caldav
|
||||||
calendar list (injected so this is unit-testable with fakes).
|
calendar list (injected so this is unit-testable with fakes).
|
||||||
|
``owner`` and ``account_id`` are forwarded to ``find_remote_calendar``
|
||||||
|
so the URL hash round-trips correctly (#2765).
|
||||||
"""
|
"""
|
||||||
uid = (ev or {}).get("uid") if isinstance(ev, dict) else None
|
uid = (ev or {}).get("uid") if isinstance(ev, dict) else None
|
||||||
if not uid:
|
if not uid:
|
||||||
return {"ok": False, "error": "event uid is required"}
|
return {"ok": False, "error": "event uid is required"}
|
||||||
|
|
||||||
remote = find_remote_calendar(calendars, local_cal_id)
|
remote = find_remote_calendar(calendars, local_cal_id, owner=owner, account_id=account_id)
|
||||||
if remote is None:
|
if remote is None:
|
||||||
return {"ok": False, "error": "remote calendar not found"}
|
return {"ok": False, "error": "remote calendar not found"}
|
||||||
|
|
||||||
@@ -136,13 +141,17 @@ def _discover_calendars(client):
|
|||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
||||||
def _writeback_blocking(local_cal_id, ev, delete, url, username, password) -> dict:
|
def _writeback_blocking(local_cal_id, ev, delete, url, username, password,
|
||||||
import caldav
|
owner="", account_id="") -> dict:
|
||||||
client = caldav.DAVClient(url=url, username=username, password=password)
|
from src.caldav_sync import _build_dav_client
|
||||||
|
# Redirects disabled here too: the write-back path opens its own DAVClient,
|
||||||
|
# so it needs the same SSRF-via-redirect protection as the pull path.
|
||||||
|
client = _build_dav_client(url, username, password)
|
||||||
calendars = _discover_calendars(client)
|
calendars = _discover_calendars(client)
|
||||||
if not calendars:
|
if not calendars:
|
||||||
return {"ok": False, "error": "no remote calendars discovered"}
|
return {"ok": False, "error": "no remote calendars discovered"}
|
||||||
return push_event(calendars, local_cal_id, ev, delete=delete)
|
return push_event(calendars, local_cal_id, ev, delete=delete,
|
||||||
|
owner=owner, account_id=account_id)
|
||||||
|
|
||||||
|
|
||||||
async def writeback_event(owner: str, calendar_source: str, calendar_id: str,
|
async def writeback_event(owner: str, calendar_source: str, calendar_id: str,
|
||||||
@@ -156,18 +165,45 @@ async def writeback_event(owner: str, calendar_source: str, calendar_id: str,
|
|||||||
if calendar_source != "caldav":
|
if calendar_source != "caldav":
|
||||||
return {"skipped": "not a caldav calendar"}
|
return {"skipped": "not a caldav calendar"}
|
||||||
try:
|
try:
|
||||||
from routes.prefs_routes import _load_for_user
|
from src.caldav_sync import _load_caldav_accounts
|
||||||
from src.secret_storage import decrypt
|
from src.secret_storage import decrypt
|
||||||
cfg = (_load_for_user(owner) or {}).get("caldav", {}) or {}
|
from core.database import CalendarCal, SessionLocal
|
||||||
url = (cfg.get("url") or "").strip()
|
|
||||||
user = (cfg.get("username") or "").strip()
|
accounts = _load_caldav_accounts(owner)
|
||||||
# Stored encrypted by routes/calendar_routes; decrypt before use so
|
if not accounts:
|
||||||
# the remote sees the real password (decrypt is a no-op on legacy
|
|
||||||
# plaintext). The pull path src/caldav_sync.py already does this.
|
|
||||||
pw = decrypt(cfg.get("password") or "")
|
|
||||||
if not (url and user and pw):
|
|
||||||
return {"skipped": "caldav not configured"}
|
return {"skipped": "caldav not configured"}
|
||||||
result = await asyncio.to_thread(_writeback_blocking, calendar_id, ev, delete, url, user, pw)
|
|
||||||
|
# Find which account owns this calendar.
|
||||||
|
acc = None
|
||||||
|
if len(accounts) > 1:
|
||||||
|
db = SessionLocal()
|
||||||
|
try:
|
||||||
|
cal_row = db.query(CalendarCal).filter(CalendarCal.id == calendar_id).first()
|
||||||
|
cal_account_id = cal_row.account_id if cal_row else None
|
||||||
|
finally:
|
||||||
|
db.close()
|
||||||
|
if cal_account_id:
|
||||||
|
acc = next((a for a in accounts if a.get("id") == cal_account_id), None)
|
||||||
|
# Fall back to first account (covers single-account and legacy rows with
|
||||||
|
# no account_id stamped).
|
||||||
|
if acc is None:
|
||||||
|
acc = accounts[0]
|
||||||
|
|
||||||
|
url = (acc.get("url") or "").strip()
|
||||||
|
user = (acc.get("username") or "").strip()
|
||||||
|
pw = decrypt(acc.get("password") or "")
|
||||||
|
if not (url and user and pw):
|
||||||
|
return {"skipped": "caldav account credentials incomplete"}
|
||||||
|
from src.caldav_sync import validate_caldav_url
|
||||||
|
try:
|
||||||
|
url = validate_caldav_url(url)
|
||||||
|
except ValueError as e:
|
||||||
|
logger.warning("CalDAV write-back URL rejected: %s", e)
|
||||||
|
return {"ok": False, "error": str(e)[:200]}
|
||||||
|
acc_id = acc.get("id") or ""
|
||||||
|
result = await asyncio.to_thread(
|
||||||
|
_writeback_blocking, calendar_id, ev, delete, url, user, pw, owner, acc_id
|
||||||
|
)
|
||||||
if not result.get("ok"):
|
if not result.get("ok"):
|
||||||
logger.warning("CalDAV write-back did not apply: %s", result.get("error") or result)
|
logger.warning("CalDAV write-back did not apply: %s", result.get("error") or result)
|
||||||
return result
|
return result
|
||||||
|
|||||||
+25
-15
@@ -98,6 +98,7 @@ class ChatHandler:
|
|||||||
att_ids: List[str],
|
att_ids: List[str],
|
||||||
sess,
|
sess,
|
||||||
auto_opened_docs: Optional[List[Dict[str, Any]]] = None,
|
auto_opened_docs: Optional[List[Dict[str, Any]]] = None,
|
||||||
|
allow_tool_preprocessing: bool = True,
|
||||||
) -> tuple:
|
) -> tuple:
|
||||||
"""
|
"""
|
||||||
Common preprocessing for both chat endpoints.
|
Common preprocessing for both chat endpoints.
|
||||||
@@ -112,7 +113,7 @@ class ChatHandler:
|
|||||||
attachment_meta: List[Dict[str, Any]] = []
|
attachment_meta: List[Dict[str, Any]] = []
|
||||||
|
|
||||||
# Extract URLs and process YouTube transcripts
|
# Extract URLs and process YouTube transcripts
|
||||||
urls = extract_urls(enhanced_message)
|
urls = extract_urls(enhanced_message) if allow_tool_preprocessing else []
|
||||||
youtube_transcripts: List[str] = []
|
youtube_transcripts: List[str] = []
|
||||||
|
|
||||||
has_youtube = False
|
has_youtube = False
|
||||||
@@ -143,24 +144,18 @@ class ChatHandler:
|
|||||||
if has_youtube:
|
if has_youtube:
|
||||||
youtube_transcripts.insert(0, YOUTUBE_INSTRUCTION_PROMPT)
|
youtube_transcripts.insert(0, YOUTUBE_INSTRUCTION_PROMPT)
|
||||||
|
|
||||||
# Analyze images — skip if vision disabled, or if main model is vision-capable
|
|
||||||
from src.settings import get_setting
|
|
||||||
vision_enabled = get_setting("vision_enabled", True)
|
|
||||||
main_is_vision = await asyncio.to_thread(
|
|
||||||
model_supports_vision, sess.model or "", getattr(sess, "endpoint_url", "") or ""
|
|
||||||
)
|
|
||||||
|
|
||||||
# Resolve uploads once with the session owner. Attachment IDs are
|
# Resolve uploads once with the session owner. Attachment IDs are
|
||||||
# bearer-like references; never trust them without an owner check.
|
# bearer-like references; never trust them without an owner check.
|
||||||
files_by_id: Dict[str, Dict] = {}
|
files_by_id: Dict[str, Dict] = {}
|
||||||
owner = getattr(sess, "owner", None)
|
owner = getattr(sess, "owner", None)
|
||||||
if att_ids:
|
effective_att_ids = att_ids if allow_tool_preprocessing else []
|
||||||
for att_id in att_ids:
|
if effective_att_ids:
|
||||||
|
for att_id in effective_att_ids:
|
||||||
fi = self.upload_handler.resolve_upload(att_id, owner=owner)
|
fi = self.upload_handler.resolve_upload(att_id, owner=owner)
|
||||||
if fi:
|
if fi:
|
||||||
files_by_id[att_id] = fi
|
files_by_id[att_id] = fi
|
||||||
|
|
||||||
for att_id in att_ids:
|
for att_id in effective_att_ids:
|
||||||
fi = files_by_id.get(att_id)
|
fi = files_by_id.get(att_id)
|
||||||
if fi:
|
if fi:
|
||||||
attachment_meta.append({
|
attachment_meta.append({
|
||||||
@@ -172,9 +167,24 @@ class ChatHandler:
|
|||||||
"height": fi.get("height"),
|
"height": fi.get("height"),
|
||||||
})
|
})
|
||||||
|
|
||||||
if att_ids and vision_enabled:
|
# Analyze images only when attachment preprocessing is actually
|
||||||
|
# allowed. The vision capability check can probe local model endpoints,
|
||||||
|
# so guide-only/no-tools turns must not reach it.
|
||||||
|
vision_enabled = False
|
||||||
|
main_is_vision = False
|
||||||
|
if effective_att_ids:
|
||||||
|
from src.settings import get_setting
|
||||||
|
vision_enabled = get_setting("vision_enabled", True)
|
||||||
|
if vision_enabled:
|
||||||
|
main_is_vision = await asyncio.to_thread(
|
||||||
|
model_supports_vision,
|
||||||
|
sess.model or "",
|
||||||
|
getattr(sess, "endpoint_url", "") or "",
|
||||||
|
)
|
||||||
|
|
||||||
|
if effective_att_ids and vision_enabled:
|
||||||
meta_by_id = {m["id"]: m for m in attachment_meta}
|
meta_by_id = {m["id"]: m for m in attachment_meta}
|
||||||
for att_id in att_ids:
|
for att_id in effective_att_ids:
|
||||||
file_info = files_by_id.get(att_id)
|
file_info = files_by_id.get(att_id)
|
||||||
if file_info and self.upload_handler.is_image_file(
|
if file_info and self.upload_handler.is_image_file(
|
||||||
file_info["name"], file_info.get("mime", "")
|
file_info["name"], file_info.get("mime", "")
|
||||||
@@ -219,7 +229,7 @@ class ChatHandler:
|
|||||||
except Exception:
|
except Exception:
|
||||||
vl_desc = None
|
vl_desc = None
|
||||||
if not vl_desc:
|
if not vl_desc:
|
||||||
vl_result = analyze_image_with_vl_result(file_info["path"])
|
vl_result = analyze_image_with_vl_result(file_info["path"], owner=owner)
|
||||||
vl_desc = vl_result.get("text", "")
|
vl_desc = vl_result.get("text", "")
|
||||||
vl_model = vl_result.get("model", "")
|
vl_model = vl_result.get("model", "")
|
||||||
if vl_desc and not vl_desc.startswith("["):
|
if vl_desc and not vl_desc.startswith("["):
|
||||||
@@ -239,7 +249,7 @@ class ChatHandler:
|
|||||||
_m["vision_model"] = vl_model
|
_m["vision_model"] = vl_model
|
||||||
|
|
||||||
user_content = build_user_content(
|
user_content = build_user_content(
|
||||||
enhanced_message, att_ids, UPLOAD_DIR, self.upload_handler,
|
enhanced_message, effective_att_ids, UPLOAD_DIR, self.upload_handler,
|
||||||
session_id=getattr(sess, "id", None),
|
session_id=getattr(sess, "id", None),
|
||||||
auto_opened_docs=auto_opened_docs,
|
auto_opened_docs=auto_opened_docs,
|
||||||
owner=owner,
|
owner=owner,
|
||||||
|
|||||||
+13
-3
@@ -13,6 +13,8 @@ from fastapi import HTTPException
|
|||||||
from fastapi import UploadFile
|
from fastapi import UploadFile
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from src.upload_limits import format_byte_limit, get_chat_upload_max_bytes
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@@ -22,7 +24,14 @@ def extract_urls(text: str) -> List[str]:
|
|||||||
urls = re.findall(url_pattern, text)
|
urls = re.findall(url_pattern, text)
|
||||||
cleaned_urls = []
|
cleaned_urls = []
|
||||||
for url in urls:
|
for url in urls:
|
||||||
url = re.sub(r'[.,;:!?\)]+$', '', url)
|
# Strip trailing sentence punctuation, but keep a balanced ')' so URLs
|
||||||
|
# that legitimately end in one are preserved, e.g. the Wikipedia link
|
||||||
|
# ".../Python_(programming_language)". A ')' is only dropped when it is
|
||||||
|
# unbalanced (more ')' than '('), which is the prose-glued case such as
|
||||||
|
# "(see https://example.com)".
|
||||||
|
url = re.sub(r'[.,;:!?]+$', '', url)
|
||||||
|
while url.endswith(')') and url.count(')') > url.count('('):
|
||||||
|
url = re.sub(r'[.,;:!?]+$', '', url[:-1])
|
||||||
cleaned_urls.append(url)
|
cleaned_urls.append(url)
|
||||||
return cleaned_urls
|
return cleaned_urls
|
||||||
|
|
||||||
@@ -201,12 +210,13 @@ def validate_file_upload(file: UploadFile) -> UploadFile:
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
if file_size > 10 * 1024 * 1024:
|
upload_limit = get_chat_upload_max_bytes()
|
||||||
|
if file_size > upload_limit:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=400,
|
status_code=400,
|
||||||
detail={
|
detail={
|
||||||
"error": "FILE_TOO_LARGE",
|
"error": "FILE_TOO_LARGE",
|
||||||
"message": "File size exceeds 10MB limit"
|
"message": f"File size exceeds {format_byte_limit(upload_limit)} limit"
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
except IOError as e:
|
except IOError as e:
|
||||||
|
|||||||
@@ -0,0 +1,311 @@
|
|||||||
|
"""ChatGPT subscription / Codex backend OAuth helpers.
|
||||||
|
|
||||||
|
This provider is intentionally separate from OpenAI API-key endpoints. It uses
|
||||||
|
OpenAI account OAuth device authorization, stores refresh tokens server-side,
|
||||||
|
and resolves a fresh bearer token at request time.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import base64
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
from fastapi import HTTPException
|
||||||
|
|
||||||
|
from core.database import ProviderAuthSession, SessionLocal, utcnow_naive
|
||||||
|
|
||||||
|
DEFAULT_CHATGPT_SUBSCRIPTION_BASE_URL = (
|
||||||
|
os.getenv("CHATGPT_SUBSCRIPTION_BASE_URL", "").strip().rstrip("/")
|
||||||
|
or "https://chatgpt.com/backend-api/codex"
|
||||||
|
)
|
||||||
|
CHATGPT_SUBSCRIPTION_PROVIDER = "chatgpt-subscription"
|
||||||
|
CHATGPT_OAUTH_CLIENT_ID = "app_EMoamEEZ73f0CkXaXp7hrann"
|
||||||
|
CHATGPT_OAUTH_TOKEN_URL = "https://auth.openai.com/oauth/token"
|
||||||
|
CHATGPT_OAUTH_ISSUER = "https://auth.openai.com"
|
||||||
|
CHATGPT_OAUTH_REDIRECT_URI = f"{CHATGPT_OAUTH_ISSUER}/deviceauth/callback"
|
||||||
|
CHATGPT_ACCESS_TOKEN_REFRESH_SKEW_SECONDS = 120
|
||||||
|
_AUTH_REFRESH_LOCKS: dict[str, threading.Lock] = {}
|
||||||
|
_AUTH_REFRESH_LOCKS_GUARD = threading.Lock()
|
||||||
|
|
||||||
|
|
||||||
|
def _refresh_lock_for(auth_id: str) -> threading.Lock:
|
||||||
|
with _AUTH_REFRESH_LOCKS_GUARD:
|
||||||
|
lock = _AUTH_REFRESH_LOCKS.get(auth_id)
|
||||||
|
if lock is None:
|
||||||
|
lock = threading.Lock()
|
||||||
|
_AUTH_REFRESH_LOCKS[auth_id] = lock
|
||||||
|
return lock
|
||||||
|
|
||||||
|
|
||||||
|
class ChatGPTSubscriptionError(RuntimeError):
|
||||||
|
"""Base error for ChatGPT subscription provider failures."""
|
||||||
|
|
||||||
|
|
||||||
|
class ChatGPTSubscriptionReauthRequired(ChatGPTSubscriptionError):
|
||||||
|
"""Stored OAuth credentials are invalid or expired beyond refresh."""
|
||||||
|
|
||||||
|
|
||||||
|
class ChatGPTSubscriptionRateLimited(ChatGPTSubscriptionError):
|
||||||
|
"""Upstream quota/rate limit; reconnecting will not fix it."""
|
||||||
|
|
||||||
|
|
||||||
|
class ChatGPTSubscriptionAuthNotFound(ChatGPTSubscriptionError):
|
||||||
|
"""No matching owner-scoped auth session exists."""
|
||||||
|
|
||||||
|
|
||||||
|
def is_chatgpt_subscription_base(url: str) -> bool:
|
||||||
|
try:
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
parsed = urlparse(url or "")
|
||||||
|
host = (parsed.hostname or "").lower().rstrip(".")
|
||||||
|
path = (parsed.path or "").rstrip("/")
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
return host == "chatgpt.com" and (
|
||||||
|
path == "/backend-api/codex" or path.startswith("/backend-api/codex/")
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def chatgpt_headers(access_token: Optional[str]) -> Dict[str, str]:
|
||||||
|
headers = {
|
||||||
|
"Accept": "application/json, text/event-stream",
|
||||||
|
"Origin": "https://chatgpt.com",
|
||||||
|
"Referer": "https://chatgpt.com/codex",
|
||||||
|
"User-Agent": "Odysseus ChatGPT Subscription",
|
||||||
|
}
|
||||||
|
if access_token:
|
||||||
|
headers["Authorization"] = f"Bearer {access_token}"
|
||||||
|
return headers
|
||||||
|
|
||||||
|
|
||||||
|
def fetch_available_models(access_token: str, timeout: float = 10.0) -> list[str]:
|
||||||
|
if not access_token:
|
||||||
|
return []
|
||||||
|
try:
|
||||||
|
response = httpx.get(
|
||||||
|
"https://chatgpt.com/backend-api/codex/models?client_version=1.0.0",
|
||||||
|
headers=chatgpt_headers(access_token),
|
||||||
|
timeout=timeout,
|
||||||
|
)
|
||||||
|
if response.status_code != 200:
|
||||||
|
return []
|
||||||
|
data = response.json()
|
||||||
|
except Exception:
|
||||||
|
return []
|
||||||
|
entries = data.get("models", []) if isinstance(data, dict) else []
|
||||||
|
sortable: list[tuple[int, str]] = []
|
||||||
|
for item in entries:
|
||||||
|
if not isinstance(item, dict):
|
||||||
|
continue
|
||||||
|
slug = item.get("slug")
|
||||||
|
if not isinstance(slug, str) or not slug.strip():
|
||||||
|
continue
|
||||||
|
visibility = item.get("visibility", "")
|
||||||
|
if isinstance(visibility, str) and visibility.strip().lower() in {"hide", "hidden"}:
|
||||||
|
continue
|
||||||
|
priority = item.get("priority")
|
||||||
|
rank = int(priority) if isinstance(priority, (int, float)) else 10_000
|
||||||
|
sortable.append((rank, slug.strip()))
|
||||||
|
sortable.sort(key=lambda item: (item[0], item[1]))
|
||||||
|
ordered: list[str] = []
|
||||||
|
seen: set[str] = set()
|
||||||
|
for _, slug in sortable:
|
||||||
|
if slug not in seen:
|
||||||
|
ordered.append(slug)
|
||||||
|
seen.add(slug)
|
||||||
|
return ordered
|
||||||
|
|
||||||
|
|
||||||
|
def _raise_for_oauth_response(response: httpx.Response, action: str) -> None:
|
||||||
|
if response.status_code < 400:
|
||||||
|
return
|
||||||
|
code = ""
|
||||||
|
message = f"ChatGPT Subscription {action} failed with HTTP {response.status_code}."
|
||||||
|
try:
|
||||||
|
payload = response.json()
|
||||||
|
err = payload.get("error") if isinstance(payload, dict) else None
|
||||||
|
if isinstance(err, dict):
|
||||||
|
code = str(err.get("code") or err.get("type") or "").strip()
|
||||||
|
msg = err.get("message")
|
||||||
|
if msg:
|
||||||
|
message = f"ChatGPT Subscription {action} failed: {msg}"
|
||||||
|
elif isinstance(err, str):
|
||||||
|
code = err.strip()
|
||||||
|
desc = payload.get("error_description") or payload.get("message")
|
||||||
|
if desc:
|
||||||
|
message = f"ChatGPT Subscription {action} failed: {desc}"
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
if response.status_code == 429:
|
||||||
|
raise ChatGPTSubscriptionRateLimited(
|
||||||
|
"ChatGPT Subscription quota or rate limit was reached. Credentials are still valid."
|
||||||
|
)
|
||||||
|
if response.status_code in (401, 403) or code in {"invalid_grant", "invalid_token", "invalid_request", "refresh_token_reused"}:
|
||||||
|
raise ChatGPTSubscriptionReauthRequired(message)
|
||||||
|
raise ChatGPTSubscriptionError(message)
|
||||||
|
|
||||||
|
|
||||||
|
def _json_or_error(response: httpx.Response, action: str) -> Dict[str, Any]:
|
||||||
|
_raise_for_oauth_response(response, action)
|
||||||
|
try:
|
||||||
|
data = response.json()
|
||||||
|
except Exception as exc:
|
||||||
|
raise ChatGPTSubscriptionError(f"ChatGPT Subscription {action} returned invalid JSON.") from exc
|
||||||
|
if not isinstance(data, dict):
|
||||||
|
raise ChatGPTSubscriptionError(f"ChatGPT Subscription {action} returned an unexpected response.")
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def request_device_code(timeout: float = 15.0) -> Dict[str, Any]:
|
||||||
|
response = httpx.post(
|
||||||
|
f"{CHATGPT_OAUTH_ISSUER}/api/accounts/deviceauth/usercode",
|
||||||
|
json={"client_id": CHATGPT_OAUTH_CLIENT_ID},
|
||||||
|
headers={"Content-Type": "application/json"},
|
||||||
|
timeout=timeout,
|
||||||
|
)
|
||||||
|
data = _json_or_error(response, "device-code request")
|
||||||
|
if not data.get("device_auth_id") or not data.get("user_code"):
|
||||||
|
raise ChatGPTSubscriptionError("ChatGPT device-code response was missing required fields.")
|
||||||
|
data.setdefault("verification_uri", f"{CHATGPT_OAUTH_ISSUER}/codex/device")
|
||||||
|
data.setdefault("interval", 5)
|
||||||
|
data.setdefault("expires_in", 900)
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def poll_device_auth(device_auth_id: str, user_code: str, timeout: float = 15.0) -> Dict[str, Any]:
|
||||||
|
response = httpx.post(
|
||||||
|
f"{CHATGPT_OAUTH_ISSUER}/api/accounts/deviceauth/token",
|
||||||
|
json={"device_auth_id": device_auth_id, "user_code": user_code},
|
||||||
|
headers={"Content-Type": "application/json"},
|
||||||
|
timeout=timeout,
|
||||||
|
)
|
||||||
|
if response.status_code in (403, 404):
|
||||||
|
return {"status": "pending", "error": "authorization_pending"}
|
||||||
|
return _json_or_error(response, "device-code poll")
|
||||||
|
|
||||||
|
|
||||||
|
def exchange_authorization_code(authorization_code: str, code_verifier: str, timeout: float = 15.0) -> Dict[str, Any]:
|
||||||
|
response = httpx.post(
|
||||||
|
CHATGPT_OAUTH_TOKEN_URL,
|
||||||
|
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||||||
|
data={
|
||||||
|
"grant_type": "authorization_code",
|
||||||
|
"code": authorization_code,
|
||||||
|
"redirect_uri": CHATGPT_OAUTH_REDIRECT_URI,
|
||||||
|
"client_id": CHATGPT_OAUTH_CLIENT_ID,
|
||||||
|
"code_verifier": code_verifier,
|
||||||
|
},
|
||||||
|
timeout=timeout,
|
||||||
|
)
|
||||||
|
data = _json_or_error(response, "token exchange")
|
||||||
|
if not data.get("access_token"):
|
||||||
|
raise ChatGPTSubscriptionReauthRequired("ChatGPT token exchange did not return an access token.")
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def refresh_oauth_tokens(access_token: str, refresh_token: str, timeout: float = 20.0) -> Dict[str, Any]:
|
||||||
|
del access_token
|
||||||
|
if not refresh_token:
|
||||||
|
raise ChatGPTSubscriptionReauthRequired("ChatGPT Subscription is missing a refresh token. Reconnect the provider.")
|
||||||
|
response = httpx.post(
|
||||||
|
CHATGPT_OAUTH_TOKEN_URL,
|
||||||
|
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||||||
|
data={
|
||||||
|
"grant_type": "refresh_token",
|
||||||
|
"refresh_token": refresh_token,
|
||||||
|
"client_id": CHATGPT_OAUTH_CLIENT_ID,
|
||||||
|
},
|
||||||
|
timeout=timeout,
|
||||||
|
)
|
||||||
|
data = _json_or_error(response, "token refresh")
|
||||||
|
if not data.get("access_token"):
|
||||||
|
raise ChatGPTSubscriptionReauthRequired("ChatGPT token refresh did not return an access token.")
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def _decode_jwt_payload(token: str) -> Dict[str, Any]:
|
||||||
|
parts = (token or "").split(".")
|
||||||
|
if len(parts) < 2:
|
||||||
|
raise ValueError("not a JWT")
|
||||||
|
segment = parts[1]
|
||||||
|
segment += "=" * (-len(segment) % 4)
|
||||||
|
raw = base64.urlsafe_b64decode(segment.encode("ascii"))
|
||||||
|
payload = json.loads(raw.decode("utf-8"))
|
||||||
|
return payload if isinstance(payload, dict) else {}
|
||||||
|
|
||||||
|
|
||||||
|
def access_token_is_expiring(access_token: str, skew_seconds: int = CHATGPT_ACCESS_TOKEN_REFRESH_SKEW_SECONDS) -> bool:
|
||||||
|
try:
|
||||||
|
exp = int(_decode_jwt_payload(access_token).get("exp") or 0)
|
||||||
|
except Exception:
|
||||||
|
return True
|
||||||
|
return exp <= int(time.time()) + int(skew_seconds)
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_runtime_credentials(auth_id: str, owner: Optional[str] = None, *, force_refresh: bool = False) -> Dict[str, Any]:
|
||||||
|
db = SessionLocal()
|
||||||
|
try:
|
||||||
|
q = db.query(ProviderAuthSession).filter(
|
||||||
|
ProviderAuthSession.id == auth_id,
|
||||||
|
ProviderAuthSession.provider == CHATGPT_SUBSCRIPTION_PROVIDER,
|
||||||
|
)
|
||||||
|
if owner:
|
||||||
|
q = q.filter(ProviderAuthSession.owner == owner)
|
||||||
|
row = q.first()
|
||||||
|
if row is None:
|
||||||
|
raise ChatGPTSubscriptionAuthNotFound("ChatGPT Subscription credentials were not found for this user.")
|
||||||
|
|
||||||
|
access_token = row.access_token or ""
|
||||||
|
if force_refresh or access_token_is_expiring(access_token):
|
||||||
|
with _refresh_lock_for(auth_id):
|
||||||
|
db.refresh(row)
|
||||||
|
access_token = row.access_token or ""
|
||||||
|
refresh_token = row.refresh_token or ""
|
||||||
|
if force_refresh or access_token_is_expiring(access_token):
|
||||||
|
refreshed = refresh_oauth_tokens(access_token, refresh_token)
|
||||||
|
row.access_token = refreshed["access_token"]
|
||||||
|
if refreshed.get("refresh_token"):
|
||||||
|
row.refresh_token = refreshed["refresh_token"]
|
||||||
|
row.last_refresh = utcnow_naive()
|
||||||
|
db.commit()
|
||||||
|
db.refresh(row)
|
||||||
|
access_token = row.access_token or ""
|
||||||
|
|
||||||
|
return {
|
||||||
|
"provider": CHATGPT_SUBSCRIPTION_PROVIDER,
|
||||||
|
"base_url": (row.base_url or DEFAULT_CHATGPT_SUBSCRIPTION_BASE_URL).rstrip("/"),
|
||||||
|
"api_key": access_token,
|
||||||
|
"auth_mode": row.auth_mode or "chatgpt",
|
||||||
|
}
|
||||||
|
finally:
|
||||||
|
db.close()
|
||||||
|
|
||||||
|
|
||||||
|
def to_http_exception(exc: Exception) -> HTTPException:
|
||||||
|
if isinstance(exc, ChatGPTSubscriptionRateLimited):
|
||||||
|
return HTTPException(429, str(exc))
|
||||||
|
if isinstance(exc, (ChatGPTSubscriptionReauthRequired, ChatGPTSubscriptionAuthNotFound)):
|
||||||
|
return HTTPException(401, f"{exc} Reconnect the provider.")
|
||||||
|
return HTTPException(502, str(exc))
|
||||||
|
|
||||||
|
|
||||||
|
def build_responses_input(messages: list[dict]) -> list[dict]:
|
||||||
|
input_items: list[dict] = []
|
||||||
|
for msg in messages or []:
|
||||||
|
role = msg.get("role") or "user"
|
||||||
|
if role == "tool":
|
||||||
|
role = "user"
|
||||||
|
content = msg.get("content")
|
||||||
|
if isinstance(content, list):
|
||||||
|
text = "\n".join(str(part.get("text") or part.get("content") or "") for part in content if isinstance(part, dict))
|
||||||
|
else:
|
||||||
|
text = "" if content is None else str(content)
|
||||||
|
input_type = "output_text" if role == "assistant" else "input_text"
|
||||||
|
input_items.append({"role": role, "content": [{"type": input_type, "text": text}]})
|
||||||
|
return input_items
|
||||||
+10
-8
@@ -4,6 +4,8 @@ from typing import List, Optional
|
|||||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||||
from pydantic import Field, field_validator
|
from pydantic import Field, field_validator
|
||||||
|
|
||||||
|
from src.constants import DATA_DIR as _DATA_DIR_CONST
|
||||||
|
|
||||||
# Cross-platform OS flag, exposed here so callers can `from src.config import
|
# Cross-platform OS flag, exposed here so callers can `from src.config import
|
||||||
# IS_WINDOWS`. Defined locally (a trivial `os.name == "nt"`) rather than imported
|
# IS_WINDOWS`. Defined locally (a trivial `os.name == "nt"`) rather than imported
|
||||||
# from core.platform_compat, to keep this dependency-light config module from
|
# from core.platform_compat, to keep this dependency-light config module from
|
||||||
@@ -20,13 +22,13 @@ class DataConfig(BaseSettings):
|
|||||||
base_dir: Path = Field(default=Path(__file__).parent.parent, description="Base directory for the application")
|
base_dir: Path = Field(default=Path(__file__).parent.parent, description="Base directory for the application")
|
||||||
|
|
||||||
# Data paths
|
# Data paths
|
||||||
data_dir: Path = Field(default=Path("data"), description="Main data directory")
|
data_dir: Path = Field(default=Path(_DATA_DIR_CONST), description="Main data directory")
|
||||||
uploads_dir: Path = Field(default=Path("data/uploads"), description="Directory for uploaded files")
|
uploads_dir: Path = Field(default=Path(_DATA_DIR_CONST) / "uploads", description="Directory for uploaded files")
|
||||||
sessions_file: Path = Field(default=Path("data/sessions.json"), description="Sessions storage file")
|
sessions_file: Path = Field(default=Path(_DATA_DIR_CONST) / "sessions.json", description="Sessions storage file")
|
||||||
memory_file: Path = Field(default=Path("data/memory.json"), description="Memory storage file")
|
memory_file: Path = Field(default=Path(_DATA_DIR_CONST) / "memory.json", description="Memory storage file")
|
||||||
memory_doc: Path = Field(default=Path("data/memory_doc.md"), description="Memory document file")
|
memory_doc: Path = Field(default=Path(_DATA_DIR_CONST) / "memory_doc.md", description="Memory document file")
|
||||||
personal_dir: Path = Field(default=Path("data/personal_docs"), description="Personal documents directory")
|
personal_dir: Path = Field(default=Path(_DATA_DIR_CONST) / "personal_docs", description="Personal documents directory")
|
||||||
runbook_dir: Path = Field(default=Path("data/personal_docs/runbook"), description="Runbook directory")
|
runbook_dir: Path = Field(default=Path(_DATA_DIR_CONST) / "personal_docs" / "runbook", description="Runbook directory")
|
||||||
|
|
||||||
# Upload settings
|
# Upload settings
|
||||||
max_upload_size: int = Field(default=10 * 1024 * 1024, description="Maximum upload size in bytes (10MB)")
|
max_upload_size: int = Field(default=10 * 1024 * 1024, description="Maximum upload size in bytes (10MB)")
|
||||||
@@ -139,7 +141,7 @@ class AppConfig(BaseSettings):
|
|||||||
base_dir = Path(__file__).parent.parent
|
base_dir = Path(__file__).parent.parent
|
||||||
|
|
||||||
# Convert string paths to Path objects relative to base_dir
|
# Convert string paths to Path objects relative to base_dir
|
||||||
data_dir = base_dir / "data"
|
data_dir = Path(_DATA_DIR_CONST)
|
||||||
|
|
||||||
# Get values from the input dict or use defaults
|
# Get values from the input dict or use defaults
|
||||||
max_upload_size = v.get("max_upload_size", 10 * 1024 * 1024) if isinstance(v, dict) else 10 * 1024 * 1024
|
max_upload_size = v.get("max_upload_size", 10 * 1024 * 1024) if isinstance(v, dict) else 10 * 1024 * 1024
|
||||||
|
|||||||
+65
-2
@@ -7,9 +7,12 @@ APP_VERSION = "1.0.0"
|
|||||||
# Base paths
|
# Base paths
|
||||||
BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + "/"
|
BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + "/"
|
||||||
STATIC_DIR = os.path.join(BASE_DIR, "static")
|
STATIC_DIR = os.path.join(BASE_DIR, "static")
|
||||||
DATA_DIR = os.path.join(BASE_DIR, "data")
|
DATA_DIR = os.getenv("ODYSSEUS_DATA_DIR", os.path.join(BASE_DIR, "data"))
|
||||||
|
|
||||||
# Data file paths
|
# Data file paths
|
||||||
|
# Single source of truth: every persisted file/dir lives under DATA_DIR, which
|
||||||
|
# is the ONLY place ODYSSEUS_DATA_DIR is read. Import these constants instead of
|
||||||
|
# re-deriving paths from __file__ or a relative "data" literal.
|
||||||
SESSIONS_FILE = os.path.join(DATA_DIR, "sessions.json")
|
SESSIONS_FILE = os.path.join(DATA_DIR, "sessions.json")
|
||||||
MEMORY_FILE = os.path.join(DATA_DIR, "memory.json")
|
MEMORY_FILE = os.path.join(DATA_DIR, "memory.json")
|
||||||
MEMORY_DOC = os.path.join(DATA_DIR, "memory_doc.md")
|
MEMORY_DOC = os.path.join(DATA_DIR, "memory_doc.md")
|
||||||
@@ -18,6 +21,47 @@ RUNBOOK_DIR = os.path.join(PERSONAL_DIR, "runbook")
|
|||||||
UPLOAD_DIR = os.path.join(DATA_DIR, "uploads")
|
UPLOAD_DIR = os.path.join(DATA_DIR, "uploads")
|
||||||
FEATURES_FILE = os.path.join(DATA_DIR, "features.json")
|
FEATURES_FILE = os.path.join(DATA_DIR, "features.json")
|
||||||
SETTINGS_FILE = os.path.join(DATA_DIR, "settings.json")
|
SETTINGS_FILE = os.path.join(DATA_DIR, "settings.json")
|
||||||
|
AUTH_FILE = os.path.join(DATA_DIR, "auth.json")
|
||||||
|
USER_PREFS_FILE = os.path.join(DATA_DIR, "user_prefs.json")
|
||||||
|
PRESETS_FILE = os.path.join(DATA_DIR, "presets.json")
|
||||||
|
INTEGRATIONS_FILE = os.path.join(DATA_DIR, "integrations.json")
|
||||||
|
CONTACTS_FILE = os.path.join(DATA_DIR, "contacts.json")
|
||||||
|
APP_KEY_FILE = os.path.join(DATA_DIR, ".app_key")
|
||||||
|
EMBEDDING_ENDPOINT_FILE = os.path.join(DATA_DIR, "embedding_endpoint.json")
|
||||||
|
COOKBOOK_STATE_FILE = os.path.join(DATA_DIR, "cookbook_state.json")
|
||||||
|
BG_JOBS_FILE = os.path.join(DATA_DIR, "bg_jobs.json")
|
||||||
|
VAULT_FILE = os.path.join(DATA_DIR, "vault.json")
|
||||||
|
TIDY_CALENDAR_STATE_FILE = os.path.join(DATA_DIR, "tidy_calendar_state.json")
|
||||||
|
SKILLS_FILE = os.path.join(DATA_DIR, "skills.json")
|
||||||
|
APP_DB = os.path.join(DATA_DIR, "app.db")
|
||||||
|
SCHEDULED_EMAILS_DB = os.path.join(DATA_DIR, "scheduled_emails.db")
|
||||||
|
EMAIL_CACHE_DB = os.path.join(DATA_DIR, "email_cache.db")
|
||||||
|
|
||||||
|
# Data subdirectories
|
||||||
|
PERSONAL_UPLOADS_DIR = os.path.join(DATA_DIR, "personal_uploads")
|
||||||
|
EMOJI_CACHE_DIR = os.path.join(DATA_DIR, "emoji_cache")
|
||||||
|
RAG_DIR = os.path.join(DATA_DIR, "rag")
|
||||||
|
CHROMA_DIR = os.path.join(DATA_DIR, "chroma")
|
||||||
|
BG_JOBS_DIR = os.path.join(DATA_DIR, "bg_jobs")
|
||||||
|
DEEP_RESEARCH_DIR = os.path.join(DATA_DIR, "deep_research")
|
||||||
|
MCP_OAUTH_DIR = os.path.join(DATA_DIR, "mcp_oauth")
|
||||||
|
GENERATED_IMAGES_DIR = os.path.join(DATA_DIR, "generated_images")
|
||||||
|
TTS_CACHE_DIR = os.path.join(DATA_DIR, "tts_cache")
|
||||||
|
EMAIL_URGENCY_CACHE_DIR = os.path.join(DATA_DIR, "email_urgency_cache")
|
||||||
|
SKILLS_DIR = os.path.join(DATA_DIR, "skills")
|
||||||
|
GALLERY_DIR = os.path.join(DATA_DIR, "gallery")
|
||||||
|
GALLERY_UPLOADS_DIR = os.path.join(DATA_DIR, "gallery_uploads")
|
||||||
|
MEMORY_VECTORS_DIR = os.path.join(DATA_DIR, "memory_vectors")
|
||||||
|
|
||||||
|
# Paths with an intentional dedicated env override, defaulting under DATA_DIR.
|
||||||
|
MAIL_ATTACHMENTS_DIR = os.getenv("ODYSSEUS_MAIL_ATTACHMENTS_DIR", os.path.join(DATA_DIR, "mail-attachments"))
|
||||||
|
FASTEMBED_CACHE_DIR = os.getenv("FASTEMBED_CACHE_PATH", os.path.join(DATA_DIR, "fastembed_cache"))
|
||||||
|
|
||||||
|
# Agent tool output limits (single source of truth — imported by tool_execution.py,
|
||||||
|
# tool_implementations.py, agent_tools.py, and any other module that needs them)
|
||||||
|
MAX_OUTPUT_CHARS = 10_000 # cap for bash/python/web_search/web_fetch output
|
||||||
|
MAX_READ_CHARS = 20_000 # cap for read_file / document preview
|
||||||
|
MAX_DIFF_LINES = 400 # cap for edit_file unified-diff display
|
||||||
|
|
||||||
# API Configuration
|
# API Configuration
|
||||||
MAX_CONTEXT_MESSAGES = 90
|
MAX_CONTEXT_MESSAGES = 90
|
||||||
@@ -28,7 +72,7 @@ OPENAI_COMPAT_PATH = "/v1/chat/completions"
|
|||||||
DEFAULT_HOST = os.getenv("LLM_HOST", "localhost")
|
DEFAULT_HOST = os.getenv("LLM_HOST", "localhost")
|
||||||
LLM_HOSTS = [h.strip() for h in os.getenv("LLM_HOSTS", "").split(",") if h.strip()]
|
LLM_HOSTS = [h.strip() for h in os.getenv("LLM_HOSTS", "").split(",") if h.strip()]
|
||||||
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
|
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
|
||||||
SEARXNG_INSTANCE = os.getenv('SEARXNG_INSTANCE', 'http://localhost:8080')
|
SEARXNG_INSTANCE = os.getenv("SEARXNG_INSTANCE", "http://localhost:8080")
|
||||||
|
|
||||||
|
|
||||||
# Cleanup configuration
|
# Cleanup configuration
|
||||||
@@ -38,3 +82,22 @@ CLEANUP_INTERVAL_HOURS = int(os.getenv("CLEANUP_INTERVAL_HOURS", "24"))
|
|||||||
# Default parameters
|
# Default parameters
|
||||||
DEFAULT_TEMPERATURE = 1.0
|
DEFAULT_TEMPERATURE = 1.0
|
||||||
DEFAULT_MAX_TOKENS = 0
|
DEFAULT_MAX_TOKENS = 0
|
||||||
|
|
||||||
|
|
||||||
|
def internal_api_base() -> str:
|
||||||
|
"""Base URL for in-process loopback calls to Odysseus's own API.
|
||||||
|
|
||||||
|
Agent tools and background jobs reach admin-gated routes by calling the
|
||||||
|
running server over HTTP. Resolution order:
|
||||||
|
1. ODYSSEUS_INTERNAL_BASE - explicit override (e.g. behind a TLS proxy).
|
||||||
|
2. APP_PORT - http://127.0.0.1:$APP_PORT (docker-compose).
|
||||||
|
3. Fallback http://127.0.0.1:7000 - legacy default.
|
||||||
|
|
||||||
|
127.0.0.1 (not "localhost") avoids IPv6/DNS ambiguity for a strictly-local
|
||||||
|
call. Without this, loopback tools fail with "All connection attempts
|
||||||
|
failed" whenever the server is not on port 7000.
|
||||||
|
"""
|
||||||
|
override = os.environ.get("ODYSSEUS_INTERNAL_BASE")
|
||||||
|
if override:
|
||||||
|
return override.rstrip("/")
|
||||||
|
return f"http://127.0.0.1:{os.environ.get('APP_PORT', '7000')}"
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user