diff --git a/easybuild/easyconfigs/j/jax-triton/jax-triton-0.1.1_ignore_missing_torch.patch b/easybuild/easyconfigs/j/jax-triton/jax-triton-0.1.1_ignore_missing_torch.patch deleted file mode 100644 index bc166c8158a..00000000000 --- a/easybuild/easyconfigs/j/jax-triton/jax-triton-0.1.1_ignore_missing_torch.patch +++ /dev/null @@ -1,38 +0,0 @@ -# Thomas Hoffmann, EMBL Heidelberg, structures-it@embl.de, 2024/12 -# fix: import jax_triton sanity check fails, if PyTorch is not loaded. - -diff -ru jax-triton-0.1.1/jax_triton/triton_call.py jax-triton-0.1.1_ignore_missing_torch/jax_triton/triton_call.py ---- jax-triton-0.1.1/jax_triton/triton_call.py 2022-09-13 06:27:46.000000000 +0200 -+++ jax-triton-0.1.1_ignore_missing_torch/jax_triton/triton_call.py 2024-12-19 15:47:16.069571976 +0100 -@@ -26,11 +26,15 @@ - from jax.interpreters import mlir - from jax import tree_util - from jax._src import util --from jax._src.lib import xla_bridge as xb -+from jax.lib import xla_bridge as xb - from jax._src.lib.mlir import ir - from jax._src.lib.mlir.dialects import mhlo - import numpy as np --import torch -+_has_torch=True -+try: -+ import torch -+except: -+ _has_torch = False - import triton - import triton.language as tl - -@@ -109,8 +113,11 @@ - dump_binary_path=dump_binary_path, **metaparams) - return tree_util.tree_unflatten(out_tree, out_flat) - --table = {'float32': torch.float32, 'int32': torch.int32, 'float16': torch.float16, -- 'float64': torch.float64, 'int64': torch.int64 } -+table = {} -+ -+if _has_torch: -+ table = {'float32': torch.float32, 'int32': torch.int32, 'float16': torch.float16, -+ 'float64': torch.float64, 'int64': torch.int64 } - - @triton_call_p.def_impl - def triton_call_impl(*args, kernel, out_shapes, grid, dump_binary_path, **metaparams):