Skip to content

Commit

Permalink
Merge pull request #29 from firedrakeproject/ksagiyam/merge_upstream
Browse files Browse the repository at this point in the history
Ksagiyam/merge upstream
  • Loading branch information
connorjward authored Nov 11, 2024
2 parents 6f02d31 + 422c0e9 commit b2c485f
Show file tree
Hide file tree
Showing 28 changed files with 310 additions and 444 deletions.
11 changes: 11 additions & 0 deletions doc/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,4 +60,15 @@
# As of 2022-06-22, it doesn't look like there's sphinx documentation
# available.
["py:class", r"immutables\.(.+)"],

# Reference not found from "<unknown>"? I'm not even sure where to look.
["py:class", r"Expression"],
]

autodoc_type_aliases = {
"ToLoopyTypeConvertible": "ToLoopyTypeConvertible",
"ExpressionT": "ExpressionT",
"InameStr": "InameStr",
"ShapeType": "ShapeType",
"StridesType": "StridesType",
}
5 changes: 4 additions & 1 deletion loopy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@
)
from loopy.translation_unit import TranslationUnit, for_each_kernel, make_program
from loopy.type_inference import infer_unknown_types
from loopy.types import to_loopy_type
from loopy.types import LoopyType, NumpyType, ToLoopyTypeConvertible, to_loopy_type
from loopy.typing import auto
from loopy.version import MOST_RECENT_LANGUAGE_VERSION, VERSION

Expand Down Expand Up @@ -248,12 +248,14 @@
"LinearSubscript",
"LoopKernel",
"LoopyError",
"LoopyType",
"LoopyWarning",
"MemAccess",
"MemoryOrdering",
"MemoryScope",
"MultiAssignmentBase",
"NoOpInstruction",
"NumpyType",
"Op",
"OpenCLTarget",
"Optional",
Expand All @@ -270,6 +272,7 @@
"TemporaryVariable",
"ToCountMap",
"ToCountPolynomialMap",
"ToLoopyTypeConvertible",
"TranslationUnit",
"TypeCast",
"UniqueName",
Expand Down
8 changes: 7 additions & 1 deletion loopy/check.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@

import islpy as isl
from islpy import dim_type
from pymbolic.primitives import Variable, is_arithmetic_expression
from pytools import memoize_method

