Skip to content

Commit

Permalink
Merge pull request #35 from yoshitomo-matsubara/dev
Browse files Browse the repository at this point in the history
Fix bugs in process of I/O dict for post_forward
  • Loading branch information
yoshitomo-matsubara authored Dec 4, 2020
2 parents a91f063 + 6d61b24 commit a2ca5c6
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 11 deletions.
7 changes: 4 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@
with open('README.md', 'r') as f:
long_description = f.read()


description = 'A Modular, Configuration-Driven Framework for Knowledge Distillation. ' \
'Trained models, training logs and configurations are available for ensuring the reproducibiliy.'
setup(
name='torchdistill',
version='0.0.1',
description='A unified knowledge distillation framework.',
version='0.0.2',
description=description,
long_description=long_description,
long_description_content_type='text/markdown',
url='https://github.com/yoshitomo-matsubara/torchdistill',
Expand Down
13 changes: 8 additions & 5 deletions torchdistill/core/distillation.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from torchdistill.common.module_util import check_if_wrapped, freeze_module_params, get_module, unfreeze_module_params
from torchdistill.core.forward_proc import get_forward_proc_func
from torchdistill.core.util import set_hooks, wrap_model, change_device, tensor2numpy2tensor, extract_io_dict, \
extract_sub_model_output_dict
update_io_dict, extract_sub_model_output_dict
from torchdistill.datasets.util import build_data_loaders
from torchdistill.losses.custom import get_custom_loss
from torchdistill.losses.single import KDLoss, get_single_loss
Expand Down Expand Up @@ -227,10 +227,11 @@ def get_teacher_output(self, sample_batch, targets, supp_dict):
# Deep copy of teacher info dict if teacher special module contains trainable module(s)
teacher_io_dict4cache = copy.deepcopy(self.teacher_io_dict) \
if self.teacher_updatable and isinstance(cache_file_paths, (list, tuple)) is not None else None
extracted_teacher_io_dict = extract_io_dict(self.teacher_io_dict, self.device)
if isinstance(self.teacher_model, SpecialModule):
self.teacher_model.post_forward(self.teacher_io_dict)
self.teacher_model.post_forward(extracted_teacher_io_dict)

extracted_teacher_io_dict = extract_io_dict(self.teacher_io_dict, self.device)
update_io_dict(extracted_teacher_io_dict, extract_io_dict(self.teacher_io_dict, self.device))
# Write cache files if output file paths (cache_file_paths) are given
if isinstance(cache_file_paths, (list, tuple)):
if teacher_io_dict4cache is None:
Expand All @@ -249,13 +250,15 @@ def forward(self, sample_batch, targets, supp_dict):
teacher_outputs, extracted_teacher_io_dict =\
self.get_teacher_output(sample_batch, targets, supp_dict=supp_dict)
student_outputs = self.student_forward_proc(self.student_model, sample_batch, targets, supp_dict)
extracted_student_io_dict = extract_io_dict(self.student_io_dict, self.device)
if isinstance(self.student_model, SpecialModule):
self.student_model.post_forward(self.student_io_dict)
self.student_model.post_forward(extracted_student_io_dict)

org_loss_dict = self.extract_org_loss(self.org_criterion, student_outputs, teacher_outputs, targets,
uses_teacher_output=self.uses_teacher_output, supp_dict=supp_dict)
update_io_dict(extracted_student_io_dict, extract_io_dict(self.student_io_dict, self.device))
output_dict = {'teacher': extracted_teacher_io_dict,
'student': extract_io_dict(self.student_io_dict, self.device)}
'student': extracted_student_io_dict}
total_loss = self.criterion(output_dict, org_loss_dict, targets)
return total_loss

Expand Down
2 changes: 1 addition & 1 deletion torchdistill/core/forward_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def pop_io_dict(self):
for io_type in list(module_io_dict.keys()):
sub_dict = module_io_dict.pop(io_type)
values = [sub_dict[key] for key in sorted(sub_dict.keys())]
gathered_obj = gather(values, self.target_device) if self.uses_cuda else values[-1]
gathered_obj = gather(values, self.target_device) if self.uses_cuda and len(values) > 1 else values[-1]
gathered_io_dict[module_path][io_type] = gathered_obj
return gathered_io_dict

Expand Down
11 changes: 9 additions & 2 deletions torchdistill/core/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,18 +85,25 @@ def tensor2numpy2tensor(data, device):


def extract_io_dict(model_io_dict, target_device):
uses_cuda = target_device.type.startswith('cuda')
uses_cuda = target_device.type == 'cuda'
gathered_io_dict = dict()
for module_path, module_io_dict in model_io_dict.items():
gathered_io_dict[module_path] = dict()
for io_type in list(module_io_dict.keys()):
sub_dict = module_io_dict.pop(io_type)
values = [sub_dict[key] for key in sorted(sub_dict.keys())]
gathered_obj = gather(values, target_device) if uses_cuda else values[-1]
gathered_obj = gather(values, target_device) if uses_cuda and len(values) > 1 else values[-1]
gathered_io_dict[module_path][io_type] = gathered_obj
return gathered_io_dict


def update_io_dict(main_io_dict, new_io_dict):
for key, module_io_dict in new_io_dict.items():
for io_type, value in module_io_dict.items():
if len(value) > 0:
main_io_dict[key][io_type] = value


def extract_sub_model_output_dict(model_output_dict, index):
sub_model_output_dict = dict()
for module_path, sub_model_io_dict in model_output_dict.items():
Expand Down

0 comments on commit a2ca5c6

Please sign in to comment.