Skip to content

Commit

Permalink
Merge branch 'main' into production-pilot
Browse files Browse the repository at this point in the history
  • Loading branch information
MTCam committed Dec 17, 2024
2 parents ee2f1d1 + c4f00b8 commit f66d33f
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 42 deletions.
3 changes: 2 additions & 1 deletion arraycontext/container/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,8 @@ def _deserialize_ndarray_container( # type: ignore[misc]

result = type(template)(template.shape, dtype=object)
for i, subary in serialized:
result[i] = subary
# FIXME: numpy annotations don't seem to handle object arrays very well
result[i] = subary # type: ignore[call-overload]

return result

Expand Down
44 changes: 22 additions & 22 deletions arraycontext/container/traversal.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,16 +110,16 @@ def _map_array_container_impl(
specific container classes. By default, the recursion is stopped when
a non-:class:`ArrayContainer` class is encountered.
"""
def rec(_ary: ArrayOrContainer) -> ArrayOrContainer:
if type(_ary) is leaf_cls: # type(ary) is never None
return f(_ary)
def rec(ary_: ArrayOrContainer) -> ArrayOrContainer:
if type(ary_) is leaf_cls: # type(ary) is never None
return f(ary_)

try:
iterable = serialize_container(_ary)
iterable = serialize_container(ary_)
except NotAnArrayContainerError:
return f(_ary)
return f(ary_)
else:
return deserialize_container(_ary, [
return deserialize_container(ary_, [
(key, frec(subary)) for key, subary in iterable
])

Expand All @@ -144,28 +144,28 @@ def _multimap_array_container_impl(

# {{{ recursive traversal

def rec(*_args: Any) -> Any:
template_ary = _args[container_indices[0]]
def rec(*args_: Any) -> Any:
template_ary = args_[container_indices[0]]
if type(template_ary) is leaf_cls:
return f(*_args)
return f(*args_)

try:
iterable_template = serialize_container(template_ary)
except NotAnArrayContainerError:
return f(*_args)
return f(*args_)
else:
pass

assert all(
type(_args[i]) is type(template_ary) for i in container_indices[1:]
type(args_[i]) is type(template_ary) for i in container_indices[1:]
), f"expected type '{type(template_ary).__name__}'"

result = []
new_args = list(_args)
new_args = list(args_)

for subarys in zip(
iterable_template,
*[serialize_container(_args[i]) for i in container_indices[1:]],
*[serialize_container(args_[i]) for i in container_indices[1:]],
strict=True
):
key = None
Expand Down Expand Up @@ -415,13 +415,13 @@ def rec_keyed_map_array_container(
"""

def rec(keys: tuple[SerializationKey, ...],
_ary: ArrayOrContainerT) -> ArrayOrContainerT:
ary_: ArrayOrContainerT) -> ArrayOrContainerT:
try:
iterable = serialize_container(_ary)
iterable = serialize_container(ary_)
except NotAnArrayContainerError:
return cast(ArrayOrContainerT, f(keys, cast(ArrayT, _ary)))
return cast(ArrayOrContainerT, f(keys, cast(ArrayT, ary_)))
else:
return deserialize_container(_ary, [
return deserialize_container(ary_, [
(key, rec((*keys, key), subary)) for key, subary in iterable
])

Expand Down Expand Up @@ -522,14 +522,14 @@ def rec_map_reduce_array_container(
or any other such traversal.
"""
def rec(_ary: ArrayOrContainerT) -> ArrayOrContainerT:
if type(_ary) is leaf_class:
return map_func(_ary)
def rec(ary_: ArrayOrContainerT) -> ArrayOrContainerT:
if type(ary_) is leaf_class:
return map_func(ary_)
else:
try:
iterable = serialize_container(_ary)
iterable = serialize_container(ary_)
except NotAnArrayContainerError:
return map_func(_ary)
return map_func(ary_)
else:
return reduce_func([
rec(subary) for _, subary in iterable
Expand Down
32 changes: 17 additions & 15 deletions arraycontext/loopy.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,12 +83,14 @@ def get_default_entrypoint(t_unit):
def _get_scalar_func_loopy_program(actx, c_name, nargs, naxes):
@memoize_in(actx, _get_scalar_func_loopy_program)
def get(c_name, nargs, naxes):
from pymbolic import var
from pymbolic.primitives import Subscript, Variable

var_names = [f"i{i}" for i in range(naxes)]
size_names = [f"n{i}" for i in range(naxes)]
subscript = tuple(var(vname) for vname in var_names)
subscript = tuple(Variable(vname) for vname in var_names)

from islpy import make_zero_and_vars

v = make_zero_and_vars(var_names, params=size_names)
domain = v[0].domain()
for vname, sname in zip(var_names, size_names, strict=True):
Expand All @@ -98,22 +100,22 @@ def get(c_name, nargs, naxes):

import loopy as lp

from .loopy import make_loopy_program
from arraycontext.transform_metadata import ElementwiseMapKernelTag

def sub(name: str) -> Variable | Subscript:
return Subscript(Variable(name), subscript) if subscript else Variable(name)

return make_loopy_program(
[domain_bset],
[
[domain_bset], [
lp.Assignment(
var("out")[subscript],
var(c_name)(*[
var(f"inp{i}")[subscript] for i in range(nargs)]))
],
[
lp.GlobalArg("out",
dtype=None, shape=lp.auto, offset=lp.auto)] + [
lp.GlobalArg(f"inp{i}",
dtype=None, shape=lp.auto, offset=lp.auto)
for i in range(nargs)] + [...],
sub("out"),
Variable(c_name)(*[sub(f"inp{i}") for i in range(nargs)]))
], [
lp.GlobalArg("out", dtype=None, shape=lp.auto, offset=lp.auto)
] + [
lp.GlobalArg(f"inp{i}", dtype=None, shape=lp.auto, offset=lp.auto)
for i in range(nargs)
] + [...],
name=f"actx_special_{c_name}",
tags=(ElementwiseMapKernelTag(),))

Expand Down
8 changes: 4 additions & 4 deletions test/test_arraycontext.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,11 +263,11 @@ def test_array_context_np_workalike(actx_factory, sym_name, n_args, dtype):
"atan2": "arctan2",
}

def evaluate(_np, *_args):
func = getattr(_np, sym_name,
getattr(_np, c_to_numpy_arc_functions.get(sym_name, sym_name)))
def evaluate(np_, *args_):
func = getattr(np_, sym_name,
getattr(np_, c_to_numpy_arc_functions.get(sym_name, sym_name)))

return func(*_args)
return func(*args_)

assert_close_to_numpy_in_containers(actx, evaluate, args)

Expand Down

0 comments on commit f66d33f

Please sign in to comment.