from loopy.diagnostic import (
Expand Down Expand Up @@ -1669,6 +1670,8 @@ def _are_sub_array_refs_equivalent(
if len(sar1.swept_inames) != len(sar2.swept_inames):
return False

assert isinstance(sar1.subscript.aggregate, Variable)
assert isinstance(sar2.subscript.aggregate, Variable)
if sar1.subscript.aggregate.name != sar2.subscript.aggregate.name:
return False

Expand All @@ -1692,7 +1695,10 @@ def _are_sub_array_refs_equivalent(

for idx1, idx2 in zip(sar1.subscript.index_tuple,
sar2.subscript.index_tuple):
if simplify_via_aff(subst_mapper(idx1) - idx2) != 0:
subst_idx1 = subst_mapper(idx1)
assert is_arithmetic_expression(subst_idx1)
assert is_arithmetic_expression(idx2)
if simplify_via_aff(subst_idx1 - idx2) != 0:
return False
return True

Expand Down
16 changes: 15 additions & 1 deletion loopy/codegen/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from immutables import Map

from loopy.codegen.result import CodeGenerationResult
from loopy.library.reduction import ReductionOpFunction
from loopy.translation_unit import CallablesTable, TranslationUnit


Expand Down Expand Up @@ -86,6 +87,12 @@
.. automodule:: loopy.codegen.result
.. automodule:: loopy.codegen.tools
References
^^^^^^^^^^
.. class:: Expression
See :class:`pymbolic.Expression`.
"""


Expand Down Expand Up @@ -661,8 +668,15 @@ def generate_code_v2(t_unit: TranslationUnit) -> CodeGenerationResult:
ast=t_unit.target.get_device_ast_builder().ast_module.Collection(
callee_fdecls+[device_programs[0].ast]))] +
device_programs[1:])

def not_reduction_op(name: str | ReductionOpFunction) -> str:
assert isinstance(name, str)
return name

cgr = TranslationUnitCodeGenerationResult(
host_programs=host_programs,
host_programs={
not_reduction_op(name): prg
for name, prg in host_programs.items()},
device_programs=device_programs,
device_preambles=device_preambles)

Expand Down
3 changes: 2 additions & 1 deletion loopy/frontend/fortran/expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,8 @@ class FortranExpressionParser(ExpressionParserBase):
(_not, pytools.lex.RE(r"\.not\.", re.I)),
(_and, pytools.lex.RE(r"\.and\.", re.I)),
(_or, pytools.lex.RE(r"\.or\.", re.I)),
] + ExpressionParserBase.lex_table
*ExpressionParserBase.lex_table,
]

def __init__(self, tree_walker):
self.tree_walker = tree_walker
Expand Down
5 changes: 4 additions & 1 deletion loopy/kernel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@

import islpy as isl
from islpy import dim_type
from pymbolic import ArithmeticExpressionT
from pytools import (
UniqueNameGenerator,
generate_unique_names,
Expand Down Expand Up @@ -1042,7 +1043,9 @@ def get_grid_size_upper_bounds(self, callables_table, ignore_auto=False,
def get_grid_size_upper_bounds_as_exprs(
self, callables_table,
ignore_auto=False, return_dict=False
) -> Tuple[Tuple[ExpressionT, ...], Tuple[ExpressionT, ...]]:
) -> Tuple[
Tuple[ArithmeticExpressionT, ...],
Tuple[ArithmeticExpressionT, ...]]:
"""Return a tuple (global_size, local_size) containing a grid that
could accommodate execution of *all* instructions in the kernel.
Expand Down
55 changes: 32 additions & 23 deletions loopy/kernel/array.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from __future__ import annotations

from loopy.symbolic import flatten


__copyright__ = "Copyright (C) 2012 Andreas Kloeckner"

Expand Down Expand Up @@ -47,13 +45,15 @@
import numpy as np # noqa
from typing_extensions import TypeAlias

from pymbolic import ArithmeticExpressionT
from pymbolic.primitives import is_arithmetic_expression
from pytools import ImmutableRecord
from pytools.tag import Tag, Taggable

from loopy.diagnostic import LoopyError
from loopy.tools import is_integer
from loopy.symbolic import flatten
from loopy.types import LoopyType
from loopy.typing import ExpressionT, ShapeType, auto
from loopy.typing import ExpressionT, ShapeType, auto, is_integer


if TYPE_CHECKING:
Expand Down Expand Up @@ -625,17 +625,35 @@ def _parse_shape_or_strides(
if x is auto:
return auto

if isinstance(x, str):
x = parse(x)
if not isinstance(x, str):
x_parsed = x
else:
x_parsed = parse(x)

if isinstance(x, list):
if isinstance(x_parsed, list):
raise ValueError("shape can't be a list")

if not isinstance(x, tuple):
assert x is not auto
x = (x,)
if isinstance(x_parsed, tuple):
x_tup: tuple[ExpressionT | str, ...] = x_parsed
else:
assert x_parsed is not auto
x_tup = (cast(ExpressionT, x_parsed),)

def parse_arith(x: ExpressionT | str) -> ArithmeticExpressionT:
if isinstance(x, str):
res = parse(x)
else:
res = x

# The Fortran parser may do this, but this is (deliberately) outside
# the behavior allowed by types, because the hope is to phase it out.
if x is None:
return x

assert is_arithmetic_expression(res)
return res

return tuple(parse(xi) if isinstance(xi, str) else xi for xi in x)
return tuple(parse_arith(xi) for xi in x_tup)


class ArrayBase(ImmutableRecord, Taggable):
Expand Down Expand Up @@ -1026,16 +1044,6 @@ def __str__(self):
def __repr__(self):
return "<%s>" % self.__str__()

def update_persistent_hash_for_shape(self, key_hash, key_builder, shape):
if isinstance(shape, tuple):
for shape_i in shape:
if shape_i is None:
key_builder.rec(key_hash, shape_i)
else:
key_builder.update_for_pymbolic_expression(key_hash, shape_i)
else:
key_builder.rec(key_hash, shape)

def update_persistent_hash(self, key_hash, key_builder):
"""Custom hash computation function for use with
:class:`pytools.persistent_dict.PersistentDict`.
Expand All @@ -1044,7 +1052,7 @@ def update_persistent_hash(self, key_hash, key_builder):
key_builder.rec(key_hash, type(self).__name__)
key_builder.rec(key_hash, self.name)
key_builder.rec(key_hash, self.dtype)
self.update_persistent_hash_for_shape(key_hash, key_builder, self.shape)
key_builder.rec(key_hash, self.shape)
key_builder.rec(key_hash, self.dim_tags)
key_builder.rec(key_hash, self.offset)
key_builder.rec(key_hash, self.dim_names)
Expand Down Expand Up @@ -1232,11 +1240,12 @@ def get_access_info(kernel: "LoopKernel",

import loopy as lp

def eval_expr_assert_integer_constant(i, expr):
def eval_expr_assert_integer_constant(i, expr) -> int:
from pymbolic.mapper.evaluator import UnknownVariableError
try:
result = eval_expr(expr)
except UnknownVariableError as e:
assert ary.dim_tags is not None
raise LoopyError("When trying to index the array '%s' along axis "
"%d (tagged '%s'), the index was not a compile-time "
"constant (but it has to be in order for code to be "
Expand Down
15 changes: 11 additions & 4 deletions loopy/kernel/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
import numpy as np
from immutables import Map

from pymbolic import ArithmeticExpressionT
from pytools import ImmutableRecord
from pytools.tag import Tag, Taggable, UniqueTag as UniqueTagBase

Expand Down Expand Up @@ -87,6 +88,13 @@
.. autoclass:: UnrollTag
.. autoclass:: Iname
References
^^^^^^^^^^
.. class:: ToLoopyTypeConvertible
See :class:`loopy.ToLoopyTypeConvertible`.
"""

# This docstring is included in ref_internals. Do not include parts of the public
Expand Down Expand Up @@ -809,7 +817,7 @@ def nbytes(self) -> ExpressionT:
raise ValueError("shape is None")
if self.shape is auto:
raise ValueError("shape is auto")
shape = cast(Tuple[ExpressionT], self.shape)
shape = cast(Tuple[ArithmeticExpressionT], self.shape)

if self.dtype is None:
raise ValueError("data type is indeterminate")
Expand Down Expand Up @@ -853,8 +861,7 @@ def update_persistent_hash(self, key_hash, key_builder):
"""

super().update_persistent_hash(key_hash, key_builder)
self.update_persistent_hash_for_shape(key_hash, key_builder,
self.storage_shape)
key_builder.rec(key_hash, self.storage_shape)
key_builder.rec(key_hash, self.base_indices)
key_builder.rec(key_hash, self.address_space)
key_builder.rec(key_hash, self.base_storage)
Expand Down Expand Up @@ -899,7 +906,7 @@ def copy(self, **kwargs: Any) -> SubstitutionRule:
def update_persistent_hash(self, key_hash, key_builder):
key_builder.rec(key_hash, self.name)
key_builder.rec(key_hash, self.arguments)
key_builder.update_for_pymbolic_expression(key_hash, self.expression)
key_builder.rec(key_hash, self.expression)


# }}}
Expand Down
2 changes: 1 addition & 1 deletion loopy/kernel/function_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def depends_on(self):
return frozenset(var.name for var in result)

def update_persistent_hash(self, key_hash, key_builder):
key_builder.update_for_pymbolic_expression(key_hash, self.shape)
key_builder.rec(key_hash, self.shape)
key_builder.rec(key_hash, self.address_space)
key_builder.rec(key_hash, self.dim_tags)

Expand Down
Loading

0 comments on commit b2c485f

Please sign in to comment.