mirror of
https://github.com/pewdiepie-archdaemon/odysseus.git
synced 2026-06-20 03:35:35 -04:00
Merge branch 'dev'
# Conflicts: # routes/task_routes.py # src/caldav_sync.py
This commit is contained in:
@@ -56,6 +56,13 @@ SEARXNG_INSTANCE=http://localhost:8080
|
||||
# SQLite database path (default: sqlite:///./data/app.db)
|
||||
# DATABASE_URL=sqlite:///./data/app.db
|
||||
|
||||
# ============================================================
|
||||
# Data directory
|
||||
# ============================================================
|
||||
# Move everything that lives under data/ - settings, sessions, database, auth,
|
||||
# cache, uploads, etc. - to another path:
|
||||
# ODYSSEUS_DATA_DIR=C:\path\to\dir
|
||||
|
||||
# ============================================================
|
||||
# Auth & Security
|
||||
# ============================================================
|
||||
@@ -112,6 +119,9 @@ SEARXNG_INSTANCE=http://localhost:8080
|
||||
# Default: http://{LLM_HOST}:11434/v1/embeddings (ollama)
|
||||
# EMBEDDING_URL=http://localhost:11434/v1/embeddings
|
||||
|
||||
# Embedding API key (if there's one)
|
||||
# EMBEDDING_API_KEY=embedding_api_key_here
|
||||
|
||||
# Embedding model name (must be available at the endpoint above)
|
||||
# EMBEDDING_MODEL=all-minilm:l6-v2
|
||||
|
||||
@@ -144,6 +154,21 @@ SEARXNG_INSTANCE=http://localhost:8080
|
||||
# if you intentionally want scheduled scripts to run remotely.
|
||||
# ODYSSEUS_SCRIPT_HOST=localhost
|
||||
|
||||
# Chat / agent attachment size cap in bytes (default: 10 MB).
|
||||
# Raise this for local installs that need larger PDFs or text documents.
|
||||
# Example: 52428800 = 50 MB.
|
||||
# ODYSSEUS_CHAT_UPLOAD_MAX_BYTES=10485760
|
||||
|
||||
# Other per-feature upload size caps in bytes. All are validated and optional;
|
||||
# defaults shown. An invalid value (non-integer or < 1) fails fast at startup.
|
||||
# ODYSSEUS_GALLERY_UPLOAD_MAX_BYTES=104857600 # gallery image upload (100 MB)
|
||||
# ODYSSEUS_GALLERY_TRANSFORM_UPLOAD_MAX_BYTES=26214400 # gallery transform input (25 MB)
|
||||
# ODYSSEUS_MEMORY_IMPORT_MAX_BYTES=10485760 # memory import file (10 MB)
|
||||
# ODYSSEUS_PERSONAL_UPLOAD_MAX_BYTES=26214400 # personal document upload (25 MB)
|
||||
# ODYSSEUS_EMAIL_COMPOSE_UPLOAD_MAX_BYTES=26214400 # email compose attachment (25 MB)
|
||||
# ODYSSEUS_STT_MAX_AUDIO_BYTES=26214400 # speech-to-text audio (25 MB)
|
||||
# ODYSSEUS_ICS_MAX_BYTES=10485760 # calendar .ics import (10 MB)
|
||||
|
||||
# ============================================================
|
||||
# GPU support (Docker Compose)
|
||||
# ============================================================
|
||||
|
||||
@@ -23,7 +23,7 @@ body:
|
||||
required: true
|
||||
- label: This is **not** a security vulnerability. (Vulnerabilities go to [GitHub Security Advisories](https://github.com/pewdiepie-archdaemon/odysseus/security/advisories/new) — see [SECURITY.md](https://github.com/pewdiepie-archdaemon/odysseus/blob/main/SECURITY.md).)
|
||||
required: true
|
||||
- label: I am running the latest code from `main`.
|
||||
- label: I am running the latest code from the `dev` branch (the default branch you get on clone, where fixes land first) and the bug still reproduces there. Please `git pull` the latest `dev` before filing.
|
||||
required: true
|
||||
|
||||
- type: dropdown
|
||||
|
||||
@@ -103,14 +103,21 @@ module.exports = async ({ github, context, core }) => {
|
||||
|
||||
async function swapLabel(num, add, remove) {
|
||||
if (await labelExists(add)) {
|
||||
await github.rest.issues.addLabels({ owner, repo, issue_number: num, labels: [add] });
|
||||
try {
|
||||
await github.rest.issues.addLabels({ owner, repo, issue_number: num, labels: [add] });
|
||||
} catch (e) {
|
||||
// Fail soft on a token that can't write labels so a label permission
|
||||
// problem never masks the actual description verdict.
|
||||
if (e.status !== 403) throw e;
|
||||
core.warning(`Could not add "${add}" — token lacks label write here; skipping.`);
|
||||
}
|
||||
} else {
|
||||
core.warning(`Label "${add}" does not exist in the repo — skipping. Create it once to enable labelling.`);
|
||||
}
|
||||
try {
|
||||
await github.rest.issues.removeLabel({ owner, repo, issue_number: num, name: remove });
|
||||
} catch (e) {
|
||||
if (e.status !== 404 && e.status !== 410) throw e;
|
||||
if (e.status !== 404 && e.status !== 410 && e.status !== 403) throw e;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -20,6 +20,8 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5
|
||||
with:
|
||||
python-version: "3.11"
|
||||
@@ -31,6 +33,8 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions/setup-node@49933ea5288caeca8642d1e84afbd3f7d6820020 # v4
|
||||
with:
|
||||
node-version: "20"
|
||||
@@ -51,10 +55,40 @@ jobs:
|
||||
continue-on-error: true
|
||||
steps:
|
||||
- uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
persist-credentials: false
|
||||
|
||||
# Detect whether this PR only touches documentation files.
|
||||
# If so, skip the expensive pytest run while still reporting a passing check.
|
||||
- name: Check for docs-only changes
|
||||
id: docs-check
|
||||
run: |
|
||||
if [ "${{ github.event_name }}" = "pull_request" ]; then
|
||||
BASE="${{ github.event.pull_request.base.sha }}"
|
||||
HEAD="${{ github.event.pull_request.head.sha }}"
|
||||
else
|
||||
BASE="${{ github.event.before }}"
|
||||
HEAD="${{ github.sha }}"
|
||||
fi
|
||||
# List all changed files; if every file matches docs/markdown patterns, skip pytest.
|
||||
changed=$(git diff --name-only "$BASE" "$HEAD" 2>/dev/null || git diff --name-only HEAD~1 HEAD)
|
||||
non_docs=$(echo "$changed" | grep -Ev '^(docs/|.*\.md$|\.github/[^/]+\.md$)' || true)
|
||||
if [ -z "$non_docs" ]; then
|
||||
echo "docs_only=true" >> "$GITHUB_OUTPUT"
|
||||
echo "Docs-only change detected — skipping pytest."
|
||||
else
|
||||
echo "docs_only=false" >> "$GITHUB_OUTPUT"
|
||||
fi
|
||||
|
||||
- uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5
|
||||
if: steps.docs-check.outputs.docs_only != 'true'
|
||||
with:
|
||||
python-version: "3.11"
|
||||
cache: pip
|
||||
- run: pip install -r requirements.txt
|
||||
if: steps.docs-check.outputs.docs_only != 'true'
|
||||
- run: mkdir -p data # sqlite DB lives at ./data/app.db
|
||||
if: steps.docs-check.outputs.docs_only != 'true'
|
||||
- run: python -m pytest -q
|
||||
if: steps.docs-check.outputs.docs_only != 'true'
|
||||
|
||||
@@ -0,0 +1,140 @@
|
||||
name: ci / docker publish
|
||||
|
||||
# Build the Odysseus image and publish to GHCR.
|
||||
# push to main -> :latest, :X.Y.Z (curated release; main is fast-forwarded at releases)
|
||||
# push to dev -> :dev, :X.Y.Z-dev.<sha> (rolling dev + an immutable, traceable pin)
|
||||
# Multi-arch (linux/amd64 + linux/arm64): each arch builds on its own native
|
||||
# runner and pushes by digest, then a merge job stitches the digests into one
|
||||
# manifest list and applies the tags (faster + cleaner than QEMU emulation).
|
||||
# Registry: ghcr.io/<owner>/<repo>.
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [dev, main]
|
||||
paths-ignore:
|
||||
- '**.md'
|
||||
- 'docs/**'
|
||||
- '.github/ISSUE_TEMPLATE/**'
|
||||
|
||||
concurrency:
|
||||
group: docker-publish-${{ github.ref }}
|
||||
cancel-in-progress: true
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
env:
|
||||
REGISTRY: ghcr.io
|
||||
IMAGE_NAME: ${{ github.repository }}
|
||||
|
||||
jobs:
|
||||
build:
|
||||
name: build (${{ matrix.arch }})
|
||||
runs-on: ${{ matrix.runner }}
|
||||
permissions:
|
||||
contents: read
|
||||
packages: write
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- platform: linux/amd64
|
||||
arch: amd64
|
||||
runner: ubuntu-latest
|
||||
- platform: linux/arm64
|
||||
arch: arm64
|
||||
runner: ubuntu-24.04-arm
|
||||
steps:
|
||||
- uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
|
||||
with:
|
||||
persist-credentials: false
|
||||
- name: Set up Buildx
|
||||
uses: docker/setup-buildx-action@d7f5e7f509e45cec5c76c4d5afdd7de93d0b3df5 # v4.1.0
|
||||
- name: Log in to GHCR
|
||||
uses: docker/login-action@650006c6eb7dba73a995cc03b0b2d7f5ca915bee # v4.2.0
|
||||
with:
|
||||
registry: ${{ env.REGISTRY }}
|
||||
username: ${{ github.actor }}
|
||||
password: ${{ secrets.GITHUB_TOKEN }}
|
||||
- name: Build and push by digest
|
||||
id: build
|
||||
uses: docker/build-push-action@f9f3042f7e2789586610d6e8b85c8f03e5195baf # v7.2.0
|
||||
with:
|
||||
context: .
|
||||
platforms: ${{ matrix.platform }}
|
||||
outputs: type=image,name=${{ env.REGISTRY }}/${{ env.IMAGE_NAME }},push-by-digest=true,name-canonical=true,push=true
|
||||
cache-from: type=gha,scope=${{ matrix.arch }}
|
||||
cache-to: type=gha,mode=max,scope=${{ matrix.arch }}
|
||||
- name: Export digest
|
||||
run: |
|
||||
mkdir -p /tmp/digests
|
||||
digest="${{ steps.build.outputs.digest }}"
|
||||
touch "/tmp/digests/${digest#sha256:}"
|
||||
- name: Upload digest
|
||||
uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7.0.1
|
||||
with:
|
||||
name: digest-${{ matrix.arch }}
|
||||
path: /tmp/digests/*
|
||||
if-no-files-found: error
|
||||
retention-days: 1
|
||||
|
||||
merge:
|
||||
name: merge manifest + tag
|
||||
runs-on: ubuntu-latest
|
||||
needs: build
|
||||
permissions:
|
||||
contents: read
|
||||
packages: write
|
||||
steps:
|
||||
- uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
|
||||
with:
|
||||
persist-credentials: false
|
||||
- name: Read APP_VERSION + short sha
|
||||
id: ver
|
||||
run: |
|
||||
v=$(grep -E '^APP_VERSION' src/constants.py | head -1 | sed -E 's/.*"([^"]+)".*/\1/')
|
||||
[ -n "$v" ] || { echo "APP_VERSION not found"; exit 1; }
|
||||
echo "version=$v" >> "$GITHUB_OUTPUT"
|
||||
echo "short=${GITHUB_SHA::7}" >> "$GITHUB_OUTPUT"
|
||||
- name: Download digests
|
||||
uses: actions/download-artifact@3e5f45b2cfb9172054b4087a40e8e0b5a5461e7c # v8.0.1
|
||||
with:
|
||||
path: /tmp/digests
|
||||
pattern: digest-*
|
||||
merge-multiple: true
|
||||
- name: Set up Buildx
|
||||
uses: docker/setup-buildx-action@d7f5e7f509e45cec5c76c4d5afdd7de93d0b3df5 # v4.1.0
|
||||
- name: Log in to GHCR
|
||||
uses: docker/login-action@650006c6eb7dba73a995cc03b0b2d7f5ca915bee # v4.2.0
|
||||
with:
|
||||
registry: ${{ env.REGISTRY }}
|
||||
username: ${{ github.actor }}
|
||||
password: ${{ secrets.GITHUB_TOKEN }}
|
||||
- name: Compute tags
|
||||
id: meta
|
||||
uses: docker/metadata-action@80c7e94dd9b9319bd5eb7a0e0fe9291e23a2a2e9 # v6.1.0
|
||||
with:
|
||||
images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}
|
||||
tags: |
|
||||
type=raw,value=latest,enable=${{ github.ref == 'refs/heads/main' }}
|
||||
type=raw,value=${{ steps.ver.outputs.version }},enable=${{ github.ref == 'refs/heads/main' }}
|
||||
type=raw,value=dev,enable=${{ github.ref == 'refs/heads/dev' }}
|
||||
type=raw,value=${{ steps.ver.outputs.version }}-dev.${{ steps.ver.outputs.short }},enable=${{ github.ref == 'refs/heads/dev' }}
|
||||
- name: Create manifest list + push tags
|
||||
working-directory: /tmp/digests
|
||||
run: |
|
||||
tags=$(jq -cr '.tags | map("-t " + .) | join(" ")' <<< "$DOCKER_METADATA_OUTPUT_JSON")
|
||||
digests=$(printf "${REGISTRY}/${IMAGE_NAME}@sha256:%s " *)
|
||||
# word-splitting is intended: $tags and $digests each expand to multiple args
|
||||
# shellcheck disable=SC2086
|
||||
docker buildx imagetools create $tags $digests
|
||||
env:
|
||||
REGISTRY: ${{ env.REGISTRY }}
|
||||
IMAGE_NAME: ${{ env.IMAGE_NAME }}
|
||||
- name: Inspect
|
||||
run: |
|
||||
if [ "$GITHUB_REF" = "refs/heads/main" ]; then ref=latest; else ref=dev; fi
|
||||
docker buildx imagetools inspect "${REGISTRY}/${IMAGE_NAME}:${ref}"
|
||||
env:
|
||||
REGISTRY: ${{ env.REGISTRY }}
|
||||
IMAGE_NAME: ${{ env.IMAGE_NAME }}
|
||||
@@ -14,10 +14,11 @@ jobs:
|
||||
# Skip bots (Dependabot, release-drafter, etc.)
|
||||
if: ${{ github.event.issue.user.type != 'Bot' }}
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
|
||||
with:
|
||||
sparse-checkout: .github/scripts
|
||||
persist-credentials: false
|
||||
|
||||
- uses: actions/github-script@v7
|
||||
- uses: actions/github-script@3a2844b7e9c422d3c10d287c895573f7108da1b3 # v9.0.0
|
||||
with:
|
||||
script: return require('./.github/scripts/check-issue-description.js')({github, context, core})
|
||||
|
||||
@@ -1,28 +1,109 @@
|
||||
name: ci / PR description check
|
||||
name: ci / PR checks
|
||||
|
||||
on:
|
||||
pull_request_target:
|
||||
types: [opened, edited, synchronize, reopened]
|
||||
# pull_request_target runs in the base-repo context (has secrets) so the check
|
||||
# works on fork PRs. Safe here: the checkout pins to the base branch (no fork
|
||||
# code runs) and the scripts only read context.payload and call the GitHub API.
|
||||
pull_request_target: # zizmor: ignore[dangerous-triggers]
|
||||
types: [opened, edited, synchronize, reopened, ready_for_review]
|
||||
|
||||
# pull_request_target runs in the base-repo context (has secrets).
|
||||
# The checkout below pins to the base branch so no fork code is executed.
|
||||
# The script only reads context.payload and calls the GitHub API.
|
||||
permissions:
|
||||
issues: write
|
||||
pull-requests: write
|
||||
# Default-deny at the workflow level; each job opts into only the scopes it needs.
|
||||
# Note: modifying a PR's labels/comments needs pull-requests:write even though the
|
||||
# REST path is under /issues/{n}/...; issues:write alone returns 403 on PRs.
|
||||
permissions: {}
|
||||
|
||||
jobs:
|
||||
check-description:
|
||||
name: Check PR description
|
||||
runs-on: ubuntu-latest
|
||||
# Skip bots — they open PRs programmatically and have their own process.
|
||||
permissions:
|
||||
contents: read
|
||||
pull-requests: write
|
||||
issues: write
|
||||
# Skip bots: they open PRs programmatically and have their own process.
|
||||
if: github.event.pull_request.user.type != 'Bot'
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
|
||||
with:
|
||||
ref: ${{ github.base_ref }}
|
||||
sparse-checkout: .github/scripts
|
||||
persist-credentials: false
|
||||
|
||||
- uses: actions/github-script@v7
|
||||
- uses: actions/github-script@3a2844b7e9c422d3c10d287c895573f7108da1b3 # v9.0.0
|
||||
with:
|
||||
script: return require('./.github/scripts/check-pr-description.js')({github, context, core})
|
||||
|
||||
check-title:
|
||||
name: Check PR title (Conventional Commits)
|
||||
runs-on: ubuntu-latest
|
||||
permissions: {}
|
||||
# Skip bots: they open PRs programmatically and have their own process.
|
||||
if: github.event.pull_request.user.type != 'Bot'
|
||||
steps:
|
||||
- uses: actions/github-script@3a2844b7e9c422d3c10d287c895573f7108da1b3 # v9.0.0
|
||||
with:
|
||||
script: |
|
||||
const title = context.payload.pull_request.title || "";
|
||||
// Conventional Commits: type(optional-scope)(optional !): summary
|
||||
const re = /^(feat|fix|docs|style|refactor|perf|test|build|ci|chore|revert)(\([\w .\/-]+\))?!?: .+/;
|
||||
if (!re.test(title)) {
|
||||
core.setFailed(
|
||||
`PR title is not in Conventional Commits format:\n "${title}"\n\n` +
|
||||
`Expected: type(scope): summary\n` +
|
||||
`Example: fix(search): handle empty query\n` +
|
||||
`Types: feat, fix, docs, style, refactor, perf, test, build, ci, chore, revert.`
|
||||
);
|
||||
} else {
|
||||
core.info(`PR title OK: ${title}`);
|
||||
}
|
||||
|
||||
check-mergeable:
|
||||
name: Flag unmergeable PRs
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
pull-requests: write
|
||||
issues: write
|
||||
# Skip bots: they open PRs programmatically and have their own process.
|
||||
if: github.event.pull_request.user.type != 'Bot'
|
||||
steps:
|
||||
- uses: actions/github-script@3a2844b7e9c422d3c10d287c895573f7108da1b3 # v9.0.0
|
||||
with:
|
||||
script: |
|
||||
const repo = { owner: context.repo.owner, repo: context.repo.repo };
|
||||
const number = context.payload.pull_request.number;
|
||||
const READY = "ready for review";
|
||||
const CONFLICT = "merge conflict";
|
||||
|
||||
// Ensure the conflict label exists (red). Ignore if already present.
|
||||
try {
|
||||
await github.rest.issues.getLabel({ ...repo, name: CONFLICT });
|
||||
} catch {
|
||||
await github.rest.issues.createLabel({
|
||||
...repo, name: CONFLICT, color: "B60205",
|
||||
description: "Conflicts with the base branch; needs a rebase before review.",
|
||||
}).catch(() => {});
|
||||
}
|
||||
|
||||
// mergeable is computed asynchronously and is often null right after
|
||||
// an event, so poll a few times until GitHub has resolved it.
|
||||
let pr = null;
|
||||
for (let i = 0; i < 5; i++) {
|
||||
const { data } = await github.rest.pulls.get({ ...repo, pull_number: number });
|
||||
if (data.mergeable !== null) { pr = data; break; }
|
||||
await new Promise(r => setTimeout(r, 3000));
|
||||
}
|
||||
if (!pr || pr.draft) return;
|
||||
const labels = pr.labels.map(l => l.name);
|
||||
|
||||
if (pr.mergeable === false) {
|
||||
if (labels.includes(READY)) {
|
||||
await github.rest.issues.removeLabel({ ...repo, issue_number: number, name: READY }).catch(() => {});
|
||||
}
|
||||
if (!labels.includes(CONFLICT)) {
|
||||
await github.rest.issues.addLabels({ ...repo, issue_number: number, labels: [CONFLICT] });
|
||||
}
|
||||
} else if (pr.mergeable === true) {
|
||||
if (labels.includes(CONFLICT)) {
|
||||
await github.rest.issues.removeLabel({ ...repo, issue_number: number, name: CONFLICT }).catch(() => {});
|
||||
}
|
||||
}
|
||||
|
||||
@@ -94,6 +94,18 @@ Before submitting any change that affects what the app looks like — buttons, i
|
||||
|
||||
If you are unsure whether a change is "visual," it is. Default to attaching a screenshot.
|
||||
|
||||
## Code conventions
|
||||
|
||||
Don't hardcode values that the project already exposes through a constant or a helper. Hardcoded literals drift out of sync, break on non-default deployments, and reintroduce bugs we've already fixed.
|
||||
|
||||
- **Filesystem paths:** never build writable paths from `Path(__file__)...` into the source tree, hardcode `/app/...`, or use a relative `"data/..."` string. Every persisted file and directory has a named constant in `src/constants.py` (for example `AUTH_FILE`, `USER_PREFS_FILE`, `SETTINGS_FILE`, `TTS_CACHE_DIR`, `CHROMA_DIR`). Import and use that named constant; do not re-derive the path locally with `os.path.join(DATA_DIR, "x.json")` or `DATA_DIR / "x.json"`. `DATA_DIR` is the single place that reads `ODYSSEUS_DATA_DIR`, so use it directly only for dynamic paths that have no fixed name (for example per-owner files). If a data file or directory has no constant yet, add one to `src/constants.py`. The source tree is read-only in Docker and `/app/...` does not exist on native runs; guard directory creation so an unwritable path degrades gracefully instead of crashing at import.
|
||||
- **Internal API / loopback URLs:** don't hardcode `http://localhost:7000`. Use `internal_api_base()` from `src.constants` (it honors `ODYSSEUS_INTERNAL_BASE` / `APP_PORT`).
|
||||
- **Ports, limits, model lists, and similar:** reuse the existing constant if one exists; if it doesn't and the value is used in more than one place, add a constant rather than copying the literal.
|
||||
|
||||
If you need a value that has no constant or helper yet, add it to `src/constants.py` (the single source of truth for paths and config; `core/constants.py` only re-exports it for backward compatibility) and import it, rather than repeating a literal across files.
|
||||
|
||||
**Commits:** use [Conventional Commits](https://www.conventionalcommits.org), `type(scope): summary` (e.g. `fix(search): ...`, `feat(notes): ...`, `docs(contributing): ...`). Common types: `fix`, `feat`, `refactor`, `docs`, `test`, `chore`, `ci`. Keep the subject short and imperative; put the "why" in the body when it isn't obvious.
|
||||
|
||||
## Issue Reports
|
||||
|
||||
For bugs, include:
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
# Odysseus
|
||||
|
||||
> **Branch note:** `dev` is the default branch and contains the latest development changes, but it may be unstable. For the more stable curated branch, use [`main`](https://github.com/pewdiepie-archdaemon/odysseus/tree/main).
|
||||
|
||||
```
|
||||
───────────────────────────────────────────────
|
||||
⊹ ࣪ ˖ ૮( ˶ᵔ ᵕ ᵔ˶ )っ Odysseus vers. 1.0
|
||||
@@ -331,6 +333,12 @@ To expose Odysseus on a local network or Tailscale with HTTPS:
|
||||
| `PyMuPDF` | PDF page rendering in the side viewer panel and form-filling. (Note: AGPL-3.0) |
|
||||
| `markitdown` | Office/EPUB document text extraction (converts .docx/.xlsx/.pptx/.xls/.epub to Markdown). |
|
||||
|
||||
### Outlook / Office 365 email
|
||||
Odysseus email accounts currently use IMAP/SMTP username-password auth. Outlook
|
||||
and Microsoft 365 generally require OAuth instead, so normal Microsoft mailbox
|
||||
passwords will fail. See [docs/email-outlook.md](docs/email-outlook.md) for the
|
||||
current limitation and the planned integration direction.
|
||||
|
||||
## Security Notes
|
||||
Odysseus is a self-hosted workspace with powerful local tools: shell access, file uploads, model downloads, web research, email/calendar integrations, and API tokens. Treat it like an admin console.
|
||||
|
||||
@@ -394,6 +402,16 @@ Key settings:
|
||||
| `CHROMADB_HOST` | `localhost` | ChromaDB host for vector memory. Docker overrides this to `chromadb`. |
|
||||
| `CHROMADB_PORT` | `8100` | ChromaDB port for manual host runs. Docker overrides this to `8000`. |
|
||||
| `EMBEDDING_URL` | -- | OpenAI-compatible embeddings endpoint |
|
||||
| `ODYSSEUS_CHAT_UPLOAD_MAX_BYTES` | `10485760` | Chat/agent attachment cap in bytes. Raise for larger local PDFs or text documents. |
|
||||
| `ODYSSEUS_GALLERY_UPLOAD_MAX_BYTES` | `104857600` | Gallery image upload cap in bytes (100 MB). |
|
||||
| `ODYSSEUS_GALLERY_TRANSFORM_UPLOAD_MAX_BYTES` | `26214400` | Gallery transform input cap in bytes (25 MB). |
|
||||
| `ODYSSEUS_MEMORY_IMPORT_MAX_BYTES` | `10485760` | Memory import file cap in bytes (10 MB). |
|
||||
| `ODYSSEUS_PERSONAL_UPLOAD_MAX_BYTES` | `26214400` | Personal document upload cap in bytes (25 MB). |
|
||||
| `ODYSSEUS_EMAIL_COMPOSE_UPLOAD_MAX_BYTES` | `26214400` | Email compose attachment cap in bytes (25 MB). |
|
||||
| `ODYSSEUS_STT_MAX_AUDIO_BYTES` | `26214400` | Speech-to-text audio cap in bytes (25 MB). |
|
||||
| `ODYSSEUS_ICS_MAX_BYTES` | `10485760` | Calendar `.ics` import cap in bytes (10 MB). |
|
||||
|
||||
All upload-limit vars are validated (must be a positive integer) and optional; an invalid value fails fast at startup.
|
||||
|
||||
### Built-in MCP servers (optional setup)
|
||||
|
||||
|
||||
@@ -51,10 +51,10 @@ from starlette.middleware.base import BaseHTTPMiddleware
|
||||
# Core imports
|
||||
from core.constants import (
|
||||
BASE_DIR, STATIC_DIR, SESSIONS_FILE,
|
||||
REQUEST_TIMEOUT, OPENAI_API_KEY,
|
||||
REQUEST_TIMEOUT, OPENAI_API_KEY, AUTH_FILE,
|
||||
)
|
||||
from core.database import SessionLocal, ApiToken
|
||||
from core.middleware import SecurityHeadersMiddleware
|
||||
from core.middleware import SecurityHeadersMiddleware, is_cors_preflight
|
||||
from core.auth import AuthManager
|
||||
from core.exceptions import (
|
||||
SessionNotFoundError, InvalidFileUploadError,
|
||||
@@ -64,6 +64,7 @@ from core.exceptions import (
|
||||
import bcrypt as _bcrypt
|
||||
|
||||
from src.app_helpers import abs_join
|
||||
from src.generated_images import GENERATED_IMAGE_HEADERS, resolve_generated_image_path
|
||||
from starlette.responses import RedirectResponse
|
||||
|
||||
# ========= LOGGING =========
|
||||
@@ -252,6 +253,15 @@ if AUTH_ENABLED:
|
||||
class AuthMiddleware(BaseHTTPMiddleware):
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
path = request.url.path
|
||||
# A genuine CORS preflight (OPTIONS + Access-Control-Request-Method)
|
||||
# carries no credentials by design and must reach CORSMiddleware to be
|
||||
# answered. AuthMiddleware is the outermost middleware, so gating the
|
||||
# preflight on auth 401s it before CORS can respond -- which blocks
|
||||
# every cross-origin browser/WebView client before the real request
|
||||
# is sent. Let real preflights through (only OPTIONS w/ the ACRM
|
||||
# header; never a credentialed request).
|
||||
if is_cors_preflight(request.method, request.headers):
|
||||
return await call_next(request)
|
||||
if _is_auth_exempt(path):
|
||||
return await call_next(request)
|
||||
# In-process internal-tool token bypass. Used by the agent
|
||||
@@ -387,13 +397,7 @@ app.mount("/static", _RevalidatingStatic(directory="static"), name="static")
|
||||
@app.get("/api/generated-image/{filename}")
|
||||
async def serve_generated_image(filename: str, request: Request):
|
||||
"""Serve generated images from the data directory."""
|
||||
from pathlib import Path
|
||||
import re
|
||||
if not re.match(r'^[a-f0-9]{8,64}\.(png|jpg|jpeg|webp|gif|mp4|mov|webm|mkv|m4v)$', filename):
|
||||
raise HTTPException(status_code=400, detail="Invalid filename")
|
||||
img_path = Path("data/generated_images") / filename
|
||||
if not img_path.exists():
|
||||
raise HTTPException(status_code=404, detail="Image not found")
|
||||
img_path = resolve_generated_image_path(filename)
|
||||
# SECURITY: filename is the only key, so anyone who knows / guesses a
|
||||
# 12-hex content hash could pull another user's image bytes. Require
|
||||
# auth and verify ownership via the gallery row (when one exists).
|
||||
@@ -429,7 +433,7 @@ async def serve_generated_image(filename: str, request: Request):
|
||||
return FileResponse(
|
||||
str(img_path),
|
||||
media_type=mime,
|
||||
headers={"Cache-Control": "public, max-age=31536000, immutable"},
|
||||
headers=GENERATED_IMAGE_HEADERS,
|
||||
)
|
||||
|
||||
# ========= YOUTUBE INIT =========
|
||||
@@ -594,6 +598,10 @@ app.include_router(setup_model_routes(model_discovery))
|
||||
from routes.copilot_routes import setup_copilot_routes
|
||||
app.include_router(setup_copilot_routes())
|
||||
|
||||
# ChatGPT Subscription device-flow login
|
||||
from routes.chatgpt_subscription_routes import setup_chatgpt_subscription_routes
|
||||
app.include_router(setup_chatgpt_subscription_routes())
|
||||
|
||||
# TTS
|
||||
from routes.tts_routes import setup_tts_routes
|
||||
app.include_router(setup_tts_routes(tts_service))
|
||||
@@ -789,6 +797,8 @@ async def serve_backgrounds(request: Request):
|
||||
|
||||
@app.get("/login")
|
||||
async def serve_login(request: Request):
|
||||
if not AUTH_ENABLED:
|
||||
return RedirectResponse(url="/", status_code=302)
|
||||
return _serve_html_with_nonce(request, abs_join(BASE_DIR, "static/login.html"))
|
||||
|
||||
@app.get("/api/version")
|
||||
@@ -948,7 +958,7 @@ async def _startup_event():
|
||||
owners = set()
|
||||
try:
|
||||
import json as _json
|
||||
auth_path = "data/auth.json"
|
||||
auth_path = AUTH_FILE
|
||||
with open(auth_path, encoding="utf-8") as f:
|
||||
users = _json.load(f).get("users", {})
|
||||
owners.update(users.keys())
|
||||
@@ -995,7 +1005,7 @@ async def _startup_event():
|
||||
# does not make an existing library look empty after auth/account changes.
|
||||
try:
|
||||
import json as _json
|
||||
auth_path = "data/auth.json"
|
||||
auth_path = AUTH_FILE
|
||||
with open(auth_path, encoding="utf-8") as f:
|
||||
users = _json.load(f).get("users", {})
|
||||
primary_owner = None
|
||||
|
||||
@@ -14,6 +14,8 @@ import uuid
|
||||
|
||||
import bcrypt
|
||||
|
||||
from src.constants import AUTH_FILE
|
||||
|
||||
PAIRING_VERSION = 1
|
||||
COMPANION_SCOPE = "chat"
|
||||
|
||||
@@ -61,7 +63,7 @@ def lan_ip_candidates() -> list[str]:
|
||||
def find_admin_user() -> str | None:
|
||||
"""Resolve an admin username from data/auth.json (schema uses is_admin),
|
||||
falling back to the first user."""
|
||||
auth_path = os.path.join("data", "auth.json")
|
||||
auth_path = AUTH_FILE
|
||||
try:
|
||||
with open(auth_path, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
|
||||
+92
-62
@@ -30,14 +30,24 @@ DEFAULT_PRIVILEGES = {
|
||||
"can_manage_memory": True,
|
||||
"max_messages_per_day": 0,
|
||||
"allowed_models": [],
|
||||
"allowed_models_restricted": False,
|
||||
# Explicit "block every model" sentinel. An empty `allowed_models` list is
|
||||
# ambiguous — it's also what gets sent when the admin clicks "[All]" — so
|
||||
# we need a dedicated flag to express "this user may use no models at all"
|
||||
# distinctly from "this user has no restriction".
|
||||
"block_all_models": False,
|
||||
}
|
||||
|
||||
# Admins get everything
|
||||
ADMIN_PRIVILEGES = {k: (True if isinstance(v, bool) else (0 if isinstance(v, int) else [])) for k, v in DEFAULT_PRIVILEGES.items()}
|
||||
ADMIN_PRIVILEGES["allowed_models_restricted"] = False
|
||||
# Admins must never be blocked from using models — the generic dict
|
||||
# comprehension above flips every boolean default to True, which would be
|
||||
# backwards for this sentinel.
|
||||
ADMIN_PRIVILEGES["block_all_models"] = False
|
||||
|
||||
DEFAULT_AUTH_PATH = os.path.join(
|
||||
Path(__file__).parent.parent, "data", "auth.json"
|
||||
)
|
||||
from src.constants import AUTH_FILE
|
||||
DEFAULT_AUTH_PATH = AUTH_FILE
|
||||
TOKEN_TTL = 60 * 60 * 24 * 7 # 7 days
|
||||
|
||||
# Usernames the auth + middleware layer reserve as internal "synthetic owner"
|
||||
@@ -76,6 +86,10 @@ class AuthManager:
|
||||
# Guards mutations of self._sessions and the on-disk sessions.json.
|
||||
# Validate/create/revoke run concurrently from the FastAPI threadpool.
|
||||
self._sessions_lock = threading.RLock()
|
||||
# Guards all mutations of self._config and the on-disk auth.json so
|
||||
# concurrent create/delete/rename/privilege operations don't interleave
|
||||
# and corrupt the user database.
|
||||
self._config_lock = threading.Lock()
|
||||
# Guards the first-run setup check-and-write so concurrent requests
|
||||
# cannot both observe is_configured==False and both create admin accounts.
|
||||
self._setup_lock = threading.Lock()
|
||||
@@ -172,8 +186,9 @@ class AuthManager:
|
||||
|
||||
@signup_enabled.setter
|
||||
def signup_enabled(self, value: bool):
|
||||
self._config["signup_enabled"] = value
|
||||
self._save()
|
||||
with self._config_lock:
|
||||
self._config["signup_enabled"] = value
|
||||
self._save()
|
||||
|
||||
@property
|
||||
def is_configured(self) -> bool:
|
||||
@@ -198,17 +213,18 @@ class AuthManager:
|
||||
if username in RESERVED_USERNAMES:
|
||||
logger.warning("Refused to create reserved username '%s'", username)
|
||||
return False
|
||||
if username in self.users:
|
||||
return False
|
||||
if "users" not in self._config:
|
||||
self._config["users"] = {}
|
||||
self._config["users"][username] = {
|
||||
"password_hash": _hash_password(password),
|
||||
"created": time.time(),
|
||||
"is_admin": is_admin,
|
||||
"privileges": dict(ADMIN_PRIVILEGES if is_admin else DEFAULT_PRIVILEGES),
|
||||
}
|
||||
self._save()
|
||||
with self._config_lock:
|
||||
if username in self.users:
|
||||
return False
|
||||
if "users" not in self._config:
|
||||
self._config["users"] = {}
|
||||
self._config["users"][username] = {
|
||||
"password_hash": _hash_password(password),
|
||||
"created": time.time(),
|
||||
"is_admin": is_admin,
|
||||
"privileges": dict(ADMIN_PRIVILEGES if is_admin else DEFAULT_PRIVILEGES),
|
||||
}
|
||||
self._save()
|
||||
logger.info(f"Created user '{username}' (admin={is_admin})")
|
||||
return True
|
||||
|
||||
@@ -221,14 +237,15 @@ class AuthManager:
|
||||
their cookie expired naturally (default ~30 days).
|
||||
"""
|
||||
username = username.strip().lower()
|
||||
if username not in self.users:
|
||||
return False
|
||||
if username == requesting_user:
|
||||
return False
|
||||
if not self.users.get(requesting_user, {}).get("is_admin"):
|
||||
return False
|
||||
del self._config["users"][username]
|
||||
self._save()
|
||||
with self._config_lock:
|
||||
if username not in self.users:
|
||||
return False
|
||||
if username == requesting_user:
|
||||
return False
|
||||
if not self.users.get(requesting_user, {}).get("is_admin"):
|
||||
return False
|
||||
del self._config["users"][username]
|
||||
self._save()
|
||||
# Purge all sessions belonging to this user. validate_token doesn't
|
||||
# cross-check `self.users`, so without this step a deleted user's
|
||||
# cookie keeps authenticating.
|
||||
@@ -266,14 +283,15 @@ class AuthManager:
|
||||
if new_username in RESERVED_USERNAMES:
|
||||
logger.warning("Refused to rename '%s' into reserved username '%s'", old_username, new_username)
|
||||
return False
|
||||
if old_username not in self.users:
|
||||
return False
|
||||
if new_username in self.users:
|
||||
return False
|
||||
if not self.users.get(requesting_user, {}).get("is_admin"):
|
||||
return False
|
||||
self._config.setdefault("users", {})[new_username] = self._config["users"].pop(old_username)
|
||||
self._save()
|
||||
with self._config_lock:
|
||||
if old_username not in self.users:
|
||||
return False
|
||||
if new_username in self.users:
|
||||
return False
|
||||
if not self.users.get(requesting_user, {}).get("is_admin"):
|
||||
return False
|
||||
self._config.setdefault("users", {})[new_username] = self._config["users"].pop(old_username)
|
||||
self._save()
|
||||
|
||||
renamed_sessions = 0
|
||||
with self._sessions_lock:
|
||||
@@ -311,17 +329,18 @@ class AuthManager:
|
||||
def set_privileges(self, username: str, privileges: Dict[str, Any]) -> bool:
|
||||
"""Update privileges for a user. Can't modify admin privileges."""
|
||||
username = username.strip().lower()
|
||||
if username not in self.users:
|
||||
return False
|
||||
if self.users[username].get("is_admin"):
|
||||
return False # admins always have full access
|
||||
# Only allow known privilege keys
|
||||
current = self.get_privileges(username)
|
||||
for k, v in privileges.items():
|
||||
if k in DEFAULT_PRIVILEGES:
|
||||
current[k] = v
|
||||
self._config["users"][username]["privileges"] = current
|
||||
self._save()
|
||||
with self._config_lock:
|
||||
if username not in self.users:
|
||||
return False
|
||||
if self.users[username].get("is_admin"):
|
||||
return False # admins always have full access
|
||||
# Only allow known privilege keys
|
||||
current = self.get_privileges(username)
|
||||
for k, v in privileges.items():
|
||||
if k in DEFAULT_PRIVILEGES:
|
||||
current[k] = v
|
||||
self._config["users"][username]["privileges"] = current
|
||||
self._save()
|
||||
logger.info(f"Updated privileges for '{username}': {current}")
|
||||
return True
|
||||
|
||||
@@ -331,8 +350,9 @@ class AuthManager:
|
||||
return False
|
||||
if not _verify_password(current_password, self.users[username]["password_hash"]):
|
||||
return False
|
||||
self._config["users"][username]["password_hash"] = _hash_password(new_password)
|
||||
self._save()
|
||||
with self._config_lock:
|
||||
self._config["users"][username]["password_hash"] = _hash_password(new_password)
|
||||
self._save()
|
||||
return True
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
@@ -350,8 +370,9 @@ class AuthManager:
|
||||
if username not in self.users:
|
||||
return None
|
||||
secret = pyotp.random_base32()
|
||||
self._config["users"][username]["totp_secret_pending"] = secret
|
||||
self._save()
|
||||
with self._config_lock:
|
||||
self._config["users"][username]["totp_secret_pending"] = secret
|
||||
self._save()
|
||||
return secret
|
||||
|
||||
def totp_get_provisioning_uri(self, username: str, secret: str) -> str:
|
||||
@@ -370,13 +391,14 @@ class AuthManager:
|
||||
if not totp.verify(code, valid_window=1):
|
||||
return False
|
||||
# Enable 2FA
|
||||
self._config["users"][username]["totp_secret"] = secret
|
||||
self._config["users"][username]["totp_enabled"] = True
|
||||
self._config["users"][username].pop("totp_secret_pending", None)
|
||||
# Generate backup codes
|
||||
backup = [secrets.token_hex(4) for _ in range(8)]
|
||||
self._config["users"][username]["totp_backup_codes"] = backup
|
||||
self._save()
|
||||
with self._config_lock:
|
||||
self._config["users"][username]["totp_secret"] = secret
|
||||
self._config["users"][username]["totp_enabled"] = True
|
||||
self._config["users"][username].pop("totp_secret_pending", None)
|
||||
# Generate backup codes
|
||||
backup = [secrets.token_hex(4) for _ in range(8)]
|
||||
self._config["users"][username]["totp_backup_codes"] = backup
|
||||
self._save()
|
||||
logger.info(f"2FA enabled for '{username}'")
|
||||
return True
|
||||
|
||||
@@ -395,9 +417,10 @@ class AuthManager:
|
||||
# Check backup codes first
|
||||
backup = user.get("totp_backup_codes", [])
|
||||
if code in backup:
|
||||
backup.remove(code)
|
||||
self._config["users"][username]["totp_backup_codes"] = backup
|
||||
self._save()
|
||||
with self._config_lock:
|
||||
backup.remove(code)
|
||||
self._config["users"][username]["totp_backup_codes"] = backup
|
||||
self._save()
|
||||
logger.info(f"Backup code used for '{username}' ({len(backup)} remaining)")
|
||||
return True
|
||||
totp = pyotp.TOTP(secret)
|
||||
@@ -408,11 +431,12 @@ class AuthManager:
|
||||
username = username.strip().lower()
|
||||
if not self.verify_password(username, password):
|
||||
return False
|
||||
self._config["users"][username].pop("totp_secret", None)
|
||||
self._config["users"][username].pop("totp_secret_pending", None)
|
||||
self._config["users"][username].pop("totp_backup_codes", None)
|
||||
self._config["users"][username]["totp_enabled"] = False
|
||||
self._save()
|
||||
with self._config_lock:
|
||||
self._config["users"][username].pop("totp_secret", None)
|
||||
self._config["users"][username].pop("totp_secret_pending", None)
|
||||
self._config["users"][username].pop("totp_backup_codes", None)
|
||||
self._config["users"][username]["totp_enabled"] = False
|
||||
self._save()
|
||||
logger.info(f"2FA disabled for '{username}'")
|
||||
return True
|
||||
|
||||
@@ -431,6 +455,12 @@ class AuthManager:
|
||||
username = username.strip().lower()
|
||||
if not self.verify_password(username, password):
|
||||
return None
|
||||
return self.create_session_trusted(username)
|
||||
|
||||
def create_session_trusted(self, username: str) -> str:
|
||||
"""Issue a session token for an already-verified user.
|
||||
Call only after verify_password (and TOTP if enabled) have passed."""
|
||||
username = username.strip().lower()
|
||||
token = secrets.token_hex(32)
|
||||
with self._sessions_lock:
|
||||
self._sessions[token] = {
|
||||
|
||||
+11
-39
@@ -1,40 +1,12 @@
|
||||
# src/constants.py
|
||||
"""Application-wide constants and configuration values."""
|
||||
import os
|
||||
# core/constants.py
|
||||
"""Backward-compatible shim — the single source of truth is src/constants.py.
|
||||
|
||||
APP_VERSION = "0.9.1"
|
||||
|
||||
# Base paths
|
||||
BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + "/"
|
||||
STATIC_DIR = os.path.join(BASE_DIR, "static")
|
||||
DATA_DIR = os.path.join(BASE_DIR, "data")
|
||||
|
||||
# Data file paths
|
||||
SESSIONS_FILE = os.path.join(DATA_DIR, "sessions.json")
|
||||
MEMORY_FILE = os.path.join(DATA_DIR, "memory.json")
|
||||
MEMORY_DOC = os.path.join(DATA_DIR, "memory_doc.md")
|
||||
PERSONAL_DIR = os.path.join(DATA_DIR, "personal_docs")
|
||||
RUNBOOK_DIR = os.path.join(PERSONAL_DIR, "runbook")
|
||||
UPLOAD_DIR = os.path.join(DATA_DIR, "uploads")
|
||||
FEATURES_FILE = os.path.join(DATA_DIR, "features.json")
|
||||
SETTINGS_FILE = os.path.join(DATA_DIR, "settings.json")
|
||||
|
||||
# API Configuration
|
||||
MAX_CONTEXT_MESSAGES = 90
|
||||
REQUEST_TIMEOUT = 20
|
||||
OPENAI_COMPAT_PATH = "/v1/chat/completions"
|
||||
|
||||
# Environment variables with defaults
|
||||
DEFAULT_HOST = os.getenv("LLM_HOST", "localhost")
|
||||
LLM_HOSTS = [h.strip() for h in os.getenv("LLM_HOSTS", "").split(",") if h.strip()]
|
||||
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
|
||||
SEARXNG_INSTANCE = os.getenv('SEARXNG_INSTANCE', 'http://localhost:8080')
|
||||
|
||||
|
||||
# Cleanup configuration
|
||||
CLEANUP_ENABLED = os.getenv("CLEANUP_ENABLED", "True").lower() == "true"
|
||||
CLEANUP_INTERVAL_HOURS = int(os.getenv("CLEANUP_INTERVAL_HOURS", "24"))
|
||||
|
||||
# Default parameters
|
||||
DEFAULT_TEMPERATURE = 1.0
|
||||
DEFAULT_MAX_TOKENS = 0
|
||||
Historically there were two copies of this module (this one lagged behind at
|
||||
APP_VERSION 0.9.1 and was missing the consolidated tool-output constants). To
|
||||
kill the drift, this now simply re-exports everything from src.constants so
|
||||
there is exactly one place that defines paths and reads ODYSSEUS_DATA_DIR.
|
||||
internal_api_base() also lives in src.constants now and is re-exported here so
|
||||
existing `from core.constants import internal_api_base` callers keep working.
|
||||
"""
|
||||
from src.constants import * # noqa: F401,F403
|
||||
from src.constants import internal_api_base # noqa: F401 (explicit: functions aren't covered by some linters' * checks)
|
||||
|
||||
+168
-7
@@ -29,8 +29,9 @@ class TimestampMixin:
|
||||
def updated_at(cls):
|
||||
return Column(DateTime, default=utcnow_naive, onupdate=utcnow_naive, nullable=False)
|
||||
|
||||
# Get database URL from environment, default to SQLite
|
||||
DATABASE_URL = os.getenv("DATABASE_URL", "sqlite:///./data/app.db")
|
||||
# Get database URL from environment, default to SQLite in DATA_DIR
|
||||
from src.constants import DATA_DIR, AUTH_FILE, MEMORY_FILE, USER_PREFS_FILE, SETTINGS_FILE
|
||||
DATABASE_URL = os.getenv("DATABASE_URL", f"sqlite:///{DATA_DIR}/app.db")
|
||||
|
||||
# Create engine
|
||||
engine = create_engine(
|
||||
@@ -360,6 +361,24 @@ class ModelEndpoint(TimestampMixin, Base):
|
||||
# is the historical default. When non-null, the model picker only shows
|
||||
# the endpoint to that user (admins always see everything).
|
||||
owner = Column(String, nullable=True, index=True)
|
||||
# Optional OAuth/session-backed credential row. Used by subscription-backed
|
||||
# providers that need refresh tokens instead of a static API key.
|
||||
provider_auth_id = Column(String, nullable=True, index=True)
|
||||
|
||||
|
||||
class ProviderAuthSession(TimestampMixin, Base):
|
||||
"""Encrypted OAuth/session credentials for refresh-aware model providers."""
|
||||
__tablename__ = "provider_auth_sessions"
|
||||
|
||||
id = Column(String, primary_key=True, index=True)
|
||||
provider = Column(String, nullable=False, index=True)
|
||||
owner = Column(String, nullable=True, index=True)
|
||||
label = Column(String, nullable=True)
|
||||
base_url = Column(String, nullable=False)
|
||||
access_token = Column(EncryptedText, nullable=True)
|
||||
refresh_token = Column(EncryptedText, nullable=True)
|
||||
last_refresh = Column(DateTime, nullable=True)
|
||||
auth_mode = Column(String, nullable=True)
|
||||
|
||||
class McpServer(TimestampMixin, Base):
|
||||
"""Admin-configured MCP (Model Context Protocol) tool servers."""
|
||||
@@ -800,6 +819,26 @@ def _migrate_add_model_endpoint_owner_column():
|
||||
logging.getLogger(__name__).warning(f"model_endpoints.owner migration failed: {e}")
|
||||
|
||||
|
||||
def _migrate_add_provider_auth_id_column():
|
||||
"""Add provider_auth_id column to model_endpoints if it doesn't exist."""
|
||||
import sqlite3
|
||||
db_path = DATABASE_URL.replace("sqlite:///", "")
|
||||
if not os.path.exists(db_path):
|
||||
return
|
||||
try:
|
||||
conn = sqlite3.connect(db_path)
|
||||
cursor = conn.execute("PRAGMA table_info(model_endpoints)")
|
||||
columns = [row[1] for row in cursor.fetchall()]
|
||||
if columns and "provider_auth_id" not in columns:
|
||||
conn.execute("ALTER TABLE model_endpoints ADD COLUMN provider_auth_id VARCHAR")
|
||||
conn.execute("CREATE INDEX IF NOT EXISTS ix_model_endpoints_provider_auth_id ON model_endpoints(provider_auth_id)")
|
||||
conn.commit()
|
||||
logging.getLogger(__name__).info("Migrated: added 'provider_auth_id' column + index to model_endpoints")
|
||||
conn.close()
|
||||
except Exception as e:
|
||||
logging.getLogger(__name__).warning(f"model_endpoints.provider_auth_id migration failed: {e}")
|
||||
|
||||
|
||||
def _migrate_add_model_type_column():
|
||||
"""Add model_type column to model_endpoints if it doesn't exist."""
|
||||
import sqlite3
|
||||
@@ -1065,7 +1104,7 @@ def _migrate_assign_legacy_owner():
|
||||
# fell through to "first user" every time.
|
||||
auth_path = os.path.join(os.path.dirname(DATABASE_URL.replace("sqlite:///", "")), "auth.json")
|
||||
if not os.path.isabs(auth_path):
|
||||
auth_path = os.path.join("data", "auth.json")
|
||||
auth_path = AUTH_FILE
|
||||
admin_user = None
|
||||
try:
|
||||
with open(auth_path, "r", encoding="utf-8") as f:
|
||||
@@ -1118,7 +1157,7 @@ def _migrate_assign_legacy_owner():
|
||||
logger.warning(f"Legacy owner migration failed: {e}")
|
||||
|
||||
# Also migrate memory.json
|
||||
mem_path = os.path.join("data", "memory.json")
|
||||
mem_path = MEMORY_FILE
|
||||
try:
|
||||
if os.path.exists(mem_path):
|
||||
with open(mem_path, "r", encoding="utf-8") as f:
|
||||
@@ -1136,7 +1175,7 @@ def _migrate_assign_legacy_owner():
|
||||
logger.warning(f"memory.json legacy migration failed: {e}")
|
||||
|
||||
# Also migrate user_prefs.json to per-user format
|
||||
prefs_path = os.path.join("data", "user_prefs.json")
|
||||
prefs_path = USER_PREFS_FILE
|
||||
try:
|
||||
if os.path.exists(prefs_path):
|
||||
with open(prefs_path, "r", encoding="utf-8") as f:
|
||||
@@ -1458,7 +1497,11 @@ class CalendarCal(TimestampMixin, Base):
|
||||
owner = Column(String, nullable=True, index=True)
|
||||
name = Column(String, nullable=False)
|
||||
color = Column(String, default="#5b8abf")
|
||||
source = Column(String, default="local") # "local" or "timetree"
|
||||
source = Column(String, default="local") # "local" or "caldav"
|
||||
# UUID of the CalDAV account in user prefs that owns this calendar.
|
||||
# NULL for local calendars and for CalDAV calendars created before
|
||||
# multi-account support was added (treated as "use any configured account").
|
||||
account_id = Column(String, nullable=True, index=True)
|
||||
|
||||
events = relationship("CalendarEvent", back_populates="calendar", cascade="all, delete-orphan")
|
||||
|
||||
@@ -1526,7 +1569,7 @@ def _migrate_seed_email_account():
|
||||
import json as _json
|
||||
import uuid as _uuid
|
||||
from pathlib import Path
|
||||
settings_file = Path("data/settings.json")
|
||||
settings_file = Path(SETTINGS_FILE)
|
||||
if not settings_file.exists():
|
||||
return
|
||||
try:
|
||||
@@ -1594,6 +1637,7 @@ def init_db():
|
||||
_migrate_add_model_type_column()
|
||||
_migrate_add_model_endpoint_refresh_columns()
|
||||
_migrate_add_model_endpoint_owner_column()
|
||||
_migrate_add_provider_auth_id_column()
|
||||
_migrate_add_supports_tools_column()
|
||||
_migrate_add_task_run_model_column()
|
||||
_migrate_add_owner_column()
|
||||
@@ -1622,9 +1666,105 @@ def init_db():
|
||||
_migrate_add_calendar_metadata()
|
||||
_migrate_add_calendar_is_utc()
|
||||
_migrate_add_calendar_origin()
|
||||
_migrate_add_calendar_account_id()
|
||||
_migrate_chat_messages_fts()
|
||||
_migrate_encrypt_email_passwords()
|
||||
_migrate_encrypt_signatures()
|
||||
_migrate_encrypt_endpoint_keys()
|
||||
_migrate_backfill_task_folders()
|
||||
|
||||
|
||||
def _migrate_backfill_task_folders():
|
||||
"""Backfill folder='Tasks' on pre-existing task/research sessions.
|
||||
|
||||
Sessions created by the task scheduler (LLM tasks, action tasks, research
|
||||
runs) now set folder='Tasks' at creation time. This migration tags any
|
||||
older sessions that predate that assignment. Idempotent — only touches
|
||||
rows where folder is NULL or empty and the title matches known prefixes.
|
||||
"""
|
||||
try:
|
||||
with engine.connect() as conn:
|
||||
cols = [r[1] for r in conn.execute(text("PRAGMA table_info(sessions)"))]
|
||||
if "folder" not in cols:
|
||||
return
|
||||
res = conn.execute(text(
|
||||
"UPDATE sessions SET folder = 'Tasks' "
|
||||
"WHERE (folder IS NULL OR folder = '') "
|
||||
"AND (name LIKE '[Task] %' OR name LIKE '[Research] %')"
|
||||
))
|
||||
conn.commit()
|
||||
if res.rowcount:
|
||||
logging.getLogger(__name__).info(
|
||||
f"Backfilled folder='Tasks' on {res.rowcount} task/research sessions")
|
||||
except Exception as e:
|
||||
logging.getLogger(__name__).warning(f"task folder backfill: {e}")
|
||||
|
||||
|
||||
def _migrate_chat_messages_fts():
|
||||
"""Create and backfill the session transcript FTS index for SQLite."""
|
||||
if not DATABASE_URL.startswith("sqlite"):
|
||||
return
|
||||
|
||||
db_path = DATABASE_URL.replace("sqlite:///", "")
|
||||
if db_path == ":memory:":
|
||||
return
|
||||
conn = None
|
||||
try:
|
||||
conn = sqlite3.connect(db_path)
|
||||
try:
|
||||
conn.execute("CREATE VIRTUAL TABLE IF NOT EXISTS temp._odysseus_fts5_probe USING fts5(content)")
|
||||
conn.execute("DROP TABLE IF EXISTS temp._odysseus_fts5_probe")
|
||||
except Exception as e:
|
||||
logging.getLogger(__name__).warning(f"chat_messages FTS migration skipped; FTS5 unavailable: {e}")
|
||||
return
|
||||
|
||||
conn.executescript(
|
||||
"""
|
||||
CREATE VIRTUAL TABLE IF NOT EXISTS chat_messages_fts USING fts5(
|
||||
content,
|
||||
message_id UNINDEXED,
|
||||
session_id UNINDEXED,
|
||||
role UNINDEXED
|
||||
);
|
||||
|
||||
CREATE TRIGGER IF NOT EXISTS chat_messages_fts_ai
|
||||
AFTER INSERT ON chat_messages BEGIN
|
||||
INSERT INTO chat_messages_fts(content, message_id, session_id, role)
|
||||
VALUES (COALESCE(new.content, ''), new.id, new.session_id, new.role);
|
||||
END;
|
||||
|
||||
CREATE TRIGGER IF NOT EXISTS chat_messages_fts_ad
|
||||
AFTER DELETE ON chat_messages BEGIN
|
||||
DELETE FROM chat_messages_fts WHERE message_id = old.id;
|
||||
END;
|
||||
|
||||
CREATE TRIGGER IF NOT EXISTS chat_messages_fts_au
|
||||
AFTER UPDATE ON chat_messages BEGIN
|
||||
DELETE FROM chat_messages_fts WHERE message_id = old.id;
|
||||
INSERT INTO chat_messages_fts(content, message_id, session_id, role)
|
||||
VALUES (COALESCE(new.content, ''), new.id, new.session_id, new.role);
|
||||
END;
|
||||
"""
|
||||
)
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT INTO chat_messages_fts(content, message_id, session_id, role)
|
||||
SELECT COALESCE(cm.content, ''), cm.id, cm.session_id, cm.role
|
||||
FROM chat_messages cm
|
||||
WHERE NOT EXISTS (
|
||||
SELECT 1 FROM chat_messages_fts fts
|
||||
WHERE fts.message_id = cm.id
|
||||
)
|
||||
"""
|
||||
)
|
||||
conn.commit()
|
||||
except Exception as e:
|
||||
logging.getLogger(__name__).warning(f"chat_messages FTS migration failed: {e}")
|
||||
finally:
|
||||
try:
|
||||
conn.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def _migrate_add_email_smtp_security():
|
||||
@@ -1786,6 +1926,27 @@ def _migrate_add_calendar_origin():
|
||||
logging.getLogger(__name__).warning(f"calendar_events.origin migration failed: {e}")
|
||||
|
||||
|
||||
def _migrate_add_calendar_account_id():
|
||||
"""Add `account_id` to calendars so each CalDAV-backed calendar knows which
|
||||
credential set (from caldav_accounts in user prefs) owns it. Idempotent."""
|
||||
import sqlite3
|
||||
db_path = DATABASE_URL.replace("sqlite:///", "")
|
||||
if not os.path.exists(db_path):
|
||||
return
|
||||
try:
|
||||
conn = sqlite3.connect(db_path)
|
||||
cursor = conn.execute("PRAGMA table_info(calendars)")
|
||||
columns = [row[1] for row in cursor.fetchall()]
|
||||
if columns and "account_id" not in columns:
|
||||
conn.execute("ALTER TABLE calendars ADD COLUMN account_id TEXT")
|
||||
conn.execute("CREATE INDEX IF NOT EXISTS ix_calendars_account_id ON calendars(account_id)")
|
||||
conn.commit()
|
||||
logging.getLogger(__name__).info("Migrated: added 'account_id' column to calendars")
|
||||
conn.close()
|
||||
except Exception as e:
|
||||
logging.getLogger(__name__).warning(f"calendars.account_id migration failed: {e}")
|
||||
|
||||
|
||||
def _migrate_add_calendar_metadata():
|
||||
"""Add importance/event_type/last_pinged columns to calendar_events table."""
|
||||
import sqlite3
|
||||
|
||||
@@ -17,6 +17,15 @@ INTERNAL_TOOL_TOKEN = os.environ.get("ODYSSEUS_INTERNAL_TOKEN") or secrets.token
|
||||
INTERNAL_TOOL_HEADER = "X-Odysseus-Internal-Token"
|
||||
|
||||
|
||||
def is_cors_preflight(method: str, headers) -> bool:
|
||||
"""True for a genuine CORS preflight: an OPTIONS request carrying the
|
||||
Access-Control-Request-Method header. Such requests are credential-less by
|
||||
design and must reach CORSMiddleware to be answered -- gating them on auth
|
||||
401s the preflight and breaks every cross-origin browser/WebView client.
|
||||
Pure so it can be unit-tested without standing up the app."""
|
||||
return method == "OPTIONS" and "access-control-request-method" in headers
|
||||
|
||||
|
||||
def require_admin(request: Request):
|
||||
"""Raise 403 if the current user isn't an admin.
|
||||
Allows access when auth is explicitly disabled, or when the request carries
|
||||
@@ -58,11 +67,22 @@ class SecurityHeadersMiddleware(BaseHTTPMiddleware):
|
||||
|
||||
# Tool render endpoints are served inside iframes — allow framing by self
|
||||
is_tool_render = path.startswith("/api/tools/") and path.endswith("/render")
|
||||
# PDF previews are embedded by the in-app document library. Keep the
|
||||
# exception route-scoped so normal app pages remain unframeable.
|
||||
is_document_pdf_preview = path.startswith("/api/document/") and path.endswith("/render-pdf")
|
||||
# Visual report pages are self-contained HTML — need inline scripts + external images
|
||||
is_report = path.startswith("/api/research/report/")
|
||||
|
||||
response.headers["X-Content-Type-Options"] = "nosniff"
|
||||
response.headers["Referrer-Policy"] = "no-referrer"
|
||||
response.headers["Permissions-Policy"] = "camera=(), microphone=(self), geolocation=()"
|
||||
|
||||
is_https = (
|
||||
request.url.scheme == "https"
|
||||
or request.headers.get("X-Forwarded-Proto") == "https"
|
||||
)
|
||||
if is_https:
|
||||
response.headers["Strict-Transport-Security"] = "max-age=31536000; includeSubDomains"
|
||||
|
||||
if is_report:
|
||||
response.headers["Content-Security-Policy"] = (
|
||||
@@ -79,6 +99,12 @@ class SecurityHeadersMiddleware(BaseHTTPMiddleware):
|
||||
# sandbox="allow-scripts" attribute provides isolation.
|
||||
# Don't overwrite the route's own restrictive CSP either.
|
||||
pass
|
||||
elif is_document_pdf_preview:
|
||||
response.headers["X-Frame-Options"] = "SAMEORIGIN"
|
||||
response.headers["Content-Security-Policy"] = (
|
||||
"default-src 'none'; "
|
||||
"frame-ancestors 'self'"
|
||||
)
|
||||
else:
|
||||
response.headers["X-Frame-Options"] = "DENY"
|
||||
# NOTE: `style-src 'unsafe-inline'` is intentionally retained.
|
||||
|
||||
+205
-3
@@ -18,10 +18,22 @@ import ntpath
|
||||
import shutil
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
import sys
|
||||
from typing import List, Optional
|
||||
import platform
|
||||
|
||||
IS_WINDOWS = os.name == "nt"
|
||||
IS_POSIX = not IS_WINDOWS
|
||||
# Allows APFEL support and ARM-native binary recommendations on Apple Silicon Macs.
|
||||
IS_APPLE_SILICON = (
|
||||
IS_POSIX
|
||||
and platform.system() == "Darwin"
|
||||
and platform.machine().lower()
|
||||
in {
|
||||
"arm64",
|
||||
"aarch64",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
# ── File permissions ────────────────────────────────────────────────────────
|
||||
@@ -53,9 +65,8 @@ def detached_popen_kwargs() -> dict:
|
||||
and is detached from any console.
|
||||
"""
|
||||
if IS_WINDOWS:
|
||||
flags = (
|
||||
getattr(subprocess, "CREATE_NEW_PROCESS_GROUP", 0x00000200)
|
||||
| getattr(subprocess, "DETACHED_PROCESS", 0x00000008)
|
||||
flags = getattr(subprocess, "CREATE_NEW_PROCESS_GROUP", 0x00000200) | getattr(
|
||||
subprocess, "DETACHED_PROCESS", 0x00000008
|
||||
)
|
||||
return {"creationflags": flags}
|
||||
return {"start_new_session": True}
|
||||
@@ -150,6 +161,29 @@ _WINDOWS_BASH_RELATIVE_PATHS = (
|
||||
("usr", "bin", "bash.exe"),
|
||||
)
|
||||
|
||||
# Paths to add to the remote SSH probe command to find tools like nvidia-smi that may not be on PATH.
|
||||
_SSH_PATH_MEMBERS = (
|
||||
"/usr/bin",
|
||||
"/usr/local/bin",
|
||||
"/usr/local/cuda/bin",
|
||||
"/usr/lib/wsl/lib"
|
||||
)
|
||||
# Fallback locations for nvidia-smi on WSL and other Linux distros where it may not be on PATH.
|
||||
NVIDIA_PATH_CANDIDATES = (
|
||||
"/usr/bin/nvidia-smi",
|
||||
"/usr/local/bin/nvidia-smi",
|
||||
"/usr/local/cuda/bin/nvidia-smi",
|
||||
"/usr/lib/wsl/lib/nvidia-smi",
|
||||
)
|
||||
|
||||
|
||||
def _ssh_path_override() -> str:
|
||||
"""Build the PATH export snippet used for remote SSH shell probes."""
|
||||
return f"export PATH=\"$PATH:{':'.join(_SSH_PATH_MEMBERS)}\"; "
|
||||
|
||||
|
||||
SSH_PATH_OVERRIDE = _ssh_path_override()
|
||||
|
||||
|
||||
def _windows_bash_fallbacks() -> List[str]:
|
||||
roots: List[str] = []
|
||||
@@ -180,6 +214,21 @@ def _is_windows_bash_stub(path: str) -> bool:
|
||||
)
|
||||
|
||||
|
||||
def git_bash_path(path: str | Path) -> str:
|
||||
"""Convert a path to POSIX style suitable for Git Bash on Windows.
|
||||
|
||||
Transforms drive letters (e.g., 'C:\\path') to POSIX '/c/path',
|
||||
and uses forward slashes.
|
||||
"""
|
||||
p = Path(path)
|
||||
p_str = p.as_posix()
|
||||
if IS_WINDOWS and len(p_str) >= 2 and p_str[1] == ":":
|
||||
drive = p_str[0].lower()
|
||||
return f"/{drive}{p_str[2:]}"
|
||||
return p_str
|
||||
|
||||
|
||||
|
||||
def find_bash() -> Optional[str]:
|
||||
"""Locate a real ``bash`` interpreter, or None.
|
||||
|
||||
@@ -242,3 +291,156 @@ def run_script_argv(script_path) -> List[str]:
|
||||
comspec = os.environ.get("ComSpec", "cmd.exe")
|
||||
return [comspec, "/c", str(script_path)]
|
||||
return ["sh", str(script_path)]
|
||||
|
||||
|
||||
def is_wsl() -> bool:
|
||||
"""True if running inside Windows Subsystem for Linux (WSL)."""
|
||||
import sys
|
||||
if sys.platform.startswith("linux") or os.name == "posix":
|
||||
try:
|
||||
with open("/proc/version", "r") as f:
|
||||
if "microsoft" in f.read().lower():
|
||||
return True
|
||||
except Exception:
|
||||
pass
|
||||
return False
|
||||
|
||||
|
||||
def translate_path(path_str: str) -> str:
|
||||
"""Translate a path (possibly a Windows path) to the current OS format.
|
||||
|
||||
Particularly handles Windows paths (e.g. C:\\foo or C:/foo) when running
|
||||
under WSL, translating them to /mnt/c/foo.
|
||||
Also handles standard path normalization to avoid string breakages.
|
||||
"""
|
||||
if not path_str:
|
||||
return path_str
|
||||
|
||||
if is_wsl():
|
||||
path_str = path_str.replace("\\", "/")
|
||||
import re
|
||||
m = re.match(r"^([a-zA-Z]):(.*)", path_str)
|
||||
if m:
|
||||
drive = m.group(1).lower()
|
||||
rest = m.group(2)
|
||||
if not rest.startswith("/"):
|
||||
rest = "/" + rest
|
||||
return f"/mnt/{drive}{rest}"
|
||||
|
||||
try:
|
||||
return str(Path(path_str).resolve())
|
||||
except Exception:
|
||||
return path_str
|
||||
|
||||
|
||||
def get_wsl_windows_user_profile() -> Optional[str]:
|
||||
"""Retrieve the Windows host User Profile path from inside WSL."""
|
||||
if not is_wsl():
|
||||
return None
|
||||
try:
|
||||
r = run_wsl_windows_powershell("Write-Output $env:USERPROFILE", timeout=5)
|
||||
if r.returncode == 0 and r.stdout.strip():
|
||||
return translate_path(r.stdout.strip())
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
users_dir = "/mnt/c/Users"
|
||||
if os.path.isdir(users_dir):
|
||||
for entry in os.listdir(users_dir):
|
||||
if entry not in ("All Users", "Default", "Default User", "desktop.ini", "Public"):
|
||||
path = os.path.join(users_dir, entry)
|
||||
if os.path.isdir(path):
|
||||
return path
|
||||
except Exception:
|
||||
pass
|
||||
return None
|
||||
|
||||
|
||||
def _ssh_exec_argv(
|
||||
remote: str,
|
||||
ssh_port: str | None,
|
||||
*,
|
||||
remote_cmd: str | None = None,
|
||||
connect_timeout: int | None = None,
|
||||
strict_host_key_checking: bool | None = None,
|
||||
) -> list[str]:
|
||||
"""Build a consistent ssh argv for remote command execution."""
|
||||
argv = ["ssh"]
|
||||
if connect_timeout is not None:
|
||||
argv.extend(["-o", f"ConnectTimeout={int(connect_timeout)}"])
|
||||
if strict_host_key_checking is not None:
|
||||
argv.extend(
|
||||
[
|
||||
"-o",
|
||||
"StrictHostKeyChecking=yes"
|
||||
if strict_host_key_checking
|
||||
else "StrictHostKeyChecking=no",
|
||||
]
|
||||
)
|
||||
if ssh_port and ssh_port != "22":
|
||||
argv.extend(["-p", str(ssh_port)])
|
||||
argv.append(remote)
|
||||
if remote_cmd is not None:
|
||||
argv.append(remote_cmd)
|
||||
return argv
|
||||
|
||||
|
||||
def run_ssh_command(
|
||||
remote: str,
|
||||
ssh_port: str | None,
|
||||
remote_cmd: str,
|
||||
*,
|
||||
timeout: float,
|
||||
connect_timeout: int | None = None,
|
||||
strict_host_key_checking: bool | None = None,
|
||||
text: bool = True,
|
||||
) -> subprocess.CompletedProcess:
|
||||
"""Run an ssh command with centralized timeout and stderr/stdout capture."""
|
||||
return subprocess.run(
|
||||
_ssh_exec_argv(
|
||||
remote,
|
||||
ssh_port,
|
||||
remote_cmd=remote_cmd,
|
||||
connect_timeout=connect_timeout,
|
||||
strict_host_key_checking=strict_host_key_checking,
|
||||
),
|
||||
timeout=timeout,
|
||||
capture_output=True,
|
||||
text=text,
|
||||
)
|
||||
|
||||
|
||||
def _windows_powershell_argv(
|
||||
command: str,
|
||||
*,
|
||||
no_profile: bool = True,
|
||||
non_interactive: bool = True,
|
||||
) -> List[str]:
|
||||
argv: List[str] = ["powershell.exe"]
|
||||
if no_profile:
|
||||
argv.append("-NoProfile")
|
||||
if non_interactive:
|
||||
argv.append("-NonInteractive")
|
||||
argv.extend(["-Command", command])
|
||||
return argv
|
||||
|
||||
|
||||
def run_wsl_windows_powershell(
|
||||
command: str,
|
||||
*,
|
||||
timeout: float = 5,
|
||||
) -> subprocess.CompletedProcess[str]:
|
||||
"""Run a PowerShell command on the Windows host from WSL.
|
||||
|
||||
Raises ``RuntimeError`` when called outside WSL.
|
||||
"""
|
||||
|
||||
if not is_wsl():
|
||||
raise RuntimeError("run_wsl_windows_powershell is only supported in WSL")
|
||||
return subprocess.run(
|
||||
_windows_powershell_argv(command),
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
@@ -14,7 +14,7 @@ import logging
|
||||
from datetime import datetime, timezone, timedelta
|
||||
from typing import Dict, Optional
|
||||
|
||||
from .database import Session as DbSession, ChatMessage as DbChatMessage, Document as DbDocument, SessionLocal
|
||||
from .database import Session as DbSession, ChatMessage as DbChatMessage, Document as DbDocument, SessionLocal, utcnow_naive
|
||||
from .models import Session, ChatMessage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -619,7 +619,7 @@ class SessionManager:
|
||||
|
||||
try:
|
||||
all_sessions = db.query(DbSession).all()
|
||||
cutoff_date = datetime.now(timezone.utc) - timedelta(days=auto_archive_days)
|
||||
cutoff_date = utcnow_naive() - timedelta(days=auto_archive_days)
|
||||
|
||||
for db_session in all_sessions:
|
||||
stats['total_checked'] += 1
|
||||
|
||||
@@ -52,12 +52,14 @@ services:
|
||||
- SECURE_COOKIES=${SECURE_COOKIES:-false}
|
||||
- EMBEDDING_URL=${EMBEDDING_URL:-}
|
||||
- EMBEDDING_MODEL=${EMBEDDING_MODEL:-}
|
||||
- EMBEDDING_API_KEY=${EMBEDDING_API_KEY:-}
|
||||
- FASTEMBED_MODEL=${FASTEMBED_MODEL:-sentence-transformers/all-MiniLM-L6-v2}
|
||||
- FASTEMBED_CACHE_PATH=${FASTEMBED_CACHE_PATH:-}
|
||||
- CLEANUP_INTERVAL_HOURS=${CLEANUP_INTERVAL_HOURS:-24}
|
||||
- ODYSSEUS_INPROCESS_POLLERS=${ODYSSEUS_INPROCESS_POLLERS:-1}
|
||||
- ODYSSEUS_INPROCESS_TASKS=${ODYSSEUS_INPROCESS_TASKS:-1}
|
||||
- ODYSSEUS_SCRIPT_HOST=${ODYSSEUS_SCRIPT_HOST:-localhost}
|
||||
- ODYSSEUS_CHAT_UPLOAD_MAX_BYTES=${ODYSSEUS_CHAT_UPLOAD_MAX_BYTES:-10485760}
|
||||
- DATA_BRAVE_API_KEY=${DATA_BRAVE_API_KEY:-}
|
||||
- GOOGLE_API_KEY=${GOOGLE_API_KEY:-}
|
||||
- GOOGLE_PSE_CX=${GOOGLE_PSE_CX:-}
|
||||
|
||||
@@ -51,12 +51,14 @@ services:
|
||||
- SECURE_COOKIES=${SECURE_COOKIES:-false}
|
||||
- EMBEDDING_URL=${EMBEDDING_URL:-}
|
||||
- EMBEDDING_MODEL=${EMBEDDING_MODEL:-}
|
||||
- EMBEDDING_API_KEY=${EMBEDDING_API_KEY:-}
|
||||
- FASTEMBED_MODEL=${FASTEMBED_MODEL:-sentence-transformers/all-MiniLM-L6-v2}
|
||||
- FASTEMBED_CACHE_PATH=${FASTEMBED_CACHE_PATH:-}
|
||||
- CLEANUP_INTERVAL_HOURS=${CLEANUP_INTERVAL_HOURS:-24}
|
||||
- ODYSSEUS_INPROCESS_POLLERS=${ODYSSEUS_INPROCESS_POLLERS:-1}
|
||||
- ODYSSEUS_INPROCESS_TASKS=${ODYSSEUS_INPROCESS_TASKS:-1}
|
||||
- ODYSSEUS_SCRIPT_HOST=${ODYSSEUS_SCRIPT_HOST:-localhost}
|
||||
- ODYSSEUS_CHAT_UPLOAD_MAX_BYTES=${ODYSSEUS_CHAT_UPLOAD_MAX_BYTES:-10485760}
|
||||
- DATA_BRAVE_API_KEY=${DATA_BRAVE_API_KEY:-}
|
||||
- GOOGLE_API_KEY=${GOOGLE_API_KEY:-}
|
||||
- GOOGLE_PSE_CX=${GOOGLE_PSE_CX:-}
|
||||
|
||||
@@ -40,12 +40,14 @@ services:
|
||||
- SECURE_COOKIES=${SECURE_COOKIES:-false}
|
||||
- EMBEDDING_URL=${EMBEDDING_URL:-}
|
||||
- EMBEDDING_MODEL=${EMBEDDING_MODEL:-}
|
||||
- EMBEDDING_API_KEY=${EMBEDDING_API_KEY:-}
|
||||
- FASTEMBED_MODEL=${FASTEMBED_MODEL:-sentence-transformers/all-MiniLM-L6-v2}
|
||||
- FASTEMBED_CACHE_PATH=${FASTEMBED_CACHE_PATH:-}
|
||||
- CLEANUP_INTERVAL_HOURS=${CLEANUP_INTERVAL_HOURS:-24}
|
||||
- ODYSSEUS_INPROCESS_POLLERS=${ODYSSEUS_INPROCESS_POLLERS:-1}
|
||||
- ODYSSEUS_INPROCESS_TASKS=${ODYSSEUS_INPROCESS_TASKS:-1}
|
||||
- ODYSSEUS_SCRIPT_HOST=${ODYSSEUS_SCRIPT_HOST:-localhost}
|
||||
- ODYSSEUS_CHAT_UPLOAD_MAX_BYTES=${ODYSSEUS_CHAT_UPLOAD_MAX_BYTES:-10485760}
|
||||
- DATA_BRAVE_API_KEY=${DATA_BRAVE_API_KEY:-}
|
||||
- GOOGLE_API_KEY=${GOOGLE_API_KEY:-}
|
||||
- GOOGLE_PSE_CX=${GOOGLE_PSE_CX:-}
|
||||
|
||||
@@ -0,0 +1,17 @@
|
||||
# Outlook / Office 365 email accounts
|
||||
|
||||
Odysseus email accounts currently use IMAP and SMTP with username/password
|
||||
authentication. That works for providers that still allow app passwords or
|
||||
mailbox passwords for IMAP/SMTP.
|
||||
|
||||
Microsoft disables basic authentication for Outlook and Microsoft 365 in most
|
||||
modern accounts and tenants. If you try to add an Outlook account with a normal
|
||||
password, Microsoft may return errors such as:
|
||||
|
||||
- `IMAP: AUTHENTICATE failed`
|
||||
- `SMTP: 535 5.7.139 Authentication unsuccessful, basic authentication is disabled`
|
||||
|
||||
This is expected. Odysseus does not support Microsoft OAuth or Graph Mail yet,
|
||||
so Outlook / Office 365 accounts cannot currently be added through the password
|
||||
form. Use another email provider with app-password support, or track the future
|
||||
Microsoft Graph OAuth integration.
|
||||
@@ -1,22 +0,0 @@
|
||||
"""
|
||||
_common.py
|
||||
|
||||
Shared constants and helpers for built-in MCP servers.
|
||||
"""
|
||||
|
||||
MAX_OUTPUT_CHARS = 10_000
|
||||
MAX_READ_CHARS = 20_000
|
||||
SHELL_TIMEOUT = 60
|
||||
PYTHON_TIMEOUT = 30
|
||||
SEARCH_TIMEOUT = 30
|
||||
|
||||
|
||||
def truncate(text: str, limit: int = MAX_OUTPUT_CHARS) -> str:
|
||||
"""Truncate text to *limit* characters with a suffix note."""
|
||||
if not isinstance(text, str):
|
||||
# Tool output is occasionally None or a non-string; len(None) would
|
||||
# raise. Coerce so this shared helper never crashes a tool response.
|
||||
text = "" if text is None else str(text)
|
||||
if len(text) > limit:
|
||||
return text[:limit] + f"\n... (truncated, {len(text)} chars total)"
|
||||
return text
|
||||
+182
-125
@@ -31,13 +31,19 @@ sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
|
||||
|
||||
server = Server("email")
|
||||
EMAIL_SOCKET_TIMEOUT = float(os.environ.get("EMAIL_SOCKET_TIMEOUT", "20"))
|
||||
DATA_DIR = Path(__file__).resolve().parent.parent / "data"
|
||||
from src.constants import DATA_DIR as _DATA_DIR, APP_DB, EMAIL_CACHE_DB, SETTINGS_FILE as _SETTINGS_FILE, MAIL_ATTACHMENTS_DIR
|
||||
DATA_DIR = Path(_DATA_DIR)
|
||||
|
||||
|
||||
def _b(value) -> bytes:
|
||||
return str(value).encode()
|
||||
|
||||
|
||||
def _q(name: str) -> str:
|
||||
"""Quote an IMAP mailbox name for commands that take mailbox args."""
|
||||
return '"' + (name or "").replace("\\", "\\\\").replace('"', '\\"') + '"'
|
||||
|
||||
|
||||
def _uid_fetch_rows(data) -> list:
|
||||
return [d for d in (data or []) if isinstance(d, bytes) and b"UID " in d]
|
||||
|
||||
@@ -58,7 +64,7 @@ def _clean_header_value(value) -> str:
|
||||
|
||||
|
||||
def _db_path() -> Path:
|
||||
return DATA_DIR / "app.db"
|
||||
return Path(APP_DB)
|
||||
|
||||
|
||||
def _list_accounts_raw() -> list:
|
||||
@@ -157,7 +163,7 @@ def _load_config(account: str | None = None) -> dict:
|
||||
"trash_folder": os.environ.get("TRASH_FOLDER", "Trash"),
|
||||
"cache_db": os.environ.get(
|
||||
"EMAIL_CACHE_DB",
|
||||
str(DATA_DIR / "email_cache.db"),
|
||||
EMAIL_CACHE_DB,
|
||||
),
|
||||
"account_id": None,
|
||||
"account_name": None,
|
||||
@@ -199,7 +205,7 @@ def _load_config(account: str | None = None) -> dict:
|
||||
else:
|
||||
# Legacy fallback: settings.json flat keys
|
||||
try:
|
||||
settings_path = Path(__file__).resolve().parent.parent / "data" / "settings.json"
|
||||
settings_path = Path(_SETTINGS_FILE)
|
||||
if settings_path.exists():
|
||||
settings = json.loads(settings_path.read_text(encoding="utf-8"))
|
||||
for key in (
|
||||
@@ -239,10 +245,27 @@ def _imap_connect(account: str | None = None):
|
||||
timeout=EMAIL_SOCKET_TIMEOUT,
|
||||
)
|
||||
if cfg["imap_starttls"]:
|
||||
conn.starttls()
|
||||
try:
|
||||
conn.starttls()
|
||||
except Exception:
|
||||
# Don't leak the open plain socket on a rejected STARTTLS. (#3174)
|
||||
try:
|
||||
conn.shutdown()
|
||||
except Exception:
|
||||
pass
|
||||
raise
|
||||
if getattr(conn, "sock", None):
|
||||
conn.sock.settimeout(EMAIL_SOCKET_TIMEOUT)
|
||||
conn.login(cfg["imap_user"], cfg["imap_password"])
|
||||
try:
|
||||
conn.login(cfg["imap_user"], cfg["imap_password"])
|
||||
except Exception:
|
||||
# A failed login otherwise orphans the connected socket; close it
|
||||
# before propagating (shutdown() is the pre-auth low-level close). (#3174)
|
||||
try:
|
||||
conn.shutdown()
|
||||
except Exception:
|
||||
pass
|
||||
raise
|
||||
return conn
|
||||
|
||||
|
||||
@@ -418,68 +441,71 @@ def _list_emails(folder="INBOX", max_results=20, unresponded_only=False,
|
||||
Pass unread_only=True and/or unresponded_only=True for attention scans.
|
||||
account selects mailbox (None = default).
|
||||
"""
|
||||
conn = _imap_connect(account)
|
||||
select_status, _ = conn.select(folder, readonly=True)
|
||||
if select_status != "OK":
|
||||
conn.logout()
|
||||
raise ValueError(f"IMAP folder not found: {folder}")
|
||||
conn = None
|
||||
try:
|
||||
conn = _imap_connect(account)
|
||||
select_status, _ = conn.select(_q(folder), readonly=True)
|
||||
if select_status != "OK":
|
||||
raise ValueError(f"IMAP folder not found: {folder}")
|
||||
|
||||
if unread_only and unresponded_only:
|
||||
status, data = conn.uid("SEARCH", None, "(UNSEEN UNANSWERED)")
|
||||
elif unread_only:
|
||||
status, data = conn.uid("SEARCH", None, "(UNSEEN)")
|
||||
elif unresponded_only:
|
||||
# Was missing — unresponded_only=True (without unread_only) fell through
|
||||
# to "ALL" and returned answered mail too, despite the documented
|
||||
# "emails without replies" behaviour.
|
||||
status, data = conn.uid("SEARCH", None, "(UNANSWERED)")
|
||||
else:
|
||||
# Include read too — IMAP search "ALL" returns the entire folder
|
||||
status, data = conn.uid("SEARCH", None, "ALL")
|
||||
if unread_only and unresponded_only:
|
||||
status, data = conn.uid("SEARCH", None, "(UNSEEN UNANSWERED)")
|
||||
elif unread_only:
|
||||
status, data = conn.uid("SEARCH", None, "(UNSEEN)")
|
||||
elif unresponded_only:
|
||||
# Was missing — unresponded_only=True (without unread_only) fell through
|
||||
# to "ALL" and returned answered mail too, despite the documented
|
||||
# "emails without replies" behaviour.
|
||||
status, data = conn.uid("SEARCH", None, "(UNANSWERED)")
|
||||
else:
|
||||
# Include read too — IMAP search "ALL" returns the entire folder
|
||||
status, data = conn.uid("SEARCH", None, "ALL")
|
||||
|
||||
if status != "OK" or not data[0]:
|
||||
conn.logout()
|
||||
return []
|
||||
if status != "OK" or not data[0]:
|
||||
return []
|
||||
|
||||
uid_list = list(reversed(data[0].split()))[:max_results]
|
||||
cache = _get_cached_summaries()
|
||||
results = []
|
||||
uid_list = list(reversed(data[0].split()))[:max_results]
|
||||
cache = _get_cached_summaries()
|
||||
results = []
|
||||
|
||||
for uid in uid_list:
|
||||
try:
|
||||
status, msg_data = conn.uid("FETCH", uid, "(RFC822.HEADER)")
|
||||
if status != "OK":
|
||||
for uid in uid_list:
|
||||
try:
|
||||
status, msg_data = conn.uid("FETCH", uid, "(RFC822.HEADER)")
|
||||
if status != "OK":
|
||||
continue
|
||||
raw_header = msg_data[0][1]
|
||||
msg = email.message_from_bytes(raw_header)
|
||||
|
||||
subject = _decode_header(msg.get("Subject", "(no subject)"))
|
||||
sender = _decode_header(msg.get("From", "unknown"))
|
||||
date_str = msg.get("Date", "")
|
||||
message_id = msg.get("Message-ID", "")
|
||||
|
||||
# Parse sender name
|
||||
sender_name, sender_addr = email.utils.parseaddr(sender)
|
||||
sender_display = sender_name or sender_addr
|
||||
|
||||
# Check cache for summary
|
||||
cached = cache.get(subject, {})
|
||||
summary = cached.get("summary", "")
|
||||
|
||||
results.append({
|
||||
"uid": uid.decode(),
|
||||
"message_id": message_id,
|
||||
"subject": subject,
|
||||
"from": sender_display,
|
||||
"from_address": sender_addr,
|
||||
"date": date_str,
|
||||
"summary": summary,
|
||||
})
|
||||
except Exception:
|
||||
continue
|
||||
raw_header = msg_data[0][1]
|
||||
msg = email.message_from_bytes(raw_header)
|
||||
|
||||
subject = _decode_header(msg.get("Subject", "(no subject)"))
|
||||
sender = _decode_header(msg.get("From", "unknown"))
|
||||
date_str = msg.get("Date", "")
|
||||
message_id = msg.get("Message-ID", "")
|
||||
|
||||
# Parse sender name
|
||||
sender_name, sender_addr = email.utils.parseaddr(sender)
|
||||
sender_display = sender_name or sender_addr
|
||||
|
||||
# Check cache for summary
|
||||
cached = cache.get(subject, {})
|
||||
summary = cached.get("summary", "")
|
||||
|
||||
results.append({
|
||||
"uid": uid.decode(),
|
||||
"message_id": message_id,
|
||||
"subject": subject,
|
||||
"from": sender_display,
|
||||
"from_address": sender_addr,
|
||||
"date": date_str,
|
||||
"summary": summary,
|
||||
})
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
conn.logout()
|
||||
return results
|
||||
return results
|
||||
finally:
|
||||
if conn:
|
||||
try: conn.logout()
|
||||
except Exception: pass
|
||||
|
||||
|
||||
def _result_sort_time(result: dict) -> datetime:
|
||||
@@ -542,7 +568,7 @@ def _search_emails(query, folders=None, max_results=20, account=None):
|
||||
try:
|
||||
for folder in folders:
|
||||
try:
|
||||
status, _ = conn.select(folder, readonly=True)
|
||||
status, _ = conn.select(_q(folder), readonly=True)
|
||||
if status != "OK":
|
||||
continue
|
||||
status, data = conn.uid("SEARCH", None, search_cmd)
|
||||
@@ -652,54 +678,55 @@ def _extract_attachment_to_disk(msg, index, target_dir):
|
||||
def _read_email(uid=None, message_id=None, folder="INBOX", account=None):
|
||||
"""Read full email content by UID or message-ID. account = mailbox selector."""
|
||||
cfg = _load_config(account)
|
||||
conn = _imap_connect(account)
|
||||
conn.select(folder, readonly=True)
|
||||
conn = None
|
||||
try:
|
||||
conn = _imap_connect(account)
|
||||
conn.select(_q(folder), readonly=True)
|
||||
|
||||
if message_id and not uid:
|
||||
status, data = conn.uid("SEARCH", None, f'(HEADER Message-ID "{message_id}")')
|
||||
if status != "OK" or not data[0]:
|
||||
conn.logout()
|
||||
return {"error": f"Email not found with Message-ID: {message_id}"}
|
||||
uid = data[0].split()[-1]
|
||||
if message_id and not uid:
|
||||
status, data = conn.uid("SEARCH", None, f'(HEADER Message-ID "{message_id}")')
|
||||
if status != "OK" or not data[0]:
|
||||
return {"error": f"Email not found with Message-ID: {message_id}"}
|
||||
uid = data[0].split()[-1]
|
||||
|
||||
if not uid:
|
||||
conn.logout()
|
||||
return {"error": "No UID or Message-ID provided"}
|
||||
if not uid:
|
||||
return {"error": "No UID or Message-ID provided"}
|
||||
|
||||
status, msg_data = conn.uid("FETCH", _b(uid), "(BODY.PEEK[])")
|
||||
if status != "OK":
|
||||
conn.logout()
|
||||
return {"error": f"Failed to fetch email UID {uid}"}
|
||||
if not msg_data or not msg_data[0] or not isinstance(msg_data[0], tuple) or len(msg_data[0]) < 2:
|
||||
conn.logout()
|
||||
return {"error": f"Email not found with UID {uid}"}
|
||||
status, msg_data = conn.uid("FETCH", _b(uid), "(BODY.PEEK[])")
|
||||
if status != "OK":
|
||||
return {"error": f"Failed to fetch email UID {uid}"}
|
||||
if not msg_data or not msg_data[0] or not isinstance(msg_data[0], tuple) or len(msg_data[0]) < 2:
|
||||
return {"error": f"Email not found with UID {uid}"}
|
||||
|
||||
raw = msg_data[0][1]
|
||||
msg = email.message_from_bytes(raw)
|
||||
raw = msg_data[0][1]
|
||||
msg = email.message_from_bytes(raw)
|
||||
|
||||
subject = _decode_header(msg.get("Subject", "(no subject)"))
|
||||
sender = _decode_header(msg.get("From", "unknown"))
|
||||
date_str = msg.get("Date", "")
|
||||
message_id_header = msg.get("Message-ID", "")
|
||||
body = _extract_text(msg)
|
||||
attachments = _list_attachments_from_msg(msg)
|
||||
subject = _decode_header(msg.get("Subject", "(no subject)"))
|
||||
sender = _decode_header(msg.get("From", "unknown"))
|
||||
date_str = msg.get("Date", "")
|
||||
message_id_header = msg.get("Message-ID", "")
|
||||
body = _extract_text(msg)
|
||||
attachments = _list_attachments_from_msg(msg)
|
||||
|
||||
sender_name, sender_addr = email.utils.parseaddr(sender)
|
||||
sender_name, sender_addr = email.utils.parseaddr(sender)
|
||||
|
||||
conn.logout()
|
||||
return {
|
||||
"uid": uid.decode() if isinstance(uid, bytes) else str(uid),
|
||||
"account": cfg.get("account_name") or cfg.get("imap_user") or "default",
|
||||
"account_email": cfg.get("imap_user") or cfg.get("from_address") or "",
|
||||
"account_id": cfg.get("account_id"),
|
||||
"message_id": message_id_header,
|
||||
"subject": subject,
|
||||
"from": sender_name or sender_addr,
|
||||
"from_address": sender_addr,
|
||||
"date": date_str,
|
||||
"body": body[:8000],
|
||||
"attachments": attachments,
|
||||
}
|
||||
return {
|
||||
"uid": uid.decode() if isinstance(uid, bytes) else str(uid),
|
||||
"account": cfg.get("account_name") or cfg.get("imap_user") or "default",
|
||||
"account_email": cfg.get("imap_user") or cfg.get("from_address") or "",
|
||||
"account_id": cfg.get("account_id"),
|
||||
"message_id": message_id_header,
|
||||
"subject": subject,
|
||||
"from": sender_name or sender_addr,
|
||||
"from_address": sender_addr,
|
||||
"date": date_str,
|
||||
"body": body[:8000],
|
||||
"attachments": attachments,
|
||||
}
|
||||
finally:
|
||||
if conn:
|
||||
try: conn.logout()
|
||||
except Exception: pass
|
||||
|
||||
|
||||
def _read_email_across_accounts(uid=None, message_id=None, folder="INBOX"):
|
||||
@@ -768,7 +795,16 @@ def _smtp_connect(account=None, cfg=None):
|
||||
port,
|
||||
timeout=EMAIL_SOCKET_TIMEOUT,
|
||||
)
|
||||
conn.starttls()
|
||||
try:
|
||||
conn.starttls()
|
||||
except Exception:
|
||||
# Don't leak the open plain socket on a rejected STARTTLS. SMTP has
|
||||
# no shutdown(); close() is the low-level socket close (no QUIT). (#3174)
|
||||
try:
|
||||
conn.close()
|
||||
except Exception:
|
||||
pass
|
||||
raise
|
||||
elif security == "ssl":
|
||||
conn = smtplib.SMTP_SSL(
|
||||
cfg["smtp_host"],
|
||||
@@ -782,7 +818,16 @@ def _smtp_connect(account=None, cfg=None):
|
||||
timeout=EMAIL_SOCKET_TIMEOUT,
|
||||
)
|
||||
if cfg["smtp_user"] and cfg["smtp_password"]:
|
||||
conn.login(cfg["smtp_user"], cfg["smtp_password"])
|
||||
try:
|
||||
conn.login(cfg["smtp_user"], cfg["smtp_password"])
|
||||
except Exception:
|
||||
# A failed login otherwise orphans the connected socket; close it
|
||||
# before propagating (SMTP has no shutdown(); close() = socket close). (#3174)
|
||||
try:
|
||||
conn.close()
|
||||
except Exception:
|
||||
pass
|
||||
raise
|
||||
return conn
|
||||
|
||||
|
||||
@@ -827,7 +872,7 @@ def _send_email(to, subject, body, in_reply_to=None, references=None, cc=None, b
|
||||
imap = _imap_connect(send_account)
|
||||
try:
|
||||
sent_folder = _detect_sent_folder(imap)
|
||||
append_st, append_data = imap.append(sent_folder, "\\Seen", None, msg.as_bytes())
|
||||
append_st, append_data = imap.append(_q(sent_folder), "\\Seen", None, msg.as_bytes())
|
||||
if append_st == "OK" and append_data:
|
||||
m = re.search(rb"APPENDUID\s+\d+\s+(\d+)", append_data[0] or b"")
|
||||
if m:
|
||||
@@ -853,10 +898,15 @@ def _send_email(to, subject, body, in_reply_to=None, references=None, cc=None, b
|
||||
|
||||
def _reply_to_email(uid, body, folder="INBOX", reply_all=False, account=None):
|
||||
"""Reply to an existing email by UID. Threads via In-Reply-To/References."""
|
||||
conn = _imap_connect(account)
|
||||
conn.select(folder, readonly=True)
|
||||
status, msg_data = conn.uid("FETCH", _b(uid), "(BODY.PEEK[])")
|
||||
conn.logout()
|
||||
conn = None
|
||||
try:
|
||||
conn = _imap_connect(account)
|
||||
conn.select(_q(folder), readonly=True)
|
||||
status, msg_data = conn.uid("FETCH", _b(uid), "(BODY.PEEK[])")
|
||||
finally:
|
||||
if conn:
|
||||
try: conn.logout()
|
||||
except Exception: pass
|
||||
if status != "OK" or not msg_data or not msg_data[0]:
|
||||
return {"error": f"Failed to fetch email UID {uid}"}
|
||||
raw = msg_data[0][1]
|
||||
@@ -896,7 +946,7 @@ def _reply_to_email(uid, body, folder="INBOX", reply_all=False, account=None):
|
||||
def _set_flag(uid, folder, flag, add=True, account=None):
|
||||
"""Add or remove an IMAP flag (e.g. \\Seen, \\Answered, \\Deleted)."""
|
||||
conn = _imap_connect(account)
|
||||
conn.select(folder)
|
||||
conn.select(_q(folder))
|
||||
op = "+FLAGS" if add else "-FLAGS"
|
||||
try:
|
||||
status, data = conn.uid("STORE", _b(uid), op, flag)
|
||||
@@ -918,7 +968,7 @@ def _bulk_set_flag(uids, folder, flag, add=True, account=None):
|
||||
conn = _imap_connect(account)
|
||||
touched = []
|
||||
try:
|
||||
conn.select(folder)
|
||||
conn.select(_q(folder))
|
||||
op = "+FLAGS" if add else "-FLAGS"
|
||||
msg_set = ",".join(str(u) for u in uids)
|
||||
try:
|
||||
@@ -945,7 +995,7 @@ def _bulk_move(uids, source_folder, dest_folder, account=None, role: str = ""):
|
||||
conn = _imap_connect(account)
|
||||
moved = 0
|
||||
try:
|
||||
conn.select(source_folder)
|
||||
conn.select(_q(source_folder))
|
||||
dest_folder = _resolve_folder(conn, dest_folder, role or _folder_role_from_name(dest_folder))
|
||||
msg_set = ",".join(str(u) for u in uids)
|
||||
try:
|
||||
@@ -956,10 +1006,11 @@ def _bulk_move(uids, source_folder, dest_folder, account=None, role: str = ""):
|
||||
if not existing:
|
||||
return 0
|
||||
moved = len(existing)
|
||||
status, _ = conn.uid("MOVE", _b(msg_set), dest_folder)
|
||||
dest_arg = _q(dest_folder)
|
||||
status, _ = conn.uid("MOVE", _b(msg_set), dest_arg)
|
||||
if status != "OK":
|
||||
# Fallback: UID copy + flag-delete + expunge
|
||||
status, _ = conn.uid("COPY", _b(msg_set), dest_folder)
|
||||
status, _ = conn.uid("COPY", _b(msg_set), dest_arg)
|
||||
if status != "OK":
|
||||
return 0
|
||||
status, _ = conn.uid("STORE", _b(msg_set), "+FLAGS", "\\Deleted")
|
||||
@@ -976,7 +1027,7 @@ def _search_uids(folder="INBOX", criteria="UNSEEN", account=None):
|
||||
ALL, ANSWERED). Used to resolve selectors like all_unread → uids."""
|
||||
conn = _imap_connect(account)
|
||||
try:
|
||||
conn.select(folder, readonly=True)
|
||||
conn.select(_q(folder), readonly=True)
|
||||
status, data = conn.uid("SEARCH", None, criteria)
|
||||
if status != "OK" or not data or not data[0]:
|
||||
return []
|
||||
@@ -988,7 +1039,7 @@ def _search_uids(folder="INBOX", criteria="UNSEEN", account=None):
|
||||
def _move_message(uid, source_folder, dest_folder, account=None, role: str = ""):
|
||||
"""Move a message between folders. Tries IMAP MOVE, falls back to copy+delete."""
|
||||
conn = _imap_connect(account)
|
||||
conn.select(source_folder)
|
||||
conn.select(_q(source_folder))
|
||||
try:
|
||||
dest_folder = _resolve_folder(conn, dest_folder, role or _folder_role_from_name(dest_folder))
|
||||
try:
|
||||
@@ -998,11 +1049,12 @@ def _move_message(uid, source_folder, dest_folder, account=None, role: str = "")
|
||||
existing = _uid_fetch_rows(data)
|
||||
if status != "OK" or not existing:
|
||||
return False
|
||||
status, _ = conn.uid("MOVE", _b(uid), dest_folder)
|
||||
dest_arg = _q(dest_folder)
|
||||
status, _ = conn.uid("MOVE", _b(uid), dest_arg)
|
||||
if status == "OK":
|
||||
return True
|
||||
# Fallback: UID copy + delete
|
||||
status, _ = conn.uid("COPY", _b(uid), dest_folder)
|
||||
status, _ = conn.uid("COPY", _b(uid), dest_arg)
|
||||
if status != "OK":
|
||||
return False
|
||||
status, _ = conn.uid("STORE", _b(uid), "+FLAGS", "\\Deleted")
|
||||
@@ -1031,16 +1083,21 @@ def _archive_email(uid, folder="INBOX", account=None):
|
||||
|
||||
def _download_attachment(uid, index, folder="INBOX", account=None):
|
||||
"""Extract a specific attachment to disk and return its local path."""
|
||||
conn = _imap_connect(account)
|
||||
conn.select(folder, readonly=True)
|
||||
status, msg_data = conn.uid("FETCH", _b(uid), "(BODY.PEEK[])")
|
||||
conn.logout()
|
||||
conn = None
|
||||
try:
|
||||
conn = _imap_connect(account)
|
||||
conn.select(_q(folder), readonly=True)
|
||||
status, msg_data = conn.uid("FETCH", _b(uid), "(BODY.PEEK[])")
|
||||
finally:
|
||||
if conn:
|
||||
try: conn.logout()
|
||||
except Exception: pass
|
||||
if status != "OK":
|
||||
return {"error": f"Failed to fetch email UID {uid}"}
|
||||
raw = msg_data[0][1]
|
||||
msg = email.message_from_bytes(raw)
|
||||
|
||||
target_dir = DATA_DIR / "mail-attachments" / f"{folder}_{uid}"
|
||||
target_dir = Path(MAIL_ATTACHMENTS_DIR) / f"{folder}_{uid}"
|
||||
filepath = _extract_attachment_to_disk(msg, index, target_dir)
|
||||
if not filepath:
|
||||
return {"error": f"Attachment index {index} not found"}
|
||||
|
||||
@@ -16,6 +16,8 @@ from mcp.types import Tool, TextContent
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
|
||||
|
||||
from src.constants import GENERATED_IMAGES_DIR
|
||||
|
||||
server = Server("image_gen")
|
||||
|
||||
|
||||
@@ -115,14 +117,18 @@ async def call_tool(name: str, arguments: dict) -> list[TextContent]:
|
||||
|
||||
img = images[0]
|
||||
image_url = None
|
||||
# Prefix the instance's public base URL (existing app_public_url setting) so the
|
||||
# link is fully-qualified and clickable when the model echoes it. Empty = relative
|
||||
# same-origin path (unchanged default).
|
||||
_pub_base = (get_setting("app_public_url", "") or "").rstrip("/")
|
||||
|
||||
if img.get("b64_json"):
|
||||
img_dir = Path("data/generated_images")
|
||||
img_dir = Path(GENERATED_IMAGES_DIR)
|
||||
img_dir.mkdir(parents=True, exist_ok=True)
|
||||
filename = f"{uuid.uuid4().hex[:12]}.png"
|
||||
img_path = img_dir / filename
|
||||
img_path.write_bytes(base64.b64decode(img["b64_json"]))
|
||||
image_url = f"/api/generated-image/{filename}"
|
||||
image_url = f"{_pub_base}/api/generated-image/{filename}"
|
||||
|
||||
# Save to gallery
|
||||
try:
|
||||
@@ -146,7 +152,13 @@ async def call_tool(name: str, arguments: dict) -> list[TextContent]:
|
||||
else:
|
||||
return [TextContent(type="text", text="Error: Unexpected image API response format")]
|
||||
|
||||
result = f"Generated image for: {prompt[:100]}\nimage_url: {image_url}\nmodel: {model_id}\nsize: {size}"
|
||||
# "Direct link:" rather than an "image_url:" label — small models copied the
|
||||
# label token ("image_url") into the link href, producing a broken link.
|
||||
result = (
|
||||
f"Generated image for: {prompt[:100]}\n"
|
||||
f"Direct link: {image_url}\n"
|
||||
f"model: {model_id}\nsize: {size}"
|
||||
)
|
||||
return [TextContent(type="text", text=result)]
|
||||
|
||||
except httpx.TimeoutException:
|
||||
|
||||
Generated
+1
-1
@@ -1,5 +1,5 @@
|
||||
{
|
||||
"name": "odysseus-ui",
|
||||
"name": "odysseus",
|
||||
"lockfileVersion": 3,
|
||||
"requires": true,
|
||||
"packages": {
|
||||
|
||||
@@ -1,3 +1,18 @@
|
||||
[tool.pytest.ini_options]
|
||||
testpaths = ["tests"]
|
||||
asyncio_mode = "auto"
|
||||
# Test-taxonomy markers added at collection time by tests/conftest.py. The
|
||||
# stable area_* markers are declared here; the dynamic sub_<filename-token>
|
||||
# markers are registered before collection by pytest_configure in
|
||||
# tests/conftest.py, so unknown-mark warnings still flag genuine typos outside
|
||||
# the taxonomy. See tests/_taxonomy.py and tests/README.md.
|
||||
markers = [
|
||||
"area_security: tests covering auth, owner-scope, SSRF, XSS, confinement, redaction",
|
||||
"area_routes: tests covering HTTP route / API behavior",
|
||||
"area_services: tests covering service-layer behavior (llm, cookbook, email, calendar, ...)",
|
||||
"area_cli: tests covering CLI / script behavior",
|
||||
"area_js: JavaScript / Node-backed tests",
|
||||
"area_helpers: self-tests for the shared test helpers in tests/helpers/",
|
||||
"area_unit: pure parser / utility tests that do not clearly belong elsewhere",
|
||||
"area_uncategorized: tests not yet matched by the taxonomy (fallback)",
|
||||
]
|
||||
|
||||
@@ -31,7 +31,7 @@ from core.database import (
|
||||
CalendarEvent,
|
||||
CalendarCal,
|
||||
)
|
||||
from src.constants import DATA_DIR
|
||||
from src.constants import DATA_DIR, SKILLS_DIR, SKILLS_FILE, GALLERY_DIR, GALLERY_UPLOADS_DIR
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -107,7 +107,7 @@ def setup_admin_wipe_routes(session_manager):
|
||||
# Skills live as SKILL.md files under data/skills/. Drop
|
||||
# the entire directory; the SkillsManager re-creates the
|
||||
# tree on next write.
|
||||
skills_dir = os.path.join(DATA_DIR, "skills")
|
||||
skills_dir = SKILLS_DIR
|
||||
count = 0
|
||||
if os.path.isdir(skills_dir):
|
||||
# Count SKILL.md files for the response — quick walk.
|
||||
@@ -115,7 +115,7 @@ def setup_admin_wipe_routes(session_manager):
|
||||
count += sum(1 for f in files if f == "SKILL.md")
|
||||
_rmtree_quiet(skills_dir)
|
||||
# Legacy fallback file
|
||||
legacy = os.path.join(DATA_DIR, "skills.json")
|
||||
legacy = SKILLS_FILE
|
||||
if os.path.exists(legacy):
|
||||
try:
|
||||
os.remove(legacy)
|
||||
@@ -151,8 +151,8 @@ def setup_admin_wipe_routes(session_manager):
|
||||
db.query(GalleryAlbum).delete()
|
||||
db.commit()
|
||||
# Also drop the upload dir so disk doesn't keep orphans.
|
||||
_rmtree_quiet(os.path.join(DATA_DIR, "gallery"))
|
||||
_rmtree_quiet(os.path.join(DATA_DIR, "gallery_uploads"))
|
||||
_rmtree_quiet(GALLERY_DIR)
|
||||
_rmtree_quiet(GALLERY_UPLOADS_DIR)
|
||||
return {"status": "deleted", "kind": kind, "count": count}
|
||||
|
||||
if kind == "calendar":
|
||||
|
||||
@@ -155,22 +155,30 @@ def setup_api_token_routes() -> APIRouter:
|
||||
payload = await request.json()
|
||||
except Exception:
|
||||
payload = {}
|
||||
scope_list = _normalize_scopes(payload.get("scopes"))
|
||||
scopes_value = ",".join(scope_list)
|
||||
with get_db_session() as db:
|
||||
token = db.query(ApiToken).filter(ApiToken.id == token_id).first()
|
||||
if not token:
|
||||
raise HTTPException(404, "Token not found")
|
||||
if isinstance(payload.get("name"), str) and payload["name"].strip():
|
||||
token.name = payload["name"].strip()[:MAX_NAME_LEN]
|
||||
token.scopes = scopes_value
|
||||
# Only touch scopes when the caller actually sent them. A partial
|
||||
# update such as a rename ({"name": ...} with no "scopes" key) must
|
||||
# not silently reset the token to the default scope — that dropped
|
||||
# every previously granted scope.
|
||||
if "scopes" in payload:
|
||||
token.scopes = ",".join(_normalize_scopes(payload.get("scopes")))
|
||||
db.add(token)
|
||||
current_scopes = [
|
||||
s.strip()
|
||||
for s in (getattr(token, "scopes", "") or DEFAULT_SCOPES).split(",")
|
||||
if s.strip()
|
||||
]
|
||||
response = {
|
||||
"id": token_id,
|
||||
"name": getattr(token, "name", ""),
|
||||
"owner": getattr(token, "owner", None),
|
||||
"token_prefix": getattr(token, "token_prefix", ""),
|
||||
"scopes": scope_list,
|
||||
"scopes": current_scopes,
|
||||
}
|
||||
_invalidate_cache(request)
|
||||
return response
|
||||
|
||||
+23
-4
@@ -131,10 +131,8 @@ def setup_auth_routes(auth_manager: AuthManager) -> APIRouter:
|
||||
return {"ok": False, "requires_totp": True, "username": username}
|
||||
if not auth_manager.totp_verify(username, body.totp_code):
|
||||
raise HTTPException(401, "Invalid 2FA code")
|
||||
# All checks passed — create session
|
||||
token = await asyncio.to_thread(auth_manager.create_session, username, body.password)
|
||||
if not token:
|
||||
raise HTTPException(401, "Invalid credentials")
|
||||
# All checks passed — create session (password already verified above)
|
||||
token = await asyncio.to_thread(auth_manager.create_session_trusted, username)
|
||||
cookie_kwargs = dict(
|
||||
key=SESSION_COOKIE,
|
||||
value=token,
|
||||
@@ -585,6 +583,27 @@ def setup_auth_routes(auth_manager: AuthManager) -> APIRouter:
|
||||
hint = " If this is Docker Compose ntfy, set NTFY_BIND to that host/Tailscale IP and NTFY_BASE_URL to the same server URL in .env, then recreate ntfy."
|
||||
return {"ok": False, "message": f"ntfy publish to {full_url} failed: {e}.{hint}"[:500]}
|
||||
|
||||
if preset == "discord_webhook":
|
||||
import httpx
|
||||
webhook_url = (integ.get("base_url") or "").strip()
|
||||
if not webhook_url:
|
||||
return {"ok": False, "message": "No webhook URL set — paste the full Discord webhook URL into the Base URL field."}
|
||||
payload = {
|
||||
"embeds": [{
|
||||
"title": "Odysseus connectivity test",
|
||||
"description": "If you see this, your Discord Webhook integration is wired up correctly.",
|
||||
"color": 5793266,
|
||||
}]
|
||||
}
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=8.0) as client:
|
||||
r = await client.post(webhook_url, json=payload)
|
||||
if r.is_success:
|
||||
return {"ok": True, "message": "Test embed sent — check your Discord channel to confirm it arrived."}
|
||||
return {"ok": False, "message": f"Discord returned HTTP {r.status_code}: {r.text[:200]}"}
|
||||
except Exception as e:
|
||||
return {"ok": False, "message": f"Request failed: {e}"[:400]}
|
||||
|
||||
# All other presets: GET against a known health endpoint.
|
||||
# Fall back to detecting from name if preset is missing.
|
||||
health_paths = {
|
||||
|
||||
+56
-12
@@ -101,24 +101,68 @@ def setup_backup_routes(memory_manager, preset_manager, skills_manager) -> APIRo
|
||||
# ── Skills ──
|
||||
if "skills" in body and isinstance(body["skills"], list):
|
||||
existing = skills_manager.load_all()
|
||||
existing_ids = {s.get("id") for s in existing}
|
||||
existing_titles = {s.get("title", "").strip().lower() for s in existing}
|
||||
existing_names = {s.get("name") for s in existing if s.get("name")}
|
||||
existing_ids = {s.get("id") for s in existing if s.get("id")}
|
||||
existing_titles = {
|
||||
(s.get("title") or s.get("description") or "").strip().lower()
|
||||
for s in existing
|
||||
}
|
||||
added = 0
|
||||
for skill in body["skills"]:
|
||||
if not isinstance(skill, dict) or not skill.get("title"):
|
||||
if not isinstance(skill, dict):
|
||||
continue
|
||||
# Skip if same id or same title already exists
|
||||
if skill.get("id") in existing_ids:
|
||||
title = (
|
||||
skill.get("title") or skill.get("description")
|
||||
or skill.get("name") or ""
|
||||
).strip()
|
||||
if not title:
|
||||
continue
|
||||
if skill["title"].strip().lower() in existing_titles:
|
||||
sid = skill.get("id") or skill.get("name")
|
||||
if sid and sid in existing_ids:
|
||||
continue
|
||||
if user and not skill.get("owner"):
|
||||
skill["owner"] = user
|
||||
existing.append(skill)
|
||||
existing_ids.add(skill.get("id"))
|
||||
existing_titles.add(skill["title"].strip().lower())
|
||||
nm = skill.get("name")
|
||||
if nm and nm in existing_names:
|
||||
continue
|
||||
if title.lower() in existing_titles:
|
||||
continue
|
||||
owner = skill.get("owner")
|
||||
if user and not owner:
|
||||
owner = user
|
||||
# Skills live on disk as SKILL.md files; the old JSON-era
|
||||
# skills_manager.save() no longer exists. Write each new skill
|
||||
# via add_skill (source="user" skips auto-dedup — this is an
|
||||
# explicit backup restore).
|
||||
result = skills_manager.add_skill(
|
||||
title=title,
|
||||
name=skill.get("name"),
|
||||
description=skill.get("description"),
|
||||
problem=skill.get("problem", ""),
|
||||
solution=skill.get("solution", ""),
|
||||
steps=skill.get("steps"),
|
||||
tags=skill.get("tags"),
|
||||
source="user",
|
||||
teacher_model=skill.get("teacher_model"),
|
||||
confidence=skill.get("confidence", 0.8),
|
||||
owner=owner,
|
||||
category=skill.get("category", "general"),
|
||||
when_to_use=skill.get("when_to_use"),
|
||||
procedure=skill.get("procedure"),
|
||||
pitfalls=skill.get("pitfalls"),
|
||||
verification=skill.get("verification"),
|
||||
platforms=skill.get("platforms"),
|
||||
requires_toolsets=skill.get("requires_toolsets"),
|
||||
fallback_for_toolsets=skill.get("fallback_for_toolsets"),
|
||||
status=skill.get("status", "draft"),
|
||||
version=skill.get("version", "1.0.0"),
|
||||
)
|
||||
if result.get("_deduped"):
|
||||
continue
|
||||
if result.get("name"):
|
||||
existing_names.add(result["name"])
|
||||
if result.get("id"):
|
||||
existing_ids.add(result["id"])
|
||||
existing_titles.add(title.lower())
|
||||
added += 1
|
||||
skills_manager.save(existing)
|
||||
imported.append(f"{added} skills")
|
||||
|
||||
# ── Presets ──
|
||||
|
||||
+254
-69
@@ -1,6 +1,7 @@
|
||||
"""Calendar routes — local SQLite-backed calendar CRUD."""
|
||||
|
||||
import logging
|
||||
import re
|
||||
import uuid
|
||||
from datetime import datetime, date, timedelta
|
||||
from typing import Optional, List
|
||||
@@ -12,7 +13,7 @@ from dateutil.rrule import rrulestr
|
||||
|
||||
from core.database import SessionLocal, CalendarCal, CalendarEvent
|
||||
from src.auth_helpers import require_user
|
||||
from src.upload_limits import read_upload_limited
|
||||
from src.upload_limits import read_upload_limited, ICS_MAX_BYTES
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -100,6 +101,15 @@ def _ics_escape(text: str) -> str:
|
||||
)
|
||||
|
||||
|
||||
def _safe_ics_filename(name: str) -> str:
|
||||
"""Return a conservative .ics filename safe for Content-Disposition."""
|
||||
stem = name if isinstance(name, str) else ""
|
||||
stem = re.sub(r"[^A-Za-z0-9._-]", "_", stem).strip("._-")
|
||||
if not stem:
|
||||
stem = "calendar"
|
||||
return f"{stem[:128]}.ics"
|
||||
|
||||
|
||||
def _resolve_base_uid(uid: str) -> str:
|
||||
"""Extract the base series UID from a compound occurrence UID.
|
||||
|
||||
@@ -248,6 +258,17 @@ def parse_due_for_user(s: str) -> str:
|
||||
if t is not None:
|
||||
return base.replace(hour=t[0], minute=t[1]).isoformat()
|
||||
|
||||
# Time-first: "3pm today", "11pm today", "9am tomorrow"
|
||||
m = _re.match(r'^(.+?)\s+(today|tonight|tomorrow|tmrw|yesterday)$', lower)
|
||||
if m:
|
||||
time_part, word = m.group(1).strip(), m.group(2)
|
||||
base = today
|
||||
if word in ("tomorrow", "tmrw"): base = today + _td(days=1)
|
||||
elif word == "yesterday": base = today - _td(days=1)
|
||||
t = _parse_time(time_part)
|
||||
if t is not None:
|
||||
return base.replace(hour=t[0], minute=t[1]).isoformat()
|
||||
|
||||
m = _re.match(r'^in\s+(\d+)\s*(hour|hr|minute|min|day)s?\s*$', lower)
|
||||
if m:
|
||||
n = int(m.group(1)); unit = m.group(2)
|
||||
@@ -399,7 +420,17 @@ def _parse_dt(s: str) -> datetime:
|
||||
# Last resort: dateutil's fuzzy parser
|
||||
try:
|
||||
from dateutil import parser as _du
|
||||
return _du.parse(s)
|
||||
parsed = _du.parse(s)
|
||||
# Strip tz like every other return path above — this function's
|
||||
# contract is naive datetimes (CalendarEvent.dtstart is naive). An
|
||||
# offset-bearing non-ISO input (e.g. RFC-2822 "Mon, 05 Jan 2026
|
||||
# 14:00:00 +0900") otherwise leaked tz-aware into the naive column and
|
||||
# crashed read-back comparisons in _expand_rrule with "can't compare
|
||||
# offset-naive and offset-aware datetimes".
|
||||
if parsed.tzinfo is not None:
|
||||
from datetime import timezone as _tz
|
||||
return parsed.astimezone(_tz.utc).replace(tzinfo=None)
|
||||
return parsed
|
||||
except Exception:
|
||||
raise ValueError(f"could not parse datetime: {s!r}")
|
||||
|
||||
@@ -440,6 +471,9 @@ def _event_to_dict(ev: CalendarEvent) -> dict:
|
||||
|
||||
# ── Recurrence expansion ──
|
||||
|
||||
_RRULE_EXPANSION_LIMIT = 1000
|
||||
|
||||
|
||||
def _expand_rrule(
|
||||
ev: CalendarEvent, start: datetime, end: datetime
|
||||
) -> List[dict]:
|
||||
@@ -462,6 +496,7 @@ def _expand_rrule(
|
||||
d = _event_to_dict(ev)
|
||||
d["is_recurrence"] = False
|
||||
d["series_uid"] = ev.uid
|
||||
d["truncated"] = False
|
||||
return [d]
|
||||
|
||||
# Parse the rrule, applying it to the base dtstart.
|
||||
@@ -487,6 +522,7 @@ def _expand_rrule(
|
||||
d = _event_to_dict(ev)
|
||||
d["is_recurrence"] = False
|
||||
d["series_uid"] = ev.uid
|
||||
d["truncated"] = False
|
||||
# Malformed RRULE rows are fetched by the recurring SQL branch
|
||||
# with only dtstart < end_dt — the base event may not actually
|
||||
# overlap the window. Only return if it does.
|
||||
@@ -499,22 +535,26 @@ def _expand_rrule(
|
||||
# (matching non-recurring overlap semantics: dtstart < end AND
|
||||
# dtend > start).
|
||||
expand_start = start - duration
|
||||
occurrences = rule.between(expand_start, end, inc=True)
|
||||
if not occurrences:
|
||||
return []
|
||||
|
||||
results = []
|
||||
truncated = False
|
||||
base = _event_to_dict(ev)
|
||||
|
||||
for occ_start in occurrences:
|
||||
for occ_start in rule.xafter(expand_start, inc=True):
|
||||
if occ_start >= end:
|
||||
break
|
||||
|
||||
occ_end = occ_start + duration
|
||||
|
||||
# Overlap filter: occurrence must intersect [start, end).
|
||||
# This enforces exclusive-end semantics (occ_start >= end is
|
||||
# excluded) and includes multi-day crossings (occ_end > start).
|
||||
if occ_start >= end or occ_end <= start:
|
||||
if occ_end <= start:
|
||||
continue
|
||||
|
||||
if len(results) >= _RRULE_EXPANSION_LIMIT:
|
||||
truncated = True
|
||||
break
|
||||
|
||||
# Build the compound uid: {base_uid}::{date} or ::{datetime}
|
||||
if ev.all_day:
|
||||
occ_uid = f"{ev.uid}::{occ_start.strftime('%Y-%m-%d')}"
|
||||
@@ -525,6 +565,7 @@ def _expand_rrule(
|
||||
d["uid"] = occ_uid
|
||||
d["series_uid"] = ev.uid
|
||||
d["is_recurrence"] = True
|
||||
d["truncated"] = False
|
||||
|
||||
if ev.all_day:
|
||||
d["dtstart"] = occ_start.strftime("%Y-%m-%d")
|
||||
@@ -537,6 +578,10 @@ def _expand_rrule(
|
||||
|
||||
results.append(d)
|
||||
|
||||
if truncated:
|
||||
for d in results:
|
||||
d["truncated"] = True
|
||||
|
||||
return results
|
||||
|
||||
|
||||
@@ -545,72 +590,178 @@ def _expand_rrule(
|
||||
def setup_calendar_routes() -> APIRouter:
|
||||
router = APIRouter(prefix="/api/calendar", tags=["calendar"])
|
||||
|
||||
# CalDAV connect form (Integrations → Calendar). Storage is local
|
||||
# SQLite; sync (src/caldav_sync.py) pulls remote events into it on
|
||||
# calendar open and periodically via the scheduler.
|
||||
# ── CalDAV multi-account helpers ─────────────────────────────────────────
|
||||
|
||||
def _get_caldav_accounts(owner: str) -> list:
|
||||
from src.caldav_sync import _load_caldav_accounts
|
||||
return _load_caldav_accounts(owner)
|
||||
|
||||
def _save_caldav_accounts(owner: str, accounts: list) -> None:
|
||||
from routes.prefs_routes import _load_for_user, _save_for_user
|
||||
prefs = _load_for_user(owner) or {}
|
||||
prefs["caldav_accounts"] = accounts
|
||||
prefs.pop("caldav", None)
|
||||
_save_for_user(owner, prefs)
|
||||
|
||||
# ── CalDAV config routes (backward-compat single-account API) ────────────
|
||||
|
||||
@router.get("/config")
|
||||
async def get_config(request: Request):
|
||||
"""Legacy single-account endpoint — returns the first configured account."""
|
||||
owner = _require_user(request)
|
||||
from routes.prefs_routes import _load_for_user
|
||||
cfg = (_load_for_user(owner) or {}).get("caldav", {}) or {}
|
||||
caldav_password = cfg.get("password") or ""
|
||||
if caldav_password:
|
||||
accounts = _get_caldav_accounts(owner)
|
||||
if not accounts:
|
||||
return {"url": "", "username": "", "password": "", "has_password": False, "local": True}
|
||||
first = accounts[0]
|
||||
pw = first.get("password") or ""
|
||||
has_pw = False
|
||||
if pw:
|
||||
try:
|
||||
from src.secret_storage import decrypt
|
||||
caldav_password = decrypt(caldav_password)
|
||||
has_pw = bool(decrypt(pw))
|
||||
except Exception:
|
||||
pass
|
||||
# Surface url+username but never hand the password back to the
|
||||
# client — saved-state UI shouldn't leak the credential.
|
||||
has_pw = bool(pw)
|
||||
return {
|
||||
"url": cfg.get("url", "") or "",
|
||||
"username": cfg.get("username", "") or "",
|
||||
"url": first.get("url", "") or "",
|
||||
"username": first.get("username", "") or "",
|
||||
"password": "",
|
||||
"has_password": bool(caldav_password),
|
||||
"local": not bool(cfg.get("url")),
|
||||
"has_password": has_pw,
|
||||
"local": not bool(first.get("url")),
|
||||
}
|
||||
|
||||
@router.post("/config")
|
||||
async def save_config(request: Request):
|
||||
"""Legacy single-account endpoint — upserts the first account."""
|
||||
owner = _require_user(request)
|
||||
from routes.prefs_routes import _load_for_user, _save_for_user
|
||||
try:
|
||||
body = await request.json()
|
||||
except Exception:
|
||||
body = {}
|
||||
prefs = _load_for_user(owner) or {}
|
||||
cfg = dict(prefs.get("caldav") or {})
|
||||
# Empty url => clear the whole entry (treat as "remove integration").
|
||||
accounts = _get_caldav_accounts(owner)
|
||||
if not (body.get("url") or "").strip():
|
||||
prefs.pop("caldav", None)
|
||||
_save_for_user(owner, prefs)
|
||||
_save_caldav_accounts(owner, [])
|
||||
return {"ok": True, "cleared": True}
|
||||
from src.caldav_sync import validate_caldav_url
|
||||
try:
|
||||
cfg["url"] = validate_caldav_url(body.get("url", ""))
|
||||
validated_url = validate_caldav_url(body.get("url", ""))
|
||||
except ValueError as e:
|
||||
raise HTTPException(400, str(e))
|
||||
cfg["username"] = (body.get("username") or "").strip()
|
||||
# Preserve the stored password when the client sends an empty
|
||||
# one (edit form re-submitted without re-typing the password).
|
||||
# cfg already holds the existing (already-encrypted) password from
|
||||
# prefs, so we only touch it when a new password is supplied —
|
||||
# re-encrypting the stored value would double-encrypt it.
|
||||
if accounts:
|
||||
acc = dict(accounts[0])
|
||||
else:
|
||||
import uuid as _uuid
|
||||
acc = {"id": str(_uuid.uuid4()), "label": "CalDAV"}
|
||||
acc["url"] = validated_url
|
||||
acc["username"] = (body.get("username") or "").strip()
|
||||
if body.get("password"):
|
||||
from src.secret_storage import encrypt
|
||||
cfg["password"] = encrypt(body["password"])
|
||||
prefs["caldav"] = cfg
|
||||
_save_for_user(owner, prefs)
|
||||
acc["password"] = encrypt(body["password"])
|
||||
new_accounts = [acc] + (accounts[1:] if len(accounts) > 1 else [])
|
||||
_save_caldav_accounts(owner, new_accounts)
|
||||
return {"ok": True}
|
||||
|
||||
# ── CalDAV multi-account CRUD ─────────────────────────────────────────────
|
||||
|
||||
@router.get("/config/accounts")
|
||||
async def list_caldav_accounts(request: Request):
|
||||
"""Return all configured CalDAV accounts (passwords never returned)."""
|
||||
owner = _require_user(request)
|
||||
accounts = _get_caldav_accounts(owner)
|
||||
safe = []
|
||||
for acc in accounts:
|
||||
pw = acc.get("password") or ""
|
||||
has_pw = False
|
||||
if pw:
|
||||
try:
|
||||
from src.secret_storage import decrypt
|
||||
has_pw = bool(decrypt(pw))
|
||||
except Exception:
|
||||
has_pw = bool(pw)
|
||||
safe.append({
|
||||
"id": acc.get("id", ""),
|
||||
"label": acc.get("label", "") or acc.get("url", ""),
|
||||
"url": acc.get("url", "") or "",
|
||||
"username": acc.get("username", "") or "",
|
||||
"has_password": has_pw,
|
||||
})
|
||||
return {"accounts": safe}
|
||||
|
||||
@router.post("/config/accounts")
|
||||
async def add_caldav_account(request: Request):
|
||||
"""Add a new CalDAV account."""
|
||||
import uuid as _uuid
|
||||
owner = _require_user(request)
|
||||
try:
|
||||
body = await request.json()
|
||||
except Exception:
|
||||
body = {}
|
||||
from src.caldav_sync import validate_caldav_url
|
||||
try:
|
||||
url = validate_caldav_url(body.get("url", ""))
|
||||
except ValueError as e:
|
||||
raise HTTPException(400, str(e))
|
||||
if not body.get("password"):
|
||||
raise HTTPException(400, "Password is required")
|
||||
from src.secret_storage import encrypt
|
||||
new_acc = {
|
||||
"id": str(_uuid.uuid4()),
|
||||
"label": (body.get("label") or "").strip() or "CalDAV",
|
||||
"url": url,
|
||||
"username": (body.get("username") or "").strip(),
|
||||
"password": encrypt(body["password"]),
|
||||
}
|
||||
accounts = _get_caldav_accounts(owner)
|
||||
accounts.append(new_acc)
|
||||
_save_caldav_accounts(owner, accounts)
|
||||
return {"ok": True, "id": new_acc["id"]}
|
||||
|
||||
@router.put("/config/accounts/{account_id}")
|
||||
async def update_caldav_account(account_id: str, request: Request):
|
||||
"""Update an existing CalDAV account by id."""
|
||||
owner = _require_user(request)
|
||||
try:
|
||||
body = await request.json()
|
||||
except Exception:
|
||||
body = {}
|
||||
accounts = _get_caldav_accounts(owner)
|
||||
idx = next((i for i, a in enumerate(accounts) if a.get("id") == account_id), None)
|
||||
if idx is None:
|
||||
raise HTTPException(404, "Account not found")
|
||||
acc = dict(accounts[idx])
|
||||
if body.get("url"):
|
||||
from src.caldav_sync import validate_caldav_url
|
||||
try:
|
||||
acc["url"] = validate_caldav_url(body["url"])
|
||||
except ValueError as e:
|
||||
raise HTTPException(400, str(e))
|
||||
if body.get("label") is not None:
|
||||
acc["label"] = (body.get("label") or "").strip() or "CalDAV"
|
||||
if body.get("username") is not None:
|
||||
acc["username"] = (body.get("username") or "").strip()
|
||||
if body.get("password"):
|
||||
from src.secret_storage import encrypt
|
||||
acc["password"] = encrypt(body["password"])
|
||||
accounts[idx] = acc
|
||||
_save_caldav_accounts(owner, accounts)
|
||||
return {"ok": True}
|
||||
|
||||
@router.delete("/config/accounts/{account_id}")
|
||||
async def delete_caldav_account(account_id: str, request: Request):
|
||||
"""Remove a CalDAV account by id."""
|
||||
owner = _require_user(request)
|
||||
accounts = _get_caldav_accounts(owner)
|
||||
new_accounts = [a for a in accounts if a.get("id") != account_id]
|
||||
if len(new_accounts) == len(accounts):
|
||||
raise HTTPException(404, "Account not found")
|
||||
_save_caldav_accounts(owner, new_accounts)
|
||||
return {"ok": True}
|
||||
|
||||
@router.post("/test")
|
||||
async def test_connection(request: Request):
|
||||
"""Actually probe the configured CalDAV server with a PROPFIND
|
||||
request (the same handshake every CalDAV client uses). Accepts
|
||||
an optional {url, username, password} body so the user can test
|
||||
a configuration BEFORE saving it; falls back to the stored
|
||||
creds otherwise. Returns {ok, error?} with a useful message on
|
||||
failure (status code, auth issue, network error)."""
|
||||
"""Probe a CalDAV server with a PROPFIND. Accepts an optional body:
|
||||
{url, username, password} to test before saving, or {account_id} to
|
||||
test an already-saved account. Falls back to the first saved account
|
||||
when nothing is provided."""
|
||||
owner = _require_user(request)
|
||||
try:
|
||||
body = await request.json()
|
||||
@@ -620,19 +771,24 @@ def setup_calendar_routes() -> APIRouter:
|
||||
user = (body.get("username") or "").strip()
|
||||
pw = body.get("password") or ""
|
||||
if not (url and user and pw):
|
||||
# Fall back to saved settings for this user.
|
||||
from routes.prefs_routes import _load_for_user
|
||||
cfg = (_load_for_user(owner) or {}).get("caldav", {}) or {}
|
||||
url = url or (cfg.get("url") or "")
|
||||
user = user or (cfg.get("username") or "")
|
||||
if not pw:
|
||||
pw = cfg.get("password") or ""
|
||||
if pw:
|
||||
try:
|
||||
from src.secret_storage import decrypt
|
||||
pw = decrypt(pw)
|
||||
except Exception:
|
||||
pass
|
||||
# Look up a saved account: by id if supplied, else first account.
|
||||
accounts = _get_caldav_accounts(owner)
|
||||
acc = None
|
||||
if body.get("account_id"):
|
||||
acc = next((a for a in accounts if a.get("id") == body["account_id"]), None)
|
||||
if acc is None and accounts:
|
||||
acc = accounts[0]
|
||||
if acc:
|
||||
url = url or (acc.get("url") or "")
|
||||
user = user or (acc.get("username") or "")
|
||||
if not pw:
|
||||
pw = acc.get("password") or ""
|
||||
if pw:
|
||||
try:
|
||||
from src.secret_storage import decrypt
|
||||
pw = decrypt(pw)
|
||||
except Exception:
|
||||
pass
|
||||
if not (url and user and pw):
|
||||
return {"ok": False, "error": "Missing URL, username, or password"}
|
||||
from src.caldav_sync import validate_caldav_url
|
||||
@@ -695,6 +851,28 @@ def setup_calendar_routes() -> APIRouter:
|
||||
from src.caldav_sync import sync_caldav
|
||||
return await sync_caldav(owner)
|
||||
|
||||
@router.delete("/calendars/{cal_id}")
|
||||
async def delete_calendar(cal_id: str, request: Request):
|
||||
owner = _require_user(request)
|
||||
db = SessionLocal()
|
||||
try:
|
||||
cal = db.query(CalendarCal).filter(
|
||||
CalendarCal.id == cal_id,
|
||||
CalendarCal.owner == owner,
|
||||
).first()
|
||||
if not cal:
|
||||
raise HTTPException(404, "Calendar not found")
|
||||
db.delete(cal)
|
||||
db.commit()
|
||||
return {"ok": True}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Failed to delete calendar %s: %s", cal_id, e)
|
||||
raise HTTPException(500, "Failed to delete calendar")
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
@router.get("/calendars")
|
||||
async def list_calendars(request: Request):
|
||||
owner = _require_user(request)
|
||||
@@ -703,7 +881,7 @@ def setup_calendar_routes() -> APIRouter:
|
||||
_ensure_default_calendar(db, owner)
|
||||
cals = db.query(CalendarCal).filter(CalendarCal.owner == owner).all()
|
||||
return {"calendars": [
|
||||
{"name": c.name, "href": c.id, "color": c.color}
|
||||
{"name": c.name, "href": c.id, "color": c.color, "source": c.source}
|
||||
for c in cals
|
||||
]}
|
||||
except HTTPException:
|
||||
@@ -766,8 +944,12 @@ def setup_calendar_routes() -> APIRouter:
|
||||
expanded.extend(_expand_rrule(e, start_dt, end_dt))
|
||||
|
||||
# Sort by occurrence start time for consistent frontend ordering.
|
||||
truncated = any(e.get("truncated") for e in expanded)
|
||||
expanded.sort(key=lambda d: d["dtstart"])
|
||||
return {"events": expanded}
|
||||
response: dict = {"events": expanded}
|
||||
if truncated:
|
||||
response["truncated"] = True
|
||||
return response
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
@@ -988,9 +1170,9 @@ def setup_calendar_routes() -> APIRouter:
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
# 10 MB hard cap on ICS upload. Loading the whole file into memory is
|
||||
# unavoidable with python-icalendar, so an unbounded upload would OOM.
|
||||
_ICS_MAX_BYTES = 10 * 1024 * 1024
|
||||
# Hard cap on ICS upload (ICS_MAX_BYTES, default 10 MB). Loading the whole
|
||||
# file into memory is unavoidable with python-icalendar, so an unbounded
|
||||
# upload would OOM.
|
||||
|
||||
@router.post("/import")
|
||||
async def import_ics(request: Request, file: UploadFile = File(...), calendar_name: str = ""):
|
||||
@@ -1000,7 +1182,7 @@ def setup_calendar_routes() -> APIRouter:
|
||||
owner = _require_user(request)
|
||||
db = SessionLocal()
|
||||
try:
|
||||
content = await read_upload_limited(file, _ICS_MAX_BYTES, "ICS file")
|
||||
content = await read_upload_limited(file, ICS_MAX_BYTES, "ICS file")
|
||||
try:
|
||||
cal_data = iCal.from_ical(content)
|
||||
except Exception as e:
|
||||
@@ -1168,11 +1350,14 @@ def setup_calendar_routes() -> APIRouter:
|
||||
lines.append("END:VCALENDAR")
|
||||
|
||||
ics_data = "\r\n".join(lines)
|
||||
safe_name = cal.name.replace(" ", "_").replace("/", "_")
|
||||
download_name = _safe_ics_filename(cal.name)
|
||||
return Response(
|
||||
content=ics_data,
|
||||
media_type="text/calendar",
|
||||
headers={"Content-Disposition": f'attachment; filename="{safe_name}.ics"'},
|
||||
headers={
|
||||
"Content-Disposition": f'attachment; filename="{download_name}"',
|
||||
"X-Content-Type-Options": "nosniff",
|
||||
},
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
@@ -1194,7 +1379,7 @@ def setup_calendar_routes() -> APIRouter:
|
||||
"tomorrow", "next Tuesday", "in 30 minutes" resolve correctly.
|
||||
Uses the "utility" endpoint (small / fast model) to keep latency low.
|
||||
"""
|
||||
_require_user(request)
|
||||
owner = _require_user(request)
|
||||
from src.endpoint_resolver import resolve_endpoint
|
||||
from src.llm_core import llm_call_async
|
||||
from src.text_helpers import strip_think
|
||||
@@ -1220,9 +1405,9 @@ def setup_calendar_routes() -> APIRouter:
|
||||
if tz_hint:
|
||||
set_user_tz_name(tz_hint)
|
||||
|
||||
url, model, headers = resolve_endpoint("utility")
|
||||
url, model, headers = resolve_endpoint("utility", owner=owner or None)
|
||||
if not url:
|
||||
url, model, headers = resolve_endpoint("default")
|
||||
url, model, headers = resolve_endpoint("default", owner=owner or None)
|
||||
if not url or not model:
|
||||
return {"ok": False, "error": "No LLM endpoint configured"}
|
||||
|
||||
|
||||
+130
-38
@@ -75,7 +75,7 @@ def _enforce_chat_privileges(request, sess) -> None:
|
||||
allowlist, or HTTPException(429) if the user has hit their daily message
|
||||
cap. No-op for unauthenticated callers or when auth_manager is absent
|
||||
(single-user mode). Admins receive ADMIN_PRIVILEGES from get_privileges,
|
||||
which means empty allowed_models / zero cap → no-op for them.
|
||||
which means unrestricted allowed_models / zero cap -> no-op for them.
|
||||
"""
|
||||
try:
|
||||
user = get_current_user(request)
|
||||
@@ -88,8 +88,18 @@ def _enforce_chat_privileges(request, sess) -> None:
|
||||
return
|
||||
|
||||
privs = auth_manager.get_privileges(user) or {}
|
||||
allowed = privs.get("allowed_models") or []
|
||||
if allowed and sess.model and sess.model not in allowed:
|
||||
|
||||
# Explicit "block everything" sentinel takes precedence over the
|
||||
# allowlist — it's the only way to distinguish "user clicked [None]"
|
||||
# (block all) from "user clicked [All]" (no restriction), since both
|
||||
# otherwise produce an empty `allowed_models` list.
|
||||
if privs.get("block_all_models"):
|
||||
raise HTTPException(403, f"Your account is not allowed to use model '{sess.model}'.")
|
||||
|
||||
allowed_raw = privs.get("allowed_models")
|
||||
allowed = allowed_raw if isinstance(allowed_raw, list) else []
|
||||
restricted = bool(privs.get("allowed_models_restricted")) or bool(allowed)
|
||||
if restricted and sess.model and sess.model not in allowed:
|
||||
raise HTTPException(403, f"Your account is not allowed to use model '{sess.model}'.")
|
||||
|
||||
cap = int(privs.get("max_messages_per_day") or 0)
|
||||
@@ -194,14 +204,26 @@ def try_fallback_endpoint(sess, session_id: str) -> dict | None:
|
||||
Returns {"model": ..., "endpoint_url": ..., "endpoint_name": ...} or None.
|
||||
"""
|
||||
import requests as _req
|
||||
from src.endpoint_resolver import build_chat_url, build_headers, build_models_url, normalize_base
|
||||
from src.endpoint_resolver import (
|
||||
build_chat_url,
|
||||
build_headers,
|
||||
build_models_url,
|
||||
normalize_base,
|
||||
resolve_endpoint_runtime,
|
||||
)
|
||||
from src.chatgpt_subscription import is_chatgpt_subscription_base
|
||||
|
||||
current_url = sess.endpoint_url or ""
|
||||
owner = getattr(sess, "owner", None)
|
||||
db = SessionLocal()
|
||||
try:
|
||||
endpoints = db.query(ModelEndpoint).filter(
|
||||
q = db.query(ModelEndpoint).filter(
|
||||
ModelEndpoint.is_enabled == True
|
||||
).all()
|
||||
)
|
||||
if owner:
|
||||
from src.auth_helpers import owner_filter
|
||||
q = owner_filter(q, ModelEndpoint, owner)
|
||||
endpoints = q.all()
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
@@ -210,26 +232,33 @@ def try_fallback_endpoint(sess, session_id: str) -> dict | None:
|
||||
# Skip current endpoint
|
||||
if current_url and base in current_url:
|
||||
continue
|
||||
# Quick ping
|
||||
ping_url = build_models_url(base)
|
||||
headers = build_headers(ep.api_key, base)
|
||||
try:
|
||||
r = _req.get(ping_url, headers=headers, timeout=5)
|
||||
r.raise_for_status()
|
||||
data = r.json()
|
||||
models = [m.get("id") for m in (data.get("data") or []) if m.get("id")]
|
||||
if not models:
|
||||
models = [
|
||||
m.get("name") or m.get("model")
|
||||
for m in (data.get("models") or [])
|
||||
if m.get("name") or m.get("model")
|
||||
]
|
||||
base, api_key = resolve_endpoint_runtime(ep, owner=owner)
|
||||
except Exception:
|
||||
continue
|
||||
ping_url = build_models_url(base)
|
||||
headers = build_headers(api_key, base)
|
||||
try:
|
||||
if ping_url:
|
||||
r = _req.get(ping_url, headers=headers, timeout=5)
|
||||
r.raise_for_status()
|
||||
data = r.json()
|
||||
models = [m.get("id") for m in (data.get("data") or []) if m.get("id")]
|
||||
if not models:
|
||||
models = [
|
||||
m.get("name") or m.get("model")
|
||||
for m in (data.get("models") or [])
|
||||
if m.get("name") or m.get("model")
|
||||
]
|
||||
else:
|
||||
models = json.loads(ep.cached_models or "[]")
|
||||
if not models:
|
||||
continue
|
||||
# Found a working endpoint — update session
|
||||
new_model = models[0]
|
||||
chat_url = build_chat_url(base)
|
||||
new_headers = build_headers(ep.api_key, base)
|
||||
new_headers = build_headers(api_key, base)
|
||||
persisted_headers = {} if is_chatgpt_subscription_base(base) else new_headers
|
||||
|
||||
sess.model = new_model
|
||||
sess.endpoint_url = chat_url
|
||||
@@ -241,7 +270,7 @@ def try_fallback_endpoint(sess, session_id: str) -> dict | None:
|
||||
_db.query(DBSession).filter(DBSession.id == session_id).update({
|
||||
"model": new_model,
|
||||
"endpoint_url": chat_url,
|
||||
"headers": json.dumps(new_headers),
|
||||
"headers": persisted_headers,
|
||||
})
|
||||
_db.commit()
|
||||
finally:
|
||||
@@ -275,11 +304,16 @@ def extract_preset(chat_handler, preset_id) -> PresetInfo:
|
||||
async def preprocess(
|
||||
chat_handler, message, att_ids, sess,
|
||||
auto_opened_docs: Optional[list] = None,
|
||||
allow_tool_preprocessing: bool = True,
|
||||
) -> PreprocessedMessage:
|
||||
"""Run chat_handler.preprocess_message and wrap the result."""
|
||||
enhanced, user_content, text_ctx, yt_transcripts, att_meta = (
|
||||
await chat_handler.preprocess_message(
|
||||
message, att_ids, sess, auto_opened_docs=auto_opened_docs
|
||||
message,
|
||||
att_ids,
|
||||
sess,
|
||||
auto_opened_docs=auto_opened_docs,
|
||||
allow_tool_preprocessing=allow_tool_preprocessing,
|
||||
)
|
||||
)
|
||||
return PreprocessedMessage(
|
||||
@@ -329,16 +363,26 @@ def _session_url_matches_endpoint(session_url: str, endpoint_base: str) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def _has_auth_keys(headers) -> bool:
|
||||
"""True if a headers dict carries an Authorization/x-api-key entry."""
|
||||
return isinstance(headers, dict) and any(
|
||||
k.lower() in ('authorization', 'x-api-key') for k in headers
|
||||
)
|
||||
|
||||
|
||||
def resolve_session_auth(sess, session_id: str, owner: Optional[str] = None):
|
||||
"""Ensure session has auth headers — resolve from endpoint DB if missing."""
|
||||
has_auth = sess.headers and isinstance(sess.headers, dict) and any(
|
||||
k.lower() in ('authorization', 'x-api-key') for k in sess.headers
|
||||
)
|
||||
if has_auth:
|
||||
try:
|
||||
from src.chatgpt_subscription import is_chatgpt_subscription_base
|
||||
is_chatgpt_subscription = is_chatgpt_subscription_base(getattr(sess, "endpoint_url", "") or "")
|
||||
except Exception:
|
||||
is_chatgpt_subscription = False
|
||||
has_auth = _has_auth_keys(sess.headers)
|
||||
if has_auth and not is_chatgpt_subscription:
|
||||
return
|
||||
|
||||
try:
|
||||
from src.endpoint_resolver import build_headers, normalize_base
|
||||
from src.endpoint_resolver import build_headers, resolve_endpoint_runtime
|
||||
db = SessionLocal()
|
||||
try:
|
||||
target_url = getattr(sess, "endpoint_url", "") or ""
|
||||
@@ -354,10 +398,30 @@ def resolve_session_auth(sess, session_id: str, owner: Optional[str] = None):
|
||||
for ep in q.all():
|
||||
if not _session_url_matches_endpoint(target_url, ep.base_url or ""):
|
||||
continue
|
||||
if not ep.api_key:
|
||||
try:
|
||||
base, api_key = resolve_endpoint_runtime(ep, owner=owner)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to resolve provider auth for session %s: %s", session_id, e)
|
||||
return
|
||||
if not api_key:
|
||||
# No usable key (e.g. ChatGPT Subscription needs re-auth).
|
||||
return
|
||||
sess.headers = build_headers(api_key, base)
|
||||
if is_chatgpt_subscription:
|
||||
# The bearer is short-lived and re-resolved per request, so it
|
||||
# stays request-local and is never written to the plaintext
|
||||
# sessions.headers column. Proactively strip any bearer an
|
||||
# older code path may have persisted so it does not linger.
|
||||
stale_q = db.query(DBSession).filter(DBSession.id == session_id)
|
||||
if owner:
|
||||
stale_q = stale_q.filter(DBSession.owner == owner)
|
||||
stored = stale_q.first()
|
||||
if stored is not None and _has_auth_keys(stored.headers):
|
||||
stale_q.update({"headers": {}})
|
||||
db.commit()
|
||||
logger.info(f"Cleared persisted ChatGPT Subscription bearer from session {session_id}")
|
||||
logger.debug(f"Resolved request-local ChatGPT Subscription auth for session {session_id}")
|
||||
return
|
||||
base = normalize_base(ep.base_url or "")
|
||||
sess.headers = build_headers(ep.api_key, base)
|
||||
update_q = db.query(DBSession).filter(DBSession.id == session_id)
|
||||
if owner:
|
||||
update_q = update_q.filter(DBSession.owner == owner)
|
||||
@@ -401,7 +465,12 @@ def _normalize_model_id_from_cache(sess) -> Optional[str]:
|
||||
|
||||
db = SessionLocal()
|
||||
try:
|
||||
endpoints = db.query(ModelEndpoint).filter(ModelEndpoint.is_enabled == True).all()
|
||||
q = db.query(ModelEndpoint).filter(ModelEndpoint.is_enabled == True)
|
||||
owner = getattr(sess, "owner", None)
|
||||
if owner:
|
||||
from src.auth_helpers import owner_filter
|
||||
q = owner_filter(q, ModelEndpoint, owner)
|
||||
endpoints = q.all()
|
||||
for ep in endpoints:
|
||||
try:
|
||||
if normalize_base(getattr(ep, "base_url", "") or "") != session_base:
|
||||
@@ -448,6 +517,7 @@ async def build_chat_context(
|
||||
webhook_manager=None,
|
||||
use_enhanced_message: bool = False,
|
||||
agent_mode: bool = False,
|
||||
allow_tool_preprocessing: bool = True,
|
||||
) -> ChatContext:
|
||||
"""Build the full context (preface + messages) for an LLM call.
|
||||
|
||||
@@ -465,6 +535,7 @@ async def build_chat_context(
|
||||
preprocessed = await preprocess(
|
||||
chat_handler, message, att_ids or [], sess,
|
||||
auto_opened_docs=auto_opened_docs,
|
||||
allow_tool_preprocessing=allow_tool_preprocessing,
|
||||
)
|
||||
|
||||
# Add user message to history
|
||||
@@ -483,6 +554,9 @@ async def build_chat_context(
|
||||
# Skills injection respects its own enable toggle (mirrors memory_enabled).
|
||||
# When off, the "Available skills" index is not added to the prompt.
|
||||
skills_enabled = not incognito and uprefs.get("skills_enabled", True)
|
||||
if not allow_tool_preprocessing:
|
||||
mem_enabled = False
|
||||
skills_enabled = False
|
||||
logger.debug(
|
||||
"Memory enabled=%s for user=%s (incognito=%s, no_memory=%s, pref=%s)",
|
||||
mem_enabled, user, incognito, no_memory, uprefs.get("memory_enabled", "NOT_SET"),
|
||||
@@ -490,11 +564,11 @@ async def build_chat_context(
|
||||
|
||||
# Use RAG?
|
||||
use_rag_val = (str(use_rag).lower() != "false") if use_rag is not None else True
|
||||
if incognito:
|
||||
if incognito or not allow_tool_preprocessing:
|
||||
use_rag_val = False
|
||||
|
||||
# If pre-fetched search context was provided (compare mode), skip live web search
|
||||
skip_web = bool(search_context)
|
||||
skip_web = bool(search_context) or not allow_tool_preprocessing
|
||||
|
||||
# Build context preface
|
||||
# The stream path uses enhanced_message (with CoT/preprocessing applied),
|
||||
@@ -521,7 +595,7 @@ async def build_chat_context(
|
||||
used_memories = getattr(chat_processor, '_last_used_memories', [])
|
||||
|
||||
# Inject pre-fetched search context (compare mode)
|
||||
if search_context:
|
||||
if search_context and allow_tool_preprocessing:
|
||||
preface.append(untrusted_context_message("prefetched search context", search_context))
|
||||
|
||||
# YouTube transcripts
|
||||
@@ -530,7 +604,11 @@ async def build_chat_context(
|
||||
|
||||
# Normalize model ID. Prefer cached endpoint models so group chat does not
|
||||
# re-hit slow local /models endpoints on every participant turn.
|
||||
norm = _normalize_model_id_from_cache(sess) or normalize_model_id(sess.endpoint_url, sess.model)
|
||||
norm = _normalize_model_id_from_cache(sess) or normalize_model_id(
|
||||
sess.endpoint_url,
|
||||
sess.model,
|
||||
owner=getattr(sess, "owner", None),
|
||||
)
|
||||
if norm:
|
||||
sess.model = norm
|
||||
|
||||
@@ -539,7 +617,7 @@ async def build_chat_context(
|
||||
|
||||
# Auto-compact
|
||||
messages, context_length, was_compacted = await maybe_compact(
|
||||
sess, sess.endpoint_url, sess.model, messages, sess.headers,
|
||||
sess, sess.endpoint_url, sess.model, messages, sess.headers, owner=user,
|
||||
)
|
||||
messages = trim_for_context(messages, context_length)
|
||||
|
||||
@@ -772,7 +850,19 @@ def save_assistant_response(
|
||||
):
|
||||
"""Add assistant response to session history. In incognito mode, keeps in-memory context but skips DB persistence."""
|
||||
md = dict(last_metrics) if last_metrics else {}
|
||||
md["model"] = sess.model
|
||||
def _model_value(value) -> str:
|
||||
if value is None:
|
||||
return ""
|
||||
if not isinstance(value, str):
|
||||
value = str(value)
|
||||
return value.strip()
|
||||
|
||||
requested_model = _model_value(md.get("requested_model") or md.get("selected_model") or getattr(sess, "model", ""))
|
||||
actual_model = _model_value(md.get("model") or md.get("actual_model") or requested_model)
|
||||
if requested_model:
|
||||
md["requested_model"] = requested_model
|
||||
if actual_model:
|
||||
md["model"] = actual_model
|
||||
if character_name:
|
||||
md["character_name"] = character_name
|
||||
if web_sources:
|
||||
@@ -841,12 +931,13 @@ def run_post_response_tasks(
|
||||
skills_manager=None,
|
||||
owner: str = None,
|
||||
extract_skills: bool = True,
|
||||
allow_background_extraction: bool = True,
|
||||
):
|
||||
"""Fire background tasks after a completed response: memory extraction, webhooks, auto-name, skill extraction."""
|
||||
# Memory extraction — only every 4th message pair to avoid excess LLM calls
|
||||
_msg_count = len(sess.history) if hasattr(sess, 'history') else 0
|
||||
_should_extract = (_msg_count >= 4) and (_msg_count % 4 == 0)
|
||||
if not incognito and not compare_mode and _should_extract and uprefs.get("auto_memory", True):
|
||||
if allow_background_extraction and not incognito and not compare_mode and _should_extract and uprefs.get("auto_memory", True):
|
||||
from services.memory.memory_extractor import extract_and_store
|
||||
from src.task_endpoint import resolve_task_endpoint
|
||||
t_url, t_model, t_headers = resolve_task_endpoint(
|
||||
@@ -873,6 +964,7 @@ def run_post_response_tasks(
|
||||
)
|
||||
if (
|
||||
extract_skills
|
||||
and allow_background_extraction
|
||||
and auto_skills_enabled
|
||||
and not incognito
|
||||
and not compare_mode
|
||||
|
||||
+206
-72
@@ -20,6 +20,7 @@ from src import agent_runs
|
||||
from src.model_context import estimate_tokens
|
||||
from src.chat_helpers import coerce_message_and_session
|
||||
from src.endpoint_resolver import normalize_base as _normalize_base, build_chat_url
|
||||
from src.session_search import search_session_messages
|
||||
from src.prompt_security import untrusted_context_message
|
||||
from core.exceptions import SessionNotFoundError
|
||||
from src.auth_helpers import get_current_user
|
||||
@@ -39,6 +40,7 @@ from routes.chat_helpers import (
|
||||
_enforce_chat_privileges,
|
||||
)
|
||||
from src.action_intents import classify_tool_intent as _classify_tool_intent
|
||||
from src.tool_policy import build_effective_tool_policy
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -167,13 +169,20 @@ def _recover_empty_session_model(sess, session_id: str, owner: str | None = None
|
||||
Covers the window between endpoint setup and the first chat send: the
|
||||
picker showed a model in the dropdown but the session record never got
|
||||
written (Issue #587 — UI uses the cached endpoint list, not s.model).
|
||||
Without this, we'd POST the upstream with model="" and get a generic
|
||||
401/503 instead of using the model the user already picked.
|
||||
|
||||
Returns True iff sess.model was repaired.
|
||||
For ChatGPT Subscription, also repairs stale OpenAI API model names such as
|
||||
``gpt-5`` that are not accepted by the Codex-backed ChatGPT account route.
|
||||
"""
|
||||
if getattr(sess, "model", None):
|
||||
return False
|
||||
current_model = (getattr(sess, "model", "") or "").strip()
|
||||
endpoint_url = (getattr(sess, "endpoint_url", "") or "").strip()
|
||||
is_chatgpt_subscription = False
|
||||
if current_model:
|
||||
try:
|
||||
from src.chatgpt_subscription import is_chatgpt_subscription_base
|
||||
is_chatgpt_subscription = is_chatgpt_subscription_base(endpoint_url)
|
||||
if not is_chatgpt_subscription:
|
||||
return False
|
||||
except Exception:
|
||||
return False
|
||||
db = SessionLocal()
|
||||
try:
|
||||
# Prefer the endpoint whose base URL matches the session — we know the
|
||||
@@ -192,16 +201,51 @@ def _recover_empty_session_model(sess, session_id: str, owner: str | None = None
|
||||
break
|
||||
if not ep:
|
||||
return False
|
||||
if not is_chatgpt_subscription:
|
||||
try:
|
||||
from src.chatgpt_subscription import is_chatgpt_subscription_base
|
||||
is_chatgpt_subscription = is_chatgpt_subscription_base(getattr(ep, "base_url", "") or endpoint_url)
|
||||
except Exception:
|
||||
is_chatgpt_subscription = False
|
||||
try:
|
||||
cached = json.loads(ep.cached_models) if isinstance(ep.cached_models, str) else (ep.cached_models or [])
|
||||
except Exception:
|
||||
cached = []
|
||||
if not cached:
|
||||
visible = []
|
||||
else:
|
||||
try:
|
||||
visible = _visible_models(cached, getattr(ep, "hidden_models", None))
|
||||
except Exception:
|
||||
visible = cached
|
||||
if current_model and current_model in {str(item).strip() for item in visible}:
|
||||
return False
|
||||
try:
|
||||
visible = _visible_models(cached, getattr(ep, "hidden_models", None))
|
||||
except Exception:
|
||||
visible = cached
|
||||
if is_chatgpt_subscription:
|
||||
live_models = []
|
||||
if getattr(ep, "provider_auth_id", None):
|
||||
try:
|
||||
from src.chatgpt_subscription import fetch_available_models
|
||||
from src.endpoint_resolver import resolve_endpoint_runtime
|
||||
_base, api_key = resolve_endpoint_runtime(ep, owner=owner)
|
||||
if api_key:
|
||||
live_models = fetch_available_models(api_key)
|
||||
if live_models:
|
||||
ep.cached_models = json.dumps(live_models)
|
||||
db.commit()
|
||||
except Exception:
|
||||
live_models = []
|
||||
# ChatGPT Subscription recovery must use the live Codex catalog.
|
||||
# Cached rows are only trusted above to avoid revalidating a model
|
||||
# that is already present in the visible picker list.
|
||||
cached = live_models
|
||||
if not cached:
|
||||
return False
|
||||
try:
|
||||
visible = _visible_models(cached, getattr(ep, "hidden_models", None))
|
||||
except Exception:
|
||||
visible = cached
|
||||
if current_model and current_model in {str(item).strip() for item in visible}:
|
||||
return False
|
||||
if not visible:
|
||||
return False
|
||||
model = visible[0]
|
||||
@@ -211,14 +255,17 @@ def _recover_empty_session_model(sess, session_id: str, owner: str | None = None
|
||||
# Persist so the next request, websocket reconnect, or page reload
|
||||
# picks up the same model (we'd otherwise re-pick on every send
|
||||
# and silently switch on the user if the cached order shifts).
|
||||
db_session = db.query(DBSession).filter(DBSession.id == session_id).first()
|
||||
db_session_q = db.query(DBSession).filter(DBSession.id == session_id)
|
||||
if owner:
|
||||
db_session_q = db_session_q.filter(DBSession.owner == owner)
|
||||
db_session = db_session_q.first()
|
||||
if db_session:
|
||||
db_session.model = model
|
||||
db_session.updated_at = datetime.utcnow()
|
||||
db.commit()
|
||||
sess.model = model
|
||||
logger.info(
|
||||
"Recovered empty session model for %s — picked %r from endpoint %s",
|
||||
"Recovered session model for %s — picked %r from endpoint %s",
|
||||
session_id, model, ep.id,
|
||||
)
|
||||
return True
|
||||
@@ -304,8 +351,13 @@ def setup_chat_routes(
|
||||
# non-streaming path can't be used to bypass).
|
||||
_enforce_chat_privileges(request, sess)
|
||||
|
||||
tool_policy = build_effective_tool_policy(last_user_message=message)
|
||||
allow_tool_preprocessing = not tool_policy.block_all_tool_calls
|
||||
|
||||
# Inline memory command
|
||||
memory_response = await chat_handler.handle_memory_command(sess, message)
|
||||
memory_response = None
|
||||
if not tool_policy.blocks("manage_memory"):
|
||||
memory_response = await chat_handler.handle_memory_command(sess, message)
|
||||
if memory_response:
|
||||
return {"response": memory_response}
|
||||
|
||||
@@ -319,10 +371,15 @@ def setup_chat_routes(
|
||||
use_web=use_web,
|
||||
time_filter=time_filter,
|
||||
webhook_manager=webhook_manager,
|
||||
allow_tool_preprocessing=allow_tool_preprocessing,
|
||||
)
|
||||
|
||||
# Research injection
|
||||
if use_research:
|
||||
research_blocked_by_policy = (
|
||||
tool_policy.blocks("trigger_research")
|
||||
or tool_policy.blocks("manage_research")
|
||||
)
|
||||
if use_research and not research_blocked_by_policy:
|
||||
try:
|
||||
_r_ep, _r_model, _r_headers = _resolve_research_endpoint(sess)
|
||||
research_ctx = await research_handler.call_research_service(
|
||||
@@ -357,6 +414,7 @@ def setup_chat_routes(
|
||||
ctx.uprefs, memory_manager, memory_vector, webhook_manager,
|
||||
character_name=ctx.preset.character_name,
|
||||
owner=ctx.user,
|
||||
allow_background_extraction=not tool_policy.block_all_tool_calls,
|
||||
)
|
||||
|
||||
return {"response": reply}
|
||||
@@ -394,6 +452,7 @@ def setup_chat_routes(
|
||||
search_context = form_data.get("search_context") # pre-fetched web search results (compare mode)
|
||||
compare_mode = str(form_data.get("compare_mode", "")).lower() == "true"
|
||||
incognito = str(form_data.get("incognito", "")).lower() == "true"
|
||||
plan_mode = str(form_data.get("plan_mode", "")).lower() == "true"
|
||||
chat_mode = str(form_data.get("mode", "")).lower() # 'chat' or 'agent'
|
||||
# Workspace: confine the agent's file/shell tools to this folder. Validate
|
||||
# it's a real directory; ignore (no confinement) otherwise.
|
||||
@@ -401,6 +460,17 @@ def setup_chat_routes(
|
||||
if workspace:
|
||||
_ws_real = os.path.realpath(os.path.expanduser(workspace))
|
||||
workspace = _ws_real if os.path.isdir(_ws_real) else ""
|
||||
# Plan mode is a modifier on agent mode — it only makes sense with tools.
|
||||
if plan_mode:
|
||||
chat_mode = "agent"
|
||||
# An approved plan being EXECUTED: the frontend sends the checklist back
|
||||
# on each turn so we can pin it in context. This way a long plan on a
|
||||
# weak model survives history truncation — the agent can always re-read
|
||||
# the plan. Ignored while still proposing (plan_mode on). Capped so a
|
||||
# huge plan can't blow the prompt.
|
||||
approved_plan = ""
|
||||
if not plan_mode:
|
||||
approved_plan = (form_data.get("approved_plan") or "").strip()[:8192]
|
||||
# Did the USER explicitly pick agent mode? (vs. us auto-escalating
|
||||
# below). Skill extraction should only learn from real agent sessions,
|
||||
# not chats we quietly promoted for a notes/calendar intent.
|
||||
@@ -479,11 +549,6 @@ def setup_chat_routes(
|
||||
do_research = True
|
||||
logger.info(f"Session {session} in research_pending — auto-triggering research")
|
||||
|
||||
# Persist session mode (research > agent > chat)
|
||||
_effective_mode = 'research' if do_research else (chat_mode or 'chat')
|
||||
if _effective_mode in ('agent', 'research', 'chat'):
|
||||
set_session_mode(session, _effective_mode)
|
||||
|
||||
att_ids = []
|
||||
if body and isinstance(body.get("attachments"), list):
|
||||
att_ids = [str(x) for x in body["attachments"]]
|
||||
@@ -494,6 +559,10 @@ def setup_chat_routes(
|
||||
pass
|
||||
|
||||
no_memory = str(form_data.get("no_memory", "")).lower() == "true"
|
||||
pre_context_tool_policy = build_effective_tool_policy(
|
||||
last_user_message=message,
|
||||
)
|
||||
allow_tool_preprocessing = not pre_context_tool_policy.block_all_tool_calls
|
||||
|
||||
# Build shared context (stream path uses enhanced_message for context preface)
|
||||
ctx = await build_chat_context(
|
||||
@@ -515,6 +584,7 @@ def setup_chat_routes(
|
||||
# manage_skills (agent mode). In plain chat or incognito the
|
||||
# index would be useless / unwanted noise.
|
||||
agent_mode=(chat_mode == "agent"),
|
||||
allow_tool_preprocessing=allow_tool_preprocessing,
|
||||
)
|
||||
|
||||
_research_flags = {"do": do_research} # Mutable container for generator scope
|
||||
@@ -659,6 +729,32 @@ def setup_chat_routes(
|
||||
if chat_mode == 'chat':
|
||||
disabled_tools.update({"bash", "python", "read_file", "write_file", "web_search", "web_fetch", "search_chats", "manage_tasks"})
|
||||
|
||||
# Plan mode: investigate read-only, propose a plan, don't mutate. Block
|
||||
# every tool not on the read-only allowlist. (stream_agent_loop enforces
|
||||
# this again + drops MCP, so this is belt-and-suspenders.)
|
||||
if plan_mode:
|
||||
from src.tool_security import plan_mode_disabled_tools
|
||||
disabled_tools.update(plan_mode_disabled_tools())
|
||||
|
||||
tool_policy = build_effective_tool_policy(
|
||||
disabled_tools=disabled_tools,
|
||||
last_user_message=message,
|
||||
)
|
||||
disabled_tools = tool_policy.all_disabled_names()
|
||||
research_blocked_by_policy = bool(
|
||||
tool_policy.blocks("trigger_research")
|
||||
or tool_policy.blocks("manage_research")
|
||||
)
|
||||
effective_do_research = bool(
|
||||
do_research and _research_flags["do"] and not research_blocked_by_policy
|
||||
)
|
||||
|
||||
# Persist session mode after policy/privilege gates so blocked research
|
||||
# turns remain ordinary chat/agent streams and saved messages.
|
||||
_effective_mode = 'research' if effective_do_research else (chat_mode or 'chat')
|
||||
if _effective_mode in ('agent', 'research', 'chat'):
|
||||
set_session_mode(session, _effective_mode)
|
||||
|
||||
async def stream_with_save() -> AsyncGenerator[str, None]:
|
||||
# _effective_mode is read-only here; closure captures it from
|
||||
# the outer scope. (Was `nonlocal` but never reassigned.)
|
||||
@@ -666,7 +762,7 @@ def setup_chat_routes(
|
||||
web_sources = ctx.web_sources
|
||||
|
||||
# Register active stream for partial-save safety net
|
||||
_active_streams[session] = {"status": "streaming", "partial": "", "query": message, "is_research": do_research, "mode": _effective_mode}
|
||||
_active_streams[session] = {"status": "streaming", "partial": "", "query": message, "is_research": effective_do_research, "mode": _effective_mode}
|
||||
|
||||
if ctx.preprocessed.attachment_meta:
|
||||
yield f"data: {json.dumps({'type': 'attachments', 'data': ctx.preprocessed.attachment_meta})}\n\n"
|
||||
@@ -690,7 +786,7 @@ def setup_chat_routes(
|
||||
yield f"data: {json.dumps({'type': 'memories_used', 'data': ctx.used_memories})}\n\n"
|
||||
|
||||
# Run research as a background task (survives page refresh)
|
||||
if do_research and _research_flags["do"]:
|
||||
if effective_do_research:
|
||||
_r_ep, _r_model, _r_headers = _resolve_research_endpoint(sess)
|
||||
_auth_keys = list(_r_headers.keys()) if _r_headers else []
|
||||
logger.info(f"Research endpoint resolved: model={_r_model}, endpoint={_r_ep}, auth_keys={_auth_keys}, sess_headers_keys={list(sess.headers.keys()) if isinstance(sess.headers, dict) else type(sess.headers)}")
|
||||
@@ -829,7 +925,7 @@ def setup_chat_routes(
|
||||
_fallback_candidates = []
|
||||
|
||||
# Send model name early so the frontend can show it during streaming
|
||||
_model_suffix = "Research" if do_research else None
|
||||
_model_suffix = "Research" if effective_do_research else None
|
||||
_model_info = {"type": "model_info", "model": sess.model}
|
||||
if _model_suffix:
|
||||
_model_info["suffix"] = _model_suffix
|
||||
@@ -839,6 +935,12 @@ def setup_chat_routes(
|
||||
|
||||
if _is_image_generation_session(sess, owner=_user):
|
||||
from src.settings import get_setting
|
||||
if tool_policy.blocks("generate_image"):
|
||||
_blocked_msg = tool_policy.reason_for("generate_image")
|
||||
yield f'data: {json.dumps({"delta": _blocked_msg})}\n\n'
|
||||
yield "data: [DONE]\n\n"
|
||||
_active_streams.pop(session, None)
|
||||
return
|
||||
if not get_setting("image_gen_enabled", True):
|
||||
yield f'data: {json.dumps({"delta": "Image generation is disabled by the administrator."})}\n\n'
|
||||
yield "data: [DONE]\n\n"
|
||||
@@ -873,6 +975,8 @@ def setup_chat_routes(
|
||||
elif chat_mode == "chat":
|
||||
_chat_start = time.time()
|
||||
_answered_by = None # set if the selected model failed and a fallback answered
|
||||
_requested_model = sess.model
|
||||
_actual_model = None
|
||||
# ── Chat mode: call stream_llm directly, NO tools, NO document access ──
|
||||
try:
|
||||
_chat_candidates = [(sess.endpoint_url, sess.model, sess.headers)] + _fallback_candidates
|
||||
@@ -905,10 +1009,18 @@ def setup_chat_routes(
|
||||
# Selected model failed; a fallback answered.
|
||||
# Forward the notice and remember the real model.
|
||||
_answered_by = data.get("answered_by") or _answered_by
|
||||
_actual_model = _actual_model or _answered_by
|
||||
data["selected_model"] = data.get("selected_model") or _requested_model
|
||||
yield chunk
|
||||
elif data.get("type") == "model_actual":
|
||||
_actual_model = data.get("model") or _actual_model
|
||||
data["requested_model"] = _requested_model
|
||||
yield f'data: {json.dumps(data)}\n\n'
|
||||
elif data.get("type") == "usage":
|
||||
last_metrics = data.get("data", {})
|
||||
last_metrics["model"] = _answered_by or sess.model
|
||||
_reported_model = last_metrics.get("model")
|
||||
last_metrics["requested_model"] = _requested_model
|
||||
last_metrics["model"] = _reported_model or _actual_model or _answered_by or _requested_model
|
||||
if ctx.context_length and last_metrics.get("input_tokens"):
|
||||
pct = min(round((last_metrics["input_tokens"] / ctx.context_length) * 100, 1), 100.0)
|
||||
last_metrics["context_percent"] = pct
|
||||
@@ -945,7 +1057,8 @@ def setup_chat_routes(
|
||||
"tokens_per_second": _tps,
|
||||
"context_percent": _ctx_pct,
|
||||
"context_length": ctx.context_length,
|
||||
"model": sess.model,
|
||||
"model": _actual_model or _answered_by or _requested_model,
|
||||
"requested_model": _requested_model,
|
||||
"usage_source": "estimated",
|
||||
}
|
||||
yield f'data: {json.dumps({"type": "metrics", "data": last_metrics})}\n\n'
|
||||
@@ -957,7 +1070,7 @@ def setup_chat_routes(
|
||||
rag_sources=ctx.rag_sources,
|
||||
research_sources=research_sources,
|
||||
used_memories=ctx.used_memories,
|
||||
do_research=do_research,
|
||||
do_research=effective_do_research,
|
||||
incognito=incognito,
|
||||
)
|
||||
if _saved_id:
|
||||
@@ -967,14 +1080,22 @@ def setup_chat_routes(
|
||||
last_metrics, ctx.uprefs, memory_manager, memory_vector, webhook_manager,
|
||||
incognito=incognito, compare_mode=compare_mode,
|
||||
character_name=ctx.preset.character_name,
|
||||
owner=_user,
|
||||
owner=_user,
|
||||
allow_background_extraction=not tool_policy.block_all_tool_calls,
|
||||
)
|
||||
_stream_set(session, status="done")
|
||||
yield chunk
|
||||
except (asyncio.CancelledError, GeneratorExit):
|
||||
if full_response:
|
||||
logger.info("Client disconnected mid-stream (chat mode) for session %s, saving partial (%d chars)", session, len(full_response))
|
||||
_stopped_content, _stopped_md = clean_thinking_for_save(full_response, {"stopped": True, "model": sess.model})
|
||||
_stopped_content, _stopped_md = clean_thinking_for_save(
|
||||
full_response,
|
||||
{
|
||||
"stopped": True,
|
||||
"model": _actual_model or _answered_by or _requested_model,
|
||||
"requested_model": _requested_model,
|
||||
},
|
||||
)
|
||||
sess.add_message(ChatMessage("assistant", _stopped_content, metadata=_stopped_md))
|
||||
if not incognito:
|
||||
session_manager.save_sessions()
|
||||
@@ -986,6 +1107,8 @@ def setup_chat_routes(
|
||||
_agent_rounds = 0
|
||||
_agent_tool_calls = 0
|
||||
_answered_by = None # set if the selected model failed and a fallback answered
|
||||
_requested_model = sess.model
|
||||
_actual_model = None
|
||||
try:
|
||||
from src.settings import get_setting
|
||||
from src.agent_tools import MAX_AGENT_ROUNDS as _DEFAULT_ROUNDS
|
||||
@@ -1012,9 +1135,12 @@ def setup_chat_routes(
|
||||
active_document=active_doc,
|
||||
session_id=session,
|
||||
disabled_tools=disabled_tools if disabled_tools else None,
|
||||
tool_policy=tool_policy,
|
||||
owner=_user,
|
||||
fallbacks=_fallback_candidates,
|
||||
workspace=workspace or None,
|
||||
plan_mode=plan_mode,
|
||||
approved_plan=approved_plan or None,
|
||||
):
|
||||
if chunk.startswith("data: ") and not chunk.startswith("data: [DONE]"):
|
||||
try:
|
||||
@@ -1035,6 +1161,8 @@ def setup_chat_routes(
|
||||
"doc_stream_open", "doc_stream_delta",
|
||||
"doc_update", "doc_suggestions", "ui_control",
|
||||
"rounds_exhausted",
|
||||
"ask_user",
|
||||
"plan_update",
|
||||
):
|
||||
if data.get("type") == "agent_step":
|
||||
_agent_rounds = max(_agent_rounds, data.get("round", 1))
|
||||
@@ -1047,10 +1175,18 @@ def setup_chat_routes(
|
||||
# model so metrics reflect it, not the masked
|
||||
# selected model.
|
||||
_answered_by = data.get("answered_by") or _answered_by
|
||||
_actual_model = _actual_model or _answered_by
|
||||
data["selected_model"] = data.get("selected_model") or _requested_model
|
||||
yield chunk
|
||||
elif data.get("type") == "model_actual":
|
||||
_actual_model = data.get("model") or _actual_model
|
||||
data["requested_model"] = _requested_model
|
||||
yield f'data: {json.dumps(data)}\n\n'
|
||||
elif data.get("type") == "metrics":
|
||||
last_metrics = data.get("data", {})
|
||||
last_metrics["model"] = _answered_by or sess.model
|
||||
_reported_model = last_metrics.get("model")
|
||||
last_metrics["requested_model"] = last_metrics.get("requested_model") or _requested_model
|
||||
last_metrics["model"] = _reported_model or _actual_model or _answered_by or _requested_model
|
||||
yield f'data: {json.dumps({"type": "metrics", "data": last_metrics})}\n\n'
|
||||
except json.JSONDecodeError:
|
||||
yield chunk
|
||||
@@ -1078,6 +1214,7 @@ def setup_chat_routes(
|
||||
skills_manager=skills_manager,
|
||||
owner=_user,
|
||||
extract_skills=user_requested_agent,
|
||||
allow_background_extraction=not tool_policy.block_all_tool_calls,
|
||||
)
|
||||
_stream_set(session, status="done")
|
||||
yield chunk
|
||||
@@ -1091,7 +1228,14 @@ def setup_chat_routes(
|
||||
try:
|
||||
if full_response:
|
||||
logger.info("Client disconnected mid-stream for session %s, saving partial response (%d chars)", session, len(full_response))
|
||||
_stopped_content2, _stopped_md2 = clean_thinking_for_save(full_response, {"stopped": True, "model": sess.model})
|
||||
_stopped_content2, _stopped_md2 = clean_thinking_for_save(
|
||||
full_response,
|
||||
{
|
||||
"stopped": True,
|
||||
"model": _actual_model or _answered_by or _requested_model,
|
||||
"requested_model": _requested_model,
|
||||
},
|
||||
)
|
||||
sess.add_message(ChatMessage("assistant", _stopped_content2, metadata=_stopped_md2))
|
||||
if not incognito:
|
||||
session_manager.save_sessions()
|
||||
@@ -1110,11 +1254,30 @@ def setup_chat_routes(
|
||||
finally:
|
||||
_active_streams.pop(session, None)
|
||||
|
||||
# Run the stream as a DETACHED background task so it survives the client
|
||||
# closing the tab / navigating away (true terminal-agent behavior). The
|
||||
# SSE response just subscribes (replay buffered output + live); dropping
|
||||
# the SSE only removes a subscriber — the run keeps going and saves the
|
||||
# assistant message on completion regardless. Reconnect via /api/chat/resume.
|
||||
# Compare panes are short-lived, single-shot generations whose sessions
|
||||
# exist only to drive that one pane — there's nothing to "resume" and
|
||||
# the user expects the pane's Stop button (which aborts the fetch,
|
||||
# closing this SSE) to promptly cancel the upstream LLM call. Detaching
|
||||
# them would keep burning upstream tokens/compute after the pane is
|
||||
# stopped or the comparison is abandoned, and would surface a stale
|
||||
# "still streaming" /resume target for a session nobody will revisit.
|
||||
#
|
||||
# So: stream them directly (no agent_runs wrapping). Starlette cancels
|
||||
# the underlying async generator (raising CancelledError/GeneratorExit
|
||||
# inside it) as soon as it notices the client disconnected — which the
|
||||
# mode-specific except blocks above already handle by saving the
|
||||
# partial response exactly once. This stops the upstream call promptly
|
||||
# without waiting on the next streamed chunk.
|
||||
#
|
||||
# Normal chat/agent streams keep the DETACHED behavior below: they
|
||||
# survive the client closing the tab / navigating away (true
|
||||
# terminal-agent semantics). The SSE response just subscribes (replay
|
||||
# buffered output + live); dropping the SSE only removes a subscriber —
|
||||
# the run keeps going and saves the assistant message on completion
|
||||
# regardless. Reconnect via /api/chat/resume.
|
||||
if compare_mode:
|
||||
return StreamingResponse(_safe_stream(), media_type="text/event-stream")
|
||||
|
||||
agent_runs.start(session, _safe_stream())
|
||||
return StreamingResponse(agent_runs.subscribe(session), media_type="text/event-stream")
|
||||
|
||||
@@ -1185,45 +1348,16 @@ def setup_chat_routes(
|
||||
return []
|
||||
|
||||
_user = get_current_user(request)
|
||||
query_term = q.strip()
|
||||
db = SessionLocal()
|
||||
try:
|
||||
base_q = (
|
||||
db.query(DBChatMessage, DBSession.name)
|
||||
.join(DBSession, DBChatMessage.session_id == DBSession.id)
|
||||
.filter(
|
||||
DBSession.archived == False,
|
||||
DBChatMessage.content.ilike(f"%{query_term}%"),
|
||||
DBChatMessage.role.in_(["user", "assistant"]),
|
||||
)
|
||||
return [
|
||||
result.to_dict()
|
||||
for result in search_session_messages(
|
||||
q,
|
||||
limit=limit,
|
||||
owner=_user,
|
||||
restrict_owner=_user is not None,
|
||||
include_legacy_owner=False,
|
||||
)
|
||||
if _user:
|
||||
base_q = base_q.filter(DBSession.owner == _user)
|
||||
rows = base_q.order_by(DBChatMessage.timestamp.desc()).limit(limit).all()
|
||||
|
||||
results = []
|
||||
for msg, session_name in rows:
|
||||
content = msg.content or ""
|
||||
lower_content = content.lower()
|
||||
idx = lower_content.find(query_term.lower())
|
||||
if idx == -1:
|
||||
snippet = content[:120]
|
||||
else:
|
||||
start = max(0, idx - 50)
|
||||
end = min(len(content), idx + len(query_term) + 50)
|
||||
snippet = ("..." if start > 0 else "") + content[start:end] + ("..." if end < len(content) else "")
|
||||
|
||||
results.append({
|
||||
"session_id": msg.session_id,
|
||||
"session_name": session_name or "Untitled",
|
||||
"role": msg.role,
|
||||
"content_snippet": snippet,
|
||||
"timestamp": msg.timestamp.isoformat() if msg.timestamp else None,
|
||||
})
|
||||
|
||||
return results
|
||||
finally:
|
||||
db.close()
|
||||
]
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# POST /api/rewrite — lightweight rewrite of last AI message (no tools)
|
||||
|
||||
@@ -0,0 +1,170 @@
|
||||
"""ChatGPT Subscription device-flow setup routes."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Dict, Optional
|
||||
|
||||
from fastapi import HTTPException, Request
|
||||
|
||||
from core.database import ModelEndpoint, ProviderAuthSession, SessionLocal, utcnow_naive
|
||||
from routes.device_flow import (
|
||||
DeviceFlowPoll,
|
||||
DeviceFlowStart,
|
||||
PendingDeviceFlowStore,
|
||||
create_device_flow_router,
|
||||
)
|
||||
from src.auth_helpers import get_current_user
|
||||
from src import chatgpt_subscription
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_DEVICE_FLOW_STORE = PendingDeviceFlowStore()
|
||||
|
||||
|
||||
def _provision_endpoint(tokens: Dict, owner: Optional[str]) -> Dict:
|
||||
access_token = tokens.get("access_token")
|
||||
refresh_token = tokens.get("refresh_token")
|
||||
if not access_token or not refresh_token:
|
||||
raise ValueError("ChatGPT token response was missing access_token or refresh_token")
|
||||
|
||||
base = chatgpt_subscription.DEFAULT_CHATGPT_SUBSCRIPTION_BASE_URL
|
||||
models = chatgpt_subscription.fetch_available_models(access_token)
|
||||
if not models:
|
||||
raise ValueError("ChatGPT Subscription connected, but no usable Codex models were discovered for this account.")
|
||||
db = SessionLocal()
|
||||
try:
|
||||
auth = (
|
||||
db.query(ProviderAuthSession)
|
||||
.filter(
|
||||
ProviderAuthSession.provider == chatgpt_subscription.CHATGPT_SUBSCRIPTION_PROVIDER,
|
||||
ProviderAuthSession.owner == owner,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if auth is None:
|
||||
auth = ProviderAuthSession(
|
||||
id=str(uuid.uuid4())[:8],
|
||||
provider=chatgpt_subscription.CHATGPT_SUBSCRIPTION_PROVIDER,
|
||||
owner=owner,
|
||||
label="ChatGPT Subscription",
|
||||
base_url=base,
|
||||
auth_mode="chatgpt",
|
||||
)
|
||||
db.add(auth)
|
||||
auth.base_url = base
|
||||
auth.access_token = access_token
|
||||
auth.refresh_token = refresh_token
|
||||
auth.last_refresh = utcnow_naive()
|
||||
auth.auth_mode = "chatgpt"
|
||||
|
||||
ep = (
|
||||
db.query(ModelEndpoint)
|
||||
.filter(
|
||||
ModelEndpoint.base_url == base,
|
||||
ModelEndpoint.provider_auth_id == auth.id,
|
||||
ModelEndpoint.owner == owner,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if ep is None:
|
||||
ep = ModelEndpoint(
|
||||
id=str(uuid.uuid4())[:8],
|
||||
name="ChatGPT Subscription",
|
||||
base_url=base,
|
||||
model_type="llm",
|
||||
endpoint_kind="api",
|
||||
owner=owner,
|
||||
)
|
||||
db.add(ep)
|
||||
ep.name = "ChatGPT Subscription"
|
||||
ep.base_url = base
|
||||
ep.api_key = None
|
||||
ep.provider_auth_id = auth.id
|
||||
ep.is_enabled = True
|
||||
ep.supports_tools = False
|
||||
ep.model_type = "llm"
|
||||
ep.endpoint_kind = "api"
|
||||
ep.model_refresh_mode = "manual"
|
||||
ep.cached_models = json.dumps(models)
|
||||
db.commit()
|
||||
result = {
|
||||
"id": ep.id,
|
||||
"name": ep.name,
|
||||
"base_url": ep.base_url,
|
||||
"models": models,
|
||||
}
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
try:
|
||||
from routes.model_routes import _invalidate_models_cache
|
||||
|
||||
_invalidate_models_cache()
|
||||
except Exception:
|
||||
pass
|
||||
return result
|
||||
|
||||
|
||||
def _start_device_flow(request: Request, _form) -> DeviceFlowStart:
|
||||
try:
|
||||
data = chatgpt_subscription.request_device_code()
|
||||
except Exception as exc:
|
||||
raise chatgpt_subscription.to_http_exception(exc)
|
||||
|
||||
device_auth_id = data.get("device_auth_id")
|
||||
user_code = data.get("user_code")
|
||||
if not device_auth_id or not user_code:
|
||||
raise HTTPException(502, "ChatGPT did not return a complete device code")
|
||||
verification_uri = data.get("verification_uri") or f"{chatgpt_subscription.CHATGPT_OAUTH_ISSUER}/codex/device"
|
||||
return DeviceFlowStart(
|
||||
pending={
|
||||
"device_auth_id": device_auth_id,
|
||||
"user_code": user_code,
|
||||
"owner": get_current_user(request) or None,
|
||||
},
|
||||
response={
|
||||
"user_code": user_code,
|
||||
"verification_uri": verification_uri,
|
||||
},
|
||||
interval=int(data.get("interval") or 5),
|
||||
expires_in=int(data.get("expires_in") or 900),
|
||||
)
|
||||
|
||||
|
||||
def _poll_device_flow(_request: Request, pending: Dict) -> DeviceFlowPoll:
|
||||
try:
|
||||
data = chatgpt_subscription.poll_device_auth(pending["device_auth_id"], pending["user_code"])
|
||||
except Exception as exc:
|
||||
logger.debug("ChatGPT device poll failed: %s", exc)
|
||||
return DeviceFlowPoll.pending(str(exc))
|
||||
|
||||
authorization_code = data.get("authorization_code")
|
||||
code_verifier = data.get("code_verifier")
|
||||
if authorization_code and code_verifier:
|
||||
try:
|
||||
tokens = chatgpt_subscription.exchange_authorization_code(authorization_code, code_verifier)
|
||||
result = _provision_endpoint(tokens, pending["owner"])
|
||||
except Exception as exc:
|
||||
logger.exception("ChatGPT Subscription endpoint provisioning failed")
|
||||
raise chatgpt_subscription.to_http_exception(exc)
|
||||
return DeviceFlowPoll.authorized(result)
|
||||
|
||||
err = data.get("error") or data.get("status")
|
||||
if err in ("authorization_pending", "pending", None):
|
||||
return DeviceFlowPoll.pending()
|
||||
if err == "slow_down":
|
||||
return DeviceFlowPoll.slow_down(int(data.get("interval") or 0) or None)
|
||||
if err in ("expired_token", "access_denied", "denied"):
|
||||
return DeviceFlowPoll.failed(err)
|
||||
return DeviceFlowPoll.pending(err or "unknown")
|
||||
|
||||
|
||||
def setup_chatgpt_subscription_routes():
|
||||
return create_device_flow_router(
|
||||
prefix="/api/chatgpt-subscription",
|
||||
tags=["chatgpt-subscription"],
|
||||
store=_DEVICE_FLOW_STORE,
|
||||
start_flow=_start_device_flow,
|
||||
poll_flow=_poll_device_flow,
|
||||
)
|
||||
+16
-6
@@ -15,8 +15,9 @@ from typing import Any
|
||||
from fastapi import APIRouter, BackgroundTasks, Body, HTTPException, Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
from src.auth_helpers import require_user
|
||||
from src.auth_helpers import require_authenticated_request, require_user
|
||||
from src.tool_implementations import do_manage_notes
|
||||
from src.constants import COOKBOOK_STATE_FILE
|
||||
|
||||
|
||||
COOKBOOK_READ_SCOPES = {"cookbook:read", "cookbook:launch"}
|
||||
@@ -41,7 +42,9 @@ async def _as_owner(request: Request, owner: str, fn, *args, **kwargs):
|
||||
the scope-gated owner (not the "api" pseudo-user the bearer middleware sets).
|
||||
Restores the original value when done. Works for sync and async handlers."""
|
||||
orig = getattr(request.state, "current_user", None)
|
||||
orig_api_token = getattr(request.state, "api_token", None)
|
||||
request.state.current_user = owner
|
||||
request.state.api_token = False
|
||||
try:
|
||||
result = fn(*args, **kwargs)
|
||||
if asyncio.iscoroutine(result):
|
||||
@@ -49,6 +52,13 @@ async def _as_owner(request: Request, owner: str, fn, *args, **kwargs):
|
||||
return result
|
||||
finally:
|
||||
request.state.current_user = orig
|
||||
if orig_api_token is None:
|
||||
try:
|
||||
delattr(request.state, "api_token")
|
||||
except AttributeError:
|
||||
pass
|
||||
else:
|
||||
request.state.api_token = orig_api_token
|
||||
|
||||
|
||||
def _scope_owner(request: Request, allowed: set[str]) -> str:
|
||||
@@ -146,7 +156,7 @@ def setup_codex_routes(
|
||||
|
||||
@router.get("/plugin.zip")
|
||||
def plugin_zip(request: Request):
|
||||
require_user(request)
|
||||
require_authenticated_request(request)
|
||||
root = Path(__file__).resolve().parent.parent / "integrations" / "codex"
|
||||
if not root.exists():
|
||||
raise HTTPException(404, "Codex plugin bundle not found")
|
||||
@@ -415,8 +425,8 @@ def setup_codex_routes(
|
||||
|
||||
def _read_cookbook_state() -> dict:
|
||||
from pathlib import Path as _Path
|
||||
import os as _os, json as _json
|
||||
p = _Path(_os.environ.get("DATA_DIR", "data")) / "cookbook_state.json"
|
||||
import json as _json
|
||||
p = _Path(COOKBOOK_STATE_FILE)
|
||||
if not p.exists():
|
||||
return {}
|
||||
try:
|
||||
@@ -724,7 +734,7 @@ def setup_codex_routes(
|
||||
import time as _t, json as _json
|
||||
from core.atomic_io import atomic_write_json
|
||||
from pathlib import Path as _Path
|
||||
cookbook_state_path = _Path("/app/data/cookbook_state.json")
|
||||
cookbook_state_path = _Path(COOKBOOK_STATE_FILE)
|
||||
try:
|
||||
state = _json.loads(cookbook_state_path.read_text(encoding="utf-8"))
|
||||
except Exception:
|
||||
@@ -762,7 +772,7 @@ def setup_claude_routes() -> APIRouter:
|
||||
|
||||
@router.get("/plugin.zip")
|
||||
def plugin_zip(request: Request):
|
||||
require_user(request)
|
||||
require_authenticated_request(request)
|
||||
# Only ship the skills/ subtree so extracting at ~/.claude/ doesn't dump
|
||||
# README.md or other bundle metadata into the user's claude config dir.
|
||||
skills_root = Path(__file__).resolve().parent.parent / "integrations" / "claude" / "skills"
|
||||
|
||||
+110
-22
@@ -12,6 +12,7 @@ import logging
|
||||
from core.database import Comparison, SessionLocal
|
||||
from core.session_manager import SessionManager
|
||||
from src.auth_helpers import get_current_user
|
||||
from routes.session_routes import _reject_raw_endpoint_url_for_non_admin
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -38,6 +39,24 @@ def _owned_endpoint_by_url(db, base_url, owner):
|
||||
return owner_filter(q, ModelEndpoint, owner).first()
|
||||
|
||||
|
||||
def _owned_endpoint_by_id(db, endpoint_id, owner):
|
||||
"""ModelEndpoint whose id == `endpoint_id` and is VISIBLE to `owner` (their
|
||||
own rows + legacy null-owner "shared" rows); None otherwise.
|
||||
|
||||
Preferred over _owned_endpoint_by_url for credential resolution: two visible
|
||||
endpoints can share the same base_url but hold DIFFERENT api_keys (e.g. two
|
||||
accounts on the same provider). A base_url-only match returns whichever row
|
||||
sorts first, so it can copy the WRONG owner-scoped key into the [CMP] session.
|
||||
An id pins the exact registered endpoint, so /api/compare/start prefers it and
|
||||
only falls back to URL matching for legacy / admin raw-URL callers. Owner
|
||||
scoping is identical to _owned_endpoint_by_url (a null/empty owner is a no-op).
|
||||
"""
|
||||
from core.database import ModelEndpoint
|
||||
from src.auth_helpers import owner_filter
|
||||
q = db.query(ModelEndpoint).filter(ModelEndpoint.id == endpoint_id)
|
||||
return owner_filter(q, ModelEndpoint, owner).first()
|
||||
|
||||
|
||||
class RecordVoteRequest(BaseModel):
|
||||
prompt: str
|
||||
models: List[str]
|
||||
@@ -54,8 +73,10 @@ def setup_compare_routes(session_manager: SessionManager):
|
||||
prompt: str = Form(...),
|
||||
model_a: str = Form(...),
|
||||
model_b: str = Form(...),
|
||||
endpoint_a: str = Form(...),
|
||||
endpoint_b: str = Form(...),
|
||||
endpoint_a: str = Form(""),
|
||||
endpoint_b: str = Form(""),
|
||||
endpoint_a_id: str = Form(""),
|
||||
endpoint_b_id: str = Form(""),
|
||||
is_blind: str = Form("true"),
|
||||
):
|
||||
"""Create two ephemeral sessions and a comparison record.
|
||||
@@ -63,10 +84,10 @@ def setup_compare_routes(session_manager: SessionManager):
|
||||
Returns the comparison ID and the two session IDs so the client
|
||||
can fire two independent SSE streams to /api/chat_stream.
|
||||
"""
|
||||
user = getattr(request.state, 'current_user', None)
|
||||
comp_id = str(uuid.uuid4())
|
||||
sid_a = str(uuid.uuid4())
|
||||
sid_b = str(uuid.uuid4())
|
||||
user = getattr(request.state, 'current_user', None)
|
||||
|
||||
# Blind mapping: randomly assign left/right
|
||||
blind = str(is_blind).lower() == "true"
|
||||
@@ -87,31 +108,94 @@ def setup_compare_routes(session_manager: SessionManager):
|
||||
# de-anonymizing the comparison before the user votes (issue #1285).
|
||||
slot_name = {session_left: "Model A", session_right: "Model B"}
|
||||
|
||||
# Create ephemeral sessions (prefixed [CMP])
|
||||
for sid, model, endpoint in [(sid_a, model_a, endpoint_a), (sid_b, model_b, endpoint_b)]:
|
||||
# SECURITY: resolve and validate BOTH endpoints before creating any
|
||||
# session. Compare copies a registered endpoint's Authorization header
|
||||
# into the [CMP] session, so validating one endpoint while creating its
|
||||
# session, then rejecting the other, would leave a partial compare
|
||||
# session behind with that header attached. Doing all the owner-scope
|
||||
# resolution + raw-URL rejection up front means a 403 on either endpoint
|
||||
# aborts the whole request with nothing created and no header copied.
|
||||
from src.endpoint_resolver import build_chat_url, build_headers, normalize_base
|
||||
resolved = []
|
||||
db = SessionLocal()
|
||||
try:
|
||||
for sid, model, endpoint, endpoint_id in [
|
||||
(sid_a, model_a, endpoint_a, endpoint_a_id),
|
||||
(sid_b, model_b, endpoint_b, endpoint_b_id),
|
||||
]:
|
||||
# Prefer an explicit endpoint id: it pins the EXACT registered
|
||||
# endpoint (and its api_key), even when two endpoints visible to
|
||||
# the caller share a base_url with different keys — a URL-only
|
||||
# match would copy whichever row sorts first, i.e. possibly the
|
||||
# wrong key. Fall back to URL resolution only for legacy / admin
|
||||
# raw-URL callers that don't send an id.
|
||||
eid = endpoint_id.strip() if isinstance(endpoint_id, str) else ""
|
||||
if eid:
|
||||
ep = _owned_endpoint_by_id(db, eid, user)
|
||||
if ep is None:
|
||||
# An id the caller can't see (wrong owner / deleted) must
|
||||
# NOT silently fall back to a same-URL row with a different
|
||||
# key — that's exactly the mix-up ids exist to prevent.
|
||||
raise HTTPException(404, "Model endpoint not found")
|
||||
# The id already resolved the endpoint; ignore any raw URL the
|
||||
# caller also sent and dial the stored config instead.
|
||||
endpoint = ep.base_url
|
||||
elif not endpoint:
|
||||
raise HTTPException(
|
||||
422, "endpoint_a/endpoint_b or endpoint_a_id/endpoint_b_id is required"
|
||||
)
|
||||
else:
|
||||
# Resolve the supplied URL to a ModelEndpoint the caller owns
|
||||
# (their own rows + legacy null-owner shared rows), scoped so a
|
||||
# comparison can't borrow another user's private endpoint key.
|
||||
base = normalize_base(endpoint)
|
||||
ep = _owned_endpoint_by_url(db, base, user)
|
||||
# Reject *unregistered* raw URLs for signed-in non-admins; a
|
||||
# matched registered endpoint supplies an id so the caller can
|
||||
# still compare endpoints they own. Blanket-rejecting here (the
|
||||
# earlier `endpoint_id=None` call) locked non-admins out of
|
||||
# compare entirely, since compare resolves endpoints by URL with
|
||||
# no endpoint_id. Mirrors the gallery inpaint/harmonize checks.
|
||||
# Raised here (phase 1), before any session exists.
|
||||
_reject_raw_endpoint_url_for_non_admin(
|
||||
request, user, str(ep.id) if ep is not None else None, endpoint
|
||||
)
|
||||
# Bind the [CMP] session to the RESOLVED endpoint, not the raw
|
||||
# caller-supplied string. When the URL matches a registered
|
||||
# endpoint visible to the caller, use that row's own normalized
|
||||
# base URL (the same value owner scoping + endpoint validation
|
||||
# already vetted) so the session dials exactly where the stored
|
||||
# config points. The raw `endpoint` only survives for callers
|
||||
# allowed to pass one — admins / single-user mode, where
|
||||
# `_reject_raw_endpoint_url_for_non_admin` is a no-op and `ep`
|
||||
# is None. Mirrors the registered-endpoint path in session_routes.
|
||||
session_endpoint_url = (
|
||||
build_chat_url(normalize_base(ep.base_url)) if ep is not None else endpoint
|
||||
)
|
||||
# Headers come only from a matched endpoint's key; None when
|
||||
# `ep` is None (raw admin URL or no match), so a comparison can
|
||||
# never inherit another user's key/headers.
|
||||
headers = build_headers(ep.api_key, ep.base_url) if (ep and ep.api_key) else None
|
||||
resolved.append((sid, model, session_endpoint_url, headers))
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
# Both endpoints validated — only now create the ephemeral [CMP]
|
||||
# sessions and copy any resolved headers.
|
||||
for sid, model, session_endpoint_url, headers in resolved:
|
||||
name = f"[CMP] {slot_name[sid]}" if blind else f"[CMP] {model.split('/')[-1]}"
|
||||
session_manager.create_session(
|
||||
session_id=sid,
|
||||
name=name,
|
||||
endpoint_url=endpoint,
|
||||
endpoint_url=session_endpoint_url,
|
||||
model=model,
|
||||
rag=False,
|
||||
owner=user,
|
||||
)
|
||||
# Copy API key from endpoint config
|
||||
db = SessionLocal()
|
||||
try:
|
||||
from src.endpoint_resolver import build_headers, normalize_base
|
||||
# Find matching endpoint by URL, scoped to the caller so a
|
||||
# comparison can't borrow another user's private endpoint key.
|
||||
base = normalize_base(endpoint)
|
||||
ep = _owned_endpoint_by_url(db, base, user)
|
||||
if ep and ep.api_key:
|
||||
s = session_manager.sessions.get(sid)
|
||||
if s:
|
||||
s.headers = build_headers(ep.api_key, ep.base_url)
|
||||
finally:
|
||||
db.close()
|
||||
if headers:
|
||||
s = session_manager.sessions.get(sid)
|
||||
if s:
|
||||
s.headers = headers
|
||||
|
||||
# Store comparison record
|
||||
db = SessionLocal()
|
||||
@@ -121,8 +205,12 @@ def setup_compare_routes(session_manager: SessionManager):
|
||||
prompt=prompt,
|
||||
model_a=model_a,
|
||||
model_b=model_b,
|
||||
endpoint_a=endpoint_a,
|
||||
endpoint_b=endpoint_b,
|
||||
# Record the URL the session actually dials. For URL callers this
|
||||
# is their raw input; for id-only callers (empty endpoint_a/_b)
|
||||
# fall back to the resolved endpoint URL so the column stays
|
||||
# meaningful and non-null. resolved is in [a, b] order.
|
||||
endpoint_a=endpoint_a or resolved[0][2],
|
||||
endpoint_b=endpoint_b or resolved[1][2],
|
||||
is_blind=blind,
|
||||
blind_mapping=json.dumps(mapping),
|
||||
owner=user,
|
||||
|
||||
+53
-18
@@ -11,20 +11,24 @@ import uuid
|
||||
import json
|
||||
import csv
|
||||
import io
|
||||
import os
|
||||
import httpx
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
from fastapi import APIRouter, Query, Depends, Response
|
||||
from urllib.parse import urljoin, urlparse, urlunparse
|
||||
|
||||
from fastapi import APIRouter, Query, Depends, Response, HTTPException
|
||||
from typing import List, Dict, Optional
|
||||
|
||||
from src.auth_helpers import require_user
|
||||
from core.middleware import require_admin
|
||||
from src.url_safety import check_outbound_url
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DATA_DIR = Path(__file__).resolve().parent.parent / "data"
|
||||
SETTINGS_FILE = DATA_DIR / "settings.json"
|
||||
LOCAL_CONTACTS_FILE = DATA_DIR / "contacts.json"
|
||||
from src.constants import DATA_DIR as _DATA_DIR, SETTINGS_FILE as _SETTINGS_FILE, CONTACTS_FILE as _CONTACTS_FILE
|
||||
DATA_DIR = Path(_DATA_DIR)
|
||||
SETTINGS_FILE = Path(_SETTINGS_FILE)
|
||||
LOCAL_CONTACTS_FILE = Path(_CONTACTS_FILE)
|
||||
|
||||
|
||||
def _load_settings():
|
||||
@@ -53,6 +57,21 @@ def _carddav_configured(cfg: Optional[Dict] = None) -> bool:
|
||||
return bool((cfg.get("url") or "").strip())
|
||||
|
||||
|
||||
def _validate_carddav_url(url: str) -> str:
|
||||
cleaned = (url if isinstance(url, str) else "").strip().rstrip("/")
|
||||
ok, reason = check_outbound_url(
|
||||
cleaned,
|
||||
block_private=os.getenv("CARDDAV_BLOCK_PRIVATE_IPS", "false").lower() == "true",
|
||||
)
|
||||
if not ok:
|
||||
raise ValueError(f"Rejected CardDAV URL: {reason}")
|
||||
return cleaned
|
||||
|
||||
|
||||
def _carddav_base_url(cfg: Dict) -> str:
|
||||
return _validate_carddav_url(cfg.get("url") or "")
|
||||
|
||||
|
||||
def _normalize_contact(contact: Dict) -> Dict:
|
||||
emails = []
|
||||
for e in contact.get("emails") or ([] if not contact.get("email") else [contact.get("email")]):
|
||||
@@ -219,14 +238,18 @@ _contact_cache = {"contacts": [], "fetched_at": None}
|
||||
def _abs_url(href: str) -> str:
|
||||
"""Combine a multistatus <href> (an absolute path like
|
||||
/user/contacts/x.vcf) with the configured CardDAV server origin so we
|
||||
get a fully-qualified URL to PUT/DELETE. If href is already absolute
|
||||
(http...), return it as-is."""
|
||||
from urllib.parse import urlparse, urlunparse
|
||||
if href.startswith("http://") or href.startswith("https://"):
|
||||
return href
|
||||
get a fully-qualified URL to PUT/DELETE. Absolute hrefs are accepted only
|
||||
for the configured origin; a cross-origin href is treated as a path on the
|
||||
configured server so a malicious CardDAV response cannot redirect later
|
||||
writes/deletes to cloud metadata or another host."""
|
||||
cfg = _get_carddav_config()
|
||||
p = urlparse(cfg["url"])
|
||||
return urlunparse((p.scheme, p.netloc, href, "", "", ""))
|
||||
base = _carddav_base_url(cfg)
|
||||
base_p = urlparse(base)
|
||||
joined = urljoin(base.rstrip("/") + "/", href or "")
|
||||
joined_p = urlparse(joined)
|
||||
if (joined_p.scheme, joined_p.netloc) != (base_p.scheme, base_p.netloc):
|
||||
joined = urlunparse((base_p.scheme, base_p.netloc, joined_p.path or "/", "", joined_p.query, ""))
|
||||
return _validate_carddav_url(joined)
|
||||
|
||||
|
||||
# CardDAV REPORT body — pull every card's etag + raw vCard in ONE request,
|
||||
@@ -297,6 +320,7 @@ def _fetch_contacts(force=False):
|
||||
return contacts
|
||||
|
||||
try:
|
||||
cfg["url"] = _carddav_base_url(cfg)
|
||||
auth = None
|
||||
if cfg["username"]:
|
||||
auth = (cfg["username"], cfg["password"])
|
||||
@@ -353,8 +377,8 @@ def _create_contact(name: str, email: str) -> bool:
|
||||
|
||||
contact_uid = str(uuid.uuid4())
|
||||
vcard = _build_vcard(name, email, contact_uid)
|
||||
url = cfg["url"].rstrip("/") + "/" + contact_uid + ".vcf"
|
||||
try:
|
||||
url = _carddav_base_url(cfg) + "/" + contact_uid + ".vcf"
|
||||
auth = None
|
||||
if cfg["username"]:
|
||||
auth = (cfg["username"], cfg["password"])
|
||||
@@ -382,7 +406,7 @@ def _vcard_url(uid: str) -> str:
|
||||
escape the collection and target an arbitrary CardDAV resource."""
|
||||
from urllib.parse import quote
|
||||
cfg = _get_carddav_config()
|
||||
return cfg["url"].rstrip("/") + "/" + quote(uid, safe="") + ".vcf"
|
||||
return _carddav_base_url(cfg) + "/" + quote(uid, safe="") + ".vcf"
|
||||
|
||||
|
||||
def _import_vcards(text: str) -> Dict:
|
||||
@@ -413,6 +437,11 @@ def _import_vcards(text: str) -> Dict:
|
||||
if imported:
|
||||
_save_local_contacts(contacts)
|
||||
return {"imported": imported, "failed": 0, "total": len(parsed)}
|
||||
try:
|
||||
base_url = _carddav_base_url(cfg)
|
||||
except ValueError as e:
|
||||
logger.warning("CardDAV import URL rejected: %s", e)
|
||||
return {"imported": 0, "failed": 0, "total": 0, "error": str(e)}
|
||||
auth = (cfg["username"], cfg["password"]) if cfg["username"] else None
|
||||
# Split into individual cards. re.split drops the BEGIN line, so we
|
||||
# re-add it. Normalize CRLF.
|
||||
@@ -441,7 +470,7 @@ def _import_vcards(text: str) -> Dict:
|
||||
elif not re.search(r"^VERSION:", block, re.MULTILINE):
|
||||
block = block.replace("BEGIN:VCARD", "BEGIN:VCARD\nVERSION:4.0", 1)
|
||||
vcard = block.replace("\n", "\r\n") + "\r\n"
|
||||
url = cfg["url"].rstrip("/") + "/" + quote(uid, safe="") + ".vcf"
|
||||
url = base_url + "/" + quote(uid, safe="") + ".vcf"
|
||||
try:
|
||||
r = httpx.put(
|
||||
url, data=vcard.encode("utf-8"),
|
||||
@@ -601,8 +630,8 @@ def _update_contact(uid: str, name: str, emails: List[str], phones: List[str]) -
|
||||
vcard = _build_vcard(name, "", uid=uid, emails=emails, phones=phones)
|
||||
# Use the real resource href (handles externally-created contacts whose
|
||||
# filename != UID); falls back to the <uid>.vcf guess.
|
||||
url = _resolve_resource_url(uid)
|
||||
try:
|
||||
url = _resolve_resource_url(uid)
|
||||
auth = (cfg["username"], cfg["password"]) if cfg["username"] else None
|
||||
r = httpx.put(
|
||||
url,
|
||||
@@ -630,8 +659,8 @@ def _delete_contact(uid: str) -> bool:
|
||||
_save_local_contacts(remaining)
|
||||
return True
|
||||
|
||||
url = _resolve_resource_url(uid)
|
||||
try:
|
||||
url = _resolve_resource_url(uid)
|
||||
auth = (cfg["username"], cfg["password"]) if cfg["username"] else None
|
||||
r = httpx.delete(url, auth=auth, timeout=10)
|
||||
if r.status_code in (200, 204):
|
||||
@@ -747,7 +776,13 @@ def setup_contacts_routes():
|
||||
settings = _load_settings()
|
||||
for key in ("carddav_url", "carddav_username", "carddav_password"):
|
||||
if key in data:
|
||||
settings[key] = data[key]
|
||||
if key == "carddav_url" and str(data[key] or "").strip():
|
||||
try:
|
||||
settings[key] = _validate_carddav_url(data[key])
|
||||
except ValueError as e:
|
||||
raise HTTPException(400, str(e))
|
||||
else:
|
||||
settings[key] = data[key]
|
||||
_save_settings(settings)
|
||||
# Force re-fetch
|
||||
_contact_cache["fetched_at"] = None
|
||||
|
||||
+312
-14
@@ -11,6 +11,8 @@ import shlex
|
||||
from fastapi import HTTPException
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.platform_compat import _ssh_exec_argv
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -195,6 +197,20 @@ def _pip_install_attempt(pip_cmd: str) -> str:
|
||||
)
|
||||
|
||||
|
||||
def _pip_command(python_cmd: str) -> str:
|
||||
"""Return a pip command for either a pip executable or a Python executable."""
|
||||
cmd = python_cmd.strip()
|
||||
if " -m pip" in cmd or cmd in {"pip", "pip3"}:
|
||||
return python_cmd
|
||||
if cmd in {"python", "python3", "python.exe"} or cmd.endswith(("/python", "/python3", "\\python.exe")):
|
||||
return f"{python_cmd} -m pip"
|
||||
return python_cmd
|
||||
|
||||
|
||||
def _pip_break_system_packages_check(pip_cmd: str) -> str:
|
||||
return f"{pip_cmd} install --help 2>/dev/null | grep -q -- --break-system-packages"
|
||||
|
||||
|
||||
def _pip_install_fallback_chain(package: str, *, python_cmd: str = "python3 -m pip", upgrade: bool = False) -> str:
|
||||
"""Build a bash pip install fallback chain that surfaces errors.
|
||||
|
||||
@@ -206,33 +222,44 @@ def _pip_install_fallback_chain(package: str, *, python_cmd: str = "python3 -m p
|
||||
exit code is preserved (no ``| tail`` masking) and the last 5 lines of
|
||||
pip output appear in the Cookbook log on failure.
|
||||
"""
|
||||
from core.platform_compat import IS_WINDOWS
|
||||
upgrade_flag = " -U" if upgrade else ""
|
||||
# Shell-quote the package spec: an extras spec like ``llama-cpp-python[server]``
|
||||
# contains brackets that bash would treat as a glob, so it must be quoted
|
||||
# before being embedded in the install command. Plain names (e.g.
|
||||
# ``huggingface_hub``) are returned unchanged by ``shlex.quote``.
|
||||
pkg = shlex.quote(package)
|
||||
base = _pip_install_attempt(f"{python_cmd} install -q{upgrade_flag} {pkg}")
|
||||
user = _pip_install_attempt(f"{python_cmd} install --user --break-system-packages -q{upgrade_flag} {pkg}")
|
||||
# llama-cpp-python source builds are brittle on older distro pip/packaging
|
||||
# stacks (common on WSL images). Prefer the prebuilt wheel index whenever
|
||||
# this package is requested so dependency-install tasks are reliable.
|
||||
if "llama-cpp-python" in package:
|
||||
pkg += " --extra-index-url https://abetlen.github.io/llama-cpp-python/whl/cpu"
|
||||
|
||||
pip_cmd = _pip_command(python_cmd)
|
||||
base = _pip_install_attempt(f"{pip_cmd} install -q{upgrade_flag} {pkg}")
|
||||
user = _pip_install_attempt(f"{pip_cmd} install --user -q{upgrade_flag} {pkg}")
|
||||
user_break_system = _pip_install_attempt(f"{pip_cmd} install --user --break-system-packages -q{upgrade_flag} {pkg}")
|
||||
user_fallback = f"( {user} || {{ {_pip_break_system_packages_check(pip_cmd)} && {user_break_system}; }} )"
|
||||
# Derive the python executable for the venv detection check.
|
||||
# Must use the same interpreter that pip belongs to; hardcoding
|
||||
# python3 breaks when pip lives in a venv that only has "python".
|
||||
if " -m pip" in python_cmd:
|
||||
python_exe = python_cmd.replace(" -m pip", "")
|
||||
elif python_cmd.strip() == "pip":
|
||||
if " -m pip" in pip_cmd:
|
||||
python_exe = pip_cmd.replace(" -m pip", "")
|
||||
elif pip_cmd.strip() == "pip":
|
||||
python_exe = "python"
|
||||
elif python_cmd.strip() == "pip3":
|
||||
elif pip_cmd.strip() == "pip3":
|
||||
python_exe = "python3"
|
||||
else:
|
||||
python_exe = "python3"
|
||||
venv_check = f'{python_exe} -c "import sys; sys.exit(0 if sys.prefix != sys.base_prefix else 1)"'
|
||||
# Negated: `! venv_check` succeeds (exit 0) when NOT in a venv → `&&` tries
|
||||
# --user. When IN a venv `! venv_check` fails → `&&` skips --user and the
|
||||
# Negated: `! venv_check` succeeds (exit 0) when NOT in a venv -> `&&` tries
|
||||
# --user. When IN a venv `! venv_check` fails -> `&&` skips --user and the
|
||||
# group exits non-zero, propagating the base-install failure instead of
|
||||
# masking it as success (the `|| { venv_check || … }` shape from #903
|
||||
# swallowed the exit code because venv_check's exit-0 became the group's
|
||||
# result).
|
||||
return f"{base} || {{ ! {venv_check} && {user}; }}"
|
||||
# result). `--break-system-packages` is only attempted when the active pip
|
||||
# supports it; older pip versions abort with "no such option" otherwise.
|
||||
return f"{base} || {{ ! {venv_check} && {user_fallback}; }}"
|
||||
|
||||
|
||||
def _venv_safe_local_pip_install_cmd(cmd: str, *, local: bool, in_venv: bool) -> str:
|
||||
@@ -263,6 +290,55 @@ def _venv_safe_local_pip_install_cmd(cmd: str, *, local: bool, in_venv: bool) ->
|
||||
return shlex.join(stripped)
|
||||
|
||||
|
||||
def _pip_install_command_without_break_system_packages(cmd: str) -> str:
|
||||
try:
|
||||
parts = shlex.split(cmd)
|
||||
except ValueError:
|
||||
return cmd
|
||||
stripped = [part for part in parts if part != "--break-system-packages"]
|
||||
return shlex.join(stripped)
|
||||
|
||||
|
||||
def _pip_install_help_check_from_cmd(cmd: str) -> str | None:
|
||||
try:
|
||||
parts = shlex.split(cmd)
|
||||
except ValueError:
|
||||
return None
|
||||
try:
|
||||
install_index = parts.index("install")
|
||||
except ValueError:
|
||||
return None
|
||||
if install_index <= 0:
|
||||
return None
|
||||
pip_prefix = parts[:install_index]
|
||||
return f"{shlex.join(pip_prefix + ['install', '--help'])} 2>/dev/null | grep -q -- --break-system-packages"
|
||||
|
||||
|
||||
def _append_pip_install_runner_lines(runner_lines: list[str], cmd: str) -> None:
|
||||
"""Append a pip install command, guarding --break-system-packages support.
|
||||
|
||||
The Dependencies UI may submit ``python3 -m pip install --user
|
||||
--break-system-packages ...`` for non-venv installs. That flag is useful on
|
||||
PEP-668-locked distros, but older pip (including Ubuntu 22.04's apt pip in
|
||||
the NVIDIA CUDA base image) aborts with "no such option". Branch at runner
|
||||
time so stale browser JS and remote targets are handled by the server too.
|
||||
"""
|
||||
if "--break-system-packages" not in (cmd or ""):
|
||||
runner_lines.append(cmd)
|
||||
return
|
||||
help_check = _pip_install_help_check_from_cmd(cmd)
|
||||
without_break = _pip_install_command_without_break_system_packages(cmd)
|
||||
if not help_check or without_break == cmd:
|
||||
runner_lines.append(cmd)
|
||||
return
|
||||
runner_lines.append(f"if {help_check}; then")
|
||||
runner_lines.append(f" {cmd}")
|
||||
runner_lines.append("else")
|
||||
runner_lines.append(' echo "[odysseus] pip does not support --break-system-packages; installing without it."')
|
||||
runner_lines.append(f" {without_break}")
|
||||
runner_lines.append("fi")
|
||||
|
||||
|
||||
def _user_shell_path_bootstrap() -> list[str]:
|
||||
return [
|
||||
'ODYSSEUS_USER_SHELL="${SHELL:-}"',
|
||||
@@ -271,11 +347,14 @@ def _user_shell_path_bootstrap() -> list[str]:
|
||||
' if [ -n "$ODYSSEUS_USER_PATH" ]; then export PATH="$ODYSSEUS_USER_PATH:$PATH"; fi',
|
||||
'fi',
|
||||
'command -v python3 >/dev/null 2>&1 || python3() { python "$@"; }',
|
||||
'command -v python >/dev/null 2>&1 || python() { python3 "$@"; }',
|
||||
]
|
||||
|
||||
|
||||
def _cached_model_scan_script(model_dirs: list[str] | None = None) -> str:
|
||||
"""Build the standalone Python scanner used by /api/model/cached."""
|
||||
def _cached_model_scan_script(model_dirs: list[str] | None = None, add_hf_cache: str | None = None) -> str:
|
||||
"""Build the standalone Python scanner used by /api/model/cached.
|
||||
Allows for an additional HuggingFace cache path to be scanned (i.e. Windows HF cache for local WSL envs.)
|
||||
"""
|
||||
lines = [
|
||||
"import json, os, re, shutil, subprocess, urllib.request",
|
||||
"models = []",
|
||||
@@ -338,6 +417,15 @@ def _cached_model_scan_script(model_dirs: list[str] | None = None) -> str:
|
||||
" if f.is_file(): nf += 1; sz += f.stat().st_size",
|
||||
" if f.name.endswith('.incomplete'): ic = True",
|
||||
" snap = os.path.join(cache, d, 'snapshots')",
|
||||
" # Windows HF cache stores files directly in snapshots/; blobs/ may be empty.",
|
||||
" # Fallback: scan snapshots for real files when blobs yielded nothing.",
|
||||
" if sz == 0 and os.path.isdir(snap):",
|
||||
" for sd in os.listdir(snap):",
|
||||
" sf = os.path.join(snap, sd)",
|
||||
" if not os.path.isdir(sf): continue",
|
||||
" for f in os.scandir(sf):",
|
||||
" if f.is_file(): nf += 1; sz += f.stat().st_size",
|
||||
" if f.name.endswith('.incomplete'): ic = True",
|
||||
" is_diffusion = False; gguf_files = []",
|
||||
" if os.path.isdir(snap):",
|
||||
" for sd in os.listdir(snap):",
|
||||
@@ -346,6 +434,21 @@ def _cached_model_scan_script(model_dirs: list[str] | None = None) -> str:
|
||||
" if os.path.exists(os.path.join(sf, 'model_index.json')): is_diffusion = True",
|
||||
" for f in collect_ggufs(sf): f['rel_path'] = sd + '/' + f['rel_path']; gguf_files.append(f)",
|
||||
" models.append({'repo_id':rid,'size_bytes':sz,'nb_files':nf,'has_incomplete':ic,'path':cache,'is_diffusion':is_diffusion,'is_gguf':bool(gguf_files),'gguf_files':gguf_files})",
|
||||
"def hf_cache_paths():",
|
||||
" candidates = []",
|
||||
" def add(p):",
|
||||
" if not p: return",
|
||||
" p = os.path.expanduser(p)",
|
||||
" if p not in candidates: candidates.append(p)",
|
||||
" add(os.environ.get('HUGGINGFACE_HUB_CACHE'))",
|
||||
" hf_home = os.environ.get('HF_HOME')",
|
||||
" if hf_home: add(os.path.join(hf_home, 'hub'))",
|
||||
" add('~/.cache/huggingface/hub')",
|
||||
" # Docker images mount ./data/huggingface at /app/.cache/huggingface.",
|
||||
" # When HOME is /root, expanduser() misses that persisted cache.",
|
||||
" add('/app/.cache/huggingface/hub')",
|
||||
f" add({add_hf_cache!r})" if add_hf_cache else "",
|
||||
" return candidates",
|
||||
"def scan_dir(p):",
|
||||
" if not os.path.isdir(p) or not safe_path(p): return",
|
||||
" for d in sorted(os.listdir(p)):",
|
||||
@@ -409,7 +512,7 @@ def _cached_model_scan_script(model_dirs: list[str] | None = None) -> str:
|
||||
" seen.add(name)",
|
||||
" models.append({'repo_id':name,'size_bytes':size_bytes,'nb_files':1,'has_incomplete':False,'path':'ollama','backend':'ollama','is_ollama':True})",
|
||||
" return",
|
||||
"scan_hf(os.path.expanduser('~/.cache/huggingface/hub'))",
|
||||
"for _hf_cache in hf_cache_paths(): scan_hf(_hf_cache)",
|
||||
"scan_ollama()",
|
||||
"scan_ollama_api()",
|
||||
]
|
||||
@@ -525,6 +628,7 @@ def _validate_serve_cmd(v: str | None) -> str | None:
|
||||
# Backticks and raw newlines are never legitimate here.
|
||||
if any(c in v for c in ("`", "\n", "\r")):
|
||||
raise HTTPException(400, "Invalid characters in cmd")
|
||||
|
||||
# Known GGUF launcher prelude → validate the serve invocation(s) it guards.
|
||||
m = _GGUF_PRELUDE_RE.match(v)
|
||||
if m:
|
||||
@@ -533,9 +637,19 @@ def _validate_serve_cmd(v: str | None) -> str | None:
|
||||
for part in rest.split("||"):
|
||||
_check_serve_binary(part.strip())
|
||||
return v
|
||||
|
||||
# Otherwise: a single invocation — no shell metacharacters allowed.
|
||||
# Temporarily replace safe $(printf %s ...) expressions with a placeholder
|
||||
# to avoid triggering the metacharacter/command-injection checks.
|
||||
cleaned_v = v
|
||||
printf_matches = list(re.finditer(r"\$\(\s*printf\s+%s\s+([^\n()]*?)\)", v))
|
||||
for match in printf_matches:
|
||||
inner = match.group(1)
|
||||
if not any(c in inner for c in (";", "&&", "||", "$(", "`")):
|
||||
cleaned_v = cleaned_v.replace(match.group(0), "/placeholder/safe/path.gguf")
|
||||
|
||||
# (`$(` was the original intent; bare `$` is fine for shell-safe paths.)
|
||||
if any(c in v for c in (";", "&&", "||", "$(")):
|
||||
if any(c in cleaned_v for c in (";", "&&", "||", "$(")):
|
||||
raise HTTPException(400, "Invalid characters in cmd")
|
||||
_check_serve_binary(v)
|
||||
return v
|
||||
@@ -559,6 +673,21 @@ def _append_serve_preflight_exit_lines(runner_lines: list[str], *, keep_shell_op
|
||||
runner_lines.append('fi')
|
||||
|
||||
|
||||
def _append_vllm_linux_preflight_lines(runner_lines: list[str]) -> None:
|
||||
"""Append Linux vLLM readiness lines that identify the runtime being used."""
|
||||
# Keep the user install bin visible for Odysseus-managed `pip install --user`
|
||||
# installs, but then report the actual CLI path so external runtimes are clear.
|
||||
runner_lines.append('export PATH="$HOME/.local/bin:$PATH"')
|
||||
runner_lines.append('ODYSSEUS_VLLM_BIN="$(command -v vllm 2>/dev/null || true)"')
|
||||
runner_lines.append('if [ -z "$ODYSSEUS_VLLM_BIN" ]; then')
|
||||
runner_lines.append(' echo "ERROR: vLLM is not installed."')
|
||||
runner_lines.append(' ODYSSEUS_PREFLIGHT_EXIT=127')
|
||||
runner_lines.append('else')
|
||||
runner_lines.append(' echo "[odysseus] vLLM CLI: $ODYSSEUS_VLLM_BIN"')
|
||||
runner_lines.append(' ODYSSEUS_VLLM_VERSION="$("$ODYSSEUS_VLLM_BIN" --version 2>&1 | head -n 1 || true)"')
|
||||
runner_lines.append(' if [ -n "$ODYSSEUS_VLLM_VERSION" ]; then echo "[odysseus] vLLM version: $ODYSSEUS_VLLM_VERSION"; fi')
|
||||
runner_lines.append('fi')
|
||||
|
||||
def _append_serve_exit_code_lines(
|
||||
runner_lines: list[str],
|
||||
*,
|
||||
@@ -804,3 +933,172 @@ def _ssh_ps(host, script_path, port=None):
|
||||
|
||||
# Windows session dir — stored in user's temp on the remote
|
||||
WIN_SESSION_DIR = "$env:TEMP\\\\odysseus-sessions"
|
||||
|
||||
|
||||
def _diagnose_serve_output(text: str) -> dict | None:
|
||||
"""Server-side mirror of the Cookbook UI's common serve diagnoses.
|
||||
|
||||
The browser uses cookbook-diagnosis.js for clickable fixes. This gives
|
||||
the agent/tool path the same structured signal so it can retry with an
|
||||
adjusted command instead of guessing from raw tmux output.
|
||||
"""
|
||||
if not text:
|
||||
return None
|
||||
tail = text[-6000:]
|
||||
patterns = [
|
||||
(
|
||||
r"No available memory for the cache blocks|Available KV cache memory:.*-",
|
||||
"No GPU memory left for KV cache after loading model.",
|
||||
[
|
||||
{"label": "retry with GPU memory utilization 0.95", "op": "replace", "flag": "--gpu-memory-utilization", "value": "0.95"},
|
||||
{"label": "retry with context 2048", "op": "replace", "flag": "--max-model-len", "value": "2048"},
|
||||
],
|
||||
),
|
||||
(
|
||||
r"CUDA out of memory|torch\.cuda\.OutOfMemoryError|CUDA error: out of memory|warming up sampler|max_num_seqs.*gpu_memory_utilization",
|
||||
"GPU ran out of memory during startup or warmup.",
|
||||
[
|
||||
{"label": "retry with context 4096", "op": "replace", "flag": "--max-model-len", "value": "4096"},
|
||||
{"label": "retry with GPU memory utilization 0.80", "op": "replace", "flag": "--gpu-memory-utilization", "value": "0.80"},
|
||||
{"label": "retry with --enforce-eager", "op": "append", "arg": "--enforce-eager"},
|
||||
],
|
||||
),
|
||||
(
|
||||
r"not divisib|must be divisible|attention heads.*divisible",
|
||||
"Tensor parallel size is incompatible with the model.",
|
||||
[
|
||||
{"label": "retry with tensor parallel size 1", "op": "replace", "flag": "--tensor-parallel-size", "value": "1"},
|
||||
{"label": "retry with tensor parallel size 2", "op": "replace", "flag": "--tensor-parallel-size", "value": "2"},
|
||||
],
|
||||
),
|
||||
(
|
||||
r"KV cache.*too (small|large)|max_model_len.*exceeds|maximum.*context",
|
||||
"Context length is too large for available GPU memory.",
|
||||
[
|
||||
{"label": "retry with context 8192", "op": "replace", "flag": "--max-model-len", "value": "8192"},
|
||||
{"label": "retry with context 4096", "op": "replace", "flag": "--max-model-len", "value": "4096"},
|
||||
],
|
||||
),
|
||||
(
|
||||
r"enable-auto-tool-choice requires --tool-call-parser",
|
||||
"Auto tool choice requires an explicit tool call parser.",
|
||||
[{"label": "retry with Hermes tool parser", "op": "append", "arg": "--tool-call-parser hermes"}],
|
||||
),
|
||||
(
|
||||
r"Please pass.*trust.remote.code=True|contains custom code which must be executed to correctly load|does not recognize this architecture|model type.*but Transformers does not",
|
||||
"Model requires custom code or newer model support.",
|
||||
[{"label": "retry with --trust-remote-code", "op": "append", "arg": "--trust-remote-code"}],
|
||||
),
|
||||
(
|
||||
r"There is no module or parameter named ['\"]lm_head\.input_scale['\"]|lm_head\.input_scale|weight_scale_2",
|
||||
"vLLM cannot load this ModelOpt LM-head quantized checkpoint with the current runtime.",
|
||||
[
|
||||
{
|
||||
"label": "upgrade vLLM through the environment that provides this CLI, or use a compatible checkpoint",
|
||||
"op": "manual",
|
||||
}
|
||||
],
|
||||
),
|
||||
(
|
||||
r"Either a revision or a version must be specified|transformers\.integrations\.hub_kernels|kernels/layer",
|
||||
"vLLM/Transformers kernel package mismatch.",
|
||||
[{"label": "update vLLM, Transformers, and kernels on this server", "op": "dependency", "package": "vllm transformers kernels"}],
|
||||
),
|
||||
(
|
||||
r"Address already in use|bind.*address.*in use",
|
||||
"Port is already in use.",
|
||||
[{"label": "retry on port 8001", "op": "replace", "flag": "--port", "value": "8001"}],
|
||||
),
|
||||
(
|
||||
r"No CUDA GPUs are available|no GPU.*found|CUDA_VISIBLE_DEVICES.*invalid",
|
||||
"No GPUs are visible to the serve process.",
|
||||
[{"label": "clear Cookbook GPU selection or choose available GPUs", "op": "settings", "field": "gpus", "value": ""}],
|
||||
),
|
||||
(
|
||||
r"Failed to infer device type|NVML Shared Library Not Found|No module named 'amdsmi'|platform is not available",
|
||||
"vLLM could not find a supported GPU (CUDA or ROCm). "
|
||||
"This machine may have integrated or unsupported graphics only.",
|
||||
[
|
||||
{"label": "switch to llama.cpp (CPU/Metal, works without a discrete GPU)", "op": "manual"},
|
||||
{"label": "switch to Ollama (CPU/Metal, works without a discrete GPU)", "op": "manual"},
|
||||
],
|
||||
),
|
||||
(
|
||||
r"vllm.*command not found|No module named vllm|ERROR: vLLM is not installed",
|
||||
"vLLM is not installed or not in PATH on this server.",
|
||||
[{"label": "install vLLM in Cookbook Dependencies", "op": "dependency", "package": "vllm"}],
|
||||
),
|
||||
(
|
||||
r"sglang.*command not found|No module named sglang|SGLang is not installed",
|
||||
"SGLang is not installed or not in PATH on this server.",
|
||||
[{"label": "install SGLang in Cookbook Dependencies", "op": "dependency", "package": "sglang[all]"}],
|
||||
),
|
||||
(
|
||||
r"llama-server.*command not found|llama\.cpp.*not found|No module named.*llama_cpp|No module named 'starlette_context'|git: command not found|cmake: command not found",
|
||||
"llama.cpp / llama-cpp-python dependencies are missing.",
|
||||
[{"label": "install llama.cpp dependencies or llama-cpp-python[server]", "op": "dependency", "package": "llama-cpp-python[server]"}],
|
||||
),
|
||||
(
|
||||
r"No GGUF found on this host|no \.gguf file|No GGUF file found",
|
||||
"No GGUF file found for this model on this host. The llama.cpp backend needs a .gguf file.",
|
||||
[{"label": "download a GGUF build of this model (repo name usually ends in -GGUF, file like Q4_K_M.gguf)", "op": "manual"}],
|
||||
),
|
||||
(
|
||||
r"No module named 'torch'|No module named torch|No module named 'diffusers'|No module named diffusers",
|
||||
"Diffusion serving requires PyTorch and diffusers.",
|
||||
[{"label": "install diffusers[torch] in Cookbook Dependencies", "op": "dependency", "package": "diffusers[torch]"}],
|
||||
),
|
||||
(
|
||||
r"403 Forbidden|401 Unauthorized|Access to model.*is restricted|gated repo|not in the authorized list|awaiting a review",
|
||||
"Model access is gated or unauthorized.",
|
||||
[{"label": "set HF token and request model access on HuggingFace", "op": "manual"}],
|
||||
),
|
||||
]
|
||||
for pattern, message, suggestions in patterns:
|
||||
if re.search(pattern, tail, re.I):
|
||||
return {"message": message, "suggestions": suggestions}
|
||||
if re.search(r"Traceback \(most recent call last\)", tail, re.I) and not re.search(
|
||||
r"Application startup complete|GET /v1/|Uvicorn running on", tail, re.I
|
||||
):
|
||||
return {
|
||||
"message": "Python traceback detected during serve startup.",
|
||||
"suggestions": [{"label": "inspect traceback and retry with adjusted backend/settings", "op": "manual"}],
|
||||
}
|
||||
return None
|
||||
|
||||
|
||||
async def run_ssh_command_async(
|
||||
remote: str,
|
||||
ssh_port: str | None,
|
||||
remote_cmd: str,
|
||||
*,
|
||||
timeout: float,
|
||||
connect_timeout: int | None = None,
|
||||
strict_host_key_checking: bool | None = None,
|
||||
stdin_data: bytes | None = None,
|
||||
) -> tuple[int, bytes, bytes]:
|
||||
"""Run an ssh command with centralized timeout and stderr/stdout capture.
|
||||
Async version of core.platform_compat.run_ssh_command_sync.
|
||||
"""
|
||||
import asyncio
|
||||
proc = await asyncio.create_subprocess_exec(
|
||||
*_ssh_exec_argv(
|
||||
remote,
|
||||
ssh_port,
|
||||
remote_cmd=remote_cmd,
|
||||
connect_timeout=connect_timeout,
|
||||
strict_host_key_checking=strict_host_key_checking,
|
||||
),
|
||||
stdin=asyncio.subprocess.PIPE if stdin_data is not None else None,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
)
|
||||
try:
|
||||
stdout, stderr = await asyncio.wait_for(
|
||||
proc.communicate(input=stdin_data), timeout=timeout
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
proc.kill()
|
||||
await proc.communicate()
|
||||
raise
|
||||
return proc.returncode or 0, stdout, stderr
|
||||
|
||||
+189
-206
@@ -15,19 +15,26 @@ from pathlib import Path
|
||||
from fastapi import APIRouter, HTTPException, Request, Depends
|
||||
|
||||
from src.auth_helpers import require_user
|
||||
from src.constants import COOKBOOK_STATE_FILE
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.middleware import require_admin
|
||||
from core.platform_compat import (
|
||||
IS_WINDOWS,
|
||||
SSH_PATH_OVERRIDE,
|
||||
NVIDIA_PATH_CANDIDATES,
|
||||
detached_popen_kwargs,
|
||||
find_bash,
|
||||
git_bash_path,
|
||||
kill_process_tree,
|
||||
pid_alive,
|
||||
safe_chmod,
|
||||
which_tool,
|
||||
translate_path,
|
||||
get_wsl_windows_user_profile,
|
||||
)
|
||||
from routes.shell_routes import TMUX_LOG_DIR
|
||||
from src.constants import COOKBOOK_STATE_FILE
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -38,8 +45,10 @@ from routes.cookbook_helpers import (
|
||||
_ps_squote, _bash_squote, _validate_serve_cmd, _parse_serve_phase,
|
||||
_safe_env_prefix, _local_tooling_path_export, _append_serve_preflight_exit_lines,
|
||||
_append_serve_exit_code_lines, _append_llama_cpp_linux_accel_build_lines, _cached_model_scan_script,
|
||||
_ollama_bind_from_cmd, _pip_install_fallback_chain, _pip_install_no_cache,
|
||||
_user_shell_path_bootstrap, _venv_safe_local_pip_install_cmd,
|
||||
_append_vllm_linux_preflight_lines, _ollama_bind_from_cmd, _pip_install_fallback_chain,
|
||||
_pip_install_no_cache, _user_shell_path_bootstrap, _venv_safe_local_pip_install_cmd,
|
||||
_append_pip_install_runner_lines,
|
||||
_diagnose_serve_output, run_ssh_command_async,
|
||||
ModelDownloadRequest, ServeRequest,
|
||||
)
|
||||
|
||||
@@ -54,7 +63,7 @@ _HF_TOKEN_STATUS_SNIPPET = (
|
||||
|
||||
def setup_cookbook_routes() -> APIRouter:
|
||||
router = APIRouter(tags=["cookbook"])
|
||||
_cookbook_state_path = Path(os.environ.get("DATA_DIR", "data")) / "cookbook_state.json"
|
||||
_cookbook_state_path = Path(COOKBOOK_STATE_FILE)
|
||||
|
||||
def _mask_secret(value: str) -> str:
|
||||
if not value:
|
||||
@@ -81,127 +90,6 @@ def setup_cookbook_routes() -> APIRouter:
|
||||
task["payload"].pop("hf_token", None)
|
||||
return state
|
||||
|
||||
def _diagnose_serve_output(text: str) -> dict | None:
|
||||
"""Server-side mirror of the Cookbook UI's common serve diagnoses.
|
||||
|
||||
The browser uses cookbook-diagnosis.js for clickable fixes. This gives
|
||||
the agent/tool path the same structured signal so it can retry with an
|
||||
adjusted command instead of guessing from raw tmux output.
|
||||
"""
|
||||
if not text:
|
||||
return None
|
||||
tail = text[-6000:]
|
||||
patterns = [
|
||||
(
|
||||
r"No available memory for the cache blocks|Available KV cache memory:.*-",
|
||||
"No GPU memory left for KV cache after loading model.",
|
||||
[
|
||||
{"label": "retry with GPU memory utilization 0.95", "op": "replace", "flag": "--gpu-memory-utilization", "value": "0.95"},
|
||||
{"label": "retry with context 2048", "op": "replace", "flag": "--max-model-len", "value": "2048"},
|
||||
],
|
||||
),
|
||||
(
|
||||
r"CUDA out of memory|torch\.cuda\.OutOfMemoryError|CUDA error: out of memory|warming up sampler|max_num_seqs.*gpu_memory_utilization",
|
||||
"GPU ran out of memory during startup or warmup.",
|
||||
[
|
||||
{"label": "retry with context 4096", "op": "replace", "flag": "--max-model-len", "value": "4096"},
|
||||
{"label": "retry with GPU memory utilization 0.80", "op": "replace", "flag": "--gpu-memory-utilization", "value": "0.80"},
|
||||
{"label": "retry with --enforce-eager", "op": "append", "arg": "--enforce-eager"},
|
||||
],
|
||||
),
|
||||
(
|
||||
r"not divisib|must be divisible|attention heads.*divisible",
|
||||
"Tensor parallel size is incompatible with the model.",
|
||||
[
|
||||
{"label": "retry with tensor parallel size 1", "op": "replace", "flag": "--tensor-parallel-size", "value": "1"},
|
||||
{"label": "retry with tensor parallel size 2", "op": "replace", "flag": "--tensor-parallel-size", "value": "2"},
|
||||
],
|
||||
),
|
||||
(
|
||||
r"KV cache.*too (small|large)|max_model_len.*exceeds|maximum.*context",
|
||||
"Context length is too large for available GPU memory.",
|
||||
[
|
||||
{"label": "retry with context 8192", "op": "replace", "flag": "--max-model-len", "value": "8192"},
|
||||
{"label": "retry with context 4096", "op": "replace", "flag": "--max-model-len", "value": "4096"},
|
||||
],
|
||||
),
|
||||
(
|
||||
r"enable-auto-tool-choice requires --tool-call-parser",
|
||||
"Auto tool choice requires an explicit tool call parser.",
|
||||
[{"label": "retry with Hermes tool parser", "op": "append", "arg": "--tool-call-parser hermes"}],
|
||||
),
|
||||
(
|
||||
r"Please pass.*trust.remote.code=True|contains custom code which must be executed to correctly load|does not recognize this architecture|model type.*but Transformers does not",
|
||||
"Model requires custom code or newer model support.",
|
||||
[{"label": "retry with --trust-remote-code", "op": "append", "arg": "--trust-remote-code"}],
|
||||
),
|
||||
(
|
||||
r"Either a revision or a version must be specified|transformers\.integrations\.hub_kernels|kernels/layer",
|
||||
"vLLM/Transformers kernel package mismatch.",
|
||||
[{"label": "update vLLM, Transformers, and kernels on this server", "op": "dependency", "package": "vllm transformers kernels"}],
|
||||
),
|
||||
(
|
||||
r"Address already in use|bind.*address.*in use",
|
||||
"Port is already in use.",
|
||||
[{"label": "retry on port 8001", "op": "replace", "flag": "--port", "value": "8001"}],
|
||||
),
|
||||
(
|
||||
r"No CUDA GPUs are available|no GPU.*found|CUDA_VISIBLE_DEVICES.*invalid",
|
||||
"No GPUs are visible to the serve process.",
|
||||
[{"label": "clear Cookbook GPU selection or choose available GPUs", "op": "settings", "field": "gpus", "value": ""}],
|
||||
),
|
||||
(
|
||||
r"Failed to infer device type|NVML Shared Library Not Found|No module named 'amdsmi'|platform is not available",
|
||||
"vLLM could not find a supported GPU (CUDA or ROCm). "
|
||||
"This machine may have integrated or unsupported graphics only.",
|
||||
[
|
||||
{"label": "switch to llama.cpp (CPU/Metal, works without a discrete GPU)", "op": "manual"},
|
||||
{"label": "switch to Ollama (CPU/Metal, works without a discrete GPU)", "op": "manual"},
|
||||
],
|
||||
),
|
||||
(
|
||||
r"vllm.*command not found|No module named vllm|ERROR: vLLM is not installed",
|
||||
"vLLM is not installed or not in PATH on this server.",
|
||||
[{"label": "install vLLM in Cookbook Dependencies", "op": "dependency", "package": "vllm"}],
|
||||
),
|
||||
(
|
||||
r"sglang.*command not found|No module named sglang|SGLang is not installed",
|
||||
"SGLang is not installed or not in PATH on this server.",
|
||||
[{"label": "install SGLang in Cookbook Dependencies", "op": "dependency", "package": "sglang[all]"}],
|
||||
),
|
||||
(
|
||||
r"llama-server.*command not found|llama\.cpp.*not found|No module named.*llama_cpp|No module named 'starlette_context'|git: command not found|cmake: command not found",
|
||||
"llama.cpp / llama-cpp-python dependencies are missing.",
|
||||
[{"label": "install llama.cpp dependencies or llama-cpp-python[server]", "op": "dependency", "package": "llama-cpp-python[server]"}],
|
||||
),
|
||||
(
|
||||
r"No GGUF found on this host|no \.gguf file|No GGUF file found",
|
||||
"No GGUF file found for this model on this host. The llama.cpp backend needs a .gguf file.",
|
||||
[{"label": "download a GGUF build of this model (repo name usually ends in -GGUF, file like Q4_K_M.gguf)", "op": "manual"}],
|
||||
),
|
||||
(
|
||||
r"No module named 'torch'|No module named torch|No module named 'diffusers'|No module named diffusers",
|
||||
"Diffusion serving requires PyTorch and diffusers.",
|
||||
[{"label": "install diffusers[torch] in Cookbook Dependencies", "op": "dependency", "package": "diffusers[torch]"}],
|
||||
),
|
||||
(
|
||||
r"403 Forbidden|401 Unauthorized|Access to model.*is restricted|gated repo|not in the authorized list|awaiting a review",
|
||||
"Model access is gated or unauthorized.",
|
||||
[{"label": "set HF token and request model access on HuggingFace", "op": "manual"}],
|
||||
),
|
||||
]
|
||||
for pattern, message, suggestions in patterns:
|
||||
if re.search(pattern, tail, re.I):
|
||||
return {"message": message, "suggestions": suggestions}
|
||||
if re.search(r"Traceback \(most recent call last\)", tail, re.I) and not re.search(
|
||||
r"Application startup complete|GET /v1/|Uvicorn running on", tail, re.I
|
||||
):
|
||||
return {
|
||||
"message": "Python traceback detected during serve startup.",
|
||||
"suggestions": [{"label": "inspect traceback and retry with adjusted backend/settings", "op": "manual"}],
|
||||
}
|
||||
return None
|
||||
|
||||
def _state_for_client(state):
|
||||
"""Return cookbook state without raw secrets for browser clients."""
|
||||
_strip_task_secrets(state)
|
||||
@@ -295,6 +183,7 @@ def setup_cookbook_routes() -> APIRouter:
|
||||
safe_chmod(key_path.with_suffix(".pub"), 0o644)
|
||||
return {"ok": True, "public_key": _read_cookbook_public_key()}
|
||||
|
||||
|
||||
def _needs_binary(cmd: str, binary: str) -> bool:
|
||||
return bool(re.search(rf"(^|[\s;&|()]){re.escape(binary)}($|[\s;&|()])", cmd or ""))
|
||||
|
||||
@@ -355,8 +244,8 @@ def setup_cookbook_routes() -> APIRouter:
|
||||
# POSIX form + shell-quoting so drive paths / spaces survive.
|
||||
inner = TMUX_LOG_DIR / f"{session_id}_run.sh"
|
||||
inner.write_text("\n".join(bash_lines) + "\n", encoding="utf-8")
|
||||
lp = shlex.quote(log_path.as_posix())
|
||||
ip = shlex.quote(inner.as_posix())
|
||||
lp = shlex.quote(git_bash_path(log_path))
|
||||
ip = shlex.quote(git_bash_path(inner))
|
||||
script_path = TMUX_LOG_DIR / f"{session_id}.sh"
|
||||
script_path.write_text(
|
||||
f"bash {ip} > {lp} 2>&1\n",
|
||||
@@ -472,6 +361,8 @@ def setup_cookbook_routes() -> APIRouter:
|
||||
ps_lines = []
|
||||
ps_lines.append('$sessionDir = "$env:TEMP\\odysseus-sessions"')
|
||||
ps_lines.append('New-Item -ItemType Directory -Force -Path $sessionDir | Out-Null')
|
||||
ps_lines.append('$env:PYTHONIOENCODING = "utf-8"')
|
||||
ps_lines.append('$env:PYTHONUTF8 = "1"')
|
||||
if req.hf_token:
|
||||
ps_lines.append(f"$env:HF_TOKEN = '{_ps_squote(req.hf_token)}'")
|
||||
if req.env_prefix:
|
||||
@@ -545,7 +436,7 @@ def setup_cookbook_routes() -> APIRouter:
|
||||
# Install hf CLI + optional hf_transfer best-effort. Retries disable
|
||||
# hf_transfer because the Rust parallel path is fast but has been
|
||||
# flaky near the end of very large multi-file downloads.
|
||||
# Use --break-system-packages on PEP-668 systems (Arch, newer Debian) so it doesn't bail.
|
||||
# The helper tries active pip first, then guarded user-site fallbacks.
|
||||
runner_lines.append(f"command -v hf >/dev/null 2>&1 || {_pip_install_fallback_chain('huggingface_hub', python_cmd='pip', upgrade=True)}")
|
||||
if req.disable_hf_transfer:
|
||||
runner_lines.append("export HF_HUB_ENABLE_HF_TRANSFER=0")
|
||||
@@ -673,24 +564,35 @@ def setup_cookbook_routes() -> APIRouter:
|
||||
for d in model_dir.split(','):
|
||||
d = d.strip()
|
||||
if d:
|
||||
model_dirs.append(d)
|
||||
paths_code = _cached_model_scan_script(model_dirs)
|
||||
translated_d = translate_path(d) if not host else d
|
||||
model_dirs.append(translated_d)
|
||||
win_hf_hub = None
|
||||
if not host:
|
||||
win_profile = get_wsl_windows_user_profile()
|
||||
win_hf_hub = os.path.join(win_profile, ".cache", "huggingface", "hub") if win_profile else None
|
||||
|
||||
paths_code = _cached_model_scan_script(model_dirs, win_hf_hub)
|
||||
|
||||
scan_py = TMUX_LOG_DIR / "scan_cache.py"
|
||||
scan_py.write_text(paths_code, encoding="utf-8")
|
||||
scan_payload = scan_py.read_bytes()
|
||||
|
||||
if host:
|
||||
_pf = f"-p {ssh_port} " if ssh_port and ssh_port != "22" else ""
|
||||
if platform == "windows":
|
||||
# Windows: use 'python' and pipe via stdin with double-quote wrapping
|
||||
cmd = f'ssh {_pf}{host} "python -" < \'{scan_py}\''
|
||||
remote_cmd = "python -"
|
||||
else:
|
||||
cmd = f"ssh {_pf}{host} 'python3 -' < '{scan_py}'"
|
||||
proc = await asyncio.create_subprocess_shell(
|
||||
cmd,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
cwd=str(Path.home()),
|
||||
# POSIX: use 'python3' if available, fall back to 'python'; throw if neither is found.
|
||||
remote_cmd = (
|
||||
"if command -v python3 >/dev/null 2>&1; then python3 -; "
|
||||
"elif command -v python >/dev/null 2>&1; then python -; "
|
||||
"else echo \"python3/python not found\" >&2; exit 127; fi"
|
||||
)
|
||||
rc, stdout_b, stderr_b = await run_ssh_command_async(
|
||||
host,
|
||||
ssh_port,
|
||||
remote_cmd,
|
||||
timeout=60,
|
||||
stdin_data=scan_payload,
|
||||
)
|
||||
else:
|
||||
# LOCAL scan: use sys.executable (the venv Python Odysseus is already
|
||||
@@ -710,7 +612,7 @@ def setup_cookbook_routes() -> APIRouter:
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
cwd=str(Path.home()),
|
||||
)
|
||||
stdout_b, stderr_b = await asyncio.wait_for(proc.communicate(), timeout=60)
|
||||
stdout_b, stderr_b = await asyncio.wait_for(proc.communicate(), timeout=60)
|
||||
|
||||
models = []
|
||||
try:
|
||||
@@ -915,6 +817,10 @@ def setup_cookbook_routes() -> APIRouter:
|
||||
existing.name = display_name
|
||||
if supports_tools is not None:
|
||||
existing.supports_tools = supports_tools
|
||||
# Wipe stale model lists so the picker re-probes and discovers
|
||||
# the newly-served model instead of showing the old one.
|
||||
existing.cached_models = None
|
||||
existing.hidden_models = None
|
||||
db.commit()
|
||||
logger.info(f"Updated existing local model endpoint: {base_url}")
|
||||
return existing.id
|
||||
@@ -971,11 +877,27 @@ def setup_cookbook_routes() -> APIRouter:
|
||||
in_venv=sys.prefix != sys.base_prefix,
|
||||
)
|
||||
is_pip_install = bool(req.cmd and "pip install" in req.cmd)
|
||||
remote = req.remote_host
|
||||
is_windows = req.platform == "windows"
|
||||
local_windows = IS_WINDOWS and not remote
|
||||
if is_windows or local_windows:
|
||||
if req.cmd.startswith("python3 "):
|
||||
req.cmd = "python " + req.cmd[len("python3 "):]
|
||||
if is_pip_install and ("llama-cpp-python" in req.cmd or "llama_cpp" in req.cmd) and (is_windows or local_windows):
|
||||
if "--extra-index-url" not in req.cmd:
|
||||
req.cmd += " --extra-index-url https://abetlen.github.io/llama-cpp-python/whl/cpu"
|
||||
|
||||
if is_pip_install:
|
||||
# Keep big dependency wheel builds (vLLM, …) off the home filesystem's
|
||||
# pip cache so they don't fail mid-build with "No space left" (#1219)
|
||||
# and leave the dep installed-but-unusable (#1459).
|
||||
req.cmd = _pip_install_no_cache(req.cmd)
|
||||
# Accept common aliases and enforce server extras for llama-cpp so
|
||||
# `python -m llama_cpp.server` has all runtime dependencies.
|
||||
req.cmd = re.sub(r"(?<![A-Za-z0-9_.-])llama_cpp(?![A-Za-z0-9_.-])", "llama-cpp-python[server]", req.cmd)
|
||||
req.cmd = re.sub(r"(?<![A-Za-z0-9_.-])llama-cpp-python(?!\[)", "llama-cpp-python[server]", req.cmd)
|
||||
if "llama-cpp-python" in req.cmd and "--extra-index-url" not in req.cmd:
|
||||
req.cmd += " --extra-index-url https://abetlen.github.io/llama-cpp-python/whl/cpu"
|
||||
# PEP-508-style package spec — letters, digits, `.-_` for the
|
||||
# name; `[` `]` for extras; `<>=!~,` for version specifiers.
|
||||
# v2 review HIGH-14: tightened from the previous regex which
|
||||
@@ -1028,6 +950,8 @@ def setup_cookbook_routes() -> APIRouter:
|
||||
ps_lines = []
|
||||
ps_lines.append('$sessionDir = "$env:TEMP\\odysseus-sessions"')
|
||||
ps_lines.append('New-Item -ItemType Directory -Force -Path $sessionDir | Out-Null')
|
||||
ps_lines.append('$env:PYTHONIOENCODING = "utf-8"')
|
||||
ps_lines.append('$env:PYTHONUTF8 = "1"')
|
||||
if req.hf_token:
|
||||
ps_lines.append(f"$env:HF_TOKEN = '{_ps_squote(req.hf_token)}'")
|
||||
if req.gpus:
|
||||
@@ -1046,7 +970,7 @@ def setup_cookbook_routes() -> APIRouter:
|
||||
ps_lines.append('try { python -c "import llama_cpp" 2>$null } catch {}')
|
||||
ps_lines.append('if ($LASTEXITCODE -ne 0) {')
|
||||
ps_lines.append(' Write-Host "Installing llama-cpp-python..."')
|
||||
ps_lines.append(' python -m pip install llama-cpp-python[server]')
|
||||
ps_lines.append(' python -m pip install llama-cpp-python[server] --extra-index-url https://abetlen.github.io/llama-cpp-python/whl/cpu')
|
||||
ps_lines.append('}')
|
||||
elif "vllm" in req.cmd:
|
||||
ps_lines.append('Write-Host "ERROR: vLLM is not supported on Windows. Use Ollama or llama.cpp instead."')
|
||||
@@ -1121,45 +1045,57 @@ def setup_cookbook_routes() -> APIRouter:
|
||||
# ollama is found (otherwise macOS falls back to a slow source build).
|
||||
# /opt/homebrew = Apple Silicon, /usr/local = Intel; harmless on Linux.
|
||||
runner_lines.append('export PATH="$HOME/.local/bin:$HOME/bin:$HOME/llama.cpp/build/bin:/opt/homebrew/bin:/usr/local/bin:$PATH"')
|
||||
runner_lines.append('if [ -d /data/data/com.termux ]; then')
|
||||
runner_lines.append(' # Termux: no native build — use the Python bindings (CPU).')
|
||||
runner_lines.append(' if ! python3 -c "import llama_cpp" 2>/dev/null; then')
|
||||
runner_lines.append(' pkg install -y cmake 2>/dev/null')
|
||||
runner_lines.append(' pip install numpy diskcache jinja2 2>/dev/null')
|
||||
runner_lines.append(' CMAKE_ARGS="-DGGML_BLAS=OFF -DGGML_LLAMAFILE=OFF" pip install \'llama-cpp-python[server]\' --no-build-isolation --no-cache-dir 2>&1 || true')
|
||||
runner_lines.append(' fi')
|
||||
runner_lines.append('elif ! command -v llama-server &>/dev/null; then')
|
||||
runner_lines.append(' echo "Native llama-server not found — building from source (one-time, may take a few minutes)..."')
|
||||
runner_lines.append(' mkdir -p ~/bin')
|
||||
runner_lines.append(' cd ~ && [ -d llama.cpp ] || git clone --depth 1 https://github.com/ggml-org/llama.cpp')
|
||||
# Build with the right accelerator: Metal on macOS (llama.cpp
|
||||
# enables it automatically, no flag), CUDA on Linux when present,
|
||||
# else a plain CPU build. nproc is Linux-only — fall back to
|
||||
# `sysctl hw.ncpu` on macOS. (Tip: `brew install llama.cpp` ships
|
||||
# a prebuilt llama-server and skips this whole source build.)
|
||||
runner_lines.append(' NPROC="$(nproc 2>/dev/null || sysctl -n hw.ncpu 2>/dev/null || echo 4)"')
|
||||
runner_lines.append(' if [ "$(uname -s)" = "Darwin" ]; then')
|
||||
runner_lines.append(' command -v cmake >/dev/null 2>&1 || echo "WARNING: cmake not found — install it with: brew install cmake (or: brew install llama.cpp for a prebuilt llama-server)."')
|
||||
# Start from a clean cache: a prior failed configure (e.g. a CUDA
|
||||
# attempt) poisons build/CMakeCache.txt, so a plain `cmake -B build`
|
||||
# would reuse the bad settings and fail again. CMAKE_BUILD_TYPE is
|
||||
# explicit so the binary is optimized (Metal auto-enables on macOS).
|
||||
runner_lines.append(' cd ~/llama.cpp && rm -rf build && cmake -B build -DCMAKE_BUILD_TYPE=Release \\')
|
||||
runner_lines.append(' && cmake --build build -j"$NPROC" --target llama-server \\')
|
||||
runner_lines.append(' && ln -sf ~/llama.cpp/build/bin/llama-server ~/bin/llama-server')
|
||||
runner_lines.append(' else')
|
||||
_append_llama_cpp_linux_accel_build_lines(runner_lines)
|
||||
runner_lines.append(' fi')
|
||||
runner_lines.append(' # If the native build failed, fall back to the Python bindings.')
|
||||
runner_lines.append(' if ! command -v llama-server &>/dev/null && ! python3 -c "import llama_cpp" 2>/dev/null; then')
|
||||
runner_lines.append(' echo "llama-server build failed — installing Python bindings as fallback..."')
|
||||
runner_lines.append(f" {_pip_install_fallback_chain('llama-cpp-python[server]', python_cmd='pip')} || true")
|
||||
runner_lines.append(' fi')
|
||||
runner_lines.append(' if ! command -v llama-server &>/dev/null && ! python3 -c "import llama_cpp" 2>/dev/null; then')
|
||||
runner_lines.append(' echo "ERROR: llama.cpp serving is not available after install/build attempts."')
|
||||
runner_lines.append(' ODYSSEUS_PREFLIGHT_EXIT=127')
|
||||
runner_lines.append(' fi')
|
||||
runner_lines.append('fi')
|
||||
if local_windows:
|
||||
# LOCAL Windows: no native source compilation (no cmake/compiler on Git Bash).
|
||||
# Just check python bindings (using native `python` binary) and fall back to pip install.
|
||||
runner_lines.append('if ! command -v llama-server &>/dev/null && ! python -c "import llama_cpp" 2>/dev/null; then')
|
||||
runner_lines.append(' echo "llama-server not found — installing Python bindings..."')
|
||||
runner_lines.append(f" {_pip_install_fallback_chain('llama-cpp-python[server]', python_cmd='python')} || true")
|
||||
runner_lines.append('fi')
|
||||
runner_lines.append('if ! command -v llama-server &>/dev/null && ! python -c "import llama_cpp" 2>/dev/null; then')
|
||||
runner_lines.append(' echo "ERROR: llama.cpp serving is not available after install attempts."')
|
||||
runner_lines.append(' ODYSSEUS_PREFLIGHT_EXIT=127')
|
||||
runner_lines.append('fi')
|
||||
else:
|
||||
runner_lines.append('if [ -d /data/data/com.termux ]; then')
|
||||
runner_lines.append(' # Termux: no native build — use the Python bindings (CPU).')
|
||||
runner_lines.append(' if ! python3 -c "import llama_cpp" 2>/dev/null; then')
|
||||
runner_lines.append(' pkg install -y cmake 2>/dev/null')
|
||||
runner_lines.append(' pip install numpy diskcache jinja2 2>/dev/null')
|
||||
runner_lines.append(' CMAKE_ARGS="-DGGML_BLAS=OFF -DGGML_LLAMAFILE=OFF" pip install \'llama-cpp-python[server]\' --no-build-isolation --no-cache-dir 2>&1 || true')
|
||||
runner_lines.append(' fi')
|
||||
runner_lines.append('elif ! command -v llama-server &>/dev/null; then')
|
||||
runner_lines.append(' echo "Native llama-server not found — building from source (one-time, may take a few minutes)..."')
|
||||
runner_lines.append(' mkdir -p ~/bin')
|
||||
runner_lines.append(' cd ~ && [ -d llama.cpp ] || git clone --depth 1 https://github.com/ggml-org/llama.cpp')
|
||||
# Build with the right accelerator: Metal on macOS (llama.cpp
|
||||
# enables it automatically, no flag), CUDA on Linux when present,
|
||||
# else a plain CPU build. nproc is Linux-only — fall back to
|
||||
# `sysctl hw.ncpu` on macOS. (Tip: `brew install llama.cpp` ships
|
||||
# a prebuilt llama-server and skips this whole source build.)
|
||||
runner_lines.append(' NPROC="$(nproc 2>/dev/null || sysctl -n hw.ncpu 2>/dev/null || echo 4)"')
|
||||
runner_lines.append(' if [ "$(uname -s)" = "Darwin" ]; then')
|
||||
runner_lines.append(' command -v cmake >/dev/null 2>&1 || echo "WARNING: cmake not found — install it with: brew install cmake (or: brew install llama.cpp for a prebuilt llama-server)."')
|
||||
# Start from a clean cache: a prior failed configure (e.g. a CUDA
|
||||
# attempt) poisons build/CMakeCache.txt, so a plain `cmake -B build`
|
||||
# would reuse the bad settings and fail again. CMAKE_BUILD_TYPE is
|
||||
# explicit so the binary is optimized (Metal auto-enables on macOS).
|
||||
runner_lines.append(' cd ~/llama.cpp && rm -rf build && cmake -B build -DCMAKE_BUILD_TYPE=Release \\')
|
||||
runner_lines.append(' && cmake --build build -j"$NPROC" --target llama-server \\')
|
||||
runner_lines.append(' && ln -sf ~/llama.cpp/build/bin/llama-server ~/bin/llama-server')
|
||||
runner_lines.append(' else')
|
||||
_append_llama_cpp_linux_accel_build_lines(runner_lines)
|
||||
runner_lines.append(' fi')
|
||||
# If the native build failed, fall back to the Python bindings.
|
||||
runner_lines.append(' if ! command -v llama-server &>/dev/null && ! python3 -c "import llama_cpp" 2>/dev/null; then')
|
||||
runner_lines.append(' echo "llama-server build failed — installing Python bindings as fallback..."')
|
||||
runner_lines.append(f" {_pip_install_fallback_chain('llama-cpp-python[server]', python_cmd='pip')} || true")
|
||||
runner_lines.append(' fi')
|
||||
runner_lines.append(' if ! command -v llama-server &>/dev/null && ! python3 -c "import llama_cpp" 2>/dev/null; then')
|
||||
runner_lines.append(' echo "ERROR: llama.cpp serving is not available after install/build attempts."')
|
||||
runner_lines.append(' ODYSSEUS_PREFLIGHT_EXIT=127')
|
||||
runner_lines.append(' fi')
|
||||
runner_lines.append('fi')
|
||||
elif "ollama" in req.cmd:
|
||||
handled_ollama_serve = True
|
||||
_ollama_default_host = "0.0.0.0" if remote else "127.0.0.1"
|
||||
@@ -1181,13 +1117,23 @@ def setup_cookbook_routes() -> APIRouter:
|
||||
runner_lines.append(' ODYSSEUS_OLLAMA_PORT="$_ody_try_port"')
|
||||
runner_lines.append(' break')
|
||||
runner_lines.append(' fi')
|
||||
runner_lines.append(' exec 3<&-; exec 3>&-')
|
||||
runner_lines.append('done')
|
||||
runner_lines.append(' echo "[odysseus] Ollama API ready on port ${ODYSSEUS_OLLAMA_PORT}: ${ODYSSEUS_OLLAMA_URL}"')
|
||||
runner_lines.append(' echo "[odysseus] This task is monitoring an existing Ollama server; stopping it here will not stop an external Docker/system service."')
|
||||
if local_windows:
|
||||
# Windows detached process has no TTY; exec bash -i crashes.
|
||||
# Keep the monitoring task alive with a sleep loop.
|
||||
runner_lines.append(' while true; do sleep 60; done')
|
||||
else:
|
||||
runner_lines.append(' exec bash -i')
|
||||
runner_lines.append('fi')
|
||||
runner_lines.append('if ! command -v ollama &>/dev/null; then')
|
||||
runner_lines.append(' echo "ERROR: Ollama not found on this server. Install it from https://ollama.com/download or `curl -fsSL https://ollama.com/install.sh | sh`."')
|
||||
runner_lines.append(' echo')
|
||||
runner_lines.append(' echo "=== Process exited with code 127 ==="')
|
||||
runner_lines.append(' exec bash -i')
|
||||
if local_windows:
|
||||
runner_lines.append(' exit 127')
|
||||
else:
|
||||
runner_lines.append(' exec bash -i')
|
||||
runner_lines.append('fi')
|
||||
runner_lines.append('ODYSSEUS_OLLAMA_URL="http://${ODYSSEUS_OLLAMA_HOST}:${ODYSSEUS_OLLAMA_PORT}"')
|
||||
if remote and _ollama_host in ("0.0.0.0", "::"):
|
||||
@@ -1195,24 +1141,20 @@ def setup_cookbook_routes() -> APIRouter:
|
||||
runner_lines.append('echo "[odysseus] Ollama has no built-in authentication; expose this only on a trusted LAN/VPN or provide an explicit OLLAMA_HOST with your own access controls."')
|
||||
runner_lines.append('echo "Starting ollama server on ${ODYSSEUS_OLLAMA_HOST}:${ODYSSEUS_OLLAMA_PORT}..."')
|
||||
runner_lines.append('OLLAMA_HOST="${ODYSSEUS_OLLAMA_HOST}:${ODYSSEUS_OLLAMA_PORT}" ollama serve')
|
||||
runner_lines.append('_ody_exit=$?')
|
||||
runner_lines.append('echo')
|
||||
runner_lines.append('echo "=== Process exited with code ${_ody_exit} ==="')
|
||||
runner_lines.append('exec bash -i')
|
||||
if local_windows:
|
||||
_append_serve_exit_code_lines(runner_lines, keep_shell_open=False)
|
||||
else:
|
||||
runner_lines.append('_ody_exit=$?')
|
||||
runner_lines.append('echo')
|
||||
runner_lines.append('echo "=== Process exited with code ${_ody_exit} ==="')
|
||||
runner_lines.append('exec bash -i')
|
||||
elif "vllm serve" in req.cmd:
|
||||
# vLLM is CUDA/ROCm-only and does not run on macOS at all.
|
||||
runner_lines.append('if [ "$(uname -s)" = "Darwin" ]; then')
|
||||
runner_lines.append(' echo "ERROR: vLLM does not run on macOS. Use Ollama or llama.cpp (Metal) instead."')
|
||||
runner_lines.append(' ODYSSEUS_PREFLIGHT_EXIT=1')
|
||||
runner_lines.append('fi')
|
||||
# Put ~/.local/bin on PATH first — without a venv, vllm installs
|
||||
# there via --user and the non-login serve shell otherwise can't
|
||||
# find the `vllm` CLI ("command not found"). Mirrors llama.cpp above.
|
||||
runner_lines.append('export PATH="$HOME/.local/bin:$PATH"')
|
||||
runner_lines.append('if ! command -v vllm &>/dev/null; then')
|
||||
runner_lines.append(' echo "ERROR: vLLM is not installed."')
|
||||
runner_lines.append(' ODYSSEUS_PREFLIGHT_EXIT=127')
|
||||
runner_lines.append('fi')
|
||||
_append_vllm_linux_preflight_lines(runner_lines)
|
||||
elif "sglang.launch_server" in req.cmd:
|
||||
runner_lines.append('export PATH="$HOME/.local/bin:$PATH"')
|
||||
runner_lines.append('if ! command -v sglang &>/dev/null; then')
|
||||
@@ -1236,7 +1178,10 @@ def setup_cookbook_routes() -> APIRouter:
|
||||
runner_lines,
|
||||
keep_shell_open=not local_windows,
|
||||
)
|
||||
runner_lines.append(req.cmd)
|
||||
if is_pip_install:
|
||||
_append_pip_install_runner_lines(runner_lines, req.cmd)
|
||||
else:
|
||||
runner_lines.append(req.cmd)
|
||||
if local_windows:
|
||||
# Detached background process — no interactive shell to keep open.
|
||||
# Print the exit marker the status poller looks for, then stop.
|
||||
@@ -1397,8 +1342,8 @@ def setup_cookbook_routes() -> APIRouter:
|
||||
cmd = f"ssh {pf}{host} '{setup_script}'"
|
||||
else:
|
||||
# Linux: auto-install tmux (via whichever package manager is available)
|
||||
# and huggingface_hub + hf_transfer (falling back to --user/--break-system-packages
|
||||
# on PEP-668 locked distros like Arch / newer Debian).
|
||||
# and huggingface_hub + hf_transfer (falling back to --user, then
|
||||
# guarded --break-system-packages on PEP-668 locked distros).
|
||||
setup_script = (
|
||||
# Install tmux if missing — try common package managers; skip if no sudo
|
||||
"if ! command -v tmux >/dev/null 2>&1; then "
|
||||
@@ -1410,10 +1355,15 @@ def setup_cookbook_routes() -> APIRouter:
|
||||
" fi; "
|
||||
"fi; "
|
||||
"command -v tmux >/dev/null 2>&1 || echo 'WARNING: tmux missing and auto-install failed (need passwordless sudo). Install manually.'; "
|
||||
# Install Python bits. Try system install first; fall back to --user --break-system-packages on PEP 668 systems.
|
||||
# Install Python bits. Try system install first; fall back to --user,
|
||||
# then use --break-system-packages only when pip supports it.
|
||||
"pip install -q huggingface_hub hf_transfer 2>/dev/null || "
|
||||
"pip install --user --break-system-packages -q huggingface_hub hf_transfer 2>/dev/null || "
|
||||
"pip3 install --user --break-system-packages -q huggingface_hub hf_transfer 2>/dev/null; "
|
||||
"pip install --user -q huggingface_hub hf_transfer 2>/dev/null || "
|
||||
"( pip install --help 2>/dev/null | grep -q -- --break-system-packages && "
|
||||
"pip install --user --break-system-packages -q huggingface_hub hf_transfer 2>/dev/null ) || "
|
||||
"pip3 install --user -q huggingface_hub hf_transfer 2>/dev/null || "
|
||||
"( pip3 install --help 2>/dev/null | grep -q -- --break-system-packages && "
|
||||
"pip3 install --user --break-system-packages -q huggingface_hub hf_transfer 2>/dev/null ); "
|
||||
"python3 -c 'from huggingface_hub import snapshot_download; print(\"OK\")'"
|
||||
)
|
||||
cmd = f"ssh {pf}{host} '{setup_script}'"
|
||||
@@ -1436,11 +1386,38 @@ def setup_cookbook_routes() -> APIRouter:
|
||||
async def _run_nvidia_smi(query: str, host: str | None, ssh_port: str | None, timeout: int = 8):
|
||||
"""Run nvidia-smi locally or over SSH. Returns (stdout, error_or_None)."""
|
||||
if host:
|
||||
pf = f"-p {ssh_port} " if ssh_port and ssh_port != "22" else ""
|
||||
cmd = f"ssh -o ConnectTimeout=5 -o StrictHostKeyChecking=no {pf}{host} '{query}'"
|
||||
proc = await asyncio.create_subprocess_shell(
|
||||
cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE
|
||||
)
|
||||
candidates = [query]
|
||||
stripped = query.strip()
|
||||
if stripped.startswith("nvidia-smi "):
|
||||
args = stripped[len("nvidia-smi "):]
|
||||
candidates.append(
|
||||
"bash -lc "
|
||||
+ shlex.quote(
|
||||
f"{SSH_PATH_OVERRIDE}"
|
||||
f"nvidia-smi {args}"
|
||||
)
|
||||
)
|
||||
for nvidia_path in NVIDIA_PATH_CANDIDATES:
|
||||
candidates.append(f"{nvidia_path} {args}")
|
||||
|
||||
last_err = "nvidia-smi failed"
|
||||
for candidate in candidates:
|
||||
try:
|
||||
rc, stdout, stderr = await run_ssh_command_async(
|
||||
host,
|
||||
ssh_port,
|
||||
candidate,
|
||||
connect_timeout=5,
|
||||
timeout=timeout,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
return None, "nvidia-smi timed out"
|
||||
if rc == 0:
|
||||
return stdout.decode("utf-8", errors="replace"), None
|
||||
err = (stderr.decode("utf-8", errors="replace") or "").strip()[:200]
|
||||
if err:
|
||||
last_err = err
|
||||
return None, last_err
|
||||
else:
|
||||
proc = await asyncio.create_subprocess_exec(
|
||||
*shlex.split(query),
|
||||
@@ -2203,7 +2180,13 @@ def setup_cookbook_routes() -> APIRouter:
|
||||
"inc=os.path.isdir(blobs) and any(x.endswith('.incomplete') for x in os.listdir(blobs));"
|
||||
"sys.exit(0 if ok and not inc else 1)"
|
||||
)
|
||||
cmd = ["python3", "-c", py, repo_id]
|
||||
if remote_host:
|
||||
cmd = ["python3", "-c", py, repo_id]
|
||||
else:
|
||||
# Local Windows: python3 can hit the Microsoft Store stub. Use the
|
||||
# real Python Odysseus is running under (guaranteed to exist).
|
||||
import sys as _sys_local
|
||||
cmd = [_sys_local.executable, "-c", py, repo_id]
|
||||
try:
|
||||
if remote_host:
|
||||
ssh_base = ["ssh"]
|
||||
|
||||
+67
-117
@@ -20,39 +20,26 @@ All routes are admin-gated (endpoint/provider management is an admin action).
|
||||
"""
|
||||
|
||||
import json
|
||||
import time
|
||||
import uuid
|
||||
import logging
|
||||
import threading
|
||||
from typing import Dict, Optional
|
||||
|
||||
import httpx
|
||||
from fastapi import APIRouter, Request, Form, HTTPException
|
||||
from fastapi import HTTPException, Request
|
||||
|
||||
from core.database import SessionLocal, ModelEndpoint
|
||||
from core.middleware import require_admin
|
||||
from routes.device_flow import (
|
||||
DeviceFlowPoll,
|
||||
DeviceFlowStart,
|
||||
PendingDeviceFlowStore,
|
||||
create_device_flow_router,
|
||||
)
|
||||
from src.auth_helpers import get_current_user
|
||||
from src import copilot
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Pending device-flow logins, keyed by an opaque poll_id. The device_code is a
|
||||
# bearer-like secret, so it lives here (server memory) rather than in the
|
||||
# browser. Entries expire with the GitHub device code.
|
||||
#
|
||||
# NOTE: this is per-process state. The device flow assumes a single worker
|
||||
# (Odysseus' default): with multiple uvicorn workers, the poll request can land
|
||||
# on a worker that never saw the start, returning "Unknown or expired login
|
||||
# session". Move this to a shared store (DB/Redis) if running multi-worker.
|
||||
_PENDING: Dict[str, Dict] = {}
|
||||
_PENDING_LOCK = threading.Lock()
|
||||
|
||||
|
||||
def _prune_expired() -> None:
|
||||
now = time.time()
|
||||
with _PENDING_LOCK:
|
||||
for k in [k for k, v in _PENDING.items() if v.get("expires_at", 0) < now]:
|
||||
_PENDING.pop(k, None)
|
||||
_DEVICE_FLOW_STORE = PendingDeviceFlowStore()
|
||||
|
||||
|
||||
def _provision_endpoint(token: str, base: str, owner: Optional[str]) -> Dict:
|
||||
@@ -112,112 +99,75 @@ def _provision_endpoint(token: str, base: str, owner: Optional[str]) -> Dict:
|
||||
return result
|
||||
|
||||
|
||||
def setup_copilot_routes() -> APIRouter:
|
||||
router = APIRouter(prefix="/api/copilot", tags=["copilot"])
|
||||
def _start_device_flow(request: Request, form) -> DeviceFlowStart:
|
||||
host = copilot.GITHUB_HOST
|
||||
ent = str(form.get("enterprise_url") or "").strip()
|
||||
if ent:
|
||||
host = copilot.normalize_domain(ent)
|
||||
try:
|
||||
data = copilot.request_device_code(host)
|
||||
except httpx.HTTPStatusError as e:
|
||||
status = e.response.status_code if e.response is not None else "unknown"
|
||||
raise HTTPException(502, f"GitHub device-code request failed (HTTP {status})")
|
||||
except Exception as e:
|
||||
raise HTTPException(502, f"GitHub device-code request failed: {e}")
|
||||
|
||||
@router.post("/device/start")
|
||||
def device_start(request: Request, enterprise_url: str = Form("")):
|
||||
require_admin(request)
|
||||
_prune_expired()
|
||||
host = copilot.GITHUB_HOST
|
||||
ent = (enterprise_url or "").strip()
|
||||
if ent:
|
||||
host = copilot.normalize_domain(ent)
|
||||
try:
|
||||
data = copilot.request_device_code(host)
|
||||
except httpx.HTTPStatusError as e:
|
||||
status = e.response.status_code if e.response is not None else "unknown"
|
||||
raise HTTPException(502, f"GitHub device-code request failed (HTTP {status})")
|
||||
except Exception as e:
|
||||
raise HTTPException(502, f"GitHub device-code request failed: {e}")
|
||||
device_code = data.get("device_code")
|
||||
if not device_code:
|
||||
raise HTTPException(502, "GitHub did not return a device code")
|
||||
|
||||
device_code = data.get("device_code")
|
||||
if not device_code:
|
||||
raise HTTPException(502, "GitHub did not return a device code")
|
||||
interval = int(data.get("interval") or 5)
|
||||
expires_in = int(data.get("expires_in") or 900)
|
||||
poll_id = uuid.uuid4().hex
|
||||
with _PENDING_LOCK:
|
||||
_PENDING[poll_id] = {
|
||||
"device_code": device_code,
|
||||
"host": host,
|
||||
"enterprise_url": ent,
|
||||
"interval": interval,
|
||||
"owner": get_current_user(request) or None,
|
||||
"expires_at": time.time() + expires_in,
|
||||
"next_poll_at": 0.0,
|
||||
}
|
||||
# verification_uri_complete embeds the user code, so the browser tab we
|
||||
# open lands the user straight on GitHub's "Authorize" screen with the
|
||||
# code pre-filled — one click, no manual code entry.
|
||||
return {
|
||||
"poll_id": poll_id,
|
||||
# verification_uri_complete embeds the user code, so the browser tab we
|
||||
# open lands the user straight on GitHub's "Authorize" screen with the
|
||||
# code pre-filled — one click, no manual code entry.
|
||||
return DeviceFlowStart(
|
||||
pending={
|
||||
"device_code": device_code,
|
||||
"host": host,
|
||||
"enterprise_url": ent,
|
||||
"owner": get_current_user(request) or None,
|
||||
},
|
||||
response={
|
||||
"user_code": data.get("user_code"),
|
||||
"verification_uri": data.get("verification_uri"),
|
||||
"verification_uri_complete": data.get("verification_uri_complete"),
|
||||
"interval": interval,
|
||||
"expires_in": expires_in,
|
||||
}
|
||||
},
|
||||
interval=int(data.get("interval") or 5),
|
||||
expires_in=int(data.get("expires_in") or 900),
|
||||
)
|
||||
|
||||
@router.post("/device/poll")
|
||||
def device_poll(request: Request, poll_id: str = Form(...)):
|
||||
require_admin(request)
|
||||
_prune_expired()
|
||||
with _PENDING_LOCK:
|
||||
pending = _PENDING.get(poll_id)
|
||||
if not pending:
|
||||
raise HTTPException(404, "Unknown or expired login session")
|
||||
|
||||
# Enforce GitHub's polling interval server-side so a chatty client
|
||||
# can't trip slow_down.
|
||||
now = time.time()
|
||||
if now < pending.get("next_poll_at", 0):
|
||||
return {"status": "pending"}
|
||||
def _poll_device_flow(_request: Request, pending: Dict) -> DeviceFlowPoll:
|
||||
try:
|
||||
data = copilot.poll_access_token(pending["host"], pending["device_code"])
|
||||
except Exception as e:
|
||||
return DeviceFlowPoll.pending(f"poll error: {e}")
|
||||
|
||||
token = data.get("access_token")
|
||||
if token:
|
||||
base = copilot.enterprise_base(pending["enterprise_url"]) if pending["enterprise_url"] else copilot.COPILOT_BASE
|
||||
try:
|
||||
data = copilot.poll_access_token(pending["host"], pending["device_code"])
|
||||
result = _provision_endpoint(token, base, pending["owner"])
|
||||
except Exception as e:
|
||||
return {"status": "pending", "detail": f"poll error: {e}"}
|
||||
logger.exception("Copilot endpoint provisioning failed")
|
||||
raise HTTPException(500, f"Login succeeded but provisioning failed: {e}")
|
||||
return DeviceFlowPoll.authorized(result)
|
||||
|
||||
token = data.get("access_token")
|
||||
if token:
|
||||
base = copilot.enterprise_base(pending["enterprise_url"]) if pending["enterprise_url"] else copilot.COPILOT_BASE
|
||||
try:
|
||||
result = _provision_endpoint(token, base, pending["owner"])
|
||||
except Exception as e:
|
||||
logger.exception("Copilot endpoint provisioning failed")
|
||||
with _PENDING_LOCK:
|
||||
_PENDING.pop(poll_id, None)
|
||||
raise HTTPException(500, f"Login succeeded but provisioning failed: {e}")
|
||||
with _PENDING_LOCK:
|
||||
_PENDING.pop(poll_id, None)
|
||||
return {"status": "authorized", "endpoint": result}
|
||||
err = data.get("error")
|
||||
if err == "authorization_pending":
|
||||
return DeviceFlowPoll.pending()
|
||||
if err == "slow_down":
|
||||
return DeviceFlowPoll.slow_down(int(data.get("interval") or 0) or None)
|
||||
if err in ("expired_token", "access_denied"):
|
||||
return DeviceFlowPoll.failed(err)
|
||||
# Unknown error — surface but keep the session for another try.
|
||||
return DeviceFlowPoll.pending(err or "unknown")
|
||||
|
||||
err = data.get("error")
|
||||
if err == "authorization_pending":
|
||||
with _PENDING_LOCK:
|
||||
if poll_id in _PENDING:
|
||||
_PENDING[poll_id]["next_poll_at"] = now + pending["interval"]
|
||||
return {"status": "pending"}
|
||||
if err == "slow_down":
|
||||
new_interval = int(data.get("interval") or (pending["interval"] + 5))
|
||||
with _PENDING_LOCK:
|
||||
if poll_id in _PENDING:
|
||||
_PENDING[poll_id]["interval"] = new_interval
|
||||
_PENDING[poll_id]["next_poll_at"] = now + new_interval
|
||||
return {"status": "pending"}
|
||||
if err in ("expired_token", "access_denied"):
|
||||
with _PENDING_LOCK:
|
||||
_PENDING.pop(poll_id, None)
|
||||
return {"status": "failed", "error": err}
|
||||
# Unknown error — surface but keep the session for another try.
|
||||
return {"status": "pending", "detail": err or "unknown"}
|
||||
|
||||
@router.post("/device/cancel")
|
||||
def device_cancel(request: Request, poll_id: str = Form(...)):
|
||||
require_admin(request)
|
||||
with _PENDING_LOCK:
|
||||
_PENDING.pop(poll_id, None)
|
||||
return {"status": "cancelled"}
|
||||
|
||||
return router
|
||||
def setup_copilot_routes():
|
||||
return create_device_flow_router(
|
||||
prefix="/api/copilot",
|
||||
tags=["copilot"],
|
||||
store=_DEVICE_FLOW_STORE,
|
||||
start_flow=_start_device_flow,
|
||||
poll_flow=_poll_device_flow,
|
||||
)
|
||||
|
||||
@@ -0,0 +1,193 @@
|
||||
"""Shared OAuth/device-flow route scaffolding for provider setup."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Iterable, Mapping, Optional
|
||||
|
||||
from fastapi import APIRouter, Form, HTTPException, Request
|
||||
|
||||
from core.middleware import require_admin
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DeviceFlowStart:
|
||||
"""Provider-specific start result consumed by the shared route wrapper."""
|
||||
|
||||
pending: Mapping[str, Any]
|
||||
response: Mapping[str, Any]
|
||||
interval: int = 5
|
||||
expires_in: int = 900
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DeviceFlowPoll:
|
||||
"""Normalized provider poll outcome."""
|
||||
|
||||
status: str
|
||||
endpoint: Optional[Mapping[str, Any]] = None
|
||||
error: Optional[str] = None
|
||||
detail: Optional[str] = None
|
||||
interval: Optional[int] = None
|
||||
|
||||
@classmethod
|
||||
def pending(cls, detail: Optional[str] = None) -> "DeviceFlowPoll":
|
||||
return cls(status="pending", detail=detail)
|
||||
|
||||
@classmethod
|
||||
def slow_down(cls, interval: Optional[int] = None, detail: Optional[str] = None) -> "DeviceFlowPoll":
|
||||
return cls(status="slow_down", interval=interval, detail=detail)
|
||||
|
||||
@classmethod
|
||||
def authorized(cls, endpoint: Mapping[str, Any]) -> "DeviceFlowPoll":
|
||||
return cls(status="authorized", endpoint=endpoint)
|
||||
|
||||
@classmethod
|
||||
def failed(cls, error: str) -> "DeviceFlowPoll":
|
||||
return cls(status="failed", error=error)
|
||||
|
||||
|
||||
class PendingDeviceFlowStore:
|
||||
"""Thread-safe in-memory pending device-flow store.
|
||||
|
||||
Device codes and provider-side secrets stay inside this process. Each entry
|
||||
stores provider payload separately from poll metadata so provider callbacks
|
||||
only receive the fields they created.
|
||||
"""
|
||||
|
||||
def __init__(self, *, time_func: Callable[[], float] = time.time):
|
||||
self._pending: dict[str, dict[str, Any]] = {}
|
||||
self._lock = threading.Lock()
|
||||
self._time = time_func
|
||||
|
||||
def _now(self) -> float:
|
||||
return float(self._time())
|
||||
|
||||
def prune_expired(self) -> None:
|
||||
now = self._now()
|
||||
with self._lock:
|
||||
for key in [k for k, v in self._pending.items() if v.get("expires_at", 0) < now]:
|
||||
self._pending.pop(key, None)
|
||||
|
||||
def add(self, payload: Mapping[str, Any], *, interval: int, expires_in: int) -> str:
|
||||
self.prune_expired()
|
||||
poll_id = uuid.uuid4().hex
|
||||
with self._lock:
|
||||
self._pending[poll_id] = {
|
||||
"payload": dict(payload),
|
||||
"interval": max(int(interval or 5), 1),
|
||||
"expires_at": self._now() + max(int(expires_in or 900), 1),
|
||||
"next_poll_at": 0.0,
|
||||
}
|
||||
return poll_id
|
||||
|
||||
def get_payload(self, poll_id: str) -> Optional[dict[str, Any]]:
|
||||
self.prune_expired()
|
||||
with self._lock:
|
||||
entry = self._pending.get(poll_id)
|
||||
if entry is None:
|
||||
return None
|
||||
return dict(entry.get("payload") or {})
|
||||
|
||||
def is_throttled(self, poll_id: str) -> bool:
|
||||
with self._lock:
|
||||
entry = self._pending.get(poll_id)
|
||||
return bool(entry and self._now() < float(entry.get("next_poll_at") or 0))
|
||||
|
||||
def schedule_next(self, poll_id: str) -> None:
|
||||
now = self._now()
|
||||
with self._lock:
|
||||
entry = self._pending.get(poll_id)
|
||||
if entry is not None:
|
||||
entry["next_poll_at"] = now + int(entry.get("interval") or 5)
|
||||
|
||||
def slow_down(self, poll_id: str, interval: Optional[int] = None) -> None:
|
||||
now = self._now()
|
||||
with self._lock:
|
||||
entry = self._pending.get(poll_id)
|
||||
if entry is not None:
|
||||
new_interval = int(interval or (int(entry.get("interval") or 5) + 5))
|
||||
entry["interval"] = max(new_interval, 1)
|
||||
entry["next_poll_at"] = now + entry["interval"]
|
||||
|
||||
def pop(self, poll_id: str) -> None:
|
||||
with self._lock:
|
||||
self._pending.pop(poll_id, None)
|
||||
|
||||
|
||||
async def _maybe_await(value: Any) -> Any:
|
||||
if inspect.isawaitable(value):
|
||||
return await value
|
||||
return value
|
||||
|
||||
|
||||
def _pending_response(detail: Optional[str] = None) -> dict[str, Any]:
|
||||
response: dict[str, Any] = {"status": "pending"}
|
||||
if detail:
|
||||
response["detail"] = detail
|
||||
return response
|
||||
|
||||
|
||||
def create_device_flow_router(
|
||||
*,
|
||||
prefix: str,
|
||||
tags: Iterable[str],
|
||||
store: PendingDeviceFlowStore,
|
||||
start_flow: Callable[[Request, Mapping[str, Any]], DeviceFlowStart],
|
||||
poll_flow: Callable[[Request, Mapping[str, Any]], DeviceFlowPoll],
|
||||
) -> APIRouter:
|
||||
"""Create standard `/device/start|poll|cancel` routes for a provider."""
|
||||
|
||||
router = APIRouter(prefix=prefix, tags=list(tags))
|
||||
|
||||
@router.post("/device/start")
|
||||
async def device_start(request: Request):
|
||||
require_admin(request)
|
||||
form = await request.form()
|
||||
start = await _maybe_await(start_flow(request, form))
|
||||
interval = int(start.interval or 5)
|
||||
expires_in = int(start.expires_in or 900)
|
||||
poll_id = store.add(start.pending, interval=interval, expires_in=expires_in)
|
||||
response = dict(start.response)
|
||||
response.update({"poll_id": poll_id, "interval": interval, "expires_in": expires_in})
|
||||
return response
|
||||
|
||||
@router.post("/device/poll")
|
||||
async def device_poll(request: Request, poll_id: str = Form(...)):
|
||||
require_admin(request)
|
||||
payload = store.get_payload(poll_id)
|
||||
if payload is None:
|
||||
raise HTTPException(404, "Unknown or expired login session")
|
||||
if store.is_throttled(poll_id):
|
||||
return {"status": "pending"}
|
||||
|
||||
try:
|
||||
outcome = await _maybe_await(poll_flow(request, payload))
|
||||
except Exception:
|
||||
store.pop(poll_id)
|
||||
raise
|
||||
|
||||
if outcome.status == "authorized":
|
||||
store.pop(poll_id)
|
||||
return {"status": "authorized", "endpoint": dict(outcome.endpoint or {})}
|
||||
if outcome.status == "failed":
|
||||
store.pop(poll_id)
|
||||
return {"status": "failed", "error": outcome.error or "denied"}
|
||||
if outcome.status == "slow_down":
|
||||
store.slow_down(poll_id, outcome.interval)
|
||||
return _pending_response(outcome.detail)
|
||||
|
||||
store.schedule_next(poll_id)
|
||||
return _pending_response(outcome.detail)
|
||||
|
||||
@router.post("/device/cancel")
|
||||
def device_cancel(request: Request, poll_id: str = Form(...)):
|
||||
require_admin(request)
|
||||
store.pop(poll_id)
|
||||
return {"status": "cancelled"}
|
||||
|
||||
return router
|
||||
+66
-36
@@ -7,14 +7,24 @@ from typing import Dict, Any, List, Optional
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Query, Request, UploadFile, File, Form
|
||||
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import case, func, or_
|
||||
from core.database import SessionLocal, Document, DocumentVersion
|
||||
from core.database import Session as DbSession
|
||||
from src.auth_helpers import get_current_user
|
||||
from src.constants import MAIL_ATTACHMENTS_DIR
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _get_session_or_404(db, session_id: str, user: Optional[str]):
|
||||
session = db.query(DbSession).filter(DbSession.id == session_id).first()
|
||||
if not session:
|
||||
raise HTTPException(404, "Session not found")
|
||||
if user and session.owner != user:
|
||||
raise HTTPException(404, "Session not found")
|
||||
return session
|
||||
|
||||
|
||||
def _aggregate_language_facets(lang_rows):
|
||||
"""Sum document counts per display language for the library facet.
|
||||
|
||||
@@ -30,6 +40,19 @@ def _aggregate_language_facets(lang_rows):
|
||||
return out
|
||||
|
||||
|
||||
def _library_language_for_document(doc: Document) -> str:
|
||||
"""Return the display language used by the document library.
|
||||
|
||||
PDF documents are stored as markdown wrappers so the editor can preserve
|
||||
extracted text, form fields, and annotations. The library should still
|
||||
identify them as PDFs instead of exposing that internal wrapper format.
|
||||
"""
|
||||
from src.pdf_form_doc import find_source_upload_id
|
||||
|
||||
if find_source_upload_id(doc.current_content or ""):
|
||||
return "pdf"
|
||||
return doc.language or "text"
|
||||
|
||||
|
||||
from routes.document_helpers import (
|
||||
DocumentCreate, DocumentUpdate, DocumentPatch,
|
||||
@@ -69,17 +92,12 @@ def setup_document_routes(session_manager, upload_handler=None) -> APIRouter:
|
||||
# the doc is owner-stamped, so it lives in the library on its own.
|
||||
session = None
|
||||
if req.session_id:
|
||||
session = db.query(DbSession).filter(DbSession.id == req.session_id).first()
|
||||
if not session:
|
||||
raise HTTPException(404, "Session not found")
|
||||
# Match the lenient ownership model the rest of the app uses
|
||||
# (see _owner_filter): only block when an AUTHENTICATED user is
|
||||
# writing into a DIFFERENT user's session. In single-user /
|
||||
# unconfigured / localhost-bypass mode the middleware leaves
|
||||
# current_user unset (None), and those sessions are already
|
||||
# served freely everywhere else.
|
||||
if user and session.owner and session.owner != user:
|
||||
raise HTTPException(403, "Cannot create document in another user's session")
|
||||
# unconfigured / localhost-bypass mode, falsey users preserve
|
||||
# the existing lenient path.
|
||||
session = _get_session_or_404(db, req.session_id, user)
|
||||
|
||||
doc_id = str(uuid.uuid4())
|
||||
ver_id = str(uuid.uuid4())
|
||||
@@ -171,11 +189,7 @@ def setup_document_routes(session_manager, upload_handler=None) -> APIRouter:
|
||||
if session_id:
|
||||
db = SessionLocal()
|
||||
try:
|
||||
sess = db.query(DbSession).filter(DbSession.id == session_id).first()
|
||||
if not sess:
|
||||
raise HTTPException(404, "Session not found")
|
||||
if user and sess.owner and sess.owner != user:
|
||||
raise HTTPException(403, "Cannot import into another user's session")
|
||||
_get_session_or_404(db, session_id, user)
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
@@ -198,7 +212,7 @@ def setup_document_routes(session_manager, upload_handler=None) -> APIRouter:
|
||||
|
||||
title = os.path.splitext(meta.get("original_name") or meta.get("name") or upload_id)[0]
|
||||
try:
|
||||
body_text = strip_pdf_content_marker(_process_pdf(pdf_path))
|
||||
body_text = strip_pdf_content_marker(_process_pdf(pdf_path, owner=user))
|
||||
except Exception:
|
||||
body_text = None
|
||||
|
||||
@@ -260,18 +274,29 @@ def setup_document_routes(session_manager, upload_handler=None) -> APIRouter:
|
||||
db = SessionLocal()
|
||||
try:
|
||||
from sqlalchemy import or_
|
||||
pdf_marker_cond = or_(
|
||||
Document.current_content.like('%<!-- pdf_source upload_id="%'),
|
||||
Document.current_content.like('%<!-- pdf_form_source upload_id="%'),
|
||||
)
|
||||
library_language_expr = case(
|
||||
(pdf_marker_cond, "pdf"),
|
||||
(Document.language.is_(None), "text"),
|
||||
else_=Document.language,
|
||||
)
|
||||
# Archived view shows ONLY archived docs; the default view excludes
|
||||
# them (NULL = legacy rows that predate the column = not archived).
|
||||
_arch_cond = (Document.archived == True) if archived else or_(
|
||||
Document.archived == False, Document.archived.is_(None))
|
||||
# Language facet counts (owner-filtered)
|
||||
# Language facet counts (owner-filtered). PDF documents are stored
|
||||
# as markdown wrappers, so group by the library display language
|
||||
# instead of the raw stored language.
|
||||
lang_q = (
|
||||
db.query(Document.language, func.count(Document.id))
|
||||
db.query(library_language_expr, func.count(Document.id))
|
||||
.outerjoin(DbSession, Document.session_id == DbSession.id)
|
||||
.filter(Document.is_active == True).filter(_arch_cond)
|
||||
)
|
||||
lang_q = _owner_session_filter(lang_q, user)
|
||||
lang_rows = lang_q.group_by(Document.language).all()
|
||||
lang_rows = lang_q.group_by(library_language_expr).all()
|
||||
languages = _aggregate_language_facets(lang_rows)
|
||||
|
||||
# Session count (owner-filtered)
|
||||
@@ -303,12 +328,17 @@ def setup_document_routes(session_manager, upload_handler=None) -> APIRouter:
|
||||
Document.title.ilike(term) | Document.current_content.ilike(term)
|
||||
)
|
||||
|
||||
# Language filter
|
||||
# Language filter. "pdf" is a display language derived from the
|
||||
# source marker; "markdown" excludes those wrappers.
|
||||
if language:
|
||||
if language == "text":
|
||||
q = q.filter((Document.language == None) | (Document.language == "text"))
|
||||
elif language == "pdf":
|
||||
q = q.filter(pdf_marker_cond)
|
||||
else:
|
||||
q = q.filter(Document.language == language)
|
||||
if language == "markdown":
|
||||
q = q.filter(~pdf_marker_cond)
|
||||
|
||||
# Total before pagination
|
||||
total = q.count()
|
||||
@@ -332,7 +362,7 @@ def setup_document_routes(session_manager, upload_handler=None) -> APIRouter:
|
||||
"session_id": doc.session_id,
|
||||
"session_name": session_name,
|
||||
"title": doc.title,
|
||||
"language": doc.language or "text",
|
||||
"language": _library_language_for_document(doc),
|
||||
"preview": (doc.current_content or "")[:500],
|
||||
"version_count": doc.version_count,
|
||||
"created_at": (doc.created_at.isoformat() + "Z") if doc.created_at else None,
|
||||
@@ -359,18 +389,17 @@ def setup_document_routes(session_manager, upload_handler=None) -> APIRouter:
|
||||
try:
|
||||
if not user:
|
||||
raise HTTPException(403, "Authentication required")
|
||||
session = db.query(DbSession).filter(DbSession.id == session_id).first()
|
||||
# v2 review HIGH-9: raise 403 explicitly when the caller
|
||||
# can't see this session, instead of returning [] which the
|
||||
# UI treats identically to "no docs" and silently masks
|
||||
# auth failures.
|
||||
if not session:
|
||||
raise HTTPException(404, "Session not found")
|
||||
if user and session.owner and session.owner != user:
|
||||
raise HTTPException(403, "Access denied")
|
||||
docs = db.query(Document).filter(
|
||||
_get_session_or_404(db, session_id, user)
|
||||
q = db.query(Document).filter(
|
||||
Document.session_id == session_id
|
||||
).order_by(Document.created_at.desc()).all()
|
||||
)
|
||||
if user:
|
||||
q = q.filter(or_(Document.owner == user, Document.owner.is_(None)))
|
||||
docs = q.order_by(Document.created_at.desc()).all()
|
||||
return [_doc_to_dict(d) for d in docs]
|
||||
finally:
|
||||
db.close()
|
||||
@@ -437,7 +466,7 @@ def setup_document_routes(session_manager, upload_handler=None) -> APIRouter:
|
||||
raise HTTPException(404, "Source PDF could not be located")
|
||||
|
||||
try:
|
||||
body_text = strip_pdf_content_marker(_process_pdf(pdf_path))
|
||||
body_text = strip_pdf_content_marker(_process_pdf(pdf_path, owner=user))
|
||||
except Exception as e:
|
||||
logger.error(f"extract_pdf_text failed for {pdf_path}: {e}")
|
||||
raise HTTPException(500, f"Extraction failed: {e}")
|
||||
@@ -606,6 +635,8 @@ def setup_document_routes(session_manager, upload_handler=None) -> APIRouter:
|
||||
doc.language = req.language
|
||||
if req.session_id is not None:
|
||||
# Empty string = unlink from session
|
||||
if req.session_id:
|
||||
_get_session_or_404(db, req.session_id, user)
|
||||
doc.session_id = req.session_id if req.session_id else None
|
||||
if not req.session_id:
|
||||
# Tab closed / doc detached from its session — drop the
|
||||
@@ -855,10 +886,10 @@ def setup_document_routes(session_manager, upload_handler=None) -> APIRouter:
|
||||
from src.llm_core import llm_call_async
|
||||
|
||||
user = get_current_user(request)
|
||||
url, model, headers = resolve_task_endpoint()
|
||||
url, model, headers = resolve_task_endpoint(owner=user or None)
|
||||
if not url or not model:
|
||||
# Fall back to default endpoint
|
||||
url, model, headers = resolve_endpoint("default")
|
||||
url, model, headers = resolve_endpoint("default", owner=user or None)
|
||||
if not url or not model:
|
||||
raise HTTPException(500, "No endpoint configured for AI tidy")
|
||||
|
||||
@@ -1158,7 +1189,7 @@ def setup_document_routes(session_manager, upload_handler=None) -> APIRouter:
|
||||
settings = _load_vl_settings()
|
||||
vl_model = settings.get("vision_model", "")
|
||||
try:
|
||||
url, model_id, headers = _resolve_vl_model(vl_model)
|
||||
url, model_id, headers = _resolve_vl_model(vl_model, owner=user)
|
||||
except Exception as e:
|
||||
raise HTTPException(503, f"No vision model available: {e}")
|
||||
|
||||
@@ -1512,10 +1543,7 @@ def setup_document_routes(session_manager, upload_handler=None) -> APIRouter:
|
||||
# don't import from a routes file (cycle-prone). Same env override
|
||||
# as email_routes (ODYSSEUS_MAIL_ATTACHMENTS_DIR).
|
||||
from pathlib import Path as _Path
|
||||
import os as _os
|
||||
_DATA_DIR = _Path(__file__).resolve().parent.parent / "data"
|
||||
_BASE = _os.environ.get("ODYSSEUS_MAIL_ATTACHMENTS_DIR", str(_DATA_DIR / "mail-attachments"))
|
||||
_COMPOSE_DIR = _Path(_BASE) / "_compose"
|
||||
_COMPOSE_DIR = _Path(MAIL_ATTACHMENTS_DIR) / "_compose"
|
||||
_COMPOSE_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
user = get_current_user(request)
|
||||
@@ -1631,9 +1659,11 @@ def setup_document_routes(session_manager, upload_handler=None) -> APIRouter:
|
||||
# context (To/Subject/In-Reply-To/References).
|
||||
try:
|
||||
from routes.email_routes import _imap, _decode_header
|
||||
from routes.email_helpers import _q
|
||||
except Exception:
|
||||
_imap = None
|
||||
_decode_header = lambda x: x or ""
|
||||
_q = lambda x: x or ""
|
||||
|
||||
to_addr = ""
|
||||
from_name = ""
|
||||
@@ -1643,7 +1673,7 @@ def setup_document_routes(session_manager, upload_handler=None) -> APIRouter:
|
||||
if _imap:
|
||||
try:
|
||||
with _imap(doc.source_email_account_id or None) as conn:
|
||||
conn.select(doc.source_email_folder, readonly=True)
|
||||
conn.select(_q(doc.source_email_folder), readonly=True)
|
||||
status, data = conn.fetch(doc.source_email_uid.encode(), "(RFC822.HEADER)")
|
||||
if status == "OK" and data and data[0]:
|
||||
raw_hdr = data[0][1]
|
||||
|
||||
+98
-33
@@ -71,6 +71,38 @@ def _send_smtp_message(cfg: dict, from_addr: str, recipients: list[str], message
|
||||
smtp.sendmail(from_addr, recipients, message)
|
||||
|
||||
|
||||
def _friendly_email_auth_error(protocol: str, host: str, error: object) -> str:
|
||||
"""Return a clearer setup error for known provider auth policies."""
|
||||
raw = str(error or "")
|
||||
lower = raw.lower()
|
||||
host_lower = (host or "").lower()
|
||||
microsoft_host = any(
|
||||
marker in host_lower
|
||||
for marker in (
|
||||
"outlook.office365.com",
|
||||
"smtp.office365.com",
|
||||
"office365.com",
|
||||
"outlook.com",
|
||||
"hotmail.com",
|
||||
"live.com",
|
||||
)
|
||||
)
|
||||
microsoft_basic_auth_failure = (
|
||||
"5.7.139" in lower
|
||||
or "basic authentication is disabled" in lower
|
||||
or ("authenticate failed" in lower and microsoft_host)
|
||||
or ("authentication unsuccessful" in lower and microsoft_host)
|
||||
)
|
||||
if microsoft_basic_auth_failure:
|
||||
return (
|
||||
"Microsoft no longer accepts normal mailbox passwords for "
|
||||
"Outlook/Office 365 IMAP/SMTP in most accounts. Odysseus "
|
||||
"does not support Microsoft OAuth/Graph mail yet, so Outlook "
|
||||
"accounts cannot be added with this password form."
|
||||
)
|
||||
return raw[:200]
|
||||
|
||||
|
||||
def _strip_think(text: str) -> str:
|
||||
"""Email-flavored think strip — thin wrapper over the central helper.
|
||||
|
||||
@@ -254,16 +286,17 @@ def _cleanup_compose_uploads(tokens) -> None:
|
||||
pass
|
||||
|
||||
|
||||
DATA_DIR = Path(__file__).resolve().parent.parent / "data"
|
||||
SETTINGS_FILE = DATA_DIR / "settings.json"
|
||||
from src.constants import DATA_DIR as _DATA_DIR, MAIL_ATTACHMENTS_DIR, SETTINGS_FILE as _SETTINGS_FILE, SCHEDULED_EMAILS_DB
|
||||
DATA_DIR = Path(_DATA_DIR)
|
||||
SETTINGS_FILE = Path(_SETTINGS_FILE)
|
||||
# Override at deploy time via ODYSSEUS_MAIL_ATTACHMENTS_DIR. Defaults to a
|
||||
# subdir of the install's data/ tree so the app works out-of-the-box without
|
||||
# a hardcoded /home/<user>/ path.
|
||||
ATTACHMENTS_DIR = Path(os.environ.get("ODYSSEUS_MAIL_ATTACHMENTS_DIR", str(DATA_DIR / "mail-attachments")))
|
||||
ATTACHMENTS_DIR = Path(MAIL_ATTACHMENTS_DIR)
|
||||
ATTACHMENTS_DIR.mkdir(parents=True, exist_ok=True)
|
||||
COMPOSE_UPLOADS_DIR = ATTACHMENTS_DIR / "_compose"
|
||||
COMPOSE_UPLOADS_DIR.mkdir(parents=True, exist_ok=True)
|
||||
SCHEDULED_DB = DATA_DIR / "scheduled_emails.db"
|
||||
SCHEDULED_DB = Path(SCHEDULED_EMAILS_DB)
|
||||
|
||||
|
||||
OWNER_SCOPED_EMAIL_CACHE_TABLES = {
|
||||
@@ -705,7 +738,16 @@ def _open_imap_connection(host: str, port: int, *, starttls: bool, timeout: int
|
||||
port = int(port or 993)
|
||||
if starttls:
|
||||
conn = imaplib.IMAP4(host, port, timeout=timeout)
|
||||
conn.starttls()
|
||||
try:
|
||||
conn.starttls()
|
||||
except Exception:
|
||||
# Don't leak the open plain socket if the STARTTLS upgrade is
|
||||
# rejected; close it before propagating. (#3174)
|
||||
try:
|
||||
conn.shutdown()
|
||||
except Exception:
|
||||
pass
|
||||
raise
|
||||
elif port == 993:
|
||||
conn = imaplib.IMAP4_SSL(host, port, timeout=timeout)
|
||||
else:
|
||||
@@ -714,6 +756,10 @@ def _open_imap_connection(host: str, port: int, *, starttls: bool, timeout: int
|
||||
conn.sock.settimeout(timeout)
|
||||
except Exception:
|
||||
pass
|
||||
# Raise the IMAP line-length limit from the default 1 MB to 50 MB so that
|
||||
# large mailboxes (tens of thousands of messages) don't crash with
|
||||
# "got more than 1000000 bytes" on UID SEARCH ALL. (#2883)
|
||||
imaplib._MAXLINE = 50_000_000
|
||||
return conn
|
||||
|
||||
def _imap_connect(account_id: str | None = None, owner: str = ""):
|
||||
@@ -734,7 +780,18 @@ def _imap_connect(account_id: str | None = None, owner: str = ""):
|
||||
starttls=bool(cfg.get("imap_starttls")),
|
||||
timeout=_IMAP_TIMEOUT_SECONDS,
|
||||
)
|
||||
conn.login(cfg["imap_user"], cfg["imap_password"])
|
||||
try:
|
||||
conn.login(cfg["imap_user"], cfg["imap_password"])
|
||||
except Exception:
|
||||
# A failed AUTHENTICATE (e.g. an Office 365 app password on an
|
||||
# MFA-enabled tenant, #3174) otherwise orphans the already-connected
|
||||
# socket; close it before propagating so a misconfigured account
|
||||
# can't leak one descriptor per retry / background poller pass.
|
||||
try:
|
||||
conn.shutdown()
|
||||
except Exception:
|
||||
pass
|
||||
raise
|
||||
return conn
|
||||
|
||||
|
||||
@@ -798,20 +855,28 @@ def _imap(account_id: str | None = None, owner: str = ""):
|
||||
def _decode_header(raw):
|
||||
if not raw:
|
||||
return ""
|
||||
parts = email.header.decode_header(raw)
|
||||
decoded = []
|
||||
for data, charset in parts:
|
||||
if isinstance(data, bytes):
|
||||
try:
|
||||
decoded.append(data.decode(charset or "utf-8", errors="replace"))
|
||||
except (LookupError, ValueError):
|
||||
# Unknown/invalid MIME charset (e.g. a malformed or spam header
|
||||
# like =?x-unknown-charset?B?...?=). errors="replace" only covers
|
||||
# byte-decode errors, not codec lookup, so fall back to utf-8.
|
||||
decoded.append(data.decode("utf-8", errors="replace"))
|
||||
else:
|
||||
decoded.append(data)
|
||||
return " ".join(decoded)
|
||||
try:
|
||||
# make_header concatenates per RFC 2047: no spurious space between an
|
||||
# encoded-word and adjacent plain text (plain runs keep their own
|
||||
# whitespace), and the whitespace between two adjacent encoded-words is
|
||||
# dropped. The old " ".join produced "Re: Jose"-style double spaces on
|
||||
# every non-ASCII subject or sender.
|
||||
return str(email.header.make_header(email.header.decode_header(raw)))
|
||||
except Exception:
|
||||
# Malformed header or unknown/invalid MIME charset (e.g. a spam header
|
||||
# like =?x-unknown-charset?B?...?=) makes make_header raise LookupError;
|
||||
# fall back to a lossy per-part decode. errors="replace" only covers
|
||||
# byte-decode errors, not codec lookup, hence the explicit utf-8 retry.
|
||||
decoded = []
|
||||
for data, charset in email.header.decode_header(raw):
|
||||
if isinstance(data, bytes):
|
||||
try:
|
||||
decoded.append(data.decode(charset or "utf-8", errors="replace"))
|
||||
except (LookupError, ValueError):
|
||||
decoded.append(data.decode("utf-8", errors="replace"))
|
||||
else:
|
||||
decoded.append(data)
|
||||
return "".join(decoded)
|
||||
|
||||
|
||||
def _detect_sent_folder(conn):
|
||||
@@ -1136,13 +1201,9 @@ def _fetch_sender_thread_context(sender_addr: str,
|
||||
if exclude_uid:
|
||||
seen_uids.add((exclude_folder or "INBOX", str(exclude_uid)))
|
||||
|
||||
conn = None
|
||||
try:
|
||||
conn = _imap_connect(account_id, owner=owner)
|
||||
except Exception as e:
|
||||
logger.warning(f"sender-thread-context: imap connect failed: {e}")
|
||||
return ""
|
||||
|
||||
try:
|
||||
for folder in ["INBOX", "Sent", "Archive", "Drafts"]:
|
||||
if len(blocks) >= limit:
|
||||
break
|
||||
@@ -1209,11 +1270,14 @@ def _fetch_sender_thread_context(sender_addr: str,
|
||||
if atts_text:
|
||||
lines.append(atts_text)
|
||||
blocks.append("\n".join(lines))
|
||||
except Exception as e:
|
||||
logger.warning(f"sender-thread-context: imap failed: {e}")
|
||||
finally:
|
||||
try: conn.close()
|
||||
except Exception: pass
|
||||
try: conn.logout()
|
||||
except Exception: pass
|
||||
if conn:
|
||||
try: conn.close()
|
||||
except Exception: pass
|
||||
try: conn.logout()
|
||||
except Exception: pass
|
||||
|
||||
if not blocks:
|
||||
return ""
|
||||
@@ -1316,6 +1380,7 @@ def _pre_retrieve_context(
|
||||
if not terms_list:
|
||||
return context_snippets, terms_list
|
||||
|
||||
ctx_conn = None
|
||||
try:
|
||||
ctx_conn = _imap_connect(account_id, owner=owner)
|
||||
for folder in ["INBOX", "Sent", "Archive", "Drafts"]:
|
||||
@@ -1352,12 +1417,12 @@ def _pre_retrieve_context(
|
||||
except Exception as _e:
|
||||
logger.warning(f" search {folder} {term!r} failed: {_e}")
|
||||
continue
|
||||
try:
|
||||
ctx_conn.logout()
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as _e:
|
||||
logger.warning(f"IMAP context search failed: {_e}")
|
||||
finally:
|
||||
if ctx_conn:
|
||||
try: ctx_conn.logout()
|
||||
except Exception: pass
|
||||
|
||||
try:
|
||||
from routes.contacts_routes import _fetch_contacts
|
||||
|
||||
@@ -210,7 +210,7 @@ async def _auto_summarize_pass_single(days_back: int = 1, account_id: str | None
|
||||
if auto_cal:
|
||||
for sent_name in ("Sent", "INBOX/Sent", "Sent Items", "[Gmail]/Sent Mail"):
|
||||
try:
|
||||
st, _ = conn.select(sent_name, readonly=True)
|
||||
st, _ = conn.select(_q(sent_name), readonly=True)
|
||||
if st == "OK":
|
||||
folders_to_scan.append(sent_name)
|
||||
break
|
||||
@@ -1046,7 +1046,7 @@ def _scheduled_poll_once() -> dict:
|
||||
try:
|
||||
with _imap(row_account_id, owner=row_owner) as imap:
|
||||
sent_folder = _detect_sent_folder(imap)
|
||||
imap.append(sent_folder, "\\Seen", None, outer.as_bytes())
|
||||
imap.append(_q(sent_folder), "\\Seen", None, outer.as_bytes())
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to append scheduled {sid} to Sent: {e}")
|
||||
|
||||
|
||||
@@ -32,9 +32,10 @@ from email.mime.multipart import MIMEMultipart
|
||||
|
||||
from fastapi import APIRouter, Query, UploadFile, File, BackgroundTasks, HTTPException, Depends, Request
|
||||
from fastapi.responses import FileResponse
|
||||
from src.constants import DATA_DIR
|
||||
|
||||
from src.llm_core import llm_call_async
|
||||
from src.upload_limits import read_upload_limited
|
||||
from src.upload_limits import read_upload_limited, EMAIL_COMPOSE_UPLOAD_MAX_BYTES
|
||||
|
||||
from routes.email_helpers import (
|
||||
_strip_think, _extract_reply, _apply_email_style_mechanics, require_owner, require_user, _assert_owns_account,
|
||||
@@ -47,6 +48,7 @@ from routes.email_helpers import (
|
||||
_extract_attachment_to_disk, _extract_html, _extract_text,
|
||||
_fetch_sender_thread_context, _pre_retrieve_context,
|
||||
_EMAIL_REPLY_SYS_PROMPT_BASE, _POOL_HOOKS,
|
||||
_friendly_email_auth_error,
|
||||
SendEmailRequest, ExtractStyleRequest,
|
||||
ATTACHMENTS_DIR, COMPOSE_UPLOADS_DIR, SCHEDULED_DB,
|
||||
attachment_extract_dir, _email_cache_owner_clause,
|
||||
@@ -56,7 +58,6 @@ from routes.email_pollers import _start_poller
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ODYSSEUS_MAIL_ORIGIN = "odysseus-ui"
|
||||
EMAIL_COMPOSE_UPLOAD_MAX_BYTES = 25 * 1024 * 1024
|
||||
|
||||
|
||||
def _email_tag_owner_aliases(account_id: str | None, owner: str = "") -> list[str]:
|
||||
@@ -2904,7 +2905,7 @@ def setup_email_routes():
|
||||
from pathlib import Path as _P
|
||||
import json as _json
|
||||
_slug = "".join(c if (c.isalnum() or c in "-_.@") else "_" for c in (owner or "default"))
|
||||
path = _P(f"data/email_urgency_state_{_slug}.json")
|
||||
path = _P(DATA_DIR) / f"email_urgency_state_{_slug}.json"
|
||||
if not path.exists():
|
||||
return {"total_unread": 0, "total_urgent": 0, "max_score": 0, "per_uid": {}}
|
||||
try:
|
||||
@@ -3162,7 +3163,7 @@ def setup_email_routes():
|
||||
try: conn.logout()
|
||||
except Exception: pass
|
||||
except Exception as e:
|
||||
imap_result = {"ok": False, "error": str(e)[:200]}
|
||||
imap_result = {"ok": False, "error": _friendly_email_auth_error("IMAP", imap_host, e)}
|
||||
|
||||
smtp_host = (body.get("smtp_host") or "").strip()
|
||||
if smtp_host:
|
||||
@@ -3184,7 +3185,7 @@ def setup_email_routes():
|
||||
try: smtp.quit()
|
||||
except Exception: pass
|
||||
except Exception as e:
|
||||
smtp_result = {"ok": False, "error": str(e)[:200]}
|
||||
smtp_result = {"ok": False, "error": _friendly_email_auth_error("SMTP", smtp_host, e)}
|
||||
|
||||
return {
|
||||
"ok": imap_result["ok"] and (smtp_result is None or smtp_result["ok"]),
|
||||
|
||||
+65
-22
@@ -7,12 +7,12 @@ import logging
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
from fastapi import APIRouter, HTTPException, Form, Depends
|
||||
from core.constants import BASE_DIR
|
||||
from core.constants import EMBEDDING_ENDPOINT_FILE, FASTEMBED_CACHE_DIR
|
||||
from core.middleware import require_admin
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_ENDPOINT_FILE = os.path.join(BASE_DIR, "data", "embedding_endpoint.json")
|
||||
_ENDPOINT_FILE = EMBEDDING_ENDPOINT_FILE
|
||||
|
||||
# Track in-progress downloads
|
||||
_downloading: dict = {}
|
||||
@@ -35,13 +35,7 @@ def _cache_dir() -> str:
|
||||
default lived in /tmp, which many systems wipe on reboot — forcing a
|
||||
full re-download of the embedding model after every restart.
|
||||
"""
|
||||
env = os.environ.get("FASTEMBED_CACHE_PATH")
|
||||
if env:
|
||||
return env
|
||||
return os.path.join(
|
||||
os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
|
||||
"data", "fastembed_cache",
|
||||
)
|
||||
return FASTEMBED_CACHE_DIR
|
||||
|
||||
|
||||
def _model_cache_name(hf_source: str) -> str:
|
||||
@@ -49,19 +43,35 @@ def _model_cache_name(hf_source: str) -> str:
|
||||
return "models--" + hf_source.replace("/", "--")
|
||||
|
||||
|
||||
def _model_cache_path(hf_source: str) -> Path:
|
||||
"""Return a confined cache path for a fastembed HF source."""
|
||||
root = Path(_cache_dir()).expanduser().resolve()
|
||||
raw_path = root / _model_cache_name(hf_source)
|
||||
if raw_path.is_symlink():
|
||||
raise ValueError("Model cache path must not be a symlink")
|
||||
path = raw_path.resolve(strict=False)
|
||||
try:
|
||||
path.relative_to(root)
|
||||
except ValueError:
|
||||
raise ValueError("Model cache path escapes cache root")
|
||||
return path
|
||||
|
||||
|
||||
def _is_downloaded(hf_source: str) -> bool:
|
||||
"""Check if a model is already cached."""
|
||||
cache = _cache_dir()
|
||||
model_dir = os.path.join(cache, _model_cache_name(hf_source))
|
||||
if not os.path.isdir(model_dir):
|
||||
try:
|
||||
model_dir = _model_cache_path(hf_source)
|
||||
except ValueError:
|
||||
return False
|
||||
if not model_dir.is_dir():
|
||||
return False
|
||||
# Check for actual model files (not just empty dir)
|
||||
snapshots = os.path.join(model_dir, "snapshots")
|
||||
if os.path.isdir(snapshots):
|
||||
return any(os.listdir(snapshots))
|
||||
snapshots = model_dir / "snapshots"
|
||||
if snapshots.is_dir():
|
||||
return any(snapshots.iterdir())
|
||||
# Also check for blobs (older cache format)
|
||||
blobs = os.path.join(model_dir, "blobs")
|
||||
return os.path.isdir(blobs) and any(os.listdir(blobs))
|
||||
blobs = model_dir / "blobs"
|
||||
return blobs.is_dir() and any(blobs.iterdir())
|
||||
|
||||
|
||||
def _active_model() -> str:
|
||||
@@ -119,8 +129,10 @@ def setup_embedding_routes():
|
||||
|
||||
cached_size = None
|
||||
if downloaded and hf_src:
|
||||
model_path = os.path.join(_cache_dir(), _model_cache_name(hf_src))
|
||||
cached_size = _dir_size_mb(model_path)
|
||||
try:
|
||||
cached_size = _dir_size_mb(str(_model_cache_path(hf_src)))
|
||||
except ValueError:
|
||||
cached_size = None
|
||||
|
||||
result.append({
|
||||
"model": m["model"],
|
||||
@@ -217,8 +229,11 @@ def setup_embedding_routes():
|
||||
if not hf_src:
|
||||
raise HTTPException(400, "No cache source for this model")
|
||||
|
||||
model_path = os.path.join(_cache_dir(), _model_cache_name(hf_src))
|
||||
if not os.path.isdir(model_path):
|
||||
try:
|
||||
model_path = _model_cache_path(hf_src)
|
||||
except ValueError as e:
|
||||
raise HTTPException(400, str(e))
|
||||
if not model_path.is_dir():
|
||||
return {"deleted": False, "message": "Model not cached"}
|
||||
|
||||
shutil.rmtree(model_path)
|
||||
@@ -237,7 +252,7 @@ def setup_embedding_routes():
|
||||
}
|
||||
|
||||
@router.post("/endpoint")
|
||||
def set_endpoint(url: str = Form(...), model: str = Form("")):
|
||||
def set_endpoint(url: str = Form(...), model: str = Form(""), api_key: str = Form("")):
|
||||
"""Save a custom embedding endpoint URL."""
|
||||
url = url.strip()
|
||||
if not url:
|
||||
@@ -261,6 +276,7 @@ def setup_embedding_routes():
|
||||
resp = httpx.post(
|
||||
url,
|
||||
json={"input": ["test"], "model": model or "test"},
|
||||
headers={"Authorization": f"Bearer {api_key}"} if api_key else {},
|
||||
timeout=10,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
@@ -271,10 +287,16 @@ def setup_embedding_routes():
|
||||
data = {"url": url}
|
||||
if model:
|
||||
data["model"] = model
|
||||
if api_key:
|
||||
from src.secret_storage import encrypt
|
||||
data["api_key"] = encrypt(api_key)
|
||||
|
||||
_save_custom_endpoint(data)
|
||||
os.environ["EMBEDDING_URL"] = url
|
||||
if model:
|
||||
os.environ["EMBEDDING_MODEL"] = model
|
||||
if api_key:
|
||||
os.environ["EMBEDDING_API_KEY"] = api_key
|
||||
|
||||
# Reset the RAG singleton so it picks up the new endpoint
|
||||
import src.rag_singleton as _rs
|
||||
@@ -288,6 +310,16 @@ def setup_embedding_routes():
|
||||
reset_http_embed_state()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
from src.embedding_lanes import reset_embedding_lane_state
|
||||
reset_embedding_lane_state()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
from src.tool_index import reset_tool_index
|
||||
reset_tool_index()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Reset ChromaDB client (collections will be recreated with new embeddings)
|
||||
try:
|
||||
@@ -308,6 +340,7 @@ def setup_embedding_routes():
|
||||
# Remove from environment
|
||||
os.environ.pop("EMBEDDING_URL", None)
|
||||
os.environ.pop("EMBEDDING_MODEL", None)
|
||||
os.environ.pop("EMBEDDING_API_KEY", None)
|
||||
|
||||
# Reset the RAG singleton so it falls back to fastembed
|
||||
import src.rag_singleton as _rs
|
||||
@@ -318,6 +351,16 @@ def setup_embedding_routes():
|
||||
reset_http_embed_state()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
from src.embedding_lanes import reset_embedding_lane_state
|
||||
reset_embedding_lane_state()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
from src.tool_index import reset_tool_index
|
||||
reset_tool_index()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Reset ChromaDB client
|
||||
try:
|
||||
|
||||
+45
-6
@@ -16,22 +16,54 @@ from pathlib import Path
|
||||
|
||||
import httpx
|
||||
from fastapi import APIRouter
|
||||
from fastapi.responses import FileResponse, Response
|
||||
from fastapi.responses import Response
|
||||
|
||||
from src.constants import EMOJI_CACHE_DIR
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_CACHE_DIR = Path(__file__).resolve().parent.parent / "data" / "emoji_cache"
|
||||
_CACHE_DIR = Path(EMOJI_CACHE_DIR)
|
||||
# OpenMoji "black" set = monochrome line-art SVGs. Filenames are the codepoints
|
||||
# in UPPERCASE (FE0F dropped, same as we compute), '-' joined.
|
||||
_OPENMOJI_BASE = "https://cdn.jsdelivr.net/npm/openmoji@15.0.0/black/svg"
|
||||
# codepoints like "1f600" or "1f468-200d-1f469-200d-1f467" (lowercase hex, '-' joined)
|
||||
_CODE_RE = re.compile(r"^[0-9a-f]{2,6}(?:-[0-9a-f]{2,6})*$")
|
||||
_SVG_HEADERS = {"Cache-Control": "public, max-age=31536000, immutable"}
|
||||
_MAX_SVG_BYTES = 256 * 1024
|
||||
_BLOCKED_SVG_RE = re.compile(
|
||||
br"<\s*(?:script|foreignObject|iframe|object|embed|image)\b|"
|
||||
br"\bon[a-z0-9_-]+\s*=",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
_EXTERNAL_REF_RE = re.compile(
|
||||
br"\b(?:href|xlink:href)\s*=\s*['\"](?:https?:|//|data:|javascript:)",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
_SVG_SECURITY_HEADERS = {
|
||||
"X-Content-Type-Options": "nosniff",
|
||||
"Content-Security-Policy": "sandbox",
|
||||
"Cross-Origin-Resource-Policy": "same-origin",
|
||||
}
|
||||
_SVG_HEADERS = {
|
||||
"Cache-Control": "public, max-age=31536000, immutable",
|
||||
**_SVG_SECURITY_HEADERS,
|
||||
}
|
||||
# Returned when a codepoint is unknown/unreachable: an empty (transparent) SVG,
|
||||
# so the CSS mask renders nothing instead of a solid box. Not cached, so a later
|
||||
# request can still pick up the real glyph once the CDN is reachable.
|
||||
_BLANK_SVG = b'<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 1 1"></svg>'
|
||||
_BLANK_HEADERS = {"Cache-Control": "no-store"}
|
||||
_BLANK_HEADERS = {"Cache-Control": "no-store", **_SVG_SECURITY_HEADERS}
|
||||
|
||||
|
||||
def _is_safe_svg(content: bytes) -> bool:
|
||||
if not isinstance(content, bytes) or not content:
|
||||
return False
|
||||
if len(content) > _MAX_SVG_BYTES:
|
||||
return False
|
||||
if b"<svg" not in content[:256].lower():
|
||||
return False
|
||||
if _BLOCKED_SVG_RE.search(content) or _EXTERNAL_REF_RE.search(content):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def setup_emoji_routes() -> APIRouter:
|
||||
@@ -49,14 +81,21 @@ def setup_emoji_routes() -> APIRouter:
|
||||
_CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
||||
fp = _CACHE_DIR / f"{code}.svg"
|
||||
if fp.exists():
|
||||
return FileResponse(fp, media_type="image/svg+xml", headers=_SVG_HEADERS)
|
||||
try:
|
||||
content = fp.read_bytes()
|
||||
if _is_safe_svg(content):
|
||||
return Response(content, media_type="image/svg+xml", headers=_SVG_HEADERS)
|
||||
fp.unlink(missing_ok=True)
|
||||
except Exception as e:
|
||||
logger.warning("emoji cache read %s failed: %s", code, e)
|
||||
return _blank()
|
||||
|
||||
# First time we've seen this emoji — fetch the OpenMoji black SVG + cache
|
||||
# it. OpenMoji filenames are the codepoints uppercased.
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=8.0) as client:
|
||||
r = await client.get(f"{_OPENMOJI_BASE}/{code.upper()}.svg")
|
||||
if r.status_code == 200 and b"<svg" in r.content[:256]:
|
||||
if r.status_code == 200 and _is_safe_svg(r.content):
|
||||
try:
|
||||
fp.write_bytes(r.content)
|
||||
except Exception:
|
||||
|
||||
+144
-54
@@ -12,8 +12,13 @@ from fastapi import APIRouter, HTTPException, Query, Request
|
||||
|
||||
from core.database import SessionLocal, GalleryImage, GalleryAlbum, ModelEndpoint
|
||||
from core.database import Session as DbSession
|
||||
from src.auth_helpers import get_current_user, require_privilege
|
||||
from src.upload_limits import read_upload_limited
|
||||
from src.auth_helpers import get_current_user, owner_filter, require_privilege
|
||||
from src.upload_limits import (
|
||||
read_upload_limited,
|
||||
GALLERY_UPLOAD_MAX_BYTES,
|
||||
GALLERY_TRANSFORM_UPLOAD_MAX_BYTES,
|
||||
)
|
||||
from src.constants import GENERATED_IMAGES_DIR
|
||||
|
||||
from routes.gallery_helpers import (
|
||||
GalleryPatch, _extract_exif, _image_to_dict, _owner_filter, _human_size,
|
||||
@@ -21,17 +26,88 @@ from routes.gallery_helpers import (
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
GALLERY_UPLOAD_MAX_BYTES = int(os.getenv("ODYSSEUS_GALLERY_UPLOAD_MAX_BYTES", str(100 * 1024 * 1024)))
|
||||
GALLERY_TRANSFORM_UPLOAD_MAX_BYTES = int(os.getenv("ODYSSEUS_GALLERY_TRANSFORM_UPLOAD_MAX_BYTES", str(25 * 1024 * 1024)))
|
||||
|
||||
def _current_user_is_admin(request: Request, user: str | None) -> bool:
|
||||
if not user:
|
||||
return False
|
||||
auth_mgr = getattr(request.app.state, "auth_manager", None)
|
||||
is_admin = getattr(auth_mgr, "is_admin", None)
|
||||
if not callable(is_admin):
|
||||
return False
|
||||
try:
|
||||
return bool(is_admin(user))
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def _sanitize_gallery_filename(filename: str) -> str:
|
||||
"""Return a local filename safe to join under generated_images."""
|
||||
safe_name = re.sub(r"[^A-Za-z0-9._-]", "_", Path(filename or "").name)[:128]
|
||||
safe_name = re.sub(r"[^A-Za-z0-9._-]", "_", Path(str(filename or "")).name)[:128]
|
||||
if not safe_name or safe_name in {".", ".."}:
|
||||
safe_name = uuid.uuid4().hex[:12]
|
||||
return safe_name
|
||||
|
||||
|
||||
GALLERY_IMAGE_DIR = Path(GENERATED_IMAGES_DIR)
|
||||
|
||||
|
||||
def _gallery_image_path(filename: str) -> Path:
|
||||
"""Resolve a stored gallery filename without leaving generated_images."""
|
||||
if not isinstance(filename, str):
|
||||
raise HTTPException(400, "Unsafe gallery filename")
|
||||
safe_name = _sanitize_gallery_filename(filename)
|
||||
original = str(filename or "")
|
||||
root = GALLERY_IMAGE_DIR.resolve()
|
||||
path = (GALLERY_IMAGE_DIR / safe_name).resolve()
|
||||
try:
|
||||
if os.path.commonpath([str(root), str(path)]) != str(root):
|
||||
raise ValueError
|
||||
except Exception:
|
||||
raise HTTPException(400, "Unsafe gallery filename")
|
||||
if safe_name != original:
|
||||
raise HTTPException(400, "Unsafe gallery filename")
|
||||
return path
|
||||
|
||||
|
||||
def _normalize_image_endpoint_base(url: str) -> str:
|
||||
base = (url or "").strip().rstrip("/")
|
||||
if base.endswith("/v1"):
|
||||
base = base[:-3].rstrip("/")
|
||||
return base
|
||||
|
||||
|
||||
def _visible_image_endpoint_query(db, owner: str | None):
|
||||
from src.auth_helpers import owner_filter
|
||||
q = db.query(ModelEndpoint).filter(
|
||||
ModelEndpoint.model_type == "image",
|
||||
ModelEndpoint.is_enabled == True, # noqa: E712
|
||||
)
|
||||
return owner_filter(q, ModelEndpoint, owner)
|
||||
|
||||
|
||||
def _first_visible_image_endpoint(db, owner: str | None):
|
||||
endpoints = _visible_image_endpoint_query(db, owner).all()
|
||||
if owner:
|
||||
for ep in endpoints:
|
||||
if getattr(ep, "owner", None) == owner:
|
||||
return ep
|
||||
return endpoints[0] if endpoints else None
|
||||
|
||||
|
||||
def _visible_image_endpoint_for_base(db, base: str, owner: str | None):
|
||||
target = _normalize_image_endpoint_base(base)
|
||||
if not target:
|
||||
return None
|
||||
fallback = None
|
||||
for ep in _visible_image_endpoint_query(db, owner).all():
|
||||
if _normalize_image_endpoint_base(getattr(ep, "base_url", "")) == target:
|
||||
if owner and getattr(ep, "owner", None) == owner:
|
||||
return ep
|
||||
if fallback is None:
|
||||
fallback = ep
|
||||
return fallback
|
||||
|
||||
|
||||
def setup_gallery_routes() -> APIRouter:
|
||||
router = APIRouter(tags=["gallery"])
|
||||
|
||||
@@ -55,6 +131,9 @@ def setup_gallery_routes() -> APIRouter:
|
||||
file_hash = hashlib.sha256(content).hexdigest()
|
||||
db = SessionLocal()
|
||||
try:
|
||||
if album_id and user is not None:
|
||||
_get_or_404_album(db, album_id, user)
|
||||
|
||||
# SECURITY: scope the dup-detect to THIS user — otherwise a
|
||||
# caller can probe whether someone else uploaded the same
|
||||
# file (the response leaks the existing row's id+filename).
|
||||
@@ -69,7 +148,7 @@ def setup_gallery_routes() -> APIRouter:
|
||||
return {"ok": False, "duplicate": True, "filename": existing.filename,
|
||||
"id": existing.id, "message": "Duplicate photo skipped"}
|
||||
|
||||
img_dir = Path("data/generated_images")
|
||||
img_dir = Path(GENERATED_IMAGES_DIR)
|
||||
img_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
ext = file.filename.rsplit(".", 1)[-1].lower() if "." in file.filename else "png"
|
||||
@@ -135,7 +214,7 @@ def setup_gallery_routes() -> APIRouter:
|
||||
raise HTTPException(400, "No image provided")
|
||||
|
||||
content = await read_upload_limited(file, GALLERY_UPLOAD_MAX_BYTES, "Gallery replacement")
|
||||
img_dir = Path("data/generated_images")
|
||||
img_dir = Path(GENERATED_IMAGES_DIR)
|
||||
img_dir.mkdir(parents=True, exist_ok=True)
|
||||
img_path = img_dir / _sanitize_gallery_filename(img.filename)
|
||||
img_path.write_bytes(content)
|
||||
@@ -211,7 +290,7 @@ def setup_gallery_routes() -> APIRouter:
|
||||
if not user or img.owner != user:
|
||||
raise HTTPException(403, "Not your image")
|
||||
|
||||
img_path = Path("data/generated_images") / img.filename
|
||||
img_path = _gallery_image_path(img.filename)
|
||||
if not img_path.exists():
|
||||
raise HTTPException(404, "Image file not found")
|
||||
|
||||
@@ -248,7 +327,7 @@ def setup_gallery_routes() -> APIRouter:
|
||||
"""AI upscale using img2img with the diffusion server."""
|
||||
import base64, httpx
|
||||
|
||||
require_privilege(request, "can_generate_images")
|
||||
user = require_privilege(request, "can_generate_images")
|
||||
form = await request.form()
|
||||
file = form.get("image")
|
||||
if not file: raise HTTPException(400, "No image")
|
||||
@@ -260,7 +339,7 @@ def setup_gallery_routes() -> APIRouter:
|
||||
# Find image endpoint
|
||||
db = SessionLocal()
|
||||
try:
|
||||
ep = db.query(ModelEndpoint).filter(ModelEndpoint.model_type == "image", ModelEndpoint.is_enabled == True).first()
|
||||
ep = _first_visible_image_endpoint(db, user)
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
@@ -291,7 +370,7 @@ def setup_gallery_routes() -> APIRouter:
|
||||
"""Style transfer using img2img with the diffusion server."""
|
||||
import base64, httpx
|
||||
|
||||
require_privilege(request, "can_generate_images")
|
||||
user = require_privilege(request, "can_generate_images")
|
||||
form = await request.form()
|
||||
file = form.get("image")
|
||||
prompt = form.get("prompt", "")
|
||||
@@ -303,7 +382,7 @@ def setup_gallery_routes() -> APIRouter:
|
||||
|
||||
db = SessionLocal()
|
||||
try:
|
||||
ep = db.query(ModelEndpoint).filter(ModelEndpoint.model_type == "image", ModelEndpoint.is_enabled == True).first()
|
||||
ep = _first_visible_image_endpoint(db, user)
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
@@ -505,18 +584,24 @@ def setup_gallery_routes() -> APIRouter:
|
||||
albums = q.order_by(GalleryAlbum.created_at.desc()).all()
|
||||
result = []
|
||||
for a in albums:
|
||||
count = db.query(GalleryImage).filter(
|
||||
_count_q = db.query(GalleryImage).filter(
|
||||
GalleryImage.album_id == a.id, GalleryImage.is_active == True
|
||||
).count()
|
||||
)
|
||||
if user:
|
||||
_count_q = _count_q.filter(GalleryImage.owner == user)
|
||||
count = _count_q.count()
|
||||
cover_url = None
|
||||
if a.cover_id:
|
||||
cover = db.query(GalleryImage).filter(GalleryImage.id == a.cover_id).first()
|
||||
if cover:
|
||||
cover_url = f"/api/generated-image/{cover.filename}"
|
||||
elif count > 0:
|
||||
first = db.query(GalleryImage).filter(
|
||||
_cover_q = db.query(GalleryImage).filter(
|
||||
GalleryImage.album_id == a.id, GalleryImage.is_active == True
|
||||
).order_by(GalleryImage.created_at.desc()).first()
|
||||
)
|
||||
if user:
|
||||
_cover_q = _cover_q.filter(GalleryImage.owner == user)
|
||||
first = _cover_q.order_by(GalleryImage.created_at.desc()).first()
|
||||
if first:
|
||||
cover_url = f"/api/generated-image/{first.filename}"
|
||||
result.append({
|
||||
@@ -649,7 +734,14 @@ def setup_gallery_routes() -> APIRouter:
|
||||
if req.favorite is not None:
|
||||
img.favorite = req.favorite
|
||||
if req.album_id is not None:
|
||||
img.album_id = req.album_id if req.album_id else None
|
||||
if req.album_id:
|
||||
# Validate the target album belongs to the caller before
|
||||
# moving the image into it — mirrors add_to_album, so you
|
||||
# cannot file your image into another user's album.
|
||||
_get_or_404_album(db, req.album_id, user)
|
||||
img.album_id = req.album_id
|
||||
else:
|
||||
img.album_id = None
|
||||
db.commit()
|
||||
db.refresh(img)
|
||||
return _image_to_dict(img)
|
||||
@@ -692,11 +784,11 @@ def setup_gallery_routes() -> APIRouter:
|
||||
used = set()
|
||||
with zipfile.ZipFile(buf, "w", zipfile.ZIP_DEFLATED) as zf:
|
||||
for img in imgs:
|
||||
src = os.path.join("data", "generated_images", img.filename)
|
||||
if not os.path.exists(src):
|
||||
src = _gallery_image_path(img.filename)
|
||||
if not src.exists():
|
||||
continue
|
||||
ext = os.path.splitext(img.filename)[1] or ".png"
|
||||
base = (img.prompt or "").strip() or os.path.splitext(img.filename)[0]
|
||||
ext = src.suffix or ".png"
|
||||
base = (img.prompt or "").strip() or src.stem
|
||||
base = re.sub(r"[^\w\-. ]+", "", base)[:60].strip() or img.id
|
||||
name = f"{base}{ext}"
|
||||
i = 1
|
||||
@@ -818,9 +910,9 @@ def setup_gallery_routes() -> APIRouter:
|
||||
|
||||
img_filename = img.filename
|
||||
# Remove the file from disk
|
||||
img_path = os.path.join("data", "generated_images", img_filename)
|
||||
if os.path.exists(img_path):
|
||||
os.remove(img_path)
|
||||
img_path = _gallery_image_path(img_filename)
|
||||
if img_path.exists():
|
||||
img_path.unlink()
|
||||
|
||||
# Soft-delete the record
|
||||
img.is_active = False
|
||||
@@ -923,7 +1015,7 @@ def setup_gallery_routes() -> APIRouter:
|
||||
the request for /v1/images/edits (multipart, inverted mask). Otherwise
|
||||
proxy through to a self-hosted diffusion server's /v1/images/inpaint."""
|
||||
import httpx
|
||||
require_privilege(request, "can_generate_images")
|
||||
user = require_privilege(request, "can_generate_images")
|
||||
body = await request.json()
|
||||
# Use endpoint from request body (editor dropdown) or fall back to DB lookup
|
||||
base = (body.pop("_endpoint", "") or "").rstrip("/")
|
||||
@@ -942,14 +1034,11 @@ def setup_gallery_routes() -> APIRouter:
|
||||
if not base:
|
||||
db = SessionLocal()
|
||||
try:
|
||||
eps = db.query(ModelEndpoint).filter(
|
||||
ModelEndpoint.is_enabled == True,
|
||||
ModelEndpoint.model_type == "image",
|
||||
).all()
|
||||
if not eps:
|
||||
ep = _first_visible_image_endpoint(db, user)
|
||||
if not ep:
|
||||
raise HTTPException(400, "No image generation endpoint configured. Serve a diffusion model via Cookbook first.")
|
||||
base = eps[0].base_url.rstrip("/")
|
||||
api_key = eps[0].api_key
|
||||
base = ep.base_url.rstrip("/")
|
||||
api_key = ep.api_key
|
||||
finally:
|
||||
db.close()
|
||||
else:
|
||||
@@ -966,10 +1055,12 @@ def setup_gallery_routes() -> APIRouter:
|
||||
_target = _norm_url(base)
|
||||
db = SessionLocal()
|
||||
try:
|
||||
for ep in db.query(ModelEndpoint).all():
|
||||
if _norm_url(ep.base_url) == _target:
|
||||
api_key = ep.api_key
|
||||
break
|
||||
ep = _visible_image_endpoint_for_base(db, _target, user)
|
||||
if ep:
|
||||
base = (ep.base_url or base).rstrip("/")
|
||||
api_key = ep.api_key
|
||||
elif user and not _current_user_is_admin(request, user):
|
||||
raise HTTPException(403, "Choose a registered image endpoint")
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
@@ -1121,7 +1212,7 @@ def setup_gallery_routes() -> APIRouter:
|
||||
you get edge blending + lighting unification while keeping the
|
||||
composition recognisable."""
|
||||
import httpx, base64 as _b64
|
||||
require_privilege(request, "can_generate_images")
|
||||
user = require_privilege(request, "can_generate_images")
|
||||
body = await request.json()
|
||||
|
||||
image_b64 = body.get("image")
|
||||
@@ -1148,23 +1239,22 @@ def setup_gallery_routes() -> APIRouter:
|
||||
if not base:
|
||||
db = SessionLocal()
|
||||
try:
|
||||
eps = db.query(ModelEndpoint).filter(
|
||||
ModelEndpoint.is_enabled == True,
|
||||
ModelEndpoint.model_type == "image",
|
||||
).all()
|
||||
if not eps:
|
||||
ep = _first_visible_image_endpoint(db, user)
|
||||
if not ep:
|
||||
raise HTTPException(400, "No image generation endpoint configured.")
|
||||
base = eps[0].base_url.rstrip("/")
|
||||
api_key = eps[0].api_key
|
||||
base = ep.base_url.rstrip("/")
|
||||
api_key = ep.api_key
|
||||
finally:
|
||||
db.close()
|
||||
else:
|
||||
db = SessionLocal()
|
||||
try:
|
||||
for ep in db.query(ModelEndpoint).all():
|
||||
if ep.base_url.rstrip("/").removesuffix("/v1").rstrip("/") == base.rstrip("/").removesuffix("/v1").rstrip("/"):
|
||||
api_key = ep.api_key
|
||||
break
|
||||
ep = _visible_image_endpoint_for_base(db, base, user)
|
||||
if ep:
|
||||
base = (ep.base_url or base).rstrip("/")
|
||||
api_key = ep.api_key
|
||||
elif user and not _current_user_is_admin(request, user):
|
||||
raise HTTPException(403, "Choose a registered image endpoint")
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
@@ -1636,9 +1726,10 @@ def setup_gallery_routes() -> APIRouter:
|
||||
db = SessionLocal()
|
||||
try:
|
||||
album = _get_or_404_album(db, album_id, user)
|
||||
db.query(GalleryImage).filter(GalleryImage.album_id == album_id).update(
|
||||
{"album_id": None}, synchronize_session=False
|
||||
)
|
||||
q = db.query(GalleryImage).filter(GalleryImage.album_id == album_id)
|
||||
if user is not None:
|
||||
q = q.filter(GalleryImage.owner == user)
|
||||
q.update({"album_id": None}, synchronize_session=False)
|
||||
db.delete(album)
|
||||
db.commit()
|
||||
return {"ok": True}
|
||||
@@ -1709,7 +1800,7 @@ def setup_gallery_routes() -> APIRouter:
|
||||
try:
|
||||
img = _get_or_404_image(db, image_id, user)
|
||||
|
||||
img_path = Path("data/generated_images") / img.filename
|
||||
img_path = _gallery_image_path(img.filename)
|
||||
if not img_path.exists():
|
||||
raise HTTPException(404, "Image file not found")
|
||||
|
||||
@@ -1727,7 +1818,7 @@ def setup_gallery_routes() -> APIRouter:
|
||||
return {"error": "Vision is disabled — enable it in Settings → Vision"}
|
||||
configured = vl_settings.get("vision_model", "")
|
||||
try:
|
||||
chat_url, model_name, headers = _resolve_vl_model(configured)
|
||||
chat_url, model_name, headers = _resolve_vl_model(configured, owner=user)
|
||||
except ValueError:
|
||||
return {"error": "No vision model configured — set one in Settings → Vision"}
|
||||
if not chat_url:
|
||||
@@ -1808,4 +1899,3 @@ def setup_gallery_routes() -> APIRouter:
|
||||
db.close()
|
||||
|
||||
return router
|
||||
|
||||
|
||||
@@ -490,7 +490,13 @@ def setup_history_routes(session_manager) -> APIRouter:
|
||||
# Copy messages up to keep_count
|
||||
msgs_to_copy = source.history[:keep_count]
|
||||
for msg in msgs_to_copy:
|
||||
new_session.add_message(ChatMessage(msg.role, msg.content, msg.metadata))
|
||||
# Copy the metadata dict. Sharing it would let the fork's
|
||||
# persistence (add_message -> _persist_message stamps
|
||||
# _db_id/timestamp onto the dict) mutate the SOURCE session's
|
||||
# in-memory messages, corrupting their _db_id and breaking
|
||||
# edit/delete-by-id on the original conversation.
|
||||
meta = dict(msg.metadata) if isinstance(msg.metadata, dict) else None
|
||||
new_session.add_message(ChatMessage(msg.role, msg.content, meta))
|
||||
try:
|
||||
from src.event_bus import fire_event
|
||||
fire_event("session_created", getattr(source, 'owner', None))
|
||||
@@ -522,6 +528,8 @@ def setup_history_routes(session_manager) -> APIRouter:
|
||||
async def compact_session(request: Request, session_id: str):
|
||||
"""Manually trigger context compaction for a session."""
|
||||
_verify_session_owner(request, session_id)
|
||||
from src.auth_helpers import effective_user
|
||||
owner = effective_user(request)
|
||||
try:
|
||||
session = session_manager.get_session(session_id)
|
||||
except KeyError:
|
||||
@@ -555,7 +563,7 @@ def setup_history_routes(session_manager) -> APIRouter:
|
||||
)
|
||||
|
||||
# Use utility model if available
|
||||
util_url, util_model, util_headers = resolve_endpoint("utility")
|
||||
util_url, util_model, util_headers = resolve_endpoint("utility", owner=owner or None)
|
||||
compact_url = util_url or session.endpoint_url
|
||||
compact_model = util_model or session.model
|
||||
compact_headers = util_headers if util_url else session.headers
|
||||
|
||||
@@ -13,7 +13,7 @@ import httpx
|
||||
|
||||
from core.database import McpServer, SessionLocal
|
||||
from core.middleware import require_admin
|
||||
from src.constants import DATA_DIR
|
||||
from src.constants import DATA_DIR, MCP_OAUTH_DIR
|
||||
from src.mcp_manager import McpManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -23,7 +23,7 @@ router = APIRouter(prefix="/api/mcp", tags=["mcp"])
|
||||
|
||||
def _mcp_oauth_base_dir() -> Path:
|
||||
"""Directory that may contain OAuth files managed by Odysseus."""
|
||||
return (Path(DATA_DIR) / "mcp_oauth").resolve(strict=False)
|
||||
return Path(MCP_OAUTH_DIR).resolve(strict=False)
|
||||
|
||||
|
||||
def _resolve_mcp_oauth_path(raw_path, field_name: str) -> str:
|
||||
|
||||
@@ -29,11 +29,10 @@ from src.llm_core import llm_call_async
|
||||
from services.memory.memory_extractor import audit_memories
|
||||
from src.auth_helpers import get_current_user, require_user
|
||||
from src.endpoint_resolver import resolve_endpoint
|
||||
from src.upload_limits import read_upload_limited
|
||||
from src.upload_limits import read_upload_limited, MEMORY_IMPORT_MAX_BYTES
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
MEMORY_IMPORT_MAX_BYTES = int(os.getenv("ODYSSEUS_MEMORY_IMPORT_MAX_BYTES", str(10 * 1024 * 1024)))
|
||||
|
||||
def setup_memory_routes(memory_manager: MemoryManager, session_manager: SessionManager, memory_vector=None):
|
||||
"""Set up memory-related routes."""
|
||||
@@ -371,7 +370,7 @@ def setup_memory_routes(memory_manager: MemoryManager, session_manager: SessionM
|
||||
tmp.write(content)
|
||||
tmp_path = tmp.name
|
||||
try:
|
||||
text = _process_pdf(tmp_path)
|
||||
text = _process_pdf(tmp_path, owner=_owner(request))
|
||||
finally:
|
||||
os.unlink(tmp_path)
|
||||
else:
|
||||
|
||||
+188
-26
@@ -5,6 +5,7 @@ import re
|
||||
import uuid
|
||||
import json
|
||||
import socket
|
||||
import hashlib
|
||||
import time as _time
|
||||
import logging
|
||||
import httpx
|
||||
@@ -282,8 +283,11 @@ _HOST_TO_CURATED = (
|
||||
("fireworks.ai", "fireworks"),
|
||||
("googleapis.com", "google"),
|
||||
("x.ai", "xai"),
|
||||
|
||||
("openrouter.ai", "openrouter"),
|
||||
("ollama.com", "ollama"),
|
||||
("opencode.ai/zen/go", "opencode-go"),
|
||||
("opencode.ai/zen", "opencode-zen"),
|
||||
)
|
||||
|
||||
|
||||
@@ -490,6 +494,8 @@ _NON_CHAT_EXACT_PREFIXES = (
|
||||
def _is_chat_model(model_id: str) -> bool:
|
||||
"""Return True if the model ID looks like a chat/completions-capable model."""
|
||||
mid = model_id.lower()
|
||||
if mid in {"gpt-5.1-codex"}:
|
||||
return True
|
||||
for prefix in _NON_CHAT_PREFIXES:
|
||||
if mid.startswith(prefix):
|
||||
return False
|
||||
@@ -502,9 +508,67 @@ def _is_chat_model(model_id: str) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
def _probe_single_model(base: str, api_key: str, model_id: str, timeout: int = 10, with_tools: bool = False) -> dict:
|
||||
def _delete_orphaned_provider_auth(db, auth_id: Optional[str], exclude_ep_id: Optional[str] = None) -> bool:
|
||||
"""Delete a ProviderAuthSession once no endpoint still references it.
|
||||
|
||||
Subscription providers (e.g. ChatGPT Subscription) keep their refresh token
|
||||
in ProviderAuthSession rather than ModelEndpoint.api_key. When the last
|
||||
endpoint backed by that auth row is removed, the stored credentials should
|
||||
be cleared instead of lingering. Returns True if a row was deleted.
|
||||
``exclude_ep_id`` drops the endpoint currently being deleted from the
|
||||
reference count so it does not keep its own auth alive.
|
||||
"""
|
||||
if not auth_id:
|
||||
return False
|
||||
from core.database import ProviderAuthSession
|
||||
still_referenced = db.query(ModelEndpoint.id).filter(
|
||||
ModelEndpoint.provider_auth_id == auth_id,
|
||||
ModelEndpoint.id != exclude_ep_id,
|
||||
).first()
|
||||
if still_referenced is not None:
|
||||
return False
|
||||
auth_row = db.query(ProviderAuthSession).filter(ProviderAuthSession.id == auth_id).first()
|
||||
if auth_row is None:
|
||||
return False
|
||||
db.delete(auth_row)
|
||||
return True
|
||||
|
||||
|
||||
def _is_discovery_only_provider(provider: str) -> bool:
|
||||
"""Provider that only supports model discovery, not live probing.
|
||||
|
||||
ChatGPT Subscription speaks the Responses/Codex API and has no
|
||||
chat-completions or general health endpoint, so completion probes and
|
||||
reachability pings are skipped — status is derived from cached models.
|
||||
"""
|
||||
return provider == "chatgpt-subscription"
|
||||
|
||||
|
||||
def _resolve_probe_key(ep) -> Optional[str]:
|
||||
"""API key/bearer to probe an endpoint with.
|
||||
|
||||
Delegates to ``resolve_endpoint_runtime``, which already returns the static
|
||||
``ModelEndpoint.api_key`` for keyed endpoints and resolves (and refreshes)
|
||||
the runtime bearer for session-backed providers (e.g. ChatGPT Subscription).
|
||||
Returns None if resolution fails (e.g. re-auth required) so probing skips
|
||||
rather than raising. Reads only already-loaded scalar attributes of ``ep``.
|
||||
"""
|
||||
try:
|
||||
from src.endpoint_resolver import resolve_endpoint_runtime
|
||||
_base, key = resolve_endpoint_runtime(ep, owner=getattr(ep, "owner", None))
|
||||
return key
|
||||
except Exception as e:
|
||||
logger.warning("Probe key resolution failed for %s: %s", getattr(ep, "id", "?"), e)
|
||||
return None
|
||||
|
||||
|
||||
def _probe_single_model(base: str, api_key: Optional[str], model_id: str, timeout: int = 10, with_tools: bool = False) -> dict:
|
||||
"""Send a realistic completion request to a single model. Returns {status, latency_ms, error?}."""
|
||||
provider = _detect_provider(base)
|
||||
if _is_discovery_only_provider(provider):
|
||||
# Responses/Codex API, not chat-completions: a completion probe would
|
||||
# 400 and the re-probe flow would then hide every model. Discovery-only.
|
||||
return {"status": "ok", "latency_ms": 0, "skipped": True}
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Say OK"},
|
||||
@@ -618,6 +682,11 @@ def _probe_endpoint(base_url: str, api_key: str = None, timeout: int = 5) -> Lis
|
||||
For Anthropic, queries their /v1/models API, falling back to hardcoded list."""
|
||||
from src.endpoint_resolver import resolve_url
|
||||
base = resolve_url(_normalize_base(base_url))
|
||||
if _detect_provider(base) == "chatgpt-subscription":
|
||||
from src.chatgpt_subscription import fetch_available_models
|
||||
if api_key:
|
||||
return fetch_available_models(api_key, timeout=timeout)
|
||||
return []
|
||||
if _detect_provider(base) == "anthropic":
|
||||
# Try Anthropic's /v1/models endpoint first
|
||||
url = build_models_url(base)
|
||||
@@ -644,6 +713,10 @@ def _probe_endpoint(base_url: str, api_key: str = None, timeout: int = 5) -> Lis
|
||||
logger.warning(f"Anthropic /v1/models failed, using hardcoded list: {e}")
|
||||
return list(ANTHROPIC_MODELS)
|
||||
url = build_models_url(base)
|
||||
if not url:
|
||||
curated_key = _match_provider_curated(base, None)
|
||||
fallback = _PROVIDER_CURATED.get(curated_key) if curated_key else None
|
||||
return list(fallback or [])
|
||||
headers = build_headers(api_key, base)
|
||||
try:
|
||||
r = httpx.get(url, headers=headers, timeout=timeout, verify=llm_verify())
|
||||
@@ -697,7 +770,6 @@ def _probe_endpoint(base_url: str, api_key: str = None, timeout: int = 5) -> Lis
|
||||
return list(fallback)
|
||||
return []
|
||||
|
||||
|
||||
def _ping_endpoint(base_url: str, api_key: str = None, timeout: float = 1.5) -> Dict[str, Any]:
|
||||
"""Reachability probe that does not require installed/listed models."""
|
||||
from src.endpoint_resolver import resolve_url
|
||||
@@ -713,6 +785,10 @@ def _ping_endpoint(base_url: str, api_key: str = None, timeout: float = 1.5) ->
|
||||
or "ollama" in (parsed_base.hostname or "").lower()
|
||||
)
|
||||
|
||||
# APFEL-specific detection
|
||||
host = (parsed_base.hostname or "").lower()
|
||||
looks_like_apfel = "apfel" in host or parsed_base.port == 11435
|
||||
|
||||
def _result_from_response(r) -> Dict[str, Any]:
|
||||
if 300 <= r.status_code < 400:
|
||||
loc = r.headers.get("location", "")
|
||||
@@ -734,7 +810,23 @@ def _ping_endpoint(base_url: str, api_key: str = None, timeout: float = 1.5) ->
|
||||
last_error: Optional[str] = None
|
||||
|
||||
try:
|
||||
if looks_like_ollama:
|
||||
# APFEL does not behave like Ollama; use its health endpoint.
|
||||
if looks_like_apfel:
|
||||
root = base
|
||||
for suffix in ("/v1", "/api"):
|
||||
if root.endswith(suffix):
|
||||
root = root[: -len(suffix)].rstrip("/")
|
||||
break
|
||||
try:
|
||||
r = httpx.get(root + "/health", timeout=timeout, verify=llm_verify())
|
||||
result = _result_from_response(r)
|
||||
if result["reachable"]:
|
||||
return result
|
||||
last_error = result.get("error")
|
||||
except Exception as e:
|
||||
last_error = str(e)[:120]
|
||||
|
||||
elif looks_like_ollama:
|
||||
root = base
|
||||
for suffix in ("/v1", "/api"):
|
||||
if root.endswith(suffix):
|
||||
@@ -754,14 +846,31 @@ def _ping_endpoint(base_url: str, api_key: str = None, timeout: float = 1.5) ->
|
||||
|
||||
try:
|
||||
r = httpx.get(base, headers=headers, timeout=timeout, verify=llm_verify())
|
||||
return _result_from_response(r)
|
||||
result = _result_from_response(r)
|
||||
# If the bare base URL returns a non-auth 4xx (e.g. 404), try /models
|
||||
# as a fallback. OpenAI-compatible servers like llama-swap return 404
|
||||
# on the base /v1 prefix but 200 on /v1/models. Auth failures (401/403)
|
||||
# are definitive — probing /models would just repeat the same rejection.
|
||||
if (
|
||||
not result["reachable"]
|
||||
and result.get("status_code") is not None
|
||||
and 400 <= result["status_code"] < 500
|
||||
and result["status_code"] not in (401, 403)
|
||||
):
|
||||
models_url = build_models_url(base)
|
||||
try:
|
||||
r2 = httpx.get(models_url, headers=headers, timeout=timeout, verify=llm_verify())
|
||||
result2 = _result_from_response(r2)
|
||||
if result2["reachable"]:
|
||||
return result2
|
||||
except Exception:
|
||||
pass
|
||||
return result
|
||||
except Exception as e:
|
||||
last_error = str(e)[:120]
|
||||
|
||||
return {"reachable": False, "status_code": None, "error": last_error}
|
||||
|
||||
|
||||
|
||||
def _model_endpoint_error_message(base_url: str, ping: Dict[str, Any] = None) -> str:
|
||||
"""Return a provider-aware error message for failed endpoint probes."""
|
||||
ping = ping or {}
|
||||
@@ -850,6 +959,14 @@ def _visible_models(cached_models, hidden_models, pinned_models=None):
|
||||
return [m for m in merged if m not in hidden]
|
||||
|
||||
|
||||
def _api_key_fingerprint(api_key: Optional[str]) -> str:
|
||||
"""Stable, non-secret label for distinguishing same-URL credentials."""
|
||||
key = (api_key or "").strip()
|
||||
if not key:
|
||||
return ""
|
||||
return hashlib.sha256(key.encode("utf-8")).hexdigest()[:8]
|
||||
|
||||
|
||||
def setup_model_routes(model_discovery):
|
||||
router = APIRouter(prefix="/api")
|
||||
|
||||
@@ -951,6 +1068,17 @@ def setup_model_routes(model_discovery):
|
||||
ok, info = _should_refresh_endpoint(ep, now, force=force)
|
||||
if not ok:
|
||||
continue
|
||||
if getattr(ep, "provider_auth_id", None):
|
||||
try:
|
||||
from src.endpoint_resolver import resolve_endpoint_runtime
|
||||
info["base"], info["api_key"] = resolve_endpoint_runtime(
|
||||
ep,
|
||||
owner=getattr(ep, "owner", None),
|
||||
)
|
||||
info["key"] = _refresh_key(info["base"], info["api_key"])
|
||||
except Exception as e:
|
||||
logger.warning("Skipping model refresh for %s: could not resolve provider auth: %s", getattr(ep, "name", ep.id), e)
|
||||
continue
|
||||
groups.setdefault(info["key"], {
|
||||
"base": info["base"],
|
||||
"api_key": info["api_key"],
|
||||
@@ -1104,8 +1232,9 @@ def setup_model_routes(model_discovery):
|
||||
raise HTTPException(401, "Not authenticated")
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.error('Auth gate error in GET /api/models, failing closed: %s', e)
|
||||
raise HTTPException(status_code=500, detail='Internal error')
|
||||
# Admins see every endpoint (they manage the global pool); regular
|
||||
# users get the owner-scoped view.
|
||||
_is_admin = False
|
||||
@@ -1219,12 +1348,20 @@ def setup_model_routes(model_discovery):
|
||||
"endpoint_kind": kind,
|
||||
}
|
||||
try:
|
||||
t0 = _time.time()
|
||||
ping = _ping_endpoint(base, ep.api_key, timeout=1.5)
|
||||
entry["latency_ms"] = round((_time.time() - t0) * 1000)
|
||||
entry["status"] = "online" if ping.get("reachable") or cached_count else "offline"
|
||||
entry["error"] = ping.get("error")
|
||||
entry["model_count"] = cached_count or (len(ANTHROPIC_MODELS) if provider == "anthropic" else 0)
|
||||
if _is_discovery_only_provider(provider):
|
||||
# No general health endpoint — an unauthenticated GET just
|
||||
# 401s. Report status from cached models instead of pinging.
|
||||
entry["latency_ms"] = None
|
||||
entry["status"] = "online" if cached_count else "offline"
|
||||
entry["error"] = None
|
||||
entry["model_count"] = cached_count
|
||||
else:
|
||||
t0 = _time.time()
|
||||
ping = _ping_endpoint(base, ep.api_key, timeout=1.5)
|
||||
entry["latency_ms"] = round((_time.time() - t0) * 1000)
|
||||
entry["status"] = "online" if ping.get("reachable") or cached_count else "offline"
|
||||
entry["error"] = ping.get("error")
|
||||
entry["model_count"] = cached_count or (len(ANTHROPIC_MODELS) if provider == "anthropic" else 0)
|
||||
except Exception as e:
|
||||
entry["latency_ms"] = None
|
||||
entry["status"] = "online" if cached_count else "offline"
|
||||
@@ -1257,7 +1394,7 @@ def setup_model_routes(model_discovery):
|
||||
if ep_id and ep_id not in endpoints_cache:
|
||||
ep = db.query(ModelEndpoint).filter(ModelEndpoint.id == ep_id).first()
|
||||
if ep:
|
||||
endpoints_cache[ep_id] = {"base_url": ep.base_url, "api_key": ep.api_key}
|
||||
endpoints_cache[ep_id] = {"base_url": ep.base_url, "api_key": _resolve_probe_key(ep)}
|
||||
ep_data = endpoints_cache.get(ep_id)
|
||||
if not ep_data:
|
||||
# Try to find by base_url from the model's endpoint field
|
||||
@@ -1296,7 +1433,7 @@ def setup_model_routes(model_discovery):
|
||||
"id": ep.id,
|
||||
"name": ep.name,
|
||||
"base_url": ep.base_url,
|
||||
"api_key": ep.api_key,
|
||||
"api_key": _resolve_probe_key(ep),
|
||||
})
|
||||
finally:
|
||||
db.close()
|
||||
@@ -1385,18 +1522,21 @@ def setup_model_routes(model_discovery):
|
||||
# Endpoint counts as reachable if it has any model — including
|
||||
# admin-pinned IDs that a probe would never surface.
|
||||
status = "online" if (all_models or pinned) else "offline"
|
||||
base = _normalize_base(r.base_url)
|
||||
ping = None
|
||||
if not all_models and not pinned and r.is_enabled:
|
||||
# Discovery-only providers have no health endpoint — an
|
||||
# unauthenticated ping just 401s, so don't bother.
|
||||
if not all_models and not pinned and r.is_enabled and not _is_discovery_only_provider(_detect_provider(base)):
|
||||
ping = _ping_endpoint(r.base_url, r.api_key, timeout=1.0)
|
||||
if ping.get("reachable"):
|
||||
status = "empty"
|
||||
base = _normalize_base(r.base_url)
|
||||
kind = _effective_endpoint_kind(r, base)
|
||||
results.append({
|
||||
"id": r.id,
|
||||
"name": r.name,
|
||||
"base_url": r.base_url,
|
||||
"has_key": bool(r.api_key),
|
||||
"api_key_fingerprint": _api_key_fingerprint(r.api_key),
|
||||
"is_enabled": r.is_enabled,
|
||||
"models": visible,
|
||||
"pinned_models": pinned,
|
||||
@@ -1463,21 +1603,34 @@ def setup_model_routes(model_discovery):
|
||||
)
|
||||
explicit_timeout = _explicit_model_list_timeout(base_url, requested_kind, refresh_timeout)
|
||||
|
||||
# Dedupe: if an endpoint with the same base_url already exists and
|
||||
# is reachable by the caller (shared or owned by them), return it
|
||||
# instead of creating a duplicate row. Fixes "Scan for Servers"
|
||||
# re-adding manually-added endpoints under their host:port name.
|
||||
# Dedupe: if an endpoint with the same base_url and compatible
|
||||
# credentials already exists and is reachable by the caller (shared or
|
||||
# owned by them), return it instead of creating a duplicate row. Keep
|
||||
# same-url/different-key rows distinct so users can group the same
|
||||
# provider URL under multiple credentials.
|
||||
from src.auth_helpers import get_current_user as _gcu_dedup
|
||||
_caller = _gcu_dedup(request) or None
|
||||
_incoming_api_key = api_key.strip()
|
||||
_db_dedup = SessionLocal()
|
||||
try:
|
||||
existing = (
|
||||
_same_url_rows = (
|
||||
_db_dedup.query(ModelEndpoint)
|
||||
.filter(ModelEndpoint.base_url == base_url)
|
||||
.filter((ModelEndpoint.owner.is_(None)) | (ModelEndpoint.owner == _caller))
|
||||
.order_by(ModelEndpoint.owner.desc()) # prefer owned over shared
|
||||
.first()
|
||||
.all()
|
||||
)
|
||||
existing = None
|
||||
_empty_key_existing = None
|
||||
for _candidate in _same_url_rows:
|
||||
_candidate_key = (getattr(_candidate, "api_key", None) or "").strip()
|
||||
if _candidate_key == _incoming_api_key:
|
||||
existing = _candidate
|
||||
break
|
||||
if _incoming_api_key and not _candidate_key and _empty_key_existing is None:
|
||||
_empty_key_existing = _candidate
|
||||
if existing is None and _incoming_api_key and _empty_key_existing is not None:
|
||||
existing = _empty_key_existing
|
||||
if existing:
|
||||
changed = False
|
||||
# Persist any incoming pinned IDs onto the existing row. An
|
||||
@@ -1526,6 +1679,8 @@ def setup_model_routes(model_discovery):
|
||||
"id": existing.id,
|
||||
"name": existing.name,
|
||||
"base_url": existing.base_url,
|
||||
"has_key": bool(existing.api_key),
|
||||
"api_key_fingerprint": _api_key_fingerprint(existing.api_key),
|
||||
"models": _visible_models(
|
||||
existing_models,
|
||||
getattr(existing, "hidden_models", None),
|
||||
@@ -1599,6 +1754,8 @@ def setup_model_routes(model_discovery):
|
||||
"id": ep_id,
|
||||
"name": name.strip(),
|
||||
"base_url": base_url,
|
||||
"has_key": bool(api_key.strip()),
|
||||
"api_key_fingerprint": _api_key_fingerprint(api_key),
|
||||
"models": _merge_model_ids(model_ids, _pinned),
|
||||
"pinned_models": _pinned,
|
||||
"online": bool(model_ids) or bool(_pinned) or bool(ping.get("reachable")),
|
||||
@@ -1648,7 +1805,7 @@ def setup_model_routes(model_discovery):
|
||||
ep = db.query(ModelEndpoint).filter(ModelEndpoint.id == ep_id).first()
|
||||
if not ep:
|
||||
raise HTTPException(404, "Endpoint not found")
|
||||
ep_data = {"id": ep.id, "name": ep.name, "base_url": ep.base_url, "api_key": ep.api_key}
|
||||
ep_data = {"id": ep.id, "name": ep.name, "base_url": ep.base_url, "api_key": _resolve_probe_key(ep)}
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
@@ -1712,7 +1869,7 @@ def setup_model_routes(model_discovery):
|
||||
category = _classify_endpoint(base, kind)
|
||||
timeout = _manual_refresh_timeout(ep, category, refresh_timeout)
|
||||
try:
|
||||
probed = _probe_endpoint(base, ep.api_key, timeout=timeout)
|
||||
probed = _probe_endpoint(base, _resolve_probe_key(ep), timeout=timeout)
|
||||
except Exception as exc:
|
||||
logger.warning("Manual model refresh failed for endpoint %s at %s: %s", ep_id, base, exc)
|
||||
probed = []
|
||||
@@ -1948,6 +2105,8 @@ def setup_model_routes(model_discovery):
|
||||
"name": ep.name,
|
||||
"model_type": ep.model_type,
|
||||
"base_url": ep.base_url,
|
||||
"has_key": bool(ep.api_key),
|
||||
"api_key_fingerprint": _api_key_fingerprint(ep.api_key),
|
||||
"pinned_models": _normalize_model_ids(getattr(ep, "pinned_models", None)),
|
||||
"endpoint_kind": getattr(ep, "endpoint_kind", None) or "auto",
|
||||
"model_refresh_mode": getattr(ep, "model_refresh_mode", None) or "auto",
|
||||
@@ -2049,7 +2208,9 @@ def setup_model_routes(model_discovery):
|
||||
cleared_user_preferences = _clear_user_prefs_for_endpoint(ep_id)
|
||||
cleared_sessions = _clear_sessions_for_endpoint(db, ep.base_url)
|
||||
cleared_loaded_sessions = _clear_loaded_sessions_for_endpoint(ep.base_url)
|
||||
auth_id = getattr(ep, "provider_auth_id", None)
|
||||
db.delete(ep)
|
||||
cleared_provider_auth = _delete_orphaned_provider_auth(db, auth_id, exclude_ep_id=ep_id)
|
||||
db.commit()
|
||||
_invalidate_models_cache()
|
||||
_local_probe_cache["data"] = None
|
||||
@@ -2059,6 +2220,7 @@ def setup_model_routes(model_discovery):
|
||||
"cleared_user_preferences": cleared_user_preferences,
|
||||
"cleared_sessions": cleared_sessions,
|
||||
"cleared_loaded_sessions": cleared_loaded_sessions,
|
||||
"cleared_provider_auth": cleared_provider_auth,
|
||||
}
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
+161
-16
@@ -11,6 +11,7 @@ from pydantic import BaseModel
|
||||
|
||||
from core.database import SessionLocal, Note
|
||||
from src.auth_helpers import get_current_user
|
||||
from src.constants import DATA_DIR
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -95,6 +96,32 @@ def _note_to_dict(note: Note) -> Dict[str, Any]:
|
||||
}
|
||||
|
||||
|
||||
def _reminder_text_from_note(note: Note) -> tuple[str, str]:
|
||||
"""Return the reminder title/body from a stored note row."""
|
||||
title = (note.title or "Note reminder").strip() or "Note reminder"
|
||||
if note.items:
|
||||
try:
|
||||
items = json.loads(note.items)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
items = None
|
||||
if isinstance(items, list):
|
||||
pending: list[str] = []
|
||||
for item in items:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
if item.get("done") or item.get("checked"):
|
||||
continue
|
||||
text = str(item.get("text") or "").strip()
|
||||
if text:
|
||||
pending.append(text)
|
||||
if pending:
|
||||
shown = "\n".join(f"- {text}" for text in pending[:8])
|
||||
extra = f"\n...and {len(pending) - 8} more" if len(pending) > 8 else ""
|
||||
return title, f"Pending ({len(pending)}):\n{shown}{extra}"
|
||||
return title, f"{len(items)} item{'s' if len(items) != 1 else ''}"
|
||||
return title, (note.content or "").strip()[:400]
|
||||
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Reminder dispatch — module-level so background tasks (built-in actions)
|
||||
@@ -114,8 +141,9 @@ async def dispatch_reminder(
|
||||
note_id: str,
|
||||
owner: str = "",
|
||||
queue_browser: bool = True,
|
||||
settings_override: dict | None = None,
|
||||
) -> dict:
|
||||
"""Fire a reminder via the configured channel (browser/email/ntfy).
|
||||
"""Fire a reminder via the configured channel (browser/email/ntfy/webhook).
|
||||
|
||||
Args:
|
||||
title: short headline shown to the user
|
||||
@@ -129,7 +157,7 @@ async def dispatch_reminder(
|
||||
nothing is "sent" synchronously for it — the channel just routes there.
|
||||
"""
|
||||
from src.settings import load_settings
|
||||
settings = load_settings()
|
||||
settings = {**load_settings(), **(settings_override or {})}
|
||||
channel = settings.get("reminder_channel", "browser")
|
||||
llm_on = bool(settings.get("reminder_llm_synthesis", False))
|
||||
title = (title or "").strip()
|
||||
@@ -143,7 +171,7 @@ async def dispatch_reminder(
|
||||
from datetime import datetime as _dt, timezone as _tz, timedelta as _td
|
||||
from pathlib import Path as _P
|
||||
_slug = "".join(c if (c.isalnum() or c in "-_.@") else "_" for c in (owner or "default"))
|
||||
cache_path = _P(f"data/note_pings_{_slug}.json")
|
||||
cache_path = _P(DATA_DIR) / f"note_pings_{_slug}.json"
|
||||
if cache_path.exists():
|
||||
cache = _json.loads(cache_path.read_text(encoding="utf-8"))
|
||||
last = cache.get(cache_key)
|
||||
@@ -160,13 +188,14 @@ async def dispatch_reminder(
|
||||
# Treat those as browser-only dedupe so email reminders can be
|
||||
# retried by the backend scanner after a failed frontend path.
|
||||
should_skip = last_dt >= _dt.now(_tz.utc) - _td(minutes=25)
|
||||
if should_skip and channel in ("email", "ntfy"):
|
||||
if should_skip and channel in ("email", "ntfy", "webhook"):
|
||||
should_skip = last_channel == channel
|
||||
if should_skip:
|
||||
return {
|
||||
"synthesis": None,
|
||||
"email_sent": False,
|
||||
"ntfy_sent": False,
|
||||
"webhook_sent": False,
|
||||
"browser_sent": True,
|
||||
"skipped": True,
|
||||
}
|
||||
@@ -179,9 +208,9 @@ async def dispatch_reminder(
|
||||
try:
|
||||
from src.endpoint_resolver import resolve_endpoint
|
||||
from src.llm_core import llm_call_async
|
||||
url, model, headers = resolve_endpoint("utility")
|
||||
url, model, headers = resolve_endpoint("utility", owner=owner or None)
|
||||
if not url:
|
||||
url, model, headers = resolve_endpoint("default")
|
||||
url, model, headers = resolve_endpoint("default", owner=owner or None)
|
||||
if url and model:
|
||||
raw = await llm_call_async(
|
||||
url=url, model=model,
|
||||
@@ -360,6 +389,76 @@ async def dispatch_reminder(
|
||||
email_error = str(e) or e.__class__.__name__
|
||||
logger.warning(f"Reminder email send failed: {e}")
|
||||
|
||||
webhook_sent = False
|
||||
webhook_error = ""
|
||||
if channel == "webhook":
|
||||
try:
|
||||
import httpx
|
||||
import json as _wjson
|
||||
from src.integrations import load_integrations
|
||||
# Built-in payload defaults for known presets so users don't have
|
||||
# to configure a template just to use a standard service.
|
||||
_PRESET_TEMPLATE_DEFAULTS = {
|
||||
"discord_webhook": '{"embeds": [{"title": "{{title}}", "description": "{{message}}", "color": 5793266}]}',
|
||||
}
|
||||
intg_id = settings.get("reminder_webhook_integration_id", "").strip()
|
||||
template = settings.get("reminder_webhook_payload_template", "").strip()
|
||||
if not intg_id:
|
||||
webhook_error = "No webhook integration selected"
|
||||
else:
|
||||
intg = next(
|
||||
(i for i in load_integrations()
|
||||
if i.get("id") == intg_id and i.get("base_url")),
|
||||
None,
|
||||
)
|
||||
if not intg:
|
||||
webhook_error = f"Integration {intg_id!r} not found or missing base URL"
|
||||
else:
|
||||
# Fall back to a built-in default for known presets so
|
||||
# users don't have to configure a template for standard
|
||||
# services like Discord.
|
||||
if not template:
|
||||
template = _PRESET_TEMPLATE_DEFAULTS.get(intg.get("preset", ""), "")
|
||||
if not template:
|
||||
webhook_error = "No payload template configured"
|
||||
else:
|
||||
# Render template: JSON-escape the values so the result
|
||||
# is always valid JSON regardless of special characters.
|
||||
# dumps() returns `"value"` — strip outer quotes.
|
||||
msg = (synthesis or note_body or title or "Reminder")[:4000]
|
||||
_t = _wjson.dumps(title or "Reminder")[1:-1]
|
||||
_m = _wjson.dumps(msg)[1:-1]
|
||||
rendered = template.replace("{{title}}", _t).replace("{{message}}", _m)
|
||||
hdrs = {"Content-Type": "application/json"}
|
||||
api_key = intg.get("api_key", "")
|
||||
auth_type = (intg.get("auth_type") or "none").lower()
|
||||
if api_key:
|
||||
if auth_type == "bearer":
|
||||
hdrs["Authorization"] = f"Bearer {api_key}"
|
||||
elif auth_type == "header":
|
||||
hdrs[intg.get("auth_header") or "Authorization"] = api_key
|
||||
url = intg["base_url"].rstrip("/")
|
||||
# SSRF guard — matches the pattern used by webhook_routes,
|
||||
# CalDAV, search, and embeddings. Blocks link-local / metadata
|
||||
# addresses (169.254.x.x) by default; set
|
||||
# REMINDER_WEBHOOK_BLOCK_PRIVATE_IPS=true to also block
|
||||
# RFC-1918 ranges for locked-down deployments.
|
||||
import os as _os
|
||||
from src.url_safety import check_outbound_url as _chk
|
||||
_block = _os.getenv("REMINDER_WEBHOOK_BLOCK_PRIVATE_IPS", "false").lower() == "true"
|
||||
_ok, _reason = _chk(url, block_private=_block)
|
||||
if not _ok:
|
||||
webhook_error = f"Webhook URL rejected: {_reason}"
|
||||
else:
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
resp = await client.post(url, content=rendered.encode(), headers=hdrs)
|
||||
webhook_sent = resp.is_success
|
||||
if not webhook_sent:
|
||||
webhook_error = f"Webhook returned HTTP {resp.status_code}"
|
||||
except Exception as e:
|
||||
webhook_error = str(e) or e.__class__.__name__
|
||||
logger.warning(f"Reminder webhook send failed: {e}")
|
||||
|
||||
ntfy_sent = False
|
||||
ntfy_error = ""
|
||||
if channel == "ntfy":
|
||||
@@ -415,7 +514,7 @@ async def dispatch_reminder(
|
||||
# second send for the same note within 25 min. Without this, a note
|
||||
# whose due_date fires while the user has the app open got TWO emails
|
||||
# (frontend-fired here + background-fired by ping_notes 0–5 min later).
|
||||
if (email_sent or ntfy_sent or browser_sent or local_browser_sent) and note_id:
|
||||
if (email_sent or ntfy_sent or webhook_sent or browser_sent or local_browser_sent) and note_id:
|
||||
try:
|
||||
import json as _json
|
||||
from datetime import datetime as _dt, timezone as _tz
|
||||
@@ -425,13 +524,13 @@ async def dispatch_reminder(
|
||||
_STATE = cache_path
|
||||
if _STATE is None:
|
||||
_slug = "".join(c if (c.isalnum() or c in "-_.@") else "_" for c in (owner or "default"))
|
||||
_STATE = _P(f"data/note_pings_{_slug}.json")
|
||||
_STATE = _P(DATA_DIR) / f"note_pings_{_slug}.json"
|
||||
_STATE.parent.mkdir(parents=True, exist_ok=True)
|
||||
try:
|
||||
_cache = cache or (_json.loads(_STATE.read_text(encoding="utf-8")) if _STATE.exists() else {})
|
||||
except Exception:
|
||||
_cache = {}
|
||||
sent_channel = "email" if email_sent else "ntfy" if ntfy_sent else "browser"
|
||||
sent_channel = "email" if email_sent else "ntfy" if ntfy_sent else "webhook" if webhook_sent else "browser"
|
||||
_cache[cache_key or str(note_id)] = {
|
||||
"at": _dt.now(_tz.utc).isoformat(),
|
||||
"channel": sent_channel,
|
||||
@@ -441,11 +540,14 @@ async def dispatch_reminder(
|
||||
logger.debug(f"dispatch_reminder: cache write failed: {_e}")
|
||||
|
||||
return {
|
||||
"channel": channel,
|
||||
"synthesis": synthesis,
|
||||
"email_sent": email_sent,
|
||||
"email_error": email_error,
|
||||
"ntfy_sent": ntfy_sent,
|
||||
"ntfy_error": ntfy_error,
|
||||
"webhook_sent": webhook_sent,
|
||||
"webhook_error": webhook_error,
|
||||
"browser_sent": browser_sent or local_browser_sent,
|
||||
}
|
||||
|
||||
@@ -467,6 +569,23 @@ def setup_note_routes(task_scheduler=None):
|
||||
def _owner(request: Request) -> Optional[str]:
|
||||
return get_current_user(request)
|
||||
|
||||
def _is_admin_or_single_user(request: Request, user: str | None) -> bool:
|
||||
if user == "internal-tool":
|
||||
return True
|
||||
if not user:
|
||||
# require_user() already admitted this request, which only happens
|
||||
# for auth-disabled, loopback-bypass, or unconfigured single-user
|
||||
# modes. There is no separate non-admin account boundary there.
|
||||
return True
|
||||
try:
|
||||
from core.auth import AuthManager
|
||||
auth_mgr = getattr(request.app.state, "auth_manager", None) or AuthManager()
|
||||
if not getattr(auth_mgr, "is_configured", True):
|
||||
return True
|
||||
return bool(auth_mgr.is_admin(user))
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
# --- LIST ---
|
||||
@router.get("")
|
||||
def list_notes(
|
||||
@@ -684,20 +803,46 @@ def setup_note_routes(task_scheduler=None):
|
||||
"""
|
||||
# Gate against anonymous callers — LLM synthesis can burn tokens.
|
||||
from src.auth_helpers import require_user as _ru
|
||||
_ru(request)
|
||||
user = _ru(request)
|
||||
body = await request.json()
|
||||
note_id = body.get("note_id")
|
||||
title = (body.get("title") or "").strip()
|
||||
note_body = (body.get("body") or "").strip()
|
||||
note_id = str(body.get("note_id") or "").strip()
|
||||
if not note_id:
|
||||
raise HTTPException(400, "note_id required")
|
||||
|
||||
# Delegate to the module-level helper so background tasks can reuse
|
||||
# the same dispatch without an HTTP roundtrip + auth cookie.
|
||||
caller = _owner(request)
|
||||
is_test = note_id.startswith("test-")
|
||||
is_admin = _is_admin_or_single_user(request, user or caller)
|
||||
_override: dict = {}
|
||||
if is_test:
|
||||
if not is_admin:
|
||||
raise HTTPException(403, "Admin only")
|
||||
title = (body.get("title") or "Test Reminder").strip() or "Test Reminder"
|
||||
note_body = (body.get("body") or "").strip()
|
||||
# Optional overrides let the admin settings test button pass the
|
||||
# current UI values directly so it never races a pending save.
|
||||
if body.get("channel"):
|
||||
_override["reminder_channel"] = body["channel"]
|
||||
if body.get("webhook_integration_id"):
|
||||
_override["reminder_webhook_integration_id"] = body["webhook_integration_id"]
|
||||
if body.get("webhook_payload_template"):
|
||||
_override["reminder_webhook_payload_template"] = body["webhook_payload_template"]
|
||||
else:
|
||||
db = SessionLocal()
|
||||
try:
|
||||
note = db.query(Note).filter(Note.id == note_id).first()
|
||||
if not note:
|
||||
raise HTTPException(404, "Note not found")
|
||||
if caller is not None and note.owner != caller:
|
||||
raise HTTPException(404, "Note not found")
|
||||
title, note_body = _reminder_text_from_note(note)
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
return await dispatch_reminder(
|
||||
title=title, note_body=note_body, note_id=note_id,
|
||||
owner=_owner(request) or "",
|
||||
owner=caller or "",
|
||||
queue_browser=False,
|
||||
settings_override=_override or None,
|
||||
)
|
||||
|
||||
# --- REORDER NOTES ---
|
||||
|
||||
+13
-12
@@ -6,16 +6,14 @@ import uuid
|
||||
from typing import List, Tuple
|
||||
from fastapi import APIRouter, HTTPException, Query, Request, UploadFile, File, Depends
|
||||
from src.request_models import DirectoryRequest
|
||||
from core.constants import BASE_DIR, PERSONAL_DIR
|
||||
from core.constants import BASE_DIR, PERSONAL_DIR, PERSONAL_UPLOADS_DIR
|
||||
from src.rag_singleton import get_rag_manager
|
||||
from src.auth_helpers import get_current_user, require_user
|
||||
from src.auth_helpers import require_privilege, require_user
|
||||
from core.middleware import require_admin
|
||||
from src.upload_handler import secure_filename
|
||||
from src.upload_limits import PERSONAL_UPLOAD_MAX_BYTES
|
||||
|
||||
UPLOADS_DIR = os.path.join(BASE_DIR, "data", "personal_uploads")
|
||||
MAX_PERSONAL_UPLOAD_BYTES = int(
|
||||
os.getenv("ODYSSEUS_PERSONAL_UPLOAD_MAX_BYTES", str(25 * 1024 * 1024))
|
||||
)
|
||||
UPLOADS_DIR = PERSONAL_UPLOADS_DIR
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -194,7 +192,7 @@ def setup_personal_routes(personal_docs_manager, rag_manager, rag_available):
|
||||
@router.post("/upload")
|
||||
async def upload_files_to_rag(request: Request, files: List[UploadFile] = File(...)):
|
||||
"""Upload files directly into RAG. Supports text and PDF."""
|
||||
user = get_current_user(request)
|
||||
user = require_privilege(request, "can_use_documents")
|
||||
rag = _rag()
|
||||
if not rag:
|
||||
raise HTTPException(503, "RAG system is not available — is the embedding service running?")
|
||||
@@ -208,8 +206,8 @@ def setup_personal_routes(personal_docs_manager, rag_manager, rag_available):
|
||||
for upload in files:
|
||||
try:
|
||||
file_path, stored_name, safe_name = _unique_personal_upload_path(upload_dir, upload.filename)
|
||||
content_bytes = await upload.read(MAX_PERSONAL_UPLOAD_BYTES + 1)
|
||||
if len(content_bytes) > MAX_PERSONAL_UPLOAD_BYTES:
|
||||
content_bytes = await upload.read(PERSONAL_UPLOAD_MAX_BYTES + 1)
|
||||
if len(content_bytes) > PERSONAL_UPLOAD_MAX_BYTES:
|
||||
logger.warning(f"Rejected oversized personal upload: {upload.filename!r}")
|
||||
total_failed += 1
|
||||
continue
|
||||
@@ -286,9 +284,12 @@ def setup_personal_routes(personal_docs_manager, rag_manager, rag_available):
|
||||
except ValueError:
|
||||
# commonpath raises on mixed drives / non-comparable paths
|
||||
in_uploads = False
|
||||
if in_uploads and abs_target != base_abs and os.path.exists(abs_target):
|
||||
os.remove(abs_target)
|
||||
deleted_from_disk = True
|
||||
if in_uploads and abs_target != base_abs:
|
||||
try:
|
||||
os.remove(abs_target)
|
||||
deleted_from_disk = True
|
||||
except FileNotFoundError:
|
||||
pass # already gone — race with another request or cleanup
|
||||
|
||||
# Exclude the file from the listing (persists across restarts)
|
||||
personal_docs_manager.exclude_file(filepath)
|
||||
|
||||
@@ -4,8 +4,9 @@ import os
|
||||
from typing import Optional
|
||||
from fastapi import APIRouter, Request
|
||||
from src.auth_helpers import get_current_user
|
||||
from src.constants import USER_PREFS_FILE
|
||||
|
||||
PREFS_FILE = os.path.join("data", "user_prefs.json")
|
||||
PREFS_FILE = USER_PREFS_FILE
|
||||
|
||||
|
||||
def _load():
|
||||
|
||||
@@ -9,6 +9,7 @@ from pydantic import BaseModel, Field
|
||||
|
||||
from src.request_models import PresetUpdateRequest
|
||||
from core.middleware import require_admin
|
||||
from src.auth_helpers import effective_user
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -100,7 +101,8 @@ def setup_preset_routes(preset_manager) -> APIRouter:
|
||||
|
||||
try:
|
||||
model_spec = data.get("model") or ""
|
||||
url, model, headers = _resolve_model(model_spec)
|
||||
user = effective_user(request)
|
||||
url, model, headers = _resolve_model(model_spec, owner=user)
|
||||
result = await llm_call_async(url, model, messages, temperature=0.8, max_tokens=500, headers=headers)
|
||||
return {"success": True, "prompt": result.strip()}
|
||||
except Exception as e:
|
||||
|
||||
+61
-46
@@ -14,6 +14,7 @@ from fastapi.responses import HTMLResponse, StreamingResponse
|
||||
from pydantic import BaseModel, Field
|
||||
from src.endpoint_resolver import resolve_endpoint
|
||||
from src.auth_helpers import _auth_disabled, get_current_user
|
||||
from src.constants import DEEP_RESEARCH_DIR
|
||||
|
||||
_SESSION_ID_RE = re.compile(r"^[a-zA-Z0-9-]{1,128}$")
|
||||
|
||||
@@ -37,13 +38,15 @@ def _first_chat_model(models) -> str:
|
||||
return (models[0] if models else "")
|
||||
|
||||
|
||||
def _resolve_research_endpoint(sess) -> tuple:
|
||||
def _resolve_research_endpoint(sess, owner: Optional[str] = None) -> tuple:
|
||||
"""Return (endpoint_url, model, headers) for Deep Research, checking admin overrides."""
|
||||
owner = owner or getattr(sess, "owner", None) or None
|
||||
url, model, headers = resolve_endpoint(
|
||||
"research",
|
||||
fallback_url=sess.endpoint_url,
|
||||
fallback_model=sess.model,
|
||||
fallback_headers=sess.headers,
|
||||
owner=owner,
|
||||
)
|
||||
return url, model, headers
|
||||
|
||||
@@ -72,6 +75,38 @@ def _owned_enabled_endpoint(db, owner, endpoint_id=None):
|
||||
return owner_filter(q, ModelEndpoint, owner).first()
|
||||
|
||||
|
||||
def _resolve_endpoint_runtime(ep, owner=None, model: Optional[str] = None):
|
||||
"""Resolve a ModelEndpoint row into (chat_url, model, headers).
|
||||
|
||||
Mirrors endpoint_resolver.resolve_endpoint's provider-auth handling for
|
||||
panel-selected research endpoints. ChatGPT Subscription endpoints keep
|
||||
OAuth tokens in ProviderAuthSession, so ep.api_key is intentionally empty.
|
||||
"""
|
||||
from src.endpoint_resolver import (
|
||||
build_chat_url,
|
||||
build_headers,
|
||||
resolve_endpoint_runtime as resolve_model_endpoint_runtime,
|
||||
)
|
||||
|
||||
try:
|
||||
base, api_key = resolve_model_endpoint_runtime(ep, owner=owner)
|
||||
except Exception as e:
|
||||
logger.warning("Could not resolve endpoint credentials for research: %s", e)
|
||||
return None
|
||||
|
||||
ep_model = (model or "").strip()
|
||||
if not ep_model:
|
||||
try:
|
||||
models = json.loads(ep.cached_models) if ep.cached_models else []
|
||||
if models:
|
||||
ep_model = _first_chat_model(models)
|
||||
except Exception:
|
||||
pass
|
||||
if not ep_model:
|
||||
return None
|
||||
return build_chat_url(base), ep_model, build_headers(api_key, base)
|
||||
|
||||
|
||||
def setup_research_routes(research_handler, session_manager=None) -> APIRouter:
|
||||
router = APIRouter(tags=["research"])
|
||||
|
||||
@@ -98,7 +133,7 @@ def setup_research_routes(research_handler, session_manager=None) -> APIRouter:
|
||||
if entry is not None:
|
||||
return entry.get("owner", "") == user
|
||||
# Task no longer in memory — check the persisted JSON.
|
||||
path = Path("data/deep_research") / f"{session_id}.json"
|
||||
path = Path(DEEP_RESEARCH_DIR) / f"{session_id}.json"
|
||||
if not path.exists():
|
||||
return False
|
||||
try:
|
||||
@@ -162,7 +197,7 @@ def setup_research_routes(research_handler, session_manager=None) -> APIRouter:
|
||||
def _assert_owns_research(session_id: str, user: str) -> None:
|
||||
"""404-not-403 ownership gate for a research session's on-disk JSON.
|
||||
Use BEFORE returning any data or mutating the file."""
|
||||
path = Path("data/deep_research") / f"{session_id}.json"
|
||||
path = Path(DEEP_RESEARCH_DIR) / f"{session_id}.json"
|
||||
if not path.exists():
|
||||
raise HTTPException(404, "Research not found")
|
||||
try:
|
||||
@@ -225,7 +260,7 @@ def setup_research_routes(research_handler, session_manager=None) -> APIRouter:
|
||||
):
|
||||
user = _require_user(request)
|
||||
"""List all completed research for the Library panel."""
|
||||
data_dir = Path("data/deep_research")
|
||||
data_dir = Path(DEEP_RESEARCH_DIR)
|
||||
items = []
|
||||
for p in data_dir.glob("*.json"):
|
||||
try:
|
||||
@@ -275,7 +310,7 @@ def setup_research_routes(research_handler, session_manager=None) -> APIRouter:
|
||||
summary, stats — used by the Library preview panel."""
|
||||
user = _require_user(request)
|
||||
_validate_session_id(session_id)
|
||||
path = Path("data/deep_research") / f"{session_id}.json"
|
||||
path = Path(DEEP_RESEARCH_DIR) / f"{session_id}.json"
|
||||
if not path.exists():
|
||||
raise HTTPException(404, "Research not found")
|
||||
try:
|
||||
@@ -292,7 +327,7 @@ def setup_research_routes(research_handler, session_manager=None) -> APIRouter:
|
||||
"""Soft-archive / restore a research report (sets `archived` in its JSON)."""
|
||||
user = _require_user(request)
|
||||
_validate_session_id(session_id)
|
||||
path = Path("data/deep_research") / f"{session_id}.json"
|
||||
path = Path(DEEP_RESEARCH_DIR) / f"{session_id}.json"
|
||||
if not path.exists():
|
||||
raise HTTPException(404, "Research not found")
|
||||
try:
|
||||
@@ -312,7 +347,7 @@ def setup_research_routes(research_handler, session_manager=None) -> APIRouter:
|
||||
"""Delete a research result from disk."""
|
||||
user = _require_user(request)
|
||||
_validate_session_id(session_id)
|
||||
data_dir = Path("data/deep_research")
|
||||
data_dir = Path(DEEP_RESEARCH_DIR)
|
||||
json_path = data_dir / f"{session_id}.json"
|
||||
deleted = False
|
||||
if json_path.exists():
|
||||
@@ -368,7 +403,6 @@ def setup_research_routes(research_handler, session_manager=None) -> APIRouter:
|
||||
|
||||
if body.endpoint_id:
|
||||
from src.database import SessionLocal
|
||||
from src.endpoint_resolver import normalize_base, build_chat_url, build_headers
|
||||
db = SessionLocal()
|
||||
try:
|
||||
# Owner-scoped: never resolve another user's private endpoint
|
||||
@@ -377,35 +411,26 @@ def setup_research_routes(research_handler, session_manager=None) -> APIRouter:
|
||||
ep = _owned_enabled_endpoint(db, user, body.endpoint_id)
|
||||
if not ep:
|
||||
raise HTTPException(404, "Endpoint not found or disabled")
|
||||
base = normalize_base(ep.base_url)
|
||||
ep_url = build_chat_url(base)
|
||||
ep_headers = build_headers(ep.api_key, base)
|
||||
ep_model = body.model or ""
|
||||
if not ep_model:
|
||||
try:
|
||||
import json as _json
|
||||
models = _json.loads(ep.cached_models) if ep.cached_models else []
|
||||
if models:
|
||||
ep_model = _first_chat_model(models)
|
||||
except Exception:
|
||||
pass
|
||||
resolved = _resolve_endpoint_runtime(ep, owner=user, model=body.model)
|
||||
if not resolved:
|
||||
raise HTTPException(400, "Endpoint is not configured with a usable model.")
|
||||
ep_url, ep_model, ep_headers = resolved
|
||||
finally:
|
||||
db.close()
|
||||
else:
|
||||
ep_url, ep_model, ep_headers = resolve_endpoint("research")
|
||||
ep_url, ep_model, ep_headers = resolve_endpoint("research", owner=user)
|
||||
if not ep_url:
|
||||
ep_url, ep_model, ep_headers = resolve_endpoint("utility")
|
||||
ep_url, ep_model, ep_headers = resolve_endpoint("utility", owner=user)
|
||||
# When neither research nor utility is configured, use the user's
|
||||
# configured DEFAULT model (default_endpoint_id/default_model) rather
|
||||
# than arbitrarily grabbing the first enabled endpoint's first model
|
||||
# (which surfaced gpt-3.5). "Default" should mean the default model.
|
||||
if not ep_url:
|
||||
ep_url, ep_model, ep_headers = resolve_endpoint("default")
|
||||
ep_url, ep_model, ep_headers = resolve_endpoint("default", owner=user)
|
||||
if not ep_url:
|
||||
ep_url, ep_model, ep_headers = resolve_endpoint("chat")
|
||||
ep_url, ep_model, ep_headers = resolve_endpoint("chat", owner=user)
|
||||
if not ep_url:
|
||||
from src.database import SessionLocal
|
||||
from src.endpoint_resolver import normalize_base, build_chat_url, build_headers
|
||||
db = SessionLocal()
|
||||
try:
|
||||
# Owner-scoped first-enabled fallback: the caller's own rows
|
||||
@@ -414,18 +439,9 @@ def setup_research_routes(research_handler, session_manager=None) -> APIRouter:
|
||||
# /api/v1/chat fallback (webhook_routes._first_enabled_endpoint).
|
||||
ep = _owned_enabled_endpoint(db, user)
|
||||
if ep:
|
||||
base = normalize_base(ep.base_url)
|
||||
ep_url = build_chat_url(base)
|
||||
ep_headers = build_headers(ep.api_key, base)
|
||||
ep_model = ""
|
||||
if ep.cached_models:
|
||||
try:
|
||||
import json as _json
|
||||
models = _json.loads(ep.cached_models)
|
||||
if models:
|
||||
ep_model = _first_chat_model(models)
|
||||
except Exception:
|
||||
pass
|
||||
resolved = _resolve_endpoint_runtime(ep, owner=user)
|
||||
if resolved:
|
||||
ep_url, ep_model, ep_headers = resolved
|
||||
finally:
|
||||
db.close()
|
||||
if not ep_url:
|
||||
@@ -494,7 +510,7 @@ def setup_research_routes(research_handler, session_manager=None) -> APIRouter:
|
||||
raise HTTPException(404, "No research found for this session")
|
||||
result = research_handler.get_result(session_id)
|
||||
if result is None:
|
||||
p = Path("data/deep_research") / f"{session_id}.json"
|
||||
p = Path(DEEP_RESEARCH_DIR) / f"{session_id}.json"
|
||||
if p.exists():
|
||||
d = json.loads(p.read_text(encoding="utf-8"))
|
||||
return {
|
||||
@@ -534,7 +550,7 @@ def setup_research_routes(research_handler, session_manager=None) -> APIRouter:
|
||||
sources = research_handler.get_sources(session_id) or []
|
||||
query = ""
|
||||
|
||||
path = Path("data/deep_research") / f"{session_id}.json"
|
||||
path = Path(DEEP_RESEARCH_DIR) / f"{session_id}.json"
|
||||
if path.exists():
|
||||
try:
|
||||
disk = json.loads(path.read_text(encoding="utf-8"))
|
||||
@@ -572,19 +588,18 @@ def setup_research_routes(research_handler, session_manager=None) -> APIRouter:
|
||||
ep_headers = dict(r_headers)
|
||||
|
||||
if not ep_url or not ep_model:
|
||||
_merge(*resolve_endpoint("chat"))
|
||||
_merge(*resolve_endpoint("chat", owner=user))
|
||||
if not ep_url or not ep_model:
|
||||
_merge(*resolve_endpoint("research"))
|
||||
_merge(*resolve_endpoint("research", owner=user))
|
||||
if not ep_url or not ep_model:
|
||||
_merge(*resolve_endpoint("utility"))
|
||||
_merge(*resolve_endpoint("utility", owner=user))
|
||||
if not ep_url or not ep_model:
|
||||
# Last resort: any enabled endpoint
|
||||
# Last resort: this user's enabled endpoint, plus legacy shared rows.
|
||||
from src.database import SessionLocal
|
||||
from src.database import ModelEndpoint
|
||||
from src.endpoint_resolver import normalize_base, build_chat_url, build_headers
|
||||
db = SessionLocal()
|
||||
try:
|
||||
ep = db.query(ModelEndpoint).filter(ModelEndpoint.is_enabled == True).first()
|
||||
ep = _owned_enabled_endpoint(db, user)
|
||||
if ep:
|
||||
base = normalize_base(ep.base_url)
|
||||
fallback_url = build_chat_url(base)
|
||||
@@ -594,7 +609,7 @@ def setup_research_routes(research_handler, session_manager=None) -> APIRouter:
|
||||
try:
|
||||
models = json.loads(ep.cached_models)
|
||||
if models:
|
||||
fallback_model = models[0]
|
||||
fallback_model = _first_chat_model(models)
|
||||
except Exception:
|
||||
pass
|
||||
_merge(fallback_url, fallback_model, fallback_headers)
|
||||
|
||||
+48
-33
@@ -10,8 +10,9 @@ import logging
|
||||
from core.session_manager import SessionManager
|
||||
from core.models import ChatMessage
|
||||
from src.request_models import SessionResponse
|
||||
from core.database import Session as DbSession, SessionLocal, Document, GalleryImage
|
||||
from src.auth_helpers import get_current_user, effective_user
|
||||
from core.database import Session as DbSession, SessionLocal, Document, GalleryImage, utcnow_naive
|
||||
from src.auth_helpers import get_current_user, effective_user, _auth_disabled
|
||||
from src.session_actions import is_session_recently_active
|
||||
|
||||
|
||||
def _sanitize_export_filename(name: str) -> str:
|
||||
@@ -92,35 +93,30 @@ def _reject_compact_during_active_run(session_id: str) -> None:
|
||||
|
||||
|
||||
def _verify_session_owner(request: Request, session_id: str, session_manager=None):
|
||||
"""Verify the current user owns the session. Raises 404 if not.
|
||||
"""Verify the current user owns the session, honoring single-user modes.
|
||||
|
||||
Ownership is checked against the DB row when one exists (unchanged). If
|
||||
there is no DB row but the caller owns an in-memory "ghost" session — one
|
||||
that lives only in ``session_manager`` because it was never persisted, or
|
||||
its DB row was removed out-of-band — fall back to the in-memory owner so the
|
||||
user can still manage and delete it. Without this fallback such sessions are
|
||||
listed by ``/api/sessions`` (they come from the in-memory manager) yet every
|
||||
per-session operation 404s, making them impossible to delete (issue #1044).
|
||||
|
||||
``session_manager`` is optional and defaults to ``None`` so existing callers
|
||||
that only care about persisted sessions keep their exact prior behavior.
|
||||
Authenticated requests must match the stored DB or in-memory owner. When
|
||||
auth is disabled and no user is present, treat the app as single-user mode:
|
||||
verify that the session exists, but do not compare its stored owner. This
|
||||
keeps QA/dev instances with AUTH_ENABLED=false from rejecting owner-stamped
|
||||
rows created while auth was previously enabled.
|
||||
"""
|
||||
user = effective_user(request)
|
||||
if not user:
|
||||
raise HTTPException(403, "Authentication required")
|
||||
if not user and not _auth_disabled():
|
||||
raise HTTPException(401, "Authentication required")
|
||||
db = SessionLocal()
|
||||
try:
|
||||
row = db.query(DbSession.owner).filter(DbSession.id == session_id).first()
|
||||
finally:
|
||||
db.close()
|
||||
if row is not None:
|
||||
if row.owner != user:
|
||||
if user and row.owner != user:
|
||||
raise HTTPException(404, f"Session {session_id} not found")
|
||||
return
|
||||
# No DB row — allow the caller to act on an in-memory ghost they own.
|
||||
if session_manager is not None:
|
||||
ghost = getattr(session_manager, "sessions", {}).get(session_id)
|
||||
if ghost is not None and getattr(ghost, "owner", None) == user:
|
||||
if ghost is not None and (not user or getattr(ghost, "owner", None) == user):
|
||||
return
|
||||
raise HTTPException(404, f"Session {session_id} not found")
|
||||
|
||||
@@ -262,7 +258,7 @@ def setup_session_routes(session_manager: SessionManager, config: dict, webhook_
|
||||
last_msg_map = {}
|
||||
mode_map = {}
|
||||
msg_count_map = {}
|
||||
rows = db.query(DbSession.id, DbSession.folder, DbSession.total_input_tokens, DbSession.total_output_tokens, DbSession.is_important, DbSession.created_at, DbSession.updated_at, DbSession.last_message_at, DbSession.mode, DbSession.message_count).filter(DbSession.archived == False).all()
|
||||
rows = db.query(DbSession.id, DbSession.folder, DbSession.total_input_tokens, DbSession.total_output_tokens, DbSession.is_important, DbSession.created_at, DbSession.updated_at, DbSession.last_message_at, DbSession.mode, DbSession.message_count).filter(DbSession.archived == False, DbSession.owner == user).all()
|
||||
for row in rows:
|
||||
folder_map[row.id] = row.folder
|
||||
token_map[row.id] = (row.total_input_tokens or 0) + (row.total_output_tokens or 0)
|
||||
@@ -284,12 +280,14 @@ def setup_session_routes(session_manager: SessionManager, config: dict, webhook_
|
||||
r[0] for r in db.query(Document.session_id)
|
||||
.filter(Document.is_active == True,
|
||||
Document.current_content != None,
|
||||
func.trim(Document.current_content) != "")
|
||||
func.trim(Document.current_content) != "",
|
||||
Document.owner == user)
|
||||
.distinct().all()
|
||||
)
|
||||
img_session_ids = set(
|
||||
r[0] for r in db.query(GalleryImage.session_id)
|
||||
.filter(GalleryImage.session_id != None)
|
||||
.filter(GalleryImage.session_id != None,
|
||||
GalleryImage.owner == user)
|
||||
.distinct().all()
|
||||
)
|
||||
finally:
|
||||
@@ -370,8 +368,13 @@ def setup_session_routes(session_manager: SessionManager, config: dict, webhook_
|
||||
pass
|
||||
elif not model_to_use:
|
||||
from src.llm_core import list_model_ids
|
||||
ids = list_model_ids(endpoint_url, timeout=REQUEST_TIMEOUT,
|
||||
headers=validation_headers)
|
||||
ids = list_model_ids(
|
||||
endpoint_url,
|
||||
timeout=REQUEST_TIMEOUT,
|
||||
headers=validation_headers,
|
||||
owner=user,
|
||||
endpoint_id=endpoint_id.strip() if endpoint_id else None,
|
||||
)
|
||||
if not ids:
|
||||
raise HTTPException(400, "Cannot reach /v1/models")
|
||||
# Default to the first CHAT model — endpoints often list embedding/
|
||||
@@ -385,8 +388,13 @@ def setup_session_routes(session_manager: SessionManager, config: dict, webhook_
|
||||
from src.llm_core import list_model_ids
|
||||
import os as _os
|
||||
req_base = _os.path.basename(model_to_use.rstrip("/"))
|
||||
avail = list_model_ids(endpoint_url, timeout=REQUEST_TIMEOUT,
|
||||
headers=validation_headers)
|
||||
avail = list_model_ids(
|
||||
endpoint_url,
|
||||
timeout=REQUEST_TIMEOUT,
|
||||
headers=validation_headers,
|
||||
owner=user,
|
||||
endpoint_id=endpoint_id.strip() if endpoint_id else None,
|
||||
)
|
||||
if not avail:
|
||||
raise HTTPException(400, "Cannot reach /v1/models")
|
||||
if model_to_use not in avail:
|
||||
@@ -543,22 +551,25 @@ def setup_session_routes(session_manager: SessionManager, config: dict, webhook_
|
||||
ids = body.get("ids", [])
|
||||
except Exception:
|
||||
ids = []
|
||||
deleted_count = 0
|
||||
for sid in ids:
|
||||
try:
|
||||
_verify_session_owner(request, sid, session_manager)
|
||||
session_manager.delete_session(sid)
|
||||
|
||||
# Enforce "starred" protection consistent with single-session delete
|
||||
db = SessionLocal()
|
||||
try:
|
||||
db.query(_CM).filter(_CM.session_id == sid).delete()
|
||||
db.query(DbSession).filter(DbSession.id == sid).delete()
|
||||
db.commit()
|
||||
except Exception:
|
||||
db.rollback()
|
||||
db_sess = db.query(DbSession).filter(DbSession.id == sid).first()
|
||||
if db_sess and db_sess.is_important:
|
||||
continue
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
if session_manager.delete_session(sid):
|
||||
deleted_count += 1
|
||||
except Exception:
|
||||
pass
|
||||
return {"deleted": len(ids)}
|
||||
return {"deleted": deleted_count}
|
||||
|
||||
@router.delete("/session/{sid}")
|
||||
def delete_session(request: Request, sid: str):
|
||||
@@ -924,7 +935,8 @@ def setup_session_routes(session_manager: SessionManager, config: dict, webhook_
|
||||
from src.endpoint_resolver import resolve_endpoint
|
||||
from src.llm_core import llm_call_async
|
||||
|
||||
url, model, headers = resolve_endpoint("utility", owner=get_current_user(request))
|
||||
owner = getattr(session, "owner", None) or effective_user(request)
|
||||
url, model, headers = resolve_endpoint("utility", owner=owner)
|
||||
if not url or not model:
|
||||
url, model, headers = session.endpoint_url, session.model, session.headers
|
||||
if not url or not model:
|
||||
@@ -1006,7 +1018,7 @@ def setup_session_routes(session_manager: SessionManager, config: dict, webhook_
|
||||
}
|
||||
_THROWAWAY_MAX_MESSAGES = 4 # only delete if <= this many messages
|
||||
try:
|
||||
rows = db.query(DbSession).filter(DbSession.archived == False, DbSession.owner == user).all()
|
||||
rows = db.query(DbSession).filter(DbSession.archived == False, DbSession.owner == user).limit(2000).all()
|
||||
folder_map = {r.id: r.folder for r in rows}
|
||||
# Precompute per-session message counts in TWO aggregate queries
|
||||
# instead of 1–3 queries PER session — with many chats the per-row
|
||||
@@ -1017,6 +1029,7 @@ def setup_session_routes(session_manager: SessionManager, config: dict, webhook_
|
||||
db.query(DbMsg.session_id, _sa_func.count(DbMsg.id))
|
||||
.filter(DbMsg.role == "assistant").group_by(DbMsg.session_id).all()
|
||||
)
|
||||
cleanup_now = utcnow_naive()
|
||||
for row in rows:
|
||||
# Never delete important sessions
|
||||
if getattr(row, 'is_important', False):
|
||||
@@ -1029,6 +1042,8 @@ def setup_session_routes(session_manager: SessionManager, config: dict, webhook_
|
||||
if hasattr(session_manager, 'delete_session'):
|
||||
session_manager.delete_session(row.id)
|
||||
continue
|
||||
if is_session_recently_active(row, now=cleanup_now):
|
||||
continue
|
||||
msg_count = _counts.get(row.id, 0)
|
||||
should_delete = False
|
||||
if msg_count == 0:
|
||||
|
||||
+279
-58
@@ -13,6 +13,7 @@ import tempfile
|
||||
from collections import namedtuple
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any
|
||||
from core.platform_compat import IS_APPLE_SILICON, which_tool
|
||||
|
||||
# POSIX-only: `pty`/`fcntl` transitively import `termios`, which does NOT exist
|
||||
# on Windows, so importing them unconditionally crashed app startup there
|
||||
@@ -37,6 +38,7 @@ from core.platform_compat import (
|
||||
IS_WINDOWS,
|
||||
detached_popen_kwargs,
|
||||
find_bash,
|
||||
git_bash_path,
|
||||
)
|
||||
|
||||
|
||||
@@ -92,6 +94,7 @@ def _venv_activate_prefix(venv: str | None) -> str:
|
||||
act = venv if venv.endswith("/bin/activate") else venv.rstrip("/") + "/bin/activate"
|
||||
return f". {act} && "
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
PTY_SUPPORTED = pty is not None and fcntl is not None and hasattr(os, "setsid")
|
||||
@@ -169,7 +172,10 @@ def _package_installed_from_probe(name: str, probe: dict) -> bool:
|
||||
and (dists.get("torch") or modules.get("torch", {}).get("real_module"))
|
||||
)
|
||||
if name == "hf_transfer":
|
||||
return bool(dists.get("hf-transfer") or modules.get("hf_transfer", {}).get("real_module"))
|
||||
return bool(
|
||||
dists.get("hf-transfer")
|
||||
or modules.get("hf_transfer", {}).get("real_module")
|
||||
)
|
||||
return bool(dists.get(name) or modules.get(name, {}).get("real_module"))
|
||||
|
||||
|
||||
@@ -194,8 +200,14 @@ def _package_status_note(name: str, probe: dict) -> str:
|
||||
if binaries.get("llama-server"):
|
||||
parts.append(f"native llama-server: {binaries['llama-server']}")
|
||||
if dists.get("llama-cpp-python"):
|
||||
parts.append(f"python package: llama-cpp-python {dists['llama-cpp-python']}")
|
||||
return "; ".join(parts) if parts else "No native llama-server or llama-cpp-python server package found."
|
||||
parts.append(
|
||||
f"python package: llama-cpp-python {dists['llama-cpp-python']}"
|
||||
)
|
||||
return (
|
||||
"; ".join(parts)
|
||||
if parts
|
||||
else "No native llama-server or llama-cpp-python server package found."
|
||||
)
|
||||
if name == "diffusers":
|
||||
if _package_installed_from_probe(name, probe):
|
||||
return f"diffusers {dists.get('diffusers', 'available')} with torch {dists.get('torch', 'available')}"
|
||||
@@ -205,7 +217,9 @@ def _package_status_note(name: str, probe: dict) -> str:
|
||||
return ""
|
||||
|
||||
|
||||
def _package_pip_update_status(pkg: dict, probe: dict | None = None) -> PackageUpdateStatus:
|
||||
def _package_pip_update_status(
|
||||
pkg: dict, probe: dict | None = None
|
||||
) -> PackageUpdateStatus:
|
||||
"""Return whether the Dependencies UI should offer a generic pip update.
|
||||
|
||||
"Installed" means Cookbook can use the dependency. It does not always mean
|
||||
@@ -213,12 +227,28 @@ def _package_pip_update_status(pkg: dict, probe: dict | None = None) -> PackageU
|
||||
native llama-server can come from a package manager/source build, and a CLI
|
||||
may be on PATH without matching Python package metadata.
|
||||
"""
|
||||
if pkg.get("name") == "APFEL":
|
||||
return PackageUpdateStatus(
|
||||
False,
|
||||
"", # Note is empty because IT DOES allow for updates outside of PIP.
|
||||
)
|
||||
|
||||
if pkg.get("kind") == "system" or not pkg.get("pip"):
|
||||
return PackageUpdateStatus(False, "Update this system dependency outside Odysseus.")
|
||||
return PackageUpdateStatus(
|
||||
False, "Update this system dependency outside Odysseus."
|
||||
)
|
||||
|
||||
name = pkg.get("name")
|
||||
binaries = probe.get("binaries") if isinstance(probe, dict) and isinstance(probe.get("binaries"), dict) else {}
|
||||
dists = probe.get("dists") if isinstance(probe, dict) and isinstance(probe.get("dists"), dict) else {}
|
||||
binaries = (
|
||||
probe.get("binaries")
|
||||
if isinstance(probe, dict) and isinstance(probe.get("binaries"), dict)
|
||||
else {}
|
||||
)
|
||||
dists = (
|
||||
probe.get("dists")
|
||||
if isinstance(probe, dict) and isinstance(probe.get("dists"), dict)
|
||||
else {}
|
||||
)
|
||||
|
||||
if name == "llama_cpp" and binaries.get("llama-server"):
|
||||
return PackageUpdateStatus(
|
||||
@@ -231,7 +261,9 @@ def _package_pip_update_status(pkg: dict, probe: dict | None = None) -> PackageU
|
||||
"Using a vLLM CLI on PATH without Python package metadata; update it outside Odysseus.",
|
||||
)
|
||||
|
||||
return PackageUpdateStatus(True, "Update uses pip in the selected Python environment.")
|
||||
return PackageUpdateStatus(
|
||||
True, "Update uses pip in the selected Python environment."
|
||||
)
|
||||
|
||||
|
||||
def _prepend_user_install_bins_to_path() -> None:
|
||||
@@ -250,7 +282,9 @@ def _prepend_user_install_bins_to_path() -> None:
|
||||
candidates = []
|
||||
candidates.append(os.path.expanduser("~/.local/bin"))
|
||||
|
||||
parts = os.environ.get("PATH", "").split(os.pathsep) if os.environ.get("PATH") else []
|
||||
parts = (
|
||||
os.environ.get("PATH", "").split(os.pathsep) if os.environ.get("PATH") else []
|
||||
)
|
||||
changed = False
|
||||
for path in reversed([p for p in candidates if p]):
|
||||
if path not in parts:
|
||||
@@ -357,9 +391,11 @@ PTY_UNSUPPORTED_ERROR = "pty_unsupported"
|
||||
|
||||
class ShellExecRequest(BaseModel):
|
||||
command: str
|
||||
timeout: int | None = None # optional override; 0 = no timeout (run until client disconnects)
|
||||
use_pty: bool = False # use pseudo-TTY (for progress bars)
|
||||
use_tmux: bool = False # run in tmux session (survives browser disconnect)
|
||||
timeout: int | None = (
|
||||
None # optional override; 0 = no timeout (run until client disconnects)
|
||||
)
|
||||
use_pty: bool = False # use pseudo-TTY (for progress bars)
|
||||
use_tmux: bool = False # run in tmux session (survives browser disconnect)
|
||||
|
||||
|
||||
async def _create_shell(command: str, **kwargs):
|
||||
@@ -368,8 +404,16 @@ async def _create_shell(command: str, **kwargs):
|
||||
POSIX: /bin/sh via create_subprocess_shell (unchanged behaviour).
|
||||
Windows: prefer a real bash (Git Bash/WSL) so bash-syntax commands behave
|
||||
the same as on Linux; fall back to cmd.exe when no bash is installed.
|
||||
Powershell commands are executed directly via cmd.exe /c to avoid quoting
|
||||
and env variable expansion errors under Git Bash.
|
||||
"""
|
||||
if IS_WINDOWS:
|
||||
# PowerShell commands (used by the frontend for Windows log-file polling
|
||||
# and session management) must run directly — passing them through
|
||||
# bash -c mangles $env:VAR syntax and breaks the command.
|
||||
cmd_trim = command.strip()
|
||||
if cmd_trim.startswith("powershell") or cmd_trim.startswith("cmd "):
|
||||
return await asyncio.create_subprocess_shell(command, **kwargs)
|
||||
bash = find_bash()
|
||||
if bash:
|
||||
return await asyncio.create_subprocess_exec(bash, "-c", command, **kwargs)
|
||||
@@ -386,9 +430,7 @@ async def _exec_shell(command: str, timeout: int = EXEC_TIMEOUT) -> Dict[str, An
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
cwd=str(Path.home()),
|
||||
)
|
||||
stdout_b, stderr_b = await asyncio.wait_for(
|
||||
proc.communicate(), timeout=timeout
|
||||
)
|
||||
stdout_b, stderr_b = await asyncio.wait_for(proc.communicate(), timeout=timeout)
|
||||
stdout = stdout_b.decode(errors="replace")[:MAX_OUTPUT]
|
||||
stderr = stderr_b.decode(errors="replace")[:MAX_OUTPUT]
|
||||
return {"stdout": stdout, "stderr": stderr, "exit_code": proc.returncode}
|
||||
@@ -399,7 +441,11 @@ async def _exec_shell(command: str, timeout: int = EXEC_TIMEOUT) -> Dict[str, An
|
||||
await proc.wait()
|
||||
except ProcessLookupError:
|
||||
pass
|
||||
return {"stdout": "", "stderr": f"Command timed out after {timeout}s", "exit_code": -1}
|
||||
return {
|
||||
"stdout": "",
|
||||
"stderr": f"Command timed out after {timeout}s",
|
||||
"exit_code": -1,
|
||||
}
|
||||
except Exception as e:
|
||||
return {"stdout": "", "stderr": str(e), "exit_code": -1}
|
||||
|
||||
@@ -481,7 +527,7 @@ async def _generate_pty(cmd: str, timeout: int, request: Request):
|
||||
if idx == -1:
|
||||
break
|
||||
line = buf[:idx].decode(errors="replace")
|
||||
buf = buf[idx + sep_len:]
|
||||
buf = buf[idx + sep_len :]
|
||||
if line:
|
||||
yield f"data: {json.dumps({'stream': 'stdout', 'data': line})}\n\n"
|
||||
|
||||
@@ -503,7 +549,7 @@ async def _generate_pty(cmd: str, timeout: int, request: Request):
|
||||
if idx == -1:
|
||||
break
|
||||
line = buf[:idx].decode(errors="replace")
|
||||
buf = buf[idx + sep_len:]
|
||||
buf = buf[idx + sep_len :]
|
||||
if line:
|
||||
yield f"data: {json.dumps({'stream': 'stdout', 'data': line})}\n\n"
|
||||
if buf:
|
||||
@@ -534,6 +580,7 @@ def _pty_read(fd: int) -> bytes | None:
|
||||
"""Blocking read from PTY fd. Called via run_in_executor.
|
||||
Returns bytes on data, None on timeout (no data yet)."""
|
||||
import select
|
||||
|
||||
r, _, _ = select.select([fd], [], [], 1.0)
|
||||
if r:
|
||||
try:
|
||||
@@ -557,10 +604,10 @@ async def _generate_tmux(cmd: str, request: Request):
|
||||
script_path = TMUX_LOG_DIR / f"{session_id}.sh"
|
||||
script_path.write_text(
|
||||
f"#!/bin/bash\n"
|
||||
f"ODYSSEUS_USER_SHELL=\"${{SHELL:-}}\"\n"
|
||||
f"if [ -n \"$ODYSSEUS_USER_SHELL\" ] && [ -x \"$ODYSSEUS_USER_SHELL\" ]; then\n"
|
||||
f" ODYSSEUS_USER_PATH=\"$(\"$ODYSSEUS_USER_SHELL\" -ic 'printf \"__ODYSSEUS_PATH__%s\\n\" \"$PATH\"' 2>/dev/null | sed -n 's/^__ODYSSEUS_PATH__//p' | tail -n 1 || true)\"\n"
|
||||
f" if [ -n \"$ODYSSEUS_USER_PATH\" ]; then export PATH=\"$ODYSSEUS_USER_PATH:$PATH\"; fi\n"
|
||||
f'ODYSSEUS_USER_SHELL="${{SHELL:-}}"\n'
|
||||
f'if [ -n "$ODYSSEUS_USER_SHELL" ] && [ -x "$ODYSSEUS_USER_SHELL" ]; then\n'
|
||||
f' ODYSSEUS_USER_PATH="$("$ODYSSEUS_USER_SHELL" -ic \'printf "__ODYSSEUS_PATH__%s\\n" "$PATH"\' 2>/dev/null | sed -n \'s/^__ODYSSEUS_PATH__//p\' | tail -n 1 || true)"\n'
|
||||
f' if [ -n "$ODYSSEUS_USER_PATH" ]; then export PATH="$ODYSSEUS_USER_PATH:$PATH"; fi\n'
|
||||
f"fi\n"
|
||||
f"{cmd} 2>&1 | tee '{log_path}'\n"
|
||||
f"EC=${{PIPESTATUS[0]}}\n"
|
||||
@@ -570,7 +617,9 @@ async def _generate_tmux(cmd: str, request: Request):
|
||||
encoding="utf-8",
|
||||
)
|
||||
script_path.chmod(0o755)
|
||||
logger.info("tmux wrapper script created: session=%s path=%s", session_id, script_path)
|
||||
logger.info(
|
||||
"tmux wrapper script created: session=%s path=%s", session_id, script_path
|
||||
)
|
||||
|
||||
tmux_cmd = f"tmux new-session -d -s {session_id} {shlex.quote(str(script_path))}"
|
||||
|
||||
@@ -602,7 +651,9 @@ async def _generate_tmux(cmd: str, request: Request):
|
||||
# Read new lines from log
|
||||
try:
|
||||
if log_path.exists():
|
||||
lines = log_path.read_text(encoding="utf-8", errors="replace").splitlines()
|
||||
lines = log_path.read_text(
|
||||
encoding="utf-8", errors="replace"
|
||||
).splitlines()
|
||||
new_lines = lines[lines_sent:]
|
||||
for line in new_lines:
|
||||
if line.startswith(":::EXIT_CODE:::"):
|
||||
@@ -630,7 +681,9 @@ async def _generate_tmux(cmd: str, request: Request):
|
||||
# Session ended — do one final read
|
||||
await asyncio.sleep(0.5)
|
||||
if log_path.exists():
|
||||
lines = log_path.read_text(encoding="utf-8", errors="replace").splitlines()
|
||||
lines = log_path.read_text(
|
||||
encoding="utf-8", errors="replace"
|
||||
).splitlines()
|
||||
for line in lines[lines_sent:]:
|
||||
if line.startswith(":::EXIT_CODE:::"):
|
||||
try:
|
||||
@@ -672,8 +725,8 @@ async def _generate_win_detached(cmd: str, request: Request):
|
||||
if bash:
|
||||
script_path = TMUX_LOG_DIR / f"{session_id}.sh"
|
||||
script_path.write_text(
|
||||
f"{cmd} > {shlex.quote(str(log_path))} 2>&1\n"
|
||||
f"echo $? > {shlex.quote(str(exit_path))}\n",
|
||||
f"{cmd} > {shlex.quote(git_bash_path(log_path))} 2>&1\n"
|
||||
f"echo $? > {shlex.quote(git_bash_path(exit_path))}\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
argv = [bash, str(script_path)]
|
||||
@@ -711,7 +764,9 @@ async def _generate_win_detached(cmd: str, request: Request):
|
||||
return
|
||||
try:
|
||||
if log_path.exists():
|
||||
lines = log_path.read_text(encoding="utf-8", errors="replace").splitlines()
|
||||
lines = log_path.read_text(
|
||||
encoding="utf-8", errors="replace"
|
||||
).splitlines()
|
||||
for line in lines[lines_sent:]:
|
||||
yield f"data: {json.dumps({'stream': 'stdout', 'data': line})}\n\n"
|
||||
lines_sent = len(lines)
|
||||
@@ -723,11 +778,18 @@ async def _generate_win_detached(cmd: str, request: Request):
|
||||
await asyncio.sleep(0.3)
|
||||
try:
|
||||
if log_path.exists():
|
||||
lines = log_path.read_text(encoding="utf-8", errors="replace").splitlines()
|
||||
lines = log_path.read_text(
|
||||
encoding="utf-8", errors="replace"
|
||||
).splitlines()
|
||||
for line in lines[lines_sent:]:
|
||||
yield f"data: {json.dumps({'stream': 'stdout', 'data': line})}\n\n"
|
||||
lines_sent = len(lines)
|
||||
exit_code = int((exit_path.read_text(encoding="utf-8", errors="replace").strip() or "0"))
|
||||
exit_code = int(
|
||||
(
|
||||
exit_path.read_text(encoding="utf-8", errors="replace").strip()
|
||||
or "0"
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
exit_code = 0
|
||||
break
|
||||
@@ -753,7 +815,9 @@ def setup_shell_routes() -> APIRouter:
|
||||
return {"stdout": "", "stderr": "No command provided", "exit_code": 1}
|
||||
|
||||
logger.info("User shell exec requested: length=%d", len(cmd))
|
||||
result = await _exec_shell(cmd, timeout=EXEC_TIMEOUT)
|
||||
result = await _exec_shell(
|
||||
cmd, timeout=req.timeout if req.timeout is not None else EXEC_TIMEOUT
|
||||
)
|
||||
return result
|
||||
|
||||
@router.post("/api/shell/stream")
|
||||
@@ -762,9 +826,11 @@ def setup_shell_routes() -> APIRouter:
|
||||
_require_admin(request)
|
||||
cmd = req.command.strip()
|
||||
if not cmd:
|
||||
|
||||
async def empty():
|
||||
yield f"data: {json.dumps({'stream': 'stderr', 'data': 'No command provided'})}\n\n"
|
||||
yield f"data: {json.dumps({'exit_code': 1})}\n\n"
|
||||
|
||||
return StreamingResponse(empty(), media_type="text/event-stream")
|
||||
|
||||
timeout = req.timeout if req.timeout is not None else STREAM_TIMEOUT
|
||||
@@ -781,7 +847,11 @@ def setup_shell_routes() -> APIRouter:
|
||||
if use_tmux:
|
||||
# tmux is POSIX-only; Windows uses a detached-process + logfile tail
|
||||
# that preserves the "survives disconnect" behaviour.
|
||||
gen = _generate_win_detached(cmd, request) if IS_WINDOWS else _generate_tmux(cmd, request)
|
||||
gen = (
|
||||
_generate_win_detached(cmd, request)
|
||||
if IS_WINDOWS
|
||||
else _generate_tmux(cmd, request)
|
||||
)
|
||||
return StreamingResponse(gen, media_type="text/event-stream")
|
||||
|
||||
if use_pty and not IS_WINDOWS:
|
||||
@@ -813,7 +883,12 @@ def setup_shell_routes() -> APIRouter:
|
||||
chunk = await stream.read(4096)
|
||||
if not chunk:
|
||||
if buf:
|
||||
await q.put((name, buf.decode(errors="replace").rstrip("\r\n")))
|
||||
await q.put(
|
||||
(
|
||||
name,
|
||||
buf.decode(errors="replace").rstrip("\r\n"),
|
||||
)
|
||||
)
|
||||
break
|
||||
buf += chunk
|
||||
while True:
|
||||
@@ -821,7 +896,7 @@ def setup_shell_routes() -> APIRouter:
|
||||
if idx == -1:
|
||||
break
|
||||
line = buf[:idx].decode(errors="replace")
|
||||
buf = buf[idx + sep_len:]
|
||||
buf = buf[idx + sep_len :]
|
||||
if line:
|
||||
await q.put((name, line))
|
||||
finally:
|
||||
@@ -880,7 +955,12 @@ def setup_shell_routes() -> APIRouter:
|
||||
return StreamingResponse(generate(), media_type="text/event-stream")
|
||||
|
||||
@router.get("/api/cookbook/packages")
|
||||
async def list_packages(request: Request, host: str | None = None, ssh_port: str | None = None, venv: str | None = None):
|
||||
async def list_packages(
|
||||
request: Request,
|
||||
host: str | None = None,
|
||||
ssh_port: str | None = None,
|
||||
venv: str | None = None,
|
||||
):
|
||||
"""Check which optional packages are installed.
|
||||
|
||||
Local-target packages are checked in-process. Remote-target packages
|
||||
@@ -890,7 +970,13 @@ def setup_shell_routes() -> APIRouter:
|
||||
"""
|
||||
_require_admin(request)
|
||||
_reject_cross_site(request)
|
||||
import importlib, importlib.metadata as importlib_metadata, shlex, json as _json, site, sys
|
||||
import importlib
|
||||
import importlib.metadata as importlib_metadata
|
||||
import shlex
|
||||
import json as _json
|
||||
import site
|
||||
import sys
|
||||
|
||||
_prepend_user_install_bins_to_path()
|
||||
importlib.invalidate_caches()
|
||||
try:
|
||||
@@ -905,26 +991,115 @@ def setup_shell_routes() -> APIRouter:
|
||||
raise HTTPException(400, "Invalid ssh_port")
|
||||
packages = [
|
||||
# ── System ── OS binaries, not pip packages
|
||||
{"name": "tmux", "pip": "", "desc": "Required for Linux/Termux Cookbook background downloads and serves", "category": "System", "target": "remote", "kind": "system", "install_hint": "Run Cookbook server setup, or install tmux with apt/pacman/dnf/apk/zypper."},
|
||||
{"name": "docker", "pip": "", "desc": "Required only for Docker-backed launch commands", "category": "System", "target": "remote", "kind": "system", "install_hint": "Install Docker on the selected server and allow this user to run docker."},
|
||||
{
|
||||
"name": "tmux",
|
||||
"pip": "",
|
||||
"desc": "Required for Linux/Termux Cookbook background downloads and serves",
|
||||
"category": "System",
|
||||
"target": "remote",
|
||||
"kind": "system",
|
||||
"install_hint": "Run Cookbook server setup, or install tmux with apt/pacman/dnf/apk/zypper.",
|
||||
},
|
||||
{
|
||||
"name": "docker",
|
||||
"pip": "",
|
||||
"desc": "Required only for Docker-backed launch commands",
|
||||
"category": "System",
|
||||
"target": "remote",
|
||||
"kind": "system",
|
||||
"install_hint": "Install Docker on the selected server and allow this user to run docker.",
|
||||
},
|
||||
# ── LLM ── installs on GPU servers for model serving/downloading
|
||||
{"name": "hf_transfer", "pip": "hf_transfer", "desc": "Fast model downloads from HuggingFace", "category": "LLM", "target": "remote"},
|
||||
{"name": "llama_cpp", "pip": "llama-cpp-python[server]", "desc": "Serve GGUF models via llama.cpp", "category": "LLM", "target": "remote"},
|
||||
{"name": "sglang", "pip": "sglang[all]", "desc": "Serve HF safetensors models via SGLang", "category": "LLM", "target": "remote"},
|
||||
{"name": "vllm", "pip": "vllm", "desc": "High-throughput LLM serving engine", "category": "LLM", "target": "remote"},
|
||||
{
|
||||
"name": "hf_transfer",
|
||||
"pip": "hf_transfer",
|
||||
"desc": "Fast model downloads from HuggingFace",
|
||||
"category": "LLM",
|
||||
"target": "remote",
|
||||
},
|
||||
{
|
||||
"name": "llama_cpp",
|
||||
"pip": "llama-cpp-python[server]",
|
||||
"desc": "Serve GGUF models via llama.cpp",
|
||||
"category": "LLM",
|
||||
"target": "remote",
|
||||
},
|
||||
{
|
||||
"name": "sglang",
|
||||
"pip": "sglang[all]",
|
||||
"desc": "Serve HF safetensors models via SGLang",
|
||||
"category": "LLM",
|
||||
"target": "remote",
|
||||
},
|
||||
{
|
||||
"name": "vllm",
|
||||
"pip": "vllm",
|
||||
"desc": "High-throughput LLM serving engine",
|
||||
"category": "LLM",
|
||||
"target": "remote",
|
||||
},
|
||||
{
|
||||
"name": "APFEL",
|
||||
"pip": "",
|
||||
"desc": "OpenAI-compatible API for Apple Foundational Models on Apple Silicon",
|
||||
"category": "LLM",
|
||||
"target": "local",
|
||||
"kind": "system",
|
||||
"install_cmd": "brew install apfel",
|
||||
"update_cmd": "brew upgrade apfel",
|
||||
"install_hint": "Requires a native Apple Silicon Mac with Apple Foundational Models support. Installable via Homebrew on supported Macs.",
|
||||
},
|
||||
# ── Image ── editor + diffusion model serving
|
||||
{"name": "diffusers", "pip": "diffusers[torch]", "desc": "Image generation pipelines (SD, Flux) with PyTorch", "category": "Image", "target": "remote"},
|
||||
{"name": "rembg", "pip": "rembg[gpu]", "desc": "AI background removal for image editor", "category": "Image", "target": "local"},
|
||||
{"name": "realesrgan", "pip": "realesrgan", "desc": "AI denoise + upscale (Real-ESRGAN). Used by editor's Denoise and Upscale tools.", "category": "Image", "target": "local"},
|
||||
{
|
||||
"name": "diffusers",
|
||||
"pip": "diffusers[torch]",
|
||||
"desc": "Image generation pipelines (SD, Flux) with PyTorch",
|
||||
"category": "Image",
|
||||
"target": "remote",
|
||||
},
|
||||
{
|
||||
"name": "rembg",
|
||||
"pip": "rembg[gpu]",
|
||||
"desc": "AI background removal for image editor",
|
||||
"category": "Image",
|
||||
"target": "local",
|
||||
},
|
||||
{
|
||||
"name": "realesrgan",
|
||||
"pip": "realesrgan",
|
||||
"desc": "AI denoise + upscale (Real-ESRGAN). Used by editor's Denoise and Upscale tools.",
|
||||
"category": "Image",
|
||||
"target": "local",
|
||||
},
|
||||
# ── Tools ──
|
||||
{"name": "playwright", "pip": "playwright", "desc": "Browser automation for web tools", "category": "Tools", "target": "local"},
|
||||
{
|
||||
"name": "playwright",
|
||||
"pip": "playwright",
|
||||
"desc": "Browser automation for web tools",
|
||||
"category": "Tools",
|
||||
"target": "local",
|
||||
},
|
||||
]
|
||||
|
||||
# Most packages should not be installed through external means. Hence, set the default of the
|
||||
# install_cmd and update_cmd to None, which indicates that the recommended way to install/update is through the Cookbook # server setup or pip. Only system packages, should have explicit install/update commands provided.
|
||||
for pkg in packages:
|
||||
pkg.setdefault("install_cmd", None)
|
||||
pkg.setdefault("update_cmd", None)
|
||||
# Remote check: for remote-target packages, probe the selected server's
|
||||
# venv over SSH so a remote `pip install` actually reflects here.
|
||||
remote_status: dict = {}
|
||||
remote_details: dict = {}
|
||||
remote_names = [p["name"] for p in packages if p.get("target") == "remote" and p.get("kind") != "system"]
|
||||
remote_system_names = [p["name"] for p in packages if p.get("target") == "remote" and p.get("kind") == "system"]
|
||||
remote_names = [
|
||||
p["name"]
|
||||
for p in packages
|
||||
if p.get("target") == "remote" and p.get("kind") != "system"
|
||||
]
|
||||
remote_system_names = [
|
||||
p["name"]
|
||||
for p in packages
|
||||
if p.get("target") == "remote" and p.get("kind") == "system"
|
||||
]
|
||||
if host and remote_names:
|
||||
try:
|
||||
py = _package_probe_script(remote_names)
|
||||
@@ -934,7 +1109,9 @@ def setup_shell_routes() -> APIRouter:
|
||||
inner = f"{src}python3 -c {shlex.quote(py)}"
|
||||
argv = _ssh_base_argv(host, ssh_port) + [inner]
|
||||
proc = await asyncio.create_subprocess_exec(
|
||||
*argv, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE
|
||||
*argv,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
)
|
||||
out, _err = await asyncio.wait_for(proc.communicate(), timeout=12)
|
||||
txt = out.decode("utf-8", errors="replace").strip()
|
||||
@@ -958,11 +1135,15 @@ def setup_shell_routes() -> APIRouter:
|
||||
checks = []
|
||||
for name in remote_system_names:
|
||||
qn = shlex.quote(name)
|
||||
checks.append(f"if command -v {qn} >/dev/null 2>&1; then echo {qn}=1; else echo {qn}=0; fi")
|
||||
checks.append(
|
||||
f"if command -v {qn} >/dev/null 2>&1; then echo {qn}=1; else echo {qn}=0; fi"
|
||||
)
|
||||
inner = " ; ".join(checks)
|
||||
argv = _ssh_base_argv(host, ssh_port) + [inner]
|
||||
proc = await asyncio.create_subprocess_exec(
|
||||
*argv, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE
|
||||
*argv,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
)
|
||||
out, _err = await asyncio.wait_for(proc.communicate(), timeout=12)
|
||||
txt = out.decode("utf-8", errors="replace").strip()
|
||||
@@ -987,11 +1168,25 @@ def setup_shell_routes() -> APIRouter:
|
||||
if note:
|
||||
pkg["status_note"] = note
|
||||
elif pkg.get("kind") == "system":
|
||||
pkg["installed"] = shutil.which(pkg["name"]) is not None
|
||||
if pkg["name"] == "APFEL":
|
||||
pkg["applicable"] = IS_APPLE_SILICON
|
||||
pkg["installed"] = which_tool("apfel") is not None
|
||||
pkg["status_note"] = (
|
||||
"Available on Apple Silicon (arm64) devices; exposed through a local OpenAI-compatible API."
|
||||
if IS_APPLE_SILICON
|
||||
else "Requires a native Apple Silicon Mac with Apple Foundational Models support."
|
||||
)
|
||||
else:
|
||||
pkg["installed"] = shutil.which(pkg["name"]) is not None
|
||||
elif pkg["name"] == "llama_cpp" and shutil.which("llama-server"):
|
||||
pkg["installed"] = True
|
||||
pkg["status_note"] = f"native llama-server: {shutil.which('llama-server')}"
|
||||
probe = {"binaries": {"llama-server": shutil.which("llama-server")}, "dists": {}}
|
||||
pkg["status_note"] = (
|
||||
f"native llama-server: {shutil.which('llama-server')}"
|
||||
)
|
||||
probe = {
|
||||
"binaries": {"llama-server": shutil.which("llama-server")},
|
||||
"dists": {},
|
||||
}
|
||||
elif pkg["name"] == "vllm":
|
||||
_vllm_cli = shutil.which("vllm")
|
||||
pkg["installed"] = _vllm_cli is not None
|
||||
@@ -1014,6 +1209,12 @@ def setup_shell_routes() -> APIRouter:
|
||||
pkg["installed"] = False
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
pkg["installed"] = False
|
||||
except Exception:
|
||||
# Installed but crashes on import — e.g. a CUDA build of
|
||||
# llama-cpp-python raising FileNotFoundError when the CUDA
|
||||
# toolkit dir is absent. One broken optional package must not
|
||||
# 500 the entire packages panel; report it as not usable.
|
||||
pkg["installed"] = False
|
||||
|
||||
if pkg.get("installed"):
|
||||
update_status = _package_pip_update_status(pkg, probe)
|
||||
@@ -1037,15 +1238,30 @@ def setup_shell_routes() -> APIRouter:
|
||||
"""Install a package via pip. Admin only — pip install is effectively code exec."""
|
||||
_require_admin(request)
|
||||
import sys as _sys
|
||||
|
||||
body = await request.json()
|
||||
pip_name = body.get("pip")
|
||||
if not pip_name:
|
||||
return {"ok": False, "error": "No package specified"}
|
||||
# Validate against known packages to prevent arbitrary pip install
|
||||
known = {
|
||||
"rembg[gpu]", "hf_transfer", "llama-cpp-python[server]", "sglang[all]", "diffusers", "diffusers[torch]",
|
||||
"TTS", "bark", "faster-whisper", "playwright", "realesrgan", "gfpgan",
|
||||
"insightface", "onnxruntime-gpu", "onnxruntime", "hdbscan", "vllm",
|
||||
"rembg[gpu]",
|
||||
"hf_transfer",
|
||||
"llama-cpp-python[server]",
|
||||
"sglang[all]",
|
||||
"diffusers",
|
||||
"diffusers[torch]",
|
||||
"TTS",
|
||||
"bark",
|
||||
"faster-whisper",
|
||||
"playwright",
|
||||
"realesrgan",
|
||||
"gfpgan",
|
||||
"insightface",
|
||||
"onnxruntime-gpu",
|
||||
"onnxruntime",
|
||||
"hdbscan",
|
||||
"vllm",
|
||||
}
|
||||
if pip_name not in known:
|
||||
return {"ok": False, "error": f"Unknown package: {pip_name}"}
|
||||
@@ -1071,6 +1287,7 @@ def setup_shell_routes() -> APIRouter:
|
||||
"""
|
||||
_require_admin(request)
|
||||
from routes.cookbook_helpers import _llama_cpp_rebuild_cmd
|
||||
|
||||
body = await request.json()
|
||||
engine = str(body.get("engine") or "llamacpp").strip()
|
||||
if engine != "llamacpp":
|
||||
@@ -1079,7 +1296,11 @@ def setup_shell_routes() -> APIRouter:
|
||||
ssh_port = body.get("ssh_port")
|
||||
cmd = _llama_cpp_rebuild_cmd()
|
||||
try:
|
||||
argv = (_ssh_base_argv(host, ssh_port) + [cmd]) if host else ["bash", "-lc", cmd]
|
||||
argv = (
|
||||
(_ssh_base_argv(host, ssh_port) + [cmd])
|
||||
if host
|
||||
else ["bash", "-lc", cmd]
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(400, str(e))
|
||||
try:
|
||||
|
||||
+44
-16
@@ -21,10 +21,44 @@ from src.auth_helpers import get_current_user
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
_DATA_URL_RE = re.compile(
|
||||
r'^data:image/(?P<fmt>png|jpeg|jpg);base64,(?P<data>.+)$',
|
||||
re.IGNORECASE | re.DOTALL,
|
||||
)
|
||||
_DATA_URL_RE = re.compile(r"^data:image/png;base64,(?P<data>.+)$", re.IGNORECASE | re.DOTALL)
|
||||
_ANY_IMAGE_DATA_URL_RE = re.compile(r"^data:image/[^;]+;base64,", re.IGNORECASE)
|
||||
_PNG_MAGIC = b"\x89PNG\r\n\x1a\n"
|
||||
_MAX_SIGNATURE_BYTES = 2 * 1024 * 1024
|
||||
_MAX_SIGNATURE_B64 = ((_MAX_SIGNATURE_BYTES + 2) // 3) * 4
|
||||
_MAX_SIGNATURE_DIMENSION = 4096
|
||||
|
||||
|
||||
def _normalize_signature_png(raw: str) -> str:
|
||||
raw = (raw or "").strip()
|
||||
m = _DATA_URL_RE.match(raw)
|
||||
if m:
|
||||
b64 = m.group("data")
|
||||
elif _ANY_IMAGE_DATA_URL_RE.match(raw):
|
||||
raise HTTPException(400, "Signature data must be a PNG image")
|
||||
else:
|
||||
b64 = raw
|
||||
if len(b64) > _MAX_SIGNATURE_B64:
|
||||
raise HTTPException(400, "Signature PNG is too large")
|
||||
try:
|
||||
payload = base64.b64decode(b64, validate=True)
|
||||
except Exception:
|
||||
raise HTTPException(400, "Signature data must be base64-encoded PNG bytes")
|
||||
if not payload:
|
||||
raise HTTPException(400, "Signature PNG is empty")
|
||||
if len(payload) > _MAX_SIGNATURE_BYTES:
|
||||
raise HTTPException(400, "Signature PNG is too large")
|
||||
if not payload.startswith(_PNG_MAGIC):
|
||||
raise HTTPException(400, "Signature data must be a PNG image")
|
||||
return base64.b64encode(payload).decode("ascii")
|
||||
|
||||
|
||||
def _signature_dimension(value: Optional[int]) -> Optional[int]:
|
||||
if value is None:
|
||||
return None
|
||||
if not isinstance(value, int) or value < 1 or value > _MAX_SIGNATURE_DIMENSION:
|
||||
raise HTTPException(400, "Signature dimensions are invalid")
|
||||
return value
|
||||
|
||||
|
||||
class SignatureCreate(BaseModel):
|
||||
@@ -67,24 +101,18 @@ def setup_signature_routes() -> APIRouter:
|
||||
@router.post("/api/signatures")
|
||||
async def create_signature(request: Request, req: SignatureCreate) -> Dict[str, Any]:
|
||||
user = get_current_user(request)
|
||||
raw = (req.data or "").strip()
|
||||
m = _DATA_URL_RE.match(raw)
|
||||
b64 = m.group("data") if m else raw
|
||||
try:
|
||||
payload = base64.b64decode(b64, validate=True)
|
||||
if not payload:
|
||||
raise ValueError("empty payload")
|
||||
except Exception:
|
||||
raise HTTPException(400, "Signature data must be base64-encoded PNG bytes")
|
||||
b64 = _normalize_signature_png(req.data)
|
||||
width = _signature_dimension(req.width)
|
||||
height = _signature_dimension(req.height)
|
||||
|
||||
sig = Signature(
|
||||
id=str(uuid.uuid4()),
|
||||
owner=user,
|
||||
name=(req.name or "Signature").strip() or "Signature",
|
||||
data_png=b64,
|
||||
width=req.width,
|
||||
height=req.height,
|
||||
svg=req.svg,
|
||||
width=width,
|
||||
height=height,
|
||||
svg=None,
|
||||
)
|
||||
db = SessionLocal()
|
||||
try:
|
||||
|
||||
+107
-1
@@ -11,6 +11,8 @@ import logging
|
||||
import re
|
||||
from typing import List, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Request
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
@@ -51,6 +53,10 @@ class SkillAddRequest(BaseModel):
|
||||
steps: List[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
class SkillImportUrlRequest(BaseModel):
|
||||
url: str = Field(..., min_length=8, max_length=2000)
|
||||
|
||||
|
||||
class SkillUpdateRequest(BaseModel):
|
||||
name: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
@@ -1014,7 +1020,7 @@ def _resolve_audit_models(owner=None):
|
||||
spec = (get_setting("teacher_model", "") or "").strip()
|
||||
if spec:
|
||||
from src.ai_interaction import _resolve_model
|
||||
t_url, t_model, t_headers = _resolve_model(spec)
|
||||
t_url, t_model, t_headers = _resolve_model(spec, owner=owner)
|
||||
if t_url and t_model:
|
||||
teacher = (t_url, t_model, t_headers)
|
||||
except Exception as e:
|
||||
@@ -1103,6 +1109,35 @@ def setup_skills_routes(skills_manager: SkillsManager) -> APIRouter:
|
||||
idx = skills_manager.index_for(owner=user)
|
||||
return {"index": idx, "count": len(idx)}
|
||||
|
||||
@router.get("/slash-catalog")
|
||||
async def get_slash_catalog(request: Request):
|
||||
"""Return skills that are available as slash commands.
|
||||
|
||||
Mirrors the agent prompt's published-skill index so the UI never offers
|
||||
a slash command the model would not normally be allowed to discover.
|
||||
"""
|
||||
user = _owner(request)
|
||||
all_skills = {s.get("name"): s for s in skills_manager.load(owner=user)}
|
||||
entries = []
|
||||
for s in skills_manager.index_for(owner=user):
|
||||
name = (s.get("name") or "").strip()
|
||||
if not name:
|
||||
continue
|
||||
full = all_skills.get(name) or {}
|
||||
category = (s.get("category") or full.get("category") or "general").strip() or "general"
|
||||
entries.append({
|
||||
"type": "skill",
|
||||
"token": f"/{name}",
|
||||
"name": name,
|
||||
"category": f"Skills / {category}",
|
||||
"help": s.get("description") or full.get("description") or "",
|
||||
"usage": f"/{name} <request>",
|
||||
"uses": int(full.get("uses") or 0),
|
||||
"last_used": full.get("last_used"),
|
||||
})
|
||||
entries.sort(key=lambda row: row["name"])
|
||||
return {"skills": entries, "count": len(entries)}
|
||||
|
||||
@router.get("/builtin")
|
||||
async def list_builtin_skills(request: Request):
|
||||
"""Read-only list of the agent's built-in tool capabilities (research,
|
||||
@@ -1203,6 +1238,36 @@ def setup_skills_routes(skills_manager: SkillsManager) -> APIRouter:
|
||||
save_settings(settings)
|
||||
return {"ok": True, "name": name, "is_overridden": False}
|
||||
|
||||
@router.post("/import-from-url")
|
||||
async def import_skill_from_url(request: Request, body: SkillImportUrlRequest):
|
||||
"""Install a SKILL.md bundle from a public GitHub URL (skills.sh links supported)."""
|
||||
require_admin(request)
|
||||
user = _owner(request)
|
||||
from services.memory.skill_importer import (
|
||||
SkillImportError,
|
||||
fetch_skill_bundle,
|
||||
)
|
||||
|
||||
try:
|
||||
files, _src = fetch_skill_bundle(body.url.strip())
|
||||
entry = skills_manager.import_bundle_from_files(
|
||||
files,
|
||||
owner=user,
|
||||
source_url=body.url.strip(),
|
||||
)
|
||||
except SkillImportError as e:
|
||||
raise HTTPException(400, str(e)) from e
|
||||
except httpx.HTTPError as e:
|
||||
logger.warning("skill import fetch failed: %s", e)
|
||||
detail = str(e).strip() or "Could not download skill from URL"
|
||||
raise HTTPException(502, detail) from e
|
||||
except Exception as e:
|
||||
logger.error("skill import failed: %s", e)
|
||||
raise HTTPException(500, "Skill import failed") from e
|
||||
|
||||
_fire_skill_added(user)
|
||||
return {"ok": True, "skill": entry, "files": len(files)}
|
||||
|
||||
@router.post("/add")
|
||||
async def add_skill(request: Request, body: SkillAddRequest):
|
||||
user = _owner(request)
|
||||
@@ -1236,6 +1301,47 @@ def setup_skills_routes(skills_manager: SkillsManager) -> APIRouter:
|
||||
_fire_skill_added(user)
|
||||
return {"ok": True, "deduped": bool(entry.get("_deduped")), "skill": entry}
|
||||
|
||||
@router.post("/{skill_id}/invoke")
|
||||
async def invoke_skill(request: Request, skill_id: str):
|
||||
"""Build a skill-pinned prompt for slash-command invocation.
|
||||
|
||||
This is intentionally server-side so availability, ownership, and usage
|
||||
accounting use the same rules as the SkillsManager.
|
||||
"""
|
||||
user = _owner(request)
|
||||
try:
|
||||
body = await request.json()
|
||||
except Exception:
|
||||
body = {}
|
||||
request_text = (body.get("request") or "").strip() if isinstance(body, dict) else ""
|
||||
|
||||
invokable = {
|
||||
s.get("name"): s for s in skills_manager.index_for(owner=user)
|
||||
if (s.get("name") or "").strip()
|
||||
}
|
||||
match = invokable.get(skill_id)
|
||||
if not match:
|
||||
raise HTTPException(404, "Skill is not available for slash invocation")
|
||||
|
||||
name = match.get("name")
|
||||
md = skills_manager.read_skill_md(name, owner=user)
|
||||
if md is None:
|
||||
raise HTTPException(404, "Skill source unavailable")
|
||||
|
||||
skills_manager.record_use(name, owner=user)
|
||||
message = (
|
||||
"Apply the skill below to my request, following its Procedure / Pitfalls / Verification.\n\n"
|
||||
f"--- BEGIN SKILL ---\n{md}\n--- END SKILL ---\n\n"
|
||||
+ (f"Request: {request_text}" if request_text else "Request: (use the skill as appropriate)")
|
||||
)
|
||||
return {
|
||||
"ok": True,
|
||||
"type": "skill",
|
||||
"name": name,
|
||||
"command": f"/{name}",
|
||||
"message": message,
|
||||
}
|
||||
|
||||
@router.get("/{skill_id}")
|
||||
async def get_skill(request: Request, skill_id: str):
|
||||
user = _owner(request)
|
||||
|
||||
@@ -4,12 +4,10 @@
|
||||
from fastapi import APIRouter, HTTPException, UploadFile, File
|
||||
import logging
|
||||
|
||||
from src.upload_limits import read_upload_limited
|
||||
from src.upload_limits import read_upload_limited, STT_MAX_AUDIO_BYTES
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
STT_MAX_AUDIO_BYTES = 25 * 1024 * 1024
|
||||
|
||||
|
||||
def setup_stt_routes(stt_service):
|
||||
"""Setup STT routes with the provided STT service"""
|
||||
|
||||
+37
-22
@@ -11,7 +11,9 @@ from fastapi import APIRouter, HTTPException, Request
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.database import SessionLocal, ScheduledTask, TaskRun
|
||||
from core.constants import internal_api_base
|
||||
from src.auth_helpers import get_current_user
|
||||
from src.constants import DATA_DIR, EMAIL_URGENCY_CACHE_DIR
|
||||
from src.task_scheduler import compute_next_run, HOUSEKEEPING_DEFAULTS
|
||||
from routes.prefs_routes import _load_for_user, _save_for_user
|
||||
|
||||
@@ -56,7 +58,7 @@ def _maybe_cascade_calendar_event(task) -> None:
|
||||
try:
|
||||
with httpx.Client(timeout=10) as client:
|
||||
r = client.delete(
|
||||
f"http://localhost:7000/api/calendar/events/{uid}",
|
||||
f"{internal_api_base()}/api/calendar/events/{uid}",
|
||||
headers=headers,
|
||||
)
|
||||
if r.status_code >= 400:
|
||||
@@ -81,7 +83,7 @@ def _maybe_cascade_calendar_event(task) -> None:
|
||||
try:
|
||||
with httpx.Client(timeout=10) as client:
|
||||
# Find the Cookbook calendar.
|
||||
cal_r = client.get("http://localhost:7000/api/calendar/calendars", headers=headers)
|
||||
cal_r = client.get(f"{internal_api_base()}/api/calendar/calendars", headers=headers)
|
||||
if cal_r.status_code >= 400:
|
||||
return
|
||||
cals = (cal_r.json() or {}).get("calendars", [])
|
||||
@@ -98,7 +100,7 @@ def _maybe_cascade_calendar_event(task) -> None:
|
||||
start = (now - _td(days=30)).isoformat()
|
||||
end = (now + _td(days=365)).isoformat()
|
||||
ev_r = client.get(
|
||||
"http://localhost:7000/api/calendar/events",
|
||||
f"{internal_api_base()}/api/calendar/events",
|
||||
params={"start": start, "end": end, "calendar": cal_href},
|
||||
headers=headers,
|
||||
)
|
||||
@@ -291,20 +293,24 @@ def setup_task_routes(task_scheduler) -> APIRouter:
|
||||
def _owner(request: Request):
|
||||
return get_current_user(request)
|
||||
|
||||
async def _generate_task_name(prompt: str) -> str:
|
||||
async def _generate_task_name(prompt: str, owner: Optional[str] = None) -> str:
|
||||
"""Use LLM to generate a short task name from the prompt."""
|
||||
try:
|
||||
from src.llm_core import llm_call_async
|
||||
from core.database import Session as DbSession
|
||||
db = SessionLocal()
|
||||
try:
|
||||
recent = db.query(DbSession).filter(
|
||||
q = db.query(DbSession).filter(
|
||||
DbSession.endpoint_url.isnot(None),
|
||||
DbSession.model.isnot(None),
|
||||
).order_by(DbSession.created_at.desc()).first()
|
||||
)
|
||||
if owner:
|
||||
q = q.filter(DbSession.owner == owner)
|
||||
recent = q.order_by(DbSession.created_at.desc()).first()
|
||||
if not recent:
|
||||
return prompt[:50].strip()
|
||||
url, model = recent.endpoint_url, recent.model
|
||||
headers = recent.headers or {}
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
@@ -315,6 +321,7 @@ def setup_task_routes(task_scheduler) -> APIRouter:
|
||||
{"role": "user", "content": prompt[:500]},
|
||||
],
|
||||
max_tokens=20,
|
||||
headers=headers,
|
||||
timeout=15,
|
||||
)
|
||||
title = result.strip().strip('"\'').strip()
|
||||
@@ -429,6 +436,20 @@ def setup_task_routes(task_scheduler) -> APIRouter:
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def _validate_then_task_id(db, then_task_id: Optional[str], user: Optional[str], current_task_id: Optional[str] = None) -> Optional[str]:
|
||||
target_id = (then_task_id or "").strip()
|
||||
if not target_id:
|
||||
return None
|
||||
if current_task_id and target_id == current_task_id:
|
||||
raise HTTPException(400, "Task cannot chain to itself")
|
||||
q = db.query(ScheduledTask).filter(ScheduledTask.id == target_id)
|
||||
if user:
|
||||
q = q.filter(ScheduledTask.owner == user)
|
||||
target = q.first()
|
||||
if not target:
|
||||
raise HTTPException(404, "Chained task not found")
|
||||
return target.id
|
||||
|
||||
@router.post("")
|
||||
async def create_task(request: Request, req: TaskCreate):
|
||||
user = _owner(request)
|
||||
@@ -465,7 +486,7 @@ def setup_task_routes(task_scheduler) -> APIRouter:
|
||||
from src.builtin_actions import BUILTIN_ACTION_INFO
|
||||
name = BUILTIN_ACTION_INFO.get(req.action, req.action or "Action Task")
|
||||
elif req.prompt:
|
||||
name = await _generate_task_name(req.prompt)
|
||||
name = await _generate_task_name(req.prompt, owner=user)
|
||||
else:
|
||||
name = "Untitled Task"
|
||||
|
||||
@@ -492,6 +513,7 @@ def setup_task_routes(task_scheduler) -> APIRouter:
|
||||
task_id = str(uuid.uuid4())
|
||||
db = SessionLocal()
|
||||
try:
|
||||
then_task_id = _validate_then_task_id(db, req.then_task_id, user)
|
||||
notifications_enabled = (
|
||||
False if req.task_type == "action" and req.notifications_enabled is None
|
||||
else bool(req.notifications_enabled) if req.notifications_enabled is not None
|
||||
@@ -527,7 +549,7 @@ def setup_task_routes(task_scheduler) -> APIRouter:
|
||||
output_target=req.output_target,
|
||||
model=req.model or None,
|
||||
endpoint_url=req.endpoint_url or None,
|
||||
then_task_id=req.then_task_id or None,
|
||||
then_task_id=then_task_id,
|
||||
webhook_token=webhook_token,
|
||||
notifications_enabled=notifications_enabled,
|
||||
)
|
||||
@@ -609,7 +631,7 @@ def setup_task_routes(task_scheduler) -> APIRouter:
|
||||
|
||||
removed_files = 0
|
||||
if action == "check_email_urgency":
|
||||
cache_dir = Path("data/email_urgency_cache")
|
||||
cache_dir = Path(EMAIL_URGENCY_CACHE_DIR)
|
||||
if cache_dir.exists():
|
||||
for child in cache_dir.glob("*.json"):
|
||||
try:
|
||||
@@ -618,7 +640,7 @@ def setup_task_routes(task_scheduler) -> APIRouter:
|
||||
except Exception:
|
||||
pass
|
||||
owner_slug = "".join(c if (c.isalnum() or c in "-_.@") else "_" for c in (user or "default"))
|
||||
for state_path in [Path(f"data/email_urgency_state_{owner_slug}.json")]:
|
||||
for state_path in [Path(DATA_DIR) / f"email_urgency_state_{owner_slug}.json"]:
|
||||
try:
|
||||
if state_path.exists():
|
||||
state_path.unlink()
|
||||
@@ -680,15 +702,7 @@ def setup_task_routes(task_scheduler) -> APIRouter:
|
||||
if req.trigger_count is not None:
|
||||
task.trigger_count = req.trigger_count
|
||||
if req.then_task_id is not None:
|
||||
if req.then_task_id:
|
||||
chain_target = db.query(ScheduledTask).filter(
|
||||
ScheduledTask.id == req.then_task_id
|
||||
).first()
|
||||
if not chain_target:
|
||||
raise HTTPException(400, "Chained task not found")
|
||||
if chain_target.owner != user:
|
||||
raise HTTPException(403, "Cannot chain to another user's task")
|
||||
task.then_task_id = req.then_task_id or None
|
||||
task.then_task_id = _validate_then_task_id(db, req.then_task_id, user, current_task_id=task.id)
|
||||
if req.notifications_enabled is not None:
|
||||
task.notifications_enabled = bool(req.notifications_enabled)
|
||||
if req.cron_expression is not None:
|
||||
@@ -969,7 +983,7 @@ def setup_task_routes(task_scheduler) -> APIRouter:
|
||||
"tag", "label", "move", "archive", "delete", "mark", "schedule",
|
||||
)
|
||||
try:
|
||||
from src.agent_tools import get_mcp_manager
|
||||
from src.tool_utils import get_mcp_manager
|
||||
mcp = get_mcp_manager()
|
||||
if mcp:
|
||||
for tool in mcp.get_all_tools():
|
||||
@@ -1064,6 +1078,7 @@ def setup_task_routes(task_scheduler) -> APIRouter:
|
||||
desc = (body.get("description") or "").strip()
|
||||
if not desc:
|
||||
return {"success": False, "message": "Nothing to parse"}
|
||||
user = _owner(request)
|
||||
|
||||
now = _dt.now()
|
||||
# Give the model the current date/time + weekday so relative phrasing
|
||||
@@ -1090,9 +1105,9 @@ def setup_task_routes(task_scheduler) -> APIRouter:
|
||||
"use cron '0 H * * 1-5'. Keep the prompt actionable and self-contained."
|
||||
)
|
||||
try:
|
||||
url, model, headers = resolve_endpoint("utility")
|
||||
url, model, headers = resolve_endpoint("utility", owner=user or None)
|
||||
if not url:
|
||||
url, model, headers = resolve_endpoint("default")
|
||||
url, model, headers = resolve_endpoint("default", owner=user or None)
|
||||
if not (url and model):
|
||||
return {"success": False, "message": "No model endpoint configured"}
|
||||
raw = await llm_call_async(
|
||||
|
||||
+51
-34
@@ -13,9 +13,43 @@ from src.upload_handler import count_recent_uploads
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/upload", tags=["upload"])
|
||||
UPLOAD_RESPONSE_HEADERS = {"X-Content-Type-Options": "nosniff"}
|
||||
|
||||
def setup_upload_routes(upload_handler):
|
||||
"""Setup upload routes with the provided handler"""
|
||||
|
||||
def _upload_root() -> str:
|
||||
from src.constants import UPLOAD_DIR
|
||||
return os.path.realpath(getattr(upload_handler, "upload_dir", UPLOAD_DIR))
|
||||
|
||||
def _path_inside_upload_dir(path: str) -> bool:
|
||||
try:
|
||||
return os.path.commonpath([_upload_root(), os.path.realpath(path)]) == _upload_root()
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def _resolve_upload_path(file_id: str) -> str:
|
||||
from src.constants import UPLOAD_DIR
|
||||
upload_root = getattr(upload_handler, "upload_dir", UPLOAD_DIR)
|
||||
direct = os.path.join(upload_root, file_id)
|
||||
if os.path.lexists(direct):
|
||||
if not _path_inside_upload_dir(direct):
|
||||
raise HTTPException(403, "Access denied")
|
||||
if os.path.isfile(direct):
|
||||
return direct
|
||||
raise HTTPException(404, "File not found")
|
||||
|
||||
for root, _dirs, files in os.walk(upload_root, followlinks=False):
|
||||
if file_id not in files:
|
||||
continue
|
||||
path = os.path.join(root, file_id)
|
||||
if not _path_inside_upload_dir(path):
|
||||
raise HTTPException(403, "Access denied")
|
||||
if os.path.isfile(path):
|
||||
return path
|
||||
raise HTTPException(404, "File not found")
|
||||
|
||||
raise HTTPException(404, "File not found")
|
||||
|
||||
@router.post("")
|
||||
async def api_upload(request: Request, files: List[UploadFile] = File(...)):
|
||||
@@ -91,23 +125,11 @@ def setup_upload_routes(upload_handler):
|
||||
client isn't downloading the full-resolution photo just to show it tiny."""
|
||||
if not upload_handler.validate_upload_id(file_id):
|
||||
raise HTTPException(400, "Invalid file ID")
|
||||
# Search upload directories for the file
|
||||
from src.constants import UPLOAD_DIR
|
||||
import mimetypes as _mt
|
||||
path = os.path.join(UPLOAD_DIR, file_id)
|
||||
if not os.path.exists(path):
|
||||
for root, dirs, files in os.walk(UPLOAD_DIR):
|
||||
if file_id in files:
|
||||
path = os.path.join(root, file_id)
|
||||
break
|
||||
else:
|
||||
raise HTTPException(404, "File not found")
|
||||
if not upload_handler.inside_base_dir(path):
|
||||
raise HTTPException(403, "Access denied")
|
||||
# Look up original filename and owner from uploads.json
|
||||
original_name = file_id
|
||||
info = None
|
||||
uploads_db = os.path.join(UPLOAD_DIR, "uploads.json")
|
||||
uploads_db = os.path.join(_upload_root(), "uploads.json")
|
||||
if os.path.exists(uploads_db):
|
||||
with open(uploads_db, encoding="utf-8") as f:
|
||||
db = json.load(f)
|
||||
@@ -123,13 +145,14 @@ def setup_upload_routes(upload_handler):
|
||||
raise HTTPException(403, "Access denied")
|
||||
if file_owner != current_user and not auth_mgr.is_admin(current_user):
|
||||
raise HTTPException(404, "File not found")
|
||||
mime = _mt.guess_type(path)[0] or "application/octet-stream"
|
||||
path = _resolve_upload_path(file_id)
|
||||
mime = (info or {}).get("mime") or _mt.guess_type(path)[0] or "application/octet-stream"
|
||||
from fastapi.responses import FileResponse
|
||||
# Downscaled thumbnail for image previews — generated once and cached.
|
||||
if thumb and mime.startswith("image/"):
|
||||
try:
|
||||
from PIL import Image, ImageOps
|
||||
thumb_dir = os.path.join(UPLOAD_DIR, ".thumbs")
|
||||
thumb_dir = os.path.join(_upload_root(), ".thumbs")
|
||||
os.makedirs(thumb_dir, exist_ok=True)
|
||||
thumb_path = os.path.join(thumb_dir, file_id + ".jpg")
|
||||
if (not os.path.exists(thumb_path)
|
||||
@@ -145,17 +168,21 @@ def setup_upload_routes(upload_handler):
|
||||
if im.mode not in ("RGB", "L"):
|
||||
im = im.convert("RGB")
|
||||
im.save(thumb_path, "JPEG", quality=80)
|
||||
return FileResponse(thumb_path, media_type="image/jpeg")
|
||||
return FileResponse(thumb_path, media_type="image/jpeg", headers=UPLOAD_RESPONSE_HEADERS)
|
||||
except Exception as e:
|
||||
logger.warning(f"Thumbnail generation failed for {file_id}: {e}")
|
||||
# Fall through to the full image.
|
||||
return FileResponse(path, media_type=mime, filename=original_name)
|
||||
return FileResponse(
|
||||
path,
|
||||
media_type=mime,
|
||||
filename=original_name,
|
||||
headers=UPLOAD_RESPONSE_HEADERS,
|
||||
)
|
||||
|
||||
def _load_upload_info(file_id: str):
|
||||
"""Look up the uploads.json record for a file_id, with owner/auth checks."""
|
||||
from src.constants import UPLOAD_DIR
|
||||
info = None
|
||||
uploads_db = os.path.join(UPLOAD_DIR, "uploads.json")
|
||||
uploads_db = os.path.join(_upload_root(), "uploads.json")
|
||||
if os.path.exists(uploads_db):
|
||||
with open(uploads_db, encoding="utf-8") as f:
|
||||
db = json.load(f)
|
||||
@@ -163,8 +190,7 @@ def setup_upload_routes(upload_handler):
|
||||
return info
|
||||
|
||||
def _vision_cache_path(file_id: str) -> str:
|
||||
from src.constants import UPLOAD_DIR
|
||||
cache_dir = os.path.join(UPLOAD_DIR, ".vision")
|
||||
cache_dir = os.path.join(_upload_root(), ".vision")
|
||||
os.makedirs(cache_dir, exist_ok=True)
|
||||
return os.path.join(cache_dir, file_id + ".txt")
|
||||
|
||||
@@ -175,17 +201,6 @@ def setup_upload_routes(upload_handler):
|
||||
subsequent loads are instant. Pass force=1 to recompute."""
|
||||
if not upload_handler.validate_upload_id(file_id):
|
||||
raise HTTPException(400, "Invalid file ID")
|
||||
from src.constants import UPLOAD_DIR
|
||||
path = os.path.join(UPLOAD_DIR, file_id)
|
||||
if not os.path.exists(path):
|
||||
for root, dirs, files in os.walk(UPLOAD_DIR):
|
||||
if file_id in files:
|
||||
path = os.path.join(root, file_id)
|
||||
break
|
||||
else:
|
||||
raise HTTPException(404, "File not found")
|
||||
if not upload_handler.inside_base_dir(path):
|
||||
raise HTTPException(403, "Access denied")
|
||||
info = _load_upload_info(file_id)
|
||||
auth_mgr = getattr(request.app.state, "auth_manager", None)
|
||||
auth_configured = bool(auth_mgr and auth_mgr.is_configured)
|
||||
@@ -196,8 +211,9 @@ def setup_upload_routes(upload_handler):
|
||||
raise HTTPException(403, "Access denied")
|
||||
if file_owner != current_user and not auth_mgr.is_admin(current_user):
|
||||
raise HTTPException(404, "File not found")
|
||||
path = _resolve_upload_path(file_id)
|
||||
import mimetypes as _mt
|
||||
mime = _mt.guess_type(path)[0] or ""
|
||||
mime = (info or {}).get("mime") or _mt.guess_type(path)[0] or ""
|
||||
if not mime.startswith("image/"):
|
||||
raise HTTPException(400, "Not an image")
|
||||
cache_path = _vision_cache_path(file_id)
|
||||
@@ -209,7 +225,7 @@ def setup_upload_routes(upload_handler):
|
||||
logger.warning(f"Vision cache read failed for {file_id}: {e}")
|
||||
from src.document_processor import analyze_image_with_vl
|
||||
try:
|
||||
text = analyze_image_with_vl(path) or ""
|
||||
text = analyze_image_with_vl(path, owner=current_user) or ""
|
||||
except Exception as e:
|
||||
logger.error(f"Vision analysis failed for {file_id}: {e}")
|
||||
raise HTTPException(500, f"Vision analysis failed: {e}")
|
||||
@@ -238,6 +254,7 @@ def setup_upload_routes(upload_handler):
|
||||
raise HTTPException(403, "Access denied")
|
||||
if file_owner != current_user and not auth_mgr.is_admin(current_user):
|
||||
raise HTTPException(404, "File not found")
|
||||
_resolve_upload_path(file_id)
|
||||
body = await request.json()
|
||||
text = (body or {}).get("text", "")
|
||||
if not isinstance(text, str):
|
||||
|
||||
@@ -17,10 +17,11 @@ from pydantic import BaseModel
|
||||
|
||||
from core.middleware import require_admin
|
||||
from core.platform_compat import IS_WINDOWS, safe_chmod, which_tool
|
||||
from src.constants import VAULT_FILE as _VAULT_FILE
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
VAULT_FILE = Path("data/vault.json")
|
||||
VAULT_FILE = Path(_VAULT_FILE)
|
||||
|
||||
|
||||
def _find_bw() -> str:
|
||||
|
||||
+23
-10
@@ -194,6 +194,8 @@ def setup_webhook_routes(
|
||||
"together": "https://api.together.xyz/v1",
|
||||
"openrouter": "https://openrouter.ai/api/v1",
|
||||
"ollama": "https://ollama.com/api",
|
||||
"opencode-zen": "https://opencode.ai/zen/v1",
|
||||
"opencode-go": "https://opencode.ai/zen/go/v1",
|
||||
"fireworks": "https://api.fireworks.ai/inference/v1",
|
||||
"venice": "https://api.venice.ai/api/v1",
|
||||
}
|
||||
@@ -323,22 +325,33 @@ def setup_webhook_routes(
|
||||
endpoint_url = build_chat_url(base_url)
|
||||
model = body.model or "auto"
|
||||
api_key = ep.api_key
|
||||
if getattr(ep, "provider_auth_id", None):
|
||||
try:
|
||||
from src.endpoint_resolver import resolve_endpoint_runtime
|
||||
base_url, api_key = resolve_endpoint_runtime(ep, owner=token_owner)
|
||||
endpoint_url = build_chat_url(base_url)
|
||||
except Exception:
|
||||
raise HTTPException(500, "Could not resolve endpoint credentials")
|
||||
|
||||
if model == "auto":
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=5) as client:
|
||||
models_url = build_models_url(base_url)
|
||||
hdrs = build_headers(api_key, base_url)
|
||||
resp = await client.get(models_url, headers=hdrs)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
ids = [m.get("id") for m in (data.get("data") or []) if m.get("id")]
|
||||
if not ids:
|
||||
ids = [
|
||||
m.get("name") or m.get("model")
|
||||
for m in (data.get("models") or [])
|
||||
if m.get("name") or m.get("model")
|
||||
]
|
||||
if models_url:
|
||||
resp = await client.get(models_url, headers=hdrs)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
ids = [m.get("id") for m in (data.get("data") or []) if m.get("id")]
|
||||
if not ids:
|
||||
ids = [
|
||||
m.get("name") or m.get("model")
|
||||
for m in (data.get("models") or [])
|
||||
if m.get("name") or m.get("model")
|
||||
]
|
||||
else:
|
||||
import json as _json
|
||||
ids = _json.loads(ep.cached_models or "[]")
|
||||
model = ids[0] if ids else "auto"
|
||||
except Exception:
|
||||
raise HTTPException(500, "Could not discover models from endpoint")
|
||||
|
||||
@@ -13,6 +13,8 @@ import json
|
||||
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from src.constants import MEMORY_FILE, SKILLS_FILE
|
||||
|
||||
|
||||
def claim_json_entries(entries, owner):
|
||||
count = 0
|
||||
@@ -35,8 +37,8 @@ def main():
|
||||
|
||||
# 1. Memories (JSON files)
|
||||
for label, path in [
|
||||
("memory.json", "data/memory.json"),
|
||||
("skills.json", "data/skills.json"),
|
||||
("memory.json", MEMORY_FILE),
|
||||
("skills.json", SKILLS_FILE),
|
||||
]:
|
||||
if not os.path.exists(path):
|
||||
print(f" {label}: not found, skipping")
|
||||
|
||||
@@ -34,6 +34,7 @@ import torch
|
||||
import uvicorn
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from starlette.middleware.trustedhost import TrustedHostMiddleware
|
||||
from pydantic import BaseModel
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
@@ -52,7 +53,63 @@ async def lifespan(application):
|
||||
|
||||
|
||||
app = FastAPI(title="Diffusion Server", lifespan=lifespan)
|
||||
app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])
|
||||
|
||||
# Conservative defaults — server is designed for server-to-server use from
|
||||
# the Odysseus backend. Wildcard CORS + the 127.0.0.1 default bind used to
|
||||
# leave the server reachable via DNS-rebinding from any browser tab on the
|
||||
# same host. The CLI flags below extend these allowlists for operators who
|
||||
# need browser access; the safe defaults handle the common case.
|
||||
_DEFAULT_ALLOWED_HOSTS = ["127.0.0.1", "localhost", "::1"]
|
||||
_DEFAULT_CORS_ORIGINS: list = [] # default-deny
|
||||
|
||||
|
||||
def _compute_allowed_hosts(bind_host: str, extras=None) -> list:
|
||||
"""Allowed Host header values: the bind address + loopback variants +
|
||||
any operator-supplied --allowed-host values. Duplicates and empty
|
||||
strings are dropped; order is stable for predictable middleware setup."""
|
||||
seen = []
|
||||
for h in (bind_host, *_DEFAULT_ALLOWED_HOSTS, *(extras or [])):
|
||||
h = (h or "").strip()
|
||||
if h and h not in seen:
|
||||
seen.append(h)
|
||||
return seen
|
||||
|
||||
|
||||
def _compute_cors_origins(extras=None) -> list:
|
||||
"""CORS allowlist: default-deny (empty), extended only by explicit
|
||||
--allowed-origin values. Server-to-server callers don't set an Origin
|
||||
header so they're unaffected; this only narrows browser access."""
|
||||
seen = []
|
||||
for o in (*_DEFAULT_CORS_ORIGINS, *(extras or [])):
|
||||
o = (o or "").strip()
|
||||
if o and o not in seen:
|
||||
seen.append(o)
|
||||
return seen
|
||||
|
||||
|
||||
def _configure_security_middleware(application, allowed_hosts, allowed_origins):
|
||||
"""Replace `application`'s user middleware stack with the diffusion server
|
||||
security middleware: the TrustedHost allowlist and, when origins are
|
||||
supplied, CORS. Used at module load and by the __main__ CLI path before
|
||||
serving starts. Raises before mutating if the middleware stack has already
|
||||
been built. Order is preserved: TrustedHost first, then CORS (added last ->
|
||||
outermost)."""
|
||||
if application.middleware_stack is not None:
|
||||
raise RuntimeError("security middleware must be configured before the app starts serving")
|
||||
application.user_middleware.clear()
|
||||
application.add_middleware(TrustedHostMiddleware, allowed_hosts=list(allowed_hosts))
|
||||
if allowed_origins:
|
||||
application.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=list(allowed_origins),
|
||||
allow_methods=["GET", "POST", "OPTIONS"],
|
||||
allow_headers=["Authorization", "Content-Type"],
|
||||
)
|
||||
|
||||
|
||||
# Install defaults at module load so importing the app for tests / direct
|
||||
# uvicorn invocation still benefits from the Host-header allowlist.
|
||||
_configure_security_middleware(app, _DEFAULT_ALLOWED_HOSTS, _DEFAULT_CORS_ORIGINS)
|
||||
|
||||
|
||||
class ImageRequest(BaseModel):
|
||||
@@ -1089,7 +1146,25 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--attention-slicing", action="store_true", help="Enable attention slicing")
|
||||
parser.add_argument("--vae-slicing", action="store_true", help="Enable VAE slicing")
|
||||
parser.add_argument("--harmonize-gpu", type=int, default=None, help="GPU index for harmonize/img2img (default: same as main)")
|
||||
parser.add_argument("--allowed-host", action="append", default=[],
|
||||
help="Additional Host header value to accept (DNS-rebinding allowlist). "
|
||||
"Can be repeated. Loopback values are always included.")
|
||||
parser.add_argument("--allowed-origin", action="append", default=[],
|
||||
help="Additional CORS origin to allow. Can be repeated. Defaults to "
|
||||
"no cross-origin access — only pass this if you need a browser "
|
||||
"on a specific origin to call the server.")
|
||||
_args = parser.parse_args()
|
||||
|
||||
# Replace the module-load middleware stack with the CLI-configured one so
|
||||
# operator-supplied --allowed-host / --allowed-origin values take effect
|
||||
# before the first request is served. user_middleware is consulted lazily
|
||||
# when the middleware stack is built on the first request, so mutating it
|
||||
# here is safe.
|
||||
final_hosts = _compute_allowed_hosts(_args.host, _args.allowed_host)
|
||||
final_origins = _compute_cors_origins(_args.allowed_origin)
|
||||
_configure_security_middleware(app, final_hosts, final_origins)
|
||||
logger.info("security middleware: allowed_hosts=%s allowed_origins=%s",
|
||||
final_hosts, final_origins or "(none — default-deny)")
|
||||
|
||||
app.state.model_path = _args.model
|
||||
uvicorn.run(app, host=_args.host, port=_args.port)
|
||||
|
||||
@@ -19,6 +19,9 @@ import sys
|
||||
from pathlib import Path
|
||||
from typing import List, Tuple
|
||||
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
from src.constants import PERSONAL_DIR
|
||||
|
||||
# Configure logging for the script
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
@@ -45,7 +48,7 @@ def main():
|
||||
rag_manager = RAGManager()
|
||||
|
||||
# Directory to scan
|
||||
docs_directory = "data/personal_docs"
|
||||
docs_directory = PERSONAL_DIR
|
||||
directory_path = Path(docs_directory)
|
||||
|
||||
# Check if directory exists
|
||||
|
||||
@@ -63,10 +63,10 @@ def migrate_memories():
|
||||
"""Migrate memory vectors from FAISS to ChromaDB."""
|
||||
from src.chroma_client import get_chroma_client
|
||||
from src.embeddings import get_embedding_client
|
||||
from src.constants import DATA_DIR
|
||||
from src.constants import MEMORY_VECTORS_DIR, MEMORY_FILE
|
||||
|
||||
ids_path = os.path.join(DATA_DIR, "memory_vectors", "ids.json")
|
||||
memory_path = os.path.join(DATA_DIR, "memory.json")
|
||||
ids_path = os.path.join(MEMORY_VECTORS_DIR, "ids.json")
|
||||
memory_path = MEMORY_FILE
|
||||
|
||||
if not os.path.exists(ids_path):
|
||||
logger.info("No memory FAISS index found, skipping memory migration")
|
||||
|
||||
@@ -47,6 +47,9 @@ _STATE_PATH = _DATA_DIR / "cookbook_state.json"
|
||||
import tempfile
|
||||
_TMUX_LOG_DIR = Path(tempfile.gettempdir()) / "odysseus-tmux"
|
||||
|
||||
from core.platform_compat import NVIDIA_PATH_CANDIDATES, SSH_PATH_OVERRIDE
|
||||
|
||||
|
||||
|
||||
def fail(msg: str, code: int = 1) -> None:
|
||||
sys.stderr.write(f"error: {msg}\n")
|
||||
@@ -160,7 +163,26 @@ def cmd_gpus(args) -> None:
|
||||
prefix = _ssh_prefix(args.host, args.ssh_port)
|
||||
cmd = prefix + (query.split() if not prefix else [query])
|
||||
try:
|
||||
out = subprocess.run(cmd, capture_output=True, text=True, timeout=15)
|
||||
if prefix:
|
||||
candidates = [query]
|
||||
args_part = query[len("nvidia-smi "):]
|
||||
candidates.append(
|
||||
"bash -lc "
|
||||
+ repr(
|
||||
f"{SSH_PATH_OVERRIDE}"
|
||||
f"nvidia-smi {args_part}"
|
||||
)
|
||||
)
|
||||
for nvidia_path in NVIDIA_PATH_CANDIDATES:
|
||||
candidates.append(f"{nvidia_path} {args_part}")
|
||||
|
||||
out = None
|
||||
for candidate in candidates:
|
||||
out = subprocess.run(prefix + [candidate], capture_output=True, text=True, timeout=15)
|
||||
if out.returncode == 0:
|
||||
break
|
||||
else:
|
||||
out = subprocess.run(cmd, capture_output=True, text=True, timeout=15)
|
||||
except FileNotFoundError:
|
||||
# No nvidia-smi locally → try the Metal fallback before giving up.
|
||||
if not prefix:
|
||||
|
||||
@@ -25,6 +25,24 @@ from pathlib import Path
|
||||
|
||||
_DATA_DIR = _REPO_ROOT / "data" / "deep_research"
|
||||
|
||||
# The CLI's --status takes the user-facing label "complete", but the writer
|
||||
# in services/research/research_handler.py stores `status="done"` when a run
|
||||
# finishes (and the legacy src/research_handler.py does the same). Without
|
||||
# this alias, --status complete filters every finished record out and the
|
||||
# user sees an empty list. Map at filter time so the on-disk corpus is the
|
||||
# source of truth and the CLI surface stays the friendlier word. The other
|
||||
# choices ("running", "cancelled", "error") are stored verbatim, so they
|
||||
# fall through unchanged.
|
||||
_STATUS_CLI_TO_STORED = {"complete": "done"}
|
||||
|
||||
|
||||
def _status_matches(stored, requested: str) -> bool:
|
||||
stored = (stored or "")
|
||||
if not isinstance(stored, str):
|
||||
stored = ""
|
||||
target = _STATUS_CLI_TO_STORED.get(requested, requested)
|
||||
return stored == target
|
||||
|
||||
|
||||
def _load_path(path: Path) -> dict | None:
|
||||
try:
|
||||
@@ -72,7 +90,7 @@ def cmd_list(args):
|
||||
data = _load_path(path)
|
||||
if data is None:
|
||||
continue
|
||||
if args.status and (data.get("status") or "") != args.status:
|
||||
if args.status and not _status_matches(data.get("status"), args.status):
|
||||
continue
|
||||
out.append(_summarize(rp_id, data))
|
||||
out.sort(key=lambda r: r.get("started_at") or "", reverse=True)
|
||||
|
||||
@@ -5,6 +5,7 @@ from dataclasses import dataclass
|
||||
from typing import List, Dict, Any
|
||||
|
||||
from src.rag_manager import RAGManager
|
||||
from src.constants import CHROMA_DIR
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -34,7 +35,7 @@ class DocsService:
|
||||
results = await service.query("what is async await?")
|
||||
"""
|
||||
|
||||
def __init__(self, persist_dir: str = "data/chroma"):
|
||||
def __init__(self, persist_dir: str = CHROMA_DIR):
|
||||
self.rag = RAGManager(persist_directory=persist_dir)
|
||||
|
||||
async def query(self, query: str, top_k: int = 5) -> List[DocChunk]:
|
||||
|
||||
+93
-48
@@ -4,6 +4,13 @@ import re
|
||||
import shutil
|
||||
import subprocess
|
||||
import time
|
||||
import shlex
|
||||
|
||||
from core.platform_compat import (
|
||||
NVIDIA_PATH_CANDIDATES,
|
||||
SSH_PATH_OVERRIDE,
|
||||
run_ssh_command,
|
||||
)
|
||||
|
||||
CACHE_TTL = 24 * 3600 # 24 h — hardware probes are user-initiated via the Rescan button; bumped
|
||||
# from 30 min so changing filters doesn't keep re-probing the rig every
|
||||
@@ -21,16 +28,17 @@ def _run(cmd):
|
||||
if _remote_host:
|
||||
# Run command on remote host via SSH
|
||||
if isinstance(cmd, list):
|
||||
cmd_str = " ".join(cmd)
|
||||
cmd_str = shlex.join(str(c) for c in cmd)
|
||||
else:
|
||||
cmd_str = cmd
|
||||
ssh_cmd = ["ssh", "-o", "ConnectTimeout=5", "-o", "StrictHostKeyChecking=no"]
|
||||
if _remote_port and _remote_port != "22":
|
||||
ssh_cmd += ["-p", _remote_port]
|
||||
ssh_cmd += [_remote_host, cmd_str]
|
||||
r = subprocess.run(
|
||||
ssh_cmd,
|
||||
capture_output=True, text=True, timeout=15,
|
||||
r = run_ssh_command(
|
||||
_remote_host,
|
||||
_remote_port,
|
||||
cmd_str,
|
||||
timeout=15,
|
||||
connect_timeout=5,
|
||||
strict_host_key_checking=False,
|
||||
text=True,
|
||||
)
|
||||
else:
|
||||
r = subprocess.run(cmd, capture_output=True, text=True, timeout=10)
|
||||
@@ -76,21 +84,29 @@ def _detect_nvidia():
|
||||
global _last_gpu_error
|
||||
_last_gpu_error = None
|
||||
out = _run(["nvidia-smi", "--query-gpu=memory.total,name", "--format=csv,noheader,nounits"])
|
||||
# Remote fallback: a non-interactive SSH shell often has a minimal PATH
|
||||
# that omits where nvidia-smi lives (/usr/bin, /usr/local/cuda/bin), so the
|
||||
# first call silently returns nothing → "No GPU" on hosts that DO have GPUs.
|
||||
# Fallback: a non-interactive shell (or WSL) often has a minimal PATH
|
||||
# that omits where nvidia-smi lives (/usr/bin, /usr/local/cuda/bin,
|
||||
# /usr/lib/wsl/lib), so the first call silently returns nothing →
|
||||
# "No GPU" on machines that DO have GPUs.
|
||||
# Retry through a login shell with the common CUDA bin dirs on PATH.
|
||||
if not out and _remote_host:
|
||||
out = _run(
|
||||
"bash -lc 'export PATH=\"$PATH:/usr/bin:/usr/local/bin:/usr/local/cuda/bin\"; "
|
||||
f"bash -lc '{SSH_PATH_OVERRIDE}"
|
||||
"nvidia-smi --query-gpu=memory.total,name --format=csv,noheader,nounits'"
|
||||
)
|
||||
# Last resort: call nvidia-smi by absolute path. Some hosts have a login
|
||||
# shell that isn't bash (or a profile that errors), so the bash -lc retry
|
||||
# above still comes back empty even though the binary is right there.
|
||||
if not out and _remote_host:
|
||||
for _p in ("/usr/bin/nvidia-smi", "/usr/local/bin/nvidia-smi", "/usr/local/cuda/bin/nvidia-smi"):
|
||||
out = _run(f"{_p} --query-gpu=memory.total,name --format=csv,noheader,nounits")
|
||||
# Also handles WSL where nvidia-smi lives at /usr/lib/wsl/lib/ — a path
|
||||
# that may not be in the server process's PATH.
|
||||
if not out:
|
||||
for _p in NVIDIA_PATH_CANDIDATES:
|
||||
# Use list form so subprocess.run (local) resolves the absolute path
|
||||
# correctly instead of treating the whole string as an executable name.
|
||||
if _remote_host:
|
||||
out = _run(f"{_p} --query-gpu=memory.total,name --format=csv,noheader,nounits")
|
||||
else:
|
||||
out = _run([_p, "--query-gpu=memory.total,name", "--format=csv,noheader,nounits"])
|
||||
if out:
|
||||
break
|
||||
if not out:
|
||||
@@ -468,39 +484,55 @@ def _detect_windows():
|
||||
"""
|
||||
# Single PowerShell command that gathers all hardware info at once
|
||||
ps_cmd = (
|
||||
"$r = @{}; "
|
||||
"$os = Get-CimInstance Win32_OperatingSystem; "
|
||||
"$r.ram_gb = [math]::Round($os.TotalVisibleMemorySize / 1048576, 1); "
|
||||
"$r.avail_gb = [math]::Round($os.FreePhysicalMemory / 1048576, 1); "
|
||||
"$cpu = Get-CimInstance Win32_Processor | Select-Object -First 1; "
|
||||
"$r.cpu_name = $cpu.Name; "
|
||||
"$r.cpu_cores = (Get-CimInstance Win32_Processor | Measure-Object -Property NumberOfLogicalProcessors -Sum).Sum; "
|
||||
"$r.arch = $cpu.AddressWidth; "
|
||||
"""
|
||||
$r = @{}
|
||||
$os = Get-CimInstance Win32_OperatingSystem
|
||||
$r.ram_gb = [math]::Round($os.TotalVisibleMemorySize / 1048576, 1)
|
||||
$r.avail_gb = [math]::Round($os.FreePhysicalMemory / 1048576, 1)
|
||||
$cpu = Get-CimInstance Win32_Processor | Select-Object -First 1
|
||||
$r.cpu_name = $cpu.Name
|
||||
$r.cpu_cores = (Get-CimInstance Win32_Processor | Measure-Object -Property NumberOfLogicalProcessors -Sum).Sum
|
||||
$r.arch = $cpu.AddressWidth
|
||||
# GPU detection via nvidia-smi (fastest) or WMI fallback
|
||||
"try { "
|
||||
" $nv = nvidia-smi --query-gpu=memory.total,name --format=csv,noheader,nounits 2>$null; "
|
||||
" if ($LASTEXITCODE -eq 0 -and $nv) { "
|
||||
" $gpus = @(); "
|
||||
" foreach ($line in $nv -split \"`n\") { "
|
||||
" $p = $line -split ','; "
|
||||
" if ($p.Count -ge 2) { $gpus += [pscustomobject]@{name=$p[1].Trim(); vram_mb=[double]$p[0].Trim()} } "
|
||||
" }; "
|
||||
" $r.gpu_name = $gpus[0].name; "
|
||||
" $r.gpu_vram_gb = [math]::Round(($gpus | Measure-Object -Property vram_mb -Sum).Sum / 1024, 1); "
|
||||
" $r.gpu_count = $gpus.Count; "
|
||||
" $r.gpu_backend = 'cuda'; "
|
||||
" } "
|
||||
"} catch {}; "
|
||||
"if (-not $r.gpu_name) { "
|
||||
" $wmiGpu = Get-CimInstance Win32_VideoController | Where-Object { $_.AdapterRAM -gt 0 } | Select-Object -First 1; "
|
||||
" if ($wmiGpu) { "
|
||||
" $r.gpu_name = $wmiGpu.Name; "
|
||||
" $r.gpu_vram_gb = [math]::Round($wmiGpu.AdapterRAM / 1073741824, 1); "
|
||||
" $r.gpu_count = 1; "
|
||||
" $r.gpu_backend = 'cpu_x86'; " # WMI doesn't tell us CUDA/ROCm
|
||||
" } "
|
||||
"}; "
|
||||
"$r | ConvertTo-Json -Compress"
|
||||
try {
|
||||
$nv = nvidia-smi --query-gpu=memory.total,name --format=csv,noheader,nounits 2>$null
|
||||
if ($LASTEXITCODE -eq 0 -and $nv) {
|
||||
$gpus = @()
|
||||
foreach ($line in $nv -split "`n") {
|
||||
$p = $line -split ','
|
||||
if ($p.Count -ge 2) { $gpus += [pscustomobject]@{name = $p[1].Trim(); vram_mb = [double]$p[0].Trim() } }
|
||||
}
|
||||
$r.gpu_name = $gpus[0].name
|
||||
$r.gpu_vram_gb = [math]::Round(($gpus | Measure-Object -Property vram_mb -Sum).Sum / 1024, 1)
|
||||
$r.gpu_count = $gpus.Count
|
||||
$r.gpu_backend = 'cuda'
|
||||
}
|
||||
}
|
||||
catch {}
|
||||
if (-not $r.gpu_name) {
|
||||
$wmiGpu = Get-CimInstance Win32_VideoController | Where-Object { $_.AdapterRAM -gt 0 } | Select-Object -First 1
|
||||
$GPUDriverKey = "HKLM:\\SYSTEM\\CurrentControlSet\\Control\\Class\\{4d36e968-e325-11ce-bfc1-08002be10318}\\0*"
|
||||
$GPUDeviceID = $wmiGpu.PNPDeviceID.Split('&')[0..1] -join '&'
|
||||
$VRAMfromRegistry = Get-ItemProperty -Path $GPUDriverKey |
|
||||
Where-Object { $_.MatchingDeviceId -like "${GPUDeviceID}*" } |
|
||||
# Sometimes there happen to be multiple driver classes for the same gpu.
|
||||
Select-Object -ExpandProperty HardwareInformation.qwMemorySize -ErrorAction SilentlyContinue -First 1
|
||||
if ($wmiGpu) {
|
||||
$r.gpu_name = $wmiGpu.Name
|
||||
# Edge case: driver is broken, otherwise $wmiGpu.AdapterRAM is redundant
|
||||
if ($VRAMfromRegistry -ge $wmiGpu.AdapterRAM) {
|
||||
$r.gpu_vram_gb = [math]::Round($VRAMfromRegistry / 1073741824, 1)
|
||||
}
|
||||
else {
|
||||
$r.gpu_vram_gb = [math]::Round($wmiGpu.AdapterRAM / 1073741824, 1)
|
||||
}
|
||||
$r.gpu_count = 1
|
||||
# WMI doesn't tell us CUDA/ROCm
|
||||
$r.gpu_backend = 'cpu_x86';
|
||||
}
|
||||
}
|
||||
$r | ConvertTo-Json -Compress
|
||||
"""
|
||||
)
|
||||
if _remote_host:
|
||||
# Remote: ship a single command string over SSH. The remote shell parses
|
||||
@@ -566,6 +598,19 @@ def _detect_windows():
|
||||
_cache_by_host = {} # host -> (timestamp, result)
|
||||
|
||||
|
||||
def _cache_key(host: str, ssh_port: str, platform_name: str):
|
||||
"""Build a stable cache key that isolates remote SSH context.
|
||||
|
||||
Same host aliases can have different hardware due to visibility, forwarding etc.
|
||||
To avoid using the wrong cached hardware info, include the SSH port and platform in the cache key.
|
||||
"""
|
||||
return (
|
||||
host or "_local",
|
||||
str(ssh_port or ""),
|
||||
str(platform_name or "").lower(),
|
||||
)
|
||||
|
||||
|
||||
def detect_system(host="", ssh_port="", platform="", fresh=False):
|
||||
"""Detect system hardware: RAM, CPU, GPU. Cached per host (hardware rarely
|
||||
changes, and probing a remote host over SSH is slow). Pass fresh=True to
|
||||
@@ -575,7 +620,7 @@ def detect_system(host="", ssh_port="", platform="", fresh=False):
|
||||
"""
|
||||
global _remote_host, _remote_port, _remote_platform
|
||||
|
||||
cache_key = host or "_local"
|
||||
cache_key = _cache_key(host, ssh_port, platform)
|
||||
now = time.time()
|
||||
if not fresh and cache_key in _cache_by_host:
|
||||
ts, cached = _cache_by_host[cache_key]
|
||||
|
||||
@@ -192,11 +192,19 @@ def _fallback_memory_candidates(messages) -> list[dict]:
|
||||
if place:
|
||||
add(f"User lives in {place}.", "identity")
|
||||
|
||||
m = re.search(r"\bi (?:prefer|like|love|hate|do not like|don't like)\s+([^.!?\n]{4,100})", text, re.I)
|
||||
m = re.search(r"\bi (prefer|like|love|hate|do not like|don't like)\s+([^.!?\n]{4,100})", text, re.I)
|
||||
if m:
|
||||
preference = _clean_memory_value(m.group(1), 100)
|
||||
preference = _clean_memory_value(m.group(2), 100)
|
||||
if preference:
|
||||
add(f"User prefers {preference}.", "preference")
|
||||
# The same pattern catches likes and dislikes; keep the stored
|
||||
# sentiment faithful instead of recording every match as a
|
||||
# preference ("I hate cilantro" must not become "User prefers
|
||||
# cilantro").
|
||||
verb = m.group(1).lower()
|
||||
if verb in ("hate", "do not like", "don't like"):
|
||||
add(f"User dislikes {preference}.", "preference")
|
||||
else:
|
||||
add(f"User prefers {preference}.", "preference")
|
||||
|
||||
m = re.search(
|
||||
r"\bi (?:(?:want|would like|plan|hope) to|wanna) "
|
||||
@@ -228,6 +236,43 @@ def _is_text_duplicate(new_text: str, existing: list, threshold: float = 0.6) ->
|
||||
return False
|
||||
|
||||
|
||||
def _parse_extraction_json(raw: str) -> list:
|
||||
"""Parse the extraction LLM's reply into a list of facts, tolerating
|
||||
reasoning-model noise.
|
||||
|
||||
The model emits <think>…</think> (and sometimes a prose preamble or a
|
||||
```json fence) AROUND the JSON array; without stripping it, json.loads
|
||||
bombs and the run silently yields "0 candidates". Pure str -> list (no
|
||||
LLM/network); returns [] on any parse failure instead of raising.
|
||||
"""
|
||||
text = (raw or "").strip()
|
||||
try:
|
||||
from src.text_helpers import strip_think as _strip_think
|
||||
text = _strip_think(text, prose=True, prompt_echo=True).strip()
|
||||
except Exception:
|
||||
pass
|
||||
if text.startswith("```"):
|
||||
text = text.split("\n", 1)[-1].rsplit("```", 1)[0].strip()
|
||||
# JSON may still be embedded in surrounding commentary (leading prose or
|
||||
# trailing remarks like "[...] Done!") — slice from the first '[' to the
|
||||
# last ']' whenever both exist. Slice unconditionally: a reply that starts
|
||||
# with '[' can still carry trailing commentary that breaks json.loads.
|
||||
_start = text.find("[")
|
||||
_end = text.rfind("]")
|
||||
if 0 <= _start < _end:
|
||||
text = text[_start : _end + 1]
|
||||
|
||||
try:
|
||||
facts = json.loads(text)
|
||||
except json.JSONDecodeError:
|
||||
logger.debug("Memory extraction returned non-JSON: %r", (raw or "")[:120])
|
||||
return []
|
||||
except Exception:
|
||||
logger.debug("Memory extraction returned non-JSON: %r", (raw or "")[:120])
|
||||
return []
|
||||
return facts if isinstance(facts, list) else []
|
||||
|
||||
|
||||
async def extract_and_store(
|
||||
session,
|
||||
memory_manager,
|
||||
@@ -276,9 +321,34 @@ async def extract_and_store(
|
||||
|
||||
fallback_facts = _fallback_memory_candidates(stripped_recent)
|
||||
|
||||
# Flatten the window into a SINGLE user message instead of appending the
|
||||
# raw alternating role messages. Passed as raw chat messages, the model
|
||||
# treats the window as a conversation to CONTINUE rather than a transcript
|
||||
# to ANALYZE, so it reliably extracts nothing — typically returning `[]`
|
||||
# (and, depending on the input, sometimes an empty or <think>-only
|
||||
# completion when the window ends on an assistant turn). This was the real
|
||||
# cause of auto-memory logging "0 candidates" on every run. Reframing it as
|
||||
# one "analyze this transcript, return the JSON array" user message makes
|
||||
# the model actually extract. Controlled repro on this model: 0/6 trials
|
||||
# with the old structure vs 6/6 with this one. The skill extractor flattens
|
||||
# for the same reason.
|
||||
def _flatten_msg(m):
|
||||
c = m.get("content", "")
|
||||
if isinstance(c, list):
|
||||
c = " ".join(
|
||||
b.get("text", "") for b in c
|
||||
if isinstance(b, dict) and b.get("type") == "text"
|
||||
)
|
||||
return f"{m.get('role', '?')}: {c}"
|
||||
|
||||
transcript = "\n\n".join(_flatten_msg(m) for m in stripped_recent)
|
||||
extraction_messages = [
|
||||
{"role": "system", "content": EXTRACT_SYSTEM_PROMPT},
|
||||
] + stripped_recent
|
||||
{"role": "user", "content": (
|
||||
"Conversation to analyze:\n\n" + transcript
|
||||
+ "\n\nReturn the JSON array of durable facts now (or [] if none)."
|
||||
)},
|
||||
]
|
||||
|
||||
facts = []
|
||||
try:
|
||||
@@ -287,19 +357,20 @@ async def extract_and_store(
|
||||
model,
|
||||
extraction_messages,
|
||||
temperature=0.1,
|
||||
max_tokens=500,
|
||||
# A reasoning model spends most of its budget on <think> tokens
|
||||
# BEFORE emitting the JSON, so the old 500 truncated the response
|
||||
# before any JSON appeared → every run logged "0 candidates". The
|
||||
# audit path hit the same wall and raised to 16384; extraction's
|
||||
# output (a short facts list) is small, so an ample ceiling is
|
||||
# enough once thinking has room.
|
||||
max_tokens=4096,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
# Parse JSON from response (handle markdown fences if model wraps them)
|
||||
text = raw.strip()
|
||||
if text.startswith("```"):
|
||||
text = text.split("\n", 1)[-1].rsplit("```", 1)[0].strip()
|
||||
|
||||
try:
|
||||
facts = json.loads(text)
|
||||
except json.JSONDecodeError:
|
||||
logger.debug("Memory extraction returned non-JSON")
|
||||
# Parse JSON, tolerating reasoning-model noise (<think> blocks, a
|
||||
# ```json fence, and leading/trailing commentary). See
|
||||
# _parse_extraction_json — returns [] rather than raising.
|
||||
facts = _parse_extraction_json(raw)
|
||||
except Exception as e:
|
||||
logger.warning(f"LLM memory extraction failed; using fallback candidates if available: {e}")
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@ import os
|
||||
from .memory import MemoryManager
|
||||
from .memory_vector import MemoryVectorStore
|
||||
from src.memory_provider import MemoryRecord, NativeMemoryProvider
|
||||
from src.constants import DATA_DIR
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -38,7 +39,7 @@ class MemoryService:
|
||||
results = await service.recall("preferences")
|
||||
"""
|
||||
|
||||
def __init__(self, data_dir: str = "data"):
|
||||
def __init__(self, data_dir: str = DATA_DIR):
|
||||
self.manager = MemoryManager(data_dir)
|
||||
self.vector_store = MemoryVectorStore(data_dir) if os.path.exists(
|
||||
os.path.join(data_dir, "memory_vectors")
|
||||
|
||||
@@ -63,6 +63,46 @@ def _has_duplicate_title(skills, title: str) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def _extract_json_object(text: str) -> Optional[dict]:
|
||||
"""Best-effort extraction of a JSON object from an LLM response.
|
||||
|
||||
The response may be wrapped in code fences or surrounded by prose, and some
|
||||
models emit a stray brace in the prose before the real object
|
||||
(e.g. "uses {placeholder} then {...}"). Slicing first-'{' .. last-'}' then
|
||||
grabs an unparseable span and the skill is silently lost. Try the whole
|
||||
string first, then each '{' start position in turn, returning the first
|
||||
candidate that parses to a JSON object (dict). Returns None if none do.
|
||||
"""
|
||||
if not text:
|
||||
return None
|
||||
s = text.strip()
|
||||
if s.startswith("```"):
|
||||
s = s.split("\n", 1)[-1].rsplit("```", 1)[0].strip()
|
||||
end = s.rfind("}")
|
||||
if end == -1:
|
||||
return None
|
||||
|
||||
def _as_dict(candidate):
|
||||
try:
|
||||
obj = json.loads(candidate)
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
return None
|
||||
return obj if isinstance(obj, dict) else None
|
||||
|
||||
# The clean, common case: the whole (de-fenced) string is the object.
|
||||
obj = _as_dict(s)
|
||||
if obj is not None:
|
||||
return obj
|
||||
# Otherwise scan each '{' candidate up to the last '}'.
|
||||
start = s.find("{")
|
||||
while 0 <= start < end:
|
||||
obj = _as_dict(s[start : end + 1])
|
||||
if obj is not None:
|
||||
return obj
|
||||
start = s.find("{", start + 1)
|
||||
return None
|
||||
|
||||
|
||||
async def maybe_extract_skill(
|
||||
session,
|
||||
skills_manager,
|
||||
@@ -169,21 +209,14 @@ async def maybe_extract_skill(
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Parse JSON
|
||||
text = response.strip()
|
||||
if text.startswith("```"):
|
||||
text = text.split("\n", 1)[-1].rsplit("```", 1)[0].strip()
|
||||
# After strip_think, the JSON may still be embedded inside surrounding
|
||||
# commentary — slice from the first '{' to the matching last '}'.
|
||||
if text and text[0] != "{":
|
||||
_start = text.find("{")
|
||||
_end = text.rfind("}")
|
||||
if 0 <= _start < _end:
|
||||
text = text[_start : _end + 1]
|
||||
|
||||
data = json.loads(text)
|
||||
if not data or not isinstance(data, dict):
|
||||
logger.debug("[skill-extract] parsed JSON not a dict, dropping")
|
||||
# Parse JSON. The object may be wrapped in code fences or surrounded by
|
||||
# commentary (and may contain a stray/invalid brace fragment before
|
||||
# the real object — including one that makes the response itself look
|
||||
# like it starts with '{'), so use a tolerant extractor that tries the
|
||||
# whole string first and then each '{' candidate left-to-right.
|
||||
data = _extract_json_object(response)
|
||||
if not data:
|
||||
logger.debug("[skill-extract] no JSON object found in response, dropping")
|
||||
return None
|
||||
|
||||
title = data.get("title", "").strip()
|
||||
|
||||
@@ -0,0 +1,283 @@
|
||||
"""Import SKILL.md bundles from public GitHub (or skills.sh → GitHub) URLs."""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from urllib.parse import quote, urlparse
|
||||
|
||||
import httpx
|
||||
|
||||
from src.url_safety import check_outbound_url
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
MAX_FILES = 64
|
||||
MAX_TOTAL_BYTES = 2_000_000
|
||||
MAX_FILE_BYTES = 400_000
|
||||
ALLOWED_SUFFIXES = (
|
||||
".md", ".txt", ".json", ".yaml", ".yml", ".py", ".sh", ".toml",
|
||||
".js", ".ts", ".css", ".html", ".xml", ".csv",
|
||||
)
|
||||
TEXT_NAMES = {"skill.md", "license", "license.md", "readme.md"}
|
||||
_GITHUB_HOSTS = frozenset({
|
||||
"github.com", "www.github.com", "api.github.com", "raw.githubusercontent.com",
|
||||
})
|
||||
|
||||
|
||||
def _github_host(url: str) -> str:
|
||||
return (urlparse(str(url)).hostname or "").lower()
|
||||
|
||||
|
||||
def _assert_github_url(url: str, *, context: str = "URL") -> None:
|
||||
host = _github_host(url)
|
||||
if host not in _GITHUB_HOSTS:
|
||||
raise SkillImportError(
|
||||
f"{context} must stay on GitHub (got {host or 'unknown host'})"
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ResolvedSource:
|
||||
owner: str
|
||||
repo: str
|
||||
ref: str
|
||||
path: str # directory or file path inside repo (no leading slash)
|
||||
|
||||
|
||||
class SkillImportError(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
def _safe_relpath(rel: str) -> str:
|
||||
rel = (rel or "").replace("\\", "/").strip().lstrip("/")
|
||||
if not rel or rel.startswith("..") or "/../" in f"/{rel}/":
|
||||
raise SkillImportError(f"unsafe path: {rel!r}")
|
||||
parts = [p for p in rel.split("/") if p and p != "."]
|
||||
if any(p == ".." for p in parts):
|
||||
raise SkillImportError(f"unsafe path: {rel!r}")
|
||||
return "/".join(parts)
|
||||
|
||||
|
||||
def _is_text_file(name: str) -> bool:
|
||||
low = name.lower()
|
||||
if low in TEXT_NAMES:
|
||||
return True
|
||||
return any(low.endswith(s) for s in ALLOWED_SUFFIXES)
|
||||
|
||||
|
||||
def parse_skill_source(url: str) -> ResolvedSource:
|
||||
"""Normalize skills.sh / GitHub web URLs into owner/repo/ref/path."""
|
||||
raw = (url or "").strip()
|
||||
if not raw:
|
||||
raise SkillImportError("URL is required")
|
||||
|
||||
# skills.sh often links to GitHub; try to unwrap ?url= or redirect target later.
|
||||
if "skills.sh" in raw and "github.com" not in raw:
|
||||
ok, reason = check_outbound_url(raw)
|
||||
if not ok:
|
||||
raise SkillImportError(reason)
|
||||
with httpx.Client(follow_redirects=True, timeout=20.0) as client:
|
||||
r = client.get(raw)
|
||||
if r.status_code >= 400:
|
||||
raise _github_response_error(r)
|
||||
final = str(r.url)
|
||||
_assert_github_url(final, context="redirect target")
|
||||
# Page may embed a github link; prefer final URL if redirected.
|
||||
if "github.com" in final:
|
||||
raw = final
|
||||
else:
|
||||
m = re.search(r"https?://github\.com/[^\s\"')]+", r.text or "")
|
||||
if m:
|
||||
raw = m.group(0).rstrip(".,)")
|
||||
|
||||
parsed = urlparse(raw)
|
||||
host = _github_host(raw)
|
||||
if host not in _GITHUB_HOSTS:
|
||||
raise SkillImportError(
|
||||
"Only GitHub URLs are supported (https://github.com/... or raw.githubusercontent.com/...)"
|
||||
)
|
||||
|
||||
if host == "raw.githubusercontent.com":
|
||||
# /owner/repo/ref/path/to/file
|
||||
bits = [p for p in parsed.path.split("/") if p]
|
||||
if len(bits) < 4:
|
||||
raise SkillImportError("Invalid raw GitHub URL")
|
||||
owner, repo, ref = bits[0], bits[1], bits[2]
|
||||
path = "/".join(bits[3:])
|
||||
return ResolvedSource(owner=owner, repo=repo, ref=ref, path=path)
|
||||
|
||||
bits = [p for p in parsed.path.split("/") if p]
|
||||
if len(bits) < 2:
|
||||
raise SkillImportError("Invalid GitHub URL")
|
||||
owner, repo = bits[0], bits[1]
|
||||
ref = "main"
|
||||
path = ""
|
||||
|
||||
if len(bits) >= 4 and bits[2] in ("tree", "blob"):
|
||||
ref = bits[3]
|
||||
path = "/".join(bits[4:])
|
||||
elif len(bits) == 2:
|
||||
path = ""
|
||||
else:
|
||||
raise SkillImportError("GitHub URL must include /tree/<branch>/... or /blob/<branch>/...")
|
||||
|
||||
return ResolvedSource(owner=owner, repo=repo, ref=ref, path=path)
|
||||
|
||||
|
||||
def _raw_url(src: ResolvedSource, rel_path: str) -> str:
|
||||
rel = _safe_relpath(rel_path)
|
||||
return f"https://raw.githubusercontent.com/{src.owner}/{src.repo}/{quote(src.ref, safe='')}/{quote(rel, safe='/')}"
|
||||
|
||||
|
||||
def _api_contents_url(src: ResolvedSource, rel_path: str = "") -> str:
|
||||
rel = _safe_relpath(rel_path) if rel_path else ""
|
||||
base = f"https://api.github.com/repos/{src.owner}/{src.repo}/contents"
|
||||
if rel:
|
||||
base += f"/{quote(rel, safe='/')}"
|
||||
return f"{base}?ref={quote(src.ref, safe='')}"
|
||||
|
||||
|
||||
def _github_response_error(response: httpx.Response) -> SkillImportError:
|
||||
"""Turn a failed GitHub HTTP response into a user-visible import error."""
|
||||
status = response.status_code
|
||||
detail = ""
|
||||
try:
|
||||
body = response.json()
|
||||
if isinstance(body, dict):
|
||||
detail = str(body.get("message") or "").strip()
|
||||
except Exception:
|
||||
detail = (response.text or "").strip()[:200]
|
||||
|
||||
low = detail.lower()
|
||||
if status == 403 and "rate limit" in low:
|
||||
return SkillImportError(
|
||||
"GitHub API rate limit exceeded — try again in a bit"
|
||||
+ (f" ({detail})" if detail else "")
|
||||
)
|
||||
if status == 404:
|
||||
return SkillImportError("path not found on GitHub")
|
||||
if detail:
|
||||
return SkillImportError(f"GitHub request failed ({status}): {detail}")
|
||||
return SkillImportError(f"GitHub request failed ({status})")
|
||||
|
||||
|
||||
def _fetch_bytes(url: str) -> bytes:
|
||||
ok, reason = check_outbound_url(url)
|
||||
if not ok:
|
||||
raise SkillImportError(reason)
|
||||
with httpx.Client(follow_redirects=True, timeout=30.0) as client:
|
||||
r = client.get(url, headers={"Accept": "application/vnd.github+json"})
|
||||
if r.status_code >= 400:
|
||||
raise _github_response_error(r)
|
||||
_assert_github_url(str(r.url), context="redirect target")
|
||||
if len(r.content) > MAX_FILE_BYTES:
|
||||
raise SkillImportError(f"file too large: {url}")
|
||||
return r.content
|
||||
|
||||
|
||||
def _fetch_text(url: str) -> str:
|
||||
data = _fetch_bytes(url)
|
||||
try:
|
||||
return data.decode("utf-8")
|
||||
except UnicodeDecodeError as e:
|
||||
raise SkillImportError(f"non-text file: {url}") from e
|
||||
|
||||
|
||||
def _list_github_dir(src: ResolvedSource, rel_dir: str, out: Dict[str, str], *, depth: int = 0) -> None:
|
||||
if depth > 4 or len(out) >= MAX_FILES:
|
||||
return
|
||||
url = _api_contents_url(src, rel_dir)
|
||||
ok, reason = check_outbound_url(url)
|
||||
if not ok:
|
||||
raise SkillImportError(reason)
|
||||
with httpx.Client(follow_redirects=True, timeout=30.0) as client:
|
||||
r = client.get(url, headers={"Accept": "application/vnd.github+json"})
|
||||
if r.status_code >= 400:
|
||||
raise _github_response_error(r)
|
||||
_assert_github_url(str(r.url), context="redirect target")
|
||||
entries = r.json()
|
||||
if not isinstance(entries, list):
|
||||
raise SkillImportError("expected a directory on GitHub")
|
||||
total = sum(len(v.encode("utf-8")) for v in out.values())
|
||||
for ent in entries:
|
||||
if len(out) >= MAX_FILES or total >= MAX_TOTAL_BYTES:
|
||||
break
|
||||
if not isinstance(ent, dict):
|
||||
continue
|
||||
name = ent.get("name") or ""
|
||||
ent_type = ent.get("type")
|
||||
rel = _safe_relpath(f"{rel_dir}/{name}" if rel_dir else name)
|
||||
if ent_type == "dir":
|
||||
_list_github_dir(src, rel, out, depth=depth + 1)
|
||||
total = sum(len(v.encode("utf-8")) for v in out.values())
|
||||
continue
|
||||
if ent_type != "file" or not _is_text_file(name):
|
||||
continue
|
||||
dl = ent.get("download_url")
|
||||
if not dl:
|
||||
continue
|
||||
_assert_github_url(dl, context="download URL")
|
||||
text = _fetch_text(dl)
|
||||
total += len(text.encode("utf-8"))
|
||||
if total > MAX_TOTAL_BYTES:
|
||||
raise SkillImportError("skill bundle exceeds size limit")
|
||||
out[rel] = text
|
||||
|
||||
|
||||
def fetch_skill_bundle(url: str) -> Tuple[Dict[str, str], ResolvedSource]:
|
||||
"""Download SKILL.md and sibling text assets. Returns relative_path → content."""
|
||||
src = parse_skill_source(url)
|
||||
files: Dict[str, str] = {}
|
||||
|
||||
path = _safe_relpath(src.path) if src.path else ""
|
||||
if path.lower().endswith("skill.md"):
|
||||
files[path] = _fetch_text(_raw_url(src, path))
|
||||
parent = "/".join(path.split("/")[:-1])
|
||||
if parent:
|
||||
try:
|
||||
_list_github_dir(src, parent, files)
|
||||
except SkillImportError:
|
||||
pass
|
||||
return files, src
|
||||
|
||||
if path:
|
||||
try:
|
||||
_fetch_text(_raw_url(src, f"{path}/SKILL.md"))
|
||||
_list_github_dir(src, path, files)
|
||||
return files, src
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
text = _fetch_text(_raw_url(src, path))
|
||||
if path.lower().endswith(".md"):
|
||||
files[path] = text
|
||||
return files, src
|
||||
except Exception:
|
||||
pass
|
||||
_list_github_dir(src, path, files)
|
||||
else:
|
||||
_list_github_dir(src, "", files)
|
||||
|
||||
if not any(p.lower().endswith("skill.md") for p in files):
|
||||
# Flat repo root with SKILL.md only
|
||||
try:
|
||||
files["SKILL.md"] = _fetch_text(_raw_url(src, "SKILL.md"))
|
||||
except Exception as e:
|
||||
raise SkillImportError(
|
||||
"No SKILL.md found — link to a skill folder or SKILL.md on GitHub"
|
||||
) from e
|
||||
return files, src
|
||||
|
||||
|
||||
def pick_skill_md(files: Dict[str, str]) -> Tuple[str, str]:
|
||||
for rel, content in files.items():
|
||||
if rel.lower().endswith("skill.md"):
|
||||
return rel, content
|
||||
raise SkillImportError("bundle has no SKILL.md")
|
||||
|
||||
|
||||
def default_category_from_source(src: ResolvedSource) -> str:
|
||||
return "imported"
|
||||
@@ -381,6 +381,54 @@ class SkillsManager:
|
||||
|
||||
return sk.to_dict()
|
||||
|
||||
def import_bundle_from_files(
|
||||
self,
|
||||
files: Dict[str, str],
|
||||
*,
|
||||
owner: Optional[str] = None,
|
||||
source_url: str = "",
|
||||
category: str = "imported",
|
||||
) -> Dict:
|
||||
"""Install a fetched skill bundle (relative path → text) under skills/."""
|
||||
from .skill_importer import SkillImportError, pick_skill_md, _safe_relpath
|
||||
from core.atomic_io import atomic_write_text
|
||||
|
||||
if not files:
|
||||
raise SkillImportError("empty bundle")
|
||||
_rel, skill_md = pick_skill_md(files)
|
||||
sk = Skill.from_markdown(skill_md)
|
||||
nm = slugify(sk.name or _rel.split("/")[-2] or "skill")
|
||||
cat = slugify(category or sk.category or "imported", fallback="imported")
|
||||
|
||||
existing = {s["name"] for s in self.load_all()}
|
||||
base = nm
|
||||
i = 2
|
||||
while nm in existing:
|
||||
nm = f"{base}-{i}"
|
||||
i += 1
|
||||
|
||||
skill_dir = self._skill_dir(cat, nm)
|
||||
os.makedirs(skill_dir, exist_ok=True)
|
||||
|
||||
# Preserve bundle layout (templates/, references/, etc.) under the skill dir.
|
||||
for rel, content in files.items():
|
||||
safe = _safe_relpath(rel)
|
||||
dest = os.path.join(skill_dir, safe)
|
||||
os.makedirs(os.path.dirname(dest), exist_ok=True)
|
||||
atomic_write_text(dest, content)
|
||||
|
||||
sk.name = nm
|
||||
sk.category = cat
|
||||
sk.owner = owner
|
||||
sk.source = "imported"
|
||||
if source_url:
|
||||
extra = (sk.body_extra or "").strip()
|
||||
note = f"Imported from {source_url}"
|
||||
sk.body_extra = f"{extra}\n\n{note}".strip() if extra else note
|
||||
atomic_write_text(self._skill_file(cat, nm), sk.to_markdown())
|
||||
sk.path = self._skill_file(cat, nm)
|
||||
return sk.to_dict()
|
||||
|
||||
def update_skill(self, skill_id: str, updates: Dict, owner: Optional[str] = None) -> bool:
|
||||
"""`skill_id` is the slug name. Allows updating any field plus
|
||||
renames if `name` changes (file is moved on disk).
|
||||
|
||||
@@ -15,10 +15,11 @@ from pathlib import Path
|
||||
from typing import Optional, Dict
|
||||
|
||||
from src.research_utils import is_low_quality
|
||||
from src.constants import DEEP_RESEARCH_DIR
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
RESEARCH_DATA_DIR = Path("data/deep_research")
|
||||
RESEARCH_DATA_DIR = Path(DEEP_RESEARCH_DIR)
|
||||
|
||||
|
||||
class ResearchHandler:
|
||||
|
||||
@@ -6,21 +6,29 @@ from collections import Counter
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any
|
||||
|
||||
from core.constants import DATA_DIR
|
||||
|
||||
from .cache import cache_metrics
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Dedicated error logger with file handler
|
||||
_error_log_path = Path(__file__).resolve().parent.parent / "search_engine_error.log"
|
||||
_error_handler = logging.FileHandler(_error_log_path, encoding="utf-8")
|
||||
_error_handler.setLevel(logging.WARNING)
|
||||
_error_handler.setFormatter(logging.Formatter("%(asctime)s %(levelname)s %(name)s %(message)s"))
|
||||
# Dedicated error logger — write to the data logs directory (writable on both
|
||||
# native runs and Docker, where DATA_DIR resolves to the bind-mounted volume).
|
||||
_log_dir = Path(DATA_DIR) / "logs"
|
||||
_error_log_path = _log_dir / "search_engine_error.log"
|
||||
error_logger = logging.getLogger("search_engine_error")
|
||||
error_logger.addHandler(_error_handler)
|
||||
error_logger.propagate = False
|
||||
try:
|
||||
_log_dir.mkdir(parents=True, exist_ok=True)
|
||||
_error_handler = logging.FileHandler(_error_log_path, encoding="utf-8")
|
||||
_error_handler.setLevel(logging.WARNING)
|
||||
_error_handler.setFormatter(logging.Formatter("%(asctime)s %(levelname)s %(name)s %(message)s"))
|
||||
error_logger.addHandler(_error_handler)
|
||||
except Exception as _e:
|
||||
logging.getLogger(__name__).warning("search_engine_error log handler unavailable: %s", _e)
|
||||
|
||||
# Analytics file
|
||||
ANALYTICS_FILE = Path(__file__).resolve().parent.parent / "search_analytics.json"
|
||||
# Analytics file — also in the writable logs volume.
|
||||
ANALYTICS_FILE = _log_dir / "search_analytics.json"
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------
|
||||
|
||||
@@ -6,17 +6,23 @@ from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
from typing import Dict
|
||||
|
||||
from core.constants import DATA_DIR
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Cache directories
|
||||
CACHE_DIR = Path(__file__).resolve().parent.parent / "cache"
|
||||
CACHE_DIR = Path(DATA_DIR) / "cache"
|
||||
SEARCH_CACHE_DIR = CACHE_DIR / "search"
|
||||
CONTENT_CACHE_DIR = CACHE_DIR / "content"
|
||||
CACHE_MAX_ENTRIES = 1000
|
||||
|
||||
# Create cache directories
|
||||
SEARCH_CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
||||
CONTENT_CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
||||
# Create cache directories. Guarded so an unwritable path (e.g. a read-only
|
||||
# mount) degrades to no-disk-cache instead of crashing module import.
|
||||
try:
|
||||
SEARCH_CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
||||
CONTENT_CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
||||
except OSError as _e:
|
||||
logger.warning("Search cache directory unavailable (%s); disk cache disabled", _e)
|
||||
|
||||
# Track cache size for LRU eviction
|
||||
search_cache_index: Dict[str, datetime] = {}
|
||||
|
||||
@@ -259,6 +259,9 @@ def fetch_webpage_content(url: str, timeout: int = 5, retry_attempt: int = 0) ->
|
||||
raise RateLimitError(f"Rate limit hit for {url} (attempt {retry_attempt})")
|
||||
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPStatusError as e:
|
||||
error_logger.warning(f"HTTP {e.response.status_code} fetching {url}: {e}")
|
||||
return _empty_result(url, f"HTTP {e.response.status_code}: {e}")
|
||||
except httpx.RequestError as e:
|
||||
error_logger.error(f"NetworkError fetching {url} (attempt {retry_attempt}): {e}")
|
||||
return _empty_result(url, f"NetworkError: {e}")
|
||||
|
||||
@@ -76,6 +76,19 @@ def _domain(url: str) -> str:
|
||||
return ""
|
||||
|
||||
|
||||
def _has_word(text: str, term: str) -> bool:
|
||||
"""True if ``term`` appears in ``text`` as a whole word.
|
||||
|
||||
Query terms are matched on word boundaries so a short term doesn't match
|
||||
inside an unrelated word: "us" must not match "business"/"music", "port"
|
||||
must not match "transport"/"support". This mirrors the tokenization used to
|
||||
build ``query_terms`` (``\\b\\w+\\b``). #1473 converted the title and sports
|
||||
checks to word boundaries; the snippet and subject-term checks below use
|
||||
the same helper so the whole file stays consistent.
|
||||
"""
|
||||
return re.search(rf"\b{re.escape(term)}\b", text) is not None
|
||||
|
||||
|
||||
def rank_search_results(query: str, results: List[dict]) -> List[dict]:
|
||||
"""Rank search results by title relevance, snippet quality, domain authority, and recency."""
|
||||
query_terms = [t.lower() for t in re.findall(r"\b\w+\b", query)]
|
||||
@@ -87,14 +100,14 @@ def rank_search_results(query: str, results: List[dict]) -> List[dict]:
|
||||
if not title:
|
||||
return 0.0
|
||||
title_lc = title.lower()
|
||||
matches = sum(1 for term in query_terms if re.search(rf"\b{re.escape(term)}\b", title_lc))
|
||||
matches = sum(1 for term in query_terms if _has_word(title_lc, term))
|
||||
return matches / len(query_terms) if query_terms else 0.0
|
||||
|
||||
def snippet_score(snippet: str) -> float:
|
||||
if not snippet:
|
||||
return 0.0
|
||||
length_factor = min(len(snippet), 200) / 200
|
||||
term_hits = sum(1 for term in query_terms if term in snippet.lower())
|
||||
term_hits = sum(1 for term in query_terms if _has_word(snippet.lower(), term))
|
||||
term_factor = term_hits / len(query_terms) if query_terms else 0.0
|
||||
return (length_factor + term_factor) / 2
|
||||
|
||||
@@ -127,7 +140,7 @@ def rank_search_results(query: str, results: List[dict]) -> List[dict]:
|
||||
# A country/news query should not rank a page whose title/snippet barely
|
||||
# mentions the country above actual news pages for that country.
|
||||
subject_terms = [t for t in query_terms if t not in _NEWS_HINTS]
|
||||
if subject_terms and not any(t in text or t in netloc for t in subject_terms):
|
||||
if subject_terms and not any(_has_word(text, t) or _has_word(netloc, t) for t in subject_terms):
|
||||
adjustment -= 1.0
|
||||
return adjustment
|
||||
|
||||
|
||||
@@ -9,6 +9,8 @@ import httpx
|
||||
from pathlib import Path
|
||||
from typing import Optional, Dict, Any
|
||||
|
||||
from src.constants import TTS_CACHE_DIR
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -35,7 +37,7 @@ class TTSService:
|
||||
"endpoint:<id>" — OpenAI-compatible /audio/speech via ModelEndpoint
|
||||
"""
|
||||
|
||||
def __init__(self, cache_dir: str = "data/tts_cache"):
|
||||
def __init__(self, cache_dir: str = TTS_CACHE_DIR):
|
||||
self.cache_dir = Path(cache_dir)
|
||||
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
self._kokoro = None # lazy-init
|
||||
|
||||
@@ -6,23 +6,30 @@ initial admin user. Safe to re-run (skips what already exists).
|
||||
"""
|
||||
|
||||
import os
|
||||
import platform
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
DATA_DIR = os.path.join(BASE_DIR, "data")
|
||||
sys.path.insert(0, BASE_DIR)
|
||||
from src.constants import (
|
||||
DATA_DIR, AUTH_FILE, UPLOAD_DIR, PERSONAL_DIR, PERSONAL_UPLOADS_DIR,
|
||||
TTS_CACHE_DIR, GENERATED_IMAGES_DIR, DEEP_RESEARCH_DIR, CHROMA_DIR,
|
||||
RAG_DIR, MEMORY_VECTORS_DIR,
|
||||
)
|
||||
|
||||
DIRS = [
|
||||
DATA_DIR,
|
||||
os.path.join(DATA_DIR, "uploads"),
|
||||
os.path.join(DATA_DIR, "personal_docs"),
|
||||
os.path.join(DATA_DIR, "personal_uploads"),
|
||||
os.path.join(DATA_DIR, "tts_cache"),
|
||||
os.path.join(DATA_DIR, "generated_images"),
|
||||
os.path.join(DATA_DIR, "deep_research"),
|
||||
os.path.join(DATA_DIR, "chroma"),
|
||||
os.path.join(DATA_DIR, "rag"),
|
||||
os.path.join(DATA_DIR, "memory_vectors"),
|
||||
UPLOAD_DIR,
|
||||
PERSONAL_DIR,
|
||||
PERSONAL_UPLOADS_DIR,
|
||||
TTS_CACHE_DIR,
|
||||
GENERATED_IMAGES_DIR,
|
||||
DEEP_RESEARCH_DIR,
|
||||
CHROMA_DIR,
|
||||
RAG_DIR,
|
||||
MEMORY_VECTORS_DIR,
|
||||
os.path.join(BASE_DIR, "logs"),
|
||||
]
|
||||
|
||||
@@ -72,7 +79,7 @@ def _prompt_admin_credentials():
|
||||
|
||||
def create_default_admin():
|
||||
"""Create an initial admin user if none exists."""
|
||||
auth_path = os.path.join(DATA_DIR, "auth.json")
|
||||
auth_path = AUTH_FILE
|
||||
if os.path.exists(auth_path):
|
||||
print(" [skip] auth.json already exists")
|
||||
return "exists"
|
||||
@@ -117,7 +124,16 @@ def create_default_admin():
|
||||
print(f" Temporary password: {password}")
|
||||
print(f" ** Change it after first login. Set ODYSSEUS_ADMIN_PASSWORD to choose your own. **")
|
||||
return "created"
|
||||
except ImportError:
|
||||
except ImportError as e:
|
||||
if "incompatible architecture" in str(e).lower():
|
||||
# bcrypt is present but built for the wrong CPU architecture — the
|
||||
# same Apple Silicon mismatch check_arch() guards against, caught here
|
||||
# for the rarer case of an x86 wheel inside an arm64 venv.
|
||||
print(" [error] bcrypt loaded with the wrong CPU architecture.")
|
||||
print(" Rebuild the venv with an arm64 Python:")
|
||||
print(" rm -rf venv && /opt/homebrew/bin/python3.11 -m venv venv")
|
||||
print(" ./venv/bin/pip install -r requirements.txt")
|
||||
return "skipped"
|
||||
print(" [warn] bcrypt not installed — skipping admin user creation")
|
||||
print(" Run: pip install bcrypt")
|
||||
return "skipped"
|
||||
@@ -167,9 +183,52 @@ def check_deps():
|
||||
print(" [ok] tmux installed")
|
||||
|
||||
|
||||
def check_arch():
|
||||
"""Stop early, with guidance, if we're on Apple Silicon but running an
|
||||
Intel (x86_64) Python through Rosetta.
|
||||
|
||||
A venv built with such an interpreter installs and loads compiled packages
|
||||
(bcrypt, pydantic-core, onnxruntime, …) for the wrong CPU architecture, then
|
||||
dies deep inside an import with a cryptic
|
||||
"(mach-o file, but is an incompatible architecture)" error. Catching it here
|
||||
turns that into one clear, actionable message.
|
||||
"""
|
||||
if sys.platform != "darwin" or platform.machine() == "arm64":
|
||||
return # Not macOS, or already an arm64-native interpreter — nothing to do.
|
||||
|
||||
# platform.machine() == "x86_64": either a genuine Intel Mac (fine) or an x86
|
||||
# interpreter running under Rosetta on Apple Silicon (the case we must catch).
|
||||
try:
|
||||
translated = subprocess.run(
|
||||
["sysctl", "-n", "sysctl.proc_translated"],
|
||||
capture_output=True, text=True, timeout=5,
|
||||
).stdout.strip()
|
||||
except Exception:
|
||||
translated = ""
|
||||
if translated != "1":
|
||||
return # Genuine Intel Mac — carry on.
|
||||
|
||||
print("\n [error] This is an Apple Silicon Mac, but setup is running under an")
|
||||
print(" Intel (x86_64) Python through Rosetta. Compiled packages would")
|
||||
print(' load as the wrong architecture and crash with "incompatible')
|
||||
print(' architecture" later on.')
|
||||
print("\n Rebuild the environment with Homebrew's arm64 Python:")
|
||||
print(" brew install python@3.11 # if you don't have it yet")
|
||||
print(" rm -rf venv")
|
||||
print(" /opt/homebrew/bin/python3.11 -m venv venv")
|
||||
print(" ./venv/bin/pip install -r requirements.txt")
|
||||
print(" ./venv/bin/python setup.py")
|
||||
print("\n Tip: ./start-macos.sh does all of this with the right Python.\n")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def main():
|
||||
print("\n=== Odysseus Setup ===\n")
|
||||
|
||||
# Fail fast with a clear message if the CPU architecture is wrong (Apple
|
||||
# Silicon under an x86/Rosetta Python) before importing anything native.
|
||||
check_arch()
|
||||
|
||||
print("1. Creating directories...")
|
||||
create_dirs()
|
||||
|
||||
|
||||
@@ -35,6 +35,7 @@ _CALENDAR_ACTION = (
|
||||
r"delete|deleting|remove|removing|cancel|cancelling|canceling)"
|
||||
)
|
||||
_CALENDAR_THING = r"(?:calendar|calendar\s+(?:entry|item)|event|meeting|appointment|entry|call)"
|
||||
_CALENDAR_READ_THING = r"(?:calendar|schedule|events?|meetings?|appointments?|classes?)"
|
||||
_EXPLANATORY_PREFIX = re.compile(
|
||||
r"^\s*(?:how\s+(?:do|can)\s+i|can\s+you\s+explain|what\s+about|tell\s+me\s+how|show\s+me\s+how)\b",
|
||||
re.I,
|
||||
@@ -59,6 +60,14 @@ _ROUTING_PATTERNS: tuple[tuple[str, str, Pattern[str]], ...] = tuple(
|
||||
("calendar", "calendar target action request", rf"\b{_CALENDAR_ACTION}\b.{{0,120}}\b(?:to|on|in|into|for)\s+(?:my\s+|the\s+|this\s+)?calendar\b"),
|
||||
("calendar", "put item on calendar request", r"\bput\s+.+\bon\s+(?:my\s+)?calendar\b"),
|
||||
|
||||
# Calendar/event lookup. A question such as "Do I have Taekwondo
|
||||
# classes this week?" needs the calendar tool; plain chat cannot know.
|
||||
("calendar", "calendar lookup request", rf"\b(?:list|show|check|find)\b.{{0,120}}\b(?:my\s+|the\s+)?(?:upcoming|next|today'?s?|tomorrow'?s?|this\s+week'?s?)\b.{{0,120}}\b{_CALENDAR_READ_THING}\b"),
|
||||
("calendar", "calendar lookup question", rf"\b(?:what|which)\b.{{0,120}}\b(?:upcoming|next|today'?s?|tomorrow'?s?|this\s+week'?s?)\b.{{0,120}}\b{_CALENDAR_READ_THING}\b"),
|
||||
("calendar", "calendar availability question", rf"\bdo\s+i\s+have\b.{{0,120}}\b(?:upcoming|next|today|tomorrow|this\s+week)\b.{{0,120}}\b{_CALENDAR_READ_THING}\b"),
|
||||
("calendar", "calendar agenda question", r"\bwhat(?:'s| is)\s+on\s+(?:my\s+)?calendar\b"),
|
||||
("calendar", "next calendar item question", r"\bwhen\s+(?:is|are)\s+(?:my\s+)?next\s+(?:event|meeting|appointment|class)\b"),
|
||||
|
||||
# Notes, todos, checklists, and reminders.
|
||||
("notes", "reminder request", r"\bremind\s+me\b"),
|
||||
("notes", "assistant note/todo action request", rf"{_ACTION_QUESTION}(?:add|create|make|take|jot|write\s+down|set)\b.{{0,120}}\b(?:note|todo|task|checklist|reminder)\b"),
|
||||
|
||||
+683
-198
File diff suppressed because it is too large
Load Diff
+5
-31
@@ -14,16 +14,17 @@ Sub-modules:
|
||||
import logging
|
||||
from collections import namedtuple
|
||||
|
||||
from src.tool_utils import _truncate, get_mcp_manager, set_mcp_manager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Constants (kept here — sub-modules import from here)
|
||||
# Constants (re-exported for backward compatibility — single source of truth
|
||||
# is src.constants; always prefer importing from there for new code)
|
||||
# ---------------------------------------------------------------------------
|
||||
MAX_AGENT_ROUNDS = 50
|
||||
SHELL_TIMEOUT = 60
|
||||
PYTHON_TIMEOUT = 30
|
||||
MAX_OUTPUT_CHARS = 10_000
|
||||
MAX_READ_CHARS = 20_000
|
||||
|
||||
# Tool types that trigger execution
|
||||
TOOL_TAGS = {"bash", "python", "web_search", "web_fetch", "read_file", "write_file", "edit_file",
|
||||
@@ -34,7 +35,7 @@ TOOL_TAGS = {"bash", "python", "web_search", "web_fetch", "read_file", "write_fi
|
||||
"send_to_session",
|
||||
"pipeline",
|
||||
"manage_session", "manage_memory", "list_models",
|
||||
"ui_control", "generate_image",
|
||||
"ui_control", "generate_image", "ask_user", "update_plan",
|
||||
"manage_tasks", "api_call", "ask_teacher", "manage_skills",
|
||||
"suggest_document",
|
||||
"manage_endpoints", "manage_mcp", "manage_webhooks",
|
||||
@@ -63,33 +64,6 @@ TOOL_TAGS = {"bash", "python", "web_search", "web_fetch", "read_file", "write_fi
|
||||
|
||||
ToolBlock = namedtuple("ToolBlock", ["tool_type", "content"])
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# MCP Manager (kept here — used by execution and agent_loop)
|
||||
# ---------------------------------------------------------------------------
|
||||
_mcp_manager = None
|
||||
|
||||
def set_mcp_manager(manager):
|
||||
"""Set the global MCP manager instance."""
|
||||
global _mcp_manager
|
||||
_mcp_manager = manager
|
||||
|
||||
def get_mcp_manager():
|
||||
"""Get the global MCP manager instance."""
|
||||
return _mcp_manager
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers (kept here — used by sub-modules)
|
||||
# ---------------------------------------------------------------------------
|
||||
def _truncate(text: str, limit: int = MAX_OUTPUT_CHARS) -> str:
|
||||
# Callers treat the result as text, so always return a string: coerce a
|
||||
# non-string (None -> "", otherwise str(...)) instead of returning it raw,
|
||||
# which would just move the crash downstream.
|
||||
if not isinstance(text, str):
|
||||
text = "" if text is None else str(text)
|
||||
if len(text) > limit:
|
||||
return text[:limit] + f"\n... (truncated, {len(text)} chars total)"
|
||||
return text
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Re-exports from sub-modules
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
+44
-28
@@ -14,6 +14,8 @@ import uuid
|
||||
import time
|
||||
from typing import Dict, Optional, Tuple
|
||||
|
||||
from src.constants import GENERATED_IMAGES_DIR
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
AI_CHAT_TIMEOUT = 120 # seconds for a single LLM call
|
||||
@@ -55,7 +57,7 @@ def set_rag_manager(rag_mgr, personal_docs_mgr=None):
|
||||
# Model resolution
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
from src.endpoint_resolver import normalize_base as _normalize_base, build_chat_url, build_headers, build_models_url
|
||||
from src.endpoint_resolver import build_chat_url, build_headers, build_models_url, resolve_endpoint_runtime
|
||||
|
||||
|
||||
def _resolve_model(spec: str, owner: Optional[str] = None) -> Tuple[str, str, Dict]:
|
||||
@@ -96,9 +98,12 @@ def _resolve_model(spec: str, owner: Optional[str] = None) -> Tuple[str, str, Di
|
||||
(f" matching '{target_endpoint_name}'" if target_endpoint_name else ""))
|
||||
|
||||
for ep in endpoints:
|
||||
base = _normalize_base(ep.base_url)
|
||||
try:
|
||||
base, api_key = resolve_endpoint_runtime(ep, owner=owner)
|
||||
except Exception:
|
||||
continue
|
||||
provider = _detect_provider(base)
|
||||
headers = build_headers(ep.api_key, base)
|
||||
headers = build_headers(api_key, base)
|
||||
|
||||
if provider == "anthropic":
|
||||
# Anthropic: match against hardcoded model list
|
||||
@@ -112,16 +117,20 @@ def _resolve_model(spec: str, owner: Optional[str] = None) -> Tuple[str, str, Di
|
||||
else:
|
||||
# OpenAI-compatible and native Ollama: probe the provider's model list.
|
||||
try:
|
||||
r = httpx.get(build_models_url(base), headers=headers, timeout=5)
|
||||
r.raise_for_status()
|
||||
data = r.json()
|
||||
model_ids = [m.get("id") for m in (data.get("data") or []) if m.get("id")]
|
||||
if not model_ids:
|
||||
model_ids = [
|
||||
m.get("name") or m.get("model")
|
||||
for m in (data.get("models") or [])
|
||||
if m.get("name") or m.get("model")
|
||||
]
|
||||
models_url = build_models_url(base)
|
||||
if models_url:
|
||||
r = httpx.get(models_url, headers=headers, timeout=5)
|
||||
r.raise_for_status()
|
||||
data = r.json()
|
||||
model_ids = [m.get("id") for m in (data.get("data") or []) if m.get("id")]
|
||||
if not model_ids:
|
||||
model_ids = [
|
||||
m.get("name") or m.get("model")
|
||||
for m in (data.get("models") or [])
|
||||
if m.get("name") or m.get("model")
|
||||
]
|
||||
else:
|
||||
model_ids = json.loads(ep.cached_models or "[]")
|
||||
except Exception:
|
||||
model_ids = []
|
||||
|
||||
@@ -1119,25 +1128,32 @@ async def do_list_models(content: str, session_id: Optional[str] = None, owner:
|
||||
total_models = 0
|
||||
|
||||
for ep in endpoints:
|
||||
base = _normalize_base(ep.base_url)
|
||||
try:
|
||||
base, api_key = resolve_endpoint_runtime(ep, owner=owner)
|
||||
except Exception:
|
||||
continue
|
||||
provider = _detect_provider(base)
|
||||
headers = build_headers(ep.api_key, base)
|
||||
headers = build_headers(api_key, base)
|
||||
|
||||
model_ids = []
|
||||
if provider == "anthropic":
|
||||
model_ids = list(ANTHROPIC_MODELS)
|
||||
else:
|
||||
try:
|
||||
r = httpx.get(build_models_url(base), headers=headers, timeout=5)
|
||||
r.raise_for_status()
|
||||
data = r.json()
|
||||
model_ids = [m.get("id") for m in (data.get("data") or []) if m.get("id")]
|
||||
if not model_ids:
|
||||
model_ids = [
|
||||
m.get("name") or m.get("model")
|
||||
for m in (data.get("models") or [])
|
||||
if m.get("name") or m.get("model")
|
||||
]
|
||||
models_url = build_models_url(base)
|
||||
if models_url:
|
||||
r = httpx.get(models_url, headers=headers, timeout=5)
|
||||
r.raise_for_status()
|
||||
data = r.json()
|
||||
model_ids = [m.get("id") for m in (data.get("data") or []) if m.get("id")]
|
||||
if not model_ids:
|
||||
model_ids = [
|
||||
m.get("name") or m.get("model")
|
||||
for m in (data.get("models") or [])
|
||||
if m.get("name") or m.get("model")
|
||||
]
|
||||
else:
|
||||
model_ids = json.loads(ep.cached_models or "[]")
|
||||
except Exception:
|
||||
model_ids = ["(endpoint offline)"]
|
||||
|
||||
@@ -1268,7 +1284,7 @@ async def do_ui_control(content: str, session_id: Optional[str] = None, owner: O
|
||||
toggle <name> <on|off> — Toggle a setting (web, bash, rag, research, incognito, document_editor)
|
||||
set_mode <agent|chat> — Switch between agent and chat mode
|
||||
switch_model <model> — Change the model for the current session
|
||||
set_theme <preset> — Apply a theme preset (dark, light, paper, nord, dracula, gruvbox, gpt, claude, lavender, etc.)
|
||||
set_theme <preset> — Apply a built-in theme preset (dark, light, midnight, paper, cyberpunk, retrowave, forest, ocean, ume, copper, terminal, organs, lavender, gpt, claude, cute)
|
||||
create_theme <name> <bg> <fg> <panel> <border> <accent> [key=val ...] — Create custom theme. Optional key=val: advanced color overrides AND background effects: bgPattern=<none|dots|synapse|rain|constellations|perlin-flow|petals|sparkles|embers>, bgEffectColor=#RRGGBB, bgEffectIntensity=<num>, bgEffectSize=<num>, frosted=true|false
|
||||
open_panel <name> — Open a panel (documents, gallery, email, sessions, notes, memories, skills, settings, cookbook)
|
||||
open_email_reply <uid> [folder] [reply|reply-all|ai-reply] — Open a reply draft document for an email; does not send
|
||||
@@ -1715,7 +1731,7 @@ async def do_generate_image(content: str, session_id: Optional[str] = None, owne
|
||||
|
||||
# GPT image models always return b64_json; DALL-E may return url
|
||||
if img.get("b64_json"):
|
||||
img_dir = Path("data/generated_images")
|
||||
img_dir = Path(GENERATED_IMAGES_DIR)
|
||||
img_dir.mkdir(parents=True, exist_ok=True)
|
||||
filename = f"{uuid.uuid4().hex[:12]}.png"
|
||||
img_path = img_dir / filename
|
||||
@@ -1728,7 +1744,7 @@ async def do_generate_image(content: str, session_id: Optional[str] = None, owne
|
||||
try:
|
||||
dl_resp = httpx.get(img["url"], timeout=60)
|
||||
if dl_resp.status_code == 200:
|
||||
img_dir = Path("data/generated_images")
|
||||
img_dir = Path(GENERATED_IMAGES_DIR)
|
||||
img_dir.mkdir(parents=True, exist_ok=True)
|
||||
filename = f"{uuid.uuid4().hex[:12]}.png"
|
||||
img_path = img_dir / filename
|
||||
|
||||
+22
-1
@@ -10,7 +10,7 @@ def get_current_user(request: Request) -> Optional[str]:
|
||||
return getattr(request.state, 'current_user', None)
|
||||
|
||||
|
||||
def effective_user(request: Request):
|
||||
def effective_user(request: Request) -> Optional[str]:
|
||||
"""The real human behind the request, for ownership/attribution.
|
||||
|
||||
Cookie sessions resolve to the logged-in username. Bearer ``ody_`` callers
|
||||
@@ -34,6 +34,24 @@ def effective_user(request: Request):
|
||||
return get_current_user(request)
|
||||
|
||||
|
||||
def _is_api_token_request(request: Request) -> bool:
|
||||
"""Return True when middleware authenticated a bearer API token."""
|
||||
return bool(getattr(request.state, "api_token", False))
|
||||
|
||||
|
||||
def require_authenticated_request(request: Request) -> str:
|
||||
"""Allow either a browser session or a valid bearer API token.
|
||||
|
||||
This is intentionally narrower than :func:`require_user`: use it only for
|
||||
routes that need authentication but do not read or mutate owner-scoped
|
||||
user data. Owner-scoped routes should use ``require_user`` for browser
|
||||
sessions or their own API-token scope/owner gate.
|
||||
"""
|
||||
if _is_api_token_request(request):
|
||||
return effective_user(request) or ""
|
||||
return require_user(request)
|
||||
|
||||
|
||||
def _auth_disabled() -> bool:
|
||||
"""True when the operator has explicitly turned off auth via .env.
|
||||
Mirrors the AUTH_ENABLED parse in app.py / core/middleware.py so the
|
||||
@@ -60,6 +78,9 @@ def require_user(request: Request) -> str:
|
||||
Use this on routes that touch user data so middleware misconfig can't
|
||||
open them up.
|
||||
"""
|
||||
if _is_api_token_request(request):
|
||||
raise HTTPException(403, "API tokens must use a scope-aware API route")
|
||||
|
||||
u = get_current_user(request)
|
||||
if u:
|
||||
return u
|
||||
|
||||
+6
-4
@@ -33,13 +33,15 @@ from core.atomic_io import atomic_write_json
|
||||
from core.platform_compat import (
|
||||
detached_popen_kwargs,
|
||||
find_bash,
|
||||
git_bash_path,
|
||||
kill_process_tree,
|
||||
pid_alive,
|
||||
)
|
||||
|
||||
_DATA_DIR = Path(os.environ.get("DATA_DIR", "data"))
|
||||
_JOBS_DIR = _DATA_DIR / "bg_jobs"
|
||||
_STORE = _DATA_DIR / "bg_jobs.json"
|
||||
from src.constants import BG_JOBS_DIR, BG_JOBS_FILE
|
||||
|
||||
_JOBS_DIR = Path(BG_JOBS_DIR)
|
||||
_STORE = Path(BG_JOBS_FILE)
|
||||
|
||||
# A job that runs longer than this is presumed stuck and reaped (the agent
|
||||
# still gets a "timed out" follow-up so nothing hangs forever).
|
||||
@@ -106,7 +108,7 @@ def launch(command: str, session_id: str, cwd: Optional[str] = None,
|
||||
# handles drive paths and spaces correctly.
|
||||
cmd_path = _JOBS_DIR / f"{job_id}.cmd.sh"
|
||||
cmd_path.write_text(command + "\n", encoding="utf-8")
|
||||
lp, xp, cp = (shlex.quote(p.as_posix()) for p in (log_path, exit_path, cmd_path))
|
||||
lp, xp, cp = (shlex.quote(git_bash_path(p)) for p in (log_path, exit_path, cmd_path))
|
||||
script_path = _JOBS_DIR / f"{job_id}.sh"
|
||||
script_path.write_text(
|
||||
f"bash {cp} > {lp} 2>&1\n"
|
||||
|
||||
+34
-23
@@ -12,6 +12,8 @@ from typing import Tuple
|
||||
|
||||
from src.auth_helpers import owner_filter
|
||||
from core.platform_compat import IS_WINDOWS, find_bash
|
||||
from core.constants import internal_api_base
|
||||
from src.constants import DATA_DIR, DEEP_RESEARCH_DIR, TIDY_CALENDAR_STATE_FILE, EMAIL_URGENCY_CACHE_DIR, COOKBOOK_STATE_FILE
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -166,7 +168,6 @@ async def action_consolidate_memory(owner: str, **kwargs) -> Tuple[str, bool]:
|
||||
drop_items = decision.get("drop") if isinstance(decision, dict) else None
|
||||
if isinstance(keep_items, list) and isinstance(drop_items, list):
|
||||
by_id = {m.get("id"): m for m in group_memories if m.get("id")}
|
||||
keep_ids = set()
|
||||
cleaned_by_id = {}
|
||||
for item in keep_items:
|
||||
if not isinstance(item, dict):
|
||||
@@ -177,7 +178,6 @@ async def action_consolidate_memory(owner: str, **kwargs) -> Tuple[str, bool]:
|
||||
text = (item.get("text") or "").strip()
|
||||
if not text:
|
||||
continue
|
||||
keep_ids.add(mid)
|
||||
cleaned = {
|
||||
"category": (item.get("category") or by_id[mid].get("category") or "fact").strip(),
|
||||
}
|
||||
@@ -186,11 +186,20 @@ async def action_consolidate_memory(owner: str, **kwargs) -> Tuple[str, bool]:
|
||||
cleaned["text"] = text
|
||||
cleaned_by_id[mid] = cleaned
|
||||
|
||||
# If the model only saw a truncated memory, do not let
|
||||
# that partial view delete or rewrite the full memory.
|
||||
keep_ids.update(mid for mid in truncated_ids if mid in by_id)
|
||||
# Delete only memories the model EXPLICITLY dropped, never
|
||||
# ones it merely omitted from `keep`. Treating the
|
||||
# complement of `keep` as deletions meant a model that
|
||||
# forgot to re-list an id (common) silently destroyed that
|
||||
# memory. Honor the explicit `drop` set instead.
|
||||
drop_ids = {
|
||||
d.get("id")
|
||||
for d in drop_items
|
||||
if isinstance(d, dict) and d.get("id") in by_id
|
||||
}
|
||||
# Never delete a memory the model only saw truncated.
|
||||
drop_ids -= truncated_ids
|
||||
|
||||
if keep_ids:
|
||||
if drop_ids or cleaned_by_id:
|
||||
changed_text = 0
|
||||
group_ref_ids = {id(m) for m in group_memories}
|
||||
kept_all = []
|
||||
@@ -199,7 +208,7 @@ async def action_consolidate_memory(owner: str, **kwargs) -> Tuple[str, bool]:
|
||||
kept_all.append(mem)
|
||||
continue
|
||||
mid = mem.get("id")
|
||||
if mid not in keep_ids:
|
||||
if mid in drop_ids:
|
||||
continue
|
||||
cleaned = cleaned_by_id.get(mid) or {}
|
||||
if mid in truncated_ids:
|
||||
@@ -211,7 +220,7 @@ async def action_consolidate_memory(owner: str, **kwargs) -> Tuple[str, bool]:
|
||||
mem["category"] = cleaned["category"]
|
||||
kept_all.append(mem)
|
||||
|
||||
removed = len(group_memories) - len(keep_ids)
|
||||
removed = sum(1 for m in group_memories if m.get("id") in drop_ids)
|
||||
total_scanned += len(group_memories)
|
||||
if removed or changed_text:
|
||||
all_memories = kept_all
|
||||
@@ -348,7 +357,7 @@ async def action_tidy_research(owner: str, **kwargs) -> Tuple[str, bool]:
|
||||
try:
|
||||
from pathlib import Path
|
||||
import json as _json
|
||||
research_dir = Path("data/deep_research")
|
||||
research_dir = Path(DEEP_RESEARCH_DIR)
|
||||
if not research_dir.exists():
|
||||
raise TaskNoop("no research directory")
|
||||
files = list(research_dir.glob("*.json"))
|
||||
@@ -386,7 +395,7 @@ async def action_tidy_calendar(owner: str, **kwargs) -> Tuple[str, bool]:
|
||||
from core.database import SessionLocal, CalendarEvent
|
||||
from sqlalchemy import func
|
||||
|
||||
STATE_FILE = Path("data/tidy_calendar_state.json")
|
||||
STATE_FILE = Path(TIDY_CALENDAR_STATE_FILE)
|
||||
last_watermark = None
|
||||
try:
|
||||
if STATE_FILE.exists():
|
||||
@@ -593,9 +602,9 @@ async def action_classify_events(owner: str, **kwargs) -> Tuple[str, bool]:
|
||||
if not events:
|
||||
return "No upcoming events to classify", True
|
||||
|
||||
llm_url, llm_model, llm_headers = resolve_endpoint("utility")
|
||||
llm_url, llm_model, llm_headers = resolve_endpoint("utility", owner=owner)
|
||||
if not llm_url:
|
||||
llm_url, llm_model, llm_headers = resolve_endpoint("default")
|
||||
llm_url, llm_model, llm_headers = resolve_endpoint("default", owner=owner)
|
||||
llm_available = bool(llm_url and llm_model)
|
||||
|
||||
# Pull user memories so the LLM has personal context (relationships,
|
||||
@@ -867,9 +876,9 @@ async def action_learn_sender_signatures(owner: str, **kwargs) -> Tuple[str, boo
|
||||
if not eligible:
|
||||
return "All sender sigs already cached (or no eligible senders)", True
|
||||
|
||||
url, model, headers = resolve_endpoint("utility")
|
||||
url, model, headers = resolve_endpoint("utility", owner=owner)
|
||||
if not url or not model:
|
||||
url, model, headers = resolve_endpoint("default")
|
||||
url, model, headers = resolve_endpoint("default", owner=owner)
|
||||
if not url or not model:
|
||||
return "No LLM endpoint available", False
|
||||
|
||||
@@ -1303,12 +1312,12 @@ async def action_ping_notes(owner: str, **kwargs) -> Tuple[str, bool]:
|
||||
# users' entries (review C4). Legacy path kept as fallback so a
|
||||
# single-user install (empty owner) doesn't lose its history.
|
||||
_owner_slug = "".join(c if (c.isalnum() or c in "-_.@") else "_" for c in (owner or "default"))
|
||||
STATE = _P(f"data/note_pings_{_owner_slug}.json")
|
||||
STATE = _P(DATA_DIR) / f"note_pings_{_owner_slug}.json"
|
||||
STATE.parent.mkdir(parents=True, exist_ok=True)
|
||||
# One-time migration: if legacy global file exists and per-owner file
|
||||
# doesn't, seed from global (entries for OTHER owners still get pruned
|
||||
# on their first run — acceptable, prevents silent loss).
|
||||
_legacy = _P("data/note_pings.json")
|
||||
_legacy = _P(DATA_DIR) / "note_pings.json"
|
||||
if _legacy.exists() and not STATE.exists():
|
||||
try:
|
||||
STATE.write_text(_legacy.read_text(encoding="utf-8"), encoding="utf-8")
|
||||
@@ -1465,8 +1474,8 @@ async def action_check_email_urgency(owner: str, **kwargs) -> Tuple[str, bool]:
|
||||
# notified_uids / urgency counts. Empty owner falls back to a generic
|
||||
# filename for single-user installs (matches prior behaviour).
|
||||
_owner_slug = "".join(c if (c.isalnum() or c in "-_.@") else "_" for c in (owner or "default"))
|
||||
STATE_PATH = _P(f"data/email_urgency_state_{_owner_slug}.json")
|
||||
CACHE_DIR = _P("data/email_urgency_cache")
|
||||
STATE_PATH = _P(DATA_DIR) / f"email_urgency_state_{_owner_slug}.json"
|
||||
CACHE_DIR = _P(EMAIL_URGENCY_CACHE_DIR)
|
||||
CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
||||
STATE_PATH.parent.mkdir(parents=True, exist_ok=True)
|
||||
AGE_CUTOFF = _dt.utcnow() - _td(days=7)
|
||||
@@ -1480,12 +1489,12 @@ async def action_check_email_urgency(owner: str, **kwargs) -> Tuple[str, bool]:
|
||||
|
||||
# ── 1. Resolve LLM candidates (utility primary + utility fallbacks; fall
|
||||
# through to default chat as a last resort).
|
||||
url, model, headers = resolve_endpoint("utility")
|
||||
url, model, headers = resolve_endpoint("utility", owner=owner)
|
||||
if not url or not model:
|
||||
url, model, headers = resolve_endpoint("default")
|
||||
url, model, headers = resolve_endpoint("default", owner=owner)
|
||||
if not url or not model:
|
||||
return "No LLM endpoint available", False
|
||||
candidates = [(url, model, headers)] + resolve_utility_fallback_candidates()
|
||||
candidates = [(url, model, headers)] + resolve_utility_fallback_candidates(owner=owner)
|
||||
|
||||
# ── 2. Enumerate enabled accounts. Match this task's owner AND fall
|
||||
# back to the legacy "unowned account whose imap_user / from_address
|
||||
@@ -1902,6 +1911,8 @@ async def action_check_email_urgency(owner: str, **kwargs) -> Tuple[str, bool]:
|
||||
delivered = bool(dispatch_result.get("email_sent"))
|
||||
elif channel == "ntfy":
|
||||
delivered = bool(dispatch_result.get("ntfy_sent"))
|
||||
elif channel == "webhook":
|
||||
delivered = bool(dispatch_result.get("webhook_sent"))
|
||||
if delivered:
|
||||
newly_notified.update(new_urgent)
|
||||
else:
|
||||
@@ -2040,7 +2051,7 @@ async def action_cookbook_serve(
|
||||
except Exception:
|
||||
end_after_min = 0
|
||||
|
||||
state_path = Path("/app/data/cookbook_state.json")
|
||||
state_path = Path(COOKBOOK_STATE_FILE)
|
||||
try:
|
||||
state = json.loads(state_path.read_text(encoding="utf-8")) if state_path.exists() else {}
|
||||
except Exception:
|
||||
@@ -2116,7 +2127,7 @@ async def action_cookbook_serve(
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=30) as client:
|
||||
r = await client.post("http://localhost:7000/api/model/serve",
|
||||
r = await client.post(f"{internal_api_base()}/api/model/serve",
|
||||
json=body, headers=headers)
|
||||
data = r.json() if r.content else {}
|
||||
except Exception as e:
|
||||
|
||||
+241
-51
@@ -27,6 +27,7 @@ import hashlib
|
||||
import ipaddress
|
||||
import logging
|
||||
import os
|
||||
import socket
|
||||
import uuid
|
||||
from datetime import date, datetime, timedelta, timezone
|
||||
from urllib.parse import urlparse, urlunparse
|
||||
@@ -50,15 +51,55 @@ def _private_caldav_allowed() -> bool:
|
||||
return os.environ.get("ODYSSEUS_ALLOW_PRIVATE_CALDAV", "0").lower() in {"1", "true", "yes"}
|
||||
|
||||
|
||||
def _validate_caldav_address(addr: ipaddress._BaseAddress) -> None:
|
||||
if isinstance(addr, ipaddress.IPv6Address) and addr.ipv4_mapped is not None:
|
||||
addr = addr.ipv4_mapped
|
||||
if (
|
||||
addr.is_loopback
|
||||
or addr.is_link_local
|
||||
or addr.is_multicast
|
||||
or addr.is_unspecified
|
||||
or addr.is_reserved
|
||||
):
|
||||
raise ValueError("CalDAV URL host is not allowed")
|
||||
if addr.is_private and not _private_caldav_allowed():
|
||||
raise ValueError("Private CalDAV IPs require ODYSSEUS_ALLOW_PRIVATE_CALDAV=1")
|
||||
|
||||
|
||||
def _validate_caldav_ip(host: str) -> None:
|
||||
try:
|
||||
ip = ipaddress.ip_address(host.strip("[]"))
|
||||
except ValueError:
|
||||
return
|
||||
if ip.is_loopback or ip.is_link_local or ip.is_multicast or ip.is_unspecified:
|
||||
raise ValueError("CalDAV URL host is not allowed")
|
||||
if ip.is_private and not _private_caldav_allowed():
|
||||
raise ValueError("Private CalDAV IPs require ODYSSEUS_ALLOW_PRIVATE_CALDAV=1")
|
||||
_validate_caldav_address(ip)
|
||||
|
||||
|
||||
def _resolve_caldav_host_ips(host: str) -> list[ipaddress._BaseAddress]:
|
||||
addrs: list[ipaddress._BaseAddress] = []
|
||||
for family, _, _, _, sockaddr in socket.getaddrinfo(host, None):
|
||||
if family not in (socket.AF_INET, socket.AF_INET6):
|
||||
continue
|
||||
try:
|
||||
addrs.append(ipaddress.ip_address(sockaddr[0].split("%", 1)[0]))
|
||||
except ValueError:
|
||||
continue
|
||||
return addrs
|
||||
|
||||
|
||||
def _validate_caldav_hostname(host: str) -> None:
|
||||
try:
|
||||
ipaddress.ip_address(host.strip("[]"))
|
||||
return
|
||||
except ValueError:
|
||||
pass
|
||||
try:
|
||||
addrs = _resolve_caldav_host_ips(host)
|
||||
except OSError:
|
||||
raise ValueError("CalDAV URL host does not resolve")
|
||||
if not addrs:
|
||||
raise ValueError("CalDAV URL host does not resolve")
|
||||
for addr in addrs:
|
||||
_validate_caldav_address(addr)
|
||||
|
||||
|
||||
def validate_caldav_url(raw_url: str) -> str:
|
||||
@@ -83,15 +124,18 @@ def validate_caldav_url(raw_url: str) -> str:
|
||||
if host in _BLOCKED_HOSTS or host.endswith(".localhost"):
|
||||
raise ValueError("CalDAV URL host is not allowed")
|
||||
_validate_caldav_ip(host)
|
||||
_validate_caldav_hostname(host)
|
||||
return urlunparse(parsed._replace(fragment="")).rstrip("/")
|
||||
|
||||
|
||||
def _stable_cal_id(remote_url: str, owner: str = "") -> str:
|
||||
"""Deterministic local id for a remote CalDAV calendar — same URL
|
||||
always maps to the same local row across restarts and re-syncs.
|
||||
Owner is included in the hash to prevent PK collisions when multiple
|
||||
users sync the same CalDAV endpoint."""
|
||||
h = hashlib.sha256(f"{owner}:{remote_url}".encode("utf-8")).hexdigest()[:24]
|
||||
def _stable_cal_id(remote_url: str, owner: str = "", account_id: str = "") -> str:
|
||||
"""Deterministic local id for a remote CalDAV calendar, scoped to owner
|
||||
and account so two users — or one user with two accounts — pointing at
|
||||
the same server URL get distinct local rows (avoids PK collision, #2765).
|
||||
The owner and account_id default to "" for the legacy/URL-only path so
|
||||
existing callers without those arguments keep working."""
|
||||
key = f"{owner}\n{account_id}\n{remote_url}"
|
||||
h = hashlib.sha256(key.encode("utf-8")).hexdigest()[:24]
|
||||
return f"caldav-{h}"
|
||||
|
||||
|
||||
@@ -126,18 +170,103 @@ def _find_existing_event(db, pending, uid_val, calendar_id):
|
||||
).first()
|
||||
|
||||
|
||||
def _sync_blocking(owner: str, url: str, username: str, password: str) -> dict:
|
||||
def _google_caldav_events_url(url: str) -> str | None:
|
||||
"""Map a Google CalDAV *principal* URL to its event-collection URL.
|
||||
|
||||
Google serves the principal at ``…/user`` but events live under ``…/events``
|
||||
— the ``/user`` resource holds no VEVENTs. The `caldav` library's
|
||||
principal→home-set discovery does not reliably enumerate calendars from
|
||||
Google's ``/user`` endpoint, so the sync falls into the "treat the URL as a
|
||||
single calendar" fallback below. Pointed at ``/user`` that fallback issues
|
||||
every calendar-query REPORT against the principal, which returns a clean but
|
||||
empty 200 for all date ranges — the calendar shows no events even though
|
||||
auth succeeded (issue #2507).
|
||||
|
||||
Both Google CalDAV endpoint forms are handled, since some accounts only
|
||||
authenticate against one of them:
|
||||
- newer: ``https://apidata.googleusercontent.com/caldav/v2/<id>/user``
|
||||
- legacy: ``https://www.google.com/calendar/dav/<id>/user``
|
||||
|
||||
Returns the events URL for a recognised Google principal URL, else None so
|
||||
the caller keeps the original URL unchanged.
|
||||
"""
|
||||
parts = urlparse(url)
|
||||
host = (parts.hostname or "").lower()
|
||||
path = parts.path.rstrip("/")
|
||||
if not path.endswith("/user"):
|
||||
return None
|
||||
is_google = (
|
||||
host.endswith("googleusercontent.com") # newer /caldav/v2 form
|
||||
or (host in ("www.google.com", "google.com") and "/calendar/dav/" in path) # legacy form
|
||||
)
|
||||
if not is_google:
|
||||
return None
|
||||
new_path = path[: -len("/user")] + "/events"
|
||||
return urlunparse(parts._replace(path=new_path))
|
||||
|
||||
|
||||
def _open_url_as_calendar(client, url: str):
|
||||
"""Open ``url`` as a single calendar collection.
|
||||
|
||||
Used when principal discovery yields no calendars. Google's principal URL
|
||||
is not an event collection, so map it to the events URL first
|
||||
(see ``_google_caldav_events_url``); other servers' URLs are used as-is.
|
||||
"""
|
||||
target = _google_caldav_events_url(url) or url
|
||||
return client.calendar(url=target)
|
||||
|
||||
|
||||
def _build_dav_client(url: str, username: str, password: str):
|
||||
"""Construct a CalDAV client with automatic redirects disabled.
|
||||
|
||||
``validate_caldav_url`` resolves and vets the *initial* host, but caldav's
|
||||
underlying HTTP session follows 3xx redirects by default. So a URL that
|
||||
passes validation can still be redirected — at request time — to
|
||||
loopback / link-local / private space, re-opening the SSRF the host check
|
||||
closes. Pin the session to zero redirects: any 3xx then raises instead of
|
||||
silently following an attacker-chosen ``Location``. This mirrors the
|
||||
test-connection path in ``routes/calendar_routes.py``, which already sets
|
||||
``follow_redirects=False``.
|
||||
|
||||
DAVClient exposes no per-request redirect flag, so we set it on the session
|
||||
after construction (the session is created in ``__init__``).
|
||||
"""
|
||||
import caldav
|
||||
|
||||
client = caldav.DAVClient(url=url, username=username, password=password)
|
||||
# Unconditional: a redirect-disable that only sometimes applies is not a
|
||||
# control. The session exists right after __init__ on every real client;
|
||||
# test_build_dav_client_disables_redirects asserts it against installed
|
||||
# caldav in CI.
|
||||
client.session.max_redirects = 0
|
||||
return client
|
||||
|
||||
|
||||
def _should_prune_window(seen_uids: set, parse_failed: bool) -> bool:
|
||||
"""Whether the post-sync prune of vanished CalDAV events is safe to run.
|
||||
|
||||
The prune deletes local ``origin=="caldav"`` rows in the window whose UID the
|
||||
server did not just return. Any parse failure (total or partial) makes
|
||||
``seen_uids`` an incomplete view of the server, so pruning against it can
|
||||
delete events that still exist upstream but could not be read: a total
|
||||
failure wipes the whole window, a partial failure deletes just the
|
||||
unreadable ones. Only prune on a clean read. An empty ``seen_uids`` after a
|
||||
clean read is a genuinely empty window, which is safe to prune.
|
||||
"""
|
||||
return not parse_failed
|
||||
|
||||
|
||||
def _sync_blocking(owner: str, url: str, username: str, password: str, account_id: str = "") -> dict:
|
||||
"""The actual sync — synchronous, intended to run in a threadpool.
|
||||
Returns counts: {calendars, events, deleted, errors}."""
|
||||
# Lazy imports so a missing `caldav` dep doesn't break app startup —
|
||||
# the integrations form still works, sync just no-ops with an error.
|
||||
import caldav
|
||||
from caldav.lib.error import AuthorizationError, NotFoundError
|
||||
from core.database import CalendarCal, CalendarEvent, SessionLocal
|
||||
|
||||
result = {"calendars": 0, "events": 0, "deleted": 0, "errors": []}
|
||||
|
||||
client = caldav.DAVClient(url=url, username=username, password=password)
|
||||
client = _build_dav_client(url, username, password)
|
||||
|
||||
# Discovery: try principal → calendars first; if the server doesn't
|
||||
# support discovery (or the URL points directly at a calendar), fall
|
||||
@@ -152,14 +281,14 @@ def _sync_blocking(owner: str, url: str, username: str, password: str) -> dict:
|
||||
except Exception as e:
|
||||
logger.info(f"CalDAV principal discovery failed, trying URL as calendar: {e}")
|
||||
try:
|
||||
calendars = [client.calendar(url=url)]
|
||||
calendars = [_open_url_as_calendar(client, url)]
|
||||
except Exception as e2:
|
||||
result["errors"].append(f"Could not open URL as calendar: {e2}")
|
||||
return result
|
||||
|
||||
if not calendars:
|
||||
try:
|
||||
calendars = [client.calendar(url=url)]
|
||||
calendars = [_open_url_as_calendar(client, url)]
|
||||
except Exception as e:
|
||||
result["errors"].append(f"No calendars and URL fallback failed: {e}")
|
||||
return result
|
||||
@@ -172,7 +301,7 @@ def _sync_blocking(owner: str, url: str, username: str, password: str) -> dict:
|
||||
for remote_cal in calendars:
|
||||
try:
|
||||
remote_url = str(remote_cal.url)
|
||||
cal_id = _stable_cal_id(remote_url, owner)
|
||||
cal_id = _stable_cal_id(remote_url, owner=owner, account_id=account_id)
|
||||
display_name = (remote_cal.name or "").strip() or "CalDAV"
|
||||
|
||||
local_cal = db.query(CalendarCal).filter(
|
||||
@@ -186,14 +315,20 @@ def _sync_blocking(owner: str, url: str, username: str, password: str) -> dict:
|
||||
name=display_name,
|
||||
color="#5b8abf",
|
||||
source="caldav",
|
||||
account_id=account_id or None,
|
||||
)
|
||||
db.add(local_cal)
|
||||
db.commit()
|
||||
else:
|
||||
# Refresh the display name if the user renamed it
|
||||
# remotely; preserve any local color override.
|
||||
# Refresh display name and stamp account_id if missing.
|
||||
changed = False
|
||||
if local_cal.name != display_name:
|
||||
local_cal.name = display_name
|
||||
changed = True
|
||||
if account_id and not local_cal.account_id:
|
||||
local_cal.account_id = account_id
|
||||
changed = True
|
||||
if changed:
|
||||
db.commit()
|
||||
result["calendars"] += 1
|
||||
|
||||
@@ -207,6 +342,7 @@ def _sync_blocking(owner: str, url: str, username: str, password: str) -> dict:
|
||||
# duplicate UIDs within the same batch are updated, not re-inserted
|
||||
# (which would violate the UNIQUE constraint on commit).
|
||||
pending: dict = {}
|
||||
parse_failed = False
|
||||
try:
|
||||
objs = remote_cal.date_search(start=start, end=end, expand=False)
|
||||
except Exception as e:
|
||||
@@ -218,6 +354,7 @@ def _sync_blocking(owner: str, url: str, username: str, password: str) -> dict:
|
||||
ical = iCal.from_ical(obj.data)
|
||||
except Exception as e:
|
||||
result["errors"].append(f"{display_name}: parse failed ({e})")
|
||||
parse_failed = True
|
||||
continue
|
||||
|
||||
for comp in ical.walk():
|
||||
@@ -294,17 +431,23 @@ def _sync_blocking(owner: str, url: str, username: str, password: str) -> dict:
|
||||
# are prunable; locally-created events (agent / email triage / a
|
||||
# UI event whose write-back failed) carry origin NULL and must
|
||||
# never be deleted just because the server didn't return them.
|
||||
stale = db.query(CalendarEvent).filter(
|
||||
CalendarEvent.calendar_id == local_cal.id,
|
||||
CalendarEvent.origin == "caldav",
|
||||
CalendarEvent.dtstart >= start,
|
||||
CalendarEvent.dtstart <= end,
|
||||
~CalendarEvent.uid.in_(seen_uids) if seen_uids else CalendarEvent.uid.isnot(None),
|
||||
).all()
|
||||
for ev in stale:
|
||||
db.delete(ev)
|
||||
result["deleted"] += len(stale)
|
||||
db.commit()
|
||||
# Skip the prune on any parse failure: seen_uids is then an
|
||||
# incomplete view of the server, so pruning against it would
|
||||
# delete events that still exist upstream but could not be read
|
||||
# (the empty-seen_uids case wipes the whole window; a partial
|
||||
# failure deletes just the unreadable rows).
|
||||
if _should_prune_window(seen_uids, parse_failed):
|
||||
stale = db.query(CalendarEvent).filter(
|
||||
CalendarEvent.calendar_id == local_cal.id,
|
||||
CalendarEvent.origin == "caldav",
|
||||
CalendarEvent.dtstart >= start,
|
||||
CalendarEvent.dtstart <= end,
|
||||
~CalendarEvent.uid.in_(seen_uids) if seen_uids else CalendarEvent.uid.isnot(None),
|
||||
).all()
|
||||
for ev in stale:
|
||||
db.delete(ev)
|
||||
result["deleted"] += len(stale)
|
||||
db.commit()
|
||||
except Exception as e:
|
||||
logger.exception("CalDAV sync failed for one calendar")
|
||||
result["errors"].append(str(e)[:200])
|
||||
@@ -315,31 +458,78 @@ def _sync_blocking(owner: str, url: str, username: str, password: str) -> dict:
|
||||
return result
|
||||
|
||||
|
||||
async def sync_caldav(owner: str) -> dict:
|
||||
"""Pull CalDAV state into local DB for `owner`. Returns counts +
|
||||
errors. Loads credentials from the user's prefs; no-ops with a
|
||||
clear error if CalDAV isn't configured."""
|
||||
def _load_caldav_accounts(owner: str) -> list:
|
||||
"""Return the list of CalDAV accounts for *owner*, auto-migrating the legacy
|
||||
single-account ``caldav`` key to the new ``caldav_accounts`` list on first call.
|
||||
|
||||
The save step is best-effort: if ``_save_for_user`` is unavailable (e.g. in a
|
||||
test with a minimal prefs mock) the migrated accounts are still returned; the
|
||||
next real call will just re-run the cheap migration again.
|
||||
"""
|
||||
import uuid as _uuid
|
||||
from routes.prefs_routes import _load_for_user
|
||||
|
||||
cfg = (_load_for_user(owner) or {}).get("caldav", {}) or {}
|
||||
url = (cfg.get("url") or "").strip()
|
||||
user = (cfg.get("username") or "").strip()
|
||||
pw = cfg.get("password") or ""
|
||||
try:
|
||||
from src.secret_storage import decrypt
|
||||
pw = decrypt(pw)
|
||||
except Exception:
|
||||
pass
|
||||
if not (url and user and pw):
|
||||
prefs = _load_for_user(owner) or {}
|
||||
if "caldav_accounts" in prefs:
|
||||
return list(prefs["caldav_accounts"] or [])
|
||||
# Migrate legacy single-account config to the list format.
|
||||
legacy = prefs.get("caldav", {}) or {}
|
||||
if legacy.get("url"):
|
||||
accounts = [{
|
||||
"id": str(_uuid.uuid4()),
|
||||
"label": "CalDAV",
|
||||
"url": legacy["url"],
|
||||
"username": legacy.get("username", ""),
|
||||
"password": legacy.get("password", ""),
|
||||
}]
|
||||
prefs["caldav_accounts"] = accounts
|
||||
prefs.pop("caldav", None)
|
||||
try:
|
||||
from routes.prefs_routes import _save_for_user
|
||||
_save_for_user(owner, prefs)
|
||||
except (ImportError, AttributeError):
|
||||
pass # best-effort; next call re-migrates from the still-present legacy key
|
||||
return accounts
|
||||
return []
|
||||
|
||||
|
||||
async def sync_caldav(owner: str) -> dict:
|
||||
"""Pull CalDAV state into local DB for `owner` across all configured accounts.
|
||||
Returns aggregated counts + per-account errors."""
|
||||
from src.secret_storage import decrypt
|
||||
|
||||
accounts = _load_caldav_accounts(owner)
|
||||
if not accounts:
|
||||
return {
|
||||
"calendars": 0, "events": 0, "deleted": 0,
|
||||
"errors": ["CalDAV is not configured"],
|
||||
}
|
||||
try:
|
||||
url = validate_caldav_url(url)
|
||||
return await asyncio.to_thread(_sync_blocking, owner, url, user, pw)
|
||||
except ValueError as e:
|
||||
return {"calendars": 0, "events": 0, "deleted": 0, "errors": [str(e)]}
|
||||
except Exception as e:
|
||||
logger.exception("CalDAV sync raised")
|
||||
return {"calendars": 0, "events": 0, "deleted": 0, "errors": [str(e)[:200]]}
|
||||
|
||||
totals: dict = {"calendars": 0, "events": 0, "deleted": 0, "errors": []}
|
||||
for acc in accounts:
|
||||
url = (acc.get("url") or "").strip()
|
||||
user = (acc.get("username") or "").strip()
|
||||
pw = acc.get("password") or ""
|
||||
account_id = acc.get("id") or ""
|
||||
label = acc.get("label") or url or account_id
|
||||
try:
|
||||
pw = decrypt(pw)
|
||||
except Exception:
|
||||
pass
|
||||
if not (url and user and pw):
|
||||
totals["errors"].append(f"{label}: missing URL, username, or password")
|
||||
continue
|
||||
try:
|
||||
url = validate_caldav_url(url)
|
||||
result = await asyncio.to_thread(_sync_blocking, owner, url, user, pw, account_id)
|
||||
except ValueError as e:
|
||||
result = {"calendars": 0, "events": 0, "deleted": 0, "errors": [str(e)]}
|
||||
except Exception as e:
|
||||
logger.exception("CalDAV sync raised for account %s", label)
|
||||
result = {"calendars": 0, "events": 0, "deleted": 0, "errors": [str(e)[:200]]}
|
||||
totals["calendars"] += result.get("calendars", 0)
|
||||
totals["events"] += result.get("events", 0)
|
||||
totals["deleted"] += result.get("deleted", 0)
|
||||
for err in result.get("errors", []):
|
||||
totals["errors"].append(f"{label}: {err}")
|
||||
return totals
|
||||
|
||||
+59
-23
@@ -23,11 +23,10 @@ from datetime import timezone
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _stable_cal_id(remote_url: str) -> str:
|
||||
# Reuse the sync module's hashing so a local CalDAV calendar id maps back to
|
||||
# the same remote URL it was pulled from.
|
||||
def _stable_cal_id(remote_url: str, owner: str = "", account_id: str = "") -> str:
|
||||
# Reuse the sync module's hashing so owner+account_id scoping stays consistent.
|
||||
from src.caldav_sync import _stable_cal_id as _sync_id
|
||||
return _sync_id(remote_url)
|
||||
return _sync_id(remote_url, owner=owner, account_id=account_id)
|
||||
|
||||
|
||||
def build_event_ical(ev: dict) -> str:
|
||||
@@ -76,28 +75,34 @@ def build_event_ical(ev: dict) -> str:
|
||||
return cal.to_ical().decode("utf-8")
|
||||
|
||||
|
||||
def find_remote_calendar(calendars, local_cal_id: str):
|
||||
"""Find the remote calendar whose URL hashes to ``local_cal_id``, or None."""
|
||||
def find_remote_calendar(calendars, local_cal_id: str, owner: str = "", account_id: str = ""):
|
||||
"""Find the remote calendar whose URL hashes to ``local_cal_id``, or None.
|
||||
|
||||
``owner`` and ``account_id`` must match what was used when the local calendar
|
||||
id was originally computed in ``_sync_blocking`` so the hash round-trips."""
|
||||
for cal in calendars:
|
||||
try:
|
||||
if _stable_cal_id(str(cal.url)) == local_cal_id:
|
||||
if _stable_cal_id(str(cal.url), owner=owner, account_id=account_id) == local_cal_id:
|
||||
return cal
|
||||
except Exception:
|
||||
continue
|
||||
return None
|
||||
|
||||
|
||||
def push_event(calendars, local_cal_id: str, ev: dict, *, delete: bool = False) -> dict:
|
||||
def push_event(calendars, local_cal_id: str, ev: dict, *, delete: bool = False,
|
||||
owner: str = "", account_id: str = "") -> dict:
|
||||
"""Create/update (or delete) ``ev`` on the matching remote calendar.
|
||||
|
||||
Returns ``{"ok": bool, ...}``. ``calendars`` is the discovered caldav
|
||||
calendar list (injected so this is unit-testable with fakes).
|
||||
``owner`` and ``account_id`` are forwarded to ``find_remote_calendar``
|
||||
so the URL hash round-trips correctly (#2765).
|
||||
"""
|
||||
uid = (ev or {}).get("uid") if isinstance(ev, dict) else None
|
||||
if not uid:
|
||||
return {"ok": False, "error": "event uid is required"}
|
||||
|
||||
remote = find_remote_calendar(calendars, local_cal_id)
|
||||
remote = find_remote_calendar(calendars, local_cal_id, owner=owner, account_id=account_id)
|
||||
if remote is None:
|
||||
return {"ok": False, "error": "remote calendar not found"}
|
||||
|
||||
@@ -136,13 +141,17 @@ def _discover_calendars(client):
|
||||
return []
|
||||
|
||||
|
||||
def _writeback_blocking(local_cal_id, ev, delete, url, username, password) -> dict:
|
||||
import caldav
|
||||
client = caldav.DAVClient(url=url, username=username, password=password)
|
||||
def _writeback_blocking(local_cal_id, ev, delete, url, username, password,
|
||||
owner="", account_id="") -> dict:
|
||||
from src.caldav_sync import _build_dav_client
|
||||
# Redirects disabled here too: the write-back path opens its own DAVClient,
|
||||
# so it needs the same SSRF-via-redirect protection as the pull path.
|
||||
client = _build_dav_client(url, username, password)
|
||||
calendars = _discover_calendars(client)
|
||||
if not calendars:
|
||||
return {"ok": False, "error": "no remote calendars discovered"}
|
||||
return push_event(calendars, local_cal_id, ev, delete=delete)
|
||||
return push_event(calendars, local_cal_id, ev, delete=delete,
|
||||
owner=owner, account_id=account_id)
|
||||
|
||||
|
||||
async def writeback_event(owner: str, calendar_source: str, calendar_id: str,
|
||||
@@ -156,18 +165,45 @@ async def writeback_event(owner: str, calendar_source: str, calendar_id: str,
|
||||
if calendar_source != "caldav":
|
||||
return {"skipped": "not a caldav calendar"}
|
||||
try:
|
||||
from routes.prefs_routes import _load_for_user
|
||||
from src.caldav_sync import _load_caldav_accounts
|
||||
from src.secret_storage import decrypt
|
||||
cfg = (_load_for_user(owner) or {}).get("caldav", {}) or {}
|
||||
url = (cfg.get("url") or "").strip()
|
||||
user = (cfg.get("username") or "").strip()
|
||||
# Stored encrypted by routes/calendar_routes; decrypt before use so
|
||||
# the remote sees the real password (decrypt is a no-op on legacy
|
||||
# plaintext). The pull path src/caldav_sync.py already does this.
|
||||
pw = decrypt(cfg.get("password") or "")
|
||||
if not (url and user and pw):
|
||||
from core.database import CalendarCal, SessionLocal
|
||||
|
||||
accounts = _load_caldav_accounts(owner)
|
||||
if not accounts:
|
||||
return {"skipped": "caldav not configured"}
|
||||
result = await asyncio.to_thread(_writeback_blocking, calendar_id, ev, delete, url, user, pw)
|
||||
|
||||
# Find which account owns this calendar.
|
||||
acc = None
|
||||
if len(accounts) > 1:
|
||||
db = SessionLocal()
|
||||
try:
|
||||
cal_row = db.query(CalendarCal).filter(CalendarCal.id == calendar_id).first()
|
||||
cal_account_id = cal_row.account_id if cal_row else None
|
||||
finally:
|
||||
db.close()
|
||||
if cal_account_id:
|
||||
acc = next((a for a in accounts if a.get("id") == cal_account_id), None)
|
||||
# Fall back to first account (covers single-account and legacy rows with
|
||||
# no account_id stamped).
|
||||
if acc is None:
|
||||
acc = accounts[0]
|
||||
|
||||
url = (acc.get("url") or "").strip()
|
||||
user = (acc.get("username") or "").strip()
|
||||
pw = decrypt(acc.get("password") or "")
|
||||
if not (url and user and pw):
|
||||
return {"skipped": "caldav account credentials incomplete"}
|
||||
from src.caldav_sync import validate_caldav_url
|
||||
try:
|
||||
url = validate_caldav_url(url)
|
||||
except ValueError as e:
|
||||
logger.warning("CalDAV write-back URL rejected: %s", e)
|
||||
return {"ok": False, "error": str(e)[:200]}
|
||||
acc_id = acc.get("id") or ""
|
||||
result = await asyncio.to_thread(
|
||||
_writeback_blocking, calendar_id, ev, delete, url, user, pw, owner, acc_id
|
||||
)
|
||||
if not result.get("ok"):
|
||||
logger.warning("CalDAV write-back did not apply: %s", result.get("error") or result)
|
||||
return result
|
||||
|
||||
+25
-15
@@ -98,6 +98,7 @@ class ChatHandler:
|
||||
att_ids: List[str],
|
||||
sess,
|
||||
auto_opened_docs: Optional[List[Dict[str, Any]]] = None,
|
||||
allow_tool_preprocessing: bool = True,
|
||||
) -> tuple:
|
||||
"""
|
||||
Common preprocessing for both chat endpoints.
|
||||
@@ -112,7 +113,7 @@ class ChatHandler:
|
||||
attachment_meta: List[Dict[str, Any]] = []
|
||||
|
||||
# Extract URLs and process YouTube transcripts
|
||||
urls = extract_urls(enhanced_message)
|
||||
urls = extract_urls(enhanced_message) if allow_tool_preprocessing else []
|
||||
youtube_transcripts: List[str] = []
|
||||
|
||||
has_youtube = False
|
||||
@@ -143,24 +144,18 @@ class ChatHandler:
|
||||
if has_youtube:
|
||||
youtube_transcripts.insert(0, YOUTUBE_INSTRUCTION_PROMPT)
|
||||
|
||||
# Analyze images — skip if vision disabled, or if main model is vision-capable
|
||||
from src.settings import get_setting
|
||||
vision_enabled = get_setting("vision_enabled", True)
|
||||
main_is_vision = await asyncio.to_thread(
|
||||
model_supports_vision, sess.model or "", getattr(sess, "endpoint_url", "") or ""
|
||||
)
|
||||
|
||||
# Resolve uploads once with the session owner. Attachment IDs are
|
||||
# bearer-like references; never trust them without an owner check.
|
||||
files_by_id: Dict[str, Dict] = {}
|
||||
owner = getattr(sess, "owner", None)
|
||||
if att_ids:
|
||||
for att_id in att_ids:
|
||||
effective_att_ids = att_ids if allow_tool_preprocessing else []
|
||||
if effective_att_ids:
|
||||
for att_id in effective_att_ids:
|
||||
fi = self.upload_handler.resolve_upload(att_id, owner=owner)
|
||||
if fi:
|
||||
files_by_id[att_id] = fi
|
||||
|
||||
for att_id in att_ids:
|
||||
for att_id in effective_att_ids:
|
||||
fi = files_by_id.get(att_id)
|
||||
if fi:
|
||||
attachment_meta.append({
|
||||
@@ -172,9 +167,24 @@ class ChatHandler:
|
||||
"height": fi.get("height"),
|
||||
})
|
||||
|
||||
if att_ids and vision_enabled:
|
||||
# Analyze images only when attachment preprocessing is actually
|
||||
# allowed. The vision capability check can probe local model endpoints,
|
||||
# so guide-only/no-tools turns must not reach it.
|
||||
vision_enabled = False
|
||||
main_is_vision = False
|
||||
if effective_att_ids:
|
||||
from src.settings import get_setting
|
||||
vision_enabled = get_setting("vision_enabled", True)
|
||||
if vision_enabled:
|
||||
main_is_vision = await asyncio.to_thread(
|
||||
model_supports_vision,
|
||||
sess.model or "",
|
||||
getattr(sess, "endpoint_url", "") or "",
|
||||
)
|
||||
|
||||
if effective_att_ids and vision_enabled:
|
||||
meta_by_id = {m["id"]: m for m in attachment_meta}
|
||||
for att_id in att_ids:
|
||||
for att_id in effective_att_ids:
|
||||
file_info = files_by_id.get(att_id)
|
||||
if file_info and self.upload_handler.is_image_file(
|
||||
file_info["name"], file_info.get("mime", "")
|
||||
@@ -219,7 +229,7 @@ class ChatHandler:
|
||||
except Exception:
|
||||
vl_desc = None
|
||||
if not vl_desc:
|
||||
vl_result = analyze_image_with_vl_result(file_info["path"])
|
||||
vl_result = analyze_image_with_vl_result(file_info["path"], owner=owner)
|
||||
vl_desc = vl_result.get("text", "")
|
||||
vl_model = vl_result.get("model", "")
|
||||
if vl_desc and not vl_desc.startswith("["):
|
||||
@@ -239,7 +249,7 @@ class ChatHandler:
|
||||
_m["vision_model"] = vl_model
|
||||
|
||||
user_content = build_user_content(
|
||||
enhanced_message, att_ids, UPLOAD_DIR, self.upload_handler,
|
||||
enhanced_message, effective_att_ids, UPLOAD_DIR, self.upload_handler,
|
||||
session_id=getattr(sess, "id", None),
|
||||
auto_opened_docs=auto_opened_docs,
|
||||
owner=owner,
|
||||
|
||||
+13
-3
@@ -13,6 +13,8 @@ from fastapi import HTTPException
|
||||
from fastapi import UploadFile
|
||||
from typing import List, Optional
|
||||
|
||||
from src.upload_limits import format_byte_limit, get_chat_upload_max_bytes
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -22,7 +24,14 @@ def extract_urls(text: str) -> List[str]:
|
||||
urls = re.findall(url_pattern, text)
|
||||
cleaned_urls = []
|
||||
for url in urls:
|
||||
url = re.sub(r'[.,;:!?\)]+$', '', url)
|
||||
# Strip trailing sentence punctuation, but keep a balanced ')' so URLs
|
||||
# that legitimately end in one are preserved, e.g. the Wikipedia link
|
||||
# ".../Python_(programming_language)". A ')' is only dropped when it is
|
||||
# unbalanced (more ')' than '('), which is the prose-glued case such as
|
||||
# "(see https://example.com)".
|
||||
url = re.sub(r'[.,;:!?]+$', '', url)
|
||||
while url.endswith(')') and url.count(')') > url.count('('):
|
||||
url = re.sub(r'[.,;:!?]+$', '', url[:-1])
|
||||
cleaned_urls.append(url)
|
||||
return cleaned_urls
|
||||
|
||||
@@ -201,12 +210,13 @@ def validate_file_upload(file: UploadFile) -> UploadFile:
|
||||
}
|
||||
)
|
||||
|
||||
if file_size > 10 * 1024 * 1024:
|
||||
upload_limit = get_chat_upload_max_bytes()
|
||||
if file_size > upload_limit:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": "FILE_TOO_LARGE",
|
||||
"message": "File size exceeds 10MB limit"
|
||||
"message": f"File size exceeds {format_byte_limit(upload_limit)} limit"
|
||||
}
|
||||
)
|
||||
except IOError as e:
|
||||
|
||||
@@ -0,0 +1,311 @@
|
||||
"""ChatGPT subscription / Codex backend OAuth helpers.
|
||||
|
||||
This provider is intentionally separate from OpenAI API-key endpoints. It uses
|
||||
OpenAI account OAuth device authorization, stores refresh tokens server-side,
|
||||
and resolves a fresh bearer token at request time.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import json
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import httpx
|
||||
from fastapi import HTTPException
|
||||
|
||||
from core.database import ProviderAuthSession, SessionLocal, utcnow_naive
|
||||
|
||||
DEFAULT_CHATGPT_SUBSCRIPTION_BASE_URL = (
|
||||
os.getenv("CHATGPT_SUBSCRIPTION_BASE_URL", "").strip().rstrip("/")
|
||||
or "https://chatgpt.com/backend-api/codex"
|
||||
)
|
||||
CHATGPT_SUBSCRIPTION_PROVIDER = "chatgpt-subscription"
|
||||
CHATGPT_OAUTH_CLIENT_ID = "app_EMoamEEZ73f0CkXaXp7hrann"
|
||||
CHATGPT_OAUTH_TOKEN_URL = "https://auth.openai.com/oauth/token"
|
||||
CHATGPT_OAUTH_ISSUER = "https://auth.openai.com"
|
||||
CHATGPT_OAUTH_REDIRECT_URI = f"{CHATGPT_OAUTH_ISSUER}/deviceauth/callback"
|
||||
CHATGPT_ACCESS_TOKEN_REFRESH_SKEW_SECONDS = 120
|
||||
_AUTH_REFRESH_LOCKS: dict[str, threading.Lock] = {}
|
||||
_AUTH_REFRESH_LOCKS_GUARD = threading.Lock()
|
||||
|
||||
|
||||
def _refresh_lock_for(auth_id: str) -> threading.Lock:
|
||||
with _AUTH_REFRESH_LOCKS_GUARD:
|
||||
lock = _AUTH_REFRESH_LOCKS.get(auth_id)
|
||||
if lock is None:
|
||||
lock = threading.Lock()
|
||||
_AUTH_REFRESH_LOCKS[auth_id] = lock
|
||||
return lock
|
||||
|
||||
|
||||
class ChatGPTSubscriptionError(RuntimeError):
|
||||
"""Base error for ChatGPT subscription provider failures."""
|
||||
|
||||
|
||||
class ChatGPTSubscriptionReauthRequired(ChatGPTSubscriptionError):
|
||||
"""Stored OAuth credentials are invalid or expired beyond refresh."""
|
||||
|
||||
|
||||
class ChatGPTSubscriptionRateLimited(ChatGPTSubscriptionError):
|
||||
"""Upstream quota/rate limit; reconnecting will not fix it."""
|
||||
|
||||
|
||||
class ChatGPTSubscriptionAuthNotFound(ChatGPTSubscriptionError):
|
||||
"""No matching owner-scoped auth session exists."""
|
||||
|
||||
|
||||
def is_chatgpt_subscription_base(url: str) -> bool:
|
||||
try:
|
||||
from urllib.parse import urlparse
|
||||
|
||||
parsed = urlparse(url or "")
|
||||
host = (parsed.hostname or "").lower().rstrip(".")
|
||||
path = (parsed.path or "").rstrip("/")
|
||||
except Exception:
|
||||
return False
|
||||
return host == "chatgpt.com" and (
|
||||
path == "/backend-api/codex" or path.startswith("/backend-api/codex/")
|
||||
)
|
||||
|
||||
|
||||
def chatgpt_headers(access_token: Optional[str]) -> Dict[str, str]:
|
||||
headers = {
|
||||
"Accept": "application/json, text/event-stream",
|
||||
"Origin": "https://chatgpt.com",
|
||||
"Referer": "https://chatgpt.com/codex",
|
||||
"User-Agent": "Odysseus ChatGPT Subscription",
|
||||
}
|
||||
if access_token:
|
||||
headers["Authorization"] = f"Bearer {access_token}"
|
||||
return headers
|
||||
|
||||
|
||||
def fetch_available_models(access_token: str, timeout: float = 10.0) -> list[str]:
|
||||
if not access_token:
|
||||
return []
|
||||
try:
|
||||
response = httpx.get(
|
||||
"https://chatgpt.com/backend-api/codex/models?client_version=1.0.0",
|
||||
headers=chatgpt_headers(access_token),
|
||||
timeout=timeout,
|
||||
)
|
||||
if response.status_code != 200:
|
||||
return []
|
||||
data = response.json()
|
||||
except Exception:
|
||||
return []
|
||||
entries = data.get("models", []) if isinstance(data, dict) else []
|
||||
sortable: list[tuple[int, str]] = []
|
||||
for item in entries:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
slug = item.get("slug")
|
||||
if not isinstance(slug, str) or not slug.strip():
|
||||
continue
|
||||
visibility = item.get("visibility", "")
|
||||
if isinstance(visibility, str) and visibility.strip().lower() in {"hide", "hidden"}:
|
||||
continue
|
||||
priority = item.get("priority")
|
||||
rank = int(priority) if isinstance(priority, (int, float)) else 10_000
|
||||
sortable.append((rank, slug.strip()))
|
||||
sortable.sort(key=lambda item: (item[0], item[1]))
|
||||
ordered: list[str] = []
|
||||
seen: set[str] = set()
|
||||
for _, slug in sortable:
|
||||
if slug not in seen:
|
||||
ordered.append(slug)
|
||||
seen.add(slug)
|
||||
return ordered
|
||||
|
||||
|
||||
def _raise_for_oauth_response(response: httpx.Response, action: str) -> None:
|
||||
if response.status_code < 400:
|
||||
return
|
||||
code = ""
|
||||
message = f"ChatGPT Subscription {action} failed with HTTP {response.status_code}."
|
||||
try:
|
||||
payload = response.json()
|
||||
err = payload.get("error") if isinstance(payload, dict) else None
|
||||
if isinstance(err, dict):
|
||||
code = str(err.get("code") or err.get("type") or "").strip()
|
||||
msg = err.get("message")
|
||||
if msg:
|
||||
message = f"ChatGPT Subscription {action} failed: {msg}"
|
||||
elif isinstance(err, str):
|
||||
code = err.strip()
|
||||
desc = payload.get("error_description") or payload.get("message")
|
||||
if desc:
|
||||
message = f"ChatGPT Subscription {action} failed: {desc}"
|
||||
except Exception:
|
||||
pass
|
||||
if response.status_code == 429:
|
||||
raise ChatGPTSubscriptionRateLimited(
|
||||
"ChatGPT Subscription quota or rate limit was reached. Credentials are still valid."
|
||||
)
|
||||
if response.status_code in (401, 403) or code in {"invalid_grant", "invalid_token", "invalid_request", "refresh_token_reused"}:
|
||||
raise ChatGPTSubscriptionReauthRequired(message)
|
||||
raise ChatGPTSubscriptionError(message)
|
||||
|
||||
|
||||
def _json_or_error(response: httpx.Response, action: str) -> Dict[str, Any]:
|
||||
_raise_for_oauth_response(response, action)
|
||||
try:
|
||||
data = response.json()
|
||||
except Exception as exc:
|
||||
raise ChatGPTSubscriptionError(f"ChatGPT Subscription {action} returned invalid JSON.") from exc
|
||||
if not isinstance(data, dict):
|
||||
raise ChatGPTSubscriptionError(f"ChatGPT Subscription {action} returned an unexpected response.")
|
||||
return data
|
||||
|
||||
|
||||
def request_device_code(timeout: float = 15.0) -> Dict[str, Any]:
|
||||
response = httpx.post(
|
||||
f"{CHATGPT_OAUTH_ISSUER}/api/accounts/deviceauth/usercode",
|
||||
json={"client_id": CHATGPT_OAUTH_CLIENT_ID},
|
||||
headers={"Content-Type": "application/json"},
|
||||
timeout=timeout,
|
||||
)
|
||||
data = _json_or_error(response, "device-code request")
|
||||
if not data.get("device_auth_id") or not data.get("user_code"):
|
||||
raise ChatGPTSubscriptionError("ChatGPT device-code response was missing required fields.")
|
||||
data.setdefault("verification_uri", f"{CHATGPT_OAUTH_ISSUER}/codex/device")
|
||||
data.setdefault("interval", 5)
|
||||
data.setdefault("expires_in", 900)
|
||||
return data
|
||||
|
||||
|
||||
def poll_device_auth(device_auth_id: str, user_code: str, timeout: float = 15.0) -> Dict[str, Any]:
|
||||
response = httpx.post(
|
||||
f"{CHATGPT_OAUTH_ISSUER}/api/accounts/deviceauth/token",
|
||||
json={"device_auth_id": device_auth_id, "user_code": user_code},
|
||||
headers={"Content-Type": "application/json"},
|
||||
timeout=timeout,
|
||||
)
|
||||
if response.status_code in (403, 404):
|
||||
return {"status": "pending", "error": "authorization_pending"}
|
||||
return _json_or_error(response, "device-code poll")
|
||||
|
||||
|
||||
def exchange_authorization_code(authorization_code: str, code_verifier: str, timeout: float = 15.0) -> Dict[str, Any]:
|
||||
response = httpx.post(
|
||||
CHATGPT_OAUTH_TOKEN_URL,
|
||||
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||||
data={
|
||||
"grant_type": "authorization_code",
|
||||
"code": authorization_code,
|
||||
"redirect_uri": CHATGPT_OAUTH_REDIRECT_URI,
|
||||
"client_id": CHATGPT_OAUTH_CLIENT_ID,
|
||||
"code_verifier": code_verifier,
|
||||
},
|
||||
timeout=timeout,
|
||||
)
|
||||
data = _json_or_error(response, "token exchange")
|
||||
if not data.get("access_token"):
|
||||
raise ChatGPTSubscriptionReauthRequired("ChatGPT token exchange did not return an access token.")
|
||||
return data
|
||||
|
||||
|
||||
def refresh_oauth_tokens(access_token: str, refresh_token: str, timeout: float = 20.0) -> Dict[str, Any]:
|
||||
del access_token
|
||||
if not refresh_token:
|
||||
raise ChatGPTSubscriptionReauthRequired("ChatGPT Subscription is missing a refresh token. Reconnect the provider.")
|
||||
response = httpx.post(
|
||||
CHATGPT_OAUTH_TOKEN_URL,
|
||||
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||||
data={
|
||||
"grant_type": "refresh_token",
|
||||
"refresh_token": refresh_token,
|
||||
"client_id": CHATGPT_OAUTH_CLIENT_ID,
|
||||
},
|
||||
timeout=timeout,
|
||||
)
|
||||
data = _json_or_error(response, "token refresh")
|
||||
if not data.get("access_token"):
|
||||
raise ChatGPTSubscriptionReauthRequired("ChatGPT token refresh did not return an access token.")
|
||||
return data
|
||||
|
||||
|
||||
def _decode_jwt_payload(token: str) -> Dict[str, Any]:
|
||||
parts = (token or "").split(".")
|
||||
if len(parts) < 2:
|
||||
raise ValueError("not a JWT")
|
||||
segment = parts[1]
|
||||
segment += "=" * (-len(segment) % 4)
|
||||
raw = base64.urlsafe_b64decode(segment.encode("ascii"))
|
||||
payload = json.loads(raw.decode("utf-8"))
|
||||
return payload if isinstance(payload, dict) else {}
|
||||
|
||||
|
||||
def access_token_is_expiring(access_token: str, skew_seconds: int = CHATGPT_ACCESS_TOKEN_REFRESH_SKEW_SECONDS) -> bool:
|
||||
try:
|
||||
exp = int(_decode_jwt_payload(access_token).get("exp") or 0)
|
||||
except Exception:
|
||||
return True
|
||||
return exp <= int(time.time()) + int(skew_seconds)
|
||||
|
||||
|
||||
def resolve_runtime_credentials(auth_id: str, owner: Optional[str] = None, *, force_refresh: bool = False) -> Dict[str, Any]:
|
||||
db = SessionLocal()
|
||||
try:
|
||||
q = db.query(ProviderAuthSession).filter(
|
||||
ProviderAuthSession.id == auth_id,
|
||||
ProviderAuthSession.provider == CHATGPT_SUBSCRIPTION_PROVIDER,
|
||||
)
|
||||
if owner:
|
||||
q = q.filter(ProviderAuthSession.owner == owner)
|
||||
row = q.first()
|
||||
if row is None:
|
||||
raise ChatGPTSubscriptionAuthNotFound("ChatGPT Subscription credentials were not found for this user.")
|
||||
|
||||
access_token = row.access_token or ""
|
||||
if force_refresh or access_token_is_expiring(access_token):
|
||||
with _refresh_lock_for(auth_id):
|
||||
db.refresh(row)
|
||||
access_token = row.access_token or ""
|
||||
refresh_token = row.refresh_token or ""
|
||||
if force_refresh or access_token_is_expiring(access_token):
|
||||
refreshed = refresh_oauth_tokens(access_token, refresh_token)
|
||||
row.access_token = refreshed["access_token"]
|
||||
if refreshed.get("refresh_token"):
|
||||
row.refresh_token = refreshed["refresh_token"]
|
||||
row.last_refresh = utcnow_naive()
|
||||
db.commit()
|
||||
db.refresh(row)
|
||||
access_token = row.access_token or ""
|
||||
|
||||
return {
|
||||
"provider": CHATGPT_SUBSCRIPTION_PROVIDER,
|
||||
"base_url": (row.base_url or DEFAULT_CHATGPT_SUBSCRIPTION_BASE_URL).rstrip("/"),
|
||||
"api_key": access_token,
|
||||
"auth_mode": row.auth_mode or "chatgpt",
|
||||
}
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
def to_http_exception(exc: Exception) -> HTTPException:
|
||||
if isinstance(exc, ChatGPTSubscriptionRateLimited):
|
||||
return HTTPException(429, str(exc))
|
||||
if isinstance(exc, (ChatGPTSubscriptionReauthRequired, ChatGPTSubscriptionAuthNotFound)):
|
||||
return HTTPException(401, f"{exc} Reconnect the provider.")
|
||||
return HTTPException(502, str(exc))
|
||||
|
||||
|
||||
def build_responses_input(messages: list[dict]) -> list[dict]:
|
||||
input_items: list[dict] = []
|
||||
for msg in messages or []:
|
||||
role = msg.get("role") or "user"
|
||||
if role == "tool":
|
||||
role = "user"
|
||||
content = msg.get("content")
|
||||
if isinstance(content, list):
|
||||
text = "\n".join(str(part.get("text") or part.get("content") or "") for part in content if isinstance(part, dict))
|
||||
else:
|
||||
text = "" if content is None else str(content)
|
||||
input_type = "output_text" if role == "assistant" else "input_text"
|
||||
input_items.append({"role": role, "content": [{"type": input_type, "text": text}]})
|
||||
return input_items
|
||||
+10
-8
@@ -4,6 +4,8 @@ from typing import List, Optional
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
from pydantic import Field, field_validator
|
||||
|
||||
from src.constants import DATA_DIR as _DATA_DIR_CONST
|
||||
|
||||
# Cross-platform OS flag, exposed here so callers can `from src.config import
|
||||
# IS_WINDOWS`. Defined locally (a trivial `os.name == "nt"`) rather than imported
|
||||
# from core.platform_compat, to keep this dependency-light config module from
|
||||
@@ -20,13 +22,13 @@ class DataConfig(BaseSettings):
|
||||
base_dir: Path = Field(default=Path(__file__).parent.parent, description="Base directory for the application")
|
||||
|
||||
# Data paths
|
||||
data_dir: Path = Field(default=Path("data"), description="Main data directory")
|
||||
uploads_dir: Path = Field(default=Path("data/uploads"), description="Directory for uploaded files")
|
||||
sessions_file: Path = Field(default=Path("data/sessions.json"), description="Sessions storage file")
|
||||
memory_file: Path = Field(default=Path("data/memory.json"), description="Memory storage file")
|
||||
memory_doc: Path = Field(default=Path("data/memory_doc.md"), description="Memory document file")
|
||||
personal_dir: Path = Field(default=Path("data/personal_docs"), description="Personal documents directory")
|
||||
runbook_dir: Path = Field(default=Path("data/personal_docs/runbook"), description="Runbook directory")
|
||||
data_dir: Path = Field(default=Path(_DATA_DIR_CONST), description="Main data directory")
|
||||
uploads_dir: Path = Field(default=Path(_DATA_DIR_CONST) / "uploads", description="Directory for uploaded files")
|
||||
sessions_file: Path = Field(default=Path(_DATA_DIR_CONST) / "sessions.json", description="Sessions storage file")
|
||||
memory_file: Path = Field(default=Path(_DATA_DIR_CONST) / "memory.json", description="Memory storage file")
|
||||
memory_doc: Path = Field(default=Path(_DATA_DIR_CONST) / "memory_doc.md", description="Memory document file")
|
||||
personal_dir: Path = Field(default=Path(_DATA_DIR_CONST) / "personal_docs", description="Personal documents directory")
|
||||
runbook_dir: Path = Field(default=Path(_DATA_DIR_CONST) / "personal_docs" / "runbook", description="Runbook directory")
|
||||
|
||||
# Upload settings
|
||||
max_upload_size: int = Field(default=10 * 1024 * 1024, description="Maximum upload size in bytes (10MB)")
|
||||
@@ -139,7 +141,7 @@ class AppConfig(BaseSettings):
|
||||
base_dir = Path(__file__).parent.parent
|
||||
|
||||
# Convert string paths to Path objects relative to base_dir
|
||||
data_dir = base_dir / "data"
|
||||
data_dir = Path(_DATA_DIR_CONST)
|
||||
|
||||
# Get values from the input dict or use defaults
|
||||
max_upload_size = v.get("max_upload_size", 10 * 1024 * 1024) if isinstance(v, dict) else 10 * 1024 * 1024
|
||||
|
||||
+65
-2
@@ -7,9 +7,12 @@ APP_VERSION = "1.0.0"
|
||||
# Base paths
|
||||
BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + "/"
|
||||
STATIC_DIR = os.path.join(BASE_DIR, "static")
|
||||
DATA_DIR = os.path.join(BASE_DIR, "data")
|
||||
DATA_DIR = os.getenv("ODYSSEUS_DATA_DIR", os.path.join(BASE_DIR, "data"))
|
||||
|
||||
# Data file paths
|
||||
# Single source of truth: every persisted file/dir lives under DATA_DIR, which
|
||||
# is the ONLY place ODYSSEUS_DATA_DIR is read. Import these constants instead of
|
||||
# re-deriving paths from __file__ or a relative "data" literal.
|
||||
SESSIONS_FILE = os.path.join(DATA_DIR, "sessions.json")
|
||||
MEMORY_FILE = os.path.join(DATA_DIR, "memory.json")
|
||||
MEMORY_DOC = os.path.join(DATA_DIR, "memory_doc.md")
|
||||
@@ -18,6 +21,47 @@ RUNBOOK_DIR = os.path.join(PERSONAL_DIR, "runbook")
|
||||
UPLOAD_DIR = os.path.join(DATA_DIR, "uploads")
|
||||
FEATURES_FILE = os.path.join(DATA_DIR, "features.json")
|
||||
SETTINGS_FILE = os.path.join(DATA_DIR, "settings.json")
|
||||
AUTH_FILE = os.path.join(DATA_DIR, "auth.json")
|
||||
USER_PREFS_FILE = os.path.join(DATA_DIR, "user_prefs.json")
|
||||
PRESETS_FILE = os.path.join(DATA_DIR, "presets.json")
|
||||
INTEGRATIONS_FILE = os.path.join(DATA_DIR, "integrations.json")
|
||||
CONTACTS_FILE = os.path.join(DATA_DIR, "contacts.json")
|
||||
APP_KEY_FILE = os.path.join(DATA_DIR, ".app_key")
|
||||
EMBEDDING_ENDPOINT_FILE = os.path.join(DATA_DIR, "embedding_endpoint.json")
|
||||
COOKBOOK_STATE_FILE = os.path.join(DATA_DIR, "cookbook_state.json")
|
||||
BG_JOBS_FILE = os.path.join(DATA_DIR, "bg_jobs.json")
|
||||
VAULT_FILE = os.path.join(DATA_DIR, "vault.json")
|
||||
TIDY_CALENDAR_STATE_FILE = os.path.join(DATA_DIR, "tidy_calendar_state.json")
|
||||
SKILLS_FILE = os.path.join(DATA_DIR, "skills.json")
|
||||
APP_DB = os.path.join(DATA_DIR, "app.db")
|
||||
SCHEDULED_EMAILS_DB = os.path.join(DATA_DIR, "scheduled_emails.db")
|
||||
EMAIL_CACHE_DB = os.path.join(DATA_DIR, "email_cache.db")
|
||||
|
||||
# Data subdirectories
|
||||
PERSONAL_UPLOADS_DIR = os.path.join(DATA_DIR, "personal_uploads")
|
||||
EMOJI_CACHE_DIR = os.path.join(DATA_DIR, "emoji_cache")
|
||||
RAG_DIR = os.path.join(DATA_DIR, "rag")
|
||||
CHROMA_DIR = os.path.join(DATA_DIR, "chroma")
|
||||
BG_JOBS_DIR = os.path.join(DATA_DIR, "bg_jobs")
|
||||
DEEP_RESEARCH_DIR = os.path.join(DATA_DIR, "deep_research")
|
||||
MCP_OAUTH_DIR = os.path.join(DATA_DIR, "mcp_oauth")
|
||||
GENERATED_IMAGES_DIR = os.path.join(DATA_DIR, "generated_images")
|
||||
TTS_CACHE_DIR = os.path.join(DATA_DIR, "tts_cache")
|
||||
EMAIL_URGENCY_CACHE_DIR = os.path.join(DATA_DIR, "email_urgency_cache")
|
||||
SKILLS_DIR = os.path.join(DATA_DIR, "skills")
|
||||
GALLERY_DIR = os.path.join(DATA_DIR, "gallery")
|
||||
GALLERY_UPLOADS_DIR = os.path.join(DATA_DIR, "gallery_uploads")
|
||||
MEMORY_VECTORS_DIR = os.path.join(DATA_DIR, "memory_vectors")
|
||||
|
||||
# Paths with an intentional dedicated env override, defaulting under DATA_DIR.
|
||||
MAIL_ATTACHMENTS_DIR = os.getenv("ODYSSEUS_MAIL_ATTACHMENTS_DIR", os.path.join(DATA_DIR, "mail-attachments"))
|
||||
FASTEMBED_CACHE_DIR = os.getenv("FASTEMBED_CACHE_PATH", os.path.join(DATA_DIR, "fastembed_cache"))
|
||||
|
||||
# Agent tool output limits (single source of truth — imported by tool_execution.py,
|
||||
# tool_implementations.py, agent_tools.py, and any other module that needs them)
|
||||
MAX_OUTPUT_CHARS = 10_000 # cap for bash/python/web_search/web_fetch output
|
||||
MAX_READ_CHARS = 20_000 # cap for read_file / document preview
|
||||
MAX_DIFF_LINES = 400 # cap for edit_file unified-diff display
|
||||
|
||||
# API Configuration
|
||||
MAX_CONTEXT_MESSAGES = 90
|
||||
@@ -28,7 +72,7 @@ OPENAI_COMPAT_PATH = "/v1/chat/completions"
|
||||
DEFAULT_HOST = os.getenv("LLM_HOST", "localhost")
|
||||
LLM_HOSTS = [h.strip() for h in os.getenv("LLM_HOSTS", "").split(",") if h.strip()]
|
||||
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
|
||||
SEARXNG_INSTANCE = os.getenv('SEARXNG_INSTANCE', 'http://localhost:8080')
|
||||
SEARXNG_INSTANCE = os.getenv("SEARXNG_INSTANCE", "http://localhost:8080")
|
||||
|
||||
|
||||
# Cleanup configuration
|
||||
@@ -38,3 +82,22 @@ CLEANUP_INTERVAL_HOURS = int(os.getenv("CLEANUP_INTERVAL_HOURS", "24"))
|
||||
# Default parameters
|
||||
DEFAULT_TEMPERATURE = 1.0
|
||||
DEFAULT_MAX_TOKENS = 0
|
||||
|
||||
|
||||
def internal_api_base() -> str:
|
||||
"""Base URL for in-process loopback calls to Odysseus's own API.
|
||||
|
||||
Agent tools and background jobs reach admin-gated routes by calling the
|
||||
running server over HTTP. Resolution order:
|
||||
1. ODYSSEUS_INTERNAL_BASE - explicit override (e.g. behind a TLS proxy).
|
||||
2. APP_PORT - http://127.0.0.1:$APP_PORT (docker-compose).
|
||||
3. Fallback http://127.0.0.1:7000 - legacy default.
|
||||
|
||||
127.0.0.1 (not "localhost") avoids IPv6/DNS ambiguity for a strictly-local
|
||||
call. Without this, loopback tools fail with "All connection attempts
|
||||
failed" whenever the server is not on port 7000.
|
||||
"""
|
||||
override = os.environ.get("ODYSSEUS_INTERNAL_BASE")
|
||||
if override:
|
||||
return override.rstrip("/")
|
||||
return f"http://127.0.0.1:{os.environ.get('APP_PORT', '7000')}"
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user