Skip to content

Commit

Permalink
Update chex.assert_type to check concrete types instead of just ass…
Browse files Browse the repository at this point in the history
…erting that the type is a floating/integer sub-type.

Previously, `assert_type` would only check that the input was of the same parent type. For example:
```
x = np.ones((1,), dtype=np.float32)
chex.assert_type(x, np.float64)  # Succeeds
chex.assert_type(x, np.int32)  # Fails.
```

Instead, if a concrete dtype is provided we check that the input has the same type. If `float` or `np.floating` is provided, we continue to only assert that the input is the same parent.

```
x = np.ones((1,), dtype=np.float32)
chex.assert_type(x, np.float64) # Fails
chex.assert_type(x, float) # Succeeds.
```
PiperOrigin-RevId: 607102995
  • Loading branch information
tomwardio authored and DistraxDev committed Feb 16, 2024
1 parent 7c0e1bf commit d3d039f
Show file tree
Hide file tree
Showing 15 changed files with 114 additions and 82 deletions.
10 changes: 6 additions & 4 deletions distrax/_src/distributions/deterministic_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import chex
from distrax._src.distributions import deterministic
from distrax._src.utils import equivalence
import jax.experimental
import jax.numpy as jnp
import numpy as np

Expand Down Expand Up @@ -107,10 +108,11 @@ def test_sample_shape(self, loc, sample_shape):
('float32', jnp.float32),
('float64', jnp.float64))
def test_sample_dtype(self, dtype):
dist = self.distrax_cls(loc=jnp.zeros((), dtype=dtype))
samples = self.variant(dist.sample)(seed=self.key)
self.assertEqual(samples.dtype, dist.dtype)
chex.assert_type(samples, dtype)
with jax.experimental.enable_x64(dtype.dtype.itemsize == 8):
dist = self.distrax_cls(loc=jnp.zeros((), dtype=dtype))
samples = self.variant(dist.sample)(seed=self.key)
self.assertEqual(samples.dtype, dist.dtype)
chex.assert_type(samples, dtype)

@chex.all_variants
@parameterized.named_parameters(
Expand Down
12 changes: 7 additions & 5 deletions distrax/_src/distributions/epsilon_greedy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import chex
from distrax._src.distributions import epsilon_greedy
from distrax._src.utils import equivalence
import jax.experimental
import jax.numpy as jnp
import numpy as np

Expand Down Expand Up @@ -51,11 +52,12 @@ def test_num_categories(self):
('float32', jnp.float32),
('float64', jnp.float64))
def test_sample_dtype(self, dtype):
dist = self.distrax_cls(
preferences=self.preferences, epsilon=self.epsilon, dtype=dtype)
samples = self.variant(dist.sample)(seed=self.key)
self.assertEqual(samples.dtype, dist.dtype)
chex.assert_type(samples, dtype)
with jax.experimental.enable_x64(dtype.dtype.itemsize == 8):
dist = self.distrax_cls(
preferences=self.preferences, epsilon=self.epsilon, dtype=dtype)
samples = self.variant(dist.sample)(seed=self.key)
self.assertEqual(samples.dtype, dist.dtype)
chex.assert_type(samples, dtype)

def test_jittable(self):
super()._test_jittable(
Expand Down
12 changes: 7 additions & 5 deletions distrax/_src/distributions/gamma_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import chex
from distrax._src.distributions import gamma
from distrax._src.utils import equivalence
import jax.experimental
import jax.numpy as jnp
import numpy as np

Expand Down Expand Up @@ -73,11 +74,12 @@ def test_sample_shape(self, distr_params, sample_shape):
('float32', jnp.float32),
('float64', jnp.float64))
def test_sample_dtype(self, dtype):
dist = self.distrax_cls(
concentration=jnp.ones((), dtype), rate=jnp.ones((), dtype))
samples = self.variant(dist.sample)(seed=self.key)
self.assertEqual(samples.dtype, dist.dtype)
chex.assert_type(samples, dtype)
with jax.experimental.enable_x64(dtype.dtype.itemsize == 8):
dist = self.distrax_cls(
concentration=jnp.ones((), dtype), rate=jnp.ones((), dtype))
samples = self.variant(dist.sample)(seed=self.key)
self.assertEqual(samples.dtype, dist.dtype)
chex.assert_type(samples, dtype)

