Add support for EMBEDDING_API_KEY (#2691)

* feat: support for embedding API key

* feat: encrypt and decrypt embedding API key

* test: add unit tests for EmbeddingClient authorization header behavior
This commit is contained in:
Yiğit Egemen
2026-06-05 14:47:24 +02:00
committed by GitHub
parent b5c45326e4
commit ec8fbf5d8f
7 changed files with 68 additions and 3 deletions
+3
View File
@@ -112,6 +112,9 @@ SEARXNG_INSTANCE=http://localhost:8080
# Default: http://{LLM_HOST}:11434/v1/embeddings (ollama)
# EMBEDDING_URL=http://localhost:11434/v1/embeddings
# Embedding API key (if there's one)
# EMBEDDING_API_KEY=embedding_api_key_here
# Embedding model name (must be available at the endpoint above)
# EMBEDDING_MODEL=all-minilm:l6-v2
+1
View File
@@ -52,6 +52,7 @@ services:
- SECURE_COOKIES=${SECURE_COOKIES:-false}
- EMBEDDING_URL=${EMBEDDING_URL:-}
- EMBEDDING_MODEL=${EMBEDDING_MODEL:-}
- EMBEDDING_API_KEY=${EMBEDDING_API_KEY:-}
- FASTEMBED_MODEL=${FASTEMBED_MODEL:-sentence-transformers/all-MiniLM-L6-v2}
- FASTEMBED_CACHE_PATH=${FASTEMBED_CACHE_PATH:-}
- CLEANUP_INTERVAL_HOURS=${CLEANUP_INTERVAL_HOURS:-24}
+1
View File
@@ -51,6 +51,7 @@ services:
- SECURE_COOKIES=${SECURE_COOKIES:-false}
- EMBEDDING_URL=${EMBEDDING_URL:-}
- EMBEDDING_MODEL=${EMBEDDING_MODEL:-}
- EMBEDDING_API_KEY=${EMBEDDING_API_KEY:-}
- FASTEMBED_MODEL=${FASTEMBED_MODEL:-sentence-transformers/all-MiniLM-L6-v2}
- FASTEMBED_CACHE_PATH=${FASTEMBED_CACHE_PATH:-}
- CLEANUP_INTERVAL_HOURS=${CLEANUP_INTERVAL_HOURS:-24}
+1
View File
@@ -40,6 +40,7 @@ services:
- SECURE_COOKIES=${SECURE_COOKIES:-false}
- EMBEDDING_URL=${EMBEDDING_URL:-}
- EMBEDDING_MODEL=${EMBEDDING_MODEL:-}
- EMBEDDING_API_KEY=${EMBEDDING_API_KEY:-}
- FASTEMBED_MODEL=${FASTEMBED_MODEL:-sentence-transformers/all-MiniLM-L6-v2}
- FASTEMBED_CACHE_PATH=${FASTEMBED_CACHE_PATH:-}
- CLEANUP_INTERVAL_HOURS=${CLEANUP_INTERVAL_HOURS:-24}
+9 -1
View File
@@ -258,7 +258,7 @@ def setup_embedding_routes():
}
@router.post("/endpoint")
def set_endpoint(url: str = Form(...), model: str = Form("")):
def set_endpoint(url: str = Form(...), model: str = Form(""), api_key: str = Form("")):
"""Save a custom embedding endpoint URL."""
url = url.strip()
if not url:
@@ -282,6 +282,7 @@ def setup_embedding_routes():
resp = httpx.post(
url,
json={"input": ["test"], "model": model or "test"},
headers={"Authorization": f"Bearer {api_key}"} if api_key else {},
timeout=10,
)
resp.raise_for_status()
@@ -292,10 +293,16 @@ def setup_embedding_routes():
data = {"url": url}
if model:
data["model"] = model
if api_key:
from src.secret_storage import encrypt
data["api_key"] = encrypt(api_key)
_save_custom_endpoint(data)
os.environ["EMBEDDING_URL"] = url
if model:
os.environ["EMBEDDING_MODEL"] = model
if api_key:
os.environ["EMBEDDING_API_KEY"] = api_key
# Reset the RAG singleton so it picks up the new endpoint
import src.rag_singleton as _rs
@@ -329,6 +336,7 @@ def setup_embedding_routes():
# Remove from environment
os.environ.pop("EMBEDDING_URL", None)
os.environ.pop("EMBEDDING_MODEL", None)
os.environ.pop("EMBEDDING_API_KEY", None)
# Reset the RAG singleton so it falls back to fastembed
import src.rag_singleton as _rs
+7 -2
View File
@@ -38,12 +38,13 @@ _DEFAULT_FASTEMBED_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
class EmbeddingClient:
"""Drop-in replacement for SentenceTransformer.encode() using an HTTP API."""
def __init__(self, url: Optional[str] = None, model: Optional[str] = None):
def __init__(self, url: Optional[str] = None, model: Optional[str] = None, api_key: Optional[str] = None):
self.url = url or os.getenv(
"EMBEDDING_URL",
f"http://{os.getenv('LLM_HOST', 'localhost')}:11434/v1/embeddings",
)
self.model = model or os.getenv("EMBEDDING_MODEL", _DEFAULT_MODEL)
self.api_key = api_key or os.getenv("EMBEDDING_API_KEY")
self._dim: Optional[int] = None
# Short connect timeout so a DOWN embedding endpoint (e.g. Ollama not
# running on :11434) fast-fails to the local FastEmbed fallback instead
@@ -74,6 +75,7 @@ class EmbeddingClient:
batch = texts[i : i + 64]
resp = self._client.post(
self.url,
headers={"Authorization": f"Bearer {self.api_key}"} if self.api_key else {},
json={"input": batch, "model": self.model},
)
resp.raise_for_status()
@@ -222,11 +224,14 @@ def get_embedding_client():
if persisted.get("url"):
url = persisted["url"]
model = persisted.get("model", "")
api_key = persisted.get("api_key", "")
# Also set in env so other code sees it
os.environ["EMBEDDING_URL"] = url
if model:
os.environ["EMBEDDING_MODEL"] = model
if api_key:
from src.secret_storage import decrypt
os.environ["EMBEDDING_API_KEY"] = decrypt(api_key)
# Try the HTTP embedding API — unless we already found it down this process
# (avoids paying the connect timeout again on every RAG/memory/tool probe).
if not _http_embed_down:
+46
View File
@@ -0,0 +1,46 @@
"""Tests for embeddings.py"""
from unittest.mock import MagicMock, patch
from src.embeddings import EmbeddingClient
class TestEmbeddingClient:
_MOCK_RESPONSE = {
"data": [{"embedding": [0.1], "index": 0}],
}
def _make_mock_resp(self):
resp = MagicMock()
resp.status_code = 200
resp.json.return_value = self._MOCK_RESPONSE
resp.raise_for_status = MagicMock()
return resp
@patch("src.embeddings.httpx.Client")
def test_bearer_header_sent_when_api_key_set(self, mock_httpx):
"""
Test that the EmbeddingClient sends the Authorization header with the correct value when api_key is set.
"""
mock_httpx.return_value.post.return_value = self._make_mock_resp()
client = EmbeddingClient(
url="http://test:11434/v1/embeddings",
model="all-minilm:l6-v2",
api_key="secret-key",
)
client.encode(["x"])
headers = mock_httpx.return_value.post.call_args.kwargs["headers"]
assert headers.get("Authorization") == "Bearer secret-key"
@patch("src.embeddings.httpx.Client")
def test_no_bearer_header_when_api_key_none(self, mock_httpx):
"""
Test that the EmbeddingClient does not send the Authorization header when api_key is None.
"""
mock_httpx.return_value.post.return_value = self._make_mock_resp()
client = EmbeddingClient(url="http://test:11434/v1/embeddings")
client.encode(["x"])
headers = mock_httpx.return_value.post.call_args.kwargs["headers"]
assert "Authorization" not in headers