Skip to content

Commit

Permalink
[pallas:mosaic_gpu] Added some runtime type checking to copy_* and …
Browse files Browse the repository at this point in the history
…`barrier_*` primitives

PiperOrigin-RevId: 709279972
  • Loading branch information
superbobry authored and Google-ML-Automation committed Dec 24, 2024
1 parent fa9c7ed commit 09210c1
Showing 1 changed file with 25 additions and 16 deletions.
41 changes: 25 additions & 16 deletions jax/_src/pallas/mosaic_gpu/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from jax._src.pallas import core as pallas_core
from jax._src.pallas.mosaic_gpu import core as gpu_core
from jax._src.pallas.mosaic_gpu import lowering
from jax._src.pallas.mosaic_gpu.core import state_types
from jax._src.state import discharge
from jax._src.state import indexing
from jax._src.state import primitives as state_primitives
Expand Down Expand Up @@ -113,10 +114,23 @@ def _extract_smem_copy_params(transforms):
)


_Ref = pallas_core.AbstractMemoryRef | state_types.TransformedRef


def _check_ref(
value: object, name: str, memory_space: gpu_core.GPUMemorySpace
) -> None:
if not isinstance(value, _Ref):
raise TypeError(f"{name} must be a reference, got {value}")
value_memory_space = getattr(value, "memory_space", None) or gpu_core.GMEM
if value_memory_space is not memory_space:
raise ValueError(
f"{name} must be a {memory_space.name.upper()} reference, got {value}"
)


def copy_smem_to_gmem(
src: pallas_core.AbstractMemoryRef,
dst: pallas_core.AbstractMemoryRef,
predicate: jax.Array | None = None,
src: _Ref, dst: _Ref, predicate: jax.Array | None = None
) -> None:
"""Asynchronously copies a SMEM reference to a GMEM reference.
Expand All @@ -130,10 +144,8 @@ def copy_smem_to_gmem(
:func:`jax.experimental.mosaic.gpu.wait_smem_to_gmem`
:func:`jax.experimental.mosaic.gpu.commit_smem`
"""
if src.memory_space is not gpu_core.SMEM:
raise TypeError(f"src must be a SMEM reference, got {src.memory_space}")
if getattr(dst, "memory_space", gpu_core.GMEM) is not gpu_core.GMEM:
raise ValueError(f"dst must be a GMEM reference, got {dst.memory_space}")
_check_ref(src, "src", gpu_core.SMEM)
_check_ref(dst, "dst", gpu_core.GMEM)
src, src_transforms = state_primitives.get_ref_and_transforms(
src, None, "copy_smem_to_gmem", force_trailing_indexer=False,
)
Expand Down Expand Up @@ -217,21 +229,16 @@ def _copy_gmem_to_smem_lowering(
return ()


def copy_gmem_to_smem(
src: pallas_core.AbstractMemoryRef,
dst: pallas_core.AbstractMemoryRef,
barrier: pallas_core.AbstractMemoryRef,
) -> None:
def copy_gmem_to_smem(src: _Ref, dst: _Ref, barrier: _Ref) -> None:
"""Asynchronously copies a GMEM reference to a SMEM reference.
See also:
:func:`jax.experimental.mosaic.gpu.barrier_arrive`
:func:`jax.experimental.mosaic.gpu.barrier_wait`
"""
if getattr(src, "memory_space", gpu_core.GMEM) is not gpu_core.GMEM:
raise TypeError(f"src must be a GMEM reference, got {src.memory_space}")
if dst.memory_space is not gpu_core.SMEM:
raise ValueError(f"dst must be a SMEM reference, got {dst.memory_space}")
_check_ref(src, "src", gpu_core.GMEM)
_check_ref(dst, "dst", gpu_core.SMEM)
_check_ref(barrier, "barrier", gpu_core.SMEM)
src, src_transforms = state_primitives.get_ref_and_transforms(
src, None, "copy_gmem_to_smem", force_trailing_indexer=False,
)
Expand Down Expand Up @@ -314,6 +321,7 @@ def _barrier_arrive_lowering(

def barrier_arrive(barrier: pallas_core.AbstractMemoryRef) -> None:
"""Arrives at the given barrier."""
_check_ref(barrier, "barrier", gpu_core.SMEM)
barrier, transforms = state_primitives.get_ref_and_transforms(
barrier, None, "barrier_arrive", force_trailing_indexer=False,
)
Expand Down Expand Up @@ -351,6 +359,7 @@ def _barrier_wait_lowering(

def barrier_wait(barrier: pallas_core.AbstractMemoryRef) -> None:
"""Waits on the given barrier."""
_check_ref(barrier, "barrier", gpu_core.SMEM)
barrier, transforms = state_primitives.get_ref_and_transforms(
barrier, None, "barrier_wait", force_trailing_indexer=False,
)
Expand Down

0 comments on commit 09210c1

Please sign in to comment.