Skip to content

Commit

Permalink
test v0.4.34 with pybind11/2.12.0 builddep
Browse files Browse the repository at this point in the history
  • Loading branch information
ThomasHoffmann77 committed Dec 20, 2024
1 parent fc5b969 commit 849f9fb
Showing 1 changed file with 168 additions and 0 deletions.
168 changes: 168 additions & 0 deletions easybuild/easyconfigs/j/jax/jax-0.4.34-gfbf-2024a-CUDA-12.6.0.eb
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
# This file is an EasyBuild reciPY as per https://github.com/easybuilders/easybuild
# Author: Denis Kristak
# Updated by: Alex Domingo (Vrije Universiteit Brussel)
# Updated by: Pavel Tománek (INUITS)
# Updated by: Thomas Hoffmann (EMBL Heidelberg)
easyblock = 'PythonBundle'

name = 'jax'
version = '0.4.34'
versionsuffix = '-CUDA-%(cudaver)s'

homepage = 'https://jax.readthedocs.io/'
description = """Composable transformations of Python+NumPy programs:
differentiate, vectorize, JIT to GPU/TPU, and more"""

toolchain = {'name': 'gfbf', 'version': '2024a'}
cuda_compute_capabilities = ["5.0", "6.0", "6.1", "7.0", "7.5", "8.0", "8.6", "9.0"]

builddependencies = [
# ('Bazel', '7.4.1'), TODO: problems with @@local_config_python//:py3_runtime:
# Error in fail: interpreter_path must be an absolute path
# Bazel 6.5.0 (download) works.
('pybind11', '2.12.0'), # 2.12.0 ? SciPy-bundle has pybind/2.12.0.
# Fix: change to builddependency in SciPy-bundle?
# tmporarily mv to dependencies (TODO: mv back)
('pytest-xdist', '3.6.1'),
('git', '2.45.1'), # bazel uses git to fetch repositories
('matplotlib', '3.9.2'), # required for tests/lobpcg_test.py
('poetry', '1.8.3'),
('Clang', '18.1.8')
]

dependencies = [
('CUDA', '12.6.0', '', SYSTEM), # 12.6.2 ?
('cuDNN', '9.5.0.50', versionsuffix, SYSTEM),
('NCCL', '2.22.3', versionsuffix),
('Python', '3.12.3'),
('SciPy-bundle', '2024.05'), # 2024.11 ?
('absl-py', '2.1.0'),
('flatbuffers-python', '24.3.25'),
('ml_dtypes', '0.5.0'),
('zlib', '1.3.1'),
# ('pybind11', '2.13.6'), # override 2.12.0. SciPy-bundle has pybind/2.12.0. Fix:
# change to builddependency in SciPy-bundle? (TODO)
]

# downloading xla and other tarballs to avoid that Bazel downloads it during the build
local_extract_cmd = 'mkdir -p %(builddir)s/archives && cp %s %(builddir)s/archives'
# note: following commits *must* be the exact same onces used upstream
# XLA_COMMIT from jax-jaxlib: third_party/xla/workspace.bzl
local_xla_commit = 'cd6e808c59f53b40a99df1f1b860db9a3e598bff'
# TFRT_COMMIT from xla: third_party/tsl/third_party/tf_runtime/workspace.bzl
local_tfrt_commit = '0aeefb1660d7e37964b2bb71b1f518096bda9a25' # TODO: still required?
# TODO: add other downloads

# Use sources downloaded by EasyBuild
_jaxlib_buildopts = '--bazel_options="--distdir=%(builddir)s/archives" '
# Use dependencies from EasyBuild
_jaxlib_buildopts += '--bazel_options="--action_env=TF_SYSTEM_LIBS=pybind11" '
_jaxlib_buildopts += '--bazel_options="--action_env=CPATH=$EBROOTPYBIND11/include:$EBROOTCUDA/extras/CUPTI/include" '
# Avoid warning (treated as error) in upb/table.c
_jaxlib_buildopts += '--bazel_options="--copt=-Wno-maybe-uninitialized" ' # TODO: still required?
# _jaxlib_buildopts += '--nouse_clang ' #TODO: avoid clang (?)
_jaxlib_buildopts += '--cuda_version=%(cudaver)s '
_jaxlib_buildopts += '--python_bin_path=$EBROOTPYTHON/bin/python3 '
# Do not use hermetic CUDA/cuDNN/NCCL: (requires action_env=CPATH=$EBROOTCUDA/extras/CUPTI/include";
# requires patch of external/xla/xla/tsl/cuda/cupti_stub.cc and jaxlib/gpu/vendor.h (#include <cupti.h>):
_jaxlib_buildopts += """--bazel_options=--repo_env=LOCAL_CUDNN_PATH="$EBROOTCUDNN" """
_jaxlib_buildopts += """--bazel_options=--repo_env=LOCAL_NCCL_PATH="$EBROOTNCCL" """
_jaxlib_buildopts += """--bazel_options=--repo_env=LOCAL_CUDA_PATH="$EBROOTCUDA" """
_jaxlib_buildopts += """--bazel_options="--copt=-Ithird_party/gpus/cuda/extras/CUPTI/include" """

