fix(hwfit): distinguish Apple Silicon bandwidth variants (#2564)

* fix: resolve Apple Silicon bandwidth variants

* fix(hwfit): preserve string lookup path in _lookup_bandwidth

* fix(hwfit): guard Apple bandwidth lookup against false GPU matches

Add "apple" not in gn check to _lookup_apple_bandwidth() so that
non-Apple GPUs with "m3"/"m4"/"m5" in their names (e.g. NVIDIA
Quadro M4 000) don't incorrectly match Apple bandwidth tiers.

Addresses @o3LL review comment on PR #2564.
This commit is contained in:
Ahmad Naalweh
2026-06-15 15:13:03 +02:00
committed by GitHub
parent 514d345334
commit f7aa2de410
4 changed files with 184 additions and 17 deletions
+40
View File
@@ -0,0 +1,40 @@
from services.hwfit.fit import _lookup_bandwidth
def test_m3_max_bandwidth_uses_gpu_cores():
assert _lookup_bandwidth({"gpu_name": "Apple M3 Max", "gpu_cores": 30}) == 300
assert _lookup_bandwidth({"gpu_name": "Apple M3 Max", "gpu_cores": 40}) == 400
def test_m4_max_bandwidth_uses_gpu_cores():
assert _lookup_bandwidth({"gpu_name": "Apple M4 Max", "gpu_cores": 32}) == 410
assert _lookup_bandwidth({"gpu_name": "Apple M4 Max", "gpu_cores": 40}) == 546
def test_m5_max_bandwidth_uses_gpu_cores():
assert _lookup_bandwidth({"gpu_name": "Apple M5 Max", "gpu_cores": 32}) == 460
assert _lookup_bandwidth({"gpu_name": "Apple M5 Max", "gpu_cores": 40}) == 614
def test_apple_max_bandwidth_falls_back_conservatively_without_gpu_cores():
assert _lookup_bandwidth({"gpu_name": "Apple M3 Max"}) == 300
assert _lookup_bandwidth({"gpu_name": "Apple M4 Max"}) == 410
assert _lookup_bandwidth({"gpu_name": "Apple M5 Max"}) == 460
def test_fixed_apple_bandwidth_entries_include_updated_m5_values():
assert _lookup_bandwidth({"gpu_name": "Apple M5 Pro"}) == 307
assert _lookup_bandwidth({"gpu_name": "Apple M5"}) == 153
def test_non_apple_gpu_does_not_match_apple_bandwidth():
"""NVIDIA Quadro M4 000 should NOT match Apple bandwidth lookup."""
assert _lookup_bandwidth({"gpu_name": "NVIDIA Quadro M4 000"}) is None
assert _lookup_bandwidth({"gpu_name": "NVIDIA Quadro M3 000"}) is None
assert _lookup_bandwidth({"gpu_name": "NVIDIA Quadro M5 000"}) is None
def test_non_apple_gpu_with_cores_does_not_match():
"""NVIDIA GPU with core count should not match Apple bandwidth."""
assert _lookup_bandwidth({"gpu_name": "NVIDIA GeForce RTX 4090", "gpu_cores": 128}) is None
assert _lookup_bandwidth({"gpu_name": "AMD Radeon RX 9070 XT", "gpu_cores": 64}) is None
+43 -3
View File
@@ -4,6 +4,8 @@ Covers the Metal-specific behavior added for Apple Silicon and locks in the
guarantee that non-macOS (Linux/Windows) detection is unchanged.
"""
import json
from services.hwfit import hardware
from services.hwfit.fit import rank_models
from services.hwfit.models import get_models
@@ -22,7 +24,7 @@ def _metal_system(ram_gb=16.0, vram_gb=10.7):
}
def _fake_sysctl(brand="Apple M2 Pro", memsize_gb=32, wired_mb=None):
def _fake_sysctl(brand="Apple M2 Pro", memsize_gb=32, wired_mb=None, display_json=None, display_text=None):
def run(cmd):
joined = " ".join(cmd)
if "machdep.cpu.brand_string" in joined:
@@ -31,6 +33,12 @@ def _fake_sysctl(brand="Apple M2 Pro", memsize_gb=32, wired_mb=None):
return str(int(memsize_gb * 1024**3))
if "iogpu.wired_limit_mb" in joined:
return str(wired_mb) if wired_mb is not None else None
if "system_profiler SPDisplaysDataType -json" in joined:
if isinstance(display_json, (dict, list)):
return json.dumps(display_json)
return display_json
if "system_profiler SPDisplaysDataType" in joined:
return display_text
return None
return run
@@ -98,16 +106,47 @@ def test_apple_silicon_detected_as_metal(monkeypatch):
monkeypatch.setattr(hardware, "_remote_host", None)
monkeypatch.setattr(hardware.platform, "system", lambda: "Darwin")
monkeypatch.setattr(hardware.platform, "machine", lambda: "arm64")
monkeypatch.setattr(hardware, "_run", _fake_sysctl(memsize_gb=32))
monkeypatch.setattr(hardware, "_run", _fake_sysctl(
memsize_gb=32,
display_json={"SPDisplaysDataType": [{"sppci_model": "Apple M2 Pro", "sppci_cores": "19"}]},
))
info = hardware._detect_apple_silicon()
assert info is not None
assert info["backend"] == "metal"
assert info["gpu_name"] == "Apple M2 Pro"
assert info["unified_memory"] is True
assert info["gpu_cores"] == 19
assert info["gpu_vram_gb"] == 24.0 # 32GB * 0.75
def test_apple_silicon_gpu_cores_fall_back_to_plain_text(monkeypatch):
monkeypatch.setattr(hardware, "_remote_host", None)
monkeypatch.setattr(hardware.platform, "system", lambda: "Darwin")
monkeypatch.setattr(hardware.platform, "machine", lambda: "arm64")
monkeypatch.setattr(hardware, "_run", _fake_sysctl(
brand="Apple M4 Max",
memsize_gb=64,
display_json="{not-json",
display_text="Graphics/Displays:\n\nApple M4 Max:\n Total Number of Cores: 32\n",
))
info = hardware._detect_apple_silicon()
assert info is not None
assert info["gpu_cores"] == 32
def test_apple_silicon_gpu_cores_are_optional(monkeypatch):
monkeypatch.setattr(hardware, "_remote_host", None)
monkeypatch.setattr(hardware.platform, "system", lambda: "Darwin")
monkeypatch.setattr(hardware.platform, "machine", lambda: "arm64")
monkeypatch.setattr(hardware, "_run", _fake_sysctl(memsize_gb=32))
info = hardware._detect_apple_silicon()
assert info is not None
assert "gpu_cores" not in info
def test_apple_silicon_skipped_on_linux(monkeypatch):
"""Guarantee Linux detection is untouched: the Metal probe bails immediately."""
monkeypatch.setattr(hardware, "_remote_host", None)
@@ -132,7 +171,7 @@ def test_detect_system_propagates_unified_memory(monkeypatch):
monkeypatch.setattr(hardware, "_detect_apple_silicon", lambda: {
"gpu_name": "Apple M4", "gpu_vram_gb": 10.7, "gpu_count": 1,
"gpus": [], "gpu_groups": [], "homogeneous": True,
"backend": "metal", "unified_memory": True,
"backend": "metal", "unified_memory": True, "gpu_cores": 10,
})
monkeypatch.setattr(hardware, "_get_ram_gb", lambda: 16.0)
monkeypatch.setattr(hardware, "_get_available_ram_gb", lambda: 11.0)
@@ -142,3 +181,4 @@ def test_detect_system_propagates_unified_memory(monkeypatch):
s = hardware.detect_system(fresh=True)
assert s["backend"] == "metal"
assert s.get("unified_memory") is True
assert s["gpu_cores"] == 10