Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat discriminator unet mha #580

Open
wants to merge 2 commits into
base: master
Choose a base branch
from

Conversation

wr0124
Copy link
Collaborator

@wr0124 wr0124 commented Nov 7, 2023

Integrate CutMix into the unet_discriminator_mha which is original UNet.

@wr0124 wr0124 force-pushed the feat_discriminator_unet_mha branch from 439d06e to 01134aa Compare November 9, 2023 10:25
getattr(self, fake_name + "_with_context"), size=self.real_A.shape[2:]
),
)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All changes above are not from this PR, I believe they shouldn't be here, right ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Those codes were not written by me, and I have no clue where they came from. I had an accidental pull three weeks ago, which was fixed by Louis; I'm wondering if they were imported at that time. Can I delete them without destroying the work of others?"

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

they are being removed

D_ndf (int) -- the number of filters in the first conv layer
D_ngf(int) -- the number of filters in the last cov layer
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't change this, also I believe D_ngf doesn't exist.

Copy link
Collaborator Author

@wr0124 wr0124 Nov 9, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

D_ngf is one parameter in the first Unet ( https://github.com/jolibrain/joliGEN/blob/98a09fe6360b396cd3d309227deef47d278cf3c3/models/modules/unet_architecture/unet_generator.py#L14C16-L14C16 ) , but since it is not relevant to this UNet architecture, I'll try to delete it from this pull request

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is being removed

@@ -1,11 +1,24 @@
import functools

import torch
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not needed I believe.

Copy link
Collaborator Author

@wr0124 wr0124 Nov 9, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, i'll check on it and delete it

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if i remove it, my code does not work.

self.gan_mode, target_real_label=target_real_label
).to(self.device)

self.viz = visdom.Visdom(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove viz from here, viz is taken care of elsewhere.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be left here for testing, but moved/removed afterwards.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK. I'll delete all the viz

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is done

mixed_images = masks * fake + (1 - masks) * real
return mixed_images, masks

train_use_cutmix = True # change this into True to use cutmix
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to make this an option, in options/train_options.py and then pass it as argument to DualDiscriminatorGANLoss(DiscriminatorLoss) constructor.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, I'll make this an option.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is done

consistent_pred_pixel = torch.mul(
pred_fake_pixel, 1 - cutmix_pixel_label
) + torch.mul(pred_real_pixel, cutmix_pixel_label)
loss_cutmix_pixel = (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wr0124 are you sure this is the loss from the paper ?
It should be D(cutmix(inputs)) - cutmix(D(inputs)) no ?

Copy link
Collaborator Author

@wr0124 wr0124 Nov 9, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, you are right. in the paper, it is || D(cutmix(inputs)) - cutmix(D(inputs)) ||² . So, D(cutmix(inputs)) is pred_cutmix_fake_pixel in my code, and cutmix(D(inputs)) is consistent_pred_pixel in my code @beniz

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK cool, you can probably use MSELoss instead of the norm and **2, that´d be more efficient probably.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, you are right. The MSELoss is being used here.

@wr0124 wr0124 force-pushed the feat_discriminator_unet_mha branch from 59bd27e to 3f67ad8 Compare November 10, 2023 14:49
@wr0124 wr0124 force-pushed the feat_discriminator_unet_mha branch from 3f67ad8 to dc69f51 Compare November 10, 2023 15:49
@wr0124 wr0124 requested a review from beniz November 10, 2023 15:50
@@ -379,6 +379,20 @@ def initialize(self, parser):
default=64,
help="# of discrim filters in the first conv layer",
)
parser.add_argument(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe this is already done with --D_ndf.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is removed

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is removed


parser.add_argument(
"--train_use_cutmix",
type=bool,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use action="store_true" instead.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is done

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants