Skip to content

Commit

Permalink
Rework test_deterministic_partitioning
Browse files Browse the repository at this point in the history
  • Loading branch information
inducer committed Apr 10, 2023
1 parent 8e04ff8 commit 7b38c26
Showing 1 changed file with 30 additions and 48 deletions.
78 changes: 30 additions & 48 deletions test/test_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,68 +254,50 @@ def test_dag_with_no_comm_nodes():

# {{{ test deterministic partitioning

def _check_deterministic_partition(dag, ref_partition,
iproc, results):
def _gather_random_dist_partitions(ctx_factory):
import mpi4py.MPI as MPI

# FIXME: This test is limited to single-rank
partition = pt.find_distributed_partition(MPI.COMM_WORLD, dag)

are_equal = int(partition == ref_partition)
print(iproc, are_equal)
results[iproc] = are_equal


def test_deterministic_partitioning():
pytest.skip("this test needs to be rewritten to spawn multiple MPI processes")
comm = MPI.COMM_WORLD

import multiprocessing as mp
import os
seed = int(os.environ["PYTATO_DAG_SEED"])
from testlib import get_random_pt_dag_with_send_recv_nodes
dag = get_random_pt_dag_with_send_recv_nodes(
seed, rank=comm.rank, size=comm.size,
convert_dws_to_placeholders=True)

original_hash_seed = os.environ.pop("PYTHONHASHSEED", None)

nprocs = 4
my_partition = pt.find_distributed_partition(comm, dag)

mp_ctx = mp.get_context("spawn")
all_partitions = comm.gather(my_partition)

ntests = 10
for i in range(ntests):
seed = 120 + i
results = mp_ctx.Array("i", (0, ) * nprocs)
print(f"Step {i} {seed}")

# FIXME: This test no longer makes sense; it does not generate
# DAGs on ranks 1..6.
ref_dag = get_random_pt_dag_with_send_recv_nodes(
seed, rank=0, size=7,
convert_dws_to_placeholders=True)

ref_partition = pt.find_distributed_partition(comm, ref_dag)
from pickle import dump
if comm.rank == 0:
with open(os.environ["PYTATO_PARTITIONS_DUMP_FN"], "wb") as outf:
dump(all_partitions, outf)

# {{{ spawn nprocs-processes and verify they all compare equally

procs = [mp_ctx.Process(target=_check_deterministic_partition,
args=(ref_dag,
ref_partition,
iproc, results))
for iproc in range(nprocs)]
@pytest.mark.parametrize("seed", list(range(10)))
def test_deterministic_partitioning(seed):
import os
from pickle import load
from pytools import is_single_valued

for iproc, proc in enumerate(procs):
# See
# https://replit.com/@KaushikKulkarn1/spawningprocswithhashseedv2?v=1#main.py
os.environ["PYTHONHASHSEED"] = str(iproc)
proc.start()
partitions_across_seeds = []
partitions_dump_fn = f"tmp-partitions-{os.getpid()}.pkl"

for proc in procs:
proc.join()
for hashseed in [234, 241, 9222, 5]:
run_test_with_mpi(2, _gather_random_dist_partitions, extra_env_vars={
"PYTATO_DAG_SEED": str(seed),
"PYTHONHASHSEED": str(hashseed),
"PYTATO_PARTITIONS_DUMP_FN": partitions_dump_fn,
})

if original_hash_seed is not None:
os.environ["PYTHONHASHSEED"] = original_hash_seed
with open(partitions_dump_fn, "rb") as inf:
partitions_across_seeds.append(load(inf))
os.unlink(partitions_dump_fn)

assert set(results[:]) == {1}
pu.db

# }}}
assert is_single_valued(partitions_across_seeds)

# }}}

Expand Down

0 comments on commit 7b38c26

Please sign in to comment.