@chex.all_variants
@parameterized.named_parameters(
Expand Down
10 changes: 6 additions & 4 deletions distrax/_src/distributions/greedy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import chex
from distrax._src.distributions import greedy
from distrax._src.utils import equivalence
import jax.experimental
import jax.numpy as jnp
import numpy as np

Expand Down Expand Up @@ -48,10 +49,11 @@ def test_num_categories(self):
('float32', jnp.float32),
('float64', jnp.float64))
def test_sample_dtype(self, dtype):
dist = self.distrax_cls(preferences=self.preferences, dtype=dtype)
samples = self.variant(dist.sample)(seed=self.key)
self.assertEqual(samples.dtype, dist.dtype)
chex.assert_type(samples, dtype)
with jax.experimental.enable_x64(dtype.dtype.itemsize == 8):
dist = self.distrax_cls(preferences=self.preferences, dtype=dtype)
samples = self.variant(dist.sample)(seed=self.key)
self.assertEqual(samples.dtype, dist.dtype)
chex.assert_type(samples, dtype)

def test_jittable(self):
super()._test_jittable((np.array([0., 4., -1., 4.]),))
Expand Down
12 changes: 7 additions & 5 deletions distrax/_src/distributions/gumbel_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import chex
from distrax._src.distributions import gumbel
from distrax._src.utils import equivalence
import jax.experimental
import jax.numpy as jnp
import numpy as np

Expand Down Expand Up @@ -67,11 +68,12 @@ def test_sample_shape(self, distr_params, sample_shape):
('float32', jnp.float32),
('float64', jnp.float64))
def test_sample_dtype(self, dtype):
dist = self.distrax_cls(
loc=jnp.zeros((), dtype), scale=jnp.ones((), dtype))
samples = self.variant(dist.sample)(seed=self.key)
self.assertEqual(samples.dtype, dist.dtype)
chex.assert_type(samples, dtype)
with jax.experimental.enable_x64(dtype.dtype.itemsize == 8):
dist = self.distrax_cls(
loc=jnp.zeros((), dtype), scale=jnp.ones((), dtype))
samples = self.variant(dist.sample)(seed=self.key)
self.assertEqual(samples.dtype, dist.dtype)
chex.assert_type(samples, dtype)

@chex.all_variants
@parameterized.named_parameters(
Expand Down
12 changes: 7 additions & 5 deletions distrax/_src/distributions/laplace_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import chex
from distrax._src.distributions import laplace
from distrax._src.utils import equivalence
import jax.experimental
import jax.numpy as jnp
import numpy as np

Expand Down Expand Up @@ -65,11 +66,12 @@ def test_sample_shape(self, distr_params, sample_shape):
('float32', jnp.float32),
('float64', jnp.float64))
def test_sample_dtype(self, dtype):
dist = self.distrax_cls(
loc=jnp.zeros((), dtype), scale=jnp.ones((), dtype))
samples = self.variant(dist.sample)(seed=self.key)
self.assertEqual(samples.dtype, dist.dtype)
chex.assert_type(samples, dtype)
with jax.experimental.enable_x64(dtype.dtype.itemsize == 8):
dist = self.distrax_cls(
loc=jnp.zeros((), dtype), scale=jnp.ones((), dtype))
samples = self.variant(dist.sample)(seed=self.key)
self.assertEqual(samples.dtype, dist.dtype)
chex.assert_type(samples, dtype)

@chex.all_variants
@parameterized.named_parameters(
Expand Down
12 changes: 7 additions & 5 deletions distrax/_src/distributions/log_stddev_normal_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from distrax._src.distributions import log_stddev_normal as lsn
from distrax._src.distributions import normal
import jax
import jax.experimental
import jax.numpy as jnp
import mock
import numpy as np
Expand Down Expand Up @@ -105,11 +106,12 @@ def test_sampling_batched_custom_dim(self):
('float32', jnp.float32),
('float64', jnp.float64))
def test_sample_dtype(self, dtype):
dist = lsn.LogStddevNormal(
loc=jnp.zeros((), dtype), log_scale=jnp.zeros((), dtype))
samples = self.variant(dist.sample)(seed=jax.random.PRNGKey(0))
self.assertEqual(samples.dtype, dist.dtype)
chex.assert_type(samples, dtype)
with jax.experimental.enable_x64(dtype.dtype.itemsize == 8):
dist = lsn.LogStddevNormal(
loc=jnp.zeros((), dtype), log_scale=jnp.zeros((), dtype))
samples = self.variant(dist.sample)(seed=jax.random.PRNGKey(0))
self.assertEqual(samples.dtype, dist.dtype)
chex.assert_type(samples, dtype)

def test_kl_versus_normal(self):
loc, scale = jnp.array([2.0]), jnp.array([2.0])
Expand Down
12 changes: 7 additions & 5 deletions distrax/_src/distributions/logistic_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import chex
from distrax._src.distributions import logistic
from distrax._src.utils import equivalence
import jax.experimental
import jax.numpy as jnp
import numpy as np

