Skip to content

Commit

Permalink
use numpy type promotion directly (inducer#540)
Browse files Browse the repository at this point in the history
* use numpy 2 type promotion directly

* fix lint

* add comment & test regarding np.result_type with pytato arrays

* adjust README

* revert old_numpy test adjustment (tested manually with numpy 1.26)

* remove comment
  • Loading branch information
matthiasdiener authored Sep 12, 2024
1 parent 3fedafc commit 850dcb4
Show file tree
Hide file tree
Showing 5 changed files with 32 additions and 126 deletions.
11 changes: 5 additions & 6 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,8 @@ Numpy compatibility
Pytato is written to pose no particular restrictions on the version of numpy
used for execution. To use mypy-based type checking on Pytato itself or
packages using Pytato, numpy 1.20 or newer is required, due to the
typing-based changes to numpy in that release. Furthermore, pytato
now uses type promotion rules aiming to match those in
`numpy 2 <https://numpy.org/devdocs/numpy_2_0_migration_guide.html#changes-to-numpy-data-type-promotion>`__.
This will not break compatibility with older numpy versions, but may
result in differing data types between computations carried out in
numpy and pytato.
typing-based changes to numpy in that release.

Furthermore, pytato now uses type promotion rules based on those in
`numpy <https://numpy.org/devdocs/numpy_2_0_migration_guide.html#changes-to-numpy-data-type-promotion>`__ that should result in the same
data types as the currently installed version of numpy.
88 changes: 5 additions & 83 deletions pytato/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,6 @@
import operator
import re
from abc import ABC, abstractmethod
from enum import IntEnum
from functools import cached_property, partialmethod
from typing import (
TYPE_CHECKING,
Expand Down Expand Up @@ -306,87 +305,10 @@ def normalize_shape_component(
ArrayOrScalar = Union["Array", Scalar]


# https://numpy.org/neps/nep-0050-scalar-promotion.html
class DtypeKindCategory(IntEnum):
BOOLEAN = 0
INTEGRAL = 1
INEXACT = 2


_dtype_kind_char_to_kind_cat = {
"b": DtypeKindCategory.BOOLEAN,
"i": DtypeKindCategory.INTEGRAL,
"u": DtypeKindCategory.INTEGRAL,
"f": DtypeKindCategory.INEXACT,
"c": DtypeKindCategory.INEXACT,
}


_py_type_to_kind_cat = {
bool: DtypeKindCategory.BOOLEAN,
int: DtypeKindCategory.INTEGRAL,
float: DtypeKindCategory.INEXACT,
complex: DtypeKindCategory.INEXACT,
}


_float_dtype_to_complex: dict[np.dtype[Any], np.dtype[Any]] = {
np.dtype(np.float32): np.dtype(np.complex64),
np.dtype(np.float64): np.dtype(np.complex128),
}


def _complexify_dtype(dtype: np.dtype[Any]) -> np.dtype[Any]:
if dtype.kind == "c":
return dtype
elif dtype.kind == "f":
return _float_dtype_to_complex[dtype]
else:
raise ValueError("can only complexify types that are already inexact")


def _np_result_dtype(*dtypes: DtypeOrPyScalarType) -> np.dtype[Any]:
# For numpy 2.0, np.result_type does not implement numpy's type
# promotion behavior. Weird. Hence all this nonsense is needed.

py_types = [dtype for dtype in dtypes if isinstance(dtype, type)]

if not py_types:
return np.result_type(*dtypes)

np_dtypes = [dtype for dtype in dtypes if isinstance(dtype, np.dtype)]
np_kind_cats = {
_dtype_kind_char_to_kind_cat[dtype.kind] for dtype in np_dtypes}
py_kind_cats = {_py_type_to_kind_cat[tp] for tp in py_types}
kind_cats = np_kind_cats | py_kind_cats

res_kind_cat = max(kind_cats)
max_py_kind_cats = max(py_kind_cats)
max_np_kind_cats = max(np_kind_cats)

is_complex = (complex in py_types
or any(dtype.kind == "c" for dtype in np_dtypes))

if max_py_kind_cats > max_np_kind_cats:
if res_kind_cat == DtypeKindCategory.INTEGRAL:
# FIXME: Perhaps this should be int32 "on some systems, e.g. Windows"
py_promotion_dtype: np.dtype[Any] = np.dtype(np.int64)
elif res_kind_cat == DtypeKindCategory.INEXACT:
if is_complex:
py_promotion_dtype = np.dtype(np.complex128)
else:
py_promotion_dtype = np.dtype(np.float64)
else:
# bool won't ever be promoted to
raise AssertionError()
return np.result_type(*([*np_dtypes, py_promotion_dtype]))

else:
# Just ignore the python types for promotion.
result = np.result_type(*np_dtypes)
if is_complex:
result = _complexify_dtype(result)
return result
def _np_result_dtype(
*arrays_and_dtypes: np.typing.ArrayLike | np.typing.DTypeLike,
) -> np.dtype[Any]:
return np.result_type(*arrays_and_dtypes)


def _truediv_result_type(*dtypes: DtypeOrPyScalarType) -> np.dtype[Any]:
Expand Down Expand Up @@ -656,7 +578,7 @@ def _binary_op(
op: Callable[[ScalarExpression, ScalarExpression], ScalarExpression],
other: ArrayOrScalar,
get_result_type: Callable[
[DtypeOrPyScalarType, DtypeOrPyScalarType],
[ArrayOrScalar, ArrayOrScalar],
np.dtype[Any]] = _np_result_dtype,
reverse: bool = False,
cast_to_result_dtype: bool = True,
Expand Down
26 changes: 5 additions & 21 deletions pytato/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@
ArrayOrScalar,
BasicIndex,
ConvertibleToIndexExpr,
DtypeOrPyScalarType,
Einsum,
IndexExpr,
IndexLambda,
Expand All @@ -58,7 +57,6 @@
)
from pytato.scalar_expr import (
INT_CLASSES,
PYTHON_SCALAR_CLASSES,
SCALAR_CLASSES,
BoolT,
IntegralScalarExpression,
Expand Down Expand Up @@ -164,22 +162,6 @@ def with_indices_for_broadcasted_shape(val: prim.Variable, shape: ShapeType,
return val[get_indexing_expression(shape, result_shape)]


def _extract_dtypes(
exprs: Sequence[ArrayOrScalar]) -> list[DtypeOrPyScalarType]:
dtypes: list[DtypeOrPyScalarType] = []
for expr in exprs:
if isinstance(expr, Array):
dtypes.append(expr.dtype)
elif isinstance(expr, np.generic):
dtypes.append(expr.dtype)
elif isinstance(expr, PYTHON_SCALAR_CLASSES):
dtypes.append(type(expr))
else:
raise TypeError(f"unexpected expression type: '{type(expr)}'")

return dtypes


def update_bindings_and_get_broadcasted_expr(arr: ArrayOrScalar,
bnd_name: str,
bindings: dict[str, Array],
Expand Down Expand Up @@ -208,7 +190,7 @@ def update_bindings_and_get_broadcasted_expr(arr: ArrayOrScalar,

def broadcast_binary_op(a1: ArrayOrScalar, a2: ArrayOrScalar,
op: Callable[[ScalarExpression, ScalarExpression], ScalarExpression], # noqa:E501
get_result_type: Callable[[DtypeOrPyScalarType, DtypeOrPyScalarType], np.dtype[Any]], # noqa:E501
get_result_type: Callable[[ArrayOrScalar, ArrayOrScalar], np.dtype[Any]], # noqa:E501
*,
tags: frozenset[Tag],
non_equality_tags: frozenset[Tag],
Expand All @@ -222,8 +204,10 @@ def broadcast_binary_op(a1: ArrayOrScalar, a2: ArrayOrScalar,

result_shape = get_shape_after_broadcasting([a1, a2])

dtypes = _extract_dtypes([a1, a2])
result_dtype = get_result_type(*dtypes)
# Note: get_result_type calls np.result_type by default, which means
# that we are passing a pytato array to numpy. Luckily, np.result_type
# only looks at the dtype of input arrays as of numpy v2.1.
result_dtype = get_result_type(a1, a2)

bindings: dict[str, Array] = {}

Expand Down
18 changes: 2 additions & 16 deletions test/test_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,9 +272,6 @@ def wrapper(*args):
"logical_or"))
@pytest.mark.parametrize("reverse", (False, True))
def test_scalar_array_binary_arith(ctx_factory, which, reverse):
from numpy.lib import NumpyVersion
is_old_numpy = NumpyVersion(np.__version__) < "2.0.0"

cl_ctx = ctx_factory()
queue = cl.CommandQueue(cl_ctx)
not_valid_in_complex = which in ["equal", "not_equal", "less", "less_equal",
Expand Down Expand Up @@ -319,18 +316,8 @@ def test_scalar_array_binary_arith(ctx_factory, which, reverse):
out = outputs[dtype]
out_ref = np_op(x_in, y_orig.astype(dtype))

if not is_old_numpy:
assert out.dtype == out_ref.dtype, (out.dtype, out_ref.dtype)

# In some cases ops are done in float32 in loopy but float64 in numpy.
is_allclose = np.allclose(out, out_ref), (out, out_ref)
if not is_old_numpy:
assert is_allclose
else:
if out_ref.dtype.itemsize == 1:
pass
else:
assert is_allclose
assert out.dtype == out_ref.dtype, (out.dtype, out_ref.dtype)
assert np.allclose(out, out_ref), (out, out_ref)


@pytest.mark.parametrize("which", ("add", "sub", "mul", "truediv", "pow",
Expand Down Expand Up @@ -389,7 +376,6 @@ def test_array_array_binary_arith(ctx_factory, which, reverse):
out_ref = np_op(x_in, y_orig.astype(dtype))

assert out.dtype == out_ref.dtype, (out.dtype, out_ref.dtype)
# In some cases ops are done in float32 in loopy but float64 in numpy.
assert np.allclose(out, out_ref), (out, out_ref)


Expand Down
15 changes: 15 additions & 0 deletions test/test_pytato.py
Original file line number Diff line number Diff line change
Expand Up @@ -1364,6 +1364,21 @@ def test_dot_visualizers():
# }}}


def test_numpy_type_promotion_with_pytato_arrays():
class NotReallyAnArray:
@property
def dtype(self):
return np.dtype("float64")

# Make sure that np.result_type accesses only the dtype attribute of the
# class, not (e.g.) its data.
assert np.result_type(42, NotReallyAnArray()) == np.float64

from pytato.array import _np_result_dtype
assert _np_result_dtype(42, NotReallyAnArray()) == np.float64
assert _np_result_dtype(42.0, NotReallyAnArray()) == np.float64


if __name__ == "__main__":
if len(sys.argv) > 1:
exec(sys.argv[1])
Expand Down

0 comments on commit 850dcb4

Please sign in to comment.