Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding TCN decoder. #9

Open
wants to merge 4 commits into
base: alpha
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 15 additions & 4 deletions oatomobile/baselines/torch/dim/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from oatomobile.torch import types
from oatomobile.torch.networks.mlp import MLP
from oatomobile.torch.networks.perception import MobileNetV2
from oatomobile.torch.networks.sequence import AutoregressiveFlow
from oatomobile.torch.networks.sequence import AutoregressiveFlow, TCN


class ImitativeModel(nn.Module):
Expand Down Expand Up @@ -62,9 +62,18 @@ def __init__(
)

# The decoder recurrent network used for the sequence generation.
self._decoder = AutoregressiveFlow(
output_shape=self._output_shape,
hidden_size=64,
# self._decoder = AutoregressiveFlow(
# output_shape=self._output_shape,
# hidden_size=64,
# )

self._decoder = TCN(
input_channels=1,
num_input_features=64,
num_output_features=2,
num_channels=[30, 30, 30, 30, 30, 30, 30, 30, 4],
kernel_size=7,
dropout=0.2
)

def to(self, *args, **kwargs):
Expand Down Expand Up @@ -92,6 +101,8 @@ def forward(
Returns:
A mode from the posterior, with shape `[D, 2]`.
"""


if not "visual_features" in context:
raise ValueError("Missing `visual_features` keyword argument.")
batch_size = context["visual_features"].shape[0]
Expand Down
32 changes: 17 additions & 15 deletions oatomobile/baselines/torch/dim/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
)
flags.DEFINE_integer(
name="batch_size",
default=512,
default=50,
help="The batch size used for training the neural network.",
)
flags.DEFINE_integer(
Expand Down Expand Up @@ -80,6 +80,11 @@
default=False,
help="If True it clips the gradients norm to 1.0.",
)
flags.DEFINE_bool(
name="use_tcn",
default=True,
help="If True, use the TCN decoder. Otherwise, use the autoregressive decoder.",
)


def main(argv):
Expand All @@ -97,6 +102,7 @@ def main(argv):
num_timesteps_to_keep = FLAGS.num_timesteps_to_keep
weight_decay = FLAGS.weight_decay
clip_gradients = FLAGS.clip_gradients
use_tcn = FLAGS.use_tcn
noise_level = 1e-2

# Determines device, accelerator.
Expand Down Expand Up @@ -151,7 +157,7 @@ def transform(batch: Mapping[str, types.Array]) -> Mapping[str, torch.Tensor]:
dataset_train,
batch_size=batch_size,
shuffle=True,
num_workers=50,
num_workers=5,
)
dataset_val = CARLADataset.as_torch(
dataset_dir=os.path.join(dataset_dir, "val"),
Expand All @@ -161,7 +167,7 @@ def transform(batch: Mapping[str, types.Array]) -> Mapping[str, torch.Tensor]:
dataset_val,
batch_size=batch_size * 5,
shuffle=True,
num_workers=50,
num_workers=5,
)

# Theoretical limit of NLL.
Expand All @@ -182,10 +188,11 @@ def train_step(
# Resets optimizer's gradients.
optimizer.zero_grad()

target_mean = batch["player_future"][..., :2]
# Perturb target.
y = torch.normal( # pylint: disable=no-member
mean=batch["player_future"][..., :2],
std=torch.ones_like(batch["player_future"][..., :2]) * noise_level, # pylint: disable=no-member
mean=target_mean,
std=torch.ones_like(target_mean) * noise_level, # pylint: disable=no-member
)

# Forward pass from the model.
Expand All @@ -195,10 +202,9 @@ def train_step(
is_at_traffic_light=batch["is_at_traffic_light"],
traffic_light_state=batch["traffic_light_state"],
)
_, log_prob, logabsdet = model._decoder._inverse(y=y, z=z)

# Calculates loss (NLL).
loss = -torch.mean(log_prob - logabsdet, dim=0) # pylint: disable=no-member
target_mean = batch["player_future"][..., :2]
loss = model._decoder.compute_loss(y=target_mean, z=z)

# Backward pass.
loss.backward()
Expand Down Expand Up @@ -240,13 +246,9 @@ def evaluate_step(
is_at_traffic_light=batch["is_at_traffic_light"],
traffic_light_state=batch["traffic_light_state"],
)
_, log_prob, logabsdet = model._decoder._inverse(
y=batch["player_future"][..., :2],
z=z,
)

# Calculates loss (NLL).
loss = -torch.mean(log_prob - logabsdet, dim=0) # pylint: disable=no-member

target_mean = batch["player_future"][..., :2]
loss = model._decoder.compute_loss(y=target_mean, z=z)

return loss

Expand Down
115 changes: 114 additions & 1 deletion oatomobile/torch/networks/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,13 @@
# ==============================================================================
"""Sequence generation."""

from typing import Tuple
from typing import Callable, List, Tuple

import torch
import torch.distributions as D
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils import weight_norm

from oatomobile.torch import types
from oatomobile.torch.networks.mlp import MLP
Expand Down Expand Up @@ -214,3 +215,115 @@ def _inverse(
logabsdet = torch.sum(logabsdet, dim=-1) # sum over T dimension # pylint: disable=no-member

return x, log_prob, logabsdet

def compute_loss(
self,
y: torch.Tensor,
z: torch.Tensor,
) -> torch.Tensor:
_, log_prob, logabsdet = self._inverse(y, z)
loss = -torch.mean(log_prob - logabsdet, dim=0) # pylint: disable=no-member
return loss


class Chomp1d(nn.Module):
def __init__(self, chomp_size: int):
super(Chomp1d, self).__init__()
self.chomp_size = chomp_size

def forward(self, x: torch.Tensor):
return x[:, :, :-self.chomp_size].contiguous()

class TemporalBlock(nn.Module):
def __init__(self, n_inputs: int, n_outputs: int, kernel_size: int, stride: int, dilation: int, padding: int, dropout: int=0.2):
super(TemporalBlock, self).__init__()
self.conv1 = weight_norm(nn.Conv1d(n_inputs, n_outputs, kernel_size,
stride=stride, padding=padding, dilation=dilation))
self.chomp1 = Chomp1d(padding)
self.relu1 = nn.ReLU()
self.dropout1 = nn.Dropout(dropout)

self.conv2 = weight_norm(nn.Conv1d(n_outputs, n_outputs, kernel_size,
stride=stride, padding=padding, dilation=dilation))
self.chomp2 = Chomp1d(padding)
self.relu2 = nn.ReLU()
self.dropout2 = nn.Dropout(dropout)

self.net = nn.Sequential(self.conv1, self.chomp1, self.relu1, self.dropout1,
self.conv2, self.chomp2, self.relu2, self.dropout2)
self.downsample = nn.Conv1d(n_inputs, n_outputs, 1) if n_inputs != n_outputs else None
self.relu = nn.ReLU()
self.init_weights()

def init_weights(self):
self.conv1.weight.data.normal_(0, 0.01)
self.conv2.weight.data.normal_(0, 0.01)
if self.downsample is not None:
self.downsample.weight.data.normal_(0, 0.01)

def forward(self, x: torch.Tensor):
out = self.net(x)
res = x if self.downsample is None else self.downsample(x)
return self.relu(out + res)

class TemporalConvNet(nn.Module):
def __init__(self, num_inputs: int, num_channels: List[int], kernel_size: int=2, dropout: int=0.2):
super(TemporalConvNet, self).__init__()
layers = []
num_levels = len(num_channels)
for i in range(num_levels):
dilation_size = 2 ** i
in_channels = num_inputs if i == 0 else num_channels[i-1]
out_channels = num_channels[i]
layers += [TemporalBlock(in_channels, out_channels, kernel_size, stride=1, dilation=dilation_size,
padding=(kernel_size-1) * dilation_size, dropout=dropout)]

self.network = nn.Sequential(*layers)

def forward(self, x: torch.Tensor):
return self.network(x)

class TCN(nn.Module):
"""
This generic TCN model converts the image + modalities representation to a sequence of grid points.

Args:
input_channels: length of input sequence, and in our case, this value is 1 (1 image)
num_input_features: number of features in input sequence, and in our case, this value is 64
num_output_features: number of features in output sequence, and in our case, this value is 2 (2D grid points)
num_channels: list of sizes of hidden layers
kernel_size: kernel size of 1D dilated convolution
dropout: dropout rate after every 1D dilated convolution

Usage:
In our setting, we initialize our TCN as:
TCN(
input_channels=1,
num_input_features=64,
num_output_features=2,
num_channels=[30, 30, 30, 30, 30, 30, 30, 30, 4],
kernel_size=7,
dropout=0.0
)
This converts the (batch_size, 1, num_input_features) inputs to (batch_size, num_channels[-1], num_output_features)
In our setting, this converts (1, 1, 64) inputs to (1, 4, 2) outputs.
"""
def __init__(self, input_channels, num_input_features, num_output_features, num_channels, kernel_size, dropout):
super(TCN, self).__init__()
self.tcn = TemporalConvNet(input_channels, num_channels, kernel_size=kernel_size, dropout=dropout)
# going from 64 image features to 2 grid points
self.linear = nn.Linear(num_input_features, num_output_features)
self.init_weights()

def init_weights(self):
self.linear.weight.data.normal_(0, 0.01)

def forward(self, x: torch.Tensor) -> torch.Tensor:
y1 = self.tcn(x)
return self.linear(y1)

def compute_loss(self, y: torch.Tensor, z: torch.Tensor) -> torch.Tensor:
z = z.unsqueeze(1) # convert to sequence for TCN processing
seq_prediction =self.forward(z)
loss = F.mse_loss(y, seq_prediction)
return loss