From 867f9b213b98ce2a564cfc3fc2f7dfc377d9d227 Mon Sep 17 00:00:00 2001 From: Yoshitomo Matsubara Date: Thu, 3 Dec 2020 10:41:38 -0800 Subject: [PATCH 1/4] added an option to pop I/O dict for a specific device --- torchdistill/core/forward_hook.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/torchdistill/core/forward_hook.py b/torchdistill/core/forward_hook.py index a694ce3c..6cecf224 100644 --- a/torchdistill/core/forward_hook.py +++ b/torchdistill/core/forward_hook.py @@ -116,6 +116,15 @@ def pop_io_dict(self): gathered_io_dict[module_path][io_type] = gathered_obj return gathered_io_dict + def pop_io_dict_from_device(self, device): + device_io_dict = dict() + for module_path, module_io_dict in self.io_dict.items(): + device_io_dict[module_path] = dict() + for io_type in list(module_io_dict.keys()): + sub_dict = module_io_dict[io_type] + device_io_dict[module_path][io_type] = sub_dict.pop(device.index) + return device_io_dict + def change_target_device(self, target_device): if self.target_device.type != target_device.type: for sub_dict in self.io_dict.values(): From bee840df40384594ceb3676e502c46cf6327ccdf Mon Sep 17 00:00:00 2001 From: Yoshitomo Matsubara Date: Thu, 3 Dec 2020 10:42:07 -0800 Subject: [PATCH 2/4] added a test case --- tests/core_test.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/tests/core_test.py b/tests/core_test.py index f8da4afe..4ae29ce4 100644 --- a/tests/core_test.py +++ b/tests/core_test.py @@ -34,3 +34,18 @@ def test_pop_io_dict(self): hooked_y = io_dict[target_module_path]['output'] assert torch.equal(y, hooked_y) assert len(fhm.io_dict[target_module_path]) == 0 + + def test_pop_io_dict_from_device(self): + device = torch.device('cpu') + fhm = ForwardHookManager(device) + model = models.resnet18(False) + target_module_path = 'fc' + fhm.add_hook(model, target_module_path, requires_input=False, requires_output=True) + x = torch.rand(1, 3, 224, 224) + y = model(x) + io_dict = fhm.pop_io_dict_from_device(device) + assert len(io_dict) == 1 + assert 'output' in io_dict[target_module_path] + hooked_y = io_dict[target_module_path]['output'] + assert torch.equal(y, hooked_y) + assert len(fhm.io_dict[target_module_path]) == 0 From b0e9415f66f9a699fe8ccbc1a17f15ed14566cf9 Mon Sep 17 00:00:00 2001 From: Yoshitomo Matsubara Date: Thu, 3 Dec 2020 10:57:26 -0800 Subject: [PATCH 3/4] fixed a bug for cpu device --- torchdistill/core/forward_hook.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torchdistill/core/forward_hook.py b/torchdistill/core/forward_hook.py index 6cecf224..dc558db2 100644 --- a/torchdistill/core/forward_hook.py +++ b/torchdistill/core/forward_hook.py @@ -118,11 +118,12 @@ def pop_io_dict(self): def pop_io_dict_from_device(self, device): device_io_dict = dict() + device_key = device.index if device.type == 'cuda' else device.type for module_path, module_io_dict in self.io_dict.items(): device_io_dict[module_path] = dict() for io_type in list(module_io_dict.keys()): sub_dict = module_io_dict[io_type] - device_io_dict[module_path][io_type] = sub_dict.pop(device.index) + device_io_dict[module_path][io_type] = sub_dict.pop(device_key) return device_io_dict def change_target_device(self, target_device): From ca16a4acf66e58d62028cafeba0a352830966dad Mon Sep 17 00:00:00 2001 From: Yoshitomo Matsubara Date: Thu, 3 Dec 2020 11:07:09 -0800 Subject: [PATCH 4/4] fixed a typo --- tests/core_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/core_test.py b/tests/core_test.py index 4ae29ce4..ff2e6268 100644 --- a/tests/core_test.py +++ b/tests/core_test.py @@ -48,4 +48,4 @@ def test_pop_io_dict_from_device(self): assert 'output' in io_dict[target_module_path] hooked_y = io_dict[target_module_path]['output'] assert torch.equal(y, hooked_y) - assert len(fhm.io_dict[target_module_path]) == 0 + assert len(fhm.io_dict[target_module_path]['output']) == 0