Skip to content

Commit

Permalink
update: use sdpa_kernel
Browse files Browse the repository at this point in the history
  • Loading branch information
jla524 committed Jan 1, 2025
1 parent 919220d commit 6bae9bb
Showing 1 changed file with 2 additions and 6 deletions.
8 changes: 2 additions & 6 deletions tests/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -4255,7 +4255,7 @@ def test_sdpa_can_dispatch_on_flash(self):
if isinstance(inp, torch.Tensor) and inp.dtype in [torch.float32, torch.float16]:
inputs_dict[name] = inp.to(torch.float16)

with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
with sdpa_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
_ = model(**inputs_dict)

@require_non_xpu
Expand Down Expand Up @@ -4347,11 +4347,7 @@ def test_sdpa_matches_eager_sliding_window(self):
model_sdpa = model_sdpa.eval()

with torch.no_grad():
with torch.backends.cuda.sdp_kernel(
enable_flash=False,
enable_math=True,
enable_mem_efficient=False,
):
with sdpa_kernel(enable_flash=False, enable_math=True, enable_mem_efficient=False):
res_eager = model_eager(**inputs_dict, return_dict=False)[0]
res_sdpa = model_sdpa(**inputs_dict, return_dict=False)[0]

Expand Down

0 comments on commit 6bae9bb

Please sign in to comment.