mirror of
https://github.com/pewdiepie-archdaemon/odysseus.git
synced 2026-06-16 01:35:36 -04:00
Merge remote-tracking branch 'origin/main' into visual-pr-playground
# Conflicts: # routes/cookbook_routes.py # routes/hwfit_routes.py # services/hwfit/fit.py # services/hwfit/models.py # static/js/cookbook-diagnosis.js # static/js/cookbook-hwfit.js # static/js/cookbook.js # static/js/cookbookRunning.js
This commit is contained in:
@@ -57,6 +57,7 @@ class DocsService:
|
||||
metadata=r.get("metadata"),
|
||||
)
|
||||
for r in results
|
||||
if isinstance(r, dict)
|
||||
]
|
||||
|
||||
async def index(self, directory: str) -> IndexResult:
|
||||
|
||||
+17
-3
@@ -61,7 +61,7 @@ CONTEXT_TARGET = {
|
||||
|
||||
|
||||
def _lookup_bandwidth(gpu_name):
|
||||
if not gpu_name:
|
||||
if not isinstance(gpu_name, str) or not gpu_name:
|
||||
return None
|
||||
gn = gpu_name.lower()
|
||||
for key in _BW_KEYS_SORTED:
|
||||
@@ -280,10 +280,14 @@ def _native_quant(model):
|
||||
return "FP8"
|
||||
if "gptq" in text:
|
||||
m = re.search(r"(?:gptq|int|w)(?:[-_]?)(\d{1,2})(?:bit)?", text)
|
||||
return f"GPTQ-{m.group(1)}bit" if m else "GPTQ"
|
||||
# Canonical catalog label is "GPTQ-Int4"/"GPTQ-Int8" (see models.py
|
||||
# QUANT_BPP / QUANT_QUALITY_PENALTY keys); "GPTQ-4bit" misses both
|
||||
# maps, so BPP and the quality penalty silently fall to defaults.
|
||||
return f"GPTQ-Int{m.group(1)}" if m else "GPTQ-Int4"
|
||||
if "awq" in text:
|
||||
m = re.search(r"(?:awq|int|w)(?:[-_]?)(\d{1,2})(?:bit)?", text)
|
||||
return f"AWQ-{m.group(1)}bit" if m else "AWQ"
|
||||
# Catalog keys are "AWQ-4bit"/"AWQ-8bit"; bare "AWQ" misses the maps.
|
||||
return f"AWQ-{m.group(1)}bit" if m else "AWQ-4bit"
|
||||
if "mlx" in text:
|
||||
m = re.search(r"mlx[-_]?(\d{1,2})bit", text)
|
||||
return f"mlx-{m.group(1)}bit" if m else native_quant
|
||||
@@ -571,6 +575,8 @@ def rank_models(system, use_case=None, limit=50, search=None, sort="score", quan
|
||||
|
||||
system_backend = (system.get("backend") or "").lower()
|
||||
apple_silicon = system_backend in ("mps", "metal", "apple")
|
||||
rocm = system_backend == "rocm"
|
||||
|
||||
# Consumer AMD Radeon (RDNA, gfx10/11/12): the practical local serving path
|
||||
# is GGUF via llama.cpp. vLLM/SGLang on ROCm are validated for datacenter
|
||||
# Instinct (CDNA, gfx9xx) but are unreliable on consumer RDNA — AWQ kernels
|
||||
@@ -589,6 +595,14 @@ def rank_models(system, use_case=None, limit=50, search=None, sort="score", quan
|
||||
if native_q.startswith("mlx-") or "mlx" in (m.get("name") or "").lower():
|
||||
continue
|
||||
|
||||
# ROCm support for vLLM/SGLang quantized safetensors is too brittle to
|
||||
# recommend blindly in the default scan. Keep AWQ/GPTQ/FP8 discoverable
|
||||
# only when the user explicitly picks that format from the quant filter;
|
||||
# otherwise prefer GGUF/Q* entries that Odysseus can route through
|
||||
# llama.cpp/Ollama without pretending "fits VRAM" means "servable".
|
||||
if rocm and is_prequantized(m) and not filter_native:
|
||||
continue
|
||||
|
||||
# On Apple Silicon the only serving engines are llama.cpp and Ollama,
|
||||
# both GGUF-only (vLLM/SGLang are CUDA/ROCm and don't run on macOS). So
|
||||
# a model is Metal-servable ONLY if it ships a real GGUF. Drop everything
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import os
|
||||
import platform
|
||||
import re
|
||||
import shutil
|
||||
import subprocess
|
||||
import time
|
||||
@@ -104,6 +105,8 @@ def _detect_nvidia():
|
||||
return None
|
||||
|
||||
gpus = []
|
||||
# Devices nvidia-smi lists with a real name but a non-numeric memory.total.
|
||||
unified = []
|
||||
# nvidia-smi lists GPUs in index order (0,1,2,...), so the row position is
|
||||
# the CUDA device index we'd pass to CUDA_VISIBLE_DEVICES.
|
||||
for idx, line in enumerate(out.strip().split("\n")):
|
||||
@@ -113,9 +116,32 @@ def _detect_nvidia():
|
||||
vram_mb = float(parts[0])
|
||||
gpus.append({"index": idx, "name": parts[1], "vram_gb": vram_mb / 1024.0})
|
||||
except ValueError:
|
||||
# Grace Blackwell GB10 / DGX Spark and other unified-memory
|
||||
# NVIDIA parts report memory.total as "[N/A]"/"Not Supported"
|
||||
# because the GPU shares the system LPDDR pool instead of
|
||||
# carrying discrete VRAM. Don't drop the device — remember it so
|
||||
# we report a unified-memory GPU below rather than "No GPU" (#1340).
|
||||
if parts[1]:
|
||||
unified.append({"index": idx, "name": parts[1]})
|
||||
continue
|
||||
|
||||
if not gpus:
|
||||
if unified:
|
||||
# Unified-memory CUDA box: report the GPU backed by system RAM so the
|
||||
# Cookbook recommends models and serving works. The pool is shared
|
||||
# (not per-GPU discrete VRAM), so report the RAM total once.
|
||||
ram_gb = round(_get_ram_gb(), 1)
|
||||
gpus = [{"index": g["index"], "name": g["name"], "vram_gb": ram_gb} for g in unified]
|
||||
return {
|
||||
"gpu_name": gpus[0]["name"],
|
||||
"gpu_vram_gb": ram_gb,
|
||||
"gpu_count": len(gpus),
|
||||
"gpus": gpus,
|
||||
"gpu_groups": _group_gpus(gpus),
|
||||
"homogeneous": True,
|
||||
"backend": "cuda",
|
||||
"unified_memory": True,
|
||||
}
|
||||
return None
|
||||
total_vram = sum(g["vram_gb"] for g in gpus)
|
||||
groups = _group_gpus(gpus)
|
||||
@@ -130,6 +156,33 @@ def _detect_nvidia():
|
||||
}
|
||||
|
||||
|
||||
def classify_amd_gfx(gfx):
|
||||
"""Map an AMD ISA target (e.g. "gfx1200") to (gfx, family).
|
||||
|
||||
family is one of:
|
||||
"rdna" — consumer Radeon RX (gfx10xx RDNA1/2, gfx11xx RDNA3, gfx12xx RDNA4)
|
||||
"cdna" — datacenter Instinct (gfx908 MI100, gfx90a MI200, gfx94x/95x MI300+)
|
||||
"gcn" — older GCN/Vega (gfx900/906)
|
||||
"unknown" — empty/unrecognized; callers must treat conservatively
|
||||
|
||||
This drives the serving decision: vLLM/SGLang on ROCm are validated on CDNA
|
||||
but fragile on consumer RDNA (AWQ kernels largely unsupported, FP8 needs
|
||||
out-of-tree patches), so RDNA is steered to GGUF/llama.cpp.
|
||||
"""
|
||||
gfx = (gfx or "").lower().strip()
|
||||
m = re.fullmatch(r"gfx(\d+[a-f]?)", gfx)
|
||||
if not m:
|
||||
return "", "unknown"
|
||||
digits = m.group(1)
|
||||
if digits[:2] in ("10", "11", "12"):
|
||||
return gfx, "rdna"
|
||||
if digits in ("908", "90a") or digits[:2] in ("94", "95"):
|
||||
return gfx, "cdna"
|
||||
if digits[:1] == "9":
|
||||
return gfx, "gcn"
|
||||
return gfx, "unknown"
|
||||
|
||||
|
||||
def _detect_amd():
|
||||
"""Detect AMD GPUs. Handles both discrete cards (with mem_info_vram_total)
|
||||
and APUs / unified-memory SoCs like Strix Halo (which expose
|
||||
@@ -155,6 +208,17 @@ def _detect_amd():
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
def _amd_arch():
|
||||
"""Best-effort AMD GPU ISA + family from rocminfo.
|
||||
|
||||
rocminfo is the source of truth; its GPU agents report a `Name: gfxNNNN`
|
||||
line (CPU agents report a brand string, not a gfx target), so the first
|
||||
gfx match is the GPU ISA. Returns (gfx, family) — see classify_amd_gfx.
|
||||
"""
|
||||
info = _run(["rocminfo"]) or _run(["/opt/rocm/bin/rocminfo"]) or ""
|
||||
m = re.search(r"gfx\d+[a-f]?", info)
|
||||
return classify_amd_gfx(m.group(0) if m else "")
|
||||
|
||||
try:
|
||||
cards = []
|
||||
is_apu = False
|
||||
@@ -187,6 +251,7 @@ def _detect_amd():
|
||||
return None
|
||||
total_vram = sum(c["vram_gb"] for c in cards)
|
||||
groups = _group_gpus(cards)
|
||||
gfx, family = _amd_arch()
|
||||
# NOTE: for APUs with BIOS UMA carveout (e.g. Strix Halo), vis_vram_total
|
||||
# is the real usable GPU memory — it's physically backed but reserved
|
||||
# by BIOS so it doesn't appear in /proc/meminfo. Don't cap it at system
|
||||
@@ -200,6 +265,13 @@ def _detect_amd():
|
||||
"homogeneous": len(groups) <= 1,
|
||||
"backend": "rocm",
|
||||
"unified_memory": is_apu,
|
||||
# AMD ISA/family so downstream can tell datacenter Instinct (CDNA,
|
||||
# where vLLM/SGLang run AWQ/GPTQ reliably) from consumer Radeon
|
||||
# (RDNA, where the practical path is GGUF via llama.cpp). Empty/
|
||||
# "unknown" when rocminfo isn't available — callers must treat
|
||||
# unknown conservatively, not assume vLLM works.
|
||||
"gpu_arch": gfx,
|
||||
"gpu_family": family,
|
||||
}
|
||||
except Exception:
|
||||
return None
|
||||
@@ -409,7 +481,7 @@ def _detect_windows():
|
||||
" $gpus = @(); "
|
||||
" foreach ($line in $nv -split \"`n\") { "
|
||||
" $p = $line -split ','; "
|
||||
" if ($p.Count -ge 2) { $gpus += @{name=$p[1].Trim(); vram_mb=[double]$p[0].Trim()} } "
|
||||
" if ($p.Count -ge 2) { $gpus += [pscustomobject]@{name=$p[1].Trim(); vram_mb=[double]$p[0].Trim()} } "
|
||||
" }; "
|
||||
" $r.gpu_name = $gpus[0].name; "
|
||||
" $r.gpu_vram_gb = [math]::Round(($gpus | Measure-Object -Property vram_mb -Sum).Sum / 1024, 1); "
|
||||
|
||||
@@ -5,7 +5,9 @@ import re
|
||||
QUANT_HIERARCHY = ["Q8_0", "Q6_K", "Q5_K_M", "Q4_K_M", "Q3_K_M", "Q2_K"]
|
||||
|
||||
QUANT_BPP = {
|
||||
"F32": 4.0, "F16": 2.0, "BF16": 2.0, "FP8": 1.0, "INT8": 1.0, "NVFP4": 0.5,
|
||||
"F32": 4.0, "F16": 2.0, "BF16": 2.0, "FP8": 1.0,
|
||||
"FP4": 0.50, "NVFP4": 0.50, "MXFP4": 0.50, "NF4": 0.50,
|
||||
"INT4": 0.50, "INT8": 1.0, "W4A16": 0.50, "W8A8": 1.0, "W8A16": 1.0,
|
||||
"Q8_0": 1.05, "Q6_K": 0.80, "Q5_K_M": 0.68,
|
||||
"Q4_K_M": 0.58, "Q4_0": 0.58, "Q3_K_M": 0.48, "Q2_K": 0.37,
|
||||
"AWQ-4bit": 0.50, "AWQ-8bit": 1.0,
|
||||
@@ -14,7 +16,9 @@ QUANT_BPP = {
|
||||
}
|
||||
|
||||
QUANT_SPEED_MULT = {
|
||||
"F16": 0.6, "BF16": 0.6, "FP8": 0.85, "INT8": 0.85, "NVFP4": 1.1,
|
||||
"F16": 0.6, "BF16": 0.6, "FP8": 0.85,
|
||||
"FP4": 1.15, "NVFP4": 1.15, "MXFP4": 1.15, "NF4": 1.10,
|
||||
"INT4": 1.15, "INT8": 0.85, "W4A16": 1.15, "W8A8": 0.85, "W8A16": 0.85,
|
||||
"Q8_0": 0.8, "Q6_K": 0.95, "Q5_K_M": 1.0,
|
||||
"Q4_K_M": 1.15, "Q4_0": 1.15, "Q3_K_M": 1.25, "Q2_K": 1.35,
|
||||
"AWQ-4bit": 1.2, "AWQ-8bit": 0.85,
|
||||
@@ -23,8 +27,10 @@ QUANT_SPEED_MULT = {
|
||||
}
|
||||
|
||||
QUANT_QUALITY_PENALTY = {
|
||||
"F16": 0.0, "BF16": 0.0, "FP8": 0.0, "INT8": 0.0, "NVFP4": -0.5,
|
||||
"Q8_0": -0.5, "Q6_K": -1.5, "Q5_K_M": -2.5,
|
||||
"F16": 0.0, "BF16": 0.0, "FP8": 0.0,
|
||||
"FP4": -3.0, "NVFP4": -3.0, "MXFP4": -3.0, "NF4": -4.0,
|
||||
"INT4": -4.0, "INT8": 0.0, "W4A16": -4.0, "W8A8": 0.0, "W8A16": 0.0,
|
||||
"Q8_0": 0.0, "Q6_K": -1.0, "Q5_K_M": -2.0,
|
||||
"Q4_K_M": -5.0, "Q4_0": -5.0, "Q3_K_M": -8.0, "Q2_K": -12.0,
|
||||
# Bare "AWQ" and "AWQ-8bit" used to be 0.0 (tied with FP8). In practice
|
||||
# AWQ-anything is a calibrated reconstruction, not raw 8-bit weights —
|
||||
@@ -36,7 +42,9 @@ QUANT_QUALITY_PENALTY = {
|
||||
}
|
||||
|
||||
QUANT_BYTES_PER_PARAM = {
|
||||
"F16": 2.0, "BF16": 2.0, "FP8": 1.0, "INT8": 1.0, "NVFP4": 0.5,
|
||||
"F16": 2.0, "BF16": 2.0, "FP8": 1.0,
|
||||
"FP4": 0.5, "NVFP4": 0.5, "MXFP4": 0.5, "NF4": 0.5,
|
||||
"INT4": 0.5, "INT8": 1.0, "W4A16": 0.5, "W8A8": 1.0, "W8A16": 1.0,
|
||||
"Q8_0": 1.0, "Q6_K": 0.75, "Q5_K_M": 0.625,
|
||||
"Q4_K_M": 0.5, "Q4_0": 0.5, "Q3_K_M": 0.375, "Q2_K": 0.25,
|
||||
"AWQ-4bit": 0.5, "AWQ-8bit": 1.0,
|
||||
@@ -44,8 +52,55 @@ QUANT_BYTES_PER_PARAM = {
|
||||
"mlx-4bit": 0.5, "mlx-8bit": 1.0, "mlx-6bit": 0.75,
|
||||
}
|
||||
|
||||
# Pre-quantized formats that should NOT go through the GGUF quant hierarchy
|
||||
PREQUANTIZED_PREFIXES = ("AWQ-", "GPTQ-", "mlx-", "FP8", "INT8", "NVFP4")
|
||||
# Pre-quantized formats that should NOT go through the GGUF quant hierarchy.
|
||||
# These are native HF/vLLM-style repos, not llama.cpp GGUF quant tiers.
|
||||
PREQUANTIZED_PREFIXES = (
|
||||
"AWQ-", "GPTQ-", "mlx-", "FP8", "FP4", "NVFP4", "MXFP4", "NF4",
|
||||
"INT4", "INT8", "W4A16", "W8A8", "W8A16",
|
||||
)
|
||||
|
||||
|
||||
def infer_quantization_from_name(name):
|
||||
n = (name or "").lower()
|
||||
if "nvfp4" in n:
|
||||
return "NVFP4"
|
||||
if "mxfp4" in n:
|
||||
return "MXFP4"
|
||||
if re.search(r"(^|[-_/])nf4($|[-_/])", n):
|
||||
return "NF4"
|
||||
if re.search(r"(^|[-_/])fp4($|[-_/])", n):
|
||||
return "FP4"
|
||||
if re.search(r"(^|[-_/])w4a16($|[-_/])", n):
|
||||
return "W4A16"
|
||||
if re.search(r"(^|[-_/])w8a8($|[-_/])", n):
|
||||
return "W8A8"
|
||||
if re.search(r"(^|[-_/])w8a16($|[-_/])", n):
|
||||
return "W8A16"
|
||||
is8 = "8bit" in n or "8-bit" in n or "int8" in n
|
||||
if "awq" in n:
|
||||
return "AWQ-8bit" if is8 else "AWQ-4bit"
|
||||
if "gptq" in n:
|
||||
return "GPTQ-Int8" if is8 else "GPTQ-Int4"
|
||||
if "mlx" in n:
|
||||
if "6bit" in n:
|
||||
return "mlx-6bit"
|
||||
return "mlx-8bit" if is8 else "mlx-4bit"
|
||||
if "fp8" in n:
|
||||
return "FP8"
|
||||
if "int4" in n or "4bit" in n or "4-bit" in n:
|
||||
return "INT4"
|
||||
if "int8" in n or "8bit" in n or "8-bit" in n:
|
||||
return "INT8"
|
||||
return ""
|
||||
|
||||
|
||||
def _normalize_model_entry(model):
|
||||
if not isinstance(model, dict):
|
||||
return model
|
||||
inferred = infer_quantization_from_name(model.get("name", ""))
|
||||
if inferred and (model.get("quantization") in (None, "", "Q4_K_M") or model.get("_discovered")):
|
||||
model["quantization"] = inferred
|
||||
return model
|
||||
|
||||
|
||||
def is_prequantized(model):
|
||||
@@ -72,7 +127,13 @@ def params_b(model):
|
||||
pc = pc.strip().upper()
|
||||
m = re.match(r"^([\d.]+)\s*([BKMGT]?)$", pc)
|
||||
if m:
|
||||
val = float(m.group(1))
|
||||
try:
|
||||
val = float(m.group(1))
|
||||
except ValueError:
|
||||
# Malformed count like "1.5.3B" — [\d.]+ matches but float()
|
||||
# rejects it. One bad catalog row must not abort the whole
|
||||
# ranking pass, so treat it as unknown size.
|
||||
return 0.0
|
||||
suffix = m.group(2)
|
||||
if suffix == "B":
|
||||
return val
|
||||
@@ -180,7 +241,7 @@ def get_models():
|
||||
data_path = os.path.join(os.path.dirname(__file__), "data", "hf_models.json")
|
||||
try:
|
||||
with open(data_path, encoding="utf-8") as f:
|
||||
_models_cache = json.load(f)
|
||||
_models_cache = [_normalize_model_entry(m) for m in json.load(f)]
|
||||
except (FileNotFoundError, json.JSONDecodeError):
|
||||
_models_cache = []
|
||||
return _models_cache
|
||||
|
||||
@@ -0,0 +1,229 @@
|
||||
"""Compute intelligent llama.cpp serve profiles from detected hardware.
|
||||
|
||||
Given a system (VRAM/RAM/arch) and a model, produce 1-4 ready-to-launch
|
||||
profiles — Quality / Balanced / Speed — with concrete llama.cpp flags
|
||||
(n_gpu_layers, n_cpu_moe, cache-type, context). This turns the by-hand tuning
|
||||
(how many MoE layers fit on the GPU, when to spend VRAM on a q8 KV cache vs more
|
||||
context, how much headroom to leave for a vision encoder) into a formula.
|
||||
|
||||
Pure/deterministic — no benchmarking, no I/O. Reuses the same VRAM math as
|
||||
fit.py/models.py so "what the Cookbook recommends" and "what it serves" agree.
|
||||
|
||||
NOTE: token/s figures are NOT computed here — real speed on partial-offload MoE
|
||||
is CPU-bound and not reliably predictable from specs. The UI labels profiles by
|
||||
their tradeoff (Quality/Balanced/Speed), and the VRAM fit (the part that decides
|
||||
whether it even loads) is what's computed from real numbers.
|
||||
"""
|
||||
|
||||
from services.hwfit.models import (
|
||||
QUANT_BPP,
|
||||
params_b,
|
||||
_active_params_b,
|
||||
is_prequantized,
|
||||
)
|
||||
|
||||
# GGUF KV-cache cost per token, in bytes-per-active-billion-param, by cache type.
|
||||
# q4_0 is ~half of q8_0 is ~half of f16. The 8e-6 base in estimate_memory_gb is
|
||||
# the q8_0-ish figure; scale from there.
|
||||
_KV_FACTOR = {"q4_0": 0.5, "q8_0": 1.0, "f16": 2.0}
|
||||
|
||||
# Quant ladder from highest quality/size down. A profile that wants "best quant
|
||||
# that fits fully on GPU" walks this until one fits.
|
||||
_QUANT_LADDER = ["Q8_0", "Q6_K", "Q5_K_M", "Q4_K_M", "Q3_K_M", "Q2_K"]
|
||||
|
||||
|
||||
def _weights_gb(model, quant, fixed_gb=None):
|
||||
"""VRAM for the full weights. When fixed_gb is given (serving a specific GGUF
|
||||
file already on disk), use its real size — the quant is whatever the file is,
|
||||
not something we get to pick."""
|
||||
if fixed_gb and fixed_gb > 0:
|
||||
return float(fixed_gb)
|
||||
return params_b(model) * QUANT_BPP.get(quant, 0.58)
|
||||
|
||||
|
||||
def _kv_gb(model, ctx, kv_type):
|
||||
"""KV-cache VRAM at a context length and cache type."""
|
||||
kv_params = _active_params_b(model)
|
||||
return 0.000008 * kv_params * ctx * _KV_FACTOR.get(kv_type, 1.0)
|
||||
|
||||
|
||||
def _n_layers(model):
|
||||
"""Best-effort total transformer block count (for n-cpu-moe math)."""
|
||||
for k in ("num_hidden_layers", "n_layers", "num_layers", "block_count"):
|
||||
v = model.get(k)
|
||||
if isinstance(v, (int, float)) and v > 0:
|
||||
return int(v)
|
||||
# Fallback heuristic by size — most MoE/dense LLMs land 28-64 layers.
|
||||
pb = params_b(model)
|
||||
if pb >= 60:
|
||||
return 64
|
||||
if pb >= 25:
|
||||
return 48
|
||||
if pb >= 12:
|
||||
return 40
|
||||
return 32
|
||||
|
||||
|
||||
def _cpu_moe_for_budget(model, quant, kv_gb, vram_budget_gb, fixed_gb=None):
|
||||
"""How many MoE layers must move to CPU so weights+KV fit vram_budget_gb.
|
||||
|
||||
Returns (n_cpu_moe, fits_fully). When the model already fits, n_cpu_moe=0.
|
||||
Each offloaded layer frees roughly weights/n_layers of VRAM. We only model
|
||||
this for MoE (where --n-cpu-moe applies); dense models just report whether
|
||||
they fit at the given n_gpu_layers=999.
|
||||
"""
|
||||
weights = _weights_gb(model, quant, fixed_gb)
|
||||
needed = weights + kv_gb + 0.6 # +0.6 GB runtime/compute buffers
|
||||
if needed <= vram_budget_gb:
|
||||
return 0, True
|
||||
if not model.get("is_moe"):
|
||||
# Dense: no per-expert offload knob; either it fits or it spills via -ngl.
|
||||
return 0, False
|
||||
layers = _n_layers(model)
|
||||
per_layer = weights / max(layers, 1)
|
||||
overflow = needed - vram_budget_gb
|
||||
import math
|
||||
n = math.ceil(overflow / max(per_layer, 1e-6))
|
||||
n = max(0, min(n, layers)) # clamp
|
||||
return n, False
|
||||
|
||||
|
||||
def compute_serve_profiles(system, model, serve_weights_gb=None, serve_quant=None):
|
||||
"""Return a list of profile dicts for llama.cpp serving of `model` on `system`.
|
||||
|
||||
Each profile: {key, label, quant, n_gpu_layers, n_cpu_moe, cache_type, ctx,
|
||||
est_vram_gb, fits, note}. Empty list if no GGUF path makes
|
||||
sense (caller should fall back to manual flags).
|
||||
|
||||
DOWNLOAD mode (default): the quant isn't chosen yet, so profiles vary it
|
||||
(Quality=Q6, Balanced=Q4, Speed=Q2…) to show download options.
|
||||
|
||||
SERVE mode (serve_weights_gb set): a specific GGUF file already exists on
|
||||
disk — its quant is FIXED. Profiles then keep that quant/size and differ only
|
||||
in the actual serving knobs (n_cpu_moe, KV-cache type, context). serve_quant
|
||||
is the file's quant label (e.g. "Q4_K_M") just for display.
|
||||
"""
|
||||
vram = float(system.get("gpu_vram_gb") or 0)
|
||||
if vram <= 0:
|
||||
return []
|
||||
|
||||
serve_mode = bool(serve_weights_gb and serve_weights_gb > 0)
|
||||
|
||||
# Never propose more context than the model was trained for — asking llama.cpp
|
||||
# for ctx > n_ctx_train triggers a "training context overflow" and, with a
|
||||
# quantized KV cache, an oversized allocation that can crash the GPU
|
||||
# (radv/amdgpu ErrorDeviceLost). Cap every profile at the model's real limit.
|
||||
model_ctx_max = 0
|
||||
for k in ("context_length", "max_position_embeddings", "n_ctx_train", "context"):
|
||||
v = model.get(k)
|
||||
if isinstance(v, (int, float)) and v > 0:
|
||||
model_ctx_max = int(v)
|
||||
break
|
||||
if model_ctx_max <= 0:
|
||||
model_ctx_max = 131072 # conservative default when the catalog omits it
|
||||
|
||||
# Vision models need headroom for the image encoder (~1 GB on top of weights).
|
||||
is_vision = bool(
|
||||
model.get("is_multimodal") or model.get("vision") or model.get("mmproj")
|
||||
or "vl" in str(model.get("name", "")).lower()
|
||||
)
|
||||
headroom = 1.1 if is_vision else 0.4
|
||||
budget = max(vram - headroom, 1.0)
|
||||
|
||||
# Prequantized (AWQ/GPTQ/FP8) served via GGUF fallback use a fixed ~Q4 quant;
|
||||
# GGUF models can pick their quant. Pick a sensible per-profile quant.
|
||||
fixed_quant = model.get("quantization") if is_prequantized(model) else None
|
||||
|
||||
is_moe = bool(model.get("is_moe"))
|
||||
|
||||
def _pick_quant(prefer, require_full_fit):
|
||||
"""Choose a quant for a profile.
|
||||
|
||||
- fixed_quant (AWQ/GPTQ/FP8 served via GGUF): always that.
|
||||
- require_full_fit=True (Speed): walk DOWN from `prefer` to the best quant
|
||||
whose weights fit fully on the GPU (no offload) — fastest.
|
||||
- require_full_fit=False (Quality on MoE): keep `prefer` even if it must
|
||||
offload experts to CPU; that's the whole point of n-cpu-moe on a card
|
||||
too small to hold the weights. For dense models we can't offload
|
||||
per-expert, so fall back to the largest fully-fitting quant.
|
||||
"""
|
||||
if fixed_quant:
|
||||
return fixed_quant
|
||||
start = _QUANT_LADDER.index(prefer) if prefer in _QUANT_LADDER else 3
|
||||
if require_full_fit or not is_moe:
|
||||
for q in _QUANT_LADDER[start:]:
|
||||
if _weights_gb(model, q) + 0.6 <= budget:
|
||||
return q
|
||||
return _QUANT_LADDER[-1]
|
||||
# MoE quality: keep the preferred (big) quant; offload handles overflow.
|
||||
return prefer
|
||||
|
||||
if serve_mode:
|
||||
# Fixed file on disk — quant can't change. Vary only the serving knobs.
|
||||
fq = serve_quant or model.get("quantization") or "GGUF"
|
||||
specs = [
|
||||
# key, label, prefer_quant, full_fit, kv_type, ctx, note
|
||||
("quality", "Quality", fq, False, "q8_0", 131072,
|
||||
"Sharp q8 KV cache + full context. Best long-context accuracy; offloads MoE layers to CPU if needed."),
|
||||
("balanced", "Balanced", fq, False, "q4_0", 131072,
|
||||
"Compact q4 KV at full context — good speed/quality mix."),
|
||||
("speed", "Speed", fq, False, "q4_0", 32768,
|
||||
"Trimmed context + light KV for the fastest tokens/s."),
|
||||
]
|
||||
else:
|
||||
specs = [
|
||||
# key, label, prefer_quant, full_fit, kv_type, ctx, note
|
||||
("quality", "Quality", "Q6_K", False, "q8_0", 131072,
|
||||
"Biggest quant + sharp q8 KV cache. Best answers; offloads MoE layers to CPU if needed."),
|
||||
("balanced", "Balanced", "Q4_K_M", False, "q4_0", 131072,
|
||||
"Q4 weights + compact q4 KV. Good speed/quality mix at full context."),
|
||||
("speed", "Speed", "Q4_K_M", True, "q4_0", 32768,
|
||||
"Smallest offload + trimmed context for the fastest tokens/s."),
|
||||
]
|
||||
|
||||
profiles = []
|
||||
for key, label, prefer_q, full_fit, kv_type, ctx, note in specs:
|
||||
# In serve mode the quant is fixed (the file's); in download mode we pick.
|
||||
quant = prefer_q if serve_mode else _pick_quant(prefer_q, full_fit)
|
||||
# Shrink context if even the chosen KV won't fit alongside weights.
|
||||
# Start from the smaller of the profile's target and the model's limit.
|
||||
cur_ctx = min(ctx, model_ctx_max)
|
||||
while cur_ctx >= 8192:
|
||||
kv = _kv_gb(model, cur_ctx, kv_type)
|
||||
n_cpu_moe, fits = _cpu_moe_for_budget(model, quant, kv, budget, fixed_gb=serve_weights_gb)
|
||||
est = _weights_gb(model, quant, serve_weights_gb) + kv + 0.6
|
||||
# If a non-MoE model can't fit even fully offloaded, try less context.
|
||||
if model.get("is_moe") or fits or cur_ctx <= 8192:
|
||||
profiles.append({
|
||||
"key": key,
|
||||
"label": label,
|
||||
"quant": quant,
|
||||
"n_gpu_layers": 999,
|
||||
"n_cpu_moe": n_cpu_moe,
|
||||
"cache_type": kv_type,
|
||||
"ctx": cur_ctx,
|
||||
# When experts offload, GPU-resident VRAM tops out at the
|
||||
# budget (weights beyond it live in system RAM), so cap the
|
||||
# estimate at `budget`, not the full card — this also leaves
|
||||
# the vision-encoder headroom visible in the number.
|
||||
"est_vram_gb": round(min(est, budget), 1),
|
||||
# For MoE we treat it as fitting via offload; report whether
|
||||
# it fit WITHOUT offload as the "clean" flag.
|
||||
"fits": fits or bool(model.get("is_moe")),
|
||||
"offloads": n_cpu_moe > 0,
|
||||
"note": note,
|
||||
})
|
||||
break
|
||||
cur_ctx //= 2
|
||||
|
||||
# De-dupe identical profiles (e.g. tiny model where all three collapse to the
|
||||
# same all-GPU config) — keep the first/highest-quality label.
|
||||
seen = set()
|
||||
deduped = []
|
||||
for p in profiles:
|
||||
sig = (p["quant"], p["n_cpu_moe"], p["cache_type"], p["ctx"])
|
||||
if sig in seen:
|
||||
continue
|
||||
seen.add(sig)
|
||||
deduped.append(p)
|
||||
return deduped
|
||||
@@ -12,7 +12,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
def tokenize(text: str) -> List[str]:
|
||||
"""Simple tokenizer that splits on whitespace and removes punctuation."""
|
||||
return [word.strip('.,!?";') for word in text.split()]
|
||||
return [cleaned for word in text.split() if (cleaned := word.strip('.,!?";'))]
|
||||
|
||||
def get_text_similarity(text1: str, text2: str) -> float:
|
||||
"""Calculate Jaccard similarity between two texts."""
|
||||
@@ -59,14 +59,18 @@ class MemoryManager:
|
||||
line = line.strip()
|
||||
# Look for bullet points or numbered lists that might contain memories
|
||||
if re.match(r'^[-*•]|\d+\.', line):
|
||||
# Extract the text after the bullet/number
|
||||
text_match = re.match(r'^[-*•]|\d+\.\s*(.*)', line)
|
||||
# Extract the text after the bullet/number. Group both
|
||||
# markers so the capture applies to either. The previous
|
||||
# `^[-*•]|\d+\.\s*(.*)` put the group on the numbered
|
||||
# branch only, so a bullet line matched with group(1)=None
|
||||
# and crashed on .strip().
|
||||
text_match = re.match(r'^(?:[-*•]|\d+\.)\s*(.*)', line)
|
||||
if text_match:
|
||||
text = text_match.group(1).strip()
|
||||
if text:
|
||||
memories.append({
|
||||
"text": text,
|
||||
"timestamp": int(datetime.now().timestamp()),
|
||||
"timestamp": int(time.time()),
|
||||
"session_id": session_id
|
||||
})
|
||||
# If we see a heading that suggests memories
|
||||
@@ -101,6 +105,7 @@ class MemoryManager:
|
||||
def ensure_file_exists(self):
|
||||
"""Create memory file if it doesn't exist."""
|
||||
if not os.path.exists(self.memory_file):
|
||||
os.makedirs(os.path.dirname(self.memory_file), exist_ok=True)
|
||||
with open(self.memory_file, 'w', encoding='utf-8') as f:
|
||||
json.dump([], f, ensure_ascii=False, indent=2)
|
||||
|
||||
|
||||
@@ -34,7 +34,7 @@ def _fingerprint_entries(entries) -> str:
|
||||
only on id+text+category. Any add/edit/delete invalidates it."""
|
||||
items = sorted(
|
||||
(str(e.get("id", "")), e.get("text", ""), e.get("category", ""))
|
||||
for e in entries
|
||||
for e in _memory_dicts(entries)
|
||||
)
|
||||
h = hashlib.sha256()
|
||||
for triple in items:
|
||||
@@ -42,6 +42,12 @@ def _fingerprint_entries(entries) -> str:
|
||||
return h.hexdigest()
|
||||
|
||||
|
||||
def _memory_dicts(entries):
|
||||
for entry in entries or []:
|
||||
if isinstance(entry, dict):
|
||||
yield entry
|
||||
|
||||
|
||||
def _load_tidy_state(memory_manager) -> dict:
|
||||
path = _tidy_state_path(memory_manager)
|
||||
try:
|
||||
@@ -211,7 +217,7 @@ def _is_text_duplicate(new_text: str, existing: list, threshold: float = 0.6) ->
|
||||
new_tokens = set(new_text.lower().split())
|
||||
if not new_tokens:
|
||||
return False
|
||||
for entry in existing:
|
||||
for entry in _memory_dicts(existing):
|
||||
old_tokens = set(entry.get("text", "").lower().split())
|
||||
if not old_tokens:
|
||||
continue
|
||||
@@ -235,6 +241,10 @@ async def extract_and_store(
|
||||
Designed to run as a background task (asyncio.create_task).
|
||||
Errors are logged, never raised.
|
||||
"""
|
||||
if not endpoint_url or not model:
|
||||
logger.debug("[memory-extract] No model or URL provided, skipping")
|
||||
return
|
||||
|
||||
try:
|
||||
from src.llm_core import llm_call_async
|
||||
|
||||
@@ -245,11 +255,30 @@ async def extract_and_store(
|
||||
if len(recent) < 2:
|
||||
return # Need at least a user message and assistant response
|
||||
|
||||
fallback_facts = _fallback_memory_candidates(recent)
|
||||
# Strip media (images/audio) from messages — background memory extraction
|
||||
# only needs the text. The VL-generated descriptions are already in the
|
||||
# text content of the messages. This avoids sending image tokens to
|
||||
# non-vision models and prevents accidental "vision grounding" triggers.
|
||||
stripped_recent = []
|
||||
for msg in recent:
|
||||
role = msg.get("role")
|
||||
content = msg.get("content", "")
|
||||
if isinstance(content, list):
|
||||
# Filter out multimodal blocks that aren't text
|
||||
text_only = [b for b in content if isinstance(b, dict) and b.get("type") == "text"]
|
||||
if not text_only and content:
|
||||
continue
|
||||
content = text_only
|
||||
stripped_recent.append({"role": role, "content": content})
|
||||
|
||||
if not stripped_recent:
|
||||
return
|
||||
|
||||
fallback_facts = _fallback_memory_candidates(stripped_recent)
|
||||
|
||||
extraction_messages = [
|
||||
{"role": "system", "content": EXTRACT_SYSTEM_PROMPT},
|
||||
] + recent
|
||||
] + stripped_recent
|
||||
|
||||
facts = []
|
||||
try:
|
||||
@@ -303,9 +332,18 @@ async def extract_and_store(
|
||||
if not fact_text or len(fact_text) < 5:
|
||||
continue
|
||||
|
||||
# Dedup: check vector similarity first (fast), then exact text match
|
||||
# Dedup: check vector similarity first (fast), then exact text match.
|
||||
# A runtime embedding/ChromaDB failure (backend OOM, model evicted,
|
||||
# remote endpoint down) must not abort the whole batch — fall through
|
||||
# to the text/fuzzy dedup below instead of losing every validated
|
||||
# fact extracted this session. (`.healthy` is only set at init, so
|
||||
# it does not catch failures that develop later.)
|
||||
if memory_vector and memory_vector.healthy:
|
||||
existing_id = memory_vector.find_similar(fact_text, threshold=0.72)
|
||||
try:
|
||||
existing_id = memory_vector.find_similar(fact_text, threshold=0.72)
|
||||
except Exception as e:
|
||||
logger.warning(f"Memory dedup (vector) unavailable, using text fallback: {e}")
|
||||
existing_id = None
|
||||
if existing_id:
|
||||
logger.debug(f"Memory dedup (vector): '{fact_text[:50]}' matches {existing_id}")
|
||||
continue
|
||||
@@ -330,9 +368,14 @@ async def extract_and_store(
|
||||
|
||||
existing.append(entry)
|
||||
|
||||
# Add to vector index
|
||||
# Add to vector index. The JSON store (saved below) is the source of
|
||||
# truth and the keyword path can still retrieve this entry, so a vector
|
||||
# write failure must not drop the fact or abort the remaining batch.
|
||||
if memory_vector and memory_vector.healthy:
|
||||
memory_vector.add(entry["id"], fact_text)
|
||||
try:
|
||||
memory_vector.add(entry["id"], fact_text)
|
||||
except Exception as e:
|
||||
logger.warning(f"Memory vector add failed for {entry['id']}: {e}")
|
||||
|
||||
added += 1
|
||||
|
||||
@@ -510,17 +553,20 @@ async def audit_memories(
|
||||
for e in all_entries:
|
||||
if e.get("owner") is None and e["id"] not in audited_ids and e["id"] not in {o["id"] for o in other_entries}:
|
||||
other_entries.append(e)
|
||||
memory_manager.save(final_entries + other_entries)
|
||||
saved_entries = final_entries + other_entries
|
||||
else:
|
||||
memory_manager.save(final_entries)
|
||||
saved_entries = final_entries
|
||||
memory_manager.save(saved_entries)
|
||||
logger.info(
|
||||
f"Memory audit complete: {before_count} -> {after_count} entries "
|
||||
f"({before_count - after_count} removed/merged)"
|
||||
)
|
||||
|
||||
# Rebuild vector index
|
||||
# Rebuild vector index from the full saved set, not just this owner's
|
||||
# slice — otherwise the shared collection is wiped of every other
|
||||
# owner's entries until they happen to run their own audit.
|
||||
if memory_vector and memory_vector.healthy:
|
||||
memory_vector.rebuild(final_entries)
|
||||
memory_vector.rebuild(saved_entries)
|
||||
|
||||
# Persist the post-tidy fingerprint so the next call short-circuits
|
||||
# if nothing has changed in the meantime.
|
||||
|
||||
@@ -103,6 +103,7 @@ class MemoryService:
|
||||
metadata=r.get("metadata", {}),
|
||||
)
|
||||
for r in results
|
||||
if isinstance(r, dict)
|
||||
]
|
||||
return MemorySearchResult(memories=memories, query=query, total=len(memories))
|
||||
|
||||
|
||||
@@ -48,6 +48,21 @@ MIN_CONFIDENCE = 0.6
|
||||
CONTEXT_WINDOW = 12
|
||||
|
||||
|
||||
def _skill_dicts(skills):
|
||||
for skill in skills or []:
|
||||
if isinstance(skill, dict):
|
||||
yield skill
|
||||
|
||||
|
||||
def _has_duplicate_title(skills, title: str) -> bool:
|
||||
wanted = title.lower()
|
||||
for skill in _skill_dicts(skills):
|
||||
existing = skill.get("title", "")
|
||||
if isinstance(existing, str) and existing.lower() == wanted:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
async def maybe_extract_skill(
|
||||
session,
|
||||
skills_manager,
|
||||
@@ -59,6 +74,10 @@ async def maybe_extract_skill(
|
||||
owner: Optional[str] = None,
|
||||
):
|
||||
"""Extract a skill if the agent run was complex enough."""
|
||||
if not model:
|
||||
logger.debug("[skill-extract] No model provided, skipping")
|
||||
return None
|
||||
|
||||
# Quiet by default; flip to DEBUG when chasing extractor issues.
|
||||
logger.debug(
|
||||
"[skill-extract] start: rounds=%d tools=%d model=%s owner=%s",
|
||||
@@ -78,9 +97,23 @@ async def maybe_extract_skill(
|
||||
logger.debug("[skill-extract] no recent messages, skipping")
|
||||
return None
|
||||
|
||||
# Strip media (images/audio) from messages
|
||||
stripped_recent = []
|
||||
for msg in recent:
|
||||
content = msg.get("content", "")
|
||||
if isinstance(content, list):
|
||||
text_only = [b for b in content if isinstance(b, dict) and b.get("type") == "text"]
|
||||
if not text_only and content:
|
||||
continue
|
||||
content = text_only
|
||||
stripped_recent.append({"role": msg.get("role"), "content": content})
|
||||
|
||||
if not stripped_recent:
|
||||
return None
|
||||
|
||||
# Build conversation summary for extraction
|
||||
conv_lines = []
|
||||
for msg in recent:
|
||||
for msg in stripped_recent:
|
||||
role = msg.get("role", "?")
|
||||
content = msg.get("content", "")
|
||||
if isinstance(content, list):
|
||||
@@ -173,10 +206,9 @@ async def maybe_extract_skill(
|
||||
|
||||
# Check for duplicate skills
|
||||
existing = skills_manager.load(owner=owner)
|
||||
for sk in existing:
|
||||
if sk.get("title", "").lower() == title.lower():
|
||||
logger.debug("[skill-extract] '%s' already exists — dropped as duplicate", title)
|
||||
return None
|
||||
if _has_duplicate_title(existing, title):
|
||||
logger.debug("[skill-extract] '%s' already exists — dropped as duplicate", title)
|
||||
return None
|
||||
|
||||
entry = skills_manager.add_skill(
|
||||
title=title,
|
||||
|
||||
+53
-24
@@ -6,8 +6,8 @@ YAML frontmatter and a structured markdown body (When to Use / Procedure /
|
||||
Pitfalls / Verification). See `skill_format.py` for the format.
|
||||
|
||||
Usage counters (`uses`, `last_used`) live in a sidecar
|
||||
`data/skills/_usage.json` keyed by skill name so the SKILL.md content
|
||||
doesn't churn on every retrieval.
|
||||
`data/skills/_usage.json` keyed by owner plus skill name so the SKILL.md
|
||||
content doesn't churn on every retrieval.
|
||||
|
||||
Ownership: skills declare `owner: <username>` in frontmatter. Single-user
|
||||
deployments can leave that blank.
|
||||
@@ -105,14 +105,29 @@ class SkillsManager:
|
||||
json.dump(usage, f, indent=2)
|
||||
os.replace(tmp, self.usage_file)
|
||||
|
||||
@staticmethod
|
||||
def _usage_key(name: str, owner: Optional[str] = None) -> str:
|
||||
# Skill names are not globally unique once multiple owners are present.
|
||||
# Keep the usage sidecar keyed the same way the skill file is scoped.
|
||||
return f"{owner}::{name}" if owner else name
|
||||
|
||||
def _usage_entry(self, usage: Dict[str, Dict], name: str, owner: Optional[str] = None) -> Dict:
|
||||
key = self._usage_key(name, owner)
|
||||
entry = usage.get(key)
|
||||
if isinstance(entry, dict):
|
||||
return entry
|
||||
return {}
|
||||
|
||||
def set_audit(self, name: str, verdict: str, by_teacher: bool = False,
|
||||
worker_model: str = "", teacher_model: str = "") -> None:
|
||||
worker_model: str = "", teacher_model: str = "",
|
||||
owner: Optional[str] = None) -> None:
|
||||
"""Record the last test/audit result for a skill in the usage sidecar
|
||||
(so it surfaces in load() without touching SKILL.md). Drives the
|
||||
'verified' check + teacher mark on the card."""
|
||||
import time as _t
|
||||
usage = self._load_usage()
|
||||
e = usage.setdefault(name, {"uses": 0, "last_used": None})
|
||||
key = self._usage_key(name, owner)
|
||||
e = usage.setdefault(key, {"uses": 0, "last_used": None})
|
||||
e["audit_verdict"] = verdict
|
||||
e["audit_by_teacher"] = bool(by_teacher)
|
||||
if worker_model:
|
||||
@@ -123,11 +138,13 @@ class SkillsManager:
|
||||
self._save_usage(usage)
|
||||
|
||||
def set_necessity(self, name: str, necessary: bool,
|
||||
redundant_with=None, reason: str = "") -> None:
|
||||
redundant_with=None, reason: str = "",
|
||||
owner: Optional[str] = None) -> None:
|
||||
"""Record the advisory 'is this skill necessary?' judgment in the usage
|
||||
sidecar. Surfaced on the card as a flag; never acts on the skill."""
|
||||
usage = self._load_usage()
|
||||
e = usage.setdefault(name, {"uses": 0, "last_used": None})
|
||||
key = self._usage_key(name, owner)
|
||||
e = usage.setdefault(key, {"uses": 0, "last_used": None})
|
||||
e["necessity"] = {
|
||||
"necessary": bool(necessary),
|
||||
"redundant_with": list(redundant_with or []),
|
||||
@@ -207,7 +224,7 @@ class SkillsManager:
|
||||
if not sk:
|
||||
continue
|
||||
d = sk.to_dict()
|
||||
u = usage.get(sk.name) or {}
|
||||
u = self._usage_entry(usage, sk.name, sk.owner)
|
||||
d["uses"] = int(u.get("uses", 0))
|
||||
d["last_used"] = u.get("last_used")
|
||||
d["audit_verdict"] = u.get("audit_verdict")
|
||||
@@ -308,6 +325,7 @@ class SkillsManager:
|
||||
# never auto-skipped — a human asked for it. The every-X AI audit
|
||||
# handles the fuzzier near-duplicates this cheap check won't catch.
|
||||
_all = self.load_all()
|
||||
_dedup_pool = _all if owner is None else [s for s in _all if s.get("owner") == owner]
|
||||
if source != "user":
|
||||
cand = _tokenize(" ".join([
|
||||
nm, (description or title or ""),
|
||||
@@ -315,7 +333,7 @@ class SkillsManager:
|
||||
" ".join(procedure if procedure is not None else (steps or [])),
|
||||
]))
|
||||
if cand:
|
||||
for s in _all:
|
||||
for s in _dedup_pool:
|
||||
ex = _tokenize(" ".join([
|
||||
s.get("name", ""), s.get("description", ""),
|
||||
s.get("when_to_use", ""),
|
||||
@@ -326,7 +344,7 @@ class SkillsManager:
|
||||
# existing skill's usage and return it so the caller
|
||||
# knows it already exists.
|
||||
try:
|
||||
self.record_use(s["name"])
|
||||
self.record_use(s["name"], owner=s.get("owner"))
|
||||
except Exception:
|
||||
pass
|
||||
return {**s, "_deduped": True, "_duplicate_of": s.get("name")}
|
||||
@@ -428,8 +446,9 @@ class SkillsManager:
|
||||
os.rename(old_dir, new_dir)
|
||||
# Also rename usage key
|
||||
usage = self._load_usage()
|
||||
if skill_id in usage:
|
||||
usage[sk.name] = usage.pop(skill_id)
|
||||
old_usage_key = self._usage_key(skill_id, sk.owner)
|
||||
if old_usage_key in usage:
|
||||
usage[self._usage_key(sk.name, sk.owner)] = usage.pop(old_usage_key)
|
||||
self._save_usage(usage)
|
||||
self._write_skill(sk)
|
||||
return True
|
||||
@@ -455,15 +474,17 @@ class SkillsManager:
|
||||
logger.warning(f"Failed to remove skill dir {skill_dir}: {e}")
|
||||
return False
|
||||
usage = self._load_usage()
|
||||
if skill_id in usage:
|
||||
del usage[skill_id]
|
||||
usage_key = self._usage_key(skill_id, sk.owner)
|
||||
if usage_key in usage:
|
||||
del usage[usage_key]
|
||||
self._save_usage(usage)
|
||||
return True
|
||||
return False
|
||||
|
||||
def record_use(self, skill_id: str) -> None:
|
||||
def record_use(self, skill_id: str, owner: Optional[str] = None) -> None:
|
||||
usage = self._load_usage()
|
||||
entry = usage.setdefault(skill_id, {"uses": 0, "last_used": None})
|
||||
key = self._usage_key(skill_id, owner)
|
||||
entry = usage.setdefault(key, {"uses": 0, "last_used": None})
|
||||
entry["uses"] = int(entry.get("uses", 0)) + 1
|
||||
entry["last_used"] = int(time.time())
|
||||
self._save_usage(usage)
|
||||
@@ -472,24 +493,29 @@ class SkillsManager:
|
||||
# Reading a single skill (used by the skill_view tool)
|
||||
# ----------------------------------------------------------------------
|
||||
|
||||
def read_skill_md(self, name: str) -> Optional[str]:
|
||||
def read_skill_md(self, name: str, owner: Optional[str] = None) -> Optional[str]:
|
||||
for path in self._iter_skill_files():
|
||||
sk = self._read_skill(path)
|
||||
if sk and sk.name == name:
|
||||
try:
|
||||
with open(path, encoding="utf-8") as f:
|
||||
return f.read()
|
||||
except Exception:
|
||||
return None
|
||||
if not sk or sk.name != name:
|
||||
continue
|
||||
if (sk.owner or "") != (owner or ""):
|
||||
continue
|
||||
try:
|
||||
with open(path, encoding="utf-8") as f:
|
||||
return f.read()
|
||||
except Exception:
|
||||
return None
|
||||
return None
|
||||
|
||||
def read_skill_reference(self, name: str, ref_path: str) -> Optional[str]:
|
||||
def read_skill_reference(self, name: str, ref_path: str, owner: Optional[str] = None) -> Optional[str]:
|
||||
"""Read a sub-file under the skill's directory (references/, etc).
|
||||
Refuses path traversal."""
|
||||
for path in self._iter_skill_files():
|
||||
sk = self._read_skill(path)
|
||||
if not sk or sk.name != name:
|
||||
continue
|
||||
if (sk.owner or "") != (owner or ""):
|
||||
continue
|
||||
base = os.path.realpath(os.path.dirname(path))
|
||||
target = os.path.realpath(os.path.join(base, ref_path))
|
||||
if os.path.commonpath([base, target]) != base or target == os.path.dirname(path):
|
||||
@@ -624,7 +650,10 @@ class SkillsManager:
|
||||
])
|
||||
score = _jaccard(query_tokens, _tokenize(text))
|
||||
for tag in sk.get("tags", []) or []:
|
||||
if tag and tag in query.lower():
|
||||
# Match tags as whole tokens, not substrings: `tag in query`
|
||||
# boosted e.g. a "ai" tag for any query containing "email".
|
||||
tag_tokens = _tokenize(tag)
|
||||
if tag_tokens and tag_tokens <= query_tokens:
|
||||
score = max(score, 0.3) * 1.3
|
||||
if query.lower() in (sk.get("description") or "").lower():
|
||||
score = max(score, 0.6)
|
||||
|
||||
@@ -14,6 +14,8 @@ import time
|
||||
from pathlib import Path
|
||||
from typing import Optional, Dict
|
||||
|
||||
from src.research_utils import is_low_quality
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
RESEARCH_DATA_DIR = Path("data/deep_research")
|
||||
@@ -179,13 +181,14 @@ class ResearchHandler:
|
||||
|
||||
@staticmethod
|
||||
def _extract_sources(findings: list) -> list:
|
||||
"""Extract deduplicated [{url, title}] from findings."""
|
||||
"""Extract deduplicated [{url, title}] from findings, filtering low-quality ones."""
|
||||
seen = set()
|
||||
sources = []
|
||||
for f in findings:
|
||||
url = f.get("url", "")
|
||||
title = f.get("title", "") or url
|
||||
if url and url not in seen:
|
||||
summary = f.get("summary", "") or f.get("evidence", "")
|
||||
if url and url not in seen and not is_low_quality(summary):
|
||||
seen.add(url)
|
||||
sources.append({"url": url, "title": title})
|
||||
return sources
|
||||
@@ -346,7 +349,8 @@ class ResearchHandler:
|
||||
for f in findings:
|
||||
url = f.get("url", "")
|
||||
title = f.get("title", "") or url
|
||||
if url and url not in seen_urls:
|
||||
summary = f.get("summary", "") or f.get("evidence", "")
|
||||
if url and url not in seen_urls and not is_low_quality(summary):
|
||||
seen_urls.add(url)
|
||||
source_lines.append(f"- [{title}]({url})")
|
||||
if source_lines:
|
||||
|
||||
@@ -1,11 +1,16 @@
|
||||
# services/research/service.py
|
||||
"""Research service — deep research with LLM-in-the-loop."""
|
||||
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Optional, Callable
|
||||
|
||||
from .research_handler import ResearchHandler
|
||||
|
||||
# Markdown source links emitted by ResearchHandler._format_research_report,
|
||||
# e.g. "- [Some Title](https://example.com/page)".
|
||||
_SOURCE_LINK_RE = re.compile(r"^\s*-\s*\[(?P<title>[^\]]*)\]\((?P<url>[^)]+)\)\s*$")
|
||||
|
||||
|
||||
@dataclass
|
||||
class ResearchSource:
|
||||
@@ -75,26 +80,71 @@ class ResearchService:
|
||||
|
||||
duration = time.time() - start
|
||||
|
||||
# Parse result into structured format
|
||||
sources = [
|
||||
ResearchSource(
|
||||
url=s.get("url", ""),
|
||||
title=s.get("title", ""),
|
||||
snippet=s.get("snippet", ""),
|
||||
relevance=s.get("relevance", 0.0),
|
||||
# call_research_service returns a formatted markdown report string
|
||||
# (see ResearchHandler.call_research_service -> _format_research_report),
|
||||
# not a dict. Treat it as such; tolerate an unexpected dict/None defensively.
|
||||
if isinstance(result, dict):
|
||||
sources = [
|
||||
ResearchSource(
|
||||
url=s.get("url", ""),
|
||||
title=s.get("title", ""),
|
||||
snippet=s.get("snippet", ""),
|
||||
relevance=s.get("relevance", 0.0),
|
||||
)
|
||||
for s in result.get("sources", [])
|
||||
if isinstance(s, dict)
|
||||
]
|
||||
return ResearchResult(
|
||||
query=topic,
|
||||
summary=result.get("summary", result.get("answer", "")),
|
||||
sources=sources,
|
||||
sections=result.get("sections", []),
|
||||
tokens_used=result.get("tokens_used", 0),
|
||||
duration_seconds=duration,
|
||||
)
|
||||
for s in result.get("sources", [])
|
||||
]
|
||||
|
||||
report = result if isinstance(result, str) else ""
|
||||
return ResearchResult(
|
||||
query=topic,
|
||||
summary=result.get("summary", result.get("answer", "")),
|
||||
sources=sources,
|
||||
sections=result.get("sections", []),
|
||||
tokens_used=result.get("tokens_used", 0),
|
||||
summary=report,
|
||||
sources=self._parse_sources(report),
|
||||
duration_seconds=duration,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _parse_sources(report: str) -> List[ResearchSource]:
|
||||
"""Extract sources from the markdown ### Sources section of a report.
|
||||
|
||||
ResearchHandler emits one ``- [title](url)`` link per deduplicated
|
||||
finding under a ``### Sources`` heading. Parse only that section so
|
||||
inline links elsewhere in the body are not mistaken for sources.
|
||||
"""
|
||||
if not report:
|
||||
return []
|
||||
sources: List[ResearchSource] = []
|
||||
seen = set()
|
||||
in_sources = False
|
||||
for line in report.splitlines():
|
||||
stripped = line.strip()
|
||||
if stripped.startswith("###") or stripped.startswith("##"):
|
||||
in_sources = stripped.lower().lstrip("#").strip() == "sources"
|
||||
continue
|
||||
if not in_sources:
|
||||
continue
|
||||
match = _SOURCE_LINK_RE.match(line)
|
||||
if not match:
|
||||
continue
|
||||
url = match.group("url").strip()
|
||||
if not url or url in seen:
|
||||
continue
|
||||
seen.add(url)
|
||||
sources.append(
|
||||
# snippet is required on ResearchSource; markdown source links
|
||||
# carry no snippet, so default to empty (matches the dict path).
|
||||
ResearchSource(url=url, title=match.group("title").strip(), snippet="")
|
||||
)
|
||||
return sources
|
||||
|
||||
def start_background(
|
||||
self,
|
||||
session_id: str,
|
||||
|
||||
@@ -45,32 +45,36 @@ class RateLimitError(SearchEngineError):
|
||||
# ----------------------------------------------------------------------
|
||||
# Analytics helpers
|
||||
# ----------------------------------------------------------------------
|
||||
def _default_analytics() -> Dict[str, Any]:
|
||||
return {
|
||||
"total_queries": 0,
|
||||
"successful_queries": 0,
|
||||
"failed_queries": 0,
|
||||
"cache_hits": 0,
|
||||
"cache_misses": 0,
|
||||
"query_patterns": {},
|
||||
}
|
||||
|
||||
|
||||
def _load_analytics() -> Dict[str, Any]:
|
||||
"""Load analytics data from the JSON file, creating defaults if missing."""
|
||||
if not ANALYTICS_FILE.exists():
|
||||
default = {
|
||||
"total_queries": 0,
|
||||
"successful_queries": 0,
|
||||
"failed_queries": 0,
|
||||
"cache_hits": 0,
|
||||
"cache_misses": 0,
|
||||
"query_patterns": {},
|
||||
}
|
||||
default = _default_analytics()
|
||||
_save_analytics(default)
|
||||
return default
|
||||
try:
|
||||
with open(ANALYTICS_FILE, "r", encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
data = json.load(f)
|
||||
# Merge over defaults so a file written by an older schema (or a
|
||||
# partial write) still has every counter — _record_query indexes
|
||||
# these keys directly and would otherwise raise KeyError.
|
||||
merged = _default_analytics()
|
||||
if isinstance(data, dict):
|
||||
merged.update(data)
|
||||
return merged
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load analytics file: {e}")
|
||||
return {
|
||||
"total_queries": 0,
|
||||
"successful_queries": 0,
|
||||
"failed_queries": 0,
|
||||
"cache_hits": 0,
|
||||
"cache_misses": 0,
|
||||
"query_patterns": {},
|
||||
}
|
||||
return _default_analytics()
|
||||
|
||||
|
||||
def _save_analytics(data: Dict[str, Any]) -> None:
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
"""Webpage content fetching with caching, PDF extraction, and summarization helpers."""
|
||||
|
||||
import copy
|
||||
import io
|
||||
import ipaddress
|
||||
import json
|
||||
@@ -115,6 +116,28 @@ def _extract_meta(soup: BeautifulSoup) -> dict:
|
||||
return {"description": description, "keywords": keywords}
|
||||
|
||||
|
||||
def _extract_og_image(soup: BeautifulSoup) -> str:
|
||||
"""Extract the best representative image URL from meta tags.
|
||||
|
||||
Only returns absolute http(s) URLs -- skips relative paths and data URIs.
|
||||
"""
|
||||
candidates = []
|
||||
for prop in ("og:image", "og:image:url", "og:image:secure_url"):
|
||||
tag = soup.find("meta", attrs={"property": prop})
|
||||
if tag and tag.get("content", "").strip():
|
||||
candidates.append(tag["content"].strip())
|
||||
tag = soup.find("meta", attrs={"name": "twitter:image"})
|
||||
if tag and tag.get("content", "").strip():
|
||||
candidates.append(tag["content"].strip())
|
||||
tag = soup.find("meta", attrs={"name": "thumbnail"})
|
||||
if tag and tag.get("content", "").strip():
|
||||
candidates.append(tag["content"].strip())
|
||||
for url in candidates:
|
||||
if url.startswith(("https://", "http://")) and not url.endswith((".svg", ".ico")):
|
||||
return url
|
||||
return ""
|
||||
|
||||
|
||||
def _extract_lists(soup: BeautifulSoup) -> List[List[str]]:
|
||||
"""Return a list of lists, each inner list representing a <ul>/<ol>."""
|
||||
all_lists = []
|
||||
@@ -275,10 +298,12 @@ def fetch_webpage_content(url: str, timeout: int = 5, retry_attempt: int = 0) ->
|
||||
title_tag = soup.find("title")
|
||||
title_text = title_tag.get_text(strip=True) if title_tag else ""
|
||||
meta_info = _extract_meta(soup)
|
||||
og_image = _extract_og_image(soup)
|
||||
js_rendered = _detect_js_frameworks(soup)
|
||||
js_message = "Page appears to be rendered by a JavaScript framework; content may be incomplete." if js_rendered else ""
|
||||
|
||||
# Main textual content (heuristic)
|
||||
# Main textual content (heuristic): prefer semantic / "content"-classed
|
||||
# containers to skip nav/footer/boilerplate; tuned for article pages.
|
||||
main_content = ""
|
||||
content_areas = soup.find_all(
|
||||
["main", "article", "section", "div"],
|
||||
@@ -287,12 +312,23 @@ def fetch_webpage_content(url: str, timeout: int = 5, retry_attempt: int = 0) ->
|
||||
if content_areas:
|
||||
for area in content_areas[:3]:
|
||||
main_content += area.get_text(separator=" ", strip=True) + " "
|
||||
if not main_content:
|
||||
main_content = re.sub(r"\s+", " ", main_content).strip()
|
||||
|
||||
# If the heuristic finds only a tiny wrapper, fall back to body text with
|
||||
# obvious boilerplate stripped so UI/deep-research search results do not
|
||||
# look empty for app/landing pages.
|
||||
THIN_CONTENT_CHARS = 600
|
||||
if len(main_content) < THIN_CONTENT_CHARS:
|
||||
body = soup.find("body")
|
||||
if body:
|
||||
main_content = body.get_text(separator=" ", strip=True)
|
||||
|
||||
main_content = re.sub(r"\s+", " ", main_content).strip()[:8000]
|
||||
body_copy = copy.copy(body)
|
||||
for noise in body_copy.find_all(
|
||||
["script", "style", "noscript", "template", "nav", "header", "footer", "aside"]
|
||||
):
|
||||
noise.extract()
|
||||
body_text = re.sub(r"\s+", " ", body_copy.get_text(separator=" ", strip=True)).strip()
|
||||
if len(body_text) > len(main_content):
|
||||
main_content = body_text
|
||||
|
||||
result = {
|
||||
"url": url,
|
||||
@@ -303,6 +339,7 @@ def fetch_webpage_content(url: str, timeout: int = 5, retry_attempt: int = 0) ->
|
||||
"code_blocks": _extract_code_blocks(soup),
|
||||
"meta_description": meta_info.get("description", ""),
|
||||
"meta_keywords": meta_info.get("keywords", ""),
|
||||
"og_image": og_image,
|
||||
"js_rendered": js_rendered,
|
||||
"js_message": js_message,
|
||||
"success": True,
|
||||
@@ -348,13 +385,18 @@ def get_tldr(text: str, max_sentences: int = 3) -> str:
|
||||
|
||||
def extract_quotes(text: str) -> List[str]:
|
||||
"""Return quoted excerpts that are at least 15 characters long."""
|
||||
return [m.group(1).strip() for m in re.finditer(r'["\']([^"\']{15,}?)["\']', text)]
|
||||
# Backreference the opening quote so the closing quote must match it —
|
||||
# otherwise `"text'` (open double, close single) is treated as a quote.
|
||||
return [m.group(2).strip() for m in re.finditer(r'(["\'])([^"\']{15,}?)\1', text)]
|
||||
|
||||
|
||||
def extract_statistics(text: str) -> List[str]:
|
||||
"""Find numbers, percentages, dates and simple measurements."""
|
||||
# Match a comma-grouped number (1,000,000) OR a plain digit run (50000) —
|
||||
# the old `\d{1,3}(?:,\d{3})*` matched only the first 3 digits of a
|
||||
# comma-less number, and the trailing `\b` dropped a closing `%`.
|
||||
pattern = re.compile(
|
||||
r"\b\d{1,3}(?:,\d{3})*(?:\.\d+)?\s*(%|percent|‰|per cent|[a-zA-Z]+)?\b",
|
||||
r"\b(?:\d{1,3}(?:,\d{3})+|\d+)(?:\.\d+)?\s*(%|percent|‰|per cent|[a-zA-Z]+)?",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
return [m.group(0).strip() for m in pattern.finditer(text)]
|
||||
|
||||
+54
-9
@@ -30,6 +30,7 @@ from .providers import (
|
||||
tavily_search,
|
||||
serper_search,
|
||||
_get_search_settings,
|
||||
_get_provider_key,
|
||||
_get_result_count,
|
||||
)
|
||||
from .content import (
|
||||
@@ -48,24 +49,48 @@ SEARCH_CONFIG: Dict[str, Any] = {
|
||||
}
|
||||
|
||||
|
||||
def _is_secret_key(name: str) -> bool:
|
||||
"""True for config keys that hold a credential (e.g. ``brave_api_key``)."""
|
||||
return name.endswith(("_api_key", "_key", "_token", "_secret"))
|
||||
|
||||
|
||||
def get_search_config() -> Dict[str, Any]:
|
||||
"""Get current search configuration including active provider info."""
|
||||
"""Get current search configuration including active provider info.
|
||||
|
||||
Never returns stored API keys: callers — including the unauthenticated
|
||||
``GET /api/search/config`` route — only need key *presence* via
|
||||
``has_api_key``, not the secret itself (#1661).
|
||||
"""
|
||||
config = SEARCH_CONFIG.copy()
|
||||
settings = _get_search_settings()
|
||||
provider = settings.get("search_provider", "searxng")
|
||||
config["active_provider"] = provider
|
||||
config["has_api_key"] = bool((settings.get("search_api_key") or "").strip())
|
||||
config["has_api_key"] = bool(_get_provider_key(provider))
|
||||
config["result_count"] = _get_result_count()
|
||||
if provider == "searxng":
|
||||
from .providers import _get_search_instance
|
||||
config["search_url"] = _get_search_instance()
|
||||
return config
|
||||
# Strip any string-valued credential so secrets never reach the response;
|
||||
# the boolean has_api_key flag (presence only) is preserved.
|
||||
return {
|
||||
k: v for k, v in config.items()
|
||||
if not (isinstance(v, str) and _is_secret_key(k))
|
||||
}
|
||||
|
||||
|
||||
def update_search_config(api_key: str = None, **kwargs):
|
||||
"""Update search configuration (e.g. Brave API key)."""
|
||||
if api_key:
|
||||
SEARCH_CONFIG["brave_api_key"] = api_key
|
||||
"""Merge non-secret search config into SEARCH_CONFIG.
|
||||
|
||||
Provider API keys are intentionally NOT cached here. They are read on demand
|
||||
from settings/env via ``_get_provider_key`` (e.g. ``brave_search``), so the
|
||||
previous ``SEARCH_CONFIG["brave_api_key"] = api_key`` cache was never used
|
||||
for search and only leaked the decrypted key through ``get_search_config`` /
|
||||
``GET /api/search/config`` (#1661). ``api_key`` is accepted for backward
|
||||
compatibility but no longer stored.
|
||||
"""
|
||||
for k, v in kwargs.items():
|
||||
if not _is_secret_key(k):
|
||||
SEARCH_CONFIG[k] = v
|
||||
|
||||
|
||||
def _call_provider(provider_name: str, query: str, count: int, time_filter: str = None) -> List[dict]:
|
||||
@@ -203,7 +228,10 @@ def invalidate_search_cache(query: Optional[str] = None) -> None:
|
||||
search_cache_index.clear()
|
||||
logger.info("All search cache entries have been cleared.")
|
||||
else:
|
||||
cache_key = generate_cache_key(f"{query}|10|None")
|
||||
# Match the key the write path stores: searxng_search_results replaces
|
||||
# the caller's default count with the configured _get_result_count()
|
||||
# (default 5), so a hardcoded "|10|None" never matched a real entry.
|
||||
cache_key = generate_cache_key(f"{query}|{_get_result_count()}|None")
|
||||
cache_file = SEARCH_CACHE_DIR / f"{cache_key}.cache"
|
||||
if cache_file.exists():
|
||||
try:
|
||||
@@ -328,6 +356,12 @@ def comprehensive_web_search(
|
||||
for r in search_results if r.get("url")
|
||||
]
|
||||
|
||||
# Map each URL to its [i] number in the sources list so fetched content
|
||||
# blocks can be labeled with the SAME index the model cites.
|
||||
_url_index = {
|
||||
r["url"]: i for i, r in enumerate(search_results, 1) if r.get("url")
|
||||
}
|
||||
|
||||
# Fetch content in parallel
|
||||
fetched_content = []
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
@@ -340,6 +374,10 @@ def comprehensive_web_search(
|
||||
try:
|
||||
result = future.result()
|
||||
if result["success"] and result["content"] and len(result["content"]) >= min_content_length:
|
||||
# Remember which source this fetch belongs to: redirects
|
||||
# can change result["url"] and completion order is
|
||||
# arbitrary, so the block label cannot be recomputed later.
|
||||
result["source_index"] = _url_index.get(url)
|
||||
fetched_content.append(result)
|
||||
except Exception as e:
|
||||
logger.error(f"Exception while fetching {url}: {str(e)}")
|
||||
@@ -380,8 +418,15 @@ def comprehensive_web_search(
|
||||
output_parts.append("FETCHED PAGE CONTENT:")
|
||||
output_parts.append("-" * 50)
|
||||
|
||||
for i, content in enumerate(fetched_content, 1):
|
||||
output_parts.append(f"\n[CONTENT {i}] From: {content['url']}")
|
||||
# Emit blocks in source order, numbered with the same [i] as the
|
||||
# sources list, so [CONTENT 2] really is content from source [2].
|
||||
# Before this, blocks were numbered 1..N in fetch COMPLETION order,
|
||||
# which matched neither the sources list nor each other run to run.
|
||||
fetched_content.sort(key=lambda c: c.get("source_index") or len(search_results) + 1)
|
||||
for content in fetched_content:
|
||||
_idx = content.get("source_index")
|
||||
_label = f"[CONTENT {_idx}]" if _idx else "[CONTENT]"
|
||||
output_parts.append(f"\n{_label} From: {content['url']}")
|
||||
output_parts.append(f"Title: {content['title']}")
|
||||
output_parts.append("-" * 30)
|
||||
|
||||
|
||||
+121
-10
@@ -4,6 +4,7 @@ import json
|
||||
import logging
|
||||
import os
|
||||
from typing import List, Optional
|
||||
from urllib.parse import urljoin, urlparse, parse_qs
|
||||
|
||||
import httpx
|
||||
from bs4 import BeautifulSoup
|
||||
@@ -63,7 +64,17 @@ def _get_provider_key(provider: str) -> str:
|
||||
if val:
|
||||
return val
|
||||
# Legacy fallback: old shared search_api_key field
|
||||
return (settings.get("search_api_key") or "").strip()
|
||||
legacy = (settings.get("search_api_key") or "").strip()
|
||||
if legacy:
|
||||
return legacy
|
||||
env_map = {
|
||||
"brave": "DATA_BRAVE_API_KEY",
|
||||
"google_pse": "GOOGLE_API_KEY",
|
||||
"tavily": "TAVILY_API_KEY",
|
||||
"serper": "SERPER_API_KEY",
|
||||
}
|
||||
env_name = env_map.get(provider, "")
|
||||
return (os.environ.get(env_name) or "").strip() if env_name else ""
|
||||
|
||||
|
||||
def _get_result_count() -> int:
|
||||
@@ -75,6 +86,43 @@ def _get_result_count() -> int:
|
||||
return 5
|
||||
|
||||
|
||||
# Canonical SafeSearch levels: "strict" (default), "moderate", "off".
|
||||
# Each provider has its own knob name and value space -- see _safesearch_for(...).
|
||||
_SAFESEARCH_LEVELS = ("strict", "moderate", "off")
|
||||
|
||||
|
||||
def _get_safesearch_level() -> str:
|
||||
"""Return configured SafeSearch level normalized to a canonical value."""
|
||||
settings = _get_search_settings()
|
||||
raw = (settings.get("search_safesearch") or "strict").strip().lower()
|
||||
if raw in _SAFESEARCH_LEVELS:
|
||||
return raw
|
||||
aliases = {
|
||||
"on": "strict", "high": "strict", "2": "strict",
|
||||
"medium": "moderate", "1": "moderate", "default": "moderate",
|
||||
"none": "off", "disabled": "off", "0": "off",
|
||||
}
|
||||
return aliases.get(raw, "strict")
|
||||
|
||||
|
||||
def _safesearch_for(provider: str) -> Optional[str]:
|
||||
"""Translate the canonical SafeSearch level into provider-specific values."""
|
||||
level = _get_safesearch_level()
|
||||
if provider == "searxng":
|
||||
return {"strict": "2", "moderate": "1", "off": "0"}[level]
|
||||
if provider == "brave":
|
||||
return level
|
||||
if provider == "duckduckgo_lib":
|
||||
return {"strict": "on", "moderate": "moderate", "off": "off"}[level]
|
||||
if provider == "duckduckgo_html":
|
||||
return {"strict": "1", "moderate": "-1", "off": "-2"}[level]
|
||||
if provider == "google_pse":
|
||||
return None if level == "off" else "active"
|
||||
if provider == "serper":
|
||||
return None if level == "off" else "active"
|
||||
return None
|
||||
|
||||
|
||||
# ── SearXNG ──
|
||||
|
||||
_NEWS_HINTS = ("news", "nyheter", "headlines", "breaking", "latest", "today", "idag")
|
||||
@@ -104,7 +152,12 @@ def searxng_search_api(query: str, count: int = 10, categories: str = "general",
|
||||
# languages and brand-ambiguous terms bleed in foreign SEO pages (e.g.
|
||||
# "Odyssey" → Honda Japan, "Trojan" → Japanese malware blogs, "Polyphemus"
|
||||
# → Chinese math forums). The news path already did this; general didn't.
|
||||
params = {"q": query, "format": "json", "language": "en"}
|
||||
params = {
|
||||
"q": query,
|
||||
"format": "json",
|
||||
"language": "en",
|
||||
"safesearch": _safesearch_for("searxng"),
|
||||
}
|
||||
q_lc = query.lower()
|
||||
is_news = time_filter is not None or any(h in q_lc for h in _NEWS_HINTS)
|
||||
if is_news and categories == "general":
|
||||
@@ -153,6 +206,7 @@ def searxng_search_api(query: str, count: int = 10, categories: str = "general",
|
||||
"format": "json",
|
||||
"language": "en",
|
||||
"categories": "general",
|
||||
"safesearch": _safesearch_for("searxng"),
|
||||
}
|
||||
if _GENERAL_ENGINES:
|
||||
fallback["engines"] = _GENERAL_ENGINES
|
||||
@@ -203,7 +257,7 @@ def searxng_search(query, max_results=10):
|
||||
try:
|
||||
response = httpx.get(
|
||||
f"{instance}/search",
|
||||
params={"q": query},
|
||||
params={"q": query, "safesearch": _safesearch_for("searxng")},
|
||||
headers=req_headers,
|
||||
timeout=10,
|
||||
)
|
||||
@@ -248,7 +302,11 @@ def _brave_search_impl(query: str, count: int, time_filter: Optional[str] = None
|
||||
return []
|
||||
|
||||
headers = {"X-Subscription-Token": brave_api_key, "Accept": "application/json"}
|
||||
params = {"q": enhanced_query, "count": count}
|
||||
params = {
|
||||
"q": enhanced_query,
|
||||
"count": count,
|
||||
"safesearch": _safesearch_for("brave"),
|
||||
}
|
||||
if time_filter:
|
||||
time_map = {"day": "day", "week": "week", "month": "month", "year": "year"}
|
||||
if time_filter in time_map:
|
||||
@@ -297,13 +355,40 @@ def _brave_search_impl(query: str, count: int, time_filter: Optional[str] = None
|
||||
|
||||
# ── DuckDuckGo (free, no key) ──
|
||||
|
||||
def _is_duckduckgo_host(host: str) -> bool:
|
||||
"""True only for duckduckgo.com and its subdomains."""
|
||||
host = (host or "").lower()
|
||||
return host == "duckduckgo.com" or host.endswith(".duckduckgo.com")
|
||||
|
||||
|
||||
def _resolve_ddg_redirect(raw: str) -> str:
|
||||
"""Resolve a DuckDuckGo /l/?uddg= redirect URL to its destination."""
|
||||
if not raw:
|
||||
return raw
|
||||
resolved = raw
|
||||
if resolved.startswith("//"):
|
||||
resolved = "https:" + resolved
|
||||
elif resolved.startswith("/"):
|
||||
resolved = urljoin("https://html.duckduckgo.com", resolved)
|
||||
try:
|
||||
parsed = urlparse(resolved)
|
||||
if _is_duckduckgo_host(parsed.hostname) and parsed.path.rstrip("/") == "/l":
|
||||
qs = parse_qs(parsed.query)
|
||||
if "uddg" in qs:
|
||||
return qs["uddg"][0]
|
||||
except Exception:
|
||||
pass
|
||||
return resolved
|
||||
|
||||
|
||||
def duckduckgo_search(query: str, count: int = 10, time_filter: Optional[str] = None) -> List[dict]:
|
||||
"""Search using DuckDuckGo via the duckduckgo-search library. No API key needed."""
|
||||
|
||||
def _html_fallback() -> List[dict]:
|
||||
try:
|
||||
response = httpx.get(
|
||||
"https://html.duckduckgo.com/html/",
|
||||
params={"q": query},
|
||||
params={"q": query, "kp": _safesearch_for("duckduckgo_html")},
|
||||
headers={"User-Agent": "Mozilla/5.0"},
|
||||
timeout=REQUEST_TIMEOUT,
|
||||
)
|
||||
@@ -314,7 +399,7 @@ def duckduckgo_search(query: str, count: int = 10, time_filter: Optional[str] =
|
||||
link = result.select_one(".result__a")
|
||||
if not link:
|
||||
continue
|
||||
url = link.get("href", "")
|
||||
url = _resolve_ddg_redirect(link.get("href", ""))
|
||||
if not url:
|
||||
continue
|
||||
snippet_el = result.select_one(".result__snippet")
|
||||
@@ -342,7 +427,12 @@ def duckduckgo_search(query: str, count: int = 10, time_filter: Optional[str] =
|
||||
|
||||
try:
|
||||
ddgs = DDGS()
|
||||
raw = ddgs.text(query, max_results=count, timelimit=timelimit)
|
||||
raw = ddgs.text(
|
||||
query,
|
||||
max_results=count,
|
||||
timelimit=timelimit,
|
||||
safesearch=_safesearch_for("duckduckgo_lib"),
|
||||
)
|
||||
results = []
|
||||
for item in raw:
|
||||
url = item.get("href", "")
|
||||
@@ -384,6 +474,9 @@ def google_pse_search(query: str, count: int = 10, time_filter: Optional[str] =
|
||||
"q": query,
|
||||
"num": min(count, 10), # Google PSE max is 10 per request
|
||||
}
|
||||
safe = _safesearch_for("google_pse")
|
||||
if safe:
|
||||
params["safe"] = safe
|
||||
if time_filter:
|
||||
# dateRestrict: d[number], w[number], m[number], y[number]
|
||||
time_map = {"day": "d1", "week": "w1", "month": "m1", "year": "y1"}
|
||||
@@ -399,7 +492,6 @@ def google_pse_search(query: str, count: int = 10, time_filter: Optional[str] =
|
||||
if response.status_code == 429:
|
||||
raise RateLimitError("Google PSE rate limit hit")
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
except httpx.RequestError as e:
|
||||
error_logger.error(f"Google PSE search failed: {e}")
|
||||
return []
|
||||
@@ -407,6 +499,12 @@ def google_pse_search(query: str, count: int = 10, time_filter: Optional[str] =
|
||||
error_logger.error(str(e))
|
||||
return []
|
||||
|
||||
try:
|
||||
data = response.json()
|
||||
except json.JSONDecodeError as e:
|
||||
error_logger.error(f"Google PSE returned invalid JSON: {e}")
|
||||
return []
|
||||
|
||||
results = []
|
||||
for item in data.get("items", [])[:count]:
|
||||
url = item.get("link", "")
|
||||
@@ -451,7 +549,6 @@ def tavily_search(query: str, count: int = 10, time_filter: Optional[str] = None
|
||||
if response.status_code == 429:
|
||||
raise RateLimitError("Tavily rate limit hit")
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
except httpx.RequestError as e:
|
||||
error_logger.error(f"Tavily search failed: {e}")
|
||||
return []
|
||||
@@ -459,6 +556,12 @@ def tavily_search(query: str, count: int = 10, time_filter: Optional[str] = None
|
||||
error_logger.error(str(e))
|
||||
return []
|
||||
|
||||
try:
|
||||
data = response.json()
|
||||
except json.JSONDecodeError as e:
|
||||
error_logger.error(f"Tavily returned invalid JSON: {e}")
|
||||
return []
|
||||
|
||||
results = []
|
||||
for item in data.get("results", [])[:count]:
|
||||
url = item.get("url", "")
|
||||
@@ -488,6 +591,9 @@ def serper_search(query: str, count: int = 10, time_filter: Optional[str] = None
|
||||
"q": query,
|
||||
"num": count,
|
||||
}
|
||||
safe = _safesearch_for("serper")
|
||||
if safe:
|
||||
payload["safe"] = safe
|
||||
if time_filter:
|
||||
time_map = {"day": "qdr:d", "week": "qdr:w", "month": "qdr:m", "year": "qdr:y"}
|
||||
if time_filter in time_map:
|
||||
@@ -503,7 +609,6 @@ def serper_search(query: str, count: int = 10, time_filter: Optional[str] = None
|
||||
if response.status_code == 429:
|
||||
raise RateLimitError("Serper rate limit hit")
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
except httpx.RequestError as e:
|
||||
error_logger.error(f"Serper search failed: {e}")
|
||||
return []
|
||||
@@ -511,6 +616,12 @@ def serper_search(query: str, count: int = 10, time_filter: Optional[str] = None
|
||||
error_logger.error(str(e))
|
||||
return []
|
||||
|
||||
try:
|
||||
data = response.json()
|
||||
except json.JSONDecodeError as e:
|
||||
error_logger.error(f"Serper returned invalid JSON: {e}")
|
||||
return []
|
||||
|
||||
results = []
|
||||
for item in data.get("organic", [])[:count]:
|
||||
url = item.get("link", "")
|
||||
|
||||
@@ -13,15 +13,22 @@ logger = logging.getLogger(__name__)
|
||||
# ----------------------------------------------------------------------
|
||||
def _detect_question_type(query: str) -> Optional[str]:
|
||||
"""Return the leading question word if present (who, what, when, where, why, how)."""
|
||||
if not isinstance(query, str):
|
||||
return None
|
||||
q = query.strip().lower()
|
||||
for word in ("who", "what", "when", "where", "why", "how"):
|
||||
if q.startswith(word):
|
||||
# Require a whole-word match: a bare prefix mis-flags ordinary queries
|
||||
# like "whatsapp pricing" (-> what) or "however ..." (-> how), which
|
||||
# then get spurious boost terms OR-appended in enhance_query.
|
||||
if q == word or q.startswith(word + " "):
|
||||
return word
|
||||
return None
|
||||
|
||||
|
||||
def _extract_entities(query: str) -> Dict[str, List[str]]:
|
||||
"""Lightweight entity extraction: capitalized words and date patterns."""
|
||||
if not isinstance(query, str):
|
||||
return {"names": [], "dates": []}
|
||||
entities: Dict[str, List[str]] = {"names": [], "dates": []}
|
||||
qtype = _detect_question_type(query)
|
||||
cleaned = query
|
||||
@@ -42,12 +49,16 @@ def _extract_entities(query: str) -> Dict[str, List[str]]:
|
||||
|
||||
def _split_multi_part(query: str) -> List[str]:
|
||||
"""Split a query into sub-queries on common conjunctions."""
|
||||
if not isinstance(query, str):
|
||||
return []
|
||||
parts = re.split(r"\s+and\s+|\s+or\s+|;", query, flags=re.I)
|
||||
return [p.strip() for p in parts if p.strip()]
|
||||
|
||||
|
||||
def _extract_site_filter(query: str) -> Tuple[str, Optional[str]]:
|
||||
"""Detect a 'site:example.com' token. Returns (query_without_token, site_or_None)."""
|
||||
if not isinstance(query, str):
|
||||
return "", None
|
||||
match = re.search(r"\bsite:([^\s]+)", query, flags=re.I)
|
||||
if match:
|
||||
site = match.group(1)
|
||||
@@ -68,6 +79,8 @@ def _boost_entities_in_query(base_query: str, entities: Dict[str, List[str]]) ->
|
||||
|
||||
def enhance_query(original_query: str) -> Tuple[str, Optional[str]]:
|
||||
"""Process the original query: site filter, question type boosts, entity extraction."""
|
||||
if not isinstance(original_query, str):
|
||||
original_query = ""
|
||||
query_without_site, site = _extract_site_filter(original_query)
|
||||
sub_queries = _split_multi_part(query_without_site)
|
||||
|
||||
@@ -117,6 +130,8 @@ def build_enhanced_query(query: str, time_filter: str = None) -> str:
|
||||
def _is_news_query(query: str) -> bool:
|
||||
"""Lightweight heuristic to decide if a query is news-oriented."""
|
||||
news_terms = {"news", "latest", "breaking", "today", "today's", "current", "updates", "happening"}
|
||||
if not isinstance(query, str):
|
||||
return False
|
||||
tokens = set(re.findall(r"\b\w+\b", query.lower()))
|
||||
return bool(tokens & news_terms)
|
||||
|
||||
|
||||
@@ -13,6 +13,11 @@ _SPORTS_HINTS = {
|
||||
"sport", "sports", "soccer", "football", "hockey", "nba", "nfl", "mlb",
|
||||
"fifa", "world cup", "championship", "quarterfinal", "eliminates",
|
||||
}
|
||||
# Word-boundary match so "sport" does not fire inside "transport"/"passport"
|
||||
# and a domain like "transport.gov" is not mistaken for a sports site.
|
||||
_SPORTS_HINT_RE = re.compile(
|
||||
r"\b(?:" + "|".join(re.escape(h) for h in _SPORTS_HINTS) + r")\b"
|
||||
)
|
||||
_LOW_VALUE_NEWS_DOMAINS = {
|
||||
"facebook.com", "www.facebook.com", "sports.yahoo.com", "yahoo.com",
|
||||
"www.yahoo.com", "msn.com", "www.msn.com",
|
||||
@@ -39,7 +44,7 @@ def rank_search_results(query: str, results: List[dict]) -> List[dict]:
|
||||
query_terms = [t.lower() for t in re.findall(r"\b\w+\b", query)]
|
||||
query_lc = query.lower()
|
||||
is_news_query = any(term in _NEWS_HINTS for term in query_terms)
|
||||
is_sports_query = any(hint in query_lc for hint in _SPORTS_HINTS)
|
||||
is_sports_query = bool(_SPORTS_HINT_RE.search(query_lc))
|
||||
|
||||
def title_score(title: str) -> float:
|
||||
if not title:
|
||||
@@ -98,7 +103,7 @@ def rank_search_results(query: str, results: List[dict]) -> List[dict]:
|
||||
adjustment += 0.4
|
||||
if netloc in _LOW_VALUE_NEWS_DOMAINS:
|
||||
adjustment -= 0.8
|
||||
if not is_sports_query and any(hint in text or hint in netloc for hint in _SPORTS_HINTS):
|
||||
if not is_sports_query and (_SPORTS_HINT_RE.search(text) or _SPORTS_HINT_RE.search(netloc)):
|
||||
adjustment -= 1.5
|
||||
# A country/news query should not rank a page whose title/snippet barely
|
||||
# mentions the country above actual news pages for that country.
|
||||
|
||||
@@ -62,17 +62,24 @@ class SearchService:
|
||||
SearchResponse with results
|
||||
"""
|
||||
depth = depth or self.default_depth
|
||||
fetch_content = fetch_content if fetch_content is not None else self.fetch_content
|
||||
|
||||
# Use existing search implementation
|
||||
raw_results = await comprehensive_web_search(
|
||||
# comprehensive_web_search is synchronous and, with return_sources=True,
|
||||
# returns (context_str, [{"url", "title"}, ...]). Run it off the event
|
||||
# loop so we don't block it, and use the source list as the result rows.
|
||||
# `fetch_content` is accepted for API compatibility; the comprehensive
|
||||
# search always fetches page content.
|
||||
import asyncio
|
||||
_context, raw_results = await asyncio.to_thread(
|
||||
comprehensive_web_search,
|
||||
query,
|
||||
max_results=10 * depth,
|
||||
fetch_content=fetch_content,
|
||||
max_pages=10 * depth,
|
||||
return_sources=True,
|
||||
)
|
||||
|
||||
results = []
|
||||
for r in raw_results:
|
||||
if not isinstance(r, dict):
|
||||
continue
|
||||
results.append(SearchResult(
|
||||
url=r.get("url", ""),
|
||||
title=r.get("title", ""),
|
||||
|
||||
@@ -125,10 +125,11 @@ class ShellService:
|
||||
asyncio.create_task(_reader(proc.stderr, "stderr")),
|
||||
]
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
finished = 0
|
||||
deadline = asyncio.get_event_loop().time() + timeout
|
||||
deadline = loop.time() + timeout
|
||||
while finished < 2:
|
||||
remaining = deadline - asyncio.get_event_loop().time()
|
||||
remaining = deadline - loop.time()
|
||||
if remaining <= 0:
|
||||
raise asyncio.TimeoutError()
|
||||
|
||||
|
||||
+28
-11
@@ -40,6 +40,8 @@ class STTService:
|
||||
@property
|
||||
def available(self) -> bool:
|
||||
settings = self._load_settings()
|
||||
if settings.get("stt_enabled") is False:
|
||||
return False
|
||||
provider = settings["stt_provider"]
|
||||
if provider == "disabled":
|
||||
return False
|
||||
@@ -57,17 +59,29 @@ class STTService:
|
||||
if self._whisper_model is None:
|
||||
try:
|
||||
from faster_whisper import WhisperModel
|
||||
settings = self._load_settings()
|
||||
model_size = settings.get("stt_model", "base")
|
||||
# Use CPU by default; will use CUDA if available
|
||||
import torch
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
compute_type = "float16" if device == "cuda" else "int8"
|
||||
self._whisper_model = WhisperModel(model_size, device=device, compute_type=compute_type)
|
||||
logger.info(f"faster-whisper model '{model_size}' loaded on {device}")
|
||||
except ImportError:
|
||||
logger.warning("faster-whisper not installed. Install with: pip install faster-whisper")
|
||||
return None
|
||||
try:
|
||||
settings = self._load_settings()
|
||||
model_size = settings.get("stt_model", "base")
|
||||
# faster-whisper runs on CTranslate2, not torch. torch is only
|
||||
# used (optionally) to detect a CUDA device for acceleration —
|
||||
# if it's missing or unusable we just run on CPU. Keeping this
|
||||
# probe separate (and tolerant of any failure, e.g. a broken
|
||||
# CUDA/torch install that raises OSError on import) means a
|
||||
# torch-less or torch-broken machine still does CPU
|
||||
# transcription instead of failing with a misleading
|
||||
# "faster-whisper not installed" error.
|
||||
try:
|
||||
import torch
|
||||
use_cuda = torch.cuda.is_available()
|
||||
except Exception:
|
||||
use_cuda = False
|
||||
device = "cuda" if use_cuda else "cpu"
|
||||
compute_type = "float16" if device == "cuda" else "int8"
|
||||
self._whisper_model = WhisperModel(model_size, device=device, compute_type=compute_type)
|
||||
logger.info(f"faster-whisper model '{model_size}' loaded on {device}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load whisper model: {e}")
|
||||
return None
|
||||
@@ -77,6 +91,7 @@ class STTService:
|
||||
model = self._get_whisper()
|
||||
if not model:
|
||||
return None
|
||||
tmp_path = None
|
||||
try:
|
||||
# Write to temp file (faster-whisper needs a file path or file-like)
|
||||
with tempfile.NamedTemporaryFile(suffix=".webm", delete=False) as tmp:
|
||||
@@ -90,14 +105,14 @@ class STTService:
|
||||
segments, info = model.transcribe(tmp_path, **kwargs)
|
||||
text = " ".join(seg.text.strip() for seg in segments)
|
||||
|
||||
# Cleanup
|
||||
Path(tmp_path).unlink(missing_ok=True)
|
||||
|
||||
logger.info(f"Local STT: {len(text)} chars, lang={info.language}, prob={info.language_probability:.2f}")
|
||||
return text
|
||||
except Exception as e:
|
||||
logger.error(f"Local STT transcription failed: {e}", exc_info=True)
|
||||
return None
|
||||
finally:
|
||||
if tmp_path:
|
||||
Path(tmp_path).unlink(missing_ok=True)
|
||||
|
||||
# ── API endpoint ──
|
||||
|
||||
@@ -140,6 +155,8 @@ class STTService:
|
||||
|
||||
def transcribe(self, audio_bytes: bytes) -> Optional[str]:
|
||||
settings = self._load_settings()
|
||||
if settings.get("stt_enabled") is False:
|
||||
return None
|
||||
provider = settings["stt_provider"]
|
||||
model = settings["stt_model"]
|
||||
language = settings.get("stt_language", "")
|
||||
|
||||
@@ -12,6 +12,18 @@ from typing import Optional, Dict, Any
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _safe_speed(value, default: float = 1.0) -> float:
|
||||
"""Parse the stored tts_speed defensively. The settings layer tolerates
|
||||
corrupt/agent-written config, so a non-numeric or empty value (e.g. an agent
|
||||
setting "speech speed" = "fast", or a hand-edited settings.json) must not
|
||||
crash synthesis or the stats endpoint with a ValueError."""
|
||||
try:
|
||||
speed = float(value)
|
||||
except (TypeError, ValueError):
|
||||
return default
|
||||
return speed if speed > 0 else default
|
||||
|
||||
|
||||
class TTSService:
|
||||
"""Multi-provider TTS service.
|
||||
|
||||
@@ -34,6 +46,7 @@ class TTSService:
|
||||
from src.settings import load_settings
|
||||
saved = load_settings()
|
||||
return {
|
||||
"tts_enabled": saved.get("tts_enabled", True),
|
||||
"tts_provider": saved.get("tts_provider", "disabled"),
|
||||
"tts_model": saved.get("tts_model", "tts-1"),
|
||||
"tts_voice": saved.get("tts_voice", "alloy"),
|
||||
@@ -43,6 +56,8 @@ class TTSService:
|
||||
@property
|
||||
def available(self) -> bool:
|
||||
settings = self._load_settings()
|
||||
if settings.get("tts_enabled") is False:
|
||||
return False
|
||||
provider = settings["tts_provider"]
|
||||
if provider == "disabled":
|
||||
return False
|
||||
@@ -128,10 +143,12 @@ class TTSService:
|
||||
|
||||
def synthesize(self, text: str, use_cache: bool = True) -> Optional[bytes]:
|
||||
settings = self._load_settings()
|
||||
if settings.get("tts_enabled") is False:
|
||||
return None
|
||||
provider = settings["tts_provider"]
|
||||
model = settings["tts_model"]
|
||||
voice = settings["tts_voice"]
|
||||
speed = float(settings.get("tts_speed", "1"))
|
||||
speed = _safe_speed(settings.get("tts_speed", "1"))
|
||||
|
||||
if provider in ("disabled", "browser"):
|
||||
return None
|
||||
@@ -183,7 +200,7 @@ class TTSService:
|
||||
provider = settings["tts_provider"]
|
||||
tts_enabled = settings.get("tts_enabled", True)
|
||||
|
||||
cache_files = list(self.cache_dir.glob("*.wav"))
|
||||
cache_files = list(self.cache_dir.glob("*.wav")) + list(self.cache_dir.glob("*.mp3"))
|
||||
cache_size = sum(f.stat().st_size for f in cache_files)
|
||||
|
||||
is_available = self.available and tts_enabled
|
||||
@@ -193,7 +210,7 @@ class TTSService:
|
||||
"provider": provider,
|
||||
"model": settings["tts_model"],
|
||||
"voice": settings["tts_voice"],
|
||||
"speed": float(settings.get("tts_speed", "1")),
|
||||
"speed": _safe_speed(settings.get("tts_speed", "1")),
|
||||
"cache_entries": len(cache_files),
|
||||
"cache_size_mb": round(cache_size / (1024 * 1024), 2),
|
||||
}
|
||||
|
||||
@@ -59,11 +59,15 @@ def init_youtube():
|
||||
|
||||
|
||||
def is_youtube_url(url: str) -> bool:
|
||||
if not isinstance(url, str):
|
||||
return False
|
||||
return "youtube.com" in url or "youtu.be" in url
|
||||
|
||||
|
||||
def extract_youtube_id(url: str) -> Optional[str]:
|
||||
"""Extract YouTube video ID from various URL formats."""
|
||||
if not isinstance(url, str):
|
||||
return None
|
||||
parsed = urllib.parse.urlparse(url)
|
||||
if parsed.hostname in ("www.youtube.com", "youtube.com", "m.youtube.com"):
|
||||
if parsed.path == "/watch":
|
||||
@@ -254,6 +258,8 @@ def format_comments_for_context(comments_data: Dict[str, Any], url: str) -> str:
|
||||
ctx += f"URL: {url}\n\n"
|
||||
|
||||
for i, c in enumerate(comments, 1):
|
||||
if not isinstance(c, dict):
|
||||
continue
|
||||
likes = c.get("likes", 0)
|
||||
likes_str = f" [{likes} likes]" if likes else ""
|
||||
ctx += f"{i}. @{c['author']}{likes_str}: {c['text']}\n\n"
|
||||
|
||||
Reference in New Issue
Block a user