Skip to content

Commit

Permalink
Merge pull request #4 from carmelgafa/gan
Browse files Browse the repository at this point in the history
Add Generative Adversarial Network (GAN) Implementation for MNIST Dataset
  • Loading branch information
carmelgafa authored Oct 19, 2024
2 parents 90a6ebd + 4d3e739 commit c51c4f4
Show file tree
Hide file tree
Showing 9 changed files with 315 additions and 10 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
venv
data
Binary file not shown.
Binary file not shown.
42 changes: 42 additions & 0 deletions ml_algorithms/src/algorithms/gan/discriminator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@


import torch
import torch.nn as nn


class Discriminator(nn.Module):

def __init__(self):
"""
Initialize the discriminator network.
The network consists of three layers of fully connected (dense) layers.
The output of the network is a probability that the input is real.
"""
super(Discriminator, self).__init__()
self.model = nn.Sequential(
nn.Linear(784, 512),
nn.LeakyReLU(0.2),
nn.Linear(512, 256),
nn.LeakyReLU(0.2),
nn.Linear(256, 1),
nn.Sigmoid() # Output a probability
)

def forward(self, img:torch.tensor) -> torch.tensor:
"""
Forward pass of the discriminator network.
Parameters
----------
img : torch.tensor
The input image to the discriminator network.
Returns
-------
validity : torch.tensor
The probability that the input image is real.
"""
img_flat = img.view(img.size(0), -1) # Flatten the image
validity = self.model(img_flat)
return validity
102 changes: 102 additions & 0 deletions ml_algorithms/src/algorithms/gan/gan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
from torchvision import transforms
import torchvision
import matplotlib.pyplot as plt

from generator import Generator
from discriminator import Discriminator

# Hyperparameters
latent_dim = 100
lr = 0.0002
batch_size = 64
epochs = 200

# Device configuration (GPU if available)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]) # Normalize images to [-1, 1]
])

train_data = datasets.MNIST(root="./data", train=True, transform=transform, download=True)
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)


generator = Generator(latent_dim).to(device)
discriminator = Discriminator().to(device)

optimizer_G = optim.Adam(generator.parameters(), lr=lr)
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr)

criterion = nn.BCELoss() # Binary Cross Entropy Loss

generator = Generator(latent_dim).to(device)
discriminator = Discriminator().to(device)

optimizer_G = optim.Adam(generator.parameters(), lr=lr)
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr)

criterion = nn.BCELoss() # Binary Cross Entropy Loss

for epoch in range(epochs):
for i, (imgs, _) in enumerate(train_loader):

# Ground truths
real = torch.ones(imgs.size(0), 1).to(device)
fake = torch.zeros(imgs.size(0), 1).to(device)

# ---------------------
# Train Discriminator
# ---------------------

optimizer_D.zero_grad()

# Real images
real_imgs = imgs.to(device)
real_loss = criterion(discriminator(real_imgs), real)

# Fake images
z = torch.randn(imgs.size(0), latent_dim).to(device)
fake_imgs = generator(z)
fake_loss = criterion(discriminator(fake_imgs), fake)

# Total loss for discriminator
d_loss = real_loss + fake_loss
d_loss.backward()
optimizer_D.step()

# -----------------
# Train Generator
# -----------------

optimizer_G.zero_grad()

# Generate fake images
z = torch.randn(imgs.size(0), latent_dim).to(device)
fake_imgs = generator(z)

# The generator wants the discriminator to think these images are real
g_loss = criterion(discriminator(fake_imgs), real)

g_loss.backward()
optimizer_G.step()

# Print progress
if i % 200 == 0:
print(f"Epoch [{epoch}/{epochs}] Batch {i}/{len(train_loader)} \
Loss D: {d_loss.item():.4f}, loss G: {g_loss.item():.4f}")

# Save generated samples for visualization every few epochs
if epoch % 10 == 0:
with torch.no_grad():
z = torch.randn(16, latent_dim).to(device)
generated_imgs = generator(z).cpu().view(-1, 1, 28, 28)
grid_img = torchvision.utils.make_grid(generated_imgs, nrow=4, normalize=True)
plt.imshow(grid_img.permute(1, 2, 0))
plt.show()
37 changes: 37 additions & 0 deletions ml_algorithms/src/algorithms/gan/generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import torch
import torch.nn as nn


class Generator(nn.Module):
def __init__(self, latent_dim):
super(Generator, self).__init__()
self.model = nn.Sequential(
nn.Linear(latent_dim, 128),
nn.LeakyReLU(0.2),
nn.Linear(128, 256),
nn.BatchNorm1d(256),
nn.LeakyReLU(0.2),
nn.Linear(256, 512),
nn.BatchNorm1d(512),
nn.LeakyReLU(0.2),
nn.Linear(512, 784), # 28x28=784
nn.Tanh() # Normalize the output to [-1, 1]
)

