From 28687b08e4e714fef9d6b6c93d0754a57dfd4d44 Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Fri, 20 Dec 2024 11:26:04 +0000 Subject: [PATCH] Move jex.ffi to jax.ffi. --- CHANGELOG.md | 2 + docs/ffi.ipynb | 63 ++-- docs/ffi.md | 63 ++-- docs/jax.extend.rst | 1 - docs/{jax.extend.ffi.rst => jax.ffi.rst} | 6 +- docs/jax.rst | 1 + .../ffi/src/jax_ffi_example/cpu_examples.py | 9 +- .../ffi/src/jax_ffi_example/cuda_examples.py | 9 +- examples/ffi/src/jax_ffi_example/rms_norm.py | 11 +- jax/BUILD | 2 +- jax/__init__.py | 1 + jax/_src/{extend => }/ffi.py | 10 +- jax/_src/lax/linalg.py | 2 +- jax/experimental/mosaic/gpu/profiler.py | 5 +- jax/experimental/shard_map.py | 2 +- jax/extend/ffi.py | 47 ++- jax/ffi.py | 24 ++ tests/BUILD | 7 + tests/extend_test.py | 298 ---------------- tests/ffi_test.py | 332 ++++++++++++++++++ 20 files changed, 493 insertions(+), 402 deletions(-) rename docs/{jax.extend.ffi.rst => jax.ffi.rst} (55%) rename jax/_src/{extend => }/ffi.py (98%) create mode 100644 jax/ffi.py create mode 100644 tests/ffi_test.py diff --git a/CHANGELOG.md b/CHANGELOG.md index e86dece51013..8b87ccf8d4f0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,8 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. * From {mod}`jax.interpreters.xla`, `abstractify` and `pytype_aval_mappings` are now deprecated, having been replaced by symbols of the same name in {mod}`jax.core`. + * The {mod}`jax.extend.ffi` submodule was moved to {mod}`jax.ffi`, and the + previous import path is deprecated. * Deletions * `jax_enable_memories` flag has been deleted and the behavior of that flag diff --git a/docs/ffi.ipynb b/docs/ffi.ipynb index 72a2a6914fc0..c80c83996cdb 100644 --- a/docs/ffi.ipynb +++ b/docs/ffi.ipynb @@ -21,7 +21,7 @@ "JAX's FFI support is provided in two parts:\n", "\n", "1. A header-only C++ library from XLA which is packaged as part of JAX as of v0.4.29 or available from the [openxla/xla](https://github.com/openxla/xla) project, and\n", - "2. A Python front end, available in the `jax.extend.ffi` submodule.\n", + "2. A Python front end, available in the `jax.ffi` submodule.\n", "\n", "In this tutorial we demonstrate the use of both of these components using a simple example, and then go on to discuss some lower-level extensions for more complicated use cases.\n", "We start by presenting the FFI on CPU, and discuss generalizations to GPU or multi-device environments below.\n", @@ -191,9 +191,9 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "With this compiled library in hand, we now need to register this handler with XLA via the {func}`~jax.extend.ffi.register_ffi_target` function.\n", + "With this compiled library in hand, we now need to register this handler with XLA via the {func}`~jax.ffi.register_ffi_target` function.\n", "This function expects our handler (a function pointer to the C++ function `RmsNorm`) to be wrapped in a [`PyCapsule`](https://docs.python.org/3/c-api/capsule.html).\n", - "JAX provides a helper function {func}`~jax.extend.ffi.pycapsule` to help with this:" + "JAX provides a helper function {func}`~jax.ffi.pycapsule` to help with this:" ] }, { @@ -204,12 +204,11 @@ "source": [ "import ctypes\n", "from pathlib import Path\n", - "import jax.extend as jex\n", "\n", "path = next(Path(\"ffi\").glob(\"librms_norm*\"))\n", "rms_norm_lib = ctypes.cdll.LoadLibrary(path)\n", - "jex.ffi.register_ffi_target(\n", - " \"rms_norm\", jex.ffi.pycapsule(rms_norm_lib.RmsNorm), platform=\"cpu\")" + "jax.ffi.register_ffi_target(\n", + " \"rms_norm\", jax.ffi.pycapsule(rms_norm_lib.RmsNorm), platform=\"cpu\")" ] }, { @@ -217,7 +216,7 @@ "metadata": {}, "source": [ "```{tip}\n", - "If you're familiar with the legacy \"custom call\" API, it's worth noting that you can also use {func}`~jax.extend.ffi.register_ffi_target` to register a custom call target by manually specifying the keyword argument `api_version=0`. The default `api_version` for {func}`~jax.extend.ffi.register_ffi_target` is `1`, the new \"typed\" FFI API that we're using here.\n", + "If you're familiar with the legacy \"custom call\" API, it's worth noting that you can also use {func}`~jax.ffi.register_ffi_target` to register a custom call target by manually specifying the keyword argument `api_version=0`. The default `api_version` for {func}`~jax.ffi.register_ffi_target` is `1`, the new \"typed\" FFI API that we're using here.\n", "```\n", "\n", "**An alternative approach**:\n", @@ -251,7 +250,7 @@ "# Assuming that we compiled a nanobind extension called `rms_norm`:\n", "import rms_norm as rms_norm_lib\n", "\n", - "jex.ffi.register_ffi_target(\"rms_norm\", rms_norm_lib.rms_norm(), platform=\"cpu\")\n", + "jax.ffi.register_ffi_target(\"rms_norm\", rms_norm_lib.rms_norm(), platform=\"cpu\")\n", "```" ] }, @@ -261,7 +260,7 @@ "source": [ "## Frontend code\n", "\n", - "Now that we have registered our FFI handler, it is straightforward to call our C++ library from JAX using the {func}`~jax.extend.ffi.ffi_call` function:" + "Now that we have registered our FFI handler, it is straightforward to call our C++ library from JAX using the {func}`~jax.ffi.ffi_call` function:" ] }, { @@ -282,7 +281,7 @@ " if x.dtype != jnp.float32:\n", " raise ValueError(\"Only the float32 dtype is implemented by rms_norm\")\n", "\n", - " call = jex.ffi.ffi_call(\n", + " call = jax.ffi.ffi_call(\n", " # The target name must be the same string as we used to register the target\n", " # above in `register_custom_call_target`\n", " \"rms_norm\",\n", @@ -314,25 +313,25 @@ "metadata": {}, "source": [ "This code cell includes a lot of inline comments which should explain most of what is happening here, but there are a few points that are worth explicitly highlighting.\n", - "Most of the heavy lifting here is done by the {func}`~jax.extend.ffi.ffi_call` function, which tells JAX how to call the foreign function for a particular set of inputs.\n", - "It's important to note that the first argument to {func}`~jax.extend.ffi.ffi_call` must be a string that matches the target name that we used when calling `register_custom_call_target` above.\n", + "Most of the heavy lifting here is done by the {func}`~jax.ffi.ffi_call` function, which tells JAX how to call the foreign function for a particular set of inputs.\n", + "It's important to note that the first argument to {func}`~jax.ffi.ffi_call` must be a string that matches the target name that we used when calling `register_custom_call_target` above.\n", "\n", - "Any attributes (defined using `Attr` in the C++ wrapper above) should be passed as keyword arguments to {func}`~jax.extend.ffi.ffi_call`.\n", + "Any attributes (defined using `Attr` in the C++ wrapper above) should be passed as keyword arguments to {func}`~jax.ffi.ffi_call`.\n", "Note that we explicitly cast `eps` to `np.float32` because our FFI library expects a C `float`, and we can't use `jax.numpy` here, because these parameters must be static arguments.\n", "\n", - "The `vmap_method` argument to {func}`~jax.extend.ffi.ffi_call` defines how this FFI call interacts with {func}`~jax.vmap` as described next.\n", + "The `vmap_method` argument to {func}`~jax.ffi.ffi_call` defines how this FFI call interacts with {func}`~jax.vmap` as described next.\n", "\n", "```{tip}\n", - "If you are familiar with the earlier \"custom call\" interface, you might be surprised that we're not passing the problem dimensions as parameters (batch size, etc.) to {func}`~jax.extend.ffi.ffi_call`.\n", + "If you are familiar with the earlier \"custom call\" interface, you might be surprised that we're not passing the problem dimensions as parameters (batch size, etc.) to {func}`~jax.ffi.ffi_call`.\n", "In this earlier API, the backend had no mechanism for receiving metadata about the input arrays, but since the FFI includes dimension information with the `Buffer` objects, we no longer need to compute this using Python when lowering.\n", - "One major perk of this change is {func}`~jax.extend.ffi.ffi_call` can support some simple {func}`~jax.vmap` semantics out of the box, as discussed below.\n", + "One major perk of this change is {func}`~jax.ffi.ffi_call` can support some simple {func}`~jax.vmap` semantics out of the box, as discussed below.\n", "```\n", "\n", "(ffi-call-vmap)=\n", "### Batching with `vmap`\n", "\n", - "{func}`~jax.extend.ffi.ffi_call` supports some simple {func}`~jax.vmap` semantics out of the box using the `vmap_method` parameter.\n", - "The docs for {func}`~jax.pure_callback` provide more details about the `vmap_method` parameter, and the same behavior applies to {func}`~jax.extend.ffi.ffi_call`.\n", + "{func}`~jax.ffi.ffi_call` supports some simple {func}`~jax.vmap` semantics out of the box using the `vmap_method` parameter.\n", + "The docs for {func}`~jax.pure_callback` provide more details about the `vmap_method` parameter, and the same behavior applies to {func}`~jax.ffi.ffi_call`.\n", "\n", "The simplest `vmap_method` is `\"sequential\"`.\n", "In this case, when `vmap`ped, an `ffi_call` will be rewritten as a {func}`~jax.lax.scan` with the `ffi_call` in the body.\n", @@ -395,7 +394,7 @@ "outputs": [], "source": [ "def rms_norm_sequential(x, eps=1e-5):\n", - " return jex.ffi.ffi_call(\n", + " return jax.ffi.ffi_call(\n", " \"rms_norm\",\n", " jax.ShapeDtypeStruct(x.shape, x.dtype),\n", " vmap_method=\"sequential\",\n", @@ -418,9 +417,9 @@ "source": [ "### Differentiation\n", "\n", - "Unlike with batching, {func}`~jax.extend.ffi.ffi_call` doesn't provide any default support for automatic differentiation (AD) of foreign functions.\n", + "Unlike with batching, {func}`~jax.ffi.ffi_call` doesn't provide any default support for automatic differentiation (AD) of foreign functions.\n", "As far as JAX is concerned, the foreign function is a black box that can't be inspected to determine the appropriate behavior when differentiated.\n", - "Therefore, it is the {func}`~jax.extend.ffi.ffi_call` user's responsibility to define a custom derivative rule.\n", + "Therefore, it is the {func}`~jax.ffi.ffi_call` user's responsibility to define a custom derivative rule.\n", "\n", "More details about custom derivative rules can be found in the [custom derivatives tutorial](https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html), but the most common pattern used for implementing differentiation for foreign functions is to define a {func}`~jax.custom_vjp` which itself calls a foreign function.\n", "In this case, we actually define two new FFI calls:\n", @@ -429,7 +428,7 @@ "2. `rms_norm_bwd` takes the residuals and the output co-tangents, and returns the input co-tangents.\n", "\n", "We won't get into the details of the RMS normalization backwards pass, but take a look at the [C++ source code](https://github.com/jax-ml/jax/blob/main/examples/ffi/src/jax_ffi_example/rms_norm.cc) to see how these functions are implemented on the back end.\n", - "The main point to emphasize here is that the \"residual\" computed has a different shape than the primal output, therefore, in the {func}`~jax.extend.ffi.ffi_call` to `res_norm_fwd`, the output type has two elements with different shapes.\n", + "The main point to emphasize here is that the \"residual\" computed has a different shape than the primal output, therefore, in the {func}`~jax.ffi.ffi_call` to `res_norm_fwd`, the output type has two elements with different shapes.\n", "\n", "This custom derivative rule can be wired in as follows:" ] @@ -440,16 +439,16 @@ "metadata": {}, "outputs": [], "source": [ - "jex.ffi.register_ffi_target(\n", - " \"rms_norm_fwd\", jex.ffi.pycapsule(rms_norm_lib.RmsNormFwd), platform=\"cpu\"\n", + "jax.ffi.register_ffi_target(\n", + " \"rms_norm_fwd\", jax.ffi.pycapsule(rms_norm_lib.RmsNormFwd), platform=\"cpu\"\n", ")\n", - "jex.ffi.register_ffi_target(\n", - " \"rms_norm_bwd\", jex.ffi.pycapsule(rms_norm_lib.RmsNormBwd), platform=\"cpu\"\n", + "jax.ffi.register_ffi_target(\n", + " \"rms_norm_bwd\", jax.ffi.pycapsule(rms_norm_lib.RmsNormBwd), platform=\"cpu\"\n", ")\n", "\n", "\n", "def rms_norm_fwd(x, eps=1e-5):\n", - " y, res = jex.ffi.ffi_call(\n", + " y, res = jax.ffi.ffi_call(\n", " \"rms_norm_fwd\",\n", " (\n", " jax.ShapeDtypeStruct(x.shape, x.dtype),\n", @@ -466,7 +465,7 @@ " assert res.shape == ct.shape[:-1]\n", " assert x.shape == ct.shape\n", " return (\n", - " jex.ffi.ffi_call(\n", + " jax.ffi.ffi_call(\n", " \"rms_norm_bwd\",\n", " jax.ShapeDtypeStruct(ct.shape, ct.dtype),\n", " vmap_method=\"broadcast_all\",\n", @@ -533,7 +532,7 @@ "On the front end, the registration code would be updated to specify the appropriate platform:\n", "\n", "```python\n", - "jex.ffi.register_ffi_target(\n", + "jax.ffi.register_ffi_target(\n", " \"rms_norm_cuda\", rms_norm_lib_cuda.rms_norm(), platform=\"CUDA\"\n", ")\n", "```\n", @@ -554,7 +553,7 @@ " out_type = jax.ShapeDtypeStruct(x.shape, x.dtype)\n", "\n", " def impl(target_name):\n", - " return lambda x: jex.ffi.ffi_call(\n", + " return lambda x: jax.ffi.ffi_call(\n", " target_name,\n", " out_type,\n", " vmap_method=\"broadcast_all\",\n", @@ -620,9 +619,9 @@ "This tutorial covers most of the basic steps that are required to get up and running with JAX's FFI, but advanced use cases may require more features.\n", "We will leave these topics to future tutorials, but here are some possibly useful references:\n", "\n", - "* **Supporting multiple dtypes**: In this tutorial's example, we restricted to only support `float32` inputs and outputs, but many use cases require supporting multiple different input types. One option to handle this is to register different FFI targets for all supported input types and then use Python to select the appropriate target for {func}`jax.extend.ffi.ffi_call` depending on the input types. But, this approach could get quickly unwieldy depending on the combinatorics of the supported cases. So it is also possible to define the C++ handler to accept `ffi::AnyBuffer` instead of `ffi::Buffer`. Then, the input buffer will include a `element_type()` method which can be used to define the appropriate dtype dispatching logic in the backend.\n", + "* **Supporting multiple dtypes**: In this tutorial's example, we restricted to only support `float32` inputs and outputs, but many use cases require supporting multiple different input types. One option to handle this is to register different FFI targets for all supported input types and then use Python to select the appropriate target for {func}`jax.ffi.ffi_call` depending on the input types. But, this approach could get quickly unwieldy depending on the combinatorics of the supported cases. So it is also possible to define the C++ handler to accept `ffi::AnyBuffer` instead of `ffi::Buffer`. Then, the input buffer will include a `element_type()` method which can be used to define the appropriate dtype dispatching logic in the backend.\n", "\n", - "* **Sharding**: When using JAX's automatic data-dependent parallelism within {func}`~jax.jit`, FFI calls implemented using {func}`~jax.extend.ffi.ffi_call` don't have sufficient information to shard appropriately, so they result in a copy of the inputs to all devices and the FFI call gets executed on the full array on each device. To get around this limitation, you can use {func}`~jax.experimental.shard_map.shard_map` or {func}`~jax.experimental.custom_partitioning.custom_partitioning`.\n", + "* **Sharding**: When using JAX's automatic data-dependent parallelism within {func}`~jax.jit`, FFI calls implemented using {func}`~jax.ffi.ffi_call` don't have sufficient information to shard appropriately, so they result in a copy of the inputs to all devices and the FFI call gets executed on the full array on each device. To get around this limitation, you can use {func}`~jax.experimental.shard_map.shard_map` or {func}`~jax.experimental.custom_partitioning.custom_partitioning`.\n", "\n", "* **Stateful foreign functions**: It is also possible to use the FFI to wrap functions with associated state. There is a [low-level example included in the XLA test suite](https://github.com/openxla/xla/blob/737a7da3c5405583dc95773ac0bb11b1349fc9ea/xla/service/gpu/custom_call_test.cc#L794-L845), and a future tutorial will include more details." ] diff --git a/docs/ffi.md b/docs/ffi.md index 96b627675004..6485c9b8369d 100644 --- a/docs/ffi.md +++ b/docs/ffi.md @@ -29,7 +29,7 @@ We will discuss some possible approaches below, but it is important to call this JAX's FFI support is provided in two parts: 1. A header-only C++ library from XLA which is packaged as part of JAX as of v0.4.29 or available from the [openxla/xla](https://github.com/openxla/xla) project, and -2. A Python front end, available in the `jax.extend.ffi` submodule. +2. A Python front end, available in the `jax.ffi` submodule. In this tutorial we demonstrate the use of both of these components using a simple example, and then go on to discuss some lower-level extensions for more complicated use cases. We start by presenting the FFI on CPU, and discuss generalizations to GPU or multi-device environments below. @@ -171,23 +171,22 @@ To compile the shared library, we're using CMake here, but you should be able to !cmake --install ffi/_build ``` -With this compiled library in hand, we now need to register this handler with XLA via the {func}`~jax.extend.ffi.register_ffi_target` function. +With this compiled library in hand, we now need to register this handler with XLA via the {func}`~jax.ffi.register_ffi_target` function. This function expects our handler (a function pointer to the C++ function `RmsNorm`) to be wrapped in a [`PyCapsule`](https://docs.python.org/3/c-api/capsule.html). -JAX provides a helper function {func}`~jax.extend.ffi.pycapsule` to help with this: +JAX provides a helper function {func}`~jax.ffi.pycapsule` to help with this: ```{code-cell} ipython3 import ctypes from pathlib import Path -import jax.extend as jex path = next(Path("ffi").glob("librms_norm*")) rms_norm_lib = ctypes.cdll.LoadLibrary(path) -jex.ffi.register_ffi_target( - "rms_norm", jex.ffi.pycapsule(rms_norm_lib.RmsNorm), platform="cpu") +jax.ffi.register_ffi_target( + "rms_norm", jax.ffi.pycapsule(rms_norm_lib.RmsNorm), platform="cpu") ``` ```{tip} -If you're familiar with the legacy "custom call" API, it's worth noting that you can also use {func}`~jax.extend.ffi.register_ffi_target` to register a custom call target by manually specifying the keyword argument `api_version=0`. The default `api_version` for {func}`~jax.extend.ffi.register_ffi_target` is `1`, the new "typed" FFI API that we're using here. +If you're familiar with the legacy "custom call" API, it's worth noting that you can also use {func}`~jax.ffi.register_ffi_target` to register a custom call target by manually specifying the keyword argument `api_version=0`. The default `api_version` for {func}`~jax.ffi.register_ffi_target` is `1`, the new "typed" FFI API that we're using here. ``` **An alternative approach**: @@ -221,14 +220,14 @@ Then, in Python we can register this handler using: # Assuming that we compiled a nanobind extension called `rms_norm`: import rms_norm as rms_norm_lib -jex.ffi.register_ffi_target("rms_norm", rms_norm_lib.rms_norm(), platform="cpu") +jax.ffi.register_ffi_target("rms_norm", rms_norm_lib.rms_norm(), platform="cpu") ``` +++ ## Frontend code -Now that we have registered our FFI handler, it is straightforward to call our C++ library from JAX using the {func}`~jax.extend.ffi.ffi_call` function: +Now that we have registered our FFI handler, it is straightforward to call our C++ library from JAX using the {func}`~jax.ffi.ffi_call` function: ```{code-cell} ipython3 import numpy as np @@ -243,7 +242,7 @@ def rms_norm(x, eps=1e-5): if x.dtype != jnp.float32: raise ValueError("Only the float32 dtype is implemented by rms_norm") - call = jex.ffi.ffi_call( + call = jax.ffi.ffi_call( # The target name must be the same string as we used to register the target # above in `register_custom_call_target` "rms_norm", @@ -271,25 +270,25 @@ np.testing.assert_allclose(rms_norm(x), rms_norm_ref(x), rtol=1e-5) ``` This code cell includes a lot of inline comments which should explain most of what is happening here, but there are a few points that are worth explicitly highlighting. -Most of the heavy lifting here is done by the {func}`~jax.extend.ffi.ffi_call` function, which tells JAX how to call the foreign function for a particular set of inputs. -It's important to note that the first argument to {func}`~jax.extend.ffi.ffi_call` must be a string that matches the target name that we used when calling `register_custom_call_target` above. +Most of the heavy lifting here is done by the {func}`~jax.ffi.ffi_call` function, which tells JAX how to call the foreign function for a particular set of inputs. +It's important to note that the first argument to {func}`~jax.ffi.ffi_call` must be a string that matches the target name that we used when calling `register_custom_call_target` above. -Any attributes (defined using `Attr` in the C++ wrapper above) should be passed as keyword arguments to {func}`~jax.extend.ffi.ffi_call`. +Any attributes (defined using `Attr` in the C++ wrapper above) should be passed as keyword arguments to {func}`~jax.ffi.ffi_call`. Note that we explicitly cast `eps` to `np.float32` because our FFI library expects a C `float`, and we can't use `jax.numpy` here, because these parameters must be static arguments. -The `vmap_method` argument to {func}`~jax.extend.ffi.ffi_call` defines how this FFI call interacts with {func}`~jax.vmap` as described next. +The `vmap_method` argument to {func}`~jax.ffi.ffi_call` defines how this FFI call interacts with {func}`~jax.vmap` as described next. ```{tip} -If you are familiar with the earlier "custom call" interface, you might be surprised that we're not passing the problem dimensions as parameters (batch size, etc.) to {func}`~jax.extend.ffi.ffi_call`. +If you are familiar with the earlier "custom call" interface, you might be surprised that we're not passing the problem dimensions as parameters (batch size, etc.) to {func}`~jax.ffi.ffi_call`. In this earlier API, the backend had no mechanism for receiving metadata about the input arrays, but since the FFI includes dimension information with the `Buffer` objects, we no longer need to compute this using Python when lowering. -One major perk of this change is {func}`~jax.extend.ffi.ffi_call` can support some simple {func}`~jax.vmap` semantics out of the box, as discussed below. +One major perk of this change is {func}`~jax.ffi.ffi_call` can support some simple {func}`~jax.vmap` semantics out of the box, as discussed below. ``` (ffi-call-vmap)= ### Batching with `vmap` -{func}`~jax.extend.ffi.ffi_call` supports some simple {func}`~jax.vmap` semantics out of the box using the `vmap_method` parameter. -The docs for {func}`~jax.pure_callback` provide more details about the `vmap_method` parameter, and the same behavior applies to {func}`~jax.extend.ffi.ffi_call`. +{func}`~jax.ffi.ffi_call` supports some simple {func}`~jax.vmap` semantics out of the box using the `vmap_method` parameter. +The docs for {func}`~jax.pure_callback` provide more details about the `vmap_method` parameter, and the same behavior applies to {func}`~jax.ffi.ffi_call`. The simplest `vmap_method` is `"sequential"`. In this case, when `vmap`ped, an `ffi_call` will be rewritten as a {func}`~jax.lax.scan` with the `ffi_call` in the body. @@ -326,7 +325,7 @@ Using `vmap_method="sequential"`, `vmap`ping a `ffi_call` will fall back on a {f ```{code-cell} ipython3 def rms_norm_sequential(x, eps=1e-5): - return jex.ffi.ffi_call( + return jax.ffi.ffi_call( "rms_norm", jax.ShapeDtypeStruct(x.shape, x.dtype), vmap_method="sequential", @@ -342,9 +341,9 @@ If your foreign function provides an efficient batching rule that isn't supporte ### Differentiation -Unlike with batching, {func}`~jax.extend.ffi.ffi_call` doesn't provide any default support for automatic differentiation (AD) of foreign functions. +Unlike with batching, {func}`~jax.ffi.ffi_call` doesn't provide any default support for automatic differentiation (AD) of foreign functions. As far as JAX is concerned, the foreign function is a black box that can't be inspected to determine the appropriate behavior when differentiated. -Therefore, it is the {func}`~jax.extend.ffi.ffi_call` user's responsibility to define a custom derivative rule. +Therefore, it is the {func}`~jax.ffi.ffi_call` user's responsibility to define a custom derivative rule. More details about custom derivative rules can be found in the [custom derivatives tutorial](https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html), but the most common pattern used for implementing differentiation for foreign functions is to define a {func}`~jax.custom_vjp` which itself calls a foreign function. In this case, we actually define two new FFI calls: @@ -353,21 +352,21 @@ In this case, we actually define two new FFI calls: 2. `rms_norm_bwd` takes the residuals and the output co-tangents, and returns the input co-tangents. We won't get into the details of the RMS normalization backwards pass, but take a look at the [C++ source code](https://github.com/jax-ml/jax/blob/main/examples/ffi/src/jax_ffi_example/rms_norm.cc) to see how these functions are implemented on the back end. -The main point to emphasize here is that the "residual" computed has a different shape than the primal output, therefore, in the {func}`~jax.extend.ffi.ffi_call` to `res_norm_fwd`, the output type has two elements with different shapes. +The main point to emphasize here is that the "residual" computed has a different shape than the primal output, therefore, in the {func}`~jax.ffi.ffi_call` to `res_norm_fwd`, the output type has two elements with different shapes. This custom derivative rule can be wired in as follows: ```{code-cell} ipython3 -jex.ffi.register_ffi_target( - "rms_norm_fwd", jex.ffi.pycapsule(rms_norm_lib.RmsNormFwd), platform="cpu" +jax.ffi.register_ffi_target( + "rms_norm_fwd", jax.ffi.pycapsule(rms_norm_lib.RmsNormFwd), platform="cpu" ) -jex.ffi.register_ffi_target( - "rms_norm_bwd", jex.ffi.pycapsule(rms_norm_lib.RmsNormBwd), platform="cpu" +jax.ffi.register_ffi_target( + "rms_norm_bwd", jax.ffi.pycapsule(rms_norm_lib.RmsNormBwd), platform="cpu" ) def rms_norm_fwd(x, eps=1e-5): - y, res = jex.ffi.ffi_call( + y, res = jax.ffi.ffi_call( "rms_norm_fwd", ( jax.ShapeDtypeStruct(x.shape, x.dtype), @@ -384,7 +383,7 @@ def rms_norm_bwd(eps, res, ct): assert res.shape == ct.shape[:-1] assert x.shape == ct.shape return ( - jex.ffi.ffi_call( + jax.ffi.ffi_call( "rms_norm_bwd", jax.ShapeDtypeStruct(ct.shape, ct.dtype), vmap_method="broadcast_all", @@ -447,7 +446,7 @@ Then, the `RmsNormImpl` can use the CUDA stream to launch CUDA kernels. On the front end, the registration code would be updated to specify the appropriate platform: ```python -jex.ffi.register_ffi_target( +jax.ffi.register_ffi_target( "rms_norm_cuda", rms_norm_lib_cuda.rms_norm(), platform="CUDA" ) ``` @@ -462,7 +461,7 @@ def rms_norm_cross_platform(x, eps=1e-5): out_type = jax.ShapeDtypeStruct(x.shape, x.dtype) def impl(target_name): - return lambda x: jex.ffi.ffi_call( + return lambda x: jax.ffi.ffi_call( target_name, out_type, vmap_method="broadcast_all", @@ -499,8 +498,8 @@ and there will be no runtime overhead to using {func}`jax.lax.platform_dependent This tutorial covers most of the basic steps that are required to get up and running with JAX's FFI, but advanced use cases may require more features. We will leave these topics to future tutorials, but here are some possibly useful references: -* **Supporting multiple dtypes**: In this tutorial's example, we restricted to only support `float32` inputs and outputs, but many use cases require supporting multiple different input types. One option to handle this is to register different FFI targets for all supported input types and then use Python to select the appropriate target for {func}`jax.extend.ffi.ffi_call` depending on the input types. But, this approach could get quickly unwieldy depending on the combinatorics of the supported cases. So it is also possible to define the C++ handler to accept `ffi::AnyBuffer` instead of `ffi::Buffer`. Then, the input buffer will include a `element_type()` method which can be used to define the appropriate dtype dispatching logic in the backend. +* **Supporting multiple dtypes**: In this tutorial's example, we restricted to only support `float32` inputs and outputs, but many use cases require supporting multiple different input types. One option to handle this is to register different FFI targets for all supported input types and then use Python to select the appropriate target for {func}`jax.ffi.ffi_call` depending on the input types. But, this approach could get quickly unwieldy depending on the combinatorics of the supported cases. So it is also possible to define the C++ handler to accept `ffi::AnyBuffer` instead of `ffi::Buffer`. Then, the input buffer will include a `element_type()` method which can be used to define the appropriate dtype dispatching logic in the backend. -* **Sharding**: When using JAX's automatic data-dependent parallelism within {func}`~jax.jit`, FFI calls implemented using {func}`~jax.extend.ffi.ffi_call` don't have sufficient information to shard appropriately, so they result in a copy of the inputs to all devices and the FFI call gets executed on the full array on each device. To get around this limitation, you can use {func}`~jax.experimental.shard_map.shard_map` or {func}`~jax.experimental.custom_partitioning.custom_partitioning`. +* **Sharding**: When using JAX's automatic data-dependent parallelism within {func}`~jax.jit`, FFI calls implemented using {func}`~jax.ffi.ffi_call` don't have sufficient information to shard appropriately, so they result in a copy of the inputs to all devices and the FFI call gets executed on the full array on each device. To get around this limitation, you can use {func}`~jax.experimental.shard_map.shard_map` or {func}`~jax.experimental.custom_partitioning.custom_partitioning`. * **Stateful foreign functions**: It is also possible to use the FFI to wrap functions with associated state. There is a [low-level example included in the XLA test suite](https://github.com/openxla/xla/blob/737a7da3c5405583dc95773ac0bb11b1349fc9ea/xla/service/gpu/custom_call_test.cc#L794-L845), and a future tutorial will include more details. diff --git a/docs/jax.extend.rst b/docs/jax.extend.rst index 0d68013c9261..3fb2b9d830c0 100644 --- a/docs/jax.extend.rst +++ b/docs/jax.extend.rst @@ -12,7 +12,6 @@ Modules :maxdepth: 1 jax.extend.core - jax.extend.ffi jax.extend.linear_util jax.extend.mlir jax.extend.random diff --git a/docs/jax.extend.ffi.rst b/docs/jax.ffi.rst similarity index 55% rename from docs/jax.extend.ffi.rst rename to docs/jax.ffi.rst index ac8e38c5e89a..aa652947a70e 100644 --- a/docs/jax.extend.ffi.rst +++ b/docs/jax.ffi.rst @@ -1,7 +1,7 @@ -``jax.extend.ffi`` module -========================= +``jax.ffi`` module +================== -.. automodule:: jax.extend.ffi +.. automodule:: jax.ffi .. autosummary:: :toctree: _autosummary diff --git a/docs/jax.rst b/docs/jax.rst index 042804792f8a..4a5c429abaff 100644 --- a/docs/jax.rst +++ b/docs/jax.rst @@ -18,6 +18,7 @@ Subpackages jax.dlpack jax.distributed jax.dtypes + jax.ffi jax.flatten_util jax.image jax.nn diff --git a/examples/ffi/src/jax_ffi_example/cpu_examples.py b/examples/ffi/src/jax_ffi_example/cpu_examples.py index 7771237e41d1..563e5a911b99 100644 --- a/examples/ffi/src/jax_ffi_example/cpu_examples.py +++ b/examples/ffi/src/jax_ffi_example/cpu_examples.py @@ -15,28 +15,27 @@ import numpy as np import jax -import jax.extend as jex from jax_ffi_example import _cpu_examples for name, target in _cpu_examples.registrations().items(): - jex.ffi.register_ffi_target(name, target) + jax.ffi.register_ffi_target(name, target) def array_attr(num: int): - return jex.ffi.ffi_call( + return jax.ffi.ffi_call( "array_attr", jax.ShapeDtypeStruct((), np.int32), )(array=np.arange(num, dtype=np.int32)) def dictionary_attr(**kwargs): - return jex.ffi.ffi_call( + return jax.ffi.ffi_call( "dictionary_attr", (jax.ShapeDtypeStruct((), np.int32), jax.ShapeDtypeStruct((), np.int32)), )(**kwargs) def counter(index): - return jex.ffi.ffi_call( + return jax.ffi.ffi_call( "counter", jax.ShapeDtypeStruct((), jax.numpy.int32))(index=int(index)) diff --git a/examples/ffi/src/jax_ffi_example/cuda_examples.py b/examples/ffi/src/jax_ffi_example/cuda_examples.py index b60b12af577e..ab660d8a3ca3 100644 --- a/examples/ffi/src/jax_ffi_example/cuda_examples.py +++ b/examples/ffi/src/jax_ffi_example/cuda_examples.py @@ -24,15 +24,14 @@ import jax import jax.numpy as jnp -import jax.extend as jex # Load the shared library with the FFI target definitions SHARED_LIBRARY = os.path.join(os.path.dirname(__file__), "lib_cuda_examples.so") library = ctypes.cdll.LoadLibrary(SHARED_LIBRARY) -jex.ffi.register_ffi_target("foo-fwd", jex.ffi.pycapsule(library.FooFwd), +jax.ffi.register_ffi_target("foo-fwd", jax.ffi.pycapsule(library.FooFwd), platform="CUDA") -jex.ffi.register_ffi_target("foo-bwd", jex.ffi.pycapsule(library.FooBwd), +jax.ffi.register_ffi_target("foo-bwd", jax.ffi.pycapsule(library.FooBwd), platform="CUDA") @@ -42,7 +41,7 @@ def foo_fwd(a, b): assert a.dtype == b.dtype n = np.prod(a.shape).astype(np.uint64) out_type = jax.ShapeDtypeStruct(a.shape, a.dtype) - c, b_plus_1 = jex.ffi.ffi_call("foo-fwd", (out_type, out_type))(a, b, n=n) + c, b_plus_1 = jax.ffi.ffi_call("foo-fwd", (out_type, out_type))(a, b, n=n) return c, (a, b_plus_1) @@ -55,7 +54,7 @@ def foo_bwd(res, c_grad): assert a.dtype == b_plus_1.dtype n = np.prod(a.shape).astype(np.uint64) out_type = jax.ShapeDtypeStruct(a.shape, a.dtype) - return jex.ffi.ffi_call("foo-bwd", (out_type, out_type))(c_grad, a, b_plus_1, + return jax.ffi.ffi_call("foo-bwd", (out_type, out_type))(c_grad, a, b_plus_1, n=n) diff --git a/examples/ffi/src/jax_ffi_example/rms_norm.py b/examples/ffi/src/jax_ffi_example/rms_norm.py index a2606e3d6002..851f1900ca3c 100644 --- a/examples/ffi/src/jax_ffi_example/rms_norm.py +++ b/examples/ffi/src/jax_ffi_example/rms_norm.py @@ -16,7 +16,7 @@ This example is exactly the same as the one in the `FFI tutorial `, so more details can be found on that page. But, the high level summary is that we implement our custom -extension in ``rms_norm.cc``, then call it usin ``jax.extend.ffi.ffi_call`` in +extension in ``rms_norm.cc``, then call it usin ``jax.ffi.ffi_call`` in this module. The behavior under autodiff is implemented using ``jax.custom_vjp``. """ @@ -26,13 +26,12 @@ import numpy as np import jax -import jax.extend as jex import jax.numpy as jnp from jax_ffi_example import _rms_norm for name, target in _rms_norm.registrations().items(): - jex.ffi.register_ffi_target(name, target) + jax.ffi.register_ffi_target(name, target) @partial(jax.custom_vjp, nondiff_argnums=(1,)) @@ -53,7 +52,7 @@ def rms_norm(x, eps=1e-5): # the attribute `eps`. Our FFI function expects this to have the C++ `float` # type (which corresponds to numpy's `float32` type), and it must be a # static parameter (i.e. not a JAX array). - return jex.ffi.ffi_call( + return jax.ffi.ffi_call( # The target name must be the same string as we used to register the target # above in `register_ffi_target` "rms_norm", @@ -63,7 +62,7 @@ def rms_norm(x, eps=1e-5): def rms_norm_fwd(x, eps=1e-5): - y, res = jex.ffi.ffi_call( + y, res = jax.ffi.ffi_call( "rms_norm_fwd", ( jax.ShapeDtypeStruct(x.shape, x.dtype), @@ -80,7 +79,7 @@ def rms_norm_bwd(eps, res, ct): assert res.shape == ct.shape[:-1] assert x.shape == ct.shape return ( - jex.ffi.ffi_call( + jax.ffi.ffi_call( "rms_norm_bwd", jax.ShapeDtypeStruct(ct.shape, ct.dtype), vmap_method="broadcast_all", diff --git a/jax/BUILD b/jax/BUILD index 175f0f6fe1e8..bc31bea451eb 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -199,6 +199,7 @@ py_library_providing_imports_info( "_src/dispatch.py", "_src/dlpack.py", "_src/earray.py", + "_src/ffi.py", "_src/flatten_util.py", "_src/interpreters/__init__.py", "_src/interpreters/ad.py", @@ -730,7 +731,6 @@ py_library( ":jax", ":mlir", "//jax/_src/lib", - "//jax/extend:ffi", "//jaxlib/mlir:arithmetic_dialect", "//jaxlib/mlir:builtin_dialect", "//jaxlib/mlir:func_dialect", diff --git a/jax/__init__.py b/jax/__init__.py index 8ca7721da445..d24ec60e1057 100644 --- a/jax/__init__.py +++ b/jax/__init__.py @@ -160,6 +160,7 @@ from jax import dlpack as dlpack from jax import dtypes as dtypes from jax import errors as errors +from jax import ffi as ffi from jax import image as image from jax import lax as lax from jax import monitoring as monitoring diff --git a/jax/_src/extend/ffi.py b/jax/_src/ffi.py similarity index 98% rename from jax/_src/extend/ffi.py rename to jax/_src/ffi.py index 6459ff751ceb..eef8b1ca99d6 100644 --- a/jax/_src/extend/ffi.py +++ b/jax/_src/ffi.py @@ -76,8 +76,8 @@ def pycapsule(funcptr): """Wrap a ctypes function pointer in a PyCapsule. The primary use of this function, and the reason why it lives with in the - ``jax.extend.ffi`` submodule, is to wrap function calls from external - compiled libraries to be registered as XLA custom calls. + ``jax.ffi`` submodule, is to wrap function calls from external compiled + libraries to be registered as XLA custom calls. Example usage:: @@ -88,7 +88,7 @@ def pycapsule(funcptr): libfoo = ctypes.cdll.LoadLibrary('./foo.so') xla_client.register_custom_call_target( name="bar", - fn=jax.extend.ffi.pycapsule(libfoo.bar), + fn=jax.ffi.pycapsule(libfoo.bar), platform=PLATFORM, api_version=API_VERSION ) @@ -145,7 +145,7 @@ def ffi_lowering( Note that layouts passed to this function as tuples should be in minor-to-major order (as expected by XLA) rather than major-to-minor as used - by :func:`~jax.extend.ffi.ffi_call` and ``DeviceLocalLayout``. + by :func:`~jax.ffi.ffi_call` and ``DeviceLocalLayout``. If keyword arguments are passed to the lowering rule, these are treated as attributes, and added to `backend_config`. @@ -310,7 +310,7 @@ def ffi_call( Args: target_name: the name of the XLA FFI custom call target that was registered - using :func:`~jax.extend.ffi.register_ffi_target`. + using :func:`~jax.ffi.register_ffi_target`. result_shape_dtypes: an object, or sequence of objects, with ``shape`` and ``dtype`` attributes which are expected to match the shape and dtype of the custom call output or outputs. :class:`~jax.ShapeDtypeStruct` is often diff --git a/jax/_src/lax/linalg.py b/jax/_src/lax/linalg.py index 085ca2d0686c..64bb3017f780 100644 --- a/jax/_src/lax/linalg.py +++ b/jax/_src/lax/linalg.py @@ -34,7 +34,7 @@ from jax._src import util from jax._src.core import ( Primitive, ShapedArray, is_constant_dim, is_constant_shape) -from jax._src.extend import ffi +from jax._src import ffi from jax._src.interpreters import ad from jax._src.interpreters import batching from jax._src.interpreters import mlir diff --git a/jax/experimental/mosaic/gpu/profiler.py b/jax/experimental/mosaic/gpu/profiler.py index 358ed841e686..9b1403c9ad49 100644 --- a/jax/experimental/mosaic/gpu/profiler.py +++ b/jax/experimental/mosaic/gpu/profiler.py @@ -22,7 +22,6 @@ import jax from jax._src.lib import xla_client -from jax.extend import ffi import jax.numpy as jnp from jaxlib.mlir import ir from jaxlib.mlir.dialects import arith @@ -54,7 +53,7 @@ def _event_record(args, *, copy_before): flat_args, treedef = jax.tree.flatten(args) - event, *flat_outs = ffi.ffi_call( + event, *flat_outs = jax.ffi.ffi_call( "mgpu_event_record", result_shape_dtypes=(jax.core.ShapedArray((), jnp.uint64), *flat_args), input_output_aliases={i: i + 1 for i in range(len(flat_args))}, @@ -63,7 +62,7 @@ def _event_record(args, *, copy_before): def _event_elapsed(start_event, end_event): - return ffi.ffi_call( + return jax.ffi.ffi_call( "mgpu_event_elapsed", result_shape_dtypes=jax.core.ShapedArray((), jnp.float32), )(start_event, end_event) diff --git a/jax/experimental/shard_map.py b/jax/experimental/shard_map.py index 572440f486eb..62614c33a9f8 100644 --- a/jax/experimental/shard_map.py +++ b/jax/experimental/shard_map.py @@ -52,7 +52,7 @@ from jax._src.lax import (lax, parallel as lax_parallel, slicing, windowed_reductions, convolution, fft, linalg, special, control_flow, ann) -from jax._src.extend import ffi +from jax._src import ffi from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import sdy from jax._src.util import (HashableFunction, HashablePartial, unzip2, diff --git a/jax/extend/ffi.py b/jax/extend/ffi.py index b2d480adc7eb..21642055993b 100644 --- a/jax/extend/ffi.py +++ b/jax/extend/ffi.py @@ -12,13 +12,42 @@ # See the License for the specific language governing permissions and # limitations under the License. -# Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 +from jax._src import ffi as _ffi -from jax._src.extend.ffi import ( - ffi_call as ffi_call, - ffi_lowering as ffi_lowering, - include_dir as include_dir, - pycapsule as pycapsule, - register_ffi_target as register_ffi_target, -) +_deprecations = { + # Added 2024-12-20 + "ffi_call": ( + "jax.extend.ffi.ffi_call is deprecated, use jax.ffi.ffi_call instead.", + _ffi.ffi_call, + ), + "ffi_lowering": ( + "jax.extend.ffi.ffi_lowering is deprecated, use jax.ffi.ffi_lowering instead.", + _ffi.ffi_lowering, + ), + "include_dir": ( + "jax.extend.ffi.include_dir is deprecated, use jax.ffi.include_dir instead.", + _ffi.include_dir, + ), + "pycapsule": ( + "jax.extend.ffi.pycapsule is deprecated, use jax.ffi.pycapsule instead.", + _ffi.pycapsule, + ), + "register_ffi_target": ( + "jax.extend.ffi.register_ffi_target is deprecated, use jax.ffi.register_ffi_target instead.", + _ffi.register_ffi_target, + ), +} + +import typing +if typing.TYPE_CHECKING: + ffi_call = _ffi.ffi_call + ffi_lowering = _ffi.ffi_lowering + include_dir = _ffi.include_dir + pycapsule = _ffi.pycapsule + register_ffi_target = _ffi.register_ffi_target +else: + from jax._src.deprecations import deprecation_getattr as _deprecation_getattr + __getattr__ = _deprecation_getattr(__name__, _deprecations) + del _deprecation_getattr +del typing +del _ffi diff --git a/jax/ffi.py b/jax/ffi.py new file mode 100644 index 000000000000..529818ff59da --- /dev/null +++ b/jax/ffi.py @@ -0,0 +1,24 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Note: import as is required for names to be exported. +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 + +from jax._src.ffi import ( + ffi_call as ffi_call, + ffi_lowering as ffi_lowering, + include_dir as include_dir, + pycapsule as pycapsule, + register_ffi_target as register_ffi_target, +) diff --git a/tests/BUILD b/tests/BUILD index c25d10f460aa..e21f3e98d96b 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -142,6 +142,13 @@ jax_multiplatform_test( deps = ["//jax:extend"], ) +jax_multiplatform_test( + name = "ffi_test", + srcs = ["ffi_test.py"], + # TODO(dfm): Remove after removal of jex.ffi imports. + deps = ["//jax:extend"], +) + jax_multiplatform_test( name = "fft_test", srcs = ["fft_test.py"], diff --git a/tests/extend_test.py b/tests/extend_test.py index 3561e716f09c..e37bea42c3e6 100644 --- a/tests/extend_test.py +++ b/tests/extend_test.py @@ -12,34 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os -import unittest -from functools import partial - -import numpy as np from absl.testing import absltest -from absl.testing import parameterized import jax -from jax import lax import jax.extend as jex import jax.numpy as jnp -import jax.sharding as shd from jax._src import abstract_arrays from jax._src import api -from jax._src import config -from jax._src import core from jax._src import linear_util from jax._src import prng from jax._src import test_util as jtu from jax._src import xla_bridge from jax._src.interpreters import mlir -from jax._src.layout import DeviceLocalLayout -from jax._src.lib import lapack -from jax._src.lib.mlir.dialects import hlo -from jax._src.lax import linalg as lax_linalg_internal -from jax.experimental.shard_map import shard_map jax.config.parse_flags_with_absl() @@ -109,289 +94,6 @@ def test_key_impl_is_spec(self): self.assertEqual(repr(spec), f"PRNGSpec({spec_ref._impl.name!r})") -class FfiTest(jtu.JaxTestCase): - - def find_custom_call_in_module(self, module): - for func in module.body.operations: - for block in func.body.blocks: - for op in block.operations: - if op.OPERATION_NAME == "stablehlo.custom_call": - return op - self.fail("No custom_call found in the lowered IR") - - def testHeadersExist(self): - base_dir = os.path.join(jex.ffi.include_dir(), "xla", "ffi", "api") - for header in ["c_api.h", "api.h", "ffi.h"]: - self.assertTrue(os.path.exists(os.path.join(base_dir, header))) - - @parameterized.parameters([ - (tuple(range(3)), tuple(range(3))), - (None, tuple(reversed(range(3)))), - (DeviceLocalLayout(tuple(range(3))), tuple(reversed(range(3)))), - ]) - def testLoweringLayouts(self, layout_spec, expected_layout): - # Regression test to ensure that the lowering rule properly captures - # layouts. - def lowering_rule(ctx, x): - aval, = ctx.avals_in - return jex.ffi.ffi_lowering("test_ffi", operand_layouts=[layout_spec], - result_layouts=[layout_spec])(ctx, x) - prim = core.Primitive("test_ffi") - prim.def_impl(lambda x: x) - prim.def_abstract_eval(lambda x: x) - mlir.register_lowering(prim, lowering_rule) - - x = jnp.ones((3,) * len(expected_layout)) - lowered = jax.jit(prim.bind).lower(x) - module = lowered.compiler_ir("stablehlo") - op = self.find_custom_call_in_module(module) - self.assertIn("operand_layouts", op.attributes) - self.assertIn("result_layouts", op.attributes) - - text = lowered.as_text() - expected = ", ".join(map(str, expected_layout)) - pattern = rf"operand_layouts = \[dense<\[{expected}\]>" - self.assertRegex(text, pattern) - pattern = rf"result_layouts = \[dense<\[{expected}\]>" - self.assertRegex(text, pattern) - - @parameterized.parameters([ - (True, mlir.ir.BoolAttr.get), - (1, mlir.i64_attr), - (5.0, lambda x: mlir.ir.FloatAttr.get(mlir.ir.F64Type.get(), x)), - ("param", mlir.ir.StringAttr.get), - (np.float32(0.5), - lambda x: mlir.ir.FloatAttr.get(mlir.ir.F32Type.get(), x)), - ]) - def testParams(self, param, expected_builder): - def fun(x): - return jex.ffi.ffi_call("test_ffi", x)(x, param=param) - - # Here we inspect the lowered IR to test that the parameter has been - # serialized with the appropriate type. - module = jax.jit(fun).lower(0.5).compiler_ir("stablehlo") - op = self.find_custom_call_in_module(module) - config = op.attributes["mhlo.backend_config"] - self.assertIsInstance(config, mlir.ir.DictAttr) - self.assertIn("param", config) - with mlir.make_ir_context(), mlir.ir.Location.unknown(): - expected = expected_builder(param) - self.assertEqual(type(config["param"]), type(expected)) - self.assertTrue(expected.type.isinstance(config["param"].type)) - - def testToken(self): - def fun(): - token = lax.create_token() - return jex.ffi.ffi_call("test_ffi", core.abstract_token)(token) - - # Ensure that token inputs and outputs are translated to the correct type - module = jax.jit(fun).lower().compiler_ir("stablehlo") - op = self.find_custom_call_in_module(module) - self.assertTrue(hlo.TokenType.isinstance(op.operands[0].type)) - self.assertTrue(hlo.TokenType.isinstance(op.results[0].type)) - - def testEffectsHlo(self): - # The target name must exist on the current platform, but we don't actually - # need to call it with the correct syntax, because we're only checking the - # compiled HLO. - if jtu.test_device_matches(["cpu"]): - target_name = "lapack_sgetrf_ffi" - elif jtu.test_device_matches(["rocm"]): - target_name = "hipsolver_getrf_ffi" - elif jtu.test_device_matches(["cuda", "gpu"]): - target_name = "cusolver_getrf_ffi" - else: - raise unittest.SkipTest("Unsupported device") - def fun(): - jex.ffi.ffi_call(target_name, (), has_side_effect=True)() - hlo = jax.jit(fun).lower() - self.assertIn(target_name, hlo.as_text()) - self.assertIn("has_side_effect = true", hlo.as_text()) - self.assertIn(target_name, hlo.compile().as_text()) - - def testJvpError(self): - def fun(x): - return jex.ffi.ffi_call("test_ffi", x)(x, non_hashable_arg={"a": 1}) - with self.assertRaisesRegex( - ValueError, "The FFI call to `.+` cannot be differentiated."): - jax.jvp(fun, (0.5,), (0.5,)) - - def testNonHashableAttributes(self): - def fun(x): - return jex.ffi.ffi_call("test_ffi", x)(x, non_hashable_arg={"a": 1}) - - self.assertIn("HashableDict", str(jax.make_jaxpr(fun)(jnp.ones(5)))) - hlo = jax.jit(fun).lower(jnp.ones(5)).as_text() - self.assertIn("non_hashable_arg = {a = 1", hlo) - - # If non-hashable arguments aren't handled properly, this will raise a - # TypeError. We make sure it doesn't. - with self.assertRaises(Exception) as manager: - fun(jnp.ones(5)) - self.assertNotIsInstance(manager.exception, TypeError) - - def fun(x): - return jex.ffi.ffi_call("test_ffi", x)(x, non_hashable_arg=np.arange(3)) - self.assertIn("HashableArray", str(jax.make_jaxpr(fun)(jnp.ones(5)))) - hlo = jax.jit(fun).lower(jnp.ones(5)).as_text() - self.assertIn("non_hashable_arg = array", hlo) - with self.assertRaises(Exception) as manager: - fun(jnp.ones(5)) - self.assertNotIsInstance(manager.exception, TypeError) - - @jtu.sample_product(shape=[(6, 5), (4, 5, 6)]) - @jtu.run_on_devices("gpu", "cpu") - def testFfiCall(self, shape): - x = self.rng().randn(*shape).astype(np.float32) - expected = lax_linalg_internal.geqrf(x) - actual = ffi_call_geqrf(x) - for a, b in zip(actual, expected): - self.assertArraysEqual(a, b) - - @jtu.sample_product( - shape=[(6, 5), (4, 5, 6)], - vmap_method=["expand_dims", "broadcast_all", "sequential"], - ) - @jtu.run_on_devices("gpu", "cpu") - def testFfiCallBatching(self, shape, vmap_method): - shape = (10,) + shape - x = self.rng().randn(*shape).astype(np.float32) - expected = lax_linalg_internal.geqrf(x) - actual = jax.vmap(partial(ffi_call_geqrf, vmap_method=vmap_method))(x) - for a, b in zip(actual, expected): - if vmap_method == "sequential" and len(shape) == 3: - # On GPU, the batched FFI call to geqrf uses an algorithm with - # different numerics than the unbatched version (which is used when - # vmap_method="sequential"). Therefore, we need to include floating - # point tolerance for this check. - self.assertArraysAllClose(a, b) - else: - self.assertArraysEqual(a, b) - - @jtu.run_on_devices("gpu", "cpu") - def testVectorizedDeprecation(self): - x = self.rng().randn(3, 5, 4).astype(np.float32) - with self.assertWarns(DeprecationWarning): - ffi_call_geqrf(x, vectorized=True) - with self.assertWarns(DeprecationWarning): - jax.vmap(ffi_call_geqrf)(x) - - def testBackwardCompatSyntax(self): - def fun(x): - return jex.ffi.ffi_call("test_ffi", x, x, param=0.5) - with self.assertWarns(DeprecationWarning): - jax.jit(fun).lower(jnp.ones(5)) - - def testInputOutputAliases(self): - def fun(x): - return jex.ffi.ffi_call("test", x, input_output_aliases={0: 0})(x) - hlo = jax.jit(fun).lower(jnp.ones(5)).as_text() - self.assertRegex(hlo, r"output_operand_aliases = \[.*operand_index = 0.*\]") - - def testInvalidInputOutputAliases(self): - def fun(x): - return jex.ffi.ffi_call("test", x, input_output_aliases={1: 0})(x) - with self.assertRaisesRegex(ValueError, "with input index"): - jax.jit(fun).lower(jnp.ones(5)).as_text() - - def fun(x): - return jex.ffi.ffi_call("test", x, input_output_aliases={0: 1})(x) - with self.assertRaisesRegex(ValueError, "with output index"): - jax.jit(fun).lower(jnp.ones(5)).as_text() - - def fun(x): - return jex.ffi.ffi_call("test", jax.ShapeDtypeStruct(x.shape, np.int32), - input_output_aliases={0: 0})(x) - with self.assertRaisesRegex(ValueError, - "referring to an input with abstract value"): - jax.jit(fun).lower(jnp.ones(5)).as_text() - - def fun(x): - return jex.ffi.ffi_call("test", jax.ShapeDtypeStruct(x.shape + x.shape, - x.dtype), - input_output_aliases={0: 0})(x) - with self.assertRaisesRegex(ValueError, - "referring to an input with abstract value"): - jax.jit(fun).lower(jnp.ones(5)).as_text() - - def testLegacyBackendConfig(self): - def fun(x): - return jex.ffi.ffi_call("test", x, custom_call_api_version=2, - legacy_backend_config="12345")(x) - hlo = jax.jit(fun).lower(jnp.ones(5)).as_text() - self.assertRegex(hlo, 'backend_config = "12345"') - - def testInvalidBackendConfig(self): - def fun(x): - return jex.ffi.ffi_call("test", x, legacy_backend_config="12345")(x) - with self.assertRaisesRegex(ValueError, - "The use of the legacy_backend_config"): - jax.jit(fun).lower(jnp.ones(5)).as_text() - - def fun(x): - return jex.ffi.ffi_call("test", x, - custom_call_api_version=2)(x, attribute=1) - with self.assertRaisesRegex(ValueError, - "The use of ffi_call attributes requires"): - jax.jit(fun).lower(jnp.ones(5)).as_text() - - def testAllow64(self): - if config.enable_x64.value: - self.skipTest("Requires enable_x64=False") - def fun(): - return jex.ffi.ffi_call("test", jax.ShapeDtypeStruct((), np.int64))() - self.assertIn("tensor", jax.jit(fun).lower().as_text()) - - def testInvalidResultType(self): - with self.assertRaisesRegex( - ValueError, "All elements of result_shape_dtypes.*position 0"): - jex.ffi.ffi_call("test", None)() - with self.assertRaisesRegex( - ValueError, "All elements of result_shape_dtypes.*position 1"): - jex.ffi.ffi_call("test", (jax.ShapeDtypeStruct((), np.float32), ()))() - - @jtu.run_on_devices("gpu", "cpu") - def testShardMap(self): - mesh = jtu.create_mesh((1,), ("i",)) - x = self.rng().randn(8, 4, 5).astype(np.float32) - - @partial(shard_map, mesh=mesh, in_specs=shd.PartitionSpec('i'), - out_specs=shd.PartitionSpec('i')) - def f(x): - return ffi_call_geqrf(x) - - f(x) # eager mode doesn't crash - jax.jit(f)(x) # neither does JIT - self.assertNotIn("all-gather", jax.jit(f).lower(x).compile().as_text()) - - -def ffi_call_geqrf(x, **kwargs): - if jtu.test_device_matches(["cpu"]): - lapack._lapack.initialize() - - assert x.dtype == np.float32 - ndim = x.ndim - x_major_to_minor = tuple(range(ndim - 2)) + (ndim - 1, ndim - 2) - output_types = [ - x, jax.ShapeDtypeStruct(x.shape[:-2] + (min(*x.shape[-2:]),), x.dtype)] - - def call(platform, x): - target_name = dict( - cpu="lapack_sgeqrf_ffi", - rocm="hipsolver_geqrf_ffi", - cuda="cusolver_geqrf_ffi", - )[platform] - return jex.ffi.ffi_call( - target_name, output_types, input_output_aliases={0: 0}, - input_layouts=[x_major_to_minor], - output_layouts=[x_major_to_minor, None], - **kwargs)(x) - - return lax.platform_dependent( - x, cpu=partial(call, "cpu"), rocm=partial(call, "rocm"), - cuda=partial(call, "cuda")) - - class MlirRegisterLoweringTest(jtu.JaxTestCase): def test_unknown_platform_error(self): diff --git a/tests/ffi_test.py b/tests/ffi_test.py new file mode 100644 index 000000000000..510506d475ef --- /dev/null +++ b/tests/ffi_test.py @@ -0,0 +1,332 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import unittest +from functools import partial + +import numpy as np +from absl.testing import absltest +from absl.testing import parameterized + +import jax +from jax import lax +import jax.extend as jex +import jax.numpy as jnp +import jax.sharding as shd + +from jax._src import config +from jax._src import core +from jax._src import test_util as jtu +from jax._src.interpreters import mlir +from jax._src.layout import DeviceLocalLayout +from jax._src.lib import lapack +from jax._src.lib.mlir.dialects import hlo +from jax._src.lax import linalg as lax_linalg_internal +from jax.experimental.shard_map import shard_map + +jax.config.parse_flags_with_absl() + + +class FfiTest(jtu.JaxTestCase): + + def find_custom_call_in_module(self, module): + for func in module.body.operations: + for block in func.body.blocks: + for op in block.operations: + if op.OPERATION_NAME == "stablehlo.custom_call": + return op + self.fail("No custom_call found in the lowered IR") + + def test_headers_exist(self): + base_dir = os.path.join(jax.ffi.include_dir(), "xla", "ffi", "api") + for header in ["c_api.h", "api.h", "ffi.h"]: + self.assertTrue(os.path.exists(os.path.join(base_dir, header))) + + @parameterized.parameters([ + (tuple(range(3)), tuple(range(3))), + (None, tuple(reversed(range(3)))), + (DeviceLocalLayout(tuple(range(3))), tuple(reversed(range(3)))), + ]) + def test_lowering_layouts(self, layout_spec, expected_layout): + # Regression test to ensure that the lowering rule properly captures + # layouts. + def lowering_rule(ctx, x): + return jax.ffi.ffi_lowering("test_ffi", operand_layouts=[layout_spec], + result_layouts=[layout_spec])(ctx, x) + prim = core.Primitive("test_ffi") + prim.def_impl(lambda x: x) + prim.def_abstract_eval(lambda x: x) + mlir.register_lowering(prim, lowering_rule) + + x = jnp.ones((3,) * len(expected_layout)) + lowered = jax.jit(prim.bind).lower(x) + module = lowered.compiler_ir("stablehlo") + op = self.find_custom_call_in_module(module) + self.assertIn("operand_layouts", op.attributes) + self.assertIn("result_layouts", op.attributes) + + text = lowered.as_text() + expected = ", ".join(map(str, expected_layout)) + pattern = rf"operand_layouts = \[dense<\[{expected}\]>" + self.assertRegex(text, pattern) + pattern = rf"result_layouts = \[dense<\[{expected}\]>" + self.assertRegex(text, pattern) + + @parameterized.parameters([ + (True, mlir.ir.BoolAttr.get), + (1, mlir.i64_attr), + (5.0, lambda x: mlir.ir.FloatAttr.get(mlir.ir.F64Type.get(), x)), + ("param", mlir.ir.StringAttr.get), + (np.float32(0.5), + lambda x: mlir.ir.FloatAttr.get(mlir.ir.F32Type.get(), x)), + ]) + def test_params(self, param, expected_builder): + def fun(x): + return jax.ffi.ffi_call("test_ffi", x)(x, param=param) + + # Here we inspect the lowered IR to test that the parameter has been + # serialized with the appropriate type. + module = jax.jit(fun).lower(0.5).compiler_ir("stablehlo") + op = self.find_custom_call_in_module(module) + config = op.attributes["mhlo.backend_config"] + self.assertIsInstance(config, mlir.ir.DictAttr) + self.assertIn("param", config) + with mlir.make_ir_context(), mlir.ir.Location.unknown(): + expected = expected_builder(param) + self.assertEqual(type(config["param"]), type(expected)) + self.assertTrue(expected.type.isinstance(config["param"].type)) + + def test_token(self): + def fun(): + token = lax.create_token() + return jax.ffi.ffi_call("test_ffi", core.abstract_token)(token) + + # Ensure that token inputs and outputs are translated to the correct type + module = jax.jit(fun).lower().compiler_ir("stablehlo") + op = self.find_custom_call_in_module(module) + self.assertTrue(hlo.TokenType.isinstance(op.operands[0].type)) + self.assertTrue(hlo.TokenType.isinstance(op.results[0].type)) + + def test_effects_hlo(self): + # The target name must exist on the current platform, but we don't actually + # need to call it with the correct syntax, because we're only checking the + # compiled HLO. + if jtu.test_device_matches(["cpu"]): + target_name = "lapack_sgetrf_ffi" + elif jtu.test_device_matches(["rocm"]): + target_name = "hipsolver_getrf_ffi" + elif jtu.test_device_matches(["cuda", "gpu"]): + target_name = "cusolver_getrf_ffi" + else: + raise unittest.SkipTest("Unsupported device") + def fun(): + jax.ffi.ffi_call(target_name, (), has_side_effect=True)() + hlo = jax.jit(fun).lower() + self.assertIn(target_name, hlo.as_text()) + self.assertIn("has_side_effect = true", hlo.as_text()) + self.assertIn(target_name, hlo.compile().as_text()) + + def test_jvp_error(self): + def fun(x): + return jax.ffi.ffi_call("test_ffi", x)(x, non_hashable_arg={"a": 1}) + with self.assertRaisesRegex( + ValueError, "The FFI call to `.+` cannot be differentiated."): + jax.jvp(fun, (0.5,), (0.5,)) + + def test_non_hashable_attributes(self): + def fun(x): + return jax.ffi.ffi_call("test_ffi", x)(x, non_hashable_arg={"a": 1}) + + self.assertIn("HashableDict", str(jax.make_jaxpr(fun)(jnp.ones(5)))) + hlo = jax.jit(fun).lower(jnp.ones(5)).as_text() + self.assertIn("non_hashable_arg = {a = 1", hlo) + + # If non-hashable arguments aren't handled properly, this will raise a + # TypeError. We make sure it doesn't. + with self.assertRaises(Exception) as manager: + fun(jnp.ones(5)) + self.assertNotIsInstance(manager.exception, TypeError) + + def fun(x): + return jax.ffi.ffi_call("test_ffi", x)(x, non_hashable_arg=np.arange(3)) + self.assertIn("HashableArray", str(jax.make_jaxpr(fun)(jnp.ones(5)))) + hlo = jax.jit(fun).lower(jnp.ones(5)).as_text() + self.assertIn("non_hashable_arg = array", hlo) + with self.assertRaises(Exception) as manager: + fun(jnp.ones(5)) + self.assertNotIsInstance(manager.exception, TypeError) + + @jtu.sample_product(shape=[(6, 5), (4, 5, 6)]) + @jtu.run_on_devices("gpu", "cpu") + def test_ffi_call(self, shape): + x = self.rng().randn(*shape).astype(np.float32) + expected = lax_linalg_internal.geqrf(x) + actual = ffi_call_geqrf(x) + for a, b in zip(actual, expected): + self.assertArraysEqual(a, b) + + @jtu.sample_product( + shape=[(6, 5), (4, 5, 6)], + vmap_method=["expand_dims", "broadcast_all", "sequential"], + ) + @jtu.run_on_devices("gpu", "cpu") + def test_ffi_call_batching(self, shape, vmap_method): + shape = (10,) + shape + x = self.rng().randn(*shape).astype(np.float32) + expected = lax_linalg_internal.geqrf(x) + actual = jax.vmap(partial(ffi_call_geqrf, vmap_method=vmap_method))(x) + for a, b in zip(actual, expected): + if vmap_method == "sequential" and len(shape) == 3: + # On GPU, the batched FFI call to geqrf uses an algorithm with + # different numerics than the unbatched version (which is used when + # vmap_method="sequential"). Therefore, we need to include floating + # point tolerance for this check. + self.assertArraysAllClose(a, b) + else: + self.assertArraysEqual(a, b) + + @jtu.run_on_devices("gpu", "cpu") + def test_vectorized_deprecation(self): + x = self.rng().randn(3, 5, 4).astype(np.float32) + with self.assertWarns(DeprecationWarning): + ffi_call_geqrf(x, vectorized=True) + with self.assertWarns(DeprecationWarning): + jax.vmap(ffi_call_geqrf)(x) + + def test_backward_compat_syntax(self): + def fun(x): + return jax.ffi.ffi_call("test_ffi", x, x, param=0.5) + with self.assertWarns(DeprecationWarning): + jax.jit(fun).lower(jnp.ones(5)) + + def test_input_output_aliases(self): + def fun(x): + return jax.ffi.ffi_call("test", x, input_output_aliases={0: 0})(x) + hlo = jax.jit(fun).lower(jnp.ones(5)).as_text() + self.assertRegex(hlo, r"output_operand_aliases = \[.*operand_index = 0.*\]") + + def test_invalid_input_output_aliases(self): + def fun(x): + return jax.ffi.ffi_call("test", x, input_output_aliases={1: 0})(x) + with self.assertRaisesRegex(ValueError, "with input index"): + jax.jit(fun).lower(jnp.ones(5)).as_text() + + def fun(x): + return jax.ffi.ffi_call("test", x, input_output_aliases={0: 1})(x) + with self.assertRaisesRegex(ValueError, "with output index"): + jax.jit(fun).lower(jnp.ones(5)).as_text() + + def fun(x): + return jax.ffi.ffi_call("test", jax.ShapeDtypeStruct(x.shape, np.int32), + input_output_aliases={0: 0})(x) + with self.assertRaisesRegex(ValueError, + "referring to an input with abstract value"): + jax.jit(fun).lower(jnp.ones(5)).as_text() + + def fun(x): + return jax.ffi.ffi_call("test", jax.ShapeDtypeStruct(x.shape + x.shape, + x.dtype), + input_output_aliases={0: 0})(x) + with self.assertRaisesRegex(ValueError, + "referring to an input with abstract value"): + jax.jit(fun).lower(jnp.ones(5)).as_text() + + def test_legacy_backend_config(self): + def fun(x): + return jax.ffi.ffi_call("test", x, custom_call_api_version=2, + legacy_backend_config="12345")(x) + hlo = jax.jit(fun).lower(jnp.ones(5)).as_text() + self.assertRegex(hlo, 'backend_config = "12345"') + + def test_invalid_backend_config(self): + def fun(x): + return jax.ffi.ffi_call("test", x, legacy_backend_config="12345")(x) + with self.assertRaisesRegex(ValueError, + "The use of the legacy_backend_config"): + jax.jit(fun).lower(jnp.ones(5)).as_text() + + def fun(x): + return jax.ffi.ffi_call("test", x, + custom_call_api_version=2)(x, attribute=1) + with self.assertRaisesRegex(ValueError, + "The use of ffi_call attributes requires"): + jax.jit(fun).lower(jnp.ones(5)).as_text() + + def test_allow_x64(self): + if config.enable_x64.value: + self.skipTest("Requires enable_x64=False") + def fun(): + return jax.ffi.ffi_call("test", jax.ShapeDtypeStruct((), np.int64))() + self.assertIn("tensor", jax.jit(fun).lower().as_text()) + + def test_invalid_result_type(self): + with self.assertRaisesRegex( + ValueError, "All elements of result_shape_dtypes.*position 0"): + jax.ffi.ffi_call("test", None)() + with self.assertRaisesRegex( + ValueError, "All elements of result_shape_dtypes.*position 1"): + jax.ffi.ffi_call("test", (jax.ShapeDtypeStruct((), np.float32), ()))() + + @jtu.run_on_devices("gpu", "cpu") + def test_shard_map(self): + mesh = jtu.create_mesh((1,), ("i",)) + x = self.rng().randn(8, 4, 5).astype(np.float32) + + @partial(shard_map, mesh=mesh, in_specs=shd.PartitionSpec('i'), + out_specs=shd.PartitionSpec('i')) + def f(x): + return ffi_call_geqrf(x) + + f(x) # eager mode doesn't crash + jax.jit(f)(x) # neither does JIT + self.assertNotIn("all-gather", jax.jit(f).lower(x).compile().as_text()) + + @jtu.run_on_devices("gpu", "cpu") + def test_extend_import_shim(self): + with self.assertWarnsRegex(DeprecationWarning, "jax.extend.ffi.ffi_call"): + ffi_call_geqrf(jnp.ones((4, 5), dtype=np.float32), _use_extend=True) + + + +def ffi_call_geqrf(x, _use_extend=False, **kwargs): + if jtu.test_device_matches(["cpu"]): + lapack._lapack.initialize() + + assert x.dtype == np.float32 + ndim = x.ndim + x_major_to_minor = tuple(range(ndim - 2)) + (ndim - 1, ndim - 2) + output_types = [ + x, jax.ShapeDtypeStruct(x.shape[:-2] + (min(*x.shape[-2:]),), x.dtype)] + + def call(platform, x): + target_name = dict( + cpu="lapack_sgeqrf_ffi", + rocm="hipsolver_geqrf_ffi", + cuda="cusolver_geqrf_ffi", + )[platform] + f = jex.ffi.ffi_call if _use_extend else jax.ffi.ffi_call + return f( + target_name, output_types, input_output_aliases={0: 0}, + input_layouts=[x_major_to_minor], + output_layouts=[x_major_to_minor, None], + **kwargs)(x) + + return lax.platform_dependent( + x, cpu=partial(call, "cpu"), rocm=partial(call, "rocm"), + cuda=partial(call, "cuda")) + + +if __name__ == "__main__": + absltest.main(testLoader=jtu.JaxTestLoader())