diff --git a/tests/others/test_ema.py b/tests/others/test_ema.py index 3443e6366f01..7cf8f30ecc44 100644 --- a/tests/others/test_ema.py +++ b/tests/others/test_ema.py @@ -67,6 +67,7 @@ def test_from_pretrained(self): # Load the EMA model from the saved directory loaded_ema_unet = EMAModel.from_pretrained(tmpdir, model_cls=UNet2DConditionModel, foreach=False) + loaded_ema_unet.to(torch_device) # Check that the shadow parameters of the loaded model match the original EMA model for original_param, loaded_param in zip(ema_unet.shadow_params, loaded_ema_unet.shadow_params): @@ -221,6 +222,7 @@ def test_from_pretrained(self): # Load the EMA model from the saved directory loaded_ema_unet = EMAModel.from_pretrained(tmpdir, model_cls=UNet2DConditionModel, foreach=True) + loaded_ema_unet.to(torch_device) # Check that the shadow parameters of the loaded model match the original EMA model for original_param, loaded_param in zip(ema_unet.shadow_params, loaded_ema_unet.shadow_params):