mirror of
https://github.com/pewdiepie-archdaemon/odysseus.git
synced 2026-06-17 10:15:27 -04:00
fix: add threading lock to AuthManager config mutations (#1226)
This commit is contained in:
@@ -76,6 +76,10 @@ class AuthManager:
|
||||
# Guards mutations of self._sessions and the on-disk sessions.json.
|
||||
# Validate/create/revoke run concurrently from the FastAPI threadpool.
|
||||
self._sessions_lock = threading.RLock()
|
||||
# Guards all mutations of self._config and the on-disk auth.json so
|
||||
# concurrent create/delete/rename/privilege operations don't interleave
|
||||
# and corrupt the user database.
|
||||
self._config_lock = threading.Lock()
|
||||
# Guards the first-run setup check-and-write so concurrent requests
|
||||
# cannot both observe is_configured==False and both create admin accounts.
|
||||
self._setup_lock = threading.Lock()
|
||||
@@ -172,6 +176,7 @@ class AuthManager:
|
||||
|
||||
@signup_enabled.setter
|
||||
def signup_enabled(self, value: bool):
|
||||
with self._config_lock:
|
||||
self._config["signup_enabled"] = value
|
||||
self._save()
|
||||
|
||||
@@ -198,6 +203,7 @@ class AuthManager:
|
||||
if username in RESERVED_USERNAMES:
|
||||
logger.warning("Refused to create reserved username '%s'", username)
|
||||
return False
|
||||
with self._config_lock:
|
||||
if username in self.users:
|
||||
return False
|
||||
if "users" not in self._config:
|
||||
@@ -221,6 +227,7 @@ class AuthManager:
|
||||
their cookie expired naturally (default ~30 days).
|
||||
"""
|
||||
username = username.strip().lower()
|
||||
with self._config_lock:
|
||||
if username not in self.users:
|
||||
return False
|
||||
if username == requesting_user:
|
||||
@@ -266,6 +273,7 @@ class AuthManager:
|
||||
if new_username in RESERVED_USERNAMES:
|
||||
logger.warning("Refused to rename '%s' into reserved username '%s'", old_username, new_username)
|
||||
return False
|
||||
with self._config_lock:
|
||||
if old_username not in self.users:
|
||||
return False
|
||||
if new_username in self.users:
|
||||
@@ -311,6 +319,7 @@ class AuthManager:
|
||||
def set_privileges(self, username: str, privileges: Dict[str, Any]) -> bool:
|
||||
"""Update privileges for a user. Can't modify admin privileges."""
|
||||
username = username.strip().lower()
|
||||
with self._config_lock:
|
||||
if username not in self.users:
|
||||
return False
|
||||
if self.users[username].get("is_admin"):
|
||||
@@ -331,6 +340,7 @@ class AuthManager:
|
||||
return False
|
||||
if not _verify_password(current_password, self.users[username]["password_hash"]):
|
||||
return False
|
||||
with self._config_lock:
|
||||
self._config["users"][username]["password_hash"] = _hash_password(new_password)
|
||||
self._save()
|
||||
return True
|
||||
@@ -350,6 +360,7 @@ class AuthManager:
|
||||
if username not in self.users:
|
||||
return None
|
||||
secret = pyotp.random_base32()
|
||||
with self._config_lock:
|
||||
self._config["users"][username]["totp_secret_pending"] = secret
|
||||
self._save()
|
||||
return secret
|
||||
@@ -370,6 +381,7 @@ class AuthManager:
|
||||
if not totp.verify(code, valid_window=1):
|
||||
return False
|
||||
# Enable 2FA
|
||||
with self._config_lock:
|
||||
self._config["users"][username]["totp_secret"] = secret
|
||||
self._config["users"][username]["totp_enabled"] = True
|
||||
self._config["users"][username].pop("totp_secret_pending", None)
|
||||
@@ -395,6 +407,7 @@ class AuthManager:
|
||||
# Check backup codes first
|
||||
backup = user.get("totp_backup_codes", [])
|
||||
if code in backup:
|
||||
with self._config_lock:
|
||||
backup.remove(code)
|
||||
self._config["users"][username]["totp_backup_codes"] = backup
|
||||
self._save()
|
||||
@@ -408,6 +421,7 @@ class AuthManager:
|
||||
username = username.strip().lower()
|
||||
if not self.verify_password(username, password):
|
||||
return False
|
||||
with self._config_lock:
|
||||
self._config["users"][username].pop("totp_secret", None)
|
||||
self._config["users"][username].pop("totp_secret_pending", None)
|
||||
self._config["users"][username].pop("totp_backup_codes", None)
|
||||
|
||||
@@ -0,0 +1,198 @@
|
||||
"""Concurrency stress tests for AuthManager._config_lock.
|
||||
|
||||
Verifies that concurrent create/delete/rename operations don't lose data
|
||||
or corrupt auth.json. If someone removes the lock, these tests should fail
|
||||
with missing users or assertion errors.
|
||||
"""
|
||||
|
||||
import json
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
def _fresh_auth_manager(tmp_path):
|
||||
sys.modules.pop("core.auth", None)
|
||||
if "core" in sys.modules and hasattr(sys.modules["core"], "auth"):
|
||||
delattr(sys.modules["core"], "auth")
|
||||
from core.auth import AuthManager
|
||||
|
||||
return AuthManager(str(tmp_path / "auth.json"))
|
||||
|
||||
|
||||
class TestConcurrentCreateUser:
|
||||
"""Concurrent create_user calls must not lose accounts."""
|
||||
|
||||
def test_parallel_creates_no_lost_users(self, tmp_path):
|
||||
mgr = _fresh_auth_manager(tmp_path)
|
||||
num_users = 50
|
||||
|
||||
def create(i):
|
||||
return mgr.create_user(f"user{i}", f"password{i}")
|
||||
|
||||
with ThreadPoolExecutor(max_workers=10) as pool:
|
||||
futures = [pool.submit(create, i) for i in range(num_users)]
|
||||
results = [f.result() for f in as_completed(futures)]
|
||||
|
||||
assert all(results), "Some create_user calls returned False unexpectedly"
|
||||
assert len(mgr.users) == num_users
|
||||
|
||||
mgr2 = _fresh_auth_manager(tmp_path)
|
||||
mgr2.auth_path = mgr.auth_path
|
||||
mgr2._load()
|
||||
assert len(mgr2.users) == num_users
|
||||
|
||||
def test_parallel_creates_same_username_only_one_wins(self, tmp_path):
|
||||
mgr = _fresh_auth_manager(tmp_path)
|
||||
num_attempts = 20
|
||||
|
||||
def create(_):
|
||||
return mgr.create_user("contested", "password123")
|
||||
|
||||
with ThreadPoolExecutor(max_workers=10) as pool:
|
||||
futures = [pool.submit(create, i) for i in range(num_attempts)]
|
||||
results = [f.result() for f in as_completed(futures)]
|
||||
|
||||
assert results.count(True) == 1
|
||||
assert results.count(False) == num_attempts - 1
|
||||
assert len(mgr.users) == 1
|
||||
|
||||
|
||||
class TestConcurrentDeleteUser:
|
||||
"""Concurrent deletes must not corrupt state."""
|
||||
|
||||
def test_parallel_deletes_no_corruption(self, tmp_path):
|
||||
mgr = _fresh_auth_manager(tmp_path)
|
||||
mgr.create_user("admin", "adminpw", is_admin=True)
|
||||
num_users = 30
|
||||
for i in range(num_users):
|
||||
mgr.create_user(f"target{i}", f"pw{i}")
|
||||
|
||||
assert len(mgr.users) == num_users + 1
|
||||
|
||||
def delete(i):
|
||||
return mgr.delete_user(f"target{i}", "admin")
|
||||
|
||||
with ThreadPoolExecutor(max_workers=10) as pool:
|
||||
futures = [pool.submit(delete, i) for i in range(num_users)]
|
||||
results = [f.result() for f in as_completed(futures)]
|
||||
|
||||
assert all(results)
|
||||
assert len(mgr.users) == 1
|
||||
with open(mgr.auth_path, "r") as f:
|
||||
data = json.load(f)
|
||||
assert len(data["users"]) == 1
|
||||
assert "admin" in data["users"]
|
||||
|
||||
|
||||
class TestConcurrentRenameUser:
|
||||
"""Concurrent renames must not lose or duplicate users."""
|
||||
|
||||
def test_parallel_renames_no_lost_users(self, tmp_path):
|
||||
mgr = _fresh_auth_manager(tmp_path)
|
||||
mgr.create_user("admin", "adminpw", is_admin=True)
|
||||
num_users = 20
|
||||
for i in range(num_users):
|
||||
mgr.create_user(f"old{i}", f"pw{i}")
|
||||
|
||||
def rename(i):
|
||||
return mgr.rename_user(f"old{i}", f"new{i}", "admin")
|
||||
|
||||
with ThreadPoolExecutor(max_workers=10) as pool:
|
||||
futures = [pool.submit(rename, i) for i in range(num_users)]
|
||||
results = [f.result() for f in as_completed(futures)]
|
||||
|
||||
assert all(results)
|
||||
for i in range(num_users):
|
||||
assert f"new{i}" in mgr.users
|
||||
assert f"old{i}" not in mgr.users
|
||||
|
||||
assert len(mgr.users) == num_users + 1
|
||||
|
||||
|
||||
class TestConcurrentMixedOperations:
|
||||
"""Mixed create/delete/rename at the same time."""
|
||||
|
||||
def test_mixed_operations_no_corruption(self, tmp_path):
|
||||
mgr = _fresh_auth_manager(tmp_path)
|
||||
mgr.create_user("admin", "adminpw", is_admin=True)
|
||||
|
||||
for i in range(20):
|
||||
mgr.create_user(f"existing{i}", f"pw{i}")
|
||||
|
||||
def create_batch():
|
||||
for i in range(20):
|
||||
mgr.create_user(f"newuser{i}", f"pw{i}")
|
||||
|
||||
def delete_batch():
|
||||
for i in range(10):
|
||||
mgr.delete_user(f"existing{i}", "admin")
|
||||
|
||||
def rename_batch():
|
||||
for i in range(10, 20):
|
||||
mgr.rename_user(f"existing{i}", f"renamed{i}", "admin")
|
||||
|
||||
threads = [
|
||||
threading.Thread(target=create_batch),
|
||||
threading.Thread(target=delete_batch),
|
||||
threading.Thread(target=rename_batch),
|
||||
]
|
||||
for t in threads:
|
||||
t.start()
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
assert "admin" in mgr.users
|
||||
for i in range(10):
|
||||
assert f"existing{i}" not in mgr.users
|
||||
for i in range(10, 20):
|
||||
assert f"renamed{i}" in mgr.users
|
||||
assert f"existing{i}" not in mgr.users
|
||||
for i in range(20):
|
||||
assert f"newuser{i}" in mgr.users
|
||||
|
||||
with open(mgr.auth_path, "r") as f:
|
||||
data = json.load(f)
|
||||
assert set(data["users"].keys()) == set(mgr.users.keys())
|
||||
|
||||
|
||||
class TestDiskConsistency:
|
||||
"""Verify auth.json is never in a corrupt state during concurrent writes."""
|
||||
|
||||
def test_file_always_valid_json_during_concurrent_ops(self, tmp_path):
|
||||
mgr = _fresh_auth_manager(tmp_path)
|
||||
mgr.create_user("admin", "adminpw", is_admin=True)
|
||||
|
||||
stop_event = threading.Event()
|
||||
corruption_found = []
|
||||
|
||||
def reader():
|
||||
while not stop_event.is_set():
|
||||
try:
|
||||
with open(mgr.auth_path, "r") as f:
|
||||
content = f.read()
|
||||
json.loads(content)
|
||||
except json.JSONDecodeError as e:
|
||||
corruption_found.append(str(e))
|
||||
break
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
time.sleep(0.001)
|
||||
|
||||
def writer():
|
||||
for i in range(50):
|
||||
mgr.create_user(f"stress{i}", f"pw{i}")
|
||||
|
||||
reader_thread = threading.Thread(target=reader)
|
||||
writer_thread = threading.Thread(target=writer)
|
||||
|
||||
reader_thread.start()
|
||||
writer_thread.start()
|
||||
writer_thread.join()
|
||||
stop_event.set()
|
||||
reader_thread.join()
|
||||
|
||||
assert not corruption_found, f"Corrupt JSON detected: {corruption_found[0]}"
|
||||
Reference in New Issue
Block a user