Skip to content

Commit

Permalink
adding easyconfigs: jax-0.4.35-gfbf-2024a-CUDA-12.6.0.eb and patches:…
Browse files Browse the repository at this point in the history
… jax-0.4.35_easyblock_compat.patch, jax-0.4.35_fix-pybind11-systemlib_cupti.patch
  • Loading branch information
ThomasHoffmann77 committed Nov 28, 2024
1 parent 91c8df6 commit 9164c61
Show file tree
Hide file tree
Showing 3 changed files with 250 additions and 0 deletions.
162 changes: 162 additions & 0 deletions easybuild/easyconfigs/j/jax/jax-0.4.35-gfbf-2024a-CUDA-12.6.0.eb
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
# This file is an EasyBuild reciPY as per https://github.com/easybuilders/easybuild
# Author: Denis Kristak
# Updated by: Alex Domingo (Vrije Universiteit Brussel)
# Updated by: Pavel Tománek (INUITS)
# Updated by: Thomas Hoffmann (EMBL Heidelberg)
easyblock = 'PythonBundle'

name = 'jax'
version = '0.4.35'
versionsuffix = '-CUDA-%(cudaver)s'

homepage = 'https://jax.readthedocs.io/'
description = """Composable transformations of Python+NumPy programs:
differentiate, vectorize, JIT to GPU/TPU, and more"""

toolchain = {'name': 'gfbf', 'version': '2024a'}
cuda_compute_capabilities = ["5.0", "6.0", "6.1", "7.0", "7.5", "8.0", "8.6", "9.0"]

builddependencies = [
# ('Bazel', '7.4.1'), TODO: problems with py. 6.5.0 works.
('pybind11', '2.13.6'), # 2.12.0 ? SciPy-bundle has pybind/2.12.0. Fix: change to builddependency in SciPy-bundle?
('pytest-xdist', '3.6.1'),
('git', '2.45.1'), # bazel uses git to fetch repositories
('matplotlib', '3.9.2'), # required for tests/lobpcg_test.py
('poetry', '1.8.3'),
('Clang', '18.1.8')
]

dependencies = [
('CUDA', '12.6.0', '', SYSTEM), # 12.6.2 ?
('cuDNN', '9.5.0.50', versionsuffix, SYSTEM),
('NCCL', '2.22.3', versionsuffix),
('Python', '3.12.3'),
('SciPy-bundle', '2024.05'), # 2024.11 ?
('absl-py', '2.1.0'),
('flatbuffers-python', '24.3.25'),
('ml_dtypes', '0.5.0'),
('zlib', '1.3.1'),
('pybind11', '2.13.6'), # override 2.12.0. SciPy-bundle has pybind/2.12.0. Fix:
# change to builddependency in SciPy-bundle? (TODO)
]

# downloading xla and other tarballs to avoid that Bazel downloads it during the build
local_extract_cmd = 'mkdir -p %(builddir)s/archives && cp %s %(builddir)s/archives'
# note: following commits *must* be the exact same onces used upstream
# XLA_COMMIT from jax-jaxlib: third_party/xla/workspace.bzl
local_xla_commit = '76da730179313b3bebad6dea6861768421b7358c'
# TFRT_COMMIT from xla: third_party/tsl/third_party/tf_runtime/workspace.bzl
local_tfrt_commit = '0aeefb1660d7e37964b2bb71b1f518096bda9a25' # TODO: still required?
# TODO: add other downloads

