diff --git a/tests/models/test_attention_processor.py b/tests/models/test_attention_processor.py index c334feeefee9..d070f6ea33e3 100644 --- a/tests/models/test_attention_processor.py +++ b/tests/models/test_attention_processor.py @@ -81,9 +81,13 @@ def test_only_cross_attention(self): class DeprecatedAttentionBlockTests(unittest.TestCase): + @pytest.fixture(scope="session") + def is_dist_enabled(pytestconfig): + return pytestconfig.getoption("dist") == "loadfile" + @pytest.mark.xfail( - condition=torch.device(torch_device).type == "cuda", - reason="Test currently fails on our GPU CI because of `disfile`.", + condition=torch.device(torch_device).type == "cuda" and is_dist_enabled, + reason="Test currently fails on our GPU CI because of `loadfile`. Note that it only fails when the tests are distributed from `pytest ... tests/models`. If the tests are run individually, even with `loadfile` it won't fail.", strict=True, ) def test_conversion_when_using_device_map(self):