Expand Down Expand Up @@ -66,11 +67,12 @@ def test_sample_shape(self, distr_params, sample_shape):
('float32', jnp.float32),
('float64', jnp.float64))
def test_sample_dtype(self, dtype):
dist = self.distrax_cls(
loc=jnp.zeros((), dtype), scale=jnp.ones((), dtype))
samples = self.variant(dist.sample)(seed=self.key)
self.assertEqual(samples.dtype, dist.dtype)
chex.assert_type(samples, dtype)
with jax.experimental.enable_x64(dtype.dtype.itemsize == 8):
dist = self.distrax_cls(
loc=jnp.zeros((), dtype), scale=jnp.ones((), dtype))
samples = self.variant(dist.sample)(seed=self.key)
self.assertEqual(samples.dtype, dist.dtype)
chex.assert_type(samples, dtype)

@chex.all_variants
@parameterized.named_parameters(
Expand Down
17 changes: 11 additions & 6 deletions distrax/_src/distributions/multinomial_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from distrax._src.utils import equivalence
from distrax._src.utils import math
import jax
import jax.experimental
import jax.numpy as jnp
import numpy as np
from scipy import stats
Expand Down Expand Up @@ -405,12 +406,16 @@ def test_sample_and_log_prob(self, dist_params, sample_shape):
('float32', jnp.float32),
('float64', jnp.float64))
def test_sample_dtype(self, dtype):
dist_params = {
'logits': self.logits, 'dtype': dtype, 'total_count': self.total_count}
dist = self.distrax_cls(**dist_params)
samples = self.variant(dist.sample)(seed=self.key)
self.assertEqual(samples.dtype, dist.dtype)
chex.assert_type(samples, dtype)
with jax.experimental.enable_x64(dtype.dtype.itemsize == 8):
dist_params = {
'logits': self.logits,
'dtype': dtype,
'total_count': self.total_count,
}
dist = self.distrax_cls(**dist_params)
samples = self.variant(dist.sample)(seed=self.key)
self.assertEqual(samples.dtype, dist.dtype)
chex.assert_type(samples, dtype)

@chex.all_variants
def test_sample_extreme_probs(self):
Expand Down
16 changes: 9 additions & 7 deletions distrax/_src/distributions/mvn_diag_plus_low_rank_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from distrax._src.utils import equivalence

import jax
import jax.experimental
import jax.numpy as jnp
import numpy as np
from tensorflow_probability.substrates import jax as tfp
Expand Down Expand Up @@ -180,13 +181,14 @@ def test_sample_shape(self, sample_shape, loc_shape, scale_diag_shape,
('float32', jnp.float32),
('float64', jnp.float64))
def test_sample_dtype(self, dtype):
dist_params = {
'loc': np.array([0., 0.], dtype),
'scale_diag': np.array([1., 1.], dtype)}
dist = MultivariateNormalDiagPlusLowRank(**dist_params)
samples = self.variant(dist.sample)(seed=jax.random.PRNGKey(0))
self.assertEqual(samples.dtype, dist.dtype)
chex.assert_type(samples, dtype)
with jax.experimental.enable_x64(dtype.dtype.itemsize == 8):
dist_params = {
'loc': np.array([0., 0.], dtype),
'scale_diag': np.array([1., 1.], dtype)}
dist = MultivariateNormalDiagPlusLowRank(**dist_params)
samples = self.variant(dist.sample)(seed=jax.random.PRNGKey(0))
self.assertEqual(samples.dtype, dist.dtype)
chex.assert_type(samples, dtype)

@chex.all_variants
@parameterized.named_parameters(
Expand Down
16 changes: 9 additions & 7 deletions distrax/_src/distributions/mvn_diag_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from distrax._src.distributions import normal
from distrax._src.utils import equivalence
import jax
import jax.experimental
import jax.numpy as jnp
import numpy as np

Expand Down Expand Up @@ -214,13 +215,14 @@ def test_sample_and_log_prob(self, distr_params, sample_shape):
('float32', jnp.float32),
('float64', jnp.float64))
def test_sample_dtype(self, dtype):
dist_params = {
'loc': np.array([0., 0.], dtype),
'scale_diag': np.array([1., 1.], dtype)}
dist = self.distrax_cls(**dist_params)
samples = self.variant(dist.sample)(seed=self.key)
self.assertEqual(samples.dtype, dist.dtype)
chex.assert_type(samples, dtype)
with jax.experimental.enable_x64(dtype.dtype.itemsize == 8):
dist_params = {
'loc': np.array([0., 0.], dtype),
'scale_diag': np.array([1., 1.], dtype)}
dist = self.distrax_cls(**dist_params)
samples = self.variant(dist.sample)(seed=self.key)
self.assertEqual(samples.dtype, dist.dtype)
chex.assert_type(samples, dtype)

