Skip to content

Commit

Permalink
created generator and discriminator networks
Browse files Browse the repository at this point in the history
  • Loading branch information
carmelgafa committed Oct 19, 2024
1 parent 9732c56 commit fe2a312
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 0 deletions.
42 changes: 42 additions & 0 deletions ml_algorithms/src/algorithms/gan/discriminator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@


import torch
import torch.nn as nn


class Discriminator(nn.Module):

def __init__(self):
"""
Initialize the discriminator network.
The network consists of three layers of fully connected (dense) layers.
The output of the network is a probability that the input is real.
"""
super(Discriminator, self).__init__()
self.model = nn.Sequential(
nn.Linear(784, 512),
nn.LeakyReLU(0.2),
nn.Linear(512, 256),
nn.LeakyReLU(0.2),
nn.Linear(256, 1),
nn.Sigmoid() # Output a probability
)

def forward(self, img:torch.tensor) -> torch.tensor:
"""
Forward pass of the discriminator network.
Parameters
----------
img : torch.tensor
The input image to the discriminator network.
Returns
-------
validity : torch.tensor
The probability that the input image is real.
"""
img_flat = img.view(img.size(0), -1) # Flatten the image
validity = self.model(img_flat)
return validity
37 changes: 37 additions & 0 deletions ml_algorithms/src/algorithms/gan/generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import torch
import torch.nn as nn


class Generator(nn.Module):
def __init__(self, latent_dim):
super(Generator, self).__init__()
self.model = nn.Sequential(
nn.Linear(latent_dim, 128),
nn.LeakyReLU(0.2),
nn.Linear(128, 256),
nn.BatchNorm1d(256),
nn.LeakyReLU(0.2),
nn.Linear(256, 512),
nn.BatchNorm1d(512),
nn.LeakyReLU(0.2),
nn.Linear(512, 784), # 28x28=784
nn.Tanh() # Normalize the output to [-1, 1]
)

def forward(self, z:torch.tensor) -> torch.tensor:
"""
Forward pass of the generator network.
Parameters
----------
z : torch.tensor
The input latent vector to the generator network.
Returns
-------
img : torch.tensor
The generated image, reshaped to 28x28 for MNIST.
"""
img = self.model(z)
img = img.view(img.size(0), 1, 28, 28) # Reshape to 28x28 for MNIST
return img

0 comments on commit fe2a312

Please sign in to comment.