# Use sources downloaded by EasyBuild
_jaxlib_buildopts = '--bazel_options="--distdir=%(builddir)s/archives" '
# Use dependencies from EasyBuild
_jaxlib_buildopts += '--bazel_options="--action_env=TF_SYSTEM_LIBS=pybind11" '
_jaxlib_buildopts += '--bazel_options="--action_env=CPATH=$EBROOTPYBIND11/include:$EBROOTCUDA/extras/CUPTI/include" '
# Avoid warning (treated as error) in upb/table.c
_jaxlib_buildopts += '--bazel_options="--copt=-Wno-maybe-uninitialized" ' # TODO: still required?
# _jaxlib_buildopts += '--nouse_clang ' #TODO: avoid clang (?)
_jaxlib_buildopts += '--cuda_version=%(cudaver)s '
_jaxlib_buildopts += '--python_bin_path=$EBROOTPYTHON/bin/python3 '
# Do not use hermetic CUDA/cuDNN/NCCL: (requires action_env=CPATH=$EBROOTCUDA/extras/CUPTI/include";
# requires patch of external/xla/xla/tsl/cuda/cupti_stub.cc and jaxlib/gpu/vendor.h (#include <cupti.h>):
_jaxlib_buildopts += """--bazel_options=--repo_env=LOCAL_CUDNN_PATH="$EBROOTCUDNN" """
_jaxlib_buildopts += """--bazel_options=--repo_env=LOCAL_NCCL_PATH="$EBROOTNCCL" """
_jaxlib_buildopts += """--bazel_options=--repo_env=LOCAL_CUDA_PATH="$EBROOTCUDA" """
_jaxlib_buildopts += """--bazel_options="--copt=-Ithird_party/gpus/cuda/extras/CUPTI/include" """

# get rid of .devDate versionsuffix: TODO: find a better way
# _no_devtag = """ export JAX_RELEASE && export JAXLIB_RELEASE && """ does not work (?)
_no_devtag = """ sed -i "s/version=__version__/version='%(version)s'/g" setup.py && """
_jaxlib_buildopts += """--bazel_options="--action_env=JAXLIB_RELEASE=1" """ # required?

components = [
('jaxlib', version, {
'sources': [
{
'source_urls': ['https://github.com/google/jax/archive/'],
'filename': 'jax-v%(version)s.tar.gz',
},
{
'source_urls': ['https://github.com/openxla/xla/archive'],
'download_filename': '%s.tar.gz' % local_xla_commit,
'filename': 'xla-%s.tar.gz' % local_xla_commit[:8],
'extract_cmd': local_extract_cmd,
},
{
'source_urls': ['https://github.com/tensorflow/runtime/archive'],
'download_filename': '%s.tar.gz' % local_tfrt_commit,
'filename': 'tf_runtime-%s.tar.gz' % local_tfrt_commit[:8],
'extract_cmd': local_extract_cmd,
},
],
'patches': [
'jax-0.4.35_easyblock_compat.patch',
'jax-0.4.35_fix-pybind11-systemlib_cupti.patch',
'jax-v0.4.35_version.patch',
],
'checksums': [
{'jax-v0.4.35.tar.gz':
'65e086708ae56670676b7b2340ad82b901d8c9993d1241a839c8990bdb8d6212'},
{'xla-76da7301.tar.gz':
'd67ced09b69ab8d7b26fa4cd5f48b22db57eb330294a35f6e1d462ee17066757'},
{'tf_runtime-0aeefb16.tar.gz':
'a3df827d7896774cb1d80bf4e1c79ab05c268f29bd4d3db1fb5a4b9c2079d8e3'},
{'jax-0.4.35_easyblock_compat.patch':
'cbf4ad92b8438c4ce2a975efce1c47c57d4c3b117bceee071ab660f964057223'},
{'jax-0.4.35_fix-pybind11-systemlib_cupti.patch':
'78efe6b5108a5da1935258286c94dea8438fd03651533c34023eeba27f514130'},
],
'start_dir': 'jax-jax-v%(version)s',
'buildopts': _jaxlib_buildopts,
'prebuildopts': ' mkdir third_party/gpus/cuda/extras/ -p && ' +
'ln -s $EBROOTCUDA/extras/CUPTI third_party/gpus/cuda/extras --relative &&' +
_no_devtag
}),
]

