Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mypy #369

Merged
merged 3 commits into from
Dec 2, 2024
Merged

Mypy #369

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,19 @@ jobs:
pipx install ruff
ruff check

mypy:
name: Mypy
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: "Main Script"
run: |
curl -L -O https://tiker.net/ci-support-v0
. ./ci-support-v0
build_py_project_in_conda_env
python -m pip install mypy
./run-mypy.sh

typos:
name: Typos
runs-on: ubuntu-latest
Expand Down
15 changes: 14 additions & 1 deletion .gitlab-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ Documentation:
tags:
- python3

Flake8:
Ruff:
script:
- pipx install ruff
- ruff check
Expand All @@ -108,6 +108,19 @@ Flake8:
except:
- tags

Mypy:
script: |
EXTRA_INSTALL="Cython mpi4py"
curl -L -O https://tiker.net/ci-support-v0
. ./ci-support-v0
build_py_project_in_venv
python -m pip install mypy
./run-mypy.sh
tags:
- python3
except:
- tags

Pylint:
script: |
EXTRA_INSTALL="pybind11 make numpy scipy matplotlib mpi4py"
Expand Down
5 changes: 5 additions & 0 deletions doc/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,8 @@

# index-page demo uses pyopencl via plot_directive
os.environ["PYOPENCL_TEST"] = "port:cpu"

nitpick_ignore_regex = [
["py:class", r"np\.ndarray"],
["py:data|py:class", r"arraycontext.*ContainerTc"],
]
5 changes: 0 additions & 5 deletions doc/misc.rst
Original file line number Diff line number Diff line change
Expand Up @@ -142,8 +142,3 @@ AK also gratefully acknowledges a hardware gift from Nvidia Corporation.

The views and opinions expressed herein do not necessarily reflect those of the
funding agencies.

Deprecated functionality
========================

.. automodule:: grudge.eager
163 changes: 61 additions & 102 deletions grudge/array_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
.. autofunction:: get_reasonable_array_context_class
"""

from __future__ import annotations


__copyright__ = "Copyright (C) 2020 Andreas Kloeckner"

__license__ = """
Expand Down Expand Up @@ -35,9 +38,11 @@
import logging
from collections.abc import Callable, Mapping
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Optional
from typing import TYPE_CHECKING, Any
from warnings import warn

from typing_extensions import Self

from meshmode.array_context import (
PyOpenCLArrayContext as _PyOpenCLArrayContextBase,
PytatoPyOpenCLArrayContext as _PytatoPyOpenCLArrayContextBase,
Expand All @@ -48,28 +53,6 @@

logger = logging.getLogger(__name__)

try:
# FIXME: temporary workaround while SingleGridWorkBalancingPytatoArrayContext
# is not available in meshmode's main branch
# (it currently needs
# https://github.com/kaushikcfd/meshmode/tree/pytato-array-context-transforms)
from meshmode.array_context import SingleGridWorkBalancingPytatoArrayContext
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@matthiasdiener This is no longer relevant, right?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think so, yes. We have not been using the SingleGridWorkBalancingPytatoArrayContext for quite a while.

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!


try:
# Crude check if we have the correct loopy branch
# (https://github.com/kaushikcfd/loopy/tree/pytato-array-context-transforms)
from loopy.codegen.result import get_idis_for_kernel # noqa
except ImportError:
# warn("Your loopy and meshmode branches are mismatched. "
# "Please make sure that you have the "
# "https://github.com/kaushikcfd/loopy/tree/pytato-array-context-transforms " # noqa
# "branch of loopy.")
_HAVE_SINGLE_GRID_WORK_BALANCING = False
else:
_HAVE_SINGLE_GRID_WORK_BALANCING = True

except ImportError:
_HAVE_SINGLE_GRID_WORK_BALANCING = False

try:
# FIXME: temporary workaround while FusionContractorArrayContext
Expand Down Expand Up @@ -119,8 +102,8 @@ class PyOpenCLArrayContext(_PyOpenCLArrayContextBase):
to understand :mod:`grudge`-specific transform metadata. (Of which there isn't
any, for now.)
"""
def __init__(self, queue: "pyopencl.CommandQueue",
allocator: Optional["pyopencl.tools.AllocatorBase"] = None,
def __init__(self, queue: pyopencl.CommandQueue,
allocator: pyopencl.tools.AllocatorBase | None = None,
wait_event_queue_length: int | None = None,
force_device_scalars: bool = True) -> None:

Expand Down Expand Up @@ -165,8 +148,8 @@ def __init__(self, queue, allocator=None,
# }}}


class MPIBasedArrayContext:
mpi_communicator: "MPI.Comm"
class MPIBasedArrayContext(ArrayContext):
mpi_communicator: MPI.Intracomm


# {{{ distributed + pytato
Expand Down Expand Up @@ -345,13 +328,13 @@ class _DistributedCompiledFunction:
type of the callable.
"""

actx: "MPISingleGridWorkBalancingPytatoArrayContext"
distributed_partition: "DistributedGraphPartition"
part_id_to_prg: "Mapping[PartId, pt.target.BoundProgram]"
actx: MPIBasedArrayContext
distributed_partition: DistributedGraphPartition
part_id_to_prg: Mapping[PartId, pt.target.BoundProgram]
input_id_to_name_in_program: Mapping[tuple[Any, ...], str]
output_id_to_name_in_program: Mapping[tuple[Any, ...], str]
name_in_program_to_tags: Mapping[str, frozenset[Tag]]
name_in_program_to_axes: Mapping[str, tuple["pt.Axis", ...]]
name_in_program_to_axes: Mapping[str, tuple[pt.Axis, ...]]
output_template: ArrayContainer

def __call__(self, arg_id_to_arg) -> ArrayContainer:
Expand All @@ -368,10 +351,11 @@ def __call__(self, arg_id_to_arg) -> ArrayContainer:
self.actx, self.input_id_to_name_in_program, arg_id_to_arg)

