Skip to content

Commit

Permalink
add optimizer load checkpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
microhum committed Jun 2, 2024
1 parent 4383cdb commit 0069f4d
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 9 deletions.
2 changes: 1 addition & 1 deletion test_few_shot.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def test_main_model(opts):
device = torch.device("cuda")
else:
device = torch.device("cpu")
model_main.load_state_dict(torch.load(path_ckpt)['model'], map_location=torch.device('cpu'))
model_main.load_state_dict(torch.load(path_ckpt)['model'])
model_main.to(device)
model_main.eval()
with torch.no_grad():
Expand Down
19 changes: 11 additions & 8 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,20 +37,23 @@ def train_main_model(opts):
run = wandb.init(project=opts.wandb_project_name, config=opts) # initialize wandb project

model_main = ModelMain(opts)
if torch.cuda.is_available() and opts.multi_gpu:
model_main = torch.nn.DataParallel(model_main)

if opts.continue_training:
model_main.load_state_dict(torch.load(opts.continue_ckpt)['model'])

model_main.cuda()

parameters_all = [{"params": model_main.img_encoder.parameters()}, {"params": model_main.img_decoder.parameters()},
{"params": model_main.modality_fusion.parameters()}, {"params": model_main.transformer_main.parameters()},
{"params": model_main.transformer_seqdec.parameters()}]

optimizer = AdamW(parameters_all, lr=opts.lr, betas=(opts.beta1, opts.beta2), eps=opts.eps, weight_decay=opts.weight_decay)

if torch.cuda.is_available() and opts.multi_gpu:
model_main = torch.nn.DataParallel(model_main)


if opts.continue_training:
checkpoint = torch.load(opts.continue_ckpt)
model_main.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['opt'])

scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.997)
model_main.cuda()

for epoch in range(opts.init_epoch, opts.n_epochs):
t0 = time()
Expand Down

0 comments on commit 0069f4d

Please sign in to comment.