From 8475130ced15fd08117ff8ccfb1fbc01d00ab2d7 Mon Sep 17 00:00:00 2001 From: Harim Kang Date: Fri, 27 Sep 2024 14:14:22 +0900 Subject: [PATCH] Refactor to allow for a wider model in TIMM (#3976) * update for releases 2.2.0rc0 * Fix Classification explain forward issue (#3867) Fix bug * Fix e2e code error (#3871) * Update test_cli.py * Update tests/e2e/cli/test_cli.py Co-authored-by: Eunwoo Shin * Update test_cli.py * Update test_cli.py --------- Co-authored-by: Eunwoo Shin * Add documentation about configurable input size (#3870) * add docs about configurable input size * update api usecase and fix bug * Fix zero-shot e2e (#3876) Fix * Fix DeiT for multi-label classification (#3881) Remove init_args * Fix Semi-SL for ViT accuracy drop (#3883) Remove init_args * Update docs for 2.2 (#3884) Update docs * Fix mean and scale for segmentation task (#3885) fix mean and scale * Update MAPI in 2.2 (#3889) * Bump MAPI * Update exportable code requirements * Improve Semi-SL for LiteHRNet (small-medium case) (#3891) * change drop pixels value * go safe, change only tested models * minor * Improve h-cls for eff models (#3893) * Update step size for eff v2 * Update effb0 recipe * Fix maskrcnn swin nncf acc drop (#3900) update maskrcnn swimt model type to transformer * Add keypoint detection recipe for single object cases (#3903) * add rtmpose_tiny for single obj * add rtmpose_tiny for single obj * modify test subset name * fix unit test * update recipe with reset * Improve acc drop of efficientnetv2 for h-label cls (#3907) * Add warmup_iters for effv2 * Update max_epochs * Fix pretrained weight cached dir for timm (#3909) * Fix pretrained_weight for timm * Fix unit-test * Fix keypoint detection single obj recipe (#3915) * add rtmpose_tiny for single obj * modify test subset name * fix unit test * property for pck * Fix cached dir for timm & hugging-face (#3914) * Fix cached dir * Pretrained weight download unit-test * Fix pre-commit * Fix wrong template id mapping for anomaly (#3916) * Update script to allow setting otx version using env. variable (#3913) * Fix Datamodule creation for OV in AutoConfigurator (#3920) Fix datamodule for ov * Update tpp file for 2.2.0 (#3921) * Fix names for ignored scope [HOT-FIX, 2.2.0] (#3924) fix names for ignored scope * Fix classification rt_info (#3922) * Restore output_raw_scores for classificaiton * Add uts * Fix linter * Update label info (#3925) add label info to init Signed-off-by: Ashwin Vaidya * Fix binary classification metric task (#3928) * Fix binary classification * Add unit-tests * Improve MaskRCNN SwinT NNCF (#3929) * ignore heads and disable smooth quant * add activations_range_estimator_params * update changelog * Fix get_item for Chained Tasks in Classification (#3931) * Fix Task Chain * Add multi-label case as well * Add multi-label case as well2 * Add H-label case * Correct Keyerror for h-label cls in label_groups for dm_label_categories using label's id/key (#3932) Modify label_groups for dm_label_categories with id/key of label * Remove datumaro attribute id from tiling, add subset names (#3933) * remove datumaro attribute id from tiling * add subset names * Fix soft predictions for Semantic Segmentation (#3934) fix soft preds * Update STFPM config (#3935) * Add missing pretrained weights when creating a docker image (#3938) * Fix pre-trained weight downloader * Remove if condition for pretrained wiehgt download * Change default option 'full' to 'base' in otx install (#3937) * Change option full to base for otx install * Fix wrong code * Fix issue * Fix docs * Fix auto adapt batch size in Converter (#3939) * Enable auto adapt batch size into converter * Fix wrong * Fix hpo converter (#3940) * save best hp after hpo * add test * Fix tiling XAI out of range (#3943) - Fix tile merge XAI out of range * enable model export (#3952) Signed-off-by: Ashwin Vaidya * Move templates from OTX1.X to OTX2.X (#3951) * add otx1.6 templates * added new models * delete entrypoints and nncf cfg * updated some hyperparams * fix for rtmdet_tiny * updated converter * Update classification templates * Update det, r-det, vpm * Update template.yaml * changed warmaup value in train.yaml --------- Co-authored-by: Kang, Harim Co-authored-by: Kim, Sungchul * Add missing tile recipes and various tile recipe changes (#3942) * add missing tile recipes * Fix tiling XAI out of range (#3943) - Fix tile merge XAI out of range * update xai tile merge * update rtdetr * update tile recipes * update rtdetr tile postprocess * update rtdetr recipes and tile recipes * update tile recipes * fix rtdetr unittest * update recipes * refactor tile unit test * address pr reviews * remove unnecessary files * update color channel * fix image channel passing * include tiling in cli integration test * remove transform_bbox --------- Co-authored-by: Vladislav Sovrasov * Support ImageFromBytes (#3948) * add image_from_bytes Signed-off-by: Ashwin Vaidya * refactor code Signed-off-by: Ashwin Vaidya * allow empty anomalous masks Signed-off-by: Ashwin Vaidya --------- Signed-off-by: Ashwin Vaidya * Change categories mapping logic (#3946) * change pre-filtering logic * Update src/otx/core/data/pre_filtering.py Co-authored-by: Eunwoo Shin --------- Co-authored-by: Eunwoo Shin * Update for 2.2.0rc1 (#3956) * Include Geti arrow dataset subset names (#3962) * restrited number of output masks by tiling * add geti subset name * update num of max pred * Include full image with anno in case there's no tile in tile dataset (#3964) * include full image with anno incase there's no tile in dataset * update test * Add type checker in converter for callable functions (optimizer, scheduler) (#3968) Fix converter callable functions (optimizer, scheduler) * Update for 2.2.0rc2 (#3969) update for 2.2.0rc2 * Refactor TIMM * Remove experimental recipes * Revert timm version * Fix conflict2 * Fix unit-test --------- Signed-off-by: Ashwin Vaidya Co-authored-by: Yunchu Lee Co-authored-by: Emily Chun Co-authored-by: Eunwoo Shin Co-authored-by: Kim, Sungchul Co-authored-by: Prokofiev Kirill Co-authored-by: Vladislav Sovrasov Co-authored-by: Sooah Lee Co-authored-by: Eugene Liu Co-authored-by: Wonju Lee Co-authored-by: Ashwin Vaidya --- src/otx/algo/classification/backbones/timm.py | 39 ++----- src/otx/algo/classification/timm_model.py | 104 +++++++++++++++--- .../h_label_cls/efficientnet_v2.yaml | 2 +- .../multi_class_cls/efficientnet_v2.yaml | 2 +- .../semisl/efficientnet_v2_semisl.yaml | 2 +- .../multi_label_cls/efficientnet_v2.yaml | 2 +- .../classification/backbones/test_timm.py | 6 +- .../algo/classification/test_timm_model.py | 6 +- 8 files changed, 112 insertions(+), 51 deletions(-) diff --git a/src/otx/algo/classification/backbones/timm.py b/src/otx/algo/classification/backbones/timm.py index 7bafa0b1dbb..0dc7ea1fce5 100644 --- a/src/otx/algo/classification/backbones/timm.py +++ b/src/otx/algo/classification/backbones/timm.py @@ -1,7 +1,7 @@ # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 -"""EfficientNetV2 model. +"""Timm Backbone Class for OTX classification. Original papers: - 'EfficientNetV2: Smaller Models and Faster Training,' https://arxiv.org/abs/2104.00298, @@ -9,49 +9,39 @@ """ from __future__ import annotations -from typing import Literal - import timm import torch from torch import nn -TimmModelType = Literal[ - "mobilenetv3_large_100_miil_in21k", - "mobilenetv3_large_100_miil", - "tresnet_m", - "tf_efficientnetv2_s.in21k", - "tf_efficientnetv2_s.in21ft1k", - "tf_efficientnetv2_m.in21k", - "tf_efficientnetv2_m.in21ft1k", - "tf_efficientnetv2_b0", -] - class TimmBackbone(nn.Module): - """Timm backbone model.""" + """Timm backbone model. + + Args: + model_name (str): The name of the model. + You can find available models at timm.list_models() or timm.list_pretrained(). + pretrained (bool, optional): Whether to load pretrained weights. Defaults to False. + """ def __init__( self, - backbone: TimmModelType, + model_name: str, pretrained: bool = False, - pooling_type: str = "avg", **kwargs, ): super().__init__(**kwargs) - self.backbone = backbone + self.model_name = model_name self.pretrained: bool | dict = pretrained - self.is_mobilenet = backbone.startswith("mobilenet") self.model = timm.create_model( - self.backbone, + self.model_name, pretrained=pretrained, num_classes=1000, ) self.model.classifier = None # Detach classifier. Only use 'backbone' part in otx. self.num_head_features = self.model.num_features - self.num_features = self.model.conv_head.in_channels if self.is_mobilenet else self.model.num_features - self.pooling_type = pooling_type + self.num_features = self.model.num_features def forward(self, x: torch.Tensor, **kwargs) -> tuple[torch.Tensor]: """Forward.""" @@ -60,11 +50,6 @@ def forward(self, x: torch.Tensor, **kwargs) -> tuple[torch.Tensor]: def extract_features(self, x: torch.Tensor) -> torch.Tensor: """Extract features.""" - if self.is_mobilenet: - x = self.model.conv_stem(x) - x = self.model.bn1(x) - x = self.model.act1(x) - return self.model.blocks(x) return self.model.forward_features(x) def get_config_optim(self, lrs: list[float] | float) -> list[dict[str, float]]: diff --git a/src/otx/algo/classification/timm_model.py b/src/otx/algo/classification/timm_model.py index d7e171565a7..bfc8960dd85 100644 --- a/src/otx/algo/classification/timm_model.py +++ b/src/otx/algo/classification/timm_model.py @@ -1,7 +1,7 @@ # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 -"""EfficientNetV2 model implementation.""" +"""TIMM wrapper model class for OTX.""" from __future__ import annotations @@ -12,7 +12,7 @@ import torch from torch import nn -from otx.algo.classification.backbones.timm import TimmBackbone, TimmModelType +from otx.algo.classification.backbones.timm import TimmBackbone from otx.algo.classification.classifier import HLabelClassifier, ImageClassifier, SemiSLClassifier from otx.algo.classification.heads import ( HierarchicalCBAMClsHead, @@ -50,12 +50,38 @@ class TimmModelForMulticlassCls(OTXMulticlassClsModel): - """TimmModel for multi-class classification task.""" + """TimmModel for multi-class classification task. + + Args: + label_info (LabelInfoTypes): The label information for the classification task. + model_name (str): The name of the model. + You can find available models at timm.list_models() or timm.list_pretrained(). + input_size (tuple[int, int], optional): Model input size in the order of height and width. + Defaults to (224, 224). + pretrained (bool, optional): Whether to load pretrained weights. Defaults to True. + optimizer (OptimizerCallable, optional): The optimizer callable for training the model. + scheduler (LRSchedulerCallable | LRSchedulerListCallable, optional): The learning rate scheduler callable. + metric (MetricCallable, optional): The metric callable for evaluating the model. + Defaults to MultiClassClsMetricCallable. + torch_compile (bool, optional): Whether to compile the model using TorchScript. Defaults to False. + train_type (Literal[OTXTrainType.SUPERVISED, OTXTrainType.SEMI_SUPERVISED], optional): The training type. + + Example: + 1. API + >>> model = TimmModelForMulticlassCls( + ... model_name="tf_efficientnetv2_s.in21k", + ... label_info=, + ... ) + 2. CLI + >>> otx train \ + ... --model otx.algo.classification.timm_model.TimmModelForMulticlassCls \ + ... --model.model_name tf_efficientnetv2_s.in21k + """ def __init__( self, label_info: LabelInfoTypes, - backbone: TimmModelType, + model_name: str, input_size: tuple[int, int] = (224, 224), # input size of default classification data recipe pretrained: bool = True, optimizer: OptimizerCallable = DefaultOptimizerCallable, @@ -64,7 +90,7 @@ def __init__( torch_compile: bool = False, train_type: Literal[OTXTrainType.SUPERVISED, OTXTrainType.SEMI_SUPERVISED] = OTXTrainType.SUPERVISED, ) -> None: - self.backbone = backbone + self.model_name = model_name self.pretrained = pretrained super().__init__( @@ -92,7 +118,7 @@ def _create_model(self) -> nn.Module: return model def _build_model(self, num_classes: int) -> nn.Module: - backbone = TimmBackbone(backbone=self.backbone, pretrained=self.pretrained) + backbone = TimmBackbone(model_name=self.model_name, pretrained=self.pretrained) neck = GlobalAveragePooling(dim=2) if self.train_type == OTXTrainType.SEMI_SUPERVISED: return SemiSLClassifier( @@ -142,12 +168,37 @@ def forward_for_tracing(self, image: torch.Tensor) -> torch.Tensor | dict[str, t class TimmModelForMultilabelCls(OTXMultilabelClsModel): - """TimmModel for multi-label classification task.""" + """TimmModel for multi-label classification task. + + Args: + label_info (LabelInfoTypes): The label information for the classification task. + model_name (str): The name of the model. + You can find available models at timm.list_models() or timm.list_pretrained(). + input_size (tuple[int, int], optional): Model input size in the order of height and width. + Defaults to (224, 224). + pretrained (bool, optional): Whether to load pretrained weights. Defaults to True. + optimizer (OptimizerCallable, optional): The optimizer callable for training the model. + scheduler (LRSchedulerCallable | LRSchedulerListCallable, optional): The learning rate scheduler callable. + metric (MetricCallable, optional): The metric callable for evaluating the model. + Defaults to MultiLabelClsMetricCallable. + torch_compile (bool, optional): Whether to compile the model using TorchScript. Defaults to False. + + Example: + 1. API + >>> model = TimmModelForMultilabelCls( + ... model_name="tf_efficientnetv2_s.in21k", + ... label_info=, + ... ) + 2. CLI + >>> otx train \ + ... --model otx.algo.classification.timm_model.TimmModelForMultilabelCls \ + ... --model.model_name tf_efficientnetv2_s.in21k + """ def __init__( self, label_info: LabelInfoTypes, - backbone: TimmModelType, + model_name: str, input_size: tuple[int, int] = (224, 224), # input size of default classification data recipe pretrained: bool = True, optimizer: OptimizerCallable = DefaultOptimizerCallable, @@ -155,7 +206,7 @@ def __init__( metric: MetricCallable = MultiLabelClsMetricCallable, torch_compile: bool = False, ) -> None: - self.backbone = backbone + self.model_name = model_name self.pretrained = pretrained super().__init__( @@ -182,7 +233,7 @@ def _create_model(self) -> nn.Module: return model def _build_model(self, num_classes: int) -> nn.Module: - backbone = TimmBackbone(backbone=self.backbone, pretrained=self.pretrained) + backbone = TimmBackbone(model_name=self.model_name, pretrained=self.pretrained) return ImageClassifier( backbone=backbone, neck=GlobalAveragePooling(dim=2), @@ -222,14 +273,39 @@ def forward_for_tracing(self, image: torch.Tensor) -> torch.Tensor | dict[str, t class TimmModelForHLabelCls(OTXHlabelClsModel): - """EfficientNetV2 Model for hierarchical label classification task.""" + """Timm Model for hierarchical label classification task. + + Args: + label_info (HLabelInfo): The label information for the classification task. + model_name (str): The name of the model. + You can find available models at timm.list_models() or timm.list_pretrained(). + input_size (tuple[int, int], optional): Model input size in the order of height and width. + Defaults to (224, 224). + pretrained (bool, optional): Whether to load pretrained weights. Defaults to True. + optimizer (OptimizerCallable, optional): The optimizer callable for training the model. + scheduler (LRSchedulerCallable | LRSchedulerListCallable, optional): The learning rate scheduler callable. + metric (MetricCallable, optional): The metric callable for evaluating the model. + Defaults to HLabelClsMetricCallable. + torch_compile (bool, optional): Whether to compile the model using TorchScript. Defaults to False. + + Example: + 1. API + >>> model = TimmModelForHLabelCls( + ... model_name="tf_efficientnetv2_s.in21k", + ... label_info=, + ... ) + 2. CLI + >>> otx train \ + ... --model otx.algo.classification.timm_model.TimmModelForHLabelCls \ + ... --model.model_name tf_efficientnetv2_s.in21k + """ label_info: HLabelInfo def __init__( self, label_info: HLabelInfo, - backbone: TimmModelType, + model_name: str, input_size: tuple[int, int] = (224, 224), # input size of default classification data recipe pretrained: bool = True, optimizer: OptimizerCallable = DefaultOptimizerCallable, @@ -237,7 +313,7 @@ def __init__( metric: MetricCallable = HLabelClsMetricCallable, torch_compile: bool = False, ) -> None: - self.backbone = backbone + self.model_name = model_name self.pretrained = pretrained super().__init__( @@ -267,7 +343,7 @@ def _create_model(self) -> nn.Module: return model def _build_model(self, head_config: dict) -> nn.Module: - backbone = TimmBackbone(backbone=self.backbone, pretrained=self.pretrained) + backbone = TimmBackbone(model_name=self.model_name, pretrained=self.pretrained) copied_head_config = copy(head_config) copied_head_config["step_size"] = (ceil(self.input_size[0] / 32), ceil(self.input_size[1] / 32)) return HLabelClassifier( diff --git a/src/otx/recipe/classification/h_label_cls/efficientnet_v2.yaml b/src/otx/recipe/classification/h_label_cls/efficientnet_v2.yaml index fc3f6abeab8..1dc6f209979 100644 --- a/src/otx/recipe/classification/h_label_cls/efficientnet_v2.yaml +++ b/src/otx/recipe/classification/h_label_cls/efficientnet_v2.yaml @@ -1,7 +1,7 @@ model: class_path: otx.algo.classification.timm_model.TimmModelForHLabelCls init_args: - backbone: tf_efficientnetv2_s.in21k + model_name: tf_efficientnetv2_s.in21k optimizer: class_path: torch.optim.SGD diff --git a/src/otx/recipe/classification/multi_class_cls/efficientnet_v2.yaml b/src/otx/recipe/classification/multi_class_cls/efficientnet_v2.yaml index 0cb77ef8852..0dd6daf26f7 100644 --- a/src/otx/recipe/classification/multi_class_cls/efficientnet_v2.yaml +++ b/src/otx/recipe/classification/multi_class_cls/efficientnet_v2.yaml @@ -2,7 +2,7 @@ model: class_path: otx.algo.classification.timm_model.TimmModelForMulticlassCls init_args: label_info: 1000 - backbone: tf_efficientnetv2_s.in21k + model_name: tf_efficientnetv2_s.in21k optimizer: class_path: torch.optim.SGD diff --git a/src/otx/recipe/classification/multi_class_cls/semisl/efficientnet_v2_semisl.yaml b/src/otx/recipe/classification/multi_class_cls/semisl/efficientnet_v2_semisl.yaml index 0bf81b8e05d..1187d41d3ba 100644 --- a/src/otx/recipe/classification/multi_class_cls/semisl/efficientnet_v2_semisl.yaml +++ b/src/otx/recipe/classification/multi_class_cls/semisl/efficientnet_v2_semisl.yaml @@ -2,7 +2,7 @@ model: class_path: otx.algo.classification.timm_model.TimmModelForMulticlassCls init_args: label_info: 1000 - backbone: tf_efficientnetv2_s.in21k + model_name: tf_efficientnetv2_s.in21k train_type: SEMI_SUPERVISED optimizer: diff --git a/src/otx/recipe/classification/multi_label_cls/efficientnet_v2.yaml b/src/otx/recipe/classification/multi_label_cls/efficientnet_v2.yaml index 87177eb1e17..14f3b605f12 100644 --- a/src/otx/recipe/classification/multi_label_cls/efficientnet_v2.yaml +++ b/src/otx/recipe/classification/multi_label_cls/efficientnet_v2.yaml @@ -2,7 +2,7 @@ model: class_path: otx.algo.classification.timm_model.TimmModelForMultilabelCls init_args: label_info: 1000 - backbone: tf_efficientnetv2_s.in21k + model_name: tf_efficientnetv2_s.in21k optimizer: class_path: torch.optim.SGD diff --git a/tests/unit/algo/classification/backbones/test_timm.py b/tests/unit/algo/classification/backbones/test_timm.py index 800f45520f3..b53eda0aa43 100644 --- a/tests/unit/algo/classification/backbones/test_timm.py +++ b/tests/unit/algo/classification/backbones/test_timm.py @@ -11,11 +11,11 @@ class TestOTXEfficientNetV2: def test_forward(self): - model = TimmBackbone(backbone="tf_efficientnetv2_s.in21k") + model = TimmBackbone(model_name="tf_efficientnetv2_s.in21k") assert model(torch.randn(1, 3, 244, 244))[0].shape == torch.Size([1, 1280, 8, 8]) def test_get_config_optim(self): - model = TimmBackbone(backbone="tf_efficientnetv2_s.in21k") + model = TimmBackbone(model_name="tf_efficientnetv2_s.in21k") assert model.get_config_optim([0.01])[0]["lr"] == 0.01 assert model.get_config_optim(0.01)[0]["lr"] == 0.01 @@ -24,5 +24,5 @@ def test_check_pretrained_weight_download(self): if target.exists(): shutil.rmtree(target) assert not target.exists() - TimmBackbone(backbone="tf_efficientnetv2_s.in21k", pretrained=True) + TimmBackbone(model_name="tf_efficientnetv2_s.in21k", pretrained=True) assert target.exists() diff --git a/tests/unit/algo/classification/test_timm_model.py b/tests/unit/algo/classification/test_timm_model.py index 1cacecf2b5b..0b055f57c1d 100644 --- a/tests/unit/algo/classification/test_timm_model.py +++ b/tests/unit/algo/classification/test_timm_model.py @@ -21,7 +21,7 @@ def fxt_multi_class_cls_model(): return TimmModelForMulticlassCls( label_info=10, - backbone="tf_efficientnetv2_s.in21k", + model_name="tf_efficientnetv2_s.in21k", ) @@ -59,7 +59,7 @@ def test_predict_step(self, fxt_multi_class_cls_model, fxt_multiclass_cls_batch_ def fxt_multi_label_cls_model(): return TimmModelForMultilabelCls( label_info=10, - backbone="tf_efficientnetv2_s.in21k", + model_name="tf_efficientnetv2_s.in21k", ) @@ -97,7 +97,7 @@ def test_predict_step(self, fxt_multi_label_cls_model, fxt_multilabel_cls_batch_ def fxt_h_label_cls_model(fxt_hlabel_cifar): return TimmModelForHLabelCls( label_info=fxt_hlabel_cifar, - backbone="tf_efficientnetv2_s.in21k", + model_name="tf_efficientnetv2_s.in21k", )