diff --git a/.env.example b/.env.example index f282880bc..39c90b30d 100644 --- a/.env.example +++ b/.env.example @@ -112,6 +112,9 @@ SEARXNG_INSTANCE=http://localhost:8080 # Default: http://{LLM_HOST}:11434/v1/embeddings (ollama) # EMBEDDING_URL=http://localhost:11434/v1/embeddings +# Embedding API key (if there's one) +# EMBEDDING_API_KEY=embedding_api_key_here + # Embedding model name (must be available at the endpoint above) # EMBEDDING_MODEL=all-minilm:l6-v2 diff --git a/docker-compose.gpu-amd.yml b/docker-compose.gpu-amd.yml index 47e0c8550..6d87cb6e3 100644 --- a/docker-compose.gpu-amd.yml +++ b/docker-compose.gpu-amd.yml @@ -52,6 +52,7 @@ services: - SECURE_COOKIES=${SECURE_COOKIES:-false} - EMBEDDING_URL=${EMBEDDING_URL:-} - EMBEDDING_MODEL=${EMBEDDING_MODEL:-} + - EMBEDDING_API_KEY=${EMBEDDING_API_KEY:-} - FASTEMBED_MODEL=${FASTEMBED_MODEL:-sentence-transformers/all-MiniLM-L6-v2} - FASTEMBED_CACHE_PATH=${FASTEMBED_CACHE_PATH:-} - CLEANUP_INTERVAL_HOURS=${CLEANUP_INTERVAL_HOURS:-24} diff --git a/docker-compose.gpu-nvidia.yml b/docker-compose.gpu-nvidia.yml index 36ca10efe..f61d22a4b 100644 --- a/docker-compose.gpu-nvidia.yml +++ b/docker-compose.gpu-nvidia.yml @@ -51,6 +51,7 @@ services: - SECURE_COOKIES=${SECURE_COOKIES:-false} - EMBEDDING_URL=${EMBEDDING_URL:-} - EMBEDDING_MODEL=${EMBEDDING_MODEL:-} + - EMBEDDING_API_KEY=${EMBEDDING_API_KEY:-} - FASTEMBED_MODEL=${FASTEMBED_MODEL:-sentence-transformers/all-MiniLM-L6-v2} - FASTEMBED_CACHE_PATH=${FASTEMBED_CACHE_PATH:-} - CLEANUP_INTERVAL_HOURS=${CLEANUP_INTERVAL_HOURS:-24} diff --git a/docker-compose.yml b/docker-compose.yml index f3a8dcc49..b5b3fd93d 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -40,6 +40,7 @@ services: - SECURE_COOKIES=${SECURE_COOKIES:-false} - EMBEDDING_URL=${EMBEDDING_URL:-} - EMBEDDING_MODEL=${EMBEDDING_MODEL:-} + - EMBEDDING_API_KEY=${EMBEDDING_API_KEY:-} - FASTEMBED_MODEL=${FASTEMBED_MODEL:-sentence-transformers/all-MiniLM-L6-v2} - FASTEMBED_CACHE_PATH=${FASTEMBED_CACHE_PATH:-} - CLEANUP_INTERVAL_HOURS=${CLEANUP_INTERVAL_HOURS:-24} diff --git a/routes/embedding_routes.py b/routes/embedding_routes.py index a5ef4c084..c6f0645a7 100644 --- a/routes/embedding_routes.py +++ b/routes/embedding_routes.py @@ -258,7 +258,7 @@ def setup_embedding_routes(): } @router.post("/endpoint") - def set_endpoint(url: str = Form(...), model: str = Form("")): + def set_endpoint(url: str = Form(...), model: str = Form(""), api_key: str = Form("")): """Save a custom embedding endpoint URL.""" url = url.strip() if not url: @@ -282,6 +282,7 @@ def setup_embedding_routes(): resp = httpx.post( url, json={"input": ["test"], "model": model or "test"}, + headers={"Authorization": f"Bearer {api_key}"} if api_key else {}, timeout=10, ) resp.raise_for_status() @@ -292,10 +293,16 @@ def setup_embedding_routes(): data = {"url": url} if model: data["model"] = model + if api_key: + from src.secret_storage import encrypt + data["api_key"] = encrypt(api_key) + _save_custom_endpoint(data) os.environ["EMBEDDING_URL"] = url if model: os.environ["EMBEDDING_MODEL"] = model + if api_key: + os.environ["EMBEDDING_API_KEY"] = api_key # Reset the RAG singleton so it picks up the new endpoint import src.rag_singleton as _rs @@ -329,6 +336,7 @@ def setup_embedding_routes(): # Remove from environment os.environ.pop("EMBEDDING_URL", None) os.environ.pop("EMBEDDING_MODEL", None) + os.environ.pop("EMBEDDING_API_KEY", None) # Reset the RAG singleton so it falls back to fastembed import src.rag_singleton as _rs diff --git a/src/embeddings.py b/src/embeddings.py index 67cfd86ad..f2d0c5934 100644 --- a/src/embeddings.py +++ b/src/embeddings.py @@ -38,12 +38,13 @@ _DEFAULT_FASTEMBED_MODEL = "sentence-transformers/all-MiniLM-L6-v2" class EmbeddingClient: """Drop-in replacement for SentenceTransformer.encode() using an HTTP API.""" - def __init__(self, url: Optional[str] = None, model: Optional[str] = None): + def __init__(self, url: Optional[str] = None, model: Optional[str] = None, api_key: Optional[str] = None): self.url = url or os.getenv( "EMBEDDING_URL", f"http://{os.getenv('LLM_HOST', 'localhost')}:11434/v1/embeddings", ) self.model = model or os.getenv("EMBEDDING_MODEL", _DEFAULT_MODEL) + self.api_key = api_key or os.getenv("EMBEDDING_API_KEY") self._dim: Optional[int] = None # Short connect timeout so a DOWN embedding endpoint (e.g. Ollama not # running on :11434) fast-fails to the local FastEmbed fallback instead @@ -74,6 +75,7 @@ class EmbeddingClient: batch = texts[i : i + 64] resp = self._client.post( self.url, + headers={"Authorization": f"Bearer {self.api_key}"} if self.api_key else {}, json={"input": batch, "model": self.model}, ) resp.raise_for_status() @@ -222,11 +224,14 @@ def get_embedding_client(): if persisted.get("url"): url = persisted["url"] model = persisted.get("model", "") + api_key = persisted.get("api_key", "") # Also set in env so other code sees it os.environ["EMBEDDING_URL"] = url if model: os.environ["EMBEDDING_MODEL"] = model - + if api_key: + from src.secret_storage import decrypt + os.environ["EMBEDDING_API_KEY"] = decrypt(api_key) # Try the HTTP embedding API — unless we already found it down this process # (avoids paying the connect timeout again on every RAG/memory/tool probe). if not _http_embed_down: diff --git a/tests/test_embeddings.py b/tests/test_embeddings.py new file mode 100644 index 000000000..a32fb1edc --- /dev/null +++ b/tests/test_embeddings.py @@ -0,0 +1,46 @@ +"""Tests for embeddings.py""" +from unittest.mock import MagicMock, patch +from src.embeddings import EmbeddingClient + + +class TestEmbeddingClient: + _MOCK_RESPONSE = { + "data": [{"embedding": [0.1], "index": 0}], + } + + def _make_mock_resp(self): + resp = MagicMock() + resp.status_code = 200 + resp.json.return_value = self._MOCK_RESPONSE + resp.raise_for_status = MagicMock() + return resp + + @patch("src.embeddings.httpx.Client") + def test_bearer_header_sent_when_api_key_set(self, mock_httpx): + """ + Test that the EmbeddingClient sends the Authorization header with the correct value when api_key is set. + """ + mock_httpx.return_value.post.return_value = self._make_mock_resp() + + client = EmbeddingClient( + url="http://test:11434/v1/embeddings", + model="all-minilm:l6-v2", + api_key="secret-key", + ) + client.encode(["x"]) + + headers = mock_httpx.return_value.post.call_args.kwargs["headers"] + assert headers.get("Authorization") == "Bearer secret-key" + + @patch("src.embeddings.httpx.Client") + def test_no_bearer_header_when_api_key_none(self, mock_httpx): + """ + Test that the EmbeddingClient does not send the Authorization header when api_key is None. + """ + mock_httpx.return_value.post.return_value = self._make_mock_resp() + + client = EmbeddingClient(url="http://test:11434/v1/embeddings") + client.encode(["x"]) + + headers = mock_httpx.return_value.post.call_args.kwargs["headers"] + assert "Authorization" not in headers