Skip to content

Commit

Permalink
Merge pull request #33 from yoshitomo-matsubara/dev
Browse files Browse the repository at this point in the history
Extend ForwardHookManager
  • Loading branch information
yoshitomo-matsubara authored Dec 4, 2020
2 parents 0eeaf72 + ca16a4a commit a91f063
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 0 deletions.
15 changes: 15 additions & 0 deletions tests/core_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,18 @@ def test_pop_io_dict(self):
hooked_y = io_dict[target_module_path]['output']
assert torch.equal(y, hooked_y)
assert len(fhm.io_dict[target_module_path]) == 0

def test_pop_io_dict_from_device(self):
device = torch.device('cpu')
fhm = ForwardHookManager(device)
model = models.resnet18(False)
target_module_path = 'fc'
fhm.add_hook(model, target_module_path, requires_input=False, requires_output=True)
x = torch.rand(1, 3, 224, 224)
y = model(x)
io_dict = fhm.pop_io_dict_from_device(device)
assert len(io_dict) == 1
assert 'output' in io_dict[target_module_path]
hooked_y = io_dict[target_module_path]['output']
assert torch.equal(y, hooked_y)
assert len(fhm.io_dict[target_module_path]['output']) == 0
10 changes: 10 additions & 0 deletions torchdistill/core/forward_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,16 @@ def pop_io_dict(self):
gathered_io_dict[module_path][io_type] = gathered_obj
return gathered_io_dict

def pop_io_dict_from_device(self, device):
device_io_dict = dict()
device_key = device.index if device.type == 'cuda' else device.type
for module_path, module_io_dict in self.io_dict.items():
device_io_dict[module_path] = dict()
for io_type in list(module_io_dict.keys()):
sub_dict = module_io_dict[io_type]
device_io_dict[module_path][io_type] = sub_dict.pop(device_key)
return device_io_dict

def change_target_device(self, target_device):
if self.target_device.type != target_device.type:
for sub_dict in self.io_dict.values():
Expand Down

0 comments on commit a91f063

Please sign in to comment.