Skip to content

Commit

Permalink
Start a new TPU interpret mode for Pallas.
Browse files Browse the repository at this point in the history
The goal of this interpret mode is to run a Pallas TPU kernel on CPU,
while simulating a TPU's shared memory, multiple devices/cores, remote
DMAs, and synchronization.

The basic approach is to execute the kernel's Jaxpr on CPU, but to
replace all load/store, DMA, and synchronization primitives with
io_callbacks to a Python functions that simulate these primitives.
When this interpret mode is run inside of shard_map and jit, the
shards will run in parallel, simulating the parallel execution of the
kernel on multiple TPU devices.

The initial version in this PR can successfully interpret the examples
in https://jax.readthedocs.io/en/latest/pallas/tpu/distributed.html ,
but is still missing a lot of functionality, including:

 - Executing DMAs asynchronously.

 - Padding in pallas_call.

 - Propagating source info.
  • Loading branch information
jburnim committed Dec 12, 2024
1 parent 3ff5706 commit 2fbe0bb
Show file tree
Hide file tree
Showing 6 changed files with 1,735 additions and 4 deletions.
1 change: 1 addition & 0 deletions jax/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -623,6 +623,7 @@ pytype_strict_library(
":tpu_custom_call",
"//jax/_src/pallas",
"//jax/_src/pallas/mosaic:core",
"//jax/_src/pallas/mosaic:interpret",
"//jax/_src/pallas/mosaic:lowering",
"//jax/_src/pallas/mosaic:pallas_call_registration", # build_cleaner: keep
"//jax/_src/pallas/mosaic:pipeline",
Expand Down
12 changes: 12 additions & 0 deletions jax/_src/pallas/mosaic/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -137,3 +137,15 @@ py_library(
"//jax:typing",
] + py_deps("numpy"),
)

py_library(
name = "interpret",
srcs = ["interpret.py"],
deps = [
":core",
":primitives",
"//jax",
"//jax/_src/lib",
"//jax/_src/pallas",
] + py_deps("numpy"),
)
Loading

0 comments on commit 2fbe0bb

Please sign in to comment.