# Some tests require an isolated run: TODO: still required?
local_isolated_tests = [
'tests/host_callback_test.py::HostCallbackTapTest::test_tap_scan_custom_jvp',
'tests/host_callback_test.py::HostCallbackTapTest::test_tap_transforms_doc',
'tests/lax_scipy_special_functions_test.py::LaxScipySpcialFunctionsTest' +
'::testScipySpecialFun_gammainc_s_2x1x4_float32_float32',
]
# deliberately not testing in parallel, as that results in (additional) failing tests;
# use XLA_PYTHON_CLIENT_ALLOCATOR=platform to allocate and deallocate GPU memory during testing,
# see https://github.com/google/jax/issues/7323 and
# https://github.com/google/jax/blob/main/docs/gpu_memory_allocation.rst;
# use CUDA_VISIBLE_DEVICES=0 to avoid failing tests on systems with multiple GPUs;
# use NVIDIA_TF32_OVERRIDE=0 to avoid loosing numerical precision by disabling TF32 Tensor Cores;
local_test_exports = [
"NVIDIA_TF32_OVERRIDE=0",
"CUDA_VISIBLE_DEVICES=0",
"XLA_PYTHON_CLIENT_ALLOCATOR=platform",
"JAX_ENABLE_X64=true",
]
local_test = ''.join(['export %s;' % x for x in local_test_exports])
# run all tests at once except for local_isolated_tests:
local_test += "pytest -vv tests %s && " % ' '.join(['--deselect %s' % x for x in local_isolated_tests])
# run remaining local_isolated_tests separately:
local_test += ' && '.join(['pytest -vv %s' % x for x in local_isolated_tests])

use_pip = True

exts_list = [
(name, version, {
'source_tmpl': '%(name)s-v%(version)s.tar.gz',
'source_urls': ['https://github.com/google/jax/archive/'],
# 'patches': ['jax-0.4.25_fix_env_test_no_log_spam.patch'], # TODO: still required? update?
'patches': ['jax-v0.4.35_version.patch'],
'checksums': [
{'jax-v0.4.35.tar.gz': '65e086708ae56670676b7b2340ad82b901d8c9993d1241a839c8990bdb8d6212'},
],
'runtest': local_test,
'preinstallopts': _no_devtag
}),
]

sanity_pip_check = True

moduleclass = 'ai'
21 changes: 21 additions & 0 deletions easybuild/easyconfigs/j/jax/jax-0.4.35_easyblock_compat.patch
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Thomas Hoffmann, EMBL Heidelberg, [email protected], 2024/11
# add dummy parameters to build/build.py for cudnn_path and cuda_path, which are set by default by the jaxlib easyblock.
diff -ru jax-jax-v0.4.35/build/build.py jax-jax-v0.4.35_easyblockcompat/build/build.py
--- jax-jax-v0.4.35/build/build.py 2024-10-22 21:00:23.000000000 +0200
+++ jax-jax-v0.4.35_easyblockcompat/build/build.py 2024-11-19 12:35:46.524479324 +0100
@@ -549,6 +549,15 @@
help_str="Same as update_requirements, but will consider dev, nightly "
"and pre-release versions of packages.")

+ parser.add_argument(
+ "--cuda_path",
+ default="dummy",
+ help="compatibility with jaxlib.py easyblock")
+ parser.add_argument(
+ "--cudnn_path",
+ default="dummy",
+ help="compatibility with jaxlib.py easyblock")
+
args = parser.parse_args()

logging.basicConfig()
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
jax-0.4.25_fix-pybind11-systemlib.patch: Add missing value for System Pybind11 Bazel config
jax-0.4.25_fix-pybind11-systemlib.patch: Author: Alexander Grund (TU Dresden)

THEMBL: fix cupti include path.

