mirror of
https://github.com/pewdiepie-archdaemon/odysseus.git
synced 2026-06-17 02:05:22 -04:00
Odysseus v1.0
This commit is contained in:
@@ -0,0 +1,191 @@
|
||||
# services/stt/stt_service.py
|
||||
"""Multi-provider Speech-to-Text service — dispatches to local Whisper, OpenAI-compatible API, or browser."""
|
||||
|
||||
import io
|
||||
import logging
|
||||
import httpx
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Optional, Dict, Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class STTService:
|
||||
"""Multi-provider STT service.
|
||||
|
||||
Reads provider config from data/settings.json on each call.
|
||||
Providers:
|
||||
"disabled" — no STT
|
||||
"browser" — client-side Web Speech API (no server transcription)
|
||||
"local" — faster-whisper on CPU/GPU
|
||||
"endpoint:<id>" — OpenAI-compatible /audio/transcriptions via ModelEndpoint
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._whisper_model = None # lazy-init
|
||||
|
||||
# ── Settings ──
|
||||
|
||||
def _load_settings(self) -> dict:
|
||||
from src.settings import load_settings
|
||||
saved = load_settings()
|
||||
return {
|
||||
"stt_enabled": saved.get("stt_enabled", False),
|
||||
"stt_provider": saved.get("stt_provider", "disabled"),
|
||||
"stt_model": saved.get("stt_model", "base"),
|
||||
"stt_language": saved.get("stt_language", ""),
|
||||
}
|
||||
|
||||
@property
|
||||
def available(self) -> bool:
|
||||
settings = self._load_settings()
|
||||
provider = settings["stt_provider"]
|
||||
if provider == "disabled":
|
||||
return False
|
||||
if provider == "browser":
|
||||
return True # handled client-side
|
||||
if provider == "local":
|
||||
return self._get_whisper() is not None
|
||||
if provider.startswith("endpoint:"):
|
||||
return True # assume reachable
|
||||
return False
|
||||
|
||||
# ── Local Whisper ──
|
||||
|
||||
def _get_whisper(self):
|
||||
if self._whisper_model is None:
|
||||
try:
|
||||
from faster_whisper import WhisperModel
|
||||
settings = self._load_settings()
|
||||
model_size = settings.get("stt_model", "base")
|
||||
# Use CPU by default; will use CUDA if available
|
||||
import torch
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
compute_type = "float16" if device == "cuda" else "int8"
|
||||
self._whisper_model = WhisperModel(model_size, device=device, compute_type=compute_type)
|
||||
logger.info(f"faster-whisper model '{model_size}' loaded on {device}")
|
||||
except ImportError:
|
||||
logger.warning("faster-whisper not installed. Install with: pip install faster-whisper")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load whisper model: {e}")
|
||||
return None
|
||||
return self._whisper_model
|
||||
|
||||
def _transcribe_local(self, audio_bytes: bytes, language: str = "") -> Optional[str]:
|
||||
model = self._get_whisper()
|
||||
if not model:
|
||||
return None
|
||||
try:
|
||||
# Write to temp file (faster-whisper needs a file path or file-like)
|
||||
with tempfile.NamedTemporaryFile(suffix=".webm", delete=False) as tmp:
|
||||
tmp.write(audio_bytes)
|
||||
tmp_path = tmp.name
|
||||
|
||||
kwargs = {}
|
||||
if language:
|
||||
kwargs["language"] = language
|
||||
|
||||
segments, info = model.transcribe(tmp_path, **kwargs)
|
||||
text = " ".join(seg.text.strip() for seg in segments)
|
||||
|
||||
# Cleanup
|
||||
Path(tmp_path).unlink(missing_ok=True)
|
||||
|
||||
logger.info(f"Local STT: {len(text)} chars, lang={info.language}, prob={info.language_probability:.2f}")
|
||||
return text
|
||||
except Exception as e:
|
||||
logger.error(f"Local STT transcription failed: {e}", exc_info=True)
|
||||
return None
|
||||
|
||||
# ── API endpoint ──
|
||||
|
||||
def _transcribe_api(self, audio_bytes: bytes, endpoint_id: str, model: str, language: str = "") -> Optional[str]:
|
||||
from src.database import SessionLocal, ModelEndpoint
|
||||
|
||||
db = SessionLocal()
|
||||
try:
|
||||
ep = db.query(ModelEndpoint).filter(ModelEndpoint.id == endpoint_id).first()
|
||||
if not ep:
|
||||
logger.error(f"STT endpoint {endpoint_id} not found")
|
||||
return None
|
||||
base_url = ep.base_url.rstrip("/")
|
||||
api_key = ep.api_key
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
url = base_url + "/audio/transcriptions"
|
||||
headers = {}
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
files = {"file": ("audio.webm", io.BytesIO(audio_bytes), "audio/webm")}
|
||||
data = {"model": model or "whisper-1"}
|
||||
if language:
|
||||
data["language"] = language
|
||||
|
||||
try:
|
||||
r = httpx.post(url, headers=headers, files=files, data=data, timeout=60)
|
||||
r.raise_for_status()
|
||||
result = r.json()
|
||||
text = result.get("text", "")
|
||||
logger.info(f"API STT: {len(text)} chars from {base_url}")
|
||||
return text
|
||||
except Exception as e:
|
||||
logger.error(f"API STT transcription failed: {e}")
|
||||
return None
|
||||
|
||||
# ── Public interface ──
|
||||
|
||||
def transcribe(self, audio_bytes: bytes) -> Optional[str]:
|
||||
settings = self._load_settings()
|
||||
provider = settings["stt_provider"]
|
||||
model = settings["stt_model"]
|
||||
language = settings.get("stt_language", "")
|
||||
|
||||
if provider in ("disabled", "browser"):
|
||||
return None
|
||||
|
||||
if provider == "local":
|
||||
return self._transcribe_local(audio_bytes, language)
|
||||
elif provider.startswith("endpoint:"):
|
||||
endpoint_id = provider.split(":", 1)[1]
|
||||
return self._transcribe_api(audio_bytes, endpoint_id, model, language)
|
||||
else:
|
||||
logger.error(f"Unknown STT provider: {provider}")
|
||||
return None
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
settings = self._load_settings()
|
||||
provider = settings["stt_provider"]
|
||||
stt_enabled = settings.get("stt_enabled", False)
|
||||
# If toggle is off, report as disabled
|
||||
effective_provider = provider if stt_enabled else "disabled"
|
||||
|
||||
stats = {
|
||||
"available": self.available and stt_enabled,
|
||||
"provider": effective_provider,
|
||||
"model": settings["stt_model"],
|
||||
"language": settings.get("stt_language", ""),
|
||||
}
|
||||
|
||||
if provider == "local":
|
||||
whisper = self._get_whisper()
|
||||
stats["model_loaded"] = whisper is not None
|
||||
elif provider == "browser":
|
||||
stats["model"] = "Browser (Web Speech API)"
|
||||
elif provider.startswith("endpoint:"):
|
||||
stats["endpoint_id"] = provider.split(":", 1)[1]
|
||||
|
||||
return stats
|
||||
|
||||
|
||||
# Module-level singleton
|
||||
_stt_service = None
|
||||
|
||||
def get_stt_service() -> STTService:
|
||||
global _stt_service
|
||||
if _stt_service is None:
|
||||
_stt_service = STTService()
|
||||
return _stt_service
|
||||
Reference in New Issue
Block a user