-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
110 lines (91 loc) · 3.33 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import copy
import os
import shutil
import numpy as np
import torch
# To make directories
def mkdir(paths):
for path in paths:
if not os.path.isdir(path):
os.makedirs(path)
# To make cuda tensor
def cuda(xs):
if torch.cuda.is_available():
if not isinstance(xs, (list, tuple)):
return xs.cuda()
else:
return [x.cuda() for x in xs]
# For Pytorch data loader
def create_link(dataset_dir):
dirs = {}
dirs['trainA'] = os.path.join(dataset_dir, 'trainA')
dirs['trainB'] = os.path.join(dataset_dir, 'trainB')
dirs['testA'] = os.path.join(dataset_dir, 'testA')
dirs['testB'] = os.path.join(dataset_dir, 'testB')
mkdir(dirs.values())
for key in dirs:
try:
os.remove(os.path.join(dirs[key], 'Link'))
except:
pass
os.symlink(os.path.abspath(os.path.join(dataset_dir, key)),
os.path.join(dirs[key], 'Link'))
return dirs
def get_traindata_link(dataset_dir):
dirs = {}
dirs['trainA'] = os.path.join(dataset_dir, 'trainA')
dirs['trainB'] = os.path.join(dataset_dir, 'trainB')
return dirs
def get_testdata_link(dataset_dir):
dirs = {}
dirs['testA'] = os.path.join(dataset_dir, 'testA')
dirs['testB'] = os.path.join(dataset_dir, 'testB')
return dirs
# To save the checkpoint
def save_checkpoint(state, save_path):
torch.save(state, save_path)
# To load the checkpoint
def load_checkpoint(ckpt_path, map_location=None):
ckpt = torch.load(ckpt_path, map_location=map_location)
print(' [*] Loading checkpoint from %s succeed!' % ckpt_path)
return ckpt
# To store 50 generated image in a pool and sample from it when it is full
# Shrivastava et al’s strategy
class Sample_from_Pool(object):
def __init__(self, max_elements=50):
self.max_elements = max_elements
self.cur_elements = 0
self.items = []
def __call__(self, in_items):
return_items = []
for in_item in in_items:
if self.cur_elements < self.max_elements:
self.items.append(in_item)
self.cur_elements = self.cur_elements + 1
return_items.append(in_item)
else:
if np.random.ranf() > 0.5:
idx = np.random.randint(0, self.max_elements)
tmp = copy.copy(self.items[idx])
self.items[idx] = in_item
return_items.append(tmp)
else:
return_items.append(in_item)
return return_items
class LambdaLR():
def __init__(self, epochs, offset, decay_epoch):
self.epochs = epochs
self.offset = offset
self.decay_epoch = decay_epoch
def step(self, epoch):
return 1.0 - max(0, epoch + self.offset - self.decay_epoch)/(self.epochs - self.decay_epoch)
def print_networks(nets, names):
print('------------Number of Parameters---------------')
i=0
for net in nets:
num_params = 0
for param in net.parameters():
num_params += param.numel()
print('[Network %s] Total number of parameters : %.3f M' % (names[i], num_params / 1e6))
i=i+1
print('-----------------------------------------------')