forked from easybuilders/easybuild-easyconfigs
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
adding easyconfigs: jax-0.4.35-gfbf-2024a-CUDA-12.6.0.eb and patches:…
… jax-0.4.35_easyblock_compat.patch, jax-0.4.35_fix-pybind11-systemlib_cupti.patch
- Loading branch information
1 parent
91c8df6
commit 9164c61
Showing
3 changed files
with
250 additions
and
0 deletions.
There are no files selected for viewing
162 changes: 162 additions & 0 deletions
162
easybuild/easyconfigs/j/jax/jax-0.4.35-gfbf-2024a-CUDA-12.6.0.eb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,162 @@ | ||
# This file is an EasyBuild reciPY as per https://github.com/easybuilders/easybuild | ||
# Author: Denis Kristak | ||
# Updated by: Alex Domingo (Vrije Universiteit Brussel) | ||
# Updated by: Pavel Tománek (INUITS) | ||
# Updated by: Thomas Hoffmann (EMBL Heidelberg) | ||
easyblock = 'PythonBundle' | ||
|
||
name = 'jax' | ||
version = '0.4.35' | ||
versionsuffix = '-CUDA-%(cudaver)s' | ||
|
||
homepage = 'https://jax.readthedocs.io/' | ||
description = """Composable transformations of Python+NumPy programs: | ||
differentiate, vectorize, JIT to GPU/TPU, and more""" | ||
|
||
toolchain = {'name': 'gfbf', 'version': '2024a'} | ||
cuda_compute_capabilities = ["5.0", "6.0", "6.1", "7.0", "7.5", "8.0", "8.6", "9.0"] | ||
|
||
builddependencies = [ | ||
# ('Bazel', '7.4.1'), TODO: problems with py. 6.5.0 works. | ||
('pybind11', '2.13.6'), # 2.12.0 ? SciPy-bundle has pybind/2.12.0. Fix: change to builddependency in SciPy-bundle? | ||
('pytest-xdist', '3.6.1'), | ||
('git', '2.45.1'), # bazel uses git to fetch repositories | ||
('matplotlib', '3.9.2'), # required for tests/lobpcg_test.py | ||
('poetry', '1.8.3'), | ||
('Clang', '18.1.8') | ||
] | ||
|
||
dependencies = [ | ||
('CUDA', '12.6.0', '', SYSTEM), # 12.6.2 ? | ||
('cuDNN', '9.5.0.50', versionsuffix, SYSTEM), | ||
('NCCL', '2.22.3', versionsuffix), | ||
('Python', '3.12.3'), | ||
('SciPy-bundle', '2024.05'), # 2024.11 ? | ||
('absl-py', '2.1.0'), | ||
('flatbuffers-python', '24.3.25'), | ||
('ml_dtypes', '0.5.0'), | ||
('zlib', '1.3.1'), | ||
('pybind11', '2.13.6'), # override 2.12.0. SciPy-bundle has pybind/2.12.0. Fix: | ||
# change to builddependency in SciPy-bundle? (TODO) | ||
] | ||
|
||
# downloading xla and other tarballs to avoid that Bazel downloads it during the build | ||
local_extract_cmd = 'mkdir -p %(builddir)s/archives && cp %s %(builddir)s/archives' | ||
# note: following commits *must* be the exact same onces used upstream | ||
# XLA_COMMIT from jax-jaxlib: third_party/xla/workspace.bzl | ||
local_xla_commit = '76da730179313b3bebad6dea6861768421b7358c' | ||
# TFRT_COMMIT from xla: third_party/tsl/third_party/tf_runtime/workspace.bzl | ||
local_tfrt_commit = '0aeefb1660d7e37964b2bb71b1f518096bda9a25' # TODO: still required? | ||
# TODO: add other downloads | ||
|
||
# Use sources downloaded by EasyBuild | ||
_jaxlib_buildopts = '--bazel_options="--distdir=%(builddir)s/archives" ' | ||
# Use dependencies from EasyBuild | ||
_jaxlib_buildopts += '--bazel_options="--action_env=TF_SYSTEM_LIBS=pybind11" ' | ||
_jaxlib_buildopts += '--bazel_options="--action_env=CPATH=$EBROOTPYBIND11/include:$EBROOTCUDA/extras/CUPTI/include" ' | ||
# Avoid warning (treated as error) in upb/table.c | ||
_jaxlib_buildopts += '--bazel_options="--copt=-Wno-maybe-uninitialized" ' # TODO: still required? | ||
# _jaxlib_buildopts += '--nouse_clang ' #TODO: avoid clang (?) | ||
_jaxlib_buildopts += '--cuda_version=%(cudaver)s ' | ||
_jaxlib_buildopts += '--python_bin_path=$EBROOTPYTHON/bin/python3 ' | ||
# Do not use hermetic CUDA/cuDNN/NCCL: (requires action_env=CPATH=$EBROOTCUDA/extras/CUPTI/include"; | ||
# requires patch of external/xla/xla/tsl/cuda/cupti_stub.cc and jaxlib/gpu/vendor.h (#include <cupti.h>): | ||
_jaxlib_buildopts += """--bazel_options=--repo_env=LOCAL_CUDNN_PATH="$EBROOTCUDNN" """ | ||
_jaxlib_buildopts += """--bazel_options=--repo_env=LOCAL_NCCL_PATH="$EBROOTNCCL" """ | ||
_jaxlib_buildopts += """--bazel_options=--repo_env=LOCAL_CUDA_PATH="$EBROOTCUDA" """ | ||
_jaxlib_buildopts += """--bazel_options="--copt=-Ithird_party/gpus/cuda/extras/CUPTI/include" """ | ||
|
||
# get rid of .devDate versionsuffix: TODO: find a better way | ||
# _no_devtag = """ export JAX_RELEASE && export JAXLIB_RELEASE && """ does not work (?) | ||
_no_devtag = """ sed -i "s/version=__version__/version='%(version)s'/g" setup.py && """ | ||
_jaxlib_buildopts += """--bazel_options="--action_env=JAXLIB_RELEASE=1" """ # required? | ||
|
||
components = [ | ||
('jaxlib', version, { | ||
'sources': [ | ||
{ | ||
'source_urls': ['https://github.com/google/jax/archive/'], | ||
'filename': 'jax-v%(version)s.tar.gz', | ||
}, | ||
{ | ||
'source_urls': ['https://github.com/openxla/xla/archive'], | ||
'download_filename': '%s.tar.gz' % local_xla_commit, | ||
'filename': 'xla-%s.tar.gz' % local_xla_commit[:8], | ||
'extract_cmd': local_extract_cmd, | ||
}, | ||
{ | ||
'source_urls': ['https://github.com/tensorflow/runtime/archive'], | ||
'download_filename': '%s.tar.gz' % local_tfrt_commit, | ||
'filename': 'tf_runtime-%s.tar.gz' % local_tfrt_commit[:8], | ||
'extract_cmd': local_extract_cmd, | ||
}, | ||
], | ||
'patches': [ | ||
'jax-0.4.35_easyblock_compat.patch', | ||
'jax-0.4.35_fix-pybind11-systemlib_cupti.patch', | ||
'jax-v0.4.35_version.patch', | ||
], | ||
'checksums': [ | ||
{'jax-v0.4.35.tar.gz': | ||
'65e086708ae56670676b7b2340ad82b901d8c9993d1241a839c8990bdb8d6212'}, | ||
{'xla-76da7301.tar.gz': | ||
'd67ced09b69ab8d7b26fa4cd5f48b22db57eb330294a35f6e1d462ee17066757'}, | ||
{'tf_runtime-0aeefb16.tar.gz': | ||
'a3df827d7896774cb1d80bf4e1c79ab05c268f29bd4d3db1fb5a4b9c2079d8e3'}, | ||
{'jax-0.4.35_easyblock_compat.patch': | ||
'cbf4ad92b8438c4ce2a975efce1c47c57d4c3b117bceee071ab660f964057223'}, | ||
{'jax-0.4.35_fix-pybind11-systemlib_cupti.patch': | ||
'78efe6b5108a5da1935258286c94dea8438fd03651533c34023eeba27f514130'}, | ||
], | ||
'start_dir': 'jax-jax-v%(version)s', | ||
'buildopts': _jaxlib_buildopts, | ||
'prebuildopts': ' mkdir third_party/gpus/cuda/extras/ -p && ' + | ||
'ln -s $EBROOTCUDA/extras/CUPTI third_party/gpus/cuda/extras --relative &&' + | ||
_no_devtag | ||
}), | ||
] | ||
|
||
# Some tests require an isolated run: TODO: still required? | ||
local_isolated_tests = [ | ||
'tests/host_callback_test.py::HostCallbackTapTest::test_tap_scan_custom_jvp', | ||
'tests/host_callback_test.py::HostCallbackTapTest::test_tap_transforms_doc', | ||
'tests/lax_scipy_special_functions_test.py::LaxScipySpcialFunctionsTest' + | ||
'::testScipySpecialFun_gammainc_s_2x1x4_float32_float32', | ||
] | ||
# deliberately not testing in parallel, as that results in (additional) failing tests; | ||
# use XLA_PYTHON_CLIENT_ALLOCATOR=platform to allocate and deallocate GPU memory during testing, | ||
# see https://github.com/google/jax/issues/7323 and | ||
# https://github.com/google/jax/blob/main/docs/gpu_memory_allocation.rst; | ||
# use CUDA_VISIBLE_DEVICES=0 to avoid failing tests on systems with multiple GPUs; | ||
# use NVIDIA_TF32_OVERRIDE=0 to avoid loosing numerical precision by disabling TF32 Tensor Cores; | ||
local_test_exports = [ | ||
"NVIDIA_TF32_OVERRIDE=0", | ||
"CUDA_VISIBLE_DEVICES=0", | ||
"XLA_PYTHON_CLIENT_ALLOCATOR=platform", | ||
"JAX_ENABLE_X64=true", | ||
] | ||
local_test = ''.join(['export %s;' % x for x in local_test_exports]) | ||
# run all tests at once except for local_isolated_tests: | ||
local_test += "pytest -vv tests %s && " % ' '.join(['--deselect %s' % x for x in local_isolated_tests]) | ||
# run remaining local_isolated_tests separately: | ||
local_test += ' && '.join(['pytest -vv %s' % x for x in local_isolated_tests]) | ||
|
||
use_pip = True | ||
|
||
exts_list = [ | ||
(name, version, { | ||
'source_tmpl': '%(name)s-v%(version)s.tar.gz', | ||
'source_urls': ['https://github.com/google/jax/archive/'], | ||
# 'patches': ['jax-0.4.25_fix_env_test_no_log_spam.patch'], # TODO: still required? update? | ||
'patches': ['jax-v0.4.35_version.patch'], | ||
'checksums': [ | ||
{'jax-v0.4.35.tar.gz': '65e086708ae56670676b7b2340ad82b901d8c9993d1241a839c8990bdb8d6212'}, | ||
], | ||
'runtest': local_test, | ||
'preinstallopts': _no_devtag | ||
}), | ||
] | ||
|
||
sanity_pip_check = True | ||
|
||
moduleclass = 'ai' |
21 changes: 21 additions & 0 deletions
21
easybuild/easyconfigs/j/jax/jax-0.4.35_easyblock_compat.patch
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
# Thomas Hoffmann, EMBL Heidelberg, [email protected], 2024/11 | ||
# add dummy parameters to build/build.py for cudnn_path and cuda_path, which are set by default by the jaxlib easyblock. | ||
diff -ru jax-jax-v0.4.35/build/build.py jax-jax-v0.4.35_easyblockcompat/build/build.py | ||
--- jax-jax-v0.4.35/build/build.py 2024-10-22 21:00:23.000000000 +0200 | ||
+++ jax-jax-v0.4.35_easyblockcompat/build/build.py 2024-11-19 12:35:46.524479324 +0100 | ||
@@ -549,6 +549,15 @@ | ||
help_str="Same as update_requirements, but will consider dev, nightly " | ||
"and pre-release versions of packages.") | ||
|
||
+ parser.add_argument( | ||
+ "--cuda_path", | ||
+ default="dummy", | ||
+ help="compatibility with jaxlib.py easyblock") | ||
+ parser.add_argument( | ||
+ "--cudnn_path", | ||
+ default="dummy", | ||
+ help="compatibility with jaxlib.py easyblock") | ||
+ | ||
args = parser.parse_args() | ||
|
||
logging.basicConfig() |
67 changes: 67 additions & 0 deletions
67
easybuild/easyconfigs/j/jax/jax-0.4.35_fix-pybind11-systemlib_cupti.patch
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
jax-0.4.25_fix-pybind11-systemlib.patch: Add missing value for System Pybind11 Bazel config | ||
jax-0.4.25_fix-pybind11-systemlib.patch: Author: Alexander Grund (TU Dresden) | ||
|
||
THEMBL: fix cupti include path. | ||
|
||
diff --git a/third_party/xla/fix-pybind11-systemlib.patch b/third_party/xla/fix-pybind11-systemlib.patch | ||
new file mode 100644 | ||
index 000000000..68bd2063d | ||
--- /dev/null | ||
+++ b/third_party/xla/fix-pybind11-systemlib.patch | ||
@@ -0,0 +1,13 @@ | ||
+--- xla-orig/third_party/tsl/third_party/systemlibs/pybind11.BUILD | ||
++++ xla-4ccfe33c71665ddcbca5b127fefe8baa3ed632d4/third_party/tsl/third_party/systemlibs/pybind11.BUILD | ||
+@@ -6,3 +6,10 @@ | ||
+ "@tsl//third_party/python_runtime:headers", | ||
+ ], | ||
+ ) | ||
++ | ||
++# Needed by pybind11_bazel. | ||
++config_setting( | ||
++ name = "osx", | ||
++ constraint_values = ["@platforms//os:osx"], | ||
++) | ||
++ | ||
diff -ruN jax-jax-v0.4.35/jaxlib/gpu/vendor.h jax-jax-v0.4.35_jaxlib_cupti__fix-pybind11-systemlib/jaxlib/gpu/vendor.h | ||
--- jax-jax-v0.4.35/jaxlib/gpu/vendor.h 2024-10-22 21:00:23.000000000 +0200 | ||
+++ jax-jax-v0.4.35_jaxlib_cupti__fix-pybind11-systemlib/jaxlib/gpu/vendor.h 2024-11-26 10:56:20.396087442 +0100 | ||
@@ -23,7 +23,7 @@ | ||
#if defined(JAX_GPU_CUDA) | ||
|
||
// IWYU pragma: begin_exports | ||
-#include "third_party/gpus/cuda/extras/CUPTI/include/cupti.h" | ||
+#include <cupti.h> | ||
#include "third_party/gpus/cuda/include/cooperative_groups.h" | ||
#include "third_party/gpus/cuda/include/cuComplex.h" | ||
#include "third_party/gpus/cuda/include/cublas_v2.h" | ||
diff -ruN jax-jax-v0.4.35/third_party/xla/workspace.bzl jax-jax-v0.4.35_jaxlib_cupti__fix-pybind11-systemlib/third_party/xla/workspace.bzl | ||
--- jax-jax-v0.4.35/third_party/xla/workspace.bzl 2024-10-22 21:00:23.000000000 +0200 | ||
+++ jax-jax-v0.4.35_jaxlib_cupti__fix-pybind11-systemlib/third_party/xla/workspace.bzl 2024-11-27 12:17:37.913466273 +0100 | ||
@@ -30,6 +30,11 @@ | ||
sha256 = XLA_SHA256, | ||
strip_prefix = "xla-{commit}".format(commit = XLA_COMMIT), | ||
urls = tf_mirror_urls("https://github.com/openxla/xla/archive/{commit}.tar.gz".format(commit = XLA_COMMIT)), | ||
+ patch_file = [ | ||
+ "//third_party/xla:xla-76da73_cupti.patch", | ||
+ "//third_party/xla:fix-pybind11-systemlib.patch", | ||
+ ], | ||
+ | ||
) | ||
|
||
# For development, one often wants to make changes to the TF repository as well | ||
diff -ruN jax-jax-v0.4.35/third_party/xla/xla-76da73_cupti.patch jax-jax-v0.4.35_jaxlib_cupti__fix-pybind11-systemlib/third_party/xla/xla-76da73_cupti.patch | ||
--- jax-jax-v0.4.35/third_party/xla/xla-76da73_cupti.patch 1970-01-01 01:00:00.000000000 +0100 | ||
+++ jax-jax-v0.4.35_jaxlib_cupti__fix-pybind11-systemlib/third_party/xla/xla-76da73_cupti.patch 2024-11-27 12:18:26.668582799 +0100 | ||
@@ -0,0 +1,12 @@ | ||
+diff -ru xla-76da730179313b3bebad6dea6861768421b7358c/xla/tsl/cuda/cupti_stub.cc xla-76da730179313b3bebad6dea6861768421b7358c_cupti/xla/tsl/cuda/cupti_stub.cc | ||
+--- xla-76da730179313b3bebad6dea6861768421b7358c/xla/tsl/cuda/cupti_stub.cc 2024-10-21 20:29:31.000000000 +0200 | ||
++++ xla-76da730179313b3bebad6dea6861768421b7358c_cupti/xla/tsl/cuda/cupti_stub.cc 2024-11-26 12:04:11.695539146 +0100 | ||
+@@ -13,7 +13,7 @@ | ||
+ limitations under the License. | ||
+ ==============================================================================*/ | ||
+ | ||
+-#include "third_party/gpus/cuda/extras/CUPTI/include/cupti.h" | ||
++#include <cupti.h> | ||
+ #include "third_party/gpus/cuda/include/cuda.h" | ||
+ #include "tsl/platform/dso_loader.h" | ||
+ #include "tsl/platform/load_library.h" |