mirror of
https://github.com/pewdiepie-archdaemon/odysseus.git
synced 2026-06-15 17:25:26 -04:00
Make LLM host health maps thread-safe
The synchronous llm_call() runs in FastAPI's threadpool (sync route handlers such as POST /sessions/auto-sort), while llm_call_async() runs on the event loop. Both mutate the module-level _response_cache, _host_fails and _dead_hosts dicts, so these are touched from multiple OS threads concurrently. Two races result: - _set_cached_response() snapshots 64 keys then deletes them with `del _response_cache[key]`; if another thread evicts the same key first, the del raises KeyError mid-eviction. Switched to pop(key, None). - _mark_host_dead() does get()+1+set() on _host_fails with no lock, so concurrent connect failures lose increments and a genuinely dead host can stay under its cooldown threshold. Guarded the host-health maps with a threading.Lock (also applied to _is_host_dead / _clear_host_dead for consistent reads). Adds tests/test_llm_core_concurrency.py with deterministic regression tests (phantom snapshot key for the eviction race; a slow-read dict that forces the lost-update window for the counter). Both fail on the unpatched code and pass with the fix.
This commit is contained in:
@@ -0,0 +1,79 @@
|
||||
"""Regression tests for thread-safe access to llm_core's shared maps (issue #659).
|
||||
|
||||
The synchronous llm_call() runs inside FastAPI's threadpool (sync route handlers
|
||||
such as POST /sessions/auto-sort), while llm_call_async() runs on the event
|
||||
loop. Both mutate the module-level _response_cache / _host_fails / _dead_hosts
|
||||
dicts, so those mutations must tolerate concurrent access from multiple OS
|
||||
threads.
|
||||
|
||||
Plain thread stress can't reliably reproduce these races (CPython's GIL rarely
|
||||
preempts the short critical sections), so each test deterministically widens the
|
||||
vulnerable window: one injects a phantom snapshot key, the other forces every
|
||||
thread to read the counter before any writes it back.
|
||||
"""
|
||||
import threading
|
||||
import time
|
||||
|
||||
from src import llm_core
|
||||
|
||||
|
||||
def test_cache_eviction_tolerates_already_removed_key():
|
||||
"""Eviction must not raise when a snapshotted key is gone by delete time.
|
||||
|
||||
Models a concurrent evictor removing the same key: the old `del` raised
|
||||
KeyError mid-loop, `pop(key, None)` does not.
|
||||
"""
|
||||
class PhantomKeysCache(dict):
|
||||
def keys(self):
|
||||
# First key is absent from the dict — as if another thread evicted
|
||||
# it between the snapshot and the delete.
|
||||
return ["__phantom_removed__", *super().keys()]
|
||||
|
||||
original = llm_core._response_cache
|
||||
cache = PhantomKeysCache()
|
||||
for i in range(130): # exceed the 128 cap so the eviction branch runs
|
||||
cache[f"k{i}"] = "x"
|
||||
llm_core._response_cache = cache
|
||||
try:
|
||||
llm_core._set_cached_response("new-key", "y") # must not raise
|
||||
assert dict.get(cache, "new-key") == "y"
|
||||
finally:
|
||||
llm_core._response_cache = original
|
||||
|
||||
|
||||
def test_host_fail_counter_has_no_lost_updates():
|
||||
"""Concurrent _mark_host_dead calls must each count exactly once.
|
||||
|
||||
A SlowGetDict widens the read-modify-write window so the unguarded
|
||||
get()+1+set() loses every update but one; the lock serializes them.
|
||||
"""
|
||||
url = "http://race.example:1234/v1/chat/completions"
|
||||
key = llm_core._host_key(url)
|
||||
|
||||
class SlowGetDict(dict):
|
||||
def get(self, *args, **kwargs):
|
||||
value = super().get(*args, **kwargs)
|
||||
time.sleep(0.01) # widen the gap between the read and the caller's write
|
||||
return value
|
||||
|
||||
n_threads = 8
|
||||
barrier = threading.Barrier(n_threads)
|
||||
original_fails = llm_core._host_fails
|
||||
original_threshold = llm_core._HOST_FAIL_THRESHOLD
|
||||
llm_core._host_fails = SlowGetDict()
|
||||
llm_core._HOST_FAIL_THRESHOLD = 10 ** 9 # never cool: every call is a pure +1
|
||||
try:
|
||||
def worker():
|
||||
barrier.wait() # all threads enter the read window together
|
||||
llm_core._mark_host_dead(url)
|
||||
|
||||
threads = [threading.Thread(target=worker) for _ in range(n_threads)]
|
||||
for t in threads:
|
||||
t.start()
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
assert dict.get(llm_core._host_fails, key) == n_threads
|
||||
finally:
|
||||
llm_core._host_fails = original_fails
|
||||
llm_core._HOST_FAIL_THRESHOLD = original_threshold
|
||||
Reference in New Issue
Block a user