Skip to content

Commit

Permalink
wandb fix val img log
Browse files Browse the repository at this point in the history
  • Loading branch information
microhum committed Jun 2, 2024
1 parent d53d53f commit 4383cdb
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 41 deletions.
33 changes: 0 additions & 33 deletions .devcontainer/devcontainer.json

This file was deleted.

1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ inference
inference_model
Font_dataset
venv
.devcontainer

# Byte-compiled / optimized / DLL files
__pycache__/
Expand Down
11 changes: 6 additions & 5 deletions test_few_shot.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,16 @@ def test_main_model(opts):
dir_res = os.path.join(f"{opts.exp_path}", "experiments/", opts.name_exp, "results")

test_loader = get_loader(opts.data_root, opts.img_size, opts.language, opts.char_num, opts.max_seq_len, opts.dim_seq, opts.batch_size, 'test')
if torch.cuda.is_available():
device = torch.device("cuda")
else:
device = torch.device("cpu")

if opts.streamlit:
st.write("Loading Model Weight...")
model_main = ModelMain(opts)
path_ckpt = os.path.join(f"{opts.model_path}")
model_main.load_state_dict(torch.load(path_ckpt)['model'])
if torch.cuda.is_available():
device = torch.device("cuda")
else:
device = torch.device("cpu")
model_main.load_state_dict(torch.load(path_ckpt)['model'], map_location=torch.device('cpu'))
model_main.to(device)
model_main.eval()
with torch.no_grad():
Expand Down
6 changes: 3 additions & 3 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,9 +138,9 @@ def train_main_model(opts):
# Log loss value to WandB
wandb.log({f'VAL/loss_{loss_cat}_{key}': value})
wandb.log({
'Images/val_trg_img': wandb.Image(ret_dict_val['img']['trg'][0], caption="Val Target"),
'Images/val_img_output': wandb.Image(ret_dict_val['img']['out'][0], caption="Val Output")
}, step=batches_done)
'VAL_Images/val_trg_img': wandb.Image(ret_dict_val['img']['trg'][0], caption="Val Target"),
'VAL_Images/val_img_output': wandb.Image(ret_dict_val['img']['out'][0], caption="Val Output")
})


val_msg = (
Expand Down

0 comments on commit 4383cdb

Please sign in to comment.