diff --git a/pytato/distributed/partition.py b/pytato/distributed/partition.py index 426a8cff5..433fda1fc 100644 --- a/pytato/distributed/partition.py +++ b/pytato/distributed/partition.py @@ -65,7 +65,7 @@ import collections from typing import ( Iterator, Iterable, Sequence, Any, Mapping, FrozenSet, Set, Dict, cast, - List, AbstractSet, TypeVar, TYPE_CHECKING, Hashable, Optional) + List, AbstractSet, TypeVar, TYPE_CHECKING, Hashable, Optional, Tuple) import attrs from immutabledict import immutabledict @@ -509,48 +509,96 @@ def map_distributed_recv( # }}} -# {{{ _schedule_comm_batches +TaskType = TypeVar("TaskType") -def _schedule_comm_batches( - comm_ids_to_needed_comm_ids: CommunicationDepGraph - ) -> Sequence[AbstractSet[CommunicationOpIdentifier]]: - """For each :class:`CommunicationOpIdentifier`, determine the + +# {{{ _schedule_task_batches (and related) + +def _schedule_task_batches( + task_ids_to_needed_task_ids: Mapping[TaskType, AbstractSet[TaskType]]) \ + -> Sequence[AbstractSet[TaskType]]: + """For each :type:`TaskType`, determine the 'round'/'batch' during which it will be performed. A 'batch' - of communication consists of sends and receives. Computation - occurs between batches. (So, from the perspective of the - :class:`DistributedGraphPartition`, communication batches - sit *between* parts.) + of tasks consists of tasks which do not depend on each other. + A task may only be in a batch if all of its dependents have already been + completed. """ - # FIXME: I'm an O(n^2) algorithm. + return _schedule_task_batches_counted(task_ids_to_needed_task_ids)[0] +# }}} + + +# {{{ _schedule_task_batches_counted + +def _schedule_task_batches_counted( + task_ids_to_needed_task_ids: Mapping[TaskType, AbstractSet[TaskType]]) \ + -> Tuple[Sequence[AbstractSet[TaskType]], int]: + """ + Static type checkers need the functions to return the same type regardless + of the input. The testing code needs to know about the number of tasks visited + during the scheduling algorithm's execution. However, nontesting code does not. + """ + task_to_dep_level, visits_in_depend = \ + _calculate_dependency_levels(task_ids_to_needed_task_ids) + nlevels = 1 + max(task_to_dep_level.values(), default=-1) + task_batches: Sequence[Set[TaskType]] = [set() for _ in range(nlevels)] + + for task_id, dep_level in task_to_dep_level.items(): + task_batches[dep_level].add(task_id) - comm_batches: List[AbstractSet[CommunicationOpIdentifier]] = [] + return task_batches, visits_in_depend + len(task_to_dep_level.keys()) - scheduled_comm_ids: Set[CommunicationOpIdentifier] = set() - comms_to_schedule = set(comm_ids_to_needed_comm_ids) +# }}} + + +# {{{ _calculate_dependency_levels - all_comm_ids = frozenset(comm_ids_to_needed_comm_ids) +def _calculate_dependency_levels( + task_ids_to_needed_task_ids: Mapping[TaskType, AbstractSet[TaskType]] + ) -> Tuple[Mapping[TaskType, int], int]: + """Calculate the minimum dependendency level needed before a task of + type TaskType can be scheduled. We assume that any number of tasks + can be scheduled at the same time. To attain complexity linear in the + number of nodes, we assume that each task has a constant number of direct + dependents. + + The minimum dependency level for a task, i, is defined as + 1 + the maximum dependency level for its children. + """ + task_to_dep_level: Dict[TaskType, int] = {} + seen: set[TaskType] = set() + nodes_visited: int = 0 - # FIXME In order for this to work, comm tags must be unique - while len(scheduled_comm_ids) < len(all_comm_ids): - comm_ids_this_batch = { - comm_id for comm_id in comms_to_schedule - if comm_ids_to_needed_comm_ids[comm_id] <= scheduled_comm_ids} + def _dependency_level_dfs(task_id: TaskType) -> int: + """Helper function to do depth first search on a graph.""" - if not comm_ids_this_batch: - raise CycleError("cycle detected in communication graph") + if task_id in task_to_dep_level: + return task_to_dep_level[task_id] - scheduled_comm_ids.update(comm_ids_this_batch) - comms_to_schedule = comms_to_schedule - comm_ids_this_batch + # If node has been 'seen', but dep level is not yet known, that's a cycle. + if task_id in seen: + raise CycleError("Cycle detected in your input graph.") + seen.add(task_id) - comm_batches.append(comm_ids_this_batch) + nonlocal nodes_visited + nodes_visited += 1 - return comm_batches + dep_level = 1 + max( + [_dependency_level_dfs(dep) + for dep in task_ids_to_needed_task_ids[task_id]] or [-1]) + task_to_dep_level[task_id] = dep_level + return dep_level + + for task_id in task_ids_to_needed_task_ids: + _dependency_level_dfs(task_id) + + return task_to_dep_level, nodes_visited # }}} # {{{ _MaterializedArrayCollector + @optimize_mapper(drop_args=True, drop_kwargs=True, inline_get_cache_key=True) class _MaterializedArrayCollector(CachedWalkMapper): """ @@ -751,7 +799,7 @@ def find_distributed_partition( # The comm_batches correspond one-to-one to DistributedGraphParts # in the output. try: - comm_batches = _schedule_comm_batches(comm_ids_to_needed_comm_ids) + comm_batches = _schedule_task_batches(comm_ids_to_needed_comm_ids) except Exception as exc: mpi_communicator.bcast(exc) raise @@ -771,7 +819,6 @@ def find_distributed_partition( # {{{ create (local) parts out of batch ids part_comm_ids: List[_PartCommIDs] = [] - if comm_batches: recv_ids: FrozenSet[CommunicationOpIdentifier] = frozenset() for batch in comm_batches: diff --git a/test/test_distributed.py b/test/test_distributed.py index 8d7cd50dd..f7a8e5b4c 100644 --- a/test/test_distributed.py +++ b/test/test_distributed.py @@ -118,6 +118,124 @@ def _do_test_distributed_execution_basic(ctx_factory): # }}} +# {{{ Scheduler Algorithm update tests. + +def test_distributed_scheduler_counts(): + """ Test that the scheduling algorithm runs in `O(n)` time when + operating on a DAG which is just a stick with the dependencies + implied and not directly listed. + """ + from pytato.distributed.partition import _schedule_task_batches_counted + sizes = np.logspace(0, 6, 10, dtype=int) + count_list = np.zeros(len(sizes)) + for i, tree_size in enumerate(sizes): + needed_ids = {i: set() for i in range(int(tree_size))} + for key in needed_ids.keys(): + needed_ids[key] = {key-1} if key > 0 else set() + _, count_list[i] = _schedule_task_batches_counted(needed_ids) + + # Now to do the fitting. + coefficients = np.polyfit(sizes, count_list, 4) + import numpy.linalg as la + nonlinear_norm_frac = la.norm(coefficients[:-2], 2)/la.norm(coefficients, 2) + assert nonlinear_norm_frac < 0.0001 + +# }}} + + +# {{{ test_distributed_scheduler_has_minimum_num_of_levels + +def test_distributed_scheduler_returns_minimum_num_of_levels(): + from pytato.distributed.partition import _schedule_task_batches_counted + max_size = 10 + needed_ids = {j: set() for j in range(max_size)} + for i in range(1, max_size-1): + needed_ids[i].add(i-1) + + batches, _ = _schedule_task_batches_counted(needed_ids) + # The last task has no dependences listed so it can be placed anywhere. + assert len(batches) == (max_size - 1) + +# }}} + + +# {{{ test_distributed_scheduling_alg_can_find_cycle + +def test_distributed_scheduling_alg_can_find_cycle(): + from pytato.distributed.partition import _schedule_task_batches_counted + sizes = 100 + my_graph = {i: {i-1} for i in range(int(sizes))} + my_graph[0] = {} + my_graph[60].add(95) # Here is the cycle. 60 - 95 -94 - 93 ... - 60 + with pytest.raises(CycleError): + _schedule_task_batches_counted(my_graph) + +# }}} + + +# {{{ test scheduling based upon a tree with dependents listed out. + +def test_distributed_scheduling_o_n_direct_dependents(): + """ Check that the temporal complexity of the scheduling algorithm + in the case that there are `O(n)` direct dependents for each task + is not cubic. + """ + from pytato.distributed.partition import _schedule_task_batches_counted + sizes = np.logspace(0, 4, 10, dtype=int) + count_list = np.zeros(len(sizes)) + for i, tree_size in enumerate(sizes): + needed_ids = {i: set() for i in range(int(tree_size))} + for key in needed_ids.keys(): + for j in range(key): + needed_ids[key].add(j) + _, count_list[i] = _schedule_task_batches_counted(needed_ids) + + # Now to do the fitting. + coefficients = np.polyfit(sizes, count_list, 4) + import numpy.linalg as la + # We are expecting less then cubic scaling. + nonquadratic_norm_frac = la.norm(coefficients[:-3], 2)/la.norm(coefficients, 2) + assert nonquadratic_norm_frac < 0.0001 + +# }}} + + +# {{{ test scheduling constant branching tree + +def test_distributed_scheduling_constant_look_back_tree(): + """Test that the scheduling algorithm scales in linear time if the input DAG + is a constant look back tree. This tree has a single root and then 5 tendrils + off of this root. Along the tendril each node has a direct dependence on the + previous one in the tendril but no other direct dependencies. This is intended + to confirm that the scheduling algorithm utilizing the minimum number of batch + levels possible. + """ + from pytato.distributed.partition import _schedule_task_batches_counted + import math + sizes = np.logspace(0, 6, 10, dtype=int) + count_list = np.zeros(len(sizes)) + branching_factor = 5 + for i, tree_size in enumerate(sizes): + needed_ids = {j: set() for j in range(int(tree_size))} + for j in range(1, int(tree_size)): + if j < branching_factor: + needed_ids[j+1] = {0} + else: + needed_ids[j] = {j - branching_factor} + batches, count_list[i] = _schedule_task_batches_counted(needed_ids) + + # Test that the number of batches is the expected minimum number. + assert len(batches) == math.ceil((tree_size - 1) / branching_factor) + 1 + + # Now to do the fitting. + coefficients = np.polyfit(sizes, count_list, 4) + import numpy.linalg as la + nonlinear_norm_frac = la.norm(coefficients[:-2], 2)/la.norm(coefficients, 2) + assert nonlinear_norm_frac < 0.0001 + +# }}} + + # {{{ test based on random dag def test_distributed_execution_random_dag():