diff --git a/src/otx/algo/classification/backbones/timm.py b/src/otx/algo/classification/backbones/timm.py index 7bafa0b1dbb..0dc7ea1fce5 100644 --- a/src/otx/algo/classification/backbones/timm.py +++ b/src/otx/algo/classification/backbones/timm.py @@ -1,7 +1,7 @@ # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 -"""EfficientNetV2 model. +"""Timm Backbone Class for OTX classification. Original papers: - 'EfficientNetV2: Smaller Models and Faster Training,' https://arxiv.org/abs/2104.00298, @@ -9,49 +9,39 @@ """ from __future__ import annotations -from typing import Literal - import timm import torch from torch import nn -TimmModelType = Literal[ - "mobilenetv3_large_100_miil_in21k", - "mobilenetv3_large_100_miil", - "tresnet_m", - "tf_efficientnetv2_s.in21k", - "tf_efficientnetv2_s.in21ft1k", - "tf_efficientnetv2_m.in21k", - "tf_efficientnetv2_m.in21ft1k", - "tf_efficientnetv2_b0", -] - class TimmBackbone(nn.Module): - """Timm backbone model.""" + """Timm backbone model. + + Args: + model_name (str): The name of the model. + You can find available models at timm.list_models() or timm.list_pretrained(). + pretrained (bool, optional): Whether to load pretrained weights. Defaults to False. + """ def __init__( self, - backbone: TimmModelType, + model_name: str, pretrained: bool = False, - pooling_type: str = "avg", **kwargs, ): super().__init__(**kwargs) - self.backbone = backbone + self.model_name = model_name self.pretrained: bool | dict = pretrained - self.is_mobilenet = backbone.startswith("mobilenet") self.model = timm.create_model( - self.backbone, + self.model_name, pretrained=pretrained, num_classes=1000, ) self.model.classifier = None # Detach classifier. Only use 'backbone' part in otx. self.num_head_features = self.model.num_features - self.num_features = self.model.conv_head.in_channels if self.is_mobilenet else self.model.num_features - self.pooling_type = pooling_type + self.num_features = self.model.num_features def forward(self, x: torch.Tensor, **kwargs) -> tuple[torch.Tensor]: """Forward.""" @@ -60,11 +50,6 @@ def forward(self, x: torch.Tensor, **kwargs) -> tuple[torch.Tensor]: def extract_features(self, x: torch.Tensor) -> torch.Tensor: """Extract features.""" - if self.is_mobilenet: - x = self.model.conv_stem(x) - x = self.model.bn1(x) - x = self.model.act1(x) - return self.model.blocks(x) return self.model.forward_features(x) def get_config_optim(self, lrs: list[float] | float) -> list[dict[str, float]]: diff --git a/src/otx/algo/classification/timm_model.py b/src/otx/algo/classification/timm_model.py index d7e171565a7..bfc8960dd85 100644 --- a/src/otx/algo/classification/timm_model.py +++ b/src/otx/algo/classification/timm_model.py @@ -1,7 +1,7 @@ # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 -"""EfficientNetV2 model implementation.""" +"""TIMM wrapper model class for OTX.""" from __future__ import annotations @@ -12,7 +12,7 @@ import torch from torch import nn -from otx.algo.classification.backbones.timm import TimmBackbone, TimmModelType +from otx.algo.classification.backbones.timm import TimmBackbone from otx.algo.classification.classifier import HLabelClassifier, ImageClassifier, SemiSLClassifier from otx.algo.classification.heads import ( HierarchicalCBAMClsHead, @@ -50,12 +50,38 @@ class TimmModelForMulticlassCls(OTXMulticlassClsModel): - """TimmModel for multi-class classification task.""" + """TimmModel for multi-class classification task. + + Args: + label_info (LabelInfoTypes): The label information for the classification task. + model_name (str): The name of the model. + You can find available models at timm.list_models() or timm.list_pretrained(). + input_size (tuple[int, int], optional): Model input size in the order of height and width. + Defaults to (224, 224). + pretrained (bool, optional): Whether to load pretrained weights. Defaults to True. + optimizer (OptimizerCallable, optional): The optimizer callable for training the model. + scheduler (LRSchedulerCallable | LRSchedulerListCallable, optional): The learning rate scheduler callable. + metric (MetricCallable, optional): The metric callable for evaluating the model. + Defaults to MultiClassClsMetricCallable. + torch_compile (bool, optional): Whether to compile the model using TorchScript. Defaults to False. + train_type (Literal[OTXTrainType.SUPERVISED, OTXTrainType.SEMI_SUPERVISED], optional): The training type. + + Example: + 1. API + >>> model = TimmModelForMulticlassCls( + ... model_name="tf_efficientnetv2_s.in21k", + ... label_info=, + ... ) + 2. CLI + >>> otx train \ + ... --model otx.algo.classification.timm_model.TimmModelForMulticlassCls \ + ... --model.model_name tf_efficientnetv2_s.in21k + """ def __init__( self, label_info: LabelInfoTypes, - backbone: TimmModelType, + model_name: str, input_size: tuple[int, int] = (224, 224), # input size of default classification data recipe pretrained: bool = True, optimizer: OptimizerCallable = DefaultOptimizerCallable, @@ -64,7 +90,7 @@ def __init__( torch_compile: bool = False, train_type: Literal[OTXTrainType.SUPERVISED, OTXTrainType.SEMI_SUPERVISED] = OTXTrainType.SUPERVISED, ) -> None: - self.backbone = backbone + self.model_name = model_name self.pretrained = pretrained super().__init__( @@ -92,7 +118,7 @@ def _create_model(self) -> nn.Module: return model def _build_model(self, num_classes: int) -> nn.Module: - backbone = TimmBackbone(backbone=self.backbone, pretrained=self.pretrained) + backbone = TimmBackbone(model_name=self.model_name, pretrained=self.pretrained) neck = GlobalAveragePooling(dim=2) if self.train_type == OTXTrainType.SEMI_SUPERVISED: return SemiSLClassifier( @@ -142,12 +168,37 @@ def forward_for_tracing(self, image: torch.Tensor) -> torch.Tensor | dict[str, t class TimmModelForMultilabelCls(OTXMultilabelClsModel): - """TimmModel for multi-label classification task.""" + """TimmModel for multi-label classification task. + + Args: + label_info (LabelInfoTypes): The label information for the classification task. + model_name (str): The name of the model. + You can find available models at timm.list_models() or timm.list_pretrained(). + input_size (tuple[int, int], optional): Model input size in the order of height and width. + Defaults to (224, 224). + pretrained (bool, optional): Whether to load pretrained weights. Defaults to True. + optimizer (OptimizerCallable, optional): The optimizer callable for training the model. + scheduler (LRSchedulerCallable | LRSchedulerListCallable, optional): The learning rate scheduler callable. + metric (MetricCallable, optional): The metric callable for evaluating the model. + Defaults to MultiLabelClsMetricCallable. + torch_compile (bool, optional): Whether to compile the model using TorchScript. Defaults to False. + + Example: + 1. API + >>> model = TimmModelForMultilabelCls( + ... model_name="tf_efficientnetv2_s.in21k", + ... label_info=, + ... ) + 2. CLI + >>> otx train \ + ... --model otx.algo.classification.timm_model.TimmModelForMultilabelCls \ + ... --model.model_name tf_efficientnetv2_s.in21k + """ def __init__( self, label_info: LabelInfoTypes, - backbone: TimmModelType, + model_name: str, input_size: tuple[int, int] = (224, 224), # input size of default classification data recipe pretrained: bool = True, optimizer: OptimizerCallable = DefaultOptimizerCallable, @@ -155,7 +206,7 @@ def __init__( metric: MetricCallable = MultiLabelClsMetricCallable, torch_compile: bool = False, ) -> None: - self.backbone = backbone + self.model_name = model_name self.pretrained = pretrained super().__init__( @@ -182,7 +233,7 @@ def _create_model(self) -> nn.Module: return model def _build_model(self, num_classes: int) -> nn.Module: - backbone = TimmBackbone(backbone=self.backbone, pretrained=self.pretrained) + backbone = TimmBackbone(model_name=self.model_name, pretrained=self.pretrained) return ImageClassifier( backbone=backbone, neck=GlobalAveragePooling(dim=2), @@ -222,14 +273,39 @@ def forward_for_tracing(self, image: torch.Tensor) -> torch.Tensor | dict[str, t class TimmModelForHLabelCls(OTXHlabelClsModel): - """EfficientNetV2 Model for hierarchical label classification task.""" + """Timm Model for hierarchical label classification task. + + Args: + label_info (HLabelInfo): The label information for the classification task. + model_name (str): The name of the model. + You can find available models at timm.list_models() or timm.list_pretrained(). + input_size (tuple[int, int], optional): Model input size in the order of height and width. + Defaults to (224, 224). + pretrained (bool, optional): Whether to load pretrained weights. Defaults to True. + optimizer (OptimizerCallable, optional): The optimizer callable for training the model. + scheduler (LRSchedulerCallable | LRSchedulerListCallable, optional): The learning rate scheduler callable. + metric (MetricCallable, optional): The metric callable for evaluating the model. + Defaults to HLabelClsMetricCallable. + torch_compile (bool, optional): Whether to compile the model using TorchScript. Defaults to False. + + Example: + 1. API + >>> model = TimmModelForHLabelCls( + ... model_name="tf_efficientnetv2_s.in21k", + ... label_info=, + ... ) + 2. CLI + >>> otx train \ + ... --model otx.algo.classification.timm_model.TimmModelForHLabelCls \ + ... --model.model_name tf_efficientnetv2_s.in21k + """ label_info: HLabelInfo def __init__( self, label_info: HLabelInfo, - backbone: TimmModelType, + model_name: str, input_size: tuple[int, int] = (224, 224), # input size of default classification data recipe pretrained: bool = True, optimizer: OptimizerCallable = DefaultOptimizerCallable, @@ -237,7 +313,7 @@ def __init__( metric: MetricCallable = HLabelClsMetricCallable, torch_compile: bool = False, ) -> None: - self.backbone = backbone + self.model_name = model_name self.pretrained = pretrained super().__init__( @@ -267,7 +343,7 @@ def _create_model(self) -> nn.Module: return model def _build_model(self, head_config: dict) -> nn.Module: - backbone = TimmBackbone(backbone=self.backbone, pretrained=self.pretrained) + backbone = TimmBackbone(model_name=self.model_name, pretrained=self.pretrained) copied_head_config = copy(head_config) copied_head_config["step_size"] = (ceil(self.input_size[0] / 32), ceil(self.input_size[1] / 32)) return HLabelClassifier( diff --git a/src/otx/recipe/classification/h_label_cls/efficientnet_v2.yaml b/src/otx/recipe/classification/h_label_cls/efficientnet_v2.yaml index fc3f6abeab8..1dc6f209979 100644 --- a/src/otx/recipe/classification/h_label_cls/efficientnet_v2.yaml +++ b/src/otx/recipe/classification/h_label_cls/efficientnet_v2.yaml @@ -1,7 +1,7 @@ model: class_path: otx.algo.classification.timm_model.TimmModelForHLabelCls init_args: - backbone: tf_efficientnetv2_s.in21k + model_name: tf_efficientnetv2_s.in21k optimizer: class_path: torch.optim.SGD diff --git a/src/otx/recipe/classification/multi_class_cls/efficientnet_v2.yaml b/src/otx/recipe/classification/multi_class_cls/efficientnet_v2.yaml index 0cb77ef8852..0dd6daf26f7 100644 --- a/src/otx/recipe/classification/multi_class_cls/efficientnet_v2.yaml +++ b/src/otx/recipe/classification/multi_class_cls/efficientnet_v2.yaml @@ -2,7 +2,7 @@ model: class_path: otx.algo.classification.timm_model.TimmModelForMulticlassCls init_args: label_info: 1000 - backbone: tf_efficientnetv2_s.in21k + model_name: tf_efficientnetv2_s.in21k optimizer: class_path: torch.optim.SGD diff --git a/src/otx/recipe/classification/multi_class_cls/semisl/efficientnet_v2_semisl.yaml b/src/otx/recipe/classification/multi_class_cls/semisl/efficientnet_v2_semisl.yaml index 0bf81b8e05d..1187d41d3ba 100644 --- a/src/otx/recipe/classification/multi_class_cls/semisl/efficientnet_v2_semisl.yaml +++ b/src/otx/recipe/classification/multi_class_cls/semisl/efficientnet_v2_semisl.yaml @@ -2,7 +2,7 @@ model: class_path: otx.algo.classification.timm_model.TimmModelForMulticlassCls init_args: label_info: 1000 - backbone: tf_efficientnetv2_s.in21k + model_name: tf_efficientnetv2_s.in21k train_type: SEMI_SUPERVISED optimizer: diff --git a/src/otx/recipe/classification/multi_label_cls/efficientnet_v2.yaml b/src/otx/recipe/classification/multi_label_cls/efficientnet_v2.yaml index 87177eb1e17..14f3b605f12 100644 --- a/src/otx/recipe/classification/multi_label_cls/efficientnet_v2.yaml +++ b/src/otx/recipe/classification/multi_label_cls/efficientnet_v2.yaml @@ -2,7 +2,7 @@ model: class_path: otx.algo.classification.timm_model.TimmModelForMultilabelCls init_args: label_info: 1000 - backbone: tf_efficientnetv2_s.in21k + model_name: tf_efficientnetv2_s.in21k optimizer: class_path: torch.optim.SGD diff --git a/tests/unit/algo/classification/backbones/test_timm.py b/tests/unit/algo/classification/backbones/test_timm.py index 800f45520f3..b53eda0aa43 100644 --- a/tests/unit/algo/classification/backbones/test_timm.py +++ b/tests/unit/algo/classification/backbones/test_timm.py @@ -11,11 +11,11 @@ class TestOTXEfficientNetV2: def test_forward(self): - model = TimmBackbone(backbone="tf_efficientnetv2_s.in21k") + model = TimmBackbone(model_name="tf_efficientnetv2_s.in21k") assert model(torch.randn(1, 3, 244, 244))[0].shape == torch.Size([1, 1280, 8, 8]) def test_get_config_optim(self): - model = TimmBackbone(backbone="tf_efficientnetv2_s.in21k") + model = TimmBackbone(model_name="tf_efficientnetv2_s.in21k") assert model.get_config_optim([0.01])[0]["lr"] == 0.01 assert model.get_config_optim(0.01)[0]["lr"] == 0.01 @@ -24,5 +24,5 @@ def test_check_pretrained_weight_download(self): if target.exists(): shutil.rmtree(target) assert not target.exists() - TimmBackbone(backbone="tf_efficientnetv2_s.in21k", pretrained=True) + TimmBackbone(model_name="tf_efficientnetv2_s.in21k", pretrained=True) assert target.exists() diff --git a/tests/unit/algo/classification/test_timm_model.py b/tests/unit/algo/classification/test_timm_model.py index 1cacecf2b5b..0b055f57c1d 100644 --- a/tests/unit/algo/classification/test_timm_model.py +++ b/tests/unit/algo/classification/test_timm_model.py @@ -21,7 +21,7 @@ def fxt_multi_class_cls_model(): return TimmModelForMulticlassCls( label_info=10, - backbone="tf_efficientnetv2_s.in21k", + model_name="tf_efficientnetv2_s.in21k", ) @@ -59,7 +59,7 @@ def test_predict_step(self, fxt_multi_class_cls_model, fxt_multiclass_cls_batch_ def fxt_multi_label_cls_model(): return TimmModelForMultilabelCls( label_info=10, - backbone="tf_efficientnetv2_s.in21k", + model_name="tf_efficientnetv2_s.in21k", ) @@ -97,7 +97,7 @@ def test_predict_step(self, fxt_multi_label_cls_model, fxt_multilabel_cls_batch_ def fxt_h_label_cls_model(fxt_hlabel_cifar): return TimmModelForHLabelCls( label_info=fxt_hlabel_cifar, - backbone="tf_efficientnetv2_s.in21k", + model_name="tf_efficientnetv2_s.in21k", )