mirror of
https://github.com/pewdiepie-archdaemon/odysseus.git
synced 2026-06-17 10:15:27 -04:00
Add support for EMBEDDING_API_KEY (#2691)
* feat: support for embedding API key * feat: encrypt and decrypt embedding API key * test: add unit tests for EmbeddingClient authorization header behavior
This commit is contained in:
@@ -112,6 +112,9 @@ SEARXNG_INSTANCE=http://localhost:8080
|
|||||||
# Default: http://{LLM_HOST}:11434/v1/embeddings (ollama)
|
# Default: http://{LLM_HOST}:11434/v1/embeddings (ollama)
|
||||||
# EMBEDDING_URL=http://localhost:11434/v1/embeddings
|
# EMBEDDING_URL=http://localhost:11434/v1/embeddings
|
||||||
|
|
||||||
|
# Embedding API key (if there's one)
|
||||||
|
# EMBEDDING_API_KEY=embedding_api_key_here
|
||||||
|
|
||||||
# Embedding model name (must be available at the endpoint above)
|
# Embedding model name (must be available at the endpoint above)
|
||||||
# EMBEDDING_MODEL=all-minilm:l6-v2
|
# EMBEDDING_MODEL=all-minilm:l6-v2
|
||||||
|
|
||||||
|
|||||||
@@ -52,6 +52,7 @@ services:
|
|||||||
- SECURE_COOKIES=${SECURE_COOKIES:-false}
|
- SECURE_COOKIES=${SECURE_COOKIES:-false}
|
||||||
- EMBEDDING_URL=${EMBEDDING_URL:-}
|
- EMBEDDING_URL=${EMBEDDING_URL:-}
|
||||||
- EMBEDDING_MODEL=${EMBEDDING_MODEL:-}
|
- EMBEDDING_MODEL=${EMBEDDING_MODEL:-}
|
||||||
|
- EMBEDDING_API_KEY=${EMBEDDING_API_KEY:-}
|
||||||
- FASTEMBED_MODEL=${FASTEMBED_MODEL:-sentence-transformers/all-MiniLM-L6-v2}
|
- FASTEMBED_MODEL=${FASTEMBED_MODEL:-sentence-transformers/all-MiniLM-L6-v2}
|
||||||
- FASTEMBED_CACHE_PATH=${FASTEMBED_CACHE_PATH:-}
|
- FASTEMBED_CACHE_PATH=${FASTEMBED_CACHE_PATH:-}
|
||||||
- CLEANUP_INTERVAL_HOURS=${CLEANUP_INTERVAL_HOURS:-24}
|
- CLEANUP_INTERVAL_HOURS=${CLEANUP_INTERVAL_HOURS:-24}
|
||||||
|
|||||||
@@ -51,6 +51,7 @@ services:
|
|||||||
- SECURE_COOKIES=${SECURE_COOKIES:-false}
|
- SECURE_COOKIES=${SECURE_COOKIES:-false}
|
||||||
- EMBEDDING_URL=${EMBEDDING_URL:-}
|
- EMBEDDING_URL=${EMBEDDING_URL:-}
|
||||||
- EMBEDDING_MODEL=${EMBEDDING_MODEL:-}
|
- EMBEDDING_MODEL=${EMBEDDING_MODEL:-}
|
||||||
|
- EMBEDDING_API_KEY=${EMBEDDING_API_KEY:-}
|
||||||
- FASTEMBED_MODEL=${FASTEMBED_MODEL:-sentence-transformers/all-MiniLM-L6-v2}
|
- FASTEMBED_MODEL=${FASTEMBED_MODEL:-sentence-transformers/all-MiniLM-L6-v2}
|
||||||
- FASTEMBED_CACHE_PATH=${FASTEMBED_CACHE_PATH:-}
|
- FASTEMBED_CACHE_PATH=${FASTEMBED_CACHE_PATH:-}
|
||||||
- CLEANUP_INTERVAL_HOURS=${CLEANUP_INTERVAL_HOURS:-24}
|
- CLEANUP_INTERVAL_HOURS=${CLEANUP_INTERVAL_HOURS:-24}
|
||||||
|
|||||||
@@ -40,6 +40,7 @@ services:
|
|||||||
- SECURE_COOKIES=${SECURE_COOKIES:-false}
|
- SECURE_COOKIES=${SECURE_COOKIES:-false}
|
||||||
- EMBEDDING_URL=${EMBEDDING_URL:-}
|
- EMBEDDING_URL=${EMBEDDING_URL:-}
|
||||||
- EMBEDDING_MODEL=${EMBEDDING_MODEL:-}
|
- EMBEDDING_MODEL=${EMBEDDING_MODEL:-}
|
||||||
|
- EMBEDDING_API_KEY=${EMBEDDING_API_KEY:-}
|
||||||
- FASTEMBED_MODEL=${FASTEMBED_MODEL:-sentence-transformers/all-MiniLM-L6-v2}
|
- FASTEMBED_MODEL=${FASTEMBED_MODEL:-sentence-transformers/all-MiniLM-L6-v2}
|
||||||
- FASTEMBED_CACHE_PATH=${FASTEMBED_CACHE_PATH:-}
|
- FASTEMBED_CACHE_PATH=${FASTEMBED_CACHE_PATH:-}
|
||||||
- CLEANUP_INTERVAL_HOURS=${CLEANUP_INTERVAL_HOURS:-24}
|
- CLEANUP_INTERVAL_HOURS=${CLEANUP_INTERVAL_HOURS:-24}
|
||||||
|
|||||||
@@ -258,7 +258,7 @@ def setup_embedding_routes():
|
|||||||
}
|
}
|
||||||
|
|
||||||
@router.post("/endpoint")
|
@router.post("/endpoint")
|
||||||
def set_endpoint(url: str = Form(...), model: str = Form("")):
|
def set_endpoint(url: str = Form(...), model: str = Form(""), api_key: str = Form("")):
|
||||||
"""Save a custom embedding endpoint URL."""
|
"""Save a custom embedding endpoint URL."""
|
||||||
url = url.strip()
|
url = url.strip()
|
||||||
if not url:
|
if not url:
|
||||||
@@ -282,6 +282,7 @@ def setup_embedding_routes():
|
|||||||
resp = httpx.post(
|
resp = httpx.post(
|
||||||
url,
|
url,
|
||||||
json={"input": ["test"], "model": model or "test"},
|
json={"input": ["test"], "model": model or "test"},
|
||||||
|
headers={"Authorization": f"Bearer {api_key}"} if api_key else {},
|
||||||
timeout=10,
|
timeout=10,
|
||||||
)
|
)
|
||||||
resp.raise_for_status()
|
resp.raise_for_status()
|
||||||
@@ -292,10 +293,16 @@ def setup_embedding_routes():
|
|||||||
data = {"url": url}
|
data = {"url": url}
|
||||||
if model:
|
if model:
|
||||||
data["model"] = model
|
data["model"] = model
|
||||||
|
if api_key:
|
||||||
|
from src.secret_storage import encrypt
|
||||||
|
data["api_key"] = encrypt(api_key)
|
||||||
|
|
||||||
_save_custom_endpoint(data)
|
_save_custom_endpoint(data)
|
||||||
os.environ["EMBEDDING_URL"] = url
|
os.environ["EMBEDDING_URL"] = url
|
||||||
if model:
|
if model:
|
||||||
os.environ["EMBEDDING_MODEL"] = model
|
os.environ["EMBEDDING_MODEL"] = model
|
||||||
|
if api_key:
|
||||||
|
os.environ["EMBEDDING_API_KEY"] = api_key
|
||||||
|
|
||||||
# Reset the RAG singleton so it picks up the new endpoint
|
# Reset the RAG singleton so it picks up the new endpoint
|
||||||
import src.rag_singleton as _rs
|
import src.rag_singleton as _rs
|
||||||
@@ -329,6 +336,7 @@ def setup_embedding_routes():
|
|||||||
# Remove from environment
|
# Remove from environment
|
||||||
os.environ.pop("EMBEDDING_URL", None)
|
os.environ.pop("EMBEDDING_URL", None)
|
||||||
os.environ.pop("EMBEDDING_MODEL", None)
|
os.environ.pop("EMBEDDING_MODEL", None)
|
||||||
|
os.environ.pop("EMBEDDING_API_KEY", None)
|
||||||
|
|
||||||
# Reset the RAG singleton so it falls back to fastembed
|
# Reset the RAG singleton so it falls back to fastembed
|
||||||
import src.rag_singleton as _rs
|
import src.rag_singleton as _rs
|
||||||
|
|||||||
+7
-2
@@ -38,12 +38,13 @@ _DEFAULT_FASTEMBED_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
|
|||||||
class EmbeddingClient:
|
class EmbeddingClient:
|
||||||
"""Drop-in replacement for SentenceTransformer.encode() using an HTTP API."""
|
"""Drop-in replacement for SentenceTransformer.encode() using an HTTP API."""
|
||||||
|
|
||||||
def __init__(self, url: Optional[str] = None, model: Optional[str] = None):
|
def __init__(self, url: Optional[str] = None, model: Optional[str] = None, api_key: Optional[str] = None):
|
||||||
self.url = url or os.getenv(
|
self.url = url or os.getenv(
|
||||||
"EMBEDDING_URL",
|
"EMBEDDING_URL",
|
||||||
f"http://{os.getenv('LLM_HOST', 'localhost')}:11434/v1/embeddings",
|
f"http://{os.getenv('LLM_HOST', 'localhost')}:11434/v1/embeddings",
|
||||||
)
|
)
|
||||||
self.model = model or os.getenv("EMBEDDING_MODEL", _DEFAULT_MODEL)
|
self.model = model or os.getenv("EMBEDDING_MODEL", _DEFAULT_MODEL)
|
||||||
|
self.api_key = api_key or os.getenv("EMBEDDING_API_KEY")
|
||||||
self._dim: Optional[int] = None
|
self._dim: Optional[int] = None
|
||||||
# Short connect timeout so a DOWN embedding endpoint (e.g. Ollama not
|
# Short connect timeout so a DOWN embedding endpoint (e.g. Ollama not
|
||||||
# running on :11434) fast-fails to the local FastEmbed fallback instead
|
# running on :11434) fast-fails to the local FastEmbed fallback instead
|
||||||
@@ -74,6 +75,7 @@ class EmbeddingClient:
|
|||||||
batch = texts[i : i + 64]
|
batch = texts[i : i + 64]
|
||||||
resp = self._client.post(
|
resp = self._client.post(
|
||||||
self.url,
|
self.url,
|
||||||
|
headers={"Authorization": f"Bearer {self.api_key}"} if self.api_key else {},
|
||||||
json={"input": batch, "model": self.model},
|
json={"input": batch, "model": self.model},
|
||||||
)
|
)
|
||||||
resp.raise_for_status()
|
resp.raise_for_status()
|
||||||
@@ -222,11 +224,14 @@ def get_embedding_client():
|
|||||||
if persisted.get("url"):
|
if persisted.get("url"):
|
||||||
url = persisted["url"]
|
url = persisted["url"]
|
||||||
model = persisted.get("model", "")
|
model = persisted.get("model", "")
|
||||||
|
api_key = persisted.get("api_key", "")
|
||||||
# Also set in env so other code sees it
|
# Also set in env so other code sees it
|
||||||
os.environ["EMBEDDING_URL"] = url
|
os.environ["EMBEDDING_URL"] = url
|
||||||
if model:
|
if model:
|
||||||
os.environ["EMBEDDING_MODEL"] = model
|
os.environ["EMBEDDING_MODEL"] = model
|
||||||
|
if api_key:
|
||||||
|
from src.secret_storage import decrypt
|
||||||
|
os.environ["EMBEDDING_API_KEY"] = decrypt(api_key)
|
||||||
# Try the HTTP embedding API — unless we already found it down this process
|
# Try the HTTP embedding API — unless we already found it down this process
|
||||||
# (avoids paying the connect timeout again on every RAG/memory/tool probe).
|
# (avoids paying the connect timeout again on every RAG/memory/tool probe).
|
||||||
if not _http_embed_down:
|
if not _http_embed_down:
|
||||||
|
|||||||
@@ -0,0 +1,46 @@
|
|||||||
|
"""Tests for embeddings.py"""
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
from src.embeddings import EmbeddingClient
|
||||||
|
|
||||||
|
|
||||||
|
class TestEmbeddingClient:
|
||||||
|
_MOCK_RESPONSE = {
|
||||||
|
"data": [{"embedding": [0.1], "index": 0}],
|
||||||
|
}
|
||||||
|
|
||||||
|
def _make_mock_resp(self):
|
||||||
|
resp = MagicMock()
|
||||||
|
resp.status_code = 200
|
||||||
|
resp.json.return_value = self._MOCK_RESPONSE
|
||||||
|
resp.raise_for_status = MagicMock()
|
||||||
|
return resp
|
||||||
|
|
||||||
|
@patch("src.embeddings.httpx.Client")
|
||||||
|
def test_bearer_header_sent_when_api_key_set(self, mock_httpx):
|
||||||
|
"""
|
||||||
|
Test that the EmbeddingClient sends the Authorization header with the correct value when api_key is set.
|
||||||
|
"""
|
||||||
|
mock_httpx.return_value.post.return_value = self._make_mock_resp()
|
||||||
|
|
||||||
|
client = EmbeddingClient(
|
||||||
|
url="http://test:11434/v1/embeddings",
|
||||||
|
model="all-minilm:l6-v2",
|
||||||
|
api_key="secret-key",
|
||||||
|
)
|
||||||
|
client.encode(["x"])
|
||||||
|
|
||||||
|
headers = mock_httpx.return_value.post.call_args.kwargs["headers"]
|
||||||
|
assert headers.get("Authorization") == "Bearer secret-key"
|
||||||
|
|
||||||
|
@patch("src.embeddings.httpx.Client")
|
||||||
|
def test_no_bearer_header_when_api_key_none(self, mock_httpx):
|
||||||
|
"""
|
||||||
|
Test that the EmbeddingClient does not send the Authorization header when api_key is None.
|
||||||
|
"""
|
||||||
|
mock_httpx.return_value.post.return_value = self._make_mock_resp()
|
||||||
|
|
||||||
|
client = EmbeddingClient(url="http://test:11434/v1/embeddings")
|
||||||
|
client.encode(["x"])
|
||||||
|
|
||||||
|
headers = mock_httpx.return_value.post.call_args.kwargs["headers"]
|
||||||
|
assert "Authorization" not in headers
|
||||||
Reference in New Issue
Block a user