Skip to content

Commit

Permalink
Enable UP rules
Browse files Browse the repository at this point in the history
  • Loading branch information
inducer committed Jul 23, 2024
1 parent 21baa7a commit 550920a
Show file tree
Hide file tree
Showing 38 changed files with 547 additions and 610 deletions.
2 changes: 1 addition & 1 deletion examples/advection.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
memoized = functools.lru_cache(maxsize=None)


class AdvectionOperator(object):
class AdvectionOperator:
"""A class representing a DG advection operator."""

def __init__(self, discr, c, flux_type, dg_ops):
Expand Down
4 changes: 2 additions & 2 deletions examples/dg_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def ortholegval(x, c):
return leg.legval(x, c * factors)


class DGDiscr1D(object):
class DGDiscr1D:
"""A one-dimensional Discontinuous Galerkin discretization."""

def __init__(self, left, right, nelements, nnodes):
Expand Down Expand Up @@ -262,7 +262,7 @@ def elementwise(mat, vec):
return np.einsum("ij,kj->ki", mat, vec)


class AbstractDGOps1D(object):
class AbstractDGOps1D:
def __init__(self, discr):
self.discr = discr

Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ extend-select = [
"W", # pycodestyle
"NPY", # numpy
"RUF",
"UP",
]
extend-ignore = [
"E226",
Expand Down
40 changes: 20 additions & 20 deletions pytato/analysis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
THE SOFTWARE.
"""

from typing import TYPE_CHECKING, Any, Dict, FrozenSet, Mapping, Set, Tuple, Union
from typing import TYPE_CHECKING, Any, Mapping

from pymbolic.mapper.optimize import optimize_mapper
from pytools import memoize_method
Expand Down Expand Up @@ -84,8 +84,8 @@ class NUserCollector(Mapper):
def __init__(self) -> None:
from collections import defaultdict
super().__init__()
self._visited_ids: Set[int] = set()
self.nusers: Dict[Array, int] = defaultdict(lambda: 0)
self._visited_ids: set[int] = set()
self.nusers: dict[Array, int] = defaultdict(lambda: 0)

# type-ignore reason: NUserCollector.rec's type does not match
# Mapper.rec's type
Expand Down Expand Up @@ -198,7 +198,7 @@ def map_named_call_result(self, expr: NamedCallResult) -> None:
# }}}


def get_nusers(outputs: Union[Array, DictOfNamedArrays]) -> Mapping[Array, int]:
def get_nusers(outputs: Array | DictOfNamedArrays) -> Mapping[Array, int]:
"""
For the DAG *outputs*, returns the mapping from each node to the number of
nodes using its value within the DAG given by *outputs*.
Expand All @@ -214,7 +214,7 @@ def get_nusers(outputs: Union[Array, DictOfNamedArrays]) -> Mapping[Array, int]:

def _get_indices_from_input_subscript(subscript: str,
is_output: bool,
) -> Tuple[str, ...]:
) -> tuple[str, ...]:
from pytato.array import EINSUM_FIRST_INDEX

acc = subscript.strip()
Expand Down Expand Up @@ -273,7 +273,7 @@ def is_einsum_similar_to_subscript(expr: Einsum, subscripts: str) -> bool:
in_spec, out_spec = subscripts.split("->")

# build up a mapping from index names to axis descriptors
index_to_descrs: Dict[str, EinsumAxisDescriptor] = {}
index_to_descrs: dict[str, EinsumAxisDescriptor] = {}

