-
Notifications
You must be signed in to change notification settings - Fork 1
/
model.py
203 lines (181 loc) · 7.93 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
import torch
import torch.nn as nn
import numpy as np
def get_model(model_load_path):
'''
Model arguments found here:
https://github.com/antonior92/ecg-age-prediction/blob/f9801bbe7eb2ce8c5416f5d3d4182c7302813dec/train.py#L182-L183
and here:
https://www.dropbox.com/s/thvqwaryeo8uemo/model.zip?file_subpath=%2Fmodel%2Fconfig.json
'''
seq_length = 4096
net_filter_size = [64,128,196,256,320]
net_seq_length = [4096,1024,256,64,16]
N_CLASSES = 1
N_LEADS = 12
kernel_size = 17
dropout_rate = 0.8
model_args = {
'input_dim': (N_LEADS, seq_length),
'blocks_dim': list(zip(net_filter_size, net_seq_length)),
'n_classes': N_CLASSES,
'kernel_size': kernel_size,
'dropout_rate': dropout_rate,
}
model = ResNet1d(**model_args)
model.load_state_dict(torch.load(model_load_path, map_location='cpu')['model'])
# The model originally took 12 channels as input (corresponding to a 12
# lead ECG). Here, we construct a new initial convolutional layer. Since
# each cell only has one channel, we copy the parameter values
# corresponding to the first channel and ignore the rest.
n_samples_in, n_samples_out = model_args['input_dim'][1], model_args[
'blocks_dim'
][0][1]
downsample = _downsample(n_samples_in, n_samples_out)
padding = _padding(downsample, model_args['kernel_size'])
newconv1 = nn.Conv1d(
1, net_filter_size[0], kernel_size, bias=False, stride=downsample, padding=padding
)
with torch.no_grad():
next(newconv1.parameters())[:, :, :] = next(
model.conv1.parameters()
)[:, 0: 1, :]
model.conv1 = newconv1
# The model was originally a binary classification problem. Here, we
# replace the last layer to have 12 outputs to correspond to 12
# possible classifications.
# The fact that there were originally 12 input channels and there are now
# 12 output classes is a numerical coincidence - both values of 12 have
# nothing to do with each other.
model.lin = nn.Linear(model.last_layer_dim, 12)
return model, model_args
def _padding(downsample, kernel_size):
"""Compute required padding"""
padding = max(0, int(np.floor((kernel_size - downsample + 1) / 2)))
return padding
def _downsample(n_samples_in, n_samples_out):
"""Compute downsample rate"""
downsample = int(n_samples_in // n_samples_out)
if downsample < 1:
raise ValueError("Number of samples should always decrease")
if n_samples_in % n_samples_out != 0:
raise ValueError("Number of samples for two consecutive blocks "
"should always decrease by an integer factor.")
return downsample
class ResBlock1d(nn.Module):
"""
Residual network unit for unidimensional signals.
Model code is from
https://github.com/antonior92/ecg-age-prediction/blob/f9801bbe7eb2ce8c5416f5d3d4182c7302813dec/resnet.py
"""
def __init__(self, n_filters_in, n_filters_out, downsample, kernel_size, dropout_rate):
if kernel_size % 2 == 0:
raise ValueError("The current implementation only support odd values for `kernel_size`.")
super(ResBlock1d, self).__init__()
# Forward path
padding = _padding(1, kernel_size)
self.conv1 = nn.Conv1d(n_filters_in, n_filters_out, kernel_size, padding=padding, bias=False)
self.bn1 = nn.BatchNorm1d(n_filters_out)
self.relu = nn.ReLU()
self.dropout1 = nn.Dropout(dropout_rate)
padding = _padding(downsample, kernel_size)
self.conv2 = nn.Conv1d(n_filters_out, n_filters_out, kernel_size,
stride=downsample, padding=padding, bias=False)
self.bn2 = nn.BatchNorm1d(n_filters_out)
self.dropout2 = nn.Dropout(dropout_rate)
# Skip connection
skip_connection_layers = []
# Deal with downsampling
if downsample > 1:
maxpool = nn.MaxPool1d(downsample, stride=downsample)
skip_connection_layers += [maxpool]
# Deal with n_filters dimension increase
if n_filters_in != n_filters_out:
conv1x1 = nn.Conv1d(n_filters_in, n_filters_out, 1, bias=False)
skip_connection_layers += [conv1x1]
# Build skip conection layer
if skip_connection_layers:
self.skip_connection = nn.Sequential(*skip_connection_layers)
else:
self.skip_connection = None
def forward(self, x, y):
"""Residual unit."""
if self.skip_connection is not None:
y = self.skip_connection(y)
else:
y = y
# 1st layer
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.dropout1(x)
# 2nd layer
x = self.conv2(x)
x += y # Sum skip connection and main connection
y = x
x = self.bn2(x)
x = self.relu(x)
x = self.dropout2(x)
return x, y
class ResNet1d(nn.Module):
"""Residual network for unidimensional signals.
Parameters
----------
input_dim : tuple
Input dimensions. Tuple containing dimensions for the neural network
input tensor. Should be like: ``(n_filters, n_samples)``.
blocks_dim : list of tuples
Dimensions of residual blocks. The i-th tuple should contain the dimensions
of the output (i-1)-th residual block and the input to the i-th residual
block. Each tuple shoud be like: ``(n_filters, n_samples)``. `n_samples`
for two consecutive samples should always decrease by an integer factor.
dropout_rate: float [0, 1), optional
Dropout rate used in all Dropout layers. Default is 0.8
kernel_size: int, optional
Kernel size for convolutional layers. The current implementation
only supports odd kernel sizes. Default is 17.
References
----------
.. [1] K. He, X. Zhang, S. Ren, and J. Sun, "Identity Mappings in Deep Residual Networks,"
arXiv:1603.05027, Mar. 2016. https://arxiv.org/pdf/1603.05027.pdf.
.. [2] K. He, X. Zhang, S. Ren, and J. Sun, "Deep Residual Learning for Image Recognition," in 2016 IEEE Conference
on Computer Vision and Pattern Recognition (CVPR), 2016, pp. 770-778. https://arxiv.org/pdf/1512.03385.pdf
"""
def __init__(self, input_dim, blocks_dim, n_classes, kernel_size=17, dropout_rate=0.8):
super(ResNet1d, self).__init__()
# First layers
n_filters_in, n_filters_out = input_dim[0], blocks_dim[0][0]
n_samples_in, n_samples_out = input_dim[1], blocks_dim[0][1]
downsample = _downsample(n_samples_in, n_samples_out)
padding = _padding(downsample, kernel_size)
self.conv1 = nn.Conv1d(n_filters_in, n_filters_out, kernel_size, bias=False,
stride=downsample, padding=padding)
self.bn1 = nn.BatchNorm1d(n_filters_out)
# Residual block layers
for i, (n_filters, n_samples) in enumerate(blocks_dim):
n_filters_in, n_filters_out = n_filters_out, n_filters
n_samples_in, n_samples_out = n_samples_out, n_samples
downsample = _downsample(n_samples_in, n_samples_out)
setattr(self, f'resblock1d_{i}', ResBlock1d(
n_filters_in, n_filters_out, downsample,
kernel_size, dropout_rate
))
# Linear layer
n_filters_last, n_samples_last = blocks_dim[-1]
self.last_layer_dim = n_filters_last * n_samples_last
self.lin = nn.Linear(self.last_layer_dim, n_classes)
self.n_blk = len(blocks_dim)
def forward(self, x):
"""Implement ResNet1d forward propagation"""
# First layers
x = self.conv1(x)
x = self.bn1(x)
# Residual blocks
y = x
for i in range(self.n_blk):
x, y = getattr(self, f'resblock1d_{i}')(x, y)
# Flatten array
x = x.view(x.size(0), -1)
# Fully conected layer
x = self.lin(x)
return x