Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

diffbase is public #1058

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion scenic/projects/vivit/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
from scenic.common_lib import debug_utils
from scenic.model_lib.base_models import model_utils as base_model_utils
import scipy
flax.config.update('flax_return_frozendict', True)


def reshape_to_1d_factorized(x: jnp.ndarray, axis: int):
Expand Down
9 changes: 5 additions & 4 deletions scenic/projects/vivit/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from absl import logging
from clu import metric_writers
from clu import periodic_actions
import flax
from flax import jax_utils
import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -66,6 +67,7 @@ def train(
and eval_summary which are dict of metrics. These outputs are used for
regression testing.
"""
flax.config.update('flax_return_frozendict', True)
lead_host = jax.process_index() == 0
# Build the loss_fn, metrics, and flax_model.
model = model_cls(config, dataset.meta_data)
Expand Down Expand Up @@ -110,8 +112,9 @@ def train(
restored_train_state = pretrain_utils.restore_pretrained_checkpoint(
init_checkpoint_path, train_state, assert_exist=True)
elif checkpoint_format == 'big_vision':
restored_train_state = pretrain_utils.convert_big_vision_to_scenic_checkpoint(
init_checkpoint_path, train_state)
restored_train_state = (
pretrain_utils.convert_big_vision_to_scenic_checkpoint(
init_checkpoint_path, train_state))
# Config dict in big_vision is not the same format as scenic.
# Therefore, make sure config match the config of the loaded model!
restored_model_cfg = copy.deepcopy(config)
Expand All @@ -132,7 +135,6 @@ def train(
# Replicate the optimzier, state, and rng.
train_state = jax_utils.replicate(train_state)
del params # Do not keep a copy of the initial params.

# Calculate the total number of training steps.
total_steps, steps_per_epoch = train_utils.get_num_training_steps(
config, dataset.meta_data)
Expand Down Expand Up @@ -241,7 +243,6 @@ def train(
do_memory_defrag = True
except RuntimeError:
logging.warn('Memory defragmentation not possible, use the tfrt runtime')

for step in range(start_step + 1, total_steps + 1):
with jax.profiler.StepTraceAnnotation('train', step_num=step):
train_batch = next(dataset.train_iter)
Expand Down