Skip to content

Commit

Permalink
Finish removing pytato.partition
Browse files Browse the repository at this point in the history
  • Loading branch information
inducer committed Apr 10, 2023
1 parent f0b10f8 commit 13760d9
Show file tree
Hide file tree
Showing 7 changed files with 200 additions and 268 deletions.
5 changes: 0 additions & 5 deletions doc/dag.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,6 @@ Stringifying Expression Graphs

.. _partitioning:

Partitioning Array Expression Graphs
====================================

.. automodule:: pytato.partition

.. _distributed:

Support for Distributed-Memory/Message Passing
Expand Down
8 changes: 3 additions & 5 deletions pytato/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,16 +101,15 @@ def set_debug_enabled(flag: bool) -> None:
from pytato.distributed.partition import (
find_distributed_partition, DistributedGraphPart, DistributedGraphPartition)
from pytato.distributed.tags import number_distributed_tags
from pytato.distributed.execute import (
generate_code_for_partition, execute_distributed_partition)
from pytato.distributed.verify import verify_distributed_partition
from pytato.distributed.execute import execute_distributed_partition

from pytato.transform.lower_to_index_lambda import to_index_lambda
from pytato.transform.remove_broadcasts_einsum import (
rewrite_einsums_with_no_broadcasts)
from pytato.transform.metadata import unify_axes_tags

from pytato.partition import generate_code_for_partition

__all__ = (
"dtype",

Expand Down Expand Up @@ -162,11 +161,10 @@ def set_debug_enabled(flag: bool) -> None:
"find_distributed_partition",

"number_distributed_tags",
"generate_code_for_partition",
"execute_distributed_partition",
"verify_distributed_partition",

"generate_code_for_partition",

"to_index_lambda",

"rewrite_einsums_with_no_broadcasts",
Expand Down
27 changes: 25 additions & 2 deletions pytato/distributed/execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,10 @@
THE SOFTWARE.
"""

from typing import Any, Dict, Hashable, Tuple, Optional, TYPE_CHECKING
from typing import Any, Dict, Hashable, Tuple, Optional, TYPE_CHECKING, Mapping


from pytato.array import make_dict_of_named_arrays
from pytato.target import BoundProgram
from pytato.scalar_expr import INT_CLASSES

Expand All @@ -42,7 +43,7 @@
from pytato.distributed.nodes import (
DistributedRecv, DistributedSend)
from pytato.distributed.partition import (
DistributedGraphPartition, DistributedGraphPart)
DistributedGraphPartition, DistributedGraphPart, PartId)

import logging
logger = logging.getLogger(__name__)
Expand All @@ -52,6 +53,28 @@
import mpi4py.MPI


# {{{ generate_code_for_partition

def generate_code_for_partition(partition: DistributedGraphPartition) \
-> Mapping[PartId, BoundProgram]:
"""Return a mapping of partition identifiers to their
:class:`pytato.target.BoundProgram`."""
from pytato import generate_loopy
part_id_to_prg = {}

for part in sorted(partition.parts.values(),
key=lambda part_: sorted(part_.output_names)):
d = make_dict_of_named_arrays(
{var_name: partition.var_name_to_result[var_name]
for var_name in part.output_names
})
part_id_to_prg[part.pid] = generate_loopy(d)

return part_id_to_prg

# }}}


# {{{ distributed execute

def _post_receive(mpi_communicator: mpi4py.MPI.Comm,
Expand Down
101 changes: 83 additions & 18 deletions pytato/distributed/partition.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,21 @@
"""
Partitioning of graphs in :mod:`pytato` serves to enable
:ref:`distributed computation <distributed>`, i.e. sending and receiving data
as part of graph evaluation.
Partitioning of expression graphs is based on a few assumptions:
- We must be able to execute parts in any dependency-respecting order.
- Parts are compiled at partitioning time, so what inputs they take from memory
vs. what they compute is decided at that time.
- No part may depend on its own outputs as inputs.
Internal stuff that is only here because the documentation tool wants it
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.. class:: T
A type variable for :class:`~pytato.array.AbstractResultWithNamedArrays`.
.. autoclass:: CommunicationOpIdentifier
.. class:: CommunicationDepGraph
Expand Down Expand Up @@ -42,12 +59,13 @@
from functools import reduce
from typing import (
Sequence, Any, Mapping, FrozenSet, Set, Dict, cast,
List, AbstractSet, TypeVar, TYPE_CHECKING)
List, AbstractSet, TypeVar, TYPE_CHECKING, Hashable)

import attrs
from immutables import Map

from pytools.graph import CycleError
from pytools import memoize_method

