diff --git a/jax/_src/pallas/mosaic/pallas_call_registration.py b/jax/_src/pallas/mosaic/pallas_call_registration.py index ec9500c67cd7..8a7b64db6454 100644 --- a/jax/_src/pallas/mosaic/pallas_call_registration.py +++ b/jax/_src/pallas/mosaic/pallas_call_registration.py @@ -84,6 +84,8 @@ def _get_memory_space_from_aval( return None case tpu_core.TPUMemorySpace.VMEM: return tpu_custom_call.MemorySpace.VMEM + case tpu_core.TPUMemorySpace.SMEM: + return tpu_custom_call.MemorySpace.SMEM case tpu_core.TPUMemorySpace.SEMAPHORE: return tpu_custom_call.MemorySpace.SEMAPHORE_MEM return None diff --git a/jax/_src/tpu_custom_call.py b/jax/_src/tpu_custom_call.py index 2d77acba02da..bb92afebe8e9 100644 --- a/jax/_src/tpu_custom_call.py +++ b/jax/_src/tpu_custom_call.py @@ -83,6 +83,7 @@ class MemorySpace(enum.Enum): HBM = enum.auto() VMEM = enum.auto() SEMAPHORE_MEM = enum.auto() + SMEM = enum.auto() @property def color(self) -> int: @@ -92,6 +93,8 @@ def color(self) -> int: return 1 elif self == MemorySpace.SEMAPHORE_MEM: return 2 + elif self == MemorySpace.SMEM: + return 4 else: raise ValueError("invalid memory space: " + str(self))