Skip to content

Commit

Permalink
Enabled greater range of preconditioner powers. Some math utilities a…
Browse files Browse the repository at this point in the history
…dded.

PiperOrigin-RevId: 695286988
  • Loading branch information
timothyn617 authored and KfacJaxDev committed Nov 15, 2024
1 parent 59fea08 commit 776d484
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 5 deletions.
16 changes: 15 additions & 1 deletion kfac_jax/_src/curvature_blocks/kronecker_factored.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import math
from typing import Any, Sequence

import jax
import jax.numpy as jnp
from kfac_jax._src import layers_and_loss_tags as tags
from kfac_jax._src import patches_second_moment as psm
Expand Down Expand Up @@ -280,7 +281,7 @@ def _multiply_matpower_unscaled(

else:

if power != -1 and power != -0.5:
if power not in [-1, -0.5, 0.5]:
raise NotImplementedError(
f"Approximations for power {power} is not yet implemented."
)
Expand All @@ -305,6 +306,19 @@ def _multiply_matpower_unscaled(
factors = utils.invert_psd_matrices(factors)
elif power == -0.5:
factors = utils.inverse_sqrt_psd_matrices(factors)
# TODO(timothycnguyen): Hacky psd square root. Will find a better way.
elif power == 0.5:
inverse_sqrt_factors = utils.inverse_sqrt_psd_matrices(factors)

def matmul(x, y):
if x.ndim == y.ndim == 2:
return jnp.dot(x, y)
assert x.ndim == y.ndim == 1
return x * y

factors = jax.tree_util.tree_map(
matmul, factors, inverse_sqrt_factors
)
else:
raise NotImplementedError()

Expand Down
2 changes: 2 additions & 0 deletions kfac_jax/_src/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@
scalar_div = math.scalar_div
weighted_sum_of_objects = math.weighted_sum_of_objects
sum_of_objects = math.sum_objects
pytree_size = math.pytree_size
inner_product = math.inner_product
symmetric_matrix_inner_products = math.symmetric_matrix_inner_products
matrix_of_inner_products = math.matrix_of_inner_products
Expand All @@ -131,6 +132,7 @@
invert_psd_matrices = math.invert_psd_matrices
inverse_sqrt_psd_matrices = math.inverse_sqrt_psd_matrices
stable_sqrt = math.stable_sqrt
cosine_similarity = math.cosine_similarity

del math

Expand Down
16 changes: 13 additions & 3 deletions kfac_jax/_src/utils/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,14 @@
"""K-FAC utilities for various mathematical operations."""
import functools
import string
from typing import Callable, Sequence, Iterable, TypeVar
from typing import Callable, Iterable, Sequence, TypeVar

import jax
from jax import lax
from jax.experimental.sparse import linalg as experimental_splinalg
import jax.numpy as jnp
from jax.scipy import linalg

from kfac_jax._src.utils import types

import numpy as np
import optax
import tree
Expand Down Expand Up @@ -156,6 +154,13 @@ def sum_objects(objects: Sequence[TArrayTree]) -> TArrayTree:
return weighted_sum_of_objects(objects, [1] * len(objects))


def pytree_size(pytree):
"""Computes total size of pytree leaves."""
return jax.tree_util.tree_reduce(
lambda x, y: x + y, jax.tree_util.tree_map(jnp.size, pytree), 0
)


def _inner_product_float64(obj1: ArrayTree, obj2: ArrayTree) -> Array:
"""Computes inner product explicitly in float64 precision."""

Expand Down Expand Up @@ -1155,3 +1160,8 @@ def _stable_sqrt_fwd(
_sqrt_bound_derivative.defjvp(_stable_sqrt_fwd)

stable_sqrt = functools.partial(_sqrt_bound_derivative, max_gradient=1000.0)


def cosine_similarity(v1: ArrayTree, v2: ArrayTree) -> Array:
"""Computes the cosine similarity between flattened pytrees."""
return inner_product(v1, v2) / (norm(v1) * norm(v2))
16 changes: 15 additions & 1 deletion kfac_jax/_src/utils/staging.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""K-FAC utilities for classes with staged methods."""

import functools
import numbers
from typing import Any, Callable, Sequence

import jax

from jax import lax
from kfac_jax._src.utils import misc
from kfac_jax._src.utils import parallel
from kfac_jax._src.utils import types


TArrayTree = types.TArrayTree


Expand Down Expand Up @@ -129,6 +131,18 @@ def replicate(self, obj: TArrayTree) -> TArrayTree:
else:
return obj

def pmean_if_pmap_wrapper(
self,
func: Callable[..., TArrayTree],
) -> Callable[..., TArrayTree]:
"""Wraps a function to perform a pmean if `multi_device`."""
if self.multi_device:
return lambda *args, **kwargs: lax.pmean(
func(*args, **kwargs), self.pmap_axis_name
)
else:
return func


def staged(
method: Callable[..., TArrayTree],
Expand Down

0 comments on commit 776d484

Please sign in to comment.