Skip to content

Commit

Permalink
Add jax.random.multinomial.
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosgmartin committed Dec 27, 2024
1 parent 6dbda90 commit 5320d55
Show file tree
Hide file tree
Showing 5 changed files with 71 additions and 0 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
{func}`jax.numpy.fft.ifftn`, and {func}`jax.numpy.fft.irfftn` now support
transforms in more than 3 dimensions, which was previously the limit. See
{jax-issue}`#25606` for more details.
* Added {func}`jax.random.multinomial`.

* Deprecations
* From {mod}`jax.interpreters.xla`, `abstractify` and `pytype_aval_mappings`
Expand Down
1 change: 1 addition & 0 deletions docs/jax.random.rst
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ Random Samplers
logistic
lognormal
maxwell
multinomial
multivariate_normal
normal
orthogonal
Expand Down
55 changes: 55 additions & 0 deletions jax/_src/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -2627,6 +2627,61 @@ def binomial(
batching.defvectorized(random_clone_p)
mlir.register_lowering(random_clone_p, lambda _, k: [k])


def multinomial(
key: Array,
n: RealArray,
p: RealArray,
axis: int = -1,
):
r"""Sample from a multinomial distribution.
The probability mass function is
.. math::
f(x;n,p) = \frac{n!}{x_1! \ldots x_k!} p_1^{x_1} \ldots p_k^{x_k}
Args:
key: a PRNG key used as the random key.
n: a float array-like representing the number of trials.
p: a float array-like representing the probabilities of each outcome.
axis: axis along which probabilities are defined for each outcome.
Returns:
An array of counts for each outcome.
"""

key, _ = _check_prng_key("multinomial", key)
check_arraylike("multinomial", n, p)

def f(remainder, p_r_key):
p, r, key = p_r_key
count = binomial(key, remainder, p / r)
count = jnp.where(r == 0, 0, count)
return remainder - count, count

p = jnp.moveaxis(p, axis, 0)

p_shape = jnp.shape(p)

shape = jnp.broadcast_shapes(jnp.shape(n), p_shape[1:])
n = jnp.broadcast_to(n, shape)
p = jnp.broadcast_to(p, (p_shape[0],) + shape)

# remaining probabilities
r = lax.cumsum(p, 0, reverse=True)

keys = split(key, p_shape[0])

remainder, counts = lax.scan(f, n, (p, r, keys), unroll=True)

# remainder should end up as zeros

counts = jnp.moveaxis(counts, 0, axis)

return counts


def clone(key):
"""Clone a key for reuse
Expand Down
1 change: 1 addition & 0 deletions jax/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,7 @@
loggamma as loggamma,
lognormal as lognormal,
maxwell as maxwell,
multinomial as multinomial,
multivariate_normal as multivariate_normal,
normal as normal,
orthogonal as orthogonal,
Expand Down
13 changes: 13 additions & 0 deletions tests/random_lax_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1243,6 +1243,19 @@ def testBinomialCornerCases(self):
self.assertArraysAllClose(samples2, jnp.array([jnp.nan, 0., jnp.nan, jnp.nan]), check_dtypes=False)
self.assertArraysAllClose(samples3, jnp.array([jnp.nan, jnp.nan, jnp.nan]), check_dtypes=False)

def testMultinomial(self):
key = random.key(0)
probs = jnp.array([
[0.5, 0.2, 0.3],
[0.1, 0.2, 0.7],
[0.0, 1.0, 0.0],
[0.0, 0.0, 1.0],
])
trials = jnp.array(10**8).astype(float)
counts = random.multinomial(key, trials, probs)
freqs = counts / trials
self.assertAllClose(freqs, probs, atol=1e-3)

def test_batched_key_errors(self):
keys = lambda: jax.random.split(self.make_key(0))
msg = "{} accepts a single key, but was given a key array of shape.*"
Expand Down

0 comments on commit 5320d55

Please sign in to comment.