Files
odysseus/src/api_key_manager.py
T
Tom 2857723e47 fix(security): restrict API-key encryption key file to 0o600
Lock the API key encryption key file to owner-only permissions on creation and when reading existing keys, with regression coverage for permissions and encryption roundtrip.
2026-06-15 15:00:11 +09:00

100 lines
3.9 KiB
Python

import os
import json
import logging
from typing import Dict
from cryptography.fernet import Fernet, InvalidToken
from core.platform_compat import safe_chmod
logger = logging.getLogger(__name__)
class APIKeyManager:
def __init__(self, data_dir: str):
self.data_dir = data_dir
self.api_keys_file = os.path.join(data_dir, "api_keys.json")
self.key_file = os.path.join(data_dir, ".key")
def get_or_create_key(self) -> bytes:
"""Get or create encryption key for API keys"""
if os.path.exists(self.key_file):
# Older versions wrote .key with the process umask (often 0o644,
# i.e. group/world-readable). Re-restrict on read so existing
# installs heal without needing the key to be regenerated.
safe_chmod(self.key_file, 0o600)
with open(self.key_file, 'rb') as f:
return f.read()
else:
key = Fernet.generate_key()
with open(self.key_file, 'wb') as f:
f.write(key)
# This key decrypts every stored provider credential, so restrict it
# to the owner (0o600) — it must not be group/world-readable. No-op
# on Windows (files there are ACL-restricted to the user already).
safe_chmod(self.key_file, 0o600)
return key
def encrypt_api_key(self, api_key: str) -> str:
"""Encrypt an API key"""
if not api_key:
return ""
f = Fernet(self.get_or_create_key())
return f.encrypt(api_key.encode()).decode()
def decrypt_api_key(self, encrypted_key: str) -> str:
"""Decrypt an API key"""
if not encrypted_key:
return ""
f = Fernet(self.get_or_create_key())
return f.decrypt(encrypted_key.encode()).decode()
def _load_raw(self) -> Dict[str, str]:
"""Load the raw, still-encrypted keys dict from disk.
Tolerates a missing/corrupt/wrong-shaped file by returning {} — the
same robustness load() relies on at startup.
"""
if not os.path.exists(self.api_keys_file):
return {}
try:
with open(self.api_keys_file, 'r', encoding="utf-8") as f:
encrypted_keys = json.load(f)
except (json.JSONDecodeError, OSError) as e:
# A corrupt/truncated api_keys.json must not crash load() (called on
# startup via app_initializer) — treat it as no stored keys.
logger.warning("Failed to read API keys file: %s", e)
return {}
if not isinstance(encrypted_keys, dict):
# Legacy/wrong shape (e.g. a list) — .items() would raise. Ignore it.
logger.warning("API keys file has unexpected shape (%s); ignoring", type(encrypted_keys).__name__)
return {}
return {
str(provider): key
for provider, key in encrypted_keys.items()
if isinstance(key, str)
}
def save(self, provider: str, api_key: str):
"""Save encrypted API key to file.
Operates on the raw (still-encrypted) on-disk dict so other providers'
keys stay encrypted. Loading via load() first would decrypt them and
write them back as plaintext, which then fails to decrypt on the next
load() and silently drops those providers.
"""
keys = self._load_raw()
keys[provider] = self.encrypt_api_key(api_key)
with open(self.api_keys_file, 'w', encoding="utf-8") as f:
json.dump(keys, f)
def load(self) -> Dict[str, str]:
"""Load and decrypt API keys"""
encrypted_keys = self._load_raw()
decrypted = {}
for provider, key in encrypted_keys.items():
try:
decrypted[provider] = self.decrypt_api_key(key)
except (InvalidToken, ValueError) as e:
logger.warning("Failed to decrypt API key for %s: %s", provider, e)
return decrypted