Skip to content

Commit

Permalink
[JAX] Add a test using inputs with different device orders for a sing…
Browse files Browse the repository at this point in the history
…le colocated Python call

PiperOrigin-RevId: 703295336
  • Loading branch information
hyeontaek authored and Google-ML-Automation committed Dec 20, 2024
1 parent 4216f8f commit b2b37e8
Showing 1 changed file with 35 additions and 0 deletions.
35 changes: 35 additions & 0 deletions tests/colocated_python_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,41 @@ def sleep_and_wait(x: jax.Array) -> None:
# around 15 seconds.
self.assertLess(elapsed_time, 10)

def testInputsWithDifferentDeviceOrders(self):
cpu_devices = _colocated_cpu_devices(jax.local_devices())[:2]

@colocated_python.colocated_python
def add(x: jax.Array, y: jax.Array) -> jax.Array:
arrays = [
x.addressable_shards[1].data + y.addressable_shards[0].data,
x.addressable_shards[0].data + y.addressable_shards[1].data,
]
return jax.make_array_from_single_device_arrays(
y.shape, y.sharding, arrays
)

# The execution will use mixed device orders. We should specialize the
# function with devices to avoid the argument-dependent device selection.
add = add.specialize(devices=cpu_devices)

mesh1 = jax.sharding.Mesh([cpu_devices[0], cpu_devices[1]], "x")
sharding1 = jax.sharding.NamedSharding(
mesh1, jax.sharding.PartitionSpec("x")
)
mesh2 = jax.sharding.Mesh([cpu_devices[1], cpu_devices[0]], "x")
sharding2 = jax.sharding.NamedSharding(
mesh2, jax.sharding.PartitionSpec("x")
)

x = np.array([0, 2])
x = jax.device_put(x, sharding1)
y = np.array([4, 8])
y = jax.device_put(y, sharding2)

out = add(x, y)
out = jax.device_get(out)
np.testing.assert_equal(out, np.array([2 + 4, 0 + 8]))


if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())

0 comments on commit b2b37e8

Please sign in to comment.