From cf3acc0fff007937f96db41cd79acdd341ce1650 Mon Sep 17 00:00:00 2001 From: James Martens Date: Tue, 3 Dec 2024 08:05:05 -0800 Subject: [PATCH] Adding support for keyword arguments to staged methods. PiperOrigin-RevId: 702342993 --- kfac_jax/_src/utils/staging.py | 36 ++++++++++++++++++++++++++-------- 1 file changed, 28 insertions(+), 8 deletions(-) diff --git a/kfac_jax/_src/utils/staging.py b/kfac_jax/_src/utils/staging.py index 24d7500..cd91784 100644 --- a/kfac_jax/_src/utils/staging.py +++ b/kfac_jax/_src/utils/staging.py @@ -14,6 +14,7 @@ """K-FAC utilities for classes with staged methods.""" import functools +import inspect import numbers import operator from typing import Any, Callable, Sequence @@ -165,8 +166,15 @@ def staged( This decorator **should** only be applied to instance methods of classes that inherit from the `WithStagedMethods` class. The decorator makes the decorated method staged, which is equivalent to `jax.jit` if `instance.multi_device` is - `False` and to `jax.pmap` otherwise. When specifying static and donated - argunms, the `self` reference **must not** be counted. Example: + `False` and to `jax.pmap` otherwise. + + Note that the point of this abstraction around JAX's compilation is to make + sure that jitting/pmapping is only done once, so that if we are already in a + compiled/staged method, we won't initiate a second nested compilation when + calling into second staged method. + + Note that when specifying static and donated argunms, the `self` reference + **must not** be counted. Example: @functools.partial(staged, donate_argunms=0) def try(self, x): @@ -210,12 +218,22 @@ def try(self, x): donate_argnums=donate_argnums) @functools.wraps(method) - def decorated(instance: "WithStagedMethods", *args: Any) -> TArrayTree: + def decorated( + instance: "WithStagedMethods", + *args: Any, + **kwargs: Any + ) -> TArrayTree: + + sig = inspect.signature(method) + bound_args = sig.bind(instance, *args, **kwargs) + bound_args.apply_defaults() + args, kwargs = bound_args.args[1:], bound_args.kwargs if instance.in_staging: - return method(instance, *args) + return method(instance, *args, **kwargs) with instance.staging_context(): + if instance.multi_device and instance.debug: # In this case we want to call `method` once for each device index. # Note that this might not always produce sensible behavior, and will @@ -241,14 +259,16 @@ def decorated(instance: "WithStagedMethods", *args: Any) -> TArrayTree: for j in range(len(args)) ] + kwargs_i = jax.tree_util.tree_map(operator.itemgetter(i), kwargs) + with jax.disable_jit(): - outs.append(method(instance, *args_i)) + outs.append(method(instance, *args_i, **kwargs_i)) outs = jax.tree_util.tree_map(lambda *args_: jnp.stack(args_), *outs) elif instance.debug: with jax.disable_jit(): - outs = method(instance, *args) + outs = method(instance, *args, **kwargs) elif instance.multi_device: @@ -274,10 +294,10 @@ def decorated(instance: "WithStagedMethods", *args: Any) -> TArrayTree: ) pmap_funcs[key] = func - outs = func(instance, *args) + outs = func(instance, *args, **kwargs) else: - outs = jitted_func(instance, *args) + outs = jitted_func(instance, *args, **kwargs) return outs