Skip to content

Commit 7e7db04

Browse files
authored
[CI] Retry flaky fp8 cutlass mla tests (#24536)
Signed-off-by: Nick Hill <[email protected]>
1 parent 41f160b commit 7e7db04

File tree

1 file changed

+7
-1
lines changed

1 file changed

+7
-1
lines changed

tests/kernels/test_cutlass_mla_decode.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,13 @@ def cal_diff(x: torch.Tensor,
4949
@pytest.mark.parametrize("block_size", [64])
5050
@pytest.mark.parametrize("causal", [True])
5151
@pytest.mark.parametrize("varlen", [False, True])
52-
@pytest.mark.parametrize("torch_dtype", [torch.bfloat16, torch.float8_e4m3fn])
52+
@pytest.mark.parametrize(
53+
"torch_dtype",
54+
[
55+
torch.bfloat16,
56+
# fp8 can have occasional precision-related failures.
57+
pytest.param(torch.float8_e4m3fn, marks=pytest.mark.flaky(reruns=2))
58+
])
5359
@torch.inference_mode()
5460
def test_cutlass_mla_decode(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size,
5561
causal, varlen, torch_dtype):

0 commit comments

Comments
 (0)