Skip to content

Commit

Permalink
Adding support for keyword arguments to staged methods.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 702342993
  • Loading branch information
james-martens authored and KfacJaxDev committed Dec 3, 2024
1 parent bc000c6 commit cf3acc0
Showing 1 changed file with 28 additions and 8 deletions.
36 changes: 28 additions & 8 deletions kfac_jax/_src/utils/staging.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
"""K-FAC utilities for classes with staged methods."""

import functools
import inspect
import numbers
import operator
from typing import Any, Callable, Sequence
Expand Down Expand Up @@ -165,8 +166,15 @@ def staged(
This decorator **should** only be applied to instance methods of classes that
inherit from the `WithStagedMethods` class. The decorator makes the decorated
method staged, which is equivalent to `jax.jit` if `instance.multi_device` is
`False` and to `jax.pmap` otherwise. When specifying static and donated
argunms, the `self` reference **must not** be counted. Example:
`False` and to `jax.pmap` otherwise.
Note that the point of this abstraction around JAX's compilation is to make
sure that jitting/pmapping is only done once, so that if we are already in a
compiled/staged method, we won't initiate a second nested compilation when
calling into second staged method.
Note that when specifying static and donated argunms, the `self` reference
**must not** be counted. Example:
@functools.partial(staged, donate_argunms=0)
def try(self, x):
Expand Down Expand Up @@ -210,12 +218,22 @@ def try(self, x):
donate_argnums=donate_argnums)

@functools.wraps(method)
def decorated(instance: "WithStagedMethods", *args: Any) -> TArrayTree:
def decorated(
instance: "WithStagedMethods",
*args: Any,
**kwargs: Any
) -> TArrayTree:

sig = inspect.signature(method)
bound_args = sig.bind(instance, *args, **kwargs)
bound_args.apply_defaults()
args, kwargs = bound_args.args[1:], bound_args.kwargs

if instance.in_staging:
return method(instance, *args)
return method(instance, *args, **kwargs)

with instance.staging_context():

if instance.multi_device and instance.debug:
# In this case we want to call `method` once for each device index.
# Note that this might not always produce sensible behavior, and will
Expand All @@ -241,14 +259,16 @@ def decorated(instance: "WithStagedMethods", *args: Any) -> TArrayTree:
for j in range(len(args))
]

kwargs_i = jax.tree_util.tree_map(operator.itemgetter(i), kwargs)

with jax.disable_jit():
outs.append(method(instance, *args_i))
outs.append(method(instance, *args_i, **kwargs_i))

outs = jax.tree_util.tree_map(lambda *args_: jnp.stack(args_), *outs)

elif instance.debug:
with jax.disable_jit():
outs = method(instance, *args)
outs = method(instance, *args, **kwargs)

elif instance.multi_device:

Expand All @@ -274,10 +294,10 @@ def decorated(instance: "WithStagedMethods", *args: Any) -> TArrayTree:
)
pmap_funcs[key] = func

outs = func(instance, *args)
outs = func(instance, *args, **kwargs)

else:
outs = jitted_func(instance, *args)
outs = jitted_func(instance, *args, **kwargs)

return outs

Expand Down

0 comments on commit cf3acc0

Please sign in to comment.