diff --git a/alphafold3_pytorch/trainer.py b/alphafold3_pytorch/trainer.py index f469f43f..21e9b4bc 100644 --- a/alphafold3_pytorch/trainer.py +++ b/alphafold3_pytorch/trainer.py @@ -209,10 +209,14 @@ def __init__( self.distributed_eval = distributed_eval self.will_eval_or_test = self.is_main or distributed_eval + # using "switch" ema + + self.switch_ema = exists(ema_update_model_with_ema_every) + # exponential moving average self.ema_model = None - self.has_ema = self.will_eval_or_test and use_ema + self.has_ema = (self.will_eval_or_test or self.switch_ema) and use_ema if self.has_ema: self.ema_model = EMA( diff --git a/pyproject.toml b/pyproject.toml index 7853fee7..8f135db9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "alphafold3-pytorch" -version = "0.6.3" +version = "0.6.4" description = "Alphafold 3 - Pytorch" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" },