-
Notifications
You must be signed in to change notification settings - Fork 1
/
Vgg19.py
50 lines (40 loc) · 1.78 KB
/
Vgg19.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.utils as utils
from torchvision import models
class MeanShift(nn.Conv2d):
def __init__(self, rgb_range, rgb_mean, rgb_std, sign=-1):
super(MeanShift, self).__init__(3, 3, kernel_size=1)
std = torch.Tensor(rgb_std)
self.weight.data = torch.eye(3).view(3, 3, 1, 1)
self.weight.data.div_(std.view(3, 1, 1, 1))
self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean)
self.bias.data.div_(std)
# self.requires_grad = False
self.weight.requires_grad = False
self.bias.requires_grad = False
class Vgg19(torch.nn.Module):
def __init__(self, requires_grad=False, rgb_range=1):
super(Vgg19, self).__init__()
#vgg_pretrained_features = models.vgg19(pretrained=True).features
vgg_pretrained_features=models.vgg19()
pre=torch.load('model_save/vgg19-dcbb9e9d.pth')
vgg_pretrained_features.load_state_dict(pre)
vgg_pretrained_features=vgg_pretrained_features.features
self.slice1 = torch.nn.Sequential()
for x in range(30):
self.slice1.add_module(str(x), vgg_pretrained_features[x])
if not requires_grad:
for param in self.slice1.parameters():
param.requires_grad = False
vgg_mean = (0.485, 0.456, 0.406)
vgg_std = (0.229 * rgb_range, 0.224 * rgb_range, 0.225 * rgb_range)
self.sub_mean = MeanShift(rgb_range, vgg_mean, vgg_std)
def forward(self, X):
h = self.sub_mean(X)
h_relu5_1 = self.slice1(h)
return h_relu5_1
if __name__ == '__main__':
vgg19 = Vgg19(requires_grad=False)