Skip to content

Commit

Permalink
prepare dataset outside of child processes
Browse files Browse the repository at this point in the history
  • Loading branch information
NihalHarish committed Oct 3, 2020
1 parent bb8f4b9 commit 19fc0d7
Showing 1 changed file with 15 additions and 14 deletions.
29 changes: 15 additions & 14 deletions tests/zero_code_change/test_pytorch_multiprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,27 +39,17 @@ def forward(self, x):
return F.log_softmax(x, dim=1)


def train(rank, model, device, dataloader_kwargs):
def train(rank, model, device, data_set, dataloader_kwargs):
# Training Settings

batch_size = 64
epochs = 1
lr = 0.01
momentum = 0.5

torch.manual_seed(1 + rank)
train_loader = torch.utils.data.DataLoader(
datasets.MNIST(
data_dir,
train=True,
download=True,
transform=transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
),
),
batch_size=batch_size,
shuffle=True,
num_workers=1,
**dataloader_kwargs
data_set, batch_size=batch_size, shuffle=True, num_workers=1, **dataloader_kwargs
)

optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)
Expand All @@ -85,6 +75,17 @@ def test_no_failure_with_torch_mp(out_dir):
path = str(path)
os.environ["SMDEBUG_CONFIG_FILE_PATH"] = path
device = "cpu"

# clear data_dir before saving to it
shutil.rmtree(data_dir, ignore_errors=True)
data_set = datasets.MNIST(
data_dir,
train=True,
download=True,
transform=transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
),
)
dataloader_kwargs = {}
cpu_count = 2 if mp.cpu_count() > 2 else mp.cpu_count()

Expand All @@ -95,7 +96,7 @@ def test_no_failure_with_torch_mp(out_dir):

processes = []
for rank in range(cpu_count):
p = mp.Process(target=train, args=(rank, model, device, dataloader_kwargs))
p = mp.Process(target=train, args=(rank, model, device, data_set, dataloader_kwargs))
# We first train the model across `num_processes` processes
p.start()
processes.append(p)
Expand Down

0 comments on commit 19fc0d7

Please sign in to comment.