Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cannot load the model when trained with batch_size == 1 #140

Open
SysDevHayes opened this issue Dec 7, 2024 · 3 comments
Open

Cannot load the model when trained with batch_size == 1 #140

SysDevHayes opened this issue Dec 7, 2024 · 3 comments

Comments

@SysDevHayes
Copy link

Hi @ZhengPeng7

Thank you so much for your amazing work!

To verify that the training goes well, I made a small change to the train_epoch function as:

    def train_epoch(self, epoch):
        global logger_loss_idx
        self.model.train()
        self.loss_dict = {}
        iteration_save_interval = 3000
        iteration_count = (epoch - 1) * len(self.train_loader)  # Track total iterations across epochs
        if epoch > args.epochs + config.finetune_last_epochs:
            if config.task == 'Matting':
                self.pix_loss.lambdas_pix_last['mae'] *= 1
                self.pix_loss.lambdas_pix_last['mse'] *= 0.9
                self.pix_loss.lambdas_pix_last['ssim'] *= 0.9
            else:
                self.pix_loss.lambdas_pix_last['bce'] *= 0
                self.pix_loss.lambdas_pix_last['ssim'] *= 1
                self.pix_loss.lambdas_pix_last['iou'] *= 0.5
                self.pix_loss.lambdas_pix_last['mae'] *= 0.9

        for batch_idx, batch in enumerate(self.train_loader):
            # with nullcontext if not args.use_accelerate or accelerator.gradient_accumulation_steps <= 1 else accelerator.accumulate(self.model):
            self._train_batch(batch)

            # Update iteration count
            iteration_count += 1

            # Save model after x iterations with unique name
            if iteration_count % iteration_save_interval == 0:
                save_path = os.path.join(
                    args.ckpt_dir,
                    f'epoch_{epoch}_batch_{batch_idx}_iteration_{iteration_count}.pth'
                )
                torch.save(
                    self.model.module.state_dict() if to_be_distributed or args.use_accelerate else self.model.state_dict(),
                    save_path
                )
                print(f"Model saved at {save_path} after {iteration_count} iterations.")

            # Logger
            if batch_idx % 20 == 0:
                info_progress = 'Epoch[{0}/{1}] Iter[{2}/{3}].'.format(epoch, args.epochs, batch_idx, len(self.train_loader))
                info_loss = 'Training Losses'
                for loss_name, loss_value in self.loss_dict.items():
                    info_loss += ', {}: {:.3f}'.format(loss_name, loss_value)
                logger.info(' '.join((info_progress, info_loss)))
        info_loss = '@==Final== Epoch[{0}/{1}]  Training Loss: {loss.avg:.3f}  '.format(epoch, args.epochs, loss=self.loss_log)
        logger.info(info_loss)

        self.lr_scheduler.step()
        return self.loss_log.avg

The exported model seems to be slightly lighter than your trained model (yours is 885.1 MB while this one got 884.8 MB).

When I load the model using this snippet:

birefnet = BiRefNet(bb_pretrained=False)
state_dict = torch.load('./trained/epoch_1_batch_2999_iteration_3000.pth', map_location='cpu', weights_only=True)
state_dict = check_state_dict(state_dict)
birefnet.load_state_dict(state_dict)

device = 'cuda' if torch.cuda.is_available() else 'cpu'

torch.set_float32_matmul_precision(['high', 'highest'][0])

birefnet.to(device)
birefnet.eval()
print('BiRefNet is ready to use.')

It raised an error of:

