Skip to content

Commit

Permalink
Fix model weight load bug with multigpu
Browse files Browse the repository at this point in the history
Signed-off-by: heyufan1995 <[email protected]>
  • Loading branch information
heyufan1995 committed Sep 11, 2024
1 parent f6308df commit 821c763
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 15 deletions.
13 changes: 5 additions & 8 deletions vista3d/scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,10 +216,6 @@ def run(config_file: Optional[Union[str, Sequence[str]]] = None, **override):
optimizer = optimizer_part.instantiate(params=model.parameters())
lr_scheduler_part = parser.get_parsed_content("lr_scheduler", instantiate=False)
lr_scheduler = lr_scheduler_part.instantiate(optimizer=optimizer)
if world_size > 1:
model = DistributedDataParallel(
model, device_ids=[device], find_unused_parameters=True
)
if finetune["activate"] and os.path.isfile(finetune["pretrained_ckpt_name"]):
logger.debug(
"Fine-tuning pre-trained checkpoint {:s}".format(
Expand All @@ -229,13 +225,14 @@ def run(config_file: Optional[Union[str, Sequence[str]]] = None, **override):
pretrained_ckpt = torch.load(
finetune["pretrained_ckpt_name"], map_location=device
)
copy_model_state(
model, pretrained_ckpt, exclude_vars=finetune.get("exclude_vars")
)
model.load_state_dict(pretrained_ckpt)
del pretrained_ckpt
else:
logger.debug("Training from scratch")

if world_size > 1:
model = DistributedDataParallel(
model, device_ids=[device], find_unused_parameters=True
)
# training hyperparameters - sample
num_images_per_batch = parser.get_parsed_content("num_images_per_batch")
num_patches_per_iter = parser.get_parsed_content("num_patches_per_iter")
Expand Down
12 changes: 5 additions & 7 deletions vista3d/scripts/train_finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,10 +149,6 @@ def run(config_file: Optional[Union[str, Sequence[str]]] = None, **override):
optimizer = optimizer_part.instantiate(params=model.parameters())
lr_scheduler_part = parser.get_parsed_content("lr_scheduler", instantiate=False)
lr_scheduler = lr_scheduler_part.instantiate(optimizer=optimizer)
if world_size > 1:
model = DistributedDataParallel(
model, device_ids=[device], find_unused_parameters=True
)
if finetune["activate"] and os.path.isfile(finetune["pretrained_ckpt_name"]):
logger.debug(
"Fine-tuning pre-trained checkpoint {:s}".format(
Expand All @@ -162,13 +158,15 @@ def run(config_file: Optional[Union[str, Sequence[str]]] = None, **override):
pretrained_ckpt = torch.load(
finetune["pretrained_ckpt_name"], map_location=device
)
copy_model_state(
model, pretrained_ckpt, exclude_vars=finetune.get("exclude_vars")
)
model.load_state_dict(pretrained_ckpt)
del pretrained_ckpt
else:
logger.debug("Training from scratch")

if world_size > 1:
model = DistributedDataParallel(
model, device_ids=[device], find_unused_parameters=True
)
# training hyperparameters - sample
num_images_per_batch = parser.get_parsed_content("num_images_per_batch")
num_patches_per_iter = parser.get_parsed_content("num_patches_per_iter")
Expand Down

0 comments on commit 821c763

Please sign in to comment.