Skip to content

Commit

Permalink
Merge pull request #25390 from jakevdp:matvec
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 705077854
  • Loading branch information
Google-ML-Automation committed Dec 11, 2024
2 parents 13ce517 + f6d5876 commit b79dae8
Show file tree
Hide file tree
Showing 5 changed files with 142 additions and 3 deletions.
2 changes: 2 additions & 0 deletions docs/jax.numpy.rst
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,7 @@ namespace; they are listed below.
mask_indices
matmul
matrix_transpose
matvec
max
maximum
mean
Expand Down Expand Up @@ -428,6 +429,7 @@ namespace; they are listed below.
var
vdot
vecdot
vecmat
vectorize
vsplit
vstack
Expand Down
84 changes: 84 additions & 0 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -9168,6 +9168,89 @@ def matmul(a: ArrayLike, b: ArrayLike, *,
return lax_internal._convert_element_type(result, preferred_element_type, output_weak_type)


@export
@jit
def matvec(x1: ArrayLike, x2: ArrayLike, /) -> Array:
"""Batched matrix-vector product.
JAX implementation of :func:`numpy.matvec`.
Args:
x1: array of shape ``(..., M, N)``
x2: array of shape ``(..., N)``. Leading dimensions must be broadcast-compatible
with leading dimensions of ``x1``.
Returns:
An array of shape ``(..., M)`` containing the batched matrix-vector product.
See also:
- :func:`jax.numpy.linalg.vecdot`: batched vector product.
- :func:`jax.numpy.vecmat`: vector-matrix product.
- :func:`jax.numpy.matmul`: general matrix multiplication.
Examples:
Simple matrix-vector product:
>>> x1 = jnp.array([[1, 2, 3],
... [4, 5, 6]])
>>> x2 = jnp.array([7, 8, 9])
>>> jnp.matvec(x1, x2)
Array([ 50, 122], dtype=int32)
Batched matrix-vector product:
>>> x2 = jnp.array([[7, 8, 9],
... [5, 6, 7]])
>>> jnp.matvec(x1, x2)
Array([[ 50, 122],
[ 38, 92]], dtype=int32)
"""
util.check_arraylike("matvec", x1, x2)
return vectorize(matmul, signature="(n,m),(m)->(n)")(x1, x2)


@export
@jit
def vecmat(x1: ArrayLike, x2: ArrayLike, /) -> Array:
"""Batched conjugate vector-matrix product.
JAX implementation of :func:`numpy.vecmat`.
Args:
x1: array of shape ``(..., M)``.
x2: array of shape ``(..., M, N)``. Leading dimensions must be broadcast-compatible
with leading dimensions of ``x1``.
Returns:
An array of shape ``(..., N)`` containing the batched conjugate vector-matrix product.
See also:
- :func:`jax.numpy.linalg.vecdot`: batched vector product.
- :func:`jax.numpy.matvec`: matrix-vector product.
- :func:`jax.numpy.matmul`: general matrix multiplication.
Examples:
Simple vector-matrix product:
>>> x1 = jnp.array([[1, 2, 3]])
>>> x2 = jnp.array([[4, 5],
... [6, 7],
... [8, 9]])
>>> jnp.vecmat(x1, x2)
Array([[40, 46]], dtype=int32)
Batched vector-matrix product:
>>> x1 = jnp.array([[1, 2, 3],
... [4, 5, 6]])
>>> jnp.vecmat(x1, x2)
Array([[ 40, 46],
[ 94, 109]], dtype=int32)
"""
util.check_arraylike("matvec", x1, x2)
return vectorize(matmul, signature="(n),(n,m)->(m)")(ufuncs.conj(x1), x2)


@export
@partial(jit, static_argnames=('precision', 'preferred_element_type'), inline=True)
def vdot(
Expand Down Expand Up @@ -9244,6 +9327,7 @@ def vecdot(x1: ArrayLike, x2: ArrayLike, /, *, axis: int = -1,
See Also:
- :func:`jax.numpy.vdot`: flattened vector product.
- :func:`jax.numpy.vecmat`: vector-matrix product.
- :func:`jax.numpy.matmul`: general matrix multiplication.
- :func:`jax.lax.dot_general`: general N-dimensional batched dot product.
Expand Down
2 changes: 2 additions & 0 deletions jax/numpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@
logspace as logspace,
mask_indices as mask_indices,
matmul as matmul,
matvec as matvec,
matrix_transpose as matrix_transpose,
meshgrid as meshgrid,
moveaxis as moveaxis,
Expand Down Expand Up @@ -258,6 +259,7 @@
vander as vander,
vdot as vdot,
vecdot as vecdot,
vecmat as vecmat,
vsplit as vsplit,
vstack as vstack,
where as where,
Expand Down
2 changes: 2 additions & 0 deletions jax/numpy/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -645,6 +645,7 @@ def matmul(
a: ArrayLike, b: ArrayLike, *, precision: PrecisionLike = ...,
preferred_element_type: DTypeLike | None = ...) -> Array: ...
def matrix_transpose(x: ArrayLike, /) -> Array: ...
def matvec(x1: ArrayLike, x2: ArrayLike, /) -> Array: ...
def max(a: ArrayLike, axis: _Axis = ..., out: None = ...,
keepdims: builtins.bool = ..., initial: ArrayLike | None = ...,
where: ArrayLike | None = ...) -> Array: ...
Expand Down Expand Up @@ -995,6 +996,7 @@ def vdot(
def vecdot(x1: ArrayLike, x2: ArrayLike, /, *, axis: int = ...,
precision: PrecisionLike = ...,
preferred_element_type: DTypeLike | None = ...) -> Array: ...
def vecmat(x1: ArrayLike, x2: ArrayLike, /) -> Array: ...
def vsplit(
ary: ArrayLike, indices_or_sections: int | ArrayLike
) -> list[Array]: ...
Expand Down
55 changes: 52 additions & 3 deletions tests/lax_numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -649,6 +649,57 @@ def np_fn(x, y, axis=axis):
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, tol=tol)
self._CompileAndCheck(jnp_fn, args_maker, tol=tol)

@jtu.sample_product(
lhs_batch=broadcast_compatible_shapes,
rhs_batch=broadcast_compatible_shapes,
mat_size=[1, 2, 3],
vec_size=[2, 3, 4],
dtype=number_dtypes,
)
@jax.default_matmul_precision("float32")
@jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion.
def testMatvec(self, lhs_batch, rhs_batch, mat_size, vec_size, dtype):
rng = jtu.rand_default(self.rng())
lhs_shape = (*lhs_batch, mat_size, vec_size)
rhs_shape = (*rhs_batch, vec_size)
args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)]
jnp_fn = jnp.matvec
@jtu.promote_like_jnp
def np_fn(x, y):
f = (np.vectorize(np.matmul, signature="(m,n),(n)->(m)")
if jtu.numpy_version() < (2, 2, 0) else np.matvec)
return f(x, y).astype(x.dtype)
tol = {np.float16: 1e-2, np.float32: 1E-3, np.float64: 1e-12,
np.complex64: 1E-3, np.complex128: 1e-12, jnp.bfloat16: 1e-1}
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, tol=tol)
self._CompileAndCheck(jnp_fn, args_maker, tol=tol)

