Skip to content

Commit

Permalink
trying to fix wandb logging
Browse files Browse the repository at this point in the history
  • Loading branch information
clemsgrs committed Dec 19, 2024
1 parent a4d86ad commit 9e22796
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 42 deletions.
2 changes: 1 addition & 1 deletion dinov2/configs/ssl_default_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ train:
centering: "centering" # or "sinkhorn_knopp"
save_frequency: 0.1 # save every x% of an epoch
tune:
tune_every: # run tuning every x% of an epoch, leave empty to disable tuning
tune_every_pct: # run tuning every x% of an epoch, leave empty to disable tuning (1 = every epoch)
query_dataset_path:
test_dataset_path:
tile_size: 256
Expand Down
69 changes: 28 additions & 41 deletions dinov2/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,15 +389,14 @@ def do_train(cfg, model, resume=False):
metrics_file = os.path.join(cfg.train.output_dir, "training_metrics.json")
metric_logger = MetricLogger(delimiter=" ", output_file=metrics_file)
log_freq = 10 # log_freq has to be smaller than the window_size used with instantiating SmoothedValue (here and in MetricLogger)
header = "Train"

forward_backward_time = SmoothedValue(fmt="{avg:.6f}")

for data in metric_logger.log_every(
data_loader,
distributed.get_global_rank(),
log_freq,
header,
"Train",
max_iter,
start_iter,
):
Expand Down Expand Up @@ -468,55 +467,43 @@ def do_train(cfg, model, resume=False):
metric_logger.update(current_batch_size=current_batch_size)
metric_logger.update(total_loss=losses_reduced, **loss_dict_reduced)

epoch = iteration // OFFICIAL_EPOCH_LENGTH

# logging
if distributed.is_main_process() and cfg.wandb.enable:
log_dict = {"iteration": iteration}
update_log_dict(log_dict, f"{header.lower()}/lr", lr, step="iteration")
update_log_dict(log_dict, f"{header.lower()}/wd", wd, step="iteration")
update_log_dict(log_dict, f"{header.lower()}/loss", losses_reduced, step="iteration")
log_dict = {"iteration": iteration, "epoch": epoch}
update_log_dict(log_dict, "train/lr", lr, step="iteration")
update_log_dict(log_dict, "train/wd", wd, step="iteration")
update_log_dict(log_dict, "train/loss", losses_reduced, step="iteration")
for loss_name, loss_value in loss_dict.items():
update_log_dict(log_dict, f"{header.lower()}/{loss_name}", loss_value, step="iteration")
wandb.log(log_dict, step=iteration)

epoch = iteration // OFFICIAL_EPOCH_LENGTH
update_log_dict(log_dict, f"train/{loss_name}", loss_value, step="iteration")

# addtional logging at the end of each epoch
if iteration % OFFICIAL_EPOCH_LENGTH == 0:
if distributed.is_main_process() and cfg.wandb.enable:
# log the total loss and each individual loss to wandb
log_dict = {"epoch": epoch}
update_log_dict(log_dict, f"{header.lower()}/lr", lr, step="epoch")
update_log_dict(log_dict, f"{header.lower()}/wd", wd, step="epoch")
update_log_dict(log_dict, f"{header.lower()}/loss", losses_reduced, step="epoch")
for loss_name, loss_value in loss_dict.items():
update_log_dict(log_dict, f"{header.lower()}/{loss_name}", loss_value, step="epoch")

# optionally run tuning
# optionally run tuning
tune_results = None
if distributed.is_main_process() and tune_every_iter and iteration % tune_every_iter == 0:
# only run tuning on rank 0, otherwise one has to take care of gathering knn metrics from multiple gpus
tune_results = None
if tune_every_iter and iteration % tune_every_iter == 0:
tune_results = do_tune(
cfg,
iteration,
model,
query_dataset,
test_dataset,
results_save_dir,
verbose=False,
)

if distributed.is_main_process() and cfg.wandb.enable:
for model_name, metrics_dict in tune_results.items():
for name, value in metrics_dict.items():
update_log_dict(log_dict, f"tune/{model_name}.{name}", value, step="epoch")
tune_results = do_tune(
cfg,
iteration,
model,
query_dataset,
test_dataset,
results_save_dir,
verbose=False,
)

if cfg.wandb.enable:
for model_name, metrics_dict in tune_results.items():
for name, value in metrics_dict.items():
update_log_dict(log_dict, f"tune/{model_name}.{name}", value, step="iteration")

early_stopper(epoch, tune_results, periodic_checkpointer, run_distributed, iteration)
if early_stopper.early_stop and cfg.tune.early_stopping.enable:
stop = True

# log to wandb
if distributed.is_main_process() and cfg.wandb.enable:
wandb.log(log_dict, step=epoch)
# log to wandb
if distributed.is_main_process() and cfg.wandb.enable:
wandb.log(log_dict, step=iteration)

# checkpointing and testing

Expand Down

0 comments on commit 9e22796

Please sign in to comment.