Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Start a new TPU interpret mode for Pallas.
The goal of this interpret mode is to run a Pallas TPU kernel on CPU, while simulating a TPU's shared memory, multiple devices/cores, remote DMAs, and synchronization. The basic approach is to execute the kernel's Jaxpr on CPU, but to replace all load/store, DMA, and synchronization primitives with io_callbacks to a Python functions that simulate these primitives. When this interpret mode is run inside of shard_map and jit, the shards will run in parallel, simulating the parallel execution of the kernel on multiple TPU devices. The initial version in this PR can successfully interpret the examples in https://jax.readthedocs.io/en/latest/pallas/tpu/distributed.html , but is still missing a lot of functionality, including: - Executing DMAs asynchronously. - Padding in pallas_call. - Propagating source info.
- Loading branch information