diff --git a/vista3d/scripts/train.py b/vista3d/scripts/train.py index e9beb8e..e3a1332 100644 --- a/vista3d/scripts/train.py +++ b/vista3d/scripts/train.py @@ -216,10 +216,6 @@ def run(config_file: Optional[Union[str, Sequence[str]]] = None, **override): optimizer = optimizer_part.instantiate(params=model.parameters()) lr_scheduler_part = parser.get_parsed_content("lr_scheduler", instantiate=False) lr_scheduler = lr_scheduler_part.instantiate(optimizer=optimizer) - if world_size > 1: - model = DistributedDataParallel( - model, device_ids=[device], find_unused_parameters=True - ) if finetune["activate"] and os.path.isfile(finetune["pretrained_ckpt_name"]): logger.debug( "Fine-tuning pre-trained checkpoint {:s}".format( @@ -229,13 +225,14 @@ def run(config_file: Optional[Union[str, Sequence[str]]] = None, **override): pretrained_ckpt = torch.load( finetune["pretrained_ckpt_name"], map_location=device ) - copy_model_state( - model, pretrained_ckpt, exclude_vars=finetune.get("exclude_vars") - ) + model.load_state_dict(pretrained_ckpt) del pretrained_ckpt else: logger.debug("Training from scratch") - + if world_size > 1: + model = DistributedDataParallel( + model, device_ids=[device], find_unused_parameters=True + ) # training hyperparameters - sample num_images_per_batch = parser.get_parsed_content("num_images_per_batch") num_patches_per_iter = parser.get_parsed_content("num_patches_per_iter") diff --git a/vista3d/scripts/train_finetune.py b/vista3d/scripts/train_finetune.py index bb945ee..9a59735 100644 --- a/vista3d/scripts/train_finetune.py +++ b/vista3d/scripts/train_finetune.py @@ -149,10 +149,6 @@ def run(config_file: Optional[Union[str, Sequence[str]]] = None, **override): optimizer = optimizer_part.instantiate(params=model.parameters()) lr_scheduler_part = parser.get_parsed_content("lr_scheduler", instantiate=False) lr_scheduler = lr_scheduler_part.instantiate(optimizer=optimizer) - if world_size > 1: - model = DistributedDataParallel( - model, device_ids=[device], find_unused_parameters=True - ) if finetune["activate"] and os.path.isfile(finetune["pretrained_ckpt_name"]): logger.debug( "Fine-tuning pre-trained checkpoint {:s}".format( @@ -162,13 +158,15 @@ def run(config_file: Optional[Union[str, Sequence[str]]] = None, **override): pretrained_ckpt = torch.load( finetune["pretrained_ckpt_name"], map_location=device ) - copy_model_state( - model, pretrained_ckpt, exclude_vars=finetune.get("exclude_vars") - ) + model.load_state_dict(pretrained_ckpt) del pretrained_ckpt else: logger.debug("Training from scratch") + if world_size > 1: + model = DistributedDataParallel( + model, device_ids=[device], find_unused_parameters=True + ) # training hyperparameters - sample num_images_per_batch = parser.get_parsed_content("num_images_per_batch") num_patches_per_iter = parser.get_parsed_content("num_patches_per_iter")