Files
odysseus/src/api_key_manager.py
T
Sushanth Reddy eee2167502 Stop API key save() from writing other providers' keys as plaintext (#1944)
save() called load(), which DECRYPTS every stored key, then re-encrypted
only the key being saved and wrote the whole dict back. The other
providers' keys were thus persisted in plaintext; on the next load()
Fernet raised InvalidToken on them and they were silently dropped.

Add _load_raw() that returns the still-encrypted on-disk dict (reusing the
existing missing/corrupt-file guards) and have save() build on that, so
untouched providers keep their ciphertext. load() now also goes through
_load_raw(), keeping its behavior identical.

Fixes #1914

Co-authored-by: EkaTantra Dev <dev@ekatantra.com>
2026-06-04 04:47:13 +01:00

86 lines
3.2 KiB
Python

import os
import json
import logging
from typing import Dict
from cryptography.fernet import Fernet, InvalidToken
logger = logging.getLogger(__name__)
class APIKeyManager:
def __init__(self, data_dir: str):
self.data_dir = data_dir
self.api_keys_file = os.path.join(data_dir, "api_keys.json")
self.key_file = os.path.join(data_dir, ".key")
def get_or_create_key(self) -> bytes:
"""Get or create encryption key for API keys"""
if os.path.exists(self.key_file):
with open(self.key_file, 'rb') as f:
return f.read()
else:
key = Fernet.generate_key()
with open(self.key_file, 'wb') as f:
f.write(key)
return key
def encrypt_api_key(self, api_key: str) -> str:
"""Encrypt an API key"""
if not api_key:
return ""
f = Fernet(self.get_or_create_key())
return f.encrypt(api_key.encode()).decode()
def decrypt_api_key(self, encrypted_key: str) -> str:
"""Decrypt an API key"""
if not encrypted_key:
return ""
f = Fernet(self.get_or_create_key())
return f.decrypt(encrypted_key.encode()).decode()
def _load_raw(self) -> Dict[str, str]:
"""Load the raw, still-encrypted keys dict from disk.
Tolerates a missing/corrupt/wrong-shaped file by returning {} — the
same robustness load() relies on at startup.
"""
if not os.path.exists(self.api_keys_file):
return {}
try:
with open(self.api_keys_file, 'r', encoding="utf-8") as f:
encrypted_keys = json.load(f)
except (json.JSONDecodeError, OSError) as e:
# A corrupt/truncated api_keys.json must not crash load() (called on
# startup via app_initializer) — treat it as no stored keys.
logger.warning("Failed to read API keys file: %s", e)
return {}
if not isinstance(encrypted_keys, dict):
# Legacy/wrong shape (e.g. a list) — .items() would raise. Ignore it.
logger.warning("API keys file has unexpected shape (%s); ignoring", type(encrypted_keys).__name__)
return {}
return encrypted_keys
def save(self, provider: str, api_key: str):
"""Save encrypted API key to file.
Operates on the raw (still-encrypted) on-disk dict so other providers'
keys stay encrypted. Loading via load() first would decrypt them and
write them back as plaintext, which then fails to decrypt on the next
load() and silently drops those providers.
"""
keys = self._load_raw()
keys[provider] = self.encrypt_api_key(api_key)
with open(self.api_keys_file, 'w', encoding="utf-8") as f:
json.dump(keys, f)
def load(self) -> Dict[str, str]:
"""Load and decrypt API keys"""
encrypted_keys = self._load_raw()
decrypted = {}
for provider, key in encrypted_keys.items():
try:
decrypted[provider] = self.decrypt_api_key(key)
except (InvalidToken, ValueError) as e:
logger.warning("Failed to decrypt API key for %s: %s", provider, e)
return decrypted