-
Notifications
You must be signed in to change notification settings - Fork 0
/
model.py
114 lines (95 loc) · 3.58 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
from torch import nn
import torch
class VAE(nn.Module):
def __init__(self,
in_channels: int,
latent_dim: int,
hidden_dims=None,
) -> None:
super(VAE, self).__init__()
self.latent_dim = latent_dim
modules = []
if hidden_dims is None:
hidden_dims = [32, 64, 128, 256, 512]
# Encoder
for h_dim in hidden_dims:
modules.append(
nn.Sequential(
nn.Conv2d(in_channels, out_channels=h_dim,
kernel_size=3, stride=2, padding=1),
nn.BatchNorm2d(h_dim),
nn.LeakyReLU()))
in_channels = h_dim
self.encoder = nn.Sequential(*modules)
self.fc_mu = nn.Linear(512*7*7, latent_dim)
self.fc_var = nn.Linear(512*7*7, latent_dim)
# Decoder
modules = []
self.decoder_input = nn.Sequential(
nn.Linear(latent_dim, 512*7*7),
nn.LeakyReLU())
hidden_dims.reverse()
for i in range(len(hidden_dims) - 1):
modules.append(
nn.Sequential(
nn.ConvTranspose2d(hidden_dims[i],
hidden_dims[i + 1],
kernel_size=3,
stride = 2,
padding=1,
output_padding=1),
nn.BatchNorm2d(hidden_dims[i + 1]),
nn.LeakyReLU()))
self.decoder = nn.Sequential(*modules)
self.final_layer = nn.Sequential(
nn.ConvTranspose2d(hidden_dims[-1],
hidden_dims[-1],
kernel_size=3,
stride=2,
padding=1,
output_padding=1),
nn.BatchNorm2d(hidden_dims[-1]),
nn.LeakyReLU(),
nn.Conv2d(hidden_dims[-1], out_channels= 3,
kernel_size= 3, padding= 1),
nn.Sigmoid())
def encode(self, input):
result = self.encoder(input)
result = torch.flatten(result, start_dim=1)
mu = self.fc_mu(result)
log_var = self.fc_var(result)
return mu, log_var
def decode(self, z):
result = self.decoder_input(z)
result = result.view(-1, 512, 7, 7)
result = self.decoder(result)
result = self.final_layer(result)
return result
def reparameterize(self, mu, logvar):
"""
Reparameterization trick to sample from N(mu, var) from N(0,1)
"""
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return eps * std + mu
def forward(self, input):
mu, log_var = self.encode(input)
z = self.reparameterize(mu, log_var)
return self.decode(z), mu, log_var
def reconstruct(self, x):
'''
Reconstruct from input images (b, 3, 224, 224)
'''
return self.forward(x)[0]
def get_z(self, x):
'''
Return the latent embedding of input images
'''
mu, log_var = self.encode(x)
z = self.reparameterize(mu, log_var)
return z
def generate_from_z(self, z):
'''
Generate images from latent embedding z
'''
return self.decode(z)