-
Notifications
You must be signed in to change notification settings - Fork 706
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
test v0.4.34 with pybind11/2.12.0 builddep
- Loading branch information
1 parent
fc5b969
commit 849f9fb
Showing
1 changed file
with
168 additions
and
0 deletions.
There are no files selected for viewing
168 changes: 168 additions & 0 deletions
168
easybuild/easyconfigs/j/jax/jax-0.4.34-gfbf-2024a-CUDA-12.6.0.eb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,168 @@ | ||
# This file is an EasyBuild reciPY as per https://github.com/easybuilders/easybuild | ||
# Author: Denis Kristak | ||
# Updated by: Alex Domingo (Vrije Universiteit Brussel) | ||
# Updated by: Pavel Tománek (INUITS) | ||
# Updated by: Thomas Hoffmann (EMBL Heidelberg) | ||
easyblock = 'PythonBundle' | ||
|
||
name = 'jax' | ||
version = '0.4.34' | ||
versionsuffix = '-CUDA-%(cudaver)s' | ||
|
||
homepage = 'https://jax.readthedocs.io/' | ||
description = """Composable transformations of Python+NumPy programs: | ||
differentiate, vectorize, JIT to GPU/TPU, and more""" | ||
|
||
toolchain = {'name': 'gfbf', 'version': '2024a'} | ||
cuda_compute_capabilities = ["5.0", "6.0", "6.1", "7.0", "7.5", "8.0", "8.6", "9.0"] | ||
|
||
builddependencies = [ | ||
# ('Bazel', '7.4.1'), TODO: problems with @@local_config_python//:py3_runtime: | ||
# Error in fail: interpreter_path must be an absolute path | ||
# Bazel 6.5.0 (download) works. | ||
('pybind11', '2.12.0'), # 2.12.0 ? SciPy-bundle has pybind/2.12.0. | ||
# Fix: change to builddependency in SciPy-bundle? | ||
# tmporarily mv to dependencies (TODO: mv back) | ||
('pytest-xdist', '3.6.1'), | ||
('git', '2.45.1'), # bazel uses git to fetch repositories | ||
('matplotlib', '3.9.2'), # required for tests/lobpcg_test.py | ||
('poetry', '1.8.3'), | ||
('Clang', '18.1.8') | ||
] | ||
|
||
dependencies = [ | ||
('CUDA', '12.6.0', '', SYSTEM), # 12.6.2 ? | ||
('cuDNN', '9.5.0.50', versionsuffix, SYSTEM), | ||
('NCCL', '2.22.3', versionsuffix), | ||
('Python', '3.12.3'), | ||
('SciPy-bundle', '2024.05'), # 2024.11 ? | ||
('absl-py', '2.1.0'), | ||
('flatbuffers-python', '24.3.25'), | ||
('ml_dtypes', '0.5.0'), | ||
('zlib', '1.3.1'), | ||
# ('pybind11', '2.13.6'), # override 2.12.0. SciPy-bundle has pybind/2.12.0. Fix: | ||
# change to builddependency in SciPy-bundle? (TODO) | ||
] | ||
|
||
# downloading xla and other tarballs to avoid that Bazel downloads it during the build | ||
local_extract_cmd = 'mkdir -p %(builddir)s/archives && cp %s %(builddir)s/archives' | ||
# note: following commits *must* be the exact same onces used upstream | ||
# XLA_COMMIT from jax-jaxlib: third_party/xla/workspace.bzl | ||
local_xla_commit = 'cd6e808c59f53b40a99df1f1b860db9a3e598bff' | ||
# TFRT_COMMIT from xla: third_party/tsl/third_party/tf_runtime/workspace.bzl | ||
local_tfrt_commit = '0aeefb1660d7e37964b2bb71b1f518096bda9a25' # TODO: still required? | ||
# TODO: add other downloads | ||
|
||
# Use sources downloaded by EasyBuild | ||
_jaxlib_buildopts = '--bazel_options="--distdir=%(builddir)s/archives" ' | ||
# Use dependencies from EasyBuild | ||
_jaxlib_buildopts += '--bazel_options="--action_env=TF_SYSTEM_LIBS=pybind11" ' | ||
_jaxlib_buildopts += '--bazel_options="--action_env=CPATH=$EBROOTPYBIND11/include:$EBROOTCUDA/extras/CUPTI/include" ' | ||
# Avoid warning (treated as error) in upb/table.c | ||
_jaxlib_buildopts += '--bazel_options="--copt=-Wno-maybe-uninitialized" ' # TODO: still required? | ||
# _jaxlib_buildopts += '--nouse_clang ' #TODO: avoid clang (?) | ||
_jaxlib_buildopts += '--cuda_version=%(cudaver)s ' | ||
_jaxlib_buildopts += '--python_bin_path=$EBROOTPYTHON/bin/python3 ' | ||
# Do not use hermetic CUDA/cuDNN/NCCL: (requires action_env=CPATH=$EBROOTCUDA/extras/CUPTI/include"; | ||
# requires patch of external/xla/xla/tsl/cuda/cupti_stub.cc and jaxlib/gpu/vendor.h (#include <cupti.h>): | ||
_jaxlib_buildopts += """--bazel_options=--repo_env=LOCAL_CUDNN_PATH="$EBROOTCUDNN" """ | ||
_jaxlib_buildopts += """--bazel_options=--repo_env=LOCAL_NCCL_PATH="$EBROOTNCCL" """ | ||
_jaxlib_buildopts += """--bazel_options=--repo_env=LOCAL_CUDA_PATH="$EBROOTCUDA" """ | ||
_jaxlib_buildopts += """--bazel_options="--copt=-Ithird_party/gpus/cuda/extras/CUPTI/include" """ | ||
|
||
# get rid of .devDate versionsuffix: TODO: find a better way | ||
# _no_devtag = """ export JAX_RELEASE && export JAXLIB_RELEASE && """ does not work (?) | ||
_no_devtag = """ sed -i "s/version=__version__/version='%(version)s'/g" setup.py && """ | ||
_jaxlib_buildopts += """--bazel_options="--action_env=JAXLIB_RELEASE=1" """ # required? | ||
|
||
components = [ | ||
('jaxlib', version, { | ||
'sources': [ | ||
{ | ||
'source_urls': ['https://github.com/google/jax/archive/'], | ||
'filename': 'jax-v%(version)s.tar.gz', | ||
}, | ||
{ | ||
'source_urls': ['https://github.com/openxla/xla/archive'], | ||
'download_filename': '%s.tar.gz' % local_xla_commit, | ||
'filename': 'xla-%s.tar.gz' % local_xla_commit[:8], | ||
'extract_cmd': local_extract_cmd, | ||
}, | ||
{ | ||
'source_urls': ['https://github.com/tensorflow/runtime/archive'], | ||
'download_filename': '%s.tar.gz' % local_tfrt_commit, | ||
'filename': 'tf_runtime-%s.tar.gz' % local_tfrt_commit[:8], | ||
'extract_cmd': local_extract_cmd, | ||
}, | ||
], | ||
'patches': [ | ||
'jax-0.4.35_easyblock_compat.patch', | ||
'jax-0.4.35_fix-pybind11-systemlib_cupti.patch', | ||
'jax-0.4.35_version.patch', | ||
], | ||
'checksums': [ | ||
{'jax-v0.4.34.tar.gz': | ||
'd3a75ad667772309ade81350fa70c4a78028a920028800282e46d8383c0ee6bb'}, | ||
{'xla-cd6e808c.tar.gz': | ||
'65cb6d63ef4083b35775052636cb9c629f86db6947c8b91711923ba31dbdcde8'}, | ||
{'tf_runtime-0aeefb16.tar.gz': | ||
'a3df827d7896774cb1d80bf4e1c79ab05c268f29bd4d3db1fb5a4b9c2079d8e3'}, | ||
{'jax-0.4.35_easyblock_compat.patch': | ||
'cbf4ad92b8438c4ce2a975efce1c47c57d4c3b117bceee071ab660f964057223'}, | ||
{'jax-0.4.35_fix-pybind11-systemlib_cupti.patch': | ||
'78efe6b5108a5da1935258286c94dea8438fd03651533c34023eeba27f514130'}, | ||
{'jax-0.4.35_version.patch': | ||
'cd2139a7802abf14b4b2cecee331aed80fff2ef91e16fa105093aea0795455e8'}, | ||
], | ||
'start_dir': 'jax-jax-v%(version)s', | ||
'buildopts': _jaxlib_buildopts, | ||
'prebuildopts': ' mkdir third_party/gpus/cuda/extras/ -p && ' + | ||
'ln -s $EBROOTCUDA/extras/CUPTI third_party/gpus/cuda/extras --relative &&' + | ||
_no_devtag | ||
}), | ||
] | ||
|
||
# Some tests require an isolated run: TODO: still required? | ||
local_isolated_tests = [ | ||
'tests/host_callback_test.py::HostCallbackTapTest::test_tap_scan_custom_jvp', | ||
'tests/host_callback_test.py::HostCallbackTapTest::test_tap_transforms_doc', | ||
'tests/lax_scipy_special_functions_test.py::LaxScipySpcialFunctionsTest' + | ||
'::testScipySpecialFun_gammainc_s_2x1x4_float32_float32', | ||
] | ||
# deliberately not testing in parallel, as that results in (additional) failing tests; | ||
# use XLA_PYTHON_CLIENT_ALLOCATOR=platform to allocate and deallocate GPU memory during testing, | ||
# see https://github.com/google/jax/issues/7323 and | ||
# https://github.com/google/jax/blob/main/docs/gpu_memory_allocation.rst; | ||
# use CUDA_VISIBLE_DEVICES=0 to avoid failing tests on systems with multiple GPUs; | ||
# use NVIDIA_TF32_OVERRIDE=0 to avoid loosing numerical precision by disabling TF32 Tensor Cores; | ||
local_test_exports = [ | ||
"NVIDIA_TF32_OVERRIDE=0", | ||
"CUDA_VISIBLE_DEVICES=0", | ||
"XLA_PYTHON_CLIENT_ALLOCATOR=platform", | ||
"JAX_ENABLE_X64=true", | ||
] | ||
local_test = ''.join(['export %s;' % x for x in local_test_exports]) | ||
# run all tests at once except for local_isolated_tests: | ||
local_test += "pytest -vv tests %s && " % ' '.join(['--deselect %s' % x for x in local_isolated_tests]) | ||
# run remaining local_isolated_tests separately: | ||
local_test += ' && '.join(['pytest -vv %s' % x for x in local_isolated_tests]) | ||
|
||
use_pip = True | ||
|
||
exts_list = [ | ||
(name, version, { | ||
'patches': ['jax-0.4.35_version.patch'], | ||
'preinstallopts': _no_devtag, | ||
'runtest': False, | ||
'source_tmpl': '%(name)s-v%(version)s.tar.gz', | ||
'source_urls': ['https://github.com/google/jax/archive/'], | ||
'checksums': [ | ||
{'jax-v0.4.34.tar.gz': 'd3a75ad667772309ade81350fa70c4a78028a920028800282e46d8383c0ee6bb'}, | ||
{'jax-0.4.35_version.patch': 'cd2139a7802abf14b4b2cecee331aed80fff2ef91e16fa105093aea0795455e8'}, | ||
], | ||
}), | ||
] | ||
|
||
sanity_pip_check = True | ||
|
||
moduleclass = 'ai' |