Skip to content

Commit

Permalink
Handle local measures in TransformedDistribution.
Browse files Browse the repository at this point in the history
This change continues to set up the framework for tracking base measures and computing corrections on transformed densities. In `TransformedDistribution` we update `log_prob` to call a version of `experimental_local_measure` that keeps track of the base measure. We use the backwards-compatibility argument to control this rollout.

Note that this change reverses the bijector method called by `transformed_distribution._log_prob` from `inverse_log_det_jacobian` to `forward_log_det_jacobian`, which (i) shifted the numerics, and (ii) affected which functions get exercised by the test suite. As a result, in this change we (i) loosen tolerances in some tests, and (ii) find and fix a dtype correctness bug in `moyal_cdf`.

PiperOrigin-RevId: 385616650
  • Loading branch information
DistraxDev authored and DistraxDev committed Feb 17, 2022
1 parent cf013e2 commit 72afc80
Show file tree
Hide file tree
Showing 4 changed files with 110 additions and 1 deletion.
31 changes: 31 additions & 0 deletions distrax/_src/bijectors/tfp_compatible_bijector.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

Array = chex.Array
Bijector = bijector.Bijector
TangentSpace = tfp.experimental.tangent_spaces.TangentSpace


def tfp_compatible_bijector(
Expand Down Expand Up @@ -175,4 +176,34 @@ def _check_shape(
f"{event_shape} which has only {len(event_shape)} "
f"dimensions instead.")

def experimental_compute_density_correction(
self,
x: Array,
tangent_space: TangentSpace,
backward_compat: bool = True,
**kwargs):
"""Density correction for this transform wrt the tangent space, at x.
See `tfp.bijectors.experimental_compute_density_correction`, and
Radul and Alexeev, AISTATS 2021, “The Base Measure Problem and its
Solution”, https://arxiv.org/abs/2010.09647.
Args:
x: `float` or `double` `Array`.
tangent_space: `TangentSpace` or one of its subclasses. The tangent to
the support manifold at `x`.
backward_compat: unused
**kwargs: Optional keyword arguments forwarded to tangent space methods.
Returns:
density_correction: `Array` representing the density correction---in log
space---under the transformation that this Bijector denotes. Assumes
the Bijector is dimension-preserving.
"""
del backward_compat
# We ignore the `backward_compat` flag and always act as though it's
# true because Distrax bijectors and distributions need not follow the
# base measure protocol from TFP.
return tangent_space.transform_dimension_preserving(x, self, **kwargs)

return TFPCompatibleBijector()
10 changes: 10 additions & 0 deletions distrax/_src/bijectors/tfp_compatible_bijector_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
tfb = tfp.bijectors
tfd = tfp.distributions


RTOL = 3e-3


Expand Down Expand Up @@ -209,6 +210,15 @@ def test_log_det_jacobian(self, dx_bijector_fn, tfp_bijector_fn, event):
y, event_ndims=base_bij.event_ndims_out)
np.testing.assert_allclose(dx_out, tfp_out, rtol=RTOL)

with self.subTest('experimental_compute_density_correction'):
dx_out = dx_bij.forward_log_det_jacobian(
event, event_ndims=base_bij.event_ndims_in)
dx_dcorr_out, space = dx_bij.experimental_compute_density_correction(
event, tangent_space=tfp.experimental.tangent_spaces.FullSpace(),
event_ndims=base_bij.event_ndims_in)
np.testing.assert_allclose(dx_out, dx_dcorr_out, rtol=RTOL)
self.assertIsInstance(space, tfp.experimental.tangent_spaces.FullSpace)

