Skip to content

Commit

Permalink
Accelerate deprecation of legacy JAX FFI calling convention.
Browse files Browse the repository at this point in the history
In #24370, `ffi_call` was updated to return a callable, and the original calling convention was deprecated. This change is part of the deprecation cycle for this calling convention.

PiperOrigin-RevId: 708424223
  • Loading branch information
dfm authored and Google-ML-Automation committed Dec 20, 2024
1 parent 3a35155 commit 4216f8f
Showing 1 changed file with 8 additions and 2 deletions.
10 changes: 8 additions & 2 deletions tests/extend_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from jax._src import api
from jax._src import config
from jax._src import core
from jax._src import deprecations
from jax._src import linear_util
from jax._src import prng
from jax._src import test_util as jtu
Expand Down Expand Up @@ -279,8 +280,13 @@ def testVectorizedDeprecation(self):
def testBackwardCompatSyntax(self):
def fun(x):
return jex.ffi.ffi_call("test_ffi", x, x, param=0.5)
with self.assertWarns(DeprecationWarning):
jax.jit(fun).lower(jnp.ones(5))
msg = "Calling ffi_call directly with input arguments is deprecated"
if deprecations.is_accelerated("jax-ffi-call-args"):
with self.assertRaisesRegex(ValueError, msg):
jax.jit(fun).lower(jnp.ones(5))
else:
with self.assertWarnsRegex(DeprecationWarning, msg):
jax.jit(fun).lower(jnp.ones(5))

def testInputOutputAliases(self):
def fun(x):
Expand Down

0 comments on commit 4216f8f

Please sign in to comment.