mirror of
https://github.com/pewdiepie-archdaemon/odysseus.git
synced 2026-06-17 10:15:27 -04:00
fix: add threading lock to AuthManager config mutations (#1226)
This commit is contained in:
@@ -76,6 +76,10 @@ class AuthManager:
|
|||||||
# Guards mutations of self._sessions and the on-disk sessions.json.
|
# Guards mutations of self._sessions and the on-disk sessions.json.
|
||||||
# Validate/create/revoke run concurrently from the FastAPI threadpool.
|
# Validate/create/revoke run concurrently from the FastAPI threadpool.
|
||||||
self._sessions_lock = threading.RLock()
|
self._sessions_lock = threading.RLock()
|
||||||
|
# Guards all mutations of self._config and the on-disk auth.json so
|
||||||
|
# concurrent create/delete/rename/privilege operations don't interleave
|
||||||
|
# and corrupt the user database.
|
||||||
|
self._config_lock = threading.Lock()
|
||||||
# Guards the first-run setup check-and-write so concurrent requests
|
# Guards the first-run setup check-and-write so concurrent requests
|
||||||
# cannot both observe is_configured==False and both create admin accounts.
|
# cannot both observe is_configured==False and both create admin accounts.
|
||||||
self._setup_lock = threading.Lock()
|
self._setup_lock = threading.Lock()
|
||||||
@@ -172,6 +176,7 @@ class AuthManager:
|
|||||||
|
|
||||||
@signup_enabled.setter
|
@signup_enabled.setter
|
||||||
def signup_enabled(self, value: bool):
|
def signup_enabled(self, value: bool):
|
||||||
|
with self._config_lock:
|
||||||
self._config["signup_enabled"] = value
|
self._config["signup_enabled"] = value
|
||||||
self._save()
|
self._save()
|
||||||
|
|
||||||
@@ -198,6 +203,7 @@ class AuthManager:
|
|||||||
if username in RESERVED_USERNAMES:
|
if username in RESERVED_USERNAMES:
|
||||||
logger.warning("Refused to create reserved username '%s'", username)
|
logger.warning("Refused to create reserved username '%s'", username)
|
||||||
return False
|
return False
|
||||||
|
with self._config_lock:
|
||||||
if username in self.users:
|
if username in self.users:
|
||||||
return False
|
return False
|
||||||
if "users" not in self._config:
|
if "users" not in self._config:
|
||||||
@@ -221,6 +227,7 @@ class AuthManager:
|
|||||||
their cookie expired naturally (default ~30 days).
|
their cookie expired naturally (default ~30 days).
|
||||||
"""
|
"""
|
||||||
username = username.strip().lower()
|
username = username.strip().lower()
|
||||||
|
with self._config_lock:
|
||||||
if username not in self.users:
|
if username not in self.users:
|
||||||
return False
|
return False
|
||||||
if username == requesting_user:
|
if username == requesting_user:
|
||||||
@@ -266,6 +273,7 @@ class AuthManager:
|
|||||||
if new_username in RESERVED_USERNAMES:
|
if new_username in RESERVED_USERNAMES:
|
||||||
logger.warning("Refused to rename '%s' into reserved username '%s'", old_username, new_username)
|
logger.warning("Refused to rename '%s' into reserved username '%s'", old_username, new_username)
|
||||||
return False
|
return False
|
||||||
|
with self._config_lock:
|
||||||
if old_username not in self.users:
|
if old_username not in self.users:
|
||||||
return False
|
return False
|
||||||
if new_username in self.users:
|
if new_username in self.users:
|
||||||
@@ -311,6 +319,7 @@ class AuthManager:
|
|||||||
def set_privileges(self, username: str, privileges: Dict[str, Any]) -> bool:
|
def set_privileges(self, username: str, privileges: Dict[str, Any]) -> bool:
|
||||||
"""Update privileges for a user. Can't modify admin privileges."""
|
"""Update privileges for a user. Can't modify admin privileges."""
|
||||||
username = username.strip().lower()
|
username = username.strip().lower()
|
||||||
|
with self._config_lock:
|
||||||
if username not in self.users:
|
if username not in self.users:
|
||||||
return False
|
return False
|
||||||
if self.users[username].get("is_admin"):
|
if self.users[username].get("is_admin"):
|
||||||
@@ -331,6 +340,7 @@ class AuthManager:
|
|||||||
return False
|
return False
|
||||||
if not _verify_password(current_password, self.users[username]["password_hash"]):
|
if not _verify_password(current_password, self.users[username]["password_hash"]):
|
||||||
return False
|
return False
|
||||||
|
with self._config_lock:
|
||||||
self._config["users"][username]["password_hash"] = _hash_password(new_password)
|
self._config["users"][username]["password_hash"] = _hash_password(new_password)
|
||||||
self._save()
|
self._save()
|
||||||
return True
|
return True
|
||||||
@@ -350,6 +360,7 @@ class AuthManager:
|
|||||||
if username not in self.users:
|
if username not in self.users:
|
||||||
return None
|
return None
|
||||||
secret = pyotp.random_base32()
|
secret = pyotp.random_base32()
|
||||||
|
with self._config_lock:
|
||||||
self._config["users"][username]["totp_secret_pending"] = secret
|
self._config["users"][username]["totp_secret_pending"] = secret
|
||||||
self._save()
|
self._save()
|
||||||
return secret
|
return secret
|
||||||
@@ -370,6 +381,7 @@ class AuthManager:
|
|||||||
if not totp.verify(code, valid_window=1):
|
if not totp.verify(code, valid_window=1):
|
||||||
return False
|
return False
|
||||||
# Enable 2FA
|
# Enable 2FA
|
||||||
|
with self._config_lock:
|
||||||
self._config["users"][username]["totp_secret"] = secret
|
self._config["users"][username]["totp_secret"] = secret
|
||||||
self._config["users"][username]["totp_enabled"] = True
|
self._config["users"][username]["totp_enabled"] = True
|
||||||
self._config["users"][username].pop("totp_secret_pending", None)
|
self._config["users"][username].pop("totp_secret_pending", None)
|
||||||
@@ -395,6 +407,7 @@ class AuthManager:
|
|||||||
# Check backup codes first
|
# Check backup codes first
|
||||||
backup = user.get("totp_backup_codes", [])
|
backup = user.get("totp_backup_codes", [])
|
||||||
if code in backup:
|
if code in backup:
|
||||||
|
with self._config_lock:
|
||||||
backup.remove(code)
|
backup.remove(code)
|
||||||
self._config["users"][username]["totp_backup_codes"] = backup
|
self._config["users"][username]["totp_backup_codes"] = backup
|
||||||
self._save()
|
self._save()
|
||||||
@@ -408,6 +421,7 @@ class AuthManager:
|
|||||||
username = username.strip().lower()
|
username = username.strip().lower()
|
||||||
if not self.verify_password(username, password):
|
if not self.verify_password(username, password):
|
||||||
return False
|
return False
|
||||||
|
with self._config_lock:
|
||||||
self._config["users"][username].pop("totp_secret", None)
|
self._config["users"][username].pop("totp_secret", None)
|
||||||
self._config["users"][username].pop("totp_secret_pending", None)
|
self._config["users"][username].pop("totp_secret_pending", None)
|
||||||
self._config["users"][username].pop("totp_backup_codes", None)
|
self._config["users"][username].pop("totp_backup_codes", None)
|
||||||
|
|||||||
@@ -0,0 +1,198 @@
|
|||||||
|
"""Concurrency stress tests for AuthManager._config_lock.
|
||||||
|
|
||||||
|
Verifies that concurrent create/delete/rename operations don't lose data
|
||||||
|
or corrupt auth.json. If someone removes the lock, these tests should fail
|
||||||
|
with missing users or assertion errors.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import sys
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
def _fresh_auth_manager(tmp_path):
|
||||||
|
sys.modules.pop("core.auth", None)
|
||||||
|
if "core" in sys.modules and hasattr(sys.modules["core"], "auth"):
|
||||||
|
delattr(sys.modules["core"], "auth")
|
||||||
|
from core.auth import AuthManager
|
||||||
|
|
||||||
|
return AuthManager(str(tmp_path / "auth.json"))
|
||||||
|
|
||||||
|
|
||||||
|
class TestConcurrentCreateUser:
|
||||||
|
"""Concurrent create_user calls must not lose accounts."""
|
||||||
|
|
||||||
|
def test_parallel_creates_no_lost_users(self, tmp_path):
|
||||||
|
mgr = _fresh_auth_manager(tmp_path)
|
||||||
|
num_users = 50
|
||||||
|
|
||||||
|
def create(i):
|
||||||
|
return mgr.create_user(f"user{i}", f"password{i}")
|
||||||
|
|
||||||
|
with ThreadPoolExecutor(max_workers=10) as pool:
|
||||||
|
futures = [pool.submit(create, i) for i in range(num_users)]
|
||||||
|
results = [f.result() for f in as_completed(futures)]
|
||||||
|
|
||||||
|
assert all(results), "Some create_user calls returned False unexpectedly"
|
||||||
|
assert len(mgr.users) == num_users
|
||||||
|
|
||||||
|
mgr2 = _fresh_auth_manager(tmp_path)
|
||||||
|
mgr2.auth_path = mgr.auth_path
|
||||||
|
mgr2._load()
|
||||||
|
assert len(mgr2.users) == num_users
|
||||||
|
|
||||||
|
def test_parallel_creates_same_username_only_one_wins(self, tmp_path):
|
||||||
|
mgr = _fresh_auth_manager(tmp_path)
|
||||||
|
num_attempts = 20
|
||||||
|
|
||||||
|
def create(_):
|
||||||
|
return mgr.create_user("contested", "password123")
|
||||||
|
|
||||||
|
with ThreadPoolExecutor(max_workers=10) as pool:
|
||||||
|
futures = [pool.submit(create, i) for i in range(num_attempts)]
|
||||||
|
results = [f.result() for f in as_completed(futures)]
|
||||||
|
|
||||||
|
assert results.count(True) == 1
|
||||||
|
assert results.count(False) == num_attempts - 1
|
||||||
|
assert len(mgr.users) == 1
|
||||||
|
|
||||||
|
|
||||||
|
class TestConcurrentDeleteUser:
|
||||||
|
"""Concurrent deletes must not corrupt state."""
|
||||||
|
|
||||||
|
def test_parallel_deletes_no_corruption(self, tmp_path):
|
||||||
|
mgr = _fresh_auth_manager(tmp_path)
|
||||||
|
mgr.create_user("admin", "adminpw", is_admin=True)
|
||||||
|
num_users = 30
|
||||||
|
for i in range(num_users):
|
||||||
|
mgr.create_user(f"target{i}", f"pw{i}")
|
||||||
|
|
||||||
|
assert len(mgr.users) == num_users + 1
|
||||||
|
|
||||||
|
def delete(i):
|
||||||
|
return mgr.delete_user(f"target{i}", "admin")
|
||||||
|
|
||||||
|
with ThreadPoolExecutor(max_workers=10) as pool:
|
||||||
|
futures = [pool.submit(delete, i) for i in range(num_users)]
|
||||||
|
results = [f.result() for f in as_completed(futures)]
|
||||||
|
|
||||||
|
assert all(results)
|
||||||
|
assert len(mgr.users) == 1
|
||||||
|
with open(mgr.auth_path, "r") as f:
|
||||||
|
data = json.load(f)
|
||||||
|
assert len(data["users"]) == 1
|
||||||
|
assert "admin" in data["users"]
|
||||||
|
|
||||||
|
|
||||||
|
class TestConcurrentRenameUser:
|
||||||
|
"""Concurrent renames must not lose or duplicate users."""
|
||||||
|
|
||||||
|
def test_parallel_renames_no_lost_users(self, tmp_path):
|
||||||
|
mgr = _fresh_auth_manager(tmp_path)
|
||||||
|
mgr.create_user("admin", "adminpw", is_admin=True)
|
||||||
|
num_users = 20
|
||||||
|
for i in range(num_users):
|
||||||
|
mgr.create_user(f"old{i}", f"pw{i}")
|
||||||
|
|
||||||
|
def rename(i):
|
||||||
|
return mgr.rename_user(f"old{i}", f"new{i}", "admin")
|
||||||
|
|
||||||
|
with ThreadPoolExecutor(max_workers=10) as pool:
|
||||||
|
futures = [pool.submit(rename, i) for i in range(num_users)]
|
||||||
|
results = [f.result() for f in as_completed(futures)]
|
||||||
|
|
||||||
|
assert all(results)
|
||||||
|
for i in range(num_users):
|
||||||
|
assert f"new{i}" in mgr.users
|
||||||
|
assert f"old{i}" not in mgr.users
|
||||||
|
|
||||||
|
assert len(mgr.users) == num_users + 1
|
||||||
|
|
||||||
|
|
||||||
|
class TestConcurrentMixedOperations:
|
||||||
|
"""Mixed create/delete/rename at the same time."""
|
||||||
|
|
||||||
|
def test_mixed_operations_no_corruption(self, tmp_path):
|
||||||
|
mgr = _fresh_auth_manager(tmp_path)
|
||||||
|
mgr.create_user("admin", "adminpw", is_admin=True)
|
||||||
|
|
||||||
|
for i in range(20):
|
||||||
|
mgr.create_user(f"existing{i}", f"pw{i}")
|
||||||
|
|
||||||
|
def create_batch():
|
||||||
|
for i in range(20):
|
||||||
|
mgr.create_user(f"newuser{i}", f"pw{i}")
|
||||||
|
|
||||||
|
def delete_batch():
|
||||||
|
for i in range(10):
|
||||||
|
mgr.delete_user(f"existing{i}", "admin")
|
||||||
|
|
||||||
|
def rename_batch():
|
||||||
|
for i in range(10, 20):
|
||||||
|
mgr.rename_user(f"existing{i}", f"renamed{i}", "admin")
|
||||||
|
|
||||||
|
threads = [
|
||||||
|
threading.Thread(target=create_batch),
|
||||||
|
threading.Thread(target=delete_batch),
|
||||||
|
threading.Thread(target=rename_batch),
|
||||||
|
]
|
||||||
|
for t in threads:
|
||||||
|
t.start()
|
||||||
|
for t in threads:
|
||||||
|
t.join()
|
||||||
|
|
||||||
|
assert "admin" in mgr.users
|
||||||
|
for i in range(10):
|
||||||
|
assert f"existing{i}" not in mgr.users
|
||||||
|
for i in range(10, 20):
|
||||||
|
assert f"renamed{i}" in mgr.users
|
||||||
|
assert f"existing{i}" not in mgr.users
|
||||||
|
for i in range(20):
|
||||||
|
assert f"newuser{i}" in mgr.users
|
||||||
|
|
||||||
|
with open(mgr.auth_path, "r") as f:
|
||||||
|
data = json.load(f)
|
||||||
|
assert set(data["users"].keys()) == set(mgr.users.keys())
|
||||||
|
|
||||||
|
|
||||||
|
class TestDiskConsistency:
|
||||||
|
"""Verify auth.json is never in a corrupt state during concurrent writes."""
|
||||||
|
|
||||||
|
def test_file_always_valid_json_during_concurrent_ops(self, tmp_path):
|
||||||
|
mgr = _fresh_auth_manager(tmp_path)
|
||||||
|
mgr.create_user("admin", "adminpw", is_admin=True)
|
||||||
|
|
||||||
|
stop_event = threading.Event()
|
||||||
|
corruption_found = []
|
||||||
|
|
||||||
|
def reader():
|
||||||
|
while not stop_event.is_set():
|
||||||
|
try:
|
||||||
|
with open(mgr.auth_path, "r") as f:
|
||||||
|
content = f.read()
|
||||||
|
json.loads(content)
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
corruption_found.append(str(e))
|
||||||
|
break
|
||||||
|
except FileNotFoundError:
|
||||||
|
pass
|
||||||
|
time.sleep(0.001)
|
||||||
|
|
||||||
|
def writer():
|
||||||
|
for i in range(50):
|
||||||
|
mgr.create_user(f"stress{i}", f"pw{i}")
|
||||||
|
|
||||||
|
reader_thread = threading.Thread(target=reader)
|
||||||
|
writer_thread = threading.Thread(target=writer)
|
||||||
|
|
||||||
|
reader_thread.start()
|
||||||
|
writer_thread.start()
|
||||||
|
writer_thread.join()
|
||||||
|
stop_event.set()
|
||||||
|
reader_thread.join()
|
||||||
|
|
||||||
|
assert not corruption_found, f"Corrupt JSON detected: {corruption_found[0]}"
|
||||||
Reference in New Issue
Block a user