diff --git a/oatomobile/baselines/torch/dim/model.py b/oatomobile/baselines/torch/dim/model.py index 2c35638..60a6bbf 100644 --- a/oatomobile/baselines/torch/dim/model.py +++ b/oatomobile/baselines/torch/dim/model.py @@ -30,7 +30,7 @@ from oatomobile.torch import types from oatomobile.torch.networks.mlp import MLP from oatomobile.torch.networks.perception import MobileNetV2 -from oatomobile.torch.networks.sequence import AutoregressiveFlow +from oatomobile.torch.networks.sequence import AutoregressiveFlow, TCN class ImitativeModel(nn.Module): @@ -62,9 +62,18 @@ def __init__( ) # The decoder recurrent network used for the sequence generation. - self._decoder = AutoregressiveFlow( - output_shape=self._output_shape, - hidden_size=64, + # self._decoder = AutoregressiveFlow( + # output_shape=self._output_shape, + # hidden_size=64, + # ) + + self._decoder = TCN( + input_channels=1, + num_input_features=64, + num_output_features=2, + num_channels=[30, 30, 30, 30, 30, 30, 30, 30, 4], + kernel_size=7, + dropout=0.2 ) def to(self, *args, **kwargs): @@ -92,6 +101,8 @@ def forward( Returns: A mode from the posterior, with shape `[D, 2]`. """ + + if not "visual_features" in context: raise ValueError("Missing `visual_features` keyword argument.") batch_size = context["visual_features"].shape[0] diff --git a/oatomobile/baselines/torch/dim/train.py b/oatomobile/baselines/torch/dim/train.py index 07b5400..1216875 100644 --- a/oatomobile/baselines/torch/dim/train.py +++ b/oatomobile/baselines/torch/dim/train.py @@ -47,7 +47,7 @@ ) flags.DEFINE_integer( name="batch_size", - default=512, + default=50, help="The batch size used for training the neural network.", ) flags.DEFINE_integer( @@ -80,6 +80,11 @@ default=False, help="If True it clips the gradients norm to 1.0.", ) +flags.DEFINE_bool( + name="use_tcn", + default=True, + help="If True, use the TCN decoder. Otherwise, use the autoregressive decoder.", +) def main(argv): @@ -97,6 +102,7 @@ def main(argv): num_timesteps_to_keep = FLAGS.num_timesteps_to_keep weight_decay = FLAGS.weight_decay clip_gradients = FLAGS.clip_gradients + use_tcn = FLAGS.use_tcn noise_level = 1e-2 # Determines device, accelerator. @@ -151,7 +157,7 @@ def transform(batch: Mapping[str, types.Array]) -> Mapping[str, torch.Tensor]: dataset_train, batch_size=batch_size, shuffle=True, - num_workers=50, + num_workers=5, ) dataset_val = CARLADataset.as_torch( dataset_dir=os.path.join(dataset_dir, "val"), @@ -161,7 +167,7 @@ def transform(batch: Mapping[str, types.Array]) -> Mapping[str, torch.Tensor]: dataset_val, batch_size=batch_size * 5, shuffle=True, - num_workers=50, + num_workers=5, ) # Theoretical limit of NLL. @@ -182,10 +188,11 @@ def train_step( # Resets optimizer's gradients. optimizer.zero_grad() + target_mean = batch["player_future"][..., :2] # Perturb target. y = torch.normal( # pylint: disable=no-member - mean=batch["player_future"][..., :2], - std=torch.ones_like(batch["player_future"][..., :2]) * noise_level, # pylint: disable=no-member + mean=target_mean, + std=torch.ones_like(target_mean) * noise_level, # pylint: disable=no-member ) # Forward pass from the model. @@ -195,10 +202,9 @@ def train_step( is_at_traffic_light=batch["is_at_traffic_light"], traffic_light_state=batch["traffic_light_state"], ) - _, log_prob, logabsdet = model._decoder._inverse(y=y, z=z) - # Calculates loss (NLL). - loss = -torch.mean(log_prob - logabsdet, dim=0) # pylint: disable=no-member + target_mean = batch["player_future"][..., :2] + loss = model._decoder.compute_loss(y=target_mean, z=z) # Backward pass. loss.backward() @@ -240,13 +246,9 @@ def evaluate_step( is_at_traffic_light=batch["is_at_traffic_light"], traffic_light_state=batch["traffic_light_state"], ) - _, log_prob, logabsdet = model._decoder._inverse( - y=batch["player_future"][..., :2], - z=z, - ) - - # Calculates loss (NLL). - loss = -torch.mean(log_prob - logabsdet, dim=0) # pylint: disable=no-member + + target_mean = batch["player_future"][..., :2] + loss = model._decoder.compute_loss(y=target_mean, z=z) return loss diff --git a/oatomobile/torch/networks/sequence.py b/oatomobile/torch/networks/sequence.py index 3fe982c..c39f2b1 100644 --- a/oatomobile/torch/networks/sequence.py +++ b/oatomobile/torch/networks/sequence.py @@ -14,12 +14,13 @@ # ============================================================================== """Sequence generation.""" -from typing import Tuple +from typing import Callable, List, Tuple import torch import torch.distributions as D import torch.nn as nn import torch.nn.functional as F +from torch.nn.utils import weight_norm from oatomobile.torch import types from oatomobile.torch.networks.mlp import MLP @@ -214,3 +215,115 @@ def _inverse( logabsdet = torch.sum(logabsdet, dim=-1) # sum over T dimension # pylint: disable=no-member return x, log_prob, logabsdet + + def compute_loss( + self, + y: torch.Tensor, + z: torch.Tensor, + ) -> torch.Tensor: + _, log_prob, logabsdet = self._inverse(y, z) + loss = -torch.mean(log_prob - logabsdet, dim=0) # pylint: disable=no-member + return loss + + +class Chomp1d(nn.Module): + def __init__(self, chomp_size: int): + super(Chomp1d, self).__init__() + self.chomp_size = chomp_size + + def forward(self, x: torch.Tensor): + return x[:, :, :-self.chomp_size].contiguous() + +class TemporalBlock(nn.Module): + def __init__(self, n_inputs: int, n_outputs: int, kernel_size: int, stride: int, dilation: int, padding: int, dropout: int=0.2): + super(TemporalBlock, self).__init__() + self.conv1 = weight_norm(nn.Conv1d(n_inputs, n_outputs, kernel_size, + stride=stride, padding=padding, dilation=dilation)) + self.chomp1 = Chomp1d(padding) + self.relu1 = nn.ReLU() + self.dropout1 = nn.Dropout(dropout) + + self.conv2 = weight_norm(nn.Conv1d(n_outputs, n_outputs, kernel_size, + stride=stride, padding=padding, dilation=dilation)) + self.chomp2 = Chomp1d(padding) + self.relu2 = nn.ReLU() + self.dropout2 = nn.Dropout(dropout) + + self.net = nn.Sequential(self.conv1, self.chomp1, self.relu1, self.dropout1, + self.conv2, self.chomp2, self.relu2, self.dropout2) + self.downsample = nn.Conv1d(n_inputs, n_outputs, 1) if n_inputs != n_outputs else None + self.relu = nn.ReLU() + self.init_weights() + + def init_weights(self): + self.conv1.weight.data.normal_(0, 0.01) + self.conv2.weight.data.normal_(0, 0.01) + if self.downsample is not None: + self.downsample.weight.data.normal_(0, 0.01) + + def forward(self, x: torch.Tensor): + out = self.net(x) + res = x if self.downsample is None else self.downsample(x) + return self.relu(out + res) + +class TemporalConvNet(nn.Module): + def __init__(self, num_inputs: int, num_channels: List[int], kernel_size: int=2, dropout: int=0.2): + super(TemporalConvNet, self).__init__() + layers = [] + num_levels = len(num_channels) + for i in range(num_levels): + dilation_size = 2 ** i + in_channels = num_inputs if i == 0 else num_channels[i-1] + out_channels = num_channels[i] + layers += [TemporalBlock(in_channels, out_channels, kernel_size, stride=1, dilation=dilation_size, + padding=(kernel_size-1) * dilation_size, dropout=dropout)] + + self.network = nn.Sequential(*layers) + + def forward(self, x: torch.Tensor): + return self.network(x) + +class TCN(nn.Module): + """ + This generic TCN model converts the image + modalities representation to a sequence of grid points. + + Args: + input_channels: length of input sequence, and in our case, this value is 1 (1 image) + num_input_features: number of features in input sequence, and in our case, this value is 64 + num_output_features: number of features in output sequence, and in our case, this value is 2 (2D grid points) + num_channels: list of sizes of hidden layers + kernel_size: kernel size of 1D dilated convolution + dropout: dropout rate after every 1D dilated convolution + + Usage: + In our setting, we initialize our TCN as: + TCN( + input_channels=1, + num_input_features=64, + num_output_features=2, + num_channels=[30, 30, 30, 30, 30, 30, 30, 30, 4], + kernel_size=7, + dropout=0.0 + ) + This converts the (batch_size, 1, num_input_features) inputs to (batch_size, num_channels[-1], num_output_features) + In our setting, this converts (1, 1, 64) inputs to (1, 4, 2) outputs. + """ + def __init__(self, input_channels, num_input_features, num_output_features, num_channels, kernel_size, dropout): + super(TCN, self).__init__() + self.tcn = TemporalConvNet(input_channels, num_channels, kernel_size=kernel_size, dropout=dropout) + # going from 64 image features to 2 grid points + self.linear = nn.Linear(num_input_features, num_output_features) + self.init_weights() + + def init_weights(self): + self.linear.weight.data.normal_(0, 0.01) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + y1 = self.tcn(x) + return self.linear(y1) + + def compute_loss(self, y: torch.Tensor, z: torch.Tensor) -> torch.Tensor: + z = z.unsqueeze(1) # convert to sequence for TCN processing + seq_prediction =self.forward(z) + loss = F.mse_loss(y, seq_prediction) + return loss \ No newline at end of file