-
Notifications
You must be signed in to change notification settings - Fork 10
/
net_utils.py
284 lines (248 loc) · 11.5 KB
/
net_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
"""Utilities for building Inflated 3D ConvNets """
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
import collections
from snets.scopes import *
slim = tf.contrib.slim
@add_arg_scope
def unit3D(inputs, output_channels,
kernel_shape=(1, 1, 1),
strides=(1, 1, 1),
activation_fn=tf.nn.relu,
use_batch_norm=True,
use_bias=False,
padding='same',
is_training=True,
name=None):
"""Basic unit containing Conv3D + BatchNorm + non-linearity."""
with tf.variable_scope(name, 'unit3D', [inputs]):
net = tf.layers.conv3d(inputs, filters=output_channels,
kernel_size=kernel_shape,
strides=strides,
padding=padding,
use_bias=use_bias)
if use_batch_norm:
net = tf.contrib.layers.batch_norm(net, is_training=is_training)
if activation_fn is not None:
net = activation_fn(net)
return net
@add_arg_scope
def sep3D(inputs, output_channels,
kernel_shape=(1, 1, 1),
strides=(1, 1, 1),
activation_fn=tf.nn.relu,
use_batch_norm=True,
use_bias=False,
padding='same',
is_training=True,
name=None):
"""Basic Sep-Conv3D layer with BatchNorm + non-linearity.
A (k_t, k, k) kernel is replaced by a (1, k, k) kernel and a (k_t, 1, 1) kernel
"""
k_t, k_h, k_w = kernel_shape
if type(strides) == int:
s_t, s_h, s_w = strides, strides, strides
else:
s_t, s_h, s_w = strides
spatial_kernel = (1, k_h, k_w)
spatial_stride = (1, s_h, s_w)
temporal_kernel = (k_t, 1, 1)
temporal_stride = (s_t, 1, 1)
with tf.variable_scope(name, 'sep3D', [inputs]):
spatial_net = tf.layers.conv3d(inputs, filters=output_channels,
kernel_size=spatial_kernel,
strides=spatial_stride,
padding=padding,
use_bias=use_bias)
if use_batch_norm:
spatial_net = tf.contrib.layers.batch_norm(spatial_net, is_training=is_training)
if activation_fn is not None:
spatial_net = activation_fn(spatial_net)
temporal_net = tf.layers.conv3d(spatial_net, filters=output_channels,
kernel_size=temporal_kernel,
strides=temporal_stride,
padding=padding,
use_bias=use_bias)
if use_batch_norm:
temporal_net = tf.contrib.layers.batch_norm(temporal_net, is_training=is_training)
if activation_fn is not None:
net = activation_fn(temporal_net)
return net
@add_arg_scope
def unit3D_same(inputs, output_channels,
kernel_shape=(1, 1, 1),
strides=(1, 1, 1),
activation_fn=tf.nn.relu,
use_batch_norm=True,
use_bias=False,
padding='same',
rate=1,
is_training=True,
name=None):
if (1, 1, 1) == strides:
return unit3D(inputs, output_channels, kernel_shape,strides,
padding=padding, name=name)
else:
kernel_size_effective = kernel_shape + (kernel_shape - 1) * (rate - 1)
pad_total = kernel_size_effective - 1
pad_beg = pad_total // 2
pad_end = pad_total - pad_beg
padding = [[0, 0], [pad_beg, pad_end], [pad_beg, pad_end], [pad_beg, pad_end], [0, 0]]
inputs = tf.pad(inputs,padding)
return unit3D(inputs, output_channels, kernel_shape,
strides, padding='VALID', name=name)
def subsample3D(inputs, factor, name=None):
"""Subsamples the input along the spatial dimensions.
Args:
inputs: A `Tensor` of size [batch, height_in, width_in, channels].
factor: The subsampling factor.
scope: Optional variable_scope.
Returns:
output: A `Tensor` of size [batch, height_out, width_out, channels] with the
input, either intact (if factor == 1) or subsampled (if factor > 1).
"""
if factor == 1:
return inputs
else:
return tf.nn.max_pool3d(inputs, [1, 1, 1, 1, 1],
strides=[1, factor, factor, factor, 1],
padding='SAME', name=name)
class Block(collections.namedtuple('Block', ['scope', 'unit_fn', 'args'])):
"""A named tuple describing a ResNet block.
Its parts are:
scope: The scope of the `Block`.
unit_fn: The ResNet unit function which takes as input a `Tensor` and
returns another `Tensor` with the output of the ResNet unit.
args: A list of length equal to the number of units in the `Block`. The list
contains one (depth, depth_bottleneck, stride) tuple for each unit in the
block to serve as argument to unit_fn.
"""
@add_arg_scope
def bottleneck3D(inputs, depth, depth_bottleneck, stride, rate=1,
outputs_collections=None, scope=None):
"""Bottleneck residual unit variant with BN after convolutions.
This is the original residual unit proposed in [1]. See Fig. 1(a) of [2] for
its definition. Note that we use here the bottleneck variant which has an
extra bottleneck layer.
When putting together two consecutive ResNet blocks that use this unit, one
should use stride = 2 in the last unit of the first block.
Args:
inputs: A tensor of size [batch, height, width, channels].
depth: The depth of the ResNet unit output.
depth_bottleneck: The depth of the bottleneck layers.
stride: The ResNet unit's stride. Determines the amount of downsampling of
the units output compared to its input.
rate: An integer, rate for atrous convolution.
outputs_collections: Collection to add the ResNet unit output.
scope: Optional variable_scope.
Returns:
The ResNet unit's output.
"""
with tf.variable_scope(scope, 'bottleneck_v1', [inputs]) as sc:
depth_in = slim.utils.last_dimension(inputs.get_shape(), min_rank=4)
if depth == depth_in:
shortcut = subsample3D(inputs, stride, 'shortcut')
else:
shortcut = unit3D(inputs, depth, [1, 1, 1], strides=stride,
activation_fn=None, name='shortcut')
residual = unit3D(inputs, depth_bottleneck, [1, 1, 1], strides=1, name='conv1')
residual = unit3D_same(residual, depth_bottleneck, 3, stride,
rate=rate, name='conv2')
residual = unit3D(residual, depth, [1, 1, 1], strides=1,
activation_fn=None, name='conv3')
output = tf.nn.relu(shortcut + residual)
return slim.utils.collect_named_outputs(outputs_collections,
sc.original_name_scope,
output)
@add_arg_scope
def bottleneck3D_v2(inputs, depth, depth_bottleneck, stride, rate=1,
outputs_collections=None, scope=None):
with tf.variable_scope(scope, 'bottleneck_v2', [inputs]) as sc:
depth_in = slim.utils.last_dimension(inputs.get_shape(), min_rank=4)
input_shape = inputs.get_shape().as_list()
preact = slim.batch_norm(inputs, activation_fn=tf.nn.relu, scope='preact')
if depth == depth_in:
shortcut = subsample3D(inputs, stride, 'shortcut')
else:
shortcut = unit3D(preact, depth, [1, 1, 1], strides=stride,
use_batch_norm=False, activation_fn=None,
name='shortcut')
residual = unit3D(preact, depth_bottleneck, [1, 1, 1], strides=1,
name='conv1')
residual = unit3D_same(residual, depth_bottleneck, 3, stride,
rate=rate, name='conv2')
residual = unit3D(residual, depth, [1, 1, 1], strides=1,
use_batch_norm=False, activation_fn=None,
name='conv3')
output = shortcut + residual
return slim.utils.collect_named_outputs(outputs_collections,
sc.original_name_scope,
output)
@add_arg_scope
def stack_blocks_dense(net, blocks, output_stride=None,
outputs_collections=None):
"""Stacks ResNet `Blocks` and controls output feature density.
First, this function creates scopes for the ResNet in the form of
'block_name/unit_1', 'block_name/unit_2', etc.
Second, this function allows the user to explicitly control the ResNet
output_stride, which is the ratio of the input to output spatial resolution.
This is useful for dense prediction tasks such as semantic segmentation or
object detection.
Most ResNets consist of 4 ResNet blocks and subsample the activations by a
factor of 2 when transitioning between consecutive ResNet blocks. This results
to a nominal ResNet output_stride equal to 8. If we set the output_stride to
half the nominal network stride (e.g., output_stride=4), then we compute
responses twice.
Control of the output feature density is implemented by atrous convolution.
Args:
net: A `Tensor` of size [batch, height, width, channels].
blocks: A list of length equal to the number of ResNet `Blocks`. Each
element is a ResNet `Block` object describing the units in the `Block`.
output_stride: If `None`, then the output will be computed at the nominal
network stride. If output_stride is not `None`, it specifies the requested
ratio of input to output spatial resolution, which needs to be equal to
the product of unit strides from the start up to some level of the ResNet.
For example, if the ResNet employs units with strides 1, 2, 1, 3, 4, 1,
then valid values for the output_stride are 1, 2, 6, 24 or None (which
is equivalent to output_stride=24).
outputs_collections: Collection to add the ResNet block outputs.
Returns:
net: Output tensor with stride equal to the specified output_stride.
Raises:
ValueError: If the target output_stride is not valid.
"""
# The current_stride variable keeps track of the effective stride of the
# activations. This allows us to invoke atrous convolution whenever applying
# the next residual unit would result in the activations having stride larger
# than the target output_stride.
current_stride = 1
# The atrous convolution rate parameter.
rate = 1
for block in blocks:
with tf.variable_scope(block.scope, 'block', [net]) as sc:
for i, unit in enumerate(block.args):
if output_stride is not None and current_stride > output_stride:
raise ValueError('The target output_stride cannot be reached.')
with tf.variable_scope('unit_%d' % (i + 1), values=[net]):
unit_depth, unit_depth_bottleneck, unit_stride = unit
# If we have reached the target output_stride, then we need to employ
# atrous convolution with stride=1 and multiply the atrous rate by the
# current unit's stride for use in subsequent layers.
if output_stride is not None and current_stride == output_stride:
net = block.unit_fn(net, depth=unit_depth,
depth_bottleneck=unit_depth_bottleneck,
stride=1,
rate=rate)
rate *= unit_stride
else:
net = block.unit_fn(net, depth=unit_depth,
depth_bottleneck=unit_depth_bottleneck,
stride=unit_stride,
rate=1)
current_stride *= unit_stride
net = slim.utils.collect_named_outputs(outputs_collections, sc.name, net)
if output_stride is not None and current_stride != output_stride:
raise ValueError('The target output_stride cannot be reached.')
return net