Skip to content

Commit

Permalink
feat : format
Browse files Browse the repository at this point in the history
  • Loading branch information
wr0124 committed Nov 9, 2023
1 parent cd21148 commit 01134aa
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 23 deletions.
38 changes: 18 additions & 20 deletions models/base_gan_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from torchviz import make_dot

# for FID
#from data.base_dataset import get_transform
# from data.base_dataset import get_transform
from util.diff_aug import DiffAugment
from util.discriminator import DiscriminatorInfo

Expand Down Expand Up @@ -438,9 +438,8 @@ def compute_G_loss(self):
getattr(self, loss_function)()

def compute_G_loss_GAN(self):
"""Calculate GAN losses for generator(s)"""
"""Calculate GAN losses for generator(s)"""


for discriminator in self.discriminators:
if "mask" in discriminator.name:
continue
Expand All @@ -460,7 +459,7 @@ def compute_G_loss_GAN(self):
netD,
domain,
loss,
fake_name=fake_name,
fake_name=fake_name,
real_name=real_name,
)

Expand Down Expand Up @@ -587,29 +586,28 @@ def set_discriminators_info(self):

elif "unet_discriminator_mha" in discriminator_name:
loss_calculator = loss.DualDiscriminatorGANLoss(
netD=getattr(self, "net"+ discriminator_name),
device=self.device,
dataaug_APA_p=self.opt.dataaug_APA_p,
dataaug_APA_target=self.opt.dataaug_APA_target,
train_batch_size=self.opt.train_batch_size,
dataaug_APA_nimg=self.opt.dataaug_APA_nimg,
dataaug_APA_every=self.opt.dataaug_APA_every,
dataaug_D_label_smooth=self.opt.dataaug_D_label_smooth,
train_gan_mode=train_gan_mode,
dataaug_APA=self.opt.dataaug_APA,
dataaug_D_diffusion=dataaug_D_diffusion,
dataaug_D_diffusion_every=dataaug_D_diffusion_every,
)
netD=getattr(self, "net" + discriminator_name),
device=self.device,
dataaug_APA_p=self.opt.dataaug_APA_p,
dataaug_APA_target=self.opt.dataaug_APA_target,
train_batch_size=self.opt.train_batch_size,
dataaug_APA_nimg=self.opt.dataaug_APA_nimg,
dataaug_APA_every=self.opt.dataaug_APA_every,
dataaug_D_label_smooth=self.opt.dataaug_D_label_smooth,
train_gan_mode=train_gan_mode,
dataaug_APA=self.opt.dataaug_APA,
dataaug_D_diffusion=dataaug_D_diffusion,
dataaug_D_diffusion_every=dataaug_D_diffusion_every,
)
fake_name = None
real_name = None
compute_every = 1
else :

else:
fake_name = None
real_name = None
compute_every = 1


if self.opt.train_use_contrastive_loss_D:
loss_calculator = (
loss.DiscriminatorContrastiveLoss(
Expand Down
2 changes: 0 additions & 2 deletions models/modules/discriminators.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,5 +164,3 @@ def __init__(self, input_nc, ndf=64, norm_layer=nn.BatchNorm2d):
def forward(self, input):
"""Standard forward."""
return self.net(input)


2 changes: 1 addition & 1 deletion models/modules/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -673,7 +673,7 @@ def cutmix_real_fake_pairwise(real, fake):
for _ in range(self.real.size(0))
]
)
masks = masks.unsqueeze(1)
masks = masks.unsqueeze(1)
mixed_images = masks * fake + (1 - masks) * real
return mixed_images, masks

Expand Down

0 comments on commit 01134aa

Please sign in to comment.