@jtu.sample_product(
lhs_batch=broadcast_compatible_shapes,
rhs_batch=broadcast_compatible_shapes,
mat_size=[1, 2, 3],
vec_size=[2, 3, 4],
dtype=number_dtypes,
)
@jax.default_matmul_precision("float32")
@jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion.
def testVecmat(self, lhs_batch, rhs_batch, mat_size, vec_size, dtype):
rng = jtu.rand_default(self.rng())
lhs_shape = (*lhs_batch, vec_size)
rhs_shape = (*rhs_batch, vec_size, mat_size)
args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)]
jnp_fn = jnp.vecmat
@jtu.promote_like_jnp
def np_fn(x, y):
f = (np.vectorize(lambda x, y: np.matmul(np.conj(x), y),
signature="(m),(m,n)->(n)")
if jtu.numpy_version() < (2, 2, 0) else np.vecmat)
return f(x, y).astype(x.dtype)
tol = {np.float16: 1e-2, np.float32: 1E-3, np.float64: 1e-12,
np.complex64: 1E-3, np.complex128: 1e-12, jnp.bfloat16: 1e-1}
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, tol=tol)
self._CompileAndCheck(jnp_fn, args_maker, tol=tol)

@jtu.sample_product(
[dict(lhs_shape=lhs_shape, rhs_shape=rhs_shape, axes=axes)
for lhs_shape, rhs_shape, axes in [
Expand Down Expand Up @@ -6257,7 +6308,6 @@ def testWrappedSignaturesMatch(self):
'isnat',
'loadtxt',
'matrix',
'matvec',
'may_share_memory',
'memmap',
'min_scalar_type',
Expand All @@ -6283,8 +6333,7 @@ def testWrappedSignaturesMatch(self):
'show_runtime',
'test',
'trapz',
'typename',
'vecmat'}
'typename'}

# symbols removed in NumPy 2.0
skip |= {'add_docstring',
Expand Down

0 comments on commit b79dae8

Please sign in to comment.