def forward(self, z:torch.tensor) -> torch.tensor:
"""
Forward pass of the generator network.
Parameters
----------
z : torch.tensor
The input latent vector to the generator network.
Returns
-------
img : torch.tensor
The generated image, reshaped to 28x28 for MNIST.
"""
img = self.model(z)
img = img.view(img.size(0), 1, 28, 28) # Reshape to 28x28 for MNIST
return img
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
x,y
0.0,159.93428306022466
1.0,297.2347139765763
2.0,462.95377076201385
3.0,630.4605971281605
4.0,745.3169325055333
5.0,895.3172608610164
6.0,1081.5842563101478
7.0,1215.3486945830582
8.0,1340.610512281301
9.0,1510.8512008717194
10.0,1640.7316461437508
11.0,1790.6854049285948
12.0,1954.8392454313207
13.0,2061.734395106844
14.0,2215.5016433497394
15.0,2388.7542494151808
16.0,2529.7433775933114
17.0,2706.2849466519056
18.0,2831.8395184895758
19.0,2971.7539259732944
20.0,3179.312975378431
21.0,3295.484473990269
22.0,3451.3505640937583
23.0,3571.505036275731
24.0,3739.112345509496
25.0,3902.2184517941973
26.0,4026.980128451554
27.0,4207.513960366913
28.0,4337.987226201624
29.0,4494.166125004134
30.0,4637.965867755412
31.0,4837.045563690179
32.0,4949.730055505242
33.0,5078.845781420882
34.0,5266.450898242064
35.0,5375.583127000579
36.0,5554.177271900096
37.0,5660.806597522404
38.0,5823.436279022031
39.0,6003.937224717382
40.0,6164.769331599909
41.0,6303.427365623799
42.0,6447.687034352235
43.0,6593.977926088214
44.0,6720.4295601926515
45.0,6885.603115832106
46.0,7040.787224580804
47.0,7221.142444524378
48.0,7356.872365791369
49.0,7464.739196892745
50.0,7656.481679387896
51.0,7792.2983543916735
52.0,7936.461559993881
53.0,8112.233525776817
54.0,8270.619990449919
55.0,8418.625602382324
56.0,8533.215649535547
57.0,8693.815752482975
58.0,8856.625268628071
59.0,9019.510902542446
60.0,9140.416515243094
61.0,9296.286820466723
62.0,9427.87330051988
63.0,9576.075867518386
64.0,9766.250516447884
65.0,9927.124800571417
66.0,10048.559797568394
67.0,10220.07065795784
68.0,10357.232720500953
69.0,10487.097604907898
70.0,10657.227912110167
71.0,10830.76073132932
72.0,10949.283479217802
73.0,11131.29287311628
74.0,11197.605097918206
75.0,11416.438050087505
76.0,11551.740941364764
77.0,11694.019852990683
78.0,11851.83521553071
79.0,11960.248621707982
80.0,12145.60656224325
81.0,12307.142251430236
82.0,12479.55788089483
83.0,12589.634595634527
84.0,12733.830127942136
85.0,12889.96485912831
86.0,13068.308042354041
87.0,13206.575022193194
88.0,13339.40479592466
89.0,13510.265348662268
90.0,13651.941550986961
91.0,13819.372899810658
92.0,13935.958938122452
93.0,14093.446757068044
94.0,14242.157836937356
95.0,14370.729701037357
96.0,14555.922405541292
97.0,14705.221105443597
98.0,14850.102269132849
99.0,14995.308257332497
100.0,15121.692585158991
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
'''
single feature data generation
'''
import os
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import os

def generate_data(a0, a1, noise_sigma, plot=False):
def generate_data(a0, a1, noise_sigma, file_name, plot=False):
'''
Generates 100 points with m slope and c intercept
and adds noise with sigma
Expand All @@ -21,7 +21,7 @@ def generate_data(a0, a1, noise_sigma, plot=False):
e = np.random.randn(len(x))*noise_sigma
y = l + e

file_path = os.path.join(os.path.dirname(__file__), 'data_1f.csv')
file_path = os.path.join(os.path.dirname(__file__), file_name)
# save the data to a csv file
df = pd.DataFrame(data=[x, y]).T
df.columns = ['x', 'y']
Expand All @@ -36,4 +36,4 @@ def generate_data(a0, a1, noise_sigma, plot=False):
plt.show()

if __name__=='__main__':
generate_data(a0=150, a1=20, noise_sigma=20, plot=True)
generate_data(a0=150, a1=20, noise_sigma=20, file_name="data_1f.csv", plot=True)
Original file line number Diff line number Diff line change
Expand Up @@ -77,15 +77,27 @@ def plot_univariate_gd_analysis(
plt.rcParams['text.usetex'] = True
fig = plt.figure()
ax = plt.axes(projection='3d')
# ax.plot_surface(
# a0,
# a1,
# np.array(costs),
# rstride=1,
# cstride=1,
# cmap='cividis',
# edgecolor='none',
# alpha=0.5)
ax.plot_surface(
a0,
a1,
np.array(costs),
rstride=1,
cstride=1,
cmap='cividis',
cmap='viridis', # or 'plasma'
edgecolor='none',
alpha=0.5)
alpha=0.6)



ax.contour(a0, a1, np.array(costs), zdir='z', offset=-0.5, cmap=cm.coolwarm)
ax.plot(xx, yy, zz, 'r.--', alpha=1)
ax.set_xlabel(r'$a_0$')
Expand All @@ -96,9 +108,18 @@ def plot_univariate_gd_analysis(
if __name__=='__main__':
import os

# plot_univariate_gd_analysis(
# file=os.path.join(os.path.dirname(__file__), 'data_generation', 'data_1f.csv'),
# a0_range=(125,175,0.2),
# a1_range=(18,22,0.2),
# gd_points= [],
# plot_slices=True)


plot_univariate_gd_analysis(
file=os.path.join(os.path.dirname(__file__), 'data_generation', 'data_1f.csv'),
a0_range=(125,175,0.2),
a1_range=(18,22,0.2),
gd_points= [],
plot_slices=True)
a0_range=(125, 175, 0.1), # finer grid
a1_range=(18, 22, 0.1), # finer grid
gd_points=[],
plot_slices=True
)

0 comments on commit c51c4f4

Please sign in to comment.