Skip to content

Commit

Permalink
Use 3x3 conv for refinement.
Browse files Browse the repository at this point in the history
  • Loading branch information
mitsuse committed Oct 29, 2024
1 parent 0b950ad commit 0f78913
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions src/torch_wae/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def __init__(self, s: int) -> None:
nn.Conv2d(1 * s, 1 * s, 3, stride=1, padding=1),
nn.BatchNorm2d(1 * s),
nn.LeakyReLU(),
nn.Conv2d(1 * s, 1, 1, stride=1),
nn.Conv2d(1 * s, 1, 3, stride=1, padding=1),
)

def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
Expand All @@ -184,7 +184,7 @@ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
h_6 = self.decode_2(h_5 + h_1)
h_7 = self.decode_3(h_6 + h_0)

return self.refine(h_7)
return self.refine(h_7 + x)


class Preprocess(nn.Module):
Expand Down

0 comments on commit 0f78913

Please sign in to comment.