diff --git a/scenic/projects/vivit/model_utils.py b/scenic/projects/vivit/model_utils.py index 5b4f1174..f29a9afb 100644 --- a/scenic/projects/vivit/model_utils.py +++ b/scenic/projects/vivit/model_utils.py @@ -27,7 +27,6 @@ from scenic.common_lib import debug_utils from scenic.model_lib.base_models import model_utils as base_model_utils import scipy -flax.config.update('flax_return_frozendict', True) def reshape_to_1d_factorized(x: jnp.ndarray, axis: int): diff --git a/scenic/projects/vivit/trainer.py b/scenic/projects/vivit/trainer.py index a1e31d76..f02d9564 100644 --- a/scenic/projects/vivit/trainer.py +++ b/scenic/projects/vivit/trainer.py @@ -21,6 +21,7 @@ from absl import logging from clu import metric_writers from clu import periodic_actions +import flax from flax import jax_utils import jax import jax.numpy as jnp @@ -66,6 +67,7 @@ def train( and eval_summary which are dict of metrics. These outputs are used for regression testing. """ + flax.config.update('flax_return_frozendict', True) lead_host = jax.process_index() == 0 # Build the loss_fn, metrics, and flax_model. model = model_cls(config, dataset.meta_data) @@ -110,8 +112,9 @@ def train( restored_train_state = pretrain_utils.restore_pretrained_checkpoint( init_checkpoint_path, train_state, assert_exist=True) elif checkpoint_format == 'big_vision': - restored_train_state = pretrain_utils.convert_big_vision_to_scenic_checkpoint( - init_checkpoint_path, train_state) + restored_train_state = ( + pretrain_utils.convert_big_vision_to_scenic_checkpoint( + init_checkpoint_path, train_state)) # Config dict in big_vision is not the same format as scenic. # Therefore, make sure config match the config of the loaded model! restored_model_cfg = copy.deepcopy(config) @@ -132,7 +135,6 @@ def train( # Replicate the optimzier, state, and rng. train_state = jax_utils.replicate(train_state) del params # Do not keep a copy of the initial params. - # Calculate the total number of training steps. total_steps, steps_per_epoch = train_utils.get_num_training_steps( config, dataset.meta_data) @@ -241,7 +243,6 @@ def train( do_memory_defrag = True except RuntimeError: logging.warn('Memory defragmentation not possible, use the tfrt runtime') - for step in range(start_step + 1, total_steps + 1): with jax.profiler.StepTraceAnnotation('train', step_num=step): train_batch = next(dataset.train_iter)