mirror of
https://github.com/pewdiepie-archdaemon/odysseus.git
synced 2026-06-15 17:25:26 -04:00
Add support for EMBEDDING_API_KEY (#2691)
* feat: support for embedding API key * feat: encrypt and decrypt embedding API key * test: add unit tests for EmbeddingClient authorization header behavior
This commit is contained in:
@@ -112,6 +112,9 @@ SEARXNG_INSTANCE=http://localhost:8080
|
||||
# Default: http://{LLM_HOST}:11434/v1/embeddings (ollama)
|
||||
# EMBEDDING_URL=http://localhost:11434/v1/embeddings
|
||||
|
||||
# Embedding API key (if there's one)
|
||||
# EMBEDDING_API_KEY=embedding_api_key_here
|
||||
|
||||
# Embedding model name (must be available at the endpoint above)
|
||||
# EMBEDDING_MODEL=all-minilm:l6-v2
|
||||
|
||||
|
||||
@@ -52,6 +52,7 @@ services:
|
||||
- SECURE_COOKIES=${SECURE_COOKIES:-false}
|
||||
- EMBEDDING_URL=${EMBEDDING_URL:-}
|
||||
- EMBEDDING_MODEL=${EMBEDDING_MODEL:-}
|
||||
- EMBEDDING_API_KEY=${EMBEDDING_API_KEY:-}
|
||||
- FASTEMBED_MODEL=${FASTEMBED_MODEL:-sentence-transformers/all-MiniLM-L6-v2}
|
||||
- FASTEMBED_CACHE_PATH=${FASTEMBED_CACHE_PATH:-}
|
||||
- CLEANUP_INTERVAL_HOURS=${CLEANUP_INTERVAL_HOURS:-24}
|
||||
|
||||
@@ -51,6 +51,7 @@ services:
|
||||
- SECURE_COOKIES=${SECURE_COOKIES:-false}
|
||||
- EMBEDDING_URL=${EMBEDDING_URL:-}
|
||||
- EMBEDDING_MODEL=${EMBEDDING_MODEL:-}
|
||||
- EMBEDDING_API_KEY=${EMBEDDING_API_KEY:-}
|
||||
- FASTEMBED_MODEL=${FASTEMBED_MODEL:-sentence-transformers/all-MiniLM-L6-v2}
|
||||
- FASTEMBED_CACHE_PATH=${FASTEMBED_CACHE_PATH:-}
|
||||
- CLEANUP_INTERVAL_HOURS=${CLEANUP_INTERVAL_HOURS:-24}
|
||||
|
||||
@@ -40,6 +40,7 @@ services:
|
||||
- SECURE_COOKIES=${SECURE_COOKIES:-false}
|
||||
- EMBEDDING_URL=${EMBEDDING_URL:-}
|
||||
- EMBEDDING_MODEL=${EMBEDDING_MODEL:-}
|
||||
- EMBEDDING_API_KEY=${EMBEDDING_API_KEY:-}
|
||||
- FASTEMBED_MODEL=${FASTEMBED_MODEL:-sentence-transformers/all-MiniLM-L6-v2}
|
||||
- FASTEMBED_CACHE_PATH=${FASTEMBED_CACHE_PATH:-}
|
||||
- CLEANUP_INTERVAL_HOURS=${CLEANUP_INTERVAL_HOURS:-24}
|
||||
|
||||
@@ -258,7 +258,7 @@ def setup_embedding_routes():
|
||||
}
|
||||
|
||||
@router.post("/endpoint")
|
||||
def set_endpoint(url: str = Form(...), model: str = Form("")):
|
||||
def set_endpoint(url: str = Form(...), model: str = Form(""), api_key: str = Form("")):
|
||||
"""Save a custom embedding endpoint URL."""
|
||||
url = url.strip()
|
||||
if not url:
|
||||
@@ -282,6 +282,7 @@ def setup_embedding_routes():
|
||||
resp = httpx.post(
|
||||
url,
|
||||
json={"input": ["test"], "model": model or "test"},
|
||||
headers={"Authorization": f"Bearer {api_key}"} if api_key else {},
|
||||
timeout=10,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
@@ -292,10 +293,16 @@ def setup_embedding_routes():
|
||||
data = {"url": url}
|
||||
if model:
|
||||
data["model"] = model
|
||||
if api_key:
|
||||
from src.secret_storage import encrypt
|
||||
data["api_key"] = encrypt(api_key)
|
||||
|
||||
_save_custom_endpoint(data)
|
||||
os.environ["EMBEDDING_URL"] = url
|
||||
if model:
|
||||
os.environ["EMBEDDING_MODEL"] = model
|
||||
if api_key:
|
||||
os.environ["EMBEDDING_API_KEY"] = api_key
|
||||
|
||||
# Reset the RAG singleton so it picks up the new endpoint
|
||||
import src.rag_singleton as _rs
|
||||
@@ -329,6 +336,7 @@ def setup_embedding_routes():
|
||||
# Remove from environment
|
||||
os.environ.pop("EMBEDDING_URL", None)
|
||||
os.environ.pop("EMBEDDING_MODEL", None)
|
||||
os.environ.pop("EMBEDDING_API_KEY", None)
|
||||
|
||||
# Reset the RAG singleton so it falls back to fastembed
|
||||
import src.rag_singleton as _rs
|
||||
|
||||
+7
-2
@@ -38,12 +38,13 @@ _DEFAULT_FASTEMBED_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
|
||||
class EmbeddingClient:
|
||||
"""Drop-in replacement for SentenceTransformer.encode() using an HTTP API."""
|
||||
|
||||
def __init__(self, url: Optional[str] = None, model: Optional[str] = None):
|
||||
def __init__(self, url: Optional[str] = None, model: Optional[str] = None, api_key: Optional[str] = None):
|
||||
self.url = url or os.getenv(
|
||||
"EMBEDDING_URL",
|
||||
f"http://{os.getenv('LLM_HOST', 'localhost')}:11434/v1/embeddings",
|
||||
)
|
||||
self.model = model or os.getenv("EMBEDDING_MODEL", _DEFAULT_MODEL)
|
||||
self.api_key = api_key or os.getenv("EMBEDDING_API_KEY")
|
||||
self._dim: Optional[int] = None
|
||||
# Short connect timeout so a DOWN embedding endpoint (e.g. Ollama not
|
||||
# running on :11434) fast-fails to the local FastEmbed fallback instead
|
||||
@@ -74,6 +75,7 @@ class EmbeddingClient:
|
||||
batch = texts[i : i + 64]
|
||||
resp = self._client.post(
|
||||
self.url,
|
||||
headers={"Authorization": f"Bearer {self.api_key}"} if self.api_key else {},
|
||||
json={"input": batch, "model": self.model},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
@@ -222,11 +224,14 @@ def get_embedding_client():
|
||||
if persisted.get("url"):
|
||||
url = persisted["url"]
|
||||
model = persisted.get("model", "")
|
||||
api_key = persisted.get("api_key", "")
|
||||
# Also set in env so other code sees it
|
||||
os.environ["EMBEDDING_URL"] = url
|
||||
if model:
|
||||
os.environ["EMBEDDING_MODEL"] = model
|
||||
|
||||
if api_key:
|
||||
from src.secret_storage import decrypt
|
||||
os.environ["EMBEDDING_API_KEY"] = decrypt(api_key)
|
||||
# Try the HTTP embedding API — unless we already found it down this process
|
||||
# (avoids paying the connect timeout again on every RAG/memory/tool probe).
|
||||
if not _http_embed_down:
|
||||
|
||||
@@ -0,0 +1,46 @@
|
||||
"""Tests for embeddings.py"""
|
||||
from unittest.mock import MagicMock, patch
|
||||
from src.embeddings import EmbeddingClient
|
||||
|
||||
|
||||
class TestEmbeddingClient:
|
||||
_MOCK_RESPONSE = {
|
||||
"data": [{"embedding": [0.1], "index": 0}],
|
||||
}
|
||||
|
||||
def _make_mock_resp(self):
|
||||
resp = MagicMock()
|
||||
resp.status_code = 200
|
||||
resp.json.return_value = self._MOCK_RESPONSE
|
||||
resp.raise_for_status = MagicMock()
|
||||
return resp
|
||||
|
||||
@patch("src.embeddings.httpx.Client")
|
||||
def test_bearer_header_sent_when_api_key_set(self, mock_httpx):
|
||||
"""
|
||||
Test that the EmbeddingClient sends the Authorization header with the correct value when api_key is set.
|
||||
"""
|
||||
mock_httpx.return_value.post.return_value = self._make_mock_resp()
|
||||
|
||||
client = EmbeddingClient(
|
||||
url="http://test:11434/v1/embeddings",
|
||||
model="all-minilm:l6-v2",
|
||||
api_key="secret-key",
|
||||
)
|
||||
client.encode(["x"])
|
||||
|
||||
headers = mock_httpx.return_value.post.call_args.kwargs["headers"]
|
||||
assert headers.get("Authorization") == "Bearer secret-key"
|
||||
|
||||
@patch("src.embeddings.httpx.Client")
|
||||
def test_no_bearer_header_when_api_key_none(self, mock_httpx):
|
||||
"""
|
||||
Test that the EmbeddingClient does not send the Authorization header when api_key is None.
|
||||
"""
|
||||
mock_httpx.return_value.post.return_value = self._make_mock_resp()
|
||||
|
||||
client = EmbeddingClient(url="http://test:11434/v1/embeddings")
|
||||
client.encode(["x"])
|
||||
|
||||
headers = mock_httpx.return_value.post.call_args.kwargs["headers"]
|
||||
assert "Authorization" not in headers
|
||||
Reference in New Issue
Block a user