Skip to content

Commit

Permalink
Add a test for Kaushik's distributed MWE
Browse files Browse the repository at this point in the history
  • Loading branch information
kaushikcfd authored and inducer committed Apr 10, 2023
1 parent c11cc03 commit f0b10f8
Showing 1 changed file with 56 additions and 0 deletions.
56 changes: 56 additions & 0 deletions test/test_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,62 @@ def test_deterministic_partitioning():
# }}}


# {{{ test Kaushik's MWE

def test_kaushik_mwe():
run_test_with_mpi(2, _do_test_kaushik_mwe)


def _do_test_kaushik_mwe(ctx_factory):
# from https://github.com/inducer/pytato/pull/393#issuecomment-1324642248
from mpi4py import MPI

comm = MPI.COMM_WORLD

if comm.rank == 0:
send_rank = 1
recv_rank = 1

recv = pt.make_distributed_recv(
src_rank=recv_rank, comm_tag=42,
shape=(10,), dtype=np.float64)
y = 2*recv

send = pt.staple_distributed_send(
y, dest_rank=send_rank, comm_tag=43,
stapled_to=pt.ones(10))
out = pt.make_dict_of_named_arrays({"out": send})
elif comm.rank == 1:
send_rank = 0
recv_rank = 0
x = pt.make_data_wrapper(np.ones(10))

send = pt.staple_distributed_send(
2*x, dest_rank=send_rank, comm_tag=42,
stapled_to=pt.zeros(10))
recv = pt.make_distributed_recv(
src_rank=recv_rank, comm_tag=43,
shape=(10,), dtype=np.float64)
out = pt.make_dict_of_named_arrays({"out1": send, "out2": recv})
else:
raise AssertionError()

distributed_parts = pt.find_distributed_partition(comm, out)

pt.verify_distributed_partition(comm, distributed_parts)
prg_per_partition = pt.generate_code_for_partition(distributed_parts)

# Execute the distributed partition
ctx = ctx_factory()
queue = cl.CommandQueue(ctx)

pt.execute_distributed_partition(distributed_parts, prg_per_partition,
queue, comm,
input_args={})

# }}}


# {{{ test verify_distributed_partition

def test_verify_distributed_partition():
Expand Down

0 comments on commit f0b10f8

Please sign in to comment.