from pymbolic.mapper.optimize import optimize_mapper
from pytools import UniqueNameGenerator
Expand All @@ -57,14 +75,13 @@
from pytato.transform import (ArrayOrNames, CopyMapper,
CachedWalkMapper,
CombineMapper)
from pytato.partition import GraphPart, GraphPartition, PartId
from pytato.distributed.nodes import (
DistributedRecv, DistributedSend, DistributedSendRefHolder)
from pytato.distributed.nodes import CommTagType
from pytato.analysis import DirectPredecessorsGetter

if TYPE_CHECKING:
import mpi4py.MPI as MPI
import mpi4py.MPI


@attrs.define(frozen=True)
Expand All @@ -78,10 +95,10 @@ class CommunicationOpIdentifier:
.. note::
In :func:`find_distributed_partition`, we use instances of this type as
though they identify sends or receives, i.e. just a single end of the
communication. Realize that this is only true given the additional
context of which rank is the local rank.
In :func:`~pytato.find_distributed_partition`, we use instances of this
type as though they identify sends or receives, i.e. just a single end
of the communication. Realize that this is only true given the
additional context of which rank is the local rank.
"""
src_rank: int
dest_rank: int
Expand All @@ -96,32 +113,80 @@ class CommunicationOpIdentifier:
_ValueT = TypeVar("_ValueT")


# {{{ distributed graph partition
# {{{ distributed graph part

PartId = Hashable


@attrs.define(frozen=True, slots=False)
class DistributedGraphPart(GraphPart):
class DistributedGraphPart:
"""For one graph part, record send/receive information for input/
output names.
Names that occur as keys in :attr:`name_to_recv_node` and
:attr:`name_to_send_node` are usable as input names by other
parts, or in the result of the computation.
.. attribute:: pid
An identifier for this part of the graph.
.. attribute:: needed_pids
The IDs of parts that are required to be evaluated before this
part can be evaluated.
.. attribute:: user_input_names
A :class:`frozenset` of names representing input to the computational
graph, i.e. which were *not* introduced by partitioning.
.. attribute:: partition_input_names
A :class:`frozenset` of names of placeholders the part requires as
input from other parts in the partition.
.. attribute:: output_names
Names of placeholders this part provides as output.
.. attribute:: name_to_recv_node
.. attribute:: name_to_send_node
.. automethod:: all_input_names
"""
pid: PartId
needed_pids: FrozenSet[PartId]
user_input_names: FrozenSet[str]
partition_input_names: FrozenSet[str]
output_names: FrozenSet[str]

name_to_recv_node: Mapping[str, DistributedRecv]
name_to_send_node: Mapping[str, DistributedSend]

@memoize_method
def all_input_names(self) -> FrozenSet[str]:
return self.user_input_names | self. partition_input_names

# }}}


# {{{ distributed graph partition

@attrs.define(frozen=True, slots=False)
class DistributedGraphPartition(GraphPartition):
"""Store information about distributed graph partitions. This
has the same attributes as :class:`~pytato.partition.GraphPartition`,
however :attr:`~pytato.partition.GraphPartition.parts` now maps to
instances of :class:`DistributedGraphPart`.
class DistributedGraphPartition:
"""
.. attribute:: parts
Mapping from part IDs to instances of :class:`DistributedGraphPart`.
.. attribute:: var_name_to_result
Mapping of placeholder names to the respective :class:`pytato.array.Array`
they represent.
"""
parts: Dict[PartId, DistributedGraphPart]
parts: Mapping[PartId, DistributedGraphPart]
var_name_to_result: Mapping[str, Array]

# }}}

Expand Down Expand Up @@ -433,7 +498,7 @@ def post_visit(self, expr: Any) -> None: # type: ignore[override]
def _set_dict_union_mpi(
dict_a: Mapping[_KeyT, FrozenSet[_ValueT]],
dict_b: Mapping[_KeyT, FrozenSet[_ValueT]],
mpi_data_type: MPI.Datatype) -> Mapping[_KeyT, FrozenSet[_ValueT]]:
mpi_data_type: mpi4py.MPI.Datatype) -> Mapping[_KeyT, FrozenSet[_ValueT]]:
assert mpi_data_type is None
result = dict(dict_a)
for key, values in dict_b.items():
Expand All @@ -446,7 +511,7 @@ def _set_dict_union_mpi(
# {{{ find_distributed_partition

def find_distributed_partition(
mpi_communicator: MPI.Comm,
mpi_communicator: mpi4py.MPI.Comm,
outputs: DictOfNamedArrays
) -> DistributedGraphPartition:
r"""
Expand Down Expand Up @@ -836,7 +901,7 @@ def get_stored_predecessors(ary: Array) -> FrozenSet[Array]:
lsrdg.local_recv_id_to_recv_node,
lsrdg.local_send_id_to_send_node)

from pytato.partition import _run_partition_diagnostics
from pytato.distributed.verify import _run_partition_diagnostics
_run_partition_diagnostics(outputs, partition)

if __debug__:
Expand Down
Loading

0 comments on commit 13760d9

Please sign in to comment.