RuntimeError: Error(s) in loading state_dict for BiRefNet:
	Missing key(s) in state_dict: "squeeze_module.0.dec_att.aspp1.bn.weight", "squeeze_module.0.dec_att.aspp1.bn.bias", "squeeze_module.0.dec_att.aspp1.bn.running_mean", "squeeze_module.0.dec_att.aspp1.bn.running_var", "squeeze_module.0.dec_att.aspp_deforms.0.bn.weight", "squeeze_module.0.dec_att.aspp_deforms.0.bn.bias", "squeeze_module.0.dec_att.aspp_deforms.0.bn.running_mean", "squeeze_module.0.dec_att.aspp_deforms.0.bn.running_var", "squeeze_module.0.dec_att.aspp_deforms.1.bn.weight", "squeeze_module.0.dec_att.aspp_deforms.1.bn.bias", "squeeze_module.0.dec_att.aspp_deforms.1.bn.running_mean", "squeeze_module.0.dec_att.aspp_deforms.1.bn.running_var", "squeeze_module.0.dec_att.aspp_deforms.2.bn.weight", "squeeze_module.0.dec_att.aspp_deforms.2.bn.bias", "squeeze_module.0.dec_att.aspp_deforms.2.bn.running_mean", "squeeze_module.0.dec_att.aspp_deforms.2.bn.running_var", "squeeze_module.0.dec_att.global_avg_pool.2.weight", "squeeze_module.0.dec_att.global_avg_pool.2.bias", "squeeze_module.0.dec_att.global_avg_pool.2.running_mean", "squeeze_module.0.dec_att.global_avg_pool.2.running_var", "squeeze_module.0.dec_att.bn1.weight", "squeeze_module.0.dec_att.bn1.bias", "squeeze_module.0.dec_att.bn1.running_mean", "squeeze_module.0.dec_att.bn1.running_var", "squeeze_module.0.bn_in.weight", "squeeze_module.0.bn_in.bias", "squeeze_module.0.bn_in.running_mean", "squeeze_module.0.bn_in.running_var", "squeeze_module.0.bn_out.weight", "squeeze_module.0.bn_out.bias", "squeeze_module.0.bn_out.running_mean", "squeeze_module.0.bn_out.running_var", "decoder.decoder_block4.dec_att.aspp1.bn.weight", "decoder.decoder_block4.dec_att.aspp1.bn.bias", "decoder.decoder_block4.dec_att.aspp1.bn.running_mean", "decoder.decoder_block4.dec_att.aspp1.bn.running_var", "decoder.decoder_block4.dec_att.aspp_deforms.0.bn.weight", "decoder.decoder_block4.dec_att.aspp_deforms.0.bn.bias", "decoder.decoder_block4.dec_att.aspp_deforms.0.bn.running_mean", "decoder.decoder_block4.dec_att.aspp_deforms.0.bn.running_var", "decoder.decoder_block4.dec_att.aspp_deforms.1.bn.weight", "decoder.decoder_block4.dec_att.aspp_deforms.1.bn.bias", "decoder.decoder_block4.dec_att.aspp_deforms.1.bn.running_mean", "decoder.decoder_block4.dec_att.aspp_deforms.1.bn.running_var", "decoder.decoder_block4.dec_att.aspp_deforms.2.bn.weight", "decoder.decoder_block4.dec_att.aspp_deforms.2.bn.bias", "decoder.decoder_block4.dec_att.aspp_deforms.2.bn.running_mean", "decoder.decoder_block4.dec_att.aspp_deforms.2.bn.running_var", "decoder.decoder_block4.dec_att.global_avg_pool.2.weight", "decoder.decoder_block4.dec_att.global_avg_pool.2.bias", "decoder.decoder_block4.dec_att.global_avg_pool.2.running_mean", "decoder.decoder_block4.dec_att.global_avg_pool.2.running_var", "decoder.decoder_block4.dec_att.bn1.weight", "decoder.decoder_block4.dec_att.bn1.bias", "decoder.decoder_block4.dec_att.bn1.running_mean", "decoder.decoder_block4.dec_att.bn1.running_var", "decoder.decoder_block4.bn_in.weight", "decoder.decoder_block4.bn_in.bias", "decoder.decoder_block4.bn_in.running_mean", "decoder.decoder_block4.bn_in.running_var", "decoder.decoder_block4.bn_out.weight", "decoder.decoder_block4.bn_out.bias", "decoder.decoder_block4.bn_out.running_mean", "decoder.decoder_block4.bn_out.running_var", "decoder.decoder_block3.dec_att.aspp1.bn.weight", "decoder.decoder_block3.dec_att.aspp1.bn.bias", "decoder.decoder_block3.dec_att.aspp1.bn.running_mean", "decoder.decoder_block3.dec_att.aspp1.bn.running_var", "decoder.decoder_block3.dec_att.aspp_deforms.0.bn.weight", "decoder.decoder_block3.dec_att.aspp_deforms.0.bn.bias", "decoder.decoder_block3.dec_att.aspp_deforms.0.bn.running_mean", "decoder.decoder_block3.dec_att.aspp_deforms.0.bn.running_var", "decoder.decoder_block3.dec_att.aspp_deforms.1.bn.weight", "decoder.decoder_block3.dec_att.aspp_deforms.1.bn.bias", "decoder.decoder_block3.dec_att.aspp_deforms.1.bn.running_mean", "decoder.decoder_block3.dec_att.aspp_deforms.1.bn.running_var", "decoder.decoder_block3.dec_att.aspp_deforms.2.bn.weight", "decoder.decoder_block3.dec_att.aspp_deforms.2.bn.bias", "decoder.decoder_block3.dec_att.aspp_deforms.2.bn.running_mean", "decoder.decoder_block3.dec_att.aspp_deforms.2.bn.running_var", "decoder.decoder_block3.dec_att.global_avg_pool.2.weight", "decoder.decoder_block3.dec_att.global_avg_pool.2.bias", "decoder.decoder_block3.dec_att.global_avg_pool.2.running_mean", "decoder.decoder_block3.dec_att.global_avg_pool.2.running_var", "decoder.decoder_block3.dec_att.bn1.weight", "decoder.decoder_block3.dec_att.bn1.bias", "decoder.decoder_block3.dec_att.bn1.running_mean", "decoder.decoder_block3.dec_att.bn1.running_var", "decoder.decoder_block3.bn_in.weight", "decoder.decoder_block3.bn_in.bias", "decoder.decoder_block3.bn_in.running_mean", "decoder.decoder_block3.bn_in.running_var", "decoder.decoder_block3.bn_out.weight", "decoder.decoder_block3.bn_out.bias", "decoder.decoder_block3.bn_out.running_mean", "decoder.decoder_block3.bn_out.running_var", "decoder.decoder_block2.dec_att.aspp1.bn.weight", "decoder.decoder_block2.dec_att.aspp1.bn.bias", "decoder.decoder_block2.dec_att.aspp1.bn.running_mean", "decoder.decoder_block2.dec_att.aspp1.bn.running_var", "decoder.decoder_block2.dec_att.aspp_deforms.0.bn.weight", "decoder.decoder_block2.dec_att.aspp_deforms.0.bn.bias", "decoder.decoder_block2.dec_att.aspp_deforms.0.bn.running_mean", "decoder.decoder_block2.dec_att.aspp_deforms.0.bn.running_var", "decoder.decoder_block2.dec_att.aspp_deforms.1.bn.weight", "decoder.decoder_block2.dec_att.aspp_deforms.1.bn.bias", "decoder.decoder_block2.dec_att.aspp_deforms.1.bn.running_mean", "decoder.decoder_block2.dec_att.aspp_deforms.1.bn.running_var", "decoder.decoder_block2.dec_att.aspp_deforms.2.bn.weight", "decoder.decoder_block2.dec_att.aspp_deforms.2.bn.bias", "decoder.decoder_block2.dec_att.aspp_deforms.2.bn.running_mean", "decoder.decoder_block2.dec_att.aspp_deforms.2.bn.running_var", "decoder.decoder_block2.dec_att.global_avg_pool.2.weight", "decoder.decoder_block2.dec_att.global_avg_pool.2.bias", "decoder.decoder_block2.dec_att.global_avg_pool.2.running_mean", "decoder.decoder_block2.dec_att.global_avg_pool.2.running_var", "decoder.decoder_block2.dec_att.bn1.weight", "decoder.decoder_block2.dec_att.bn1.bias", "decoder.decoder_block2.dec_att.bn1.running_mean", "decoder.decoder_block2.dec_att.bn1.running_var", "decoder.decoder_block2.bn_in.weight", "decoder.decoder_block2.bn_in.bias", "decoder.decoder_block2.bn_in.running_mean", "decoder.decoder_block2.bn_in.running_var", "decoder.decoder_block2.bn_out.weight", "decoder.decoder_block2.bn_out.bias", "decoder.decoder_block2.bn_out.running_mean", "decoder.decoder_block2.bn_out.running_var", "decoder.decoder_block1.dec_att.aspp1.bn.weight", "decoder.decoder_block1.dec_att.aspp1.bn.bias", "decoder.decoder_block1.dec_att.aspp1.bn.running_mean", "decoder.decoder_block1.dec_att.aspp1.bn.running_var", "decoder.decoder_block1.dec_att.aspp_deforms.0.bn.weight", "decoder.decoder_block1.dec_att.aspp_deforms.0.bn.bias", "decoder.decoder_block1.dec_att.aspp_deforms.0.bn.running_mean", "decoder.decoder_block1.dec_att.aspp_deforms.0.bn.running_var", "decoder.decoder_block1.dec_att.aspp_deforms.1.bn.weight", "decoder.decoder_block1.dec_att.aspp_deforms.1.bn.bias", "decoder.decoder_block1.dec_att.aspp_deforms.1.bn.running_mean", "decoder.decoder_block1.dec_att.aspp_deforms.1.bn.running_var", "decoder.decoder_block1.dec_att.aspp_deforms.2.bn.weight", "decoder.decoder_block1.dec_att.aspp_deforms.2.bn.bias", "decoder.decoder_block1.dec_att.aspp_deforms.2.bn.running_mean", "decoder.decoder_block1.dec_att.aspp_deforms.2.bn.running_var", "decoder.decoder_block1.dec_att.global_avg_pool.2.weight", "decoder.decoder_block1.dec_att.global_avg_pool.2.bias", "decoder.decoder_block1.dec_att.global_avg_pool.2.running_mean", "decoder.decoder_block1.dec_att.global_avg_pool.2.running_var", "decoder.decoder_block1.dec_att.bn1.weight", "decoder.decoder_block1.dec_att.bn1.bias", "decoder.decoder_block1.dec_att.bn1.running_mean", "decoder.decoder_block1.dec_att.bn1.running_var", "decoder.decoder_block1.bn_in.weight", "decoder.decoder_block1.bn_in.bias", "decoder.decoder_block1.bn_in.running_mean", "decoder.decoder_block1.bn_in.running_var", "decoder.decoder_block1.bn_out.weight", "decoder.decoder_block1.bn_out.bias", "decoder.decoder_block1.bn_out.running_mean", "decoder.decoder_block1.bn_out.running_var", "decoder.gdt_convs_4.1.weight", "decoder.gdt_convs_4.1.bias", "decoder.gdt_convs_4.1.running_mean", "decoder.gdt_convs_4.1.running_var", "decoder.gdt_convs_3.1.weight", "decoder.gdt_convs_3.1.bias", "decoder.gdt_convs_3.1.running_mean", "decoder.gdt_convs_3.1.running_var", "decoder.gdt_convs_2.1.weight", "decoder.gdt_convs_2.1.bias", "decoder.gdt_convs_2.1.running_mean", "decoder.gdt_convs_2.1.running_var". 

