Skip to content

Commit

Permalink
Adding support for keyword arguments to staged methods.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 700407547
  • Loading branch information
james-martens authored and KfacJaxDev committed Nov 26, 2024
1 parent 5e135a6 commit 341f9f3
Showing 1 changed file with 25 additions and 7 deletions.
32 changes: 25 additions & 7 deletions kfac_jax/_src/utils/staging.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
"""K-FAC utilities for classes with staged methods."""

import functools
import inspect
import numbers
import operator
from typing import Any, Callable, Sequence
Expand Down Expand Up @@ -165,8 +166,15 @@ def staged(
This decorator **should** only be applied to instance methods of classes that
inherit from the `WithStagedMethods` class. The decorator makes the decorated
method staged, which is equivalent to `jax.jit` if `instance.multi_device` is
`False` and to `jax.pmap` otherwise. When specifying static and donated
argunms, the `self` reference **must not** be counted. Example:
`False` and to `jax.pmap` otherwise.
Note that the point of this abstraction around JAX's compilation is to make
sure that jitting/pmapping is only done once, so that if we are already in a
compiled/staged method, we won't initiate a second nested compilation when
calling into second staged method.
Note that when specifying static and donated argunms, the `self` reference
**must not** be counted. Example:
@functools.partial(staged, donate_argunms=0)
def try(self, x):
Expand Down Expand Up @@ -210,12 +218,22 @@ def try(self, x):
donate_argnums=donate_argnums)

@functools.wraps(method)
def decorated(instance: "WithStagedMethods", *args: Any) -> TArrayTree:
def decorated(
instance: "WithStagedMethods",
*args: Any,
**kwargs: Any
) -> TArrayTree:

sig = inspect.signature(method)
bound_args = sig.bind(instance, *args, **kwargs)
bound_args.apply_defaults()
args, kwargs = bound_args.args, bound_args.kwargs

if instance.in_staging:
return method(instance, *args)
return method(*args, **kwargs)

with instance.staging_context():

if instance.multi_device and instance.debug:
# In this case we want to call `method` once for each device index.
# Note that this might not always produce sensible behavior, and will
Expand Down Expand Up @@ -248,7 +266,7 @@ def decorated(instance: "WithStagedMethods", *args: Any) -> TArrayTree:

elif instance.debug:
with jax.disable_jit():
outs = method(instance, *args)
outs = method(*args, **kwargs)

elif instance.multi_device:

Expand All @@ -274,10 +292,10 @@ def decorated(instance: "WithStagedMethods", *args: Any) -> TArrayTree:
)
pmap_funcs[key] = func

outs = func(instance, *args)
outs = func(*args, **kwargs)

else:
outs = jitted_func(instance, *args)
outs = jitted_func(*args, **kwargs)

return outs

Expand Down

0 comments on commit 341f9f3

Please sign in to comment.