diff --git a/mmaction/models/action_segmentors/asformer.py b/mmaction/models/action_segmentors/asformer.py index b83639cb4d..534595a19c 100644 --- a/mmaction/models/action_segmentors/asformer.py +++ b/mmaction/models/action_segmentors/asformer.py @@ -60,8 +60,6 @@ def forward(self, inputs, data_samples, mode, **kwargs): - If ``mode="loss"``, return a dict of tensor. """ input = torch.stack(inputs) - if mode == 'tensor': - return self._forward(inputs, **kwargs) if mode == 'predict': return self.predict(input, data_samples, **kwargs) elif mode == 'loss': @@ -169,19 +167,6 @@ def predict(self, batch_inputs, batch_data_samples, **kwargs): output = [dict(ground=ground, recognition=recognition)] return output - def _forward(self, x): - """Define the computation performed at every call. - - Args: - x (torch.Tensor): The input data. - Returns: - torch.Tensor: The output of the module. - """ - print(x.shape) - - return x.shape - - def exponential_descrease(idx_decoder, p=3): return math.exp(-p * idx_decoder) @@ -448,6 +433,13 @@ def __init__(self, dilation, in_channels, out_channels): dilation=dilation), nn.ReLU()) def forward(self, x): + """Define the computation performed at every call. + + Args: + x (torch.Tensor): The input data. + Returns: + torch.Tensor: The output of the module. + """ return self.layer(x) @@ -579,7 +571,7 @@ def forward(self, x, fencoder, mask): class MyTransformer(nn.Module): - + """An encoder-decoder transformer""" def __init__(self, num_decoders, num_layers, r1, r2, num_f_maps, input_dim, num_classes, channel_masking_rate): super(MyTransformer, self).__init__() @@ -608,6 +600,13 @@ def __init__(self, num_decoders, num_layers, r1, r2, num_f_maps, input_dim, ]) # num_decoders def forward(self, x, mask): + """Define the computation performed at every call. + + Args: + x (torch.Tensor): The input data. + Returns: + torch.Tensor: The output of the module. + """ out, feature = self.encoder(x, mask) outputs = out.unsqueeze(0) @@ -617,4 +616,4 @@ def forward(self, x, mask): feature * mask[:, 0:1, :], mask) outputs = torch.cat((outputs, out.unsqueeze(0)), dim=0) - return outputs + return outputs \ No newline at end of file