diff --git a/easybuild/easyconfigs/j/jax/jax-0.4.35-gfbf-2024a-CUDA-12.6.0.eb b/easybuild/easyconfigs/j/jax/jax-0.4.35-gfbf-2024a-CUDA-12.6.0.eb deleted file mode 100644 index 63f290a028c..00000000000 --- a/easybuild/easyconfigs/j/jax/jax-0.4.35-gfbf-2024a-CUDA-12.6.0.eb +++ /dev/null @@ -1,170 +0,0 @@ -# This file is an EasyBuild reciPY as per https://github.com/easybuilders/easybuild -# Author: Denis Kristak -# Updated by: Alex Domingo (Vrije Universiteit Brussel) -# Updated by: Pavel Tománek (INUITS) -# Updated by: Thomas Hoffmann (EMBL Heidelberg) -easyblock = 'PythonBundle' - -name = 'jax' -version = '0.4.35' -versionsuffix = '-CUDA-%(cudaver)s' - -homepage = 'https://jax.readthedocs.io/' -description = """Composable transformations of Python+NumPy programs: -differentiate, vectorize, JIT to GPU/TPU, and more""" - -toolchain = {'name': 'gfbf', 'version': '2024a'} -cuda_compute_capabilities = ["5.0", "6.0", "6.1", "7.0", "7.5", "8.0", "8.6", "9.0"] - -builddependencies = [ - # ('Bazel', '7.4.1'), TODO: problems with @@local_config_python//:py3_runtime: - # Error in fail: interpreter_path must be an absolute path - # Bazel 6.5.0 (download) works. - ('pybind11', '2.13.6'), # 2.12.0 ? SciPy-bundle has pybind/2.12.0. - # Fix: change to builddependency in SciPy-bundle? - # tmporarily mv to dependencies (TODO: mv back) - ('pytest-xdist', '3.6.1'), - ('git', '2.45.1'), # bazel uses git to fetch repositories - ('matplotlib', '3.9.2'), # required for tests/lobpcg_test.py - ('poetry', '1.8.3'), - ('Clang', '18.1.8') -] - -dependencies = [ - ('CUDA', '12.6.0', '', SYSTEM), # 12.6.2 ? - ('cuDNN', '9.5.0.50', versionsuffix, SYSTEM), - ('NCCL', '2.22.3', versionsuffix), - ('Python', '3.12.3'), - ('SciPy-bundle', '2024.05'), # 2024.11 ? - ('absl-py', '2.1.0'), - ('flatbuffers-python', '24.3.25'), - ('ml_dtypes', '0.5.0'), - ('zlib', '1.3.1'), - # ('pybind11', '2.13.6'), # override 2.12.0. SciPy-bundle has pybind/2.12.0. Fix: - # change to builddependency in SciPy-bundle? (TODO) -] - -# downloading xla and other tarballs to avoid that Bazel downloads it during the build -local_extract_cmd = 'mkdir -p %(builddir)s/archives && cp %s %(builddir)s/archives' -# note: following commits *must* be the exact same onces used upstream -# XLA_COMMIT from jax-jaxlib: third_party/xla/workspace.bzl -local_xla_commit = '76da730179313b3bebad6dea6861768421b7358c' -# TFRT_COMMIT from xla: third_party/tsl/third_party/tf_runtime/workspace.bzl -local_tfrt_commit = '0aeefb1660d7e37964b2bb71b1f518096bda9a25' # TODO: still required? -# TODO: add other downloads - -# Use sources downloaded by EasyBuild -_jaxlib_buildopts = '--bazel_options="--distdir=%(builddir)s/archives" ' -# Use dependencies from EasyBuild -_jaxlib_buildopts += '--bazel_options="--action_env=TF_SYSTEM_LIBS=pybind11" ' -_jaxlib_buildopts += '--bazel_options="--action_env=CPATH=$EBROOTPYBIND11/include:$EBROOTCUDA/extras/CUPTI/include" ' -# Avoid warning (treated as error) in upb/table.c -_jaxlib_buildopts += '--bazel_options="--copt=-Wno-maybe-uninitialized" ' # TODO: still required? -# _jaxlib_buildopts += '--nouse_clang ' #TODO: avoid clang (?) -_jaxlib_buildopts += '--cuda_version=%(cudaver)s ' -_jaxlib_buildopts += '--python_bin_path=$EBROOTPYTHON/bin/python3 ' -# Do not use hermetic CUDA/cuDNN/NCCL: (requires action_env=CPATH=$EBROOTCUDA/extras/CUPTI/include"; -# requires patch of external/xla/xla/tsl/cuda/cupti_stub.cc and jaxlib/gpu/vendor.h (#include ): -_jaxlib_buildopts += """--bazel_options=--repo_env=LOCAL_CUDNN_PATH="$EBROOTCUDNN" """ -_jaxlib_buildopts += """--bazel_options=--repo_env=LOCAL_NCCL_PATH="$EBROOTNCCL" """ -_jaxlib_buildopts += """--bazel_options=--repo_env=LOCAL_CUDA_PATH="$EBROOTCUDA" """ -_jaxlib_buildopts += """--bazel_options="--copt=-Ithird_party/gpus/cuda/extras/CUPTI/include" """ - -# get rid of .devDate versionsuffix: TODO: find a better way -# _no_devtag = """ export JAX_RELEASE && export JAXLIB_RELEASE && """ does not work (?) -_no_devtag = """ sed -i "s/version=__version__/version='%(version)s'/g" setup.py && """ -_jaxlib_buildopts += """--bazel_options="--action_env=JAXLIB_RELEASE=1" """ # required? - -components = [ - ('jaxlib', version, { - 'sources': [ - { - 'source_urls': ['https://github.com/google/jax/archive/'], - 'filename': 'jax-v%(version)s.tar.gz', - }, - { - 'source_urls': ['https://github.com/openxla/xla/archive'], - 'download_filename': '%s.tar.gz' % local_xla_commit, - 'filename': 'xla-%s.tar.gz' % local_xla_commit[:8], - 'extract_cmd': local_extract_cmd, - }, - { - 'source_urls': ['https://github.com/tensorflow/runtime/archive'], - 'download_filename': '%s.tar.gz' % local_tfrt_commit, - 'filename': 'tf_runtime-%s.tar.gz' % local_tfrt_commit[:8], - 'extract_cmd': local_extract_cmd, - }, - ], - 'patches': [ - 'jax-0.4.35_easyblock_compat.patch', - 'jax-0.4.35_fix-pybind11-systemlib_cupti.patch', - 'jax-0.4.35_version.patch', - ], - 'checksums': [ - {'jax-v0.4.35.tar.gz': - '65e086708ae56670676b7b2340ad82b901d8c9993d1241a839c8990bdb8d6212'}, - {'xla-76da7301.tar.gz': - 'd67ced09b69ab8d7b26fa4cd5f48b22db57eb330294a35f6e1d462ee17066757'}, - {'tf_runtime-0aeefb16.tar.gz': - 'a3df827d7896774cb1d80bf4e1c79ab05c268f29bd4d3db1fb5a4b9c2079d8e3'}, - {'jax-0.4.35_easyblock_compat.patch': - 'cbf4ad92b8438c4ce2a975efce1c47c57d4c3b117bceee071ab660f964057223'}, - {'jax-0.4.35_fix-pybind11-systemlib_cupti.patch': - '78efe6b5108a5da1935258286c94dea8438fd03651533c34023eeba27f514130'}, - {'jax-0.4.35_version.patch': - 'cd2139a7802abf14b4b2cecee331aed80fff2ef91e16fa105093aea0795455e8'}, - ], - 'start_dir': 'jax-jax-v%(version)s', - 'buildopts': _jaxlib_buildopts, - 'prebuildopts': ' mkdir third_party/gpus/cuda/extras/ -p && ' + - 'ln -s $EBROOTCUDA/extras/CUPTI third_party/gpus/cuda/extras --relative &&' + - _no_devtag - }), -] - -# Some tests require an isolated run: TODO: still required? -local_isolated_tests = [ - 'tests/host_callback_test.py::HostCallbackTapTest::test_tap_scan_custom_jvp', - 'tests/host_callback_test.py::HostCallbackTapTest::test_tap_transforms_doc', - 'tests/lax_scipy_special_functions_test.py::LaxScipySpcialFunctionsTest' + - '::testScipySpecialFun_gammainc_s_2x1x4_float32_float32', -] -# deliberately not testing in parallel, as that results in (additional) failing tests; -# use XLA_PYTHON_CLIENT_ALLOCATOR=platform to allocate and deallocate GPU memory during testing, -# see https://github.com/google/jax/issues/7323 and -# https://github.com/google/jax/blob/main/docs/gpu_memory_allocation.rst; -# use CUDA_VISIBLE_DEVICES=0 to avoid failing tests on systems with multiple GPUs; -# use NVIDIA_TF32_OVERRIDE=0 to avoid loosing numerical precision by disabling TF32 Tensor Cores; -local_test_exports = [ - "NVIDIA_TF32_OVERRIDE=0", - "CUDA_VISIBLE_DEVICES=0", - "XLA_PYTHON_CLIENT_ALLOCATOR=platform", - "JAX_ENABLE_X64=true", -] -local_test = ''.join(['export %s;' % x for x in local_test_exports]) -# run all tests at once except for local_isolated_tests: -local_test += "pytest -vv tests %s && " % ' '.join(['--deselect %s' % x for x in local_isolated_tests]) -# run remaining local_isolated_tests separately: -local_test += ' && '.join(['pytest -vv %s' % x for x in local_isolated_tests]) - -use_pip = True - -exts_list = [ - (name, version, { - 'source_tmpl': '%(name)s-v%(version)s.tar.gz', - 'source_urls': ['https://github.com/google/jax/archive/'], - # 'patches': ['jax-0.4.25_fix_env_test_no_log_spam.patch'], # TODO: still required? update? - 'patches': ['jax-0.4.35_version.patch'], - 'checksums': [ - {'jax-v0.4.35.tar.gz': '65e086708ae56670676b7b2340ad82b901d8c9993d1241a839c8990bdb8d6212'}, - {'jax-0.4.35_version.patch': 'cd2139a7802abf14b4b2cecee331aed80fff2ef91e16fa105093aea0795455e8'}, - ], - # 'runtest': local_test, - 'runtest': False, # tmp - 'preinstallopts': _no_devtag - }), -] - -sanity_pip_check = True - -moduleclass = 'ai'