From 19fc0d7713cbad5024cc0e814075561d0c736689 Mon Sep 17 00:00:00 2001 From: NihalHarish Date: Fri, 2 Oct 2020 22:22:20 -0700 Subject: [PATCH] prepare dataset outside of child processes --- .../test_pytorch_multiprocessing.py | 29 ++++++++++--------- 1 file changed, 15 insertions(+), 14 deletions(-) diff --git a/tests/zero_code_change/test_pytorch_multiprocessing.py b/tests/zero_code_change/test_pytorch_multiprocessing.py index 472630c7b..3eba884dd 100644 --- a/tests/zero_code_change/test_pytorch_multiprocessing.py +++ b/tests/zero_code_change/test_pytorch_multiprocessing.py @@ -39,8 +39,9 @@ def forward(self, x): return F.log_softmax(x, dim=1) -def train(rank, model, device, dataloader_kwargs): +def train(rank, model, device, data_set, dataloader_kwargs): # Training Settings + batch_size = 64 epochs = 1 lr = 0.01 @@ -48,18 +49,7 @@ def train(rank, model, device, dataloader_kwargs): torch.manual_seed(1 + rank) train_loader = torch.utils.data.DataLoader( - datasets.MNIST( - data_dir, - train=True, - download=True, - transform=transforms.Compose( - [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))] - ), - ), - batch_size=batch_size, - shuffle=True, - num_workers=1, - **dataloader_kwargs + data_set, batch_size=batch_size, shuffle=True, num_workers=1, **dataloader_kwargs ) optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum) @@ -85,6 +75,17 @@ def test_no_failure_with_torch_mp(out_dir): path = str(path) os.environ["SMDEBUG_CONFIG_FILE_PATH"] = path device = "cpu" + + # clear data_dir before saving to it + shutil.rmtree(data_dir, ignore_errors=True) + data_set = datasets.MNIST( + data_dir, + train=True, + download=True, + transform=transforms.Compose( + [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))] + ), + ) dataloader_kwargs = {} cpu_count = 2 if mp.cpu_count() > 2 else mp.cpu_count() @@ -95,7 +96,7 @@ def test_no_failure_with_torch_mp(out_dir): processes = [] for rank in range(cpu_count): - p = mp.Process(target=train, args=(rank, model, device, dataloader_kwargs)) + p = mp.Process(target=train, args=(rank, model, device, data_set, dataloader_kwargs)) # We first train the model across `num_processes` processes p.start() processes.append(p)