Skip to content

Commit

Permalink
Restoring support for scalar factors in KroneckerFactored curvature b…
Browse files Browse the repository at this point in the history
…lock super class.

PiperOrigin-RevId: 580468365
  • Loading branch information
james-martens authored and KfacJaxDev committed Nov 8, 2023
1 parent f466559 commit c1c1fd7
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 31 deletions.
19 changes: 4 additions & 15 deletions kfac_jax/_src/curvature_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -1117,21 +1117,10 @@ def array_ndim(self) -> int:
@property
def grouped_array_shape(self) -> Shape:
"""The shape of the single axis grouped array."""

shape = []
for group in self.axis_groups:

size = utils.product([self.array_shape[i] for i in group])

# filter out groups of size 1
if size != 1:
shape.append(size)

# need at least one group
if not shape:
shape = [1]

return tuple(shape)
return tuple(
utils.product([self.array_shape[i] for i in group])
for group in self.axis_groups
)

@property
def grouped_array_ndim(self) -> int:
Expand Down
58 changes: 42 additions & 16 deletions kfac_jax/_src/utils/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,35 +590,62 @@ def pi_adjusted_kronecker_factors(
# Compute the normalized factors `u_i`, such that Trace(u_i) / dim(u_i) = 1
us = [fi / ni for fi, ni in zip(factors, norms)]

# 0-dim scalar factors are not allowed
assert not any(f.ndim == 0 for f in factors)

k = len(factors)

# TODO(jamesmartens,botev): consider making the use of special behavior for
# scalar factors a module-level configurable option. One can argue that scalar
# factors should behave the same as non-scalar factors for the sake of
# consistent behavior as the layer widths shrink to 1.

def regular_case() -> Tuple[Array, ...]:

# Distribute c and damping/c among k factors, where c = jnp.prod(norms),
# satisfying kron(factors) = c * kron(us).
num_non_scalars = sum(1 if f.size != 1 else 0 for f in factors)

if num_non_scalars != 0:

# Distribute c and damping/c among k factors, where c = jnp.prod(norms),
# satisfying kron(factors) = c * kron(us).

# NOTE: c_k (geometric mean of norms) can also be calculated by
# c ** (1/k) = jnp.prod(norms) ** (1 / len(norms)), but this alternative
# can make the result zero due to the multiplication of (potentially)
# small values, i.e. jnp.prod(norms).
c_k = jnp.exp(jnp.mean(jnp.log(norms)))

c_k = jnp.exp(jnp.mean(jnp.log(norms)))
d_k = jnp.power(damping, 1.0 / k) / c_k

# NOTE: c_k (geometric mean of norms) can also be calculated by
# c ** (1/k) = jnp.prod(norms) ** (1 / len(norms)), but this alternative
# can make the result zero due to the multiplication of (potentially) small
# values, i.e. jnp.prod(norms).
if k > num_non_scalars:

d_k = jnp.power(damping, 1.0 / k) / c_k
c_non_scalar = c_k ** (float(k) / num_non_scalars)

# We distribute the damping only inside the non-scalar factors
d_hat = jnp.power(damping, 1.0 / num_non_scalars) / c_non_scalar

else:
d_hat = d_k

else:

# This could cause under/overflow, but it's unavoidable here.
c = jnp.prod(jnp.array(norms))

# In the case where all factors are scalar we need to add the damping and
# then take the k-th root
c_k = jnp.power(c + damping, 1.0 / k)

u_hats = []

for u in us:

if u.ndim == 2:
u_hat = u + d_k * jnp.eye(u.shape[0], dtype=u.dtype)
if u.size == 1: # scalar case
u_hat = jnp.ones_like(u) # damping not used in the scalar factors

elif u.ndim == 2:
u_hat = u + d_hat * jnp.eye(u.shape[0], dtype=u.dtype)

else: # diagonal case
assert u.ndim == 1
u_hat = u + d_k
u_hat = u + d_hat

u_hats.append(u_hat * c_k)

Expand All @@ -638,8 +665,7 @@ def zero_case() -> Tuple[Array, ...]:
if u.ndim == 2:
u_hat = jnp.eye(u.shape[0], dtype=u.dtype)

else: # diagonal case
assert u.ndim == 1
else:
u_hat = jnp.ones_like(u)

u_hats.append(u_hat * c_k)
Expand Down

0 comments on commit c1c1fd7

Please sign in to comment.