Skip to content

Commit

Permalink
Merge branch 'main' into production-pilot
Browse files Browse the repository at this point in the history
  • Loading branch information
MTCam committed Dec 2, 2024
2 parents b7dbab8 + d8e8683 commit ee2f1d1
Show file tree
Hide file tree
Showing 28 changed files with 512 additions and 274 deletions.
4 changes: 4 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ on:
schedule:
- cron: '17 3 * * 0'

concurrency:
group: ${{ github.head_ref || github.ref_name }}
cancel-in-progress: true

jobs:
typos:
name: Typos
Expand Down
11 changes: 11 additions & 0 deletions arraycontext/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
An array context is an abstraction that helps you dispatch between multiple
implementations of :mod:`numpy`-like :math:`n`-dimensional arrays.
"""
from __future__ import annotations


__copyright__ = """
Expand Down Expand Up @@ -29,6 +30,7 @@
"""

from .container import (
ArithArrayContainer,
ArrayContainer,
ArrayContainerT,
NotAnArrayContainerError,
Expand Down Expand Up @@ -72,6 +74,10 @@
from .context import (
Array,
ArrayContext,
ArrayOrArithContainer,
ArrayOrArithContainerOrScalar,
ArrayOrArithContainerOrScalarT,
ArrayOrArithContainerT,
ArrayOrContainer,
ArrayOrContainerOrScalar,
ArrayOrContainerOrScalarT,
Expand All @@ -96,10 +102,15 @@


__all__ = (
"ArithArrayContainer",
"Array",
"ArrayContainer",
"ArrayContainerT",
"ArrayContext",
"ArrayOrArithContainer",
"ArrayOrArithContainerOrScalar",
"ArrayOrArithContainerOrScalarT",
"ArrayOrArithContainerT",
"ArrayOrContainer",
"ArrayOrContainerOrScalar",
"ArrayOrContainerOrScalarT",
Expand Down
31 changes: 28 additions & 3 deletions arraycontext/container/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
.. currentmodule:: arraycontext
.. autoclass:: ArrayContainer
.. autoclass:: ArithArrayContainer
.. class:: ArrayContainerT
A type variable with a lower bound of :class:`ArrayContainer`.
Expand Down Expand Up @@ -81,14 +82,15 @@

from collections.abc import Hashable, Sequence
from functools import singledispatch
from typing import TYPE_CHECKING, Any, Protocol, TypeAlias, TypeVar
from typing import TYPE_CHECKING, Protocol, TypeAlias, TypeVar

# For use in singledispatch type annotations, because sphinx can't figure out
# what 'np' is.
import numpy
import numpy as np
from typing_extensions import Self

from arraycontext.context import ArrayContext
from arraycontext.context import ArrayContext, ArrayOrScalar


if TYPE_CHECKING:
Expand Down Expand Up @@ -145,6 +147,29 @@ class ArrayContainer(Protocol):
# that are container-typed.


class ArithArrayContainer(ArrayContainer, Protocol):
"""
A sub-protocol of :class:`ArrayContainer` that supports basic arithmetic.
"""

# This is loose and permissive, assuming that any array can be added
# to any container. The alternative would be to plaster type-ignores
# on all those uses. Achieving typing precision on what broadcasting is
# allowable seems like a huge endeavor and is likely not feasible without
# a mypy plugin. Maybe some day? -AK, November 2024

def __neg__(self) -> Self: ...
def __abs__(self) -> Self: ...
def __add__(self, other: ArrayOrScalar | Self) -> Self: ...
def __radd__(self, other: ArrayOrScalar | Self) -> Self: ...
def __sub__(self, other: ArrayOrScalar | Self) -> Self: ...
def __rsub__(self, other: ArrayOrScalar | Self) -> Self: ...
def __mul__(self, other: ArrayOrScalar | Self) -> Self: ...
def __rmul__(self, other: ArrayOrScalar | Self) -> Self: ...
def __truediv__(self, other: ArrayOrScalar | Self) -> Self: ...
def __rtruediv__(self, other: ArrayOrScalar | Self) -> Self: ...


ArrayContainerT = TypeVar("ArrayContainerT", bound=ArrayContainer)


Expand Down Expand Up @@ -219,7 +244,7 @@ def is_array_container_type(cls: type) -> bool:
is not serialize_container.__wrapped__)) # type:ignore[attr-defined]


def is_array_container(ary: Any) -> bool:
def is_array_container(ary: object) -> bool:
"""
:returns: *True* if the instance *ary* has a registered implementation of
:func:`serialize_container`.
Expand Down
56 changes: 41 additions & 15 deletions arraycontext/container/dataclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
.. currentmodule:: arraycontext
.. autofunction:: dataclass_array_container
"""
from __future__ import annotations


__copyright__ = """
Expand All @@ -30,6 +31,7 @@
THE SOFTWARE.
"""

from collections.abc import Mapping, Sequence
from dataclasses import Field, fields, is_dataclass
from typing import Union, get_args, get_origin

Expand Down Expand Up @@ -57,13 +59,21 @@ def dataclass_array_container(cls: type) -> type:
* a :class:`typing.Union` of array containers is considered an array container.
* other type annotations, e.g. :class:`typing.Optional`, are not considered
array containers, even if they wrap one.
.. note::
When type annotations are strings (e.g. because of
``from __future__ import annotations``),
this function relies on :func:`inspect.get_annotations`
(with ``eval_str=True``) to obtain type annotations. This
means that *cls* must live in a module that is importable.
"""

