diff --git a/torchsparse/nn/modules/bev.py b/torchsparse/nn/modules/bev.py index dac3a33..10f270b 100644 --- a/torchsparse/nn/modules/bev.py +++ b/torchsparse/nn/modules/bev.py @@ -93,7 +93,7 @@ def forward(self, input: SparseTensor) -> torch.Tensor: self.kernel, 0, torch.div(coords[:, self.dim], stride).trunc().long() ) feats = (feats.unsqueeze(dim=-1) * kernel).sum(1) + self.bias - coords = (coords - self.offset).t()[[3] + self.bev_dims].long() + coords = (coords - self.offset).t()[[0] + self.bev_dims].long() coords[1:] = torch.div(coords[1:], stride).trunc().long() indices = ( coords[0] * int(self.bev_shape.prod()) @@ -197,7 +197,7 @@ def forward(self, input: SparseTensor) -> torch.Tensor: assert isinstance(stride, torch.Tensor), type(stride) # [b, x, y, z] - coords = (coords - self.offset).t()[[3] + self.bev_dims + [self.dim]].long() + coords = (coords - self.offset).t()[[0] + self.bev_dims + [self.dim]].long() shape = self.shape[self.bev_dims + [self.dim]] # now stride must be torch.Tensor since input.s is tuple.