Skip to content

Commit

Permalink
Use bucket sort to schedule comm batches in distributed-memory
Browse files Browse the repository at this point in the history
This avoids quadratic runtime in the previous "batchy toposort".

Co-authored-by: Andreas Kloeckner <[email protected]>
  • Loading branch information
nkoskelo and inducer authored Nov 1, 2023
1 parent dea2373 commit d79fc1c
Show file tree
Hide file tree
Showing 2 changed files with 193 additions and 28 deletions.
103 changes: 75 additions & 28 deletions pytato/distributed/partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@
import collections
from typing import (
Iterator, Iterable, Sequence, Any, Mapping, FrozenSet, Set, Dict, cast,
List, AbstractSet, TypeVar, TYPE_CHECKING, Hashable, Optional)
List, AbstractSet, TypeVar, TYPE_CHECKING, Hashable, Optional, Tuple)

import attrs
from immutabledict import immutabledict
Expand Down Expand Up @@ -509,48 +509,96 @@ def map_distributed_recv(
# }}}


# {{{ _schedule_comm_batches
TaskType = TypeVar("TaskType")

def _schedule_comm_batches(
comm_ids_to_needed_comm_ids: CommunicationDepGraph
) -> Sequence[AbstractSet[CommunicationOpIdentifier]]:
"""For each :class:`CommunicationOpIdentifier`, determine the

# {{{ _schedule_task_batches (and related)

def _schedule_task_batches(
task_ids_to_needed_task_ids: Mapping[TaskType, AbstractSet[TaskType]]) \
-> Sequence[AbstractSet[TaskType]]:
"""For each :type:`TaskType`, determine the
'round'/'batch' during which it will be performed. A 'batch'
of communication consists of sends and receives. Computation
occurs between batches. (So, from the perspective of the
:class:`DistributedGraphPartition`, communication batches
sit *between* parts.)
of tasks consists of tasks which do not depend on each other.
A task may only be in a batch if all of its dependents have already been
completed.
"""
# FIXME: I'm an O(n^2) algorithm.
return _schedule_task_batches_counted(task_ids_to_needed_task_ids)[0]
# }}}


# {{{ _schedule_task_batches_counted

def _schedule_task_batches_counted(
task_ids_to_needed_task_ids: Mapping[TaskType, AbstractSet[TaskType]]) \
-> Tuple[Sequence[AbstractSet[TaskType]], int]:
"""
Static type checkers need the functions to return the same type regardless
of the input. The testing code needs to know about the number of tasks visited
during the scheduling algorithm's execution. However, nontesting code does not.
"""
task_to_dep_level, visits_in_depend = \
_calculate_dependency_levels(task_ids_to_needed_task_ids)
nlevels = 1 + max(task_to_dep_level.values(), default=-1)
task_batches: Sequence[Set[TaskType]] = [set() for _ in range(nlevels)]

for task_id, dep_level in task_to_dep_level.items():
task_batches[dep_level].add(task_id)

comm_batches: List[AbstractSet[CommunicationOpIdentifier]] = []
return task_batches, visits_in_depend + len(task_to_dep_level.keys())

scheduled_comm_ids: Set[CommunicationOpIdentifier] = set()
comms_to_schedule = set(comm_ids_to_needed_comm_ids)
# }}}


# {{{ _calculate_dependency_levels

all_comm_ids = frozenset(comm_ids_to_needed_comm_ids)
def _calculate_dependency_levels(
task_ids_to_needed_task_ids: Mapping[TaskType, AbstractSet[TaskType]]
) -> Tuple[Mapping[TaskType, int], int]:
"""Calculate the minimum dependendency level needed before a task of
type TaskType can be scheduled. We assume that any number of tasks
can be scheduled at the same time. To attain complexity linear in the
number of nodes, we assume that each task has a constant number of direct
dependents.
The minimum dependency level for a task, i, is defined as
1 + the maximum dependency level for its children.
"""
task_to_dep_level: Dict[TaskType, int] = {}
seen: set[TaskType] = set()
nodes_visited: int = 0

# FIXME In order for this to work, comm tags must be unique
while len(scheduled_comm_ids) < len(all_comm_ids):
comm_ids_this_batch = {
comm_id for comm_id in comms_to_schedule
if comm_ids_to_needed_comm_ids[comm_id] <= scheduled_comm_ids}
def _dependency_level_dfs(task_id: TaskType) -> int:
"""Helper function to do depth first search on a graph."""

if not comm_ids_this_batch:
raise CycleError("cycle detected in communication graph")
if task_id in task_to_dep_level:
return task_to_dep_level[task_id]

scheduled_comm_ids.update(comm_ids_this_batch)
comms_to_schedule = comms_to_schedule - comm_ids_this_batch
# If node has been 'seen', but dep level is not yet known, that's a cycle.
if task_id in seen:
raise CycleError("Cycle detected in your input graph.")
seen.add(task_id)

comm_batches.append(comm_ids_this_batch)
nonlocal nodes_visited
nodes_visited += 1

return comm_batches
dep_level = 1 + max(
[_dependency_level_dfs(dep)
for dep in task_ids_to_needed_task_ids[task_id]] or [-1])
task_to_dep_level[task_id] = dep_level
return dep_level

for task_id in task_ids_to_needed_task_ids:
_dependency_level_dfs(task_id)

return task_to_dep_level, nodes_visited

# }}}


# {{{ _MaterializedArrayCollector


@optimize_mapper(drop_args=True, drop_kwargs=True, inline_get_cache_key=True)
class _MaterializedArrayCollector(CachedWalkMapper):
"""
Expand Down Expand Up @@ -751,7 +799,7 @@ def find_distributed_partition(
# The comm_batches correspond one-to-one to DistributedGraphParts
# in the output.
try:
comm_batches = _schedule_comm_batches(comm_ids_to_needed_comm_ids)
comm_batches = _schedule_task_batches(comm_ids_to_needed_comm_ids)
except Exception as exc:
mpi_communicator.bcast(exc)
raise
Expand All @@ -771,7 +819,6 @@ def find_distributed_partition(
# {{{ create (local) parts out of batch ids

part_comm_ids: List[_PartCommIDs] = []

if comm_batches:
recv_ids: FrozenSet[CommunicationOpIdentifier] = frozenset()
for batch in comm_batches:
Expand Down
118 changes: 118 additions & 0 deletions test/test_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,124 @@ def _do_test_distributed_execution_basic(ctx_factory):
# }}}


# {{{ Scheduler Algorithm update tests.

def test_distributed_scheduler_counts():
""" Test that the scheduling algorithm runs in `O(n)` time when
operating on a DAG which is just a stick with the dependencies
implied and not directly listed.
"""
from pytato.distributed.partition import _schedule_task_batches_counted
sizes = np.logspace(0, 6, 10, dtype=int)
count_list = np.zeros(len(sizes))
for i, tree_size in enumerate(sizes):
needed_ids = {i: set() for i in range(int(tree_size))}
for key in needed_ids.keys():
needed_ids[key] = {key-1} if key > 0 else set()
_, count_list[i] = _schedule_task_batches_counted(needed_ids)

# Now to do the fitting.
coefficients = np.polyfit(sizes, count_list, 4)
import numpy.linalg as la
nonlinear_norm_frac = la.norm(coefficients[:-2], 2)/la.norm(coefficients, 2)
assert nonlinear_norm_frac < 0.0001

# }}}


# {{{ test_distributed_scheduler_has_minimum_num_of_levels

def test_distributed_scheduler_returns_minimum_num_of_levels():
from pytato.distributed.partition import _schedule_task_batches_counted
max_size = 10
needed_ids = {j: set() for j in range(max_size)}
for i in range(1, max_size-1):
needed_ids[i].add(i-1)

batches, _ = _schedule_task_batches_counted(needed_ids)
# The last task has no dependences listed so it can be placed anywhere.
assert len(batches) == (max_size - 1)

# }}}


# {{{ test_distributed_scheduling_alg_can_find_cycle

def test_distributed_scheduling_alg_can_find_cycle():
from pytato.distributed.partition import _schedule_task_batches_counted
sizes = 100
my_graph = {i: {i-1} for i in range(int(sizes))}
my_graph[0] = {}
my_graph[60].add(95) # Here is the cycle. 60 - 95 -94 - 93 ... - 60
with pytest.raises(CycleError):
_schedule_task_batches_counted(my_graph)

# }}}


# {{{ test scheduling based upon a tree with dependents listed out.

def test_distributed_scheduling_o_n_direct_dependents():
""" Check that the temporal complexity of the scheduling algorithm
in the case that there are `O(n)` direct dependents for each task
is not cubic.
"""
from pytato.distributed.partition import _schedule_task_batches_counted
sizes = np.logspace(0, 4, 10, dtype=int)
count_list = np.zeros(len(sizes))
for i, tree_size in enumerate(sizes):
needed_ids = {i: set() for i in range(int(tree_size))}
for key in needed_ids.keys():
for j in range(key):
needed_ids[key].add(j)
_, count_list[i] = _schedule_task_batches_counted(needed_ids)

# Now to do the fitting.
coefficients = np.polyfit(sizes, count_list, 4)
import numpy.linalg as la
# We are expecting less then cubic scaling.
nonquadratic_norm_frac = la.norm(coefficients[:-3], 2)/la.norm(coefficients, 2)
assert nonquadratic_norm_frac < 0.0001

# }}}


# {{{ test scheduling constant branching tree

def test_distributed_scheduling_constant_look_back_tree():
"""Test that the scheduling algorithm scales in linear time if the input DAG
is a constant look back tree. This tree has a single root and then 5 tendrils
off of this root. Along the tendril each node has a direct dependence on the
previous one in the tendril but no other direct dependencies. This is intended
to confirm that the scheduling algorithm utilizing the minimum number of batch
levels possible.
"""
from pytato.distributed.partition import _schedule_task_batches_counted
import math
sizes = np.logspace(0, 6, 10, dtype=int)
count_list = np.zeros(len(sizes))
branching_factor = 5
for i, tree_size in enumerate(sizes):
needed_ids = {j: set() for j in range(int(tree_size))}
for j in range(1, int(tree_size)):
if j < branching_factor:
needed_ids[j+1] = {0}
else:
needed_ids[j] = {j - branching_factor}
batches, count_list[i] = _schedule_task_batches_counted(needed_ids)

# Test that the number of batches is the expected minimum number.
assert len(batches) == math.ceil((tree_size - 1) / branching_factor) + 1

# Now to do the fitting.
coefficients = np.polyfit(sizes, count_list, 4)
import numpy.linalg as la
nonlinear_norm_frac = la.norm(coefficients[:-2], 2)/la.norm(coefficients, 2)
assert nonlinear_norm_frac < 0.0001

# }}}


# {{{ test based on random dag

def test_distributed_execution_random_dag():
Expand Down

0 comments on commit d79fc1c

Please sign in to comment.