fix(hwfit): normalize CPU arch for fallback estimates (#4441)

This commit is contained in:
RaresKeY
2026-06-18 20:26:22 +02:00
committed by GitHub
parent b51d83b16d
commit 16e660ad09
5 changed files with 119 additions and 10 deletions
+25
View File
@@ -47,6 +47,12 @@ ARM64_SYSTEM = {
"gpu_vram_gb": 0,
}
ARM32_SYSTEM = {
"backend": "arm",
"gpu_name": None,
"gpu_vram_gb": 0,
}
AARCH64_SYSTEM = {
"backend": "aarch64",
"gpu_name": None,
@@ -79,6 +85,16 @@ def test_cpu_only_on_metal_apple_silicon_uses_cpu_arm_fallback():
assert metal_tps > 0
def test_cpu_only_on_gpu_backend_uses_detected_arm64_cpu_arch():
"""A GPU backend on an ARM64 host should use the ARM CPU fallback for cpu_only."""
cuda_arm64 = dict(CUDA_SYSTEM, cpu_arch="aarch64", cpu_name="Ampere Altra")
cuda_arm64_tps = _estimate_speed(DENSE_MODEL, QUANT, "cpu_only", cuda_arm64)
arm_tps = _estimate_speed(DENSE_MODEL, QUANT, "cpu_only", CPU_ARM_SYSTEM)
assert cuda_arm64_tps == pytest.approx(arm_tps, rel=1e-9, abs=1e-9)
assert cuda_arm64_tps > 0
@pytest.mark.parametrize(
"arm_alias_system",
[ARM64_SYSTEM, AARCH64_SYSTEM, CPU_ARM_SYSTEM],
@@ -93,6 +109,15 @@ def test_cpu_only_preserves_arm_backends(arm_alias_system):
assert alias_tps > 0
def test_cpu_only_does_not_treat_plain_arm_as_arm64_fallback():
"""Docker/OCI plain arm is not the ARM64-class fallback used for Apple Silicon."""
arm32_tps = _estimate_speed(DENSE_MODEL, QUANT, "cpu_only", ARM32_SYSTEM)
x86_tps = _estimate_speed(DENSE_MODEL, QUANT, "cpu_only", CPU_X86_SYSTEM)
assert arm32_tps == pytest.approx(x86_tps, rel=1e-9, abs=1e-9)
assert arm32_tps > 0
def test_cpu_only_preserves_known_cpu_backends():
"""Known CPU backends should be preserved, not rewritten to cpu_x86."""
for system in (CPU_X86_SYSTEM, CPU_ARM_SYSTEM):