diff --git a/third_party/xla/fix-pybind11-systemlib.patch b/third_party/xla/fix-pybind11-systemlib.patch
new file mode 100644
index 000000000..68bd2063d
--- /dev/null
+++ b/third_party/xla/fix-pybind11-systemlib.patch
@@ -0,0 +1,13 @@
+--- xla-orig/third_party/tsl/third_party/systemlibs/pybind11.BUILD
++++ xla-4ccfe33c71665ddcbca5b127fefe8baa3ed632d4/third_party/tsl/third_party/systemlibs/pybind11.BUILD
+@@ -6,3 +6,10 @@
+ "@tsl//third_party/python_runtime:headers",
+ ],
+ )
++
++# Needed by pybind11_bazel.
++config_setting(
++ name = "osx",
++ constraint_values = ["@platforms//os:osx"],
++)
++
diff -ruN jax-jax-v0.4.35/jaxlib/gpu/vendor.h jax-jax-v0.4.35_jaxlib_cupti__fix-pybind11-systemlib/jaxlib/gpu/vendor.h
--- jax-jax-v0.4.35/jaxlib/gpu/vendor.h 2024-10-22 21:00:23.000000000 +0200
+++ jax-jax-v0.4.35_jaxlib_cupti__fix-pybind11-systemlib/jaxlib/gpu/vendor.h 2024-11-26 10:56:20.396087442 +0100
@@ -23,7 +23,7 @@
#if defined(JAX_GPU_CUDA)

// IWYU pragma: begin_exports
-#include "third_party/gpus/cuda/extras/CUPTI/include/cupti.h"
+#include <cupti.h>
#include "third_party/gpus/cuda/include/cooperative_groups.h"
#include "third_party/gpus/cuda/include/cuComplex.h"
#include "third_party/gpus/cuda/include/cublas_v2.h"
diff -ruN jax-jax-v0.4.35/third_party/xla/workspace.bzl jax-jax-v0.4.35_jaxlib_cupti__fix-pybind11-systemlib/third_party/xla/workspace.bzl
--- jax-jax-v0.4.35/third_party/xla/workspace.bzl 2024-10-22 21:00:23.000000000 +0200
+++ jax-jax-v0.4.35_jaxlib_cupti__fix-pybind11-systemlib/third_party/xla/workspace.bzl 2024-11-27 12:17:37.913466273 +0100
@@ -30,6 +30,11 @@
sha256 = XLA_SHA256,
strip_prefix = "xla-{commit}".format(commit = XLA_COMMIT),
urls = tf_mirror_urls("https://github.com/openxla/xla/archive/{commit}.tar.gz".format(commit = XLA_COMMIT)),
+ patch_file = [
+ "//third_party/xla:xla-76da73_cupti.patch",
+ "//third_party/xla:fix-pybind11-systemlib.patch",
+ ],
+
)

# For development, one often wants to make changes to the TF repository as well
diff -ruN jax-jax-v0.4.35/third_party/xla/xla-76da73_cupti.patch jax-jax-v0.4.35_jaxlib_cupti__fix-pybind11-systemlib/third_party/xla/xla-76da73_cupti.patch
--- jax-jax-v0.4.35/third_party/xla/xla-76da73_cupti.patch 1970-01-01 01:00:00.000000000 +0100
+++ jax-jax-v0.4.35_jaxlib_cupti__fix-pybind11-systemlib/third_party/xla/xla-76da73_cupti.patch 2024-11-27 12:18:26.668582799 +0100
@@ -0,0 +1,12 @@
+diff -ru xla-76da730179313b3bebad6dea6861768421b7358c/xla/tsl/cuda/cupti_stub.cc xla-76da730179313b3bebad6dea6861768421b7358c_cupti/xla/tsl/cuda/cupti_stub.cc
+--- xla-76da730179313b3bebad6dea6861768421b7358c/xla/tsl/cuda/cupti_stub.cc 2024-10-21 20:29:31.000000000 +0200
++++ xla-76da730179313b3bebad6dea6861768421b7358c_cupti/xla/tsl/cuda/cupti_stub.cc 2024-11-26 12:04:11.695539146 +0100
+@@ -13,7 +13,7 @@
+ limitations under the License.
+ ==============================================================================*/
+
+-#include "third_party/gpus/cuda/extras/CUPTI/include/cupti.h"
++#include <cupti.h>
+ #include "third_party/gpus/cuda/include/cuda.h"
+ #include "tsl/platform/dso_loader.h"
+ #include "tsl/platform/load_library.h"

0 comments on commit 9164c61

Please sign in to comment.