From 850dcb47ea4f36f6bcde5b614dc159c212f28926 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Thu, 12 Sep 2024 10:59:07 -0500 Subject: [PATCH] use numpy type promotion directly (#540) * use numpy 2 type promotion directly * fix lint * add comment & test regarding np.result_type with pytato arrays * adjust README * revert old_numpy test adjustment (tested manually with numpy 1.26) * remove comment --- README.rst | 11 +++--- pytato/array.py | 88 +++----------------------------------------- pytato/utils.py | 26 +++---------- test/test_codegen.py | 18 +-------- test/test_pytato.py | 15 ++++++++ 5 files changed, 32 insertions(+), 126 deletions(-) diff --git a/README.rst b/README.rst index 5d588f37c..05873b6fb 100644 --- a/README.rst +++ b/README.rst @@ -32,9 +32,8 @@ Numpy compatibility Pytato is written to pose no particular restrictions on the version of numpy used for execution. To use mypy-based type checking on Pytato itself or packages using Pytato, numpy 1.20 or newer is required, due to the -typing-based changes to numpy in that release. Furthermore, pytato -now uses type promotion rules aiming to match those in -`numpy 2 `__. -This will not break compatibility with older numpy versions, but may -result in differing data types between computations carried out in -numpy and pytato. +typing-based changes to numpy in that release. + +Furthermore, pytato now uses type promotion rules based on those in +`numpy `__ that should result in the same +data types as the currently installed version of numpy. diff --git a/pytato/array.py b/pytato/array.py index 0ee917d1e..6f9221ae8 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -172,7 +172,6 @@ import operator import re from abc import ABC, abstractmethod -from enum import IntEnum from functools import cached_property, partialmethod from typing import ( TYPE_CHECKING, @@ -306,87 +305,10 @@ def normalize_shape_component( ArrayOrScalar = Union["Array", Scalar] -# https://numpy.org/neps/nep-0050-scalar-promotion.html -class DtypeKindCategory(IntEnum): - BOOLEAN = 0 - INTEGRAL = 1 - INEXACT = 2 - - -_dtype_kind_char_to_kind_cat = { - "b": DtypeKindCategory.BOOLEAN, - "i": DtypeKindCategory.INTEGRAL, - "u": DtypeKindCategory.INTEGRAL, - "f": DtypeKindCategory.INEXACT, - "c": DtypeKindCategory.INEXACT, -} - - -_py_type_to_kind_cat = { - bool: DtypeKindCategory.BOOLEAN, - int: DtypeKindCategory.INTEGRAL, - float: DtypeKindCategory.INEXACT, - complex: DtypeKindCategory.INEXACT, -} - - -_float_dtype_to_complex: dict[np.dtype[Any], np.dtype[Any]] = { - np.dtype(np.float32): np.dtype(np.complex64), - np.dtype(np.float64): np.dtype(np.complex128), -} - - -def _complexify_dtype(dtype: np.dtype[Any]) -> np.dtype[Any]: - if dtype.kind == "c": - return dtype - elif dtype.kind == "f": - return _float_dtype_to_complex[dtype] - else: - raise ValueError("can only complexify types that are already inexact") - - -def _np_result_dtype(*dtypes: DtypeOrPyScalarType) -> np.dtype[Any]: - # For numpy 2.0, np.result_type does not implement numpy's type - # promotion behavior. Weird. Hence all this nonsense is needed. - - py_types = [dtype for dtype in dtypes if isinstance(dtype, type)] - - if not py_types: - return np.result_type(*dtypes) - - np_dtypes = [dtype for dtype in dtypes if isinstance(dtype, np.dtype)] - np_kind_cats = { - _dtype_kind_char_to_kind_cat[dtype.kind] for dtype in np_dtypes} - py_kind_cats = {_py_type_to_kind_cat[tp] for tp in py_types} - kind_cats = np_kind_cats | py_kind_cats - - res_kind_cat = max(kind_cats) - max_py_kind_cats = max(py_kind_cats) - max_np_kind_cats = max(np_kind_cats) - - is_complex = (complex in py_types - or any(dtype.kind == "c" for dtype in np_dtypes)) - - if max_py_kind_cats > max_np_kind_cats: - if res_kind_cat == DtypeKindCategory.INTEGRAL: - # FIXME: Perhaps this should be int32 "on some systems, e.g. Windows" - py_promotion_dtype: np.dtype[Any] = np.dtype(np.int64) - elif res_kind_cat == DtypeKindCategory.INEXACT: - if is_complex: - py_promotion_dtype = np.dtype(np.complex128) - else: - py_promotion_dtype = np.dtype(np.float64) - else: - # bool won't ever be promoted to - raise AssertionError() - return np.result_type(*([*np_dtypes, py_promotion_dtype])) - - else: - # Just ignore the python types for promotion. - result = np.result_type(*np_dtypes) - if is_complex: - result = _complexify_dtype(result) - return result +def _np_result_dtype( + *arrays_and_dtypes: np.typing.ArrayLike | np.typing.DTypeLike, + ) -> np.dtype[Any]: + return np.result_type(*arrays_and_dtypes) def _truediv_result_type(*dtypes: DtypeOrPyScalarType) -> np.dtype[Any]: @@ -656,7 +578,7 @@ def _binary_op( op: Callable[[ScalarExpression, ScalarExpression], ScalarExpression], other: ArrayOrScalar, get_result_type: Callable[ - [DtypeOrPyScalarType, DtypeOrPyScalarType], + [ArrayOrScalar, ArrayOrScalar], np.dtype[Any]] = _np_result_dtype, reverse: bool = False, cast_to_result_dtype: bool = True, diff --git a/pytato/utils.py b/pytato/utils.py index 3da2d3893..f4261685c 100644 --- a/pytato/utils.py +++ b/pytato/utils.py @@ -46,7 +46,6 @@ ArrayOrScalar, BasicIndex, ConvertibleToIndexExpr, - DtypeOrPyScalarType, Einsum, IndexExpr, IndexLambda, @@ -58,7 +57,6 @@ ) from pytato.scalar_expr import ( INT_CLASSES, - PYTHON_SCALAR_CLASSES, SCALAR_CLASSES, BoolT, IntegralScalarExpression, @@ -164,22 +162,6 @@ def with_indices_for_broadcasted_shape(val: prim.Variable, shape: ShapeType, return val[get_indexing_expression(shape, result_shape)] -def _extract_dtypes( - exprs: Sequence[ArrayOrScalar]) -> list[DtypeOrPyScalarType]: - dtypes: list[DtypeOrPyScalarType] = [] - for expr in exprs: - if isinstance(expr, Array): - dtypes.append(expr.dtype) - elif isinstance(expr, np.generic): - dtypes.append(expr.dtype) - elif isinstance(expr, PYTHON_SCALAR_CLASSES): - dtypes.append(type(expr)) - else: - raise TypeError(f"unexpected expression type: '{type(expr)}'") - - return dtypes - - def update_bindings_and_get_broadcasted_expr(arr: ArrayOrScalar, bnd_name: str, bindings: dict[str, Array], @@ -208,7 +190,7 @@ def update_bindings_and_get_broadcasted_expr(arr: ArrayOrScalar, def broadcast_binary_op(a1: ArrayOrScalar, a2: ArrayOrScalar, op: Callable[[ScalarExpression, ScalarExpression], ScalarExpression], # noqa:E501 - get_result_type: Callable[[DtypeOrPyScalarType, DtypeOrPyScalarType], np.dtype[Any]], # noqa:E501 + get_result_type: Callable[[ArrayOrScalar, ArrayOrScalar], np.dtype[Any]], # noqa:E501 *, tags: frozenset[Tag], non_equality_tags: frozenset[Tag], @@ -222,8 +204,10 @@ def broadcast_binary_op(a1: ArrayOrScalar, a2: ArrayOrScalar, result_shape = get_shape_after_broadcasting([a1, a2]) - dtypes = _extract_dtypes([a1, a2]) - result_dtype = get_result_type(*dtypes) + # Note: get_result_type calls np.result_type by default, which means + # that we are passing a pytato array to numpy. Luckily, np.result_type + # only looks at the dtype of input arrays as of numpy v2.1. + result_dtype = get_result_type(a1, a2) bindings: dict[str, Array] = {} diff --git a/test/test_codegen.py b/test/test_codegen.py index 5df31063b..5380f0676 100755 --- a/test/test_codegen.py +++ b/test/test_codegen.py @@ -272,9 +272,6 @@ def wrapper(*args): "logical_or")) @pytest.mark.parametrize("reverse", (False, True)) def test_scalar_array_binary_arith(ctx_factory, which, reverse): - from numpy.lib import NumpyVersion - is_old_numpy = NumpyVersion(np.__version__) < "2.0.0" - cl_ctx = ctx_factory() queue = cl.CommandQueue(cl_ctx) not_valid_in_complex = which in ["equal", "not_equal", "less", "less_equal", @@ -319,18 +316,8 @@ def test_scalar_array_binary_arith(ctx_factory, which, reverse): out = outputs[dtype] out_ref = np_op(x_in, y_orig.astype(dtype)) - if not is_old_numpy: - assert out.dtype == out_ref.dtype, (out.dtype, out_ref.dtype) - - # In some cases ops are done in float32 in loopy but float64 in numpy. - is_allclose = np.allclose(out, out_ref), (out, out_ref) - if not is_old_numpy: - assert is_allclose - else: - if out_ref.dtype.itemsize == 1: - pass - else: - assert is_allclose + assert out.dtype == out_ref.dtype, (out.dtype, out_ref.dtype) + assert np.allclose(out, out_ref), (out, out_ref) @pytest.mark.parametrize("which", ("add", "sub", "mul", "truediv", "pow", @@ -389,7 +376,6 @@ def test_array_array_binary_arith(ctx_factory, which, reverse): out_ref = np_op(x_in, y_orig.astype(dtype)) assert out.dtype == out_ref.dtype, (out.dtype, out_ref.dtype) - # In some cases ops are done in float32 in loopy but float64 in numpy. assert np.allclose(out, out_ref), (out, out_ref) diff --git a/test/test_pytato.py b/test/test_pytato.py index f67e7e5f1..271c8fb01 100644 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -1364,6 +1364,21 @@ def test_dot_visualizers(): # }}} +def test_numpy_type_promotion_with_pytato_arrays(): + class NotReallyAnArray: + @property + def dtype(self): + return np.dtype("float64") + + # Make sure that np.result_type accesses only the dtype attribute of the + # class, not (e.g.) its data. + assert np.result_type(42, NotReallyAnArray()) == np.float64 + + from pytato.array import _np_result_dtype + assert _np_result_dtype(42, NotReallyAnArray()) == np.float64 + assert _np_result_dtype(42.0, NotReallyAnArray()) == np.float64 + + if __name__ == "__main__": if len(sys.argv) > 1: exec(sys.argv[1])