From 9318b11dae8e048176750cba9f7611200d556c84 Mon Sep 17 00:00:00 2001 From: DistraxDev Date: Mon, 19 Jul 2021 12:34:59 -0700 Subject: [PATCH] Handle local measures in TransformedDistribution. This change continues to set up the framework for tracking base measures and computing corrections on transformed densities. In `TransformedDistribution` we update `log_prob` to call a version of `experimental_local_measure` that keeps track of the base measure. We introduce a backwards-compatibility argument to control this rollout. PiperOrigin-RevId: 385616650 --- .../_src/bijectors/tfp_compatible_bijector.py | 28 +++++++++++++++++++ .../tfp_compatible_distribution.py | 27 +++++++++++++++++- 2 files changed, 54 insertions(+), 1 deletion(-) diff --git a/distrax/_src/bijectors/tfp_compatible_bijector.py b/distrax/_src/bijectors/tfp_compatible_bijector.py index f70370a8..05718d1f 100644 --- a/distrax/_src/bijectors/tfp_compatible_bijector.py +++ b/distrax/_src/bijectors/tfp_compatible_bijector.py @@ -21,6 +21,7 @@ from distrax._src.utils import math import jax import jax.numpy as jnp +from tensorflow_probability.python.experimental import tangent_spaces from tensorflow_probability.substrates import jax as tfp tfb = tfp.bijectors @@ -28,6 +29,7 @@ Array = chex.Array Bijector = bijector.Bijector +TangentSpace = tangent_spaces.TangentSpace def tfp_compatible_bijector( @@ -175,4 +177,30 @@ def _check_shape( f"{event_shape} which has only {len(event_shape)} " f"dimensions instead.") + def experimental_compute_density_correction( + self, + x: Array, + tangent_space: TangentSpace, + backward_compat: bool = True, + **kwargs): + """Density correction for this transform wrt the tangent space, at x. + + See `tfb.bijector.Bijector.experimental_compute_density_correction`, and + Radul and Alexeev, AISTATS 2021, “The Base Measure Problem and its + Solution”, https://arxiv.org/abs/2010.09647. + + Args: + x: `float` or `double` `Array`. + tangent_space: `TangentSpace` or one of its subclasses. The tangent to + the support manifold at `x`. + backward_compat: ignored + **kwargs: Optional keyword arguments forwarded to tangent space methods. + + Returns: + density_correction: `Array` representing the density correction---in log + space---under the transformation that this Bijector denotes. Assumes + the Bijector is dimension-preserving. + """ + return tangent_space.transform_dimension_preserving(x, self, **kwargs) + return TFPCompatibleBijector() diff --git a/distrax/_src/distributions/tfp_compatible_distribution.py b/distrax/_src/distributions/tfp_compatible_distribution.py index 96addef0..0c4da081 100644 --- a/distrax/_src/distributions/tfp_compatible_distribution.py +++ b/distrax/_src/distributions/tfp_compatible_distribution.py @@ -14,12 +14,13 @@ # ============================================================================== """Wrapper to adapt a Distrax distribution for use in TFP.""" -from typing import Dict, Optional, Sequence, Union +from typing import Dict, Optional, Sequence, Tuple, Union import chex from distrax._src.distributions import distribution import jax.numpy as jnp import numpy as np +from tensorflow_probability.python.experimental import tangent_spaces from tensorflow_probability.substrates import jax as tfp tfd = tfp.distributions @@ -29,6 +30,7 @@ Distribution = distribution.Distribution IntLike = distribution.IntLike PRNGKey = chex.PRNGKey +TangentSpace = tangent_spaces.TangentSpace def tfp_compatible_distribution( @@ -136,4 +138,27 @@ def sample(self, sample_shape = tuple(sample_shape) return base_distribution.sample(sample_shape=sample_shape, seed=seed) + def experimental_local_measure( + self, + value: Array, + unused_backward_compat: bool = True, + **unused_kwargs) -> Tuple[Array, TangentSpace]: + """Returns a log probability density together with a `TangentSpace`. + + See `tfd.distribution.Distribution.experimental_local_measure`, and + Radul and Alexeev, AISTATS 2021, “The Base Measure Problem and its + Solution”, https://arxiv.org/abs/2010.09647. + + Args: + value: `float` or `double` `Array`. + unused_backward_compat: ignored + **unused_kwargs: ignored + + Returns: + log_prob: see `log_prob`. + tangent_space: `tangent_spaces.FullSpace()`, representing R^n with the + standard basis. + """ + return self.log_prob(value), tangent_spaces.FullSpace() + return TFPCompatibleDistribution()