from pytato import execute_distributed_partition
assert isinstance(self.actx, PytatoPyOpenCLArrayContext | PyOpenCLArrayContext)
out_dict = execute_distributed_partition(
self.distributed_partition, self.part_id_to_prg,
self.actx.queue, self.actx.mpi_communicator,
allocator=self.actx.allocator,
self.actx.queue, self.actx.mpi_communicator, # pylint: disable=no-member
allocator=self.actx.allocator, # pylint: disable=no-member
input_args=input_args_for_prg)

def to_output_template(keys, _):
Expand All @@ -387,42 +371,6 @@ def to_output_template(keys, _):
self.output_template)


class MPIPytatoArrayContextBase(MPIBasedArrayContext):
def __init__(
self, mpi_communicator, queue, *, mpi_base_tag, allocator=None,
compile_trace_callback: Callable[[Any, str, Any], None] | None = None,
) -> None:
"""
:arg compile_trace_callback: A function of three arguments
*(what, stage, ir)*, where *what* identifies the object
being compiled, *stage* is a string describing the compilation
pass, and *ir* is an object containing the intermediate
representation. This interface should be considered
unstable.
"""
if allocator is None:
warn("No memory allocator specified, please pass one. "
"(Preferably a pyopencl.tools.MemoryPool in order "
"to reduce device allocations)", stacklevel=2)

super().__init__(queue, allocator,
compile_trace_callback=compile_trace_callback)

self.mpi_communicator = mpi_communicator
self.mpi_base_tag = mpi_base_tag

# FIXME: implement distributed-aware freeze

def compile(self, f: Callable[..., Any]) -> Callable[..., Any]:
return _DistributedLazilyPyOpenCLCompilingFunctionCaller(self, f)

def clone(self):
# type-ignore-reason: 'DistributedLazyArrayContext' has no 'queue' member
# pylint: disable=no-member
return type(self)(self.mpi_communicator, self.queue,
mpi_base_tag=self.mpi_base_tag,
allocator=self.allocator)

# }}}


Expand All @@ -437,8 +385,8 @@ class MPIPyOpenCLArrayContext(PyOpenCLArrayContext, MPIBasedArrayContext):