I assume my code is missing something, can you please guide me? Thank you so much!

@SysDevHayes
Copy link
Author

SysDevHayes commented Dec 7, 2024

Additional information that I just found:
I used the swin_l swin_large_patch4_window12_384_22kto1k.pth weights. I guess it's more or less similar to this issue. Can you please confirm? If that is the case, can you please show me how to ignore the value from BN layers? Or is there any other way to do it? Maybe training from scratch without having to load these weights?

Otherwise, do you think adding track_running_stats=False to all BatchNorm2d would work? Maybe replacing BatchNorm2d by InstanceNorm2d?

@SysDevHayes SysDevHayes changed the title Cannot load the trained model Cannot load the model when trained with batch_size == 1 Dec 7, 2024
@ZhengPeng7
Copy link
Owner

Yeah, when batch size == 1, BN is not added to the network. So, the architectures of these two conditions are not consistent. Therefore, I strongly suggest not using batch size == 1. In the previous issue you mentioned, I only said some possible trade-off ways (still should be avoided if possible).
And thanks for the suggestion. I also realized the incompatibility of BN here. In later versions, I would change it to another norm method, which is batch agnostic.

@ZhengPeng7
Copy link
Owner

BTW, I mean loading the each key: value of the loaded torch model, and skip those with BN in its name.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants