Skip to content

Commit

Permalink
Adding NaN/Inf guard on call to matrix inverses/solves since LU decom…
Browse files Browse the repository at this point in the history
…p on GPU can cause an infinite loop when the matrix has these values.

PiperOrigin-RevId: 701379417
  • Loading branch information
james-martens authored and KfacJaxDev committed Nov 30, 2024
1 parent face046 commit 6c0cf40
Showing 1 changed file with 14 additions and 3 deletions.
17 changes: 14 additions & 3 deletions kfac_jax/_src/utils/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,13 @@ def psd_inv(matrix: Array) -> Array:
identity = jnp.eye(matrix.shape[0], dtype=matrix.dtype)
return linalg.solve(matrix, identity, assume_a="pos")
else:
return linalg.inv(matrix)
# Cuda's LU solver will go into an infinite loop if the matrix has NaNs or
# possibly Infs, so we need to check for that before calling it.
return lax.cond(
jnp.logical_or(jnp.any(jnp.isnan(matrix)), jnp.any(jnp.isinf(matrix))),
lambda: jnp.full(matrix.shape, jnp.nan, dtype=matrix.dtype),
lambda: linalg.inv(matrix),
)


def psd_solve(matrix: Array, vector: Array) -> Array:
Expand All @@ -385,9 +391,14 @@ def psd_solve(matrix: Array, vector: Array) -> Array:

if get_use_cholesky_inversion():
return linalg.solve(matrix, vector, assume_a="pos")

else:
return linalg.solve(matrix, vector)
# Cuda's LU solver will go into an infinite loop if the matrix has NaNs or
# possibly Infs, so we need to check for that before calling it.
return lax.cond(
jnp.logical_or(jnp.any(jnp.isnan(matrix)), jnp.any(jnp.isinf(matrix))),
lambda: jnp.full(vector.shape, jnp.nan, dtype=vector.dtype),
lambda: linalg.solve(matrix, vector),
)


def psd_solve_without_last_idx(a: Array, b: Array) -> Array:
Expand Down

0 comments on commit 6c0cf40

Please sign in to comment.