diff --git a/models/networks/loss.py b/models/networks/loss.py index b2485d77..3b849b14 100644 --- a/models/networks/loss.py +++ b/models/networks/loss.py @@ -102,7 +102,7 @@ def __call__(self, input, target_is_real, for_discriminator=True): class VGGLoss(nn.Module): def __init__(self, gpu_ids): super(VGGLoss, self).__init__() - self.vgg = VGG19().cuda() + self.vgg = VGG19() if len(gpu_ids) == 0 else VGG19().cuda() self.criterion = nn.L1Loss() self.weights = [1.0 / 32, 1.0 / 16, 1.0 / 8, 1.0 / 4, 1.0]