Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MaskDINO ResNet50 #3977

Merged
merged 60 commits into from
Oct 8, 2024
Merged
Show file tree
Hide file tree
Changes from 56 commits
Commits
Show all changes
60 commits
Select commit Hold shift + click to select a range
de84ae5
update
eugene123tw Aug 13, 2024
faa52b4
add pixel encoders and transformer encoders
eugene123tw Aug 13, 2024
8a164c0
update
eugene123tw Aug 13, 2024
e845edd
format update
eugene123tw Aug 13, 2024
81e3c99
reformat
eugene123tw Aug 13, 2024
fd6fb91
add checkpoint loader
eugene123tw Aug 13, 2024
5172d18
add custom forward
eugene123tw Aug 13, 2024
a3cfa15
udpate
eugene123tw Aug 14, 2024
e527126
update postprocess
eugene123tw Aug 14, 2024
1b95f6d
fine tune model
eugene123tw Aug 15, 2024
3e43762
update criterion optimizers etc
eugene123tw Aug 15, 2024
ffe4abb
update loss
eugene123tw Aug 16, 2024
dc40a19
replace MSDeformableAttention with native pytorch ops impl
eugene123tw Aug 19, 2024
cf364b9
add export functionality
eugene123tw Aug 20, 2024
b9f9c9a
roi align mask extraction
eugene123tw Aug 21, 2024
5efad89
Revert "roi align mask extraction"
eugene123tw Aug 21, 2024
3e9179f
Merge branch 'develop' into eugene/maskdino
eugene123tw Aug 21, 2024
eb8379b
update
eugene123tw Aug 21, 2024
d9406f4
Revert "update"
eugene123tw Aug 21, 2024
f84ee81
Merge branch 'develop' into eugene/maskdino
eugene123tw Sep 20, 2024
b7af978
Upgrade OV, MAPI, and NNCF (#3967)
sovrasov Sep 20, 2024
a69a1f1
Tiling Semantic Seg (#3954)
eugene123tw Sep 22, 2024
ac493ad
udpate
eugene123tw Sep 23, 2024
0a6d834
update maskdino from_config
eugene123tw Sep 23, 2024
2c6ebcd
update maskdino config
eugene123tw Sep 24, 2024
183cc1c
update maskdino config 2
eugene123tw Sep 24, 2024
13c9244
Merge branch 'develop' into eugene/maskdino
eugene123tw Sep 24, 2024
1460c9c
decouple detectron resnet50
eugene123tw Sep 24, 2024
1083b2a
decouple detectron resnet50
eugene123tw Sep 24, 2024
c7f17d2
decouple detectron resnet50
eugene123tw Sep 24, 2024
32dba58
mypy pylint
eugene123tw Sep 25, 2024
7224806
* fix all pylint issues
eugene123tw Sep 25, 2024
ed8810a
remove checkpoint
eugene123tw Sep 27, 2024
dbf0c8a
Merge branch 'develop' into eugene/maskdino
eugene123tw Sep 27, 2024
84daa00
update test
eugene123tw Sep 27, 2024
816743d
update test
eugene123tw Sep 27, 2024
ffb9867
address comments
eugene123tw Sep 30, 2024
9814f08
remove inverse_sigmoid
eugene123tw Sep 30, 2024
e962334
update weight clip optimizer
eugene123tw Sep 30, 2024
74dfafd
change get_clones to public func
eugene123tw Sep 30, 2024
4094274
add tile recipe
eugene123tw Sep 30, 2024
8cd9f17
fix export
eugene123tw Sep 30, 2024
bca6948
remove export names
eugene123tw Oct 1, 2024
413b6c0
optimize denoise training
eugene123tw Oct 1, 2024
0ff0075
optimize denoise training
eugene123tw Oct 1, 2024
6b2cd02
refactor inverse sigmoid
eugene123tw Oct 2, 2024
3cce03e
Merge branch 'develop' into eugene/maskdino
eugene123tw Oct 2, 2024
7bc99fd
update inverse sigmoid import
eugene123tw Oct 2, 2024
4bed3f7
update ptq config
eugene123tw Oct 2, 2024
facbd0f
fix export
eugene123tw Oct 3, 2024
3eb253f
remove private underscore
eugene123tw Oct 4, 2024
c3e5f64
update license
eugene123tw Oct 4, 2024
8665486
user lightning gradient clipping
eugene123tw Oct 4, 2024
d5576f8
update docstring
eugene123tw Oct 7, 2024
c8f5578
improve loss docstring
eugene123tw Oct 7, 2024
a071a22
pre commit
eugene123tw Oct 7, 2024
7dd8c35
Merge branch 'develop' into eugene/maskdino
eugene123tw Oct 7, 2024
a2ff755
add unit test
eugene123tw Oct 7, 2024
e7259e6
fix test
eugene123tw Oct 8, 2024
f7d0dd1
reformat
eugene123tw Oct 8, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions src/otx/algo/detection/heads/rtdetr_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,7 @@
from torch import nn
from torch.nn import init

from otx.algo.detection.utils.utils import (
inverse_sigmoid,
)
from otx.algo.common.utils.utils import inverse_sigmoid
from otx.algo.modules.base_module import BaseModule
from otx.algo.modules.transformer import deformable_attention_core_func

Expand Down Expand Up @@ -236,8 +234,6 @@ def forward(
value = self.value_proj(value)
if value_mask is not None:
value = value.masked_fill(value_mask[..., None], float(0))
# value_mask = value_mask.astype(value.dtype).unsqueeze(-1)
# value3 = value * value_mask.unsqueeze(-1)
value = value.reshape(bs, len_v, self.num_heads, self.head_dim)

sampling_offsets = self.sampling_offsets(query).reshape(
Expand All @@ -263,7 +259,11 @@ def forward(
)

if reference_points.shape[-1] == 2:
offset_normalizer = value_spatial_shapes.clone()
offset_normalizer = (
value_spatial_shapes
if isinstance(value_spatial_shapes, torch.Tensor)
else torch.tensor(value_spatial_shapes)
).clone()
offset_normalizer = offset_normalizer.flip([1]).reshape(1, 1, 1, self.num_levels, 1, 2)
sampling_locations = (
reference_points.reshape(
Expand Down
6 changes: 0 additions & 6 deletions src/otx/algo/detection/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,12 +185,6 @@ def backward(ctx, grad_output) -> tuple[Tensor, Tensor]: # noqa: D102, ANN001
sigmoid_geometric_mean = SigmoidGeometricMean.apply


def inverse_sigmoid(x: torch.Tensor, eps: float = 1e-5) -> torch.Tensor:
"""Compute the inverse of sigmoid function."""
x = x.clip(min=0.0, max=1.0)
return torch.log(x.clip(min=eps) / (1 - x).clip(min=eps))


def auto_pad(kernel_size: int | tuple[int, int], dilation: int | tuple[int, int] = 1, **kwargs) -> tuple[int, int]: # noqa: ARG001
"""Auto Padding for the convolution blocks.

Expand Down
336 changes: 336 additions & 0 deletions src/otx/algo/instance_segmentation/backbones/detectron_resnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,336 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) Facebook, Inc. and its affiliates.

"""Implementation modified from Detectron2 ResNet.

Reference: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/resnet.py
"""


from __future__ import annotations

import numpy as np
import torch.nn.functional as f
from torch import Tensor, nn

from otx.algo.instance_segmentation.layers.batch_norm import CNNBlockBase, get_norm
from otx.algo.instance_segmentation.utils.utils import Conv2d, ShapeSpec, c2_msra_fill

__all__ = [
"BottleneckBlock",
"BasicStem",
"ResNet",
"build_resnet_backbone",
]


class BottleneckBlock(CNNBlockBase):
"""The standard bottleneck residual block used by ResNet-50, 101 and 152."""

def __init__(
self,
in_channels: int,
out_channels: int,
bottleneck_channels: int,
stride: int = 1,
num_groups: int = 1,
norm: str = "BN",
stride_in_1x1: bool = False,
dilation: int = 1,
) -> None:
super().__init__(in_channels, out_channels, stride)

if in_channels != out_channels:
self.shortcut = Conv2d(
in_channels,
out_channels,
kernel_size=1,
stride=stride,
bias=False,
norm=get_norm(norm, out_channels),
)
else:
self.shortcut = None # type: ignore[assignment]

# The original MSRA ResNet models have stride in the first 1x1 conv
# The subsequent fb.torch.resnet and Caffe2 ResNe[X]t implementations have
# stride in the 3x3 conv
stride_1x1, stride_3x3 = (stride, 1) if stride_in_1x1 else (1, stride)

self.conv1 = Conv2d(
in_channels,
bottleneck_channels,
kernel_size=1,
stride=stride_1x1,
bias=False,
norm=get_norm(norm, bottleneck_channels),
)

self.conv2 = Conv2d(
bottleneck_channels,
bottleneck_channels,
kernel_size=3,
stride=stride_3x3,
padding=1 * dilation,
bias=False,
groups=num_groups,
dilation=dilation,
norm=get_norm(norm, bottleneck_channels),
)

self.conv3 = Conv2d(
bottleneck_channels,
out_channels,
kernel_size=1,
bias=False,
norm=get_norm(norm, out_channels),
)

for layer in [self.conv1, self.conv2, self.conv3, self.shortcut]:
if layer is not None: # shortcut can be None
c2_msra_fill(layer)

def forward(self, x: Tensor) -> Tensor:
"""Forward pass."""
out = self.conv1(x)
out = f.relu_(out)

out = self.conv2(out)
out = f.relu_(out)

out = self.conv3(out)

shortcut = self.shortcut(x) if self.shortcut is not None else x

out += shortcut
return f.relu_(out)


class BasicStem(CNNBlockBase):
"""The standard ResNet stem (layers before the first residual block), with a conv, relu and max_pool."""

def __init__(self, in_channels: int = 3, out_channels: int = 64, norm: str = "BN") -> None:
super().__init__(in_channels, out_channels, 4)
self.in_channels = in_channels
self.conv1 = Conv2d(
in_channels,
out_channels,
kernel_size=7,
stride=2,
padding=3,
bias=False,
norm=get_norm(norm, out_channels),
)
c2_msra_fill(self.conv1)

def forward(self, x: Tensor) -> Tensor:
"""Forward pass."""
x = self.conv1(x)
x = f.relu_(x)
return f.max_pool2d(x, kernel_size=3, stride=2, padding=1)


class ResNet(nn.Module):
sovrasov marked this conversation as resolved.
Show resolved Hide resolved
"""Implement :paper:`ResNet`."""

def __init__(
self,
stem: nn.Module,
stages: list[list[CNNBlockBase]],
out_features: tuple[str, ...],
freeze_at: int = 0,
) -> None:
super().__init__()
self.stem = stem

current_stride = self.stem.stride
self._out_feature_strides = {"stem": current_stride}
self._out_feature_channels = {"stem": self.stem.out_channels}

self.stage_names, self.stages = [], []

if out_features is not None:
# Avoid keeping unused layers in this module. They consume extra memory
# and may cause allreduce to fail
num_stages = max(
[{"res2": 1, "res3": 2, "res4": 3, "res5": 4}.get(f, 0) for f in out_features],
)
stages = stages[:num_stages]
for i, blocks in enumerate(stages):
name = "res" + str(i + 2)
stage = nn.Sequential(*blocks)

self.add_module(name, stage)
self.stage_names.append(name)
self.stages.append(stage)

self._out_feature_strides[name] = current_stride = int(
current_stride * np.prod([k.stride for k in blocks]),
)
self._out_feature_channels[name] = blocks[-1].out_channels

# Make it static for scripting
self.stage_names = tuple(self.stage_names) # type: ignore [assignment]

self._out_features = out_features
self.freeze(freeze_at)

def forward(self, x: Tensor) -> dict[str, Tensor]:
"""Forward pass."""
outputs = {}
x = self.stem(x)
if "stem" in self._out_features:
outputs["stem"] = x
for name, stage in zip(self.stage_names, self.stages, strict=True):
x = stage(x)
if name in self._out_features:
outputs[name] = x
return outputs

def freeze(self, freeze_at: int = 0) -> nn.Module:
"""Freeze the first several stages of the ResNet. Commonly used in fine-tuning.

Layers that produce the same feature map spatial size are defined as one
"stage" by :paper:`FPN`.

Args:
freeze_at (int): number of stages to freeze.
`1` means freezing the stem. `2` means freezing the stem and
one residual stage, etc.

Returns:
nn.Module: this ResNet itself
"""
if freeze_at >= 1:
self.stem.freeze()
for idx, stage in enumerate(self.stages, start=2):
if freeze_at >= idx:
for block in stage.children():
block.freeze()
return self

@staticmethod
def make_stage(
block_class: nn.Module,
num_blocks: int,
*,
in_channels: int,
out_channels: int,
**kwargs,
) -> list[CNNBlockBase]:
"""Create a list of blocks of the same type that forms one ResNet stage.

Args:
block_class (type): a subclass of CNNBlockBase that's used to create all blocks in this
stage. A module of this type must not change spatial resolution of inputs unless its
stride != 1.
num_blocks (int): number of blocks in this stage
in_channels (int): input channels of the entire stage.
out_channels (int): output channels of **every block** in the stage.
kwargs: other arguments passed to the constructor of
`block_class`. If the argument name is "xx_per_block", the
argument is a list of values to be passed to each block in the
stage. Otherwise, the same argument is passed to every block
in the stage.

Returns:
list[CNNBlockBase]: a list of block module.

Examples:
::
stage = ResNet.make_stage(
BottleneckBlock, 3, in_channels=16, out_channels=64,
bottleneck_channels=16, num_groups=1,
stride_per_block=[2, 1, 1],
dilations_per_block=[1, 1, 2]
)

Usually, layers that produce the same feature map spatial size are defined as one
"stage" (in :paper:`FPN`). Under such definition, ``stride_per_block[1:]`` should
all be 1.
"""
blocks = []
for i in range(num_blocks):
curr_kwargs = {}
for k, v in kwargs.items():
if k.endswith("_per_block"):
newk = k[: -len("_per_block")]
curr_kwargs[newk] = v[i]
else:
curr_kwargs[k] = v

blocks.append(
block_class(in_channels=in_channels, out_channels=out_channels, **curr_kwargs),
)
in_channels = out_channels
return blocks

def output_shape(self) -> dict[str, ShapeSpec]:
"""Returns output shapes for each stage."""
return {
name: ShapeSpec(
channels=self._out_feature_channels[name],
stride=self._out_feature_strides[name],
)
for name in self._out_features
}


def build_resnet_backbone(
norm: str,
stem_out_channels: int,
input_shape: ShapeSpec,
freeze_at: int,
out_features: tuple[str, ...],
depth: int,
num_groups: int,
width_per_group: int,
in_channels: int,
out_channels: int,
stride_in_1x1: bool,
res5_dilation: int = 1,
) -> ResNet:
"""Create a ResNet instance from config.

Returns:
ResNet: a :class:`ResNet` instance.
"""
# need registration of new blocks/stems?
stem = BasicStem(
in_channels=input_shape.channels,
out_channels=stem_out_channels,
norm=norm,
)

bottleneck_channels = width_per_group * num_groups
num_blocks_per_stage = {
50: [3, 4, 6, 3],
101: [3, 4, 23, 3],
152: [3, 8, 36, 3],
}[depth]

stages = []

for idx, stage_idx in enumerate(range(2, 6)):
# res5_dilation is used this way as a convention in R-FCN & Deformable Conv paper
dilation = res5_dilation if stage_idx == 5 else 1
first_stride = 1 if idx == 0 or (stage_idx == 5 and dilation == 2) else 2
stage_kargs = {
"num_blocks": num_blocks_per_stage[idx],
"stride_per_block": [first_stride] + [1] * (num_blocks_per_stage[idx] - 1),
"in_channels": in_channels,
"out_channels": out_channels,
"norm": norm,
}
stage_kargs["bottleneck_channels"] = bottleneck_channels
stage_kargs["stride_in_1x1"] = stride_in_1x1
stage_kargs["dilation"] = dilation
stage_kargs["num_groups"] = num_groups
stage_kargs["block_class"] = BottleneckBlock
blocks = ResNet.make_stage(**stage_kargs) # type: ignore[arg-type]
in_channels = out_channels
out_channels *= 2
bottleneck_channels *= 2
stages.append(blocks)
return ResNet(stem, stages, out_features=out_features, freeze_at=freeze_at)
Loading
Loading