mirror of
https://github.com/pewdiepie-archdaemon/odysseus.git
synced 2026-06-16 09:45:24 -04:00
Odysseus v1.0
This commit is contained in:
@@ -0,0 +1,37 @@
|
||||
# services/__init__.py
|
||||
"""
|
||||
Service layer — plug-in capabilities for the chat core.
|
||||
|
||||
Each service:
|
||||
- Does one thing well
|
||||
- Exposes a clean async interface
|
||||
- Can run in-process or as a standalone HTTP service
|
||||
"""
|
||||
|
||||
from .search import SearchService, SearchResult, SearchResponse
|
||||
from .docs import DocsService, DocChunk, IndexResult
|
||||
from .research import ResearchService, ResearchResult, ResearchSource
|
||||
from .memory import MemoryService, Memory, MemorySearchResult
|
||||
from .shell import ShellService, ShellResult
|
||||
|
||||
__all__ = [
|
||||
# Search
|
||||
"SearchService",
|
||||
"SearchResult",
|
||||
"SearchResponse",
|
||||
# Docs
|
||||
"DocsService",
|
||||
"DocChunk",
|
||||
"IndexResult",
|
||||
# Research
|
||||
"ResearchService",
|
||||
"ResearchResult",
|
||||
"ResearchSource",
|
||||
# Memory
|
||||
"MemoryService",
|
||||
"Memory",
|
||||
"MemorySearchResult",
|
||||
# Shell
|
||||
"ShellService",
|
||||
"ShellResult",
|
||||
]
|
||||
@@ -0,0 +1,18 @@
|
||||
# services/docs/__init__.py
|
||||
"""Docs service — personal document RAG with ChromaDB.
|
||||
|
||||
Thin facade: DocsService lives here, RAGManager/VectorRAG are re-exported
|
||||
from the canonical implementations in src/.
|
||||
"""
|
||||
|
||||
from .service import DocsService, DocChunk, IndexResult
|
||||
from src.rag_manager import RAGManager
|
||||
from src.rag_vector import VectorRAG
|
||||
|
||||
__all__ = [
|
||||
"DocsService",
|
||||
"DocChunk",
|
||||
"IndexResult",
|
||||
"RAGManager",
|
||||
"VectorRAG",
|
||||
]
|
||||
@@ -0,0 +1,89 @@
|
||||
# services/docs/service.py
|
||||
"""Docs service — personal document RAG."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Dict, Any
|
||||
|
||||
from src.rag_manager import RAGManager
|
||||
|
||||
|
||||
@dataclass
|
||||
class DocChunk:
|
||||
"""A retrieved document chunk."""
|
||||
text: str
|
||||
source: str
|
||||
score: float
|
||||
metadata: Dict[str, Any] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class IndexResult:
|
||||
"""Result of indexing documents."""
|
||||
indexed: int
|
||||
failed: int
|
||||
errors: List[str]
|
||||
|
||||
|
||||
class DocsService:
|
||||
"""
|
||||
Document RAG service.
|
||||
|
||||
Usage:
|
||||
service = DocsService()
|
||||
await service.index("/path/to/docs")
|
||||
results = await service.query("what is async await?")
|
||||
"""
|
||||
|
||||
def __init__(self, persist_dir: str = "data/chroma"):
|
||||
self.rag = RAGManager(persist_directory=persist_dir)
|
||||
|
||||
async def query(self, query: str, top_k: int = 5) -> List[DocChunk]:
|
||||
"""
|
||||
Query the document index.
|
||||
|
||||
Args:
|
||||
query: Search query
|
||||
top_k: Number of results
|
||||
|
||||
Returns:
|
||||
List of DocChunk objects
|
||||
"""
|
||||
results = self.rag.search(query, k=top_k)
|
||||
return [
|
||||
DocChunk(
|
||||
text=r.get("text", r.get("content", "")),
|
||||
source=r.get("source", r.get("metadata", {}).get("source", "unknown")),
|
||||
score=r.get("score", 0.0),
|
||||
metadata=r.get("metadata"),
|
||||
)
|
||||
for r in results
|
||||
]
|
||||
|
||||
async def index(self, directory: str) -> IndexResult:
|
||||
"""
|
||||
Index documents from a directory.
|
||||
|
||||
Args:
|
||||
directory: Path to documents
|
||||
|
||||
Returns:
|
||||
IndexResult with stats
|
||||
"""
|
||||
result = self.rag.index_personal_documents(directory)
|
||||
return IndexResult(
|
||||
indexed=result.get("indexed", 0),
|
||||
failed=result.get("failed", 0),
|
||||
errors=result.get("errors", []),
|
||||
)
|
||||
|
||||
async def add_document(self, text: str, metadata: Dict[str, Any]) -> bool:
|
||||
"""Add a single document to the index."""
|
||||
return self.rag.add_document(text, metadata)
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""Get index statistics."""
|
||||
return self.rag.get_stats()
|
||||
|
||||
def rebuild_index(self) -> bool:
|
||||
"""Rebuild the entire index."""
|
||||
return self.rag.rebuild_index()
|
||||
@@ -0,0 +1 @@
|
||||
"""Face detection + embedding service (standalone worker + helpers)."""
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,463 @@
|
||||
import re
|
||||
|
||||
from services.hwfit.models import (
|
||||
params_b, estimate_memory_gb, infer_use_case,
|
||||
get_models, is_prequantized, _active_params_b, QUANT_BYTES_PER_PARAM,
|
||||
QUANT_SPEED_MULT, QUANT_QUALITY_PENALTY,
|
||||
)
|
||||
|
||||
GPU_BANDWIDTH = {
|
||||
"5090": 1792, "5080": 960, "5070 ti": 896, "5070": 672, "5060 ti": 448, "5060": 256,
|
||||
"4090": 1008, "4080 super": 736, "4080": 717, "4070 ti super": 672, "4070 ti": 504, "4070 super": 504, "4070": 504, "4060 ti": 288, "4060": 272,
|
||||
"3090 ti": 1008, "3090": 936, "3080 ti": 912, "3080": 760, "3070 ti": 608, "3070": 448, "3060 ti": 448, "3060": 360,
|
||||
"2080 ti": 616, "2080 super": 496, "2080": 448, "2070 super": 448, "2070": 448, "2060 super": 448, "2060": 336,
|
||||
"1660 ti": 288, "1660 super": 336, "1660": 192, "1650 super": 192, "1650": 128,
|
||||
"h100 sxm": 3350, "h100": 2039, "h200": 4800, "a100 sxm": 2039, "a100": 1555,
|
||||
"l40s": 864, "l40": 864, "l4": 300, "a10g": 600, "a10": 600, "t4": 320,
|
||||
"v100 sxm": 900, "v100": 897, "a6000": 768, "a5000": 768, "a4000": 448,
|
||||
"7900 xtx": 960, "7900 xt": 800, "7900 gre": 576, "7800 xt": 624, "7700 xt": 432, "7600": 288,
|
||||
"6950 xt": 576, "6900 xt": 512, "6800 xt": 512, "6800": 512, "6700 xt": 384, "6600 xt": 256, "6600": 224,
|
||||
"mi300x": 5300, "mi300": 5300, "mi250x": 3277, "mi250": 3277, "mi210": 1638, "mi100": 1229,
|
||||
"9070 xt": 624, "9070": 488,
|
||||
}
|
||||
|
||||
# Pre-sort keys by length descending for correct substring matching
|
||||
_BW_KEYS_SORTED = sorted(GPU_BANDWIDTH.keys(), key=len, reverse=True)
|
||||
|
||||
FALLBACK_K = {"cuda": 220, "rocm": 180, "cpu_x86": 70, "cpu_arm": 90}
|
||||
|
||||
USE_CASE_WEIGHTS = {
|
||||
"general": (0.45, 0.30, 0.15, 0.10),
|
||||
"coding": (0.50, 0.20, 0.15, 0.15),
|
||||
"reasoning": (0.55, 0.15, 0.15, 0.15),
|
||||
"chat": (0.40, 0.35, 0.15, 0.10),
|
||||
"multimodal": (0.50, 0.20, 0.15, 0.15),
|
||||
"embedding": (0.30, 0.40, 0.20, 0.10),
|
||||
"tts": (0.40, 0.35, 0.15, 0.10),
|
||||
"stt": (0.40, 0.35, 0.15, 0.10),
|
||||
}
|
||||
|
||||
SPEED_TARGET = {
|
||||
"general": 40, "coding": 40, "multimodal": 40, "chat": 40,
|
||||
"reasoning": 25, "embedding": 200, "tts": 40, "stt": 40,
|
||||
}
|
||||
|
||||
CONTEXT_TARGET = {
|
||||
"general": 4096, "chat": 4096, "coding": 8192,
|
||||
"reasoning": 8192, "multimodal": 4096, "embedding": 512,
|
||||
"tts": 2048, "stt": 2048,
|
||||
}
|
||||
|
||||
|
||||
def _lookup_bandwidth(gpu_name):
|
||||
if not gpu_name:
|
||||
return None
|
||||
gn = gpu_name.lower()
|
||||
for key in _BW_KEYS_SORTED:
|
||||
if key in gn:
|
||||
return GPU_BANDWIDTH[key]
|
||||
return None
|
||||
|
||||
|
||||
def _estimate_speed(model, quant, run_mode, system):
|
||||
"""Estimate tok/s. Uses active params for MoE (only active experts run per token)."""
|
||||
pb = _active_params_b(model)
|
||||
is_moe = model.get("is_moe", False)
|
||||
bw = _lookup_bandwidth(system.get("gpu_name"))
|
||||
backend = system.get("backend", "cpu_x86")
|
||||
|
||||
if bw and run_mode in ("gpu", "cpu_offload"):
|
||||
bpp = QUANT_BYTES_PER_PARAM.get(quant, 0.5)
|
||||
model_gb = pb * bpp
|
||||
if model_gb <= 0:
|
||||
return 0.0
|
||||
efficiency = 0.55
|
||||
raw_tps = (bw / model_gb) * efficiency
|
||||
if run_mode == "cpu_offload":
|
||||
mode_factor = 0.5
|
||||
elif is_moe:
|
||||
mode_factor = 0.8
|
||||
else:
|
||||
mode_factor = 1.0
|
||||
return raw_tps * mode_factor
|
||||
|
||||
k = FALLBACK_K.get(backend, 70)
|
||||
if pb <= 0:
|
||||
return 0.0
|
||||
sm = QUANT_SPEED_MULT.get(quant, 1.0)
|
||||
return k / pb * sm
|
||||
|
||||
|
||||
def _quality_score(model, quant, use_case):
|
||||
pb = params_b(model)
|
||||
if pb < 1:
|
||||
base = 30
|
||||
elif pb < 3:
|
||||
base = 45
|
||||
elif pb < 7:
|
||||
base = 60
|
||||
elif pb < 10:
|
||||
base = 75
|
||||
elif pb < 20:
|
||||
base = 82
|
||||
elif pb < 40:
|
||||
base = 89
|
||||
else:
|
||||
base = 95
|
||||
|
||||
name_lower = model.get("name", "").lower()
|
||||
if "qwen" in name_lower:
|
||||
base += 2
|
||||
if "deepseek" in name_lower:
|
||||
base += 3
|
||||
if "llama" in name_lower:
|
||||
base += 2
|
||||
if "mistral" in name_lower or "mixtral" in name_lower:
|
||||
base += 1
|
||||
if "gemma" in name_lower:
|
||||
base += 1
|
||||
|
||||
base += QUANT_QUALITY_PENALTY.get(quant, 0)
|
||||
|
||||
model_uc = infer_use_case(model)
|
||||
if model_uc == "coding" and use_case == "coding":
|
||||
base += 6
|
||||
if model_uc == "reasoning" and use_case == "reasoning" and pb >= 13:
|
||||
base += 5
|
||||
if model_uc == "multimodal" and use_case == "multimodal":
|
||||
base += 6
|
||||
|
||||
return max(0, min(100, base))
|
||||
|
||||
|
||||
def _speed_score(tps, use_case):
|
||||
target = SPEED_TARGET.get(use_case, 40)
|
||||
return max(0, min(100, (tps / target) * 100))
|
||||
|
||||
|
||||
def _fit_score(required, available):
|
||||
if required > available:
|
||||
return 0
|
||||
if available <= 0:
|
||||
return 0
|
||||
ratio = required / available
|
||||
if ratio <= 0.5:
|
||||
return 60 + (ratio / 0.5) * 40
|
||||
if ratio <= 0.8:
|
||||
return 100
|
||||
if ratio <= 0.9:
|
||||
return 70
|
||||
return 50
|
||||
|
||||
|
||||
def _context_score(ctx, use_case):
|
||||
target = CONTEXT_TARGET.get(use_case, 4096)
|
||||
if ctx >= target:
|
||||
return 100
|
||||
if ctx >= target / 2:
|
||||
return 70
|
||||
return 30
|
||||
|
||||
|
||||
def _try_quant_at(model, quant, ctx, gpu_vram, available_ram):
|
||||
"""Try a specific quant at a given context. Returns (run_mode, quant, ctx, mem) or None."""
|
||||
mem = estimate_memory_gb(model, quant, ctx)
|
||||
if gpu_vram > 0 and mem <= gpu_vram:
|
||||
return "gpu", quant, ctx, mem
|
||||
if gpu_vram > 0 and mem <= available_ram:
|
||||
return "cpu_offload", quant, ctx, mem
|
||||
if gpu_vram <= 0 and mem <= available_ram:
|
||||
return "cpu_only", quant, ctx, mem
|
||||
# Try halving context
|
||||
cur_ctx = ctx // 2
|
||||
while cur_ctx >= 1024:
|
||||
mem = estimate_memory_gb(model, quant, cur_ctx)
|
||||
if gpu_vram > 0 and mem <= gpu_vram:
|
||||
return "gpu", quant, cur_ctx, mem
|
||||
if mem <= available_ram:
|
||||
return ("cpu_offload" if gpu_vram > 0 else "cpu_only"), quant, cur_ctx, mem
|
||||
cur_ctx //= 2
|
||||
return None
|
||||
|
||||
|
||||
def _quant_bits(q):
|
||||
"""Approximate bit-width of a quant label so GGUF quant tiers (Q4/Q8/…) can
|
||||
be matched against prequantized formats (AWQ 4, AWQ-8bit, FP8, GPTQ-4bit…).
|
||||
Returns 0 when unknown (caller treats unknown as "don't filter")."""
|
||||
qu = (q or "").upper().replace("-", "").replace("_", "").replace(" ", "")
|
||||
# GGUF k-quants + float formats
|
||||
if qu.startswith("Q8") or "FP8" in qu:
|
||||
return 8
|
||||
if qu.startswith("Q4") or qu.startswith("IQ4"):
|
||||
return 4
|
||||
if qu.startswith("Q2") or qu.startswith("IQ2"):
|
||||
return 2
|
||||
if qu.startswith("Q3") or qu.startswith("IQ3"):
|
||||
return 3
|
||||
if qu.startswith("Q5"):
|
||||
return 5
|
||||
if qu.startswith("Q6"):
|
||||
return 6
|
||||
if qu.startswith("F16") or qu.startswith("BF16") or qu.startswith("F32"):
|
||||
return 16
|
||||
# Prequantized formats: pull the bit-width digit (AWQ4 / AWQ4BIT / GPTQ8 / 4BIT / INT8 …)
|
||||
m = re.search(r"(?:AWQ|GPTQ|MLX|EXL2|BNB|INT|W)(\d{1,2})", qu) or re.search(r"(\d{1,2})BIT", qu)
|
||||
if m:
|
||||
b = int(m.group(1))
|
||||
if 2 <= b <= 16:
|
||||
return b
|
||||
return 0
|
||||
|
||||
|
||||
def analyze_model(model, system, target_quant=None):
|
||||
pb = params_b(model)
|
||||
if pb <= 0:
|
||||
return None
|
||||
|
||||
use_case = infer_use_case(model)
|
||||
has_gpu = system.get("has_gpu", False)
|
||||
gpu_vram = (system.get("gpu_vram_gb") or 0) if has_gpu else 0
|
||||
gpu_count = system.get("gpu_count", 1) or 1
|
||||
single_gpu_vram = gpu_vram / gpu_count if gpu_count > 1 else gpu_vram
|
||||
available_ram = system.get("available_ram_gb", 0)
|
||||
# When the user has explicitly picked a GPU config (not RAM mode), they want
|
||||
# to see what runs ON the GPU(s) — not big models that only "fit" by spilling
|
||||
# most layers to system RAM. Zeroing the offload budget makes _try_quant_at
|
||||
# take only its GPU branches (fit on VRAM, shrinking context if needed),
|
||||
# otherwise return None. Fixes "96 GB GPU still lists a 175 GB model".
|
||||
gpu_only = bool(system.get("gpu_only")) and has_gpu and gpu_vram > 0
|
||||
eff_ram = 0 if gpu_only else available_ram
|
||||
is_moe = model.get("is_moe", False)
|
||||
ctx = model.get("context_length", 4096) or 4096
|
||||
|
||||
native_quant = model.get("quantization", "Q4_K_M")
|
||||
preq = is_prequantized(model)
|
||||
|
||||
# GGUF models can't be sharded across GPUs — use single GPU VRAM
|
||||
is_gguf = bool(model.get("gguf_sources"))
|
||||
quant_upper = (native_quant or "").upper()
|
||||
is_gguf_quant = any(quant_upper.startswith(p) for p in ("Q2", "Q3", "Q4", "Q5", "Q6", "Q8", "IQ", "F16", "F32"))
|
||||
# Single-GPU VRAM only applies to GGUF/dense builds (llama.cpp can't shard
|
||||
# across GPUs). Prequantized formats (AWQ/GPTQ/FP8) are served sharded by
|
||||
# vLLM across all GPUs, so they get the FULL multi-GPU VRAM — even when the
|
||||
# model also lists a GGUF alternate download (gguf_sources).
|
||||
if (is_gguf or is_gguf_quant) and not preq:
|
||||
effective_vram = single_gpu_vram
|
||||
else:
|
||||
effective_vram = gpu_vram
|
||||
|
||||
# Determine which quant to evaluate at
|
||||
if preq:
|
||||
# AWQ/GPTQ/FP8/MLX come at a fixed bit-width. If the user picked a
|
||||
# specific quant tier (e.g. Q8 → 8-bit), only keep prequant models whose
|
||||
# native bit-width matches — otherwise selecting Q8 would still surface
|
||||
# AWQ-4bit models, mixing 4- and 8-bit in one view.
|
||||
if target_quant:
|
||||
_tb, _nb = _quant_bits(target_quant), _quant_bits(native_quant)
|
||||
if _tb and _nb and _tb != _nb:
|
||||
return None
|
||||
quant_to_try = native_quant
|
||||
elif target_quant:
|
||||
# User picked a specific quant
|
||||
quant_to_try = target_quant
|
||||
else:
|
||||
# Default: Q4_K_M (user's stated preference)
|
||||
quant_to_try = "Q4_K_M"
|
||||
|
||||
result = _try_quant_at(model, quant_to_try, ctx, effective_vram, eff_ram)
|
||||
|
||||
# If target quant doesn't fit and it's not pre-quantized, try lower quants
|
||||
if result is None and not preq and target_quant:
|
||||
from services.hwfit.models import QUANT_HIERARCHY
|
||||
idx = QUANT_HIERARCHY.index(target_quant) if target_quant in QUANT_HIERARCHY else -1
|
||||
for q in QUANT_HIERARCHY[idx + 1:]:
|
||||
result = _try_quant_at(model, q, ctx, effective_vram, eff_ram)
|
||||
if result:
|
||||
break
|
||||
|
||||
if result is None:
|
||||
# Model doesn't fit on the user's current hardware. Surface it
|
||||
# anyway with a "too_tight" badge instead of silently dropping
|
||||
# it — without this, editing the hardware config to try LARGER
|
||||
# tiers never revealed the bigger models, because they were
|
||||
# filtered out before the user could see what would fit. The
|
||||
# client already knows how to render too_tight (red row).
|
||||
oversized_required = estimate_memory_gb(model, quant_to_try, ctx)
|
||||
return {
|
||||
"name": model.get("name"),
|
||||
"provider": model.get("provider"),
|
||||
"parameter_count": model.get("parameter_count"),
|
||||
"params_b": round(pb, 1),
|
||||
"is_moe": is_moe,
|
||||
"use_case": use_case,
|
||||
"fit_level": "too_tight",
|
||||
"run_mode": "no_fit",
|
||||
"quant": quant_to_try,
|
||||
"context": ctx,
|
||||
"required_gb": round(oversized_required, 1),
|
||||
"speed_tps": 0,
|
||||
"score": 0,
|
||||
"scores": {"quality": 0, "speed": 0, "fit": 0, "context": 0},
|
||||
"gguf_sources": model.get("gguf_sources", []),
|
||||
"context_length": model.get("context_length", 4096),
|
||||
}
|
||||
|
||||
run_mode, quant, fit_ctx, required_gb = result
|
||||
|
||||
# Determine fit level
|
||||
budget = effective_vram if run_mode == "gpu" else available_ram
|
||||
if required_gb > budget:
|
||||
return None
|
||||
if run_mode == "gpu":
|
||||
rec = model.get("recommended_ram_gb") or required_gb
|
||||
if rec <= gpu_vram:
|
||||
fit_level = "perfect"
|
||||
elif gpu_vram >= required_gb * 1.2:
|
||||
fit_level = "good"
|
||||
else:
|
||||
fit_level = "marginal"
|
||||
elif run_mode == "cpu_offload":
|
||||
fit_level = "good" if available_ram >= required_gb * 1.2 else "marginal"
|
||||
else:
|
||||
fit_level = "marginal"
|
||||
|
||||
tps = _estimate_speed(model, quant, run_mode, system)
|
||||
|
||||
q_score = _quality_score(model, quant, use_case)
|
||||
s_score = _speed_score(tps, use_case)
|
||||
f_score = _fit_score(required_gb, budget)
|
||||
c_score = _context_score(fit_ctx, use_case)
|
||||
|
||||
wq, ws, wf, wc = USE_CASE_WEIGHTS.get(use_case, (0.45, 0.30, 0.15, 0.10))
|
||||
composite = q_score * wq + s_score * ws + f_score * wf + c_score * wc
|
||||
|
||||
return {
|
||||
"name": model.get("name"),
|
||||
"provider": model.get("provider"),
|
||||
"parameter_count": model.get("parameter_count"),
|
||||
"params_b": round(pb, 1),
|
||||
"is_moe": is_moe,
|
||||
"use_case": use_case,
|
||||
"fit_level": fit_level,
|
||||
"run_mode": run_mode,
|
||||
"quant": quant,
|
||||
"context": fit_ctx,
|
||||
"required_gb": round(required_gb, 1),
|
||||
"speed_tps": round(tps, 1),
|
||||
"score": round(composite, 1),
|
||||
"scores": {
|
||||
"quality": round(q_score, 1),
|
||||
"speed": round(s_score, 1),
|
||||
"fit": round(f_score, 1),
|
||||
"context": round(c_score, 1),
|
||||
},
|
||||
"gguf_sources": model.get("gguf_sources", []),
|
||||
"context_length": model.get("context_length", 4096),
|
||||
}
|
||||
|
||||
|
||||
SORT_KEYS = {
|
||||
"score": lambda r: r["score"],
|
||||
"speed": lambda r: r["speed_tps"],
|
||||
"vram": lambda r: r["required_gb"],
|
||||
"params": lambda r: r["params_b"],
|
||||
"context": lambda r: r["context"],
|
||||
}
|
||||
|
||||
|
||||
def rank_models(system, use_case=None, limit=50, search=None, sort="score", quant=None):
|
||||
"""Rank all models against detected hardware. Returns sorted list of fit results."""
|
||||
models = get_models()
|
||||
results = []
|
||||
|
||||
# Include image gen models only when explicitly filtered
|
||||
if use_case == "image_gen":
|
||||
try:
|
||||
from services.hwfit.image_models import rank_image_models
|
||||
except ImportError:
|
||||
rank_image_models = None
|
||||
if rank_image_models:
|
||||
img_results = rank_image_models(system, search=search)
|
||||
else:
|
||||
img_results = []
|
||||
for im in img_results:
|
||||
fit_map = {"perfect": "perfect", "good": "good", "tight": "marginal", "no_fit": "too_tight", "no_gpu": "too_tight"}
|
||||
results.append({
|
||||
"name": im["id"],
|
||||
"provider": im["provider"],
|
||||
"parameter_count": f"{im['params_b']}B",
|
||||
"params_b": im["params_b"],
|
||||
"is_moe": False,
|
||||
"use_case": "image_gen",
|
||||
"fit_level": fit_map.get(im["fit"], "too_tight"),
|
||||
"run_mode": "gpu" if im["fits"] else "no_fit",
|
||||
"quant": im.get("quant", "BF16"),
|
||||
"context": 0,
|
||||
"context_length": 0,
|
||||
"required_gb": round(im.get("vram_needed") or 0, 1),
|
||||
"speed_tps": 0,
|
||||
"score": float(im["score"]),
|
||||
"scores": {"quality": float(im["quality"]), "speed": float(im["speed"]), "fit": 0, "context": 0},
|
||||
"gguf_sources": [],
|
||||
"is_image_gen": True,
|
||||
"capabilities": im.get("capabilities", []),
|
||||
"description": im.get("description", ""),
|
||||
})
|
||||
if use_case == "image_gen":
|
||||
sort_fn = SORT_KEYS.get(sort, SORT_KEYS["score"])
|
||||
results.sort(key=sort_fn, reverse=(sort != "vram"))
|
||||
return results[:limit]
|
||||
|
||||
# If user picked a prequantized format (AWQ/FP8/GPTQ), filter to only those models
|
||||
filter_native = quant and any(quant.startswith(p) for p in ("AWQ-", "GPTQ-", "FP8"))
|
||||
|
||||
# MLX-quantized models only run on Apple Silicon (Metal). Exclude them on
|
||||
# every other backend (CUDA / ROCm / CPU) so Linux/Windows users don't see
|
||||
# unrunnable suggestions.
|
||||
system_backend = (system.get("backend") or "").lower()
|
||||
apple_silicon = system_backend in ("mps", "metal", "apple")
|
||||
|
||||
for m in models:
|
||||
native_q = m.get("quantization", "")
|
||||
|
||||
# Drop MLX models on non-Apple hardware
|
||||
if not apple_silicon and native_q.startswith("mlx-"):
|
||||
continue
|
||||
|
||||
# Format filter: AWQ tab → only AWQ models, FP8 tab → only FP8 models
|
||||
if filter_native:
|
||||
if quant == "FP8" and native_q != "FP8":
|
||||
continue
|
||||
if quant.startswith("AWQ") and not native_q.startswith("AWQ"):
|
||||
continue
|
||||
if quant.startswith("GPTQ") and not native_q.startswith("GPTQ"):
|
||||
continue
|
||||
|
||||
if search:
|
||||
name = m.get("name", "").lower()
|
||||
provider = m.get("provider", "").lower()
|
||||
if search.lower() not in name and search.lower() not in provider:
|
||||
continue
|
||||
|
||||
result = analyze_model(m, system, target_quant=quant)
|
||||
if result is None:
|
||||
continue
|
||||
|
||||
if use_case:
|
||||
model_uc = infer_use_case(m)
|
||||
if use_case != model_uc and use_case != "general":
|
||||
continue
|
||||
|
||||
results.append(result)
|
||||
|
||||
# Pick the visible SET by best fit (score) first, so it stays the same no
|
||||
# matter which column the user sorts by — otherwise sorting by params would
|
||||
# truncate to the N biggest models (huge ones that don't even fit) while
|
||||
# sorting by vram showed the N smallest. Only AFTER choosing the set do we
|
||||
# order it by the requested column.
|
||||
results.sort(key=SORT_KEYS["score"], reverse=True)
|
||||
results = results[:limit]
|
||||
sort_fn = SORT_KEYS.get(sort, SORT_KEYS["score"])
|
||||
# vram ascending (smallest first), everything else descending (biggest first)
|
||||
results.sort(key=sort_fn, reverse=(sort != "vram"))
|
||||
return results
|
||||
@@ -0,0 +1,457 @@
|
||||
import os
|
||||
import platform
|
||||
import subprocess
|
||||
import time
|
||||
|
||||
CACHE_TTL = 1800 # 30 min — hardware rarely changes; use the Rescan button to force a re-probe
|
||||
|
||||
|
||||
_remote_host = None # set by detect_system(host=...)
|
||||
_remote_port = None # set by detect_system(ssh_port=...)
|
||||
_remote_platform = None # set by detect_system(platform=...): "windows", "linux", "termux"
|
||||
_last_gpu_error = None # set by _detect_nvidia() when nvidia-smi errors (driver mismatch, etc.)
|
||||
|
||||
|
||||
def _run(cmd):
|
||||
try:
|
||||
if _remote_host:
|
||||
# Run command on remote host via SSH
|
||||
if isinstance(cmd, list):
|
||||
cmd_str = " ".join(cmd)
|
||||
else:
|
||||
cmd_str = cmd
|
||||
ssh_cmd = ["ssh", "-o", "ConnectTimeout=5", "-o", "StrictHostKeyChecking=no"]
|
||||
if _remote_port and _remote_port != "22":
|
||||
ssh_cmd += ["-p", _remote_port]
|
||||
ssh_cmd += [_remote_host, cmd_str]
|
||||
r = subprocess.run(
|
||||
ssh_cmd,
|
||||
capture_output=True, text=True, timeout=15,
|
||||
)
|
||||
else:
|
||||
r = subprocess.run(cmd, capture_output=True, text=True, timeout=10)
|
||||
if r.returncode == 0:
|
||||
return r.stdout.strip()
|
||||
except Exception:
|
||||
pass
|
||||
return None
|
||||
|
||||
|
||||
def _group_gpus(gpus):
|
||||
"""Group identical GPUs by (name, rounded VRAM).
|
||||
|
||||
vLLM tensor-parallel only works across IDENTICAL GPUs, so a mixed box must
|
||||
be split into homogeneous pools. Each group carries the device indices so a
|
||||
serve command can pin CUDA_VISIBLE_DEVICES to exactly one pool. Biggest pool
|
||||
(by total VRAM) first — that's the sensible auto-default serving target.
|
||||
"""
|
||||
groups = {}
|
||||
order = []
|
||||
for g in gpus:
|
||||
key = (g["name"], round(g["vram_gb"]))
|
||||
if key not in groups:
|
||||
groups[key] = {
|
||||
"name": g["name"],
|
||||
"vram_each": round(g["vram_gb"], 1),
|
||||
"count": 0,
|
||||
"indices": [],
|
||||
}
|
||||
order.append(key)
|
||||
groups[key]["count"] += 1
|
||||
groups[key]["indices"].append(g.get("index"))
|
||||
out = []
|
||||
for key in order:
|
||||
grp = groups[key]
|
||||
grp["vram_total"] = round(grp["vram_each"] * grp["count"], 1)
|
||||
out.append(grp)
|
||||
out.sort(key=lambda x: x["vram_total"], reverse=True)
|
||||
return out
|
||||
|
||||
|
||||
def _detect_nvidia():
|
||||
global _last_gpu_error
|
||||
_last_gpu_error = None
|
||||
out = _run(["nvidia-smi", "--query-gpu=memory.total,name", "--format=csv,noheader,nounits"])
|
||||
# Remote fallback: a non-interactive SSH shell often has a minimal PATH
|
||||
# that omits where nvidia-smi lives (/usr/bin, /usr/local/cuda/bin), so the
|
||||
# first call silently returns nothing → "No GPU" on hosts that DO have GPUs.
|
||||
# Retry through a login shell with the common CUDA bin dirs on PATH.
|
||||
if not out and _remote_host:
|
||||
out = _run(
|
||||
"bash -lc 'export PATH=\"$PATH:/usr/bin:/usr/local/bin:/usr/local/cuda/bin\"; "
|
||||
"nvidia-smi --query-gpu=memory.total,name --format=csv,noheader,nounits'"
|
||||
)
|
||||
# Last resort: call nvidia-smi by absolute path. Some hosts have a login
|
||||
# shell that isn't bash (or a profile that errors), so the bash -lc retry
|
||||
# above still comes back empty even though the binary is right there.
|
||||
if not out and _remote_host:
|
||||
for _p in ("/usr/bin/nvidia-smi", "/usr/local/bin/nvidia-smi", "/usr/local/cuda/bin/nvidia-smi"):
|
||||
out = _run(f"{_p} --query-gpu=memory.total,name --format=csv,noheader,nounits")
|
||||
if out:
|
||||
break
|
||||
if not out:
|
||||
return None
|
||||
|
||||
# nvidia-smi present but unable to talk to the driver (e.g. it was updated
|
||||
# without a reboot). It prints an error and no GPU rows — surface that as a
|
||||
# driver error rather than the misleading "No GPU".
|
||||
_low = out.lower()
|
||||
if ("nvml" in _low or "driver/library version mismatch" in _low
|
||||
or "couldn't communicate" in _low or "no devices were found" in _low
|
||||
or "failed to initialize" in _low):
|
||||
_last_gpu_error = out.strip().split("\n")[0][:140] or "NVIDIA driver error"
|
||||
return None
|
||||
|
||||
gpus = []
|
||||
# nvidia-smi lists GPUs in index order (0,1,2,...), so the row position is
|
||||
# the CUDA device index we'd pass to CUDA_VISIBLE_DEVICES.
|
||||
for idx, line in enumerate(out.strip().split("\n")):
|
||||
parts = [p.strip() for p in line.split(",")]
|
||||
if len(parts) >= 2:
|
||||
try:
|
||||
vram_mb = float(parts[0])
|
||||
gpus.append({"index": idx, "name": parts[1], "vram_gb": vram_mb / 1024.0})
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
if not gpus:
|
||||
return None
|
||||
total_vram = sum(g["vram_gb"] for g in gpus)
|
||||
groups = _group_gpus(gpus)
|
||||
return {
|
||||
"gpu_name": gpus[0]["name"],
|
||||
"gpu_vram_gb": round(total_vram, 1),
|
||||
"gpu_count": len(gpus),
|
||||
"gpus": gpus,
|
||||
"gpu_groups": groups,
|
||||
"homogeneous": len(groups) <= 1,
|
||||
"backend": "cuda",
|
||||
}
|
||||
|
||||
|
||||
def _detect_amd():
|
||||
"""Detect AMD GPUs. Handles both discrete cards (with mem_info_vram_total)
|
||||
and APUs / unified-memory SoCs like Strix Halo (which expose
|
||||
mem_info_vis_vram_total instead, or only mem_info_gtt_total)."""
|
||||
def _read(path):
|
||||
if _remote_host:
|
||||
val = _run(["cat", path])
|
||||
return val.strip() if val else None
|
||||
try:
|
||||
with open(path) as f:
|
||||
return f.read().strip()
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def _list_drm_cards():
|
||||
if _remote_host:
|
||||
out = _run(["ls", "/sys/class/drm"])
|
||||
if not out:
|
||||
return []
|
||||
return [e for e in out.split() if e.startswith("card") and "-" not in e]
|
||||
try:
|
||||
return [e for e in os.listdir("/sys/class/drm") if e.startswith("card") and "-" not in e]
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
try:
|
||||
cards = []
|
||||
is_apu = False
|
||||
for _cidx, entry in enumerate(_list_drm_cards()):
|
||||
base = f"/sys/class/drm/{entry}/device"
|
||||
vendor = _read(f"{base}/vendor")
|
||||
if vendor != "0x1002":
|
||||
continue
|
||||
# Discrete cards usually report real VRAM in mem_info_vram_total,
|
||||
# while some AMD APUs / Docker views expose a tiny vram_total and
|
||||
# the usable pool in vis_vram_total. Use the larger of those two;
|
||||
# only fall back to GTT if neither VRAM field is available.
|
||||
vram_raw = _read(f"{base}/mem_info_vram_total")
|
||||
vis_raw = _read(f"{base}/mem_info_vis_vram_total")
|
||||
gtt_raw = _read(f"{base}/mem_info_gtt_total")
|
||||
vram_val = int(vram_raw) if vram_raw and vram_raw.isdigit() else 0
|
||||
vis_val = int(vis_raw) if vis_raw and vis_raw.isdigit() else 0
|
||||
gtt_val = int(gtt_raw) if gtt_raw and gtt_raw.isdigit() else 0
|
||||
vram_bytes = max(vram_val, vis_val)
|
||||
if vram_bytes <= 0:
|
||||
vram_bytes = gtt_val
|
||||
if vis_val and vis_val >= vram_val:
|
||||
is_apu = True
|
||||
if vram_bytes <= 0:
|
||||
continue
|
||||
name = _read(f"{base}/product_name") or f"AMD GPU ({entry})"
|
||||
cards.append({"index": _cidx, "name": name, "vram_gb": vram_bytes / (1024**3)})
|
||||
|
||||
if not cards:
|
||||
return None
|
||||
total_vram = sum(c["vram_gb"] for c in cards)
|
||||
groups = _group_gpus(cards)
|
||||
# NOTE: for APUs with BIOS UMA carveout (e.g. Strix Halo), vis_vram_total
|
||||
# is the real usable GPU memory — it's physically backed but reserved
|
||||
# by BIOS so it doesn't appear in /proc/meminfo. Don't cap it at system
|
||||
# RAM: the two pools are separate from the OS's perspective.
|
||||
return {
|
||||
"gpu_name": cards[0]["name"],
|
||||
"gpu_vram_gb": round(total_vram, 1),
|
||||
"gpu_count": len(cards),
|
||||
"gpus": cards,
|
||||
"gpu_groups": groups,
|
||||
"homogeneous": len(groups) <= 1,
|
||||
"backend": "rocm",
|
||||
"unified_memory": is_apu,
|
||||
}
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def _read_file(path):
|
||||
"""Read a file, locally or via SSH."""
|
||||
if _remote_host:
|
||||
return _run(["cat", path])
|
||||
try:
|
||||
with open(path) as f:
|
||||
return f.read()
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def _parse_meminfo():
|
||||
"""Parse /proc/meminfo into a dict of key -> KB values."""
|
||||
text = _read_file("/proc/meminfo")
|
||||
if not text:
|
||||
return {}
|
||||
result = {}
|
||||
for line in text.split("\n"):
|
||||
if ":" in line:
|
||||
key, val = line.split(":", 1)
|
||||
parts = val.strip().split()
|
||||
if parts:
|
||||
try:
|
||||
result[key.strip()] = int(parts[0])
|
||||
except ValueError:
|
||||
pass
|
||||
return result
|
||||
|
||||
|
||||
def _get_ram_gb():
|
||||
meminfo = _parse_meminfo()
|
||||
if "MemTotal" in meminfo:
|
||||
return meminfo["MemTotal"] / (1024**2)
|
||||
|
||||
if not _remote_host:
|
||||
try:
|
||||
pages = os.sysconf("SC_PHYS_PAGES")
|
||||
page_size = os.sysconf("SC_PAGE_SIZE")
|
||||
if pages and page_size:
|
||||
return (pages * page_size) / (1024**3)
|
||||
except Exception:
|
||||
pass
|
||||
return 0.0
|
||||
|
||||
|
||||
def _get_available_ram_gb():
|
||||
meminfo = _parse_meminfo()
|
||||
if "MemAvailable" in meminfo:
|
||||
return meminfo["MemAvailable"] / (1024**2)
|
||||
return _get_ram_gb() * 0.7
|
||||
|
||||
|
||||
def _get_cpu_name():
|
||||
text = _read_file("/proc/cpuinfo")
|
||||
if text:
|
||||
for line in text.split("\n"):
|
||||
if line.startswith("model name"):
|
||||
return line.split(":", 1)[1].strip()
|
||||
|
||||
if not _remote_host:
|
||||
return platform.processor() or "unknown"
|
||||
return "unknown"
|
||||
|
||||
|
||||
def _get_cpu_count():
|
||||
if _remote_host:
|
||||
out = _run(["nproc"])
|
||||
if out:
|
||||
try:
|
||||
return int(out.strip())
|
||||
except ValueError:
|
||||
pass
|
||||
# fallback: count "processor" lines in /proc/cpuinfo
|
||||
text = _read_file("/proc/cpuinfo")
|
||||
if text:
|
||||
return sum(1 for line in text.split("\n") if line.startswith("processor"))
|
||||
return os.cpu_count() or 1
|
||||
|
||||
|
||||
def _detect_windows():
|
||||
"""Detect Windows hardware in a single SSH call using PowerShell."""
|
||||
# Single PowerShell command that gathers all hardware info at once
|
||||
ps_cmd = (
|
||||
"$r = @{}; "
|
||||
"$os = Get-CimInstance Win32_OperatingSystem; "
|
||||
"$r.ram_gb = [math]::Round($os.TotalVisibleMemorySize / 1048576, 1); "
|
||||
"$r.avail_gb = [math]::Round($os.FreePhysicalMemory / 1048576, 1); "
|
||||
"$cpu = Get-CimInstance Win32_Processor | Select-Object -First 1; "
|
||||
"$r.cpu_name = $cpu.Name; "
|
||||
"$r.cpu_cores = (Get-CimInstance Win32_Processor | Measure-Object -Property NumberOfLogicalProcessors -Sum).Sum; "
|
||||
"$r.arch = $cpu.AddressWidth; "
|
||||
# GPU detection via nvidia-smi (fastest) or WMI fallback
|
||||
"try { "
|
||||
" $nv = nvidia-smi --query-gpu=memory.total,name --format=csv,noheader,nounits 2>$null; "
|
||||
" if ($LASTEXITCODE -eq 0 -and $nv) { "
|
||||
" $gpus = @(); "
|
||||
" foreach ($line in $nv -split \"`n\") { "
|
||||
" $p = $line -split ','; "
|
||||
" if ($p.Count -ge 2) { $gpus += @{name=$p[1].Trim(); vram_mb=[double]$p[0].Trim()} } "
|
||||
" }; "
|
||||
" $r.gpu_name = $gpus[0].name; "
|
||||
" $r.gpu_vram_gb = [math]::Round(($gpus | Measure-Object -Property vram_mb -Sum).Sum / 1024, 1); "
|
||||
" $r.gpu_count = $gpus.Count; "
|
||||
" $r.gpu_backend = 'cuda'; "
|
||||
" } "
|
||||
"} catch {}; "
|
||||
"if (-not $r.gpu_name) { "
|
||||
" $wmiGpu = Get-CimInstance Win32_VideoController | Where-Object { $_.AdapterRAM -gt 0 } | Select-Object -First 1; "
|
||||
" if ($wmiGpu) { "
|
||||
" $r.gpu_name = $wmiGpu.Name; "
|
||||
" $r.gpu_vram_gb = [math]::Round($wmiGpu.AdapterRAM / 1073741824, 1); "
|
||||
" $r.gpu_count = 1; "
|
||||
" $r.gpu_backend = 'cpu_x86'; " # WMI doesn't tell us CUDA/ROCm
|
||||
" } "
|
||||
"}; "
|
||||
"$r | ConvertTo-Json -Compress"
|
||||
)
|
||||
out = _run(f'powershell -Command "{ps_cmd}"')
|
||||
if not out:
|
||||
return None
|
||||
import json as _json
|
||||
try:
|
||||
d = _json.loads(out)
|
||||
result = {
|
||||
"total_ram_gb": d.get("ram_gb", 0),
|
||||
"available_ram_gb": d.get("avail_gb", 0),
|
||||
"cpu_cores": d.get("cpu_cores", 1),
|
||||
"cpu_name": d.get("cpu_name", "unknown"),
|
||||
"has_gpu": bool(d.get("gpu_name")),
|
||||
"gpu_name": d.get("gpu_name"),
|
||||
"gpu_vram_gb": d.get("gpu_vram_gb"),
|
||||
"gpu_count": d.get("gpu_count", 0),
|
||||
"backend": d.get("gpu_backend", "cpu_x86"),
|
||||
}
|
||||
# PowerShell only reports aggregate GPU info, not per-card detail, so we
|
||||
# can't tell a mixed box from a uniform one here — assume one homogeneous
|
||||
# pool spanning all reported GPUs (the common Windows case).
|
||||
_n = result["gpu_count"] or 0
|
||||
if result["has_gpu"] and _n > 0:
|
||||
_each = round((result["gpu_vram_gb"] or 0) / _n, 1)
|
||||
result["gpus"] = [
|
||||
{"index": i, "name": result["gpu_name"], "vram_gb": _each} for i in range(_n)
|
||||
]
|
||||
result["gpu_groups"] = [{
|
||||
"name": result["gpu_name"],
|
||||
"vram_each": _each,
|
||||
"count": _n,
|
||||
"indices": list(range(_n)),
|
||||
"vram_total": result["gpu_vram_gb"],
|
||||
}]
|
||||
result["homogeneous"] = True
|
||||
return result
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
_cache_by_host = {} # host -> (timestamp, result)
|
||||
|
||||
|
||||
def detect_system(host="", ssh_port="", platform="", fresh=False):
|
||||
"""Detect system hardware: RAM, CPU, GPU. Cached per host (hardware rarely
|
||||
changes, and probing a remote host over SSH is slow). Pass fresh=True to
|
||||
bypass the cache and re-probe (the "Rescan" button).
|
||||
If host is set (e.g. 'user@server'), runs detection commands over SSH.
|
||||
platform: "windows", "linux", "termux", or "" (auto-detect).
|
||||
"""
|
||||
global _remote_host, _remote_port, _remote_platform
|
||||
|
||||
cache_key = host or "_local"
|
||||
now = time.time()
|
||||
if not fresh and cache_key in _cache_by_host:
|
||||
ts, cached = _cache_by_host[cache_key]
|
||||
if (now - ts) < CACHE_TTL:
|
||||
return cached
|
||||
|
||||
_remote_host = host or None
|
||||
_remote_port = ssh_port or None
|
||||
_remote_platform = platform or None
|
||||
|
||||
# Windows: single PowerShell command for all hardware info
|
||||
if _remote_platform == "windows" and _remote_host:
|
||||
result = _detect_windows()
|
||||
if result:
|
||||
_remote_host = None
|
||||
_remote_platform = None
|
||||
_cache_by_host[cache_key] = (now, result)
|
||||
return result
|
||||
# If Windows detection failed, return error
|
||||
result = {"error": f"Cannot connect to {host}", "host": host}
|
||||
_remote_host = None
|
||||
_remote_platform = None
|
||||
_cache_by_host[cache_key] = (now, result)
|
||||
return result
|
||||
|
||||
# Linux/Termux: existing multi-command detection
|
||||
total_ram = round(_get_ram_gb(), 1)
|
||||
# If remote host returns 0 RAM, connection likely failed
|
||||
if _remote_host and total_ram <= 0:
|
||||
result = {"error": f"Cannot connect to {host}", "host": host}
|
||||
_cache_by_host[cache_key] = (now, result)
|
||||
_remote_host = None
|
||||
_remote_platform = None
|
||||
return result
|
||||
available_ram = round(_get_available_ram_gb(), 1)
|
||||
cpu_cores = _get_cpu_count()
|
||||
cpu_name = _get_cpu_name()
|
||||
|
||||
gpu_info = _detect_nvidia() or _detect_amd()
|
||||
|
||||
if gpu_info:
|
||||
result = {
|
||||
"total_ram_gb": total_ram,
|
||||
"available_ram_gb": available_ram,
|
||||
"cpu_cores": cpu_cores,
|
||||
"cpu_name": cpu_name,
|
||||
"has_gpu": True,
|
||||
"gpu_name": gpu_info["gpu_name"],
|
||||
"gpu_vram_gb": gpu_info["gpu_vram_gb"],
|
||||
"gpu_count": gpu_info["gpu_count"],
|
||||
"gpus": gpu_info.get("gpus", []),
|
||||
"gpu_groups": gpu_info.get("gpu_groups", []),
|
||||
"homogeneous": gpu_info.get("homogeneous", True),
|
||||
"backend": gpu_info["backend"],
|
||||
}
|
||||
else:
|
||||
if _remote_host:
|
||||
arch_out = _run(["uname", "-m"]) or ""
|
||||
else:
|
||||
import platform as _platform
|
||||
arch_out = _platform.machine().lower()
|
||||
backend = "cpu_arm" if "aarch64" in arch_out or "arm" in arch_out else "cpu_x86"
|
||||
result = {
|
||||
"total_ram_gb": total_ram,
|
||||
"available_ram_gb": available_ram,
|
||||
"cpu_cores": cpu_cores,
|
||||
"cpu_name": cpu_name,
|
||||
"has_gpu": False,
|
||||
"gpu_name": None,
|
||||
"gpu_vram_gb": None,
|
||||
"gpu_count": 0,
|
||||
"backend": backend,
|
||||
# Set when nvidia-smi exists but failed (e.g. driver/library
|
||||
# version mismatch) — lets the UI say "GPU driver error" instead
|
||||
# of the misleading "No GPU".
|
||||
"gpu_error": _last_gpu_error,
|
||||
}
|
||||
|
||||
_remote_host = None
|
||||
_remote_platform = None
|
||||
_cache_by_host[cache_key] = (now, result)
|
||||
return result
|
||||
@@ -0,0 +1,374 @@
|
||||
"""Image generation model registry and VRAM fitting for Cookbook."""
|
||||
|
||||
# Curated registry of image generation models supported by diffusers.
|
||||
# ONLY verified HuggingFace repo IDs.
|
||||
# VRAM estimates are for inference (single image generation).
|
||||
IMAGE_MODEL_REGISTRY = [
|
||||
# ── Z-Image (Alibaba Tongyi) ──
|
||||
{
|
||||
"id": "Tongyi-MAI/Z-Image-Turbo",
|
||||
"name": "Z-Image Turbo",
|
||||
"provider": "Tongyi",
|
||||
"params_b": 6.0,
|
||||
"vram_bf16": 19.0,
|
||||
"vram_fp8": 10.0,
|
||||
"vram_q4": 6.0,
|
||||
"default_quant": "BF16",
|
||||
"quant_repos": {
|
||||
"FP8": "drbaph/Z-Image-Turbo-FP8",
|
||||
},
|
||||
"capabilities": ["text-to-image"],
|
||||
"description": "6B distilled, 8-step. Sub-second on H800. Apache 2.0.",
|
||||
"quality": 92,
|
||||
"speed": 95,
|
||||
"released": "2025-12",
|
||||
},
|
||||
{
|
||||
"id": "Tongyi-MAI/Z-Image",
|
||||
"name": "Z-Image",
|
||||
"provider": "Tongyi",
|
||||
"params_b": 6.0,
|
||||
"vram_bf16": 19.0,
|
||||
"vram_fp8": 10.0,
|
||||
"vram_q4": 6.0,
|
||||
"default_quant": "BF16",
|
||||
"quant_repos": {
|
||||
"FP8": "drbaph/Z-Image-fp8",
|
||||
},
|
||||
"capabilities": ["text-to-image"],
|
||||
"description": "Full undistilled model. Highest creative freedom. Apache 2.0.",
|
||||
"quality": 93,
|
||||
"speed": 70,
|
||||
"released": "2025-12",
|
||||
},
|
||||
# ── Qwen Image ──
|
||||
{
|
||||
"id": "Qwen/Qwen-Image-2512",
|
||||
"name": "Qwen Image 2512",
|
||||
"provider": "Qwen",
|
||||
"params_b": 20.0,
|
||||
"vram_bf16": 42.0,
|
||||
"vram_fp8": 22.0,
|
||||
"vram_q4": 14.0,
|
||||
"default_quant": "FP8",
|
||||
"quant_repos": {},
|
||||
"capabilities": ["text-to-image", "text-rendering"],
|
||||
"description": "Dec 2025 update. Better humans, finer detail, strong text. Apache 2.0.",
|
||||
"quality": 95,
|
||||
"speed": 50,
|
||||
"released": "2025-12",
|
||||
},
|
||||
{
|
||||
"id": "Qwen/Qwen-Image",
|
||||
"name": "Qwen Image",
|
||||
"provider": "Qwen",
|
||||
"params_b": 20.0,
|
||||
"vram_bf16": 42.0,
|
||||
"vram_fp8": 22.0,
|
||||
"vram_q4": 14.0,
|
||||
"default_quant": "FP8",
|
||||
"quant_repos": {},
|
||||
"capabilities": ["text-to-image", "text-rendering"],
|
||||
"description": "20B foundation. Best text rendering in images. Apache 2.0.",
|
||||
"quality": 94,
|
||||
"speed": 50,
|
||||
"released": "2025-08",
|
||||
},
|
||||
{
|
||||
"id": "Qwen/Qwen-Image-Edit-2511",
|
||||
"name": "Qwen Image Edit",
|
||||
"provider": "Qwen",
|
||||
"params_b": 20.0,
|
||||
"vram_bf16": 42.0,
|
||||
"vram_fp8": 22.0,
|
||||
"vram_q4": 14.0,
|
||||
"default_quant": "FP8",
|
||||
"quant_repos": {},
|
||||
"capabilities": ["image-editing", "inpainting"],
|
||||
"description": "Dedicated editing. Style transfer, object removal. Apache 2.0.",
|
||||
"quality": 92,
|
||||
"speed": 50,
|
||||
"released": "2025-11",
|
||||
},
|
||||
# ── Stable Diffusion (dedicated inpainting) ──
|
||||
{
|
||||
"id": "diffusers/stable-diffusion-xl-1.0-inpainting-0.1",
|
||||
"name": "SDXL Inpainting",
|
||||
"provider": "Stability AI",
|
||||
"params_b": 3.5,
|
||||
"vram_bf16": 12.0,
|
||||
"vram_fp8": 8.0,
|
||||
"vram_q4": 6.0,
|
||||
"default_quant": "BF16",
|
||||
"quant_repos": {},
|
||||
"capabilities": ["inpainting", "image-editing"],
|
||||
"description": "SDXL fine-tuned for inpainting (9-channel UNet). Best SD-family fill quality; fits a 24GB card comfortably.",
|
||||
"quality": 86,
|
||||
"speed": 68,
|
||||
"released": "2023-11",
|
||||
},
|
||||
{
|
||||
"id": "stable-diffusion-v1-5/stable-diffusion-inpainting",
|
||||
"name": "SD 1.5 Inpainting",
|
||||
"provider": "Stability AI",
|
||||
"params_b": 1.1,
|
||||
"vram_bf16": 4.0,
|
||||
"vram_fp8": 3.0,
|
||||
"vram_q4": 2.5,
|
||||
"default_quant": "BF16",
|
||||
"quant_repos": {},
|
||||
"capabilities": ["inpainting"],
|
||||
"description": "Classic SD 1.5 inpaint. Very light and fast; lower fidelity than SDXL.",
|
||||
"quality": 70,
|
||||
"speed": 92,
|
||||
"released": "2022-10",
|
||||
},
|
||||
# ── FLUX ──
|
||||
{
|
||||
"id": "black-forest-labs/FLUX.1-dev",
|
||||
"name": "FLUX.1 Dev",
|
||||
"provider": "Black Forest Labs",
|
||||
"params_b": 12.0,
|
||||
"vram_bf16": 33.0,
|
||||
"vram_fp8": 17.0,
|
||||
"vram_q4": 10.0,
|
||||
"default_quant": "FP8",
|
||||
"quant_repos": {
|
||||
"FP8": "diffusers/FLUX.1-dev-torchao-fp8",
|
||||
},
|
||||
"capabilities": ["text-to-image"],
|
||||
"description": "High quality, detailed. Popular community model. Non-commercial.",
|
||||
"quality": 92,
|
||||
"speed": 55,
|
||||
"released": "2024-08",
|
||||
},
|
||||
{
|
||||
"id": "black-forest-labs/FLUX.1-schnell",
|
||||
"name": "FLUX.1 Schnell",
|
||||
"provider": "Black Forest Labs",
|
||||
"params_b": 12.0,
|
||||
"vram_bf16": 33.0,
|
||||
"vram_fp8": 17.0,
|
||||
"vram_q4": 10.0,
|
||||
"default_quant": "FP8",
|
||||
"quant_repos": {
|
||||
"FP8": "Kijai/flux-fp8",
|
||||
},
|
||||
"capabilities": ["text-to-image"],
|
||||
"description": "Fast 4-step variant. Apache 2.0 license.",
|
||||
"quality": 85,
|
||||
"speed": 90,
|
||||
"released": "2024-08",
|
||||
},
|
||||
# ── Stable Diffusion ──
|
||||
{
|
||||
"id": "stabilityai/stable-diffusion-3.5-medium",
|
||||
"name": "SD 3.5 Medium",
|
||||
"provider": "Stability AI",
|
||||
"params_b": 2.5,
|
||||
"vram_bf16": 12.0,
|
||||
"vram_fp8": 7.0,
|
||||
"vram_q4": None,
|
||||
"default_quant": "BF16",
|
||||
"quant_repos": {
|
||||
"FP8": "Comfy-Org/stable-diffusion-3.5-fp8",
|
||||
},
|
||||
"capabilities": ["text-to-image"],
|
||||
"description": "2.5B lightweight, fast. Fits almost any GPU.",
|
||||
"quality": 75,
|
||||
"speed": 95,
|
||||
"released": "2024-10",
|
||||
},
|
||||
{
|
||||
"id": "stabilityai/stable-diffusion-3.5-large",
|
||||
"name": "SD 3.5 Large",
|
||||
"provider": "Stability AI",
|
||||
"params_b": 8.1,
|
||||
"vram_bf16": 22.0,
|
||||
"vram_fp8": 12.0,
|
||||
"vram_q4": None,
|
||||
"default_quant": "BF16",
|
||||
"quant_repos": {
|
||||
"FP8": "Comfy-Org/stable-diffusion-3.5-fp8",
|
||||
},
|
||||
"capabilities": ["text-to-image"],
|
||||
"description": "8B high quality. Good balance of speed and quality.",
|
||||
"quality": 85,
|
||||
"speed": 70,
|
||||
"released": "2024-10",
|
||||
},
|
||||
{
|
||||
"id": "stabilityai/stable-diffusion-3.5-large-turbo",
|
||||
"name": "SD 3.5 Large Turbo",
|
||||
"provider": "Stability AI",
|
||||
"params_b": 8.1,
|
||||
"vram_bf16": 22.0,
|
||||
"vram_fp8": 12.0,
|
||||
"vram_q4": None,
|
||||
"default_quant": "BF16",
|
||||
"quant_repos": {
|
||||
"FP8": "Comfy-Org/stable-diffusion-3.5-fp8",
|
||||
},
|
||||
"capabilities": ["text-to-image"],
|
||||
"description": "Distilled for few-step inference. Fastest large SD.",
|
||||
"quality": 80,
|
||||
"speed": 92,
|
||||
"released": "2024-10",
|
||||
},
|
||||
{
|
||||
"id": "stabilityai/stable-diffusion-xl-base-1.0",
|
||||
"name": "SDXL",
|
||||
"provider": "Stability AI",
|
||||
"params_b": 3.5,
|
||||
"vram_bf16": 10.0,
|
||||
"vram_fp8": 6.0,
|
||||
"vram_q4": None,
|
||||
"default_quant": "BF16",
|
||||
"quant_repos": {},
|
||||
"capabilities": ["text-to-image"],
|
||||
"description": "Classic workhorse. Huge LoRA ecosystem. Fits 8GB+.",
|
||||
"quality": 72,
|
||||
"speed": 90,
|
||||
"released": "2023-07",
|
||||
},
|
||||
# ── Hunyuan ──
|
||||
{
|
||||
"id": "tencent/HunyuanImage-3.0",
|
||||
"name": "HunyuanImage 3.0",
|
||||
"provider": "Tencent",
|
||||
"params_b": 13.0,
|
||||
"vram_bf16": 30.0,
|
||||
"vram_fp8": 16.0,
|
||||
"vram_q4": 9.0,
|
||||
"default_quant": "FP8",
|
||||
"quant_repos": {
|
||||
"Q4": "wikeeyang/Hunyuan-Image-30-Qint4",
|
||||
"NF4": "EricRollei/HunyuanImage-3.0-Instruct-NF4",
|
||||
},
|
||||
"capabilities": ["text-to-image", "text-rendering"],
|
||||
"description": "Strong text rendering. Bilingual Chinese/English. 13B activated per token.",
|
||||
"quality": 88,
|
||||
"speed": 60,
|
||||
"released": "2025-09",
|
||||
},
|
||||
{
|
||||
"id": "tencent/HunyuanImage-3.0-Instruct-Distil",
|
||||
"name": "HunyuanImage 3.0 Distil",
|
||||
"provider": "Tencent",
|
||||
"params_b": 13.0,
|
||||
"vram_bf16": 30.0,
|
||||
"vram_fp8": 16.0,
|
||||
"vram_q4": 9.0,
|
||||
"default_quant": "FP8",
|
||||
"quant_repos": {},
|
||||
"capabilities": ["text-to-image", "text-rendering"],
|
||||
"description": "Distilled variant, fewer steps. Faster with comparable quality.",
|
||||
"quality": 85,
|
||||
"speed": 80,
|
||||
"released": "2026-01",
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def get_image_models():
|
||||
"""Return the image model registry."""
|
||||
return IMAGE_MODEL_REGISTRY
|
||||
|
||||
|
||||
def rank_image_models(system, search=None, sort="fit"):
|
||||
"""Score and rank image models against detected hardware.
|
||||
|
||||
Returns list of models with fit info (vram needed, fits, recommended quant).
|
||||
"""
|
||||
gpu_vram = system.get("gpu_vram_gb", 0) or 0
|
||||
has_gpu = system.get("has_gpu", False)
|
||||
results = []
|
||||
|
||||
for model in IMAGE_MODEL_REGISTRY:
|
||||
# Filter by search
|
||||
if search:
|
||||
s = search.lower()
|
||||
if s not in model["name"].lower() and s not in model["id"].lower() and s not in model.get("description", "").lower():
|
||||
continue
|
||||
|
||||
# Determine best quant that fits
|
||||
quant = None
|
||||
vram_needed = None
|
||||
fits = False
|
||||
quant_repo = None
|
||||
|
||||
if has_gpu and gpu_vram > 0:
|
||||
# Try BF16 first, then FP8, then Q4
|
||||
for q, vram_key in [("BF16", "vram_bf16"), ("FP8", "vram_fp8"), ("Q4", "vram_q4")]:
|
||||
v = model.get(vram_key)
|
||||
if v is not None and v <= gpu_vram * 0.90: # 10% headroom
|
||||
quant = q
|
||||
vram_needed = v
|
||||
fits = True
|
||||
quant_repo = model.get("quant_repos", {}).get(q)
|
||||
break
|
||||
# If nothing fits, show what it needs
|
||||
if not fits:
|
||||
quant = model["default_quant"]
|
||||
vram_needed = model.get("vram_bf16", 0)
|
||||
|
||||
# Fit label
|
||||
if not has_gpu:
|
||||
fit = "no_gpu"
|
||||
fit_label = "No GPU"
|
||||
elif fits:
|
||||
headroom = gpu_vram - vram_needed
|
||||
if headroom > gpu_vram * 0.3:
|
||||
fit = "perfect"
|
||||
fit_label = "Perfect"
|
||||
elif headroom > gpu_vram * 0.1:
|
||||
fit = "good"
|
||||
fit_label = "Good"
|
||||
else:
|
||||
fit = "tight"
|
||||
fit_label = "Tight"
|
||||
else:
|
||||
fit = "no_fit"
|
||||
fit_label = "Too large"
|
||||
|
||||
# Score: quality * speed * fit bonus
|
||||
score = model["quality"] * 0.6 + model["speed"] * 0.2
|
||||
if fit == "perfect":
|
||||
score += 20
|
||||
elif fit == "good":
|
||||
score += 10
|
||||
elif fit == "tight":
|
||||
score += 5
|
||||
elif fit == "no_fit":
|
||||
score -= 30
|
||||
|
||||
results.append({
|
||||
"id": model["id"],
|
||||
"name": model["name"],
|
||||
"provider": model["provider"],
|
||||
"params_b": model["params_b"],
|
||||
"vram_needed": vram_needed,
|
||||
"quant": quant,
|
||||
"quant_repo": quant_repo,
|
||||
"fits": fits,
|
||||
"fit": fit,
|
||||
"fit_label": fit_label,
|
||||
"quality": model["quality"],
|
||||
"speed": model["speed"],
|
||||
"score": round(score, 1),
|
||||
"capabilities": model["capabilities"],
|
||||
"description": model["description"],
|
||||
"released": model.get("released", ""),
|
||||
})
|
||||
|
||||
# Sort
|
||||
if sort == "quality":
|
||||
results.sort(key=lambda x: (-x["quality"], -x["score"]))
|
||||
elif sort == "speed":
|
||||
results.sort(key=lambda x: (-x["speed"], -x["score"]))
|
||||
elif sort == "vram":
|
||||
results.sort(key=lambda x: (x["vram_needed"] or 999, -x["score"]))
|
||||
else: # fit (default)
|
||||
results.sort(key=lambda x: (-x["score"],))
|
||||
|
||||
return results
|
||||
@@ -0,0 +1,177 @@
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
|
||||
QUANT_HIERARCHY = ["Q8_0", "Q6_K", "Q5_K_M", "Q4_K_M", "Q3_K_M", "Q2_K"]
|
||||
|
||||
QUANT_BPP = {
|
||||
"F32": 4.0, "F16": 2.0, "BF16": 2.0, "FP8": 1.0,
|
||||
"Q8_0": 1.05, "Q6_K": 0.80, "Q5_K_M": 0.68,
|
||||
"Q4_K_M": 0.58, "Q4_0": 0.58, "Q3_K_M": 0.48, "Q2_K": 0.37,
|
||||
"AWQ-4bit": 0.50, "AWQ-8bit": 1.0,
|
||||
"GPTQ-Int4": 0.50, "GPTQ-Int8": 1.0,
|
||||
"mlx-4bit": 0.55, "mlx-8bit": 1.0, "mlx-6bit": 0.75,
|
||||
}
|
||||
|
||||
QUANT_SPEED_MULT = {
|
||||
"F16": 0.6, "BF16": 0.6, "FP8": 0.85,
|
||||
"Q8_0": 0.8, "Q6_K": 0.95, "Q5_K_M": 1.0,
|
||||
"Q4_K_M": 1.15, "Q4_0": 1.15, "Q3_K_M": 1.25, "Q2_K": 1.35,
|
||||
"AWQ-4bit": 1.2, "AWQ-8bit": 0.85,
|
||||
"GPTQ-Int4": 1.2, "GPTQ-Int8": 0.85,
|
||||
"mlx-4bit": 1.15, "mlx-8bit": 0.85, "mlx-6bit": 1.0,
|
||||
}
|
||||
|
||||
QUANT_QUALITY_PENALTY = {
|
||||
"F16": 0.0, "BF16": 0.0, "FP8": 0.0,
|
||||
"Q8_0": 0.0, "Q6_K": -1.0, "Q5_K_M": -2.0,
|
||||
"Q4_K_M": -5.0, "Q4_0": -5.0, "Q3_K_M": -8.0, "Q2_K": -12.0,
|
||||
"AWQ-4bit": -3.0, "AWQ-8bit": 0.0,
|
||||
"GPTQ-Int4": -3.0, "GPTQ-Int8": 0.0,
|
||||
"mlx-4bit": -4.0, "mlx-8bit": 0.0, "mlx-6bit": -1.0,
|
||||
}
|
||||
|
||||
QUANT_BYTES_PER_PARAM = {
|
||||
"F16": 2.0, "BF16": 2.0, "FP8": 1.0,
|
||||
"Q8_0": 1.0, "Q6_K": 0.75, "Q5_K_M": 0.625,
|
||||
"Q4_K_M": 0.5, "Q4_0": 0.5, "Q3_K_M": 0.375, "Q2_K": 0.25,
|
||||
"AWQ-4bit": 0.5, "AWQ-8bit": 1.0,
|
||||
"GPTQ-Int4": 0.5, "GPTQ-Int8": 1.0,
|
||||
"mlx-4bit": 0.5, "mlx-8bit": 1.0, "mlx-6bit": 0.75,
|
||||
}
|
||||
|
||||
# Pre-quantized formats that should NOT go through the GGUF quant hierarchy
|
||||
PREQUANTIZED_PREFIXES = ("AWQ-", "GPTQ-", "mlx-", "FP8")
|
||||
|
||||
|
||||
def is_prequantized(model):
|
||||
q = model.get("quantization", "")
|
||||
return any(q.startswith(p) for p in PREQUANTIZED_PREFIXES)
|
||||
|
||||
|
||||
def params_b(model):
|
||||
raw = model.get("parameters_raw")
|
||||
if raw and raw > 0:
|
||||
return raw / 1_000_000_000.0
|
||||
|
||||
pc = model.get("parameter_count", "")
|
||||
if pc:
|
||||
pc = pc.strip().upper()
|
||||
m = re.match(r"^([\d.]+)\s*([BKMGT]?)$", pc)
|
||||
if m:
|
||||
val = float(m.group(1))
|
||||
suffix = m.group(2)
|
||||
if suffix == "B":
|
||||
return val
|
||||
elif suffix == "M":
|
||||
return val / 1000.0
|
||||
elif suffix == "K":
|
||||
return val / 1_000_000.0
|
||||
elif suffix == "T":
|
||||
return val * 1000.0
|
||||
else:
|
||||
# No unit. A bare number this size is conventionally a millions
|
||||
# count (e.g. "355" = 355M), NOT billions — otherwise a 355M
|
||||
# model would sort as 355B and leap above every 7B/70B model.
|
||||
# A genuine billions figure carries a "B" suffix and is handled
|
||||
# above; very large bare values are raw parameter counts.
|
||||
if val >= 1_000_000:
|
||||
return val / 1_000_000_000.0 # raw count
|
||||
if val >= 1000:
|
||||
return val / 1000.0 # thousands of millions? treat as millions
|
||||
return val / 1000.0 # e.g. "355" → 0.355B
|
||||
return 0.0
|
||||
|
||||
|
||||
def estimate_memory_gb(model, quant, ctx):
|
||||
"""Estimate VRAM needed to serve a model. All weights must be loaded,
|
||||
even for MoE (all experts live in memory, only active ones compute per token).
|
||||
KV cache scales with active params for MoE (only active experts have KV state)."""
|
||||
pb = params_b(model)
|
||||
bpp = QUANT_BPP.get(quant, 0.58)
|
||||
kv_params = _active_params_b(model)
|
||||
return pb * bpp + 0.000008 * kv_params * ctx + 0.5
|
||||
|
||||
|
||||
def _active_params_b(model):
|
||||
"""For MoE: active params per token (affects KV cache and speed, not total VRAM).
|
||||
For dense: same as total params."""
|
||||
if model.get("is_moe") and model.get("active_parameters"):
|
||||
return model["active_parameters"] / 1_000_000_000.0
|
||||
return params_b(model)
|
||||
|
||||
|
||||
def best_quant_for_budget(model, budget_gb, ctx):
|
||||
"""Find best quant that fits in budget_gb of VRAM.
|
||||
Pre-quantized models (AWQ/GPTQ/MLX) use their native quant only.
|
||||
Returns (quant, ctx, mem_gb) or (None, None, None).
|
||||
"""
|
||||
if is_prequantized(model):
|
||||
q = model.get("quantization", "Q4_K_M")
|
||||
mem = estimate_memory_gb(model, q, ctx)
|
||||
if mem <= budget_gb:
|
||||
return q, ctx, mem
|
||||
# Try halving context
|
||||
cur_ctx = ctx // 2
|
||||
while cur_ctx >= 1024:
|
||||
mem = estimate_memory_gb(model, q, cur_ctx)
|
||||
if mem <= budget_gb:
|
||||
return q, cur_ctx, mem
|
||||
cur_ctx //= 2
|
||||
return None, None, None
|
||||
|
||||
# GGUF: try best quality first, then fall back
|
||||
for q in QUANT_HIERARCHY:
|
||||
mem = estimate_memory_gb(model, q, ctx)
|
||||
if mem <= budget_gb:
|
||||
return q, ctx, mem
|
||||
|
||||
cur_ctx = ctx // 2
|
||||
while cur_ctx >= 1024:
|
||||
for q in QUANT_HIERARCHY:
|
||||
mem = estimate_memory_gb(model, q, cur_ctx)
|
||||
if mem <= budget_gb:
|
||||
return q, cur_ctx, mem
|
||||
cur_ctx //= 2
|
||||
|
||||
return None, None, None
|
||||
|
||||
|
||||
def infer_use_case(model):
|
||||
name = model.get("name", "").lower()
|
||||
uc = model.get("use_case", "").lower()
|
||||
combined = name + " " + uc
|
||||
|
||||
if any(k in combined for k in ("embedding", "embed", "bge")):
|
||||
return "embedding"
|
||||
if any(k in combined for k in ("tts", "text-to-speech", "speech-synthesis", "cosyvoice", "parler")):
|
||||
return "tts"
|
||||
if any(k in combined for k in ("stt", "speech-to-text", "whisper", "transcri", "asr")):
|
||||
return "stt"
|
||||
if "code" in combined:
|
||||
return "coding"
|
||||
if any(k in combined for k in ("vision", "multimodal", "vlm", "vl-")):
|
||||
return "multimodal"
|
||||
if any(k in combined for k in ("reason", "chain-of-thought", "deepseek-r1")):
|
||||
return "reasoning"
|
||||
if any(k in combined for k in ("chat", "instruction")):
|
||||
return "chat"
|
||||
return "general"
|
||||
|
||||
|
||||
_models_cache = None
|
||||
|
||||
def get_models():
|
||||
global _models_cache
|
||||
if _models_cache is None:
|
||||
data_path = os.path.join(os.path.dirname(__file__), "data", "hf_models.json")
|
||||
try:
|
||||
with open(data_path) as f:
|
||||
_models_cache = json.load(f)
|
||||
except (FileNotFoundError, json.JSONDecodeError):
|
||||
_models_cache = []
|
||||
return _models_cache
|
||||
|
||||
|
||||
def model_catalog_path():
|
||||
return os.path.join(os.path.dirname(__file__), "data", "hf_models.json")
|
||||
@@ -0,0 +1,14 @@
|
||||
# services/memory/__init__.py
|
||||
"""Memory service — persistent memory storage and retrieval."""
|
||||
|
||||
from .service import MemoryService, Memory, MemorySearchResult
|
||||
from .memory import MemoryManager
|
||||
from .memory_vector import MemoryVectorStore
|
||||
|
||||
__all__ = [
|
||||
"MemoryService",
|
||||
"Memory",
|
||||
"MemorySearchResult",
|
||||
"MemoryManager",
|
||||
"MemoryVectorStore",
|
||||
]
|
||||
@@ -0,0 +1,359 @@
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
import uuid
|
||||
import re
|
||||
from typing import List, Dict, Tuple
|
||||
from datetime import datetime
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def tokenize(text: str) -> List[str]:
|
||||
"""Simple tokenizer that splits on whitespace and removes punctuation."""
|
||||
return [word.strip('.,!?";') for word in text.split()]
|
||||
|
||||
def get_text_similarity(text1: str, text2: str) -> float:
|
||||
"""Calculate Jaccard similarity between two texts."""
|
||||
if not text1 or not text2:
|
||||
return 0.0
|
||||
|
||||
tokens1 = set(tokenize(text1.lower()))
|
||||
tokens2 = set(tokenize(text2.lower()))
|
||||
|
||||
if not tokens1 and not tokens2:
|
||||
return 1.0
|
||||
if not tokens1 or not tokens2:
|
||||
return 0.0
|
||||
|
||||
intersection = tokens1.intersection(tokens2)
|
||||
union = tokens1.union(tokens2)
|
||||
|
||||
return len(intersection) / len(union)
|
||||
|
||||
class MemoryManager:
|
||||
def __init__(self, data_dir: str):
|
||||
self.memory_file = os.path.join(data_dir, "memory.json")
|
||||
self.ensure_file_exists()
|
||||
|
||||
def extract_memory_from_chat(self, chat_history: List[Dict], session_id: str = None) -> List[Dict]:
|
||||
"""
|
||||
Extract memory entries from chat history as a fallback when LLM fails.
|
||||
|
||||
Args:
|
||||
chat_history: List of chat messages with 'role' and 'content' keys
|
||||
session_id: Optional session ID to associate with extracted memories
|
||||
|
||||
Returns:
|
||||
List of memory entries with text, timestamp, and optional session_id
|
||||
"""
|
||||
memories = []
|
||||
|
||||
for msg in chat_history:
|
||||
if msg.get("role") == "assistant":
|
||||
content = str(msg.get("content", ""))
|
||||
lines = content.split('\n')
|
||||
|
||||
for line in lines:
|
||||
line = line.strip()
|
||||
# Look for bullet points or numbered lists that might contain memories
|
||||
if re.match(r'^[-*•]|\d+\.', line):
|
||||
# Extract the text after the bullet/number
|
||||
text_match = re.match(r'^[-*•]|\d+\.\s*(.*)', line)
|
||||
if text_match:
|
||||
text = text_match.group(1).strip()
|
||||
if text:
|
||||
memories.append({
|
||||
"text": text,
|
||||
"timestamp": int(datetime.now().timestamp()),
|
||||
"session_id": session_id
|
||||
})
|
||||
# If we see a heading that suggests memories
|
||||
elif re.search(r'memory|fact|note|remember', line, re.I):
|
||||
pass
|
||||
# If we see a clear separator or end
|
||||
elif re.match(r'^={3,}|-{3,}|_{3,}', line):
|
||||
pass
|
||||
|
||||
return memories
|
||||
|
||||
def process_inline_memory_command(self, message: str) -> Tuple[bool, str]:
|
||||
"""
|
||||
Check if a message is an inline memory command (e.g. "remember: X").
|
||||
|
||||
Args:
|
||||
message: The user message to check
|
||||
|
||||
Returns:
|
||||
Tuple of (is_command, extracted_text) where is_command is True if
|
||||
the message matches the memory command pattern
|
||||
"""
|
||||
# Pattern for memory commands: "remember: X", "memorize: X", "save: X", etc.
|
||||
pattern = r'^(?:remember|memorize|save|note|store)[:\-]?\s+(.+)$'
|
||||
match = re.match(pattern, message.strip(), re.IGNORECASE)
|
||||
|
||||
if match:
|
||||
return True, match.group(1).strip()
|
||||
else:
|
||||
return False, ""
|
||||
|
||||
def ensure_file_exists(self):
|
||||
"""Create memory file if it doesn't exist."""
|
||||
if not os.path.exists(self.memory_file):
|
||||
with open(self.memory_file, 'w', encoding='utf-8') as f:
|
||||
json.dump([], f, ensure_ascii=False, indent=2)
|
||||
|
||||
def load_all(self) -> List[Dict]:
|
||||
"""Load all memory entries from JSON file (unfiltered)."""
|
||||
if not os.path.exists(self.memory_file):
|
||||
return []
|
||||
|
||||
try:
|
||||
with open(self.memory_file, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
if isinstance(data, list):
|
||||
return self._validate_entries(data)
|
||||
except (json.JSONDecodeError, PermissionError) as e:
|
||||
logger.error("Error loading memory.json: %s", e)
|
||||
return self._migrate_from_legacy()
|
||||
|
||||
return []
|
||||
|
||||
def load(self, owner: str = None) -> List[Dict]:
|
||||
"""Load memory entries, filtered by owner."""
|
||||
entries = self.load_all()
|
||||
if owner is None:
|
||||
return entries
|
||||
return [e for e in entries if e.get("owner") == owner]
|
||||
|
||||
def claim_ownerless(self, owner: str):
|
||||
"""Assign all ownerless memory entries to the given owner. Run once to migrate."""
|
||||
entries = self.load_all()
|
||||
changed = False
|
||||
for e in entries:
|
||||
if not e.get("owner"):
|
||||
e["owner"] = owner
|
||||
changed = True
|
||||
if changed:
|
||||
self.save(entries)
|
||||
logger.info("Claimed %d ownerless memories for %s", sum(1 for e in entries if e.get("owner") == owner), owner)
|
||||
|
||||
def _validate_entries(self, entries: List[Dict]) -> List[Dict]:
|
||||
"""Ensure all entries have required fields."""
|
||||
validated = []
|
||||
for entry in entries:
|
||||
if "id" not in entry:
|
||||
entry["id"] = str(uuid.uuid4())
|
||||
if "timestamp" not in entry:
|
||||
entry["timestamp"] = int(time.time())
|
||||
if "source" not in entry:
|
||||
entry["source"] = "unknown"
|
||||
if "category" not in entry:
|
||||
entry["category"] = "fact"
|
||||
validated.append(entry)
|
||||
return validated
|
||||
|
||||
def _migrate_from_legacy(self) -> List[Dict]:
|
||||
"""Migrate from old text format to JSON if needed."""
|
||||
legacy_path = os.path.join(os.path.dirname(self.memory_file), "memory.txt")
|
||||
if not os.path.exists(legacy_path):
|
||||
return []
|
||||
|
||||
logger.info("Converting legacy memory.txt to new JSON format")
|
||||
try:
|
||||
with open(legacy_path, "r", encoding="utf-8") as f:
|
||||
lines = [ln.strip() for ln in f.readlines() if ln.strip()]
|
||||
|
||||
entries = []
|
||||
for line in lines:
|
||||
entries.append({
|
||||
"id": str(uuid.uuid4()),
|
||||
"text": line,
|
||||
"timestamp": int(time.time()),
|
||||
"source": "user",
|
||||
"category": "fact"
|
||||
})
|
||||
|
||||
self.save(entries)
|
||||
return entries
|
||||
except Exception as e:
|
||||
logger.error("Failed to convert legacy memory: %s", e)
|
||||
return []
|
||||
|
||||
def save(self, entries: List[Dict]):
|
||||
"""Save memory entries to JSON file."""
|
||||
# Validate entries before saving
|
||||
for entry in entries:
|
||||
if "id" not in entry:
|
||||
entry["id"] = str(uuid.uuid4())
|
||||
if "timestamp" not in entry:
|
||||
entry["timestamp"] = int(time.time())
|
||||
if "source" not in entry:
|
||||
entry["source"] = "user"
|
||||
if "category" not in entry:
|
||||
entry["category"] = "fact"
|
||||
|
||||
# Use atomic write
|
||||
tmp_file = self.memory_file + ".tmp"
|
||||
with open(tmp_file, "w", encoding="utf-8") as f:
|
||||
json.dump(entries, f, ensure_ascii=False, indent=2)
|
||||
os.replace(tmp_file, self.memory_file)
|
||||
|
||||
def add_entry(self, text: str, source: str = "user", category: str = "fact", owner: str = None) -> Dict:
|
||||
"""Add a new memory entry."""
|
||||
if not text.strip():
|
||||
raise ValueError("Memory text cannot be empty")
|
||||
|
||||
entry = {
|
||||
"id": str(uuid.uuid4()),
|
||||
"text": text.strip(),
|
||||
"timestamp": int(time.time()),
|
||||
"source": source,
|
||||
"category": category
|
||||
}
|
||||
if owner:
|
||||
entry["owner"] = owner
|
||||
return entry
|
||||
|
||||
def find_duplicates(self, text: str, entries: List[Dict] = None) -> List[Dict]:
|
||||
"""Find duplicate memory entries based on text content."""
|
||||
if entries is None:
|
||||
entries = self.load()
|
||||
|
||||
text_lower = text.strip().lower()
|
||||
return [entry for entry in entries if entry["text"].lower() == text_lower]
|
||||
|
||||
def categorize_memory_by_relevance(self, message: str, memories: list):
|
||||
"""Categorize memories by type and relevance"""
|
||||
categories = {
|
||||
"contacts": [],
|
||||
"preferences": [],
|
||||
"facts": [],
|
||||
"tasks": []
|
||||
}
|
||||
|
||||
msg_lower = message.lower()
|
||||
|
||||
for mem in memories:
|
||||
text_lower = mem["text"].lower()
|
||||
|
||||
# Contact info
|
||||
if any(word in text_lower for word in ["phone", "email", "address", "lives", "works"]):
|
||||
if any(word in msg_lower for word in ["contact", "phone", "address", "email"]):
|
||||
categories["contacts"].append(mem)
|
||||
|
||||
# Personal preferences
|
||||
elif any(word in text_lower for word in ["likes", "dislikes", "prefers", "favorite"]):
|
||||
if any(word in msg_lower for word in ["like", "prefer", "favorite", "want"]):
|
||||
categories["preferences"].append(mem)
|
||||
|
||||
# Tasks and todos
|
||||
elif any(word in text_lower for word in ["todo", "task", "remind", "meeting"]):
|
||||
if any(word in msg_lower for word in ["todo", "task", "schedule", "remind"]):
|
||||
categories["tasks"].append(mem)
|
||||
|
||||
# General facts - only if very relevant
|
||||
else:
|
||||
if get_text_similarity(message, mem["text"]) > 0.4:
|
||||
categories["facts"].append(mem)
|
||||
|
||||
return categories
|
||||
|
||||
def get_relevant_memories(self, query: str, memories: list, threshold: float = 0.05, max_items: int = 8):
|
||||
"""Get memories that are relevant to the query based on text similarity and semantic keyword matching."""
|
||||
if not memories or not query.strip():
|
||||
return []
|
||||
|
||||
# Define keyword categories for semantic matching
|
||||
identity_words = ["name", "who", "i", "am", "called", "identity", "myself", "me", "my"]
|
||||
contact_words = ["phone", "email", "address", "contact", "number", "where", "located", "reach"]
|
||||
preference_words = ["like", "prefer", "favorite", "want", "love", "hate", "dislike", "enjoy", "interested"]
|
||||
task_words = ["todo", "task", "remind", "meeting", "appointment", "schedule", "deadline"]
|
||||
fact_words = ["what", "when", "where", "how", "why", "explain", "describe", "information", "know"]
|
||||
|
||||
query_lower = query.lower()
|
||||
|
||||
# Determine query type based on keywords
|
||||
query_type = None
|
||||
if any(word in query_lower for word in identity_words):
|
||||
query_type = "identity"
|
||||
elif any(word in query_lower for word in contact_words):
|
||||
query_type = "contact"
|
||||
elif any(word in query_lower for word in preference_words):
|
||||
query_type = "preference"
|
||||
elif any(word in query_lower for word in task_words):
|
||||
query_type = "task"
|
||||
elif any(word in query_lower for word in fact_words):
|
||||
query_type = "fact"
|
||||
|
||||
relevant = []
|
||||
identity_memories = []
|
||||
other_memories = []
|
||||
|
||||
# Separate identity memories from others
|
||||
for memory in memories:
|
||||
memory_text = memory["text"].lower()
|
||||
# Check if this is an identity memory (contains name patterns or identity indicators)
|
||||
is_identity = any([
|
||||
re.search(r'\b[A-Z][a-z]+ [A-Z][a-z]+\b', memory["text"]),
|
||||
any(word in memory_text for word in ["name is", "i'm", "i am", "called", "my name", "named", "call me"])
|
||||
])
|
||||
if is_identity:
|
||||
identity_memories.append(memory)
|
||||
else:
|
||||
other_memories.append(memory)
|
||||
|
||||
# For identity queries, include all identity memories regardless of similarity
|
||||
if query_type == "identity" and identity_memories:
|
||||
# Give them high scores to ensure they're included first
|
||||
for memory in identity_memories:
|
||||
relevant.append((0.9, memory)) # High score for identity memories in identity queries
|
||||
|
||||
# Process other memories with similarity scoring
|
||||
for memory in other_memories:
|
||||
memory_text = memory["text"].lower()
|
||||
memory_tokens = set(tokenize(memory_text))
|
||||
query_tokens = set(tokenize(query_lower))
|
||||
|
||||
# Calculate base Jaccard similarity
|
||||
if not query_tokens or not memory_tokens:
|
||||
continue
|
||||
|
||||
base_similarity = len(query_tokens & memory_tokens) / len(query_tokens | memory_tokens)
|
||||
final_score = base_similarity
|
||||
|
||||
# Apply boosts based on semantic matching
|
||||
if query_type == "contact":
|
||||
# Boost memories with contact information
|
||||
has_contact_info = any(word in memory_text for word in ["@gmail.com", "@", ".com",
|
||||
"phone", "number", "address",
|
||||
"http", "www", "tel:"])
|
||||
if has_contact_info:
|
||||
final_score *= 1.4 # 40% boost for contact-related memories
|
||||
|
||||
elif query_type == "preference":
|
||||
# Boost memories with preference indicators
|
||||
has_preference = any(word in memory_text for word in ["like", "love", "hate", "dislike",
|
||||
"prefer", "favorite", "enjoy", "interested"])
|
||||
if has_preference:
|
||||
final_score *= 1.3 # 30% boost for preference-related memories
|
||||
|
||||
elif query_type == "task":
|
||||
# Boost memories with task indicators
|
||||
has_task = any(word in memory_text for word in ["todo", "task", "remind", "meeting",
|
||||
"appointment", "schedule", "deadline", "need to"])
|
||||
if has_task:
|
||||
final_score *= 1.3 # 30% boost for task-related memories
|
||||
|
||||
# Always consider exact phrase matches as highly relevant
|
||||
if query.lower() in memory["text"].lower():
|
||||
final_score = max(final_score, 0.8) # Ensure high relevance for exact matches
|
||||
|
||||
# Include memory if it meets threshold after boosts
|
||||
if final_score >= threshold:
|
||||
relevant.append((final_score, memory))
|
||||
|
||||
# Sort by final score (descending) and return top matches
|
||||
relevant.sort(key=lambda x: x[0], reverse=True)
|
||||
return [mem for _, mem in relevant[:max_items]]
|
||||
@@ -0,0 +1,533 @@
|
||||
"""
|
||||
memory_extractor.py
|
||||
|
||||
Background auto-extraction of facts from chat conversations.
|
||||
After each LLM response, this module sends the last few messages to the LLM
|
||||
asking it to extract memorable facts, then stores them in both memory.json
|
||||
and the FAISS vector index.
|
||||
|
||||
Periodically audits all memories via LLM to consolidate duplicates,
|
||||
rewrite vague entries, and remove junk.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from typing import Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _tidy_state_path(memory_manager) -> str:
|
||||
"""Sidecar JSON next to memory.json that remembers the fingerprint of
|
||||
the last successfully-audited state per owner. Lets the audit short-
|
||||
circuit when nothing has changed since the previous tidy — running
|
||||
the LLM again on an already-clean list was wasting 30-120s per call
|
||||
and occasionally timing out on the second pass."""
|
||||
return os.path.join(os.path.dirname(memory_manager.memory_file), "memory_tidy_state.json")
|
||||
|
||||
|
||||
def _fingerprint_entries(entries) -> str:
|
||||
"""Stable hash of an owner's memories — order-independent, depends
|
||||
only on id+text+category. Any add/edit/delete invalidates it."""
|
||||
items = sorted(
|
||||
(str(e.get("id", "")), e.get("text", ""), e.get("category", ""))
|
||||
for e in entries
|
||||
)
|
||||
h = hashlib.sha256()
|
||||
for triple in items:
|
||||
h.update(("\x1f".join(triple) + "\x1e").encode("utf-8"))
|
||||
return h.hexdigest()
|
||||
|
||||
|
||||
def _load_tidy_state(memory_manager) -> dict:
|
||||
path = _tidy_state_path(memory_manager)
|
||||
try:
|
||||
with open(path, "r") as f:
|
||||
data = json.load(f)
|
||||
return data if isinstance(data, dict) else {}
|
||||
except (FileNotFoundError, json.JSONDecodeError):
|
||||
return {}
|
||||
|
||||
|
||||
def _save_tidy_state(memory_manager, owner: Optional[str], fingerprint: str) -> None:
|
||||
path = _tidy_state_path(memory_manager)
|
||||
state = _load_tidy_state(memory_manager)
|
||||
state[owner or ""] = {"fingerprint": fingerprint}
|
||||
try:
|
||||
with open(path, "w") as f:
|
||||
json.dump(state, f, indent=2)
|
||||
except OSError as e:
|
||||
logger.warning(f"Could not persist tidy fingerprint: {e}")
|
||||
|
||||
EXTRACT_SYSTEM_PROMPT = (
|
||||
"You are a memory extraction assistant. Analyze the conversation and extract ONLY "
|
||||
"durable personal facts about the user that would be useful across many future conversations.\n\n"
|
||||
"Good examples: name, job title, city, family members, long-term projects, strong preferences.\n"
|
||||
"Bad examples: what they asked about today, temporary moods, generic statements, "
|
||||
"things the assistant said, one-off tasks, opinions on the current topic.\n\n"
|
||||
"Rules:\n"
|
||||
"- MAX 2 facts per conversation — only the most important\n"
|
||||
"- Only extract facts the USER stated or clearly implied\n"
|
||||
"- Each fact must be a single short sentence (under 15 words)\n"
|
||||
"- If a fact is similar to something likely already known, skip it\n"
|
||||
"- If nothing durable was revealed, return []\n\n"
|
||||
"Return a JSON array of objects with 'text' and 'category' fields.\n"
|
||||
"Categories: 'identity', 'preference', 'fact', 'contact', 'project', 'goal'\n\n"
|
||||
"Return ONLY valid JSON, no markdown fences."
|
||||
)
|
||||
|
||||
# How many recent messages to include for extraction
|
||||
CONTEXT_WINDOW = 6
|
||||
|
||||
AUDIT_SYSTEM_PROMPT = (
|
||||
"You are a memory database curator. Be CONSERVATIVE: remove only TRUE "
|
||||
"duplicates and clearly useless entries. Every distinct fact must survive. "
|
||||
"When in doubt, KEEP the entry. Return the cleaned list.\n\n"
|
||||
"Rules:\n"
|
||||
"1. MERGE only entries that state the SAME fact in different words. If you "
|
||||
"are not sure two entries are the same fact, KEEP BOTH.\n"
|
||||
" Merge: 'User's name is Sam' + 'The user is called Sam' -> one.\n"
|
||||
" Do NOT merge related-but-distinct facts: 'Likes Python' and 'Uses "
|
||||
"Python at work' are DIFFERENT — keep both.\n"
|
||||
"2. REMOVE only entries that are genuinely worthless: about what the AI did "
|
||||
"(not the user), empty, or meaningless. Do NOT drop a real fact just "
|
||||
"because it seems minor or niche.\n"
|
||||
"3. Keep the original wording. Only lightly trim obvious redundancy — do "
|
||||
"NOT aggressively rewrite or shorten.\n"
|
||||
"4. Preserve the 'id' of the entry you keep when merging.\n"
|
||||
"5. Never invent facts. When unsure, KEEP.\n\n"
|
||||
"Return a JSON array of objects with fields: id, text, category.\n"
|
||||
"Return ONLY valid JSON, no markdown fences."
|
||||
)
|
||||
|
||||
AUDIT_INTERVAL = 5 # audit every N new memories added
|
||||
_extractions_since_audit = 0
|
||||
|
||||
|
||||
def _message_text(message) -> str:
|
||||
content = getattr(message, "content", None)
|
||||
if content is None and isinstance(message, dict):
|
||||
content = message.get("content")
|
||||
if isinstance(content, str):
|
||||
return content.strip()
|
||||
if isinstance(content, list):
|
||||
parts = []
|
||||
for item in content:
|
||||
if isinstance(item, dict):
|
||||
parts.append(str(item.get("text") or item.get("content") or ""))
|
||||
else:
|
||||
parts.append(str(item))
|
||||
return " ".join(p for p in parts if p).strip()
|
||||
return ""
|
||||
|
||||
|
||||
def _message_role(message) -> str:
|
||||
role = getattr(message, "role", None)
|
||||
if role is None and isinstance(message, dict):
|
||||
role = message.get("role")
|
||||
return str(role or "").lower()
|
||||
|
||||
|
||||
def _clean_memory_value(value: str, max_len: int = 80) -> str:
|
||||
value = re.sub(r"\s+", " ", value or "").strip(" .,!?:;\"'`“”‘’")
|
||||
value = re.sub(r"^(?:the|a|an)\s+", "", value, flags=re.I)
|
||||
if not value or len(value) > max_len:
|
||||
return ""
|
||||
if re.search(r"https?://|@|[{}<>]", value):
|
||||
return ""
|
||||
return value
|
||||
|
||||
|
||||
def _fallback_memory_candidates(messages) -> list[dict]:
|
||||
"""Extract obvious durable facts without relying on the LLM.
|
||||
|
||||
This is deliberately narrow. The LLM remains the main extractor, but
|
||||
simple identity/preference/goal statements should not silently vanish just
|
||||
because the background model judged them too conversational.
|
||||
"""
|
||||
candidates = []
|
||||
seen = set()
|
||||
|
||||
def add(text: str, category: str):
|
||||
text = _clean_memory_value(text, 120)
|
||||
if not text:
|
||||
return
|
||||
key = text.lower()
|
||||
if key in seen:
|
||||
return
|
||||
seen.add(key)
|
||||
candidates.append({"text": text, "category": category})
|
||||
|
||||
for msg in messages:
|
||||
if _message_role(msg) != "user":
|
||||
continue
|
||||
text = _message_text(msg)
|
||||
if not text:
|
||||
continue
|
||||
|
||||
m = re.search(r"\bmy name is\s+([A-Za-z][A-Za-z0-9 .'\-]{1,50})\b", text, re.I)
|
||||
if m:
|
||||
name = _clean_memory_value(m.group(1), 50)
|
||||
if name:
|
||||
add(f"User's name is {name}.", "identity")
|
||||
|
||||
m = re.search(r"\bcall me\s+([A-Za-z][A-Za-z0-9 .'\-]{1,50})\b", text, re.I)
|
||||
if m:
|
||||
name = _clean_memory_value(m.group(1), 50)
|
||||
if name:
|
||||
add(f"User wants to be called {name}.", "identity")
|
||||
|
||||
m = re.search(r"\bi (?:live in|am from|'m from)\s+([^.!?\n]{2,80})", text, re.I)
|
||||
if m:
|
||||
place = _clean_memory_value(m.group(1), 80)
|
||||
if place:
|
||||
add(f"User lives in {place}.", "identity")
|
||||
|
||||
m = re.search(r"\bi (?:prefer|like|love|hate|do not like|don't like)\s+([^.!?\n]{4,100})", text, re.I)
|
||||
if m:
|
||||
preference = _clean_memory_value(m.group(1), 100)
|
||||
if preference:
|
||||
add(f"User prefers {preference}.", "preference")
|
||||
|
||||
m = re.search(
|
||||
r"\bi (?:(?:want|would like|plan|hope) to|wanna) "
|
||||
r"(?:go|travel|move|visit) to\s+([^.!?\n]{2,80})",
|
||||
text,
|
||||
re.I,
|
||||
)
|
||||
if m:
|
||||
destination = _clean_memory_value(m.group(1), 80)
|
||||
if destination:
|
||||
add(f"User wants to visit {destination}.", "goal")
|
||||
|
||||
return candidates[:2]
|
||||
|
||||
|
||||
def _is_text_duplicate(new_text: str, existing: list, threshold: float = 0.6) -> bool:
|
||||
"""Check if new_text is too similar to any existing memory (Jaccard similarity)."""
|
||||
new_tokens = set(new_text.lower().split())
|
||||
if not new_tokens:
|
||||
return False
|
||||
for entry in existing:
|
||||
old_tokens = set(entry.get("text", "").lower().split())
|
||||
if not old_tokens:
|
||||
continue
|
||||
intersection = new_tokens & old_tokens
|
||||
union = new_tokens | old_tokens
|
||||
if len(intersection) / len(union) >= threshold:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
async def extract_and_store(
|
||||
session,
|
||||
memory_manager,
|
||||
memory_vector,
|
||||
endpoint_url: str,
|
||||
model: str,
|
||||
headers: Optional[dict] = None,
|
||||
):
|
||||
"""Extract facts from recent conversation and store them.
|
||||
|
||||
Designed to run as a background task (asyncio.create_task).
|
||||
Errors are logged, never raised.
|
||||
"""
|
||||
try:
|
||||
from src.llm_core import llm_call_async
|
||||
|
||||
# Get last N messages from session
|
||||
messages = session.get_context_messages()
|
||||
recent = messages[-CONTEXT_WINDOW:] if len(messages) > CONTEXT_WINDOW else messages
|
||||
|
||||
if len(recent) < 2:
|
||||
return # Need at least a user message and assistant response
|
||||
|
||||
fallback_facts = _fallback_memory_candidates(recent)
|
||||
|
||||
extraction_messages = [
|
||||
{"role": "system", "content": EXTRACT_SYSTEM_PROMPT},
|
||||
] + recent
|
||||
|
||||
facts = []
|
||||
try:
|
||||
raw = await llm_call_async(
|
||||
endpoint_url,
|
||||
model,
|
||||
extraction_messages,
|
||||
temperature=0.1,
|
||||
max_tokens=500,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
# Parse JSON from response (handle markdown fences if model wraps them)
|
||||
text = raw.strip()
|
||||
if text.startswith("```"):
|
||||
text = text.split("\n", 1)[-1].rsplit("```", 1)[0].strip()
|
||||
|
||||
try:
|
||||
facts = json.loads(text)
|
||||
except json.JSONDecodeError:
|
||||
logger.debug("Memory extraction returned non-JSON")
|
||||
except Exception as e:
|
||||
logger.warning(f"LLM memory extraction failed; using fallback candidates if available: {e}")
|
||||
|
||||
if not isinstance(facts, list):
|
||||
facts = []
|
||||
|
||||
if fallback_facts:
|
||||
facts = list(facts) + fallback_facts
|
||||
|
||||
if not facts:
|
||||
logger.info("Auto memory extraction ran: 0 candidates")
|
||||
return
|
||||
|
||||
# Get owner from session
|
||||
_owner = getattr(session, 'owner', None)
|
||||
|
||||
existing = memory_manager.load_all()
|
||||
added = 0
|
||||
|
||||
for fact in facts:
|
||||
if isinstance(fact, str):
|
||||
fact_text = fact
|
||||
category = "fact"
|
||||
elif isinstance(fact, dict):
|
||||
fact_text = fact.get("text", "").strip()
|
||||
category = fact.get("category", "fact")
|
||||
else:
|
||||
continue
|
||||
|
||||
if not fact_text or len(fact_text) < 5:
|
||||
continue
|
||||
|
||||
# Dedup: check vector similarity first (fast), then exact text match
|
||||
if memory_vector and memory_vector.healthy:
|
||||
existing_id = memory_vector.find_similar(fact_text, threshold=0.72)
|
||||
if existing_id:
|
||||
logger.debug(f"Memory dedup (vector): '{fact_text[:50]}' matches {existing_id}")
|
||||
continue
|
||||
|
||||
# Text dedup fallback: exact match + fuzzy similarity
|
||||
user_existing = [e for e in existing if e.get("owner") == _owner or e.get("owner") is None] if _owner else existing
|
||||
if memory_manager.find_duplicates(fact_text, user_existing):
|
||||
continue
|
||||
# Fuzzy text similarity check (catches rephrased duplicates when vector index is unavailable)
|
||||
if _is_text_duplicate(fact_text, user_existing):
|
||||
logger.debug(f"Memory dedup (fuzzy): '{fact_text[:50]}' too similar to existing")
|
||||
continue
|
||||
|
||||
entry = memory_manager.add_entry(fact_text, source="auto", category=category, owner=_owner)
|
||||
# Auto-pin identity facts (name, job, location) — core context
|
||||
if category == "identity":
|
||||
entry["pinned"] = True
|
||||
if hasattr(session, "session_id"):
|
||||
entry["session_id"] = session.session_id
|
||||
elif hasattr(session, "name"):
|
||||
entry["session_id"] = session.name
|
||||
|
||||
existing.append(entry)
|
||||
|
||||
# Add to vector index
|
||||
if memory_vector and memory_vector.healthy:
|
||||
memory_vector.add(entry["id"], fact_text)
|
||||
|
||||
added += 1
|
||||
|
||||
if added > 0:
|
||||
memory_manager.save(existing)
|
||||
try:
|
||||
from src.event_bus import fire_event
|
||||
for _ in range(added):
|
||||
fire_event("memory_added", _owner)
|
||||
except Exception:
|
||||
logger.debug("memory_added event dispatch failed", exc_info=True)
|
||||
logger.info(f"Auto-extracted {added} memories from session")
|
||||
|
||||
global _extractions_since_audit
|
||||
_extractions_since_audit += added
|
||||
if _extractions_since_audit >= AUDIT_INTERVAL:
|
||||
_extractions_since_audit = 0
|
||||
logger.info("Audit threshold reached, running memory audit")
|
||||
await audit_memories(
|
||||
memory_manager, memory_vector, endpoint_url, model, headers, owner=_owner
|
||||
)
|
||||
else:
|
||||
logger.info("Auto memory extraction ran: 0 added")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Memory extraction failed: {e}")
|
||||
|
||||
|
||||
async def audit_memories(
|
||||
memory_manager,
|
||||
memory_vector,
|
||||
endpoint_url: str,
|
||||
model: str,
|
||||
headers: Optional[dict] = None,
|
||||
owner: Optional[str] = None,
|
||||
):
|
||||
"""Send all memories to the LLM for deduplication and consolidation.
|
||||
|
||||
- Merges near-duplicate entries
|
||||
- Rewrites vague entries to be concise
|
||||
- Removes junk / non-personal entries
|
||||
- Rebuilds the vector index afterwards
|
||||
|
||||
Safe to call manually or from the automatic trigger in extract_and_store.
|
||||
Errors are logged, never raised.
|
||||
"""
|
||||
try:
|
||||
from src.llm_core import llm_call_async
|
||||
|
||||
existing = memory_manager.load(owner=owner)
|
||||
if not existing:
|
||||
logger.info("Memory audit: nothing to audit")
|
||||
return {"before": 0, "after": 0}
|
||||
|
||||
before_count = len(existing)
|
||||
|
||||
# Skip the LLM call entirely when this exact set of memories was
|
||||
# already audited — the previous tidy left them in a clean state
|
||||
# and nothing has changed since. Returns instantly so the UI shows
|
||||
# "Already clean" without spending 30-120s on a wasted LLM round.
|
||||
# The fingerprint includes id+text+category; any add/edit/delete
|
||||
# invalidates it and the audit runs normally.
|
||||
current_fp = _fingerprint_entries(existing)
|
||||
last_state = _load_tidy_state(memory_manager).get(owner or "") or {}
|
||||
if last_state.get("fingerprint") == current_fp:
|
||||
logger.info("Memory audit: state unchanged since last tidy — skipping LLM")
|
||||
return {
|
||||
"before": before_count,
|
||||
"after": before_count,
|
||||
"already_tidy": True,
|
||||
}
|
||||
|
||||
# Build payload: list of {id, text, category} for the LLM
|
||||
memory_payload = [
|
||||
{"id": m["id"], "text": m["text"], "category": m.get("category", "fact")}
|
||||
for m in existing
|
||||
]
|
||||
|
||||
audit_messages = [
|
||||
{"role": "system", "content": AUDIT_SYSTEM_PROMPT},
|
||||
{"role": "user", "content": json.dumps(memory_payload, ensure_ascii=False)},
|
||||
]
|
||||
|
||||
raw = await llm_call_async(
|
||||
endpoint_url,
|
||||
model,
|
||||
audit_messages,
|
||||
temperature=0.1,
|
||||
# 16384 (was 2000): the deduped list of all memories can be large,
|
||||
# and a reasoning model spends tokens thinking first — 2000 truncated
|
||||
# the JSON so it never parsed ("bad_json").
|
||||
max_tokens=16384,
|
||||
headers=headers,
|
||||
# Bound the call so the Tidy whirlpool can't spin indefinitely on a
|
||||
# slow/large generation.
|
||||
timeout=120,
|
||||
)
|
||||
|
||||
# Parse the JSON list, tolerating reasoning-model noise: <think> blocks,
|
||||
# markdown fences, leading prose, and trailing commas.
|
||||
import re as _re
|
||||
text = (raw or "").strip()
|
||||
text = _re.sub(r'<think(?:ing)?>[\s\S]*?</think(?:ing)?>', '', text, flags=_re.I).strip()
|
||||
|
||||
def _loads_list(s):
|
||||
if not s:
|
||||
return None
|
||||
for cand in (s, _re.sub(r',(\s*[}\]])', r'\1', s)):
|
||||
try:
|
||||
v = json.loads(cand)
|
||||
if isinstance(v, list):
|
||||
return v
|
||||
except Exception:
|
||||
continue
|
||||
return None
|
||||
|
||||
cleaned = _loads_list(text)
|
||||
if cleaned is None:
|
||||
_m = _re.search(r'```(?:json)?\s*\n?([\s\S]*?)```', text)
|
||||
if _m:
|
||||
cleaned = _loads_list(_m.group(1).strip())
|
||||
if cleaned is None:
|
||||
_a, _b = text.find('['), text.rfind(']')
|
||||
if _a >= 0 and _b > _a:
|
||||
cleaned = _loads_list(text[_a:_b + 1])
|
||||
if cleaned is None:
|
||||
logger.error(f"Memory audit returned non-JSON: {text[:300]}")
|
||||
return {"before": before_count, "after": before_count, "error": "bad_json"}
|
||||
|
||||
# Build lookup of original entries by ID so we can preserve metadata
|
||||
originals = {m["id"]: m for m in existing}
|
||||
|
||||
final_entries = []
|
||||
for item in cleaned:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
mid = item.get("id", "")
|
||||
new_text = item.get("text", "").strip()
|
||||
if not new_text:
|
||||
continue
|
||||
|
||||
if mid in originals:
|
||||
# Preserve original metadata, update text + category
|
||||
entry = originals[mid].copy()
|
||||
entry["text"] = new_text
|
||||
if item.get("category"):
|
||||
entry["category"] = item["category"]
|
||||
else:
|
||||
# ID not found — skip to avoid inventing entries
|
||||
logger.debug(f"Audit returned unknown id {mid}, skipping")
|
||||
continue
|
||||
|
||||
final_entries.append(entry)
|
||||
|
||||
after_count = len(final_entries)
|
||||
|
||||
# Safety net against catastrophic over-deletion. A conservative tidy
|
||||
# should never wipe out half the store in one pass — if the model
|
||||
# returned far fewer entries than it was given (over-consolidation, a
|
||||
# dropped/truncated list, or it ignored ids), treat it as a misfire and
|
||||
# DON'T save. Better to no-op than to silently lose memories.
|
||||
if before_count >= 8 and after_count < before_count * 0.5:
|
||||
logger.warning(
|
||||
f"Memory audit would cut {before_count} -> {after_count} "
|
||||
f"(>50% removed) — refusing as unsafe, keeping originals"
|
||||
)
|
||||
return {"before": before_count, "after": before_count, "error": "unsafe_removal"}
|
||||
|
||||
# Merge audited entries back with other users' entries
|
||||
if owner:
|
||||
all_entries = memory_manager.load_all()
|
||||
audited_ids = {e["id"] for e in final_entries}
|
||||
other_entries = [e for e in all_entries if e.get("owner") != owner and (e.get("owner") is not None)]
|
||||
# Also keep legacy entries that weren't part of this audit
|
||||
for e in all_entries:
|
||||
if e.get("owner") is None and e["id"] not in audited_ids and e["id"] not in {o["id"] for o in other_entries}:
|
||||
other_entries.append(e)
|
||||
memory_manager.save(final_entries + other_entries)
|
||||
else:
|
||||
memory_manager.save(final_entries)
|
||||
logger.info(
|
||||
f"Memory audit complete: {before_count} -> {after_count} entries "
|
||||
f"({before_count - after_count} removed/merged)"
|
||||
)
|
||||
|
||||
# Rebuild vector index
|
||||
if memory_vector and memory_vector.healthy:
|
||||
memory_vector.rebuild(final_entries)
|
||||
|
||||
# Persist the post-tidy fingerprint so the next call short-circuits
|
||||
# if nothing has changed in the meantime.
|
||||
_save_tidy_state(memory_manager, owner, _fingerprint_entries(final_entries))
|
||||
|
||||
return {"before": before_count, "after": after_count}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Memory audit failed: {e}")
|
||||
return {"error": str(e)}
|
||||
@@ -0,0 +1,175 @@
|
||||
"""
|
||||
memory_vector.py
|
||||
|
||||
ChromaDB-backed vector store for memory entries.
|
||||
Shares the EmbeddingClient with RAG to save memory.
|
||||
Stores pre-computed embeddings (ChromaDB does not manage embedding).
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import List, Dict, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MemoryVectorStore:
|
||||
"""Vector index over memory entries for semantic retrieval."""
|
||||
|
||||
COLLECTION_NAME = "odysseus_memories"
|
||||
|
||||
def __init__(self, data_dir: str, embedding_model=None):
|
||||
self._model = embedding_model
|
||||
self._collection = None
|
||||
self._healthy = False
|
||||
|
||||
self._initialize()
|
||||
|
||||
def _initialize(self):
|
||||
try:
|
||||
from src.chroma_client import get_chroma_client
|
||||
|
||||
if self._model is None:
|
||||
from src.embeddings import get_embedding_client
|
||||
self._model = get_embedding_client()
|
||||
if self._model is None:
|
||||
raise RuntimeError("No embedding backend available")
|
||||
logger.info(f"MemoryVectorStore using embeddings: {self._model.url}")
|
||||
|
||||
client = get_chroma_client()
|
||||
self._collection = client.get_or_create_collection(
|
||||
name=self.COLLECTION_NAME,
|
||||
metadata={"hnsw:space": "cosine"},
|
||||
)
|
||||
|
||||
self._healthy = True
|
||||
count = self._collection.count()
|
||||
logger.info(f"MemoryVectorStore ready (entries={count})")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"MemoryVectorStore init failed: {e}")
|
||||
|
||||
@property
|
||||
def healthy(self) -> bool:
|
||||
return self._healthy
|
||||
|
||||
def _embed(self, texts: List[str]) -> List[List[float]]:
|
||||
vecs = self._model.encode(texts, normalize_embeddings=True)
|
||||
return vecs.tolist()
|
||||
|
||||
def count(self) -> int:
|
||||
"""Return the number of stored vectors."""
|
||||
if not self._healthy:
|
||||
return 0
|
||||
return self._collection.count()
|
||||
|
||||
def add(self, memory_id: str, text: str):
|
||||
"""Add a single memory entry to the vector index."""
|
||||
if not self._healthy:
|
||||
return
|
||||
# Skip if already exists
|
||||
existing = self._collection.get(ids=[memory_id])
|
||||
if existing["ids"]:
|
||||
return
|
||||
embeddings = self._embed([text])
|
||||
self._collection.add(
|
||||
ids=[memory_id],
|
||||
embeddings=embeddings,
|
||||
documents=[text],
|
||||
metadatas=[{"source": "memory"}],
|
||||
)
|
||||
|
||||
def remove(self, memory_id: str):
|
||||
"""Remove a memory entry. O(1) — no rebuild needed."""
|
||||
if not self._healthy:
|
||||
return
|
||||
try:
|
||||
self._collection.delete(ids=[memory_id])
|
||||
except Exception as e:
|
||||
logger.warning(f"memory remove {memory_id}: {e}")
|
||||
|
||||
def search(self, query: str, k: int = 8) -> List[Dict]:
|
||||
"""Search for the most relevant memory IDs by semantic similarity.
|
||||
Returns list of {"memory_id": str, "score": float}.
|
||||
|
||||
ChromaDB cosine distance = 1 - cosine_similarity.
|
||||
We convert back: similarity = 1.0 - distance.
|
||||
"""
|
||||
if not self._healthy or self._collection.count() == 0:
|
||||
return []
|
||||
|
||||
embeddings = self._embed([query])
|
||||
actual_k = min(k, self._collection.count())
|
||||
results = self._collection.query(
|
||||
query_embeddings=embeddings,
|
||||
n_results=actual_k,
|
||||
)
|
||||
|
||||
out = []
|
||||
for idx, mid in enumerate(results["ids"][0]):
|
||||
distance = results["distances"][0][idx]
|
||||
out.append({
|
||||
"memory_id": mid,
|
||||
"score": round(1.0 - distance, 4),
|
||||
})
|
||||
return out
|
||||
|
||||
def find_similar(self, text: str, threshold: float = 0.92) -> Optional[str]:
|
||||
"""Check if a near-duplicate exists. Returns memory_id if found, else None."""
|
||||
if not self._healthy or self._collection.count() == 0:
|
||||
return None
|
||||
|
||||
embeddings = self._embed([text])
|
||||
results = self._collection.query(
|
||||
query_embeddings=embeddings,
|
||||
n_results=1,
|
||||
)
|
||||
|
||||
if results["ids"][0]:
|
||||
distance = results["distances"][0][0]
|
||||
similarity = 1.0 - distance
|
||||
if similarity >= threshold:
|
||||
return results["ids"][0][0]
|
||||
return None
|
||||
|
||||
def rebuild(self, memories: List[Dict]):
|
||||
"""Rebuild the entire index from a list of memory entries.
|
||||
Each entry must have 'id' and 'text' keys."""
|
||||
if not self._healthy:
|
||||
return
|
||||
|
||||
from src.chroma_client import get_chroma_client
|
||||
|
||||
# Delete and recreate collection for a clean rebuild
|
||||
client = get_chroma_client()
|
||||
try:
|
||||
client.delete_collection(self.COLLECTION_NAME)
|
||||
except Exception:
|
||||
pass
|
||||
self._collection = client.get_or_create_collection(
|
||||
name=self.COLLECTION_NAME,
|
||||
metadata={"hnsw:space": "cosine"},
|
||||
)
|
||||
|
||||
texts = []
|
||||
ids = []
|
||||
for mem in memories:
|
||||
text = mem.get("text", "").strip()
|
||||
mid = mem.get("id", "")
|
||||
if text and mid:
|
||||
texts.append(text)
|
||||
ids.append(mid)
|
||||
|
||||
if texts:
|
||||
# Batch in chunks of 100 to avoid oversized requests
|
||||
for i in range(0, len(texts), 100):
|
||||
batch_texts = texts[i:i + 100]
|
||||
batch_ids = ids[i:i + 100]
|
||||
embeddings = self._embed(batch_texts)
|
||||
self._collection.add(
|
||||
ids=batch_ids,
|
||||
embeddings=embeddings,
|
||||
documents=batch_texts,
|
||||
metadatas=[{"source": "memory"}] * len(batch_ids),
|
||||
)
|
||||
|
||||
logger.info(f"MemoryVectorStore rebuilt with {len(ids)} entries")
|
||||
@@ -0,0 +1,137 @@
|
||||
# services/memory/service.py
|
||||
"""Memory service — persistent memory storage and retrieval."""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Optional, Dict, Any
|
||||
import os
|
||||
|
||||
from .memory import MemoryManager
|
||||
from .memory_vector import MemoryVectorStore
|
||||
|
||||
|
||||
@dataclass
|
||||
class Memory:
|
||||
"""A stored memory."""
|
||||
id: str
|
||||
text: str
|
||||
timestamp: int
|
||||
session_id: Optional[str] = None
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MemorySearchResult:
|
||||
"""Result of memory search."""
|
||||
memories: List[Memory]
|
||||
query: str
|
||||
total: int
|
||||
|
||||
|
||||
class MemoryService:
|
||||
"""
|
||||
Memory storage and retrieval service.
|
||||
|
||||
Usage:
|
||||
service = MemoryService()
|
||||
await service.remember("User prefers dark mode")
|
||||
results = await service.recall("preferences")
|
||||
"""
|
||||
|
||||
def __init__(self, data_dir: str = "data"):
|
||||
self.manager = MemoryManager(data_dir)
|
||||
self.vector_store = MemoryVectorStore(data_dir) if os.path.exists(
|
||||
os.path.join(data_dir, "memory_vectors")
|
||||
) else None
|
||||
|
||||
async def remember(self, text: str, session_id: Optional[str] = None) -> Memory:
|
||||
"""
|
||||
Store a new memory.
|
||||
|
||||
Args:
|
||||
text: Memory content
|
||||
session_id: Optional session association
|
||||
|
||||
Returns:
|
||||
Created Memory object
|
||||
"""
|
||||
import uuid
|
||||
import time
|
||||
|
||||
memory_id = str(uuid.uuid4())[:8]
|
||||
timestamp = int(time.time())
|
||||
|
||||
entry = {
|
||||
"id": memory_id,
|
||||
"text": text,
|
||||
"timestamp": timestamp,
|
||||
"session_id": session_id,
|
||||
}
|
||||
|
||||
self.manager.add_memory(entry)
|
||||
|
||||
# Also add to vector store if available
|
||||
if self.vector_store:
|
||||
self.vector_store.add(text, {"id": memory_id, "session_id": session_id})
|
||||
|
||||
return Memory(
|
||||
id=memory_id,
|
||||
text=text,
|
||||
timestamp=timestamp,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
async def recall(self, query: str, top_k: int = 5) -> MemorySearchResult:
|
||||
"""
|
||||
Search memories.
|
||||
|
||||
Args:
|
||||
query: Search query
|
||||
top_k: Max results
|
||||
|
||||
Returns:
|
||||
MemorySearchResult with matching memories
|
||||
"""
|
||||
# Try vector search first
|
||||
if self.vector_store:
|
||||
results = self.vector_store.search(query, k=top_k)
|
||||
memories = [
|
||||
Memory(
|
||||
id=r.get("id", ""),
|
||||
text=r.get("text", ""),
|
||||
timestamp=r.get("timestamp", 0),
|
||||
session_id=r.get("session_id"),
|
||||
metadata=r.get("metadata", {}),
|
||||
)
|
||||
for r in results
|
||||
]
|
||||
return MemorySearchResult(memories=memories, query=query, total=len(memories))
|
||||
|
||||
# Fallback to keyword search
|
||||
results = self.manager.search_memories(query, limit=top_k)
|
||||
memories = [
|
||||
Memory(
|
||||
id=m.get("id", ""),
|
||||
text=m.get("text", ""),
|
||||
timestamp=m.get("timestamp", 0),
|
||||
session_id=m.get("session_id"),
|
||||
)
|
||||
for m in results
|
||||
]
|
||||
return MemorySearchResult(memories=memories, query=query, total=len(memories))
|
||||
|
||||
def get_all(self, limit: int = 100) -> List[Memory]:
|
||||
"""Get all memories."""
|
||||
memories = self.manager.get_memories(limit=limit)
|
||||
return [
|
||||
Memory(
|
||||
id=m.get("id", ""),
|
||||
text=m.get("text", ""),
|
||||
timestamp=m.get("timestamp", 0),
|
||||
session_id=m.get("session_id"),
|
||||
)
|
||||
for m in memories
|
||||
]
|
||||
|
||||
def delete(self, memory_id: str) -> bool:
|
||||
"""Delete a memory by ID."""
|
||||
return self.manager.delete_memory(memory_id)
|
||||
@@ -0,0 +1,209 @@
|
||||
"""
|
||||
skill_extractor.py
|
||||
|
||||
Background auto-extraction of skills from complex agent runs.
|
||||
When the agent takes >= 2 rounds or >= 2 tool calls to complete a task,
|
||||
we ask the LLM to distill the approach into a reusable skill.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SKILL_EXTRACT_PROMPT = (
|
||||
"You are analyzing an AI agent's work session. The agent took {rounds} rounds "
|
||||
"and {tool_count} tool calls to complete the task.\n\n"
|
||||
"Extract a reusable 'skill' ONLY IF the session contains a concrete, "
|
||||
"repeatable procedure the agent could follow to solve a similar problem "
|
||||
"ON THE COMPUTER next time (e.g. a sequence of shell commands, code, file "
|
||||
"edits, API calls, or tool usage).\n\n"
|
||||
"Return null (the bare word, no JSON) when the session is NOT a reusable "
|
||||
"computer procedure, including:\n"
|
||||
"- The real work happened OUTSIDE the computer (the user did something "
|
||||
"physically, in person, on another device, or by hand) and the agent only "
|
||||
"discussed or advised it.\n"
|
||||
"- A one-off, personal, or context-specific task that won't recur "
|
||||
"(personal errands, a specific person/place/date, casual conversation).\n"
|
||||
"- A pure question/answer or explanation with no transferable method.\n"
|
||||
"- The agent failed, gave up, or the approach is not worth repeating.\n\n"
|
||||
"When (and only when) a genuine reusable procedure exists, return a JSON "
|
||||
"object with:\n"
|
||||
'- "title": short name (under 10 words)\n'
|
||||
'- "problem": what was the challenge (1-2 sentences)\n'
|
||||
'- "solution": what worked (1-2 sentences)\n'
|
||||
'- "steps": array of step-by-step instructions (3-7 short steps)\n'
|
||||
'- "tags": array of relevant keywords (3-5 tags)\n'
|
||||
'- "confidence": 0.0-1.0 how reliable AND reusable this procedure is\n\n'
|
||||
"Be conservative: if in doubt, return null.\n"
|
||||
"Return ONLY valid JSON (or the bare word null), no markdown fences."
|
||||
)
|
||||
|
||||
# Skills the model is unsure about (or that read as one-offs) add clutter —
|
||||
# drop anything below this confidence.
|
||||
MIN_CONFIDENCE = 0.6
|
||||
|
||||
# How many recent messages to include
|
||||
CONTEXT_WINDOW = 12
|
||||
|
||||
|
||||
async def maybe_extract_skill(
|
||||
session,
|
||||
skills_manager,
|
||||
endpoint_url: str,
|
||||
model: str,
|
||||
headers: dict,
|
||||
round_count: int,
|
||||
tool_count: int,
|
||||
owner: Optional[str] = None,
|
||||
):
|
||||
"""Extract a skill if the agent run was complex enough."""
|
||||
# Quiet by default; flip to DEBUG when chasing extractor issues.
|
||||
logger.debug(
|
||||
"[skill-extract] start: rounds=%d tools=%d model=%s owner=%s",
|
||||
round_count, tool_count, model, owner,
|
||||
)
|
||||
if round_count < 2 and tool_count < 2:
|
||||
logger.debug("[skill-extract] BELOW threshold (need rounds>=2 or tools>=2)")
|
||||
return None
|
||||
|
||||
try:
|
||||
from src.llm_core import llm_call_async
|
||||
|
||||
# Get recent messages
|
||||
history = session.get_context_messages()
|
||||
recent = history[-CONTEXT_WINDOW:] if len(history) > CONTEXT_WINDOW else history
|
||||
if not recent:
|
||||
logger.debug("[skill-extract] no recent messages, skipping")
|
||||
return None
|
||||
|
||||
# Build conversation summary for extraction
|
||||
conv_lines = []
|
||||
for msg in recent:
|
||||
role = msg.get("role", "?")
|
||||
content = msg.get("content", "")
|
||||
if isinstance(content, list):
|
||||
content = " ".join(
|
||||
b.get("text", "") for b in content if isinstance(b, dict) and b.get("type") == "text"
|
||||
)
|
||||
# Truncate long messages
|
||||
if len(content) > 500:
|
||||
content = content[:500] + "..."
|
||||
conv_lines.append(f"[{role}] {content}")
|
||||
|
||||
conversation = "\n".join(conv_lines)
|
||||
|
||||
prompt = SKILL_EXTRACT_PROMPT.format(rounds=round_count, tool_count=tool_count)
|
||||
|
||||
import time as _time
|
||||
_t0 = _time.monotonic()
|
||||
logger.debug(
|
||||
"[skill-extract] calling LLM (endpoint=%s, ctx=%d msgs, timeout=30s)",
|
||||
endpoint_url, len(recent),
|
||||
)
|
||||
response = await llm_call_async(
|
||||
endpoint_url,
|
||||
model,
|
||||
[
|
||||
{"role": "system", "content": prompt},
|
||||
{"role": "user", "content": f"Conversation:\n{conversation}"},
|
||||
],
|
||||
headers=headers,
|
||||
timeout=30,
|
||||
)
|
||||
logger.debug(
|
||||
"[skill-extract] LLM returned in %.1fs (len=%d, head=%r)",
|
||||
_time.monotonic() - _t0, len(response or ""), (response or "")[:80],
|
||||
)
|
||||
|
||||
if not response or response.strip().lower() == "null":
|
||||
logger.debug(
|
||||
"[skill-extract] LLM declined (returned null/empty) — "
|
||||
"session deemed not a reusable procedure"
|
||||
)
|
||||
return None
|
||||
|
||||
# Some models (MiniMax, Qwen-Thinker, DeepSeek-R1) emit their
|
||||
# chain-of-thought BEFORE the JSON output even when asked for
|
||||
# raw JSON. `strip_think(prose=True, prompt_echo=True)` removes
|
||||
# <think>…</think> tags AND prose-style "Let me analyze this…"
|
||||
# preambles. Without it, json.loads bombed on character 0 every
|
||||
# time and the silent-bail looked like "extractor doesn't work".
|
||||
try:
|
||||
from src.text_helpers import strip_think as _strip_think
|
||||
response = _strip_think(response, prose=True, prompt_echo=True)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Parse JSON
|
||||
text = response.strip()
|
||||
if text.startswith("```"):
|
||||
text = text.split("\n", 1)[-1].rsplit("```", 1)[0].strip()
|
||||
# After strip_think, the JSON may still be embedded inside surrounding
|
||||
# commentary — slice from the first '{' to the matching last '}'.
|
||||
if text and text[0] != "{":
|
||||
_start = text.find("{")
|
||||
_end = text.rfind("}")
|
||||
if 0 <= _start < _end:
|
||||
text = text[_start : _end + 1]
|
||||
|
||||
data = json.loads(text)
|
||||
if not data or not isinstance(data, dict):
|
||||
logger.debug("[skill-extract] parsed JSON not a dict, dropping")
|
||||
return None
|
||||
|
||||
title = data.get("title", "").strip()
|
||||
if not title:
|
||||
logger.debug("[skill-extract] LLM returned object with no title, dropping")
|
||||
return None
|
||||
|
||||
# Honour the model's own reliability/reusability estimate — low-
|
||||
# confidence extractions are usually one-offs or shaky procedures.
|
||||
try:
|
||||
_conf = float(data.get("confidence", 0.7))
|
||||
except (TypeError, ValueError):
|
||||
_conf = 0.7
|
||||
if _conf < MIN_CONFIDENCE:
|
||||
logger.debug(
|
||||
"[skill-extract] '%s' below confidence floor (%.2f < %.2f) — dropped",
|
||||
title, _conf, MIN_CONFIDENCE,
|
||||
)
|
||||
return None
|
||||
|
||||
# Check for duplicate skills
|
||||
existing = skills_manager.load(owner=owner)
|
||||
for sk in existing:
|
||||
if sk.get("title", "").lower() == title.lower():
|
||||
logger.debug("[skill-extract] '%s' already exists — dropped as duplicate", title)
|
||||
return None
|
||||
|
||||
entry = skills_manager.add_skill(
|
||||
title=title,
|
||||
problem=data.get("problem", ""),
|
||||
solution=data.get("solution", ""),
|
||||
steps=data.get("steps", []),
|
||||
tags=data.get("tags", []),
|
||||
source="learned",
|
||||
confidence=data.get("confidence", 0.7),
|
||||
session_id=getattr(session, "session_id", None),
|
||||
owner=owner,
|
||||
)
|
||||
try:
|
||||
from src.event_bus import fire_event
|
||||
fire_event("skill_added", owner)
|
||||
except Exception:
|
||||
logger.debug("skill_added event dispatch failed", exc_info=True)
|
||||
logger.info("Auto-extracted skill: %s (id=%s)", title, entry["id"])
|
||||
return entry
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
logger.debug("[skill-extract] non-JSON LLM response, dropping: %s", e)
|
||||
return None
|
||||
except Exception as e:
|
||||
# Real exceptions stay INFO+warning so they don't get lost when
|
||||
# users only have default log level. `exc_info=True` ships the
|
||||
# full traceback so timeouts vs auth vs import errors are
|
||||
# distinguishable from outside.
|
||||
logger.warning("[skill-extract] FAILED: %s", e, exc_info=True)
|
||||
return None
|
||||
@@ -0,0 +1,444 @@
|
||||
"""SKILL.md parser & writer.
|
||||
|
||||
Reads/writes a single skill from a `SKILL.md` file with YAML frontmatter
|
||||
and a structured markdown body. Inspired by Hermes' skills format
|
||||
(https://hermes-agent.nousresearch.com/docs/user-guide/features/skills).
|
||||
|
||||
Frontmatter shape (YAML):
|
||||
|
||||
---
|
||||
name: open-pr-from-branch
|
||||
description: One-line summary surfaced in the skills index.
|
||||
version: 1.0.0
|
||||
category: dev
|
||||
tags: [git, github]
|
||||
platforms: [linux, macos] # optional
|
||||
requires_toolsets: [] # optional
|
||||
fallback_for_toolsets: [] # optional
|
||||
status: published # draft | published
|
||||
confidence: 0.8 # 0..1
|
||||
source: learned # learned | taught | imported
|
||||
teacher_model: claude-opus-4-7 # optional
|
||||
created: 2026-05-09T21:43:00Z
|
||||
---
|
||||
|
||||
Body sections (any subset; rendered as headings):
|
||||
|
||||
## When to Use
|
||||
Trigger conditions in plain English.
|
||||
|
||||
## Procedure
|
||||
1. First step
|
||||
2. Second step
|
||||
|
||||
## Pitfalls
|
||||
- Common failure mode + how to recover
|
||||
|
||||
## Verification
|
||||
- How to confirm success
|
||||
|
||||
Anything else (raw paragraphs after the last known section) is preserved
|
||||
in `body_extra` and round-trips on save.
|
||||
|
||||
Usage counters (`uses`, `last_used`) live in a sidecar `_usage.json` keyed
|
||||
by skill name, so the SKILL.md file doesn't churn on every retrieval.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Slugify
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_SLUG_RE = re.compile(r"[^a-z0-9]+")
|
||||
|
||||
|
||||
def slugify(text: str, fallback: str = "skill") -> str:
|
||||
"""Convert a free-form title to a kebab-case slug suitable for a directory
|
||||
name. Strips non-alphanumerics, collapses runs, trims leading/trailing
|
||||
dashes. Caps at 60 chars."""
|
||||
s = str(text or "").strip().lower()
|
||||
s = _SLUG_RE.sub("-", s)
|
||||
s = s.strip("-")
|
||||
return (s or fallback)[:60]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Frontmatter (minimal YAML — we don't pull in PyYAML for one feature)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# We accept a tiny subset of YAML: scalar `key: value`, inline lists `[a, b]`,
|
||||
# and block lists with `-`. That covers everything in our schema and avoids
|
||||
# a new dependency.
|
||||
|
||||
_FM_KEY_RE = re.compile(r"^([a-z_][a-z0-9_]*):\s*(.*)$", re.IGNORECASE)
|
||||
_FM_BLOCK_LIST_RE = re.compile(r"^\s*-\s*(.*)$")
|
||||
|
||||
|
||||
def _parse_scalar(raw: str) -> Any:
|
||||
raw = raw.strip()
|
||||
if raw == "":
|
||||
return ""
|
||||
if raw.startswith("[") and raw.endswith("]"):
|
||||
inner = raw[1:-1].strip()
|
||||
if not inner:
|
||||
return []
|
||||
return [_parse_scalar(p) for p in _split_top_level(inner, ",")]
|
||||
if raw.lower() in ("true", "yes"):
|
||||
return True
|
||||
if raw.lower() in ("false", "no"):
|
||||
return False
|
||||
if raw.lower() in ("null", "none", "~"):
|
||||
return None
|
||||
if (raw[0] == raw[-1]) and raw[0] in ("'", '"'):
|
||||
return raw[1:-1]
|
||||
# Try number
|
||||
try:
|
||||
if "." in raw:
|
||||
return float(raw)
|
||||
return int(raw)
|
||||
except ValueError:
|
||||
pass
|
||||
return raw
|
||||
|
||||
|
||||
def _split_top_level(s: str, sep: str) -> List[str]:
|
||||
"""Split `s` on `sep` ignoring separators inside [] or quotes."""
|
||||
out, buf, depth, quote = [], [], 0, None
|
||||
for ch in s:
|
||||
if quote:
|
||||
buf.append(ch)
|
||||
if ch == quote:
|
||||
quote = None
|
||||
continue
|
||||
if ch in ("'", '"'):
|
||||
quote = ch
|
||||
buf.append(ch)
|
||||
continue
|
||||
if ch == "[":
|
||||
depth += 1
|
||||
elif ch == "]":
|
||||
depth = max(0, depth - 1)
|
||||
if ch == sep and depth == 0:
|
||||
out.append("".join(buf).strip())
|
||||
buf = []
|
||||
continue
|
||||
buf.append(ch)
|
||||
if buf:
|
||||
out.append("".join(buf).strip())
|
||||
return out
|
||||
|
||||
|
||||
def parse_frontmatter(text: str) -> tuple[Dict[str, Any], str]:
|
||||
"""Pull the YAML frontmatter out of a SKILL.md and return (fm, body)."""
|
||||
if not text.startswith("---"):
|
||||
return {}, text
|
||||
end = text.find("\n---", 3)
|
||||
if end < 0:
|
||||
return {}, text
|
||||
fm_text = text[3:end].lstrip("\n")
|
||||
body = text[end + 4:].lstrip("\n")
|
||||
fm: Dict[str, Any] = {}
|
||||
pending_key: Optional[str] = None
|
||||
for line in fm_text.splitlines():
|
||||
if not line.strip() or line.lstrip().startswith("#"):
|
||||
continue
|
||||
m = _FM_KEY_RE.match(line)
|
||||
if m:
|
||||
key, val = m.group(1), m.group(2)
|
||||
if val.strip() == "":
|
||||
pending_key = key
|
||||
fm[key] = []
|
||||
else:
|
||||
fm[key] = _parse_scalar(val)
|
||||
pending_key = None
|
||||
continue
|
||||
m2 = _FM_BLOCK_LIST_RE.match(line)
|
||||
if m2 and pending_key:
|
||||
existing = fm.get(pending_key)
|
||||
if not isinstance(existing, list):
|
||||
fm[pending_key] = []
|
||||
fm[pending_key].append(_parse_scalar(m2.group(1)))
|
||||
return fm, body
|
||||
|
||||
|
||||
def _emit_scalar(v: Any) -> str:
|
||||
if v is None:
|
||||
return "null"
|
||||
if isinstance(v, bool):
|
||||
return "true" if v else "false"
|
||||
if isinstance(v, (int, float)):
|
||||
return str(v)
|
||||
if isinstance(v, list):
|
||||
return "[" + ", ".join(_emit_scalar(x) for x in v) + "]"
|
||||
s = str(v)
|
||||
if any(c in s for c in (":", "#", "\n", "[", "]", "{", "}", ",", "&", "*", "!", "|", ">", "'", '"', "%", "@")):
|
||||
return json.dumps(s)
|
||||
return s
|
||||
|
||||
|
||||
def _as_list(v: Any) -> List[str]:
|
||||
if v is None:
|
||||
return []
|
||||
if isinstance(v, list):
|
||||
return [str(x) for x in v if x not in (None, "")]
|
||||
return [str(v)]
|
||||
|
||||
|
||||
def _as_float(v: Any, default: float = 0.8) -> float:
|
||||
try:
|
||||
return float(v)
|
||||
except (TypeError, ValueError):
|
||||
return default
|
||||
|
||||
|
||||
def emit_frontmatter(fm: Dict[str, Any]) -> str:
|
||||
lines = []
|
||||
for k, v in fm.items():
|
||||
if v is None or v == [] or v == "":
|
||||
continue
|
||||
lines.append(f"{k}: {_emit_scalar(v)}")
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Skill body sections
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_KNOWN_SECTIONS = ("when_to_use", "procedure", "pitfalls", "verification")
|
||||
_HEADING_TO_KEY = {
|
||||
"when to use": "when_to_use",
|
||||
"procedure": "procedure",
|
||||
"steps": "procedure",
|
||||
"pitfalls": "pitfalls",
|
||||
"verification": "verification",
|
||||
}
|
||||
_KEY_TO_HEADING = {
|
||||
"when_to_use": "When to Use",
|
||||
"procedure": "Procedure",
|
||||
"pitfalls": "Pitfalls",
|
||||
"verification": "Verification",
|
||||
}
|
||||
|
||||
|
||||
def parse_body(body: str) -> Dict[str, Any]:
|
||||
"""Split a SKILL.md body into known sections.
|
||||
|
||||
Returns:
|
||||
{
|
||||
"when_to_use": str,
|
||||
"procedure": list[str], # numbered/bulleted lines
|
||||
"pitfalls": list[str],
|
||||
"verification": list[str],
|
||||
"body_extra": str, # anything not under a known heading
|
||||
}
|
||||
"""
|
||||
out = {k: ([] if k != "when_to_use" else "") for k in _KNOWN_SECTIONS}
|
||||
out["body_extra"] = ""
|
||||
if not body or not body.strip():
|
||||
return out
|
||||
|
||||
sections: List[tuple[Optional[str], List[str]]] = [(None, [])]
|
||||
for line in body.splitlines():
|
||||
m = re.match(r"^##\s+(.*?)\s*$", line)
|
||||
if m:
|
||||
heading = m.group(1).strip().lower()
|
||||
key = _HEADING_TO_KEY.get(heading)
|
||||
sections.append((key, []))
|
||||
continue
|
||||
sections[-1][1].append(line)
|
||||
|
||||
for key, lines in sections:
|
||||
text = "\n".join(lines).strip("\n")
|
||||
if key is None:
|
||||
extras = text.strip()
|
||||
if extras:
|
||||
out["body_extra"] = (out["body_extra"] + "\n\n" + extras).strip()
|
||||
continue
|
||||
if key == "when_to_use":
|
||||
out["when_to_use"] = text.strip()
|
||||
else:
|
||||
out[key] = _parse_list_lines(text)
|
||||
return out
|
||||
|
||||
|
||||
def _parse_list_lines(text: str) -> List[str]:
|
||||
"""Pull bullet/numbered lines out of a section body. Plain paragraphs are
|
||||
treated as a single entry."""
|
||||
items: List[str] = []
|
||||
for line in (text or "").splitlines():
|
||||
s = line.strip()
|
||||
if not s:
|
||||
continue
|
||||
m = re.match(r"^(?:[-*]|\d+[.)])\s+(.*)$", s)
|
||||
if m:
|
||||
items.append(m.group(1).strip())
|
||||
elif items:
|
||||
# continuation of previous bullet
|
||||
items[-1] = items[-1] + " " + s
|
||||
else:
|
||||
items.append(s)
|
||||
return items
|
||||
|
||||
|
||||
def emit_body(sections: Dict[str, Any]) -> str:
|
||||
parts: List[str] = []
|
||||
when = (sections.get("when_to_use") or "").strip()
|
||||
if when:
|
||||
parts.append(f"## {_KEY_TO_HEADING['when_to_use']}\n\n{when}")
|
||||
for key in ("procedure", "pitfalls", "verification"):
|
||||
items = sections.get(key) or []
|
||||
if not items:
|
||||
continue
|
||||
heading = _KEY_TO_HEADING[key]
|
||||
if key == "procedure":
|
||||
body = "\n".join(f"{i + 1}. {x}" for i, x in enumerate(items))
|
||||
else:
|
||||
body = "\n".join(f"- {x}" for x in items)
|
||||
parts.append(f"## {heading}\n\n{body}")
|
||||
extra = (sections.get("body_extra") or "").strip()
|
||||
if extra:
|
||||
parts.append(extra)
|
||||
return "\n\n".join(parts) + ("\n" if parts else "")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Skill record
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class Skill:
|
||||
name: str # slug, dir name
|
||||
description: str = ""
|
||||
version: str = "1.0.0"
|
||||
category: str = "general"
|
||||
tags: List[str] = field(default_factory=list)
|
||||
platforms: List[str] = field(default_factory=list)
|
||||
requires_toolsets: List[str] = field(default_factory=list)
|
||||
fallback_for_toolsets: List[str] = field(default_factory=list)
|
||||
status: str = "draft" # draft | published
|
||||
confidence: float = 0.8
|
||||
source: str = "learned"
|
||||
teacher_model: Optional[str] = None
|
||||
owner: Optional[str] = None
|
||||
created: str = "" # ISO8601
|
||||
when_to_use: str = ""
|
||||
procedure: List[str] = field(default_factory=list)
|
||||
pitfalls: List[str] = field(default_factory=list)
|
||||
verification: List[str] = field(default_factory=list)
|
||||
body_extra: str = ""
|
||||
# Sidecar (not persisted in SKILL.md)
|
||||
uses: int = 0
|
||||
last_used: Optional[int] = None
|
||||
# File path on disk (set when read)
|
||||
path: Optional[str] = None
|
||||
|
||||
# ----------------------------------------------------------------------
|
||||
# Serialization
|
||||
# ----------------------------------------------------------------------
|
||||
|
||||
def to_frontmatter(self) -> Dict[str, Any]:
|
||||
fm: Dict[str, Any] = {
|
||||
"name": self.name,
|
||||
"description": self.description,
|
||||
"version": self.version,
|
||||
"category": self.category,
|
||||
}
|
||||
if self.tags: fm["tags"] = list(self.tags)
|
||||
if self.platforms: fm["platforms"] = list(self.platforms)
|
||||
if self.requires_toolsets: fm["requires_toolsets"] = list(self.requires_toolsets)
|
||||
if self.fallback_for_toolsets: fm["fallback_for_toolsets"] = list(self.fallback_for_toolsets)
|
||||
fm["status"] = self.status
|
||||
fm["confidence"] = round(float(self.confidence), 3)
|
||||
fm["source"] = self.source
|
||||
if self.teacher_model: fm["teacher_model"] = self.teacher_model
|
||||
if self.owner: fm["owner"] = self.owner
|
||||
fm["created"] = self.created or _now_iso()
|
||||
return fm
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
d = {
|
||||
"id": self.name, # slug doubles as id
|
||||
"name": self.name,
|
||||
"description": self.description,
|
||||
"version": self.version,
|
||||
"category": self.category,
|
||||
"tags": list(self.tags),
|
||||
"platforms": list(self.platforms),
|
||||
"requires_toolsets": list(self.requires_toolsets),
|
||||
"fallback_for_toolsets": list(self.fallback_for_toolsets),
|
||||
"status": self.status,
|
||||
"confidence": round(float(self.confidence), 3),
|
||||
"source": self.source,
|
||||
"teacher_model": self.teacher_model,
|
||||
"owner": self.owner,
|
||||
"created": self.created,
|
||||
"when_to_use": self.when_to_use,
|
||||
"procedure": list(self.procedure),
|
||||
"pitfalls": list(self.pitfalls),
|
||||
"verification": list(self.verification),
|
||||
"body_extra": self.body_extra,
|
||||
"uses": int(self.uses or 0),
|
||||
"last_used": self.last_used,
|
||||
"path": self.path,
|
||||
}
|
||||
# Back-compat aliases for the old API/UI
|
||||
d["title"] = self.description or self.name.replace("-", " ").title()
|
||||
d["problem"] = self.when_to_use
|
||||
d["solution"] = (self.procedure[0] if self.procedure else "") if not self.body_extra else self.body_extra
|
||||
d["steps"] = list(self.procedure)
|
||||
return d
|
||||
|
||||
@classmethod
|
||||
def from_markdown(cls, text: str, *, path: Optional[str] = None) -> "Skill":
|
||||
fm, body = parse_frontmatter(text)
|
||||
sections = parse_body(body)
|
||||
raw_name = fm.get("name")
|
||||
name = slugify(raw_name if raw_name not in (None, "") else fm.get("description", ""), fallback="skill")
|
||||
return cls(
|
||||
name=name,
|
||||
description=str(fm.get("description", "") or ""),
|
||||
version=str(fm.get("version", "1.0.0") or "1.0.0"),
|
||||
category=str(fm.get("category", "general") or "general"),
|
||||
tags=_as_list(fm.get("tags")),
|
||||
platforms=_as_list(fm.get("platforms")),
|
||||
requires_toolsets=_as_list(fm.get("requires_toolsets")),
|
||||
fallback_for_toolsets=_as_list(fm.get("fallback_for_toolsets")),
|
||||
status=str(fm.get("status", "draft") or "draft"),
|
||||
confidence=_as_float(fm.get("confidence", 0.8), 0.8),
|
||||
source=str(fm.get("source", "learned") or "learned"),
|
||||
teacher_model=str(fm.get("teacher_model")) if fm.get("teacher_model") else None,
|
||||
owner=str(fm.get("owner")) if fm.get("owner") else None,
|
||||
created=str(fm.get("created") or _now_iso()),
|
||||
when_to_use=sections["when_to_use"],
|
||||
procedure=list(sections["procedure"]),
|
||||
pitfalls=list(sections["pitfalls"]),
|
||||
verification=list(sections["verification"]),
|
||||
body_extra=sections["body_extra"],
|
||||
path=path,
|
||||
)
|
||||
|
||||
def to_markdown(self) -> str:
|
||||
fm = emit_frontmatter(self.to_frontmatter())
|
||||
body = emit_body({
|
||||
"when_to_use": self.when_to_use,
|
||||
"procedure": self.procedure,
|
||||
"pitfalls": self.pitfalls,
|
||||
"verification": self.verification,
|
||||
"body_extra": self.body_extra,
|
||||
})
|
||||
return f"---\n{fm}\n---\n\n{body}"
|
||||
|
||||
|
||||
def _now_iso() -> str:
|
||||
return datetime.utcnow().strftime("%Y-%m-%dT%H:%M:%SZ")
|
||||
@@ -0,0 +1,610 @@
|
||||
# services/memory/skills.py
|
||||
"""Skills storage layer.
|
||||
|
||||
Skills live on disk as `data/skills/<category>/<name>/SKILL.md` files with
|
||||
YAML frontmatter and a structured markdown body (When to Use / Procedure /
|
||||
Pitfalls / Verification). See `skill_format.py` for the format.
|
||||
|
||||
Usage counters (`uses`, `last_used`) live in a sidecar
|
||||
`data/skills/_usage.json` keyed by skill name so the SKILL.md content
|
||||
doesn't churn on every retrieval.
|
||||
|
||||
Ownership: skills declare `owner: <username>` in frontmatter. Single-user
|
||||
deployments can leave that blank.
|
||||
|
||||
This module also retains a JSON fallback for any legacy `data/skills.json`
|
||||
entries — they're surfaced as read-only `Skill` objects so old data still
|
||||
loads while a user migrates them to disk.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from typing import Dict, Iterable, List, Optional
|
||||
|
||||
from .skill_format import Skill, slugify
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Token / similarity helpers (kept for the relevance fallback)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _tokenize(text: str) -> set:
|
||||
return {w.strip('.,!?";:()[]') for w in (text or "").lower().split() if len(w) > 1}
|
||||
|
||||
|
||||
def _jaccard(a: set, b: set) -> float:
|
||||
if not a or not b:
|
||||
return 0.0
|
||||
return len(a & b) / len(a | b)
|
||||
|
||||
|
||||
def _to_float(x, default: float = 0.0) -> float:
|
||||
"""Coerce a possibly hand-edited frontmatter value to float without
|
||||
raising — a blank or non-numeric `confidence:` in a SKILL.md must not
|
||||
blow up retrieval or eviction."""
|
||||
try:
|
||||
return float(x)
|
||||
except (TypeError, ValueError):
|
||||
return default
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SkillsManager
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class SkillsManager:
|
||||
"""Read/write SKILL.md files under <data_dir>/skills/."""
|
||||
|
||||
def __init__(self, data_dir: str):
|
||||
self.data_dir = data_dir
|
||||
self.skills_root = os.path.join(data_dir, "skills")
|
||||
self.usage_file = os.path.join(self.skills_root, "_usage.json")
|
||||
self.legacy_file = os.path.join(data_dir, "skills.json") # back-compat
|
||||
os.makedirs(self.skills_root, exist_ok=True)
|
||||
|
||||
# ----------------------------------------------------------------------
|
||||
# Path helpers
|
||||
# ----------------------------------------------------------------------
|
||||
|
||||
def _skill_dir(self, category: str, name: str) -> str:
|
||||
cat = slugify(category or "general", fallback="general")
|
||||
nm = slugify(name, fallback="skill")
|
||||
return os.path.join(self.skills_root, cat, nm)
|
||||
|
||||
def _skill_file(self, category: str, name: str) -> str:
|
||||
return os.path.join(self._skill_dir(category, name), "SKILL.md")
|
||||
|
||||
# ----------------------------------------------------------------------
|
||||
# Usage sidecar
|
||||
# ----------------------------------------------------------------------
|
||||
|
||||
def _load_usage(self) -> Dict[str, Dict]:
|
||||
if not os.path.exists(self.usage_file):
|
||||
return {}
|
||||
try:
|
||||
with open(self.usage_file) as f:
|
||||
d = json.load(f)
|
||||
return d if isinstance(d, dict) else {}
|
||||
except Exception:
|
||||
return {}
|
||||
|
||||
def _save_usage(self, usage: Dict[str, Dict]) -> None:
|
||||
try:
|
||||
from core.atomic_io import atomic_write_json
|
||||
atomic_write_json(self.usage_file, usage, indent=2)
|
||||
except Exception:
|
||||
tmp = self.usage_file + ".tmp"
|
||||
with open(tmp, "w") as f:
|
||||
json.dump(usage, f, indent=2)
|
||||
os.replace(tmp, self.usage_file)
|
||||
|
||||
def set_audit(self, name: str, verdict: str, by_teacher: bool = False,
|
||||
worker_model: str = "", teacher_model: str = "") -> None:
|
||||
"""Record the last test/audit result for a skill in the usage sidecar
|
||||
(so it surfaces in load() without touching SKILL.md). Drives the
|
||||
'verified' check + teacher mark on the card."""
|
||||
import time as _t
|
||||
usage = self._load_usage()
|
||||
e = usage.setdefault(name, {"uses": 0, "last_used": None})
|
||||
e["audit_verdict"] = verdict
|
||||
e["audit_by_teacher"] = bool(by_teacher)
|
||||
if worker_model:
|
||||
e["audit_worker_model"] = worker_model
|
||||
if teacher_model:
|
||||
e["audit_teacher_model"] = teacher_model
|
||||
e["audited_at"] = _t.time()
|
||||
self._save_usage(usage)
|
||||
|
||||
def set_necessity(self, name: str, necessary: bool,
|
||||
redundant_with=None, reason: str = "") -> None:
|
||||
"""Record the advisory 'is this skill necessary?' judgment in the usage
|
||||
sidecar. Surfaced on the card as a flag; never acts on the skill."""
|
||||
usage = self._load_usage()
|
||||
e = usage.setdefault(name, {"uses": 0, "last_used": None})
|
||||
e["necessity"] = {
|
||||
"necessary": bool(necessary),
|
||||
"redundant_with": list(redundant_with or []),
|
||||
"reason": str(reason or ""),
|
||||
}
|
||||
self._save_usage(usage)
|
||||
|
||||
# ----------------------------------------------------------------------
|
||||
# Disk scan
|
||||
# ----------------------------------------------------------------------
|
||||
|
||||
def _iter_skill_files(self) -> Iterable[str]:
|
||||
if not os.path.isdir(self.skills_root):
|
||||
return
|
||||
for root, _dirs, files in os.walk(self.skills_root, followlinks=False):
|
||||
if "SKILL.md" in files:
|
||||
yield os.path.join(root, "SKILL.md")
|
||||
|
||||
def _read_skill(self, path: str) -> Optional[Skill]:
|
||||
try:
|
||||
with open(path) as f:
|
||||
text = f.read()
|
||||
return Skill.from_markdown(text, path=path)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to parse {path}: {e}")
|
||||
return None
|
||||
|
||||
def _write_skill(self, sk: Skill) -> str:
|
||||
path = self._skill_file(sk.category or "general", sk.name)
|
||||
os.makedirs(os.path.dirname(path), exist_ok=True)
|
||||
from core.atomic_io import atomic_write_text
|
||||
atomic_write_text(path, sk.to_markdown())
|
||||
sk.path = path
|
||||
return path
|
||||
|
||||
def backfill_owner(self, primary_owner: str, valid_owners: Optional[set[str]] = None) -> int:
|
||||
"""Assign legacy/unclaimed skill files to the primary owner.
|
||||
|
||||
Skills are disk-backed, so the DB legacy-owner migration cannot fix
|
||||
them. If strict owner filtering is enabled and SKILL.md files have no
|
||||
owner or an owner from a deleted/test account, the UI appears empty even
|
||||
though files still exist. This mirrors the DB legacy-owner sweep.
|
||||
"""
|
||||
primary_owner = (primary_owner or "").strip()
|
||||
if not primary_owner:
|
||||
return 0
|
||||
valid_owners = set(valid_owners or [])
|
||||
changed = 0
|
||||
for path in self._iter_skill_files():
|
||||
sk = self._read_skill(path)
|
||||
if not sk:
|
||||
continue
|
||||
owner = (sk.owner or "").strip()
|
||||
if owner == primary_owner:
|
||||
continue
|
||||
if owner and owner in valid_owners:
|
||||
continue
|
||||
sk.owner = primary_owner
|
||||
try:
|
||||
self._write_skill(sk)
|
||||
changed += 1
|
||||
except Exception as e:
|
||||
logger.warning("Failed to backfill owner for skill %s: %s", sk.name, e)
|
||||
return changed
|
||||
|
||||
# ----------------------------------------------------------------------
|
||||
# Public API — keeps the old method names so callers don't break
|
||||
# ----------------------------------------------------------------------
|
||||
|
||||
def load_all(self) -> List[Dict]:
|
||||
"""Return every skill as a plain dict, plus any legacy JSON entries."""
|
||||
usage = self._load_usage()
|
||||
out: List[Dict] = []
|
||||
seen_names: set[str] = set()
|
||||
for path in self._iter_skill_files():
|
||||
sk = self._read_skill(path)
|
||||
if not sk:
|
||||
continue
|
||||
d = sk.to_dict()
|
||||
u = usage.get(sk.name) or {}
|
||||
d["uses"] = int(u.get("uses", 0))
|
||||
d["last_used"] = u.get("last_used")
|
||||
d["audit_verdict"] = u.get("audit_verdict")
|
||||
d["audit_by_teacher"] = bool(u.get("audit_by_teacher"))
|
||||
d["audit_worker_model"] = u.get("audit_worker_model")
|
||||
d["audit_teacher_model"] = u.get("audit_teacher_model")
|
||||
d["audited_at"] = u.get("audited_at")
|
||||
d["necessity"] = u.get("necessity")
|
||||
out.append(d)
|
||||
seen_names.add(sk.name)
|
||||
# Legacy JSON entries — surfaced as draft, not editable from new flow
|
||||
if os.path.exists(self.legacy_file):
|
||||
try:
|
||||
with open(self.legacy_file) as f:
|
||||
legacy = json.load(f)
|
||||
if isinstance(legacy, list):
|
||||
for row in legacy:
|
||||
if not isinstance(row, dict):
|
||||
continue
|
||||
name = slugify(row.get("title") or row.get("id") or "skill")
|
||||
if name in seen_names:
|
||||
continue
|
||||
out.append({
|
||||
"id": row.get("id") or name,
|
||||
"name": name,
|
||||
"description": row.get("title", ""),
|
||||
"version": "0.0.1",
|
||||
"category": "legacy",
|
||||
"tags": row.get("tags") or [],
|
||||
"status": row.get("status") or "draft",
|
||||
"confidence": row.get("confidence", 0.5),
|
||||
"source": row.get("source", "imported"),
|
||||
"owner": row.get("owner"),
|
||||
"when_to_use": row.get("problem", ""),
|
||||
"procedure": row.get("steps") or [],
|
||||
"pitfalls": [],
|
||||
"verification": [],
|
||||
"body_extra": row.get("solution", ""),
|
||||
"title": row.get("title", ""),
|
||||
"problem": row.get("problem", ""),
|
||||
"solution": row.get("solution", ""),
|
||||
"steps": row.get("steps") or [],
|
||||
"uses": row.get("uses", 0),
|
||||
"last_used": row.get("last_used"),
|
||||
"_legacy": True,
|
||||
})
|
||||
except Exception:
|
||||
pass
|
||||
return out
|
||||
|
||||
def load(self, owner: Optional[str] = None) -> List[Dict]:
|
||||
entries = self.load_all()
|
||||
if owner is None:
|
||||
return entries
|
||||
# SECURITY: strict ownership filter. The previous predicate also
|
||||
# included skills with NO owner field (`not s.get("owner")`), which
|
||||
# leaked legacy / un-stamped skills to every authenticated user.
|
||||
# Hide them now; the owner needs to be backfilled on disk if those
|
||||
# skills should be visible to a specific user.
|
||||
return [s for s in entries if s.get("owner") == owner]
|
||||
|
||||
# ----------------------------------------------------------------------
|
||||
# CRUD — disk-backed
|
||||
# ----------------------------------------------------------------------
|
||||
|
||||
def add_skill(
|
||||
self,
|
||||
title: str = "",
|
||||
problem: str = "",
|
||||
solution: str = "",
|
||||
steps: Optional[List[str]] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
source: str = "learned",
|
||||
teacher_model: Optional[str] = None,
|
||||
confidence: float = 0.8,
|
||||
session_id: Optional[str] = None,
|
||||
owner: Optional[str] = None,
|
||||
# New-schema fields (optional; fall back to old shape if absent)
|
||||
name: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
category: str = "general",
|
||||
when_to_use: Optional[str] = None,
|
||||
procedure: Optional[List[str]] = None,
|
||||
pitfalls: Optional[List[str]] = None,
|
||||
verification: Optional[List[str]] = None,
|
||||
platforms: Optional[List[str]] = None,
|
||||
requires_toolsets: Optional[List[str]] = None,
|
||||
fallback_for_toolsets: Optional[List[str]] = None,
|
||||
status: str = "draft",
|
||||
version: str = "1.0.0",
|
||||
) -> Dict:
|
||||
# Normalize name
|
||||
nm = slugify(name or title or description or "skill")
|
||||
|
||||
# Free dedup-at-creation (always, no API): for LLM-authored skills,
|
||||
# skip if a near-identical skill already exists (Jaccard over
|
||||
# name+description+when_to_use+procedure). User-authored skills are
|
||||
# never auto-skipped — a human asked for it. The every-X AI audit
|
||||
# handles the fuzzier near-duplicates this cheap check won't catch.
|
||||
_all = self.load_all()
|
||||
if source != "user":
|
||||
cand = _tokenize(" ".join([
|
||||
nm, (description or title or ""),
|
||||
(when_to_use if when_to_use is not None else (problem or "")),
|
||||
" ".join(procedure if procedure is not None else (steps or [])),
|
||||
]))
|
||||
if cand:
|
||||
for s in _all:
|
||||
ex = _tokenize(" ".join([
|
||||
s.get("name", ""), s.get("description", ""),
|
||||
s.get("when_to_use", ""),
|
||||
" ".join(s.get("procedure", []) or []),
|
||||
]))
|
||||
if _jaccard(cand, ex) >= 0.82:
|
||||
# Near-identical — don't grow the library; bump the
|
||||
# existing skill's usage and return it so the caller
|
||||
# knows it already exists.
|
||||
try:
|
||||
self.record_use(s["name"])
|
||||
except Exception:
|
||||
pass
|
||||
return {**s, "_deduped": True, "_duplicate_of": s.get("name")}
|
||||
|
||||
# Avoid clobbering an existing skill with the same name
|
||||
existing = {s["name"] for s in _all}
|
||||
base = nm
|
||||
i = 2
|
||||
while nm in existing:
|
||||
nm = f"{base}-{i}"
|
||||
i += 1
|
||||
|
||||
sk = Skill(
|
||||
name=nm,
|
||||
description=(description or title or "").strip(),
|
||||
version=version,
|
||||
category=category or "general",
|
||||
tags=list(tags or []),
|
||||
platforms=list(platforms or []),
|
||||
requires_toolsets=list(requires_toolsets or []),
|
||||
fallback_for_toolsets=list(fallback_for_toolsets or []),
|
||||
status=status or "draft",
|
||||
confidence=float(confidence),
|
||||
source=source,
|
||||
teacher_model=teacher_model,
|
||||
owner=owner,
|
||||
when_to_use=(when_to_use if when_to_use is not None else (problem or "")),
|
||||
procedure=list(procedure if procedure is not None else (steps or [])),
|
||||
pitfalls=list(pitfalls or []),
|
||||
verification=list(verification or []),
|
||||
body_extra=(solution if solution and not procedure else ""),
|
||||
)
|
||||
self._write_skill(sk)
|
||||
|
||||
return sk.to_dict()
|
||||
|
||||
def update_skill(self, skill_id: str, updates: Dict) -> bool:
|
||||
"""`skill_id` is the slug name. Allows updating any field plus
|
||||
renames if `name` changes (file is moved on disk)."""
|
||||
for path in self._iter_skill_files():
|
||||
sk = self._read_skill(path)
|
||||
if not sk or sk.name != skill_id:
|
||||
continue
|
||||
old_dir = os.path.dirname(path)
|
||||
|
||||
# Apply updates in a Skill-shape friendly way
|
||||
scalar_keys = (
|
||||
"description", "version", "category", "status", "confidence",
|
||||
"source", "teacher_model", "owner", "when_to_use",
|
||||
"body_extra",
|
||||
)
|
||||
for k in scalar_keys:
|
||||
if k in updates:
|
||||
setattr(sk, k, updates[k])
|
||||
list_keys = ("tags", "procedure", "pitfalls", "verification",
|
||||
"platforms", "requires_toolsets", "fallback_for_toolsets")
|
||||
for k in list_keys:
|
||||
if k in updates:
|
||||
setattr(sk, k, list(updates[k] or []))
|
||||
|
||||
# Old-schema field aliases
|
||||
if "title" in updates and "description" not in updates:
|
||||
sk.description = updates["title"]
|
||||
if "problem" in updates and "when_to_use" not in updates:
|
||||
sk.when_to_use = updates["problem"]
|
||||
if "solution" in updates and "body_extra" not in updates and not sk.procedure:
|
||||
sk.body_extra = updates["solution"]
|
||||
if "steps" in updates and "procedure" not in updates:
|
||||
sk.procedure = list(updates["steps"] or [])
|
||||
|
||||
# Rename
|
||||
new_name = slugify(updates.get("name") or sk.name)
|
||||
if new_name != sk.name:
|
||||
sk.name = new_name
|
||||
|
||||
# Write to potentially new path
|
||||
new_path = self._skill_file(sk.category, sk.name)
|
||||
if new_path != path:
|
||||
# Move the whole skill directory if rename or recategorize
|
||||
new_dir = os.path.dirname(new_path)
|
||||
if os.path.isdir(new_dir):
|
||||
logger.warning(f"Skill rename target exists: {new_dir}")
|
||||
return False
|
||||
os.makedirs(os.path.dirname(new_dir), exist_ok=True)
|
||||
os.rename(old_dir, new_dir)
|
||||
# Also rename usage key
|
||||
usage = self._load_usage()
|
||||
if skill_id in usage:
|
||||
usage[sk.name] = usage.pop(skill_id)
|
||||
self._save_usage(usage)
|
||||
self._write_skill(sk)
|
||||
return True
|
||||
return False
|
||||
|
||||
def delete_skill(self, skill_id: str) -> bool:
|
||||
for path in self._iter_skill_files():
|
||||
sk = self._read_skill(path)
|
||||
if not sk or sk.name != skill_id:
|
||||
continue
|
||||
skill_dir = os.path.dirname(path)
|
||||
try:
|
||||
# Remove the whole skill dir
|
||||
for root, dirs, files in os.walk(skill_dir, topdown=False):
|
||||
for f in files:
|
||||
os.remove(os.path.join(root, f))
|
||||
for d in dirs:
|
||||
os.rmdir(os.path.join(root, d))
|
||||
os.rmdir(skill_dir)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to remove skill dir {skill_dir}: {e}")
|
||||
return False
|
||||
usage = self._load_usage()
|
||||
if skill_id in usage:
|
||||
del usage[skill_id]
|
||||
self._save_usage(usage)
|
||||
return True
|
||||
return False
|
||||
|
||||
def record_use(self, skill_id: str) -> None:
|
||||
usage = self._load_usage()
|
||||
entry = usage.setdefault(skill_id, {"uses": 0, "last_used": None})
|
||||
entry["uses"] = int(entry.get("uses", 0)) + 1
|
||||
entry["last_used"] = int(time.time())
|
||||
self._save_usage(usage)
|
||||
|
||||
# ----------------------------------------------------------------------
|
||||
# Reading a single skill (used by the skill_view tool)
|
||||
# ----------------------------------------------------------------------
|
||||
|
||||
def read_skill_md(self, name: str) -> Optional[str]:
|
||||
for path in self._iter_skill_files():
|
||||
sk = self._read_skill(path)
|
||||
if sk and sk.name == name:
|
||||
try:
|
||||
with open(path) as f:
|
||||
return f.read()
|
||||
except Exception:
|
||||
return None
|
||||
return None
|
||||
|
||||
def read_skill_reference(self, name: str, ref_path: str) -> Optional[str]:
|
||||
"""Read a sub-file under the skill's directory (references/, etc).
|
||||
Refuses path traversal."""
|
||||
for path in self._iter_skill_files():
|
||||
sk = self._read_skill(path)
|
||||
if not sk or sk.name != name:
|
||||
continue
|
||||
base = os.path.realpath(os.path.dirname(path))
|
||||
target = os.path.realpath(os.path.join(base, ref_path))
|
||||
if os.path.commonpath([base, target]) != base or target == os.path.dirname(path):
|
||||
return None
|
||||
if not os.path.isfile(target):
|
||||
return None
|
||||
try:
|
||||
with open(target) as f:
|
||||
return f.read()
|
||||
except Exception:
|
||||
return None
|
||||
return None
|
||||
|
||||
# ----------------------------------------------------------------------
|
||||
# Index — the lightweight summary injected into the system prompt
|
||||
# ----------------------------------------------------------------------
|
||||
|
||||
def index_for(
|
||||
self,
|
||||
owner: Optional[str] = None,
|
||||
*,
|
||||
active_toolsets: Optional[List[str]] = None,
|
||||
platform: Optional[str] = None,
|
||||
) -> List[Dict]:
|
||||
"""Return the `[{name, description, category, status}]` list the
|
||||
agent sees in its system prompt.
|
||||
|
||||
Includes:
|
||||
- All published skills.
|
||||
- Drafts written by the teacher-escalation loop
|
||||
(`source == "teacher-escalation"`). The whole point of
|
||||
the teacher loop is for the student to find the new
|
||||
procedure on the very next turn — waiting for a manual
|
||||
publish click defeats the loop.
|
||||
|
||||
Excludes user-created drafts (status=draft, source != teacher-
|
||||
escalation) — those are work-in-progress and pollute the
|
||||
prompt with half-finished procedures.
|
||||
"""
|
||||
active_toolsets = active_toolsets or []
|
||||
out = []
|
||||
for s in self.load(owner=owner):
|
||||
status = s.get("status")
|
||||
# Published + None (pre-status legacy) always included.
|
||||
# Drafts only if the teacher wrote them.
|
||||
if status not in ("published", None):
|
||||
if status == "draft" and s.get("source") == "teacher-escalation":
|
||||
pass # let it through
|
||||
else:
|
||||
continue
|
||||
# Platform gating
|
||||
if platform and s.get("platforms") and platform not in s["platforms"]:
|
||||
continue
|
||||
# requires_toolsets: hide unless every required toolset is active
|
||||
req = s.get("requires_toolsets") or []
|
||||
if req and not all(t in active_toolsets for t in req):
|
||||
continue
|
||||
# fallback_for_toolsets: hide when any of those toolsets is active
|
||||
fb = s.get("fallback_for_toolsets") or []
|
||||
if fb and any(t in active_toolsets for t in fb):
|
||||
continue
|
||||
out.append({
|
||||
"name": s["name"],
|
||||
"description": s.get("description") or s.get("title", ""),
|
||||
"category": s.get("category", "general"),
|
||||
"status": status or "published",
|
||||
})
|
||||
out.sort(key=lambda x: (x["category"], x["name"]))
|
||||
return out
|
||||
|
||||
# ----------------------------------------------------------------------
|
||||
# Relevance search (kept for the existing /api/skills/search endpoint
|
||||
# and the `manage_skills` action="search"). Now operates on the new
|
||||
# field set.
|
||||
# ----------------------------------------------------------------------
|
||||
|
||||
def get_relevant_skills(
|
||||
self,
|
||||
query: str,
|
||||
skills: Optional[List[Dict]] = None,
|
||||
threshold: float = 0.3,
|
||||
max_items: int = 5,
|
||||
min_confidence: float = 0.0,
|
||||
) -> List[Dict]:
|
||||
if skills is None:
|
||||
skills = self.load_all()
|
||||
if not skills or not query.strip():
|
||||
return []
|
||||
# Consider published AND draft skills for relevance retrieval.
|
||||
# The teacher-escalation loop writes new skills as drafts; the
|
||||
# whole point is for the student to find them on the next try
|
||||
# without a manual publish click. The UI flags teacher-written
|
||||
# entries with a 🎓 badge so users can demote / delete bad
|
||||
# ones when they spot them.
|
||||
skills = [s for s in skills if s.get("status") in ("published", "draft")]
|
||||
# Confidence gate (used by prompt-injection, NOT by search): a DRAFT
|
||||
# skill must clear the bar to be injected. Published skills are already
|
||||
# vetted, so they always qualify. Missing confidence = treat as 1.0
|
||||
# (legacy skills shouldn't silently vanish). 0 disables the gate.
|
||||
if min_confidence > 0:
|
||||
def _passes(s):
|
||||
if s.get("status") == "published":
|
||||
return True
|
||||
c = s.get("confidence")
|
||||
if c is None:
|
||||
return True # unset → don't filter (legacy)
|
||||
return _to_float(c, 1.0) >= min_confidence # unparseable → pass
|
||||
skills = [s for s in skills if _passes(s)]
|
||||
if not skills:
|
||||
return []
|
||||
|
||||
query_tokens = _tokenize(query)
|
||||
scored = []
|
||||
for sk in skills:
|
||||
text = " ".join([
|
||||
sk.get("name", ""),
|
||||
sk.get("description", ""),
|
||||
sk.get("when_to_use", ""),
|
||||
" ".join(sk.get("tags", []) or []),
|
||||
" ".join(sk.get("procedure", []) or []),
|
||||
])
|
||||
score = _jaccard(query_tokens, _tokenize(text))
|
||||
for tag in sk.get("tags", []) or []:
|
||||
if tag and tag in query.lower():
|
||||
score = max(score, 0.3) * 1.3
|
||||
if query.lower() in (sk.get("description") or "").lower():
|
||||
score = max(score, 0.6)
|
||||
score *= 1.0 + _to_float(sk.get("confidence"), 0.5) * 0.1
|
||||
if sk.get("uses", 0) > 0:
|
||||
score *= 1.05
|
||||
if score >= threshold:
|
||||
scored.append((score, sk))
|
||||
scored.sort(key=lambda x: x[0], reverse=True)
|
||||
return [sk for _, sk in scored[:max_items]]
|
||||
@@ -0,0 +1,12 @@
|
||||
# services/research/__init__.py
|
||||
"""Research service — deep research with LLM-in-the-loop."""
|
||||
|
||||
from .service import ResearchService, ResearchResult, ResearchSource
|
||||
from .research_handler import ResearchHandler
|
||||
|
||||
__all__ = [
|
||||
"ResearchService",
|
||||
"ResearchResult",
|
||||
"ResearchSource",
|
||||
"ResearchHandler",
|
||||
]
|
||||
@@ -0,0 +1,463 @@
|
||||
# src/research_handler.py
|
||||
"""Handler for research service integration with expandable UI support.
|
||||
|
||||
Uses the IterResearch-style DeepResearcher (LLM-in-the-loop) as the primary
|
||||
engine, falling back to the legacy ResearchOrchestrator or basic web search
|
||||
if needed.
|
||||
|
||||
Includes a task registry so research survives page refreshes and can be cancelled.
|
||||
"""
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Optional, Dict
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
RESEARCH_DATA_DIR = Path("data/deep_research")
|
||||
|
||||
|
||||
class ResearchHandler:
|
||||
"""Handles research service operations with iterative deep research."""
|
||||
|
||||
def __init__(self):
|
||||
self._legacy_engine = None
|
||||
self._active_tasks: Dict[str, dict] = {}
|
||||
self._initialize_legacy_engine()
|
||||
RESEARCH_DATA_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def _initialize_legacy_engine(self):
|
||||
"""Initialize the legacy research engine as a fallback."""
|
||||
try:
|
||||
from research_engine import ResearchOrchestrator, Config
|
||||
config = Config(max_searches=12, max_content_per_page=15000)
|
||||
self._legacy_engine = ResearchOrchestrator(config)
|
||||
logger.info("Legacy ResearchOrchestrator initialized (fallback)")
|
||||
except ImportError:
|
||||
logger.info("Legacy research_engine.py not found — DeepResearcher only")
|
||||
self._legacy_engine = None
|
||||
except Exception as e:
|
||||
logger.warning(f"Legacy research engine init failed: {e}")
|
||||
self._legacy_engine = None
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Task registry — background research with persistence
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def start_research(
|
||||
self,
|
||||
session_id: str,
|
||||
query: str,
|
||||
llm_endpoint: str,
|
||||
llm_model: str,
|
||||
max_time: int = 300,
|
||||
llm_headers: dict = None,
|
||||
) -> dict:
|
||||
"""Start research as a background task. Returns task info dict."""
|
||||
# Cancel any existing research for this session
|
||||
if session_id in self._active_tasks:
|
||||
existing = self._active_tasks[session_id]
|
||||
if existing.get("status") == "running":
|
||||
self.cancel_research(session_id)
|
||||
|
||||
entry = {
|
||||
"task": None,
|
||||
"researcher": None,
|
||||
"query": query,
|
||||
"status": "running",
|
||||
"progress": {},
|
||||
"result": None,
|
||||
"started_at": time.time(),
|
||||
}
|
||||
self._active_tasks[session_id] = entry
|
||||
|
||||
def on_progress(event):
|
||||
entry["progress"] = event
|
||||
|
||||
async def _run():
|
||||
try:
|
||||
result = await self.call_research_service(
|
||||
query, llm_endpoint, llm_model,
|
||||
max_time=max_time,
|
||||
progress_callback=on_progress,
|
||||
_task_entry=entry,
|
||||
llm_headers=llm_headers,
|
||||
)
|
||||
entry["result"] = result
|
||||
entry["status"] = "done"
|
||||
self._save_result(session_id, entry)
|
||||
except asyncio.CancelledError:
|
||||
entry["status"] = "cancelled"
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Background research failed: {e}", exc_info=True)
|
||||
entry["result"] = str(e)
|
||||
entry["status"] = "error"
|
||||
|
||||
task = asyncio.create_task(_run())
|
||||
entry["task"] = task
|
||||
return {"session_id": session_id, "status": "running", "query": query}
|
||||
|
||||
def get_status(self, session_id: str) -> Optional[dict]:
|
||||
"""Get current research status for a session."""
|
||||
if session_id in self._active_tasks:
|
||||
entry = self._active_tasks[session_id]
|
||||
return {
|
||||
"status": entry["status"],
|
||||
"progress": entry["progress"],
|
||||
"query": entry["query"],
|
||||
"started_at": entry["started_at"],
|
||||
}
|
||||
# Check disk for completed research
|
||||
path = RESEARCH_DATA_DIR / f"{session_id}.json"
|
||||
if path.exists():
|
||||
try:
|
||||
data = json.loads(path.read_text())
|
||||
return {
|
||||
"status": data.get("status", "done"),
|
||||
"progress": {},
|
||||
"query": data.get("query", ""),
|
||||
"started_at": data.get("started_at", 0),
|
||||
}
|
||||
except Exception:
|
||||
pass
|
||||
return None
|
||||
|
||||
def cancel_research(self, session_id: str) -> bool:
|
||||
"""Cancel running research for a session."""
|
||||
if session_id not in self._active_tasks:
|
||||
return False
|
||||
entry = self._active_tasks[session_id]
|
||||
if entry["status"] != "running":
|
||||
return False
|
||||
researcher = entry.get("researcher")
|
||||
if researcher:
|
||||
researcher.cancel()
|
||||
task = entry.get("task")
|
||||
if task and not task.done():
|
||||
task.cancel()
|
||||
entry["status"] = "cancelled"
|
||||
return True
|
||||
|
||||
def get_result(self, session_id: str) -> Optional[str]:
|
||||
"""Get the completed research result."""
|
||||
if session_id in self._active_tasks:
|
||||
entry = self._active_tasks[session_id]
|
||||
if entry["status"] in ("done", "error", "cancelled"):
|
||||
return entry.get("result")
|
||||
# Check disk
|
||||
path = RESEARCH_DATA_DIR / f"{session_id}.json"
|
||||
if path.exists():
|
||||
try:
|
||||
data = json.loads(path.read_text())
|
||||
return data.get("result")
|
||||
except Exception:
|
||||
pass
|
||||
return None
|
||||
|
||||
def get_sources(self, session_id: str) -> Optional[list]:
|
||||
"""Get deduplicated source list from research findings."""
|
||||
# Check in-memory first
|
||||
if session_id in self._active_tasks:
|
||||
entry = self._active_tasks[session_id]
|
||||
if entry.get("sources"):
|
||||
return entry["sources"]
|
||||
researcher = entry.get("researcher")
|
||||
if researcher and researcher.findings:
|
||||
return self._extract_sources(researcher.findings)
|
||||
# Check disk
|
||||
path = RESEARCH_DATA_DIR / f"{session_id}.json"
|
||||
if path.exists():
|
||||
try:
|
||||
data = json.loads(path.read_text())
|
||||
return data.get("sources")
|
||||
except Exception:
|
||||
pass
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _extract_sources(findings: list) -> list:
|
||||
"""Extract deduplicated [{url, title}] from findings."""
|
||||
seen = set()
|
||||
sources = []
|
||||
for f in findings:
|
||||
url = f.get("url", "")
|
||||
title = f.get("title", "") or url
|
||||
if url and url not in seen:
|
||||
seen.add(url)
|
||||
sources.append({"url": url, "title": title})
|
||||
return sources
|
||||
|
||||
def clear_result(self, session_id: str):
|
||||
"""Remove persisted result after it's been consumed."""
|
||||
self._active_tasks.pop(session_id, None)
|
||||
path = RESEARCH_DATA_DIR / f"{session_id}.json"
|
||||
if path.exists():
|
||||
try:
|
||||
path.unlink()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _save_result(self, session_id: str, entry: dict):
|
||||
"""Persist completed research result to disk."""
|
||||
try:
|
||||
# Extract and cache sources
|
||||
sources = []
|
||||
researcher = entry.get("researcher")
|
||||
if researcher and researcher.findings:
|
||||
sources = self._extract_sources(researcher.findings)
|
||||
entry["sources"] = sources
|
||||
|
||||
path = RESEARCH_DATA_DIR / f"{session_id}.json"
|
||||
data = {
|
||||
"query": entry["query"],
|
||||
"status": entry["status"],
|
||||
"result": entry["result"],
|
||||
"sources": sources,
|
||||
"started_at": entry["started_at"],
|
||||
"completed_at": time.time(),
|
||||
}
|
||||
path.write_text(json.dumps(data))
|
||||
logger.info(f"Research result saved to {path}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save research result: {e}")
|
||||
|
||||
async def call_research_service(
|
||||
self,
|
||||
query: str,
|
||||
llm_endpoint: str,
|
||||
llm_model: str,
|
||||
max_time: int = 300,
|
||||
progress_callback=None,
|
||||
_task_entry: dict = None,
|
||||
llm_headers: dict = None,
|
||||
) -> str:
|
||||
"""
|
||||
Run iterative deep research using the LLM-in-the-loop DeepResearcher.
|
||||
|
||||
Args:
|
||||
query: Research question
|
||||
llm_endpoint: LLM endpoint URL for chat completions
|
||||
llm_model: Model name/ID
|
||||
max_time: Maximum research time in seconds (default 5 minutes)
|
||||
_task_entry: Internal - registry entry to store researcher ref
|
||||
|
||||
Returns:
|
||||
Formatted research report with expandable section and summary
|
||||
"""
|
||||
logger.info("Starting IterResearch Deep Research")
|
||||
logger.info(f"Query: {query}")
|
||||
logger.info(f"LLM: {llm_endpoint} / {llm_model}")
|
||||
logger.info(f"Max time: {max_time}s")
|
||||
|
||||
try:
|
||||
from src.deep_research import DeepResearcher
|
||||
from src.settings import get_setting
|
||||
|
||||
researcher = DeepResearcher(
|
||||
llm_endpoint=llm_endpoint,
|
||||
llm_model=llm_model,
|
||||
llm_headers=llm_headers,
|
||||
max_rounds=8,
|
||||
max_time=max_time,
|
||||
max_report_tokens=int(get_setting("research_max_tokens", 8192)),
|
||||
progress_callback=progress_callback,
|
||||
)
|
||||
if _task_entry is not None:
|
||||
_task_entry["researcher"] = researcher
|
||||
|
||||
start_time = time.time()
|
||||
report = await researcher.research(query)
|
||||
elapsed = time.time() - start_time
|
||||
|
||||
stats = researcher.get_stats()
|
||||
logger.info("IterResearch completed successfully")
|
||||
for key, value in stats.items():
|
||||
logger.info(f" {key}: {value}")
|
||||
|
||||
return self._format_research_report(
|
||||
query, report, stats, elapsed,
|
||||
findings=researcher.findings,
|
||||
evolving_report=researcher.evolving_report,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"DeepResearcher failed: {e}", exc_info=True)
|
||||
return await self._fallback_research(query, llm_endpoint, llm_model, max_time, str(e))
|
||||
|
||||
async def _fallback_research(
|
||||
self, query: str, llm_endpoint: str, llm_model: str,
|
||||
max_time: int, primary_error: str,
|
||||
) -> str:
|
||||
"""Fall back to legacy engine, then to basic web search."""
|
||||
# Try legacy orchestrator
|
||||
if self._legacy_engine:
|
||||
try:
|
||||
import asyncio
|
||||
logger.info("Falling back to legacy ResearchOrchestrator...")
|
||||
loop = asyncio.get_event_loop()
|
||||
result = await loop.run_in_executor(
|
||||
None, self._legacy_engine.start_research, query, max_time
|
||||
)
|
||||
stats = self._get_legacy_stats()
|
||||
elapsed = float(stats.get("Duration", "0").rstrip("s") or 0)
|
||||
return self._format_research_report(query, result, stats, elapsed)
|
||||
except Exception as e:
|
||||
logger.error(f"Legacy engine also failed: {e}")
|
||||
|
||||
# Fall back to basic web search
|
||||
return self._handle_research_failure(query, primary_error)
|
||||
|
||||
def _get_legacy_stats(self) -> dict:
|
||||
"""Get statistics from the legacy research engine."""
|
||||
if not self._legacy_engine:
|
||||
return {}
|
||||
try:
|
||||
tracker = self._legacy_engine.progress_tracker
|
||||
return {
|
||||
"Findings": len(self._legacy_engine.findings),
|
||||
"Sources": len(self._legacy_engine.source_reports),
|
||||
"Searches": tracker.counters['searches_executed'],
|
||||
"URLs": tracker.counters['urls_processed'],
|
||||
}
|
||||
except Exception:
|
||||
return {}
|
||||
|
||||
def _format_research_report(
|
||||
self, query: str, full_report: str, stats: dict, elapsed: float,
|
||||
findings: list = None, evolving_report: str = None,
|
||||
) -> str:
|
||||
"""Format research report with sources list and expandable raw findings."""
|
||||
summary_lines = [
|
||||
f"**Duration:** {elapsed:.1f}s",
|
||||
f"**Rounds:** {stats.get('Rounds', stats.get('Findings', '?'))}",
|
||||
f"**Queries:** {stats.get('Queries', stats.get('Searches', '?'))}",
|
||||
f"**URLs Analyzed:** {stats.get('URLs', '?')}",
|
||||
]
|
||||
summary_text = " | ".join(summary_lines)
|
||||
|
||||
# Build sources list with clickable links
|
||||
sources_section = ""
|
||||
if findings:
|
||||
seen_urls = set()
|
||||
source_lines = []
|
||||
for f in findings:
|
||||
url = f.get("url", "")
|
||||
title = f.get("title", "") or url
|
||||
if url and url not in seen_urls:
|
||||
seen_urls.add(url)
|
||||
source_lines.append(f"- [{title}]({url})")
|
||||
if source_lines:
|
||||
sources_section = "\n### Sources\n\n" + "\n".join(source_lines) + "\n"
|
||||
|
||||
# Build raw findings section (individual extractions per source)
|
||||
raw_findings_section = ""
|
||||
if findings:
|
||||
parts = []
|
||||
for i, f in enumerate(findings, 1):
|
||||
url = f.get("url", "")
|
||||
title = f.get("title", "") or "Untitled"
|
||||
summary = f.get("summary", "")
|
||||
evidence = f.get("evidence", "")
|
||||
content = summary if summary else (evidence[:2000] if evidence else "(no content)")
|
||||
parts.append(f"**{i}. [{title}]({url})**\n\n{content}")
|
||||
raw_findings_section = "\n\n".join(parts)
|
||||
|
||||
# Build expandable collected info section
|
||||
collected_section = ""
|
||||
if evolving_report or raw_findings_section:
|
||||
collected_section = "\n<details>\n<summary><strong>Raw collected findings ({} sources)</strong></summary>\n\n".format(
|
||||
len(findings) if findings else 0
|
||||
)
|
||||
if raw_findings_section:
|
||||
collected_section += raw_findings_section + "\n"
|
||||
collected_section += "\n</details>\n"
|
||||
|
||||
formatted = f"""---
|
||||
|
||||
## Research Summary
|
||||
|
||||
{summary_text}
|
||||
|
||||
---
|
||||
|
||||
{full_report}
|
||||
|
||||
{sources_section}
|
||||
{collected_section}
|
||||
---
|
||||
|
||||
**The AI has analyzed all research findings above. Ask me anything about: "{query}"**
|
||||
"""
|
||||
return formatted
|
||||
|
||||
def _format_error_response(self, error_msg: str, query: str) -> str:
|
||||
"""Format error response in a user-friendly way."""
|
||||
return f"""## Research Engine Unavailable
|
||||
|
||||
**Query:** {query}
|
||||
|
||||
**Error:** {error_msg}
|
||||
|
||||
**Please check:**
|
||||
1. LLM endpoint is reachable
|
||||
2. SearXNG is running at the configured instance
|
||||
3. Application logs for detailed error information
|
||||
|
||||
**Troubleshooting:**
|
||||
- Test basic search: Try the web search toggle first
|
||||
- Check search config: `/api/search/config`
|
||||
- Review logs for initialization errors
|
||||
"""
|
||||
|
||||
def _handle_research_failure(self, query: str, error: str) -> str:
|
||||
"""Handle research failure with fallback to basic search."""
|
||||
try:
|
||||
logger.info("Attempting fallback to basic web search...")
|
||||
from src.search import comprehensive_web_search
|
||||
|
||||
search_result = comprehensive_web_search(query)
|
||||
|
||||
return f"""## Research Failed - Basic Search Fallback
|
||||
|
||||
**Query:** {query}
|
||||
|
||||
**Error:** {error}
|
||||
|
||||
**Note:** The deep research engine encountered an error. Here are basic search results instead:
|
||||
|
||||
---
|
||||
|
||||
### Basic Web Search Results
|
||||
|
||||
{search_result}
|
||||
|
||||
---
|
||||
|
||||
**To fix deep research:**
|
||||
1. Check that your LLM endpoint and search provider are properly configured
|
||||
2. Verify network connectivity
|
||||
3. Review application logs for detailed error information
|
||||
|
||||
Try the web search toggle for simpler queries, or fix the research engine for comprehensive analysis.
|
||||
"""
|
||||
|
||||
except Exception as e2:
|
||||
logger.error(f"Fallback search also failed: {e2}", exc_info=True)
|
||||
return f"""## Complete Research Failure
|
||||
|
||||
**Primary Error:** {error}
|
||||
**Fallback Error:** {str(e2)}
|
||||
|
||||
**Please check:**
|
||||
1. Search provider configuration in Settings -> Search Settings
|
||||
2. Network connectivity to search APIs
|
||||
3. Application logs for detailed error information
|
||||
4. That SearXNG is running (if using SearXNG)
|
||||
|
||||
**Debug Info:**
|
||||
- Search config endpoint: `/api/search/config`
|
||||
- Test basic search toggle with a simple query first
|
||||
"""
|
||||
@@ -0,0 +1,117 @@
|
||||
# services/research/service.py
|
||||
"""Research service — deep research with LLM-in-the-loop."""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Optional, Callable
|
||||
|
||||
from .research_handler import ResearchHandler
|
||||
|
||||
|
||||
@dataclass
|
||||
class ResearchSource:
|
||||
"""A source found during research."""
|
||||
url: str
|
||||
title: str
|
||||
snippet: str
|
||||
relevance: float = 0.0
|
||||
|
||||
|
||||
@dataclass
|
||||
class ResearchResult:
|
||||
"""Result of a deep research query."""
|
||||
query: str
|
||||
summary: str
|
||||
sources: List[ResearchSource] = field(default_factory=list)
|
||||
sections: List[str] = field(default_factory=list)
|
||||
tokens_used: int = 0
|
||||
duration_seconds: float = 0.0
|
||||
|
||||
|
||||
class ResearchService:
|
||||
"""
|
||||
Deep research service.
|
||||
|
||||
Usage:
|
||||
service = ResearchService()
|
||||
result = await service.research("quantum computing advances 2024")
|
||||
print(result.summary)
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.handler = ResearchHandler()
|
||||
self._active: dict = {}
|
||||
|
||||
async def research(
|
||||
self,
|
||||
topic: str,
|
||||
llm_endpoint: str,
|
||||
llm_model: str,
|
||||
max_time: int = 300,
|
||||
on_progress: Optional[Callable[[dict], None]] = None,
|
||||
) -> ResearchResult:
|
||||
"""
|
||||
Perform deep research on a topic.
|
||||
|
||||
Args:
|
||||
topic: Research topic/question
|
||||
llm_endpoint: LLM API endpoint
|
||||
llm_model: Model to use
|
||||
max_time: Maximum time in seconds
|
||||
on_progress: Optional progress callback
|
||||
|
||||
Returns:
|
||||
ResearchResult with findings
|
||||
"""
|
||||
import time
|
||||
start = time.time()
|
||||
|
||||
result = await self.handler.call_research_service(
|
||||
topic,
|
||||
llm_endpoint,
|
||||
llm_model,
|
||||
max_time=max_time,
|
||||
progress_callback=on_progress,
|
||||
)
|
||||
|
||||
duration = time.time() - start
|
||||
|
||||
# Parse result into structured format
|
||||
sources = [
|
||||
ResearchSource(
|
||||
url=s.get("url", ""),
|
||||
title=s.get("title", ""),
|
||||
snippet=s.get("snippet", ""),
|
||||
relevance=s.get("relevance", 0.0),
|
||||
)
|
||||
for s in result.get("sources", [])
|
||||
]
|
||||
|
||||
return ResearchResult(
|
||||
query=topic,
|
||||
summary=result.get("summary", result.get("answer", "")),
|
||||
sources=sources,
|
||||
sections=result.get("sections", []),
|
||||
tokens_used=result.get("tokens_used", 0),
|
||||
duration_seconds=duration,
|
||||
)
|
||||
|
||||
def start_background(
|
||||
self,
|
||||
session_id: str,
|
||||
topic: str,
|
||||
llm_endpoint: str,
|
||||
llm_model: str,
|
||||
max_time: int = 300,
|
||||
) -> dict:
|
||||
"""Start research in background. Returns task info."""
|
||||
return self.handler.start_research(
|
||||
session_id, topic, llm_endpoint, llm_model, max_time
|
||||
)
|
||||
|
||||
def get_status(self, session_id: str) -> Optional[dict]:
|
||||
"""Get status of background research."""
|
||||
return self.handler.get_status(session_id)
|
||||
|
||||
def cancel(self, session_id: str) -> bool:
|
||||
"""Cancel background research."""
|
||||
return self.handler.cancel_research(session_id)
|
||||
@@ -0,0 +1,35 @@
|
||||
"""Search service — web search with SearXNG."""
|
||||
|
||||
from .core import (
|
||||
comprehensive_web_search,
|
||||
get_search_config,
|
||||
invalidate_search_cache,
|
||||
searxng_search_results,
|
||||
update_search_config,
|
||||
)
|
||||
from .content import fetch_webpage_content
|
||||
from .providers import searxng_search, searxng_search_api, PROVIDER_INFO
|
||||
from .analytics import get_search_stats, SearchEngineError, NetworkError, ParseError, RateLimitError
|
||||
from .service import SearchService, SearchResult, SearchResponse
|
||||
|
||||
__all__ = [
|
||||
# Service interface (preferred)
|
||||
"SearchService",
|
||||
"SearchResult",
|
||||
"SearchResponse",
|
||||
# Low-level functions (for backwards compat)
|
||||
"comprehensive_web_search",
|
||||
"fetch_webpage_content",
|
||||
"get_search_config",
|
||||
"get_search_stats",
|
||||
"invalidate_search_cache",
|
||||
"searxng_search",
|
||||
"searxng_search_api",
|
||||
"searxng_search_results",
|
||||
"update_search_config",
|
||||
"PROVIDER_INFO",
|
||||
"SearchEngineError",
|
||||
"NetworkError",
|
||||
"ParseError",
|
||||
"RateLimitError",
|
||||
]
|
||||
@@ -0,0 +1,136 @@
|
||||
"""Search analytics, metrics tracking, and exception hierarchy."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from collections import Counter
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any
|
||||
|
||||
from .cache import cache_metrics
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Dedicated error logger with file handler
|
||||
_error_log_path = Path(__file__).resolve().parent.parent / "search_engine_error.log"
|
||||
_error_handler = logging.FileHandler(_error_log_path, encoding="utf-8")
|
||||
_error_handler.setLevel(logging.WARNING)
|
||||
_error_handler.setFormatter(logging.Formatter("%(asctime)s %(levelname)s %(name)s %(message)s"))
|
||||
error_logger = logging.getLogger("search_engine_error")
|
||||
error_logger.addHandler(_error_handler)
|
||||
error_logger.propagate = False
|
||||
|
||||
# Analytics file
|
||||
ANALYTICS_FILE = Path(__file__).resolve().parent.parent / "search_analytics.json"
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------
|
||||
# Custom exception hierarchy
|
||||
# ----------------------------------------------------------------------
|
||||
class SearchEngineError(Exception):
|
||||
"""Base class for all search-engine related errors."""
|
||||
|
||||
|
||||
class NetworkError(SearchEngineError):
|
||||
"""Raised when a network request fails (e.g., timeout, DNS error)."""
|
||||
|
||||
|
||||
class ParseError(SearchEngineError):
|
||||
"""Raised when HTML or other content cannot be parsed."""
|
||||
|
||||
|
||||
class RateLimitError(SearchEngineError):
|
||||
"""Raised when the remote service returns a rate-limit (HTTP 429)."""
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------
|
||||
# Analytics helpers
|
||||
# ----------------------------------------------------------------------
|
||||
def _load_analytics() -> Dict[str, Any]:
|
||||
"""Load analytics data from the JSON file, creating defaults if missing."""
|
||||
if not ANALYTICS_FILE.exists():
|
||||
default = {
|
||||
"total_queries": 0,
|
||||
"successful_queries": 0,
|
||||
"failed_queries": 0,
|
||||
"cache_hits": 0,
|
||||
"cache_misses": 0,
|
||||
"query_patterns": {},
|
||||
}
|
||||
_save_analytics(default)
|
||||
return default
|
||||
try:
|
||||
with open(ANALYTICS_FILE, "r", encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load analytics file: {e}")
|
||||
return {
|
||||
"total_queries": 0,
|
||||
"successful_queries": 0,
|
||||
"failed_queries": 0,
|
||||
"cache_hits": 0,
|
||||
"cache_misses": 0,
|
||||
"query_patterns": {},
|
||||
}
|
||||
|
||||
|
||||
def _save_analytics(data: Dict[str, Any]) -> None:
|
||||
"""Persist analytics data to the JSON file."""
|
||||
try:
|
||||
with open(ANALYTICS_FILE, "w", encoding="utf-8") as f:
|
||||
json.dump(data, f, indent=2)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to write analytics file: {e}")
|
||||
|
||||
|
||||
def _record_query(query: str, success: bool, cache_hit: bool) -> None:
|
||||
"""Update analytics for a single query execution."""
|
||||
analytics = _load_analytics()
|
||||
analytics["total_queries"] += 1
|
||||
if success:
|
||||
analytics["successful_queries"] += 1
|
||||
else:
|
||||
analytics["failed_queries"] += 1
|
||||
|
||||
if cache_hit:
|
||||
analytics["cache_hits"] += 1
|
||||
cache_metrics["hits"] += 1
|
||||
else:
|
||||
analytics["cache_misses"] += 1
|
||||
cache_metrics["misses"] += 1
|
||||
|
||||
patterns = analytics["query_patterns"]
|
||||
entry = patterns.get(query, {"count": 0, "successes": 0})
|
||||
entry["count"] += 1
|
||||
if success:
|
||||
entry["successes"] += 1
|
||||
patterns[query] = entry
|
||||
|
||||
_save_analytics(analytics)
|
||||
|
||||
|
||||
def get_search_stats() -> Dict[str, Any]:
|
||||
"""Return aggregated search analytics."""
|
||||
analytics = _load_analytics()
|
||||
total = analytics.get("total_queries", 0) or 1
|
||||
success_rate = analytics.get("successful_queries", 0) / total
|
||||
cache_total = analytics.get("cache_hits", 0) + analytics.get("cache_misses", 0) or 1
|
||||
cache_hit_rate = analytics.get("cache_hits", 0) / cache_total
|
||||
|
||||
pattern_counter = Counter({
|
||||
q: data["count"] for q, data in analytics.get("query_patterns", {}).items()
|
||||
})
|
||||
most_common = [q for q, _ in pattern_counter.most_common(5)]
|
||||
|
||||
return {
|
||||
"most_common_queries": most_common,
|
||||
"success_rate": success_rate,
|
||||
"cache_hit_rate": cache_hit_rate,
|
||||
"total_queries": analytics.get("total_queries", 0),
|
||||
"successful_queries": analytics.get("successful_queries", 0),
|
||||
"failed_queries": analytics.get("failed_queries", 0),
|
||||
"cache_hits": analytics.get("cache_hits", 0),
|
||||
"cache_misses": analytics.get("cache_misses", 0),
|
||||
"cache_evictions": cache_metrics["evictions"],
|
||||
"runtime_cache_hits": cache_metrics["hits"],
|
||||
"runtime_cache_misses": cache_metrics["misses"],
|
||||
}
|
||||
@@ -0,0 +1,57 @@
|
||||
"""Search and content caching with LRU eviction."""
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
from typing import Dict
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Cache directories
|
||||
CACHE_DIR = Path(__file__).resolve().parent.parent / "cache"
|
||||
SEARCH_CACHE_DIR = CACHE_DIR / "search"
|
||||
CONTENT_CACHE_DIR = CACHE_DIR / "content"
|
||||
CACHE_MAX_ENTRIES = 1000
|
||||
|
||||
# Create cache directories
|
||||
SEARCH_CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
||||
CONTENT_CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Track cache size for LRU eviction
|
||||
search_cache_index: Dict[str, datetime] = {}
|
||||
content_cache_index: Dict[str, datetime] = {}
|
||||
|
||||
# Cache metrics (shared across modules)
|
||||
cache_metrics = {"hits": 0, "misses": 0, "evictions": 0}
|
||||
|
||||
|
||||
def generate_cache_key(data: str) -> str:
|
||||
"""Generate a unique cache key using SHA-256 hash."""
|
||||
return hashlib.sha256(data.encode("utf-8")).hexdigest()
|
||||
|
||||
|
||||
def cleanup_cache(cache_dir: Path, cache_index: Dict[str, datetime], max_age: timedelta):
|
||||
"""Remove expired cache entries and enforce LRU policy."""
|
||||
current_time = datetime.now()
|
||||
files_in_dir = {f.name.split(".")[0]: f for f in cache_dir.glob("*.cache")}
|
||||
|
||||
to_remove = []
|
||||
for key, timestamp in list(cache_index.items()):
|
||||
if current_time - timestamp > max_age or key not in files_in_dir:
|
||||
to_remove.append(key)
|
||||
if key in files_in_dir:
|
||||
files_in_dir[key].unlink(missing_ok=True)
|
||||
|
||||
for key in to_remove:
|
||||
cache_index.pop(key, None)
|
||||
cache_metrics["evictions"] += 1
|
||||
|
||||
if len(cache_index) > CACHE_MAX_ENTRIES:
|
||||
sorted_items = sorted(cache_index.items(), key=lambda x: x[1])
|
||||
excess_count = len(cache_index) - CACHE_MAX_ENTRIES
|
||||
for key, _ in sorted_items[:excess_count]:
|
||||
cache_index.pop(key, None)
|
||||
cache_file = cache_dir / f"{key}.cache"
|
||||
cache_file.unlink(missing_ok=True)
|
||||
cache_metrics["evictions"] += 1
|
||||
@@ -0,0 +1,360 @@
|
||||
"""Webpage content fetching with caching, PDF extraction, and summarization helpers."""
|
||||
|
||||
import io
|
||||
import ipaddress
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import logging
|
||||
import socket
|
||||
from datetime import datetime, timedelta
|
||||
from typing import List
|
||||
from urllib.parse import urljoin, urlparse
|
||||
|
||||
import httpx
|
||||
from bs4 import BeautifulSoup
|
||||
|
||||
from .analytics import RateLimitError, error_logger
|
||||
from .cache import (
|
||||
CONTENT_CACHE_DIR,
|
||||
content_cache_index,
|
||||
generate_cache_key,
|
||||
cleanup_cache,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_PRIVATE_NETWORKS = (
|
||||
ipaddress.ip_network("0.0.0.0/8"),
|
||||
ipaddress.ip_network("10.0.0.0/8"),
|
||||
ipaddress.ip_network("127.0.0.0/8"),
|
||||
ipaddress.ip_network("169.254.0.0/16"),
|
||||
ipaddress.ip_network("172.16.0.0/12"),
|
||||
ipaddress.ip_network("192.168.0.0/16"),
|
||||
ipaddress.ip_network("::1/128"),
|
||||
ipaddress.ip_network("fc00::/7"),
|
||||
ipaddress.ip_network("fe80::/10"),
|
||||
)
|
||||
|
||||
|
||||
def _is_private_address(addr: ipaddress._BaseAddress) -> bool:
|
||||
return addr.is_private or addr.is_loopback or addr.is_link_local or any(addr in net for net in _PRIVATE_NETWORKS)
|
||||
|
||||
|
||||
def _resolve_hostname_ips(hostname: str) -> list[ipaddress._BaseAddress]:
|
||||
try:
|
||||
infos = socket.getaddrinfo(hostname, None)
|
||||
except Exception:
|
||||
return []
|
||||
out = []
|
||||
for info in infos:
|
||||
try:
|
||||
out.append(ipaddress.ip_address(info[4][0]))
|
||||
except Exception:
|
||||
continue
|
||||
return out
|
||||
|
||||
|
||||
def _public_http_url(url: str) -> bool:
|
||||
try:
|
||||
parsed = urlparse(url)
|
||||
if parsed.scheme not in ("http", "https"):
|
||||
return False
|
||||
host = (parsed.hostname or "").strip()
|
||||
if not host:
|
||||
return False
|
||||
lower = host.lower()
|
||||
if lower in ("localhost", "metadata", "metadata.google.internal"):
|
||||
return False
|
||||
if lower.endswith((".local", ".localhost", ".internal", ".lan", ".intranet")):
|
||||
return False
|
||||
try:
|
||||
return not _is_private_address(ipaddress.ip_address(host))
|
||||
except ValueError:
|
||||
pass
|
||||
addrs = _resolve_hostname_ips(host)
|
||||
return bool(addrs) and not any(_is_private_address(a) for a in addrs)
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def _get_public_url(url: str, headers: dict, timeout: int, max_redirects: int = 5) -> httpx.Response:
|
||||
current = url
|
||||
for _ in range(max_redirects + 1):
|
||||
if not _public_http_url(current):
|
||||
raise httpx.RequestError("Blocked private/internal URL", request=httpx.Request("GET", current))
|
||||
response = httpx.get(current, headers=headers, timeout=timeout, follow_redirects=False)
|
||||
if response.status_code not in (301, 302, 303, 307, 308):
|
||||
return response
|
||||
location = response.headers.get("location")
|
||||
if not location:
|
||||
return response
|
||||
current = urljoin(str(response.url), location)
|
||||
raise httpx.RequestError("Too many redirects", request=httpx.Request("GET", current))
|
||||
|
||||
# PDF extraction (optional dependency)
|
||||
try:
|
||||
from pdfminer.high_level import extract_text as pdf_extract_text
|
||||
except ImportError:
|
||||
pdf_extract_text = None # type: ignore
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------
|
||||
# HTML extraction helpers
|
||||
# ----------------------------------------------------------------------
|
||||
def _extract_meta(soup: BeautifulSoup) -> dict:
|
||||
"""Pull meta description and keywords if present."""
|
||||
description = ""
|
||||
keywords = ""
|
||||
desc_tag = soup.find("meta", attrs={"name": re.compile("description", re.I)})
|
||||
if desc_tag and desc_tag.get("content"):
|
||||
description = desc_tag["content"].strip()
|
||||
kw_tag = soup.find("meta", attrs={"name": re.compile("keywords", re.I)})
|
||||
if kw_tag and kw_tag.get("content"):
|
||||
keywords = kw_tag["content"].strip()
|
||||
return {"description": description, "keywords": keywords}
|
||||
|
||||
|
||||
def _extract_lists(soup: BeautifulSoup) -> List[List[str]]:
|
||||
"""Return a list of lists, each inner list representing a <ul>/<ol>."""
|
||||
all_lists = []
|
||||
for lst in soup.find_all(["ul", "ol"]):
|
||||
items = [li.get_text(separator=" ", strip=True) for li in lst.find_all("li")]
|
||||
if items:
|
||||
all_lists.append(items)
|
||||
return all_lists
|
||||
|
||||
|
||||
def _extract_tables(soup: BeautifulSoup) -> List[List[List[str]]]:
|
||||
"""Return a list of tables, each table is a list of rows, each row a list of cell texts."""
|
||||
tables_data = []
|
||||
for table in soup.find_all("table"):
|
||||
rows = []
|
||||
for tr in table.find_all("tr"):
|
||||
cells = [td.get_text(separator=" ", strip=True) for td in tr.find_all(["td", "th"])]
|
||||
if cells:
|
||||
rows.append(cells)
|
||||
if rows:
|
||||
tables_data.append(rows)
|
||||
return tables_data
|
||||
|
||||
|
||||
def _extract_code_blocks(soup: BeautifulSoup) -> List[str]:
|
||||
"""Collect text from <pre> and <code> blocks."""
|
||||
blocks = []
|
||||
for tag in soup.find_all(["pre", "code"]):
|
||||
txt = tag.get_text(separator=" ", strip=True)
|
||||
if txt:
|
||||
blocks.append(txt)
|
||||
return blocks
|
||||
|
||||
|
||||
def _detect_js_frameworks(soup: BeautifulSoup) -> bool:
|
||||
"""Very naive detection of common JS frameworks."""
|
||||
js_indicators = [
|
||||
"react", "angular", "vue", "svelte", "next", "nuxt",
|
||||
"ember", "backbone", "jquery", "polymer", "mithril",
|
||||
]
|
||||
for script in soup.find_all("script"):
|
||||
src = script.get("src", "").lower()
|
||||
if any(fr in src for fr in js_indicators):
|
||||
return True
|
||||
if script.string:
|
||||
content = script.string.lower()
|
||||
if any(fr in content for fr in js_indicators):
|
||||
return True
|
||||
if soup.find(attrs={"data-reactroot": True}) or soup.find(attrs={"ng-app": True}):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _empty_result(url: str, error: str = "") -> dict:
|
||||
"""Build a standard failure result dict."""
|
||||
return {
|
||||
"url": url,
|
||||
"title": "",
|
||||
"content": "",
|
||||
"lists": [],
|
||||
"tables": [],
|
||||
"code_blocks": [],
|
||||
"meta_description": "",
|
||||
"meta_keywords": "",
|
||||
"js_rendered": False,
|
||||
"js_message": "",
|
||||
"success": False,
|
||||
"error": error,
|
||||
}
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------
|
||||
# Main content fetcher
|
||||
# ----------------------------------------------------------------------
|
||||
def fetch_webpage_content(url: str, timeout: int = 5, retry_attempt: int = 0) -> dict:
|
||||
"""Fetch and extract meaningful content from a webpage with caching."""
|
||||
cache_key = generate_cache_key(url)
|
||||
cache_file = CONTENT_CACHE_DIR / f"{cache_key}.cache"
|
||||
|
||||
# Check cache
|
||||
if cache_file.exists():
|
||||
try:
|
||||
with open(cache_file, "r", encoding="utf-8") as f:
|
||||
cached_data = json.load(f)
|
||||
timestamp = datetime.fromisoformat(cached_data["timestamp"])
|
||||
if datetime.now() - timestamp < timedelta(hours=2):
|
||||
logger.debug(f"Content cache hit for URL: {url}")
|
||||
return cached_data["data"]
|
||||
else:
|
||||
cache_file.unlink(missing_ok=True)
|
||||
content_cache_index.pop(cache_key, None)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to read content cache for {url}: {e}")
|
||||
cache_file.unlink(missing_ok=True)
|
||||
content_cache_index.pop(cache_key, None)
|
||||
|
||||
# Fetch
|
||||
try:
|
||||
headers = {
|
||||
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36",
|
||||
"Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8",
|
||||
"Accept-Language": "en-US,en;q=0.5",
|
||||
"Accept-Encoding": "gzip, deflate",
|
||||
"Connection": "keep-alive",
|
||||
}
|
||||
response = _get_public_url(url, headers=headers, timeout=timeout)
|
||||
|
||||
if response.status_code == 429:
|
||||
raise RateLimitError(f"Rate limit hit for {url} (attempt {retry_attempt})")
|
||||
|
||||
response.raise_for_status()
|
||||
except httpx.RequestError as e:
|
||||
error_logger.error(f"NetworkError fetching {url} (attempt {retry_attempt}): {e}")
|
||||
return _empty_result(url, f"NetworkError: {e}")
|
||||
except RateLimitError as e:
|
||||
error_logger.error(str(e))
|
||||
return _empty_result(url, str(e))
|
||||
|
||||
# PDF handling
|
||||
content_type = response.headers.get("Content-Type", "").lower()
|
||||
if "application/pdf" in content_type or url.lower().endswith(".pdf"):
|
||||
if pdf_extract_text is None:
|
||||
logger.error("pdfminer.six is not installed; cannot extract PDF text.")
|
||||
pdf_text = ""
|
||||
else:
|
||||
try:
|
||||
pdf_bytes = io.BytesIO(response.content)
|
||||
pdf_text = pdf_extract_text(pdf_bytes)
|
||||
except Exception as e:
|
||||
logger.warning(f"PDF extraction failed for {url}: {e}")
|
||||
pdf_text = ""
|
||||
result = {
|
||||
"url": url,
|
||||
"title": os.path.basename(url),
|
||||
"content": pdf_text,
|
||||
"lists": [],
|
||||
"tables": [],
|
||||
"code_blocks": [],
|
||||
"meta_description": "",
|
||||
"meta_keywords": "",
|
||||
"js_rendered": False,
|
||||
"js_message": "",
|
||||
"success": bool(pdf_text),
|
||||
"error": "" if pdf_text else "Failed to extract PDF text",
|
||||
}
|
||||
_cache_result(cache_file, cache_key, result, url)
|
||||
return result
|
||||
|
||||
# HTML handling
|
||||
try:
|
||||
soup = BeautifulSoup(response.text, "html.parser")
|
||||
except Exception as e:
|
||||
error_logger.error(f"ParseError parsing HTML from {url} (attempt {retry_attempt}): {e}")
|
||||
result = _empty_result(url, f"ParseError: {e}")
|
||||
_cache_result(cache_file, cache_key, result, url)
|
||||
return result
|
||||
|
||||
title_tag = soup.find("title")
|
||||
title_text = title_tag.get_text(strip=True) if title_tag else ""
|
||||
meta_info = _extract_meta(soup)
|
||||
js_rendered = _detect_js_frameworks(soup)
|
||||
js_message = "Page appears to be rendered by a JavaScript framework; content may be incomplete." if js_rendered else ""
|
||||
|
||||
# Main textual content (heuristic)
|
||||
main_content = ""
|
||||
content_areas = soup.find_all(
|
||||
["main", "article", "section", "div"],
|
||||
class_=re.compile("content|main|body|article|post|entry|text", re.I),
|
||||
)
|
||||
if content_areas:
|
||||
for area in content_areas[:3]:
|
||||
main_content += area.get_text(separator=" ", strip=True) + " "
|
||||
if not main_content:
|
||||
body = soup.find("body")
|
||||
if body:
|
||||
main_content = body.get_text(separator=" ", strip=True)
|
||||
|
||||
main_content = re.sub(r"\s+", " ", main_content).strip()[:8000]
|
||||
|
||||
result = {
|
||||
"url": url,
|
||||
"title": title_text,
|
||||
"content": main_content,
|
||||
"lists": _extract_lists(soup),
|
||||
"tables": _extract_tables(soup),
|
||||
"code_blocks": _extract_code_blocks(soup),
|
||||
"meta_description": meta_info.get("description", ""),
|
||||
"meta_keywords": meta_info.get("keywords", ""),
|
||||
"js_rendered": js_rendered,
|
||||
"js_message": js_message,
|
||||
"success": True,
|
||||
"error": "",
|
||||
}
|
||||
_cache_result(cache_file, cache_key, result, url)
|
||||
return result
|
||||
|
||||
|
||||
def _cache_result(cache_file, cache_key: str, result: dict, url: str):
|
||||
"""Write a result to the content cache."""
|
||||
try:
|
||||
cache_data = {"timestamp": datetime.now().isoformat(), "data": result}
|
||||
with open(cache_file, "w", encoding="utf-8") as f:
|
||||
json.dump(cache_data, f)
|
||||
content_cache_index[cache_key] = datetime.now()
|
||||
cleanup_cache(CONTENT_CACHE_DIR, content_cache_index, timedelta(hours=2))
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to write content cache for {url}: {e}")
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------
|
||||
# Content summarization helpers
|
||||
# ----------------------------------------------------------------------
|
||||
def extract_key_points(text: str) -> List[str]:
|
||||
"""Pull out bullet-style key points from a block of text."""
|
||||
points: List[str] = []
|
||||
bullet_pat = re.compile(r"^\s*[-*•]\s+(.*)")
|
||||
numbered_pat = re.compile(r"^\s*\d+[\.\)]\s+(.*)")
|
||||
for line in text.splitlines():
|
||||
m = bullet_pat.match(line) or numbered_pat.match(line)
|
||||
if m:
|
||||
points.append(m.group(1).strip())
|
||||
return points
|
||||
|
||||
|
||||
def get_tldr(text: str, max_sentences: int = 3) -> str:
|
||||
"""Produce a very short TL;DR by taking the first few sentences."""
|
||||
sentences = re.split(r"(?<=[.!?])\s+", text)
|
||||
selected = [s.strip() for s in sentences if s][:max_sentences]
|
||||
return " ".join(selected)
|
||||
|
||||
|
||||
def extract_quotes(text: str) -> List[str]:
|
||||
"""Return quoted excerpts that are at least 15 characters long."""
|
||||
return [m.group(1).strip() for m in re.finditer(r'["\']([^"\']{15,}?)["\']', text)]
|
||||
|
||||
|
||||
def extract_statistics(text: str) -> List[str]:
|
||||
"""Find numbers, percentages, dates and simple measurements."""
|
||||
pattern = re.compile(
|
||||
r"\b\d{1,3}(?:,\d{3})*(?:\.\d+)?\s*(%|percent|‰|per cent|[a-zA-Z]+)?\b",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
return [m.group(0).strip() for m in pattern.finditer(text)]
|
||||
@@ -0,0 +1,433 @@
|
||||
"""Core search orchestrators: searxng_search_results, comprehensive_web_search, config, cache invalidation."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, Any, Optional, List, Set
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from .analytics import (
|
||||
NetworkError,
|
||||
ParseError,
|
||||
RateLimitError,
|
||||
error_logger,
|
||||
_record_query,
|
||||
)
|
||||
from .cache import (
|
||||
SEARCH_CACHE_DIR,
|
||||
search_cache_index,
|
||||
generate_cache_key,
|
||||
cleanup_cache,
|
||||
)
|
||||
from .query import _cache_duration_for_query
|
||||
from .ranking import rank_search_results
|
||||
from .providers import (
|
||||
searxng_search_api,
|
||||
brave_search,
|
||||
duckduckgo_search,
|
||||
google_pse_search,
|
||||
tavily_search,
|
||||
serper_search,
|
||||
_get_search_settings,
|
||||
_get_result_count,
|
||||
)
|
||||
from .content import (
|
||||
fetch_webpage_content,
|
||||
extract_key_points,
|
||||
get_tldr,
|
||||
extract_quotes,
|
||||
extract_statistics,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ========= CONFIG =========
|
||||
SEARCH_CONFIG: Dict[str, Any] = {
|
||||
"primary_provider": "searxng",
|
||||
}
|
||||
|
||||
|
||||
def get_search_config() -> Dict[str, Any]:
|
||||
"""Get current search configuration including active provider info."""
|
||||
config = SEARCH_CONFIG.copy()
|
||||
settings = _get_search_settings()
|
||||
provider = settings.get("search_provider", "searxng")
|
||||
config["active_provider"] = provider
|
||||
config["has_api_key"] = bool((settings.get("search_api_key") or "").strip())
|
||||
config["result_count"] = _get_result_count()
|
||||
if provider == "searxng":
|
||||
from .providers import _get_search_instance
|
||||
config["search_url"] = _get_search_instance()
|
||||
return config
|
||||
|
||||
|
||||
def update_search_config(api_key: str = None, **kwargs):
|
||||
"""Update search configuration (e.g. Brave API key)."""
|
||||
if api_key:
|
||||
SEARCH_CONFIG["brave_api_key"] = api_key
|
||||
|
||||
|
||||
def _call_provider(provider_name: str, query: str, count: int, time_filter: str = None) -> List[dict]:
|
||||
"""Call a search provider by name. Returns list of results or empty list."""
|
||||
if provider_name == "searxng":
|
||||
return searxng_search_api(query, count, time_filter=time_filter)
|
||||
elif provider_name == "brave":
|
||||
return brave_search(query, count, time_filter)
|
||||
elif provider_name == "duckduckgo":
|
||||
return duckduckgo_search(query, count, time_filter)
|
||||
elif provider_name == "google_pse":
|
||||
return google_pse_search(query, count, time_filter)
|
||||
elif provider_name == "tavily":
|
||||
return tavily_search(query, count, time_filter)
|
||||
elif provider_name == "serper":
|
||||
return serper_search(query, count, time_filter)
|
||||
return []
|
||||
|
||||
|
||||
# If the self-hosted SearXNG instance is up but all enabled engines return
|
||||
# empty, fall back to the no-key provider so "search X" still works on fresh
|
||||
# installs. Users can override/disable with `search_fallback_chain`.
|
||||
_FALLBACK_ORDER = ["duckduckgo"]
|
||||
|
||||
|
||||
def _build_provider_chain(primary: str) -> List[str]:
|
||||
"""Build ordered list: primary first, then configured/default fallbacks."""
|
||||
chain = [primary]
|
||||
settings = _get_search_settings()
|
||||
user_chain = settings.get("search_fallback_chain") or []
|
||||
if isinstance(user_chain, str):
|
||||
user_chain = [s.strip() for s in user_chain.split(",") if s.strip()]
|
||||
fallbacks = user_chain if user_chain else _FALLBACK_ORDER
|
||||
for fb in fallbacks:
|
||||
if fb and fb != primary and fb not in chain and fb != "disabled":
|
||||
chain.append(fb)
|
||||
return chain
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------
|
||||
# Unified search with caching and retry
|
||||
# ----------------------------------------------------------------------
|
||||
def searxng_search_results(query: str, count: int = 10, time_filter: str = None) -> list[dict]:
|
||||
"""Perform a web search using configured provider with caching and retry."""
|
||||
settings = _get_search_settings()
|
||||
search_provider = settings.get("search_provider", "searxng")
|
||||
result_count = _get_result_count()
|
||||
# Use configured count if caller used default
|
||||
if count == 10:
|
||||
count = result_count
|
||||
|
||||
cache_key = generate_cache_key(f"{query}|{count}|{time_filter}")
|
||||
cache_file = SEARCH_CACHE_DIR / f"{cache_key}.cache"
|
||||
|
||||
# Check cache
|
||||
if cache_file.exists():
|
||||
try:
|
||||
with open(cache_file, "r", encoding="utf-8") as f:
|
||||
cached_data = json.load(f)
|
||||
expiry_raw = cached_data.get("expiry")
|
||||
expiry = datetime.fromisoformat(expiry_raw) if expiry_raw else None
|
||||
if expiry and datetime.now() < expiry:
|
||||
logger.debug(f"Search cache hit for query: {query}")
|
||||
results = cached_data["data"]
|
||||
_record_query(query, bool(results), cache_hit=True)
|
||||
return results
|
||||
else:
|
||||
cache_file.unlink(missing_ok=True)
|
||||
search_cache_index.pop(cache_key, None)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to read search cache for {query}: {e}")
|
||||
cache_file.unlink(missing_ok=True)
|
||||
search_cache_index.pop(cache_key, None)
|
||||
|
||||
logger.debug(f"Search cache miss for query: {query}")
|
||||
|
||||
if search_provider == "disabled":
|
||||
logger.info("Search is disabled via admin settings")
|
||||
return []
|
||||
|
||||
provider_chain = _build_provider_chain(search_provider)
|
||||
|
||||
results: List[dict] = []
|
||||
for provider_name in provider_chain:
|
||||
for attempt in range(2):
|
||||
try:
|
||||
logger.info(f"Attempting {provider_name} search (attempt {attempt + 1})")
|
||||
results = _call_provider(provider_name, query, count, time_filter)
|
||||
if results:
|
||||
logger.info(f"{provider_name} search succeeded with {len(results)} results")
|
||||
break
|
||||
except (NetworkError, ParseError, RateLimitError) as e:
|
||||
error_logger.error(f"{provider_name} search error (attempt {attempt + 1}): {e}")
|
||||
except Exception as e:
|
||||
error_logger.error(f"Unexpected error during {provider_name} search (attempt {attempt + 1}): {e}")
|
||||
if results:
|
||||
break
|
||||
|
||||
success = bool(results)
|
||||
_record_query(query, success, cache_hit=False)
|
||||
|
||||
if success:
|
||||
results = rank_search_results(query, results)
|
||||
try:
|
||||
expiry = datetime.now() + _cache_duration_for_query(query)
|
||||
cache_data = {
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"expiry": expiry.isoformat(),
|
||||
"data": results,
|
||||
}
|
||||
with open(cache_file, "w", encoding="utf-8") as f:
|
||||
json.dump(cache_data, f)
|
||||
search_cache_index[cache_key] = datetime.now()
|
||||
cleanup_cache(SEARCH_CACHE_DIR, search_cache_index, timedelta(hours=1))
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to write search cache for {query}: {e}")
|
||||
|
||||
if not success:
|
||||
logger.error(f"All search providers failed for query: {query}")
|
||||
|
||||
return results
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------
|
||||
# Cache invalidation
|
||||
# ----------------------------------------------------------------------
|
||||
def invalidate_search_cache(query: Optional[str] = None) -> None:
|
||||
"""Invalidate cached search results. None clears all, otherwise just the given query."""
|
||||
if query is None:
|
||||
for file in SEARCH_CACHE_DIR.glob("*.cache"):
|
||||
try:
|
||||
file.unlink(missing_ok=True)
|
||||
except Exception as e:
|
||||
error_logger.warning(f"Failed to delete cache file {file}: {e}")
|
||||
search_cache_index.clear()
|
||||
logger.info("All search cache entries have been cleared.")
|
||||
else:
|
||||
cache_key = generate_cache_key(f"{query}|10|None")
|
||||
cache_file = SEARCH_CACHE_DIR / f"{cache_key}.cache"
|
||||
if cache_file.exists():
|
||||
try:
|
||||
cache_file.unlink(missing_ok=True)
|
||||
search_cache_index.pop(cache_key, None)
|
||||
logger.info(f"Cache entry for query '{query}' has been invalidated.")
|
||||
except Exception as e:
|
||||
error_logger.warning(f"Failed to delete cache file for query '{query}': {e}")
|
||||
else:
|
||||
logger.info(f"No cache entry found for query '{query}'.")
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------
|
||||
# Comprehensive web search (with advanced filtering)
|
||||
# ----------------------------------------------------------------------
|
||||
def comprehensive_web_search(
|
||||
query: str,
|
||||
max_pages: int = 3,
|
||||
max_workers: int = 4,
|
||||
time_filter: str = None,
|
||||
domain_whitelist: Optional[Set[str]] = None,
|
||||
domain_blacklist: Optional[Set[str]] = None,
|
||||
content_type: Optional[str] = None,
|
||||
language: Optional[str] = None,
|
||||
min_content_length: int = 0,
|
||||
return_sources: bool = False,
|
||||
):
|
||||
"""Perform comprehensive web search with content fetching and advanced filtering."""
|
||||
logger.info(f"Starting comprehensive search for: {query}")
|
||||
if time_filter:
|
||||
logger.info(f"Applying time filter: {time_filter}")
|
||||
|
||||
settings = _get_search_settings()
|
||||
search_provider = settings.get("search_provider", "searxng")
|
||||
result_count = _get_result_count()
|
||||
|
||||
if search_provider == "disabled":
|
||||
logger.info("Search is disabled via admin settings")
|
||||
msg = "Web search is disabled by the administrator."
|
||||
return (msg, []) if return_sources else msg
|
||||
|
||||
# Use configured result count (at least max_pages for content fetching)
|
||||
fetch_count = max(result_count, max_pages)
|
||||
|
||||
provider_chain = _build_provider_chain(search_provider)
|
||||
|
||||
search_results = []
|
||||
provider_attempts = {}
|
||||
for provider_name in provider_chain:
|
||||
last_err = None
|
||||
empty = False
|
||||
for attempt in range(2):
|
||||
try:
|
||||
search_results = _call_provider(provider_name, query, fetch_count, time_filter)
|
||||
if search_results:
|
||||
provider_attempts[provider_name] = f"ok ({len(search_results)})"
|
||||
logger.info(f"Comprehensive search: {provider_name} returned {len(search_results)} results")
|
||||
break
|
||||
empty = True
|
||||
except Exception as e:
|
||||
last_err = e
|
||||
logger.warning(f"Comprehensive search: {provider_name} attempt {attempt + 1} failed: {e}")
|
||||
if search_results:
|
||||
break
|
||||
if last_err is not None:
|
||||
provider_attempts[provider_name] = f"error: {last_err}"
|
||||
elif empty:
|
||||
provider_attempts[provider_name] = "empty"
|
||||
|
||||
if not search_results:
|
||||
tally = ", ".join(f"{p}:{r}" for p, r in provider_attempts.items()) or "no providers configured"
|
||||
any_errors = any(r.startswith("error") for r in provider_attempts.values())
|
||||
if any_errors:
|
||||
msg = f"Web search failed — all providers errored or returned empty. Tried: {tally}"
|
||||
else:
|
||||
msg = (
|
||||
f"No search results found. Tried: {tally}. "
|
||||
"All providers returned empty — possibly a niche query or upstream rate-limiting; "
|
||||
"rephrasing or using the browser tool for a specific URL may help."
|
||||
)
|
||||
logger.warning(msg)
|
||||
return (msg, []) if return_sources else msg
|
||||
|
||||
search_results = rank_search_results(query, search_results)
|
||||
|
||||
# URL filter helper
|
||||
def url_passes_filters(url: str) -> bool:
|
||||
try:
|
||||
netloc = urlparse(url).netloc.lower()
|
||||
except Exception:
|
||||
return False
|
||||
if domain_whitelist is not None and netloc not in domain_whitelist:
|
||||
return False
|
||||
if domain_blacklist is not None and netloc in domain_blacklist:
|
||||
return False
|
||||
if content_type:
|
||||
ct = content_type.lower()
|
||||
if ct == "article":
|
||||
if not any(k in url.lower() for k in ("article", "blog", "news", "post")):
|
||||
return False
|
||||
elif ct == "forum":
|
||||
if not any(k in url.lower() for k in ("forum", "discussion", "thread", "topic")):
|
||||
return False
|
||||
elif ct == "academic":
|
||||
if not any(k in url.lower() for k in ("pdf", "doi", "scholar", "arxiv", "journal", "research")):
|
||||
return False
|
||||
if language:
|
||||
lang_pat = language.lower()
|
||||
if not (f"/{lang_pat}/" in url.lower() or f"?lang={lang_pat}" in url.lower() or f"&lang={lang_pat}" in url.lower()):
|
||||
return False
|
||||
return True
|
||||
|
||||
filtered_urls = [r["url"] for r in search_results[:max_pages] if url_passes_filters(r["url"])]
|
||||
if not filtered_urls:
|
||||
logger.warning("All URLs filtered out by advanced criteria")
|
||||
msg = "No suitable results after applying filters."
|
||||
return (msg, []) if return_sources else msg
|
||||
|
||||
# Build sources list for the frontend (before content fetching)
|
||||
_source_list = [
|
||||
{"url": r.get("url", ""), "title": r.get("title", "")}
|
||||
for r in search_results if r.get("url")
|
||||
]
|
||||
|
||||
# Fetch content in parallel
|
||||
fetched_content = []
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
future_to_url = {
|
||||
executor.submit(fetch_webpage_content, url, 8, retry_attempt=0): url
|
||||
for url in filtered_urls
|
||||
}
|
||||
for future in as_completed(future_to_url):
|
||||
url = future_to_url[future]
|
||||
try:
|
||||
result = future.result()
|
||||
if result["success"] and result["content"] and len(result["content"]) >= min_content_length:
|
||||
fetched_content.append(result)
|
||||
except Exception as e:
|
||||
logger.error(f"Exception while fetching {url}: {str(e)}")
|
||||
|
||||
logger.info(f"Successfully fetched content from {len(fetched_content)} pages")
|
||||
|
||||
# Format results
|
||||
output_parts = []
|
||||
|
||||
if search_results:
|
||||
output_parts.append("```sources")
|
||||
for i, result in enumerate(search_results, 1):
|
||||
output_parts.append(f"[{i}] {result['title']}")
|
||||
output_parts.append(f" {result['url']}")
|
||||
if result.get("age"):
|
||||
output_parts.append(f" {result['age']}")
|
||||
output_parts.append("```")
|
||||
output_parts.append("")
|
||||
|
||||
output_parts.append("=" * 70)
|
||||
output_parts.append("WEB SEARCH RESULTS AND FETCHED CONTENT")
|
||||
output_parts.append(f"Query: {query}")
|
||||
output_parts.append(f"Searched {len(search_results)} results, fetched {len(fetched_content)} pages")
|
||||
output_parts.append("=" * 70)
|
||||
output_parts.append("")
|
||||
|
||||
output_parts.append("SEARCH RESULTS SUMMARY:")
|
||||
output_parts.append("-" * 50)
|
||||
for i, result in enumerate(search_results, 1):
|
||||
output_parts.append(f"\n[{i}] {result['title']}")
|
||||
output_parts.append(f" URL: {result['url']}")
|
||||
output_parts.append(f" Snippet: {result['snippet'][:200]}...")
|
||||
if result.get("age"):
|
||||
output_parts.append(f" Age: {result['age']}")
|
||||
|
||||
if fetched_content:
|
||||
output_parts.append("\n" + "=" * 70)
|
||||
output_parts.append("FETCHED PAGE CONTENT:")
|
||||
output_parts.append("-" * 50)
|
||||
|
||||
for i, content in enumerate(fetched_content, 1):
|
||||
output_parts.append(f"\n[CONTENT {i}] From: {content['url']}")
|
||||
output_parts.append(f"Title: {content['title']}")
|
||||
output_parts.append("-" * 30)
|
||||
|
||||
text = content["content"][:3000]
|
||||
if len(content["content"]) > 3000:
|
||||
text += "... [truncated]"
|
||||
output_parts.append(text)
|
||||
|
||||
key_points = extract_key_points(content["content"])
|
||||
if key_points:
|
||||
output_parts.append("\nKey Points:")
|
||||
for pt in key_points[:5]:
|
||||
output_parts.append(f"- {pt}")
|
||||
|
||||
tldr = get_tldr(content["content"])
|
||||
if tldr:
|
||||
output_parts.append("\nTL;DR:")
|
||||
output_parts.append(tldr)
|
||||
|
||||
quotes = extract_quotes(content["content"])
|
||||
if quotes:
|
||||
output_parts.append("\nImportant Quotes:")
|
||||
for q in quotes[:3]:
|
||||
output_parts.append(f"\u201c{q}\u201d")
|
||||
|
||||
stats = extract_statistics(content["content"])
|
||||
if stats:
|
||||
output_parts.append("\nData / Statistics:")
|
||||
for s in stats[:5]:
|
||||
output_parts.append(f"- {s}")
|
||||
|
||||
output_parts.append("")
|
||||
|
||||
output_parts.append("=" * 70)
|
||||
output_parts.append("END OF WEB SEARCH RESULTS")
|
||||
output_parts.append("=" * 70)
|
||||
|
||||
instructions = (
|
||||
"\n\nIMPORTANT INSTRUCTIONS:\n"
|
||||
"1. Use the above web search results and fetched content to answer the user's question\n"
|
||||
"2. Prioritize information from the FETCHED PAGE CONTENT section as it contains actual page data\n"
|
||||
"3. Cross-reference multiple sources when possible\n"
|
||||
"4. If the information is time-sensitive, pay attention to the age of the results\n"
|
||||
"5. Be explicit if the search results don't contain sufficient information to fully answer the question"
|
||||
)
|
||||
output_parts.append(instructions)
|
||||
|
||||
result = "\n".join(output_parts)
|
||||
return (result, _source_list) if return_sources else result
|
||||
@@ -0,0 +1,527 @@
|
||||
"""Search provider implementations: SearXNG, Brave, DuckDuckGo, Google PSE, Tavily, Serper."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from typing import List, Optional
|
||||
|
||||
import httpx
|
||||
from bs4 import BeautifulSoup
|
||||
|
||||
from src.constants import SEARXNG_INSTANCE
|
||||
from .analytics import RateLimitError, error_logger
|
||||
from .query import build_enhanced_query
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
REQUEST_TIMEOUT = 20
|
||||
|
||||
# Provider registry — maps setting value to (label, needs_key, needs_url)
|
||||
PROVIDER_INFO = {
|
||||
"searxng": ("SearXNG", False, True),
|
||||
"brave": ("Brave Search", True, False),
|
||||
"duckduckgo": ("DuckDuckGo", False, False),
|
||||
"google_pse": ("Google PSE", True, False),
|
||||
"tavily": ("Tavily", True, False),
|
||||
"serper": ("Serper", True, False),
|
||||
"disabled": ("Disabled", False, False),
|
||||
}
|
||||
|
||||
|
||||
# ── Settings helpers ──
|
||||
|
||||
def _get_search_settings() -> dict:
|
||||
"""Return search settings from admin config, falling back to env defaults."""
|
||||
try:
|
||||
from src.settings import load_settings
|
||||
return load_settings()
|
||||
except Exception:
|
||||
return {}
|
||||
|
||||
|
||||
def _get_search_instance() -> str:
|
||||
"""Return the active search API URL from admin settings, falling back to env var."""
|
||||
settings = _get_search_settings()
|
||||
url = (settings.get("search_url") or "").strip()
|
||||
if url:
|
||||
return url.rstrip("/")
|
||||
return SEARXNG_INSTANCE
|
||||
|
||||
|
||||
def _get_provider_key(provider: str) -> str:
|
||||
"""Return the API key for a specific provider, with legacy fallback."""
|
||||
settings = _get_search_settings()
|
||||
key_map = {
|
||||
"brave": "brave_api_key",
|
||||
"google_pse": "google_pse_key",
|
||||
"tavily": "tavily_api_key",
|
||||
"serper": "serper_api_key",
|
||||
}
|
||||
field = key_map.get(provider, "")
|
||||
if field:
|
||||
val = (settings.get(field) or "").strip()
|
||||
if val:
|
||||
return val
|
||||
# Legacy fallback: old shared search_api_key field
|
||||
return (settings.get("search_api_key") or "").strip()
|
||||
|
||||
|
||||
def _get_result_count() -> int:
|
||||
"""Return configured result count, default 5."""
|
||||
settings = _get_search_settings()
|
||||
try:
|
||||
return int(settings.get("search_result_count", 5))
|
||||
except (ValueError, TypeError):
|
||||
return 5
|
||||
|
||||
|
||||
# ── SearXNG ──
|
||||
|
||||
_NEWS_HINTS = ("news", "nyheter", "headlines", "breaking", "latest", "today", "idag")
|
||||
|
||||
# Default general engines (google/duckduckgo/brave/startpage/wikipedia) are
|
||||
# routinely rate-limited / CAPTCHA-blocked on this instance and return nothing.
|
||||
# Pin engines that actually respond so non-news queries get results without any
|
||||
# third-party API fallback. Override via SEARXNG_GENERAL_ENGINES.
|
||||
_GENERAL_ENGINES = os.environ.get("SEARXNG_GENERAL_ENGINES", "bing,mojeek,presearch")
|
||||
|
||||
|
||||
def searxng_search_api(query: str, count: int = 10, categories: str = "general",
|
||||
time_filter: Optional[str] = None) -> List[dict]:
|
||||
"""Search using SearXNG JSON API. Returns list of {title, url, snippet}."""
|
||||
instance = _get_search_instance()
|
||||
api_key = ""
|
||||
headers = {"User-Agent": "Mozilla/5.0"}
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
# News/fresh queries do badly in the 'general' category — it favours
|
||||
# encyclopedic/tourism pages, ignores recency, and (with no language pin)
|
||||
# bleeds in foreign-language results. When the agent layer detected
|
||||
# freshness (time_filter) or the query reads like a news lookup, switch to
|
||||
# the 'news' category, constrain recency, and pin language to English so a
|
||||
# search like "Canada latest news" returns actual news instead of Wikipedia.
|
||||
# Pin English for ALL searches — without it, SearXNG geolocates / mixes
|
||||
# languages and brand-ambiguous terms bleed in foreign SEO pages (e.g.
|
||||
# "Odyssey" → Honda Japan, "Trojan" → Japanese malware blogs, "Polyphemus"
|
||||
# → Chinese math forums). The news path already did this; general didn't.
|
||||
params = {"q": query, "format": "json", "language": "en"}
|
||||
q_lc = query.lower()
|
||||
is_news = time_filter is not None or any(h in q_lc for h in _NEWS_HINTS)
|
||||
if is_news and categories == "general":
|
||||
params["categories"] = "news"
|
||||
if time_filter in ("day", "week", "month", "year"):
|
||||
# 'day' is too sparse on most SearXNG news engines — widen to a week
|
||||
# so there's enough volume; the news category already biases recent.
|
||||
params["time_range"] = "week" if time_filter in ("day", "week") else time_filter
|
||||
else:
|
||||
params["categories"] = categories
|
||||
# Route general queries to engines that aren't blocked (default general
|
||||
# set returns 0 on this instance — see _GENERAL_ENGINES).
|
||||
if categories == "general" and _GENERAL_ENGINES:
|
||||
params["engines"] = _GENERAL_ENGINES
|
||||
try:
|
||||
def _parse_results(results):
|
||||
return [
|
||||
{
|
||||
"title": r.get("title", ""),
|
||||
"url": r.get("url", ""),
|
||||
"snippet": r.get("content", ""),
|
||||
}
|
||||
for r in results[:count]
|
||||
if r.get("url")
|
||||
]
|
||||
|
||||
def _run(search_params):
|
||||
response = httpx.get(
|
||||
f"{instance}/search",
|
||||
params=search_params,
|
||||
headers=headers or None,
|
||||
timeout=15,
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
return _parse_results(data.get("results", [])), data
|
||||
|
||||
active_params = params
|
||||
parsed, data = _run(active_params)
|
||||
if not parsed and is_news and categories == "general":
|
||||
# Some self-hosted SearXNG configs have no working news engines.
|
||||
# Fall back to the known-good general engines before reporting an
|
||||
# empty search, otherwise common queries like "Canada news" fail.
|
||||
fallback = {
|
||||
"q": query,
|
||||
"format": "json",
|
||||
"language": "en",
|
||||
"categories": "general",
|
||||
}
|
||||
if _GENERAL_ENGINES:
|
||||
fallback["engines"] = _GENERAL_ENGINES
|
||||
logger.info(
|
||||
"SearXNG news search returned 0 results for %r; retrying general engines",
|
||||
query,
|
||||
)
|
||||
active_params = fallback
|
||||
parsed, data = _run(active_params)
|
||||
if not parsed and active_params.get("language"):
|
||||
fallback = dict(active_params)
|
||||
fallback.pop("language", None)
|
||||
logger.info(
|
||||
"SearXNG language-pinned search returned 0 results for %r; retrying without language",
|
||||
query,
|
||||
)
|
||||
active_params = fallback
|
||||
parsed, data = _run(active_params)
|
||||
if not parsed and active_params.get("engines"):
|
||||
fallback = dict(active_params)
|
||||
fallback.pop("engines", None)
|
||||
logger.info(
|
||||
"SearXNG pinned engines returned 0 results for %r; retrying default engines",
|
||||
query,
|
||||
)
|
||||
parsed, data = _run(fallback)
|
||||
logger.info(f"SearXNG JSON API returned {len(parsed)} results for: {query}")
|
||||
if not parsed:
|
||||
unresponsive = data.get("unresponsive_engines") if isinstance(data, dict) else None
|
||||
if unresponsive:
|
||||
logger.info(f"SearXNG unresponsive engines for {query!r}: {unresponsive}")
|
||||
return parsed
|
||||
except Exception as e:
|
||||
logger.warning(f"SearXNG JSON API search failed: {e}")
|
||||
html_results = searxng_search(query, max_results=count)
|
||||
if html_results:
|
||||
logger.info(f"SearXNG HTML fallback returned {len(html_results)} results for: {query}")
|
||||
return html_results
|
||||
|
||||
|
||||
def searxng_search(query, max_results=10):
|
||||
"""Search using SearXNG instance - parsing HTML."""
|
||||
instance = _get_search_instance()
|
||||
api_key = ""
|
||||
req_headers = {"User-Agent": "Mozilla/5.0"}
|
||||
if api_key:
|
||||
req_headers["Authorization"] = f"Bearer {api_key}"
|
||||
try:
|
||||
response = httpx.get(
|
||||
f"{instance}/search",
|
||||
params={"q": query},
|
||||
headers=req_headers,
|
||||
timeout=10,
|
||||
)
|
||||
if response.is_success:
|
||||
soup = BeautifulSoup(response.text, "html.parser")
|
||||
results = []
|
||||
for article in soup.select("article.result")[:max_results]:
|
||||
title_elem = article.select_one("h3 a")
|
||||
if not title_elem:
|
||||
continue
|
||||
title = title_elem.get_text(strip=True)
|
||||
url = title_elem.get("href", "")
|
||||
snippet_elem = article.select_one("p.content")
|
||||
snippet = snippet_elem.get_text(strip=True) if snippet_elem else ""
|
||||
results.append({"title": title, "url": url, "snippet": snippet})
|
||||
logger.info(f"SearXNG search (HTML) returned {len(results)} results")
|
||||
return results
|
||||
except Exception as e:
|
||||
logger.error(f"SearXNG search failed: {e}")
|
||||
return []
|
||||
|
||||
|
||||
# ── Brave ──
|
||||
|
||||
def brave_search(query: str, count: int = 10, time_filter: Optional[str] = None) -> List[dict]:
|
||||
"""Search using Brave API with key from admin settings or env var."""
|
||||
api_key = _get_provider_key("brave") or os.environ.get("DATA_BRAVE_API_KEY") or ""
|
||||
return _brave_search_impl(query, count, time_filter, search_config={"brave_api_key": api_key})
|
||||
|
||||
|
||||
def _brave_search_impl(query: str, count: int, time_filter: Optional[str] = None, search_config: dict = None) -> List[dict]:
|
||||
"""Core Brave API call. Returns a list of result dicts or an empty list on failure."""
|
||||
enhanced_query = build_enhanced_query(query, time_filter)
|
||||
config = search_config or {}
|
||||
|
||||
brave_api_key = config.get("brave_api_key")
|
||||
if not brave_api_key:
|
||||
brave_api_key = os.environ.get("DATA_BRAVE_API_KEY")
|
||||
|
||||
if not brave_api_key:
|
||||
logger.warning("Brave API key not found, returning empty results for fallback")
|
||||
return []
|
||||
|
||||
headers = {"X-Subscription-Token": brave_api_key, "Accept": "application/json"}
|
||||
params = {"q": enhanced_query, "count": count}
|
||||
if time_filter:
|
||||
time_map = {"day": "day", "week": "week", "month": "month", "year": "year"}
|
||||
if time_filter in time_map:
|
||||
params["freshness"] = time_map[time_filter]
|
||||
|
||||
logger.info(f"Executing Brave search with query: {enhanced_query}")
|
||||
try:
|
||||
response = httpx.get(
|
||||
"https://api.search.brave.com/res/v1/web/search",
|
||||
headers=headers,
|
||||
params=params,
|
||||
timeout=REQUEST_TIMEOUT,
|
||||
)
|
||||
if response.status_code == 429:
|
||||
raise RateLimitError("Brave rate limit hit")
|
||||
response.raise_for_status()
|
||||
except httpx.RequestError as e:
|
||||
error_logger.error(f"NetworkError during Brave search: {e}")
|
||||
return []
|
||||
except RateLimitError as e:
|
||||
error_logger.error(str(e))
|
||||
return []
|
||||
|
||||
try:
|
||||
data = response.json()
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"Failed to parse Brave API response: {e}")
|
||||
return []
|
||||
|
||||
results = []
|
||||
if "web" in data and "results" in data["web"]:
|
||||
for item in data["web"]["results"][:count]:
|
||||
url = item.get("url", "")
|
||||
if not url:
|
||||
continue
|
||||
results.append({
|
||||
"title": item.get("title", ""),
|
||||
"url": url,
|
||||
"snippet": item.get("description", "") or item.get("content", ""),
|
||||
"age": item.get("date", "") if item.get("date") else "",
|
||||
})
|
||||
|
||||
logger.info(f"Brave search returned {len(results)} results")
|
||||
return results
|
||||
|
||||
|
||||
# ── DuckDuckGo (free, no key) ──
|
||||
|
||||
def duckduckgo_search(query: str, count: int = 10, time_filter: Optional[str] = None) -> List[dict]:
|
||||
"""Search using DuckDuckGo via the duckduckgo-search library. No API key needed."""
|
||||
def _html_fallback() -> List[dict]:
|
||||
try:
|
||||
response = httpx.get(
|
||||
"https://html.duckduckgo.com/html/",
|
||||
params={"q": query},
|
||||
headers={"User-Agent": "Mozilla/5.0"},
|
||||
timeout=REQUEST_TIMEOUT,
|
||||
)
|
||||
response.raise_for_status()
|
||||
soup = BeautifulSoup(response.text, "html.parser")
|
||||
parsed = []
|
||||
for result in soup.select(".result")[:count]:
|
||||
link = result.select_one(".result__a")
|
||||
if not link:
|
||||
continue
|
||||
url = link.get("href", "")
|
||||
if not url:
|
||||
continue
|
||||
snippet_el = result.select_one(".result__snippet")
|
||||
parsed.append({
|
||||
"title": link.get_text(" ", strip=True),
|
||||
"url": url,
|
||||
"snippet": snippet_el.get_text(" ", strip=True) if snippet_el else "",
|
||||
})
|
||||
logger.info(f"DuckDuckGo HTML search returned {len(parsed)} results")
|
||||
return parsed
|
||||
except Exception as e:
|
||||
logger.warning(f"DuckDuckGo HTML search failed: {e}")
|
||||
return []
|
||||
|
||||
try:
|
||||
from duckduckgo_search import DDGS
|
||||
except ImportError:
|
||||
logger.warning("duckduckgo-search package not installed; using HTML fallback")
|
||||
return _html_fallback()
|
||||
|
||||
timelimit = None
|
||||
if time_filter:
|
||||
time_map = {"day": "d", "week": "w", "month": "m", "year": "y"}
|
||||
timelimit = time_map.get(time_filter)
|
||||
|
||||
try:
|
||||
ddgs = DDGS()
|
||||
raw = ddgs.text(query, max_results=count, timelimit=timelimit)
|
||||
results = []
|
||||
for item in raw:
|
||||
url = item.get("href", "")
|
||||
if not url:
|
||||
continue
|
||||
results.append({
|
||||
"title": item.get("title", ""),
|
||||
"url": url,
|
||||
"snippet": item.get("body", ""),
|
||||
})
|
||||
logger.info(f"DuckDuckGo search returned {len(results)} results")
|
||||
return results or _html_fallback()
|
||||
except Exception as e:
|
||||
logger.warning(f"DuckDuckGo search failed: {e}")
|
||||
return _html_fallback()
|
||||
|
||||
|
||||
# ── Google Programmable Search Engine ──
|
||||
|
||||
def google_pse_search(query: str, count: int = 10, time_filter: Optional[str] = None) -> List[dict]:
|
||||
"""Search using Google PSE (Custom Search JSON API).
|
||||
|
||||
Requires two keys in settings:
|
||||
- search_api_key: Google API key
|
||||
- google_pse_cx: Programmable Search Engine ID (cx)
|
||||
Or env vars GOOGLE_API_KEY and GOOGLE_PSE_CX.
|
||||
"""
|
||||
settings = _get_search_settings()
|
||||
api_key = _get_provider_key("google_pse") or os.environ.get("GOOGLE_API_KEY", "")
|
||||
cx = (settings.get("google_pse_cx") or "").strip() or os.environ.get("GOOGLE_PSE_CX", "")
|
||||
|
||||
if not api_key or not cx:
|
||||
logger.warning("Google PSE: missing API key or CX ID")
|
||||
return []
|
||||
|
||||
params = {
|
||||
"key": api_key,
|
||||
"cx": cx,
|
||||
"q": query,
|
||||
"num": min(count, 10), # Google PSE max is 10 per request
|
||||
}
|
||||
if time_filter:
|
||||
# dateRestrict: d[number], w[number], m[number], y[number]
|
||||
time_map = {"day": "d1", "week": "w1", "month": "m1", "year": "y1"}
|
||||
if time_filter in time_map:
|
||||
params["dateRestrict"] = time_map[time_filter]
|
||||
|
||||
try:
|
||||
response = httpx.get(
|
||||
"https://www.googleapis.com/customsearch/v1",
|
||||
params=params,
|
||||
timeout=REQUEST_TIMEOUT,
|
||||
)
|
||||
if response.status_code == 429:
|
||||
raise RateLimitError("Google PSE rate limit hit")
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
except httpx.RequestError as e:
|
||||
error_logger.error(f"Google PSE search failed: {e}")
|
||||
return []
|
||||
except RateLimitError as e:
|
||||
error_logger.error(str(e))
|
||||
return []
|
||||
|
||||
results = []
|
||||
for item in data.get("items", [])[:count]:
|
||||
url = item.get("link", "")
|
||||
if not url:
|
||||
continue
|
||||
results.append({
|
||||
"title": item.get("title", ""),
|
||||
"url": url,
|
||||
"snippet": item.get("snippet", ""),
|
||||
})
|
||||
|
||||
logger.info(f"Google PSE returned {len(results)} results")
|
||||
return results
|
||||
|
||||
|
||||
# ── Tavily ──
|
||||
|
||||
def tavily_search(query: str, count: int = 10, time_filter: Optional[str] = None) -> List[dict]:
|
||||
"""Search using Tavily API. Requires search_api_key or TAVILY_API_KEY env var."""
|
||||
api_key = _get_provider_key("tavily") or os.environ.get("TAVILY_API_KEY", "")
|
||||
if not api_key:
|
||||
logger.warning("Tavily: no API key configured")
|
||||
return []
|
||||
|
||||
payload = {
|
||||
"query": query,
|
||||
"max_results": count,
|
||||
"include_answer": False,
|
||||
}
|
||||
if time_filter:
|
||||
time_map = {"day": "day", "week": "week", "month": "month", "year": "year"}
|
||||
if time_filter in time_map:
|
||||
payload["days"] = {"day": 1, "week": 7, "month": 30, "year": 365}[time_filter]
|
||||
|
||||
try:
|
||||
response = httpx.post(
|
||||
"https://api.tavily.com/search",
|
||||
json=payload,
|
||||
headers={"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"},
|
||||
timeout=REQUEST_TIMEOUT,
|
||||
)
|
||||
if response.status_code == 429:
|
||||
raise RateLimitError("Tavily rate limit hit")
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
except httpx.RequestError as e:
|
||||
error_logger.error(f"Tavily search failed: {e}")
|
||||
return []
|
||||
except RateLimitError as e:
|
||||
error_logger.error(str(e))
|
||||
return []
|
||||
|
||||
results = []
|
||||
for item in data.get("results", [])[:count]:
|
||||
url = item.get("url", "")
|
||||
if not url:
|
||||
continue
|
||||
results.append({
|
||||
"title": item.get("title", ""),
|
||||
"url": url,
|
||||
"snippet": item.get("content", ""),
|
||||
"age": item.get("published_date", ""),
|
||||
})
|
||||
|
||||
logger.info(f"Tavily returned {len(results)} results")
|
||||
return results
|
||||
|
||||
|
||||
# ── Serper.dev ──
|
||||
|
||||
def serper_search(query: str, count: int = 10, time_filter: Optional[str] = None) -> List[dict]:
|
||||
"""Search using Serper.dev API. Requires search_api_key or SERPER_API_KEY env var."""
|
||||
api_key = _get_provider_key("serper") or os.environ.get("SERPER_API_KEY", "")
|
||||
if not api_key:
|
||||
logger.warning("Serper: no API key configured")
|
||||
return []
|
||||
|
||||
payload = {
|
||||
"q": query,
|
||||
"num": count,
|
||||
}
|
||||
if time_filter:
|
||||
time_map = {"day": "qdr:d", "week": "qdr:w", "month": "qdr:m", "year": "qdr:y"}
|
||||
if time_filter in time_map:
|
||||
payload["tbs"] = time_map[time_filter]
|
||||
|
||||
try:
|
||||
response = httpx.post(
|
||||
"https://google.serper.dev/search",
|
||||
json=payload,
|
||||
headers={"X-API-KEY": api_key, "Content-Type": "application/json"},
|
||||
timeout=REQUEST_TIMEOUT,
|
||||
)
|
||||
if response.status_code == 429:
|
||||
raise RateLimitError("Serper rate limit hit")
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
except httpx.RequestError as e:
|
||||
error_logger.error(f"Serper search failed: {e}")
|
||||
return []
|
||||
except RateLimitError as e:
|
||||
error_logger.error(str(e))
|
||||
return []
|
||||
|
||||
results = []
|
||||
for item in data.get("organic", [])[:count]:
|
||||
url = item.get("link", "")
|
||||
if not url:
|
||||
continue
|
||||
results.append({
|
||||
"title": item.get("title", ""),
|
||||
"url": url,
|
||||
"snippet": item.get("snippet", ""),
|
||||
"age": item.get("date", ""),
|
||||
})
|
||||
|
||||
logger.info(f"Serper returned {len(results)} results")
|
||||
return results
|
||||
@@ -0,0 +1,128 @@
|
||||
"""Query enhancement, entity extraction, and cache duration helpers."""
|
||||
|
||||
import re
|
||||
import logging
|
||||
from datetime import timedelta
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------
|
||||
# Query processing helpers
|
||||
# ----------------------------------------------------------------------
|
||||
def _detect_question_type(query: str) -> Optional[str]:
|
||||
"""Return the leading question word if present (who, what, when, where, why, how)."""
|
||||
q = query.strip().lower()
|
||||
for word in ("who", "what", "when", "where", "why", "how"):
|
||||
if q.startswith(word):
|
||||
return word
|
||||
return None
|
||||
|
||||
|
||||
def _extract_entities(query: str) -> Dict[str, List[str]]:
|
||||
"""Lightweight entity extraction: capitalized words and date patterns."""
|
||||
entities: Dict[str, List[str]] = {"names": [], "dates": []}
|
||||
qtype = _detect_question_type(query)
|
||||
cleaned = query
|
||||
if qtype:
|
||||
cleaned = re.sub(rf"^{qtype}\b", "", cleaned, flags=re.I).strip()
|
||||
for token in re.findall(r"\b[A-Z][a-zA-Z]+\b", cleaned):
|
||||
entities["names"].append(token)
|
||||
for year in re.findall(r"\b(19|20)\d{2}\b", cleaned):
|
||||
entities["dates"].append(year)
|
||||
month_day_year = re.findall(
|
||||
r"\b(?:Jan|January|Feb|February|Mar|March|Apr|April|May|Jun|June|Jul|July|Aug|August|Sep|Sept|September|Oct|October|Nov|November|Dec|December)\s+\d{1,2},?\s*\d{4}\b",
|
||||
cleaned,
|
||||
flags=re.I,
|
||||
)
|
||||
entities["dates"].extend(month_day_year)
|
||||
return entities
|
||||
|
||||
|
||||
def _split_multi_part(query: str) -> List[str]:
|
||||
"""Split a query into sub-queries on common conjunctions."""
|
||||
parts = re.split(r"\s+and\s+|\s+or\s+|;", query, flags=re.I)
|
||||
return [p.strip() for p in parts if p.strip()]
|
||||
|
||||
|
||||
def _extract_site_filter(query: str) -> Tuple[str, Optional[str]]:
|
||||
"""Detect a 'site:example.com' token. Returns (query_without_token, site_or_None)."""
|
||||
match = re.search(r"\bsite:([^\s]+)", query, flags=re.I)
|
||||
if match:
|
||||
site = match.group(1)
|
||||
new_query = re.sub(r"\bsite:[^\s]+", "", query, flags=re.I).strip()
|
||||
return new_query, site
|
||||
return query, None
|
||||
|
||||
|
||||
def _boost_entities_in_query(base_query: str, entities: Dict[str, List[str]]) -> str:
|
||||
"""Append extracted entities to the query using OR to increase relevance."""
|
||||
parts = [base_query]
|
||||
if entities.get("names"):
|
||||
parts.append(" OR ".join(f'"{n}"' for n in entities["names"]))
|
||||
if entities.get("dates"):
|
||||
parts.append(" OR ".join(f'"{d}"' for d in entities["dates"]))
|
||||
return " ".join(parts)
|
||||
|
||||
|
||||
def enhance_query(original_query: str) -> Tuple[str, Optional[str]]:
|
||||
"""Process the original query: site filter, question type boosts, entity extraction."""
|
||||
query_without_site, site = _extract_site_filter(original_query)
|
||||
sub_queries = _split_multi_part(query_without_site)
|
||||
|
||||
enhanced_subs: List[str] = []
|
||||
for sub in sub_queries:
|
||||
qtype = _detect_question_type(sub)
|
||||
boost_keywords = []
|
||||
if qtype == "who":
|
||||
boost_keywords.append("person")
|
||||
elif qtype == "when":
|
||||
boost_keywords.append("date")
|
||||
elif qtype == "where":
|
||||
boost_keywords.append("location")
|
||||
elif qtype == "why":
|
||||
boost_keywords.append("reason")
|
||||
elif qtype == "how":
|
||||
boost_keywords.append("method")
|
||||
entities = _extract_entities(sub)
|
||||
boosted = _boost_entities_in_query(sub, entities)
|
||||
if boost_keywords:
|
||||
boosted = f'({boosted}) OR ({" OR ".join(boost_keywords)})'
|
||||
enhanced_subs.append(boosted)
|
||||
|
||||
final_query = " AND ".join(f"({s})" for s in enhanced_subs)
|
||||
if site:
|
||||
final_query = f"{final_query} site:{site}"
|
||||
return final_query, site
|
||||
|
||||
|
||||
def build_enhanced_query(query: str, time_filter: str = None) -> str:
|
||||
"""Build an enhanced search query with optional time filtering."""
|
||||
enhanced_query, _ = enhance_query(query)
|
||||
|
||||
if time_filter:
|
||||
time_map = {"day": "d", "week": "w", "month": "m", "year": "y"}
|
||||
if time_filter in time_map:
|
||||
enhanced_query = f"{enhanced_query} after:{time_map[time_filter]}"
|
||||
logger.info(f"Added time filter '{time_filter}' to query")
|
||||
|
||||
logger.info(f"Enhanced query: '{query}' -> '{enhanced_query}'")
|
||||
return enhanced_query
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------
|
||||
# Cache duration helpers
|
||||
# ----------------------------------------------------------------------
|
||||
def _is_news_query(query: str) -> bool:
|
||||
"""Lightweight heuristic to decide if a query is news-oriented."""
|
||||
news_terms = {"news", "latest", "breaking", "today", "today's", "current", "updates", "happening"}
|
||||
tokens = set(re.findall(r"\b\w+\b", query.lower()))
|
||||
return bool(tokens & news_terms)
|
||||
|
||||
|
||||
def _cache_duration_for_query(query: str) -> timedelta:
|
||||
"""News queries -> 30 minutes, reference queries -> 24 hours."""
|
||||
if _is_news_query(query):
|
||||
return timedelta(minutes=30)
|
||||
return timedelta(hours=24)
|
||||
@@ -0,0 +1,127 @@
|
||||
"""Search result ranking based on relevance, source quality, and recency."""
|
||||
|
||||
import re
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
from urllib.parse import urlparse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_NEWS_HINTS = {"news", "nyheter", "headlines", "breaking", "latest", "today", "idag"}
|
||||
_SPORTS_HINTS = {
|
||||
"sport", "sports", "soccer", "football", "hockey", "nba", "nfl", "mlb",
|
||||
"fifa", "world cup", "championship", "quarterfinal", "eliminates",
|
||||
}
|
||||
_LOW_VALUE_NEWS_DOMAINS = {
|
||||
"facebook.com", "www.facebook.com", "sports.yahoo.com", "yahoo.com",
|
||||
"www.yahoo.com", "msn.com", "www.msn.com",
|
||||
}
|
||||
_TRUSTED_NEWS_DOMAINS = {
|
||||
"apnews.com", "www.apnews.com", "reuters.com", "www.reuters.com",
|
||||
"bbc.com", "www.bbc.com", "cbc.ca", "www.cbc.ca",
|
||||
"ctvnews.ca", "www.ctvnews.ca", "globalnews.ca", "www.globalnews.ca",
|
||||
"theguardian.com",
|
||||
"www.theguardian.com", "euronews.com", "www.euronews.com",
|
||||
"dw.com", "www.dw.com", "government.se", "www.government.se",
|
||||
}
|
||||
|
||||
|
||||
def _domain(url: str) -> str:
|
||||
try:
|
||||
return urlparse(url).netloc.lower()
|
||||
except Exception:
|
||||
return ""
|
||||
|
||||
|
||||
def rank_search_results(query: str, results: List[dict]) -> List[dict]:
|
||||
"""Rank search results by title relevance, snippet quality, domain authority, and recency."""
|
||||
query_terms = [t.lower() for t in re.findall(r"\b\w+\b", query)]
|
||||
query_lc = query.lower()
|
||||
is_news_query = any(term in _NEWS_HINTS for term in query_terms)
|
||||
is_sports_query = any(hint in query_lc for hint in _SPORTS_HINTS)
|
||||
|
||||
def title_score(title: str) -> float:
|
||||
if not title:
|
||||
return 0.0
|
||||
title_lc = title.lower()
|
||||
matches = sum(1 for term in query_terms if re.search(rf"\b{re.escape(term)}\b", title_lc))
|
||||
return matches / len(query_terms) if query_terms else 0.0
|
||||
|
||||
def snippet_score(snippet: str) -> float:
|
||||
if not snippet:
|
||||
return 0.0
|
||||
length_factor = min(len(snippet), 200) / 200
|
||||
term_hits = sum(1 for term in query_terms if term in snippet.lower())
|
||||
term_factor = term_hits / len(query_terms) if query_terms else 0.0
|
||||
return (length_factor + term_factor) / 2
|
||||
|
||||
def domain_score(url: str) -> float:
|
||||
netloc = _domain(url)
|
||||
if not netloc:
|
||||
return 0.0
|
||||
if netloc in _TRUSTED_NEWS_DOMAINS:
|
||||
return 1.0
|
||||
if netloc.endswith(".edu") or netloc.endswith(".gov"):
|
||||
return 1.0
|
||||
if netloc.endswith(".org"):
|
||||
return 0.7
|
||||
return 0.4
|
||||
|
||||
def recency_score(age_str: Optional[str]) -> float:
|
||||
if not age_str:
|
||||
return 0.0
|
||||
for fmt in ("%Y-%m-%d", "%Y-%m-%dT%H:%M:%S", "%Y-%m-%d %H:%M:%S"):
|
||||
try:
|
||||
dt = datetime.strptime(age_str, fmt)
|
||||
break
|
||||
except Exception:
|
||||
dt = None
|
||||
if not dt:
|
||||
return 0.0
|
||||
days_old = (datetime.now() - dt).days
|
||||
if days_old <= 7:
|
||||
return 1.0
|
||||
if days_old >= 30:
|
||||
return 0.0
|
||||
return (30 - days_old) / 23
|
||||
|
||||
def news_quality_adjustment(title: str, snippet: str, url: str) -> float:
|
||||
if not is_news_query:
|
||||
return 0.0
|
||||
text = f"{title} {snippet}".lower()
|
||||
netloc = _domain(url)
|
||||
adjustment = 0.0
|
||||
if netloc in _TRUSTED_NEWS_DOMAINS:
|
||||
adjustment += 1.2
|
||||
if any(term in text for term in ("latest news", "breaking news", "daily coverage", "news from")):
|
||||
adjustment += 0.4
|
||||
if netloc in _LOW_VALUE_NEWS_DOMAINS:
|
||||
adjustment -= 0.8
|
||||
if not is_sports_query and any(hint in text or hint in netloc for hint in _SPORTS_HINTS):
|
||||
adjustment -= 1.5
|
||||
# A country/news query should not rank a page whose title/snippet barely
|
||||
# mentions the country above actual news pages for that country.
|
||||
subject_terms = [t for t in query_terms if t not in _NEWS_HINTS]
|
||||
if subject_terms and not any(t in text or t in netloc for t in subject_terms):
|
||||
adjustment -= 1.0
|
||||
return adjustment
|
||||
|
||||
ranked = []
|
||||
for result in results:
|
||||
title = result.get("title", "")
|
||||
snippet = result.get("snippet", "")
|
||||
url = result.get("url", "")
|
||||
age = result.get("age", None)
|
||||
|
||||
score = (
|
||||
2.0 * title_score(title)
|
||||
+ 1.0 * snippet_score(snippet)
|
||||
+ 1.5 * domain_score(url)
|
||||
+ 1.0 * recency_score(age)
|
||||
+ news_quality_adjustment(title, snippet, url)
|
||||
)
|
||||
ranked.append((score, result))
|
||||
|
||||
ranked.sort(key=lambda x: x[0], reverse=True)
|
||||
return [r for _, r in ranked]
|
||||
@@ -0,0 +1,95 @@
|
||||
# services/search/service.py
|
||||
"""Search service — clean interface for web search."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Dict, Any
|
||||
|
||||
from . import (
|
||||
comprehensive_web_search,
|
||||
fetch_webpage_content,
|
||||
get_search_config,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SearchResult:
|
||||
"""A single search result."""
|
||||
url: str
|
||||
title: str
|
||||
snippet: str
|
||||
content: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class SearchResponse:
|
||||
"""Response from a search query."""
|
||||
query: str
|
||||
results: List[SearchResult]
|
||||
total: int
|
||||
cached: bool = False
|
||||
|
||||
|
||||
class SearchService:
|
||||
"""
|
||||
Web search service.
|
||||
|
||||
Usage:
|
||||
service = SearchService()
|
||||
result = await service.search("python async patterns")
|
||||
for r in result.results:
|
||||
print(f"{r.title}: {r.url}")
|
||||
"""
|
||||
|
||||
def __init__(self, default_depth: int = 1, fetch_content: bool = True):
|
||||
self.default_depth = default_depth
|
||||
self.fetch_content = fetch_content
|
||||
|
||||
async def search(
|
||||
self,
|
||||
query: str,
|
||||
depth: Optional[int] = None,
|
||||
fetch_content: Optional[bool] = None,
|
||||
) -> SearchResponse:
|
||||
"""
|
||||
Search the web.
|
||||
|
||||
Args:
|
||||
query: Search query
|
||||
depth: Search depth (1=quick, 2=thorough, 3=comprehensive)
|
||||
fetch_content: Whether to fetch full page content
|
||||
|
||||
Returns:
|
||||
SearchResponse with results
|
||||
"""
|
||||
depth = depth or self.default_depth
|
||||
fetch_content = fetch_content if fetch_content is not None else self.fetch_content
|
||||
|
||||
# Use existing search implementation
|
||||
raw_results = await comprehensive_web_search(
|
||||
query,
|
||||
max_results=10 * depth,
|
||||
fetch_content=fetch_content,
|
||||
)
|
||||
|
||||
results = []
|
||||
for r in raw_results:
|
||||
results.append(SearchResult(
|
||||
url=r.get("url", ""),
|
||||
title=r.get("title", ""),
|
||||
snippet=r.get("snippet", ""),
|
||||
content=r.get("content"),
|
||||
))
|
||||
|
||||
return SearchResponse(
|
||||
query=query,
|
||||
results=results,
|
||||
total=len(results),
|
||||
)
|
||||
|
||||
async def fetch_content(self, url: str) -> Optional[str]:
|
||||
"""Fetch content from a URL."""
|
||||
return await fetch_webpage_content(url)
|
||||
|
||||
def get_config(self) -> Dict[str, Any]:
|
||||
"""Get current search configuration."""
|
||||
return get_search_config()
|
||||
@@ -0,0 +1,6 @@
|
||||
# services/shell/__init__.py
|
||||
"""Shell service — safe command execution."""
|
||||
|
||||
from .service import ShellService, ShellResult
|
||||
|
||||
__all__ = ["ShellService", "ShellResult"]
|
||||
@@ -0,0 +1,162 @@
|
||||
# services/shell/service.py
|
||||
"""Shell service — safe command execution."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, AsyncIterator
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
@dataclass
|
||||
class ShellResult:
|
||||
"""Result of a shell command."""
|
||||
stdout: str
|
||||
stderr: str
|
||||
exit_code: int
|
||||
timed_out: bool = False
|
||||
|
||||
|
||||
class ShellService:
|
||||
"""
|
||||
Shell execution service.
|
||||
|
||||
Usage:
|
||||
service = ShellService()
|
||||
result = await service.execute("ls -la")
|
||||
print(result.stdout)
|
||||
"""
|
||||
|
||||
def __init__(self, timeout: int = 30, max_output: int = 200_000):
|
||||
self.timeout = timeout
|
||||
self.max_output = max_output
|
||||
self.cwd = str(Path.home())
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
command: str,
|
||||
timeout: Optional[int] = None,
|
||||
cwd: Optional[str] = None,
|
||||
) -> ShellResult:
|
||||
"""
|
||||
Execute a shell command.
|
||||
|
||||
Args:
|
||||
command: Shell command to run
|
||||
timeout: Timeout in seconds (default: self.timeout)
|
||||
cwd: Working directory (default: home)
|
||||
|
||||
Returns:
|
||||
ShellResult with stdout, stderr, exit_code
|
||||
"""
|
||||
timeout = timeout or self.timeout
|
||||
cwd = cwd or self.cwd
|
||||
|
||||
proc = None
|
||||
try:
|
||||
proc = await asyncio.create_subprocess_shell(
|
||||
command,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
cwd=cwd,
|
||||
)
|
||||
stdout_b, stderr_b = await asyncio.wait_for(
|
||||
proc.communicate(), timeout=timeout
|
||||
)
|
||||
stdout = stdout_b.decode(errors="replace")[:self.max_output]
|
||||
stderr = stderr_b.decode(errors="replace")[:self.max_output]
|
||||
return ShellResult(
|
||||
stdout=stdout,
|
||||
stderr=stderr,
|
||||
exit_code=proc.returncode,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
if proc:
|
||||
try:
|
||||
proc.kill()
|
||||
await proc.wait()
|
||||
except ProcessLookupError:
|
||||
pass
|
||||
return ShellResult(
|
||||
stdout="",
|
||||
stderr=f"Command timed out after {timeout}s",
|
||||
exit_code=-1,
|
||||
timed_out=True,
|
||||
)
|
||||
except Exception as e:
|
||||
return ShellResult(stdout="", stderr=str(e), exit_code=-1)
|
||||
|
||||
async def stream(
|
||||
self,
|
||||
command: str,
|
||||
timeout: int = 120,
|
||||
) -> AsyncIterator[dict]:
|
||||
"""
|
||||
Execute a command and stream output.
|
||||
|
||||
Yields:
|
||||
{"stream": "stdout"|"stderr", "data": line}
|
||||
{"exit_code": int}
|
||||
"""
|
||||
|
||||
proc = None
|
||||
reader_tasks = []
|
||||
try:
|
||||
proc = await asyncio.create_subprocess_shell(
|
||||
command,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
cwd=self.cwd,
|
||||
)
|
||||
|
||||
q: asyncio.Queue = asyncio.Queue()
|
||||
|
||||
async def _reader(stream, name):
|
||||
try:
|
||||
while True:
|
||||
line = await stream.readline()
|
||||
if not line:
|
||||
break
|
||||
await q.put((name, line.decode(errors="replace").rstrip("\n")))
|
||||
finally:
|
||||
await q.put((name, None))
|
||||
|
||||
reader_tasks = [
|
||||
asyncio.create_task(_reader(proc.stdout, "stdout")),
|
||||
asyncio.create_task(_reader(proc.stderr, "stderr")),
|
||||
]
|
||||
|
||||
finished = 0
|
||||
deadline = asyncio.get_event_loop().time() + timeout
|
||||
while finished < 2:
|
||||
remaining = deadline - asyncio.get_event_loop().time()
|
||||
if remaining <= 0:
|
||||
raise asyncio.TimeoutError()
|
||||
|
||||
try:
|
||||
name, text = await asyncio.wait_for(q.get(), timeout=min(remaining, 2.0))
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
|
||||
if text is None:
|
||||
finished += 1
|
||||
continue
|
||||
yield {"stream": name, "data": text}
|
||||
|
||||
await proc.wait()
|
||||
yield {"exit_code": proc.returncode}
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
if proc:
|
||||
try:
|
||||
proc.kill()
|
||||
await proc.wait()
|
||||
except ProcessLookupError:
|
||||
pass
|
||||
yield {"stream": "stderr", "data": f"Command timed out after {timeout}s"}
|
||||
yield {"exit_code": -1}
|
||||
except Exception as e:
|
||||
yield {"stream": "stderr", "data": str(e)}
|
||||
yield {"exit_code": -1}
|
||||
finally:
|
||||
for t in reader_tasks:
|
||||
t.cancel()
|
||||
@@ -0,0 +1,3 @@
|
||||
from services.stt.stt_service import get_stt_service
|
||||
|
||||
__all__ = ["get_stt_service"]
|
||||
@@ -0,0 +1,191 @@
|
||||
# services/stt/stt_service.py
|
||||
"""Multi-provider Speech-to-Text service — dispatches to local Whisper, OpenAI-compatible API, or browser."""
|
||||
|
||||
import io
|
||||
import logging
|
||||
import httpx
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Optional, Dict, Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class STTService:
|
||||
"""Multi-provider STT service.
|
||||
|
||||
Reads provider config from data/settings.json on each call.
|
||||
Providers:
|
||||
"disabled" — no STT
|
||||
"browser" — client-side Web Speech API (no server transcription)
|
||||
"local" — faster-whisper on CPU/GPU
|
||||
"endpoint:<id>" — OpenAI-compatible /audio/transcriptions via ModelEndpoint
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._whisper_model = None # lazy-init
|
||||
|
||||
# ── Settings ──
|
||||
|
||||
def _load_settings(self) -> dict:
|
||||
from src.settings import load_settings
|
||||
saved = load_settings()
|
||||
return {
|
||||
"stt_enabled": saved.get("stt_enabled", False),
|
||||
"stt_provider": saved.get("stt_provider", "disabled"),
|
||||
"stt_model": saved.get("stt_model", "base"),
|
||||
"stt_language": saved.get("stt_language", ""),
|
||||
}
|
||||
|
||||
@property
|
||||
def available(self) -> bool:
|
||||
settings = self._load_settings()
|
||||
provider = settings["stt_provider"]
|
||||
if provider == "disabled":
|
||||
return False
|
||||
if provider == "browser":
|
||||
return True # handled client-side
|
||||
if provider == "local":
|
||||
return self._get_whisper() is not None
|
||||
if provider.startswith("endpoint:"):
|
||||
return True # assume reachable
|
||||
return False
|
||||
|
||||
# ── Local Whisper ──
|
||||
|
||||
def _get_whisper(self):
|
||||
if self._whisper_model is None:
|
||||
try:
|
||||
from faster_whisper import WhisperModel
|
||||
settings = self._load_settings()
|
||||
model_size = settings.get("stt_model", "base")
|
||||
# Use CPU by default; will use CUDA if available
|
||||
import torch
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
compute_type = "float16" if device == "cuda" else "int8"
|
||||
self._whisper_model = WhisperModel(model_size, device=device, compute_type=compute_type)
|
||||
logger.info(f"faster-whisper model '{model_size}' loaded on {device}")
|
||||
except ImportError:
|
||||
logger.warning("faster-whisper not installed. Install with: pip install faster-whisper")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load whisper model: {e}")
|
||||
return None
|
||||
return self._whisper_model
|
||||
|
||||
def _transcribe_local(self, audio_bytes: bytes, language: str = "") -> Optional[str]:
|
||||
model = self._get_whisper()
|
||||
if not model:
|
||||
return None
|
||||
try:
|
||||
# Write to temp file (faster-whisper needs a file path or file-like)
|
||||
with tempfile.NamedTemporaryFile(suffix=".webm", delete=False) as tmp:
|
||||
tmp.write(audio_bytes)
|
||||
tmp_path = tmp.name
|
||||
|
||||
kwargs = {}
|
||||
if language:
|
||||
kwargs["language"] = language
|
||||
|
||||
segments, info = model.transcribe(tmp_path, **kwargs)
|
||||
text = " ".join(seg.text.strip() for seg in segments)
|
||||
|
||||
# Cleanup
|
||||
Path(tmp_path).unlink(missing_ok=True)
|
||||
|
||||
logger.info(f"Local STT: {len(text)} chars, lang={info.language}, prob={info.language_probability:.2f}")
|
||||
return text
|
||||
except Exception as e:
|
||||
logger.error(f"Local STT transcription failed: {e}", exc_info=True)
|
||||
return None
|
||||
|
||||
# ── API endpoint ──
|
||||
|
||||
def _transcribe_api(self, audio_bytes: bytes, endpoint_id: str, model: str, language: str = "") -> Optional[str]:
|
||||
from src.database import SessionLocal, ModelEndpoint
|
||||
|
||||
db = SessionLocal()
|
||||
try:
|
||||
ep = db.query(ModelEndpoint).filter(ModelEndpoint.id == endpoint_id).first()
|
||||
if not ep:
|
||||
logger.error(f"STT endpoint {endpoint_id} not found")
|
||||
return None
|
||||
base_url = ep.base_url.rstrip("/")
|
||||
api_key = ep.api_key
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
url = base_url + "/audio/transcriptions"
|
||||
headers = {}
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
files = {"file": ("audio.webm", io.BytesIO(audio_bytes), "audio/webm")}
|
||||
data = {"model": model or "whisper-1"}
|
||||
if language:
|
||||
data["language"] = language
|
||||
|
||||
try:
|
||||
r = httpx.post(url, headers=headers, files=files, data=data, timeout=60)
|
||||
r.raise_for_status()
|
||||
result = r.json()
|
||||
text = result.get("text", "")
|
||||
logger.info(f"API STT: {len(text)} chars from {base_url}")
|
||||
return text
|
||||
except Exception as e:
|
||||
logger.error(f"API STT transcription failed: {e}")
|
||||
return None
|
||||
|
||||
# ── Public interface ──
|
||||
|
||||
def transcribe(self, audio_bytes: bytes) -> Optional[str]:
|
||||
settings = self._load_settings()
|
||||
provider = settings["stt_provider"]
|
||||
model = settings["stt_model"]
|
||||
language = settings.get("stt_language", "")
|
||||
|
||||
if provider in ("disabled", "browser"):
|
||||
return None
|
||||
|
||||
if provider == "local":
|
||||
return self._transcribe_local(audio_bytes, language)
|
||||
elif provider.startswith("endpoint:"):
|
||||
endpoint_id = provider.split(":", 1)[1]
|
||||
return self._transcribe_api(audio_bytes, endpoint_id, model, language)
|
||||
else:
|
||||
logger.error(f"Unknown STT provider: {provider}")
|
||||
return None
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
settings = self._load_settings()
|
||||
provider = settings["stt_provider"]
|
||||
stt_enabled = settings.get("stt_enabled", False)
|
||||
# If toggle is off, report as disabled
|
||||
effective_provider = provider if stt_enabled else "disabled"
|
||||
|
||||
stats = {
|
||||
"available": self.available and stt_enabled,
|
||||
"provider": effective_provider,
|
||||
"model": settings["stt_model"],
|
||||
"language": settings.get("stt_language", ""),
|
||||
}
|
||||
|
||||
if provider == "local":
|
||||
whisper = self._get_whisper()
|
||||
stats["model_loaded"] = whisper is not None
|
||||
elif provider == "browser":
|
||||
stats["model"] = "Browser (Web Speech API)"
|
||||
elif provider.startswith("endpoint:"):
|
||||
stats["endpoint_id"] = provider.split(":", 1)[1]
|
||||
|
||||
return stats
|
||||
|
||||
|
||||
# Module-level singleton
|
||||
_stt_service = None
|
||||
|
||||
def get_stt_service() -> STTService:
|
||||
global _stt_service
|
||||
if _stt_service is None:
|
||||
_stt_service = STTService()
|
||||
return _stt_service
|
||||
@@ -0,0 +1,9 @@
|
||||
# services/tts/__init__.py
|
||||
"""TTS service — text-to-speech."""
|
||||
|
||||
from .tts_service import (
|
||||
TTSService,
|
||||
get_tts_service,
|
||||
)
|
||||
|
||||
__all__ = ["TTSService", "get_tts_service"]
|
||||
@@ -0,0 +1,278 @@
|
||||
# src/tts_service.py
|
||||
"""Multi-provider TTS service — dispatches to local Kokoro, OpenAI-compatible API, or browser."""
|
||||
|
||||
import io
|
||||
import wave
|
||||
import logging
|
||||
import hashlib
|
||||
import httpx
|
||||
from pathlib import Path
|
||||
from typing import Optional, Dict, Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TTSService:
|
||||
"""Multi-provider TTS service.
|
||||
|
||||
Reads provider config from data/settings.json on each call.
|
||||
Providers:
|
||||
"disabled" — no TTS
|
||||
"browser" — client-side Web Speech API (no server synthesis)
|
||||
"local" — Kokoro-82M on GPU
|
||||
"endpoint:<id>" — OpenAI-compatible /audio/speech via ModelEndpoint
|
||||
"""
|
||||
|
||||
def __init__(self, cache_dir: str = "data/tts_cache"):
|
||||
self.cache_dir = Path(cache_dir)
|
||||
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
self._kokoro = None # lazy-init
|
||||
|
||||
# ── Settings ──
|
||||
|
||||
def _load_settings(self) -> dict:
|
||||
from src.settings import load_settings
|
||||
saved = load_settings()
|
||||
return {
|
||||
"tts_provider": saved.get("tts_provider", "disabled"),
|
||||
"tts_model": saved.get("tts_model", "tts-1"),
|
||||
"tts_voice": saved.get("tts_voice", "alloy"),
|
||||
"tts_speed": saved.get("tts_speed", "1"),
|
||||
}
|
||||
|
||||
@property
|
||||
def available(self) -> bool:
|
||||
settings = self._load_settings()
|
||||
provider = settings["tts_provider"]
|
||||
if provider == "disabled":
|
||||
return False
|
||||
if provider == "browser":
|
||||
return True # handled client-side
|
||||
if provider == "local":
|
||||
kokoro = self._get_kokoro()
|
||||
return kokoro is not None and kokoro.available
|
||||
if provider.startswith("endpoint:"):
|
||||
return True # assume reachable; errors surface at synthesis time
|
||||
return False
|
||||
|
||||
# ── Cache ──
|
||||
|
||||
def _cache_key(self, text: str, provider: str, model: str, voice: str, speed: float = 1.0) -> str:
|
||||
raw = f"{provider}|{model}|{voice}|{speed}|{text}"
|
||||
return hashlib.sha256(raw.encode()).hexdigest()
|
||||
|
||||
def _get_cached(self, key: str) -> Optional[bytes]:
|
||||
for ext in (".mp3", ".wav"):
|
||||
path = self.cache_dir / f"{key}{ext}"
|
||||
if path.exists():
|
||||
return path.read_bytes()
|
||||
return None
|
||||
|
||||
def _put_cache(self, key: str, data: bytes):
|
||||
ext = ".mp3" if (len(data) >= 3 and (data[:3] == b'ID3' or (data[0] == 0xff and (data[1] & 0xe0) == 0xe0))) else ".wav"
|
||||
(self.cache_dir / f"{key}{ext}").write_bytes(data)
|
||||
|
||||
def clear_cache(self):
|
||||
count = 0
|
||||
for f in self.cache_dir.glob("*.*"):
|
||||
f.unlink()
|
||||
count += 1
|
||||
logger.info(f"Cleared {count} cached TTS files")
|
||||
|
||||
# ── Kokoro (local) ──
|
||||
|
||||
def _get_kokoro(self):
|
||||
if self._kokoro is None:
|
||||
self._kokoro = _KokoroPipeline()
|
||||
return self._kokoro
|
||||
|
||||
# ── API endpoint ──
|
||||
|
||||
def _synthesize_api(self, text: str, endpoint_id: str, model: str, voice: str, speed: float = 1.0) -> Optional[bytes]:
|
||||
from src.database import SessionLocal, ModelEndpoint
|
||||
|
||||
db = SessionLocal()
|
||||
try:
|
||||
ep = db.query(ModelEndpoint).filter(ModelEndpoint.id == endpoint_id).first()
|
||||
if not ep:
|
||||
logger.error(f"TTS endpoint {endpoint_id} not found")
|
||||
return None
|
||||
base_url = ep.base_url.rstrip("/")
|
||||
api_key = ep.api_key
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
url = base_url + "/audio/speech"
|
||||
headers = {"Content-Type": "application/json"}
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
payload = {
|
||||
"model": model,
|
||||
"input": text,
|
||||
"voice": voice,
|
||||
"response_format": "mp3",
|
||||
"speed": speed,
|
||||
}
|
||||
|
||||
try:
|
||||
r = httpx.post(url, json=payload, headers=headers, timeout=60)
|
||||
r.raise_for_status()
|
||||
logger.info(f"API TTS: {len(r.content)} bytes from {base_url}")
|
||||
return r.content
|
||||
except Exception as e:
|
||||
logger.error(f"API TTS synthesis failed: {e}")
|
||||
return None
|
||||
|
||||
# ── Public interface ──
|
||||
|
||||
def synthesize(self, text: str, use_cache: bool = True) -> Optional[bytes]:
|
||||
settings = self._load_settings()
|
||||
provider = settings["tts_provider"]
|
||||
model = settings["tts_model"]
|
||||
voice = settings["tts_voice"]
|
||||
speed = float(settings.get("tts_speed", "1"))
|
||||
|
||||
if provider in ("disabled", "browser"):
|
||||
return None
|
||||
|
||||
if len(text) > 5000:
|
||||
text = text[:5000]
|
||||
|
||||
if use_cache:
|
||||
key = self._cache_key(text, provider, model, voice, speed)
|
||||
cached = self._get_cached(key)
|
||||
if cached:
|
||||
logger.info(f"TTS cache hit ({len(text)} chars)")
|
||||
return cached
|
||||
|
||||
audio_data = None
|
||||
|
||||
if provider == "local":
|
||||
kokoro = self._get_kokoro()
|
||||
if kokoro and kokoro.available:
|
||||
audio_data = kokoro.synthesize_raw(text, voice)
|
||||
else:
|
||||
logger.warning("Kokoro TTS not available")
|
||||
return None
|
||||
elif provider.startswith("endpoint:"):
|
||||
endpoint_id = provider.split(":", 1)[1]
|
||||
audio_data = self._synthesize_api(text, endpoint_id, model, voice, speed)
|
||||
else:
|
||||
logger.error(f"Unknown TTS provider: {provider}")
|
||||
return None
|
||||
|
||||
if audio_data and use_cache:
|
||||
key = self._cache_key(text, provider, model, voice, speed)
|
||||
self._put_cache(key, audio_data)
|
||||
|
||||
return audio_data
|
||||
|
||||
def synthesize_to_base64(self, text: str) -> Optional[str]:
|
||||
import base64
|
||||
audio = self.synthesize(text)
|
||||
if audio:
|
||||
return base64.b64encode(audio).decode("utf-8")
|
||||
return None
|
||||
|
||||
def set_voice(self, voice: str):
|
||||
"""Legacy no-op — voice is now managed via admin settings."""
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
settings = self._load_settings()
|
||||
provider = settings["tts_provider"]
|
||||
tts_enabled = settings.get("tts_enabled", True)
|
||||
|
||||
cache_files = list(self.cache_dir.glob("*.wav"))
|
||||
cache_size = sum(f.stat().st_size for f in cache_files)
|
||||
|
||||
is_available = self.available and tts_enabled
|
||||
stats = {
|
||||
"available": is_available,
|
||||
"ready": is_available,
|
||||
"provider": provider,
|
||||
"model": settings["tts_model"],
|
||||
"voice": settings["tts_voice"],
|
||||
"speed": float(settings.get("tts_speed", "1")),
|
||||
"cache_entries": len(cache_files),
|
||||
"cache_size_mb": round(cache_size / (1024 * 1024), 2),
|
||||
}
|
||||
|
||||
if provider == "local":
|
||||
kokoro = self._get_kokoro()
|
||||
stats["model"] = "Kokoro-82M (GPU)" if (kokoro and kokoro.available) else "Kokoro (not loaded)"
|
||||
elif provider == "browser":
|
||||
stats["model"] = "Browser (Web Speech API)"
|
||||
elif provider.startswith("endpoint:"):
|
||||
stats["endpoint_id"] = provider.split(":", 1)[1]
|
||||
|
||||
return stats
|
||||
|
||||
|
||||
class _KokoroPipeline:
|
||||
"""Encapsulates the Kokoro-82M local GPU pipeline."""
|
||||
|
||||
def __init__(self):
|
||||
self.pipeline = None
|
||||
self.available = False
|
||||
self.device = None
|
||||
self._init()
|
||||
|
||||
def _init(self):
|
||||
try:
|
||||
import torch
|
||||
from kokoro import KPipeline
|
||||
|
||||
if not torch.cuda.is_available():
|
||||
logger.warning("CUDA not available for Kokoro TTS")
|
||||
return
|
||||
|
||||
self.device = torch.device("cuda:0")
|
||||
with torch.cuda.device(0):
|
||||
self.pipeline = KPipeline(lang_code="a")
|
||||
if hasattr(self.pipeline, "model"):
|
||||
self.pipeline.model = self.pipeline.model.to(self.device)
|
||||
self.available = True
|
||||
logger.info("Kokoro-82M TTS pipeline loaded")
|
||||
except ImportError as e:
|
||||
logger.warning(f"Kokoro TTS not available: {e}")
|
||||
logger.warning("Install with: pip install kokoro soundfile")
|
||||
except Exception as e:
|
||||
logger.error(f"Kokoro init failed: {e}", exc_info=True)
|
||||
|
||||
def synthesize_raw(self, text: str, voice: str = "af_heart") -> Optional[bytes]:
|
||||
if not self.available:
|
||||
return None
|
||||
try:
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
with torch.cuda.device(self.device):
|
||||
chunks = []
|
||||
for _, _, audio in self.pipeline(text, voice=voice):
|
||||
chunks.append(audio)
|
||||
|
||||
if not chunks:
|
||||
return None
|
||||
|
||||
full = np.concatenate(chunks)
|
||||
buf = io.BytesIO()
|
||||
with wave.open(buf, "wb") as wf:
|
||||
wf.setnchannels(1)
|
||||
wf.setsampwidth(2)
|
||||
wf.setframerate(24000)
|
||||
wf.writeframes((full * 32767).astype(np.int16).tobytes())
|
||||
return buf.getvalue()
|
||||
except Exception as e:
|
||||
logger.error(f"Kokoro synthesis failed: {e}", exc_info=True)
|
||||
return None
|
||||
|
||||
|
||||
# Module-level singleton
|
||||
_tts_service = None
|
||||
|
||||
def get_tts_service() -> TTSService:
|
||||
global _tts_service
|
||||
if _tts_service is None:
|
||||
_tts_service = TTSService()
|
||||
return _tts_service
|
||||
@@ -0,0 +1,22 @@
|
||||
# services/youtube/__init__.py
|
||||
"""YouTube service — transcript extraction."""
|
||||
|
||||
from .youtube_handler import (
|
||||
init_youtube,
|
||||
is_youtube_url,
|
||||
extract_youtube_id,
|
||||
extract_transcript_async,
|
||||
format_transcript_for_context,
|
||||
fetch_youtube_comments,
|
||||
format_comments_for_context,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"init_youtube",
|
||||
"is_youtube_url",
|
||||
"extract_youtube_id",
|
||||
"extract_transcript_async",
|
||||
"format_transcript_for_context",
|
||||
"fetch_youtube_comments",
|
||||
"format_comments_for_context",
|
||||
]
|
||||
@@ -0,0 +1,265 @@
|
||||
"""
|
||||
YouTube handling — transcript extraction, comment fetching (yt-dlp),
|
||||
and context formatting for LLM injection. Used by chat_handler.py.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import shutil
|
||||
import sys
|
||||
import urllib.parse
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Constants
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
YOUTUBE_INSTRUCTION_PROMPT = """When the user shares a YouTube video, respond with a structured breakdown:
|
||||
|
||||
1. **Summary** — Concise overview of the video's content and main thesis (2-4 sentences)
|
||||
2. **Key Points** — Bullet list of the most important topics, arguments, or moments
|
||||
3. **Notable Timestamps** — If timestamps are available from the transcript, highlight 3-5 interesting moments with their approximate timestamps (e.g. "03:45 — discusses X")
|
||||
4. **Audience Reception** — If comments are available, summarize what viewers think: general sentiment, top reactions, any debate or controversy
|
||||
|
||||
Keep it conversational and concise. Do NOT web search for this video — use only the transcript and comments provided."""
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Init / helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Will be set at startup by init_youtube()
|
||||
YouTubeTranscriptApi = None
|
||||
YOUTUBE_AVAILABLE = False
|
||||
|
||||
|
||||
def _find_ytdlp() -> str:
|
||||
"""Find the yt-dlp binary: venv bin first, then system PATH."""
|
||||
venv_bin = Path(sys.executable).parent / "yt-dlp"
|
||||
if venv_bin.exists():
|
||||
return str(venv_bin)
|
||||
found = shutil.which("yt-dlp")
|
||||
return found or "yt-dlp"
|
||||
|
||||
|
||||
def init_youtube():
|
||||
"""Import and cache the YouTube transcript API."""
|
||||
global YouTubeTranscriptApi, YOUTUBE_AVAILABLE
|
||||
try:
|
||||
from youtube_transcript_api import YouTubeTranscriptApi as _Api
|
||||
YouTubeTranscriptApi = _Api
|
||||
YOUTUBE_AVAILABLE = True
|
||||
logger.info("YouTube transcript API available")
|
||||
except ImportError as e:
|
||||
logger.warning(f"youtube-transcript-api not installed: {e}")
|
||||
YOUTUBE_AVAILABLE = False
|
||||
|
||||
|
||||
def is_youtube_url(url: str) -> bool:
|
||||
return "youtube.com" in url or "youtu.be" in url
|
||||
|
||||
|
||||
def extract_youtube_id(url: str) -> Optional[str]:
|
||||
"""Extract YouTube video ID from various URL formats."""
|
||||
parsed = urllib.parse.urlparse(url)
|
||||
if parsed.hostname in ("www.youtube.com", "youtube.com", "m.youtube.com"):
|
||||
if parsed.path == "/watch":
|
||||
params = urllib.parse.parse_qs(parsed.query)
|
||||
if "v" in params:
|
||||
return params["v"][0]
|
||||
elif parsed.path.startswith("/embed/"):
|
||||
return parsed.path.split("/")[-1]
|
||||
elif parsed.hostname == "youtu.be":
|
||||
return parsed.path[1:]
|
||||
return None
|
||||
|
||||
|
||||
async def extract_transcript_async(
|
||||
url: str, video_id: str, max_retries: int = 3
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Async YouTube transcript extraction with retries.
|
||||
|
||||
Args:
|
||||
url: Full YouTube URL
|
||||
video_id: Extracted video ID
|
||||
max_retries: Number of attempts
|
||||
|
||||
Returns:
|
||||
Dict with success/error/transcript keys
|
||||
"""
|
||||
if not YOUTUBE_AVAILABLE or YouTubeTranscriptApi is None:
|
||||
return {"success": False, "error": "YouTube transcript API not available", "transcript": None}
|
||||
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
api = YouTubeTranscriptApi()
|
||||
transcript = api.fetch(video_id)
|
||||
transcript_list = list(transcript)
|
||||
|
||||
formatted = []
|
||||
for snippet in transcript_list:
|
||||
text = snippet.text.strip()
|
||||
if not text:
|
||||
continue
|
||||
start = snippet.start
|
||||
formatted.append({
|
||||
"text": text,
|
||||
"start": start,
|
||||
"duration": snippet.duration,
|
||||
"timestamp": f"{int(start // 60):02d}:{int(start % 60):02d}",
|
||||
})
|
||||
|
||||
full_text = " ".join(e["text"] for e in formatted)
|
||||
max_len = 8000
|
||||
if len(full_text) > max_len:
|
||||
full_text = full_text[:max_len] + "... [transcript truncated]"
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"transcript": full_text,
|
||||
"video_id": video_id,
|
||||
"language": "en",
|
||||
"is_generated": False,
|
||||
"segments": formatted,
|
||||
}
|
||||
except Exception as e:
|
||||
logger.warning(f"Transcript attempt {attempt + 1} failed: {e}")
|
||||
if attempt < max_retries - 1:
|
||||
await asyncio.sleep(1 * (attempt + 1))
|
||||
|
||||
return {"success": False, "error": f"Failed after {max_retries} attempts", "transcript": None}
|
||||
|
||||
|
||||
def format_transcript_for_context(
|
||||
transcript_data: Dict[str, Any], url: str,
|
||||
title: str = "", channel: str = ""
|
||||
) -> str:
|
||||
"""Format transcript data for inclusion in LLM context."""
|
||||
if not transcript_data.get("success"):
|
||||
header = ""
|
||||
if title:
|
||||
header = f" \"{title}\""
|
||||
if channel:
|
||||
header += f" by {channel}"
|
||||
return f"\n[YouTube Video{header}: Transcript unavailable ({transcript_data.get('error', 'Unknown error')}). Use the comments below if available, do NOT web search for this video.]"
|
||||
|
||||
transcript = transcript_data.get("transcript", "")
|
||||
video_id = transcript_data.get("video_id", "")
|
||||
language = transcript_data.get("language", "unknown")
|
||||
is_generated = transcript_data.get("is_generated", False)
|
||||
segments = transcript_data.get("segments", [])
|
||||
|
||||
ctx = "\n[YOUTUBE VIDEO TRANSCRIPT]\n"
|
||||
if title:
|
||||
ctx += f"Title: {title}\n"
|
||||
if channel:
|
||||
ctx += f"Channel: {channel}\n"
|
||||
ctx += f"Video ID: {video_id}\n"
|
||||
ctx += f"Language: {language}\n"
|
||||
ctx += f"Source: {'Auto-generated' if is_generated else 'Manual'}\n"
|
||||
ctx += f"URL: {url}\n\n"
|
||||
# Include timestamped segments for the LLM to reference
|
||||
if segments:
|
||||
ctx += "Timestamped Transcript:\n"
|
||||
for seg in segments:
|
||||
ctx += f"[{seg['timestamp']}] {seg['text']}\n"
|
||||
# Check length — fall back to plain text if too long
|
||||
if len(ctx) > 12000:
|
||||
ctx = ctx[:ctx.index("Timestamped Transcript:\n")]
|
||||
ctx += "Transcript:\n"
|
||||
ctx += transcript
|
||||
else:
|
||||
ctx += "Transcript:\n"
|
||||
ctx += transcript
|
||||
ctx += "\n[END TRANSCRIPT]\n"
|
||||
return ctx
|
||||
|
||||
|
||||
async def fetch_youtube_comments(
|
||||
video_id: str, max_comments: int = 25, timeout: int = 30
|
||||
) -> Dict[str, Any]:
|
||||
"""Fetch top comments for a YouTube video using yt-dlp.
|
||||
|
||||
Returns dict with 'success', 'comments' list, 'error'.
|
||||
"""
|
||||
try:
|
||||
cmd = [
|
||||
_find_ytdlp(),
|
||||
"--skip-download",
|
||||
"--write-comments",
|
||||
"--extractor-args", f"youtube:max_comments={max_comments},all,100,0",
|
||||
"--dump-json",
|
||||
"--js-runtimes", "node",
|
||||
"--remote-components", "ejs:github",
|
||||
f"https://www.youtube.com/watch?v={video_id}",
|
||||
]
|
||||
|
||||
proc = await asyncio.wait_for(
|
||||
asyncio.create_subprocess_exec(
|
||||
*cmd,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
),
|
||||
timeout=timeout,
|
||||
)
|
||||
stdout, stderr = await proc.communicate()
|
||||
|
||||
if proc.returncode != 0:
|
||||
return {"success": False, "error": f"yt-dlp failed: {stderr.decode()[:200]}", "comments": []}
|
||||
|
||||
data = json.loads(stdout.decode())
|
||||
title = data.get("title", "")
|
||||
channel = data.get("channel", "") or data.get("uploader", "")
|
||||
raw_comments = data.get("comments", [])
|
||||
|
||||
comments = []
|
||||
for c in raw_comments[:max_comments]:
|
||||
text = (c.get("text") or "").strip()
|
||||
if not text:
|
||||
continue
|
||||
comments.append({
|
||||
"author": c.get("author", "Unknown"),
|
||||
"text": text,
|
||||
"likes": c.get("like_count", 0),
|
||||
})
|
||||
|
||||
# Sort by likes descending — most popular comments first
|
||||
comments.sort(key=lambda x: x.get("likes", 0), reverse=True)
|
||||
|
||||
return {"success": True, "comments": comments, "count": len(comments),
|
||||
"title": title, "channel": channel}
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(f"Comment fetch timed out for {video_id}")
|
||||
return {"success": False, "error": "Comment fetch timed out", "comments": []}
|
||||
except FileNotFoundError:
|
||||
logger.warning("yt-dlp not installed — cannot fetch comments")
|
||||
return {"success": False, "error": "yt-dlp not installed", "comments": []}
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch comments for {video_id}: {e}")
|
||||
return {"success": False, "error": str(e), "comments": []}
|
||||
|
||||
|
||||
def format_comments_for_context(comments_data: Dict[str, Any], url: str) -> str:
|
||||
"""Format YouTube comments for inclusion in LLM context."""
|
||||
if not comments_data.get("success") or not comments_data.get("comments"):
|
||||
return ""
|
||||
|
||||
comments = comments_data["comments"]
|
||||
ctx = f"\n[YOUTUBE VIDEO COMMENTS — Top {len(comments)} by popularity]\n"
|
||||
ctx += f"URL: {url}\n\n"
|
||||
|
||||
for i, c in enumerate(comments, 1):
|
||||
likes = c.get("likes", 0)
|
||||
likes_str = f" [{likes} likes]" if likes else ""
|
||||
ctx += f"{i}. @{c['author']}{likes_str}: {c['text']}\n\n"
|
||||
|
||||
if len(ctx) > 4000:
|
||||
ctx = ctx[:4000] + "\n[Comments truncated]\n"
|
||||
|
||||
ctx += "[END COMMENTS]\n"
|
||||
return ctx
|
||||
Reference in New Issue
Block a user