-
Notifications
You must be signed in to change notification settings - Fork 34
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feat discriminator unet mha #580
base: master
Are you sure you want to change the base?
Conversation
439d06e
to
01134aa
Compare
models/base_gan_model.py
Outdated
getattr(self, fake_name + "_with_context"), size=self.real_A.shape[2:] | ||
), | ||
) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
All changes above are not from this PR, I believe they shouldn't be here, right ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Those codes were not written by me, and I have no clue where they came from. I had an accidental pull three weeks ago, which was fixed by Louis; I'm wondering if they were imported at that time. Can I delete them without destroying the work of others?"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
they are being removed
models/gan_networks.py
Outdated
D_ndf (int) -- the number of filters in the first conv layer | ||
D_ngf(int) -- the number of filters in the last cov layer |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't change this, also I believe D_ngf doesn't exist.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
D_ngf is one parameter in the first Unet ( https://github.com/jolibrain/joliGEN/blob/98a09fe6360b396cd3d309227deef47d278cf3c3/models/modules/unet_architecture/unet_generator.py#L14C16-L14C16 ) , but since it is not relevant to this UNet architecture, I'll try to delete it from this pull request
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it is being removed
@@ -1,11 +1,24 @@ | |||
import functools | |||
|
|||
import torch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not needed I believe.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, i'll check on it and delete it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if i remove it, my code does not work.
models/modules/loss.py
Outdated
self.gan_mode, target_real_label=target_real_label | ||
).to(self.device) | ||
|
||
self.viz = visdom.Visdom( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove viz from here, viz is taken care of elsewhere.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This can be left here for testing, but moved/removed afterwards.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK. I'll delete all the viz
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it is done
models/modules/loss.py
Outdated
mixed_images = masks * fake + (1 - masks) * real | ||
return mixed_images, masks | ||
|
||
train_use_cutmix = True # change this into True to use cutmix |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need to make this an option, in options/train_options.py
and then pass it as argument to DualDiscriminatorGANLoss(DiscriminatorLoss)
constructor.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok, I'll make this an option.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it is done
models/modules/loss.py
Outdated
consistent_pred_pixel = torch.mul( | ||
pred_fake_pixel, 1 - cutmix_pixel_label | ||
) + torch.mul(pred_real_pixel, cutmix_pixel_label) | ||
loss_cutmix_pixel = ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@wr0124 are you sure this is the loss from the paper ?
It should be D(cutmix(inputs)) - cutmix(D(inputs)) no ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, you are right. in the paper, it is || D(cutmix(inputs)) - cutmix(D(inputs)) ||² . So, D(cutmix(inputs)) is pred_cutmix_fake_pixel in my code, and cutmix(D(inputs)) is consistent_pred_pixel in my code @beniz
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK cool, you can probably use MSELoss
instead of the norm and **2
, that´d be more efficient probably.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, you are right. The MSELoss is being used here.
59bd27e
to
3f67ad8
Compare
3f67ad8
to
dc69f51
Compare
@@ -379,6 +379,20 @@ def initialize(self, parser): | |||
default=64, | |||
help="# of discrim filters in the first conv layer", | |||
) | |||
parser.add_argument( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe this is already done with --D_ndf
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it is removed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it is removed
options/train_options.py
Outdated
|
||
parser.add_argument( | ||
"--train_use_cutmix", | ||
type=bool, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use action="store_true"
instead.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it is done
Integrate CutMix into the unet_discriminator_mha which is original UNet.