diff --git a/services/hwfit/fit.py b/services/hwfit/fit.py index 10ab286e0..7a3d4c4f2 100644 --- a/services/hwfit/fit.py +++ b/services/hwfit/fit.py @@ -109,10 +109,15 @@ def _lookup_bandwidth(system): if not isinstance(gpu_name, str) or not gpu_name: return None - if isinstance(system, dict): - bw = _lookup_apple_bandwidth(system) - if bw is not None: - return bw + # Apple tiers live only in the Apple-specific table now (#2564), so route + # BOTH dict and bare-string callers through it. A bare string carries no + # gpu_cores, so the helper falls back to the conservative (lowest) tier for + # that model -- before #2564 the generic table answered string lookups, and + # dropping that made _lookup_bandwidth("Apple M3 Max") return None. + apple_input = system if isinstance(system, dict) else {"gpu_name": gpu_name} + bw = _lookup_apple_bandwidth(apple_input) + if bw is not None: + return bw gn = gpu_name.lower() for key in _BW_KEYS_SORTED: diff --git a/tests/test_hwfit_apple_bandwidth.py b/tests/test_hwfit_apple_bandwidth.py index f5b6df3d4..0977ba517 100644 --- a/tests/test_hwfit_apple_bandwidth.py +++ b/tests/test_hwfit_apple_bandwidth.py @@ -1,4 +1,4 @@ -from services.hwfit.fit import _lookup_bandwidth +from services.hwfit.fit import _lookup_apple_bandwidth, _lookup_bandwidth def test_m3_max_bandwidth_uses_gpu_cores(): @@ -35,6 +35,25 @@ def test_non_apple_gpu_does_not_match_apple_bandwidth(): def test_non_apple_gpu_with_cores_does_not_match(): - """NVIDIA GPU with core count should not match Apple bandwidth.""" - assert _lookup_bandwidth({"gpu_name": "NVIDIA GeForce RTX 4090", "gpu_cores": 128}) is None - assert _lookup_bandwidth({"gpu_name": "AMD Radeon RX 9070 XT", "gpu_cores": 64}) is None + """A non-Apple GPU that happens to carry a gpu_cores count must not be + matched by the APPLE bandwidth path. This asserts the Apple-specific + matcher directly: _lookup_bandwidth would (correctly) return these cards' + real bandwidth from the general GPU table (e.g. the RTX 4090's 1008 GB/s), + which is a different code path and not what this guard is about. + """ + assert _lookup_apple_bandwidth({"gpu_name": "NVIDIA GeForce RTX 4090", "gpu_cores": 128}) is None + assert _lookup_apple_bandwidth({"gpu_name": "AMD Radeon RX 9070 XT", "gpu_cores": 64}) is None + + +def test_apple_string_input_resolves_conservative_tier(): + """Bare-string callers must still get Apple bandwidth. #2564 moved the + Apple tiers out of the generic GPU table into the dict-only Apple helper, + so _lookup_bandwidth("Apple M3 Max") (no gpu_cores) regressed to None; + string inputs now route through the Apple helper and get the conservative + (lowest) tier for the model.""" + assert _lookup_bandwidth("Apple M3 Max") == 300 + assert _lookup_bandwidth("Apple M4 Max") == 410 + assert _lookup_bandwidth("Apple M5 Max") == 460 + # Non-Apple strings still fall through to the generic table. + assert _lookup_bandwidth("NVIDIA GeForce RTX 4090") == 1008 + assert _lookup_bandwidth("Totally Unknown GPU") is None