diff --git a/easybuild/easyconfigs/j/jax/jax-0.4.34-gfbf-2024a-CUDA-12.6.0.eb b/easybuild/easyconfigs/j/jax/jax-0.4.34-gfbf-2024a-CUDA-12.6.0.eb new file mode 100644 index 00000000000..8fac152e93f --- /dev/null +++ b/easybuild/easyconfigs/j/jax/jax-0.4.34-gfbf-2024a-CUDA-12.6.0.eb @@ -0,0 +1,168 @@ +# This file is an EasyBuild reciPY as per https://github.com/easybuilders/easybuild +# Author: Denis Kristak +# Updated by: Alex Domingo (Vrije Universiteit Brussel) +# Updated by: Pavel Tománek (INUITS) +# Updated by: Thomas Hoffmann (EMBL Heidelberg) +easyblock = 'PythonBundle' + +name = 'jax' +version = '0.4.34' +versionsuffix = '-CUDA-%(cudaver)s' + +homepage = 'https://jax.readthedocs.io/' +description = """Composable transformations of Python+NumPy programs: +differentiate, vectorize, JIT to GPU/TPU, and more""" + +toolchain = {'name': 'gfbf', 'version': '2024a'} +cuda_compute_capabilities = ["5.0", "6.0", "6.1", "7.0", "7.5", "8.0", "8.6", "9.0"] + +builddependencies = [ + # ('Bazel', '7.4.1'), TODO: problems with @@local_config_python//:py3_runtime: + # Error in fail: interpreter_path must be an absolute path + # Bazel 6.5.0 (download) works. + ('pybind11', '2.12.0'), # 2.12.0 ? SciPy-bundle has pybind/2.12.0. + # Fix: change to builddependency in SciPy-bundle? + # tmporarily mv to dependencies (TODO: mv back) + ('pytest-xdist', '3.6.1'), + ('git', '2.45.1'), # bazel uses git to fetch repositories + ('matplotlib', '3.9.2'), # required for tests/lobpcg_test.py + ('poetry', '1.8.3'), + ('Clang', '18.1.8') +] + +dependencies = [ + ('CUDA', '12.6.0', '', SYSTEM), # 12.6.2 ? + ('cuDNN', '9.5.0.50', versionsuffix, SYSTEM), + ('NCCL', '2.22.3', versionsuffix), + ('Python', '3.12.3'), + ('SciPy-bundle', '2024.05'), # 2024.11 ? + ('absl-py', '2.1.0'), + ('flatbuffers-python', '24.3.25'), + ('ml_dtypes', '0.5.0'), + ('zlib', '1.3.1'), + # ('pybind11', '2.13.6'), # override 2.12.0. SciPy-bundle has pybind/2.12.0. Fix: + # change to builddependency in SciPy-bundle? (TODO) +] + +# downloading xla and other tarballs to avoid that Bazel downloads it during the build +local_extract_cmd = 'mkdir -p %(builddir)s/archives && cp %s %(builddir)s/archives' +# note: following commits *must* be the exact same onces used upstream +# XLA_COMMIT from jax-jaxlib: third_party/xla/workspace.bzl +local_xla_commit = 'cd6e808c59f53b40a99df1f1b860db9a3e598bff' +# TFRT_COMMIT from xla: third_party/tsl/third_party/tf_runtime/workspace.bzl +local_tfrt_commit = '0aeefb1660d7e37964b2bb71b1f518096bda9a25' # TODO: still required? +# TODO: add other downloads + +# Use sources downloaded by EasyBuild +_jaxlib_buildopts = '--bazel_options="--distdir=%(builddir)s/archives" ' +# Use dependencies from EasyBuild +_jaxlib_buildopts += '--bazel_options="--action_env=TF_SYSTEM_LIBS=pybind11" ' +_jaxlib_buildopts += '--bazel_options="--action_env=CPATH=$EBROOTPYBIND11/include:$EBROOTCUDA/extras/CUPTI/include" ' +# Avoid warning (treated as error) in upb/table.c +_jaxlib_buildopts += '--bazel_options="--copt=-Wno-maybe-uninitialized" ' # TODO: still required? +# _jaxlib_buildopts += '--nouse_clang ' #TODO: avoid clang (?) +_jaxlib_buildopts += '--cuda_version=%(cudaver)s ' +_jaxlib_buildopts += '--python_bin_path=$EBROOTPYTHON/bin/python3 ' +# Do not use hermetic CUDA/cuDNN/NCCL: (requires action_env=CPATH=$EBROOTCUDA/extras/CUPTI/include"; +# requires patch of external/xla/xla/tsl/cuda/cupti_stub.cc and jaxlib/gpu/vendor.h (#include ): +_jaxlib_buildopts += """--bazel_options=--repo_env=LOCAL_CUDNN_PATH="$EBROOTCUDNN" """ +_jaxlib_buildopts += """--bazel_options=--repo_env=LOCAL_NCCL_PATH="$EBROOTNCCL" """ +_jaxlib_buildopts += """--bazel_options=--repo_env=LOCAL_CUDA_PATH="$EBROOTCUDA" """ +_jaxlib_buildopts += """--bazel_options="--copt=-Ithird_party/gpus/cuda/extras/CUPTI/include" """ + +# get rid of .devDate versionsuffix: TODO: find a better way +# _no_devtag = """ export JAX_RELEASE && export JAXLIB_RELEASE && """ does not work (?) +_no_devtag = """ sed -i "s/version=__version__/version='%(version)s'/g" setup.py && """ +_jaxlib_buildopts += """--bazel_options="--action_env=JAXLIB_RELEASE=1" """ # required? + +components = [ + ('jaxlib', version, { + 'sources': [ + { + 'source_urls': ['https://github.com/google/jax/archive/'], + 'filename': 'jax-v%(version)s.tar.gz', + }, + { + 'source_urls': ['https://github.com/openxla/xla/archive'], + 'download_filename': '%s.tar.gz' % local_xla_commit, + 'filename': 'xla-%s.tar.gz' % local_xla_commit[:8], + 'extract_cmd': local_extract_cmd, + }, + { + 'source_urls': ['https://github.com/tensorflow/runtime/archive'], + 'download_filename': '%s.tar.gz' % local_tfrt_commit, + 'filename': 'tf_runtime-%s.tar.gz' % local_tfrt_commit[:8], + 'extract_cmd': local_extract_cmd, + }, + ], + 'patches': [ + 'jax-0.4.35_easyblock_compat.patch', + 'jax-0.4.35_fix-pybind11-systemlib_cupti.patch', + 'jax-0.4.35_version.patch', + ], + 'checksums': [ + {'jax-v0.4.34.tar.gz': + 'd3a75ad667772309ade81350fa70c4a78028a920028800282e46d8383c0ee6bb'}, + {'xla-cd6e808c.tar.gz': + '65cb6d63ef4083b35775052636cb9c629f86db6947c8b91711923ba31dbdcde8'}, + {'tf_runtime-0aeefb16.tar.gz': + 'a3df827d7896774cb1d80bf4e1c79ab05c268f29bd4d3db1fb5a4b9c2079d8e3'}, + {'jax-0.4.35_easyblock_compat.patch': + 'cbf4ad92b8438c4ce2a975efce1c47c57d4c3b117bceee071ab660f964057223'}, + {'jax-0.4.35_fix-pybind11-systemlib_cupti.patch': + '78efe6b5108a5da1935258286c94dea8438fd03651533c34023eeba27f514130'}, + {'jax-0.4.35_version.patch': + 'cd2139a7802abf14b4b2cecee331aed80fff2ef91e16fa105093aea0795455e8'}, + ], + 'start_dir': 'jax-jax-v%(version)s', + 'buildopts': _jaxlib_buildopts, + 'prebuildopts': ' mkdir third_party/gpus/cuda/extras/ -p && ' + + 'ln -s $EBROOTCUDA/extras/CUPTI third_party/gpus/cuda/extras --relative &&' + + _no_devtag + }), +] + +# Some tests require an isolated run: TODO: still required? +local_isolated_tests = [ + 'tests/host_callback_test.py::HostCallbackTapTest::test_tap_scan_custom_jvp', + 'tests/host_callback_test.py::HostCallbackTapTest::test_tap_transforms_doc', + 'tests/lax_scipy_special_functions_test.py::LaxScipySpcialFunctionsTest' + + '::testScipySpecialFun_gammainc_s_2x1x4_float32_float32', +] +# deliberately not testing in parallel, as that results in (additional) failing tests; +# use XLA_PYTHON_CLIENT_ALLOCATOR=platform to allocate and deallocate GPU memory during testing, +# see https://github.com/google/jax/issues/7323 and +# https://github.com/google/jax/blob/main/docs/gpu_memory_allocation.rst; +# use CUDA_VISIBLE_DEVICES=0 to avoid failing tests on systems with multiple GPUs; +# use NVIDIA_TF32_OVERRIDE=0 to avoid loosing numerical precision by disabling TF32 Tensor Cores; +local_test_exports = [ + "NVIDIA_TF32_OVERRIDE=0", + "CUDA_VISIBLE_DEVICES=0", + "XLA_PYTHON_CLIENT_ALLOCATOR=platform", + "JAX_ENABLE_X64=true", +] +local_test = ''.join(['export %s;' % x for x in local_test_exports]) +# run all tests at once except for local_isolated_tests: +local_test += "pytest -vv tests %s && " % ' '.join(['--deselect %s' % x for x in local_isolated_tests]) +# run remaining local_isolated_tests separately: +local_test += ' && '.join(['pytest -vv %s' % x for x in local_isolated_tests]) + +use_pip = True + +exts_list = [ + (name, version, { + 'patches': ['jax-0.4.35_version.patch'], + 'preinstallopts': _no_devtag, + 'runtest': False, + 'source_tmpl': '%(name)s-v%(version)s.tar.gz', + 'source_urls': ['https://github.com/google/jax/archive/'], + 'checksums': [ + {'jax-v0.4.34.tar.gz': 'd3a75ad667772309ade81350fa70c4a78028a920028800282e46d8383c0ee6bb'}, + {'jax-0.4.35_version.patch': 'cd2139a7802abf14b4b2cecee331aed80fff2ef91e16fa105093aea0795455e8'}, + ], + }), +] + +sanity_pip_check = True + +moduleclass = 'ai'