@chex.all_variants
@parameterized.named_parameters(
Expand Down
16 changes: 9 additions & 7 deletions distrax/_src/distributions/mvn_full_covariance_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import chex
from distrax._src.distributions.mvn_full_covariance import MultivariateNormalFullCovariance
from distrax._src.utils import equivalence
import jax.experimental
import jax.numpy as jnp
import numpy as np

Expand Down Expand Up @@ -106,13 +107,14 @@ def test_sample_shape(self, sample_shape, loc_shape, covariance_matrix_shape):
('float32', jnp.float32),
('float64', jnp.float64))
def test_sample_dtype(self, dtype):
dist_params = {
'loc': np.array([0., 0.], dtype),
'covariance_matrix': np.array([[1., 0.], [0., 1.]], dtype)}
dist = self.distrax_cls(**dist_params)
samples = self.variant(dist.sample)(seed=self.key)
self.assertEqual(samples.dtype, dist.dtype)
chex.assert_type(samples, dtype)
with jax.experimental.enable_x64(dtype.dtype.itemsize == 8):
dist_params = {
'loc': np.array([0., 0.], dtype),
'covariance_matrix': np.array([[1., 0.], [0., 1.]], dtype)}
dist = self.distrax_cls(**dist_params)
samples = self.variant(dist.sample)(seed=self.key)
self.assertEqual(samples.dtype, dist.dtype)
chex.assert_type(samples, dtype)

@chex.all_variants
@parameterized.named_parameters(
Expand Down
16 changes: 9 additions & 7 deletions distrax/_src/distributions/mvn_tri_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import chex
from distrax._src.distributions.mvn_tri import MultivariateNormalTri
from distrax._src.utils import equivalence
import jax.experimental
import jax.numpy as jnp
import numpy as np

Expand Down Expand Up @@ -114,13 +115,14 @@ def test_sample_shape(self, sample_shape, loc_shape, scale_tri_shape):
('float32', jnp.float32),
('float64', jnp.float64))
def test_sample_dtype(self, dtype):
dist_params = {
'loc': np.array([0., 0.], dtype),
'scale_tri': np.array([[1., 0.], [0., 1.]], dtype)}
dist = self.distrax_cls(**dist_params)
samples = self.variant(dist.sample)(seed=self.key)
self.assertEqual(samples.dtype, dist.dtype)
chex.assert_type(samples, dtype)
with jax.experimental.enable_x64(dtype.dtype.itemsize == 8):
dist_params = {
'loc': np.array([0., 0.], dtype),
'scale_tri': np.array([[1., 0.], [0., 1.]], dtype)}
dist = self.distrax_cls(**dist_params)
samples = self.variant(dist.sample)(seed=self.key)
self.assertEqual(samples.dtype, dist.dtype)
chex.assert_type(samples, dtype)

@chex.all_variants
@parameterized.named_parameters(
Expand Down
12 changes: 7 additions & 5 deletions distrax/_src/distributions/one_hot_categorical_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from distrax._src.utils import equivalence
from distrax._src.utils import math
import jax
import jax.experimental
import jax.numpy as jnp
import numpy as np
import scipy
Expand Down Expand Up @@ -178,11 +179,12 @@ def test_sample_and_log_prob(self, distr_params, sample_shape):
('float32', jnp.float32),
('float64', jnp.float64))
def test_sample_dtype(self, dtype):
dist_params = {'logits': self.logits, 'dtype': dtype}
dist = self.distrax_cls(**dist_params)
samples = self.variant(dist.sample)(seed=self.key)
self.assertEqual(samples.dtype, dist.dtype)
chex.assert_type(samples, dtype)
with jax.experimental.enable_x64(dtype.dtype.itemsize == 8):
dist_params = {'logits': self.logits, 'dtype': dtype}
dist = self.distrax_cls(**dist_params)
samples = self.variant(dist.sample)(seed=self.key)
self.assertEqual(samples.dtype, dist.dtype)
chex.assert_type(samples, dtype)

@chex.all_variants
@parameterized.named_parameters(
Expand Down
11 changes: 6 additions & 5 deletions distrax/_src/distributions/softmax_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,12 @@ def test_parameters(self):
('float32', jnp.float32),
('float64', jnp.float64))
def test_sample_dtype(self, dtype):
dist = self.distrax_cls(
logits=self.logits, temperature=self.temperature, dtype=dtype)
samples = self.variant(dist.sample)(seed=self.key)
self.assertEqual(samples.dtype, dist.dtype)
chex.assert_type(samples, dtype)
with jax.experimental.enable_x64(dtype.dtype.itemsize == 8):
dist = self.distrax_cls(
logits=self.logits, temperature=self.temperature, dtype=dtype)
samples = self.variant(dist.sample)(seed=self.key)
self.assertEqual(samples.dtype, dist.dtype)
chex.assert_type(samples, dtype)

def test_jittable(self):
super()._test_jittable((np.array([2., 4., 1., 3.]),))
Expand Down

0 comments on commit d3d039f

Please sign in to comment.