Skip to content

Commit

Permalink
fix(pu): fix DownSample for different obs shape (#254)
Browse files Browse the repository at this point in the history
  • Loading branch information
puyuan1996 authored Aug 15, 2024
1 parent 3f6cb5a commit 0064381
Showing 1 changed file with 11 additions and 6 deletions.
17 changes: 11 additions & 6 deletions lzero/model/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ class DownSample(nn.Module):
def __init__(self, observation_shape: SequenceType, out_channels: int,
activation: nn.Module = nn.ReLU(inplace=True),
norm_type: Optional[str] = 'BN',
num_resblocks: int = 1,
) -> None:
"""
Overview:
Expand All @@ -178,7 +179,8 @@ def __init__(self, observation_shape: SequenceType, out_channels: int,
- out_channels (:obj:`int`): The output channels of output hidden state.
- activation (:obj:`nn.Module`): The activation function used in network, defaults to nn.ReLU(inplace=True). \
Use the inplace operation to speed up.
- norm_type (:obj:`Optional[str]`): The normalization type used in network, defaults to 'BN'.
- norm_type (:obj:`Optional[str]`): The normalization type used in network, defaults to 'BN'.
- num_resblocks (:obj:`int`): The number of residual blocks. Defaults to 1.
"""
super().__init__()
assert norm_type in ['BN', 'LN'], "norm_type must in ['BN', 'LN']"
Expand Down Expand Up @@ -206,7 +208,7 @@ def __init__(self, observation_shape: SequenceType, out_channels: int,
norm_type=norm_type,
res_type='basic',
bias=False
) for _ in range(1)
) for _ in range(num_resblocks)
]
)
self.downsample_block = ResBlock(
Expand All @@ -221,7 +223,7 @@ def __init__(self, observation_shape: SequenceType, out_channels: int,
[
ResBlock(
in_channels=out_channels, activation=activation, norm_type=norm_type, res_type='basic', bias=False
) for _ in range(1)
) for _ in range(num_resblocks)
]
)
self.pooling1 = nn.AvgPool2d(kernel_size=3, stride=2, padding=1)
Expand Down Expand Up @@ -261,6 +263,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
elif self.observation_shape[1] == 96:
x = self.pooling2(x)
output = x
else:
raise NotImplementedError(f"DownSample for observation shape {self.observation_shape} is not implemented now. "
f"You should transform the observation shape to 64 or 96 in the env.")

return output

Expand Down Expand Up @@ -337,7 +342,6 @@ def __init__(

self.sim_norm = SimNorm(simnorm_dim=group_size)


def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Shapes:
Expand All @@ -355,8 +359,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
for block in self.resblocks:
x = block(x)

# NOTE: very important.
# for atari (64,8,8), flatten_size = 4096 -> 768
# Important: Transform the output feature plane to the latent state.
# For example, for an Atari feature plane of shape (64, 8, 8),
# flattening results in a size of 4096, which is then transformed to 768.
x = self.last_linear(x.reshape(-1, 64 * 8 * 8))
x = x.view(-1, self.embedding_dim)

Expand Down

0 comments on commit 0064381

Please sign in to comment.