for idim, idx in enumerate(_get_indices_from_input_subscript(out_spec,
is_output=True)):
Expand Down Expand Up @@ -321,26 +321,26 @@ class DirectPredecessorsGetter(Mapper):
We only consider the predecessors of a nodes in a data-flow sense.
"""
def _get_preds_from_shape(self, shape: ShapeType) -> FrozenSet[Array]:
def _get_preds_from_shape(self, shape: ShapeType) -> frozenset[Array]:
return frozenset({dim for dim in shape if isinstance(dim, Array)})

def map_index_lambda(self, expr: IndexLambda) -> FrozenSet[Array]:
def map_index_lambda(self, expr: IndexLambda) -> frozenset[Array]:
return (frozenset(expr.bindings.values())
| self._get_preds_from_shape(expr.shape))

def map_stack(self, expr: Stack) -> FrozenSet[Array]:
def map_stack(self, expr: Stack) -> frozenset[Array]:
return (frozenset(expr.arrays)
| self._get_preds_from_shape(expr.shape))

def map_concatenate(self, expr: Concatenate) -> FrozenSet[Array]:
def map_concatenate(self, expr: Concatenate) -> frozenset[Array]:
return (frozenset(expr.arrays)
| self._get_preds_from_shape(expr.shape))

def map_einsum(self, expr: Einsum) -> FrozenSet[Array]:
def map_einsum(self, expr: Einsum) -> frozenset[Array]:
return (frozenset(expr.args)
| self._get_preds_from_shape(expr.shape))

def map_loopy_call_result(self, expr: NamedArray) -> FrozenSet[Array]:
def map_loopy_call_result(self, expr: NamedArray) -> frozenset[Array]:
from pytato.loopy import LoopyCall, LoopyCallResult
assert isinstance(expr, LoopyCallResult)
assert isinstance(expr._container, LoopyCall)
Expand All @@ -349,7 +349,7 @@ def map_loopy_call_result(self, expr: NamedArray) -> FrozenSet[Array]:
if isinstance(ary, Array))
| self._get_preds_from_shape(expr.shape))

def _map_index_base(self, expr: IndexBase) -> FrozenSet[Array]:
def _map_index_base(self, expr: IndexBase) -> frozenset[Array]:
return (frozenset([expr.array])
| frozenset(idx for idx in expr.indices
if isinstance(idx, Array))
Expand All @@ -360,29 +360,29 @@ def _map_index_base(self, expr: IndexBase) -> FrozenSet[Array]:
map_non_contiguous_advanced_index = _map_index_base

def _map_index_remapping_base(self, expr: IndexRemappingBase
) -> FrozenSet[Array]:
) -> frozenset[Array]:
return frozenset([expr.array])

map_roll = _map_index_remapping_base
map_axis_permutation = _map_index_remapping_base
map_reshape = _map_index_remapping_base

def _map_input_base(self, expr: InputArgumentBase) -> FrozenSet[Array]:
def _map_input_base(self, expr: InputArgumentBase) -> frozenset[Array]:
return self._get_preds_from_shape(expr.shape)

map_placeholder = _map_input_base
map_data_wrapper = _map_input_base
map_size_param = _map_input_base

def map_distributed_recv(self, expr: DistributedRecv) -> FrozenSet[Array]:
def map_distributed_recv(self, expr: DistributedRecv) -> frozenset[Array]:
return self._get_preds_from_shape(expr.shape)

def map_distributed_send_ref_holder(self,
expr: DistributedSendRefHolder
) -> FrozenSet[Array]:
) -> frozenset[Array]:
return frozenset([expr.passthrough_data])

def map_named_call_result(self, expr: NamedCallResult) -> FrozenSet[Array]:
def map_named_call_result(self, expr: NamedCallResult) -> frozenset[Array]:
raise NotImplementedError(
"DirectPredecessorsGetter does not yet support expressions containing "
"functions.")
Expand Down Expand Up @@ -414,7 +414,7 @@ def post_visit(self, expr: Any) -> None:
self.count += 1


def get_num_nodes(outputs: Union[Array, DictOfNamedArrays]) -> int:
def get_num_nodes(outputs: Array | DictOfNamedArrays) -> int:
"""Returns the number of nodes in DAG *outputs*."""

from pytato.codegen import normalize_outputs
Expand Down Expand Up @@ -465,7 +465,7 @@ def post_visit(self, expr: Any) -> None:
self.count += 1


def get_num_call_sites(outputs: Union[Array, DictOfNamedArrays]) -> int:
def get_num_call_sites(outputs: Array | DictOfNamedArrays) -> int:
"""Returns the number of nodes in DAG *outputs*."""

from pytato.codegen import normalize_outputs
Expand Down
Loading

0 comments on commit 550920a

Please sign in to comment.