@parameterized.named_parameters(
('identity unbatched',
lambda: Lambda(lambda x: x, is_constant_jacobian=True), ()),
Expand Down
32 changes: 31 additions & 1 deletion distrax/_src/distributions/tfp_compatible_distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# ==============================================================================
"""Wrapper to adapt a Distrax distribution for use in TFP."""

from typing import Dict, Optional, Sequence, Union
from typing import Dict, Optional, Sequence, Tuple, Union

import chex
from distrax._src.distributions import distribution
Expand All @@ -29,6 +29,8 @@
Distribution = distribution.Distribution
IntLike = distribution.IntLike
PRNGKey = chex.PRNGKey
tangent_spaces = tfp.experimental.tangent_spaces
TangentSpace = tangent_spaces.TangentSpace


def tfp_compatible_distribution(
Expand Down Expand Up @@ -136,4 +138,32 @@ def sample(self,
sample_shape = tuple(sample_shape)
return base_distribution.sample(sample_shape=sample_shape, seed=seed)

def experimental_local_measure(
self,
value: Array,
backward_compat: bool = True,
**unused_kwargs) -> Tuple[Array, TangentSpace]:
"""Returns a log probability density together with a `TangentSpace`.
See `tfd.distribution.Distribution.experimental_local_measure`, and
Radul and Alexeev, AISTATS 2021, “The Base Measure Problem and its
Solution”, https://arxiv.org/abs/2010.09647.
Args:
value: `float` or `double` `Array`.
backward_compat: unused
**unused_kwargs: unused
Returns:
log_prob: see `log_prob`.
tangent_space: `tangent_spaces.FullSpace()`, representing R^n with the
standard basis.
"""
del backward_compat
# We ignore the `backward_compat` flag and always act as though it's
# true because Distrax bijectors and distributions need not follow the
# base measure protocol from TFP.
del unused_kwargs
return self.log_prob(value), tangent_spaces.FullSpace()

return TFPCompatibleDistribution()
38 changes: 38 additions & 0 deletions distrax/_src/distributions/tfp_compatible_distribution_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,18 @@ def test_with_independent(self):

self.assertion_fn(log_prob, expected_log_prob)

def test_local_measure_with_independent(self):
base_dist = Normal(loc=jnp.array([0., 0.]), scale=jnp.array([1., 1.]))
wrapped_dist = tfp_compatible_distribution(base_dist)
meta_dist = tfd.Independent(wrapped_dist, 1, validate_args=True)
samples = meta_dist.sample((), self._key)
expected_log_prob = meta_dist.log_prob(samples)

log_prob, space = meta_dist.experimental_local_measure(
samples, backward_compat=True)
self.assertion_fn(log_prob, expected_log_prob)
self.assertIsInstance(space, tfp.experimental.tangent_spaces.FullSpace)

def test_with_transformed_distribution(self):
base_dist = Normal(loc=jnp.array([0., 0.]), scale=jnp.array([1., 1.]))
wrapped_dist = tfp_compatible_distribution(base_dist)
Expand All @@ -234,13 +246,39 @@ def test_with_transformed_distribution(self):

self.assertion_fn(log_prob, expected_log_prob)

def test_local_measure_with_transformed_distribution(self):
base_dist = Normal(loc=jnp.array([0., 0.]), scale=jnp.array([1., 1.]))
wrapped_dist = tfp_compatible_distribution(base_dist)
meta_dist = tfd.TransformedDistribution(
distribution=wrapped_dist, bijector=tfb.Exp(), validate_args=True)
samples = meta_dist.sample(seed=self._key)
expected_log_prob = meta_dist.log_prob(samples)

log_prob, space = meta_dist.experimental_local_measure(
samples, backward_compat=True)
self.assertion_fn(log_prob, expected_log_prob)
self.assertIsInstance(space, tfp.experimental.tangent_spaces.FullSpace)

def test_with_sample(self):
base_dist = Normal(0., 1.)
wrapped_dist = tfp_compatible_distribution(base_dist)
meta_dist = tfd.Sample(
wrapped_dist, sample_shape=[1, 3], validate_args=True)
meta_dist.log_prob(meta_dist.sample(2, seed=self._key))

def test_local_measure_with_sample(self):
base_dist = Normal(0., 1.)
wrapped_dist = tfp_compatible_distribution(base_dist)
meta_dist = tfd.Sample(
wrapped_dist, sample_shape=[1, 3], validate_args=True)
samples = meta_dist.sample(2, seed=self._key)
expected_log_prob = meta_dist.log_prob(samples)

log_prob, space = meta_dist.experimental_local_measure(
samples, backward_compat=True)
self.assertion_fn(log_prob, expected_log_prob)
self.assertIsInstance(space, tfp.experimental.tangent_spaces.FullSpace)

def test_with_joint_distribution_named_auto_batched(self):
def laplace(a, b):
return tfp_compatible_distribution(Laplace(a * jnp.ones((2, 1)), b))
Expand Down

0 comments on commit 72afc80

Please sign in to comment.