Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Summarize changes to support prediction #1

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
1 change: 1 addition & 0 deletions arraycontext/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@
from .impl.numpy import NumpyArrayContext
from .impl.pyopencl import PyOpenCLArrayContext
from .impl.pytato import PytatoJAXArrayContext, PytatoPyOpenCLArrayContext
from .impl.numpy import NumpyArrayContext
from .loopy import make_loopy_program
from .pytest import (
PytestArrayContextFactory,
Expand Down
27 changes: 25 additions & 2 deletions arraycontext/impl/pyopencl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,10 +305,33 @@ def transform_loopy_program(self, t_unit: lp.TranslationUnit) -> lp.TranslationU
"to create this kernel?")

all_inames = default_entrypoint.all_inames()

# FIXME: This could be much smarter.
inner_iname = None

if "i0" in all_inames:
# import with underscore to avoid DeprecationWarning
# from arraycontext.metadata import _FirstAxisIsElementsTag
from meshmode.transform_metadata import FirstAxisIsElementsTag

if (len(default_entrypoint.instructions) == 1
and isinstance(default_entrypoint.instructions[0], lp.Assignment)
and any(isinstance(tag, FirstAxisIsElementsTag)
# FIXME: Firedrake branch lacks kernel tags
for tag in getattr(default_entrypoint, "tags", ()))):
stmt, = default_entrypoint.instructions

out_inames = [v.name for v in stmt.assignee.index_tuple]
assert out_inames
outer_iname = out_inames[0]
if len(out_inames) >= 2:
inner_iname = out_inames[1]

elif "iel" in all_inames:
outer_iname = "iel"

if "idof" in all_inames:
inner_iname = "idof"

elif "i0" in all_inames:
outer_iname = "i0"

if "i1" in all_inames:
Expand Down
10 changes: 6 additions & 4 deletions arraycontext/impl/pytato/fake_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,9 @@ def zeros(self, shape, dtype):

def zeros_like(self, ary):
def _zeros_like(array):
return self._array_context.zeros(
array.shape, array.dtype).copy(axes=array.axes, tags=array.tags)
# return self._array_context.zeros(
# array.shape, array.dtype).copy(axes=array.axes, tags=array.tags)
return 0*array

return self._array_context._rec_map_container(
_zeros_like, ary, default_scalar=0)
Expand All @@ -104,8 +105,9 @@ def ones_like(self, ary):

def full_like(self, ary, fill_value):
def _full_like(subary):
return pt.full(subary.shape, fill_value, subary.dtype).copy(
axes=subary.axes, tags=subary.tags)
# return pt.full(subary.shape, fill_value, subary.dtype).copy(
# axes=subary.axes, tags=subary.tags)
return fill_value * (0*subary + 1)

return self._array_context._rec_map_container(
_full_like, ary, default_scalar=fill_value)
Expand Down
23 changes: 22 additions & 1 deletion arraycontext/pytest.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,9 @@
from arraycontext import NumpyArrayContext
from arraycontext.context import ArrayContext


# {{{ array context factories


class PytestArrayContextFactory:
@classmethod
def is_available(cls) -> bool:
Expand Down Expand Up @@ -226,6 +226,27 @@ def __call__(self):
def __str__(self):
return "<PytatoJAXArrayContext>"

# {{{ _PytestArrayContextFactory


class _NumpyArrayContextForTests(NumpyArrayContext):
def transform_loopy_program(self, t_unit):
return t_unit


class _PytestNumpyArrayContextFactory(PytestArrayContextFactory):
def __init__(self, *args, **kwargs):
super().__init__()

def __call__(self):
return _NumpyArrayContextForTests()

def __str__(self):
return "<NumpyArrayContext>"

# }}}



# {{{ _PytestArrayContextFactory

Expand Down
10 changes: 8 additions & 2 deletions test/test_arraycontext.py
Original file line number Diff line number Diff line change
Expand Up @@ -811,8 +811,9 @@ def _check_allclose(f, arg1, arg2, atol=5.0e-14):
with pytest.raises(TypeError):
ary_of_dofs + dc_of_dofs

with pytest.raises(TypeError):
dc_of_dofs + ary_of_dofs
if not isinstance(actx, NumpyArrayContext):
with pytest.raises(TypeError):
dc_of_dofs + ary_of_dofs

with pytest.raises(TypeError):
ary_dof + dc_of_dofs
Expand Down Expand Up @@ -1014,7 +1015,12 @@ def test_flatten_with_leaf_class(actx_factory):
# {{{ test from_numpy and to_numpy

def test_numpy_conversion(actx_factory):
from arraycontext import NumpyArrayContext

actx = actx_factory()
if isinstance(actx, NumpyArrayContext):
pytest.skip("Irrelevant tests for NumpyArrayContext")

rng = np.random.default_rng()

nelements = 42
Expand Down
Loading