Odysseus v1.0

This commit is contained in:
pewdiepie-archdaemon
2026-05-31 23:58:26 +09:00
commit e5c99a5eee
421 changed files with 271349 additions and 0 deletions
+37
View File
@@ -0,0 +1,37 @@
# services/__init__.py
"""
Service layer — plug-in capabilities for the chat core.
Each service:
- Does one thing well
- Exposes a clean async interface
- Can run in-process or as a standalone HTTP service
"""
from .search import SearchService, SearchResult, SearchResponse
from .docs import DocsService, DocChunk, IndexResult
from .research import ResearchService, ResearchResult, ResearchSource
from .memory import MemoryService, Memory, MemorySearchResult
from .shell import ShellService, ShellResult
__all__ = [
# Search
"SearchService",
"SearchResult",
"SearchResponse",
# Docs
"DocsService",
"DocChunk",
"IndexResult",
# Research
"ResearchService",
"ResearchResult",
"ResearchSource",
# Memory
"MemoryService",
"Memory",
"MemorySearchResult",
# Shell
"ShellService",
"ShellResult",
]
+18
View File
@@ -0,0 +1,18 @@
# services/docs/__init__.py
"""Docs service — personal document RAG with ChromaDB.
Thin facade: DocsService lives here, RAGManager/VectorRAG are re-exported
from the canonical implementations in src/.
"""
from .service import DocsService, DocChunk, IndexResult
from src.rag_manager import RAGManager
from src.rag_vector import VectorRAG
__all__ = [
"DocsService",
"DocChunk",
"IndexResult",
"RAGManager",
"VectorRAG",
]
+89
View File
@@ -0,0 +1,89 @@
# services/docs/service.py
"""Docs service — personal document RAG."""
from dataclasses import dataclass
from typing import List, Dict, Any
from src.rag_manager import RAGManager
@dataclass
class DocChunk:
"""A retrieved document chunk."""
text: str
source: str
score: float
metadata: Dict[str, Any] = None
@dataclass
class IndexResult:
"""Result of indexing documents."""
indexed: int
failed: int
errors: List[str]
class DocsService:
"""
Document RAG service.
Usage:
service = DocsService()
await service.index("/path/to/docs")
results = await service.query("what is async await?")
"""
def __init__(self, persist_dir: str = "data/chroma"):
self.rag = RAGManager(persist_directory=persist_dir)
async def query(self, query: str, top_k: int = 5) -> List[DocChunk]:
"""
Query the document index.
Args:
query: Search query
top_k: Number of results
Returns:
List of DocChunk objects
"""
results = self.rag.search(query, k=top_k)
return [
DocChunk(
text=r.get("text", r.get("content", "")),
source=r.get("source", r.get("metadata", {}).get("source", "unknown")),
score=r.get("score", 0.0),
metadata=r.get("metadata"),
)
for r in results
]
async def index(self, directory: str) -> IndexResult:
"""
Index documents from a directory.
Args:
directory: Path to documents
Returns:
IndexResult with stats
"""
result = self.rag.index_personal_documents(directory)
return IndexResult(
indexed=result.get("indexed", 0),
failed=result.get("failed", 0),
errors=result.get("errors", []),
)
async def add_document(self, text: str, metadata: Dict[str, Any]) -> bool:
"""Add a single document to the index."""
return self.rag.add_document(text, metadata)
def get_stats(self) -> Dict[str, Any]:
"""Get index statistics."""
return self.rag.get_stats()
def rebuild_index(self) -> bool:
"""Rebuild the entire index."""
return self.rag.rebuild_index()
+1
View File
@@ -0,0 +1 @@
"""Face detection + embedding service (standalone worker + helpers)."""
View File
File diff suppressed because it is too large Load Diff
+463
View File
@@ -0,0 +1,463 @@
import re
from services.hwfit.models import (
params_b, estimate_memory_gb, infer_use_case,
get_models, is_prequantized, _active_params_b, QUANT_BYTES_PER_PARAM,
QUANT_SPEED_MULT, QUANT_QUALITY_PENALTY,
)
GPU_BANDWIDTH = {
"5090": 1792, "5080": 960, "5070 ti": 896, "5070": 672, "5060 ti": 448, "5060": 256,
"4090": 1008, "4080 super": 736, "4080": 717, "4070 ti super": 672, "4070 ti": 504, "4070 super": 504, "4070": 504, "4060 ti": 288, "4060": 272,
"3090 ti": 1008, "3090": 936, "3080 ti": 912, "3080": 760, "3070 ti": 608, "3070": 448, "3060 ti": 448, "3060": 360,
"2080 ti": 616, "2080 super": 496, "2080": 448, "2070 super": 448, "2070": 448, "2060 super": 448, "2060": 336,
"1660 ti": 288, "1660 super": 336, "1660": 192, "1650 super": 192, "1650": 128,
"h100 sxm": 3350, "h100": 2039, "h200": 4800, "a100 sxm": 2039, "a100": 1555,
"l40s": 864, "l40": 864, "l4": 300, "a10g": 600, "a10": 600, "t4": 320,
"v100 sxm": 900, "v100": 897, "a6000": 768, "a5000": 768, "a4000": 448,
"7900 xtx": 960, "7900 xt": 800, "7900 gre": 576, "7800 xt": 624, "7700 xt": 432, "7600": 288,
"6950 xt": 576, "6900 xt": 512, "6800 xt": 512, "6800": 512, "6700 xt": 384, "6600 xt": 256, "6600": 224,
"mi300x": 5300, "mi300": 5300, "mi250x": 3277, "mi250": 3277, "mi210": 1638, "mi100": 1229,
"9070 xt": 624, "9070": 488,
}
# Pre-sort keys by length descending for correct substring matching
_BW_KEYS_SORTED = sorted(GPU_BANDWIDTH.keys(), key=len, reverse=True)
FALLBACK_K = {"cuda": 220, "rocm": 180, "cpu_x86": 70, "cpu_arm": 90}
USE_CASE_WEIGHTS = {
"general": (0.45, 0.30, 0.15, 0.10),
"coding": (0.50, 0.20, 0.15, 0.15),
"reasoning": (0.55, 0.15, 0.15, 0.15),
"chat": (0.40, 0.35, 0.15, 0.10),
"multimodal": (0.50, 0.20, 0.15, 0.15),
"embedding": (0.30, 0.40, 0.20, 0.10),
"tts": (0.40, 0.35, 0.15, 0.10),
"stt": (0.40, 0.35, 0.15, 0.10),
}
SPEED_TARGET = {
"general": 40, "coding": 40, "multimodal": 40, "chat": 40,
"reasoning": 25, "embedding": 200, "tts": 40, "stt": 40,
}
CONTEXT_TARGET = {
"general": 4096, "chat": 4096, "coding": 8192,
"reasoning": 8192, "multimodal": 4096, "embedding": 512,
"tts": 2048, "stt": 2048,
}
def _lookup_bandwidth(gpu_name):
if not gpu_name:
return None
gn = gpu_name.lower()
for key in _BW_KEYS_SORTED:
if key in gn:
return GPU_BANDWIDTH[key]
return None
def _estimate_speed(model, quant, run_mode, system):
"""Estimate tok/s. Uses active params for MoE (only active experts run per token)."""
pb = _active_params_b(model)
is_moe = model.get("is_moe", False)
bw = _lookup_bandwidth(system.get("gpu_name"))
backend = system.get("backend", "cpu_x86")
if bw and run_mode in ("gpu", "cpu_offload"):
bpp = QUANT_BYTES_PER_PARAM.get(quant, 0.5)
model_gb = pb * bpp
if model_gb <= 0:
return 0.0
efficiency = 0.55
raw_tps = (bw / model_gb) * efficiency
if run_mode == "cpu_offload":
mode_factor = 0.5
elif is_moe:
mode_factor = 0.8
else:
mode_factor = 1.0
return raw_tps * mode_factor
k = FALLBACK_K.get(backend, 70)
if pb <= 0:
return 0.0
sm = QUANT_SPEED_MULT.get(quant, 1.0)
return k / pb * sm
def _quality_score(model, quant, use_case):
pb = params_b(model)
if pb < 1:
base = 30
elif pb < 3:
base = 45
elif pb < 7:
base = 60
elif pb < 10:
base = 75
elif pb < 20:
base = 82
elif pb < 40:
base = 89
else:
base = 95
name_lower = model.get("name", "").lower()
if "qwen" in name_lower:
base += 2
if "deepseek" in name_lower:
base += 3
if "llama" in name_lower:
base += 2
if "mistral" in name_lower or "mixtral" in name_lower:
base += 1
if "gemma" in name_lower:
base += 1
base += QUANT_QUALITY_PENALTY.get(quant, 0)
model_uc = infer_use_case(model)
if model_uc == "coding" and use_case == "coding":
base += 6
if model_uc == "reasoning" and use_case == "reasoning" and pb >= 13:
base += 5
if model_uc == "multimodal" and use_case == "multimodal":
base += 6
return max(0, min(100, base))
def _speed_score(tps, use_case):
target = SPEED_TARGET.get(use_case, 40)
return max(0, min(100, (tps / target) * 100))
def _fit_score(required, available):
if required > available:
return 0
if available <= 0:
return 0
ratio = required / available
if ratio <= 0.5:
return 60 + (ratio / 0.5) * 40
if ratio <= 0.8:
return 100
if ratio <= 0.9:
return 70
return 50
def _context_score(ctx, use_case):
target = CONTEXT_TARGET.get(use_case, 4096)
if ctx >= target:
return 100
if ctx >= target / 2:
return 70
return 30
def _try_quant_at(model, quant, ctx, gpu_vram, available_ram):
"""Try a specific quant at a given context. Returns (run_mode, quant, ctx, mem) or None."""
mem = estimate_memory_gb(model, quant, ctx)
if gpu_vram > 0 and mem <= gpu_vram:
return "gpu", quant, ctx, mem
if gpu_vram > 0 and mem <= available_ram:
return "cpu_offload", quant, ctx, mem
if gpu_vram <= 0 and mem <= available_ram:
return "cpu_only", quant, ctx, mem
# Try halving context
cur_ctx = ctx // 2
while cur_ctx >= 1024:
mem = estimate_memory_gb(model, quant, cur_ctx)
if gpu_vram > 0 and mem <= gpu_vram:
return "gpu", quant, cur_ctx, mem
if mem <= available_ram:
return ("cpu_offload" if gpu_vram > 0 else "cpu_only"), quant, cur_ctx, mem
cur_ctx //= 2
return None
def _quant_bits(q):
"""Approximate bit-width of a quant label so GGUF quant tiers (Q4/Q8/…) can
be matched against prequantized formats (AWQ 4, AWQ-8bit, FP8, GPTQ-4bit…).
Returns 0 when unknown (caller treats unknown as "don't filter")."""
qu = (q or "").upper().replace("-", "").replace("_", "").replace(" ", "")
# GGUF k-quants + float formats
if qu.startswith("Q8") or "FP8" in qu:
return 8
if qu.startswith("Q4") or qu.startswith("IQ4"):
return 4
if qu.startswith("Q2") or qu.startswith("IQ2"):
return 2
if qu.startswith("Q3") or qu.startswith("IQ3"):
return 3
if qu.startswith("Q5"):
return 5
if qu.startswith("Q6"):
return 6
if qu.startswith("F16") or qu.startswith("BF16") or qu.startswith("F32"):
return 16
# Prequantized formats: pull the bit-width digit (AWQ4 / AWQ4BIT / GPTQ8 / 4BIT / INT8 …)
m = re.search(r"(?:AWQ|GPTQ|MLX|EXL2|BNB|INT|W)(\d{1,2})", qu) or re.search(r"(\d{1,2})BIT", qu)
if m:
b = int(m.group(1))
if 2 <= b <= 16:
return b
return 0
def analyze_model(model, system, target_quant=None):
pb = params_b(model)
if pb <= 0:
return None
use_case = infer_use_case(model)
has_gpu = system.get("has_gpu", False)
gpu_vram = (system.get("gpu_vram_gb") or 0) if has_gpu else 0
gpu_count = system.get("gpu_count", 1) or 1
single_gpu_vram = gpu_vram / gpu_count if gpu_count > 1 else gpu_vram
available_ram = system.get("available_ram_gb", 0)
# When the user has explicitly picked a GPU config (not RAM mode), they want
# to see what runs ON the GPU(s) — not big models that only "fit" by spilling
# most layers to system RAM. Zeroing the offload budget makes _try_quant_at
# take only its GPU branches (fit on VRAM, shrinking context if needed),
# otherwise return None. Fixes "96 GB GPU still lists a 175 GB model".
gpu_only = bool(system.get("gpu_only")) and has_gpu and gpu_vram > 0
eff_ram = 0 if gpu_only else available_ram
is_moe = model.get("is_moe", False)
ctx = model.get("context_length", 4096) or 4096
native_quant = model.get("quantization", "Q4_K_M")
preq = is_prequantized(model)
# GGUF models can't be sharded across GPUs — use single GPU VRAM
is_gguf = bool(model.get("gguf_sources"))
quant_upper = (native_quant or "").upper()
is_gguf_quant = any(quant_upper.startswith(p) for p in ("Q2", "Q3", "Q4", "Q5", "Q6", "Q8", "IQ", "F16", "F32"))
# Single-GPU VRAM only applies to GGUF/dense builds (llama.cpp can't shard
# across GPUs). Prequantized formats (AWQ/GPTQ/FP8) are served sharded by
# vLLM across all GPUs, so they get the FULL multi-GPU VRAM — even when the
# model also lists a GGUF alternate download (gguf_sources).
if (is_gguf or is_gguf_quant) and not preq:
effective_vram = single_gpu_vram
else:
effective_vram = gpu_vram
# Determine which quant to evaluate at
if preq:
# AWQ/GPTQ/FP8/MLX come at a fixed bit-width. If the user picked a
# specific quant tier (e.g. Q8 → 8-bit), only keep prequant models whose
# native bit-width matches — otherwise selecting Q8 would still surface
# AWQ-4bit models, mixing 4- and 8-bit in one view.
if target_quant:
_tb, _nb = _quant_bits(target_quant), _quant_bits(native_quant)
if _tb and _nb and _tb != _nb:
return None
quant_to_try = native_quant
elif target_quant:
# User picked a specific quant
quant_to_try = target_quant
else:
# Default: Q4_K_M (user's stated preference)
quant_to_try = "Q4_K_M"
result = _try_quant_at(model, quant_to_try, ctx, effective_vram, eff_ram)
# If target quant doesn't fit and it's not pre-quantized, try lower quants
if result is None and not preq and target_quant:
from services.hwfit.models import QUANT_HIERARCHY
idx = QUANT_HIERARCHY.index(target_quant) if target_quant in QUANT_HIERARCHY else -1
for q in QUANT_HIERARCHY[idx + 1:]:
result = _try_quant_at(model, q, ctx, effective_vram, eff_ram)
if result:
break
if result is None:
# Model doesn't fit on the user's current hardware. Surface it
# anyway with a "too_tight" badge instead of silently dropping
# it — without this, editing the hardware config to try LARGER
# tiers never revealed the bigger models, because they were
# filtered out before the user could see what would fit. The
# client already knows how to render too_tight (red row).
oversized_required = estimate_memory_gb(model, quant_to_try, ctx)
return {
"name": model.get("name"),
"provider": model.get("provider"),
"parameter_count": model.get("parameter_count"),
"params_b": round(pb, 1),
"is_moe": is_moe,
"use_case": use_case,
"fit_level": "too_tight",
"run_mode": "no_fit",
"quant": quant_to_try,
"context": ctx,
"required_gb": round(oversized_required, 1),
"speed_tps": 0,
"score": 0,
"scores": {"quality": 0, "speed": 0, "fit": 0, "context": 0},
"gguf_sources": model.get("gguf_sources", []),
"context_length": model.get("context_length", 4096),
}
run_mode, quant, fit_ctx, required_gb = result
# Determine fit level
budget = effective_vram if run_mode == "gpu" else available_ram
if required_gb > budget:
return None
if run_mode == "gpu":
rec = model.get("recommended_ram_gb") or required_gb
if rec <= gpu_vram:
fit_level = "perfect"
elif gpu_vram >= required_gb * 1.2:
fit_level = "good"
else:
fit_level = "marginal"
elif run_mode == "cpu_offload":
fit_level = "good" if available_ram >= required_gb * 1.2 else "marginal"
else:
fit_level = "marginal"
tps = _estimate_speed(model, quant, run_mode, system)
q_score = _quality_score(model, quant, use_case)
s_score = _speed_score(tps, use_case)
f_score = _fit_score(required_gb, budget)
c_score = _context_score(fit_ctx, use_case)
wq, ws, wf, wc = USE_CASE_WEIGHTS.get(use_case, (0.45, 0.30, 0.15, 0.10))
composite = q_score * wq + s_score * ws + f_score * wf + c_score * wc
return {
"name": model.get("name"),
"provider": model.get("provider"),
"parameter_count": model.get("parameter_count"),
"params_b": round(pb, 1),
"is_moe": is_moe,
"use_case": use_case,
"fit_level": fit_level,
"run_mode": run_mode,
"quant": quant,
"context": fit_ctx,
"required_gb": round(required_gb, 1),
"speed_tps": round(tps, 1),
"score": round(composite, 1),
"scores": {
"quality": round(q_score, 1),
"speed": round(s_score, 1),
"fit": round(f_score, 1),
"context": round(c_score, 1),
},
"gguf_sources": model.get("gguf_sources", []),
"context_length": model.get("context_length", 4096),
}
SORT_KEYS = {
"score": lambda r: r["score"],
"speed": lambda r: r["speed_tps"],
"vram": lambda r: r["required_gb"],
"params": lambda r: r["params_b"],
"context": lambda r: r["context"],
}
def rank_models(system, use_case=None, limit=50, search=None, sort="score", quant=None):
"""Rank all models against detected hardware. Returns sorted list of fit results."""
models = get_models()
results = []
# Include image gen models only when explicitly filtered
if use_case == "image_gen":
try:
from services.hwfit.image_models import rank_image_models
except ImportError:
rank_image_models = None
if rank_image_models:
img_results = rank_image_models(system, search=search)
else:
img_results = []
for im in img_results:
fit_map = {"perfect": "perfect", "good": "good", "tight": "marginal", "no_fit": "too_tight", "no_gpu": "too_tight"}
results.append({
"name": im["id"],
"provider": im["provider"],
"parameter_count": f"{im['params_b']}B",
"params_b": im["params_b"],
"is_moe": False,
"use_case": "image_gen",
"fit_level": fit_map.get(im["fit"], "too_tight"),
"run_mode": "gpu" if im["fits"] else "no_fit",
"quant": im.get("quant", "BF16"),
"context": 0,
"context_length": 0,
"required_gb": round(im.get("vram_needed") or 0, 1),
"speed_tps": 0,
"score": float(im["score"]),
"scores": {"quality": float(im["quality"]), "speed": float(im["speed"]), "fit": 0, "context": 0},
"gguf_sources": [],
"is_image_gen": True,
"capabilities": im.get("capabilities", []),
"description": im.get("description", ""),
})
if use_case == "image_gen":
sort_fn = SORT_KEYS.get(sort, SORT_KEYS["score"])
results.sort(key=sort_fn, reverse=(sort != "vram"))
return results[:limit]
# If user picked a prequantized format (AWQ/FP8/GPTQ), filter to only those models
filter_native = quant and any(quant.startswith(p) for p in ("AWQ-", "GPTQ-", "FP8"))
# MLX-quantized models only run on Apple Silicon (Metal). Exclude them on
# every other backend (CUDA / ROCm / CPU) so Linux/Windows users don't see
# unrunnable suggestions.
system_backend = (system.get("backend") or "").lower()
apple_silicon = system_backend in ("mps", "metal", "apple")
for m in models:
native_q = m.get("quantization", "")
# Drop MLX models on non-Apple hardware
if not apple_silicon and native_q.startswith("mlx-"):
continue
# Format filter: AWQ tab → only AWQ models, FP8 tab → only FP8 models
if filter_native:
if quant == "FP8" and native_q != "FP8":
continue
if quant.startswith("AWQ") and not native_q.startswith("AWQ"):
continue
if quant.startswith("GPTQ") and not native_q.startswith("GPTQ"):
continue
if search:
name = m.get("name", "").lower()
provider = m.get("provider", "").lower()
if search.lower() not in name and search.lower() not in provider:
continue
result = analyze_model(m, system, target_quant=quant)
if result is None:
continue
if use_case:
model_uc = infer_use_case(m)
if use_case != model_uc and use_case != "general":
continue
results.append(result)
# Pick the visible SET by best fit (score) first, so it stays the same no
# matter which column the user sorts by — otherwise sorting by params would
# truncate to the N biggest models (huge ones that don't even fit) while
# sorting by vram showed the N smallest. Only AFTER choosing the set do we
# order it by the requested column.
results.sort(key=SORT_KEYS["score"], reverse=True)
results = results[:limit]
sort_fn = SORT_KEYS.get(sort, SORT_KEYS["score"])
# vram ascending (smallest first), everything else descending (biggest first)
results.sort(key=sort_fn, reverse=(sort != "vram"))
return results
+457
View File
@@ -0,0 +1,457 @@
import os
import platform
import subprocess
import time
CACHE_TTL = 1800 # 30 min — hardware rarely changes; use the Rescan button to force a re-probe
_remote_host = None # set by detect_system(host=...)
_remote_port = None # set by detect_system(ssh_port=...)
_remote_platform = None # set by detect_system(platform=...): "windows", "linux", "termux"
_last_gpu_error = None # set by _detect_nvidia() when nvidia-smi errors (driver mismatch, etc.)
def _run(cmd):
try:
if _remote_host:
# Run command on remote host via SSH
if isinstance(cmd, list):
cmd_str = " ".join(cmd)
else:
cmd_str = cmd
ssh_cmd = ["ssh", "-o", "ConnectTimeout=5", "-o", "StrictHostKeyChecking=no"]
if _remote_port and _remote_port != "22":
ssh_cmd += ["-p", _remote_port]
ssh_cmd += [_remote_host, cmd_str]
r = subprocess.run(
ssh_cmd,
capture_output=True, text=True, timeout=15,
)
else:
r = subprocess.run(cmd, capture_output=True, text=True, timeout=10)
if r.returncode == 0:
return r.stdout.strip()
except Exception:
pass
return None
def _group_gpus(gpus):
"""Group identical GPUs by (name, rounded VRAM).
vLLM tensor-parallel only works across IDENTICAL GPUs, so a mixed box must
be split into homogeneous pools. Each group carries the device indices so a
serve command can pin CUDA_VISIBLE_DEVICES to exactly one pool. Biggest pool
(by total VRAM) first — that's the sensible auto-default serving target.
"""
groups = {}
order = []
for g in gpus:
key = (g["name"], round(g["vram_gb"]))
if key not in groups:
groups[key] = {
"name": g["name"],
"vram_each": round(g["vram_gb"], 1),
"count": 0,
"indices": [],
}
order.append(key)
groups[key]["count"] += 1
groups[key]["indices"].append(g.get("index"))
out = []
for key in order:
grp = groups[key]
grp["vram_total"] = round(grp["vram_each"] * grp["count"], 1)
out.append(grp)
out.sort(key=lambda x: x["vram_total"], reverse=True)
return out
def _detect_nvidia():
global _last_gpu_error
_last_gpu_error = None
out = _run(["nvidia-smi", "--query-gpu=memory.total,name", "--format=csv,noheader,nounits"])
# Remote fallback: a non-interactive SSH shell often has a minimal PATH
# that omits where nvidia-smi lives (/usr/bin, /usr/local/cuda/bin), so the
# first call silently returns nothing → "No GPU" on hosts that DO have GPUs.
# Retry through a login shell with the common CUDA bin dirs on PATH.
if not out and _remote_host:
out = _run(
"bash -lc 'export PATH=\"$PATH:/usr/bin:/usr/local/bin:/usr/local/cuda/bin\"; "
"nvidia-smi --query-gpu=memory.total,name --format=csv,noheader,nounits'"
)
# Last resort: call nvidia-smi by absolute path. Some hosts have a login
# shell that isn't bash (or a profile that errors), so the bash -lc retry
# above still comes back empty even though the binary is right there.
if not out and _remote_host:
for _p in ("/usr/bin/nvidia-smi", "/usr/local/bin/nvidia-smi", "/usr/local/cuda/bin/nvidia-smi"):
out = _run(f"{_p} --query-gpu=memory.total,name --format=csv,noheader,nounits")
if out:
break
if not out:
return None
# nvidia-smi present but unable to talk to the driver (e.g. it was updated
# without a reboot). It prints an error and no GPU rows — surface that as a
# driver error rather than the misleading "No GPU".
_low = out.lower()
if ("nvml" in _low or "driver/library version mismatch" in _low
or "couldn't communicate" in _low or "no devices were found" in _low
or "failed to initialize" in _low):
_last_gpu_error = out.strip().split("\n")[0][:140] or "NVIDIA driver error"
return None
gpus = []
# nvidia-smi lists GPUs in index order (0,1,2,...), so the row position is
# the CUDA device index we'd pass to CUDA_VISIBLE_DEVICES.
for idx, line in enumerate(out.strip().split("\n")):
parts = [p.strip() for p in line.split(",")]
if len(parts) >= 2:
try:
vram_mb = float(parts[0])
gpus.append({"index": idx, "name": parts[1], "vram_gb": vram_mb / 1024.0})
except ValueError:
continue
if not gpus:
return None
total_vram = sum(g["vram_gb"] for g in gpus)
groups = _group_gpus(gpus)
return {
"gpu_name": gpus[0]["name"],
"gpu_vram_gb": round(total_vram, 1),
"gpu_count": len(gpus),
"gpus": gpus,
"gpu_groups": groups,
"homogeneous": len(groups) <= 1,
"backend": "cuda",
}
def _detect_amd():
"""Detect AMD GPUs. Handles both discrete cards (with mem_info_vram_total)
and APUs / unified-memory SoCs like Strix Halo (which expose
mem_info_vis_vram_total instead, or only mem_info_gtt_total)."""
def _read(path):
if _remote_host:
val = _run(["cat", path])
return val.strip() if val else None
try:
with open(path) as f:
return f.read().strip()
except Exception:
return None
def _list_drm_cards():
if _remote_host:
out = _run(["ls", "/sys/class/drm"])
if not out:
return []
return [e for e in out.split() if e.startswith("card") and "-" not in e]
try:
return [e for e in os.listdir("/sys/class/drm") if e.startswith("card") and "-" not in e]
except Exception:
return []
try:
cards = []
is_apu = False
for _cidx, entry in enumerate(_list_drm_cards()):
base = f"/sys/class/drm/{entry}/device"
vendor = _read(f"{base}/vendor")
if vendor != "0x1002":
continue
# Discrete cards usually report real VRAM in mem_info_vram_total,
# while some AMD APUs / Docker views expose a tiny vram_total and
# the usable pool in vis_vram_total. Use the larger of those two;
# only fall back to GTT if neither VRAM field is available.
vram_raw = _read(f"{base}/mem_info_vram_total")
vis_raw = _read(f"{base}/mem_info_vis_vram_total")
gtt_raw = _read(f"{base}/mem_info_gtt_total")
vram_val = int(vram_raw) if vram_raw and vram_raw.isdigit() else 0
vis_val = int(vis_raw) if vis_raw and vis_raw.isdigit() else 0
gtt_val = int(gtt_raw) if gtt_raw and gtt_raw.isdigit() else 0
vram_bytes = max(vram_val, vis_val)
if vram_bytes <= 0:
vram_bytes = gtt_val
if vis_val and vis_val >= vram_val:
is_apu = True
if vram_bytes <= 0:
continue
name = _read(f"{base}/product_name") or f"AMD GPU ({entry})"
cards.append({"index": _cidx, "name": name, "vram_gb": vram_bytes / (1024**3)})
if not cards:
return None
total_vram = sum(c["vram_gb"] for c in cards)
groups = _group_gpus(cards)
# NOTE: for APUs with BIOS UMA carveout (e.g. Strix Halo), vis_vram_total
# is the real usable GPU memory — it's physically backed but reserved
# by BIOS so it doesn't appear in /proc/meminfo. Don't cap it at system
# RAM: the two pools are separate from the OS's perspective.
return {
"gpu_name": cards[0]["name"],
"gpu_vram_gb": round(total_vram, 1),
"gpu_count": len(cards),
"gpus": cards,
"gpu_groups": groups,
"homogeneous": len(groups) <= 1,
"backend": "rocm",
"unified_memory": is_apu,
}
except Exception:
return None
def _read_file(path):
"""Read a file, locally or via SSH."""
if _remote_host:
return _run(["cat", path])
try:
with open(path) as f:
return f.read()
except Exception:
return None
def _parse_meminfo():
"""Parse /proc/meminfo into a dict of key -> KB values."""
text = _read_file("/proc/meminfo")
if not text:
return {}
result = {}
for line in text.split("\n"):
if ":" in line:
key, val = line.split(":", 1)
parts = val.strip().split()
if parts:
try:
result[key.strip()] = int(parts[0])
except ValueError:
pass
return result
def _get_ram_gb():
meminfo = _parse_meminfo()
if "MemTotal" in meminfo:
return meminfo["MemTotal"] / (1024**2)
if not _remote_host:
try:
pages = os.sysconf("SC_PHYS_PAGES")
page_size = os.sysconf("SC_PAGE_SIZE")
if pages and page_size:
return (pages * page_size) / (1024**3)
except Exception:
pass
return 0.0
def _get_available_ram_gb():
meminfo = _parse_meminfo()
if "MemAvailable" in meminfo:
return meminfo["MemAvailable"] / (1024**2)
return _get_ram_gb() * 0.7
def _get_cpu_name():
text = _read_file("/proc/cpuinfo")
if text:
for line in text.split("\n"):
if line.startswith("model name"):
return line.split(":", 1)[1].strip()
if not _remote_host:
return platform.processor() or "unknown"
return "unknown"
def _get_cpu_count():
if _remote_host:
out = _run(["nproc"])
if out:
try:
return int(out.strip())
except ValueError:
pass
# fallback: count "processor" lines in /proc/cpuinfo
text = _read_file("/proc/cpuinfo")
if text:
return sum(1 for line in text.split("\n") if line.startswith("processor"))
return os.cpu_count() or 1
def _detect_windows():
"""Detect Windows hardware in a single SSH call using PowerShell."""
# Single PowerShell command that gathers all hardware info at once
ps_cmd = (
"$r = @{}; "
"$os = Get-CimInstance Win32_OperatingSystem; "
"$r.ram_gb = [math]::Round($os.TotalVisibleMemorySize / 1048576, 1); "
"$r.avail_gb = [math]::Round($os.FreePhysicalMemory / 1048576, 1); "
"$cpu = Get-CimInstance Win32_Processor | Select-Object -First 1; "
"$r.cpu_name = $cpu.Name; "
"$r.cpu_cores = (Get-CimInstance Win32_Processor | Measure-Object -Property NumberOfLogicalProcessors -Sum).Sum; "
"$r.arch = $cpu.AddressWidth; "
# GPU detection via nvidia-smi (fastest) or WMI fallback
"try { "
" $nv = nvidia-smi --query-gpu=memory.total,name --format=csv,noheader,nounits 2>$null; "
" if ($LASTEXITCODE -eq 0 -and $nv) { "
" $gpus = @(); "
" foreach ($line in $nv -split \"`n\") { "
" $p = $line -split ','; "
" if ($p.Count -ge 2) { $gpus += @{name=$p[1].Trim(); vram_mb=[double]$p[0].Trim()} } "
" }; "
" $r.gpu_name = $gpus[0].name; "
" $r.gpu_vram_gb = [math]::Round(($gpus | Measure-Object -Property vram_mb -Sum).Sum / 1024, 1); "
" $r.gpu_count = $gpus.Count; "
" $r.gpu_backend = 'cuda'; "
" } "
"} catch {}; "
"if (-not $r.gpu_name) { "
" $wmiGpu = Get-CimInstance Win32_VideoController | Where-Object { $_.AdapterRAM -gt 0 } | Select-Object -First 1; "
" if ($wmiGpu) { "
" $r.gpu_name = $wmiGpu.Name; "
" $r.gpu_vram_gb = [math]::Round($wmiGpu.AdapterRAM / 1073741824, 1); "
" $r.gpu_count = 1; "
" $r.gpu_backend = 'cpu_x86'; " # WMI doesn't tell us CUDA/ROCm
" } "
"}; "
"$r | ConvertTo-Json -Compress"
)
out = _run(f'powershell -Command "{ps_cmd}"')
if not out:
return None
import json as _json
try:
d = _json.loads(out)
result = {
"total_ram_gb": d.get("ram_gb", 0),
"available_ram_gb": d.get("avail_gb", 0),
"cpu_cores": d.get("cpu_cores", 1),
"cpu_name": d.get("cpu_name", "unknown"),
"has_gpu": bool(d.get("gpu_name")),
"gpu_name": d.get("gpu_name"),
"gpu_vram_gb": d.get("gpu_vram_gb"),
"gpu_count": d.get("gpu_count", 0),
"backend": d.get("gpu_backend", "cpu_x86"),
}
# PowerShell only reports aggregate GPU info, not per-card detail, so we
# can't tell a mixed box from a uniform one here — assume one homogeneous
# pool spanning all reported GPUs (the common Windows case).
_n = result["gpu_count"] or 0
if result["has_gpu"] and _n > 0:
_each = round((result["gpu_vram_gb"] or 0) / _n, 1)
result["gpus"] = [
{"index": i, "name": result["gpu_name"], "vram_gb": _each} for i in range(_n)
]
result["gpu_groups"] = [{
"name": result["gpu_name"],
"vram_each": _each,
"count": _n,
"indices": list(range(_n)),
"vram_total": result["gpu_vram_gb"],
}]
result["homogeneous"] = True
return result
except Exception:
return None
_cache_by_host = {} # host -> (timestamp, result)
def detect_system(host="", ssh_port="", platform="", fresh=False):
"""Detect system hardware: RAM, CPU, GPU. Cached per host (hardware rarely
changes, and probing a remote host over SSH is slow). Pass fresh=True to
bypass the cache and re-probe (the "Rescan" button).
If host is set (e.g. 'user@server'), runs detection commands over SSH.
platform: "windows", "linux", "termux", or "" (auto-detect).
"""
global _remote_host, _remote_port, _remote_platform
cache_key = host or "_local"
now = time.time()
if not fresh and cache_key in _cache_by_host:
ts, cached = _cache_by_host[cache_key]
if (now - ts) < CACHE_TTL:
return cached
_remote_host = host or None
_remote_port = ssh_port or None
_remote_platform = platform or None
# Windows: single PowerShell command for all hardware info
if _remote_platform == "windows" and _remote_host:
result = _detect_windows()
if result:
_remote_host = None
_remote_platform = None
_cache_by_host[cache_key] = (now, result)
return result
# If Windows detection failed, return error
result = {"error": f"Cannot connect to {host}", "host": host}
_remote_host = None
_remote_platform = None
_cache_by_host[cache_key] = (now, result)
return result
# Linux/Termux: existing multi-command detection
total_ram = round(_get_ram_gb(), 1)
# If remote host returns 0 RAM, connection likely failed
if _remote_host and total_ram <= 0:
result = {"error": f"Cannot connect to {host}", "host": host}
_cache_by_host[cache_key] = (now, result)
_remote_host = None
_remote_platform = None
return result
available_ram = round(_get_available_ram_gb(), 1)
cpu_cores = _get_cpu_count()
cpu_name = _get_cpu_name()
gpu_info = _detect_nvidia() or _detect_amd()
if gpu_info:
result = {
"total_ram_gb": total_ram,
"available_ram_gb": available_ram,
"cpu_cores": cpu_cores,
"cpu_name": cpu_name,
"has_gpu": True,
"gpu_name": gpu_info["gpu_name"],
"gpu_vram_gb": gpu_info["gpu_vram_gb"],
"gpu_count": gpu_info["gpu_count"],
"gpus": gpu_info.get("gpus", []),
"gpu_groups": gpu_info.get("gpu_groups", []),
"homogeneous": gpu_info.get("homogeneous", True),
"backend": gpu_info["backend"],
}
else:
if _remote_host:
arch_out = _run(["uname", "-m"]) or ""
else:
import platform as _platform
arch_out = _platform.machine().lower()
backend = "cpu_arm" if "aarch64" in arch_out or "arm" in arch_out else "cpu_x86"
result = {
"total_ram_gb": total_ram,
"available_ram_gb": available_ram,
"cpu_cores": cpu_cores,
"cpu_name": cpu_name,
"has_gpu": False,
"gpu_name": None,
"gpu_vram_gb": None,
"gpu_count": 0,
"backend": backend,
# Set when nvidia-smi exists but failed (e.g. driver/library
# version mismatch) — lets the UI say "GPU driver error" instead
# of the misleading "No GPU".
"gpu_error": _last_gpu_error,
}
_remote_host = None
_remote_platform = None
_cache_by_host[cache_key] = (now, result)
return result
+374
View File
@@ -0,0 +1,374 @@
"""Image generation model registry and VRAM fitting for Cookbook."""
# Curated registry of image generation models supported by diffusers.
# ONLY verified HuggingFace repo IDs.
# VRAM estimates are for inference (single image generation).
IMAGE_MODEL_REGISTRY = [
# ── Z-Image (Alibaba Tongyi) ──
{
"id": "Tongyi-MAI/Z-Image-Turbo",
"name": "Z-Image Turbo",
"provider": "Tongyi",
"params_b": 6.0,
"vram_bf16": 19.0,
"vram_fp8": 10.0,
"vram_q4": 6.0,
"default_quant": "BF16",
"quant_repos": {
"FP8": "drbaph/Z-Image-Turbo-FP8",
},
"capabilities": ["text-to-image"],
"description": "6B distilled, 8-step. Sub-second on H800. Apache 2.0.",
"quality": 92,
"speed": 95,
"released": "2025-12",
},
{
"id": "Tongyi-MAI/Z-Image",
"name": "Z-Image",
"provider": "Tongyi",
"params_b": 6.0,
"vram_bf16": 19.0,
"vram_fp8": 10.0,
"vram_q4": 6.0,
"default_quant": "BF16",
"quant_repos": {
"FP8": "drbaph/Z-Image-fp8",
},
"capabilities": ["text-to-image"],
"description": "Full undistilled model. Highest creative freedom. Apache 2.0.",
"quality": 93,
"speed": 70,
"released": "2025-12",
},
# ── Qwen Image ──
{
"id": "Qwen/Qwen-Image-2512",
"name": "Qwen Image 2512",
"provider": "Qwen",
"params_b": 20.0,
"vram_bf16": 42.0,
"vram_fp8": 22.0,
"vram_q4": 14.0,
"default_quant": "FP8",
"quant_repos": {},
"capabilities": ["text-to-image", "text-rendering"],
"description": "Dec 2025 update. Better humans, finer detail, strong text. Apache 2.0.",
"quality": 95,
"speed": 50,
"released": "2025-12",
},
{
"id": "Qwen/Qwen-Image",
"name": "Qwen Image",
"provider": "Qwen",
"params_b": 20.0,
"vram_bf16": 42.0,
"vram_fp8": 22.0,
"vram_q4": 14.0,
"default_quant": "FP8",
"quant_repos": {},
"capabilities": ["text-to-image", "text-rendering"],
"description": "20B foundation. Best text rendering in images. Apache 2.0.",
"quality": 94,
"speed": 50,
"released": "2025-08",
},
{
"id": "Qwen/Qwen-Image-Edit-2511",
"name": "Qwen Image Edit",
"provider": "Qwen",
"params_b": 20.0,
"vram_bf16": 42.0,
"vram_fp8": 22.0,
"vram_q4": 14.0,
"default_quant": "FP8",
"quant_repos": {},
"capabilities": ["image-editing", "inpainting"],
"description": "Dedicated editing. Style transfer, object removal. Apache 2.0.",
"quality": 92,
"speed": 50,
"released": "2025-11",
},
# ── Stable Diffusion (dedicated inpainting) ──
{
"id": "diffusers/stable-diffusion-xl-1.0-inpainting-0.1",
"name": "SDXL Inpainting",
"provider": "Stability AI",
"params_b": 3.5,
"vram_bf16": 12.0,
"vram_fp8": 8.0,
"vram_q4": 6.0,
"default_quant": "BF16",
"quant_repos": {},
"capabilities": ["inpainting", "image-editing"],
"description": "SDXL fine-tuned for inpainting (9-channel UNet). Best SD-family fill quality; fits a 24GB card comfortably.",
"quality": 86,
"speed": 68,
"released": "2023-11",
},
{
"id": "stable-diffusion-v1-5/stable-diffusion-inpainting",
"name": "SD 1.5 Inpainting",
"provider": "Stability AI",
"params_b": 1.1,
"vram_bf16": 4.0,
"vram_fp8": 3.0,
"vram_q4": 2.5,
"default_quant": "BF16",
"quant_repos": {},
"capabilities": ["inpainting"],
"description": "Classic SD 1.5 inpaint. Very light and fast; lower fidelity than SDXL.",
"quality": 70,
"speed": 92,
"released": "2022-10",
},
# ── FLUX ──
{
"id": "black-forest-labs/FLUX.1-dev",
"name": "FLUX.1 Dev",
"provider": "Black Forest Labs",
"params_b": 12.0,
"vram_bf16": 33.0,
"vram_fp8": 17.0,
"vram_q4": 10.0,
"default_quant": "FP8",
"quant_repos": {
"FP8": "diffusers/FLUX.1-dev-torchao-fp8",
},
"capabilities": ["text-to-image"],
"description": "High quality, detailed. Popular community model. Non-commercial.",
"quality": 92,
"speed": 55,
"released": "2024-08",
},
{
"id": "black-forest-labs/FLUX.1-schnell",
"name": "FLUX.1 Schnell",
"provider": "Black Forest Labs",
"params_b": 12.0,
"vram_bf16": 33.0,
"vram_fp8": 17.0,
"vram_q4": 10.0,
"default_quant": "FP8",
"quant_repos": {
"FP8": "Kijai/flux-fp8",
},
"capabilities": ["text-to-image"],
"description": "Fast 4-step variant. Apache 2.0 license.",
"quality": 85,
"speed": 90,
"released": "2024-08",
},
# ── Stable Diffusion ──
{
"id": "stabilityai/stable-diffusion-3.5-medium",
"name": "SD 3.5 Medium",
"provider": "Stability AI",
"params_b": 2.5,
"vram_bf16": 12.0,
"vram_fp8": 7.0,
"vram_q4": None,
"default_quant": "BF16",
"quant_repos": {
"FP8": "Comfy-Org/stable-diffusion-3.5-fp8",
},
"capabilities": ["text-to-image"],
"description": "2.5B lightweight, fast. Fits almost any GPU.",
"quality": 75,
"speed": 95,
"released": "2024-10",
},
{
"id": "stabilityai/stable-diffusion-3.5-large",
"name": "SD 3.5 Large",
"provider": "Stability AI",
"params_b": 8.1,
"vram_bf16": 22.0,
"vram_fp8": 12.0,
"vram_q4": None,
"default_quant": "BF16",
"quant_repos": {
"FP8": "Comfy-Org/stable-diffusion-3.5-fp8",
},
"capabilities": ["text-to-image"],
"description": "8B high quality. Good balance of speed and quality.",
"quality": 85,
"speed": 70,
"released": "2024-10",
},
{
"id": "stabilityai/stable-diffusion-3.5-large-turbo",
"name": "SD 3.5 Large Turbo",
"provider": "Stability AI",
"params_b": 8.1,
"vram_bf16": 22.0,
"vram_fp8": 12.0,
"vram_q4": None,
"default_quant": "BF16",
"quant_repos": {
"FP8": "Comfy-Org/stable-diffusion-3.5-fp8",
},
"capabilities": ["text-to-image"],
"description": "Distilled for few-step inference. Fastest large SD.",
"quality": 80,
"speed": 92,
"released": "2024-10",
},
{
"id": "stabilityai/stable-diffusion-xl-base-1.0",
"name": "SDXL",
"provider": "Stability AI",
"params_b": 3.5,
"vram_bf16": 10.0,
"vram_fp8": 6.0,
"vram_q4": None,
"default_quant": "BF16",
"quant_repos": {},
"capabilities": ["text-to-image"],
"description": "Classic workhorse. Huge LoRA ecosystem. Fits 8GB+.",
"quality": 72,
"speed": 90,
"released": "2023-07",
},
# ── Hunyuan ──
{
"id": "tencent/HunyuanImage-3.0",
"name": "HunyuanImage 3.0",
"provider": "Tencent",
"params_b": 13.0,
"vram_bf16": 30.0,
"vram_fp8": 16.0,
"vram_q4": 9.0,
"default_quant": "FP8",
"quant_repos": {
"Q4": "wikeeyang/Hunyuan-Image-30-Qint4",
"NF4": "EricRollei/HunyuanImage-3.0-Instruct-NF4",
},
"capabilities": ["text-to-image", "text-rendering"],
"description": "Strong text rendering. Bilingual Chinese/English. 13B activated per token.",
"quality": 88,
"speed": 60,
"released": "2025-09",
},
{
"id": "tencent/HunyuanImage-3.0-Instruct-Distil",
"name": "HunyuanImage 3.0 Distil",
"provider": "Tencent",
"params_b": 13.0,
"vram_bf16": 30.0,
"vram_fp8": 16.0,
"vram_q4": 9.0,
"default_quant": "FP8",
"quant_repos": {},
"capabilities": ["text-to-image", "text-rendering"],
"description": "Distilled variant, fewer steps. Faster with comparable quality.",
"quality": 85,
"speed": 80,
"released": "2026-01",
},
]
def get_image_models():
"""Return the image model registry."""
return IMAGE_MODEL_REGISTRY
def rank_image_models(system, search=None, sort="fit"):
"""Score and rank image models against detected hardware.
Returns list of models with fit info (vram needed, fits, recommended quant).
"""
gpu_vram = system.get("gpu_vram_gb", 0) or 0
has_gpu = system.get("has_gpu", False)
results = []
for model in IMAGE_MODEL_REGISTRY:
# Filter by search
if search:
s = search.lower()
if s not in model["name"].lower() and s not in model["id"].lower() and s not in model.get("description", "").lower():
continue
# Determine best quant that fits
quant = None
vram_needed = None
fits = False
quant_repo = None
if has_gpu and gpu_vram > 0:
# Try BF16 first, then FP8, then Q4
for q, vram_key in [("BF16", "vram_bf16"), ("FP8", "vram_fp8"), ("Q4", "vram_q4")]:
v = model.get(vram_key)
if v is not None and v <= gpu_vram * 0.90: # 10% headroom
quant = q
vram_needed = v
fits = True
quant_repo = model.get("quant_repos", {}).get(q)
break
# If nothing fits, show what it needs
if not fits:
quant = model["default_quant"]
vram_needed = model.get("vram_bf16", 0)
# Fit label
if not has_gpu:
fit = "no_gpu"
fit_label = "No GPU"
elif fits:
headroom = gpu_vram - vram_needed
if headroom > gpu_vram * 0.3:
fit = "perfect"
fit_label = "Perfect"
elif headroom > gpu_vram * 0.1:
fit = "good"
fit_label = "Good"
else:
fit = "tight"
fit_label = "Tight"
else:
fit = "no_fit"
fit_label = "Too large"
# Score: quality * speed * fit bonus
score = model["quality"] * 0.6 + model["speed"] * 0.2
if fit == "perfect":
score += 20
elif fit == "good":
score += 10
elif fit == "tight":
score += 5
elif fit == "no_fit":
score -= 30
results.append({
"id": model["id"],
"name": model["name"],
"provider": model["provider"],
"params_b": model["params_b"],
"vram_needed": vram_needed,
"quant": quant,
"quant_repo": quant_repo,
"fits": fits,
"fit": fit,
"fit_label": fit_label,
"quality": model["quality"],
"speed": model["speed"],
"score": round(score, 1),
"capabilities": model["capabilities"],
"description": model["description"],
"released": model.get("released", ""),
})
# Sort
if sort == "quality":
results.sort(key=lambda x: (-x["quality"], -x["score"]))
elif sort == "speed":
results.sort(key=lambda x: (-x["speed"], -x["score"]))
elif sort == "vram":
results.sort(key=lambda x: (x["vram_needed"] or 999, -x["score"]))
else: # fit (default)
results.sort(key=lambda x: (-x["score"],))
return results
+177
View File
@@ -0,0 +1,177 @@
import json
import os
import re
QUANT_HIERARCHY = ["Q8_0", "Q6_K", "Q5_K_M", "Q4_K_M", "Q3_K_M", "Q2_K"]
QUANT_BPP = {
"F32": 4.0, "F16": 2.0, "BF16": 2.0, "FP8": 1.0,
"Q8_0": 1.05, "Q6_K": 0.80, "Q5_K_M": 0.68,
"Q4_K_M": 0.58, "Q4_0": 0.58, "Q3_K_M": 0.48, "Q2_K": 0.37,
"AWQ-4bit": 0.50, "AWQ-8bit": 1.0,
"GPTQ-Int4": 0.50, "GPTQ-Int8": 1.0,
"mlx-4bit": 0.55, "mlx-8bit": 1.0, "mlx-6bit": 0.75,
}
QUANT_SPEED_MULT = {
"F16": 0.6, "BF16": 0.6, "FP8": 0.85,
"Q8_0": 0.8, "Q6_K": 0.95, "Q5_K_M": 1.0,
"Q4_K_M": 1.15, "Q4_0": 1.15, "Q3_K_M": 1.25, "Q2_K": 1.35,
"AWQ-4bit": 1.2, "AWQ-8bit": 0.85,
"GPTQ-Int4": 1.2, "GPTQ-Int8": 0.85,
"mlx-4bit": 1.15, "mlx-8bit": 0.85, "mlx-6bit": 1.0,
}
QUANT_QUALITY_PENALTY = {
"F16": 0.0, "BF16": 0.0, "FP8": 0.0,
"Q8_0": 0.0, "Q6_K": -1.0, "Q5_K_M": -2.0,
"Q4_K_M": -5.0, "Q4_0": -5.0, "Q3_K_M": -8.0, "Q2_K": -12.0,
"AWQ-4bit": -3.0, "AWQ-8bit": 0.0,
"GPTQ-Int4": -3.0, "GPTQ-Int8": 0.0,
"mlx-4bit": -4.0, "mlx-8bit": 0.0, "mlx-6bit": -1.0,
}
QUANT_BYTES_PER_PARAM = {
"F16": 2.0, "BF16": 2.0, "FP8": 1.0,
"Q8_0": 1.0, "Q6_K": 0.75, "Q5_K_M": 0.625,
"Q4_K_M": 0.5, "Q4_0": 0.5, "Q3_K_M": 0.375, "Q2_K": 0.25,
"AWQ-4bit": 0.5, "AWQ-8bit": 1.0,
"GPTQ-Int4": 0.5, "GPTQ-Int8": 1.0,
"mlx-4bit": 0.5, "mlx-8bit": 1.0, "mlx-6bit": 0.75,
}
# Pre-quantized formats that should NOT go through the GGUF quant hierarchy
PREQUANTIZED_PREFIXES = ("AWQ-", "GPTQ-", "mlx-", "FP8")
def is_prequantized(model):
q = model.get("quantization", "")
return any(q.startswith(p) for p in PREQUANTIZED_PREFIXES)
def params_b(model):
raw = model.get("parameters_raw")
if raw and raw > 0:
return raw / 1_000_000_000.0
pc = model.get("parameter_count", "")
if pc:
pc = pc.strip().upper()
m = re.match(r"^([\d.]+)\s*([BKMGT]?)$", pc)
if m:
val = float(m.group(1))
suffix = m.group(2)
if suffix == "B":
return val
elif suffix == "M":
return val / 1000.0
elif suffix == "K":
return val / 1_000_000.0
elif suffix == "T":
return val * 1000.0
else:
# No unit. A bare number this size is conventionally a millions
# count (e.g. "355" = 355M), NOT billions — otherwise a 355M
# model would sort as 355B and leap above every 7B/70B model.
# A genuine billions figure carries a "B" suffix and is handled
# above; very large bare values are raw parameter counts.
if val >= 1_000_000:
return val / 1_000_000_000.0 # raw count
if val >= 1000:
return val / 1000.0 # thousands of millions? treat as millions
return val / 1000.0 # e.g. "355" → 0.355B
return 0.0
def estimate_memory_gb(model, quant, ctx):
"""Estimate VRAM needed to serve a model. All weights must be loaded,
even for MoE (all experts live in memory, only active ones compute per token).
KV cache scales with active params for MoE (only active experts have KV state)."""
pb = params_b(model)
bpp = QUANT_BPP.get(quant, 0.58)
kv_params = _active_params_b(model)
return pb * bpp + 0.000008 * kv_params * ctx + 0.5
def _active_params_b(model):
"""For MoE: active params per token (affects KV cache and speed, not total VRAM).
For dense: same as total params."""
if model.get("is_moe") and model.get("active_parameters"):
return model["active_parameters"] / 1_000_000_000.0
return params_b(model)
def best_quant_for_budget(model, budget_gb, ctx):
"""Find best quant that fits in budget_gb of VRAM.
Pre-quantized models (AWQ/GPTQ/MLX) use their native quant only.
Returns (quant, ctx, mem_gb) or (None, None, None).
"""
if is_prequantized(model):
q = model.get("quantization", "Q4_K_M")
mem = estimate_memory_gb(model, q, ctx)
if mem <= budget_gb:
return q, ctx, mem
# Try halving context
cur_ctx = ctx // 2
while cur_ctx >= 1024:
mem = estimate_memory_gb(model, q, cur_ctx)
if mem <= budget_gb:
return q, cur_ctx, mem
cur_ctx //= 2
return None, None, None
# GGUF: try best quality first, then fall back
for q in QUANT_HIERARCHY:
mem = estimate_memory_gb(model, q, ctx)
if mem <= budget_gb:
return q, ctx, mem
cur_ctx = ctx // 2
while cur_ctx >= 1024:
for q in QUANT_HIERARCHY:
mem = estimate_memory_gb(model, q, cur_ctx)
if mem <= budget_gb:
return q, cur_ctx, mem
cur_ctx //= 2
return None, None, None
def infer_use_case(model):
name = model.get("name", "").lower()
uc = model.get("use_case", "").lower()
combined = name + " " + uc
if any(k in combined for k in ("embedding", "embed", "bge")):
return "embedding"
if any(k in combined for k in ("tts", "text-to-speech", "speech-synthesis", "cosyvoice", "parler")):
return "tts"
if any(k in combined for k in ("stt", "speech-to-text", "whisper", "transcri", "asr")):
return "stt"
if "code" in combined:
return "coding"
if any(k in combined for k in ("vision", "multimodal", "vlm", "vl-")):
return "multimodal"
if any(k in combined for k in ("reason", "chain-of-thought", "deepseek-r1")):
return "reasoning"
if any(k in combined for k in ("chat", "instruction")):
return "chat"
return "general"
_models_cache = None
def get_models():
global _models_cache
if _models_cache is None:
data_path = os.path.join(os.path.dirname(__file__), "data", "hf_models.json")
try:
with open(data_path) as f:
_models_cache = json.load(f)
except (FileNotFoundError, json.JSONDecodeError):
_models_cache = []
return _models_cache
def model_catalog_path():
return os.path.join(os.path.dirname(__file__), "data", "hf_models.json")
+14
View File
@@ -0,0 +1,14 @@
# services/memory/__init__.py
"""Memory service — persistent memory storage and retrieval."""
from .service import MemoryService, Memory, MemorySearchResult
from .memory import MemoryManager
from .memory_vector import MemoryVectorStore
__all__ = [
"MemoryService",
"Memory",
"MemorySearchResult",
"MemoryManager",
"MemoryVectorStore",
]
+359
View File
@@ -0,0 +1,359 @@
import json
import logging
import os
import time
import uuid
import re
from typing import List, Dict, Tuple
from datetime import datetime
logger = logging.getLogger(__name__)
def tokenize(text: str) -> List[str]:
"""Simple tokenizer that splits on whitespace and removes punctuation."""
return [word.strip('.,!?";') for word in text.split()]
def get_text_similarity(text1: str, text2: str) -> float:
"""Calculate Jaccard similarity between two texts."""
if not text1 or not text2:
return 0.0
tokens1 = set(tokenize(text1.lower()))
tokens2 = set(tokenize(text2.lower()))
if not tokens1 and not tokens2:
return 1.0
if not tokens1 or not tokens2:
return 0.0
intersection = tokens1.intersection(tokens2)
union = tokens1.union(tokens2)
return len(intersection) / len(union)
class MemoryManager:
def __init__(self, data_dir: str):
self.memory_file = os.path.join(data_dir, "memory.json")
self.ensure_file_exists()
def extract_memory_from_chat(self, chat_history: List[Dict], session_id: str = None) -> List[Dict]:
"""
Extract memory entries from chat history as a fallback when LLM fails.
Args:
chat_history: List of chat messages with 'role' and 'content' keys
session_id: Optional session ID to associate with extracted memories
Returns:
List of memory entries with text, timestamp, and optional session_id
"""
memories = []
for msg in chat_history:
if msg.get("role") == "assistant":
content = str(msg.get("content", ""))
lines = content.split('\n')
for line in lines:
line = line.strip()
# Look for bullet points or numbered lists that might contain memories
if re.match(r'^[-*•]|\d+\.', line):
# Extract the text after the bullet/number
text_match = re.match(r'^[-*•]|\d+\.\s*(.*)', line)
if text_match:
text = text_match.group(1).strip()
if text:
memories.append({
"text": text,
"timestamp": int(datetime.now().timestamp()),
"session_id": session_id
})
# If we see a heading that suggests memories
elif re.search(r'memory|fact|note|remember', line, re.I):
pass
# If we see a clear separator or end
elif re.match(r'^={3,}|-{3,}|_{3,}', line):
pass
return memories
def process_inline_memory_command(self, message: str) -> Tuple[bool, str]:
"""
Check if a message is an inline memory command (e.g. "remember: X").
Args:
message: The user message to check
Returns:
Tuple of (is_command, extracted_text) where is_command is True if
the message matches the memory command pattern
"""
# Pattern for memory commands: "remember: X", "memorize: X", "save: X", etc.
pattern = r'^(?:remember|memorize|save|note|store)[:\-]?\s+(.+)$'
match = re.match(pattern, message.strip(), re.IGNORECASE)
if match:
return True, match.group(1).strip()
else:
return False, ""
def ensure_file_exists(self):
"""Create memory file if it doesn't exist."""
if not os.path.exists(self.memory_file):
with open(self.memory_file, 'w', encoding='utf-8') as f:
json.dump([], f, ensure_ascii=False, indent=2)
def load_all(self) -> List[Dict]:
"""Load all memory entries from JSON file (unfiltered)."""
if not os.path.exists(self.memory_file):
return []
try:
with open(self.memory_file, "r", encoding="utf-8") as f:
data = json.load(f)
if isinstance(data, list):
return self._validate_entries(data)
except (json.JSONDecodeError, PermissionError) as e:
logger.error("Error loading memory.json: %s", e)
return self._migrate_from_legacy()
return []
def load(self, owner: str = None) -> List[Dict]:
"""Load memory entries, filtered by owner."""
entries = self.load_all()
if owner is None:
return entries
return [e for e in entries if e.get("owner") == owner]
def claim_ownerless(self, owner: str):
"""Assign all ownerless memory entries to the given owner. Run once to migrate."""
entries = self.load_all()
changed = False
for e in entries:
if not e.get("owner"):
e["owner"] = owner
changed = True
if changed:
self.save(entries)
logger.info("Claimed %d ownerless memories for %s", sum(1 for e in entries if e.get("owner") == owner), owner)
def _validate_entries(self, entries: List[Dict]) -> List[Dict]:
"""Ensure all entries have required fields."""
validated = []
for entry in entries:
if "id" not in entry:
entry["id"] = str(uuid.uuid4())
if "timestamp" not in entry:
entry["timestamp"] = int(time.time())
if "source" not in entry:
entry["source"] = "unknown"
if "category" not in entry:
entry["category"] = "fact"
validated.append(entry)
return validated
def _migrate_from_legacy(self) -> List[Dict]:
"""Migrate from old text format to JSON if needed."""
legacy_path = os.path.join(os.path.dirname(self.memory_file), "memory.txt")
if not os.path.exists(legacy_path):
return []
logger.info("Converting legacy memory.txt to new JSON format")
try:
with open(legacy_path, "r", encoding="utf-8") as f:
lines = [ln.strip() for ln in f.readlines() if ln.strip()]
entries = []
for line in lines:
entries.append({
"id": str(uuid.uuid4()),
"text": line,
"timestamp": int(time.time()),
"source": "user",
"category": "fact"
})
self.save(entries)
return entries
except Exception as e:
logger.error("Failed to convert legacy memory: %s", e)
return []
def save(self, entries: List[Dict]):
"""Save memory entries to JSON file."""
# Validate entries before saving
for entry in entries:
if "id" not in entry:
entry["id"] = str(uuid.uuid4())
if "timestamp" not in entry:
entry["timestamp"] = int(time.time())
if "source" not in entry:
entry["source"] = "user"
if "category" not in entry:
entry["category"] = "fact"
# Use atomic write
tmp_file = self.memory_file + ".tmp"
with open(tmp_file, "w", encoding="utf-8") as f:
json.dump(entries, f, ensure_ascii=False, indent=2)
os.replace(tmp_file, self.memory_file)
def add_entry(self, text: str, source: str = "user", category: str = "fact", owner: str = None) -> Dict:
"""Add a new memory entry."""
if not text.strip():
raise ValueError("Memory text cannot be empty")
entry = {
"id": str(uuid.uuid4()),
"text": text.strip(),
"timestamp": int(time.time()),
"source": source,
"category": category
}
if owner:
entry["owner"] = owner
return entry
def find_duplicates(self, text: str, entries: List[Dict] = None) -> List[Dict]:
"""Find duplicate memory entries based on text content."""
if entries is None:
entries = self.load()
text_lower = text.strip().lower()
return [entry for entry in entries if entry["text"].lower() == text_lower]
def categorize_memory_by_relevance(self, message: str, memories: list):
"""Categorize memories by type and relevance"""
categories = {
"contacts": [],
"preferences": [],
"facts": [],
"tasks": []
}
msg_lower = message.lower()
for mem in memories:
text_lower = mem["text"].lower()
# Contact info
if any(word in text_lower for word in ["phone", "email", "address", "lives", "works"]):
if any(word in msg_lower for word in ["contact", "phone", "address", "email"]):
categories["contacts"].append(mem)
# Personal preferences
elif any(word in text_lower for word in ["likes", "dislikes", "prefers", "favorite"]):
if any(word in msg_lower for word in ["like", "prefer", "favorite", "want"]):
categories["preferences"].append(mem)
# Tasks and todos
elif any(word in text_lower for word in ["todo", "task", "remind", "meeting"]):
if any(word in msg_lower for word in ["todo", "task", "schedule", "remind"]):
categories["tasks"].append(mem)
# General facts - only if very relevant
else:
if get_text_similarity(message, mem["text"]) > 0.4:
categories["facts"].append(mem)
return categories
def get_relevant_memories(self, query: str, memories: list, threshold: float = 0.05, max_items: int = 8):
"""Get memories that are relevant to the query based on text similarity and semantic keyword matching."""
if not memories or not query.strip():
return []
# Define keyword categories for semantic matching
identity_words = ["name", "who", "i", "am", "called", "identity", "myself", "me", "my"]
contact_words = ["phone", "email", "address", "contact", "number", "where", "located", "reach"]
preference_words = ["like", "prefer", "favorite", "want", "love", "hate", "dislike", "enjoy", "interested"]
task_words = ["todo", "task", "remind", "meeting", "appointment", "schedule", "deadline"]
fact_words = ["what", "when", "where", "how", "why", "explain", "describe", "information", "know"]
query_lower = query.lower()
# Determine query type based on keywords
query_type = None
if any(word in query_lower for word in identity_words):
query_type = "identity"
elif any(word in query_lower for word in contact_words):
query_type = "contact"
elif any(word in query_lower for word in preference_words):
query_type = "preference"
elif any(word in query_lower for word in task_words):
query_type = "task"
elif any(word in query_lower for word in fact_words):
query_type = "fact"
relevant = []
identity_memories = []
other_memories = []
# Separate identity memories from others
for memory in memories:
memory_text = memory["text"].lower()
# Check if this is an identity memory (contains name patterns or identity indicators)
is_identity = any([
re.search(r'\b[A-Z][a-z]+ [A-Z][a-z]+\b', memory["text"]),
any(word in memory_text for word in ["name is", "i'm", "i am", "called", "my name", "named", "call me"])
])
if is_identity:
identity_memories.append(memory)
else:
other_memories.append(memory)
# For identity queries, include all identity memories regardless of similarity
if query_type == "identity" and identity_memories:
# Give them high scores to ensure they're included first
for memory in identity_memories:
relevant.append((0.9, memory)) # High score for identity memories in identity queries
# Process other memories with similarity scoring
for memory in other_memories:
memory_text = memory["text"].lower()
memory_tokens = set(tokenize(memory_text))
query_tokens = set(tokenize(query_lower))
# Calculate base Jaccard similarity
if not query_tokens or not memory_tokens:
continue
base_similarity = len(query_tokens & memory_tokens) / len(query_tokens | memory_tokens)
final_score = base_similarity
# Apply boosts based on semantic matching
if query_type == "contact":
# Boost memories with contact information
has_contact_info = any(word in memory_text for word in ["@gmail.com", "@", ".com",
"phone", "number", "address",
"http", "www", "tel:"])
if has_contact_info:
final_score *= 1.4 # 40% boost for contact-related memories
elif query_type == "preference":
# Boost memories with preference indicators
has_preference = any(word in memory_text for word in ["like", "love", "hate", "dislike",
"prefer", "favorite", "enjoy", "interested"])
if has_preference:
final_score *= 1.3 # 30% boost for preference-related memories
elif query_type == "task":
# Boost memories with task indicators
has_task = any(word in memory_text for word in ["todo", "task", "remind", "meeting",
"appointment", "schedule", "deadline", "need to"])
if has_task:
final_score *= 1.3 # 30% boost for task-related memories
# Always consider exact phrase matches as highly relevant
if query.lower() in memory["text"].lower():
final_score = max(final_score, 0.8) # Ensure high relevance for exact matches
# Include memory if it meets threshold after boosts
if final_score >= threshold:
relevant.append((final_score, memory))
# Sort by final score (descending) and return top matches
relevant.sort(key=lambda x: x[0], reverse=True)
return [mem for _, mem in relevant[:max_items]]
+533
View File
@@ -0,0 +1,533 @@
"""
memory_extractor.py
Background auto-extraction of facts from chat conversations.
After each LLM response, this module sends the last few messages to the LLM
asking it to extract memorable facts, then stores them in both memory.json
and the FAISS vector index.
Periodically audits all memories via LLM to consolidate duplicates,
rewrite vague entries, and remove junk.
"""
import hashlib
import json
import logging
import os
import re
from typing import Optional
logger = logging.getLogger(__name__)
def _tidy_state_path(memory_manager) -> str:
"""Sidecar JSON next to memory.json that remembers the fingerprint of
the last successfully-audited state per owner. Lets the audit short-
circuit when nothing has changed since the previous tidy — running
the LLM again on an already-clean list was wasting 30-120s per call
and occasionally timing out on the second pass."""
return os.path.join(os.path.dirname(memory_manager.memory_file), "memory_tidy_state.json")
def _fingerprint_entries(entries) -> str:
"""Stable hash of an owner's memories — order-independent, depends
only on id+text+category. Any add/edit/delete invalidates it."""
items = sorted(
(str(e.get("id", "")), e.get("text", ""), e.get("category", ""))
for e in entries
)
h = hashlib.sha256()
for triple in items:
h.update(("\x1f".join(triple) + "\x1e").encode("utf-8"))
return h.hexdigest()
def _load_tidy_state(memory_manager) -> dict:
path = _tidy_state_path(memory_manager)
try:
with open(path, "r") as f:
data = json.load(f)
return data if isinstance(data, dict) else {}
except (FileNotFoundError, json.JSONDecodeError):
return {}
def _save_tidy_state(memory_manager, owner: Optional[str], fingerprint: str) -> None:
path = _tidy_state_path(memory_manager)
state = _load_tidy_state(memory_manager)
state[owner or ""] = {"fingerprint": fingerprint}
try:
with open(path, "w") as f:
json.dump(state, f, indent=2)
except OSError as e:
logger.warning(f"Could not persist tidy fingerprint: {e}")
EXTRACT_SYSTEM_PROMPT = (
"You are a memory extraction assistant. Analyze the conversation and extract ONLY "
"durable personal facts about the user that would be useful across many future conversations.\n\n"
"Good examples: name, job title, city, family members, long-term projects, strong preferences.\n"
"Bad examples: what they asked about today, temporary moods, generic statements, "
"things the assistant said, one-off tasks, opinions on the current topic.\n\n"
"Rules:\n"
"- MAX 2 facts per conversation — only the most important\n"
"- Only extract facts the USER stated or clearly implied\n"
"- Each fact must be a single short sentence (under 15 words)\n"
"- If a fact is similar to something likely already known, skip it\n"
"- If nothing durable was revealed, return []\n\n"
"Return a JSON array of objects with 'text' and 'category' fields.\n"
"Categories: 'identity', 'preference', 'fact', 'contact', 'project', 'goal'\n\n"
"Return ONLY valid JSON, no markdown fences."
)
# How many recent messages to include for extraction
CONTEXT_WINDOW = 6
AUDIT_SYSTEM_PROMPT = (
"You are a memory database curator. Be CONSERVATIVE: remove only TRUE "
"duplicates and clearly useless entries. Every distinct fact must survive. "
"When in doubt, KEEP the entry. Return the cleaned list.\n\n"
"Rules:\n"
"1. MERGE only entries that state the SAME fact in different words. If you "
"are not sure two entries are the same fact, KEEP BOTH.\n"
" Merge: 'User's name is Sam' + 'The user is called Sam' -> one.\n"
" Do NOT merge related-but-distinct facts: 'Likes Python' and 'Uses "
"Python at work' are DIFFERENT — keep both.\n"
"2. REMOVE only entries that are genuinely worthless: about what the AI did "
"(not the user), empty, or meaningless. Do NOT drop a real fact just "
"because it seems minor or niche.\n"
"3. Keep the original wording. Only lightly trim obvious redundancy — do "
"NOT aggressively rewrite or shorten.\n"
"4. Preserve the 'id' of the entry you keep when merging.\n"
"5. Never invent facts. When unsure, KEEP.\n\n"
"Return a JSON array of objects with fields: id, text, category.\n"
"Return ONLY valid JSON, no markdown fences."
)
AUDIT_INTERVAL = 5 # audit every N new memories added
_extractions_since_audit = 0
def _message_text(message) -> str:
content = getattr(message, "content", None)
if content is None and isinstance(message, dict):
content = message.get("content")
if isinstance(content, str):
return content.strip()
if isinstance(content, list):
parts = []
for item in content:
if isinstance(item, dict):
parts.append(str(item.get("text") or item.get("content") or ""))
else:
parts.append(str(item))
return " ".join(p for p in parts if p).strip()
return ""
def _message_role(message) -> str:
role = getattr(message, "role", None)
if role is None and isinstance(message, dict):
role = message.get("role")
return str(role or "").lower()
def _clean_memory_value(value: str, max_len: int = 80) -> str:
value = re.sub(r"\s+", " ", value or "").strip(" .,!?:;\"'`“”‘’")
value = re.sub(r"^(?:the|a|an)\s+", "", value, flags=re.I)
if not value or len(value) > max_len:
return ""
if re.search(r"https?://|@|[{}<>]", value):
return ""
return value
def _fallback_memory_candidates(messages) -> list[dict]:
"""Extract obvious durable facts without relying on the LLM.
This is deliberately narrow. The LLM remains the main extractor, but
simple identity/preference/goal statements should not silently vanish just
because the background model judged them too conversational.
"""
candidates = []
seen = set()
def add(text: str, category: str):
text = _clean_memory_value(text, 120)
if not text:
return
key = text.lower()
if key in seen:
return
seen.add(key)
candidates.append({"text": text, "category": category})
for msg in messages:
if _message_role(msg) != "user":
continue
text = _message_text(msg)
if not text:
continue
m = re.search(r"\bmy name is\s+([A-Za-z][A-Za-z0-9 .'\-]{1,50})\b", text, re.I)
if m:
name = _clean_memory_value(m.group(1), 50)
if name:
add(f"User's name is {name}.", "identity")
m = re.search(r"\bcall me\s+([A-Za-z][A-Za-z0-9 .'\-]{1,50})\b", text, re.I)
if m:
name = _clean_memory_value(m.group(1), 50)
if name:
add(f"User wants to be called {name}.", "identity")
m = re.search(r"\bi (?:live in|am from|'m from)\s+([^.!?\n]{2,80})", text, re.I)
if m:
place = _clean_memory_value(m.group(1), 80)
if place:
add(f"User lives in {place}.", "identity")
m = re.search(r"\bi (?:prefer|like|love|hate|do not like|don't like)\s+([^.!?\n]{4,100})", text, re.I)
if m:
preference = _clean_memory_value(m.group(1), 100)
if preference:
add(f"User prefers {preference}.", "preference")
m = re.search(
r"\bi (?:(?:want|would like|plan|hope) to|wanna) "
r"(?:go|travel|move|visit) to\s+([^.!?\n]{2,80})",
text,
re.I,
)
if m:
destination = _clean_memory_value(m.group(1), 80)
if destination:
add(f"User wants to visit {destination}.", "goal")
return candidates[:2]
def _is_text_duplicate(new_text: str, existing: list, threshold: float = 0.6) -> bool:
"""Check if new_text is too similar to any existing memory (Jaccard similarity)."""
new_tokens = set(new_text.lower().split())
if not new_tokens:
return False
for entry in existing:
old_tokens = set(entry.get("text", "").lower().split())
if not old_tokens:
continue
intersection = new_tokens & old_tokens
union = new_tokens | old_tokens
if len(intersection) / len(union) >= threshold:
return True
return False
async def extract_and_store(
session,
memory_manager,
memory_vector,
endpoint_url: str,
model: str,
headers: Optional[dict] = None,
):
"""Extract facts from recent conversation and store them.
Designed to run as a background task (asyncio.create_task).
Errors are logged, never raised.
"""
try:
from src.llm_core import llm_call_async
# Get last N messages from session
messages = session.get_context_messages()
recent = messages[-CONTEXT_WINDOW:] if len(messages) > CONTEXT_WINDOW else messages
if len(recent) < 2:
return # Need at least a user message and assistant response
fallback_facts = _fallback_memory_candidates(recent)
extraction_messages = [
{"role": "system", "content": EXTRACT_SYSTEM_PROMPT},
] + recent
facts = []
try:
raw = await llm_call_async(
endpoint_url,
model,
extraction_messages,
temperature=0.1,
max_tokens=500,
headers=headers,
)
# Parse JSON from response (handle markdown fences if model wraps them)
text = raw.strip()
if text.startswith("```"):
text = text.split("\n", 1)[-1].rsplit("```", 1)[0].strip()
try:
facts = json.loads(text)
except json.JSONDecodeError:
logger.debug("Memory extraction returned non-JSON")
except Exception as e:
logger.warning(f"LLM memory extraction failed; using fallback candidates if available: {e}")
if not isinstance(facts, list):
facts = []
if fallback_facts:
facts = list(facts) + fallback_facts
if not facts:
logger.info("Auto memory extraction ran: 0 candidates")
return
# Get owner from session
_owner = getattr(session, 'owner', None)
existing = memory_manager.load_all()
added = 0
for fact in facts:
if isinstance(fact, str):
fact_text = fact
category = "fact"
elif isinstance(fact, dict):
fact_text = fact.get("text", "").strip()
category = fact.get("category", "fact")
else:
continue
if not fact_text or len(fact_text) < 5:
continue
# Dedup: check vector similarity first (fast), then exact text match
if memory_vector and memory_vector.healthy:
existing_id = memory_vector.find_similar(fact_text, threshold=0.72)
if existing_id:
logger.debug(f"Memory dedup (vector): '{fact_text[:50]}' matches {existing_id}")
continue
# Text dedup fallback: exact match + fuzzy similarity
user_existing = [e for e in existing if e.get("owner") == _owner or e.get("owner") is None] if _owner else existing
if memory_manager.find_duplicates(fact_text, user_existing):
continue
# Fuzzy text similarity check (catches rephrased duplicates when vector index is unavailable)
if _is_text_duplicate(fact_text, user_existing):
logger.debug(f"Memory dedup (fuzzy): '{fact_text[:50]}' too similar to existing")
continue
entry = memory_manager.add_entry(fact_text, source="auto", category=category, owner=_owner)
# Auto-pin identity facts (name, job, location) — core context
if category == "identity":
entry["pinned"] = True
if hasattr(session, "session_id"):
entry["session_id"] = session.session_id
elif hasattr(session, "name"):
entry["session_id"] = session.name
existing.append(entry)
# Add to vector index
if memory_vector and memory_vector.healthy:
memory_vector.add(entry["id"], fact_text)
added += 1
if added > 0:
memory_manager.save(existing)
try:
from src.event_bus import fire_event
for _ in range(added):
fire_event("memory_added", _owner)
except Exception:
logger.debug("memory_added event dispatch failed", exc_info=True)
logger.info(f"Auto-extracted {added} memories from session")
global _extractions_since_audit
_extractions_since_audit += added
if _extractions_since_audit >= AUDIT_INTERVAL:
_extractions_since_audit = 0
logger.info("Audit threshold reached, running memory audit")
await audit_memories(
memory_manager, memory_vector, endpoint_url, model, headers, owner=_owner
)
else:
logger.info("Auto memory extraction ran: 0 added")
except Exception as e:
logger.error(f"Memory extraction failed: {e}")
async def audit_memories(
memory_manager,
memory_vector,
endpoint_url: str,
model: str,
headers: Optional[dict] = None,
owner: Optional[str] = None,
):
"""Send all memories to the LLM for deduplication and consolidation.
- Merges near-duplicate entries
- Rewrites vague entries to be concise
- Removes junk / non-personal entries
- Rebuilds the vector index afterwards
Safe to call manually or from the automatic trigger in extract_and_store.
Errors are logged, never raised.
"""
try:
from src.llm_core import llm_call_async
existing = memory_manager.load(owner=owner)
if not existing:
logger.info("Memory audit: nothing to audit")
return {"before": 0, "after": 0}
before_count = len(existing)
# Skip the LLM call entirely when this exact set of memories was
# already audited — the previous tidy left them in a clean state
# and nothing has changed since. Returns instantly so the UI shows
# "Already clean" without spending 30-120s on a wasted LLM round.
# The fingerprint includes id+text+category; any add/edit/delete
# invalidates it and the audit runs normally.
current_fp = _fingerprint_entries(existing)
last_state = _load_tidy_state(memory_manager).get(owner or "") or {}
if last_state.get("fingerprint") == current_fp:
logger.info("Memory audit: state unchanged since last tidy — skipping LLM")
return {
"before": before_count,
"after": before_count,
"already_tidy": True,
}
# Build payload: list of {id, text, category} for the LLM
memory_payload = [
{"id": m["id"], "text": m["text"], "category": m.get("category", "fact")}
for m in existing
]
audit_messages = [
{"role": "system", "content": AUDIT_SYSTEM_PROMPT},
{"role": "user", "content": json.dumps(memory_payload, ensure_ascii=False)},
]
raw = await llm_call_async(
endpoint_url,
model,
audit_messages,
temperature=0.1,
# 16384 (was 2000): the deduped list of all memories can be large,
# and a reasoning model spends tokens thinking first — 2000 truncated
# the JSON so it never parsed ("bad_json").
max_tokens=16384,
headers=headers,
# Bound the call so the Tidy whirlpool can't spin indefinitely on a
# slow/large generation.
timeout=120,
)
# Parse the JSON list, tolerating reasoning-model noise: <think> blocks,
# markdown fences, leading prose, and trailing commas.
import re as _re
text = (raw or "").strip()
text = _re.sub(r'<think(?:ing)?>[\s\S]*?</think(?:ing)?>', '', text, flags=_re.I).strip()
def _loads_list(s):
if not s:
return None
for cand in (s, _re.sub(r',(\s*[}\]])', r'\1', s)):
try:
v = json.loads(cand)
if isinstance(v, list):
return v
except Exception:
continue
return None
cleaned = _loads_list(text)
if cleaned is None:
_m = _re.search(r'```(?:json)?\s*\n?([\s\S]*?)```', text)
if _m:
cleaned = _loads_list(_m.group(1).strip())
if cleaned is None:
_a, _b = text.find('['), text.rfind(']')
if _a >= 0 and _b > _a:
cleaned = _loads_list(text[_a:_b + 1])
if cleaned is None:
logger.error(f"Memory audit returned non-JSON: {text[:300]}")
return {"before": before_count, "after": before_count, "error": "bad_json"}
# Build lookup of original entries by ID so we can preserve metadata
originals = {m["id"]: m for m in existing}
final_entries = []
for item in cleaned:
if not isinstance(item, dict):
continue
mid = item.get("id", "")
new_text = item.get("text", "").strip()
if not new_text:
continue
if mid in originals:
# Preserve original metadata, update text + category
entry = originals[mid].copy()
entry["text"] = new_text
if item.get("category"):
entry["category"] = item["category"]
else:
# ID not found — skip to avoid inventing entries
logger.debug(f"Audit returned unknown id {mid}, skipping")
continue
final_entries.append(entry)
after_count = len(final_entries)
# Safety net against catastrophic over-deletion. A conservative tidy
# should never wipe out half the store in one pass — if the model
# returned far fewer entries than it was given (over-consolidation, a
# dropped/truncated list, or it ignored ids), treat it as a misfire and
# DON'T save. Better to no-op than to silently lose memories.
if before_count >= 8 and after_count < before_count * 0.5:
logger.warning(
f"Memory audit would cut {before_count} -> {after_count} "
f"(>50% removed) — refusing as unsafe, keeping originals"
)
return {"before": before_count, "after": before_count, "error": "unsafe_removal"}
# Merge audited entries back with other users' entries
if owner:
all_entries = memory_manager.load_all()
audited_ids = {e["id"] for e in final_entries}
other_entries = [e for e in all_entries if e.get("owner") != owner and (e.get("owner") is not None)]
# Also keep legacy entries that weren't part of this audit
for e in all_entries:
if e.get("owner") is None and e["id"] not in audited_ids and e["id"] not in {o["id"] for o in other_entries}:
other_entries.append(e)
memory_manager.save(final_entries + other_entries)
else:
memory_manager.save(final_entries)
logger.info(
f"Memory audit complete: {before_count} -> {after_count} entries "
f"({before_count - after_count} removed/merged)"
)
# Rebuild vector index
if memory_vector and memory_vector.healthy:
memory_vector.rebuild(final_entries)
# Persist the post-tidy fingerprint so the next call short-circuits
# if nothing has changed in the meantime.
_save_tidy_state(memory_manager, owner, _fingerprint_entries(final_entries))
return {"before": before_count, "after": after_count}
except Exception as e:
logger.error(f"Memory audit failed: {e}")
return {"error": str(e)}
+175
View File
@@ -0,0 +1,175 @@
"""
memory_vector.py
ChromaDB-backed vector store for memory entries.
Shares the EmbeddingClient with RAG to save memory.
Stores pre-computed embeddings (ChromaDB does not manage embedding).
"""
import logging
from typing import List, Dict, Optional
logger = logging.getLogger(__name__)
class MemoryVectorStore:
"""Vector index over memory entries for semantic retrieval."""
COLLECTION_NAME = "odysseus_memories"
def __init__(self, data_dir: str, embedding_model=None):
self._model = embedding_model
self._collection = None
self._healthy = False
self._initialize()
def _initialize(self):
try:
from src.chroma_client import get_chroma_client
if self._model is None:
from src.embeddings import get_embedding_client
self._model = get_embedding_client()
if self._model is None:
raise RuntimeError("No embedding backend available")
logger.info(f"MemoryVectorStore using embeddings: {self._model.url}")
client = get_chroma_client()
self._collection = client.get_or_create_collection(
name=self.COLLECTION_NAME,
metadata={"hnsw:space": "cosine"},
)
self._healthy = True
count = self._collection.count()
logger.info(f"MemoryVectorStore ready (entries={count})")
except Exception as e:
logger.error(f"MemoryVectorStore init failed: {e}")
@property
def healthy(self) -> bool:
return self._healthy
def _embed(self, texts: List[str]) -> List[List[float]]:
vecs = self._model.encode(texts, normalize_embeddings=True)
return vecs.tolist()
def count(self) -> int:
"""Return the number of stored vectors."""
if not self._healthy:
return 0
return self._collection.count()
def add(self, memory_id: str, text: str):
"""Add a single memory entry to the vector index."""
if not self._healthy:
return
# Skip if already exists
existing = self._collection.get(ids=[memory_id])
if existing["ids"]:
return
embeddings = self._embed([text])
self._collection.add(
ids=[memory_id],
embeddings=embeddings,
documents=[text],
metadatas=[{"source": "memory"}],
)
def remove(self, memory_id: str):
"""Remove a memory entry. O(1) — no rebuild needed."""
if not self._healthy:
return
try:
self._collection.delete(ids=[memory_id])
except Exception as e:
logger.warning(f"memory remove {memory_id}: {e}")
def search(self, query: str, k: int = 8) -> List[Dict]:
"""Search for the most relevant memory IDs by semantic similarity.
Returns list of {"memory_id": str, "score": float}.
ChromaDB cosine distance = 1 - cosine_similarity.
We convert back: similarity = 1.0 - distance.
"""
if not self._healthy or self._collection.count() == 0:
return []
embeddings = self._embed([query])
actual_k = min(k, self._collection.count())
results = self._collection.query(
query_embeddings=embeddings,
n_results=actual_k,
)
out = []
for idx, mid in enumerate(results["ids"][0]):
distance = results["distances"][0][idx]
out.append({
"memory_id": mid,
"score": round(1.0 - distance, 4),
})
return out
def find_similar(self, text: str, threshold: float = 0.92) -> Optional[str]:
"""Check if a near-duplicate exists. Returns memory_id if found, else None."""
if not self._healthy or self._collection.count() == 0:
return None
embeddings = self._embed([text])
results = self._collection.query(
query_embeddings=embeddings,
n_results=1,
)
if results["ids"][0]:
distance = results["distances"][0][0]
similarity = 1.0 - distance
if similarity >= threshold:
return results["ids"][0][0]
return None
def rebuild(self, memories: List[Dict]):
"""Rebuild the entire index from a list of memory entries.
Each entry must have 'id' and 'text' keys."""
if not self._healthy:
return
from src.chroma_client import get_chroma_client
# Delete and recreate collection for a clean rebuild
client = get_chroma_client()
try:
client.delete_collection(self.COLLECTION_NAME)
except Exception:
pass
self._collection = client.get_or_create_collection(
name=self.COLLECTION_NAME,
metadata={"hnsw:space": "cosine"},
)
texts = []
ids = []
for mem in memories:
text = mem.get("text", "").strip()
mid = mem.get("id", "")
if text and mid:
texts.append(text)
ids.append(mid)
if texts:
# Batch in chunks of 100 to avoid oversized requests
for i in range(0, len(texts), 100):
batch_texts = texts[i:i + 100]
batch_ids = ids[i:i + 100]
embeddings = self._embed(batch_texts)
self._collection.add(
ids=batch_ids,
embeddings=embeddings,
documents=batch_texts,
metadatas=[{"source": "memory"}] * len(batch_ids),
)
logger.info(f"MemoryVectorStore rebuilt with {len(ids)} entries")
+137
View File
@@ -0,0 +1,137 @@
# services/memory/service.py
"""Memory service — persistent memory storage and retrieval."""
from dataclasses import dataclass, field
from typing import List, Optional, Dict, Any
import os
from .memory import MemoryManager
from .memory_vector import MemoryVectorStore
@dataclass
class Memory:
"""A stored memory."""
id: str
text: str
timestamp: int
session_id: Optional[str] = None
metadata: Dict[str, Any] = field(default_factory=dict)
@dataclass
class MemorySearchResult:
"""Result of memory search."""
memories: List[Memory]
query: str
total: int
class MemoryService:
"""
Memory storage and retrieval service.
Usage:
service = MemoryService()
await service.remember("User prefers dark mode")
results = await service.recall("preferences")
"""
def __init__(self, data_dir: str = "data"):
self.manager = MemoryManager(data_dir)
self.vector_store = MemoryVectorStore(data_dir) if os.path.exists(
os.path.join(data_dir, "memory_vectors")
) else None
async def remember(self, text: str, session_id: Optional[str] = None) -> Memory:
"""
Store a new memory.
Args:
text: Memory content
session_id: Optional session association
Returns:
Created Memory object
"""
import uuid
import time
memory_id = str(uuid.uuid4())[:8]
timestamp = int(time.time())
entry = {
"id": memory_id,
"text": text,
"timestamp": timestamp,
"session_id": session_id,
}
self.manager.add_memory(entry)
# Also add to vector store if available
if self.vector_store:
self.vector_store.add(text, {"id": memory_id, "session_id": session_id})
return Memory(
id=memory_id,
text=text,
timestamp=timestamp,
session_id=session_id,
)
async def recall(self, query: str, top_k: int = 5) -> MemorySearchResult:
"""
Search memories.
Args:
query: Search query
top_k: Max results
Returns:
MemorySearchResult with matching memories
"""
# Try vector search first
if self.vector_store:
results = self.vector_store.search(query, k=top_k)
memories = [
Memory(
id=r.get("id", ""),
text=r.get("text", ""),
timestamp=r.get("timestamp", 0),
session_id=r.get("session_id"),
metadata=r.get("metadata", {}),
)
for r in results
]
return MemorySearchResult(memories=memories, query=query, total=len(memories))
# Fallback to keyword search
results = self.manager.search_memories(query, limit=top_k)
memories = [
Memory(
id=m.get("id", ""),
text=m.get("text", ""),
timestamp=m.get("timestamp", 0),
session_id=m.get("session_id"),
)
for m in results
]
return MemorySearchResult(memories=memories, query=query, total=len(memories))
def get_all(self, limit: int = 100) -> List[Memory]:
"""Get all memories."""
memories = self.manager.get_memories(limit=limit)
return [
Memory(
id=m.get("id", ""),
text=m.get("text", ""),
timestamp=m.get("timestamp", 0),
session_id=m.get("session_id"),
)
for m in memories
]
def delete(self, memory_id: str) -> bool:
"""Delete a memory by ID."""
return self.manager.delete_memory(memory_id)
+209
View File
@@ -0,0 +1,209 @@
"""
skill_extractor.py
Background auto-extraction of skills from complex agent runs.
When the agent takes >= 2 rounds or >= 2 tool calls to complete a task,
we ask the LLM to distill the approach into a reusable skill.
"""
import json
import logging
from typing import Optional
logger = logging.getLogger(__name__)
SKILL_EXTRACT_PROMPT = (
"You are analyzing an AI agent's work session. The agent took {rounds} rounds "
"and {tool_count} tool calls to complete the task.\n\n"
"Extract a reusable 'skill' ONLY IF the session contains a concrete, "
"repeatable procedure the agent could follow to solve a similar problem "
"ON THE COMPUTER next time (e.g. a sequence of shell commands, code, file "
"edits, API calls, or tool usage).\n\n"
"Return null (the bare word, no JSON) when the session is NOT a reusable "
"computer procedure, including:\n"
"- The real work happened OUTSIDE the computer (the user did something "
"physically, in person, on another device, or by hand) and the agent only "
"discussed or advised it.\n"
"- A one-off, personal, or context-specific task that won't recur "
"(personal errands, a specific person/place/date, casual conversation).\n"
"- A pure question/answer or explanation with no transferable method.\n"
"- The agent failed, gave up, or the approach is not worth repeating.\n\n"
"When (and only when) a genuine reusable procedure exists, return a JSON "
"object with:\n"
'- "title": short name (under 10 words)\n'
'- "problem": what was the challenge (1-2 sentences)\n'
'- "solution": what worked (1-2 sentences)\n'
'- "steps": array of step-by-step instructions (3-7 short steps)\n'
'- "tags": array of relevant keywords (3-5 tags)\n'
'- "confidence": 0.0-1.0 how reliable AND reusable this procedure is\n\n'
"Be conservative: if in doubt, return null.\n"
"Return ONLY valid JSON (or the bare word null), no markdown fences."
)
# Skills the model is unsure about (or that read as one-offs) add clutter —
# drop anything below this confidence.
MIN_CONFIDENCE = 0.6
# How many recent messages to include
CONTEXT_WINDOW = 12
async def maybe_extract_skill(
session,
skills_manager,
endpoint_url: str,
model: str,
headers: dict,
round_count: int,
tool_count: int,
owner: Optional[str] = None,
):
"""Extract a skill if the agent run was complex enough."""
# Quiet by default; flip to DEBUG when chasing extractor issues.
logger.debug(
"[skill-extract] start: rounds=%d tools=%d model=%s owner=%s",
round_count, tool_count, model, owner,
)
if round_count < 2 and tool_count < 2:
logger.debug("[skill-extract] BELOW threshold (need rounds>=2 or tools>=2)")
return None
try:
from src.llm_core import llm_call_async
# Get recent messages
history = session.get_context_messages()
recent = history[-CONTEXT_WINDOW:] if len(history) > CONTEXT_WINDOW else history
if not recent:
logger.debug("[skill-extract] no recent messages, skipping")
return None
# Build conversation summary for extraction
conv_lines = []
for msg in recent:
role = msg.get("role", "?")
content = msg.get("content", "")
if isinstance(content, list):
content = " ".join(
b.get("text", "") for b in content if isinstance(b, dict) and b.get("type") == "text"
)
# Truncate long messages
if len(content) > 500:
content = content[:500] + "..."
conv_lines.append(f"[{role}] {content}")
conversation = "\n".join(conv_lines)
prompt = SKILL_EXTRACT_PROMPT.format(rounds=round_count, tool_count=tool_count)
import time as _time
_t0 = _time.monotonic()
logger.debug(
"[skill-extract] calling LLM (endpoint=%s, ctx=%d msgs, timeout=30s)",
endpoint_url, len(recent),
)
response = await llm_call_async(
endpoint_url,
model,
[
{"role": "system", "content": prompt},
{"role": "user", "content": f"Conversation:\n{conversation}"},
],
headers=headers,
timeout=30,
)
logger.debug(
"[skill-extract] LLM returned in %.1fs (len=%d, head=%r)",
_time.monotonic() - _t0, len(response or ""), (response or "")[:80],
)
if not response or response.strip().lower() == "null":
logger.debug(
"[skill-extract] LLM declined (returned null/empty) — "
"session deemed not a reusable procedure"
)
return None
# Some models (MiniMax, Qwen-Thinker, DeepSeek-R1) emit their
# chain-of-thought BEFORE the JSON output even when asked for
# raw JSON. `strip_think(prose=True, prompt_echo=True)` removes
# <think>…</think> tags AND prose-style "Let me analyze this…"
# preambles. Without it, json.loads bombed on character 0 every
# time and the silent-bail looked like "extractor doesn't work".
try:
from src.text_helpers import strip_think as _strip_think
response = _strip_think(response, prose=True, prompt_echo=True)
except Exception:
pass
# Parse JSON
text = response.strip()
if text.startswith("```"):
text = text.split("\n", 1)[-1].rsplit("```", 1)[0].strip()
# After strip_think, the JSON may still be embedded inside surrounding
# commentary — slice from the first '{' to the matching last '}'.
if text and text[0] != "{":
_start = text.find("{")
_end = text.rfind("}")
if 0 <= _start < _end:
text = text[_start : _end + 1]
data = json.loads(text)
if not data or not isinstance(data, dict):
logger.debug("[skill-extract] parsed JSON not a dict, dropping")
return None
title = data.get("title", "").strip()
if not title:
logger.debug("[skill-extract] LLM returned object with no title, dropping")
return None
# Honour the model's own reliability/reusability estimate — low-
# confidence extractions are usually one-offs or shaky procedures.
try:
_conf = float(data.get("confidence", 0.7))
except (TypeError, ValueError):
_conf = 0.7
if _conf < MIN_CONFIDENCE:
logger.debug(
"[skill-extract] '%s' below confidence floor (%.2f < %.2f) — dropped",
title, _conf, MIN_CONFIDENCE,
)
return None
# Check for duplicate skills
existing = skills_manager.load(owner=owner)
for sk in existing:
if sk.get("title", "").lower() == title.lower():
logger.debug("[skill-extract] '%s' already exists — dropped as duplicate", title)
return None
entry = skills_manager.add_skill(
title=title,
problem=data.get("problem", ""),
solution=data.get("solution", ""),
steps=data.get("steps", []),
tags=data.get("tags", []),
source="learned",
confidence=data.get("confidence", 0.7),
session_id=getattr(session, "session_id", None),
owner=owner,
)
try:
from src.event_bus import fire_event
fire_event("skill_added", owner)
except Exception:
logger.debug("skill_added event dispatch failed", exc_info=True)
logger.info("Auto-extracted skill: %s (id=%s)", title, entry["id"])
return entry
except json.JSONDecodeError as e:
logger.debug("[skill-extract] non-JSON LLM response, dropping: %s", e)
return None
except Exception as e:
# Real exceptions stay INFO+warning so they don't get lost when
# users only have default log level. `exc_info=True` ships the
# full traceback so timeouts vs auth vs import errors are
# distinguishable from outside.
logger.warning("[skill-extract] FAILED: %s", e, exc_info=True)
return None
+444
View File
@@ -0,0 +1,444 @@
"""SKILL.md parser & writer.
Reads/writes a single skill from a `SKILL.md` file with YAML frontmatter
and a structured markdown body. Inspired by Hermes' skills format
(https://hermes-agent.nousresearch.com/docs/user-guide/features/skills).
Frontmatter shape (YAML):
---
name: open-pr-from-branch
description: One-line summary surfaced in the skills index.
version: 1.0.0
category: dev
tags: [git, github]
platforms: [linux, macos] # optional
requires_toolsets: [] # optional
fallback_for_toolsets: [] # optional
status: published # draft | published
confidence: 0.8 # 0..1
source: learned # learned | taught | imported
teacher_model: claude-opus-4-7 # optional
created: 2026-05-09T21:43:00Z
---
Body sections (any subset; rendered as headings):
## When to Use
Trigger conditions in plain English.
## Procedure
1. First step
2. Second step
## Pitfalls
- Common failure mode + how to recover
## Verification
- How to confirm success
Anything else (raw paragraphs after the last known section) is preserved
in `body_extra` and round-trips on save.
Usage counters (`uses`, `last_used`) live in a sidecar `_usage.json` keyed
by skill name, so the SKILL.md file doesn't churn on every retrieval.
"""
from __future__ import annotations
import json
import logging
import re
from dataclasses import dataclass, field
from datetime import datetime
from typing import Any, Dict, List, Optional
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Slugify
# ---------------------------------------------------------------------------
_SLUG_RE = re.compile(r"[^a-z0-9]+")
def slugify(text: str, fallback: str = "skill") -> str:
"""Convert a free-form title to a kebab-case slug suitable for a directory
name. Strips non-alphanumerics, collapses runs, trims leading/trailing
dashes. Caps at 60 chars."""
s = str(text or "").strip().lower()
s = _SLUG_RE.sub("-", s)
s = s.strip("-")
return (s or fallback)[:60]
# ---------------------------------------------------------------------------
# Frontmatter (minimal YAML — we don't pull in PyYAML for one feature)
# ---------------------------------------------------------------------------
# We accept a tiny subset of YAML: scalar `key: value`, inline lists `[a, b]`,
# and block lists with `-`. That covers everything in our schema and avoids
# a new dependency.
_FM_KEY_RE = re.compile(r"^([a-z_][a-z0-9_]*):\s*(.*)$", re.IGNORECASE)
_FM_BLOCK_LIST_RE = re.compile(r"^\s*-\s*(.*)$")
def _parse_scalar(raw: str) -> Any:
raw = raw.strip()
if raw == "":
return ""
if raw.startswith("[") and raw.endswith("]"):
inner = raw[1:-1].strip()
if not inner:
return []
return [_parse_scalar(p) for p in _split_top_level(inner, ",")]
if raw.lower() in ("true", "yes"):
return True
if raw.lower() in ("false", "no"):
return False
if raw.lower() in ("null", "none", "~"):
return None
if (raw[0] == raw[-1]) and raw[0] in ("'", '"'):
return raw[1:-1]
# Try number
try:
if "." in raw:
return float(raw)
return int(raw)
except ValueError:
pass
return raw
def _split_top_level(s: str, sep: str) -> List[str]:
"""Split `s` on `sep` ignoring separators inside [] or quotes."""
out, buf, depth, quote = [], [], 0, None
for ch in s:
if quote:
buf.append(ch)
if ch == quote:
quote = None
continue
if ch in ("'", '"'):
quote = ch
buf.append(ch)
continue
if ch == "[":
depth += 1
elif ch == "]":
depth = max(0, depth - 1)
if ch == sep and depth == 0:
out.append("".join(buf).strip())
buf = []
continue
buf.append(ch)
if buf:
out.append("".join(buf).strip())
return out
def parse_frontmatter(text: str) -> tuple[Dict[str, Any], str]:
"""Pull the YAML frontmatter out of a SKILL.md and return (fm, body)."""
if not text.startswith("---"):
return {}, text
end = text.find("\n---", 3)
if end < 0:
return {}, text
fm_text = text[3:end].lstrip("\n")
body = text[end + 4:].lstrip("\n")
fm: Dict[str, Any] = {}
pending_key: Optional[str] = None
for line in fm_text.splitlines():
if not line.strip() or line.lstrip().startswith("#"):
continue
m = _FM_KEY_RE.match(line)
if m:
key, val = m.group(1), m.group(2)
if val.strip() == "":
pending_key = key
fm[key] = []
else:
fm[key] = _parse_scalar(val)
pending_key = None
continue
m2 = _FM_BLOCK_LIST_RE.match(line)
if m2 and pending_key:
existing = fm.get(pending_key)
if not isinstance(existing, list):
fm[pending_key] = []
fm[pending_key].append(_parse_scalar(m2.group(1)))
return fm, body
def _emit_scalar(v: Any) -> str:
if v is None:
return "null"
if isinstance(v, bool):
return "true" if v else "false"
if isinstance(v, (int, float)):
return str(v)
if isinstance(v, list):
return "[" + ", ".join(_emit_scalar(x) for x in v) + "]"
s = str(v)
if any(c in s for c in (":", "#", "\n", "[", "]", "{", "}", ",", "&", "*", "!", "|", ">", "'", '"', "%", "@")):
return json.dumps(s)
return s
def _as_list(v: Any) -> List[str]:
if v is None:
return []
if isinstance(v, list):
return [str(x) for x in v if x not in (None, "")]
return [str(v)]
def _as_float(v: Any, default: float = 0.8) -> float:
try:
return float(v)
except (TypeError, ValueError):
return default
def emit_frontmatter(fm: Dict[str, Any]) -> str:
lines = []
for k, v in fm.items():
if v is None or v == [] or v == "":
continue
lines.append(f"{k}: {_emit_scalar(v)}")
return "\n".join(lines)
# ---------------------------------------------------------------------------
# Skill body sections
# ---------------------------------------------------------------------------
_KNOWN_SECTIONS = ("when_to_use", "procedure", "pitfalls", "verification")
_HEADING_TO_KEY = {
"when to use": "when_to_use",
"procedure": "procedure",
"steps": "procedure",
"pitfalls": "pitfalls",
"verification": "verification",
}
_KEY_TO_HEADING = {
"when_to_use": "When to Use",
"procedure": "Procedure",
"pitfalls": "Pitfalls",
"verification": "Verification",
}
def parse_body(body: str) -> Dict[str, Any]:
"""Split a SKILL.md body into known sections.
Returns:
{
"when_to_use": str,
"procedure": list[str], # numbered/bulleted lines
"pitfalls": list[str],
"verification": list[str],
"body_extra": str, # anything not under a known heading
}
"""
out = {k: ([] if k != "when_to_use" else "") for k in _KNOWN_SECTIONS}
out["body_extra"] = ""
if not body or not body.strip():
return out
sections: List[tuple[Optional[str], List[str]]] = [(None, [])]
for line in body.splitlines():
m = re.match(r"^##\s+(.*?)\s*$", line)
if m:
heading = m.group(1).strip().lower()
key = _HEADING_TO_KEY.get(heading)
sections.append((key, []))
continue
sections[-1][1].append(line)
for key, lines in sections:
text = "\n".join(lines).strip("\n")
if key is None:
extras = text.strip()
if extras:
out["body_extra"] = (out["body_extra"] + "\n\n" + extras).strip()
continue
if key == "when_to_use":
out["when_to_use"] = text.strip()
else:
out[key] = _parse_list_lines(text)
return out
def _parse_list_lines(text: str) -> List[str]:
"""Pull bullet/numbered lines out of a section body. Plain paragraphs are
treated as a single entry."""
items: List[str] = []
for line in (text or "").splitlines():
s = line.strip()
if not s:
continue
m = re.match(r"^(?:[-*]|\d+[.)])\s+(.*)$", s)
if m:
items.append(m.group(1).strip())
elif items:
# continuation of previous bullet
items[-1] = items[-1] + " " + s
else:
items.append(s)
return items
def emit_body(sections: Dict[str, Any]) -> str:
parts: List[str] = []
when = (sections.get("when_to_use") or "").strip()
if when:
parts.append(f"## {_KEY_TO_HEADING['when_to_use']}\n\n{when}")
for key in ("procedure", "pitfalls", "verification"):
items = sections.get(key) or []
if not items:
continue
heading = _KEY_TO_HEADING[key]
if key == "procedure":
body = "\n".join(f"{i + 1}. {x}" for i, x in enumerate(items))
else:
body = "\n".join(f"- {x}" for x in items)
parts.append(f"## {heading}\n\n{body}")
extra = (sections.get("body_extra") or "").strip()
if extra:
parts.append(extra)
return "\n\n".join(parts) + ("\n" if parts else "")
# ---------------------------------------------------------------------------
# Skill record
# ---------------------------------------------------------------------------
@dataclass
class Skill:
name: str # slug, dir name
description: str = ""
version: str = "1.0.0"
category: str = "general"
tags: List[str] = field(default_factory=list)
platforms: List[str] = field(default_factory=list)
requires_toolsets: List[str] = field(default_factory=list)
fallback_for_toolsets: List[str] = field(default_factory=list)
status: str = "draft" # draft | published
confidence: float = 0.8
source: str = "learned"
teacher_model: Optional[str] = None
owner: Optional[str] = None
created: str = "" # ISO8601
when_to_use: str = ""
procedure: List[str] = field(default_factory=list)
pitfalls: List[str] = field(default_factory=list)
verification: List[str] = field(default_factory=list)
body_extra: str = ""
# Sidecar (not persisted in SKILL.md)
uses: int = 0
last_used: Optional[int] = None
# File path on disk (set when read)
path: Optional[str] = None
# ----------------------------------------------------------------------
# Serialization
# ----------------------------------------------------------------------
def to_frontmatter(self) -> Dict[str, Any]:
fm: Dict[str, Any] = {
"name": self.name,
"description": self.description,
"version": self.version,
"category": self.category,
}
if self.tags: fm["tags"] = list(self.tags)
if self.platforms: fm["platforms"] = list(self.platforms)
if self.requires_toolsets: fm["requires_toolsets"] = list(self.requires_toolsets)
if self.fallback_for_toolsets: fm["fallback_for_toolsets"] = list(self.fallback_for_toolsets)
fm["status"] = self.status
fm["confidence"] = round(float(self.confidence), 3)
fm["source"] = self.source
if self.teacher_model: fm["teacher_model"] = self.teacher_model
if self.owner: fm["owner"] = self.owner
fm["created"] = self.created or _now_iso()
return fm
def to_dict(self) -> Dict[str, Any]:
d = {
"id": self.name, # slug doubles as id
"name": self.name,
"description": self.description,
"version": self.version,
"category": self.category,
"tags": list(self.tags),
"platforms": list(self.platforms),
"requires_toolsets": list(self.requires_toolsets),
"fallback_for_toolsets": list(self.fallback_for_toolsets),
"status": self.status,
"confidence": round(float(self.confidence), 3),
"source": self.source,
"teacher_model": self.teacher_model,
"owner": self.owner,
"created": self.created,
"when_to_use": self.when_to_use,
"procedure": list(self.procedure),
"pitfalls": list(self.pitfalls),
"verification": list(self.verification),
"body_extra": self.body_extra,
"uses": int(self.uses or 0),
"last_used": self.last_used,
"path": self.path,
}
# Back-compat aliases for the old API/UI
d["title"] = self.description or self.name.replace("-", " ").title()
d["problem"] = self.when_to_use
d["solution"] = (self.procedure[0] if self.procedure else "") if not self.body_extra else self.body_extra
d["steps"] = list(self.procedure)
return d
@classmethod
def from_markdown(cls, text: str, *, path: Optional[str] = None) -> "Skill":
fm, body = parse_frontmatter(text)
sections = parse_body(body)
raw_name = fm.get("name")
name = slugify(raw_name if raw_name not in (None, "") else fm.get("description", ""), fallback="skill")
return cls(
name=name,
description=str(fm.get("description", "") or ""),
version=str(fm.get("version", "1.0.0") or "1.0.0"),
category=str(fm.get("category", "general") or "general"),
tags=_as_list(fm.get("tags")),
platforms=_as_list(fm.get("platforms")),
requires_toolsets=_as_list(fm.get("requires_toolsets")),
fallback_for_toolsets=_as_list(fm.get("fallback_for_toolsets")),
status=str(fm.get("status", "draft") or "draft"),
confidence=_as_float(fm.get("confidence", 0.8), 0.8),
source=str(fm.get("source", "learned") or "learned"),
teacher_model=str(fm.get("teacher_model")) if fm.get("teacher_model") else None,
owner=str(fm.get("owner")) if fm.get("owner") else None,
created=str(fm.get("created") or _now_iso()),
when_to_use=sections["when_to_use"],
procedure=list(sections["procedure"]),
pitfalls=list(sections["pitfalls"]),
verification=list(sections["verification"]),
body_extra=sections["body_extra"],
path=path,
)
def to_markdown(self) -> str:
fm = emit_frontmatter(self.to_frontmatter())
body = emit_body({
"when_to_use": self.when_to_use,
"procedure": self.procedure,
"pitfalls": self.pitfalls,
"verification": self.verification,
"body_extra": self.body_extra,
})
return f"---\n{fm}\n---\n\n{body}"
def _now_iso() -> str:
return datetime.utcnow().strftime("%Y-%m-%dT%H:%M:%SZ")
+610
View File
@@ -0,0 +1,610 @@
# services/memory/skills.py
"""Skills storage layer.
Skills live on disk as `data/skills/<category>/<name>/SKILL.md` files with
YAML frontmatter and a structured markdown body (When to Use / Procedure /
Pitfalls / Verification). See `skill_format.py` for the format.
Usage counters (`uses`, `last_used`) live in a sidecar
`data/skills/_usage.json` keyed by skill name so the SKILL.md content
doesn't churn on every retrieval.
Ownership: skills declare `owner: <username>` in frontmatter. Single-user
deployments can leave that blank.
This module also retains a JSON fallback for any legacy `data/skills.json`
entries — they're surfaced as read-only `Skill` objects so old data still
loads while a user migrates them to disk.
"""
from __future__ import annotations
import json
import logging
import os
import time
from typing import Dict, Iterable, List, Optional
from .skill_format import Skill, slugify
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Token / similarity helpers (kept for the relevance fallback)
# ---------------------------------------------------------------------------
def _tokenize(text: str) -> set:
return {w.strip('.,!?";:()[]') for w in (text or "").lower().split() if len(w) > 1}
def _jaccard(a: set, b: set) -> float:
if not a or not b:
return 0.0
return len(a & b) / len(a | b)
def _to_float(x, default: float = 0.0) -> float:
"""Coerce a possibly hand-edited frontmatter value to float without
raising — a blank or non-numeric `confidence:` in a SKILL.md must not
blow up retrieval or eviction."""
try:
return float(x)
except (TypeError, ValueError):
return default
# ---------------------------------------------------------------------------
# SkillsManager
# ---------------------------------------------------------------------------
class SkillsManager:
"""Read/write SKILL.md files under <data_dir>/skills/."""
def __init__(self, data_dir: str):
self.data_dir = data_dir
self.skills_root = os.path.join(data_dir, "skills")
self.usage_file = os.path.join(self.skills_root, "_usage.json")
self.legacy_file = os.path.join(data_dir, "skills.json") # back-compat
os.makedirs(self.skills_root, exist_ok=True)
# ----------------------------------------------------------------------
# Path helpers
# ----------------------------------------------------------------------
def _skill_dir(self, category: str, name: str) -> str:
cat = slugify(category or "general", fallback="general")
nm = slugify(name, fallback="skill")
return os.path.join(self.skills_root, cat, nm)
def _skill_file(self, category: str, name: str) -> str:
return os.path.join(self._skill_dir(category, name), "SKILL.md")
# ----------------------------------------------------------------------
# Usage sidecar
# ----------------------------------------------------------------------
def _load_usage(self) -> Dict[str, Dict]:
if not os.path.exists(self.usage_file):
return {}
try:
with open(self.usage_file) as f:
d = json.load(f)
return d if isinstance(d, dict) else {}
except Exception:
return {}
def _save_usage(self, usage: Dict[str, Dict]) -> None:
try:
from core.atomic_io import atomic_write_json
atomic_write_json(self.usage_file, usage, indent=2)
except Exception:
tmp = self.usage_file + ".tmp"
with open(tmp, "w") as f:
json.dump(usage, f, indent=2)
os.replace(tmp, self.usage_file)
def set_audit(self, name: str, verdict: str, by_teacher: bool = False,
worker_model: str = "", teacher_model: str = "") -> None:
"""Record the last test/audit result for a skill in the usage sidecar
(so it surfaces in load() without touching SKILL.md). Drives the
'verified' check + teacher mark on the card."""
import time as _t
usage = self._load_usage()
e = usage.setdefault(name, {"uses": 0, "last_used": None})
e["audit_verdict"] = verdict
e["audit_by_teacher"] = bool(by_teacher)
if worker_model:
e["audit_worker_model"] = worker_model
if teacher_model:
e["audit_teacher_model"] = teacher_model
e["audited_at"] = _t.time()
self._save_usage(usage)
def set_necessity(self, name: str, necessary: bool,
redundant_with=None, reason: str = "") -> None:
"""Record the advisory 'is this skill necessary?' judgment in the usage
sidecar. Surfaced on the card as a flag; never acts on the skill."""
usage = self._load_usage()
e = usage.setdefault(name, {"uses": 0, "last_used": None})
e["necessity"] = {
"necessary": bool(necessary),
"redundant_with": list(redundant_with or []),
"reason": str(reason or ""),
}
self._save_usage(usage)
# ----------------------------------------------------------------------
# Disk scan
# ----------------------------------------------------------------------
def _iter_skill_files(self) -> Iterable[str]:
if not os.path.isdir(self.skills_root):
return
for root, _dirs, files in os.walk(self.skills_root, followlinks=False):
if "SKILL.md" in files:
yield os.path.join(root, "SKILL.md")
def _read_skill(self, path: str) -> Optional[Skill]:
try:
with open(path) as f:
text = f.read()
return Skill.from_markdown(text, path=path)
except Exception as e:
logger.warning(f"Failed to parse {path}: {e}")
return None
def _write_skill(self, sk: Skill) -> str:
path = self._skill_file(sk.category or "general", sk.name)
os.makedirs(os.path.dirname(path), exist_ok=True)
from core.atomic_io import atomic_write_text
atomic_write_text(path, sk.to_markdown())
sk.path = path
return path
def backfill_owner(self, primary_owner: str, valid_owners: Optional[set[str]] = None) -> int:
"""Assign legacy/unclaimed skill files to the primary owner.
Skills are disk-backed, so the DB legacy-owner migration cannot fix
them. If strict owner filtering is enabled and SKILL.md files have no
owner or an owner from a deleted/test account, the UI appears empty even
though files still exist. This mirrors the DB legacy-owner sweep.
"""
primary_owner = (primary_owner or "").strip()
if not primary_owner:
return 0
valid_owners = set(valid_owners or [])
changed = 0
for path in self._iter_skill_files():
sk = self._read_skill(path)
if not sk:
continue
owner = (sk.owner or "").strip()
if owner == primary_owner:
continue
if owner and owner in valid_owners:
continue
sk.owner = primary_owner
try:
self._write_skill(sk)
changed += 1
except Exception as e:
logger.warning("Failed to backfill owner for skill %s: %s", sk.name, e)
return changed
# ----------------------------------------------------------------------
# Public API — keeps the old method names so callers don't break
# ----------------------------------------------------------------------
def load_all(self) -> List[Dict]:
"""Return every skill as a plain dict, plus any legacy JSON entries."""
usage = self._load_usage()
out: List[Dict] = []
seen_names: set[str] = set()
for path in self._iter_skill_files():
sk = self._read_skill(path)
if not sk:
continue
d = sk.to_dict()
u = usage.get(sk.name) or {}
d["uses"] = int(u.get("uses", 0))
d["last_used"] = u.get("last_used")
d["audit_verdict"] = u.get("audit_verdict")
d["audit_by_teacher"] = bool(u.get("audit_by_teacher"))
d["audit_worker_model"] = u.get("audit_worker_model")
d["audit_teacher_model"] = u.get("audit_teacher_model")
d["audited_at"] = u.get("audited_at")
d["necessity"] = u.get("necessity")
out.append(d)
seen_names.add(sk.name)
# Legacy JSON entries — surfaced as draft, not editable from new flow
if os.path.exists(self.legacy_file):
try:
with open(self.legacy_file) as f:
legacy = json.load(f)
if isinstance(legacy, list):
for row in legacy:
if not isinstance(row, dict):
continue
name = slugify(row.get("title") or row.get("id") or "skill")
if name in seen_names:
continue
out.append({
"id": row.get("id") or name,
"name": name,
"description": row.get("title", ""),
"version": "0.0.1",
"category": "legacy",
"tags": row.get("tags") or [],
"status": row.get("status") or "draft",
"confidence": row.get("confidence", 0.5),
"source": row.get("source", "imported"),
"owner": row.get("owner"),
"when_to_use": row.get("problem", ""),
"procedure": row.get("steps") or [],
"pitfalls": [],
"verification": [],
"body_extra": row.get("solution", ""),
"title": row.get("title", ""),
"problem": row.get("problem", ""),
"solution": row.get("solution", ""),
"steps": row.get("steps") or [],
"uses": row.get("uses", 0),
"last_used": row.get("last_used"),
"_legacy": True,
})
except Exception:
pass
return out
def load(self, owner: Optional[str] = None) -> List[Dict]:
entries = self.load_all()
if owner is None:
return entries
# SECURITY: strict ownership filter. The previous predicate also
# included skills with NO owner field (`not s.get("owner")`), which
# leaked legacy / un-stamped skills to every authenticated user.
# Hide them now; the owner needs to be backfilled on disk if those
# skills should be visible to a specific user.
return [s for s in entries if s.get("owner") == owner]
# ----------------------------------------------------------------------
# CRUD — disk-backed
# ----------------------------------------------------------------------
def add_skill(
self,
title: str = "",
problem: str = "",
solution: str = "",
steps: Optional[List[str]] = None,
tags: Optional[List[str]] = None,
source: str = "learned",
teacher_model: Optional[str] = None,
confidence: float = 0.8,
session_id: Optional[str] = None,
owner: Optional[str] = None,
# New-schema fields (optional; fall back to old shape if absent)
name: Optional[str] = None,
description: Optional[str] = None,
category: str = "general",
when_to_use: Optional[str] = None,
procedure: Optional[List[str]] = None,
pitfalls: Optional[List[str]] = None,
verification: Optional[List[str]] = None,
platforms: Optional[List[str]] = None,
requires_toolsets: Optional[List[str]] = None,
fallback_for_toolsets: Optional[List[str]] = None,
status: str = "draft",
version: str = "1.0.0",
) -> Dict:
# Normalize name
nm = slugify(name or title or description or "skill")
# Free dedup-at-creation (always, no API): for LLM-authored skills,
# skip if a near-identical skill already exists (Jaccard over
# name+description+when_to_use+procedure). User-authored skills are
# never auto-skipped — a human asked for it. The every-X AI audit
# handles the fuzzier near-duplicates this cheap check won't catch.
_all = self.load_all()
if source != "user":
cand = _tokenize(" ".join([
nm, (description or title or ""),
(when_to_use if when_to_use is not None else (problem or "")),
" ".join(procedure if procedure is not None else (steps or [])),
]))
if cand:
for s in _all:
ex = _tokenize(" ".join([
s.get("name", ""), s.get("description", ""),
s.get("when_to_use", ""),
" ".join(s.get("procedure", []) or []),
]))
if _jaccard(cand, ex) >= 0.82:
# Near-identical — don't grow the library; bump the
# existing skill's usage and return it so the caller
# knows it already exists.
try:
self.record_use(s["name"])
except Exception:
pass
return {**s, "_deduped": True, "_duplicate_of": s.get("name")}
# Avoid clobbering an existing skill with the same name
existing = {s["name"] for s in _all}
base = nm
i = 2
while nm in existing:
nm = f"{base}-{i}"
i += 1
sk = Skill(
name=nm,
description=(description or title or "").strip(),
version=version,
category=category or "general",
tags=list(tags or []),
platforms=list(platforms or []),
requires_toolsets=list(requires_toolsets or []),
fallback_for_toolsets=list(fallback_for_toolsets or []),
status=status or "draft",
confidence=float(confidence),
source=source,
teacher_model=teacher_model,
owner=owner,
when_to_use=(when_to_use if when_to_use is not None else (problem or "")),
procedure=list(procedure if procedure is not None else (steps or [])),
pitfalls=list(pitfalls or []),
verification=list(verification or []),
body_extra=(solution if solution and not procedure else ""),
)
self._write_skill(sk)
return sk.to_dict()
def update_skill(self, skill_id: str, updates: Dict) -> bool:
"""`skill_id` is the slug name. Allows updating any field plus
renames if `name` changes (file is moved on disk)."""
for path in self._iter_skill_files():
sk = self._read_skill(path)
if not sk or sk.name != skill_id:
continue
old_dir = os.path.dirname(path)
# Apply updates in a Skill-shape friendly way
scalar_keys = (
"description", "version", "category", "status", "confidence",
"source", "teacher_model", "owner", "when_to_use",
"body_extra",
)
for k in scalar_keys:
if k in updates:
setattr(sk, k, updates[k])
list_keys = ("tags", "procedure", "pitfalls", "verification",
"platforms", "requires_toolsets", "fallback_for_toolsets")
for k in list_keys:
if k in updates:
setattr(sk, k, list(updates[k] or []))
# Old-schema field aliases
if "title" in updates and "description" not in updates:
sk.description = updates["title"]
if "problem" in updates and "when_to_use" not in updates:
sk.when_to_use = updates["problem"]
if "solution" in updates and "body_extra" not in updates and not sk.procedure:
sk.body_extra = updates["solution"]
if "steps" in updates and "procedure" not in updates:
sk.procedure = list(updates["steps"] or [])
# Rename
new_name = slugify(updates.get("name") or sk.name)
if new_name != sk.name:
sk.name = new_name
# Write to potentially new path
new_path = self._skill_file(sk.category, sk.name)
if new_path != path:
# Move the whole skill directory if rename or recategorize
new_dir = os.path.dirname(new_path)
if os.path.isdir(new_dir):
logger.warning(f"Skill rename target exists: {new_dir}")
return False
os.makedirs(os.path.dirname(new_dir), exist_ok=True)
os.rename(old_dir, new_dir)
# Also rename usage key
usage = self._load_usage()
if skill_id in usage:
usage[sk.name] = usage.pop(skill_id)
self._save_usage(usage)
self._write_skill(sk)
return True
return False
def delete_skill(self, skill_id: str) -> bool:
for path in self._iter_skill_files():
sk = self._read_skill(path)
if not sk or sk.name != skill_id:
continue
skill_dir = os.path.dirname(path)
try:
# Remove the whole skill dir
for root, dirs, files in os.walk(skill_dir, topdown=False):
for f in files:
os.remove(os.path.join(root, f))
for d in dirs:
os.rmdir(os.path.join(root, d))
os.rmdir(skill_dir)
except Exception as e:
logger.warning(f"Failed to remove skill dir {skill_dir}: {e}")
return False
usage = self._load_usage()
if skill_id in usage:
del usage[skill_id]
self._save_usage(usage)
return True
return False
def record_use(self, skill_id: str) -> None:
usage = self._load_usage()
entry = usage.setdefault(skill_id, {"uses": 0, "last_used": None})
entry["uses"] = int(entry.get("uses", 0)) + 1
entry["last_used"] = int(time.time())
self._save_usage(usage)
# ----------------------------------------------------------------------
# Reading a single skill (used by the skill_view tool)
# ----------------------------------------------------------------------
def read_skill_md(self, name: str) -> Optional[str]:
for path in self._iter_skill_files():
sk = self._read_skill(path)
if sk and sk.name == name:
try:
with open(path) as f:
return f.read()
except Exception:
return None
return None
def read_skill_reference(self, name: str, ref_path: str) -> Optional[str]:
"""Read a sub-file under the skill's directory (references/, etc).
Refuses path traversal."""
for path in self._iter_skill_files():
sk = self._read_skill(path)
if not sk or sk.name != name:
continue
base = os.path.realpath(os.path.dirname(path))
target = os.path.realpath(os.path.join(base, ref_path))
if os.path.commonpath([base, target]) != base or target == os.path.dirname(path):
return None
if not os.path.isfile(target):
return None
try:
with open(target) as f:
return f.read()
except Exception:
return None
return None
# ----------------------------------------------------------------------
# Index — the lightweight summary injected into the system prompt
# ----------------------------------------------------------------------
def index_for(
self,
owner: Optional[str] = None,
*,
active_toolsets: Optional[List[str]] = None,
platform: Optional[str] = None,
) -> List[Dict]:
"""Return the `[{name, description, category, status}]` list the
agent sees in its system prompt.
Includes:
- All published skills.
- Drafts written by the teacher-escalation loop
(`source == "teacher-escalation"`). The whole point of
the teacher loop is for the student to find the new
procedure on the very next turn — waiting for a manual
publish click defeats the loop.
Excludes user-created drafts (status=draft, source != teacher-
escalation) — those are work-in-progress and pollute the
prompt with half-finished procedures.
"""
active_toolsets = active_toolsets or []
out = []
for s in self.load(owner=owner):
status = s.get("status")
# Published + None (pre-status legacy) always included.
# Drafts only if the teacher wrote them.
if status not in ("published", None):
if status == "draft" and s.get("source") == "teacher-escalation":
pass # let it through
else:
continue
# Platform gating
if platform and s.get("platforms") and platform not in s["platforms"]:
continue
# requires_toolsets: hide unless every required toolset is active
req = s.get("requires_toolsets") or []
if req and not all(t in active_toolsets for t in req):
continue
# fallback_for_toolsets: hide when any of those toolsets is active
fb = s.get("fallback_for_toolsets") or []
if fb and any(t in active_toolsets for t in fb):
continue
out.append({
"name": s["name"],
"description": s.get("description") or s.get("title", ""),
"category": s.get("category", "general"),
"status": status or "published",
})
out.sort(key=lambda x: (x["category"], x["name"]))
return out
# ----------------------------------------------------------------------
# Relevance search (kept for the existing /api/skills/search endpoint
# and the `manage_skills` action="search"). Now operates on the new
# field set.
# ----------------------------------------------------------------------
def get_relevant_skills(
self,
query: str,
skills: Optional[List[Dict]] = None,
threshold: float = 0.3,
max_items: int = 5,
min_confidence: float = 0.0,
) -> List[Dict]:
if skills is None:
skills = self.load_all()
if not skills or not query.strip():
return []
# Consider published AND draft skills for relevance retrieval.
# The teacher-escalation loop writes new skills as drafts; the
# whole point is for the student to find them on the next try
# without a manual publish click. The UI flags teacher-written
# entries with a 🎓 badge so users can demote / delete bad
# ones when they spot them.
skills = [s for s in skills if s.get("status") in ("published", "draft")]
# Confidence gate (used by prompt-injection, NOT by search): a DRAFT
# skill must clear the bar to be injected. Published skills are already
# vetted, so they always qualify. Missing confidence = treat as 1.0
# (legacy skills shouldn't silently vanish). 0 disables the gate.
if min_confidence > 0:
def _passes(s):
if s.get("status") == "published":
return True
c = s.get("confidence")
if c is None:
return True # unset → don't filter (legacy)
return _to_float(c, 1.0) >= min_confidence # unparseable → pass
skills = [s for s in skills if _passes(s)]
if not skills:
return []
query_tokens = _tokenize(query)
scored = []
for sk in skills:
text = " ".join([
sk.get("name", ""),
sk.get("description", ""),
sk.get("when_to_use", ""),
" ".join(sk.get("tags", []) or []),
" ".join(sk.get("procedure", []) or []),
])
score = _jaccard(query_tokens, _tokenize(text))
for tag in sk.get("tags", []) or []:
if tag and tag in query.lower():
score = max(score, 0.3) * 1.3
if query.lower() in (sk.get("description") or "").lower():
score = max(score, 0.6)
score *= 1.0 + _to_float(sk.get("confidence"), 0.5) * 0.1
if sk.get("uses", 0) > 0:
score *= 1.05
if score >= threshold:
scored.append((score, sk))
scored.sort(key=lambda x: x[0], reverse=True)
return [sk for _, sk in scored[:max_items]]
+12
View File
@@ -0,0 +1,12 @@
# services/research/__init__.py
"""Research service — deep research with LLM-in-the-loop."""
from .service import ResearchService, ResearchResult, ResearchSource
from .research_handler import ResearchHandler
__all__ = [
"ResearchService",
"ResearchResult",
"ResearchSource",
"ResearchHandler",
]
+463
View File
@@ -0,0 +1,463 @@
# src/research_handler.py
"""Handler for research service integration with expandable UI support.
Uses the IterResearch-style DeepResearcher (LLM-in-the-loop) as the primary
engine, falling back to the legacy ResearchOrchestrator or basic web search
if needed.
Includes a task registry so research survives page refreshes and can be cancelled.
"""
import asyncio
import json
import logging
import time
from pathlib import Path
from typing import Optional, Dict
logger = logging.getLogger(__name__)
RESEARCH_DATA_DIR = Path("data/deep_research")
class ResearchHandler:
"""Handles research service operations with iterative deep research."""
def __init__(self):
self._legacy_engine = None
self._active_tasks: Dict[str, dict] = {}
self._initialize_legacy_engine()
RESEARCH_DATA_DIR.mkdir(parents=True, exist_ok=True)
def _initialize_legacy_engine(self):
"""Initialize the legacy research engine as a fallback."""
try:
from research_engine import ResearchOrchestrator, Config
config = Config(max_searches=12, max_content_per_page=15000)
self._legacy_engine = ResearchOrchestrator(config)
logger.info("Legacy ResearchOrchestrator initialized (fallback)")
except ImportError:
logger.info("Legacy research_engine.py not found — DeepResearcher only")
self._legacy_engine = None
except Exception as e:
logger.warning(f"Legacy research engine init failed: {e}")
self._legacy_engine = None
# ------------------------------------------------------------------
# Task registry — background research with persistence
# ------------------------------------------------------------------
def start_research(
self,
session_id: str,
query: str,
llm_endpoint: str,
llm_model: str,
max_time: int = 300,
llm_headers: dict = None,
) -> dict:
"""Start research as a background task. Returns task info dict."""
# Cancel any existing research for this session
if session_id in self._active_tasks:
existing = self._active_tasks[session_id]
if existing.get("status") == "running":
self.cancel_research(session_id)
entry = {
"task": None,
"researcher": None,
"query": query,
"status": "running",
"progress": {},
"result": None,
"started_at": time.time(),
}
self._active_tasks[session_id] = entry
def on_progress(event):
entry["progress"] = event
async def _run():
try:
result = await self.call_research_service(
query, llm_endpoint, llm_model,
max_time=max_time,
progress_callback=on_progress,
_task_entry=entry,
llm_headers=llm_headers,
)
entry["result"] = result
entry["status"] = "done"
self._save_result(session_id, entry)
except asyncio.CancelledError:
entry["status"] = "cancelled"
raise
except Exception as e:
logger.error(f"Background research failed: {e}", exc_info=True)
entry["result"] = str(e)
entry["status"] = "error"
task = asyncio.create_task(_run())
entry["task"] = task
return {"session_id": session_id, "status": "running", "query": query}
def get_status(self, session_id: str) -> Optional[dict]:
"""Get current research status for a session."""
if session_id in self._active_tasks:
entry = self._active_tasks[session_id]
return {
"status": entry["status"],
"progress": entry["progress"],
"query": entry["query"],
"started_at": entry["started_at"],
}
# Check disk for completed research
path = RESEARCH_DATA_DIR / f"{session_id}.json"
if path.exists():
try:
data = json.loads(path.read_text())
return {
"status": data.get("status", "done"),
"progress": {},
"query": data.get("query", ""),
"started_at": data.get("started_at", 0),
}
except Exception:
pass
return None
def cancel_research(self, session_id: str) -> bool:
"""Cancel running research for a session."""
if session_id not in self._active_tasks:
return False
entry = self._active_tasks[session_id]
if entry["status"] != "running":
return False
researcher = entry.get("researcher")
if researcher:
researcher.cancel()
task = entry.get("task")
if task and not task.done():
task.cancel()
entry["status"] = "cancelled"
return True
def get_result(self, session_id: str) -> Optional[str]:
"""Get the completed research result."""
if session_id in self._active_tasks:
entry = self._active_tasks[session_id]
if entry["status"] in ("done", "error", "cancelled"):
return entry.get("result")
# Check disk
path = RESEARCH_DATA_DIR / f"{session_id}.json"
if path.exists():
try:
data = json.loads(path.read_text())
return data.get("result")
except Exception:
pass
return None
def get_sources(self, session_id: str) -> Optional[list]:
"""Get deduplicated source list from research findings."""
# Check in-memory first
if session_id in self._active_tasks:
entry = self._active_tasks[session_id]
if entry.get("sources"):
return entry["sources"]
researcher = entry.get("researcher")
if researcher and researcher.findings:
return self._extract_sources(researcher.findings)
# Check disk
path = RESEARCH_DATA_DIR / f"{session_id}.json"
if path.exists():
try:
data = json.loads(path.read_text())
return data.get("sources")
except Exception:
pass
return None
@staticmethod
def _extract_sources(findings: list) -> list:
"""Extract deduplicated [{url, title}] from findings."""
seen = set()
sources = []
for f in findings:
url = f.get("url", "")
title = f.get("title", "") or url
if url and url not in seen:
seen.add(url)
sources.append({"url": url, "title": title})
return sources
def clear_result(self, session_id: str):
"""Remove persisted result after it's been consumed."""
self._active_tasks.pop(session_id, None)
path = RESEARCH_DATA_DIR / f"{session_id}.json"
if path.exists():
try:
path.unlink()
except Exception:
pass
def _save_result(self, session_id: str, entry: dict):
"""Persist completed research result to disk."""
try:
# Extract and cache sources
sources = []
researcher = entry.get("researcher")
if researcher and researcher.findings:
sources = self._extract_sources(researcher.findings)
entry["sources"] = sources
path = RESEARCH_DATA_DIR / f"{session_id}.json"
data = {
"query": entry["query"],
"status": entry["status"],
"result": entry["result"],
"sources": sources,
"started_at": entry["started_at"],
"completed_at": time.time(),
}
path.write_text(json.dumps(data))
logger.info(f"Research result saved to {path}")
except Exception as e:
logger.error(f"Failed to save research result: {e}")
async def call_research_service(
self,
query: str,
llm_endpoint: str,
llm_model: str,
max_time: int = 300,
progress_callback=None,
_task_entry: dict = None,
llm_headers: dict = None,
) -> str:
"""
Run iterative deep research using the LLM-in-the-loop DeepResearcher.
Args:
query: Research question
llm_endpoint: LLM endpoint URL for chat completions
llm_model: Model name/ID
max_time: Maximum research time in seconds (default 5 minutes)
_task_entry: Internal - registry entry to store researcher ref
Returns:
Formatted research report with expandable section and summary
"""
logger.info("Starting IterResearch Deep Research")
logger.info(f"Query: {query}")
logger.info(f"LLM: {llm_endpoint} / {llm_model}")
logger.info(f"Max time: {max_time}s")
try:
from src.deep_research import DeepResearcher
from src.settings import get_setting
researcher = DeepResearcher(
llm_endpoint=llm_endpoint,
llm_model=llm_model,
llm_headers=llm_headers,
max_rounds=8,
max_time=max_time,
max_report_tokens=int(get_setting("research_max_tokens", 8192)),
progress_callback=progress_callback,
)
if _task_entry is not None:
_task_entry["researcher"] = researcher
start_time = time.time()
report = await researcher.research(query)
elapsed = time.time() - start_time
stats = researcher.get_stats()
logger.info("IterResearch completed successfully")
for key, value in stats.items():
logger.info(f" {key}: {value}")
return self._format_research_report(
query, report, stats, elapsed,
findings=researcher.findings,
evolving_report=researcher.evolving_report,
)
except Exception as e:
logger.error(f"DeepResearcher failed: {e}", exc_info=True)
return await self._fallback_research(query, llm_endpoint, llm_model, max_time, str(e))
async def _fallback_research(
self, query: str, llm_endpoint: str, llm_model: str,
max_time: int, primary_error: str,
) -> str:
"""Fall back to legacy engine, then to basic web search."""
# Try legacy orchestrator
if self._legacy_engine:
try:
import asyncio
logger.info("Falling back to legacy ResearchOrchestrator...")
loop = asyncio.get_event_loop()
result = await loop.run_in_executor(
None, self._legacy_engine.start_research, query, max_time
)
stats = self._get_legacy_stats()
elapsed = float(stats.get("Duration", "0").rstrip("s") or 0)
return self._format_research_report(query, result, stats, elapsed)
except Exception as e:
logger.error(f"Legacy engine also failed: {e}")
# Fall back to basic web search
return self._handle_research_failure(query, primary_error)
def _get_legacy_stats(self) -> dict:
"""Get statistics from the legacy research engine."""
if not self._legacy_engine:
return {}
try:
tracker = self._legacy_engine.progress_tracker
return {
"Findings": len(self._legacy_engine.findings),
"Sources": len(self._legacy_engine.source_reports),
"Searches": tracker.counters['searches_executed'],
"URLs": tracker.counters['urls_processed'],
}
except Exception:
return {}
def _format_research_report(
self, query: str, full_report: str, stats: dict, elapsed: float,
findings: list = None, evolving_report: str = None,
) -> str:
"""Format research report with sources list and expandable raw findings."""
summary_lines = [
f"**Duration:** {elapsed:.1f}s",
f"**Rounds:** {stats.get('Rounds', stats.get('Findings', '?'))}",
f"**Queries:** {stats.get('Queries', stats.get('Searches', '?'))}",
f"**URLs Analyzed:** {stats.get('URLs', '?')}",
]
summary_text = " | ".join(summary_lines)
# Build sources list with clickable links
sources_section = ""
if findings:
seen_urls = set()
source_lines = []
for f in findings:
url = f.get("url", "")
title = f.get("title", "") or url
if url and url not in seen_urls:
seen_urls.add(url)
source_lines.append(f"- [{title}]({url})")
if source_lines:
sources_section = "\n### Sources\n\n" + "\n".join(source_lines) + "\n"
# Build raw findings section (individual extractions per source)
raw_findings_section = ""
if findings:
parts = []
for i, f in enumerate(findings, 1):
url = f.get("url", "")
title = f.get("title", "") or "Untitled"
summary = f.get("summary", "")
evidence = f.get("evidence", "")
content = summary if summary else (evidence[:2000] if evidence else "(no content)")
parts.append(f"**{i}. [{title}]({url})**\n\n{content}")
raw_findings_section = "\n\n".join(parts)
# Build expandable collected info section
collected_section = ""
if evolving_report or raw_findings_section:
collected_section = "\n<details>\n<summary><strong>Raw collected findings ({} sources)</strong></summary>\n\n".format(
len(findings) if findings else 0
)
if raw_findings_section:
collected_section += raw_findings_section + "\n"
collected_section += "\n</details>\n"
formatted = f"""---
## Research Summary
{summary_text}
---
{full_report}
{sources_section}
{collected_section}
---
**The AI has analyzed all research findings above. Ask me anything about: "{query}"**
"""
return formatted
def _format_error_response(self, error_msg: str, query: str) -> str:
"""Format error response in a user-friendly way."""
return f"""## Research Engine Unavailable
**Query:** {query}
**Error:** {error_msg}
**Please check:**
1. LLM endpoint is reachable
2. SearXNG is running at the configured instance
3. Application logs for detailed error information
**Troubleshooting:**
- Test basic search: Try the web search toggle first
- Check search config: `/api/search/config`
- Review logs for initialization errors
"""
def _handle_research_failure(self, query: str, error: str) -> str:
"""Handle research failure with fallback to basic search."""
try:
logger.info("Attempting fallback to basic web search...")
from src.search import comprehensive_web_search
search_result = comprehensive_web_search(query)
return f"""## Research Failed - Basic Search Fallback
**Query:** {query}
**Error:** {error}
**Note:** The deep research engine encountered an error. Here are basic search results instead:
---
### Basic Web Search Results
{search_result}
---
**To fix deep research:**
1. Check that your LLM endpoint and search provider are properly configured
2. Verify network connectivity
3. Review application logs for detailed error information
Try the web search toggle for simpler queries, or fix the research engine for comprehensive analysis.
"""
except Exception as e2:
logger.error(f"Fallback search also failed: {e2}", exc_info=True)
return f"""## Complete Research Failure
**Primary Error:** {error}
**Fallback Error:** {str(e2)}
**Please check:**
1. Search provider configuration in Settings -> Search Settings
2. Network connectivity to search APIs
3. Application logs for detailed error information
4. That SearXNG is running (if using SearXNG)
**Debug Info:**
- Search config endpoint: `/api/search/config`
- Test basic search toggle with a simple query first
"""
+117
View File
@@ -0,0 +1,117 @@
# services/research/service.py
"""Research service — deep research with LLM-in-the-loop."""
from dataclasses import dataclass, field
from typing import List, Optional, Callable
from .research_handler import ResearchHandler
@dataclass
class ResearchSource:
"""A source found during research."""
url: str
title: str
snippet: str
relevance: float = 0.0
@dataclass
class ResearchResult:
"""Result of a deep research query."""
query: str
summary: str
sources: List[ResearchSource] = field(default_factory=list)
sections: List[str] = field(default_factory=list)
tokens_used: int = 0
duration_seconds: float = 0.0
class ResearchService:
"""
Deep research service.
Usage:
service = ResearchService()
result = await service.research("quantum computing advances 2024")
print(result.summary)
"""
def __init__(self):
self.handler = ResearchHandler()
self._active: dict = {}
async def research(
self,
topic: str,
llm_endpoint: str,
llm_model: str,
max_time: int = 300,
on_progress: Optional[Callable[[dict], None]] = None,
) -> ResearchResult:
"""
Perform deep research on a topic.
Args:
topic: Research topic/question
llm_endpoint: LLM API endpoint
llm_model: Model to use
max_time: Maximum time in seconds
on_progress: Optional progress callback
Returns:
ResearchResult with findings
"""
import time
start = time.time()
result = await self.handler.call_research_service(
topic,
llm_endpoint,
llm_model,
max_time=max_time,
progress_callback=on_progress,
)
duration = time.time() - start
# Parse result into structured format
sources = [
ResearchSource(
url=s.get("url", ""),
title=s.get("title", ""),
snippet=s.get("snippet", ""),
relevance=s.get("relevance", 0.0),
)
for s in result.get("sources", [])
]
return ResearchResult(
query=topic,
summary=result.get("summary", result.get("answer", "")),
sources=sources,
sections=result.get("sections", []),
tokens_used=result.get("tokens_used", 0),
duration_seconds=duration,
)
def start_background(
self,
session_id: str,
topic: str,
llm_endpoint: str,
llm_model: str,
max_time: int = 300,
) -> dict:
"""Start research in background. Returns task info."""
return self.handler.start_research(
session_id, topic, llm_endpoint, llm_model, max_time
)
def get_status(self, session_id: str) -> Optional[dict]:
"""Get status of background research."""
return self.handler.get_status(session_id)
def cancel(self, session_id: str) -> bool:
"""Cancel background research."""
return self.handler.cancel_research(session_id)
+35
View File
@@ -0,0 +1,35 @@
"""Search service — web search with SearXNG."""
from .core import (
comprehensive_web_search,
get_search_config,
invalidate_search_cache,
searxng_search_results,
update_search_config,
)
from .content import fetch_webpage_content
from .providers import searxng_search, searxng_search_api, PROVIDER_INFO
from .analytics import get_search_stats, SearchEngineError, NetworkError, ParseError, RateLimitError
from .service import SearchService, SearchResult, SearchResponse
__all__ = [
# Service interface (preferred)
"SearchService",
"SearchResult",
"SearchResponse",
# Low-level functions (for backwards compat)
"comprehensive_web_search",
"fetch_webpage_content",
"get_search_config",
"get_search_stats",
"invalidate_search_cache",
"searxng_search",
"searxng_search_api",
"searxng_search_results",
"update_search_config",
"PROVIDER_INFO",
"SearchEngineError",
"NetworkError",
"ParseError",
"RateLimitError",
]
+136
View File
@@ -0,0 +1,136 @@
"""Search analytics, metrics tracking, and exception hierarchy."""
import json
import logging
from collections import Counter
from pathlib import Path
from typing import Dict, Any
from .cache import cache_metrics
logger = logging.getLogger(__name__)
# Dedicated error logger with file handler
_error_log_path = Path(__file__).resolve().parent.parent / "search_engine_error.log"
_error_handler = logging.FileHandler(_error_log_path, encoding="utf-8")
_error_handler.setLevel(logging.WARNING)
_error_handler.setFormatter(logging.Formatter("%(asctime)s %(levelname)s %(name)s %(message)s"))
error_logger = logging.getLogger("search_engine_error")
error_logger.addHandler(_error_handler)
error_logger.propagate = False
# Analytics file
ANALYTICS_FILE = Path(__file__).resolve().parent.parent / "search_analytics.json"
# ----------------------------------------------------------------------
# Custom exception hierarchy
# ----------------------------------------------------------------------
class SearchEngineError(Exception):
"""Base class for all search-engine related errors."""
class NetworkError(SearchEngineError):
"""Raised when a network request fails (e.g., timeout, DNS error)."""
class ParseError(SearchEngineError):
"""Raised when HTML or other content cannot be parsed."""
class RateLimitError(SearchEngineError):
"""Raised when the remote service returns a rate-limit (HTTP 429)."""
# ----------------------------------------------------------------------
# Analytics helpers
# ----------------------------------------------------------------------
def _load_analytics() -> Dict[str, Any]:
"""Load analytics data from the JSON file, creating defaults if missing."""
if not ANALYTICS_FILE.exists():
default = {
"total_queries": 0,
"successful_queries": 0,
"failed_queries": 0,
"cache_hits": 0,
"cache_misses": 0,
"query_patterns": {},
}
_save_analytics(default)
return default
try:
with open(ANALYTICS_FILE, "r", encoding="utf-8") as f:
return json.load(f)
except Exception as e:
logger.warning(f"Failed to load analytics file: {e}")
return {
"total_queries": 0,
"successful_queries": 0,
"failed_queries": 0,
"cache_hits": 0,
"cache_misses": 0,
"query_patterns": {},
}
def _save_analytics(data: Dict[str, Any]) -> None:
"""Persist analytics data to the JSON file."""
try:
with open(ANALYTICS_FILE, "w", encoding="utf-8") as f:
json.dump(data, f, indent=2)
except Exception as e:
logger.warning(f"Failed to write analytics file: {e}")
def _record_query(query: str, success: bool, cache_hit: bool) -> None:
"""Update analytics for a single query execution."""
analytics = _load_analytics()
analytics["total_queries"] += 1
if success:
analytics["successful_queries"] += 1
else:
analytics["failed_queries"] += 1
if cache_hit:
analytics["cache_hits"] += 1
cache_metrics["hits"] += 1
else:
analytics["cache_misses"] += 1
cache_metrics["misses"] += 1
patterns = analytics["query_patterns"]
entry = patterns.get(query, {"count": 0, "successes": 0})
entry["count"] += 1
if success:
entry["successes"] += 1
patterns[query] = entry
_save_analytics(analytics)
def get_search_stats() -> Dict[str, Any]:
"""Return aggregated search analytics."""
analytics = _load_analytics()
total = analytics.get("total_queries", 0) or 1
success_rate = analytics.get("successful_queries", 0) / total
cache_total = analytics.get("cache_hits", 0) + analytics.get("cache_misses", 0) or 1
cache_hit_rate = analytics.get("cache_hits", 0) / cache_total
pattern_counter = Counter({
q: data["count"] for q, data in analytics.get("query_patterns", {}).items()
})
most_common = [q for q, _ in pattern_counter.most_common(5)]
return {
"most_common_queries": most_common,
"success_rate": success_rate,
"cache_hit_rate": cache_hit_rate,
"total_queries": analytics.get("total_queries", 0),
"successful_queries": analytics.get("successful_queries", 0),
"failed_queries": analytics.get("failed_queries", 0),
"cache_hits": analytics.get("cache_hits", 0),
"cache_misses": analytics.get("cache_misses", 0),
"cache_evictions": cache_metrics["evictions"],
"runtime_cache_hits": cache_metrics["hits"],
"runtime_cache_misses": cache_metrics["misses"],
}
+57
View File
@@ -0,0 +1,57 @@
"""Search and content caching with LRU eviction."""
import hashlib
import logging
from datetime import datetime, timedelta
from pathlib import Path
from typing import Dict
logger = logging.getLogger(__name__)
# Cache directories
CACHE_DIR = Path(__file__).resolve().parent.parent / "cache"
SEARCH_CACHE_DIR = CACHE_DIR / "search"
CONTENT_CACHE_DIR = CACHE_DIR / "content"
CACHE_MAX_ENTRIES = 1000
# Create cache directories
SEARCH_CACHE_DIR.mkdir(parents=True, exist_ok=True)
CONTENT_CACHE_DIR.mkdir(parents=True, exist_ok=True)
# Track cache size for LRU eviction
search_cache_index: Dict[str, datetime] = {}
content_cache_index: Dict[str, datetime] = {}
# Cache metrics (shared across modules)
cache_metrics = {"hits": 0, "misses": 0, "evictions": 0}
def generate_cache_key(data: str) -> str:
"""Generate a unique cache key using SHA-256 hash."""
return hashlib.sha256(data.encode("utf-8")).hexdigest()
def cleanup_cache(cache_dir: Path, cache_index: Dict[str, datetime], max_age: timedelta):
"""Remove expired cache entries and enforce LRU policy."""
current_time = datetime.now()
files_in_dir = {f.name.split(".")[0]: f for f in cache_dir.glob("*.cache")}
to_remove = []
for key, timestamp in list(cache_index.items()):
if current_time - timestamp > max_age or key not in files_in_dir:
to_remove.append(key)
if key in files_in_dir:
files_in_dir[key].unlink(missing_ok=True)
for key in to_remove:
cache_index.pop(key, None)
cache_metrics["evictions"] += 1
if len(cache_index) > CACHE_MAX_ENTRIES:
sorted_items = sorted(cache_index.items(), key=lambda x: x[1])
excess_count = len(cache_index) - CACHE_MAX_ENTRIES
for key, _ in sorted_items[:excess_count]:
cache_index.pop(key, None)
cache_file = cache_dir / f"{key}.cache"
cache_file.unlink(missing_ok=True)
cache_metrics["evictions"] += 1
+360
View File
@@ -0,0 +1,360 @@
"""Webpage content fetching with caching, PDF extraction, and summarization helpers."""
import io
import ipaddress
import json
import os
import re
import logging
import socket
from datetime import datetime, timedelta
from typing import List
from urllib.parse import urljoin, urlparse
import httpx
from bs4 import BeautifulSoup
from .analytics import RateLimitError, error_logger
from .cache import (
CONTENT_CACHE_DIR,
content_cache_index,
generate_cache_key,
cleanup_cache,
)
logger = logging.getLogger(__name__)
_PRIVATE_NETWORKS = (
ipaddress.ip_network("0.0.0.0/8"),
ipaddress.ip_network("10.0.0.0/8"),
ipaddress.ip_network("127.0.0.0/8"),
ipaddress.ip_network("169.254.0.0/16"),
ipaddress.ip_network("172.16.0.0/12"),
ipaddress.ip_network("192.168.0.0/16"),
ipaddress.ip_network("::1/128"),
ipaddress.ip_network("fc00::/7"),
ipaddress.ip_network("fe80::/10"),
)
def _is_private_address(addr: ipaddress._BaseAddress) -> bool:
return addr.is_private or addr.is_loopback or addr.is_link_local or any(addr in net for net in _PRIVATE_NETWORKS)
def _resolve_hostname_ips(hostname: str) -> list[ipaddress._BaseAddress]:
try:
infos = socket.getaddrinfo(hostname, None)
except Exception:
return []
out = []
for info in infos:
try:
out.append(ipaddress.ip_address(info[4][0]))
except Exception:
continue
return out
def _public_http_url(url: str) -> bool:
try:
parsed = urlparse(url)
if parsed.scheme not in ("http", "https"):
return False
host = (parsed.hostname or "").strip()
if not host:
return False
lower = host.lower()
if lower in ("localhost", "metadata", "metadata.google.internal"):
return False
if lower.endswith((".local", ".localhost", ".internal", ".lan", ".intranet")):
return False
try:
return not _is_private_address(ipaddress.ip_address(host))
except ValueError:
pass
addrs = _resolve_hostname_ips(host)
return bool(addrs) and not any(_is_private_address(a) for a in addrs)
except Exception:
return False
def _get_public_url(url: str, headers: dict, timeout: int, max_redirects: int = 5) -> httpx.Response:
current = url
for _ in range(max_redirects + 1):
if not _public_http_url(current):
raise httpx.RequestError("Blocked private/internal URL", request=httpx.Request("GET", current))
response = httpx.get(current, headers=headers, timeout=timeout, follow_redirects=False)
if response.status_code not in (301, 302, 303, 307, 308):
return response
location = response.headers.get("location")
if not location:
return response
current = urljoin(str(response.url), location)
raise httpx.RequestError("Too many redirects", request=httpx.Request("GET", current))
# PDF extraction (optional dependency)
try:
from pdfminer.high_level import extract_text as pdf_extract_text
except ImportError:
pdf_extract_text = None # type: ignore
# ----------------------------------------------------------------------
# HTML extraction helpers
# ----------------------------------------------------------------------
def _extract_meta(soup: BeautifulSoup) -> dict:
"""Pull meta description and keywords if present."""
description = ""
keywords = ""
desc_tag = soup.find("meta", attrs={"name": re.compile("description", re.I)})
if desc_tag and desc_tag.get("content"):
description = desc_tag["content"].strip()
kw_tag = soup.find("meta", attrs={"name": re.compile("keywords", re.I)})
if kw_tag and kw_tag.get("content"):
keywords = kw_tag["content"].strip()
return {"description": description, "keywords": keywords}
def _extract_lists(soup: BeautifulSoup) -> List[List[str]]:
"""Return a list of lists, each inner list representing a <ul>/<ol>."""
all_lists = []
for lst in soup.find_all(["ul", "ol"]):
items = [li.get_text(separator=" ", strip=True) for li in lst.find_all("li")]
if items:
all_lists.append(items)
return all_lists
def _extract_tables(soup: BeautifulSoup) -> List[List[List[str]]]:
"""Return a list of tables, each table is a list of rows, each row a list of cell texts."""
tables_data = []
for table in soup.find_all("table"):
rows = []
for tr in table.find_all("tr"):
cells = [td.get_text(separator=" ", strip=True) for td in tr.find_all(["td", "th"])]
if cells:
rows.append(cells)
if rows:
tables_data.append(rows)
return tables_data
def _extract_code_blocks(soup: BeautifulSoup) -> List[str]:
"""Collect text from <pre> and <code> blocks."""
blocks = []
for tag in soup.find_all(["pre", "code"]):
txt = tag.get_text(separator=" ", strip=True)
if txt:
blocks.append(txt)
return blocks
def _detect_js_frameworks(soup: BeautifulSoup) -> bool:
"""Very naive detection of common JS frameworks."""
js_indicators = [
"react", "angular", "vue", "svelte", "next", "nuxt",
"ember", "backbone", "jquery", "polymer", "mithril",
]
for script in soup.find_all("script"):
src = script.get("src", "").lower()
if any(fr in src for fr in js_indicators):
return True
if script.string:
content = script.string.lower()
if any(fr in content for fr in js_indicators):
return True
if soup.find(attrs={"data-reactroot": True}) or soup.find(attrs={"ng-app": True}):
return True
return False
def _empty_result(url: str, error: str = "") -> dict:
"""Build a standard failure result dict."""
return {
"url": url,
"title": "",
"content": "",
"lists": [],
"tables": [],
"code_blocks": [],
"meta_description": "",
"meta_keywords": "",
"js_rendered": False,
"js_message": "",
"success": False,
"error": error,
}
# ----------------------------------------------------------------------
# Main content fetcher
# ----------------------------------------------------------------------
def fetch_webpage_content(url: str, timeout: int = 5, retry_attempt: int = 0) -> dict:
"""Fetch and extract meaningful content from a webpage with caching."""
cache_key = generate_cache_key(url)
cache_file = CONTENT_CACHE_DIR / f"{cache_key}.cache"
# Check cache
if cache_file.exists():
try:
with open(cache_file, "r", encoding="utf-8") as f:
cached_data = json.load(f)
timestamp = datetime.fromisoformat(cached_data["timestamp"])
if datetime.now() - timestamp < timedelta(hours=2):
logger.debug(f"Content cache hit for URL: {url}")
return cached_data["data"]
else:
cache_file.unlink(missing_ok=True)
content_cache_index.pop(cache_key, None)
except Exception as e:
logger.warning(f"Failed to read content cache for {url}: {e}")
cache_file.unlink(missing_ok=True)
content_cache_index.pop(cache_key, None)
# Fetch
try:
headers = {
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36",
"Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8",
"Accept-Language": "en-US,en;q=0.5",
"Accept-Encoding": "gzip, deflate",
"Connection": "keep-alive",
}
response = _get_public_url(url, headers=headers, timeout=timeout)
if response.status_code == 429:
raise RateLimitError(f"Rate limit hit for {url} (attempt {retry_attempt})")
response.raise_for_status()
except httpx.RequestError as e:
error_logger.error(f"NetworkError fetching {url} (attempt {retry_attempt}): {e}")
return _empty_result(url, f"NetworkError: {e}")
except RateLimitError as e:
error_logger.error(str(e))
return _empty_result(url, str(e))
# PDF handling
content_type = response.headers.get("Content-Type", "").lower()
if "application/pdf" in content_type or url.lower().endswith(".pdf"):
if pdf_extract_text is None:
logger.error("pdfminer.six is not installed; cannot extract PDF text.")
pdf_text = ""
else:
try:
pdf_bytes = io.BytesIO(response.content)
pdf_text = pdf_extract_text(pdf_bytes)
except Exception as e:
logger.warning(f"PDF extraction failed for {url}: {e}")
pdf_text = ""
result = {
"url": url,
"title": os.path.basename(url),
"content": pdf_text,
"lists": [],
"tables": [],
"code_blocks": [],
"meta_description": "",
"meta_keywords": "",
"js_rendered": False,
"js_message": "",
"success": bool(pdf_text),
"error": "" if pdf_text else "Failed to extract PDF text",
}
_cache_result(cache_file, cache_key, result, url)
return result
# HTML handling
try:
soup = BeautifulSoup(response.text, "html.parser")
except Exception as e:
error_logger.error(f"ParseError parsing HTML from {url} (attempt {retry_attempt}): {e}")
result = _empty_result(url, f"ParseError: {e}")
_cache_result(cache_file, cache_key, result, url)
return result
title_tag = soup.find("title")
title_text = title_tag.get_text(strip=True) if title_tag else ""
meta_info = _extract_meta(soup)
js_rendered = _detect_js_frameworks(soup)
js_message = "Page appears to be rendered by a JavaScript framework; content may be incomplete." if js_rendered else ""
# Main textual content (heuristic)
main_content = ""
content_areas = soup.find_all(
["main", "article", "section", "div"],
class_=re.compile("content|main|body|article|post|entry|text", re.I),
)
if content_areas:
for area in content_areas[:3]:
main_content += area.get_text(separator=" ", strip=True) + " "
if not main_content:
body = soup.find("body")
if body:
main_content = body.get_text(separator=" ", strip=True)
main_content = re.sub(r"\s+", " ", main_content).strip()[:8000]
result = {
"url": url,
"title": title_text,
"content": main_content,
"lists": _extract_lists(soup),
"tables": _extract_tables(soup),
"code_blocks": _extract_code_blocks(soup),
"meta_description": meta_info.get("description", ""),
"meta_keywords": meta_info.get("keywords", ""),
"js_rendered": js_rendered,
"js_message": js_message,
"success": True,
"error": "",
}
_cache_result(cache_file, cache_key, result, url)
return result
def _cache_result(cache_file, cache_key: str, result: dict, url: str):
"""Write a result to the content cache."""
try:
cache_data = {"timestamp": datetime.now().isoformat(), "data": result}
with open(cache_file, "w", encoding="utf-8") as f:
json.dump(cache_data, f)
content_cache_index[cache_key] = datetime.now()
cleanup_cache(CONTENT_CACHE_DIR, content_cache_index, timedelta(hours=2))
except Exception as e:
logger.warning(f"Failed to write content cache for {url}: {e}")
# ----------------------------------------------------------------------
# Content summarization helpers
# ----------------------------------------------------------------------
def extract_key_points(text: str) -> List[str]:
"""Pull out bullet-style key points from a block of text."""
points: List[str] = []
bullet_pat = re.compile(r"^\s*[-*•]\s+(.*)")
numbered_pat = re.compile(r"^\s*\d+[\.\)]\s+(.*)")
for line in text.splitlines():
m = bullet_pat.match(line) or numbered_pat.match(line)
if m:
points.append(m.group(1).strip())
return points
def get_tldr(text: str, max_sentences: int = 3) -> str:
"""Produce a very short TL;DR by taking the first few sentences."""
sentences = re.split(r"(?<=[.!?])\s+", text)
selected = [s.strip() for s in sentences if s][:max_sentences]
return " ".join(selected)
def extract_quotes(text: str) -> List[str]:
"""Return quoted excerpts that are at least 15 characters long."""
return [m.group(1).strip() for m in re.finditer(r'["\']([^"\']{15,}?)["\']', text)]
def extract_statistics(text: str) -> List[str]:
"""Find numbers, percentages, dates and simple measurements."""
pattern = re.compile(
r"\b\d{1,3}(?:,\d{3})*(?:\.\d+)?\s*(%|percent|‰|per cent|[a-zA-Z]+)?\b",
re.IGNORECASE,
)
return [m.group(0).strip() for m in pattern.finditer(text)]
+433
View File
@@ -0,0 +1,433 @@
"""Core search orchestrators: searxng_search_results, comprehensive_web_search, config, cache invalidation."""
import json
import logging
from concurrent.futures import ThreadPoolExecutor, as_completed
from datetime import datetime, timedelta
from typing import Dict, Any, Optional, List, Set
from urllib.parse import urlparse
from .analytics import (
NetworkError,
ParseError,
RateLimitError,
error_logger,
_record_query,
)
from .cache import (
SEARCH_CACHE_DIR,
search_cache_index,
generate_cache_key,
cleanup_cache,
)
from .query import _cache_duration_for_query
from .ranking import rank_search_results
from .providers import (
searxng_search_api,
brave_search,
duckduckgo_search,
google_pse_search,
tavily_search,
serper_search,
_get_search_settings,
_get_result_count,
)
from .content import (
fetch_webpage_content,
extract_key_points,
get_tldr,
extract_quotes,
extract_statistics,
)
logger = logging.getLogger(__name__)
# ========= CONFIG =========
SEARCH_CONFIG: Dict[str, Any] = {
"primary_provider": "searxng",
}
def get_search_config() -> Dict[str, Any]:
"""Get current search configuration including active provider info."""
config = SEARCH_CONFIG.copy()
settings = _get_search_settings()
provider = settings.get("search_provider", "searxng")
config["active_provider"] = provider
config["has_api_key"] = bool((settings.get("search_api_key") or "").strip())
config["result_count"] = _get_result_count()
if provider == "searxng":
from .providers import _get_search_instance
config["search_url"] = _get_search_instance()
return config
def update_search_config(api_key: str = None, **kwargs):
"""Update search configuration (e.g. Brave API key)."""
if api_key:
SEARCH_CONFIG["brave_api_key"] = api_key
def _call_provider(provider_name: str, query: str, count: int, time_filter: str = None) -> List[dict]:
"""Call a search provider by name. Returns list of results or empty list."""
if provider_name == "searxng":
return searxng_search_api(query, count, time_filter=time_filter)
elif provider_name == "brave":
return brave_search(query, count, time_filter)
elif provider_name == "duckduckgo":
return duckduckgo_search(query, count, time_filter)
elif provider_name == "google_pse":
return google_pse_search(query, count, time_filter)
elif provider_name == "tavily":
return tavily_search(query, count, time_filter)
elif provider_name == "serper":
return serper_search(query, count, time_filter)
return []
# If the self-hosted SearXNG instance is up but all enabled engines return
# empty, fall back to the no-key provider so "search X" still works on fresh
# installs. Users can override/disable with `search_fallback_chain`.
_FALLBACK_ORDER = ["duckduckgo"]
def _build_provider_chain(primary: str) -> List[str]:
"""Build ordered list: primary first, then configured/default fallbacks."""
chain = [primary]
settings = _get_search_settings()
user_chain = settings.get("search_fallback_chain") or []
if isinstance(user_chain, str):
user_chain = [s.strip() for s in user_chain.split(",") if s.strip()]
fallbacks = user_chain if user_chain else _FALLBACK_ORDER
for fb in fallbacks:
if fb and fb != primary and fb not in chain and fb != "disabled":
chain.append(fb)
return chain
# ----------------------------------------------------------------------
# Unified search with caching and retry
# ----------------------------------------------------------------------
def searxng_search_results(query: str, count: int = 10, time_filter: str = None) -> list[dict]:
"""Perform a web search using configured provider with caching and retry."""
settings = _get_search_settings()
search_provider = settings.get("search_provider", "searxng")
result_count = _get_result_count()
# Use configured count if caller used default
if count == 10:
count = result_count
cache_key = generate_cache_key(f"{query}|{count}|{time_filter}")
cache_file = SEARCH_CACHE_DIR / f"{cache_key}.cache"
# Check cache
if cache_file.exists():
try:
with open(cache_file, "r", encoding="utf-8") as f:
cached_data = json.load(f)
expiry_raw = cached_data.get("expiry")
expiry = datetime.fromisoformat(expiry_raw) if expiry_raw else None
if expiry and datetime.now() < expiry:
logger.debug(f"Search cache hit for query: {query}")
results = cached_data["data"]
_record_query(query, bool(results), cache_hit=True)
return results
else:
cache_file.unlink(missing_ok=True)
search_cache_index.pop(cache_key, None)
except Exception as e:
logger.warning(f"Failed to read search cache for {query}: {e}")
cache_file.unlink(missing_ok=True)
search_cache_index.pop(cache_key, None)
logger.debug(f"Search cache miss for query: {query}")
if search_provider == "disabled":
logger.info("Search is disabled via admin settings")
return []
provider_chain = _build_provider_chain(search_provider)
results: List[dict] = []
for provider_name in provider_chain:
for attempt in range(2):
try:
logger.info(f"Attempting {provider_name} search (attempt {attempt + 1})")
results = _call_provider(provider_name, query, count, time_filter)
if results:
logger.info(f"{provider_name} search succeeded with {len(results)} results")
break
except (NetworkError, ParseError, RateLimitError) as e:
error_logger.error(f"{provider_name} search error (attempt {attempt + 1}): {e}")
except Exception as e:
error_logger.error(f"Unexpected error during {provider_name} search (attempt {attempt + 1}): {e}")
if results:
break
success = bool(results)
_record_query(query, success, cache_hit=False)
if success:
results = rank_search_results(query, results)
try:
expiry = datetime.now() + _cache_duration_for_query(query)
cache_data = {
"timestamp": datetime.now().isoformat(),
"expiry": expiry.isoformat(),
"data": results,
}
with open(cache_file, "w", encoding="utf-8") as f:
json.dump(cache_data, f)
search_cache_index[cache_key] = datetime.now()
cleanup_cache(SEARCH_CACHE_DIR, search_cache_index, timedelta(hours=1))
except Exception as e:
logger.warning(f"Failed to write search cache for {query}: {e}")
if not success:
logger.error(f"All search providers failed for query: {query}")
return results
# ----------------------------------------------------------------------
# Cache invalidation
# ----------------------------------------------------------------------
def invalidate_search_cache(query: Optional[str] = None) -> None:
"""Invalidate cached search results. None clears all, otherwise just the given query."""
if query is None:
for file in SEARCH_CACHE_DIR.glob("*.cache"):
try:
file.unlink(missing_ok=True)
except Exception as e:
error_logger.warning(f"Failed to delete cache file {file}: {e}")
search_cache_index.clear()
logger.info("All search cache entries have been cleared.")
else:
cache_key = generate_cache_key(f"{query}|10|None")
cache_file = SEARCH_CACHE_DIR / f"{cache_key}.cache"
if cache_file.exists():
try:
cache_file.unlink(missing_ok=True)
search_cache_index.pop(cache_key, None)
logger.info(f"Cache entry for query '{query}' has been invalidated.")
except Exception as e:
error_logger.warning(f"Failed to delete cache file for query '{query}': {e}")
else:
logger.info(f"No cache entry found for query '{query}'.")
# ----------------------------------------------------------------------
# Comprehensive web search (with advanced filtering)
# ----------------------------------------------------------------------
def comprehensive_web_search(
query: str,
max_pages: int = 3,
max_workers: int = 4,
time_filter: str = None,
domain_whitelist: Optional[Set[str]] = None,
domain_blacklist: Optional[Set[str]] = None,
content_type: Optional[str] = None,
language: Optional[str] = None,
min_content_length: int = 0,
return_sources: bool = False,
):
"""Perform comprehensive web search with content fetching and advanced filtering."""
logger.info(f"Starting comprehensive search for: {query}")
if time_filter:
logger.info(f"Applying time filter: {time_filter}")
settings = _get_search_settings()
search_provider = settings.get("search_provider", "searxng")
result_count = _get_result_count()
if search_provider == "disabled":
logger.info("Search is disabled via admin settings")
msg = "Web search is disabled by the administrator."
return (msg, []) if return_sources else msg
# Use configured result count (at least max_pages for content fetching)
fetch_count = max(result_count, max_pages)
provider_chain = _build_provider_chain(search_provider)
search_results = []
provider_attempts = {}
for provider_name in provider_chain:
last_err = None
empty = False
for attempt in range(2):
try:
search_results = _call_provider(provider_name, query, fetch_count, time_filter)
if search_results:
provider_attempts[provider_name] = f"ok ({len(search_results)})"
logger.info(f"Comprehensive search: {provider_name} returned {len(search_results)} results")
break
empty = True
except Exception as e:
last_err = e
logger.warning(f"Comprehensive search: {provider_name} attempt {attempt + 1} failed: {e}")
if search_results:
break
if last_err is not None:
provider_attempts[provider_name] = f"error: {last_err}"
elif empty:
provider_attempts[provider_name] = "empty"
if not search_results:
tally = ", ".join(f"{p}:{r}" for p, r in provider_attempts.items()) or "no providers configured"
any_errors = any(r.startswith("error") for r in provider_attempts.values())
if any_errors:
msg = f"Web search failed — all providers errored or returned empty. Tried: {tally}"
else:
msg = (
f"No search results found. Tried: {tally}. "
"All providers returned empty — possibly a niche query or upstream rate-limiting; "
"rephrasing or using the browser tool for a specific URL may help."
)
logger.warning(msg)
return (msg, []) if return_sources else msg
search_results = rank_search_results(query, search_results)
# URL filter helper
def url_passes_filters(url: str) -> bool:
try:
netloc = urlparse(url).netloc.lower()
except Exception:
return False
if domain_whitelist is not None and netloc not in domain_whitelist:
return False
if domain_blacklist is not None and netloc in domain_blacklist:
return False
if content_type:
ct = content_type.lower()
if ct == "article":
if not any(k in url.lower() for k in ("article", "blog", "news", "post")):
return False
elif ct == "forum":
if not any(k in url.lower() for k in ("forum", "discussion", "thread", "topic")):
return False
elif ct == "academic":
if not any(k in url.lower() for k in ("pdf", "doi", "scholar", "arxiv", "journal", "research")):
return False
if language:
lang_pat = language.lower()
if not (f"/{lang_pat}/" in url.lower() or f"?lang={lang_pat}" in url.lower() or f"&lang={lang_pat}" in url.lower()):
return False
return True
filtered_urls = [r["url"] for r in search_results[:max_pages] if url_passes_filters(r["url"])]
if not filtered_urls:
logger.warning("All URLs filtered out by advanced criteria")
msg = "No suitable results after applying filters."
return (msg, []) if return_sources else msg
# Build sources list for the frontend (before content fetching)
_source_list = [
{"url": r.get("url", ""), "title": r.get("title", "")}
for r in search_results if r.get("url")
]
# Fetch content in parallel
fetched_content = []
with ThreadPoolExecutor(max_workers=max_workers) as executor:
future_to_url = {
executor.submit(fetch_webpage_content, url, 8, retry_attempt=0): url
for url in filtered_urls
}
for future in as_completed(future_to_url):
url = future_to_url[future]
try:
result = future.result()
if result["success"] and result["content"] and len(result["content"]) >= min_content_length:
fetched_content.append(result)
except Exception as e:
logger.error(f"Exception while fetching {url}: {str(e)}")
logger.info(f"Successfully fetched content from {len(fetched_content)} pages")
# Format results
output_parts = []
if search_results:
output_parts.append("```sources")
for i, result in enumerate(search_results, 1):
output_parts.append(f"[{i}] {result['title']}")
output_parts.append(f" {result['url']}")
if result.get("age"):
output_parts.append(f" {result['age']}")
output_parts.append("```")
output_parts.append("")
output_parts.append("=" * 70)
output_parts.append("WEB SEARCH RESULTS AND FETCHED CONTENT")
output_parts.append(f"Query: {query}")
output_parts.append(f"Searched {len(search_results)} results, fetched {len(fetched_content)} pages")
output_parts.append("=" * 70)
output_parts.append("")
output_parts.append("SEARCH RESULTS SUMMARY:")
output_parts.append("-" * 50)
for i, result in enumerate(search_results, 1):
output_parts.append(f"\n[{i}] {result['title']}")
output_parts.append(f" URL: {result['url']}")
output_parts.append(f" Snippet: {result['snippet'][:200]}...")
if result.get("age"):
output_parts.append(f" Age: {result['age']}")
if fetched_content:
output_parts.append("\n" + "=" * 70)
output_parts.append("FETCHED PAGE CONTENT:")
output_parts.append("-" * 50)
for i, content in enumerate(fetched_content, 1):
output_parts.append(f"\n[CONTENT {i}] From: {content['url']}")
output_parts.append(f"Title: {content['title']}")
output_parts.append("-" * 30)
text = content["content"][:3000]
if len(content["content"]) > 3000:
text += "... [truncated]"
output_parts.append(text)
key_points = extract_key_points(content["content"])
if key_points:
output_parts.append("\nKey Points:")
for pt in key_points[:5]:
output_parts.append(f"- {pt}")
tldr = get_tldr(content["content"])
if tldr:
output_parts.append("\nTL;DR:")
output_parts.append(tldr)
quotes = extract_quotes(content["content"])
if quotes:
output_parts.append("\nImportant Quotes:")
for q in quotes[:3]:
output_parts.append(f"\u201c{q}\u201d")
stats = extract_statistics(content["content"])
if stats:
output_parts.append("\nData / Statistics:")
for s in stats[:5]:
output_parts.append(f"- {s}")
output_parts.append("")
output_parts.append("=" * 70)
output_parts.append("END OF WEB SEARCH RESULTS")
output_parts.append("=" * 70)
instructions = (
"\n\nIMPORTANT INSTRUCTIONS:\n"
"1. Use the above web search results and fetched content to answer the user's question\n"
"2. Prioritize information from the FETCHED PAGE CONTENT section as it contains actual page data\n"
"3. Cross-reference multiple sources when possible\n"
"4. If the information is time-sensitive, pay attention to the age of the results\n"
"5. Be explicit if the search results don't contain sufficient information to fully answer the question"
)
output_parts.append(instructions)
result = "\n".join(output_parts)
return (result, _source_list) if return_sources else result
+527
View File
@@ -0,0 +1,527 @@
"""Search provider implementations: SearXNG, Brave, DuckDuckGo, Google PSE, Tavily, Serper."""
import json
import logging
import os
from typing import List, Optional
import httpx
from bs4 import BeautifulSoup
from src.constants import SEARXNG_INSTANCE
from .analytics import RateLimitError, error_logger
from .query import build_enhanced_query
logger = logging.getLogger(__name__)
REQUEST_TIMEOUT = 20
# Provider registry — maps setting value to (label, needs_key, needs_url)
PROVIDER_INFO = {
"searxng": ("SearXNG", False, True),
"brave": ("Brave Search", True, False),
"duckduckgo": ("DuckDuckGo", False, False),
"google_pse": ("Google PSE", True, False),
"tavily": ("Tavily", True, False),
"serper": ("Serper", True, False),
"disabled": ("Disabled", False, False),
}
# ── Settings helpers ──
def _get_search_settings() -> dict:
"""Return search settings from admin config, falling back to env defaults."""
try:
from src.settings import load_settings
return load_settings()
except Exception:
return {}
def _get_search_instance() -> str:
"""Return the active search API URL from admin settings, falling back to env var."""
settings = _get_search_settings()
url = (settings.get("search_url") or "").strip()
if url:
return url.rstrip("/")
return SEARXNG_INSTANCE
def _get_provider_key(provider: str) -> str:
"""Return the API key for a specific provider, with legacy fallback."""
settings = _get_search_settings()
key_map = {
"brave": "brave_api_key",
"google_pse": "google_pse_key",
"tavily": "tavily_api_key",
"serper": "serper_api_key",
}
field = key_map.get(provider, "")
if field:
val = (settings.get(field) or "").strip()
if val:
return val
# Legacy fallback: old shared search_api_key field
return (settings.get("search_api_key") or "").strip()
def _get_result_count() -> int:
"""Return configured result count, default 5."""
settings = _get_search_settings()
try:
return int(settings.get("search_result_count", 5))
except (ValueError, TypeError):
return 5
# ── SearXNG ──
_NEWS_HINTS = ("news", "nyheter", "headlines", "breaking", "latest", "today", "idag")
# Default general engines (google/duckduckgo/brave/startpage/wikipedia) are
# routinely rate-limited / CAPTCHA-blocked on this instance and return nothing.
# Pin engines that actually respond so non-news queries get results without any
# third-party API fallback. Override via SEARXNG_GENERAL_ENGINES.
_GENERAL_ENGINES = os.environ.get("SEARXNG_GENERAL_ENGINES", "bing,mojeek,presearch")
def searxng_search_api(query: str, count: int = 10, categories: str = "general",
time_filter: Optional[str] = None) -> List[dict]:
"""Search using SearXNG JSON API. Returns list of {title, url, snippet}."""
instance = _get_search_instance()
api_key = ""
headers = {"User-Agent": "Mozilla/5.0"}
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
# News/fresh queries do badly in the 'general' category — it favours
# encyclopedic/tourism pages, ignores recency, and (with no language pin)
# bleeds in foreign-language results. When the agent layer detected
# freshness (time_filter) or the query reads like a news lookup, switch to
# the 'news' category, constrain recency, and pin language to English so a
# search like "Canada latest news" returns actual news instead of Wikipedia.
# Pin English for ALL searches — without it, SearXNG geolocates / mixes
# languages and brand-ambiguous terms bleed in foreign SEO pages (e.g.
# "Odyssey" → Honda Japan, "Trojan" → Japanese malware blogs, "Polyphemus"
# → Chinese math forums). The news path already did this; general didn't.
params = {"q": query, "format": "json", "language": "en"}
q_lc = query.lower()
is_news = time_filter is not None or any(h in q_lc for h in _NEWS_HINTS)
if is_news and categories == "general":
params["categories"] = "news"
if time_filter in ("day", "week", "month", "year"):
# 'day' is too sparse on most SearXNG news engines — widen to a week
# so there's enough volume; the news category already biases recent.
params["time_range"] = "week" if time_filter in ("day", "week") else time_filter
else:
params["categories"] = categories
# Route general queries to engines that aren't blocked (default general
# set returns 0 on this instance — see _GENERAL_ENGINES).
if categories == "general" and _GENERAL_ENGINES:
params["engines"] = _GENERAL_ENGINES
try:
def _parse_results(results):
return [
{
"title": r.get("title", ""),
"url": r.get("url", ""),
"snippet": r.get("content", ""),
}
for r in results[:count]
if r.get("url")
]
def _run(search_params):
response = httpx.get(
f"{instance}/search",
params=search_params,
headers=headers or None,
timeout=15,
)
response.raise_for_status()
data = response.json()
return _parse_results(data.get("results", [])), data
active_params = params
parsed, data = _run(active_params)
if not parsed and is_news and categories == "general":
# Some self-hosted SearXNG configs have no working news engines.
# Fall back to the known-good general engines before reporting an
# empty search, otherwise common queries like "Canada news" fail.
fallback = {
"q": query,
"format": "json",
"language": "en",
"categories": "general",
}
if _GENERAL_ENGINES:
fallback["engines"] = _GENERAL_ENGINES
logger.info(
"SearXNG news search returned 0 results for %r; retrying general engines",
query,
)
active_params = fallback
parsed, data = _run(active_params)
if not parsed and active_params.get("language"):
fallback = dict(active_params)
fallback.pop("language", None)
logger.info(
"SearXNG language-pinned search returned 0 results for %r; retrying without language",
query,
)
active_params = fallback
parsed, data = _run(active_params)
if not parsed and active_params.get("engines"):
fallback = dict(active_params)
fallback.pop("engines", None)
logger.info(
"SearXNG pinned engines returned 0 results for %r; retrying default engines",
query,
)
parsed, data = _run(fallback)
logger.info(f"SearXNG JSON API returned {len(parsed)} results for: {query}")
if not parsed:
unresponsive = data.get("unresponsive_engines") if isinstance(data, dict) else None
if unresponsive:
logger.info(f"SearXNG unresponsive engines for {query!r}: {unresponsive}")
return parsed
except Exception as e:
logger.warning(f"SearXNG JSON API search failed: {e}")
html_results = searxng_search(query, max_results=count)
if html_results:
logger.info(f"SearXNG HTML fallback returned {len(html_results)} results for: {query}")
return html_results
def searxng_search(query, max_results=10):
"""Search using SearXNG instance - parsing HTML."""
instance = _get_search_instance()
api_key = ""
req_headers = {"User-Agent": "Mozilla/5.0"}
if api_key:
req_headers["Authorization"] = f"Bearer {api_key}"
try:
response = httpx.get(
f"{instance}/search",
params={"q": query},
headers=req_headers,
timeout=10,
)
if response.is_success:
soup = BeautifulSoup(response.text, "html.parser")
results = []
for article in soup.select("article.result")[:max_results]:
title_elem = article.select_one("h3 a")
if not title_elem:
continue
title = title_elem.get_text(strip=True)
url = title_elem.get("href", "")
snippet_elem = article.select_one("p.content")
snippet = snippet_elem.get_text(strip=True) if snippet_elem else ""
results.append({"title": title, "url": url, "snippet": snippet})
logger.info(f"SearXNG search (HTML) returned {len(results)} results")
return results
except Exception as e:
logger.error(f"SearXNG search failed: {e}")
return []
# ── Brave ──
def brave_search(query: str, count: int = 10, time_filter: Optional[str] = None) -> List[dict]:
"""Search using Brave API with key from admin settings or env var."""
api_key = _get_provider_key("brave") or os.environ.get("DATA_BRAVE_API_KEY") or ""
return _brave_search_impl(query, count, time_filter, search_config={"brave_api_key": api_key})
def _brave_search_impl(query: str, count: int, time_filter: Optional[str] = None, search_config: dict = None) -> List[dict]:
"""Core Brave API call. Returns a list of result dicts or an empty list on failure."""
enhanced_query = build_enhanced_query(query, time_filter)
config = search_config or {}
brave_api_key = config.get("brave_api_key")
if not brave_api_key:
brave_api_key = os.environ.get("DATA_BRAVE_API_KEY")
if not brave_api_key:
logger.warning("Brave API key not found, returning empty results for fallback")
return []
headers = {"X-Subscription-Token": brave_api_key, "Accept": "application/json"}
params = {"q": enhanced_query, "count": count}
if time_filter:
time_map = {"day": "day", "week": "week", "month": "month", "year": "year"}
if time_filter in time_map:
params["freshness"] = time_map[time_filter]
logger.info(f"Executing Brave search with query: {enhanced_query}")
try:
response = httpx.get(
"https://api.search.brave.com/res/v1/web/search",
headers=headers,
params=params,
timeout=REQUEST_TIMEOUT,
)
if response.status_code == 429:
raise RateLimitError("Brave rate limit hit")
response.raise_for_status()
except httpx.RequestError as e:
error_logger.error(f"NetworkError during Brave search: {e}")
return []
except RateLimitError as e:
error_logger.error(str(e))
return []
try:
data = response.json()
except json.JSONDecodeError as e:
logger.error(f"Failed to parse Brave API response: {e}")
return []
results = []
if "web" in data and "results" in data["web"]:
for item in data["web"]["results"][:count]:
url = item.get("url", "")
if not url:
continue
results.append({
"title": item.get("title", ""),
"url": url,
"snippet": item.get("description", "") or item.get("content", ""),
"age": item.get("date", "") if item.get("date") else "",
})
logger.info(f"Brave search returned {len(results)} results")
return results
# ── DuckDuckGo (free, no key) ──
def duckduckgo_search(query: str, count: int = 10, time_filter: Optional[str] = None) -> List[dict]:
"""Search using DuckDuckGo via the duckduckgo-search library. No API key needed."""
def _html_fallback() -> List[dict]:
try:
response = httpx.get(
"https://html.duckduckgo.com/html/",
params={"q": query},
headers={"User-Agent": "Mozilla/5.0"},
timeout=REQUEST_TIMEOUT,
)
response.raise_for_status()
soup = BeautifulSoup(response.text, "html.parser")
parsed = []
for result in soup.select(".result")[:count]:
link = result.select_one(".result__a")
if not link:
continue
url = link.get("href", "")
if not url:
continue
snippet_el = result.select_one(".result__snippet")
parsed.append({
"title": link.get_text(" ", strip=True),
"url": url,
"snippet": snippet_el.get_text(" ", strip=True) if snippet_el else "",
})
logger.info(f"DuckDuckGo HTML search returned {len(parsed)} results")
return parsed
except Exception as e:
logger.warning(f"DuckDuckGo HTML search failed: {e}")
return []
try:
from duckduckgo_search import DDGS
except ImportError:
logger.warning("duckduckgo-search package not installed; using HTML fallback")
return _html_fallback()
timelimit = None
if time_filter:
time_map = {"day": "d", "week": "w", "month": "m", "year": "y"}
timelimit = time_map.get(time_filter)
try:
ddgs = DDGS()
raw = ddgs.text(query, max_results=count, timelimit=timelimit)
results = []
for item in raw:
url = item.get("href", "")
if not url:
continue
results.append({
"title": item.get("title", ""),
"url": url,
"snippet": item.get("body", ""),
})
logger.info(f"DuckDuckGo search returned {len(results)} results")
return results or _html_fallback()
except Exception as e:
logger.warning(f"DuckDuckGo search failed: {e}")
return _html_fallback()
# ── Google Programmable Search Engine ──
def google_pse_search(query: str, count: int = 10, time_filter: Optional[str] = None) -> List[dict]:
"""Search using Google PSE (Custom Search JSON API).
Requires two keys in settings:
- search_api_key: Google API key
- google_pse_cx: Programmable Search Engine ID (cx)
Or env vars GOOGLE_API_KEY and GOOGLE_PSE_CX.
"""
settings = _get_search_settings()
api_key = _get_provider_key("google_pse") or os.environ.get("GOOGLE_API_KEY", "")
cx = (settings.get("google_pse_cx") or "").strip() or os.environ.get("GOOGLE_PSE_CX", "")
if not api_key or not cx:
logger.warning("Google PSE: missing API key or CX ID")
return []
params = {
"key": api_key,
"cx": cx,
"q": query,
"num": min(count, 10), # Google PSE max is 10 per request
}
if time_filter:
# dateRestrict: d[number], w[number], m[number], y[number]
time_map = {"day": "d1", "week": "w1", "month": "m1", "year": "y1"}
if time_filter in time_map:
params["dateRestrict"] = time_map[time_filter]
try:
response = httpx.get(
"https://www.googleapis.com/customsearch/v1",
params=params,
timeout=REQUEST_TIMEOUT,
)
if response.status_code == 429:
raise RateLimitError("Google PSE rate limit hit")
response.raise_for_status()
data = response.json()
except httpx.RequestError as e:
error_logger.error(f"Google PSE search failed: {e}")
return []
except RateLimitError as e:
error_logger.error(str(e))
return []
results = []
for item in data.get("items", [])[:count]:
url = item.get("link", "")
if not url:
continue
results.append({
"title": item.get("title", ""),
"url": url,
"snippet": item.get("snippet", ""),
})
logger.info(f"Google PSE returned {len(results)} results")
return results
# ── Tavily ──
def tavily_search(query: str, count: int = 10, time_filter: Optional[str] = None) -> List[dict]:
"""Search using Tavily API. Requires search_api_key or TAVILY_API_KEY env var."""
api_key = _get_provider_key("tavily") or os.environ.get("TAVILY_API_KEY", "")
if not api_key:
logger.warning("Tavily: no API key configured")
return []
payload = {
"query": query,
"max_results": count,
"include_answer": False,
}
if time_filter:
time_map = {"day": "day", "week": "week", "month": "month", "year": "year"}
if time_filter in time_map:
payload["days"] = {"day": 1, "week": 7, "month": 30, "year": 365}[time_filter]
try:
response = httpx.post(
"https://api.tavily.com/search",
json=payload,
headers={"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"},
timeout=REQUEST_TIMEOUT,
)
if response.status_code == 429:
raise RateLimitError("Tavily rate limit hit")
response.raise_for_status()
data = response.json()
except httpx.RequestError as e:
error_logger.error(f"Tavily search failed: {e}")
return []
except RateLimitError as e:
error_logger.error(str(e))
return []
results = []
for item in data.get("results", [])[:count]:
url = item.get("url", "")
if not url:
continue
results.append({
"title": item.get("title", ""),
"url": url,
"snippet": item.get("content", ""),
"age": item.get("published_date", ""),
})
logger.info(f"Tavily returned {len(results)} results")
return results
# ── Serper.dev ──
def serper_search(query: str, count: int = 10, time_filter: Optional[str] = None) -> List[dict]:
"""Search using Serper.dev API. Requires search_api_key or SERPER_API_KEY env var."""
api_key = _get_provider_key("serper") or os.environ.get("SERPER_API_KEY", "")
if not api_key:
logger.warning("Serper: no API key configured")
return []
payload = {
"q": query,
"num": count,
}
if time_filter:
time_map = {"day": "qdr:d", "week": "qdr:w", "month": "qdr:m", "year": "qdr:y"}
if time_filter in time_map:
payload["tbs"] = time_map[time_filter]
try:
response = httpx.post(
"https://google.serper.dev/search",
json=payload,
headers={"X-API-KEY": api_key, "Content-Type": "application/json"},
timeout=REQUEST_TIMEOUT,
)
if response.status_code == 429:
raise RateLimitError("Serper rate limit hit")
response.raise_for_status()
data = response.json()
except httpx.RequestError as e:
error_logger.error(f"Serper search failed: {e}")
return []
except RateLimitError as e:
error_logger.error(str(e))
return []
results = []
for item in data.get("organic", [])[:count]:
url = item.get("link", "")
if not url:
continue
results.append({
"title": item.get("title", ""),
"url": url,
"snippet": item.get("snippet", ""),
"age": item.get("date", ""),
})
logger.info(f"Serper returned {len(results)} results")
return results
+128
View File
@@ -0,0 +1,128 @@
"""Query enhancement, entity extraction, and cache duration helpers."""
import re
import logging
from datetime import timedelta
from typing import Dict, List, Optional, Tuple
logger = logging.getLogger(__name__)
# ----------------------------------------------------------------------
# Query processing helpers
# ----------------------------------------------------------------------
def _detect_question_type(query: str) -> Optional[str]:
"""Return the leading question word if present (who, what, when, where, why, how)."""
q = query.strip().lower()
for word in ("who", "what", "when", "where", "why", "how"):
if q.startswith(word):
return word
return None
def _extract_entities(query: str) -> Dict[str, List[str]]:
"""Lightweight entity extraction: capitalized words and date patterns."""
entities: Dict[str, List[str]] = {"names": [], "dates": []}
qtype = _detect_question_type(query)
cleaned = query
if qtype:
cleaned = re.sub(rf"^{qtype}\b", "", cleaned, flags=re.I).strip()
for token in re.findall(r"\b[A-Z][a-zA-Z]+\b", cleaned):
entities["names"].append(token)
for year in re.findall(r"\b(19|20)\d{2}\b", cleaned):
entities["dates"].append(year)
month_day_year = re.findall(
r"\b(?:Jan|January|Feb|February|Mar|March|Apr|April|May|Jun|June|Jul|July|Aug|August|Sep|Sept|September|Oct|October|Nov|November|Dec|December)\s+\d{1,2},?\s*\d{4}\b",
cleaned,
flags=re.I,
)
entities["dates"].extend(month_day_year)
return entities
def _split_multi_part(query: str) -> List[str]:
"""Split a query into sub-queries on common conjunctions."""
parts = re.split(r"\s+and\s+|\s+or\s+|;", query, flags=re.I)
return [p.strip() for p in parts if p.strip()]
def _extract_site_filter(query: str) -> Tuple[str, Optional[str]]:
"""Detect a 'site:example.com' token. Returns (query_without_token, site_or_None)."""
match = re.search(r"\bsite:([^\s]+)", query, flags=re.I)
if match:
site = match.group(1)
new_query = re.sub(r"\bsite:[^\s]+", "", query, flags=re.I).strip()
return new_query, site
return query, None
def _boost_entities_in_query(base_query: str, entities: Dict[str, List[str]]) -> str:
"""Append extracted entities to the query using OR to increase relevance."""
parts = [base_query]
if entities.get("names"):
parts.append(" OR ".join(f'"{n}"' for n in entities["names"]))
if entities.get("dates"):
parts.append(" OR ".join(f'"{d}"' for d in entities["dates"]))
return " ".join(parts)
def enhance_query(original_query: str) -> Tuple[str, Optional[str]]:
"""Process the original query: site filter, question type boosts, entity extraction."""
query_without_site, site = _extract_site_filter(original_query)
sub_queries = _split_multi_part(query_without_site)
enhanced_subs: List[str] = []
for sub in sub_queries:
qtype = _detect_question_type(sub)
boost_keywords = []
if qtype == "who":
boost_keywords.append("person")
elif qtype == "when":
boost_keywords.append("date")
elif qtype == "where":
boost_keywords.append("location")
elif qtype == "why":
boost_keywords.append("reason")
elif qtype == "how":
boost_keywords.append("method")
entities = _extract_entities(sub)
boosted = _boost_entities_in_query(sub, entities)
if boost_keywords:
boosted = f'({boosted}) OR ({" OR ".join(boost_keywords)})'
enhanced_subs.append(boosted)
final_query = " AND ".join(f"({s})" for s in enhanced_subs)
if site:
final_query = f"{final_query} site:{site}"
return final_query, site
def build_enhanced_query(query: str, time_filter: str = None) -> str:
"""Build an enhanced search query with optional time filtering."""
enhanced_query, _ = enhance_query(query)
if time_filter:
time_map = {"day": "d", "week": "w", "month": "m", "year": "y"}
if time_filter in time_map:
enhanced_query = f"{enhanced_query} after:{time_map[time_filter]}"
logger.info(f"Added time filter '{time_filter}' to query")
logger.info(f"Enhanced query: '{query}' -> '{enhanced_query}'")
return enhanced_query
# ----------------------------------------------------------------------
# Cache duration helpers
# ----------------------------------------------------------------------
def _is_news_query(query: str) -> bool:
"""Lightweight heuristic to decide if a query is news-oriented."""
news_terms = {"news", "latest", "breaking", "today", "today's", "current", "updates", "happening"}
tokens = set(re.findall(r"\b\w+\b", query.lower()))
return bool(tokens & news_terms)
def _cache_duration_for_query(query: str) -> timedelta:
"""News queries -> 30 minutes, reference queries -> 24 hours."""
if _is_news_query(query):
return timedelta(minutes=30)
return timedelta(hours=24)
+127
View File
@@ -0,0 +1,127 @@
"""Search result ranking based on relevance, source quality, and recency."""
import re
import logging
from datetime import datetime
from typing import List, Optional
from urllib.parse import urlparse
logger = logging.getLogger(__name__)
_NEWS_HINTS = {"news", "nyheter", "headlines", "breaking", "latest", "today", "idag"}
_SPORTS_HINTS = {
"sport", "sports", "soccer", "football", "hockey", "nba", "nfl", "mlb",
"fifa", "world cup", "championship", "quarterfinal", "eliminates",
}
_LOW_VALUE_NEWS_DOMAINS = {
"facebook.com", "www.facebook.com", "sports.yahoo.com", "yahoo.com",
"www.yahoo.com", "msn.com", "www.msn.com",
}
_TRUSTED_NEWS_DOMAINS = {
"apnews.com", "www.apnews.com", "reuters.com", "www.reuters.com",
"bbc.com", "www.bbc.com", "cbc.ca", "www.cbc.ca",
"ctvnews.ca", "www.ctvnews.ca", "globalnews.ca", "www.globalnews.ca",
"theguardian.com",
"www.theguardian.com", "euronews.com", "www.euronews.com",
"dw.com", "www.dw.com", "government.se", "www.government.se",
}
def _domain(url: str) -> str:
try:
return urlparse(url).netloc.lower()
except Exception:
return ""
def rank_search_results(query: str, results: List[dict]) -> List[dict]:
"""Rank search results by title relevance, snippet quality, domain authority, and recency."""
query_terms = [t.lower() for t in re.findall(r"\b\w+\b", query)]
query_lc = query.lower()
is_news_query = any(term in _NEWS_HINTS for term in query_terms)
is_sports_query = any(hint in query_lc for hint in _SPORTS_HINTS)
def title_score(title: str) -> float:
if not title:
return 0.0
title_lc = title.lower()
matches = sum(1 for term in query_terms if re.search(rf"\b{re.escape(term)}\b", title_lc))
return matches / len(query_terms) if query_terms else 0.0
def snippet_score(snippet: str) -> float:
if not snippet:
return 0.0
length_factor = min(len(snippet), 200) / 200
term_hits = sum(1 for term in query_terms if term in snippet.lower())
term_factor = term_hits / len(query_terms) if query_terms else 0.0
return (length_factor + term_factor) / 2
def domain_score(url: str) -> float:
netloc = _domain(url)
if not netloc:
return 0.0
if netloc in _TRUSTED_NEWS_DOMAINS:
return 1.0
if netloc.endswith(".edu") or netloc.endswith(".gov"):
return 1.0
if netloc.endswith(".org"):
return 0.7
return 0.4
def recency_score(age_str: Optional[str]) -> float:
if not age_str:
return 0.0
for fmt in ("%Y-%m-%d", "%Y-%m-%dT%H:%M:%S", "%Y-%m-%d %H:%M:%S"):
try:
dt = datetime.strptime(age_str, fmt)
break
except Exception:
dt = None
if not dt:
return 0.0
days_old = (datetime.now() - dt).days
if days_old <= 7:
return 1.0
if days_old >= 30:
return 0.0
return (30 - days_old) / 23
def news_quality_adjustment(title: str, snippet: str, url: str) -> float:
if not is_news_query:
return 0.0
text = f"{title} {snippet}".lower()
netloc = _domain(url)
adjustment = 0.0
if netloc in _TRUSTED_NEWS_DOMAINS:
adjustment += 1.2
if any(term in text for term in ("latest news", "breaking news", "daily coverage", "news from")):
adjustment += 0.4
if netloc in _LOW_VALUE_NEWS_DOMAINS:
adjustment -= 0.8
if not is_sports_query and any(hint in text or hint in netloc for hint in _SPORTS_HINTS):
adjustment -= 1.5
# A country/news query should not rank a page whose title/snippet barely
# mentions the country above actual news pages for that country.
subject_terms = [t for t in query_terms if t not in _NEWS_HINTS]
if subject_terms and not any(t in text or t in netloc for t in subject_terms):
adjustment -= 1.0
return adjustment
ranked = []
for result in results:
title = result.get("title", "")
snippet = result.get("snippet", "")
url = result.get("url", "")
age = result.get("age", None)
score = (
2.0 * title_score(title)
+ 1.0 * snippet_score(snippet)
+ 1.5 * domain_score(url)
+ 1.0 * recency_score(age)
+ news_quality_adjustment(title, snippet, url)
)
ranked.append((score, result))
ranked.sort(key=lambda x: x[0], reverse=True)
return [r for _, r in ranked]
+95
View File
@@ -0,0 +1,95 @@
# services/search/service.py
"""Search service — clean interface for web search."""
from dataclasses import dataclass
from typing import List, Optional, Dict, Any
from . import (
comprehensive_web_search,
fetch_webpage_content,
get_search_config,
)
@dataclass
class SearchResult:
"""A single search result."""
url: str
title: str
snippet: str
content: Optional[str] = None
@dataclass
class SearchResponse:
"""Response from a search query."""
query: str
results: List[SearchResult]
total: int
cached: bool = False
class SearchService:
"""
Web search service.
Usage:
service = SearchService()
result = await service.search("python async patterns")
for r in result.results:
print(f"{r.title}: {r.url}")
"""
def __init__(self, default_depth: int = 1, fetch_content: bool = True):
self.default_depth = default_depth
self.fetch_content = fetch_content
async def search(
self,
query: str,
depth: Optional[int] = None,
fetch_content: Optional[bool] = None,
) -> SearchResponse:
"""
Search the web.
Args:
query: Search query
depth: Search depth (1=quick, 2=thorough, 3=comprehensive)
fetch_content: Whether to fetch full page content
Returns:
SearchResponse with results
"""
depth = depth or self.default_depth
fetch_content = fetch_content if fetch_content is not None else self.fetch_content
# Use existing search implementation
raw_results = await comprehensive_web_search(
query,
max_results=10 * depth,
fetch_content=fetch_content,
)
results = []
for r in raw_results:
results.append(SearchResult(
url=r.get("url", ""),
title=r.get("title", ""),
snippet=r.get("snippet", ""),
content=r.get("content"),
))
return SearchResponse(
query=query,
results=results,
total=len(results),
)
async def fetch_content(self, url: str) -> Optional[str]:
"""Fetch content from a URL."""
return await fetch_webpage_content(url)
def get_config(self) -> Dict[str, Any]:
"""Get current search configuration."""
return get_search_config()
+6
View File
@@ -0,0 +1,6 @@
# services/shell/__init__.py
"""Shell service — safe command execution."""
from .service import ShellService, ShellResult
__all__ = ["ShellService", "ShellResult"]
+162
View File
@@ -0,0 +1,162 @@
# services/shell/service.py
"""Shell service — safe command execution."""
from dataclasses import dataclass
from typing import Optional, AsyncIterator
import asyncio
from pathlib import Path
@dataclass
class ShellResult:
"""Result of a shell command."""
stdout: str
stderr: str
exit_code: int
timed_out: bool = False
class ShellService:
"""
Shell execution service.
Usage:
service = ShellService()
result = await service.execute("ls -la")
print(result.stdout)
"""
def __init__(self, timeout: int = 30, max_output: int = 200_000):
self.timeout = timeout
self.max_output = max_output
self.cwd = str(Path.home())
async def execute(
self,
command: str,
timeout: Optional[int] = None,
cwd: Optional[str] = None,
) -> ShellResult:
"""
Execute a shell command.
Args:
command: Shell command to run
timeout: Timeout in seconds (default: self.timeout)
cwd: Working directory (default: home)
Returns:
ShellResult with stdout, stderr, exit_code
"""
timeout = timeout or self.timeout
cwd = cwd or self.cwd
proc = None
try:
proc = await asyncio.create_subprocess_shell(
command,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
cwd=cwd,
)
stdout_b, stderr_b = await asyncio.wait_for(
proc.communicate(), timeout=timeout
)
stdout = stdout_b.decode(errors="replace")[:self.max_output]
stderr = stderr_b.decode(errors="replace")[:self.max_output]
return ShellResult(
stdout=stdout,
stderr=stderr,
exit_code=proc.returncode,
)
except asyncio.TimeoutError:
if proc:
try:
proc.kill()
await proc.wait()
except ProcessLookupError:
pass
return ShellResult(
stdout="",
stderr=f"Command timed out after {timeout}s",
exit_code=-1,
timed_out=True,
)
except Exception as e:
return ShellResult(stdout="", stderr=str(e), exit_code=-1)
async def stream(
self,
command: str,
timeout: int = 120,
) -> AsyncIterator[dict]:
"""
Execute a command and stream output.
Yields:
{"stream": "stdout"|"stderr", "data": line}
{"exit_code": int}
"""
proc = None
reader_tasks = []
try:
proc = await asyncio.create_subprocess_shell(
command,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
cwd=self.cwd,
)
q: asyncio.Queue = asyncio.Queue()
async def _reader(stream, name):
try:
while True:
line = await stream.readline()
if not line:
break
await q.put((name, line.decode(errors="replace").rstrip("\n")))
finally:
await q.put((name, None))
reader_tasks = [
asyncio.create_task(_reader(proc.stdout, "stdout")),
asyncio.create_task(_reader(proc.stderr, "stderr")),
]
finished = 0
deadline = asyncio.get_event_loop().time() + timeout
while finished < 2:
remaining = deadline - asyncio.get_event_loop().time()
if remaining <= 0:
raise asyncio.TimeoutError()
try:
name, text = await asyncio.wait_for(q.get(), timeout=min(remaining, 2.0))
except asyncio.TimeoutError:
continue
if text is None:
finished += 1
continue
yield {"stream": name, "data": text}
await proc.wait()
yield {"exit_code": proc.returncode}
except asyncio.TimeoutError:
if proc:
try:
proc.kill()
await proc.wait()
except ProcessLookupError:
pass
yield {"stream": "stderr", "data": f"Command timed out after {timeout}s"}
yield {"exit_code": -1}
except Exception as e:
yield {"stream": "stderr", "data": str(e)}
yield {"exit_code": -1}
finally:
for t in reader_tasks:
t.cancel()
+3
View File
@@ -0,0 +1,3 @@
from services.stt.stt_service import get_stt_service
__all__ = ["get_stt_service"]
+191
View File
@@ -0,0 +1,191 @@
# services/stt/stt_service.py
"""Multi-provider Speech-to-Text service — dispatches to local Whisper, OpenAI-compatible API, or browser."""
import io
import logging
import httpx
import tempfile
from pathlib import Path
from typing import Optional, Dict, Any
logger = logging.getLogger(__name__)
class STTService:
"""Multi-provider STT service.
Reads provider config from data/settings.json on each call.
Providers:
"disabled" — no STT
"browser" — client-side Web Speech API (no server transcription)
"local" — faster-whisper on CPU/GPU
"endpoint:<id>" — OpenAI-compatible /audio/transcriptions via ModelEndpoint
"""
def __init__(self):
self._whisper_model = None # lazy-init
# ── Settings ──
def _load_settings(self) -> dict:
from src.settings import load_settings
saved = load_settings()
return {
"stt_enabled": saved.get("stt_enabled", False),
"stt_provider": saved.get("stt_provider", "disabled"),
"stt_model": saved.get("stt_model", "base"),
"stt_language": saved.get("stt_language", ""),
}
@property
def available(self) -> bool:
settings = self._load_settings()
provider = settings["stt_provider"]
if provider == "disabled":
return False
if provider == "browser":
return True # handled client-side
if provider == "local":
return self._get_whisper() is not None
if provider.startswith("endpoint:"):
return True # assume reachable
return False
# ── Local Whisper ──
def _get_whisper(self):
if self._whisper_model is None:
try:
from faster_whisper import WhisperModel
settings = self._load_settings()
model_size = settings.get("stt_model", "base")
# Use CPU by default; will use CUDA if available
import torch
device = "cuda" if torch.cuda.is_available() else "cpu"
compute_type = "float16" if device == "cuda" else "int8"
self._whisper_model = WhisperModel(model_size, device=device, compute_type=compute_type)
logger.info(f"faster-whisper model '{model_size}' loaded on {device}")
except ImportError:
logger.warning("faster-whisper not installed. Install with: pip install faster-whisper")
return None
except Exception as e:
logger.error(f"Failed to load whisper model: {e}")
return None
return self._whisper_model
def _transcribe_local(self, audio_bytes: bytes, language: str = "") -> Optional[str]:
model = self._get_whisper()
if not model:
return None
try:
# Write to temp file (faster-whisper needs a file path or file-like)
with tempfile.NamedTemporaryFile(suffix=".webm", delete=False) as tmp:
tmp.write(audio_bytes)
tmp_path = tmp.name
kwargs = {}
if language:
kwargs["language"] = language
segments, info = model.transcribe(tmp_path, **kwargs)
text = " ".join(seg.text.strip() for seg in segments)
# Cleanup
Path(tmp_path).unlink(missing_ok=True)
logger.info(f"Local STT: {len(text)} chars, lang={info.language}, prob={info.language_probability:.2f}")
return text
except Exception as e:
logger.error(f"Local STT transcription failed: {e}", exc_info=True)
return None
# ── API endpoint ──
def _transcribe_api(self, audio_bytes: bytes, endpoint_id: str, model: str, language: str = "") -> Optional[str]:
from src.database import SessionLocal, ModelEndpoint
db = SessionLocal()
try:
ep = db.query(ModelEndpoint).filter(ModelEndpoint.id == endpoint_id).first()
if not ep:
logger.error(f"STT endpoint {endpoint_id} not found")
return None
base_url = ep.base_url.rstrip("/")
api_key = ep.api_key
finally:
db.close()
url = base_url + "/audio/transcriptions"
headers = {}
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
files = {"file": ("audio.webm", io.BytesIO(audio_bytes), "audio/webm")}
data = {"model": model or "whisper-1"}
if language:
data["language"] = language
try:
r = httpx.post(url, headers=headers, files=files, data=data, timeout=60)
r.raise_for_status()
result = r.json()
text = result.get("text", "")
logger.info(f"API STT: {len(text)} chars from {base_url}")
return text
except Exception as e:
logger.error(f"API STT transcription failed: {e}")
return None
# ── Public interface ──
def transcribe(self, audio_bytes: bytes) -> Optional[str]:
settings = self._load_settings()
provider = settings["stt_provider"]
model = settings["stt_model"]
language = settings.get("stt_language", "")
if provider in ("disabled", "browser"):
return None
if provider == "local":
return self._transcribe_local(audio_bytes, language)
elif provider.startswith("endpoint:"):
endpoint_id = provider.split(":", 1)[1]
return self._transcribe_api(audio_bytes, endpoint_id, model, language)
else:
logger.error(f"Unknown STT provider: {provider}")
return None
def get_stats(self) -> Dict[str, Any]:
settings = self._load_settings()
provider = settings["stt_provider"]
stt_enabled = settings.get("stt_enabled", False)
# If toggle is off, report as disabled
effective_provider = provider if stt_enabled else "disabled"
stats = {
"available": self.available and stt_enabled,
"provider": effective_provider,
"model": settings["stt_model"],
"language": settings.get("stt_language", ""),
}
if provider == "local":
whisper = self._get_whisper()
stats["model_loaded"] = whisper is not None
elif provider == "browser":
stats["model"] = "Browser (Web Speech API)"
elif provider.startswith("endpoint:"):
stats["endpoint_id"] = provider.split(":", 1)[1]
return stats
# Module-level singleton
_stt_service = None
def get_stt_service() -> STTService:
global _stt_service
if _stt_service is None:
_stt_service = STTService()
return _stt_service
+9
View File
@@ -0,0 +1,9 @@
# services/tts/__init__.py
"""TTS service — text-to-speech."""
from .tts_service import (
TTSService,
get_tts_service,
)
__all__ = ["TTSService", "get_tts_service"]
+278
View File
@@ -0,0 +1,278 @@
# src/tts_service.py
"""Multi-provider TTS service — dispatches to local Kokoro, OpenAI-compatible API, or browser."""
import io
import wave
import logging
import hashlib
import httpx
from pathlib import Path
from typing import Optional, Dict, Any
logger = logging.getLogger(__name__)
class TTSService:
"""Multi-provider TTS service.
Reads provider config from data/settings.json on each call.
Providers:
"disabled" — no TTS
"browser" — client-side Web Speech API (no server synthesis)
"local" — Kokoro-82M on GPU
"endpoint:<id>" — OpenAI-compatible /audio/speech via ModelEndpoint
"""
def __init__(self, cache_dir: str = "data/tts_cache"):
self.cache_dir = Path(cache_dir)
self.cache_dir.mkdir(parents=True, exist_ok=True)
self._kokoro = None # lazy-init
# ── Settings ──
def _load_settings(self) -> dict:
from src.settings import load_settings
saved = load_settings()
return {
"tts_provider": saved.get("tts_provider", "disabled"),
"tts_model": saved.get("tts_model", "tts-1"),
"tts_voice": saved.get("tts_voice", "alloy"),
"tts_speed": saved.get("tts_speed", "1"),
}
@property
def available(self) -> bool:
settings = self._load_settings()
provider = settings["tts_provider"]
if provider == "disabled":
return False
if provider == "browser":
return True # handled client-side
if provider == "local":
kokoro = self._get_kokoro()
return kokoro is not None and kokoro.available
if provider.startswith("endpoint:"):
return True # assume reachable; errors surface at synthesis time
return False
# ── Cache ──
def _cache_key(self, text: str, provider: str, model: str, voice: str, speed: float = 1.0) -> str:
raw = f"{provider}|{model}|{voice}|{speed}|{text}"
return hashlib.sha256(raw.encode()).hexdigest()
def _get_cached(self, key: str) -> Optional[bytes]:
for ext in (".mp3", ".wav"):
path = self.cache_dir / f"{key}{ext}"
if path.exists():
return path.read_bytes()
return None
def _put_cache(self, key: str, data: bytes):
ext = ".mp3" if (len(data) >= 3 and (data[:3] == b'ID3' or (data[0] == 0xff and (data[1] & 0xe0) == 0xe0))) else ".wav"
(self.cache_dir / f"{key}{ext}").write_bytes(data)
def clear_cache(self):
count = 0
for f in self.cache_dir.glob("*.*"):
f.unlink()
count += 1
logger.info(f"Cleared {count} cached TTS files")
# ── Kokoro (local) ──
def _get_kokoro(self):
if self._kokoro is None:
self._kokoro = _KokoroPipeline()
return self._kokoro
# ── API endpoint ──
def _synthesize_api(self, text: str, endpoint_id: str, model: str, voice: str, speed: float = 1.0) -> Optional[bytes]:
from src.database import SessionLocal, ModelEndpoint
db = SessionLocal()
try:
ep = db.query(ModelEndpoint).filter(ModelEndpoint.id == endpoint_id).first()
if not ep:
logger.error(f"TTS endpoint {endpoint_id} not found")
return None
base_url = ep.base_url.rstrip("/")
api_key = ep.api_key
finally:
db.close()
url = base_url + "/audio/speech"
headers = {"Content-Type": "application/json"}
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
payload = {
"model": model,
"input": text,
"voice": voice,
"response_format": "mp3",
"speed": speed,
}
try:
r = httpx.post(url, json=payload, headers=headers, timeout=60)
r.raise_for_status()
logger.info(f"API TTS: {len(r.content)} bytes from {base_url}")
return r.content
except Exception as e:
logger.error(f"API TTS synthesis failed: {e}")
return None
# ── Public interface ──
def synthesize(self, text: str, use_cache: bool = True) -> Optional[bytes]:
settings = self._load_settings()
provider = settings["tts_provider"]
model = settings["tts_model"]
voice = settings["tts_voice"]
speed = float(settings.get("tts_speed", "1"))
if provider in ("disabled", "browser"):
return None
if len(text) > 5000:
text = text[:5000]
if use_cache:
key = self._cache_key(text, provider, model, voice, speed)
cached = self._get_cached(key)
if cached:
logger.info(f"TTS cache hit ({len(text)} chars)")
return cached
audio_data = None
if provider == "local":
kokoro = self._get_kokoro()
if kokoro and kokoro.available:
audio_data = kokoro.synthesize_raw(text, voice)
else:
logger.warning("Kokoro TTS not available")
return None
elif provider.startswith("endpoint:"):
endpoint_id = provider.split(":", 1)[1]
audio_data = self._synthesize_api(text, endpoint_id, model, voice, speed)
else:
logger.error(f"Unknown TTS provider: {provider}")
return None
if audio_data and use_cache:
key = self._cache_key(text, provider, model, voice, speed)
self._put_cache(key, audio_data)
return audio_data
def synthesize_to_base64(self, text: str) -> Optional[str]:
import base64
audio = self.synthesize(text)
if audio:
return base64.b64encode(audio).decode("utf-8")
return None
def set_voice(self, voice: str):
"""Legacy no-op — voice is now managed via admin settings."""
def get_stats(self) -> Dict[str, Any]:
settings = self._load_settings()
provider = settings["tts_provider"]
tts_enabled = settings.get("tts_enabled", True)
cache_files = list(self.cache_dir.glob("*.wav"))
cache_size = sum(f.stat().st_size for f in cache_files)
is_available = self.available and tts_enabled
stats = {
"available": is_available,
"ready": is_available,
"provider": provider,
"model": settings["tts_model"],
"voice": settings["tts_voice"],
"speed": float(settings.get("tts_speed", "1")),
"cache_entries": len(cache_files),
"cache_size_mb": round(cache_size / (1024 * 1024), 2),
}
if provider == "local":
kokoro = self._get_kokoro()
stats["model"] = "Kokoro-82M (GPU)" if (kokoro and kokoro.available) else "Kokoro (not loaded)"
elif provider == "browser":
stats["model"] = "Browser (Web Speech API)"
elif provider.startswith("endpoint:"):
stats["endpoint_id"] = provider.split(":", 1)[1]
return stats
class _KokoroPipeline:
"""Encapsulates the Kokoro-82M local GPU pipeline."""
def __init__(self):
self.pipeline = None
self.available = False
self.device = None
self._init()
def _init(self):
try:
import torch
from kokoro import KPipeline
if not torch.cuda.is_available():
logger.warning("CUDA not available for Kokoro TTS")
return
self.device = torch.device("cuda:0")
with torch.cuda.device(0):
self.pipeline = KPipeline(lang_code="a")
if hasattr(self.pipeline, "model"):
self.pipeline.model = self.pipeline.model.to(self.device)
self.available = True
logger.info("Kokoro-82M TTS pipeline loaded")
except ImportError as e:
logger.warning(f"Kokoro TTS not available: {e}")
logger.warning("Install with: pip install kokoro soundfile")
except Exception as e:
logger.error(f"Kokoro init failed: {e}", exc_info=True)
def synthesize_raw(self, text: str, voice: str = "af_heart") -> Optional[bytes]:
if not self.available:
return None
try:
import torch
import numpy as np
with torch.cuda.device(self.device):
chunks = []
for _, _, audio in self.pipeline(text, voice=voice):
chunks.append(audio)
if not chunks:
return None
full = np.concatenate(chunks)
buf = io.BytesIO()
with wave.open(buf, "wb") as wf:
wf.setnchannels(1)
wf.setsampwidth(2)
wf.setframerate(24000)
wf.writeframes((full * 32767).astype(np.int16).tobytes())
return buf.getvalue()
except Exception as e:
logger.error(f"Kokoro synthesis failed: {e}", exc_info=True)
return None
# Module-level singleton
_tts_service = None
def get_tts_service() -> TTSService:
global _tts_service
if _tts_service is None:
_tts_service = TTSService()
return _tts_service
+22
View File
@@ -0,0 +1,22 @@
# services/youtube/__init__.py
"""YouTube service — transcript extraction."""
from .youtube_handler import (
init_youtube,
is_youtube_url,
extract_youtube_id,
extract_transcript_async,
format_transcript_for_context,
fetch_youtube_comments,
format_comments_for_context,
)
__all__ = [
"init_youtube",
"is_youtube_url",
"extract_youtube_id",
"extract_transcript_async",
"format_transcript_for_context",
"fetch_youtube_comments",
"format_comments_for_context",
]
+265
View File
@@ -0,0 +1,265 @@
"""
YouTube handling — transcript extraction, comment fetching (yt-dlp),
and context formatting for LLM injection. Used by chat_handler.py.
"""
import asyncio
import json
import logging
import shutil
import sys
import urllib.parse
from pathlib import Path
from typing import Dict, Any, Optional
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Constants
# ---------------------------------------------------------------------------
YOUTUBE_INSTRUCTION_PROMPT = """When the user shares a YouTube video, respond with a structured breakdown:
1. **Summary** — Concise overview of the video's content and main thesis (2-4 sentences)
2. **Key Points** — Bullet list of the most important topics, arguments, or moments
3. **Notable Timestamps** — If timestamps are available from the transcript, highlight 3-5 interesting moments with their approximate timestamps (e.g. "03:45 — discusses X")
4. **Audience Reception** — If comments are available, summarize what viewers think: general sentiment, top reactions, any debate or controversy
Keep it conversational and concise. Do NOT web search for this video — use only the transcript and comments provided."""
# ---------------------------------------------------------------------------
# Init / helpers
# ---------------------------------------------------------------------------
# Will be set at startup by init_youtube()
YouTubeTranscriptApi = None
YOUTUBE_AVAILABLE = False
def _find_ytdlp() -> str:
"""Find the yt-dlp binary: venv bin first, then system PATH."""
venv_bin = Path(sys.executable).parent / "yt-dlp"
if venv_bin.exists():
return str(venv_bin)
found = shutil.which("yt-dlp")
return found or "yt-dlp"
def init_youtube():
"""Import and cache the YouTube transcript API."""
global YouTubeTranscriptApi, YOUTUBE_AVAILABLE
try:
from youtube_transcript_api import YouTubeTranscriptApi as _Api
YouTubeTranscriptApi = _Api
YOUTUBE_AVAILABLE = True
logger.info("YouTube transcript API available")
except ImportError as e:
logger.warning(f"youtube-transcript-api not installed: {e}")
YOUTUBE_AVAILABLE = False
def is_youtube_url(url: str) -> bool:
return "youtube.com" in url or "youtu.be" in url
def extract_youtube_id(url: str) -> Optional[str]:
"""Extract YouTube video ID from various URL formats."""
parsed = urllib.parse.urlparse(url)
if parsed.hostname in ("www.youtube.com", "youtube.com", "m.youtube.com"):
if parsed.path == "/watch":
params = urllib.parse.parse_qs(parsed.query)
if "v" in params:
return params["v"][0]
elif parsed.path.startswith("/embed/"):
return parsed.path.split("/")[-1]
elif parsed.hostname == "youtu.be":
return parsed.path[1:]
return None
async def extract_transcript_async(
url: str, video_id: str, max_retries: int = 3
) -> Dict[str, Any]:
"""
Async YouTube transcript extraction with retries.
Args:
url: Full YouTube URL
video_id: Extracted video ID
max_retries: Number of attempts
Returns:
Dict with success/error/transcript keys
"""
if not YOUTUBE_AVAILABLE or YouTubeTranscriptApi is None:
return {"success": False, "error": "YouTube transcript API not available", "transcript": None}
for attempt in range(max_retries):
try:
api = YouTubeTranscriptApi()
transcript = api.fetch(video_id)
transcript_list = list(transcript)
formatted = []
for snippet in transcript_list:
text = snippet.text.strip()
if not text:
continue
start = snippet.start
formatted.append({
"text": text,
"start": start,
"duration": snippet.duration,
"timestamp": f"{int(start // 60):02d}:{int(start % 60):02d}",
})
full_text = " ".join(e["text"] for e in formatted)
max_len = 8000
if len(full_text) > max_len:
full_text = full_text[:max_len] + "... [transcript truncated]"
return {
"success": True,
"transcript": full_text,
"video_id": video_id,
"language": "en",
"is_generated": False,
"segments": formatted,
}
except Exception as e:
logger.warning(f"Transcript attempt {attempt + 1} failed: {e}")
if attempt < max_retries - 1:
await asyncio.sleep(1 * (attempt + 1))
return {"success": False, "error": f"Failed after {max_retries} attempts", "transcript": None}
def format_transcript_for_context(
transcript_data: Dict[str, Any], url: str,
title: str = "", channel: str = ""
) -> str:
"""Format transcript data for inclusion in LLM context."""
if not transcript_data.get("success"):
header = ""
if title:
header = f" \"{title}\""
if channel:
header += f" by {channel}"
return f"\n[YouTube Video{header}: Transcript unavailable ({transcript_data.get('error', 'Unknown error')}). Use the comments below if available, do NOT web search for this video.]"
transcript = transcript_data.get("transcript", "")
video_id = transcript_data.get("video_id", "")
language = transcript_data.get("language", "unknown")
is_generated = transcript_data.get("is_generated", False)
segments = transcript_data.get("segments", [])
ctx = "\n[YOUTUBE VIDEO TRANSCRIPT]\n"
if title:
ctx += f"Title: {title}\n"
if channel:
ctx += f"Channel: {channel}\n"
ctx += f"Video ID: {video_id}\n"
ctx += f"Language: {language}\n"
ctx += f"Source: {'Auto-generated' if is_generated else 'Manual'}\n"
ctx += f"URL: {url}\n\n"
# Include timestamped segments for the LLM to reference
if segments:
ctx += "Timestamped Transcript:\n"
for seg in segments:
ctx += f"[{seg['timestamp']}] {seg['text']}\n"
# Check length — fall back to plain text if too long
if len(ctx) > 12000:
ctx = ctx[:ctx.index("Timestamped Transcript:\n")]
ctx += "Transcript:\n"
ctx += transcript
else:
ctx += "Transcript:\n"
ctx += transcript
ctx += "\n[END TRANSCRIPT]\n"
return ctx
async def fetch_youtube_comments(
video_id: str, max_comments: int = 25, timeout: int = 30
) -> Dict[str, Any]:
"""Fetch top comments for a YouTube video using yt-dlp.
Returns dict with 'success', 'comments' list, 'error'.
"""
try:
cmd = [
_find_ytdlp(),
"--skip-download",
"--write-comments",
"--extractor-args", f"youtube:max_comments={max_comments},all,100,0",
"--dump-json",
"--js-runtimes", "node",
"--remote-components", "ejs:github",
f"https://www.youtube.com/watch?v={video_id}",
]
proc = await asyncio.wait_for(
asyncio.create_subprocess_exec(
*cmd,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
),
timeout=timeout,
)
stdout, stderr = await proc.communicate()
if proc.returncode != 0:
return {"success": False, "error": f"yt-dlp failed: {stderr.decode()[:200]}", "comments": []}
data = json.loads(stdout.decode())
title = data.get("title", "")
channel = data.get("channel", "") or data.get("uploader", "")
raw_comments = data.get("comments", [])
comments = []
for c in raw_comments[:max_comments]:
text = (c.get("text") or "").strip()
if not text:
continue
comments.append({
"author": c.get("author", "Unknown"),
"text": text,
"likes": c.get("like_count", 0),
})
# Sort by likes descending — most popular comments first
comments.sort(key=lambda x: x.get("likes", 0), reverse=True)
return {"success": True, "comments": comments, "count": len(comments),
"title": title, "channel": channel}
except asyncio.TimeoutError:
logger.warning(f"Comment fetch timed out for {video_id}")
return {"success": False, "error": "Comment fetch timed out", "comments": []}
except FileNotFoundError:
logger.warning("yt-dlp not installed — cannot fetch comments")
return {"success": False, "error": "yt-dlp not installed", "comments": []}
except Exception as e:
logger.warning(f"Failed to fetch comments for {video_id}: {e}")
return {"success": False, "error": str(e), "comments": []}
def format_comments_for_context(comments_data: Dict[str, Any], url: str) -> str:
"""Format YouTube comments for inclusion in LLM context."""
if not comments_data.get("success") or not comments_data.get("comments"):
return ""
comments = comments_data["comments"]
ctx = f"\n[YOUTUBE VIDEO COMMENTS — Top {len(comments)} by popularity]\n"
ctx += f"URL: {url}\n\n"
for i, c in enumerate(comments, 1):
likes = c.get("likes", 0)
likes_str = f" [{likes} likes]" if likes else ""
ctx += f"{i}. @{c['author']}{likes_str}: {c['text']}\n\n"
if len(ctx) > 4000:
ctx = ctx[:4000] + "\n[Comments truncated]\n"
ctx += "[END COMMENTS]\n"
return ctx