from types import GenericAlias, UnionType

assert is_dataclass(cls)

def is_array_field(f: Field) -> bool:
def is_array_field(f: Field, field_type: type) -> bool:
# NOTE: unions of array containers are treated separately to handle
# unions of only array containers, e.g. `Union[np.ndarray, Array]`, as
# they can work seamlessly with arithmetic and traversal.
Expand All @@ -76,17 +86,17 @@ def is_array_field(f: Field) -> bool:
#
# This is not set in stone, but mostly driven by current usage!

origin = get_origin(f.type)
origin = get_origin(field_type)
# NOTE: `UnionType` is returned when using `Type1 | Type2`
if origin in (Union, UnionType):
if all(is_array_type(arg) for arg in get_args(f.type)):
if all(is_array_type(arg) for arg in get_args(field_type)):
return True
else:
raise TypeError(
f"Field '{f.name}' union contains non-array container "
"arguments. All arguments must be array containers.")

if isinstance(f.type, str):
if isinstance(field_type, str):
raise TypeError(
f"String annotation on field '{f.name}' not supported. "
"(this may be due to 'from __future__ import annotations')")
Expand All @@ -104,33 +114,49 @@ def is_array_field(f: Field) -> bool:
_BaseGenericAlias,
_SpecialForm,
)
if isinstance(f.type, GenericAlias | _BaseGenericAlias | _SpecialForm):
if isinstance(field_type, GenericAlias | _BaseGenericAlias | _SpecialForm):
# NOTE: anything except a Union is not allowed
raise TypeError(
f"Typing annotation not supported on field '{f.name}': "
f"'{f.type!r}'")
f"'{field_type!r}'")

if not isinstance(f.type, type):
if not isinstance(field_type, type):
raise TypeError(
f"Field '{f.name}' not an instance of 'type': "
f"'{f.type!r}'")
f"'{field_type!r}'")

return is_array_type(field_type)

from inspect import get_annotations

return is_array_type(f.type)
array_fields: list[Field] = []
non_array_fields: list[Field] = []
cls_ann: Mapping[str, type] | None = None
for field in fields(cls):
field_type_or_str = field.type
if isinstance(field_type_or_str, str):
if cls_ann is None:
cls_ann = get_annotations(cls, eval_str=True)
field_type = cls_ann[field.name]
else:
field_type = field_type_or_str

from pytools import partition
array_fields, non_array_fields = partition(is_array_field, fields(cls))
if is_array_field(field, field_type):
array_fields.append(field)
else:
non_array_fields.append(field)

if not array_fields:
raise ValueError(f"'{cls}' must have fields with array container type "
"in order to use the 'dataclass_array_container' decorator")

return inject_dataclass_serialization(cls, array_fields, non_array_fields)
return _inject_dataclass_serialization(cls, array_fields, non_array_fields)


def inject_dataclass_serialization(
def _inject_dataclass_serialization(
cls: type,
array_fields: tuple[Field, ...],
non_array_fields: tuple[Field, ...]) -> type:
array_fields: Sequence[Field],
non_array_fields: Sequence[Field]) -> type:
"""Implements :func:`~arraycontext.serialize_container` and
:func:`~arraycontext.deserialize_container` for the given dataclass *cls*.
Expand Down
Loading

0 comments on commit ee2f1d1

Please sign in to comment.