diff --git a/setup.py b/setup.py index 915c6960c..c08436378 100644 --- a/setup.py +++ b/setup.py @@ -47,10 +47,9 @@ def check_cuda_torch_binary_vs_bare_metal(cuda_dir): print(raw_output + "from " + cuda_dir + "/bin\n") if (bare_metal_major != torch_binary_major) or (bare_metal_minor != torch_binary_minor): - # TODO: make this a hard error? - print("\nWarning: Cuda extensions are being compiled with a version of Cuda that does " - "not match the version used to compile Pytorch binaries.\n") - print("Pytorch binaries were compiled with Cuda {}\n".format(torch.version.cuda)) + raise RuntimeError("Cuda extensions are being compiled with a version of Cuda that does " + "not match the version used to compile Pytorch binaries. " + "Pytorch binaries were compiled with Cuda {}.\n".format(torch.version.cuda)) if "--cuda_ext" in sys.argv: from torch.utils.cpp_extension import CUDAExtension