Skip to content

Commit

Permalink
move sm80 code inside MHA (#937)
Browse files Browse the repository at this point in the history
Co-authored-by: pbialecki <[email protected]>
  • Loading branch information
ptrblck and ptrblck authored Aug 10, 2020
1 parent 85b1783 commit 5d9b5cb
Showing 1 changed file with 7 additions and 6 deletions.
13 changes: 7 additions & 6 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,12 +278,6 @@ def check_cuda_torch_binary_vs_bare_metal(cuda_dir):
if os.path.exists(os.path.join(torch_dir, 'include', 'ATen', 'CUDAGenerator.h')):
generator_flag = ['-DOLD_GENERATOR']

# Check, if CUDA11 is installed for compute capability 8.0
cc_flag = []
_, bare_metal_major, _ = get_cuda_bare_metal_version(cpp_extension.CUDA_HOME)
if int(bare_metal_major) >= 11:
cc_flag.append('-gencode')
cc_flag.append('arch=compute_80,code=sm_80')

if "--fast_multihead_attn" in sys.argv:
from torch.utils.cpp_extension import CUDAExtension
Expand All @@ -295,6 +289,13 @@ def check_cuda_torch_binary_vs_bare_metal(cuda_dir):
if torch.utils.cpp_extension.CUDA_HOME is None:
raise RuntimeError("--fast_multihead_attn was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
else:
# Check, if CUDA11 is installed for compute capability 8.0
cc_flag = []
_, bare_metal_major, _ = get_cuda_bare_metal_version(cpp_extension.CUDA_HOME)
if int(bare_metal_major) >= 11:
cc_flag.append('-gencode')
cc_flag.append('arch=compute_80,code=sm_80')

subprocess.run(["git", "submodule", "update", "--init", "apex/contrib/csrc/multihead_attn/cutlass"])
ext_modules.append(
CUDAExtension(name='fast_additive_mask_softmax_dropout',
Expand Down

0 comments on commit 5d9b5cb

Please sign in to comment.