def __init__(self,
mpi_communicator,
queue: "pyopencl.CommandQueue",
*, allocator: Optional["pyopencl.tools.AllocatorBase"] = None,
queue: pyopencl.CommandQueue,
*, allocator: pyopencl.tools.AllocatorBase | None = None,
wait_event_queue_length: int | None = None,
force_device_scalars: bool = True) -> None:
"""
Expand All @@ -451,7 +399,7 @@ def __init__(self,

self.mpi_communicator = mpi_communicator

def clone(self):
def clone(self) -> Self:
# type-ignore-reason: 'DistributedLazyArrayContext' has no 'queue' member
# pylint: disable=no-member
return type(self)(self.mpi_communicator, self.queue,
Expand All @@ -476,7 +424,7 @@ def __init__(self, mpi_communicator) -> None:

self.mpi_communicator = mpi_communicator

def clone(self):
def clone(self) -> Self:
return type(self)(self.mpi_communicator)

# }}}
Expand All @@ -485,28 +433,50 @@ def clone(self):
# {{{ distributed + pytato array context subclasses

class MPIBasePytatoPyOpenCLArrayContext(
MPIPytatoArrayContextBase, PytatoPyOpenCLArrayContext):
MPIBasedArrayContext, PytatoPyOpenCLArrayContext):
"""
.. autofunction:: __init__
"""
pass


if _HAVE_SINGLE_GRID_WORK_BALANCING:
class MPISingleGridWorkBalancingPytatoArrayContext(
MPIPytatoArrayContextBase, SingleGridWorkBalancingPytatoArrayContext):
def __init__(
self, mpi_communicator, queue, *, mpi_base_tag, allocator=None,
compile_trace_callback: Callable[[Any, str, Any], None] | None = None,
) -> None:
"""
.. autofunction:: __init__
:arg compile_trace_callback: A function of three arguments
*(what, stage, ir)*, where *what* identifies the object
being compiled, *stage* is a string describing the compilation
pass, and *ir* is an object containing the intermediate
representation. This interface should be considered
unstable.
"""
if allocator is None:
warn("No memory allocator specified, please pass one. "
"(Preferably a pyopencl.tools.MemoryPool in order "
"to reduce device allocations)", stacklevel=2)

MPIPytatoArrayContext = MPISingleGridWorkBalancingPytatoArrayContext
else:
MPIPytatoArrayContext = MPIBasePytatoPyOpenCLArrayContext
super().__init__(queue, allocator,
compile_trace_callback=compile_trace_callback)

self.mpi_communicator = mpi_communicator
self.mpi_base_tag = mpi_base_tag

# FIXME: implement distributed-aware freeze

def compile(self, f: Callable[..., Any]) -> Callable[..., Any]:
return _DistributedLazilyPyOpenCLCompilingFunctionCaller(self, f)

def clone(self) -> Self:
return type(self)(self.mpi_communicator, self.queue,
mpi_base_tag=self.mpi_base_tag,
allocator=self.allocator)


MPIPytatoArrayContext: type[MPIBasedArrayContext] = MPIBasePytatoPyOpenCLArrayContext


if _HAVE_FUSION_ACTX:
class MPIFusionContractorArrayContext(
MPIPytatoArrayContextBase, FusionContractorArrayContext):
MPIBasePytatoPyOpenCLArrayContext, FusionContractorArrayContext):
"""
.. autofunction:: __init__
"""
Expand Down Expand Up @@ -570,25 +540,14 @@ def __call__(self):


def _get_single_grid_pytato_actx_class(distributed: bool) -> type[ArrayContext]:
if not _HAVE_SINGLE_GRID_WORK_BALANCING:
warn("No device-parallel actx available, execution will be slow. "
"Please make sure you have the right branches for loopy "
"(https://github.com/kaushikcfd/loopy/tree/pytato-array-context-transforms) " # noqa
"and meshmode "
"(https://github.com/kaushikcfd/meshmode/tree/pytato-array-context-transforms).",
stacklevel=1)
warn("No device-parallel actx available, execution will be slow.",
stacklevel=1)
# lazy, non-distributed
if not distributed:
if _HAVE_SINGLE_GRID_WORK_BALANCING:
return SingleGridWorkBalancingPytatoArrayContext
else:
return PytatoPyOpenCLArrayContext
return PytatoPyOpenCLArrayContext
else:
# distributed+lazy:
if _HAVE_SINGLE_GRID_WORK_BALANCING:
return MPISingleGridWorkBalancingPytatoArrayContext
else:
return MPIBasePytatoPyOpenCLArrayContext
return MPIBasePytatoPyOpenCLArrayContext


def get_reasonable_array_context_class(
Expand All @@ -603,7 +562,7 @@ def get_reasonable_array_context_class(
if numpy:
assert not (lazy or fusion)
if distributed:
actx_class = MPINumpyArrayContext
actx_class: type[ArrayContext] = MPINumpyArrayContext
else:
actx_class = NumpyArrayContext

Expand Down Expand Up @@ -641,7 +600,7 @@ def get_reasonable_array_context_class(
"device-parallel=%r",
actx_class.__name__, lazy, distributed,
# eager is always device-parallel:
(_HAVE_SINGLE_GRID_WORK_BALANCING or _HAVE_FUSION_ACTX or not lazy))
(_HAVE_FUSION_ACTX or not lazy))
return actx_class

# }}}
Expand Down
Loading
Loading