# get rid of .devDate versionsuffix: TODO: find a better way
# _no_devtag = """ export JAX_RELEASE && export JAXLIB_RELEASE && """ does not work (?)
_no_devtag = """ sed -i "s/version=__version__/version='%(version)s'/g" setup.py && """
_jaxlib_buildopts += """--bazel_options="--action_env=JAXLIB_RELEASE=1" """ # required?

components = [
('jaxlib', version, {
'sources': [
{
'source_urls': ['https://github.com/google/jax/archive/'],
'filename': 'jax-v%(version)s.tar.gz',
},
{
'source_urls': ['https://github.com/openxla/xla/archive'],
'download_filename': '%s.tar.gz' % local_xla_commit,
'filename': 'xla-%s.tar.gz' % local_xla_commit[:8],
'extract_cmd': local_extract_cmd,
},
{
'source_urls': ['https://github.com/tensorflow/runtime/archive'],
'download_filename': '%s.tar.gz' % local_tfrt_commit,
'filename': 'tf_runtime-%s.tar.gz' % local_tfrt_commit[:8],
'extract_cmd': local_extract_cmd,
},
],
'patches': [
'jax-0.4.35_easyblock_compat.patch',
'jax-0.4.35_fix-pybind11-systemlib_cupti.patch',
'jax-0.4.35_version.patch',
],
'checksums': [
{'jax-v0.4.34.tar.gz':
'd3a75ad667772309ade81350fa70c4a78028a920028800282e46d8383c0ee6bb'},
{'xla-cd6e808c.tar.gz':
'65cb6d63ef4083b35775052636cb9c629f86db6947c8b91711923ba31dbdcde8'},
{'tf_runtime-0aeefb16.tar.gz':
'a3df827d7896774cb1d80bf4e1c79ab05c268f29bd4d3db1fb5a4b9c2079d8e3'},
{'jax-0.4.35_easyblock_compat.patch':
'cbf4ad92b8438c4ce2a975efce1c47c57d4c3b117bceee071ab660f964057223'},
{'jax-0.4.35_fix-pybind11-systemlib_cupti.patch':
'78efe6b5108a5da1935258286c94dea8438fd03651533c34023eeba27f514130'},
{'jax-0.4.35_version.patch':
'cd2139a7802abf14b4b2cecee331aed80fff2ef91e16fa105093aea0795455e8'},
],
'start_dir': 'jax-jax-v%(version)s',
'buildopts': _jaxlib_buildopts,
'prebuildopts': ' mkdir third_party/gpus/cuda/extras/ -p && ' +
'ln -s $EBROOTCUDA/extras/CUPTI third_party/gpus/cuda/extras --relative &&' +
_no_devtag
}),
]

# Some tests require an isolated run: TODO: still required?
local_isolated_tests = [
'tests/host_callback_test.py::HostCallbackTapTest::test_tap_scan_custom_jvp',
'tests/host_callback_test.py::HostCallbackTapTest::test_tap_transforms_doc',
'tests/lax_scipy_special_functions_test.py::LaxScipySpcialFunctionsTest' +
'::testScipySpecialFun_gammainc_s_2x1x4_float32_float32',
]
# deliberately not testing in parallel, as that results in (additional) failing tests;
# use XLA_PYTHON_CLIENT_ALLOCATOR=platform to allocate and deallocate GPU memory during testing,
# see https://github.com/google/jax/issues/7323 and
# https://github.com/google/jax/blob/main/docs/gpu_memory_allocation.rst;
# use CUDA_VISIBLE_DEVICES=0 to avoid failing tests on systems with multiple GPUs;
# use NVIDIA_TF32_OVERRIDE=0 to avoid loosing numerical precision by disabling TF32 Tensor Cores;
local_test_exports = [
"NVIDIA_TF32_OVERRIDE=0",
"CUDA_VISIBLE_DEVICES=0",
"XLA_PYTHON_CLIENT_ALLOCATOR=platform",
"JAX_ENABLE_X64=true",
]
local_test = ''.join(['export %s;' % x for x in local_test_exports])
# run all tests at once except for local_isolated_tests:
local_test += "pytest -vv tests %s && " % ' '.join(['--deselect %s' % x for x in local_isolated_tests])
# run remaining local_isolated_tests separately:
local_test += ' && '.join(['pytest -vv %s' % x for x in local_isolated_tests])

use_pip = True

exts_list = [
(name, version, {
'patches': ['jax-0.4.35_version.patch'],
'preinstallopts': _no_devtag,
'runtest': False,
'source_tmpl': '%(name)s-v%(version)s.tar.gz',
'source_urls': ['https://github.com/google/jax/archive/'],
'checksums': [
{'jax-v0.4.34.tar.gz': 'd3a75ad667772309ade81350fa70c4a78028a920028800282e46d8383c0ee6bb'},
{'jax-0.4.35_version.patch': 'cd2139a7802abf14b4b2cecee331aed80fff2ef91e16fa105093aea0795455e8'},
],
}),
]

sanity_pip_check = True

moduleclass = 'ai'

0 comments on commit 849f9fb

Please sign in to comment.